├── .bumpversion.cfg ├── .gitignore ├── CONTRIBUTING.md ├── Jenkinsfile ├── LICENSE.txt ├── README.md ├── aicsmlsegment ├── DataLoader3D │ └── Universal_Loader.py ├── Net3D │ ├── __init__.py │ ├── uNet_original.py │ ├── unet_xy.py │ └── unet_xy_enlarge.py ├── __init__.py ├── bin │ ├── curator │ │ ├── curator_merging.py │ │ ├── curator_sorting.py │ │ └── curator_takeall.py │ ├── predict.py │ └── train.py ├── custom_loss.py ├── custom_metrics.py ├── model_utils.py ├── tests │ ├── __init__.py │ └── dummy_test.py ├── training_utils.py ├── utils.py └── version.py ├── build.gradle ├── configs ├── predict_file_config.yaml ├── predict_folder_config.yaml └── train_config.yaml ├── docs ├── bb1.md ├── bb1_pic.png ├── bb2.md ├── bb2_pic.png ├── bb3.md ├── bb3_pic.png ├── check_cuda.md ├── cuda.png ├── demo1_pic.png ├── demo2_pic.png ├── demo_1.md ├── demo_2.md ├── dl_1_pic.png ├── dl_final.png ├── doc.rst ├── doc_pred_yaml.md ├── doc_train_yaml.md ├── nvidia_smi.png ├── overview.md ├── overview_pic.png └── wf_pic.png ├── gradle.properties ├── gradle └── wrapper │ ├── gradle-wrapper.jar │ └── gradle-wrapper.properties ├── gradlew ├── gradlew.bat ├── settings.gradle ├── setup.cfg └── setup.py /.bumpversion.cfg: -------------------------------------------------------------------------------- 1 | [bumpversion] 2 | current_version = 0.0.8.dev0 3 | parse = (?P\d+)\.(?P\d+)\.(?P\d+)(\.((?P[a-z]*)(?P\d*)))? 4 | serialize = 5 | {major}.{minor}.{patch}.{release}{devbuild} 6 | {major}.{minor}.{patch} 7 | 8 | [bumpversion:part:release] 9 | optional_value = rel 10 | values = 11 | dev 12 | rel 13 | 14 | [bumpversion:file:aicsmlsegment/version.py] 15 | search = {current_version} 16 | replace = {new_version} 17 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode/ 2 | .gradle/ 3 | .idea/ 4 | *.iml 5 | .*.swp 6 | .*.swo 7 | *~ 8 | *.ipynb_checkpoints 9 | 10 | # Generated by build 11 | build/ 12 | dist/ 13 | venv/ 14 | .eggs/ 15 | *.egg-info 16 | **/__pycache__/ 17 | .pytest_cache/ 18 | activate 19 | .coverage 20 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to Allen Cell Segmenter 2 | ------ 3 | 4 | Thank you for taking the time to read how to contribute to Allen Cell Segmenter! This is the basic guideline for contributing to two repositories of Allen Cell Segmenter: 5 | Classic Image Segmentation: https://github.com/AllenCell/aics-segmentation 6 | Iterative Deep Learning: https://github.com/AllenCell/aics-ml-segmentation 7 | 8 | 9 | ## Pull Request 10 | 11 | Both repositories (referred as Allen-Repos below) are one-way mirrored from the internal bitbucket repositories, which are integrated with our internal building infrastructure. For this reason, we cannot support direct pull requests on our Allen-Repos. Instead, we would welcome pull requests following the steps below. 12 | 13 | Fork Allen-Repo 14 | Create a new branch in the forked repo and make changes 15 | Create an issue on Allen-Repo to notify us about the new branch on the forked repo to be viewed as a pull request 16 | We will do code review and building test, and merge the new branch into Allen-Repo when ready 17 | 18 | ## Feature Request and Bug Report 19 | 20 | For feature request or bug report, we encourage people to leave a comment on allen cell discussion forum https://forum.allencell.org/c/software-code. Our internal notification system and issue tracker will keep us up to date and make updates as soon as we can. 21 | 22 | 23 | ## Credit to contributors 24 | 25 | For pull request, if it is merged eventually, all the commit history in the original branch on the forked repo will be preserved. For feature requests and bug reports via allen cell discussion forum, the contributors cannot show up as in Github issue tracker. But, credits will be given to the contributors of pull requests, feature requests, bug reports. 26 | 27 | 28 | 29 | 47 | -------------------------------------------------------------------------------- /Jenkinsfile: -------------------------------------------------------------------------------- 1 | pipeline { 2 | parameters { booleanParam(name: 'create_release', defaultValue: false, 3 | description: 'If true, create a release artifact and publish to ' + 4 | 'the artifactory release PyPi or public PyPi.') } 5 | options { 6 | timeout(time: 1, unit: 'HOURS') 7 | } 8 | agent { 9 | node { 10 | label "python-gradle" 11 | } 12 | } 13 | environment { 14 | PATH = "/home/jenkins/.local/bin:$PATH" 15 | REQUESTS_CA_BUNDLE = "/etc/ssl/certs" 16 | } 17 | stages { 18 | stage ("create virtualenv") { 19 | steps { 20 | this.notifyBB("INPROGRESS") 21 | sh "./gradlew -i cleanAll installCIDependencies" 22 | } 23 | } 24 | 25 | stage ("bump version pre-build") { 26 | when { 27 | expression { return params.create_release } 28 | } 29 | steps { 30 | // This will drop the dev suffix if we are releasing 31 | // X.Y.Z.devN -> X.Y.Z 32 | sh "./gradlew -i bumpVersionRelease" 33 | } 34 | } 35 | 36 | stage ("test/build distribution") { 37 | steps { 38 | sh "./gradlew -i build" 39 | } 40 | } 41 | 42 | stage ("report on tests") { 43 | steps { 44 | junit "build/test_report.xml" 45 | 46 | cobertura autoUpdateHealth: false, 47 | autoUpdateStability: false, 48 | coberturaReportFile: 'build/coverage.xml', 49 | failUnhealthy: false, 50 | failUnstable: false, 51 | maxNumberOfBuilds: 0, 52 | onlyStable: false, 53 | sourceEncoding: 'ASCII', 54 | zoomCoverageChart: false 55 | 56 | 57 | } 58 | } 59 | 60 | stage ("publish release") { 61 | when { 62 | branch 'master' 63 | expression { return params.create_release } 64 | } 65 | steps { 66 | sh "./gradlew -i publishRelease" 67 | sh "./gradlew -i gitTagCommitPush" 68 | sh "./gradlew -i bumpVersionPostRelease gitCommitPush" 69 | } 70 | } 71 | 72 | stage ("publish snapshot") { 73 | when { 74 | branch 'master' 75 | not { expression { return params.create_release } } 76 | } 77 | steps { 78 | sh "./gradlew -i publishSnapshot" 79 | script { 80 | def ignoreAuthors = ["jenkins", "Jenkins User", "Jenkins Builder"] 81 | if (!ignoreAuthors.contains(gitAuthor())) { 82 | sh "./gradlew -i bumpVersionDev gitCommitPush" 83 | } 84 | } 85 | } 86 | } 87 | 88 | } 89 | post { 90 | always { 91 | notifyBuildOnSlack(currentBuild.result, currentBuild.previousBuild?.result) 92 | this.notifyBB(currentBuild.result) 93 | } 94 | cleanup { 95 | deleteDir() 96 | } 97 | } 98 | } 99 | 100 | def notifyBB(String state) { 101 | // on success, result is null 102 | state = state ?: "SUCCESS" 103 | 104 | if (state == "SUCCESS" || state == "FAILURE") { 105 | currentBuild.result = state 106 | } 107 | 108 | notifyBitbucket commitSha1: "${GIT_COMMIT}", 109 | credentialsId: 'aea50792-dda8-40e4-a683-79e8c83e72a6', 110 | disableInprogressNotification: false, 111 | considerUnstableAsSuccess: true, 112 | ignoreUnverifiedSSLPeer: false, 113 | includeBuildNumberInKey: false, 114 | prependParentProjectKey: false, 115 | projectKey: 'SW', 116 | stashServerBaseUrl: 'https://aicsbitbucket.corp.alleninstitute.org' 117 | } 118 | 119 | def notifyBuildOnSlack(String buildStatus = 'STARTED', String priorStatus) { 120 | // build status of null means successful 121 | buildStatus = buildStatus ?: 'SUCCESS' 122 | 123 | // Override default values based on build status 124 | if (buildStatus != 'SUCCESS') { 125 | slackSend ( 126 | color: '#FF0000', 127 | message: "${buildStatus}: '${env.JOB_NAME} [${env.BUILD_NUMBER}]' (${env.BUILD_URL})" 128 | ) 129 | } else if (priorStatus != 'SUCCESS') { 130 | slackSend ( 131 | color: '#00FF00', 132 | message: "BACK_TO_NORMAL: '${env.JOB_NAME} [${env.BUILD_NUMBER}]' (${env.BUILD_URL})" 133 | ) 134 | } 135 | } 136 | 137 | def gitAuthor() { 138 | sh(returnStdout: true, script: 'git log -1 --format=%an').trim() 139 | } 140 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | Allen Institute Software License – This software license is the 2-clause BSD 2 | license plus clause a third clause that prohibits redistribution and use for 3 | commercial purposes without further permission. 4 | 5 | Copyright © 2019. Allen Institute. All rights reserved. 6 | 7 | Redistribution and use in source and binary forms, with or without modification, 8 | are permitted provided that the following conditions are met: 9 | 10 | 1. Redistributions of source code must retain the above copyright notice, this 11 | list of conditions and the following disclaimer. 12 | 13 | 2. Redistributions in binary form must reproduce the above copyright notice, 14 | this list of conditions and the following disclaimer in the documentation and/or 15 | other materials provided with the distribution. 16 | 17 | 3. Redistributions and use for commercial purposes are not permitted without the 18 | Allen Institute’s written permission. For purposes of this license, commercial 19 | purposes are the incorporation of the Allen Institute's software into anything 20 | for which you will charge fees or other compensation or use of the software to 21 | perform a commercial service for a third party. Contact terms@alleninstitute.org 22 | for commercial licensing opportunities. 23 | 24 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 25 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 26 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. 27 | IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, 28 | INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, 29 | BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 30 | DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 31 | LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE 32 | OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED 33 | OF THE POSSIBILITY OF SUCH DAMAGE. 34 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Overview 2 | 3 | The Allen Cell Structure Segmenter is a Python-based open source toolkit developed for 3D segmentation of intracellular structures in fluorescence microscope images, developed at the Allen Institute for Cell Science. This toolkit consists of two complementary elements, a classic image segmentation workflow with a restricted set of algorithms and parameters and an iterative deep learning segmentation workflow. We created a collection of 20 classic image segmentation workflows based on 20 distinct and representative intracellular structure localization patterns as a lookup table reference and starting point for users. The iterative deep learning workflow can take over when the classic segmentation workflow is insufficient. Two straightforward human-in-the-loop curation strategies convert a set of classic image segmentation workflow results into a set of 3D ground truth images for iterative model training without the need for manual painting in 3D. The Allen Cell Structure Segmenter thus leverages state of the art computer vision algorithms in an accessible way to facilitate their application by the experimental biology researcher. More details including algorithms, validations, examples, and video tutorials can be found at [allencell.org/segmenter](allencell.org/segmenter) or in our [bioRxiv paper](https://www.biorxiv.org/content/10.1101/491035v1). 4 | 5 | **Note: This repository has only the code for the "Iterative Deep Learning Workflow". The classic part can be found at [https://github.com/AllenCell/aics-segmentation](https://github.com/AllenCell/aics-segmentation)** 6 | 7 | ## Installation: 8 | 9 | 0. prerequisite: 10 | 11 | To perform training/prediction of the deep learning models in this package, we assume an [NVIDIA GPU](https://www.nvidia.com/en-us/deep-learning-ai/developer/) has been set up properly on a Linux operating system, either on a local machine or on a remote computation cluster. Make sure to check if your GPU supports at least CUDA 8.0 (CUDA 9.0 and up is preferred): [NVIDIA Driver check](https://www.nvidia.com/Download/index.aspx?lang=en-us). 12 | 13 | The GPUs we used to develop and test our package are two types: (1) GeForce GTX 1080 Ti GPU (about 11GB GPU memory), (2) Titan Xp GPU (about 12GB GPU memory), (3) Tesla V100 for PCIe (with about 33GB memory). These cover common chips for personal workstations and data centers. 14 | 15 | **Note 1:** As remote GPU clusters could be set up differently from institute to institute, we will assume a local machine use case through out the installation and demos. 16 | 17 | **Note 2:** We are investigating alternative cloud computing service to deploy our package and will have updates in the next few months. Stay tuned :) 18 | 19 | 20 | 1. create a conda environment: 21 | 22 | ```bash 23 | conda create --name mlsegmenter python=3.7 24 | ``` 25 | 26 | (For how to install conda, see [here](https://docs.conda.io/projects/conda/en/latest/user-guide/install/index.html#installing-conda-on-a-system-that-has-other-python-installations-or-packages)) 27 | 28 | 2. activate your environment and do the installation within the environment: 29 | 30 | ```bash 31 | conda activate mlsegmenter 32 | ``` 33 | 34 | (Note: always check out [conda documentation](https://docs.conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html#activating-an-environment) for updates. If you are using an older version of conda, you may need to activate the environment by `source activate mlsegmenter`.) 35 | 36 | 3. Install Pytorch 37 | 38 | Go to [PyTorch website](https://pytorch.org/get-started/locally/), and find the right installation command for you. 39 | 40 | * we use version 1.0 (which is the stable version at the time of our development) 41 | * we use Linux (OS), Conda (package), python 3.6 (Language), CUDA=10.0 (Question about CUDA? see [setup CUDA](./docs/check_cuda.md)). 42 | 43 | ***Make sure you use either the automatically generated command on PyTorch website, or the command recommended on PyTorch website for installing [older version](https://pytorch.org/get-started/previous-versions/)*** 44 | 45 | 46 | 47 | 4. Install Allen Cell Segmenter (deep learning part) 48 | 49 | ```bash 50 | git clone https://github.com/AllenCell/aics-ml-segmentation.git 51 | cd ./aics-ml-segmentation 52 | pip install -e .[all] 53 | ``` 54 | 55 | The `-e` flag when doing `pip install` will allow users to modify any the source code without the need of re-installing the package afterward. You may do the installation without `-e`, if you don't want any change on the code. 56 | 57 | ## Level of Support 58 | We are offering it to the community AS IS; we have used the toolkit within our organization. We are not able to provide guarantees of support. However, we welcome feedback and submission of issues. Users are encouraged to sign up on our [Allen Cell Discussion Forum](https://forum.allencell.org/) for community quesitons and comments. 59 | 60 | 61 | # Link to [Documentations and Tutorials](./docs/overview.md) -------------------------------------------------------------------------------- /aicsmlsegment/DataLoader3D/Universal_Loader.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | from PIL import Image 4 | import random 5 | 6 | from torch import from_numpy 7 | from aicsimageio import imread 8 | from random import shuffle 9 | import time 10 | from torchvision.transforms import ToTensor 11 | from torch.utils.data import Dataset 12 | 13 | 14 | # CODE for generic loader 15 | # No augmentation = NOAUG,simply load data and convert to tensor 16 | # Augmentation code: 17 | # RR = Rotate by a random degree from 1 to 180 18 | # R4 = Rotate by 0, 90, 180, 270 19 | # FH = Flip Horizantally 20 | # FV = Flip Vertically 21 | # FD = Flip Depth (i.e., along z dim) 22 | # SS = Size Scaling by a ratio between -0.1 to 0.1 (TODO) 23 | # IJ = Intensity Jittering (TODO) 24 | # DD = Dense Deformation (TODO) 25 | 26 | 27 | class RR_FH_M0(Dataset): 28 | 29 | def __init__(self, filenames, num_patch, size_in, size_out): 30 | 31 | self.img = [] 32 | self.gt = [] 33 | self.cmap = [] 34 | 35 | padding = [(x-y)//2 for x,y in zip(size_in, size_out)] 36 | total_in_count = size_in[0] * size_in[1] * size_in[2] 37 | total_out_count = size_out[0] * size_out[1] * size_out[2] 38 | 39 | num_data = len(filenames) 40 | shuffle(filenames) 41 | num_patch_per_img = np.zeros((num_data,), dtype=int) 42 | if num_data >= num_patch: 43 | # all one 44 | num_patch_per_img[:num_patch]=1 45 | else: 46 | basic_num = num_patch // num_data 47 | # assign each image the same number of patches to extract 48 | num_patch_per_img[:] = basic_num 49 | 50 | # assign one more patch to the first few images to achieve the total patch number 51 | num_patch_per_img[:(num_patch-basic_num*num_data)] = num_patch_per_img[:(num_patch-basic_num*num_data)] + 1 52 | 53 | for img_idx, fn in enumerate(filenames): 54 | 55 | if len(self.img)==num_patch: 56 | break 57 | 58 | label = np.squeeze(imread(fn+'_GT.ome.tif')) 59 | label = np.expand_dims(label, axis=0) 60 | 61 | input_img = np.squeeze(imread(fn+'.ome.tif')) 62 | if len(input_img.shape) == 3: 63 | # add channel dimension 64 | input_img = np.expand_dims(input_img, axis=0) 65 | elif len(input_img.shape) == 4: 66 | # assume number of channel < number of Z, make sure channel dim comes first 67 | if input_img.shape[0] > input_img.shape[1]: 68 | input_img = np.transpose(input_img, (1, 0, 2, 3)) 69 | 70 | costmap = np.squeeze(imread(fn+'_CM.ome.tif')) 71 | 72 | img_pad0 = np.pad(input_img, ((0,0),(0,0),(padding[1],padding[1]),(padding[2],padding[2])), 'constant') 73 | raw = np.pad(img_pad0, ((0,0),(padding[0],padding[0]),(0,0),(0,0)), 'constant') 74 | 75 | cost_scale = costmap.max() 76 | if cost_scale<1: ## this should not happen, but just in case 77 | cost_scale = 1 78 | 79 | deg = random.randrange(1,180) 80 | flip_flag = random.random() 81 | 82 | for zz in range(label.shape[1]): 83 | 84 | for ci in range(label.shape[0]): 85 | labi = label[ci,zz,:,:] 86 | labi_pil = Image.fromarray(np.uint8(labi)) 87 | new_labi_pil = labi_pil.rotate(deg,resample=Image.NEAREST) 88 | if flip_flag<0.5: 89 | new_labi_pil = new_labi_pil.transpose(Image.FLIP_LEFT_RIGHT) 90 | new_labi = np.array(new_labi_pil.convert('L')) 91 | label[ci,zz,:,:] = new_labi.astype(int) 92 | 93 | cmap = costmap[zz,:,:] 94 | cmap_pil = Image.fromarray(np.uint8(255*(cmap/cost_scale))) 95 | new_cmap_pil = cmap_pil.rotate(deg,resample=Image.NEAREST) 96 | if flip_flag<0.5: 97 | new_cmap_pil = new_cmap_pil.transpose(Image.FLIP_LEFT_RIGHT) 98 | new_cmap = np.array(new_cmap_pil.convert('L')) 99 | costmap[zz,:,:] = cost_scale*(new_cmap/255.0) 100 | 101 | for zz in range(raw.shape[1]): 102 | for ci in range(raw.shape[0]): 103 | str_im = raw[ci,zz,:,:] 104 | str_im_pil = Image.fromarray(np.uint8(str_im*255)) 105 | new_str_im_pil = str_im_pil.rotate(deg,resample=Image.BICUBIC) 106 | if flip_flag<0.5: 107 | new_str_im_pil = new_str_im_pil.transpose(Image.FLIP_LEFT_RIGHT) 108 | new_str_image = np.array(new_str_im_pil.convert('L')) 109 | raw[ci,zz,:,:] = (new_str_image.astype(float))/255.0 110 | new_patch_num = 0 111 | 112 | while new_patch_num < num_patch_per_img[img_idx]: 113 | 114 | pz = random.randint(0, label.shape[1] - size_out[0]) 115 | py = random.randint(0, label.shape[2] - size_out[1]) 116 | px = random.randint(0, label.shape[3] - size_out[2]) 117 | 118 | 119 | # check if this is a good crop 120 | ref_patch_cmap = costmap[pz:pz+size_out[0],py:py+size_out[1],px:px+size_out[2]] 121 | 122 | # confirmed good crop 123 | (self.img).append(raw[:,pz:pz+size_in[0],py:py+size_in[1],px:px+size_in[2]] ) 124 | (self.gt).append(label[:,pz:pz+size_out[0],py:py+size_out[1],px:px+size_out[2]]) 125 | (self.cmap).append(ref_patch_cmap) 126 | 127 | new_patch_num += 1 128 | 129 | def __getitem__(self, index): 130 | 131 | image_tensor = from_numpy(self.img[index].astype(float)) 132 | cmap_tensor = from_numpy(self.cmap[index].astype(float)) 133 | 134 | label_tensor = [] 135 | if self.gt[index].shape[0]>0: 136 | for zz in range(self.gt[index].shape[0]): 137 | label_tensor.append(from_numpy(self.gt[index][zz,:,:,:].astype(float)).float()) 138 | else: 139 | label_tensor.append(from_numpy(self.gt[index].astype(float)).float()) 140 | 141 | return image_tensor.float(), label_tensor, cmap_tensor.float() 142 | 143 | def __len__(self): 144 | return len(self.img) 145 | 146 | class RR_FH_M0C(Dataset): 147 | 148 | def __init__(self, filenames, num_patch, size_in, size_out): 149 | 150 | self.img = [] 151 | self.gt = [] 152 | self.cmap = [] 153 | 154 | padding = [(x-y)//2 for x,y in zip(size_in, size_out)] 155 | 156 | num_data = len(filenames) 157 | shuffle(filenames) 158 | 159 | num_trial_round = 0 160 | while len(self.img) < num_patch: 161 | 162 | # to avoid dead loop 163 | num_trial_round = num_trial_round + 1 164 | if num_trial_round > 2: 165 | break 166 | 167 | num_patch_to_obtain = num_patch - len(self.img) 168 | num_patch_per_img = np.zeros((num_data,), dtype=int) 169 | if num_data >= num_patch_to_obtain: 170 | # all one 171 | num_patch_per_img[:num_patch_to_obtain]=1 172 | else: 173 | basic_num = num_patch_to_obtain // num_data 174 | # assign each image the same number of patches to extract 175 | num_patch_per_img[:] = basic_num 176 | 177 | # assign one more patch to the first few images to achieve the total patch number 178 | num_patch_per_img[:(num_patch_to_obtain-basic_num*num_data)] = num_patch_per_img[:(num_patch_to_obtain-basic_num*num_data)] + 1 179 | 180 | 181 | for img_idx, fn in enumerate(filenames): 182 | 183 | if len(self.img)==num_patch: 184 | break 185 | 186 | label = np.squeeze(imread(fn+'_GT.ome.tif')) 187 | label = np.expand_dims(label, axis=0) 188 | 189 | input_img = np.squeeze(imread(fn+'.ome.tif')) 190 | if len(input_img.shape) == 3: 191 | # add channel dimension 192 | input_img = np.expand_dims(input_img, axis=0) 193 | elif len(input_img.shape) == 4: 194 | # assume number of channel < number of Z, make sure channel dim comes first 195 | if input_img.shape[0] > input_img.shape[1]: 196 | input_img = np.transpose(input_img, (1, 0, 2, 3)) 197 | 198 | costmap = np.squeeze(imread(fn+'_CM.ome.tif')) 199 | 200 | img_pad0 = np.pad(input_img, ((0,0),(0,0),(padding[1],padding[1]),(padding[2],padding[2])), 'constant') 201 | raw = np.pad(img_pad0, ((0,0),(padding[0],padding[0]),(0,0),(0,0)), 'constant') 202 | 203 | cost_scale = costmap.max() 204 | if cost_scale<1: ## this should not happen, but just in case 205 | cost_scale = 1 206 | 207 | deg = random.randrange(1,180) 208 | flip_flag = random.random() 209 | 210 | for zz in range(label.shape[1]): 211 | 212 | for ci in range(label.shape[0]): 213 | labi = label[ci,zz,:,:] 214 | labi_pil = Image.fromarray(np.uint8(labi)) 215 | new_labi_pil = labi_pil.rotate(deg,resample=Image.NEAREST) 216 | if flip_flag<0.5: 217 | new_labi_pil = new_labi_pil.transpose(Image.FLIP_LEFT_RIGHT) 218 | new_labi = np.array(new_labi_pil.convert('L')) 219 | label[ci,zz,:,:] = new_labi.astype(int) 220 | 221 | cmap = costmap[zz,:,:] 222 | cmap_pil = Image.fromarray(np.uint8(255*(cmap/cost_scale))) 223 | new_cmap_pil = cmap_pil.rotate(deg,resample=Image.NEAREST) 224 | if flip_flag<0.5: 225 | new_cmap_pil = new_cmap_pil.transpose(Image.FLIP_LEFT_RIGHT) 226 | new_cmap = np.array(new_cmap_pil.convert('L')) 227 | costmap[zz,:,:] = cost_scale*(new_cmap/255.0) 228 | 229 | for zz in range(raw.shape[1]): 230 | for ci in range(raw.shape[0]): 231 | str_im = raw[ci,zz,:,:] 232 | str_im_pil = Image.fromarray(np.uint8(str_im*255)) 233 | new_str_im_pil = str_im_pil.rotate(deg,resample=Image.BICUBIC) 234 | if flip_flag<0.5: 235 | new_str_im_pil = new_str_im_pil.transpose(Image.FLIP_LEFT_RIGHT) 236 | new_str_image = np.array(new_str_im_pil.convert('L')) 237 | raw[ci,zz,:,:] = (new_str_image.astype(float))/255.0 238 | 239 | new_patch_num = 0 240 | num_fail = 0 241 | while new_patch_num < num_patch_per_img[img_idx]: 242 | 243 | pz = random.randint(0, label.shape[1] - size_out[0]) 244 | py = random.randint(0, label.shape[2] - size_out[1]) 245 | px = random.randint(0, label.shape[3] - size_out[2]) 246 | 247 | 248 | # check if this is a good crop 249 | ref_patch_cmap = costmap[pz:pz+size_out[0],py:py+size_out[1],px:px+size_out[2]] 250 | if np.count_nonzero(ref_patch_cmap>1e-5) < 1000: #enough valida samples 251 | num_fail = num_fail + 1 252 | if num_fail > 50: 253 | break 254 | continue 255 | 256 | 257 | # confirmed good crop 258 | (self.img).append(raw[:,pz:pz+size_in[0],py:py+size_in[1],px:px+size_in[2]] ) 259 | (self.gt).append(label[:,pz:pz+size_out[0],py:py+size_out[1],px:px+size_out[2]]) 260 | (self.cmap).append(ref_patch_cmap) 261 | 262 | new_patch_num += 1 263 | 264 | def __getitem__(self, index): 265 | 266 | image_tensor = from_numpy(self.img[index].astype(float)) 267 | cmap_tensor = from_numpy(self.cmap[index].astype(float)) 268 | 269 | label_tensor = [] 270 | if self.gt[index].shape[0]>0: 271 | for zz in range(self.gt[index].shape[0]): 272 | label_tensor.append(from_numpy(self.gt[index][zz,:,:,:].astype(float)).float()) 273 | else: 274 | label_tensor.append(from_numpy(self.gt[index].astype(float)).float()) 275 | 276 | return image_tensor.float(), label_tensor, cmap_tensor.float() 277 | 278 | def __len__(self): 279 | return len(self.img) 280 | 281 | class NOAUG_M(Dataset): 282 | 283 | def __init__(self, filenames, num_patch, size_in, size_out): 284 | 285 | self.img = [] 286 | self.gt = [] 287 | self.cmap = [] 288 | 289 | padding = [(x-y)//2 for x,y in zip(size_in, size_out)] 290 | total_in_count = size_in[0] * size_in[1] * size_in[2] 291 | total_out_count = size_out[0] * size_out[1] * size_out[2] 292 | 293 | num_data = len(filenames) 294 | shuffle(filenames) 295 | num_patch_per_img = np.zeros((num_data,), dtype=int) 296 | if num_data >= num_patch: 297 | # all one 298 | num_patch_per_img[:num_patch]=1 299 | else: 300 | basic_num = num_patch // num_data 301 | # assign each image the same number of patches to extract 302 | num_patch_per_img[:] = basic_num 303 | 304 | # assign one more patch to the first few images to achieve the total patch number 305 | num_patch_per_img[:(num_patch-basic_num*num_data)] = num_patch_per_img[:(num_patch-basic_num*num_data)] + 1 306 | 307 | 308 | for img_idx, fn in enumerate(filenames): 309 | 310 | label = np.squeeze(imread(fn+'_GT.ome.tif')) 311 | label = np.expand_dims(label, axis=0) 312 | 313 | input_img = np.squeeze(imread(fn+'.ome.tif')) 314 | if len(input_img.shape) == 3: 315 | # add channel dimension 316 | input_img = np.expand_dims(input_img, axis=0) 317 | elif len(input_img.shape) == 4: 318 | # assume number of channel < number of Z, make sure channel dim comes first 319 | if input_img.shape[0] > input_img.shape[1]: 320 | input_img = np.transpose(input_img, (1, 0, 2, 3)) 321 | 322 | costmap = np.squeeze(imread(fn+'_CM.ome.tif')) 323 | 324 | img_pad0 = np.pad(input_img, ((0,0),(0,0),(padding[1],padding[1]),(padding[2],padding[2])), 'symmetric') 325 | raw = np.pad(img_pad0, ((0,0),(padding[0],padding[0]),(0,0),(0,0)), 'constant') 326 | 327 | new_patch_num = 0 328 | 329 | while new_patch_num < num_patch_per_img[img_idx]: 330 | 331 | pz = random.randint(0, label.shape[1] - size_out[0]) 332 | py = random.randint(0, label.shape[2] - size_out[1]) 333 | px = random.randint(0, label.shape[3] - size_out[2]) 334 | 335 | 336 | ## check if this is a good crop 337 | ref_patch_cmap = costmap[pz:pz+size_out[0],py:py+size_out[1],px:px+size_out[2]] 338 | 339 | 340 | # confirmed good crop 341 | (self.img).append(raw[:,pz:pz+size_in[0],py:py+size_in[1],px:px+size_in[2]] ) 342 | (self.gt).append(label[:,pz:pz+size_out[0],py:py+size_out[1],px:px+size_out[2]]) 343 | (self.cmap).append(ref_patch_cmap) 344 | 345 | new_patch_num += 1 346 | 347 | def __getitem__(self, index): 348 | 349 | image_tensor = from_numpy(self.img[index].astype(float)) 350 | cmap_tensor = from_numpy(self.cmap[index].astype(float)) 351 | 352 | #if self.gt[index].shape[0]>1: 353 | label_tensor = [] 354 | for zz in range(self.gt[index].shape[0]): 355 | tmp_tensor = from_numpy(self.gt[index][zz,:,:,:].astype(float)) 356 | label_tensor.append(tmp_tensor.float()) 357 | #else: 358 | # label_tensor = from_numpy(self.gt[index].astype(float)) 359 | # label_tensor = label_tensor.float() 360 | 361 | return image_tensor.float(), label_tensor, cmap_tensor.float() 362 | 363 | def __len__(self): 364 | return len(self.img) -------------------------------------------------------------------------------- /aicsmlsegment/Net3D/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AllenCell/aics-ml-segmentation/5c15d9720e1c3297377c586a98209792792a43f8/aicsmlsegment/Net3D/__init__.py -------------------------------------------------------------------------------- /aicsmlsegment/Net3D/uNet_original.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class UNet3D(nn.Module): 6 | def __init__(self, in_channel, n_classes, batchnorm_flag=True): 7 | self.in_channel = in_channel 8 | self.n_classes = n_classes 9 | super(UNet3D, self).__init__() 10 | 11 | self.ec1 = self.encoder(self.in_channel, 32, batchnorm=batchnorm_flag) 12 | self.ec2 = self.encoder(64, 64, batchnorm=batchnorm_flag) 13 | self.ec3 = self.encoder(128, 128, batchnorm=batchnorm_flag) 14 | self.ec4 = self.encoder(256, 256, batchnorm=batchnorm_flag) 15 | 16 | self.pool1 = nn.MaxPool3d(2) 17 | self.pool2 = nn.MaxPool3d(2) 18 | self.pool3 = nn.MaxPool3d(2) 19 | 20 | self.up3 = nn.ConvTranspose3d(512, 512, kernel_size=2, stride=2, padding=0, output_padding=0, bias=True) 21 | self.dc3 = self.decoder(256 + 512, 256, batchnorm=batchnorm_flag) 22 | self.up2 = nn.ConvTranspose3d(256, 256, kernel_size=2, stride=2, padding=0, output_padding=0, bias=True) 23 | self.dc2 = self.decoder(128 + 256, 128, batchnorm=batchnorm_flag) 24 | self.up1 = nn.ConvTranspose3d(128, 128, kernel_size=2, stride=2, padding=0, output_padding=0, bias=True) 25 | self.dc1 = self.decoder(64 + 128, 64, batchnorm=batchnorm_flag) 26 | 27 | self.dc0 = nn.Conv3d(64, n_classes, 1) 28 | self.softmax = F.log_softmax 29 | 30 | self.numClass = n_classes 31 | 32 | def encoder(self, in_channels, out_channels, kernel_size=3, stride=1, padding=0, 33 | bias=True, batchnorm=False): 34 | if batchnorm: 35 | layer = nn.Sequential( 36 | nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=bias), 37 | nn.BatchNorm2d(out_channels, affine=False), 38 | nn.ReLU(), 39 | nn.Conv3d(out_channels, 2*out_channels, kernel_size, stride=stride, padding=padding, bias=bias), 40 | nn.BatchNorm2d(2*out_channels, affine=False), 41 | nn.ReLU()) 42 | else: 43 | layer = nn.Sequential( 44 | nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=bias), 45 | nn.ReLU(), 46 | nn.Conv3d(out_channels, 2*out_channels, kernel_size, stride=stride, padding=padding, bias=bias), 47 | nn.ReLU()) 48 | return layer 49 | 50 | 51 | def decoder(self, in_channels, out_channels, kernel_size=3, stride=1, padding=0, 52 | bias=True, batchnorm=False): 53 | if batchnorm: 54 | layer = nn.Sequential( 55 | nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=bias), 56 | nn.BatchNorm2d(out_channels, affine=False), 57 | nn.ReLU(), 58 | nn.Conv3d(out_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=bias), 59 | nn.BatchNorm2d(out_channels, affine=False), 60 | nn.ReLU()) 61 | else: 62 | layer = nn.Sequential( 63 | nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=bias), 64 | nn.ReLU(), 65 | nn.Conv3d(out_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=bias), 66 | nn.ReLU()) 67 | return layer 68 | 69 | def forward(self, x): 70 | 71 | down1 = self.ec1(x) 72 | x1 = self.pool1(down1) 73 | down2 = self.ec2(x1) 74 | x2 = self.pool2(down2) 75 | down3 = self.ec3(x2) 76 | x3 = self.pool3(down3) 77 | 78 | u3 = self.ec4(x3) 79 | 80 | d3 = torch.cat((self.up3(u3), F.pad(down3,(-4,-4,-4,-4,-4,-4))), 1) 81 | u2 = self.dc3(d3) 82 | d2 = torch.cat((self.up2(u2), F.pad(down2,(-16,-16,-16,-16,-16,-16))), 1) 83 | u1 = self.dc2(d2) 84 | d1 = torch.cat((self.up1(u1), F.pad(down1,(-40,-40,-40,-40,-40,-40))), 1) 85 | u0 = self.dc1(d1) 86 | out = self.dc0(u0) 87 | 88 | out = out.permute(0, 2, 3, 4, 1).contiguous() # move the class channel to the last dimension 89 | out = out.view(out.numel() // self.numClass, self.numClass) 90 | out = self.softmax(out, dim=1) 91 | 92 | return out 93 | -------------------------------------------------------------------------------- /aicsmlsegment/Net3D/unet_xy.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class UNet3D(nn.Module): 6 | def __init__(self, in_channel, n_classes, batchnorm_flag=True): 7 | self.in_channel = in_channel 8 | self.n_classes = n_classes 9 | super(UNet3D, self).__init__() 10 | 11 | self.ec1 = self.encoder(self.in_channel, 32, batchnorm=batchnorm_flag) # in --> 64 12 | self.ec2 = self.encoder(64, 64, batchnorm=batchnorm_flag) # 64 --> 128 13 | self.ec3 = self.encoder(128, 128, batchnorm=batchnorm_flag) # 128 --> 256 14 | self.ec4 = self.encoder(256, 256, batchnorm=batchnorm_flag) # 256 -->512 15 | 16 | self.pool1 = nn.MaxPool3d((1,2,2)) 17 | self.pool2 = nn.MaxPool3d((1,2,2)) 18 | self.pool3 = nn.MaxPool3d((1,2,2)) 19 | 20 | self.up3 = nn.ConvTranspose3d(512, 512, kernel_size=(1,2,2), stride=(1,2,2), padding=0, output_padding=0, bias=True) 21 | self.up2 = nn.ConvTranspose3d(256, 256, kernel_size=(1,2,2), stride=(1,2,2), padding=0, output_padding=0, bias=True) 22 | self.up1 = nn.ConvTranspose3d(128, 128, kernel_size=(1,2,2), stride=(1,2,2), padding=0, output_padding=0, bias=True) 23 | 24 | self.dc3 = self.decoder(256 + 512, 256, batchnorm=batchnorm_flag) 25 | self.dc2 = self.decoder(128 + 256, 128, batchnorm=batchnorm_flag) 26 | self.dc1 = self.decoder(64 + 128, 64, batchnorm=batchnorm_flag) 27 | 28 | self.dc0 = nn.Conv3d(64, n_classes[0], 1) 29 | 30 | self.up2a = nn.ConvTranspose3d(256, n_classes[2], kernel_size=(1,8,8), stride=(1,4,4), padding=0, output_padding=0, bias=True) 31 | self.up1a = nn.ConvTranspose3d(128, n_classes[1], kernel_size=(1,4,4), stride=(1,2,2), padding=0, output_padding=0, bias=True) 32 | 33 | self.conv2a = nn.Conv3d(n_classes[2], n_classes[2], 3, stride=1, padding=0, bias=True) 34 | self.conv1a = nn.Conv3d(n_classes[1], n_classes[1], 3, stride=1, padding=0, bias=True) 35 | 36 | self.predict2a = nn.Conv3d(n_classes[2], n_classes[2], 1) 37 | self.predict1a = nn.Conv3d(n_classes[1], n_classes[1], 1) 38 | 39 | #self.conv_final = nn.Conv3d(n_classes[0]+n_classes[1]+n_classes[2], n_classes[0]+n_classes[1]+n_classes[2], 3, stride=1, padding=1, bias=True) 40 | #self.predict_final = nn.Conv3d(n_classes[0]+n_classes[1]+n_classes[2], n_classes[3], 1) 41 | 42 | self.softmax = F.log_softmax # nn.LogSoftmax(1) 43 | 44 | self.final_activation = nn.Softmax(dim=1) 45 | 46 | self.numClass = n_classes[0] 47 | self.numClass1 = n_classes[1] 48 | self.numClass2 = n_classes[2] 49 | #self.numClass_combine = n_classes[3] 50 | 51 | def encoder(self, in_channels, out_channels, kernel_size=3, stride=1, padding=0, 52 | bias=True, batchnorm=False): 53 | if batchnorm: 54 | layer = nn.Sequential( 55 | nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=bias), 56 | nn.BatchNorm3d(out_channels, affine=False), 57 | nn.ReLU(), 58 | nn.Conv3d(out_channels, 2*out_channels, kernel_size, stride=stride, padding=padding, bias=bias), 59 | nn.BatchNorm3d(2*out_channels, affine=False), 60 | nn.ReLU()) 61 | else: 62 | layer = nn.Sequential( 63 | nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=bias), 64 | nn.ReLU(), 65 | nn.Conv3d(out_channels, 2*out_channels, kernel_size, stride=stride, padding=padding, bias=bias), 66 | nn.ReLU()) 67 | return layer 68 | 69 | 70 | def decoder(self, in_channels, out_channels, kernel_size=3, stride=1, padding=0, 71 | bias=True, batchnorm=False): 72 | if batchnorm: 73 | layer = nn.Sequential( 74 | nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=bias), 75 | nn.BatchNorm3d(out_channels, affine=False), 76 | nn.ReLU(), 77 | nn.Conv3d(out_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=bias), 78 | nn.BatchNorm3d(out_channels, affine=False), 79 | nn.ReLU()) 80 | else: 81 | layer = nn.Sequential( 82 | nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=bias), 83 | nn.ReLU(), 84 | nn.Conv3d(out_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=bias), 85 | nn.ReLU()) 86 | return layer 87 | 88 | def forward(self, x): 89 | 90 | down1 = self.ec1(x) 91 | x1 = self.pool1(down1) 92 | down2 = self.ec2(x1) 93 | x2 = self.pool2(down2) 94 | down3 = self.ec3(x2) 95 | x3 = self.pool3(down3) 96 | 97 | u3 = self.ec4(x3) 98 | 99 | d3 = torch.cat((self.up3(u3), F.pad(down3,(-4,-4,-4,-4,-2,-2))), 1) 100 | u2 = self.dc3(d3) 101 | 102 | d2 = torch.cat((self.up2(u2), F.pad(down2,(-16,-16,-16,-16,-6,-6))), 1) 103 | u1 = self.dc2(d2) 104 | 105 | d1 = torch.cat((self.up1(u1), F.pad(down1,(-40,-40,-40,-40,-10,-10))), 1) 106 | u0 = self.dc1(d1) 107 | 108 | p0 = self.dc0(u0) 109 | 110 | p1a = F.pad(self.predict1a(self.conv1a(self.up1a(u1))),(-2,-2,-2,-2, -1, -1)) 111 | p2a = F.pad(self.predict2a(self.conv2a(self.up2a(u2))),(-7,-7,-7,-7,-3,-3)) 112 | 113 | p0_final = p0.permute(0, 2, 3, 4, 1).contiguous() # move the class channel to the last dimension 114 | p0_final = p0_final.view(p0_final.numel() // self.numClass, self.numClass) 115 | p0_final = self.softmax(p0_final, dim=1) 116 | 117 | p1_final = p1a.permute(0, 2, 3, 4, 1).contiguous() # move the class channel to the last dimension 118 | p1_final = p1_final.view(p1_final.numel() // self.numClass1, self.numClass1) 119 | p1_final = self.softmax(p1_final, dim=1) 120 | 121 | p2_final = p2a.permute(0, 2, 3, 4, 1).contiguous() # move the class channel to the last dimension 122 | p2_final = p2_final.view(p2_final.numel() // self.numClass2, self.numClass2) 123 | p2_final = self.softmax(p2_final, dim=1) 124 | 125 | ''' 126 | p_combine0 = self.predict_final(self.conv_final(torch.cat((p0, p1a, p2a), 1))) # BCZYX 127 | p_combine = p_combine0.permute(0, 2, 3, 4, 1).contiguous() # move the class channel to the last dimension 128 | p_combine = p_combine.view(p_combine.numel() // self.numClass_combine, self.numClass_combine) 129 | p_combine = self.softmax(p_combine) 130 | ''' 131 | 132 | return [p0_final, p1_final, p2_final] 133 | -------------------------------------------------------------------------------- /aicsmlsegment/Net3D/unet_xy_enlarge.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class UNet3D(nn.Module): 6 | def __init__(self, in_channel, n_classes, down_ratio, batchnorm_flag=True): 7 | self.in_channel = in_channel 8 | self.n_classes = n_classes 9 | super(UNet3D, self).__init__() 10 | 11 | k = down_ratio 12 | 13 | self.ec1 = self.encoder(self.in_channel, 32, batchnorm=batchnorm_flag) # in --> 64 14 | self.ec2 = self.encoder(64, 64, batchnorm=batchnorm_flag) # 64 --> 128 15 | self.ec3 = self.encoder(128, 128, batchnorm=batchnorm_flag) # 128 --> 256 16 | self.ec4 = self.encoder(256, 256, batchnorm=batchnorm_flag) # 256 -->512 17 | 18 | self.pool0 = nn.MaxPool3d((1,k,k)) 19 | self.pool1 = nn.MaxPool3d((1,2,2)) 20 | self.pool2 = nn.MaxPool3d((1,2,2)) 21 | self.pool3 = nn.MaxPool3d((1,2,2)) 22 | 23 | self.up3 = nn.ConvTranspose3d(512, 512, kernel_size=(1,2,2), stride=(1,2,2), padding=0, output_padding=0, bias=True) 24 | self.up2 = nn.ConvTranspose3d(256, 256, kernel_size=(1,2,2), stride=(1,2,2), padding=0, output_padding=0, bias=True) 25 | self.up1 = nn.ConvTranspose3d(128, 128, kernel_size=(1,2,2), stride=(1,2,2), padding=0, output_padding=0, bias=True) 26 | self.up0 = nn.ConvTranspose3d(64, 64, kernel_size=(1,k,k), stride=(1,k,k), padding=0, output_padding=0, bias=True) 27 | 28 | self.dc3 = self.decoder(256 + 512, 256, batchnorm=batchnorm_flag) 29 | self.dc2 = self.decoder(128 + 256, 128, batchnorm=batchnorm_flag) 30 | self.dc1 = self.decoder(64 + 128, 64, batchnorm=batchnorm_flag) 31 | self.dc0 = self.decoder(64, 64, batchnorm=batchnorm_flag) 32 | 33 | self.predict0 = nn.Conv3d(64, n_classes[0], 1) 34 | 35 | self.up1a = nn.ConvTranspose3d(128, n_classes[1], kernel_size=(1,2*k,2*k), stride=(1,2*k,2*k), padding=0, output_padding=0, bias=True) 36 | self.up2a = nn.ConvTranspose3d(256, n_classes[2], kernel_size=(1,4*k,4*k), stride=(1,4*k,4*k), padding=0, output_padding=0, bias=True) 37 | 38 | self.conv2a = nn.Conv3d(n_classes[2], n_classes[2], 3, stride=1, padding=0, bias=True) 39 | self.conv1a = nn.Conv3d(n_classes[1], n_classes[1], 3, stride=1, padding=0, bias=True) 40 | 41 | self.predict2a = nn.Conv3d(n_classes[2], n_classes[2], 1) 42 | self.predict1a = nn.Conv3d(n_classes[1], n_classes[1], 1) 43 | 44 | #self.conv_final = nn.Conv3d(n_classes[0]+n_classes[1]+n_classes[2], n_classes[0]+n_classes[1]+n_classes[2], 3, stride=1, padding=1, bias=True) 45 | #self.predict_final = nn.Conv3d(n_classes[0]+n_classes[1]+n_classes[2], n_classes[3], 1) 46 | 47 | self.softmax = F.log_softmax # nn.LogSoftmax(1) 48 | 49 | self.final_activation = nn.Softmax(dim=1) 50 | 51 | self.numClass = n_classes[0] 52 | self.numClass1 = n_classes[1] 53 | self.numClass2 = n_classes[2] 54 | 55 | self.k = k 56 | #self.numClass_combine = n_classes[3] 57 | 58 | def encoder(self, in_channels, out_channels, kernel_size=3, stride=1, padding=0, 59 | bias=True, batchnorm=False): 60 | if batchnorm: 61 | layer = nn.Sequential( 62 | nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=bias), 63 | nn.BatchNorm3d(out_channels, affine=False), 64 | nn.ReLU(), 65 | nn.Conv3d(out_channels, 2*out_channels, kernel_size, stride=stride, padding=padding, bias=bias), 66 | nn.BatchNorm3d(2*out_channels, affine=False), 67 | nn.ReLU()) 68 | else: 69 | layer = nn.Sequential( 70 | nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=bias), 71 | nn.ReLU(), 72 | nn.Conv3d(out_channels, 2*out_channels, kernel_size, stride=stride, padding=padding, bias=bias), 73 | nn.ReLU()) 74 | return layer 75 | 76 | 77 | def decoder(self, in_channels, out_channels, kernel_size=3, stride=1, padding=0, 78 | bias=True, batchnorm=False): 79 | if batchnorm: 80 | layer = nn.Sequential( 81 | nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=bias), 82 | nn.BatchNorm3d(out_channels, affine=False), 83 | nn.ReLU(), 84 | nn.Conv3d(out_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=bias), 85 | nn.BatchNorm3d(out_channels, affine=False), 86 | nn.ReLU()) 87 | else: 88 | layer = nn.Sequential( 89 | nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=bias), 90 | nn.ReLU(), 91 | nn.Conv3d(out_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=bias), 92 | nn.ReLU()) 93 | return layer 94 | 95 | def forward(self, x): 96 | 97 | k = self.k 98 | 99 | x0 = self.pool0(x) 100 | 101 | down1 = self.ec1(x0) 102 | x1 = self.pool1(down1) 103 | down2 = self.ec2(x1) 104 | x2 = self.pool2(down2) 105 | down3 = self.ec3(x2) 106 | x3 = self.pool3(down3) 107 | 108 | u3 = self.ec4(x3) 109 | 110 | d3 = torch.cat((self.up3(u3), F.pad(down3,(-4,-4,-4,-4,-2,-2))), 1) 111 | u2 = self.dc3(d3) 112 | 113 | d2 = torch.cat((self.up2(u2), F.pad(down2,(-16,-16,-16,-16,-6,-6))), 1) 114 | u1 = self.dc2(d2) 115 | 116 | d1 = torch.cat((self.up1(u1), F.pad(down1,(-40,-40,-40,-40,-10,-10))), 1) 117 | u0 = self.dc1(d1) 118 | 119 | d0 = self.up0(u0) 120 | 121 | predict00 = self.predict0(self.dc0(d0)) 122 | p0_final = predict00.permute(0, 2, 3, 4, 1).contiguous() # move the class channel to the last dimension 123 | p0_final = p0_final.view(p0_final.numel() // self.numClass, self.numClass) 124 | p0_final = self.softmax(p0_final, dim=1) 125 | 126 | p1a = F.pad(self.predict1a(self.conv1a(self.up1a(u1))),(-2*k-1,-2*k-1,-2*k-1,-2*k-1, -3, -3)) 127 | p1_final = p1a.permute(0, 2, 3, 4, 1).contiguous() # move the class channel to the last dimension 128 | p1_final = p1_final.view(p1_final.numel() // self.numClass1, self.numClass1) 129 | p1_final = self.softmax(p1_final, dim=1) 130 | 131 | p2a = F.pad(self.predict2a(self.conv2a(self.up2a(u2))),(-6*k-1,-6*k-1,-6*k-1,-6*k-1,-5,-5)) ## fix +5 132 | p2_final = p2a.permute(0, 2, 3, 4, 1).contiguous() # move the class channel to the last dimension 133 | p2_final = p2_final.view(p2_final.numel() // self.numClass2, self.numClass2) 134 | p2_final = self.softmax(p2_final, dim=1) 135 | 136 | ''' 137 | p_combine0 = self.predict_final(self.conv_final(torch.cat((p0, p1a, p2a), 1))) # BCZYX 138 | p_combine = p_combine0.permute(0, 2, 3, 4, 1).contiguous() # move the class channel to the last dimension 139 | p_combine = p_combine.view(p_combine.numel() // self.numClass_combine, self.numClass_combine) 140 | p_combine = self.softmax(p_combine) 141 | ''' 142 | 143 | return [p0_final, p1_final, p2_final] 144 | -------------------------------------------------------------------------------- /aicsmlsegment/__init__.py: -------------------------------------------------------------------------------- 1 | from .version import MODULE_VERSION 2 | 3 | 4 | def get_module_version(): 5 | return MODULE_VERSION 6 | 7 | 8 | -------------------------------------------------------------------------------- /aicsmlsegment/bin/curator/curator_merging.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import os 4 | import sys 5 | import logging 6 | import argparse 7 | import traceback 8 | import importlib 9 | import pathlib 10 | import csv 11 | 12 | import pandas as pd 13 | import numpy as np 14 | import matplotlib.pyplot as plt 15 | import matplotlib 16 | from glob import glob 17 | from random import shuffle 18 | from scipy import stats 19 | from skimage.io import imsave 20 | from skimage.draw import line, polygon 21 | 22 | from aicssegmentation.core.utils import histogram_otsu 23 | from aicsimageio import AICSImage, imread 24 | from aicsimageio.writers import OmeTiffWriter 25 | from aicsmlsegment.utils import input_normalization 26 | 27 | matplotlib.use('TkAgg') 28 | 29 | #################################################################################################### 30 | # global settings 31 | ignore_img = False 32 | flag_done = False 33 | pts = [] 34 | draw_img = None 35 | draw_mask = None 36 | draw_ax = None 37 | 38 | 39 | log = logging.getLogger() 40 | logging.basicConfig(level=logging.INFO, 41 | format='[%(asctime)s - %(name)s - %(lineno)3d][%(levelname)s] %(message)s') 42 | # 43 | # Set the default log level for other modules used by this script 44 | # logging.getLogger("labkey").setLevel(logging.ERROR) 45 | # logging.getLogger("requests").setLevel(logging.WARNING) 46 | # logging.getLogger("urllib3").setLevel(logging.WARNING) 47 | logging.getLogger("matplotlib").setLevel(logging.INFO) 48 | #################################################################################################### 49 | 50 | def draw_polygons(event): 51 | global pts, draw_img, draw_ax, draw_mask 52 | if event.button == 1: 53 | if not (event.ydata == None or event.xdata == None): 54 | pts.append([event.xdata,event.ydata]) 55 | if len(pts)>1: 56 | rr, cc = line(int(round(pts[-1][0])), int(round(pts[-1][1])), int(round(pts[-2][0])), int(round(pts[-2][1])) ) 57 | draw_img[cc,rr,:1]=255 58 | draw_ax.set_data(draw_img) 59 | plt.draw() 60 | elif event.button == 3: 61 | if len(pts)>2: 62 | # draw polygon 63 | pts_array = np.asarray(pts) 64 | rr, cc = polygon(pts_array[:,0], pts_array[:,1]) 65 | draw_img[cc,rr,:1]=255 66 | draw_ax.set_data(draw_img) 67 | draw_mask[cc,rr]=1 68 | pts.clear() 69 | plt.draw() 70 | else: 71 | print('need at least three clicks before finishing annotation') 72 | 73 | def quit_mask_drawing(event): 74 | global ignore_img 75 | if event.key == 'd': 76 | plt.close() 77 | elif event.key == 'b': 78 | ignore_img = True 79 | plt.close() 80 | elif event.key == 'q': 81 | exit() 82 | 83 | 84 | def create_merge_mask(raw_img, seg1, seg2, drawing_aim): 85 | global pts, draw_img, draw_mask, draw_ax 86 | 87 | offset = 20 88 | seg1_label = seg1 + offset # make it brighter 89 | seg1_label[seg1_label==offset]=0 90 | seg1_label = seg1_label.astype(float) * (255/seg1_label.max()) 91 | seg1_label = np.round(seg1_label) 92 | seg1_label = seg1_label.astype(np.uint8) 93 | 94 | offset = 25 95 | seg2_label = seg2 + offset # make it brighter 96 | seg2_label[seg2_label==offset]=0 97 | seg2_label = seg2_label.astype(float) * (255/seg2_label.max()) 98 | seg2_label = np.round(seg2_label) 99 | seg2_label = seg2_label.astype(np.uint8) 100 | 101 | 102 | bw = seg1>0 103 | z_profile = np.zeros((bw.shape[0],),dtype=int) 104 | for zz in range(bw.shape[0]): 105 | z_profile[zz] = np.count_nonzero(bw[zz,:,:]) 106 | mid_frame = int(round(histogram_otsu(z_profile)*bw.shape[0])) 107 | 108 | img = np.zeros((2*raw_img.shape[1], 3*raw_img.shape[2], 3),dtype=np.uint8) 109 | 110 | row_index = 0 111 | 112 | for cc in range(3): 113 | img[row_index*raw_img.shape[1]:(row_index+1)*raw_img.shape[1], :raw_img.shape[2], cc]=np.amax(raw_img, axis=0) 114 | img[row_index*raw_img.shape[1]:(row_index+1)*raw_img.shape[1], raw_img.shape[2]:2*raw_img.shape[2], cc]=np.amax(seg1_label, axis=0) 115 | img[row_index*raw_img.shape[1]:(row_index+1)*raw_img.shape[1], 2*raw_img.shape[2]:, cc]=np.amax(seg2_label, axis=0) 116 | 117 | row_index = 1 118 | for cc in range(3): 119 | img[row_index*raw_img.shape[1]:(row_index+1)*raw_img.shape[1], :raw_img.shape[2], cc]=raw_img[mid_frame,:,:] 120 | img[row_index*raw_img.shape[1]:(row_index+1)*raw_img.shape[1], raw_img.shape[2]:2*raw_img.shape[2], cc]=seg1_label[mid_frame,:,:] 121 | img[row_index*raw_img.shape[1]:(row_index+1)*raw_img.shape[1], 2*raw_img.shape[2]:, cc]=seg2_label[mid_frame,:,:] 122 | 123 | draw_mask = np.zeros((img.shape[0],img.shape[1]),dtype=np.uint8) 124 | draw_img = img.copy() 125 | # display the image for good/bad inspection 126 | fig = plt.figure() 127 | figManager = plt.get_current_fig_manager() 128 | figManager.full_screen_toggle() 129 | ax = fig.add_subplot(111) 130 | ax.set_title('Interface for annotating '+drawing_aim+'. Left: raw, Middle: segmentation v1, Right: segmentation v2. \n' \ 131 | +'Top row: max z projection, Bottom row: middle z slice. \n'\ 132 | +'Please draw in the upper left panel \n'\ 133 | +'Left click to add a vertex; Right click to close the current polygon \n' \ 134 | +'Press D to finish annotating mask, Press Q to quit curation (can resume later)') 135 | draw_ax = ax.imshow(img) 136 | cid = fig.canvas.mpl_connect('button_press_event', draw_polygons) 137 | cid2 = fig.canvas.mpl_connect('key_press_event', quit_mask_drawing) 138 | plt.show() 139 | fig.canvas.mpl_disconnect(cid) 140 | fig.canvas.mpl_disconnect(cid2) 141 | 142 | class Args(object): 143 | """ 144 | Use this to define command line arguments and use them later. 145 | 146 | For each argument do the following 147 | 1. Create a member in __init__ before the self.__parse call. 148 | 2. Provide a default value here. 149 | 3. Then in p.add_argument, set the dest parameter to that variable name. 150 | 151 | See the debug parameter as an example. 152 | """ 153 | 154 | def __init__(self, log_cmdline=True): 155 | self.debug = False 156 | self.output_dir = '.'+os.sep 157 | self.struct_ch = 0 158 | self.xy = 0.108 159 | 160 | # 161 | self.__parse() 162 | # 163 | if self.debug: 164 | log.setLevel(logging.DEBUG) 165 | log.debug("-" * 80) 166 | self.show_info() 167 | log.debug("-" * 80) 168 | 169 | @staticmethod 170 | def __no_args_print_help(parser): 171 | """ 172 | This is used to print out the help if no arguments are provided. 173 | Note: 174 | - You need to remove it's usage if your script truly doesn't want arguments. 175 | - It exits with 1 because it's an error if this is used in a script with no args. 176 | That's a non-interactive use scenario - typically you don't want help there. 177 | """ 178 | if len(sys.argv) == 1: 179 | parser.print_help() 180 | sys.exit(1) 181 | 182 | def __parse(self): 183 | p = argparse.ArgumentParser() 184 | # Add arguments 185 | p.add_argument('--d', '--debug', action='store_true', dest='debug', 186 | help='If set debug log output is enabled') 187 | p.add_argument('--raw_path', required=True, help='path to raw images') 188 | p.add_argument('--data_type', required=True, help='the type of raw images') 189 | p.add_argument('--input_channel', default=0, type=int) 190 | p.add_argument('--seg1_path', required=True, help='path to segmentation results v1') 191 | p.add_argument('--seg2_path', required=True, help='path to segmentation results v2') 192 | p.add_argument('--train_path', required=True, help='path to output training data') 193 | p.add_argument('--mask_path', help='[optional] the output directory for merging masks') 194 | p.add_argument('--ex_mask_path', help='[optional] the output directory for excluding masks') 195 | p.add_argument('--csv_name', required=True, help='the csv file to save the sorting results') 196 | p.add_argument('--Normalization', required=True, type=int, help='the normalization recipe to use') 197 | 198 | self.__no_args_print_help(p) 199 | p.parse_args(namespace=self) 200 | 201 | def show_info(self): 202 | log.debug("Working Dir:") 203 | log.debug("\t{}".format(os.getcwd())) 204 | log.debug("Command Line:") 205 | log.debug("\t{}".format(" ".join(sys.argv))) 206 | log.debug("Args:") 207 | for (k, v) in self.__dict__.items(): 208 | log.debug("\t{}: {}".format(k, v)) 209 | 210 | 211 | ############################################################################### 212 | 213 | class Executor(object): 214 | 215 | def __init__(self, args): 216 | 217 | if os.path.exists(args.csv_name): 218 | print('the csv file for saving sorting results exists, sorting will be resumed') 219 | else: 220 | print('no existing csv found, start a new sorting ') 221 | if not args.data_type.startswith('.'): 222 | args.data_type = '.' + args.data_type 223 | 224 | filenames = glob(args.raw_path + os.sep +'*' + args.data_type) 225 | filenames.sort() 226 | with open(args.csv_name, 'w') as csvfile: 227 | filewriter = csv.writer(csvfile, delimiter=',', quotechar='|', quoting=csv.QUOTE_MINIMAL) 228 | filewriter.writerow(['raw','seg1','seg2','score','merging_mask','excluding_mask']) 229 | for _, fn in enumerate(filenames): 230 | seg1_fn = args.seg1_path + os.sep + os.path.basename(fn)[:-1*len(args.data_type)] + '_struct_segmentation.tiff' 231 | seg2_fn = args.seg2_path + os.sep + os.path.basename(fn)[:-1*len(args.data_type)] + '_struct_segmentation.tiff' 232 | assert os.path.exists(seg1_fn) 233 | assert os.path.exists(seg2_fn) 234 | filewriter.writerow([fn, seg1_fn , seg2_fn , None, None, None]) 235 | 236 | def execute(self, args): 237 | 238 | global draw_mask, ignore_img 239 | # part 1: do sorting 240 | df = pd.read_csv(args.csv_name, index_col=False) 241 | 242 | for index, row in df.iterrows(): 243 | 244 | if not np.isnan(row['score']) and (row['score']==1 or row['score']==0): 245 | continue 246 | 247 | reader = AICSImage(row['raw']) 248 | struct_img = reader.get_image_data("ZYX", S=0, T=0, C=args.input_channel) 249 | raw_img = (struct_img- struct_img.min() + 1e-8)/(struct_img.max() - struct_img.min() + 1e-8) 250 | raw_img = 255 * raw_img 251 | raw_img = raw_img.astype(np.uint8) 252 | 253 | seg1 = np.squeeze(imread(row['seg1'])) > 0.01 254 | seg2 = np.squeeze(imread(row['seg2'])) > 0.01 255 | 256 | create_merge_mask(raw_img, seg1.astype(np.uint8), seg2.astype(np.uint8), 'merging_mask') 257 | 258 | if ignore_img: 259 | df['score'].iloc[index]=0 260 | else: 261 | df['score'].iloc[index]=1 262 | 263 | mask_fn = args.mask_path + os.sep + os.path.basename(row['raw'])[:-5] + '_mask.tiff' 264 | crop_mask = np.zeros(seg1.shape, dtype=np.uint8) 265 | for zz in range(crop_mask.shape[0]): 266 | crop_mask[zz,:,:] = draw_mask[:crop_mask.shape[1],:crop_mask.shape[2]] 267 | 268 | crop_mask = crop_mask.astype(np.uint8) 269 | crop_mask[crop_mask>0]=255 270 | with OmeTiffWriter(mask_fn) as writer: 271 | writer.save(crop_mask) 272 | df['merging_mask'].iloc[index]=mask_fn 273 | 274 | need_mask = input('Do you need to add an excluding mask for this image, enter y or n: ') 275 | if need_mask == 'y': 276 | create_merge_mask(raw_img, seg1.astype(np.uint8), seg2.astype(np.uint8), 'excluding mask') 277 | 278 | mask_fn = args.ex_mask_path + os.sep + os.path.basename(row['raw'])[:-5] + '_mask.tiff' 279 | crop_mask = np.zeros(seg1.shape, dtype=np.uint8) 280 | for zz in range(crop_mask.shape[0]): 281 | crop_mask[zz,:,:] = draw_mask[:crop_mask.shape[1],:crop_mask.shape[2]] 282 | 283 | crop_mask = crop_mask.astype(np.uint8) 284 | crop_mask[crop_mask>0]=255 285 | with OmeTiffWriter(mask_fn) as writer: 286 | writer.save(crop_mask) 287 | df['excluding_mask'].iloc[index]=mask_fn 288 | 289 | 290 | df.to_csv(args.csv_name, index=False) 291 | 292 | 293 | ######################################### 294 | # generate training data: 295 | # (we want to do this step after "sorting" 296 | # (is mainly because we want to get the sorting 297 | # step as smooth as possible, even though 298 | # this may waster i/o time on reloading images) 299 | # ####################################### 300 | print('finish merging, start building the training data ...') 301 | existing_files = glob(args.train_path+os.sep+'img_*.ome.tif') 302 | print(len(existing_files)) 303 | 304 | training_data_count = len(existing_files)//3 305 | for index, row in df.iterrows(): 306 | if row['score']==1: 307 | training_data_count += 1 308 | 309 | # load raw image 310 | reader = AICSImage(row['raw']) 311 | img = reader.get_image_data("CZYX", S=0, T=0, C=[args.input_channel]).astype(np.float32) 312 | struct_img = input_normalization(img, args) 313 | struct_img= struct_img[0,:,:,:] 314 | 315 | seg1 = np.squeeze(imread(row['seg1'])) > 0.01 316 | seg2 = np.squeeze(imread(row['seg2'])) > 0.01 317 | 318 | if os.path.isfile(str(row['merging_mask'])): 319 | mask = np.squeeze(imread(row['merging_mask'])) 320 | seg1[mask>0]=0 321 | seg2[mask==0]=0 322 | seg1 = np.logical_or(seg1,seg2) 323 | 324 | cmap = np.ones(seg1.shape, dtype=np.float32) 325 | if os.path.isfile(str(row['excluding_mask'])): 326 | ex_mask = np.squeeze(imread(row['excluding_mask'])) > 0.01 327 | cmap[ex_mask>0]=0 328 | 329 | with OmeTiffWriter(args.train_path + os.sep + 'img_' + f'{training_data_count:03}' + '.ome.tif') as writer: 330 | writer.save(struct_img) 331 | 332 | seg1 = seg1.astype(np.uint8) 333 | seg1[seg1>0]=1 334 | with OmeTiffWriter(args.train_path + os.sep + 'img_' + f'{training_data_count:03}' + '_GT.ome.tif') as writer: 335 | writer.save(seg1) 336 | 337 | with OmeTiffWriter(args.train_path + os.sep + 'img_' + f'{training_data_count:03}' + '_CM.ome.tif') as writer: 338 | writer.save(cmap) 339 | print('training data is ready') 340 | 341 | 342 | def main(): 343 | dbg = False 344 | try: 345 | args = Args() 346 | dbg = args.debug 347 | 348 | # Do your work here - preferably in a class or function, 349 | # passing in your args. E.g. 350 | exe = Executor(args) 351 | exe.execute(args) 352 | 353 | except Exception as e: 354 | log.error("=============================================") 355 | if dbg: 356 | log.error("\n\n" + traceback.format_exc()) 357 | log.error("=============================================") 358 | log.error("\n\n" + str(e) + "\n") 359 | log.error("=============================================") 360 | sys.exit(1) 361 | 362 | 363 | if __name__ == "__main__": 364 | main() 365 | 366 | -------------------------------------------------------------------------------- /aicsmlsegment/bin/curator/curator_takeall.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import os 4 | import sys 5 | import logging 6 | import argparse 7 | import traceback 8 | import importlib 9 | import pathlib 10 | import csv 11 | import pandas as pd 12 | import numpy as np 13 | import matplotlib.pyplot as plt 14 | from glob import glob 15 | from random import shuffle 16 | from scipy import stats 17 | from skimage.io import imsave 18 | from skimage.draw import line, polygon 19 | from scipy import ndimage as ndi 20 | 21 | from aicssegmentation.core.utils import histogram_otsu 22 | from aicsimageio import AICSImage, imread 23 | from aicsimageio.writers import OmeTiffWriter 24 | 25 | from aicsmlsegment.utils import input_normalization 26 | 27 | #################################################################################################### 28 | # global settings 29 | button = 0 30 | flag_done = False 31 | pts = [] 32 | draw_img = None 33 | draw_mask = None 34 | draw_ax = None 35 | 36 | 37 | log = logging.getLogger() 38 | logging.basicConfig(level=logging.INFO, 39 | format='[%(asctime)s - %(name)s - %(lineno)3d][%(levelname)s] %(message)s') 40 | # 41 | # Set the default log level for other modules used by this script 42 | # logging.getLogger("labkey").setLevel(logging.ERROR) 43 | # logging.getLogger("requests").setLevel(logging.WARNING) 44 | # logging.getLogger("urllib3").setLevel(logging.WARNING) 45 | logging.getLogger("matplotlib").setLevel(logging.INFO) 46 | #################################################################################################### 47 | 48 | class Args(object): 49 | """ 50 | Use this to define command line arguments and use them later. 51 | 52 | For each argument do the following 53 | 1. Create a member in __init__ before the self.__parse call. 54 | 2. Provide a default value here. 55 | 3. Then in p.add_argument, set the dest parameter to that variable name. 56 | 57 | See the debug parameter as an example. 58 | """ 59 | 60 | def __init__(self, log_cmdline=True): 61 | self.debug = False 62 | self.output_dir = '.' + os.sep 63 | self.struct_ch = 0 64 | self.xy = 0.108 65 | 66 | # 67 | self.__parse() 68 | # 69 | if self.debug: 70 | log.setLevel(logging.DEBUG) 71 | log.debug("-" * 80) 72 | self.show_info() 73 | log.debug("-" * 80) 74 | 75 | @staticmethod 76 | def __no_args_print_help(parser): 77 | """ 78 | This is used to print out the help if no arguments are provided. 79 | Note: 80 | - You need to remove it's usage if your script truly doesn't want arguments. 81 | - It exits with 1 because it's an error if this is used in a script with no args. 82 | That's a non-interactive use scenario - typically you don't want help there. 83 | """ 84 | if len(sys.argv) == 1: 85 | parser.print_help() 86 | sys.exit(1) 87 | 88 | def __parse(self): 89 | p = argparse.ArgumentParser() 90 | # Add arguments 91 | p.add_argument('--d', '--debug', action='store_true', dest='debug', 92 | help='If set debug log output is enabled') 93 | p.add_argument('--raw_path', required=True, help='path to raw images') 94 | p.add_argument('--data_type', required=True, help='the type of raw images') 95 | p.add_argument('--input_channel', default=0, type=int) 96 | p.add_argument('--seg_path', required=True, help='path to segmentation results') 97 | p.add_argument('--train_path', required=True, help='path to output training data') 98 | p.add_argument('--mask_path', help='[optional] the output directory for masks') 99 | p.add_argument('--Normalization', default=0, help='the normalization method to use') 100 | 101 | self.__no_args_print_help(p) 102 | p.parse_args(namespace=self) 103 | 104 | def show_info(self): 105 | log.debug("Working Dir:") 106 | log.debug("\t{}".format(os.getcwd())) 107 | log.debug("Command Line:") 108 | log.debug("\t{}".format(" ".join(sys.argv))) 109 | log.debug("Args:") 110 | for (k, v) in self.__dict__.items(): 111 | log.debug("\t{}: {}".format(k, v)) 112 | 113 | 114 | ############################################################################### 115 | 116 | class Executor(object): 117 | 118 | def __init__(self, args): 119 | pass 120 | 121 | def execute(self, args): 122 | 123 | if not args.data_type.startswith('.'): 124 | args.data_type = '.' + args.data_type 125 | 126 | filenames = glob(args.raw_path + os.sep +'*' + args.data_type) 127 | filenames.sort() 128 | 129 | existing_files = glob(args.train_path+os.sep+'img_*.ome.tif') 130 | print(len(existing_files)) 131 | 132 | training_data_count = len(existing_files)//3 133 | for _, fn in enumerate(filenames): 134 | 135 | training_data_count += 1 136 | 137 | # load raw 138 | reader = AICSImage(fn) 139 | struct_img = reader.get_image_data("CZYX", S=0, T=0, C=[args.input_channel]).astype(np.float32) 140 | struct_img = input_normalization(img, args) 141 | 142 | # load seg 143 | seg_fn = args.seg_path + os.sep + os.path.basename(fn)[:-1*len(args.data_type)] + '_struct_segmentation.tiff' 144 | seg = np.squeeze(imread(seg_fn)) > 0.01 145 | seg = seg.astype(np.uint8) 146 | seg[seg>0]=1 147 | 148 | # excluding mask 149 | cmap = np.ones(seg.shape, dtype=np.float32) 150 | mask_fn = args.mask_path + os.sep + os.path.basename(fn)[:-1*len(args.data_type)] + '_mask.tiff' 151 | if os.path.isfile(mask_fn): 152 | mask = np.squeeze(imread(mask_fn)) 153 | cmap[mask==0]=0 154 | 155 | with OmeTiffWriter(args.train_path + os.sep + 'img_' + f'{training_data_count:03}' + '.ome.tif') as writer: 156 | writer.save(struct_img) 157 | 158 | with OmeTiffWriter(args.train_path + os.sep + 'img_' + f'{training_data_count:03}' + '_GT.ome.tif') as writer: 159 | writer.save(seg) 160 | 161 | with OmeTiffWriter(args.train_path + os.sep + 'img_' + f'{training_data_count:03}' + '_CM.ome.tif') as writer: 162 | writer.save(cmap) 163 | 164 | 165 | def main(): 166 | dbg = False 167 | try: 168 | args = Args() 169 | dbg = args.debug 170 | 171 | # Do your work here - preferably in a class or function, 172 | # passing in your args. E.g. 173 | exe = Executor(args) 174 | exe.execute(args) 175 | 176 | except Exception as e: 177 | log.error("=============================================") 178 | if dbg: 179 | log.error("\n\n" + traceback.format_exc()) 180 | log.error("=============================================") 181 | log.error("\n\n" + str(e) + "\n") 182 | log.error("=============================================") 183 | sys.exit(1) 184 | 185 | 186 | if __name__ == "__main__": 187 | main() 188 | 189 | -------------------------------------------------------------------------------- /aicsmlsegment/bin/predict.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import sys 4 | import argparse 5 | import logging 6 | import traceback 7 | import os 8 | import pathlib 9 | import numpy as np 10 | 11 | from skimage.morphology import remove_small_objects 12 | from skimage.io import imsave 13 | from aicsimageio import AICSImage 14 | from scipy.ndimage import zoom 15 | 16 | from aicsmlsegment.utils import load_config, load_single_image, input_normalization, image_normalization 17 | from aicsmlsegment.utils import get_logger 18 | from aicsmlsegment.model_utils import build_model, load_checkpoint, model_inference, apply_on_image 19 | 20 | def main(): 21 | 22 | parser = argparse.ArgumentParser() 23 | parser.add_argument('--config', required=True) 24 | args = parser.parse_args() 25 | 26 | config = load_config(args.config) 27 | 28 | # declare the model 29 | model = build_model(config) 30 | 31 | # load the trained model instance 32 | model_path = config['model_path'] 33 | print(f'Loading model from {model_path}...') 34 | load_checkpoint(model_path, model) 35 | 36 | # extract the parameters for running the model inference 37 | args_inference=lambda:None 38 | args_inference.size_in = config['size_in'] 39 | args_inference.size_out = config['size_out'] 40 | args_inference.OutputCh = config['OutputCh'] 41 | args_inference.nclass = config['nclass'] 42 | if config['RuntimeAug'] <=0: 43 | args_inference.RuntimeAug = False 44 | else: 45 | args_inference.RuntimeAug = True 46 | 47 | # run 48 | inf_config = config['mode'] 49 | if inf_config['name'] == 'file': 50 | fn = inf_config['InputFile'] 51 | data_reader = AICSImage(fn) 52 | 53 | if inf_config['timelapse']: 54 | assert data_reader.shape[1] > 1, "not a timelapse, check you data" 55 | 56 | for tt in range(data_reader.shape[1]): 57 | # Assume: dimensions = TCZYX 58 | img = data_reader.get_image_data("CZYX", S=0, T=tt, C=config['InputCh']).astype(float) 59 | img = image_normalization(img, config['Normalization']) 60 | 61 | if len(config['ResizeRatio'])>0: 62 | img = zoom(img, (1, config['ResizeRatio'][0], config['ResizeRatio'][1], config['ResizeRatio'][2]), order=2, mode='reflect') 63 | for ch_idx in range(img.shape[0]): 64 | struct_img = img[ch_idx,:,:,:] 65 | struct_img = (struct_img - struct_img.min())/(struct_img.max() - struct_img.min()) 66 | img[ch_idx,:,:,:] = struct_img 67 | 68 | # apply the model 69 | output_img = apply_on_image(model, img, model.final_activation, args_inference) 70 | 71 | # extract the result and write the output 72 | if len(config['OutputCh']) == 2: 73 | out = output_img[0] 74 | out = (out - out.min()) / (out.max()-out.min()) 75 | if len(config['ResizeRatio'])>0: 76 | out = zoom(out, (1.0, 1/config['ResizeRatio'][0], 1/config['ResizeRatio'][1], 1/config['ResizeRatio'][2]), order=2, mode='reflect') 77 | out = out.astype(np.float32) 78 | if config['Threshold']>0: 79 | out = out > config['Threshold'] 80 | out = out.astype(np.uint8) 81 | out[out>0]=255 82 | imsave(config['OutputDir'] + os.sep + pathlib.PurePosixPath(fn).stem + '_T_'+ f'{tt:03}' +'_struct_segmentation.tiff', out) 83 | else: 84 | for ch_idx in range(len(config['OutputCh'])//2): 85 | out = output_img[ch_idx] 86 | out = (out - out.min()) / (out.max()-out.min()) 87 | if len(config['ResizeRatio'])>0: 88 | out = zoom(out, (1.0, 1/config['ResizeRatio'][0], 1/config['ResizeRatio'][1], 1/config['ResizeRatio'][2]), order=2, mode='reflect') 89 | out = out.astype(np.float32) 90 | if config['Threshold']>0: 91 | out = out > config['Threshold'] 92 | out = out.astype(np.uint8) 93 | out[out>0]=255 94 | imsave(config['OutputDir'] + os.sep + pathlib.PurePosixPath(fn).stem + '_T_'+ f'{tt:03}' +'_seg_'+ str(config['OutputCh'][2*ch_idx])+'.tiff',out) 95 | else: 96 | img = data_reader.get_image_data("CZYX", S=0, T=0, C=config['InputCh']).astype(float) 97 | img = image_normalization(img, config['Normalization']) 98 | 99 | if len(config['ResizeRatio'])>0: 100 | img = zoom(img, (1, config['ResizeRatio'][0], config['ResizeRatio'][1], config['ResizeRatio'][2]), order=2, mode='reflect') 101 | for ch_idx in range(img.shape[0]): 102 | struct_img = img[ch_idx,:,:,:] # note that struct_img is only a view of img, so changes made on struct_img also affects img 103 | struct_img = (struct_img - struct_img.min())/(struct_img.max() - struct_img.min()) 104 | img[ch_idx,:,:,:] = struct_img 105 | 106 | # apply the model 107 | output_img = apply_on_image(model, img, model.final_activation, args_inference) 108 | 109 | # extract the result and write the output 110 | if len(config['OutputCh']) == 2: 111 | out = output_img[0] 112 | out = (out - out.min()) / (out.max()-out.min()) 113 | if len(config['ResizeRatio'])>0: 114 | out = zoom(out, (1.0, 1/config['ResizeRatio'][0], 1/config['ResizeRatio'][1], 1/config['ResizeRatio'][2]), order=2, mode='reflect') 115 | out = out.astype(np.float32) 116 | if config['Threshold']>0: 117 | out = out > config['Threshold'] 118 | out = out.astype(np.uint8) 119 | out[out>0]=255 120 | imsave(config['OutputDir'] + os.sep + pathlib.PurePosixPath(fn).stem +'_struct_segmentation.tiff', out) 121 | else: 122 | for ch_idx in range(len(config['OutputCh'])//2): 123 | out = output_img[ch_idx] 124 | out = (out - out.min()) / (out.max()-out.min()) 125 | if len(config['ResizeRatio'])>0: 126 | out = zoom(out, (1.0, 1/config['ResizeRatio'][0], 1/config['ResizeRatio'][1], 1/config['ResizeRatio'][2]), order=2, mode='reflect') 127 | out = out.astype(np.float32) 128 | if config['Threshold']>0: 129 | out = out > config['Threshold'] 130 | out = out.astype(np.uint8) 131 | out[out>0]=255 132 | imsave(config['OutputDir'] + os.sep + pathlib.PurePosixPath(fn).stem +'_seg_'+ str(config['OutputCh'][2*ch_idx])+'.tiff', out) 133 | print(f'Image {fn} has been segmented') 134 | 135 | elif inf_config['name'] == 'folder': 136 | from glob import glob 137 | filenames = glob(inf_config['InputDir'] + '/*' + inf_config['DataType']) 138 | filenames.sort() #(reverse=True) 139 | print('files to be processed:') 140 | print(filenames) 141 | 142 | for _, fn in enumerate(filenames): 143 | 144 | # load data 145 | data_reader = AICSImage(fn) 146 | img = data_reader.get_image_data('CZYX', S=0, T=0, C=config['InputCh']).astype(float) 147 | if len(config['ResizeRatio'])>0: 148 | img = zoom(img, (1,config['ResizeRatio'][0], config['ResizeRatio'][1], config['ResizeRatio'][2]), order=2, mode='reflect') 149 | img = image_normalization(img, config['Normalization']) 150 | 151 | # apply the model 152 | output_img = apply_on_image(model, img, model.final_activation, args_inference) 153 | 154 | # extract the result and write the output 155 | if len(config['OutputCh'])==2: 156 | if config['Threshold']<0: 157 | out = output_img[0] 158 | out = (out - out.min()) / (out.max()-out.min()) 159 | if len(config['ResizeRatio'])>0: 160 | out = zoom(out, (1.0, 1/config['ResizeRatio'][0], 1/config['ResizeRatio'][1], 1/config['ResizeRatio'][2]), order=2, mode='reflect') 161 | out = out.astype(np.float32) 162 | out = (out - out.min()) / (out.max()-out.min()) 163 | else: 164 | out = remove_small_objects(output_img[0] > config['Threshold'], min_size=2, connectivity=1) 165 | out = out.astype(np.uint8) 166 | out[out>0]=255 167 | imsave(config['OutputDir'] + os.sep + pathlib.PurePosixPath(fn).stem + '_struct_segmentation.tiff', out) 168 | else: 169 | for ch_idx in range(len(config['OutputCh'])//2): 170 | if config['Threshold']<0: 171 | out = output_img[ch_idx] 172 | out = (out - out.min()) / (out.max()-out.min()) 173 | out = out.astype(np.float32) 174 | else: 175 | out = output_img[ch_idx] > config['Threshold'] 176 | out = out.astype(np.uint8) 177 | out[out>0]=255 178 | imsave(config['OutputDir'] + os.sep + pathlib.PurePosixPath(fn).stem + '_seg_'+ str(config['OutputCh'][2*ch_idx])+'.ome.tif', out) 179 | 180 | print(f'Image {fn} has been segmented') 181 | 182 | if __name__ == '__main__': 183 | 184 | main() -------------------------------------------------------------------------------- /aicsmlsegment/bin/train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import sys 4 | import argparse 5 | import logging 6 | import traceback 7 | 8 | from aicsmlsegment.utils import load_config 9 | 10 | from aicsmlsegment.training_utils import BasicFolderTrainer, get_loss_criterion, build_optimizer, get_train_dataloader 11 | from aicsmlsegment.utils import get_logger 12 | from aicsmlsegment.model_utils import get_number_of_learnable_parameters, build_model, load_checkpoint 13 | 14 | 15 | def main(): 16 | 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument('--config', required=True) 19 | args = parser.parse_args() 20 | 21 | # create logger 22 | logger = get_logger('ModelTrainer') 23 | config = load_config(args.config) 24 | logger.info(config) 25 | 26 | # Create model 27 | model = build_model(config) 28 | 29 | # Log the number of learnable parameters 30 | logger.info(f'Number of learnable params {get_number_of_learnable_parameters(model)}') 31 | 32 | # check if resuming 33 | if config['resume'] is not None: 34 | print(f"Loading checkpoint '{config['resume']}'...") 35 | load_checkpoint(config['resume'], model) 36 | else: 37 | print('start a new training') 38 | 39 | # run the training 40 | trainer = BasicFolderTrainer(model, config, logger=logger) 41 | trainer.train() 42 | 43 | if __name__ == '__main__': 44 | main() -------------------------------------------------------------------------------- /aicsmlsegment/custom_loss.py: -------------------------------------------------------------------------------- 1 | 2 | from torch.autograd import Variable, Function 3 | import torch.optim as optim 4 | import torch.nn as nn 5 | import torch 6 | import numpy as np 7 | 8 | class ElementNLLLoss(torch.nn.Module): 9 | def __init__(self, num_class): 10 | super(ElementNLLLoss,self).__init__() 11 | self.num_class = num_class 12 | 13 | def forward(self, input, target, weight): 14 | 15 | target_np = target.cpu().data.numpy() 16 | target_np = target_np.astype(np.uint8) 17 | 18 | row_num = target_np.shape[0] 19 | mask = np.zeros((row_num,self.num_class )) 20 | mask[np.arange(row_num), target_np]=1 21 | class_x = torch.masked_select(input, Variable(torch.from_numpy(mask).cuda().bool())) 22 | 23 | out = torch.mul(class_x,weight) 24 | loss = torch.mean(torch.neg(out),0) 25 | 26 | return loss 27 | 28 | class MultiAuxillaryElementNLLLoss(torch.nn.Module): 29 | def __init__(self,num_task, weight, num_class): 30 | super(MultiAuxillaryElementNLLLoss,self).__init__() 31 | self.num_task = num_task 32 | self.weight = weight 33 | 34 | self.criteria_list = [] 35 | for nn in range(self.num_task): 36 | self.criteria_list.append(ElementNLLLoss(num_class[nn])) 37 | 38 | def forward(self, input, target, cmap): 39 | 40 | total_loss = self.weight[0]*self.criteria_list[0](input[0], target.view(target.numel()), cmap.view(cmap.numel()) ) 41 | 42 | for nn in np.arange(1,self.num_task): 43 | total_loss = total_loss + self.weight[nn]*self.criteria_list[nn](input[nn], target.view(target.numel()), cmap.view(cmap.numel()) ) 44 | 45 | return total_loss 46 | 47 | class MultiTaskElementNLLLoss(torch.nn.Module): 48 | def __init__(self, weight, num_class): 49 | super(MultiTaskElementNLLLoss,self).__init__() 50 | self.num_task = len(num_class) 51 | self.weight = weight 52 | 53 | self.criteria_list = [] 54 | for nn in range(self.num_task): 55 | self.criteria_list.append(ElementNLLLoss(num_class[nn])) 56 | 57 | def forward(self, input, target, cmap): 58 | 59 | assert len(target) == self.num_task and len(input) == self.num_task 60 | 61 | total_loss = self.weight[0]*self.criteria_list[0](input[0], target[0].view(target[0].numel()), cmap.view(cmap.numel()) ) 62 | 63 | for nn in np.arange(1,self.num_task): 64 | total_loss = total_loss + self.weight[nn]*self.criteria_list[nn](input[nn], target[nn].view(target[nn].numel()), cmap.view(cmap.numel()) ) 65 | 66 | return total_loss 67 | 68 | class ElementAngularMSELoss(torch.nn.Module): 69 | def __init__(self): 70 | super(ElementAngularMSELoss,self).__init__() 71 | 72 | def forward(self, input, target, weight): 73 | 74 | #((input - target) ** 2).sum() / input.data.nelement() 75 | 76 | return torch.sum( torch.mul( torch.acos(torch.sum(torch.mul(input,target),dim=1))**2, weight) )/ torch.gt(weight,0).data.nelement() 77 | 78 | def compute_per_channel_dice(input, target, epsilon=1e-5, ignore_index=None, weight=None): 79 | # assumes that input is a normalized probability 80 | 81 | # input and target shapes must match 82 | assert input.size() == target.size(), "'input' and 'target' must have the same shape" 83 | 84 | # mask ignore_index if present 85 | if ignore_index is not None: 86 | mask = target.clone().ne_(ignore_index) 87 | mask.requires_grad = False 88 | 89 | input = input * mask 90 | target = target * mask 91 | 92 | input = flatten(input) 93 | target = flatten(target) 94 | 95 | target = target.float() 96 | # Compute per channel Dice Coefficient 97 | intersect = (input * target).sum(-1) 98 | if weight is not None: 99 | intersect = weight * intersect 100 | 101 | denominator = (input + target).sum(-1) 102 | return 2. * intersect / denominator.clamp(min=epsilon) 103 | 104 | 105 | class DiceLoss(nn.Module): 106 | """Computes Dice Loss, which just 1 - DiceCoefficient described above. 107 | Additionally allows per-class weights to be provided. 108 | """ 109 | 110 | def __init__(self, epsilon=1e-5, weight=None, ignore_index=None, sigmoid_normalization=True, 111 | skip_last_target=False): 112 | super(DiceLoss, self).__init__() 113 | self.epsilon = epsilon 114 | self.register_buffer('weight', weight) 115 | self.ignore_index = ignore_index 116 | # The output from the network during training is assumed to be un-normalized probabilities and we would 117 | # like to normalize the logits. Since Dice (or soft Dice in this case) is usually used for binary data, 118 | # normalizing the channels with Sigmoid is the default choice even for multi-class segmentation problems. 119 | # However if one would like to apply Softmax in order to get the proper probability distribution from the 120 | # output, just specify sigmoid_normalization=False. 121 | if sigmoid_normalization: 122 | self.normalization = nn.Sigmoid() 123 | else: 124 | self.normalization = nn.Softmax(dim=1) 125 | # if True skip the last channel in the target 126 | self.skip_last_target = skip_last_target 127 | 128 | def forward(self, input, target): 129 | # get probabilities from logits 130 | input = self.normalization(input) 131 | if self.weight is not None: 132 | weight = Variable(self.weight, requires_grad=False) 133 | else: 134 | weight = None 135 | 136 | if self.skip_last_target: 137 | target = target[:, :-1, ...] 138 | 139 | per_channel_dice = compute_per_channel_dice(input, target, epsilon=self.epsilon, ignore_index=self.ignore_index, 140 | weight=weight) 141 | # Average the Dice score across all channels/classes 142 | return torch.mean(1. - per_channel_dice) 143 | 144 | 145 | class GeneralizedDiceLoss(nn.Module): 146 | """Computes Generalized Dice Loss (GDL) as described in https://arxiv.org/pdf/1707.03237.pdf 147 | """ 148 | 149 | def __init__(self, epsilon=1e-5, weight=None, ignore_index=None, sigmoid_normalization=True): 150 | super(GeneralizedDiceLoss, self).__init__() 151 | self.epsilon = epsilon 152 | self.register_buffer('weight', weight) 153 | self.ignore_index = ignore_index 154 | if sigmoid_normalization: 155 | self.normalization = nn.Sigmoid() 156 | else: 157 | self.normalization = nn.Softmax(dim=1) 158 | 159 | def forward(self, input, target): 160 | # get probabilities from logits 161 | input = self.normalization(input) 162 | 163 | assert input.size() == target.size(), "'input' and 'target' must have the same shape" 164 | 165 | # mask ignore_index if present 166 | if self.ignore_index is not None: 167 | mask = target.clone().ne_(self.ignore_index) 168 | mask.requires_grad = False 169 | 170 | input = input * mask 171 | target = target * mask 172 | 173 | input = flatten(input) 174 | target = flatten(target) 175 | 176 | target = target.float() 177 | target_sum = target.sum(-1) 178 | class_weights = Variable(1. / (target_sum * target_sum).clamp(min=self.epsilon), requires_grad=False) 179 | 180 | intersect = (input * target).sum(-1) * class_weights 181 | if self.weight is not None: 182 | weight = Variable(self.weight, requires_grad=False) 183 | intersect = weight * intersect 184 | 185 | denominator = (input + target).sum(-1) * class_weights 186 | 187 | return torch.mean(1. - 2. * intersect / denominator.clamp(min=self.epsilon)) 188 | 189 | 190 | class WeightedCrossEntropyLoss(nn.Module): 191 | """WeightedCrossEntropyLoss (WCE) as described in https://arxiv.org/pdf/1707.03237.pdf 192 | """ 193 | 194 | def __init__(self, weight=None, ignore_index=-1): 195 | super(WeightedCrossEntropyLoss, self).__init__() 196 | self.register_buffer('weight', weight) 197 | self.ignore_index = ignore_index 198 | 199 | def forward(self, input, target): 200 | class_weights = self._class_weights(input) 201 | if self.weight is not None: 202 | weight = Variable(self.weight, requires_grad=False) 203 | class_weights = class_weights * weight 204 | return F.cross_entropy(input, target, weight=class_weights, ignore_index=self.ignore_index) 205 | 206 | @staticmethod 207 | def _class_weights(input): 208 | # normalize the input first 209 | input = F.softmax(input, _stacklevel=5) 210 | flattened = flatten(input) 211 | nominator = (1. - flattened).sum(-1) 212 | denominator = flattened.sum(-1) 213 | class_weights = Variable(nominator / denominator, requires_grad=False) 214 | return class_weights 215 | 216 | 217 | class BCELossWrapper: 218 | """ 219 | Wrapper around BCE loss functions allowing to pass 'ignore_index' as well as 'skip_last_target' option. 220 | """ 221 | 222 | def __init__(self, loss_criterion, ignore_index=-1, skip_last_target=False): 223 | if hasattr(loss_criterion, 'ignore_index'): 224 | raise RuntimeError(f"Cannot wrap {type(loss_criterion)}. Use 'ignore_index' attribute instead") 225 | self.loss_criterion = loss_criterion 226 | self.ignore_index = ignore_index 227 | self.skip_last_target = skip_last_target 228 | 229 | def __call__(self, input, target): 230 | if self.skip_last_target: 231 | target = target[:, :-1, ...] 232 | 233 | assert input.size() == target.size() 234 | 235 | masked_input = input 236 | masked_target = target 237 | if self.ignore_index is not None: 238 | mask = target.clone().ne_(self.ignore_index) 239 | mask.requires_grad = False 240 | 241 | masked_input = input * mask 242 | masked_target = target * mask 243 | 244 | return self.loss_criterion(masked_input, masked_target) 245 | 246 | 247 | class PixelWiseCrossEntropyLoss(nn.Module): 248 | def __init__(self, class_weights=None, ignore_index=None): 249 | super(PixelWiseCrossEntropyLoss, self).__init__() 250 | self.register_buffer('class_weights', class_weights) 251 | self.ignore_index = ignore_index 252 | self.log_softmax = nn.LogSoftmax(dim=1) 253 | 254 | def forward(self, input, target, weights): 255 | assert target.size() == weights.size() 256 | # normalize the input 257 | log_probabilities = self.log_softmax(input) 258 | # standard CrossEntropyLoss requires the target to be (NxDxHxW), so we need to expand it to (NxCxDxHxW) 259 | target = expand_as_one_hot(target, C=input.size()[1], ignore_index=self.ignore_index) 260 | # expand weights 261 | weights = weights.unsqueeze(0) 262 | weights = weights.expand_as(input) 263 | 264 | # mask ignore_index if present 265 | if self.ignore_index is not None: 266 | mask = Variable(target.data.ne(self.ignore_index).float(), requires_grad=False) 267 | log_probabilities = log_probabilities * mask 268 | target = target * mask 269 | 270 | # apply class weights 271 | if self.class_weights is None: 272 | class_weights = torch.ones(input.size()[1]).float().to(input.device) 273 | else: 274 | class_weights = self.class_weights 275 | class_weights = class_weights.view(1, input.size()[1], 1, 1, 1) 276 | class_weights = Variable(class_weights, requires_grad=False) 277 | # add class_weights to each channel 278 | weights = class_weights + weights 279 | 280 | # compute the losses 281 | result = -weights * target * log_probabilities 282 | # average the losses 283 | return result.mean() 284 | 285 | 286 | def flatten(tensor): 287 | """Flattens a given tensor such that the channel axis is first. 288 | The shapes are transformed as follows: 289 | (N, C, D, H, W) -> (C, N * D * H * W) 290 | """ 291 | C = tensor.size(1) 292 | # new axis order 293 | axis_order = (1, 0) + tuple(range(2, tensor.dim())) 294 | # Transpose: (N, C, D, H, W) -> (C, N, D, H, W) 295 | transposed = tensor.permute(axis_order) 296 | # Flatten: (C, N, D, H, W) -> (C, N * D * H * W) 297 | return transposed.view(C, -1) 298 | 299 | 300 | def expand_as_one_hot(input, C, ignore_index=None): 301 | """ 302 | Converts NxDxHxW label image to NxCxDxHxW, where each label is stored in a separate channel 303 | :param input: 4D input image (NxDxHxW) 304 | :param C: number of channels/labels 305 | :param ignore_index: ignore index to be kept during the expansion 306 | :return: 5D output image (NxCxDxHxW) 307 | """ 308 | assert input.dim() == 4 309 | 310 | shape = input.size() 311 | shape = list(shape) 312 | shape.insert(1, C) 313 | shape = tuple(shape) 314 | 315 | # expand the input tensor to Nx1xDxHxW 316 | src = input.unsqueeze(0) 317 | 318 | if ignore_index is not None: 319 | # create ignore_index mask for the result 320 | expanded_src = src.expand(shape) 321 | mask = expanded_src == ignore_index 322 | # clone the src tensor and zero out ignore_index in the input 323 | src = src.clone() 324 | src[src == ignore_index] = 0 325 | # scatter to get the one-hot tensor 326 | result = torch.zeros(shape).to(input.device).scatter_(1, src, 1) 327 | # bring back the ignore_index in the result 328 | result[mask] = ignore_index 329 | return result 330 | else: 331 | # scatter to get the one-hot tensor 332 | return torch.zeros(shape).to(input.device).scatter_(1, src, 1) -------------------------------------------------------------------------------- /aicsmlsegment/custom_metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from skimage import measure 4 | from aicsmlsegment.custom_loss import MultiAuxillaryElementNLLLoss, compute_per_channel_dice, expand_as_one_hot 5 | 6 | class DiceCoefficient: 7 | """Computes Dice Coefficient. 8 | Generalized to multiple channels by computing per-channel Dice Score 9 | (as described in https://arxiv.org/pdf/1707.03237.pdf) and theTn simply taking the average. 10 | Input is expected to be probabilities instead of logits. 11 | This metric is mostly useful when channels contain the same semantic class (e.g. affinities computed with different offsets). 12 | DO NOT USE this metric when training with DiceLoss, otherwise the results will be biased towards the loss. 13 | """ 14 | 15 | def __init__(self, epsilon=1e-5, ignore_index=None): 16 | self.epsilon = epsilon 17 | self.ignore_index = ignore_index 18 | 19 | def __call__(self, input, target): 20 | """ 21 | :param input: 5D probability maps torch tensor (NxCxDxHxW) 22 | :param target: 4D or 5D ground truth torch tensor. 4D (NxDxHxW) tensor will be expanded to 5D as one-hot 23 | :return: Soft Dice Coefficient averaged over all channels/classes 24 | """ 25 | # Average across channels in order to get the final score 26 | return torch.mean(compute_per_channel_dice(input, target, epsilon=self.epsilon, ignore_index=self.ignore_index)) 27 | 28 | 29 | class MeanIoU: 30 | """ 31 | Computes IoU for each class separately and then averages over all classes. 32 | """ 33 | 34 | def __init__(self, skip_channels=(), ignore_index=None): 35 | """ 36 | :param skip_channels: list/tuple of channels to be ignored from the IoU computation 37 | :param ignore_index: id of the label to be ignored from IoU computation 38 | """ 39 | self.ignore_index = ignore_index 40 | self.skip_channels = skip_channels 41 | 42 | def __call__(self, input, target): 43 | """ 44 | :param input: 5D probability maps torch float tensor (NxCxDxHxW) 45 | :param target: 4D or 5D ground truth torch tensor. 4D (NxDxHxW) tensor will be expanded to 5D as one-hot 46 | :return: intersection over union averaged over all channels 47 | """ 48 | n_classes = input.size()[1] 49 | if target.dim() == 4: 50 | target = expand_as_one_hot(target, C=n_classes, ignore_index=self.ignore_index) 51 | 52 | # batch dim must be 1 53 | input = input[0] 54 | target = target[0] 55 | assert input.size() == target.size() 56 | 57 | binary_prediction = self._binarize_predictions(input) 58 | 59 | if self.ignore_index is not None: 60 | # zero out ignore_index 61 | mask = target == self.ignore_index 62 | binary_prediction[mask] = 0 63 | target[mask] = 0 64 | 65 | # convert to uint8 just in case 66 | binary_prediction = binary_prediction.bool() 67 | target = target.bool() 68 | 69 | per_channel_iou = [] 70 | for c in range(n_classes): 71 | if c in self.skip_channels: 72 | continue 73 | 74 | per_channel_iou.append(self._jaccard_index(binary_prediction[c], target[c])) 75 | 76 | assert per_channel_iou, "All channels were ignored from the computation" 77 | return torch.mean(torch.tensor(per_channel_iou)) 78 | 79 | def _binarize_predictions(self, input): 80 | """ 81 | Puts 1 for the class/channel with the highest probability and 0 in other channels. Returns byte tensor of the 82 | same size as the input tensor. 83 | """ 84 | _, max_index = torch.max(input, dim=0, keepdim=True) 85 | return torch.zeros_like(input, dtype=torch.uint8).scatter_(0, max_index, 1) 86 | 87 | def _jaccard_index(self, prediction, target): 88 | """ 89 | Computes IoU for a given target and prediction tensors 90 | """ 91 | return torch.sum(prediction & target).float() / torch.sum(prediction | target).float() 92 | 93 | 94 | class AveragePrecision: 95 | """ 96 | Computes Average Precision given boundary prediction and ground truth instance segmentation. 97 | """ 98 | 99 | def __init__(self, threshold=0.4, iou_range=(0.5, 1.0), ignore_index=-1, min_instance_size=None, 100 | use_last_target=False): 101 | """ 102 | :param threshold: probability value at which the input is going to be thresholded 103 | :param iou_range: compute ROC curve for the the range of IoU values: range(min,max,0.05) 104 | :param ignore_index: label to be ignored during computation 105 | :param min_instance_size: minimum size of the predicted instances to be considered 106 | :param use_last_target: if True use the last target channel to compute AP 107 | """ 108 | self.threshold = threshold 109 | # always have well defined ignore_index 110 | if ignore_index is None: 111 | ignore_index = -1 112 | self.iou_range = iou_range 113 | self.ignore_index = ignore_index 114 | self.min_instance_size = min_instance_size 115 | self.use_last_target = use_last_target 116 | 117 | def __call__(self, input, target): 118 | """ 119 | :param input: 5D probability maps torch float tensor (NxCxDxHxW) / or 4D numpy.ndarray 120 | :param target: 4D or 5D ground truth instance segmentation torch long tensor / or 3D numpy.ndarray 121 | :return: highest average precision among channels 122 | """ 123 | if isinstance(input, torch.Tensor): 124 | assert input.dim() == 5 125 | # convert to numpy array 126 | input = input[0].detach().cpu().numpy() # 4D 127 | if isinstance(target, torch.Tensor): 128 | if not self.use_last_target: 129 | assert target.dim() == 4 130 | # convert to numpy array 131 | target = target[0].detach().cpu().numpy() # 3D 132 | else: 133 | # if use_last_target == True the target must be 5D (NxCxDxHxW) 134 | assert target.dim() == 5 135 | target = target[0, -1].detach().cpu().numpy() # 3D 136 | 137 | if isinstance(input, np.ndarray): 138 | assert input.ndim == 4 139 | if isinstance(target, np.ndarray): 140 | assert target.ndim == 3 141 | 142 | # filter small instances from the target and get ground truth label set (without 'ignore_index') 143 | target, target_instances = self._filter_instances(target) 144 | 145 | per_channel_ap = [] 146 | n_channels = input.shape[0] 147 | for c in range(n_channels): 148 | predictions = input[c] 149 | # threshold probability maps 150 | predictions = predictions > self.threshold 151 | # for connected component analysis we need to treat boundary signal as background 152 | # assign 0-label to boundary mask 153 | predictions = np.logical_not(predictions).astype(np.uint8) 154 | # run connected components on the predicted mask; consider only 1-connectivity 155 | predicted = measure.label(predictions, background=0, connectivity=1) 156 | ap = self._calculate_average_precision(predicted, target, target_instances) 157 | per_channel_ap.append(ap) 158 | 159 | # get maximum average precision across channels 160 | max_ap, c_index = np.max(per_channel_ap), np.argmax(per_channel_ap) 161 | #LOGGER.info(f'Max average precision: {max_ap}, channel: {c_index}') 162 | return max_ap 163 | 164 | def _calculate_average_precision(self, predicted, target, target_instances): 165 | recall, precision = self._roc_curve(predicted, target, target_instances) 166 | recall.insert(0, 0.0) # insert 0.0 at beginning of list 167 | recall.append(1.0) # insert 1.0 at end of list 168 | precision.insert(0, 0.0) # insert 0.0 at beginning of list 169 | precision.append(0.0) # insert 0.0 at end of list 170 | # make the precision(recall) piece-wise constant and monotonically decreasing 171 | # by iterating backwards starting from the last precision value (0.0) 172 | # see: https://www.jeremyjordan.me/evaluating-image-segmentation-models/ e.g. 173 | for i in range(len(precision) - 2, -1, -1): 174 | precision[i] = max(precision[i], precision[i + 1]) 175 | # compute the area under precision recall curve by simple integration of piece-wise constant function 176 | ap = 0.0 177 | for i in range(1, len(recall)): 178 | ap += ((recall[i] - recall[i - 1]) * precision[i]) 179 | return ap 180 | 181 | def _roc_curve(self, predicted, target, target_instances): 182 | ROC = [] 183 | predicted, predicted_instances = self._filter_instances(predicted) 184 | 185 | # compute precision/recall curve points for various IoU values from a given range 186 | for min_iou in np.arange(self.iou_range[0], self.iou_range[1], 0.1): 187 | # initialize false negatives set 188 | false_negatives = set(target_instances) 189 | # initialize false positives set 190 | false_positives = set(predicted_instances) 191 | # initialize true positives set 192 | true_positives = set() 193 | 194 | for pred_label in predicted_instances: 195 | target_label = self._find_overlapping_target(pred_label, predicted, target, min_iou) 196 | if target_label is not None: 197 | # update TP, FP and FN 198 | if target_label == self.ignore_index: 199 | # ignore if 'ignore_index' is the biggest overlapping 200 | false_positives.discard(pred_label) 201 | else: 202 | true_positives.add(pred_label) 203 | false_positives.discard(pred_label) 204 | false_negatives.discard(target_label) 205 | 206 | tp = len(true_positives) 207 | fp = len(false_positives) 208 | fn = len(false_negatives) 209 | 210 | recall = tp / (tp + fn) 211 | precision = tp / (tp + fp) 212 | ROC.append((recall, precision)) 213 | 214 | # sort points by recall 215 | ROC = np.array(sorted(ROC, key=lambda t: t[0])) 216 | # return recall and precision values 217 | return list(ROC[:, 0]), list(ROC[:, 1]) 218 | 219 | def _find_overlapping_target(self, predicted_label, predicted, target, min_iou): 220 | """ 221 | Return ground truth label which overlaps by at least 'min_iou' with a given input label 'p_label' 222 | or None if such ground truth label does not exist. 223 | """ 224 | mask_predicted = predicted == predicted_label 225 | overlapping_labels = target[mask_predicted] 226 | labels, counts = np.unique(overlapping_labels, return_counts=True) 227 | # retrieve the biggest overlapping label 228 | target_label_ind = np.argmax(counts) 229 | target_label = labels[target_label_ind] 230 | # return target label if IoU greater than 'min_iou'; since we're starting from 0.5 IoU there might be 231 | # only one target label that fulfill this criterion 232 | mask_target = target == target_label 233 | # return target_label if IoU > min_iou 234 | if self._iou(mask_predicted, mask_target) > min_iou: 235 | return target_label 236 | return None 237 | 238 | @staticmethod 239 | def _iou(prediction, target): 240 | """ 241 | Computes intersection over union 242 | """ 243 | intersection = np.logical_and(prediction, target) 244 | union = np.logical_or(prediction, target) 245 | return np.sum(intersection) / np.sum(union) 246 | 247 | def _filter_instances(self, input): 248 | """ 249 | Filters instances smaller than 'min_instance_size' by overriding them with 'ignore_index' 250 | :param input: input instance segmentation 251 | :return: tuple: (instance segmentation with small instances filtered, set of unique labels without the 'ignore_index') 252 | """ 253 | if self.min_instance_size is not None: 254 | labels, counts = np.unique(input, return_counts=True) 255 | for label, count in zip(labels, counts): 256 | if count < self.min_instance_size: 257 | mask = input == label 258 | input[mask] = self.ignore_index 259 | 260 | labels = set(np.unique(input)) 261 | labels.discard(self.ignore_index) 262 | return input, labels -------------------------------------------------------------------------------- /aicsmlsegment/model_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.autograd import Variable 4 | import time 5 | import logging 6 | import os 7 | import shutil 8 | import sys 9 | 10 | from aicsmlsegment.utils import get_logger 11 | 12 | SUPPORTED_MODELS = ['unet_xy_zoom', 'unet_xy'] 13 | 14 | def weights_init(m): 15 | classname = m.__class__.__name__ 16 | if classname.find('Conv3d') != -1: 17 | torch.nn.init.kaiming_normal_(m.weight) 18 | m.bias.data.zero_() 19 | 20 | def apply_on_image(model, input_img, softmax, args): 21 | 22 | if not args.RuntimeAug: 23 | return model_inference(model, input_img, softmax, args) 24 | else: 25 | from PIL import Image 26 | print('doing runtime augmentation') 27 | 28 | input_img_aug = input_img.copy() 29 | for ch_idx in range(input_img_aug.shape[0]): 30 | str_im = input_img_aug[ch_idx,:,:,:] 31 | input_img_aug[ch_idx,:,:,:] = np.flip(str_im, axis=2) 32 | 33 | out1 = model_inference(model, input_img_aug, softmax, args) 34 | 35 | input_img_aug = [] 36 | input_img_aug = input_img.copy() 37 | for ch_idx in range(input_img_aug.shape[0]): 38 | str_im = input_img_aug[ch_idx,:,:,:] 39 | input_img_aug[ch_idx,:,:,:] = np.flip(str_im, axis=1) 40 | 41 | out2 = model_inference(model, input_img_aug, softmax, args) 42 | 43 | input_img_aug = [] 44 | input_img_aug = input_img.copy() 45 | for ch_idx in range(input_img_aug.shape[0]): 46 | str_im = input_img_aug[ch_idx,:,:,:] 47 | input_img_aug[ch_idx,:,:,:] = np.flip(str_im, axis=0) 48 | 49 | out3 = model_inference(model, input_img_aug, softmax, args) 50 | 51 | out0 = model_inference(model, input_img, softmax, args) 52 | 53 | for ch_idx in range(len(out0)): 54 | out0[ch_idx] = 0.25*(out0[ch_idx] + np.flip(out1[ch_idx], axis=3) + np.flip(out2[ch_idx], axis=2) + np.flip(out3[ch_idx], axis=1)) 55 | 56 | return out0 57 | 58 | def model_inference(model, input_img, softmax, args): 59 | 60 | model.eval() 61 | 62 | if args.size_in == args.size_out: 63 | img_pad = np.np.expand_dims(input_img, axis=0) # add batch dimension 64 | else: # zero padding on input image 65 | padding = [(x-y)//2 for x,y in zip(args.size_in, args.size_out)] 66 | img_pad0 = np.pad(input_img, ((0,0),(0,0),(padding[1],padding[1]),(padding[2],padding[2])), 'symmetric')#'constant') 67 | img_pad = np.pad(img_pad0, ((0,0),(padding[0],padding[0]),(0,0),(0,0)), 'constant') 68 | 69 | output_img = [] 70 | for ch_idx in range(len(args.OutputCh)//2): 71 | output_img.append(np.zeros(input_img.shape)) 72 | 73 | # loop through the image patch by patch 74 | num_step_z = int(np.floor(input_img.shape[1]/args.size_out[0])+1) 75 | num_step_y = int(np.floor(input_img.shape[2]/args.size_out[1])+1) 76 | num_step_x = int(np.floor(input_img.shape[3]/args.size_out[2])+1) 77 | 78 | with torch.no_grad(): 79 | for ix in range(num_step_x): 80 | if ix0 and leaveout[0]<1: 68 | num_train = int(np.floor((1-leaveout[0]) * total_num)) 69 | shuffled_idx = np.arange(total_num) 70 | random.shuffle(shuffled_idx) 71 | train_idx = shuffled_idx[:num_train] 72 | valid_idx = shuffled_idx[num_train:] 73 | else: 74 | valid_idx = [int(leaveout[0])] 75 | train_idx = list(set(range(total_num)) - set(map(int, leaveout))) 76 | elif leaveout: 77 | valid_idx = list(map(int, leaveout)) 78 | train_idx = list(set(range(total_num)) - set(valid_idx)) 79 | 80 | valid_filenames = [] 81 | train_filenames = [] 82 | for _, fn in enumerate(valid_idx): 83 | valid_filenames.append(filenames[fn][:-11]) 84 | for _, fn in enumerate(train_idx): 85 | train_filenames.append(filenames[fn][:-11]) 86 | 87 | return train_filenames, valid_filenames 88 | 89 | class BasicFolderTrainer: 90 | """basic version of trainer. 91 | Args: 92 | model: model to be trained 93 | optimizer (nn.optim.Optimizer): optimizer used for training 94 | loss_criterion (callable): loss function 95 | loaders (dict): 'train' and 'val' loaders 96 | checkpoint_dir (string): dir for saving checkpoints and tensorboard logs 97 | """ 98 | 99 | def __init__(self, model, config, logger=None): 100 | 101 | if logger is None: 102 | self.logger = get_logger('ModelTrainer', level=logging.DEBUG) 103 | else: 104 | self.logger = logger 105 | 106 | device = config['device'] 107 | self.logger.info(f"Sending the model to '{device}'") 108 | self.model = model.to(device) 109 | self.logger.debug(model) 110 | 111 | #self.optimizer = optimizer 112 | #self.scheduler = lr_scheduler 113 | #self.loss_criterion = loss_criterion 114 | self.device = device 115 | #self.loaders = loaders 116 | self.config = config 117 | 118 | 119 | def train(self): 120 | 121 | ### load settings ### 122 | config = self.config #TODO, fix this 123 | model = self.model 124 | 125 | # define loss 126 | #TODO, add more loss 127 | loss_config = config['loss'] 128 | if loss_config['name']=='Aux': 129 | criterion = MultiAuxillaryElementNLLLoss(3,loss_config['loss_weight'], config['nclass']) 130 | else: 131 | print('do not support other loss yet') 132 | quit() 133 | 134 | # dataloader 135 | validation_config = config['validation'] 136 | loader_config = config['loader'] 137 | args_inference=lambda:None 138 | if validation_config['metric'] is not None: 139 | print('prepare the data ... ...') 140 | filenames = glob(loader_config['datafolder'] + '/*_GT.ome.tif') 141 | filenames.sort() 142 | total_num = len(filenames) 143 | LeaveOut = validation_config['leaveout'] 144 | if len(LeaveOut)==1: 145 | if LeaveOut[0]>0 and LeaveOut[0]<1: 146 | num_train = int(np.floor((1-LeaveOut[0]) * total_num)) 147 | shuffled_idx = np.arange(total_num) 148 | random.shuffle(shuffled_idx) 149 | train_idx = shuffled_idx[:num_train] 150 | valid_idx = shuffled_idx[num_train:] 151 | else: 152 | valid_idx = [int(LeaveOut[0])] 153 | train_idx = list(set(range(total_num)) - set(map(int, LeaveOut))) 154 | elif LeaveOut: 155 | valid_idx = list(map(int, LeaveOut)) 156 | train_idx = list(set(range(total_num)) - set(valid_idx)) 157 | 158 | valid_filenames = [] 159 | train_filenames = [] 160 | for fi, fn in enumerate(valid_idx): 161 | valid_filenames.append(filenames[fn][:-11]) 162 | for fi, fn in enumerate(train_idx): 163 | train_filenames.append(filenames[fn][:-11]) 164 | 165 | args_inference.size_in = config['size_in'] 166 | args_inference.size_out = config['size_out'] 167 | args_inference.OutputCh = validation_config['OutputCh'] 168 | args_inference.nclass = config['nclass'] 169 | 170 | else: 171 | #TODO, update here 172 | print('need validation') 173 | quit() 174 | 175 | if loader_config['name']=='default': 176 | from aicsmlsegment.DataLoader3D.Universal_Loader import RR_FH_M0 as train_loader 177 | train_set_loader = DataLoader(train_loader(train_filenames, loader_config['PatchPerBuffer'], config['size_in'], config['size_out']), num_workers=loader_config['NumWorkers'], batch_size=loader_config['batch_size'], shuffle=True) 178 | elif loader_config['name']=='focus': 179 | from aicsmlsegment.DataLoader3D.Universal_Loader import RR_FH_M0C as train_loader 180 | train_set_loader = DataLoader(train_loader(train_filenames, loader_config['PatchPerBuffer'], config['size_in'], config['size_out']), num_workers=loader_config['NumWorkers'], batch_size=loader_config['batch_size'], shuffle=True) 181 | else: 182 | print('other loader not support yet') 183 | quit() 184 | 185 | num_iterations = 0 186 | num_epoch = 0 #TODO: load num_epoch from checkpoint 187 | 188 | start_epoch = num_epoch 189 | for _ in range(start_epoch, config['epochs']+1): 190 | 191 | # sets the model in training mode 192 | model.train() 193 | 194 | optimizer = None 195 | optimizer = optim.Adam(model.parameters(),lr = config['learning_rate'], weight_decay = config['weight_decay']) 196 | 197 | # check if re-load on training data in needed 198 | if num_epoch>0 and num_epoch % loader_config['epoch_shuffle'] ==0: 199 | print('shuffling data') 200 | train_set_loader = None 201 | train_set_loader = DataLoader(train_loader(train_filenames, loader_config['PatchPerBuffer'], config['size_in'], config['size_out']), num_workers=loader_config['NumWorkers'], batch_size=loader_config['batch_size'], shuffle=True) 202 | 203 | # Training starts ... 204 | epoch_loss = [] 205 | 206 | for i, current_batch in tqdm(enumerate(train_set_loader)): 207 | 208 | inputs = Variable(current_batch[0].cuda()) 209 | targets = current_batch[1] 210 | outputs = model(inputs) 211 | 212 | if len(targets)>1: 213 | for zidx in range(len(targets)): 214 | targets[zidx] = Variable(targets[zidx].cuda()) 215 | else: 216 | targets = Variable(targets[0].cuda()) 217 | 218 | optimizer.zero_grad() 219 | if len(current_batch)==3: # input + target + cmap 220 | cmap = Variable(current_batch[2].cuda()) 221 | loss = criterion(outputs, targets, cmap) 222 | else: # input + target 223 | loss = criterion(outputs,targets) 224 | 225 | loss.backward() 226 | optimizer.step() 227 | 228 | epoch_loss.append(loss.data.item()) 229 | num_iterations += 1 230 | 231 | average_training_loss = sum(epoch_loss) / len(epoch_loss) 232 | 233 | # validation 234 | if num_epoch % validation_config['validate_every_n_epoch'] ==0: 235 | validation_loss = np.zeros((len(validation_config['OutputCh'])//2,)) 236 | model.eval() 237 | 238 | for img_idx, fn in enumerate(valid_filenames): 239 | 240 | # target 241 | label = np.squeeze(imread(fn+'_GT.ome.tif')) 242 | label = np.expand_dims(label, axis=0) 243 | 244 | # input image 245 | input_img = np.squeeze(imread(fn+'.ome.tif')) 246 | if len(input_img.shape) == 3: 247 | # add channel dimension 248 | input_img = np.expand_dims(input_img, axis=0) 249 | elif len(input_img.shape) == 4: 250 | # assume number of channel < number of Z, make sure channel dim comes first 251 | if input_img.shape[0] > input_img.shape[1]: 252 | input_img = np.transpose(input_img, (1, 0, 2, 3)) 253 | 254 | # cmap tensor 255 | costmap = np.squeeze(imread(fn+'_CM.ome.tif')) 256 | 257 | # output 258 | outputs = model_inference(model, input_img, model.final_activation, args_inference) 259 | 260 | assert len(validation_config['OutputCh'])//2 == len(outputs) 261 | 262 | for vi in range(len(outputs)): 263 | if label.shape[0]==1: # the same label for all output 264 | validation_loss[vi] += compute_iou(outputs[vi][0,:,:,:]>0.5, label[0,:,:,:]==validation_config['OutputCh'][2*vi+1], costmap) 265 | else: 266 | validation_loss[vi] += compute_iou(outputs[vi][0,:,:,:]>0.5, label[vi,:,:,:]==validation_config['OutputCh'][2*vi+1], costmap) 267 | 268 | average_validation_loss = validation_loss / len(valid_filenames) 269 | print(f'Epoch: {num_epoch}, Training Loss: {average_training_loss}, Validation loss: {average_validation_loss}') 270 | else: 271 | print(f'Epoch: {num_epoch}, Training Loss: {average_training_loss}') 272 | 273 | 274 | if num_epoch % config['save_every_n_epoch'] == 0: 275 | save_checkpoint({ 276 | 'epoch': num_epoch, 277 | 'num_iterations': num_iterations, 278 | 'model_state_dict': model.state_dict(), 279 | #'best_val_score': self.best_val_score, 280 | 'optimizer_state_dict': optimizer.state_dict(), 281 | 'device': str(self.device), 282 | }, checkpoint_dir=config['checkpoint_dir'], logger=self.logger) 283 | num_epoch += 1 284 | 285 | # TODO: add validation step 286 | 287 | def _log_lr(self): 288 | lr = self.optimizer.param_groups[0]['lr'] 289 | self.writer.add_scalar('learning_rate', lr, self.num_iterations) 290 | 291 | def _log_stats(self, phase, loss_avg, eval_score_avg): 292 | tag_value = { 293 | f'{phase}_loss_avg': loss_avg, 294 | f'{phase}_eval_score_avg': eval_score_avg 295 | } 296 | 297 | for tag, value in tag_value.items(): 298 | self.writer.add_scalar(tag, value, self.num_iterations) 299 | 300 | def _log_params(self): 301 | self.logger.info('Logging model parameters and gradients') 302 | for name, value in self.model.named_parameters(): 303 | self.writer.add_histogram(name, value.data.cpu().numpy(), 304 | self.num_iterations) 305 | self.writer.add_histogram(name + '/grad', 306 | value.grad.data.cpu().numpy(), 307 | self.num_iterations) 308 | 309 | def _log_images(self, input, target, prediction): 310 | sources = { 311 | 'inputs': input.data.cpu().numpy(), 312 | 'targets': target.data.cpu().numpy(), 313 | 'predictions': prediction.data.cpu().numpy() 314 | } 315 | for name, batch in sources.items(): 316 | for tag, image in self._images_from_batch(name, batch): 317 | self.writer.add_image(tag, image, self.num_iterations, dataformats='HW') -------------------------------------------------------------------------------- /aicsmlsegment/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import logging 3 | import sys 4 | from typing import List 5 | from aicsimageio import AICSImage 6 | from scipy.ndimage import zoom 7 | import os 8 | from scipy import ndimage as ndi 9 | from scipy import stats 10 | import argparse 11 | 12 | import yaml 13 | 14 | def load_config(config_path): 15 | import torch 16 | config = _load_config_yaml(config_path) 17 | # Get a device to train on 18 | device_name = config.get('device', 'cuda:0') 19 | device = torch.device(device_name if torch.cuda.is_available() else 'cpu') 20 | config['device'] = device 21 | return config 22 | 23 | 24 | def _load_config_yaml(config_file): 25 | return yaml.load(open(config_file, 'r')) 26 | 27 | def get_samplers(num_training_data, validation_ratio, my_seed): 28 | from torch.utils.data import sampler as torch_sampler 29 | indices = list(range(num_training_data)) 30 | split = int(np.floor(validation_ratio * num_training_data)) 31 | 32 | np.random.seed(my_seed) 33 | np.random.shuffle(indices) 34 | 35 | train_idx, valid_idx = indices[split:], indices[:split] 36 | 37 | train_sampler = torch_sampler.SubsetRandomSampler(train_idx) 38 | valid_sampler = torch_sampler.SubsetRandomSampler(valid_idx) 39 | 40 | return train_sampler, valid_sampler 41 | 42 | def simple_norm(img, a, b, m_high=-1, m_low=-1): 43 | idx = np.ones(img.shape, dtype=bool) 44 | if m_high>0: 45 | idx = np.logical_and(idx, img0: 47 | idx = np.logical_and(idx, img>m_low) 48 | img_valid = img[idx] 49 | m,s = stats.norm.fit(img_valid.flat) 50 | strech_min = max(m - a*s, img.min()) 51 | strech_max = min(m + b*s, img.max()) 52 | img[img>strech_max]=strech_max 53 | img[imgstrech_max]=strech_max 76 | struct_img[struct_img4000] 91 | m,s = stats.norm.fit(img_valid.flat) 92 | m,s = stats.norm.fit(struct_img.flat) 93 | strech_min = struct_img.min() 94 | strech_max = min(m + 25 *s, struct_img.max()) 95 | struct_img[struct_img>strech_max]=strech_max 96 | struct_img = (struct_img- strech_min + 1e-8)/(strech_max - strech_min + 1e-8) 97 | img[ch_idx,:,:,:] = struct_img[:,:,:] 98 | elif args.Normalization == 12: # nuc 99 | struct_img = background_sub(struct_img,50) 100 | struct_img = simple_norm(struct_img, 2.5, 10) 101 | img[ch_idx,:,:,:] = struct_img[:,:,:] 102 | print('subtracted background') 103 | elif args.Normalization == 11: 104 | struct_img = background_sub(struct_img,50) 105 | #struct_img = simple_norm(struct_img, 2.5, 10) 106 | img[ch_idx,:,:,:] = struct_img[:,:,:] 107 | elif args.Normalization == 13: # cellmask 108 | #struct_img[struct_img>10000] = struct_img.min() 109 | struct_img = background_sub(struct_img,50) 110 | struct_img = simple_norm(struct_img, 2, 11) 111 | img[ch_idx,:,:,:] = struct_img[:,:,:] 112 | elif args.Normalization == 14: 113 | struct_img = simple_norm(struct_img, 1, 10) 114 | img[ch_idx,:,:,:] = struct_img[:,:,:] 115 | elif args.Normalization == 15: # lamin 116 | struct_img[struct_img>4000] = struct_img.min() 117 | struct_img = background_sub(struct_img,50) 118 | img[ch_idx,:,:,:] = struct_img[:,:,:] 119 | elif args.Normalization == 16: # lamin/h2b 120 | struct_img = background_sub(struct_img,50) 121 | struct_img = simple_norm(struct_img, 1.5, 6) 122 | img[ch_idx,:,:,:] = struct_img[:,:,:] 123 | elif args.Normalization == 17: # lamin 124 | struct_img = background_sub(struct_img,50) 125 | struct_img = simple_norm(struct_img, 1, 10) 126 | img[ch_idx,:,:,:] = struct_img[:,:,:] 127 | elif args.Normalization == 18: # h2b 128 | struct_img = background_sub(struct_img,50) 129 | struct_img = simple_norm(struct_img, 1.5, 10) 130 | img[ch_idx,:,:,:] = struct_img[:,:,:] 131 | else: 132 | print('no normalization recipe found') 133 | quit() 134 | return img 135 | 136 | def image_normalization(img, config): 137 | 138 | if type(config) is dict: 139 | ops = config['ops'] 140 | nchannel = img.shape[0] 141 | assert len(ops) == nchannel 142 | for ch_idx in range(nchannel): 143 | ch_ops = ops[ch_idx]['ch'] 144 | struct_img = img[ch_idx,:,:,:] 145 | for transform in ch_ops: 146 | if transform['name'] == 'background_sub': 147 | struct_img = background_sub(struct_img, transform['sigma']) 148 | elif transform['name'] =='auto_contrast': 149 | param = transform['param'] 150 | if len(param)==2: 151 | struct_img = simple_norm(struct_img, param[0], param[1]) 152 | elif len(param)==4: 153 | struct_img = simple_norm(struct_img, param[0], param[1], param[2], param[3]) 154 | else: 155 | print('bad paramter for auto contrast') 156 | quit() 157 | else: 158 | print(transform['name']) 159 | print('other normalization methods are not supported yet') 160 | quit() 161 | 162 | img[ch_idx,:,:,:] = struct_img[:,:,:] 163 | else: 164 | args_norm = lambda:None 165 | args_norm.Normalization = config 166 | 167 | img = input_normalization(img, args_norm) 168 | 169 | return img 170 | 171 | def load_single_image(args, fn, time_flag=False): 172 | 173 | if time_flag: 174 | img = fn[:,args.InputCh,:,:] 175 | img = img.astype(float) 176 | img = np.transpose(img, axes=(1,0,2,3)) 177 | else: 178 | data_reader = AICSImage(fn) 179 | if isinstance(args.InputCh, List): 180 | channel_list = args.InputCh 181 | else: 182 | channel_list = [args.InputCh] 183 | img = data_reader.get_image_data('CZYX', S=0, T=0, C=channel_list) 184 | 185 | # normalization 186 | if args.mode == 'train': 187 | for ch_idx in range(args.nchannel): 188 | struct_img = img[ch_idx,:,:,:] # note that struct_img is only a view of img, so changes made on struct_img also affects img 189 | struct_img = (struct_img - struct_img.min() )/(struct_img.max() - struct_img.min()) 190 | elif not args.Normalization == 0: 191 | img = input_normalization(img, args) 192 | 193 | # rescale 194 | if len(args.ResizeRatio)>0: 195 | img = zoom(img, (1, args.ResizeRatio[0], args.ResizeRatio[1], args.ResizeRatio[2]), order=1) 196 | 197 | return img 198 | 199 | 200 | def compute_iou(prediction, gt, cmap): 201 | 202 | area_i = np.logical_and(prediction, gt) 203 | area_i[cmap==0]=False 204 | area_u = np.logical_or(prediction, gt) 205 | area_u[cmap==0]=False 206 | 207 | return np.count_nonzero(area_i) / np.count_nonzero(area_u) 208 | 209 | def get_logger(name, level=logging.INFO): 210 | logger = logging.getLogger(name) 211 | logger.setLevel(level) 212 | # Logging to console 213 | stream_handler = logging.StreamHandler(sys.stdout) 214 | formatter = logging.Formatter( 215 | '%(asctime)s [%(threadName)s] %(levelname)s %(name)s - %(message)s') 216 | stream_handler.setFormatter(formatter) 217 | logger.addHandler(stream_handler) 218 | 219 | return logger 220 | -------------------------------------------------------------------------------- /aicsmlsegment/version.py: -------------------------------------------------------------------------------- 1 | # Autogenerated file - do NOT edit this by hand 2 | MODULE_VERSION = "0.0.8.dev0" 3 | 4 | # For snapshot, X.Y.Z.devN -> X.Y.Z.devN+1 5 | # bumpversion devbuild 6 | # 7 | # For release, X.Y.Z.devN -> X.Y.Z 8 | # bumpversion release 9 | # DO NOT CALL release on consecutive calls 10 | # DO NOT CALL release on 0.0.0.devN 11 | # 12 | # For preparing for next development cycle after release 13 | # bumpversion patch (X.Y.Z -> X.Y.Z+1.dev0) 14 | # bumpversion minor (X.Y.Z -> X.Y+1.Z.dev0) 15 | # bumpversion major (X.Y.Z -> X+1.Y.Z.dev0) 16 | # 17 | -------------------------------------------------------------------------------- /build.gradle: -------------------------------------------------------------------------------- 1 | buildscript { 2 | def buildScriptPlugins = ['scripts/common/buildscript-5.gradle'] 3 | println "> Applying script plugins in buildscripts:" 4 | for (scriptPlugin in buildScriptPlugins) { 5 | def pluginPath = "${scriptPluginPrefix}${scriptPlugin}${scriptPluginSuffix}${scriptPluginTag}" 6 | println "${pluginPath}" 7 | apply from: pluginPath, to: buildscript 8 | } 9 | } 10 | 11 | ////////////////////////////////////////////////////////////////////////////////////////////////////// 12 | 13 | def scriptPlugins = ['scripts/common/gradle-version-5.gradle', 14 | 'scripts/common/common-5.gradle', 15 | 'scripts/python/build.gradle', 16 | 'scripts/python/version.gradle', 17 | 'scripts/python/publish.gradle'] 18 | println "> Applying script plugins:" 19 | for (scriptPlugin in scriptPlugins) { 20 | def pluginPath = "${scriptPluginPrefix}${scriptPlugin}${scriptPluginSuffix}${scriptPluginTag}" 21 | println "${pluginPath}" 22 | apply from: pluginPath 23 | } 24 | 25 | 26 | // Add the environment variable to gradle for coverage report 27 | // Do not add this to setup.cfg since it will break IDE tools 28 | py.env.put("PYTEST_ADDOPTS", "--cov=${rootProject.name} --cov-config=setup.cfg --cov-report=html --cov-report=xml --cov-report=term") 29 | 30 | 31 | ////////////////////////////////////////////////////////////////////////////////////////////////////// 32 | py.uploadToPyPi = true 33 | project.group = "org.alleninstitute.aics.pypi" 34 | description = "AICS ML segmentation" 35 | // Project version will be managed outside of gradle in accordance with PEP 440 36 | // ("https://www.python.org/dev/peps/pep-0440/") 37 | 38 | ////////////////////////////////////////////////////////////////////////////////////////////////////// 39 | -------------------------------------------------------------------------------- /configs/predict_file_config.yaml: -------------------------------------------------------------------------------- 1 | ########################################################################################################## 2 | # model settings 3 | ########################################################################################################## 4 | model: 5 | name: unet_xy_zoom 6 | zoom_ratio: 3 7 | # path to the trained model 8 | model_path: '//allen/aics/assay-dev/Segmentation/DeepLearning/SavedModel/LAMINB1/20190204_01/unet_xy_p3-300-default.pth' 9 | # number of input channels to the model 10 | nchannel: 1 11 | # number of output channels 12 | nclass: [2,2,2] 13 | # the channel to extract from output tensors 14 | OutputCh: [0, 1] 15 | # input patch size given to the network (adapt to fit in your GPU memory, generally bigger patches are better) 16 | size_in: [52, 420, 420] 17 | # prediction patch size from the network (change according to input size) 18 | size_out: [20, 152, 152] 19 | 20 | ########################################################################################################## 21 | # Data Info 22 | ########################################################################################################## 23 | # the path to output folder 24 | OutputDir: '/allen/aics/assay-dev/Segmentation/Lamin_segmentation/' 25 | # the index of the input channel 26 | InputCh: [-1] 27 | # the ratio to resize the image 28 | ResizeRatio: [1.0,1.0,1.0] 29 | # the method to normalize your data 30 | Normalization: 10 31 | # the threshold to be applied on your data 32 | Threshold: 0.5 33 | # whether to use run time augmentation (may improve the accuracy, but takes much longer to run) 34 | RuntimeAug: False 35 | 36 | ########################################################################################################## 37 | # Execution mode: single file 38 | ########################################################################################################## 39 | mode: 40 | name: file 41 | # paths to the file 42 | InputFile: '/allen/aics/assay-dev/Segmentation/Lamin_segmentation/raw/3/3500000943_100X_20170530_1-Scene-3-P3-E04.czi' 43 | # is this file a timelapse image 44 | timelapse: False 45 | 46 | 47 | -------------------------------------------------------------------------------- /configs/predict_folder_config.yaml: -------------------------------------------------------------------------------- 1 | ########################################################################################################## 2 | # model settings 3 | ########################################################################################################## 4 | model: 5 | name: unet_xy 6 | # path to the trained model 7 | model_path: '/allen/aics/assay-dev/Segmentation/DeepLearning/for_april_2019_release/LMNB1_saved_model_iter_2/checkpoint_epoch_400.pytorch' 8 | # number of input channels to the model 9 | nchannel: 1 10 | # number of output channels 11 | nclass: [2,2,2] 12 | # the channel to extract from output tensors 13 | OutputCh: [0, 1] 14 | # input patch size given to the network (adapt to fit in your GPU memory, generally bigger patches are better) 15 | size_in: [88, 180, 180] #[62, 420, 420] 16 | # prediction patch size from the network (change according to input size) 17 | size_out: [60, 92, 92] # [30, 152, 152] 18 | 19 | ########################################################################################################## 20 | # Data Info 21 | ########################################################################################################## 22 | # the path to output folder 23 | OutputDir: '/allen/aics/assay-dev/Segmentation/DeepLearning/for_april_2019_release/LMNB1_test/400' 24 | # the index of the input channel 25 | InputCh: [0] 26 | # the ratio to resize the image 27 | ResizeRatio: [1.0, 1.0, 1.0] 28 | # the method to normalize your data 29 | Normalization: 10 30 | # the threshold to be applied on your data 31 | Threshold: 0.5 #0.5 # 0.3 32 | # whether to use run time augmentation (may improve the accuracy, but takes much longer to run) 33 | RuntimeAug: False 34 | 35 | ########################################################################################################## 36 | # Execution mode: all files of specific type within a directory 37 | ########################################################################################################## 38 | mode: 39 | name: folder 40 | # paths to the file 41 | InputDir: '/allen/aics/assay-dev/Segmentation/DeepLearning/for_april_2019_release/LMNB1_test' 42 | # the type of images to be processed in this folder 43 | DataType: .tiff 44 | 45 | 46 | 47 | -------------------------------------------------------------------------------- /configs/train_config.yaml: -------------------------------------------------------------------------------- 1 | ########################################################################################################## 2 | # model settings 3 | ########################################################################################################## 4 | model: 5 | name: unet_xy 6 | # number of input channels to the model 7 | nchannel: 1 8 | # number of output channels, 9 | nclass: [2,2,2] 10 | # input patch size given to the network (adapt to fit in your GPU memory, generally bigger patches are better) 11 | size_in: [50, 156, 156] 12 | # prediction patch size from the network (change according to input size) 13 | size_out: [22, 68, 68] 14 | # path to save the checkpoint 15 | checkpoint_dir: '/allen/aics/assay-dev/Segmentation/DeepLearning/for_april_2019_release/LMNB1_saved_model_iter_2' 16 | # path to latest checkpoint; if provided the training will be resumed from that checkpoint 17 | resume: '/allen/aics/assay-dev/Segmentation/DeepLearning/for_april_2019_release/LMNB1_saved_model_iter_1/checkpoint_epoch_400.pytorch' 18 | 19 | ########################################################################################################## 20 | # training precedure setting 21 | ########################################################################################################## 22 | # initial learning rate 23 | learning_rate: 0.00001 24 | # weight decay 25 | weight_decay: 0.005 26 | # max number of epochs 27 | epochs: 400 28 | # number of epoch to save the model 29 | save_every_n_epoch: 50 30 | # loss function configuration 31 | loss: 32 | # loss function to be used during training (Aux - Training with auxillary loss) 33 | name: Aux 34 | # A manual rescaling weight given to each auxilluary loss. 35 | loss_weight: [1, 1, 1] 36 | # a target value that is ignored and does not contribute to the input gradient 37 | ignore_index: null 38 | 39 | ########################################################################################################## 40 | # data loaders configuration 41 | ########################################################################### 42 | loader: 43 | name: default 44 | # paths to the training datasets 45 | datafolder: '/allen/aics/assay-dev/Segmentation/DeepLearning/for_april_2019_release/LMNB1_training_data_iter_1/' 46 | # number of batch in each training iteration (related to patch size and GPU memory) 47 | batch_size: 8 48 | # number of patches loaded to cache 49 | PatchPerBuffer: 160 50 | # number of epoches for every time the patches in cache are cleared and resampled (smaller = heavier i/o, larger = higher chance of overfitting) 51 | epoch_shuffle: 5 52 | # number of workers for loading data in each training iteration 53 | NumWorkers: 1 54 | 55 | ########################################################################################################## 56 | # validation setting 57 | ########################################################################################################## 58 | # evaluation metric configuration 59 | validation: 60 | # the metric for validation 61 | metric: default 62 | # how to make the validation set (only used if metric is not None) 63 | leaveout: [0] 64 | # the channel to extract from output tensors 65 | # this is a list of even number of integers [out_1, ch_1, out_2, ch_2, ...] 66 | # means taking out_1'th tensor from the output list and get the ch_1'th channel from this tensor as output 67 | OutputCh: [0, 1, 1, 1, 2, 1] 68 | # how many iterations between validations 69 | validate_every_n_epoch: 25 70 | -------------------------------------------------------------------------------- /docs/bb1.md: -------------------------------------------------------------------------------- 1 | # Building Block 1: **Binarizer** 2 | 3 | 4 | The **Binarizer** is the core building block actually doing the computation for getting segmentation by either a classic image segmentation workflow or a model trained by an iterative deep learning workflow. We refer [this documentation](./demo_1.md) for a demo on how to develope a classic image segmentation workflow for a specific cell structure, and [Curator tutorial](./bb2.md) + [Trainer tutorial](./bb3.md) for how to train a deep learning based segmentation model. 5 | 6 | ![segmenter pic](./bb1_pic.png) 7 | 8 | ## Option 1: Classic image segmentation 9 | 10 | Suppose you already build a classic image segmentation workflow for your data and you call this workflow, for example "FBL_HIPSC". Assume the original images are multi-channel with structure channel in the second (so use`--struct_ch 1`, sine python is zero-based). 11 | 12 | ### Apply on one image 13 | 14 | 15 | ```bash 16 | batch_processing \ 17 | --workflow_name FBL_HIPSC \ 18 | --struct_ch 0 \ 19 | --output_dir /path/to/save/segmentation/ \ 20 | per_img \ 21 | --input /path/to/image_test.tiff 22 | ``` 23 | 24 | ### Apply on a folder of images 25 | 26 | Suppose we want to segment all `.tiff` files in one folder, we can do 27 | 28 | ```bash 29 | batch_processing \ 30 | --workflow_name FBL_HIPSC \ 31 | --struct_ch 0 \ 32 | --output_dir /path/to/save/segmentation/ \ 33 | per_dir \ 34 | --input_dir /path/to/raw_images/ \ 35 | --data_type .tiff 36 | ``` 37 | 38 | 39 | ## Option 2: Deep learning segmentation model 40 | 41 | ### Understanding model output 42 | 43 | The actual prediction from a deep learning based segmentation model is not binary. The value of each voxel is a real number between 0 and 1. To make it binary, we usually apply a cutoff value, i.e., the `Threshold` parameter in the [configuration file](./doc_pred_yaml.md). For each model, a different cutoff value may be needed. To determine a proper cutoff value, you can use `-1` for `Threshold` on sample images and open the output in ImageJ (with [bio-formats importer](https://imagej.net/Bio-Formats#Bio-Formats_Importer)) and try out different threshold values. Then, you can set `Threshold` as the new value and run on all images. Now, the results will be binary. 44 | 45 | 46 | ### Apply on one image 47 | 48 | Find/build a `.yaml` file for processing a single file (e.g., `./config/predict_file.yaml`) and make sure to follow the list [**here**](./doc_pred_yaml.md) to change the parameters, such as the image file path, the output path, the model path, etc.. 49 | 50 | ```bash 51 | dl_predict --config /path/to/predict_file.yaml 52 | ``` 53 | 54 | ### Apply on a folder of images 55 | 56 | Find/build a `.yaml` file for processing a folder of images (e.g., `./config/predict_folder.yaml`) and make sure to follow the list [**here**](./doc_pred_yaml.md) to change the parameters, such as the image folder path, the output path, the model path, etc.. 57 | 58 | ```bash 59 | dl_predict --config /path/to/predict_folder.yaml 60 | ``` -------------------------------------------------------------------------------- /docs/bb1_pic.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AllenCell/aics-ml-segmentation/5c15d9720e1c3297377c586a98209792792a43f8/docs/bb1_pic.png -------------------------------------------------------------------------------- /docs/bb2.md: -------------------------------------------------------------------------------- 1 | # Building Block 2: **Curator** 2 | 3 | **Curator** is used to prepare training data for **Trainer**. It really emphasizes the iterative part of the DL workflow we presented in the [paper](https://www.biorxiv.org/content/10.1101/491035v1). Namely, after having some preliminary segmentation results, you can improve the performance by training a deep learning model with curated segmentations in an iterative fashion. 4 | 5 | *Annotating images is a way to imcorporate human knowledge into training the deep learning based segmentation model. But, we never want to directly draw the segmentation manually. That could very time consuming and subjective. Instead, we want to draw areas to guide the model on whether the automatic segmentation result is realiable or which segmentation (always from automatic algorithms) to use in different areas.* 6 | 7 | ![segmenter pic](./bb2_pic.png) 8 | 9 | There are three scenarios that the current version of **Curator** can handle. The same "curation" idea can be adapted based on the current scripts for your special needs. 10 | 11 | 1. Sorting: If your segmentation algorithm only works well on a subset of images that you need to analyze (maybe due to instability of the algorithm or variations between images), you should sort out the successful cases and train your model with them. 12 | 2. Merging: If the objects to be segmented in each image form two sub-populations (e.g., mitotic cells vs. interphase cells) and different algorithms are needed to segment each sub-population, you should merge the two segmentation versions and train your model with the merged ground truth. 13 | 3. Take-All: If you already have ground truth data (e.g., by manual annotation or you are using simulated images with known ground truth), there is a simple script to convert your data into the format compatible with **Trainer**. 14 | 15 | 16 | ## Sorting: 17 | 18 | Suppose you have a set of raw images and their segmentations, and each raw image is multi-channel with the structure channel in the third (so `--input_channel`=2, zero-base). Training data will be generated automatically at the end of sorting. A `.csv` file needs be generated to track and resumes the process when necessary and you can name it `curator_sorting_tracker.csv`. You might also want to create excluding masks (see special note 1 below) for one or more images. If so, you can specify it as `--mask_path`. Additionally, make sure to check special note 2 for input image normalization. In the following code example below, we are using "normalization recipe" 15. 19 | 20 | Here, we assume the files in `raw_path` and `seg_path` have the following correspondence: Suppose `raw_path` has an image `img_001.tiff`, then `seg_path` needs to have the corresponding segmentation result named as `img_001_struct_segmentation.tiff`. Users do not need to worry about naming, if the segmentation is generated by our segmenter (either classic or deep learning), as `_struct_segmentation.tiff` is the default name for segmentation results. 21 | 22 | ### How to run? 23 | 24 | ```bash 25 | curator_sorting \ 26 | --raw_path /path/to/raw/image/ \ 27 | --input_channel 2 \ 28 | --data_type .tiff \ 29 | --seg_path /path/to/segmentation/ \ 30 | --train_path /path/to/training_data/ \ 31 | --csv_name /path/to/curator_sorting_tracker.csv \ 32 | --mask_path /path/to/excluding_mask/ \ 33 | --Normalization 15 34 | ``` 35 | 36 | ### How to use? 37 | 38 | A side-by-side view of the original image and the segmentation will pop up after the code is finished running and you should be able to decide if the segmentation is good or bad with your mouse by a right-click or a left-click. 39 | 40 | * left-click = 'Bad' 41 | * right-click = 'Good' 42 | 43 | 44 | If an image is labeled as 'Good' (i.e., after a right-click), users will be asked if an excluding mask is needed. If yes, type in `y` in the command line or type in `n` if otherwise. If typed in `y`, a new window will pop up for drawing polygons on the image as the excluding mask. 45 | 46 | * When adding a polygon, left-clicks will be recorded as the vertices of the polygon. After the last vertex, a right mouse click will close the polygon (connecting the last vertex to the first vertex). **Make sure you only draw within the upper left panel, i.e., the original image** 47 | * Multiple polygons can be added in one image 48 | * After finished drawing with one image, press `D` to close the window and move on to the next one 49 | * Press `Q` to quit current annotation (can resume later) 50 | 51 | 52 | ## Merging: 53 | 54 | Suppose you have a set of raw images which are multi-channel with the structure channel in the third (so `--input_channel`=2, zero-base). Also, suppose we have two different versions of segmentations, say `seg1` and `seg2`, each of which is more suitable for certain cells in each image. We can think of `seg1` as the "base version" segmentation and `seg2` as the "patch version" segmentation. Then we can think of the mask as "patches" over the base version, so the segmentation in the "patch version" will be used inside the "patches", while using "base version" elsewhere else. In this sense, *it is important to properly assign which path is passed in as `seg1` and which path is passed in as `seg2`.* For each mask, you will need to draw polygons on a 2D image as the "patches". All masks drawn in 2D will be duplicated on every z-slice (see Note 3 below for more details). Training data will be generated automatically at the end of curation. A `.csv` file needs be generated to track and resumes the process when necessary and you can name it `curator_merging_tracker.csv`. Like in sorting, you might also want to create excluding masks (see special note 1 below) for one or more images. If so, you can specify as `--ex_mask_path`. Additionally, make sure to check special note 2 for input image normalization. In the following code example below, we are using "normalization recipe" 15. 55 | 56 | Here, we assume the files in `raw_path`, `seg1_path` and `seg2_path` have the following correspondence: Suppose `raw_path` has an image `img_001.tiff`, then `seg1_path` and `seg2_path` both need to have a corresponding segmentation file named as `img_001_struct_segmentation.tiff`. Users do not need to worry about naming, if the segmentation is generated by our segmenter (either classic or deep learning), as `_struct_segmentation.tiff` is the default name for segmentation results. 57 | 58 | ### How to run? 59 | 60 | ```bash 61 | curator_merging \ 62 | --raw_path /path/to/raw/image/ \ 63 | --input_channel 2 \ 64 | --data_type .tiff \ 65 | --seg1_path /path/to/seg_base_version/ \ 66 | --seg2_path /path/to/seg_patch_version/ \ 67 | --train_path /path/to/training_data/ \ 68 | --csv_name /path/to/curator_merging_tracker.csv \ 69 | --mask_path /path/to/merging_mask/ \ 70 | --ex_mask_path /path/to/excluding_mask/ \ 71 | --Normalization 15 72 | ``` 73 | 74 | ### How to use? 75 | 76 | A side-by-side view of the original image and the segmentation will pop up after the code is finished running for you to select images and draw masks. If one image shouldn't be included in the training dataset, press `B` to label the image as "bad" and move on to the next one. Otherwise, you can start drawing polygons to assign regions to either workflow. **Make sure you only draw within the upper left panel, i.e., the original image. For each polygon you draw, the segmentation on the rightmost panel will be used to replace the corresponding part in the middle panel.** (Again, it is importatnt to specify which is v1 and which is v2 properly as execution parameters.) 77 | 78 | * When adding a polygon, left-clicks will be recorded as the vertices of the polygon. After the last vertex, a right mouse click will close the polygon (connecting the last vertex to the first vertex). 79 | * Multiple polygons can be added in one image 80 | * After finished drawing with one image, press `D` to close the window and move on to the next one 81 | * Press `Q` to quit current annotation (can resume later) 82 | 83 | After assigning the regions in the images to either workflow, users will also be asked if an excluding mask is needed. If yes, type in `y` in the command line or type in `n` if otherwise. If typed in `y`, a new window will pop up for drawing polygons on the image as the excluding mask. (same as drawing the merging mask). 84 | 85 | 86 | ## Take-All: 87 | 88 | If you already have ground truth data (e.g., by manual annotation or you are using simulated images with known ground truth), **Curator** can take all of them and convert them to a specific format compatible **Trainer** (e.g., naming convention and input normalization). Follow the steps below if you wish to do so. 89 | 90 | 1. Make sure all original images are in one folder and have the same format (e.g., multi-channel `ome.tif` with the target structure in channel `0`). 91 | 2. Make sure all segmentations (i.e. ground truth images) are in a different folder and each filename starts with the base name of the corresponding original image (without extension) and ends with `_struct_segmentation.tiff`. For example, the segmentation image for original image `img_001.tiff` should have the filename `img_001_struct_segmentation.tiff`. 92 | 3. (optional) If excluding masks are needed for certain images, make sure that they are saved in another folder (different from original and segmentation) and each filename starts with the base name of the corresponding original image (without extension) and ends with `_mask.tiff`. For example, the excluding mask image for original image `img_001.tiff` should have the filename `img_001_mask.tiff`. The areas to be excluding should have value `0` in the image, while other areas have positive values. 93 | 94 | You can use the following code example to use this version of **Curator**. 95 | 96 | ```bash 97 | curator_takeall \ 98 | --raw_path /path/to/raw/image/ \ 99 | --data_type .tiff \ 100 | --input_channel 2 \ 101 | --seg_path /path/to/segmentation/ \ 102 | --train_path /path/to/training_data/ \ 103 | --mask_path /path/to/excluding_mask/ \ 104 | --Normalization 15 105 | ``` 106 | 107 | ======================= 108 | 109 | ### Special note 1: Masking areas to be excluded 110 | 111 | It is not uncommon to have a small area in an image that should be excluded from training due to bad segmentation. In the context of sorting, for example, an image can be almost perfectly segmented except for a small area. You may not want to simply say the segmentation for this image failed, but instead, to include excluding masking areas to only ignore the small area that failed. This step is certainly optional and only meant to include more images for training. For sorting/merging, an optional step can be triggered to draw a mask (polygons) on a 2D image (max z-projection) to indicate the areas to be excluded. For take-all, an optional folder can be used to save all mask images. 112 | 113 | 114 | ### Special note 2: "--Normalization" 115 | 116 | It is important to normalize your images before feeding them into the deep learning model. For example, if your model is trained on images with intensity values between 300 and 400 with mean intensity 310, it can have a hard time when being applied to a new image with intensity values between 360 and 480 with mean intensity 400, even if the actual contents look very similar. We provide users with a set of pre-defined 'recipes' for image normalization and all of them are based on three basic functions: min-max, auto contrast and background subtraction. More details about the min-max and auto contrast functions can be found [here](). `suggest_normalization` in the `aicssegmentation` package can help to determine the proper parameter values for your data. Background subtraction is implemented as subtracting the gaussian smoothed image from the original image and rescale to [0, 1]. This function can be used to correct uneven intensity levels and the only parameter is the gaussian kernel size. As a rule of thumb, one can use half the size of average uneven areas. An optional parameter is to set an upper bound for intensity so that any values above the upper bound will be considered outliers and re-assign to the min intensity of the image. In case you need to add your own recipes, you can modify the function `input_normalization` in `utils.py` (e.g., copy and paste one of the current recipes, change the parameters, and give it a new recipe index). 117 | 118 | List of current pre-defined recipes: 119 | 120 | * 0: min-max 121 | * 1: auto contrast [mean - 2 * std, mean + 11 * std] 122 | * 2: auto contrast [mean - 2.5 * std, mean + 10 * std] 123 | * 7: auto contrast [mean - 1 * std, mean + 6 * std] 124 | * 10: auto contrast [min, mean + 25 * std] with upper bound intensity 4000 125 | * 12: background subtraction (kernel=50) + auto contrast [mean - 2.5 * std, mean + 10 * std] 126 | * 15: background subtraction (kernel=50) with upper bound intensity 4000 127 | 128 | ### Special note 3: Interface design 129 | 130 | #### Techincal consideration 131 | 132 | Currently, the interface of merging/sorting is implemented only using `matplotlib` without any advanced packages for interface building. Our goal is to keep it simple and easy to setup. It should be robust across different machines and easy to be hacked so that users can customize their own curation interface when necessary. We are also investigating other ways for implementing the interphase that is both more user friendly, easy to setup and robust across different machines. 133 | 134 | #### Human computer interaction consideration 135 | 136 | The best way for annotation may vary from problem to problem. Such variation could be in two aspects: 137 | 138 | (1) how to visualize current results for determining where to annotate 139 | 140 | In the current implementation, the max-projection along z and the middel z slice is presented in *merging*, while extra slices above and below middle slice are also shown in *sorting*. It is very likely that different ways (e.g., show the top few z slices or not showing max-projection) may be more suitable in specific problems. This should be easily hackable by modifying `PATH/TO/aicsmlsegment/bin/curator/curator_sorting.py` or `PATH/TO/aicsmlsegment/bin/curator/curator_merging.py`. Looking for code sections above `plt.figure()`. 141 | 142 | (2) how to annotate 143 | 144 | In the current implementation, all mask drawing is done as polygons in 2D and duplicated on every z slice as 3D mask. The motivation is drawing in 3D is hard and time-consuming, no matter slice by slice or directly in 3D visualization. A very common question for *merging* is that **What if two cells overlape and one cell needs segmentation version 1 and the other cell needs segmentation version 2?** First of all, we assume this is not true for all cases require merging. In other words, we expect there are always at least some non-overlapping parts on the z-projection. Otherwise, we may visualize the images in a different way (e.g., viewing along x or y axis). Then, the overlapping areas between the two cells can be marked as 'exclusing mask' to be ignored during training. Our goal is to make the drawing as simple as possible so that it is very easy and fast to get annotation on more images without worrying about throwing away some parts in certain images. 145 | -------------------------------------------------------------------------------- /docs/bb2_pic.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AllenCell/aics-ml-segmentation/5c15d9720e1c3297377c586a98209792792a43f8/docs/bb2_pic.png -------------------------------------------------------------------------------- /docs/bb3.md: -------------------------------------------------------------------------------- 1 | # Building Block 3: **Trainer** 2 | 3 | **Trainer** is used to train deep learning-based segmentation models. The input for **Trainer** should be data prepared by **Curator** (see [documentation](./bb2.md)) and the output should be a model that can be used in **Segmenter**. 4 | 5 | ![segmenter pic](./bb3_pic.png) 6 | 7 | Find/build the `.yaml` file for training (e.g, './config/train.yaml') and make sure to following the list [**here**](./doc_train_yaml.md) to change the parameters, such as the training data path, the path for saving the model, etc.. 8 | 9 | ```bash 10 | dl_train --config /home/config_files/train_lab.yaml 11 | ``` 12 | 13 | ### When multiple GPUs are available 14 | 15 | By default, **Trainer** will use the first available GPU for computation. If there are multiple GPUs on your machine, you can choose which GPU to use by setting `CUDA_VISIBLE_DEVICES` before running **Trainer**. 16 | 17 | ```bash 18 | CUDA_VISIBLE_DEVICES=2 dl_train --config /home/config_files/train_lab.yaml 19 | ``` 20 | -------------------------------------------------------------------------------- /docs/bb3_pic.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AllenCell/aics-ml-segmentation/5c15d9720e1c3297377c586a98209792792a43f8/docs/bb3_pic.png -------------------------------------------------------------------------------- /docs/check_cuda.md: -------------------------------------------------------------------------------- 1 | # How to setup CUDA and check the right CUDA version? 2 | 3 | What is [CUDA](https://developer.nvidia.com/cuda-toolkit)? It is a package for utilizing NVDIA GPU for high performance computing. 4 | 5 | First of all, you need an NVIDIA GPU card properly installed on your Linux machine or computing cluster. Here we assume the GPU card has been physically in place and the hardware driver has been installed. 6 | 7 | 1. check your GPU driver version 8 | 9 | Running command `nvidia-smi` in your terminal will give an overview of your GPU cards, for example 10 | 11 | ![nvidia_smi](./nvidia_smi.png) 12 | 13 | So, we have Driver Version: 390.87 14 | 15 | 2. install CUDA 16 | 17 | Running command `nvcc --version` in your terminal will give an overview of your CUDA installation. 18 | 19 | ![cuda](./cuda.png) 20 | 21 | You cannot see this message if CUDA is not setup on your machine. If so, you need to install CUDA. In order to determine which CUDA version fits your GPU, check this [chart](https://stackoverflow.com/questions/30820513/what-is-the-correct-version-of-cuda-for-my-nvidia-driver/30820690#30820690). For our case (Driver Version = 390.87), CUDA 9.0 or higher is good for us. 22 | 23 | Go to [CUDA website](https://developer.nvidia.com/cuda-toolkit) and follow the download and installation instruction to install CUDA. 24 | 25 | 26 | 27 | -------------------------------------------------------------------------------- /docs/cuda.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AllenCell/aics-ml-segmentation/5c15d9720e1c3297377c586a98209792792a43f8/docs/cuda.png -------------------------------------------------------------------------------- /docs/demo1_pic.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AllenCell/aics-ml-segmentation/5c15d9720e1c3297377c586a98209792792a43f8/docs/demo1_pic.png -------------------------------------------------------------------------------- /docs/demo2_pic.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AllenCell/aics-ml-segmentation/5c15d9720e1c3297377c586a98209792792a43f8/docs/demo2_pic.png -------------------------------------------------------------------------------- /docs/demo_1.md: -------------------------------------------------------------------------------- 1 | # Demo 1: Segmentation of ATP2A2 in 3D fluorescent microscopy images of hiPS cells 2 | 3 | In this demo, we will demonstrate how to get the segmentation of ATP2A2 in 3D fluorescent microscopy images of hiPS cells. 4 | 5 | *Note: This demo only uses the classic segmentation workflow and thus does not require a GPU. See package [aics-segmentation](https://github.com/AllenCell/aics-segmentation). 6 | 7 | ## Stage 1: Develop a classic image segmentation workflow 8 | 9 | We recommend users starting by identifying a structure in the [lookup table](https://www.allencell.org/segmenter.html) that looks the most similar to the segmentation task that you have. Once you have identified a structure, open the corresponding Jupyter Notebook and follow the instructions in the notebook to tune the workflow. After finalizing the algorithms and parameters in the workflow, modify batch_processing.py to batch process all images (file by file or folder by folder). 10 | 11 | #### Step 1: Find the structure in the lookup table with the most similar morphology to your data 12 | 13 | List of "playgrounds" for the lookup table: 14 | 15 | 1. playground_st6gal.ipynb: workflow for Sialyltransferase 1 16 | 2. playground_spotty.ipynb: workflow for Fibrillarin, Beta catenin 17 | 3. playground_npm1.ipynb: workflow for Nucleophosmin 18 | 4. playground_curvi.ipynb: workflows for Sec61 beta, Tom 20, Lamin B1 (mitosis-specific) 19 | 5. playground_lamp1.ipynb: workflow for LAMP-1 20 | 6. playground_dots.ipynb: workflows for Centrin-2, Desmoplakin, and PMP34 21 | 7. playground_gja1.ipynb: workflow for Connexin-43 22 | 8. playground_filament3d.ipynb: workflows for Tight junction protein ZO1, Beta actin, Non-muscle myosin IIB, Alpha-actinin-1, Alpha tubulin, Troponin I, and Titin 23 | 9. playground_shell.ipynb: workflow for Lamin B1 (Interphase-specific) 24 | 25 | In this example, ATP2A2 localizes to the nuclear periphery and ER tubules, very similar to Sec61B. Therefore we are starting with `playground_curvi.ipynb`. 26 | 27 | #### Step 2: Go to Jupyter Notebook and tune the workflow 28 | 29 | First, start your Jupyter Notebook App (make sure to activate your conda environment, see package [aics-segmentation](https://github.com/AllenCell/aics-segmentation) for details). 30 | 31 | ```bash 32 | jupyter notebook 33 | ``` 34 | 35 | Now, Jupyter Notebook should have opened in your default browser and you can make a copy of `playground_curvi.ipynb` to start working. Simply follow the instructions embedded in the notebook to tune the workflow for your image. ([how to use a Jupyter Notebook?](https://jupyter-notebook-beginner-guide.readthedocs.io/en/latest/execute.html#executing-a-notebook)) 36 | 37 | #### Step 3: Batch run 38 | 39 | You can easily test your workflow on multiple images with batch processing following the steps below. 40 | 41 | 1. Duplicate the template file in `/aics-segmentation/aicssegmentation/structure_wrapper/seg_template.py` and change name to `/aics-segmentation/aicssegmentation/structure_wrapper/seg_atp2a2.py` 42 | 2. Open `seg_atp2a2.py` 43 | 3. Change the function name from `Workflow_template()` to `Workflow_atp2a2()` 44 | 4. insert parameters and functions at the placeholders. Meanwhile, make sure you `import` all the functions you want to use. You can check `seg_lamin_interphase.py` under structure_wrapper to see examples. 45 | 5. Save the file 46 | 6. Run (make sure to use your own path and structure channel index) 47 | 48 | ```bash 49 | batch_processing --workflow_name atp2a2 --struct_ch 1 --output_dir /path/to/output per_dir --input_dir /path/to/raw --data_type .czi 50 | ``` 51 | Or, you can also use these scripts (`aicssegmentation/bin/run_toolkit.sh` for linux/mac, `aicssegmentation/bin/run_toolkit.bat`) 52 | 53 | ## Stage 2: Evaluation 54 | 55 | The goal of the Jupyter Notebook "playground" is to design and assess the overall workflow on one or a couple of images. After applying on more images (ideally representing possible variations in the full dataset to be analyzed), we want to make sure the workflow works well on different examples. 56 | 57 | ### Case 1: 58 | 59 | If everything looks good, the script is ready for processing new data for analysis. 60 | 61 | ### Case 2: 62 | 63 | If the results are okay/aceptable on all images, but may need a little tuning (e.g., decreasing the cutoff value of `filament_3d_wrapper` to be more permissive), you can adjust the paramters in `seg_atp2a2.py`. You may need several rounds of finetuning of the parameters in batch mode to finally achieve the most satisfactory results on all representative images. You may not have to go back to the Jupyter Notebook file since the notebook is only meant to help you quickly test out the overall workflow and get reasonable parameters. 64 | 65 | ### Case 3: 66 | 67 | If the results are good on some images, but bad on others, or the results are only good on certain cells, you may consider use the iterative deep learning workflow to improve the segmentation quality. See [demo 2](./demo_2.md) for details. 68 | 69 | 70 | -------------------------------------------------------------------------------- /docs/demo_2.md: -------------------------------------------------------------------------------- 1 | # Demo 2: Segmentation of Lamin B1 in 3D fluorescent microscopy images of hiPS cells 2 | 3 | In this demo, we will demonstrate how to get the segmentation of Lamin B1 in 3D fluorescent microscopy images of hiPS cells. Before starting this demo, make sure to check out [demo 1: build a classic image segmentation workflow](./demo_1.md), and detailed descriptions of the building blocks in our segmenter ([Binarizer](./bb1.md), [Curator](./bb2.md), [Trainer](./bb3.md)). The data used in this demo can be found on [allencell quilt bucket](https://open.quiltdata.com/b/allencell/packages/aics/laminb1_sample_data)/ 4 | 5 | 6 | ## Stage 1: Run **Binarizer** (classic image segmentation workflow) and Assess Results 7 | 8 | Suppose you already worked out a classic image segmentation workflow and saved it as `seg_lmnb1_interphase.py` (i.e., setting `workflow_name` as `lmnb1_interphase`). You can run 9 | 10 | ```bash 11 | batch_processing \ 12 | --workflow_name lmnb1_interphase \ 13 | --struct_ch 0 \ 14 | --output_dir /path/to/segmentation \ 15 | per_dir \ 16 | --input_dir /path/to/raw \ 17 | --data_type .tiff 18 | ``` 19 | to batch process all Lamin B1 images in a folder and evaluate them. 20 | 21 | During evaluation, some results appear to be good but some have errors (left: original; right: binary image from **Binarizer**). 22 | 23 | ![wf1 pic](./wf_pic.png) 24 | 25 | Some objects were missed in the segmentation due to the failure of an automatic seeding step (see yellow arrow). Also, this workflow performed poorly on mitotic cells (see blue arrow). The segmentation wasn't able to produce consistent result on all images, however, we can to leverage the successful ones to build a DL model. 26 | 27 | ## Stage 2: Run **Curator** (sorting) 28 | 29 | The goal of this curation step is to select those images that were successfully segmented so it is appropriate to use the "sorting" strategy in **Curator** . It can be achieved by running the code below. 30 | 31 | ```bash 32 | curator_sorting \ 33 | --raw_path /path/to/raw \ 34 | --data_type .tiff \ 35 | --input_ch 0 \ 36 | --seg_path /path/to/segmentation \ 37 | --mask_path /path/to/curator_sorting_excluding_mask \ 38 | --csv_name /path/to/sorting_record.csv \ 39 | --train_path /path/to/training_data \ 40 | --Normalization 10 41 | ``` 42 | 43 | ## Stage 3: Run **Trainer** 44 | 45 | Find/build the `.yaml` file for training (e.g, './config/train.yaml') and make sure to following the list [**here**](./doc_train_yaml.md) to change the parameters, such as the training data path, the path for saving the model, etc.. 46 | 47 | ```bash 48 | dl_train --config /path/totrain_config.yaml 49 | ``` 50 | to start training of the model. 51 | 52 | Depending on the size of your training data, the training process can take 8~32 hours 53 | 54 | ## Stage 4: Run **Binarizer** 55 | 56 | After the training is finished, you can either apply the model on one image or a folder of image. Simply find the `.yaml` file for processing a folder of images (e.g., `./config/predict_folder.yaml`) or the `.yaml` file for processing a single image (e.g., `./config/predict_file.yaml`). Make sure to follow the list [**here**](./doc_pred_yaml.md) to change the parameters, such as the image path, the output path, the model path, etc.. Then you can run 57 | 58 | ```bash 59 | dl_predict --config /path/to/predict_file_config.yaml 60 | ``` 61 | to apply the model on your data. 62 | 63 | Looking at the results, you can probably notice that Lamin B1 in all interphase cells were segmented well, but the model still failed to correctly segment the structure in mitotic cells. To improve the model accuracy, you can develop another classic image segmentation workflow specifically for Lamin B1 in mitotic cells and call it `lmnb1_mitotic`. 64 | 65 | In this demo, to be more efficient, we will use a mitotic dataset, where each image has at least one mitotic cell. Suppose all the images are saved in folder `raw_mitosis`. 66 | 67 | Then run the **Binarizer** twices: 68 | * first run with the deep learning model (better for interphase) and save the segmentation in folder `seg_v1`. (Again, make sure to follow the list [**here**](./doc_pred_yaml.md) to change the parameters) 69 | 70 | ```bash 71 | dl_predict --config /path/to/predict_folder_config.yaml 72 | ``` 73 | 74 | * second run with the `lmnb1_mitotic` workflow (better for mitosis), and save the segmentation in folder `seg_v2`. 75 | 76 | ```bash 77 | batch_processing \ 78 | --workflow_name lmnb1_mitotic \ 79 | --struct_ch 0 \ 80 | --output_dir /path/to/seg_v2 \ 81 | per_dir \ 82 | --input_dir /path/to/raw_mitosis \ 83 | --data_type .tiff 84 | ``` 85 | 86 | ## Stage 5: Run **Curator** 87 | 88 | Now with the combined results from the DL model and the classic segmentation workflow, it is necessary to perform another curation step to merge the two segmentation versions (for interphase and mitosis) of each image. Therefore the "merging" strategy in **Curator** will be used by running the code below. The newly generated training data can be saved in the same folder as the previous training step in order to keep using them for training. 89 | 90 | ```bash 91 | curator_merging \ 92 | --raw_path /path/to/raw_mitosis/ \ 93 | --input_ch 0 \ 94 | --data_type .tiff \ 95 | --seg1_path /path/to/seg_v1 \ 96 | --seg2_path /path/to/seg_v2 \ 97 | --mask_path /path/to/curator_merging_mask \ 98 | --ex_mask_path /path/to/curator_merging_excluding_mask \ 99 | --csv_name /path/to/merging_record.csv \ 100 | --train_path /path/to/training_data \ 101 | --Normalization 10 102 | ``` 103 | 104 | ## Stage 6: Run **Trainer** 105 | 106 | Find/build the `.yaml` file for training (e.g, './config/train.yaml') and make sure to following the list [**here**](./doc_train_yaml.md) to change the parameters, such as the training data path, the path for saving the model, etc.. 107 | 108 | ```bash 109 | dl_train --config /path/to/train_config.yaml 110 | ``` 111 | to start training of the model. 112 | 113 | ## Stage 7: Run *Binarizer* 114 | 115 | After training is finished, find the `.yaml` file for processing a folder of images (e.g., `./config/predict_folder.yaml`) and make sure to follow the list [**here**](./doc_pred_yaml.md) to change the parameters. Then you can run 116 | 117 | ```bash 118 | dl_predict --config /path/to/predict_folder_config.yaml 119 | ``` 120 | to apply the model on your data. 121 | 122 | In our case, Lamin B1 in both interphase cells and mitotic cells were correctly segmented after the second round of training the DL model. 123 | 124 | ![dl pic](./dl_final.png) 125 | -------------------------------------------------------------------------------- /docs/dl_1_pic.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AllenCell/aics-ml-segmentation/5c15d9720e1c3297377c586a98209792792a43f8/docs/dl_1_pic.png -------------------------------------------------------------------------------- /docs/dl_final.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AllenCell/aics-ml-segmentation/5c15d9720e1c3297377c586a98209792792a43f8/docs/dl_final.png -------------------------------------------------------------------------------- /docs/doc.rst: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AllenCell/aics-ml-segmentation/5c15d9720e1c3297377c586a98209792792a43f8/docs/doc.rst -------------------------------------------------------------------------------- /docs/doc_pred_yaml.md: -------------------------------------------------------------------------------- 1 | # Configuration for Running a DL Model in **Segmenter** 2 | 3 | This is a detailed description of the prediction configuration for running a DL model in **Segmenter** to generate the segmentation. There are a lot of parameters in the configuration file, which can be categorized into three types: 4 | 5 | 1. Parameters specific to each running (need to change every time), marked by :pushpin: 6 | 2. Parameters specific to each machine (only need to change once on a particular machine), marked by :computer: 7 | 3. Parameters pre-defined for the general training scheme (no need to change for most problems and need basic knowledge of deep learning to adjust), marked by :ok: 8 | 9 | 10 | ### model-related parameters 11 | 12 | 1. choose which model to use (:pushpin:) 13 | ```yaml 14 | model: 15 | name: unet_xy_zoom 16 | zoom_ratio: 3 17 | ``` 18 | or 19 | ```yaml 20 | model: 21 | name: unet_xy 22 | ``` 23 | see model parameters in [training configuration](./doc_train_yaml.md) 24 | 25 | 2. input and output type (:ok:) 26 | ```yaml 27 | nchannel: 1 28 | nclass: [2, 2, 2] 29 | OutputCh: [0, 1] 30 | ``` 31 | These are related to the model architecture and fixed by default. 32 | 33 | 3. patch size (:computer:) 34 | 35 | ```yaml 36 | size_in: [48, 148, 148] 37 | size_out: [20, 60, 60] 38 | ``` 39 | see patch size parameters in [training configuration](./doc_train_yaml.md) 40 | 41 | 4. model directory (:pushpin:) 42 | ```yaml 43 | model_path: '/home/model/checkpoint_epoch_300.pytorch' 44 | ``` 45 | This the place to specify which trained model to run. 46 | 47 | 48 | ### Data Info (:pushpin:) 49 | ```yaml 50 | OutputDir: '//allen/aics/assay-dev/Segmentation/DeepLearning/for_april_2019_release/' 51 | InputCh: [0] 52 | ResizeRatio: [1.0, 1.0, 1.0] 53 | Threshold: 0.75 54 | RuntimeAug: False 55 | Normalization: 10 56 | mode: 57 | name: folder 58 | InputDir: '/allen/aics/assay-dev/Segmentation/DeepLearning/for_april_2019_release/LMNB1_fluorescent' 59 | DataType: tiff 60 | ``` 61 | 62 | `DataType` is the type of images to be processed in `InputDir`, which the `InputCh`'th (keep the [ ]) channel of each images will be segmented. If your model is trained on images of a certain resolution and your test images are of different resolution `ResizeRatio` needs to be set as [new_z_size/old_z_size, new_y_size/old_y_size, new_x_size/old_x_size]. The acutal output is the likelihood of each voxels being the target structure. A `Threshold` between 0 and 1 needs to be set to generate the binary mask. We recommend to use 0.6 ~ 0.9. When `Threshold` is set as `-1`, the raw prediction from the model will be saved, for users to determine a proper binary cutoff. `Normalization` is the index of a list of pre-defined normalization recipes and should be the same index as generating training data (see [Curator](./bb2.md) for the full list of normalization recipes). 63 | 64 | 65 | -------------------------------------------------------------------------------- /docs/doc_train_yaml.md: -------------------------------------------------------------------------------- 1 | # Configuration for Training a DL Model in **Trainer** 2 | 3 | This is a detailed description of the configuration for training a DL model in **Trainer**. There are a lot of parameters in the configuration file, which can be categorized into three types: 4 | 5 | 1. Need to change on every run (i.e., parameters specific to each execution), marked by :warning: 6 | * `checkpoint_dir` (where to save trained models), `datafolder` (where are training data), `resume` (whether to start from a previous model) 7 | 2. Need to change for every segmentation problem (i.e., parameters specific to one problem), marked by :pushpin: 8 | * `model`, `epochs`, `save_every_n_epoch`, ``PatchPerBuffer`` 9 | 3. Only need to change once on a particular machine (parameters specific to each machine), marked by :computer: 10 | 4. No need to change for most problems (parameters pre-defined as a general training scheme and requires advacned deep learning knowledge to adjust), marked by :ok: 11 | 12 | 13 | ### Model related parameters 14 | 15 | 1. choose which model to use (:pushpin:) 16 | ```yaml 17 | model: 18 | name: unet_xy_zoom 19 | zoom_ratio: 3 20 | ``` 21 | or 22 | ```yaml 23 | model: 24 | name: unet_xy 25 | ``` 26 | There may be probably more than 100 models in the literature for 3D image segmentation. The two models we implemented here are carefully designed for cell structure segmentation in 3D microscopy images. Model `unet_xy` is suitable for smaller-scale structures, like severl voxels thick (e.g., tubulin, lamin b1). Model `unet_xy_zoom` is more suitable for larger-scale structures, like more than 100 voxels in diameter (e.g., nucleus), while the `zoom_ratio` is an integer (e.g., 2 or 3) and can be estimated by average diameter of target object in voxels divided by 150. 27 | 28 | 2. start from an existing model? (:warning:) 29 | 30 | ```yaml 31 | resume: null 32 | ``` 33 | 34 | When doing iterative deep learning, it may be useful to start from the model trained in the previous step. The model can be specified at `resume`. If `null`, a new model will be trained from scratch. 35 | 36 | 3. input and output type (:ok:) 37 | ```yaml 38 | nchannel: 1 39 | nclass: [2, 2, 2] 40 | ``` 41 | These are related to the model architecture and fixed by default. We assume the input image has only one channel. 42 | 43 | 4. patch size (:computer:) (:pushpin:) 44 | 45 | ```yaml 46 | size_in: [50, 156, 156] 47 | size_out: [22, 68, 68] 48 | ``` 49 | In most situations, we cannot fit the entire image into the memory of a single GPU. These are also related to `batch_size` (an data loader parameter), which will be discussed shortly. `size_in` is the actual size of each patch fed into the model, while `size_out` is the size of the model's prediction. The prediction size is smaller than the input size is because the multiple convolution operations. The equation for calculating `size_in` and `size_out` is as follows. 50 | 51 | > For unet_xy, `size_in` = `[z, 8p+60, 8p+60]`, `size_out` = `[z-28, 8p-28, 8p-28]` 52 | 53 | > For unet_xy_zoom, with `zoom_ratio`=`k`, `size_in` = `[z, 8kp+60k, 8kp+60k]` and `size_out` = `[z-32, 8kp-28k-4, 8kp-28k-4]` 54 | 55 | Here, `p` and `z` can be any positive integers that make `size_out` has all positive values. 56 | 57 | Here are some pre-calculated values for different models on different types of GPUs. 58 | 59 | | | size_in | size_out | batch_size | 60 | | --------------------------------------|:-----------------:|:-----------------:|:-------------:| 61 | | unet_xy on 12GB GPU | [44, 140, 140] | [16, 52, 52] | 4 | 62 | | unet_xy on 33GB GPU | [50, 156, 156] | [22, 68, 68] | 8 | 63 | | unet_xy_zoom (ratio=3) on 12GB GPU | [52, 372, 372] | [20, 104, 104] | 4 | 64 | | unet_xy_zoom (ratio=3) on 33GB GPU | [52, 420, 420] | [20, 152, 152] | 8 | 65 | 66 | 5. model directory (:warning:) 67 | ```yaml 68 | checkpoint_dir: /home/model/xyz/ 69 | resume: null 70 | ``` 71 | This is the directory to save the trained model. If you want to start this training from a previous saved model, you may add the path to `resume`. 72 | 73 | ### Training scheme realted parameters 74 | 1. optimization parameters (:ok:) 75 | ```yaml 76 | learning_rate: 0.00001 77 | weight_decay: 0.005 78 | loss: 79 | name: Aux 80 | loss_weight: [1, 1, 1] 81 | ignore_index: null 82 | ``` 83 | 84 | 2. training epochs (:pushpin:) 85 | ```yaml 86 | epochs: 400 87 | save_every_n_epoch: 40 88 | ``` 89 | `epochs` controls how many iterations in the training. We suggest to use values between 300 and 600. A model will be saved on every `save_every_n_epoch` epochs. 90 | 91 | 92 | ### Data realted parameters 93 | 94 | ```yaml 95 | loader: 96 | name: default 97 | datafolder: '/home/data/train/' 98 | batch_size: 8 99 | PatchPerBuffer: 200 100 | epoch_shuffle: 5 101 | NumWorkers: 1 102 | ``` 103 | `datafolder` (:warning:) and `PatchPerBuffer` (:pushpin:) need to be specified for each problem. `datafolder` is the directory of training data. `PatchPerBuffer` is the number of sample patches randomly drawn in each epoch, which can be set as *number of patches to draw from each data* **x** *number of training data*. `name`, `epoch_shuffle` and `NumWorkers` (:ok:) are fixed by default. `batch_size` is related to GPU memory and patch size (see values presented with patch size). 104 | 105 | ### Validation related parameter 106 | 107 | In machine learning studies, we usually do a validation after every few epochs to make sure things are not going wrong. For most 3d microscopy image segmentation problems, the training data is very limited. We cannot save a big portion (e.g., 20%) from the training data for validation purpose. So, by default, we use leave-one-out for validation (:ok:) and may only need to adjust for advanced users. -------------------------------------------------------------------------------- /docs/nvidia_smi.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AllenCell/aics-ml-segmentation/5c15d9720e1c3297377c586a98209792792a43f8/docs/nvidia_smi.png -------------------------------------------------------------------------------- /docs/overview.md: -------------------------------------------------------------------------------- 1 | # Allen Cell Structure Segmenter Tutorials: 2 | 3 | The Allen Cell Structure Segmenter is an open source toolkit developed at the Allen Institute for Cell Science for 3D segmentation of intracellular structures in fluorescence microscope images, which brings together classic image segmentation and iterative deep learning workflows. Details including algorithms, validations, and examples can be found in our [bioRxiv paper](https://www.biorxiv.org/content/10.1101/491035v1) or [allencell.org/segmenter](allencell.org/segmenter). This tutorial will focus on how to run the *Allen Cell Structure Segmenter* (both classic image segmentation workflow and iterative DL workflow) to get an accurate segmentation. 4 | 5 | The Allen Cell Structure Segmenter is implemented as two packages: [`aicssegmentation`](https://pypi.org/project/aicssegmentation/) (classic image segmentation) and [`aicsmlsegment`](https://pypi.org/project/aicsmlsegment/) (deep learning segmentation). The execution is based on three building blocks: **Binarizer**, **Curator** and **Trainer**. We will explain how each building block works and demonstrate with real examples. 6 | 7 | *Note: The image reader used in our package supports images in common formats, such as `.tiff`, `.tif`, `.ome.tif`. The only vendor specific format supported by the reader is `.czi` (the file format for ZEISS microscope). For other formats, images have to be converted to `.tiff` or `.ome.tif` in advance.* 8 | 9 | ## Installation: 10 | 11 | * `aicssegmentation` (classic image segmentation): [Installation instruction](https://github.com/AllenCell/aics-segmentation) (available on Linux, MacOS, Windows) 12 | * `aicsmlsegment` (deep learning segmentation): [Installation instruction](../README.md) (requires NVIDIA GPU and Linux OS) 13 | 14 | 15 | ## Understanding each building block: 16 | 17 | * **Binarizer**: [documentation](./bb1.md) 18 | * **Curator**: [documentation](./bb2.md) 19 | * **Trainer**: [documentation](./bb3.md) 20 | 21 | ## Challenges in deep learning-based segmentation 22 | 23 | Deep learning (DL) is a very powerful approach for 3D image segmentation. But it is not as simple as collecting a set of segmentation ground truth, feeding them into a DL model and getting a perfect segmentation model. DL for 3D image segmentation is still being investigated in the field of computer vision (see top conferences organized by [MICCAI](http://www.miccai.org/) and [CVF](https://www.thecvf.com/) ). It is possible to have a model trained with our package still failing to produce accurate result. This could be due to many reasons and finding out ways to improve the model is beyond the scope of this tutorial. Here, we want to focus on demonstrating how to use our package, the DL part of which is designed to (1) get a good segmentation model that could work on images existing wide variability (2) be flexible enough for advanced users to develop their own research on DL-based 3D segmentation. 24 | 25 | ## Demos on real examples: 26 | 27 | ![overview pic](./overview_pic.png) 28 | 29 | The above flowchart is a simplified version of the segmenter showing the most important parts of the workflows. **Binarizer** can be either class segmentation algorithms or a DL model to compute the binary segmentation. **Curator** and **Trainer** are used to improve the segmentation from **Binarizer** when necessary. More details can be found in [bioRxiv paper](https://www.biorxiv.org/content/10.1101/491035v1)). Here, we will have demonstrations on two examples: first one only using **Binarizer** to solve the problem (i.e., a classic image segmentation workflow) and the other example also requiring **Curator** and **Trainer** (makes up the iterative DL workflow). 30 | 31 | ### Example 1: Segmentation of ATP2A2 in 3D fluorescent microscopy images of hiPS cells 32 | 33 | ![demo1 pic](./demo1_pic.png) 34 | 35 | [Link to the demo documentation](./demo_1.md) 36 | 37 | [Link to the demo video](https://youtu.be/Ynl_Yt9N8p4) 38 | 39 | ### Example 2: Segmentation of Lamin B1 in 3D fluorescent microscopy images of hiPS cells 40 | 41 | ![demo2 pic](./demo2_pic.png) 42 | 43 | [Link to the demo documentation](./demo_2.md) 44 | 45 | [Link to demo data](https://open.quiltdata.com/b/allencell/packages/aics/laminb1_sample_data) 46 | 47 | [Link to the demo video](https://youtu.be/5jBSp38ezG8) 48 | 49 | 50 | 51 | -------------------------------------------------------------------------------- /docs/overview_pic.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AllenCell/aics-ml-segmentation/5c15d9720e1c3297377c586a98209792792a43f8/docs/overview_pic.png -------------------------------------------------------------------------------- /docs/wf_pic.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AllenCell/aics-ml-segmentation/5c15d9720e1c3297377c586a98209792792a43f8/docs/wf_pic.png -------------------------------------------------------------------------------- /gradle.properties: -------------------------------------------------------------------------------- 1 | # This points to the location of the script plugins used to provide the build tasks. 2 | # This also provides the gradle version - run './gradlew wrapper' to set it. 3 | 4 | #### 5 | scriptPluginPrefix=https://aicsbitbucket.corp.alleninstitute.org/projects/sw/repos/gradle-script-plugins/raw/ 6 | scriptPluginSuffix=?at=refs/tags/ 7 | 8 | # This variable is separated in order to allow it to be overridden with an environment variable 9 | # See: https://docs.gradle.org/current/userguide/build_environment.html#sec:project_properties 10 | # This will be overridden by Jenkins with the version defined in the Jenkins configuration. 11 | # The variable can also be used locally: `ORG_GRADLE_PROJECT_scriptPluginTag=2.4.3 ./gradlew tasks` 12 | scriptPluginTag=2.4.2 13 | -------------------------------------------------------------------------------- /gradle/wrapper/gradle-wrapper.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AllenCell/aics-ml-segmentation/5c15d9720e1c3297377c586a98209792792a43f8/gradle/wrapper/gradle-wrapper.jar -------------------------------------------------------------------------------- /gradle/wrapper/gradle-wrapper.properties: -------------------------------------------------------------------------------- 1 | distributionBase=GRADLE_USER_HOME 2 | distributionPath=wrapper/dists 3 | distributionUrl=https\://services.gradle.org/distributions/gradle-5.1-bin.zip 4 | zipStoreBase=GRADLE_USER_HOME 5 | zipStorePath=wrapper/dists 6 | -------------------------------------------------------------------------------- /gradlew: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env sh 2 | 3 | ############################################################################## 4 | ## 5 | ## Gradle start up script for UN*X 6 | ## 7 | ############################################################################## 8 | 9 | # Attempt to set APP_HOME 10 | # Resolve links: $0 may be a link 11 | PRG="$0" 12 | # Need this for relative symlinks. 13 | while [ -h "$PRG" ] ; do 14 | ls=`ls -ld "$PRG"` 15 | link=`expr "$ls" : '.*-> \(.*\)$'` 16 | if expr "$link" : '/.*' > /dev/null; then 17 | PRG="$link" 18 | else 19 | PRG=`dirname "$PRG"`"/$link" 20 | fi 21 | done 22 | SAVED="`pwd`" 23 | cd "`dirname \"$PRG\"`/" >/dev/null 24 | APP_HOME="`pwd -P`" 25 | cd "$SAVED" >/dev/null 26 | 27 | APP_NAME="Gradle" 28 | APP_BASE_NAME=`basename "$0"` 29 | 30 | # Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. 31 | DEFAULT_JVM_OPTS='"-Xmx64m"' 32 | 33 | # Use the maximum available, or set MAX_FD != -1 to use that value. 34 | MAX_FD="maximum" 35 | 36 | warn () { 37 | echo "$*" 38 | } 39 | 40 | die () { 41 | echo 42 | echo "$*" 43 | echo 44 | exit 1 45 | } 46 | 47 | # OS specific support (must be 'true' or 'false'). 48 | cygwin=false 49 | msys=false 50 | darwin=false 51 | nonstop=false 52 | case "`uname`" in 53 | CYGWIN* ) 54 | cygwin=true 55 | ;; 56 | Darwin* ) 57 | darwin=true 58 | ;; 59 | MINGW* ) 60 | msys=true 61 | ;; 62 | NONSTOP* ) 63 | nonstop=true 64 | ;; 65 | esac 66 | 67 | CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar 68 | 69 | # Determine the Java command to use to start the JVM. 70 | if [ -n "$JAVA_HOME" ] ; then 71 | if [ -x "$JAVA_HOME/jre/sh/java" ] ; then 72 | # IBM's JDK on AIX uses strange locations for the executables 73 | JAVACMD="$JAVA_HOME/jre/sh/java" 74 | else 75 | JAVACMD="$JAVA_HOME/bin/java" 76 | fi 77 | if [ ! -x "$JAVACMD" ] ; then 78 | die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME 79 | 80 | Please set the JAVA_HOME variable in your environment to match the 81 | location of your Java installation." 82 | fi 83 | else 84 | JAVACMD="java" 85 | which java >/dev/null 2>&1 || die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. 86 | 87 | Please set the JAVA_HOME variable in your environment to match the 88 | location of your Java installation." 89 | fi 90 | 91 | # Increase the maximum file descriptors if we can. 92 | if [ "$cygwin" = "false" -a "$darwin" = "false" -a "$nonstop" = "false" ] ; then 93 | MAX_FD_LIMIT=`ulimit -H -n` 94 | if [ $? -eq 0 ] ; then 95 | if [ "$MAX_FD" = "maximum" -o "$MAX_FD" = "max" ] ; then 96 | MAX_FD="$MAX_FD_LIMIT" 97 | fi 98 | ulimit -n $MAX_FD 99 | if [ $? -ne 0 ] ; then 100 | warn "Could not set maximum file descriptor limit: $MAX_FD" 101 | fi 102 | else 103 | warn "Could not query maximum file descriptor limit: $MAX_FD_LIMIT" 104 | fi 105 | fi 106 | 107 | # For Darwin, add options to specify how the application appears in the dock 108 | if $darwin; then 109 | GRADLE_OPTS="$GRADLE_OPTS \"-Xdock:name=$APP_NAME\" \"-Xdock:icon=$APP_HOME/media/gradle.icns\"" 110 | fi 111 | 112 | # For Cygwin, switch paths to Windows format before running java 113 | if $cygwin ; then 114 | APP_HOME=`cygpath --path --mixed "$APP_HOME"` 115 | CLASSPATH=`cygpath --path --mixed "$CLASSPATH"` 116 | JAVACMD=`cygpath --unix "$JAVACMD"` 117 | 118 | # We build the pattern for arguments to be converted via cygpath 119 | ROOTDIRSRAW=`find -L / -maxdepth 1 -mindepth 1 -type d 2>/dev/null` 120 | SEP="" 121 | for dir in $ROOTDIRSRAW ; do 122 | ROOTDIRS="$ROOTDIRS$SEP$dir" 123 | SEP="|" 124 | done 125 | OURCYGPATTERN="(^($ROOTDIRS))" 126 | # Add a user-defined pattern to the cygpath arguments 127 | if [ "$GRADLE_CYGPATTERN" != "" ] ; then 128 | OURCYGPATTERN="$OURCYGPATTERN|($GRADLE_CYGPATTERN)" 129 | fi 130 | # Now convert the arguments - kludge to limit ourselves to /bin/sh 131 | i=0 132 | for arg in "$@" ; do 133 | CHECK=`echo "$arg"|egrep -c "$OURCYGPATTERN" -` 134 | CHECK2=`echo "$arg"|egrep -c "^-"` ### Determine if an option 135 | 136 | if [ $CHECK -ne 0 ] && [ $CHECK2 -eq 0 ] ; then ### Added a condition 137 | eval `echo args$i`=`cygpath --path --ignore --mixed "$arg"` 138 | else 139 | eval `echo args$i`="\"$arg\"" 140 | fi 141 | i=$((i+1)) 142 | done 143 | case $i in 144 | (0) set -- ;; 145 | (1) set -- "$args0" ;; 146 | (2) set -- "$args0" "$args1" ;; 147 | (3) set -- "$args0" "$args1" "$args2" ;; 148 | (4) set -- "$args0" "$args1" "$args2" "$args3" ;; 149 | (5) set -- "$args0" "$args1" "$args2" "$args3" "$args4" ;; 150 | (6) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" ;; 151 | (7) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" ;; 152 | (8) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" ;; 153 | (9) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" "$args8" ;; 154 | esac 155 | fi 156 | 157 | # Escape application args 158 | save () { 159 | for i do printf %s\\n "$i" | sed "s/'/'\\\\''/g;1s/^/'/;\$s/\$/' \\\\/" ; done 160 | echo " " 161 | } 162 | APP_ARGS=$(save "$@") 163 | 164 | # Collect all arguments for the java command, following the shell quoting and substitution rules 165 | eval set -- $DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS "\"-Dorg.gradle.appname=$APP_BASE_NAME\"" -classpath "\"$CLASSPATH\"" org.gradle.wrapper.GradleWrapperMain "$APP_ARGS" 166 | 167 | # by default we should be in the correct project dir, but when run from Finder on Mac, the cwd is wrong 168 | if [ "$(uname)" = "Darwin" ] && [ "$HOME" = "$PWD" ]; then 169 | cd "$(dirname "$0")" 170 | fi 171 | 172 | exec "$JAVACMD" "$@" 173 | -------------------------------------------------------------------------------- /gradlew.bat: -------------------------------------------------------------------------------- 1 | @if "%DEBUG%" == "" @echo off 2 | @rem ########################################################################## 3 | @rem 4 | @rem Gradle startup script for Windows 5 | @rem 6 | @rem ########################################################################## 7 | 8 | @rem Set local scope for the variables with windows NT shell 9 | if "%OS%"=="Windows_NT" setlocal 10 | 11 | set DIRNAME=%~dp0 12 | if "%DIRNAME%" == "" set DIRNAME=. 13 | set APP_BASE_NAME=%~n0 14 | set APP_HOME=%DIRNAME% 15 | 16 | @rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. 17 | set DEFAULT_JVM_OPTS="-Xmx64m" 18 | 19 | @rem Find java.exe 20 | if defined JAVA_HOME goto findJavaFromJavaHome 21 | 22 | set JAVA_EXE=java.exe 23 | %JAVA_EXE% -version >NUL 2>&1 24 | if "%ERRORLEVEL%" == "0" goto init 25 | 26 | echo. 27 | echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. 28 | echo. 29 | echo Please set the JAVA_HOME variable in your environment to match the 30 | echo location of your Java installation. 31 | 32 | goto fail 33 | 34 | :findJavaFromJavaHome 35 | set JAVA_HOME=%JAVA_HOME:"=% 36 | set JAVA_EXE=%JAVA_HOME%/bin/java.exe 37 | 38 | if exist "%JAVA_EXE%" goto init 39 | 40 | echo. 41 | echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME% 42 | echo. 43 | echo Please set the JAVA_HOME variable in your environment to match the 44 | echo location of your Java installation. 45 | 46 | goto fail 47 | 48 | :init 49 | @rem Get command-line arguments, handling Windows variants 50 | 51 | if not "%OS%" == "Windows_NT" goto win9xME_args 52 | 53 | :win9xME_args 54 | @rem Slurp the command line arguments. 55 | set CMD_LINE_ARGS= 56 | set _SKIP=2 57 | 58 | :win9xME_args_slurp 59 | if "x%~1" == "x" goto execute 60 | 61 | set CMD_LINE_ARGS=%* 62 | 63 | :execute 64 | @rem Setup the command line 65 | 66 | set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar 67 | 68 | @rem Execute Gradle 69 | "%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %CMD_LINE_ARGS% 70 | 71 | :end 72 | @rem End local scope for the variables with windows NT shell 73 | if "%ERRORLEVEL%"=="0" goto mainEnd 74 | 75 | :fail 76 | rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of 77 | rem the _cmd.exe /c_ return code! 78 | if not "" == "%GRADLE_EXIT_CONSOLE%" exit 1 79 | exit /b 1 80 | 81 | :mainEnd 82 | if "%OS%"=="Windows_NT" endlocal 83 | 84 | :omega 85 | -------------------------------------------------------------------------------- /settings.gradle: -------------------------------------------------------------------------------- 1 | rootProject.name = 'aicsmlsegment' -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [bdist_wheel] 2 | universal=1 3 | 4 | [metadata] 5 | license_file = LICENSE.txt 6 | 7 | [aliases] 8 | test=pytest 9 | 10 | [tool:pytest] 11 | addopts = --junitxml=build/test_report.xml -v 12 | norecursedirs = aicsmlsegment/tests/checkouts .egg* build dist venv .gradle aicsmlsegment.egg-info/* 13 | 14 | [flake8] 15 | max-line-length = 130 16 | 17 | [coverage:html] 18 | directory = build/coverage_html 19 | title = Test coverage report for aicsbatch 20 | 21 | [coverage:xml] 22 | output = build/coverage.xml 23 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | 4 | PACKAGE_NAME = 'aicsmlsegment' 5 | 6 | """ 7 | Notes: 8 | We get the constants MODULE_VERSION from 9 | See (3) in following link to read about versions from a single source 10 | https://packaging.python.org/guides/single-sourcing-package-version/#single-sourcing-the-version 11 | """ 12 | 13 | MODULE_VERSION = "" 14 | exec(open(PACKAGE_NAME + "/version.py").read()) 15 | 16 | 17 | def readme(): 18 | with open('README.md') as f: 19 | return f.read() 20 | 21 | 22 | test_deps = ['pytest', 'pytest-cov'] 23 | lint_deps = ['flake8'] 24 | all_deps = [*test_deps, *lint_deps] 25 | extras = { 26 | 'test_group': test_deps, 27 | 'lint_group': lint_deps, 28 | 'all': all_deps 29 | } 30 | 31 | setup(name=PACKAGE_NAME, 32 | version=MODULE_VERSION, 33 | description='Scripts for ML structure segmentation.', 34 | long_description=readme(), 35 | author='AICS', 36 | author_email='jianxuc@alleninstitute.org', 37 | license='Allen Institute Software License', 38 | packages=find_packages(exclude=['tests', '*.tests', '*.tests.*']), 39 | entry_points={ 40 | "console_scripts": [ 41 | "dl_train={}.bin.train:main".format(PACKAGE_NAME), 42 | "dl_predict={}.bin.predict:main".format(PACKAGE_NAME), 43 | "curator_merging={}.bin.curator.curator_merging:main".format(PACKAGE_NAME), 44 | "curator_sorting={}.bin.curator.curator_sorting:main".format(PACKAGE_NAME), 45 | "curator_takeall={}.bin.curator.curator_takeall:main".format(PACKAGE_NAME), 46 | ] 47 | }, 48 | install_requires=[ 49 | 'numpy>=1.15.1', 50 | 'scipy>=1.1.0', 51 | 'scikit-image', 52 | 'pandas>=0.23.4', 53 | 'aicsimageio>3.3.0', 54 | 'tqdm', 55 | 'pyyaml', 56 | 'aicssegmentation', 57 | #'pytorch=1.0.0' 58 | ], 59 | 60 | # For test setup. This will allow JUnit XML output for Jenkins 61 | setup_requires=['pytest-runner'], 62 | tests_require=test_deps, 63 | 64 | extras_require=extras, 65 | zip_safe=False 66 | ) 67 | --------------------------------------------------------------------------------