├── LICENSE.md ├── README.md ├── __init__.py ├── input_data ├── __init__.py ├── dataset.py ├── image_processing.py └── input_data_build_image_data_with_mask.py ├── segdec_data.py ├── segdec_model.py ├── segdec_print_eval.py └── segdec_train.py /LICENSE.md: -------------------------------------------------------------------------------- 1 | # Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International 2 | 3 | Creative Commons Corporation (“Creative Commons”) is not a law firm and does not provide legal services or legal advice. Distribution of Creative Commons public licenses does not create a lawyer-client or other relationship. Creative Commons makes its licenses and related information available on an “as-is” basis. Creative Commons gives no warranties regarding its licenses, any material licensed under their terms and conditions, or any related information. Creative Commons disclaims all liability for damages resulting from their use to the fullest extent possible. 4 | 5 | ### Using Creative Commons Public Licenses 6 | 7 | Creative Commons public licenses provide a standard set of terms and conditions that creators and other rights holders may use to share original works of authorship and other material subject to copyright and certain other rights specified in the public license below. The following considerations are for informational purposes only, are not exhaustive, and do not form part of our licenses. 8 | 9 | * __Considerations for licensors:__ Our public licenses are intended for use by those authorized to give the public permission to use material in ways otherwise restricted by copyright and certain other rights. Our licenses are irrevocable. Licensors should read and understand the terms and conditions of the license they choose before applying it. Licensors should also secure all rights necessary before applying our licenses so that the public can reuse the material as expected. Licensors should clearly mark any material not subject to the license. This includes other CC-licensed material, or material used under an exception or limitation to copyright. [More considerations for licensors](http://wiki.creativecommons.org/Considerations_for_licensors_and_licensees#Considerations_for_licensors). 10 | 11 | * __Considerations for the public:__ By using one of our public licenses, a licensor grants the public permission to use the licensed material under specified terms and conditions. If the licensor’s permission is not necessary for any reason–for example, because of any applicable exception or limitation to copyright–then that use is not regulated by the license. Our licenses grant only permissions under copyright and certain other rights that a licensor has authority to grant. Use of the licensed material may still be restricted for other reasons, including because others have copyright or other rights in the material. A licensor may make special requests, such as asking that all changes be marked or described. Although not required by our licenses, you are encouraged to respect those requests where reasonable. [More considerations for the public](http://wiki.creativecommons.org/Considerations_for_licensors_and_licensees#Considerations_for_licensees). 12 | 13 | ## Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International Public License 14 | 15 | By exercising the Licensed Rights (defined below), You accept and agree to be bound by the terms and conditions of this Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International Public License ("Public License"). To the extent this Public License may be interpreted as a contract, You are granted the Licensed Rights in consideration of Your acceptance of these terms and conditions, and the Licensor grants You such rights in consideration of benefits the Licensor receives from making the Licensed Material available under these terms and conditions. 16 | 17 | ### Section 1 – Definitions. 18 | 19 | a. __Adapted Material__ means material subject to Copyright and Similar Rights that is derived from or based upon the Licensed Material and in which the Licensed Material is translated, altered, arranged, transformed, or otherwise modified in a manner requiring permission under the Copyright and Similar Rights held by the Licensor. For purposes of this Public License, where the Licensed Material is a musical work, performance, or sound recording, Adapted Material is always produced where the Licensed Material is synched in timed relation with a moving image. 20 | 21 | b. __Adapter's License__ means the license You apply to Your Copyright and Similar Rights in Your contributions to Adapted Material in accordance with the terms and conditions of this Public License. 22 | 23 | c. __BY-NC-SA Compatible License__ means a license listed at [creativecommons.org/compatiblelicenses](http://creativecommons.org/compatiblelicenses), approved by Creative Commons as essentially the equivalent of this Public License. 24 | 25 | d. __Copyright and Similar Rights__ means copyright and/or similar rights closely related to copyright including, without limitation, performance, broadcast, sound recording, and Sui Generis Database Rights, without regard to how the rights are labeled or categorized. For purposes of this Public License, the rights specified in Section 2(b)(1)-(2) are not Copyright and Similar Rights. 26 | 27 | e. __Effective Technological Measures__ means those measures that, in the absence of proper authority, may not be circumvented under laws fulfilling obligations under Article 11 of the WIPO Copyright Treaty adopted on December 20, 1996, and/or similar international agreements. 28 | 29 | f. __Exceptions and Limitations__ means fair use, fair dealing, and/or any other exception or limitation to Copyright and Similar Rights that applies to Your use of the Licensed Material. 30 | 31 | g. __License Elements__ means the license attributes listed in the name of a Creative Commons Public License. The License Elements of this Public License are Attribution, NonCommercial, and ShareAlike. 32 | 33 | h. __Licensed Material__ means the artistic or literary work, database, or other material to which the Licensor applied this Public License. 34 | 35 | i. __Licensed Rights__ means the rights granted to You subject to the terms and conditions of this Public License, which are limited to all Copyright and Similar Rights that apply to Your use of the Licensed Material and that the Licensor has authority to license. 36 | 37 | h. __Licensor__ means the individual(s) or entity(ies) granting rights under this Public License. 38 | 39 | i. __NonCommercial__ means not primarily intended for or directed towards commercial advantage or monetary compensation. For purposes of this Public License, the exchange of the Licensed Material for other material subject to Copyright and Similar Rights by digital file-sharing or similar means is NonCommercial provided there is no payment of monetary compensation in connection with the exchange. 40 | 41 | j. __Share__ means to provide material to the public by any means or process that requires permission under the Licensed Rights, such as reproduction, public display, public performance, distribution, dissemination, communication, or importation, and to make material available to the public including in ways that members of the public may access the material from a place and at a time individually chosen by them. 42 | 43 | k. __Sui Generis Database Rights__ means rights other than copyright resulting from Directive 96/9/EC of the European Parliament and of the Council of 11 March 1996 on the legal protection of databases, as amended and/or succeeded, as well as other essentially equivalent rights anywhere in the world. 44 | 45 | l. __You__ means the individual or entity exercising the Licensed Rights under this Public License. Your has a corresponding meaning. 46 | 47 | ### Section 2 – Scope. 48 | 49 | a. ___License grant.___ 50 | 51 | 1. Subject to the terms and conditions of this Public License, the Licensor hereby grants You a worldwide, royalty-free, non-sublicensable, non-exclusive, irrevocable license to exercise the Licensed Rights in the Licensed Material to: 52 | 53 | A. reproduce and Share the Licensed Material, in whole or in part, for NonCommercial purposes only; and 54 | 55 | B. produce, reproduce, and Share Adapted Material for NonCommercial purposes only. 56 | 57 | 2. __Exceptions and Limitations.__ For the avoidance of doubt, where Exceptions and Limitations apply to Your use, this Public License does not apply, and You do not need to comply with its terms and conditions. 58 | 59 | 3. __Term.__ The term of this Public License is specified in Section 6(a). 60 | 61 | 4. __Media and formats; technical modifications allowed.__ The Licensor authorizes You to exercise the Licensed Rights in all media and formats whether now known or hereafter created, and to make technical modifications necessary to do so. The Licensor waives and/or agrees not to assert any right or authority to forbid You from making technical modifications necessary to exercise the Licensed Rights, including technical modifications necessary to circumvent Effective Technological Measures. For purposes of this Public License, simply making modifications authorized by this Section 2(a)(4) never produces Adapted Material. 62 | 63 | 5. __Downstream recipients.__ 64 | 65 | A. __Offer from the Licensor – Licensed Material.__ Every recipient of the Licensed Material automatically receives an offer from the Licensor to exercise the Licensed Rights under the terms and conditions of this Public License. 66 | 67 | B. __Additional offer from the Licensor – Adapted Material.__ Every recipient of Adapted Material from You automatically receives an offer from the Licensor to exercise the Licensed Rights in the Adapted Material under the conditions of the Adapter’s License You apply. 68 | 69 | C. __No downstream restrictions.__ You may not offer or impose any additional or different terms or conditions on, or apply any Effective Technological Measures to, the Licensed Material if doing so restricts exercise of the Licensed Rights by any recipient of the Licensed Material. 70 | 71 | 6. __No endorsement.__ Nothing in this Public License constitutes or may be construed as permission to assert or imply that You are, or that Your use of the Licensed Material is, connected with, or sponsored, endorsed, or granted official status by, the Licensor or others designated to receive attribution as provided in Section 3(a)(1)(A)(i). 72 | 73 | b. ___Other rights.___ 74 | 75 | 1. Moral rights, such as the right of integrity, are not licensed under this Public License, nor are publicity, privacy, and/or other similar personality rights; however, to the extent possible, the Licensor waives and/or agrees not to assert any such rights held by the Licensor to the limited extent necessary to allow You to exercise the Licensed Rights, but not otherwise. 76 | 77 | 2. Patent and trademark rights are not licensed under this Public License. 78 | 79 | 3. To the extent possible, the Licensor waives any right to collect royalties from You for the exercise of the Licensed Rights, whether directly or through a collecting society under any voluntary or waivable statutory or compulsory licensing scheme. In all other cases the Licensor expressly reserves any right to collect such royalties, including when the Licensed Material is used other than for NonCommercial purposes. 80 | 81 | ### Section 3 – License Conditions. 82 | 83 | Your exercise of the Licensed Rights is expressly made subject to the following conditions. 84 | 85 | a. ___Attribution.___ 86 | 87 | 1. If You Share the Licensed Material (including in modified form), You must: 88 | 89 | A. retain the following if it is supplied by the Licensor with the Licensed Material: 90 | 91 | i. identification of the creator(s) of the Licensed Material and any others designated to receive attribution, in any reasonable manner requested by the Licensor (including by pseudonym if designated); 92 | 93 | ii. a copyright notice; 94 | 95 | iii. a notice that refers to this Public License; 96 | 97 | iv. a notice that refers to the disclaimer of warranties; 98 | 99 | v. a URI or hyperlink to the Licensed Material to the extent reasonably practicable; 100 | 101 | B. indicate if You modified the Licensed Material and retain an indication of any previous modifications; and 102 | 103 | C. indicate the Licensed Material is licensed under this Public License, and include the text of, or the URI or hyperlink to, this Public License. 104 | 105 | 2. You may satisfy the conditions in Section 3(a)(1) in any reasonable manner based on the medium, means, and context in which You Share the Licensed Material. For example, it may be reasonable to satisfy the conditions by providing a URI or hyperlink to a resource that includes the required information. 106 | 107 | 3. If requested by the Licensor, You must remove any of the information required by Section 3(a)(1)(A) to the extent reasonably practicable. 108 | 109 | b. ___ShareAlike.___ 110 | 111 | In addition to the conditions in Section 3(a), if You Share Adapted Material You produce, the following conditions also apply. 112 | 113 | 1. The Adapter’s License You apply must be a Creative Commons license with the same License Elements, this version or later, or a BY-NC-SA Compatible License. 114 | 115 | 2. You must include the text of, or the URI or hyperlink to, the Adapter's License You apply. You may satisfy this condition in any reasonable manner based on the medium, means, and context in which You Share Adapted Material. 116 | 117 | 3. You may not offer or impose any additional or different terms or conditions on, or apply any Effective Technological Measures to, Adapted Material that restrict exercise of the rights granted under the Adapter's License You apply. 118 | 119 | ### Section 4 – Sui Generis Database Rights. 120 | 121 | Where the Licensed Rights include Sui Generis Database Rights that apply to Your use of the Licensed Material: 122 | 123 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right to extract, reuse, reproduce, and Share all or a substantial portion of the contents of the database for NonCommercial purposes only; 124 | 125 | b. if You include all or a substantial portion of the database contents in a database in which You have Sui Generis Database Rights, then the database in which You have Sui Generis Database Rights (but not its individual contents) is Adapted Material, including for purposes of Section 3(b); and 126 | 127 | c. You must comply with the conditions in Section 3(a) if You Share all or a substantial portion of the contents of the database. 128 | 129 | For the avoidance of doubt, this Section 4 supplements and does not replace Your obligations under this Public License where the Licensed Rights include other Copyright and Similar Rights. 130 | 131 | ### Section 5 – Disclaimer of Warranties and Limitation of Liability. 132 | 133 | a. __Unless otherwise separately undertaken by the Licensor, to the extent possible, the Licensor offers the Licensed Material as-is and as-available, and makes no representations or warranties of any kind concerning the Licensed Material, whether express, implied, statutory, or other. This includes, without limitation, warranties of title, merchantability, fitness for a particular purpose, non-infringement, absence of latent or other defects, accuracy, or the presence or absence of errors, whether or not known or discoverable. Where disclaimers of warranties are not allowed in full or in part, this disclaimer may not apply to You.__ 134 | 135 | b. __To the extent possible, in no event will the Licensor be liable to You on any legal theory (including, without limitation, negligence) or otherwise for any direct, special, indirect, incidental, consequential, punitive, exemplary, or other losses, costs, expenses, or damages arising out of this Public License or use of the Licensed Material, even if the Licensor has been advised of the possibility of such losses, costs, expenses, or damages. Where a limitation of liability is not allowed in full or in part, this limitation may not apply to You.__ 136 | 137 | c. The disclaimer of warranties and limitation of liability provided above shall be interpreted in a manner that, to the extent possible, most closely approximates an absolute disclaimer and waiver of all liability. 138 | 139 | ### Section 6 – Term and Termination. 140 | 141 | a. This Public License applies for the term of the Copyright and Similar Rights licensed here. However, if You fail to comply with this Public License, then Your rights under this Public License terminate automatically. 142 | 143 | b. Where Your right to use the Licensed Material has terminated under Section 6(a), it reinstates: 144 | 145 | 1. automatically as of the date the violation is cured, provided it is cured within 30 days of Your discovery of the violation; or 146 | 147 | 2. upon express reinstatement by the Licensor. 148 | 149 | For the avoidance of doubt, this Section 6(b) does not affect any right the Licensor may have to seek remedies for Your violations of this Public License. 150 | 151 | c. For the avoidance of doubt, the Licensor may also offer the Licensed Material under separate terms or conditions or stop distributing the Licensed Material at any time; however, doing so will not terminate this Public License. 152 | 153 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public License. 154 | 155 | ### Section 7 – Other Terms and Conditions. 156 | 157 | a. The Licensor shall not be bound by any additional or different terms or conditions communicated by You unless expressly agreed. 158 | 159 | b. Any arrangements, understandings, or agreements regarding the Licensed Material not stated herein are separate from and independent of the terms and conditions of this Public License. 160 | 161 | ### Section 8 – Interpretation. 162 | 163 | a. For the avoidance of doubt, this Public License does not, and shall not be interpreted to, reduce, limit, restrict, or impose conditions on any use of the Licensed Material that could lawfully be made without permission under this Public License. 164 | 165 | b. To the extent possible, if any provision of this Public License is deemed unenforceable, it shall be automatically reformed to the minimum extent necessary to make it enforceable. If the provision cannot be reformed, it shall be severed from this Public License without affecting the enforceability of the remaining terms and conditions. 166 | 167 | c. No term or condition of this Public License will be waived and no failure to comply consented to unless expressly agreed to by the Licensor. 168 | 169 | d. Nothing in this Public License constitutes or may be interpreted as a limitation upon, or waiver of, any privileges and immunities that apply to the Licensor or You, including from the legal processes of any jurisdiction or authority. 170 | 171 | > Creative Commons is not a party to its public licenses. Notwithstanding, Creative Commons may elect to apply one of its public licenses to material it publishes and in those instances will be considered the “Licensor.” Except for the limited purpose of indicating that material is shared under a Creative Commons public license or as otherwise permitted by the Creative Commons policies published at [creativecommons.org/policies](http://creativecommons.org/policies), Creative Commons does not authorize the use of the trademark “Creative Commons” or any other trademark or logo of Creative Commons without its prior written consent including, without limitation, in connection with any unauthorized modifications to any of its public licenses or any other arrangements, understandings, or agreements concerning use of licensed material. For the avoidance of doubt, this paragraph does not form part of the public licenses. 172 | > 173 | > Creative Commons may be contacted at creativecommons.org 174 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Surface Defect Detection with Segmentation-Decision Network on KolektorSDD 2 | 3 | Official TensorFlow implementation for [Segmentation-based deep-learning approach for surface-defect detection](https://prints.vicos.si/publications/370) that uses segmentation and decision networks for the detection of surface defects. This work was done in collaboration with [Kolektor Group d.o.o.](http://www.kolektorvision.com/en/). 4 | 5 | [![CC BY-NC-SA 4.0][cc-by-nc-sa-shield]][cc-by-nc-sa] 6 | 7 | Code and the dataset are licensed under [Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License][cc-by-nc-sa]. For comerical use please contact danijel.skocaj@fri.uni-lj.si. 8 | 9 | [![CC BY-NC-SA 4.0][cc-by-nc-sa-image]][cc-by-nc-sa] 10 | 11 | [cc-by-nc-sa]: http://creativecommons.org/licenses/by-nc-sa/4.0/ 12 | [cc-by-nc-sa-image]: https://licensebuttons.net/l/by-nc-sa/4.0/88x31.png 13 | [cc-by-nc-sa-shield]: https://img.shields.io/badge/License-CC%20BY--NC--SA%204.0-lightgrey.svg 14 | 15 | # Citation: 16 | 17 | Please cite JIM 2019 journal paper: 18 | 19 | ``` 20 | @article{Tabernik2019JIM, 21 | author = {Tabernik, Domen and {\v{S}}ela, Samo and Skvar{\v{c}}, Jure and Sko{\v{c}}aj, Danijel}, 22 | journal = {Journal of Intelligent Manufacturing}, 23 | title = {{Segmentation-Based Deep-Learning Approach for Surface-Defect Detection}}, 24 | year = {2019}, 25 | month = {May}, 26 | day = {15}, 27 | issn={1572-8145}, 28 | doi={10.1007/s10845-019-01476-x} 29 | } 30 | 31 | ``` 32 | 33 | # Dependencies: 34 | 35 | * python2.7 36 | * TensorFlow r1.1 or newer (tested up to r1.8) 37 | * python libs: numpy, scipy, six, PIL, sklearn, pylab, matplotlib 38 | 39 | 40 | # Dataset 41 | 42 | The full dataset Kolektor Surface Defect Dataset (KolektorSDD) is available [here](https://www.vicos.si/Downloads/KolektorSDD). 43 | 44 | We split the dataset into three folds to perform 3-fold cross validation. The splits are available at [http://box.vicos.si/skokec/gostop/KolektorSDD-training-splits.zip](http://box.vicos.si/skokec/gostop/KolektorSDD-training-splits.zip). 45 | 46 | Fully prepared TensorFlow dataset split into 3 folds is available at [http://box.vicos.si/skokec/gostop/KolektorSDD-dilate=5-tensorflow.zip](http://box.vicos.si/skokec/gostop/KolektorSDD-dilate=5-tensorflow.zip). 47 | 48 | 49 | # Usage of training/evaluation code 50 | 51 | The following files are used to train/evaluate the model: 52 | * `segdec_train.py`: MAIN ENTRY for training and evaluation 53 | * `segdec_model.py`: model file for the network 54 | * `segdec_data.py`: dataset class for training the model 55 | 56 | 57 | Using the TensorFlow ready [KolektorSDD](https://www.vicos.si/Downloads/KolektorSDD) (with dilate=5 for mask) dataset you can train and evaluate with the following: 58 | 59 | 60 | 61 | ```bash 62 | 63 | # 1. Download and extract `KolektorSDD-dilate=5-tensorflow.zip` 64 | mkdir db 65 | cd db 66 | wget http://box.vicos.si/skokec/gostop/KolektorSDD-dilate=5-tensorflow.zip 67 | unzip -x KolektorSDD-dilate=5-tensorflow.zip 68 | cd .. 69 | 70 | 71 | # Empty folder where models/results will be stored 72 | export OUTPUT_FOLDER=`pwd`/output 73 | 74 | # folder where `KolektorSDD-dilate=5-tensorflow.zip` is extracted (must contain `KolektorSDD-dilate=5` subfolder). 75 | export DATASET_FOLDER=`pwd`/db 76 | 77 | mkdir $OUTPUT_FOLDER 78 | 79 | # 2. Train only segmentation network first: 80 | 81 | python -u segdec_train.py --fold=0,1,2 --gpu=0 --max_steps=6600 --train_subset=train \ 82 | --seg_net_type=ENTROPY \ 83 | --size_height=1408 \ 84 | --size_width=512 \ 85 | --with_seg_net=True \ 86 | --with_decision_net=False \ 87 | --storage_dir=$OUTPUT_FOLDER \ 88 | --dataset_dir=$DATASET_FOLDER \ 89 | --datasets=KolektorSDD-dilate=5 \ 90 | --name_prefix=full-size_cross-entropy 91 | 92 | # 3. Train and evaluate decision network based on existing segmentation network: 93 | 94 | # The `--pretrained_main_folder` must point to the folder where 'fold_XY' subfolders with the trained segmentation models are. 95 | # NOTE: Getting several `Not found: Key tower_0//decision` warrnings when loading the model is OK since the pre-trained model does not have decision net layers yet. 96 | 97 | python -u segdec_train.py --fold=0,1,2 --gpu=0 --max_steps=6600 --train_subset=train \ 98 | --seg_net_type=ENTROPY \ 99 | --size_height=1408 \ 100 | --size_width=512 \ 101 | --with_seg_net=False \ 102 | --with_decision_net=True \ 103 | --storage_dir=$OUTPUT_FOLDER \ 104 | --dataset_dir=$DATASET_FOLDER \ 105 | --datasets=KolektorSDD-dilate=5 \ 106 | --name_prefix=decision-net_full-size_cross-entropy \ 107 | --pretrained_main_folder=$OUTPUT_FOLDER/segdec_train/KolektorSDD-dilate=5/full-size_cross-entropy 108 | 109 | 110 | # 4. Print evaluation metrics combined from all folds 111 | 112 | python -u segdec_print_eval.py $OUTPUT_FOLDER/segdec_eval/KolektorSDD-dilate=5/decision-net_full-size_cross-entropy 113 | 114 | ``` 115 | 116 | Note: The model is sensitive to random data shuffles during the training and will lead to different performance with different runs. 117 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/skokec/segdec-net-jim2019/77849100d0fbd646ffd99fcae75cc9b7fbcc0b72/__init__.py -------------------------------------------------------------------------------- /input_data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/skokec/segdec-net-jim2019/77849100d0fbd646ffd99fcae75cc9b7fbcc0b72/input_data/__init__.py -------------------------------------------------------------------------------- /input_data/dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Small library that points to a data set. 16 | 17 | Methods of Data class: 18 | data_files: Returns a python list of all (sharded) data set files. 19 | num_examples_per_epoch: Returns the number of examples in the data set. 20 | num_classes: Returns the number of classes in the data set. 21 | reader: Return a reader for a single entry from the data set. 22 | """ 23 | from __future__ import absolute_import 24 | from __future__ import division 25 | from __future__ import print_function 26 | 27 | from abc import ABCMeta 28 | from abc import abstractmethod 29 | import os 30 | 31 | 32 | import tensorflow as tf 33 | 34 | 35 | class Dataset(object): 36 | """A simple class for handling data sets.""" 37 | __metaclass__ = ABCMeta 38 | 39 | def __init__(self, name, subset, data_dir): 40 | """Initialize input_data using a subset and the path to the data.""" 41 | assert subset in self.available_subsets(), self.available_subsets() 42 | self.name = name 43 | self.subset = subset 44 | self.data_dir = data_dir 45 | 46 | @abstractmethod 47 | def num_classes(self): 48 | """Returns the number of classes in the data set.""" 49 | pass 50 | # return 10 51 | 52 | @abstractmethod 53 | def num_examples_per_epoch(self): 54 | """Returns the number of examples in the data subset.""" 55 | # read all sharded data files and count how may samples is there 56 | count = 0 57 | for tf_file in self.data_files(): 58 | for record in tf.python_io.tf_record_iterator(tf_file): 59 | count += 1 60 | return count 61 | 62 | @abstractmethod 63 | def download_message(self): 64 | """Prints a download message for the Dataset.""" 65 | pass 66 | 67 | def available_subsets(self): 68 | """Returns the list of available subsets.""" 69 | return ['train', 'train_pos', 'test'] 70 | 71 | def data_files(self): 72 | """Returns a python list of all (sharded) data subset files. 73 | 74 | Returns: 75 | python list of all (sharded) data set files. 76 | Raises: 77 | ValueError: if there are not data_files matching the subset. 78 | """ 79 | tf_record_pattern = os.path.join(self.data_dir, '%s-*' % self.subset) 80 | data_files = tf.gfile.Glob(tf_record_pattern) 81 | if not data_files: 82 | print('No files found for input_data %s/%s at %s' % (self.name, 83 | self.subset, 84 | self.data_dir)) 85 | 86 | self.download_message() 87 | exit(-1) 88 | return data_files 89 | 90 | def reader(self): 91 | """Return a reader for a single entry from the data set. 92 | 93 | See io_ops.py for details of Reader class. 94 | 95 | Returns: 96 | Reader object that reads the data set. 97 | """ 98 | return tf.TFRecordReader() 99 | -------------------------------------------------------------------------------- /input_data/image_processing.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import tensorflow as tf 6 | from tensorflow.python.framework import tensor_shape 7 | 8 | 9 | class NetInputProcessing(object): 10 | def __init__(self, batch_size, num_preprocess_threads, input_size = None, mask_size = None, num_readers = 1, input_queue_memory_factor=16, use_random_rotation=False, ensure_posneg_balance=True): 11 | self.batch_size = batch_size 12 | 13 | self.input_size = input_size # input_size = (height, width, depth) 14 | self.mask_size = mask_size # mask_size = (height, width, depth) 15 | 16 | self.num_preprocess_threads = num_preprocess_threads 17 | self.num_readers = num_readers 18 | self.input_queue_memory_factor = input_queue_memory_factor 19 | self.use_random_rotation = use_random_rotation 20 | self.ensure_posneg_balance = ensure_posneg_balance 21 | 22 | if self.num_preprocess_threads is None: 23 | raise Exception("Missing num_preprocess_threads argument") 24 | 25 | if self.num_preprocess_threads % 1: 26 | raise ValueError('Please make num_preprocess_threads a multiple ' 27 | 'of 1 (%d % 1 != 0).', self.num_preprocess_threads) 28 | 29 | if self.num_readers is None: 30 | raise Exception("Missing num_readers argument") 31 | 32 | if self.num_readers < 1: 33 | raise ValueError('Please make num_readers at least 1') 34 | 35 | def add_inputs_nodes(self, dataset, train): 36 | """Generate batches of ImageNet images for evaluation. 37 | 38 | Use this function as the inputs for evaluating a network. 39 | 40 | Note that some (minimal) image preprocessing occurs during evaluation 41 | including central cropping and resizing of the image to fit the network. 42 | 43 | Args: 44 | dataset: instance of Dataset class specifying the input_data. 45 | batch_size: integer, number of examples in batch 46 | num_preprocess_threads: integer, total number of preprocessing threads but 47 | None defaults to FLAGS.num_preprocess_threads. 48 | 49 | Returns: 50 | images: Images. 4D tensor of size [batch_size, FLAGS.image_size, 51 | image_size, 3]. 52 | labels: 1-D integer Tensor of [FLAGS.batch_size]. 53 | """ 54 | 55 | # Force all input processing onto CPU in order to reserve the GPU for 56 | # the forward inference and back-propagation. 57 | with tf.device('/cpu:0'): 58 | images, labels, img_names = self.batch_inputs(dataset, train=train) 59 | 60 | return images, labels, img_names 61 | 62 | def batch_inputs(self, dataset, train): 63 | """Contruct batches of training or evaluation examples from the image input_data. 64 | 65 | Args: 66 | dataset: instance of Dataset class specifying the input_data. 67 | See input_data.py for details. 68 | batch_size: integer 69 | train: boolean 70 | num_preprocess_threads: integer, total number of preprocessing threads 71 | num_readers: integer, number of parallel readers 72 | 73 | Returns: 74 | images: 4-D float Tensor of a batch of images 75 | labels: 1-D integer Tensor of [batch_size]. 76 | 77 | Raises: 78 | ValueError: if data is not found 79 | """ 80 | with tf.name_scope('batch_processing'): 81 | data_files = dataset.data_files() 82 | if data_files is None: 83 | raise ValueError('No data files found for this input_data') 84 | 85 | # Create filename_queue 86 | if train: 87 | filename_queue = tf.train.string_input_producer(data_files, 88 | shuffle=True, 89 | capacity=16) 90 | else: 91 | filename_queue = tf.train.string_input_producer(data_files, 92 | shuffle=False, 93 | capacity=1) 94 | 95 | # Approximate number of examples per shard. 96 | examples_per_shard = 1024 97 | # Size the random shuffle queue to balance between good global 98 | # mixing (more examples) and memory use (fewer examples). 99 | # 1 image uses 299*299*3*4 bytes = 1MB 100 | # The default input_queue_memory_factor is 16 implying a shuffling queue 101 | # size: examples_per_shard * 16 * 1MB = 17.6GB 102 | min_queue_examples = examples_per_shard * self.input_queue_memory_factor 103 | if train: 104 | examples_queue = tf.RandomShuffleQueue( 105 | capacity=min_queue_examples + 3 * self.batch_size, 106 | min_after_dequeue=min_queue_examples, 107 | dtypes=[tf.string]) 108 | else: 109 | examples_queue = tf.FIFOQueue( 110 | capacity=examples_per_shard + 3 * self.batch_size, 111 | dtypes=[tf.string]) 112 | 113 | # Create multiple readers to populate the queue of examples. 114 | if self.num_readers > 1: 115 | enqueue_ops = [] 116 | for _ in range(self.num_readers): 117 | reader = dataset.reader() 118 | _, value = reader.read(filename_queue) 119 | enqueue_ops.append(examples_queue.enqueue([value])) 120 | 121 | tf.train.queue_runner.add_queue_runner( 122 | tf.train.queue_runner.QueueRunner(examples_queue, enqueue_ops)) 123 | example_serialized = examples_queue.dequeue() 124 | else: 125 | reader = dataset.reader() 126 | _, example_serialized = reader.read(filename_queue) 127 | 128 | 129 | pos_queue = None 130 | neg_queue = None 131 | 132 | if self.batch_size < 2: 133 | pos_queue = tf.RandomShuffleQueue(name="pos-queue", capacity=10, min_after_dequeue=5, dtypes=[tf.float32, tf.float32, tf.string]) 134 | neg_queue = tf.RandomShuffleQueue(name="neg-queue", capacity=10, min_after_dequeue=5, dtypes=[tf.float32, tf.float32, tf.string]) 135 | 136 | pos_queue_enq = [] 137 | neg_queue_enq = [] 138 | 139 | with tf.name_scope('split-merge'): 140 | if train and self.ensure_posneg_balance: 141 | images_and_masks = [] 142 | for thread_id in range(self.num_preprocess_threads): 143 | # Parse a serialized Example proto to extract the image and metadata. 144 | image_buffer, mask_buffer, img_name_ = self.parse_example_proto(example_serialized) 145 | 146 | image_ = self.image_preprocessing(image_buffer, img_size=(self.input_size[0],self.input_size[1]), num_channels=self.input_size[2]) 147 | mask_ = self.image_preprocessing(mask_buffer, img_size=(self.mask_size[0],self.mask_size[1]), num_channels=self.mask_size[2]) 148 | 149 | image_ = tf.expand_dims(image_, 0) 150 | mask_ = tf.expand_dims(mask_, 0) 151 | img_name_ = tf.expand_dims(img_name_,0) 152 | 153 | img_shape = tf.TensorShape([image_.shape[1], image_.shape[2], image_.shape[3]]) 154 | mask_shape = tf.TensorShape([mask_.shape[1], mask_.shape[2], mask_.shape[3]]) 155 | img_name_shape = tf.TensorShape([]) 156 | 157 | # initialize pos/neg queues with proper shape size on first 158 | if pos_queue is None or neg_queue is None: 159 | pos_queue = tf.RandomShuffleQueue(name="pos-queue", capacity=10, min_after_dequeue=5, dtypes=[tf.float32, tf.float32, tf.string], shapes=[img_shape, mask_shape, img_name_shape]) 160 | neg_queue = tf.RandomShuffleQueue(name="neg-queue", capacity=10, min_after_dequeue=5, dtypes=[tf.float32, tf.float32, tf.string], shapes=[img_shape, mask_shape, img_name_shape]) 161 | 162 | is_pos = tf.squeeze(tf.reduce_sum(mask_,[1,2], keep_dims=False)) 163 | 164 | neg_mask = tf.less_equal(is_pos, 0) 165 | 166 | pos_idx = tf.reshape(tf.where([tf.logical_not(neg_mask)]), [-1]) 167 | neg_idx = tf.reshape(tf.where([neg_mask]),[-1]) 168 | 169 | pos_data = [tf.gather(image_, pos_idx), 170 | tf.gather(mask_, pos_idx), 171 | tf.gather(img_name_, pos_idx)] 172 | neg_data = [tf.gather(image_, neg_idx), 173 | tf.gather(mask_, neg_idx), 174 | tf.gather(img_name_, neg_idx)] 175 | 176 | pos_queue_enq.append(pos_queue.enqueue_many(pos_data)) 177 | neg_queue_enq.append(neg_queue.enqueue_many(neg_data)) 178 | 179 | 180 | tf.train.queue_runner.add_queue_runner(tf.train.queue_runner.QueueRunner(pos_queue, pos_queue_enq)) 181 | tf.train.queue_runner.add_queue_runner(tf.train.queue_runner.QueueRunner(neg_queue, neg_queue_enq)) 182 | 183 | 184 | if self.batch_size >= 2: 185 | if self.batch_size % 2 != 0: 186 | raise Exception("'batch_size' mod 2 != 0 ! only even batch sizes supported at the moment") 187 | 188 | num_deque = int(self.batch_size / 2) 189 | 190 | pos_data = pos_queue.dequeue_many(num_deque) 191 | neg_data = neg_queue.dequeue_many(num_deque) 192 | 193 | concat_data = [tf.concat([pos_data[0], neg_data[0]], axis=0, name='Concat-img'), 194 | tf.concat([pos_data[1], neg_data[1]], axis=0, name='Concat-mask'), 195 | tf.concat([pos_data[2], neg_data[2]], axis=0, name='Concat-img-name')] 196 | 197 | # randomly permute within batch size (is this even necessary ??) 198 | idx = tf.Variable(range(0, self.batch_size), trainable=False, dtype=tf.int32) 199 | idx = tf.random_shuffle(idx) 200 | 201 | images = tf.gather(concat_data[0],idx) 202 | masks = tf.gather(concat_data[1],idx) 203 | img_names = tf.gather(concat_data[2],idx) 204 | 205 | else: 206 | # positive only 207 | #images, masks, img_names = pos_queue.dequeue() 208 | 209 | # negative only 210 | #images, masks, img_names = neg_queue.dequeue() 211 | 212 | # mix 50/50 213 | counter = tf.Variable(initial_value=0, trainable=False, dtype=tf.int32) 214 | 215 | counter = tf.assign_add(counter, 1) 216 | condition_term = tf.equal(tf.mod(counter, 2), tf.constant(0)) 217 | images, masks, img_names = tf.cond(condition_term, 218 | lambda: pos_queue.dequeue(), 219 | lambda: neg_queue.dequeue()) 220 | 221 | if self.use_random_rotation: 222 | images.set_shape(tensor_shape.as_shape([None, None, 1])) 223 | masks.set_shape(tensor_shape.as_shape([None, None, 1])) 224 | 225 | # randomly rotate image by 90 degrees 226 | rot_factor = tf.random_uniform([1], minval=0, maxval=3, dtype=tf.int32) 227 | rot_factor = tf.gather(rot_factor,0) 228 | 229 | images = tf.image.rot90(images, k=rot_factor) 230 | masks = tf.image.rot90(masks, k=rot_factor) 231 | 232 | images = tf.expand_dims(images,axis=0) 233 | masks = tf.expand_dims(masks, axis=0) 234 | img_names = tf.expand_dims(img_names, axis=0) 235 | else: 236 | 237 | # Parse a serialized Example proto to extract the image and metadata. 238 | image_buffer, mask_buffer, img_names = self.parse_example_proto(example_serialized) 239 | 240 | images = self.image_preprocessing(image_buffer, 241 | img_size=(self.input_size[0], self.input_size[1]), 242 | num_channels=self.input_size[2]) 243 | masks = self.image_preprocessing(mask_buffer, img_size=(self.mask_size[0], self.mask_size[1]), 244 | num_channels=1) 245 | 246 | 247 | 248 | images = tf.expand_dims(images, axis=0) 249 | masks = tf.expand_dims(masks, axis=0) 250 | img_names = tf.expand_dims(img_names, axis=0) 251 | 252 | # Reshape images into these desired dimensions. 253 | images = tf.cast(images, tf.float32) 254 | masks = tf.cast(masks, tf.float32) 255 | 256 | images.set_shape(tensor_shape.as_shape([self.batch_size, None, None, self.input_size[2]])) 257 | masks.set_shape(tensor_shape.as_shape([self.batch_size, self.input_size[0], self.input_size[1], self.mask_size[2]])) 258 | 259 | # Display the training images in the visualizer. 260 | tf.summary.image('images', images) 261 | tf.summary.image('masks', masks) 262 | 263 | return images, masks, img_names 264 | 265 | def decode_png(self, image_buffer, num_channels, scope=None): 266 | """Decode a PNG string into one 3-D float image Tensor. 267 | 268 | Args: 269 | image_buffer: scalar string Tensor. 270 | scope: Optional scope for name_scope. 271 | Returns: 272 | 3-D float Tensor with values ranging from [0, 1). 273 | """ 274 | with tf.name_scope(values=[image_buffer], name=scope, 275 | default_name='decode_png'): 276 | # Decode the string as an PNG. 277 | # Note that the resulting image contains an unknown height and width 278 | # that is set dynamically by decode_jpeg. In other words, the height 279 | # and width of image is unknown at compile-time. 280 | image = tf.image.decode_png(image_buffer, channels=num_channels) 281 | 282 | # After this point, all image pixels reside in [0,1) 283 | # until the very end, when they're rescaled to (-1, 1). The various 284 | # adjust_* ops all require this range for dtype float. 285 | image = tf.image.convert_image_dtype(image, dtype=tf.float32) 286 | return image 287 | 288 | 289 | 290 | def image_preprocessing(self, image_buffer, img_size, num_channels): 291 | """Decode and preprocess one image for evaluation or training. 292 | 293 | Args: 294 | image_buffer: JPEG encoded string Tensor 295 | 296 | Returns: 297 | 3-D float Tensor containing an appropriately scaled image 298 | """ 299 | 300 | image = self.decode_png(image_buffer, num_channels) 301 | 302 | # Image size is unknown until run-time so we need to resize the image to specific size 303 | image = tf.image.resize_images(image, size=img_size) 304 | 305 | # Finally, rescale to [-1,1] instead of [0, 1) 306 | #image = tf.subtract(image, 0.5) 307 | #image = tf.multiply(image, 2.0) 308 | return image 309 | 310 | 311 | def parse_example_proto(self, example_serialized): 312 | """Parses an Example proto containing a training example of an image. 313 | 314 | The output of the build_image_data.py image preprocessing script is a input_data 315 | containing serialized Example protocol buffers. Each Example proto contains 316 | the following fields: 317 | 318 | image/height: 750 319 | image/width: 250 320 | image/channels: 1 321 | image/class/encoded: 322 | image/class/filename: 'knee pad' 323 | image/class/channels: 1 324 | image/format: 'PNG' 325 | image/filename: 'ILSVRC2012_val_00041207.JPEG' 326 | image/encoded: 327 | 328 | Args: 329 | example_serialized: scalar Tensor tf.string containing a serialized 330 | Example protocol buffer. 331 | 332 | Returns: 333 | image_buffer: Tensor tf.string containing the contents of a PNG file. 334 | label_buffer: Tensor tf.string containing the contents of a PNG mask/groundtruth file. 335 | """ 336 | # Dense features in Example proto. 337 | feature_map = { 338 | 'image/encoded': tf.FixedLenFeature([], dtype=tf.string, default_value=''), 339 | 'image/class/encoded': tf.FixedLenFeature([], dtype=tf.string, default_value=''), 340 | 'image/filename': tf.FixedLenFeature([], dtype=tf.string, default_value=''), 341 | 'image/name': tf.FixedLenFeature([], dtype=tf.string, default_value=''), 342 | } 343 | 344 | features = tf.parse_single_example(example_serialized, feature_map) 345 | 346 | return features['image/encoded'], features['image/class/encoded'], features['image/name'] 347 | 348 | 349 | -------------------------------------------------------------------------------- /input_data/input_data_build_image_data_with_mask.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Converts image data to TFRecords file format with Example protos. 16 | 17 | The image data set is expected to reside in JPEG files located in the 18 | following directory structure. 19 | 20 | data_dir/label_0/image0.jpeg 21 | data_dir/label_0/image1.jpg 22 | ... 23 | data_dir/label_1/weird-image.jpeg 24 | data_dir/label_1/my-image.jpeg 25 | ... 26 | 27 | where the sub-directory is the unique label associated with these images. 28 | 29 | This TensorFlow script converts the training and evaluation data into 30 | a sharded data set consisting of TFRecord files 31 | 32 | train_directory/train-00000-of-01024 33 | train_directory/train-00001-of-01024 34 | ... 35 | train_directory/train-01023-of-01024 36 | 37 | and 38 | 39 | validation_directory/validation-00000-of-00128 40 | validation_directory/validation-00001-of-00128 41 | ... 42 | validation_directory/validation-00127-of-00128 43 | 44 | where we have selected 1024 and 128 shards for each data set. Each record 45 | within the TFRecord file is a serialized Example proto. The Example proto 46 | contains the following fields: 47 | 48 | image/encoded: string containing JPEG encoded image in RGB colorspace 49 | image/height: integer, image height in pixels 50 | image/width: integer, image width in pixels 51 | image/colorspace: string, specifying the colorspace, always 'RGB' 52 | image/channels: integer, specifying the number of channels, always 3 53 | image/format: string, specifying the format, always 'JPEG' 54 | 55 | image/filename: string containing the basename of the image file 56 | e.g. 'n01440764_10026.JPEG' or 'ILSVRC2012_val_00000293.JPEG' 57 | image/class/encoded: string containing PNG encoded class mask 58 | 59 | """ 60 | from __future__ import absolute_import 61 | from __future__ import division 62 | from __future__ import print_function 63 | 64 | from datetime import datetime 65 | import os 66 | import random 67 | import sys 68 | import threading 69 | 70 | import numpy as np 71 | import tensorflow as tf 72 | 73 | from scipy.ndimage.morphology import binary_dilation 74 | from scipy.ndimage import generate_binary_structure, iterate_structure 75 | 76 | from cStringIO import StringIO 77 | from PIL import Image 78 | 79 | def _int64_feature(value): 80 | """Wrapper for inserting int64 features into Example proto.""" 81 | if not isinstance(value, list): 82 | value = [value] 83 | return tf.train.Feature(int64_list=tf.train.Int64List(value=value)) 84 | 85 | 86 | def _bytes_feature(value): 87 | """Wrapper for inserting bytes features into Example proto.""" 88 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) 89 | 90 | 91 | def _convert_to_example(filename, image_buffer, img_channels, mask_filename, mask_buffer, mask_channels, image_format, height, width, naming_fn = None): 92 | """Build an Example proto for an example. 93 | 94 | Args: 95 | filename: string, path to an image file, e.g., '/path/to/example.JPG' 96 | image_buffer: string, PNG/JPEG encoding of RGB image 97 | mask_filename: string, path to groundtruth mask file 98 | mask_buffer: string, PNG/JPEG encoding of mask image 99 | image_format: string, format of encoded images, e.g., PNG or JPEG 100 | height: integer, image height in pixels 101 | width: integer, image width in pixels 102 | Returns: 103 | Example proto 104 | """ 105 | 106 | feat = { 107 | 'image/height': _int64_feature(height), 108 | 'image/width': _int64_feature(width), 109 | 'image/channels': _int64_feature(img_channels), 110 | 'image/format': _bytes_feature(tf.compat.as_bytes(image_format)), 111 | 'image/filename': _bytes_feature(tf.compat.as_bytes(os.path.basename(filename))), 112 | 'image/name': _bytes_feature( 113 | tf.compat.as_bytes(naming_fn(filename) if naming_fn is not None else os.path.abspath(filename))), 114 | 'image/encoded': _bytes_feature(tf.compat.as_bytes(image_buffer))} 115 | 116 | if mask_buffer is not None: 117 | feat['image/class/encoded'] = _bytes_feature(tf.compat.as_bytes(mask_buffer)) 118 | feat['image/class/channels'] = _int64_feature(mask_channels) 119 | 120 | if mask_filename is not None: 121 | feat['image/class/filename'] = _bytes_feature(tf.compat.as_bytes(os.path.basename(mask_filename))) 122 | 123 | 124 | example = tf.train.Example(features=tf.train.Features(feature=feat)) 125 | return example 126 | 127 | 128 | def _process_image(filename, out_format, resize = None, dilate = None, require_binary_output = False): 129 | """Process a single image file. 130 | 131 | Args: 132 | filename: string, path to an image file e.g., '/path/to/example.JPG'. 133 | out_format: string, output format type e.g., 'PNG', 'JPEG' 134 | Returns: 135 | image_buffer: string, encoding of image in out_format 136 | height: integer, image height in pixels. 137 | width: integer, image width in pixels. 138 | """ 139 | # Read the image file. 140 | with tf.gfile.FastGFile(filename, 'rb') as f: 141 | raw_image_data = f.read() 142 | 143 | # Convert any format to PNG for consistency. 144 | pil_img = Image.open(StringIO(raw_image_data)) 145 | 146 | # dilate image if requested so - create structering element of appropriate size 147 | if dilate is not None: 148 | dilation_se = iterate_structure(generate_binary_structure(2,1),(int)((dilate-1)/2)) 149 | im = binary_dilation(np.array(pil_img), structure=dilation_se) 150 | pil_img = Image.fromarray(np.uint8(im)*255) 151 | 152 | if resize is not None: 153 | pil_img = pil_img.resize(resize[::-1]) # NOTE: use reversed order of resize to make input consistent with tensorflow 154 | 155 | # if output should be in binary then we must do binarization to remove interpolation values from resize 156 | if require_binary_output: 157 | im = (np.array(pil_img) > 0) 158 | pil_img = Image.fromarray(np.uint8(im) * 255) 159 | 160 | image_data = StringIO() 161 | pil_img.save(image_data, out_format) 162 | 163 | height = pil_img.size[1] 164 | width = pil_img.size[0] 165 | if pil_img.mode in ['RGBA', 'CMYK']: 166 | num_chanels = 4 167 | elif pil_img.mode in ['RGB','LAB','HSV','YCbCr']: 168 | num_chanels = 3 169 | else: 170 | num_chanels = 1 171 | 172 | return image_data.getvalue(), height, width, num_chanels 173 | 174 | 175 | def _process_image_files_batch(image_format, thread_index, ranges, name, filenames, masks, num_shards, output_directory, resize = None, naming_fn = None, dilate = None, require_binary_output = False, export_images = False): 176 | """Processes and saves list of images as TFRecord in 1 thread. 177 | 178 | Args: 179 | image_format: string, output format type e.g., 'PNG', 'JPEG' 180 | thread_index: integer, unique batch to run index is within [0, len(ranges)). 181 | ranges: list of pairs of integers specifying ranges of each batches to 182 | analyze in parallel. 183 | name: string, unique identifier specifying the data set 184 | filenames: list of strings; each string is a path to an image file 185 | texts: list of strings; each string is human readable, e.g. 'dog' 186 | masks: list of strings; each string is a path to an groundtruth mask file 187 | num_shards: integer number of shards for this data set. 188 | """ 189 | # Each thread produces N shards where N = int(num_shards / num_threads). 190 | # For instance, if num_shards = 128, and the num_threads = 2, then the first 191 | # thread would produce shards [0, 64). 192 | num_threads = len(ranges) 193 | assert not num_shards % num_threads 194 | num_shards_per_batch = int(num_shards / num_threads) 195 | 196 | shard_ranges = np.linspace(ranges[thread_index][0], 197 | ranges[thread_index][1], 198 | num_shards_per_batch + 1).astype(int) 199 | num_files_in_thread = ranges[thread_index][1] - ranges[thread_index][0] 200 | 201 | counter = 0 202 | for s in range(num_shards_per_batch): 203 | # Generate a sharded version of the file name, e.g. 'train-00002-of-00010' 204 | shard = thread_index * num_shards_per_batch + s 205 | output_filename = '%s-%.5d-of-%.5d' % (name, shard, num_shards) 206 | output_file = os.path.join(output_directory, output_filename) 207 | writer = tf.python_io.TFRecordWriter(output_file) 208 | 209 | export_folder_imgs = None 210 | export_folder_masks = None 211 | if export_images: 212 | export_folder = os.path.join(output_directory, 'export') 213 | export_folder_imgs = os.path.join(export_folder, 'imgs') 214 | export_folder_masks = os.path.join(export_folder, 'masks') 215 | 216 | try: 217 | os.makedirs(export_folder_imgs) 218 | except: 219 | pass 220 | 221 | try: 222 | os.makedirs(export_folder_masks) 223 | except: 224 | pass 225 | 226 | 227 | shard_counter = 0 228 | files_in_shard = np.arange(shard_ranges[s], shard_ranges[s + 1], dtype=int) 229 | for i in files_in_shard: 230 | filename = filenames[i] 231 | mask_filename = masks[i] 232 | 233 | try: 234 | image_buffer, img_height, img_width, img_channels = _process_image(filename, image_format, resize) 235 | except Exception as e: 236 | print(e) 237 | print('SKIPPED: Unexpected eror while decoding %s.' % filename) 238 | continue 239 | 240 | try: 241 | mask_buffer, mask_height, mask_width, mask_channels = _process_image(mask_filename, image_format, resize, dilate,require_binary_output=require_binary_output) 242 | 243 | except Exception as e: 244 | print('WARNING: No mask found for %s - using empty mask instead' % filename) 245 | 246 | # Generate dummy mask 247 | pil_img = Image.fromarray(np.zeros([img_height, img_width], dtype=np.uint8)) 248 | image_data = StringIO() 249 | pil_img.save(image_data, image_format) 250 | 251 | mask_buffer = image_data.getvalue() 252 | mask_height = pil_img.size[1] 253 | mask_width = pil_img.size[0] 254 | mask_channels = 1 255 | 256 | assert img_height == mask_height 257 | assert img_width == mask_width 258 | 259 | example = _convert_to_example(filename, image_buffer, img_channels, mask_filename, mask_buffer, mask_channels, image_format, img_height, img_width, naming_fn=naming_fn) 260 | writer.write(example.SerializeToString()) 261 | shard_counter += 1 262 | counter += 1 263 | 264 | # if export_images is set then we need to export image and mask as raw image into appropriate folder 265 | if export_images and export_folder_imgs is not None and export_folder_masks is not None: 266 | 267 | # get name of sample using img name and its last folder parent 268 | part_name = os.path.basename(os.path.dirname(filename)) 269 | part_id = os.path.splitext(os.path.basename(filename))[0] 270 | 271 | export_img_name = str.format("{0}_{1}.{2}",part_name, part_id, image_format.lower() ) 272 | export_mask_name = str.format("{0}_{1}_label.{2}", part_name, part_id, image_format.lower()) 273 | 274 | # export both img and its mask 275 | Image.open(StringIO(image_buffer)).save(os.path.join(export_folder_imgs, export_img_name)) 276 | Image.open(StringIO(mask_buffer)).save(os.path.join(export_folder_masks, export_mask_name)) 277 | 278 | 279 | if not counter % 1000: 280 | print('%s [thread %d]: Processed %d of %d images in thread batch.' % 281 | (datetime.now(), thread_index, counter, num_files_in_thread)) 282 | sys.stdout.flush() 283 | 284 | writer.close() 285 | print('%s [thread %d]: Wrote %d images to %s' % 286 | (datetime.now(), thread_index, shard_counter, output_file)) 287 | sys.stdout.flush() 288 | shard_counter = 0 289 | print('%s [thread %d]: Wrote %d images to %d shards.' % 290 | (datetime.now(), thread_index, counter, num_files_in_thread)) 291 | sys.stdout.flush() 292 | 293 | 294 | def _process_image_files(name, filenames, masks, output_directory, num_shards, num_threads, resize = None, naming_fn = None, dilate = None, require_binary_output = False, export_images = False): 295 | """Process and save list of images as TFRecord of Example protos. 296 | 297 | Args: 298 | name: string, unique identifier specifying the data set 299 | filenames: list of strings; each string is a path to an image file 300 | masks: list of strings; each string is a path to an groundtruth mask file 301 | output_directory: string, output folder path 302 | num_shards: integer number of shards for this data set. 303 | num_threads: integer number of threads to use, NOTE: must be num_shards % num_threads == 0 304 | """ 305 | if masks is not None: 306 | assert len(filenames) == len(masks) 307 | 308 | # Break all images into batches with a [ranges[i][0], ranges[i][1]]. 309 | spacing = np.linspace(0, len(filenames), num_threads + 1).astype(np.int) 310 | ranges = [] 311 | for i in range(len(spacing) - 1): 312 | ranges.append([spacing[i], spacing[i + 1]]) 313 | 314 | # Launch a thread for each batch. 315 | print('Launching %d threads for spacings: %s' % (num_threads, ranges)) 316 | sys.stdout.flush() 317 | 318 | # Create a mechanism for monitoring when all threads are finished. 319 | coord = tf.train.Coordinator() 320 | 321 | image_format = 'PNG' 322 | 323 | threads = [] 324 | for thread_index in range(len(ranges)): 325 | args = (image_format, thread_index, ranges, name, filenames, masks, num_shards, output_directory, resize, naming_fn, dilate, require_binary_output, export_images) 326 | _process_image_files_batch(*args) 327 | #t = threading.Thread(target=_process_image_files_batch, args=args) 328 | #t.start() 329 | #threads.append(t) 330 | 331 | # Wait for all the threads to terminate. 332 | coord.join(threads) 333 | print('%s: Finished writing all %d images in data set.' % 334 | (datetime.now(), len(filenames))) 335 | sys.stdout.flush() 336 | 337 | 338 | def _find_image_files(data_dir_list, data_ext, mask_pattern, ignore_non_positive_masks = False): 339 | """Build a list of all images files and labels in the data set. 340 | 341 | Args: 342 | data_dir_list: array of string, list of paths to the root directory of images. 343 | 344 | Assumes that the image data set resides in JPEG files located in 345 | the following directory structure. 346 | 347 | data_dir[0]/another-image.JPEG 348 | data_dir[1]/my-image.jpg 349 | 350 | with corresponding mask files in the same folders 351 | 352 | data_dir[0]/another-image_mask.png 353 | data_dir[1]/my-image_mask.png 354 | 355 | data_ext: string, extension of images 356 | mask_pattern: tuple of string, with mask_pattern[0] string replace pattern 357 | and mask_pattern[1] replace string, e.g., mask_pattern = ('.jpg', '_mask.png') 358 | ignore_non_positive_masks: boolean, ignores files that have only zero mask values 359 | 360 | Returns: 361 | filenames: list of strings; each string is a path to an image file. 362 | mask_filenames: list of strings; each string is a path to an image mask file. 363 | """ 364 | 365 | filenames = [] 366 | mask_filenames = [] 367 | 368 | # Construct the list of JPEG files and labels. 369 | for data_dir in data_dir_list: 370 | print('Determining list of input files and labels from %s.' % data_dir) 371 | 372 | jpeg_file_path = '%s/*%s' % (data_dir, data_ext) 373 | matching_files = tf.gfile.Glob(jpeg_file_path) 374 | 375 | # remove files that are actual labels !! 376 | matching_files = [f for f in matching_files if not f.endswith(mask_pattern[1])] 377 | 378 | filenames.extend(matching_files) 379 | 380 | # Find corresponding mask files 381 | matching_files_masks = [filename.replace(mask_pattern[0],mask_pattern[1]) for filename in matching_files] 382 | 383 | mask_filenames.extend(matching_files_masks) 384 | 385 | if ignore_non_positive_masks: 386 | select = [np.any(Image.open(m)) for m in mask_filenames if os.path.exists(m)] 387 | filenames = [f for f,s in zip(filenames,select) if s] 388 | mask_filenames = [m for m, s in zip(mask_filenames, select) if s] 389 | 390 | # Shuffle the ordering of all image files in order to guarantee 391 | # random ordering of the images with respect to label in the 392 | # saved TFRecord files. Make the randomization repeatable. 393 | shuffled_index = list(range(len(filenames))) 394 | random.seed(12345) 395 | random.shuffle(shuffled_index) 396 | 397 | filenames = [filenames[i] for i in shuffled_index] 398 | mask_filenames = [mask_filenames[i] for i in shuffled_index] 399 | 400 | print('Found %d image files across all folders.' % (len(filenames))) 401 | 402 | return filenames, mask_filenames 403 | 404 | 405 | def _process_dataset(name, directory_list, data_extension, mask_patterns, output_directory, num_shards, num_threads, resize = None, ignore_non_positive_masks = False, naming_fn = None, dilate = None, require_binary_output = False, export_images = False): 406 | """Process a complete data set and save it as a TFRecord. 407 | 408 | Args: 409 | name: string, unique identifier specifying the data set. 410 | directory: string, root path to the data set. 411 | num_shards: integer number of shards for this data set. 412 | labels_file: string, path to the labels file. 413 | """ 414 | filenames, masks = _find_image_files(directory_list, data_extension, mask_patterns, ignore_non_positive_masks) 415 | _process_image_files(name, filenames, masks, output_directory, num_shards, num_threads, resize=resize, naming_fn = naming_fn, dilate = dilate, require_binary_output=require_binary_output, export_images = export_images) 416 | 417 | -------------------------------------------------------------------------------- /segdec_data.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Small library that points to the ImageNet data set. 16 | """ 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | from input_data.dataset import Dataset 22 | 23 | 24 | class InputData(Dataset): 25 | """GOSTOP data set.""" 26 | 27 | def __init__(self, subset, data_dir): 28 | super(InputData, self).__init__('SegDecData', subset, data_dir) 29 | 30 | def available_subsets(self): 31 | """Returns the list of available subsets.""" 32 | subsets = super(InputData, self).available_subsets() 33 | subsets.append('all') 34 | return subsets 35 | 36 | def num_classes(self): 37 | """Returns the number of classes in the data set.""" 38 | return 2 39 | 40 | def num_examples_per_epoch(self): 41 | """Returns the number of examples in the data set.""" 42 | return super(InputData, self).num_examples_per_epoch() 43 | 44 | 45 | def download_message(self): 46 | print('Missing data.') 47 | 48 | -------------------------------------------------------------------------------- /segdec_model.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import os 6 | import re 7 | 8 | import tensorflow as tf 9 | 10 | from tensorflow.contrib import keras 11 | from tensorflow.contrib import layers 12 | from tensorflow.contrib.framework.python.ops import arg_scope 13 | from tensorflow.contrib.layers.python.layers import layers as layers_lib 14 | from tensorflow.contrib.layers.python.layers import initializers 15 | from tensorflow.contrib.layers.python.layers import utils 16 | from tensorflow.python.ops import variable_scope 17 | 18 | class SegDecModel(object): 19 | 20 | # If a model is trained using multiple GPUs, prefix all Op names with tower_name 21 | # to differentiate the operations. Note that this prefix is removed from the 22 | # names of the summaries when visualizing a model. 23 | TOWER_NAME = 'tower' 24 | 25 | # Batch normalization. Constant governing the exponential moving average of 26 | # the 'global' mean and variance for all activations. 27 | BATCHNORM_MOVING_AVERAGE_DECAY = 0.9997 28 | 29 | # The decay to use for the moving average. 30 | MOVING_AVERAGE_DECAY = 0.9999 31 | 32 | DECISION_NET_NONE = 0 33 | DECISION_NET_LOGISTIC = 1 34 | DECISION_NET_FULL = 2 35 | 36 | def __init__(self, 37 | use_corss_entropy_seg_net=True, 38 | positive_weight=1, 39 | decision_net=DECISION_NET_NONE, 40 | decision_positive_weight=1, 41 | load_from_seg_only_net=False): 42 | 43 | # weight for positive samples in segmentation net 44 | self.positive_weight = positive_weight 45 | 46 | # weight for positive samples in decision net 47 | self.decision_positive_weight = decision_positive_weight 48 | 49 | if decision_net == SegDecModel.DECISION_NET_NONE: 50 | self.decision_net_fn = lambda net, net_prob_mat: None 51 | elif decision_net == SegDecModel.DECISION_NET_LOGISTIC: 52 | self.decision_net_fn = self.get_decision_net_simple 53 | elif decision_net == SegDecModel.DECISION_NET_FULL: 54 | self.decision_net_fn = self.get_decision_net 55 | 56 | self.use_corss_entropy_seg_net = use_corss_entropy_seg_net 57 | 58 | # this is only when loading from pre-trained network of segmetnation that did not have decision net layers 59 | # present at the same time 60 | self.load_from_seg_only_net = load_from_seg_only_net 61 | 62 | 63 | 64 | 65 | def get_inference(self, inputs, num_classes, for_training=False, restore_logits=True, scope=None): 66 | """ Build model 67 | 68 | 69 | Args: 70 | images: Images returned from inputs() or distorted_inputs(). 71 | num_classes: number of classes 72 | for_training: If set to `True`, build the inference model for training. 73 | Kernels that operate differently for inference during training 74 | e.g. dropout, are appropriately configured. 75 | restore_logits: whether or not the logits layers should be restored. 76 | Useful for fine-tuning a model with different num_classes. 77 | scope: optional prefix string identifying the ImageNet tower. 78 | 79 | Returns: 80 | Logits. 2-D float Tensor. 81 | Auxiliary Logits. 2-D float Tensor of side-head. Used for training only. 82 | """ 83 | 84 | 85 | with variable_scope.variable_scope(scope, 'SegDecNet', [inputs]) as sc: 86 | end_points_collection = sc.original_name_scope + '_end_points' 87 | # Collect outputs for conv2d, max_pool2d 88 | with arg_scope( 89 | [layers.conv2d, layers.fully_connected, layers_lib.max_pool2d, layers.batch_norm], 90 | outputs_collections=end_points_collection): 91 | 92 | # Apply specific parameters to all conv2d layers (to use batch norm and relu - relu is by default) 93 | with arg_scope([layers.conv2d, layers.fully_connected], 94 | weights_initializer= lambda shape,dtype=tf.float32, partition_info=None: tf.random_normal(shape, mean=0,stddev=0.01, dtype=dtype), 95 | biases_initializer=None, 96 | normalizer_fn=layers.batch_norm, 97 | normalizer_params={'center': True, 98 | 'scale': True, 99 | #'is_training': for_training, # we disable this to do feature normalization (but requires batch size=1) 100 | 'decay': self.BATCHNORM_MOVING_AVERAGE_DECAY, # Decay for the moving averages. 101 | 'epsilon': 0.001, # epsilon to prevent 0s in variance. 102 | }): 103 | 104 | net = layers_lib.repeat(inputs, 2, layers.conv2d, 32, [5, 5], scope='conv1') 105 | 106 | net = layers_lib.max_pool2d(net, [2, 2], scope='pool1') 107 | 108 | net = layers_lib.repeat(net, 3, layers.conv2d, 64, [5, 5], scope='conv2') 109 | 110 | net = layers_lib.max_pool2d(net, [2, 2], scope='pool2') 111 | 112 | net = layers_lib.repeat(net, 4, layers.conv2d, 64, [5, 5], scope='conv3') 113 | 114 | net = layers_lib.max_pool2d(net, [2, 2], scope='pool3') 115 | 116 | net = layers.conv2d(net, 1024, [15, 15], padding='SAME', scope='conv4') 117 | 118 | net_prob_mat = layers.conv2d(net, 1, [1, 1], scope='conv5', 119 | activation_fn=None) 120 | 121 | decision_net = self.decision_net_fn(net, tf.nn.relu(net_prob_mat)) 122 | 123 | # Convert end_points_collection into a end_point dict. 124 | endpoints = utils.convert_collection_to_dict(end_points_collection) 125 | 126 | 127 | 128 | # Add summaries for viewing model statistics on TensorBoard. 129 | self._activation_summaries(endpoints) 130 | 131 | return net_prob_mat, decision_net, endpoints 132 | 133 | def get_decision_net_simple(self, net, net_prob_mat): 134 | 135 | avg_output = keras.layers.GlobalAveragePooling2D()(net_prob_mat) 136 | max_output = keras.layers.GlobalMaxPooling2D()(net_prob_mat) 137 | 138 | decision_net = tf.concat([avg_output, max_output], 3) 139 | 140 | decision_net = layers.conv2d(decision_net, 1, [1, 1], scope='decision6', 141 | normalizer_fn=None, 142 | weights_initializer=initializers.xavier_initializer_conv2d(False), 143 | biases_initializer=tf.constant_initializer(0), 144 | activation_fn=None) 145 | 146 | return decision_net 147 | 148 | def get_decision_net(self, net, net_prob_mat): 149 | 150 | with tf.name_scope('decision'): 151 | 152 | decision_net = tf.concat([net, net_prob_mat],axis=3) 153 | 154 | decision_net = layers_lib.max_pool2d(decision_net, [2, 2], scope='decision/pool4') 155 | 156 | decision_net = layers.conv2d(decision_net, 8, [5, 5], padding='SAME', scope='decision/conv6') 157 | 158 | decision_net = layers_lib.max_pool2d(decision_net, [2, 2], scope='decision/pool5') 159 | 160 | decision_net = layers.conv2d(decision_net, 16, [5, 5], padding='SAME', scope='decision/conv7') 161 | 162 | decision_net = layers_lib.max_pool2d(decision_net, [2, 2], scope='decision/pool6') 163 | 164 | decision_net = layers.conv2d(decision_net, 32, [5, 5], scope='decision/conv8') 165 | 166 | with tf.name_scope('decision/global_avg_pool'): 167 | avg_decision_net = keras.layers.GlobalAveragePooling2D()(decision_net) 168 | 169 | with tf.name_scope('decision/global_max_pool'): 170 | max_decision_net = keras.layers.GlobalMaxPooling2D()(decision_net) 171 | 172 | with tf.name_scope('decision/global_avg_pool'): 173 | avg_prob_net = keras.layers.GlobalAveragePooling2D()(net_prob_mat) 174 | 175 | with tf.name_scope('decision/global_max_pool'): 176 | max_prob_net = keras.layers.GlobalMaxPooling2D()(net_prob_mat) 177 | 178 | # adding avg_prob_net and max_prob_net may not be needed, but it doesen't hurt 179 | decision_net = tf.concat([avg_decision_net, max_decision_net, avg_prob_net, max_prob_net], axis=1) 180 | 181 | decision_net = layers.fully_connected(decision_net, 1, scope='decision/FC9', 182 | normalizer_fn=None, 183 | biases_initializer=tf.constant_initializer(0), 184 | activation_fn=None) 185 | return decision_net 186 | 187 | 188 | def get_loss(self, net_model, masks, batch_size=None, return_segmentation_net=True, return_decision_net=True, output_resolution_reduction=8): 189 | """Adds all losses for the model. 190 | 191 | Note the final loss is not returned. Instead, the list of losses are collected 192 | by slim.losses. The losses are accumulated in tower_loss() and summed to 193 | calculate the total loss. 194 | 195 | Args: 196 | logits: List of logits from inference(). Each entry is a 2-D float Tensor. 197 | labels: Labels from distorted_inputs or inputs(). 1-D tensor 198 | of shape [batch_size] 199 | batch_size: integer 200 | """ 201 | if not batch_size: 202 | raise Exception("Missing batch_size") 203 | 204 | net, decision_net, endpoints = net_model 205 | 206 | if output_resolution_reduction > 1: 207 | mask_blur_kernel = [output_resolution_reduction*2+1, output_resolution_reduction*2+1] 208 | masks = layers_lib.avg_pool2d(masks, mask_blur_kernel, stride=output_resolution_reduction, padding='SAME', scope='pool_mask',outputs_collections='tower_0/_end_points') 209 | 210 | if self.use_corss_entropy_seg_net is False: 211 | masks = tf.greater(masks, tf.constant(0.5)) 212 | 213 | 214 | predictions = net 215 | 216 | tf.summary.image('prediction', predictions) 217 | 218 | l1 = None 219 | l2 = None 220 | 221 | if return_segmentation_net: 222 | if self.positive_weight > 1: 223 | pos_pixels = tf.less(tf.constant(0.0), masks) 224 | neg_pixels = tf.greater_equal(tf.constant(0.0), masks) 225 | 226 | num_pos_pixels = tf.cast(tf.count_nonzero(pos_pixels), dtype=tf.float32) 227 | num_neg_pixels = tf.cast(tf.count_nonzero(neg_pixels), dtype=tf.float32) 228 | 229 | pos_pixels = tf.cast(pos_pixels, dtype=tf.float32) 230 | neg_pixels = tf.cast(neg_pixels, dtype=tf.float32) 231 | 232 | positive_weight = tf.cond(num_pos_pixels > tf.constant(0,dtype=tf.float32), 233 | lambda: tf.multiply(tf.div(num_neg_pixels, num_pos_pixels), 234 | tf.constant(self.positive_weight,dtype=tf.float32)), 235 | lambda: tf.constant(self.positive_weight, dtype=tf.float32)) 236 | 237 | positive_weight = tf.reshape(positive_weight, [1]) 238 | 239 | # weight positive samples more !! 240 | weights = tf.add(neg_pixels, tf.multiply(pos_pixels, positive_weight)) 241 | 242 | # noramlize weights so that the sum of weights is always equal to the num of elements 243 | N = tf.constant(weights.shape[1]._value * weights.shape[2]._value, dtype=tf.float32) 244 | 245 | factor = tf.reduce_sum(weights,axis=[1,2]) 246 | factor = tf.divide(N, factor) 247 | 248 | weights = tf.multiply(weights, tf.reshape(factor,[-1,1,1,1])) 249 | 250 | if self.use_corss_entropy_seg_net is False: 251 | l1 = tf.losses.mean_squared_error(masks, predictions, weights=weights) 252 | else: 253 | l1 = tf.losses.sigmoid_cross_entropy(logits=predictions, multi_class_labels=masks, weights=weights) # NOTE: weights were added but not tested yet !! 254 | else: 255 | if self.use_corss_entropy_seg_net is False: 256 | l1 = tf.losses.mean_squared_error(masks, predictions) 257 | else: 258 | l1 = tf.losses.sigmoid_cross_entropy(logits=predictions,multi_class_labels=masks) 259 | 260 | 261 | if return_decision_net: 262 | with tf.name_scope('decision'): 263 | masks = tf.cast(masks, tf.float32) 264 | label = tf.minimum(tf.reduce_sum(masks, [1, 2, 3]), tf.constant(1.0)) 265 | 266 | if len(decision_net.shape) == 2: 267 | decision_net = tf.squeeze(decision_net, [1]) 268 | elif len(decision_net.shape) == 4: 269 | decision_net = tf.squeeze(decision_net, [1, 2, 3]) 270 | else: 271 | raise Exception("Only 2 or 4 dimensional output expected for decision_net") 272 | 273 | decision_net = tf.reshape(decision_net,[-1,1]) 274 | label = tf.reshape(label, [-1, 1]) 275 | 276 | l2 = tf.losses.sigmoid_cross_entropy(logits=decision_net,multi_class_labels=label, weights=self.decision_positive_weight) 277 | 278 | return [l1,l2] 279 | 280 | 281 | 282 | 283 | 284 | def _activation_summary(self, x): 285 | """Helper to create summaries for activations. 286 | 287 | Creates a summary that provides a histogram of activations. 288 | Creates a summary that measure the sparsity of activations. 289 | 290 | Args: 291 | x: Tensor 292 | """ 293 | # Remove 'tower_[0-9]/' from the name in case this is a multi-GPU training 294 | # session. This helps the clarity of presentation on tensorboard. 295 | tensor_name = re.sub('%s_[0-9]*/' % self.TOWER_NAME, '', x.op.name) 296 | tf.summary.histogram(tensor_name + '/activations', x) 297 | tf.summary.scalar(tensor_name + '/sparsity', tf.nn.zero_fraction(x)) 298 | 299 | 300 | def _activation_summaries(self, endpoints): 301 | with tf.name_scope('summaries'): 302 | for act in endpoints.values(): 303 | self._activation_summary(act) 304 | 305 | 306 | def restore(self, session, model_checkpoint_path, variables_to_restore = None, load_from_seg_only_net=False): 307 | 308 | if variables_to_restore is None: 309 | variables_to_restore = tf.trainable_variables()# + tf.moving_average_variables() # tf.moving_average_variables is required only in TF r1.1 310 | 311 | # this is only when loading from pre-trained network of segmetnation that did not have decision net layers 312 | # present at the same time 313 | if load_from_seg_only_net: 314 | variables_to_restore = [v for v in variables_to_restore if v.name.count('decision') == 0] 315 | 316 | saver = tf.train.Saver(variables_to_restore) 317 | try: 318 | saver.restore(session, model_checkpoint_path) 319 | 320 | except: 321 | # remove decision variables if cannot load them 322 | if type(variables_to_restore) is dict: 323 | variables_to_restore = [variables_to_restore[v] for v in variables_to_restore.keys() if v.find('decision') < 0] 324 | else: 325 | variables_to_restore = [v for v in variables_to_restore if v.name.find('decision') < 0] 326 | 327 | saver = tf.train.Saver(variables_to_restore) 328 | 329 | saver.restore(session, model_checkpoint_path) 330 | -------------------------------------------------------------------------------- /segdec_print_eval.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import numpy as np 3 | 4 | from sklearn.metrics import precision_recall_curve, roc_curve, auc, average_precision_score 5 | 6 | def calc_confusion_mat(D, Y): 7 | FP = (D != Y) & (Y.astype(np.bool) == False) 8 | FN = (D != Y) & (Y.astype(np.bool) == True) 9 | TN = (D == Y) & (Y.astype(np.bool) == False) 10 | TP = (D == Y) & (Y.astype(np.bool) == True) 11 | 12 | return FP, FN, TN, TP 13 | 14 | def get_performance_eval(P,Y): 15 | precision_, recall_, thresholds = precision_recall_curve(Y.astype(np.int32), P) 16 | FPR, TPR, _ = roc_curve(Y.astype(np.int32), P) 17 | AUC = auc(FPR, TPR) 18 | AP = average_precision_score(Y.astype(np.int32), P) 19 | 20 | f_measure = 2 * (precision_ * recall_) / (precision_ + recall_ + 0.0000000001) 21 | 22 | best_idx = np.argmax(f_measure) 23 | 24 | f_measure[best_idx] 25 | thr = thresholds[best_idx] 26 | 27 | FP, FN, TN, TP = calc_confusion_mat(P >= thr, Y) 28 | 29 | FP_, FN_, TN_, TP_ = calc_confusion_mat(P >= thresholds[np.where(recall_ >= 1)], Y) 30 | 31 | F_measure = (2 * TP.sum()) / float(2 * TP.sum() + FP.sum() + FN.sum()) 32 | 33 | return TP, FP, FN, TN, TP_, FP_, FN_, TN_, F_measure, AUC, AP 34 | 35 | def evaluate_decision(data_dir, folds_list = [0,1,2]): 36 | 37 | PD_decision_net = None 38 | 39 | num_params_list = [] 40 | 41 | for f in folds_list: 42 | if f >= 0: 43 | fold_name = 'fold_%d' % f 44 | else: 45 | fold_name = '' 46 | 47 | sample_outcomes = np.load(os.path.join(data_dir, fold_name, 'test', 'results_decision_net.npy')) 48 | 49 | if len(sample_outcomes) > 0: 50 | PD_decision_net = np.concatenate((PD_decision_net, sample_outcomes)) if PD_decision_net is not None else sample_outcomes 51 | 52 | num_params_filename = os.path.join(data_dir, fold_name, 'test', 'decision_net_num_params.npy') 53 | if os.path.exists(num_params_filename): 54 | num_params_list.append(np.load(num_params_filename)) 55 | 56 | results = None 57 | 58 | if PD_decision_net is not None: 59 | 60 | TP, FP, FN, TN, TP_, FP_, FN_, TN_, F_measure, AUC, AP = get_performance_eval(PD_decision_net[:,0], PD_decision_net[:,1]) 61 | 62 | print "AP: %.03f, FP/FN: %d/%d, FP@FN=0: %d" % (AP, FP.sum(), FN.sum(), FP_.sum()) 63 | 64 | results = {'TP': TP.sum(), 65 | 'FP': FP.sum(), 66 | 'FN': FN.sum(), 67 | 'TN': FN.sum(), 68 | 'FP@FN=0': FP_.sum(), 69 | 'f-measure': F_measure, 70 | 'AUC': AUC, 71 | 'AP': AP} 72 | 73 | return results 74 | 75 | 76 | if __name__ == "__main__": 77 | 78 | evaluate_decision(sys.argv[1], folds_list = [0,1,2]) 79 | 80 | -------------------------------------------------------------------------------- /segdec_train.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import copy 6 | import math 7 | import os.path 8 | import re 9 | import time 10 | from datetime import datetime 11 | 12 | import numpy as np 13 | import pylab as plt 14 | from tensorflow.contrib import slim 15 | 16 | from input_data.image_processing import NetInputProcessing 17 | 18 | 19 | class SegDecTrain(object): 20 | # Constants dictating the learning rate schedule. 21 | RMSPROP_DECAY = 0.9 # Decay term for RMSProp. 22 | RMSPROP_MOMENTUM = 0.9 # Momentum in RMSProp. 23 | RMSPROP_EPSILON = 1.0 # Epsilon term for RMSProp. 24 | 25 | def __init__(self, model, storage_dir, run_string, image_size, batch_size, 26 | learning_rate = 0.01, 27 | max_epochs = 1000, 28 | max_steps = 10000000, 29 | num_gpus = 1, 30 | visible_device_list = None, 31 | num_preprocess_threads = 1, 32 | pretrained_model_checkpoint_path = None, 33 | train_segmentation_net = True, 34 | train_decision_net = False, 35 | use_random_rotation=False, 36 | ensure_posneg_balance=True): 37 | 38 | self.model = model 39 | 40 | run_train_string = run_string[0] if type(run_string) is tuple else run_string 41 | run_eval_string = run_string[1] if type(run_string) is tuple else run_string 42 | 43 | self.visible_device_list = visible_device_list 44 | self.batch_size = batch_size 45 | self.train_dir = os.path.join(storage_dir, 'segdec_train', run_train_string) # Directory where to write event logs and checkpoint. 46 | self.eval_dir = os.path.join(storage_dir, 'segdec_eval', run_eval_string) 47 | 48 | # Takes number of learning batch iterations based on min(self.max_steps, self.max_epoch * num_batches_per_epoch) 49 | self.max_steps = max_steps # Number of batches to run. 50 | self.max_epochs = max_epochs # Number of epochs to run 51 | 52 | # Flags governing the hardware employed for running TensorFlow. 53 | self.num_gpus = num_gpus # How many GPUs to use. 54 | self.log_device_placement = False # Whether to log device placement 55 | 56 | self.num_preprocess_threads = num_preprocess_threads 57 | # Flags governing the type of training. 58 | self.fine_tune = False # If set, randomly initialize the final layer of weights in order to train the network on a new task. 59 | self.pretrained_model_checkpoint_path = pretrained_model_checkpoint_path # If specified, restore this pretrained model before beginning any training. 60 | 61 | self.initial_learning_rate = learning_rate # Initial learning rate. 62 | self.decay_steps = 0 # no decay by default 63 | self.learning_rate_decay_factor = 1 64 | 65 | self.TOWER_NAME = "tower" 66 | 67 | # Batch normalization. Constant governing the exponential moving average of 68 | # the 'global' mean and variance for all activations. 69 | self.BATCHNORM_MOVING_AVERAGE_DECAY = 0.9997 70 | 71 | # The decay to use for the moving average. 72 | self.MOVING_AVERAGE_DECAY = 0.9999 73 | 74 | # Override the number of preprocessing threads to account for the increased 75 | # number of GPU towers. 76 | input_num_preprocess_threads = self.num_preprocess_threads * self.num_gpus 77 | 78 | self.input = NetInputProcessing(batch_size=self.batch_size, 79 | num_preprocess_threads=input_num_preprocess_threads, 80 | input_size=image_size, 81 | mask_size=(image_size[0],image_size[1],1), 82 | use_random_rotation=use_random_rotation, 83 | ensure_posneg_balance=ensure_posneg_balance) 84 | 85 | self.train_segmentation_net = train_segmentation_net 86 | self.train_decision_net = train_decision_net 87 | 88 | assert self.batch_size == 1, "Only batch_size=1 is allowed due to the way the batch_norm is used to normalize features in testing !!!" 89 | 90 | self.loss_print_step = 11 91 | self.summary_step = 110 92 | self.checkpoint_step = 10007 93 | 94 | def _tower_loss(self, images, masks, num_classes, scope, reuse_variables=None): 95 | """Calculate the total loss on a single tower running the ImageNet model. 96 | 97 | We perform 'batch splitting'. This means that we cut up a batch across 98 | multiple GPU's. For instance, if the batch size = 32 and num_gpus = 2, 99 | then each tower will operate on an batch of 16 images. 100 | 101 | Args: 102 | images: Images. 4D tensor of size [batch_size, FLAGS.image_size, 103 | FLAGS.image_size, 3]. 104 | labels: 1-D integer Tensor of [batch_size]. 105 | num_classes: number of classes 106 | scope: unique prefix string identifying the ImageNet tower, e.g. 107 | 'tower_0'. 108 | 109 | Returns: 110 | Tensor of shape [] containing the total loss for a batch of data 111 | """ 112 | # When fine-tuning a model, we do not restore the logits but instead we 113 | # randomly initialize the logits. The number of classes in the output of the 114 | # logit is the number of classes in specified Dataset. 115 | restore_logits = not self.fine_tune 116 | 117 | # Build inference Graph. 118 | with tf.variable_scope(tf.get_variable_scope(), reuse=reuse_variables): 119 | net_model = self.model.get_inference(images, num_classes, for_training=True, 120 | restore_logits=restore_logits, 121 | scope=scope) 122 | 123 | # Build the portion of the Graph calculating the losses. Note that we will 124 | # assemble the total_loss using a custom function below. 125 | split_batch_size = images.get_shape().as_list()[0] 126 | self.model.get_loss(net_model, masks, 127 | batch_size=split_batch_size, 128 | return_segmentation_net=self.train_segmentation_net, 129 | return_decision_net=self.train_decision_net) 130 | 131 | # Assemble all of the losses for the current tower only. 132 | losses = tf.get_collection(tf.GraphKeys.LOSSES, scope) 133 | 134 | # Calculate the total loss for the current tower. 135 | regularization_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES) 136 | total_loss = tf.add_n(losses + regularization_losses, name='total_loss') 137 | 138 | # Compute the moving average of all individual losses and the total loss. 139 | loss_averages = tf.train.ExponentialMovingAverage(0.9, name='avg') 140 | loss_averages_op = loss_averages.apply(losses + [total_loss]) 141 | 142 | # Attach a scalar summmary to all individual losses and the total loss; do the 143 | # same for the averaged version of the losses. 144 | for l in losses + [total_loss]: 145 | # Remove 'tower_[0-9]/' from the name in case this is a multi-GPU training 146 | # session. This helps the clarity of presentation on TensorBoard. 147 | loss_name = re.sub('%s_[0-9]*/' % self.TOWER_NAME, '', l.op.name) 148 | # Name each loss as '(raw)' and name the moving average version of the loss 149 | # as the original loss name. 150 | tf.summary.scalar(loss_name +' (raw)', l) 151 | tf.summary.scalar(loss_name, loss_averages.average(l)) 152 | 153 | with tf.control_dependencies([loss_averages_op]): 154 | total_loss = tf.identity(total_loss) 155 | return total_loss 156 | 157 | 158 | def _average_gradients(self, tower_grads): 159 | """Calculate the average gradient for each shared variable across all towers. 160 | 161 | Note that this function provides a synchronization point across all towers. 162 | 163 | Args: 164 | tower_grads: List of lists of (gradient, variable) tuples. The outer list 165 | is over individual gradients. The inner list is over the gradient 166 | calculation for each tower. 167 | Returns: 168 | List of pairs of (gradient, variable) where the gradient has been averaged 169 | across all towers. 170 | """ 171 | average_grads = [] 172 | for grad_and_vars in zip(*tower_grads): 173 | # Note that each grad_and_vars looks like the following: 174 | # ((grad0_gpu0, var0_gpu0), ... , (grad0_gpuN, var0_gpuN)) 175 | grads = [] 176 | for g, _ in grad_and_vars: 177 | # Add 0 dimension to the gradients to represent the tower. 178 | expanded_g = tf.expand_dims(g, 0) 179 | 180 | # Append on a 'tower' dimension which we will average over below. 181 | grads.append(expanded_g) 182 | 183 | # Average over the 'tower' dimension. 184 | grad = tf.concat(axis=0, values=grads) 185 | grad = tf.reduce_mean(grad, 0) 186 | 187 | # Keep in mind that the Variables are redundant because they are shared 188 | # across towers. So .. we will just return the first tower's pointer to 189 | # the Variable. 190 | v = grad_and_vars[0][1] 191 | grad_and_var = (grad, v) 192 | average_grads.append(grad_and_var) 193 | return average_grads 194 | 195 | 196 | def train(self, dataset): 197 | """Train on input_data for a number of steps.""" 198 | with tf.Graph().as_default(), tf.device('/cpu:0'): 199 | # Create a variable to count the number of train() calls. This equals the 200 | # number of batches processed * FLAGS.num_gpus. 201 | global_step = tf.get_variable( 202 | 'global_step', [], 203 | initializer=tf.constant_initializer(0), trainable=False) 204 | 205 | # Calculate the learning rate schedule. 206 | 207 | # Decay the learning rate exponentially based on the number of steps. 208 | if self.decay_steps > 0: 209 | lr = tf.train.exponential_decay(self.initial_learning_rate, 210 | global_step, 211 | self.decay_steps, 212 | self.learning_rate_decay_factor, 213 | staircase=True) 214 | else: 215 | lr = self.initial_learning_rate 216 | 217 | # Create an optimizer that performs gradient descent. 218 | opt = tf.train.GradientDescentOptimizer(lr) 219 | 220 | # Get images and labels for ImageNet and split the batch across GPUs. 221 | assert self.batch_size % self.num_gpus == 0, ( 222 | 'Batch size must be divisible by number of GPUs') 223 | 224 | images, masks, _ = self.input.add_inputs_nodes(dataset, True) 225 | 226 | 227 | input_summaries = copy.copy(tf.get_collection(tf.GraphKeys.SUMMARIES)) 228 | 229 | # Number of classes in the Dataset label set plus 1. 230 | # Label 0 is reserved for an (unused) background class. 231 | num_classes = dataset.num_classes() + 1 232 | 233 | # Split the batch of images and labels for towers. 234 | images_splits = tf.split(axis=0, num_or_size_splits=self.num_gpus, value=images) 235 | masks_splits = tf.split(axis=0, num_or_size_splits=self.num_gpus, value=masks) 236 | 237 | # Calculate the gradients for each model tower. 238 | tower_grads = [] 239 | reuse_variables = None 240 | for i in range(self.num_gpus): 241 | with tf.device('/gpu:%d' % i): 242 | with tf.name_scope('%s_%d' % (self.TOWER_NAME, i)) as scope: 243 | # Force all Variables to reside on the CPU. 244 | with slim.arg_scope([slim.variable], device='/cpu:0'): 245 | # Calculate the loss for one tower of the ImageNet model. This 246 | # function constructs the entire ImageNet model but shares the 247 | # variables across all towers. 248 | loss = self._tower_loss(images_splits[i], masks_splits[i], num_classes, 249 | scope, reuse_variables) 250 | 251 | # Reuse variables for the next tower. 252 | reuse_variables = True 253 | 254 | # Retain the summaries from the final tower. 255 | summaries = tf.get_collection(tf.GraphKeys.SUMMARIES, scope) 256 | 257 | # Retain the Batch Normalization updates operations only from the 258 | # final tower. Ideally, we should grab the updates from all towers 259 | # but these stats accumulate extremely fast so we can ignore the 260 | # other stats from the other towers without significant detriment. 261 | batchnorm_updates = tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope) 262 | 263 | # Calculate the gradients for the batch of data on this ImageNet 264 | # tower. 265 | grads = opt.compute_gradients(loss) 266 | 267 | # Keep track of the gradients across all towers. 268 | tower_grads.append(grads) 269 | 270 | variables_to_average = (tf.trainable_variables() + 271 | tf.moving_average_variables()) 272 | 273 | # if decision_net is not trained then remove all gradients for decision 274 | if self.train_decision_net is False: 275 | tower_grads = [[g for g in tg if g[1].name.find('decision') < 0] for tg in tower_grads] 276 | 277 | variables_to_average = [v for v in variables_to_average if v.name.find('decision') < 0] 278 | 279 | # if segmentation_net is not trained then remove all gradients for segmentation net 280 | # i.e. we assume all variables NOT flaged as decision net are segmentation net 281 | if self.train_segmentation_net is False: 282 | tower_grads = [[g for g in tg if g[1].name.find('decision') >= 0] for tg in tower_grads] 283 | 284 | # We must calculate the mean of each gradient. Note that this is the 285 | # synchronization point across all towers. 286 | grads = self._average_gradients(tower_grads) 287 | 288 | # Apply the gradients to adjust the shared variables. 289 | apply_gradient_op = opt.apply_gradients(grads, global_step=global_step) 290 | 291 | # Track the moving averages of all trainable variables. 292 | # Note that we maintain a "double-average" of the BatchNormalization 293 | # global statistics. This is more complicated then need be but we employ 294 | # this for backward-compatibility with our previous models. 295 | variable_averages = tf.train.ExponentialMovingAverage(self.MOVING_AVERAGE_DECAY, global_step) 296 | 297 | # Another possibility is to use tf.slim.get_variables(). 298 | variables_averages_op = variable_averages.apply(variables_to_average) 299 | 300 | # Group all updates to into a single train op. 301 | batchnorm_updates_op = tf.group(*batchnorm_updates) 302 | train_op = tf.group(apply_gradient_op, variables_averages_op, 303 | batchnorm_updates_op) 304 | 305 | # Add summaries and visualization 306 | 307 | 308 | # Add histograms for trainable variables. 309 | for var in tf.trainable_variables(): 310 | summaries.append(tf.summary.histogram(var.op.name, var)) 311 | 312 | # Add weight visualization 313 | weight_variables = [v for v in tf.global_variables() if v.name.find('/weights') >= 0] 314 | 315 | for c in ['conv1_1','conv1_2', 316 | 'conv2_1', 'conv2_2', 'conv2_3', 317 | 'conv3_1', 'conv3_2', 'conv3_3', 'conv3_4']: 318 | with tf.name_scope(c): 319 | w = [v for v in weight_variables if v.name.find('/' + c + '/') >= 0] 320 | w = w[0] 321 | 322 | x_min = tf.reduce_min(w) 323 | x_max = tf.reduce_max(w) 324 | ww = (w - x_min) / (x_max - x_min) 325 | 326 | ww_t = tf.transpose(ww, [3, 0, 1, 2]) 327 | ww_t = tf.reshape(ww_t[:,:,:,0], [int(ww_t.shape[0]), int(ww_t.shape[1]), int(ww_t.shape[2]), 1]) 328 | tf.summary.image(c, ww_t, max_outputs=10) 329 | 330 | summaries.extend(tf.get_collection(tf.GraphKeys.SUMMARIES, c)) 331 | 332 | # Add a summaries for the input processing and global_step. 333 | summaries.extend(input_summaries) 334 | 335 | # Add a summary to track the learning rate. 336 | summaries.append(tf.summary.scalar('learning_rate', lr)) 337 | 338 | # Add histograms for gradients. 339 | for grad, var in grads: 340 | if grad is not None: 341 | summaries.append( 342 | tf.summary.histogram(var.op.name + '/gradients', grad)) 343 | 344 | summaries = tf.get_collection(tf.GraphKeys.SUMMARIES) 345 | # Create a saver. 346 | saver = tf.train.Saver(tf.global_variables()) 347 | 348 | # Build the summary operation from the last tower summaries. 349 | summary_op = tf.summary.merge(summaries) 350 | 351 | 352 | # Build an initialization operation to run below. 353 | init = tf.global_variables_initializer() 354 | 355 | # Start running operations on the Graph. allow_soft_placement must be set to 356 | # True to build towers on GPU, as some of the ops do not have GPU 357 | # implementations. 358 | c = tf.ConfigProto(allow_soft_placement=True, 359 | log_device_placement=self.log_device_placement) 360 | if self.visible_device_list is not None: 361 | c.gpu_options.visible_device_list = self.visible_device_list 362 | c.gpu_options.allow_growth = True 363 | 364 | sess = tf.Session(config=c) 365 | sess.run(init) 366 | 367 | # restore weights from previous model 368 | if self.pretrained_model_checkpoint_path is not None: 369 | ckpt = tf.train.get_checkpoint_state(self.pretrained_model_checkpoint_path) 370 | if ckpt is None: 371 | raise Exception('No valid saved model found in ' + self.pretrained_model_checkpoint_path) 372 | 373 | self.model.restore(sess, ckpt.model_checkpoint_path) 374 | 375 | # Start the queue runners. 376 | tf.train.start_queue_runners(sess=sess) 377 | 378 | summary_writer = tf.summary.FileWriter( 379 | self.train_dir, 380 | graph=sess.graph) 381 | 382 | num_steps = min(int(self.max_epochs * dataset.num_examples_per_epoch() / self.batch_size), 383 | self.max_steps) 384 | 385 | prev_duration = None 386 | 387 | for step in range(num_steps): 388 | 389 | run_nodes = [train_op, loss] 390 | 391 | if step % self.summary_step == 0: 392 | run_nodes = [train_op, loss, summary_op] 393 | 394 | start_time = time.time() 395 | output_vals = sess.run(run_nodes) 396 | duration = time.time() - start_time 397 | 398 | if prev_duration is None: 399 | prev_duration = duration 400 | 401 | loss_value = output_vals[1] 402 | 403 | assert not np.isnan(loss_value), 'Model diverged with loss = NaN' 404 | 405 | if step % self.loss_print_step == 0: 406 | examples_per_sec = self.batch_size / float(prev_duration) 407 | format_str = ('%s: step %d, loss = %.5f (%.1f examples/sec; %.3f ' 408 | 'sec/batch)') 409 | print(format_str % (datetime.now(), step, loss_value, 410 | examples_per_sec, prev_duration)) 411 | 412 | if step % self.summary_step == 0: 413 | summary_str = output_vals[2] 414 | summary_writer.add_summary(summary_str, step) 415 | 416 | # Save the model checkpoint periodically. 417 | if step % self.checkpoint_step == 0 or (step + 1) == num_steps: 418 | checkpoint_path = os.path.join(self.train_dir, 'model.ckpt') 419 | saver.save(sess, checkpoint_path, global_step=step) 420 | 421 | prev_duration = duration 422 | 423 | def _eval_once(self, eval_dir, variables_to_restore, net_op, decision_op, images_op, labels_op, img_names_op, num_examples, plot_results=True): 424 | """Runs Eval once. 425 | 426 | Args: 427 | saver: Saver. 428 | summary_writer: Summary writer. 429 | net_op: net operation with prediction 430 | summary_op: Summary op. 431 | """ 432 | c = tf.ConfigProto() 433 | if self.visible_device_list is not None: 434 | c.gpu_options.visible_device_list = self.visible_device_list 435 | c.gpu_options.allow_growth = True 436 | with tf.Session(config=c) as sess: 437 | ckpt = tf.train.get_checkpoint_state(self.train_dir) 438 | if ckpt and ckpt.model_checkpoint_path: 439 | 440 | model_checkpoint_path = ckpt.model_checkpoint_path 441 | 442 | # Restores from checkpoint with relative path. 443 | if os.path.isabs(model_checkpoint_path): 444 | model_checkpoint_path = os.path.join(self.train_dir, model_checkpoint_path) 445 | 446 | self.model.restore(sess, model_checkpoint_path, variables_to_restore) 447 | 448 | # Assuming model_checkpoint_path looks something like: 449 | # /my-favorite-path/imagenet_train/model.ckpt-0, 450 | # extract global_step from it. 451 | global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1] 452 | print('Successfully loaded model from %s at step=%s.' % 453 | (ckpt.model_checkpoint_path, global_step)) 454 | else: 455 | print('No checkpoint file found') 456 | return 457 | 458 | # Start the queue runners. 459 | coord = tf.train.Coordinator() 460 | try: 461 | threads = [] 462 | for qr in tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS): 463 | threads.extend(qr.create_threads(sess, coord=coord, daemon=True, 464 | start=True)) 465 | 466 | num_iter = int(math.ceil(num_examples / self.batch_size)) 467 | 468 | # Counts the number of correct predictions. 469 | samples_outcome = [] 470 | samples_names = [] 471 | samples_speed_eval = [] 472 | 473 | total_sample_count = num_iter * self.batch_size 474 | step = 0 475 | 476 | print('%s: starting evaluation on (%s).' % (datetime.now(), '')) 477 | start_time = time.time() 478 | while step < num_iter and not coord.should_stop(): 479 | start_time_run = time.time() 480 | if decision_op is None: 481 | predictions, image, label, img_name = sess.run([net_op, images_op, labels_op, img_names_op]) 482 | else: 483 | predictions, decision, image, label, img_name = sess.run([net_op, decision_op, images_op, labels_op, img_names_op]) 484 | 485 | decision = 1.0/(1+np.exp(-np.squeeze(decision))) 486 | 487 | # if we use sigmoid cross-correlation loss, then we need to add sigmoid to predictions 488 | # since this is usually handled by loss which we do not use in inference 489 | if self.model.use_corss_entropy_seg_net: 490 | predictions = 1.0/(1+np.exp(-predictions)) 491 | 492 | end_time_run = time.time() 493 | 494 | name = str(img_name[0]).replace("/", "_") 495 | samples_names.append(name) 496 | 497 | np.save(str.format("{0}/result_{2}.npy", eval_dir, step, name), predictions) 498 | np.save(str.format("{0}/result_{2}_gt.npy", eval_dir, step, name), label) 499 | 500 | if plot_results: 501 | plt.figure(1) 502 | plt.clf() 503 | plt.subplot(1, 3, 1) 504 | plt.title('Input image') 505 | plt.imshow(image[0, :, :, 0], cmap="gray") 506 | 507 | plt.subplot(1, 3, 2) 508 | plt.title('Groundtruth') 509 | plt.imshow(label[0, :, :, 0], cmap="gray") 510 | 511 | plt.subplot(1, 3, 3) 512 | if decision_op is None: 513 | plt.title('Output/prediction') 514 | else: 515 | plt.title(str.format('Output/prediction: {0}',decision)) 516 | 517 | # display max 518 | vmax_value = max(1, predictions.max()) 519 | 520 | plt.imshow((predictions[0, :, :, 0] > 0) * predictions[0, :, :, 0], cmap="jet", vmax=vmax_value) 521 | plt.suptitle(str(img_name[0])) 522 | 523 | plt.show(block=0) 524 | 525 | out_prefix = '' 526 | 527 | if decision_op is not None: 528 | out_prefix = '%.3f_' % decision 529 | 530 | plt.savefig(str.format("{0}/{1}result_{2}.pdf", eval_dir, out_prefix, name), bbox_inches='tight') 531 | 532 | samples_speed_eval.append(end_time_run - start_time_run) 533 | 534 | if decision_op is None: 535 | pass 536 | else: 537 | samples_outcome.append((decision, np.max(label))) 538 | 539 | step += 1 540 | if step % 20 == 0: 541 | duration = time.time() - start_time 542 | sec_per_batch = duration / 20.0 543 | examples_per_sec = self.batch_size / sec_per_batch 544 | print('%s: [%d batches out of %d] (%.1f examples/sec; %.3f' 545 | 'sec/batch)' % (datetime.now(), step, num_iter, 546 | examples_per_sec, sec_per_batch)) 547 | start_time = time.time() 548 | 549 | if len(samples_outcome) > 0: 550 | from sklearn.metrics import precision_recall_curve, roc_curve, auc 551 | 552 | samples_outcome = np.matrix(np.array(samples_outcome)) 553 | 554 | idx = np.argsort(samples_outcome[:,0],axis=0) 555 | idx = idx[::-1] 556 | samples_outcome = np.squeeze(samples_outcome[idx, :]) 557 | samples_names = np.array(samples_names) 558 | samples_names = samples_names[idx] 559 | 560 | np.save(str.format("{0}/samples_outcome.npy", eval_dir), samples_outcome) 561 | np.save(str.format("{0}/samples_names.npy", eval_dir), samples_names) 562 | 563 | P = np.sum(samples_outcome[:, 1]) 564 | 565 | TP = np.cumsum(samples_outcome[:, 1] == 1).astype(np.float32).T 566 | FP = np.cumsum(samples_outcome[:, 1] == 0).astype(np.float32).T 567 | 568 | recall = TP / P 569 | precision = TP / (TP + FP) 570 | 571 | f_measure = 2 * np.multiply(recall, precision) / (recall + precision) 572 | 573 | 574 | idx = np.argmax(f_measure) 575 | 576 | best_f_measure = f_measure[idx] 577 | best_thr = samples_outcome[idx,0] 578 | best_FP = FP[idx] 579 | best_FN = P - TP[idx] 580 | 581 | precision_, recall_, thresholds = precision_recall_curve(samples_outcome[:, 1], samples_outcome[:, 0]) 582 | FPR, TPR, _ = roc_curve(samples_outcome[:, 1], samples_outcome[:, 0]) 583 | AUC = auc(FPR,TPR) 584 | AP = auc(recall_, precision_) 585 | 586 | print('AUC=%f, and AP=%f, with best thr=%f at f-measure=%.3f and FP=%d, FN=%d' % (AUC, AP, best_thr, best_f_measure, best_FP, best_FN)) 587 | 588 | plt.figure(1) 589 | plt.clf() 590 | plt.plot(recall, precision) 591 | plt.title('Average Precision=%.4f' % AP) 592 | plt.xlabel('Recall') 593 | plt.ylabel('Precision') 594 | plt.savefig(str.format("{0}/precision-recall.pdf", eval_dir), bbox_inches='tight') 595 | 596 | plt.figure(1) 597 | plt.clf() 598 | plt.plot(FPR, TPR) 599 | plt.title('AUC=%.4f' % AUC) 600 | plt.xlabel('False positive rate') 601 | plt.ylabel('True positive rate') 602 | plt.savefig(str.format("{0}/ROC.pdf", eval_dir), bbox_inches='tight') 603 | 604 | 605 | 606 | 607 | except Exception as e: # pylint: disable=broad-except 608 | coord.request_stop(e) 609 | 610 | coord.request_stop() 611 | coord.join(threads, stop_grace_period_secs=10) 612 | 613 | return samples_outcome,samples_names, samples_speed_eval 614 | 615 | def evaluate(self, dataset, run_once = True, eval_interval_secs = 5, plot_results=True): 616 | """Evaluate model on Dataset for a number of steps.""" 617 | with tf.Graph().as_default(): 618 | # Get images and labels from the input_data. 619 | images, labels, img_names = self.input.add_inputs_nodes(dataset, False) 620 | 621 | # Number of classes in the Dataset label set plus 1. 622 | # Label 0 is reserved for an (unused) background class. 623 | num_classes = dataset.num_classes() + 1 624 | 625 | # Build a Graph that computes the logits predictions from the 626 | # inference model. 627 | with tf.name_scope('%s_%d' % (self.TOWER_NAME, 0)) as scope: 628 | net, decision, _ = self.model.get_inference(images, num_classes, scope=scope) 629 | 630 | # Restore the moving average version of the learned variables for eval. 631 | variable_averages = tf.train.ExponentialMovingAverage(self.model.MOVING_AVERAGE_DECAY) 632 | variables_to_restore = variable_averages.variables_to_restore() 633 | 634 | eval_dir = os.path.join(self.eval_dir, dataset.subset) 635 | try: 636 | os.makedirs(eval_dir) 637 | except: 638 | pass 639 | 640 | while True: 641 | samples_outcome,samples_names, samples_speed_eval = self._eval_once(eval_dir, variables_to_restore, net, decision, images, labels, img_names, dataset.num_examples_per_epoch(),plot_results) 642 | if run_once: 643 | break 644 | time.sleep(eval_interval_secs) 645 | 646 | num_params = np.sum([np.prod(v.get_shape().as_list()) for v in tf.trainable_variables()]) 647 | 648 | return samples_outcome,samples_names, samples_speed_eval,num_params 649 | 650 | 651 | import tensorflow as tf 652 | 653 | from segdec_model import SegDecModel 654 | from segdec_data import InputData 655 | 656 | if __name__ == '__main__': 657 | 658 | import argparse, glob, shutil 659 | 660 | def str2bool(v): 661 | return v.lower() in ("yes", "true", "t", "1") 662 | 663 | parser = argparse.ArgumentParser() 664 | 665 | # add boolean parser to allow using 'false' in arguments 666 | parser.register('type', 'bool', str2bool) 667 | 668 | parser.add_argument('--folds',type=str, help="Comma delimited list of ints identifying which folds to use.") 669 | parser.add_argument('--gpu', type=str, help="Comma delimited list of ints identifying which GPU ids to use.") 670 | parser.add_argument('--storage_dir', help='Path to your storage dir where segdec_train (tensorboard info) and segdec_eval (results) will be stored.', 671 | type=str, 672 | default='/opt/workspace/host_storage_hdd/') 673 | parser.add_argument('--dataset_dir', help='Path to your input_data dirs.', 674 | type=str, 675 | default='/opt/workspace/host_storage_hdd/') 676 | parser.add_argument('--datasets', help='Comma delimited list of input_data names to use, e.g., "Dataset1,Dataset2".', 677 | type=str, default=','.join(['KolektorSDD'])) 678 | parser.add_argument('--name_prefix',type=str, default=None) 679 | parser.add_argument('--train_subset', type=str, default="train_pos") 680 | parser.add_argument('--pretrained_model', type=str, default=None) 681 | parser.add_argument('--pretrained_main_folder', type=str, default=None) 682 | 683 | parser.add_argument('--size_height', type=int, default=2*704) 684 | parser.add_argument('--size_width', type=int, default=2*256) 685 | 686 | parser.add_argument('--seg_net_type', type=str, default='MSE') 687 | 688 | parser.add_argument('--input_rotation', type='bool', default=False) 689 | 690 | parser.add_argument('--with_seg_net', type='bool', default=True) 691 | parser.add_argument('--with_decision_net', type='bool', default=False) 692 | parser.add_argument('--lr', type=float, default=0) 693 | parser.add_argument('--max_steps', type=int, default=6600) 694 | 695 | parser.add_argument('--channels', type=int, default=1) 696 | parser.add_argument('--pos_weights', type=float, default=1) 697 | 698 | parser.add_argument('--ensure_posneg_balance', type='bool', default=True) 699 | 700 | args = parser.parse_args() 701 | 702 | main_storage_dir = args.storage_dir 703 | main_dataset_folder = args.dataset_dir 704 | dataset_list = args.datasets.split(",") 705 | train_subset = args.train_subset 706 | pretrained_model = args.pretrained_model 707 | pretrained_main_folder = args.pretrained_main_folder 708 | pos_weights = args.pos_weights 709 | ensure_posneg_balance = args.ensure_posneg_balance 710 | 711 | size_height = args.size_height 712 | size_width = args.size_width 713 | channels = args.channels 714 | 715 | seg_net_type = args.seg_net_type 716 | 717 | input_rotation = args.input_rotation 718 | 719 | with_seg_net = args.with_seg_net 720 | with_decision_net = args.with_decision_net 721 | 722 | max_steps = args.max_steps 723 | lr = args.lr 724 | 725 | if seg_net_type == 'MSE': 726 | lr_val = 0.005 727 | use_corss_entropy_seg_net = False 728 | elif seg_net_type == 'ENTROPY': 729 | lr_val = 0.1 730 | use_corss_entropy_seg_net = True 731 | else: 732 | raise Exception('Unkown SEG-NET type; allowed only: \'MSE\' or \'ENTROPY\'') 733 | 734 | 735 | if lr > 0: 736 | lr_val = lr 737 | 738 | folds = [int(i) for i in args.folds.split(",")] 739 | for i in folds: 740 | if i >= 0: 741 | fold_name = 'fold_%d' % i 742 | else: 743 | fold_name = '' 744 | 745 | for d in dataset_list: 746 | 747 | run_name = os.path.join(d, fold_name if args.name_prefix is None else os.path.join(args.name_prefix, fold_name)) 748 | 749 | dataset_folder = os.path.join(main_dataset_folder, d) 750 | print("running", dataset_folder, run_name) 751 | 752 | if with_decision_net is False: 753 | # use bigger lr for sigmoid_corss_correlation loss 754 | net_model = SegDecModel(decision_net=SegDecModel.DECISION_NET_NONE, 755 | use_corss_entropy_seg_net=use_corss_entropy_seg_net, 756 | positive_weight=pos_weights) 757 | else: 758 | # use lr=0.005 ofr mean squated error loss 759 | net_model = SegDecModel(decision_net=SegDecModel.DECISION_NET_FULL, 760 | use_corss_entropy_seg_net=use_corss_entropy_seg_net, 761 | positive_weight = pos_weights) 762 | current_pretrained_model = pretrained_model 763 | 764 | if current_pretrained_model is None and pretrained_main_folder is not None: 765 | current_pretrained_model = os.path.join(pretrained_main_folder,fold_name) 766 | 767 | train = SegDecTrain(net_model, 768 | storage_dir=main_storage_dir, 769 | run_string=run_name, 770 | image_size=(size_height,size_width,channels), # NOTE size should be dividable by 16 !!! 771 | batch_size=1, 772 | learning_rate=lr_val, 773 | max_steps=max_steps, 774 | max_epochs=1200, 775 | visible_device_list=args.gpu, 776 | pretrained_model_checkpoint_path=current_pretrained_model, 777 | train_segmentation_net=with_seg_net, 778 | train_decision_net=with_decision_net, 779 | use_random_rotation=input_rotation, 780 | ensure_posneg_balance=ensure_posneg_balance) 781 | 782 | dataset_fold_folder = os.path.join(dataset_folder,fold_name) 783 | 784 | # Run training 785 | train.train(InputData(train_subset, dataset_fold_folder)) 786 | 787 | if with_decision_net: 788 | # Run evaluation on test data 789 | samples_outcome_test,samples_names_test, samples_speed_eval,num_params = train.evaluate(InputData('test', dataset_fold_folder)) 790 | 791 | np.save(os.path.join(main_storage_dir, 'segdec_eval', run_name, 'test', 'results_decision_net.npy'), samples_outcome_test) 792 | np.save(os.path.join(main_storage_dir, 'segdec_eval', run_name, 'test', 'results_decision_net_names.npy'), samples_names_test) 793 | np.save(os.path.join(main_storage_dir, 'segdec_eval', run_name, 'test', 'results_decision_net_speed_eval.npy'), samples_speed_eval) 794 | 795 | # Copy results from test dir of specific fold into common folder for this input_data 796 | src_dir = os.path.join(main_storage_dir, 'segdec_eval', run_name, 'test') 797 | dst_dir = os.path.join(main_storage_dir, 'segdec_eval', d if args.name_prefix is None else os.path.join(d,args.name_prefix)) 798 | for src_file in glob.glob(os.path.join(src_dir, '*.pdf')): 799 | shutil.copy(src_file, dst_dir) 800 | --------------------------------------------------------------------------------