├── .gitignore ├── .idea ├── .gitignore ├── deployment.xml ├── garbage_calssify-by-resnet50.iml ├── inspectionProfiles │ ├── Project_Default.xml │ └── profiles_settings.xml ├── misc.xml ├── modules.xml └── other.xml ├── README.md ├── __init__.py ├── __pycache__ ├── data_gen.cpython-36.pyc ├── img_gen.cpython-36.pyc ├── predict_local.cpython-36.pyc ├── random_eraser.cpython-36.pyc ├── save_model.cpython-36.pyc ├── train.cpython-36.pyc ├── utils.cpython-36.pyc └── warmup_cosine_decay_scheduler.cpython-36.pyc ├── data_process.py ├── garbage_classify └── garbage_classify_rule.json ├── models ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── model.cpython-36.pyc │ └── resnet50.cpython-36.pyc ├── model.py ├── resnet50.py └── resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5 ├── predict_local.py ├── run.py ├── static └── images │ ├── banana.jpg │ ├── test.jpg │ └── test1.jpg ├── templates └── predict.html ├── test_image ├── banana.jpg ├── box.jpg ├── cigar.jpg ├── egg.jpg ├── img_14052.jpg ├── test.jpg └── test1.jpg ├── tools ├── __pycache__ │ ├── data_gen.cpython-36.pyc │ └── utils.cpython-36.pyc ├── data_gen.py ├── img_gen.py └── utils.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | output_model/best.h5 -------------------------------------------------------------------------------- /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /workspace.xml -------------------------------------------------------------------------------- /.idea/deployment.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | -------------------------------------------------------------------------------- /.idea/garbage_calssify-by-resnet50.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/Project_Default.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 23 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 6 | 7 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/other.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 6 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | #环境初始化 2 | - python3 3 | - 安装框架flask`pip3 install flask` 4 | - 安装tensorflow,keras等依赖 5 | > - `pip3 install tensorflow==1.13.1` 6 | > - `pip3 install keras==2.3.1 ` 7 | 8 | 9 | #运行 10 | - 1.命令`python3 train.py`开启训练 11 | - 2.命令`python3 predict_local.py`开启输入图片测试 12 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /__pycache__/data_gen.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PlantBiTree/Classify_localtrain/2bcb8af0b78ee69e370f60e919a1163700ad96b4/__pycache__/data_gen.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/img_gen.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PlantBiTree/Classify_localtrain/2bcb8af0b78ee69e370f60e919a1163700ad96b4/__pycache__/img_gen.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/predict_local.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PlantBiTree/Classify_localtrain/2bcb8af0b78ee69e370f60e919a1163700ad96b4/__pycache__/predict_local.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/random_eraser.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PlantBiTree/Classify_localtrain/2bcb8af0b78ee69e370f60e919a1163700ad96b4/__pycache__/random_eraser.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/save_model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PlantBiTree/Classify_localtrain/2bcb8af0b78ee69e370f60e919a1163700ad96b4/__pycache__/save_model.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/train.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PlantBiTree/Classify_localtrain/2bcb8af0b78ee69e370f60e919a1163700ad96b4/__pycache__/train.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PlantBiTree/Classify_localtrain/2bcb8af0b78ee69e370f60e919a1163700ad96b4/__pycache__/utils.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/warmup_cosine_decay_scheduler.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PlantBiTree/Classify_localtrain/2bcb8af0b78ee69e370f60e919a1163700ad96b4/__pycache__/warmup_cosine_decay_scheduler.cpython-36.pyc -------------------------------------------------------------------------------- /data_process.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | from PIL import Image 3 | import requests as req 4 | from io import BytesIO 5 | import numpy as np 6 | import os 7 | import math 8 | import codecs 9 | import random 10 | from models.resnet50 import preprocess_input 11 | # 本地路径获取图片信息 12 | def preprocess_img(img_path,img_size): 13 | try: 14 | img = Image.open(img_path) 15 | # if img.format: 16 | # resize_scale = img_size / max(img.size[:2]) 17 | # img = img.resize((int(img.size[0] * resize_scale), int(img.size[1] * resize_scale))) 18 | img = img.resize((256, 256)) 19 | img = img.convert('RGB') 20 | # img.show() 21 | img = np.array(img) 22 | imgs = [] 23 | for _ in range(10): 24 | i = random.randint(0, 32) 25 | j = random.randint(0, 32) 26 | imgg = img[i:i + 224, j:j + 224] 27 | imgg = preprocess_input(imgg) 28 | imgs.append(imgg) 29 | return imgs 30 | except Exception as e: 31 | print('发生了异常data_process:', e) 32 | return 0 33 | 34 | 35 | 36 | 37 | # url获取图片数组信息 38 | def preprocess_img_from_Url(img_path,img_size): 39 | try: 40 | response = req.get(img_path) 41 | img = Image.open(BytesIO(response.content)) 42 | img = img.resize((256, 256)) 43 | img = img.convert('RGB') 44 | # img.show() 45 | img = np.array(img) 46 | imgs = [] 47 | for _ in range(10): 48 | i = random.randint(0, 32) 49 | j = random.randint(0, 32) 50 | imgg = img[i:i + 224, j:j + 224] 51 | imgg = preprocess_input(imgg) 52 | imgs.append(imgg) 53 | return imgs 54 | except Exception as e: 55 | print('发生了异常data_process:', e) 56 | return 0 -------------------------------------------------------------------------------- /garbage_classify/garbage_classify_rule.json: -------------------------------------------------------------------------------- 1 | { 2 | "0": "其他垃圾/一次性快餐盒", 3 | "1": "其他垃圾/污损塑料", 4 | "2": "其他垃圾/烟蒂", 5 | "3": "其他垃圾/牙签", 6 | "4": "其他垃圾/破碎花盆及碟碗", 7 | "5": "其他垃圾/竹筷", 8 | "6": "厨余垃圾/剩饭剩菜", 9 | "7": "厨余垃圾/大骨头", 10 | "8": "厨余垃圾/水果果皮", 11 | "9": "厨余垃圾/水果果肉", 12 | "10": "厨余垃圾/茶叶渣", 13 | "11": "厨余垃圾/菜叶菜根", 14 | "12": "厨余垃圾/蛋壳", 15 | "13": "厨余垃圾/鱼骨", 16 | "14": "可回收物/充电宝", 17 | "15": "可回收物/包", 18 | "16": "可回收物/化妆品瓶", 19 | "17": "可回收物/塑料玩具", 20 | "18": "可回收物/塑料碗盆", 21 | "19": "可回收物/塑料衣架", 22 | "20": "可回收物/快递纸袋", 23 | "21": "可回收物/插头电线", 24 | "22": "可回收物/旧衣服", 25 | "23": "可回收物/易拉罐", 26 | "24": "可回收物/枕头", 27 | "25": "可回收物/毛绒玩具", 28 | "26": "可回收物/洗发水瓶", 29 | "27": "可回收物/玻璃杯", 30 | "28": "可回收物/皮鞋", 31 | "29": "可回收物/砧板", 32 | "30": "可回收物/纸板箱", 33 | "31": "可回收物/调料瓶", 34 | "32": "可回收物/酒瓶", 35 | "33": "可回收物/金属食品罐", 36 | "34": "可回收物/锅", 37 | "35": "可回收物/食用油桶", 38 | "36": "可回收物/饮料瓶", 39 | "37": "有害垃圾/干电池", 40 | "38": "有害垃圾/软膏", 41 | "39": "有害垃圾/过期药物" 42 | } -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PlantBiTree/Classify_localtrain/2bcb8af0b78ee69e370f60e919a1163700ad96b4/models/__init__.py -------------------------------------------------------------------------------- /models/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PlantBiTree/Classify_localtrain/2bcb8af0b78ee69e370f60e919a1163700ad96b4/models/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PlantBiTree/Classify_localtrain/2bcb8af0b78ee69e370f60e919a1163700ad96b4/models/__pycache__/model.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/resnet50.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PlantBiTree/Classify_localtrain/2bcb8af0b78ee69e370f60e919a1163700ad96b4/models/__pycache__/resnet50.cpython-36.pyc -------------------------------------------------------------------------------- /models/model.py: -------------------------------------------------------------------------------- 1 | import keras.backend as K 2 | from keras.layers import Input 3 | from keras.layers.normalization import BatchNormalization 4 | from keras.layers import Conv2D, MaxPool2D, Dropout,Activation 5 | from keras.layers import Dense,Lambda,Add,GlobalAveragePooling2D,ZeroPadding2D,Multiply 6 | from keras import Model 7 | from keras import regularizers 8 | from keras.layers import Conv1D,MaxPool1D,LSTM,ZeroPadding2D 9 | from keras.layers import GlobalMaxPool1D,Permute,MaxPooling2D 10 | from keras.layers import GRU,TimeDistributed,Flatten, LeakyReLU,ELU 11 | 12 | # MODEL 13 | WEIGHT_DECAY = 0.0001 #0.00001 14 | REDUCTION_RATIO = 4 15 | BLOCK_NUM = 1 16 | # DROPOUT= 0.5 17 | # Resblcok 18 | def res_conv_block(x,filters,strides,name): 19 | filter1,filer2,filter3 = filters 20 | # block a 21 | x = Conv2D(filter1,(1,1),strides=strides,kernel_initializer='he_normal', 22 | kernel_regularizer = regularizers.l2(l=WEIGHT_DECAY),name=f'{name}_conva')(x) 23 | x = BatchNormalization(name=f'{name}_bna')(x) 24 | x = Activation('relu',name=f'{name}_relua')(x) 25 | # block b 26 | x = Conv2D(filer2,(3,3),padding='same',kernel_regularizer = regularizers.l2(l=WEIGHT_DECAY),name=f'{name}_convb')(x) 27 | x = BatchNormalization(name=f'{name}_bnb')(x) 28 | x = Activation('relu',name=f'{name}_relub')(x) 29 | # block c 30 | x = Conv2D(filter3,(1,1),name=f'{name}_convc',kernel_regularizer = regularizers.l2(l=WEIGHT_DECAY))(x) 31 | x = BatchNormalization(name=f'{name}_bnc')(x) 32 | # shortcut 33 | shortcut = Conv2D(filter3,(1,1),strides=strides,name=f'{name}_shcut',kernel_regularizer = regularizers.l2(l=WEIGHT_DECAY))(x) 34 | shortcut = BatchNormalization(name=f'{name}_stbn')(x) 35 | x = Add(name=f'{name}_add')([x,shortcut]) 36 | x = Activation('relu',name=f'{name}_relu')(x) 37 | return x 38 | 39 | 40 | # ResNet 41 | def ResNet_50(input_shape): 42 | x_in = Input(input_shape,name='input') 43 | 44 | x = Conv2D(64,(7,7),strides=(2,2),padding='same',name='conv1',kernel_regularizer = regularizers.l2(l=WEIGHT_DECAY))(x_in) 45 | x = BatchNormalization(name="bn1")(x) 46 | x = Activation('relu')(x) 47 | x = MaxPool2D((3,3),strides=(2,2),padding='same',name='pool1')(x) 48 | 49 | x = res_conv_block(x,(64,64,256),(1,1),name='block1') 50 | x = res_conv_block(x,(64,64,256),(1,1),name='block2') 51 | x = res_conv_block(x,(64,64,256),(1,1),name='block3') 52 | 53 | x = res_conv_block(x,(128,128,512),(1,1),name='block4') 54 | x = res_conv_block(x,(128,128,512),(1,1),name='block5') 55 | x = res_conv_block(x,(128,128,512),(1,1),name='block6') 56 | x = res_conv_block(x,(128,128,512),(2,2),name='block7') 57 | 58 | x = Conv2D(512,(x.shape[1].value,1),name='fc6')(x) 59 | x = BatchNormalization(name="bn_fc6")(x) 60 | x = Activation('relu',name='relu_fc6')(x) 61 | # avgpool 62 | # x = GlobalAveragePooling2D(name='avgPool')(x) 63 | x = Lambda(lambda y: K.mean(y,axis=[1,2]),name='avgpool')(x) 64 | 65 | model = Model(inputs=[x_in],outputs=[x],name='ResCNN') 66 | # model.summary() 67 | return model 68 | 69 | def squeeze_excitation(x,reduction_ratio,name): 70 | out_dim = int(x.shape[-1].value) 71 | x = GlobalAveragePooling2D(name=f'{name}_squeeze')(x) 72 | x = Dense(out_dim//reduction_ratio,activation='relu',name=f'{name}_ex0')(x) 73 | x = Dense(out_dim,activation='sigmoid',name=f'{name}_ex1')(x) 74 | return x 75 | 76 | def conv_block(x,filters,kernal_size,stride,name,stage,i,padding='same'): 77 | x = Conv2D(filters,kernal_size,strides=stride,padding=padding,name=f'{name}_conv{stage}_{i}', 78 | kernel_regularizer = regularizers.l2(l=WEIGHT_DECAY))(x) 79 | x = BatchNormalization(name=f'{name}_bn{stage}_{i}')(x) 80 | if stage != 'c': 81 | # x = ELU(name=f'{name}_relu{stage}_{i}')(x) 82 | x = Activation('relu',name=f'{name}_relu{stage}_{i}')(x) 83 | return x 84 | 85 | 86 | def residual_block(x,outdim,stride,name): 87 | input_dim = int(x.shape[-1].value) 88 | shortcut = Conv2D(outdim,kernel_size=(1,1),strides=stride,name=f'{name}_scut_conv', 89 | kernel_regularizer = regularizers.l2(l=WEIGHT_DECAY))(x) 90 | shortcut = BatchNormalization(name=f'{name}_scut_norm')(shortcut) 91 | 92 | for i in range(BLOCK_NUM): 93 | if i>0 : 94 | stride = 1 95 | # x = Dropout(DROPOUT,name=f'{name}_drop{i-1}')(x) 96 | x = conv_block(x,outdim//4,(1,1),stride,name,'a',i,padding='valid') 97 | x = conv_block(x,outdim//4,(3,3),(1,1),name,'b',i,padding='same') 98 | x = conv_block(x,outdim,(1,1),(1,1),name,'c',i,padding='valid') 99 | # add SE 100 | x = Multiply(name=f'{name}_scale')([x,squeeze_excitation(x,REDUCTION_RATIO,name)]) 101 | x = Add(name=f'{name}_scut')([shortcut,x]) 102 | x = Activation('relu',name=f'{name}_relu')(x) 103 | return x 104 | 105 | 106 | # proposed model v4.0 timit libri 107 | def SE_ResNet(input_shape): 108 | # first layer 109 | x_in =Input(input_shape,name='input') 110 | 111 | x = ZeroPadding2D(padding=(3, 3), name='conv1_pad')(x_in) 112 | x = Conv2D(64, (7, 7), 113 | strides=(2, 2), 114 | padding='valid', 115 | kernel_initializer='he_normal', 116 | kernel_regularizer= regularizers.l2(WEIGHT_DECAY), 117 | name='conv1')(x) 118 | x = BatchNormalization(name='bn_conv1')(x) 119 | x = Activation('relu',name='relu1')(x) 120 | # x = ELU(name=f'relu1')(x) 121 | x = ZeroPadding2D(padding=(1, 1), name='pool1_pad')(x) 122 | x = MaxPooling2D((3, 3), strides=(2, 2))(x) 123 | 124 | 125 | x = residual_block(x,outdim=256,stride=(2,2),name='block1') 126 | x = residual_block(x,outdim=256,stride=(2,2),name='block2') 127 | x = residual_block(x,outdim=256,stride=(2,2),name='block3') 128 | # x = residual_block(x,outdim=256,stride=(2,2),name='block4') 129 | 130 | x = residual_block(x,outdim=512,stride=(2,2),name='block5') 131 | x = residual_block(x,outdim=512,stride=(2,2),name='block6') 132 | # x = residual_block(x,outdim=512,stride=(2,2),name='block7') 133 | 134 | 135 | x = Flatten(name='flatten')(x) 136 | 137 | x = Dropout(0.5,name='drop1')(x) 138 | x = Dense(512,kernel_regularizer= regularizers.l2(WEIGHT_DECAY),name='fc1')(x) 139 | x = BatchNormalization(name='bn_fc1')(x) 140 | # x = ELU(name=f'relu_fc1')(x) 141 | x = Activation('relu',name=f'relu_fc1')(x) 142 | 143 | 144 | return Model(inputs=[x_in],outputs=[x],name='SEResNet') 145 | 146 | 147 | # vggvox1 148 | def conv_pool(x,layerid,filters,kernal_size,conv_strides,pool_size=None,pool_strides=None,pool=None): 149 | x = Conv2D(filters,kernal_size,strides= conv_strides,padding='same',name=f'conv{layerid}')(x) 150 | x = BatchNormalization(name=f'bn{layerid}')(x) 151 | x = Activation('relu',name=f'relu{layerid}')(x) 152 | if pool == 'max': 153 | x = MaxPool2D(pool_size,pool_strides,name=f'mpool{layerid}')(x) 154 | return x 155 | 156 | def vggvox1_cnn(input_shape): 157 | x_in = Input(input_shape,name='input') 158 | x = conv_pool(x_in,1,96,(7,7),(2,2),(3,3),(2,2),'max') 159 | x = conv_pool(x,2,256,(5,5),(2,2),(3,3),(2,2),'max') 160 | x = conv_pool(x,3,384,(3,3),(1,1)) 161 | x = conv_pool(x,4,256,(3,3),(1,1)) 162 | x = conv_pool(x,5,256,(3,3),(1,1),(5,3),(3,2),'max') 163 | # fc 6 164 | x = Conv2D(256,(9,1),name='fc6')(x) 165 | # apool6 166 | x = GlobalAveragePooling2D(name='avgPool')(x) 167 | # fc7 168 | x = Dense(512,name='fc7',activation='relu')(x) 169 | model = Model(inputs=[x_in],outputs=[x],name='vggvox1_cnn') 170 | return model 171 | 172 | # def vggvox1_cnn(input_shape): 173 | # x_in = Input(input_shape,name='input') 174 | # x = conv_pool(x_in,1,96,(7,7),(1,1),(3,3),(2,2),'max') 175 | # x = conv_pool(x,2,256,(5,5),(1,1),(3,3),(2,2),'max') 176 | # x = conv_pool(x,3,384,(3,3),(1,1)) 177 | # x = conv_pool(x,4,256,(3,3),(1,1)) 178 | # x = conv_pool(x,5,256,(3,3),(1,1),(5,3),(2,2),'max') 179 | # # fc 6 180 | # x = Conv2D(256,(x.shape[1].value,1),name='fc6')(x) 181 | # # apool6 182 | # x = GlobalAveragePooling2D(name='avgPool')(x) 183 | # # fc7 184 | # x = Dense(512,name='fc7',activation='relu')(x) 185 | # model = Model(inputs=[x_in],outputs=[x],name='vggvox1_cnn') 186 | # return model 187 | 188 | # deep speaker 189 | def clipped_relu(inputs): 190 | return Lambda(lambda y:K.minimum(K.maximum(y,0),20))(inputs) 191 | 192 | def identity_block(x_in,kernel_size,filters,name): 193 | x = Conv2D(filters,kernel_size=kernel_size,strides=(1,1), 194 | padding='same',kernel_regularizer=regularizers.l2(l=WEIGHT_DECAY), 195 | name=f'{name}_conva')(x_in) 196 | x = BatchNormalization(name=f'{name}_bn1')(x) 197 | x = clipped_relu(x) 198 | x = Conv2D(filters,kernel_size=kernel_size,strides=(1,1), 199 | padding='same',kernel_regularizer = regularizers.l2(l=WEIGHT_DECAY), 200 | name=f'{name}_convb')(x) 201 | x = BatchNormalization(name=f'{name}_bn2')(x) 202 | x = Add(name=f'{name}_add')([x,x_in]) 203 | x = clipped_relu(x) 204 | return x 205 | 206 | def Deep_speaker_model(input_shape): 207 | def conv_and_res_block(x_in,filters): 208 | x = Conv2D(filters,kernel_size=(5,5),strides=(2,2), 209 | padding='same',kernel_regularizer=regularizers.l2(l=WEIGHT_DECAY), 210 | name=f'conv_{filters}-s')(x_in) 211 | x = BatchNormalization(name=f'conv_{filters}-s_bn')(x) 212 | x = clipped_relu(x) 213 | for i in range(3): 214 | x = identity_block(x,kernel_size=(3,3),filters=filters,name=f'res{filters}_{i}') 215 | return x 216 | 217 | x_in = Input(input_shape,name='input') 218 | x = Permute((2,1,3),name='permute')(x_in) 219 | x = conv_and_res_block(x,64) 220 | x = conv_and_res_block(x,128) 221 | x = conv_and_res_block(x,256) 222 | x = conv_and_res_block(x,512) 223 | # average 224 | x = Lambda(lambda y: K.mean(y,axis=[1,2]),name='avgpool')(x) 225 | # affine 226 | x = Dense(512,name='affine')(x) 227 | x = Lambda(lambda y:K.l2_normalize(y,axis=1),name='ln')(x) 228 | model = Model(inputs=[x_in],outputs=[x],name='deepspeaker') 229 | return model 230 | 231 | # proposed model 232 | def Baseline_GRU(input_shape): 233 | # first layer 234 | x_in = Input(input_shape, name='input') 235 | x = Conv2D(64, (3, 3), strides=(1, 1), padding='same', name='conv1')(x_in) 236 | x = BatchNormalization(name='bn1')(x) 237 | x = ELU(name='relu1')(x) 238 | x = MaxPool2D((2, 2), strides=(2, 2), padding='same', name='pool1')(x) 239 | 240 | x = Conv2D(64, (3, 3), strides=(1, 1), padding='same', name='conv2')(x) 241 | x = BatchNormalization(name='bn2')(x) 242 | x = ELU(name='relu2')(x) 243 | x = MaxPool2D((2, 2), strides=(2, 2), padding='same', name='pool2')(x) 244 | 245 | x = TimeDistributed(Flatten(), name='timedis1')(x) 246 | x = GRU(512, return_sequences=True, name='gru1')(x) 247 | x = GRU(512, return_sequences=True, name='gru2')(x) 248 | x = GRU(512, return_sequences=False, name='gru4')(x) 249 | 250 | x = Dense(512, name='fc2', activation='relu')(x) 251 | x = BatchNormalization(name='fc_norm')(x) 252 | x = ELU(name='relu3')(x) 253 | 254 | return Model(inputs=[x_in], outputs=[x], name='Baseline_GRU') 255 | 256 | 257 | 258 | 259 | 260 | if __name__ == "__main__": 261 | 262 | # model = ResNet(c.INPUT_SHPE) 263 | model = vggvox1_cnn((299,40,1)) 264 | # model = Deep_speaker_model(c.INPUT_SHPE) 265 | # # model = SE_ResNet(c.INPUT_SHPE) 266 | # model = RWCNN_LSTM((59049,1)) 267 | print(model.summary()) 268 | -------------------------------------------------------------------------------- /models/resnet50.py: -------------------------------------------------------------------------------- 1 | """Enables dynamic setting of underlying Keras module. 2 | """ 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | import os 8 | import json 9 | import warnings 10 | import numpy as np 11 | 12 | from keras.applications import backend 13 | from keras.applications import layers 14 | from keras.applications import models 15 | from keras.applications import utils 16 | 17 | _KERAS_BACKEND = backend 18 | _KERAS_LAYERS = layers 19 | _KERAS_MODELS = models 20 | _KERAS_UTILS = utils 21 | 22 | CLASS_INDEX = None 23 | CLASS_INDEX_PATH = ('https://modelarts-competitions.obs.cn-north-1.myhuaweicloud.com/' 24 | 'model_zoo/resnet/imagenet_class_index.json') 25 | 26 | # Global tensor of imagenet mean for preprocessin 27 | # g symbolic inputs 28 | _IMAGENET_MEAN = None 29 | 30 | WEIGHTS_PATH = ('https://github.com/fchollet/deep-learning-models/' 31 | 'releases/download/v0.2/' 32 | 'resnet50_weights_tf_dim_ordering_tf_kernels.h5') 33 | WEIGHTS_PATH_NO_TOP = ('https://modelarts-competitions.obs.cn-north-1.myhuaweicloud.com/' 34 | 'model_zoo/resnet/' 35 | 'resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5') 36 | 37 | def get_submodules_from_kwargs(kwargs): 38 | backend = kwargs.get('backend', _KERAS_BACKEND) 39 | layers = kwargs.get('layers', _KERAS_LAYERS) 40 | models = kwargs.get('models', _KERAS_MODELS) 41 | utils = kwargs.get('utils', _KERAS_UTILS) 42 | for key in kwargs.keys(): 43 | if key not in ['backend', 'layers', 'models', 'utils']: 44 | raise TypeError('Invalid keyword argument: %s', key) 45 | return backend, layers, models, utils 46 | 47 | 48 | def correct_pad(backend, inputs, kernel_size): 49 | """Returns a tuple for zero-padding for 2D convolution with downsampling. 50 | 51 | # Arguments 52 | input_size: An integer or tuple/list of 2 integers. 53 | kernel_size: An integer or tuple/list of 2 integers. 54 | 55 | # Returns 56 | A tuple. 57 | """ 58 | img_dim = 2 if backend.image_data_format() == 'channels_first' else 1 59 | input_size = backend.int_shape(inputs)[img_dim:(img_dim + 2)] 60 | 61 | if isinstance(kernel_size, int): 62 | kernel_size = (kernel_size, kernel_size) 63 | 64 | if input_size[0] is None: 65 | adjust = (1, 1) 66 | else: 67 | adjust = (1 - input_size[0] % 2, 1 - input_size[1] % 2) 68 | 69 | correct = (kernel_size[0] // 2, kernel_size[1] // 2) 70 | 71 | return ((correct[0] - adjust[0], correct[0]), 72 | (correct[1] - adjust[1], correct[1])) 73 | 74 | __version__ = '1.0.7' 75 | 76 | 77 | def _preprocess_numpy_input(x, data_format, mode, **kwargs): 78 | """Preprocesses a Numpy array encoding a batch of images. 79 | 80 | # Arguments 81 | x: Input array, 3D or 4D. 82 | data_format: Data format of the image array. 83 | mode: One of "caffe", "tf" or "torch". 84 | - caffe: will convert the images from RGB to BGR, 85 | then will zero-center each color channel with 86 | respect to the ImageNet dataset, 87 | without scaling. 88 | - tf: will scale pixels between -1 and 1, 89 | sample-wise. 90 | - torch: will scale pixels between 0 and 1 and then 91 | will normalize each channel with respect to the 92 | ImageNet dataset. 93 | 94 | # Returns 95 | Preprocessed Numpy array. 96 | """ 97 | backend, _, _, _ = get_submodules_from_kwargs(kwargs) 98 | if not issubclass(x.dtype.type, np.floating): 99 | x = x.astype(backend.floatx(), copy=False) 100 | 101 | 102 | if mode == 'tf': 103 | x /= 127.5 104 | x -= 1. 105 | 106 | return x 107 | 108 | if mode == 'torch': 109 | x /= 255. 110 | mean = [0.485, 0.456, 0.406] 111 | std = [0.229, 0.224, 0.225] 112 | else: 113 | if data_format == 'channels_first': 114 | # 'RGB'->'BGR' 115 | if x.ndim == 3: 116 | x = x[::-1, ...] 117 | else: 118 | x = x[:, ::-1, ...] 119 | else: 120 | # 'RGB'->'BGR' 121 | x = x[..., ::-1] 122 | mean = [103.939, 116.779, 123.68] 123 | std = None 124 | 125 | # Zero-center by mean pixel 126 | if data_format == 'channels_first': 127 | if x.ndim == 3: 128 | x[0, :, :] -= mean[0] 129 | x[1, :, :] -= mean[1] 130 | x[2, :, :] -= mean[2] 131 | if std is not None: 132 | x[0, :, :] /= std[0] 133 | x[1, :, :] /= std[1] 134 | x[2, :, :] /= std[2] 135 | else: 136 | x[:, 0, :, :] -= mean[0] 137 | x[:, 1, :, :] -= mean[1] 138 | x[:, 2, :, :] -= mean[2] 139 | if std is not None: 140 | x[:, 0, :, :] /= std[0] 141 | x[:, 1, :, :] /= std[1] 142 | x[:, 2, :, :] /= std[2] 143 | else: 144 | x[..., 0] -= mean[0] 145 | x[..., 1] -= mean[1] 146 | x[..., 2] -= mean[2] 147 | if std is not None: 148 | x[..., 0] /= std[0] 149 | x[..., 1] /= std[1] 150 | x[..., 2] /= std[2] 151 | return x 152 | 153 | 154 | def _preprocess_symbolic_input(x, data_format, mode, **kwargs): 155 | """Preprocesses a tensor encoding a batch of images. 156 | 157 | # Arguments 158 | x: Input tensor, 3D or 4D. 159 | data_format: Data format of the image tensor. 160 | mode: One of "caffe", "tf" or "torch". 161 | - caffe: will convert the images from RGB to BGR, 162 | then will zero-center each color channel with 163 | respect to the ImageNet dataset, 164 | without scaling. 165 | - tf: will scale pixels between -1 and 1, 166 | sample-wise. 167 | - torch: will scale pixels between 0 and 1 and then 168 | will normalize each channel with respect to the 169 | ImageNet dataset. 170 | 171 | # Returns 172 | Preprocessed tensor. 173 | """ 174 | global _IMAGENET_MEAN 175 | 176 | backend, _, _, _ = get_submodules_from_kwargs(kwargs) 177 | 178 | if mode == 'tf': 179 | x /= 127.5 180 | x -= 1. 181 | return x 182 | 183 | if mode == 'torch': 184 | x /= 255. 185 | mean = [0.485, 0.456, 0.406] 186 | std = [0.229, 0.224, 0.225] 187 | else: 188 | if data_format == 'channels_first': 189 | # 'RGB'->'BGR' 190 | if backend.ndim(x) == 3: 191 | x = x[::-1, ...] 192 | else: 193 | x = x[:, ::-1, ...] 194 | else: 195 | # 'RGB'->'BGR' 196 | x = x[..., ::-1] 197 | mean = [103.939, 116.779, 123.68] 198 | std = None 199 | 200 | if _IMAGENET_MEAN is None: 201 | _IMAGENET_MEAN = backend.constant(-np.array(mean)) 202 | 203 | # Zero-center by mean pixel 204 | if backend.dtype(x) != backend.dtype(_IMAGENET_MEAN): 205 | x = backend.bias_add( 206 | x, backend.cast(_IMAGENET_MEAN, backend.dtype(x)), 207 | data_format=data_format) 208 | else: 209 | x = backend.bias_add(x, _IMAGENET_MEAN, data_format) 210 | if std is not None: 211 | x /= std 212 | return x 213 | 214 | 215 | def preprocess_input(x, data_format=None, mode='caffe', **kwargs): 216 | """Preprocesses a tensor or Numpy array encoding a batch of images. 217 | 218 | # Arguments 219 | x: Input Numpy or symbolic tensor, 3D or 4D. 220 | The preprocessed data is written over the input data 221 | if the data types are compatible. To avoid this 222 | behaviour, `numpy.copy(x)` can be used. 223 | data_format: Data format of the image tensor/array. 224 | mode: One of "caffe", "tf" or "torch". 225 | - caffe: will convert the images from RGB to BGR, 226 | then will zero-center each color channel with 227 | respect to the ImageNet dataset, 228 | without scaling. 229 | - tf: will scale pixels between -1 and 1, 230 | sample-wise. 231 | - torch: will scale pixels between 0 and 1 and then 232 | will normalize each channel with respect to the 233 | ImageNet dataset. 234 | 235 | # Returns 236 | Preprocessed tensor or Numpy array. 237 | 238 | # Raises 239 | ValueError: In case of unknown `data_format` argument. 240 | """ 241 | backend, _, _, _ = get_submodules_from_kwargs(kwargs) 242 | 243 | if data_format is None: 244 | data_format = backend.image_data_format() 245 | if data_format not in {'channels_first', 'channels_last'}: 246 | raise ValueError('Unknown data_format ' + str(data_format)) 247 | 248 | if isinstance(x, np.ndarray): 249 | return _preprocess_numpy_input(x, data_format=data_format, 250 | mode=mode, **kwargs) 251 | else: 252 | return _preprocess_symbolic_input(x, data_format=data_format, 253 | mode=mode, **kwargs) 254 | 255 | 256 | def decode_predictions(preds, top=5, **kwargs): 257 | """Decodes the prediction of an ImageNet model. 258 | 259 | # Arguments 260 | preds: Numpy tensor encoding a batch of predictions. 261 | top: Integer, how many top-guesses to return. 262 | 263 | # Returns 264 | A list of lists of top class prediction tuples 265 | `(class_name, class_description, score)`. 266 | One list of tuples per sample in batch input. 267 | 268 | # Raises 269 | ValueError: In case of invalid shape of the `pred` array 270 | (must be 2D). 271 | """ 272 | global CLASS_INDEX 273 | 274 | backend, _, _, keras_utils = get_submodules_from_kwargs(kwargs) 275 | 276 | if len(preds.shape) != 2 or preds.shape[1] != 1000: 277 | raise ValueError('`decode_predictions` expects ' 278 | 'a batch of predictions ' 279 | '(i.e. a 2D array of shape (samples, 1000)). ' 280 | 'Found array with shape: ' + str(preds.shape)) 281 | if CLASS_INDEX is None: 282 | fpath = keras_utils.get_file( 283 | 'imagenet_class_index.json', 284 | CLASS_INDEX_PATH, 285 | cache_subdir='models', 286 | file_hash='c2c37ea517e94d9795004a39431a14cb', 287 | cache_dir=os.path.join(os.path.dirname(__file__), '..')) 288 | with open(fpath) as f: 289 | CLASS_INDEX = json.load(f) 290 | results = [] 291 | for pred in preds: 292 | top_indices = pred.argsort()[-top:][::-1] 293 | result = [tuple(CLASS_INDEX[str(i)]) + (pred[i],) for i in top_indices] 294 | result.sort(key=lambda x: x[2], reverse=True) 295 | results.append(result) 296 | return results 297 | 298 | 299 | def _obtain_input_shape(input_shape, 300 | default_size, 301 | min_size, 302 | data_format, 303 | require_flatten, 304 | weights=None): 305 | """Internal utility to compute/validate a model's input shape. 306 | 307 | # Arguments 308 | input_shape: Either None (will return the default network input shape), 309 | or a user-provided shape to be validated. 310 | default_size: Default input width/height for the model. 311 | min_size: Minimum input width/height accepted by the model. 312 | data_format: Image data format to use. 313 | require_flatten: Whether the model is expected to 314 | be linked to a classifier via a Flatten layer. 315 | weights: One of `None` (random initialization) 316 | or 'imagenet' (pre-training on ImageNet). 317 | If weights='imagenet' input channels must be equal to 3. 318 | 319 | # Returns 320 | An integer shape tuple (may include None entries). 321 | 322 | # Raises 323 | ValueError: In case of invalid argument values. 324 | """ 325 | if weights != 'imagenet' and input_shape and len(input_shape) == 3: 326 | if data_format == 'channels_first': 327 | if input_shape[0] not in {1, 3}: 328 | warnings.warn( 329 | 'This model usually expects 1 or 3 input channels. ' 330 | 'However, it was passed an input_shape with ' + 331 | str(input_shape[0]) + ' input channels.') 332 | default_shape = (input_shape[0], default_size, default_size) 333 | else: 334 | if input_shape[-1] not in {1, 3}: 335 | warnings.warn( 336 | 'This model usually expects 1 or 3 input channels. ' 337 | 'However, it was passed an input_shape with ' + 338 | str(input_shape[-1]) + ' input channels.') 339 | default_shape = (default_size, default_size, input_shape[-1]) 340 | else: 341 | if data_format == 'channels_first': 342 | default_shape = (3, default_size, default_size) 343 | else: 344 | default_shape = (default_size, default_size, 3) 345 | if weights == 'imagenet' and require_flatten: 346 | if input_shape is not None: 347 | if input_shape != default_shape: 348 | raise ValueError('When setting `include_top=True` ' 349 | 'and loading `imagenet` weights, ' 350 | '`input_shape` should be ' + 351 | str(default_shape) + '.') 352 | return default_shape 353 | if input_shape: 354 | if data_format == 'channels_first': 355 | if input_shape is not None: 356 | if len(input_shape) != 3: 357 | raise ValueError( 358 | '`input_shape` must be a tuple of three integers.') 359 | if input_shape[0] != 3 and weights == 'imagenet': 360 | raise ValueError('The input must have 3 channels; got ' 361 | '`input_shape=' + str(input_shape) + '`') 362 | if ((input_shape[1] is not None and input_shape[1] < min_size) or 363 | (input_shape[2] is not None and input_shape[2] < min_size)): 364 | raise ValueError('Input size must be at least ' + 365 | str(min_size) + 'x' + str(min_size) + 366 | '; got `input_shape=' + 367 | str(input_shape) + '`') 368 | else: 369 | if input_shape is not None: 370 | if len(input_shape) != 3: 371 | raise ValueError( 372 | '`input_shape` must be a tuple of three integers.') 373 | if input_shape[-1] != 3 and weights == 'imagenet': 374 | raise ValueError('The input must have 3 channels; got ' 375 | '`input_shape=' + str(input_shape) + '`') 376 | if ((input_shape[0] is not None and input_shape[0] < min_size) or 377 | (input_shape[1] is not None and input_shape[1] < min_size)): 378 | raise ValueError('Input size must be at least ' + 379 | str(min_size) + 'x' + str(min_size) + 380 | '; got `input_shape=' + 381 | str(input_shape) + '`') 382 | else: 383 | if require_flatten: 384 | input_shape = default_shape 385 | else: 386 | if data_format == 'channels_first': 387 | input_shape = (3, None, None) 388 | else: 389 | input_shape = (None, None, 3) 390 | if require_flatten: 391 | if None in input_shape: 392 | raise ValueError('If `include_top` is True, ' 393 | 'you should specify a static `input_shape`. ' 394 | 'Got `input_shape=' + str(input_shape) + '`') 395 | return input_shape 396 | 397 | 398 | backend = None 399 | layers = None 400 | models = None 401 | keras_utils = None 402 | 403 | 404 | def identity_block(input_tensor, kernel_size, filters, stage, block): 405 | """The identity block is the block that has no conv layer at shortcut. 406 | 407 | # Arguments 408 | input_tensor: input tensor 409 | kernel_size: default 3, the kernel size of 410 | middle conv layer at main path 411 | filters: list of integers, the filters of 3 conv layer at main path 412 | stage: integer, current stage label, used for generating layer names 413 | block: 'a','b'..., current block label, used for generating layer names 414 | 415 | # Returns 416 | Output tensor for the block. 417 | """ 418 | filters1, filters2, filters3 = filters 419 | if backend.image_data_format() == 'channels_last': 420 | bn_axis = 3 421 | else: 422 | bn_axis = 1 423 | conv_name_base = 'res' + str(stage) + block + '_branch' 424 | bn_name_base = 'bn' + str(stage) + block + '_branch' 425 | 426 | x = layers.Conv2D(filters1, (1, 1), 427 | kernel_initializer='he_normal', 428 | name=conv_name_base + '2a')(input_tensor) 429 | x = layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '2a')(x) 430 | x = layers.Activation('relu')(x) 431 | 432 | x = layers.Conv2D(filters2, kernel_size, 433 | padding='same', 434 | kernel_initializer='he_normal', 435 | name=conv_name_base + '2b')(x) 436 | x = layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '2b')(x) 437 | x = layers.Activation('relu')(x) 438 | 439 | x = layers.Conv2D(filters3, (1, 1), 440 | kernel_initializer='he_normal', 441 | name=conv_name_base + '2c')(x) 442 | x = layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '2c')(x) 443 | 444 | x = layers.add([x, input_tensor]) 445 | x = layers.Activation('relu')(x) 446 | return x 447 | 448 | 449 | def conv_block(input_tensor, 450 | kernel_size, 451 | filters, 452 | stage, 453 | block, 454 | strides=(2, 2)): 455 | """A block that has a conv layer at shortcut. 456 | 457 | # Arguments 458 | input_tensor: input tensor 459 | kernel_size: default 3, the kernel size of 460 | middle conv layer at main path 461 | filters: list of integers, the filters of 3 conv layer at main path 462 | stage: integer, current stage label, used for generating layer names 463 | block: 'a','b'..., current block label, used for generating layer names 464 | strides: Strides for the first conv layer in the block. 465 | 466 | # Returns 467 | Output tensor for the block. 468 | 469 | Note that from stage 3, 470 | the first conv layer at main path is with strides=(2, 2) 471 | And the shortcut should have strides=(2, 2) as well 472 | """ 473 | filters1, filters2, filters3 = filters 474 | if backend.image_data_format() == 'channels_last': 475 | bn_axis = 3 476 | else: 477 | bn_axis = 1 478 | conv_name_base = 'res' + str(stage) + block + '_branch' 479 | bn_name_base = 'bn' + str(stage) + block + '_branch' 480 | 481 | x = layers.Conv2D(filters1, (1, 1), strides=strides, 482 | kernel_initializer='he_normal', 483 | name=conv_name_base + '2a')(input_tensor) 484 | x = layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '2a')(x) 485 | x = layers.Activation('relu')(x) 486 | 487 | x = layers.Conv2D(filters2, kernel_size, padding='same', 488 | kernel_initializer='he_normal', 489 | name=conv_name_base + '2b')(x) 490 | x = layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '2b')(x) 491 | x = layers.Activation('relu')(x) 492 | 493 | x = layers.Conv2D(filters3, (1, 1), 494 | kernel_initializer='he_normal', 495 | name=conv_name_base + '2c')(x) 496 | x = layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '2c')(x) 497 | 498 | shortcut = layers.Conv2D(filters3, (1, 1), strides=strides, 499 | kernel_initializer='he_normal', 500 | name=conv_name_base + '1')(input_tensor) 501 | shortcut = layers.BatchNormalization( 502 | axis=bn_axis, name=bn_name_base + '1')(shortcut) 503 | 504 | x = layers.add([x, shortcut]) 505 | x = layers.Activation('relu')(x) 506 | return x 507 | 508 | 509 | def ResNet50(include_top=True, 510 | weights='imagenet', 511 | input_tensor=None, 512 | input_shape=None, 513 | pooling=None, 514 | classes=1000, 515 | **kwargs): 516 | """Instantiates the ResNet50 architecture. 517 | 518 | Optionally loads weights pre-trained on ImageNet. 519 | Note that the data format convention used by the model is 520 | the one specified in your Keras config at `~/.keras/keras.json`. 521 | 522 | # Arguments 523 | include_top: whether to include the fully-connected 524 | layer at the top of the network. 525 | weights: one of `None` (random initialization), 526 | 'imagenet' (pre-training on ImageNet), 527 | or the path to the weights file to be loaded. 528 | input_tensor: optional Keras tensor (i.e. output of `layers.Input()`) 529 | to use as image input for the model. 530 | input_shape: optional shape tuple, only to be specified 531 | if `include_top` is False (otherwise the input shape 532 | has to be `(224, 224, 3)` (with `channels_last` data format) 533 | or `(3, 224, 224)` (with `channels_first` data format). 534 | It should have exactly 3 inputs channels, 535 | and width and height should be no smaller than 32. 536 | E.g. `(200, 200, 3)` would be one valid value. 537 | pooling: Optional pooling mode for feature extraction 538 | when `include_top` is `False`. 539 | - `None` means that the output of the model will be 540 | the 4D tensor output of the 541 | last convolutional block. 542 | - `avg` means that global average pooling 543 | will be applied to the output of the 544 | last convolutional block, and thus 545 | the output of the model will be a 2D tensor. 546 | - `max` means that global max pooling will 547 | be applied. 548 | classes: optional number of classes to classify images 549 | into, only to be specified if `include_top` is True, and 550 | if no `weights` argument is specified. 551 | 552 | # Returns 553 | A Keras model instance. 554 | 555 | # Raises 556 | ValueError: in case of invalid argument for `weights`, 557 | or invalid input shape. 558 | """ 559 | global backend, layers, models, keras_utils 560 | backend, layers, models, keras_utils = get_submodules_from_kwargs(kwargs) 561 | 562 | if not (weights in {'imagenet', None} or os.path.exists(weights)): 563 | raise ValueError('The `weights` argument should be either ' 564 | '`None` (random initialization), `imagenet` ' 565 | '(pre-training on ImageNet), ' 566 | 'or the path to the weights file to be loaded.') 567 | 568 | if weights == 'imagenet' and include_top and classes != 1000: 569 | raise ValueError('If using `weights` as `"imagenet"` with `include_top`' 570 | ' as true, `classes` should be 1000') 571 | 572 | # Determine proper input shape 573 | input_shape = _obtain_input_shape(input_shape, 574 | default_size=224, 575 | min_size=32, 576 | data_format=backend.image_data_format(), 577 | require_flatten=include_top, 578 | weights=weights) 579 | 580 | if input_tensor is None: 581 | img_input = layers.Input(shape=input_shape) 582 | else: 583 | if not backend.is_keras_tensor(input_tensor): 584 | img_input = layers.Input(tensor=input_tensor, shape=input_shape) 585 | else: 586 | img_input = input_tensor 587 | if backend.image_data_format() == 'channels_last': 588 | bn_axis = 3 589 | else: 590 | bn_axis = 1 591 | 592 | x = layers.ZeroPadding2D(padding=(3, 3), name='conv1_pad')(img_input) 593 | x = layers.Conv2D(64, (7, 7), 594 | strides=(2, 2), 595 | padding='valid', 596 | kernel_initializer='he_normal', 597 | name='conv1')(x) 598 | x = layers.BatchNormalization(axis=bn_axis, name='bn_conv1')(x) 599 | x = layers.Activation('relu')(x) 600 | x = layers.ZeroPadding2D(padding=(1, 1), name='pool1_pad')(x) 601 | x = layers.MaxPooling2D((3, 3), strides=(2, 2))(x) 602 | 603 | x = conv_block(x, 3, [64, 64, 256], stage=2, block='a', strides=(1, 1)) 604 | x = identity_block(x, 3, [64, 64, 256], stage=2, block='b') 605 | x = identity_block(x, 3, [64, 64, 256], stage=2, block='c') 606 | 607 | x = conv_block(x, 3, [128, 128, 512], stage=3, block='a') 608 | x = identity_block(x, 3, [128, 128, 512], stage=3, block='b') 609 | x = identity_block(x, 3, [128, 128, 512], stage=3, block='c') 610 | x = identity_block(x, 3, [128, 128, 512], stage=3, block='d') 611 | 612 | x = conv_block(x, 3, [256, 256, 1024], stage=4, block='a') 613 | x = identity_block(x, 3, [256, 256, 1024], stage=4, block='b') 614 | x = identity_block(x, 3, [256, 256, 1024], stage=4, block='c') 615 | x = identity_block(x, 3, [256, 256, 1024], stage=4, block='d') 616 | x = identity_block(x, 3, [256, 256, 1024], stage=4, block='e') 617 | x = identity_block(x, 3, [256, 256, 1024], stage=4, block='f') 618 | 619 | x = conv_block(x, 3, [512, 512, 2048], stage=5, block='a') 620 | x = identity_block(x, 3, [512, 512, 2048], stage=5, block='b') 621 | x = identity_block(x, 3, [512, 512, 2048], stage=5, block='c') 622 | 623 | if include_top: 624 | x = layers.GlobalAveragePooling2D(name='avg_pool')(x) 625 | x = layers.Dense(classes, activation='softmax', name='fc1000')(x) 626 | else: 627 | if pooling == 'avg': 628 | x = layers.GlobalAveragePooling2D()(x) 629 | elif pooling == 'max': 630 | x = layers.GlobalMaxPooling2D()(x) 631 | else: 632 | warnings.warn('The output shape of `ResNet50(include_top=False)` ' 633 | 'has been changed since Keras 2.2.0.') 634 | 635 | # Ensure that the model takes into account 636 | # any potential predecessors of `input_tensor`. 637 | if input_tensor is not None: 638 | inputs = keras_utils.get_source_inputs(input_tensor) 639 | else: 640 | inputs = img_input 641 | # Create model. 642 | model = models.Model(inputs, x, name='resnet50') 643 | 644 | # Load weights. 645 | if weights == 'imagenet': 646 | if include_top: 647 | weights_path = keras_utils.get_file( 648 | 'resnet50_weights_tf_dim_ordering_tf_kernels.h5', 649 | WEIGHTS_PATH, 650 | cache_subdir='models', 651 | md5_hash='a7b3fe01876f51b976af0dea6bc144eb', 652 | cache_dir=os.path.join(os.path.dirname(__file__), '..')) 653 | else: 654 | weights_path = keras_utils.get_file( 655 | 'resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5', 656 | WEIGHTS_PATH_NO_TOP, 657 | cache_subdir='models', 658 | md5_hash='a268eb855778b3df3c7506639542a6af', 659 | cache_dir=os.path.join(os.path.dirname(__file__), '..')) 660 | model.load_weights(weights_path) 661 | if backend.backend() == 'theano': 662 | keras_utils.convert_all_kernels_in_model(model) 663 | elif weights is not None: 664 | model.load_weights(weights) 665 | 666 | return model 667 | 668 | 669 | -------------------------------------------------------------------------------- /models/resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PlantBiTree/Classify_localtrain/2bcb8af0b78ee69e370f60e919a1163700ad96b4/models/resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5 -------------------------------------------------------------------------------- /predict_local.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import re 3 | import os 4 | import json 5 | from tools.data_gen import preprocess_img,preprocess_img_from_Url 6 | from models.resnet50 import ResNet50 7 | from keras.layers import Dense,Dropout,BatchNormalization,GlobalAveragePooling2D 8 | from keras.models import Model 9 | import numpy as np 10 | from keras import regularizers 11 | from tensorflow.python.keras.backend import set_session 12 | 13 | # import serial 14 | # OPTIONAL: control usage of GPU 15 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 16 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 17 | config = tf.ConfigProto() 18 | config.gpu_options.allow_growth = True 19 | config.gpu_options.per_process_gpu_memory_fraction = 0.7 20 | sess = tf.Session(config=config) 21 | 22 | # 全局配置文件 23 | tf.app.flags.DEFINE_integer('num_classes', 40, '垃圾分类数目') 24 | tf.app.flags.DEFINE_integer('input_size', 224, '模型输入图片大小') 25 | tf.app.flags.DEFINE_integer('batch_size', 16, '图片批处理大小') 26 | 27 | FLAGS = tf.app.flags.FLAGS 28 | h5_weights_path = './output_model/best.h5' 29 | 30 | # 增加最后输出层 31 | def add_new_last_layer(base_model,num_classes): 32 | x = base_model.output 33 | x = GlobalAveragePooling2D(name='avg_pool')(x) 34 | x = Dropout(0.5,name='dropout1')(x) 35 | # x = Dense(1024,activation='relu',kernel_regularizer= regularizers.l2(0.0001),name='fc1')(x) 36 | # x = BatchNormalization(name='bn_fc_00')(x) 37 | x = Dense(512,activation='relu',kernel_regularizer= regularizers.l2(0.0001),name='fc2')(x) 38 | x = BatchNormalization(name='bn_fc_01')(x) 39 | x = Dropout(0.5,name='dropout2')(x) 40 | x = Dense(num_classes,activation='softmax')(x) 41 | model = Model(inputs=base_model.input,outputs=x) 42 | return model 43 | 44 | 45 | # 加载模型 46 | def model_fn(FLAGS): 47 | # K.set_learning_phase(0) 48 | # setup model 49 | base_model = ResNet50(weights="imagenet", 50 | include_top=False, 51 | pooling=None, 52 | input_shape=(FLAGS.input_size, FLAGS.input_size, 3), 53 | classes=FLAGS.num_classes) 54 | for layer in base_model.layers: 55 | layer.trainable = False 56 | 57 | # if FLAGS.mode == 'train': 58 | # K.set_learning_phase(1) 59 | model = add_new_last_layer(base_model,FLAGS.num_classes) 60 | 61 | # print(model.summary()) 62 | # print(model.layers[84].name) 63 | # exit() 64 | 65 | # Adam = adam(lr=FLAGS.learning_rate,clipnorm=0.001) 66 | model.compile(optimizer="adam",loss = 'categorical_crossentropy',metrics=['accuracy']) 67 | return model 68 | 69 | # 暴露模型初始化 70 | def init_artificial_neural_network(): 71 | set_session(sess) 72 | model = model_fn(FLAGS) 73 | model.load_weights(h5_weights_path, by_name=True) 74 | return model 75 | # model = model_fn(FLAGS) 76 | #model.load_weights(h5_weights_path, by_name=True) 77 | #return model 78 | 79 | 80 | # 测试图片 81 | def prediction_result_from_img(model,imgurl): 82 | """ 83 | 84 | :rtype: object 85 | """ 86 | # 加载分类数据 87 | with open("./garbage_classify/garbage_classify_rule.json", 'br') as load_f: 88 | load_dict = json.load(load_f) 89 | if re.match(r'^https?:/{2}\w.+$', imgurl): 90 | test_data = preprocess_img_from_Url(imgurl,FLAGS.input_size) 91 | else: 92 | test_data = preprocess_img(imgurl,FLAGS.input_size) 93 | tta_num = 5 94 | predictions = [0 * tta_num] 95 | for i in range(tta_num): 96 | x_test = test_data[i] 97 | x_test = x_test[np.newaxis, :, :, :] 98 | prediction = model.predict(x_test)[0] 99 | # print(prediction) 100 | predictions += prediction 101 | pred_label = np.argmax(predictions, axis=0) 102 | print('-------深度学习垃圾分类预测结果----------') 103 | print(pred_label) 104 | print(load_dict[str(pred_label)]) 105 | print('-------深度学习垃圾分类预测结果--------') 106 | return load_dict[str(pred_label)] 107 | 108 | 109 | 110 | if __name__ == "__main__": 111 | model = init_artificial_neural_network() 112 | while True: 113 | try: 114 | img_url = input("请输入图片地址:") 115 | print('您输入的图片地址为:' + img_url) 116 | res = prediction_result_from_img(model, img_url) 117 | except Exception as e: 118 | print('发生了异常:', e) 119 | 120 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | from flask_sever import sess 2 | from predict_local import prediction_result_from_img, init_artificial_neural_network 3 | 4 | if __name__ == '__main__': 5 | model = init_artificial_neural_network(sess); 6 | while True: 7 | try: 8 | img_url = input("请输入图片地址:") 9 | print('您输入的图片地址为:' + img_url) 10 | res = prediction_result_from_img(model, img_url) 11 | except Exception as e: 12 | print('发生了异常:', e) 13 | -------------------------------------------------------------------------------- /static/images/banana.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PlantBiTree/Classify_localtrain/2bcb8af0b78ee69e370f60e919a1163700ad96b4/static/images/banana.jpg -------------------------------------------------------------------------------- /static/images/test.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PlantBiTree/Classify_localtrain/2bcb8af0b78ee69e370f60e919a1163700ad96b4/static/images/test.jpg -------------------------------------------------------------------------------- /static/images/test1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PlantBiTree/Classify_localtrain/2bcb8af0b78ee69e370f60e919a1163700ad96b4/static/images/test1.jpg -------------------------------------------------------------------------------- /templates/predict.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 在线测试神经网络垃圾分类识别接口 6 | 7 | 9 | 10 | 11 | 12 | 13 | 14 | 15 |
16 |

