├── .gitignore ├── LICENSE ├── README.md ├── assets ├── grouping module.png ├── gvcnn_framework.png └── view-group-shape_architecture.png ├── data_utils ├── ObjFile.py ├── make_views_dir.py ├── obj2png.py └── off2obj.py ├── dataset_tools ├── _create_modelnet_tf_record_each.py ├── create_modelnet_tf_record.py └── dataset_util.py ├── eval.py ├── eval_data.py ├── nets ├── __init__.py ├── inception_utils.py ├── inception_v3.py ├── model.py ├── resnet_utils.py └── resnet_v2.py ├── train.py ├── train_data.py ├── unit_test.py ├── utils ├── _train_helper.py ├── downsize_modelnet.py └── train_utils.py └── val_data.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | *.pyc 107 | tfmodels/ 108 | _temp/ 109 | _nets/ 110 | pre-trained/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## GVCNN (Group-View Convolutional Neural Networks for 3D Shape Recognition) 2 | ![](assets/gvcnn_framework.png) 3 | ![](assets/view-group-shape_architecture.png) 4 | 5 | 6 | ## Data 7 | - Download [modelnet10-Class Orientation-aligned Subset](http://modelnet.cs.princeton.edu/) 8 | - make .png images in order below 9 | - data_utils/make_views_dir.py 10 | - data_utils/off2obj.py (after 'sudo apt install openctm-tools') 11 | - data_utils/obj2png.py 12 | - Or You can create 2D dataset from 3D objects (.obj, .stl, and .off), using [BlenderPhong](https://github.com/WeiTang114/BlenderPhong). 13 | - Or Downsized modelnet40(from https://drive.google.com/file/d/0B4v2jR3WsindMUE3N2xiLVpyLW8/view) to modelnet12/6-view. 14 | 15 | ## Quick Start 16 | - make group-view image tfrecord file 17 | - dataset_tools/create_modelnet_tf_record.py 18 | - train.py 19 | 20 | ## TODO 21 | - balanced sampler 22 | 23 | ## References from 24 | - http://openaccess.thecvf.com/content_cvpr_2018/papers/Feng_GVCNN_Group-View_Convolutional_CVPR_2018_paper.pdf 25 | - https://github.com/WeiTang114/MVCNN-TensorFlow 26 | - https://github.com/pclausen/obj2png 27 | 28 | -------------------------------------------------------------------------------- /assets/grouping module.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ace19-dev/gvcnn-tf/9555509c714e194c0b110b1aaee9f9bd356078cf/assets/grouping module.png -------------------------------------------------------------------------------- /assets/gvcnn_framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ace19-dev/gvcnn-tf/9555509c714e194c0b110b1aaee9f9bd356078cf/assets/gvcnn_framework.png -------------------------------------------------------------------------------- /assets/view-group-shape_architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ace19-dev/gvcnn-tf/9555509c714e194c0b110b1aaee9f9bd356078cf/assets/view-group-shape_architecture.png -------------------------------------------------------------------------------- /data_utils/ObjFile.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Sat Jul 7 00:41:00 2018 4 | 5 | @author: Peter M. Clausen, pclausen 6 | 7 | MIT License 8 | 9 | Copyright (c) 2018 pclausen 10 | 11 | Permission is hereby granted, free of charge, to any person obtaining a copy 12 | of this software and associated documentation files (the "Software"), to deal 13 | in the Software without restriction, including without limitation the rights 14 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 15 | copies of the Software, and to permit persons to whom the Software is 16 | furnished to do so, subject to the following conditions: 17 | 18 | The above copyright notice and this permission notice shall be included in all 19 | copies or substantial portions of the Software. 20 | 21 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 22 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 23 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 24 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 25 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 26 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 27 | SOFTWARE. 28 | 29 | """ 30 | 31 | import re 32 | import numpy as np 33 | import matplotlib.pyplot as plt 34 | from mpl_toolkits.mplot3d import Axes3D 35 | 36 | RE=re.compile(r'/\d+') 37 | 38 | class ObjFile: 39 | """ 40 | >>> obj_file = '../obj/bun_zipper_res2.obj' 41 | >>> out_file = '../obj/bun_zipper_res2.png' 42 | >>> obj = ObjFile(obj_file) 43 | >>> len(obj.nodes)==8147 44 | True 45 | >>> len(obj.faces)==16301 46 | True 47 | >>> nmin,nmax=obj.MinMaxNodes() 48 | >>> np.allclose(nmin, np.array([-0.094572, 0. , -0.061874]) ) 49 | True 50 | >>> np.allclose(nmax, np.array([0.060935, 0.186643, 0.05869 ])) 51 | True 52 | >>> obj.Plot(out_file) 53 | """ 54 | 55 | def __init__(self, obj_file=None): 56 | self.nodes=None 57 | self.faces=None 58 | if obj_file: 59 | self.ObjParse(obj_file) 60 | 61 | def ObjInfo(self): 62 | print ("Num vertices : %d"%(len(self.nodes))) 63 | print ("Num faces : %d"%(len(self.faces))) 64 | nmin,nmax=self.MinMaxNodes() 65 | print ("Min/Max : %s %s"%(np.around(nmin,3), np.around(nmax,3) )) 66 | 67 | @staticmethod 68 | def MinMax3d(arr): 69 | nmin=1E9*np.ones((3)) 70 | nmax=-1E9*np.ones((3)) 71 | for a in arr: 72 | for i in range(3): 73 | nmin[i]=min(nmin[i],a[i]) 74 | nmax[i]=max(nmax[i],a[i]) 75 | return (nmin,nmax) 76 | 77 | def MinMaxNodes(self): 78 | return ObjFile.MinMax3d(self.nodes) 79 | 80 | 81 | def ObjParse(self, obj_file ): 82 | f=open(obj_file) 83 | lines=f.readlines() 84 | f.close() 85 | nodes=[] 86 | # add zero entry to get ids right 87 | nodes.append([.0,.0,.0]) 88 | faces=[] 89 | for line in lines: 90 | if 'v' == line[0]: 91 | v=line.split() 92 | nodes.append(ObjFile.ToFloats(v[1:])) 93 | if 'f' == line[0]: 94 | # remove /int 95 | line=re.sub(RE,'',line) 96 | f=line.split() 97 | faces.append(ObjFile.ToInts(f[1:])) 98 | 99 | self.nodes=np.array(nodes) 100 | self.faces=faces 101 | 102 | 103 | def ObjWrite(self, obj_file): 104 | f=open(obj_file, 'w') 105 | for n in self.nodes[1:]: # skip first dummy 'node' 106 | f.write('v ') 107 | for nn in n: 108 | f.write('%g '%(nn)) 109 | f.write('\n') 110 | for ff in self.faces: 111 | f.write('f ') 112 | for fff in ff: 113 | f.write('%d '%(fff)) 114 | f.write('\n') 115 | 116 | 117 | @staticmethod 118 | def ToFloats(n): 119 | if isinstance(n,list): 120 | v=[] 121 | for nn in n: 122 | v.append(float(nn)) 123 | return v 124 | else: 125 | return float(n) 126 | 127 | @staticmethod 128 | def ToInts(n): 129 | if isinstance(n,list): 130 | v=[] 131 | for nn in n: 132 | v.append(int(nn[:-2])) 133 | return v 134 | else: 135 | return int(n) 136 | 137 | @staticmethod 138 | def Normalize(v): 139 | v2=np.linalg.norm(v) 140 | if v2<0.000000001: 141 | return v 142 | else: 143 | return v/v2 144 | 145 | def QuadToTria(self): 146 | trifaces=[] 147 | for f in self.faces: 148 | if len(f)==3: 149 | trifaces.append(f) 150 | elif len(f)==4: 151 | f1=[f[0],f[1],f[2]] 152 | f2=[f[0],f[2],f[3]] 153 | trifaces.append(f1) 154 | trifaces.append(f2) 155 | return trifaces 156 | 157 | @staticmethod 158 | def ScaleVal(v,scale,minval=True): 159 | 160 | if minval: 161 | if v>0: 162 | return v*(1.-scale) 163 | else: 164 | return v*scale 165 | else: # maxval 166 | if v>0: 167 | return v*scale 168 | else: 169 | return v*(1.-scale) 170 | 171 | 172 | def Plot(self, output_file=None, elevation=None, azim=None,dpi=None,scale=None,animate=None): 173 | plt.ioff() 174 | tri=self.QuadToTria() 175 | fig = plt.figure() 176 | ax = fig.gca(projection='3d') 177 | ax.plot_trisurf(self.nodes[:,0],self.nodes[:,1],self.nodes[:,2], triangles=tri) 178 | ax.axis('off') 179 | fig.subplots_adjust(left=0, right=1, bottom=0, top=1) 180 | nmin,nmax=self.MinMaxNodes() 181 | if scale is not None: 182 | ax.set_xlim(ObjFile.ScaleVal(nmin[0],scale),ObjFile.ScaleVal(nmax[0],scale,False)) 183 | ax.set_ylim(ObjFile.ScaleVal(nmin[1],scale),ObjFile.ScaleVal(nmax[1],scale,False)) 184 | ax.set_zlim(ObjFile.ScaleVal(nmin[2],scale),ObjFile.ScaleVal(nmax[2],scale,False)) 185 | if elevation is not None and azim is not None: 186 | ax.view_init(elevation, azim) 187 | elif elevation is not None: 188 | ax.view_init(elevation, 30) 189 | elif azim is not None: 190 | ax.view_init(30, azim) 191 | else: 192 | ax.view_init(30, 30) 193 | 194 | if output_file: 195 | #fig.tight_layout() 196 | #fig.subplots_adjust(left=-0.2, bottom=-0.2, right=1.2, top=1.2, 197 | # wspace=0, hspace=0) 198 | #ax.autoscale_view(tight=True) 199 | #ax.autoscale(tight=True) 200 | #ax.margins(tight=True) 201 | plt.savefig(output_file,dpi=dpi,transparent=True) 202 | plt.close() 203 | else: 204 | if animate: 205 | # rotate the axes and update 206 | for elevation in np.linspace(-180,180,10): 207 | for azim in np.linspace(-180,180,10): 208 | print('--elevation {} --azim {}'.format(elevation, azim)) 209 | ax.view_init(elevation, azim) 210 | textvar=ax.text2D(0.05, 0.95, '--elevation {} --azim {}'.format(elevation, azim), transform=ax.transAxes) 211 | plt.draw() 212 | #plt.show() 213 | plt.pause(.5) 214 | textvar.remove() 215 | else: 216 | plt.show() 217 | 218 | if __name__ == '__main__': 219 | import doctest 220 | doctest.testmod() 221 | 222 | -------------------------------------------------------------------------------- /data_utils/make_views_dir.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import tensorflow as tf 4 | 5 | flags = tf.app.flags 6 | 7 | FLAGS = flags.FLAGS 8 | 9 | flags.DEFINE_string('source_dir', '/home/ace19/dl_data/ModelNet10', 10 | 'source dir') 11 | 12 | flags.DEFINE_string('target_dir', '/home/ace19/dl_data/modelnet10', 13 | 'target dir') 14 | 15 | flags.DEFINE_string('dataset_category', 'test', 16 | 'train or test') 17 | 18 | 19 | def main(unused_argv): 20 | root = os.listdir(FLAGS.source_dir) 21 | root.sort() 22 | 23 | for cls in root: 24 | if not os.path.isdir(os.path.join(FLAGS.source_dir, cls)): 25 | continue 26 | 27 | dataset = os.path.join(FLAGS.source_dir, cls, FLAGS.dataset_category) 28 | data_list = os.listdir(dataset) 29 | 30 | for f in data_list: 31 | if '.off' in f: 32 | target_dir = os.path.join(FLAGS.target_dir, cls, FLAGS.dataset_category) 33 | p = os.path.join(target_dir, f) 34 | os.makedirs(p) 35 | 36 | 37 | 38 | if __name__ == '__main__': 39 | tf.compat.v1.app.run() -------------------------------------------------------------------------------- /data_utils/obj2png.py: -------------------------------------------------------------------------------- 1 | """ 2 | Created on Sat Jul 7 00:40:00 2018 3 | 4 | @author: Peter M. Clausen, pclausen 5 | 6 | MIT License 7 | 8 | Copyright (c) 2018 pclausen 9 | 10 | Permission is hereby granted, free of charge, to any person obtaining a copy 11 | of this software and associated documentation files (the "Software"), to deal 12 | in the Software without restriction, including without limitation the rights 13 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 14 | copies of the Software, and to permit persons to whom the Software is 15 | furnished to do so, subject to the following conditions: 16 | 17 | The above copyright notice and this permission notice shall be included in all 18 | copies or substantial portions of the Software. 19 | 20 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 21 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 22 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 23 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 24 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 25 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 26 | SOFTWARE. 27 | 28 | """ 29 | 30 | import tensorflow as tf 31 | 32 | import os 33 | import data_utils.ObjFile 34 | import sys 35 | import os 36 | import glob 37 | 38 | 39 | flags = tf.app.flags 40 | 41 | FLAGS = flags.FLAGS 42 | 43 | 44 | flags.DEFINE_string('source_dir', '/home/ace19/dl_data/ModelNet10', 45 | 'Directory where .obj files is.') 46 | 47 | flags.DEFINE_string('target_dir', '/home/ace19/dl_data/modelnet', 48 | 'Output directory.') 49 | 50 | flags.DEFINE_string('target_file_ext', '.obj', 51 | 'target file extension') 52 | 53 | flags.DEFINE_string('dataset_category', 'train', 54 | 'train or test') 55 | 56 | flags.DEFINE_integer('num_views', 8, 'Number of views') 57 | flags.DEFINE_float('azim', 45, 'Azimuth angle of view in degrees.') 58 | flags.DEFINE_float('elevation', None, 'Elevation angle of view in degrees.') 59 | flags.DEFINE_string('quality', 'LOW', 'Image quality (HIGH,MEDIUM,LOW). Default: LOW') 60 | flags.DEFINE_float('scale', 0.9, 61 | 'Scale picture by descreasing boundaries. Lower than 1. gives a larger object.') 62 | flags.DEFINE_string('animate', None, 63 | 'Animate instead of creating picture file as animation, from elevation -180:180 and azim -180:180') 64 | 65 | # parser.add_argument("-v", "--view", 66 | # dest='view', 67 | # action='store_true', 68 | # help="View instead of creating picture file.") 69 | 70 | 71 | def main(unused_argv): 72 | tf.logging.set_verbosity(tf.logging.INFO) 73 | 74 | res={'HIGH':1200,'MEDIUM':600,'LOW':300} 75 | dpi=None 76 | if FLAGS.quality: 77 | if type(FLAGS.quality)==int: 78 | dpi=FLAGS.quality 79 | elif FLAGS.quality.upper() in res: 80 | dpi=res[FLAGS.quality.upper()] 81 | 82 | azim=None 83 | if FLAGS.azim is not None: 84 | azim=FLAGS.azim 85 | 86 | elevation=None 87 | if FLAGS.elevation is not None: 88 | elevation=FLAGS.elevation 89 | 90 | scale=None 91 | if FLAGS.scale: 92 | scale=FLAGS.scale 93 | 94 | animate=None 95 | if FLAGS.animate: 96 | animate=FLAGS.animate 97 | 98 | 99 | root = os.listdir(FLAGS.source_dir) 100 | root.sort() 101 | for cls in root: 102 | if not os.path.isdir(os.path.join(FLAGS.source_dir, cls)): 103 | continue 104 | 105 | dataset = os.path.join(FLAGS.source_dir, cls, FLAGS.dataset_category) 106 | files = os.listdir(dataset) 107 | files.sort() 108 | 109 | for objfile in files: 110 | obj_file_path = os.path.join(dataset, objfile) 111 | if os.path.isfile(obj_file_path) and '.obj' in objfile: 112 | target_path = objfile.replace('.obj','.off') 113 | # if FLAGS.outfile: 114 | # outfile=FLAGS.outfile 115 | # if FLAGS.view: 116 | # outfile=None 117 | # else: 118 | # print('Converting %s to %s'%(objfile, outfile)) 119 | 120 | ob = data_utils.ObjFile.ObjFile(obj_file_path) 121 | 122 | for i in range(FLAGS.num_views): 123 | new_output = objfile[:-4] + '.' + str(i) + '.png' 124 | print('Converting %s to %s' % (objfile, new_output)) 125 | outfile_path = os.path.join(FLAGS.target_dir, cls, FLAGS.dataset_category, 126 | target_path, new_output) 127 | ob.Plot(outfile_path, 128 | elevation=elevation, 129 | azim=azim*(i+1), 130 | dpi=dpi, 131 | scale=scale, 132 | animate=animate) 133 | # else: 134 | # print('File %s not found or not file type .obj'%objfile) 135 | # sys.exit(1) 136 | 137 | if __name__ == '__main__': 138 | tf.app.run() -------------------------------------------------------------------------------- /data_utils/off2obj.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | import os 4 | 5 | flags = tf.app.flags 6 | 7 | FLAGS = flags.FLAGS 8 | 9 | 10 | flags.DEFINE_string('source_dir', '/home/ace19/dl_data/ModelNet10', 11 | 'source dir') 12 | 13 | flags.DEFINE_string('target_dir', '/home/ace19/dl_data/ModelNet10', 14 | 'target dir') 15 | 16 | flags.DEFINE_string('target_file_ext', '.obj', 17 | 'target file extension') 18 | 19 | flags.DEFINE_string('dataset_category', 'train', 20 | 'train or test') 21 | 22 | 23 | def main(unused_argv): 24 | tf.logging.set_verbosity(tf.logging.INFO) 25 | 26 | root = os.listdir(FLAGS.source_dir) 27 | root.sort() 28 | 29 | for cls in root: 30 | if not os.path.isdir(os.path.join(FLAGS.source_dir, cls)): 31 | continue 32 | 33 | dataset = os.path.join(FLAGS.source_dir, cls, FLAGS.dataset_category) 34 | off_files = os.listdir(dataset) 35 | off_files.sort() 36 | 37 | total = len(off_files) 38 | for i, file in enumerate(off_files): 39 | if i % 50 == 0: 40 | tf.logging.info('\n\ncompleted \'%s\': %d/%d' % (cls, i, total)) 41 | 42 | file_name = os.path.basename(file)[:-4] 43 | output_file_path = os.path.join(FLAGS.target_dir, cls, FLAGS.dataset_category, file_name + FLAGS.target_file_ext) 44 | cmd = 'ctmconv ' + os.path.join(dataset, file) + ' ' + output_file_path 45 | os.system(cmd) 46 | 47 | 48 | if __name__ == '__main__': 49 | tf.app.run() -------------------------------------------------------------------------------- /dataset_tools/_create_modelnet_tf_record_each.py: -------------------------------------------------------------------------------- 1 | """ 2 | Convert PatchCamelyon (PCam) dataset to TFRecord for classification. 3 | 4 | """ 5 | from __future__ import absolute_import 6 | from __future__ import division 7 | from __future__ import print_function 8 | 9 | import hashlib 10 | import io 11 | import logging 12 | import os 13 | import random 14 | import numpy as np 15 | 16 | import PIL.Image 17 | import tensorflow as tf 18 | 19 | from dataset_tools import dataset_util 20 | 21 | 22 | # RANDOM_SEED = 8045 23 | 24 | flags = tf.app.flags 25 | flags.DEFINE_string('dataset_dir', 26 | '/home/ace19/dl_data/modelnet12/view/classes', 27 | 'Root Directory to raw modelnet dataset.') 28 | flags.DEFINE_string('output_dir', 29 | '/home/ace19/dl_data/modelnet12/tfrecords', 30 | 'Path to output TFRecord') 31 | flags.DEFINE_string('dataset_category', 32 | 'test', 33 | 'dataset category, train|validate|test') 34 | 35 | FLAGS = flags.FLAGS 36 | 37 | _FILE_PATTERN = 'modelnet12_%s_%s.tfrecord' 38 | 39 | 40 | def get_data_map_dict(label_to_index): 41 | label_map_dict = {} 42 | view_map_dict = {} 43 | 44 | cls_lst = os.listdir(FLAGS.dataset_dir) 45 | for i, cls in enumerate(cls_lst): 46 | if not os.path.isdir(os.path.join(FLAGS.dataset_dir, cls)): 47 | continue 48 | 49 | data_path = os.path.join(FLAGS.dataset_dir, cls, FLAGS.dataset_category) 50 | img_lst = os.listdir(data_path) 51 | for n, img in enumerate(img_lst): 52 | img_path = os.path.join(data_path, img) 53 | view_lst = os.listdir(img_path) 54 | views = [] 55 | for k, view in enumerate(view_lst): 56 | v_path = os.path.join(img_path, view) 57 | views.append(v_path) 58 | label_map_dict[img] = label_to_index[cls] 59 | view_map_dict[img] = views 60 | 61 | return label_map_dict, view_map_dict 62 | 63 | 64 | def dict_to_tf_example(image, 65 | label_map_dict=None, 66 | view_map_dict=None): 67 | """ 68 | Args: 69 | image: a single image name 70 | label_map_dict: A map from string label names to integers ids. 71 | image_subdirectory: String specifying subdirectory within the 72 | PCam dataset directory holding the actual image data. 73 | 74 | Returns: 75 | example: The converted tf.Example. 76 | 77 | Raises: 78 | ValueError: if the image pointed to by image is not a valid PNG 79 | """ 80 | # full_path = os.path.join(dataset_directory, image_subdirectory, image_name) 81 | 82 | filenames = [] 83 | sourceids = [] 84 | encoded_pngs = [] 85 | widths = [] 86 | heights = [] 87 | formats = [] 88 | labels = [] 89 | keys = [] 90 | 91 | view_lst = view_map_dict[image] 92 | label = label_map_dict[image] 93 | for i, view_path in enumerate(view_lst): 94 | filenames.append(view_path.encode('utf8')) 95 | sourceids.append(view_path.encode('utf8')) 96 | with tf.gfile.GFile(view_path, 'rb') as fid: 97 | encoded_png = fid.read() 98 | encoded_pngs.append(encoded_png) 99 | encoded_png_io = io.BytesIO(encoded_png) 100 | image = PIL.Image.open(encoded_png_io) 101 | width, height = image.size 102 | widths.append(width) 103 | heights.append(height) 104 | 105 | format = image.format 106 | formats.append(format.encode('utf8')) 107 | if format!= 'PNG': 108 | raise ValueError('Image format not PNG') 109 | key = hashlib.sha256(encoded_png).hexdigest() 110 | keys.append(key.encode('utf8')) 111 | # labels.append(label) 112 | 113 | example = tf.train.Example(features=tf.train.Features(feature={ 114 | 'image/height': dataset_util.int64_list_feature(heights), 115 | 'image/width': dataset_util.int64_list_feature(widths), 116 | 'image/filename': dataset_util.bytes_list_feature(filenames), 117 | 'image/source_id': dataset_util.bytes_list_feature(sourceids), 118 | 'image/key/sha256': dataset_util.bytes_list_feature(keys), 119 | 'image/encoded': dataset_util.bytes_list_feature(encoded_pngs), 120 | 'image/format': dataset_util.bytes_list_feature(formats), 121 | 'image/label': dataset_util.int64_feature(label), 122 | # 'image/text': dataset_util.bytes_feature('label_text'.encode('utf8')) 123 | })) 124 | return example 125 | 126 | 127 | def main(_): 128 | tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO) 129 | 130 | options = tf.io.TFRecordOptions(tf.compat.v1.python_io.TFRecordCompressionType.GZIP) 131 | # writer = tf.python_io.TFRecordWriter(FLAGS.output_dir, options=options) 132 | 133 | dataset_lst = os.listdir(FLAGS.dataset_dir) 134 | dataset_lst.sort() 135 | label_to_index = {} 136 | for i, cls in enumerate(dataset_lst): 137 | cls_path = os.path.join(FLAGS.dataset_dir, cls) 138 | if os.path.isdir(cls_path): 139 | label_to_index[cls] = i 140 | 141 | label_map_dict, view_map_dict = get_data_map_dict(label_to_index) 142 | 143 | if not os.path.exists(FLAGS.output_dir): 144 | os.makedirs(FLAGS.output_dir) 145 | 146 | tf.compat.v1.logging.info('Reading from modelnet dataset.') 147 | cls_lst = os.listdir(FLAGS.dataset_dir) 148 | for i, label in enumerate(cls_lst): 149 | tfrecord_name = os.path.join(FLAGS.output_dir, 150 | _FILE_PATTERN % (FLAGS.dataset_category, label)) 151 | tf.compat.v1.logging.info('tfrecord name %s: ', tfrecord_name) 152 | writer = tf.io.TFRecordWriter(tfrecord_name, options=options) 153 | 154 | data_path = os.path.join(FLAGS.dataset_dir, label, FLAGS.dataset_category) 155 | if not os.path.isdir(data_path): 156 | continue 157 | img_lst = os.listdir(data_path) 158 | for idx, image in enumerate(img_lst): 159 | if idx % 100 == 0: 160 | tf.compat.v1.logging.info('On image %d of %d', idx, len(img_lst)) 161 | tf_example = dict_to_tf_example(image, label_map_dict, view_map_dict) 162 | writer.write(tf_example.SerializeToString()) 163 | 164 | writer.close() 165 | 166 | 167 | if __name__ == '__main__': 168 | tf.app.run() 169 | -------------------------------------------------------------------------------- /dataset_tools/create_modelnet_tf_record.py: -------------------------------------------------------------------------------- 1 | """ 2 | Convert PatchCamelyon (PCam) dataset to TFRecord for classification. 3 | 4 | """ 5 | from __future__ import absolute_import 6 | from __future__ import division 7 | from __future__ import print_function 8 | 9 | import hashlib 10 | import io 11 | import logging 12 | import os 13 | import random 14 | import numpy as np 15 | 16 | import PIL.Image 17 | import tensorflow as tf 18 | 19 | from dataset_tools import dataset_util 20 | 21 | 22 | # RANDOM_SEED = 8045 23 | 24 | flags = tf.compat.v1.app.flags 25 | flags.DEFINE_string('dataset_dir', 26 | '/home/ace19/dl_data/modelnet5/view/classes', 27 | 'Root Directory to raw modelnet dataset.') 28 | flags.DEFINE_string('output_dir', 29 | '/home/ace19/dl_data/modelnet5', 30 | 'Path to output TFRecord') 31 | flags.DEFINE_string('dataset_category', 32 | 'train', 33 | 'dataset category, train|validate|test') 34 | 35 | FLAGS = flags.FLAGS 36 | 37 | _FILE_PATTERN = 'modelnet%d_%dview_%s.record' 38 | 39 | filter = ['1','2','4','7','8','10'] 40 | 41 | 42 | def get_data_map_dict(label_to_index): 43 | label_map_dict = {} 44 | view_map_dict = {} 45 | 46 | cls_lst = os.listdir(FLAGS.dataset_dir) 47 | for i, cls in enumerate(cls_lst): 48 | if not os.path.isdir(os.path.join(FLAGS.dataset_dir, cls)): 49 | continue 50 | 51 | data_path = os.path.join(FLAGS.dataset_dir, cls, FLAGS.dataset_category) 52 | img_lst = os.listdir(data_path) 53 | for n, img in enumerate(img_lst): 54 | img_path = os.path.join(data_path, img) 55 | view_lst = os.listdir(img_path) 56 | views = [] 57 | for k, view in enumerate(view_lst): 58 | v_path = os.path.join(img_path, view) 59 | views.append(v_path) 60 | label_map_dict[img] = label_to_index[cls] 61 | view_map_dict[img] = views 62 | 63 | return label_map_dict, view_map_dict 64 | 65 | 66 | def dict_to_tf_example(image, 67 | label_map_dict=None, 68 | view_map_dict=None): 69 | """ 70 | Args: 71 | image: a single image name 72 | label_map_dict: A map from string label names to integers ids. 73 | image_subdirectory: String specifying subdirectory within the 74 | PCam dataset directory holding the actual image data. 75 | 76 | Returns: 77 | example: The converted tf.Example. 78 | 79 | Raises: 80 | ValueError: if the image pointed to by image is not a valid PNG 81 | """ 82 | # full_path = os.path.join(dataset_directory, image_subdirectory, image_name) 83 | 84 | filenames = [] 85 | sourceids = [] 86 | encoded_pngs = [] 87 | widths = [] 88 | heights = [] 89 | formats = [] 90 | labels = [] 91 | keys = [] 92 | 93 | view_lst = view_map_dict[image] 94 | label = label_map_dict[image] 95 | for i, view_path in enumerate(view_lst): 96 | # Make fast training and verifying 97 | if view_path.split('/')[-1].split('.')[1] in filter: 98 | continue 99 | 100 | filenames.append(view_path.encode('utf8')) 101 | sourceids.append(view_path.encode('utf8')) 102 | with tf.io.gfile.GFile(view_path, 'rb') as fid: 103 | encoded_png = fid.read() 104 | encoded_pngs.append(encoded_png) 105 | encoded_png_io = io.BytesIO(encoded_png) 106 | image = PIL.Image.open(encoded_png_io) 107 | width, height = image.size 108 | widths.append(width) 109 | heights.append(height) 110 | 111 | format = image.format 112 | formats.append(format.encode('utf8')) 113 | if format!= 'PNG': 114 | raise ValueError('Image format not PNG') 115 | key = hashlib.sha256(encoded_png).hexdigest() 116 | keys.append(key.encode('utf8')) 117 | # labels.append(label) 118 | 119 | example = tf.train.Example(features=tf.train.Features(feature={ 120 | 'image/height': dataset_util.int64_list_feature(heights), 121 | 'image/width': dataset_util.int64_list_feature(widths), 122 | 'image/filename': dataset_util.bytes_list_feature(filenames), 123 | 'image/source_id': dataset_util.bytes_list_feature(sourceids), 124 | 'image/key/sha256': dataset_util.bytes_list_feature(keys), 125 | 'image/encoded': dataset_util.bytes_list_feature(encoded_pngs), 126 | 'image/format': dataset_util.bytes_list_feature(formats), 127 | 'image/label': dataset_util.int64_feature(label), 128 | # 'image/text': dataset_util.bytes_feature('label_text'.encode('utf8')) 129 | })) 130 | return example 131 | 132 | 133 | def main(_): 134 | tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO) 135 | 136 | dataset_lst = os.listdir(FLAGS.dataset_dir) 137 | dataset_lst.sort() 138 | 139 | tfrecord_name = os.path.join(FLAGS.output_dir, _FILE_PATTERN % 140 | (len(dataset_lst), 12-len(filter), FLAGS.dataset_category)) 141 | options = tf.io.TFRecordOptions(tf.compat.v1.io.TFRecordCompressionType.GZIP) 142 | writer = tf.io.TFRecordWriter(tfrecord_name, options=options) 143 | 144 | # dataset_lst = os.listdir(FLAGS.dataset_dir) 145 | # dataset_lst.sort() 146 | label_to_index = {} 147 | for i, cls in enumerate(dataset_lst): 148 | cls_path = os.path.join(FLAGS.dataset_dir, cls) 149 | if os.path.isdir(cls_path): 150 | label_to_index[cls] = i 151 | 152 | label_map_dict, view_map_dict = get_data_map_dict(label_to_index) 153 | 154 | tf.compat.v1.logging.info('Reading from modelnet dataset.') 155 | cls_lst = os.listdir(FLAGS.dataset_dir) 156 | for i, label in enumerate(cls_lst): 157 | data_path = os.path.join(FLAGS.dataset_dir, label, FLAGS.dataset_category) 158 | if not os.path.isdir(data_path): 159 | continue 160 | img_lst = os.listdir(data_path) 161 | for idx, image in enumerate(img_lst): 162 | if idx % 100 == 0: 163 | tf.compat.v1.logging.info('On image %d of %d', idx, len(img_lst)) 164 | tf_example = dict_to_tf_example(image, label_map_dict, view_map_dict) 165 | writer.write(tf_example.SerializeToString()) 166 | 167 | writer.close() 168 | 169 | 170 | if __name__ == '__main__': 171 | tf.compat.v1.app.run() 172 | -------------------------------------------------------------------------------- /dataset_tools/dataset_util.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Utility functions for creating TFRecord data sets.""" 17 | 18 | import tensorflow as tf 19 | 20 | 21 | def int64_feature(value): 22 | return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) 23 | 24 | 25 | def int64_list_feature(value): 26 | return tf.train.Feature(int64_list=tf.train.Int64List(value=value)) 27 | 28 | 29 | def bytes_feature(value): 30 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) 31 | 32 | 33 | def bytes_list_feature(value): 34 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=value)) 35 | 36 | 37 | def float_list_feature(value): 38 | return tf.train.Feature(float_list=tf.train.FloatList(value=value)) 39 | 40 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | # Given C classes in the classification task, 2 | # the output of the last layer in our network architecture is a 3 | # vector with C elements, i.e., V = {v 1 , v 2 , · · · , v C }. Each 4 | # element represents the probability that the subject belongs 5 | # to that category. And the category with the largest value is 6 | # the category it belongs to 7 | 8 | 9 | import datetime 10 | import os 11 | 12 | import tensorflow as tf 13 | 14 | import eval_data 15 | from nets import model 16 | 17 | slim = tf.contrib.slim 18 | 19 | flags = tf.app.flags 20 | FLAGS = flags.FLAGS 21 | 22 | 23 | NUM_GROUP = 10 24 | 25 | # temporary constant 26 | MODELNET_EVAL_DATA_SIZE = 150 27 | 28 | 29 | # Dataset settings. 30 | flags.DEFINE_string('dataset_path', '/home/ace19/dl_data/modelnet/test.record', 31 | 'Where the dataset reside.') 32 | 33 | flags.DEFINE_string('checkpoint_path', 34 | os.getcwd() + '/models', 35 | 'Directory where to read training checkpoints.') 36 | 37 | flags.DEFINE_integer('batch_size', 4, 'batch size') 38 | flags.DEFINE_integer('num_views', 6, 'number of views') 39 | flags.DEFINE_integer('height', 299, 'height') 40 | flags.DEFINE_integer('width', 299, 'width') 41 | flags.DEFINE_string('labels', 42 | 'airplane,bed,bookshelf,toilet,vase', 43 | 'number of classes') 44 | 45 | def main(unused_argv): 46 | tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO) 47 | 48 | labels = FLAGS.labels.split(',') 49 | num_classes = len(labels) 50 | 51 | # Define the model 52 | X = tf.compat.v1.placeholder(tf.float32, 53 | [None, FLAGS.num_views, FLAGS.height, FLAGS.width, 3], 54 | name='X') 55 | # final_X = tf.compat.v1.placeholder(tf.float32, 56 | # [FLAGS.num_views, None, 8, 8, 1536], 57 | # name='final_X') 58 | ground_truth = tf.compat.v1.placeholder(tf.int64, [None], name='ground_truth') 59 | is_training = tf.compat.v1.placeholder(tf.bool, name='is_training') 60 | dropout_keep_prob = tf.compat.v1.placeholder(tf.float32, name='dropout_keep_prob') 61 | # grouping_scheme = tf.placeholder(tf.bool, [NUM_GROUP, FLAGS.num_views]) 62 | # grouping_weight = tf.placeholder(tf.float32, [NUM_GROUP, 1]) 63 | g_scheme = tf.compat.v1.placeholder(tf.int32, [FLAGS.num_group, FLAGS.num_views]) 64 | g_weight = tf.compat.v1.placeholder(tf.float32, [FLAGS.num_group]) 65 | 66 | # # Grouping Module 67 | # d_scores, _, final_desc = model.discrimination_score(X, 68 | # num_classes, 69 | # is_training) 70 | 71 | # # GVCNN 72 | # logits, _ = model.gvcnn(final_X, 73 | # grouping_scheme, 74 | # grouping_weight, 75 | # num_classes, 76 | # is_training2, 77 | # dropout_keep_prob) 78 | 79 | # GVCNN 80 | view_scores, _, logits = model.gvcnn(X, 81 | num_classes, 82 | g_scheme, 83 | g_weight, 84 | is_training, 85 | dropout_keep_prob) 86 | 87 | # prediction = tf.nn.softmax(logits) 88 | # predicted_labels = tf.argmax(prediction, 1) 89 | 90 | # prediction = tf.argmax(logits, 1, name='prediction') 91 | # correct_prediction = tf.equal(prediction, ground_truth) 92 | # confusion_matrix = tf.confusion_matrix( 93 | # ground_truth, prediction, num_classes=num_classes) 94 | # accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) 95 | prediction = tf.argmax(logits, 1, name='prediction') 96 | correct_prediction = tf.equal(prediction, ground_truth) 97 | confusion_matrix = tf.math.confusion_matrix(ground_truth, 98 | prediction, 99 | num_classes=num_classes) 100 | accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) 101 | 102 | ################ 103 | # Prepare data 104 | ################ 105 | filenames = tf.compat.v1.placeholder(tf.string, shape=[]) 106 | eval_dataset = eval_data.Dataset(filenames, 107 | FLAGS.num_views, 108 | FLAGS.height, 109 | FLAGS.width, 110 | FLAGS.batch_size) 111 | iterator = eval_dataset.dataset.make_initializable_iterator() 112 | next_batch = iterator.get_next() 113 | 114 | sess_config = tf.compat.v1.ConfigProto(gpu_options=tf.compat.v1.GPUOptions(allow_growth=True)) 115 | with tf.compat.v1.Session(config=sess_config) as sess: 116 | sess.run(tf.compat.v1.global_variables_initializer()) 117 | 118 | # Create a saver object which will save all the variables 119 | saver = tf.compat.v1.train.Saver() 120 | if FLAGS.checkpoint_path: 121 | if tf.gfile.IsDirectory(FLAGS.checkpoint_path): 122 | checkpoint_path = tf.train.latest_checkpoint(FLAGS.checkpoint_path) 123 | else: 124 | checkpoint_path = FLAGS.checkpoint_path 125 | saver.restore(sess, checkpoint_path) 126 | 127 | # global_step = checkpoint_path.split('/')[-1].split('-')[-1] 128 | 129 | # Get the number of training/validation steps per epoch 130 | batches = int(MODELNET_EVAL_DATA_SIZE / FLAGS.batch_size) 131 | if MODELNET_EVAL_DATA_SIZE % FLAGS.batch_size > 0: 132 | batches += 1 133 | 134 | ############## 135 | # prediction 136 | ############## 137 | start_time = datetime.datetime.now() 138 | tf.logging.info("Start prediction: %s" % start_time) 139 | 140 | eval_filenames = os.path.join(FLAGS.dataset_path) 141 | sess.run(iterator.initializer, feed_dict={filenames: eval_filenames}) 142 | 143 | count = 0; 144 | total_acc = 0 145 | total_conf_matrix = None 146 | for i in range(batches): 147 | batch_xs, batch_ys, _ = sess.run(next_batch) 148 | 149 | # # Sets up a graph with feeds and fetches for partial runs. 150 | # handle = sess.partial_run_setup([d_scores, final_desc, 151 | # accuracy, confusion_matrix], 152 | # [X, final_X, ground_truth, 153 | # grouping_scheme, grouping_weight, is_training, 154 | # is_training2, dropout_keep_prob]) 155 | # 156 | # scores, final = sess.partial_run(handle, 157 | # [d_scores, final_desc], 158 | # feed_dict={ 159 | # X: batch_xs, 160 | # is_training: False} 161 | # ) 162 | # schemes = model.grouping_scheme(scores, NUM_GROUP, FLAGS.num_views) 163 | # weights = model.grouping_weight(scores, schemes) 164 | # 165 | # # Run the graph with this batch of training data. 166 | # acc, conf_matrix = \ 167 | # sess.partial_run(handle, 168 | # [accuracy, confusion_matrix], 169 | # feed_dict={ 170 | # final_X: final, 171 | # ground_truth: batch_ys, 172 | # grouping_scheme: schemes, 173 | # grouping_weight: weights, 174 | # is_training2: False, 175 | # dropout_keep_prob: 1.0} 176 | # ) 177 | 178 | 179 | # Sets up a graph with feeds and fetches for partial run. 180 | handle = sess.partial_run_setup([view_scores, accuracy, confusion_matrix], 181 | [X, g_scheme, g_weight, 182 | ground_truth, is_training, dropout_keep_prob]) 183 | 184 | _view_scores = sess.partial_run(handle, 185 | [view_scores], 186 | feed_dict={ 187 | X: batch_xs, 188 | is_training: False, 189 | dropout_keep_prob: 1.0} 190 | ) 191 | _g_schemes = model.group_scheme(_view_scores, FLAGS.num_group, FLAGS.num_views) 192 | _g_weights = model.group_weight(_g_schemes) 193 | 194 | # Run the graph with this batch of training data. 195 | acc, conf_matrix = \ 196 | sess.partial_run(handle, 197 | [accuracy, confusion_matrix], 198 | feed_dict={ 199 | ground_truth: batch_ys, 200 | g_scheme: _g_schemes, 201 | g_weight: _g_weights} 202 | ) 203 | 204 | total_acc += acc 205 | count += 1 206 | 207 | if total_conf_matrix is None: 208 | total_conf_matrix = conf_matrix 209 | else: 210 | total_conf_matrix += conf_matrix 211 | 212 | total_acc /= count 213 | tf.compat.v1.logging.info('Confusion Matrix:\n %s' % (total_conf_matrix)) 214 | tf.compat.v1.logging.info('Final test accuracy = %.3f%% (N=%d)' % 215 | (total_acc * 100, MODELNET_EVAL_DATA_SIZE)) 216 | 217 | end_time = datetime.datetime.now() 218 | tf.compat.v1.logging.info('End prediction: %s' % end_time) 219 | tf.compat.v1.logging.info('prediction waste time: %s' % (end_time - start_time)) 220 | 221 | 222 | if __name__ == '__main__': 223 | tf.app.run() 224 | -------------------------------------------------------------------------------- /eval_data.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import random 6 | 7 | import tensorflow as tf 8 | 9 | 10 | MEAN=[0.485, 0.456, 0.406] 11 | STD=[0.229, 0.224, 0.225] 12 | 13 | 14 | class Dataset(object): 15 | """ 16 | Wrapper class around the new Tensorflows dataset pipeline. 17 | 18 | Handles loading, partitioning, and preparing training data. 19 | """ 20 | 21 | def __init__(self, tfrecord_path, num_views, height, width, batch_size): 22 | self.num_views = num_views 23 | self.resize_h = height 24 | self.resize_w = width 25 | 26 | self.dataset = tf.data.TFRecordDataset(tfrecord_path, 27 | compression_type='GZIP', 28 | num_parallel_reads=batch_size * 4) 29 | # dataset = dataset.map(self._parse_func, num_parallel_calls=8) 30 | # The map transformation takes a function and applies it to every element 31 | # of the dataset. 32 | self.dataset = self.dataset.map(self.decode, num_parallel_calls=8) 33 | # self.dataset = self.dataset.map(self.augment, num_parallel_calls=8) 34 | self.dataset = self.dataset.map(self.normalize, num_parallel_calls=8) 35 | 36 | # Prefetches a batch at a time to smooth out the time taken to load input 37 | # files for shuffling and processing. 38 | self.dataset = self.dataset.prefetch(buffer_size=batch_size) 39 | # The shuffle transformation uses a finite-sized buffer to shuffle elements 40 | # in memory. The parameter is the number of elements in the buffer. For 41 | # completely uniform shuffling, set the parameter to be the same as the 42 | # number of elements in the dataset. 43 | # self.dataset = self.dataset.shuffle(1000 + 3 * batch_size) 44 | self.dataset = self.dataset.repeat() 45 | self.dataset = self.dataset.batch(batch_size) 46 | 47 | 48 | def decode(self, serialized_example): 49 | """Parses an image and label from the given `serialized_example`.""" 50 | features = tf.io.parse_single_example( 51 | serialized_example, 52 | # Defaults are not specified since both keys are required. 53 | features={ 54 | 'image/filename': tf.io.FixedLenFeature([self.num_views], tf.string), 55 | 'image/encoded': tf.io.FixedLenFeature([self.num_views], tf.string), 56 | 'image/label': tf.io.FixedLenFeature([], tf.int64), 57 | }) 58 | 59 | # Convert from a scalar string tensor to a float32 tensor with shape 60 | # image_decoded = tf.image.decode_png(features['image/encoded'], channels=3) 61 | # image = tf.image.resize_images(image_decoded, [self.resize_h, self.resize_w]) 62 | # 63 | # filename = features['image/filename'] 64 | 65 | images = [] 66 | filenames = [] 67 | img_lst = tf.unstack(features['image/encoded']) 68 | filename_lst = tf.unstack(features['image/filename']) 69 | for i, img in enumerate(img_lst): 70 | # Convert from a scalar string tensor to a float32 tensor with shape 71 | image_decoded = tf.image.decode_png(img, channels=3) 72 | image = tf.image.resize(image_decoded, [self.resize_h, self.resize_w]) 73 | images.append(image) 74 | filenames.append(filename_lst[i]) 75 | 76 | # Convert label from a scalar uint8 tensor to an int32 scalar. 77 | label = features['image/label'] 78 | 79 | return images, label, filenames 80 | 81 | 82 | def augment(self, images, label, filenames): 83 | """Placeholder for data augmentation.""" 84 | # OPTIONAL: Could reshape into a 28x28 image and apply distortions 85 | # here. Since we are not applying any distortions in this 86 | # example, and the next step expects the image to be flattened 87 | # into a vector, we don't bother. 88 | # img_lst = [] 89 | # img_tensor_lst = tf.unstack(images) 90 | # for i, image in enumerate(img_tensor_lst): 91 | # image = tf.image.central_crop(image, 0.85) 92 | # image = tf.image.random_flip_up_down(image) 93 | # image = tf.image.random_flip_left_right(image) 94 | # image = tf.image.rot90(image, k=random.randint(0, 4)) 95 | # paddings = tf.constant([[22, 22], [22, 22], [0, 0]]) # 299 96 | # image = tf.pad(image, paddings, "CONSTANT") 97 | # 98 | # img_lst.append(image) 99 | # 100 | # return img_lst, filenames 101 | return images, label, filenames 102 | 103 | 104 | def normalize(self, images, label, filenames): 105 | # input[channel] = (input[channel] - mean[channel]) / std[channel] 106 | img_lst = [] 107 | img_tensor_lst = tf.unstack(images) 108 | for i, image in enumerate(img_tensor_lst): 109 | img_lst.append(tf.cast(image, tf.float32) * (1. / 255) - 0.5) 110 | # img_lst.append(tf.div(tf.subtract(image, MEAN), STD)) 111 | 112 | return img_lst, label, filenames 113 | -------------------------------------------------------------------------------- /nets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ace19-dev/gvcnn-tf/9555509c714e194c0b110b1aaee9f9bd356078cf/nets/__init__.py -------------------------------------------------------------------------------- /nets/inception_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains common code shared by all inception models. 16 | Usage of arg scope: 17 | with slim.arg_scope(inception_arg_scope()): 18 | logits, end_points = inception.inception_v3(images, num_classes, 19 | is_training=is_training) 20 | """ 21 | from __future__ import absolute_import 22 | from __future__ import division 23 | from __future__ import print_function 24 | 25 | import tensorflow as tf 26 | 27 | slim = tf.contrib.slim 28 | 29 | 30 | def inception_arg_scope(weight_decay=0.00004, 31 | use_batch_norm=True, 32 | batch_norm_decay=0.9997, 33 | batch_norm_epsilon=0.001, 34 | activation_fn=tf.nn.relu, 35 | batch_norm_updates_collections=tf.GraphKeys.UPDATE_OPS, 36 | batch_norm_scale=False): 37 | """Defines the default arg scope for inception models. 38 | Args: 39 | weight_decay: The weight decay to use for regularizing the model. 40 | use_batch_norm: "If `True`, batch_norm is applied after each convolution. 41 | batch_norm_decay: Decay for batch norm moving average. 42 | batch_norm_epsilon: Small float added to variance to avoid dividing by zero 43 | in batch norm. 44 | activation_fn: Activation function for conv2d. 45 | batch_norm_updates_collections: Collection for the update ops for 46 | batch norm. 47 | batch_norm_scale: If True, uses an explicit `gamma` multiplier to scale the 48 | activations in the batch normalization layer. 49 | Returns: 50 | An `arg_scope` to use for the inception models. 51 | """ 52 | batch_norm_params = { 53 | # Decay for the moving averages. 54 | 'decay': batch_norm_decay, 55 | # epsilon to prevent 0s in variance. 56 | 'epsilon': batch_norm_epsilon, 57 | # collection containing update_ops. 58 | 'updates_collections': batch_norm_updates_collections, 59 | # use fused batch norm if possible. 60 | 'fused': None, 61 | 'scale': batch_norm_scale, 62 | } 63 | if use_batch_norm: 64 | normalizer_fn = slim.batch_norm 65 | normalizer_params = batch_norm_params 66 | else: 67 | normalizer_fn = None 68 | normalizer_params = {} 69 | # Set weight_decay for weights in Conv and FC layers. 70 | with slim.arg_scope([slim.conv2d, slim.fully_connected], 71 | weights_regularizer=slim.l2_regularizer(weight_decay)): 72 | with slim.arg_scope( 73 | [slim.conv2d], 74 | weights_initializer=slim.variance_scaling_initializer(), 75 | activation_fn=activation_fn, 76 | normalizer_fn=normalizer_fn, 77 | normalizer_params=normalizer_params) as sc: 78 | return sc -------------------------------------------------------------------------------- /nets/inception_v3.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains the definition for inception v3 classification network.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf 22 | 23 | from nets import inception_utils 24 | 25 | slim = tf.contrib.slim 26 | trunc_normal = lambda stddev: tf.truncated_normal_initializer(0.0, stddev) 27 | 28 | 29 | def inception_v3_base(inputs, 30 | final_endpoint='Mixed_7c', 31 | min_depth=16, 32 | depth_multiplier=1.0, 33 | scope=None): 34 | """Inception model from http://arxiv.org/abs/1512.00567. 35 | Constructs an Inception v3 network from inputs to the given final endpoint. 36 | This method can construct the network up to the final inception block 37 | Mixed_7c. 38 | Note that the names of the layers in the paper do not correspond to the names 39 | of the endpoints registered by this function although they build the same 40 | network. 41 | Here is a mapping from the old_names to the new names: 42 | Old name | New name 43 | ======================================= 44 | conv0 | Conv2d_1a_3x3 45 | conv1 | Conv2d_2a_3x3 46 | conv2 | Conv2d_2b_3x3 47 | pool1 | MaxPool_3a_3x3 48 | conv3 | Conv2d_3b_1x1 49 | conv4 | Conv2d_4a_3x3 50 | pool2 | MaxPool_5a_3x3 51 | mixed_35x35x256a | Mixed_5b 52 | mixed_35x35x288a | Mixed_5c 53 | mixed_35x35x288b | Mixed_5d 54 | mixed_17x17x768a | Mixed_6a 55 | mixed_17x17x768b | Mixed_6b 56 | mixed_17x17x768c | Mixed_6c 57 | mixed_17x17x768d | Mixed_6d 58 | mixed_17x17x768e | Mixed_6e 59 | mixed_8x8x1280a | Mixed_7a 60 | mixed_8x8x2048a | Mixed_7b 61 | mixed_8x8x2048b | Mixed_7c 62 | Args: 63 | inputs: a tensor of size [batch_size, height, width, channels]. 64 | final_endpoint: specifies the endpoint to construct the network up to. It 65 | can be one of ['Conv2d_1a_3x3', 'Conv2d_2a_3x3', 'Conv2d_2b_3x3', 66 | 'MaxPool_3a_3x3', 'Conv2d_3b_1x1', 'Conv2d_4a_3x3', 'MaxPool_5a_3x3', 67 | 'Mixed_5b', 'Mixed_5c', 'Mixed_5d', 'Mixed_6a', 'Mixed_6b', 'Mixed_6c', 68 | 'Mixed_6d', 'Mixed_6e', 'Mixed_7a', 'Mixed_7b', 'Mixed_7c']. 69 | min_depth: Minimum depth value (number of channels) for all convolution ops. 70 | Enforced when depth_multiplier < 1, and not an active constraint when 71 | depth_multiplier >= 1. 72 | depth_multiplier: Float multiplier for the depth (number of channels) 73 | for all convolution ops. The value must be greater than zero. Typical 74 | usage will be to set this value in (0, 1) to reduce the number of 75 | parameters or computation cost of the model. 76 | scope: Optional variable_scope. 77 | Returns: 78 | tensor_out: output tensor corresponding to the final_endpoint. 79 | end_points: a set of activations for external use, for example summaries or 80 | losses. 81 | Raises: 82 | ValueError: if final_endpoint is not set to one of the predefined values, 83 | or depth_multiplier <= 0 84 | """ 85 | # end_points will collect relevant activations for external use, for example 86 | # summaries or losses. 87 | end_points = {} 88 | 89 | if depth_multiplier <= 0: 90 | raise ValueError('depth_multiplier is not greater than zero.') 91 | depth = lambda d: max(int(d * depth_multiplier), min_depth) 92 | 93 | with tf.variable_scope(scope, 'InceptionV3', [inputs]): 94 | with slim.arg_scope([slim.conv2d, slim.max_pool2d, slim.avg_pool2d], 95 | stride=1, padding='VALID'): 96 | # 299 x 299 x 3 97 | end_point = 'Conv2d_1a_3x3' 98 | net = slim.conv2d(inputs, depth(32), [3, 3], stride=2, scope=end_point) 99 | end_points[end_point] = net 100 | if end_point == final_endpoint: return net, end_points 101 | # 149 x 149 x 32 102 | end_point = 'Conv2d_2a_3x3' 103 | net = slim.conv2d(net, depth(32), [3, 3], scope=end_point) 104 | end_points[end_point] = net 105 | if end_point == final_endpoint: return net, end_points 106 | # 147 x 147 x 32 107 | end_point = 'Conv2d_2b_3x3' 108 | net = slim.conv2d(net, depth(64), [3, 3], padding='SAME', scope=end_point) 109 | end_points[end_point] = net 110 | if end_point == final_endpoint: return net, end_points 111 | # 147 x 147 x 64 112 | end_point = 'MaxPool_3a_3x3' 113 | net = slim.max_pool2d(net, [3, 3], stride=2, scope=end_point) 114 | end_points[end_point] = net 115 | if end_point == final_endpoint: return net, end_points 116 | # 73 x 73 x 64 117 | end_point = 'Conv2d_3b_1x1' 118 | net = slim.conv2d(net, depth(80), [1, 1], scope=end_point) 119 | end_points[end_point] = net 120 | if end_point == final_endpoint: return net, end_points 121 | # 73 x 73 x 80. 122 | end_point = 'Conv2d_4a_3x3' 123 | net = slim.conv2d(net, depth(192), [3, 3], scope=end_point) 124 | end_points[end_point] = net 125 | if end_point == final_endpoint: return net, end_points 126 | # 71 x 71 x 192. 127 | end_point = 'MaxPool_5a_3x3' 128 | net = slim.max_pool2d(net, [3, 3], stride=2, scope=end_point) 129 | end_points[end_point] = net 130 | if end_point == final_endpoint: return net, end_points 131 | # 35 x 35 x 192. 132 | 133 | # Inception blocks 134 | with slim.arg_scope([slim.conv2d, slim.max_pool2d, slim.avg_pool2d], 135 | stride=1, padding='SAME'): 136 | # mixed: 35 x 35 x 256. 137 | end_point = 'Mixed_5b' 138 | with tf.variable_scope(end_point): 139 | with tf.variable_scope('Branch_0'): 140 | branch_0 = slim.conv2d(net, depth(64), [1, 1], scope='Conv2d_0a_1x1') 141 | with tf.variable_scope('Branch_1'): 142 | branch_1 = slim.conv2d(net, depth(48), [1, 1], scope='Conv2d_0a_1x1') 143 | branch_1 = slim.conv2d(branch_1, depth(64), [5, 5], 144 | scope='Conv2d_0b_5x5') 145 | with tf.variable_scope('Branch_2'): 146 | branch_2 = slim.conv2d(net, depth(64), [1, 1], scope='Conv2d_0a_1x1') 147 | branch_2 = slim.conv2d(branch_2, depth(96), [3, 3], 148 | scope='Conv2d_0b_3x3') 149 | branch_2 = slim.conv2d(branch_2, depth(96), [3, 3], 150 | scope='Conv2d_0c_3x3') 151 | with tf.variable_scope('Branch_3'): 152 | branch_3 = slim.avg_pool2d(net, [3, 3], scope='AvgPool_0a_3x3') 153 | branch_3 = slim.conv2d(branch_3, depth(32), [1, 1], 154 | scope='Conv2d_0b_1x1') 155 | net = tf.concat(axis=3, values=[branch_0, branch_1, branch_2, branch_3]) 156 | end_points[end_point] = net 157 | if end_point == final_endpoint: return net, end_points 158 | 159 | # mixed_1: 35 x 35 x 288. 160 | end_point = 'Mixed_5c' 161 | with tf.variable_scope(end_point): 162 | with tf.variable_scope('Branch_0'): 163 | branch_0 = slim.conv2d(net, depth(64), [1, 1], scope='Conv2d_0a_1x1') 164 | with tf.variable_scope('Branch_1'): 165 | branch_1 = slim.conv2d(net, depth(48), [1, 1], scope='Conv2d_0b_1x1') 166 | branch_1 = slim.conv2d(branch_1, depth(64), [5, 5], 167 | scope='Conv_1_0c_5x5') 168 | with tf.variable_scope('Branch_2'): 169 | branch_2 = slim.conv2d(net, depth(64), [1, 1], 170 | scope='Conv2d_0a_1x1') 171 | branch_2 = slim.conv2d(branch_2, depth(96), [3, 3], 172 | scope='Conv2d_0b_3x3') 173 | branch_2 = slim.conv2d(branch_2, depth(96), [3, 3], 174 | scope='Conv2d_0c_3x3') 175 | with tf.variable_scope('Branch_3'): 176 | branch_3 = slim.avg_pool2d(net, [3, 3], scope='AvgPool_0a_3x3') 177 | branch_3 = slim.conv2d(branch_3, depth(64), [1, 1], 178 | scope='Conv2d_0b_1x1') 179 | net = tf.concat(axis=3, values=[branch_0, branch_1, branch_2, branch_3]) 180 | end_points[end_point] = net 181 | if end_point == final_endpoint: return net, end_points 182 | 183 | # mixed_2: 35 x 35 x 288. 184 | end_point = 'Mixed_5d' 185 | with tf.variable_scope(end_point): 186 | with tf.variable_scope('Branch_0'): 187 | branch_0 = slim.conv2d(net, depth(64), [1, 1], scope='Conv2d_0a_1x1') 188 | with tf.variable_scope('Branch_1'): 189 | branch_1 = slim.conv2d(net, depth(48), [1, 1], scope='Conv2d_0a_1x1') 190 | branch_1 = slim.conv2d(branch_1, depth(64), [5, 5], 191 | scope='Conv2d_0b_5x5') 192 | with tf.variable_scope('Branch_2'): 193 | branch_2 = slim.conv2d(net, depth(64), [1, 1], scope='Conv2d_0a_1x1') 194 | branch_2 = slim.conv2d(branch_2, depth(96), [3, 3], 195 | scope='Conv2d_0b_3x3') 196 | branch_2 = slim.conv2d(branch_2, depth(96), [3, 3], 197 | scope='Conv2d_0c_3x3') 198 | with tf.variable_scope('Branch_3'): 199 | branch_3 = slim.avg_pool2d(net, [3, 3], scope='AvgPool_0a_3x3') 200 | branch_3 = slim.conv2d(branch_3, depth(64), [1, 1], 201 | scope='Conv2d_0b_1x1') 202 | net = tf.concat(axis=3, values=[branch_0, branch_1, branch_2, branch_3]) 203 | end_points[end_point] = net 204 | if end_point == final_endpoint: return net, end_points 205 | 206 | # mixed_3: 17 x 17 x 768. 207 | end_point = 'Mixed_6a' 208 | with tf.variable_scope(end_point): 209 | with tf.variable_scope('Branch_0'): 210 | branch_0 = slim.conv2d(net, depth(384), [3, 3], stride=2, 211 | padding='VALID', scope='Conv2d_1a_1x1') 212 | with tf.variable_scope('Branch_1'): 213 | branch_1 = slim.conv2d(net, depth(64), [1, 1], scope='Conv2d_0a_1x1') 214 | branch_1 = slim.conv2d(branch_1, depth(96), [3, 3], 215 | scope='Conv2d_0b_3x3') 216 | branch_1 = slim.conv2d(branch_1, depth(96), [3, 3], stride=2, 217 | padding='VALID', scope='Conv2d_1a_1x1') 218 | with tf.variable_scope('Branch_2'): 219 | branch_2 = slim.max_pool2d(net, [3, 3], stride=2, padding='VALID', 220 | scope='MaxPool_1a_3x3') 221 | net = tf.concat(axis=3, values=[branch_0, branch_1, branch_2]) 222 | end_points[end_point] = net 223 | if end_point == final_endpoint: return net, end_points 224 | 225 | # mixed4: 17 x 17 x 768. 226 | end_point = 'Mixed_6b' 227 | with tf.variable_scope(end_point): 228 | with tf.variable_scope('Branch_0'): 229 | branch_0 = slim.conv2d(net, depth(192), [1, 1], scope='Conv2d_0a_1x1') 230 | with tf.variable_scope('Branch_1'): 231 | branch_1 = slim.conv2d(net, depth(128), [1, 1], scope='Conv2d_0a_1x1') 232 | branch_1 = slim.conv2d(branch_1, depth(128), [1, 7], 233 | scope='Conv2d_0b_1x7') 234 | branch_1 = slim.conv2d(branch_1, depth(192), [7, 1], 235 | scope='Conv2d_0c_7x1') 236 | with tf.variable_scope('Branch_2'): 237 | branch_2 = slim.conv2d(net, depth(128), [1, 1], scope='Conv2d_0a_1x1') 238 | branch_2 = slim.conv2d(branch_2, depth(128), [7, 1], 239 | scope='Conv2d_0b_7x1') 240 | branch_2 = slim.conv2d(branch_2, depth(128), [1, 7], 241 | scope='Conv2d_0c_1x7') 242 | branch_2 = slim.conv2d(branch_2, depth(128), [7, 1], 243 | scope='Conv2d_0d_7x1') 244 | branch_2 = slim.conv2d(branch_2, depth(192), [1, 7], 245 | scope='Conv2d_0e_1x7') 246 | with tf.variable_scope('Branch_3'): 247 | branch_3 = slim.avg_pool2d(net, [3, 3], scope='AvgPool_0a_3x3') 248 | branch_3 = slim.conv2d(branch_3, depth(192), [1, 1], 249 | scope='Conv2d_0b_1x1') 250 | net = tf.concat(axis=3, values=[branch_0, branch_1, branch_2, branch_3]) 251 | end_points[end_point] = net 252 | if end_point == final_endpoint: return net, end_points 253 | 254 | # mixed_5: 17 x 17 x 768. 255 | end_point = 'Mixed_6c' 256 | with tf.variable_scope(end_point): 257 | with tf.variable_scope('Branch_0'): 258 | branch_0 = slim.conv2d(net, depth(192), [1, 1], scope='Conv2d_0a_1x1') 259 | with tf.variable_scope('Branch_1'): 260 | branch_1 = slim.conv2d(net, depth(160), [1, 1], scope='Conv2d_0a_1x1') 261 | branch_1 = slim.conv2d(branch_1, depth(160), [1, 7], 262 | scope='Conv2d_0b_1x7') 263 | branch_1 = slim.conv2d(branch_1, depth(192), [7, 1], 264 | scope='Conv2d_0c_7x1') 265 | with tf.variable_scope('Branch_2'): 266 | branch_2 = slim.conv2d(net, depth(160), [1, 1], scope='Conv2d_0a_1x1') 267 | branch_2 = slim.conv2d(branch_2, depth(160), [7, 1], 268 | scope='Conv2d_0b_7x1') 269 | branch_2 = slim.conv2d(branch_2, depth(160), [1, 7], 270 | scope='Conv2d_0c_1x7') 271 | branch_2 = slim.conv2d(branch_2, depth(160), [7, 1], 272 | scope='Conv2d_0d_7x1') 273 | branch_2 = slim.conv2d(branch_2, depth(192), [1, 7], 274 | scope='Conv2d_0e_1x7') 275 | with tf.variable_scope('Branch_3'): 276 | branch_3 = slim.avg_pool2d(net, [3, 3], scope='AvgPool_0a_3x3') 277 | branch_3 = slim.conv2d(branch_3, depth(192), [1, 1], 278 | scope='Conv2d_0b_1x1') 279 | net = tf.concat(axis=3, values=[branch_0, branch_1, branch_2, branch_3]) 280 | end_points[end_point] = net 281 | if end_point == final_endpoint: return net, end_points 282 | # mixed_6: 17 x 17 x 768. 283 | end_point = 'Mixed_6d' 284 | with tf.variable_scope(end_point): 285 | with tf.variable_scope('Branch_0'): 286 | branch_0 = slim.conv2d(net, depth(192), [1, 1], scope='Conv2d_0a_1x1') 287 | with tf.variable_scope('Branch_1'): 288 | branch_1 = slim.conv2d(net, depth(160), [1, 1], scope='Conv2d_0a_1x1') 289 | branch_1 = slim.conv2d(branch_1, depth(160), [1, 7], 290 | scope='Conv2d_0b_1x7') 291 | branch_1 = slim.conv2d(branch_1, depth(192), [7, 1], 292 | scope='Conv2d_0c_7x1') 293 | with tf.variable_scope('Branch_2'): 294 | branch_2 = slim.conv2d(net, depth(160), [1, 1], scope='Conv2d_0a_1x1') 295 | branch_2 = slim.conv2d(branch_2, depth(160), [7, 1], 296 | scope='Conv2d_0b_7x1') 297 | branch_2 = slim.conv2d(branch_2, depth(160), [1, 7], 298 | scope='Conv2d_0c_1x7') 299 | branch_2 = slim.conv2d(branch_2, depth(160), [7, 1], 300 | scope='Conv2d_0d_7x1') 301 | branch_2 = slim.conv2d(branch_2, depth(192), [1, 7], 302 | scope='Conv2d_0e_1x7') 303 | with tf.variable_scope('Branch_3'): 304 | branch_3 = slim.avg_pool2d(net, [3, 3], scope='AvgPool_0a_3x3') 305 | branch_3 = slim.conv2d(branch_3, depth(192), [1, 1], 306 | scope='Conv2d_0b_1x1') 307 | net = tf.concat(axis=3, values=[branch_0, branch_1, branch_2, branch_3]) 308 | end_points[end_point] = net 309 | if end_point == final_endpoint: return net, end_points 310 | 311 | # mixed_7: 17 x 17 x 768. 312 | end_point = 'Mixed_6e' 313 | with tf.variable_scope(end_point): 314 | with tf.variable_scope('Branch_0'): 315 | branch_0 = slim.conv2d(net, depth(192), [1, 1], scope='Conv2d_0a_1x1') 316 | with tf.variable_scope('Branch_1'): 317 | branch_1 = slim.conv2d(net, depth(192), [1, 1], scope='Conv2d_0a_1x1') 318 | branch_1 = slim.conv2d(branch_1, depth(192), [1, 7], 319 | scope='Conv2d_0b_1x7') 320 | branch_1 = slim.conv2d(branch_1, depth(192), [7, 1], 321 | scope='Conv2d_0c_7x1') 322 | with tf.variable_scope('Branch_2'): 323 | branch_2 = slim.conv2d(net, depth(192), [1, 1], scope='Conv2d_0a_1x1') 324 | branch_2 = slim.conv2d(branch_2, depth(192), [7, 1], 325 | scope='Conv2d_0b_7x1') 326 | branch_2 = slim.conv2d(branch_2, depth(192), [1, 7], 327 | scope='Conv2d_0c_1x7') 328 | branch_2 = slim.conv2d(branch_2, depth(192), [7, 1], 329 | scope='Conv2d_0d_7x1') 330 | branch_2 = slim.conv2d(branch_2, depth(192), [1, 7], 331 | scope='Conv2d_0e_1x7') 332 | with tf.variable_scope('Branch_3'): 333 | branch_3 = slim.avg_pool2d(net, [3, 3], scope='AvgPool_0a_3x3') 334 | branch_3 = slim.conv2d(branch_3, depth(192), [1, 1], 335 | scope='Conv2d_0b_1x1') 336 | net = tf.concat(axis=3, values=[branch_0, branch_1, branch_2, branch_3]) 337 | end_points[end_point] = net 338 | if end_point == final_endpoint: return net, end_points 339 | 340 | # mixed_8: 8 x 8 x 1280. 341 | end_point = 'Mixed_7a' 342 | with tf.variable_scope(end_point): 343 | with tf.variable_scope('Branch_0'): 344 | branch_0 = slim.conv2d(net, depth(192), [1, 1], scope='Conv2d_0a_1x1') 345 | branch_0 = slim.conv2d(branch_0, depth(320), [3, 3], stride=2, 346 | padding='VALID', scope='Conv2d_1a_3x3') 347 | with tf.variable_scope('Branch_1'): 348 | branch_1 = slim.conv2d(net, depth(192), [1, 1], scope='Conv2d_0a_1x1') 349 | branch_1 = slim.conv2d(branch_1, depth(192), [1, 7], 350 | scope='Conv2d_0b_1x7') 351 | branch_1 = slim.conv2d(branch_1, depth(192), [7, 1], 352 | scope='Conv2d_0c_7x1') 353 | branch_1 = slim.conv2d(branch_1, depth(192), [3, 3], stride=2, 354 | padding='VALID', scope='Conv2d_1a_3x3') 355 | with tf.variable_scope('Branch_2'): 356 | branch_2 = slim.max_pool2d(net, [3, 3], stride=2, padding='VALID', 357 | scope='MaxPool_1a_3x3') 358 | net = tf.concat(axis=3, values=[branch_0, branch_1, branch_2]) 359 | end_points[end_point] = net 360 | if end_point == final_endpoint: return net, end_points 361 | # mixed_9: 8 x 8 x 2048. 362 | end_point = 'Mixed_7b' 363 | with tf.variable_scope(end_point): 364 | with tf.variable_scope('Branch_0'): 365 | branch_0 = slim.conv2d(net, depth(320), [1, 1], scope='Conv2d_0a_1x1') 366 | with tf.variable_scope('Branch_1'): 367 | branch_1 = slim.conv2d(net, depth(384), [1, 1], scope='Conv2d_0a_1x1') 368 | branch_1 = tf.concat(axis=3, values=[ 369 | slim.conv2d(branch_1, depth(384), [1, 3], scope='Conv2d_0b_1x3'), 370 | slim.conv2d(branch_1, depth(384), [3, 1], scope='Conv2d_0b_3x1')]) 371 | with tf.variable_scope('Branch_2'): 372 | branch_2 = slim.conv2d(net, depth(448), [1, 1], scope='Conv2d_0a_1x1') 373 | branch_2 = slim.conv2d( 374 | branch_2, depth(384), [3, 3], scope='Conv2d_0b_3x3') 375 | branch_2 = tf.concat(axis=3, values=[ 376 | slim.conv2d(branch_2, depth(384), [1, 3], scope='Conv2d_0c_1x3'), 377 | slim.conv2d(branch_2, depth(384), [3, 1], scope='Conv2d_0d_3x1')]) 378 | with tf.variable_scope('Branch_3'): 379 | branch_3 = slim.avg_pool2d(net, [3, 3], scope='AvgPool_0a_3x3') 380 | branch_3 = slim.conv2d( 381 | branch_3, depth(192), [1, 1], scope='Conv2d_0b_1x1') 382 | net = tf.concat(axis=3, values=[branch_0, branch_1, branch_2, branch_3]) 383 | end_points[end_point] = net 384 | if end_point == final_endpoint: return net, end_points 385 | 386 | # mixed_10: 8 x 8 x 2048. 387 | end_point = 'Mixed_7c' 388 | with tf.variable_scope(end_point): 389 | with tf.variable_scope('Branch_0'): 390 | branch_0 = slim.conv2d(net, depth(320), [1, 1], scope='Conv2d_0a_1x1') 391 | with tf.variable_scope('Branch_1'): 392 | branch_1 = slim.conv2d(net, depth(384), [1, 1], scope='Conv2d_0a_1x1') 393 | branch_1 = tf.concat(axis=3, values=[ 394 | slim.conv2d(branch_1, depth(384), [1, 3], scope='Conv2d_0b_1x3'), 395 | slim.conv2d(branch_1, depth(384), [3, 1], scope='Conv2d_0c_3x1')]) 396 | with tf.variable_scope('Branch_2'): 397 | branch_2 = slim.conv2d(net, depth(448), [1, 1], scope='Conv2d_0a_1x1') 398 | branch_2 = slim.conv2d( 399 | branch_2, depth(384), [3, 3], scope='Conv2d_0b_3x3') 400 | branch_2 = tf.concat(axis=3, values=[ 401 | slim.conv2d(branch_2, depth(384), [1, 3], scope='Conv2d_0c_1x3'), 402 | slim.conv2d(branch_2, depth(384), [3, 1], scope='Conv2d_0d_3x1')]) 403 | with tf.variable_scope('Branch_3'): 404 | branch_3 = slim.avg_pool2d(net, [3, 3], scope='AvgPool_0a_3x3') 405 | branch_3 = slim.conv2d( 406 | branch_3, depth(192), [1, 1], scope='Conv2d_0b_1x1') 407 | net = tf.concat(axis=3, values=[branch_0, branch_1, branch_2, branch_3]) 408 | end_points[end_point] = net 409 | if end_point == final_endpoint: return net, end_points 410 | raise ValueError('Unknown final endpoint %s' % final_endpoint) 411 | 412 | 413 | def inception_v3(inputs, 414 | num_classes=1000, 415 | is_training=True, 416 | dropout_keep_prob=0.8, 417 | min_depth=16, 418 | depth_multiplier=1.0, 419 | prediction_fn=slim.softmax, 420 | spatial_squeeze=True, 421 | reuse=None, 422 | create_aux_logits=True, 423 | scope='InceptionV3', 424 | global_pool=False): 425 | """Inception model from http://arxiv.org/abs/1512.00567. 426 | "Rethinking the Inception Architecture for Computer Vision" 427 | Christian Szegedy, Vincent Vanhoucke, Sergey Ioffe, Jonathon Shlens, 428 | Zbigniew Wojna. 429 | With the default arguments this method constructs the exact model defined in 430 | the paper. However, one can experiment with variations of the inception_v3 431 | network by changing arguments dropout_keep_prob, min_depth and 432 | depth_multiplier. 433 | The default image size used to train this network is 299x299. 434 | Args: 435 | inputs: a tensor of size [batch_size, height, width, channels]. 436 | num_classes: number of predicted classes. If 0 or None, the logits layer 437 | is omitted and the input features to the logits layer (before dropout) 438 | are returned instead. 439 | is_training: whether is training or not. 440 | dropout_keep_prob: the percentage of activation values that are retained. 441 | min_depth: Minimum depth value (number of channels) for all convolution ops. 442 | Enforced when depth_multiplier < 1, and not an active constraint when 443 | depth_multiplier >= 1. 444 | depth_multiplier: Float multiplier for the depth (number of channels) 445 | for all convolution ops. The value must be greater than zero. Typical 446 | usage will be to set this value in (0, 1) to reduce the number of 447 | parameters or computation cost of the model. 448 | prediction_fn: a function to get predictions out of logits. 449 | spatial_squeeze: if True, logits is of shape [B, C], if false logits is of 450 | shape [B, 1, 1, C], where B is batch_size and C is number of classes. 451 | reuse: whether or not the network and its variables should be reused. To be 452 | able to reuse 'scope' must be given. 453 | create_aux_logits: Whether to create the auxiliary logits. 454 | scope: Optional variable_scope. 455 | global_pool: Optional boolean flag to control the avgpooling before the 456 | logits layer. If false or unset, pooling is done with a fixed window 457 | that reduces default-sized inputs to 1x1, while larger inputs lead to 458 | larger outputs. If true, any input size is pooled down to 1x1. 459 | Returns: 460 | net: a Tensor with the logits (pre-softmax activations) if num_classes 461 | is a non-zero integer, or the non-dropped-out input to the logits layer 462 | if num_classes is 0 or None. 463 | end_points: a dictionary from components of the network to the corresponding 464 | activation. 465 | Raises: 466 | ValueError: if 'depth_multiplier' is less than or equal to zero. 467 | """ 468 | if depth_multiplier <= 0: 469 | raise ValueError('depth_multiplier is not greater than zero.') 470 | depth = lambda d: max(int(d * depth_multiplier), min_depth) 471 | 472 | with tf.variable_scope(scope, 'InceptionV3', [inputs], reuse=reuse) as scope: 473 | with slim.arg_scope([slim.batch_norm, slim.dropout], 474 | is_training=is_training): 475 | net, end_points = inception_v3_base( 476 | inputs, scope=scope, min_depth=min_depth, 477 | depth_multiplier=depth_multiplier) 478 | 479 | # Auxiliary Head logits 480 | if create_aux_logits and num_classes: 481 | with slim.arg_scope([slim.conv2d, slim.max_pool2d, slim.avg_pool2d], 482 | stride=1, padding='SAME'): 483 | aux_logits = end_points['Mixed_6e'] 484 | with tf.variable_scope('AuxLogits'): 485 | aux_logits = slim.avg_pool2d( 486 | aux_logits, [5, 5], stride=3, padding='VALID', 487 | scope='AvgPool_1a_5x5') 488 | aux_logits = slim.conv2d(aux_logits, depth(128), [1, 1], 489 | scope='Conv2d_1b_1x1') 490 | 491 | # Shape of feature map before the final layer. 492 | kernel_size = _reduced_kernel_size_for_small_input( 493 | aux_logits, [5, 5]) 494 | aux_logits = slim.conv2d( 495 | aux_logits, depth(768), kernel_size, 496 | weights_initializer=trunc_normal(0.01), 497 | padding='VALID', scope='Conv2d_2a_{}x{}'.format(*kernel_size)) 498 | aux_logits = slim.conv2d( 499 | aux_logits, num_classes, [1, 1], activation_fn=None, 500 | normalizer_fn=None, weights_initializer=trunc_normal(0.001), 501 | scope='Conv2d_2b_1x1') 502 | if spatial_squeeze: 503 | aux_logits = tf.squeeze(aux_logits, [1, 2], name='SpatialSqueeze') 504 | end_points['AuxLogits'] = aux_logits 505 | 506 | # Final pooling and prediction 507 | with tf.variable_scope('Logits'): 508 | if global_pool: 509 | # Global average pooling. 510 | net = tf.reduce_mean(net, [1, 2], keep_dims=True, name='GlobalPool') 511 | end_points['global_pool'] = net 512 | else: 513 | # Pooling with a fixed kernel size. 514 | kernel_size = _reduced_kernel_size_for_small_input(net, [8, 8]) 515 | net = slim.avg_pool2d(net, kernel_size, padding='VALID', 516 | scope='AvgPool_1a_{}x{}'.format(*kernel_size)) 517 | end_points['AvgPool_1a'] = net 518 | if not num_classes: 519 | return net, end_points 520 | # 1 x 1 x 2048 521 | net = slim.dropout(net, keep_prob=dropout_keep_prob, scope='Dropout_1b') 522 | end_points['PreLogits'] = net 523 | # 2048 524 | logits = slim.conv2d(net, num_classes, [1, 1], activation_fn=None, 525 | normalizer_fn=None, scope='Conv2d_1c_1x1') 526 | if spatial_squeeze: 527 | logits = tf.squeeze(logits, [1, 2], name='SpatialSqueeze') 528 | # 1000 529 | end_points['Logits'] = logits 530 | end_points['Predictions'] = prediction_fn(logits, scope='Predictions') 531 | return logits, end_points 532 | inception_v3.default_image_size = 299 533 | 534 | 535 | def _reduced_kernel_size_for_small_input(input_tensor, kernel_size): 536 | """Define kernel size which is automatically reduced for small input. 537 | If the shape of the input images is unknown at graph construction time this 538 | function assumes that the input images are is large enough. 539 | Args: 540 | input_tensor: input tensor of size [batch_size, height, width, channels]. 541 | kernel_size: desired kernel size of length 2: [kernel_height, kernel_width] 542 | Returns: 543 | a tensor with the kernel size. 544 | TODO(jrru): Make this function work with unknown shapes. Theoretically, this 545 | can be done with the code below. Problems are two-fold: (1) If the shape was 546 | known, it will be lost. (2) inception.slim.ops._two_element_tuple cannot 547 | handle tensors that define the kernel size. 548 | shape = tf.shape(input_tensor) 549 | return = tf.stack([tf.minimum(shape[1], kernel_size[0]), 550 | tf.minimum(shape[2], kernel_size[1])]) 551 | """ 552 | shape = input_tensor.get_shape().as_list() 553 | if shape[1] is None or shape[2] is None: 554 | kernel_size_out = kernel_size 555 | else: 556 | kernel_size_out = [min(shape[1], kernel_size[0]), 557 | min(shape[2], kernel_size[1])] 558 | return kernel_size_out 559 | 560 | 561 | inception_v3_arg_scope = inception_utils.inception_arg_scope -------------------------------------------------------------------------------- /nets/model.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | 6 | import tensorflow as tf 7 | import numpy as np 8 | import math 9 | 10 | from nets import inception_v3, resnet_v2 11 | 12 | slim = tf.contrib.slim 13 | 14 | 15 | # best group count for accuracy? 16 | def group_scheme(view_discrimination_score, num_group, num_views): 17 | ''' 18 | Note that 1 ≤ M ≤ N because there may exist sub-ranges 19 | that have no views falling into it. 20 | ''' 21 | schemes = np.full((num_group, num_views), 0, dtype=np.int) 22 | for idx, score in enumerate(view_discrimination_score[0]): 23 | schemes[int(score*10), idx] = 1 # 10 group 24 | 25 | return schemes 26 | 27 | 28 | def group_weight(g_schemes): 29 | num_group = g_schemes.shape[0] 30 | num_views = g_schemes.shape[1] 31 | 32 | weights = np.zeros(shape=(num_group), dtype=np.float32) 33 | for i in range(num_group): 34 | sum = 1 35 | for j in range(num_views): 36 | if g_schemes[i][j] == 1: 37 | sum += g_schemes[i][j] 38 | 39 | weights[i] = sum 40 | 41 | return weights 42 | 43 | 44 | def view_pooling(final_view_descriptors, group_scheme): 45 | 46 | ''' 47 | Intra-Group View Pooling 48 | 49 | Final view descriptors are source of view pooling with grouping scheme. 50 | 51 | Given the view descriptors and the generated grouping information, 52 | the objective here is to conduct intra-group 53 | view pooling towards a group level description. 54 | 55 | the views in the same group have the similar discrimination, 56 | which are assigned the same weight. 57 | 58 | :param group_scheme: shape [num_group, num_view] 59 | :param final_view_descriptors: 60 | :return: group_descriptors 61 | ''' 62 | group_descriptors = {} 63 | dummy = tf.ones_like(final_view_descriptors) 64 | 65 | scheme_list = tf.unstack(group_scheme) 66 | indices = [tf.squeeze(tf.where(elem), axis=1) for elem in scheme_list] 67 | for i, ind in enumerate(indices): 68 | pooled_view = tf.cond(tf.greater(tf.size(ind), 0), 69 | lambda: tf.gather(final_view_descriptors, ind), 70 | lambda: dummy) 71 | 72 | group_descriptors[i] = tf.reduce_max(pooled_view, axis=0) 73 | 74 | return group_descriptors 75 | 76 | 77 | def group_fusion(group_descriptors, group_weight): 78 | ''' 79 | To generate the shape level description, all these group 80 | level descriptors should be further combined. 81 | 82 | The groups containing more discriminative views contribute more to 83 | the final 3D shape descriptor D(S) than those containing less discriminative views. 84 | By using these hierarchical view-group-shape description framework, 85 | the important and discriminative visual content can be discovered in the group level, 86 | and thus emphasized in the shape descriptor accordingly. 87 | 88 | :param 89 | group_descriptors: dic {index: group_desc} 90 | group_weight: 91 | 92 | :return: 93 | ''' 94 | group_weight_list = tf.unstack(group_weight) 95 | numerator = [] 96 | for key, value in group_descriptors.items(): 97 | numerator.append(tf.multiply(group_weight_list[key], value)) 98 | 99 | denominator = tf.reduce_sum(group_weight_list) 100 | shape_descriptor = tf.div(tf.add_n(numerator), denominator) 101 | 102 | return shape_descriptor 103 | 104 | 105 | def gvcnn(inputs, num_classes, group_scheme, group_weight, 106 | is_training=True, dropout_keep_prob=0.8, reuse=tf.compat.v1.AUTO_REUSE): 107 | """ 108 | Raw View Descriptor Generation 109 | 110 | first part of the network (FCN) to get the raw descriptor in the view level. 111 | The “FCN” part is the top five convolutional layers of GoogLeNet. 112 | (mid-level representation) 113 | 114 | Extract the raw view descriptors. 115 | Compared with deeper CNN, shallow FCN could have more position information, 116 | which is needed for the followed grouping module and the deeper CNN will have 117 | the content information which could represent the view feature better. 118 | 119 | Args: 120 | inputs: N x V x H x W x C tensor 121 | scope: 122 | """ 123 | view_discrimination_scores = [] 124 | final_view_descriptors = [] 125 | 126 | n_views = inputs.get_shape().as_list()[1] 127 | # transpose views: (NxVxHxWxC) -> (VxNxHxWxC) 128 | views = tf.transpose(inputs, perm=[1, 0, 2, 3, 4]) 129 | for index in range(n_views): 130 | batch_view = tf.gather(views, index) # N x H x W x C 131 | # with slim.arg_scope(inception_v3.inception_v3_arg_scope()): 132 | # _, end_points = inception_v3.inception_v3(batch_view, 133 | # num_classes=num_classes, 134 | # is_training=is_training, 135 | # dropout_keep_prob=dropout_keep_prob, 136 | # reuse=reuse) 137 | with slim.arg_scope(resnet_v2.resnet_arg_scope()): 138 | _, end_points = resnet_v2.resnet_v2_50(batch_view, 139 | num_classes=num_classes, 140 | is_training=is_training, 141 | reuse=reuse) 142 | 143 | # GAP layer to obtain the discrimination scores from raw view descriptors. 144 | raw = tf.keras.layers.GlobalAveragePooling2D()(end_points['resnet_v2_50/block3']) 145 | raw = tf.keras.layers.Dense(1)(raw) 146 | raw = tf.reduce_mean(raw) 147 | batch_view_score = tf.nn.sigmoid(tf.math.log(tf.abs(raw))) 148 | view_discrimination_scores.append(batch_view_score) 149 | final_view_descriptors.append(end_points['resnet_v2_50/block4']) 150 | 151 | # TODO: tuning point block. 152 | # ----------------------------- 153 | # Intra-Group View Pooling 154 | group_descriptors = view_pooling(final_view_descriptors, group_scheme) 155 | 156 | # Group Fusion 157 | shape_descriptor = group_fusion(group_descriptors, group_weight) 158 | # ----------------------------- 159 | 160 | # # test - simple pooling view 161 | # shape_descriptor = tf.reduce_max(final_view_descriptors, axis=0) 162 | 163 | net = tf.keras.layers.GlobalAveragePooling2D()(shape_descriptor) 164 | logits = tf.keras.layers.Dense(num_classes)(net) 165 | 166 | return view_discrimination_scores, shape_descriptor, logits 167 | 168 | 169 | def basic(inputs, 170 | num_classes, 171 | is_training=True, 172 | dropout_keep_prob=0.8, 173 | reuse=tf.compat.v1.AUTO_REUSE): 174 | ''' 175 | Args: 176 | inputs: N x V x H x W x C tensor 177 | scope: 178 | ''' 179 | final_view_descriptors = [] 180 | 181 | n_views = inputs.get_shape().as_list()[1] 182 | # transpose views: (NxVxHxWxC) -> (VxNxHxWxC) 183 | views = tf.transpose(inputs, perm=[1, 0, 2, 3, 4]) 184 | for index in range(n_views): 185 | batch_view = tf.gather(views, index) # N x H x W x C 186 | 187 | # with slim.arg_scope(inception_v3.inception_v3_arg_scope()): 188 | # logits, end_points = inception_v3.inception_v3(batch_view, 189 | # num_classes = num_classes, 190 | # is_training=is_training, 191 | # dropout_keep_prob=dropout_keep_prob, 192 | # reuse=reuse) 193 | # final_view_descriptors.append(end_points['Mixed_7c']) 194 | 195 | with slim.arg_scope(resnet_v2.resnet_arg_scope()): 196 | logits, end_points = resnet_v2.resnet_v2_50(batch_view, 197 | num_classes=num_classes, 198 | is_training=is_training, 199 | reuse=reuse) 200 | final_view_descriptors.append(end_points['resnet_v2_50/block4']) 201 | 202 | shape_descriptor = tf.reduce_max(final_view_descriptors, axis=0) 203 | net = tf.keras.layers.GlobalAveragePooling2D()(shape_descriptor) 204 | logits = tf.keras.layers.Dense(num_classes)(net) 205 | 206 | return shape_descriptor, logits 207 | -------------------------------------------------------------------------------- /nets/resnet_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains building blocks for various versions of Residual Networks. 16 | Residual networks (ResNets) were proposed in: 17 | Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 18 | Deep Residual Learning for Image Recognition. arXiv:1512.03385, 2015 19 | More variants were introduced in: 20 | Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 21 | Identity Mappings in Deep Residual Networks. arXiv: 1603.05027, 2016 22 | We can obtain different ResNet variants by changing the network depth, width, 23 | and form of residual unit. This module implements the infrastructure for 24 | building them. Concrete ResNet units and full ResNet networks are implemented in 25 | the accompanying resnet_v1.py and resnet_v2.py modules. 26 | Compared to https://github.com/KaimingHe/deep-residual-networks, in the current 27 | implementation we subsample the output activations in the last residual unit of 28 | each block, instead of subsampling the input activations in the first residual 29 | unit of each block. The two implementations give identical results but our 30 | implementation is more memory efficient. 31 | """ 32 | from __future__ import absolute_import 33 | from __future__ import division 34 | from __future__ import print_function 35 | 36 | import collections 37 | import tensorflow as tf 38 | 39 | slim = tf.contrib.slim 40 | 41 | 42 | class Block(collections.namedtuple('Block', ['scope', 'unit_fn', 'args'])): 43 | """A named tuple describing a ResNet block. 44 | Its parts are: 45 | scope: The scope of the `Block`. 46 | unit_fn: The ResNet unit function which takes as input a `Tensor` and 47 | returns another `Tensor` with the output of the ResNet unit. 48 | args: A list of length equal to the number of units in the `Block`. The list 49 | contains one (depth, depth_bottleneck, stride) tuple for each unit in the 50 | block to serve as argument to unit_fn. 51 | """ 52 | 53 | 54 | def subsample(inputs, factor, scope=None): 55 | """Subsamples the input along the spatial dimensions. 56 | Args: 57 | inputs: A `Tensor` of size [batch, height_in, width_in, channels]. 58 | factor: The subsampling factor. 59 | scope: Optional variable_scope. 60 | Returns: 61 | output: A `Tensor` of size [batch, height_out, width_out, channels] with the 62 | input, either intact (if factor == 1) or subsampled (if factor > 1). 63 | """ 64 | if factor == 1: 65 | return inputs 66 | else: 67 | return slim.max_pool2d(inputs, [1, 1], stride=factor, scope=scope) 68 | 69 | 70 | def conv2d_same(inputs, num_outputs, kernel_size, stride, rate=1, scope=None): 71 | """Strided 2-D convolution with 'SAME' padding. 72 | When stride > 1, then we do explicit zero-padding, followed by conv2d with 73 | 'VALID' padding. 74 | Note that 75 | net = conv2d_same(inputs, num_outputs, 3, stride=stride) 76 | is equivalent to 77 | net = slim.conv2d(inputs, num_outputs, 3, stride=1, padding='SAME') 78 | net = subsample(net, factor=stride) 79 | whereas 80 | net = slim.conv2d(inputs, num_outputs, 3, stride=stride, padding='SAME') 81 | is different when the input's height or width is even, which is why we add the 82 | current function. For more details, see ResnetUtilsTest.testConv2DSameEven(). 83 | Args: 84 | inputs: A 4-D tensor of size [batch, height_in, width_in, channels]. 85 | num_outputs: An integer, the number of output filters. 86 | kernel_size: An int with the kernel_size of the filters. 87 | stride: An integer, the output stride. 88 | rate: An integer, rate for atrous convolution. 89 | scope: Scope. 90 | Returns: 91 | output: A 4-D tensor of size [batch, height_out, width_out, channels] with 92 | the convolution output. 93 | """ 94 | if stride == 1: 95 | return slim.conv2d(inputs, num_outputs, kernel_size, stride=1, rate=rate, 96 | padding='SAME', scope=scope) 97 | else: 98 | kernel_size_effective = kernel_size + (kernel_size - 1) * (rate - 1) 99 | pad_total = kernel_size_effective - 1 100 | pad_beg = pad_total // 2 101 | pad_end = pad_total - pad_beg 102 | inputs = tf.pad(inputs, 103 | [[0, 0], [pad_beg, pad_end], [pad_beg, pad_end], [0, 0]]) 104 | return slim.conv2d(inputs, num_outputs, kernel_size, stride=stride, 105 | rate=rate, padding='VALID', scope=scope) 106 | 107 | 108 | @slim.add_arg_scope 109 | def stack_blocks_dense(net, blocks, output_stride=None, 110 | store_non_strided_activations=False, 111 | outputs_collections=None): 112 | """Stacks ResNet `Blocks` and controls output feature density. 113 | First, this function creates scopes for the ResNet in the form of 114 | 'block_name/unit_1', 'block_name/unit_2', etc. 115 | Second, this function allows the user to explicitly control the ResNet 116 | output_stride, which is the ratio of the input to output spatial resolution. 117 | This is useful for dense prediction tasks such as semantic segmentation or 118 | object detection. 119 | Most ResNets consist of 4 ResNet blocks and subsample the activations by a 120 | factor of 2 when transitioning between consecutive ResNet blocks. This results 121 | to a nominal ResNet output_stride equal to 8. If we set the output_stride to 122 | half the nominal network stride (e.g., output_stride=4), then we compute 123 | responses twice. 124 | Control of the output feature density is implemented by atrous convolution. 125 | Args: 126 | net: A `Tensor` of size [batch, height, width, channels]. 127 | blocks: A list of length equal to the number of ResNet `Blocks`. Each 128 | element is a ResNet `Block` object describing the units in the `Block`. 129 | output_stride: If `None`, then the output will be computed at the nominal 130 | network stride. If output_stride is not `None`, it specifies the requested 131 | ratio of input to output spatial resolution, which needs to be equal to 132 | the product of unit strides from the start up to some level of the ResNet. 133 | For example, if the ResNet employs units with strides 1, 2, 1, 3, 4, 1, 134 | then valid values for the output_stride are 1, 2, 6, 24 or None (which 135 | is equivalent to output_stride=24). 136 | store_non_strided_activations: If True, we compute non-strided (undecimated) 137 | activations at the last unit of each block and store them in the 138 | `outputs_collections` before subsampling them. This gives us access to 139 | higher resolution intermediate activations which are useful in some 140 | dense prediction problems but increases 4x the computation and memory cost 141 | at the last unit of each block. 142 | outputs_collections: Collection to add the ResNet block outputs. 143 | Returns: 144 | net: Output tensor with stride equal to the specified output_stride. 145 | Raises: 146 | ValueError: If the target output_stride is not valid. 147 | """ 148 | # The current_stride variable keeps track of the effective stride of the 149 | # activations. This allows us to invoke atrous convolution whenever applying 150 | # the next residual unit would result in the activations having stride larger 151 | # than the target output_stride. 152 | current_stride = 1 153 | 154 | # The atrous convolution rate parameter. 155 | rate = 1 156 | 157 | for block in blocks: 158 | with tf.variable_scope(block.scope, 'block', [net]) as sc: 159 | block_stride = 1 160 | for i, unit in enumerate(block.args): 161 | if store_non_strided_activations and i == len(block.args) - 1: 162 | # Move stride from the block's last unit to the end of the block. 163 | block_stride = unit.get('stride', 1) 164 | unit = dict(unit, stride=1) 165 | 166 | with tf.variable_scope('unit_%d' % (i + 1), values=[net]): 167 | # If we have reached the target output_stride, then we need to employ 168 | # atrous convolution with stride=1 and multiply the atrous rate by the 169 | # current unit's stride for use in subsequent layers. 170 | if output_stride is not None and current_stride == output_stride: 171 | net = block.unit_fn(net, rate=rate, **dict(unit, stride=1)) 172 | rate *= unit.get('stride', 1) 173 | 174 | else: 175 | net = block.unit_fn(net, rate=1, **unit) 176 | current_stride *= unit.get('stride', 1) 177 | if output_stride is not None and current_stride > output_stride: 178 | raise ValueError('The target output_stride cannot be reached.') 179 | 180 | # Collect activations at the block's end before performing subsampling. 181 | net = slim.utils.collect_named_outputs(outputs_collections, sc.name, net) 182 | 183 | # Subsampling of the block's output activations. 184 | if output_stride is not None and current_stride == output_stride: 185 | rate *= block_stride 186 | else: 187 | net = subsample(net, block_stride) 188 | current_stride *= block_stride 189 | if output_stride is not None and current_stride > output_stride: 190 | raise ValueError('The target output_stride cannot be reached.') 191 | 192 | if output_stride is not None and current_stride != output_stride: 193 | raise ValueError('The target output_stride cannot be reached.') 194 | 195 | return net 196 | 197 | 198 | def resnet_arg_scope(weight_decay=0.0001, 199 | batch_norm_decay=0.997, 200 | batch_norm_epsilon=1e-5, 201 | batch_norm_scale=True, 202 | activation_fn=tf.nn.relu, 203 | use_batch_norm=True, 204 | batch_norm_updates_collections=tf.GraphKeys.UPDATE_OPS): 205 | """Defines the default ResNet arg scope. 206 | TODO(gpapan): The batch-normalization related default values above are 207 | appropriate for use in conjunction with the reference ResNet models 208 | released at https://github.com/KaimingHe/deep-residual-networks. When 209 | training ResNets from scratch, they might need to be tuned. 210 | Args: 211 | weight_decay: The weight decay to use for regularizing the model. 212 | batch_norm_decay: The moving average decay when estimating layer activation 213 | statistics in batch normalization. 214 | batch_norm_epsilon: Small constant to prevent division by zero when 215 | normalizing activations by their variance in batch normalization. 216 | batch_norm_scale: If True, uses an explicit `gamma` multiplier to scale the 217 | activations in the batch normalization layer. 218 | activation_fn: The activation function which is used in ResNet. 219 | use_batch_norm: Whether or not to use batch normalization. 220 | batch_norm_updates_collections: Collection for the update ops for 221 | batch norm. 222 | Returns: 223 | An `arg_scope` to use for the resnet models. 224 | """ 225 | batch_norm_params = { 226 | 'decay': batch_norm_decay, 227 | 'epsilon': batch_norm_epsilon, 228 | 'scale': batch_norm_scale, 229 | 'updates_collections': batch_norm_updates_collections, 230 | 'fused': None, # Use fused batch norm if possible. 231 | } 232 | 233 | with slim.arg_scope( 234 | [slim.conv2d], 235 | weights_regularizer=slim.l2_regularizer(weight_decay), 236 | weights_initializer=slim.variance_scaling_initializer(), 237 | activation_fn=activation_fn, 238 | normalizer_fn=slim.batch_norm if use_batch_norm else None, 239 | normalizer_params=batch_norm_params): 240 | with slim.arg_scope([slim.batch_norm], **batch_norm_params): 241 | # The following implies padding='SAME' for pool1, which makes feature 242 | # alignment easier for dense prediction tasks. This is also used in 243 | # https://github.com/facebook/fb.resnet.torch. However the accompanying 244 | # code of 'Deep Residual Learning for Image Recognition' uses 245 | # padding='VALID' for pool1. You can switch to that choice by setting 246 | # slim.arg_scope([slim.max_pool2d], padding='VALID'). 247 | with slim.arg_scope([slim.max_pool2d], padding='SAME') as arg_sc: 248 | return arg_sc -------------------------------------------------------------------------------- /nets/resnet_v2.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains definitions for the preactivation form of Residual Networks. 16 | Residual networks (ResNets) were originally proposed in: 17 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 18 | Deep Residual Learning for Image Recognition. arXiv:1512.03385 19 | The full preactivation 'v2' ResNet variant implemented in this module was 20 | introduced by: 21 | [2] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 22 | Identity Mappings in Deep Residual Networks. arXiv: 1603.05027 23 | The key difference of the full preactivation 'v2' variant compared to the 24 | 'v1' variant in [1] is the use of batch normalization before every weight layer. 25 | Typical use: 26 | from tensorflow.contrib.slim.nets import resnet_v2 27 | ResNet-101 for image classification into 1000 classes: 28 | # inputs has shape [batch, 224, 224, 3] 29 | with slim.arg_scope(resnet_v2.resnet_arg_scope()): 30 | net, end_points = resnet_v2.resnet_v2_101(inputs, 1000, is_training=False) 31 | ResNet-101 for semantic segmentation into 21 classes: 32 | # inputs has shape [batch, 513, 513, 3] 33 | with slim.arg_scope(resnet_v2.resnet_arg_scope()): 34 | net, end_points = resnet_v2.resnet_v2_101(inputs, 35 | 21, 36 | is_training=False, 37 | global_pool=False, 38 | output_stride=16) 39 | """ 40 | from __future__ import absolute_import 41 | from __future__ import division 42 | from __future__ import print_function 43 | 44 | import tensorflow as tf 45 | 46 | from nets import resnet_utils 47 | 48 | slim = tf.contrib.slim 49 | resnet_arg_scope = resnet_utils.resnet_arg_scope 50 | 51 | 52 | @slim.add_arg_scope 53 | def bottleneck(inputs, depth, depth_bottleneck, stride, rate=1, 54 | outputs_collections=None, scope=None): 55 | """Bottleneck residual unit variant with BN before convolutions. 56 | This is the full preactivation residual unit variant proposed in [2]. See 57 | Fig. 1(b) of [2] for its definition. Note that we use here the bottleneck 58 | variant which has an extra bottleneck layer. 59 | When putting together two consecutive ResNet blocks that use this unit, one 60 | should use stride = 2 in the last unit of the first block. 61 | Args: 62 | inputs: A tensor of size [batch, height, width, channels]. 63 | depth: The depth of the ResNet unit output. 64 | depth_bottleneck: The depth of the bottleneck layers. 65 | stride: The ResNet unit's stride. Determines the amount of downsampling of 66 | the units output compared to its input. 67 | rate: An integer, rate for atrous convolution. 68 | outputs_collections: Collection to add the ResNet unit output. 69 | scope: Optional variable_scope. 70 | Returns: 71 | The ResNet unit's output. 72 | """ 73 | with tf.variable_scope(scope, 'bottleneck_v2', [inputs]) as sc: 74 | depth_in = slim.utils.last_dimension(inputs.get_shape(), min_rank=4) 75 | preact = slim.batch_norm(inputs, activation_fn=tf.nn.relu, scope='preact') 76 | if depth == depth_in: 77 | shortcut = resnet_utils.subsample(inputs, stride, 'shortcut') 78 | else: 79 | shortcut = slim.conv2d(preact, depth, [1, 1], stride=stride, 80 | normalizer_fn=None, activation_fn=None, 81 | scope='shortcut') 82 | 83 | residual = slim.conv2d(preact, depth_bottleneck, [1, 1], stride=1, 84 | scope='conv1') 85 | residual = resnet_utils.conv2d_same(residual, depth_bottleneck, 3, stride, 86 | rate=rate, scope='conv2') 87 | residual = slim.conv2d(residual, depth, [1, 1], stride=1, 88 | normalizer_fn=None, activation_fn=None, 89 | scope='conv3') 90 | 91 | output = shortcut + residual 92 | 93 | return slim.utils.collect_named_outputs(outputs_collections, 94 | sc.name, 95 | output) 96 | 97 | 98 | def resnet_v2(inputs, 99 | blocks, 100 | num_classes=None, 101 | is_training=True, 102 | global_pool=True, 103 | output_stride=None, 104 | include_root_block=True, 105 | spatial_squeeze=True, 106 | reuse=None, 107 | scope=None): 108 | """Generator for v2 (preactivation) ResNet models. 109 | This function generates a family of ResNet v2 models. See the resnet_v2_*() 110 | methods for specific model instantiations, obtained by selecting different 111 | block instantiations that produce ResNets of various depths. 112 | Training for image classification on Imagenet is usually done with [224, 224] 113 | inputs, resulting in [7, 7] feature maps at the output of the last ResNet 114 | block for the ResNets defined in [1] that have nominal stride equal to 32. 115 | However, for dense prediction tasks we advise that one uses inputs with 116 | spatial dimensions that are multiples of 32 plus 1, e.g., [321, 321]. In 117 | this case the feature maps at the ResNet output will have spatial shape 118 | [(height - 1) / output_stride + 1, (width - 1) / output_stride + 1] 119 | and corners exactly aligned with the input image corners, which greatly 120 | facilitates alignment of the features to the image. Using as input [225, 225] 121 | images results in [8, 8] feature maps at the output of the last ResNet block. 122 | For dense prediction tasks, the ResNet needs to run in fully-convolutional 123 | (FCN) mode and global_pool needs to be set to False. The ResNets in [1, 2] all 124 | have nominal stride equal to 32 and a good choice in FCN mode is to use 125 | output_stride=16 in order to increase the density of the computed features at 126 | small computational and memory overhead, cf. http://arxiv.org/abs/1606.00915. 127 | Args: 128 | inputs: A tensor of size [batch, height_in, width_in, channels]. 129 | blocks: A list of length equal to the number of ResNet blocks. Each element 130 | is a resnet_utils.Block object describing the units in the block. 131 | num_classes: Number of predicted classes for classification tasks. 132 | If 0 or None, we return the features before the logit layer. 133 | is_training: whether batch_norm layers are in training mode. 134 | global_pool: If True, we perform global average pooling before computing the 135 | logits. Set to True for image classification, False for dense prediction. 136 | output_stride: If None, then the output will be computed at the nominal 137 | network stride. If output_stride is not None, it specifies the requested 138 | ratio of input to output spatial resolution. 139 | include_root_block: If True, include the initial convolution followed by 140 | max-pooling, if False excludes it. If excluded, `inputs` should be the 141 | results of an activation-less convolution. 142 | spatial_squeeze: if True, logits is of shape [B, C], if false logits is 143 | of shape [B, 1, 1, C], where B is batch_size and C is number of classes. 144 | To use this parameter, the input images must be smaller than 300x300 145 | pixels, in which case the output logit layer does not contain spatial 146 | information and can be removed. 147 | reuse: whether or not the network and its variables should be reused. To be 148 | able to reuse 'scope' must be given. 149 | scope: Optional variable_scope. 150 | Returns: 151 | net: A rank-4 tensor of size [batch, height_out, width_out, channels_out]. 152 | If global_pool is False, then height_out and width_out are reduced by a 153 | factor of output_stride compared to the respective height_in and width_in, 154 | else both height_out and width_out equal one. If num_classes is 0 or None, 155 | then net is the output of the last ResNet block, potentially after global 156 | average pooling. If num_classes is a non-zero integer, net contains the 157 | pre-softmax activations. 158 | end_points: A dictionary from components of the network to the corresponding 159 | activation. 160 | Raises: 161 | ValueError: If the target output_stride is not valid. 162 | """ 163 | with tf.variable_scope(scope, 'resnet_v2', [inputs], reuse=reuse) as sc: 164 | end_points_collection = sc.original_name_scope + '_end_points' 165 | with slim.arg_scope([slim.conv2d, bottleneck, 166 | resnet_utils.stack_blocks_dense], 167 | outputs_collections=end_points_collection): 168 | with slim.arg_scope([slim.batch_norm], is_training=is_training): 169 | net = inputs 170 | if include_root_block: 171 | if output_stride is not None: 172 | if output_stride % 4 != 0: 173 | raise ValueError('The output_stride needs to be a multiple of 4.') 174 | output_stride /= 4 175 | # We do not include batch normalization or activation functions in 176 | # conv1 because the first ResNet unit will perform these. Cf. 177 | # Appendix of [2]. 178 | with slim.arg_scope([slim.conv2d], 179 | activation_fn=None, normalizer_fn=None): 180 | net = resnet_utils.conv2d_same(net, 64, 7, stride=2, scope='conv1') 181 | net = slim.max_pool2d(net, [3, 3], stride=2, scope='pool1') 182 | net = resnet_utils.stack_blocks_dense(net, blocks, output_stride) 183 | # This is needed because the pre-activation variant does not have batch 184 | # normalization or activation functions in the residual unit output. See 185 | # Appendix of [2]. 186 | net = slim.batch_norm(net, activation_fn=tf.nn.relu, scope='postnorm') 187 | # Convert end_points_collection into a dictionary of end_points. 188 | end_points = slim.utils.convert_collection_to_dict( 189 | end_points_collection) 190 | 191 | if global_pool: 192 | # Global average pooling. 193 | net = tf.reduce_mean(net, [1, 2], name='pool5', keep_dims=True) 194 | end_points['global_pool'] = net 195 | if num_classes: 196 | net = slim.conv2d(net, num_classes, [1, 1], activation_fn=None, 197 | normalizer_fn=None, scope='logits') 198 | end_points[sc.name + '/logits'] = net 199 | if spatial_squeeze: 200 | net = tf.squeeze(net, [1, 2], name='SpatialSqueeze') 201 | end_points[sc.name + '/spatial_squeeze'] = net 202 | end_points['predictions'] = slim.softmax(net, scope='predictions') 203 | return net, end_points 204 | resnet_v2.default_image_size = 224 205 | 206 | 207 | def resnet_v2_block(scope, base_depth, num_units, stride): 208 | """Helper function for creating a resnet_v2 bottleneck block. 209 | Args: 210 | scope: The scope of the block. 211 | base_depth: The depth of the bottleneck layer for each unit. 212 | num_units: The number of units in the block. 213 | stride: The stride of the block, implemented as a stride in the last unit. 214 | All other units have stride=1. 215 | Returns: 216 | A resnet_v2 bottleneck block. 217 | """ 218 | return resnet_utils.Block(scope, bottleneck, [{ 219 | 'depth': base_depth * 4, 220 | 'depth_bottleneck': base_depth, 221 | 'stride': 1 222 | }] * (num_units - 1) + [{ 223 | 'depth': base_depth * 4, 224 | 'depth_bottleneck': base_depth, 225 | 'stride': stride 226 | }]) 227 | resnet_v2.default_image_size = 224 228 | 229 | 230 | def resnet_v2_50(inputs, 231 | num_classes=None, 232 | is_training=True, 233 | global_pool=True, 234 | output_stride=None, 235 | spatial_squeeze=True, 236 | reuse=None, 237 | scope='resnet_v2_50'): 238 | """ResNet-50 model of [1]. See resnet_v2() for arg and return description.""" 239 | blocks = [ 240 | resnet_v2_block('block1', base_depth=64, num_units=3, stride=2), 241 | resnet_v2_block('block2', base_depth=128, num_units=4, stride=2), 242 | resnet_v2_block('block3', base_depth=256, num_units=6, stride=2), 243 | resnet_v2_block('block4', base_depth=512, num_units=3, stride=1), 244 | ] 245 | return resnet_v2(inputs, blocks, num_classes, is_training=is_training, 246 | global_pool=global_pool, output_stride=output_stride, 247 | include_root_block=True, spatial_squeeze=spatial_squeeze, 248 | reuse=reuse, scope=scope) 249 | resnet_v2_50.default_image_size = resnet_v2.default_image_size 250 | 251 | 252 | def resnet_v2_101(inputs, 253 | num_classes=None, 254 | is_training=True, 255 | global_pool=True, 256 | output_stride=None, 257 | spatial_squeeze=True, 258 | reuse=None, 259 | scope='resnet_v2_101'): 260 | """ResNet-101 model of [1]. See resnet_v2() for arg and return description.""" 261 | blocks = [ 262 | resnet_v2_block('block1', base_depth=64, num_units=3, stride=2), 263 | resnet_v2_block('block2', base_depth=128, num_units=4, stride=2), 264 | resnet_v2_block('block3', base_depth=256, num_units=23, stride=2), 265 | resnet_v2_block('block4', base_depth=512, num_units=3, stride=1), 266 | ] 267 | return resnet_v2(inputs, blocks, num_classes, is_training=is_training, 268 | global_pool=global_pool, output_stride=output_stride, 269 | include_root_block=True, spatial_squeeze=spatial_squeeze, 270 | reuse=reuse, scope=scope) 271 | resnet_v2_101.default_image_size = resnet_v2.default_image_size 272 | 273 | 274 | def resnet_v2_152(inputs, 275 | num_classes=None, 276 | is_training=True, 277 | global_pool=True, 278 | output_stride=None, 279 | spatial_squeeze=True, 280 | reuse=None, 281 | scope='resnet_v2_152'): 282 | """ResNet-152 model of [1]. See resnet_v2() for arg and return description.""" 283 | blocks = [ 284 | resnet_v2_block('block1', base_depth=64, num_units=3, stride=2), 285 | resnet_v2_block('block2', base_depth=128, num_units=8, stride=2), 286 | resnet_v2_block('block3', base_depth=256, num_units=36, stride=2), 287 | resnet_v2_block('block4', base_depth=512, num_units=3, stride=1), 288 | ] 289 | return resnet_v2(inputs, blocks, num_classes, is_training=is_training, 290 | global_pool=global_pool, output_stride=output_stride, 291 | include_root_block=True, spatial_squeeze=spatial_squeeze, 292 | reuse=reuse, scope=scope) 293 | resnet_v2_152.default_image_size = resnet_v2.default_image_size 294 | 295 | 296 | def resnet_v2_200(inputs, 297 | num_classes=None, 298 | is_training=True, 299 | global_pool=True, 300 | output_stride=None, 301 | spatial_squeeze=True, 302 | reuse=None, 303 | scope='resnet_v2_200'): 304 | """ResNet-200 model of [2]. See resnet_v2() for arg and return description.""" 305 | blocks = [ 306 | resnet_v2_block('block1', base_depth=64, num_units=3, stride=2), 307 | resnet_v2_block('block2', base_depth=128, num_units=24, stride=2), 308 | resnet_v2_block('block3', base_depth=256, num_units=36, stride=2), 309 | resnet_v2_block('block4', base_depth=512, num_units=3, stride=1), 310 | ] 311 | return resnet_v2(inputs, blocks, num_classes, is_training=is_training, 312 | global_pool=global_pool, output_stride=output_stride, 313 | include_root_block=True, spatial_squeeze=spatial_squeeze, 314 | reuse=reuse, scope=scope) 315 | resnet_v2_200.default_image_size = resnet_v2.default_image_size -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | 4 | import tensorflow as tf 5 | 6 | import numpy as np 7 | 8 | import train_data 9 | import val_data 10 | from nets import model 11 | from utils import train_utils, _train_helper 12 | 13 | slim = tf.contrib.slim 14 | 15 | flags = tf.app.flags 16 | FLAGS = flags.FLAGS 17 | 18 | 19 | # Settings for logging. 20 | flags.DEFINE_string('train_logdir', './tfmodels', 21 | 'Where the checkpoint and logs are stored.') 22 | flags.DEFINE_string('ckpt_name_to_save', 'gvcnn.ckpt', 23 | 'Name to save checkpoint file') 24 | flags.DEFINE_integer('log_steps', 10, 25 | 'Display logging information at every log_steps.') 26 | flags.DEFINE_integer('save_interval_secs', 1200, 27 | 'How often, in seconds, we save the model to disk.') 28 | flags.DEFINE_boolean('save_summaries_images', False, 29 | 'Save sample inputs, labels, and semantic predictions as ' 30 | 'images to summary.') 31 | flags.DEFINE_string('summaries_dir', './tfmodels/train_logs', 32 | 'Where to save summary logs for TensorBoard.') 33 | 34 | flags.DEFINE_enum('learning_policy', 'poly', ['poly', 'step'], 35 | 'Learning rate policy for training.') 36 | flags.DEFINE_float('base_learning_rate', .001, 37 | 'The base learning rate for model training.') 38 | flags.DEFINE_float('learning_rate_decay_factor', 1e-3, 39 | 'The rate to decay the base learning rate.') 40 | flags.DEFINE_float('learning_rate_decay_step', .3000, 41 | 'Decay the base learning rate at a fixed step.') 42 | flags.DEFINE_float('learning_power', 0.9, 43 | 'The power value used in the poly learning policy.') 44 | flags.DEFINE_float('training_number_of_steps', 300000, 45 | 'The number of steps used for training.') 46 | flags.DEFINE_float('momentum', 0.9, 'The momentum value to use') 47 | 48 | flags.DEFINE_float('last_layer_gradient_multiplier', 1.0, 49 | 'The gradient multiplier for last layers, which is used to ' 50 | 'boost the gradient of last layers if the value > 1.') 51 | # Set to False if one does not want to re-use the trained classifier weights. 52 | flags.DEFINE_boolean('initialize_last_layer', True, 53 | 'Initialize the last layer.') 54 | flags.DEFINE_boolean('last_layers_contain_logits_only', False, 55 | 'Only consider logits as last layers or not.') 56 | flags.DEFINE_integer('slow_start_step', 0, 57 | 'Training model with small learning rate for few steps.') 58 | flags.DEFINE_float('slow_start_learning_rate', 1e-4, 59 | 'Learning rate employed during slow start.') 60 | 61 | # Settings for fine-tuning the network. 62 | flags.DEFINE_string('saved_checkpoint_dir', 63 | # './tfmodels', 64 | None, 65 | 'Saved checkpoint dir.') 66 | flags.DEFINE_string('pre_trained_checkpoint', 67 | None, 68 | 'The pre-trained checkpoint in tensorflow format.') 69 | flags.DEFINE_string('checkpoint_exclude_scopes', 70 | None, 71 | 'Comma-separated list of scopes of variables to exclude ' 72 | 'when restoring from a checkpoint.') 73 | flags.DEFINE_string('trainable_scopes', 74 | None, 75 | 'Comma-separated list of scopes to filter the set of variables ' 76 | 'to train. By default, None would train all the variables.') 77 | flags.DEFINE_string('checkpoint_model_scope', 78 | None, 79 | 'Model scope in the checkpoint. None if the same as the trained model.') 80 | flags.DEFINE_string('model_name', 81 | 'resnet_v2_50', 82 | 'The name of the architecture to train.') 83 | flags.DEFINE_boolean('ignore_missing_vars', 84 | False, 85 | 'When restoring a checkpoint would ignore missing variables.') 86 | 87 | # Dataset settings. 88 | flags.DEFINE_string('dataset_dir', '/home/ace19/dl_data/modelnet5', 89 | 'Where the dataset reside.') 90 | 91 | flags.DEFINE_integer('how_many_training_epochs', 100, 92 | 'How many training loops to runs') 93 | 94 | flags.DEFINE_integer('batch_size', 4, 'batch size') 95 | flags.DEFINE_integer('val_batch_size', 4, 'val batch size') 96 | flags.DEFINE_integer('num_views', 6, 'number of views') 97 | flags.DEFINE_integer('num_group', 10, 'number of group') 98 | flags.DEFINE_integer('height', 299, 'height') 99 | flags.DEFINE_integer('width', 299, 'width') 100 | flags.DEFINE_string('labels', 101 | # 'airplane,bed,bookshelf,bottle,chair,monitor,sofa,table,toilet,vase', 102 | 'bottle,monitor,table,toilet,vase', 103 | 'number of classes') 104 | 105 | # check total count before training 106 | MODELNET_TRAIN_DATA_SIZE = 392+335+344+475+465 # 5 class 107 | MODELNET_VALIDATE_DATA_SIZE = 500 108 | 109 | 110 | 111 | def main(unused_argv): 112 | tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO) 113 | 114 | labels = FLAGS.labels.split(',') 115 | num_classes = len(labels) 116 | 117 | with tf.Graph().as_default() as graph: 118 | global_step = tf.compat.v1.train.get_or_create_global_step() 119 | 120 | # Define the model 121 | X = tf.compat.v1.placeholder(tf.float32, 122 | [None, FLAGS.num_views, FLAGS.height, FLAGS.width, 3], 123 | name='X') 124 | ground_truth = tf.compat.v1.placeholder(tf.int64, [None], name='ground_truth') 125 | is_training = tf.compat.v1.placeholder(tf.bool, name='is_training') 126 | dropout_keep_prob = tf.compat.v1.placeholder(tf.float32, name='dropout_keep_prob') 127 | g_scheme = tf.compat.v1.placeholder(tf.int32, [FLAGS.num_group, FLAGS.num_views]) 128 | g_weight = tf.compat.v1.placeholder(tf.float32, [FLAGS.num_group]) 129 | 130 | # GVCNN 131 | view_scores, _, logits = model.gvcnn(X, 132 | num_classes, 133 | g_scheme, 134 | g_weight, 135 | is_training, 136 | dropout_keep_prob) 137 | 138 | # # basic - for verification 139 | # _, logits = model.basic(X, 140 | # num_classes, 141 | # is_training, 142 | # dropout_keep_prob) 143 | 144 | # Define loss 145 | _loss = tf.losses.sparse_softmax_cross_entropy(labels=ground_truth, logits=logits) 146 | 147 | # Gather initial summaries. 148 | summaries = set(tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.SUMMARIES)) 149 | 150 | prediction = tf.argmax(logits, 1, name='prediction') 151 | correct_prediction = tf.equal(prediction, ground_truth) 152 | confusion_matrix = tf.math.confusion_matrix(ground_truth, 153 | prediction, 154 | num_classes=num_classes) 155 | accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) 156 | summaries.add(tf.compat.v1.summary.scalar('accuracy', accuracy)) 157 | 158 | # # Add summaries for model variables. 159 | # for model_var in slim.get_model_variables(): 160 | # summaries.add(tf.compat.v1.summary.histogram(model_var.op.name, model_var)) 161 | 162 | # Add summaries for losses. 163 | for loss in tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.LOSSES): 164 | summaries.add(tf.compat.v1.summary.scalar('losses/%s' % loss.op.name, loss)) 165 | 166 | learning_rate = train_utils.get_model_learning_rate( 167 | FLAGS.learning_policy, FLAGS.base_learning_rate, 168 | FLAGS.learning_rate_decay_step, FLAGS.learning_rate_decay_factor, 169 | FLAGS.training_number_of_steps, FLAGS.learning_power, 170 | FLAGS.slow_start_step, FLAGS.slow_start_learning_rate) 171 | optimizer = tf.compat.v1.train.MomentumOptimizer(learning_rate, FLAGS.momentum) 172 | summaries.add(tf.compat.v1.summary.scalar('learning_rate', learning_rate)) 173 | 174 | total_loss, grads_and_vars = train_utils.optimize(optimizer) 175 | total_loss = tf.debugging.check_numerics(total_loss, 'Loss is inf or nan.') 176 | summaries.add(tf.compat.v1.summary.scalar('total_loss', total_loss)) 177 | 178 | # Gather update_ops. 179 | # These contain, for example, the updates for the batch_norm variables created by model. 180 | update_ops = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.UPDATE_OPS) 181 | 182 | # Create gradient update op. 183 | update_ops.append(optimizer.apply_gradients(grads_and_vars, 184 | global_step=global_step)) 185 | update_op = tf.group(*update_ops) 186 | with tf.control_dependencies([update_op]): 187 | train_op = tf.identity(total_loss, name='train_op') 188 | 189 | 190 | ################ 191 | # Prepare data 192 | ################ 193 | filenames = tf.compat.v1.placeholder(tf.string, shape=[]) 194 | tr_dataset = train_data.Dataset(filenames, 195 | FLAGS.num_views, 196 | FLAGS.height, 197 | FLAGS.width, 198 | FLAGS.batch_size) 199 | iterator = tr_dataset.dataset.make_initializable_iterator() 200 | next_batch = iterator.get_next() 201 | 202 | # validation dateset 203 | val_dataset = val_data.Dataset(filenames, 204 | FLAGS.num_views, 205 | FLAGS.height, 206 | FLAGS.width, 207 | FLAGS.val_batch_size) # val_batch_size 208 | val_iterator = val_dataset.dataset.make_initializable_iterator() 209 | val_next_batch = val_iterator.get_next() 210 | 211 | sess_config = tf.compat.v1.ConfigProto(gpu_options=tf.compat.v1.GPUOptions(allow_growth=True)) 212 | with tf.compat.v1.Session(config=sess_config) as sess: 213 | sess.run(tf.compat.v1.global_variables_initializer()) 214 | 215 | # Add the summaries. These contain the summaries 216 | # created by model and either optimize() or _gather_loss(). 217 | summaries |= set(tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.SUMMARIES)) 218 | 219 | # Merge all summaries together. 220 | summary_op = tf.compat.v1.summary.merge(list(summaries)) 221 | train_writer = tf.compat.v1.summary.FileWriter(FLAGS.summaries_dir, graph) 222 | validation_writer = tf.compat.v1.summary.FileWriter(FLAGS.summaries_dir + '/validation', graph) 223 | 224 | # Create a saver object which will save all the variables 225 | saver = tf.compat.v1.train.Saver(keep_checkpoint_every_n_hours=1.0) 226 | if FLAGS.pre_trained_checkpoint: 227 | train_utils.restore_fn(FLAGS) 228 | 229 | if FLAGS.saved_checkpoint_dir: 230 | if tf.gfile.IsDirectory(FLAGS.saved_checkpoint_dir): 231 | checkpoint_path = tf.train.latest_checkpoint(FLAGS.saved_checkpoint_dir) 232 | else: 233 | checkpoint_path = FLAGS.saved_checkpoint_dir 234 | saver.restore(sess, checkpoint_path) 235 | 236 | start_epoch = 0 237 | # Get the number of training/validation steps per epoch 238 | tr_batches = int(MODELNET_TRAIN_DATA_SIZE / FLAGS.batch_size) 239 | if MODELNET_TRAIN_DATA_SIZE % FLAGS.batch_size > 0: 240 | tr_batches += 1 241 | val_batches = int(MODELNET_VALIDATE_DATA_SIZE / FLAGS.val_batch_size) 242 | if MODELNET_VALIDATE_DATA_SIZE % FLAGS.val_batch_size > 0: 243 | val_batches += 1 244 | 245 | # The filenames argument to the TFRecordDataset initializer can either be a string, 246 | # a list of strings, or a tf.Tensor of strings. 247 | training_filenames = os.path.join(FLAGS.dataset_dir, 'modelnet5_6view_train.record') 248 | validate_filenames = os.path.join(FLAGS.dataset_dir, 'modelnet5_6view_test.record') 249 | 250 | ################################### 251 | # Training loop. 252 | ################################### 253 | for num_epoch in range(start_epoch, FLAGS.how_many_training_epochs): 254 | print("-------------------------------------") 255 | print(" Epoch {} ".format(num_epoch)) 256 | print("-------------------------------------") 257 | 258 | sess.run(iterator.initializer, feed_dict={filenames: training_filenames}) 259 | for step in range(tr_batches): 260 | # Pull the image batch we'll use for training. 261 | train_batch_xs, train_batch_ys = sess.run(next_batch) 262 | 263 | # Sets up a graph with feeds and fetches for partial run. 264 | handle = sess.partial_run_setup([view_scores, learning_rate, 265 | # summary_op, top1_acc, loss, optimize_op, dummy], 266 | summary_op, accuracy, _loss, train_op], 267 | [X, ground_truth, g_scheme, g_weight, 268 | is_training, dropout_keep_prob]) 269 | 270 | _view_scores = sess.partial_run(handle, 271 | [view_scores], 272 | feed_dict={ 273 | X: train_batch_xs, 274 | is_training: True, 275 | dropout_keep_prob: 0.8} 276 | ) 277 | _g_schemes = model.group_scheme(_view_scores, FLAGS.num_group, FLAGS.num_views) 278 | _g_weights = model.group_weight(_g_schemes) 279 | 280 | # Run the graph with this batch of training data. 281 | lr, train_summary, train_accuracy, train_loss, _ = \ 282 | sess.partial_run(handle, 283 | [learning_rate, summary_op, accuracy, _loss, train_op], 284 | feed_dict={ 285 | ground_truth: train_batch_ys, 286 | g_scheme: _g_schemes, 287 | g_weight: _g_weights} 288 | ) 289 | 290 | # for verification 291 | # lr, train_summary, train_accuracy, train_loss, _ = \ 292 | # sess.run([learning_rate, summary_op, accuracy, _loss, train_op], 293 | # feed_dict={ 294 | # X: train_batch_xs, 295 | # ground_truth: train_batch_ys, 296 | # is_training: True, 297 | # dropout_keep_prob: 0.8} 298 | # ) 299 | 300 | train_writer.add_summary(train_summary, num_epoch) 301 | tf.compat.v1.logging.info('Epoch #%d, Step #%d, rate %.6f, top1_acc %.3f%%, loss %.5f' % 302 | (num_epoch, step, lr, train_accuracy, train_loss)) 303 | 304 | 305 | ################################################### 306 | # Validate the model on the validation set 307 | ################################################### 308 | tf.compat.v1.logging.info('--------------------------') 309 | tf.compat.v1.logging.info(' Start validation ') 310 | tf.compat.v1.logging.info('--------------------------') 311 | 312 | total_val_losses = 0.0 313 | total_val_top1_acc = 0.0 314 | val_count = 0 315 | total_conf_matrix = None 316 | 317 | # Reinitialize val_iterator with the validation dataset 318 | sess.run(val_iterator.initializer, feed_dict={filenames: validate_filenames}) 319 | for step in range(val_batches): 320 | validation_batch_xs, validation_batch_ys = sess.run(val_next_batch) 321 | 322 | # Sets up a graph with feeds and fetches for partial run. 323 | handle = sess.partial_run_setup([view_scores, summary_op, 324 | accuracy, _loss, confusion_matrix], 325 | [X, g_scheme, g_weight, 326 | ground_truth, is_training, dropout_keep_prob]) 327 | 328 | _view_scores = sess.partial_run(handle, 329 | [view_scores], 330 | feed_dict={ 331 | X: validation_batch_xs, 332 | is_training: False, 333 | dropout_keep_prob: 1.0} 334 | ) 335 | _g_schemes = model.group_scheme(_view_scores, FLAGS.num_group, FLAGS.num_views) 336 | _g_weights = model.group_weight(_g_schemes) 337 | 338 | # Run the graph with this batch of training data. 339 | val_summary, val_accuracy, val_loss, conf_matrix = \ 340 | sess.partial_run(handle, 341 | [summary_op, accuracy, _loss, confusion_matrix], 342 | feed_dict={ 343 | ground_truth: validation_batch_ys, 344 | g_scheme: _g_schemes, 345 | g_weight: _g_weights} 346 | ) 347 | 348 | # for verification 349 | # val_summary, val_accuracy, val_loss, conf_matrix = \ 350 | # sess.run([summary_op, accuracy, _loss, confusion_matrix], 351 | # feed_dict={ 352 | # X: validation_batch_xs, 353 | # ground_truth: validation_batch_ys, 354 | # is_training: False, 355 | # dropout_keep_prob: 1.0} 356 | # ) 357 | 358 | validation_writer.add_summary(val_summary, num_epoch) 359 | 360 | total_val_losses += val_loss 361 | total_val_top1_acc += val_accuracy 362 | val_count += 1 363 | if total_conf_matrix is None: 364 | total_conf_matrix = conf_matrix 365 | else: 366 | total_conf_matrix += conf_matrix 367 | 368 | total_val_losses /= val_count 369 | total_val_top1_acc /= val_count 370 | 371 | tf.compat.v1.logging.info('Confusion Matrix:\n %s' % total_conf_matrix) 372 | tf.compat.v1.logging.info('Validation loss = %.5f' % total_val_losses) 373 | tf.compat.v1.logging.info('Validation accuracy = %.3f%% (N=%d)' % 374 | (total_val_top1_acc, MODELNET_VALIDATE_DATA_SIZE)) 375 | 376 | # Save the model checkpoint periodically. 377 | if (num_epoch <= FLAGS.how_many_training_epochs-1): 378 | checkpoint_path = os.path.join(FLAGS.train_logdir, FLAGS.ckpt_name_to_save) 379 | tf.compat.v1.logging.info('Saving to "%s-%d"', checkpoint_path, num_epoch) 380 | saver.save(sess, checkpoint_path, global_step=num_epoch) 381 | 382 | 383 | if __name__ == '__main__': 384 | tf.compat.v1.logging.info('Creating train logdir: %s', FLAGS.train_logdir) 385 | tf.io.gfile.makedirs(FLAGS.train_logdir) 386 | 387 | tf.compat.v1.app.run() 388 | -------------------------------------------------------------------------------- /train_data.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | import random 4 | 5 | 6 | MEAN=[0.485, 0.456, 0.406] 7 | STD=[0.229, 0.224, 0.225] 8 | 9 | 10 | class Dataset(object): 11 | """ 12 | Wrapper class around the new Tensorflows dataset pipeline. 13 | 14 | Handles loading, partitioning, and preparing training data. 15 | """ 16 | 17 | def __init__(self, tfrecord_path, num_views, height, width, batch_size=1): 18 | self.num_views = num_views 19 | self.resize_h = height 20 | self.resize_w = width 21 | 22 | self.dataset = tf.data.TFRecordDataset(tfrecord_path, 23 | compression_type='GZIP', 24 | num_parallel_reads=batch_size * 4) 25 | 26 | # self.dataset = self.dataset.map(self._parse_func, num_parallel_calls=8) 27 | # The map transformation takes a function and applies it to every element 28 | # of the dataset. 29 | self.dataset = self.dataset.map(self.decode, num_parallel_calls=8) 30 | self.dataset = self.dataset.map(self.augment, num_parallel_calls=8) 31 | self.dataset = self.dataset.map(self.normalize, num_parallel_calls=8) 32 | 33 | # Prefetches a batch at a time to smooth out the time taken to load input 34 | # files for shuffling and processing. 35 | self.dataset = self.dataset.prefetch(buffer_size=batch_size) 36 | # The shuffle transformation uses a finite-sized buffer to shuffle elements 37 | # in memory. The parameter is the number of elements in the buffer. For 38 | # completely uniform shuffling, set the parameter to be the same as the 39 | # number of elements in the dataset. 40 | self.dataset = self.dataset.shuffle(1000 + 3 * batch_size) 41 | self.dataset = self.dataset.repeat() 42 | self.dataset = self.dataset.batch(batch_size) 43 | 44 | 45 | def decode(self, serialized_example): 46 | """Parses an image and label from the given `serialized_example`.""" 47 | features = tf.io.parse_single_example( 48 | serialized_example, 49 | # Defaults are not specified since both keys are required. 50 | features={ 51 | # 'image/filename': tf.io.FixedLenFeature([], tf.string), 52 | 'image/encoded': tf.io.FixedLenFeature([self.num_views], tf.string), 53 | 'image/label': tf.io.FixedLenFeature([], tf.int64), 54 | }) 55 | 56 | images = [] 57 | # labels = [] 58 | img_lst = tf.unstack(features['image/encoded']) 59 | # lbl_lst = tf.unstack(features['image/label']) 60 | for i, img in enumerate(img_lst): 61 | # Convert from a scalar string tensor to a float32 tensor with shape 62 | image_decoded = tf.image.decode_png(img, channels=3) 63 | image = tf.image.resize(image_decoded, [self.resize_h, self.resize_w]) 64 | images.append(image) 65 | # labels.append(lbl_lst[i]) 66 | 67 | # Convert label from a scalar uint8 tensor to an int32 scalar. 68 | label = features['image/label'] 69 | 70 | return images, label 71 | 72 | def augment(self, images, label): 73 | """Placeholder for data augmentation.""" 74 | # OPTIONAL: Could reshape into a 28x28 image and apply distortions 75 | # here. Since we are not applying any distortions in this 76 | # example, and the next step expects the image to be flattened 77 | # into a vector, we don't bother. 78 | img_lst = [] 79 | img_tensor_lst = tf.unstack(images) 80 | for i, image in enumerate(img_tensor_lst): 81 | image = tf.image.random_flip_left_right(image) 82 | image = tf.image.random_flip_up_down(image) 83 | # image = tf.image.rot90(image, k=random.randint(0, 4)) 84 | image = tf.image.random_brightness(image, max_delta=1.1) 85 | # image = tf.image.random_contrast(image, lower=0.1, upper=1.1) 86 | # image = tf.image.random_hue(image, max_delta=0.04) 87 | # image = tf.image.random_saturation(image, lower=0.1, upper=1.1) 88 | # image = tf.image.resize(image, [self.resize_h, self.resize_w]) 89 | 90 | img_lst.append(image) 91 | 92 | return img_lst, label 93 | 94 | 95 | def normalize(self, images, label): 96 | # input[channel] = (input[channel] - mean[channel]) / std[channel] 97 | img_lst = [] 98 | img_tensor_lst = tf.unstack(images) 99 | for i, image in enumerate(img_tensor_lst): 100 | # image = tf.cast(image, tf.float32) * (1. / 255) - 0.5 101 | img_lst.append(tf.cast(image, tf.float32) * (1. / 255) - 0.5) 102 | # img_lst.append(tf.div(tf.subtract(image, MEAN), STD)) 103 | 104 | return img_lst, label 105 | -------------------------------------------------------------------------------- /unit_test.py: -------------------------------------------------------------------------------- 1 | 2 | import tensorflow as tf 3 | 4 | 5 | def main(unused_argv): 6 | tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO) 7 | 8 | with tf.Graph().as_default() as graph: 9 | # x = tf.constant([[8, 1, 220, 55], [3, 4, 3, -1]]) 10 | # x_max = tf.reduce_max(x, axis=0) 11 | # 12 | # y = tf.constant([[9, 9, 300, 8], [5, 5, -3, 0]]) 13 | # x_y = tf.stack([x,y]) 14 | # 15 | # x_y_max = tf.reduce_max(x_y, axis=0) 16 | # _max = tf.math.maximum(x, y) 17 | 18 | final_view_descriptors = tf.constant([[8, 1, 220, 55], [3, 4, 3, -1], [54, 1, 6, -53], [-3, -4, 35, -1], [0, 34, 0, -23]]) 19 | group_scheme = tf.constant([[0, 1, 0, 0, 0], [0, 0, 1, 0, 0], [0, 0, 0, 0, 0], [1, 0, 0, 1, 1], [0, 0, 0, 0, 0]]) 20 | 21 | group_descriptors = {} 22 | dummy = tf.zeros_like(final_view_descriptors[0]) 23 | 24 | scheme_list = tf.unstack(group_scheme) 25 | indices = [tf.squeeze(tf.where(elem), axis=1) for elem in scheme_list] 26 | for i, ind in enumerate(indices): 27 | pooled_view = tf.cond(tf.greater(tf.size(ind), 0), 28 | lambda: tf.gather(final_view_descriptors, ind), 29 | lambda: tf.expand_dims(dummy, 0)) 30 | 31 | group_descriptors[i] = tf.reduce_mean(pooled_view, axis=0) 32 | 33 | sess_config = tf.compat.v1.ConfigProto(gpu_options=tf.compat.v1.GPUOptions(allow_growth=True)) 34 | with tf.compat.v1.Session(config=sess_config) as sess: 35 | sess.run(tf.compat.v1.global_variables_initializer()) 36 | 37 | # print(sess.run(x_max)) 38 | # print(sess.run(x_y_max)) 39 | # print(sess.run(_max)) 40 | 41 | print(sess.run(group_descriptors)) 42 | 43 | 44 | 45 | if __name__ == '__main__': 46 | tf.compat.v1.app.run() -------------------------------------------------------------------------------- /utils/_train_helper.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import logging 3 | import tensorflow as tf 4 | 5 | 6 | def allreduce_grads(all_grads, average=True): 7 | """ 8 | REFERENCE : https://github.com/ppwwyyxx/tensorpack/blob/83e4e187af5765792408e7b7163efd4744d63628/tensorpack/graph_builder/utils.py 9 | All-reduce average the gradients among K devices. Results are broadcasted to all devices. 10 | Args: 11 | all_grads (K x N): List of list of gradients. N is the number of variables. 12 | average (bool): average gradients or not. 13 | Returns: 14 | K x N: same as input, but each grad is replaced by the average over K devices. 15 | """ 16 | # from tensorflow.contrib import nccl 17 | from tensorflow.python.ops import nccl_ops 18 | nr_tower = len(all_grads) 19 | if nr_tower == 1: 20 | return all_grads 21 | new_all_grads = [] # N x K 22 | for grads in zip(*all_grads): 23 | summed = nccl_ops.all_sum(grads) 24 | 25 | grads_for_devices = [] # K 26 | for g in summed: 27 | with tf.device(g.device): 28 | # tensorflow/benchmarks didn't average gradients 29 | if average: 30 | g = tf.multiply(g, 1.0 / nr_tower, name='allreduce_avg') 31 | grads_for_devices.append(g) 32 | new_all_grads.append(grads_for_devices) 33 | 34 | # transpose to K x N 35 | ret = list(zip(*new_all_grads)) 36 | return ret 37 | 38 | 39 | def split_grad_list(grad_list): 40 | """ 41 | Args: 42 | grad_list: K x N x 2 43 | Returns: 44 | K x N: gradients 45 | K x N: variables 46 | """ 47 | g = [] 48 | v = [] 49 | for tower in grad_list: 50 | g.append([x[0] for x in tower]) 51 | v.append([x[1] for x in tower]) 52 | return g, v 53 | 54 | 55 | def merge_grad_list(all_grads, all_vars): 56 | """ 57 | Args: 58 | all_grads (K x N): gradients 59 | all_vars(K x N): variables 60 | Return: 61 | K x N x 2: list of list of (grad, var) pairs 62 | """ 63 | return [list(zip(gs, vs)) for gs, vs in zip(all_grads, all_vars)] 64 | 65 | 66 | def get_post_init_ops(): 67 | """ 68 | Copy values of variables on GPU 0 to other GPUs. 69 | """ 70 | # literally all variables, because it's better to sync optimizer-internal variables as well 71 | all_vars = tf.compat.v1.global_variables() + tf.compat.v1.local_variables() 72 | var_by_name = dict([(v.name, v) for v in all_vars]) 73 | post_init_ops = [] 74 | for v in all_vars: 75 | if not v.name.startswith('tower'): 76 | continue 77 | if v.name.startswith('tower0'): 78 | # no need for copy to tower0 79 | continue 80 | # in this trainer, the master name doesn't have the towerx/ prefix 81 | split_name = v.name.split('/') 82 | prefix = split_name[0] 83 | realname = '/'.join(split_name[1:]) 84 | if prefix in realname: 85 | # logger.warning("variable {} has its prefix {} appears multiple times in its name!".format(v.name, prefix)) 86 | pass 87 | copy_from = var_by_name.get(v.name.replace(prefix, 'tower0')) 88 | if copy_from is not None: 89 | post_init_ops.append(v.assign(copy_from.read_value())) 90 | else: 91 | print("Cannot find {} in the graph!".format(realname)) 92 | 93 | print("'sync_variables_from_main_tower' includes {} operations.".format(len(post_init_ops))) 94 | return tf.group(*post_init_ops, name='sync_variables_from_main_tower') 95 | -------------------------------------------------------------------------------- /utils/downsize_modelnet.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import os 6 | import shutil 7 | 8 | import tensorflow as tf 9 | 10 | 11 | flags = tf.app.flags 12 | flags.DEFINE_string('dataset_dir', 13 | '/home/ace19/dl_data/modelnet12/view/classes', 14 | 'Root Directory to modelnet12 dataset.') 15 | # flags.DEFINE_string('output_path', 16 | # '/home/ace19/dl_data/modelnet10_sv', 17 | # 'Path to output') 18 | flags.DEFINE_string('dataset_category', 19 | 'test', 20 | 'dataset category, train|validate|test') 21 | 22 | FLAGS = flags.FLAGS 23 | 24 | 25 | def main(_): 26 | tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO) 27 | 28 | cls_lst = os.listdir(FLAGS.dataset_dir) 29 | for i, label in enumerate(cls_lst): 30 | category_path = os.path.join(FLAGS.dataset_dir, label, FLAGS.dataset_category) 31 | off_lst = os.listdir(category_path) 32 | for off_file in off_lst: 33 | off_path = os.path.join(category_path, off_file) 34 | img_lst = os.listdir(off_path) 35 | for img in img_lst: 36 | img_path = os.path.join(off_path, img) 37 | if os.path.isfile(img_path): 38 | if not int(img.split('.')[1]) % 2 == 0: 39 | os.remove(img_path) 40 | 41 | if __name__ == '__main__': 42 | tf.compat.v1.app.run() -------------------------------------------------------------------------------- /utils/train_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The TensorFlow Authors All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Utility functions for training.""" 16 | 17 | import tensorflow as tf 18 | 19 | slim = tf.contrib.slim 20 | 21 | LOGITS_SCOPE_NAME = 'logits' 22 | MERGED_LOGITS_SCOPE = 'merged_logits' 23 | IMAGE_POOLING_SCOPE = 'image_pooling' 24 | ASPP_SCOPE = 'aspp' 25 | CONCAT_PROJECTION_SCOPE = 'concat_projection' 26 | DECODER_SCOPE = 'decoder' 27 | META_ARCHITECTURE_SCOPE = 'meta_architecture' 28 | 29 | 30 | def get_model_gradient_multipliers(last_layers, last_layer_gradient_multiplier): 31 | """Gets the gradient multipliers. 32 | 33 | The gradient multipliers will adjust the learning rates for model 34 | variables. For the task of semantic segmentation, the models are 35 | usually fine-tuned from the models trained on the task of image 36 | classification. To fine-tune the models, we usually set larger (e.g., 37 | 10 times larger) learning rate for the parameters of last layer. 38 | 39 | Args: 40 | last_layers: Scopes of last layers. 41 | last_layer_gradient_multiplier: The gradient multiplier for last layers. 42 | 43 | Returns: 44 | The gradient multiplier map with variables as key, and multipliers as value. 45 | """ 46 | gradient_multipliers = {} 47 | 48 | for var in slim.get_model_variables(): 49 | # Double the learning rate for biases. 50 | if 'biases' in var.op.name: 51 | gradient_multipliers[var.op.name] = 2. 52 | 53 | # Use larger learning rate for last layer variables. 54 | for layer in last_layers: 55 | if layer in var.op.name and 'biases' in var.op.name: 56 | gradient_multipliers[var.op.name] = 2 * last_layer_gradient_multiplier 57 | break 58 | elif layer in var.op.name: 59 | gradient_multipliers[var.op.name] = last_layer_gradient_multiplier 60 | break 61 | 62 | return gradient_multipliers 63 | 64 | 65 | def get_model_learning_rate( 66 | learning_policy, base_learning_rate, learning_rate_decay_step, 67 | learning_rate_decay_factor, training_number_of_steps, learning_power, 68 | slow_start_step, slow_start_learning_rate): 69 | """Gets model's learning rate. 70 | 71 | Computes the model's learning rate for different learning policy. 72 | Right now, only "step" and "poly" are supported. 73 | (1) The learning policy for "step" is computed as follows: 74 | current_learning_rate = base_learning_rate * 75 | learning_rate_decay_factor ^ (global_step / learning_rate_decay_step) 76 | See tf.compat.v1.train.exponential_decay for details. 77 | (2) The learning policy for "poly" is computed as follows: 78 | current_learning_rate = base_learning_rate * 79 | (1 - global_step / training_number_of_steps) ^ learning_power 80 | 81 | Args: 82 | learning_policy: Learning rate policy for training. 83 | base_learning_rate: The base learning rate for model training. 84 | learning_rate_decay_step: Decay the base learning rate at a fixed step. 85 | learning_rate_decay_factor: The rate to decay the base learning rate. 86 | training_number_of_steps: Number of steps for training. 87 | learning_power: Power used for 'poly' learning policy. 88 | slow_start_step: Training model with small learning rate for the first 89 | few steps. 90 | slow_start_learning_rate: The learning rate employed during slow start. 91 | 92 | Returns: 93 | Learning rate for the specified learning policy. 94 | 95 | Raises: 96 | ValueError: If learning policy is not recognized. 97 | """ 98 | global_step = tf.compat.v1.train.get_or_create_global_step() 99 | if learning_policy == 'step': 100 | learning_rate = tf.compat.v1.train.exponential_decay( 101 | base_learning_rate, 102 | global_step, 103 | learning_rate_decay_step, 104 | learning_rate_decay_factor, 105 | staircase=True) 106 | elif learning_policy == 'poly': 107 | learning_rate = tf.compat.v1.train.polynomial_decay( 108 | base_learning_rate, 109 | global_step, 110 | training_number_of_steps, 111 | end_learning_rate=0, 112 | power=learning_power) 113 | else: 114 | raise ValueError('Unknown learning policy.') 115 | 116 | # Employ small learning rate at the first few steps for warm start. 117 | return tf.where(global_step < slow_start_step, slow_start_learning_rate, 118 | learning_rate) 119 | 120 | 121 | def _gather_loss(regularization_losses, scope): 122 | """Gather the loss. 123 | 124 | Args: 125 | regularization_losses: Possibly empty list of regularization_losses 126 | to add to the losses. 127 | 128 | Returns: 129 | A tensor for the total loss. Can be None. 130 | """ 131 | 132 | # The return value. 133 | sum_loss = None 134 | # Individual components of the loss that will need summaries. 135 | loss = None 136 | regularization_loss = None 137 | 138 | # Compute and aggregate losses on the clone device. 139 | all_losses = [] 140 | losses = tf.compat.v1.get_collection(tf.GraphKeys.LOSSES, scope) 141 | if losses: 142 | loss = tf.add_n(losses, name='losses') 143 | # if num_clones > 1: 144 | # clone_loss = tf.div(clone_loss, 1.0 * num_clones, 145 | # name='scaled_clone_loss') 146 | all_losses.append(loss) 147 | if regularization_losses: 148 | regularization_loss = tf.add_n(regularization_losses, 149 | name='regularization_loss') 150 | all_losses.append(regularization_loss) 151 | if all_losses: 152 | sum_loss = tf.add_n(all_losses) 153 | 154 | # Add the summaries out of the clone device block. 155 | if loss is not None: 156 | tf.compat.v1.summary.scalar('/'.join(filter(None, 157 | ['Losses', 'loss'])), 158 | loss) 159 | if regularization_loss is not None: 160 | tf.compat.v1.summary.scalar('Losses/regularization_loss', regularization_loss) 161 | return sum_loss 162 | 163 | 164 | def _optimize(optimizer, regularization_losses, scope, **kwargs): 165 | """Compute losses and gradients. 166 | 167 | Args: 168 | optimizer: A tf.Optimizer object. 169 | regularization_losses: Possibly empty list of regularization_losses 170 | to add to the losses. 171 | **kwargs: Dict of kwarg to pass to compute_gradients(). 172 | 173 | Returns: 174 | A tuple (loss, grads_and_vars). 175 | - loss: A tensor for the total loss. Can be None. 176 | - grads_and_vars: List of (gradient, variable). Can be empty. 177 | """ 178 | sum_loss = _gather_loss(regularization_losses, scope) 179 | grad = None 180 | if sum_loss is not None: 181 | grad = optimizer.compute_gradients(sum_loss, **kwargs) 182 | return sum_loss, grad 183 | 184 | 185 | def _gradients(grad): 186 | """Calculate the sum gradient for each shared variable across all clones. 187 | 188 | This function assumes that the grad has been scaled appropriately by 189 | 1 / num_clones. 190 | 191 | Args: 192 | grad: A List of List of tuples (gradient, variable) 193 | 194 | Returns: 195 | tuples of (gradient, variable) 196 | """ 197 | sum_grads = [] 198 | for grad_and_vars in zip(*grad): 199 | # Note that each grad_and_vars looks like the following: 200 | # ((grad_var0_clone0, var0), ... (grad_varN_cloneN, varN)) 201 | grads = [] 202 | var = grad_and_vars[0][1] 203 | for g, v in grad_and_vars: 204 | assert v == var 205 | if g is not None: 206 | grads.append(g) 207 | if grads: 208 | if len(grads) > 1: 209 | sum_grad = tf.add_n(grads, name=var.op.name + '/sum_grads') 210 | else: 211 | sum_grad = grads[0] 212 | sum_grads.append((sum_grad, var)) 213 | 214 | return sum_grads 215 | 216 | 217 | def optimize(optimizer, scope=None, regularization_losses=None, **kwargs): 218 | """Compute losses and gradients 219 | 220 | # Note: The regularization_losses are added to losses. 221 | 222 | Args: 223 | optimizer: An `Optimizer` object. 224 | regularization_losses: Optional list of regularization losses. If None it 225 | will gather them from tf.GraphKeys.REGULARIZATION_LOSSES. Pass `[]` to 226 | exclude them. 227 | **kwargs: Optional list of keyword arguments to pass to `compute_gradients`. 228 | 229 | Returns: 230 | A tuple (total_loss, grads_and_vars). 231 | - total_loss: A Tensor containing the average of the losses including 232 | the regularization loss. 233 | - grads_and_vars: A List of tuples (gradient, variable) containing the sum 234 | of the gradients for each variable. 235 | 236 | """ 237 | 238 | grads_and_vars = [] 239 | losses = [] 240 | if regularization_losses is None: 241 | regularization_losses = tf.compat.v1.get_collection( 242 | tf.GraphKeys.REGULARIZATION_LOSSES, scope) 243 | # with tf.name_scope(scope): 244 | loss, grad = _optimize(optimizer, 245 | regularization_losses, 246 | scope, 247 | **kwargs) 248 | if loss is not None: 249 | losses.append(loss) 250 | grads_and_vars.append(grad) 251 | # Only use regularization_losses for the first clone 252 | regularization_losses = None 253 | 254 | # Compute the total_loss summing all the losses. 255 | total_loss = tf.add_n(losses, name='total_loss') 256 | # Sum the gradients across clones. 257 | grads_and_vars = _gradients(grads_and_vars) 258 | 259 | return total_loss, grads_and_vars 260 | 261 | 262 | def get_extra_layer_scopes(last_layers_contain_logits_only=False): 263 | """Gets the scopes for extra layers. 264 | 265 | Args: 266 | last_layers_contain_logits_only: Boolean, True if only consider logits as 267 | the last layer (i.e., exclude ASPP module, decoder module and so on) 268 | 269 | Returns: 270 | A list of scopes for extra layers. 271 | """ 272 | if last_layers_contain_logits_only: 273 | return [LOGITS_SCOPE_NAME] 274 | else: 275 | return [ 276 | LOGITS_SCOPE_NAME, 277 | IMAGE_POOLING_SCOPE, 278 | ASPP_SCOPE, 279 | CONCAT_PROJECTION_SCOPE, 280 | DECODER_SCOPE, 281 | META_ARCHITECTURE_SCOPE, 282 | ] 283 | 284 | 285 | def edit_trainable_variables(removed): 286 | # gets a reference to the list containing the trainable variables 287 | trainable_collection = tf.get_collection_ref(tf.GraphKeys.TRAINABLE_VARIABLES) 288 | variables_to_remove = [] 289 | for var in trainable_collection: 290 | # uses the attribute 'name' of the variable 291 | if var.name.startswith(removed): 292 | variables_to_remove.append(var) 293 | for rem in variables_to_remove: 294 | trainable_collection.remove(rem) 295 | 296 | 297 | def add_variables_summaries(learning_rate): 298 | summaries = [] 299 | for variable in slim.get_model_variables(): 300 | summaries.append(tf.summary.histogram(variable.op.name, variable)) 301 | summaries.append(tf.summary.scalar('training/Learning Rate', learning_rate)) 302 | return summaries 303 | 304 | 305 | def restore_fn(flags): 306 | """Returns a function runs by the chief worker to warm-start the training. 307 | Note that the init_fn is only runs when initializing the model during the very 308 | first global step. 309 | 310 | """ 311 | # if flags.tf_initial_checkpoint is None: 312 | # return None 313 | 314 | # Warn the user if a checkpoint exists in the train_dir. Then ignore. 315 | # if tf.compat.v1.train.latest_checkpoint(flags.train_dir): 316 | # tf.logging.info( 317 | # 'Ignoring --checkpoint_path because a checkpoint already exists in %s' 318 | # % flags.train_dir) 319 | # return None 320 | 321 | exclusions = [] 322 | if flags.checkpoint_exclude_scopes: 323 | exclusions = [scope.strip() 324 | for scope in flags.checkpoint_exclude_scopes.split(',')] 325 | 326 | variables_to_restore = [] 327 | for var in slim.get_model_variables(): 328 | excluded = False 329 | for exclusion in exclusions: 330 | if var.op.name.startswith(exclusion): 331 | excluded = True 332 | break 333 | if not excluded: 334 | variables_to_restore.append(var) 335 | # Change model scope if necessary. 336 | if flags.checkpoint_model_scope is not None: 337 | variables_to_restore = \ 338 | {var.op.name.replace(flags.model_name, 339 | flags.checkpoint_model_scope): var 340 | for var in variables_to_restore} 341 | 342 | tf.logging.info('++++++++++++++++++++') 343 | tf.logging.info('Fine-tuning from %s. Ignoring missing vars: %s' % 344 | (flags.pre_trained_checkpoint, flags.ignore_missing_vars)) 345 | slim.assign_from_checkpoint_fn(flags.pre_trained_checkpoint, 346 | variables_to_restore, 347 | ignore_missing_vars=flags.ignore_missing_vars) 348 | 349 | 350 | def get_variables_to_train(flags): 351 | """Returns a list of variables to train. 352 | 353 | Returns: 354 | A list of variables to train by the optimizer. 355 | """ 356 | if flags.trainable_scopes is None: 357 | # print(tf.trainable_variables()) 358 | return tf.trainable_variables() 359 | else: 360 | scopes = [scope.strip() for scope in flags.trainable_scopes.split(',')] 361 | 362 | variables_to_train = [] 363 | for scope in scopes: 364 | variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope) 365 | variables_to_train.extend(variables) 366 | return variables_to_train 367 | -------------------------------------------------------------------------------- /val_data.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | MEAN=[0.485, 0.456, 0.406] 5 | STD=[0.229, 0.224, 0.225] 6 | 7 | 8 | class Dataset(object): 9 | """ 10 | Wrapper class around the new Tensorflows dataset pipeline. 11 | 12 | Handles loading, partitioning, and preparing training data. 13 | """ 14 | 15 | def __init__(self, tfrecord_path, num_views, height, width, batch_size=1): 16 | self.num_views = num_views 17 | self.resize_h = height 18 | self.resize_w = width 19 | 20 | self.dataset = tf.data.TFRecordDataset(tfrecord_path, 21 | compression_type='GZIP', 22 | num_parallel_reads=batch_size * 4) 23 | 24 | # self.dataset = self.dataset.map(self._parse_func, num_parallel_calls=8) 25 | # The map transformation takes a function and applies it to every element 26 | # of the dataset. 27 | self.dataset = self.dataset.map(self.decode, num_parallel_calls=8) 28 | # self.dataset = self.dataset.map(self.augment, num_parallel_calls=8) 29 | self.dataset = self.dataset.map(self.normalize, num_parallel_calls=8) 30 | 31 | # Prefetches a batch at a time to smooth out the time taken to load input 32 | # files for shuffling and processing. 33 | self.dataset = self.dataset.prefetch(buffer_size=batch_size) 34 | # The shuffle transformation uses a finite-sized buffer to shuffle elements 35 | # in memory. The parameter is the number of elements in the buffer. For 36 | # completely uniform shuffling, set the parameter to be the same as the 37 | # number of elements in the dataset. 38 | # self.dataset = self.dataset.shuffle(1000 + 3 * batch_size) 39 | self.dataset = self.dataset.repeat() 40 | self.dataset = self.dataset.batch(batch_size) 41 | 42 | 43 | def decode(self, serialized_example): 44 | """Parses an image and label from the given `serialized_example`.""" 45 | features = tf.io.parse_single_example( 46 | serialized_example, 47 | # Defaults are not specified since both keys are required. 48 | features={ 49 | # 'image/filename': tf.io.FixedLenFeature([], tf.string), 50 | 'image/encoded': tf.io.FixedLenFeature([self.num_views], tf.string), 51 | 'image/label': tf.io.FixedLenFeature([], tf.int64), 52 | }) 53 | 54 | images = [] 55 | # labels = [] 56 | img_lst = tf.unstack(features['image/encoded']) 57 | # lbl_lst = tf.unstack(features['image/label']) 58 | for i, img in enumerate(img_lst): 59 | # Convert from a scalar string tensor to a float32 tensor with shape 60 | image_decoded = tf.image.decode_png(img, channels=3) 61 | image = tf.image.resize(image_decoded, [self.resize_h, self.resize_w]) 62 | images.append(image) 63 | # labels.append(lbl_lst[i]) 64 | 65 | # Convert label from a scalar uint8 tensor to an int32 scalar. 66 | label = features['image/label'] 67 | 68 | return images, label 69 | 70 | def augment(self, images, label): 71 | """Placeholder for data augmentation.""" 72 | # OPTIONAL: Could reshape into a 28x28 image and apply distortions 73 | # here. Since we are not applying any distortions in this 74 | # example, and the next step expects the image to be flattened 75 | # into a vector, we don't bother. 76 | # img_lst = [] 77 | # img_tensor_lst = tf.unstack(images) 78 | # for i, image in enumerate(img_tensor_lst): 79 | # image = tf.image.random_flip_left_right(image) 80 | # image = tf.image.rot90(image, k=random.randint(0, 4)) 81 | # image = tf.image.random_brightness(image, max_delta=1.1) 82 | # image = tf.image.random_contrast(image, lower=0.9, upper=1.1) 83 | # # image = tf.image.random_hue(image, max_delta=0.04) 84 | # image = tf.image.random_saturation(image, lower=0.9, upper=1.1) 85 | # # image = tf.image.resize(image, [self.resize_h, self.resize_w]) 86 | # 87 | # img_lst.append(image) 88 | # 89 | # return img_lst, label 90 | return images, label 91 | 92 | 93 | def normalize(self, images, label): 94 | # input[channel] = (input[channel] - mean[channel]) / std[channel] 95 | img_lst = [] 96 | img_tensor_lst = tf.unstack(images) 97 | for i, image in enumerate(img_tensor_lst): 98 | # image = tf.cast(image, tf.float32) * (1. / 255) - 0.5 99 | img_lst.append(tf.cast(image, tf.float32) * (1. / 255) - 0.5) 100 | # img_lst.append(tf.div(tf.subtract(image, MEAN), STD)) 101 | 102 | return img_lst, label 103 | --------------------------------------------------------------------------------