├── .gitignore
├── LICENSE
├── README.md
├── __init__.py
├── backbone.py
├── cifar_demo.py
├── config.py
├── evaluate.py
├── faiss_utils.py
├── layers.py
├── lbl-IvS Demo.ipynb
├── log
└── events.out.tfevents.1547478727.localhost.localdomain
├── pic
├── acc.jpg
├── loss.jpg
├── lr.jpg
├── output_21_0.png
└── output_8_0.png
├── pipeline.py
├── reader.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 | MANIFEST
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 | .pytest_cache/
49 |
50 | # Translations
51 | *.mo
52 | *.pot
53 |
54 | # Django stuff:
55 | *.log
56 | local_settings.py
57 | db.sqlite3
58 |
59 | # Flask stuff:
60 | instance/
61 | .webassets-cache
62 |
63 | # Scrapy stuff:
64 | .scrapy
65 |
66 | # Sphinx documentation
67 | docs/_build/
68 |
69 | # PyBuilder
70 | target/
71 |
72 | # Jupyter Notebook
73 | .ipynb_checkpoints
74 |
75 | # pyenv
76 | .python-version
77 |
78 | # celery beat schedule file
79 | celerybeat-schedule
80 |
81 | # SageMath parsed files
82 | *.sage.py
83 |
84 | # Environments
85 | .env
86 | .venv
87 | env/
88 | venv/
89 | ENV/
90 | env.bak/
91 | venv.bak/
92 |
93 | # Spyder project settings
94 | .spyderproject
95 | .spyproject
96 |
97 | # Rope project settings
98 | .ropeproject
99 |
100 | # mkdocs documentation
101 | /site
102 |
103 | # mypy
104 | .mypy_cache/
105 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # keras-lbl-IvS
2 |
3 | 论文地址:[Large-scale Bisample Learning on ID vs. Spot Face Recognition](https://arxiv.org/pdf/1806.03018.pdf)
4 |
5 | 工程地址:[keras-lbl-IvS](https://github.com/yizt/keras-lbl-IvS)
6 |
7 | keras实现人证比对论文《Large-scale Bisample Learning on ID vs. Spot Face Recognition》的核心思想; 当类别数非常大时(超过百万),GPU显存可能无法装载权重参数;可以使用支配原型概念,每个step将相关的类别一块训练,step开始时将当前mini-batch的类别原型权重加载到GPU中,而不用加载所有的类别权重参数; step结束时更新并保存当前训练类别的原型权重。
8 |
9 | 由于目前没有公开证件数据集;这里使用cifar10和cifar100数据集做简单的模拟测试;先说结果,直接使用cifar10数据集训练,使用imprint生成初始的cifar100数据集的类别权重;在初始的权重中测试精度为20.8%,可以说是一个不错的初始化状态;每次使用30个类别做支配原型训练,训练的精度约为80%;使用训练后的原型权重再整体测试cifar100类,最终精度为57.4%; 效果任然不太理想。一方面支配原型训练优化应该更加困难;另一方面也没有测试更多的超参数选择(支配训练类别数、支配队列大小、候选队列大小、margin、scale、batch-size、lr等等)。
10 |
11 | 但以上测试结果仍然能够说明此方法是可行的;至少在类别数特别大时(GPU无法装下),使用用支配原型权重概念,在训练时动态加载当前训练类别的权重的是一个不错的方法。
12 |
13 | 欢迎关注、反馈问题和沟通交流... ...
14 |
15 | 下图为支配原型训练的acc,loss和lr信息:
16 |
17 | 

18 |
19 | 以下是测试过程,目录如下:
20 |
21 | [TOC]
22 |
23 | ## 依赖环境
24 |
25 | * python 3.6
26 | * tensorflow-gpu 1.9.0
27 | * keras 2.2.0
28 | * faiss 1.4.0
29 | * h5py 2.8.0
30 |
31 |
32 |
33 | ## 初始化环境
34 |
35 |
36 | ```python
37 | import tensorflow as tf
38 | import numpy as np
39 | import h5py
40 | import keras
41 | from keras.layers import Input
42 | from keras.models import Model
43 | from keras.optimizers import SGD
44 | from keras.datasets import cifar10,cifar100
45 | import keras.backend as K
46 | from keras.callbacks import ModelCheckpoint, ReduceLROnPlateau, TensorBoard
47 | import sys
48 | import os
49 | import time
50 | import argparse
51 |
52 | if __package__ is None:
53 | sys.path
54 | sys.path.insert(0, '/opt/github/py_data_mining/face')
55 | __package__ = "keras-lbl-IvS"
56 |
57 | from .config import Config
58 | from .utils import random_select, get_weights, update_weights, update_queue
59 | from .faiss_utils import get_index, update_multi
60 | ```
61 |
62 | /root/anaconda3/envs/keras/lib/python3.6/importlib/_bootstrap.py:219: RuntimeWarning: numpy.dtype size changed, may indicate binary incompatibility. Expected 96, got 88
63 | return f(*args, **kwds)
64 | /root/anaconda3/envs/keras/lib/python3.6/importlib/_bootstrap.py:219: RuntimeWarning: numpy.dtype size changed, may indicate binary incompatibility. Expected 96, got 88
65 | return f(*args, **kwds)
66 | Using TensorFlow backend.
67 |
68 |
69 | ### 配置信息
70 |
71 |
72 | ```python
73 | class CifarConfig(Config):
74 | # 输入信息
75 | input_shape = (32, 32, 3)
76 | num_classes = 100 # cifar100
77 | train_num_classes = 30 # 支配原型训练的类别数
78 |
79 | # 训练参数
80 | batch_size = 256
81 |
82 | # 原型权重保存
83 | prototype_weights_hdf5 = '/tmp/cifar_prototype_weights.hdf5'
84 | prototype_weights_dataset = 'cifar_prototype_weights_set'
85 | pw_h5_file = None # hdf5 File文件
86 |
87 | index = None # 保存原型类别的faiss索引
88 |
89 | dominant_queue = None # 支配队列
90 | candidate_queue = None # 候选队列
91 |
92 | dominant_queue_num = 20 # 支配队列大小
93 | candidate_queue_num = 50 # 候选队列大小
94 |
95 | config = CifarConfig()
96 | ```
97 |
98 |
99 | ```python
100 | def set_gpu_growth():
101 | os.environ["CUDA_VISIBLE_DEVICES"] = "0"
102 | cfg = tf.ConfigProto()
103 | cfg.gpu_options.allow_growth = True
104 | session = tf.Session(config=cfg)
105 | K.set_session(session)
106 |
107 | set_gpu_growth()
108 | ```
109 |
110 |
111 | ```python
112 | import matplotlib.pyplot as plt
113 | def img_show(X, rows, cols):
114 | plt.figure(figsize=(cols * 1.5, rows))
115 | for i, img in enumerate(X[:rows * cols]):
116 | plt.subplot(rows, cols, 1 + i)
117 | plt.imshow(img)
118 | plt.show()
119 | ```
120 |
121 | ## 源数据集训练
122 |
123 | 使用cifar10 作为源数据集
124 |
125 | ### cifar10 数据集加载、可视化
126 |
127 |
128 | ```python
129 | (x_train, y_train), (x_test, y_test) = cifar10.load_data()
130 | ```
131 |
132 |
133 | ```python
134 | img_show(x_train,2,10)
135 | ```
136 |
137 |
138 | 
139 |
140 |
141 | ### 骨干网络定义
142 |
143 | 骨干网络,提取特征
144 |
145 |
146 | ```python
147 | # 骨干网络,提取特征
148 | import keras.layers as layers
149 | def cifar_base_net(inputs):
150 | """
151 | 骨干网络、用于提取特征
152 | :param inputs: keras Input对象
153 | :return:
154 | """
155 | x = layers.Conv2D(32, (3, 3), padding='same', name='conv_1')(inputs)
156 | x = layers.BatchNormalization(axis=3, name='bn_conv1')(x)
157 | x = layers.Activation('relu')(x)
158 | x = layers.Conv2D(32, (3, 3), name='conv_2')(x)
159 | x = layers.Activation('relu')(x)
160 | x = layers.MaxPooling2D(pool_size=(2, 2))(x)
161 | x = layers.Dropout(0.25)(x)
162 |
163 | x = layers.Conv2D(64, (3, 3), padding='same', name='conv_3')(x)
164 | x = layers.Activation('relu')(x)
165 | x = layers.Conv2D(64, (3, 3), name='conv_4')(x)
166 | x = layers.Activation('relu')(x)
167 | x = layers.MaxPooling2D(pool_size=(2, 2))(x)
168 | x = layers.Dropout(0.25)(x)
169 |
170 | x = layers.Conv2D(512, (3, 3), name='conv_5')(x)
171 | x = layers.Activation('relu')(x)
172 | x = layers.GlobalAvgPool2D()(x)
173 |
174 | return x
175 | ```
176 |
177 | ### AMSoftmax分类层定义
178 |
179 | 关于AMSoftmax的详细信息见论文:[Additive Margin Softmax for Face Verification](https://arxiv.org/abs/1801.05599)
180 |
181 |
182 | ```python
183 | from keras.layers import Layer
184 | class DenseWithAMSoftmaxLoss(Layer):
185 |
186 | def __init__(self, num_classes, m=0.35, scale=30, **kwargs):
187 | """
188 | 根据Additive Margin Softmaxfor Face Verification论文设置m和scale
189 | :param num_classes: 类别数
190 | :param m: margin间隔
191 | :param scale: 尺度
192 | :param kwargs:
193 | """
194 | self.output_dim = num_classes
195 | self.m = m
196 | self.scale = scale
197 | super(DenseWithAMSoftmaxLoss, self).__init__(**kwargs)
198 |
199 | def build(self, input_shape):
200 | # 添加可训练参数
201 | self.kernel = self.add_weight(name='kernel',
202 | shape=(input_shape[1], self.output_dim),
203 | initializer='glorot_normal',
204 | trainable=True)
205 |
206 | def call(self, inputs):
207 | self.inputs = inputs
208 | # L2归一化
209 | self.w_norm = K.tf.nn.l2_normalize(self.kernel, 0, 1e-10)
210 | self.x_norm = K.tf.nn.l2_normalize(self.inputs, 1, 1e-10)
211 | self.logits = K.dot(self.x_norm, self.w_norm)
212 | return self.logits
213 |
214 | def compute_output_shape(self, input_shape):
215 | return (input_shape[0], # batch_size
216 | self.output_dim)
217 |
218 | def loss_dense(self, y_true, y_pred):
219 | """
220 | AmSoftmax Loss
221 | :param y_true: 真实类别 [N,num_classes]
222 | :param y_pred: 预测类别[N,num_classes]
223 | :return:
224 | """
225 | y_true = K.cast(y_true, 'float32')
226 | thelta_m = y_pred - y_true * self.m # cosine(thelta)-m ; y_true 就相当于mask
227 |
228 | return K.categorical_crossentropy(y_true, self.scale * thelta_m, from_logits=True)
229 |
230 | ```
231 |
232 | ### 构建模型
233 |
234 |
235 | ```python
236 | from keras.models import Model
237 | from keras.layers import Dense
238 | from keras import Input
239 |
240 | def source_model(img_inputs,num_class):
241 | features=cifar_base_net(img_inputs)
242 | dense=DenseWithAMSoftmaxLoss(num_class,m=0.2,scale=24.,name='amsoftmax') # 使用自定义的AMSoftMax分类层
243 | output=dense(features)
244 | m = Model(img_inputs,output)
245 | m.compile(optimizer=keras.optimizers.Adam(lr=0.001,decay=1e-6),
246 | loss=dense.loss_dense,
247 | metrics=['accuracy'])
248 | return m
249 |
250 | img_input=Input(shape=(32,32,3))
251 | m = source_model(img_input,10) # cifar10 一共10类
252 | m.summary()
253 | ```
254 |
255 | WARNING:tensorflow:From /root/anaconda3/envs/keras/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py:3216: softmax_cross_entropy_with_logits (from tensorflow.python.ops.nn_ops) is deprecated and will be removed in a future version.
256 | Instructions for updating:
257 |
258 | Future major versions of TensorFlow will allow gradients to flow
259 | into the labels input on backprop by default.
260 |
261 | See @{tf.nn.softmax_cross_entropy_with_logits_v2}.
262 |
263 | _________________________________________________________________
264 | Layer (type) Output Shape Param #
265 | =================================================================
266 | input_1 (InputLayer) (None, 32, 32, 3) 0
267 | _________________________________________________________________
268 | conv_1 (Conv2D) (None, 32, 32, 32) 896
269 | _________________________________________________________________
270 | bn_conv1 (BatchNormalization (None, 32, 32, 32) 128
271 | _________________________________________________________________
272 | activation_1 (Activation) (None, 32, 32, 32) 0
273 | _________________________________________________________________
274 | conv_2 (Conv2D) (None, 30, 30, 32) 9248
275 | _________________________________________________________________
276 | activation_2 (Activation) (None, 30, 30, 32) 0
277 | _________________________________________________________________
278 | max_pooling2d_1 (MaxPooling2 (None, 15, 15, 32) 0
279 | _________________________________________________________________
280 | dropout_1 (Dropout) (None, 15, 15, 32) 0
281 | _________________________________________________________________
282 | conv_3 (Conv2D) (None, 15, 15, 64) 18496
283 | _________________________________________________________________
284 | activation_3 (Activation) (None, 15, 15, 64) 0
285 | _________________________________________________________________
286 | conv_4 (Conv2D) (None, 13, 13, 64) 36928
287 | _________________________________________________________________
288 | activation_4 (Activation) (None, 13, 13, 64) 0
289 | _________________________________________________________________
290 | max_pooling2d_2 (MaxPooling2 (None, 6, 6, 64) 0
291 | _________________________________________________________________
292 | dropout_2 (Dropout) (None, 6, 6, 64) 0
293 | _________________________________________________________________
294 | conv_5 (Conv2D) (None, 4, 4, 512) 295424
295 | _________________________________________________________________
296 | activation_5 (Activation) (None, 4, 4, 512) 0
297 | _________________________________________________________________
298 | global_average_pooling2d_1 ( (None, 512) 0
299 | _________________________________________________________________
300 | amsoftmax (DenseWithAMSoftma (None, 10) 5120
301 | =================================================================
302 | Total params: 366,240
303 | Trainable params: 366,176
304 | Non-trainable params: 64
305 | _________________________________________________________________
306 |
307 |
308 | ### 训练模型
309 |
310 |
311 | ```python
312 | (x_train, y_train), (x_test, y_test) = cifar10.load_data()
313 | m.fit(x_train, keras.utils.to_categorical(y_train,10),
314 | batch_size=64,
315 | epochs=10,
316 | validation_data=(x_test, keras.utils.to_categorical(y_test,10)),
317 | shuffle=True)
318 | ```
319 |
320 | Train on 50000 samples, validate on 10000 samples
321 | Epoch 1/10
322 | 50000/50000 [==============================] - 8s 153us/step - loss: 5.6485 - acc: 0.4251 - val_loss: 4.8240 - val_acc: 0.5428
323 | Epoch 2/10
324 | 50000/50000 [==============================] - 6s 129us/step - loss: 4.1562 - acc: 0.6204 - val_loss: 3.5912 - val_acc: 0.6834
325 | Epoch 3/10
326 | 50000/50000 [==============================] - 7s 136us/step - loss: 3.4796 - acc: 0.6930 - val_loss: 3.1658 - val_acc: 0.7277
327 | Epoch 4/10
328 | 50000/50000 [==============================] - 7s 137us/step - loss: 3.0975 - acc: 0.7322 - val_loss: 2.9639 - val_acc: 0.7396
329 | Epoch 5/10
330 | 50000/50000 [==============================] - 7s 135us/step - loss: 2.8483 - acc: 0.7540 - val_loss: 3.1497 - val_acc: 0.7251
331 | Epoch 6/10
332 | 50000/50000 [==============================] - 7s 131us/step - loss: 2.6558 - acc: 0.7757 - val_loss: 3.0899 - val_acc: 0.7297
333 | Epoch 7/10
334 | 50000/50000 [==============================] - 6s 128us/step - loss: 2.5169 - acc: 0.7885 - val_loss: 2.6846 - val_acc: 0.7680
335 | Epoch 8/10
336 | 50000/50000 [==============================] - 7s 136us/step - loss: 2.4056 - acc: 0.7990 - val_loss: 2.8013 - val_acc: 0.7623
337 | Epoch 9/10
338 | 50000/50000 [==============================] - 7s 136us/step - loss: 2.3017 - acc: 0.8082 - val_loss: 2.4712 - val_acc: 0.7900
339 | Epoch 10/10
340 | 50000/50000 [==============================] - 7s 131us/step - loss: 2.2125 - acc: 0.8167 - val_loss: 2.4762 - val_acc: 0.7867
341 |
342 |
343 |
344 |
345 |
346 |
347 |
348 |
349 |
350 |
351 | ```python
352 | # 保存模型
353 | m.save('/tmp/source.weights.h5')
354 | ```
355 |
356 | ## 初始评估
357 |
358 | 使用在源数据集cifar10上训练的模型,来评估在cifar100上的精度;评估方式是通过imprint构建初始的分类权重;
359 | imprint详见[Low-Shot Learning with Imprinted Weights](https://arxiv.org/abs/1712.07136)论文,相关代码见[keras-imprinting](https://github.com/yizt/keras-imprinting).
360 |
361 |
362 | ### cifar100 数据加载、可视化
363 |
364 |
365 | ```python
366 | (x_train, y_train), (x_test, y_test) = cifar100.load_data()
367 | ```
368 |
369 |
370 | ```python
371 | img_show(x_train,2,10)
372 | ```
373 |
374 |
375 | 
376 |
377 |
378 | ### 初始原型权重
379 |
380 | 使用imprint方式将cifar100的类别原型权重保持到hdf5中(虽然这里只有100类,但是也按照类别上百万的方式来实现)
381 |
382 |
383 | ```python
384 | # 类别分组
385 | def label_id_map(labels, num_class):
386 | """
387 | 将图像按照类别分组
388 | """
389 | # 初始化
390 | label_id_dict = dict()
391 | for i in range(num_class):
392 | label_id_dict[i] = []
393 |
394 | # 逐个图像归类
395 | for img_id in range(len(labels)):
396 | label_id_dict[labels[img_id]].append(img_id)
397 | return label_id_dict
398 | label_id_dict = label_id_map(y_train[:,0], config.num_classes) # y_train (5000,1)
399 | ```
400 |
401 |
402 | ```python
403 | # 初始化原型权重
404 | def init_prototype(pretrained_weights, X, label_id_dict, num_class):
405 | inputs = Input(batch_shape=(config.batch_size,) + config.input_shape)
406 | features = cifar_base_net(inputs)
407 | model = Model(inputs, features)
408 | model.load_weights(pretrained_weights, by_name=True)
409 | # 原型权重一份放到hdf5,一份存放到faiss中(faiss中保留的不是精准的)
410 | if os.path.exists(config.prototype_weights_hdf5):
411 | os.remove(config.prototype_weights_hdf5)
412 | f_label = h5py.File(config.prototype_weights_hdf5, 'w')
413 | label_feature = f_label.create_dataset(config.prototype_weights_dataset,
414 | shape=(num_class, 512), dtype='f')
415 |
416 | # 逐个类别处理
417 | for label in range(num_class):
418 | # 获取某个label的所有图像,并使用模型预测图像的特征,最后求均值作为label的原型权重
419 | image_ids = label_id_dict[label] # 图像id
420 | images = X[image_ids] # 图像数据
421 | features = model.predict(np.asarray(images)) # 输出特征
422 | features = keras.utils.np_utils.normalize(features) # 归一化
423 | features = np.mean(features, axis=0) # 求均值
424 | features = keras.utils.np_utils.normalize(features) # 再次归一化; 是二维的
425 |
426 | # 赋值给hdf5
427 | label_feature[label] = features[0] # (1,d) 转为 (d,)
428 | # 每1w次,刷写到磁盘
429 | if label % 1000 == 0:
430 | f_label.flush()
431 | print("{} init_prototype 完成:{}".format(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()),
432 | label))
433 | f_label.flush()
434 | # 最后关闭文件
435 | f_label.close()
436 | ```
437 |
438 |
439 | ```python
440 | init_prototype('/tmp/source.weights.h5',x_train,
441 | label_id_dict,
442 | config.num_classes)
443 | ```
444 |
445 | 2019-01-14 23:11:57 init_prototype 完成:0
446 |
447 |
448 | ### 开始评估
449 |
450 |
451 | ```python
452 | # 定义评估函数
453 | from keras.layers import Dense
454 | def evaluate(pretrained_weights, X, y, num_class, use_amsoftmax=False):
455 | if use_amsoftmax:
456 | dense = DenseWithAMSoftmaxLoss(num_class, m=0.2, scale=24.,name='other') # name='other'
457 | else:
458 | dense = layers.Dense(num_class, use_bias=False, activation='softmax')
459 | # 构建模型
460 | inputs = Input(shape=config.input_shape)
461 | features = cifar_base_net(inputs)
462 | outputs = dense(features)
463 | m = Model(inputs, outputs)
464 |
465 | m.load_weights(pretrained_weights, by_name=True)
466 |
467 | # 设置分类层的权重
468 | f_label = h5py.File(config.prototype_weights_hdf5, 'r')
469 | weights = f_label[config.prototype_weights_dataset] # HDF5 dataset object
470 | m.layers[-1].set_weights([np.transpose(weights[:])]) # 使用weights[:]转为numpy
471 | # 最后关闭文件
472 | f_label.close()
473 | #m.summary()
474 |
475 | # 评估预测
476 | y_pred = m.predict(X)
477 | y_pred = np.argmax(y_pred, axis=-1)
478 |
479 | print("acc:{}".format(np.mean(np.equal(y,y_pred))))
480 | ```
481 |
482 |
483 | ```python
484 | evaluate('/tmp/source.weights.h5', x_test,y_test[:,0],config.num_classes,True)
485 | ```
486 |
487 | acc:0.2081
488 |
489 |
490 | 可以看到直接使用imprint方式初始化类别权重参数,在cifar100的精度超过20%;是随机精度1%的20倍;
491 |
492 | ## 支配原型训练cifar100
493 |
494 | ### 获取原型
495 |
496 | 获取保持在hdf5中的初始原型权重;并将所有的权重加载到faiss的索引中,用于候选构造支配队列和候选队列
497 |
498 |
499 | ```python
500 | def get_prototype(deal_batch=1000):
501 | f_label = h5py.File(config.prototype_weights_hdf5, 'r+')
502 | dset = f_label[config.prototype_weights_dataset]
503 | length = len(dset)
504 | index = get_index(512)
505 | # 逐个类别处理,更新faiss index
506 | for batch_no in range(length // deal_batch):
507 | start = batch_no * deal_batch
508 | end = (batch_no + 1) * deal_batch
509 | features = dset[start:end]
510 | update_multi(index, features, np.arange(start, end))
511 | # 处理不能整除的情况
512 | if not length % deal_batch == 0:
513 | start = length - length % deal_batch
514 | end = length
515 | features = dset[start:end]
516 | update_multi(index, features, np.arange(start, end))
517 | return f_label, index
518 | ```
519 |
520 |
521 | ```python
522 | config.pw_h5_file, config.index = get_prototype(1000)
523 | ```
524 |
525 | ### 初始化队列
526 |
527 | 根据原型权重构造支配队列和候选队列;
528 |
529 |
530 | ```python
531 | def init_queue(index, weights_set, num_class, dq_num, cq_num):
532 | """
533 | 初始化候选队列和支配队列
534 | :param index:
535 | :param weights_set: h5py dataset对象
536 | :param num_class:
537 | :param dq_num:
538 | :param cq_num:
539 | :return:
540 | """
541 | data, candidate_label_idx = index.search(weights_set[:num_class], cq_num + 1)
542 | candidate_label_idx = candidate_label_idx[:, 1:] #
543 | dominant_label_idx = candidate_label_idx[:, :dq_num] # 候选队列包含支配队列
544 |
545 | # 转为字典类型
546 | dominant_queue = dict(enumerate(dominant_label_idx))
547 | candidate_queue = dict(enumerate(candidate_label_idx))
548 |
549 | return dominant_queue, candidate_queue
550 | ```
551 |
552 |
553 | ```python
554 | config.dominant_queue, config.candidate_queue = init_queue(config.index,
555 | config.pw_h5_file[config.prototype_weights_dataset],
556 | config.num_classes,
557 | config.dominant_queue_num,
558 | config.candidate_queue_num)
559 | ```
560 |
561 | ### 支配原型分类层定义
562 |
563 | 关键点:
564 | a) 将需要输出的tensor保存在权重参数中
565 | b) 使用K.tf.control_dependencies保证修动态加载原型权重参数和输出预测结果到不训练的权重中会加入到计算图中
566 |
567 |
568 | ```python
569 | class DenseWithDPSoftmaxLoss(Layer):
570 | def __init__(self, num_class, bacth_size, m=0.35, scale=30, **kwargs):
571 | self.output_dim = num_class
572 | self.m = m
573 | self.scale = scale
574 | self.batch_size = bacth_size
575 | super(DenseWithDPSoftmaxLoss, self).__init__(**kwargs)
576 |
577 | def build(self, input_shape):
578 | self.kernel = self.add_weight(name='kernel',
579 | shape=(input_shape[0][1], self.output_dim), # (嵌入维度,num_class)
580 | dtype=K.floatx(),
581 | initializer='glorot_normal',
582 | constraint=keras.constraints.unit_norm(),
583 | trainable=True)
584 | self.current_selected_labels = self.add_weight(name='labels',
585 | shape=(self.batch_size, 1),
586 | initializer='glorot_normal',
587 | trainable=False)
588 | self.y_pred = self.add_weight(name='pred',
589 | shape=(self.batch_size, self.output_dim),
590 | initializer='glorot_normal',
591 | trainable=False)
592 |
593 | def call(self, inputs, **kwargs):
594 | # 将当前step类别的权重,赋值给tensor
595 | weights_assign_op = K.tf.assign(self.kernel,
596 | K.transpose(inputs[1][:self.output_dim]),
597 | name='assign_weights')
598 | label_assign_op = K.tf.assign(self.current_selected_labels, inputs[2], name='assign_labels')
599 | with K.tf.control_dependencies([weights_assign_op, label_assign_op]):
600 | # L2归一化
601 | self.w_norm = K.tf.nn.l2_normalize(self.kernel, 0, 1e-10)
602 | self.x_norm = K.tf.nn.l2_normalize(inputs[0], 1, 1e-10)
603 | self.logits = K.dot(self.x_norm, self.w_norm)
604 | return self.logits
605 |
606 | def compute_output_shape(self, input_shape):
607 | return (input_shape[0][0],
608 | self.output_dim)
609 | # return [(input_shape[0][0], self.output_dim),
610 | # (self.output_dim, input_shape[0][1])]
611 |
612 | def loss(self, y_true, y_pred):
613 | """
614 | DPAmSoftmax Loss
615 | :param y_true: 真实类别 [N,num_classes]
616 | :param y_pred: 预测类别[N,num_classes]
617 | :return:
618 | """
619 | # 首先将预测值保持到权重中
620 | pred_assign_op = K.tf.assign(self.y_pred,
621 | y_pred,
622 | name='assign_pred')
623 | with K.tf.control_dependencies([pred_assign_op]):
624 | y_true = K.cast(y_true, 'float32')
625 | thelta_m = y_pred - y_true * self.m # cosine(thelta)-m ; y_true 就相当于mask
626 | losses = K.categorical_crossentropy(y_true, self.scale * thelta_m, from_logits=True)
627 |
628 | return losses
629 | ```
630 |
631 | ### 构建模型
632 |
633 |
634 | ```python
635 | def target_model(pretrained_weights, batch_size,train_num_class, input_shape):
636 | """
637 | 目标数据集训练模型
638 | :param pretrained_weights: 源数据集上训练的特征提取层的权重
639 | :param batch_size:
640 | :param input_shape:
641 | :return:
642 | """
643 | # 构建模型
644 | inputs = Input(batch_shape=(batch_size,) + input_shape)
645 | weights_inputs = Input(batch_shape=(train_num_class, 512))
646 | label_inputs = Input(batch_shape=(batch_size, 1)) # 至少是二维的,
647 | features = cifar_base_net(inputs)
648 | dense = DenseWithDPSoftmaxLoss(train_num_class,batch_size,m=0.2,scale=24.)
649 | outputs = dense([features, weights_inputs, label_inputs])
650 | m = Model([inputs, weights_inputs, label_inputs], outputs)
651 |
652 | m.load_weights(pretrained_weights, by_name=True)
653 | # 固定参数
654 | #for l in m.layers[:-1]:
655 | # l.trainable = False
656 | m.compile(loss=dense.loss,
657 | optimizer=SGD(lr=0.2, momentum=0, decay=1e-4),
658 | metrics=['accuracy'])
659 |
660 | m.summary()
661 | return m
662 | ```
663 |
664 |
665 | ```python
666 | m=target_model('/tmp/source.weights.h5',
667 | config.batch_size,config.train_num_classes, config.input_shape)
668 | ```
669 |
670 | __________________________________________________________________________________________________
671 | Layer (type) Output Shape Param # Connected to
672 | ==================================================================================================
673 | input_4 (InputLayer) (64, 32, 32, 3) 0
674 | __________________________________________________________________________________________________
675 | conv_1 (Conv2D) (64, 32, 32, 32) 896 input_4[0][0]
676 | __________________________________________________________________________________________________
677 | bn_conv1 (BatchNormalization) (64, 32, 32, 32) 128 conv_1[0][0]
678 | __________________________________________________________________________________________________
679 | activation_16 (Activation) (64, 32, 32, 32) 0 bn_conv1[0][0]
680 | __________________________________________________________________________________________________
681 | conv_2 (Conv2D) (64, 30, 30, 32) 9248 activation_16[0][0]
682 | __________________________________________________________________________________________________
683 | activation_17 (Activation) (64, 30, 30, 32) 0 conv_2[0][0]
684 | __________________________________________________________________________________________________
685 | max_pooling2d_7 (MaxPooling2D) (64, 15, 15, 32) 0 activation_17[0][0]
686 | __________________________________________________________________________________________________
687 | dropout_7 (Dropout) (64, 15, 15, 32) 0 max_pooling2d_7[0][0]
688 | __________________________________________________________________________________________________
689 | conv_3 (Conv2D) (64, 15, 15, 64) 18496 dropout_7[0][0]
690 | __________________________________________________________________________________________________
691 | activation_18 (Activation) (64, 15, 15, 64) 0 conv_3[0][0]
692 | __________________________________________________________________________________________________
693 | conv_4 (Conv2D) (64, 13, 13, 64) 36928 activation_18[0][0]
694 | __________________________________________________________________________________________________
695 | activation_19 (Activation) (64, 13, 13, 64) 0 conv_4[0][0]
696 | __________________________________________________________________________________________________
697 | max_pooling2d_8 (MaxPooling2D) (64, 6, 6, 64) 0 activation_19[0][0]
698 | __________________________________________________________________________________________________
699 | dropout_8 (Dropout) (64, 6, 6, 64) 0 max_pooling2d_8[0][0]
700 | __________________________________________________________________________________________________
701 | conv_5 (Conv2D) (64, 4, 4, 512) 295424 dropout_8[0][0]
702 | __________________________________________________________________________________________________
703 | activation_20 (Activation) (64, 4, 4, 512) 0 conv_5[0][0]
704 | __________________________________________________________________________________________________
705 | global_average_pooling2d_4 (Glo (64, 512) 0 activation_20[0][0]
706 | __________________________________________________________________________________________________
707 | input_5 (InputLayer) (30, 512) 0
708 | __________________________________________________________________________________________________
709 | input_6 (InputLayer) (64, 1) 0
710 | __________________________________________________________________________________________________
711 | dense_with_dp_softmax_loss_1 (D (64, 30) 17344 global_average_pooling2d_4[0][0]
712 | input_5[0][0]
713 | input_6[0][0]
714 | ==================================================================================================
715 | Total params: 378,464
716 | Trainable params: 376,416
717 | Non-trainable params: 2,048
718 | __________________________________________________________________________________________________
719 |
720 | 可以看到支配原型训练模型的最后的分类层输出是30个类别
721 |
722 |
723 |
724 | ### 构建生成器
725 |
726 |
727 | ```python
728 | def generator(X, label_id_dict, dominant_queue, num_class, train_num_class, batch_size):
729 | """
730 | 训练样本生成器
731 | :param X: 图像
732 | :param label_id_dict: 类别和对应的图像id字典
733 | :param dominant_queue: 支配队列; 类别和对应的支配类别
734 | :param num_class: 类别数, 这里是100
735 | :param train_num_class: 训练类别数, 每次加载多少个类别训练
736 | :param batch_size: batch_size
737 | :return:
738 | """
739 | while True:
740 | # 两级采样,首先采样个train_num_class // 3正类别;然后随机采样支配类别
741 | sample_labels = np.random.choice(num_class, train_num_class // 3, replace=False) # 无放回抽样
742 |
743 | selected_labels = set(sample_labels) # 保存当前选中类别set
744 |
745 | # 采样支配类别
746 | all_current_dominant_labels = set()
747 | for label in selected_labels:
748 | all_current_dominant_labels = all_current_dominant_labels.union(dominant_queue[label])
749 |
750 | all_current_dominant_labels = all_current_dominant_labels - selected_labels
751 | selected_dominant_labels = np.random.choice(list(all_current_dominant_labels),
752 | train_num_class - len(selected_labels),
753 | replace=False)
754 | all_selected_labels = selected_labels | set(selected_dominant_labels)
755 | all_selected_labels = list(all_selected_labels) # 转为list
756 | list.sort(all_selected_labels)
757 | # print(all_selected_labels)
758 |
759 | # 以下采样图像
760 | selected_image_labels = [] # 保存当前step选中的图像和对应的类别标签
761 | # 首先每个类别采样一次
762 | for label in all_selected_labels:
763 | selected_image_id = random_select(label_id_dict[label])
764 | selected_image_labels.append([selected_image_id, label])
765 |
766 | # 再随机采样,直到mini-batch大小
767 | while len(selected_image_labels) < batch_size:
768 | # 随机采样类别
769 | label = random_select(all_selected_labels)
770 | # 随机采样图像
771 | selected_image_id = random_select(label_id_dict[label])
772 | selected_image_labels.append([selected_image_id, label])
773 |
774 | # 真实选中的标签
775 | selected_image_labels = np.asarray(selected_image_labels) # 转为numpy数组
776 | real_selected_labels = np.array(selected_image_labels[:, 1]) # 新建一个数组
777 |
778 | # 将真实标签映射为当前训练的标签
779 | for i in range(batch_size):
780 | selected_image_labels[i][1] = all_selected_labels.index(selected_image_labels[i][1])
781 |
782 | # 当前选中图像和标签
783 | current_selected_labels = np.asarray(all_selected_labels)
784 | current_weights = get_weights(config.pw_h5_file[config.prototype_weights_dataset],
785 | current_selected_labels)
786 | # 增加padding到batch-size大小
787 | current_weights = np.pad(current_weights,
788 | ((0, batch_size - train_num_class), (0, 0)),
789 | 'constant')
790 |
791 | # 加载图像
792 | images = X[selected_image_labels[:, 0]]
793 | # 处理mini-batch的真实标签
794 | real_selected_labels = np.expand_dims(real_selected_labels, axis=1)
795 | # 返回当前mini-batch
796 | yield [images,
797 | current_weights,
798 | real_selected_labels], keras.utils.to_categorical(selected_image_labels[:, 1],
799 | train_num_class) # 标签类别永远是0~batch_size-1
800 |
801 | ```
802 |
803 | ### 定义保存权重的CallBack
804 |
805 | 在训练的每个step后,将此次训练的类别原型权重导出保存到hdf5中,同时更新根据预测结果更新支配队列
806 |
807 |
808 | ```python
809 | class ExportWeights(keras.callbacks.Callback):
810 | def on_batch_end(self, batch, logs=None):
811 | layer = self.model.layers[-1]
812 | # 获取当前step训练后的权重、当前训练实际的标签类别,预测结果
813 | trained_weights, current_trained_labels, y_pred = layer.get_weights()[:3]
814 | trained_weights = np.transpose(trained_weights)
815 | # l2 norm
816 | trained_weights = trained_weights/np.linalg.norm(trained_weights, ord=2, axis=1, keepdims=True)
817 |
818 | current_trained_labels = np.asarray(current_trained_labels[:, 0], dtype=np.int) # 二维转一维
819 | # print("trained_weights:{}".format(trained_weights[:2][:5]))
820 | # print("current_trained_labels:{}".format(current_trained_labels))
821 | # 更新原型权重
822 | update_weights(config.pw_h5_file,
823 | config.pw_h5_file[config.prototype_weights_dataset],
824 | trained_weights,
825 | np.unique(current_trained_labels))
826 | # 以下更新支配队列,根据预测结果更新支配队列;
827 | update_queue(config.dominant_queue,
828 | config.candidate_queue,
829 | y_pred,
830 | current_trained_labels)
831 |
832 | ```
833 |
834 |
835 | ```python
836 | #定义callback
837 | def get_call_back():
838 | # 精度没有提升
839 | lr_reducer = ReduceLROnPlateau(monitor='acc',
840 | factor=np.sqrt(0.1),
841 | cooldown=0,
842 | patience=5,
843 | min_lr=0.5e-6)
844 | log = TensorBoard(log_dir='log')
845 | export_weights = ExportWeights()
846 |
847 | return [lr_reducer, export_weights, log]
848 | ```
849 |
850 | ### 训练支配原型模型
851 |
852 |
853 | ```python
854 | print("开始训练模型... ...")
855 | gen = generator(x_train,
856 | label_id_dict,
857 | config.dominant_queue,
858 | config.num_classes,
859 | config.train_num_classes, # 每次训练的类别数
860 | config.batch_size)
861 | evaluate('/tmp/source.weights.h5',
862 | x_test,y_test[:,0],config.num_classes,True)
863 | ```
864 |
865 | 开始训练模型... ...
866 |
867 |
868 |
869 | ```python
870 | m.fit_generator(gen,
871 | callbacks=get_call_back(),
872 | steps_per_epoch=len(x_train) // config.batch_size,
873 | epochs=200,
874 | use_multiprocessing=True,
875 | # validation_data=next(gen),
876 | verbose=1
877 | )
878 | ```
879 |
880 | Epoch 1/200
881 | 195/195 [==============================] - 5s 23ms/step - loss: 8.6228 - acc: 0.0701
882 | Epoch 2/200
883 | 195/195 [==============================] - 4s 21ms/step - loss: 6.9006 - acc: 0.3078
884 | Epoch 3/200
885 | 195/195 [==============================] - 4s 21ms/step - loss: 6.2487 - acc: 0.4090
886 | Epoch 4/200
887 | 195/195 [==============================] - 4s 21ms/step - loss: 5.9130 - acc: 0.4527
888 | Epoch 5/200
889 | 195/195 [==============================] - 4s 21ms/step - loss: 5.7382 - acc: 0.4725
890 | Epoch 6/200
891 | 195/195 [==============================] - 4s 21ms/step - loss: 5.5785 - acc: 0.4917
892 | Epoch 7/200
893 | 195/195 [==============================] - 4s 21ms/step - loss: 5.3747 - acc: 0.5178
894 | Epoch 8/200
895 | 195/195 [==============================] - 4s 21ms/step - loss: 5.2205 - acc: 0.5345
896 | Epoch 9/200
897 | 195/195 [==============================] - 4s 21ms/step - loss: 5.0547 - acc: 0.5533
898 | Epoch 10/200
899 | 195/195 [==============================] - 4s 21ms/step - loss: 4.9496 - acc: 0.5629
900 | Epoch 11/200
901 | 195/195 [==============================] - 4s 21ms/step - loss: 4.9175 - acc: 0.5664
902 | Epoch 12/200
903 | 195/195 [==============================] - 4s 21ms/step - loss: 4.7262 - acc: 0.5885
904 | Epoch 13/200
905 | 195/195 [==============================] - 4s 21ms/step - loss: 4.6691 - acc: 0.5940
906 | Epoch 14/200
907 | 195/195 [==============================] - 4s 21ms/step - loss: 4.5670 - acc: 0.6050
908 | Epoch 15/200
909 | 195/195 [==============================] - 4s 21ms/step - loss: 4.4895 - acc: 0.6137
910 | Epoch 16/200
911 | 195/195 [==============================] - 4s 21ms/step - loss: 4.4187 - acc: 0.6185
912 | Epoch 17/200
913 | 195/195 [==============================] - 4s 21ms/step - loss: 4.3984 - acc: 0.6248
914 | Epoch 18/200
915 | 195/195 [==============================] - 4s 21ms/step - loss: 4.3396 - acc: 0.6281
916 | Epoch 19/200
917 | 195/195 [==============================] - 4s 21ms/step - loss: 4.2386 - acc: 0.6382
918 | Epoch 20/200
919 | 195/195 [==============================] - 4s 21ms/step - loss: 4.2609 - acc: 0.6366
920 | Epoch 21/200
921 | 195/195 [==============================] - 4s 21ms/step - loss: 4.2040 - acc: 0.6425
922 | Epoch 22/200
923 | 195/195 [==============================] - 4s 21ms/step - loss: 4.1078 - acc: 0.6521
924 | Epoch 23/200
925 | 195/195 [==============================] - 4s 21ms/step - loss: 4.1030 - acc: 0.6492
926 | Epoch 24/200
927 | 195/195 [==============================] - 4s 21ms/step - loss: 4.0578 - acc: 0.6584
928 | Epoch 25/200
929 | 195/195 [==============================] - 4s 21ms/step - loss: 3.9726 - acc: 0.6668
930 | Epoch 26/200
931 | 195/195 [==============================] - 4s 21ms/step - loss: 3.9568 - acc: 0.6678
932 | Epoch 27/200
933 | 195/195 [==============================] - 4s 21ms/step - loss: 3.8982 - acc: 0.6756
934 | Epoch 28/200
935 | 195/195 [==============================] - 4s 21ms/step - loss: 3.8731 - acc: 0.6754
936 | Epoch 29/200
937 | 195/195 [==============================] - 4s 21ms/step - loss: 3.8896 - acc: 0.6719
938 | Epoch 30/200
939 | 195/195 [==============================] - 4s 21ms/step - loss: 3.8002 - acc: 0.6818
940 | Epoch 31/200
941 | 195/195 [==============================] - 4s 21ms/step - loss: 3.7603 - acc: 0.6842
942 | Epoch 32/200
943 | 195/195 [==============================] - 4s 21ms/step - loss: 3.7401 - acc: 0.6894
944 | Epoch 33/200
945 | 195/195 [==============================] - 4s 21ms/step - loss: 3.7214 - acc: 0.6924
946 | Epoch 34/200
947 | 195/195 [==============================] - 4s 21ms/step - loss: 3.6610 - acc: 0.6975
948 | Epoch 35/200
949 | 195/195 [==============================] - 4s 21ms/step - loss: 3.6877 - acc: 0.6937
950 | Epoch 36/200
951 | 195/195 [==============================] - 4s 21ms/step - loss: 3.6174 - acc: 0.7023
952 | Epoch 37/200
953 | 195/195 [==============================] - 4s 21ms/step - loss: 3.6541 - acc: 0.6981
954 | Epoch 38/200
955 | 195/195 [==============================] - 4s 21ms/step - loss: 3.5960 - acc: 0.7030
956 | Epoch 39/200
957 | 195/195 [==============================] - 4s 21ms/step - loss: 3.5589 - acc: 0.7082
958 | Epoch 40/200
959 | 195/195 [==============================] - 4s 21ms/step - loss: 3.5304 - acc: 0.7080
960 | Epoch 41/200
961 | 195/195 [==============================] - 4s 21ms/step - loss: 3.5466 - acc: 0.7083
962 | Epoch 42/200
963 | 195/195 [==============================] - 4s 21ms/step - loss: 3.4622 - acc: 0.7191
964 | Epoch 43/200
965 | 195/195 [==============================] - 4s 21ms/step - loss: 3.4591 - acc: 0.7161
966 | Epoch 44/200
967 | 195/195 [==============================] - 4s 21ms/step - loss: 3.4526 - acc: 0.7163
968 | Epoch 45/200
969 | 195/195 [==============================] - 4s 21ms/step - loss: 3.4540 - acc: 0.7162
970 | Epoch 46/200
971 | 195/195 [==============================] - 4s 21ms/step - loss: 3.4770 - acc: 0.7142
972 | Epoch 47/200
973 | 195/195 [==============================] - 4s 21ms/step - loss: 3.3777 - acc: 0.7260
974 | Epoch 48/200
975 | 195/195 [==============================] - 4s 21ms/step - loss: 3.3948 - acc: 0.7224
976 | Epoch 49/200
977 | 195/195 [==============================] - 4s 21ms/step - loss: 3.3268 - acc: 0.7286
978 | Epoch 50/200
979 | 195/195 [==============================] - 4s 21ms/step - loss: 3.3755 - acc: 0.7261
980 | Epoch 51/200
981 | 195/195 [==============================] - 4s 21ms/step - loss: 3.2830 - acc: 0.7326
982 | Epoch 52/200
983 | 195/195 [==============================] - 4s 21ms/step - loss: 3.3017 - acc: 0.7325
984 | Epoch 53/200
985 | 195/195 [==============================] - 4s 21ms/step - loss: 3.2776 - acc: 0.7352
986 | Epoch 54/200
987 | 195/195 [==============================] - 4s 21ms/step - loss: 3.2682 - acc: 0.7354
988 | Epoch 55/200
989 | 195/195 [==============================] - 4s 21ms/step - loss: 3.2782 - acc: 0.7355
990 | Epoch 56/200
991 | 195/195 [==============================] - 4s 21ms/step - loss: 3.2301 - acc: 0.7388
992 | Epoch 57/200
993 | 195/195 [==============================] - 4s 21ms/step - loss: 3.2260 - acc: 0.7384
994 | Epoch 58/200
995 | 195/195 [==============================] - 4s 21ms/step - loss: 3.2353 - acc: 0.7393
996 | Epoch 59/200
997 | 195/195 [==============================] - 4s 21ms/step - loss: 3.1699 - acc: 0.7463
998 | Epoch 60/200
999 | 195/195 [==============================] - 4s 21ms/step - loss: 3.1867 - acc: 0.7443
1000 | Epoch 61/200
1001 | 195/195 [==============================] - 4s 21ms/step - loss: 3.1513 - acc: 0.7471
1002 | Epoch 62/200
1003 | 195/195 [==============================] - 4s 21ms/step - loss: 3.1651 - acc: 0.7463
1004 | Epoch 63/200
1005 | 195/195 [==============================] - 4s 21ms/step - loss: 3.1378 - acc: 0.7481
1006 | Epoch 64/200
1007 | 195/195 [==============================] - 4s 21ms/step - loss: 3.1572 - acc: 0.7446
1008 | Epoch 65/200
1009 | 195/195 [==============================] - 4s 21ms/step - loss: 3.1101 - acc: 0.7506
1010 | Epoch 66/200
1011 | 195/195 [==============================] - 4s 21ms/step - loss: 3.1224 - acc: 0.7497
1012 | Epoch 67/200
1013 | 195/195 [==============================] - 4s 21ms/step - loss: 3.1320 - acc: 0.7495
1014 | Epoch 68/200
1015 | 195/195 [==============================] - 4s 21ms/step - loss: 3.0740 - acc: 0.7520
1016 | Epoch 69/200
1017 | 195/195 [==============================] - 4s 21ms/step - loss: 3.0902 - acc: 0.7509
1018 | Epoch 70/200
1019 | 195/195 [==============================] - 4s 21ms/step - loss: 3.0828 - acc: 0.7523
1020 | Epoch 71/200
1021 | 195/195 [==============================] - 4s 21ms/step - loss: 3.0511 - acc: 0.7560
1022 | Epoch 72/200
1023 | 195/195 [==============================] - 4s 21ms/step - loss: 3.0450 - acc: 0.7584
1024 | Epoch 73/200
1025 | 195/195 [==============================] - 4s 21ms/step - loss: 3.0774 - acc: 0.7532
1026 | Epoch 74/200
1027 | 195/195 [==============================] - 4s 21ms/step - loss: 3.0108 - acc: 0.7608
1028 | Epoch 75/200
1029 | 195/195 [==============================] - 4s 21ms/step - loss: 3.0656 - acc: 0.7544
1030 | Epoch 76/200
1031 | 195/195 [==============================] - 4s 21ms/step - loss: 3.0506 - acc: 0.7548
1032 | Epoch 77/200
1033 | 195/195 [==============================] - 4s 21ms/step - loss: 3.0210 - acc: 0.7609
1034 | Epoch 78/200
1035 | 195/195 [==============================] - 4s 21ms/step - loss: 2.9919 - acc: 0.7619
1036 | Epoch 79/200
1037 | 195/195 [==============================] - 4s 21ms/step - loss: 3.0089 - acc: 0.7601
1038 | Epoch 80/200
1039 | 195/195 [==============================] - 4s 21ms/step - loss: 2.9677 - acc: 0.7643
1040 | Epoch 81/200
1041 | 195/195 [==============================] - 4s 21ms/step - loss: 2.9317 - acc: 0.7677
1042 | Epoch 82/200
1043 | 195/195 [==============================] - 4s 21ms/step - loss: 2.9385 - acc: 0.7662
1044 | Epoch 83/200
1045 | 195/195 [==============================] - 4s 21ms/step - loss: 2.9403 - acc: 0.7667
1046 | Epoch 84/200
1047 | 195/195 [==============================] - 4s 21ms/step - loss: 2.9502 - acc: 0.7661
1048 | Epoch 85/200
1049 | 195/195 [==============================] - 4s 21ms/step - loss: 2.9566 - acc: 0.7648
1050 | Epoch 86/200
1051 | 195/195 [==============================] - 4s 21ms/step - loss: 2.9040 - acc: 0.7686
1052 | Epoch 87/200
1053 | 195/195 [==============================] - 4s 21ms/step - loss: 2.9097 - acc: 0.7696
1054 | Epoch 88/200
1055 | 195/195 [==============================] - 4s 21ms/step - loss: 2.9770 - acc: 0.7648
1056 | Epoch 89/200
1057 | 195/195 [==============================] - 4s 21ms/step - loss: 2.8816 - acc: 0.7713
1058 | Epoch 90/200
1059 | 195/195 [==============================] - 4s 21ms/step - loss: 2.9009 - acc: 0.7690
1060 | Epoch 91/200
1061 | 195/195 [==============================] - 4s 21ms/step - loss: 2.8859 - acc: 0.7689
1062 | Epoch 92/200
1063 | 195/195 [==============================] - 4s 21ms/step - loss: 2.9151 - acc: 0.7679
1064 | Epoch 93/200
1065 | 195/195 [==============================] - 4s 21ms/step - loss: 2.8468 - acc: 0.7735
1066 | Epoch 94/200
1067 | 195/195 [==============================] - 4s 21ms/step - loss: 2.8806 - acc: 0.7716
1068 | Epoch 95/200
1069 | 195/195 [==============================] - 4s 21ms/step - loss: 2.8690 - acc: 0.7727
1070 | Epoch 96/200
1071 | 195/195 [==============================] - 4s 21ms/step - loss: 2.8607 - acc: 0.7728
1072 | Epoch 97/200
1073 | 195/195 [==============================] - 4s 21ms/step - loss: 2.8648 - acc: 0.7743
1074 | Epoch 98/200
1075 | 195/195 [==============================] - 4s 21ms/step - loss: 2.8162 - acc: 0.7769
1076 | Epoch 99/200
1077 | 195/195 [==============================] - 4s 21ms/step - loss: 2.8316 - acc: 0.7759
1078 | Epoch 100/200
1079 | 195/195 [==============================] - 4s 21ms/step - loss: 2.8379 - acc: 0.7771
1080 | Epoch 101/200
1081 | 195/195 [==============================] - 4s 21ms/step - loss: 2.7826 - acc: 0.7825
1082 | Epoch 102/200
1083 | 195/195 [==============================] - 4s 21ms/step - loss: 2.8175 - acc: 0.7772
1084 | Epoch 103/200
1085 | 195/195 [==============================] - 4s 21ms/step - loss: 2.8053 - acc: 0.7807
1086 | Epoch 104/200
1087 | 195/195 [==============================] - 4s 21ms/step - loss: 2.8092 - acc: 0.7777
1088 | Epoch 105/200
1089 | 195/195 [==============================] - 4s 21ms/step - loss: 2.7802 - acc: 0.7807
1090 | Epoch 106/200
1091 | 195/195 [==============================] - 4s 21ms/step - loss: 2.7878 - acc: 0.7812
1092 | Epoch 107/200
1093 | 195/195 [==============================] - 4s 21ms/step - loss: 2.6700 - acc: 0.7913
1094 | Epoch 108/200
1095 | 195/195 [==============================] - 4s 21ms/step - loss: 2.6969 - acc: 0.7908
1096 | Epoch 109/200
1097 | 195/195 [==============================] - 4s 21ms/step - loss: 2.6961 - acc: 0.7897
1098 | Epoch 110/200
1099 | 195/195 [==============================] - 4s 21ms/step - loss: 2.7121 - acc: 0.7865
1100 | Epoch 111/200
1101 | 195/195 [==============================] - 4s 21ms/step - loss: 2.6620 - acc: 0.7920
1102 | Epoch 112/200
1103 | 195/195 [==============================] - 4s 21ms/step - loss: 2.7139 - acc: 0.7876
1104 | Epoch 113/200
1105 | 195/195 [==============================] - 4s 21ms/step - loss: 2.6802 - acc: 0.7916
1106 | Epoch 114/200
1107 | 195/195 [==============================] - 4s 21ms/step - loss: 2.7220 - acc: 0.7861
1108 | Epoch 115/200
1109 | 195/195 [==============================] - 4s 21ms/step - loss: 2.6849 - acc: 0.7881
1110 | Epoch 116/200
1111 | 195/195 [==============================] - 4s 21ms/step - loss: 2.6786 - acc: 0.7906
1112 | Epoch 117/200
1113 | 195/195 [==============================] - 4s 21ms/step - loss: 2.6823 - acc: 0.7901
1114 | Epoch 118/200
1115 | 195/195 [==============================] - 4s 21ms/step - loss: 2.6086 - acc: 0.7976
1116 | Epoch 119/200
1117 | 195/195 [==============================] - 4s 21ms/step - loss: 2.6388 - acc: 0.7937
1118 | Epoch 120/200
1119 | 195/195 [==============================] - 4s 21ms/step - loss: 2.6250 - acc: 0.7950
1120 | Epoch 121/200
1121 | 195/195 [==============================] - 4s 21ms/step - loss: 2.6180 - acc: 0.7965
1122 | Epoch 122/200
1123 | 195/195 [==============================] - 4s 21ms/step - loss: 2.6573 - acc: 0.7924
1124 | Epoch 123/200
1125 | 195/195 [==============================] - 4s 21ms/step - loss: 2.6354 - acc: 0.7954
1126 | Epoch 124/200
1127 | 195/195 [==============================] - 4s 21ms/step - loss: 2.6108 - acc: 0.8001
1128 | Epoch 125/200
1129 | 195/195 [==============================] - 4s 21ms/step - loss: 2.6199 - acc: 0.7964
1130 | Epoch 126/200
1131 | 195/195 [==============================] - 4s 21ms/step - loss: 2.6237 - acc: 0.7963
1132 | Epoch 127/200
1133 | 195/195 [==============================] - 4s 21ms/step - loss: 2.5821 - acc: 0.8002
1134 | Epoch 128/200
1135 | 195/195 [==============================] - 4s 21ms/step - loss: 2.6052 - acc: 0.7995
1136 | Epoch 129/200
1137 | 195/195 [==============================] - 4s 21ms/step - loss: 2.5955 - acc: 0.7979
1138 | Epoch 130/200
1139 | 195/195 [==============================] - 4s 21ms/step - loss: 2.6088 - acc: 0.7991
1140 | Epoch 131/200
1141 | 195/195 [==============================] - 4s 21ms/step - loss: 2.5897 - acc: 0.8001
1142 | Epoch 132/200
1143 | 195/195 [==============================] - 4s 21ms/step - loss: 2.6078 - acc: 0.7975
1144 | Epoch 133/200
1145 | 195/195 [==============================] - 4s 21ms/step - loss: 2.6425 - acc: 0.7951
1146 | Epoch 134/200
1147 | 195/195 [==============================] - 4s 21ms/step - loss: 2.6338 - acc: 0.7943
1148 | Epoch 135/200
1149 | 195/195 [==============================] - 4s 21ms/step - loss: 2.6003 - acc: 0.7984
1150 | Epoch 136/200
1151 | 195/195 [==============================] - 4s 21ms/step - loss: 2.5735 - acc: 0.8002
1152 | Epoch 137/200
1153 | 195/195 [==============================] - 4s 21ms/step - loss: 2.6034 - acc: 0.7975
1154 | Epoch 138/200
1155 | 195/195 [==============================] - 4s 21ms/step - loss: 2.6422 - acc: 0.7954
1156 | Epoch 139/200
1157 | 195/195 [==============================] - 4s 21ms/step - loss: 2.6546 - acc: 0.7932
1158 | Epoch 140/200
1159 | 195/195 [==============================] - 4s 21ms/step - loss: 2.6455 - acc: 0.7960
1160 | Epoch 141/200
1161 | 195/195 [==============================] - 4s 21ms/step - loss: 2.6205 - acc: 0.7969
1162 | Epoch 142/200
1163 | 195/195 [==============================] - 4s 21ms/step - loss: 2.6054 - acc: 0.7963
1164 | Epoch 143/200
1165 | 195/195 [==============================] - 4s 21ms/step - loss: 2.5764 - acc: 0.8005
1166 | Epoch 144/200
1167 | 195/195 [==============================] - 4s 21ms/step - loss: 2.5725 - acc: 0.8014
1168 | Epoch 145/200
1169 | 195/195 [==============================] - 4s 21ms/step - loss: 2.5743 - acc: 0.8000
1170 | Epoch 146/200
1171 | 195/195 [==============================] - 4s 21ms/step - loss: 2.6077 - acc: 0.7979
1172 | Epoch 147/200
1173 | 195/195 [==============================] - 4s 21ms/step - loss: 2.6226 - acc: 0.7988
1174 | Epoch 148/200
1175 | 195/195 [==============================] - 4s 21ms/step - loss: 2.5762 - acc: 0.8017
1176 | Epoch 149/200
1177 | 195/195 [==============================] - 4s 21ms/step - loss: 2.5979 - acc: 0.7970
1178 | Epoch 150/200
1179 | 195/195 [==============================] - 4s 21ms/step - loss: 2.6239 - acc: 0.7968
1180 | Epoch 151/200
1181 | 195/195 [==============================] - 4s 21ms/step - loss: 2.6141 - acc: 0.7958
1182 | Epoch 152/200
1183 | 195/195 [==============================] - 4s 21ms/step - loss: 2.5842 - acc: 0.8003
1184 | Epoch 153/200
1185 | 195/195 [==============================] - 4s 21ms/step - loss: 2.5660 - acc: 0.8022
1186 | Epoch 154/200
1187 | 195/195 [==============================] - 4s 21ms/step - loss: 2.5796 - acc: 0.8012
1188 | Epoch 155/200
1189 | 195/195 [==============================] - 4s 21ms/step - loss: 2.6097 - acc: 0.7964
1190 | Epoch 156/200
1191 | 195/195 [==============================] - 4s 21ms/step - loss: 2.6196 - acc: 0.7968
1192 | Epoch 157/200
1193 | 195/195 [==============================] - 4s 21ms/step - loss: 2.6238 - acc: 0.7960
1194 | Epoch 158/200
1195 | 195/195 [==============================] - 4s 21ms/step - loss: 2.5960 - acc: 0.7987
1196 | Epoch 159/200
1197 | 195/195 [==============================] - 4s 21ms/step - loss: 2.6156 - acc: 0.7973
1198 | Epoch 160/200
1199 | 195/195 [==============================] - 4s 21ms/step - loss: 2.6116 - acc: 0.7976
1200 | Epoch 161/200
1201 | 195/195 [==============================] - 4s 21ms/step - loss: 2.6379 - acc: 0.7962
1202 | Epoch 162/200
1203 | 195/195 [==============================] - 4s 21ms/step - loss: 2.6249 - acc: 0.7946
1204 | Epoch 163/200
1205 | 195/195 [==============================] - 4s 21ms/step - loss: 2.5953 - acc: 0.7966
1206 | Epoch 164/200
1207 | 195/195 [==============================] - 4s 21ms/step - loss: 2.6325 - acc: 0.7950
1208 | Epoch 165/200
1209 | 195/195 [==============================] - 4s 21ms/step - loss: 2.6523 - acc: 0.7930
1210 | Epoch 166/200
1211 | 195/195 [==============================] - 4s 21ms/step - loss: 2.5807 - acc: 0.8010
1212 | Epoch 167/200
1213 | 195/195 [==============================] - 4s 21ms/step - loss: 2.6433 - acc: 0.7930
1214 | Epoch 168/200
1215 | 195/195 [==============================] - 4s 21ms/step - loss: 2.5426 - acc: 0.8030
1216 | Epoch 169/200
1217 | 195/195 [==============================] - 4s 21ms/step - loss: 2.6110 - acc: 0.7972
1218 | Epoch 170/200
1219 | 195/195 [==============================] - 4s 21ms/step - loss: 2.6340 - acc: 0.7957
1220 | Epoch 171/200
1221 | 195/195 [==============================] - 4s 21ms/step - loss: 2.6350 - acc: 0.7963
1222 | Epoch 172/200
1223 | 195/195 [==============================] - 4s 21ms/step - loss: 2.5893 - acc: 0.7995
1224 | Epoch 173/200
1225 | 195/195 [==============================] - 4s 21ms/step - loss: 2.5882 - acc: 0.7984
1226 | Epoch 174/200
1227 | 195/195 [==============================] - 4s 21ms/step - loss: 2.6131 - acc: 0.7979
1228 | Epoch 175/200
1229 | 195/195 [==============================] - 4s 21ms/step - loss: 2.5740 - acc: 0.8005
1230 | Epoch 176/200
1231 | 195/195 [==============================] - 4s 21ms/step - loss: 2.6203 - acc: 0.7960
1232 | Epoch 177/200
1233 | 195/195 [==============================] - 4s 21ms/step - loss: 2.6040 - acc: 0.7974
1234 | Epoch 178/200
1235 | 195/195 [==============================] - 4s 21ms/step - loss: 2.6296 - acc: 0.7949
1236 | Epoch 179/200
1237 | 195/195 [==============================] - 4s 21ms/step - loss: 2.6019 - acc: 0.7982
1238 | Epoch 180/200
1239 | 195/195 [==============================] - 4s 21ms/step - loss: 2.6176 - acc: 0.7980
1240 | Epoch 181/200
1241 | 195/195 [==============================] - 4s 21ms/step - loss: 2.6353 - acc: 0.7956
1242 | Epoch 182/200
1243 | 195/195 [==============================] - 4s 21ms/step - loss: 2.6288 - acc: 0.7968
1244 | Epoch 183/200
1245 | 195/195 [==============================] - 4s 21ms/step - loss: 2.6437 - acc: 0.7945
1246 | Epoch 184/200
1247 | 195/195 [==============================] - 4s 21ms/step - loss: 2.6047 - acc: 0.7967
1248 | Epoch 185/200
1249 | 195/195 [==============================] - 4s 21ms/step - loss: 2.6041 - acc: 0.7991
1250 | Epoch 186/200
1251 | 195/195 [==============================] - 4s 21ms/step - loss: 2.6201 - acc: 0.7970
1252 | Epoch 187/200
1253 | 195/195 [==============================] - 4s 21ms/step - loss: 2.6390 - acc: 0.7957
1254 | Epoch 188/200
1255 | 195/195 [==============================] - 4s 21ms/step - loss: 2.6322 - acc: 0.7966
1256 | Epoch 189/200
1257 | 195/195 [==============================] - 4s 21ms/step - loss: 2.6007 - acc: 0.7981
1258 | Epoch 190/200
1259 | 195/195 [==============================] - 4s 21ms/step - loss: 2.5960 - acc: 0.7989
1260 | Epoch 191/200
1261 | 195/195 [==============================] - 4s 21ms/step - loss: 2.5726 - acc: 0.8006
1262 | Epoch 192/200
1263 | 195/195 [==============================] - 4s 21ms/step - loss: 2.5944 - acc: 0.7993
1264 | Epoch 193/200
1265 | 195/195 [==============================] - 4s 21ms/step - loss: 2.6384 - acc: 0.7954
1266 | Epoch 194/200
1267 | 195/195 [==============================] - 4s 21ms/step - loss: 2.6452 - acc: 0.7952
1268 | Epoch 195/200
1269 | 195/195 [==============================] - 4s 21ms/step - loss: 2.6178 - acc: 0.7959
1270 | Epoch 196/200
1271 | 195/195 [==============================] - 4s 21ms/step - loss: 2.6292 - acc: 0.7968
1272 | Epoch 197/200
1273 | 195/195 [==============================] - 4s 21ms/step - loss: 2.6361 - acc: 0.7950
1274 | Epoch 198/200
1275 | 195/195 [==============================] - 4s 21ms/step - loss: 2.6028 - acc: 0.7972
1276 | Epoch 199/200
1277 | 195/195 [==============================] - 4s 21ms/step - loss: 2.6190 - acc: 0.7978
1278 | Epoch 200/200
1279 | 195/195 [==============================] - 4s 21ms/step - loss: 2.6297 - acc: 0.7947
1280 |
1281 |
1282 |
1283 |
1284 |
1285 |
1286 | ```python
1287 | # 保存模型
1288 | m.save('/tmp/target.weights.h5')
1289 | ```
1290 |
1291 |
1292 | ```python
1293 | # 最后关闭hdf5
1294 | config.pw_h5_file.close()
1295 | ```
1296 |
1297 | ## 再次评估
1298 |
1299 |
1300 | ```python
1301 | evaluate('/tmp/target.weights.h5',
1302 | x_test,y_test[:,0],config.num_classes,True)
1303 | ```
1304 |
1305 | acc:0.5741
1306 | 虽然精度还不是很高(需要测试更多的超参数选择),但是可以看到从来都没使用100个类别一块训练;最终在100个类别的评估精度为0.5741
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | File Name: __init__.py
4 | Description : 论文《Large-scale Bisample Learning on ID vs. Spot Face Recognition》keras实现
5 | Author : mick.yi
6 | date: 2019/1/2
7 | """
--------------------------------------------------------------------------------
/backbone.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | File Name: backbone
4 | Description : 骨干网络
5 | Author : mick.yi
6 | date: 2018/12/21
7 | """
8 |
9 | import keras
10 | from keras import layers
11 | from keras_applications.resnet50 import identity_block, conv_block
12 |
13 |
14 | def resnet50(inputs):
15 | # Determine proper input shape
16 | bn_axis = 3
17 |
18 | # x = layers.ZeroPadding2D(padding=(3, 3), name='conv1_pad')(inputs)
19 | x = layers.Conv2D(64, (3, 3),
20 | strides=(1, 1),
21 | padding='same',
22 | name='conv1')(inputs)
23 | x = layers.BatchNormalization(axis=bn_axis, name='bn_conv1')(x)
24 | x = layers.Activation('relu')(x)
25 | x = layers.MaxPooling2D((3, 3), strides=(2, 2))(x)
26 |
27 | x = conv_block(x, 3, [64, 64, 64], stage=2, block='a', strides=(1, 1))
28 | x = identity_block(x, 3, [64, 64, 64], stage=2, block='b')
29 | x = identity_block(x, 3, [64, 64, 64], stage=2, block='c')
30 |
31 | x = conv_block(x, 3, [128, 128, 128], stage=3, block='a')
32 | x = identity_block(x, 3, [128, 128, 128], stage=3, block='b')
33 | x = identity_block(x, 3, [128, 128, 128], stage=3, block='c')
34 | x = identity_block(x, 3, [128, 128, 128], stage=3, block='d')
35 |
36 | x = conv_block(x, 3, [256, 256, 256], stage=4, block='a')
37 | x = identity_block(x, 3, [256, 256, 256], stage=4, block='b')
38 | x = identity_block(x, 3, [256, 256, 256], stage=4, block='c')
39 | x = identity_block(x, 3, [256, 256, 256], stage=4, block='d')
40 | x = identity_block(x, 3, [256, 256, 256], stage=4, block='e')
41 | x = identity_block(x, 3, [256, 256, 256], stage=4, block='f')
42 |
43 | x = conv_block(x, 3, [512, 512, 512], stage=5, block='a')
44 | x = identity_block(x, 3, [512, 512, 512], stage=5, block='b')
45 | x = identity_block(x, 3, [512, 512, 512], stage=5, block='c')
46 |
47 | # # 确定精调层
48 | # no_train_model = Model(inputs=img_input, outputs=x)
49 | # for l in no_train_model.layers:
50 | # if isinstance(l, layers.BatchNormalization):
51 | # l.trainable = True
52 | # else:
53 | # l.trainable = False
54 |
55 | # model = Model(input, x, name='resnet50')
56 | x = layers.GlobalAveragePooling2D()(x)
57 | # # 新增一个全连接层降维
58 | # x = layers.Dense(units=512)(x)
59 | return x
60 |
61 |
62 | def cifar_base_model(inputs):
63 | x = layers.Conv2D(32, (3, 3), padding='same')(inputs)
64 | x = layers.Activation('relu')(x)
65 | x = layers.Conv2D(32, (3, 3))(x)
66 | x = layers.Activation('relu')(x)
67 | x = layers.MaxPooling2D(pool_size=(2, 2))(x)
68 | x = layers.Dropout(0.25)(x)
69 |
70 | x = layers.Conv2D(64, (3, 3), padding='same')(x)
71 | x = layers.Activation('relu')(x)
72 | x = layers.Conv2D(64, (3, 3))(x)
73 | x = layers.Activation('relu')(x)
74 | x = layers.MaxPooling2D(pool_size=(2, 2))(x)
75 | x = layers.Dropout(0.25)(x)
76 |
77 | x = layers.Flatten()(x)
78 | x = layers.Dense(512)(x)
79 | x = layers.Activation('relu')(x)
80 | x = layers.Dropout(0.5)(x)
81 | return x
82 |
--------------------------------------------------------------------------------
/cifar_demo.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | File Name: cifar_demo
4 | Description :
5 | Author : mick.yi
6 | date: 2019/1/11
7 | """
8 | import tensorflow as tf
9 | import numpy as np
10 | import h5py
11 | import keras
12 | from keras.layers import Input, Dense, Layer
13 | from keras import layers
14 | from keras.models import Model
15 | from keras.optimizers import SGD
16 | import keras.backend as K
17 | from keras.callbacks import ReduceLROnPlateau, TensorBoard
18 | import sys
19 | import os
20 | import time
21 | import matplotlib.pyplot as plt
22 |
23 | if __package__ is None:
24 | sys.path
25 | sys.path.insert(0, '/opt/github/py_data_mining/face')
26 | __package__ = "keras-lbl-IvS"
27 |
28 | from .config import Config
29 | from .utils import random_select, get_weights, update_weights, update_queue
30 | from .faiss_utils import get_index, update_multi
31 | from .reader import get_mslm_infos, load_img
32 | from .layers import DenseWithDPSoftmaxLoss
33 |
34 |
35 | class DenseWithAMSoftmaxLoss(Layer):
36 |
37 | def __init__(self, num_classes, m=0.35, scale=30, **kwargs):
38 | """
39 | 根据AdditiveMarginSoftmaxforFaceVerification论文设置m和scale
40 | :param num_classes: 类别数
41 | :param m: margin间隔
42 | :param scale: 尺度
43 | :param kwargs:
44 | """
45 | self.output_dim = num_classes
46 | self.m = m
47 | self.scale = scale
48 | super(DenseWithAMSoftmaxLoss, self).__init__(**kwargs)
49 |
50 | def build(self, input_shape):
51 | # 添加可训练参数
52 | self.kernel = self.add_weight(name='kernel',
53 | shape=(input_shape[1], self.output_dim),
54 | initializer='glorot_normal',
55 | trainable=True)
56 |
57 | def call(self, inputs):
58 | self.inputs = inputs
59 | # L2归一化
60 | self.w_norm = K.tf.nn.l2_normalize(self.kernel, 0, 1e-10)
61 | self.x_norm = K.tf.nn.l2_normalize(self.inputs, 1, 1e-10)
62 | self.logits = K.dot(self.x_norm, self.w_norm)
63 | return self.logits
64 |
65 | def compute_output_shape(self, input_shape):
66 | return (input_shape[0], # batch_size
67 | self.output_dim)
68 |
69 | def loss(self, y_true, y_pred):
70 | """
71 | AmSoftmax Loss
72 | :param y_true: 真实类别 [N,num_classes]
73 | :param y_pred: 预测类别[N,num_classes]
74 | :return:
75 | """
76 | y_true = K.cast(y_true, 'float32')
77 | thelta_m = y_pred - y_true * self.m # cosine(thelta)-m ; y_true 就相当于mask
78 | return K.categorical_crossentropy(y_true, self.scale * thelta_m, from_logits=True)
79 |
80 | def loss_sparse(self,y_true,y_pred):
81 | y_true = y_true[:, 0] # 非常重要,默认是二维的
82 | y_true_mask = K.one_hot(K.tf.cast(y_true, dtype='int32'), self.output_dim)
83 | cosine_m = y_pred - y_true_mask * self.m # cosine-m
84 | losses = K.sparse_categorical_crossentropy(target=K.cast(y_true,'int32'),
85 | output=cosine_m * self.scale,
86 | from_logits=True)
87 | return losses
88 |
89 |
90 | class CifarConfig(Config):
91 | # 输入信息
92 | input_shape = (32, 32, 3)
93 | num_classes = 100 # cifar100
94 |
95 | # 训练参数
96 | batch_size = 30
97 |
98 | # 原型权重保存
99 | prototype_weights_hdf5 = '/tmp/cifar_prototype_weights.hdf5'
100 | prototype_weights_dataset = 'cifar_prototype_weights_set'
101 | pw_h5_file = None # hdf5 File文件
102 |
103 | index = None # 保存原型类别的faiss索引
104 |
105 | dominant_queue = None # 支配队列
106 | candidate_queue = None # 候选队列
107 |
108 | dominant_queue_num = 10 # 支配队列大小
109 | candidate_queue_num = 30 # 候选队列大小
110 |
111 |
112 | config = CifarConfig()
113 |
114 |
115 | def set_gpu_growth():
116 | os.environ["CUDA_VISIBLE_DEVICES"] = "0"
117 | cfg = tf.ConfigProto()
118 | cfg.gpu_options.allow_growth = True
119 | session = tf.Session(config=cfg)
120 | K.set_session(session)
121 |
122 |
123 | class DenseWithDPSoftmaxLoss(Layer):
124 | def __init__(self, num_class, m=0.35, scale=30, **kwargs):
125 | self.output_dim = num_class
126 | self.margin = m
127 | self.scale = scale
128 | # self.current_selected_labels = K.ones(shape=(num_class, 1))
129 | super(DenseWithDPSoftmaxLoss, self).__init__(**kwargs)
130 |
131 | def build(self, input_shape):
132 | self.kernel = self.add_weight(name='kernel',
133 | shape=(input_shape[0][1], self.output_dim), # (嵌入维度,num_class)
134 | dtype=K.floatx(),
135 | initializer='glorot_normal',
136 | trainable=True)
137 | self.current_selected_labels = self.add_weight(name='labels',
138 | shape=(self.output_dim, 1),
139 | initializer='glorot_normal',
140 | trainable=False)
141 | self.y_pred = self.add_weight(name='pred',
142 | shape=(self.output_dim, self.output_dim),
143 | initializer='glorot_normal',
144 | trainable=False)
145 |
146 | def call(self, inputs, **kwargs):
147 | # 将当前step类别的权重,赋值给tensor
148 | weights_assign_op = K.tf.assign(self.kernel,
149 | K.transpose(inputs[1]),
150 | name='assign_weights')
151 | label_assign_op = K.tf.assign(self.current_selected_labels, inputs[2], name='assign_labels')
152 | with K.tf.control_dependencies([weights_assign_op, label_assign_op]):
153 | self.x_norm = K.l2_normalize(inputs[0], axis=1)
154 | self.kernel_norm = K.l2_normalize(self.kernel, axis=0)
155 | self.logit = K.dot(self.x_norm, self.kernel_norm)
156 | return self.logit
157 |
158 | def compute_output_shape(self, input_shape):
159 | return (input_shape[0][0],
160 | self.output_dim)
161 | # return [(input_shape[0][0], self.output_dim),
162 | # (self.output_dim, input_shape[0][1])]
163 |
164 | def loss(self, y_true, y_pred):
165 | # 首先将预测值保持到权重中
166 | pred_assign_op = K.tf.assign(self.y_pred,
167 | y_pred,
168 | name='assign_pred')
169 | with K.tf.control_dependencies([pred_assign_op]):
170 | y_true = y_true[:, 0] # 非常重要,默认是二维的
171 | y_true_mask = K.one_hot(K.tf.cast(y_true, dtype='int32'), self.output_dim)
172 | cosine_m = y_pred - y_true_mask * self.margin # cosine-m
173 | losses = K.sparse_categorical_crossentropy(target=y_true,
174 | output=cosine_m * self.scale,
175 | from_logits=True)
176 |
177 | return losses
178 |
179 |
180 | def cifar_base_net(inputs):
181 | x = layers.Conv2D(32, (3, 3), padding='same', name='conv_1')(inputs)
182 | x = layers.BatchNormalization(axis=3, name='bn_conv1')(x)
183 | x = layers.Activation('relu')(x)
184 | x = layers.Conv2D(32, (3, 3), name='conv_2')(x)
185 | x = layers.Activation('relu')(x)
186 | x = layers.MaxPooling2D(pool_size=(2, 2))(x)
187 | x = layers.Dropout(0.25)(x)
188 |
189 | x = layers.Conv2D(64, (3, 3), padding='same', name='conv_3')(x)
190 | x = layers.Activation('relu')(x)
191 | x = layers.Conv2D(64, (3, 3), name='conv_4')(x)
192 | x = layers.Activation('relu')(x)
193 | x = layers.MaxPooling2D(pool_size=(2, 2))(x)
194 | x = layers.Dropout(0.25)(x)
195 |
196 | x = layers.Conv2D(512, (3, 3), name='conv_5')(x)
197 | x = layers.Activation('relu')(x)
198 | x = layers.GlobalAvgPool2D()(x)
199 |
200 | return x
201 |
202 |
203 | def label_id_map(labels, num_class):
204 | """
205 | 将图像按照类别分组
206 | """
207 | # 初始化
208 | label_id_dict = dict()
209 | for i in range(num_class):
210 | label_id_dict[i] = []
211 |
212 | # 逐个图像归类
213 | for img_id in range(len(labels)):
214 | label_id_dict[labels[img_id]].append(img_id)
215 | return label_id_dict
216 |
217 |
218 | def init_prototype(pretrained_weights, X, label_id_dict, num_class):
219 | inputs = Input(batch_shape=(config.batch_size,) + config.input_shape)
220 | features = cifar_base_net(inputs)
221 | model = Model(inputs, features)
222 | model.load_weights(pretrained_weights, by_name=True)
223 | # 原型权重一份放到hdf5,一份存放到faiss中(faiss中保留的不是精准的)
224 | f_label = h5py.File(config.prototype_weights_hdf5, 'w')
225 | label_feature = f_label.create_dataset(config.prototype_weights_dataset,
226 | shape=(num_class, 512), dtype='f')
227 |
228 | # 逐个类别处理
229 | for label in range(num_class):
230 | # 获取某个label的所有图像,并使用模型预测图像的特征,最后求均值作为label的原型权重
231 | image_ids = label_id_dict[label] # 图像id
232 | images = X[image_ids] # 图像数据
233 | features = model.predict(np.asarray(images)) # 输出特征
234 | features = keras.utils.np_utils.normalize(features) # 归一化
235 | features = np.mean(features, axis=0) # 求均值
236 | features = keras.utils.np_utils.normalize(features) # 再次归一化; 是二维的
237 |
238 | # 赋值给hdf5
239 | label_feature[label] = features[0] # (1,d) 转为 (d,)
240 | # 每1w次,刷写到磁盘
241 | if label % 1000 == 0:
242 | f_label.flush()
243 | print("{} init_prototype 完成:{}".format(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()),
244 | label))
245 | # 最后关闭文件
246 | f_label.close()
247 |
248 |
249 | def get_prototype(deal_batch=1000):
250 | f_label = h5py.File(config.prototype_weights_hdf5, 'r+')
251 | dset = f_label[config.prototype_weights_dataset]
252 | length = len(dset)
253 | index = get_index(512)
254 | # 逐个类别处理,更新faiss index
255 | for batch_no in range(length // deal_batch):
256 | start = batch_no * deal_batch
257 | end = (batch_no + 1) * deal_batch
258 | features = dset[start:end]
259 | update_multi(index, features, np.arange(start, end))
260 | # 处理不能整除的情况
261 | if not length % deal_batch == 0:
262 | start = length - length % deal_batch
263 | end = length
264 | features = dset[start:end]
265 | update_multi(index, features, np.arange(start, end))
266 | return f_label, index
267 |
268 |
269 | def init_queue(index, weights_set, num_class, dq_num, cq_num):
270 | """
271 | 初始化候选队列和支配队列
272 | :param index:
273 | :param weights_set: h5py dataset对象
274 | :param num_class:
275 | :param dq_num:
276 | :param cq_num:
277 | :return:
278 | """
279 | data, candidate_label_idx = index.search(weights_set[:num_class], cq_num)
280 | dominant_label_idx = candidate_label_idx[:, :dq_num] # 候选队列包含支配队列
281 |
282 | # 转为字典类型
283 | dominant_queue = dict(enumerate(dominant_label_idx))
284 | candidate_queue = dict(enumerate(candidate_label_idx))
285 |
286 | return dominant_queue, candidate_queue
287 |
288 |
289 | def generator(X, y, label_id_dict, dominant_queue, num_class, batch_size):
290 | """
291 | 训练样本生成器
292 | :param X: 图像
293 | :param y: 标签
294 | :param label_id_dict: 类别和对应的图像id字典
295 | :param dominant_queue: 支配队列; 类别和对应的支配类别
296 | :param num_class: 类别数, 这里是100
297 | :param batch_size: batch_size
298 | :return:
299 | """
300 | while True:
301 | # 两级采样,首先采样batch_size/3个正类别;然后随机采样支配类别
302 | sample_labels = np.random.choice(num_class, batch_size // 3, replace=False) # 无放回抽样
303 |
304 | selected_labels = set(sample_labels) # 保存当前选中类别set
305 | selected_image_labels = [] # 保存当前step选中的图像和对应的类别标签
306 | # 首先选择正原型
307 | for label in sample_labels:
308 | selected_image_id = random_select(label_id_dict[label])
309 | selected_image_labels.append([selected_image_id, label])
310 |
311 | # 再选择相关的支配原型,直到mini-batch大小
312 | while len(selected_image_labels) < batch_size:
313 | # 随机采样当前正原型
314 | label = random_select(sample_labels)
315 | # 随机选择支配类别,不能是之前已经选择过的
316 | dq_label = random_select(dominant_queue[label])
317 | while dq_label in selected_labels:
318 | dq_label = random_select(dominant_queue[label])
319 | selected_labels.add(dq_label)
320 | # 选择支配类别的图像
321 | selected_image_id = random_select(label_id_dict[dq_label])
322 | selected_image_labels.append([selected_image_id, dq_label])
323 | # 当前选中标签
324 | selected_image_labels = np.asarray(selected_image_labels) # 转为numpy数组
325 | current_selected_labels = selected_image_labels[:, 1]
326 | current_weights = get_weights(config.pw_h5_file[config.prototype_weights_dataset],
327 | current_selected_labels)
328 |
329 | # 加载图像
330 | images = X[current_selected_labels]
331 | # 返回当前mini-batch
332 | current_selected_labels = np.expand_dims(current_selected_labels, axis=1)
333 |
334 | yield [images,
335 | current_weights,
336 | current_selected_labels], np.arange(batch_size) # 标签类别永远是0~batch_size-1
337 |
338 |
339 | class ExportWeights(keras.callbacks.Callback):
340 | def on_batch_end(self, batch, logs=None):
341 | layer = self.model.layers[-1]
342 | # 获取当前step训练后的权重、当前训练实际的标签类别,预测结果
343 | trained_weights, current_trained_labels, y_pred = layer.get_weights()[:3]
344 | trained_weights = np.transpose(trained_weights)
345 | current_trained_labels = np.asarray(current_trained_labels[:, 0], dtype=np.int) # 二维转一维
346 | # 更新faiss中的权重
347 | update_weights(config.pw_h5_file,
348 | config.pw_h5_file[config.prototype_weights_dataset],
349 | trained_weights,
350 | current_trained_labels)
351 | # 以下更新支配队列,根据预测结果更新支配队列;
352 | update_queue(config.dominant_queue,
353 | config.candidate_queue,
354 | y_pred,
355 | current_trained_labels)
356 |
357 |
358 | def get_call_back():
359 | """
360 | 定义call back
361 | :return:
362 | """
363 |
364 | # 精度没有提升
365 | lr_reducer = ReduceLROnPlateau(monitor='acc',
366 | factor=np.sqrt(0.1),
367 | cooldown=1,
368 | patience=1,
369 | min_lr=0)
370 |
371 | log = TensorBoard(log_dir='log')
372 |
373 | export_weights = ExportWeights()
374 |
375 | return [lr_reducer, export_weights, log]
376 |
377 |
378 | def source_model(img_inputs, num_class, use_amsoftmax=False):
379 | print("use_amsoftmax:{}".format(use_amsoftmax))
380 | if use_amsoftmax:
381 | dense = DenseWithAMSoftmaxLoss(num_class, m=0.2, scale=24., name='amsoftmax')
382 | loss = dense.loss
383 | else:
384 | dense = layers.Dense(num_class, use_bias=False, activation='softmax', name='dense')
385 | loss = K.categorical_crossentropy
386 |
387 | features = cifar_base_net(img_inputs)
388 | output = dense(features)
389 |
390 | m = Model(img_inputs, output)
391 |
392 | m.compile(optimizer=keras.optimizers.Adam(lr=0.001, decay=1e-6),
393 | loss=loss,
394 | metrics=['accuracy'])
395 | return m
396 |
397 |
398 | def evaluate(pretrained_weights, X, y, num_class, use_amsoftmax=False):
399 | if use_amsoftmax:
400 | dense = DenseWithAMSoftmaxLoss(num_class, m=0.2, scale=24., name='other')
401 | else:
402 | dense = layers.Dense(num_class, use_bias=False, activation='softmax')
403 | # 构建模型
404 | inputs = Input(shape=config.input_shape)
405 | features = cifar_base_net(inputs)
406 | outputs = dense(features)
407 | m = Model(inputs, outputs)
408 |
409 | m.load_weights(pretrained_weights, by_name=False)
410 |
411 | # 设置分类层的权重
412 | f_label = h5py.File(config.prototype_weights_hdf5, 'r')
413 | weights = f_label[config.prototype_weights_dataset] # HDF5 dataset object
414 | m.layers[-1].set_weights([np.transpose(weights[:])]) # 使用weights[:]转为numpy
415 | # 最后关闭文件
416 | f_label.close()
417 | m.summary()
418 |
419 | # 评估预测
420 | y_pred = m.predict(X)
421 | y_pred = np.argmax(y_pred, axis=-1)
422 |
423 | print("acc:{}".format(np.mean(y == y_pred)))
424 |
425 |
426 | def img_show(X, rows, cols):
427 | plt.figure(figsize=(cols * 1.5, rows))
428 | for i, img in enumerate(X[:rows * cols]):
429 | plt.subplot(rows, cols, 1 + i)
430 | plt.imshow(img)
431 | plt.show()
432 |
433 |
434 | def main(args):
435 | K.clear_session()
436 | set_gpu_growth()
437 |
438 | (x_train, y_train), (x_test, y_test) = cifar10.load_data()
439 | label_id_dict = label_id_map(x_train, config.num_classes) # 类别对应的
440 |
441 | # 初始化原型权重
442 | if 'init' in args.stages:
443 | init_prototype(x_train,
444 | label_id_dict,
445 | config.num_classes)
446 | print("初始化原型权重完成... ...")
447 |
448 | # 训练阶段
449 | if 'train' in args.stages:
450 | config.pw_h5_file, config.index = get_prototype(1000)
451 | print("获取原型权重完成... ...")
452 | # current_weights = get_weights(config.pw_h5_file[config.prototype_weights_dataset],
453 | # np.arange(config.batch_size))
454 | # print(current_weights.shape)
455 | # print(type(current_weights))
456 | # print(current_weights[0])
457 |
458 | # 初始化队列
459 | config.dominant_queue, config.candidate_queue = init_queue(config.index,
460 | config.pw_h5_file[config.prototype_weights_dataset],
461 | config.num_classes,
462 | config.dominant_queue_num,
463 | config.candidate_queue_num)
464 | print("初始化队列完成... ...")
465 | # 构建模型
466 | inputs = Input(batch_shape=(config.batch_size,) + config.input_shape)
467 | weights_inputs = Input(batch_shape=(config.batch_size, 512))
468 | label_inputs = Input(batch_shape=(config.batch_size, 1)) # 至少是二维的,
469 | features = config.backbone(inputs)
470 | dense = DenseWithDPSoftmaxLoss(config.batch_size) # batch-size当做类别数
471 | outputs = dense([features, weights_inputs, label_inputs])
472 | m = Model([inputs, weights_inputs, label_inputs], outputs)
473 | m.summary()
474 | m.load_weights(config.backbone_weights, by_name=True)
475 | m.compile(loss=dense.loss, optimizer=SGD(lr=0.01, momentum=0.9, decay=0.0005), metrics=['accuracy'])
476 |
477 | # 训练模型
478 | print("开始训练模型... ...")
479 | gen = generator(x_train,
480 | label_id_dict,
481 | config.dominant_queue,
482 | config.num_classes,
483 | config.batch_size)
484 |
485 | m.fit_generator(gen,
486 | callbacks=get_call_back(),
487 | steps_per_epoch=len(x_train) / 4 // config.batch_size,
488 | epochs=10,
489 | use_multiprocessing=False,
490 | verbose=1,
491 | validation_data=next(gen)
492 | )
493 |
494 | # 最后关闭hdf5
495 | config.pw_h5_file.close()
496 |
497 |
498 | if __name__ == '__main__':
499 | # 训练源
500 | img_input = Input(shape=(32, 32, 3))
501 | use_amsoftmax = sys.argv[1] == '1'
502 | m = source_model(img_input, 10, use_amsoftmax)
503 |
504 | m.summary()
505 |
506 | from keras.datasets import cifar10, cifar100
507 |
508 | (x_train, y_train), (x_test, y_test) = cifar10.load_data()
509 | y_train = keras.utils.to_categorical(y_train, 10)
510 | y_test = keras.utils.to_categorical(y_test, 10)
511 | m.fit(x_train, y_train,
512 | batch_size=64,
513 | epochs=10,
514 | validation_data=(x_test, y_test),
515 | callbacks=[keras.callbacks.ModelCheckpoint(filepath='/tmp/source.weights.{epoch:03d}.h5',
516 | monitor='val_acc',
517 | verbose=1,
518 | save_best_only=True)],
519 | shuffle=True)
520 | m.save('/tmp/source.weights.h5')
521 |
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | File Name: config
4 | Description : 配置文件
5 | Author : mick.yi
6 | date: 2019/1/3
7 | """
8 | from .backbone import resnet50
9 |
10 |
11 | class Config(object):
12 | #################
13 | # basenet 阶段
14 | #################
15 | # ms1m人脸数据标注和图像图论
16 | ms1m_annotation_file = '/home/dataset/face_recognize/ms1m_112_112.label'
17 | ms1m_img_dir = '/home/dataset/face_recognize/ms1m_112_112'
18 | # 输入信息
19 | input_shape = (112, 112, 3)
20 | num_classes = 2000 # 测试时可以调小点 85164
21 |
22 | # 网络结构信息
23 | def backbone(self, inputs):
24 | return resnet50(inputs)
25 |
26 | # 训练参数
27 | batch_size = 48
28 |
29 | lr = 0.1
30 | learning_rate_schedule = {
31 | 0: 1 * lr,
32 | 160000: 0.1 * lr,
33 | 240000: 0.01 * lr,
34 | 280000: 0.001 * lr
35 | }
36 |
37 | backbone_weights = '/tmp/docface.basenet.002.h5'
38 |
39 | # 其它参数信息
40 | prototype_weights_hdf5 = '/tmp/prototype_weights.hdf5'
41 | prototype_weights_dataset = 'prototype_weights_set'
42 | pw_h5_file = None # hdf5 File文件
43 |
44 | index = None # 保存原型类别的faiss索引
45 |
46 | dominant_queue = None # 支配队列
47 | candidate_queue = None # 候选队列
48 |
49 | dominant_queue_num = 10 # 支配队列大小
50 | candidate_queue_num = 30 # 候选队列大小
51 |
52 | # 评估阶段参数
53 | lbl_weights = '/tmp/lbl-IvS.010.h5'
54 |
55 |
56 | # 当前配置
57 | current_config = Config()
58 |
--------------------------------------------------------------------------------
/evaluate.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | File Name: evaluate
4 | Description : lbl-IvS评估
5 | Author : mick.yi
6 | date: 2019/1/10
7 | """
8 | import keras
9 | from keras import Input
10 | from keras.layers import Dense
11 | from keras.models import Model
12 | import numpy as np
13 | import sys
14 | import os
15 | import h5py
16 |
17 | if __package__ is None:
18 | sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
19 | __package__ = "keras-lbl-IvS"
20 |
21 | from .config import current_config as config
22 | from .reader import get_mslm_infos, load_img
23 |
24 |
25 | def gen_data(images_info, batch_size):
26 | all_size = len(images_info)
27 | while True:
28 | batch_idx = np.random.choice(all_size, batch_size, replace=False)
29 | images = []
30 | labels = []
31 | for i in batch_idx:
32 | images.append(load_img(images_info[i]['img_path']))
33 | labels.append(int(images_info[i]['label']))
34 | yield np.asarray(images), np.asarray(labels)
35 |
36 |
37 | def main():
38 | # 构建模型
39 | inputs = Input(batch_shape=(config.batch_size,) + config.input_shape)
40 | features = config.backbone(inputs)
41 | dense = Dense(config.num_classes,
42 | use_bias=False,
43 | activation='softmax') # 增加最后一层分类层;这里需要使用softmax激活,默认不是fromlogits
44 | outputs = dense(features)
45 | m = Model(inputs, outputs)
46 | m.compile(optimizer='SGD',
47 | loss=keras.losses.sparse_categorical_crossentropy,
48 | metrics=['accuracy'])
49 | m.load_weights(config.lbl_weights, by_name=True)
50 |
51 | # 设置分类层的权重
52 | f_label = h5py.File(config.prototype_weights_hdf5, 'r')
53 | weights = f_label[config.prototype_weights_dataset] # HDF5 dataset object
54 | m.layers[-1].set_weights([np.transpose(weights[:])]) # 使用weights[:]转为numpy
55 | m.summary()
56 |
57 | # 加载数据
58 | images_info, label_set = get_mslm_infos(config.ms1m_annotation_file, config.ms1m_img_dir)
59 | # 过滤小于num_classes的类别;测试用
60 | images_info = filter(lambda x: int(x['label']) < config.num_classes, images_info)
61 |
62 | # 评估预测
63 | scores = m.evaluate_generator(gen_data(list(images_info), 200),
64 | steps=10,
65 | verbose=1)
66 | m.predict_generator
67 | print(scores)
68 |
69 |
70 | if __name__ == '__main__':
71 | main()
72 |
--------------------------------------------------------------------------------
/faiss_utils.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | File Name: faiss_utils
4 | Description : faiss工具类
5 | Author : mick.yi
6 | date: 2019/1/4
7 | """
8 | import faiss
9 | import numpy as np
10 |
11 |
12 | def get_index(dimension):
13 | sub_index = faiss.IndexFlatL2(dimension)
14 | index = faiss.IndexIDMap(sub_index)
15 | return index
16 |
17 |
18 | def update_multi(index, vectors, ids):
19 | """
20 |
21 | :param index:
22 | :param vectors:
23 | :param ids:
24 | :return:
25 | 备注:ValueError: array is not C-contiguous
26 | """
27 | idx = np.argsort(ids)
28 | # 先删除再添加
29 | index.remove_ids(ids[idx])
30 | index.add_with_ids(vectors[idx], ids[idx])
31 |
32 |
33 | def update_one(index, vector, label_id):
34 | vectors = np.expand_dims(vector, axis=0)
35 | ids = np.array([label_id])
36 | update_multi(index, vectors, ids)
37 |
--------------------------------------------------------------------------------
/layers.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | File Name: layers
4 | Description : keras 层
5 | Author : mick.yi
6 | date: 2019/1/2
7 | """
8 | from keras.layers import Layer
9 | from keras import backend as K
10 |
11 |
12 | class DenseWithDPSoftmaxLoss(Layer):
13 | def __init__(self, num_class, m=0.35, scale=30, **kwargs):
14 | self.output_dim = num_class
15 | self.margin = m
16 | self.scale = scale
17 | # self.current_selected_labels = K.ones(shape=(num_class, 1))
18 | super(DenseWithDPSoftmaxLoss, self).__init__(**kwargs)
19 |
20 | def build(self, input_shape):
21 | self.kernel = self.add_weight(name='kernel',
22 | shape=(input_shape[0][1], self.output_dim), # (嵌入维度,num_class)
23 | dtype=K.floatx(),
24 | initializer='glorot_normal',
25 | trainable=True)
26 | self.current_selected_labels = self.add_weight(name='labels',
27 | shape=(self.output_dim, 1),
28 | initializer='glorot_normal',
29 | trainable=False)
30 | self.y_pred = self.add_weight(name='pred',
31 | shape=(self.output_dim, self.output_dim),
32 | initializer='glorot_normal',
33 | trainable=False)
34 |
35 | def call(self, inputs, **kwargs):
36 | # 将当前step类别的权重,赋值给tensor
37 | weights_assign_op = K.tf.assign(self.kernel,
38 | K.transpose(inputs[1]),
39 | name='assign_weights')
40 | label_assign_op = K.tf.assign(self.current_selected_labels, inputs[2], name='assign_labels')
41 | with K.tf.control_dependencies([weights_assign_op, label_assign_op]):
42 | self.x_norm = K.l2_normalize(inputs[0], axis=1)
43 | self.kernel_norm = K.l2_normalize(self.kernel, axis=0)
44 | self.logit = K.dot(self.x_norm, self.kernel_norm)
45 | return self.logit
46 |
47 | def compute_output_shape(self, input_shape):
48 | return (input_shape[0][0],
49 | self.output_dim)
50 | # return [(input_shape[0][0], self.output_dim),
51 | # (self.output_dim, input_shape[0][1])]
52 |
53 | def loss(self, y_true, y_pred):
54 | # 首先将预测值保持到权重中
55 | pred_assign_op = K.tf.assign(self.y_pred,
56 | y_pred,
57 | name='assign_pred')
58 | with K.tf.control_dependencies([pred_assign_op]):
59 | y_true = y_true[:, 0] # 非常重要,默认是二维的
60 | y_true_mask = K.one_hot(K.tf.cast(y_true, dtype='int32'), self.output_dim)
61 | cosine_m = y_pred - y_true_mask * self.margin # cosine-m
62 | losses = K.sparse_categorical_crossentropy(target=y_true,
63 | output=cosine_m * self.scale,
64 | from_logits=True)
65 |
66 | return losses
67 |
--------------------------------------------------------------------------------
/log/events.out.tfevents.1547478727.localhost.localdomain:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yizt/keras-lbl-IvS/3f98b698c56ae40954b4920da167f7c9e32024c8/log/events.out.tfevents.1547478727.localhost.localdomain
--------------------------------------------------------------------------------
/pic/acc.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yizt/keras-lbl-IvS/3f98b698c56ae40954b4920da167f7c9e32024c8/pic/acc.jpg
--------------------------------------------------------------------------------
/pic/loss.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yizt/keras-lbl-IvS/3f98b698c56ae40954b4920da167f7c9e32024c8/pic/loss.jpg
--------------------------------------------------------------------------------
/pic/lr.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yizt/keras-lbl-IvS/3f98b698c56ae40954b4920da167f7c9e32024c8/pic/lr.jpg
--------------------------------------------------------------------------------
/pic/output_21_0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yizt/keras-lbl-IvS/3f98b698c56ae40954b4920da167f7c9e32024c8/pic/output_21_0.png
--------------------------------------------------------------------------------
/pic/output_8_0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yizt/keras-lbl-IvS/3f98b698c56ae40954b4920da167f7c9e32024c8/pic/output_8_0.png
--------------------------------------------------------------------------------
/pipeline.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | File Name: pipeline
4 | Description : 处理流程
5 | Author : mick.yi
6 | date: 2019/1/3
7 | """
8 | import tensorflow as tf
9 | import numpy as np
10 | import h5py
11 | import keras
12 | from keras.layers import Input
13 | from keras.models import Model
14 | from keras.optimizers import SGD
15 | import keras.backend as K
16 | from keras.callbacks import ModelCheckpoint, ReduceLROnPlateau, TensorBoard
17 | import sys
18 | import os
19 | import time
20 | import argparse
21 |
22 | if __package__ is None:
23 | sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
24 | __package__ = "keras-lbl-IvS"
25 |
26 | from .config import current_config as config
27 | from .utils import random_select, get_weights, update_weights, update_queue
28 | from .faiss_utils import get_index, update_multi
29 | from .reader import get_mslm_infos, load_img
30 | from .layers import DenseWithDPSoftmaxLoss
31 |
32 |
33 | def set_gpu_growth():
34 | os.environ["CUDA_VISIBLE_DEVICES"] = "0"
35 | cfg = tf.ConfigProto()
36 | cfg.gpu_options.allow_growth = True
37 | session = tf.Session(config=cfg)
38 | K.set_session(session)
39 |
40 |
41 | def generator(images_info, label_id_dict, dominant_queue, num_class, batch_size):
42 | """
43 | 训练样本生成器
44 | :param images_info: 图像的元数据信息
45 | :param label_id_dict: 类别和对应的图像id字典
46 | :param dominant_queue: 支配队列; 类别和对应的支配类别
47 | :param num_class: 类别数
48 | :param batch_size: batch_size
49 | :return:
50 | """
51 | while True:
52 | # 两级采样,首先采样batch_size/2个正类别;然后随机采样支配类别
53 | sample_labels = np.random.choice(num_class, batch_size // 2, replace=False) # 无放回抽样
54 |
55 | selected_labels = set(sample_labels) # 保存当前选中类别set
56 | selected_image_labels = [] # 保存当前step选中的图像和对应的类别标签
57 | # 首先选择正原型
58 | for label in sample_labels:
59 | selected_image_id = random_select(label_id_dict[label])
60 | selected_image_labels.append([selected_image_id, label])
61 |
62 | # 再选择相关的支配原型,直到mini-batch大小
63 | while len(selected_image_labels) < batch_size:
64 | # 随机采样当前正原型
65 | label = random_select(sample_labels)
66 | # 随机选择支配类别,不能是之前已经选择过的
67 | dq_label = random_select(dominant_queue[label])
68 | while dq_label in selected_labels:
69 | dq_label = random_select(dominant_queue[label])
70 | selected_labels.add(dq_label)
71 | # 选择支配类别的图像
72 | selected_image_id = random_select(label_id_dict[dq_label])
73 | selected_image_labels.append([selected_image_id, dq_label])
74 | # 当前选中标签
75 | selected_image_labels = np.asarray(selected_image_labels) # 转为numpy数组
76 | current_selected_labels = selected_image_labels[:, 1]
77 | current_weights = get_weights(config.pw_h5_file[config.prototype_weights_dataset],
78 | current_selected_labels)
79 |
80 | # 加载图像
81 | images = [load_img(images_info[image_id]['img_path']) for image_id, label in selected_image_labels]
82 | images = np.asarray(images)
83 | # 返回当前mini-batch
84 | current_selected_labels = np.expand_dims(current_selected_labels, axis=1)
85 | # print("current_selected_labels.shape:{}".format(current_selected_labels.shape))
86 | # print("images.shape:{},type(images):{}".format(images.shape, type(images)))
87 | yield [images,
88 | current_weights,
89 | current_selected_labels], np.arange(batch_size) # 标签类别永远是0~batch_size-1
90 |
91 |
92 | def init_queue(index, weights_set, num_class, dq_num, cq_num):
93 | """
94 | 初始化候选队列和支配队列
95 | :param index:
96 | :param weights_set: h5py dataset对象
97 | :param num_class:
98 | :param dq_num:
99 | :param cq_num:
100 | :return:
101 | """
102 | data, candidate_label_idx = index.search(weights_set[:num_class], cq_num)
103 | dominant_label_idx = candidate_label_idx[:, :dq_num] # 候选队列包含支配队列
104 |
105 | # 转为字典类型
106 | dominant_queue = dict(enumerate(dominant_label_idx))
107 | candidate_queue = dict(enumerate(candidate_label_idx))
108 |
109 | return dominant_queue, candidate_queue
110 |
111 |
112 | def init_prototype(images_info, label_id_dict, num_class):
113 | inputs = Input(batch_shape=(config.batch_size,) + config.input_shape)
114 | features = config.backbone(inputs)
115 | model = Model(inputs, features)
116 | model.load_weights(config.backbone_weights, by_name=True)
117 | # 原型权重一份放到hdf5,一份存放到faiss中(faiss中保留的不是精准的)
118 | f_label = h5py.File(config.prototype_weights_hdf5, 'w')
119 | label_feature = f_label.create_dataset(config.prototype_weights_dataset,
120 | shape=(num_class, 512), dtype='f')
121 |
122 | # 逐个类别处理
123 | for label in range(num_class):
124 | # 获取某个label的所有图像,并使用模型预测图像的特征,最后求均值作为label的原型权重
125 | image_ids = label_id_dict[label] # 图像id
126 | images = [load_img(images_info[image_id]['img_path']) for image_id in image_ids] # 图像数据
127 | features = model.predict(np.asarray(images)) # 输出特征
128 | features = keras.utils.np_utils.normalize(features) # 归一化
129 | features = np.mean(features, axis=0) # 求均值
130 | features = keras.utils.np_utils.normalize(features) # 再次归一化; 是二维的
131 |
132 | # 赋值给hdf5
133 | label_feature[label] = features[0] # (1,d) 转为 (d,)
134 | # 每1w次,刷写到磁盘
135 | if label % 500 == 0:
136 | f_label.flush()
137 | print("{} init_prototype 完成:{}".format(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()),
138 | label))
139 | # 最后关闭文件
140 | f_label.close()
141 |
142 |
143 | def get_prototype(deal_batch=1000):
144 | f_label = h5py.File(config.prototype_weights_hdf5, 'r+')
145 | dset = f_label[config.prototype_weights_dataset]
146 | length = len(dset)
147 | index = get_index(512)
148 | # 逐个类别处理,更新faiss index
149 | for batch_no in range(length // deal_batch):
150 | start = batch_no * deal_batch
151 | end = (batch_no + 1) * deal_batch
152 | features = dset[start:end]
153 | update_multi(index, features, np.arange(start, end))
154 | # 处理不能整除的情况
155 | if not length % deal_batch == 0:
156 | start = length - length % deal_batch
157 | end = length
158 | features = dset[start:end]
159 | update_multi(index, features, np.arange(start, end))
160 | return f_label, index
161 |
162 |
163 | def label_id_map(images_info, num_class):
164 | """
165 | 将图像按照类别分组
166 | :param images_info: 图像字典{'img_path': 图像路径,'label': 类别,'img_id':图像id}
167 | :param num_class: 类别数
168 | :return:
169 | """
170 | # 初始化
171 | label_id_dict = dict()
172 | for i in range(num_class):
173 | label_id_dict[i] = []
174 |
175 | # 逐个图像归类
176 | for i in range(len(images_info)):
177 | label = int(images_info[i]['label'])
178 | img_id = images_info[i]['img_id']
179 | label_id_dict[label].append(img_id)
180 |
181 | return label_id_dict
182 |
183 |
184 | class ExportWeights(keras.callbacks.Callback):
185 | def on_batch_end(self, batch, logs=None):
186 | layer = self.model.layers[-1]
187 |
188 | trained_weights, current_trained_labels, y_pred = layer.get_weights()[:3]
189 | trained_weights = np.transpose(trained_weights)
190 | current_trained_labels = np.asarray(current_trained_labels[:, 0], dtype=np.int)
191 | # print("\n input_weights:{}".format(config.current_input[1][0][:10]))
192 | # print("\n trained_weights:{}".format(trained_weights[0][:10]))
193 | # current_selected_labels = config.current_selected_labels
194 | # print("current_selected_labels:{}".format(current_selected_labels))
195 | # print("current_trained_labels:{}".format(current_trained_labels))
196 | update_multi(config.index, trained_weights, current_trained_labels) # 更新faiss index
197 | update_weights(config.pw_h5_file,
198 | config.pw_h5_file[config.prototype_weights_dataset],
199 | trained_weights,
200 | current_trained_labels)
201 | # 以下更新支配队列,根据预测结果更新支配队列;
202 | # todo : 需要使用更加高效的方式来获取输出结果,使用predict相当于又做了一次前向传播(已解决,使用权重保存)
203 | # y_pred = self.model.predict(config.current_input)
204 | # print("y_pred:{}".format(y_pred))
205 | update_queue(config.dominant_queue,
206 | config.candidate_queue,
207 | y_pred,
208 | current_trained_labels)
209 |
210 |
211 | def get_call_back():
212 | """
213 | 定义call back
214 | :return:
215 | """
216 | checkpoint = ModelCheckpoint(filepath='/tmp/lbl-IvS.{epoch:03d}.h5',
217 | monitor='val_acc',
218 | verbose=1,
219 | save_best_only=False)
220 |
221 | # 验证误差没有提升
222 | lr_reducer = ReduceLROnPlateau(monitor='val_loss',
223 | factor=np.sqrt(0.1),
224 | cooldown=1,
225 | patience=1,
226 | min_lr=0)
227 |
228 | log = TensorBoard(log_dir='log')
229 |
230 | export_weights = ExportWeights()
231 |
232 | return [checkpoint, lr_reducer, export_weights, log]
233 |
234 |
235 | def main(args):
236 | K.clear_session()
237 | set_gpu_growth()
238 | # 获取图像元数据信息
239 | images_info, label_set = get_mslm_infos(config.ms1m_annotation_file, config.ms1m_img_dir)
240 | # 过滤小于num_classes的类别;测试用
241 | images_info = filter(lambda x: int(x['label']) < config.num_classes, images_info)
242 | images_info = dict(enumerate(images_info)) # 转为字典
243 |
244 | label_id_dict = label_id_map(images_info, config.num_classes) # 类别对应的
245 |
246 | # 初始化原型权重
247 | if 'init' in args.stages:
248 | init_prototype(images_info,
249 | label_id_dict,
250 | config.num_classes)
251 | print("初始化原型权重完成... ...")
252 |
253 | # 训练阶段
254 | if 'train' in args.stages:
255 | config.pw_h5_file, config.index = get_prototype(1000)
256 | print("获取原型权重完成... ...")
257 | # current_weights = get_weights(config.pw_h5_file[config.prototype_weights_dataset],
258 | # np.arange(config.batch_size))
259 | # print(current_weights.shape)
260 | # print(type(current_weights))
261 | # print(current_weights[0])
262 |
263 | # 初始化队列
264 | config.dominant_queue, config.candidate_queue = init_queue(config.index,
265 | config.pw_h5_file[config.prototype_weights_dataset],
266 | config.num_classes,
267 | config.dominant_queue_num,
268 | config.candidate_queue_num)
269 | print("初始化队列完成... ...")
270 | # 构建模型
271 | inputs = Input(batch_shape=(config.batch_size,) + config.input_shape)
272 | weights_inputs = Input(batch_shape=(config.batch_size, 512))
273 | label_inputs = Input(batch_shape=(config.batch_size, 1)) # 至少是二维的,
274 | features = config.backbone(inputs)
275 | dense = DenseWithDPSoftmaxLoss(config.batch_size) # batch-size当做类别数
276 | outputs = dense([features, weights_inputs, label_inputs])
277 | m = Model([inputs, weights_inputs, label_inputs], outputs)
278 | m.summary()
279 | m.load_weights(config.backbone_weights, by_name=True)
280 | m.compile(loss=dense.loss, optimizer=SGD(lr=0.01, momentum=0.9, decay=0.0005), metrics=['accuracy'])
281 |
282 | # 训练模型
283 | print("开始训练模型... ...")
284 | gen = generator(images_info,
285 | label_id_dict,
286 | config.dominant_queue,
287 | config.num_classes,
288 | config.batch_size)
289 |
290 | m.fit_generator(gen,
291 | callbacks=get_call_back(),
292 | steps_per_epoch=len(images_info) / 4 // config.batch_size,
293 | epochs=10,
294 | use_multiprocessing=False,
295 | verbose=1,
296 | validation_data=next(gen)
297 | )
298 |
299 | # 最后关闭hdf5
300 | config.pw_h5_file.close()
301 |
302 |
303 | if __name__ == '__main__':
304 | parse = argparse.ArgumentParser()
305 | parse.add_argument("--stages", type=str, nargs='+', default=['train'], help="stage: init、train")
306 | argments = parse.parse_args(sys.argv[1:])
307 | # print(argments)
308 | # useage: python pipeline.py --stages init train | python pipeline.py --stages train
309 | main(argments)
310 |
--------------------------------------------------------------------------------
/reader.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | File Name: reader
4 | Description : 读取图像信息
5 | Author : mick.yi
6 | date: 2018/12/26
7 | """
8 | import os
9 | import codecs
10 | import matplotlib.pyplot as plt
11 |
12 |
13 | def get_mslm_infos(annotation_file, img_dir):
14 | """
15 | 读取mslm数据集信息
16 | :param annotation_file: 标注文件路径
17 | :param img_dir: 图像存放路径
18 | :return:
19 | """
20 | with codecs.open(annotation_file, 'r', 'utf-8') as f:
21 | lines = f.readlines()
22 | img_infos = []
23 |
24 | label_set = set()
25 | for id, line in enumerate(lines):
26 | img_name, label = line.split('\t')
27 | img_info = dict()
28 | img_info['img_path'] = os.path.join(img_dir, img_name)
29 | img_info['label'] = label
30 | img_info['img_id'] = id # 增加图像id编号
31 | img_infos.append(img_info)
32 | label_set.add(label)
33 |
34 | return img_infos, label_set
35 |
36 |
37 | def load_img(img_path):
38 | img = plt.imread(img_path)
39 | return img[:, :, :3]
40 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | File Name: utils
4 | Description :
5 | Author : mick.yi
6 | date: 2019/1/4
7 | """
8 | import numpy as np
9 |
10 |
11 | def enqueue(np_array, elem):
12 | """
13 | 入队列,新增元素放到队首,队尾元素丢弃
14 | :param np_array: 原始队列
15 | :param elem: 增加元素
16 | :return:
17 | """
18 | np_array[1:] = np_array[:-1]
19 | np_array[0] = elem
20 | return np_array
21 |
22 |
23 | def random_select(ids):
24 | """
25 | 随机选择一个id
26 | :param ids: id列表,(N,)
27 | :return:
28 | """
29 | idx = np.random.choice(len(ids))
30 | return ids[idx]
31 |
32 |
33 | # def to_train_label(train_label):
34 |
35 |
36 | def update_weights(h5_file, h5_dataset, weights, labels):
37 | """
38 | 更新保存在hdf5中的原型权重
39 | :param h5_file: 原型权重的hdf5文件
40 | :param h5_dataset: 原型权重在hdf5中的dataset
41 | :param weights: 待更新的权重,numpy数组 (Batch,Dim)
42 | :param labels: 待更新的权重对应的类别标签
43 | :return:
44 | 备注:TypeError: Indexing elements must be in increasing order; idx要排序
45 | TypeError: PointSelection __getitem__ only works with bool arrays; labels[idx]改为list(labels[idx])
46 | """
47 | # for idx, label in enumerate(labels):
48 | # h5_dataset[label] = weights[idx]
49 | idx = np.argsort(labels)
50 | h5_dataset[list(labels[idx])] = weights[idx]
51 | h5_file.flush()
52 |
53 |
54 | def get_weights(h5_dataset, labels):
55 | weights = [h5_dataset[label] for label in labels]
56 | return np.asarray(weights)
57 |
58 |
59 | def update_queue(dominant_queue, candidate_queue, predict, current_labels):
60 | """
61 | 更新支配队列
62 | :param dominant_queue: 支配队列
63 | :param candidate_queue: 候选队列
64 | :param predict: 预测的类别,numpy数组 (Batch,train_num_class)
65 | :param current_labels: 实际的当前类别(Batch,)
66 | :return:
67 | """
68 | predict_label = np.argmax(predict, axis=-1)
69 | for i in range(len(predict_label)):
70 | d_label_queue = dominant_queue[current_labels[i]]
71 | c_label_queue = candidate_queue[current_labels[i]]
72 | real_predict_label = current_labels[predict_label[i]]
73 | # 预测结果不是正确标签,不在正确标签的支配队列中,但在正确标签的候选队列中
74 | # 更新支配队列
75 | if predict_label[i] != i and \
76 | real_predict_label not in d_label_queue and \
77 | real_predict_label in c_label_queue:
78 | dominant_queue[current_labels[i]] = enqueue(d_label_queue, real_predict_label)
79 |
--------------------------------------------------------------------------------