├── .gitignore
├── .idea
├── .gitignore
├── deployment.xml
├── garbage_calssify-by-resnet50.iml
├── inspectionProfiles
│ ├── Project_Default.xml
│ └── profiles_settings.xml
├── misc.xml
├── modules.xml
└── other.xml
├── README.md
├── __init__.py
├── __pycache__
├── data_gen.cpython-36.pyc
├── img_gen.cpython-36.pyc
├── predict_local.cpython-36.pyc
├── random_eraser.cpython-36.pyc
├── save_model.cpython-36.pyc
├── train.cpython-36.pyc
├── utils.cpython-36.pyc
└── warmup_cosine_decay_scheduler.cpython-36.pyc
├── data_process.py
├── garbage_classify
└── garbage_classify_rule.json
├── models
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-36.pyc
│ ├── model.cpython-36.pyc
│ └── resnet50.cpython-36.pyc
├── model.py
├── resnet50.py
└── resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5
├── predict_local.py
├── run.py
├── static
└── images
│ ├── banana.jpg
│ ├── test.jpg
│ └── test1.jpg
├── templates
└── predict.html
├── test_image
├── banana.jpg
├── box.jpg
├── cigar.jpg
├── egg.jpg
├── img_14052.jpg
├── test.jpg
└── test1.jpg
├── tools
├── __pycache__
│ ├── data_gen.cpython-36.pyc
│ └── utils.cpython-36.pyc
├── data_gen.py
├── img_gen.py
└── utils.py
└── train.py
/.gitignore:
--------------------------------------------------------------------------------
1 | output_model/best.h5
--------------------------------------------------------------------------------
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # Default ignored files
2 | /workspace.xml
--------------------------------------------------------------------------------
/.idea/deployment.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
--------------------------------------------------------------------------------
/.idea/garbage_calssify-by-resnet50.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/Project_Default.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
14 |
15 |
16 |
21 |
22 |
23 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/other.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | #环境初始化
2 | - python3
3 | - 安装框架flask`pip3 install flask`
4 | - 安装tensorflow,keras等依赖
5 | > - `pip3 install tensorflow==1.13.1`
6 | > - `pip3 install keras==2.3.1 `
7 |
8 |
9 | #运行
10 | - 1.命令`python3 train.py`开启训练
11 | - 2.命令`python3 predict_local.py`开启输入图片测试
12 |
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/__pycache__/data_gen.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PlantBiTree/Classify_localtrain/2bcb8af0b78ee69e370f60e919a1163700ad96b4/__pycache__/data_gen.cpython-36.pyc
--------------------------------------------------------------------------------
/__pycache__/img_gen.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PlantBiTree/Classify_localtrain/2bcb8af0b78ee69e370f60e919a1163700ad96b4/__pycache__/img_gen.cpython-36.pyc
--------------------------------------------------------------------------------
/__pycache__/predict_local.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PlantBiTree/Classify_localtrain/2bcb8af0b78ee69e370f60e919a1163700ad96b4/__pycache__/predict_local.cpython-36.pyc
--------------------------------------------------------------------------------
/__pycache__/random_eraser.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PlantBiTree/Classify_localtrain/2bcb8af0b78ee69e370f60e919a1163700ad96b4/__pycache__/random_eraser.cpython-36.pyc
--------------------------------------------------------------------------------
/__pycache__/save_model.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PlantBiTree/Classify_localtrain/2bcb8af0b78ee69e370f60e919a1163700ad96b4/__pycache__/save_model.cpython-36.pyc
--------------------------------------------------------------------------------
/__pycache__/train.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PlantBiTree/Classify_localtrain/2bcb8af0b78ee69e370f60e919a1163700ad96b4/__pycache__/train.cpython-36.pyc
--------------------------------------------------------------------------------
/__pycache__/utils.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PlantBiTree/Classify_localtrain/2bcb8af0b78ee69e370f60e919a1163700ad96b4/__pycache__/utils.cpython-36.pyc
--------------------------------------------------------------------------------
/__pycache__/warmup_cosine_decay_scheduler.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PlantBiTree/Classify_localtrain/2bcb8af0b78ee69e370f60e919a1163700ad96b4/__pycache__/warmup_cosine_decay_scheduler.cpython-36.pyc
--------------------------------------------------------------------------------
/data_process.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | from PIL import Image
3 | import requests as req
4 | from io import BytesIO
5 | import numpy as np
6 | import os
7 | import math
8 | import codecs
9 | import random
10 | from models.resnet50 import preprocess_input
11 | # 本地路径获取图片信息
12 | def preprocess_img(img_path,img_size):
13 | try:
14 | img = Image.open(img_path)
15 | # if img.format:
16 | # resize_scale = img_size / max(img.size[:2])
17 | # img = img.resize((int(img.size[0] * resize_scale), int(img.size[1] * resize_scale)))
18 | img = img.resize((256, 256))
19 | img = img.convert('RGB')
20 | # img.show()
21 | img = np.array(img)
22 | imgs = []
23 | for _ in range(10):
24 | i = random.randint(0, 32)
25 | j = random.randint(0, 32)
26 | imgg = img[i:i + 224, j:j + 224]
27 | imgg = preprocess_input(imgg)
28 | imgs.append(imgg)
29 | return imgs
30 | except Exception as e:
31 | print('发生了异常data_process:', e)
32 | return 0
33 |
34 |
35 |
36 |
37 | # url获取图片数组信息
38 | def preprocess_img_from_Url(img_path,img_size):
39 | try:
40 | response = req.get(img_path)
41 | img = Image.open(BytesIO(response.content))
42 | img = img.resize((256, 256))
43 | img = img.convert('RGB')
44 | # img.show()
45 | img = np.array(img)
46 | imgs = []
47 | for _ in range(10):
48 | i = random.randint(0, 32)
49 | j = random.randint(0, 32)
50 | imgg = img[i:i + 224, j:j + 224]
51 | imgg = preprocess_input(imgg)
52 | imgs.append(imgg)
53 | return imgs
54 | except Exception as e:
55 | print('发生了异常data_process:', e)
56 | return 0
--------------------------------------------------------------------------------
/garbage_classify/garbage_classify_rule.json:
--------------------------------------------------------------------------------
1 | {
2 | "0": "其他垃圾/一次性快餐盒",
3 | "1": "其他垃圾/污损塑料",
4 | "2": "其他垃圾/烟蒂",
5 | "3": "其他垃圾/牙签",
6 | "4": "其他垃圾/破碎花盆及碟碗",
7 | "5": "其他垃圾/竹筷",
8 | "6": "厨余垃圾/剩饭剩菜",
9 | "7": "厨余垃圾/大骨头",
10 | "8": "厨余垃圾/水果果皮",
11 | "9": "厨余垃圾/水果果肉",
12 | "10": "厨余垃圾/茶叶渣",
13 | "11": "厨余垃圾/菜叶菜根",
14 | "12": "厨余垃圾/蛋壳",
15 | "13": "厨余垃圾/鱼骨",
16 | "14": "可回收物/充电宝",
17 | "15": "可回收物/包",
18 | "16": "可回收物/化妆品瓶",
19 | "17": "可回收物/塑料玩具",
20 | "18": "可回收物/塑料碗盆",
21 | "19": "可回收物/塑料衣架",
22 | "20": "可回收物/快递纸袋",
23 | "21": "可回收物/插头电线",
24 | "22": "可回收物/旧衣服",
25 | "23": "可回收物/易拉罐",
26 | "24": "可回收物/枕头",
27 | "25": "可回收物/毛绒玩具",
28 | "26": "可回收物/洗发水瓶",
29 | "27": "可回收物/玻璃杯",
30 | "28": "可回收物/皮鞋",
31 | "29": "可回收物/砧板",
32 | "30": "可回收物/纸板箱",
33 | "31": "可回收物/调料瓶",
34 | "32": "可回收物/酒瓶",
35 | "33": "可回收物/金属食品罐",
36 | "34": "可回收物/锅",
37 | "35": "可回收物/食用油桶",
38 | "36": "可回收物/饮料瓶",
39 | "37": "有害垃圾/干电池",
40 | "38": "有害垃圾/软膏",
41 | "39": "有害垃圾/过期药物"
42 | }
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PlantBiTree/Classify_localtrain/2bcb8af0b78ee69e370f60e919a1163700ad96b4/models/__init__.py
--------------------------------------------------------------------------------
/models/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PlantBiTree/Classify_localtrain/2bcb8af0b78ee69e370f60e919a1163700ad96b4/models/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/models/__pycache__/model.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PlantBiTree/Classify_localtrain/2bcb8af0b78ee69e370f60e919a1163700ad96b4/models/__pycache__/model.cpython-36.pyc
--------------------------------------------------------------------------------
/models/__pycache__/resnet50.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PlantBiTree/Classify_localtrain/2bcb8af0b78ee69e370f60e919a1163700ad96b4/models/__pycache__/resnet50.cpython-36.pyc
--------------------------------------------------------------------------------
/models/model.py:
--------------------------------------------------------------------------------
1 | import keras.backend as K
2 | from keras.layers import Input
3 | from keras.layers.normalization import BatchNormalization
4 | from keras.layers import Conv2D, MaxPool2D, Dropout,Activation
5 | from keras.layers import Dense,Lambda,Add,GlobalAveragePooling2D,ZeroPadding2D,Multiply
6 | from keras import Model
7 | from keras import regularizers
8 | from keras.layers import Conv1D,MaxPool1D,LSTM,ZeroPadding2D
9 | from keras.layers import GlobalMaxPool1D,Permute,MaxPooling2D
10 | from keras.layers import GRU,TimeDistributed,Flatten, LeakyReLU,ELU
11 |
12 | # MODEL
13 | WEIGHT_DECAY = 0.0001 #0.00001
14 | REDUCTION_RATIO = 4
15 | BLOCK_NUM = 1
16 | # DROPOUT= 0.5
17 | # Resblcok
18 | def res_conv_block(x,filters,strides,name):
19 | filter1,filer2,filter3 = filters
20 | # block a
21 | x = Conv2D(filter1,(1,1),strides=strides,kernel_initializer='he_normal',
22 | kernel_regularizer = regularizers.l2(l=WEIGHT_DECAY),name=f'{name}_conva')(x)
23 | x = BatchNormalization(name=f'{name}_bna')(x)
24 | x = Activation('relu',name=f'{name}_relua')(x)
25 | # block b
26 | x = Conv2D(filer2,(3,3),padding='same',kernel_regularizer = regularizers.l2(l=WEIGHT_DECAY),name=f'{name}_convb')(x)
27 | x = BatchNormalization(name=f'{name}_bnb')(x)
28 | x = Activation('relu',name=f'{name}_relub')(x)
29 | # block c
30 | x = Conv2D(filter3,(1,1),name=f'{name}_convc',kernel_regularizer = regularizers.l2(l=WEIGHT_DECAY))(x)
31 | x = BatchNormalization(name=f'{name}_bnc')(x)
32 | # shortcut
33 | shortcut = Conv2D(filter3,(1,1),strides=strides,name=f'{name}_shcut',kernel_regularizer = regularizers.l2(l=WEIGHT_DECAY))(x)
34 | shortcut = BatchNormalization(name=f'{name}_stbn')(x)
35 | x = Add(name=f'{name}_add')([x,shortcut])
36 | x = Activation('relu',name=f'{name}_relu')(x)
37 | return x
38 |
39 |
40 | # ResNet
41 | def ResNet_50(input_shape):
42 | x_in = Input(input_shape,name='input')
43 |
44 | x = Conv2D(64,(7,7),strides=(2,2),padding='same',name='conv1',kernel_regularizer = regularizers.l2(l=WEIGHT_DECAY))(x_in)
45 | x = BatchNormalization(name="bn1")(x)
46 | x = Activation('relu')(x)
47 | x = MaxPool2D((3,3),strides=(2,2),padding='same',name='pool1')(x)
48 |
49 | x = res_conv_block(x,(64,64,256),(1,1),name='block1')
50 | x = res_conv_block(x,(64,64,256),(1,1),name='block2')
51 | x = res_conv_block(x,(64,64,256),(1,1),name='block3')
52 |
53 | x = res_conv_block(x,(128,128,512),(1,1),name='block4')
54 | x = res_conv_block(x,(128,128,512),(1,1),name='block5')
55 | x = res_conv_block(x,(128,128,512),(1,1),name='block6')
56 | x = res_conv_block(x,(128,128,512),(2,2),name='block7')
57 |
58 | x = Conv2D(512,(x.shape[1].value,1),name='fc6')(x)
59 | x = BatchNormalization(name="bn_fc6")(x)
60 | x = Activation('relu',name='relu_fc6')(x)
61 | # avgpool
62 | # x = GlobalAveragePooling2D(name='avgPool')(x)
63 | x = Lambda(lambda y: K.mean(y,axis=[1,2]),name='avgpool')(x)
64 |
65 | model = Model(inputs=[x_in],outputs=[x],name='ResCNN')
66 | # model.summary()
67 | return model
68 |
69 | def squeeze_excitation(x,reduction_ratio,name):
70 | out_dim = int(x.shape[-1].value)
71 | x = GlobalAveragePooling2D(name=f'{name}_squeeze')(x)
72 | x = Dense(out_dim//reduction_ratio,activation='relu',name=f'{name}_ex0')(x)
73 | x = Dense(out_dim,activation='sigmoid',name=f'{name}_ex1')(x)
74 | return x
75 |
76 | def conv_block(x,filters,kernal_size,stride,name,stage,i,padding='same'):
77 | x = Conv2D(filters,kernal_size,strides=stride,padding=padding,name=f'{name}_conv{stage}_{i}',
78 | kernel_regularizer = regularizers.l2(l=WEIGHT_DECAY))(x)
79 | x = BatchNormalization(name=f'{name}_bn{stage}_{i}')(x)
80 | if stage != 'c':
81 | # x = ELU(name=f'{name}_relu{stage}_{i}')(x)
82 | x = Activation('relu',name=f'{name}_relu{stage}_{i}')(x)
83 | return x
84 |
85 |
86 | def residual_block(x,outdim,stride,name):
87 | input_dim = int(x.shape[-1].value)
88 | shortcut = Conv2D(outdim,kernel_size=(1,1),strides=stride,name=f'{name}_scut_conv',
89 | kernel_regularizer = regularizers.l2(l=WEIGHT_DECAY))(x)
90 | shortcut = BatchNormalization(name=f'{name}_scut_norm')(shortcut)
91 |
92 | for i in range(BLOCK_NUM):
93 | if i>0 :
94 | stride = 1
95 | # x = Dropout(DROPOUT,name=f'{name}_drop{i-1}')(x)
96 | x = conv_block(x,outdim//4,(1,1),stride,name,'a',i,padding='valid')
97 | x = conv_block(x,outdim//4,(3,3),(1,1),name,'b',i,padding='same')
98 | x = conv_block(x,outdim,(1,1),(1,1),name,'c',i,padding='valid')
99 | # add SE
100 | x = Multiply(name=f'{name}_scale')([x,squeeze_excitation(x,REDUCTION_RATIO,name)])
101 | x = Add(name=f'{name}_scut')([shortcut,x])
102 | x = Activation('relu',name=f'{name}_relu')(x)
103 | return x
104 |
105 |
106 | # proposed model v4.0 timit libri
107 | def SE_ResNet(input_shape):
108 | # first layer
109 | x_in =Input(input_shape,name='input')
110 |
111 | x = ZeroPadding2D(padding=(3, 3), name='conv1_pad')(x_in)
112 | x = Conv2D(64, (7, 7),
113 | strides=(2, 2),
114 | padding='valid',
115 | kernel_initializer='he_normal',
116 | kernel_regularizer= regularizers.l2(WEIGHT_DECAY),
117 | name='conv1')(x)
118 | x = BatchNormalization(name='bn_conv1')(x)
119 | x = Activation('relu',name='relu1')(x)
120 | # x = ELU(name=f'relu1')(x)
121 | x = ZeroPadding2D(padding=(1, 1), name='pool1_pad')(x)
122 | x = MaxPooling2D((3, 3), strides=(2, 2))(x)
123 |
124 |
125 | x = residual_block(x,outdim=256,stride=(2,2),name='block1')
126 | x = residual_block(x,outdim=256,stride=(2,2),name='block2')
127 | x = residual_block(x,outdim=256,stride=(2,2),name='block3')
128 | # x = residual_block(x,outdim=256,stride=(2,2),name='block4')
129 |
130 | x = residual_block(x,outdim=512,stride=(2,2),name='block5')
131 | x = residual_block(x,outdim=512,stride=(2,2),name='block6')
132 | # x = residual_block(x,outdim=512,stride=(2,2),name='block7')
133 |
134 |
135 | x = Flatten(name='flatten')(x)
136 |
137 | x = Dropout(0.5,name='drop1')(x)
138 | x = Dense(512,kernel_regularizer= regularizers.l2(WEIGHT_DECAY),name='fc1')(x)
139 | x = BatchNormalization(name='bn_fc1')(x)
140 | # x = ELU(name=f'relu_fc1')(x)
141 | x = Activation('relu',name=f'relu_fc1')(x)
142 |
143 |
144 | return Model(inputs=[x_in],outputs=[x],name='SEResNet')
145 |
146 |
147 | # vggvox1
148 | def conv_pool(x,layerid,filters,kernal_size,conv_strides,pool_size=None,pool_strides=None,pool=None):
149 | x = Conv2D(filters,kernal_size,strides= conv_strides,padding='same',name=f'conv{layerid}')(x)
150 | x = BatchNormalization(name=f'bn{layerid}')(x)
151 | x = Activation('relu',name=f'relu{layerid}')(x)
152 | if pool == 'max':
153 | x = MaxPool2D(pool_size,pool_strides,name=f'mpool{layerid}')(x)
154 | return x
155 |
156 | def vggvox1_cnn(input_shape):
157 | x_in = Input(input_shape,name='input')
158 | x = conv_pool(x_in,1,96,(7,7),(2,2),(3,3),(2,2),'max')
159 | x = conv_pool(x,2,256,(5,5),(2,2),(3,3),(2,2),'max')
160 | x = conv_pool(x,3,384,(3,3),(1,1))
161 | x = conv_pool(x,4,256,(3,3),(1,1))
162 | x = conv_pool(x,5,256,(3,3),(1,1),(5,3),(3,2),'max')
163 | # fc 6
164 | x = Conv2D(256,(9,1),name='fc6')(x)
165 | # apool6
166 | x = GlobalAveragePooling2D(name='avgPool')(x)
167 | # fc7
168 | x = Dense(512,name='fc7',activation='relu')(x)
169 | model = Model(inputs=[x_in],outputs=[x],name='vggvox1_cnn')
170 | return model
171 |
172 | # def vggvox1_cnn(input_shape):
173 | # x_in = Input(input_shape,name='input')
174 | # x = conv_pool(x_in,1,96,(7,7),(1,1),(3,3),(2,2),'max')
175 | # x = conv_pool(x,2,256,(5,5),(1,1),(3,3),(2,2),'max')
176 | # x = conv_pool(x,3,384,(3,3),(1,1))
177 | # x = conv_pool(x,4,256,(3,3),(1,1))
178 | # x = conv_pool(x,5,256,(3,3),(1,1),(5,3),(2,2),'max')
179 | # # fc 6
180 | # x = Conv2D(256,(x.shape[1].value,1),name='fc6')(x)
181 | # # apool6
182 | # x = GlobalAveragePooling2D(name='avgPool')(x)
183 | # # fc7
184 | # x = Dense(512,name='fc7',activation='relu')(x)
185 | # model = Model(inputs=[x_in],outputs=[x],name='vggvox1_cnn')
186 | # return model
187 |
188 | # deep speaker
189 | def clipped_relu(inputs):
190 | return Lambda(lambda y:K.minimum(K.maximum(y,0),20))(inputs)
191 |
192 | def identity_block(x_in,kernel_size,filters,name):
193 | x = Conv2D(filters,kernel_size=kernel_size,strides=(1,1),
194 | padding='same',kernel_regularizer=regularizers.l2(l=WEIGHT_DECAY),
195 | name=f'{name}_conva')(x_in)
196 | x = BatchNormalization(name=f'{name}_bn1')(x)
197 | x = clipped_relu(x)
198 | x = Conv2D(filters,kernel_size=kernel_size,strides=(1,1),
199 | padding='same',kernel_regularizer = regularizers.l2(l=WEIGHT_DECAY),
200 | name=f'{name}_convb')(x)
201 | x = BatchNormalization(name=f'{name}_bn2')(x)
202 | x = Add(name=f'{name}_add')([x,x_in])
203 | x = clipped_relu(x)
204 | return x
205 |
206 | def Deep_speaker_model(input_shape):
207 | def conv_and_res_block(x_in,filters):
208 | x = Conv2D(filters,kernel_size=(5,5),strides=(2,2),
209 | padding='same',kernel_regularizer=regularizers.l2(l=WEIGHT_DECAY),
210 | name=f'conv_{filters}-s')(x_in)
211 | x = BatchNormalization(name=f'conv_{filters}-s_bn')(x)
212 | x = clipped_relu(x)
213 | for i in range(3):
214 | x = identity_block(x,kernel_size=(3,3),filters=filters,name=f'res{filters}_{i}')
215 | return x
216 |
217 | x_in = Input(input_shape,name='input')
218 | x = Permute((2,1,3),name='permute')(x_in)
219 | x = conv_and_res_block(x,64)
220 | x = conv_and_res_block(x,128)
221 | x = conv_and_res_block(x,256)
222 | x = conv_and_res_block(x,512)
223 | # average
224 | x = Lambda(lambda y: K.mean(y,axis=[1,2]),name='avgpool')(x)
225 | # affine
226 | x = Dense(512,name='affine')(x)
227 | x = Lambda(lambda y:K.l2_normalize(y,axis=1),name='ln')(x)
228 | model = Model(inputs=[x_in],outputs=[x],name='deepspeaker')
229 | return model
230 |
231 | # proposed model
232 | def Baseline_GRU(input_shape):
233 | # first layer
234 | x_in = Input(input_shape, name='input')
235 | x = Conv2D(64, (3, 3), strides=(1, 1), padding='same', name='conv1')(x_in)
236 | x = BatchNormalization(name='bn1')(x)
237 | x = ELU(name='relu1')(x)
238 | x = MaxPool2D((2, 2), strides=(2, 2), padding='same', name='pool1')(x)
239 |
240 | x = Conv2D(64, (3, 3), strides=(1, 1), padding='same', name='conv2')(x)
241 | x = BatchNormalization(name='bn2')(x)
242 | x = ELU(name='relu2')(x)
243 | x = MaxPool2D((2, 2), strides=(2, 2), padding='same', name='pool2')(x)
244 |
245 | x = TimeDistributed(Flatten(), name='timedis1')(x)
246 | x = GRU(512, return_sequences=True, name='gru1')(x)
247 | x = GRU(512, return_sequences=True, name='gru2')(x)
248 | x = GRU(512, return_sequences=False, name='gru4')(x)
249 |
250 | x = Dense(512, name='fc2', activation='relu')(x)
251 | x = BatchNormalization(name='fc_norm')(x)
252 | x = ELU(name='relu3')(x)
253 |
254 | return Model(inputs=[x_in], outputs=[x], name='Baseline_GRU')
255 |
256 |
257 |
258 |
259 |
260 | if __name__ == "__main__":
261 |
262 | # model = ResNet(c.INPUT_SHPE)
263 | model = vggvox1_cnn((299,40,1))
264 | # model = Deep_speaker_model(c.INPUT_SHPE)
265 | # # model = SE_ResNet(c.INPUT_SHPE)
266 | # model = RWCNN_LSTM((59049,1))
267 | print(model.summary())
268 |
--------------------------------------------------------------------------------
/models/resnet50.py:
--------------------------------------------------------------------------------
1 | """Enables dynamic setting of underlying Keras module.
2 | """
3 | from __future__ import absolute_import
4 | from __future__ import division
5 | from __future__ import print_function
6 |
7 | import os
8 | import json
9 | import warnings
10 | import numpy as np
11 |
12 | from keras.applications import backend
13 | from keras.applications import layers
14 | from keras.applications import models
15 | from keras.applications import utils
16 |
17 | _KERAS_BACKEND = backend
18 | _KERAS_LAYERS = layers
19 | _KERAS_MODELS = models
20 | _KERAS_UTILS = utils
21 |
22 | CLASS_INDEX = None
23 | CLASS_INDEX_PATH = ('https://modelarts-competitions.obs.cn-north-1.myhuaweicloud.com/'
24 | 'model_zoo/resnet/imagenet_class_index.json')
25 |
26 | # Global tensor of imagenet mean for preprocessin
27 | # g symbolic inputs
28 | _IMAGENET_MEAN = None
29 |
30 | WEIGHTS_PATH = ('https://github.com/fchollet/deep-learning-models/'
31 | 'releases/download/v0.2/'
32 | 'resnet50_weights_tf_dim_ordering_tf_kernels.h5')
33 | WEIGHTS_PATH_NO_TOP = ('https://modelarts-competitions.obs.cn-north-1.myhuaweicloud.com/'
34 | 'model_zoo/resnet/'
35 | 'resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5')
36 |
37 | def get_submodules_from_kwargs(kwargs):
38 | backend = kwargs.get('backend', _KERAS_BACKEND)
39 | layers = kwargs.get('layers', _KERAS_LAYERS)
40 | models = kwargs.get('models', _KERAS_MODELS)
41 | utils = kwargs.get('utils', _KERAS_UTILS)
42 | for key in kwargs.keys():
43 | if key not in ['backend', 'layers', 'models', 'utils']:
44 | raise TypeError('Invalid keyword argument: %s', key)
45 | return backend, layers, models, utils
46 |
47 |
48 | def correct_pad(backend, inputs, kernel_size):
49 | """Returns a tuple for zero-padding for 2D convolution with downsampling.
50 |
51 | # Arguments
52 | input_size: An integer or tuple/list of 2 integers.
53 | kernel_size: An integer or tuple/list of 2 integers.
54 |
55 | # Returns
56 | A tuple.
57 | """
58 | img_dim = 2 if backend.image_data_format() == 'channels_first' else 1
59 | input_size = backend.int_shape(inputs)[img_dim:(img_dim + 2)]
60 |
61 | if isinstance(kernel_size, int):
62 | kernel_size = (kernel_size, kernel_size)
63 |
64 | if input_size[0] is None:
65 | adjust = (1, 1)
66 | else:
67 | adjust = (1 - input_size[0] % 2, 1 - input_size[1] % 2)
68 |
69 | correct = (kernel_size[0] // 2, kernel_size[1] // 2)
70 |
71 | return ((correct[0] - adjust[0], correct[0]),
72 | (correct[1] - adjust[1], correct[1]))
73 |
74 | __version__ = '1.0.7'
75 |
76 |
77 | def _preprocess_numpy_input(x, data_format, mode, **kwargs):
78 | """Preprocesses a Numpy array encoding a batch of images.
79 |
80 | # Arguments
81 | x: Input array, 3D or 4D.
82 | data_format: Data format of the image array.
83 | mode: One of "caffe", "tf" or "torch".
84 | - caffe: will convert the images from RGB to BGR,
85 | then will zero-center each color channel with
86 | respect to the ImageNet dataset,
87 | without scaling.
88 | - tf: will scale pixels between -1 and 1,
89 | sample-wise.
90 | - torch: will scale pixels between 0 and 1 and then
91 | will normalize each channel with respect to the
92 | ImageNet dataset.
93 |
94 | # Returns
95 | Preprocessed Numpy array.
96 | """
97 | backend, _, _, _ = get_submodules_from_kwargs(kwargs)
98 | if not issubclass(x.dtype.type, np.floating):
99 | x = x.astype(backend.floatx(), copy=False)
100 |
101 |
102 | if mode == 'tf':
103 | x /= 127.5
104 | x -= 1.
105 |
106 | return x
107 |
108 | if mode == 'torch':
109 | x /= 255.
110 | mean = [0.485, 0.456, 0.406]
111 | std = [0.229, 0.224, 0.225]
112 | else:
113 | if data_format == 'channels_first':
114 | # 'RGB'->'BGR'
115 | if x.ndim == 3:
116 | x = x[::-1, ...]
117 | else:
118 | x = x[:, ::-1, ...]
119 | else:
120 | # 'RGB'->'BGR'
121 | x = x[..., ::-1]
122 | mean = [103.939, 116.779, 123.68]
123 | std = None
124 |
125 | # Zero-center by mean pixel
126 | if data_format == 'channels_first':
127 | if x.ndim == 3:
128 | x[0, :, :] -= mean[0]
129 | x[1, :, :] -= mean[1]
130 | x[2, :, :] -= mean[2]
131 | if std is not None:
132 | x[0, :, :] /= std[0]
133 | x[1, :, :] /= std[1]
134 | x[2, :, :] /= std[2]
135 | else:
136 | x[:, 0, :, :] -= mean[0]
137 | x[:, 1, :, :] -= mean[1]
138 | x[:, 2, :, :] -= mean[2]
139 | if std is not None:
140 | x[:, 0, :, :] /= std[0]
141 | x[:, 1, :, :] /= std[1]
142 | x[:, 2, :, :] /= std[2]
143 | else:
144 | x[..., 0] -= mean[0]
145 | x[..., 1] -= mean[1]
146 | x[..., 2] -= mean[2]
147 | if std is not None:
148 | x[..., 0] /= std[0]
149 | x[..., 1] /= std[1]
150 | x[..., 2] /= std[2]
151 | return x
152 |
153 |
154 | def _preprocess_symbolic_input(x, data_format, mode, **kwargs):
155 | """Preprocesses a tensor encoding a batch of images.
156 |
157 | # Arguments
158 | x: Input tensor, 3D or 4D.
159 | data_format: Data format of the image tensor.
160 | mode: One of "caffe", "tf" or "torch".
161 | - caffe: will convert the images from RGB to BGR,
162 | then will zero-center each color channel with
163 | respect to the ImageNet dataset,
164 | without scaling.
165 | - tf: will scale pixels between -1 and 1,
166 | sample-wise.
167 | - torch: will scale pixels between 0 and 1 and then
168 | will normalize each channel with respect to the
169 | ImageNet dataset.
170 |
171 | # Returns
172 | Preprocessed tensor.
173 | """
174 | global _IMAGENET_MEAN
175 |
176 | backend, _, _, _ = get_submodules_from_kwargs(kwargs)
177 |
178 | if mode == 'tf':
179 | x /= 127.5
180 | x -= 1.
181 | return x
182 |
183 | if mode == 'torch':
184 | x /= 255.
185 | mean = [0.485, 0.456, 0.406]
186 | std = [0.229, 0.224, 0.225]
187 | else:
188 | if data_format == 'channels_first':
189 | # 'RGB'->'BGR'
190 | if backend.ndim(x) == 3:
191 | x = x[::-1, ...]
192 | else:
193 | x = x[:, ::-1, ...]
194 | else:
195 | # 'RGB'->'BGR'
196 | x = x[..., ::-1]
197 | mean = [103.939, 116.779, 123.68]
198 | std = None
199 |
200 | if _IMAGENET_MEAN is None:
201 | _IMAGENET_MEAN = backend.constant(-np.array(mean))
202 |
203 | # Zero-center by mean pixel
204 | if backend.dtype(x) != backend.dtype(_IMAGENET_MEAN):
205 | x = backend.bias_add(
206 | x, backend.cast(_IMAGENET_MEAN, backend.dtype(x)),
207 | data_format=data_format)
208 | else:
209 | x = backend.bias_add(x, _IMAGENET_MEAN, data_format)
210 | if std is not None:
211 | x /= std
212 | return x
213 |
214 |
215 | def preprocess_input(x, data_format=None, mode='caffe', **kwargs):
216 | """Preprocesses a tensor or Numpy array encoding a batch of images.
217 |
218 | # Arguments
219 | x: Input Numpy or symbolic tensor, 3D or 4D.
220 | The preprocessed data is written over the input data
221 | if the data types are compatible. To avoid this
222 | behaviour, `numpy.copy(x)` can be used.
223 | data_format: Data format of the image tensor/array.
224 | mode: One of "caffe", "tf" or "torch".
225 | - caffe: will convert the images from RGB to BGR,
226 | then will zero-center each color channel with
227 | respect to the ImageNet dataset,
228 | without scaling.
229 | - tf: will scale pixels between -1 and 1,
230 | sample-wise.
231 | - torch: will scale pixels between 0 and 1 and then
232 | will normalize each channel with respect to the
233 | ImageNet dataset.
234 |
235 | # Returns
236 | Preprocessed tensor or Numpy array.
237 |
238 | # Raises
239 | ValueError: In case of unknown `data_format` argument.
240 | """
241 | backend, _, _, _ = get_submodules_from_kwargs(kwargs)
242 |
243 | if data_format is None:
244 | data_format = backend.image_data_format()
245 | if data_format not in {'channels_first', 'channels_last'}:
246 | raise ValueError('Unknown data_format ' + str(data_format))
247 |
248 | if isinstance(x, np.ndarray):
249 | return _preprocess_numpy_input(x, data_format=data_format,
250 | mode=mode, **kwargs)
251 | else:
252 | return _preprocess_symbolic_input(x, data_format=data_format,
253 | mode=mode, **kwargs)
254 |
255 |
256 | def decode_predictions(preds, top=5, **kwargs):
257 | """Decodes the prediction of an ImageNet model.
258 |
259 | # Arguments
260 | preds: Numpy tensor encoding a batch of predictions.
261 | top: Integer, how many top-guesses to return.
262 |
263 | # Returns
264 | A list of lists of top class prediction tuples
265 | `(class_name, class_description, score)`.
266 | One list of tuples per sample in batch input.
267 |
268 | # Raises
269 | ValueError: In case of invalid shape of the `pred` array
270 | (must be 2D).
271 | """
272 | global CLASS_INDEX
273 |
274 | backend, _, _, keras_utils = get_submodules_from_kwargs(kwargs)
275 |
276 | if len(preds.shape) != 2 or preds.shape[1] != 1000:
277 | raise ValueError('`decode_predictions` expects '
278 | 'a batch of predictions '
279 | '(i.e. a 2D array of shape (samples, 1000)). '
280 | 'Found array with shape: ' + str(preds.shape))
281 | if CLASS_INDEX is None:
282 | fpath = keras_utils.get_file(
283 | 'imagenet_class_index.json',
284 | CLASS_INDEX_PATH,
285 | cache_subdir='models',
286 | file_hash='c2c37ea517e94d9795004a39431a14cb',
287 | cache_dir=os.path.join(os.path.dirname(__file__), '..'))
288 | with open(fpath) as f:
289 | CLASS_INDEX = json.load(f)
290 | results = []
291 | for pred in preds:
292 | top_indices = pred.argsort()[-top:][::-1]
293 | result = [tuple(CLASS_INDEX[str(i)]) + (pred[i],) for i in top_indices]
294 | result.sort(key=lambda x: x[2], reverse=True)
295 | results.append(result)
296 | return results
297 |
298 |
299 | def _obtain_input_shape(input_shape,
300 | default_size,
301 | min_size,
302 | data_format,
303 | require_flatten,
304 | weights=None):
305 | """Internal utility to compute/validate a model's input shape.
306 |
307 | # Arguments
308 | input_shape: Either None (will return the default network input shape),
309 | or a user-provided shape to be validated.
310 | default_size: Default input width/height for the model.
311 | min_size: Minimum input width/height accepted by the model.
312 | data_format: Image data format to use.
313 | require_flatten: Whether the model is expected to
314 | be linked to a classifier via a Flatten layer.
315 | weights: One of `None` (random initialization)
316 | or 'imagenet' (pre-training on ImageNet).
317 | If weights='imagenet' input channels must be equal to 3.
318 |
319 | # Returns
320 | An integer shape tuple (may include None entries).
321 |
322 | # Raises
323 | ValueError: In case of invalid argument values.
324 | """
325 | if weights != 'imagenet' and input_shape and len(input_shape) == 3:
326 | if data_format == 'channels_first':
327 | if input_shape[0] not in {1, 3}:
328 | warnings.warn(
329 | 'This model usually expects 1 or 3 input channels. '
330 | 'However, it was passed an input_shape with ' +
331 | str(input_shape[0]) + ' input channels.')
332 | default_shape = (input_shape[0], default_size, default_size)
333 | else:
334 | if input_shape[-1] not in {1, 3}:
335 | warnings.warn(
336 | 'This model usually expects 1 or 3 input channels. '
337 | 'However, it was passed an input_shape with ' +
338 | str(input_shape[-1]) + ' input channels.')
339 | default_shape = (default_size, default_size, input_shape[-1])
340 | else:
341 | if data_format == 'channels_first':
342 | default_shape = (3, default_size, default_size)
343 | else:
344 | default_shape = (default_size, default_size, 3)
345 | if weights == 'imagenet' and require_flatten:
346 | if input_shape is not None:
347 | if input_shape != default_shape:
348 | raise ValueError('When setting `include_top=True` '
349 | 'and loading `imagenet` weights, '
350 | '`input_shape` should be ' +
351 | str(default_shape) + '.')
352 | return default_shape
353 | if input_shape:
354 | if data_format == 'channels_first':
355 | if input_shape is not None:
356 | if len(input_shape) != 3:
357 | raise ValueError(
358 | '`input_shape` must be a tuple of three integers.')
359 | if input_shape[0] != 3 and weights == 'imagenet':
360 | raise ValueError('The input must have 3 channels; got '
361 | '`input_shape=' + str(input_shape) + '`')
362 | if ((input_shape[1] is not None and input_shape[1] < min_size) or
363 | (input_shape[2] is not None and input_shape[2] < min_size)):
364 | raise ValueError('Input size must be at least ' +
365 | str(min_size) + 'x' + str(min_size) +
366 | '; got `input_shape=' +
367 | str(input_shape) + '`')
368 | else:
369 | if input_shape is not None:
370 | if len(input_shape) != 3:
371 | raise ValueError(
372 | '`input_shape` must be a tuple of three integers.')
373 | if input_shape[-1] != 3 and weights == 'imagenet':
374 | raise ValueError('The input must have 3 channels; got '
375 | '`input_shape=' + str(input_shape) + '`')
376 | if ((input_shape[0] is not None and input_shape[0] < min_size) or
377 | (input_shape[1] is not None and input_shape[1] < min_size)):
378 | raise ValueError('Input size must be at least ' +
379 | str(min_size) + 'x' + str(min_size) +
380 | '; got `input_shape=' +
381 | str(input_shape) + '`')
382 | else:
383 | if require_flatten:
384 | input_shape = default_shape
385 | else:
386 | if data_format == 'channels_first':
387 | input_shape = (3, None, None)
388 | else:
389 | input_shape = (None, None, 3)
390 | if require_flatten:
391 | if None in input_shape:
392 | raise ValueError('If `include_top` is True, '
393 | 'you should specify a static `input_shape`. '
394 | 'Got `input_shape=' + str(input_shape) + '`')
395 | return input_shape
396 |
397 |
398 | backend = None
399 | layers = None
400 | models = None
401 | keras_utils = None
402 |
403 |
404 | def identity_block(input_tensor, kernel_size, filters, stage, block):
405 | """The identity block is the block that has no conv layer at shortcut.
406 |
407 | # Arguments
408 | input_tensor: input tensor
409 | kernel_size: default 3, the kernel size of
410 | middle conv layer at main path
411 | filters: list of integers, the filters of 3 conv layer at main path
412 | stage: integer, current stage label, used for generating layer names
413 | block: 'a','b'..., current block label, used for generating layer names
414 |
415 | # Returns
416 | Output tensor for the block.
417 | """
418 | filters1, filters2, filters3 = filters
419 | if backend.image_data_format() == 'channels_last':
420 | bn_axis = 3
421 | else:
422 | bn_axis = 1
423 | conv_name_base = 'res' + str(stage) + block + '_branch'
424 | bn_name_base = 'bn' + str(stage) + block + '_branch'
425 |
426 | x = layers.Conv2D(filters1, (1, 1),
427 | kernel_initializer='he_normal',
428 | name=conv_name_base + '2a')(input_tensor)
429 | x = layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '2a')(x)
430 | x = layers.Activation('relu')(x)
431 |
432 | x = layers.Conv2D(filters2, kernel_size,
433 | padding='same',
434 | kernel_initializer='he_normal',
435 | name=conv_name_base + '2b')(x)
436 | x = layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '2b')(x)
437 | x = layers.Activation('relu')(x)
438 |
439 | x = layers.Conv2D(filters3, (1, 1),
440 | kernel_initializer='he_normal',
441 | name=conv_name_base + '2c')(x)
442 | x = layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '2c')(x)
443 |
444 | x = layers.add([x, input_tensor])
445 | x = layers.Activation('relu')(x)
446 | return x
447 |
448 |
449 | def conv_block(input_tensor,
450 | kernel_size,
451 | filters,
452 | stage,
453 | block,
454 | strides=(2, 2)):
455 | """A block that has a conv layer at shortcut.
456 |
457 | # Arguments
458 | input_tensor: input tensor
459 | kernel_size: default 3, the kernel size of
460 | middle conv layer at main path
461 | filters: list of integers, the filters of 3 conv layer at main path
462 | stage: integer, current stage label, used for generating layer names
463 | block: 'a','b'..., current block label, used for generating layer names
464 | strides: Strides for the first conv layer in the block.
465 |
466 | # Returns
467 | Output tensor for the block.
468 |
469 | Note that from stage 3,
470 | the first conv layer at main path is with strides=(2, 2)
471 | And the shortcut should have strides=(2, 2) as well
472 | """
473 | filters1, filters2, filters3 = filters
474 | if backend.image_data_format() == 'channels_last':
475 | bn_axis = 3
476 | else:
477 | bn_axis = 1
478 | conv_name_base = 'res' + str(stage) + block + '_branch'
479 | bn_name_base = 'bn' + str(stage) + block + '_branch'
480 |
481 | x = layers.Conv2D(filters1, (1, 1), strides=strides,
482 | kernel_initializer='he_normal',
483 | name=conv_name_base + '2a')(input_tensor)
484 | x = layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '2a')(x)
485 | x = layers.Activation('relu')(x)
486 |
487 | x = layers.Conv2D(filters2, kernel_size, padding='same',
488 | kernel_initializer='he_normal',
489 | name=conv_name_base + '2b')(x)
490 | x = layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '2b')(x)
491 | x = layers.Activation('relu')(x)
492 |
493 | x = layers.Conv2D(filters3, (1, 1),
494 | kernel_initializer='he_normal',
495 | name=conv_name_base + '2c')(x)
496 | x = layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '2c')(x)
497 |
498 | shortcut = layers.Conv2D(filters3, (1, 1), strides=strides,
499 | kernel_initializer='he_normal',
500 | name=conv_name_base + '1')(input_tensor)
501 | shortcut = layers.BatchNormalization(
502 | axis=bn_axis, name=bn_name_base + '1')(shortcut)
503 |
504 | x = layers.add([x, shortcut])
505 | x = layers.Activation('relu')(x)
506 | return x
507 |
508 |
509 | def ResNet50(include_top=True,
510 | weights='imagenet',
511 | input_tensor=None,
512 | input_shape=None,
513 | pooling=None,
514 | classes=1000,
515 | **kwargs):
516 | """Instantiates the ResNet50 architecture.
517 |
518 | Optionally loads weights pre-trained on ImageNet.
519 | Note that the data format convention used by the model is
520 | the one specified in your Keras config at `~/.keras/keras.json`.
521 |
522 | # Arguments
523 | include_top: whether to include the fully-connected
524 | layer at the top of the network.
525 | weights: one of `None` (random initialization),
526 | 'imagenet' (pre-training on ImageNet),
527 | or the path to the weights file to be loaded.
528 | input_tensor: optional Keras tensor (i.e. output of `layers.Input()`)
529 | to use as image input for the model.
530 | input_shape: optional shape tuple, only to be specified
531 | if `include_top` is False (otherwise the input shape
532 | has to be `(224, 224, 3)` (with `channels_last` data format)
533 | or `(3, 224, 224)` (with `channels_first` data format).
534 | It should have exactly 3 inputs channels,
535 | and width and height should be no smaller than 32.
536 | E.g. `(200, 200, 3)` would be one valid value.
537 | pooling: Optional pooling mode for feature extraction
538 | when `include_top` is `False`.
539 | - `None` means that the output of the model will be
540 | the 4D tensor output of the
541 | last convolutional block.
542 | - `avg` means that global average pooling
543 | will be applied to the output of the
544 | last convolutional block, and thus
545 | the output of the model will be a 2D tensor.
546 | - `max` means that global max pooling will
547 | be applied.
548 | classes: optional number of classes to classify images
549 | into, only to be specified if `include_top` is True, and
550 | if no `weights` argument is specified.
551 |
552 | # Returns
553 | A Keras model instance.
554 |
555 | # Raises
556 | ValueError: in case of invalid argument for `weights`,
557 | or invalid input shape.
558 | """
559 | global backend, layers, models, keras_utils
560 | backend, layers, models, keras_utils = get_submodules_from_kwargs(kwargs)
561 |
562 | if not (weights in {'imagenet', None} or os.path.exists(weights)):
563 | raise ValueError('The `weights` argument should be either '
564 | '`None` (random initialization), `imagenet` '
565 | '(pre-training on ImageNet), '
566 | 'or the path to the weights file to be loaded.')
567 |
568 | if weights == 'imagenet' and include_top and classes != 1000:
569 | raise ValueError('If using `weights` as `"imagenet"` with `include_top`'
570 | ' as true, `classes` should be 1000')
571 |
572 | # Determine proper input shape
573 | input_shape = _obtain_input_shape(input_shape,
574 | default_size=224,
575 | min_size=32,
576 | data_format=backend.image_data_format(),
577 | require_flatten=include_top,
578 | weights=weights)
579 |
580 | if input_tensor is None:
581 | img_input = layers.Input(shape=input_shape)
582 | else:
583 | if not backend.is_keras_tensor(input_tensor):
584 | img_input = layers.Input(tensor=input_tensor, shape=input_shape)
585 | else:
586 | img_input = input_tensor
587 | if backend.image_data_format() == 'channels_last':
588 | bn_axis = 3
589 | else:
590 | bn_axis = 1
591 |
592 | x = layers.ZeroPadding2D(padding=(3, 3), name='conv1_pad')(img_input)
593 | x = layers.Conv2D(64, (7, 7),
594 | strides=(2, 2),
595 | padding='valid',
596 | kernel_initializer='he_normal',
597 | name='conv1')(x)
598 | x = layers.BatchNormalization(axis=bn_axis, name='bn_conv1')(x)
599 | x = layers.Activation('relu')(x)
600 | x = layers.ZeroPadding2D(padding=(1, 1), name='pool1_pad')(x)
601 | x = layers.MaxPooling2D((3, 3), strides=(2, 2))(x)
602 |
603 | x = conv_block(x, 3, [64, 64, 256], stage=2, block='a', strides=(1, 1))
604 | x = identity_block(x, 3, [64, 64, 256], stage=2, block='b')
605 | x = identity_block(x, 3, [64, 64, 256], stage=2, block='c')
606 |
607 | x = conv_block(x, 3, [128, 128, 512], stage=3, block='a')
608 | x = identity_block(x, 3, [128, 128, 512], stage=3, block='b')
609 | x = identity_block(x, 3, [128, 128, 512], stage=3, block='c')
610 | x = identity_block(x, 3, [128, 128, 512], stage=3, block='d')
611 |
612 | x = conv_block(x, 3, [256, 256, 1024], stage=4, block='a')
613 | x = identity_block(x, 3, [256, 256, 1024], stage=4, block='b')
614 | x = identity_block(x, 3, [256, 256, 1024], stage=4, block='c')
615 | x = identity_block(x, 3, [256, 256, 1024], stage=4, block='d')
616 | x = identity_block(x, 3, [256, 256, 1024], stage=4, block='e')
617 | x = identity_block(x, 3, [256, 256, 1024], stage=4, block='f')
618 |
619 | x = conv_block(x, 3, [512, 512, 2048], stage=5, block='a')
620 | x = identity_block(x, 3, [512, 512, 2048], stage=5, block='b')
621 | x = identity_block(x, 3, [512, 512, 2048], stage=5, block='c')
622 |
623 | if include_top:
624 | x = layers.GlobalAveragePooling2D(name='avg_pool')(x)
625 | x = layers.Dense(classes, activation='softmax', name='fc1000')(x)
626 | else:
627 | if pooling == 'avg':
628 | x = layers.GlobalAveragePooling2D()(x)
629 | elif pooling == 'max':
630 | x = layers.GlobalMaxPooling2D()(x)
631 | else:
632 | warnings.warn('The output shape of `ResNet50(include_top=False)` '
633 | 'has been changed since Keras 2.2.0.')
634 |
635 | # Ensure that the model takes into account
636 | # any potential predecessors of `input_tensor`.
637 | if input_tensor is not None:
638 | inputs = keras_utils.get_source_inputs(input_tensor)
639 | else:
640 | inputs = img_input
641 | # Create model.
642 | model = models.Model(inputs, x, name='resnet50')
643 |
644 | # Load weights.
645 | if weights == 'imagenet':
646 | if include_top:
647 | weights_path = keras_utils.get_file(
648 | 'resnet50_weights_tf_dim_ordering_tf_kernels.h5',
649 | WEIGHTS_PATH,
650 | cache_subdir='models',
651 | md5_hash='a7b3fe01876f51b976af0dea6bc144eb',
652 | cache_dir=os.path.join(os.path.dirname(__file__), '..'))
653 | else:
654 | weights_path = keras_utils.get_file(
655 | 'resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5',
656 | WEIGHTS_PATH_NO_TOP,
657 | cache_subdir='models',
658 | md5_hash='a268eb855778b3df3c7506639542a6af',
659 | cache_dir=os.path.join(os.path.dirname(__file__), '..'))
660 | model.load_weights(weights_path)
661 | if backend.backend() == 'theano':
662 | keras_utils.convert_all_kernels_in_model(model)
663 | elif weights is not None:
664 | model.load_weights(weights)
665 |
666 | return model
667 |
668 |
669 |
--------------------------------------------------------------------------------
/models/resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PlantBiTree/Classify_localtrain/2bcb8af0b78ee69e370f60e919a1163700ad96b4/models/resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5
--------------------------------------------------------------------------------
/predict_local.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import re
3 | import os
4 | import json
5 | from tools.data_gen import preprocess_img,preprocess_img_from_Url
6 | from models.resnet50 import ResNet50
7 | from keras.layers import Dense,Dropout,BatchNormalization,GlobalAveragePooling2D
8 | from keras.models import Model
9 | import numpy as np
10 | from keras import regularizers
11 | from tensorflow.python.keras.backend import set_session
12 |
13 | # import serial
14 | # OPTIONAL: control usage of GPU
15 | os.environ["CUDA_VISIBLE_DEVICES"] = "0"
16 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
17 | config = tf.ConfigProto()
18 | config.gpu_options.allow_growth = True
19 | config.gpu_options.per_process_gpu_memory_fraction = 0.7
20 | sess = tf.Session(config=config)
21 |
22 | # 全局配置文件
23 | tf.app.flags.DEFINE_integer('num_classes', 40, '垃圾分类数目')
24 | tf.app.flags.DEFINE_integer('input_size', 224, '模型输入图片大小')
25 | tf.app.flags.DEFINE_integer('batch_size', 16, '图片批处理大小')
26 |
27 | FLAGS = tf.app.flags.FLAGS
28 | h5_weights_path = './output_model/best.h5'
29 |
30 | # 增加最后输出层
31 | def add_new_last_layer(base_model,num_classes):
32 | x = base_model.output
33 | x = GlobalAveragePooling2D(name='avg_pool')(x)
34 | x = Dropout(0.5,name='dropout1')(x)
35 | # x = Dense(1024,activation='relu',kernel_regularizer= regularizers.l2(0.0001),name='fc1')(x)
36 | # x = BatchNormalization(name='bn_fc_00')(x)
37 | x = Dense(512,activation='relu',kernel_regularizer= regularizers.l2(0.0001),name='fc2')(x)
38 | x = BatchNormalization(name='bn_fc_01')(x)
39 | x = Dropout(0.5,name='dropout2')(x)
40 | x = Dense(num_classes,activation='softmax')(x)
41 | model = Model(inputs=base_model.input,outputs=x)
42 | return model
43 |
44 |
45 | # 加载模型
46 | def model_fn(FLAGS):
47 | # K.set_learning_phase(0)
48 | # setup model
49 | base_model = ResNet50(weights="imagenet",
50 | include_top=False,
51 | pooling=None,
52 | input_shape=(FLAGS.input_size, FLAGS.input_size, 3),
53 | classes=FLAGS.num_classes)
54 | for layer in base_model.layers:
55 | layer.trainable = False
56 |
57 | # if FLAGS.mode == 'train':
58 | # K.set_learning_phase(1)
59 | model = add_new_last_layer(base_model,FLAGS.num_classes)
60 |
61 | # print(model.summary())
62 | # print(model.layers[84].name)
63 | # exit()
64 |
65 | # Adam = adam(lr=FLAGS.learning_rate,clipnorm=0.001)
66 | model.compile(optimizer="adam",loss = 'categorical_crossentropy',metrics=['accuracy'])
67 | return model
68 |
69 | # 暴露模型初始化
70 | def init_artificial_neural_network():
71 | set_session(sess)
72 | model = model_fn(FLAGS)
73 | model.load_weights(h5_weights_path, by_name=True)
74 | return model
75 | # model = model_fn(FLAGS)
76 | #model.load_weights(h5_weights_path, by_name=True)
77 | #return model
78 |
79 |
80 | # 测试图片
81 | def prediction_result_from_img(model,imgurl):
82 | """
83 |
84 | :rtype: object
85 | """
86 | # 加载分类数据
87 | with open("./garbage_classify/garbage_classify_rule.json", 'br') as load_f:
88 | load_dict = json.load(load_f)
89 | if re.match(r'^https?:/{2}\w.+$', imgurl):
90 | test_data = preprocess_img_from_Url(imgurl,FLAGS.input_size)
91 | else:
92 | test_data = preprocess_img(imgurl,FLAGS.input_size)
93 | tta_num = 5
94 | predictions = [0 * tta_num]
95 | for i in range(tta_num):
96 | x_test = test_data[i]
97 | x_test = x_test[np.newaxis, :, :, :]
98 | prediction = model.predict(x_test)[0]
99 | # print(prediction)
100 | predictions += prediction
101 | pred_label = np.argmax(predictions, axis=0)
102 | print('-------深度学习垃圾分类预测结果----------')
103 | print(pred_label)
104 | print(load_dict[str(pred_label)])
105 | print('-------深度学习垃圾分类预测结果--------')
106 | return load_dict[str(pred_label)]
107 |
108 |
109 |
110 | if __name__ == "__main__":
111 | model = init_artificial_neural_network()
112 | while True:
113 | try:
114 | img_url = input("请输入图片地址:")
115 | print('您输入的图片地址为:' + img_url)
116 | res = prediction_result_from_img(model, img_url)
117 | except Exception as e:
118 | print('发生了异常:', e)
119 |
120 |
--------------------------------------------------------------------------------
/run.py:
--------------------------------------------------------------------------------
1 | from flask_sever import sess
2 | from predict_local import prediction_result_from_img, init_artificial_neural_network
3 |
4 | if __name__ == '__main__':
5 | model = init_artificial_neural_network(sess);
6 | while True:
7 | try:
8 | img_url = input("请输入图片地址:")
9 | print('您输入的图片地址为:' + img_url)
10 | res = prediction_result_from_img(model, img_url)
11 | except Exception as e:
12 | print('发生了异常:', e)
13 |
--------------------------------------------------------------------------------
/static/images/banana.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PlantBiTree/Classify_localtrain/2bcb8af0b78ee69e370f60e919a1163700ad96b4/static/images/banana.jpg
--------------------------------------------------------------------------------
/static/images/test.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PlantBiTree/Classify_localtrain/2bcb8af0b78ee69e370f60e919a1163700ad96b4/static/images/test.jpg
--------------------------------------------------------------------------------
/static/images/test1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PlantBiTree/Classify_localtrain/2bcb8af0b78ee69e370f60e919a1163700ad96b4/static/images/test1.jpg
--------------------------------------------------------------------------------
/templates/predict.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | 在线测试神经网络垃圾分类识别接口
6 |
7 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
深度学习之垃圾分类api
17 |
18 |
19 |
20 | 在线测试图片.
21 |
22 |
36 |
提交
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
http://10.11.2.17:5000/predict
48 |
post方式访问:上传文件:file
49 |
例如:file:"XXXXXXX"
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
81 |
82 |
85 |
86 |
87 |
88 |
89 |
90 |
--------------------------------------------------------------------------------
/test_image/banana.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PlantBiTree/Classify_localtrain/2bcb8af0b78ee69e370f60e919a1163700ad96b4/test_image/banana.jpg
--------------------------------------------------------------------------------
/test_image/box.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PlantBiTree/Classify_localtrain/2bcb8af0b78ee69e370f60e919a1163700ad96b4/test_image/box.jpg
--------------------------------------------------------------------------------
/test_image/cigar.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PlantBiTree/Classify_localtrain/2bcb8af0b78ee69e370f60e919a1163700ad96b4/test_image/cigar.jpg
--------------------------------------------------------------------------------
/test_image/egg.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PlantBiTree/Classify_localtrain/2bcb8af0b78ee69e370f60e919a1163700ad96b4/test_image/egg.jpg
--------------------------------------------------------------------------------
/test_image/img_14052.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PlantBiTree/Classify_localtrain/2bcb8af0b78ee69e370f60e919a1163700ad96b4/test_image/img_14052.jpg
--------------------------------------------------------------------------------
/test_image/test.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PlantBiTree/Classify_localtrain/2bcb8af0b78ee69e370f60e919a1163700ad96b4/test_image/test.jpg
--------------------------------------------------------------------------------
/test_image/test1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PlantBiTree/Classify_localtrain/2bcb8af0b78ee69e370f60e919a1163700ad96b4/test_image/test1.jpg
--------------------------------------------------------------------------------
/tools/__pycache__/data_gen.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PlantBiTree/Classify_localtrain/2bcb8af0b78ee69e370f60e919a1163700ad96b4/tools/__pycache__/data_gen.cpython-36.pyc
--------------------------------------------------------------------------------
/tools/__pycache__/utils.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/PlantBiTree/Classify_localtrain/2bcb8af0b78ee69e370f60e919a1163700ad96b4/tools/__pycache__/utils.cpython-36.pyc
--------------------------------------------------------------------------------
/tools/data_gen.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import os
3 | import math
4 | import codecs
5 | import random
6 | import numpy as np
7 | from glob import glob
8 | from PIL import Image
9 | import requests as req
10 | from io import BytesIO
11 | from keras.utils import np_utils, Sequence
12 | from sklearn.model_selection import train_test_split
13 | from collections import Counter
14 | from models.resnet50 import preprocess_input
15 | from tools.utils import Cutout
16 | from keras.preprocessing.image import ImageDataGenerator
17 |
18 | class BaseSequence(Sequence):
19 | """
20 | 基础的数据流生成器,每次迭代返回一个batch
21 | BaseSequence可直接用于fit_generator的generator参数
22 | fit_generator会将BaseSequence再次封装为一个多进程的数据流生成器
23 | 而且能保证在多进程下的一个epoch中不会重复取相同的样本
24 | """
25 | def __init__(self, img_paths, labels, batch_size, img_size,is_train):
26 | assert len(img_paths) == len(labels), "len(img_paths) must equal to len(lables)"
27 | assert img_size[0] == img_size[1], "img_size[0] must equal to img_size[1]"
28 | ## (?,41)
29 | self.x_y = np.hstack((np.array(img_paths).reshape(len(img_paths), 1), np.array(labels)))
30 | self.batch_size = batch_size
31 | self.img_size = img_size
32 | self.is_train = is_train
33 | if self.is_train:
34 | train_datagen = ImageDataGenerator(
35 | rotation_range = 30, # 图片随机转动角度
36 | width_shift_range = 0.2, #浮点数,图片宽度的某个比例,数据提升时图片水平偏移的幅度
37 | height_shift_range = 0.2, #浮点数,图片高度的某个比例,数据提升时图片竖直偏移的幅度
38 | shear_range = 0.2, # 剪切强度(逆时针方向的剪切变换角度)
39 | zoom_range = 0.2, # 随机缩放的幅度,
40 | horizontal_flip = True, # 随机水平翻转
41 | vertical_flip = True, # 随机竖直翻转
42 | fill_mode = 'nearest'
43 | )
44 | self.train_datagen = train_datagen
45 |
46 | def __len__(self):
47 | return math.ceil(len(self.x_y) / self.batch_size)
48 |
49 |
50 | @staticmethod
51 | def center_img(img, size=None, fill_value=255):
52 | """
53 | center img in a square background
54 | """
55 | h, w = img.shape[:2]
56 | if size is None:
57 | size = max(h, w)
58 | # h,w,channel
59 | shape = (size, size) + img.shape[2:]
60 | background = np.full(shape, fill_value, np.uint8)
61 | center_x = (size - w) // 2
62 | center_y = (size - h) // 2
63 | background[center_y:center_y + h, center_x:center_x + w] = img
64 | return background
65 |
66 | def preprocess_img(self, img_path):
67 | """
68 | image preprocessing
69 | you can add your special preprocess method here
70 | """
71 | img = Image.open(img_path)
72 | img = img.resize((256,256))
73 | img = img.convert('RGB')
74 | img = np.array(img)
75 | img = img[16:16+224,16:16+224]
76 | return img
77 |
78 |
79 | def cutout_img(self,img):
80 | cut_out = Cutout(n_holes=1,length=40)
81 | img = cut_out(img)
82 | return img
83 |
84 |
85 | def __getitem__(self, idx):
86 |
87 | # 图片路径
88 | batch_x = self.x_y[idx * self.batch_size: (idx + 1) * self.batch_size, 0]
89 | # 图片标签
90 | batch_y = self.x_y[idx * self.batch_size: (idx + 1) * self.batch_size, 1:]
91 | # 这里是像素数组 (224,224,3)
92 |
93 | batch_x = np.array([self.preprocess_img(img_path) for img_path in batch_x])
94 | # smooth labels
95 | batch_y = np.array(batch_y).astype(np.float32)*(1-0.05)+0.05/40
96 |
97 | # # 训练集数据增强
98 | if self.is_train:
99 | indexs = np.random.choice([0,1,2],batch_x.shape[0],replace=True,p=[0.4,0.4,0.2])
100 | mask_indexs = np.where(indexs==1)
101 | multi_indexs = np.where(indexs==2)
102 |
103 | if len(multi_indexs):
104 | # 数据增强
105 | multipy_batch_x = batch_x[multi_indexs]
106 | multipy_batch_y = batch_y[multi_indexs]
107 |
108 | train_datagenerator = self.train_datagen.flow(multipy_batch_x,multipy_batch_y,batch_size=self.batch_size)
109 | (multipy_batch_x,multipy_batch_y) = train_datagenerator.next()
110 |
111 | batch_x[multi_indexs] = multipy_batch_x
112 | batch_y[multi_indexs] = multipy_batch_y
113 |
114 | if len(mask_indexs[0]):
115 | # 随机遮挡
116 | mask_batch_x = batch_x[mask_indexs]
117 | mask_batch_y = batch_y[mask_indexs]
118 | mask_batch_x = np.array([self.cutout_img(img) for img in mask_batch_x])
119 |
120 | batch_x[mask_indexs] = mask_batch_x
121 | batch_y[mask_indexs] = mask_batch_y
122 |
123 |
124 | # 预处理
125 | batch_x =np.array([preprocess_input(img) for img in batch_x])
126 |
127 | # # plt 绘制图像时需要将其换成整型
128 | # for index,label in enumerate(batch_y):
129 | # print(np.argmax(label))
130 | # plt.subplot(2,8,index+1)
131 | # plt.imshow(batch_x[index].astype(int))
132 | # plt.show()
133 | # exit()
134 |
135 | return batch_x, batch_y
136 |
137 | def on_epoch_end(self):
138 | """Method called at the end of every epoch.
139 | """
140 | np.random.shuffle(self.x_y)
141 |
142 | # 加载训练文件到模型标签
143 | def data_flow(train_data_dir, batch_size, num_classes, input_size):
144 | label_files = glob(os.path.join(train_data_dir, '*.txt'))
145 | random.shuffle(label_files)
146 | img_paths = []
147 | labels = []
148 | for index, file_path in enumerate(label_files):
149 | with codecs.open(file_path, 'r', 'utf-8') as f:
150 | line = f.readline()
151 | line_split = line.strip().split(', ')
152 | if len(line_split) != 2:
153 | print('%s contain error lable' % os.path.basename(file_path))
154 | continue
155 | img_name = line_split[0]
156 | label = int(line_split[1])
157 | img_paths.append(os.path.join(train_data_dir, img_name))
158 | labels.append(label)
159 |
160 | labels = np_utils.to_categorical(labels, num_classes)
161 | train_img_paths, validation_img_paths, train_labels, validation_labels = \
162 | train_test_split(img_paths, labels, stratify=labels,test_size=0.15, random_state=0)
163 | print('total samples: %d, training samples: %d, validation samples: %d' % (len(img_paths), len(train_img_paths), len(validation_img_paths)))
164 | # 训练集随机增强图片
165 | train_sequence = BaseSequence(train_img_paths, train_labels, batch_size, [input_size, input_size],is_train=True)
166 | validation_sequence = BaseSequence(validation_img_paths, validation_labels, batch_size, [input_size, input_size],is_train=False)
167 | return train_sequence,validation_sequence
168 |
169 |
170 |
171 |
172 | # 本地路径获取图片信息
173 | def preprocess_img(img_path,img_size):
174 | try:
175 | img = Image.open(img_path)
176 | # if img.format:
177 | # resize_scale = img_size / max(img.size[:2])
178 | # img = img.resize((int(img.size[0] * resize_scale), int(img.size[1] * resize_scale)))
179 | img = img.resize((256, 256))
180 | img = img.convert('RGB')
181 | # img.show()
182 | img = np.array(img)
183 | imgs = []
184 | for _ in range(10):
185 | i = random.randint(0, 32)
186 | j = random.randint(0, 32)
187 | imgg = img[i:i + 224, j:j + 224]
188 | imgg = preprocess_input(imgg)
189 | imgs.append(imgg)
190 | return imgs
191 | except Exception as e:
192 | print('发生了异常data_process:', e)
193 | return 0
194 |
195 |
196 |
197 |
198 | # url获取图片数组信息
199 | def preprocess_img_from_Url(img_path,img_size):
200 | try:
201 | response = req.get(img_path)
202 | img = Image.open(BytesIO(response.content))
203 | img = img.resize((256, 256))
204 | img = img.convert('RGB')
205 | # img.show()
206 | img = np.array(img)
207 | imgs = []
208 | for _ in range(10):
209 | i = random.randint(0, 32)
210 | j = random.randint(0, 32)
211 | imgg = img[i:i + 224, j:j + 224]
212 | imgg = preprocess_input(imgg)
213 | imgs.append(imgg)
214 | return imgs
215 | except Exception as e:
216 | print('发生了异常data_process:', e)
217 | return 0
218 |
219 |
220 |
221 |
222 | # 加载测试数据
223 | def load_test_data(FLAGS):
224 | label_files = glob(os.path.join(FLAGS.test_data_local,"*.txt"))
225 | test_data = []
226 | img_names = []
227 | test_labels = []
228 | for index, file_path in enumerate(label_files):
229 | with codecs.open(file_path,'r','utf-8') as f:
230 | line = f.readline()
231 | line_split = line.strip().split(',')
232 | img_names.append(line_split[0])
233 | # 处理图片
234 | img_path = os.path.join(FLAGS.test_data_local,line_split[0])
235 | img = preprocess_img(img_path,FLAGS.input_size)
236 | test_data.append(preprocess_img(img_path,FLAGS.input_size))
237 | test_labels.append(int(line_split[1]))
238 | print(Counter(test_labels))
239 | # test_data = np.array(test_data)
240 | return img_names,test_data,test_labels
241 |
242 |
243 |
244 | # if __name__ == '__main__':
245 | #
246 | # train_data_dir = './garbage_classify/train_data/'
247 | # batch_size = 16
248 | # num_classes = 40
249 | # input_size = 224
250 | # # shape= (224,224,3) label=(16,40)
251 | # train_sequence, validation_sequence = data_flow(train_data_dir, batch_size,num_classes,input_size)
252 | # batch_data, bacth_label = train_sequence.__getitem__(5)
253 | # # print(train_sequence.shape)
254 | # print(batch_data[0])
--------------------------------------------------------------------------------
/tools/img_gen.py:
--------------------------------------------------------------------------------
1 | import os
2 | from keras.preprocessing.image import ImageDataGenerator,array_to_img,img_to_array,load_img
3 | from glob import glob
4 | import codecs
5 | from PIL import Image
6 | from collections import Counter
7 | from tqdm import tqdm
8 | import numpy as np
9 | import shutil
10 |
11 | def center_img(img,size=None,fill_value=255):
12 | h,w = img.shape[:2]
13 | if size is None:
14 | size = max(h,w)
15 | shape = (size,size) + img.shape[2:]
16 | background = np.full(shape,fill_value,np.uint8)
17 | center_x = (size-w)//2
18 | center_y = (size-h)//2
19 | background[center_y:center_y+h,center_x:center_x+w] = img
20 | return background
21 |
22 | def precess_imge(img_path,img_size):
23 | img = Image.open(img_path)
24 | resize_scale = img_size / max(img.size[:2])
25 | img = img.resize((int(img.size[0]*resize_scale),int(img.size[1]*resize_scale)))
26 | img = img.convert('RGB')
27 | img = np.array(img)
28 | img = img[:,:,::-1]
29 | img = center_img(img,img_size)
30 | return img
31 |
32 | # 加载图片和标签
33 | def load_dataset(train_data_dir):
34 | label_files = glob(os.path.join(train_data_dir,'*.txt'))
35 | labels = []
36 | img_data = []
37 | for file_path in tqdm(label_files):
38 | with codecs.open(file_path,'r','utf-8') as f:
39 | line = f.readline()
40 | line_split = line.strip().split(', ')
41 | if len(line_split) != 2:
42 | print('%s contain error lable' % os.path.basename(file_path))
43 | continue
44 | label = int(line_split[1])
45 | img_path = os.path.join(train_data_dir,line_split[0])
46 | # img = precess_imge(img_path,FLAGS.input_size)
47 | img_data.append(img_path)
48 | labels.append(label)
49 | print(sorted(Counter(labels).items(),key=lambda d:d[0],reverse=False))
50 | # print(Counter(labels))
51 |
52 | return img_data,labels
53 |
54 |
55 | # 将图片归类到文件夹中
56 | def write_list_to_dir(dstPath,img_paths,labels,is_split = False):
57 | if not os.path.exists(dstPath):
58 | os.makedirs(dstPath)
59 | index = 0
60 | for label in tqdm(labels):
61 | label = str(label)
62 | savedir = os.path.join(dstPath,label)
63 | if not is_split:
64 | savedir = os.path.join(savedir,label)
65 | if not os.path.exists(savedir):
66 | os.makedirs(savedir)
67 | img = Image.open(img_paths[index])
68 | imgsavePath = os.path.join(savedir,os.path.basename(img_paths[index]))
69 | img.save(imgsavePath)
70 | index += 1
71 |
72 | def write_split_dataset(dstPath,img_paths,labels):
73 | if not os.path.exists(dstPath):
74 | os.makedirs(dstPath)
75 | index = 0
76 | for label in tqdm(labels):
77 | # 写入图片
78 | img_name = os.path.basename(img_paths[index])[:-4]
79 | img = Image.open(img_paths[index])
80 | imgsavePath = os.path.join(dstPath,os.path.basename(img_paths[index]))
81 | img.save(imgsavePath)
82 | # 写入标签
83 | txt = img_name + '.txt'
84 | with open(os.path.join(dstPath,txt),'w') as f:
85 | f.write(img_name+'.jpg'+', '+str(labels[index]))
86 | index += 1
87 |
88 |
89 |
90 |
91 | # amplify 表示扩充的倍数 class_num表示需要扩充的类别
92 | def increase_img(img_path,class_num,save_path,amplify_ratio=1):
93 | train_datagen = ImageDataGenerator(
94 | rotation_range = 30, # 图片随机转动角度
95 | width_shift_range = 0.2, #浮点数,图片宽度的某个比例,数据提升时图片水平偏移的幅度
96 | height_shift_range = 0.2, #浮点数,图片高度的某个比例,数据提升时图片竖直偏移的幅度
97 | shear_range = 0.2, # 剪切强度(逆时针方向的剪切变换角度)
98 | zoom_range = 0.2, # 随机缩放的幅度,
99 | horizontal_flip = True, # 随机水平翻转
100 | vertical_flip = True, # 随机竖直翻转
101 | fill_mode = 'nearest'
102 | )
103 | # # 验证集合不用增强
104 | # val_datagen = ImageDataGenerator(rescale=1./255)
105 | # 原图片统计
106 | imgs = os.listdir(os.path.join(img_path,str(class_num)))
107 | print(len(imgs))
108 | # 清空保存文件夹下的所有文件
109 | shutil.rmtree(save_path)
110 | os.makedirs(save_path)
111 | # 迭代器
112 | train_datagenator = train_datagen.flow_from_directory(img_path,shuffle=True,
113 | save_to_dir=save_path,batch_size=1,target_size=(224,224),save_prefix='img',save_format='jpg')
114 | # 生成图片和txt
115 | for i in range(int(len(imgs)*amplify_ratio)):
116 | train_datagenator.next()
117 | img_names = os.listdir(save_path)
118 | img_names_txt_list = [x[:-4]+'.txt' for x in img_names]
119 | for index,txt in enumerate(img_names_txt_list):
120 | with open(os.path.join(save_path,txt),'w') as f:
121 | f.write(img_names[index]+', '+str(class_num))
122 |
123 | # train_datagenerator = train_datagen.flow_from_directory('./garbage_classify/data_set/train_dir/',target_size=(224,224),classes=classes,batch_size=16,seed=0)
124 | # # print(train_datagenerator.class_indices)
125 | # for data_batch,label_batch in train_datagenerator:
126 | # for index,label in enumerate(label_batch):
127 | # print(np.argmax(label))
128 | # plt.subplot(1,16,index+1)
129 | # plt.imshow(data_batch[index])
130 | # plt.show()
131 | # exit()
132 |
133 | # for i in range(5):
134 | # train_datagenator.next()
135 |
136 |
137 | def increase_train_img(train_path,class_num,amplify_ratio):
138 | img_paths,labels = load_dataset(train_path)
139 | img_path = f'./garbage_classify/data_set/{class_num}/'
140 | save_path = './garbage_classify/increase_img/'
141 | increase_img(img_path,class_num,save_path,amplify_ratio)
142 |
143 |
144 |
145 |
146 |
147 | if __name__ == "__main__":
148 | train_data_dir = './garbage_classify/train_data_save/'
149 | # dstPath = './garbage_classify/data_set'
150 | # num_classes = 40
151 | img_paths,labels = load_dataset(train_data_dir)
152 | # 不划分数据集
153 | # write_list_to_dir(dstPath,img_paths,labels)
154 |
155 | # 划分验证集和训练集
156 | # train_img_paths, validation_img_paths, train_labels, validation_labels = train_test_split(img_paths,labels,stratify=labels,test_size=0.25,random_state=0)
157 | # print('total samples: %d, training samples: %d, validation samples: %d' % (len(img_paths), len(train_img_paths), len(validation_img_paths)))
158 | # write_split_dataset('./garbage_classify/splitDataset/train',train_img_paths,train_labels)
159 | # write_split_dataset('./garbage_classify/splitDataset/val',validation_img_paths,validation_labels)
160 | # train_dstPath = './garbage_classify/splitDataset/train'
161 | # val_dstPath = './garbage_classify/splitDataset/val'
162 | # write_list_to_dir(train_dstPath,train_img_paths,train_labels,True)
163 | # write_list_to_dir(val_dstPath,validation_img_paths,validation_labels,True)
164 |
165 |
166 | # # 增广数据集
167 | # train_path = './garbage_classify/splitDataset/train'
168 | # class_num = 0
169 | # amplify_ratio = 4
170 | # img_paths,labels = load_dataset(train_path)
171 | # increase_train_img(train_path,class_num,amplify_ratio)
172 |
173 |
174 |
175 |
176 |
177 |
--------------------------------------------------------------------------------
/tools/utils.py:
--------------------------------------------------------------------------------
1 | ## Cutout
2 |
3 | import numpy as np
4 | import tensorflow as tf
5 | import matplotlib.pyplot as plt
6 |
7 | class Cutout(object):
8 | """Randomly mask out one or more patches from an image.
9 | Args:
10 | n_holes (int): Number of patches to cut out of each image.
11 | length (int): The length (in pixels) of each square patch.
12 | """
13 | def __init__(self, n_holes, length):
14 | self.n_holes = n_holes
15 | self.length = length
16 |
17 | def __call__(self,img):
18 | h = img.shape[0]
19 | w = img.shape[1]
20 | c = img.shape[2]
21 |
22 | mask = np.ones((h,w,c),np.float32)
23 |
24 | for n in range(self.n_holes):
25 | y = np.random.randint(h)
26 | x = np.random.randint(w)
27 | # 截取函数 遮挡部分不能超过图片的一半
28 | y1 = np.clip(y - self.length//2,0,h)
29 | y2 = np.clip(y+self.length//2,0,h)
30 | x1 = np.clip(x-self.length//2,0,w)
31 | x2 = np.clip(x+self.length//2,0,w)
32 |
33 | mask[y1:y2,x1:x2,:] = 0.
34 |
35 | # mask = tf.convert_to_tensor(mask)
36 | # mask = tf.reshape(mask,img.shape)
37 | img = img*mask
38 |
39 | return img
40 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import os
3 | from keras.optimizers import adam
4 | from tools.data_gen import data_flow
5 | from models.resnet50 import ResNet50
6 | from keras.layers import Dense,Dropout,BatchNormalization,GlobalAveragePooling2D
7 | from keras.models import \
8 | Model
9 | from keras.callbacks import ReduceLROnPlateau,ModelCheckpoint,EarlyStopping
10 | from keras import regularizers
11 |
12 |
13 |
14 | # 设备控制台输出配置
15 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
16 | os.environ["CUDA_VISIBLE_DEVICES"] = "0"
17 | config = tf.compat.v1.ConfigProto()
18 | config.gpu_options.allow_growth = True
19 | config.gpu_options.per_process_gpu_memory_fraction = 0.8
20 | sess = tf.compat.v1.Session(config=config)
21 | os.environ["TF_CPP_MIN_LOG_LEVEL"] = '99'
22 |
23 | # 全局配置文件
24 | tf.app.flags.DEFINE_string('test_data_local', './test_data', '测试图片文件夹')
25 | tf.app.flags.DEFINE_string('data_local', './garbage_classify/train_data', '训练图片文件夹')
26 | tf.app.flags.DEFINE_integer('num_classes', 40, '垃圾分类数目')
27 | tf.app.flags.DEFINE_integer('input_size', 224, '模型输入图片大小')
28 | tf.app.flags.DEFINE_integer('batch_size', 16, '图片批处理大小')
29 | tf.app.flags.DEFINE_float('learning_rate',1e-4, '学习率')
30 | tf.app.flags.DEFINE_integer('max_epochs', 4, '轮次')
31 | tf.app.flags.DEFINE_string('train_local', './output_model', '训练输出文件夹')
32 | tf.app.flags.DEFINE_integer('keep_weights_file_num', 20, '如果设置为-1,则文件保持的最大权重数表示无穷大')
33 | FLAGS = tf.app.flags.FLAGS
34 |
35 | ## test_acc = 0.78
36 | def add_new_last_layer(base_model,num_classes):
37 | x = base_model.output
38 | x = GlobalAveragePooling2D(name='avg_pool')(x)
39 | x = Dropout(0.5,name='dropout1')(x)
40 | # x = Dense(1024,activation='relu',kernel_regularizer= regularizers.l2(0.0001),name='fc1')(x)
41 | # x = BatchNormalization(name='bn_fc_00')(x)
42 | x = Dense(512,activation='relu',kernel_regularizer= regularizers.l2(0.0001),name='fc2')(x)
43 | x = BatchNormalization(name='bn_fc_01')(x)
44 | x = Dropout(0.5,name='dropout2')(x)
45 | x = Dense(num_classes,activation='softmax')(x)
46 | model = Model(inputs=base_model.input,outputs=x)
47 | return model
48 |
49 | #模型微调
50 | def setup_to_finetune(FLAGS,model,layer_number=149):
51 | # K.set_learning_phase(0)
52 | for layer in model.layers[:layer_number]:
53 | layer.trainable = False
54 | # K.set_learning_phase(1)
55 | for layer in model.layers[layer_number:]:
56 | layer.trainable = True
57 | # Adam = adam(lr=FLAGS.learning_rate,clipnorm=0.001)
58 | Adam = adam(lr=FLAGS.learning_rate,decay=0.0005)
59 | model.compile(optimizer=Adam,loss='categorical_crossentropy',metrics=['accuracy'])
60 |
61 |
62 | #模型初始化设置
63 | def model_fn(FLAGS):
64 | # K.set_learning_phase(0)
65 | # 引入初始化resnet50模型
66 | base_model = ResNet50(weights="imagenet",
67 | include_top=False,
68 | pooling=None,
69 | input_shape=(FLAGS.input_size, FLAGS.input_size, 3),
70 | classes=FLAGS.num_classes)
71 | for layer in base_model.layers:
72 | layer.trainable = False
73 | model = add_new_last_layer(base_model,FLAGS.num_classes)
74 | model.compile(optimizer="adam",loss = 'categorical_crossentropy',metrics=['accuracy'])
75 | return model
76 | def train_model(FLAGS):
77 | # 训练数据构建
78 | train_sequence, validation_sequence = data_flow(FLAGS.data_local, FLAGS.batch_size,
79 | FLAGS.num_classes, FLAGS.input_size)
80 | model = model_fn(FLAGS)
81 | history_tl = model.fit_generator(
82 | train_sequence,
83 | steps_per_epoch = len(train_sequence),
84 | epochs = FLAGS.max_epochs,
85 | verbose = 1,
86 | validation_data = validation_sequence,
87 | max_queue_size = 10,
88 | shuffle=True
89 | )
90 | #模型微调
91 | setup_to_finetune(FLAGS,model)
92 | history_tl = model.fit_generator(
93 | train_sequence,
94 | steps_per_epoch = len(train_sequence),
95 | epochs = FLAGS.max_epochs*5,
96 | verbose = 1,
97 | callbacks = [
98 | ModelCheckpoint('./output_model/best.h5',
99 | monitor='val_loss', save_best_only=True, mode='min'),
100 | ReduceLROnPlateau(monitor='val_loss', factor=0.1,
101 | patience=10, mode='min'),
102 | EarlyStopping(monitor='val_loss', patience=10),
103 | ],
104 | validation_data = validation_sequence,
105 | max_queue_size = 10,
106 | shuffle=True
107 | )
108 | print('training done!')
109 |
110 |
111 | if __name__ == "__main__":
112 | train_model(FLAGS)
113 |
--------------------------------------------------------------------------------