├── Figures ├── plot.png ├── valset.png ├── train+val.png ├── trainset.png ├── graph_run=.png ├── inception-v3.png └── Screenshot from 2018-10-07 09-48-43.png ├── Code ├── plot.py └── plant_disease.py └── README.md /Figures/plot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xungeer29/AI-Challenger-Plant-Disease-Recognition/HEAD/Figures/plot.png -------------------------------------------------------------------------------- /Figures/valset.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xungeer29/AI-Challenger-Plant-Disease-Recognition/HEAD/Figures/valset.png -------------------------------------------------------------------------------- /Figures/train+val.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xungeer29/AI-Challenger-Plant-Disease-Recognition/HEAD/Figures/train+val.png -------------------------------------------------------------------------------- /Figures/trainset.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xungeer29/AI-Challenger-Plant-Disease-Recognition/HEAD/Figures/trainset.png -------------------------------------------------------------------------------- /Figures/graph_run=.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xungeer29/AI-Challenger-Plant-Disease-Recognition/HEAD/Figures/graph_run=.png -------------------------------------------------------------------------------- /Figures/inception-v3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xungeer29/AI-Challenger-Plant-Disease-Recognition/HEAD/Figures/inception-v3.png -------------------------------------------------------------------------------- /Figures/Screenshot from 2018-10-07 09-48-43.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xungeer29/AI-Challenger-Plant-Disease-Recognition/HEAD/Figures/Screenshot from 2018-10-07 09-48-43.png -------------------------------------------------------------------------------- /Code/plot.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | 3 | import json 4 | from matplotlib import pyplot as plt 5 | import numpy as np 6 | 7 | train_label_dir = '/media/gfx/GFX/DATASET/Plant_Disease_Recognition/ai_challenger_pdr2018_train_annotations_20181021.json' 8 | val_label_dir = '/media/gfx/GFX/DATASET/Plant_Disease_Recognition/ai_challenger_pdr2018_validation_annotations_20181021.json' 9 | save_dir = './' 10 | 11 | classes = 61 12 | 13 | # trainset 14 | with open(train_label_dir, 'r') as f_train: 15 | image_label_list_train = json.load(f_train) 16 | 17 | num_per_label_train = np.zeros(classes, dtype=np.int32) 18 | labels_train = [] 19 | for index in range(len(image_label_list_train)): 20 | image_label_train = image_label_list_train[index] 21 | label_train = image_label_train['disease_class'] 22 | num_per_label_train[int(label_train)] += 1 23 | labels_train.append(int(label_train)) 24 | 25 | print('trainset:{}'.format(num_per_label_train)) 26 | plt.hist(labels_train, bins=classes) 27 | plt.xlabel('label') 28 | plt.ylabel('num_per_label') 29 | plt.title('Plant-Disease-Recognition-Trainset') 30 | plt.show() 31 | 32 | # validation 33 | with open(val_label_dir, 'r') as f_val: 34 | image_label_list_val = json.load(f_val) 35 | 36 | num_per_label_val = np.zeros(classes, dtype=np.int32) 37 | labels_val = [] 38 | for index in range(len(image_label_list_val)): 39 | image_label_val = image_label_list_val[index] 40 | label_val = image_label_val['disease_class'] 41 | num_per_label_val[int(label_val)] += 1 42 | labels_val.append(int(label_val)) 43 | 44 | print('validation set:{}'.format(num_per_label_val)) 45 | plt.hist(labels_val, bins=classes) 46 | plt.xlabel('label') 47 | plt.ylabel('num_per_label') 48 | plt.title('Plant-Disease-Recognition-Validation Set') 49 | plt.show() 50 | 51 | # train set + val set 52 | image_label_list = image_label_list_train + image_label_list_val 53 | num_per_label = np.zeros(classes, dtype=np.int32) 54 | labels = [] 55 | for index in range(len(image_label_list)): 56 | image_label_dict = image_label_list[index] 57 | label = image_label_dict['disease_class'] 58 | num_per_label[int(label)] += 1 59 | labels.append(int(label)) 60 | 61 | print('trainset + validation set:{}'.format(num_per_label)) 62 | plt.hist(labels, bins=classes) 63 | plt.xlabel('label') 64 | plt.ylabel('num_per_label') 65 | plt.title('Plant-Disease-Recognition-Trainset and Validation set') 66 | plt.show() 67 | # 68 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # AI-Challenger-Plant-Disease-Recognition 2 | ## 农作物病害检测 3 | 详情请见[CSDN](https://blog.csdn.net/qq_40859461/article/details/84199358#commentsedit) 4 | 5 | ## 环境配置 6 | python==2.7 7 | 8 | tensorflow==1.2.1 9 | 10 | ## 使用方法 11 | * 更改 plot.py 脚本中路径,运行该脚本,可以绘出数据分布的直方图 12 | * 下载预训练模型 [Inception-V3](https://storage.googleapis.com/download.tensorflow.org/models/inception_dec_2015.zip 13 | ) 14 | * 更改 plant_disease.py 中的输入文件路径,输出文件路径,预训练模型文件路径 15 | * 在 code 路径下直接运行 python plant_disease.py 16 | * 训练完成后会直接使用训练得到的参数预测 testA 数据集,生成可以用来直接提交的 json 文件 17 | 18 | ## 大佬开源分享 19 | * [spytensor/plants_disease_detection](https://github.com/spytensor/plants_disease_detection) 20 | * 框架:pytorch 21 | * 最终成绩:0.875 22 | 23 | * [bochuanwu/Agricultural-Disease-Classification](https://github.com/bochuanwu/Agricultural-Disease-Classification) 24 | * 框架:keras 25 | * 最终成绩:0.88658 26 | 27 | ## 其他 28 | | Label ID | Label Name | 29 | |----------|--------------| 30 | | 0 | apple healthy(苹果健康) | 31 | | 1 | Apple_Scab general(苹果黑星病一般) | 32 | | 2 | Apple_Scab serious(苹果黑星病严重) | 33 | | 3 | Apple Frogeye Spot(苹果灰斑病) | 34 | | 4 | Cedar Apple Rust general(苹果雪松锈病一般) | 35 | | 5 | Cedar Apple Rust serious(苹果雪松锈病严重) | 36 | | 6 | Cherry healthy(樱桃健康) | 37 | | 7 | Cherry_Powdery Mildew general(樱桃白粉病一般) | 38 | | 8 | Cherry_Powdery Mildew serious(樱桃白粉病严重) | 39 | | 9 | Corn healthy(玉米健康) | 40 | | 10 | Cercospora zeaemaydis Tehon and Daniels general(玉米灰斑病一般)| 41 | | 11 | Cercospora zeaemaydis Tehon and Daniels serious(玉米灰斑病严重) | 42 | | 12 | Puccinia polysora general(玉米锈病一般) | 43 | | 13 | Puccinia polysora serious(玉米锈病严重) | 44 | | 14 | Corn Curvularia leaf spot fungus general(玉米叶斑病一般) | 45 | | 15 | Corn Curvularia leaf spot fungus serious(玉米叶斑病严重) | 46 | | 16 | Maize dwarf mosaic virus(玉米花叶病毒病) | 47 | | 17 | Grape heathy(葡萄健康) | 48 | | 18 | Grape Black Rot Fungus general(葡萄黑腐病一般) | 49 | | 19 | Grape Black Rot Fungus serious(葡萄黑腐病严重) | 50 | | 20 | Grape Black Measles Fungus general(葡萄轮斑病一般) | 51 | | 21 | Grape Black Measles Fungus serious(葡萄轮斑病严重) | 52 | | 22 | Grape Leaf Blight Fungus general(葡萄褐斑病一般) | 53 | | 23 | Grape Leaf Blight Fungus serious(葡萄褐斑病严重) | 54 | | 24 | Citrus healthy(柑桔健康) | 55 | | 25 | Citrus Greening June general(柑桔黄龙病一般) | 56 | | 26 | Citrus Greening June serious(柑桔黄龙病严重) | 57 | | 27 | Peach healthy(桃健康) | 58 | | 28 | Peach_Bacterial Spot general(桃疮痂病一般) | 59 | | 29 | Peach_Bacterial Spot serious(桃疮痂病严重) | 60 | | 30 | Pepper healthy(辣椒健康) | 61 | | 31 | Pepper scab general(辣椒疮痂病一般) | 62 | | 32 | Pepper scab serious(辣椒疮痂病严重) | 63 | | 33 | Potato healthy(马铃薯健康) | 64 | | 34 | Potato_Early Blight Fungus general(马铃薯早疫病一般) | 65 | | 35 | Potato_Early Blight Fungus serious(马铃薯早疫病严重) | 66 | | 36 | Potato_Late Blight Fungus general(马铃薯晚疫病一般) | 67 | | 37 | Potato_Late Blight Fungus serious(马铃薯晚疫病严重) | 68 | | 38 | Strawberry healthy(草莓健康) | 69 | | 39 | Strawberry_Scorch general(草莓叶枯病一般) | 70 | | 40 | Strawberry_Scorch serious(草莓叶枯病严重) | 71 | | 41 | Tomato healthy(番茄健康) | 72 | | 42 | tomato powdery mildew general(番茄白粉病一般) | 73 | | 43 | tomato powdery mildew serious(番茄白粉病严重) | 74 | | 44 | Tomato Bacterial Spot Bacteria general(番茄疮痂病一般) | 75 | | 45 | Tomato Bacterial Spot Bacteria serious(番茄疮痂病严重) | 76 | | 46 | Tomato_Early Blight Fungus general(番茄早疫病一般) | 77 | | 47 | Tomato_Early Blight Fungus serious(番茄早疫病严重) | 78 | | 48 | Tomato_Late Blight Water Mold general(番茄晚疫病菌一般) | 79 | | 49 | Tomato_Late Blight Water Mold serious(番茄晚疫病菌严重) | 80 | | 50 | Tomato_Leaf Mold Fungus general(番茄叶霉病一般) | 81 | | 51 | Tomato_Leaf Mold Fungus serious(番茄叶霉病严重) | 82 | | 52 | Tomato Target Spot Bacteria general(番茄斑点病一般) | 83 | | 53 | Tomato Target Spot Bacteria serious(番茄斑点病严重) | 84 | | 54 | Tomato_Septoria Leaf Spot Fungus general(番茄斑枯病一般) | 85 | | 55 | Tomato_Septoria Leaf Spot Fungus serious(番茄斑枯病严重) | 86 | | 56 | Tomato Spider Mite Damage general(番茄红蜘蛛损伤一般) | 87 | | 57 | Tomato Spider Mite Damage serious(番茄红蜘蛛损伤严重) | 88 | | 58 | Tomato YLCV Virus general(番茄黄化曲叶病毒病一般) | 89 | | 59 | Tomato YLCV Virus serious(番茄黄化曲叶病毒病严重) | 90 | | 60 | Tomato Tomv(番茄花叶病毒病) | 91 | 92 | 93 | ### 结果提交说明 94 | 选手返回的结果应存为JSON文件,提交结果应包含照片id与所属的分类id,格式如下: 95 | ```json 96 | [ 97 | { 98 | "image_id": "72e0dfb8d1460203d90ce46bdc0a0fb7b84a665a.jpg", 99 | "disease_class":1 100 | }, 101 | ... 102 | ] 103 | ``` 104 | -------------------------------------------------------------------------------- /Code/plant_disease.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | 3 | import json 4 | import tensorflow as tf 5 | import os 6 | import numpy as np 7 | import random 8 | from tensorflow.python.platform import gfile 9 | 10 | # 32739 11 | TRAIN_IMAGES_DIR = '/data2/gaofuxun/data/Plant_Disease_Recognition/\ 12 | ai_challenger_pdr2018_trainingset_20180905/AgriculturalDisease_trai\ 13 | ningset/images/' 14 | 15 | # 4982 16 | VAL_IMAGES_DIR = '/data2/gaofuxun/data/Plant_Disease_Recognition/\ 17 | ai_challenger_pdr2018_validationset_20180905/AgriculturalDisease_\ 18 | validationset/images/' 19 | 20 | # 4959 21 | TEST_IMAGES_DIR = '/data2/gaofuxun/data/Plant_Disease_Recognition/\ 22 | AgriculturalDisease_testA/images/' 23 | 24 | JSON_TRAIN = '/data2/gaofuxun/data/Plant_Disease_Recognition/\ 25 | ai_challenger_pdr2018_trainingset_20180905/AgriculturalDisease_trai\ 26 | ningset/AgriculturalDisease_train_annotations.json' 27 | 28 | JSON_VAL = '/data2/gaofuxun/data/Plant_Disease_Recognition/\ 29 | ai_challenger_pdr2018_validationset_20180905/AgriculturalDisease_\ 30 | validationset/AgriculturalDisease_validation_annotations.json' 31 | 32 | log_dir = './log/' 33 | model_dir = './model/' 34 | bottleneck_path = '/data2/gaofuxun/data/Plant_Disease_Recognition/\ 35 | bottleneck/' 36 | bottleneck_train_dir = 'train/' 37 | bottleneck_val_dir = 'val/' 38 | bottleneck_test_dir = 'test/' 39 | 40 | if not os.path.exists(log_dir): 41 | os.makedirs(log_dir) 42 | if not os.path.exists(model_dir): 43 | os.makedirs(model_dir) 44 | if not os.path.exists(bottleneck_path+bottleneck_train_dir): 45 | os.makedirs(bottleneck_path+bottleneck_train_dir) 46 | if not os.path.exists(bottleneck_path+bottleneck_val_dir): 47 | os.makedirs(bottleneck_path+bottleneck_val_dir) 48 | if not os.path.exists(bottleneck_path+bottleneck_test_dir): 49 | os.makedirs(bottleneck_path+bottleneck_test_dir) 50 | 51 | 52 | BATCH_SIZE = 64 # 设为2的指数倍 53 | # IMAGE_SIZE = 128 # 设为2的指数倍 使用inception-v3不需要指定图像大小 54 | LEARNING_RATE = 0.01 55 | # 衰减系数 56 | DECAY_RATE = 0.9 57 | # 衰减间隔数 58 | DECAY_STEPS = 100 59 | STEPS = 140000 60 | CLASSES = 61 61 | 62 | BOTTLENECK_TENSOR_SIZE = 2048 63 | # 瓶颈层张量名 64 | BOTTLENECK_TENSOR_NAME = 'pool_3/_reshape:0' # [1, 1, 1, 2048] 65 | # 图像输入层张量名 66 | JPEG_DATA_TENSOR_NAME = 'DecodeJpeg/contents:0' 67 | 68 | # inception_v3 路径 69 | INCEPTION_DIR = './inception_v3/tensorflow_inception_graph.pb' 70 | 71 | 72 | """ 73 | 使用Inecption-v3处理一张图像,返回该图像的feature map 74 | """ 75 | def run_bottleneck_on_image(sess, image, image_data_tensor, bottleneck_tensor): 76 | bottleneck_values = sess.run(bottleneck_tensor, 77 | {image_data_tensor: image}) 78 | print('bottleneck shape:', bottleneck_values.shape) 79 | # 将四维张量转化成一维数组 80 | bottleneck_values = np.squeeze(bottleneck_values) 81 | 82 | return bottleneck_values 83 | 84 | """ 85 | 获取一张图像经过inception-v3处理后的Tensor 86 | 如果没有则先计算保存 87 | input: 88 | --image_name: 图像名 89 | --image_path: 图像路径 90 | --category: train val test 91 | --jpeg_data_tensor: 92 | return: 93 | --bottleneck_values: 图像经inception-v3后的bottleneck输出的特征向量 94 | """ 95 | def get_or_create_bottleneck(sess, image_name, image_path, category, 96 | jpeg_data_tensor, bottleneck_tensor): 97 | if not image_name+'.txt' in os.listdir(bottleneck_path+category+'/'): 98 | image_dir = image_path + image_name 99 | image = gfile.FastGFile(image_dir, 'rb').read() 100 | image = tf.image.decode_jpeg(image) 101 | print image.eval() 102 | # 使用数据增强 103 | image = data_argument(image) 104 | # 可视化增强的图像 105 | tf.summary.image('data_argument', image, max_images=9) 106 | bottleneck_values = run_bottleneck_on_image(sess, image, 107 | jpeg_data_tensor, bottleneck_tensor) 108 | 109 | # 将特征向量写入txt 110 | bottleneck_string = ','.join(str(x) for x in bottleneck_values) 111 | with open(bottleneck_path+category+'/'+image_name+'.txt', 'w') as f: 112 | f.write(bottleneck_string) 113 | else: 114 | try: 115 | # 直接从文件中读取保存的特征向量 116 | # bottleneck_txt = bottleneck_path+category+'/'+image_name+'.txt' 117 | # bottleneck_txt = bottleneck_txt.decode('unicode-escape') 118 | with open(bottleneck_path+category+'/'+image_name+'.txt', 'r') as f: 119 | bottleneck_string = f.read() 120 | bottleneck_values = [float(x) for x in bottleneck_string.split(',')] 121 | except: 122 | print('Faile to read feature file, re-compute: ', image_name) 123 | os.remove(bottleneck_path+category+'/'+image_name+'.txt') 124 | get_or_create_bottleneck(sess, image_name, image_path, category, 125 | jpeg_data_tensor, bottleneck_tensor) 126 | 127 | return bottleneck_values 128 | 129 | 130 | """ 131 | 随机划分出一个batch的数据 132 | Input: 133 | --json_dir: json文件地址 134 | --image_path: 图像路径 跟选择train val test模式有关 135 | --category:train val test 136 | --jpeg_data_tensor:inception的输入层,使用inception时用到 137 | --bottleneck_tensor: inception的bottleneck输出层 138 | return: 139 | --bottlenencks:一batch图像的bottleneck输出特征向量 140 | --groudtruths:该batch图像的groundtruth 141 | """ 142 | def get_batch_images(sess, json_dir, image_path, category, jpeg_data_tensor, 143 | bottleneck_tensor): 144 | with open(json_dir, 'r') as f: 145 | image_label_list = json.load(f) 146 | 147 | bottlenecks = [] 148 | groundtruths = [] 149 | for _ in range(BATCH_SIZE): 150 | # 随机获取一张图像和标签的字典 151 | index = random.randrange(len(image_label_list)) 152 | image_label_dict = image_label_list[index] 153 | 154 | # 得到图像经inception后的特征向量 155 | image_name = image_label_dict['image_id'] 156 | image_name = image_name.encode('utf-8') 157 | bottleneck_values = get_or_create_bottleneck(sess, image_name, 158 | image_path, category, jpeg_data_tensor, bottleneck_tensor) 159 | bottlenecks.append(bottleneck_values) 160 | 161 | # 将label进行onehot编码 162 | label = image_label_dict['disease_class'] 163 | groundtruth = np.zeros(CLASSES, dtype=np.float32) 164 | groundtruth[int(label)] = 1.0 165 | groundtruths.append(groundtruth) 166 | 167 | return bottlenecks, groundtruths 168 | 169 | """ 170 | 数据增强 171 | """ 172 | def data_argument(image): 173 | try: 174 | # 随机裁剪 175 | rate = np.random.randint(8, 10) 176 | image_size = tf.cast(tf.shape(image).eval(), tf.int32) 177 | image = tf.random_crop(image, [int(image_size[0]*rate*0.1), int(image_size[1]*rate*0.1), 3]) 178 | except: 179 | print tf.shape(image).eval() 180 | # 随机左右翻转 181 | image = tf.image.random_flip_left_right(image) 182 | # 随机上下翻转 183 | image = tf.image.random_up_down(image) 184 | # 随机旋转90*n次 185 | image = tf.image.rot90(image, np.random.randint(1, 4)) 186 | # 均值变为0,方差变为1 187 | image = tf.image.per_image_whitening(image) 188 | 189 | return image 190 | 191 | 192 | """ 193 | 获取全部val数据,计算正确率 194 | """ 195 | def get_val_bottlenecks(sess, json_dir, image_path, jpeg_data_tensor, 196 | bottleneck_tensor): 197 | with open(json_dir, 'r') as f: 198 | image_label_list = json.load(f) 199 | 200 | bottlenecks = [] 201 | groundtruths = [] 202 | category = 'val' 203 | for index in range(len(image_label_list)): 204 | image_label_dict = image_label_list[index] 205 | 206 | # 得到图像经inception后的特征向量 207 | image_name = image_label_dict['image_id'] 208 | image_name = image_name.encode('utf-8') 209 | bottleneck_values = get_or_create_bottleneck(sess, image_name, 210 | image_path, category, jpeg_data_tensor, bottleneck_tensor) 211 | bottlenecks.append(bottleneck_values) 212 | 213 | # 将label进行onehot编码 214 | label = image_label_dict['disease_class'] 215 | groundtruth = np.zeros(CLASSES, dtype=np.float32) 216 | groundtruth[int(label)] = 1.0 217 | groundtruths.append(groundtruth) 218 | 219 | return bottlenecks, groundtruths 220 | 221 | def get_test_bottlenecks(sess, image_path, jpeg_data_tensor, bottleneck_tensor): 222 | bottlenecks = [] 223 | images_name = [] 224 | category = 'test' 225 | for image_name in os.listdir(image_path): 226 | bottleneck_values = get_or_create_bottleneck(sess, image_name, 227 | image_path, category, jpeg_data_tensor, bottleneck_tensor) 228 | bottlenecks.append(bottleneck_values) 229 | images_name.append(image_name) 230 | 231 | return bottlenecks, images_name 232 | 233 | """ 234 | fine-tuning: 235 | 自己搭建几层fc增强分类效果 236 | """ 237 | def model(bottleneck_input): 238 | with tf.name_scope('fc1') as scope: 239 | weights = tf.Variable(tf.truncated_normal( 240 | [BOTTLENECK_TENSOR_SIZE, 1024], stddev=0.001)) 241 | biases = tf.get_variable(name='biases1', shape=[1024], 242 | initializer=tf.constant_initializer(0.1)) 243 | fc1 = tf.nn.relu(tf.matmul(bottleneck_input, weights) + biases, name='fc1') 244 | tf.summary.scalar('fc1'+'/sparsity', tf.nn.zero_fraction(fc1)) 245 | tf.summary.histogram('fc1'+'/activations', fc1) 246 | 247 | with tf.name_scope('fc2') as scope: 248 | weights = tf.Variable(tf.truncated_normal([1024, 512], stddev=0.001)) 249 | biases = tf.get_variable(name='biases2', shape=[512], 250 | initializer=tf.constant_initializer(0.1)) 251 | fc2 = tf.nn.relu(tf.matmul(fc1, weights)+biases, name='fc2') 252 | tf.summary.scalar('fc2'+'/sparsity', tf.nn.zero_fraction(fc2)) 253 | tf.summary.histogram('fc2'+'/activations', fc2) 254 | 255 | with tf.name_scope('softmax_linear') as scope: 256 | weights = tf.Variable(tf.truncated_normal([512, CLASSES], stddev=1/512.0)) 257 | biases = tf.get_variable(name='biases3', shape=[CLASSES], 258 | initializer=tf.constant_initializer(0.0)) 259 | logits = tf.matmul(fc2, weights)+biases 260 | final_tensor = tf.nn.softmax(logits) 261 | # tf.summary.scalar('softmax_linear'+'/sparsity', final_tensor) 262 | tf.summary.histogram('softmax_linear'+'/activations', final_tensor) 263 | 264 | return logits, final_tensor 265 | 266 | """ 267 | AM-Softmax 268 | input: 269 | --embedding: 网络输出的logits归一化的值 270 | --label_batch: a batch of groundtruth 271 | --args:上一层网络大小 272 | --nrof_classes: 类别数量 273 | return: 274 | --adjust_theta: 275 | """ 276 | def AM_logits_compute(embeddings, label_onehot, args, nrof_classes): 277 | m = 0.35 278 | s = 30 279 | with tf.name_scope('AM_logits'): 280 | kernel = tf.get_variable(name='kernel', dtype=tf.float32, 281 | shape=[args, nrof_classes], 282 | initializer=tf.contrib.layers.xavier_initializer(uniform=False)) 283 | kernel_norm = tf.nn.l2_normalize(kernel, 0, 1e-10, name='kernel_norm') 284 | cos_theta = tf.matmul(embeddings, kernel_norm) 285 | cos_theta = tf.clip_by_value(cos_theta, -1,1) # for numerical steady 286 | phi = cos_theta - m 287 | # label_onehot = tf.one_hot(label_batch, nrof_classes) 288 | adjust_theta = s * tf.where(tf.equal(label_onehot, 1), phi, cos_theta) 289 | 290 | return adjust_theta 291 | 292 | 293 | def main(_): 294 | # 加载inception-v3 295 | with gfile.FastGFile(INCEPTION_DIR, 'rb') as f: 296 | graph_def = tf.GraphDef() 297 | graph_def.ParseFromString(f.read()) 298 | # 获取数据输入层JPEG_DATA_TENSOR_NAME和瓶颈层BOTTLENECK_TENSOR_NAME 299 | bottleneck_tensor, jpeg_data_tensor = tf.import_graph_def(graph_def, 300 | return_elements=[BOTTLENECK_TENSOR_NAME, JPEG_DATA_TENSOR_NAME]) 301 | 302 | # 定义网络输入 303 | bottleneck_input = tf.placeholder(tf.float32, 304 | [None, BOTTLENECK_TENSOR_SIZE], name='BottleneckInputPlaceholder') 305 | groundtruth_input = tf.placeholder(tf.float32, 306 | [None, CLASSES], name='GroundTruthPlaceholder') 307 | 308 | # 准确率率上不去了,一直在 5%-10% 309 | # logits, final_tensor = model(bottleneck_input) 310 | 311 | # 搭建一个全连接层进行分类 312 | with tf.name_scope('fc1'): 313 | weights = tf.Variable(tf.truncated_normal( 314 | [BOTTLENECK_TENSOR_SIZE, CLASSES], stddev=0.001)) 315 | biases = tf.Variable(tf.zeros([CLASSES])) 316 | logits = tf.matmul(bottleneck_input, weights) + biases 317 | tf.summary.histogram('fc1'+'/pre_activation', logits) 318 | final_tensor = tf.nn.softmax(logits) 319 | tf.summary.histogram('fc1'+'/activation', final_tensor) 320 | 321 | # 增加AM-Softmax代替softmax 322 | # embeddings = tf.nn.l2_normalize(bottleneck_input, 1, 1e-10, name='embeddings') 323 | # AM_logits = AM_logits_compute(embeddings, groundtruth_input, 324 | # BOTTLENECK_TENSOR_SIZE, CLASSES) 325 | 326 | # 定义交叉熵损失 327 | cross_entropy = tf.nn.softmax_cross_entropy_with_logits( 328 | logits=logits, labels=groundtruth_input, 329 | name='cross_entropy_per_example') 330 | cross_entropy_mean = tf.reduce_mean(cross_entropy) 331 | tf.summary.scalar('cross_entropy', cross_entropy_mean) 332 | # 指数衰减学习率 333 | global_step = tf.Variable(0, trainable=False) 334 | lr = tf.train.exponential_decay(learning_rate=LEARNING_RATE, 335 | global_step=global_step, 336 | decay_steps=DECAY_STEPS, 337 | decay_rate=DECAY_RATE, 338 | staircase=False) 339 | tf.summary.scalar('learning_rate', lr) 340 | 341 | # 优化 342 | train_step = tf.train.GradientDescentOptimizer(lr).minimize( 343 | cross_entropy_mean) 344 | 345 | # 计算正确率 346 | with tf.name_scope('evaluation'): 347 | correct_prediction = tf.equal(tf.argmax(logits, 1), # final_tensor 348 | tf.argmax(groundtruth_input, 1)) 349 | evaluation_step = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) 350 | tf.summary.scalar('evaluation', evaluation_step) 351 | 352 | merged = tf.summary.merge_all() 353 | saver = tf.train.Saver(max_to_keep=1) 354 | max_acc = 0.7 355 | with tf.Session() as sess: 356 | writer = tf.summary.FileWriter(log_dir, tf.get_default_graph()) 357 | tf.global_variables_initializer().run() 358 | 359 | for step in range(STEPS): 360 | # 每次获取一个batch的训练数据 361 | train_bottlenecks, train_groundtruth = get_batch_images( 362 | sess, JSON_TRAIN, TRAIN_IMAGES_DIR, 'train', 363 | jpeg_data_tensor, bottleneck_tensor) 364 | _, summary = sess.run([train_step, merged], 365 | feed_dict={bottleneck_input: train_bottlenecks, 366 | groundtruth_input: train_groundtruth}) 367 | 368 | # 每100步输出一次正确率 在train中随机抽取 1 batch 的图像计算acc 369 | if step%100 == 0 or step + 1 == STEPS: 370 | cross_val_bottlenecks, cross_val_groundtruth = get_batch_images( 371 | sess, JSON_TRAIN, TRAIN_IMAGES_DIR, 'train', 372 | jpeg_data_tensor, bottleneck_tensor) 373 | cross_val_accuracy = sess.run(evaluation_step, 374 | feed_dict={bottleneck_input: cross_val_bottlenecks, 375 | groundtruth_input: cross_val_groundtruth}) 376 | print('Step %d: Accuracy on random sampled %d examples = %.3f%%' 377 | % (step, BATCH_SIZE, cross_val_accuracy*100)) 378 | 379 | # 保存精度最高的模型 380 | if cross_val_accuracy > max_acc: 381 | max_acc = cross_val_accuracy 382 | saver.save(sess, './model/plant_disease.ckpt', global_step=step) 383 | 384 | writer.add_summary(summary, step) 385 | 386 | if step%1000 == 0 or step + 1 == STEPS: 387 | # 最后在val数据集中测试正确率 388 | val_bottlenecks, val_groundtruth = get_val_bottlenecks(sess, 389 | JSON_VAL, VAL_IMAGES_DIR, jpeg_data_tensor, bottleneck_tensor) 390 | val_accuracy = sess.run(evaluation_step, feed_dict={ 391 | bottleneck_input: val_bottlenecks, groundtruth_input: val_groundtruth}) 392 | print('Validation accuracy = %.3f%%' % (val_accuracy * 100)) 393 | 394 | # 预测test中的数据,并保存成可提交的json格式 395 | test_bottlenecks, test_images = get_test_bottlenecks(sess, 396 | TEST_IMAGES_DIR, jpeg_data_tensor, bottleneck_tensor) 397 | predict_test = sess.run(logits, # final_tensor 398 | feed_dict={bottleneck_input: test_bottlenecks}) 399 | predict_test = tf.argmax(predict_test, 1) 400 | 401 | # predict_test = np.squeeze(predict_test) 402 | # bottleneck_string = ','.join(str(x) for x in predict_test) 403 | # print bottleneck_string 404 | 405 | result = [] 406 | for index in range(len(test_images)): 407 | single = {} 408 | single["disease_class"] = int(predict_test[index].eval()) 409 | single["image_id"] = test_images[index] 410 | result.append(single) 411 | # print result 412 | # with open('./test_result.txt', 'w') as f: 413 | # f.write(result) 414 | # 写入json 415 | with open('./test_result.json', 'w') as f: 416 | f.write(json.dumps(result)) 417 | 418 | writer.close() 419 | 420 | if __name__ == '__main__': 421 | tf.app.run() 422 | --------------------------------------------------------------------------------