├── .gitignore ├── LICENSE ├── README.md ├── __init__.py ├── backbone.py ├── cifar_demo.py ├── config.py ├── evaluate.py ├── faiss_utils.py ├── layers.py ├── lbl-IvS Demo.ipynb ├── log └── events.out.tfevents.1547478727.localhost.localdomain ├── pic ├── acc.jpg ├── loss.jpg ├── lr.jpg ├── output_21_0.png └── output_8_0.png ├── pipeline.py ├── reader.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # keras-lbl-IvS 2 | 3 | 论文地址:[Large-scale Bisample Learning on ID vs. Spot Face Recognition](https://arxiv.org/pdf/1806.03018.pdf) 4 | 5 | 工程地址:[keras-lbl-IvS](https://github.com/yizt/keras-lbl-IvS) 6 | 7 | ​ keras实现人证比对论文《Large-scale Bisample Learning on ID vs. Spot Face Recognition》的核心思想; 当类别数非常大时(超过百万),GPU显存可能无法装载权重参数;可以使用支配原型概念,每个step将相关的类别一块训练,step开始时将当前mini-batch的类别原型权重加载到GPU中,而不用加载所有的类别权重参数; step结束时更新并保存当前训练类别的原型权重。 8 | 9 | ​ 由于目前没有公开证件数据集;这里使用cifar10和cifar100数据集做简单的模拟测试;先说结果,直接使用cifar10数据集训练,使用imprint生成初始的cifar100数据集的类别权重;在初始的权重中测试精度为20.8%,可以说是一个不错的初始化状态;每次使用30个类别做支配原型训练,训练的精度约为80%;使用训练后的原型权重再整体测试cifar100类,最终精度为57.4%; 效果任然不太理想。一方面支配原型训练优化应该更加困难;另一方面也没有测试更多的超参数选择(支配训练类别数、支配队列大小、候选队列大小、margin、scale、batch-size、lr等等)。 10 | 11 | ​ 但以上测试结果仍然能够说明此方法是可行的;至少在类别数特别大时(GPU无法装下),使用用支配原型权重概念,在训练时动态加载当前训练类别的权重的是一个不错的方法。 12 | 13 | ​ 欢迎关注、反馈问题和沟通交流... ... 14 | 15 | 下图为支配原型训练的acc,loss和lr信息: 16 | 17 | 18 | 19 | ​ 以下是测试过程,目录如下: 20 | 21 | [TOC] 22 | 23 | ## 依赖环境 24 | 25 | * python 3.6 26 | * tensorflow-gpu 1.9.0 27 | * keras 2.2.0 28 | * faiss 1.4.0 29 | * h5py 2.8.0 30 | 31 | 32 | 33 | ## 初始化环境 34 | 35 | 36 | ```python 37 | import tensorflow as tf 38 | import numpy as np 39 | import h5py 40 | import keras 41 | from keras.layers import Input 42 | from keras.models import Model 43 | from keras.optimizers import SGD 44 | from keras.datasets import cifar10,cifar100 45 | import keras.backend as K 46 | from keras.callbacks import ModelCheckpoint, ReduceLROnPlateau, TensorBoard 47 | import sys 48 | import os 49 | import time 50 | import argparse 51 | 52 | if __package__ is None: 53 | sys.path 54 | sys.path.insert(0, '/opt/github/py_data_mining/face') 55 | __package__ = "keras-lbl-IvS" 56 | 57 | from .config import Config 58 | from .utils import random_select, get_weights, update_weights, update_queue 59 | from .faiss_utils import get_index, update_multi 60 | ``` 61 | 62 | /root/anaconda3/envs/keras/lib/python3.6/importlib/_bootstrap.py:219: RuntimeWarning: numpy.dtype size changed, may indicate binary incompatibility. Expected 96, got 88 63 | return f(*args, **kwds) 64 | /root/anaconda3/envs/keras/lib/python3.6/importlib/_bootstrap.py:219: RuntimeWarning: numpy.dtype size changed, may indicate binary incompatibility. Expected 96, got 88 65 | return f(*args, **kwds) 66 | Using TensorFlow backend. 67 | 68 | 69 | ### 配置信息 70 | 71 | 72 | ```python 73 | class CifarConfig(Config): 74 | # 输入信息 75 | input_shape = (32, 32, 3) 76 | num_classes = 100 # cifar100 77 | train_num_classes = 30 # 支配原型训练的类别数 78 | 79 | # 训练参数 80 | batch_size = 256 81 | 82 | # 原型权重保存 83 | prototype_weights_hdf5 = '/tmp/cifar_prototype_weights.hdf5' 84 | prototype_weights_dataset = 'cifar_prototype_weights_set' 85 | pw_h5_file = None # hdf5 File文件 86 | 87 | index = None # 保存原型类别的faiss索引 88 | 89 | dominant_queue = None # 支配队列 90 | candidate_queue = None # 候选队列 91 | 92 | dominant_queue_num = 20 # 支配队列大小 93 | candidate_queue_num = 50 # 候选队列大小 94 | 95 | config = CifarConfig() 96 | ``` 97 | 98 | 99 | ```python 100 | def set_gpu_growth(): 101 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 102 | cfg = tf.ConfigProto() 103 | cfg.gpu_options.allow_growth = True 104 | session = tf.Session(config=cfg) 105 | K.set_session(session) 106 | 107 | set_gpu_growth() 108 | ``` 109 | 110 | 111 | ```python 112 | import matplotlib.pyplot as plt 113 | def img_show(X, rows, cols): 114 | plt.figure(figsize=(cols * 1.5, rows)) 115 | for i, img in enumerate(X[:rows * cols]): 116 | plt.subplot(rows, cols, 1 + i) 117 | plt.imshow(img) 118 | plt.show() 119 | ``` 120 | 121 | ## 源数据集训练 122 | 123 | 使用cifar10 作为源数据集 124 | 125 | ### cifar10 数据集加载、可视化 126 | 127 | 128 | ```python 129 | (x_train, y_train), (x_test, y_test) = cifar10.load_data() 130 | ``` 131 | 132 | 133 | ```python 134 | img_show(x_train,2,10) 135 | ``` 136 | 137 | 138 | ![png](pic/output_8_0.png) 139 | 140 | 141 | ### 骨干网络定义 142 | 143 | 骨干网络,提取特征 144 | 145 | 146 | ```python 147 | # 骨干网络,提取特征 148 | import keras.layers as layers 149 | def cifar_base_net(inputs): 150 | """ 151 | 骨干网络、用于提取特征 152 | :param inputs: keras Input对象 153 | :return: 154 | """ 155 | x = layers.Conv2D(32, (3, 3), padding='same', name='conv_1')(inputs) 156 | x = layers.BatchNormalization(axis=3, name='bn_conv1')(x) 157 | x = layers.Activation('relu')(x) 158 | x = layers.Conv2D(32, (3, 3), name='conv_2')(x) 159 | x = layers.Activation('relu')(x) 160 | x = layers.MaxPooling2D(pool_size=(2, 2))(x) 161 | x = layers.Dropout(0.25)(x) 162 | 163 | x = layers.Conv2D(64, (3, 3), padding='same', name='conv_3')(x) 164 | x = layers.Activation('relu')(x) 165 | x = layers.Conv2D(64, (3, 3), name='conv_4')(x) 166 | x = layers.Activation('relu')(x) 167 | x = layers.MaxPooling2D(pool_size=(2, 2))(x) 168 | x = layers.Dropout(0.25)(x) 169 | 170 | x = layers.Conv2D(512, (3, 3), name='conv_5')(x) 171 | x = layers.Activation('relu')(x) 172 | x = layers.GlobalAvgPool2D()(x) 173 | 174 | return x 175 | ``` 176 | 177 | ### AMSoftmax分类层定义 178 | 179 | 关于AMSoftmax的详细信息见论文:[Additive Margin Softmax for Face Verification](https://arxiv.org/abs/1801.05599) 180 | 181 | 182 | ```python 183 | from keras.layers import Layer 184 | class DenseWithAMSoftmaxLoss(Layer): 185 | 186 | def __init__(self, num_classes, m=0.35, scale=30, **kwargs): 187 | """ 188 | 根据Additive Margin Softmaxfor Face Verification论文设置m和scale 189 | :param num_classes: 类别数 190 | :param m: margin间隔 191 | :param scale: 尺度 192 | :param kwargs: 193 | """ 194 | self.output_dim = num_classes 195 | self.m = m 196 | self.scale = scale 197 | super(DenseWithAMSoftmaxLoss, self).__init__(**kwargs) 198 | 199 | def build(self, input_shape): 200 | # 添加可训练参数 201 | self.kernel = self.add_weight(name='kernel', 202 | shape=(input_shape[1], self.output_dim), 203 | initializer='glorot_normal', 204 | trainable=True) 205 | 206 | def call(self, inputs): 207 | self.inputs = inputs 208 | # L2归一化 209 | self.w_norm = K.tf.nn.l2_normalize(self.kernel, 0, 1e-10) 210 | self.x_norm = K.tf.nn.l2_normalize(self.inputs, 1, 1e-10) 211 | self.logits = K.dot(self.x_norm, self.w_norm) 212 | return self.logits 213 | 214 | def compute_output_shape(self, input_shape): 215 | return (input_shape[0], # batch_size 216 | self.output_dim) 217 | 218 | def loss_dense(self, y_true, y_pred): 219 | """ 220 | AmSoftmax Loss 221 | :param y_true: 真实类别 [N,num_classes] 222 | :param y_pred: 预测类别[N,num_classes] 223 | :return: 224 | """ 225 | y_true = K.cast(y_true, 'float32') 226 | thelta_m = y_pred - y_true * self.m # cosine(thelta)-m ; y_true 就相当于mask 227 | 228 | return K.categorical_crossentropy(y_true, self.scale * thelta_m, from_logits=True) 229 | 230 | ``` 231 | 232 | ### 构建模型 233 | 234 | 235 | ```python 236 | from keras.models import Model 237 | from keras.layers import Dense 238 | from keras import Input 239 | 240 | def source_model(img_inputs,num_class): 241 | features=cifar_base_net(img_inputs) 242 | dense=DenseWithAMSoftmaxLoss(num_class,m=0.2,scale=24.,name='amsoftmax') # 使用自定义的AMSoftMax分类层 243 | output=dense(features) 244 | m = Model(img_inputs,output) 245 | m.compile(optimizer=keras.optimizers.Adam(lr=0.001,decay=1e-6), 246 | loss=dense.loss_dense, 247 | metrics=['accuracy']) 248 | return m 249 | 250 | img_input=Input(shape=(32,32,3)) 251 | m = source_model(img_input,10) # cifar10 一共10类 252 | m.summary() 253 | ``` 254 | 255 | WARNING:tensorflow:From /root/anaconda3/envs/keras/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py:3216: softmax_cross_entropy_with_logits (from tensorflow.python.ops.nn_ops) is deprecated and will be removed in a future version. 256 | Instructions for updating: 257 | 258 | Future major versions of TensorFlow will allow gradients to flow 259 | into the labels input on backprop by default. 260 | 261 | See @{tf.nn.softmax_cross_entropy_with_logits_v2}. 262 | 263 | _________________________________________________________________ 264 | Layer (type) Output Shape Param # 265 | ================================================================= 266 | input_1 (InputLayer) (None, 32, 32, 3) 0 267 | _________________________________________________________________ 268 | conv_1 (Conv2D) (None, 32, 32, 32) 896 269 | _________________________________________________________________ 270 | bn_conv1 (BatchNormalization (None, 32, 32, 32) 128 271 | _________________________________________________________________ 272 | activation_1 (Activation) (None, 32, 32, 32) 0 273 | _________________________________________________________________ 274 | conv_2 (Conv2D) (None, 30, 30, 32) 9248 275 | _________________________________________________________________ 276 | activation_2 (Activation) (None, 30, 30, 32) 0 277 | _________________________________________________________________ 278 | max_pooling2d_1 (MaxPooling2 (None, 15, 15, 32) 0 279 | _________________________________________________________________ 280 | dropout_1 (Dropout) (None, 15, 15, 32) 0 281 | _________________________________________________________________ 282 | conv_3 (Conv2D) (None, 15, 15, 64) 18496 283 | _________________________________________________________________ 284 | activation_3 (Activation) (None, 15, 15, 64) 0 285 | _________________________________________________________________ 286 | conv_4 (Conv2D) (None, 13, 13, 64) 36928 287 | _________________________________________________________________ 288 | activation_4 (Activation) (None, 13, 13, 64) 0 289 | _________________________________________________________________ 290 | max_pooling2d_2 (MaxPooling2 (None, 6, 6, 64) 0 291 | _________________________________________________________________ 292 | dropout_2 (Dropout) (None, 6, 6, 64) 0 293 | _________________________________________________________________ 294 | conv_5 (Conv2D) (None, 4, 4, 512) 295424 295 | _________________________________________________________________ 296 | activation_5 (Activation) (None, 4, 4, 512) 0 297 | _________________________________________________________________ 298 | global_average_pooling2d_1 ( (None, 512) 0 299 | _________________________________________________________________ 300 | amsoftmax (DenseWithAMSoftma (None, 10) 5120 301 | ================================================================= 302 | Total params: 366,240 303 | Trainable params: 366,176 304 | Non-trainable params: 64 305 | _________________________________________________________________ 306 | 307 | 308 | ### 训练模型 309 | 310 | 311 | ```python 312 | (x_train, y_train), (x_test, y_test) = cifar10.load_data() 313 | m.fit(x_train, keras.utils.to_categorical(y_train,10), 314 | batch_size=64, 315 | epochs=10, 316 | validation_data=(x_test, keras.utils.to_categorical(y_test,10)), 317 | shuffle=True) 318 | ``` 319 | 320 | Train on 50000 samples, validate on 10000 samples 321 | Epoch 1/10 322 | 50000/50000 [==============================] - 8s 153us/step - loss: 5.6485 - acc: 0.4251 - val_loss: 4.8240 - val_acc: 0.5428 323 | Epoch 2/10 324 | 50000/50000 [==============================] - 6s 129us/step - loss: 4.1562 - acc: 0.6204 - val_loss: 3.5912 - val_acc: 0.6834 325 | Epoch 3/10 326 | 50000/50000 [==============================] - 7s 136us/step - loss: 3.4796 - acc: 0.6930 - val_loss: 3.1658 - val_acc: 0.7277 327 | Epoch 4/10 328 | 50000/50000 [==============================] - 7s 137us/step - loss: 3.0975 - acc: 0.7322 - val_loss: 2.9639 - val_acc: 0.7396 329 | Epoch 5/10 330 | 50000/50000 [==============================] - 7s 135us/step - loss: 2.8483 - acc: 0.7540 - val_loss: 3.1497 - val_acc: 0.7251 331 | Epoch 6/10 332 | 50000/50000 [==============================] - 7s 131us/step - loss: 2.6558 - acc: 0.7757 - val_loss: 3.0899 - val_acc: 0.7297 333 | Epoch 7/10 334 | 50000/50000 [==============================] - 6s 128us/step - loss: 2.5169 - acc: 0.7885 - val_loss: 2.6846 - val_acc: 0.7680 335 | Epoch 8/10 336 | 50000/50000 [==============================] - 7s 136us/step - loss: 2.4056 - acc: 0.7990 - val_loss: 2.8013 - val_acc: 0.7623 337 | Epoch 9/10 338 | 50000/50000 [==============================] - 7s 136us/step - loss: 2.3017 - acc: 0.8082 - val_loss: 2.4712 - val_acc: 0.7900 339 | Epoch 10/10 340 | 50000/50000 [==============================] - 7s 131us/step - loss: 2.2125 - acc: 0.8167 - val_loss: 2.4762 - val_acc: 0.7867 341 | 342 | 343 | 344 | 345 | 346 | 347 | 348 | 349 | 350 | 351 | ```python 352 | # 保存模型 353 | m.save('/tmp/source.weights.h5') 354 | ``` 355 | 356 | ## 初始评估 357 | 358 | 使用在源数据集cifar10上训练的模型,来评估在cifar100上的精度;评估方式是通过imprint构建初始的分类权重; 359 | imprint详见[Low-Shot Learning with Imprinted Weights](https://arxiv.org/abs/1712.07136)论文,相关代码见[keras-imprinting](https://github.com/yizt/keras-imprinting). 360 | 361 | 362 | ### cifar100 数据加载、可视化 363 | 364 | 365 | ```python 366 | (x_train, y_train), (x_test, y_test) = cifar100.load_data() 367 | ``` 368 | 369 | 370 | ```python 371 | img_show(x_train,2,10) 372 | ``` 373 | 374 | 375 | ![png](pic/output_21_0.png) 376 | 377 | 378 | ### 初始原型权重 379 | 380 | 使用imprint方式将cifar100的类别原型权重保持到hdf5中(虽然这里只有100类,但是也按照类别上百万的方式来实现) 381 | 382 | 383 | ```python 384 | # 类别分组 385 | def label_id_map(labels, num_class): 386 | """ 387 | 将图像按照类别分组 388 | """ 389 | # 初始化 390 | label_id_dict = dict() 391 | for i in range(num_class): 392 | label_id_dict[i] = [] 393 | 394 | # 逐个图像归类 395 | for img_id in range(len(labels)): 396 | label_id_dict[labels[img_id]].append(img_id) 397 | return label_id_dict 398 | label_id_dict = label_id_map(y_train[:,0], config.num_classes) # y_train (5000,1) 399 | ``` 400 | 401 | 402 | ```python 403 | # 初始化原型权重 404 | def init_prototype(pretrained_weights, X, label_id_dict, num_class): 405 | inputs = Input(batch_shape=(config.batch_size,) + config.input_shape) 406 | features = cifar_base_net(inputs) 407 | model = Model(inputs, features) 408 | model.load_weights(pretrained_weights, by_name=True) 409 | # 原型权重一份放到hdf5,一份存放到faiss中(faiss中保留的不是精准的) 410 | if os.path.exists(config.prototype_weights_hdf5): 411 | os.remove(config.prototype_weights_hdf5) 412 | f_label = h5py.File(config.prototype_weights_hdf5, 'w') 413 | label_feature = f_label.create_dataset(config.prototype_weights_dataset, 414 | shape=(num_class, 512), dtype='f') 415 | 416 | # 逐个类别处理 417 | for label in range(num_class): 418 | # 获取某个label的所有图像,并使用模型预测图像的特征,最后求均值作为label的原型权重 419 | image_ids = label_id_dict[label] # 图像id 420 | images = X[image_ids] # 图像数据 421 | features = model.predict(np.asarray(images)) # 输出特征 422 | features = keras.utils.np_utils.normalize(features) # 归一化 423 | features = np.mean(features, axis=0) # 求均值 424 | features = keras.utils.np_utils.normalize(features) # 再次归一化; 是二维的 425 | 426 | # 赋值给hdf5 427 | label_feature[label] = features[0] # (1,d) 转为 (d,) 428 | # 每1w次,刷写到磁盘 429 | if label % 1000 == 0: 430 | f_label.flush() 431 | print("{} init_prototype 完成:{}".format(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()), 432 | label)) 433 | f_label.flush() 434 | # 最后关闭文件 435 | f_label.close() 436 | ``` 437 | 438 | 439 | ```python 440 | init_prototype('/tmp/source.weights.h5',x_train, 441 | label_id_dict, 442 | config.num_classes) 443 | ``` 444 | 445 | 2019-01-14 23:11:57 init_prototype 完成:0 446 | 447 | 448 | ### 开始评估 449 | 450 | 451 | ```python 452 | # 定义评估函数 453 | from keras.layers import Dense 454 | def evaluate(pretrained_weights, X, y, num_class, use_amsoftmax=False): 455 | if use_amsoftmax: 456 | dense = DenseWithAMSoftmaxLoss(num_class, m=0.2, scale=24.,name='other') # name='other' 457 | else: 458 | dense = layers.Dense(num_class, use_bias=False, activation='softmax') 459 | # 构建模型 460 | inputs = Input(shape=config.input_shape) 461 | features = cifar_base_net(inputs) 462 | outputs = dense(features) 463 | m = Model(inputs, outputs) 464 | 465 | m.load_weights(pretrained_weights, by_name=True) 466 | 467 | # 设置分类层的权重 468 | f_label = h5py.File(config.prototype_weights_hdf5, 'r') 469 | weights = f_label[config.prototype_weights_dataset] # HDF5 dataset object 470 | m.layers[-1].set_weights([np.transpose(weights[:])]) # 使用weights[:]转为numpy 471 | # 最后关闭文件 472 | f_label.close() 473 | #m.summary() 474 | 475 | # 评估预测 476 | y_pred = m.predict(X) 477 | y_pred = np.argmax(y_pred, axis=-1) 478 | 479 | print("acc:{}".format(np.mean(np.equal(y,y_pred)))) 480 | ``` 481 | 482 | 483 | ```python 484 | evaluate('/tmp/source.weights.h5', x_test,y_test[:,0],config.num_classes,True) 485 | ``` 486 | 487 | acc:0.2081 488 | 489 | 490 | 可以看到直接使用imprint方式初始化类别权重参数,在cifar100的精度超过20%;是随机精度1%的20倍; 491 | 492 | ## 支配原型训练cifar100 493 | 494 | ### 获取原型 495 | 496 | 获取保持在hdf5中的初始原型权重;并将所有的权重加载到faiss的索引中,用于候选构造支配队列和候选队列 497 | 498 | 499 | ```python 500 | def get_prototype(deal_batch=1000): 501 | f_label = h5py.File(config.prototype_weights_hdf5, 'r+') 502 | dset = f_label[config.prototype_weights_dataset] 503 | length = len(dset) 504 | index = get_index(512) 505 | # 逐个类别处理,更新faiss index 506 | for batch_no in range(length // deal_batch): 507 | start = batch_no * deal_batch 508 | end = (batch_no + 1) * deal_batch 509 | features = dset[start:end] 510 | update_multi(index, features, np.arange(start, end)) 511 | # 处理不能整除的情况 512 | if not length % deal_batch == 0: 513 | start = length - length % deal_batch 514 | end = length 515 | features = dset[start:end] 516 | update_multi(index, features, np.arange(start, end)) 517 | return f_label, index 518 | ``` 519 | 520 | 521 | ```python 522 | config.pw_h5_file, config.index = get_prototype(1000) 523 | ``` 524 | 525 | ### 初始化队列 526 | 527 | 根据原型权重构造支配队列和候选队列; 528 | 529 | 530 | ```python 531 | def init_queue(index, weights_set, num_class, dq_num, cq_num): 532 | """ 533 | 初始化候选队列和支配队列 534 | :param index: 535 | :param weights_set: h5py dataset对象 536 | :param num_class: 537 | :param dq_num: 538 | :param cq_num: 539 | :return: 540 | """ 541 | data, candidate_label_idx = index.search(weights_set[:num_class], cq_num + 1) 542 | candidate_label_idx = candidate_label_idx[:, 1:] # 543 | dominant_label_idx = candidate_label_idx[:, :dq_num] # 候选队列包含支配队列 544 | 545 | # 转为字典类型 546 | dominant_queue = dict(enumerate(dominant_label_idx)) 547 | candidate_queue = dict(enumerate(candidate_label_idx)) 548 | 549 | return dominant_queue, candidate_queue 550 | ``` 551 | 552 | 553 | ```python 554 | config.dominant_queue, config.candidate_queue = init_queue(config.index, 555 | config.pw_h5_file[config.prototype_weights_dataset], 556 | config.num_classes, 557 | config.dominant_queue_num, 558 | config.candidate_queue_num) 559 | ``` 560 | 561 | ### 支配原型分类层定义 562 | 563 | 关键点: 564 | a) 将需要输出的tensor保存在权重参数中 565 | b) 使用K.tf.control_dependencies保证修动态加载原型权重参数和输出预测结果到不训练的权重中会加入到计算图中 566 | 567 | 568 | ```python 569 | class DenseWithDPSoftmaxLoss(Layer): 570 | def __init__(self, num_class, bacth_size, m=0.35, scale=30, **kwargs): 571 | self.output_dim = num_class 572 | self.m = m 573 | self.scale = scale 574 | self.batch_size = bacth_size 575 | super(DenseWithDPSoftmaxLoss, self).__init__(**kwargs) 576 | 577 | def build(self, input_shape): 578 | self.kernel = self.add_weight(name='kernel', 579 | shape=(input_shape[0][1], self.output_dim), # (嵌入维度,num_class) 580 | dtype=K.floatx(), 581 | initializer='glorot_normal', 582 | constraint=keras.constraints.unit_norm(), 583 | trainable=True) 584 | self.current_selected_labels = self.add_weight(name='labels', 585 | shape=(self.batch_size, 1), 586 | initializer='glorot_normal', 587 | trainable=False) 588 | self.y_pred = self.add_weight(name='pred', 589 | shape=(self.batch_size, self.output_dim), 590 | initializer='glorot_normal', 591 | trainable=False) 592 | 593 | def call(self, inputs, **kwargs): 594 | # 将当前step类别的权重,赋值给tensor 595 | weights_assign_op = K.tf.assign(self.kernel, 596 | K.transpose(inputs[1][:self.output_dim]), 597 | name='assign_weights') 598 | label_assign_op = K.tf.assign(self.current_selected_labels, inputs[2], name='assign_labels') 599 | with K.tf.control_dependencies([weights_assign_op, label_assign_op]): 600 | # L2归一化 601 | self.w_norm = K.tf.nn.l2_normalize(self.kernel, 0, 1e-10) 602 | self.x_norm = K.tf.nn.l2_normalize(inputs[0], 1, 1e-10) 603 | self.logits = K.dot(self.x_norm, self.w_norm) 604 | return self.logits 605 | 606 | def compute_output_shape(self, input_shape): 607 | return (input_shape[0][0], 608 | self.output_dim) 609 | # return [(input_shape[0][0], self.output_dim), 610 | # (self.output_dim, input_shape[0][1])] 611 | 612 | def loss(self, y_true, y_pred): 613 | """ 614 | DPAmSoftmax Loss 615 | :param y_true: 真实类别 [N,num_classes] 616 | :param y_pred: 预测类别[N,num_classes] 617 | :return: 618 | """ 619 | # 首先将预测值保持到权重中 620 | pred_assign_op = K.tf.assign(self.y_pred, 621 | y_pred, 622 | name='assign_pred') 623 | with K.tf.control_dependencies([pred_assign_op]): 624 | y_true = K.cast(y_true, 'float32') 625 | thelta_m = y_pred - y_true * self.m # cosine(thelta)-m ; y_true 就相当于mask 626 | losses = K.categorical_crossentropy(y_true, self.scale * thelta_m, from_logits=True) 627 | 628 | return losses 629 | ``` 630 | 631 | ### 构建模型 632 | 633 | 634 | ```python 635 | def target_model(pretrained_weights, batch_size,train_num_class, input_shape): 636 | """ 637 | 目标数据集训练模型 638 | :param pretrained_weights: 源数据集上训练的特征提取层的权重 639 | :param batch_size: 640 | :param input_shape: 641 | :return: 642 | """ 643 | # 构建模型 644 | inputs = Input(batch_shape=(batch_size,) + input_shape) 645 | weights_inputs = Input(batch_shape=(train_num_class, 512)) 646 | label_inputs = Input(batch_shape=(batch_size, 1)) # 至少是二维的, 647 | features = cifar_base_net(inputs) 648 | dense = DenseWithDPSoftmaxLoss(train_num_class,batch_size,m=0.2,scale=24.) 649 | outputs = dense([features, weights_inputs, label_inputs]) 650 | m = Model([inputs, weights_inputs, label_inputs], outputs) 651 | 652 | m.load_weights(pretrained_weights, by_name=True) 653 | # 固定参数 654 | #for l in m.layers[:-1]: 655 | # l.trainable = False 656 | m.compile(loss=dense.loss, 657 | optimizer=SGD(lr=0.2, momentum=0, decay=1e-4), 658 | metrics=['accuracy']) 659 | 660 | m.summary() 661 | return m 662 | ``` 663 | 664 | 665 | ```python 666 | m=target_model('/tmp/source.weights.h5', 667 | config.batch_size,config.train_num_classes, config.input_shape) 668 | ``` 669 | 670 | __________________________________________________________________________________________________ 671 | Layer (type) Output Shape Param # Connected to 672 | ================================================================================================== 673 | input_4 (InputLayer) (64, 32, 32, 3) 0 674 | __________________________________________________________________________________________________ 675 | conv_1 (Conv2D) (64, 32, 32, 32) 896 input_4[0][0] 676 | __________________________________________________________________________________________________ 677 | bn_conv1 (BatchNormalization) (64, 32, 32, 32) 128 conv_1[0][0] 678 | __________________________________________________________________________________________________ 679 | activation_16 (Activation) (64, 32, 32, 32) 0 bn_conv1[0][0] 680 | __________________________________________________________________________________________________ 681 | conv_2 (Conv2D) (64, 30, 30, 32) 9248 activation_16[0][0] 682 | __________________________________________________________________________________________________ 683 | activation_17 (Activation) (64, 30, 30, 32) 0 conv_2[0][0] 684 | __________________________________________________________________________________________________ 685 | max_pooling2d_7 (MaxPooling2D) (64, 15, 15, 32) 0 activation_17[0][0] 686 | __________________________________________________________________________________________________ 687 | dropout_7 (Dropout) (64, 15, 15, 32) 0 max_pooling2d_7[0][0] 688 | __________________________________________________________________________________________________ 689 | conv_3 (Conv2D) (64, 15, 15, 64) 18496 dropout_7[0][0] 690 | __________________________________________________________________________________________________ 691 | activation_18 (Activation) (64, 15, 15, 64) 0 conv_3[0][0] 692 | __________________________________________________________________________________________________ 693 | conv_4 (Conv2D) (64, 13, 13, 64) 36928 activation_18[0][0] 694 | __________________________________________________________________________________________________ 695 | activation_19 (Activation) (64, 13, 13, 64) 0 conv_4[0][0] 696 | __________________________________________________________________________________________________ 697 | max_pooling2d_8 (MaxPooling2D) (64, 6, 6, 64) 0 activation_19[0][0] 698 | __________________________________________________________________________________________________ 699 | dropout_8 (Dropout) (64, 6, 6, 64) 0 max_pooling2d_8[0][0] 700 | __________________________________________________________________________________________________ 701 | conv_5 (Conv2D) (64, 4, 4, 512) 295424 dropout_8[0][0] 702 | __________________________________________________________________________________________________ 703 | activation_20 (Activation) (64, 4, 4, 512) 0 conv_5[0][0] 704 | __________________________________________________________________________________________________ 705 | global_average_pooling2d_4 (Glo (64, 512) 0 activation_20[0][0] 706 | __________________________________________________________________________________________________ 707 | input_5 (InputLayer) (30, 512) 0 708 | __________________________________________________________________________________________________ 709 | input_6 (InputLayer) (64, 1) 0 710 | __________________________________________________________________________________________________ 711 | dense_with_dp_softmax_loss_1 (D (64, 30) 17344 global_average_pooling2d_4[0][0] 712 | input_5[0][0] 713 | input_6[0][0] 714 | ================================================================================================== 715 | Total params: 378,464 716 | Trainable params: 376,416 717 | Non-trainable params: 2,048 718 | __________________________________________________________________________________________________ 719 | 720 | 可以看到支配原型训练模型的最后的分类层输出是30个类别 721 | 722 | 723 | 724 | ### 构建生成器 725 | 726 | 727 | ```python 728 | def generator(X, label_id_dict, dominant_queue, num_class, train_num_class, batch_size): 729 | """ 730 | 训练样本生成器 731 | :param X: 图像 732 | :param label_id_dict: 类别和对应的图像id字典 733 | :param dominant_queue: 支配队列; 类别和对应的支配类别 734 | :param num_class: 类别数, 这里是100 735 | :param train_num_class: 训练类别数, 每次加载多少个类别训练 736 | :param batch_size: batch_size 737 | :return: 738 | """ 739 | while True: 740 | # 两级采样,首先采样个train_num_class // 3正类别;然后随机采样支配类别 741 | sample_labels = np.random.choice(num_class, train_num_class // 3, replace=False) # 无放回抽样 742 | 743 | selected_labels = set(sample_labels) # 保存当前选中类别set 744 | 745 | # 采样支配类别 746 | all_current_dominant_labels = set() 747 | for label in selected_labels: 748 | all_current_dominant_labels = all_current_dominant_labels.union(dominant_queue[label]) 749 | 750 | all_current_dominant_labels = all_current_dominant_labels - selected_labels 751 | selected_dominant_labels = np.random.choice(list(all_current_dominant_labels), 752 | train_num_class - len(selected_labels), 753 | replace=False) 754 | all_selected_labels = selected_labels | set(selected_dominant_labels) 755 | all_selected_labels = list(all_selected_labels) # 转为list 756 | list.sort(all_selected_labels) 757 | # print(all_selected_labels) 758 | 759 | # 以下采样图像 760 | selected_image_labels = [] # 保存当前step选中的图像和对应的类别标签 761 | # 首先每个类别采样一次 762 | for label in all_selected_labels: 763 | selected_image_id = random_select(label_id_dict[label]) 764 | selected_image_labels.append([selected_image_id, label]) 765 | 766 | # 再随机采样,直到mini-batch大小 767 | while len(selected_image_labels) < batch_size: 768 | # 随机采样类别 769 | label = random_select(all_selected_labels) 770 | # 随机采样图像 771 | selected_image_id = random_select(label_id_dict[label]) 772 | selected_image_labels.append([selected_image_id, label]) 773 | 774 | # 真实选中的标签 775 | selected_image_labels = np.asarray(selected_image_labels) # 转为numpy数组 776 | real_selected_labels = np.array(selected_image_labels[:, 1]) # 新建一个数组 777 | 778 | # 将真实标签映射为当前训练的标签 779 | for i in range(batch_size): 780 | selected_image_labels[i][1] = all_selected_labels.index(selected_image_labels[i][1]) 781 | 782 | # 当前选中图像和标签 783 | current_selected_labels = np.asarray(all_selected_labels) 784 | current_weights = get_weights(config.pw_h5_file[config.prototype_weights_dataset], 785 | current_selected_labels) 786 | # 增加padding到batch-size大小 787 | current_weights = np.pad(current_weights, 788 | ((0, batch_size - train_num_class), (0, 0)), 789 | 'constant') 790 | 791 | # 加载图像 792 | images = X[selected_image_labels[:, 0]] 793 | # 处理mini-batch的真实标签 794 | real_selected_labels = np.expand_dims(real_selected_labels, axis=1) 795 | # 返回当前mini-batch 796 | yield [images, 797 | current_weights, 798 | real_selected_labels], keras.utils.to_categorical(selected_image_labels[:, 1], 799 | train_num_class) # 标签类别永远是0~batch_size-1 800 | 801 | ``` 802 | 803 | ### 定义保存权重的CallBack 804 | 805 | 在训练的每个step后,将此次训练的类别原型权重导出保存到hdf5中,同时更新根据预测结果更新支配队列 806 | 807 | 808 | ```python 809 | class ExportWeights(keras.callbacks.Callback): 810 | def on_batch_end(self, batch, logs=None): 811 | layer = self.model.layers[-1] 812 | # 获取当前step训练后的权重、当前训练实际的标签类别,预测结果 813 | trained_weights, current_trained_labels, y_pred = layer.get_weights()[:3] 814 | trained_weights = np.transpose(trained_weights) 815 | # l2 norm 816 | trained_weights = trained_weights/np.linalg.norm(trained_weights, ord=2, axis=1, keepdims=True) 817 | 818 | current_trained_labels = np.asarray(current_trained_labels[:, 0], dtype=np.int) # 二维转一维 819 | # print("trained_weights:{}".format(trained_weights[:2][:5])) 820 | # print("current_trained_labels:{}".format(current_trained_labels)) 821 | # 更新原型权重 822 | update_weights(config.pw_h5_file, 823 | config.pw_h5_file[config.prototype_weights_dataset], 824 | trained_weights, 825 | np.unique(current_trained_labels)) 826 | # 以下更新支配队列,根据预测结果更新支配队列; 827 | update_queue(config.dominant_queue, 828 | config.candidate_queue, 829 | y_pred, 830 | current_trained_labels) 831 | 832 | ``` 833 | 834 | 835 | ```python 836 | #定义callback 837 | def get_call_back(): 838 | # 精度没有提升 839 | lr_reducer = ReduceLROnPlateau(monitor='acc', 840 | factor=np.sqrt(0.1), 841 | cooldown=0, 842 | patience=5, 843 | min_lr=0.5e-6) 844 | log = TensorBoard(log_dir='log') 845 | export_weights = ExportWeights() 846 | 847 | return [lr_reducer, export_weights, log] 848 | ``` 849 | 850 | ### 训练支配原型模型 851 | 852 | 853 | ```python 854 | print("开始训练模型... ...") 855 | gen = generator(x_train, 856 | label_id_dict, 857 | config.dominant_queue, 858 | config.num_classes, 859 | config.train_num_classes, # 每次训练的类别数 860 | config.batch_size) 861 | evaluate('/tmp/source.weights.h5', 862 | x_test,y_test[:,0],config.num_classes,True) 863 | ``` 864 | 865 | 开始训练模型... ... 866 | 867 | 868 | 869 | ```python 870 | m.fit_generator(gen, 871 | callbacks=get_call_back(), 872 | steps_per_epoch=len(x_train) // config.batch_size, 873 | epochs=200, 874 | use_multiprocessing=True, 875 | # validation_data=next(gen), 876 | verbose=1 877 | ) 878 | ``` 879 | 880 | Epoch 1/200 881 | 195/195 [==============================] - 5s 23ms/step - loss: 8.6228 - acc: 0.0701 882 | Epoch 2/200 883 | 195/195 [==============================] - 4s 21ms/step - loss: 6.9006 - acc: 0.3078 884 | Epoch 3/200 885 | 195/195 [==============================] - 4s 21ms/step - loss: 6.2487 - acc: 0.4090 886 | Epoch 4/200 887 | 195/195 [==============================] - 4s 21ms/step - loss: 5.9130 - acc: 0.4527 888 | Epoch 5/200 889 | 195/195 [==============================] - 4s 21ms/step - loss: 5.7382 - acc: 0.4725 890 | Epoch 6/200 891 | 195/195 [==============================] - 4s 21ms/step - loss: 5.5785 - acc: 0.4917 892 | Epoch 7/200 893 | 195/195 [==============================] - 4s 21ms/step - loss: 5.3747 - acc: 0.5178 894 | Epoch 8/200 895 | 195/195 [==============================] - 4s 21ms/step - loss: 5.2205 - acc: 0.5345 896 | Epoch 9/200 897 | 195/195 [==============================] - 4s 21ms/step - loss: 5.0547 - acc: 0.5533 898 | Epoch 10/200 899 | 195/195 [==============================] - 4s 21ms/step - loss: 4.9496 - acc: 0.5629 900 | Epoch 11/200 901 | 195/195 [==============================] - 4s 21ms/step - loss: 4.9175 - acc: 0.5664 902 | Epoch 12/200 903 | 195/195 [==============================] - 4s 21ms/step - loss: 4.7262 - acc: 0.5885 904 | Epoch 13/200 905 | 195/195 [==============================] - 4s 21ms/step - loss: 4.6691 - acc: 0.5940 906 | Epoch 14/200 907 | 195/195 [==============================] - 4s 21ms/step - loss: 4.5670 - acc: 0.6050 908 | Epoch 15/200 909 | 195/195 [==============================] - 4s 21ms/step - loss: 4.4895 - acc: 0.6137 910 | Epoch 16/200 911 | 195/195 [==============================] - 4s 21ms/step - loss: 4.4187 - acc: 0.6185 912 | Epoch 17/200 913 | 195/195 [==============================] - 4s 21ms/step - loss: 4.3984 - acc: 0.6248 914 | Epoch 18/200 915 | 195/195 [==============================] - 4s 21ms/step - loss: 4.3396 - acc: 0.6281 916 | Epoch 19/200 917 | 195/195 [==============================] - 4s 21ms/step - loss: 4.2386 - acc: 0.6382 918 | Epoch 20/200 919 | 195/195 [==============================] - 4s 21ms/step - loss: 4.2609 - acc: 0.6366 920 | Epoch 21/200 921 | 195/195 [==============================] - 4s 21ms/step - loss: 4.2040 - acc: 0.6425 922 | Epoch 22/200 923 | 195/195 [==============================] - 4s 21ms/step - loss: 4.1078 - acc: 0.6521 924 | Epoch 23/200 925 | 195/195 [==============================] - 4s 21ms/step - loss: 4.1030 - acc: 0.6492 926 | Epoch 24/200 927 | 195/195 [==============================] - 4s 21ms/step - loss: 4.0578 - acc: 0.6584 928 | Epoch 25/200 929 | 195/195 [==============================] - 4s 21ms/step - loss: 3.9726 - acc: 0.6668 930 | Epoch 26/200 931 | 195/195 [==============================] - 4s 21ms/step - loss: 3.9568 - acc: 0.6678 932 | Epoch 27/200 933 | 195/195 [==============================] - 4s 21ms/step - loss: 3.8982 - acc: 0.6756 934 | Epoch 28/200 935 | 195/195 [==============================] - 4s 21ms/step - loss: 3.8731 - acc: 0.6754 936 | Epoch 29/200 937 | 195/195 [==============================] - 4s 21ms/step - loss: 3.8896 - acc: 0.6719 938 | Epoch 30/200 939 | 195/195 [==============================] - 4s 21ms/step - loss: 3.8002 - acc: 0.6818 940 | Epoch 31/200 941 | 195/195 [==============================] - 4s 21ms/step - loss: 3.7603 - acc: 0.6842 942 | Epoch 32/200 943 | 195/195 [==============================] - 4s 21ms/step - loss: 3.7401 - acc: 0.6894 944 | Epoch 33/200 945 | 195/195 [==============================] - 4s 21ms/step - loss: 3.7214 - acc: 0.6924 946 | Epoch 34/200 947 | 195/195 [==============================] - 4s 21ms/step - loss: 3.6610 - acc: 0.6975 948 | Epoch 35/200 949 | 195/195 [==============================] - 4s 21ms/step - loss: 3.6877 - acc: 0.6937 950 | Epoch 36/200 951 | 195/195 [==============================] - 4s 21ms/step - loss: 3.6174 - acc: 0.7023 952 | Epoch 37/200 953 | 195/195 [==============================] - 4s 21ms/step - loss: 3.6541 - acc: 0.6981 954 | Epoch 38/200 955 | 195/195 [==============================] - 4s 21ms/step - loss: 3.5960 - acc: 0.7030 956 | Epoch 39/200 957 | 195/195 [==============================] - 4s 21ms/step - loss: 3.5589 - acc: 0.7082 958 | Epoch 40/200 959 | 195/195 [==============================] - 4s 21ms/step - loss: 3.5304 - acc: 0.7080 960 | Epoch 41/200 961 | 195/195 [==============================] - 4s 21ms/step - loss: 3.5466 - acc: 0.7083 962 | Epoch 42/200 963 | 195/195 [==============================] - 4s 21ms/step - loss: 3.4622 - acc: 0.7191 964 | Epoch 43/200 965 | 195/195 [==============================] - 4s 21ms/step - loss: 3.4591 - acc: 0.7161 966 | Epoch 44/200 967 | 195/195 [==============================] - 4s 21ms/step - loss: 3.4526 - acc: 0.7163 968 | Epoch 45/200 969 | 195/195 [==============================] - 4s 21ms/step - loss: 3.4540 - acc: 0.7162 970 | Epoch 46/200 971 | 195/195 [==============================] - 4s 21ms/step - loss: 3.4770 - acc: 0.7142 972 | Epoch 47/200 973 | 195/195 [==============================] - 4s 21ms/step - loss: 3.3777 - acc: 0.7260 974 | Epoch 48/200 975 | 195/195 [==============================] - 4s 21ms/step - loss: 3.3948 - acc: 0.7224 976 | Epoch 49/200 977 | 195/195 [==============================] - 4s 21ms/step - loss: 3.3268 - acc: 0.7286 978 | Epoch 50/200 979 | 195/195 [==============================] - 4s 21ms/step - loss: 3.3755 - acc: 0.7261 980 | Epoch 51/200 981 | 195/195 [==============================] - 4s 21ms/step - loss: 3.2830 - acc: 0.7326 982 | Epoch 52/200 983 | 195/195 [==============================] - 4s 21ms/step - loss: 3.3017 - acc: 0.7325 984 | Epoch 53/200 985 | 195/195 [==============================] - 4s 21ms/step - loss: 3.2776 - acc: 0.7352 986 | Epoch 54/200 987 | 195/195 [==============================] - 4s 21ms/step - loss: 3.2682 - acc: 0.7354 988 | Epoch 55/200 989 | 195/195 [==============================] - 4s 21ms/step - loss: 3.2782 - acc: 0.7355 990 | Epoch 56/200 991 | 195/195 [==============================] - 4s 21ms/step - loss: 3.2301 - acc: 0.7388 992 | Epoch 57/200 993 | 195/195 [==============================] - 4s 21ms/step - loss: 3.2260 - acc: 0.7384 994 | Epoch 58/200 995 | 195/195 [==============================] - 4s 21ms/step - loss: 3.2353 - acc: 0.7393 996 | Epoch 59/200 997 | 195/195 [==============================] - 4s 21ms/step - loss: 3.1699 - acc: 0.7463 998 | Epoch 60/200 999 | 195/195 [==============================] - 4s 21ms/step - loss: 3.1867 - acc: 0.7443 1000 | Epoch 61/200 1001 | 195/195 [==============================] - 4s 21ms/step - loss: 3.1513 - acc: 0.7471 1002 | Epoch 62/200 1003 | 195/195 [==============================] - 4s 21ms/step - loss: 3.1651 - acc: 0.7463 1004 | Epoch 63/200 1005 | 195/195 [==============================] - 4s 21ms/step - loss: 3.1378 - acc: 0.7481 1006 | Epoch 64/200 1007 | 195/195 [==============================] - 4s 21ms/step - loss: 3.1572 - acc: 0.7446 1008 | Epoch 65/200 1009 | 195/195 [==============================] - 4s 21ms/step - loss: 3.1101 - acc: 0.7506 1010 | Epoch 66/200 1011 | 195/195 [==============================] - 4s 21ms/step - loss: 3.1224 - acc: 0.7497 1012 | Epoch 67/200 1013 | 195/195 [==============================] - 4s 21ms/step - loss: 3.1320 - acc: 0.7495 1014 | Epoch 68/200 1015 | 195/195 [==============================] - 4s 21ms/step - loss: 3.0740 - acc: 0.7520 1016 | Epoch 69/200 1017 | 195/195 [==============================] - 4s 21ms/step - loss: 3.0902 - acc: 0.7509 1018 | Epoch 70/200 1019 | 195/195 [==============================] - 4s 21ms/step - loss: 3.0828 - acc: 0.7523 1020 | Epoch 71/200 1021 | 195/195 [==============================] - 4s 21ms/step - loss: 3.0511 - acc: 0.7560 1022 | Epoch 72/200 1023 | 195/195 [==============================] - 4s 21ms/step - loss: 3.0450 - acc: 0.7584 1024 | Epoch 73/200 1025 | 195/195 [==============================] - 4s 21ms/step - loss: 3.0774 - acc: 0.7532 1026 | Epoch 74/200 1027 | 195/195 [==============================] - 4s 21ms/step - loss: 3.0108 - acc: 0.7608 1028 | Epoch 75/200 1029 | 195/195 [==============================] - 4s 21ms/step - loss: 3.0656 - acc: 0.7544 1030 | Epoch 76/200 1031 | 195/195 [==============================] - 4s 21ms/step - loss: 3.0506 - acc: 0.7548 1032 | Epoch 77/200 1033 | 195/195 [==============================] - 4s 21ms/step - loss: 3.0210 - acc: 0.7609 1034 | Epoch 78/200 1035 | 195/195 [==============================] - 4s 21ms/step - loss: 2.9919 - acc: 0.7619 1036 | Epoch 79/200 1037 | 195/195 [==============================] - 4s 21ms/step - loss: 3.0089 - acc: 0.7601 1038 | Epoch 80/200 1039 | 195/195 [==============================] - 4s 21ms/step - loss: 2.9677 - acc: 0.7643 1040 | Epoch 81/200 1041 | 195/195 [==============================] - 4s 21ms/step - loss: 2.9317 - acc: 0.7677 1042 | Epoch 82/200 1043 | 195/195 [==============================] - 4s 21ms/step - loss: 2.9385 - acc: 0.7662 1044 | Epoch 83/200 1045 | 195/195 [==============================] - 4s 21ms/step - loss: 2.9403 - acc: 0.7667 1046 | Epoch 84/200 1047 | 195/195 [==============================] - 4s 21ms/step - loss: 2.9502 - acc: 0.7661 1048 | Epoch 85/200 1049 | 195/195 [==============================] - 4s 21ms/step - loss: 2.9566 - acc: 0.7648 1050 | Epoch 86/200 1051 | 195/195 [==============================] - 4s 21ms/step - loss: 2.9040 - acc: 0.7686 1052 | Epoch 87/200 1053 | 195/195 [==============================] - 4s 21ms/step - loss: 2.9097 - acc: 0.7696 1054 | Epoch 88/200 1055 | 195/195 [==============================] - 4s 21ms/step - loss: 2.9770 - acc: 0.7648 1056 | Epoch 89/200 1057 | 195/195 [==============================] - 4s 21ms/step - loss: 2.8816 - acc: 0.7713 1058 | Epoch 90/200 1059 | 195/195 [==============================] - 4s 21ms/step - loss: 2.9009 - acc: 0.7690 1060 | Epoch 91/200 1061 | 195/195 [==============================] - 4s 21ms/step - loss: 2.8859 - acc: 0.7689 1062 | Epoch 92/200 1063 | 195/195 [==============================] - 4s 21ms/step - loss: 2.9151 - acc: 0.7679 1064 | Epoch 93/200 1065 | 195/195 [==============================] - 4s 21ms/step - loss: 2.8468 - acc: 0.7735 1066 | Epoch 94/200 1067 | 195/195 [==============================] - 4s 21ms/step - loss: 2.8806 - acc: 0.7716 1068 | Epoch 95/200 1069 | 195/195 [==============================] - 4s 21ms/step - loss: 2.8690 - acc: 0.7727 1070 | Epoch 96/200 1071 | 195/195 [==============================] - 4s 21ms/step - loss: 2.8607 - acc: 0.7728 1072 | Epoch 97/200 1073 | 195/195 [==============================] - 4s 21ms/step - loss: 2.8648 - acc: 0.7743 1074 | Epoch 98/200 1075 | 195/195 [==============================] - 4s 21ms/step - loss: 2.8162 - acc: 0.7769 1076 | Epoch 99/200 1077 | 195/195 [==============================] - 4s 21ms/step - loss: 2.8316 - acc: 0.7759 1078 | Epoch 100/200 1079 | 195/195 [==============================] - 4s 21ms/step - loss: 2.8379 - acc: 0.7771 1080 | Epoch 101/200 1081 | 195/195 [==============================] - 4s 21ms/step - loss: 2.7826 - acc: 0.7825 1082 | Epoch 102/200 1083 | 195/195 [==============================] - 4s 21ms/step - loss: 2.8175 - acc: 0.7772 1084 | Epoch 103/200 1085 | 195/195 [==============================] - 4s 21ms/step - loss: 2.8053 - acc: 0.7807 1086 | Epoch 104/200 1087 | 195/195 [==============================] - 4s 21ms/step - loss: 2.8092 - acc: 0.7777 1088 | Epoch 105/200 1089 | 195/195 [==============================] - 4s 21ms/step - loss: 2.7802 - acc: 0.7807 1090 | Epoch 106/200 1091 | 195/195 [==============================] - 4s 21ms/step - loss: 2.7878 - acc: 0.7812 1092 | Epoch 107/200 1093 | 195/195 [==============================] - 4s 21ms/step - loss: 2.6700 - acc: 0.7913 1094 | Epoch 108/200 1095 | 195/195 [==============================] - 4s 21ms/step - loss: 2.6969 - acc: 0.7908 1096 | Epoch 109/200 1097 | 195/195 [==============================] - 4s 21ms/step - loss: 2.6961 - acc: 0.7897 1098 | Epoch 110/200 1099 | 195/195 [==============================] - 4s 21ms/step - loss: 2.7121 - acc: 0.7865 1100 | Epoch 111/200 1101 | 195/195 [==============================] - 4s 21ms/step - loss: 2.6620 - acc: 0.7920 1102 | Epoch 112/200 1103 | 195/195 [==============================] - 4s 21ms/step - loss: 2.7139 - acc: 0.7876 1104 | Epoch 113/200 1105 | 195/195 [==============================] - 4s 21ms/step - loss: 2.6802 - acc: 0.7916 1106 | Epoch 114/200 1107 | 195/195 [==============================] - 4s 21ms/step - loss: 2.7220 - acc: 0.7861 1108 | Epoch 115/200 1109 | 195/195 [==============================] - 4s 21ms/step - loss: 2.6849 - acc: 0.7881 1110 | Epoch 116/200 1111 | 195/195 [==============================] - 4s 21ms/step - loss: 2.6786 - acc: 0.7906 1112 | Epoch 117/200 1113 | 195/195 [==============================] - 4s 21ms/step - loss: 2.6823 - acc: 0.7901 1114 | Epoch 118/200 1115 | 195/195 [==============================] - 4s 21ms/step - loss: 2.6086 - acc: 0.7976 1116 | Epoch 119/200 1117 | 195/195 [==============================] - 4s 21ms/step - loss: 2.6388 - acc: 0.7937 1118 | Epoch 120/200 1119 | 195/195 [==============================] - 4s 21ms/step - loss: 2.6250 - acc: 0.7950 1120 | Epoch 121/200 1121 | 195/195 [==============================] - 4s 21ms/step - loss: 2.6180 - acc: 0.7965 1122 | Epoch 122/200 1123 | 195/195 [==============================] - 4s 21ms/step - loss: 2.6573 - acc: 0.7924 1124 | Epoch 123/200 1125 | 195/195 [==============================] - 4s 21ms/step - loss: 2.6354 - acc: 0.7954 1126 | Epoch 124/200 1127 | 195/195 [==============================] - 4s 21ms/step - loss: 2.6108 - acc: 0.8001 1128 | Epoch 125/200 1129 | 195/195 [==============================] - 4s 21ms/step - loss: 2.6199 - acc: 0.7964 1130 | Epoch 126/200 1131 | 195/195 [==============================] - 4s 21ms/step - loss: 2.6237 - acc: 0.7963 1132 | Epoch 127/200 1133 | 195/195 [==============================] - 4s 21ms/step - loss: 2.5821 - acc: 0.8002 1134 | Epoch 128/200 1135 | 195/195 [==============================] - 4s 21ms/step - loss: 2.6052 - acc: 0.7995 1136 | Epoch 129/200 1137 | 195/195 [==============================] - 4s 21ms/step - loss: 2.5955 - acc: 0.7979 1138 | Epoch 130/200 1139 | 195/195 [==============================] - 4s 21ms/step - loss: 2.6088 - acc: 0.7991 1140 | Epoch 131/200 1141 | 195/195 [==============================] - 4s 21ms/step - loss: 2.5897 - acc: 0.8001 1142 | Epoch 132/200 1143 | 195/195 [==============================] - 4s 21ms/step - loss: 2.6078 - acc: 0.7975 1144 | Epoch 133/200 1145 | 195/195 [==============================] - 4s 21ms/step - loss: 2.6425 - acc: 0.7951 1146 | Epoch 134/200 1147 | 195/195 [==============================] - 4s 21ms/step - loss: 2.6338 - acc: 0.7943 1148 | Epoch 135/200 1149 | 195/195 [==============================] - 4s 21ms/step - loss: 2.6003 - acc: 0.7984 1150 | Epoch 136/200 1151 | 195/195 [==============================] - 4s 21ms/step - loss: 2.5735 - acc: 0.8002 1152 | Epoch 137/200 1153 | 195/195 [==============================] - 4s 21ms/step - loss: 2.6034 - acc: 0.7975 1154 | Epoch 138/200 1155 | 195/195 [==============================] - 4s 21ms/step - loss: 2.6422 - acc: 0.7954 1156 | Epoch 139/200 1157 | 195/195 [==============================] - 4s 21ms/step - loss: 2.6546 - acc: 0.7932 1158 | Epoch 140/200 1159 | 195/195 [==============================] - 4s 21ms/step - loss: 2.6455 - acc: 0.7960 1160 | Epoch 141/200 1161 | 195/195 [==============================] - 4s 21ms/step - loss: 2.6205 - acc: 0.7969 1162 | Epoch 142/200 1163 | 195/195 [==============================] - 4s 21ms/step - loss: 2.6054 - acc: 0.7963 1164 | Epoch 143/200 1165 | 195/195 [==============================] - 4s 21ms/step - loss: 2.5764 - acc: 0.8005 1166 | Epoch 144/200 1167 | 195/195 [==============================] - 4s 21ms/step - loss: 2.5725 - acc: 0.8014 1168 | Epoch 145/200 1169 | 195/195 [==============================] - 4s 21ms/step - loss: 2.5743 - acc: 0.8000 1170 | Epoch 146/200 1171 | 195/195 [==============================] - 4s 21ms/step - loss: 2.6077 - acc: 0.7979 1172 | Epoch 147/200 1173 | 195/195 [==============================] - 4s 21ms/step - loss: 2.6226 - acc: 0.7988 1174 | Epoch 148/200 1175 | 195/195 [==============================] - 4s 21ms/step - loss: 2.5762 - acc: 0.8017 1176 | Epoch 149/200 1177 | 195/195 [==============================] - 4s 21ms/step - loss: 2.5979 - acc: 0.7970 1178 | Epoch 150/200 1179 | 195/195 [==============================] - 4s 21ms/step - loss: 2.6239 - acc: 0.7968 1180 | Epoch 151/200 1181 | 195/195 [==============================] - 4s 21ms/step - loss: 2.6141 - acc: 0.7958 1182 | Epoch 152/200 1183 | 195/195 [==============================] - 4s 21ms/step - loss: 2.5842 - acc: 0.8003 1184 | Epoch 153/200 1185 | 195/195 [==============================] - 4s 21ms/step - loss: 2.5660 - acc: 0.8022 1186 | Epoch 154/200 1187 | 195/195 [==============================] - 4s 21ms/step - loss: 2.5796 - acc: 0.8012 1188 | Epoch 155/200 1189 | 195/195 [==============================] - 4s 21ms/step - loss: 2.6097 - acc: 0.7964 1190 | Epoch 156/200 1191 | 195/195 [==============================] - 4s 21ms/step - loss: 2.6196 - acc: 0.7968 1192 | Epoch 157/200 1193 | 195/195 [==============================] - 4s 21ms/step - loss: 2.6238 - acc: 0.7960 1194 | Epoch 158/200 1195 | 195/195 [==============================] - 4s 21ms/step - loss: 2.5960 - acc: 0.7987 1196 | Epoch 159/200 1197 | 195/195 [==============================] - 4s 21ms/step - loss: 2.6156 - acc: 0.7973 1198 | Epoch 160/200 1199 | 195/195 [==============================] - 4s 21ms/step - loss: 2.6116 - acc: 0.7976 1200 | Epoch 161/200 1201 | 195/195 [==============================] - 4s 21ms/step - loss: 2.6379 - acc: 0.7962 1202 | Epoch 162/200 1203 | 195/195 [==============================] - 4s 21ms/step - loss: 2.6249 - acc: 0.7946 1204 | Epoch 163/200 1205 | 195/195 [==============================] - 4s 21ms/step - loss: 2.5953 - acc: 0.7966 1206 | Epoch 164/200 1207 | 195/195 [==============================] - 4s 21ms/step - loss: 2.6325 - acc: 0.7950 1208 | Epoch 165/200 1209 | 195/195 [==============================] - 4s 21ms/step - loss: 2.6523 - acc: 0.7930 1210 | Epoch 166/200 1211 | 195/195 [==============================] - 4s 21ms/step - loss: 2.5807 - acc: 0.8010 1212 | Epoch 167/200 1213 | 195/195 [==============================] - 4s 21ms/step - loss: 2.6433 - acc: 0.7930 1214 | Epoch 168/200 1215 | 195/195 [==============================] - 4s 21ms/step - loss: 2.5426 - acc: 0.8030 1216 | Epoch 169/200 1217 | 195/195 [==============================] - 4s 21ms/step - loss: 2.6110 - acc: 0.7972 1218 | Epoch 170/200 1219 | 195/195 [==============================] - 4s 21ms/step - loss: 2.6340 - acc: 0.7957 1220 | Epoch 171/200 1221 | 195/195 [==============================] - 4s 21ms/step - loss: 2.6350 - acc: 0.7963 1222 | Epoch 172/200 1223 | 195/195 [==============================] - 4s 21ms/step - loss: 2.5893 - acc: 0.7995 1224 | Epoch 173/200 1225 | 195/195 [==============================] - 4s 21ms/step - loss: 2.5882 - acc: 0.7984 1226 | Epoch 174/200 1227 | 195/195 [==============================] - 4s 21ms/step - loss: 2.6131 - acc: 0.7979 1228 | Epoch 175/200 1229 | 195/195 [==============================] - 4s 21ms/step - loss: 2.5740 - acc: 0.8005 1230 | Epoch 176/200 1231 | 195/195 [==============================] - 4s 21ms/step - loss: 2.6203 - acc: 0.7960 1232 | Epoch 177/200 1233 | 195/195 [==============================] - 4s 21ms/step - loss: 2.6040 - acc: 0.7974 1234 | Epoch 178/200 1235 | 195/195 [==============================] - 4s 21ms/step - loss: 2.6296 - acc: 0.7949 1236 | Epoch 179/200 1237 | 195/195 [==============================] - 4s 21ms/step - loss: 2.6019 - acc: 0.7982 1238 | Epoch 180/200 1239 | 195/195 [==============================] - 4s 21ms/step - loss: 2.6176 - acc: 0.7980 1240 | Epoch 181/200 1241 | 195/195 [==============================] - 4s 21ms/step - loss: 2.6353 - acc: 0.7956 1242 | Epoch 182/200 1243 | 195/195 [==============================] - 4s 21ms/step - loss: 2.6288 - acc: 0.7968 1244 | Epoch 183/200 1245 | 195/195 [==============================] - 4s 21ms/step - loss: 2.6437 - acc: 0.7945 1246 | Epoch 184/200 1247 | 195/195 [==============================] - 4s 21ms/step - loss: 2.6047 - acc: 0.7967 1248 | Epoch 185/200 1249 | 195/195 [==============================] - 4s 21ms/step - loss: 2.6041 - acc: 0.7991 1250 | Epoch 186/200 1251 | 195/195 [==============================] - 4s 21ms/step - loss: 2.6201 - acc: 0.7970 1252 | Epoch 187/200 1253 | 195/195 [==============================] - 4s 21ms/step - loss: 2.6390 - acc: 0.7957 1254 | Epoch 188/200 1255 | 195/195 [==============================] - 4s 21ms/step - loss: 2.6322 - acc: 0.7966 1256 | Epoch 189/200 1257 | 195/195 [==============================] - 4s 21ms/step - loss: 2.6007 - acc: 0.7981 1258 | Epoch 190/200 1259 | 195/195 [==============================] - 4s 21ms/step - loss: 2.5960 - acc: 0.7989 1260 | Epoch 191/200 1261 | 195/195 [==============================] - 4s 21ms/step - loss: 2.5726 - acc: 0.8006 1262 | Epoch 192/200 1263 | 195/195 [==============================] - 4s 21ms/step - loss: 2.5944 - acc: 0.7993 1264 | Epoch 193/200 1265 | 195/195 [==============================] - 4s 21ms/step - loss: 2.6384 - acc: 0.7954 1266 | Epoch 194/200 1267 | 195/195 [==============================] - 4s 21ms/step - loss: 2.6452 - acc: 0.7952 1268 | Epoch 195/200 1269 | 195/195 [==============================] - 4s 21ms/step - loss: 2.6178 - acc: 0.7959 1270 | Epoch 196/200 1271 | 195/195 [==============================] - 4s 21ms/step - loss: 2.6292 - acc: 0.7968 1272 | Epoch 197/200 1273 | 195/195 [==============================] - 4s 21ms/step - loss: 2.6361 - acc: 0.7950 1274 | Epoch 198/200 1275 | 195/195 [==============================] - 4s 21ms/step - loss: 2.6028 - acc: 0.7972 1276 | Epoch 199/200 1277 | 195/195 [==============================] - 4s 21ms/step - loss: 2.6190 - acc: 0.7978 1278 | Epoch 200/200 1279 | 195/195 [==============================] - 4s 21ms/step - loss: 2.6297 - acc: 0.7947 1280 | 1281 | 1282 | 1283 | 1284 | 1285 | 1286 | ```python 1287 | # 保存模型 1288 | m.save('/tmp/target.weights.h5') 1289 | ``` 1290 | 1291 | 1292 | ```python 1293 | # 最后关闭hdf5 1294 | config.pw_h5_file.close() 1295 | ``` 1296 | 1297 | ## 再次评估 1298 | 1299 | 1300 | ```python 1301 | evaluate('/tmp/target.weights.h5', 1302 | x_test,y_test[:,0],config.num_classes,True) 1303 | ``` 1304 | 1305 | acc:0.5741 1306 | 虽然精度还不是很高(需要测试更多的超参数选择),但是可以看到从来都没使用100个类别一块训练;最终在100个类别的评估精度为0.5741 -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | File Name: __init__.py 4 | Description : 论文《Large-scale Bisample Learning on ID vs. Spot Face Recognition》keras实现 5 | Author : mick.yi 6 | date: 2019/1/2 7 | """ -------------------------------------------------------------------------------- /backbone.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | File Name: backbone 4 | Description : 骨干网络 5 | Author : mick.yi 6 | date: 2018/12/21 7 | """ 8 | 9 | import keras 10 | from keras import layers 11 | from keras_applications.resnet50 import identity_block, conv_block 12 | 13 | 14 | def resnet50(inputs): 15 | # Determine proper input shape 16 | bn_axis = 3 17 | 18 | # x = layers.ZeroPadding2D(padding=(3, 3), name='conv1_pad')(inputs) 19 | x = layers.Conv2D(64, (3, 3), 20 | strides=(1, 1), 21 | padding='same', 22 | name='conv1')(inputs) 23 | x = layers.BatchNormalization(axis=bn_axis, name='bn_conv1')(x) 24 | x = layers.Activation('relu')(x) 25 | x = layers.MaxPooling2D((3, 3), strides=(2, 2))(x) 26 | 27 | x = conv_block(x, 3, [64, 64, 64], stage=2, block='a', strides=(1, 1)) 28 | x = identity_block(x, 3, [64, 64, 64], stage=2, block='b') 29 | x = identity_block(x, 3, [64, 64, 64], stage=2, block='c') 30 | 31 | x = conv_block(x, 3, [128, 128, 128], stage=3, block='a') 32 | x = identity_block(x, 3, [128, 128, 128], stage=3, block='b') 33 | x = identity_block(x, 3, [128, 128, 128], stage=3, block='c') 34 | x = identity_block(x, 3, [128, 128, 128], stage=3, block='d') 35 | 36 | x = conv_block(x, 3, [256, 256, 256], stage=4, block='a') 37 | x = identity_block(x, 3, [256, 256, 256], stage=4, block='b') 38 | x = identity_block(x, 3, [256, 256, 256], stage=4, block='c') 39 | x = identity_block(x, 3, [256, 256, 256], stage=4, block='d') 40 | x = identity_block(x, 3, [256, 256, 256], stage=4, block='e') 41 | x = identity_block(x, 3, [256, 256, 256], stage=4, block='f') 42 | 43 | x = conv_block(x, 3, [512, 512, 512], stage=5, block='a') 44 | x = identity_block(x, 3, [512, 512, 512], stage=5, block='b') 45 | x = identity_block(x, 3, [512, 512, 512], stage=5, block='c') 46 | 47 | # # 确定精调层 48 | # no_train_model = Model(inputs=img_input, outputs=x) 49 | # for l in no_train_model.layers: 50 | # if isinstance(l, layers.BatchNormalization): 51 | # l.trainable = True 52 | # else: 53 | # l.trainable = False 54 | 55 | # model = Model(input, x, name='resnet50') 56 | x = layers.GlobalAveragePooling2D()(x) 57 | # # 新增一个全连接层降维 58 | # x = layers.Dense(units=512)(x) 59 | return x 60 | 61 | 62 | def cifar_base_model(inputs): 63 | x = layers.Conv2D(32, (3, 3), padding='same')(inputs) 64 | x = layers.Activation('relu')(x) 65 | x = layers.Conv2D(32, (3, 3))(x) 66 | x = layers.Activation('relu')(x) 67 | x = layers.MaxPooling2D(pool_size=(2, 2))(x) 68 | x = layers.Dropout(0.25)(x) 69 | 70 | x = layers.Conv2D(64, (3, 3), padding='same')(x) 71 | x = layers.Activation('relu')(x) 72 | x = layers.Conv2D(64, (3, 3))(x) 73 | x = layers.Activation('relu')(x) 74 | x = layers.MaxPooling2D(pool_size=(2, 2))(x) 75 | x = layers.Dropout(0.25)(x) 76 | 77 | x = layers.Flatten()(x) 78 | x = layers.Dense(512)(x) 79 | x = layers.Activation('relu')(x) 80 | x = layers.Dropout(0.5)(x) 81 | return x 82 | -------------------------------------------------------------------------------- /cifar_demo.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | File Name: cifar_demo 4 | Description : 5 | Author : mick.yi 6 | date: 2019/1/11 7 | """ 8 | import tensorflow as tf 9 | import numpy as np 10 | import h5py 11 | import keras 12 | from keras.layers import Input, Dense, Layer 13 | from keras import layers 14 | from keras.models import Model 15 | from keras.optimizers import SGD 16 | import keras.backend as K 17 | from keras.callbacks import ReduceLROnPlateau, TensorBoard 18 | import sys 19 | import os 20 | import time 21 | import matplotlib.pyplot as plt 22 | 23 | if __package__ is None: 24 | sys.path 25 | sys.path.insert(0, '/opt/github/py_data_mining/face') 26 | __package__ = "keras-lbl-IvS" 27 | 28 | from .config import Config 29 | from .utils import random_select, get_weights, update_weights, update_queue 30 | from .faiss_utils import get_index, update_multi 31 | from .reader import get_mslm_infos, load_img 32 | from .layers import DenseWithDPSoftmaxLoss 33 | 34 | 35 | class DenseWithAMSoftmaxLoss(Layer): 36 | 37 | def __init__(self, num_classes, m=0.35, scale=30, **kwargs): 38 | """ 39 | 根据AdditiveMarginSoftmaxforFaceVerification论文设置m和scale 40 | :param num_classes: 类别数 41 | :param m: margin间隔 42 | :param scale: 尺度 43 | :param kwargs: 44 | """ 45 | self.output_dim = num_classes 46 | self.m = m 47 | self.scale = scale 48 | super(DenseWithAMSoftmaxLoss, self).__init__(**kwargs) 49 | 50 | def build(self, input_shape): 51 | # 添加可训练参数 52 | self.kernel = self.add_weight(name='kernel', 53 | shape=(input_shape[1], self.output_dim), 54 | initializer='glorot_normal', 55 | trainable=True) 56 | 57 | def call(self, inputs): 58 | self.inputs = inputs 59 | # L2归一化 60 | self.w_norm = K.tf.nn.l2_normalize(self.kernel, 0, 1e-10) 61 | self.x_norm = K.tf.nn.l2_normalize(self.inputs, 1, 1e-10) 62 | self.logits = K.dot(self.x_norm, self.w_norm) 63 | return self.logits 64 | 65 | def compute_output_shape(self, input_shape): 66 | return (input_shape[0], # batch_size 67 | self.output_dim) 68 | 69 | def loss(self, y_true, y_pred): 70 | """ 71 | AmSoftmax Loss 72 | :param y_true: 真实类别 [N,num_classes] 73 | :param y_pred: 预测类别[N,num_classes] 74 | :return: 75 | """ 76 | y_true = K.cast(y_true, 'float32') 77 | thelta_m = y_pred - y_true * self.m # cosine(thelta)-m ; y_true 就相当于mask 78 | return K.categorical_crossentropy(y_true, self.scale * thelta_m, from_logits=True) 79 | 80 | def loss_sparse(self,y_true,y_pred): 81 | y_true = y_true[:, 0] # 非常重要,默认是二维的 82 | y_true_mask = K.one_hot(K.tf.cast(y_true, dtype='int32'), self.output_dim) 83 | cosine_m = y_pred - y_true_mask * self.m # cosine-m 84 | losses = K.sparse_categorical_crossentropy(target=K.cast(y_true,'int32'), 85 | output=cosine_m * self.scale, 86 | from_logits=True) 87 | return losses 88 | 89 | 90 | class CifarConfig(Config): 91 | # 输入信息 92 | input_shape = (32, 32, 3) 93 | num_classes = 100 # cifar100 94 | 95 | # 训练参数 96 | batch_size = 30 97 | 98 | # 原型权重保存 99 | prototype_weights_hdf5 = '/tmp/cifar_prototype_weights.hdf5' 100 | prototype_weights_dataset = 'cifar_prototype_weights_set' 101 | pw_h5_file = None # hdf5 File文件 102 | 103 | index = None # 保存原型类别的faiss索引 104 | 105 | dominant_queue = None # 支配队列 106 | candidate_queue = None # 候选队列 107 | 108 | dominant_queue_num = 10 # 支配队列大小 109 | candidate_queue_num = 30 # 候选队列大小 110 | 111 | 112 | config = CifarConfig() 113 | 114 | 115 | def set_gpu_growth(): 116 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 117 | cfg = tf.ConfigProto() 118 | cfg.gpu_options.allow_growth = True 119 | session = tf.Session(config=cfg) 120 | K.set_session(session) 121 | 122 | 123 | class DenseWithDPSoftmaxLoss(Layer): 124 | def __init__(self, num_class, m=0.35, scale=30, **kwargs): 125 | self.output_dim = num_class 126 | self.margin = m 127 | self.scale = scale 128 | # self.current_selected_labels = K.ones(shape=(num_class, 1)) 129 | super(DenseWithDPSoftmaxLoss, self).__init__(**kwargs) 130 | 131 | def build(self, input_shape): 132 | self.kernel = self.add_weight(name='kernel', 133 | shape=(input_shape[0][1], self.output_dim), # (嵌入维度,num_class) 134 | dtype=K.floatx(), 135 | initializer='glorot_normal', 136 | trainable=True) 137 | self.current_selected_labels = self.add_weight(name='labels', 138 | shape=(self.output_dim, 1), 139 | initializer='glorot_normal', 140 | trainable=False) 141 | self.y_pred = self.add_weight(name='pred', 142 | shape=(self.output_dim, self.output_dim), 143 | initializer='glorot_normal', 144 | trainable=False) 145 | 146 | def call(self, inputs, **kwargs): 147 | # 将当前step类别的权重,赋值给tensor 148 | weights_assign_op = K.tf.assign(self.kernel, 149 | K.transpose(inputs[1]), 150 | name='assign_weights') 151 | label_assign_op = K.tf.assign(self.current_selected_labels, inputs[2], name='assign_labels') 152 | with K.tf.control_dependencies([weights_assign_op, label_assign_op]): 153 | self.x_norm = K.l2_normalize(inputs[0], axis=1) 154 | self.kernel_norm = K.l2_normalize(self.kernel, axis=0) 155 | self.logit = K.dot(self.x_norm, self.kernel_norm) 156 | return self.logit 157 | 158 | def compute_output_shape(self, input_shape): 159 | return (input_shape[0][0], 160 | self.output_dim) 161 | # return [(input_shape[0][0], self.output_dim), 162 | # (self.output_dim, input_shape[0][1])] 163 | 164 | def loss(self, y_true, y_pred): 165 | # 首先将预测值保持到权重中 166 | pred_assign_op = K.tf.assign(self.y_pred, 167 | y_pred, 168 | name='assign_pred') 169 | with K.tf.control_dependencies([pred_assign_op]): 170 | y_true = y_true[:, 0] # 非常重要,默认是二维的 171 | y_true_mask = K.one_hot(K.tf.cast(y_true, dtype='int32'), self.output_dim) 172 | cosine_m = y_pred - y_true_mask * self.margin # cosine-m 173 | losses = K.sparse_categorical_crossentropy(target=y_true, 174 | output=cosine_m * self.scale, 175 | from_logits=True) 176 | 177 | return losses 178 | 179 | 180 | def cifar_base_net(inputs): 181 | x = layers.Conv2D(32, (3, 3), padding='same', name='conv_1')(inputs) 182 | x = layers.BatchNormalization(axis=3, name='bn_conv1')(x) 183 | x = layers.Activation('relu')(x) 184 | x = layers.Conv2D(32, (3, 3), name='conv_2')(x) 185 | x = layers.Activation('relu')(x) 186 | x = layers.MaxPooling2D(pool_size=(2, 2))(x) 187 | x = layers.Dropout(0.25)(x) 188 | 189 | x = layers.Conv2D(64, (3, 3), padding='same', name='conv_3')(x) 190 | x = layers.Activation('relu')(x) 191 | x = layers.Conv2D(64, (3, 3), name='conv_4')(x) 192 | x = layers.Activation('relu')(x) 193 | x = layers.MaxPooling2D(pool_size=(2, 2))(x) 194 | x = layers.Dropout(0.25)(x) 195 | 196 | x = layers.Conv2D(512, (3, 3), name='conv_5')(x) 197 | x = layers.Activation('relu')(x) 198 | x = layers.GlobalAvgPool2D()(x) 199 | 200 | return x 201 | 202 | 203 | def label_id_map(labels, num_class): 204 | """ 205 | 将图像按照类别分组 206 | """ 207 | # 初始化 208 | label_id_dict = dict() 209 | for i in range(num_class): 210 | label_id_dict[i] = [] 211 | 212 | # 逐个图像归类 213 | for img_id in range(len(labels)): 214 | label_id_dict[labels[img_id]].append(img_id) 215 | return label_id_dict 216 | 217 | 218 | def init_prototype(pretrained_weights, X, label_id_dict, num_class): 219 | inputs = Input(batch_shape=(config.batch_size,) + config.input_shape) 220 | features = cifar_base_net(inputs) 221 | model = Model(inputs, features) 222 | model.load_weights(pretrained_weights, by_name=True) 223 | # 原型权重一份放到hdf5,一份存放到faiss中(faiss中保留的不是精准的) 224 | f_label = h5py.File(config.prototype_weights_hdf5, 'w') 225 | label_feature = f_label.create_dataset(config.prototype_weights_dataset, 226 | shape=(num_class, 512), dtype='f') 227 | 228 | # 逐个类别处理 229 | for label in range(num_class): 230 | # 获取某个label的所有图像,并使用模型预测图像的特征,最后求均值作为label的原型权重 231 | image_ids = label_id_dict[label] # 图像id 232 | images = X[image_ids] # 图像数据 233 | features = model.predict(np.asarray(images)) # 输出特征 234 | features = keras.utils.np_utils.normalize(features) # 归一化 235 | features = np.mean(features, axis=0) # 求均值 236 | features = keras.utils.np_utils.normalize(features) # 再次归一化; 是二维的 237 | 238 | # 赋值给hdf5 239 | label_feature[label] = features[0] # (1,d) 转为 (d,) 240 | # 每1w次,刷写到磁盘 241 | if label % 1000 == 0: 242 | f_label.flush() 243 | print("{} init_prototype 完成:{}".format(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()), 244 | label)) 245 | # 最后关闭文件 246 | f_label.close() 247 | 248 | 249 | def get_prototype(deal_batch=1000): 250 | f_label = h5py.File(config.prototype_weights_hdf5, 'r+') 251 | dset = f_label[config.prototype_weights_dataset] 252 | length = len(dset) 253 | index = get_index(512) 254 | # 逐个类别处理,更新faiss index 255 | for batch_no in range(length // deal_batch): 256 | start = batch_no * deal_batch 257 | end = (batch_no + 1) * deal_batch 258 | features = dset[start:end] 259 | update_multi(index, features, np.arange(start, end)) 260 | # 处理不能整除的情况 261 | if not length % deal_batch == 0: 262 | start = length - length % deal_batch 263 | end = length 264 | features = dset[start:end] 265 | update_multi(index, features, np.arange(start, end)) 266 | return f_label, index 267 | 268 | 269 | def init_queue(index, weights_set, num_class, dq_num, cq_num): 270 | """ 271 | 初始化候选队列和支配队列 272 | :param index: 273 | :param weights_set: h5py dataset对象 274 | :param num_class: 275 | :param dq_num: 276 | :param cq_num: 277 | :return: 278 | """ 279 | data, candidate_label_idx = index.search(weights_set[:num_class], cq_num) 280 | dominant_label_idx = candidate_label_idx[:, :dq_num] # 候选队列包含支配队列 281 | 282 | # 转为字典类型 283 | dominant_queue = dict(enumerate(dominant_label_idx)) 284 | candidate_queue = dict(enumerate(candidate_label_idx)) 285 | 286 | return dominant_queue, candidate_queue 287 | 288 | 289 | def generator(X, y, label_id_dict, dominant_queue, num_class, batch_size): 290 | """ 291 | 训练样本生成器 292 | :param X: 图像 293 | :param y: 标签 294 | :param label_id_dict: 类别和对应的图像id字典 295 | :param dominant_queue: 支配队列; 类别和对应的支配类别 296 | :param num_class: 类别数, 这里是100 297 | :param batch_size: batch_size 298 | :return: 299 | """ 300 | while True: 301 | # 两级采样,首先采样batch_size/3个正类别;然后随机采样支配类别 302 | sample_labels = np.random.choice(num_class, batch_size // 3, replace=False) # 无放回抽样 303 | 304 | selected_labels = set(sample_labels) # 保存当前选中类别set 305 | selected_image_labels = [] # 保存当前step选中的图像和对应的类别标签 306 | # 首先选择正原型 307 | for label in sample_labels: 308 | selected_image_id = random_select(label_id_dict[label]) 309 | selected_image_labels.append([selected_image_id, label]) 310 | 311 | # 再选择相关的支配原型,直到mini-batch大小 312 | while len(selected_image_labels) < batch_size: 313 | # 随机采样当前正原型 314 | label = random_select(sample_labels) 315 | # 随机选择支配类别,不能是之前已经选择过的 316 | dq_label = random_select(dominant_queue[label]) 317 | while dq_label in selected_labels: 318 | dq_label = random_select(dominant_queue[label]) 319 | selected_labels.add(dq_label) 320 | # 选择支配类别的图像 321 | selected_image_id = random_select(label_id_dict[dq_label]) 322 | selected_image_labels.append([selected_image_id, dq_label]) 323 | # 当前选中标签 324 | selected_image_labels = np.asarray(selected_image_labels) # 转为numpy数组 325 | current_selected_labels = selected_image_labels[:, 1] 326 | current_weights = get_weights(config.pw_h5_file[config.prototype_weights_dataset], 327 | current_selected_labels) 328 | 329 | # 加载图像 330 | images = X[current_selected_labels] 331 | # 返回当前mini-batch 332 | current_selected_labels = np.expand_dims(current_selected_labels, axis=1) 333 | 334 | yield [images, 335 | current_weights, 336 | current_selected_labels], np.arange(batch_size) # 标签类别永远是0~batch_size-1 337 | 338 | 339 | class ExportWeights(keras.callbacks.Callback): 340 | def on_batch_end(self, batch, logs=None): 341 | layer = self.model.layers[-1] 342 | # 获取当前step训练后的权重、当前训练实际的标签类别,预测结果 343 | trained_weights, current_trained_labels, y_pred = layer.get_weights()[:3] 344 | trained_weights = np.transpose(trained_weights) 345 | current_trained_labels = np.asarray(current_trained_labels[:, 0], dtype=np.int) # 二维转一维 346 | # 更新faiss中的权重 347 | update_weights(config.pw_h5_file, 348 | config.pw_h5_file[config.prototype_weights_dataset], 349 | trained_weights, 350 | current_trained_labels) 351 | # 以下更新支配队列,根据预测结果更新支配队列; 352 | update_queue(config.dominant_queue, 353 | config.candidate_queue, 354 | y_pred, 355 | current_trained_labels) 356 | 357 | 358 | def get_call_back(): 359 | """ 360 | 定义call back 361 | :return: 362 | """ 363 | 364 | # 精度没有提升 365 | lr_reducer = ReduceLROnPlateau(monitor='acc', 366 | factor=np.sqrt(0.1), 367 | cooldown=1, 368 | patience=1, 369 | min_lr=0) 370 | 371 | log = TensorBoard(log_dir='log') 372 | 373 | export_weights = ExportWeights() 374 | 375 | return [lr_reducer, export_weights, log] 376 | 377 | 378 | def source_model(img_inputs, num_class, use_amsoftmax=False): 379 | print("use_amsoftmax:{}".format(use_amsoftmax)) 380 | if use_amsoftmax: 381 | dense = DenseWithAMSoftmaxLoss(num_class, m=0.2, scale=24., name='amsoftmax') 382 | loss = dense.loss 383 | else: 384 | dense = layers.Dense(num_class, use_bias=False, activation='softmax', name='dense') 385 | loss = K.categorical_crossentropy 386 | 387 | features = cifar_base_net(img_inputs) 388 | output = dense(features) 389 | 390 | m = Model(img_inputs, output) 391 | 392 | m.compile(optimizer=keras.optimizers.Adam(lr=0.001, decay=1e-6), 393 | loss=loss, 394 | metrics=['accuracy']) 395 | return m 396 | 397 | 398 | def evaluate(pretrained_weights, X, y, num_class, use_amsoftmax=False): 399 | if use_amsoftmax: 400 | dense = DenseWithAMSoftmaxLoss(num_class, m=0.2, scale=24., name='other') 401 | else: 402 | dense = layers.Dense(num_class, use_bias=False, activation='softmax') 403 | # 构建模型 404 | inputs = Input(shape=config.input_shape) 405 | features = cifar_base_net(inputs) 406 | outputs = dense(features) 407 | m = Model(inputs, outputs) 408 | 409 | m.load_weights(pretrained_weights, by_name=False) 410 | 411 | # 设置分类层的权重 412 | f_label = h5py.File(config.prototype_weights_hdf5, 'r') 413 | weights = f_label[config.prototype_weights_dataset] # HDF5 dataset object 414 | m.layers[-1].set_weights([np.transpose(weights[:])]) # 使用weights[:]转为numpy 415 | # 最后关闭文件 416 | f_label.close() 417 | m.summary() 418 | 419 | # 评估预测 420 | y_pred = m.predict(X) 421 | y_pred = np.argmax(y_pred, axis=-1) 422 | 423 | print("acc:{}".format(np.mean(y == y_pred))) 424 | 425 | 426 | def img_show(X, rows, cols): 427 | plt.figure(figsize=(cols * 1.5, rows)) 428 | for i, img in enumerate(X[:rows * cols]): 429 | plt.subplot(rows, cols, 1 + i) 430 | plt.imshow(img) 431 | plt.show() 432 | 433 | 434 | def main(args): 435 | K.clear_session() 436 | set_gpu_growth() 437 | 438 | (x_train, y_train), (x_test, y_test) = cifar10.load_data() 439 | label_id_dict = label_id_map(x_train, config.num_classes) # 类别对应的 440 | 441 | # 初始化原型权重 442 | if 'init' in args.stages: 443 | init_prototype(x_train, 444 | label_id_dict, 445 | config.num_classes) 446 | print("初始化原型权重完成... ...") 447 | 448 | # 训练阶段 449 | if 'train' in args.stages: 450 | config.pw_h5_file, config.index = get_prototype(1000) 451 | print("获取原型权重完成... ...") 452 | # current_weights = get_weights(config.pw_h5_file[config.prototype_weights_dataset], 453 | # np.arange(config.batch_size)) 454 | # print(current_weights.shape) 455 | # print(type(current_weights)) 456 | # print(current_weights[0]) 457 | 458 | # 初始化队列 459 | config.dominant_queue, config.candidate_queue = init_queue(config.index, 460 | config.pw_h5_file[config.prototype_weights_dataset], 461 | config.num_classes, 462 | config.dominant_queue_num, 463 | config.candidate_queue_num) 464 | print("初始化队列完成... ...") 465 | # 构建模型 466 | inputs = Input(batch_shape=(config.batch_size,) + config.input_shape) 467 | weights_inputs = Input(batch_shape=(config.batch_size, 512)) 468 | label_inputs = Input(batch_shape=(config.batch_size, 1)) # 至少是二维的, 469 | features = config.backbone(inputs) 470 | dense = DenseWithDPSoftmaxLoss(config.batch_size) # batch-size当做类别数 471 | outputs = dense([features, weights_inputs, label_inputs]) 472 | m = Model([inputs, weights_inputs, label_inputs], outputs) 473 | m.summary() 474 | m.load_weights(config.backbone_weights, by_name=True) 475 | m.compile(loss=dense.loss, optimizer=SGD(lr=0.01, momentum=0.9, decay=0.0005), metrics=['accuracy']) 476 | 477 | # 训练模型 478 | print("开始训练模型... ...") 479 | gen = generator(x_train, 480 | label_id_dict, 481 | config.dominant_queue, 482 | config.num_classes, 483 | config.batch_size) 484 | 485 | m.fit_generator(gen, 486 | callbacks=get_call_back(), 487 | steps_per_epoch=len(x_train) / 4 // config.batch_size, 488 | epochs=10, 489 | use_multiprocessing=False, 490 | verbose=1, 491 | validation_data=next(gen) 492 | ) 493 | 494 | # 最后关闭hdf5 495 | config.pw_h5_file.close() 496 | 497 | 498 | if __name__ == '__main__': 499 | # 训练源 500 | img_input = Input(shape=(32, 32, 3)) 501 | use_amsoftmax = sys.argv[1] == '1' 502 | m = source_model(img_input, 10, use_amsoftmax) 503 | 504 | m.summary() 505 | 506 | from keras.datasets import cifar10, cifar100 507 | 508 | (x_train, y_train), (x_test, y_test) = cifar10.load_data() 509 | y_train = keras.utils.to_categorical(y_train, 10) 510 | y_test = keras.utils.to_categorical(y_test, 10) 511 | m.fit(x_train, y_train, 512 | batch_size=64, 513 | epochs=10, 514 | validation_data=(x_test, y_test), 515 | callbacks=[keras.callbacks.ModelCheckpoint(filepath='/tmp/source.weights.{epoch:03d}.h5', 516 | monitor='val_acc', 517 | verbose=1, 518 | save_best_only=True)], 519 | shuffle=True) 520 | m.save('/tmp/source.weights.h5') 521 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | File Name: config 4 | Description : 配置文件 5 | Author : mick.yi 6 | date: 2019/1/3 7 | """ 8 | from .backbone import resnet50 9 | 10 | 11 | class Config(object): 12 | ################# 13 | # basenet 阶段 14 | ################# 15 | # ms1m人脸数据标注和图像图论 16 | ms1m_annotation_file = '/home/dataset/face_recognize/ms1m_112_112.label' 17 | ms1m_img_dir = '/home/dataset/face_recognize/ms1m_112_112' 18 | # 输入信息 19 | input_shape = (112, 112, 3) 20 | num_classes = 2000 # 测试时可以调小点 85164 21 | 22 | # 网络结构信息 23 | def backbone(self, inputs): 24 | return resnet50(inputs) 25 | 26 | # 训练参数 27 | batch_size = 48 28 | 29 | lr = 0.1 30 | learning_rate_schedule = { 31 | 0: 1 * lr, 32 | 160000: 0.1 * lr, 33 | 240000: 0.01 * lr, 34 | 280000: 0.001 * lr 35 | } 36 | 37 | backbone_weights = '/tmp/docface.basenet.002.h5' 38 | 39 | # 其它参数信息 40 | prototype_weights_hdf5 = '/tmp/prototype_weights.hdf5' 41 | prototype_weights_dataset = 'prototype_weights_set' 42 | pw_h5_file = None # hdf5 File文件 43 | 44 | index = None # 保存原型类别的faiss索引 45 | 46 | dominant_queue = None # 支配队列 47 | candidate_queue = None # 候选队列 48 | 49 | dominant_queue_num = 10 # 支配队列大小 50 | candidate_queue_num = 30 # 候选队列大小 51 | 52 | # 评估阶段参数 53 | lbl_weights = '/tmp/lbl-IvS.010.h5' 54 | 55 | 56 | # 当前配置 57 | current_config = Config() 58 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | File Name: evaluate 4 | Description : lbl-IvS评估 5 | Author : mick.yi 6 | date: 2019/1/10 7 | """ 8 | import keras 9 | from keras import Input 10 | from keras.layers import Dense 11 | from keras.models import Model 12 | import numpy as np 13 | import sys 14 | import os 15 | import h5py 16 | 17 | if __package__ is None: 18 | sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..')) 19 | __package__ = "keras-lbl-IvS" 20 | 21 | from .config import current_config as config 22 | from .reader import get_mslm_infos, load_img 23 | 24 | 25 | def gen_data(images_info, batch_size): 26 | all_size = len(images_info) 27 | while True: 28 | batch_idx = np.random.choice(all_size, batch_size, replace=False) 29 | images = [] 30 | labels = [] 31 | for i in batch_idx: 32 | images.append(load_img(images_info[i]['img_path'])) 33 | labels.append(int(images_info[i]['label'])) 34 | yield np.asarray(images), np.asarray(labels) 35 | 36 | 37 | def main(): 38 | # 构建模型 39 | inputs = Input(batch_shape=(config.batch_size,) + config.input_shape) 40 | features = config.backbone(inputs) 41 | dense = Dense(config.num_classes, 42 | use_bias=False, 43 | activation='softmax') # 增加最后一层分类层;这里需要使用softmax激活,默认不是fromlogits 44 | outputs = dense(features) 45 | m = Model(inputs, outputs) 46 | m.compile(optimizer='SGD', 47 | loss=keras.losses.sparse_categorical_crossentropy, 48 | metrics=['accuracy']) 49 | m.load_weights(config.lbl_weights, by_name=True) 50 | 51 | # 设置分类层的权重 52 | f_label = h5py.File(config.prototype_weights_hdf5, 'r') 53 | weights = f_label[config.prototype_weights_dataset] # HDF5 dataset object 54 | m.layers[-1].set_weights([np.transpose(weights[:])]) # 使用weights[:]转为numpy 55 | m.summary() 56 | 57 | # 加载数据 58 | images_info, label_set = get_mslm_infos(config.ms1m_annotation_file, config.ms1m_img_dir) 59 | # 过滤小于num_classes的类别;测试用 60 | images_info = filter(lambda x: int(x['label']) < config.num_classes, images_info) 61 | 62 | # 评估预测 63 | scores = m.evaluate_generator(gen_data(list(images_info), 200), 64 | steps=10, 65 | verbose=1) 66 | m.predict_generator 67 | print(scores) 68 | 69 | 70 | if __name__ == '__main__': 71 | main() 72 | -------------------------------------------------------------------------------- /faiss_utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | File Name: faiss_utils 4 | Description : faiss工具类 5 | Author : mick.yi 6 | date: 2019/1/4 7 | """ 8 | import faiss 9 | import numpy as np 10 | 11 | 12 | def get_index(dimension): 13 | sub_index = faiss.IndexFlatL2(dimension) 14 | index = faiss.IndexIDMap(sub_index) 15 | return index 16 | 17 | 18 | def update_multi(index, vectors, ids): 19 | """ 20 | 21 | :param index: 22 | :param vectors: 23 | :param ids: 24 | :return: 25 | 备注:ValueError: array is not C-contiguous 26 | """ 27 | idx = np.argsort(ids) 28 | # 先删除再添加 29 | index.remove_ids(ids[idx]) 30 | index.add_with_ids(vectors[idx], ids[idx]) 31 | 32 | 33 | def update_one(index, vector, label_id): 34 | vectors = np.expand_dims(vector, axis=0) 35 | ids = np.array([label_id]) 36 | update_multi(index, vectors, ids) 37 | -------------------------------------------------------------------------------- /layers.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | File Name: layers 4 | Description : keras 层 5 | Author : mick.yi 6 | date: 2019/1/2 7 | """ 8 | from keras.layers import Layer 9 | from keras import backend as K 10 | 11 | 12 | class DenseWithDPSoftmaxLoss(Layer): 13 | def __init__(self, num_class, m=0.35, scale=30, **kwargs): 14 | self.output_dim = num_class 15 | self.margin = m 16 | self.scale = scale 17 | # self.current_selected_labels = K.ones(shape=(num_class, 1)) 18 | super(DenseWithDPSoftmaxLoss, self).__init__(**kwargs) 19 | 20 | def build(self, input_shape): 21 | self.kernel = self.add_weight(name='kernel', 22 | shape=(input_shape[0][1], self.output_dim), # (嵌入维度,num_class) 23 | dtype=K.floatx(), 24 | initializer='glorot_normal', 25 | trainable=True) 26 | self.current_selected_labels = self.add_weight(name='labels', 27 | shape=(self.output_dim, 1), 28 | initializer='glorot_normal', 29 | trainable=False) 30 | self.y_pred = self.add_weight(name='pred', 31 | shape=(self.output_dim, self.output_dim), 32 | initializer='glorot_normal', 33 | trainable=False) 34 | 35 | def call(self, inputs, **kwargs): 36 | # 将当前step类别的权重,赋值给tensor 37 | weights_assign_op = K.tf.assign(self.kernel, 38 | K.transpose(inputs[1]), 39 | name='assign_weights') 40 | label_assign_op = K.tf.assign(self.current_selected_labels, inputs[2], name='assign_labels') 41 | with K.tf.control_dependencies([weights_assign_op, label_assign_op]): 42 | self.x_norm = K.l2_normalize(inputs[0], axis=1) 43 | self.kernel_norm = K.l2_normalize(self.kernel, axis=0) 44 | self.logit = K.dot(self.x_norm, self.kernel_norm) 45 | return self.logit 46 | 47 | def compute_output_shape(self, input_shape): 48 | return (input_shape[0][0], 49 | self.output_dim) 50 | # return [(input_shape[0][0], self.output_dim), 51 | # (self.output_dim, input_shape[0][1])] 52 | 53 | def loss(self, y_true, y_pred): 54 | # 首先将预测值保持到权重中 55 | pred_assign_op = K.tf.assign(self.y_pred, 56 | y_pred, 57 | name='assign_pred') 58 | with K.tf.control_dependencies([pred_assign_op]): 59 | y_true = y_true[:, 0] # 非常重要,默认是二维的 60 | y_true_mask = K.one_hot(K.tf.cast(y_true, dtype='int32'), self.output_dim) 61 | cosine_m = y_pred - y_true_mask * self.margin # cosine-m 62 | losses = K.sparse_categorical_crossentropy(target=y_true, 63 | output=cosine_m * self.scale, 64 | from_logits=True) 65 | 66 | return losses 67 | -------------------------------------------------------------------------------- /log/events.out.tfevents.1547478727.localhost.localdomain: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yizt/keras-lbl-IvS/3f98b698c56ae40954b4920da167f7c9e32024c8/log/events.out.tfevents.1547478727.localhost.localdomain -------------------------------------------------------------------------------- /pic/acc.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yizt/keras-lbl-IvS/3f98b698c56ae40954b4920da167f7c9e32024c8/pic/acc.jpg -------------------------------------------------------------------------------- /pic/loss.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yizt/keras-lbl-IvS/3f98b698c56ae40954b4920da167f7c9e32024c8/pic/loss.jpg -------------------------------------------------------------------------------- /pic/lr.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yizt/keras-lbl-IvS/3f98b698c56ae40954b4920da167f7c9e32024c8/pic/lr.jpg -------------------------------------------------------------------------------- /pic/output_21_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yizt/keras-lbl-IvS/3f98b698c56ae40954b4920da167f7c9e32024c8/pic/output_21_0.png -------------------------------------------------------------------------------- /pic/output_8_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yizt/keras-lbl-IvS/3f98b698c56ae40954b4920da167f7c9e32024c8/pic/output_8_0.png -------------------------------------------------------------------------------- /pipeline.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | File Name: pipeline 4 | Description : 处理流程 5 | Author : mick.yi 6 | date: 2019/1/3 7 | """ 8 | import tensorflow as tf 9 | import numpy as np 10 | import h5py 11 | import keras 12 | from keras.layers import Input 13 | from keras.models import Model 14 | from keras.optimizers import SGD 15 | import keras.backend as K 16 | from keras.callbacks import ModelCheckpoint, ReduceLROnPlateau, TensorBoard 17 | import sys 18 | import os 19 | import time 20 | import argparse 21 | 22 | if __package__ is None: 23 | sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..')) 24 | __package__ = "keras-lbl-IvS" 25 | 26 | from .config import current_config as config 27 | from .utils import random_select, get_weights, update_weights, update_queue 28 | from .faiss_utils import get_index, update_multi 29 | from .reader import get_mslm_infos, load_img 30 | from .layers import DenseWithDPSoftmaxLoss 31 | 32 | 33 | def set_gpu_growth(): 34 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 35 | cfg = tf.ConfigProto() 36 | cfg.gpu_options.allow_growth = True 37 | session = tf.Session(config=cfg) 38 | K.set_session(session) 39 | 40 | 41 | def generator(images_info, label_id_dict, dominant_queue, num_class, batch_size): 42 | """ 43 | 训练样本生成器 44 | :param images_info: 图像的元数据信息 45 | :param label_id_dict: 类别和对应的图像id字典 46 | :param dominant_queue: 支配队列; 类别和对应的支配类别 47 | :param num_class: 类别数 48 | :param batch_size: batch_size 49 | :return: 50 | """ 51 | while True: 52 | # 两级采样,首先采样batch_size/2个正类别;然后随机采样支配类别 53 | sample_labels = np.random.choice(num_class, batch_size // 2, replace=False) # 无放回抽样 54 | 55 | selected_labels = set(sample_labels) # 保存当前选中类别set 56 | selected_image_labels = [] # 保存当前step选中的图像和对应的类别标签 57 | # 首先选择正原型 58 | for label in sample_labels: 59 | selected_image_id = random_select(label_id_dict[label]) 60 | selected_image_labels.append([selected_image_id, label]) 61 | 62 | # 再选择相关的支配原型,直到mini-batch大小 63 | while len(selected_image_labels) < batch_size: 64 | # 随机采样当前正原型 65 | label = random_select(sample_labels) 66 | # 随机选择支配类别,不能是之前已经选择过的 67 | dq_label = random_select(dominant_queue[label]) 68 | while dq_label in selected_labels: 69 | dq_label = random_select(dominant_queue[label]) 70 | selected_labels.add(dq_label) 71 | # 选择支配类别的图像 72 | selected_image_id = random_select(label_id_dict[dq_label]) 73 | selected_image_labels.append([selected_image_id, dq_label]) 74 | # 当前选中标签 75 | selected_image_labels = np.asarray(selected_image_labels) # 转为numpy数组 76 | current_selected_labels = selected_image_labels[:, 1] 77 | current_weights = get_weights(config.pw_h5_file[config.prototype_weights_dataset], 78 | current_selected_labels) 79 | 80 | # 加载图像 81 | images = [load_img(images_info[image_id]['img_path']) for image_id, label in selected_image_labels] 82 | images = np.asarray(images) 83 | # 返回当前mini-batch 84 | current_selected_labels = np.expand_dims(current_selected_labels, axis=1) 85 | # print("current_selected_labels.shape:{}".format(current_selected_labels.shape)) 86 | # print("images.shape:{},type(images):{}".format(images.shape, type(images))) 87 | yield [images, 88 | current_weights, 89 | current_selected_labels], np.arange(batch_size) # 标签类别永远是0~batch_size-1 90 | 91 | 92 | def init_queue(index, weights_set, num_class, dq_num, cq_num): 93 | """ 94 | 初始化候选队列和支配队列 95 | :param index: 96 | :param weights_set: h5py dataset对象 97 | :param num_class: 98 | :param dq_num: 99 | :param cq_num: 100 | :return: 101 | """ 102 | data, candidate_label_idx = index.search(weights_set[:num_class], cq_num) 103 | dominant_label_idx = candidate_label_idx[:, :dq_num] # 候选队列包含支配队列 104 | 105 | # 转为字典类型 106 | dominant_queue = dict(enumerate(dominant_label_idx)) 107 | candidate_queue = dict(enumerate(candidate_label_idx)) 108 | 109 | return dominant_queue, candidate_queue 110 | 111 | 112 | def init_prototype(images_info, label_id_dict, num_class): 113 | inputs = Input(batch_shape=(config.batch_size,) + config.input_shape) 114 | features = config.backbone(inputs) 115 | model = Model(inputs, features) 116 | model.load_weights(config.backbone_weights, by_name=True) 117 | # 原型权重一份放到hdf5,一份存放到faiss中(faiss中保留的不是精准的) 118 | f_label = h5py.File(config.prototype_weights_hdf5, 'w') 119 | label_feature = f_label.create_dataset(config.prototype_weights_dataset, 120 | shape=(num_class, 512), dtype='f') 121 | 122 | # 逐个类别处理 123 | for label in range(num_class): 124 | # 获取某个label的所有图像,并使用模型预测图像的特征,最后求均值作为label的原型权重 125 | image_ids = label_id_dict[label] # 图像id 126 | images = [load_img(images_info[image_id]['img_path']) for image_id in image_ids] # 图像数据 127 | features = model.predict(np.asarray(images)) # 输出特征 128 | features = keras.utils.np_utils.normalize(features) # 归一化 129 | features = np.mean(features, axis=0) # 求均值 130 | features = keras.utils.np_utils.normalize(features) # 再次归一化; 是二维的 131 | 132 | # 赋值给hdf5 133 | label_feature[label] = features[0] # (1,d) 转为 (d,) 134 | # 每1w次,刷写到磁盘 135 | if label % 500 == 0: 136 | f_label.flush() 137 | print("{} init_prototype 完成:{}".format(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()), 138 | label)) 139 | # 最后关闭文件 140 | f_label.close() 141 | 142 | 143 | def get_prototype(deal_batch=1000): 144 | f_label = h5py.File(config.prototype_weights_hdf5, 'r+') 145 | dset = f_label[config.prototype_weights_dataset] 146 | length = len(dset) 147 | index = get_index(512) 148 | # 逐个类别处理,更新faiss index 149 | for batch_no in range(length // deal_batch): 150 | start = batch_no * deal_batch 151 | end = (batch_no + 1) * deal_batch 152 | features = dset[start:end] 153 | update_multi(index, features, np.arange(start, end)) 154 | # 处理不能整除的情况 155 | if not length % deal_batch == 0: 156 | start = length - length % deal_batch 157 | end = length 158 | features = dset[start:end] 159 | update_multi(index, features, np.arange(start, end)) 160 | return f_label, index 161 | 162 | 163 | def label_id_map(images_info, num_class): 164 | """ 165 | 将图像按照类别分组 166 | :param images_info: 图像字典{'img_path': 图像路径,'label': 类别,'img_id':图像id} 167 | :param num_class: 类别数 168 | :return: 169 | """ 170 | # 初始化 171 | label_id_dict = dict() 172 | for i in range(num_class): 173 | label_id_dict[i] = [] 174 | 175 | # 逐个图像归类 176 | for i in range(len(images_info)): 177 | label = int(images_info[i]['label']) 178 | img_id = images_info[i]['img_id'] 179 | label_id_dict[label].append(img_id) 180 | 181 | return label_id_dict 182 | 183 | 184 | class ExportWeights(keras.callbacks.Callback): 185 | def on_batch_end(self, batch, logs=None): 186 | layer = self.model.layers[-1] 187 | 188 | trained_weights, current_trained_labels, y_pred = layer.get_weights()[:3] 189 | trained_weights = np.transpose(trained_weights) 190 | current_trained_labels = np.asarray(current_trained_labels[:, 0], dtype=np.int) 191 | # print("\n input_weights:{}".format(config.current_input[1][0][:10])) 192 | # print("\n trained_weights:{}".format(trained_weights[0][:10])) 193 | # current_selected_labels = config.current_selected_labels 194 | # print("current_selected_labels:{}".format(current_selected_labels)) 195 | # print("current_trained_labels:{}".format(current_trained_labels)) 196 | update_multi(config.index, trained_weights, current_trained_labels) # 更新faiss index 197 | update_weights(config.pw_h5_file, 198 | config.pw_h5_file[config.prototype_weights_dataset], 199 | trained_weights, 200 | current_trained_labels) 201 | # 以下更新支配队列,根据预测结果更新支配队列; 202 | # todo : 需要使用更加高效的方式来获取输出结果,使用predict相当于又做了一次前向传播(已解决,使用权重保存) 203 | # y_pred = self.model.predict(config.current_input) 204 | # print("y_pred:{}".format(y_pred)) 205 | update_queue(config.dominant_queue, 206 | config.candidate_queue, 207 | y_pred, 208 | current_trained_labels) 209 | 210 | 211 | def get_call_back(): 212 | """ 213 | 定义call back 214 | :return: 215 | """ 216 | checkpoint = ModelCheckpoint(filepath='/tmp/lbl-IvS.{epoch:03d}.h5', 217 | monitor='val_acc', 218 | verbose=1, 219 | save_best_only=False) 220 | 221 | # 验证误差没有提升 222 | lr_reducer = ReduceLROnPlateau(monitor='val_loss', 223 | factor=np.sqrt(0.1), 224 | cooldown=1, 225 | patience=1, 226 | min_lr=0) 227 | 228 | log = TensorBoard(log_dir='log') 229 | 230 | export_weights = ExportWeights() 231 | 232 | return [checkpoint, lr_reducer, export_weights, log] 233 | 234 | 235 | def main(args): 236 | K.clear_session() 237 | set_gpu_growth() 238 | # 获取图像元数据信息 239 | images_info, label_set = get_mslm_infos(config.ms1m_annotation_file, config.ms1m_img_dir) 240 | # 过滤小于num_classes的类别;测试用 241 | images_info = filter(lambda x: int(x['label']) < config.num_classes, images_info) 242 | images_info = dict(enumerate(images_info)) # 转为字典 243 | 244 | label_id_dict = label_id_map(images_info, config.num_classes) # 类别对应的 245 | 246 | # 初始化原型权重 247 | if 'init' in args.stages: 248 | init_prototype(images_info, 249 | label_id_dict, 250 | config.num_classes) 251 | print("初始化原型权重完成... ...") 252 | 253 | # 训练阶段 254 | if 'train' in args.stages: 255 | config.pw_h5_file, config.index = get_prototype(1000) 256 | print("获取原型权重完成... ...") 257 | # current_weights = get_weights(config.pw_h5_file[config.prototype_weights_dataset], 258 | # np.arange(config.batch_size)) 259 | # print(current_weights.shape) 260 | # print(type(current_weights)) 261 | # print(current_weights[0]) 262 | 263 | # 初始化队列 264 | config.dominant_queue, config.candidate_queue = init_queue(config.index, 265 | config.pw_h5_file[config.prototype_weights_dataset], 266 | config.num_classes, 267 | config.dominant_queue_num, 268 | config.candidate_queue_num) 269 | print("初始化队列完成... ...") 270 | # 构建模型 271 | inputs = Input(batch_shape=(config.batch_size,) + config.input_shape) 272 | weights_inputs = Input(batch_shape=(config.batch_size, 512)) 273 | label_inputs = Input(batch_shape=(config.batch_size, 1)) # 至少是二维的, 274 | features = config.backbone(inputs) 275 | dense = DenseWithDPSoftmaxLoss(config.batch_size) # batch-size当做类别数 276 | outputs = dense([features, weights_inputs, label_inputs]) 277 | m = Model([inputs, weights_inputs, label_inputs], outputs) 278 | m.summary() 279 | m.load_weights(config.backbone_weights, by_name=True) 280 | m.compile(loss=dense.loss, optimizer=SGD(lr=0.01, momentum=0.9, decay=0.0005), metrics=['accuracy']) 281 | 282 | # 训练模型 283 | print("开始训练模型... ...") 284 | gen = generator(images_info, 285 | label_id_dict, 286 | config.dominant_queue, 287 | config.num_classes, 288 | config.batch_size) 289 | 290 | m.fit_generator(gen, 291 | callbacks=get_call_back(), 292 | steps_per_epoch=len(images_info) / 4 // config.batch_size, 293 | epochs=10, 294 | use_multiprocessing=False, 295 | verbose=1, 296 | validation_data=next(gen) 297 | ) 298 | 299 | # 最后关闭hdf5 300 | config.pw_h5_file.close() 301 | 302 | 303 | if __name__ == '__main__': 304 | parse = argparse.ArgumentParser() 305 | parse.add_argument("--stages", type=str, nargs='+', default=['train'], help="stage: init、train") 306 | argments = parse.parse_args(sys.argv[1:]) 307 | # print(argments) 308 | # useage: python pipeline.py --stages init train | python pipeline.py --stages train 309 | main(argments) 310 | -------------------------------------------------------------------------------- /reader.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | File Name: reader 4 | Description : 读取图像信息 5 | Author : mick.yi 6 | date: 2018/12/26 7 | """ 8 | import os 9 | import codecs 10 | import matplotlib.pyplot as plt 11 | 12 | 13 | def get_mslm_infos(annotation_file, img_dir): 14 | """ 15 | 读取mslm数据集信息 16 | :param annotation_file: 标注文件路径 17 | :param img_dir: 图像存放路径 18 | :return: 19 | """ 20 | with codecs.open(annotation_file, 'r', 'utf-8') as f: 21 | lines = f.readlines() 22 | img_infos = [] 23 | 24 | label_set = set() 25 | for id, line in enumerate(lines): 26 | img_name, label = line.split('\t') 27 | img_info = dict() 28 | img_info['img_path'] = os.path.join(img_dir, img_name) 29 | img_info['label'] = label 30 | img_info['img_id'] = id # 增加图像id编号 31 | img_infos.append(img_info) 32 | label_set.add(label) 33 | 34 | return img_infos, label_set 35 | 36 | 37 | def load_img(img_path): 38 | img = plt.imread(img_path) 39 | return img[:, :, :3] 40 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | File Name: utils 4 | Description : 5 | Author : mick.yi 6 | date: 2019/1/4 7 | """ 8 | import numpy as np 9 | 10 | 11 | def enqueue(np_array, elem): 12 | """ 13 | 入队列,新增元素放到队首,队尾元素丢弃 14 | :param np_array: 原始队列 15 | :param elem: 增加元素 16 | :return: 17 | """ 18 | np_array[1:] = np_array[:-1] 19 | np_array[0] = elem 20 | return np_array 21 | 22 | 23 | def random_select(ids): 24 | """ 25 | 随机选择一个id 26 | :param ids: id列表,(N,) 27 | :return: 28 | """ 29 | idx = np.random.choice(len(ids)) 30 | return ids[idx] 31 | 32 | 33 | # def to_train_label(train_label): 34 | 35 | 36 | def update_weights(h5_file, h5_dataset, weights, labels): 37 | """ 38 | 更新保存在hdf5中的原型权重 39 | :param h5_file: 原型权重的hdf5文件 40 | :param h5_dataset: 原型权重在hdf5中的dataset 41 | :param weights: 待更新的权重,numpy数组 (Batch,Dim) 42 | :param labels: 待更新的权重对应的类别标签 43 | :return: 44 | 备注:TypeError: Indexing elements must be in increasing order; idx要排序 45 | TypeError: PointSelection __getitem__ only works with bool arrays; labels[idx]改为list(labels[idx]) 46 | """ 47 | # for idx, label in enumerate(labels): 48 | # h5_dataset[label] = weights[idx] 49 | idx = np.argsort(labels) 50 | h5_dataset[list(labels[idx])] = weights[idx] 51 | h5_file.flush() 52 | 53 | 54 | def get_weights(h5_dataset, labels): 55 | weights = [h5_dataset[label] for label in labels] 56 | return np.asarray(weights) 57 | 58 | 59 | def update_queue(dominant_queue, candidate_queue, predict, current_labels): 60 | """ 61 | 更新支配队列 62 | :param dominant_queue: 支配队列 63 | :param candidate_queue: 候选队列 64 | :param predict: 预测的类别,numpy数组 (Batch,train_num_class) 65 | :param current_labels: 实际的当前类别(Batch,) 66 | :return: 67 | """ 68 | predict_label = np.argmax(predict, axis=-1) 69 | for i in range(len(predict_label)): 70 | d_label_queue = dominant_queue[current_labels[i]] 71 | c_label_queue = candidate_queue[current_labels[i]] 72 | real_predict_label = current_labels[predict_label[i]] 73 | # 预测结果不是正确标签,不在正确标签的支配队列中,但在正确标签的候选队列中 74 | # 更新支配队列 75 | if predict_label[i] != i and \ 76 | real_predict_label not in d_label_queue and \ 77 | real_predict_label in c_label_queue: 78 | dominant_queue[current_labels[i]] = enqueue(d_label_queue, real_predict_label) 79 | --------------------------------------------------------------------------------