├── .gitignore ├── LICENSE ├── README.md ├── _config.yml ├── cococrawler ├── __init__.py └── getcoco17.py ├── command.csh ├── data ├── imgs │ └── train1.png ├── masks │ └── train1.png └── saved_models │ └── segcapsr3 │ ├── split-0_batch-1_shuff-1_aug-0_loss-dice_slic-1_sub--1_strid-1_lr-0.0001_recon-20.0_model_20180705-092846.hdf5 │ ├── split-0_batch-1_shuff-1_aug-0_loss-dice_slic-1_sub--1_strid-1_lr-0.01_recon-20.0_model_20180707-222802.hdf5 │ ├── split-0_batch-1_shuff-1_aug-0_loss-dice_slic-1_sub--1_strid-1_lr-0.01_recon-20.0_model_20180713-041900.hdf5 │ ├── split-0_batch-1_shuff-1_aug-0_loss-mar_slic-1_sub--1_strid-1_lr-0.01_recon-20.0_model_20180723-235354.hdf5 │ ├── split-0_batch-1_shuff-1_aug-1_loss-dice_slic-1_sub--1_strid-1_lr-0.01_recon-2.0_model_20180702-055808.hdf5 │ └── split-0_batch-1_shuff-1_aug-1_loss-mar_slic-1_sub--1_strid-1_lr-0.01_recon-20.0_model_20180724-060706.hdf5 ├── gen_mask.py ├── imgs ├── baselinecaps.png ├── manip_cropped.png ├── overfit-test.png ├── qualitative1.png ├── segcaps.png └── webcam.png ├── main.py ├── manip.py ├── models ├── __init__.py ├── densenets.py └── unet.py ├── notebook ├── 1.png ├── 2.png ├── 20180611-CapsulesForObjectSegmentation.ipynb ├── 20180613-COCO Dataset Processing for Class of Person .ipynb ├── 20180621-image_convert_testing.ipynb ├── 20180629-SegCaps-image-segmentation-with Color image input.ipynb ├── 20180629-image_convert_testing-Color-image.ipynb ├── 20180630-SegCaps-image-segmentation-with Color image input.ipynb ├── 20180701-SegCapsR3-image-segmentation-with Color image input.ipynb ├── 20180701-Unet-image-segmentation-with Color image input.ipynb ├── 20180702-Capsule Net-image-segmentation-with Color image input.ipynb ├── 3.png └── 4.png ├── raspberrypi ├── Raspi3-install.sh ├── __init__.py └── opencv-install.sh ├── requirements.txt ├── segcapsnet ├── __init__.py ├── capsnet.py ├── capsule_layers.py └── subpixel_upscaling.py ├── test.py ├── train.py └── utils ├── __init__.py ├── custom_data_aug.py ├── custom_losses.py ├── data_helper.py ├── load_2D_data.py ├── load_3D_data.py ├── load_data.py ├── metrics.py ├── model_helper.py └── threadsafe.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | # Eclipse project file 107 | .pydevproject 108 | .project 109 | 110 | # SegCaps 111 | data/evaluationScript 112 | data/figs/ 113 | data/imgs/ 114 | data/logs/ 115 | data/masks/ 116 | data/np_files/ 117 | data/plots/ 118 | data/results/ 119 | reference/ 120 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /_config.yml: -------------------------------------------------------------------------------- 1 | theme: jekyll-theme-cayman -------------------------------------------------------------------------------- /cococrawler/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cheng-Lin-Li/SegCaps/237209c2f53b4c39e8109390db144b89e0f49335/cococrawler/__init__.py -------------------------------------------------------------------------------- /cococrawler/getcoco17.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # encoding: utf-8 3 | ''' 4 | Data crawler for MS COCO 2017 semantic segmentation. 5 | Tasks: Download specific category images from MSCOCO web and generate pixel level masking image files on PNG format. 6 | 7 | @author: Cheng-Lin Li a.k.a. Clark 8 | 9 | @copyright: 2018 Cheng-Lin Li@Insight AI. All rights reserved. 10 | 11 | @license: Licensed under the Apache License v2.0. http://www.apache.org/licenses/ 12 | 13 | @contact: clark.cl.li@gmail.com 14 | @version: 1.2 15 | 16 | @create: June 13, 2018 17 | @updated: June 28, 2018 18 | 19 | Tasks: 20 | The program implementation leverage pycocotools to batch download images by specific category and generate mask files for image segmentation tasks. 21 | 22 | 23 | Data: 24 | Currently focus on person category data. 25 | 26 | Enhancement: 27 | version 1.2: 28 | Support image download by Image IDs with specific masking class. 29 | 30 | ''' 31 | import logging 32 | import argparse 33 | from os.path import join 34 | from pycocotools.coco import COCO 35 | import numpy as np 36 | import skimage.io as io 37 | import matplotlib.pyplot as plt 38 | import os 39 | from tqdm import tqdm 40 | import cv2 41 | 42 | FILE_MIDDLE_NAME = 'train' 43 | IMAGE_FOLDER = 'imgs' 44 | MASK_FOLDER = 'masks' 45 | RESOLUTION = 512 # Resolution of the input for the model. 46 | BACKGROUND_COLOR = (0, 0, 0) # Black background color for padding areas 47 | 48 | def image_resize2square(image, desired_size = None): 49 | ''' 50 | Resize image to a square by specific resolution(desired_size). 51 | ''' 52 | assert (image is not None), 'Image cannot be None.' 53 | 54 | # Initialize the dimensions of the image to be resized and 55 | # grab the size of image 56 | old_size = image.shape[:2] 57 | 58 | # if both the width and height are None, then return the 59 | # original image 60 | if desired_size is None or desired_size == 0: 61 | return image 62 | 63 | # calculate the ratio of the height and construct theima 64 | # dimensions 65 | ratio = float(desired_size) / max(old_size) 66 | new_size = tuple([int(x * ratio) for x in old_size]) 67 | 68 | # new_size should be in (width, height) format 69 | resized = cv2.resize(image, (new_size[1], new_size[0])) 70 | 71 | delta_w = desired_size - new_size[1] 72 | delta_h = desired_size - new_size[0] 73 | top, bottom = delta_h // 2, delta_h - (delta_h // 2) 74 | left, right = delta_w // 2, delta_w - (delta_w // 2) 75 | 76 | # Assign background color for padding areas. Default is Black. 77 | bg_color = BACKGROUND_COLOR 78 | new_image = cv2.copyMakeBorder(resized, top, bottom, left, right, cv2.BORDER_CONSTANT, value = bg_color) 79 | 80 | # return the resized image 81 | return new_image 82 | def create_path(data_dir): 83 | ''' 84 | Create a specific path to store result images. 85 | - Under the data directory, two separated folders will store image and masking files 86 | - Example: 87 | - data_dir- 88 | |- IMAGE_FOLDER 89 | |- MASK_FOLDER 90 | 91 | ''' 92 | try: 93 | output_image_path = join(data_dir, IMAGE_FOLDER) 94 | if not os.path.isdir(output_image_path): 95 | os.makedirs(output_image_path) 96 | output_mask_path = join(data_dir, MASK_FOLDER) 97 | if not os.path.isdir(output_mask_path): 98 | os.makedirs(output_mask_path) 99 | return output_image_path, output_mask_path 100 | except Exception as e: 101 | logging.error('\nCreate folders error! Message: %s'%(str(e))) 102 | exit(0) 103 | 104 | 105 | def main(args): 106 | ''' 107 | The main entry point of the program 108 | - This program will download image from MS COCO 2017 (Microsoft Common Objects in Context) repo 109 | and generate annotation to the specific object classes. 110 | ''' 111 | plt.ioff() 112 | 113 | data_dir = args.data_root_dir 114 | category_list = list(args.category) 115 | annFile = args.annotation_file 116 | num = args.number 117 | file_name = '' 118 | 119 | #Create path for output 120 | output_image_path, output_mask_path = create_path(data_dir) 121 | 122 | # initialize COCO API for instance annotations 123 | coco=COCO(annFile) 124 | 125 | # get all images containing given categories, select one at random 126 | catIds = coco.getCatIds(catNms=category_list); 127 | 128 | if args.id is not None: 129 | imgIds = list(args.id) 130 | num = len(imgIds) 131 | else: 132 | # Get image id list from categories. 133 | imgIds = coco.getImgIds(catIds=catIds ); 134 | 135 | print('\nImage Generating...') 136 | for i in tqdm(range(num)): 137 | try: 138 | if args.id is not None: 139 | img = coco.loadImgs(imgIds[i])[0] 140 | else: 141 | img = coco.loadImgs(imgIds[np.random.randint(0,len(imgIds))])[0] 142 | except Exception as e: 143 | print('\nError: Image ID: %s cannot be found in the annotation file.'%(e)) 144 | continue 145 | 146 | # use url to load image 147 | I = io.imread(img['coco_url']) 148 | resolution = args.resolution 149 | if resolution != 0: 150 | I = image_resize2square(I, args.resolution) 151 | else: 152 | pass 153 | 154 | plt.axis('off') 155 | file_name = join(output_image_path, FILE_MIDDLE_NAME+str(i) + '.png') 156 | plt.imsave(file_name, I) 157 | 158 | # Get annotation 159 | annIds = coco.getAnnIds(imgIds=img['id'], catIds=catIds, iscrowd=None) 160 | anns = coco.loadAnns(annIds) 161 | mask = coco.annToMask(anns[0]) 162 | 163 | # Generate mask 164 | for j in range(len(anns)): 165 | mask += coco.annToMask(anns[j]) 166 | 167 | # Background color = (R,G,B)=[68, 1, 84] for MS COCO 2017 168 | # save the mask image 169 | mask = image_resize2square(mask, args.resolution) 170 | file_name = join(output_mask_path, FILE_MIDDLE_NAME+str(i) + '.png') 171 | plt.imsave(file_name, mask) 172 | 173 | print('\nProgram finished !') 174 | return True 175 | 176 | if __name__ == '__main__': 177 | ''' 178 | Main program for MS COCO 2017 annotation mask images generation. 179 | Example command: 180 | $python3 getcoco17 --data_root_dir ./data --category person dog --annotation_dir './annotations/instances_val2017.json --number 10' 181 | ''' 182 | 183 | parser = argparse.ArgumentParser(description = 'Download COCO 2017 image Data') 184 | parser.add_argument('--data_root_dir', type = str, required = False, 185 | help='The root directory for your data.') 186 | parser.add_argument('--category', nargs = '+', type=str, default = 'person', 187 | help='MS COCO object categories list (--category person dog cat). default value is person') 188 | parser.add_argument('--annotation_file', type = str, default = './instances_val2017.json', 189 | help='The annotation json file directory of MS COCO object categories list. file name should be instances_val2017.json') 190 | parser.add_argument('--resolution', type = int, default = 0, 191 | help='The resolution of images you want to transfer. It will be a square image.' 192 | 'Default is 0. resolution = 0 will keep original image resolution') 193 | parser.add_argument('--id', nargs = '+', type=int, 194 | help='The id of images you want to download from MS COCO dataset.' 195 | 'Number of images is equal to the number of ids. Masking will base on category.') 196 | parser.add_argument('--number', type = int, default = 10, 197 | help='The total number of images you want to download.') 198 | 199 | arguments = parser.parse_args() 200 | 201 | main(arguments) 202 | -------------------------------------------------------------------------------- /command.csh: -------------------------------------------------------------------------------- 1 | python3 ./main.py --train --data_root_dir=data --net segcapsr3 --initial_lr 0.01 --loglevel 2 --Kfold 4 --loss dice --dataset mscoco17 --recon_wei 20 --which_gpu -1 --gpus 1 --aug_data 0 2 | 3 | python3 ./main.py --train --data_root_dir=data --net capsbasic --initial_lr 0.01 --loglevel 2 --Kfold 4 --loss dice --dataset mscoco17 --recon_wei 20 --which_gpu -1 --gpus 1 --aug_data 0 4 | 5 | #python3 ./main.py --train --data_root_dir=data --net unet --initial_lr 0.01 --loglevel 2 --Kfold 4 --loss w_bce --dataset mscoco17 --recon_wei 2 6 | -------------------------------------------------------------------------------- /data/imgs/train1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cheng-Lin-Li/SegCaps/237209c2f53b4c39e8109390db144b89e0f49335/data/imgs/train1.png -------------------------------------------------------------------------------- /data/masks/train1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cheng-Lin-Li/SegCaps/237209c2f53b4c39e8109390db144b89e0f49335/data/masks/train1.png -------------------------------------------------------------------------------- /data/saved_models/segcapsr3/split-0_batch-1_shuff-1_aug-0_loss-dice_slic-1_sub--1_strid-1_lr-0.0001_recon-20.0_model_20180705-092846.hdf5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cheng-Lin-Li/SegCaps/237209c2f53b4c39e8109390db144b89e0f49335/data/saved_models/segcapsr3/split-0_batch-1_shuff-1_aug-0_loss-dice_slic-1_sub--1_strid-1_lr-0.0001_recon-20.0_model_20180705-092846.hdf5 -------------------------------------------------------------------------------- /data/saved_models/segcapsr3/split-0_batch-1_shuff-1_aug-0_loss-dice_slic-1_sub--1_strid-1_lr-0.01_recon-20.0_model_20180707-222802.hdf5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cheng-Lin-Li/SegCaps/237209c2f53b4c39e8109390db144b89e0f49335/data/saved_models/segcapsr3/split-0_batch-1_shuff-1_aug-0_loss-dice_slic-1_sub--1_strid-1_lr-0.01_recon-20.0_model_20180707-222802.hdf5 -------------------------------------------------------------------------------- /data/saved_models/segcapsr3/split-0_batch-1_shuff-1_aug-0_loss-dice_slic-1_sub--1_strid-1_lr-0.01_recon-20.0_model_20180713-041900.hdf5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cheng-Lin-Li/SegCaps/237209c2f53b4c39e8109390db144b89e0f49335/data/saved_models/segcapsr3/split-0_batch-1_shuff-1_aug-0_loss-dice_slic-1_sub--1_strid-1_lr-0.01_recon-20.0_model_20180713-041900.hdf5 -------------------------------------------------------------------------------- /data/saved_models/segcapsr3/split-0_batch-1_shuff-1_aug-0_loss-mar_slic-1_sub--1_strid-1_lr-0.01_recon-20.0_model_20180723-235354.hdf5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cheng-Lin-Li/SegCaps/237209c2f53b4c39e8109390db144b89e0f49335/data/saved_models/segcapsr3/split-0_batch-1_shuff-1_aug-0_loss-mar_slic-1_sub--1_strid-1_lr-0.01_recon-20.0_model_20180723-235354.hdf5 -------------------------------------------------------------------------------- /data/saved_models/segcapsr3/split-0_batch-1_shuff-1_aug-1_loss-dice_slic-1_sub--1_strid-1_lr-0.01_recon-2.0_model_20180702-055808.hdf5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cheng-Lin-Li/SegCaps/237209c2f53b4c39e8109390db144b89e0f49335/data/saved_models/segcapsr3/split-0_batch-1_shuff-1_aug-1_loss-dice_slic-1_sub--1_strid-1_lr-0.01_recon-2.0_model_20180702-055808.hdf5 -------------------------------------------------------------------------------- /data/saved_models/segcapsr3/split-0_batch-1_shuff-1_aug-1_loss-mar_slic-1_sub--1_strid-1_lr-0.01_recon-20.0_model_20180724-060706.hdf5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cheng-Lin-Li/SegCaps/237209c2f53b4c39e8109390db144b89e0f49335/data/saved_models/segcapsr3/split-0_batch-1_shuff-1_aug-1_loss-mar_slic-1_sub--1_strid-1_lr-0.01_recon-20.0_model_20180724-060706.hdf5 -------------------------------------------------------------------------------- /gen_mask.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # encoding: utf-8 3 | ''' 4 | Generate image mask by trained model. 5 | Tasks: Input an image file and output a mask image file. 6 | 7 | @author: Cheng-Lin Li a.k.a. Clark 8 | 9 | @copyright: 2018 Cheng-Lin Li@Insight AI. All rights reserved. 10 | 11 | @license: Licensed under the Apache License v2.0. http://www.apache.org/licenses/ 12 | 13 | @contact: clark.cl.li@gmail.com 14 | 15 | Tasks: 16 | The program implementation will classify input image by a trained model and generate mask image as 17 | image segmentation results. 18 | 19 | 20 | Data: 21 | 22 | Currently focus on person category data. 23 | 24 | Reference: 25 | https://github.com/jrosebr1/imutils/blob/master/imutils/video/webcamvideostream.py 26 | 27 | ''' 28 | from threading import Thread 29 | import argparse 30 | import logging 31 | from os.path import join 32 | import numpy as np 33 | from skimage import measure, filters 34 | import scipy.ndimage.morphology 35 | 36 | from utils.model_helper import create_model 37 | 38 | # from data_helper import * 39 | from utils.load_2D_data import generate_test_image 40 | from utils.custom_data_aug import image_resize2square 41 | # from test import threshold_mask 42 | from datetime import datetime 43 | import cv2 44 | 45 | FILE_MIDDLE_NAME = 'train' 46 | IMAGE_FOLDER = 'imgs' 47 | MASK_FOLDER = 'masks' 48 | RESOLUTION = 512 # Resolution of the input for the model. 49 | ARGS = None 50 | NET_INPUT = None 51 | 52 | 53 | class FPS: 54 | ''' 55 | Calculate Frame per Second 56 | ''' 57 | def __init__(self): 58 | # store the start time, end time, and total number of frames 59 | # that were examined between the start and end intervals 60 | self._start = None 61 | self._end = None 62 | self._numFrames = 0 63 | 64 | def start(self): 65 | # start the timer 66 | self._start = datetime.now() 67 | return self 68 | 69 | def stop(self): 70 | # stop the timer 71 | self._end = datetime.now() 72 | 73 | def update(self): 74 | # increment the total number of frames examined during the 75 | # start and end intervals 76 | self._numFrames += 1 77 | 78 | def elapsed(self): 79 | # return the total number of seconds between the start and 80 | # end interval 81 | return (self._end - self._start).total_seconds() 82 | 83 | def fps(self): 84 | # compute the (approximate) frames per second 85 | return self._numFrames / self.elapsed() 86 | 87 | class WebcamVideoStream: 88 | ''' 89 | Leverage thread to read video stream to speed up process time. 90 | ''' 91 | def __init__(self, src=0): 92 | # initialize the video camera stream and read the first frame 93 | # from the stream 94 | self.stream = cv2.VideoCapture(src) 95 | (self.grabbed, self.frame) = self.stream.read() 96 | 97 | # initialize the variable used to indicate if the thread should 98 | # be stopped 99 | self.stopped = False 100 | 101 | def start(self): 102 | # start the thread to read frames from the video stream 103 | t = Thread(target=self.update, args=()) 104 | t.daemon = True 105 | t.start() 106 | return self 107 | 108 | def update(self): 109 | # keep looping infinitely until the thread is stopped 110 | while True: 111 | # if the thread indicator variable is set, stop the thread 112 | if self.stopped: 113 | return 114 | 115 | # otherwise, read the next frame from the stream 116 | (self.grabbed, self.frame) = self.stream.read() 117 | 118 | def read(self): 119 | # return the frame most recently read 120 | return self.frame 121 | 122 | def stop(self): 123 | # indicate that the thread should be stopped 124 | self.stopped = True 125 | 126 | def threshold_mask(raw_output, threshold): #raw_output 3d:(119, 512, 512) 127 | if threshold == 0: 128 | try: 129 | threshold = filters.threshold_otsu(raw_output) 130 | except: 131 | threshold = 0.5 132 | 133 | logging.info('\tThreshold: {}'.format(threshold)) 134 | 135 | raw_output[raw_output > threshold] = 1 136 | raw_output[raw_output < 1] = 0 137 | 138 | #all_labels 3d:(119, 512, 512) 139 | all_labels = measure.label(raw_output) 140 | # props 3d: region of props=>list(_RegionProperties:) 141 | # with bbox. 142 | props = measure.regionprops(all_labels) 143 | props.sort(key=lambda x: x.area, reverse=True) 144 | thresholded_mask = np.zeros(raw_output.shape) 145 | 146 | if len(props) >= 2: 147 | # if the largest is way larger than the second largest 148 | if props[0].area / props[1].area > 5: 149 | thresholded_mask[all_labels == props[0].label] = 1 # only turn on the largest component 150 | else: 151 | thresholded_mask[all_labels == props[0].label] = 1 # turn on two largest components 152 | thresholded_mask[all_labels == props[1].label] = 1 153 | elif len(props): 154 | thresholded_mask[all_labels == props[0].label] = 1 155 | # threshold_mask: 3d=(119, 512, 512) 156 | thresholded_mask = scipy.ndimage.morphology.binary_fill_holes(thresholded_mask).astype(np.uint8) 157 | 158 | return thresholded_mask 159 | 160 | def apply_mask(image, mask): 161 | """apply mask to image""" 162 | 163 | 164 | redImg = np.zeros(image.shape, image.dtype) 165 | redImg[:,:] = (0,0,255) 166 | redMask = cv2.bitwise_and(redImg, redImg, mask=mask) 167 | cv2.addWeighted(redMask, 1, image, 1, 0, image) 168 | 169 | return image 170 | 171 | 172 | 173 | class segmentation_model(): 174 | ''' 175 | Model construction class for prediction 176 | ''' 177 | def __init__(self, args, net_input_shape): 178 | ''' 179 | Create evaluation model and load the pre-train weights for inference. 180 | ''' 181 | self.net_input_shape = net_input_shape 182 | weights_path = join(args.weights_path) 183 | # Create model object in inference mode but Disable decoder layer. 184 | _, eval_model, _ = create_model(args, net_input_shape, enable_decoder = False) 185 | 186 | # Load weights trained on MS-COCO by name because part of output layers are disable. 187 | eval_model.load_weights(weights_path, by_name=True) 188 | self.model = eval_model 189 | 190 | 191 | def detect(self, img_list, verbose = False): 192 | result = [] 193 | r = dict() 194 | 195 | for img_data in img_list: 196 | output_array = self.model.predict_generator(generate_test_image(img_data, 197 | self.net_input_shape, 198 | batchSize=1, 199 | numSlices=1, 200 | subSampAmt=0, 201 | stride=1), 202 | steps=1, max_queue_size=1, workers=4, 203 | use_multiprocessing=False, verbose=1) 204 | output = output_array[:,:,:,0] 205 | threshold_level = 0 206 | output_bin = threshold_mask(output, threshold_level) 207 | r['masks'] = output_bin[0,:,:] 208 | 209 | # If you want to test the masking without prediction, mark out above line and unmark below line. 210 | # Below line is make a dummy masking to test the speed. 211 | # r['masks'] = np.ones((512, 512), np.int8) # Testing 212 | result.append(r) 213 | return result 214 | 215 | 216 | if __name__ == '__main__': 217 | ''' 218 | Main program for images segmentation by mask image. 219 | Example command: 220 | $python3 gen_mask --input_file ../data/image/train1.png --net segcapsr3 --model_weight ../data/saved_models/segcapsr3/dice16-255.hdf5 221 | ''' 222 | 223 | parser = argparse.ArgumentParser(description = 'Mask image by segmentation algorithm') 224 | 225 | parser.add_argument('--net', type = str.lower, default = 'segcapsr3', 226 | choices = ['segcapsr3', 'segcapsr1', 'capsbasic', 'unet', 'tiramisu'], 227 | help = 'Choose your network.') 228 | parser.add_argument('--weights_path', type = str, required = True, 229 | help = '/path/to/trained_model.hdf5 from root. Set to "" for none.') 230 | parser.add_argument('--num_class', type = int, default = 2, 231 | help = 'Number of classes to segment. Default is 2. If number of classes > 2, ' 232 | ' the loss function will be softmax entropy and only apply on SegCapsR3' 233 | '** Current version only support binary classification tasks.') 234 | parser.add_argument('--which_gpus', type = str, default = '0', 235 | help='Enter "-2" for CPU only, "-1" for all GPUs available, ' 236 | 'or a comma separated list of GPU id numbers ex: "0,1,4".') 237 | parser.add_argument('--gpus', type = int, default = -1, 238 | help = 'Number of GPUs you have available for training. ' 239 | 'If entering specific GPU ids under the --which_gpus arg or if using CPU, ' 240 | 'then this number will be inferred, else this argument must be included.') 241 | 242 | 243 | args = parser.parse_args() 244 | net_input_shape = (RESOLUTION, RESOLUTION, 1) 245 | model = segmentation_model(args, net_input_shape) 246 | 247 | 248 | # # grab a pointer to the video stream and initialize the FPS counter 249 | # print('[INFO] sampling frames from webcam...') 250 | # cap = cv2.VideoCapture(0) 251 | 252 | # these 3 lines can control fps, frame width and height. 253 | # cap.set(cv2.CAP_PROP_FRAME_WIDTH, RESOLUTION) 254 | # cap.set(cv2.CAP_PROP_FRAME_HEIGHT, RESOLUTION) 255 | # cap.set(cv2.CAP_PROP_FPS, 0.1) 256 | # fps = FPS().start() 257 | 258 | # created a *threaded* video stream, allow the camera sensor to warmup, 259 | # and start the FPS counter 260 | print("[INFO] sampling THREADED frames from webcam...") 261 | vs = WebcamVideoStream(src=0).start() 262 | fps = FPS().start() 263 | 264 | # loop over some frames 265 | while fps._numFrames < 10000: 266 | # grab the frame from the capture stream and resize it to have a maximum 267 | # (grabbed, frame) = cap.read() 268 | frame = vs.read() 269 | frame = image_resize2square(frame, RESOLUTION) # frame = (512, 512, 3) 270 | 271 | # check to see if the frame should be displayed to our screen 272 | results = model.detect([frame], verbose=0) 273 | r = results[0] #r['masks'] = [512, 512] 274 | frame = apply_mask(frame, r['masks']) 275 | 276 | cv2.imshow("Frame", frame) 277 | # Press q or ESC to stop the video 278 | if cv2.waitKey(1) & 0xFF == ord('q') or cv2.waitKey(1) == 27: 279 | break 280 | else: 281 | pass 282 | # update the FPS counter 283 | fps.update() 284 | 285 | # stop the timer and display FPS information 286 | fps.stop() 287 | print("[INFO] elasped time: {:.2f}".format(fps.elapsed())) 288 | print("[INFO] approx. FPS: {:.2f}".format(fps.fps())) 289 | 290 | # do a bit of cleanup 291 | vs.release() 292 | cv2.destroyAllWindows() 293 | -------------------------------------------------------------------------------- /imgs/baselinecaps.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cheng-Lin-Li/SegCaps/237209c2f53b4c39e8109390db144b89e0f49335/imgs/baselinecaps.png -------------------------------------------------------------------------------- /imgs/manip_cropped.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cheng-Lin-Li/SegCaps/237209c2f53b4c39e8109390db144b89e0f49335/imgs/manip_cropped.png -------------------------------------------------------------------------------- /imgs/overfit-test.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cheng-Lin-Li/SegCaps/237209c2f53b4c39e8109390db144b89e0f49335/imgs/overfit-test.png -------------------------------------------------------------------------------- /imgs/qualitative1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cheng-Lin-Li/SegCaps/237209c2f53b4c39e8109390db144b89e0f49335/imgs/qualitative1.png -------------------------------------------------------------------------------- /imgs/segcaps.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cheng-Lin-Li/SegCaps/237209c2f53b4c39e8109390db144b89e0f49335/imgs/segcaps.png -------------------------------------------------------------------------------- /imgs/webcam.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cheng-Lin-Li/SegCaps/237209c2f53b4c39e8109390db144b89e0f49335/imgs/webcam.png -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Capsules for Object Segmentation (SegCaps) 3 | Original Paper by Rodney LaLonde and Ulas Bagci (https://arxiv.org/abs/1804.04241) 4 | Code written by: Rodney LaLonde 5 | If you use significant portions of this code or the ideas from our paper, please cite it :) 6 | If you have any questions, please email me at lalonde@knights.ucf.edu. 7 | 8 | This is the main file for the project. From here you can train, test, 9 | and manipulate the SegCaps of models. 10 | Please see the README for detailed instructions for this project. 11 | 12 | ============== 13 | This is the entry point of the package to train UNet, tiramisu, 14 | Capsule Nets (capsbasic) or SegCaps(segcapsr1 or segcapsr3). 15 | 16 | @author: Cheng-Lin Li a.k.a. Clark 17 | 18 | @copyright:2018 Cheng-Lin Li@Insight AI. All rights reserved. 19 | 20 | @license: Licensed under the Apache License v2.0. 21 | http://www.apache.org/licenses/ 22 | 23 | @contact: clark.cl.li@gmail.com 24 | 25 | Tasks: 26 | The program load parameters for training, testing, manipulation 27 | for all models. 28 | 29 | 30 | Data: 31 | MS COCO 2017 or LUNA 2016 were tested on this package. 32 | You can leverage your own data set but the mask images should follow the format of MS COCO or with background color = 0 on each channel. 33 | 34 | Enhancement: 35 | 1. The program was modified to support python 3.6 on Ubuntu 18.04 and Windows 10. 36 | 2. Support not only 3D computed tomography scan images but also 2D Microsoft Common Objects in COntext (MS COCO) dataset images. 37 | 3. Add Kfold parameter for users to customize the cross validation task. K = 1 will force model to perform overfit. 38 | 4. Add retrain parameter to enable users to reload pre-trained weights and retrain the model. 39 | 5. Add initial learning rate for users to adjust. 40 | 6. Add steps per epoch for users to adjust. 41 | 7. Add number of patience for early stop of training to users. 42 | 8. Add 'bce_dice' loss function as binary cross entropy + soft dice coefficient. 43 | 9. Revise 'train', 'test', 'manip' flags from 0 or 1 to flags show up or not to indicate the behavior of main program. 44 | ''' 45 | 46 | from __future__ import print_function 47 | import sys 48 | import logging 49 | import platform 50 | from os.path import join 51 | from os import makedirs 52 | from os import environ 53 | import argparse 54 | import SimpleITK as sitk # image process 55 | from time import gmtime, strftime 56 | from keras.utils import print_summary 57 | from utils.load_data import load_data, split_data 58 | from utils.model_helper import create_model 59 | 60 | time = strftime("%Y%m%d-%H%M%S", gmtime()) 61 | RESOLUTION = 512 # Resolution of the input for the model. 62 | GRAYSCALE = True 63 | LOGGING_FORMAT = '%(levelname)s %(asctime)s: %(message)s' 64 | 65 | 66 | def main(args): 67 | # Ensure training, testing, and manip are not all turned off 68 | assert (args.train or args.test or args.manip), 'Cannot have train, test, and manip all set to 0, Nothing to do.' 69 | 70 | # Load the training, validation, and testing data 71 | try: 72 | train_list, val_list, test_list = load_data(args.data_root_dir, args.split_num) 73 | except: 74 | # Create the training and test splits if not found 75 | logging.info('\nNo existing training, validate, test files...System will generate it.') 76 | split_data(args.data_root_dir, num_splits = args.Kfold) 77 | train_list, val_list, test_list = load_data(args.data_root_dir, args.split_num) 78 | 79 | # Get image properties from first image. Assume they are all the same. 80 | logging.info('\nRead image files...%s'%(join(args.data_root_dir, 'imgs', train_list[0][0]))) 81 | # Get image shape from the first image. 82 | image = sitk.GetArrayFromImage(sitk.ReadImage(join(args.data_root_dir, 'imgs', train_list[0][0]))) 83 | img_shape = image.shape # # (x, y, channels) 84 | if args.dataset == 'luna16': 85 | net_input_shape = (img_shape[1], img_shape[2], args.slices) 86 | else: 87 | args.slices = 1 88 | if GRAYSCALE == True: 89 | net_input_shape = (RESOLUTION, RESOLUTION, 1) # only one channel 90 | else: 91 | net_input_shape = (RESOLUTION, RESOLUTION, 3) # Only access RGB 3 channels. 92 | # Create the model for training/testing/manipulation 93 | # enable_decoder = False only for SegCaps R3 to disable recognition image output on evaluation model 94 | # to speed up performance. 95 | model_list = create_model(args=args, input_shape=net_input_shape, enable_decoder=True) 96 | print_summary(model=model_list[0], positions=[.38, .65, .75, 1.]) 97 | 98 | args.output_name = 'split-' + str(args.split_num) + '_batch-' + str(args.batch_size) + \ 99 | '_shuff-' + str(args.shuffle_data) + '_aug-' + str(args.aug_data) + \ 100 | '_loss-' + str(args.loss) + '_slic-' + str(args.slices) + \ 101 | '_sub-' + str(args.subsamp) + '_strid-' + str(args.stride) + \ 102 | '_lr-' + str(args.initial_lr) + '_recon-' + str(args.recon_wei) 103 | 104 | # args.output_name = 'sh-' + str(args.shuffle_data) + '_a-' + str(args.aug_data) 105 | 106 | args.time = time 107 | if platform.system() == 'Windows': 108 | args.use_multiprocessing = False 109 | else: 110 | args.use_multiprocessing = True 111 | args.check_dir = join(args.data_root_dir,'saved_models', args.net) 112 | try: 113 | makedirs(args.check_dir) 114 | except: 115 | pass 116 | 117 | args.log_dir = join(args.data_root_dir,'logs', args.net) 118 | try: 119 | makedirs(args.log_dir) 120 | except: 121 | pass 122 | 123 | args.tf_log_dir = join(args.log_dir, 'tf_logs') 124 | try: 125 | makedirs(args.tf_log_dir) 126 | except: 127 | pass 128 | 129 | args.output_dir = join(args.data_root_dir, 'plots', args.net) 130 | try: 131 | makedirs(args.output_dir) 132 | except: 133 | pass 134 | 135 | if args.train == True: 136 | from train import train 137 | # Run training 138 | train(args, train_list, val_list, model_list[0], net_input_shape) 139 | 140 | if args.test == True: 141 | from test import test 142 | # Run testing 143 | test(args, test_list, model_list, net_input_shape) 144 | 145 | if args.manip == True: 146 | from manip import manip 147 | # Run manipulation of segcaps 148 | manip(args, test_list, model_list, net_input_shape) 149 | 150 | 151 | def loglevel(level=0): 152 | assert isinstance(level, int) 153 | try: 154 | return [ 155 | # logging.CRITICAL, 156 | # logging.ERROR, 157 | logging.WARNING, # default 158 | logging.INFO, 159 | logging.DEBUG, 160 | logging.NOTSET, 161 | ][level] 162 | except LookupError: 163 | return logging.NOTSET 164 | 165 | 166 | if __name__ == '__main__': 167 | parser = argparse.ArgumentParser( 168 | description='Train on Medical Data or MS COCO dataset' 169 | ) 170 | parser.add_argument('--data_root_dir', type=str, required=True, 171 | help='The root directory for your data.') 172 | parser.add_argument('--weights_path', type=str, default='', 173 | help='/path/to/trained_model.hdf5 from root. Set to "" for none.') 174 | parser.add_argument('--split_num', type=int, default = 0, 175 | help='Which training split to train/test on.') 176 | parser.add_argument('--net', type=str.lower, default='segcapsr3', 177 | choices=['segcapsr3', 'segcapsr1', 'capsbasic', 'unet', 'tiramisu'], 178 | help='Choose your network.') 179 | parser.add_argument('--train', action='store_true', 180 | help='Add this flag to enable training.') 181 | parser.add_argument('--test', action='store_true', 182 | help='Add this flag to enable testing.') 183 | parser.add_argument('--manip', action='store_true', 184 | help='Add this flag to enable manipulation.') 185 | parser.add_argument('--shuffle_data', type=int, default=1, choices=[0, 1], 186 | help='Whether or not to shuffle the training data (both per epoch and in slice order.') 187 | parser.add_argument('--aug_data', type=int, default=1, choices=[0, 1], 188 | help='Whether or not to use data augmentation during training.') 189 | parser.add_argument('--loss', type=str.lower, default='w_bce', choices=['bce', 'w_bce', 'dice', 'bce_dice', 'mar', 'w_mar'], 190 | help='Which loss to use. "bce" and "w_bce": unweighted and weighted binary cross entropy' 191 | ', "dice": soft dice coefficient, "bce_dice": binary cross entropy + soft dice coefficient' 192 | ', "mar" and "w_mar": unweighted and weighted margin loss.') 193 | # TODO: multiclass segmentation. 194 | # # Calculate distance from actual labels using cross entropy 195 | # cross_entropy = tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=correct_label_reshaped[:]) 196 | # #Take mean for total loss 197 | # loss_op = tf.reduce_mean(cross_entropy, name="fcn_loss") 198 | parser.add_argument('--batch_size', type=int, default=1, 199 | help='Batch size for training/testing.') 200 | parser.add_argument('--initial_lr', type=float, default=0.00001, 201 | help='Initial learning rate for Adam.') 202 | parser.add_argument('--steps_per_epoch', type=int, default=1000, 203 | help='Number of iterations in an epoch.') 204 | parser.add_argument('--epochs', type=int, default=20, 205 | help='Number of epochs for training.') 206 | parser.add_argument('--patience', type=int, default=10, 207 | help='Number of patience indicates the criteria of early stop training.' 208 | 'If score of metrics do not improve during the patience of epochs,' 209 | ' the training will be stopped.') 210 | parser.add_argument('--recon_wei', type=float, default=131.072, 211 | help='If using capsnet: The coefficient (weighting) for the loss of decoder') 212 | parser.add_argument('--slices', type=int, default=1, 213 | help='Number of slices to include for training/testing.') 214 | parser.add_argument('--subsamp', type=int, default=-1, 215 | help='Number of slices to skip when forming 3D samples for training. Enter -1 for random ' 216 | 'subsampling up to 5%% of total slices.') 217 | parser.add_argument('--stride', type=int, default=1, 218 | help='Number of slices to move when generating the next sample.') 219 | parser.add_argument('--verbose', type=int, default=1, choices=[0, 1, 2], 220 | help='Set the verbose value for training. 0: Silent, 1: per iteration, 2: per epoch.') 221 | parser.add_argument('--save_raw', type=int, default=1, choices=[0, 1], 222 | help='Enter 0 to not save, 1 to save.') 223 | parser.add_argument('--save_seg', type=int, default=1, choices=[0, 1], 224 | help='Enter 0 to not save, 1 to save.') 225 | parser.add_argument('--save_prefix', type=str, default='', 226 | help='Prefix to append to saved CSV.') 227 | parser.add_argument('--thresh_level', type=float, default=0., 228 | help = 'Enter 0.0 for masking refine by Otsu algorithm.' 229 | ' Or set a value for thresholding level of masking. Value should between 0 and 1.') 230 | parser.add_argument('--compute_dice', type=int, default=1, 231 | help='0 or 1') 232 | parser.add_argument('--compute_jaccard', type=int, default=1, 233 | help='0 or 1') 234 | parser.add_argument('--compute_assd', type=int, default=0, 235 | help='0 or 1') 236 | parser.add_argument('--which_gpus', type = str, default='0', 237 | help='Enter "-2" for CPU only, "-1" for all GPUs available, ' 238 | 'or a comma separated list of GPU id numbers ex: "0,1,4".') 239 | parser.add_argument('--gpus', type=int, default=-1, 240 | help = 'Number of GPUs you have available for training. ' 241 | 'If entering specific GPU ids under the --which_gpus arg or if using CPU, ' 242 | 'then this number will be inferred, else this argument must be included.') 243 | # Enhancements: 244 | # TODO: implement softmax entroyp loss function for multiclass segmentation 245 | parser.add_argument('--dataset', type=str.lower, default='mscoco17', choices=['luna16', 'mscoco17'], 246 | help='Enter "mscoco17" for COCO dataset, "luna16" for CT images') 247 | parser.add_argument('--num_class', type=int, default=2, 248 | help='Number of classes to segment. Default is 2. If number of classes > 2, ' 249 | ' the loss function will be softmax entropy and only apply on SegCapsR3' 250 | '** Current version only support binary classification tasks.') 251 | parser.add_argument('--Kfold', type=int, default=4, help='Define K value for K-fold cross validate' 252 | ' default K = 4, K = 1 for over-fitting test') 253 | parser.add_argument('--retrain', type=int, default=0, choices=[0, 1], help='Retrain your model based on existing weights.' 254 | ' default 0=train your model from scratch, 1 = retrain existing model.' 255 | ' The weights file location of the model has to be provided by --weights_path parameter' ) 256 | parser.add_argument('--loglevel', type=int, default=4, help='loglevel 3 = debug, 2 = info, 1 = warning, ' 257 | ' 4 = error, > 4 =critical') 258 | arguments = parser.parse_args() 259 | 260 | # assuming loglevel is bound to the string value obtained from the 261 | # command line argument. Convert to upper case to allow the user to 262 | # specify --log=DEBUG or --log=debug 263 | logging.basicConfig(format=LOGGING_FORMAT, level=loglevel(arguments.loglevel), stream=sys.stderr) 264 | 265 | if arguments.which_gpus == -2: 266 | environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 267 | environ["CUDA_VISIBLE_DEVICES"] = "" 268 | elif arguments.which_gpus == '-1': 269 | assert (arguments.gpus != -1), 'Use all GPUs option selected under --which_gpus, with this option the user MUST ' \ 270 | 'specify the number of GPUs available with the --gpus option.' 271 | else: 272 | arguments.gpus = len(arguments.which_gpus.split(',')) 273 | environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 274 | environ["CUDA_VISIBLE_DEVICES"] = str(arguments.which_gpus) 275 | 276 | if arguments.gpus > 1: 277 | assert arguments.batch_size >= arguments.gpus, 'Error: Must have at least as many items per batch as GPUs ' \ 278 | 'for multi-GPU training. For model parallelism instead of ' \ 279 | 'data parallelism, modifications must be made to the code.' 280 | 281 | main(arguments) 282 | -------------------------------------------------------------------------------- /manip.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Capsules for Object Segmentation (SegCaps) 3 | Original Paper by Rodney LaLonde and Ulas Bagci (https://arxiv.org/abs/1804.04241) 4 | Code written by: Rodney LaLonde 5 | If you use significant portions of this code or the ideas from our paper, please cite it :) 6 | If you have any questions, please email me at lalonde@knights.ucf.edu. 7 | 8 | This file is used for manipulating the vectors of the final layer of capsules (the SegCaps or segmentation capsules). 9 | This manipulation attempts to show what each dimension of these final vectors are storing (paying attention to), 10 | in terms of information about the positive input class. 11 | Please see the README for further details about how to use this file. 12 | ''' 13 | 14 | from __future__ import print_function 15 | 16 | from os.path import join 17 | from os import makedirs 18 | import SimpleITK as sitk 19 | from tqdm import tqdm, trange 20 | from PIL import Image 21 | import numpy as np 22 | import math 23 | 24 | from keras import backend as K 25 | K.set_image_data_format('channels_last') 26 | from keras.utils import print_summary 27 | 28 | 29 | def combine_images(generated_images, height=None, width=None): 30 | num = generated_images.shape[0] 31 | if width is None and height is None: 32 | width = int(math.sqrt(num)) 33 | height = int(math.ceil(float(num)/width)) 34 | elif width is not None and height is None: # height not given 35 | height = int(math.ceil(float(num)/width)) 36 | elif height is not None and width is None: # width not given 37 | width = int(math.ceil(float(num)/height)) 38 | 39 | shape = generated_images.shape[1:3] 40 | image = np.zeros((height*shape[0], width*shape[1]), 41 | dtype=generated_images.dtype) 42 | for index, img in enumerate(generated_images): 43 | i = int(index/width) 44 | j = index % width 45 | image[i*shape[0]:(i+1)*shape[0], j*shape[1]:(j+1)*shape[1]] = \ 46 | img[:, :, 0] 47 | return image 48 | 49 | 50 | def manip(args, test_list, model_list, net_input_shape): 51 | if args.weights_path == '': 52 | weights_path = join(args.check_dir, args.output_name + '_model_' + args.time + '.hdf5') 53 | else: 54 | weights_path = join(args.data_root_dir, args.weights_path) 55 | 56 | output_dir = join(args.data_root_dir, 'results', args.net, 'split_' + str(args.split_num)) 57 | manip_out_dir = join(output_dir, 'manip_output') 58 | try: 59 | makedirs(manip_out_dir) 60 | except: 61 | pass 62 | 63 | assert(len(model_list) == 3), "Must be using segcaps with the three models." 64 | manip_model = model_list[2] 65 | try: 66 | manip_model.load_weights(weights_path) 67 | except: 68 | print('Unable to find weights path. Testing with random weights.') 69 | print_summary(model=manip_model, positions=[.38, .65, .75, 1.]) 70 | 71 | 72 | # Manipulating capsule vectors 73 | print('Testing... This will take some time...') 74 | 75 | for i, img in enumerate(tqdm(test_list)): 76 | sitk_img = sitk.ReadImage(join(args.data_root_dir, 'imgs', img[0])) 77 | img_data = sitk.GetArrayFromImage(sitk_img) 78 | num_slices = img_data.shape[0] 79 | sitk_mask = sitk.ReadImage(join(args.data_root_dir, 'masks', img[0])) 80 | gt_data = sitk.GetArrayFromImage(sitk_mask) 81 | 82 | x, y = img_data[num_slices//2, :, :], gt_data[num_slices//2, :, :] 83 | x, y = np.expand_dims(np.expand_dims(x, -1), 0), np.expand_dims(np.expand_dims(y, -1), 0) 84 | 85 | noise = np.zeros([1, 512, 512, 1, 16]) 86 | x_recons = [] 87 | for dim in trange(16): 88 | for r in [-0.25, -0.125, 0, 0.125, 0.25]: 89 | tmp = np.copy(noise) 90 | tmp[:, :, :, :, dim] = r 91 | x_recon = manip_model.predict([x, y, tmp]) 92 | x_recons.append(x_recon) 93 | 94 | x_recons = np.concatenate(x_recons) 95 | 96 | out_img = combine_images(x_recons, height=16) 97 | out_image = out_img * 4096 98 | out_image[out_image > 574] = 574 99 | out_image = out_image / 574 * 255 100 | 101 | Image.fromarray(out_image.astype(np.uint8)).save(join(manip_out_dir, img[0][:-4] + '_manip_output.png')) 102 | 103 | print('Done.') 104 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cheng-Lin-Li/SegCaps/237209c2f53b4c39e8109390db144b89e0f49335/models/__init__.py -------------------------------------------------------------------------------- /models/unet.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file contains an implementation of U-Net based on the paper 3 | "U-Net: Convolutional Network for Biomedical Image Segmentation" 4 | (https://arxiv.org/abs/1505.04597). 5 | """ 6 | from keras.models import Model 7 | from keras.layers import Input, concatenate, Conv2D, MaxPooling2D, Conv2DTranspose 8 | 9 | def UNet(input_shape=(512,512,1)): 10 | inputs = Input(input_shape) 11 | conv1 = Conv2D(64, (3, 3), activation='relu', padding='same')(inputs) 12 | conv1 = Conv2D(64, (3, 3), activation='relu', padding='same')(conv1) 13 | pool1 = MaxPooling2D(pool_size=(2, 2))(conv1) 14 | 15 | conv2 = Conv2D(128, (3, 3), activation='relu', padding='same')(pool1) 16 | conv2 = Conv2D(128, (3, 3), activation='relu', padding='same')(conv2) 17 | pool2 = MaxPooling2D(pool_size=(2, 2))(conv2) 18 | 19 | conv3 = Conv2D(256, (3, 3), activation='relu', padding='same')(pool2) 20 | conv3 = Conv2D(256, (3, 3), activation='relu', padding='same')(conv3) 21 | pool3 = MaxPooling2D(pool_size=(2, 2))(conv3) 22 | 23 | conv4 = Conv2D(512, (3, 3), activation='relu', padding='same')(pool3) 24 | conv4 = Conv2D(512, (3, 3), activation='relu', padding='same')(conv4) 25 | pool4 = MaxPooling2D(pool_size=(2, 2))(conv4) 26 | 27 | conv5 = Conv2D(1024, (3, 3), activation='relu', padding='same')(pool4) 28 | conv5 = Conv2D(1024, (3, 3), activation='relu', padding='same')(conv5) 29 | 30 | up6 = concatenate([Conv2DTranspose(512, (2, 2), strides=(2, 2), padding='same')(conv5), conv4], axis=3) 31 | conv6 = Conv2D(512, (3, 3), activation='relu', padding='same')(up6) 32 | conv6 = Conv2D(512, (3, 3), activation='relu', padding='same')(conv6) 33 | 34 | up7 = concatenate([Conv2DTranspose(256, (2, 2), strides=(2, 2), padding='same')(conv6), conv3], axis=3) 35 | conv7 = Conv2D(256, (3, 3), activation='relu', padding='same')(up7) 36 | conv7 = Conv2D(256, (3, 3), activation='relu', padding='same')(conv7) 37 | 38 | up8 = concatenate([Conv2DTranspose(128, (2, 2), strides=(2, 2), padding='same')(conv7), conv2], axis=3) 39 | conv8 = Conv2D(128, (3, 3), activation='relu', padding='same')(up8) 40 | conv8 = Conv2D(128, (3, 3), activation='relu', padding='same')(conv8) 41 | 42 | up9 = concatenate([Conv2DTranspose(64, (2, 2), strides=(2, 2), padding='same')(conv8), conv1], axis=3) 43 | conv9 = Conv2D(64, (3, 3), activation='relu', padding='same')(up9) 44 | conv9 = Conv2D(64, (3, 3), activation='relu', padding='same')(conv9) 45 | 46 | conv10 = Conv2D(1, (1, 1), activation='sigmoid')(conv9) 47 | 48 | model = Model(inputs=[inputs], outputs=[conv10]) 49 | 50 | return model 51 | -------------------------------------------------------------------------------- /notebook/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cheng-Lin-Li/SegCaps/237209c2f53b4c39e8109390db144b89e0f49335/notebook/1.png -------------------------------------------------------------------------------- /notebook/2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cheng-Lin-Li/SegCaps/237209c2f53b4c39e8109390db144b89e0f49335/notebook/2.png -------------------------------------------------------------------------------- /notebook/20180629-SegCaps-image-segmentation-with Color image input.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# SegCaps on Image Segmentation for Person\n", 8 | "## Input Color image files\n", 9 | "## Integrated with WebCam Video\n", 10 | "\n", 11 | "A quick intro to using the pre-trained model to detect and segment object of person.\n", 12 | "\n", 13 | "This notebook tests the model loading function from image file of a saved model." 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": 1, 19 | "metadata": {}, 20 | "outputs": [ 21 | { 22 | "name": "stderr", 23 | "output_type": "stream", 24 | "text": [ 25 | "Using TensorFlow backend.\n" 26 | ] 27 | } 28 | ], 29 | "source": [ 30 | "import os\n", 31 | "from os import path\n", 32 | "from os.path import join, basename\n", 33 | "import sys\n", 34 | "import random\n", 35 | "import math\n", 36 | "\n", 37 | "import warnings\n", 38 | "warnings.filterwarnings('ignore')\n", 39 | "\n", 40 | "# import SimpleITK as sitk\n", 41 | "import numpy as np\n", 42 | "# import skimage.io\n", 43 | "# import matplotlib\n", 44 | "import matplotlib.pyplot as plt\n", 45 | "import sys\n", 46 | "# Add the ptdraft folder path to the sys.path list\n", 47 | "sys.path.append('../')\n", 48 | "\n", 49 | "from keras.utils import print_summary\n", 50 | "from keras import layers, models\n", 51 | "\n", 52 | "import segcapsnet.capsnet as modellib\n", 53 | "import models.unet as unet\n", 54 | "\n", 55 | "from utils.model_helper import create_model\n", 56 | "from utils.load_2D_data import generate_test_batches, generate_test_image\n", 57 | "from test import *\n", 58 | "from PIL import Image\n", 59 | "import scipy.ndimage.morphology\n", 60 | "from skimage import measure, filters\n", 61 | "from datetime import datetime\n", 62 | "\n", 63 | "\n", 64 | "%matplotlib inline \n", 65 | "\n", 66 | "RESOLUTION = 512\n", 67 | "\n", 68 | "# Root directory of the project\n", 69 | "ROOT_DIR = path.dirname(\"../\")\n", 70 | "DATA_DIR = path.join(ROOT_DIR, \"data\")\n", 71 | "\n", 72 | "# Directory to save logs and trained model\n", 73 | "# MODEL_DIR = path.join(DATA_DIR, \"saved_models/segcapsr3/m1.hdf5\") # LUNA16\n", 74 | "\n", 75 | "# Local path to trained weights file\n", 76 | "# loss function = Dice is better than BCE (Binary Cross Entropy)\n", 77 | "# COCO_MODEL_PATH = path.join(DATA_DIR, \"saved_models/segcapsr3/dice16-255.hdf5\") # MSCOCO17\n", 78 | "COCO_MODEL_PATH = path.join(DATA_DIR, \"saved_models/capsbasic/cb1.hdf5\") # MSCOCO17\n", 79 | "# COCO_MODEL_PATH = path.join(DATA_DIR, \"saved_models/segcapsr3/mar10-255.hdf5\") # MSCOCO17\n", 80 | "# COCO_MODEL_PATH = path.join(DATA_DIR, \"saved_models/segcapsr3/bce.hdf5\") # MSCOCO17\n", 81 | "# COCO_MODEL_PATH = path.join(DATA_DIR, \"saved_models/unet/unet1.hdf5\") # MSCOCO17\n", 82 | "\n", 83 | "\n", 84 | "# Directory of images to run detection on\n", 85 | "IMAGE_DIR = path.join(DATA_DIR, \"imgs\")" 86 | ] 87 | }, 88 | { 89 | "cell_type": "markdown", 90 | "metadata": {}, 91 | "source": [ 92 | "## Create Model and Load Trained Weights" 93 | ] 94 | }, 95 | { 96 | "cell_type": "code", 97 | "execution_count": 4, 98 | "metadata": { 99 | "scrolled": false 100 | }, 101 | "outputs": [ 102 | { 103 | "name": "stdout", 104 | "output_type": "stream", 105 | "text": [ 106 | "WARNING:tensorflow:From ../segcapsnet/capsule_layers.py:322: calling softmax (from tensorflow.python.ops.nn_ops) with dim is deprecated and will be removed in a future version.\n", 107 | "Instructions for updating:\n", 108 | "dim is deprecated, use axis instead\n" 109 | ] 110 | }, 111 | { 112 | "name": "stderr", 113 | "output_type": "stream", 114 | "text": [ 115 | "WARNING:tensorflow:From ../segcapsnet/capsule_layers.py:322: calling softmax (from tensorflow.python.ops.nn_ops) with dim is deprecated and will be removed in a future version.\n", 116 | "Instructions for updating:\n", 117 | "dim is deprecated, use axis instead\n" 118 | ] 119 | }, 120 | { 121 | "name": "stdout", 122 | "output_type": "stream", 123 | "text": [ 124 | "WARNING:tensorflow:From ../segcapsnet/capsule_layers.py:351: calling norm (from tensorflow.python.ops.linalg_ops) with keep_dims is deprecated and will be removed in a future version.\n", 125 | "Instructions for updating:\n", 126 | "keep_dims is deprecated, use keepdims instead\n" 127 | ] 128 | }, 129 | { 130 | "name": "stderr", 131 | "output_type": "stream", 132 | "text": [ 133 | "WARNING:tensorflow:From ../segcapsnet/capsule_layers.py:351: calling norm (from tensorflow.python.ops.linalg_ops) with keep_dims is deprecated and will be removed in a future version.\n", 134 | "Instructions for updating:\n", 135 | "keep_dims is deprecated, use keepdims instead\n" 136 | ] 137 | }, 138 | { 139 | "name": "stdout", 140 | "output_type": "stream", 141 | "text": [ 142 | "__________________________________________________________________________________________________\n", 143 | "Layer (type) Output Shape Param # Connected to \n", 144 | "==================================================================================================\n", 145 | "input_2 (InputLayer) (None, 512, 512, 1) 0 \n", 146 | "__________________________________________________________________________________________________\n", 147 | "conv1 (Conv2D) (None, 512, 512, 256 6656 input_2[0][0] \n", 148 | "__________________________________________________________________________________________________\n", 149 | "reshape_1 (Reshape) (None, 512, 512, 1, 0 conv1[0][0] \n", 150 | "__________________________________________________________________________________________________\n", 151 | "primarycaps (ConvCapsuleLayer) (None, 512, 512, 8, 1638656 reshape_1[0][0] \n", 152 | "__________________________________________________________________________________________________\n", 153 | "seg_caps (ConvCapsuleLayer) (None, 512, 512, 1, 528 primarycaps[0][0] \n", 154 | "__________________________________________________________________________________________________\n", 155 | "mask_2 (Mask) (None, 512, 512, 1, 0 seg_caps[0][0] \n", 156 | "__________________________________________________________________________________________________\n", 157 | "reshape_3 (Reshape) (None, 512, 512, 16) 0 mask_2[0][0] \n", 158 | "__________________________________________________________________________________________________\n", 159 | "recon_1 (Conv2D) (None, 512, 512, 64) 1088 reshape_3[0][0] \n", 160 | "__________________________________________________________________________________________________\n", 161 | "recon_2 (Conv2D) (None, 512, 512, 128 8320 recon_1[0][0] \n", 162 | "__________________________________________________________________________________________________\n", 163 | "out_seg (Length) (None, 512, 512, 1) 0 seg_caps[0][0] \n", 164 | "__________________________________________________________________________________________________\n", 165 | "out_recon (Conv2D) (None, 512, 512, 1) 129 recon_2[0][0] \n", 166 | "==================================================================================================\n", 167 | "Total params: 1,655,377\n", 168 | "Trainable params: 1,655,377\n", 169 | "Non-trainable params: 0\n", 170 | "__________________________________________________________________________________________________\n" 171 | ] 172 | } 173 | ], 174 | "source": [ 175 | "# Create model object in inference mode.\n", 176 | "net_input_shape = (RESOLUTION, RESOLUTION, 1)\n", 177 | "num_class = 2\n", 178 | "# train_model, eval_model, manipulate_model = modellib.CapsNetR3(net_input_shape, num_class)\n", 179 | "train_model, eval_model, manipulate_model = modellib.CapsNetBasic(net_input_shape, num_class)\n", 180 | "# eval_model = unet.UNet(net_input_shape)\n", 181 | "\n", 182 | "# Load weights trained on MS-COCO\n", 183 | "eval_model.load_weights(COCO_MODEL_PATH)\n", 184 | "print_summary(model=eval_model)\n", 185 | "\n" 186 | ] 187 | }, 188 | { 189 | "cell_type": "code", 190 | "execution_count": 5, 191 | "metadata": {}, 192 | "outputs": [], 193 | "source": [ 194 | "\n", 195 | "def threshold_mask(raw_output, threshold):\n", 196 | " if threshold == 0:\n", 197 | " try:\n", 198 | " threshold = filters.threshold_otsu(raw_output)\n", 199 | " except:\n", 200 | " threshold = 0.5\n", 201 | "\n", 202 | " print('\\tThreshold: {}'.format(threshold))\n", 203 | "\n", 204 | " raw_output[raw_output > threshold] = 1\n", 205 | " raw_output[raw_output < 1] = 0\n", 206 | "\n", 207 | " all_labels = measure.label(raw_output)\n", 208 | " props = measure.regionprops(all_labels)\n", 209 | " props.sort(key=lambda x: x.area, reverse=True)\n", 210 | " thresholded_mask = np.zeros(raw_output.shape)\n", 211 | "\n", 212 | " if len(props) >= 2:\n", 213 | " if props[0].area / props[1].area > 5: # if the largest is way larger than the second largest\n", 214 | " thresholded_mask[all_labels == props[0].label] = 1 # only turn on the largest component\n", 215 | " else:\n", 216 | " thresholded_mask[all_labels == props[0].label] = 1 # turn on two largest components\n", 217 | " thresholded_mask[all_labels == props[1].label] = 1\n", 218 | " elif len(props):\n", 219 | " thresholded_mask[all_labels == props[0].label] = 1\n", 220 | "\n", 221 | " thresholded_mask = scipy.ndimage.morphology.binary_fill_holes(thresholded_mask).astype(np.uint8)\n", 222 | "\n", 223 | " return thresholded_mask\n" 224 | ] 225 | }, 226 | { 227 | "cell_type": "markdown", 228 | "metadata": {}, 229 | "source": [ 230 | "## Predict the Segmentation of Person\n", 231 | "\n" 232 | ] 233 | }, 234 | { 235 | "cell_type": "code", 236 | "execution_count": 6, 237 | "metadata": {}, 238 | "outputs": [ 239 | { 240 | "name": "stdout", 241 | "output_type": "stream", 242 | "text": [ 243 | "2018-06-30 07:45:13.122611\n", 244 | "1/1 [==============================] - 40s 40s/step\n", 245 | "2018-06-30 07:45:52.984257\n" 246 | ] 247 | } 248 | ], 249 | "source": [ 250 | "\n", 251 | "img = ['train2.png']\n", 252 | "output_array = None\n", 253 | "\n", 254 | "\n", 255 | "# sitk_img = sitk.ReadImage(join(IMAGE_DIR, img[0]))\n", 256 | "# img_data = sitk.GetArrayFromImage(sitk_img)\n", 257 | "img_data = np.asarray(Image.open(join(IMAGE_DIR, img[0])))\n", 258 | "\n", 259 | " \n", 260 | "print(str(datetime.now()))\n", 261 | "output_array = eval_model.predict_generator(generate_test_batches(DATA_DIR, [img],\n", 262 | " net_input_shape,\n", 263 | " batchSize=1,\n", 264 | " numSlices=1,\n", 265 | " subSampAmt=0,\n", 266 | " stride=1),\n", 267 | " steps=1, max_queue_size=1, workers=1,\n", 268 | " use_multiprocessing=False, verbose=1)\n", 269 | "print(str(datetime.now()))" 270 | ] 271 | }, 272 | { 273 | "cell_type": "code", 274 | "execution_count": null, 275 | "metadata": {}, 276 | "outputs": [], 277 | "source": [ 278 | "# output_array contain 2 masks in a list, show the first element.\n", 279 | "# print('len(output_array)=%d'%(len(output_array)))\n", 280 | "# print('test.test: output_array=%s'%(output_array[0]))\n" 281 | ] 282 | }, 283 | { 284 | "cell_type": "code", 285 | "execution_count": 9, 286 | "metadata": {}, 287 | "outputs": [ 288 | { 289 | "name": "stdout", 290 | "output_type": "stream", 291 | "text": [ 292 | "test.test: output=[[[0.999551 0.9995523 0.99955237 ... 0.9995541 0.9995544 0.99955344]\n", 293 | " [0.99955153 0.9995528 0.9995528 ... 0.9995547 0.999555 0.9995539 ]\n", 294 | " [0.9995511 0.99955255 0.99955255 ... 0.99955446 0.9995547 0.9995538 ]\n", 295 | " ...\n", 296 | " [0.9995482 0.9995497 0.9995498 ... 0.9995611 0.99956137 0.9995608 ]\n", 297 | " [0.9995478 0.9995491 0.9995493 ... 0.9995608 0.9995611 0.9995604 ]\n", 298 | " [0.9995464 0.99954784 0.9995479 ... 0.9995595 0.99956 0.99955934]]]\n" 299 | ] 300 | } 301 | ], 302 | "source": [ 303 | "# output = (1, 512, 512)\n", 304 | "output = output_array[0][:,:,:,0] # A list with two images, get first one image and reshape it to 3 dimensions.\n", 305 | "recon = output_array[1][:,:,:,0]\n", 306 | "\n", 307 | "# For unet\n", 308 | "# output = output_array[:,:,:,0]\n", 309 | "# image store in tuple structure.\n", 310 | "print('test.test: output=%s'%(output))\n", 311 | "np.ndim(output)\n", 312 | "np_output = np.array(output)\n", 313 | "\n", 314 | "\n", 315 | "\n" 316 | ] 317 | }, 318 | { 319 | "cell_type": "code", 320 | "execution_count": 10, 321 | "metadata": {}, 322 | "outputs": [ 323 | { 324 | "name": "stdout", 325 | "output_type": "stream", 326 | "text": [ 327 | "Segmenting Output\n", 328 | "\tThreshold: 0.9995464934036136\n" 329 | ] 330 | } 331 | ], 332 | "source": [ 333 | "# output_img = sitk.GetImageFromArray(output[0,:,:], isVector=True)\n", 334 | "\n", 335 | "print('Segmenting Output')\n", 336 | "threshold_level = 0\n", 337 | "output_bin = threshold_mask(output, threshold_level)\n", 338 | "# output2d = output[0,:,:]\n", 339 | "# output2d = recon[0,:,:]\n", 340 | "# print(output2d)" 341 | ] 342 | }, 343 | { 344 | "cell_type": "code", 345 | "execution_count": 11, 346 | "metadata": { 347 | "scrolled": true 348 | }, 349 | "outputs": [ 350 | { 351 | "data": { 352 | "text/plain": [ 353 | "" 354 | ] 355 | }, 356 | "execution_count": 11, 357 | "metadata": {}, 358 | "output_type": "execute_result" 359 | }, 360 | { 361 | "data": { 362 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAQYAAAD8CAYAAACVSwr3AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvNQv5yAAAEddJREFUeJzt3V2sHOV9x/HvrzaGtKQY8AFZtqlB8QVctICOiCOqikJSgRvFXIAEioqFLFlqqUREpdS0UqtIvQi9CAipIrVqVFMlAZoX2UJuqWVAVS8wHIf3uMQHRPGRET4R4KRCaUvy78U+G6939mV2z87uvPw+0tHOPDu7+9+X+e3zzM7MUURgZtbp12ZdgJmVj4PBzDIcDGaW4WAwswwHg5llOBjMLKOQYJB0s6Q3JS1K2l3EY5hZcTTp/RgkrQJ+DHwBWAJeBO6MiB9N9IHMrDBF9BiuAxYj4u2I+F/gcWB7AY9jZgVZXcB9bgBOdMwvAZ8ddIN169bF5s2bCyjFzNqOHj36k4iYy7NsEcGgHm2Z8YqkXcAugMsuu4yFhYUCSjGzNkn/lXfZIoYSS8CmjvmNwMnuhSJiT0TMR8T83FyuEDOzKSkiGF4Etki6XNIa4A7gQAGPY2YFmfhQIiI+kfSnwNPAKuDRiHhj0o9jZsUpYhsDEXEQOFjEfZtZ8bzno5llOBjMLMPBYGYZDgYzy3AwmFmGg8HMMhwMZpbhYDCzDAeDmWU4GMwsw8FgZhkOBjPLcDCYWYaDwcwyHAxmluFgMLMMB4OZZTgYzCzDwWBmGYWc89HqTTrzr0Mm/S8OrRwcDJZbZyBYvTkYGi7Pt78DoXkcDA2Vd2V3KDSTg6Fh8qzoDgNzMDTIsBXegWBt/rnSzDIcDA1RRG8gIpDknkYNeSjRAJNecb3vQv05GGqu6FBwSNSThxKWm0OgORwMNeaxv43LQ4ka8jYFWykHQ82MGwpe+a2ThxLmULCMocEg6VFJpyS93tF2kaRDko6nywtTuyQ9LGlR0quSri2yeFs5h4L1kqfH8I/AzV1tu4HDEbEFOJzmAW4BtqS/XcAjkynT8hh1GOFQsH6GBkNE/DvwQVfzdmBfmt4H3NrR/li0PA+slbR+UsXa5DgUbJBxtzFcGhHvAaTLS1L7BuBEx3JLqS1D0i5JC5IWlpeXxyzDxuFQsGEmvfGxV1+256cwIvZExHxEzM/NzU24jObJO4xwKFge4wbD++0hQro8ldqXgE0dy20ETo5fnpnNwrjBcADYkaZ3APs72u9Kv05sBU63hxxWHO/haJM2dAcnSd8BbgDWSVoC/hr4OvCkpJ3Au8DtafGDwDZgEfgYuLuAms2sYEODISLu7HPVTT2WDeCelRZlZrPlPR8bxBseLS8Hg5llOBjMLMPBYGYZDoYG8c+alpeDoeK8slsRHAxmluFgaBj3MCwPB0PFed8EK4KDwcwyHAwN5OGEDeNgqIFxhhMOBxvEp49vsEHh4G0XzeYeQ014RbZJcjCYWYaHEjUSESvaduBeh7U5GGombzg4BGwQB0MNda70vULCoWDDeBtDzTkEbBwOhgZwONioHAxmluFgaAj3GmwUDgYzy/CvEjXW/kWi3Vtwr8Hyco+h5hwGNg4Hg5llOBhqzodX2zi8jaHGPIywcbnHYGYZDgYzy3AwmFmGg8HMMhwMZpYxNBgkbZL0rKRjkt6QdG9qv0jSIUnH0+WFqV2SHpa0KOlVSdcW/STMbLLy9Bg+Af4sIq4EtgL3SLoK2A0cjogtwOE0D3ALsCX97QIemXjVZlaoocEQEe9FxA/T9M+AY8AGYDuwLy22D7g1TW8HHouW54G1ktZPvHIzK8xI2xgkbQauAY4Al0bEe9AKD+CStNgG4ETHzZZSm5lVRO5gkHQ+8D3gKxHx00GL9mjL7IInaZekBUkLy8vLecswsynIFQySzqEVCt+KiO+n5vfbQ4R0eSq1LwGbOm6+ETjZfZ8RsSci5iNifm5ubtz6zawAeX6VELAXOBYR3+i46gCwI03vAPZ3tN+Vfp3YCpxuDznMrBryHER1PfBHwGuSXk5tfwF8HXhS0k7gXeD2dN1BYBuwCHwM3D3Ris2scEODISL+g97bDQBu6rF8APessC4zmyHv+WhmGQ4GM8twMJhZhoPBzDIcDGaW4WAwswwHg5llOBjMLMPBYGYZDgYzy3AwNIz/M5Xl4WBoEIeC5eVgaCAHhA3jYGgIh4GNwsHQUA4KG8TB0AAOARuVg6HmBoWCA8P6cTDUmFd8G1eecz5ahYwaBpJonY3P7AwHQw2stGdQ5nBoP7ey1ldXDoYK81DBiuJtDBUkaeKh4JCxTg6Gimn6Ctz05z8tDoaKKKKX0OsxyqbXtoUy1lk3DoaSm0YgdD9eFVSlzqpyMJSYP/xZnT0Ivz7F8a8SJTTrD7x/IjT3GEpk2sOGYYr69WOS91mm16tOHAwlULZA6GeldXbedpT7cc9l+hwMM1b2QJjkrwLd9zXr+7H+HAwzUpVeQtEm9Rr4tZwsb3ycgap+iIvq0pf5WI2mcjBMUVUDwZrHwTAFVQ6EztqL/FZ3r6Fchm5jkHSepBckvSLpDUlfS+2XSzoi6bikJyStSe3npvnFdP3mYp9CedVtO0KZdsnuFyJ1e81nJc/Gx/8BboyI3wGuBm6WtBV4AHgwIrYAHwI70/I7gQ8j4jPAg2m5Rqnzh7Po5zapnzHr+vpPy9BgiJb/TrPnpL8AbgS+m9r3Abem6e1pnnT9TWrIu1TnQOhW1ufp4chk5Pq5UtIqSS8Dp4BDwFvARxHxSVpkCdiQpjcAJwDS9aeBi3vc5y5JC5IWlpeXV/YsSqCsK0qRyv6cy15fmeUKhoj4RURcDWwErgOu7LVYuuz1bmRiPCL2RMR8RMzPzc3lrbd0mtRL6GUlOym1/yZ1n72O8Wjye7MSI+3gFBEfAc8BW4G1ktq/amwETqbpJWATQLr+AuCDSRRbNv7QtYz6OuTp7vu1na08v0rMSVqbpj8FfB44BjwL3JYW2wHsT9MH0jzp+mfCA7/aW8mKvJKPh0OmGHn2Y1gP7JO0ilaQPBkRT0n6EfC4pL8BXgL2puX3Av8kaZFWT+GOAuqeKX/QRjfsUO7u7v+4YdG+bURkDtry91N+Q4MhIl4FrunR/jat7Q3d7T8Hbp9IdVYZw1a6UVZK73o9ez6IakTuLVSHQ2B8DgZbsVmvgIMOw/Yh2uNxMIzAH6p68Ps4nIMhJ3+YrEkcDFYLow5nHPSDORhy8Ieov1lvX8ijCjWWjYPBxlb1Fc6B35+DwcZStVAYdP4Gy3IwmFmGg2EIf6NUT57drjs1/QjZXhwMVitVG+KUlYPBasHf+JPlYBjAH7ZqcW9hchwM1igOj3wcDGaW4WDow8OI/sr4rVuG8z3UiYPBGsnhMJiDwcwyHAxmluFg6MHbF5rBw4n+HAxmluFgsEZzr6E3B0MXDyPMHAxm7jX04GAwS9xbPCPPv6gzqz33Gs7mHoNZH03uQTgYzHoYJRTqeAYoB4NZH3mGF3ULhDYHg1mXcVb2um2jcDCY9TBKb6FuoQAOhrPUtVtok1f3z4qDwazLsB5AZyjUsbcAIwSDpFWSXpL0VJq/XNIRScclPSFpTWo/N80vpus3F1O6zUrdvy0HaUIowGg9hnuBYx3zDwAPRsQW4ENgZ2rfCXwYEZ8BHkzLmVVeU0IBcgaDpI3AHwL/kOYF3Ah8Ny2yD7g1TW9P86Trb1KTv2Ks8rr3U6h7KED+HsNDwFeBX6b5i4GPIuKTNL8EbEjTG4ATAOn602n5s0jaJWlB0sLy8vKY5ZsVq/s7rQmhADmCQdIXgVMRcbSzuceikeO6Mw0ReyJiPiLm5+bmchVrNk1NDQXIdxDV9cCXJG0DzgN+k1YPYq2k1alXsBE4mZZfAjYBS5JWAxcAH0y88gnzaCe/uq4ggz4DdX3O/QztMUTE/RGxMSI2A3cAz0TEl4FngdvSYjuA/Wn6QJonXf9MNO1VtcrxF8PZVrIfw58D90lapLUNYW9q3wtcnNrvA3avrESzYjkUskY6H0NEPAc8l6bfBq7rsczPgdsnUJtZ4RwKvXnPR/zhMOvmYLDGyvuF0MRNZA4GswGaGArgYDDrq6mhAA4GayhvVxrMwWDWJSIG9haaECoOBrMOec7F0IQhhoPBGmfcb/wm9BTaHAzWKMMOn+618ncedt2E3gI4GKxBeq307e0JnSt853JN6iV08r+os0bIs4J39waadnKWTu4xWGM1tTeQh3sM1miDfmVoWi+hU+N7DP7WaIam75cwqsYHg1VPESuyw+FsDgYrjVGOdiziP0w7HM5wMFipjHoodDsg8gRFk7cZjMrBYJXVbwellXzzu9fQ4mCw0hll5ex3wNO4AeFeRYuDwUpp1JV6UEBMQ916Gg4GK41Bex6Oex82HgeDlUoRK3bR3+Z16y2Ag8HfMCXUOSyYVK+hqF8t6nrUZeODwcpr0uEwaXXsKbQ1Phjq/OZafqMGSd2PvGx8MFi5DTv/4rDbFqHuoQAOBjPrwcFgNoKmDD0dDFZrebr6/ZYZdgxGXYcR4BO1mPXU76SwUO9AaHOPwSyHpgwh2hwMVmvj7NTU2TPovr77bNJ1DYzGDyXaJ/2w+snzvvY65+Oop4Gr4xCj8cFg9TRK2Heu2JM6l0PVQyLXUELSO5Jek/SypIXUdpGkQ5KOp8sLU7skPSxpUdKrkq4t8gmY9bKS4x4moeq90FG2Mfx+RFwdEfNpfjdwOCK2AIfTPMAtwJb0twt4ZFLFmlVJlbdBrGTj43ZgX5reB9za0f5YtDwPrJW0fgWPY1ZpVQyIvMEQwL9JOippV2q7NCLeA0iXl6T2DcCJjtsupbazSNolaUHSwvLy8njVm5VAr/9/2UuVwiHvxsfrI+KkpEuAQ5L+c8CyvZ595hWLiD3AHoD5+flqb6mxxhrlF40qydVjiIiT6fIU8APgOuD99hAhXZ5Kiy8BmzpuvhE4OamCi1CXN9POVuT7upKjPqtgaDBI+g1Jn25PA38AvA4cAHakxXYA+9P0AeCu9OvEVuB0e8hhVhdV3G4wijxDiUuBH6QXYTXw7Yj4V0kvAk9K2gm8C9yelj8IbAMWgY+BuydetVkfg/5JbVGPB/XrdaoMT0jSz4A3Z11HTuuAn8y6iByqUidUp9aq1Am9a/2tiJjLc+Oy7Pn4Zsf+EaUmaaEKtValTqhOrVWpE1Zeqw+iMrMMB4OZZZQlGPbMuoARVKXWqtQJ1am1KnXCCmstxcZHMyuXsvQYzKxEZh4Mkm6W9GY6THv38FsUWsujkk5Jer2jrZSHl0vaJOlZScckvSHp3jLWK+k8SS9IeiXV+bXUfrmkI6nOJyStSe3npvnFdP3madTZUe8qSS9JeqrkdRZ7KoTOA0Cm/QesAt4CrgDWAK8AV82wnt8DrgVe72j7W2B3mt4NPJCmtwH/QuvYkK3AkSnXuh64Nk1/GvgxcFXZ6k2Pd36aPgc4kh7/SeCO1P5N4I/T9J8A30zTdwBPTPl1vQ/4NvBUmi9rne8A67raJvbeT+2J9HlynwOe7pi/H7h/xjVt7gqGN4H1aXo9rX0uAP4euLPXcjOqez/whTLXC/w68EPgs7R2vlnd/TkAngY+l6ZXp+U0pfo20jq3yI3AU2lFKl2d6TF7BcPE3vtZDyVyHaI9Yys6vHwaUjf2GlrfxqWrN3XPX6Z1oN0hWr3EjyLikx61/KrOdP1p4OJp1Ak8BHwV+GWav7ikdUIBp0LoNOs9H3Mdol1Spahd0vnA94CvRMRPBxzYM7N6I+IXwNWS1tI6OvfKAbXMpE5JXwRORcRRSTfkqGXW7//ET4XQadY9hiocol3aw8slnUMrFL4VEd9PzaWtNyI+Ap6jNc5dK6n9xdRZy6/qTNdfAHwwhfKuB74k6R3gcVrDiYdKWCdQ/KkQZh0MLwJb0pbfNbQ24hyYcU3dSnl4uVpdg73AsYj4RlnrlTSXegpI+hTweeAY8CxwW5862/XfBjwTaWBcpIi4PyI2RsRmWp/DZyLiy2WrE6Z0KoRpbnzqsxFlG60t6m8BfznjWr4DvAf8H62U3Ulr3HgYOJ4uL0rLCvi7VPdrwPyUa/1dWt3BV4GX09+2stUL/DbwUqrzdeCvUvsVwAu0Ds//Z+Dc1H5eml9M118xg8/BDZz5VaJ0daaaXkl/b7TXm0m+997z0cwyZj2UMLMScjCYWYaDwcwyHAxmluFgMLMMB4OZZTgYzCzDwWBmGf8PneSop33oBAIAAAAASUVORK5CYII=\n", 363 | "text/plain": [ 364 | "" 365 | ] 366 | }, 367 | "metadata": {}, 368 | "output_type": "display_data" 369 | } 370 | ], 371 | "source": [ 372 | "# plt.imshow(output[0,:,:], cmap='gray')\n", 373 | "# plt.imsave('raw_output' + img[0][-4:], output[0,:,:])\n", 374 | "plt.imshow(output_bin[0,:,:], cmap='gray')\n", 375 | "# plt.imsave('final_output' + img[0][-4:], output_bin[0,:,:])\n" 376 | ] 377 | } 378 | ], 379 | "metadata": { 380 | "kernelspec": { 381 | "display_name": "Python 3", 382 | "language": "python", 383 | "name": "python3" 384 | }, 385 | "language_info": { 386 | "codemirror_mode": { 387 | "name": "ipython", 388 | "version": 3 389 | }, 390 | "file_extension": ".py", 391 | "mimetype": "text/x-python", 392 | "name": "python", 393 | "nbconvert_exporter": "python", 394 | "pygments_lexer": "ipython3", 395 | "version": "3.6.4" 396 | } 397 | }, 398 | "nbformat": 4, 399 | "nbformat_minor": 2 400 | } 401 | -------------------------------------------------------------------------------- /notebook/3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cheng-Lin-Li/SegCaps/237209c2f53b4c39e8109390db144b89e0f49335/notebook/3.png -------------------------------------------------------------------------------- /notebook/4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cheng-Lin-Li/SegCaps/237209c2f53b4c39e8109390db144b89e0f49335/notebook/4.png -------------------------------------------------------------------------------- /raspberrypi/Raspi3-install.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | echo "Install Tensorflow 1.8 on Python 3.5" 4 | sudo apt-get update 5 | sudo pip3 install --upgrade pip 6 | sudo pip3 install https://github.com/lhelontra/tensorflow-on-arm/releases/download/v1.8.0/tensorflow-1.8.0-cp35-none-linux_armv7l.whl 7 | sudo pip3 uninstall mock 8 | sudo pip3 install mock 9 | 10 | echo "Install Keras on Python 3.5" 11 | sudo apt-get install python3-numpy libblas-dev liblapack-dev python3-dev libatlas-base-dev gfortran python3-setuptools python3-scipy python3-h5py at-spi2-core 12 | sudo pip3 install keras 13 | 14 | echo "Install Rest of Packages" 15 | sudo apt-get install python3-matplotlib python3-sklearn python3-pil python3-skimage 16 | 17 | sudo pip3 install cairocffi 18 | sudo pip3 install jupyter 19 | reboot 20 | 21 | -------------------------------------------------------------------------------- /raspberrypi/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cheng-Lin-Li/SegCaps/237209c2f53b4c39e8109390db144b89e0f49335/raspberrypi/__init__.py -------------------------------------------------------------------------------- /raspberrypi/opencv-install.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | echo "Install Open CV" 4 | 5 | sudo apt-get install build-essential cmake pkg-config 6 | sudo apt-get install libjpeg-dev libtiff5-dev libjasper-dev libpng12-dev 7 | sudo apt-get install libavcodec-dev libavformat-dev libswscale-dev libv4l-dev 8 | sudo apt-get install libxvidcore-dev libx264-dev 9 | sudo apt-get install libgtk2.0-dev libgtk-3-dev 10 | 11 | wget -O opencv.zip https://github.com/opencv/opencv/archive/3.4.1.zip 12 | 13 | wget -O opencv_contrib.zip https://github.com/opencv/opencv_contrib/archive/3.4.1.zip 14 | 15 | unzip opencv.zip 16 | 17 | unzip opencv_contrib.zip 18 | 19 | cd ./opencv-3.4.1/ 20 | mkdir build 21 | cd build 22 | 23 | cmake -D CMAKE_BUILD_TYPE=RELEASE \ 24 | -D CMAKE_INSTALL_PREFIX=/usr/local \ 25 | -D INSTALL_PYTHON_EXAMPLES=ON \ 26 | -D OPENCV_EXTRA_MODULES_PATH=~/SegCaps/opencv_contrib-3.4.1/modules \ 27 | -D BUILD_EXAMPLES=ON .. 28 | 29 | make -j4 30 | 31 | sudo make install 32 | sudo ldconfig 33 | sudo apt-get update 34 | 35 | python -c "import cv2 as cv2; print(cv2.__version__)" 36 | 37 | 38 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | scipy 3 | matplotlib 4 | tqdm 5 | scikit-learn 6 | scikit-image 7 | pillow 8 | SimpleITK 9 | keras 10 | h5py 11 | 12 | # Depends on your environment. 13 | tensorflow 14 | # tensorflow-gpu 15 | 16 | # Below section 17 | opencv-python 18 | pycocotools -------------------------------------------------------------------------------- /segcapsnet/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cheng-Lin-Li/SegCaps/237209c2f53b4c39e8109390db144b89e0f49335/segcapsnet/__init__.py -------------------------------------------------------------------------------- /segcapsnet/capsnet.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Capsules for Object Segmentation (SegCaps) 3 | Original Paper by Rodney LaLonde and Ulas Bagci (https://arxiv.org/abs/1804.04241) 4 | Code written by: Rodney LaLonde 5 | If you use significant portions of this code or the ideas from our paper, please cite it :) 6 | If you have any questions, please email me at lalonde@knights.ucf.edu. 7 | 8 | This file contains the network definitions for the various capsule network architectures. 9 | ''' 10 | 11 | from keras import layers, models 12 | from keras import backend as K 13 | 14 | K.set_image_data_format('channels_last') 15 | 16 | from segcapsnet.capsule_layers import ConvCapsuleLayer, DeconvCapsuleLayer, Mask, Length 17 | 18 | def CapsNetR3(input_shape, n_class=2, enable_decoder=True): 19 | x = layers.Input(shape=input_shape) # x=keras_shape(None, 512, 512, 3) 20 | 21 | # Layer 1: Just a conventional Conv2D layer 22 | conv1 = layers.Conv2D(filters=16, kernel_size=5, strides=1, padding='same', activation='relu', name='conv1')(x) 23 | 24 | # Reshape layer to be 1 capsule x [filters] atoms 25 | _, H, W, C = conv1.get_shape() # _, 512, 512, 16 26 | conv1_reshaped = layers.Reshape((H.value, W.value, 1, C.value))(conv1) 27 | 28 | # Layer 1: Primary Capsule: Conv cap with routing 1 29 | primary_caps = ConvCapsuleLayer(kernel_size=5, num_capsule=2, num_atoms=16, strides=2, padding='same', 30 | routings=1, name='primarycaps')(conv1_reshaped) 31 | 32 | # Layer 2: Convolutional Capsule 33 | conv_cap_2_1 = ConvCapsuleLayer(kernel_size=5, num_capsule=4, num_atoms=16, strides=1, padding='same', 34 | routings=3, name='conv_cap_2_1')(primary_caps) 35 | 36 | # Layer 2: Convolutional Capsule 37 | conv_cap_2_2 = ConvCapsuleLayer(kernel_size=5, num_capsule=4, num_atoms=32, strides=2, padding='same', 38 | routings=3, name='conv_cap_2_2')(conv_cap_2_1) 39 | 40 | # Layer 3: Convolutional Capsule 41 | conv_cap_3_1 = ConvCapsuleLayer(kernel_size=5, num_capsule=8, num_atoms=32, strides=1, padding='same', 42 | routings=3, name='conv_cap_3_1')(conv_cap_2_2) 43 | 44 | # Layer 3: Convolutional Capsule 45 | conv_cap_3_2 = ConvCapsuleLayer(kernel_size=5, num_capsule=8, num_atoms=64, strides=2, padding='same', 46 | routings=3, name='conv_cap_3_2')(conv_cap_3_1) 47 | 48 | # Layer 4: Convolutional Capsule 49 | conv_cap_4_1 = ConvCapsuleLayer(kernel_size=5, num_capsule=8, num_atoms=32, strides=1, padding='same', 50 | routings=3, name='conv_cap_4_1')(conv_cap_3_2) 51 | 52 | # Layer 1 Up: Deconvolutional Capsule 53 | deconv_cap_1_1 = DeconvCapsuleLayer(kernel_size=4, num_capsule=8, num_atoms=32, upsamp_type='deconv', 54 | scaling=2, padding='same', routings=3, 55 | name='deconv_cap_1_1')(conv_cap_4_1) 56 | 57 | # Skip connection 58 | up_1 = layers.Concatenate(axis=-2, name='up_1')([deconv_cap_1_1, conv_cap_3_1]) 59 | 60 | # Layer 1 Up: Deconvolutional Capsule 61 | deconv_cap_1_2 = ConvCapsuleLayer(kernel_size=5, num_capsule=4, num_atoms=32, strides=1, 62 | padding='same', routings=3, name='deconv_cap_1_2')(up_1) 63 | 64 | # Layer 2 Up: Deconvolutional Capsule 65 | deconv_cap_2_1 = DeconvCapsuleLayer(kernel_size=4, num_capsule=4, num_atoms=16, upsamp_type='deconv', 66 | scaling=2, padding='same', routings=3, 67 | name='deconv_cap_2_1')(deconv_cap_1_2) 68 | 69 | # Skip connection 70 | up_2 = layers.Concatenate(axis=-2, name='up_2')([deconv_cap_2_1, conv_cap_2_1]) 71 | 72 | # Layer 2 Up: Deconvolutional Capsule 73 | deconv_cap_2_2 = ConvCapsuleLayer(kernel_size=5, num_capsule=4, num_atoms=16, strides=1, 74 | padding='same', routings=3, name='deconv_cap_2_2')(up_2) 75 | 76 | # Layer 3 Up: Deconvolutional Capsule 77 | deconv_cap_3_1 = DeconvCapsuleLayer(kernel_size=4, num_capsule=2, num_atoms=16, upsamp_type='deconv', 78 | scaling=2, padding='same', routings=3, 79 | name='deconv_cap_3_1')(deconv_cap_2_2) 80 | 81 | # Skip connection 82 | up_3 = layers.Concatenate(axis=-2, name='up_3')([deconv_cap_3_1, conv1_reshaped]) 83 | 84 | # Layer 4: Convolutional Capsule: 1x1 85 | seg_caps = ConvCapsuleLayer(kernel_size=1, num_capsule=1, num_atoms=16, strides=1, padding='same', 86 | routings=3, name='seg_caps')(up_3) 87 | 88 | # Layer 4: This is an auxiliary layer to replace each capsule with its length. Just to match the true label's shape. 89 | out_seg = Length(num_classes=n_class, seg=True, name='out_seg')(seg_caps) 90 | 91 | # Decoder network. 92 | _, H, W, C, A = seg_caps.get_shape() #(?, 512, 512, 1, 16) 93 | y = layers.Input(shape=input_shape[:-1]+(1,)) #y: keras_shape(512, 512, 1) 94 | masked_by_y = Mask()([seg_caps, y]) # The true label is used to mask the output of capsule layer. For training (None, 512, 512, 1, 16) 95 | masked = Mask()(seg_caps) # Mask using the capsule with maximal length. For prediction () 96 | 97 | def shared_decoder(mask_layer): 98 | recon_remove_dim = layers.Reshape((H.value, W.value, A.value))(mask_layer) #mask_layer=(?, 512, 512, 1, 16) 99 | 100 | recon_1 = layers.Conv2D(filters=64, kernel_size=1, padding='same', kernel_initializer='he_normal', 101 | activation='relu', name='recon_1')(recon_remove_dim) 102 | 103 | recon_2 = layers.Conv2D(filters=128, kernel_size=1, padding='same', kernel_initializer='he_normal', 104 | activation='relu', name='recon_2')(recon_1) 105 | 106 | out_recon = layers.Conv2D(filters=1, kernel_size=1, padding='same', kernel_initializer='he_normal', 107 | activation='sigmoid', name='out_recon')(recon_2) 108 | 109 | return out_recon 110 | 111 | # Models for training and evaluation (prediction) 112 | train_model = models.Model(inputs=[x, y], outputs=[out_seg, shared_decoder(masked_by_y)]) 113 | if enable_decoder == True: 114 | eval_model = models.Model(inputs=x, outputs=[out_seg, shared_decoder(masked)]) 115 | else: 116 | eval_model = models.Model(inputs=x, outputs=[out_seg]) 117 | # manipulate model 118 | noise = layers.Input(shape=((H.value, W.value, C.value, A.value))) 119 | noised_seg_caps = layers.Add()([seg_caps, noise]) 120 | masked_noised_y = Mask()([noised_seg_caps, y]) 121 | manipulate_model = models.Model(inputs=[x, y, noise], outputs=shared_decoder(masked_noised_y)) 122 | 123 | return train_model, eval_model, manipulate_model 124 | 125 | 126 | def CapsNetR1(input_shape, n_class=2): 127 | x = layers.Input(shape=input_shape) 128 | 129 | # Layer 1: Just a conventional Conv2D layer 130 | conv1 = layers.Conv2D(filters=16, kernel_size=5, strides=1, padding='same', activation='relu', name='conv1')(x) 131 | 132 | # Reshape layer to be 1 capsule x [filters] atoms 133 | _, H, W, C = conv1.get_shape() 134 | conv1_reshaped = layers.Reshape((H.value, W.value, 1, C.value))(conv1) 135 | 136 | # Layer 1: Primary Capsule: Conv cap with routing 1 137 | primary_caps = ConvCapsuleLayer(kernel_size=5, num_capsule=2, num_atoms=16, strides=2, padding='same', 138 | routings=1, name='primarycaps')(conv1_reshaped) 139 | 140 | # Layer 2: Convolutional Capsule 141 | conv_cap_2_1 = ConvCapsuleLayer(kernel_size=5, num_capsule=4, num_atoms=16, strides=1, padding='same', 142 | routings=1, name='conv_cap_2_1')(primary_caps) 143 | 144 | # Layer 2: Convolutional Capsule 145 | conv_cap_2_2 = ConvCapsuleLayer(kernel_size=5, num_capsule=4, num_atoms=32, strides=2, padding='same', 146 | routings=3, name='conv_cap_2_2')(conv_cap_2_1) 147 | 148 | # Layer 3: Convolutional Capsule 149 | conv_cap_3_1 = ConvCapsuleLayer(kernel_size=5, num_capsule=8, num_atoms=32, strides=1, padding='same', 150 | routings=1, name='conv_cap_3_1')(conv_cap_2_2) 151 | 152 | # Layer 3: Convolutional Capsule 153 | conv_cap_3_2 = ConvCapsuleLayer(kernel_size=5, num_capsule=8, num_atoms=64, strides=2, padding='same', 154 | routings=3, name='conv_cap_3_2')(conv_cap_3_1) 155 | 156 | # Layer 4: Convolutional Capsule 157 | conv_cap_4_1 = ConvCapsuleLayer(kernel_size=5, num_capsule=8, num_atoms=32, strides=1, padding='same', 158 | routings=1, name='conv_cap_4_1')(conv_cap_3_2) 159 | 160 | # Layer 1 Up: Deconvolutional Capsule 161 | deconv_cap_1_1 = DeconvCapsuleLayer(kernel_size=4, num_capsule=8, num_atoms=32, upsamp_type='deconv', 162 | scaling=2, padding='same', routings=3, 163 | name='deconv_cap_1_1')(conv_cap_4_1) 164 | 165 | # Skip connection 166 | up_1 = layers.Concatenate(axis=-2, name='up_1')([deconv_cap_1_1, conv_cap_3_1]) 167 | 168 | # Layer 1 Up: Deconvolutional Capsule 169 | deconv_cap_1_2 = ConvCapsuleLayer(kernel_size=5, num_capsule=4, num_atoms=32, strides=1, 170 | padding='same', routings=1, name='deconv_cap_1_2')(up_1) 171 | 172 | # Layer 2 Up: Deconvolutional Capsule 173 | deconv_cap_2_1 = DeconvCapsuleLayer(kernel_size=4, num_capsule=4, num_atoms=16, upsamp_type='deconv', 174 | scaling=2, padding='same', routings=3, 175 | name='deconv_cap_2_1')(deconv_cap_1_2) 176 | 177 | # Skip connection 178 | up_2 = layers.Concatenate(axis=-2, name='up_2')([deconv_cap_2_1, conv_cap_2_1]) 179 | 180 | # Layer 2 Up: Deconvolutional Capsule 181 | deconv_cap_2_2 = ConvCapsuleLayer(kernel_size=5, num_capsule=4, num_atoms=16, strides=1, 182 | padding='same', routings=1, name='deconv_cap_2_2')(up_2) 183 | 184 | # Layer 3 Up: Deconvolutional Capsule 185 | deconv_cap_3_1 = DeconvCapsuleLayer(kernel_size=4, num_capsule=2, num_atoms=16, upsamp_type='deconv', 186 | scaling=2, padding='same', routings=3, 187 | name='deconv_cap_3_1')(deconv_cap_2_2) 188 | 189 | # Skip connection 190 | up_3 = layers.Concatenate(axis=-2, name='up_3')([deconv_cap_3_1, conv1_reshaped]) 191 | 192 | # Layer 4: Convolutional Capsule: 1x1 193 | seg_caps = ConvCapsuleLayer(kernel_size=1, num_capsule=1, num_atoms=16, strides=1, padding='same', 194 | routings=1, name='seg_caps')(up_3) 195 | 196 | # Layer 4: This is an auxiliary layer to replace each capsule with its length. Just to match the true label's shape. 197 | out_seg = Length(num_classes=n_class, seg=True, name='out_seg')(seg_caps) 198 | 199 | # Decoder network. 200 | _, H, W, C, A = seg_caps.get_shape() 201 | y = layers.Input(shape=input_shape[:-1]+(1,)) 202 | masked_by_y = Mask()([seg_caps, y]) # The true label is used to mask the output of capsule layer. For training 203 | masked = Mask()(seg_caps) # Mask using the capsule with maximal length. For prediction 204 | 205 | def shared_decoder(mask_layer): 206 | recon_remove_dim = layers.Reshape((H.value, W.value, A.value))(mask_layer) 207 | 208 | recon_1 = layers.Conv2D(filters=64, kernel_size=1, padding='same', kernel_initializer='he_normal', 209 | activation='relu', name='recon_1')(recon_remove_dim) 210 | 211 | recon_2 = layers.Conv2D(filters=128, kernel_size=1, padding='same', kernel_initializer='he_normal', 212 | activation='relu', name='recon_2')(recon_1) 213 | 214 | out_recon = layers.Conv2D(filters=1, kernel_size=1, padding='same', kernel_initializer='he_normal', 215 | activation='sigmoid', name='out_recon')(recon_2) 216 | 217 | return out_recon 218 | 219 | # Models for training and evaluation (prediction) 220 | train_model = models.Model(inputs=[x, y], outputs=[out_seg, shared_decoder(masked_by_y)]) 221 | eval_model = models.Model(inputs=x, outputs=[out_seg, shared_decoder(masked)]) 222 | 223 | # manipulate model 224 | noise = layers.Input(shape=((H.value, W.value, C.value, A.value))) 225 | noised_seg_caps = layers.Add()([seg_caps, noise]) 226 | masked_noised_y = Mask()([noised_seg_caps, y]) 227 | manipulate_model = models.Model(inputs=[x, y, noise], outputs=shared_decoder(masked_noised_y)) 228 | 229 | return train_model, eval_model, manipulate_model 230 | 231 | 232 | def CapsNetBasic(input_shape, n_class=2): 233 | x = layers.Input(shape=input_shape) 234 | 235 | # Layer 1: Just a conventional Conv2D layer 236 | conv1 = layers.Conv2D(filters=256, kernel_size=5, strides=1, padding='same', activation='relu', name='conv1')(x) 237 | 238 | # Reshape layer to be 1 capsule x [filters] atoms 239 | _, H, W, C = conv1.get_shape() 240 | conv1_reshaped = layers.Reshape((H.value, W.value, 1, C.value))(conv1) 241 | 242 | # Layer 1: Primary Capsule: Conv cap with routing 1 243 | primary_caps = ConvCapsuleLayer(kernel_size=5, num_capsule=8, num_atoms=32, strides=1, padding='same', 244 | routings=1, name='primarycaps')(conv1_reshaped) 245 | 246 | # Layer 4: Convolutional Capsule: 1x1 247 | seg_caps = ConvCapsuleLayer(kernel_size=1, num_capsule=1, num_atoms=16, strides=1, padding='same', 248 | routings=3, name='seg_caps')(primary_caps) 249 | 250 | # Layer 4: This is an auxiliary layer to replace each capsule with its length. Just to match the true label's shape. 251 | out_seg = Length(num_classes=n_class, seg=True, name='out_seg')(seg_caps) 252 | 253 | # Decoder network. 254 | _, H, W, C, A = seg_caps.get_shape() 255 | y = layers.Input(shape=input_shape[:-1]+(1,)) 256 | masked_by_y = Mask()([seg_caps, y]) # The true label is used to mask the output of capsule layer. For training 257 | masked = Mask()(seg_caps) # Mask using the capsule with maximal length. For prediction 258 | 259 | def shared_decoder(mask_layer): 260 | recon_remove_dim = layers.Reshape((H.value, W.value, A.value))(mask_layer) 261 | 262 | recon_1 = layers.Conv2D(filters=64, kernel_size=1, padding='same', kernel_initializer='he_normal', 263 | activation='relu', name='recon_1')(recon_remove_dim) 264 | 265 | recon_2 = layers.Conv2D(filters=128, kernel_size=1, padding='same', kernel_initializer='he_normal', 266 | activation='relu', name='recon_2')(recon_1) 267 | 268 | out_recon = layers.Conv2D(filters=1, kernel_size=1, padding='same', kernel_initializer='he_normal', 269 | activation='sigmoid', name='out_recon')(recon_2) 270 | 271 | return out_recon 272 | 273 | # Models for training and evaluation (prediction) 274 | train_model = models.Model(inputs=[x, y], outputs=[out_seg, shared_decoder(masked_by_y)]) 275 | eval_model = models.Model(inputs=x, outputs=[out_seg, shared_decoder(masked)]) 276 | 277 | # manipulate model 278 | noise = layers.Input(shape=((H.value, W.value, C.value, A.value))) 279 | noised_seg_caps = layers.Add()([seg_caps, noise]) 280 | masked_noised_y = Mask()([noised_seg_caps, y]) 281 | manipulate_model = models.Model(inputs=[x, y, noise], outputs=shared_decoder(masked_noised_y)) 282 | 283 | return train_model, eval_model, manipulate_model 284 | -------------------------------------------------------------------------------- /segcapsnet/capsule_layers.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Capsules for Object Segmentation (SegCaps) 3 | Original Paper: https://arxiv.org/abs/1804.04241 4 | Code written by: Rodney LaLonde 5 | If you use significant portions of this code or the ideas from our paper, please cite it :) 6 | If you have any questions, please email me at lalonde@knights.ucf.edu. 7 | 8 | This file contains the definitions of the various capsule layers and dynamic routing and squashing functions. 9 | ''' 10 | 11 | import keras.backend as K 12 | import tensorflow as tf 13 | from keras import initializers, layers 14 | from keras.utils.conv_utils import conv_output_length, deconv_length 15 | import numpy as np 16 | 17 | class Length(layers.Layer): 18 | def __init__(self, num_classes, seg=True, **kwargs): 19 | super(Length, self).__init__(**kwargs) 20 | if num_classes == 2: 21 | self.num_classes = 1 22 | else: 23 | self.num_classes = num_classes 24 | self.seg = seg 25 | 26 | def call(self, inputs, **kwargs): 27 | if inputs.get_shape().ndims == 5: 28 | assert inputs.get_shape()[-2].value == 1, 'Error: Must have num_capsules = 1 going into Length' 29 | inputs = K.squeeze(inputs, axis=-2) 30 | return K.expand_dims(tf.norm(inputs, axis=-1), axis=-1) 31 | 32 | def compute_output_shape(self, input_shape): 33 | if len(input_shape) == 5: 34 | input_shape = input_shape[0:-2] + input_shape[-1:] 35 | if self.seg: 36 | return input_shape[:-1] + (self.num_classes,) 37 | else: 38 | return input_shape[:-1] 39 | 40 | def get_config(self): 41 | config = {'num_classes': self.num_classes, 'seg': self.seg} 42 | base_config = super(Length, self).get_config() 43 | return dict(list(base_config.items()) + list(config.items())) 44 | 45 | 46 | class Mask(layers.Layer): 47 | def __init__(self, resize_masks=False, **kwargs): 48 | super(Mask, self).__init__(**kwargs) 49 | self.resize_masks = resize_masks 50 | 51 | def call(self, inputs, **kwargs): 52 | if type(inputs) is list: # The true label is used to mask the output of capsule layer. For training 53 | assert len(inputs) == 2 54 | input, mask = inputs #(?, 512, 512, 1, 16), (?, 512, 512, 1) 55 | _, hei, wid, _, _ = input.get_shape() 56 | if self.resize_masks: 57 | mask = tf.image.resize_bicubic(mask, (hei.value, wid.value)) 58 | mask = K.expand_dims(mask, -1) #mask = (?, 512, 512, 1, 1) 59 | if input.get_shape().ndims == 3: 60 | masked = K.batch_flatten(mask * input) 61 | else: 62 | masked = mask * input 63 | 64 | else: # Mask using the capsule with maximal length. For prediction 65 | if inputs.get_shape().ndims == 3: 66 | x = K.sqrt(K.sum(K.square(inputs), -1)) 67 | mask = K.one_hot(indices=K.argmax(x, 1), num_classes=x.get_shape().as_list()[1]) 68 | masked = K.batch_flatten(K.expand_dims(mask, -1) * inputs) 69 | else: 70 | masked = inputs 71 | 72 | return masked 73 | 74 | def compute_output_shape(self, input_shape): 75 | if type(input_shape[0]) is tuple: # true label provided 76 | if len(input_shape[0]) == 3: 77 | return tuple([None, input_shape[0][1] * input_shape[0][2]]) 78 | else: 79 | return input_shape[0] 80 | else: # no true label provided 81 | if len(input_shape) == 3: 82 | return tuple([None, input_shape[1] * input_shape[2]]) 83 | else: 84 | return input_shape 85 | 86 | def get_config(self): 87 | config = {'resize_masks': self.resize_masks} 88 | base_config = super(Mask, self).get_config() 89 | return dict(list(base_config.items()) + list(config.items())) 90 | 91 | 92 | class ConvCapsuleLayer(layers.Layer): 93 | def __init__(self, kernel_size, num_capsule, num_atoms, strides=1, padding='same', routings=3, 94 | kernel_initializer='he_normal', **kwargs): 95 | super(ConvCapsuleLayer, self).__init__(**kwargs) 96 | self.kernel_size = kernel_size 97 | self.num_capsule = num_capsule 98 | self.num_atoms = num_atoms 99 | self.strides = strides 100 | self.padding = padding 101 | self.routings = routings 102 | self.kernel_initializer = initializers.get(kernel_initializer) 103 | 104 | def build(self, input_shape): 105 | assert len(input_shape) == 5, "The input Tensor should have shape=[None, input_height, input_width," \ 106 | " input_num_capsule, input_num_atoms]" 107 | self.input_height = input_shape[1] 108 | self.input_width = input_shape[2] 109 | self.input_num_capsule = input_shape[3] 110 | self.input_num_atoms = input_shape[4] 111 | 112 | # Transform matrix 113 | self.W = self.add_weight(shape=[self.kernel_size, self.kernel_size, 114 | self.input_num_atoms, self.num_capsule * self.num_atoms], 115 | initializer=self.kernel_initializer, 116 | name='W') 117 | 118 | self.b = self.add_weight(shape=[1, 1, self.num_capsule, self.num_atoms], 119 | initializer=initializers.constant(0.1), 120 | name='b') 121 | 122 | self.built = True 123 | 124 | def call(self, input_tensor, training=None): 125 | 126 | input_transposed = tf.transpose(input_tensor, [3, 0, 1, 2, 4]) 127 | input_shape = K.shape(input_transposed) 128 | input_tensor_reshaped = K.reshape(input_transposed, [ 129 | input_shape[0] * input_shape[1], self.input_height, self.input_width, self.input_num_atoms]) 130 | input_tensor_reshaped.set_shape((None, self.input_height, self.input_width, self.input_num_atoms)) 131 | 132 | conv = K.conv2d(input_tensor_reshaped, self.W, (self.strides, self.strides), 133 | padding=self.padding, data_format='channels_last') 134 | 135 | votes_shape = K.shape(conv) 136 | _, conv_height, conv_width, _ = conv.get_shape() 137 | 138 | votes = K.reshape(conv, [input_shape[1], input_shape[0], votes_shape[1], votes_shape[2], 139 | self.num_capsule, self.num_atoms]) 140 | votes.set_shape((None, self.input_num_capsule, conv_height.value, conv_width.value, 141 | self.num_capsule, self.num_atoms)) 142 | 143 | logit_shape = K.stack([ 144 | input_shape[1], input_shape[0], votes_shape[1], votes_shape[2], self.num_capsule]) 145 | biases_replicated = K.tile(self.b, [conv_height.value, conv_width.value, 1, 1]) 146 | 147 | activations = update_routing( 148 | votes=votes, 149 | biases=biases_replicated, 150 | logit_shape=logit_shape, 151 | num_dims=6, 152 | input_dim=self.input_num_capsule, 153 | output_dim=self.num_capsule, 154 | num_routing=self.routings) 155 | 156 | return activations 157 | 158 | def compute_output_shape(self, input_shape): 159 | space = input_shape[1:-2] 160 | new_space = [] 161 | for i in range(len(space)): 162 | new_dim = conv_output_length( 163 | space[i], 164 | self.kernel_size, 165 | padding=self.padding, 166 | stride=self.strides, 167 | dilation=1) 168 | new_space.append(new_dim) 169 | 170 | return (input_shape[0],) + tuple(new_space) + (self.num_capsule, self.num_atoms) 171 | 172 | def get_config(self): 173 | config = { 174 | 'kernel_size': self.kernel_size, 175 | 'num_capsule': self.num_capsule, 176 | 'num_atoms': self.num_atoms, 177 | 'strides': self.strides, 178 | 'padding': self.padding, 179 | 'routings': self.routings, 180 | 'kernel_initializer': initializers.serialize(self.kernel_initializer) 181 | } 182 | base_config = super(ConvCapsuleLayer, self).get_config() 183 | return dict(list(base_config.items()) + list(config.items())) 184 | 185 | 186 | class DeconvCapsuleLayer(layers.Layer): 187 | def __init__(self, kernel_size, num_capsule, num_atoms, scaling=2, upsamp_type='deconv', padding='same', routings=3, 188 | kernel_initializer='he_normal', **kwargs): 189 | super(DeconvCapsuleLayer, self).__init__(**kwargs) 190 | self.kernel_size = kernel_size 191 | self.num_capsule = num_capsule 192 | self.num_atoms = num_atoms 193 | self.scaling = scaling 194 | self.upsamp_type = upsamp_type 195 | self.padding = padding 196 | self.routings = routings 197 | self.kernel_initializer = initializers.get(kernel_initializer) 198 | 199 | def build(self, input_shape): 200 | assert len(input_shape) == 5, "The input Tensor should have shape=[None, input_height, input_width," \ 201 | " input_num_capsule, input_num_atoms]" 202 | self.input_height = input_shape[1] 203 | self.input_width = input_shape[2] 204 | self.input_num_capsule = input_shape[3] 205 | self.input_num_atoms = input_shape[4] 206 | 207 | # Transform matrix 208 | if self.upsamp_type == 'subpix': 209 | self.W = self.add_weight(shape=[self.kernel_size, self.kernel_size, 210 | self.input_num_atoms, 211 | self.num_capsule * self.num_atoms * self.scaling * self.scaling], 212 | initializer=self.kernel_initializer, 213 | name='W') 214 | elif self.upsamp_type == 'resize': 215 | self.W = self.add_weight(shape=[self.kernel_size, self.kernel_size, 216 | self.input_num_atoms, self.num_capsule * self.num_atoms], 217 | initializer=self.kernel_initializer, name='W') 218 | elif self.upsamp_type == 'deconv': 219 | self.W = self.add_weight(shape=[self.kernel_size, self.kernel_size, 220 | self.num_capsule * self.num_atoms, self.input_num_atoms], 221 | initializer=self.kernel_initializer, name='W') 222 | else: 223 | raise NotImplementedError('Upsampling must be one of: "deconv", "resize", or "subpix"') 224 | 225 | self.b = self.add_weight(shape=[1, 1, self.num_capsule, self.num_atoms], 226 | initializer=initializers.constant(0.1), 227 | name='b') 228 | 229 | self.built = True 230 | 231 | def call(self, input_tensor, training=None): 232 | input_transposed = tf.transpose(input_tensor, [3, 0, 1, 2, 4]) 233 | input_shape = K.shape(input_transposed) 234 | input_tensor_reshaped = K.reshape(input_transposed, [ 235 | input_shape[1] * input_shape[0], self.input_height, self.input_width, self.input_num_atoms]) 236 | input_tensor_reshaped.set_shape((None, self.input_height, self.input_width, self.input_num_atoms)) 237 | 238 | 239 | if self.upsamp_type == 'resize': 240 | upsamp = K.resize_images(input_tensor_reshaped, self.scaling, self.scaling, 'channels_last') 241 | outputs = K.conv2d(upsamp, kernel=self.W, strides=(1, 1), padding=self.padding, data_format='channels_last') 242 | elif self.upsamp_type == 'subpix': 243 | conv = K.conv2d(input_tensor_reshaped, kernel=self.W, strides=(1, 1), padding='same', 244 | data_format='channels_last') 245 | outputs = tf.depth_to_space(conv, self.scaling) 246 | else: 247 | batch_size = input_shape[1] * input_shape[0] 248 | 249 | # Infer the dynamic output shape: 250 | out_height = deconv_length(self.input_height, self.scaling, self.kernel_size, self.padding) 251 | out_width = deconv_length(self.input_width, self.scaling, self.kernel_size, self.padding) 252 | output_shape = (batch_size, out_height, out_width, self.num_capsule * self.num_atoms) 253 | 254 | outputs = K.conv2d_transpose(input_tensor_reshaped, self.W, output_shape, (self.scaling, self.scaling), 255 | padding=self.padding, data_format='channels_last') 256 | 257 | votes_shape = K.shape(outputs) 258 | _, conv_height, conv_width, _ = outputs.get_shape() 259 | 260 | votes = K.reshape(outputs, [input_shape[1], input_shape[0], votes_shape[1], votes_shape[2], 261 | self.num_capsule, self.num_atoms]) 262 | votes.set_shape((None, self.input_num_capsule, conv_height.value, conv_width.value, 263 | self.num_capsule, self.num_atoms)) 264 | 265 | logit_shape = K.stack([ 266 | input_shape[1], input_shape[0], votes_shape[1], votes_shape[2], self.num_capsule]) 267 | biases_replicated = K.tile(self.b, [votes_shape[1], votes_shape[2], 1, 1]) 268 | 269 | activations = update_routing( 270 | votes=votes, 271 | biases=biases_replicated, 272 | logit_shape=logit_shape, 273 | num_dims=6, 274 | input_dim=self.input_num_capsule, 275 | output_dim=self.num_capsule, 276 | num_routing=self.routings) 277 | 278 | return activations 279 | 280 | def compute_output_shape(self, input_shape): 281 | output_shape = list(input_shape) 282 | 283 | output_shape[1] = deconv_length(output_shape[1], self.scaling, self.kernel_size, self.padding) 284 | output_shape[2] = deconv_length(output_shape[2], self.scaling, self.kernel_size, self.padding) 285 | output_shape[3] = self.num_capsule 286 | output_shape[4] = self.num_atoms 287 | 288 | return tuple(output_shape) 289 | 290 | def get_config(self): 291 | config = { 292 | 'kernel_size': self.kernel_size, 293 | 'num_capsule': self.num_capsule, 294 | 'num_atoms': self.num_atoms, 295 | 'scaling': self.scaling, 296 | 'padding': self.padding, 297 | 'upsamp_type': self.upsamp_type, 298 | 'routings': self.routings, 299 | 'kernel_initializer': initializers.serialize(self.kernel_initializer) 300 | } 301 | base_config = super(DeconvCapsuleLayer, self).get_config() 302 | return dict(list(base_config.items()) + list(config.items())) 303 | 304 | 305 | def update_routing(votes, biases, logit_shape, num_dims, input_dim, output_dim, 306 | num_routing): 307 | if num_dims == 6: 308 | votes_t_shape = [5, 0, 1, 2, 3, 4] 309 | r_t_shape = [1, 2, 3, 4, 5, 0] 310 | elif num_dims == 4: 311 | votes_t_shape = [3, 0, 1, 2] 312 | r_t_shape = [1, 2, 3, 0] 313 | else: 314 | raise NotImplementedError('Not implemented') 315 | 316 | votes_trans = tf.transpose(votes, votes_t_shape) 317 | _, _, _, height, width, caps = votes_trans.get_shape() 318 | 319 | def _body(i, logits, activations): 320 | """Routing while loop.""" 321 | # route: [batch, input_dim, output_dim, ...] 322 | route = tf.nn.softmax(logits, axis=-1) 323 | preactivate_unrolled = route * votes_trans 324 | preact_trans = tf.transpose(preactivate_unrolled, r_t_shape) 325 | preactivate = tf.reduce_sum(preact_trans, axis=1) + biases 326 | activation = _squash(preactivate) 327 | activations = activations.write(i, activation) 328 | act_3d = K.expand_dims(activation, 1) 329 | tile_shape = np.ones(num_dims, dtype=np.int32).tolist() 330 | tile_shape[1] = input_dim 331 | act_replicated = tf.tile(act_3d, tile_shape) 332 | distances = tf.reduce_sum(votes * act_replicated, axis=-1) 333 | logits += distances 334 | return (i + 1, logits, activations) 335 | 336 | activations = tf.TensorArray( 337 | dtype=tf.float32, size=num_routing, clear_after_read=False) 338 | logits = tf.fill(logit_shape, 0.0) 339 | 340 | i = tf.constant(0, dtype=tf.int32) 341 | _, logits, activations = tf.while_loop( 342 | lambda i, logits, activations: i < num_routing, 343 | _body, 344 | loop_vars=[i, logits, activations], 345 | swap_memory=True) 346 | 347 | return K.cast(activations.read(num_routing - 1), dtype='float32') 348 | 349 | 350 | def _squash(input_tensor): 351 | norm = tf.norm(input_tensor, axis=-1, keepdims=True) 352 | norm_squared = norm * norm 353 | return (input_tensor / norm) * (norm_squared / (1 + norm_squared)) 354 | -------------------------------------------------------------------------------- /segcapsnet/subpixel_upscaling.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | This file contains an implementation of Sub-pixel convolutional upscaling layer based on 4 | the paper "Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel 5 | Convolutional Neural Network" (https://arxiv.org/abs/1609.05158). 6 | """ 7 | 8 | from __future__ import absolute_import 9 | 10 | import tensorflow as tf 11 | from keras.engine import Layer 12 | from keras.utils.generic_utils import get_custom_objects 13 | from keras.utils.conv_utils import normalize_data_format 14 | 15 | class SubPixelUpscaling(Layer): 16 | """ This layer requires a Convolution2D prior to it, having output filters computed according to 17 | the formula : 18 | filters = k * (scale_factor * scale_factor) 19 | where k = a user defined number of filters (generally larger than 32) 20 | scale_factor = the upscaling factor (generally 2) 21 | This layer performs the depth to space operation on the convolution filters, and returns a 22 | tensor with the size as defined below. 23 | # Example : 24 | ```python 25 | # A standard subpixel upscaling block 26 | x = Convolution2D(256, 3, 3, padding='same', activation='relu')(...) 27 | u = SubPixelUpscaling(scale_factor=2)(x) 28 | [Optional] 29 | x = Convolution2D(256, 3, 3, padding='same', activation='relu')(u) 30 | ``` 31 | In practice, it is useful to have a second convolution layer after the 32 | SubPixelUpscaling layer to speed up the learning process. 33 | However, if you are stacking multiple SubPixelUpscaling blocks, it may increase 34 | the number of parameters greatly, so the Convolution layer after SubPixelUpscaling 35 | layer can be removed. 36 | # Arguments 37 | scale_factor: Upscaling factor. 38 | data_format: Can be None, 'channels_first' or 'channels_last'. 39 | # Input shape 40 | 4D tensor with shape: 41 | `(samples, k * (scale_factor * scale_factor) channels, rows, cols)` if data_format='channels_first' 42 | or 4D tensor with shape: 43 | `(samples, rows, cols, k * (scale_factor * scale_factor) channels)` if data_format='channels_last'. 44 | # Output shape 45 | 4D tensor with shape: 46 | `(samples, k channels, rows * scale_factor, cols * scale_factor))` if data_format='channels_first' 47 | or 4D tensor with shape: 48 | `(samples, rows * scale_factor, cols * scale_factor, k channels)` if data_format='channels_last'. 49 | """ 50 | 51 | def __init__(self, scale_factor=2, data_format=None, **kwargs): 52 | super(SubPixelUpscaling, self).__init__(**kwargs) 53 | 54 | self.scale_factor = scale_factor 55 | self.data_format = normalize_data_format(data_format) 56 | 57 | def build(self, input_shape): 58 | pass 59 | 60 | def call(self, x, mask=None): 61 | y = tf.depth_to_space(x, self.scale_factor, self.data_format) 62 | return y 63 | 64 | def compute_output_shape(self, input_shape): 65 | if self.data_format == 'channels_first': 66 | b, k, r, c = input_shape 67 | return (b, k // (self.scale_factor ** 2), r * self.scale_factor, c * self.scale_factor) 68 | else: 69 | b, r, c, k = input_shape 70 | return (b, r * self.scale_factor, c * self.scale_factor, k // (self.scale_factor ** 2)) 71 | 72 | def get_config(self): 73 | config = {'scale_factor': self.scale_factor, 74 | 'data_format': self.data_format} 75 | base_config = super(SubPixelUpscaling, self).get_config() 76 | return dict(list(base_config.items()) + list(config.items())) 77 | 78 | 79 | get_custom_objects().update({'SubPixelUpscaling': SubPixelUpscaling}) -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Capsules for Object Segmentation (SegCaps) 3 | Original Paper by Rodney LaLonde and Ulas Bagci (https://arxiv.org/abs/1804.04241) 4 | Code written by: Rodney LaLonde 5 | If you use significant portions of this code or the ideas from our paper, please cite it :) 6 | If you have any questions, please email me at lalonde@knights.ucf.edu. 7 | 8 | This file is used for testing models. Please see the README for details about testing. 9 | 10 | ============== 11 | This is the entry point of the test procedure for UNet, tiramisu, 12 | Capsule Nets (capsbasic) or SegCaps(segcapsr1 or segcapsr3). 13 | 14 | @author: Cheng-Lin Li a.k.a. Clark 15 | 16 | @copyright: 2018 Cheng-Lin Li@Insight AI. All rights reserved. 17 | 18 | @license: Licensed under the Apache License v2.0. 19 | http://www.apache.org/licenses/ 20 | 21 | @contact: clark.cl.li@gmail.com 22 | 23 | Tasks: 24 | The program based on parameters from main.py to perform testing tasks 25 | on all models. 26 | 27 | 28 | Data: 29 | MS COCO 2017 or LUNA 2016 were tested on this package. 30 | You can leverage your own data set but the mask images should follow the format of 31 | MS COCO or with background color = 0 on each channel. 32 | 33 | 34 | Enhancement: 35 | 1. Integrated with MS COCO 2017 dataset. 36 | 37 | 38 | ''' 39 | 40 | from __future__ import print_function 41 | import logging 42 | import matplotlib 43 | import matplotlib.pyplot as plt 44 | from os.path import join 45 | from os import makedirs 46 | import csv 47 | import SimpleITK as sitk 48 | import numpy as np 49 | import scipy.ndimage.morphology 50 | from skimage import measure, filters 51 | from utils.metrics import dc, jc, assd 52 | from PIL import Image 53 | from keras import backend as K 54 | from keras.utils import print_summary 55 | from utils.data_helper import get_generator 56 | from utils.custom_data_aug import convert_img_data, convert_mask_data 57 | 58 | matplotlib.use('Agg') 59 | plt.ioff() 60 | K.set_image_data_format('channels_last') 61 | 62 | RESOLUTION = 512 63 | GRAYSCALE = True 64 | 65 | 66 | def threshold_mask(raw_output, threshold): 67 | # raw_output 3d:(119, 512, 512) 68 | if threshold == 0: 69 | try: 70 | threshold = filters.threshold_otsu(raw_output) 71 | except: 72 | threshold = 0.5 73 | 74 | logging.info('\tThreshold: {}'.format(threshold)) 75 | 76 | raw_output[raw_output > threshold] = 1 77 | raw_output[raw_output < 1] = 0 78 | 79 | # all_labels 3d:(119, 512, 512) 80 | all_labels = measure.label(raw_output) 81 | # props 3d: region of props=> 82 | # list(_RegionProperties:) 83 | # with bbox. 84 | props = measure.regionprops(all_labels) 85 | props.sort(key=lambda x: x.area, reverse=True) 86 | thresholded_mask = np.zeros(raw_output.shape) 87 | 88 | if len(props) >= 2: 89 | # if the largest is way larger than the second largest 90 | if props[0].area / props[1].area > 5: 91 | # only turn on the largest component 92 | thresholded_mask[all_labels == props[0].label] = 1 93 | else: 94 | # turn on two largest components 95 | thresholded_mask[all_labels == props[0].label] = 1 96 | thresholded_mask[all_labels == props[1].label] = 1 97 | elif len(props): 98 | thresholded_mask[all_labels == props[0].label] = 1 99 | # threshold_mask: 3d=(119, 512, 512) 100 | thresholded_mask = scipy.ndimage.morphology.binary_fill_holes(thresholded_mask).astype(np.uint8) 101 | 102 | return thresholded_mask 103 | 104 | 105 | def test(args, test_list, model_list, net_input_shape): 106 | if args.weights_path == '': 107 | weights_path = join(args.check_dir, args.output_name + '_model_' + args.time + '.hdf5') 108 | else: 109 | weights_path = join(args.data_root_dir, args.weights_path) 110 | 111 | output_dir = join(args.data_root_dir, 'results', args.net, 'split_' + str(args.split_num)) 112 | raw_out_dir = join(output_dir, 'raw_output') 113 | fin_out_dir = join(output_dir, 'final_output') 114 | fig_out_dir = join(output_dir, 'qual_figs') 115 | try: 116 | makedirs(raw_out_dir) 117 | except: 118 | pass 119 | try: 120 | makedirs(fin_out_dir) 121 | except: 122 | pass 123 | try: 124 | makedirs(fig_out_dir) 125 | except: 126 | pass 127 | 128 | if len(model_list) > 1: 129 | eval_model = model_list[1] 130 | else: 131 | eval_model = model_list[0] 132 | try: 133 | logging.info('\nWeights_path=%s'%(weights_path)) 134 | eval_model.load_weights(weights_path) 135 | except: 136 | logging.warning('\nUnable to find weights path. Testing with random weights.') 137 | print_summary(model=eval_model, positions=[.38, .65, .75, 1.]) 138 | 139 | # Set up placeholders 140 | outfile = '' 141 | if args.compute_dice: 142 | dice_arr = np.zeros((len(test_list))) 143 | outfile += 'dice_' 144 | if args.compute_jaccard: 145 | jacc_arr = np.zeros((len(test_list))) 146 | outfile += 'jacc_' 147 | if args.compute_assd: 148 | assd_arr = np.zeros((len(test_list))) 149 | outfile += 'assd_' 150 | 151 | # Testing the network 152 | logging.info('\nTesting... This will take some time...') 153 | 154 | with open(join(output_dir, args.save_prefix + outfile + 'scores.csv'), 'w') as csvfile: 155 | writer = csv.writer(csvfile, delimiter=',', quotechar='|', quoting=csv.QUOTE_MINIMAL) 156 | 157 | row = ['Scan Name'] 158 | if args.compute_dice: 159 | row.append('Dice Coefficient') 160 | if args.compute_jaccard: 161 | row.append('Jaccard Index') 162 | if args.compute_assd: 163 | row.append('Average Symmetric Surface Distance') 164 | 165 | writer.writerow(row) 166 | 167 | for i, img in enumerate((test_list)): 168 | sitk_img = sitk.ReadImage(join(args.data_root_dir, 'imgs', img[0])) 169 | img_data = sitk.GetArrayFromImage(sitk_img) # 3d:(slices, 512, 512), 2d:(512, 512, channels=4) 170 | 171 | # Change RGB to single slice of grayscale image for MS COCO 17 dataset. 172 | if args.dataset == 'mscoco17': 173 | img_data = convert_img_data(img_data, 3) 174 | 175 | num_slices = 1 176 | logging.info('\ntest.test: eval_model.predict_generator') 177 | _, _, generate_test_batches = get_generator(args.dataset) 178 | output_array = eval_model.predict_generator(generate_test_batches(args.data_root_dir, [img], 179 | net_input_shape, 180 | batchSize=args.batch_size, 181 | numSlices=args.slices, 182 | subSampAmt=0, 183 | stride=1), 184 | steps=num_slices, max_queue_size=1, workers=4, 185 | use_multiprocessing=args.use_multiprocessing, 186 | verbose=1) 187 | logging.info('\ntest.test: output_array=%s'%(output_array)) 188 | if args.net.find('caps') != -1: 189 | # A list with two images [mask, recon], get mask image.#3d: 190 | # output_array=[mask(Slices, x=512, y=512, 1), recon(slices, x=512, y=512, 1)] 191 | output = output_array[0][:,:,:,0] # output = (slices, 512, 512) 192 | #recon = output_array[1][:,:,:,0] 193 | else: 194 | output = output_array[:,:,:,0] 195 | 196 | #output_image = RTTI size:[512, 512, 119] 197 | output_img = sitk.GetImageFromArray(output) 198 | print('Segmenting Output') 199 | # output_bin (119, 512, 512) 200 | output_bin = threshold_mask(output, args.thresh_level) 201 | # output_mask = RIIT (512, 512, 119) 202 | output_mask = sitk.GetImageFromArray(output_bin) 203 | if args.dataset == 'luna16': 204 | output_img.CopyInformation(sitk_img) 205 | output_mask.CopyInformation(sitk_img) 206 | 207 | print('Saving Output') 208 | sitk.WriteImage(output_img, join(raw_out_dir, img[0][:-4] + '_raw_output' + img[0][-4:])) 209 | sitk.WriteImage(output_mask, join(fin_out_dir, img[0][:-4] + '_final_output' + img[0][-4:])) 210 | else: # MS COCO 17 211 | plt.imshow(output[0,:,:], cmap = 'gray') 212 | plt.imsave(join(raw_out_dir, img[0][:-4] + '_raw_output' + img[0][-4:]), output[0,:,:]) 213 | plt.imshow(output_bin[0,:,:], cmap = 'gray') 214 | plt.imsave(join(fin_out_dir, img[0][:-4] + '_final_output' + img[0][-4:]), output_bin[0,:,:]) 215 | 216 | # Load gt mask 217 | # sitk_mask: 3d RTTI(512, 512, slices) 218 | sitk_mask = sitk.ReadImage(join(args.data_root_dir, 'masks', img[0])) 219 | # gt_data: 3d=(slices, 512, 512), Ground Truth data 220 | gt_data = sitk.GetArrayFromImage(sitk_mask) 221 | 222 | # Change RGB to single slice of grayscale image for MS COCO 17 dataset. 223 | if args.dataset == 'mscoco17': 224 | gt_data = convert_mask_data(gt_data) 225 | # Reshape numpy from 2 to 3 dimensions (slices, heigh, width) 226 | gt_data = gt_data.reshape([1, gt_data.shape[0], gt_data.shape[1]]) 227 | 228 | # Plot Qual Figure 229 | print('Creating Qualitative Figure for Quick Reference') 230 | f, ax = plt.subplots(1, 3, figsize=(15, 5)) 231 | 232 | if args.dataset == 'mscoco17': 233 | pass 234 | else: # 3D data 235 | ax[0].imshow(img_data[img_data.shape[0] // 3, :, :], alpha=1, cmap='gray') 236 | ax[0].imshow(output_bin[img_data.shape[0] // 3, :, :], alpha=0.5, cmap='Blues') 237 | ax[0].imshow(gt_data[img_data.shape[0] // 3, :, :], alpha=0.2, cmap='Reds') 238 | ax[0].set_title('Slice {}/{}'.format(img_data.shape[0] // 3, img_data.shape[0])) 239 | ax[0].axis('off') 240 | 241 | ax[1].imshow(img_data[img_data.shape[0] // 2, :, :], alpha=1, cmap='gray') 242 | ax[1].imshow(output_bin[img_data.shape[0] // 2, :, :], alpha=0.5, cmap='Blues') 243 | ax[1].imshow(gt_data[img_data.shape[0] // 2, :, :], alpha=0.2, cmap='Reds') 244 | ax[1].set_title('Slice {}/{}'.format(img_data.shape[0] // 2, img_data.shape[0])) 245 | ax[1].axis('off') 246 | 247 | ax[2].imshow(img_data[img_data.shape[0] // 2 + img_data.shape[0] // 4, :, :], alpha=1, cmap='gray') 248 | ax[2].imshow(output_bin[img_data.shape[0] // 2 + img_data.shape[0] // 4, :, :], alpha=0.5, 249 | cmap='Blues') 250 | ax[2].imshow(gt_data[img_data.shape[0] // 2 + img_data.shape[0] // 4, :, :], alpha=0.2, 251 | cmap='Reds') 252 | ax[2].set_title( 253 | 'Slice {}/{}'.format(img_data.shape[0] // 2 + img_data.shape[0] // 4, img_data.shape[0])) 254 | ax[2].axis('off') 255 | 256 | fig = plt.gcf() 257 | fig.suptitle(img[0][:-4]) 258 | 259 | plt.savefig(join(fig_out_dir, img[0][:-4] + '_qual_fig' + '.png'), 260 | format='png', bbox_inches='tight') 261 | plt.close('all') 262 | 263 | # Compute metrics 264 | row = [img[0][:-4]] 265 | if args.compute_dice: 266 | logging.info('\nComputing Dice') 267 | dice_arr[i] = dc(output_bin, gt_data) 268 | logging.info('\tDice: {}'.format(dice_arr[i])) 269 | row.append(dice_arr[i]) 270 | if args.compute_jaccard: 271 | logging.info('\nComputing Jaccard') 272 | jacc_arr[i] = jc(output_bin, gt_data) 273 | logging.info('\tJaccard: {}'.format(jacc_arr[i])) 274 | row.append(jacc_arr[i]) 275 | if args.compute_assd: 276 | logging.info('\nComputing ASSD') 277 | assd_arr[i] = assd(output_bin, gt_data, voxelspacing=sitk_img.GetSpacing(), connectivity=1) 278 | logging.info('\tASSD: {}'.format(assd_arr[i])) 279 | row.append(assd_arr[i]) 280 | 281 | writer.writerow(row) 282 | 283 | row = ['Average Scores'] 284 | if args.compute_dice: 285 | row.append(np.mean(dice_arr)) 286 | if args.compute_jaccard: 287 | row.append(np.mean(jacc_arr)) 288 | if args.compute_assd: 289 | row.append(np.mean(assd_arr)) 290 | writer.writerow(row) 291 | 292 | print('Done.') 293 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Capsules for Object Segmentation (SegCaps) 3 | Original Paper by Rodney LaLonde and Ulas Bagci (https://arxiv.org/abs/1804.04241) 4 | Code written by: Rodney LaLonde 5 | If you use significant portions of this code or the ideas from our paper, please cite it :) 6 | If you have any questions, please email me at lalonde@knights.ucf.edu. 7 | 8 | This file is used for training models. Please see the README for details about training. 9 | 10 | ============== 11 | This is the entry point of the train procedure for UNet, tiramisu, Capsule Nets (capsbasic) or SegCaps(segcapsr1 or segcapsr3). 12 | 13 | @author: Cheng-Lin Li a.k.a. Clark 14 | 15 | @copyright: 2018 Cheng-Lin Li@Insight AI. All rights reserved. 16 | 17 | @license: Licensed under the Apache License v2.0. http://www.apache.org/licenses/ 18 | 19 | @contact: clark.cl.li@gmail.com 20 | 21 | Tasks: 22 | The program based on parameters from main.py to perform training tasks on all models. 23 | 24 | 25 | Data: 26 | MS COCO 2017 or LUNA 2016 were tested on this package. 27 | You can leverage your own data set but the mask images should follow the format of MS COCO or with background color = 0 on each channel. 28 | 29 | 30 | Enhancement: 31 | 1. Integrated with MS COCO 2017 dataset. 32 | 33 | ''' 34 | 35 | from __future__ import print_function 36 | 37 | import logging 38 | import matplotlib 39 | matplotlib.use('Agg') 40 | import matplotlib.pyplot as plt 41 | plt.ioff() 42 | 43 | from os.path import join 44 | import numpy as np 45 | 46 | from keras.optimizers import Adam 47 | from keras import backend as K 48 | K.set_image_data_format('channels_last') 49 | from keras.utils.training_utils import multi_gpu_model 50 | from keras.callbacks import ModelCheckpoint, CSVLogger, EarlyStopping, ReduceLROnPlateau, TensorBoard 51 | import tensorflow as tf 52 | 53 | from utils.custom_losses import dice_hard, weighted_binary_crossentropy_loss, dice_loss, margin_loss, bce_dice_loss 54 | from utils.load_data import load_class_weights 55 | from utils.data_helper import get_generator 56 | 57 | 58 | def get_loss(root, split, net, recon_wei, choice): 59 | if choice == 'w_bce': 60 | pos_class_weight = load_class_weights(root=root, split=split) 61 | loss = weighted_binary_crossentropy_loss(pos_class_weight) 62 | elif choice == 'bce': 63 | loss = 'binary_crossentropy' 64 | elif choice == 'dice': 65 | loss = dice_loss 66 | elif choice == 'w_mar': 67 | pos_class_weight = load_class_weights(root=root, split=split) 68 | loss = margin_loss(margin=0.4, downweight=0.5, pos_weight=pos_class_weight) 69 | elif choice == 'mar': 70 | loss = margin_loss(margin=0.4, downweight=0.5, pos_weight=1.0) 71 | elif choice == 'bce_dice': 72 | loss = bce_dice_loss 73 | else: 74 | raise Exception("Unknow loss_type") 75 | 76 | if net.find('caps') != -1: 77 | return {'out_seg': loss, 'out_recon': 'mse'}, {'out_seg': 1., 'out_recon': recon_wei} 78 | else: 79 | return loss, None 80 | 81 | def get_callbacks(arguments): 82 | if arguments.net.find('caps') != -1: 83 | monitor_name = 'val_out_seg_dice_hard' 84 | else: 85 | monitor_name = 'val_dice_hard' 86 | 87 | csv_logger = CSVLogger(join(arguments.log_dir, arguments.output_name + '_log_' + arguments.time + '.csv'), separator=',') 88 | tb = TensorBoard(arguments.tf_log_dir, batch_size=arguments.batch_size, histogram_freq=0) 89 | # Due to customized major layers and loss function, the program just store the model weights. 90 | # Model should be load by program then load the model weights for inference. 91 | model_checkpoint = ModelCheckpoint(join(arguments.check_dir, arguments.output_name + '_model_' + arguments.time + '.hdf5'), 92 | monitor=monitor_name, save_best_only=True, save_weights_only=False, 93 | verbose=1, mode='max') 94 | lr_reducer = ReduceLROnPlateau(monitor=monitor_name, factor=0.05, cooldown=0, patience=50,verbose=1, mode='max') 95 | early_stopper = EarlyStopping(monitor=monitor_name, min_delta=0, patience=arguments.patience, verbose=0, mode='max') 96 | 97 | return [model_checkpoint, csv_logger, lr_reducer, early_stopper, tb] 98 | 99 | def compile_model(args, net_input_shape, uncomp_model): 100 | # Set optimizer loss and metrics 101 | # opt = Adam(lr=args.initial_lr, beta_1=0.99, beta_2=0.999, decay=1e-6) 102 | # Revised decay rate to match with the original experiment parameter on the paper 103 | opt = Adam(lr=args.initial_lr, beta_1=0.9, beta_2=0.999, epsilon = 0.1, decay = 1e-6) 104 | if args.net.find('caps') != -1: 105 | metrics = {'out_seg': dice_hard} 106 | else: 107 | metrics = [dice_hard] 108 | 109 | loss, loss_weighting = get_loss(root=args.data_root_dir, split=args.split_num, net=args.net, 110 | recon_wei=args.recon_wei, choice=args.loss) 111 | 112 | # If using CPU or single GPU 113 | if args.gpus <= 1: 114 | uncomp_model.compile(optimizer=opt, loss=loss, metrics=metrics) 115 | return uncomp_model 116 | # If using multiple GPUs 117 | else: 118 | with tf.device("/cpu:0"): 119 | uncomp_model.compile(optimizer=opt, loss=loss, loss_weights=loss_weighting, metrics=metrics) 120 | model = multi_gpu_model(uncomp_model, gpus=args.gpus) 121 | model.__setattr__('callback_model', uncomp_model) 122 | model.compile(optimizer=opt, loss=loss, loss_weights=loss_weighting, metrics=metrics) 123 | return model 124 | 125 | 126 | def plot_training(training_history, arguments): 127 | f, (ax1, ax2) = plt.subplots(2, sharex=True, figsize=(10, 10)) 128 | f.suptitle(arguments.net, fontsize=18) 129 | 130 | if arguments.net.find('caps') != -1: 131 | ax1.plot(training_history.history['out_seg_dice_hard']) 132 | ax1.plot(training_history.history['val_out_seg_dice_hard']) 133 | else: 134 | ax1.plot(training_history.history['dice_hard']) 135 | ax1.plot(training_history.history['val_dice_hard']) 136 | ax1.set_title('Dice Coefficient') 137 | ax1.set_ylabel('Dice', fontsize=12) 138 | ax1.legend(['Train', 'Val'], loc='upper left') 139 | ax1.set_yticks(np.arange(0, 1.05, 0.05)) 140 | if arguments.net.find('caps') != -1: 141 | ax1.set_xticks(np.arange(0, len(training_history.history['out_seg_dice_hard']))) 142 | else: 143 | ax1.set_xticks(np.arange(0, len(training_history.history['dice_hard']))) 144 | ax1.grid(True) 145 | gridlines1 = ax1.get_xgridlines() + ax1.get_ygridlines() 146 | for line in gridlines1: 147 | line.set_linestyle('-.') 148 | 149 | ax2.plot(training_history.history['loss']) 150 | ax2.plot(training_history.history['val_loss']) 151 | ax2.set_title('Model Loss') 152 | ax2.set_ylabel('Loss', fontsize=12) 153 | ax2.set_xlabel('Epoch', fontsize=12) 154 | ax2.legend(['Train', 'Val'], loc='upper right') 155 | ax1.set_xticks(np.arange(0, len(training_history.history['loss']))) 156 | ax2.grid(True) 157 | gridlines2 = ax2.get_xgridlines() + ax2.get_ygridlines() 158 | for line in gridlines2: 159 | line.set_linestyle('-.') 160 | 161 | f.savefig(join(arguments.output_dir, arguments.output_name + '_plots_' + arguments.time + '.png')) 162 | plt.close() 163 | 164 | def train(args, train_list, val_list, u_model, net_input_shape): 165 | # Compile the loaded model 166 | model = compile_model(args=args, net_input_shape=net_input_shape, uncomp_model=u_model) 167 | if args.retrain == 1: 168 | # Retrain the model. Load re-train weights. 169 | weights_path = join(args.data_root_dir, args.weights_path) 170 | logging.info('\nRetrain model from weights_path=%s'%(weights_path)) 171 | model.load_weights(weights_path) 172 | else: # Train from scratch 173 | pass 174 | # Set the callbacks 175 | callbacks = get_callbacks(args) 176 | 177 | # Training the network 178 | # Original project parameters. TODO: Get hyper parameters from input. 179 | # history = model.fit_generator( 180 | # generate_train_batches(args.data_root_dir, train_list, net_input_shape, net=args.net, 181 | # batchSize=args.batch_size, numSlices=args.slices, subSampAmt=args.subsamp, 182 | # stride=args.stride, shuff=args.shuffle_data, aug_data=args.aug_data), 183 | # max_queue_size=40, workers=4, use_multiprocessing=False, 184 | # steps_per_epoch=10000, 185 | # validation_data=generate_val_batches(args.data_root_dir, val_list, net_input_shape, net=args.net, 186 | # batchSize=args.batch_size, numSlices=args.slices, subSampAmt=0, 187 | # stride=20, shuff=args.shuffle_data), 188 | # validation_steps=500, # Set validation stride larger to see more of the data. 189 | # epochs=200, 190 | # callbacks=callbacks, 191 | # verbose=1) 192 | 193 | # POC testing, change stride from 20 to args.stride in generate_val_batches 194 | generate_train_batches, generate_val_batches, _ = get_generator(args.dataset) 195 | history = model.fit_generator( 196 | generate_train_batches(args.data_root_dir, train_list, net_input_shape, net=args.net, 197 | batchSize=args.batch_size, numSlices=args.slices, subSampAmt=args.subsamp, 198 | stride=args.stride, shuff=args.shuffle_data, aug_data=args.aug_data), 199 | max_queue_size=8, workers=4, use_multiprocessing=args.use_multiprocessing, 200 | steps_per_epoch=args.steps_per_epoch, 201 | validation_data=generate_val_batches(args.data_root_dir, val_list, net_input_shape, net=args.net, 202 | batchSize=args.batch_size, numSlices=args.slices, subSampAmt=0, 203 | stride=args.stride, shuff=args.shuffle_data), 204 | validation_steps=5, # Set validation stride larger to see more of the data. 205 | epochs=args.epochs, 206 | callbacks=callbacks, 207 | verbose=1) 208 | # Plot the training data collected 209 | plot_training(history, args) 210 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cheng-Lin-Li/SegCaps/237209c2f53b4c39e8109390db144b89e0f49335/utils/__init__.py -------------------------------------------------------------------------------- /utils/custom_data_aug.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Capsules for Object Segmentation (SegCaps) 3 | Original Paper by Rodney LaLonde and Ulas Bagci (https://arxiv.org/abs/1804.04241) 4 | Code written by: Rodney LaLonde 5 | If you use significant portions of this code or the ideas from our paper, please cite it :) 6 | If you have any questions, please email me at lalonde@knights.ucf.edu. 7 | 8 | This file contains the custom data augmentation functions. 9 | 10 | ===== 11 | 12 | This program includes all data augmentation functions UNet, tiramisu, Capsule Nets (capsbasic) or SegCaps(segcapsr1 or segcapsr3). 13 | 14 | @author: Cheng-Lin Li a.k.a. Clark 15 | 16 | @copyright: 2018 Cheng-Lin Li@Insight AI. All rights reserved. 17 | 18 | @license: Licensed under the Apache License v2.0. http://www.apache.org/licenses/ 19 | 20 | @contact: clark.cl.li@gmail.com 21 | 22 | Enhancement: 23 | 1. add image2float_array 24 | 2. add image_resize2square: Convert any size of image file to 512 X 512 resolutions 25 | 3. add image_enhance function to sifht image color space includes background color 26 | 4. add process_image to wrap image color shifting and resize the resolutions 27 | 5. add change_background_color to change mask background color of MS COCO files to black 28 | 6. add convert_mask_data function to resize and change mask color to black 29 | 7. add convert_img_data function to wrap process_image, normalized data, reshape to require dimension 3 or 4 30 | ''' 31 | 32 | import numpy as np 33 | import cv2 34 | from scipy.ndimage.interpolation import map_coordinates 35 | 36 | from keras.preprocessing.image import random_rotation, random_shift, random_zoom, random_shear 37 | 38 | GRAYSCALE = True 39 | RESOLUTION = 512 40 | COCO_BACKGROUND = (68, 1, 84, 255) 41 | MASK_BACKGROUND = (0,0,0,0) 42 | DEFAULT_RGB_SCALE_FACTOR = 256000.0 43 | DEFAULT_GRAY_SCALE_FACTOR = {np.uint8: 100.0, 44 | np.uint16: 1000.0, 45 | np.int32: DEFAULT_RGB_SCALE_FACTOR} 46 | 47 | # Function to distort image 48 | def elastic_transform(image, alpha=2000, sigma=40, alpha_affine=40, random_state=None): 49 | if random_state is None: 50 | random_state = np.random.RandomState(None) 51 | 52 | shape = image.shape 53 | shape_size = shape[:2] 54 | 55 | # Random affine 56 | center_square = np.float32(shape_size) // 2 57 | square_size = min(shape_size) // 3 58 | pts1 = np.float32([center_square + square_size, [center_square[0]+square_size, center_square[1]-square_size], center_square - square_size]) 59 | pts2 = pts1 + random_state.uniform(-alpha_affine, alpha_affine, size=pts1.shape).astype(np.float32) 60 | M = cv2.getAffineTransform(pts1, pts2) 61 | for i in range(shape[2]): 62 | image[:,:,i] = cv2.warpAffine(image[:,:,i], M, shape_size[::-1], borderMode=cv2.BORDER_REFLECT_101) 63 | image = image.reshape(shape) 64 | 65 | blur_size = int(4*sigma) | 1 66 | 67 | dx = cv2.GaussianBlur((random_state.rand(*shape_size) * 2 - 1), ksize=(blur_size, blur_size), sigmaX=sigma) * alpha 68 | dy = cv2.GaussianBlur((random_state.rand(*shape_size) * 2 - 1), ksize=(blur_size, blur_size), sigmaX=sigma) * alpha 69 | 70 | x, y = np.meshgrid(np.arange(shape[1]), np.arange(shape[0])) 71 | indices = np.reshape(y+dy, (-1, 1)), np.reshape(x+dx, (-1, 1)) 72 | 73 | def_img = np.zeros_like(image) 74 | for i in range(shape[2]): 75 | def_img[:,:,i] = map_coordinates(image[:,:,i], indices, order=1).reshape(shape_size) 76 | 77 | return def_img 78 | 79 | 80 | def salt_pepper_noise(image, salt=0.2, amount=0.004): 81 | row, col, chan = image.shape 82 | num_salt = np.ceil(amount * row * salt) 83 | num_pepper = np.ceil(amount * row * (1.0 - salt)) 84 | 85 | for n in range(chan//2): # //2 so we don't augment the mask 86 | # Add Salt noise 87 | coords = [np.random.randint(0, i - 1, int(num_salt)) for i in image.shape[0:2]] 88 | image[coords[0], coords[1], n] = 1 89 | 90 | # Add Pepper noise 91 | coords = [np.random.randint(0, i - 1, int(num_pepper)) for i in image.shape[0:2]] 92 | image[coords[0], coords[1], n] = 0 93 | 94 | return image 95 | 96 | 97 | def flip_axis(x, axis): 98 | x = np.asarray(x).swapaxes(axis, 0) 99 | x = x[::-1, ...] 100 | x = x.swapaxes(0, axis) 101 | return x 102 | 103 | def image2float_array(image, scale_factor=None): 104 | ''' 105 | source: https://github.com/ahundt/robotics_setup/blob/master/datasets/google_brain_robot_data/depth_image_encoding.py 106 | 107 | Recovers the depth values from an image. 108 | Reverses the depth to image conversion performed by FloatArrayToRgbImage or 109 | 110 | FloatArrayToGrayImage. 111 | 112 | The image is treated as an array of fixed point depth values. Each 113 | value is converted to float and scaled by the inverse of the factor 114 | that was used to generate the Image object from depth values. If 115 | scale_factor is specified, it should be the same value that was 116 | specified in the original conversion. 117 | 118 | The result of this function should be equal to the original input 119 | within the precision of the conversion. 120 | 121 | Args: 122 | image: Depth image output of FloatArrayTo[Format]Image. 123 | scale_factor: Fixed point scale factor. 124 | 125 | Returns: 126 | A 2D floating point numpy array representing a depth image. 127 | ''' 128 | 129 | image_array = np.array(image) 130 | image_dtype = image_array.dtype 131 | image_shape = image_array.shape 132 | 133 | channels = image_shape[2] if len(image_shape) > 2 else 1 134 | assert 2 <= len(image_shape) <= 3 135 | 136 | if channels == 3: 137 | # RGB image needs to be converted to 24 bit integer. 138 | float_array = np.sum(image_array * [65536, 256, 1], axis=2) 139 | if scale_factor is None: 140 | scale_factor = DEFAULT_RGB_SCALE_FACTOR 141 | else: 142 | if scale_factor is None: 143 | scale_factor = DEFAULT_GRAY_SCALE_FACTOR[image_dtype.type] 144 | float_array = image_array.astype(np.float32) 145 | scaled_array = float_array / scale_factor 146 | 147 | return scaled_array 148 | 149 | 150 | def image_resize2square(image, desired_size = None): 151 | ''' 152 | Transform image to a square image with desired size(resolution) 153 | Padding image with black color which defined as MASK_BACKGROUND 154 | ''' 155 | 156 | # initialize dimensions of the image to be resized and 157 | # grab the image size 158 | old_size = image.shape[:2] 159 | 160 | # if both the width and height are None, then return the 161 | # original image 162 | if desired_size is None or (old_size[0]==desired_size and old_size[1]==desired_size): 163 | return image 164 | 165 | # calculate the ratio of the height and construct the 166 | # dimensions 167 | ratio = float(desired_size)/max(old_size) 168 | new_size = tuple([int(x*ratio) for x in old_size]) 169 | 170 | # new_size should be in (width, height) format 171 | resized = cv2.resize(image, (new_size[1], new_size[0])) 172 | 173 | delta_w = desired_size - new_size[1] 174 | delta_h = desired_size - new_size[0] 175 | top, bottom = delta_h // 2, delta_h - (delta_h // 2) 176 | left, right = delta_w // 2, delta_w - (delta_w // 2) 177 | 178 | new_image = cv2.copyMakeBorder(resized, top, bottom, left, right, cv2.BORDER_CONSTANT, value = MASK_BACKGROUND) 179 | 180 | # return the resized image 181 | return new_image 182 | 183 | def image_enhance(image, shift): 184 | ''' 185 | Input image is a numpy array with unit8 grayscale. 186 | This function will enhance the bright by adding num to each pixel. 187 | perform normalization 188 | ''' 189 | if shift > 0: 190 | for i in range(shift): 191 | image += 1 192 | # If pixel value == 0 which means the value = 256 but overflow to 0 193 | # shift the overflow pix values to 255. 194 | image[image == 0] = 255 195 | 196 | return image 197 | 198 | def process_image(img, shift, resolution): 199 | ''' 200 | Pre-process image before store in numpy file. 201 | shift: shift all pixels a distance with the shift value to avoid black color in image. 202 | resolution: change image resolution to fit model. 203 | ''' 204 | # Add 5 for each pixel on the grayscale image. 205 | img = image_enhance(img, shift = shift) 206 | 207 | # The source image should be 512X512 resolution. 208 | img = image_resize2square(img, resolution) 209 | 210 | return img 211 | 212 | 213 | def change_background_color(img, original_color, new_color): 214 | ''' 215 | Convert mask color of 4 channels png image to new color 216 | ''' 217 | 218 | r1, g1, b1, a1 = original_color[0], original_color[1], original_color[2], original_color[3] # Original value 219 | # mask background color (0,0,0,0) 220 | r2, g2, b2, a2 = new_color[0], new_color[1], new_color[2], new_color[3] # Value that we want to replace it with 221 | 222 | red, green, blue, alpha = img[:,:,0], img[:,:,1], img[:,:,2], img[:,:,3] 223 | mask = (red == r1) & (green == g1) & (blue == b1) & (alpha == a1) 224 | img[:,:,:4][mask] = [r2, g2, b2, a2] 225 | return img 226 | 227 | def convert_mask_data(mask, resolution = RESOLUTION, from_background_color = COCO_BACKGROUND, 228 | to_background_color = MASK_BACKGROUND): 229 | ''' 230 | 1. Resize mask to square with size of resolution. 231 | 2. Change back ground color to black 232 | 3. Change pixel value to 1 for masking 233 | 4. Change pixel value to 0 for non-masking area 234 | 5. Reduce data type to uint8 to reduce the file size of mask. 235 | ''' 236 | mask = image_resize2square(mask, resolution) 237 | 238 | mask = change_background_color(mask, from_background_color, to_background_color) 239 | if GRAYSCALE == True: 240 | # Only need one channel for black and white 241 | mask = mask[:,:,:1] 242 | else: 243 | mask = mask[:,:,:1] # keep 3 channels for RGB. Remove alpha channel. 244 | 245 | mask[mask >= 1] = 1 # The mask. ie. class of Person 246 | mask[mask != 1] = 0 # Non Person / Background 247 | mask = mask.astype(np.uint8) 248 | return mask 249 | 250 | def convert_img_data(img, dims = 4, resolution = RESOLUTION): 251 | ''' 252 | Convert image data by 253 | 1. Shift RGB channel with value 1 to avoid pure black color. 254 | 2. Resize image to square 255 | 3. Normalized data 256 | 4. reshape to require dimension 3 or 4 257 | ''' 258 | img = img[:,:,:3] 259 | if GRAYSCALE == True: 260 | # Add 1 for each pixel and change resolution on the image. 261 | img = process_image(img, shift = 1, resolution = resolution) 262 | 263 | # Translate the image to 24bits grayscale by PILLOW package 264 | img = image2float_array(img, 16777216-1) #2^24=16777216 265 | if dims == 3: 266 | # Reshape numpy from 2 to 3 dimensions 267 | img = img.reshape([img.shape[0], img.shape[1], 1]) 268 | else: # dimension = 4 269 | img = img.reshape([1, img.shape[0], img.shape[1], 1]) 270 | else: # Color image with 3 channels 271 | # Add 1 for each pixel and change resolution on the image. 272 | img = process_image(img, shift = 1, resolution = resolution) 273 | if dims == 3: 274 | # Keep RGB channel, remove alpha channel 275 | img = img[:,:,:3] 276 | else: # dimensions = 4 277 | img = img[:,:,:,:3] 278 | return img 279 | 280 | 281 | def augmentImages(batch_of_images, batch_of_masks): 282 | for i in range(len(batch_of_images)): 283 | img_and_mask = np.concatenate((batch_of_images[i, ...], batch_of_masks[i,...]), axis=2) 284 | if img_and_mask.ndim == 4: # This assumes single channel data. For multi-channel you'll need 285 | # change this to put all channel in slices channel 286 | orig_shape = img_and_mask.shape 287 | img_and_mask = img_and_mask.reshape((img_and_mask.shape[0:3])) 288 | 289 | if np.random.randint(0,10) == 7: 290 | img_and_mask = random_rotation(img_and_mask, rg=45, row_axis=0, col_axis=1, channel_axis=2, 291 | fill_mode='constant', cval=0.) 292 | 293 | if np.random.randint(0, 5) == 3: 294 | img_and_mask = elastic_transform(img_and_mask, alpha=1000, sigma=80, alpha_affine=50) 295 | 296 | if np.random.randint(0, 10) == 7: 297 | img_and_mask = random_shift(img_and_mask, wrg=0.2, hrg=0.2, row_axis=0, col_axis=1, channel_axis=2, 298 | fill_mode='constant', cval=0.) 299 | 300 | if np.random.randint(0, 10) == 7: 301 | img_and_mask = random_shear(img_and_mask, intensity=16, row_axis=0, col_axis=1, channel_axis=2, 302 | fill_mode='constant', cval=0.) 303 | 304 | if np.random.randint(0, 10) == 7: 305 | img_and_mask = random_zoom(img_and_mask, zoom_range=(0.75, 0.75), row_axis=0, col_axis=1, channel_axis=2, 306 | fill_mode='constant', cval=0.) 307 | 308 | if np.random.randint(0, 10) == 7: 309 | img_and_mask = flip_axis(img_and_mask, axis=1) 310 | 311 | if np.random.randint(0, 10) == 7: 312 | img_and_mask = flip_axis(img_and_mask, axis=0) 313 | 314 | if np.random.randint(0, 10) == 7: 315 | salt_pepper_noise(img_and_mask, salt=0.2, amount=0.04) 316 | 317 | if batch_of_images.ndim == 4: 318 | batch_of_images[i, ...] = img_and_mask[...,0:img_and_mask.shape[2]//2] 319 | batch_of_masks[i,...] = img_and_mask[...,img_and_mask.shape[2]//2:] 320 | if batch_of_images.ndim == 5: 321 | img_and_mask = img_and_mask.reshape(orig_shape) 322 | batch_of_images[i, ...] = img_and_mask[...,0:img_and_mask.shape[2]//2, :] 323 | batch_of_masks[i,...] = img_and_mask[...,img_and_mask.shape[2]//2:, :] 324 | 325 | # Ensure the masks did not get any non-binary values. 326 | batch_of_masks[batch_of_masks > 0.5] = 1 327 | batch_of_masks[batch_of_masks <= 0.5] = 0 328 | 329 | return(batch_of_images, batch_of_masks) -------------------------------------------------------------------------------- /utils/custom_losses.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | ''' 3 | Capsules for Object Segmentation (SegCaps) 4 | Original Paper: https://arxiv.org/abs/1804.04241 5 | Code written by: Rodney LaLonde 6 | If you use significant portions of this code or the ideas from our paper, please cite it :) 7 | If you have any questions, please email me at lalonde@knights.ucf.edu. 8 | 9 | This file contains the definitions of custom loss functions not present in the default Keras. 10 | 11 | ===== 12 | 13 | This program includes all custom loss functions UNet, tiramisu, Capsule Nets (capsbasic) or SegCaps(segcapsr1 or segcapsr3). 14 | 15 | @author: Cheng-Lin Li a.k.a. Clark 16 | 17 | @copyright: 2018 Cheng-Lin Li@Insight AI. All rights reserved. 18 | 19 | @license: Licensed under the Apache License v2.0. http://www.apache.org/licenses/ 20 | 21 | @contact: clark.cl.li@gmail.com 22 | 23 | Enhancement: 24 | 1. Revise default loss_type to jaccard on dice_soft function. 25 | 2. add bce_dice_loss for future usage. 26 | ''' 27 | 28 | import tensorflow as tf 29 | from keras import backend as K 30 | from keras.losses import binary_crossentropy 31 | 32 | def dice_soft(y_true, y_pred, loss_type='jaccard', axis=[1,2,3], smooth=1e-5, from_logits=False): 33 | """Soft dice (Sørensen or Jaccard) coefficient for comparing the similarity 34 | of two batch of data, usually be used for binary image segmentation 35 | i.e. labels are binary. The coefficient between 0 to 1, 1 means totally match. 36 | 37 | Parameters 38 | ----------- 39 | y_pred : tensor 40 | A distribution with shape: [batch_size, ....], (any dimensions). 41 | y_true : tensor 42 | A distribution with shape: [batch_size, ....], (any dimensions). 43 | loss_type : string 44 | ``jaccard`` or ``sorensen``, default is ``jaccard``. 45 | axis : list of integer 46 | All dimensions are reduced, default ``[1,2,3]``. 47 | smooth : float 48 | This small value will be added to the numerator and denominator. 49 | If both y_pred and y_true are empty, it makes sure dice is 1. 50 | If either y_pred or y_true are empty (all pixels are background), dice = ```smooth/(small_value + smooth)``, 51 | then if smooth is very small, dice close to 0 (even the image values lower than the threshold), 52 | so in this case, higher smooth can have a higher dice. 53 | 54 | Examples 55 | --------- 56 | >>> outputs = tl.act.pixel_wise_softmax(network.outputs) 57 | >>> dice_loss = 1 - tl.cost.dice_coe(outputs, y_) 58 | 59 | References 60 | ----------- 61 | - `Wiki-Dice `_ 62 | """ 63 | 64 | if not from_logits: 65 | # transform back to logits 66 | _epsilon = tf.convert_to_tensor(1e-7, y_pred.dtype.base_dtype) 67 | y_pred = tf.clip_by_value(y_pred, _epsilon, 1 - _epsilon) 68 | y_pred = tf.log(y_pred / (1 - y_pred)) 69 | 70 | inse = tf.reduce_sum(y_pred * y_true, axis=axis) 71 | if loss_type == 'jaccard': 72 | l = tf.reduce_sum(y_pred * y_pred, axis=axis) 73 | r = tf.reduce_sum(y_true * y_true, axis=axis) 74 | elif loss_type == 'sorensen': 75 | l = tf.reduce_sum(y_pred, axis=axis) 76 | r = tf.reduce_sum(y_true, axis=axis) 77 | else: 78 | raise Exception("Unknow loss_type") 79 | ## old axis=[0,1,2,3] 80 | # dice = 2 * (inse) / (l + r) 81 | # epsilon = 1e-5 82 | # dice = tf.clip_by_value(dice, 0, 1.0-epsilon) # if all empty, dice = 1 83 | ## new haodong 84 | dice = (2. * inse + smooth) / (l + r + smooth) 85 | ## 86 | dice = tf.reduce_mean(dice) 87 | return dice 88 | 89 | 90 | def dice_hard(y_true, y_pred, threshold=0.5, axis=[1,2,3], smooth=1e-5): 91 | """Non-differentiable Sørensen–Dice coefficient for comparing the similarity 92 | of two batch of data, usually be used for binary image segmentation i.e. labels are binary. 93 | The coefficient between 0 to 1, 1 if totally match. 94 | 95 | Parameters 96 | ----------- 97 | y_pred : tensor 98 | A distribution with shape: [batch_size, ....], (any dimensions). 99 | y_true : tensor 100 | A distribution with shape: [batch_size, ....], (any dimensions). 101 | threshold : float 102 | The threshold value to be true. 103 | axis : list of integer 104 | All dimensions are reduced, default ``[1,2,3]``. 105 | smooth : float 106 | This small value will be added to the numerator and denominator, see ``dice_coe``. 107 | 108 | References 109 | ----------- 110 | - `Wiki-Dice `_ 111 | """ 112 | y_pred = tf.cast(y_pred > threshold, dtype=tf.float32) 113 | y_true = tf.cast(y_true > threshold, dtype=tf.float32) 114 | inse = tf.reduce_sum(tf.multiply(y_pred, y_true), axis=axis) 115 | l = tf.reduce_sum(y_pred, axis=axis) 116 | r = tf.reduce_sum(y_true, axis=axis) 117 | ## old axis=[0,1,2,3] 118 | # hard_dice = 2 * (inse) / (l + r) 119 | # epsilon = 1e-5 120 | # hard_dice = tf.clip_by_value(hard_dice, 0, 1.0-epsilon) 121 | ## new haodong 122 | hard_dice = (2. * inse + smooth) / (l + r + smooth) 123 | ## 124 | hard_dice = tf.reduce_mean(hard_dice) 125 | return hard_dice 126 | 127 | 128 | def dice_loss(y_true, y_pred, from_logits=False): 129 | return 1-dice_soft(y_true, y_pred, from_logits=False) 130 | 131 | 132 | def bce_dice_loss(y_true, y_pred): 133 | return binary_crossentropy(y_true, y_pred) + dice_loss(y_true, y_pred) 134 | 135 | 136 | 137 | def weighted_binary_crossentropy_loss(pos_weight): 138 | # pos_weight: A coefficient to use on the positive examples. 139 | def weighted_binary_crossentropy(target, output, from_logits=False): 140 | """Binary crossentropy between an output tensor and a target tensor. 141 | # Arguments 142 | target: A tensor with the same shape as `output`. 143 | output: A tensor. 144 | from_logits: Whether `output` is expected to be a logits tensor. 145 | By default, we consider that `output` 146 | encodes a probability distribution. 147 | # Returns 148 | A tensor. 149 | """ 150 | # Note: tf.nn.sigmoid_cross_entropy_with_logits 151 | # expects logits, Keras expects probabilities. 152 | if not from_logits: 153 | # transform back to logits 154 | _epsilon = tf.convert_to_tensor(1e-7, output.dtype.base_dtype) 155 | output = tf.clip_by_value(output, _epsilon, 1 - _epsilon) 156 | output = tf.log(output / (1 - output)) 157 | 158 | return tf.nn.weighted_cross_entropy_with_logits(targets=target, 159 | logits=output, 160 | pos_weight=pos_weight) 161 | return weighted_binary_crossentropy 162 | 163 | 164 | def margin_loss(margin=0.4, downweight=0.5, pos_weight=1.0): 165 | ''' 166 | Args: 167 | margin: scalar, the margin after subtracting 0.5 from raw_logits. 168 | downweight: scalar, the factor for negative cost. 169 | ''' 170 | 171 | def _margin_loss(labels, raw_logits): 172 | """Penalizes deviations from margin for each logit. 173 | 174 | Each wrong logit costs its distance to margin. For negative logits margin is 175 | 0.1 and for positives it is 0.9. First subtract 0.5 from all logits. Now 176 | margin is 0.4 from each side. 177 | 178 | Args: 179 | labels: tensor, one hot encoding of ground truth. 180 | raw_logits: tensor, model predictions in range [0, 1] 181 | 182 | 183 | Returns: 184 | A tensor with cost for each data point of shape [batch_size]. 185 | """ 186 | logits = raw_logits - 0.5 187 | positive_cost = pos_weight * labels * tf.cast(tf.less(logits, margin), 188 | tf.float32) * tf.pow(logits - margin, 2) 189 | negative_cost = (1 - labels) * tf.cast( 190 | tf.greater(logits, -margin), tf.float32) * tf.pow(logits + margin, 2) 191 | return 0.5 * positive_cost + downweight * 0.5 * negative_cost 192 | 193 | return _margin_loss 194 | 195 | -------------------------------------------------------------------------------- /utils/data_helper.py: -------------------------------------------------------------------------------- 1 | ''' 2 | This program is a helper to separate the codes of 2D and 3D image processing. 3 | 4 | @author: Cheng-Lin Li a.k.a. Clark 5 | 6 | @copyright: 2018 Cheng-Lin Li@Insight AI. All rights reserved. 7 | 8 | @license: Licensed under the Apache License v2.0. http://www.apache.org/licenses/ 9 | 10 | @contact: clark.cl.li@gmail.com 11 | 12 | Tasks: 13 | The program is a helper to separate the codes of 2D and 3D image processing. 14 | 15 | This is a helper file for choosing which dataset functions to create. 16 | ''' 17 | import logging 18 | import utils.load_3D_data as ld3D 19 | import utils.load_2D_data as ld2D 20 | from enum import Enum, unique 21 | 22 | @unique 23 | class Dataset(Enum): 24 | luna16 = 1 25 | mscoco17 = 2 26 | 27 | def get_generator(dataset): 28 | if dataset == 'luna16': 29 | generate_train_batches = ld3D.generate_train_batches 30 | generate_val_batches = ld3D.generate_val_batches 31 | generate_test_batches = ld3D.generate_test_batches 32 | elif dataset == 'mscoco17': 33 | generate_train_batches = ld2D.generate_train_batches 34 | generate_val_batches = ld2D.generate_val_batches 35 | generate_test_batches = ld2D.generate_test_batches 36 | else: 37 | logging.error('Not valid dataset!') 38 | return None, None, None 39 | return generate_train_batches, generate_val_batches, generate_test_batches 40 | 41 | -------------------------------------------------------------------------------- /utils/load_2D_data.py: -------------------------------------------------------------------------------- 1 | ''' 2 | This program includes all functions of 2D color image processing for UNet, tiramisu, Capsule Nets (capsbasic) or SegCaps(segcapsr1 or segcapsr3). 3 | 4 | @author: Cheng-Lin Li a.k.a. Clark 5 | 6 | @copyright: 2018 Cheng-Lin Li@Insight AI. All rights reserved. 7 | 8 | @license: Licensed under the Apache License v2.0. http://www.apache.org/licenses/ 9 | 10 | @contact: clark.cl.li@gmail.com 11 | 12 | Tasks: 13 | The program based on parameters from main.py to load 2D color image files from folders. 14 | 15 | The program will convert all image files into numpy format then store training/testing images into 16 | ./data/np_files and training (and testing) file lists under ./data/split_list folders. 17 | You need to remove these two folders every time if you want to replace your training image and mask files. 18 | The program will only read data from np_files folders. 19 | 20 | Data: 21 | MS COCO 2017 or LUNA 2016 were tested on this package. 22 | You can leverage your own data set but the mask images should follow the format of MS COCO or with background color = 0 on each channel. 23 | 24 | 25 | Features: 26 | 1. Integrated with MS COCO 2017 dataset. 27 | 2. Use PILLOW library instead of SimpleITK for better support on RasberryPi 28 | 3. add new generate_test_image function to process single image frame for video stream 29 | ''' 30 | 31 | from __future__ import print_function 32 | # import threading 33 | import logging 34 | from os.path import join, basename 35 | from os import makedirs 36 | 37 | import numpy as np 38 | from numpy.random import rand, shuffle 39 | from PIL import Image 40 | 41 | # import matplotlib 42 | # matplotlib.use('Agg') 43 | import matplotlib.pyplot as plt 44 | 45 | 46 | plt.ioff() 47 | 48 | from utils.custom_data_aug import augmentImages, convert_img_data, convert_mask_data 49 | from utils.threadsafe import threadsafe_generator 50 | 51 | debug = 0 52 | 53 | def convert_data_to_numpy(root_path, img_name, no_masks=False, overwrite=False): 54 | fname = img_name[:-4] 55 | numpy_path = join(root_path, 'np_files') 56 | img_path = join(root_path, 'imgs') 57 | mask_path = join(root_path, 'masks') 58 | fig_path = join(root_path, 'figs') 59 | try: 60 | makedirs(numpy_path) 61 | except: 62 | pass 63 | try: 64 | makedirs(fig_path) 65 | except: 66 | pass 67 | 68 | if not overwrite: 69 | try: 70 | with np.load(join(numpy_path, fname + '.npz')) as data: 71 | return data['img'], data['mask'] 72 | except: 73 | pass 74 | 75 | try: 76 | img = np.array(Image.open(join(img_path, img_name))) 77 | # Conver image to 3 dimensions 78 | img = convert_img_data(img, 3) 79 | 80 | if not no_masks: 81 | # Replace SimpleITK to PILLOW for 2D image support on Raspberry Pi 82 | mask = np.array(Image.open(join(mask_path, img_name))) # (x,y,4) 83 | 84 | mask = convert_mask_data(mask) 85 | 86 | if not no_masks: 87 | np.savez_compressed(join(numpy_path, fname + '.npz'), img=img, mask=mask) 88 | else: 89 | np.savez_compressed(join(numpy_path, fname + '.npz'), img=img) 90 | 91 | if not no_masks: 92 | return img, mask 93 | else: 94 | return img 95 | 96 | except Exception as e: 97 | print('\n'+'-'*100) 98 | print('Unable to load img or masks for {}'.format(fname)) 99 | print(e) 100 | print('Skipping file') 101 | print('-'*100+'\n') 102 | 103 | return np.zeros(1), np.zeros(1) 104 | 105 | 106 | def get_slice(image_data): 107 | return image_data[2] 108 | 109 | @threadsafe_generator 110 | def generate_train_batches(root_path, train_list, net_input_shape, net, batchSize=1, numSlices=1, subSampAmt=-1, 111 | stride=1, downSampAmt=1, shuff=1, aug_data=1): 112 | # Create placeholders for training 113 | # (img_shape[1], img_shape[2], args.slices) 114 | logging.info('\n2d_generate_train_batches') 115 | img_batch = np.zeros((np.concatenate(((batchSize,), net_input_shape))), dtype=np.float32) 116 | mask_batch = np.zeros((np.concatenate(((batchSize,), (net_input_shape[0], net_input_shape[1], 1)))), dtype=np.uint8) 117 | 118 | while True: 119 | if shuff: 120 | shuffle(train_list) 121 | count = 0 122 | for i, scan_name in enumerate(train_list): 123 | try: 124 | # Read image file from pre-processing image numpy format compression files. 125 | scan_name = scan_name[0] 126 | path_to_np = join(root_path,'np_files',basename(scan_name)[:-3]+'npz') 127 | logging.info('\npath_to_np=%s'%(path_to_np)) 128 | with np.load(path_to_np) as data: 129 | train_img = data['img'] 130 | train_mask = data['mask'] 131 | except: 132 | logging.info('\nPre-made numpy array not found for {}.\nCreating now...'.format(scan_name[:-4])) 133 | train_img, train_mask = convert_data_to_numpy(root_path, scan_name) 134 | if np.array_equal(train_img,np.zeros(1)): 135 | continue 136 | else: 137 | logging.info('\nFinished making npz file.') 138 | 139 | if numSlices == 1: 140 | subSampAmt = 0 141 | elif subSampAmt == -1 and numSlices > 1: # Only one slices. code can be removed. 142 | np.random.seed(None) 143 | subSampAmt = int(rand(1)*(train_img.shape[2]*0.05)) 144 | # We don't need indicies in 2D image. 145 | indicies = np.arange(0, train_img.shape[2] - numSlices * (subSampAmt + 1) + 1, stride) 146 | if shuff: 147 | shuffle(indicies) 148 | 149 | for j in indicies: 150 | if not np.any(train_mask[:, :, j:j + numSlices * (subSampAmt+1):subSampAmt+1]): 151 | continue 152 | if img_batch.ndim == 4: 153 | img_batch[count, :, :, :] = train_img[:, :, j:j + numSlices * (subSampAmt+1):subSampAmt+1] 154 | mask_batch[count, :, :, :] = train_mask[:, :, j:j + numSlices * (subSampAmt+1):subSampAmt+1] 155 | elif img_batch.ndim == 5: 156 | # Assumes img and mask are single channel. Replace 0 with : if multi-channel. 157 | img_batch[count, :, :, :, 0] = train_img[:, :, j:j + numSlices * (subSampAmt+1):subSampAmt+1] 158 | mask_batch[count, :, :, :, 0] = train_mask[:, :, j:j + numSlices * (subSampAmt+1):subSampAmt+1] 159 | else: 160 | logging.error('\nError this function currently only supports 2D and 3D data.') 161 | exit(0) 162 | 163 | count += 1 164 | if count % batchSize == 0: 165 | count = 0 166 | if aug_data: 167 | img_batch, mask_batch = augmentImages(img_batch, mask_batch) 168 | if debug: 169 | if img_batch.ndim == 4: 170 | plt.imshow(np.squeeze(img_batch[0, :, :, 0]), cmap='gray') 171 | plt.imshow(np.squeeze(mask_batch[0, :, :, 0]), alpha=0.15) 172 | elif img_batch.ndim == 5: 173 | plt.imshow(np.squeeze(img_batch[0, :, :, 0, 0]), cmap='gray') 174 | plt.imshow(np.squeeze(mask_batch[0, :, :, 0, 0]), alpha=0.15) 175 | plt.savefig(join(root_path, 'logs', 'ex_train.png'), format='png', bbox_inches='tight') 176 | plt.close() 177 | if net.find('caps') != -1: # if the network is capsule/segcaps structure 178 | # [(1, 512, 512, 3), (1, 512, 512, 1)], [(1, 512, 512, 1), (1, 512, 512, 3)] 179 | # or [(1, 512, 512, 3), (1, 512, 512, 3)], [(1, 512, 512, 3), (1, 512, 512, 3)] 180 | yield ([img_batch, mask_batch], [mask_batch, mask_batch*img_batch]) 181 | else: 182 | yield (img_batch, mask_batch) 183 | 184 | if count != 0: 185 | if aug_data: 186 | img_batch[:count,...], mask_batch[:count,...] = augmentImages(img_batch[:count,...], 187 | mask_batch[:count,...]) 188 | if net.find('caps') != -1: 189 | yield ([img_batch[:count, ...], mask_batch[:count, ...]], 190 | [mask_batch[:count, ...], mask_batch[:count, ...] * img_batch[:count, ...]]) 191 | else: 192 | yield (img_batch[:count,...], mask_batch[:count,...]) 193 | 194 | @threadsafe_generator 195 | def generate_val_batches(root_path, val_list, net_input_shape, net, batchSize=1, numSlices=1, subSampAmt=-1, 196 | stride=1, downSampAmt=1, shuff=1): 197 | logging.info('2d_generate_val_batches') 198 | # Create placeholders for validation 199 | img_batch = np.zeros((np.concatenate(((batchSize,), net_input_shape))), dtype=np.float32) 200 | mask_batch = np.zeros((np.concatenate(((batchSize,), net_input_shape))), dtype=np.uint8) 201 | 202 | while True: 203 | if shuff: 204 | shuffle(val_list) 205 | count = 0 206 | for i, scan_name in enumerate(val_list): 207 | try: 208 | scan_name = scan_name[0] 209 | path_to_np = join(root_path,'np_files',basename(scan_name)[:-3]+'npz') 210 | with np.load(path_to_np) as data: 211 | val_img = data['img'] 212 | val_mask = data['mask'] 213 | except: 214 | logging.info('\nPre-made numpy array not found for {}.\nCreating now...'.format(scan_name[:-4])) 215 | val_img, val_mask = convert_data_to_numpy(root_path, scan_name) 216 | if np.array_equal(val_img,np.zeros(1)): 217 | continue 218 | else: 219 | logging.info('\nFinished making npz file.') 220 | 221 | # New added for debugging 222 | if numSlices == 1: 223 | subSampAmt = 0 224 | elif subSampAmt == -1 and numSlices > 1: # Only one slices. code can be removed. 225 | np.random.seed(None) 226 | subSampAmt = int(rand(1)*(val_img.shape[2]*0.05)) 227 | 228 | # We don't need indicies in 2D image. 229 | indicies = np.arange(0, val_img.shape[2] - numSlices * (subSampAmt + 1) + 1, stride) 230 | if shuff: 231 | shuffle(indicies) 232 | 233 | for j in indicies: 234 | if not np.any(val_mask[:, :, j:j + numSlices * (subSampAmt+1):subSampAmt+1]): 235 | continue 236 | if img_batch.ndim == 4: 237 | img_batch[count, :, :, :] = val_img[:, :, j:j + numSlices * (subSampAmt+1):subSampAmt+1] 238 | mask_batch[count, :, :, :] = val_mask[:, :, j:j + numSlices * (subSampAmt+1):subSampAmt+1] 239 | elif img_batch.ndim == 5: 240 | # Assumes img and mask are single channel. Replace 0 with : if multi-channel. 241 | img_batch[count, :, :, :, 0] = val_img[:, :, j:j + numSlices * (subSampAmt+1):subSampAmt+1] 242 | mask_batch[count, :, :, :, 0] = val_mask[:, :, j:j + numSlices * (subSampAmt+1):subSampAmt+1] 243 | else: 244 | logging.error('\nError this function currently only supports 2D and 3D data.') 245 | exit(0) 246 | 247 | count += 1 248 | if count % batchSize == 0: 249 | count = 0 250 | if net.find('caps') != -1: 251 | yield ([img_batch, mask_batch], [mask_batch, mask_batch * img_batch]) 252 | else: 253 | yield (img_batch, mask_batch) 254 | 255 | if count != 0: 256 | if net.find('caps') != -1: 257 | yield ([img_batch[:count, ...], mask_batch[:count, ...]], 258 | [mask_batch[:count, ...], mask_batch[:count, ...] * img_batch[:count, ...]]) 259 | else: 260 | yield (img_batch[:count,...], mask_batch[:count,...]) 261 | 262 | @threadsafe_generator 263 | def generate_test_batches(root_path, test_list, net_input_shape, batchSize=1, numSlices=1, subSampAmt=0, 264 | stride=1, downSampAmt=1): 265 | # Create placeholders for testing 266 | logging.info('\nload_2D_data.generate_test_batches') 267 | img_batch = np.zeros((np.concatenate(((batchSize,), net_input_shape))), dtype=np.float32) 268 | count = 0 269 | logging.info('\nload_2D_data.generate_test_batches: test_list=%s'%(test_list)) 270 | for i, scan_name in enumerate(test_list): 271 | try: 272 | scan_name = scan_name[0] 273 | path_to_np = join(root_path,'np_files',basename(scan_name)[:-3]+'npz') 274 | with np.load(path_to_np) as data: 275 | test_img = data['img'] # (512, 512, 1) 276 | except: 277 | logging.info('\nPre-made numpy array not found for {}.\nCreating now...'.format(scan_name[:-4])) 278 | test_img = convert_data_to_numpy(root_path, scan_name, no_masks=True) 279 | if np.array_equal(test_img,np.zeros(1)): 280 | continue 281 | else: 282 | logging.info('\nFinished making npz file.') 283 | 284 | indicies = np.arange(0, test_img.shape[2] - numSlices * (subSampAmt + 1) + 1, stride) 285 | for j in indicies: 286 | if img_batch.ndim == 4: 287 | # (1, 512, 512, 1) 288 | img_batch[count, :, :, :] = test_img[:, :, j:j + numSlices * (subSampAmt+1):subSampAmt+1] 289 | elif img_batch.ndim == 5: 290 | # Assumes img and mask are single channel. Replace 0 with : if multi-channel. 291 | img_batch[count, :, :, :, 0] = test_img[:, :, j:j + numSlices * (subSampAmt+1):subSampAmt+1] 292 | else: 293 | logging.error('\nError this function currently only supports 2D and 3D data.') 294 | exit(0) 295 | 296 | count += 1 297 | if count % batchSize == 0: 298 | count = 0 299 | yield (img_batch) 300 | 301 | if count != 0: 302 | yield (img_batch[:count,:,:,:]) 303 | 304 | @threadsafe_generator 305 | def generate_test_image(test_img, net_input_shape, batchSize=1, numSlices=1, subSampAmt=0, 306 | stride=1, downSampAmt=1): 307 | ''' 308 | test_img: numpy.array of image data, (height, width, channels) 309 | 310 | ''' 311 | # Create placeholders for testing 312 | logging.info('\nload_2D_data.generate_test_image') 313 | # Convert image to 4 dimensions 314 | test_img = convert_img_data(test_img, 4) 315 | 316 | yield (test_img) 317 | -------------------------------------------------------------------------------- /utils/load_3D_data.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Capsules for Object Segmentation (SegCaps) 3 | Original Paper by Rodney LaLonde and Ulas Bagci (https://arxiv.org/abs/1804.04241) 4 | Code written by: Rodney LaLonde 5 | If you use significant portions of this code or the ideas from our paper, please cite it :) 6 | If you have any questions, please email me at lalonde@knights.ucf.edu. 7 | 8 | This file is used for loading training, validation, and testing data into the models. 9 | It is specifically designed to handle 3D single-channel medical data. 10 | Modifications will be needed to train/test on normal 3-channel images. 11 | 12 | ===== 13 | This program includes all functions of 3D image processing for UNet, tiramisu, Capsule Nets (capsbasic) or SegCaps(segcapsr1 or segcapsr3). 14 | 15 | @author: Cheng-Lin Li a.k.a. Clark 16 | 17 | @copyright: 2018 Cheng-Lin Li@Insight AI. All rights reserved. 18 | 19 | @license: Licensed under the Apache License v2.0. http://www.apache.org/licenses/ 20 | 21 | @contact: clark.cl.li@gmail.com 22 | 23 | Tasks: 24 | The program based on parameters from main.py to load 3D image files from folders. 25 | 26 | The program will convert all image files into numpy format then store training/testing images into 27 | ./data/np_files and training (and testing) file lists under ./data/split_list folders. 28 | You need to remove these two folders every time if you want to replace your training image and mask files. 29 | The program will only read data from np_files folders. 30 | 31 | Data: 32 | MS COCO 2017 or LUNA 2016 were tested on this package. 33 | You can leverage your own data set but the mask images should follow the format of MS COCO or with background color = 0 on each channel. 34 | 35 | 36 | Enhancement: 37 | 1. Porting to Python version 3.6 38 | 2. Remove program code cleaning 39 | ''' 40 | 41 | from __future__ import print_function 42 | 43 | import logging 44 | from os.path import join, basename 45 | from os import makedirs 46 | 47 | import numpy as np 48 | from numpy.random import rand, shuffle 49 | import SimpleITK as sitk 50 | 51 | import matplotlib.pyplot as plt 52 | 53 | plt.ioff() 54 | 55 | from utils.custom_data_aug import augmentImages 56 | from utils.threadsafe import threadsafe_generator 57 | 58 | debug = 0 59 | 60 | def convert_data_to_numpy(root_path, img_name, no_masks=False, overwrite=False): 61 | fname = img_name[:-4] 62 | numpy_path = join(root_path, 'np_files') 63 | img_path = join(root_path, 'imgs') 64 | mask_path = join(root_path, 'masks') 65 | fig_path = join(root_path, 'figs') 66 | try: 67 | makedirs(numpy_path) 68 | except: 69 | pass 70 | try: 71 | makedirs(fig_path) 72 | except: 73 | pass 74 | # The min and max pixel values in a ct image file 75 | ct_min = -1024 76 | ct_max = 3072 77 | 78 | if not overwrite: 79 | try: 80 | with np.load(join(numpy_path, fname + '.npz')) as data: 81 | return data['img'], data['mask'] 82 | except: 83 | pass 84 | 85 | try: 86 | itk_img = sitk.ReadImage(join(img_path, img_name)) 87 | # img=(slices, x, y) e.g. (124, 512, 512) 88 | img = sitk.GetArrayFromImage(itk_img) 89 | # roll axis from (slices, x, y)=(124, 512, 512) to (x, y, slices)=(512, 512, 124) 90 | img = np.rollaxis(img, 0, 3) 91 | img = img.astype(np.float32) 92 | # Normalized image for each pixel 93 | img[img > ct_max] = ct_max # set max value 94 | img[img < ct_min] = ct_min # set min value 95 | img += -ct_min # shift all pixel value to let min value = 0 96 | img /= (ct_max + -ct_min) # normalized pixel based on the range of max and min. 97 | 98 | if not no_masks: 99 | itk_mask = sitk.ReadImage(join(mask_path, img_name)) 100 | mask = sitk.GetArrayFromImage(itk_mask) 101 | # Switch axis=0 to the position start from = 3 102 | mask = np.rollaxis(mask, 0, 3) 103 | mask[mask > 250] = 1 # In case using 255 instead of 1 104 | mask[mask > 4.5] = 0 # Trachea = 5 105 | mask[mask >= 1] = 1 # Left lung = 3, Right lung = 4 106 | mask[mask != 1] = 0 # Non-Lung/Background 107 | mask = mask.astype(np.uint8) 108 | 109 | try: 110 | f, ax = plt.subplots(1, 3, figsize=(15, 5)) 111 | 112 | ax[0].imshow(img[:, :, img.shape[2] // 3], cmap='gray') 113 | if not no_masks: 114 | ax[0].imshow(mask[:, :, img.shape[2] // 3], alpha=0.15) 115 | ax[0].set_title('Slice {}/{}'.format(img.shape[2] // 3, img.shape[2])) 116 | ax[0].axis('off') 117 | 118 | ax[1].imshow(img[:, :, img.shape[2] // 2], cmap='gray') 119 | if not no_masks: 120 | ax[1].imshow(mask[:, :, img.shape[2] // 2], alpha=0.15) 121 | ax[1].set_title('Slice {}/{}'.format(img.shape[2] // 2, img.shape[2])) 122 | ax[1].axis('off') 123 | 124 | ax[2].imshow(img[:, :, img.shape[2] // 2 + img.shape[2] // 4], cmap='gray') 125 | if not no_masks: 126 | ax[2].imshow(mask[:, :, img.shape[2] // 2 + img.shape[2] // 4], alpha=0.15) 127 | ax[2].set_title('Slice {}/{}'.format(img.shape[2] // 2 + img.shape[2] // 4, img.shape[2])) 128 | ax[2].axis('off') 129 | 130 | fig = plt.gcf() 131 | fig.suptitle(fname) 132 | 133 | plt.savefig(join(fig_path, fname + '.png'), format='png', bbox_inches='tight') 134 | plt.close(fig) 135 | except Exception as e: 136 | logging.error('\n'+'-'*100) 137 | logging.error('Error creating qualitative figure for {}'.format(fname)) 138 | logging.error(e) 139 | logging.error('-'*100+'\n') 140 | 141 | if not no_masks: 142 | np.savez_compressed(join(numpy_path, fname + '.npz'), img=img, mask=mask) 143 | else: 144 | np.savez_compressed(join(numpy_path, fname + '.npz'), img=img) 145 | 146 | if not no_masks: 147 | return img, mask 148 | else: 149 | return img 150 | 151 | except Exception as e: 152 | logging.error('\n'+'-'*100) 153 | logging.error('Unable to load img or masks for {}'.format(fname)) 154 | logging.error(e) 155 | logging.error('Skipping file') 156 | logging.error('-'*100+'\n') 157 | 158 | return np.zeros(1), np.zeros(1) 159 | 160 | 161 | @threadsafe_generator 162 | def generate_train_batches(root_path, train_list, net_input_shape, net, batchSize=1, numSlices=1, subSampAmt=-1, 163 | stride=1, downSampAmt=1, shuff=1, aug_data=1): 164 | # Create placeholders for training 165 | # (img_shape[1], img_shape[2], args.slices) 166 | img_batch = np.zeros((np.concatenate(((batchSize,), net_input_shape))), dtype=np.float32) 167 | mask_batch = np.zeros((np.concatenate(((batchSize,), net_input_shape))), dtype=np.uint8) 168 | 169 | while True: 170 | if shuff: 171 | shuffle(train_list) 172 | count = 0 173 | for i, scan_name in enumerate(train_list): 174 | try: 175 | scan_name = scan_name[0] 176 | path_to_np = join(root_path,'np_files',basename(scan_name)[:-3]+'npz') 177 | logging.info('\npath_to_np=%s'%(path_to_np)) 178 | with np.load(path_to_np) as data: 179 | train_img = data['img'] 180 | train_mask = data['mask'] 181 | except: 182 | logging.info('\nPre-made numpy array not found for {}.\nCreating now...'.format(scan_name[:-4])) 183 | train_img, train_mask = convert_data_to_numpy(root_path, scan_name) 184 | if np.array_equal(train_img,np.zeros(1)): 185 | continue 186 | else: 187 | logging.info('\nFinished making npz file.') 188 | 189 | if numSlices == 1: 190 | subSampAmt = 0 191 | elif subSampAmt == -1 and numSlices > 1: 192 | np.random.seed(None) 193 | subSampAmt = int(rand(1)*(train_img.shape[2]*0.05)) 194 | 195 | indicies = np.arange(0, train_img.shape[2] - numSlices * (subSampAmt + 1) + 1, stride) 196 | if shuff: 197 | shuffle(indicies) 198 | 199 | for j in indicies: 200 | if not np.any(train_mask[:, :, j:j + numSlices * (subSampAmt+1):subSampAmt+1]): 201 | continue 202 | if img_batch.ndim == 4: 203 | img_batch[count, :, :, :] = train_img[:, :, j:j + numSlices * (subSampAmt+1):subSampAmt+1] 204 | mask_batch[count, :, :, :] = train_mask[:, :, j:j + numSlices * (subSampAmt+1):subSampAmt+1] 205 | elif img_batch.ndim == 5: 206 | # Assumes img and mask are single channel. Replace 0 with : if multi-channel. 207 | img_batch[count, :, :, :, 0] = train_img[:, :, j:j + numSlices * (subSampAmt+1):subSampAmt+1] 208 | mask_batch[count, :, :, :, 0] = train_mask[:, :, j:j + numSlices * (subSampAmt+1):subSampAmt+1] 209 | else: 210 | logging.error('\nError this function currently only supports 2D and 3D data.') 211 | exit(0) 212 | 213 | count += 1 214 | if count % batchSize == 0: 215 | count = 0 216 | if aug_data: 217 | img_batch, mask_batch = augmentImages(img_batch, mask_batch) 218 | if debug: 219 | if img_batch.ndim == 4: 220 | plt.imshow(np.squeeze(img_batch[0, :, :, 0]), cmap='gray') 221 | plt.imshow(np.squeeze(mask_batch[0, :, :, 0]), alpha=0.15) 222 | elif img_batch.ndim == 5: 223 | plt.imshow(np.squeeze(img_batch[0, :, :, 0, 0]), cmap='gray') 224 | plt.imshow(np.squeeze(mask_batch[0, :, :, 0, 0]), alpha=0.15) 225 | plt.savefig(join(root_path, 'logs', 'ex_train.png'), format='png', bbox_inches='tight') 226 | plt.close() 227 | if net.find('caps') != -1: # if the network is capsule/segcaps structure 228 | yield ([img_batch, mask_batch], [mask_batch, mask_batch*img_batch]) 229 | else: 230 | yield (img_batch, mask_batch) 231 | 232 | if count != 0: 233 | if aug_data: 234 | img_batch[:count,...], mask_batch[:count,...] = augmentImages(img_batch[:count,...], 235 | mask_batch[:count,...]) 236 | if net.find('caps') != -1: 237 | yield ([img_batch[:count, ...], mask_batch[:count, ...]], 238 | [mask_batch[:count, ...], mask_batch[:count, ...] * img_batch[:count, ...]]) 239 | else: 240 | yield (img_batch[:count,...], mask_batch[:count,...]) 241 | 242 | @threadsafe_generator 243 | def generate_val_batches(root_path, val_list, net_input_shape, net, batchSize=1, numSlices=1, subSampAmt=-1, 244 | stride=1, downSampAmt=1, shuff=1): 245 | # Create placeholders for validation 246 | img_batch = np.zeros((np.concatenate(((batchSize,), net_input_shape))), dtype=np.float32) 247 | mask_batch = np.zeros((np.concatenate(((batchSize,), net_input_shape))), dtype=np.uint8) 248 | 249 | while True: 250 | if shuff: 251 | shuffle(val_list) 252 | count = 0 253 | for i, scan_name in enumerate(val_list): 254 | try: 255 | scan_name = scan_name[0] 256 | path_to_np = join(root_path,'np_files',basename(scan_name)[:-3]+'npz') 257 | with np.load(path_to_np) as data: 258 | val_img = data['img'] 259 | val_mask = data['mask'] 260 | except: 261 | logging.info('\nPre-made numpy array not found for {}.\nCreating now...'.format(scan_name[:-4])) 262 | val_img, val_mask = convert_data_to_numpy(root_path, scan_name) 263 | if np.array_equal(val_img,np.zeros(1)): 264 | continue 265 | else: 266 | logging.info('\nFinished making npz file.') 267 | 268 | if numSlices == 1: 269 | subSampAmt = 0 270 | elif subSampAmt == -1 and numSlices > 1: 271 | np.random.seed(None) 272 | subSampAmt = int(rand(1)*(val_img.shape[2]*0.05)) 273 | 274 | indicies = np.arange(0, val_img.shape[2] - numSlices * (subSampAmt + 1) + 1, stride) 275 | if shuff: 276 | shuffle(indicies) 277 | 278 | for j in indicies: 279 | if not np.any(val_mask[:, :, j:j + numSlices * (subSampAmt+1):subSampAmt+1]): 280 | continue 281 | if img_batch.ndim == 4: 282 | img_batch[count, :, :, :] = val_img[:, :, j:j + numSlices * (subSampAmt+1):subSampAmt+1] 283 | mask_batch[count, :, :, :] = val_mask[:, :, j:j + numSlices * (subSampAmt+1):subSampAmt+1] 284 | elif img_batch.ndim == 5: 285 | # Assumes img and mask are single channel. Replace 0 with : if multi-channel. 286 | img_batch[count, :, :, :, 0] = val_img[:, :, j:j + numSlices * (subSampAmt+1):subSampAmt+1] 287 | mask_batch[count, :, :, :, 0] = val_mask[:, :, j:j + numSlices * (subSampAmt+1):subSampAmt+1] 288 | else: 289 | logging.error('\nError this function currently only supports 2D and 3D data.') 290 | exit(0) 291 | 292 | count += 1 293 | if count % batchSize == 0: 294 | count = 0 295 | if net.find('caps') != -1: 296 | yield ([img_batch, mask_batch], [mask_batch, mask_batch * img_batch]) 297 | else: 298 | yield (img_batch, mask_batch) 299 | 300 | if count != 0: 301 | if net.find('caps') != -1: 302 | yield ([img_batch[:count, ...], mask_batch[:count, ...]], 303 | [mask_batch[:count, ...], mask_batch[:count, ...] * img_batch[:count, ...]]) 304 | else: 305 | yield (img_batch[:count,...], mask_batch[:count,...]) 306 | 307 | @threadsafe_generator 308 | def generate_test_batches(root_path, test_list, net_input_shape, batchSize=1, numSlices=1, subSampAmt=0, 309 | stride=1, downSampAmt=1): 310 | # Create placeholders for testing 311 | logging.info('\nload_3D_data.generate_test_batches') 312 | img_batch = np.zeros((np.concatenate(((batchSize,), net_input_shape))), dtype=np.float32) 313 | count = 0 314 | logging.info('\nload_3D_data.generate_test_batches: test_list=%s'%(test_list)) 315 | for i, scan_name in enumerate(test_list): 316 | try: 317 | scan_name = scan_name[0] 318 | path_to_np = join(root_path,'np_files',basename(scan_name)[:-3]+'npz') 319 | with np.load(path_to_np) as data: 320 | test_img = data['img'] 321 | except: 322 | logging.info('\nPre-made numpy array not found for {}.\nCreating now...'.format(scan_name[:-4])) 323 | test_img = convert_data_to_numpy(root_path, scan_name, no_masks=True) 324 | if np.array_equal(test_img,np.zeros(1)): 325 | continue 326 | else: 327 | logging.info('\nFinished making npz file.') 328 | 329 | if numSlices == 1: 330 | subSampAmt = 0 331 | elif subSampAmt == -1 and numSlices > 1: 332 | np.random.seed(None) 333 | subSampAmt = int(rand(1)*(test_img.shape[2]*0.05)) 334 | 335 | indicies = np.arange(0, test_img.shape[2] - numSlices * (subSampAmt + 1) + 1, stride) 336 | for j in indicies: 337 | if img_batch.ndim == 4: 338 | img_batch[count, :, :, :] = test_img[:, :, j:j + numSlices * (subSampAmt+1):subSampAmt+1] 339 | elif img_batch.ndim == 5: 340 | # Assumes img and mask are single channel. Replace 0 with : if multi-channel. 341 | img_batch[count, :, :, :, 0] = test_img[:, :, j:j + numSlices * (subSampAmt+1):subSampAmt+1] 342 | else: 343 | logging.error('Error this function currently only supports 2D and 3D data.') 344 | exit(0) 345 | 346 | count += 1 347 | if count % batchSize == 0: 348 | count = 0 349 | yield (img_batch) 350 | 351 | if count != 0: 352 | yield (img_batch[:count,:,:,:]) 353 | -------------------------------------------------------------------------------- /utils/load_data.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Capsules for Object Segmentation (SegCaps) 3 | Original Paper by Rodney LaLonde and Ulas Bagci (https://arxiv.org/abs/1804.04241) 4 | Code written by: Rodney LaLonde 5 | If you use significant portions of this code or the ideas from our paper, please cite it :) 6 | If you have any questions, please email me at lalonde@knights.ucf.edu. 7 | 8 | This file is used for loading training, validation, and testing data into the models. 9 | It is specifically designed to handle 3D single-channel medical data. 10 | Modifications will be needed to train/test on normal 3-channel images. 11 | 12 | Enhancement: 13 | 0. Porting to Python version 3.6 14 | 1. Add image_resize2square to accept any size of images and change to 512 X 512 resolutions. 15 | 2. 16 | ''' 17 | 18 | from __future__ import print_function 19 | 20 | import logging 21 | 22 | from os.path import join, basename 23 | from os import makedirs 24 | from glob import glob 25 | import csv 26 | from sklearn.model_selection import KFold 27 | import numpy as np 28 | 29 | import SimpleITK as sitk 30 | from sklearn.model_selection import train_test_split 31 | from tqdm import tqdm #Progress bar 32 | import matplotlib 33 | matplotlib.use('Agg') 34 | import matplotlib.pyplot as plt 35 | 36 | plt.ioff() 37 | 38 | debug = 0 39 | 40 | 41 | 42 | def load_data(root, split): 43 | # Load the training and testing lists 44 | with open(join(root, 'split_lists', 'train_split_' + str(split) + '.csv'), 'r') as f: 45 | reader = csv.reader(f) 46 | training_list = list(reader) 47 | 48 | with open(join(root, 'split_lists', 'test_split_' + str(split) + '.csv'), 'r') as f: 49 | reader = csv.reader(f) 50 | testing_list = list(reader) 51 | 52 | new_training_list, validation_list = train_test_split(training_list, test_size = 0.1, random_state = 7) 53 | if new_training_list == []: # if training_list only have 1 image file. 54 | new_training_list = validation_list 55 | return new_training_list, validation_list, testing_list 56 | 57 | def compute_class_weights(root, train_data_list): 58 | ''' 59 | We want to weight the the positive pixels by the ratio of negative to positive. 60 | Three scenarios: 61 | 1. Equal classes. neg/pos ~ 1. Standard binary cross-entropy 62 | 2. Many more negative examples. The network will learn to always output negative. In this way we want to 63 | increase the punishment for getting a positive wrong that way it will want to put positive more 64 | 3. Many more positive examples. We weight the positive value less so that negatives have a chance. 65 | ''' 66 | pos = 0.0 67 | neg = 0.0 68 | for img_name in tqdm(train_data_list): 69 | img = sitk.GetArrayFromImage(sitk.ReadImage(join(root, 'masks', img_name[0]))) 70 | for slic in img: 71 | if not np.any(slic): 72 | continue 73 | else: 74 | p = np.count_nonzero(slic) 75 | pos += p 76 | neg += (slic.size - p) 77 | 78 | return neg/pos 79 | 80 | def load_class_weights(root, split): 81 | class_weight_filename = join(root, 'split_lists', 'train_split_' + str(split) + '_class_weights.npy') 82 | try: 83 | return np.load(class_weight_filename) 84 | except: 85 | logging.warning('\nClass weight file {} not found.\nComputing class weights now. This may take ' 86 | 'some time.'.format(class_weight_filename)) 87 | train_data_list, _, _ = load_data(root, str(split)) 88 | value = compute_class_weights(root, train_data_list) 89 | np.save(class_weight_filename,value) 90 | logging.warning('\nFinished computing class weights. This value has been saved for this training split.') 91 | return value 92 | 93 | 94 | def split_data(root_path, num_splits): 95 | mask_list = [] 96 | for ext in ('*.mhd', '*.hdr', '*.nii', '*.png'): #add png file support 97 | mask_list.extend(sorted(glob(join(root_path,'masks',ext)))) # check imgs instead of masks 98 | 99 | assert len(mask_list) != 0, 'Unable to find any files in {}'.format(join(root_path,'masks')) 100 | 101 | outdir = join(root_path,'split_lists') 102 | try: 103 | makedirs(outdir) 104 | except: 105 | pass 106 | 107 | if num_splits == 1: 108 | # Testing model, training set = testing set = 1 image 109 | train_index = test_index = mask_list 110 | with open(join(outdir,'train_split_' + str(0) + '.csv'), 'w', encoding='utf-8', newline='') as csvfile: 111 | writer = csv.writer(csvfile, delimiter=',', quotechar='|', quoting=csv.QUOTE_MINIMAL) 112 | print('basename=%s'%([basename(mask_list[0])])) 113 | writer.writerow([basename(mask_list[0])]) 114 | with open(join(outdir,'test_split_' + str(0) + '.csv'), 'w', encoding='utf-8', newline='') as csvfile: 115 | writer = csv.writer(csvfile, delimiter=',', quotechar='|', quoting=csv.QUOTE_MINIMAL) 116 | writer.writerow([basename(mask_list[0])]) 117 | 118 | else: 119 | kf = KFold(n_splits=num_splits) 120 | n = 0 121 | for train_index, test_index in kf.split(mask_list): 122 | with open(join(outdir,'train_split_' + str(n) + '.csv'), 'w', encoding='utf-8', newline='') as csvfile: 123 | writer = csv.writer(csvfile, delimiter=',', quotechar='|', quoting=csv.QUOTE_MINIMAL) 124 | for i in train_index: 125 | print('basename=%s'%([basename(mask_list[i])])) 126 | writer.writerow([basename(mask_list[i])]) 127 | with open(join(outdir,'test_split_' + str(n) + '.csv'), 'w', encoding='utf-8', newline='') as csvfile: 128 | writer = csv.writer(csvfile, delimiter=',', quotechar='|', quoting=csv.QUOTE_MINIMAL) 129 | for i in test_index: 130 | writer.writerow([basename(mask_list[i])]) 131 | n += 1 132 | 133 | 134 | 135 | -------------------------------------------------------------------------------- /utils/model_helper.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Capsules for Object Segmentation (SegCaps) 3 | Original Paper by Rodney LaLonde and Ulas Bagci (https://arxiv.org/abs/1804.04241) 4 | Code written by: Rodney LaLonde 5 | If you use significant portions of this code or the ideas from our paper, please cite it :) 6 | If you have any questions, please email me at lalonde@knights.ucf.edu. 7 | 8 | This is a helper file for choosing which model to create. 9 | ''' 10 | import tensorflow as tf 11 | 12 | def create_model(args, input_shape, enable_decoder=True): 13 | # If using CPU or single GPU 14 | if args.gpus <= 1: 15 | if args.net == 'unet': 16 | from models.unet import UNet 17 | model = UNet(input_shape) 18 | return [model] 19 | elif args.net == 'tiramisu': 20 | from models.densenets import DenseNetFCN 21 | model = DenseNetFCN(input_shape) 22 | return [model] 23 | elif args.net == 'segcapsr1': 24 | from segcapsnet.capsnet import CapsNetR1 25 | model_list = CapsNetR1(input_shape) 26 | return model_list 27 | elif args.net == 'segcapsr3': 28 | from segcapsnet.capsnet import CapsNetR3 29 | model_list = CapsNetR3(input_shape, args.num_class, enable_decoder) 30 | return model_list 31 | elif args.net == 'capsbasic': 32 | from segcapsnet.capsnet import CapsNetBasic 33 | model_list = CapsNetBasic(input_shape) 34 | return model_list 35 | else: 36 | raise Exception('Unknown network type specified: {}'.format(args.net)) 37 | # If using multiple GPUs 38 | else: 39 | with tf.device("/cpu:0"): 40 | if args.net == 'unet': 41 | from models.unet import UNet 42 | model = UNet(input_shape) 43 | return [model] 44 | elif args.net == 'tiramisu': 45 | from models.densenets import DenseNetFCN 46 | model = DenseNetFCN(input_shape) 47 | return [model] 48 | elif args.net == 'segcapsr1': 49 | from segcapsnet.capsnet import CapsNetR1 50 | model_list = CapsNetR1(input_shape) 51 | return model_list 52 | elif args.net == 'segcapsr3': 53 | from segcapsnet.capsnet import CapsNetR3 54 | model_list = CapsNetR3(input_shape, args.num_class, enable_decoder) 55 | return model_list 56 | elif args.net == 'capsbasic': 57 | from segcapsnet.capsnet import CapsNetBasic 58 | model_list = CapsNetBasic(input_shape) 59 | return model_list 60 | else: 61 | raise Exception('Unknown network type specified: {}'.format(args.net)) -------------------------------------------------------------------------------- /utils/threadsafe.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Capsules for Object Segmentation (SegCaps) 3 | Original Paper by Rodney LaLonde and Ulas Bagci (https://arxiv.org/abs/1804.04241) 4 | Code written by: Rodney LaLonde 5 | If you use significant portions of this code or the ideas from our paper, please cite it :) 6 | If you have any questions, please email me at lalonde@knights.ucf.edu. 7 | 8 | This file is used for loading training, validation, and testing data into the models. 9 | It is specifically designed to handle 3D single-channel medical data. 10 | Modifications will be needed to train/test on normal 3-channel images. 11 | 12 | 13 | ''' 14 | import threading 15 | 16 | ''' Make the generators threadsafe in case of multiple threads ''' 17 | class threadsafe_iter: 18 | """Takes an iterator/generator and makes it thread-safe by 19 | serializing call to the `next` method of given iterator/generator. 20 | """ 21 | def __init__(self, it): 22 | self.it = it 23 | self.lock = threading.Lock() 24 | 25 | def __iter__(self): 26 | return self 27 | 28 | def __next__(self): 29 | with self.lock: 30 | return self.it.__next__() 31 | 32 | 33 | def threadsafe_generator(f): 34 | """A decorator that takes a generator function and makes it thread-safe. 35 | """ 36 | def g(*a, **kw): 37 | return threadsafe_iter(f(*a, **kw)) 38 | return g 39 | 40 | --------------------------------------------------------------------------------