├── README.md └── code ├── __init__.py ├── __pycache__ ├── canshu.cpython-37.pyc ├── model.cpython-37.pyc └── yuchuli.cpython-37.pyc ├── canshu.py ├── gaiming.py ├── model.py ├── test.py ├── xunlian.py ├── yuchuli.py └── zengjiashuju.py /README.md: -------------------------------------------------------------------------------- 1 | # VGG19_Fruit_recognition 2 | 3 | 基于VGG19的水果识别,水果种类:香蕉、榴莲、山竹、梨、柿子。验证集精度95.51%,测试集精度92.00% 4 | 5 | ### 文件说明: 6 | 7 | |- model:存放模型 8 | 9 | |_ model1.h5:保存的模型 10 | 11 | |- sample:存放数据集 12 | 13 | |_ test:测试集 14 | 15 | |_ train:训练集(含验证集,在yuchuli.py中划分) 16 | 17 | |- canshu.py:相关参数 18 | 19 | |- gaiming.py:批量修改图片文件命名 20 | 21 | |- model.py:模型定义 22 | 23 | |- test.py:测试 24 | 25 | |- xunlian.py:训练 26 | 27 | |- yuchuli.py:数据加载与预处理 28 | 29 | |_ zengjiashuju.py:增加数据集(图像翻转) 30 | 31 | ### 数据集: 32 | 33 | 数据集下载链接:https://pan.baidu.com/s/1aw_cmqhYXEBp3eprVtY-KA 34 | 提取码:9e89 35 | 36 | ### 模型: 37 | 38 | 模型下载链接:https://pan.baidu.com/s/13Cu6VMv9cE_filsfjmWKsA 39 | 提取码:fwi5 40 | 41 | ### 步骤: 42 | 43 | ![image-20220701141631928](C:\Users\64228\AppData\Roaming\Typora\typora-user-images\image-20220701141631928.png) 44 | 45 | ### 训练及测试结果: 46 | 47 | ![image-20220701141730589](C:\Users\64228\AppData\Roaming\Typora\typora-user-images\image-20220701141730589.png) 48 | 49 | ![image-20220701141736236](C:\Users\64228\AppData\Roaming\Typora\typora-user-images\image-20220701141736236.png) -------------------------------------------------------------------------------- /code/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time :2021/12/4 17:47 3 | # @File :__init__.py.py 4 | # @Software :PyCharm 5 | # @Project : 6 | # @Content : 7 | 8 | 9 | import numpy as np 10 | import pandas as pd 11 | import matplotlib.pyplot as plt 12 | import tensorflow as tf 13 | from tensorflow import keras 14 | from tensorflow.keras import layers, optimizers, datasets 15 | -------------------------------------------------------------------------------- /code/__pycache__/canshu.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CZH332287662/VGG19_Fruit_recognition/7ebd419a0f00353174c52d58a5f259095864aa13/code/__pycache__/canshu.cpython-37.pyc -------------------------------------------------------------------------------- /code/__pycache__/model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CZH332287662/VGG19_Fruit_recognition/7ebd419a0f00353174c52d58a5f259095864aa13/code/__pycache__/model.cpython-37.pyc -------------------------------------------------------------------------------- /code/__pycache__/yuchuli.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CZH332287662/VGG19_Fruit_recognition/7ebd419a0f00353174c52d58a5f259095864aa13/code/__pycache__/yuchuli.cpython-37.pyc -------------------------------------------------------------------------------- /code/canshu.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time :2021/12/6 16:39 3 | # @File :canshu.py 4 | # @Software :PyCharm 5 | # @Project :参数信息 6 | # @Content :可变参数信息 7 | 8 | 9 | """路径参数""" 10 | train_img_dir = "./sample/train" #训练集路径 11 | test_img_dir = "./sample/test" #测试集路径 12 | model_save_dir = "./model" #模型保存路径 13 | 14 | 15 | """种类参数""" 16 | labels = ['Pear','Banana','Persimmon','Mangosteen','Durian']#, 'Carambola','Apple', 'Mango']#, 'Pear', 'Pomegranate'] #种类标签 17 | num = len(labels) #种类数 18 | 19 | 20 | """图片参数""" 21 | image_width = 64 #图片统一宽度 22 | image_height = 64 #图片统一高度 23 | 24 | 25 | """训练参数""" 26 | train_ratio = 0.9 #训练集比例 27 | train_batch_size = 128 #训练集批次大小 28 | val_batch_size = 128 #验证集批次大小 29 | epochs = 60 #训练次数 30 | lr = 0.0001 #学习率 0.000005 31 | 32 | 33 | 34 | 35 | # """路径参数""" 36 | # train_img_dir = "./sample/train" #训练集路径 37 | # test_img_dir = "./sample/test" #测试集路径 38 | # model_save_dir = "./model" #模型保存路径 39 | # 40 | # 41 | # """种类参数""" 42 | # labels = ['Apple', 'Banana', 'Carambola', 'Kiwi', 'Mango']#, 'Orange', 'Pear','Persimmon', 'Plum', 'Pomegranate'] #种类标签 43 | # num = len(labels) #种类数 44 | # 45 | # 46 | # """图片参数""" 47 | # image_width = 64 #图片统一宽度 48 | # image_height = 64 #图片统一高度 49 | # 50 | # 51 | # """训练参数""" 52 | # train_ratio = 0.9 #训练集比例 53 | # train_batch_size = 128 #训练集批次大小 54 | # val_batch_size = 128 #验证集批次大小 55 | # epochs = 250 #训练次数 56 | # lr = 0.00005 #学习率 0.000002 -------------------------------------------------------------------------------- /code/gaiming.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time :2021/12/4 21:46 3 | # @File :gaiming.py 4 | # @Software :PyCharm 5 | # @Project :修改图片文件名称 6 | # @Content : 7 | 8 | import os 9 | import canshu 10 | 11 | 12 | path = canshu.train_img_dir #文件路径 13 | labels = canshu.labels #种类名称 14 | 15 | 16 | 17 | for label in labels: 18 | j=0 19 | file_list = os.listdir(path + "/"+ label) # 获取指定目录下的所有图片文件名 20 | #print(file_list)m 21 | for file_name in file_list: #逐个获取图片文件名 22 | old_path = path + "/" + label + "/" + file_name #获取当前文件名路径 23 | #print(old_path) 24 | new_path = path + "/" + label + "/" + label + "_" +str(j) + ".jpg" #新文件名及路径 25 | print(new_path) 26 | j += 1 27 | os.rename(old_path,new_path) #修改文件名 28 | 29 | -------------------------------------------------------------------------------- /code/model.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time :2021/12/4 17:47 3 | # @File :model.py 4 | # @Software :PyCharm 5 | # @Project :模型 6 | # @Content :模型函数 7 | 8 | import tensorflow as tf 9 | from tensorflow import keras 10 | from tensorflow.keras import layers 11 | 12 | #模型构建 13 | def model_CNN(num): 14 | # model = keras.models.Sequential([ 15 | # 16 | # layers.Conv2D(32, kernel_size=(3,3), strides=(1,1), activation='relu',padding="same"), 17 | # layers.BatchNormalization(), 18 | # layers.MaxPool2D(2,2), 19 | # 20 | # layers.Conv2D(64, kernel_size=(3,3), strides=(1, 1), activation='relu', padding="same"), 21 | # layers.BatchNormalization(), 22 | # layers.MaxPool2D(2,2), 23 | # 24 | # layers.Conv2D(128, kernel_size=(3, 3), strides=(1, 1), activation='relu', padding="same"), 25 | # layers.BatchNormalization(), 26 | # layers.MaxPool2D(2,2), 27 | # 28 | # layers.Conv2D(128, kernel_size=(3, 3), strides=(1, 1), activation='relu', padding="same"), 29 | # layers.BatchNormalization(), 30 | # layers.MaxPool2D(2,2), 31 | # 32 | # layers.Conv2D(256, kernel_size=(3, 3), strides=(1, 1), activation='relu', padding="same"), 33 | # layers.BatchNormalization(), 34 | # layers.MaxPool2D(2, 2), 35 | # 36 | # layers.Flatten(), 37 | # layers.Dense(128, activation='relu'), 38 | # layers.Dropout(0.5), 39 | # 40 | # layers.Dense(64, activation='relu'), 41 | # layers.Dropout(0.5), 42 | # 43 | # layers.Dense(num, activation='softmax') 44 | # ]) 45 | # return model 46 | 47 | #VGG13 48 | # model = keras.Sequential([ 49 | # layers.Conv2D(64, 3, strides=(1, 1), padding="same", activation="relu"), 50 | # layers.Conv2D(64, 3, strides=(1, 1), padding="same", activation="relu"), 51 | # layers.MaxPool2D(2, 2), 52 | # layers.Conv2D(128, 3, strides=(1, 1), padding="same", activation="relu"), 53 | # layers.Conv2D(128, 3, strides=(1, 1), padding="same", activation="relu"), 54 | # layers.MaxPool2D(2, 2), 55 | # layers.Conv2D(256, 3, strides=(1, 1), padding="same", activation="relu"), 56 | # layers.Conv2D(256, 3, strides=(1, 1), padding="same", activation="relu"), 57 | # layers.MaxPool2D(2, 2), 58 | # layers.Conv2D(512, 3, strides=(1, 1), padding="same", activation="relu"), 59 | # layers.Conv2D(512, 3, strides=(1, 1), padding="same", activation="relu"), 60 | # layers.MaxPool2D(2, 2), 61 | # layers.Conv2D(512, 3, strides=(1, 1), padding="same", activation="relu"), 62 | # layers.Conv2D(512, 3, strides=(1, 1), padding="same", activation="relu"), 63 | # layers.MaxPool2D(2, 2), 64 | # layers.Flatten(), 65 | # layers.Dense(256, activation="relu"), 66 | # layers.Dense(128, activation="relu"), 67 | # layers.Dense(num, activation="softmax") 68 | # ]) 69 | # return model 70 | 71 | 72 | # # VGG19 73 | # model = keras.Sequential([ 74 | # layers.Conv2D(64, 3, strides=(1, 1), padding="same", activation="relu"), 75 | # layers.Conv2D(64, 3, strides=(1, 1), padding="same", activation="relu"), 76 | # layers.MaxPool2D(2, 2), 77 | # layers.Conv2D(128, 3, strides=(1, 1), padding="same", activation="relu"), 78 | # layers.Conv2D(128, 3, strides=(1, 1), padding="same", activation="relu"), 79 | # layers.MaxPool2D(2, 2), 80 | # layers.Conv2D(256, 3, strides=(1, 1), padding="same", activation="relu"), 81 | # layers.Conv2D(256, 3, strides=(1, 1), padding="same", activation="relu"), 82 | # layers.Conv2D(256, 3, strides=(1, 1), padding="same", activation="relu"), 83 | # layers.Conv2D(256, 3, strides=(1, 1), padding="same", activation="relu"), 84 | # layers.MaxPool2D(2, 2), 85 | # layers.Conv2D(512, 3, strides=(1, 1), padding="same", activation="relu"), 86 | # layers.Conv2D(512, 3, strides=(1, 1), padding="same", activation="relu"), 87 | # layers.Conv2D(512, 3, strides=(1, 1), padding="same", activation="relu"), 88 | # layers.Conv2D(512, 3, strides=(1, 1), padding="same", activation="relu"), 89 | # layers.MaxPool2D(2, 2), 90 | # layers.Conv2D(512, 3, strides=(1, 1), padding="same", activation="relu"), 91 | # layers.Conv2D(512, 3, strides=(1, 1), padding="same", activation="relu"), 92 | # layers.Conv2D(512, 3, strides=(1, 1), padding="same", activation="relu"), 93 | # layers.Conv2D(512, 3, strides=(1, 1), padding="same", activation="relu"), 94 | # layers.MaxPool2D(2, 2), 95 | # layers.Flatten(), 96 | # layers.Dense(256, activation="relu"),#256 97 | # layers.Dense(128, activation="relu"), 98 | # layers.Dense(num, activation="softmax") 99 | # ]) 100 | # return model 101 | 102 | # VGG19 103 | model = keras.Sequential([ 104 | layers.Conv2D(64, 3, strides=(1, 1), padding="same", activation="relu"), 105 | layers.Conv2D(64, 3, strides=(1, 1), padding="same", activation="relu"), 106 | layers.MaxPool2D(2, 2), 107 | layers.Conv2D(128, 3, strides=(1, 1), padding="same", activation="relu"), 108 | layers.Conv2D(128, 3, strides=(1, 1), padding="same", activation="relu"), 109 | layers.MaxPool2D(2, 2), 110 | layers.Conv2D(256, 3, strides=(1, 1), padding="same", activation="relu"), 111 | layers.Conv2D(256, 3, strides=(1, 1), padding="same", activation="relu"), 112 | layers.Conv2D(256, 3, strides=(1, 1), padding="same", activation="relu"), 113 | layers.Conv2D(256, 3, strides=(1, 1), padding="same", activation="relu"), 114 | layers.MaxPool2D(2, 2), 115 | layers.Conv2D(512, 3, strides=(1, 1), padding="same", activation="relu"), 116 | layers.Conv2D(512, 3, strides=(1, 1), padding="same", activation="relu"), 117 | layers.Conv2D(512, 3, strides=(1, 1), padding="same", activation="relu"), 118 | layers.Conv2D(512, 3, strides=(1, 1), padding="same", activation="relu"), 119 | layers.MaxPool2D(2, 2), 120 | layers.Conv2D(512, 3, strides=(1, 1), padding="same", activation="relu"), 121 | layers.Conv2D(512, 3, strides=(1, 1), padding="same", activation="relu"), 122 | layers.Conv2D(512, 3, strides=(1, 1), padding="same", activation="relu"), 123 | layers.Conv2D(512, 3, strides=(1, 1), padding="same", activation="relu"), 124 | layers.MaxPool2D(2, 2), 125 | layers.Flatten(), 126 | layers.Dense(256, activation="relu"), # 256 127 | layers.Dense(128, activation="relu"), 128 | layers.Dense(num, activation="softmax") 129 | ]) 130 | return model 131 | -------------------------------------------------------------------------------- /code/test.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time :2021/12/6 19:46 3 | # @File :test.py 4 | # @Software :PyCharm 5 | # @Project : 6 | # @Content : 7 | 8 | 9 | import os 10 | import numpy as np 11 | import pathlib 12 | import canshu 13 | import tensorflow as tf 14 | 15 | 16 | #测试集数据获取并预处理 17 | class TestDataSet: 18 | def __init__(self): 19 | self.num = canshu.num #种类数 20 | 21 | self.image_paths = self.read_path(canshu.test_img_dir) #获取所有图片文件路径(含文件名) 22 | print("测试集大小:{}".format(len(self.image_paths))) 23 | 24 | 25 | #数据测试打包并预处理 26 | def build(self): 27 | """ 28 | :return: 预处理后的测试集 29 | """ 30 | #测试集 31 | img_test_list = self.preprocess_img_all(self.image_paths) #加载图片并预处理 32 | #print(img_test_list) 33 | test_image_ds = tf.data.Dataset.from_tensor_slices(img_test_list) #测试集数据打包 34 | test_all_images_labels = [self.text2vec(pathlib.Path(path).name.split("_")[0]) for path in self.image_paths] #获取独热编码 35 | #print(test_all_images_labels) 36 | test_label_ds = tf.data.Dataset.from_tensor_slices(tf.cast(test_all_images_labels, tf.int32)) #测试集标签打包 37 | test_set = tf.data.Dataset.zip((test_image_ds, test_label_ds)) #合并数据和标签 38 | test_set = test_set.batch(25, drop_remainder=True).prefetch(tf.data.experimental.AUTOTUNE) 39 | return test_set 40 | 41 | #所有图片预处理 42 | def preprocess_img_all(self, paths): 43 | """ 44 | :param paths: 照片路径(含文件名) 45 | :return: 灰度后的图片矩阵 46 | """ 47 | img_list = [] #存储图片矩阵 48 | for path in paths: #逐个图片获取 49 | #print(path) 50 | image = tf.io.read_file(path) #读取图片文件 51 | image = tf.image.decode_png(image, channels=3) #解码图像所需的彩色通道数目,输出RGB图像 52 | image = tf.image.resize(image, [canshu.image_height,canshu.image_width]) #调整图片大小 53 | image = 2 * tf.cast(image, dtype=tf.float32) / 255. - 1 #归一化 54 | r, g, b = image[:, :, 0], image[:, :, 1], image[:, :, 2] #获取各个通道像素值 55 | image = 0.2989 * r + 0.5870 * g + 0.1140 * b #灰度处理 56 | image = tf.expand_dims(image, axis=2) #增加维度,灰度后维度降低 57 | img_list.append(image) 58 | return img_list 59 | 60 | 61 | #标签转为独热编码 62 | def text2vec(self,text): 63 | """ 64 | :param text: 标签值 65 | :return: 独热编码 66 | """ 67 | text_num = canshu.labels.index(text) #获取标签列表的下标索引作为标签的唯一数字标识 68 | vector = np.zeros(self.num) #初始化独热编码矩阵 69 | vector[text_num] = 1 #对应位置置一 70 | #print(vector) 71 | return vector 72 | 73 | 74 | #获取所有图片文件路径(含文件名) 75 | def read_path(self,paths): 76 | """ 77 | :param paths: 文件夹路径 78 | :return: 所有图片文件 79 | """ 80 | file_all = [] #存储图片文件路径 81 | for file_name in os.listdir(paths): #逐个图片获取 82 | full_path = os.path.abspath(os.path.join(paths, file_name)) #合并路径 83 | if file_name.endswith('.jpg') or file_name.endswith('.bmp') or file_name.endswith('.png'): #文件后缀是否为'.jpg'、'.bmp'、'.png'(是否为图片文件) 84 | file_all.append(full_path) 85 | #print(file_all) 86 | return file_all 87 | 88 | if __name__ == "__main__": 89 | test_data = TestDataSet().build() 90 | model = tf.keras.models.load_model("./model/model1.h5") 91 | out = model.evaluate(test_data) 92 | print(out) -------------------------------------------------------------------------------- /code/xunlian.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time :2021/12/4 22:11 3 | # @File :xunlian.py 4 | # @Software :PyCharm 5 | # @Project :训练 6 | # @Content :对模型进行训练、保存以及训练期间的数据进行可视化展示 7 | 8 | 9 | 10 | import matplotlib.pyplot as plt 11 | import tensorflow as tf 12 | from tensorflow.keras import optimizers,losses 13 | from model import model_CNN 14 | from yuchuli import ImageDataSet 15 | import canshu 16 | import matplotlib 17 | matplotlib.rcParams["font.family"]="SimHei" 18 | matplotlib.rcParams["font.sans-serif"] = "SimHei" 19 | 20 | 21 | #训练 22 | def Train(): 23 | train_data,val_data = ImageDataSet().build() #数据获取并预处理 24 | 25 | #初始化模型 26 | model_save_dir = canshu.model_save_dir #训练后的模型保存地址 27 | model = model_CNN(canshu.num) #模型初始化 28 | 29 | #模型装配 30 | model.compile( 31 | optimizer=optimizers.Adam(lr=canshu.lr), #自适应估计随机梯度下降法 32 | loss=losses.categorical_crossentropy, #y值为独热编码形式的交叉熵损失函数 33 | metrics=['accuracy'] #以精确度作为衡量指标 34 | ) 35 | history = model.fit(train_data,epochs=canshu.epochs,validation_data=val_data) #喂数据 36 | model.save(model_save_dir+"/model1.h5") #模型保存 37 | 38 | #绘制曲线 39 | plt.figure(figsize=(10,8)) 40 | plt.subplot(221) 41 | plt.plot(history.history["loss"],color="r") 42 | plt.xlabel("训练次数(训练集)") 43 | plt.ylabel("损失") 44 | plt.title("训练次数-损失曲线(训练集)") 45 | plt.subplot(222) 46 | plt.plot(history.history["accuracy"],color="r") 47 | plt.ylim(0,1) 48 | plt.xlabel("训练次数(训练集)") 49 | plt.ylabel("精确度") 50 | plt.title("训练次数-精确度曲线(训练集)") 51 | plt.subplot(223) 52 | plt.plot(history.history["val_loss"], color="r") 53 | plt.xlabel("训练次数(验证集)") 54 | plt.ylabel("损失") 55 | plt.title("训练次数-损失曲线(验证集)") 56 | plt.subplot(224) 57 | plt.plot(history.history["val_accuracy"], color="r") 58 | plt.ylim(0,1) 59 | plt.xlabel("训练次数(验证集)") 60 | plt.ylabel("精确度") 61 | plt.title("训练次数-精确度曲线(验证集)") 62 | plt.show() 63 | 64 | 65 | if __name__ == '__main__': 66 | #使用gpu训练 67 | gpu_ok = tf.test.is_gpu_available() 68 | print("use GPU", gpu_ok) 69 | 70 | # 训练 71 | Train() -------------------------------------------------------------------------------- /code/yuchuli.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time :2021/12/4 18:19 3 | # @File :yuchuli.py 4 | # @Software :PyCharm 5 | # @Project :数据加载和预处理 6 | # @Content : 7 | """通过read_path方法读取文件夹中的所有子文件夹的图片文件路径,对路径进行打乱,随后划分训练集和验证集, 8 | 通过preprocess_img_all方法进行图片的加载以及预处理(调整图片大小、归一化、灰度并增加维度), 9 | 通过text2vec方法将图片标签进行独热编码转换,随后将图片矩阵数据和独热编码标签打包并分批""" 10 | 11 | import os 12 | import numpy as np 13 | import tensorflow as tf 14 | import pathlib 15 | import random 16 | import canshu 17 | 18 | #数据获取并预处理 19 | class ImageDataSet: 20 | def __init__(self): 21 | self.num = canshu.num #种类数 22 | self.train_batch_size = canshu.train_batch_size #训练批次大小 23 | self.val_batch_size = canshu.val_batch_size #验证批次大小 24 | 25 | data_root = pathlib.Path(canshu.train_img_dir) #创建训练图片路径对象 26 | image_paths = list(data_root.glob("*")) #遍历匹配所有(*)目录并存放到列表中 27 | #print("path:{}".format(image_paths)) 28 | image_paths = self.read_path(image_paths) #获取所有图片文件路径(含文件名) 29 | print(image_paths,type(image_paths)) 30 | random.shuffle(image_paths) #随机打乱 31 | 32 | #划分训练集验证集 33 | self.train_image_paths = image_paths[:int(canshu.train_ratio * len(image_paths))] #训练集 34 | self.val_image_paths = image_paths[int(canshu.train_ratio * len(image_paths)):] #验证集 35 | print("训练集大小:{}\t验证集大小:{}".format(len(self.train_image_paths),len(self.val_image_paths))) 36 | 37 | 38 | #数据训练前处理并打包 39 | def build(self): 40 | """ 41 | :return: 预处理后的训练集和验证集 42 | """ 43 | #训练集 44 | img_train_list = self.preprocess_img_all(self.train_image_paths) #加载图片并预处理 45 | #print(img_train_list) 46 | train_image_ds = tf.data.Dataset.from_tensor_slices(img_train_list) #训练集数据打包 47 | train_all_images_labels = [self.text2vec(pathlib.Path(path).name.split("_")[0]) for path in self.train_image_paths] #获取独热编码 48 | #print(train_all_images_labels) 49 | train_label_ds = tf.data.Dataset.from_tensor_slices(tf.cast(train_all_images_labels, tf.int32)) #训练集标签打包 50 | train_set = tf.data.Dataset.zip((train_image_ds, train_label_ds)) #合并数据和标签 51 | train_set = train_set.batch(self.train_batch_size, drop_remainder=True).prefetch(tf.data.experimental.AUTOTUNE) #训练前处理操作 52 | 53 | #验证集 54 | img_val_list = self.preprocess_img_all(self.val_image_paths) 55 | val_image_ds = tf.data.Dataset.from_tensor_slices(img_val_list) 56 | val_all_images_labels = [self.text2vec(pathlib.Path(path).name.split("_")[0]) for path in self.val_image_paths] 57 | val_label_ds = tf.data.Dataset.from_tensor_slices(tf.cast(val_all_images_labels, tf.int32)) 58 | val_set = tf.data.Dataset.zip((val_image_ds, val_label_ds)) 59 | val_set = val_set.batch(self.val_batch_size, drop_remainder=True).prefetch(tf.data.experimental.AUTOTUNE) 60 | return train_set, val_set 61 | 62 | 63 | #所有图片预处理 64 | def preprocess_img_all(self, paths): 65 | """ 66 | :param paths: 照片路径(含文件名) 67 | :return: 灰度后的图片矩阵 68 | """ 69 | img_list = [] #存储图片矩阵 70 | for path in paths: #逐个图片获取 71 | #print(path) 72 | image = tf.io.read_file(path) #读取图片文件 73 | image = tf.image.decode_png(image, channels=3) #解码图像所需的彩色通道数目,输出RGB图像 74 | image = tf.image.resize(image, [canshu.image_height,canshu.image_width]) #调整图片大小 75 | image = 2 * tf.cast(image, dtype=tf.float32) / 255. - 1 #归一化 76 | r, g, b = image[:, :, 0], image[:, :, 1], image[:, :, 2] #获取各个通道像素值 77 | image = 0.2989 * r + 0.5870 * g + 0.1140 * b #灰度处理 78 | image = tf.expand_dims(image, axis=2) #增加维度,灰度后维度降低 79 | img_list.append(image) 80 | return img_list 81 | 82 | 83 | #标签转为独热编码 84 | def text2vec(self,text): 85 | """ 86 | :param text: 标签值 87 | :return: 独热编码 88 | """ 89 | text_num = canshu.labels.index(text) #获取标签列表的下标索引作为标签的唯一数字标识 90 | vector = np.zeros(self.num) #初始化独热编码矩阵 91 | vector[text_num] = 1 #对应位置置一 92 | #print(vector) 93 | return vector 94 | 95 | 96 | #获取所有图片文件路径(含文件名) 97 | def read_path(self,paths): 98 | """ 99 | :param paths: 父文件夹路径 100 | :return: 所有子孙图片文件 101 | """ 102 | file_all = [] #存储图片文件路径 103 | for i in range(len(paths)): #逐个路径获取 104 | for file_name in os.listdir(paths[i]): #逐个获取指定目录下的所有子目录和文件名 105 | full_path = os.path.abspath(os.path.join(paths[i],file_name))#合并路径 106 | if os.path.isdir(full_path): #判断路径对象是否为文件夹 107 | #如果是文件夹,递归调用获取其子文件 108 | files = read_path(full_path) #递归 109 | file_all.extend(files) #追加子文件夹中的图片文件列表 110 | else: #路径对象为文件 111 | if file_name.endswith('.jpg') or file_name.endswith('.bmp') or file_name.endswith('.png'): #文件后缀是否为'.jpg'、'.bmp'、'.png'(是否为图片文件) 112 | file_all.append(full_path) 113 | #print(file_all) 114 | return file_all 115 | 116 | if __name__ == "__main__": 117 | ImageDataSet().build() -------------------------------------------------------------------------------- /code/zengjiashuju.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time :2021/12/7 17:10 3 | # @File :zengjiashuju.py 4 | # @Software :PyCharm 5 | # @Project :增加数据集(图片翻转) 6 | # @Content : 7 | 8 | import os 9 | import cv2 10 | import canshu 11 | 12 | 13 | 14 | #水平垂直翻转 15 | def fanzhuan(file_list,file_num): 16 | for file_name in file_list: #逐个获取图片文件名 17 | file_path = path + "/" + label + "/" + file_name #获取当前文件名路径 18 | img1 = cv2.imread(file_path) #读取图片 19 | img2 = cv2.flip(img1, 0) #垂直翻转 20 | img3 = cv2.flip(img1, 1) #水平翻转 21 | cv2.imshow("yuantu", img1) 22 | cv2.imshow("chuizhi", img2) 23 | cv2.imshow("shuiping", img3) 24 | cv2.waitKey(10) 25 | new_path_chuizhi = path + "/" + label + "/" + label + "_" + str(file_num) + ".jpg" #垂直新图片文件名和路径 26 | new_path_shuiping = path + "/" + label + "/" + label + "_" + str(file_num+1) + ".jpg" #水平新图片文件名和路径 27 | cv2.imwrite(new_path_chuizhi, img2) 28 | cv2.imwrite(new_path_shuiping,img3) 29 | file_num += 2 30 | 31 | 32 | if __name__ == "__main__": 33 | path = canshu.train_img_dir # 文件夹路径 34 | label = input("文件夹名称:") 35 | file_list = os.listdir(path + "/" + label) # 获取指定目录下的所有图片文件名 36 | file_num = len(file_list) # 文件个数 37 | 38 | 39 | fanzhuan(file_list,file_num) #翻转 40 | cv2.destroyAllWindows() 41 | --------------------------------------------------------------------------------