├── .gitmodules ├── LICENSE ├── common ├── __init__.py ├── image_helper.py ├── lr_scheduler.py ├── model_loader.py ├── tf_utils.py └── utils.py ├── const.py ├── datasets ├── __init__.py ├── augmentation_factory.py ├── data_wrapper_base.py └── matting_data_wrapper.py ├── demo └── demo.mp4 ├── evaluate.py ├── factory ├── __init__.py ├── base.py ├── losses_funcs.py ├── matting_converter.py └── matting_nets.py ├── figure └── gradient_error_vs_latency.png ├── helper ├── __init__.py ├── base.py ├── evaluator.py └── trainer.py ├── matting_nets ├── __init__.py ├── deeplab_v3plus.py ├── mmnet.py └── mmnet_utils.py ├── metrics ├── base.py ├── manager.py ├── ops │ ├── __init__.py │ ├── base_ops.py │ ├── misc_ops.py │ └── tensor_ops.py ├── parser.py └── summaries.py ├── nets ├── __init__.py └── mobilenet ├── readme.md ├── requirements ├── py36-common.txt ├── py36-cpu.txt └── py36-gpu.txt ├── scripts ├── train_deeplab_os16_dm0.5_256.sh ├── train_mmnet_dm1.0_256.sh ├── valid_deeplab_os16_dm0.5_256.sh └── valid_mmnet_dm1.0_256.sh └── train.py /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "models"] 2 | path = models 3 | url = https://github.com/tensorflow/models 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /common/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hyperconnect/MMNet/aa423b598f38e051fcf46914cf8a7765458c09d2/common/__init__.py -------------------------------------------------------------------------------- /common/image_helper.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | # Below method is borrowed from tensorflow official repository 5 | # https://github.com/tensorflow/tensorflow/blob/a0b8cee815100b805a24fedfa12b28139d24e7fe/tensorflow/python/ops/image_ops_impl.py 6 | def _verify_compatible_image_shapes(img1, img2): 7 | """Checks if two image tensors are compatible for applying SSIM or PSNR. 8 | This function checks if two sets of images have ranks at least 3, and if the 9 | last three dimensions match. 10 | Args: 11 | img1: Tensor containing the first image batch. 12 | img2: Tensor containing the second image batch. 13 | Returns: 14 | A tuple containing: the first tensor shape, the second tensor shape, and a 15 | list of control_flow_ops.Assert() ops implementing the checks. 16 | Raises: 17 | ValueError: When static shape check fails. 18 | """ 19 | shape1 = img1.get_shape().with_rank_at_least(3) 20 | shape2 = img2.get_shape().with_rank_at_least(3) 21 | shape1[-3:].assert_is_compatible_with(shape2[-3:]) 22 | 23 | if shape1.ndims is not None and shape2.ndims is not None: 24 | for dim1, dim2 in zip(reversed(shape1[:-3]), reversed(shape2[:-3])): 25 | if not (dim1 == 1 or dim2 == 1 or dim1.is_compatible_with(dim2)): 26 | raise ValueError( 27 | "Two images are not compatible: %s and %s" % (shape1, shape2)) 28 | 29 | # Now assign shape tensors. 30 | shape1, shape2 = tf.shape_n([img1, img2]) 31 | 32 | checks = [] 33 | checks.append(tf.Assert( 34 | tf.greater_equal(tf.size(shape1), 3), 35 | [shape1, shape2], summarize=10)) 36 | checks.append(tf.Assert( 37 | tf.reduce_all(tf.equal(shape1[-3:], shape2[-3:])), 38 | [shape1, shape2], summarize=10)) 39 | return shape1, shape2, checks 40 | 41 | 42 | def sobel_gradient(img): 43 | """Get image and calculate the result of sobel filter. 44 | 45 | Args: 46 | imgs: Image Tensor. Either 3-D or 4-D. 47 | 48 | Return: 49 | A Tensor which concat the result of sobel in both 50 | horizontally and vertically. 51 | Therefore, number of the channels is doubled. 52 | 53 | """ 54 | num_channels = img.get_shape().as_list()[-1] 55 | 56 | # load filter which can be reused 57 | with tf.variable_scope("misc/img_gradient", reuse=tf.AUTO_REUSE): 58 | filter_x = tf.constant([[-1/8, 0, 1/8], 59 | [-2/8, 0, 2/8], 60 | [-1/8, 0, 1/8]], 61 | name="sobel_x", dtype=tf.float32, shape=[3, 3, 1, 1]) 62 | filter_x = tf.tile(filter_x, [1, 1, num_channels, 1]) 63 | 64 | filter_y = tf.constant([[-1/8, -2/8, -1/8], 65 | [0, 0, 0], 66 | [1/8, 2/8, 1/8]], 67 | name="sobel_y", dtype=tf.float32, shape=[3, 3, 1, 1]) 68 | filter_y = tf.tile(filter_y, [1, 1, num_channels, 1]) 69 | 70 | # calculate 71 | grad_x = tf.nn.depthwise_conv2d(img, filter_x, 72 | strides=[1, 1, 1, 1], padding="VALID", name="grad_x") 73 | 74 | grad_y = tf.nn.depthwise_conv2d(img, filter_y, 75 | strides=[1, 1, 1, 1], padding="VALID", name="grad_y") 76 | 77 | grad_xy = tf.concat([grad_x, grad_y], axis=-1) 78 | 79 | return grad_xy 80 | 81 | 82 | def _first_deriviate_gaussian_filters(size, sigma): 83 | size = tf.convert_to_tensor(size, tf.int32) 84 | sigma = tf.convert_to_tensor(sigma, tf.float32) 85 | sigma2 = tf.square(sigma) 86 | 87 | coords = tf.cast(tf.range(size), sigma.dtype) 88 | coords -= tf.cast(size - 1, sigma.dtype) / 2.0 89 | 90 | g = tf.square(coords) 91 | g *= -0.5 / tf.square(sigma) 92 | 93 | g = tf.reshape(g, shape=[1, -1]) + tf.reshape(g, shape=[-1, 1]) 94 | g = tf.reshape(g, shape=[1, -1]) # For tf.nn.softmax(). 95 | g = tf.nn.softmax(g) 96 | g = tf.reshape(g, shape=[size, size]) 97 | 98 | # https://cedar.buffalo.edu/~srihari/CSE555/Normal2.pdf 99 | # https://github.com/scipy/scipy/blob/v0.14.0/scipy/ndimage/filters.py#L179 100 | gx = -1 * tf.reshape(coords, shape=[1, -1]) * g / sigma2 101 | gy = -1 * tf.reshape(coords, shape=[-1, 1]) * g / sigma2 102 | 103 | # gx = tf.reshape(gx, shape=[1, -1]) # For tf.nn.softmax(). 104 | # gy = tf.reshape(gy, shape=[1, -1]) # For tf.nn.softmax(). 105 | # gx = tf.nn.softmax(gx) 106 | # gy = tf.nn.softmax(gy) 107 | 108 | return tf.reshape(gx, shape=[size, size, 1, 1]), tf.reshape(gy, shape=[size, size, 1, 1]) 109 | 110 | 111 | def first_deriviate_gaussian_gradient(img, sigma): 112 | """Get image and calculate the result of first deriviate gaussian filter. 113 | Now, implementation assume that channel is 1. 114 | https://www.juew.org/publication/CVPR09_evaluation_final_HQ.pdf 115 | 116 | Args: 117 | imgs: Image Tensor. Either 3-D or 4-D. 118 | 119 | Return: 120 | A Tensor which concat the result of sobel in both 121 | horizontally and vertically. 122 | Therefore, number of the channels is doubled. 123 | 124 | """ 125 | num_channels = img.get_shape().as_list()[-1] 126 | assert num_channels == 1 127 | 128 | # load filter which can be reused 129 | with tf.variable_scope("misc/img_gradient", reuse=tf.AUTO_REUSE): 130 | # truncate for 3 sigma 131 | half_width = int(3 * sigma + 0.5) 132 | size = 2 * half_width + 1 133 | 134 | filter_x, filter_y = _first_deriviate_gaussian_filters(size, sigma) 135 | 136 | # calculate 137 | grad_x = tf.nn.depthwise_conv2d(img, filter_x, 138 | strides=[1, 1, 1, 1], padding="VALID", name="grad_x") 139 | 140 | grad_y = tf.nn.depthwise_conv2d(img, filter_y, 141 | strides=[1, 1, 1, 1], padding="VALID", name="grad_y") 142 | 143 | grad_xy = tf.concat([grad_x, grad_y], axis=-1) 144 | return grad_xy 145 | -------------------------------------------------------------------------------- /common/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | def factory(args, log, global_step_from_checkpoint=None, global_step=None, dataset=None): 2 | learning_rate_scheduler = FixedLR(args, log) 3 | return learning_rate_scheduler 4 | 5 | 6 | def add_arguments(parser): 7 | FixedLR.add_arguments(parser) 8 | 9 | 10 | class FixedLR: 11 | def __init__( 12 | self, args, logger 13 | ): 14 | self.args = args 15 | self.logger = logger 16 | assert hasattr(self.args, "learning_rate") and isinstance(self.args.learning_rate, float) 17 | self.learning_rate = self.args.learning_rate 18 | 19 | self.placeholder = self.learning_rate 20 | self.should_feed_dict = False 21 | 22 | @staticmethod 23 | def add_arguments(parser): 24 | g_lr = parser.add_argument_group("Learning Rate Arguments") 25 | g_lr.add_argument("--learning_rate", default=1e-4, type=float, help="Initial learning rate for gradient update") 26 | -------------------------------------------------------------------------------- /common/model_loader.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import tensorflow as tf 4 | from tensorflow.python import pywrap_tensorflow 5 | from tensorflow.python.ops import control_flow_ops 6 | from tensorflow.python.platform import gfile 7 | 8 | from common.utils import format_text, get_logger 9 | 10 | 11 | class Pb(): 12 | @staticmethod 13 | def load( 14 | model_stempath: str, 15 | extension: str=".pb", 16 | ): 17 | """ Load weights from protobuf (*.pb) file and 18 | Args: 19 | model_stempath: Path without extension to protobuf model. 20 | extension (optional): 21 | Returns: 22 | graph 23 | """ 24 | model_fullpath = str(model_stempath.with_suffix(extension)) 25 | with gfile.FastGFile(model_fullpath, "rb") as f: 26 | graph_def = tf.GraphDef() 27 | graph_def.ParseFromString(f.read()) 28 | 29 | with tf.Graph().as_default() as graph: 30 | tf.import_graph_def(graph_def, name="") 31 | 32 | return graph 33 | 34 | 35 | class Ckpt(): 36 | def __init__( 37 | self, 38 | session: tf.Session, 39 | variables_to_restore=None, 40 | include_scopes: str="", 41 | exclude_scopes: str="", 42 | ignore_missing_vars: bool=False, 43 | logger=None, 44 | ): 45 | self.session = session 46 | self.variables_to_restore = self._get_variables_to_restore( 47 | variables_to_restore, 48 | include_scopes, 49 | exclude_scopes, 50 | ) 51 | self.ignore_missing_vars = ignore_missing_vars 52 | self.logger = logger 53 | if logger is None: 54 | self.logger = get_logger("Ckpt Loader", None) 55 | 56 | # variables to save reusable info from previous load 57 | self.has_previous_info = False 58 | self.grouped_vars = {} 59 | self.placeholders = {} 60 | self.assign_op = None 61 | 62 | def _get_variables_to_restore( 63 | self, 64 | variables_to_restore=None, 65 | include_scopes: str="", 66 | exclude_scopes: str="", 67 | ): 68 | # variables_to_restore might be List or Dictionary. 69 | 70 | def split_strip(scopes: str): 71 | return list(filter(lambda x: len(x) > 0, [s.strip() for s in scopes.split(",")])) 72 | 73 | def starts_with(var, scopes: List) -> bool: 74 | return any([var.op.name.startswith(prefix) for prefix in scopes]) 75 | 76 | exclusions = split_strip(exclude_scopes) 77 | inclusions = split_strip(include_scopes) 78 | 79 | if variables_to_restore is None: 80 | variables_to_restore = tf.contrib.framework.get_variables_to_restore() 81 | 82 | filtered_variables_key = variables_to_restore 83 | if len(inclusions) > 0: 84 | filtered_variables_key = filter(lambda var: starts_with(var, inclusions), filtered_variables_key) 85 | filtered_variables_key = filter(lambda var: not starts_with(var, exclusions), filtered_variables_key) 86 | 87 | if isinstance(variables_to_restore, dict): 88 | variables_to_restore = { 89 | key: variables_to_restore[key] for key in filtered_variables_key 90 | } 91 | elif isinstance(variables_to_restore, list): 92 | variables_to_restore = list(filtered_variables_key) 93 | 94 | return variables_to_restore 95 | 96 | # Copied and revised code not to create duplicated 'assign' operations everytime it gets called. 97 | # https://github.com/tensorflow/tensorflow/blob/r1.8/tensorflow/contrib/framework/python/ops/variables.py#L558 98 | def load(self, checkpoint_stempath): 99 | def get_variable_full_name(var): 100 | if var._save_slice_info: 101 | return var._save_slice_info.full_name 102 | else: 103 | return var.op.name 104 | 105 | if not self.has_previous_info: 106 | if isinstance(self.variables_to_restore, (tuple, list)): 107 | for var in self.variables_to_restore: 108 | ckpt_name = get_variable_full_name(var) 109 | if ckpt_name not in self.grouped_vars: 110 | self.grouped_vars[ckpt_name] = [] 111 | self.grouped_vars[ckpt_name].append(var) 112 | 113 | else: 114 | for ckpt_name, value in self.variables_to_restore.items(): 115 | if isinstance(value, (tuple, list)): 116 | self.grouped_vars[ckpt_name] = value 117 | else: 118 | self.grouped_vars[ckpt_name] = [value] 119 | 120 | # Read each checkpoint entry. Create a placeholder variable and 121 | # add the (possibly sliced) data from the checkpoint to the feed_dict. 122 | reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_stempath) 123 | feed_dict = {} 124 | assign_ops = [] 125 | for ckpt_name in self.grouped_vars: 126 | if not reader.has_tensor(ckpt_name): 127 | log_str = f"Checkpoint is missing variable [{ckpt_name}]" 128 | if self.ignore_missing_vars: 129 | self.logger.warning(log_str) 130 | continue 131 | else: 132 | raise ValueError(log_str) 133 | ckpt_value = reader.get_tensor(ckpt_name) 134 | 135 | for var in self.grouped_vars[ckpt_name]: 136 | placeholder_name = f"placeholder/{var.op.name}" 137 | if self.has_previous_info: 138 | placeholder_tensor = self.placeholders[placeholder_name] 139 | else: 140 | placeholder_tensor = tf.placeholder( 141 | dtype=var.dtype.base_dtype, 142 | shape=var.get_shape(), 143 | name=placeholder_name) 144 | assign_ops.append(var.assign(placeholder_tensor)) 145 | self.placeholders[placeholder_name] = placeholder_tensor 146 | 147 | if not var._save_slice_info: 148 | if var.get_shape() != ckpt_value.shape: 149 | raise ValueError( 150 | f"Total size of new array must be unchanged for {ckpt_name} " 151 | f"lh_shape: [{str(ckpt_value.shape)}], rh_shape: [{str(var.get_shape())}]") 152 | 153 | feed_dict[placeholder_tensor] = ckpt_value.reshape(ckpt_value.shape) 154 | else: 155 | slice_dims = zip(var._save_slice_info.var_offset, 156 | var._save_slice_info.var_shape) 157 | slice_dims = [(start, start + size) for (start, size) in slice_dims] 158 | slice_dims = [slice(*x) for x in slice_dims] 159 | slice_value = ckpt_value[slice_dims] 160 | slice_value = slice_value.reshape(var._save_slice_info.var_shape) 161 | feed_dict[placeholder_tensor] = slice_value 162 | 163 | if not self.has_previous_info: 164 | self.assign_op = control_flow_ops.group(*assign_ops) 165 | 166 | self.session.run(self.assign_op, feed_dict) 167 | 168 | if len(feed_dict) > 0: 169 | for key in feed_dict.keys(): 170 | self.logger.info(f"init from checkpoint > {key}") 171 | else: 172 | self.logger.info(f"No init from checkpoint") 173 | 174 | with format_text("cyan", attrs=["bold", "underline"]) as fmt: 175 | self.logger.info(fmt(f"Restore from {checkpoint_stempath}")) 176 | self.has_previous_info = True 177 | -------------------------------------------------------------------------------- /common/tf_utils.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | from typing import Dict 3 | from functools import reduce 4 | from pathlib import Path 5 | from operator import mul 6 | 7 | import tensorflow as tf 8 | import pandas as pd 9 | import numpy as np 10 | import deprecation 11 | from termcolor import colored 12 | from tensorflow.contrib.training import checkpoints_iterator 13 | from common.utils import get_logger 14 | from common.utils import wait 15 | from PIL import Image 16 | 17 | import const 18 | 19 | 20 | def get_variables_to_train(trainable_scopes, logger): 21 | """Returns a list of variables to train. 22 | Returns: 23 | A list of variables to train by the optimizer. 24 | """ 25 | if trainable_scopes is None or trainable_scopes == "": 26 | return tf.trainable_variables() 27 | else: 28 | scopes = [scope.strip() for scope in trainable_scopes.split(",")] 29 | 30 | variables_to_train = [] 31 | for scope in scopes: 32 | variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope) 33 | variables_to_train.extend(variables) 34 | 35 | for var in variables_to_train: 36 | logger.info("vars to train > {}".format(var.name)) 37 | 38 | return variables_to_train 39 | 40 | 41 | def show_models(logger): 42 | trainable_variables = set(tf.contrib.framework.get_variables(collection=tf.GraphKeys.TRAINABLE_VARIABLES)) 43 | all_variables = tf.contrib.framework.get_variables() 44 | trainable_vars = tf.trainable_variables() 45 | total_params = 0 46 | logger.info(colored(f">> Start of shoing all variables", "cyan", attrs=["bold"])) 47 | for v in all_variables: 48 | is_trainable = v in trainable_variables 49 | count_params = reduce(mul, v.get_shape().as_list(), 1) 50 | total_params += count_params 51 | color = "cyan" if is_trainable else "green" 52 | logger.info(colored(( 53 | f">> {v.name} {v.dtype} : {v.get_shape().as_list()}, {count_params} ... {total_params} " 54 | f"(is_trainable: {is_trainable})" 55 | ), color)) 56 | logger.info(colored( 57 | f">> End of showing all variables // Number of variables: {len(all_variables)}, " 58 | f"Number of trainable variables : {len(trainable_vars)}, " 59 | f"Total prod + sum of shape: {total_params}", 60 | "cyan", attrs=["bold"])) 61 | return total_params 62 | 63 | 64 | def ckpt_iterator(checkpoint_dir, min_interval_secs=0, timeout=None, timeout_fn=None, logger=None): 65 | for ckpt_path in checkpoints_iterator(checkpoint_dir, min_interval_secs, timeout, timeout_fn): 66 | if Path(ckpt_path).suffix == const.CKPT_EXTENSION: 67 | if logger is not None: 68 | logger.info(f"[skip] {ckpt_path}") 69 | else: 70 | yield ckpt_path 71 | 72 | 73 | class BestKeeper(object): 74 | def __init__( 75 | self, 76 | metric_with_modes, 77 | dataset_name, 78 | directory, 79 | logger=None, 80 | epsilon=0.00005, 81 | score_file="scores.tsv", 82 | metric_best: Dict={}, 83 | ): 84 | """Keep best model's checkpoint by each datasets & metrics 85 | 86 | Args: 87 | metric_with_modes: Dict, metric_name: mode 88 | if mode is 'min', then it means that minimum value is best, for example loss(MSE, MAE) 89 | if mode is 'max', then it means that maximum value is best, for example Accuracy, Precision, Recall 90 | dataset_name: str, dataset name on which metric be will be calculated 91 | directory: directory path for saving best model 92 | epsilon: float, threshold for measuring the new optimum, to only focus on significant changes. 93 | Because sometimes early-stopping gives better generalization results 94 | """ 95 | if logger is not None: 96 | self.log = logger 97 | else: 98 | self.log = get_logger("BestKeeper") 99 | 100 | self.score_file = score_file 101 | self.metric_best = metric_best 102 | 103 | self.log.info(colored(f"Initialize BestKeeper: Monitor {dataset_name} & Save to {directory}", 104 | "yellow", attrs=["underline"])) 105 | self.log.info(f"{metric_with_modes}") 106 | 107 | self.x_better_than_y = {} 108 | self.directory = Path(directory) 109 | self.output_temp_dir = self.directory / f"{dataset_name}_best_keeper_temp" 110 | 111 | for metric_name, mode in metric_with_modes.items(): 112 | if mode == "min": 113 | self.metric_best[metric_name] = self.load_metric_from_scores_tsv( 114 | directory / dataset_name / metric_name / score_file, 115 | metric_name, 116 | np.inf, 117 | ) 118 | self.x_better_than_y[metric_name] = lambda x, y: np.less(x, y - epsilon) 119 | elif mode == "max": 120 | self.metric_best[metric_name] = self.load_metric_from_scores_tsv( 121 | directory / dataset_name / metric_name / score_file, 122 | metric_name, 123 | -np.inf, 124 | ) 125 | self.x_better_than_y[metric_name] = lambda x, y: np.greater(x, y + epsilon) 126 | else: 127 | raise ValueError(f"Unsupported mode : {mode}") 128 | 129 | def load_metric_from_scores_tsv( 130 | self, 131 | full_path: Path, 132 | metric_name: str, 133 | default_value: float, 134 | ) -> float: 135 | def parse_scores(s: str): 136 | if len(s) > 0: 137 | return float(s) 138 | else: 139 | return default_value 140 | 141 | if full_path.exists(): 142 | with open(full_path, "r") as f: 143 | header = f.readline().strip().split("\t") 144 | values = list(map(parse_scores, f.readline().strip().split("\t"))) 145 | metric_index = header.index(metric_name) 146 | 147 | return values[metric_index] 148 | else: 149 | return default_value 150 | 151 | def monitor(self, dataset_name, eval_scores): 152 | metrics_keep = {} 153 | is_keep = False 154 | for metric_name, score in self.metric_best.items(): 155 | score = eval_scores[metric_name] 156 | if self.x_better_than_y[metric_name](score, self.metric_best[metric_name]): 157 | old_score = self.metric_best[metric_name] 158 | self.metric_best[metric_name] = score 159 | metrics_keep[metric_name] = True 160 | is_keep = True 161 | self.log.info(colored("[KeepBest] {} {:.6f} -> {:.6f}, so keep it!".format( 162 | metric_name, old_score, score), "blue", attrs=["underline"])) 163 | else: 164 | metrics_keep[metric_name] = False 165 | return is_keep, metrics_keep 166 | 167 | def save_best(self, dataset_name, metrics_keep, ckpt_glob): 168 | for metric_name, is_keep in metrics_keep.items(): 169 | if is_keep: 170 | keep_path = self.directory / Path(dataset_name) / Path(metric_name) 171 | self.keep_checkpoint(keep_path, ckpt_glob) 172 | self.keep_converted_files(keep_path) 173 | 174 | def save_scores(self, dataset_name, metrics_keep, eval_scores, meta_info=None): 175 | eval_scores_with_meta = eval_scores.copy() 176 | if meta_info is not None: 177 | eval_scores_with_meta.update(meta_info) 178 | 179 | for metric_name, is_keep in metrics_keep.items(): 180 | if is_keep: 181 | keep_path = self.directory / Path(dataset_name) / Path(metric_name) 182 | if not keep_path.exists(): 183 | keep_path.mkdir(parents=True) 184 | df = pd.DataFrame(pd.Series(eval_scores_with_meta)).sort_index().transpose() 185 | df.to_csv(keep_path / self.score_file, sep="\t", index=False, float_format="%.5f") 186 | 187 | @deprecation.deprecated(details="DO NOT USE. It will be removed at further refactoring.") 188 | def save_images(self, dataset_name, images): 189 | num_images = images[0].shape[0] 190 | h, w = images[0].shape[1:3] 191 | keep_path = self.directory / dataset_name / "images" 192 | if not keep_path.exists(): 193 | keep_path.mkdir(parents=True) 194 | 195 | for nidx in range(num_images): 196 | _images = [Image.fromarray(img[nidx]) for img in images] 197 | 198 | merged_image = Image.new("RGB", (w * 3, h * ((len(_images)-1) // 3 + 1))) 199 | for i, img in enumerate(_images): 200 | row = i // 3 201 | col = i % 3 202 | 203 | merged_image.paste(img, (col*w, row*h, (col+1)*w, (row+1)*h)) 204 | merged_image.save(keep_path / f"img_{nidx}.jpg") 205 | 206 | def remove_old_best(self, dataset_name, metrics_keep): 207 | for metric_name, is_keep in metrics_keep.items(): 208 | if is_keep: 209 | keep_path = self.directory / Path(dataset_name) / Path(metric_name) 210 | # Remove old directory to save space 211 | if keep_path.exists(): 212 | shutil.rmtree(str(keep_path)) 213 | keep_path.mkdir(parents=True) 214 | 215 | def keep_checkpoint(self, keep_path, ckpt_glob): 216 | if not isinstance(keep_path, Path): 217 | keep_path = Path(keep_path) 218 | 219 | # .data-00000-of-00001, .meta, .index 220 | for ckpt_path in ckpt_glob.parent.glob(ckpt_glob.name): 221 | shutil.copy(str(ckpt_path), str(keep_path)) 222 | """ 223 | self.log.info(colored( 224 | "[KeepBest] copy {} -> {}".format(ckpt_path, keep_path), "blue", attrs=["underline"]) 225 | ) 226 | """ 227 | 228 | def keep_converted_files(self, keep_path): 229 | if not isinstance(keep_path, Path): 230 | keep_path = Path(keep_path) 231 | 232 | for path in self.output_temp_dir.glob("*"): 233 | if path.is_dir(): 234 | shutil.copytree(str(path), str(keep_path / path.name)) 235 | else: 236 | shutil.copy(str(path), str(keep_path / path.name)) 237 | 238 | def remove_temp_dir(self): 239 | if self.output_temp_dir.exists(): 240 | shutil.rmtree(str(self.output_temp_dir)) 241 | 242 | 243 | def resolve_checkpoint_path(checkpoint_path, log, is_training): 244 | if checkpoint_path is not None and Path(checkpoint_path).is_dir(): 245 | old_ckpt_path = checkpoint_path 246 | checkpoint_path = tf.train.latest_checkpoint(old_ckpt_path) 247 | if not is_training: 248 | def stop_checker(): 249 | return (tf.train.latest_checkpoint(old_ckpt_path) is not None) 250 | wait("There are no checkpoint file yet", stop_checker) # wait until checkpoint occurs 251 | checkpoint_path = tf.train.latest_checkpoint(old_ckpt_path) 252 | log.info(colored( 253 | "self.args.checkpoint_path updated: {} -> {}".format(old_ckpt_path, checkpoint_path), 254 | "yellow", attrs=["bold"])) 255 | else: 256 | log.info(colored("checkpoint_path is {}".format(checkpoint_path), "yellow", attrs=["bold"])) 257 | 258 | return checkpoint_path 259 | 260 | 261 | def get_global_step_from_checkpoint(checkpoint_path): 262 | """It is assumed that `checkpoint_path` is path to checkpoint file, not path to directory 263 | with checkpoint files. 264 | In case checkpoint path is not defined, 0 is returned.""" 265 | if checkpoint_path is None or checkpoint_path == "": 266 | return 0 267 | else: 268 | if "-" in Path(checkpoint_path).stem: 269 | return int(Path(checkpoint_path).stem.split("-")[-1]) 270 | else: 271 | return 0 272 | -------------------------------------------------------------------------------- /common/utils.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import humanfriendly as hf 3 | import contextlib 4 | import argparse 5 | import logging 6 | import getpass 7 | import shutil 8 | import json 9 | import time 10 | from pathlib import Path 11 | from types import SimpleNamespace 12 | from datetime import datetime 13 | 14 | import tensorflow as tf 15 | import click 16 | from termcolor import colored 17 | 18 | 19 | def check_external_paths(paths, log): 20 | for p in paths: 21 | if not Path(p).exists(): 22 | log.info(colored("{} path does not exist!".format(p), "red")) 23 | time.sleep(1) 24 | 25 | 26 | def get_model_variables(exclude_prefix): 27 | if exclude_prefix == [""]: 28 | return tf.contrib.framework.get_model_variables() 29 | assert "" not in exclude_prefix 30 | 31 | model_variables = [] 32 | for var in tf.contrib.framework.get_model_variables(): 33 | is_include = True 34 | for substring in exclude_prefix: 35 | if substring in var.name: 36 | is_include = False 37 | if is_include: 38 | model_variables.append(var) 39 | 40 | return model_variables 41 | 42 | 43 | def dump_configuration(train_log_dir, config, filename="config.json"): 44 | if not Path(train_log_dir).exists(): 45 | Path(train_log_dir).mkdir(parents=True) 46 | 47 | if isinstance(config, argparse.Namespace): 48 | config = vars(config) 49 | elif isinstance(config, dict): 50 | config = config 51 | else: 52 | raise ValueError("Unsupported type for configuration: {}".format(type(config))) 53 | 54 | with Path(train_log_dir, filename).open("w") as f: 55 | json.dump(config, f) 56 | 57 | 58 | def update_train_dir(args): 59 | def replace_func(base_string, a, b): 60 | replaced_string = base_string.replace(a, b) 61 | print(colored("[update_train_dir] replace {} : {} -> {}".format(a, base_string, replaced_string), 62 | "yellow")) 63 | return replaced_string 64 | 65 | def make_placeholder(s: str, circumfix: str="%"): 66 | return circumfix + s.upper() + circumfix 67 | 68 | placeholder_mapping = { 69 | make_placeholder("DATE"): datetime.now().strftime("%y%m%d%H%M%S"), 70 | make_placeholder("USER"): getpass.getuser(), 71 | } 72 | 73 | for key, value in placeholder_mapping.items(): 74 | args.train_dir = replace_func(args.train_dir, key, value) 75 | 76 | unknown = "UNKNOWN" 77 | for key, value in vars(args).items(): 78 | key_placeholder = make_placeholder(key) 79 | if key_placeholder in args.train_dir: 80 | replace_value = value 81 | if isinstance(replace_value, str): 82 | if "/" in replace_value: 83 | replace_value = unknown 84 | elif isinstance(replace_value, list): 85 | replace_value = ",".join(map(str, replace_value)) 86 | elif isinstance(replace_value, float) or isinstance(replace_value, int): 87 | replace_value = str(replace_value) 88 | elif isinstance(replace_value, bool): 89 | replace_value = str(replace_value) 90 | else: 91 | replace_value = unknown 92 | args.train_dir = replace_func(args.train_dir, key_placeholder, replace_value) 93 | 94 | print(colored("[update_train_dir] final train_dir {}".format(args.train_dir), 95 | "yellow", attrs=["bold", "underline"])) 96 | 97 | 98 | def exit_handler(train_dir): 99 | if click.confirm("Do you want to delete {}?".format(train_dir), abort=True): 100 | shutil.rmtree(train_dir) 101 | print("... delete {} done!".format(train_dir)) 102 | 103 | 104 | def positive_int(value): 105 | ivalue = int(value) 106 | if ivalue <= 0: 107 | raise argparse.ArgumentTypeError("%s is an invalid positive int value" % value) 108 | return ivalue 109 | 110 | 111 | def get_subparser_argument_list(parser, subparser_name): 112 | # Hack argparse and get subparser's arguments 113 | from argparse import _SubParsersAction 114 | argument_list = [] 115 | for sub_parser_action in filter(lambda x: isinstance(x, _SubParsersAction), parser._subparsers._actions): 116 | for action in sub_parser_action.choices[subparser_name]._actions: 117 | arg = action.option_strings[-1].replace("--", "") 118 | if arg == "help": 119 | continue 120 | if arg.startswith("no-"): 121 | continue 122 | argument_list.append(arg) 123 | return argument_list 124 | 125 | 126 | def get_logger(logger_name, log_file: Path=None, level=logging.DEBUG): 127 | # "log/data-pipe-{}.log".format(datetime.now().strftime("%Y-%m-%d-%H-%M-%S")) 128 | logger = logging.getLogger(logger_name) 129 | 130 | if not logger.hasHandlers(): 131 | formatter = logging.Formatter("[%(levelname)s|%(filename)s:%(lineno)s] %(asctime)s > %(message)s") 132 | 133 | logger.setLevel(level) 134 | 135 | if log_file is not None: 136 | log_file.parent.mkdir(parents=True, exist_ok=True) 137 | fileHandler = logging.FileHandler(log_file, mode="w") 138 | fileHandler.setFormatter(formatter) 139 | logger.addHandler(fileHandler) 140 | 141 | streamHandler = logging.StreamHandler() 142 | streamHandler.setFormatter(formatter) 143 | logger.addHandler(streamHandler) 144 | 145 | return logger 146 | 147 | 148 | def format_timespan(duration): 149 | if duration < 10: 150 | readable_duration = "{:.1f} (ms)".format(duration * 1000) 151 | else: 152 | readable_duration = hf.format_timespan(duration) 153 | return readable_duration 154 | 155 | 156 | @contextlib.contextmanager 157 | def timer(name): 158 | st = time.time() 159 | yield 160 | print(" {} : {}".format(name, format_timespan(time.time() - st))) 161 | 162 | 163 | def timeit(method): 164 | def timed(*args, **kw): 165 | hf_timer = hf.Timer() 166 | result = method(*args, **kw) 167 | print(" {!r} ({!r}, {!r}) {}".format(method.__name__, args, kw, hf_timer.rounded)) 168 | return result 169 | return timed 170 | 171 | 172 | class Timer(object): 173 | def __init__(self, log): 174 | self.log = log 175 | 176 | @contextlib.contextmanager 177 | def __call__(self, name, log_func=None): 178 | """ 179 | Example. 180 | timer = Timer(log) 181 | with timer("Some Routines"): 182 | routine1() 183 | routine2() 184 | """ 185 | if log_func is None: 186 | log_func = self.log.info 187 | 188 | start = time.clock() 189 | yield 190 | end = time.clock() 191 | duration = end - start 192 | readable_duration = format_timespan(duration) 193 | log_func(f"{name} :: {readable_duration}") 194 | 195 | 196 | class TextFormatter(object): 197 | def __init__(self, color, attrs): 198 | self.color = color 199 | self.attrs = attrs 200 | 201 | def __call__(self, string): 202 | return colored(string, self.color, attrs=self.attrs) 203 | 204 | 205 | class LogFormatter(object): 206 | def __init__(self, log, color, attrs): 207 | self.log = log 208 | self.color = color 209 | self.attrs = attrs 210 | 211 | def __call__(self, string): 212 | return self.log(colored(string, self.color, attrs=self.attrs)) 213 | 214 | 215 | @contextlib.contextmanager 216 | def format_text(color, attrs=None): 217 | yield TextFormatter(color, attrs) 218 | 219 | 220 | def format_text_fun(color, attrs=None): 221 | return TextFormatter(color, attrs) 222 | 223 | 224 | def format_log(log, color, attrs=None): 225 | return LogFormatter(log, color, attrs) 226 | 227 | 228 | @contextlib.contextmanager 229 | def smart_log(log, msg): 230 | log.info(f"{msg} started.") 231 | yield 232 | log.info(f"{msg} finished.") 233 | 234 | 235 | def setup_step1_mode(args): 236 | args.step_evaluation = 1 237 | args.step_validation = 1 238 | args.step_minimum_save = 0 239 | args.step_save_checkpoint = 1 240 | args.step_save_summaries = 1 241 | 242 | log = get_logger("utils") 243 | log.info(colored("Update step_evaluation, step_validation, step_save_checkpoint ... to 1", 244 | "yellow")) 245 | 246 | 247 | def wait(message, stop_checker_closure): 248 | assert callable(stop_checker_closure) 249 | st = time.time() 250 | while True: 251 | try: 252 | time_pass = hf.format_timespan(int(time.time() - st)) 253 | sys.stdout.write(colored(( 254 | f"{message}. Do you wanna wait? If not, then ctrl+c! :: waiting time: {time_pass}\r" 255 | ), "yellow", attrs=["bold"])) 256 | sys.stdout.flush() 257 | time.sleep(1) 258 | if stop_checker_closure(): 259 | break 260 | except KeyboardInterrupt: 261 | break 262 | 263 | 264 | class MLNamespace(SimpleNamespace): 265 | def __init__(self, *args, **kwargs): 266 | for kwarg in kwargs.keys(): 267 | assert kwarg not in dir(self) 268 | super().__init__(*args, **kwargs) 269 | 270 | def unordered_values(self): 271 | return list(vars(self).values()) 272 | 273 | def __setitem__(self, key, value): 274 | setattr(self, key, value) 275 | -------------------------------------------------------------------------------- /const.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | TF_SESSION_CONFIG = tf.ConfigProto( 5 | gpu_options=tf.GPUOptions(allow_growth=True), 6 | log_device_placement=False, 7 | device_count={"GPU": 1}) 8 | 9 | CKPT_EXTENSION = ".my-ckpt" 10 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hyperconnect/MMNet/aa423b598f38e051fcf46914cf8a7765458c09d2/datasets/__init__.py -------------------------------------------------------------------------------- /datasets/augmentation_factory.py: -------------------------------------------------------------------------------- 1 | from math import pi 2 | from types import SimpleNamespace 3 | 4 | import tensorflow as tf 5 | 6 | 7 | _available_augmentation_methods_TF = [ 8 | "resize_random_scale_crop_flip_rotate", 9 | "resize_bilinear", 10 | ] 11 | _available_augmentation_methods = ( 12 | _available_augmentation_methods_TF + 13 | [ 14 | "no_augmentation", 15 | ] 16 | ) 17 | 18 | 19 | def expand_squeeze(aug_fun): 20 | def fun(image, output_height: int, output_width: int, channels, **kwargs): 21 | image = tf.expand_dims(image, 0) 22 | image = aug_fun(image, output_height, output_width, channels, **kwargs) 23 | return tf.squeeze(image) 24 | return fun 25 | 26 | 27 | @expand_squeeze 28 | def resize_bilinear(image, output_height: int, output_width: int, channels=3, **kwargs): 29 | # Resize the image to the specified height and width. 30 | image = tf.image.resize_bilinear(image, (output_height, output_width), align_corners=True) 31 | return image 32 | 33 | 34 | def _generate_rand(min_factor, max_factor, step_size): 35 | """Gets a random value. 36 | Args: 37 | min_factor: Minimum value. 38 | max_factor: Maximum value. 39 | step_size: The step size from minimum to maximum value. 40 | Returns: 41 | A random value selected between minimum and maximum value. 42 | Raises: 43 | ValueError: min_factor has unexpected value. 44 | """ 45 | if min_factor < 0 or min_factor > max_factor: 46 | raise ValueError("Unexpected value of min_factor.") 47 | if min_factor == max_factor: 48 | return tf.to_float(min_factor) 49 | # When step_size = 0, we sample the value uniformly from [min, max). 50 | if step_size == 0: 51 | return tf.random_uniform([1], 52 | minval=min_factor, 53 | maxval=max_factor) 54 | # When step_size != 0, we randomly select one discrete value from [min, max]. 55 | num_steps = int((max_factor - min_factor) / step_size + 1) 56 | scale_factors = tf.lin_space(min_factor, max_factor, num_steps) 57 | shuffled_scale_factors = tf.random_shuffle(scale_factors) 58 | return shuffled_scale_factors[0] 59 | 60 | 61 | def _scale_image(image, scale_height=1.0, scale_width=1.0): 62 | """Scales image. 63 | Args: 64 | image: Image with shape [height, width, 3]. 65 | scale_height: The value to scale height of image. 66 | scale_width: The value to scale of width image. 67 | Returns: 68 | Scaled image. 69 | """ 70 | # No random scaling if scale == 1. 71 | if scale_height == 1.0 and scale_width == 1.0: 72 | return image 73 | 74 | image_shape = tf.shape(image) 75 | new_dim = tf.to_int32([tf.to_float(image_shape[0]) * scale_height, tf.to_float(image_shape[1]) * scale_width]) 76 | # Need squeeze and expand_dims because image interpolation takes 77 | # 4D tensors as input. 78 | image = tf.squeeze(tf.image.resize_bilinear( 79 | tf.expand_dims(image, 0), 80 | new_dim, 81 | align_corners=True), [0]) 82 | 83 | return image 84 | 85 | 86 | def random_crop_flip(image, output_height: int, output_width: int, channels=3, **kwargs): 87 | image = tf.random_crop(image, (output_height, output_width, channels)) 88 | image = tf.image.random_flip_left_right(image) 89 | return image 90 | 91 | 92 | def rotate_with_crop(image, args): 93 | if args.rotation_range == 0: 94 | return image 95 | 96 | rotation_amount_degree = tf.random_uniform( 97 | shape=[], 98 | minval=-args.rotation_range, 99 | maxval=args.rotation_range, 100 | ) 101 | rotation_amount_radian = rotation_amount_degree * pi / 180.0 102 | image = tf.contrib.image.rotate(image, rotation_amount_radian) 103 | 104 | height_and_width = tf.shape(image) 105 | height, width = height_and_width[0], height_and_width[1] 106 | max_min_ratio = tf.cast( 107 | tf.maximum(height, width) / tf.minimum(height, width), 108 | dtype=tf.float32, 109 | ) 110 | coef_inv = tf.abs(tf.cos(rotation_amount_radian)) + max_min_ratio * tf.abs(tf.sin(rotation_amount_radian)) 111 | height_crop = tf.cast(tf.round(tf.cast(height, dtype=tf.float32) / coef_inv), dtype=tf.int32) 112 | width_crop = tf.cast(tf.round(tf.cast(width, dtype=tf.float32) / coef_inv), dtype=tf.int32) 113 | image = tf.image.crop_to_bounding_box( 114 | image, 115 | offset_height=(height - height_crop) // 2, 116 | offset_width=(width - width_crop) // 2, 117 | target_height=height_crop, 118 | target_width=width_crop, 119 | ) 120 | image = tf.image.resize_bilinear([image], (height, width))[0] 121 | return image 122 | 123 | 124 | @expand_squeeze 125 | def resize_bilinear(image, output_height: int, output_width: int, channels=3, **kwargs): 126 | # Resize the image to the specified height and width. 127 | image = tf.image.resize_bilinear(image, (output_height, output_width), align_corners=True) 128 | return image 129 | 130 | 131 | def resize_random_scale_crop_flip_rotate(image, output_height: int, output_width: int, channels=3, **kwargs): 132 | image = resize_bilinear(image, output_height, output_width, channels=channels) 133 | 134 | scale_height = _generate_rand(1.0, 1.15, 0.01) 135 | scale_width = _generate_rand(1.0, 1.15, 0.01) 136 | image = _scale_image(image, scale_height, scale_width) 137 | image = random_crop_flip(image, output_height, output_width, channels) 138 | 139 | do_rotate = tf.less(tf.random_uniform([], minval=0, maxval=1), 0.5) 140 | image = tf.cond( 141 | pred=do_rotate, 142 | true_fn=lambda: rotate_with_crop(image, SimpleNamespace(rotation_range=15)), 143 | false_fn=lambda: image, 144 | ) 145 | 146 | return image 147 | 148 | 149 | def get_augmentation_fn(name): 150 | """Returns augmentation_fn(image, height, width, channels, **kwargs). 151 | Args: 152 | name: The name of the preprocessing function. 153 | Returns: 154 | augmentation_fn: A function that preprocessing a single image (pre-batch). 155 | It has the following signature: 156 | image = augmentation_fn(image, output_height, output_width, ...). 157 | 158 | Raises: 159 | ValueError: If Preprocessing `name` is not recognized. 160 | """ 161 | if name not in _available_augmentation_methods: 162 | raise ValueError(f"Augmentation name [{name}] was not recognized") 163 | 164 | def augmentation_fn(image, output_height: int, output_width: int, channels, **kwargs): 165 | return eval(name)( 166 | image, output_height, output_width, channels, **kwargs) 167 | 168 | return augmentation_fn 169 | -------------------------------------------------------------------------------- /datasets/data_wrapper_base.py: -------------------------------------------------------------------------------- 1 | from abc import ABC 2 | from abc import abstractmethod 3 | from pathlib import Path 4 | from typing import Tuple 5 | from typing import List 6 | from itertools import chain 7 | 8 | import tensorflow as tf 9 | from termcolor import colored 10 | 11 | import common.utils as utils 12 | from datasets.augmentation_factory import _available_augmentation_methods 13 | from datasets.augmentation_factory import get_augmentation_fn 14 | 15 | 16 | class DataWrapperBase(ABC): 17 | def __init__( 18 | self, 19 | args, 20 | dataset_split_name: str, 21 | is_training: bool, 22 | name: str, 23 | ): 24 | self.name = name 25 | self.args = args 26 | self.dataset_split_name = dataset_split_name 27 | self.is_training = is_training 28 | 29 | # args.inference is False by default 30 | self.need_label = not self.args.inference 31 | self.shuffle = self.args.shuffle 32 | self.supported_extensions = [".jpg", ".JPEG", ".png"] 33 | 34 | self.log = utils.get_logger(self.name, None) 35 | self.timer = utils.Timer(self.log) 36 | self.dataset_path = Path(self.args.dataset_path) 37 | self.dataset_path_with_split_name = self.dataset_path / self.dataset_split_name 38 | 39 | with utils.format_text("yellow", ["underline"]) as fmt: 40 | self.log.info(self.name) 41 | self.log.info(fmt(f"dataset_path_with_split_name: {self.dataset_path_with_split_name}")) 42 | self.log.info(fmt(f"dataset_split_name: {self.dataset_split_name}")) 43 | 44 | @property 45 | @abstractmethod 46 | def num_samples(self): 47 | pass 48 | 49 | @property 50 | def padded_original_image_dummy_shape(self): 51 | return (1, 1, 1) 52 | 53 | @property 54 | def padded_max_size(self): 55 | return 400 56 | 57 | def resize_and_padding_before_augmentation(self, image, size): 58 | # If width > height, resize height to model's input height while preserving aspect ratio 59 | # If height > width, resize width to model's input width while preserving aspect ratio 60 | if self.args.debug_augmentation: 61 | assert size[0] == size[1], "resize_and_padding_before_augmentation only supports square target image" 62 | image = tf.expand_dims(image, 0) 63 | 64 | image_dims = tf.shape(image) 65 | height = image_dims[1] 66 | width = image_dims[2] 67 | 68 | min_size = min(*size) 69 | width_aspect = tf.maximum(min_size, tf.cast(width * min_size / height, dtype=tf.int32)) 70 | height_aspect = tf.maximum(min_size, tf.cast(height * min_size / width, dtype=tf.int32)) 71 | 72 | image = tf.image.resize_bilinear(image, (height_aspect, width_aspect)) 73 | image = image[:, :self.padded_max_size, :self.padded_max_size, :] 74 | 75 | # Pads the image on the bottom and right with zeros until it has dimensions target_height, target_width. 76 | image = tf.image.pad_to_bounding_box( 77 | image, 78 | offset_height=tf.maximum(self.padded_max_size-height_aspect, 0), 79 | offset_width=tf.maximum(self.padded_max_size-width_aspect, 0), 80 | target_height=self.padded_max_size, 81 | target_width=self.padded_max_size, 82 | ) 83 | 84 | image = tf.squeeze(image, 0) 85 | return image 86 | else: 87 | # Have to return some dummy tensor which have .get_shape() to tf.dataset 88 | return tf.constant(0, shape=self.padded_original_image_dummy_shape, dtype=tf.uint8, name="dummy") 89 | 90 | def augment_image(self, image, args, channels=3, **kwargs): 91 | aug_fn = get_augmentation_fn(args.augmentation_method) 92 | image = aug_fn(image, args.height, args.width, channels, **kwargs) 93 | return image 94 | 95 | @property 96 | def batch_size(self): 97 | try: 98 | return self._batch_size 99 | except AttributeError: 100 | self._batch_size = 0 101 | 102 | @batch_size.setter 103 | def batch_size(self, val): 104 | self._batch_size = val 105 | 106 | def get_all_images(self, image_path): 107 | if isinstance(image_path, list): 108 | img_gen = [] 109 | for p in image_path: 110 | for ext in self.supported_extensions: 111 | img_gen.append(Path(p).glob(f"*{ext}")) 112 | else: 113 | img_gen = [Path(image_path).glob(f"*{ext}") for ext in self.supported_extensions] 114 | 115 | return chain(*img_gen) 116 | 117 | def setup_dataset( 118 | self, 119 | placeholders: Tuple[tf.placeholder, tf.placeholder], 120 | batch_size: int=None, 121 | ): 122 | self.batch_size = self.args.batch_size if batch_size is None else batch_size 123 | 124 | dataset = tf.data.Dataset.from_tensor_slices(placeholders) 125 | dataset = dataset.map(self._parse_function, num_parallel_calls=self.args.num_threads).prefetch( 126 | self.args.prefetch_factor * self.batch_size) 127 | if self.is_training: 128 | dataset = dataset.repeat() 129 | if self.shuffle: 130 | dataset = dataset.shuffle(buffer_size=self.args.buffer_size) 131 | self.dataset = dataset.batch(self.batch_size) 132 | self.iterator = self.dataset.make_initializable_iterator() 133 | self.next_elem = self.iterator.get_next() 134 | 135 | def setup_iterator(self, 136 | session: tf.Session, 137 | placeholders: Tuple[tf.placeholder, tf.placeholder], 138 | variables: Tuple[tf.placeholder, tf.placeholder], 139 | ): 140 | assert len(placeholders) == len(variables), "Length of placeholders and variables differ!" 141 | with self.timer(colored("Initialize data iterator.", "yellow")): 142 | session.run(self.iterator.initializer, 143 | feed_dict={placeholder: variable for placeholder, variable in zip(placeholders, variables)}) 144 | 145 | def get_input_and_output_op(self): 146 | return self.next_elem 147 | 148 | def __str__(self): 149 | return f"path: {self.args.dataset_path}, split: {self.args.dataset_split_name} data size: {self._num_samples}" 150 | 151 | def get_all_dataset_paths(self) -> List[str]: 152 | if self.args.has_sub_dataset: 153 | return sorted([p for p in self.dataset_path_with_split_name.glob("*/") if p.is_dir()]) 154 | else: 155 | return [self.dataset_path_with_split_name] 156 | 157 | @staticmethod 158 | def add_arguments(parser): 159 | g_common = parser.add_argument_group("(DataWrapperBase) Common Arguments for all data wrapper.") 160 | g_common.add_argument("--dataset_path", required=True, type=str, help="The name of the dataset to load.") 161 | g_common.add_argument("--dataset_split_name", required=True, type=str, nargs="*", 162 | help="The name of the train/test split. Support multiple splits") 163 | 164 | g_common.add_argument("--batch_size", default=32, type=utils.positive_int, 165 | help="The number of examples in batch.") 166 | g_common.add_argument("--no-shuffle", dest="shuffle", action="store_false") 167 | g_common.add_argument("--shuffle", dest="shuffle", action="store_true") 168 | g_common.set_defaults(shuffle=True) 169 | 170 | g_common.add_argument("--width", required=True, type=int) 171 | g_common.add_argument("--height", required=True, type=int) 172 | g_common.add_argument("--no-debug_augmentation", dest="debug_augmentation", action="store_false") 173 | g_common.add_argument("--debug_augmentation", dest="debug_augmentation", action="store_true") 174 | g_common.set_defaults(debug_augmentation=False) 175 | g_common.add_argument("--max_padded_size", default=224, type=int, 176 | help=("We will resize & pads the original image " 177 | "until it has dimensions (padded_size, padded_size)" 178 | "Recommend to set this value as width(or height) * 1.8 ~ 2")) 179 | g_common.add_argument("--augmentation_method", type=str, required=True, 180 | choices=_available_augmentation_methods) 181 | g_common.add_argument("--num_threads", default=8, type=int) 182 | g_common.add_argument("--buffer_size", default=1000, type=int) 183 | g_common.add_argument("--prefetch_factor", default=100, type=int) 184 | 185 | g_common.add_argument("--rotation_range", default=0, type=int, 186 | help="Receives maximum angle to be rotated in terms of degree: " 187 | "The image is randomly rotated by the angle " 188 | "randomly chosen from [-rotation_range, rotation_range], " 189 | "and then cropped appropriately to remove dark areas.\n" 190 | "So, be aware that the rotation performs certain kind of zooming.") 191 | g_common.add_argument("--no-has_sub_dataset", dest="has_sub_dataset", action="store_false") 192 | g_common.add_argument("--has_sub_dataset", dest="has_sub_dataset", action="store_true") 193 | g_common.set_defaults(has_sub_dataset=False) 194 | -------------------------------------------------------------------------------- /datasets/matting_data_wrapper.py: -------------------------------------------------------------------------------- 1 | import random 2 | from pathlib import Path 3 | 4 | import tensorflow as tf 5 | from termcolor import colored 6 | 7 | from datasets.data_wrapper_base import DataWrapperBase 8 | 9 | 10 | class MattingDataWrapper(DataWrapperBase): 11 | """ MattingDataWrapper 12 | dataset_dir 13 | |___name 14 | |__mask 15 | |__image 16 | """ 17 | def __init__( 18 | self, 19 | args, 20 | session, 21 | dataset_split_name: str, 22 | is_training: bool=False, 23 | name: str="MattingDataWrapper", 24 | ): 25 | super().__init__(args, dataset_split_name, is_training, name=name) 26 | 27 | self.image_channels = 3 28 | 29 | self.IMAGE_DIR_NAME = "image" 30 | self.MASK_DIR_NAME = "mask" 31 | 32 | self.image_label_dirs, self.mask_label_dirs = self.build_dataset_paths() 33 | 34 | self.setup() 35 | 36 | self.setup_dataset((self.image_placeholder, self.mask_placeholder)) 37 | self.setup_iterator( 38 | session, 39 | (self.image_placeholder, self.mask_placeholder), 40 | (self.image_fullpathes, self.mask_fullpathes), 41 | ) 42 | 43 | def setup(self): 44 | self.image_fullpathes = [] 45 | self.mask_fullpathes = [] 46 | 47 | for image in sorted(self.get_all_images(self.image_label_dirs)): 48 | imagepath = str(image) 49 | self.image_fullpathes.append(imagepath) 50 | maskpath = imagepath.replace(f"/{self.IMAGE_DIR_NAME}/", f"/{self.MASK_DIR_NAME}/") 51 | assert Path(maskpath).exists(), f"{maskpath} not found" 52 | self.mask_fullpathes.append(maskpath) 53 | 54 | self._num_samples = len(self.image_fullpathes) 55 | self.log.info(colored(f"Num Data: {self._num_samples}", "red")) 56 | 57 | if self.shuffle: 58 | shuffled_data = list(zip(self.image_fullpathes, self.mask_fullpathes)) 59 | random.shuffle(shuffled_data) 60 | 61 | self.image_fullpathes, self.mask_fullpathes = zip(*shuffled_data) 62 | self.log.info(colored("Data shuffled!", "red")) 63 | 64 | self.image_placeholder = tf.placeholder(tf.string, len(self.image_fullpathes)) 65 | self.mask_placeholder = tf.placeholder(tf.string, len(self.mask_fullpathes)) 66 | 67 | @property 68 | def num_samples(self): 69 | return self._num_samples 70 | 71 | @property 72 | def padded_original_image_dummy_shape(self): 73 | return (1, 1, 4) 74 | 75 | def build_dataset_paths(self): 76 | dataset_paths = self.get_all_dataset_paths() 77 | 78 | image_label_dirs = [] 79 | mask_label_dirs = [] 80 | for dataset_path in dataset_paths: 81 | image_dataset_dir = dataset_path / self.IMAGE_DIR_NAME 82 | mask_dataset_dir = dataset_path / self.MASK_DIR_NAME 83 | 84 | image_label_dir = image_dataset_dir 85 | mask_label_dir = mask_dataset_dir 86 | 87 | image_label_dirs.append(image_label_dir) 88 | mask_label_dirs.append(mask_label_dir) 89 | 90 | return image_label_dirs, mask_label_dirs 91 | 92 | def _parse_function(self, imagename, maskname): 93 | image = tf.image.decode_jpeg(tf.read_file(imagename), channels=self.image_channels) 94 | if self.need_label: 95 | mask = tf.image.decode_jpeg(tf.read_file(maskname), channels=1) 96 | else: 97 | mask = tf.zeros([self.args.height, self.args.width]) 98 | 99 | comb = tf.concat([mask, image], axis=2) # Concat trick for syncronizing locations of random crop 100 | 101 | comb_original = self.resize_and_padding_before_augmentation(comb, [self.args.height, self.args.width]) 102 | comb_augmented = self.augment_image(comb, self.args, channels=self.image_channels+1) 103 | 104 | if self.args.target_eval_shape: 105 | mask_original = tf.image.resize_images( 106 | tf.cast(mask, tf.float32) / 255.0, 107 | self.args.target_eval_shape, 108 | ) 109 | else: 110 | mask_original = tf.squeeze(comb_original[:, :, 0:1], 2) # mask_original = comb_original[:, :, 0] 111 | 112 | image_original = comb_original[:, :, 1:] 113 | 114 | mask_augmented, image_augmented = comb_augmented[:, :, 0:1], comb_augmented[:, :, 1:] 115 | mask_augmented = tf.reshape(mask_augmented, [self.args.height, self.args.width]) / 255.0 116 | image_augmented = tf.reshape( 117 | image_augmented, 118 | [self.args.height, self.args.width, self.image_channels], 119 | ) 120 | 121 | return image_original, mask_original, image_augmented, mask_augmented 122 | 123 | @staticmethod 124 | def add_arguments(parser): 125 | pass 126 | -------------------------------------------------------------------------------- /demo/demo.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hyperconnect/MMNet/aa423b598f38e051fcf46914cf8a7765458c09d2/demo/demo.mp4 -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from typing import List 3 | 4 | import tensorflow as tf 5 | 6 | from common.tf_utils import ckpt_iterator 7 | import common.utils as utils 8 | import const 9 | from datasets.data_wrapper_base import DataWrapperBase 10 | from datasets.matting_data_wrapper import MattingDataWrapper 11 | from factory.base import CNNModel 12 | import factory.matting_nets as matting_nets 13 | from helper.base import Base 14 | from helper.evaluator import Evaluator 15 | from helper.evaluator import MattingEvaluator 16 | from metrics.base import MetricManagerBase 17 | 18 | 19 | def evaluate(args): 20 | evaluator = build_evaluator(args) 21 | log = utils.get_logger("EvaluateMatting") 22 | dataset_names = args.dataset_split_name 23 | 24 | if args.inference: 25 | for dataset_name in dataset_names: 26 | evaluator[dataset_name].inference(args.checkpoint_path) 27 | else: 28 | if args.valid_type == "once": 29 | for dataset_name in dataset_names: 30 | evaluator[dataset_name].evaluate_once(args.checkpoint_path) 31 | elif args.valid_type == "loop": 32 | current_evaluator = evaluator[dataset_names[0]] 33 | log.info(f"Start Loop: watching {current_evaluator.watch_path}") 34 | 35 | kwargs = { 36 | "min_interval_secs": 0, 37 | "timeout": None, 38 | "timeout_fn": None, 39 | "logger": log, 40 | } 41 | for ckpt_path in ckpt_iterator(current_evaluator.watch_path, **kwargs): 42 | log.info(f"[watch] {ckpt_path}") 43 | 44 | for dataset_name in dataset_names: 45 | evaluator[dataset_name].evaluate_once(ckpt_path) 46 | else: 47 | raise ValueError(f"Undefined valid_type: {args.valid_type}") 48 | 49 | 50 | def build_evaluator(args, evaluator_cls=MattingEvaluator): 51 | session = tf.Session(config=const.TF_SESSION_CONFIG) 52 | dataset_names = args.dataset_split_name 53 | 54 | dataset = MattingDataWrapper( 55 | args, 56 | session, 57 | dataset_names[0], 58 | is_training=False, 59 | ) 60 | images_original, masks_original, images, masks = dataset.get_input_and_output_op() 61 | model = eval("matting_nets.{}".format(args.model))(args, dataset) 62 | model.build( 63 | images_original=images_original, 64 | images=images, 65 | masks_original=masks_original, 66 | masks=masks, 67 | is_training=False, 68 | ) 69 | 70 | evaluator = { 71 | dataset_names[0]: evaluator_cls( 72 | model, 73 | session, 74 | args, 75 | dataset, 76 | dataset_names[0], 77 | ) 78 | } 79 | for dataset_name in dataset_names[1:]: 80 | assert False, "Evaluation of multiple dataset splits does not work." 81 | dataset = MattingDataWrapper( 82 | args, 83 | session, 84 | dataset_name, 85 | is_training=False, 86 | ) 87 | 88 | evaluator[dataset_name] = evaluator_cls( 89 | model, 90 | session, 91 | args, 92 | dataset, 93 | dataset_name, 94 | ) 95 | 96 | return evaluator 97 | 98 | 99 | def parse_arguments(arguments: List[str]=None): 100 | parser = argparse.ArgumentParser(description=__doc__) 101 | subparsers = parser.add_subparsers(title="Model", description="") 102 | 103 | # -- * -- Common Arguments & Each Model's Arguments -- * -- 104 | CNNModel.add_arguments(parser, default_type="matting") 105 | matting_nets.MattingNetModel.add_arguments(parser) 106 | for class_name in matting_nets._available_nets: 107 | subparser = subparsers.add_parser(class_name) 108 | subparser.add_argument("--model", default=class_name, type=str, help="DO NOT FIX ME") 109 | add_matting_net_arguments = eval("matting_nets.{}.add_arguments".format(class_name)) 110 | add_matting_net_arguments(subparser) 111 | 112 | Evaluator.add_arguments(parser) 113 | Base.add_arguments(parser) 114 | DataWrapperBase.add_arguments(parser) 115 | MattingDataWrapper.add_arguments(parser) 116 | MetricManagerBase.add_arguments(parser) 117 | 118 | args = parser.parse_args(arguments) 119 | 120 | model_arguments = utils.get_subparser_argument_list(parser, args.model) 121 | args.model_arguments = model_arguments 122 | 123 | return args 124 | 125 | 126 | if __name__ == "__main__": 127 | args = parse_arguments() 128 | log = utils.get_logger("MattingEvaluator", None) 129 | 130 | log.info(args) 131 | evaluate(args) 132 | -------------------------------------------------------------------------------- /factory/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hyperconnect/MMNet/aa423b598f38e051fcf46914cf8a7765458c09d2/factory/__init__.py -------------------------------------------------------------------------------- /factory/base.py: -------------------------------------------------------------------------------- 1 | from abc import ABC 2 | from abc import abstractmethod 3 | 4 | import tensorflow as tf 5 | import tensorflow.contrib.slim as slim 6 | 7 | import common.tf_utils as tf_utils 8 | 9 | 10 | class CNNModel(ABC): 11 | def preprocess_images(self, images, preprocess_method, reuse=False): 12 | with tf.variable_scope("preprocess", reuse=reuse): 13 | if images.dtype == tf.uint8: 14 | images = tf.cast(images, tf.float32) 15 | if preprocess_method == "preprocess_normalize": 16 | # -- * -- preprocess_normalize 17 | # Scale input images to range [0, 1], same scales like mean of masks 18 | images = tf.divide(images, tf.constant(255.0)) 19 | elif preprocess_method == "no_preprocessing": 20 | pass 21 | else: 22 | raise ValueError("Unsupported preprocess_method: {}".format(preprocess_method)) 23 | 24 | return images 25 | 26 | @staticmethod 27 | def add_arguments(parser, default_type): 28 | g_cnn = parser.add_argument_group("(CNNModel) Arguments") 29 | assert default_type in ["matting", None] 30 | 31 | g_cnn.add_argument("--task_type", type=str, required=True, 32 | choices=[ 33 | "matting", 34 | ]) 35 | g_cnn.add_argument("--num_classes", type=int, default=None, 36 | help=( 37 | "It is currently not used in multi-task learning, " 38 | "so it can't *required*" 39 | )) 40 | g_cnn.add_argument("--checkpoint_path", default="", type=str) 41 | 42 | g_cnn.add_argument("--input_name", type=str, default="input/image") 43 | g_cnn.add_argument("--input_batch_size", type=int, default=1) 44 | g_cnn.add_argument("--output_name", type=str, required=True) 45 | g_cnn.add_argument("--output_type", type=str, help="mainly used in convert.py", required=True) 46 | 47 | g_cnn.add_argument("--no-use_fused_batchnorm", dest="use_fused_batchnorm", action="store_false") 48 | g_cnn.add_argument("--use_fused_batchnorm", dest="use_fused_batchnorm", action="store_true") 49 | g_cnn.set_defaults(use_fused_batchnorm=True) 50 | 51 | g_cnn.add_argument("--verbosity", default=0, type=int, 52 | help="If verbosity > 0, then summary batch_norm scalar metrics etc") 53 | g_cnn.add_argument("--preprocess_method", required=True, type=str, 54 | choices=["no_preprocessing", "preprocess_normalize"]) 55 | 56 | g_cnn.add_argument("--no-ignore_missing_vars", dest="ignore_missing_vars", action="store_false") 57 | g_cnn.add_argument("--ignore_missing_vars", dest="ignore_missing_vars", action="store_true") 58 | g_cnn.set_defaults(ignore_missing_vars=False) 59 | 60 | g_cnn.add_argument("--checkpoint_exclude_scopes", default="", type=str, 61 | help=("Prefix scopes that shoule be EXLUDED for restoring variables " 62 | "(comma separated)\n Usually Logits e.g. InceptionResnetV2/Logits/Logits, " 63 | "InceptionResnetV2/AuxLogits/Logits")) 64 | 65 | g_cnn.add_argument("--checkpoint_include_scopes", default="", type=str, 66 | help=("Prefix scopes that should be INCLUDED for restoring variables " 67 | "(comma separated)")) 68 | 69 | def build_finish(self, is_training, log): 70 | total_params = tf_utils.show_models(log) 71 | 72 | if self.args.verbosity >= 1: 73 | slim.model_analyzer.analyze_ops(tf.get_default_graph(), print_info=True) 74 | 75 | return total_params 76 | 77 | @abstractmethod 78 | def build_output(self): 79 | pass 80 | 81 | @property 82 | @abstractmethod 83 | def images(self): 84 | pass 85 | 86 | @property 87 | @abstractmethod 88 | def images_original(self): 89 | pass 90 | 91 | @property 92 | @abstractmethod 93 | def total_loss(self): 94 | pass 95 | 96 | @property 97 | @abstractmethod 98 | def model_loss(self): 99 | pass 100 | -------------------------------------------------------------------------------- /factory/losses_funcs.py: -------------------------------------------------------------------------------- 1 | from tensorflow.python.ops.nn_ops import softmax_cross_entropy_with_logits_v2 2 | import tensorflow as tf 3 | 4 | import common.utils as utils 5 | 6 | 7 | def alpha_loss(alpha_scores, masks_float32): 8 | return tf.losses.absolute_difference(alpha_scores, masks_float32) 9 | 10 | 11 | def comp_loss(alpha_scores, masks_float32, images): 12 | normalized_images = images / 255 13 | 14 | reconst_fg = tf.multiply(normalized_images, alpha_scores) 15 | true_fg = tf.multiply(normalized_images, masks_float32) 16 | 17 | return tf.losses.absolute_difference(reconst_fg, true_fg) 18 | 19 | 20 | def grad_loss(alpha_scores, masks_float32): 21 | filter_x = tf.constant([[-1/8, 0, 1/8], 22 | [-2/8, 0, 2/8], 23 | [-1/8, 0, 1/8]], 24 | name="sober_x", dtype=tf.float32, shape=[3, 3, 1, 1]) 25 | filter_y = tf.constant([[-1/8, -2/8, -1/8], 26 | [0, 0, 0], 27 | [1/8, 2/8, 1/8]], 28 | name="sober_y", dtype=tf.float32, shape=[3, 3, 1, 1]) 29 | filter_xy = tf.concat([filter_x, filter_y], axis=-1) 30 | 31 | grad_alpha = tf.nn.conv2d(alpha_scores, filter_xy, strides=[1, 1, 1, 1], padding="SAME") 32 | grad_masks = tf.nn.conv2d(masks_float32, filter_xy, strides=[1, 1, 1, 1], padding="SAME") 33 | 34 | return tf.losses.absolute_difference(grad_alpha, grad_masks) 35 | 36 | 37 | def kd_loss(scores, masks): 38 | logits = tf.reshape(scores, [-1, 2]) # (?, 2) 39 | masks_foreground = tf.reshape(masks, [-1]) # foreground goes to value 1 40 | labels = tf.stack([1 - masks_foreground, masks_foreground], axis=1) 41 | loss = softmax_cross_entropy_with_logits_v2( 42 | logits=logits, 43 | labels=tf.stop_gradient(labels), 44 | ) 45 | 46 | cost = tf.reduce_mean(loss) 47 | return cost 48 | -------------------------------------------------------------------------------- /factory/matting_converter.py: -------------------------------------------------------------------------------- 1 | from abc import ABC 2 | from abc import abstractmethod 3 | import tensorflow as tf 4 | 5 | 6 | class ConverterBase(ABC): 7 | @classmethod 8 | @abstractmethod 9 | def convert( 10 | cls, 11 | logits: tf.Tensor, 12 | output_name: str, 13 | num_classes: int, 14 | ): 15 | raise NotImplementedError(f"convert() not defined in {cls.__class__.__name__}") 16 | 17 | 18 | class ProbConverter(ConverterBase): 19 | @classmethod 20 | def convert( 21 | cls, 22 | logits: tf.Tensor, 23 | output_name: str, 24 | num_classes: int, 25 | ): 26 | assert num_classes == 2 27 | 28 | softmax_scores = tf.contrib.layers.softmax(logits, scope="output/softmax") 29 | # tf.identity to assign output_name 30 | output = tf.identity(softmax_scores, name=output_name) 31 | return output 32 | 33 | -------------------------------------------------------------------------------- /factory/matting_nets.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tensorflow.contrib.slim as slim 3 | 4 | import common.utils as utils 5 | import factory.losses_funcs as losses_funcs 6 | import matting_nets.deeplab_v3plus as deeplab 7 | from matting_nets import mmnet 8 | from factory.base import CNNModel 9 | from factory.matting_converter import ProbConverter 10 | 11 | 12 | _available_nets = [ 13 | "MMNetModel", 14 | "DeepLabModel", 15 | ] 16 | 17 | 18 | class MattingNetModel(CNNModel): 19 | def __init__(self, args, dataset=None): 20 | self.log = utils.get_logger("MattingNetModel") 21 | self.args = args 22 | self.dataset = dataset # used to access data created in DataWrapper 23 | 24 | def build( 25 | self, 26 | images_original: tf.Tensor, 27 | images: tf.Tensor, 28 | masks_original: tf.Tensor, 29 | masks: tf.Tensor, 30 | is_training: bool, 31 | ): 32 | self._images_original = images_original 33 | self._images = images 34 | self._masks = masks 35 | self._masks_original = masks_original 36 | 37 | # -- * -- build_output: it includes build_inference 38 | inputs, logits, outputs, self.endpoints = self.build_output( 39 | self.images, 40 | is_training, 41 | self.args.output_name, 42 | ) 43 | self.prob_scores = outputs 44 | 45 | # -- * -- Build loss function: it can be different between each models 46 | self._total_loss, self._model_loss, self.endpoints_loss = self.build_loss( 47 | inputs, 48 | logits, 49 | outputs, 50 | self.masks, 51 | ) 52 | 53 | # -- * -- build iamge ops for summary 54 | self._image_ops = self.build_image_ops(self.images, self.prob_scores, self.masks) 55 | 56 | self.total_params = self.build_finish(is_training, self.log) 57 | 58 | @property 59 | def total_loss(self): 60 | return self._total_loss 61 | 62 | @property 63 | def model_loss(self): 64 | return self._model_loss 65 | 66 | def build_output( 67 | self, 68 | images, 69 | is_training, 70 | output_name, 71 | ): 72 | inputs = self.build_images(images) 73 | logits, endpoints = self.build_inference( 74 | inputs, 75 | is_training=is_training, 76 | ) 77 | outputs = ProbConverter.convert( 78 | logits, 79 | output_name, 80 | self.args.num_classes, 81 | ) 82 | return inputs, logits, outputs, endpoints 83 | 84 | def build_images( 85 | self, 86 | images: tf.Tensor, 87 | ): 88 | images_preprocessed = self.preprocess_images( 89 | images, 90 | preprocess_method=self.args.preprocess_method, 91 | ) 92 | 93 | return images_preprocessed 94 | 95 | def build_inference(self, images, is_training=True): 96 | raise NotImplementedError 97 | 98 | def build_loss( 99 | self, 100 | inputs: tf.Tensor, 101 | logits: tf.Tensor, 102 | scores: tf.Tensor, 103 | masks: tf.Tensor, 104 | ): 105 | endpoints_loss = {} 106 | alpha_scores = scores[:, :, :, -1:] 107 | masks_float32 = tf.expand_dims(tf.cast(masks, dtype=tf.float32), axis=3) 108 | 109 | sum_lambda = 0. 110 | 111 | # alpha loss 112 | if self.args.lambda_alpha_loss > 0: 113 | endpoints_loss["alpha_loss"] = losses_funcs.alpha_loss( 114 | alpha_scores=alpha_scores, 115 | masks_float32=masks_float32, 116 | ) 117 | endpoints_loss["alpha_loss"] *= self.args.lambda_alpha_loss 118 | sum_lambda += self.args.lambda_alpha_loss 119 | 120 | # compositional loss 121 | if self.args.lambda_comp_loss > 0: 122 | endpoints_loss["comp_loss"] = losses_funcs.comp_loss( 123 | alpha_scores=alpha_scores, 124 | masks_float32=masks_float32, 125 | images=inputs, 126 | ) 127 | endpoints_loss["comp_loss"] *= self.args.lambda_comp_loss 128 | sum_lambda += self.args.lambda_comp_loss 129 | 130 | # gradient loss 131 | if self.args.lambda_grad_loss > 0: 132 | endpoints_loss["grad_loss"] = losses_funcs.grad_loss( 133 | alpha_scores=alpha_scores, 134 | masks_float32=masks_float32, 135 | ) 136 | endpoints_loss["grad_loss"] *= self.args.lambda_grad_loss 137 | sum_lambda += self.args.lambda_grad_loss 138 | 139 | # kd loss 140 | if self.args.lambda_kd_loss > 0: 141 | endpoints_loss["kd_loss"] = losses_funcs.kd_loss( 142 | scores=logits, 143 | masks=masks, 144 | ) 145 | endpoints_loss["kd_loss"] *= self.args.lambda_kd_loss 146 | sum_lambda += self.args.lambda_kd_loss 147 | 148 | # kd aux loss 149 | if self.args.lambda_aux_loss > 0: 150 | _, h, w, _ = self.encoded_map.get_shape().as_list() 151 | compress_mask = tf.image.resize_images( 152 | tf.expand_dims(masks, 3), 153 | [h, w], 154 | tf.image.ResizeMethod.NEAREST_NEIGHBOR 155 | ) 156 | compress_mask = tf.squeeze(compress_mask, 3) 157 | 158 | endpoints_loss["aux_loss"] = losses_funcs.kd_loss( 159 | scores=self.encoded_map, 160 | masks=compress_mask, 161 | ) 162 | endpoints_loss["aux_loss"] *= self.args.lambda_aux_loss 163 | sum_lambda += self.args.lambda_aux_loss 164 | 165 | model_loss = tf.add_n(list(endpoints_loss.values())) / sum_lambda 166 | 167 | if len(tf.losses.get_regularization_losses()) > 0: 168 | reg_loss = tf.add_n(tf.losses.get_regularization_losses()) 169 | else: 170 | reg_loss = tf.constant(0.) 171 | 172 | endpoints_loss.update({ 173 | "regularize_loss": reg_loss, 174 | }) 175 | 176 | total_loss = model_loss + reg_loss 177 | return total_loss, model_loss, endpoints_loss 178 | 179 | @staticmethod 180 | def build_image_ops(images, prob_scores, masks): 181 | alpha_scores = prob_scores[:, :, :, -1:] 182 | binary_scores = tf.cast((alpha_scores > 0.5), tf.float32) 183 | masks_casted = tf.expand_dims(masks, -1) 184 | images_under_binary = images * binary_scores 185 | images_under_prob = images * alpha_scores 186 | 187 | image_ops = { 188 | "images": images, 189 | "prob_scores": alpha_scores * 255, 190 | "binary_scores": binary_scores * 255, 191 | "images_under_binary": images_under_binary, 192 | "images_under_prob": images_under_prob, 193 | "masks": masks_casted * 255, 194 | } 195 | 196 | return {k: tf.cast(op, tf.uint8) for k, op in image_ops.items()} 197 | 198 | @property 199 | def images_original(self): 200 | return self._images_original 201 | 202 | @property 203 | def images(self): 204 | return self._images 205 | 206 | @property 207 | def image_ops(self): 208 | return self._image_ops 209 | 210 | @property 211 | def masks_original(self): 212 | return self._masks_original 213 | 214 | @property 215 | def masks(self): 216 | return self._masks 217 | 218 | @staticmethod 219 | def add_arguments(parser): 220 | parser.add_argument("--lambda_alpha_loss", type=float, default=0.0) 221 | parser.add_argument("--lambda_comp_loss", type=float, default=0.0) 222 | parser.add_argument("--lambda_grad_loss", type=float, default=0.0) 223 | parser.add_argument("--lambda_kd_loss", type=float, default=0.0) 224 | parser.add_argument("--lambda_aux_loss", type=float, default=0.0) 225 | 226 | 227 | class MMNetModel(MattingNetModel): 228 | def __init__(self, args, dataset=None): 229 | super().__init__(args, dataset) 230 | 231 | def build_inference(self, images, is_training=True): 232 | with slim.arg_scope(mmnet.MMNet_arg_scope(use_fused_batchnorm=self.args.use_fused_batchnorm, 233 | weight_decay=self.args.weight_decay, 234 | dropout=self.args.dropout)): 235 | logit_scores, endpoints = mmnet.MMNet( 236 | images, 237 | is_training, 238 | depth_multiplier=self.args.width_multiplier, 239 | ) 240 | self.encoded_map = endpoints["aux_block"] 241 | return logit_scores, endpoints 242 | 243 | @staticmethod 244 | def add_arguments(parser): 245 | parser.add_argument("--width_multiplier", default=1.0, type=float) 246 | parser.add_argument("--weight_decay", type=float, default=0.0) 247 | parser.add_argument("--dropout", type=float, default=0.0) 248 | 249 | 250 | class DeepLabModel(MattingNetModel): 251 | def __init__(self, args, dataset=None): 252 | super(DeepLabModel, self).__init__(args, dataset) 253 | 254 | def build_inference(self, images, is_training=True): 255 | with slim.arg_scope(deeplab.DeepLab_arg_scope(weight_decay=self.args.weight_decay)): 256 | logit_scores, endpoints = deeplab.DeepLab(images, 257 | self.args, 258 | is_training=is_training, 259 | depth_multiplier=self.args.depth_multiplier, 260 | output_stride=self.args.output_stride) 261 | return logit_scores, endpoints 262 | 263 | @staticmethod 264 | def add_arguments(parser): 265 | parser.add_argument("--extractor", type=str, required=True, choices=deeplab.available_extractors) 266 | parser.add_argument("--weight_decay", type=float, default=0.00004) 267 | parser.add_argument("--depth_multiplier", default=1.0, type=float, help="MobileNetV2 depth_multiplier") 268 | parser.add_argument("--output_stride", default=16, type=int) 269 | -------------------------------------------------------------------------------- /figure/gradient_error_vs_latency.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hyperconnect/MMNet/aa423b598f38e051fcf46914cf8a7765458c09d2/figure/gradient_error_vs_latency.png -------------------------------------------------------------------------------- /helper/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hyperconnect/MMNet/aa423b598f38e051fcf46914cf8a7765458c09d2/helper/__init__.py -------------------------------------------------------------------------------- /helper/base.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import tensorflow as tf 4 | import numpy as np 5 | from abc import ABC 6 | from abc import abstractmethod 7 | from termcolor import colored 8 | 9 | from common.utils import Timer 10 | from common.utils import get_logger 11 | from common.utils import format_log 12 | from common.utils import format_text 13 | from metrics.summaries import BaseSummaries 14 | 15 | 16 | class Base(ABC): 17 | def __init__(self): 18 | self.log = get_logger("Base") 19 | self.timer = Timer(self.log) 20 | 21 | def get_feed_dict( 22 | self, 23 | is_training: bool=False, 24 | ): 25 | feed_dict = dict() 26 | 27 | if is_training and self.lr_scheduler.should_feed_dict: 28 | feed_dict[self.lr_scheduler.placeholder] = self.lr_scheduler.learning_rate 29 | 30 | return feed_dict 31 | 32 | def check_batch_size( 33 | self, 34 | batch_size: int, 35 | terminate: bool=False, 36 | ): 37 | if batch_size != self.args.batch_size: 38 | self.log.info(colored(f"Batch size: required {self.args.batch_size}, obtained {batch_size}", "red")) 39 | if terminate: 40 | raise tf.errors.OutOfRangeError(None, None, "Finished looping dataset.") 41 | 42 | def build_iters_from_batch_size(self, num_samples, batch_size): 43 | iters = self.dataset.num_samples // self.args.batch_size 44 | num_ignored_samples = self.dataset.num_samples % self.args.batch_size 45 | if num_ignored_samples > 0: 46 | with format_text("red", attrs=["bold"]) as fmt: 47 | msg = ( 48 | f"Number of samples cannot be divided by batch_size, " 49 | f"so it ignores some data examples in evaluation: " 50 | f"{self.dataset.num_samples} % {self.args.batch_size} = {num_ignored_samples}" 51 | ) 52 | self.log.warning(fmt(msg)) 53 | return iters 54 | 55 | def build_basic_image_ops(self): 56 | images = { 57 | "images": self.model.images, 58 | } 59 | if hasattr(self.model, "images_original") and self.model.images_original is not None: 60 | images["images_before_augmentation"] = self.model.images_original 61 | 62 | if hasattr(self.model, "image_ops"): 63 | for key, op in self.model.image_ops.items(): 64 | images[key] = op 65 | 66 | return images 67 | 68 | @abstractmethod 69 | def build_evaluation_fetch_ops(self, do_eval): 70 | raise NotImplementedError 71 | 72 | def run_inference( 73 | self, 74 | global_step: int, 75 | iters: int=None, 76 | is_training: bool=False, 77 | do_eval: bool=True, 78 | ): 79 | """ 80 | Return: Dict[metric_key] -> np.array 81 | array is stacked values for all batches 82 | """ 83 | feed_dict = self.get_feed_dict(is_training=is_training) 84 | 85 | is_first_batch = True 86 | 87 | if iters is None: 88 | iters = self.build_iters_from_batch_size(self.dataset.num_samples, self.args.batch_size) 89 | 90 | merged_tensor_type_summaries = self.metric_manager.summary.get_merged_summaries( 91 | collection_key_suffixes=[BaseSummaries.KEY_TYPES.DEFAULT], 92 | is_tensor_summary=True 93 | ) 94 | 95 | fetch_ops = self.build_evaluation_fetch_ops(do_eval) 96 | 97 | aggregator = {key: list() for key in fetch_ops} 98 | aggregator.update({ 99 | "batch_infer_time": list(), 100 | "unit_infer_time": list(), 101 | }) 102 | 103 | for i in range(iters): 104 | try: 105 | st = time.time() 106 | 107 | is_running_summary = do_eval and is_first_batch and merged_tensor_type_summaries is not None 108 | if is_running_summary: 109 | fetch_ops_with_summary = {"summary": merged_tensor_type_summaries} 110 | fetch_ops_with_summary.update(fetch_ops) 111 | 112 | fetch_vals = self.session.run(fetch_ops_with_summary, feed_dict=feed_dict) 113 | 114 | # To avoid duplicated code of session.run, we evaluate merged_sum 115 | # Because we run multiple batches within single global_step, 116 | # merged_summaries can have duplicated values. 117 | # So we write only when the session.run is first 118 | self.metric_manager.write_tensor_summaries(global_step, fetch_vals["summary"]) 119 | is_first_batch = False 120 | else: 121 | fetch_vals = self.session.run(fetch_ops, feed_dict=feed_dict) 122 | 123 | batch_infer_time = (time.time() - st) * 1000 # use milliseconds 124 | 125 | # aggregate 126 | for key, fetch_val in fetch_vals.items(): 127 | if key in aggregator: 128 | aggregator[key].append(fetch_val) 129 | 130 | # add inference time 131 | aggregator["batch_infer_time"].append(batch_infer_time) 132 | aggregator["unit_infer_time"].append(batch_infer_time / self.args.batch_size) 133 | 134 | except tf.errors.OutOfRangeError: 135 | format_log(self.log.info, "yellow")(f"Reach end of the dataset.") 136 | break 137 | except tf.errors.InvalidArgumentError as e: 138 | format_log(self.log.error, "red")(f"Invalid image is detected: {e}") 139 | continue 140 | 141 | aggregator = {k: np.vstack(v) for k, v in aggregator.items()} 142 | return aggregator 143 | 144 | def run_evaluation( 145 | self, 146 | global_step: int, 147 | iters: int=None, 148 | is_training: bool=False, 149 | ): 150 | eval_dict = self.run_inference(global_step, iters, is_training, do_eval=True) 151 | 152 | non_tensor_data = self.build_non_tensor_data_from_eval_dict(eval_dict, step=global_step) 153 | 154 | self.metric_manager.evaluate_and_aggregate_metrics(step=global_step, 155 | non_tensor_data=non_tensor_data, 156 | eval_dict=eval_dict) 157 | 158 | eval_metric_dict = self.metric_manager.get_evaluation_result(step=global_step) 159 | 160 | return eval_metric_dict 161 | 162 | @staticmethod 163 | def get_variables_to_restore(args, log=None, debug=False): 164 | if args.use_ema: 165 | ema = tf.train.ExponentialMovingAverage(decay=args.ema_decay) 166 | variables_to_restore = ema.variables_to_restore() # dictionary 167 | variables_to_restore_to_print = variables_to_restore.values() 168 | else: 169 | variables_to_restore = tf.contrib.framework.get_variables_to_restore() 170 | variables_to_restore_to_print = variables_to_restore 171 | 172 | if debug and log is not None: 173 | for var in variables_to_restore_to_print: 174 | log.info("var > {} {}".format(var.name, var.get_shape().as_list())) 175 | 176 | return variables_to_restore 177 | 178 | @staticmethod 179 | def add_arguments(parser): 180 | g_base = parser.add_argument_group("Base") 181 | 182 | """`inference` argument is used for two purposes. 183 | 184 | 1) To decide if labels should be provided, it is used 185 | in `DataWrapperBase` constructor as: 186 | need_label = not args.inference 187 | 188 | 2) To decide if we should run inference of evaluation. It is used 189 | in evaluate.py. 190 | 191 | For training we do not expect this argument to be used, therefore 192 | it defaults to False. 193 | """ 194 | g_base.add_argument("--evaluate", dest="inference", action="store_false") 195 | g_base.add_argument("--inference", dest="inference", action="store_true") 196 | g_base.set_defaults(inference=False) 197 | 198 | g_base.add_argument("--no-use_ema", dest="use_ema", action="store_false") 199 | g_base.add_argument("--use_ema", dest="use_ema", action="store_true", 200 | help="Exponential Moving Average. It may take more memory.") 201 | g_base.set_defaults(use_ema=False) 202 | g_base.add_argument("--ema_decay", default=0.999, type=float, 203 | help=("Exponential Moving Average decay.\n" 204 | "Reasonable values for decay are close to 1.0, typically " 205 | "in the multiple-nines range: 0.999, 0.9999")) 206 | g_base.add_argument("--evaluation_iterations", type=int, default=None) 207 | g_base.add_argument("--target_eval_shape", type=int, nargs="*", 208 | help="[H, W] Resize the outputs to this value for evaluation.") 209 | 210 | 211 | class MattingBase(Base): 212 | def build_evaluation_fetch_ops(self, do_eval): 213 | if do_eval: 214 | # delete 'images' since it is already added at trainer 215 | fetch_ops = { 216 | "probs": self.model.prob_scores, 217 | "total_loss": self.model.total_loss, 218 | } 219 | fetch_ops.update(self.metric_tf_op) 220 | else: 221 | fetch_ops = { 222 | "images": self.model.images, 223 | "probs": self.model.prob_scores, 224 | } 225 | 226 | return fetch_ops 227 | 228 | def build_misc_image_ops(self): 229 | images = dict() 230 | 231 | if self.args.save_evaluation_image: 232 | assert hasattr(self.model, "image_ops") 233 | for key, op in self.model.image_ops.items(): 234 | # squeeze is added because of PIL 235 | if op.get_shape().as_list()[-1] == 1: 236 | op = tf.squeeze(op, axis=-1) 237 | images[key] = op 238 | 239 | return images 240 | 241 | def build_basic_loss_ops(self): 242 | losses = { 243 | "total_loss": self.model.total_loss, 244 | "model_loss": self.model.model_loss, 245 | } 246 | losses.update(self.model.endpoints_loss) 247 | 248 | return losses 249 | -------------------------------------------------------------------------------- /helper/evaluator.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from abc import abstractmethod 3 | import sys 4 | 5 | import numpy as np 6 | import tensorflow as tf 7 | from scipy.misc import imsave 8 | from tqdm import tqdm 9 | 10 | import common.tf_utils as tf_utils 11 | import metrics.manager as metric_manager 12 | from common.model_loader import Ckpt 13 | from common.utils import format_text 14 | from common.utils import get_logger 15 | from helper.base import Base 16 | from helper.base import MattingBase 17 | from metrics.summaries import BaseSummaries 18 | from metrics.summaries import Summaries 19 | 20 | 21 | class Evaluator(object): 22 | _available_inference_output = None 23 | 24 | def __init__(self, model, session, args, dataset, dataset_name, name): 25 | self.log = get_logger(name) 26 | 27 | self.model = model 28 | self.session = session 29 | self.args = args 30 | self.dataset = dataset 31 | self.dataset_name = dataset_name 32 | 33 | if Path(self.args.checkpoint_path).is_dir(): 34 | latest_checkpoint = tf.train.latest_checkpoint(self.args.checkpoint_path) 35 | if latest_checkpoint is not None: 36 | self.args.checkpoint_path = latest_checkpoint 37 | self.log.info(f"Get latest checkpoint and update to it: {self.args.checkpoint_path}") 38 | 39 | self.watch_path = self._build_watch_path() 40 | 41 | self.variables_to_restore = Base.get_variables_to_restore(args=self.args, log=self.log, debug=False) 42 | self.session.run(tf.global_variables_initializer()) 43 | self.session.run(tf.local_variables_initializer()) 44 | 45 | self.ckpt_loader = Ckpt( 46 | session=session, 47 | variables_to_restore=self.variables_to_restore, 48 | include_scopes=args.checkpoint_include_scopes, 49 | exclude_scopes=args.checkpoint_exclude_scopes, 50 | ignore_missing_vars=args.ignore_missing_vars, 51 | ) 52 | 53 | @abstractmethod 54 | def setup_metric_manager(self): 55 | raise NotImplementedError 56 | 57 | @abstractmethod 58 | def setup_metric_ops(self): 59 | raise NotImplementedError 60 | 61 | @abstractmethod 62 | def build_non_tensor_data_from_eval_dict(self, eval_dict, **kwargs): 63 | raise NotImplementedError 64 | 65 | @abstractmethod 66 | def setup_dataset_iterator(self): 67 | raise NotImplementedError 68 | 69 | @abstractmethod 70 | def save_inference_result(self, eval_dict, checkpoint_path): 71 | raise NotImplementedError 72 | 73 | def _build_watch_path(self): 74 | if Path(self.args.checkpoint_path).is_dir(): 75 | return Path(self.args.checkpoint_path) 76 | else: 77 | return Path(self.args.checkpoint_path).parent 78 | 79 | def build_evaluation_step(self, checkpoint_path): 80 | if "-" in checkpoint_path and checkpoint_path.split("-")[-1].isdigit(): 81 | return int(checkpoint_path.split("-")[-1]) 82 | else: 83 | return 0 84 | 85 | def build_checkpoint_paths(self, checkpoint_path): 86 | checkpoint_glob = Path(checkpoint_path + "*") 87 | checkpoint_path = Path(checkpoint_path) 88 | 89 | return checkpoint_glob, checkpoint_path 90 | 91 | def build_miscellaneous_path(self, name): 92 | target_dir = self.watch_path / "miscellaneous" / self.dataset_name / name 93 | 94 | if not target_dir.exists(): 95 | target_dir.mkdir(parents=True) 96 | 97 | return target_dir 98 | 99 | def build_inference_path(self, checkpoint_path): 100 | if not isinstance(checkpoint_path, Path): 101 | checkpoint_path = Path(checkpoint_path) 102 | if checkpoint_path.is_dir(): 103 | root_dir = checkpoint_path 104 | else: 105 | root_dir = checkpoint_path.parent 106 | 107 | output_parent_dir = root_dir / "inference" / checkpoint_path.name 108 | if self.args.inference_output_dirname is not None: 109 | output_dir = output_parent_dir / self.args.inference_output_dirname 110 | else: 111 | output_dir = output_parent_dir / self.args.inference_output 112 | 113 | return output_dir 114 | 115 | def setup_best_keeper(self): 116 | metric_with_modes = self.metric_manager.get_best_keep_metric_with_modes() 117 | self.log.debug(metric_with_modes) 118 | self.best_keeper = tf_utils.BestKeeper(metric_with_modes, 119 | self.dataset_name, 120 | self.watch_path, 121 | self.log) 122 | 123 | def inference(self, checkpoint_path): 124 | assert not self.args.shuffle, "Current implementation of `inference` requires non-shuffled dataset" 125 | _data_num = self.dataset.num_samples 126 | if _data_num % self.args.batch_size != 0: 127 | with format_text("red", attrs=["bold"]) as fmt: 128 | self.log.warning(fmt(f"Among {_data_num} data, last {_data_num%self.dataset.batch_size} items will not" 129 | f" be processed during inferential procedure.")) 130 | 131 | if self.args.inference_output not in self._available_inference_output: 132 | raise ValueError(f"Inappropriate inference_output type for " 133 | f"{self.__class__.__name__}: {self.args.inference_output}.\n" 134 | f"Available outputs are {self._available_inference_output}") 135 | 136 | self.log.info("Inference started") 137 | self.setup_dataset_iterator() 138 | self.ckpt_loader.load(checkpoint_path) 139 | 140 | step = self.build_evaluation_step(checkpoint_path) 141 | checkpoint_glob, checkpoint_path = self.build_checkpoint_paths(checkpoint_path) 142 | self.session.run(tf.local_variables_initializer()) 143 | 144 | eval_dict = self.run_inference(step, is_training=False, do_eval=False) 145 | 146 | self.save_inference_result(eval_dict, checkpoint_path) 147 | 148 | def evaluate_once(self, checkpoint_path): 149 | self.log.info("Evaluation started") 150 | self.setup_dataset_iterator() 151 | self.ckpt_loader.load(checkpoint_path) 152 | 153 | step = self.build_evaluation_step(checkpoint_path) 154 | checkpoint_glob, checkpoint_path = self.build_checkpoint_paths(checkpoint_path) 155 | self.session.run(tf.local_variables_initializer()) 156 | 157 | eval_metric_dict = self.run_evaluation(step, is_training=False) 158 | best_keep_metric_dict = self.metric_manager.filter_best_keep_metric(eval_metric_dict) 159 | is_keep, metrics_keep = self.best_keeper.monitor(self.dataset_name, best_keep_metric_dict) 160 | 161 | if self.args.save_best_keeper: 162 | meta_info = { 163 | "step": step, 164 | "model_size": self.model.total_params, 165 | } 166 | self.best_keeper.remove_old_best(self.dataset_name, metrics_keep) 167 | self.best_keeper.save_best(self.dataset_name, metrics_keep, checkpoint_glob) 168 | self.best_keeper.remove_temp_dir() 169 | self.best_keeper.save_scores(self.dataset_name, metrics_keep, best_keep_metric_dict, meta_info) 170 | 171 | self.metric_manager.write_evaluation_summaries(step=step, 172 | collection_keys=[BaseSummaries.KEY_TYPES.DEFAULT]) 173 | self.metric_manager.log_metrics(step=step) 174 | 175 | self.log.info("Evaluation finished") 176 | 177 | if step >= self.args.max_step_from_restore: 178 | self.log.info("Evaluation stopped") 179 | sys.exit() 180 | 181 | def build_train_directory(self): 182 | if Path(self.args.checkpoint_path).is_dir(): 183 | return str(self.args.checkpoint_path) 184 | else: 185 | return str(Path(self.args.checkpoint_path).parent) 186 | 187 | @staticmethod 188 | def add_arguments(parser): 189 | g = parser.add_argument_group("(Evaluator) arguments") 190 | 191 | g.add_argument( 192 | "--inference_output", 193 | type=str, 194 | default="none", 195 | ) 196 | g.add_argument( 197 | "--inference_output_dirname", 198 | type=str, 199 | default=None, 200 | ) 201 | 202 | g.add_argument("--valid_type", default="loop", type=str, choices=["loop", "once"]) 203 | g.add_argument("--max_outputs", default=5, type=int) 204 | 205 | g.add_argument("--no-convert_to_pb", dest="convert_to_pb", action="store_false") 206 | g.add_argument("--convert_to_pb", dest="convert_to_pb", action="store_true") 207 | g.set_defaults(convert_to_pb=True) 208 | 209 | g.add_argument("--no-save_evaluation_image", dest="save_evaluation_image", action="store_false") 210 | g.add_argument("--save_evaluation_image", dest="save_evaluation_image", action="store_true") 211 | g.set_defaults(save_evaluation_image=False) 212 | 213 | g.add_argument("--no-save_best_keeper", dest="save_best_keeper", action="store_false") 214 | g.add_argument("--save_best_keeper", dest="save_best_keeper", action="store_true") 215 | g.set_defaults(save_best_keeper=True) 216 | 217 | g.add_argument("--max_step_from_restore", default=1e20, type=int) 218 | 219 | 220 | class MattingEvaluator(Evaluator, MattingBase): 221 | _available_inference_output = ["image_under_prob", "binary_mask", "prob_mask", "image_and_mask"] 222 | 223 | def __init__(self, model, session, args, dataset, dataset_name): 224 | super().__init__(model, session, args, dataset, dataset_name, "MattingEvaluator") 225 | 226 | self.setup_metric_manager() 227 | self.setup_metric_ops() 228 | self.setup_best_keeper() 229 | 230 | def setup_metric_manager(self): 231 | self.metric_manager = metric_manager.MattingMetricManager( 232 | is_training=False, 233 | save_evaluation_image=self.args.save_evaluation_image, 234 | exclude_metric_names=self.args.exclude_metric_names, 235 | summary=Summaries( 236 | session=self.session, 237 | train_dir=self.build_train_directory(), 238 | is_training=False, 239 | max_image_outputs=self.args.max_image_outputs 240 | ), 241 | ) 242 | 243 | def setup_metric_ops(self): 244 | losses = self.build_basic_loss_ops() 245 | summary_images = self.build_basic_image_ops() 246 | misc_images = self.build_misc_image_ops() 247 | 248 | self.metric_tf_op = self.metric_manager.build_metric_ops({ 249 | "dataset_split_name": self.dataset_name, 250 | "target_eval_shape": self.args.target_eval_shape, 251 | 252 | "losses": losses, 253 | "summary_images": summary_images, 254 | "misc_images": misc_images, 255 | "masks": self.model.masks, 256 | "masks_original": self.model.masks_original, 257 | "probs": self.model.prob_scores, 258 | }) 259 | 260 | def setup_dataset_iterator(self): 261 | self.dataset.setup_iterator( 262 | self.session, 263 | (self.dataset.image_placeholder, self.dataset.mask_placeholder), 264 | (self.dataset.image_fullpathes, self.dataset.mask_fullpathes), 265 | ) 266 | 267 | def build_non_tensor_data_from_eval_dict(self, eval_dict, **kwargs): 268 | return { 269 | "dataset_split_name": self.dataset_name, 270 | 271 | "batch_infer_time": eval_dict["batch_infer_time"], 272 | "unit_infer_time": eval_dict["unit_infer_time"], 273 | "misc_images": dict(filter(lambda x: x[0].startswith("misc_images/"), eval_dict.items())), 274 | "image_save_dir": self.build_miscellaneous_path("images"), 275 | } 276 | 277 | def save_inference_result(self, eval_dict, checkpoint_path): 278 | output_dir = self.build_inference_path(checkpoint_path) 279 | if not output_dir.is_dir(): 280 | output_dir.mkdir(parents=True) 281 | self.log.info(f"Make directory {output_dir}") 282 | 283 | self.log.info(f"Inference results will be saved under {output_dir}") 284 | 285 | predictions = eval_dict["probs"] 286 | images = eval_dict["images"] 287 | 288 | for prediction, image, image_fullpath in tqdm(zip(predictions, images, self.dataset.image_fullpathes)): 289 | prediction = prediction[:, :, 1:] 290 | imagename = Path(image_fullpath).name 291 | prediction_normed = (prediction - prediction.min()) / (prediction.max() - prediction.min()) 292 | if self.args.inference_output == "image_under_prob": 293 | _output = np.squeeze(np.expand_dims(prediction_normed, 0) * image) 294 | elif self.args.inference_output == "binary_mask": 295 | _output = np.squeeze(prediction_normed > 0.5) * 255 296 | elif self.args.inference_output == "prob_mask": 297 | _output = np.squeeze(prediction_normed) * 255 298 | elif self.args.inference_output == "image_and_mask": 299 | _output = np.concatenate([image, image * prediction, np.tile(prediction * 255, 3)], axis=0) 300 | _output = _output.round().astype(np.uint8) 301 | imsave(output_dir / imagename, _output) 302 | 303 | self.log.info(f"Inference results saved under {output_dir}") 304 | -------------------------------------------------------------------------------- /helper/trainer.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import time 3 | from pathlib import Path 4 | from abc import ABC 5 | from abc import abstractmethod 6 | from typing import Dict 7 | 8 | import humanfriendly as hf 9 | import tensorflow as tf 10 | import numpy as np 11 | import tensorflow.contrib.slim as slim 12 | from termcolor import colored 13 | 14 | import common.tf_utils as tf_utils 15 | import common.utils as utils 16 | import common.lr_scheduler as lr_scheduler 17 | import metrics.manager as metric_manager 18 | from helper.base import MattingBase 19 | from common.model_loader import Ckpt 20 | from metrics.summaries import BaseSummaries 21 | from metrics.summaries import Summaries 22 | 23 | 24 | class TrainerBase(ABC): 25 | def __init__(self, model, session, args, dataset, dataset_name, name): 26 | self.model = model 27 | self.session = session 28 | self.args = args 29 | self.dataset = dataset 30 | self.dataset_name = dataset_name 31 | 32 | self.log = utils.get_logger(name) 33 | self.timer = utils.Timer(self.log) 34 | 35 | self.info_red = utils.format_log(self.log.info, "red") 36 | self.info_cyan = utils.format_log(self.log.info, "cyan") 37 | self.info_magenta = utils.format_log(self.log.info, "magenta") 38 | self.info_magenta_reverse = utils.format_log(self.log.info, "magenta", attrs=["reverse"]) 39 | self.info_cyan_underline = utils.format_log(self.log.info, "cyan", attrs=["underline"]) 40 | self.debug_red = utils.format_log(self.log.debug, "red") 41 | self._saver = None 42 | 43 | # used in `log_step_message` method 44 | self.last_loss = dict() 45 | 46 | @property 47 | def etc_fetch_namespace(self): 48 | return utils.MLNamespace( 49 | step_op="step_op", 50 | global_step="global_step", 51 | summary="summary", 52 | ) 53 | 54 | @property 55 | def loss_fetch_namespace(self): 56 | return utils.MLNamespace( 57 | total_loss="total_loss", 58 | model_loss="model_loss", 59 | ) 60 | 61 | @property 62 | def summary_fetch_namespace(self): 63 | return utils.MLNamespace( 64 | merged_summaries="merged_summaries", 65 | merged_verbose_summaries="merged_verbose_summaries", 66 | merged_first_n_summaries="merged_first_n_summaries", 67 | ) 68 | 69 | @property 70 | def after_fetch_namespace(self): 71 | return utils.MLNamespace( 72 | single_step="single_step", 73 | single_step_per_image="single_step_per_image", 74 | ) 75 | 76 | @property 77 | def saver(self): 78 | if self._saver is None: 79 | self._saver = tf.train.Saver(max_to_keep=self.args.max_to_keep) 80 | return self._saver 81 | 82 | @abstractmethod 83 | def setup_metric_manager(self): 84 | raise NotImplementedError 85 | 86 | @abstractmethod 87 | def setup_metric_ops(self): 88 | raise NotImplementedError 89 | 90 | @abstractmethod 91 | def build_non_tensor_data_from_eval_dict(self, eval_dict, **kwargs): 92 | raise NotImplementedError 93 | 94 | @abstractmethod 95 | def build_evaluate_iterations(self, iters: int): 96 | raise NotImplementedError 97 | 98 | def routine_experiment_summary(self): 99 | self.metric_manager.summary.setup_experiment(self.args) 100 | 101 | def setup_essentials(self, max_to_keep=5): 102 | self.no_op = tf.no_op() 103 | self.args.checkpoint_path = tf_utils.resolve_checkpoint_path( 104 | self.args.checkpoint_path, self.log, is_training=True 105 | ) 106 | self.train_dir_name = (Path.cwd() / Path(self.args.train_dir)).resolve() 107 | 108 | # We use this global step for shift boundaries for piecewise_constant learning rate 109 | # We cannot use global step from checkpoint file before restore from checkpoint 110 | # For restoring, we needs to initialize all operations including optimizer 111 | self.global_step_from_checkpoint = tf_utils.get_global_step_from_checkpoint(self.args.checkpoint_path) 112 | self.global_step = tf.Variable(self.global_step_from_checkpoint, name="global_step", trainable=False) 113 | 114 | self.lr_scheduler = lr_scheduler.factory( 115 | self.args, self.log, self.global_step_from_checkpoint, self.global_step, self.dataset 116 | ) 117 | 118 | def routine_restore_and_initialize(self, checkpoint_path=None): 119 | """Read various loading methods for tensorflow 120 | - https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/slim/python/slim/learning.py#L121 121 | """ 122 | if checkpoint_path is None: 123 | checkpoint_path = self.args.checkpoint_path 124 | var_names_to_values = getattr(self.model, "var_names_to_values", None) 125 | self.session.run(tf.global_variables_initializer()) 126 | self.session.run(tf.local_variables_initializer()) # for metrics 127 | 128 | if var_names_to_values is not None: 129 | init_assign_op, init_feed_dict = slim.assign_from_values(var_names_to_values) 130 | # Create an initial assignment function. 131 | self.session.run(init_assign_op, init_feed_dict) 132 | self.log.info(colored("Restore from Memory(usually weights from caffe!)", 133 | "cyan", attrs=["bold", "underline"])) 134 | elif checkpoint_path == "" or checkpoint_path is None: 135 | self.log.info(colored("Initialize global / local variables", "cyan", attrs=["bold", "underline"])) 136 | else: 137 | ckpt_loader = Ckpt( 138 | session=self.session, 139 | include_scopes=self.args.checkpoint_include_scopes, 140 | exclude_scopes=self.args.checkpoint_exclude_scopes, 141 | ignore_missing_vars=self.args.ignore_missing_vars, 142 | ) 143 | ckpt_loader.load(checkpoint_path) 144 | 145 | def routine_logging_checkpoint_path(self): 146 | self.log.info(colored("Watch Validation Through TensorBoard !", "yellow", attrs=["underline", "bold"])) 147 | self.log.info(colored("--checkpoint_path {}".format(self.train_dir_name), 148 | "yellow", attrs=["underline", "bold"])) 149 | 150 | def build_optimizer(self, optimizer, learning_rate, momentum=None, decay=None, epsilon=None): 151 | kwargs = { 152 | "learning_rate": learning_rate 153 | } 154 | if momentum: 155 | kwargs["momentum"] = momentum 156 | if decay: 157 | kwargs["decay"] = decay 158 | if epsilon: 159 | kwargs["epsilon"] = epsilon 160 | 161 | if optimizer == "gd": 162 | opt = tf.train.GradientDescentOptimizer(**kwargs) 163 | self.log.info("Use GradientDescentOptimizer") 164 | elif optimizer == "adam": 165 | opt = tf.train.AdamOptimizer(**kwargs) 166 | self.log.info("Use AdamOptimizer") 167 | elif optimizer == "mom": 168 | opt = tf.train.MomentumOptimizer(**kwargs) 169 | self.log.info("Use MomentumOptimizer") 170 | elif optimizer == "rmsprop": 171 | opt = tf.train.RMSPropOptimizer(**kwargs) 172 | self.log.info("Use RMSPropOptimizer") 173 | else: 174 | self.log.error("Unknown optimizer: {}".format(optimizer)) 175 | raise NotImplementedError 176 | return opt 177 | 178 | def build_train_op(self, total_loss, optimizer, trainable_scopes, global_step, gradient_multipliers=None): 179 | # If you use `slim.batch_norm`, then you should include train_op in slim. 180 | # https://github.com/tensorflow/tensorflow/issues/1122#issuecomment-280325584 181 | variables_to_train = tf_utils.get_variables_to_train(trainable_scopes, logger=self.log) 182 | 183 | if variables_to_train: 184 | train_op = slim.learning.create_train_op( 185 | total_loss, 186 | optimizer, 187 | global_step=global_step, 188 | variables_to_train=variables_to_train, 189 | gradient_multipliers=gradient_multipliers, 190 | ) 191 | 192 | if self.args.use_ema: 193 | self.ema = tf.train.ExponentialMovingAverage(decay=self.args.ema_decay) 194 | 195 | with tf.control_dependencies([train_op]): 196 | train_op = self.ema.apply(variables_to_train) 197 | else: 198 | self.log.info("Empty variables_to_train") 199 | train_op = tf.no_op() 200 | 201 | return train_op 202 | 203 | def build_epoch(self, step): 204 | return (step * self.dataset.batch_size) / self.dataset.num_samples 205 | 206 | def build_info_step_message(self, info: Dict, float_format, delimiter: str=" / "): 207 | keys = list(info.keys()) 208 | desc = delimiter.join(keys) 209 | val = delimiter.join([str(float_format.format(info[k])) for k in keys]) 210 | return desc, val 211 | 212 | def build_duration_step_message(self, header: Dict, delimiter: str=" / "): 213 | def convert_to_string(number): 214 | type_number = type(number) 215 | if type_number == int or type_number == np.int32 or type_number == np.int64: 216 | return str(f"{number:8d}") 217 | elif type_number == float or type_number == np.float64: 218 | return str(f"{number:3.3f}") 219 | else: 220 | raise TypeError("Unrecognized type of input number") 221 | 222 | keys = list(header.keys()) 223 | header_desc = delimiter.join(keys) 224 | header_val = delimiter.join([convert_to_string(header[k]) for k in keys]) 225 | 226 | return header_desc, header_val 227 | 228 | def log_evaluation( 229 | self, dataset_name, epoch_from_restore, step_from_restore, global_step, eval_scores, 230 | ): 231 | self.info_cyan_underline( 232 | f"[{dataset_name}-Evaluation] global_step / step_from_restore / epoch_from_restore: " 233 | f"{global_step:8d} / {step_from_restore:5d} / {epoch_from_restore:3.3f}" 234 | ) 235 | self.metric_manager.log_metrics(global_step) 236 | 237 | def log_step_message(self, header, losses, speeds, comparative_loss, batch_size, is_training, tag=""): 238 | def get_loss_color(old_loss: float, new_loss: float): 239 | if old_loss < new_loss: 240 | return "red" 241 | else: 242 | return "green" 243 | 244 | def get_log_color(is_training: bool): 245 | if is_training: 246 | return {"color": "blue", 247 | "attrs": ["bold"]} 248 | else: 249 | return {"color": "yellow", 250 | "attrs": ["underline"]} 251 | 252 | self.last_loss.setdefault(tag, comparative_loss) 253 | loss_color = get_loss_color(self.last_loss.get(tag, 0), comparative_loss) 254 | self.last_loss[tag] = comparative_loss 255 | 256 | model_size = hf.format_size(self.model.total_params*4) 257 | total_params = hf.format_number(self.model.total_params) 258 | 259 | loss_desc, loss_val = self.build_info_step_message(losses, "{:7.4f}") 260 | header_desc, header_val = self.build_duration_step_message(header) 261 | speed_desc, speed_val = self.build_info_step_message(speeds, "{:4.0f}") 262 | 263 | with utils.format_text(loss_color) as fmt: 264 | loss_val_colored = fmt(loss_val) 265 | msg = ( 266 | f"[{tag}] {header_desc}: {header_val}\t" 267 | f"{speed_desc}: {speed_val} ({self.args.width},{self.args.height};{batch_size})\t" 268 | f"{loss_desc}: {loss_val_colored} " 269 | f"| {model_size} {total_params}") 270 | 271 | with utils.format_text(**get_log_color(is_training)) as fmt: 272 | self.log.info(fmt(msg)) 273 | 274 | def setup_trainer(self): 275 | self.setup_essentials(self.args.max_to_keep) 276 | self.optimizer = self.build_optimizer(self.args.optimizer, 277 | learning_rate=self.lr_scheduler.placeholder, 278 | momentum=self.args.momentum, 279 | decay=self.args.optimizer_decay, 280 | epsilon=self.args.optimizer_epsilon) 281 | self.train_op = self.build_train_op(total_loss=self.model.total_loss, 282 | optimizer=self.optimizer, 283 | trainable_scopes=self.args.trainable_scopes, 284 | global_step=self.global_step) 285 | self.routine_restore_and_initialize() 286 | 287 | def append_if_value_is_not_none(self, key_and_op, fetch_ops): 288 | if key_and_op[1] is not None: 289 | fetch_ops.append(key_and_op) 290 | 291 | def run_single_step(self, fetch_ops: Dict, feed_dict: Dict=None): 292 | st = time.time() 293 | fetch_vals = self.session.run(fetch_ops, feed_dict=feed_dict) 294 | step_time = (time.time() - st) * 1000 295 | step_time_per_image = step_time / self.dataset.batch_size 296 | 297 | fetch_vals[self.after_fetch_namespace.single_step] = step_time 298 | fetch_vals[self.after_fetch_namespace.single_step_per_image] = step_time_per_image 299 | 300 | return fetch_vals 301 | 302 | def log_summaries(self, fetch_vals): 303 | summary_keys = ( 304 | set(fetch_vals.keys()) - 305 | set(self.loss_fetch_namespace.unordered_values()) - 306 | set(self.etc_fetch_namespace.unordered_values()) - 307 | set(self.after_fetch_namespace.unordered_values()) 308 | ) 309 | if len(summary_keys) > 0: 310 | self.info_magenta(f"Above step includes saving {summary_keys} summaries to {self.train_dir_name}") 311 | 312 | def run_with_logging(self, summary_op, metric_op_dict, feed_dict): 313 | fetch_ops = { 314 | self.etc_fetch_namespace.step_op: self.train_op, 315 | self.etc_fetch_namespace.global_step: self.global_step, 316 | self.loss_fetch_namespace.total_loss: self.model.total_loss, 317 | self.loss_fetch_namespace.model_loss: self.model.model_loss, 318 | } 319 | if summary_op is not None: 320 | fetch_ops.update({self.etc_fetch_namespace.summary: summary_op}) 321 | if metric_op_dict is not None: 322 | fetch_ops.update(metric_op_dict) 323 | 324 | fetch_vals = self.run_single_step(fetch_ops=fetch_ops, feed_dict=feed_dict) 325 | 326 | global_step = fetch_vals[self.etc_fetch_namespace.global_step] 327 | step_from_restore = global_step - self.global_step_from_checkpoint 328 | epoch_from_restore = self.build_epoch(step_from_restore) 329 | 330 | self.log_step_message( 331 | {"GlobalStep": global_step, 332 | "StepFromRestore": step_from_restore, 333 | "EpochFromRestore": epoch_from_restore}, 334 | {"TotalLoss": fetch_vals[self.loss_fetch_namespace.total_loss], 335 | "ModelLoss": fetch_vals[self.loss_fetch_namespace.model_loss]}, 336 | {"SingleStepPerImage(ms)": fetch_vals[self.after_fetch_namespace.single_step_per_image], 337 | "SingleStep(ms)": fetch_vals[self.after_fetch_namespace.single_step]}, 338 | comparative_loss=fetch_vals[self.loss_fetch_namespace.total_loss], 339 | batch_size=self.dataset.batch_size, 340 | tag=self.dataset_name, 341 | is_training=True 342 | ) 343 | 344 | return fetch_vals, global_step, step_from_restore, epoch_from_restore 345 | 346 | def train(self, name: str="Training"): 347 | self.log.info(f"{name} started") 348 | 349 | global_step, step_from_restore, epoch_from_restore = 0, 0, 0 350 | while True: 351 | try: 352 | feed_dict = self.get_feed_dict(is_training=True) 353 | valid_collection_keys = [] 354 | 355 | # collect valid collection keys 356 | if step_from_restore >= self.args.step_min_summaries and \ 357 | step_from_restore % self.args.step_save_summaries == 0: 358 | valid_collection_keys.append(BaseSummaries.KEY_TYPES.DEFAULT) 359 | 360 | if step_from_restore % self.args.step_save_verbose_summaries == 0: 361 | valid_collection_keys.append(BaseSummaries.KEY_TYPES.VERBOSE) 362 | 363 | if step_from_restore <= self.args.step_save_first_n_summaries: 364 | valid_collection_keys.append(BaseSummaries.KEY_TYPES.FIRST_N) 365 | 366 | # merge it to single one 367 | summary_op = self.metric_manager.summary.get_merged_summaries( 368 | collection_key_suffixes=valid_collection_keys, 369 | is_tensor_summary=True 370 | ) 371 | 372 | # send metric op 373 | # run it only when evaluate 374 | if step_from_restore % self.args.step_evaluation == 0: 375 | metric_op_dict = self.metric_tf_op 376 | else: 377 | metric_op_dict = None 378 | 379 | # Session.Run! 380 | fetch_vals, global_step, step_from_restore, epoch_from_restore = self.run_with_logging( 381 | summary_op, metric_op_dict, feed_dict) 382 | self.log_summaries(fetch_vals) 383 | 384 | # Save 385 | if step_from_restore % self.args.step_save_checkpoint == 0: 386 | with self.timer(f"save checkpoint: {self.train_dir_name}", self.info_magenta_reverse): 387 | self.saver.save(self.session, 388 | str(Path(self.args.train_dir) / self.args.model), 389 | global_step=global_step) 390 | if self.args.write_pbtxt: 391 | tf.train.write_graph( 392 | self.session.graph_def, self.args.train_dir, self.args.model + ".pbtxt" 393 | ) 394 | 395 | if step_from_restore % self.args.step_evaluation == 0: 396 | self.evaluate(epoch_from_restore, step_from_restore, global_step, self.dataset_name) 397 | 398 | if epoch_from_restore >= self.args.max_epoch_from_restore: 399 | self.info_red(f"Reached {self.args.max_epoch_from_restore} epochs from restore.") 400 | break 401 | 402 | if step_from_restore >= self.args.max_step_from_restore: 403 | self.info_red(f"Reached {self.args.max_step_from_restore} steps from restore.") 404 | break 405 | 406 | if self.args.step1_mode: 407 | break 408 | 409 | global_step += 1 410 | step_from_restore = global_step - self.global_step_from_checkpoint 411 | epoch_from_restore = self.build_epoch(step_from_restore) 412 | except tf.errors.InvalidArgumentError as e: 413 | utils.format_log(self.log.error, "red")(f"Invalid image is detected: {e}") 414 | continue 415 | 416 | self.log.info(f"{name} finished") 417 | 418 | def evaluate( 419 | self, 420 | epoch_from_restore: float, 421 | step_from_restore: int, 422 | global_step: int, 423 | dataset_name: str, 424 | iters: int=None, 425 | ): 426 | # Update learning rate based on validation loss will be implemented in new class 427 | # where Trainer and Evaluator will share information. Currently, `ReduceLROnPlateau` 428 | # would not work correctly. 429 | # self.lr_scheduler.update_on_start_of_evaluation() 430 | 431 | # calculate number of iterations 432 | iters = self.build_evaluate_iterations(iters) 433 | 434 | with self.timer(f"run_evaluation (iterations: {iters})", self.info_cyan): 435 | # evaluate metrics 436 | eval_dict = self.run_inference(global_step, iters=iters, is_training=True) 437 | 438 | non_tensor_data = self.build_non_tensor_data_from_eval_dict(eval_dict) 439 | 440 | self.metric_manager.evaluate_and_aggregate_metrics(step=global_step, 441 | non_tensor_data=non_tensor_data, 442 | eval_dict=eval_dict) 443 | 444 | self.metric_manager.write_evaluation_summaries(step=global_step, 445 | collection_keys=[BaseSummaries.KEY_TYPES.DEFAULT]) 446 | 447 | self.log_evaluation(dataset_name, epoch_from_restore, step_from_restore, global_step, eval_scores=None) 448 | 449 | 450 | @staticmethod 451 | def add_arguments(parser, name: str="TrainerBase"): 452 | g_optimize = parser.add_argument_group(f"({name}) Optimizer Arguments") 453 | g_optimize.add_argument("--optimizer", default="adam", type=str, 454 | choices=["gd", "adam", "mom", "rmsprop"], 455 | help="name of optimizer") 456 | g_optimize.add_argument("--momentum", default=None, type=float) 457 | g_optimize.add_argument("--optimizer_decay", default=None, type=float) 458 | g_optimize.add_argument("--optimizer_epsilon", default=None, type=float) 459 | 460 | g_rst = parser.add_argument_group(f"({name}) Saver(Restore) Arguments") 461 | g_rst.add_argument("--trainable_scopes", default="", type=str, 462 | help=( 463 | "Prefix scopes for training variables (comma separated)\n" 464 | "Usually Logits e.g. InceptionResnetV2/Logits/Logits,InceptionResnetV2/AuxLogits/Logits" 465 | "For default value, trainable_scopes='' means training 'all' variable" 466 | "If you don't want to train(e.g. validation only), " 467 | "you should give unmatched random string" 468 | )) 469 | 470 | g_options = parser.add_argument_group(f"({name}) Training options(step, batch_size, path) Arguments") 471 | 472 | g_options.add_argument("--no-write_pbtxt", dest="write_pbtxt", action="store_false") 473 | g_options.add_argument("--write_pbtxt", dest="write_pbtxt", action="store_true", 474 | help="write_pbtxt model parameters") 475 | g_options.set_defaults(write_pbtxt=True) 476 | g_options.add_argument("--train_dir", required=True, type=str, 477 | help="Directory where to write event logs and checkpoint.") 478 | g_options.add_argument("--step_save_summaries", default=10, type=int) 479 | g_options.add_argument("--step_save_verbose_summaries", default=2000, type=int) 480 | g_options.add_argument("--step_save_first_n_summaries", default=30, type=int) 481 | g_options.add_argument("--step_save_checkpoint", default=500, type=int) 482 | g_options.add_argument("--step_evaluation", default=500, type=utils.positive_int) 483 | 484 | g_options.add_argument("--max_to_keep", default=5, type=utils.positive_int) 485 | g_options.add_argument("--max_outputs", default=5, type=utils.positive_int) 486 | g_options.add_argument("--max_epoch_from_restore", default=50000, type=float, 487 | help=( 488 | "max epoch(1 epoch = whole data): " 489 | "Default value for max_epoch is ImageNet Resnet50 max epoch counts" 490 | )) 491 | g_options.add_argument("--step_min_summaries", default=0, type=int) 492 | g_options.add_argument("--max_step_from_restore", default=sys.maxsize, type=int, 493 | help="Stop training when reaching given step value.") 494 | g_options.add_argument("--tag", default="tag", type=str, help="tag for folder name") 495 | g_options.add_argument("--no-testmode", dest="testmode", action="store_false") 496 | g_options.add_argument("--testmode", dest="testmode", action="store_true", 497 | help="If testmode, ask deleting train_dir when you exit the process") 498 | g_options.set_defaults(testmode=False) 499 | g_options.add_argument("--no-debug", dest="debug", action="store_false") 500 | g_options.add_argument("--debug", dest="debug", action="store_true", help="Debug model parameters") 501 | g_options.set_defaults(debug=False) 502 | 503 | g_options.add_argument("--no-step1_mode", dest="step1_mode", action="store_false") 504 | g_options.add_argument("--step1_mode", dest="step1_mode", action="store_true") 505 | g_options.set_defaults(step1_mode=False) 506 | 507 | g_options.add_argument("--no-save_evaluation_image", dest="save_evaluation_image", action="store_false") 508 | g_options.add_argument("--save_evaluation_image", dest="save_evaluation_image", action="store_true") 509 | g_options.set_defaults(save_evaluation_image=False) 510 | 511 | lr_scheduler.add_arguments(parser) 512 | 513 | 514 | class MattingTrainer(TrainerBase, MattingBase): 515 | def __init__(self, model, session, args, dataset, dataset_name): 516 | super().__init__(model, session, args, dataset, dataset_name, "MattingTrainer") 517 | 518 | self.setup_trainer() 519 | self.setup_metric_manager() 520 | self.setup_metric_ops() 521 | 522 | self.routine_experiment_summary() 523 | self.routine_logging_checkpoint_path() 524 | 525 | def build_evaluate_iterations(self, iters): 526 | if iters is not None: 527 | iters = iters 528 | elif self.args.evaluation_iterations is not None: 529 | iters = self.args.evaluation_iterations 530 | else: 531 | iters = 10 # a default value we used 532 | 533 | return iters 534 | 535 | def setup_metric_manager(self): 536 | self.metric_manager = metric_manager.MattingMetricManager( 537 | is_training=True, 538 | save_evaluation_image=self.args.save_evaluation_image, 539 | exclude_metric_names=self.args.exclude_metric_names, 540 | summary=Summaries( 541 | session=self.session, 542 | train_dir=self.args.train_dir, 543 | is_training=True, 544 | max_image_outputs=self.args.max_image_outputs 545 | ), 546 | ) 547 | 548 | def setup_metric_ops(self): 549 | losses = self.build_basic_loss_ops() 550 | summary_images = self.build_basic_image_ops() 551 | misc_images = self.build_misc_image_ops() 552 | 553 | self.metric_tf_op = self.metric_manager.build_metric_ops({ 554 | "dataset_split_name": self.dataset_name, 555 | "target_eval_shape": self.args.target_eval_shape, 556 | 557 | "losses": losses, 558 | "summary_images": summary_images, 559 | "misc_images": misc_images, 560 | "masks": self.model.masks, 561 | "masks_original": self.model.masks_original, 562 | "probs": self.model.prob_scores, 563 | }) 564 | 565 | def build_non_tensor_data_from_eval_dict(self, eval_dict, **kwargs): 566 | return { 567 | "dataset_split_name": self.dataset_name, 568 | 569 | "batch_infer_time": eval_dict["batch_infer_time"], 570 | "unit_infer_time": eval_dict["unit_infer_time"], 571 | "misc_images": None, 572 | "image_save_dir": None, 573 | } 574 | 575 | @staticmethod 576 | def add_arguments(parser): 577 | g_etc = parser.add_argument_group("(MattingTrainer) ETC") 578 | # EVALUATE 579 | g_etc.add_argument("--first_n", default=10, type=utils.positive_int, help="Argument for tf.Print") 580 | -------------------------------------------------------------------------------- /matting_nets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hyperconnect/MMNet/aa423b598f38e051fcf46914cf8a7765458c09d2/matting_nets/__init__.py -------------------------------------------------------------------------------- /matting_nets/deeplab_v3plus.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from nets.mobilenet import mobilenet_v2 as mobilenet_v2_slim 3 | 4 | slim = tf.contrib.slim 5 | 6 | 7 | available_extractors = [ 8 | "mobilenet_v2", 9 | ] 10 | 11 | 12 | def build_encoder(network, inputs, is_training, depth_multiplier=None, output_stride=16): 13 | if network == "mobilenet_v2": 14 | return mobilenet_v2_slim.mobilenet_base(inputs, 15 | conv_defs=mobilenet_v2_slim.V2_DEF, 16 | depth_multiplier=depth_multiplier, 17 | final_endpoint="layer_18", 18 | output_stride=output_stride, 19 | is_training=is_training) 20 | else: 21 | raise NotImplementedError 22 | 23 | 24 | def get_decoder_end_point(network): 25 | if network == "mobilenet_v2": 26 | return "layer_4/depthwise_output" 27 | else: 28 | raise NotImplementedError 29 | 30 | 31 | def aspp(inputs, depth=256, rates=[3, 6, 9]): 32 | with tf.variable_scope("aspp"): 33 | branches = [] 34 | 35 | with tf.variable_scope("aspp_1x1conv"): 36 | net = slim.conv2d(inputs, num_outputs=depth, kernel_size=1, stride=1) 37 | branches.append(net) 38 | 39 | for rate in rates: 40 | with tf.variable_scope(f"aspp_atrous{rate}"): 41 | net = slim.separable_conv2d(inputs, num_outputs=None, kernel_size=3, stride=1, 42 | depth_multiplier=1, rate=rate, scope="depthwise_conv") 43 | net = slim.conv2d(net, num_outputs=depth, kernel_size=1, stride=1, scope="pointwise_conv") 44 | branches.append(net) 45 | 46 | with tf.variable_scope("aspp_pool"): 47 | net = tf.reduce_mean(inputs, [1, 2], keep_dims=True, name="global_pool") 48 | net = slim.conv2d(net, num_outputs=depth, kernel_size=1) 49 | net = tf.image.resize_bilinear(net, size=inputs.get_shape()[1:3], align_corners=True) 50 | branches.append(net) 51 | 52 | with tf.variable_scope("concat"): 53 | concat_logits = tf.concat(branches, 3) 54 | concat_logits = slim.conv2d(concat_logits, num_outputs=depth, kernel_size=1, stride=1) 55 | concat_logits = slim.dropout(concat_logits, keep_prob=0.9) 56 | 57 | return concat_logits 58 | 59 | 60 | def decoder(args, endpoints, aspp, height, width, num_classes): 61 | with tf.variable_scope("decoder"): 62 | low_level_feature = endpoints[get_decoder_end_point(args.extractor)] 63 | net = slim.conv2d(low_level_feature, num_outputs=48, kernel_size=1, scope="feature_projection") 64 | decoder_features = [net, aspp] 65 | 66 | for i in range(len(decoder_features)): 67 | decoder_features[i] = tf.image.resize_bilinear(decoder_features[i], size=[height, width], 68 | align_corners=True) 69 | 70 | net = tf.concat(decoder_features, 3) 71 | for i in range(2): 72 | net = slim.separable_conv2d(net, num_outputs=None, kernel_size=3, stride=1, 73 | depth_multiplier=1, scope=f"decoder_conv{i}_depthwise") 74 | net = slim.conv2d(net, num_outputs=256, kernel_size=1, stride=1, scope=f"decoder_conv{i}_pointwise") 75 | 76 | net = slim.conv2d(net, num_outputs=num_classes, kernel_size=1, stride=1, 77 | normalizer_fn=None, activation_fn=None, scope="logits") 78 | 79 | return net 80 | 81 | 82 | def DeepLab(inputs, args, is_training, depth_multiplier=None, output_stride=16, scope="DeepLab"): 83 | with tf.variable_scope(scope, "DeepLab", [inputs]): 84 | with slim.arg_scope([slim.batch_norm, slim.dropout], is_training=is_training): 85 | features, endpoints = build_encoder(args.extractor, inputs, is_training, depth_multiplier, output_stride) 86 | if args.extractor == "mobilenet_v2": 87 | atrous_rates = [] 88 | result = aspp(features, rates=atrous_rates) 89 | result = slim.conv2d(result, num_outputs=args.num_classes, kernel_size=1, stride=1, 90 | normalizer_fn=None, activation_fn=None, scope="logits") 91 | else: 92 | raise NotImplementedError 93 | 94 | result = tf.image.resize_bilinear(result, 95 | size=[args.height, args.width], 96 | align_corners=True) 97 | return result, endpoints 98 | 99 | 100 | def DeepLab_arg_scope(weight_decay=0.00004, 101 | batch_norm_decay=0.9997, 102 | batch_norm_epsilon=1e-5, 103 | batch_norm_scale=True, 104 | weights_initializer_stddev=0.09, 105 | activation_fn=tf.nn.relu, 106 | regularize_depthwise=False, 107 | use_fused_batchnorm=True): 108 | batch_norm_params = { 109 | "decay": batch_norm_decay, 110 | "epsilon": batch_norm_epsilon, 111 | "scale": batch_norm_scale, 112 | "fused": use_fused_batchnorm, 113 | } 114 | if regularize_depthwise: 115 | depthwise_regularizer = slim.l2_regularizer(weight_decay) 116 | else: 117 | depthwise_regularizer = None 118 | 119 | with slim.arg_scope( 120 | [slim.conv2d, slim.separable_conv2d], 121 | weights_initializer=tf.truncated_normal_initializer(stddev=weights_initializer_stddev), 122 | activation_fn=activation_fn, 123 | normalizer_fn=slim.batch_norm): 124 | with slim.arg_scope([slim.batch_norm], **batch_norm_params): 125 | with slim.arg_scope( 126 | [slim.conv2d], 127 | weights_regularizer=slim.l2_regularizer(weight_decay)): 128 | with slim.arg_scope( 129 | [slim.separable_conv2d], 130 | weights_regularizer=depthwise_regularizer) as scope: 131 | return scope 132 | -------------------------------------------------------------------------------- /matting_nets/mmnet.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from matting_nets.mmnet_utils import quantizable_separable_convolution2d 4 | 5 | slim = tf.contrib.slim 6 | separable_conv = quantizable_separable_convolution2d 7 | 8 | 9 | def multiply_depth(depth, depth_multiplier, min_depth=8, divisor=8): 10 | multiplied_depth = round(depth * depth_multiplier) 11 | divisible_depth = (multiplied_depth + divisor // 2) // divisor * divisor 12 | return max(min_depth, divisible_depth) 13 | 14 | 15 | def init_block(inputs, depth, depth_multiplier, name): 16 | depth = multiply_depth(depth, depth_multiplier) 17 | with tf.variable_scope(name): 18 | net = slim.conv2d(inputs, num_outputs=depth, kernel_size=3, stride=2, scope="conv") 19 | return net 20 | 21 | 22 | def encoder_block(inputs, expanded_depth, output_depth, depth_multiplier, rates, stride, name, 23 | activation_fn=tf.identity): 24 | expanded_depth = multiply_depth(expanded_depth, depth_multiplier) 25 | output_depth = multiply_depth(output_depth, depth_multiplier) 26 | 27 | with tf.variable_scope(name): 28 | convs = [] 29 | for i, rate in enumerate(rates): 30 | with tf.variable_scope(f"branch{i}"): 31 | conv = slim.conv2d(inputs, num_outputs=expanded_depth, kernel_size=1, stride=1, scope="pointwise_conv") 32 | if stride > 1: 33 | conv = separable_conv(conv, num_outputs=None, kernel_size=3, stride=stride, depth_multiplier=1, 34 | scope="depthwise_conv_stride") 35 | conv = separable_conv(conv, num_outputs=None, kernel_size=3, stride=1, depth_multiplier=1, 36 | rate=rate, scope="depthwise_conv_dilation") 37 | convs.append(conv) 38 | 39 | with tf.variable_scope("merge"): 40 | if len(convs) > 1: 41 | net = tf.concat(convs, axis=-1) 42 | else: 43 | net = convs[0] 44 | 45 | net = slim.conv2d(net, num_outputs=output_depth, kernel_size=1, stride=1, activation_fn=activation_fn, 46 | scope="pointwise_conv") 47 | net = slim.dropout(net, scope="dropout") 48 | 49 | return net 50 | 51 | 52 | def decoder_block(inputs, shortcut_input, compressed_depth, shortcut_depth, depth_multiplier, num_of_resize, name): 53 | compressed_depth = multiply_depth(compressed_depth, depth_multiplier) 54 | if shortcut_depth is not None: 55 | shortcut_depth = multiply_depth(shortcut_depth, depth_multiplier) 56 | 57 | with tf.variable_scope(name): 58 | net = slim.conv2d(inputs, num_outputs=compressed_depth, kernel_size=1, stride=1, scope="pointwise_conv") 59 | 60 | for i in range(num_of_resize): 61 | resize_shape = [v * 2**(i+1) for v in inputs.get_shape().as_list()[1:3]] 62 | net = tf.image.resize_bilinear(net, size=resize_shape, name=f"resize_bilinear_{i}") 63 | 64 | if shortcut_input is not None: 65 | with tf.variable_scope("shortcut"): 66 | shortcut = separable_conv(shortcut_input, num_outputs=None, kernel_size=3, 67 | stride=1, depth_multiplier=1, rate=1, scope="depthwise_conv") 68 | shortcut = slim.conv2d(shortcut, num_outputs=shortcut_depth, kernel_size=1, stride=1, 69 | scope="pointwise_conv") 70 | net = tf.concat([net, shortcut], axis=-1, name="concat") 71 | 72 | return net 73 | 74 | 75 | def final_block(inputs, num_outputs, name): 76 | with tf.variable_scope(name): 77 | net = slim.conv2d(inputs, num_outputs=num_outputs, kernel_size=1, stride=1, 78 | activation_fn=None, normalizer_fn=None, scope="pointwise_conv") 79 | return net 80 | 81 | 82 | def MMNet(inputs, is_training, depth_multiplier=1.0, scope="MMNet"): 83 | endpoints = {} 84 | with tf.variable_scope(scope, "MMNet", [inputs]): 85 | with slim.arg_scope([slim.batch_norm], is_training=is_training, activation_fn=None): 86 | with slim.arg_scope([slim.dropout], is_training=is_training): 87 | endpoints["init_block"] = init_block(inputs, 32, depth_multiplier, "init_block") 88 | 89 | # encoder do downsampling 90 | endpoints["enc_block0"] = encoder_block(endpoints["init_block"], 16, 16, depth_multiplier, 91 | [1, 2, 4, 8], 2, "enc_block0") 92 | endpoints["enc_block1"] = encoder_block(endpoints["enc_block0"], 16, 24, depth_multiplier, 93 | [1, 2, 4, 8], 1, "enc_block1") 94 | endpoints["enc_block2"] = encoder_block(endpoints["enc_block1"], 24, 24, depth_multiplier, 95 | [1, 2, 4, 8], 1, "enc_block2") 96 | endpoints["enc_block3"] = encoder_block(endpoints["enc_block2"], 24, 24, depth_multiplier, 97 | [1, 2, 4, 8], 1, "enc_block3") 98 | endpoints["enc_block4"] = encoder_block(endpoints["enc_block3"], 32, 40, depth_multiplier, 99 | [1, 2, 4], 2, "enc_block4") 100 | endpoints["enc_block5"] = encoder_block(endpoints["enc_block4"], 64, 40, depth_multiplier, 101 | [1, 2, 4], 1, "enc_block5") 102 | endpoints["enc_block6"] = encoder_block(endpoints["enc_block5"], 64, 40, depth_multiplier, 103 | [1, 2, 4], 1, "enc_block6") 104 | endpoints["enc_block7"] = encoder_block(endpoints["enc_block6"], 64, 40, depth_multiplier, 105 | [1, 2, 4], 1, "enc_block7") 106 | endpoints["enc_block8"] = encoder_block(endpoints["enc_block7"], 80, 80, depth_multiplier, 107 | [1, 2], 2, "enc_block8") 108 | endpoints["enc_block9"] = encoder_block(endpoints["enc_block8"], 120, 80, depth_multiplier, 109 | [1, 2], 1, "enc_block9", tf.nn.relu6) 110 | 111 | endpoints["dec_block0"] = decoder_block(endpoints["enc_block9"], endpoints["enc_block4"], 112 | 64, 64, depth_multiplier, 1, "dec_block0") 113 | endpoints["dec_block1"] = decoder_block(endpoints["dec_block0"], endpoints["enc_block0"], 114 | 40, 40, depth_multiplier, 1, "dec_block1") 115 | 116 | endpoints["dec_block2"] = encoder_block(endpoints["dec_block1"], 40, 40, depth_multiplier, 117 | [1, 2, 4], 1, "dec_block2") 118 | endpoints["dec_block3"] = encoder_block(endpoints["dec_block2"], 40, 40, depth_multiplier, 119 | [1, 2, 4], 1, "dec_block3") 120 | 121 | endpoints["dec_block4"] = decoder_block(endpoints["dec_block3"], None, 122 | 16, None, depth_multiplier, 2, "dec_block4") 123 | 124 | # Final Deconvolution 125 | endpoints["final_block"] = final_block(endpoints["dec_block4"], num_outputs=2, name="final_block") 126 | 127 | # aux output 128 | endpoints["aux_block"] = final_block(endpoints["enc_block9"], num_outputs=2, name="aux_block") 129 | 130 | return endpoints["final_block"], endpoints 131 | 132 | 133 | def MMNet_arg_scope(use_fused_batchnorm=True, 134 | regularize_depthwise=False, 135 | weight_decay=0.0, 136 | dropout=0.0): 137 | if regularize_depthwise and weight_decay != 0.0: 138 | depthwise_regularizer = slim.l2_regularizer(weight_decay) 139 | else: 140 | depthwise_regularizer = None 141 | 142 | with slim.arg_scope([slim.conv2d, separable_conv], activation_fn=tf.nn.relu6, normalizer_fn=slim.batch_norm): 143 | with slim.arg_scope([slim.batch_norm], fused=use_fused_batchnorm, scale=True): 144 | with slim.arg_scope([slim.conv2d], weights_regularizer=slim.l2_regularizer(weight_decay)): 145 | with slim.arg_scope([separable_conv], weights_regularizer=depthwise_regularizer): 146 | with slim.arg_scope([slim.dropout], keep_prob=1 - dropout) as scope: 147 | return scope 148 | -------------------------------------------------------------------------------- /matting_nets/mmnet_utils.py: -------------------------------------------------------------------------------- 1 | # Copied and modified the code in tensorflow slim 2 | # https://github.com/tensorflow/tensorflow/blob/r1.6/tensorflow/contrib/layers/python/layers/layers.py 3 | 4 | # Implemented quantizable_separable_convolution2d by modifying slim.separable_convoultion: Line 192~243 5 | # Changed operation order for dilated depthwise convolution 6 | # Before: [SpaceToBatchND] -> [DepthwiseConv2dNative] -> [BatchToSpaceND] -> [normalize] -> [activate] 7 | # After : [SpaceToBatchND] -> [DepthwiseConv2dNative] -> [normalize] -> [activate] -> [BatchToSpaceND] 8 | 9 | from tensorflow.contrib.framework.python.ops import add_arg_scope 10 | from tensorflow.contrib.framework.python.ops import variables 11 | from tensorflow.contrib.layers.python.layers import initializers 12 | from tensorflow.contrib.layers.python.layers import utils 13 | from tensorflow.python.eager import context 14 | from tensorflow.python.framework import constant_op 15 | from tensorflow.python.framework import dtypes 16 | from tensorflow.python.framework import function 17 | from tensorflow.python.framework import ops 18 | from tensorflow.python.framework import sparse_tensor 19 | from tensorflow.python.framework import tensor_shape 20 | from tensorflow.python.layers import base 21 | from tensorflow.python.layers import convolutional as convolutional_layers 22 | from tensorflow.python.layers import core as core_layers 23 | from tensorflow.python.layers import normalization as normalization_layers 24 | from tensorflow.python.layers import pooling as pooling_layers 25 | from tensorflow.python.ops import array_ops 26 | from tensorflow.python.ops import check_ops 27 | from tensorflow.python.ops import init_ops 28 | from tensorflow.python.ops import linalg_ops 29 | from tensorflow.python.ops import math_ops 30 | from tensorflow.python.ops import nn 31 | from tensorflow.python.ops import nn_ops 32 | from tensorflow.python.ops import sparse_ops 33 | from tensorflow.python.ops import standard_ops 34 | from tensorflow.python.ops import variable_scope 35 | from tensorflow.python.ops import variables as tf_variables 36 | from tensorflow.python.training import moving_averages 37 | 38 | DATA_FORMAT_NCHW = 'NCHW' 39 | DATA_FORMAT_NHWC = 'NHWC' 40 | DATA_FORMAT_NCDHW = 'NCDHW' 41 | DATA_FORMAT_NDHWC = 'NDHWC' 42 | 43 | 44 | @add_arg_scope 45 | def quantizable_separable_convolution2d( 46 | inputs, 47 | num_outputs, 48 | kernel_size, 49 | depth_multiplier, 50 | stride=1, 51 | padding='SAME', 52 | data_format=DATA_FORMAT_NHWC, 53 | rate=1, 54 | activation_fn=nn.relu, 55 | normalizer_fn=None, 56 | normalizer_params=None, 57 | weights_initializer=initializers.xavier_initializer(), 58 | weights_regularizer=None, 59 | biases_initializer=init_ops.zeros_initializer(), 60 | biases_regularizer=None, 61 | reuse=None, 62 | variables_collections=None, 63 | outputs_collections=None, 64 | trainable=True, 65 | scope=None): 66 | """Adds a depth-separable 2D convolution with optional batch_norm layer. 67 | This op first performs a depthwise convolution that acts separately on 68 | channels, creating a variable called `depthwise_weights`. If `num_outputs` 69 | is not None, it adds a pointwise convolution that mixes channels, creating a 70 | variable called `pointwise_weights`. Then, if `normalizer_fn` is None, 71 | it adds bias to the result, creating a variable called 'biases', otherwise, 72 | the `normalizer_fn` is applied. It finally applies an activation function 73 | to produce the end result. 74 | Args: 75 | inputs: A tensor of size [batch_size, height, width, channels]. 76 | num_outputs: The number of pointwise convolution output filters. If is 77 | None, then we skip the pointwise convolution stage. 78 | kernel_size: A list of length 2: [kernel_height, kernel_width] of 79 | of the filters. Can be an int if both values are the same. 80 | depth_multiplier: The number of depthwise convolution output channels for 81 | each input channel. The total number of depthwise convolution output 82 | channels will be equal to `num_filters_in * depth_multiplier`. 83 | stride: A list of length 2: [stride_height, stride_width], specifying the 84 | depthwise convolution stride. Can be an int if both strides are the same. 85 | padding: One of 'VALID' or 'SAME'. 86 | data_format: A string. `NHWC` (default) and `NCHW` are supported. 87 | rate: A list of length 2: [rate_height, rate_width], specifying the dilation 88 | rates for atrous convolution. Can be an int if both rates are the same. 89 | If any value is larger than one, then both stride values need to be one. 90 | activation_fn: Activation function. The default value is a ReLU function. 91 | Explicitly set it to None to skip it and maintain a linear activation. 92 | normalizer_fn: Normalization function to use instead of `biases`. If 93 | `normalizer_fn` is provided then `biases_initializer` and 94 | `biases_regularizer` are ignored and `biases` are not created nor added. 95 | default set to None for no normalizer function 96 | normalizer_params: Normalization function parameters. 97 | weights_initializer: An initializer for the weights. 98 | weights_regularizer: Optional regularizer for the weights. 99 | biases_initializer: An initializer for the biases. If None skip biases. 100 | biases_regularizer: Optional regularizer for the biases. 101 | reuse: Whether or not the layer and its variables should be reused. To be 102 | able to reuse the layer scope must be given. 103 | variables_collections: Optional list of collections for all the variables or 104 | a dictionary containing a different list of collection per variable. 105 | outputs_collections: Collection to add the outputs. 106 | trainable: Whether or not the variables should be trainable or not. 107 | scope: Optional scope for variable_scope. 108 | Returns: 109 | A `Tensor` representing the output of the operation. 110 | Raises: 111 | ValueError: If `data_format` is invalid. 112 | """ 113 | if data_format not in (DATA_FORMAT_NCHW, DATA_FORMAT_NHWC): 114 | raise ValueError('data_format has to be either NCHW or NHWC.') 115 | layer_variable_getter = _build_variable_getter({ 116 | 'bias': 'biases', 117 | 'depthwise_kernel': 'depthwise_weights', 118 | 'pointwise_kernel': 'pointwise_weights' 119 | }) 120 | 121 | with variable_scope.variable_scope( 122 | scope, 123 | 'SeparableConv2d', [inputs], 124 | reuse=reuse, 125 | custom_getter=layer_variable_getter) as sc: 126 | inputs = ops.convert_to_tensor(inputs) 127 | 128 | df = ('channels_first' 129 | if data_format and data_format.startswith('NC') else 'channels_last') 130 | if num_outputs is not None: 131 | # Apply separable conv using the SeparableConvolution2D layer. 132 | layer = convolutional_layers.SeparableConvolution2D( 133 | filters=num_outputs, 134 | kernel_size=kernel_size, 135 | strides=stride, 136 | padding=padding, 137 | data_format=df, 138 | dilation_rate=utils.two_element_tuple(rate), 139 | activation=None, 140 | depth_multiplier=depth_multiplier, 141 | use_bias=not normalizer_fn and biases_initializer, 142 | depthwise_initializer=weights_initializer, 143 | pointwise_initializer=weights_initializer, 144 | bias_initializer=biases_initializer, 145 | depthwise_regularizer=weights_regularizer, 146 | pointwise_regularizer=weights_regularizer, 147 | bias_regularizer=biases_regularizer, 148 | activity_regularizer=None, 149 | trainable=trainable, 150 | name=sc.name, 151 | dtype=inputs.dtype.base_dtype, 152 | _scope=sc, 153 | _reuse=reuse) 154 | outputs = layer.apply(inputs) 155 | 156 | # Add variables to collections. 157 | _add_variable_to_collections(layer.depthwise_kernel, 158 | variables_collections, 'weights') 159 | _add_variable_to_collections(layer.pointwise_kernel, 160 | variables_collections, 'weights') 161 | if layer.bias is not None: 162 | _add_variable_to_collections(layer.bias, variables_collections, 163 | 'biases') 164 | 165 | if normalizer_fn is not None: 166 | normalizer_params = normalizer_params or {} 167 | outputs = normalizer_fn(outputs, **normalizer_params) 168 | else: 169 | # Actually apply depthwise conv instead of separable conv. 170 | dtype = inputs.dtype.base_dtype 171 | kernel_h, kernel_w = utils.two_element_tuple(kernel_size) 172 | stride_h, stride_w = utils.two_element_tuple(stride) 173 | num_filters_in = utils.channel_dimension( 174 | inputs.get_shape(), df, min_rank=4) 175 | weights_collections = utils.get_variable_collections( 176 | variables_collections, 'weights') 177 | 178 | depthwise_shape = [kernel_h, kernel_w, num_filters_in, depth_multiplier] 179 | depthwise_weights = variables.model_variable( 180 | 'depthwise_weights', 181 | shape=depthwise_shape, 182 | dtype=dtype, 183 | initializer=weights_initializer, 184 | regularizer=weights_regularizer, 185 | trainable=trainable, 186 | collections=weights_collections) 187 | strides = [1, 1, stride_h, 188 | stride_w] if data_format.startswith('NC') else [ 189 | 1, stride_h, stride_w, 1 190 | ] 191 | 192 | # DIFFERING PART START 193 | 194 | input = ops.convert_to_tensor(inputs, name="tensor_in") 195 | filter = ops.convert_to_tensor(depthwise_weights, name="filter_in") 196 | if rate is None: 197 | rate = [1, 1] 198 | 199 | def op(input_converted, _, padding): 200 | outputs = nn_ops.depthwise_conv2d_native( 201 | input=input_converted, 202 | filter=filter, 203 | strides=strides, 204 | padding=padding, 205 | data_format=data_format) 206 | 207 | num_outputs = depth_multiplier * num_filters_in 208 | 209 | if normalizer_fn is not None: 210 | normalizer_params_ = normalizer_params or {} 211 | outputs = normalizer_fn(outputs, **normalizer_params_) 212 | else: 213 | if biases_initializer is not None: 214 | biases_collections = utils.get_variable_collections( 215 | variables_collections, 'biases') 216 | biases = variables.model_variable( 217 | 'biases', 218 | shape=[ 219 | num_outputs, 220 | ], 221 | dtype=dtype, 222 | initializer=biases_initializer, 223 | regularizer=biases_regularizer, 224 | trainable=trainable, 225 | collections=biases_collections) 226 | outputs = nn.bias_add(outputs, biases, data_format=data_format) 227 | 228 | if activation_fn is not None: 229 | outputs = activation_fn(outputs) 230 | 231 | return outputs 232 | 233 | outputs = nn_ops.with_space_to_batch( 234 | input=input, 235 | filter_shape=array_ops.shape(filter), 236 | dilation_rate=utils.two_element_tuple(rate), 237 | padding=padding, 238 | data_format=data_format, 239 | op=op) 240 | 241 | return utils.collect_named_outputs(outputs_collections, sc.name, outputs) 242 | 243 | # DIFFERING PART END 244 | 245 | 246 | def _build_variable_getter(rename=None): 247 | """Build a model variable getter that respects scope getter and renames.""" 248 | 249 | # VariableScope will nest the getters 250 | def layer_variable_getter(getter, *args, **kwargs): 251 | kwargs['rename'] = rename 252 | return _model_variable_getter(getter, *args, **kwargs) 253 | 254 | return layer_variable_getter 255 | 256 | 257 | def _model_variable_getter(getter, 258 | name, 259 | shape=None, 260 | dtype=None, 261 | initializer=None, 262 | regularizer=None, 263 | trainable=True, 264 | collections=None, 265 | caching_device=None, 266 | partitioner=None, 267 | rename=None, 268 | use_resource=None, 269 | **_): 270 | """Getter that uses model_variable for compatibility with core layers.""" 271 | short_name = name.split('/')[-1] 272 | if rename and short_name in rename: 273 | name_components = name.split('/') 274 | name_components[-1] = rename[short_name] 275 | name = '/'.join(name_components) 276 | return variables.model_variable( 277 | name, 278 | shape=shape, 279 | dtype=dtype, 280 | initializer=initializer, 281 | regularizer=regularizer, 282 | collections=collections, 283 | trainable=trainable, 284 | caching_device=caching_device, 285 | partitioner=partitioner, 286 | custom_getter=getter, 287 | use_resource=use_resource) 288 | 289 | 290 | def _add_variable_to_collections(variable, collections_set, collections_name): 291 | """Adds variable (or all its parts) to all collections with that name.""" 292 | collections = utils.get_variable_collections(collections_set, 293 | collections_name) or [] 294 | variables_list = [variable] 295 | if isinstance(variable, tf_variables.PartitionedVariable): 296 | variables_list = [v for v in variable] 297 | for collection in collections: 298 | for var in variables_list: 299 | if var not in ops.get_collection(collection): 300 | ops.add_to_collection(collection, var) 301 | -------------------------------------------------------------------------------- /metrics/base.py: -------------------------------------------------------------------------------- 1 | from abc import ABC 2 | from collections import defaultdict 3 | 4 | from common.utils import get_logger 5 | from common.utils import format_text 6 | from common.utils import timer 7 | 8 | 9 | class DataStructure(ABC): 10 | """ 11 | Define inner data structure 12 | Should define `_keys` 13 | """ 14 | _keys = None 15 | 16 | def __init__(self, data): 17 | keys = self.__class__.get_keys() 18 | data_keys = data.keys() 19 | 20 | if set(keys) != set(data_keys): 21 | raise ValueError(f"Keys defined in `_keys ({list(keys)})`" 22 | f" should be appeared at " 23 | f"`data ({list(data_keys)})`") 24 | for k in keys: 25 | setattr(self, k, data[k]) 26 | 27 | def __str__(self): 28 | return f"" 29 | 30 | def __repr__(self): 31 | return str(self) 32 | 33 | def to_dict(self): 34 | return {k: getattr(self, k) for k in self._keys} 35 | 36 | @classmethod 37 | def get_keys(cls): 38 | return cls._keys 39 | 40 | 41 | class MetricAggregator: 42 | def __init__(self): 43 | self.step = None 44 | self.metrics_with_result = None 45 | 46 | self.init(-1) 47 | 48 | def init(self, step): 49 | self.step = step 50 | self.metrics_with_result = dict() 51 | 52 | def aggregate(self, metric, metric_result): 53 | assert metric not in self.metrics_with_result 54 | self.metrics_with_result[metric] = metric_result 55 | 56 | def iterate_metrics(self): 57 | for metric, metric_result in self.metrics_with_result.items(): 58 | yield metric, metric_result 59 | 60 | def iterate_all(self): 61 | for metric, metric_result in self.iterate_metrics(): 62 | for metric_key, value in metric_result.items(): 63 | yield metric, metric_key, value 64 | 65 | def get_collection_summary_dict(self): 66 | # for_summary: Dict[str, Dict[MetricOp, List[Tuple(metric_key, tensor_op)]]] 67 | # collection_key -> metric -> List(summary_key, value) 68 | for_summary = defaultdict(lambda: defaultdict(list)) 69 | for metric, metric_result in self.metrics_with_result.items(): 70 | if metric.is_for_summary: 71 | for metric_key, value in metric_result.items(): 72 | for_summary[metric.summary_collection_key][metric].append((metric_key, value)) 73 | 74 | return for_summary 75 | 76 | def get_tensor_metrics(self): 77 | """ 78 | Get metric that would be fetched for session run. 79 | """ 80 | tensor_metrics = dict() 81 | for metric, metric_result in self.metrics_with_result.items(): 82 | if metric.is_tensor_metric: 83 | for metric_key, value in metric_result.items(): 84 | tensor_metrics[metric_key] = value 85 | 86 | return tensor_metrics 87 | 88 | def get_logs(self): 89 | logs = dict() 90 | for metric, metric_result in self.metrics_with_result.items(): 91 | if metric.is_for_log: 92 | for metric_key, value in metric_result.items(): 93 | if isinstance(value, str): 94 | msg = f"> {metric_key}\n{value}" 95 | else: 96 | msg = f"> {metric_key} : {value}" 97 | logs[metric_key] = msg 98 | 99 | return logs 100 | 101 | 102 | class MetricManagerBase(ABC): 103 | _metric_input_data_parser = None 104 | 105 | def __init__(self, exclude_metric_names, summary): 106 | self.log = get_logger("Metrics") 107 | self.build_op_aggregator = MetricAggregator() 108 | self.eval_metric_aggregator = MetricAggregator() 109 | 110 | self.summary = summary 111 | self.exclude_metric_names = exclude_metric_names 112 | 113 | self.metric_ops = [] 114 | 115 | def register_metric(self, metric): 116 | # if metric is in exclude_metric_names ? 117 | if metric.__class__.__name__ in self.exclude_metric_names: 118 | self.log.info(f"{metric.__class__.__name__} is excluded by user setting.") 119 | return 120 | 121 | # assertion for this metric would be processable 122 | assert str(self._metric_input_data_parser) in map(lambda c: str(c), metric.valid_input_data_parsers), \ 123 | f"{metric.__class__.__name__} cannot be parsed by {self._metric_input_data_parser}" 124 | 125 | # add one 126 | self.metric_ops.append(metric) 127 | self.log.info(f"{metric.__class__.__name__} is added.") 128 | 129 | def register_metrics(self, metrics: list): 130 | for metric in metrics: 131 | self.register_metric(metric) 132 | 133 | def build_metric_ops(self, data): 134 | """ 135 | Define tensor metric operations 136 | 1. call `build_op` of metrics, i.e. add operations to graph 137 | 2. register summaries 138 | 139 | Return: Dict[str, Tensor] 140 | metric_key -> metric_op 141 | """ 142 | output_build_data = self._metric_input_data_parser.parse_build_data(data) 143 | 144 | # get metric tf ops 145 | for metric in self.metric_ops: 146 | try: 147 | metric_build_ops = metric.build_op(output_build_data) 148 | except TypeError as e: 149 | raise TypeError(f"[{metric}]: {e}") 150 | self.build_op_aggregator.aggregate(metric, metric_build_ops) 151 | 152 | # if value is not None, it means it is defined with tensor 153 | metric_tf_ops = self.build_op_aggregator.get_tensor_metrics() 154 | 155 | # register summary 156 | collection_summary_dict = self.build_op_aggregator.get_collection_summary_dict() 157 | self.summary.register_summaries(collection_summary_dict) 158 | self.summary.setup_merged_summaries() 159 | 160 | return metric_tf_ops 161 | 162 | # def evaluate_non_tensor_metric(self, data, step): 163 | def evaluate_and_aggregate_metrics(self, non_tensor_data, eval_dict, step): 164 | """ 165 | Run evaluation of non-tensor metrics 166 | Args: 167 | data: data passed from trainer / evaluator/ ... 168 | """ 169 | non_tensor_data = self._metric_input_data_parser.parse_non_tensor_data(non_tensor_data) 170 | 171 | # aggregate metrics 172 | self.eval_metric_aggregator.init(step) 173 | 174 | # evaluate all metrics 175 | for metric, metric_key_op_dict in self.build_op_aggregator.iterate_metrics(): 176 | if metric.is_tensor_metric: 177 | with timer(f"{metric}.expectation_of"): 178 | # already aggregated - tensor ops 179 | metric_result = dict() 180 | for metric_key in metric_key_op_dict: 181 | if metric_key in eval_dict: 182 | exp_value = metric.expectation_of(eval_dict[metric_key]) 183 | metric_result[metric_key] = exp_value 184 | else: 185 | with timer(f"{metric}.evaluate"): 186 | # need calculation - non tensor ops 187 | metric_result = metric.evaluate(non_tensor_data) 188 | 189 | self.eval_metric_aggregator.aggregate(metric, metric_result) 190 | 191 | def write_tensor_summaries(self, step, summary_value): 192 | self.summary.write(summary_value, step) 193 | 194 | def write_evaluation_summaries(self, step, collection_keys): 195 | assert step == self.eval_metric_aggregator.step, \ 196 | (f"step: {step} is different from aggregator's step: {self.eval_metric_aggregator.step}" 197 | f"`evaluate` function should be called before calling this function") 198 | 199 | collection_summary_dict = self.eval_metric_aggregator.get_collection_summary_dict() 200 | self.summary.write_evaluation_summaries(step=step, 201 | collection_keys=collection_keys, 202 | collection_summary_dict=collection_summary_dict) 203 | 204 | def log_metrics(self, step): 205 | """ 206 | Logging metrics that are evaluated. 207 | """ 208 | assert step == self.eval_metric_aggregator.step, \ 209 | (f"step: {step} is different from aggregator's step: {self.eval_metric_aggregator.step}" 210 | f"`evaluate` function should be called before calling this function") 211 | 212 | log_dicts = dict() 213 | log_dicts.update(self.eval_metric_aggregator.get_logs()) 214 | 215 | with format_text("green", ["bold"]) as fmt: 216 | for metric_key, log_str in log_dicts.items(): 217 | # self.log.info(fmt(metric_key)) 218 | self.log.info(fmt(log_str)) 219 | # self.info_cyan_underline(f"> {metric_key} : {eval_scores[metric_key]}") 220 | 221 | def get_evaluation_result(self, step): 222 | """ 223 | Retrun evaluation result regardless of metric type. 224 | """ 225 | assert step == self.eval_metric_aggregator.step, \ 226 | (f"step: {step} is different from aggregator's step: {self.eval_metric_aggregator.step}" 227 | f"`evaluate` function should be called before calling this function") 228 | 229 | eval_dict = dict() 230 | for metric, metric_key, value in self.eval_metric_aggregator.iterate_all(): 231 | eval_dict[metric_key] = value 232 | 233 | return eval_dict 234 | 235 | def get_best_keep_metric_with_modes(self): 236 | metric_min_max_dict = dict() 237 | for metric, metric_key, _ in self.build_op_aggregator.iterate_all(): 238 | if metric.is_for_best_keep: 239 | metric_min_max_dict[metric_key] = metric.min_max_mode 240 | 241 | return metric_min_max_dict 242 | 243 | def filter_best_keep_metric(self, eval_metric_dict): 244 | best_keep_metric_dict = dict() 245 | for metric, metric_key, _ in self.build_op_aggregator.iterate_all(): 246 | if metric_key in eval_metric_dict and metric.is_for_best_keep: 247 | best_keep_metric_dict[metric_key] = eval_metric_dict[metric_key] 248 | 249 | return best_keep_metric_dict 250 | 251 | @staticmethod 252 | def add_arguments(parser): 253 | subparser = parser.add_argument_group(f"Metric Manager Arguments") 254 | subparser.add_argument("--exclude_metric_names", 255 | nargs="*", 256 | default=[], 257 | type=str, 258 | help="Name of metrics to be excluded") 259 | subparser.add_argument("--max_image_outputs", 260 | default=3, 261 | type=int, 262 | help="Number of maximum image outputs") 263 | -------------------------------------------------------------------------------- /metrics/manager.py: -------------------------------------------------------------------------------- 1 | import metrics.ops as mops 2 | import metrics.parser as parser 3 | from metrics.base import MetricManagerBase 4 | from metrics.summaries import Summaries 5 | 6 | 7 | class MattingMetricManager(MetricManagerBase): 8 | _metric_input_data_parser = parser.MattingDataParser 9 | 10 | def __init__(self, 11 | is_training: bool, 12 | save_evaluation_image: bool, 13 | exclude_metric_names: list, 14 | summary: Summaries): 15 | super().__init__(exclude_metric_names, summary) 16 | self.register_metrics([ 17 | # misc 18 | mops.InferenceTimeMetricOp(), 19 | # tensor ops 20 | mops.LossesMetricOp(), 21 | mops.ImageSummaryOp(), 22 | 23 | mops.MADMetricOp(), 24 | mops.GaussianGradMetricOp(), 25 | ]) 26 | 27 | if not is_training and save_evaluation_image: 28 | self.register_metrics([ 29 | mops.MiscImageRetrieveOp(), 30 | mops.MiscImageSaveOp(), 31 | ]) 32 | -------------------------------------------------------------------------------- /metrics/ops/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_ops import MetricOpBase 2 | from .base_ops import TensorMetricOpBase 3 | from .base_ops import NonTensorMetricOpBase 4 | from .tensor_ops import * 5 | from .misc_ops import * 6 | -------------------------------------------------------------------------------- /metrics/ops/base_ops.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | from common.utils import get_logger 4 | from metrics.summaries import BaseSummaries 5 | 6 | 7 | class MetricOpBase(ABC): 8 | MIN_MAX_CHOICES = ["min", "max", None] 9 | 10 | _meta_properties = [ 11 | "is_for_summary", 12 | "is_for_best_keep", 13 | "is_for_log", 14 | "valid_input_data_parsers", 15 | "summary_collection_key", 16 | "summary_value_type", 17 | "min_max_mode", 18 | ] 19 | _properties = dict() 20 | 21 | def __init__(self, **kwargs): 22 | self.log = get_logger("MetricOp") 23 | 24 | # init by _properties 25 | # custom values can be added as kwargs 26 | for attr in self._meta_properties: 27 | if attr in kwargs: 28 | setattr(self, attr, kwargs[attr]) 29 | else: 30 | setattr(self, attr, self._properties[attr]) 31 | 32 | # assertion 33 | assert self.min_max_mode in self.MIN_MAX_CHOICES 34 | 35 | if self.is_for_best_keep: 36 | assert self.min_max_mode is not None 37 | 38 | if self.is_for_summary: 39 | assert self.summary_collection_key in vars(BaseSummaries.KEY_TYPES).values() 40 | 41 | def __hash__(self): 42 | return hash(str(self)) 43 | 44 | def __eq__(self, other): 45 | return str(self) == str(other) 46 | 47 | @property 48 | def is_placeholder_summary(self): 49 | assert self.is_for_summary, f"DO NOT call `is_placeholder_summary` method if it is not summary metric" 50 | return self.summary_value_type == BaseSummaries.VALUE_TYPES.PLACEHOLDER 51 | 52 | @property 53 | @abstractmethod 54 | def is_tensor_metric(self): 55 | raise NotImplementedError 56 | 57 | @abstractmethod 58 | def __str__(self): 59 | raise NotImplementedError 60 | 61 | @abstractmethod 62 | def build_op(self, data): 63 | """ This class should be overloaded for 64 | all cases of `valid_input_data_parser` 65 | """ 66 | raise NotImplementedError 67 | 68 | 69 | class NonTensorMetricOpBase(MetricOpBase): 70 | @property 71 | def is_tensor_metric(self): 72 | return False 73 | 74 | @abstractmethod 75 | def evaluate(self, data): 76 | """ This class should be overloaded for 77 | all cases of `valid_input_data_parser` 78 | """ 79 | raise NotImplementedError 80 | 81 | 82 | class TensorMetricOpBase(MetricOpBase): 83 | @property 84 | def is_tensor_metric(self): 85 | return True 86 | 87 | @abstractmethod 88 | def expectation_of(self, data): 89 | """ If evaluate is done at tensor metric, it has to re-caculate the expectation of 90 | aggregated metric values. 91 | This function assumes that data is aggregated for all batches 92 | and retruns proper expectation value. 93 | """ 94 | raise NotImplementedError 95 | -------------------------------------------------------------------------------- /metrics/ops/misc_ops.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | 4 | from metrics.ops.base_ops import NonTensorMetricOpBase 5 | from metrics.ops.base_ops import TensorMetricOpBase 6 | from metrics.summaries import BaseSummaries 7 | import metrics.parser as parser 8 | 9 | 10 | class InferenceTimeMetricOp(NonTensorMetricOpBase): 11 | """ 12 | Inference Time Metric. 13 | """ 14 | _properties = { 15 | "is_for_summary": True, 16 | "is_for_best_keep": False, 17 | "is_for_log": True, 18 | "valid_input_data_parsers": [ 19 | parser.MattingDataParser, 20 | ], 21 | "summary_collection_key": BaseSummaries.KEY_TYPES.DEFAULT, 22 | "summary_value_type": BaseSummaries.VALUE_TYPES.PLACEHOLDER, 23 | "min_max_mode": None, 24 | } 25 | 26 | def __str__(self): 27 | return "inference_time_metric" 28 | 29 | def build_op(self, 30 | data): 31 | return { 32 | f"misc/batch_infer_time/{data.dataset_split_name}": None, 33 | f"misc/unit_infer_time/{data.dataset_split_name}": None, 34 | } 35 | 36 | def evaluate(self, 37 | data): 38 | return { 39 | f"misc/batch_infer_time/{data.dataset_split_name}": np.mean(data.batch_infer_time), 40 | f"misc/unit_infer_time/{data.dataset_split_name}": np.mean(data.unit_infer_time), 41 | } 42 | 43 | 44 | class MiscImageRetrieveOp(TensorMetricOpBase): 45 | """ Image not recoreded on summary but to be retrieved. 46 | """ 47 | _properties = { 48 | "is_for_summary": False, 49 | "is_for_best_keep": False, 50 | "is_for_log": False, 51 | "valid_input_data_parsers": [ 52 | parser.MattingDataParser, 53 | ], 54 | "summary_collection_key": None, 55 | "summary_value_type": None, 56 | "min_max_mode": None, 57 | } 58 | 59 | def __str__(self): 60 | return "misc_image_retrieve" 61 | 62 | def build_op(self, data): 63 | result = dict() 64 | 65 | for name, op in data.misc_images.items(): 66 | key = f"misc_images/{data.dataset_split_name}/{name}" 67 | result[key] = op 68 | 69 | return result 70 | 71 | def expectation_of(self, data): 72 | # we don't aggregate this output 73 | pass 74 | 75 | 76 | class MiscImageSaveOp(NonTensorMetricOpBase): 77 | """ Image save 78 | """ 79 | _properties = { 80 | "is_for_summary": False, 81 | "is_for_best_keep": False, 82 | "is_for_log": True, 83 | "valid_input_data_parsers": [ 84 | parser.MattingDataParser, 85 | ], 86 | "summary_collection_key": None, 87 | "summary_value_type": None, 88 | "min_max_mode": None, 89 | } 90 | 91 | def __str__(self): 92 | return "misc_image_save" 93 | 94 | def build_op(self, data): 95 | return { 96 | f"save_images/{data.dataset_split_name}": None 97 | } 98 | 99 | def evaluate(self, data): 100 | keys = [] 101 | images = [] 102 | for key, image in data.misc_images.items(): 103 | keys.append(key) 104 | images.append(image) 105 | 106 | # meta info 107 | num_data = images[0].shape[0] 108 | h, w = images[0].shape[1:3] 109 | 110 | for nidx in range(num_data): 111 | _images = [Image.fromarray(img[nidx]) for img in images] 112 | 113 | merged_image = Image.new("RGB", (h * ((len(_images)-1) // 3 + 1), w * 3)) 114 | for i, img in enumerate(_images): 115 | row = i // 3 116 | col = i % 3 117 | 118 | merged_image.paste(img, (row*w, col*h, (row+1)*w, (col+1)*h)) 119 | merged_image.save(data.image_save_dir / f"img_{nidx}.jpg") 120 | 121 | # msg 122 | msg = f"{num_data} images are saved under {data.image_save_dir.resolve()}" 123 | 124 | return { 125 | f"save_images/{data.dataset_split_name}": msg 126 | } 127 | -------------------------------------------------------------------------------- /metrics/ops/tensor_ops.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | from overload import overload 4 | 5 | import common.image_helper as image_helper 6 | import metrics.parser as parser 7 | from metrics.ops.base_ops import TensorMetricOpBase 8 | from metrics.summaries import BaseSummaries 9 | 10 | 11 | class LossesMetricOp(TensorMetricOpBase): 12 | """ Loss Metric. 13 | """ 14 | _properties = { 15 | "is_for_summary": True, 16 | "is_for_best_keep": True, 17 | "is_for_log": True, 18 | "valid_input_data_parsers": [ 19 | parser.MattingDataParser, 20 | ], 21 | "summary_collection_key": BaseSummaries.KEY_TYPES.DEFAULT, 22 | "summary_value_type": BaseSummaries.VALUE_TYPES.PLACEHOLDER, 23 | "min_max_mode": "min", 24 | } 25 | 26 | def __str__(self): 27 | return "losses" 28 | 29 | def build_op(self, data): 30 | result = dict() 31 | 32 | for loss_name, loss_op in data.losses.items(): 33 | key = f"metric_loss/{data.dataset_split_name}/{loss_name}" 34 | result[key] = loss_op 35 | 36 | return result 37 | 38 | def expectation_of(self, data: np.array): 39 | assert len(data.shape) == 2 40 | return np.mean(data) 41 | 42 | 43 | class ImageSummaryOp(TensorMetricOpBase): 44 | """ Image summary 45 | """ 46 | _properties = { 47 | "is_for_summary": True, 48 | "is_for_best_keep": False, 49 | "is_for_log": False, 50 | "valid_input_data_parsers": [ 51 | parser.MattingDataParser, 52 | ], 53 | "summary_collection_key": BaseSummaries.KEY_TYPES.DEFAULT, 54 | "summary_value_type": BaseSummaries.VALUE_TYPES.IMAGE, 55 | "min_max_mode": None, 56 | } 57 | 58 | def __str__(self): 59 | return "summary_images" 60 | 61 | @overload 62 | def build_op(self, 63 | data: parser.MattingDataParser.OutputBuildData): 64 | result = dict() 65 | 66 | for summary_name, op in data.summary_images.items(): 67 | key = f"{summary_name}/{data.dataset_split_name}" 68 | result[key] = op 69 | 70 | return result 71 | 72 | def expectation_of(self, data): 73 | pass 74 | 75 | 76 | class MADMetricOp(TensorMetricOpBase): 77 | """ Mean Average Difference Metric. 78 | """ 79 | _properties = { 80 | "is_for_summary": True, 81 | "is_for_best_keep": True, 82 | "is_for_log": True, 83 | "valid_input_data_parsers": [ 84 | parser.MattingDataParser, 85 | ], 86 | "summary_collection_key": BaseSummaries.KEY_TYPES.DEFAULT, 87 | "summary_value_type": BaseSummaries.VALUE_TYPES.PLACEHOLDER, 88 | "min_max_mode": "min", 89 | } 90 | 91 | def __str__(self): 92 | return "mad_metric" 93 | 94 | @overload 95 | def build_op(self, 96 | data: parser.MattingDataParser.OutputBuildData): 97 | def _calc(masks, alpha_scores): 98 | return tf.reduce_mean(tf.abs(masks - alpha_scores)) 99 | 100 | result = dict() 101 | 102 | result[f"MAD/{data.dataset_split_name}"] = _calc(data.masks, data.alpha_scores) 103 | 104 | if data.target_eval_shape: 105 | suffix = f"{data.target_eval_shape[0]}_{data.target_eval_shape[1]}" 106 | result[f"MAD/{data.dataset_split_name}/{suffix}"] = _calc(data.ts_masks, data.ts_alpha_scores) 107 | 108 | return result 109 | 110 | def expectation_of(self, data: np.array): 111 | assert len(data.shape) == 2 112 | return np.mean(data) 113 | 114 | 115 | class GaussianGradMetricOp(TensorMetricOpBase): 116 | """ Gradient Error Metric. 117 | """ 118 | _properties = { 119 | "is_for_summary": True, 120 | "is_for_best_keep": True, 121 | "is_for_log": True, 122 | "valid_input_data_parsers": [ 123 | parser.MattingDataParser, 124 | ], 125 | "summary_collection_key": BaseSummaries.KEY_TYPES.DEFAULT, 126 | "summary_value_type": BaseSummaries.VALUE_TYPES.PLACEHOLDER, 127 | "min_max_mode": "min", 128 | } 129 | 130 | def __str__(self): 131 | return "gaussian_grad_metric" 132 | 133 | @overload 134 | def build_op(self, 135 | data: parser.MattingDataParser.OutputBuildData): 136 | def _norm(x): 137 | return tf.sqrt(tf.reduce_sum(tf.square(x), axis=-1)) 138 | 139 | def _calc(masks, alpha_scores): 140 | grad_masks = image_helper.first_deriviate_gaussian_gradient( 141 | masks, 142 | sigma=1.4, 143 | ) 144 | grad_alpha_scores = image_helper.first_deriviate_gaussian_gradient( 145 | alpha_scores, 146 | sigma=1.4, 147 | ) 148 | 149 | metric = tf.reduce_mean(_norm(grad_masks - grad_alpha_scores)) 150 | 151 | return metric 152 | 153 | result = dict() 154 | result[f"GAUSS_GRAD/{data.dataset_split_name}"] = _calc(data.masks, data.alpha_scores) 155 | 156 | if data.target_eval_shape: 157 | suffix = f"{data.target_eval_shape[0]}_{data.target_eval_shape[1]}" 158 | result[f"GAUSS_GRAD/{data.dataset_split_name}/{suffix}"] = _calc(data.ts_masks, data.ts_alpha_scores) 159 | 160 | return result 161 | 162 | def expectation_of(self, data: np.array): 163 | assert len(data.shape) == 2 164 | return np.mean(data) 165 | -------------------------------------------------------------------------------- /metrics/parser.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, ABCMeta 2 | 3 | import tensorflow as tf 4 | 5 | from metrics.base import DataStructure 6 | 7 | 8 | class MetricDataParserBase(ABC): 9 | @classmethod 10 | def parse_build_data(cls, data): 11 | """ 12 | Args: 13 | data: dictionary which will be passed to InputBuildData 14 | """ 15 | data = cls._validate_build_data(data) 16 | data = cls._process_build_data(data) 17 | return data 18 | 19 | @classmethod 20 | def parse_non_tensor_data(cls, data): 21 | """ 22 | Args: 23 | data: dictionary which will be passed to InputDataStructure 24 | """ 25 | input_data = cls._validate_non_tensor_data(data) 26 | output_data = cls._process_non_tensor_data(input_data) 27 | return output_data 28 | 29 | @classmethod 30 | def _validate_build_data(cls, data): 31 | """ 32 | Specify assertions that tensor data should contains 33 | 34 | Args: 35 | data: dictionary 36 | Return: 37 | InputDataStructure 38 | """ 39 | return cls.InputBuildData(data) 40 | 41 | @classmethod 42 | def _validate_non_tensor_data(cls, data): 43 | """ 44 | Specify assertions that non-tensor data should contains 45 | 46 | Args: 47 | data: dictionary 48 | Return: 49 | InputDataStructure 50 | """ 51 | return cls.InputNonTensorData(data) 52 | 53 | """ 54 | Override these two functions if needed. 55 | """ 56 | @classmethod 57 | def _process_build_data(cls, data): 58 | """ 59 | Process data in order to following metrics can use it 60 | 61 | Args: 62 | data: InputBuildData 63 | 64 | Return: 65 | OutputBuildData 66 | """ 67 | # default function is just passing data 68 | return cls.OutputBuildData(data.to_dict()) 69 | 70 | @classmethod 71 | def _process_non_tensor_data(cls, data): 72 | """ 73 | Process data in order to following metrics can use it 74 | 75 | Args: 76 | data: InputNonTensorData 77 | 78 | Return: 79 | OutputNonTensorData 80 | """ 81 | # default function is just passing data 82 | return cls.OutputNonTensorData(data.to_dict()) 83 | 84 | """ 85 | Belows should be implemented when inherit. 86 | """ 87 | class InputBuildData(DataStructure, metaclass=ABCMeta): 88 | pass 89 | 90 | class OutputBuildData(DataStructure, metaclass=ABCMeta): 91 | pass 92 | 93 | class InputNonTensorData(DataStructure, metaclass=ABCMeta): 94 | pass 95 | 96 | class OutputNonTensorData(DataStructure, metaclass=ABCMeta): 97 | pass 98 | 99 | 100 | class MattingDataParser(MetricDataParserBase): 101 | """ Matting parser 102 | """ 103 | class InputBuildData(DataStructure): 104 | _keys = [ 105 | "dataset_split_name", 106 | "target_eval_shape", # Tuple(int, int) | (height, width) 107 | 108 | "losses", # Dict | loss_key -> Tensor 109 | "summary_images", # Dict | summary_name -> Tensor 110 | "misc_images", # Dict | name -> Tensor 111 | "masks", # Tensor 112 | "masks_original", # Tensor 113 | "probs", # Tensor 114 | ] 115 | 116 | class OutputBuildData(DataStructure): 117 | _keys = [ 118 | "dataset_split_name", 119 | "target_eval_shape", 120 | 121 | "losses", 122 | "summary_images", 123 | "misc_images", 124 | "masks", 125 | "alpha_scores", # Tensor 126 | "binary_masks", 127 | "binary_alpha_scores", 128 | 129 | "ts_masks", 130 | "ts_alpha_scores", # Tensor 131 | "ts_binary_masks", 132 | "ts_binary_alpha_scores", 133 | ] 134 | 135 | class InputNonTensorData(DataStructure): 136 | _keys = [ 137 | "dataset_split_name", 138 | 139 | "batch_infer_time", 140 | "unit_infer_time", 141 | "misc_images", 142 | "image_save_dir", 143 | ] 144 | 145 | class OutputNonTensorData(DataStructure): 146 | _keys = [ 147 | "dataset_split_name", 148 | 149 | "batch_infer_time", 150 | "unit_infer_time", 151 | "misc_images", 152 | "image_save_dir", 153 | ] 154 | 155 | @classmethod 156 | def _process_build_data(cls, data): 157 | masks = tf.expand_dims(data.masks, axis=3) 158 | alpha_scores = data.probs[:, :, :, -1:] 159 | binary_masks = tf.greater_equal(masks, 0.5) 160 | binary_alpha_scores = tf.greater_equal(alpha_scores, 0.5) 161 | 162 | if data.target_eval_shape: 163 | ts_masks = tf.image.resize_bilinear(data.masks_original, data.target_eval_shape) 164 | ts_alpha_scores = tf.image.resize_bilinear(alpha_scores, data.target_eval_shape) 165 | 166 | ts_binary_masks = tf.greater_equal(ts_masks, 0.5) 167 | ts_binary_alpha_scores = tf.greater_equal(ts_alpha_scores, 0.5) 168 | else: 169 | ts_masks = None 170 | ts_alpha_scores = None 171 | 172 | ts_binary_masks = None 173 | ts_binary_alpha_scores = None 174 | 175 | return cls.OutputBuildData({ 176 | "dataset_split_name": data.dataset_split_name, 177 | "target_eval_shape": data.target_eval_shape, 178 | "losses": data.losses, 179 | "summary_images": data.summary_images, 180 | "misc_images": data.misc_images, 181 | 182 | "masks": masks, 183 | "alpha_scores": alpha_scores, 184 | "binary_masks": binary_masks, 185 | "binary_alpha_scores": binary_alpha_scores, 186 | 187 | "ts_masks": ts_masks, 188 | "ts_alpha_scores": ts_alpha_scores, 189 | "ts_binary_masks": ts_binary_masks, 190 | "ts_binary_alpha_scores": ts_binary_alpha_scores, 191 | }) 192 | 193 | @classmethod 194 | def _process_non_tensor_data(cls, data): 195 | return cls.OutputNonTensorData({ 196 | "dataset_split_name": data.dataset_split_name, 197 | "batch_infer_time": data.batch_infer_time, 198 | "unit_infer_time": data.unit_infer_time, 199 | "misc_images": data.misc_images, 200 | "image_save_dir": data.image_save_dir, 201 | }) 202 | -------------------------------------------------------------------------------- /metrics/summaries.py: -------------------------------------------------------------------------------- 1 | from abc import ABC 2 | from types import SimpleNamespace 3 | from pathlib import Path 4 | 5 | import tensorflow as tf 6 | from overload import overload 7 | 8 | from common.utils import get_logger 9 | 10 | 11 | class BaseSummaries(ABC): 12 | KEY_TYPES = SimpleNamespace( 13 | DEFAULT="SUMMARY_DEFAULT", 14 | VERBOSE="SUMMARY_VERBOSE", 15 | FIRST_N="SUMMARY_FIRST_N", 16 | ) 17 | 18 | VALUE_TYPES = SimpleNamespace( 19 | SCALAR="SCALAR", 20 | IMAGE="IMAGE", 21 | HISTOGRAM="HISTOGRAM", 22 | PLACEHOLDER="PLACEHOLDER", 23 | NONE="NONE", # None is not used for summary 24 | ) 25 | 26 | def __init__(self, session, train_dir, is_training, 27 | max_image_outputs=None): 28 | self.log = get_logger("Summary") 29 | 30 | self.session = session 31 | self.train_dir = train_dir 32 | self.max_image_outputs = max_image_outputs 33 | self.merged_summaries = dict() 34 | 35 | self.summary_writer = None 36 | self._setup_summary_writer(is_training) 37 | 38 | def write(self, summary, global_step=0): 39 | self.summary_writer.add_summary(summary, global_step) 40 | 41 | def setup_experiment(self, config): 42 | """ 43 | Args: 44 | config: Namespace 45 | """ 46 | config = vars(config) 47 | 48 | sorted_config = [(k, str(v)) for k, v in sorted(config.items(), key=lambda x: x[0])] 49 | config = tf.summary.text("config", tf.convert_to_tensor(sorted_config)) 50 | 51 | config_val = self.session.run(config) 52 | self.write(config_val) 53 | self.summary_writer.add_graph(tf.get_default_graph()) # This is magic 54 | 55 | def register_summaries(self, collection_summary_dict): 56 | """ 57 | Args: 58 | collection_summary_dict: Dict[str, Dict[MetricOp, List[Tuple(metric_key, tensor_op)]]] 59 | collection_key -> metric -> List(summary_key, value) 60 | """ 61 | for collection_key_suffix, metric_dict in collection_summary_dict.items(): 62 | for metric, key_value_list in metric_dict.items(): 63 | for summary_key, value in key_value_list: 64 | self._routine_add_summary_op(summary_value_type=metric.summary_value_type, 65 | summary_key=summary_key, 66 | value=value, 67 | collection_key_suffix=collection_key_suffix) 68 | 69 | def setup_merged_summaries(self): 70 | for collection_key_suffix in vars(self.KEY_TYPES).values(): 71 | for collection_key in self._iterate_collection_keys(collection_key_suffix): 72 | merged_summary = tf.summary.merge_all(key=collection_key) 73 | self.merged_summaries[collection_key] = merged_summary 74 | 75 | def get_merged_summaries(self, collection_key_suffixes: list, is_tensor_summary: bool): 76 | summaries = [] 77 | 78 | for collection_key_suffix in collection_key_suffixes: 79 | collection_key = self._build_collection_key(collection_key_suffix, is_tensor_summary) 80 | summary = self.merged_summaries[collection_key] 81 | 82 | if summary is not None: 83 | summaries.append(summary) 84 | 85 | if len(summaries) == 0: 86 | return None 87 | elif len(summaries) == 1: 88 | return summaries[0] 89 | else: 90 | return tf.summary.merge(summaries) 91 | 92 | def write_evaluation_summaries(self, step, collection_keys, collection_summary_dict): 93 | """ 94 | Args: 95 | collection_summary_dict: Dict[str, Dict[MetricOp, List[Tuple(metric_key, tensor_op)]]] 96 | collection_key -> metric -> List(summary_key, value) 97 | collection_keys: List 98 | """ 99 | for collection_key_suffix, metric_dict in collection_summary_dict.items(): 100 | if collection_key_suffix in collection_keys: 101 | merged_summary_op = self.get_merged_summaries(collection_key_suffixes=[collection_key_suffix], 102 | is_tensor_summary=False) 103 | feed_dict = dict() 104 | 105 | for metric, key_value_list in metric_dict.items(): 106 | if metric.is_placeholder_summary: 107 | for summary_key, value in key_value_list: 108 | # https://github.com/tensorflow/tensorflow/issues/3378 109 | placeholder_name = self._build_placeholder_name(summary_key) + ":0" 110 | feed_dict[placeholder_name] = value 111 | 112 | summary_value = self.session.run(merged_summary_op, feed_dict=feed_dict) 113 | self.write(summary_value, step) 114 | 115 | def _setup_summary_writer(self, is_training): 116 | summary_directory = self._build_summary_directory(is_training) 117 | self.log.info(f"Write summaries into : {summary_directory}") 118 | 119 | if is_training: 120 | self.summary_writer = tf.summary.FileWriter(summary_directory, self.session.graph) 121 | else: 122 | self.summary_writer = tf.summary.FileWriter(summary_directory) 123 | 124 | def _build_summary_directory(self, is_training): 125 | if is_training: 126 | return self.train_dir 127 | else: 128 | if Path(self.train_dir).is_dir(): 129 | summary_directory = (Path(self.train_dir) / Path("eval")).as_posix() 130 | else: 131 | summary_directory = (Path(self.train_dir).parent / Path("eval")).as_posix() 132 | 133 | if not Path(summary_directory).exists(): 134 | Path(summary_directory).mkdir(parents=True) 135 | 136 | return summary_directory 137 | 138 | def _routine_add_summary_op(self, summary_value_type, summary_key, value, collection_key_suffix): 139 | collection_key = self._build_collection_key(collection_key_suffix, summary_value_type) 140 | 141 | if summary_value_type == self.VALUE_TYPES.SCALAR: 142 | def register_fn(k, v): 143 | return tf.summary.scalar(k, v, collections=[collection_key]) 144 | value = self._build_placeholder(summary_key) 145 | 146 | elif summary_value_type == self.VALUE_TYPES.IMAGE: 147 | def register_fn(k, v): 148 | return tf.summary.image(k, v, max_outputs=self.max_image_outputs, collections=[collection_key]) 149 | 150 | elif summary_value_type == self.VALUE_TYPES.HISTOGRAM: 151 | def register_fn(k, v): 152 | return tf.summary.histogram(k, v, collections=[collection_key]) 153 | 154 | elif summary_value_type == self.VALUE_TYPES.PLACEHOLDER: 155 | def register_fn(k, v): 156 | return tf.summary.scalar(k, v, collections=[collection_key]) 157 | value = self._build_placeholder(summary_key) 158 | 159 | else: 160 | raise NotImplementedError 161 | 162 | register_fn(summary_key, value) 163 | 164 | @classmethod 165 | def _build_placeholder(cls, summary_key): 166 | name = cls._build_placeholder_name(summary_key) 167 | return tf.placeholder(tf.float32, [], name=name) 168 | 169 | @staticmethod 170 | def _build_placeholder_name(summary_key): 171 | return f"non_tensor_summary_placeholder/{summary_key}" 172 | 173 | # below two functions should be class method but defined as instance method 174 | # since it has bug with @overload 175 | # @classmethod 176 | @overload 177 | def _build_collection_key(self, collection_key_suffix, summary_value_type: str): 178 | if summary_value_type == self.VALUE_TYPES.PLACEHOLDER: 179 | prefix = "NON_TENSOR" 180 | else: 181 | prefix = "TENSOR" 182 | 183 | return f"{prefix}_{collection_key_suffix}" 184 | 185 | # @classmethod 186 | @_build_collection_key.add 187 | def _build_collection_key(self, collection_key_suffix, is_tensor_summary: bool): 188 | if not is_tensor_summary: 189 | prefix = "NON_TENSOR" 190 | else: 191 | prefix = "TENSOR" 192 | 193 | return f"{prefix}_{collection_key_suffix}" 194 | 195 | @classmethod 196 | def _iterate_collection_keys(cls, collection_key_suffix): 197 | for prefix in ["NON_TENSOR", "TENSOR"]: 198 | yield f"{prefix}_{collection_key_suffix}" 199 | 200 | 201 | class Summaries(BaseSummaries): 202 | pass 203 | -------------------------------------------------------------------------------- /nets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hyperconnect/MMNet/aa423b598f38e051fcf46914cf8a7765458c09d2/nets/__init__.py -------------------------------------------------------------------------------- /nets/mobilenet: -------------------------------------------------------------------------------- 1 | ../models/research/slim/nets/mobilenet -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | ## Towards Real-Time Automatic Portrait Matting on Mobile Devices 2 | 3 | We tackle the problem of automatic portrait matting on mobile devices. 4 | The proposed model is aimed at attaining real-time inference on mobile devices with minimal degradation of model performance. 5 | Our model MMNet, based on multi-branch dilated convolution with linear bottleneck blocks, outperforms the state-of-the-art model and is orders of magnitude faster. 6 | The model can be accelerated four times to attain 30 FPS on Xiaomi Mi 5 device with moderate increase in the gradient error. 7 | Under the same conditions, our model has an order of magnitude less number of parameters and is faster than Mobile DeepLabv3 while maintaining comparable performance. 8 | 9 |