深度学习之垃圾分类api

17 |
18 |
19 |

20 | 在线测试图片. 21 |

22 |
23 |
24 |
选择文件
25 |
26 |
27 | 28 | 31 |
32 |
33 | 34 |
35 |
36 |

提交

37 |
38 |
39 | 40 |
41 | 42 |
43 | 44 | 45 |
46 |
47 |

http://10.11.2.17:5000/predict

48 |

post方式访问:上传文件:file

49 |

例如:file:"XXXXXXX"

50 |
51 |
52 | 53 | 54 | 55 | 56 |
57 | 58 | 59 | 81 | 82 | 85 | 86 | 87 | 88 | 89 | 90 | -------------------------------------------------------------------------------- /test_image/banana.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PlantBiTree/Classify_localtrain/2bcb8af0b78ee69e370f60e919a1163700ad96b4/test_image/banana.jpg -------------------------------------------------------------------------------- /test_image/box.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PlantBiTree/Classify_localtrain/2bcb8af0b78ee69e370f60e919a1163700ad96b4/test_image/box.jpg -------------------------------------------------------------------------------- /test_image/cigar.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PlantBiTree/Classify_localtrain/2bcb8af0b78ee69e370f60e919a1163700ad96b4/test_image/cigar.jpg -------------------------------------------------------------------------------- /test_image/egg.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PlantBiTree/Classify_localtrain/2bcb8af0b78ee69e370f60e919a1163700ad96b4/test_image/egg.jpg -------------------------------------------------------------------------------- /test_image/img_14052.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PlantBiTree/Classify_localtrain/2bcb8af0b78ee69e370f60e919a1163700ad96b4/test_image/img_14052.jpg -------------------------------------------------------------------------------- /test_image/test.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PlantBiTree/Classify_localtrain/2bcb8af0b78ee69e370f60e919a1163700ad96b4/test_image/test.jpg -------------------------------------------------------------------------------- /test_image/test1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PlantBiTree/Classify_localtrain/2bcb8af0b78ee69e370f60e919a1163700ad96b4/test_image/test1.jpg -------------------------------------------------------------------------------- /tools/__pycache__/data_gen.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PlantBiTree/Classify_localtrain/2bcb8af0b78ee69e370f60e919a1163700ad96b4/tools/__pycache__/data_gen.cpython-36.pyc -------------------------------------------------------------------------------- /tools/__pycache__/utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PlantBiTree/Classify_localtrain/2bcb8af0b78ee69e370f60e919a1163700ad96b4/tools/__pycache__/utils.cpython-36.pyc -------------------------------------------------------------------------------- /tools/data_gen.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | import math 4 | import codecs 5 | import random 6 | import numpy as np 7 | from glob import glob 8 | from PIL import Image 9 | import requests as req 10 | from io import BytesIO 11 | from keras.utils import np_utils, Sequence 12 | from sklearn.model_selection import train_test_split 13 | from collections import Counter 14 | from models.resnet50 import preprocess_input 15 | from tools.utils import Cutout 16 | from keras.preprocessing.image import ImageDataGenerator 17 | 18 | class BaseSequence(Sequence): 19 | """ 20 | 基础的数据流生成器,每次迭代返回一个batch 21 | BaseSequence可直接用于fit_generator的generator参数 22 | fit_generator会将BaseSequence再次封装为一个多进程的数据流生成器 23 | 而且能保证在多进程下的一个epoch中不会重复取相同的样本 24 | """ 25 | def __init__(self, img_paths, labels, batch_size, img_size,is_train): 26 | assert len(img_paths) == len(labels), "len(img_paths) must equal to len(lables)" 27 | assert img_size[0] == img_size[1], "img_size[0] must equal to img_size[1]" 28 | ## (?,41) 29 | self.x_y = np.hstack((np.array(img_paths).reshape(len(img_paths), 1), np.array(labels))) 30 | self.batch_size = batch_size 31 | self.img_size = img_size 32 | self.is_train = is_train 33 | if self.is_train: 34 | train_datagen = ImageDataGenerator( 35 | rotation_range = 30, # 图片随机转动角度 36 | width_shift_range = 0.2, #浮点数,图片宽度的某个比例,数据提升时图片水平偏移的幅度 37 | height_shift_range = 0.2, #浮点数,图片高度的某个比例,数据提升时图片竖直偏移的幅度 38 | shear_range = 0.2, # 剪切强度(逆时针方向的剪切变换角度) 39 | zoom_range = 0.2, # 随机缩放的幅度, 40 | horizontal_flip = True, # 随机水平翻转 41 | vertical_flip = True, # 随机竖直翻转 42 | fill_mode = 'nearest' 43 | ) 44 | self.train_datagen = train_datagen 45 | 46 | def __len__(self): 47 | return math.ceil(len(self.x_y) / self.batch_size) 48 | 49 | 50 | @staticmethod 51 | def center_img(img, size=None, fill_value=255): 52 | """ 53 | center img in a square background 54 | """ 55 | h, w = img.shape[:2] 56 | if size is None: 57 | size = max(h, w) 58 | # h,w,channel 59 | shape = (size, size) + img.shape[2:] 60 | background = np.full(shape, fill_value, np.uint8) 61 | center_x = (size - w) // 2 62 | center_y = (size - h) // 2 63 | background[center_y:center_y + h, center_x:center_x + w] = img 64 | return background 65 | 66 | def preprocess_img(self, img_path): 67 | """ 68 | image preprocessing 69 | you can add your special preprocess method here 70 | """ 71 | img = Image.open(img_path) 72 | img = img.resize((256,256)) 73 | img = img.convert('RGB') 74 | img = np.array(img) 75 | img = img[16:16+224,16:16+224] 76 | return img 77 | 78 | 79 | def cutout_img(self,img): 80 | cut_out = Cutout(n_holes=1,length=40) 81 | img = cut_out(img) 82 | return img 83 | 84 | 85 | def __getitem__(self, idx): 86 | 87 | # 图片路径 88 | batch_x = self.x_y[idx * self.batch_size: (idx + 1) * self.batch_size, 0] 89 | # 图片标签 90 | batch_y = self.x_y[idx * self.batch_size: (idx + 1) * self.batch_size, 1:] 91 | # 这里是像素数组 (224,224,3) 92 | 93 | batch_x = np.array([self.preprocess_img(img_path) for img_path in batch_x]) 94 | # smooth labels 95 | batch_y = np.array(batch_y).astype(np.float32)*(1-0.05)+0.05/40 96 | 97 | # # 训练集数据增强 98 | if self.is_train: 99 | indexs = np.random.choice([0,1,2],batch_x.shape[0],replace=True,p=[0.4,0.4,0.2]) 100 | mask_indexs = np.where(indexs==1) 101 | multi_indexs = np.where(indexs==2) 102 | 103 | if len(multi_indexs): 104 | # 数据增强 105 | multipy_batch_x = batch_x[multi_indexs] 106 | multipy_batch_y = batch_y[multi_indexs] 107 | 108 | train_datagenerator = self.train_datagen.flow(multipy_batch_x,multipy_batch_y,batch_size=self.batch_size) 109 | (multipy_batch_x,multipy_batch_y) = train_datagenerator.next() 110 | 111 | batch_x[multi_indexs] = multipy_batch_x 112 | batch_y[multi_indexs] = multipy_batch_y 113 | 114 | if len(mask_indexs[0]): 115 | # 随机遮挡 116 | mask_batch_x = batch_x[mask_indexs] 117 | mask_batch_y = batch_y[mask_indexs] 118 | mask_batch_x = np.array([self.cutout_img(img) for img in mask_batch_x]) 119 | 120 | batch_x[mask_indexs] = mask_batch_x 121 | batch_y[mask_indexs] = mask_batch_y 122 | 123 | 124 | # 预处理 125 | batch_x =np.array([preprocess_input(img) for img in batch_x]) 126 | 127 | # # plt 绘制图像时需要将其换成整型 128 | # for index,label in enumerate(batch_y): 129 | # print(np.argmax(label)) 130 | # plt.subplot(2,8,index+1) 131 | # plt.imshow(batch_x[index].astype(int)) 132 | # plt.show() 133 | # exit() 134 | 135 | return batch_x, batch_y 136 | 137 | def on_epoch_end(self): 138 | """Method called at the end of every epoch. 139 | """ 140 | np.random.shuffle(self.x_y) 141 | 142 | # 加载训练文件到模型标签 143 | def data_flow(train_data_dir, batch_size, num_classes, input_size): 144 | label_files = glob(os.path.join(train_data_dir, '*.txt')) 145 | random.shuffle(label_files) 146 | img_paths = [] 147 | labels = [] 148 | for index, file_path in enumerate(label_files): 149 | with codecs.open(file_path, 'r', 'utf-8') as f: 150 | line = f.readline() 151 | line_split = line.strip().split(', ') 152 | if len(line_split) != 2: 153 | print('%s contain error lable' % os.path.basename(file_path)) 154 | continue 155 | img_name = line_split[0] 156 | label = int(line_split[1]) 157 | img_paths.append(os.path.join(train_data_dir, img_name)) 158 | labels.append(label) 159 | 160 | labels = np_utils.to_categorical(labels, num_classes) 161 | train_img_paths, validation_img_paths, train_labels, validation_labels = \ 162 | train_test_split(img_paths, labels, stratify=labels,test_size=0.15, random_state=0) 163 | print('total samples: %d, training samples: %d, validation samples: %d' % (len(img_paths), len(train_img_paths), len(validation_img_paths))) 164 | # 训练集随机增强图片 165 | train_sequence = BaseSequence(train_img_paths, train_labels, batch_size, [input_size, input_size],is_train=True) 166 | validation_sequence = BaseSequence(validation_img_paths, validation_labels, batch_size, [input_size, input_size],is_train=False) 167 | return train_sequence,validation_sequence 168 | 169 | 170 | 171 | 172 | # 本地路径获取图片信息 173 | def preprocess_img(img_path,img_size): 174 | try: 175 | img = Image.open(img_path) 176 | # if img.format: 177 | # resize_scale = img_size / max(img.size[:2]) 178 | # img = img.resize((int(img.size[0] * resize_scale), int(img.size[1] * resize_scale))) 179 | img = img.resize((256, 256)) 180 | img = img.convert('RGB') 181 | # img.show() 182 | img = np.array(img) 183 | imgs = [] 184 | for _ in range(10): 185 | i = random.randint(0, 32) 186 | j = random.randint(0, 32) 187 | imgg = img[i:i + 224, j:j + 224] 188 | imgg = preprocess_input(imgg) 189 | imgs.append(imgg) 190 | return imgs 191 | except Exception as e: 192 | print('发生了异常data_process:', e) 193 | return 0 194 | 195 | 196 | 197 | 198 | # url获取图片数组信息 199 | def preprocess_img_from_Url(img_path,img_size): 200 | try: 201 | response = req.get(img_path) 202 | img = Image.open(BytesIO(response.content)) 203 | img = img.resize((256, 256)) 204 | img = img.convert('RGB') 205 | # img.show() 206 | img = np.array(img) 207 | imgs = [] 208 | for _ in range(10): 209 | i = random.randint(0, 32) 210 | j = random.randint(0, 32) 211 | imgg = img[i:i + 224, j:j + 224] 212 | imgg = preprocess_input(imgg) 213 | imgs.append(imgg) 214 | return imgs 215 | except Exception as e: 216 | print('发生了异常data_process:', e) 217 | return 0 218 | 219 | 220 | 221 | 222 | # 加载测试数据 223 | def load_test_data(FLAGS): 224 | label_files = glob(os.path.join(FLAGS.test_data_local,"*.txt")) 225 | test_data = [] 226 | img_names = [] 227 | test_labels = [] 228 | for index, file_path in enumerate(label_files): 229 | with codecs.open(file_path,'r','utf-8') as f: 230 | line = f.readline() 231 | line_split = line.strip().split(',') 232 | img_names.append(line_split[0]) 233 | # 处理图片 234 | img_path = os.path.join(FLAGS.test_data_local,line_split[0]) 235 | img = preprocess_img(img_path,FLAGS.input_size) 236 | test_data.append(preprocess_img(img_path,FLAGS.input_size)) 237 | test_labels.append(int(line_split[1])) 238 | print(Counter(test_labels)) 239 | # test_data = np.array(test_data) 240 | return img_names,test_data,test_labels 241 | 242 | 243 | 244 | # if __name__ == '__main__': 245 | # 246 | # train_data_dir = './garbage_classify/train_data/' 247 | # batch_size = 16 248 | # num_classes = 40 249 | # input_size = 224 250 | # # shape= (224,224,3) label=(16,40) 251 | # train_sequence, validation_sequence = data_flow(train_data_dir, batch_size,num_classes,input_size) 252 | # batch_data, bacth_label = train_sequence.__getitem__(5) 253 | # # print(train_sequence.shape) 254 | # print(batch_data[0]) -------------------------------------------------------------------------------- /tools/img_gen.py: -------------------------------------------------------------------------------- 1 | import os 2 | from keras.preprocessing.image import ImageDataGenerator,array_to_img,img_to_array,load_img 3 | from glob import glob 4 | import codecs 5 | from PIL import Image 6 | from collections import Counter 7 | from tqdm import tqdm 8 | import numpy as np 9 | import shutil 10 | 11 | def center_img(img,size=None,fill_value=255): 12 | h,w = img.shape[:2] 13 | if size is None: 14 | size = max(h,w) 15 | shape = (size,size) + img.shape[2:] 16 | background = np.full(shape,fill_value,np.uint8) 17 | center_x = (size-w)//2 18 | center_y = (size-h)//2 19 | background[center_y:center_y+h,center_x:center_x+w] = img 20 | return background 21 | 22 | def precess_imge(img_path,img_size): 23 | img = Image.open(img_path) 24 | resize_scale = img_size / max(img.size[:2]) 25 | img = img.resize((int(img.size[0]*resize_scale),int(img.size[1]*resize_scale))) 26 | img = img.convert('RGB') 27 | img = np.array(img) 28 | img = img[:,:,::-1] 29 | img = center_img(img,img_size) 30 | return img 31 | 32 | # 加载图片和标签 33 | def load_dataset(train_data_dir): 34 | label_files = glob(os.path.join(train_data_dir,'*.txt')) 35 | labels = [] 36 | img_data = [] 37 | for file_path in tqdm(label_files): 38 | with codecs.open(file_path,'r','utf-8') as f: 39 | line = f.readline() 40 | line_split = line.strip().split(', ') 41 | if len(line_split) != 2: 42 | print('%s contain error lable' % os.path.basename(file_path)) 43 | continue 44 | label = int(line_split[1]) 45 | img_path = os.path.join(train_data_dir,line_split[0]) 46 | # img = precess_imge(img_path,FLAGS.input_size) 47 | img_data.append(img_path) 48 | labels.append(label) 49 | print(sorted(Counter(labels).items(),key=lambda d:d[0],reverse=False)) 50 | # print(Counter(labels)) 51 | 52 | return img_data,labels 53 | 54 | 55 | # 将图片归类到文件夹中 56 | def write_list_to_dir(dstPath,img_paths,labels,is_split = False): 57 | if not os.path.exists(dstPath): 58 | os.makedirs(dstPath) 59 | index = 0 60 | for label in tqdm(labels): 61 | label = str(label) 62 | savedir = os.path.join(dstPath,label) 63 | if not is_split: 64 | savedir = os.path.join(savedir,label) 65 | if not os.path.exists(savedir): 66 | os.makedirs(savedir) 67 | img = Image.open(img_paths[index]) 68 | imgsavePath = os.path.join(savedir,os.path.basename(img_paths[index])) 69 | img.save(imgsavePath) 70 | index += 1 71 | 72 | def write_split_dataset(dstPath,img_paths,labels): 73 | if not os.path.exists(dstPath): 74 | os.makedirs(dstPath) 75 | index = 0 76 | for label in tqdm(labels): 77 | # 写入图片 78 | img_name = os.path.basename(img_paths[index])[:-4] 79 | img = Image.open(img_paths[index]) 80 | imgsavePath = os.path.join(dstPath,os.path.basename(img_paths[index])) 81 | img.save(imgsavePath) 82 | # 写入标签 83 | txt = img_name + '.txt' 84 | with open(os.path.join(dstPath,txt),'w') as f: 85 | f.write(img_name+'.jpg'+', '+str(labels[index])) 86 | index += 1 87 | 88 | 89 | 90 | 91 | # amplify 表示扩充的倍数 class_num表示需要扩充的类别 92 | def increase_img(img_path,class_num,save_path,amplify_ratio=1): 93 | train_datagen = ImageDataGenerator( 94 | rotation_range = 30, # 图片随机转动角度 95 | width_shift_range = 0.2, #浮点数,图片宽度的某个比例,数据提升时图片水平偏移的幅度 96 | height_shift_range = 0.2, #浮点数,图片高度的某个比例,数据提升时图片竖直偏移的幅度 97 | shear_range = 0.2, # 剪切强度(逆时针方向的剪切变换角度) 98 | zoom_range = 0.2, # 随机缩放的幅度, 99 | horizontal_flip = True, # 随机水平翻转 100 | vertical_flip = True, # 随机竖直翻转 101 | fill_mode = 'nearest' 102 | ) 103 | # # 验证集合不用增强 104 | # val_datagen = ImageDataGenerator(rescale=1./255) 105 | # 原图片统计 106 | imgs = os.listdir(os.path.join(img_path,str(class_num))) 107 | print(len(imgs)) 108 | # 清空保存文件夹下的所有文件 109 | shutil.rmtree(save_path) 110 | os.makedirs(save_path) 111 | # 迭代器 112 | train_datagenator = train_datagen.flow_from_directory(img_path,shuffle=True, 113 | save_to_dir=save_path,batch_size=1,target_size=(224,224),save_prefix='img',save_format='jpg') 114 | # 生成图片和txt 115 | for i in range(int(len(imgs)*amplify_ratio)): 116 | train_datagenator.next() 117 | img_names = os.listdir(save_path) 118 | img_names_txt_list = [x[:-4]+'.txt' for x in img_names] 119 | for index,txt in enumerate(img_names_txt_list): 120 | with open(os.path.join(save_path,txt),'w') as f: 121 | f.write(img_names[index]+', '+str(class_num)) 122 | 123 | # train_datagenerator = train_datagen.flow_from_directory('./garbage_classify/data_set/train_dir/',target_size=(224,224),classes=classes,batch_size=16,seed=0) 124 | # # print(train_datagenerator.class_indices) 125 | # for data_batch,label_batch in train_datagenerator: 126 | # for index,label in enumerate(label_batch): 127 | # print(np.argmax(label)) 128 | # plt.subplot(1,16,index+1) 129 | # plt.imshow(data_batch[index]) 130 | # plt.show() 131 | # exit() 132 | 133 | # for i in range(5): 134 | # train_datagenator.next() 135 | 136 | 137 | def increase_train_img(train_path,class_num,amplify_ratio): 138 | img_paths,labels = load_dataset(train_path) 139 | img_path = f'./garbage_classify/data_set/{class_num}/' 140 | save_path = './garbage_classify/increase_img/' 141 | increase_img(img_path,class_num,save_path,amplify_ratio) 142 | 143 | 144 | 145 | 146 | 147 | if __name__ == "__main__": 148 | train_data_dir = './garbage_classify/train_data_save/' 149 | # dstPath = './garbage_classify/data_set' 150 | # num_classes = 40 151 | img_paths,labels = load_dataset(train_data_dir) 152 | # 不划分数据集 153 | # write_list_to_dir(dstPath,img_paths,labels) 154 | 155 | # 划分验证集和训练集 156 | # train_img_paths, validation_img_paths, train_labels, validation_labels = train_test_split(img_paths,labels,stratify=labels,test_size=0.25,random_state=0) 157 | # print('total samples: %d, training samples: %d, validation samples: %d' % (len(img_paths), len(train_img_paths), len(validation_img_paths))) 158 | # write_split_dataset('./garbage_classify/splitDataset/train',train_img_paths,train_labels) 159 | # write_split_dataset('./garbage_classify/splitDataset/val',validation_img_paths,validation_labels) 160 | # train_dstPath = './garbage_classify/splitDataset/train' 161 | # val_dstPath = './garbage_classify/splitDataset/val' 162 | # write_list_to_dir(train_dstPath,train_img_paths,train_labels,True) 163 | # write_list_to_dir(val_dstPath,validation_img_paths,validation_labels,True) 164 | 165 | 166 | # # 增广数据集 167 | # train_path = './garbage_classify/splitDataset/train' 168 | # class_num = 0 169 | # amplify_ratio = 4 170 | # img_paths,labels = load_dataset(train_path) 171 | # increase_train_img(train_path,class_num,amplify_ratio) 172 | 173 | 174 | 175 | 176 | 177 | -------------------------------------------------------------------------------- /tools/utils.py: -------------------------------------------------------------------------------- 1 | ## Cutout 2 | 3 | import numpy as np 4 | import tensorflow as tf 5 | import matplotlib.pyplot as plt 6 | 7 | class Cutout(object): 8 | """Randomly mask out one or more patches from an image. 9 | Args: 10 | n_holes (int): Number of patches to cut out of each image. 11 | length (int): The length (in pixels) of each square patch. 12 | """ 13 | def __init__(self, n_holes, length): 14 | self.n_holes = n_holes 15 | self.length = length 16 | 17 | def __call__(self,img): 18 | h = img.shape[0] 19 | w = img.shape[1] 20 | c = img.shape[2] 21 | 22 | mask = np.ones((h,w,c),np.float32) 23 | 24 | for n in range(self.n_holes): 25 | y = np.random.randint(h) 26 | x = np.random.randint(w) 27 | # 截取函数 遮挡部分不能超过图片的一半 28 | y1 = np.clip(y - self.length//2,0,h) 29 | y2 = np.clip(y+self.length//2,0,h) 30 | x1 = np.clip(x-self.length//2,0,w) 31 | x2 = np.clip(x+self.length//2,0,w) 32 | 33 | mask[y1:y2,x1:x2,:] = 0. 34 | 35 | # mask = tf.convert_to_tensor(mask) 36 | # mask = tf.reshape(mask,img.shape) 37 | img = img*mask 38 | 39 | return img 40 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import os 3 | from keras.optimizers import adam 4 | from tools.data_gen import data_flow 5 | from models.resnet50 import ResNet50 6 | from keras.layers import Dense,Dropout,BatchNormalization,GlobalAveragePooling2D 7 | from keras.models import \ 8 | Model 9 | from keras.callbacks import ReduceLROnPlateau,ModelCheckpoint,EarlyStopping 10 | from keras import regularizers 11 | 12 | 13 | 14 | # 设备控制台输出配置 15 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 16 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 17 | config = tf.compat.v1.ConfigProto() 18 | config.gpu_options.allow_growth = True 19 | config.gpu_options.per_process_gpu_memory_fraction = 0.8 20 | sess = tf.compat.v1.Session(config=config) 21 | os.environ["TF_CPP_MIN_LOG_LEVEL"] = '99' 22 | 23 | # 全局配置文件 24 | tf.app.flags.DEFINE_string('test_data_local', './test_data', '测试图片文件夹') 25 | tf.app.flags.DEFINE_string('data_local', './garbage_classify/train_data', '训练图片文件夹') 26 | tf.app.flags.DEFINE_integer('num_classes', 40, '垃圾分类数目') 27 | tf.app.flags.DEFINE_integer('input_size', 224, '模型输入图片大小') 28 | tf.app.flags.DEFINE_integer('batch_size', 16, '图片批处理大小') 29 | tf.app.flags.DEFINE_float('learning_rate',1e-4, '学习率') 30 | tf.app.flags.DEFINE_integer('max_epochs', 4, '轮次') 31 | tf.app.flags.DEFINE_string('train_local', './output_model', '训练输出文件夹') 32 | tf.app.flags.DEFINE_integer('keep_weights_file_num', 20, '如果设置为-1,则文件保持的最大权重数表示无穷大') 33 | FLAGS = tf.app.flags.FLAGS 34 | 35 | ## test_acc = 0.78 36 | def add_new_last_layer(base_model,num_classes): 37 | x = base_model.output 38 | x = GlobalAveragePooling2D(name='avg_pool')(x) 39 | x = Dropout(0.5,name='dropout1')(x) 40 | # x = Dense(1024,activation='relu',kernel_regularizer= regularizers.l2(0.0001),name='fc1')(x) 41 | # x = BatchNormalization(name='bn_fc_00')(x) 42 | x = Dense(512,activation='relu',kernel_regularizer= regularizers.l2(0.0001),name='fc2')(x) 43 | x = BatchNormalization(name='bn_fc_01')(x) 44 | x = Dropout(0.5,name='dropout2')(x) 45 | x = Dense(num_classes,activation='softmax')(x) 46 | model = Model(inputs=base_model.input,outputs=x) 47 | return model 48 | 49 | #模型微调 50 | def setup_to_finetune(FLAGS,model,layer_number=149): 51 | # K.set_learning_phase(0) 52 | for layer in model.layers[:layer_number]: 53 | layer.trainable = False 54 | # K.set_learning_phase(1) 55 | for layer in model.layers[layer_number:]: 56 | layer.trainable = True 57 | # Adam = adam(lr=FLAGS.learning_rate,clipnorm=0.001) 58 | Adam = adam(lr=FLAGS.learning_rate,decay=0.0005) 59 | model.compile(optimizer=Adam,loss='categorical_crossentropy',metrics=['accuracy']) 60 | 61 | 62 | #模型初始化设置 63 | def model_fn(FLAGS): 64 | # K.set_learning_phase(0) 65 | # 引入初始化resnet50模型 66 | base_model = ResNet50(weights="imagenet", 67 | include_top=False, 68 | pooling=None, 69 | input_shape=(FLAGS.input_size, FLAGS.input_size, 3), 70 | classes=FLAGS.num_classes) 71 | for layer in base_model.layers: 72 | layer.trainable = False 73 | model = add_new_last_layer(base_model,FLAGS.num_classes) 74 | model.compile(optimizer="adam",loss = 'categorical_crossentropy',metrics=['accuracy']) 75 | return model 76 | def train_model(FLAGS): 77 | # 训练数据构建 78 | train_sequence, validation_sequence = data_flow(FLAGS.data_local, FLAGS.batch_size, 79 | FLAGS.num_classes, FLAGS.input_size) 80 | model = model_fn(FLAGS) 81 | history_tl = model.fit_generator( 82 | train_sequence, 83 | steps_per_epoch = len(train_sequence), 84 | epochs = FLAGS.max_epochs, 85 | verbose = 1, 86 | validation_data = validation_sequence, 87 | max_queue_size = 10, 88 | shuffle=True 89 | ) 90 | #模型微调 91 | setup_to_finetune(FLAGS,model) 92 | history_tl = model.fit_generator( 93 | train_sequence, 94 | steps_per_epoch = len(train_sequence), 95 | epochs = FLAGS.max_epochs*5, 96 | verbose = 1, 97 | callbacks = [ 98 | ModelCheckpoint('./output_model/best.h5', 99 | monitor='val_loss', save_best_only=True, mode='min'), 100 | ReduceLROnPlateau(monitor='val_loss', factor=0.1, 101 | patience=10, mode='min'), 102 | EarlyStopping(monitor='val_loss', patience=10), 103 | ], 104 | validation_data = validation_sequence, 105 | max_queue_size = 10, 106 | shuffle=True 107 | ) 108 | print('training done!') 109 | 110 | 111 | if __name__ == "__main__": 112 | train_model(FLAGS) 113 | --------------------------------------------------------------------------------