├── .idea
├── deployment.xml
├── misc.xml
├── modules.xml
├── src.iml
└── workspace.xml
├── Groupnormalization.py
├── Network.py
├── README.md
├── Save_path.py
├── __init__.py
├── aug.py
├── data_gen.py
├── deploy_scripts
├── config.json
└── customize_service.py
├── efficientnet
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-36.pyc
│ ├── __version__.cpython-36.pyc
│ ├── keras.cpython-36.pyc
│ ├── model.cpython-36.pyc
│ ├── preprocessing.cpython-36.pyc
│ └── tfkeras.cpython-36.pyc
├── __version__.py
├── keras.py
├── model.py
├── preprocessing.py
└── tfkeras.py
├── eval.py
├── mean_std.py
├── mean_std.txt
├── models
├── __init__.py
└── resnet50.py
├── pip-requirements.txt
├── run.py
├── save_model.py
├── train.py
├── train.txt
├── tta_wrapper
├── __init__.py
├── __version__.py
├── augmentation.py
├── functional.py
├── layers.py
└── wrappers.py
└── warmup_cosine_decay_scheduler.py
/.idea/deployment.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/src.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
--------------------------------------------------------------------------------
/.idea/workspace.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
89 |
90 |
91 |
92 | save_weights_only
93 | Image
94 | ModelCheckpoint
95 | multiprocessing
96 | WEIGHTS_PATH
97 | _preprocess_symbolic_input
98 | WEIGHTS_PATH_NO_TOP
99 | CLASS_INDEX_PATH
100 | crossentropy
101 | file
102 | input_size
103 | classes
104 | Conv2dBn
105 | BaseSequence
106 | train_data_dir
107 | self.train
108 | restore_model_path
109 | FLAGS.train_local
110 | deploy_script_path
111 | groups
112 | self.groups
113 | pooling
114 | data_flow
115 |
116 |
117 |
118 |
119 |
120 |
121 |
122 |
123 |
124 |
125 |
126 |
127 |
128 |
129 |
130 |
131 |
132 |
133 |
134 |
135 |
136 |
137 |
138 |
139 |
140 |
141 |
142 |
143 |
144 |
145 |
146 |
147 |
148 |
149 |
150 |
151 |
152 |
153 |
154 |
155 | true
156 | DEFINITION_ORDER
157 |
158 |
159 |
160 |
161 |
162 |
163 |
164 |
165 |
166 |
167 |
168 |
169 |
170 |
171 |
172 |
173 |
174 |
175 |
176 |
177 |
178 |
179 |
180 |
181 |
182 |
183 |
184 |
185 |
186 |
187 |
188 |
189 |
190 |
191 |
192 |
193 |
194 |
195 |
196 |
197 |
198 |
199 |
200 |
201 |
202 |
203 |
204 |
205 |
206 |
207 |
208 |
209 |
210 |
211 |
212 |
213 |
214 |
215 |
216 |
217 |
218 |
219 |
220 |
221 |
222 |
223 |
224 |
225 |
226 |
227 |
228 |
229 |
230 |
231 |
232 |
233 |
234 |
235 |
236 |
237 |
238 |
239 |
240 |
241 |
242 |
243 |
244 |
245 |
246 |
247 |
248 |
249 |
250 |
251 |
252 |
253 |
254 |
255 |
256 |
257 |
258 |
259 |
260 |
261 |
262 |
263 |
264 |
265 |
266 |
267 |
268 |
269 |
270 |
271 |
272 |
273 |
274 |
275 |
276 |
277 |
278 |
279 |
280 |
281 |
282 |
283 |
284 |
285 |
286 |
287 |
288 |
289 |
290 |
291 |
292 |
293 |
294 |
295 |
296 |
297 |
298 | 1566114479120
299 |
300 |
301 | 1566114479120
302 |
303 |
304 |
305 |
306 |
307 |
308 |
309 |
310 |
311 |
312 |
313 |
314 |
315 |
316 |
317 |
318 |
319 |
320 |
321 |
322 |
323 |
324 |
325 |
326 |
327 |
328 |
329 |
330 |
331 |
332 |
333 |
334 |
335 |
336 |
337 |
338 |
339 |
340 |
341 |
342 |
343 |
344 |
345 |
346 |
347 |
348 |
349 |
350 |
351 |
352 |
353 |
354 |
355 |
356 |
357 |
358 |
359 |
360 |
361 |
362 |
363 |
364 |
365 |
366 |
367 | file://$PROJECT_DIR$/train.py
368 | 1
369 |
370 |
371 |
372 | file://$PROJECT_DIR$/Helpers/Inception_B.py
373 | 7
374 |
375 |
376 |
377 | file://G:/Google下载/tta_wrapper-master/setup.py
378 | 6
379 |
380 |
381 |
382 |
383 |
384 |
385 |
386 |
387 |
388 |
389 |
390 |
391 |
392 |
393 |
394 |
395 |
396 |
397 |
398 |
399 |
400 |
401 |
402 |
403 |
404 |
405 |
406 |
407 |
408 |
409 |
410 |
411 |
412 |
413 |
414 |
415 |
416 |
417 |
418 |
419 |
420 |
421 |
422 |
423 |
424 |
425 |
426 |
427 |
428 |
429 |
430 |
431 |
432 |
433 |
434 |
435 |
436 |
437 |
438 |
439 |
440 |
441 |
442 |
443 |
444 |
445 |
446 |
447 |
448 |
449 |
450 |
451 |
452 |
453 |
454 |
455 |
456 |
457 |
458 |
459 |
460 |
461 |
462 |
463 |
464 |
465 |
466 |
467 |
468 |
469 |
470 |
471 |
472 |
473 |
474 |
475 |
476 |
477 |
478 |
479 |
480 |
481 |
482 |
483 |
484 |
485 |
486 |
487 |
488 |
489 |
490 |
491 |
492 |
493 |
494 |
495 |
496 |
497 |
498 |
499 |
500 |
501 |
502 |
503 |
504 |
505 |
506 |
507 |
508 |
509 |
510 |
511 |
512 |
513 |
514 |
515 |
516 |
517 |
518 |
519 |
520 |
521 |
522 |
523 |
524 |
525 |
526 |
527 |
528 |
529 |
530 |
531 |
532 |
533 |
534 |
535 |
536 |
537 |
538 |
539 |
540 |
541 |
542 |
543 |
544 |
545 |
546 |
547 |
548 |
549 |
550 |
551 |
552 |
553 |
554 |
555 |
556 |
557 |
558 |
559 |
560 |
561 |
562 |
563 |
564 |
565 |
566 |
567 |
568 |
569 |
570 |
571 |
572 |
573 |
574 |
575 |
576 |
577 |
578 |
579 |
580 |
581 |
582 |
583 |
584 |
585 |
586 |
587 |
588 |
589 |
590 |
591 |
592 |
593 |
594 |
595 |
596 |
597 |
598 |
599 |
600 |
601 |
602 |
603 |
604 |
605 |
606 |
607 |
608 |
609 |
610 |
611 |
612 |
613 |
614 |
615 |
616 |
617 |
618 |
619 |
620 |
621 |
622 |
623 |
624 |
625 |
626 |
627 |
628 |
629 |
630 |
631 |
632 |
633 |
634 |
635 |
636 |
637 |
638 |
639 |
640 |
641 |
642 |
643 |
--------------------------------------------------------------------------------
/Groupnormalization.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | os.environ["KERAS_BACKEND"] = "tensorflow"
4 |
5 | from keras.engine import Layer, InputSpec
6 | from keras import initializers
7 | from keras import regularizers
8 | from keras import constraints
9 | from keras import backend as K
10 |
11 | from keras.utils.generic_utils import get_custom_objects
12 |
13 |
14 | class GroupNormalization(Layer):
15 | """Group normalization layer
16 |
17 | Group Normalization divides the channels into groups and computes within each group
18 | the mean and variance for normalization. GN's computation is independent of batch sizes,
19 | and its accuracy is stable in a wide range of batch sizes
20 |
21 | # Arguments
22 | groups: Integer, the number of groups for Group Normalization.
23 | axis: Integer, the axis that should be normalized
24 | (typically the features axis).
25 | For instance, after a `Conv2D` layer with
26 | `data_format="channels_first"`,
27 | set `axis=1` in `BatchNormalization`.
28 | epsilon: Small float added to variance to avoid dividing by zero.
29 | center: If True, add offset of `beta` to normalized tensor.
30 | If False, `beta` is ignored.
31 | scale: If True, multiply by `gamma`.
32 | If False, `gamma` is not used.
33 | When the next layer is linear (also e.g. `nn.relu`),
34 | this can be disabled since the scaling
35 | will be done by the next layer.
36 | beta_initializer: Initializer for the beta weight.
37 | gamma_initializer: Initializer for the gamma weight.
38 | beta_regularizer: Optional regularizer for the beta weight.
39 | gamma_regularizer: Optional regularizer for the gamma weight.
40 | beta_constraint: Optional constraint for the beta weight.
41 | gamma_constraint: Optional constraint for the gamma weight.
42 |
43 | # Input shape
44 | Arbitrary. Use the keyword argument `input_shape`
45 | (tuple of integers, does not include the samples axis)
46 | when using this layer as the first layer in a model.
47 |
48 | # Output shape
49 | Same shape as input.
50 |
51 | # References
52 | - [Group Normalization](https://arxiv.org/abs/1803.08494)
53 | """
54 |
55 | def __init__(self,
56 | groups=32,
57 | axis=-1,
58 | momentum=0.99,
59 | epsilon=1e-5,
60 | center=True,
61 | scale=True,
62 | beta_initializer='zeros',
63 | gamma_initializer='ones',
64 | moving_mean_initializer='zeros',
65 | moving_variance_initializer='ones',
66 | beta_regularizer=None,
67 | gamma_regularizer=None,
68 | beta_constraint=None,
69 | gamma_constraint=None,
70 | **kwargs):
71 | super(GroupNormalization, self).__init__(**kwargs)
72 | self.supports_masking = True
73 | self.groups = groups
74 | self.axis = axis
75 | self.momentum = momentum
76 | self.epsilon = epsilon
77 | self.center = center
78 | self.scale = scale
79 | self.beta_initializer = initializers.get(beta_initializer)
80 | self.gamma_initializer = initializers.get(gamma_initializer)
81 | self.moving_mean_initializer = initializers.get(moving_mean_initializer)
82 | self.moving_variance_initializer = initializers.get(moving_variance_initializer)
83 | self.beta_regularizer = regularizers.get(beta_regularizer)
84 | self.gamma_regularizer = regularizers.get(gamma_regularizer)
85 | self.beta_constraint = constraints.get(beta_constraint)
86 | self.gamma_constraint = constraints.get(gamma_constraint)
87 |
88 | def build(self, input_shape):
89 | dim = input_shape[self.axis]
90 |
91 | if dim is None:
92 | raise ValueError('Axis ' + str(self.axis) + ' of '
93 | 'input tensor should have a defined dimension '
94 | 'but the layer received an input with shape ' +
95 | str(input_shape) + '.')
96 |
97 | if dim < self.groups:
98 | raise ValueError('Number of groups (' + str(self.groups) + ') cannot be '
99 | 'more than the number of channels (' +
100 | str(dim) + ').')
101 |
102 | if dim % self.groups != 0:
103 | raise ValueError('Number of groups (' + str(self.groups) + ') must be a '
104 | 'multiple of the number of channels (' +
105 | str(dim) + ').')
106 |
107 | self.input_spec = InputSpec(ndim=len(input_shape),
108 | axes={self.axis: dim})
109 | shape_ = (1, dim, 1, 1)
110 | shape = (self.groups,)
111 | broadcast_shape = [-1, self.groups, 1, 1, 1]
112 |
113 | if self.scale:
114 | self.gamma = self.add_weight(shape=shape_,
115 | name='gamma',
116 | initializer=self.gamma_initializer,
117 | regularizer=self.gamma_regularizer,
118 | constraint=self.gamma_constraint)
119 |
120 | else:
121 | self.gamma = None
122 | if self.center:
123 | self.beta = self.add_weight(shape=shape_,
124 | name='beta',
125 | initializer=self.beta_initializer,
126 | regularizer=self.beta_regularizer,
127 | constraint=self.beta_constraint)
128 |
129 | else:
130 | self.beta = None
131 |
132 | self.moving_mean = self.add_weight(
133 | shape=shape,
134 | name="moving_mean",
135 | initializer=self.moving_mean_initializer,
136 | trainable=False)
137 | self.moving_mean = K.reshape(self.moving_mean, broadcast_shape)
138 | self.moving_mean = K.variable(value=self.moving_mean)
139 |
140 | self.moving_variance = self.add_weight(
141 | shape=shape,
142 | name="moving_variance",
143 | initializer=self.moving_variance_initializer,
144 | trainable=False)
145 | self.moving_variance = K.reshape(self.moving_variance, broadcast_shape)
146 | self.moving_variance = K.variable(value=self.moving_variance)
147 |
148 | self.built = True
149 |
150 | def call(self, inputs, training=None, **kwargs):
151 |
152 | G = self.groups
153 |
154 | # transpose:[ba,h,w,c] -> [bs,c,h,w]
155 | if self.axis in {-1, 3}:
156 | inputs = K.permute_dimensions(inputs, (0, 3, 1, 2))
157 |
158 | input_shape = K.int_shape(inputs)
159 | N, C, H, W = input_shape
160 | inputs = K.reshape(inputs, (-1, G, C // G, H, W))
161 | # inputs.assign_sub()
162 |
163 | # compute group-channel mean & variance
164 | gn_mean = K.mean(inputs, axis=[2, 3, 4], keepdims=True)
165 | gn_variance = K.var(inputs, axis=[2, 3, 4], keepdims=True)
166 |
167 | # compute group-normalization in different state
168 | def gn_inference():
169 | # when in test phase, just return moving_mean & moving_var
170 | mean, variance = self.moving_mean, self.moving_variance
171 | outputs = (inputs - mean) / (K.sqrt(variance + self.epsilon))
172 | outputs = K.reshape(outputs, [-1, C, H, W]) * self.gamma + self.beta
173 | # transpose: [bs,c,h,w] -> [ba,h,w,c]
174 | if self.axis in {-1, 3}:
175 | outputs = K.permute_dimensions(outputs, (0, 2, 3, 1))
176 |
177 | return outputs
178 |
179 | if training in {0, False}:
180 | return gn_inference()
181 |
182 | outputs = (inputs - gn_mean) / (K.sqrt(gn_variance + self.epsilon))
183 | outputs = K.reshape(outputs, [-1, C, H, W]) * self.gamma + self.beta
184 |
185 | # transpose: [bs,c,h,w] -> [ba,h,w,c]
186 | if self.axis in {-1, 3}:
187 | outputs = K.permute_dimensions(outputs, (0, 2, 3, 1))
188 |
189 | self.add_update([K.moving_average_update(self.moving_mean,
190 | gn_mean,
191 | self.momentum),
192 | K.moving_average_update(self.moving_variance,
193 | gn_variance,
194 | self.momentum)],
195 | inputs)
196 |
197 | # print("moving_mean shape : ",K.int_shape(self.moving_mean))
198 | # print("moving_mean: ",K.eval(self.moving_mean))
199 | # print("moving_variance shape: ",K.int_shape(self.moving_variance))
200 | # print("moving_variance: ",K.eval(self.moving_variance))
201 |
202 | return K.in_train_phase(outputs,
203 | gn_inference,
204 | training=training)
205 |
206 | def get_config(self):
207 | config = {
208 | 'groups': self.groups,
209 | 'axis': self.axis,
210 | 'momentum': self.momentum,
211 | 'epsilon': self.epsilon,
212 | 'center': self.center,
213 | 'scale': self.scale,
214 | 'beta_initializer': initializers.serialize(self.beta_initializer),
215 | 'gamma_initializer': initializers.serialize(self.gamma_initializer),
216 | 'moving_mean_initializer': initializers.serialize(self.moving_mean_initializer),
217 | 'moving_variance_initializer': initializers.serialize(self.moving_variance_initializer),
218 | 'beta_regularizer': regularizers.serialize(self.beta_regularizer),
219 | 'gamma_regularizer': regularizers.serialize(self.gamma_regularizer),
220 | 'beta_constraint': constraints.serialize(self.beta_constraint),
221 | 'gamma_constraint': constraints.serialize(self.gamma_constraint)
222 | }
223 | base_config = super(GroupNormalization, self).get_config()
224 | return dict(list(base_config.items()) + list(config.items()))
225 |
226 | def compute_output_shape(self, input_shape):
227 | return input_shape
228 |
229 |
230 | get_custom_objects().update({'GroupNormalization': GroupNormalization})
--------------------------------------------------------------------------------
/Network.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import os
3 | from glob import glob
4 | import numpy as np
5 | from keras import backend
6 | from keras.models import Model
7 | from keras.optimizers import adam, Nadam, SGD
8 | from keras.callbacks import TensorBoard, Callback
9 | from data_gen import data_flow
10 | from keras.layers import Dense, Input, Dropout, Activation,GlobalAveragePooling2D,LeakyReLU,BatchNormalization
11 | from keras.layers import concatenate,Concatenate,multiply, LocallyConnected2D, Lambda,Conv2D,GlobalMaxPooling2D,Flatten
12 | from keras.layers.core import Reshape
13 | from keras.layers import multiply
14 | import keras as ks
15 | from keras.models import load_model
16 | from keras.applications.xception import Xception
17 | from keras.applications.inception_v3 import InceptionV3
18 | from models.resnet50 import ResNet50
19 | from keras.layers import Flatten, Dense, AveragePooling2D
20 | from keras.models import Sequential
21 | from keras.utils import multi_gpu_model
22 | from Groupnormalization import GroupNormalization
23 | from keras_efficientnets import EfficientNetB5
24 | from keras_efficientnets import EfficientNetB4
25 | import efficientnet.keras as efn
26 | # ResNet50
27 | def model_fn(FLAGS, objective, optimizer, metrics):
28 | """
29 | pre-trained resnet50 model
30 | """
31 | base_model = ResNet50(weights="imagenet",
32 | include_top=False,
33 | pooling=None,
34 | input_shape=(FLAGS.input_size, FLAGS.input_size, 3),
35 | classes=FLAGS.num_classes)
36 | base_model = multi_gpu_model(base_model,4)
37 | for layer in base_model.layers:
38 | layer.trainable = False
39 | x = base_model.output
40 | x = Flatten()(x)
41 | predictions = Dense(FLAGS.num_classes, activation='softmax')(x)
42 | model = Model(inputs=base_model.input, outputs=predictions)
43 | model.compile(loss=objective, optimizer=optimizer, metrics=metrics)
44 | return model
45 |
46 | # SE-ResNet50
47 | def model_fn(FLAGS, objective, optimizer, metrics):
48 | inputs_dim = Input(shape=(FLAGS.input_size, FLAGS.input_size, 3))
49 | x = ResNet50(weights="imagenet",
50 | include_top=False,
51 | pooling=max,
52 | input_shape=(FLAGS.input_size, FLAGS.input_size, 3),
53 | classes=FLAGS.num_classes)(inputs_dim)
54 |
55 | squeeze = GlobalAveragePooling2D()(x)
56 |
57 | excitation = Dense(units=2048 // 16)(squeeze)
58 | excitation = Activation('relu')(excitation)
59 | excitation = Dense(units=2048)(excitation)
60 | excitation = Activation('sigmoid')(excitation)
61 | excitation = Reshape((1, 1, 2048))(excitation)
62 |
63 |
64 | scale = multiply([x, excitation])
65 |
66 | x = GlobalAveragePooling2D()(scale)
67 | # x = Dropout(0.3)(x)
68 | fc2 = Dense(FLAGS.num_classes)(x)
69 | fc2 = Activation('sigmoid')(fc2) #此处注意,为sigmoid函数
70 | model = Model(inputs=inputs_dim, outputs=fc2)
71 | # model.load_weights('/home/work/user-job-dir/src/SE-Xception.h5',by_name=True)
72 | # model = load_model('/home/work/user-job-dir/src/SE-Xception.h5')
73 | model.compile(loss=objective, optimizer=optimizer, metrics=metrics)
74 | return model
75 |
76 | # EfficientNet
77 | def model_fn(FLAGS, objective, optimizer, metrics):
78 | model = efn.EfficientNetB3(weights=None,
79 | include_top=False,
80 | input_shape=(FLAGS.input_size, FLAGS.input_size, 3),
81 | classes=FLAGS.num_classes)
82 | model.load_weights('/home/work/user-job-dir/src/efficientnet-b3_weights_tf_dim_ordering_tf_kernels_autoaugment_notop.h5')
83 | for i, layer in enumerate(model.layers):
84 | if "batch_normalization" in layer.name:
85 | model.layers[i] = GroupNormalization(groups=32, axis=-1, epsilon=0.00001)
86 | x = model.output
87 | x = GlobalAveragePooling2D()(x)
88 | x = Dropout(0.3)(x)
89 | predictions = Dense(FLAGS.num_classes, activation='softmax')(x) # activation="linear",activation='softmax'
90 | model = Model(input=model.input, output=predictions)
91 | model = multi_gpu_model(model, 4)
92 | model.compile(loss=objective, optimizer=optimizer, metrics=metrics)
93 | return model
94 | # Xception
95 | def model_fn(FLAGS, objective, optimizer, metrics):
96 | inputs_dim = Input(shape=(FLAGS.input_size, FLAGS.input_size, 3))
97 | Xception_notop = Xception(include_top=False,
98 | weights=None,
99 | input_tensor=None,
100 | input_shape=(FLAGS.input_size, FLAGS.input_size, 3),
101 | pooling=max)
102 |
103 | Xception_notop.load_weights('/home/work/user-job-dir/src/xception_weights_tf_dim_ordering_tf_kernels_notop.h5')
104 | output = Xception_notop.output
105 | output = GlobalAveragePooling2D()(output)
106 | output = Dense(FLAGS.num_classes, activation='softmax')(output)
107 | Xception_model = Model(inputs=Xception_notop.input, outputs=output)
108 | # Xception_model = multi_gpu_model(Xception_model, 4)
109 | Xception_model.compile(loss=objective, optimizer=optimizer, metrics=metrics)
110 | return Xception_model
111 |
112 | #######################################################################SE-Xception
113 | # Xception_notop = Xception_notop(inputs_dim)
114 | # squeeze = GlobalAveragePooling2D()(Xception_notop)
115 | # excitation = Dense(units=2048 // 16)(squeeze)
116 | # excitation = Activation('relu')(excitation)
117 | # excitation = Dense(units=2048)(excitation)
118 | # excitation = Activation('sigmoid')(excitation)
119 | # excitation = Reshape((1, 1, 2048))(excitation)
120 | #
121 | # scale = multiply([Xception_notop, excitation])
122 | # x = GlobalAveragePooling2D()(scale)
123 | # x = Dropout(0.3)(x)
124 | # fc2 = Dense(FLAGS.num_classes)(x)
125 | # fc2 = Activation('sigmoid')(fc2) #此处注意,为sigmoid函数
126 | # model = Model(inputs=inputs_dim, outputs=fc2)
127 | # # model = multi_gpu_model(model, 4)
128 | # model.compile(loss=objective, optimizer=optimizer, metrics=metrics)
129 | # return model
130 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # 前言
2 |
3 | #### 本文介绍的分类方式可能比较繁琐,因为它是采用华为云比赛的提交模式进行的。简洁的分类版本点击这里:https://github.com/wusaifei/HWCC_image_classification
4 |
5 | 1.图像分类的更多tricks(注意力机制 keras,TensorFlow和pytorch 版本等):[图像分类比赛tricks:“华为云杯”2019人工智能创新应用大赛](https://zhuanlan.zhihu.com/p/98740628)
6 |
7 | 2.大家如果对目标检测比赛比较感兴趣的话,可以看一下我这篇对目标检测比赛tricks的详细介绍:[目标检测比赛中的tricks(已更新更多代码解析)](https://zhuanlan.zhihu.com/p/102817180)
8 |
9 | 3.目标检测比赛笔记:[目标检测比赛笔记](https://zhuanlan.zhihu.com/p/137567177)
10 |
11 | 4.如果对换脸技术比较感兴趣的同学可以点击这里:[deepfakes/faceswap:换脸技术详细教程,手把手教学,简单快速上手!!](https://zhuanlan.zhihu.com/p/376853800)
12 |
13 | 5.在日常调参的摸爬滚打中,参考了不少他人的调参经验,也积累了自己的一些有效调参方法,慢慢总结整理如下。希望对新晋算法工程师有所助力呀~:[写给新手炼丹师:2021版调参上分手册](https://zhuanlan.zhihu.com/p/376068083)
14 |
15 | 6.[深度学习中不同类型卷积的综合介绍:2D卷积、3D卷积、转置卷积、扩张卷积、可分离卷积、扁平卷积、分组卷积、随机分组卷积、逐点分组卷积等](https://zhuanlan.zhihu.com/p/366744794)
16 |
17 | 7.分类必备知识:[Softmax函数和Sigmoid函数的区别与联系](https://zhuanlan.zhihu.com/p/356976844)、[深度学习中学习率和batchsize对模型准确率的影响](https://zhuanlan.zhihu.com/p/277487038)、[准确率(Precision)、召回率(Recall)、F值(F-Measure)、平均正确率,IoU](https://zhuanlan.zhihu.com/p/101101207)、[利用python一层一层可视化卷积神经网络,以ResNet50为例](https://zhuanlan.zhihu.com/p/101038013)
18 |
19 | 8.[pytorch笔记:Efficientnet微调](https://zhuanlan.zhihu.com/p/102467338)
20 |
21 | 9.[keras, TensorFlow中加入注意力机制](https://zhuanlan.zhihu.com/p/99260231)、[pytorch中加入注意力机制(CBAM),以ResNet为例。解析到底要不要用ImageNet预训练?如何加预训练参数?](https://zhuanlan.zhihu.com/p/99261200)
22 |
23 |
24 |
25 |
26 | # 增添内容
27 |
28 | #### 已修改成本地可以运行。
29 |
30 | 修改方法:
31 |
32 | 1.`save_model.py|train.py|eval.py|run.py|`中`moxing.framework.file`函数全部换成`os.path`和`shutil.copy`函数。因为python里面暂时没有moxing框架。
33 |
34 | 2.注释掉`run.py`文件里面的下面几行代码:
35 |
36 | # FLAGS.tmp = os.path.join(FLAGS.local_data_root, 'tmp/')
37 | # print(FLAGS.tmp)
38 | # if not os.path.exists(FLAGS.tmp):
39 | # os.mkdir(FLAGS.tmp)
40 |
41 | #### .md后面增添SVM分类器、决策树分类器、随机森林分类器。
42 |
43 | # 运行环境
44 |
45 | >python3.6
46 |
47 | >tensorflow 1.13.1
48 |
49 | >keras 2.24
50 |
51 | 新版本运行的话可能会运行不成功。
52 |
53 | # garbage_classify
54 | ## 赛题背景
55 | 比赛链接:[华为云人工智能大赛·垃圾分类挑战杯](https://developer.huaweicloud.com/competition/competitions/1000007620/introduction)
56 |
57 | 如今,垃圾分类已成为社会热点话题。其实在2019年4月26日,我国住房和城乡建设部等部门就发布了《关于在全国地级及以上城市全面开展生活垃圾分类工作的通知》,决定自2019年起在全国地级及以上城市全面启动生活垃圾分类工作。到2020年底,46个重点城市基本建成生活垃圾分类处理系统。
58 |
59 | 人工垃圾分类投放是垃圾处理的第一环节,但能够处理海量垃圾的环节是垃圾处理厂。然而,目前国内的垃圾处理厂基本都是采用人工流水线分拣的方式进行垃圾分拣,存在工作环境恶劣、劳动强度大、分拣效率低等缺点。在海量垃圾面前,人工分拣只能分拣出极有限的一部分可回收垃圾和有害垃圾,绝大多数垃圾只能进行填埋,带来了极大的资源浪费和环境污染危险。
60 |
61 | 随着深度学习技术在视觉领域的应用和发展,让我们看到了利用AI来自动进行垃圾分类的可能,通过摄像头拍摄垃圾图片,检测图片中垃圾的类别,从而可以让机器自动进行垃圾分拣,极大地提高垃圾分拣效率。
62 |
63 | 因此,华为云面向社会各界精英人士举办了本次垃圾分类竞赛,希望共同探索垃圾分类的AI技术,为垃圾分类这个利国利民的国家大计贡献自己的一份智慧。
64 |
65 | ## 赛题说明
66 | 本赛题采用深圳市垃圾分类标准,赛题任务是对垃圾图片进行分类,即首先识别出垃圾图片中物品的类别(比如易拉罐、果皮等),然后查询垃圾分类规则,输出该垃圾图片中物品属于可回收物、厨余垃圾、有害垃圾和其他垃圾中的哪一种。
67 | 模型输出格式示例:
68 |
69 | {
70 |
71 | " result ": "可回收物/易拉罐"
72 |
73 | }
74 |
75 | ## 垃圾种类40类
76 |
77 | {
78 | "0": "其他垃圾/一次性快餐盒",
79 | "1": "其他垃圾/污损塑料",
80 | "2": "其他垃圾/烟蒂",
81 | "3": "其他垃圾/牙签",
82 | "4": "其他垃圾/破碎花盆及碟碗",
83 | "5": "其他垃圾/竹筷",
84 | "6": "厨余垃圾/剩饭剩菜",
85 | "7": "厨余垃圾/大骨头",
86 | "8": "厨余垃圾/水果果皮",
87 | "9": "厨余垃圾/水果果肉",
88 | "10": "厨余垃圾/茶叶渣",
89 | "11": "厨余垃圾/菜叶菜根",
90 | "12": "厨余垃圾/蛋壳",
91 | "13": "厨余垃圾/鱼骨",
92 | "14": "可回收物/充电宝",
93 | "15": "可回收物/包",
94 | "16": "可回收物/化妆品瓶",
95 | "17": "可回收物/塑料玩具",
96 | "18": "可回收物/塑料碗盆",
97 | "19": "可回收物/塑料衣架",
98 | "20": "可回收物/快递纸袋",
99 | "21": "可回收物/插头电线",
100 | "22": "可回收物/旧衣服",
101 | "23": "可回收物/易拉罐",
102 | "24": "可回收物/枕头",
103 | "25": "可回收物/毛绒玩具",
104 | "26": "可回收物/洗发水瓶",
105 | "27": "可回收物/玻璃杯",
106 | "28": "可回收物/皮鞋",
107 | "29": "可回收物/砧板",
108 | "30": "可回收物/纸板箱",
109 | "31": "可回收物/调料瓶",
110 | "32": "可回收物/酒瓶",
111 | "33": "可回收物/金属食品罐",
112 | "34": "可回收物/锅",
113 | "35": "可回收物/食用油桶",
114 | "36": "可回收物/饮料瓶",
115 | "37": "有害垃圾/干电池",
116 | "38": "有害垃圾/软膏",
117 | "39": "有害垃圾/过期药物"
118 | }
119 | ## efficientNet默认参数
120 |
121 | (width_coefficient, depth_coefficient, resolution, dropout_rate)
122 | 'efficientnet-b0': (1.0, 1.0, 224, 0.2),
123 | 'efficientnet-b1': (1.0, 1.1, 240, 0.2),
124 | 'efficientnet-b2': (1.1, 1.2, 260, 0.3),
125 | 'efficientnet-b3': (1.2, 1.4, 300, 0.3),
126 | 'efficientnet-b4': (1.4, 1.8, 380, 0.4),
127 | 'efficientnet-b5': (1.6, 2.2, 456, 0.4),
128 | 'efficientnet-b6': (1.8, 2.6, 528, 0.5),
129 | 'efficientnet-b7': (2.0, 3.1, 600, 0.5),
130 |
131 | efficientNet的论文地址:https://arxiv.org/pdf/1905.11946.pdf
132 |
133 |
134 | ## 代码解析
135 | ### BaseLine改进
136 | 1.使用多种模型进行对比实验,[ResNet50](https://arxiv.org/pdf/1512.03385.pdf), [SE-ResNet50](https://arxiv.org/abs/1709.01507), [Xception](https://arxiv.org/abs/1610.02357), SE-Xception, [efficientNetB5](https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet)。
137 |
138 | 2.使用[组归一化(GroupNormalization)](https://arxiv.org/abs/1803.08494)代替[批量归一化(batch_normalization)](https://arxiv.org/abs/1502.03167)-解决当Batch_size过小导致的准确率下降。当batch_size小于16时,BN的error率
139 | 逐渐上升,`train.py`。
140 |
141 |
142 | for i, layer in enumerate(model.layers):
143 | if "batch_normalization" in layer.name:
144 | model.layers[i] = GroupNormalization(groups=32, axis=-1, epsilon=0.00001)
145 |
146 | 3.[NAdam优化器](http://cs229.stanford.edu/proj2015/054_report.pdf)
147 |
148 |
149 | optimizer = Nadam(lr=FLAGS.learning_rate, beta_1=0.9, beta_2=0.999, epsilon=1e-08, schedule_decay=0.004)
150 |
151 | 4.自定义学习率-[SGDR余弦退火学习率](https://arxiv.org/abs/1608.03983)
152 |
153 |
154 | sample_count = len(train_sequence) * FLAGS.batch_size
155 | epochs = FLAGS.max_epochs
156 | warmup_epoch = 5
157 | batch_size = FLAGS.batch_size
158 | learning_rate_base = FLAGS.learning_rate
159 | total_steps = int(epochs * sample_count / batch_size)
160 | warmup_steps = int(warmup_epoch * sample_count / batch_size)
161 |
162 | warm_up_lr = WarmUpCosineDecayScheduler(learning_rate_base=learning_rate_base,
163 | total_steps=total_steps,
164 | warmup_learning_rate=0,
165 | warmup_steps=warmup_steps,
166 | hold_base_rate_steps=0,
167 | )
168 |
169 | 5.数据增强:随机水平翻转、随机垂直翻转、以一定概率随机旋转90°、180°、270°、随机crop(0-10%)等(详细代码请看`aug.py`和`data_gen.py`)
170 |
171 | def img_aug(self, img):
172 | data_gen = ImageDataGenerator()
173 | dic_parameter = {'flip_horizontal': random.choice([True, False]),
174 | 'flip_vertical': random.choice([True, False]),
175 | 'theta': random.choice([0, 0, 0, 90, 180, 270])
176 | }
177 |
178 |
179 | img_aug = data_gen.apply_transform(img, transform_parameters=dic_parameter)
180 | return img_aug
181 |
182 |
183 | from imgaug import augmenters as iaa
184 | import imgaug as ia
185 |
186 | def augumentor(image):
187 | sometimes = lambda aug: iaa.Sometimes(0.5, aug)
188 | seq = iaa.Sequential(
189 | [
190 | iaa.Fliplr(0.5),
191 | iaa.Flipud(0.5),
192 | iaa.Affine(rotate=(-10, 10)),
193 | sometimes(iaa.Crop(percent=(0, 0.1), keep_size=True)),
194 | ],
195 | random_order=True
196 | )
197 |
198 |
199 | image_aug = seq.augment_image(image)
200 |
201 | return image_aug
202 |
203 |
204 | 6.标签平滑`data_gen.py`
205 |
206 |
207 | def smooth_labels(y, smooth_factor=0.1):
208 | assert len(y.shape) == 2
209 | if 0 <= smooth_factor <= 1:
210 | # label smoothing ref: https://www.robots.ox.ac.uk/~vgg/rg/papers/reinception.pdf
211 | y *= 1 - smooth_factor
212 | y += smooth_factor / y.shape[1]
213 | else:
214 | raise Exception(
215 | 'Invalid label smoothing factor: ' + str(smooth_factor))
216 | return y
217 |
218 | 7.数据归一化:得到所有图像的位置信息`Save_path.py`并计算所有图像的均值和方差`mead_std.py`
219 |
220 |
221 | normMean = [0.56719673 0.5293289 0.48351972]
222 | normStd = [0.20874391 0.21455203 0.22451781]
223 |
224 |
225 | img = np.asarray(img, np.float32) / 255.0
226 | mean = [0.56719673, 0.5293289, 0.48351972]
227 | std = [0.20874391, 0.21455203, 0.22451781]
228 | img[..., 0] -= mean[0]
229 | img[..., 1] -= mean[1]
230 | img[..., 2] -= mean[2]
231 | img[..., 0] /= std[0]
232 | img[..., 1] /= std[1]
233 | img[..., 2] /= std[2]
234 |
235 | ## 各部分代码解析
236 |
237 | * `deploy_scripts`——推理文件,需要修改
238 |
239 | 1.self.input_size = 456
240 |
241 |
242 | 2. def _inference(self, data):
243 | """
244 | model inference function
245 | Here are a inference example of resnet, if you use another model, please modify this function
246 | """
247 | img = data[self.input_key_1]
248 | img = img[np.newaxis, :, :, :] # the input tensor shape of resnet is [?, 224, 224, 3]
249 | img = np.asarray(img, np.float32) / 255.0
250 | mean = [0.56719673, 0.5293289, 0.48351972]
251 | std = [0.20874391, 0.21455203, 0.22451781]
252 | img[..., 0] -= mean[0]
253 | img[..., 1] -= mean[1]
254 | img[..., 2] -= mean[2]
255 | img[..., 0] /= std[0]
256 | img[..., 1] /= std[1]
257 | img[..., 2] /= std[2]
258 | pred_score = self.sess.run([self.output_score], feed_dict={self.input_images: img})
259 | if pred_score is not None:
260 | pred_label = np.argmax(pred_score[0], axis=1)[0]
261 | result = {'result': self.label_id_name_dict[str(pred_label)]}
262 | else:
263 | result = {'result': 'predict score is None'}
264 | return result
265 |
266 |
267 | * `aug.py`——图像增强代码(`imgaug`函数)
268 |
269 | * `data_gen.py`——数据预处理代码,包括数据增强、标签平滑以及train和val的划分
270 |
271 | * `eval.py`——估值函数
272 |
273 | * `Groupnormalization.py`——组归一化
274 |
275 | * `mean_std.py`——图像均值和方差
276 |
277 | * `Network.py`——ResNet50, SE-ResNet50, Xeception, SE-Xeception, efficientNetB5
278 |
279 | * `run.py`——运行代码
280 |
281 | * `save_model.py`——保存模型
282 |
283 | * `Save_path.py`——图像位置信息
284 |
285 | * `train.py`——训练网络部分,包括网络,loss, optimizer等
286 |
287 | * `warmup_cosine_decay_scheduler.py`——余弦退火学习率
288 |
289 | * `pip-requirements.txt`——安装其他所需的库, 安装命令为:`pip install -r requirements.txt`
290 |
291 | ## 使用
292 | ### 前期准备
293 | * 克隆此存储库
294 |
295 |
296 |
297 | git clone https://github.com/wusaifei/garbage_classify.git
298 |
299 |
300 |
301 | * [垃圾分类数据集下载地址,此链接已经不存在请用下面的百度云盘下载。](https://modelarts-competitions.obs.cn-north-1.myhuaweicloud.com/garbage_classify/dataset/garbage_classify.zip)
302 |
303 | * 垃圾分类数据集下载地址链接:https://pan.baidu.com/s/11xp0jBKAitU8r0_RWVpX1Q , 提取码:jqa1
304 |
305 | * 扩充数据集:链接:https://pan.baidu.com/s/1SulD2MqZx_U891JXeI2-2g ,提取码:epgs
306 |
307 |
308 | ### 运行
309 | * 运行`Save_path.py`得到图像的位置信息
310 | * 运行`mean_std.py`得到图像的均值和方差
311 | * `run.py`——训练
312 |
313 |
314 | python run.py --data_url='./garbage_classify/train_data' --train_url='./model_snapshots' --deploy_script_path='./deploy_scripts'
315 |
316 |
317 | * `run.py`——保存为pd
318 |
319 |
320 |
321 | python run.py --mode=save_pb --deploy_script_path='./deploy_scripts' --freeze_weights_file_path='./model_snapshots/weights_024_0.9470.h5' --num_classes=40
322 |
323 |
324 |
325 | * `run.py`——估值
326 |
327 |
328 |
329 | python run.py --mode=eval --eval_pb_path='./model_snapshots/model' --test_data_url='./garbage_classify/train_data'
330 |
331 | ## 增添SVM分类器
332 |
333 |
334 | > 当模型训练完之后,用训练好的模型预测训练数据,并将它们保存在数组中。然后放到SVC中进行训练,最后将训练好的分类器对抽取的测试数据特征进行分类。
335 |
336 |
337 | 代码如下:
338 |
339 | target_pre_con = []
340 | target_con = []
341 | for i, data in tqdm(enumerate(trian_dataloaders_dict['all_data'])):
342 |
343 | input, target = data
344 | input, target = input.to(device), target.to(device)
345 | target_pre = model(input)
346 |
347 | target_pre = target_pre.cpu()
348 | target = target.cpu()
349 |
350 | target_pre = target_pre.detach().numpy()
351 | target = target.detach().numpy()
352 |
353 | target_pre_con.extend(target_pre)
354 | target_con.extend(target)
355 |
356 | target_pre_con = np.asarray(target_pre_con)
357 | target_con = np.asarray(target_con)
358 |
359 | print(target_pre_con.shape)
360 | print(target_con.shape)
361 | # 提取特征用clf:svm
362 | clf = SVC(kernel='rbf', gamma='auto')
363 | clf.fit(target_pre_con, target_con)
364 |
365 | for i, (input, filepath) in tqdm(enumerate(test_loader)):
366 | # print(input.shape[1])
367 | with torch.no_grad():
368 | image_var = input.to(device)
369 | y_pred = model(image_var)
370 | label = y_pred.cpu().data.numpy()
371 | # 提取特征用clf分类
372 | label = clf.predict(label)
373 | labels.append(label)
374 |
375 |
376 | ## 决策树分类器和随机森林分类器
377 |
378 | 只需要将clf换成`DecisionTreeClassifier()`或`RandomForestClassifier()`即可。
379 |
380 | ```
381 | from sklearn.tree import DecisionTreeClassifier
382 |
383 | from sklearn.ensemble import RandomForestClassifier
384 |
385 | clf = DecisionTreeClassifier()
386 |
387 | clf = RandomForestClassifier()
388 | ```
389 |
390 |
391 |
392 | ## 实验结果
393 |
394 | * 网络的改进:`ResNet50-0.689704`,`SE-ResNet50-0.83259`,`Xception-0.879003`,`EfficientNetB5-0.924113`(无数据增强)
395 |
396 | * 数据增强:由0.924113提升到0.934721
397 |
398 | * 标签平滑和数据归一化处理、学习率策略的调整`ReduceLROnPlateau`换成`WarmUpCosineDecayScheduler`,最终准确率在95%左右
399 |
400 |
401 | 大家也可以在分类代码中增加测试时增强,详细代码在`tta_wrapper`文件夹里面,里面有详细的介绍和测试用例。
402 |
403 | # 后续
404 |
405 | 1. 增添模型融合(投票)。
406 |
407 | 2. 测试时增强。
408 |
409 | 3. [Cutout](https://github.com/uoguelph-mlrg/Cutout), [Mixup](https://github.com/facebookresearch/mixup-cifar10), [CutMix](https://github.com/clovaai/CutMix-PyTorch)等数据增强策略。
410 |
411 | 4. 标签平滑。
412 |
413 |
414 |
415 |
416 |
417 |
418 |
419 |
420 |
421 |
422 |
423 |
424 |
425 |
426 |
427 |
--------------------------------------------------------------------------------
/Save_path.py:
--------------------------------------------------------------------------------
1 | import os
2 | dress="../garbage_classify-master/datasets/garbage_classify/train_data/"
3 | with open("train.txt","w") as f:
4 | for root,dirs,files in os.walk(dress):
5 | # root = root.replace(dress,'')
6 | for file in files:
7 | f.write(os.path.join(root, file) + "\n")
8 |
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/aug.py:
--------------------------------------------------------------------------------
1 | from imgaug import augmenters as iaa
2 | import imgaug as ia
3 |
4 | def augumentor(image):
5 | sometimes = lambda aug: iaa.Sometimes(0.5, aug)
6 | seq = iaa.Sequential(
7 | [
8 | iaa.Fliplr(0.5),
9 | iaa.Flipud(0.5),
10 | iaa.Affine(rotate=(-10, 10)),
11 | sometimes(iaa.Crop(percent=(0, 0.1), keep_size=True)),
12 | ],
13 | random_order=True
14 | )
15 |
16 |
17 | image_aug = seq.augment_image(image)
18 |
19 | return image_aug
--------------------------------------------------------------------------------
/data_gen.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import os
3 | import math
4 | import codecs
5 | import random
6 | import numpy as np
7 | from glob import glob
8 | from PIL import Image
9 | import cv2
10 | from keras.utils import np_utils, Sequence
11 | from sklearn.model_selection import train_test_split
12 | from keras.preprocessing.image import ImageDataGenerator
13 | from aug import augumentor
14 |
15 | class BaseSequence(Sequence):
16 | """
17 | 基础的数据流生成器,每次迭代返回一个batch
18 | BaseSequence可直接用于fit_generator的generator参数
19 | fit_generator会将BaseSequence再次封装为一个多进程的数据流生成器
20 | 而且能保证在多进程下的一个epoch中不会重复取相同的样本
21 | """
22 | def __init__(self, img_paths, labels, batch_size, img_size, train=False):
23 | assert len(img_paths) == len(labels), "len(img_paths) must equal to len(lables)"
24 | assert img_size[0] == img_size[1], "img_size[0] must equal to img_size[1]"
25 | self.x_y = np.hstack((np.array(img_paths).reshape(len(img_paths), 1), np.array(labels)))
26 | self.batch_size = batch_size
27 | self.img_size = img_size
28 | self.train = train
29 |
30 | def __len__(self):
31 | return math.ceil(len(self.x_y) / self.batch_size)
32 |
33 | @staticmethod
34 | def center_img(img, size=None, fill_value=255):
35 | """
36 | center img in a square background
37 | """
38 | h, w = img.shape[:2]
39 | if size is None:
40 | size = max(h, w)
41 | shape = (size, size) + img.shape[2:]
42 | background = np.full(shape, fill_value, np.uint8)
43 | center_x = (size - w) // 2
44 | center_y = (size - h) // 2
45 | background[center_y:center_y + h, center_x:center_x + w] = img
46 | return background
47 |
48 | def img_aug(self, img):
49 | data_gen = ImageDataGenerator()
50 | dic_parameter = {'flip_horizontal': random.choice([True, False]),
51 | 'flip_vertical': random.choice([True, False]),
52 | 'theta': random.choice([0, 0, 0, 90, 180, 270])
53 | }
54 |
55 |
56 | img_aug = data_gen.apply_transform(img, transform_parameters=dic_parameter)
57 | return img_aug
58 |
59 | def preprocess_img(self, img_path):
60 | """
61 | image preprocessing
62 | you can add your special preprocess method here
63 | """
64 | img = Image.open(img_path)
65 | resize_scale = self.img_size[0] / max(img.size[:2])
66 | img = img.resize((int(img.size[0] * resize_scale), int(img.size[1] * resize_scale)))
67 | img = img.convert('RGB')
68 | img = np.array(img)
69 |
70 | # 数据归一化
71 | img = np.asarray(img, np.float32) / 255.0
72 | mean = [0.56719673, 0.5293289, 0.48351972]
73 | std = [0.20874391, 0.21455203, 0.22451781]
74 | img[..., 0] -= mean[0]
75 | img[..., 1] -= mean[1]
76 | img[..., 2] -= mean[2]
77 | img[..., 0] /= std[0]
78 | img[..., 1] /= std[1]
79 | img[..., 2] /= std[2]
80 |
81 | # 数据增强
82 | if self.train:
83 | # img = self.img_aug(img)
84 | img = augumentor(img)
85 | img = self.center_img(img, self.img_size[0])
86 | return img
87 |
88 |
89 | ########################################
90 | # img = Image.open(img_path)
91 | # img = img.resize((self.img_size[0], self.img_size[0]))
92 | # img = img.convert('RGB')
93 | # img = np.array(img)
94 | # img = img.astype(np.float)
95 | # # if self.train:
96 | # # # img = self.img_aug(img)
97 | # # img = augumentor(img)
98 | # img = img[:, :, ::-1]
99 | #
100 | # return img
101 | ########################################
102 | # Img = Image.open(img_path)
103 | # Img = cv2.cvtColor(np.asarray(Img), cv2.COLOR_RGB2BGR)
104 | # Img = cv2.resize(Img, (self.img_size[0], self.img_size[0]))
105 | # Img = Img[:, :, (2, 1, 0)]
106 | # Img = np.asarray(Img)
107 | # Img = Img.astype(np.float)
108 | # return Img
109 |
110 |
111 |
112 |
113 | def __getitem__(self, idx):
114 | batch_x = self.x_y[idx * self.batch_size: (idx + 1) * self.batch_size, 0]
115 | batch_y = self.x_y[idx * self.batch_size: (idx + 1) * self.batch_size, 1:]
116 | batch_x = np.array([self.preprocess_img(img_path) for img_path in batch_x])
117 | batch_y = np.array(batch_y).astype(np.float32)
118 | return batch_x, batch_y
119 |
120 | def on_epoch_end(self):
121 | """Method called at the end of every epoch.
122 | """
123 | np.random.shuffle(self.x_y)
124 |
125 | # 标签平滑
126 | def smooth_labels(y, smooth_factor=0.1):
127 | assert len(y.shape) == 2
128 | if 0 <= smooth_factor <= 1:
129 | # label smoothing ref: https://www.robots.ox.ac.uk/~vgg/rg/papers/reinception.pdf
130 | y *= 1 - smooth_factor
131 | y += smooth_factor / y.shape[1]
132 | else:
133 | raise Exception(
134 | 'Invalid label smoothing factor: ' + str(smooth_factor))
135 | return y
136 |
137 | def data_flow(train_data_dir, batch_size, num_classes, input_size): # need modify
138 | label_files = glob(os.path.join(train_data_dir, '*.txt'))
139 | random.shuffle(label_files)
140 | img_paths = []
141 | labels = []
142 | for index, file_path in enumerate(label_files):
143 | with codecs.open(file_path, 'r', 'utf-8') as f:
144 | line = f.readline()
145 | line_split = line.strip().split(', ')
146 | if len(line_split) != 2:
147 | print('%s contain error lable' % os.path.basename(file_path))
148 | continue
149 | img_name = line_split[0]
150 | label = int(line_split[1])
151 | img_paths.append(os.path.join(train_data_dir, img_name))
152 | labels.append(label)
153 |
154 | labels = np_utils.to_categorical(labels, num_classes)
155 | # 标签平滑
156 | labels = smooth_labels(labels)
157 |
158 | train_img_paths, validation_img_paths, train_labels, validation_labels = \
159 | train_test_split(img_paths, labels, test_size=0.1, random_state=0)
160 | print('total samples: %d, training samples: %d, validation samples: %d' % (
161 | len(img_paths), len(train_img_paths), len(validation_img_paths)))
162 |
163 | print('total samples: %d, training samples: %d, validation samples: %d' % (len(img_paths), len(train_img_paths), len(validation_img_paths)))
164 |
165 | train_sequence = BaseSequence(train_img_paths, train_labels, batch_size, [input_size, input_size], True)
166 | validation_sequence = BaseSequence(validation_img_paths, validation_labels, batch_size, [input_size, input_size], False)
167 | # # 构造多进程的数据流生成器
168 | # train_enqueuer = OrderedEnqueuer(train_sequence, use_multiprocessing=True, shuffle=True)
169 | # validation_enqueuer = OrderedEnqueuer(validation_sequence, use_multiprocessing=True, shuffle=True)
170 | #
171 | # # 启动数据生成器
172 | # n_cpu = multiprocessing.cpu_count()
173 | # train_enqueuer.start(workers=int(n_cpu * 0.7), max_queue_size=10)
174 | # validation_enqueuer.start(workers=1, max_queue_size=10)
175 | # train_data_generator = train_enqueuer.get()
176 | # validation_data_generator = validation_enqueuer.get()
177 |
178 | # return train_enqueuer, validation_enqueuer, train_data_generator, validation_data_generator
179 | return train_sequence, validation_sequence
180 |
181 |
182 | if __name__ == '__main__':
183 | # train_enqueuer, validation_enqueuer, train_data_generator, validation_data_generator = data_flow(dog_cat_data_path, batch_size)
184 | # for i in range(10):
185 | # train_data_batch = next(train_data_generator)
186 | # train_enqueuer.stop()
187 | # validation_enqueuer.stop()
188 | train_sequence, validation_sequence = data_flow(train_data_dir, batch_size)
189 | batch_data, bacth_label = train_sequence.__getitem__(5)
190 | label_name = ['cat', 'dog']
191 | for index, data in enumerate(batch_data):
192 | img = Image.fromarray(data[:, :, ::-1])
193 | img.save('./debug/%d_%s.jpg' % (index, label_name[int(bacth_label[index][1])]))
194 | train_sequence.on_epoch_end()
195 | batch_data, bacth_label = train_sequence.__getitem__(5)
196 | for index, data in enumerate(batch_data):
197 | img = Image.fromarray(data[:, :, ::-1])
198 | img.save('./debug/%d_2_%s.jpg' % (index, label_name[int(bacth_label[index][1])]))
199 | train_sequence.on_epoch_end()
200 | batch_data, bacth_label = train_sequence.__getitem__(5)
201 | for index, data in enumerate(batch_data):
202 | img = Image.fromarray(data[:, :, ::-1])
203 | img.save('./debug/%d_3_%s.jpg' % (index, label_name[int(bacth_label[index][1])]))
204 | print('end')
205 |
--------------------------------------------------------------------------------
/deploy_scripts/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "model_algorithm": "image_classification",
3 | "model_type": "TensorFlow",
4 | "runtime": "python3.6",
5 | "metrics": {
6 | "f1": 0,
7 | "accuracy": 0.6253,
8 | "precision": 0,
9 | "recall": 0
10 | },
11 | "apis": [
12 | {
13 | "procotol": "http",
14 | "url": "/",
15 | "method": "post",
16 | "request": {
17 | "Content-type": "multipart/form-data",
18 | "data": {
19 | "type": "object",
20 | "properties": {
21 | "input_img": {"type": "file"}
22 | },
23 | "required": ["input_img"]
24 | }
25 | },
26 | "response": {
27 | "Content-type": "multipart/form-data",
28 | "data": {
29 | "type": "object",
30 | "properties": {
31 | "result": {"type": "string"}
32 | },
33 | "required": ["result"]
34 | }
35 | }
36 | }
37 | ],
38 | "dependencies": [
39 | {
40 | "installer": "pip",
41 | "packages": [
42 | {
43 | "package_name": "Pillow",
44 | "package_version": "5.0.0",
45 | "restraint": "ATLEAST"
46 | }
47 | ]
48 | }
49 | ]
50 | }
--------------------------------------------------------------------------------
/deploy_scripts/customize_service.py:
--------------------------------------------------------------------------------
1 | import ast
2 | import numpy as np
3 | from PIL import Image
4 | import tensorflow as tf
5 | from collections import OrderedDict
6 | from tensorflow.python.saved_model import tag_constants
7 | from model_service.tfserving_model_service import TfServingBaseService
8 |
9 |
10 | class garbage_classify_service(TfServingBaseService):
11 | def __init__(self, model_name, model_path):
12 | # these three parameters are no need to modify
13 | self.model_name = model_name
14 | self.model_path = model_path
15 | self.signature_key = 'predict_images'
16 |
17 | self.input_size = 456 # the input image size of the model
18 |
19 | # add the input and output key of your pb model here,
20 | # these keys are defined when you save a pb file
21 | self.input_key_1 = 'input_img'
22 | self.output_key_1 = 'output_score'
23 | config = tf.ConfigProto(allow_soft_placement=True)
24 | with tf.get_default_graph().as_default():
25 | self.sess = tf.Session(graph=tf.Graph(), config=config)
26 | meta_graph_def = tf.saved_model.loader.load(self.sess, [tag_constants.SERVING], self.model_path)
27 | self.signature = meta_graph_def.signature_def
28 |
29 | # define input and out tensor of your model here
30 | input_images_tensor_name = self.signature[self.signature_key].inputs[self.input_key_1].name
31 | output_score_tensor_name = self.signature[self.signature_key].outputs[self.output_key_1].name
32 | self.input_images = self.sess.graph.get_tensor_by_name(input_images_tensor_name)
33 | self.output_score = self.sess.graph.get_tensor_by_name(output_score_tensor_name)
34 |
35 | self.label_id_name_dict = \
36 | {
37 | "0": "其他垃圾/一次性快餐盒",
38 | "1": "其他垃圾/污损塑料",
39 | "2": "其他垃圾/烟蒂",
40 | "3": "其他垃圾/牙签",
41 | "4": "其他垃圾/破碎花盆及碟碗",
42 | "5": "其他垃圾/竹筷",
43 | "6": "厨余垃圾/剩饭剩菜",
44 | "7": "厨余垃圾/大骨头",
45 | "8": "厨余垃圾/水果果皮",
46 | "9": "厨余垃圾/水果果肉",
47 | "10": "厨余垃圾/茶叶渣",
48 | "11": "厨余垃圾/菜叶菜根",
49 | "12": "厨余垃圾/蛋壳",
50 | "13": "厨余垃圾/鱼骨",
51 | "14": "可回收物/充电宝",
52 | "15": "可回收物/包",
53 | "16": "可回收物/化妆品瓶",
54 | "17": "可回收物/塑料玩具",
55 | "18": "可回收物/塑料碗盆",
56 | "19": "可回收物/塑料衣架",
57 | "20": "可回收物/快递纸袋",
58 | "21": "可回收物/插头电线",
59 | "22": "可回收物/旧衣服",
60 | "23": "可回收物/易拉罐",
61 | "24": "可回收物/枕头",
62 | "25": "可回收物/毛绒玩具",
63 | "26": "可回收物/洗发水瓶",
64 | "27": "可回收物/玻璃杯",
65 | "28": "可回收物/皮鞋",
66 | "29": "可回收物/砧板",
67 | "30": "可回收物/纸板箱",
68 | "31": "可回收物/调料瓶",
69 | "32": "可回收物/酒瓶",
70 | "33": "可回收物/金属食品罐",
71 | "34": "可回收物/锅",
72 | "35": "可回收物/食用油桶",
73 | "36": "可回收物/饮料瓶",
74 | "37": "有害垃圾/干电池",
75 | "38": "有害垃圾/软膏",
76 | "39": "有害垃圾/过期药物"
77 | }
78 |
79 | def center_img(self, img, size=None, fill_value=255):
80 | """
81 | center img in a square background
82 | """
83 | h, w = img.shape[:2]
84 | if size is None:
85 | size = max(h, w)
86 | shape = (size, size) + img.shape[2:]
87 | background = np.full(shape, fill_value, np.uint8)
88 | center_x = (size - w) // 2
89 | center_y = (size - h) // 2
90 | background[center_y:center_y + h, center_x:center_x + w] = img
91 | return background
92 |
93 | def preprocess_img(self, img):
94 | """
95 | image preprocessing
96 | you can add your special preprocess method here
97 | """
98 | resize_scale = self.input_size / max(img.size[:2])
99 | img = img.resize((int(img.size[0] * resize_scale), int(img.size[1] * resize_scale)))
100 | img = img.convert('RGB')
101 | img = np.array(img)
102 | img = img[:, :, ::-1]
103 | img = self.center_img(img, self.input_size)
104 | return img
105 | #################################################################
106 | # img = img.resize((self.input_size, self.input_size))
107 | # img = img.convert('RGB')
108 | # img = np.array(img)
109 | # img = img.astype(np.float)
110 | # img = img[:, :, ::-1]
111 | # return img
112 |
113 | def _preprocess(self, data):
114 | preprocessed_data = {}
115 | for k, v in data.items():
116 | for file_name, file_content in v.items():
117 | img = Image.open(file_content)
118 | img = self.preprocess_img(img)
119 | preprocessed_data[k] = img
120 | return preprocessed_data
121 |
122 | def _inference(self, data):
123 | """
124 | model inference function
125 | Here are a inference example of resnet, if you use another model, please modify this function
126 | """
127 | img = data[self.input_key_1]
128 | img = img[np.newaxis, :, :, :] # the input tensor shape of resnet is [?, 224, 224, 3]
129 | img = np.asarray(img, np.float32) / 255.0
130 | mean = [0.56719673, 0.5293289, 0.48351972]
131 | std = [0.20874391, 0.21455203, 0.22451781]
132 | img[..., 0] -= mean[0]
133 | img[..., 1] -= mean[1]
134 | img[..., 2] -= mean[2]
135 | img[..., 0] /= std[0]
136 | img[..., 1] /= std[1]
137 | img[..., 2] /= std[2]
138 | pred_score = self.sess.run([self.output_score], feed_dict={self.input_images: img})
139 | if pred_score is not None:
140 | pred_label = np.argmax(pred_score[0], axis=1)[0]
141 | result = {'result': self.label_id_name_dict[str(pred_label)]}
142 | else:
143 | result = {'result': 'predict score is None'}
144 | return result
145 |
146 | def _postprocess(self, data):
147 | return data
148 |
--------------------------------------------------------------------------------
/efficientnet/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The TensorFlow Authors, Pavel Yakubovskiy. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | import functools
17 |
18 | _KERAS_BACKEND = None
19 | _KERAS_LAYERS = None
20 | _KERAS_MODELS = None
21 | _KERAS_UTILS = None
22 |
23 |
24 | def get_submodules_from_kwargs(kwargs):
25 | backend = kwargs.get('backend', _KERAS_BACKEND)
26 | layers = kwargs.get('layers', _KERAS_LAYERS)
27 | models = kwargs.get('models', _KERAS_MODELS)
28 | utils = kwargs.get('utils', _KERAS_UTILS)
29 | for key in kwargs.keys():
30 | if key not in ['backend', 'layers', 'models', 'utils']:
31 | raise TypeError('Invalid keyword argument: %s', key)
32 | return backend, layers, models, utils
33 |
34 |
35 | def inject_keras_modules(func):
36 | import keras
37 | @functools.wraps(func)
38 | def wrapper(*args, **kwargs):
39 | kwargs['backend'] = keras.backend
40 | kwargs['layers'] = keras.layers
41 | kwargs['models'] = keras.models
42 | kwargs['utils'] = keras.utils
43 | return func(*args, **kwargs)
44 |
45 | return wrapper
46 |
47 |
48 | def inject_tfkeras_modules(func):
49 | import tensorflow.keras as tfkeras
50 | @functools.wraps(func)
51 | def wrapper(*args, **kwargs):
52 | kwargs['backend'] = tfkeras.backend
53 | kwargs['layers'] = tfkeras.layers
54 | kwargs['models'] = tfkeras.models
55 | kwargs['utils'] = tfkeras.utils
56 | return func(*args, **kwargs)
57 |
58 | return wrapper
59 |
60 |
61 | def init_keras_custom_objects():
62 | import keras
63 | from . import model
64 |
65 | custom_objects = {
66 | 'swish': inject_keras_modules(model.get_swish)(),
67 | 'FixedDropout': inject_keras_modules(model.get_dropout)()
68 | }
69 |
70 | keras.utils.generic_utils.get_custom_objects().update(custom_objects)
71 |
72 |
73 | def init_tfkeras_custom_objects():
74 | import tensorflow.keras as tfkeras
75 | from . import model
76 |
77 | custom_objects = {
78 | 'swish': inject_tfkeras_modules(model.get_swish)(),
79 | 'FixedDropout': inject_tfkeras_modules(model.get_dropout)()
80 | }
81 |
82 | tfkeras.utils.get_custom_objects().update(custom_objects)
83 |
--------------------------------------------------------------------------------
/efficientnet/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wusaifei/garbage_classify/107b0c02499828e72978b6fe0aa704ecb9457d30/efficientnet/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/efficientnet/__pycache__/__version__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wusaifei/garbage_classify/107b0c02499828e72978b6fe0aa704ecb9457d30/efficientnet/__pycache__/__version__.cpython-36.pyc
--------------------------------------------------------------------------------
/efficientnet/__pycache__/keras.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wusaifei/garbage_classify/107b0c02499828e72978b6fe0aa704ecb9457d30/efficientnet/__pycache__/keras.cpython-36.pyc
--------------------------------------------------------------------------------
/efficientnet/__pycache__/model.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wusaifei/garbage_classify/107b0c02499828e72978b6fe0aa704ecb9457d30/efficientnet/__pycache__/model.cpython-36.pyc
--------------------------------------------------------------------------------
/efficientnet/__pycache__/preprocessing.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wusaifei/garbage_classify/107b0c02499828e72978b6fe0aa704ecb9457d30/efficientnet/__pycache__/preprocessing.cpython-36.pyc
--------------------------------------------------------------------------------
/efficientnet/__pycache__/tfkeras.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wusaifei/garbage_classify/107b0c02499828e72978b6fe0aa704ecb9457d30/efficientnet/__pycache__/tfkeras.cpython-36.pyc
--------------------------------------------------------------------------------
/efficientnet/__version__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The TensorFlow Authors, Pavel Yakubovskiy. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | VERSION = (1, 0, '0b3')
16 |
17 | __version__ = ".".join(map(str, VERSION))
18 |
19 |
--------------------------------------------------------------------------------
/efficientnet/keras.py:
--------------------------------------------------------------------------------
1 | from . import inject_keras_modules, init_keras_custom_objects
2 | from . import model
3 |
4 | from .preprocessing import center_crop_and_resize
5 |
6 | EfficientNetB0 = inject_keras_modules(model.EfficientNetB0)
7 | EfficientNetB1 = inject_keras_modules(model.EfficientNetB1)
8 | EfficientNetB2 = inject_keras_modules(model.EfficientNetB2)
9 | EfficientNetB3 = inject_keras_modules(model.EfficientNetB3)
10 | EfficientNetB4 = inject_keras_modules(model.EfficientNetB4)
11 | EfficientNetB5 = inject_keras_modules(model.EfficientNetB5)
12 | EfficientNetB6 = inject_keras_modules(model.EfficientNetB6)
13 | EfficientNetB7 = inject_keras_modules(model.EfficientNetB7)
14 |
15 | preprocess_input = inject_keras_modules(model.preprocess_input)
16 |
17 | init_keras_custom_objects()
18 |
--------------------------------------------------------------------------------
/efficientnet/model.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The TensorFlow Authors, Pavel Yakubovskiy, Björn Barz. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Contains definitions for EfficientNet model.
16 |
17 | [1] Mingxing Tan, Quoc V. Le
18 | EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks.
19 | ICML'19, https://arxiv.org/abs/1905.11946
20 | """
21 |
22 | # Code of this model implementation is mostly written by
23 | # Björn Barz ([@Callidior](https://github.com/Callidior))
24 |
25 | from __future__ import absolute_import
26 | from __future__ import division
27 | from __future__ import print_function
28 |
29 | import os
30 | import json
31 | import math
32 | import string
33 | import collections
34 | import numpy as np
35 |
36 | from six.moves import xrange
37 | from keras_applications.imagenet_utils import _obtain_input_shape
38 | from keras_applications.imagenet_utils import decode_predictions
39 | from keras_applications.imagenet_utils import preprocess_input as _preprocess_input
40 |
41 | from . import get_submodules_from_kwargs
42 |
43 | backend = None
44 | layers = None
45 | models = None
46 | keras_utils = None
47 |
48 | BASE_WEIGHTS_PATH = (
49 | 'https://github.com/Callidior/keras-applications/'
50 | 'releases/download/efficientnet/')
51 |
52 | WEIGHTS_HASHES = {
53 | 'efficientnet-b0': ('163292582f1c6eaca8e7dc7b51b01c61'
54 | '5b0dbc0039699b4dcd0b975cc21533dc',
55 | 'c1421ad80a9fc67c2cc4000f666aa507'
56 | '89ce39eedb4e06d531b0c593890ccff3'),
57 | 'efficientnet-b1': ('d0a71ddf51ef7a0ca425bab32b7fa7f1'
58 | '6043ee598ecee73fc674d9560c8f09b0',
59 | '75de265d03ac52fa74f2f510455ba64f'
60 | '9c7c5fd96dc923cd4bfefa3d680c4b68'),
61 | 'efficientnet-b2': ('bb5451507a6418a574534aa76a91b106'
62 | 'f6b605f3b5dde0b21055694319853086',
63 | '433b60584fafba1ea3de07443b74cfd3'
64 | '2ce004a012020b07ef69e22ba8669333'),
65 | 'efficientnet-b3': ('03f1fba367f070bd2545f081cfa7f3e7'
66 | '6f5e1aa3b6f4db700f00552901e75ab9',
67 | 'c5d42eb6cfae8567b418ad3845cfd63a'
68 | 'a48b87f1bd5df8658a49375a9f3135c7'),
69 | 'efficientnet-b4': ('98852de93f74d9833c8640474b2c698d'
70 | 'b45ec60690c75b3bacb1845e907bf94f',
71 | '7942c1407ff1feb34113995864970cd4'
72 | 'd9d91ea64877e8d9c38b6c1e0767c411'),
73 | 'efficientnet-b5': ('30172f1d45f9b8a41352d4219bf930ee'
74 | '3339025fd26ab314a817ba8918fefc7d',
75 | '9d197bc2bfe29165c10a2af8c2ebc675'
76 | '07f5d70456f09e584c71b822941b1952'),
77 | 'efficientnet-b6': ('f5270466747753485a082092ac9939ca'
78 | 'a546eb3f09edca6d6fff842cad938720',
79 | '1d0923bb038f2f8060faaf0a0449db4b'
80 | '96549a881747b7c7678724ac79f427ed'),
81 | 'efficientnet-b7': ('876a41319980638fa597acbbf956a82d'
82 | '10819531ff2dcb1a52277f10c7aefa1a',
83 | '60b56ff3a8daccc8d96edfd40b204c11'
84 | '3e51748da657afd58034d54d3cec2bac')
85 | }
86 |
87 | BlockArgs = collections.namedtuple('BlockArgs', [
88 | 'kernel_size', 'num_repeat', 'input_filters', 'output_filters',
89 | 'expand_ratio', 'id_skip', 'strides', 'se_ratio'
90 | ])
91 | # defaults will be a public argument for namedtuple in Python 3.7
92 | # https://docs.python.org/3/library/collections.html#collections.namedtuple
93 | BlockArgs.__new__.__defaults__ = (None,) * len(BlockArgs._fields)
94 |
95 | DEFAULT_BLOCKS_ARGS = [
96 | BlockArgs(kernel_size=3, num_repeat=1, input_filters=32, output_filters=16,
97 | expand_ratio=1, id_skip=True, strides=[1, 1], se_ratio=0.25),
98 | BlockArgs(kernel_size=3, num_repeat=2, input_filters=16, output_filters=24,
99 | expand_ratio=6, id_skip=True, strides=[2, 2], se_ratio=0.25),
100 | BlockArgs(kernel_size=5, num_repeat=2, input_filters=24, output_filters=40,
101 | expand_ratio=6, id_skip=True, strides=[2, 2], se_ratio=0.25),
102 | BlockArgs(kernel_size=3, num_repeat=3, input_filters=40, output_filters=80,
103 | expand_ratio=6, id_skip=True, strides=[2, 2], se_ratio=0.25),
104 | BlockArgs(kernel_size=5, num_repeat=3, input_filters=80, output_filters=112,
105 | expand_ratio=6, id_skip=True, strides=[1, 1], se_ratio=0.25),
106 | BlockArgs(kernel_size=5, num_repeat=4, input_filters=112, output_filters=192,
107 | expand_ratio=6, id_skip=True, strides=[2, 2], se_ratio=0.25),
108 | BlockArgs(kernel_size=3, num_repeat=1, input_filters=192, output_filters=320,
109 | expand_ratio=6, id_skip=True, strides=[1, 1], se_ratio=0.25)
110 | ]
111 |
112 | CONV_KERNEL_INITIALIZER = {
113 | 'class_name': 'VarianceScaling',
114 | 'config': {
115 | 'scale': 2.0,
116 | 'mode': 'fan_out',
117 | # EfficientNet actually uses an untruncated normal distribution for
118 | # initializing conv layers, but keras.initializers.VarianceScaling use
119 | # a truncated distribution.
120 | # We decided against a custom initializer for better serializability.
121 | 'distribution': 'normal'
122 | }
123 | }
124 |
125 | DENSE_KERNEL_INITIALIZER = {
126 | 'class_name': 'VarianceScaling',
127 | 'config': {
128 | 'scale': 1. / 3.,
129 | 'mode': 'fan_out',
130 | 'distribution': 'uniform'
131 | }
132 | }
133 |
134 |
135 | def preprocess_input(x, **kwargs):
136 | return _preprocess_input(x, mode='torch', **kwargs)
137 |
138 |
139 | def get_swish(**kwargs):
140 | backend, layers, models, keras_utils = get_submodules_from_kwargs(kwargs)
141 | def swish(x):
142 | """Swish activation function: x * sigmoid(x).
143 | Reference: [Searching for Activation Functions](https://arxiv.org/abs/1710.05941)
144 | """
145 |
146 | if backend.backend() == 'tensorflow':
147 | try:
148 | # The native TF implementation has a more
149 | # memory-efficient gradient implementation
150 | return backend.tf.nn.swish(x)
151 | except AttributeError:
152 | pass
153 |
154 | return x * backend.sigmoid(x)
155 | return swish
156 |
157 |
158 | def get_dropout(**kwargs):
159 | """Wrapper over custom dropout. Fix problem of ``None`` shape for tf.keras.
160 | It is not possible to define FixedDropout class as global object,
161 | because we do not have modules for inheritance at first time.
162 |
163 | Issue:
164 | https://github.com/tensorflow/tensorflow/issues/30946
165 | """
166 | backend, layers, models, keras_utils = get_submodules_from_kwargs(kwargs)
167 |
168 | class FixedDropout(layers.Dropout):
169 | def _get_noise_shape(self, inputs):
170 | if self.noise_shape is None:
171 | return self.noise_shape
172 |
173 | symbolic_shape = backend.shape(inputs)
174 | noise_shape = [symbolic_shape[axis] if shape is None else shape
175 | for axis, shape in enumerate(self.noise_shape)]
176 | return tuple(noise_shape)
177 |
178 | return FixedDropout
179 |
180 |
181 | def round_filters(filters, width_coefficient, depth_divisor):
182 | """Round number of filters based on width multiplier."""
183 |
184 | filters *= width_coefficient
185 | new_filters = int(filters + depth_divisor / 2) // depth_divisor * depth_divisor
186 | new_filters = max(depth_divisor, new_filters)
187 | # Make sure that round down does not go down by more than 10%.
188 | if new_filters < 0.9 * filters:
189 | new_filters += depth_divisor
190 | return int(new_filters)
191 |
192 |
193 | def round_repeats(repeats, depth_coefficient):
194 | """Round number of repeats based on depth multiplier."""
195 |
196 | return int(math.ceil(depth_coefficient * repeats))
197 |
198 |
199 | def mb_conv_block(inputs, block_args, activation, drop_rate=None, prefix='', ):
200 | """Mobile Inverted Residual Bottleneck."""
201 |
202 | has_se = (block_args.se_ratio is not None) and (0 < block_args.se_ratio <= 1)
203 | bn_axis = 3 if backend.image_data_format() == 'channels_last' else 1
204 |
205 | # workaround over non working dropout with None in noise_shape in tf.keras
206 | Dropout = get_dropout(
207 | backend=backend,
208 | layers=layers,
209 | models=models,
210 | utils=keras_utils
211 | )
212 |
213 | # Expansion phase
214 | filters = block_args.input_filters * block_args.expand_ratio
215 | if block_args.expand_ratio != 1:
216 | x = layers.Conv2D(filters, 1,
217 | padding='same',
218 | use_bias=False,
219 | kernel_initializer=CONV_KERNEL_INITIALIZER,
220 | name=prefix + 'expand_conv')(inputs)
221 | x = layers.BatchNormalization(axis=bn_axis, name=prefix + 'expand_bn')(x)
222 | x = layers.Activation(activation, name=prefix + 'expand_activation')(x)
223 | else:
224 | x = inputs
225 |
226 | # Depthwise Convolution
227 | x = layers.DepthwiseConv2D(block_args.kernel_size,
228 | strides=block_args.strides,
229 | padding='same',
230 | use_bias=False,
231 | depthwise_initializer=CONV_KERNEL_INITIALIZER,
232 | name=prefix + 'dwconv')(x)
233 | x = layers.BatchNormalization(axis=bn_axis, name=prefix + 'bn')(x)
234 | x = layers.Activation(activation, name=prefix + 'activation')(x)
235 |
236 | # Squeeze and Excitation phase
237 | if has_se:
238 | num_reduced_filters = max(1, int(
239 | block_args.input_filters * block_args.se_ratio
240 | ))
241 | se_tensor = layers.GlobalAveragePooling2D(name=prefix + 'se_squeeze')(x)
242 |
243 | target_shape = (1, 1, filters) if backend.image_data_format() == 'channels_last' else (filters, 1, 1)
244 | se_tensor = layers.Reshape(target_shape, name=prefix + 'se_reshape')(se_tensor)
245 | se_tensor = layers.Conv2D(num_reduced_filters, 1,
246 | activation=activation,
247 | padding='same',
248 | use_bias=True,
249 | kernel_initializer=CONV_KERNEL_INITIALIZER,
250 | name=prefix + 'se_reduce')(se_tensor)
251 | se_tensor = layers.Conv2D(filters, 1,
252 | activation='sigmoid',
253 | padding='same',
254 | use_bias=True,
255 | kernel_initializer=CONV_KERNEL_INITIALIZER,
256 | name=prefix + 'se_expand')(se_tensor)
257 | if backend.backend() == 'theano':
258 | # For the Theano backend, we have to explicitly make
259 | # the excitation weights broadcastable.
260 | pattern = ([True, True, True, False] if backend.image_data_format() == 'channels_last'
261 | else [True, False, True, True])
262 | se_tensor = layers.Lambda(
263 | lambda x: backend.pattern_broadcast(x, pattern),
264 | name=prefix + 'se_broadcast')(se_tensor)
265 | x = layers.multiply([x, se_tensor], name=prefix + 'se_excite')
266 |
267 | # Output phase
268 | x = layers.Conv2D(block_args.output_filters, 1,
269 | padding='same',
270 | use_bias=False,
271 | kernel_initializer=CONV_KERNEL_INITIALIZER,
272 | name=prefix + 'project_conv')(x)
273 | x = layers.BatchNormalization(axis=bn_axis, name=prefix + 'project_bn')(x)
274 | if block_args.id_skip and all(
275 | s == 1 for s in block_args.strides
276 | ) and block_args.input_filters == block_args.output_filters:
277 | if drop_rate and (drop_rate > 0):
278 | x = Dropout(drop_rate,
279 | noise_shape=(None, 1, 1, 1),
280 | name=prefix + 'drop')(x)
281 | x = layers.add([x, inputs], name=prefix + 'add')
282 |
283 | return x
284 |
285 |
286 | def EfficientNet(width_coefficient,
287 | depth_coefficient,
288 | default_resolution,
289 | dropout_rate=0.2,
290 | drop_connect_rate=0.2,
291 | depth_divisor=8,
292 | blocks_args=DEFAULT_BLOCKS_ARGS,
293 | model_name='efficientnet',
294 | include_top=True,
295 | weights='imagenet',
296 | input_tensor=None,
297 | input_shape=None,
298 | pooling=None,
299 | classes=1000,
300 | **kwargs):
301 | """Instantiates the EfficientNet architecture using given scaling coefficients.
302 | Optionally loads weights pre-trained on ImageNet.
303 | Note that the data format convention used by the model is
304 | the one specified in your Keras config at `~/.keras/keras.json`.
305 | # Arguments
306 | width_coefficient: float, scaling coefficient for network width.
307 | depth_coefficient: float, scaling coefficient for network depth.
308 | default_resolution: int, default input image size.
309 | dropout_rate: float, dropout rate before final classifier layer.
310 | drop_connect_rate: float, dropout rate at skip connections.
311 | depth_divisor: int.
312 | blocks_args: A list of BlockArgs to construct block modules.
313 | model_name: string, model name.
314 | include_top: whether to include the fully-connected
315 | layer at the top of the network.
316 | weights: one of `None` (random initialization),
317 | 'imagenet' (pre-training on ImageNet),
318 | or the path to the weights file to be loaded.
319 | input_tensor: optional Keras tensor
320 | (i.e. output of `layers.Input()`)
321 | to use as image input for the model.
322 | input_shape: optional shape tuple, only to be specified
323 | if `include_top` is False.
324 | It should have exactly 3 inputs channels.
325 | pooling: optional pooling mode for feature extraction
326 | when `include_top` is `False`.
327 | - `None` means that the output of the model will be
328 | the 4D tensor output of the
329 | last convolutional layer.
330 | - `avg` means that global average pooling
331 | will be applied to the output of the
332 | last convolutional layer, and thus
333 | the output of the model will be a 2D tensor.
334 | - `max` means that global max pooling will
335 | be applied.
336 | classes: optional number of classes to classify images
337 | into, only to be specified if `include_top` is True, and
338 | if no `weights` argument is specified.
339 | # Returns
340 | A Keras model instance.
341 | # Raises
342 | ValueError: in case of invalid argument for `weights`,
343 | or invalid input shape.
344 | """
345 | global backend, layers, models, keras_utils
346 | backend, layers, models, keras_utils = get_submodules_from_kwargs(kwargs)
347 |
348 | if not (weights in {'imagenet', None} or os.path.exists(weights)):
349 | raise ValueError('The `weights` argument should be either '
350 | '`None` (random initialization), `imagenet` '
351 | '(pre-training on ImageNet), '
352 | 'or the path to the weights file to be loaded.')
353 |
354 | if weights == 'imagenet' and include_top and classes != 1000:
355 | raise ValueError('If using `weights` as `"imagenet"` with `include_top`'
356 | ' as true, `classes` should be 1000')
357 |
358 | # Determine proper input shape
359 | input_shape = _obtain_input_shape(input_shape,
360 | default_size=default_resolution,
361 | min_size=32,
362 | data_format=backend.image_data_format(),
363 | require_flatten=include_top,
364 | weights=weights)
365 |
366 | if input_tensor is None:
367 | img_input = layers.Input(shape=input_shape)
368 | else:
369 | if not backend.is_keras_tensor(input_tensor):
370 | img_input = layers.Input(tensor=input_tensor, shape=input_shape)
371 | else:
372 | img_input = input_tensor
373 |
374 | bn_axis = 3 if backend.image_data_format() == 'channels_last' else 1
375 | activation = get_swish(**kwargs)
376 |
377 | # Build stem
378 | x = img_input
379 | x = layers.Conv2D(round_filters(32, width_coefficient, depth_divisor), 3,
380 | strides=(2, 2),
381 | padding='same',
382 | use_bias=False,
383 | kernel_initializer=CONV_KERNEL_INITIALIZER,
384 | name='stem_conv')(x)
385 | x = layers.BatchNormalization(axis=bn_axis, name='stem_bn')(x)
386 | x = layers.Activation(activation, name='stem_activation')(x)
387 |
388 | # Build blocks
389 | num_blocks_total = sum(block_args.num_repeat for block_args in blocks_args)
390 | block_num = 0
391 | for idx, block_args in enumerate(blocks_args):
392 | assert block_args.num_repeat > 0
393 | # Update block input and output filters based on depth multiplier.
394 | block_args = block_args._replace(
395 | input_filters=round_filters(block_args.input_filters,
396 | width_coefficient, depth_divisor),
397 | output_filters=round_filters(block_args.output_filters,
398 | width_coefficient, depth_divisor),
399 | num_repeat=round_repeats(block_args.num_repeat, depth_coefficient))
400 |
401 | # The first block needs to take care of stride and filter size increase.
402 | drop_rate = drop_connect_rate * float(block_num) / num_blocks_total
403 | x = mb_conv_block(x, block_args,
404 | activation=activation,
405 | drop_rate=drop_rate,
406 | prefix='block{}a_'.format(idx + 1))
407 | block_num += 1
408 | if block_args.num_repeat > 1:
409 | # pylint: disable=protected-access
410 | block_args = block_args._replace(
411 | input_filters=block_args.output_filters, strides=[1, 1])
412 | # pylint: enable=protected-access
413 | for bidx in xrange(block_args.num_repeat - 1):
414 | drop_rate = drop_connect_rate * float(block_num) / num_blocks_total
415 | block_prefix = 'block{}{}_'.format(
416 | idx + 1,
417 | string.ascii_lowercase[bidx + 1]
418 | )
419 | x = mb_conv_block(x, block_args,
420 | activation=activation,
421 | drop_rate=drop_rate,
422 | prefix=block_prefix)
423 | block_num += 1
424 |
425 | # Build top
426 | x = layers.Conv2D(round_filters(1280, width_coefficient, depth_divisor), 1,
427 | padding='same',
428 | use_bias=False,
429 | kernel_initializer=CONV_KERNEL_INITIALIZER,
430 | name='top_conv')(x)
431 | x = layers.BatchNormalization(axis=bn_axis, name='top_bn')(x)
432 | x = layers.Activation(activation, name='top_activation')(x)
433 | if include_top:
434 | x = layers.GlobalAveragePooling2D(name='avg_pool')(x)
435 | if dropout_rate and dropout_rate > 0:
436 | x = layers.Dropout(dropout_rate, name='top_dropout')(x)
437 | x = layers.Dense(classes,
438 | activation='softmax',
439 | kernel_initializer=DENSE_KERNEL_INITIALIZER,
440 | name='probs')(x)
441 | else:
442 | if pooling == 'avg':
443 | x = layers.GlobalAveragePooling2D(name='avg_pool')(x)
444 | elif pooling == 'max':
445 | x = layers.GlobalMaxPooling2D(name='max_pool')(x)
446 |
447 | # Ensure that the model takes into account
448 | # any potential predecessors of `input_tensor`.
449 | if input_tensor is not None:
450 | inputs = keras_utils.get_source_inputs(input_tensor)
451 | else:
452 | inputs = img_input
453 |
454 | # Create model.
455 | model = models.Model(inputs, x, name=model_name)
456 |
457 | # Load weights.
458 | if weights == 'imagenet':
459 | if include_top:
460 | file_name = model_name + '_weights_tf_dim_ordering_tf_kernels_autoaugment.h5'
461 | file_hash = WEIGHTS_HASHES[model_name][0]
462 | else:
463 | file_name = model_name + '_weights_tf_dim_ordering_tf_kernels_autoaugment_notop.h5'
464 | file_hash = WEIGHTS_HASHES[model_name][1]
465 | weights_path = keras_utils.get_file(file_name,
466 | BASE_WEIGHTS_PATH + file_name,
467 | cache_subdir='models',
468 | file_hash=file_hash)
469 | model.load_weights(weights_path)
470 | elif weights is not None:
471 | model.load_weights(weights)
472 |
473 | return model
474 |
475 |
476 | def EfficientNetB0(include_top=True,
477 | weights='imagenet',
478 | input_tensor=None,
479 | input_shape=None,
480 | pooling=None,
481 | classes=1000,
482 | **kwargs):
483 | return EfficientNet(1.0, 1.0, 224, 0.2,
484 | model_name='efficientnet-b0',
485 | include_top=include_top, weights=weights,
486 | input_tensor=input_tensor, input_shape=input_shape,
487 | pooling=pooling, classes=classes,
488 | **kwargs)
489 |
490 |
491 | def EfficientNetB1(include_top=True,
492 | weights='imagenet',
493 | input_tensor=None,
494 | input_shape=None,
495 | pooling=None,
496 | classes=1000,
497 | **kwargs):
498 | return EfficientNet(1.0, 1.1, 240, 0.2,
499 | model_name='efficientnet-b1',
500 | include_top=include_top, weights=weights,
501 | input_tensor=input_tensor, input_shape=input_shape,
502 | pooling=pooling, classes=classes,
503 | **kwargs)
504 |
505 |
506 | def EfficientNetB2(include_top=True,
507 | weights='imagenet',
508 | input_tensor=None,
509 | input_shape=None,
510 | pooling=None,
511 | classes=1000,
512 | **kwargs):
513 | return EfficientNet(1.1, 1.2, 260, 0.3,
514 | model_name='efficientnet-b2',
515 | include_top=include_top, weights=weights,
516 | input_tensor=input_tensor, input_shape=input_shape,
517 | pooling=pooling, classes=classes,
518 | **kwargs)
519 |
520 |
521 | def EfficientNetB3(include_top=True,
522 | weights='imagenet',
523 | input_tensor=None,
524 | input_shape=None,
525 | pooling=None,
526 | classes=1000,
527 | **kwargs):
528 | return EfficientNet(1.2, 1.4, 300, 0.3,
529 | model_name='efficientnet-b3',
530 | include_top=include_top, weights=weights,
531 | input_tensor=input_tensor, input_shape=input_shape,
532 | pooling=pooling, classes=classes,
533 | **kwargs)
534 |
535 |
536 | def EfficientNetB4(include_top=True,
537 | weights='imagenet',
538 | input_tensor=None,
539 | input_shape=None,
540 | pooling=None,
541 | classes=1000,
542 | **kwargs):
543 | return EfficientNet(1.4, 1.8, 380, 0.4,
544 | model_name='efficientnet-b4',
545 | include_top=include_top, weights=weights,
546 | input_tensor=input_tensor, input_shape=input_shape,
547 | pooling=pooling, classes=classes,
548 | **kwargs)
549 |
550 |
551 | def EfficientNetB5(include_top=True,
552 | weights='imagenet',
553 | input_tensor=None,
554 | input_shape=None,
555 | pooling=None,
556 | classes=1000,
557 | **kwargs):
558 | return EfficientNet(1.6, 2.2, 456, 0.4,
559 | model_name='efficientnet-b5',
560 | include_top=include_top, weights=weights,
561 | input_tensor=input_tensor, input_shape=input_shape,
562 | pooling=pooling, classes=classes,
563 | **kwargs)
564 |
565 |
566 | def EfficientNetB6(include_top=True,
567 | weights='imagenet',
568 | input_tensor=None,
569 | input_shape=None,
570 | pooling=None,
571 | classes=1000,
572 | **kwargs):
573 | return EfficientNet(1.8, 2.6, 528, 0.5,
574 | model_name='efficientnet-b6',
575 | include_top=include_top, weights=weights,
576 | input_tensor=input_tensor, input_shape=input_shape,
577 | pooling=pooling, classes=classes,
578 | **kwargs)
579 |
580 |
581 | def EfficientNetB7(include_top=True,
582 | weights='imagenet',
583 | input_tensor=None,
584 | input_shape=None,
585 | pooling=None,
586 | classes=1000,
587 | **kwargs):
588 | return EfficientNet(2.0, 3.1, 600, 0.5,
589 | model_name='efficientnet-b7',
590 | include_top=include_top, weights=weights,
591 | input_tensor=input_tensor, input_shape=input_shape,
592 | pooling=pooling, classes=classes,
593 | **kwargs)
594 |
595 |
596 | setattr(EfficientNetB0, '__doc__', EfficientNet.__doc__)
597 | setattr(EfficientNetB1, '__doc__', EfficientNet.__doc__)
598 | setattr(EfficientNetB2, '__doc__', EfficientNet.__doc__)
599 | setattr(EfficientNetB3, '__doc__', EfficientNet.__doc__)
600 | setattr(EfficientNetB4, '__doc__', EfficientNet.__doc__)
601 | setattr(EfficientNetB5, '__doc__', EfficientNet.__doc__)
602 | setattr(EfficientNetB6, '__doc__', EfficientNet.__doc__)
603 | setattr(EfficientNetB7, '__doc__', EfficientNet.__doc__)
604 |
--------------------------------------------------------------------------------
/efficientnet/preprocessing.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The TensorFlow Authors, Pavel Yakubovskiy. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | import numpy as np
16 | from skimage.transform import resize
17 |
18 | MAP_INTERPOLATION_TO_ORDER = {
19 | "nearest": 0,
20 | "bilinear": 1,
21 | "biquadratic": 2,
22 | "bicubic": 3,
23 | }
24 |
25 |
26 | def center_crop_and_resize(image, image_size, crop_padding=32, interpolation="bicubic"):
27 | assert image.ndim in {2, 3}
28 | assert interpolation in MAP_INTERPOLATION_TO_ORDER.keys()
29 |
30 | h, w = image.shape[:2]
31 |
32 | padded_center_crop_size = int(
33 | (image_size / (image_size + crop_padding)) * min(h, w)
34 | )
35 | offset_height = ((h - padded_center_crop_size) + 1) // 2
36 | offset_width = ((w - padded_center_crop_size) + 1) // 2
37 |
38 | image_crop = image[
39 | offset_height: padded_center_crop_size + offset_height,
40 | offset_width: padded_center_crop_size + offset_width,
41 | ]
42 | resized_image = resize(
43 | image_crop,
44 | (image_size, image_size),
45 | order=MAP_INTERPOLATION_TO_ORDER[interpolation],
46 | preserve_range=True,
47 | )
48 |
49 | return resized_image
50 |
--------------------------------------------------------------------------------
/efficientnet/tfkeras.py:
--------------------------------------------------------------------------------
1 | from . import inject_tfkeras_modules, init_tfkeras_custom_objects
2 | from . import model
3 |
4 | from .preprocessing import center_crop_and_resize
5 |
6 | EfficientNetB0 = inject_tfkeras_modules(model.EfficientNetB0)
7 | EfficientNetB1 = inject_tfkeras_modules(model.EfficientNetB1)
8 | EfficientNetB2 = inject_tfkeras_modules(model.EfficientNetB2)
9 | EfficientNetB3 = inject_tfkeras_modules(model.EfficientNetB3)
10 | EfficientNetB4 = inject_tfkeras_modules(model.EfficientNetB4)
11 | EfficientNetB5 = inject_tfkeras_modules(model.EfficientNetB5)
12 | EfficientNetB6 = inject_tfkeras_modules(model.EfficientNetB6)
13 | EfficientNetB7 = inject_tfkeras_modules(model.EfficientNetB7)
14 |
15 | preprocess_input = inject_tfkeras_modules(model.preprocess_input)
16 |
17 | init_tfkeras_custom_objects()
18 |
--------------------------------------------------------------------------------
/eval.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import os
3 | import shutil
4 | import codecs
5 | import numpy as np
6 | from glob import glob
7 |
8 | from PIL import Image
9 | import tensorflow as tf
10 | from keras import backend
11 | from keras.optimizers import adam, Nadam
12 |
13 | from tensorflow.python.saved_model import tag_constants
14 |
15 | from train import model_fn
16 | from save_model import load_weights
17 |
18 | backend.set_image_data_format('channels_last')
19 |
20 |
21 | def center_img(img, size=None, fill_value=255):
22 | """
23 | center img in a square background
24 | """
25 | h, w = img.shape[:2]
26 | if size is None:
27 | size = max(h, w)
28 | shape = (size, size) + img.shape[2:]
29 | background = np.full(shape, fill_value, np.uint8)
30 | center_x = (size - w) // 2
31 | center_y = (size - h) // 2
32 | background[center_y:center_y + h, center_x:center_x + w] = img
33 | return background
34 |
35 |
36 | def preprocess_img(img_path, img_size):
37 | """
38 | image preprocessing
39 | you can add your special preprocess mothod here
40 | """
41 | img = Image.open(img_path)
42 | resize_scale = img_size / max(img.size[:2])
43 | img = img.resize((int(img.size[0] * resize_scale), int(img.size[1] * resize_scale)))
44 | img = img.convert('RGB')
45 | img = np.array(img)
46 | img = img[:, :, ::-1]
47 | img = center_img(img, img_size)
48 | return img
49 |
50 |
51 | def load_test_data(FLAGS):
52 | label_files = glob(os.path.join(FLAGS.test_data_local, '*.txt'))
53 | test_data = np.ndarray((len(label_files), FLAGS.input_size, FLAGS.input_size, 3),
54 | dtype=np.uint8)
55 | img_names = []
56 | test_labels = []
57 | for index, file_path in enumerate(label_files):
58 | with codecs.open(file_path, 'r', 'utf-8') as f:
59 | line = f.readline()
60 | line_split = line.strip().split(', ')
61 | if len(line_split) != 2:
62 | print('%s contain error lable' % os.path.basename(file_path))
63 | continue
64 | img_names.append(line_split[0])
65 | test_data[index] = preprocess_img(os.path.join(FLAGS.test_data_local, line_split[0]), FLAGS.input_size)
66 | test_labels.append(int(line_split[1]))
67 | return img_names, test_data, test_labels
68 |
69 |
70 | def test_single_h5(FLAGS, h5_weights_path):
71 | if not os.path.isfile(h5_weights_path):
72 | print('%s is not a h5 weights file path' % h5_weights_path)
73 | return
74 | optimizer = Nadam(lr=FLAGS.learning_rate, beta_1=0.9, beta_2=0.999, epsilon=1e-08, schedule_decay=0.004)
75 | objective = 'categorical_crossentropy'
76 | metrics = ['accuracy']
77 | model = model_fn(FLAGS, objective, optimizer, metrics)
78 | load_weights(model, FLAGS.eval_weights_path)
79 | img_names, test_data, test_labels = load_test_data(FLAGS)
80 | predictions = model.predict(test_data, verbose=0)
81 |
82 | right_count = 0
83 | error_infos = []
84 | for index, pred in enumerate(predictions):
85 | pred_label = np.argmax(pred, axis=0)
86 | test_label = test_labels[index]
87 | if pred_label == test_label:
88 | right_count += 1
89 | else:
90 | error_infos.append('%s, %s, %s\n' % (img_names[index], test_label, pred_label))
91 |
92 | accuracy = right_count / len(img_names)
93 | print('accuracy: %s' % accuracy)
94 | result_file_name = os.path.join(os.path.dirname(h5_weights_path),
95 | '%s_accuracy.txt' % os.path.basename(h5_weights_path))
96 | with open(result_file_name, 'w') as f:
97 | f.write('# predict error files\n')
98 | f.write('####################################\n')
99 | f.write('file_name, true_label, pred_label\n')
100 | f.writelines(error_infos)
101 | f.write('####################################\n')
102 | f.write('accuracy: %s\n' % accuracy)
103 | print('end')
104 |
105 |
106 | def test_batch_h5(FLAGS):
107 | """
108 | test all the h5 weights files in the model_dir
109 | """
110 | file_paths = glob.glob(os.path.join(FLAGS.eval_weights_path, '*.h5'))
111 | for file_path in file_paths:
112 | test_single_h5(FLAGS, file_path)
113 |
114 |
115 | def test_single_model(FLAGS):
116 | if FLAGS.eval_pb_path.startswith('s3//'):
117 | pb_model_dir = '/cache/tmp/model'
118 | if os.path.exists(pb_model_dir):
119 | shutil.rmtree(pb_model_dir)
120 | shutil.copytree(FLAGS.eval_pb_path, pb_model_dir)
121 | else:
122 | pb_model_dir = FLAGS.eval_pb_path
123 | signature_key = 'predict_images'
124 | input_key_1 = 'input_img'
125 | output_key_1 = 'output_score'
126 | config = tf.ConfigProto(allow_soft_placement=True)
127 |
128 | with tf.get_default_graph().as_default():
129 | sess1 = tf.Session(graph=tf.Graph(), config=config)
130 | pb_model_dir1 = pb_model_dir + '/model1'
131 | meta_graph_def = tf.saved_model.loader.load(sess1, [tag_constants.SERVING], pb_model_dir)
132 | if FLAGS.eval_pb_path.startswith('s3//'):
133 | shutil.rmtree(pb_model_dir)
134 | signature = meta_graph_def.signature_def
135 | input_images_tensor_name = signature[signature_key].inputs[input_key_1].name
136 | output_score_tensor_name = signature[signature_key].outputs[output_key_1].name
137 |
138 | input_images = sess1.graph.get_tensor_by_name(input_images_tensor_name)
139 | output_score = sess1.graph.get_tensor_by_name(output_score_tensor_name)
140 |
141 | with tf.get_default_graph().as_default():
142 | sess2 = tf.Session(graph=tf.Graph(), config=config)
143 | pb_model_dir1 = pb_model_dir + '/model2'
144 | meta_graph_def = tf.saved_model.loader.load(sess2, [tag_constants.SERVING], pb_model_dir)
145 | if FLAGS.eval_pb_path.startswith('s3//'):
146 | shutil.rmtree(pb_model_dir)
147 | signature = meta_graph_def.signature_def
148 | input_images_tensor_name = signature[signature_key].inputs[input_key_1].name
149 | output_score_tensor_name = signature[signature_key].outputs[output_key_1].name
150 |
151 | input_images = sess2.graph.get_tensor_by_name(input_images_tensor_name)
152 | output_score = sess2.graph.get_tensor_by_name(output_score_tensor_name)
153 |
154 | img_names, test_data, test_labels = load_test_data(FLAGS)
155 | right_count = 0
156 | error_infos = []
157 | for index, img in enumerate(test_data):
158 | img = img[np.newaxis, :, :, :]
159 | pred_score = sess1.run([output_score], feed_dict={input_images: img})
160 | if pred_score is not None:
161 | pred_label = np.argmax(pred_score[0], axis=1)[0]
162 | test_label = test_labels[index]
163 | if pred_label == test_label:
164 | right_count += 1
165 | else:
166 | error_infos.append('%s, %s, %s\n' % (img_names[index], test_label, pred_label))
167 | else:
168 | print('pred_score is None')
169 | accuracy = right_count / len(img_names)
170 | print('accuracy: %s' % accuracy)
171 | result_file_name = os.path.join(FLAGS.eval_pb_path, 'accuracy1.txt')
172 | with open(result_file_name, 'w') as f:
173 | f.write('# predict error files\n')
174 | f.write('####################################\n')
175 | f.write('file_name, true_label, pred_label\n')
176 | f.writelines(error_infos)
177 | f.write('####################################\n')
178 | f.write('accuracy: %s\n' % accuracy)
179 |
180 | img_names, test_data, test_labels = load_test_data(FLAGS)
181 | right_count = 0
182 | error_infos = []
183 | for index, img in enumerate(test_data):
184 | img = img[np.newaxis, :, :, :]
185 | pred_score = sess2.run([output_score], feed_dict={input_images: img})
186 | if pred_score is not None:
187 | pred_label = np.argmax(pred_score[0], axis=1)[0]
188 | test_label = test_labels[index]
189 | if pred_label == test_label:
190 | right_count += 1
191 | else:
192 | error_infos.append('%s, %s, %s\n' % (img_names[index], test_label, pred_label))
193 | else:
194 | print('pred_score is None')
195 | accuracy = right_count / len(img_names)
196 | print('accuracy: %s' % accuracy)
197 | result_file_name = os.path.join(FLAGS.eval_pb_path, 'accuracy2.txt')
198 | with open(result_file_name, 'w') as f:
199 | f.write('# predict error files\n')
200 | f.write('####################################\n')
201 | f.write('file_name, true_label, pred_label\n')
202 | f.writelines(error_infos)
203 | f.write('####################################\n')
204 | f.write('accuracy: %s\n' % accuracy)
205 | print('end')
206 |
207 |
208 | def eval_model(FLAGS):
209 | if FLAGS.eval_weights_path != '':
210 | if os.path.isdir(FLAGS.eval_weights_path):
211 | test_batch_h5(FLAGS)
212 | else:
213 | test_single_h5(FLAGS, FLAGS.eval_weights_path)
214 | elif FLAGS.eval_pb_path != '':
215 | test_single_model(FLAGS)
216 |
--------------------------------------------------------------------------------
/mean_std.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import numpy as np
3 | import cv2
4 | import random
5 | import os
6 |
7 | # calculate means and std 注意换行\n符号
8 | # train.txt中每一行是图像的位置信息
9 | path = 'train.txt'
10 | means = [0, 0, 0]
11 | stdevs = [0, 0, 0]
12 |
13 | index = 1
14 | num_imgs = 0
15 | with open(path, 'r') as f:
16 | lines = f.readlines()
17 | # random.shuffle(lines)
18 | print(lines)
19 | for line in lines:
20 | print(line)
21 | print('{}/{}'.format(index, len(lines)))
22 | index += 1
23 | a = os.path.join(line)
24 | # print(a[:-1])
25 | num_imgs += 1
26 | img = cv2.imread(a[:-1])
27 | img = np.asarray(img)
28 | print(img)
29 | img = img.astype(np.float32) / 255.
30 | for i in range(3):
31 | try:
32 | means[i] += img[:, :, i].mean()
33 | stdevs[i] += img[:, :, i].std()
34 | except:
35 | print('IndexError:此处图像出现错误, 但是不影响均值和方差的计算。')
36 | break
37 | print(num_imgs)
38 | means.reverse()
39 | stdevs.reverse()
40 |
41 | means = np.asarray(means) / num_imgs
42 | stdevs = np.asarray(stdevs) / num_imgs
43 |
44 | print("normMean = {}".format(means))
45 | print("normStd = {}".format(stdevs))
46 | print('transforms.Normalize(normMean = {}, normStd = {})'.format(means, stdevs))
47 |
--------------------------------------------------------------------------------
/mean_std.txt:
--------------------------------------------------------------------------------
1 | normMean = [0.56719673 0.5293289 0.48351972]
2 | normStd = [0.20874391 0.21455203 0.22451781]
3 | transforms.Normalize(normMean = [0.56719673 0.5293289 0.48351972], normStd = [0.20874391 0.21455203 0.22451781])
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wusaifei/garbage_classify/107b0c02499828e72978b6fe0aa704ecb9457d30/models/__init__.py
--------------------------------------------------------------------------------
/models/resnet50.py:
--------------------------------------------------------------------------------
1 | """Enables dynamic setting of underlying Keras module.
2 | """
3 | from __future__ import absolute_import
4 | from __future__ import division
5 | from __future__ import print_function
6 |
7 | import os
8 | import json
9 | import warnings
10 | import numpy as np
11 |
12 | from keras.applications import backend
13 | from keras.applications import layers
14 | from keras.applications import models
15 | from keras.applications import utils
16 |
17 | _KERAS_BACKEND = backend
18 | _KERAS_LAYERS = layers
19 | _KERAS_MODELS = models
20 | _KERAS_UTILS = utils
21 |
22 | CLASS_INDEX = None
23 | CLASS_INDEX_PATH = ('https://modelarts-competitions.obs.cn-north-1.myhuaweicloud.com/'
24 | 'model_zoo/resnet/imagenet_class_index.json')
25 |
26 | # Global tensor of imagenet mean for preprocessing symbolic inputs
27 | _IMAGENET_MEAN = None
28 |
29 | WEIGHTS_PATH = ('https://github.com/fchollet/deep-learning-models/'
30 | 'releases/download/v0.2/'
31 | 'resnet50_weights_tf_dim_ordering_tf_kernels.h5')
32 | WEIGHTS_PATH_NO_TOP = ('https://modelarts-competitions.obs.cn-north-1.myhuaweicloud.com/'
33 | 'model_zoo/resnet/'
34 | 'resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5')
35 |
36 | def get_submodules_from_kwargs(kwargs):
37 | backend = kwargs.get('backend', _KERAS_BACKEND)
38 | layers = kwargs.get('layers', _KERAS_LAYERS)
39 | models = kwargs.get('models', _KERAS_MODELS)
40 | utils = kwargs.get('utils', _KERAS_UTILS)
41 | for key in kwargs.keys():
42 | if key not in ['backend', 'layers', 'models', 'utils']:
43 | raise TypeError('Invalid keyword argument: %s', key)
44 | return backend, layers, models, utils
45 |
46 |
47 | def correct_pad(backend, inputs, kernel_size):
48 | """Returns a tuple for zero-padding for 2D convolution with downsampling.
49 |
50 | # Arguments
51 | input_size: An integer or tuple/list of 2 integers.
52 | kernel_size: An integer or tuple/list of 2 integers.
53 |
54 | # Returns
55 | A tuple.
56 | """
57 | img_dim = 2 if backend.image_data_format() == 'channels_first' else 1
58 | input_size = backend.int_shape(inputs)[img_dim:(img_dim + 2)]
59 |
60 | if isinstance(kernel_size, int):
61 | kernel_size = (kernel_size, kernel_size)
62 |
63 | if input_size[0] is None:
64 | adjust = (1, 1)
65 | else:
66 | adjust = (1 - input_size[0] % 2, 1 - input_size[1] % 2)
67 |
68 | correct = (kernel_size[0] // 2, kernel_size[1] // 2)
69 |
70 | return ((correct[0] - adjust[0], correct[0]),
71 | (correct[1] - adjust[1], correct[1]))
72 |
73 | __version__ = '1.0.7'
74 |
75 |
76 | def _preprocess_numpy_input(x, data_format, mode, **kwargs):
77 | """Preprocesses a Numpy array encoding a batch of images.
78 |
79 | # Arguments
80 | x: Input array, 3D or 4D.
81 | data_format: Data format of the image array.
82 | mode: One of "caffe", "tf" or "torch".
83 | - caffe: will convert the images from RGB to BGR,
84 | then will zero-center each color channel with
85 | respect to the ImageNet dataset,
86 | without scaling.
87 | - tf: will scale pixels between -1 and 1,
88 | sample-wise.
89 | - torch: will scale pixels between 0 and 1 and then
90 | will normalize each channel with respect to the
91 | ImageNet dataset.
92 |
93 | # Returns
94 | Preprocessed Numpy array.
95 | """
96 | backend, _, _, _ = get_submodules_from_kwargs(kwargs)
97 | if not issubclass(x.dtype.type, np.floating):
98 | x = x.astype(backend.floatx(), copy=False)
99 |
100 | if mode == 'tf':
101 | x /= 127.5
102 | x -= 1.
103 | return x
104 |
105 | if mode == 'torch':
106 | x /= 255.
107 | mean = [0.485, 0.456, 0.406]
108 | std = [0.229, 0.224, 0.225]
109 | else:
110 | if data_format == 'channels_first':
111 | # 'RGB'->'BGR'
112 | if x.ndim == 3:
113 | x = x[::-1, ...]
114 | else:
115 | x = x[:, ::-1, ...]
116 | else:
117 | # 'RGB'->'BGR'
118 | x = x[..., ::-1]
119 | mean = [103.939, 116.779, 123.68]
120 | std = None
121 |
122 | # Zero-center by mean pixel
123 | if data_format == 'channels_first':
124 | if x.ndim == 3:
125 | x[0, :, :] -= mean[0]
126 | x[1, :, :] -= mean[1]
127 | x[2, :, :] -= mean[2]
128 | if std is not None:
129 | x[0, :, :] /= std[0]
130 | x[1, :, :] /= std[1]
131 | x[2, :, :] /= std[2]
132 | else:
133 | x[:, 0, :, :] -= mean[0]
134 | x[:, 1, :, :] -= mean[1]
135 | x[:, 2, :, :] -= mean[2]
136 | if std is not None:
137 | x[:, 0, :, :] /= std[0]
138 | x[:, 1, :, :] /= std[1]
139 | x[:, 2, :, :] /= std[2]
140 | else:
141 | x[..., 0] -= mean[0]
142 | x[..., 1] -= mean[1]
143 | x[..., 2] -= mean[2]
144 | if std is not None:
145 | x[..., 0] /= std[0]
146 | x[..., 1] /= std[1]
147 | x[..., 2] /= std[2]
148 | return x
149 |
150 |
151 | def _preprocess_symbolic_input(x, data_format, mode, **kwargs):
152 | """Preprocesses a tensor encoding a batch of images.
153 |
154 | # Arguments
155 | x: Input tensor, 3D or 4D.
156 | data_format: Data format of the image tensor.
157 | mode: One of "caffe", "tf" or "torch".
158 | - caffe: will convert the images from RGB to BGR,
159 | then will zero-center each color channel with
160 | respect to the ImageNet dataset,
161 | without scaling.
162 | - tf: will scale pixels between -1 and 1,
163 | sample-wise.
164 | - torch: will scale pixels between 0 and 1 and then
165 | will normalize each channel with respect to the
166 | ImageNet dataset.
167 |
168 | # Returns
169 | Preprocessed tensor.
170 | """
171 | global _IMAGENET_MEAN
172 |
173 | backend, _, _, _ = get_submodules_from_kwargs(kwargs)
174 |
175 | if mode == 'tf':
176 | x /= 127.5
177 | x -= 1.
178 | return x
179 |
180 | if mode == 'torch':
181 | x /= 255.
182 | mean = [0.485, 0.456, 0.406]
183 | std = [0.229, 0.224, 0.225]
184 | else:
185 | if data_format == 'channels_first':
186 | # 'RGB'->'BGR'
187 | if backend.ndim(x) == 3:
188 | x = x[::-1, ...]
189 | else:
190 | x = x[:, ::-1, ...]
191 | else:
192 | # 'RGB'->'BGR'
193 | x = x[..., ::-1]
194 | mean = [103.939, 116.779, 123.68]
195 | std = None
196 |
197 | if _IMAGENET_MEAN is None:
198 | _IMAGENET_MEAN = backend.constant(-np.array(mean))
199 |
200 | # Zero-center by mean pixel
201 | if backend.dtype(x) != backend.dtype(_IMAGENET_MEAN):
202 | x = backend.bias_add(
203 | x, backend.cast(_IMAGENET_MEAN, backend.dtype(x)),
204 | data_format=data_format)
205 | else:
206 | x = backend.bias_add(x, _IMAGENET_MEAN, data_format)
207 | if std is not None:
208 | x /= std
209 | return x
210 |
211 |
212 | def preprocess_input(x, data_format=None, mode='caffe', **kwargs):
213 | """Preprocesses a tensor or Numpy array encoding a batch of images.
214 |
215 | # Arguments
216 | x: Input Numpy or symbolic tensor, 3D or 4D.
217 | The preprocessed data is written over the input data
218 | if the data types are compatible. To avoid this
219 | behaviour, `numpy.copy(x)` can be used.
220 | data_format: Data format of the image tensor/array.
221 | mode: One of "caffe", "tf" or "torch".
222 | - caffe: will convert the images from RGB to BGR,
223 | then will zero-center each color channel with
224 | respect to the ImageNet dataset,
225 | without scaling.
226 | - tf: will scale pixels between -1 and 1,
227 | sample-wise.
228 | - torch: will scale pixels between 0 and 1 and then
229 | will normalize each channel with respect to the
230 | ImageNet dataset.
231 |
232 | # Returns
233 | Preprocessed tensor or Numpy array.
234 |
235 | # Raises
236 | ValueError: In case of unknown `data_format` argument.
237 | """
238 | backend, _, _, _ = get_submodules_from_kwargs(kwargs)
239 |
240 | if data_format is None:
241 | data_format = backend.image_data_format()
242 | if data_format not in {'channels_first', 'channels_last'}:
243 | raise ValueError('Unknown data_format ' + str(data_format))
244 |
245 | if isinstance(x, np.ndarray):
246 | return _preprocess_numpy_input(x, data_format=data_format,
247 | mode=mode, **kwargs)
248 | else:
249 | return _preprocess_symbolic_input(x, data_format=data_format,
250 | mode=mode, **kwargs)
251 |
252 |
253 | def decode_predictions(preds, top=5, **kwargs):
254 | """Decodes the prediction of an ImageNet model.
255 |
256 | # Arguments
257 | preds: Numpy tensor encoding a batch of predictions.
258 | top: Integer, how many top-guesses to return.
259 |
260 | # Returns
261 | A list of lists of top class prediction tuples
262 | `(class_name, class_description, score)`.
263 | One list of tuples per sample in batch input.
264 |
265 | # Raises
266 | ValueError: In case of invalid shape of the `pred` array
267 | (must be 2D).
268 | """
269 | global CLASS_INDEX
270 |
271 | backend, _, _, keras_utils = get_submodules_from_kwargs(kwargs)
272 |
273 | if len(preds.shape) != 2 or preds.shape[1] != 1000:
274 | raise ValueError('`decode_predictions` expects '
275 | 'a batch of predictions '
276 | '(i.e. a 2D array of shape (samples, 1000)). '
277 | 'Found array with shape: ' + str(preds.shape))
278 | if CLASS_INDEX is None:
279 | fpath = keras_utils.get_file(
280 | 'imagenet_class_index.json',
281 | CLASS_INDEX_PATH,
282 | cache_subdir='models',
283 | file_hash='c2c37ea517e94d9795004a39431a14cb',
284 | cache_dir=os.path.join(os.path.dirname(__file__), '..'))
285 | with open(fpath) as f:
286 | CLASS_INDEX = json.load(f)
287 | results = []
288 | for pred in preds:
289 | top_indices = pred.argsort()[-top:][::-1]
290 | result = [tuple(CLASS_INDEX[str(i)]) + (pred[i],) for i in top_indices]
291 | result.sort(key=lambda x: x[2], reverse=True)
292 | results.append(result)
293 | return results
294 |
295 |
296 | def _obtain_input_shape(input_shape,
297 | default_size,
298 | min_size,
299 | data_format,
300 | require_flatten,
301 | weights=None):
302 | """Internal utility to compute/validate a model's input shape.
303 |
304 | # Arguments
305 | input_shape: Either None (will return the default network input shape),
306 | or a user-provided shape to be validated.
307 | default_size: Default input width/height for the model.
308 | min_size: Minimum input width/height accepted by the model.
309 | data_format: Image data format to use.
310 | require_flatten: Whether the model is expected to
311 | be linked to a classifier via a Flatten layer.
312 | weights: One of `None` (random initialization)
313 | or 'imagenet' (pre-training on ImageNet).
314 | If weights='imagenet' input channels must be equal to 3.
315 |
316 | # Returns
317 | An integer shape tuple (may include None entries).
318 |
319 | # Raises
320 | ValueError: In case of invalid argument values.
321 | """
322 | if weights != 'imagenet' and input_shape and len(input_shape) == 3:
323 | if data_format == 'channels_first':
324 | if input_shape[0] not in {1, 3}:
325 | warnings.warn(
326 | 'This model usually expects 1 or 3 input channels. '
327 | 'However, it was passed an input_shape with ' +
328 | str(input_shape[0]) + ' input channels.')
329 | default_shape = (input_shape[0], default_size, default_size)
330 | else:
331 | if input_shape[-1] not in {1, 3}:
332 | warnings.warn(
333 | 'This model usually expects 1 or 3 input channels. '
334 | 'However, it was passed an input_shape with ' +
335 | str(input_shape[-1]) + ' input channels.')
336 | default_shape = (default_size, default_size, input_shape[-1])
337 | else:
338 | if data_format == 'channels_first':
339 | default_shape = (3, default_size, default_size)
340 | else:
341 | default_shape = (default_size, default_size, 3)
342 | if weights == 'imagenet' and require_flatten:
343 | if input_shape is not None:
344 | if input_shape != default_shape:
345 | raise ValueError('When setting `include_top=True` '
346 | 'and loading `imagenet` weights, '
347 | '`input_shape` should be ' +
348 | str(default_shape) + '.')
349 | return default_shape
350 | if input_shape:
351 | if data_format == 'channels_first':
352 | if input_shape is not None:
353 | if len(input_shape) != 3:
354 | raise ValueError(
355 | '`input_shape` must be a tuple of three integers.')
356 | if input_shape[0] != 3 and weights == 'imagenet':
357 | raise ValueError('The input must have 3 channels; got '
358 | '`input_shape=' + str(input_shape) + '`')
359 | if ((input_shape[1] is not None and input_shape[1] < min_size) or
360 | (input_shape[2] is not None and input_shape[2] < min_size)):
361 | raise ValueError('Input size must be at least ' +
362 | str(min_size) + 'x' + str(min_size) +
363 | '; got `input_shape=' +
364 | str(input_shape) + '`')
365 | else:
366 | if input_shape is not None:
367 | if len(input_shape) != 3:
368 | raise ValueError(
369 | '`input_shape` must be a tuple of three integers.')
370 | if input_shape[-1] != 3 and weights == 'imagenet':
371 | raise ValueError('The input must have 3 channels; got '
372 | '`input_shape=' + str(input_shape) + '`')
373 | if ((input_shape[0] is not None and input_shape[0] < min_size) or
374 | (input_shape[1] is not None and input_shape[1] < min_size)):
375 | raise ValueError('Input size must be at least ' +
376 | str(min_size) + 'x' + str(min_size) +
377 | '; got `input_shape=' +
378 | str(input_shape) + '`')
379 | else:
380 | if require_flatten:
381 | input_shape = default_shape
382 | else:
383 | if data_format == 'channels_first':
384 | input_shape = (3, None, None)
385 | else:
386 | input_shape = (None, None, 3)
387 | if require_flatten:
388 | if None in input_shape:
389 | raise ValueError('If `include_top` is True, '
390 | 'you should specify a static `input_shape`. '
391 | 'Got `input_shape=' + str(input_shape) + '`')
392 | return input_shape
393 |
394 |
395 | backend = None
396 | layers = None
397 | models = None
398 | keras_utils = None
399 |
400 |
401 | def identity_block(input_tensor, kernel_size, filters, stage, block):
402 | """The identity block is the block that has no conv layer at shortcut.
403 |
404 | # Arguments
405 | input_tensor: input tensor
406 | kernel_size: default 3, the kernel size of
407 | middle conv layer at main path
408 | filters: list of integers, the filters of 3 conv layer at main path
409 | stage: integer, current stage label, used for generating layer names
410 | block: 'a','b'..., current block label, used for generating layer names
411 |
412 | # Returns
413 | Output tensor for the block.
414 | """
415 | filters1, filters2, filters3 = filters
416 | if backend.image_data_format() == 'channels_last':
417 | bn_axis = 3
418 | else:
419 | bn_axis = 1
420 | conv_name_base = 'res' + str(stage) + block + '_branch'
421 | bn_name_base = 'bn' + str(stage) + block + '_branch'
422 |
423 | x = layers.Conv2D(filters1, (1, 1),
424 | kernel_initializer='he_normal',
425 | name=conv_name_base + '2a')(input_tensor)
426 | x = layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '2a')(x)
427 | x = layers.Activation('relu')(x)
428 |
429 | x = layers.Conv2D(filters2, kernel_size,
430 | padding='same',
431 | kernel_initializer='he_normal',
432 | name=conv_name_base + '2b')(x)
433 | x = layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '2b')(x)
434 | x = layers.Activation('relu')(x)
435 |
436 | x = layers.Conv2D(filters3, (1, 1),
437 | kernel_initializer='he_normal',
438 | name=conv_name_base + '2c')(x)
439 | x = layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '2c')(x)
440 |
441 | x = layers.add([x, input_tensor])
442 | x = layers.Activation('relu')(x)
443 | return x
444 |
445 |
446 | def conv_block(input_tensor,
447 | kernel_size,
448 | filters,
449 | stage,
450 | block,
451 | strides=(2, 2)):
452 | """A block that has a conv layer at shortcut.
453 |
454 | # Arguments
455 | input_tensor: input tensor
456 | kernel_size: default 3, the kernel size of
457 | middle conv layer at main path
458 | filters: list of integers, the filters of 3 conv layer at main path
459 | stage: integer, current stage label, used for generating layer names
460 | block: 'a','b'..., current block label, used for generating layer names
461 | strides: Strides for the first conv layer in the block.
462 |
463 | # Returns
464 | Output tensor for the block.
465 |
466 | Note that from stage 3,
467 | the first conv layer at main path is with strides=(2, 2)
468 | And the shortcut should have strides=(2, 2) as well
469 | """
470 | filters1, filters2, filters3 = filters
471 | if backend.image_data_format() == 'channels_last':
472 | bn_axis = 3
473 | else:
474 | bn_axis = 1
475 | conv_name_base = 'res' + str(stage) + block + '_branch'
476 | bn_name_base = 'bn' + str(stage) + block + '_branch'
477 |
478 | x = layers.Conv2D(filters1, (1, 1), strides=strides,
479 | kernel_initializer='he_normal',
480 | name=conv_name_base + '2a')(input_tensor)
481 | x = layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '2a')(x)
482 | x = layers.Activation('relu')(x)
483 |
484 | x = layers.Conv2D(filters2, kernel_size, padding='same',
485 | kernel_initializer='he_normal',
486 | name=conv_name_base + '2b')(x)
487 | x = layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '2b')(x)
488 | x = layers.Activation('relu')(x)
489 |
490 | x = layers.Conv2D(filters3, (1, 1),
491 | kernel_initializer='he_normal',
492 | name=conv_name_base + '2c')(x)
493 | x = layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '2c')(x)
494 |
495 | shortcut = layers.Conv2D(filters3, (1, 1), strides=strides,
496 | kernel_initializer='he_normal',
497 | name=conv_name_base + '1')(input_tensor)
498 | shortcut = layers.BatchNormalization(
499 | axis=bn_axis, name=bn_name_base + '1')(shortcut)
500 |
501 | x = layers.add([x, shortcut])
502 | x = layers.Activation('relu')(x)
503 | return x
504 |
505 |
506 | def ResNet50(include_top=True,
507 | weights='imagenet',
508 | input_tensor=None,
509 | input_shape=None,
510 | pooling=None,
511 | classes=1000,
512 | **kwargs):
513 | """Instantiates the ResNet50 architecture.
514 |
515 | Optionally loads weights pre-trained on ImageNet.
516 | Note that the data format convention used by the model is
517 | the one specified in your Keras config at `~/.keras/keras.json`.
518 |
519 | # Arguments
520 | include_top: whether to include the fully-connected
521 | layer at the top of the network.
522 | weights: one of `None` (random initialization),
523 | 'imagenet' (pre-training on ImageNet),
524 | or the path to the weights file to be loaded.
525 | input_tensor: optional Keras tensor (i.e. output of `layers.Input()`)
526 | to use as image input for the model.
527 | input_shape: optional shape tuple, only to be specified
528 | if `include_top` is False (otherwise the input shape
529 | has to be `(224, 224, 3)` (with `channels_last` data format)
530 | or `(3, 224, 224)` (with `channels_first` data format).
531 | It should have exactly 3 inputs channels,
532 | and width and height should be no smaller than 32.
533 | E.g. `(200, 200, 3)` would be one valid value.
534 | pooling: Optional pooling mode for feature extraction
535 | when `include_top` is `False`.
536 | - `None` means that the output of the model will be
537 | the 4D tensor output of the
538 | last convolutional block.
539 | - `avg` means that global average pooling
540 | will be applied to the output of the
541 | last convolutional block, and thus
542 | the output of the model will be a 2D tensor.
543 | - `max` means that global max pooling will
544 | be applied.
545 | classes: optional number of classes to classify images
546 | into, only to be specified if `include_top` is True, and
547 | if no `weights` argument is specified.
548 |
549 | # Returns
550 | A Keras model instance.
551 |
552 | # Raises
553 | ValueError: in case of invalid argument for `weights`,
554 | or invalid input shape.
555 | """
556 | global backend, layers, models, keras_utils
557 | backend, layers, models, keras_utils = get_submodules_from_kwargs(kwargs)
558 |
559 | if not (weights in {'imagenet', None} or os.path.exists(weights)):
560 | raise ValueError('The `weights` argument should be either '
561 | '`None` (random initialization), `imagenet` '
562 | '(pre-training on ImageNet), '
563 | 'or the path to the weights file to be loaded.')
564 |
565 | if weights == 'imagenet' and include_top and classes != 1000:
566 | raise ValueError('If using `weights` as `"imagenet"` with `include_top`'
567 | ' as true, `classes` should be 1000')
568 |
569 | # Determine proper input shape
570 | input_shape = _obtain_input_shape(input_shape,
571 | default_size=224,
572 | min_size=32,
573 | data_format=backend.image_data_format(),
574 | require_flatten=include_top,
575 | weights=weights)
576 |
577 | if input_tensor is None:
578 | img_input = layers.Input(shape=input_shape)
579 | else:
580 | if not backend.is_keras_tensor(input_tensor):
581 | img_input = layers.Input(tensor=input_tensor, shape=input_shape)
582 | else:
583 | img_input = input_tensor
584 | if backend.image_data_format() == 'channels_last':
585 | bn_axis = 3
586 | else:
587 | bn_axis = 1
588 |
589 | x = layers.ZeroPadding2D(padding=(3, 3), name='conv1_pad')(img_input)
590 | x = layers.Conv2D(64, (7, 7),
591 | strides=(2, 2),
592 | padding='valid',
593 | kernel_initializer='he_normal',
594 | name='conv1')(x)
595 | x = layers.BatchNormalization(axis=bn_axis, name='bn_conv1')(x)
596 | x = layers.Activation('relu')(x)
597 | x = layers.ZeroPadding2D(padding=(1, 1), name='pool1_pad')(x)
598 | x = layers.MaxPooling2D((3, 3), strides=(2, 2))(x)
599 |
600 | x = conv_block(x, 3, [64, 64, 256], stage=2, block='a', strides=(1, 1))
601 | x = identity_block(x, 3, [64, 64, 256], stage=2, block='b')
602 | x = identity_block(x, 3, [64, 64, 256], stage=2, block='c')
603 |
604 | x = conv_block(x, 3, [128, 128, 512], stage=3, block='a')
605 | x = identity_block(x, 3, [128, 128, 512], stage=3, block='b')
606 | x = identity_block(x, 3, [128, 128, 512], stage=3, block='c')
607 | x = identity_block(x, 3, [128, 128, 512], stage=3, block='d')
608 |
609 | x = conv_block(x, 3, [256, 256, 1024], stage=4, block='a')
610 | x = identity_block(x, 3, [256, 256, 1024], stage=4, block='b')
611 | x = identity_block(x, 3, [256, 256, 1024], stage=4, block='c')
612 | x = identity_block(x, 3, [256, 256, 1024], stage=4, block='d')
613 | x = identity_block(x, 3, [256, 256, 1024], stage=4, block='e')
614 | x = identity_block(x, 3, [256, 256, 1024], stage=4, block='f')
615 |
616 | x = conv_block(x, 3, [512, 512, 2048], stage=5, block='a')
617 | x = identity_block(x, 3, [512, 512, 2048], stage=5, block='b')
618 | x = identity_block(x, 3, [512, 512, 2048], stage=5, block='c')
619 |
620 | if include_top:
621 | x = layers.GlobalAveragePooling2D(name='avg_pool')(x)
622 | x = layers.Dense(classes, activation='softmax', name='fc1000')(x)
623 | else:
624 | if pooling == 'avg':
625 | x = layers.GlobalAveragePooling2D()(x)
626 | elif pooling == 'max':
627 | x = layers.GlobalMaxPooling2D()(x)
628 | else:
629 | warnings.warn('The output shape of `ResNet50(include_top=False)` '
630 | 'has been changed since Keras 2.2.0.')
631 |
632 | # Ensure that the model takes into account
633 | # any potential predecessors of `input_tensor`.
634 | if input_tensor is not None:
635 | inputs = keras_utils.get_source_inputs(input_tensor)
636 | else:
637 | inputs = img_input
638 | # Create model.
639 | model = models.Model(inputs, x, name='resnet50')
640 |
641 | # Load weights.
642 | if weights == 'imagenet':
643 | if include_top:
644 | weights_path = keras_utils.get_file(
645 | 'resnet50_weights_tf_dim_ordering_tf_kernels.h5',
646 | WEIGHTS_PATH,
647 | cache_subdir='models',
648 | md5_hash='a7b3fe01876f51b976af0dea6bc144eb',
649 | cache_dir=os.path.join(os.path.dirname(__file__), '..'))
650 | else:
651 | weights_path = keras_utils.get_file(
652 | 'resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5',
653 | WEIGHTS_PATH_NO_TOP,
654 | cache_subdir='models',
655 | md5_hash='a268eb855778b3df3c7506639542a6af',
656 | cache_dir=os.path.join(os.path.dirname(__file__), '..'))
657 | model.load_weights(weights_path)
658 | if backend.backend() == 'theano':
659 | keras_utils.convert_all_kernels_in_model(model)
660 | elif weights is not None:
661 | model.load_weights(weights)
662 |
663 | return model
664 |
665 |
666 |
--------------------------------------------------------------------------------
/pip-requirements.txt:
--------------------------------------------------------------------------------
1 | numpy==1.15.0
2 | keras_efficientnets
3 |
4 |
5 |
--------------------------------------------------------------------------------
/run.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | '''
3 | 基于resnet50实现的垃圾分类代码
4 | 使用方法:
5 | (1)训练
6 | cd {run.py所在目录}
7 | python run.py --data_url='../datasets/garbage_classify/train_data' --train_url='./model_snapshots' --deploy_script_path='./deploy_scripts'
8 | (2)转pb
9 | cd {run.py所在目录}
10 | python run.py --mode=save_pb --deploy_script_path='./deploy_scripts' --freeze_weights_file_path='../model_snapshots/weights_000_0.9811.h5' --num_classes=40
11 | (3)评价
12 | cd {run.py所在目录}
13 | python run.py --mode=eval --eval_pb_path='../model_snapshots/model' --test_data_url='../datasets/garbage_classify/train_data'
14 | '''
15 | import os
16 | os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
17 | import tensorflow as tf
18 | import shutil
19 |
20 | tf.app.flags.DEFINE_string('mode', 'train', 'optional: train, save_pb, eval')
21 | tf.app.flags.DEFINE_string('local_data_root', '/cache/',
22 | 'a directory used for transfer data between local path and OBS path')
23 | # params for train
24 | tf.app.flags.DEFINE_string('data_url', '', 'the training data path')
25 | tf.app.flags.DEFINE_string('restore_model_path', '',
26 | 'a history model you have trained, you can load it and continue trainging')
27 | tf.app.flags.DEFINE_string('train_url', '', 'the path to save training outputs')
28 | tf.app.flags.DEFINE_integer('keep_weights_file_num', 20,
29 | 'the max num of weights files keeps, if set -1, means infinity')
30 | tf.app.flags.DEFINE_integer('num_classes', 0, 'the num of classes which your task should classify')
31 | tf.app.flags.DEFINE_integer('input_size', 456, 'the input image size of the model')
32 | tf.app.flags.DEFINE_integer('batch_size', 8, '')
33 | tf.app.flags.DEFINE_float('learning_rate',1e-4, '')
34 | tf.app.flags.DEFINE_integer('max_epochs', 30, '')
35 |
36 | # params for save pb
37 | tf.app.flags.DEFINE_string('deploy_script_path', '',
38 | 'a path which contain config.json and customize_service.py, '
39 | 'if it is set, these two scripts will be copied to {train_url}/model directory')
40 | tf.app.flags.DEFINE_string('freeze_weights_file_path', '',
41 | 'if it is set, the specified h5 weights file will be converted as a pb model, '
42 | 'only valid when {mode}=save_pb')
43 |
44 | # params for evaluation
45 | tf.app.flags.DEFINE_string('eval_weights_path', '', 'weights file path need to be evaluate')
46 | tf.app.flags.DEFINE_string('eval_pb_path', '', 'a directory which contain pb file needed to be evaluate')
47 | tf.app.flags.DEFINE_string('test_data_url', '', 'the test data path on obs')
48 |
49 | tf.app.flags.DEFINE_string('data_local', '', 'the train data path on local')
50 | tf.app.flags.DEFINE_string('train_local', '', 'the training output results on local')
51 | tf.app.flags.DEFINE_string('test_data_local', '', 'the test data path on local')
52 | tf.app.flags.DEFINE_string('tmp', '', 'a temporary path on local')
53 |
54 | FLAGS = tf.app.flags.FLAGS
55 |
56 |
57 | def check_args(FLAGS):
58 | if FLAGS.mode not in ['train', 'save_pb', 'eval']:
59 | raise Exception('FLAGS.mode error, should be train, save_pb or eval')
60 | if FLAGS.num_classes == 0:
61 | raise Exception('FLAGS.num_classes error, '
62 | 'should be a positive number associated with your classification task')
63 |
64 | if FLAGS.mode == 'train':
65 | if FLAGS.data_url == '':
66 | raise Exception('you must specify FLAGS.data_url')
67 | if not os.path.exists(FLAGS.data_url):
68 | raise Exception('FLAGS.data_url: %s is not exist' % FLAGS.data_url)
69 | if FLAGS.restore_model_path != '' and (not os.path.exists(FLAGS.restore_model_path)):
70 | raise Exception('FLAGS.restore_model_path: %s is not exist' % FLAGS.restore_model_path)
71 | if os.path.isdir(FLAGS.restore_model_path):
72 | raise Exception('FLAGS.restore_model_path must be a file path, not a directory, %s' % FLAGS.restore_model_path)
73 | if FLAGS.train_url == '':
74 | raise Exception('you must specify FLAGS.train_url')
75 | elif not os.path.exists(FLAGS.train_url):
76 | os.mkdir(FLAGS.train_url)
77 | if FLAGS.deploy_script_path != '' and (not os.path.exists(FLAGS.deploy_script_path)):
78 | raise Exception('FLAGS.deploy_script_path: %s is not exist' % FLAGS.deploy_script_path)
79 | if FLAGS.deploy_script_path != '' and os.path.exists(FLAGS.train_url + '/model'):
80 | raise Exception(FLAGS.train_url +
81 | '/model is already exist, only one model directoty is allowed to exist')
82 | if FLAGS.test_data_url != '' and (not os.path.exists(FLAGS.test_data_url)):
83 | raise Exception('FLAGS.test_data_url: %s is not exist' % FLAGS.test_data_url)
84 |
85 | if FLAGS.mode == 'save_pb':
86 | if FLAGS.deploy_script_path == '' or FLAGS.freeze_weights_file_path == '':
87 | raise Exception('you must specify FLAGS.deploy_script_path '
88 | 'and FLAGS.freeze_weights_file_path when you want to save pb')
89 | if not os.path.exists(FLAGS.deploy_script_path):
90 | raise Exception('FLAGS.deploy_script_path: %s is not exist' % FLAGS.deploy_script_path)
91 | if not os.path.isdir(FLAGS.deploy_script_path):
92 | raise Exception('FLAGS.deploy_script_path must be a directory, not a file path, %s' % FLAGS.deploy_script_path)
93 | if not os.path.exists(FLAGS.freeze_weights_file_path):
94 | raise Exception('FLAGS.freeze_weights_file_path: %s is not exist' % FLAGS.freeze_weights_file_path)
95 | if os.path.isdir(FLAGS.freeze_weights_file_path):
96 | raise Exception('FLAGS.freeze_weights_file_path must be a file path, not a directory, %s ' % FLAGS.freeze_weights_file_path)
97 | if os.path.exists(FLAGS.freeze_weights_file_path.rsplit('/', 1)[0] + '/model'):
98 | raise Exception('a model directory is already exist in ' + FLAGS.freeze_weights_file_path.rsplit('/', 1)[0]
99 | + ', please rename or remove the model directory ')
100 |
101 | if FLAGS.mode == 'eval':
102 | if FLAGS.eval_weights_path == '' and FLAGS.eval_pb_path == '':
103 | raise Exception('you must specify FLAGS.eval_weights_path '
104 | 'or FLAGS.eval_pb_path when you want to evaluate a model')
105 | if FLAGS.eval_weights_path != '' and FLAGS.eval_pb_path != '':
106 | raise Exception('you must specify only one of FLAGS.eval_weights_path '
107 | 'and FLAGS.eval_pb_path when you want to evaluate a model')
108 | if FLAGS.eval_weights_path != '' and (not os.path.exists(FLAGS.eval_weights_path)):
109 | raise Exception('FLAGS.eval_weights_path: %s is not exist' % FLAGS.eval_weights_path)
110 | if FLAGS.eval_pb_path != '' and (not os.path.exists(FLAGS.eval_pb_path)):
111 | raise Exception('FLAGS.eval_pb_path: %s is not exist' % FLAGS.eval_pb_path)
112 | if not os.path.isdir(FLAGS.eval_pb_path) or (not FLAGS.eval_pb_path.endswith('model')):
113 | raise Exception('FLAGS.eval_pb_path must be a directory named model '
114 | 'which contain saved_model.pb and variables, %s' % FLAGS.eval_pb_path)
115 | if FLAGS.test_data_url == '':
116 | raise Exception('you must specify FLAGS.test_data_url when you want to evaluate a model')
117 | if not os.path.exists(FLAGS.test_data_url):
118 | raise Exception('FLAGS.test_data_url: %s is not exist' % FLAGS.test_data_url)
119 |
120 |
121 | def main(argv=None):
122 | check_args(FLAGS)
123 |
124 | # Create some local cache directories used for transfer data between local path and OBS path
125 | if not FLAGS.data_url.startswith('s3://'):
126 | FLAGS.data_local = FLAGS.data_url
127 | else:
128 | FLAGS.data_local = os.path.join(FLAGS.local_data_root, 'train_data/')
129 | if not os.path.exists(FLAGS.data_local):
130 | shutil.copytree(FLAGS.data_url, FLAGS.data_local)
131 | else:
132 | print('FLAGS.data_local: %s is already exist, skip copy' % FLAGS.data_local)
133 |
134 | if not FLAGS.train_url.startswith('s3://'):
135 | FLAGS.train_local = FLAGS.train_url
136 | else:
137 | FLAGS.train_local = os.path.join(FLAGS.local_data_root, 'model_snapshots/')
138 | if not os.path.exists(FLAGS.train_local):
139 | os.mkdir(FLAGS.train_local)
140 |
141 | if not FLAGS.test_data_url.startswith('s3://'):
142 | FLAGS.test_data_local = FLAGS.test_data_url
143 | else:
144 | FLAGS.test_data_local = os.path.join(FLAGS.local_data_root, 'test_data/')
145 | if not os.path.exists(FLAGS.test_data_local):
146 | shutil.copytree(FLAGS.test_data_url, FLAGS.test_data_local)
147 | else:
148 | print('FLAGS.test_data_local: %s is already exist, skip copy' % FLAGS.test_data_local)
149 |
150 | # FLAGS.tmp = os.path.join(FLAGS.local_data_root, 'tmp/')
151 | # print(FLAGS.tmp)
152 | # if not os.path.exists(FLAGS.tmp):
153 | # os.mkdir(FLAGS.tmp)
154 |
155 | if FLAGS.mode == 'train':
156 | from train import train_model
157 | train_model(FLAGS)
158 | elif FLAGS.mode == 'save_pb':
159 | from save_model import load_weights_save_pb
160 | load_weights_save_pb(FLAGS)
161 | elif FLAGS.mode == 'eval':
162 | from eval import eval_model
163 | eval_model(FLAGS)
164 |
165 | if __name__ == '__main__':
166 | tf.app.run()
167 |
--------------------------------------------------------------------------------
/save_model.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import os
3 |
4 | import tensorflow as tf
5 | from keras import backend
6 | from keras.optimizers import adam, Nadam
7 | import shutil
8 |
9 |
10 | from train import model_fn
11 |
12 |
13 | def load_weights(model, weighs_file_path):
14 | if os.path.isfile(weighs_file_path):
15 | print('load weights from %s' % weighs_file_path)
16 | if weighs_file_path.startswith('s3://'):
17 | weighs_file_name = weighs_file_path.rsplit('/', 1)[1]
18 | shutil.copyfile(weighs_file_path, '/cache/tmp/' + weighs_file_name)
19 | weighs_file_path = '/cache/tmp/' + weighs_file_name
20 | model.load_weights(weighs_file_path)
21 | os.remove(weighs_file_path)
22 | else:
23 | model.load_weights(weighs_file_path)
24 | print('load weights success')
25 | else:
26 | print('load weights failed! Please check weighs_file_path')
27 |
28 |
29 | def save_pb_model(FLAGS, model):
30 | if FLAGS.mode == 'train':
31 | pb_save_dir_local = FLAGS.train_local
32 | pb_save_dir_obs = FLAGS.train_url
33 | elif FLAGS.mode == 'save_pb':
34 | freeze_weights_file_dir = FLAGS.freeze_weights_file_path.rsplit('/', 1)[0]
35 | if freeze_weights_file_dir.startswith('s3://'):
36 | pb_save_dir_local = '/cache/tmp'
37 | pb_save_dir_obs = freeze_weights_file_dir
38 | else:
39 | pb_save_dir_local = freeze_weights_file_dir
40 | pb_save_dir_obs = pb_save_dir_local
41 |
42 | signature = tf.saved_model.signature_def_utils.predict_signature_def(
43 | inputs={'input_img': model.input}, outputs={'output_score': model.output})
44 | builder = tf.saved_model.builder.SavedModelBuilder(os.path.join(pb_save_dir_local, 'model'))
45 | legacy_init_op = tf.group(tf.tables_initializer(), name='legacy_init_op')
46 | builder.add_meta_graph_and_variables(
47 | sess=backend.get_session(),
48 | tags=[tf.saved_model.tag_constants.SERVING],
49 | signature_def_map={
50 | 'predict_images': signature,
51 | },
52 | legacy_init_op=legacy_init_op)
53 | builder.save()
54 | print('save pb to local path success')
55 |
56 | if pb_save_dir_obs.startswith('s3://'):
57 | shutil.copyfile(os.path.join(pb_save_dir_local, 'model'),
58 | os.path.join(pb_save_dir_obs, 'model'))
59 | print('copy pb to %s success' % pb_save_dir_obs)
60 |
61 | shutil.copyfile(os.path.join(FLAGS.deploy_script_path, 'config.json'),
62 | os.path.join(pb_save_dir_obs, 'model/config.json'))
63 | shutil.copyfile(os.path.join(FLAGS.deploy_script_path, 'customize_service.py'),
64 | os.path.join(pb_save_dir_obs, 'model/customize_service.py'))
65 | if os.path.exists(os.path.join(pb_save_dir_obs, 'model/config.json')) and \
66 | os.path.exists(os.path.join(pb_save_dir_obs, 'model/customize_service.py')):
67 | print('copy config.json and customize_service.py success')
68 | else:
69 | print('copy config.json and customize_service.py failed')
70 |
71 |
72 | def load_weights_save_pb(FLAGS):
73 | optimizer = Nadam(lr=FLAGS.learning_rate, beta_1=0.9, beta_2=0.999, epsilon=1e-08, schedule_decay=0.004)
74 | objective = 'categorical_crossentropy'
75 | metrics = ['accuracy']
76 | model = model_fn(FLAGS, objective, optimizer, metrics)
77 | load_weights(model, FLAGS.freeze_weights_file_path)
78 | save_pb_model(FLAGS, model)
79 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import os
3 | from glob import glob
4 | import numpy as np
5 | from keras import backend
6 | from keras.models import Model
7 | from keras.optimizers import adam, Nadam, SGD
8 | from keras.callbacks import TensorBoard, Callback
9 | import shutil
10 | from data_gen import data_flow
11 | from keras.layers import Dense, Input, Dropout, Activation,GlobalAveragePooling2D,LeakyReLU,BatchNormalization
12 | from keras.layers import concatenate,Concatenate,multiply, LocallyConnected2D, Lambda,Conv2D,GlobalMaxPooling2D,Flatten
13 | from keras.layers.core import Reshape
14 | from keras.layers import multiply
15 | import keras as ks
16 | from keras.models import load_model
17 | from keras.applications.xception import Xception
18 | from keras.applications.inception_v3 import InceptionV3
19 | from models.resnet50 import ResNet50
20 | from keras.layers import Flatten, Dense, AveragePooling2D
21 | from keras.models import Sequential
22 | from keras.utils import multi_gpu_model
23 | from Groupnormalization import GroupNormalization
24 | from keras_efficientnets import EfficientNetB5
25 | from keras_efficientnets import EfficientNetB4
26 | import multiprocessing
27 | import efficientnet.keras as efn
28 | from warmup_cosine_decay_scheduler import WarmUpCosineDecayScheduler
29 |
30 | backend.set_image_data_format('channels_last')
31 | def model_fn(FLAGS, objective, optimizer, metrics):
32 |
33 | model = EfficientNetB5(weights=None,
34 | include_top=False,
35 | input_shape=(FLAGS.input_size, FLAGS.input_size, 3),
36 | classes=FLAGS.num_classes,
37 | pooling=max)
38 |
39 | model.load_weights('/home/work/user-job-dir/src/efficientnet-b5_notop.h5')
40 | for i, layer in enumerate(model.layers):
41 | if "batch_normalization" in layer.name:
42 | model.layers[i] = GroupNormalization(groups=32, axis=-1, epsilon=0.00001)
43 | x = model.output
44 | x = GlobalAveragePooling2D()(x)
45 | x = Dropout(0.4)(x)
46 | predictions = Dense(FLAGS.num_classes, activation='softmax')(x) # activation="linear",activation='softmax'
47 | model = Model(input=model.input, output=predictions)
48 | model = multi_gpu_model(model, 4) # 修改成自身需要的GPU数量,4代表用4个GPU同时加载程序
49 | # model.load_weights('/home/work/user-job-dir/src/weights_004_0.9223.h5')
50 | model.compile(loss=objective, optimizer=optimizer, metrics=metrics)
51 | return model
52 |
53 | class LossHistory(Callback):
54 | def __init__(self, FLAGS):
55 | super(LossHistory, self).__init__( )
56 | self.FLAGS = FLAGS
57 |
58 | def on_train_begin(self, logs={}):
59 | self.losses = []
60 | self.val_losses = []
61 |
62 | def on_epoch_end(self, epoch, logs={}):
63 | self.losses.append(logs.get('loss'))
64 | self.val_losses.append(logs.get('val_loss'))
65 |
66 | save_path = os.path.join(self.FLAGS.train_local, 'weights_%03d_%.4f.h5' % (epoch, logs.get('val_acc')))
67 | self.model.save_weights(save_path)
68 | if self.FLAGS.train_url.startswith('s3://'):
69 | save_url = os.path.join(self.FLAGS.train_url, 'weights_%03d_%.4f.h5' % (epoch, logs.get('val_acc')))
70 | shutil.copyfile(save_path, save_url)
71 | print('save weights file', save_path)
72 |
73 | if self.FLAGS.keep_weights_file_num > -1:
74 | weights_files = glob(os.path.join(self.FLAGS.train_local, '*.h5'))
75 | if len(weights_files) >= self.FLAGS.keep_weights_file_num:
76 | weights_files.sort(key=lambda file_name: os.stat(file_name).st_ctime, reverse=True)
77 |
78 |
79 | def train_model(FLAGS):
80 | # data flow generator
81 | train_sequence, validation_sequence = data_flow(FLAGS.data_local, FLAGS.batch_size,
82 | FLAGS.num_classes, FLAGS.input_size)
83 |
84 | # optimizer = adam(lr=FLAGS.learning_rate, clipnorm=0.001)
85 | optimizer = Nadam(lr=FLAGS.learning_rate, beta_1=0.9, beta_2=0.999, epsilon=1e-08, schedule_decay=0.004)
86 | # optimizer = SGD(lr=FLAGS.learning_rate, momentum=0.9)
87 | objective = 'categorical_crossentropy'
88 | metrics = ['accuracy']
89 | model = model_fn(FLAGS, objective, optimizer, metrics)
90 | if FLAGS.restore_model_path != '' and os.path.exists(FLAGS.restore_model_path):
91 | if FLAGS.restore_model_path.startswith('s3://'):
92 | restore_model_name = FLAGS.restore_model_path.rsplit('/', 1)[1]
93 | shutil.copyfile(FLAGS.restore_model_path, '/cache/tmp/' + restore_model_name)
94 | model.load_weights('/cache/tmp/' + restore_model_name)
95 | os.remove('/cache/tmp/' + restore_model_name)
96 | else:
97 | model.load_weights(FLAGS.restore_model_path)
98 | print("LOAD OK!!!")
99 | if not os.path.exists(FLAGS.train_local):
100 | os.makedirs(FLAGS.train_local)
101 |
102 |
103 | log_local = '../log_file/'
104 | tensorBoard = TensorBoard(log_dir=log_local)
105 | # reduce_lr = ks.callbacks.ReduceLROnPlateau(monitor='val_acc', factor=0.5, verbose=1, patience=1,
106 | # min_lr=1e-7)
107 | # 余弦退火学习率
108 | sample_count = len(train_sequence) * FLAGS.batch_size
109 | epochs = FLAGS.max_epochs
110 | warmup_epoch = 5
111 | batch_size = FLAGS.batch_size
112 | learning_rate_base = FLAGS.learning_rate
113 | total_steps = int(epochs * sample_count / batch_size)
114 | warmup_steps = int(warmup_epoch * sample_count / batch_size)
115 |
116 | warm_up_lr = WarmUpCosineDecayScheduler(learning_rate_base=learning_rate_base,
117 | total_steps=total_steps,
118 | warmup_learning_rate=0,
119 | warmup_steps=warmup_steps,
120 | hold_base_rate_steps=0,
121 | )
122 | history = LossHistory(FLAGS)
123 | model.fit_generator(
124 | train_sequence,
125 | steps_per_epoch=len(train_sequence),
126 | epochs=FLAGS.max_epochs,
127 | verbose=1,
128 | callbacks=[history, tensorBoard, warm_up_lr],
129 | validation_data=validation_sequence,
130 | max_queue_size=10,
131 | workers=int(multiprocessing.cpu_count() * 0.7),
132 | use_multiprocessing=True,
133 | shuffle=True
134 | )
135 |
136 | print('training done!')
137 |
138 | if FLAGS.deploy_script_path != '':
139 | from save_model import save_pb_model
140 | save_pb_model(FLAGS, model)
141 |
142 | if FLAGS.test_data_url != '':
143 | print('test dataset predicting...')
144 | from eval import load_test_data
145 | img_names, test_data, test_labels = load_test_data(FLAGS)
146 | test_data = preprocess_input(test_data)
147 | predictions = model.predict(test_data, verbose=0)
148 |
149 | right_count = 0
150 | for index, pred in enumerate(predictions):
151 | predict_label = np.argmax(pred, axis=0)
152 | test_label = test_labels[index]
153 | if predict_label == test_label:
154 | right_count += 1
155 | accuracy = right_count / len(img_names)
156 | print('accuracy: %0.4f' % accuracy)
157 | metric_file_name = os.path.join(FLAGS.train_local, 'metric.json')
158 | metric_file_content = '{"total_metric": {"total_metric_values": {"accuracy": %0.4f}}}' % accuracy
159 | with open(metric_file_name, "w") as f:
160 | f.write(metric_file_content + '\n')
161 | print('end')
162 |
--------------------------------------------------------------------------------
/tta_wrapper/__init__.py:
--------------------------------------------------------------------------------
1 | from .wrappers import tta_segmentation
2 | from .wrappers import tta_classification
--------------------------------------------------------------------------------
/tta_wrapper/__version__.py:
--------------------------------------------------------------------------------
1 | VERSION = (0, 0, 1)
2 |
3 | __version__ = '.'.join(map(str, VERSION))
--------------------------------------------------------------------------------
/tta_wrapper/augmentation.py:
--------------------------------------------------------------------------------
1 | import itertools
2 | from . import functional as F
3 |
4 |
5 | class Augmentation(object):
6 |
7 | transforms = {
8 | 'h_flip': F.HFlip(),
9 | 'v_flip': F.VFlip(),
10 | 'rotation': F.Rotate(),
11 | 'h_shift': F.HShift(),
12 | 'v_shift': F.VShift(),
13 | 'contrast': F.Contrast(),
14 | 'add': F.Add(),
15 | 'mul': F.Multiply(),
16 | }
17 |
18 | def __init__(self, **params):
19 | super().__init__()
20 |
21 | transforms = [Augmentation.transforms[k] for k in params.keys()]
22 | transform_params = [params[k] for k in params.keys()]
23 |
24 | # add identity parameters for all transforms and convert to list
25 | transform_params = [t.prepare(params) for t, params in zip(transforms, transform_params)]
26 |
27 | # get all combinations of transforms params
28 | transform_params = list(itertools.product(*transform_params))
29 |
30 | self.forward_aug = [t.forward for t in transforms]
31 | self.forward_params = transform_params
32 |
33 | self.backward_aug = [t.backward for t in transforms[::-1]] # reverse transforms
34 | self.backward_params = [p[::-1] for p in transform_params] # reverse params
35 |
36 | self.n_transforms = len(transform_params)
37 |
38 | @property
39 | def forward(self):
40 | return self.forward_aug, self.forward_params
41 |
42 | @property
43 | def backward(self):
44 | return self.backward_aug, self.backward_params
45 |
--------------------------------------------------------------------------------
/tta_wrapper/functional.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 |
4 | class DualTransform:
5 |
6 | identity_param = None
7 |
8 | def prepare(self, params):
9 | if isinstance(params, tuple):
10 | params = list(params)
11 | elif params is None:
12 | params = []
13 | elif not isinstance(params, list):
14 | params = [params]
15 |
16 | if not self.identity_param in params:
17 | params.append(self.identity_param)
18 | return params
19 |
20 | def forward(self, image, param):
21 | raise NotImplementedError
22 |
23 | def backward(self, image, param):
24 | raise NotImplementedError
25 |
26 |
27 | class SingleTransform(DualTransform):
28 |
29 | def backward(self, image, param):
30 | return image
31 |
32 |
33 | class HFlip(DualTransform):
34 |
35 | identity_param = 0
36 |
37 | def prepare(self, params):
38 | if params == False:
39 | return [0]
40 | if params == True:
41 | return [1, 0]
42 |
43 | def forward(self, image, param):
44 | return tf.image.flip_left_right(image) if param else image
45 |
46 | def backward(self, image, param):
47 | return self.forward(image, param)
48 |
49 |
50 | class VFlip(DualTransform):
51 |
52 | identity_param = 0
53 |
54 | def prepare(self, params):
55 | if params == False:
56 | return [0]
57 | if params == True:
58 | return [1, 0]
59 |
60 | def forward(self, image, param):
61 | return tf.image.flip_up_down(image) if param else image
62 |
63 | def backward(self, image, param):
64 | return self.forward(image, param)
65 |
66 |
67 | class Rotate(DualTransform):
68 |
69 | identity_param = 0
70 |
71 | def forward(self, image, angle):
72 | k = angle // 90 if angle >= 0 else (angle + 360) // 90
73 | return tf.image.rot90(image, k)
74 |
75 | def backward(self, image, angle):
76 | return self.forward(image, -angle)
77 |
78 |
79 | class HShift(DualTransform):
80 |
81 | identity_param = 0
82 |
83 | def forward(self, image, param):
84 | return tf.manip.roll(image, param, axis=0)
85 |
86 | def backward(self, image, param):
87 | return tf.manip.roll(image, -param, axis=0)
88 |
89 |
90 | class VShift(DualTransform):
91 |
92 | identity_param = 0
93 |
94 | def forward(self, image, param):
95 | return tf.manip.roll(image, param, axis=1)
96 |
97 | def backward(self, image, param):
98 | return tf.manip.roll(image, -param, axis=1)
99 |
100 |
101 | class Contrast(SingleTransform):
102 |
103 | identity_param = 1
104 |
105 | def forward(self, image, param):
106 | return tf.image.adjust_contrast(image, param)
107 |
108 |
109 | class Add(SingleTransform):
110 |
111 | identity_param = 0
112 |
113 | def forward(self, image, param):
114 | return image + param
115 |
116 |
117 | class Multiply(SingleTransform):
118 |
119 | identity_param = 1
120 |
121 | def forward(self, image, param):
122 | return image * param
123 |
124 |
125 | def gmean(x):
126 | g_pow = 1 / x.get_shape().as_list()[0]
127 | x = tf.reduce_prod(x, axis=0, keepdims=True)
128 | x = tf.pow(x, g_pow)
129 | return x
130 |
131 |
132 | def mean(x):
133 | return tf.reduce_mean(x, axis=0, keepdims=True)
134 |
135 |
136 | def max(x):
137 | return tf.reduce_max(x, axis=0, keepdims=True)
138 |
--------------------------------------------------------------------------------
/tta_wrapper/layers.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | from keras.layers import Layer
3 |
4 | from . import functional as F
5 |
6 |
7 | class Repeat(Layer):
8 | """
9 | Layer for cloning input information
10 | input_shape = (1, H, W, C)
11 | output_shape = (N, H, W, C)
12 | """
13 | def __init__(self, n, **kwargs):
14 | super().__init__(**kwargs)
15 | self.n = n
16 |
17 | def call(self, x):
18 | return tf.stack([x[0]] * self.n, axis=0)
19 |
20 | def compute_output_shape(self, input_shape):
21 | return (self.n, *input_shape[1:])
22 |
23 |
24 | class TTA(Layer):
25 |
26 | def __init__(self, functions, params):
27 | super().__init__()
28 | self.functions = functions
29 | self.params = params
30 |
31 | def apply_transforms(self, images):
32 | transformed_images = []
33 | for i, args in enumerate(self.params):
34 | image = images[i]
35 | for f, arg in zip(self.functions, args):
36 | image = f(image, arg)
37 | transformed_images.append(image)
38 | return tf.stack(transformed_images, 0)
39 |
40 | def call(self, images):
41 | return self.apply_transforms(images)
42 |
43 |
44 | class Merge(Layer):
45 |
46 | def __init__(self, type):
47 | super().__init__()
48 | self.type = type
49 |
50 | def merge(self, x):
51 | if self.type == 'mean':
52 | return F.mean(x)
53 | if self.type == 'gmean':
54 | return F.gmean(x)
55 | if self.type == 'max':
56 | return F.max(x)
57 | else:
58 | raise ValueError(f'Wrong merge type {type}')
59 |
60 | def call(self, x):
61 | return self.merge(x)
62 |
63 | def compute_output_shape(self, input_shape):
64 | return (1, *input_shape[1:])
65 |
66 |
--------------------------------------------------------------------------------
/tta_wrapper/wrappers.py:
--------------------------------------------------------------------------------
1 | from keras.models import Model
2 | from keras.layers import Input
3 |
4 | from .layers import Repeat, TTA, Merge
5 | from .augmentation import Augmentation
6 |
7 |
8 | doc = """
9 | IMPORTANT constraints:
10 | 1) model has to have 1 input and 1 output
11 | 2) inference batch_size = 1
12 | 3) image height == width if rotate augmentation is used
13 |
14 | Args:
15 | model: instance of Keras model
16 | h_flip: (bool) horizontal flip
17 | v_flip: (bool) vertical flip
18 | h_shifts: (list of int) list of horizontal shifts (e.g. [10, -10])
19 | v_shifts: (list of int) list of vertical shifts (e.g. [10, -10])
20 | rotation: (list of int) list of angles (deg) for rotation in range [0, 360),
21 | should be divisible by 90 deg (e.g. [90, 180, 270])
22 | contrast: (list of float) values for contrast adjustment
23 | add: (list of int or float) values to add on image (e.g. [-10, 10])
24 | mul: (list of float) values to multiply image on (e.g. [0.9, 1.1])
25 | merge: one of 'mean', 'gmean' and 'max' - mode of merging augmented
26 | predictions together.
27 |
28 | Returns:
29 | Keras Model instance
30 |
31 | """
32 |
33 | def tta_segmentation(model,
34 | h_flip=False,
35 | v_flip=False,
36 | h_shift=None,
37 | v_shift=None,
38 | rotation=None,
39 | contrast=None,
40 | add=None,
41 | mul=None,
42 | merge='mean'):
43 |
44 | """
45 | Segmentation model test time augmentation wrapper.
46 | """
47 | tta = Augmentation(h_flip=h_flip,
48 | v_flip=v_flip,
49 | h_shift=h_shift,
50 | v_shift=v_shift,
51 | rotation=rotation,
52 | contrast=contrast,
53 | add=add,
54 | mul=mul,
55 | )
56 |
57 | input_shape = (1, *model.input.shape.as_list()[1:])
58 |
59 | inp = Input(batch_shape=input_shape)
60 | x = Repeat(tta.n_transforms)(inp)
61 | x = TTA(*tta.forward)(x)
62 | x = model(x)
63 | x = TTA(*tta.backward)(x)
64 | x = Merge(merge)(x)
65 | tta_model = Model(inp, x)
66 |
67 | return tta_model
68 |
69 |
70 | def tta_classification(model,
71 | h_flip=False,
72 | v_flip=False,
73 | h_shift=None,
74 | v_shift=None,
75 | rotation=None,
76 | contrast=None,
77 | add=None,
78 | mul=None,
79 | merge='mean'):
80 | """
81 | Classification model test time augmentation wrapper.
82 | """
83 |
84 | tta = Augmentation(h_flip=h_flip,
85 | v_flip=v_flip,
86 | h_shift=h_shift,
87 | v_shift=v_shift,
88 | rotation=rotation,
89 | contrast=contrast,
90 | add=add,
91 | mul=mul,
92 | )
93 |
94 | input_shape = (1, *model.input.shape.as_list()[1:])
95 |
96 | inp = Input(batch_shape=input_shape)
97 | x = Repeat(tta.n_transforms)(inp)
98 | x = TTA(*tta.forward)(x)
99 | x = model(x)
100 | x = Merge(merge)(x)
101 | tta_model = Model(inp, x)
102 |
103 | return tta_model
104 |
105 |
106 | tta_classification.__doc__ += doc
107 | tta_segmentation.__doc__ += doc
108 |
--------------------------------------------------------------------------------
/warmup_cosine_decay_scheduler.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from tensorflow import keras
3 | from keras import backend as K
4 |
5 |
6 | def cosine_decay_with_warmup(global_step,
7 | learning_rate_base,
8 | total_steps,
9 | warmup_learning_rate=0.0,
10 | warmup_steps=0,
11 | hold_base_rate_steps=0):
12 | """Cosine decay schedule with warm up period.
13 |
14 | Cosine annealing learning rate as described in:
15 | Loshchilov and Hutter, SGDR: Stochastic Gradient Descent with Warm Restarts.
16 | ICLR 2017. https://arxiv.org/abs/1608.03983
17 | In this schedule, the learning rate grows linearly from warmup_learning_rate
18 | to learning_rate_base for warmup_steps, then transitions to a cosine decay
19 | schedule.
20 |
21 | Arguments:
22 | global_step {int} -- global step.
23 | learning_rate_base {float} -- base learning rate.
24 | total_steps {int} -- total number of training steps.
25 |
26 | Keyword Arguments:
27 | warmup_learning_rate {float} -- initial learning rate for warm up. (default: {0.0})
28 | warmup_steps {int} -- number of warmup steps. (default: {0})
29 | hold_base_rate_steps {int} -- Optional number of steps to hold base learning rate
30 | before decaying. (default: {0})
31 | Returns:
32 | a float representing learning rate.
33 |
34 | Raises:
35 | ValueError: if warmup_learning_rate is larger than learning_rate_base,
36 | or if warmup_steps is larger than total_steps.
37 | """
38 |
39 | if total_steps < warmup_steps:
40 | raise ValueError('total_steps must be larger or equal to '
41 | 'warmup_steps.')
42 | learning_rate = 0.5 * learning_rate_base * (1 + np.cos(
43 | np.pi *
44 | (global_step - warmup_steps - hold_base_rate_steps
45 | ) / float(total_steps - warmup_steps - hold_base_rate_steps)))
46 | if hold_base_rate_steps > 0:
47 | learning_rate = np.where(global_step > warmup_steps + hold_base_rate_steps,
48 | learning_rate, learning_rate_base)
49 | if warmup_steps > 0:
50 | if learning_rate_base < warmup_learning_rate:
51 | raise ValueError('learning_rate_base must be larger or equal to '
52 | 'warmup_learning_rate.')
53 | slope = (learning_rate_base - warmup_learning_rate) / warmup_steps
54 | warmup_rate = slope * global_step + warmup_learning_rate
55 | learning_rate = np.where(global_step < warmup_steps, warmup_rate,
56 | learning_rate)
57 | return np.where(global_step > total_steps, 0.0, learning_rate)
58 |
59 |
60 | class WarmUpCosineDecayScheduler(keras.callbacks.Callback):
61 | """Cosine decay with warmup learning rate scheduler
62 | """
63 |
64 | def __init__(self,
65 | learning_rate_base,
66 | total_steps,
67 | global_step_init=0,
68 | warmup_learning_rate=0.0,
69 | warmup_steps=0,
70 | hold_base_rate_steps=0,
71 | verbose=0):
72 | """Constructor for cosine decay with warmup learning rate scheduler.
73 |
74 | Arguments:
75 | learning_rate_base {float} -- base learning rate.
76 | total_steps {int} -- total number of training steps.
77 |
78 | Keyword Arguments:
79 | global_step_init {int} -- initial global step, e.g. from previous checkpoint.
80 | warmup_learning_rate {float} -- initial learning rate for warm up. (default: {0.0})
81 | warmup_steps {int} -- number of warmup steps. (default: {0})
82 | hold_base_rate_steps {int} -- Optional number of steps to hold base learning rate
83 | before decaying. (default: {0})
84 | verbose {int} -- 0: quiet, 1: update messages. (default: {0})
85 | """
86 |
87 | super(WarmUpCosineDecayScheduler, self).__init__()
88 | self.learning_rate_base = learning_rate_base
89 | self.total_steps = total_steps
90 | self.global_step = global_step_init
91 | self.warmup_learning_rate = warmup_learning_rate
92 | self.warmup_steps = warmup_steps
93 | self.hold_base_rate_steps = hold_base_rate_steps
94 | self.verbose = verbose
95 | self.learning_rates = []
96 |
97 | def on_batch_end(self, batch, logs=None):
98 | self.global_step = self.global_step + 1
99 | lr = K.get_value(self.model.optimizer.lr)
100 | self.learning_rates.append(lr)
101 |
102 | def on_batch_begin(self, batch, logs=None):
103 | lr = cosine_decay_with_warmup(global_step=self.global_step,
104 | learning_rate_base=self.learning_rate_base,
105 | total_steps=self.total_steps,
106 | warmup_learning_rate=self.warmup_learning_rate,
107 | warmup_steps=self.warmup_steps,
108 | hold_base_rate_steps=self.hold_base_rate_steps)
109 | K.set_value(self.model.optimizer.lr, lr)
110 | if self.verbose > 0:
111 | print('\nBatch %05d: setting learning '
112 | 'rate to %s.' % (self.global_step + 1, lr))
113 |
114 | if __name__ == '__main__':
115 | from keras.models import Sequential
116 | from keras.layers import Dense
117 | # Create a model.
118 | model = Sequential()
119 | model.add(Dense(32, activation='relu', input_dim=100))
120 | model.add(Dense(10, activation='softmax'))
121 | model.compile(optimizer='rmsprop',
122 | loss='categorical_crossentropy',
123 | metrics=['accuracy'])
124 |
125 | # Number of training samples.
126 | # gen1
127 | sample_count = 12608
128 | # gen
129 |
130 | # Total epochs to train.
131 | epochs = 50
132 |
133 | # Number of warmup epochs.
134 | warmup_epoch = 10
135 |
136 | # Training batch size, set small value here for demonstration purpose.
137 | batch_size = 16
138 |
139 | # Base learning rate after warmup.
140 | learning_rate_base = 0.0001
141 |
142 | total_steps = int(epochs * sample_count / batch_size)
143 |
144 | # Compute the number of warmup batches.
145 | warmup_steps = int(warmup_epoch * sample_count / batch_size)
146 |
147 | # Generate dummy data.
148 | data = np.random.random((sample_count, 100))
149 | labels = np.random.randint(10, size=(sample_count, 1))
150 |
151 | # Convert labels to categorical one-hot encoding.
152 | one_hot_labels = keras.utils.to_categorical(labels, num_classes=10)
153 |
154 | # Compute the number of warmup batches.
155 | warmup_batches = warmup_epoch * sample_count / batch_size
156 |
157 | # Create the Learning rate scheduler.
158 | warm_up_lr = WarmUpCosineDecayScheduler(learning_rate_base=learning_rate_base,
159 | total_steps=total_steps,
160 | warmup_learning_rate=4e-06,
161 | warmup_steps=warmup_steps,
162 | hold_base_rate_steps=5,
163 | )
164 |
165 | # Train the model, iterating on the data in batches of 32 samples
166 | model.fit(data, one_hot_labels, epochs=epochs, batch_size=batch_size,
167 | verbose=0, callbacks=[warm_up_lr])
168 |
169 | import matplotlib.pyplot as plt
170 | plt.plot(warm_up_lr.learning_rates)
171 | plt.xlabel('Step', fontsize=20)
172 | plt.ylabel('lr', fontsize=20)
173 | plt.axis([0, total_steps, 0, learning_rate_base*1.1])
174 | plt.xticks(np.arange(0, epochs, 1))
175 | plt.grid()
176 | plt.title('Cosine decay with warmup', fontsize=20)
177 | plt.show()
178 |
--------------------------------------------------------------------------------