10 | gradient_error_vs_latency 11 |

12 | 13 | The trade-off between gradient error and latency on a mobile device. 14 | Latency is measured using a Qualcomm Snapdragon 820 MSM8996 CPU. 15 | Size of each circle is proportional to the logarithm of the number of parameters used by the model. 16 | Different circles of Mobile DeepLabv3 are created by varying the output stride and width multiplier. 17 | The circles are marked with their width multiplier. 18 | Results using 128 x 128 inputs are marked with * , otherwise, inputs are in 256 x 256. 19 | Notice that MMNet outperforms all other models forming a Pareto front. 20 | The number of parameters for LDN+FB is not reported in their paper. 21 | 22 | 23 | ## Requirements 24 | 25 | - Python 3.6+ 26 | - Tensorflow 1.6 27 | 28 | ## Installation 29 | 30 | ``` 31 | git clone --recursive https://github.com/hyperconnect/MMNet.git 32 | pip3 install -r requirements/py36-gpu.txt 33 | ``` 34 | 35 | ## Dataset 36 | Dataset for training and evaluation has to follow directory structure as depticted below. 37 | To use other name than `train` and `test`, one can utilize `--dataset_split_name` argument in *train.py* or *evaluate.py*. 38 | ``` 39 | dataset_directory 40 | |___ train 41 | | |__ mask 42 | | |__ image 43 | | 44 | |___ test 45 | |__ mask 46 | |__ image 47 | ``` 48 | 49 | 50 | ## Training 51 | In `scripts` directory, you can find example scripts for training and evaluation of MMNet and Mobile DeepLabv3. 52 | Training scripts accept two arguments: `dataset path` and `train directory`. 53 | `dataset path` has to point to directory with structure described in the previous section. 54 | 55 | ### MMNet 56 | Training of MMNet with depth multiplier 1.0 and input image size 256. 57 | 58 | ```bash 59 | ./scripts/train_mmnet_dm1.0_256.sh /path/to/dataset /path/to/training/directory 60 | ``` 61 | 62 | ### Mobile DeepLabv3 63 | Training of Mobile DeepLabv3 with output stride 16, depth multiplier 0.5 and input image size 256. 64 | 65 | ```bash 66 | ./scripts/train_deeplab_os16_dm0.5_256.sh /path/to/dataset /path/to/training/directory 67 | ``` 68 | 69 | 70 | 71 | ## Evaluation 72 | Evaluation scripts, same as training scripts, accept two arguments: `dataset path` and `train directory`. 73 | If `train directory` argument points to specific checkpoint file, only that checkpoint file will be evaluated, otherwise the latest checkpoint file will be evaluated. 74 | It is recommended to run evaluation scripts together with training scripts in order to get evaluation metrics for every checkpoint file. 75 | 76 | ### MMNet 77 | 78 | ```bash 79 | ./scripts/valid_mmnet_dm1.0_256.sh /path/to/dataset /path/to/training/directory 80 | ``` 81 | 82 | ### Mobile DeepLabv3 83 | 84 | ```bash 85 | ./scripts/valid_deeplab_os16_dm0.5_256.sh /path/to/dataset /path/to/training/directory 86 | ``` 87 | 88 | ## Demo 89 | 90 | Refer to `demo/demo.mp4`. 91 | 92 | ## License 93 | 94 | [Apache License 2.0](LICENSE) 95 | -------------------------------------------------------------------------------- /requirements/py36-common.txt: -------------------------------------------------------------------------------- 1 | # ML 2 | numpy 3 | Cython 4 | scipy 5 | pandas 6 | tqdm 7 | scikit-learn 8 | 9 | # data 10 | termcolor 11 | pillow 12 | click 13 | matplotlib 14 | 15 | # devtools 16 | gpustat 17 | 18 | # Vision 19 | scikit-image 20 | opencv-python 21 | 22 | # Miscellaneous 23 | humanfriendly 24 | overload 25 | deprecation 26 | -------------------------------------------------------------------------------- /requirements/py36-cpu.txt: -------------------------------------------------------------------------------- 1 | # Deep learning 2 | tensorflow==1.6.0 3 | -r py36-common.txt -------------------------------------------------------------------------------- /requirements/py36-gpu.txt: -------------------------------------------------------------------------------- 1 | # Deep learning 2 | tensorflow-gpu==1.6.0 3 | -r py36-common.txt -------------------------------------------------------------------------------- /scripts/train_deeplab_os16_dm0.5_256.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -eux 3 | 4 | dataset_path=${1} 5 | train_dir=${2:-deeplab-traindir} 6 | 7 | python train.py \ 8 | --num_classes 2 \ 9 | --task_type matting \ 10 | --output_name output/score \ 11 | --output_type prob \ 12 | --width 256 \ 13 | --height 256 \ 14 | --train_dir ${train_dir} \ 15 | --batch_size 32 \ 16 | --dataset_path ${dataset_path} \ 17 | --dataset_split_name train \ 18 | --learning_rate 1e-4 \ 19 | --preprocess_method preprocess_normalize \ 20 | --no-use_fused_batchnorm \ 21 | --step_save_summaries 500 \ 22 | --step_save_checkpoint 500 \ 23 | --max_to_keep 3 \ 24 | --max_outputs 1 \ 25 | --augmentation_method resize_random_scale_crop_flip_rotate \ 26 | --max_epoch_from_restore 30000 \ 27 | --lambda_alpha_loss 1 \ 28 | --lambda_comp_loss 1 \ 29 | --lambda_grad_loss 1 \ 30 | --lambda_kd_loss 1 \ 31 | DeepLabModel \ 32 | --extractor mobilenet_v2 \ 33 | --weight_decay 4e-7 \ 34 | --depth_multiplier 0.5 \ 35 | --output_stride 16 \ 36 | -------------------------------------------------------------------------------- /scripts/train_mmnet_dm1.0_256.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -eux 3 | 4 | dataset_path=${1} 5 | train_dir=${2:-mmnet-traindir} 6 | 7 | python train.py \ 8 | --num_classes 2 \ 9 | --task_type matting \ 10 | --output_name output/score \ 11 | --output_type prob \ 12 | --width 256 \ 13 | --height 256 \ 14 | --train_dir ${train_dir} \ 15 | --batch_size 32 \ 16 | --dataset_path ${dataset_path} \ 17 | --dataset_split_name train \ 18 | --learning_rate 1e-4 \ 19 | --preprocess_method preprocess_normalize \ 20 | --no-use_fused_batchnorm \ 21 | --step_save_summaries 500 \ 22 | --step_save_checkpoint 500 \ 23 | --max_to_keep 3 \ 24 | --max_outputs 1 \ 25 | --augmentation_method resize_random_scale_crop_flip_rotate \ 26 | --max_epoch_from_restore 30000 \ 27 | --lambda_alpha_loss 1 \ 28 | --lambda_comp_loss 1 \ 29 | --lambda_grad_loss 1 \ 30 | --lambda_kd_loss 1 \ 31 | --lambda_aux_loss 1 \ 32 | MMNetModel \ 33 | --width_multiplier 1.0 \ 34 | --weight_decay 4e-7 \ 35 | -------------------------------------------------------------------------------- /scripts/valid_deeplab_os16_dm0.5_256.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -eux 3 | 4 | dataset_path=${1} 5 | checkpoint_path=${2:-deeplab-traindir} 6 | 7 | python evaluate.py \ 8 | --num_classes 2 \ 9 | --task_type matting \ 10 | --output_name output/score \ 11 | --output_type prob \ 12 | --width 256 \ 13 | --height 256 \ 14 | --target_eval_shape 800 600 \ 15 | --no-save_evaluation_image \ 16 | --batch_size 5 \ 17 | --checkpoint_path ${checkpoint_path} \ 18 | --dataset_path ${dataset_path} \ 19 | --dataset_split_name test \ 20 | --convert_to_pb \ 21 | --preprocess_method preprocess_normalize \ 22 | --no-use_fused_batchnorm \ 23 | --valid_type loop \ 24 | --max_outputs 1 \ 25 | --augmentation_method resize_bilinear \ 26 | --lambda_alpha_loss 1 \ 27 | --lambda_comp_loss 1 \ 28 | --lambda_grad_loss 1 \ 29 | --lambda_kd_loss 1 \ 30 | DeepLabModel \ 31 | --extractor mobilenet_v2 \ 32 | --depth_multiplier 0.5 \ 33 | --output_stride 16 \ 34 | -------------------------------------------------------------------------------- /scripts/valid_mmnet_dm1.0_256.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -eux 3 | 4 | dataset_path=${1} 5 | checkpoint_path=${2:-mmnet-traindir} 6 | 7 | python evaluate.py \ 8 | --num_classes 2 \ 9 | --task_type matting \ 10 | --output_name output/score \ 11 | --output_type prob \ 12 | --width 256 \ 13 | --height 256 \ 14 | --target_eval_shape 800 600 \ 15 | --no-save_evaluation_image \ 16 | --batch_size 5 \ 17 | --checkpoint_path ${checkpoint_path} \ 18 | --dataset_path ${dataset_path} \ 19 | --dataset_split_name test \ 20 | --convert_to_pb \ 21 | --preprocess_method preprocess_normalize \ 22 | --no-use_fused_batchnorm \ 23 | --valid_type loop \ 24 | --max_outputs 1 \ 25 | --augmentation_method resize_bilinear \ 26 | --lambda_alpha_loss 1 \ 27 | --lambda_comp_loss 1 \ 28 | --lambda_grad_loss 1 \ 29 | --lambda_kd_loss 1 \ 30 | --lambda_aux_loss 1 \ 31 | MMNetModel \ 32 | --width_multiplier 1.0 \ 33 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import atexit 3 | from typing import List 4 | 5 | import tensorflow as tf 6 | 7 | import const 8 | import common.utils as utils 9 | import factory.matting_nets as matting_nets 10 | from datasets.data_wrapper_base import DataWrapperBase 11 | from datasets.matting_data_wrapper import MattingDataWrapper 12 | from factory.base import CNNModel 13 | from helper.base import Base 14 | from helper.trainer import MattingTrainer 15 | from helper.trainer import TrainerBase 16 | from metrics.base import MetricManagerBase 17 | 18 | 19 | def train(args): 20 | trainer = build_trainer(args) 21 | trainer.train() 22 | 23 | 24 | def build_trainer(args, trainer_cls=MattingTrainer): 25 | is_training = True 26 | session = tf.Session(config=const.TF_SESSION_CONFIG) 27 | 28 | # only one dataset split is assumed for training 29 | dataset_name = args.dataset_split_name[0] 30 | 31 | dataset = MattingDataWrapper( 32 | args, 33 | session, 34 | dataset_name, 35 | is_training=is_training, 36 | ) 37 | 38 | images_original, masks_original, images, masks = dataset.get_input_and_output_op() 39 | 40 | model = eval(f"matting_nets.{args.model}")(args, dataset) 41 | model.build( 42 | images_original=images_original, 43 | images=images, 44 | masks_original=masks_original, 45 | masks=masks, 46 | is_training=is_training, 47 | ) 48 | 49 | trainer = trainer_cls( 50 | model, 51 | session, 52 | args, 53 | dataset, 54 | dataset_name, 55 | ) 56 | 57 | return trainer 58 | 59 | 60 | def parse_arguments(arguments: List[str]=None): 61 | parser = argparse.ArgumentParser(description=__doc__) 62 | subparsers = parser.add_subparsers(title="Model", description="") 63 | 64 | # -- * -- Common Arguments & Each Model's Arguments -- * -- 65 | CNNModel.add_arguments(parser, default_type="matting") 66 | matting_nets.MattingNetModel.add_arguments(parser) 67 | for class_name in matting_nets._available_nets: 68 | subparser = subparsers.add_parser(class_name) 69 | subparser.add_argument("--model", default=class_name, type=str, help="DO NOT FIX ME") 70 | add_matting_net_arguments = eval(f"matting_nets.{class_name}.add_arguments") 71 | add_matting_net_arguments(subparser) 72 | 73 | # -- * -- Parameters & Options for MattingTrainer -- * -- 74 | TrainerBase.add_arguments(parser) 75 | MattingTrainer.add_arguments(parser) 76 | Base.add_arguments(parser) 77 | DataWrapperBase.add_arguments(parser) 78 | MattingDataWrapper.add_arguments(parser) 79 | MetricManagerBase.add_arguments(parser) 80 | 81 | # -- Parse arguments 82 | args = parser.parse_args(arguments) 83 | 84 | # Hack!!! subparser's arguments and dynamically add it to args(Namespace) 85 | # it will be used for convert.py 86 | model_arguments = utils.get_subparser_argument_list(parser, args.model) 87 | args.model_arguments = model_arguments 88 | 89 | return args 90 | 91 | 92 | if __name__ == "__main__": 93 | args = parse_arguments() 94 | log = utils.get_logger("MattingTrainer", None) 95 | 96 | utils.update_train_dir(args) 97 | 98 | if args.testmode: 99 | atexit.register(utils.exit_handler, args.train_dir) 100 | 101 | if args.step1_mode: 102 | utils.setup_step1_mode(args) 103 | 104 | log.info(args) 105 | train(args) 106 | --------------------------------------------------------------------------------