├── 4b8bb1117771485bb723bc39ca5b84c4.png ├── 4c1f113fbd2e4f84b7f9fa57f2e13114.png ├── 558a4c1d66c3482fad5cff9d55bfa3f9.png ├── 5641a46afa874bd29027e60b1bdd3773.png ├── 5eb0c2fd5907425f9df795218dc5090b.png ├── 7a1bc762d9e5484fbc3e70d7a5a7e901.png ├── a603d8d63e4346128b70ddd85954c364.png ├── a9883c5b261e41eea4d0b9cd139d9cc0.png ├── bad338d7a8634d448caa840008a62e82.png ├── cbdba346194e42deb692ce53e0814b48.png ├── d459c7c632db4c28af196afd9512f20c.png ├── eded326eb3a54f6da1be4bed3790858c.png ├── _data.py ├── model.py ├── predict.py ├── ConfusionMatrix.py ├── ui.py ├── fit.py └── README.md /4b8bb1117771485bb723bc39ca5b84c4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qunshansj/Improved-SE-VGG16-BN-131-Fruits-Veggies-Classification-Color-Variety-Grading/HEAD/4b8bb1117771485bb723bc39ca5b84c4.png -------------------------------------------------------------------------------- /4c1f113fbd2e4f84b7f9fa57f2e13114.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qunshansj/Improved-SE-VGG16-BN-131-Fruits-Veggies-Classification-Color-Variety-Grading/HEAD/4c1f113fbd2e4f84b7f9fa57f2e13114.png -------------------------------------------------------------------------------- /558a4c1d66c3482fad5cff9d55bfa3f9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qunshansj/Improved-SE-VGG16-BN-131-Fruits-Veggies-Classification-Color-Variety-Grading/HEAD/558a4c1d66c3482fad5cff9d55bfa3f9.png -------------------------------------------------------------------------------- /5641a46afa874bd29027e60b1bdd3773.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qunshansj/Improved-SE-VGG16-BN-131-Fruits-Veggies-Classification-Color-Variety-Grading/HEAD/5641a46afa874bd29027e60b1bdd3773.png -------------------------------------------------------------------------------- /5eb0c2fd5907425f9df795218dc5090b.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qunshansj/Improved-SE-VGG16-BN-131-Fruits-Veggies-Classification-Color-Variety-Grading/HEAD/5eb0c2fd5907425f9df795218dc5090b.png -------------------------------------------------------------------------------- /7a1bc762d9e5484fbc3e70d7a5a7e901.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qunshansj/Improved-SE-VGG16-BN-131-Fruits-Veggies-Classification-Color-Variety-Grading/HEAD/7a1bc762d9e5484fbc3e70d7a5a7e901.png -------------------------------------------------------------------------------- /a603d8d63e4346128b70ddd85954c364.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qunshansj/Improved-SE-VGG16-BN-131-Fruits-Veggies-Classification-Color-Variety-Grading/HEAD/a603d8d63e4346128b70ddd85954c364.png -------------------------------------------------------------------------------- /a9883c5b261e41eea4d0b9cd139d9cc0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qunshansj/Improved-SE-VGG16-BN-131-Fruits-Veggies-Classification-Color-Variety-Grading/HEAD/a9883c5b261e41eea4d0b9cd139d9cc0.png -------------------------------------------------------------------------------- /bad338d7a8634d448caa840008a62e82.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qunshansj/Improved-SE-VGG16-BN-131-Fruits-Veggies-Classification-Color-Variety-Grading/HEAD/bad338d7a8634d448caa840008a62e82.png -------------------------------------------------------------------------------- /cbdba346194e42deb692ce53e0814b48.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qunshansj/Improved-SE-VGG16-BN-131-Fruits-Veggies-Classification-Color-Variety-Grading/HEAD/cbdba346194e42deb692ce53e0814b48.png -------------------------------------------------------------------------------- /d459c7c632db4c28af196afd9512f20c.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qunshansj/Improved-SE-VGG16-BN-131-Fruits-Veggies-Classification-Color-Variety-Grading/HEAD/d459c7c632db4c28af196afd9512f20c.png -------------------------------------------------------------------------------- /eded326eb3a54f6da1be4bed3790858c.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qunshansj/Improved-SE-VGG16-BN-131-Fruits-Veggies-Classification-Color-Variety-Grading/HEAD/eded326eb3a54f6da1be4bed3790858c.png -------------------------------------------------------------------------------- /_data.py: -------------------------------------------------------------------------------- 1 | 2 | -----data 3 | |-----train 4 | | |-----class1 5 | | |-----class2 6 | | |-----... 7 | | 8 | |-----val 9 | | |-----class1 10 | | |-----class2 11 | | |-----... 12 | 13 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | python 2 | 3 | 4 | class VGG(nn.Module): 5 | def __init__(self, features, num_classes=1000, init_weights=False): 6 | super(VGG, self).__init__() 7 | self.features = features 8 | self.classifier = nn.Sequential( 9 | nn.Linear(512*7*7, 4096), 10 | nn.ReLU(True), 11 | nn.Dropout(p=0.5), 12 | nn.Linear(4096, 4096), 13 | nn.ReLU(True), 14 | nn.Dropout(p=0.5), 15 | nn.Linear(4096, num_classes) 16 | ) 17 | if init_weights: 18 | self._initialize_weights() 19 | 20 | def forward(self, x): 21 | x = self.features(x) 22 | x = torch.flatten(x, start_dim=1) 23 | x = self.classifier(x) 24 | return x 25 | 26 | def _initialize_weights(self): 27 | for m in self.modules(): 28 | if isinstance(m, nn.Conv2d): 29 | nn.init.xavier_uniform_(m.weight) 30 | if m.bias is not None: 31 | nn.init.constant_(m.bias, 0) 32 | elif isinstance(m, nn.Linear): 33 | nn.init.xavier_uniform_(m.weight) 34 | nn.init.constant_(m.bias, 0) 35 | 36 | 37 | def make_features(cfg: list): 38 | layers = [] 39 | in_channels = 3 40 | for v in cfg: 41 | if v == "M": 42 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 43 | else: 44 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 45 | layers += [conv2d, nn.ReLU(True)] 46 | in_channels = v 47 | return nn.Sequential(*layers) 48 | 49 | 50 | cfgs = { 51 | 'vgg11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 52 | 'vgg13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 53 | 'vgg16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 54 | 'vgg19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], 55 | } 56 | 57 | 58 | def vgg(model_name="vgg16", **kwargs): 59 | assert model_name in cfgs, "Warning: model number {} not in cfgs dict!".format(model_name) 60 | cfg = cfgs[model_name] 61 | 62 | model = VGG(make_features(cfg), **kwargs) 63 | return model 64 | -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | python 2 | import os 3 | import json 4 | 5 | import torch 6 | from PIL import Image 7 | from torchvision import transforms 8 | import matplotlib.pyplot as plt 9 | 10 | from model import vgg 11 | 12 | class ImageClassifier: 13 | def __init__(self, model_name, num_classes, weights_path, json_path): 14 | self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 15 | self.data_transform = transforms.Compose( 16 | [transforms.Resize((224, 224)), 17 | transforms.ToTensor(), 18 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) 19 | self.model = vgg(model_name=model_name, num_classes=num_classes).to(self.device) 20 | self.weights_path = weights_path 21 | self.json_path = json_path 22 | 23 | def load_image(self, img_path): 24 | assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path) 25 | img = Image.open(img_path) 26 | plt.imshow(img) 27 | img = self.data_transform(img) 28 | img = torch.unsqueeze(img, dim=0) 29 | return img 30 | 31 | def load_class_indict(self): 32 | assert os.path.exists(self.json_path), "file: '{}' dose not exist.".format(self.json_path) 33 | with open(self.json_path, "r") as f: 34 | class_indict = json.load(f) 35 | return class_indict 36 | 37 | def load_model_weights(self): 38 | assert os.path.exists(self.weights_path), "file: '{}' dose not exist.".format(self.weights_path) 39 | self.model.load_state_dict(torch.load(self.weights_path, map_location=self.device)) 40 | 41 | def predict(self, img): 42 | self.model.eval() 43 | with torch.no_grad(): 44 | output = torch.squeeze(self.model(img.to(self.device))).cpu() 45 | predict = torch.softmax(output, dim=0) 46 | predict_cla = torch.argmax(predict).numpy() 47 | return predict_cla, predict 48 | 49 | def show_result(self, class_indict, predict_cla, predict): 50 | print_res = "class: {} prob: {:.3}".format(class_indict[str(predict_cla)], 51 | predict[predict_cla].numpy()) 52 | plt.title(print_res) 53 | for i in range(len(predict)): 54 | print("class: {:10} prob: {:.3}".format(class_indict[str(i)], 55 | predict[i].numpy())) 56 | plt.show() 57 | 58 | def classify_image(self, img_path): 59 | img = self.load_image(img_path) 60 | class_indict = self.load_class_indict() 61 | self.load_model_weights() 62 | predict_cla, predict = self.predict(img) 63 | self.show_result(class_indict, predict_cla, predict) 64 | 65 | 66 | -------------------------------------------------------------------------------- /ConfusionMatrix.py: -------------------------------------------------------------------------------- 1 | python 2 | 3 | 4 | class ConfusionMatrix(object): 5 | """ 6 | 注意,如果显示的图像不全,是matplotlib版本问题 7 | 本例程使用matplotlib-3.2.1(windows and ubuntu)绘制正常 8 | 需要额外安装prettytable库 9 | """ 10 | def __init__(self, num_classes: int, labels: list): 11 | self.matrix = np.zeros((num_classes, num_classes)) 12 | self.num_classes = num_classes 13 | self.labels = labels 14 | 15 | def update(self, preds, labels): 16 | for p, t in zip(preds, labels): 17 | self.matrix[p, t] += 1 18 | 19 | def summary(self): 20 | # calculate accuracy 21 | sum_TP = 0 22 | for i in range(self.num_classes): 23 | sum_TP += self.matrix[i, i] 24 | acc = sum_TP / np.sum(self.matrix) 25 | print("the model accuracy is ", acc) 26 | 27 | # precision, recall, specificity 28 | table = PrettyTable() 29 | table.field_names = ["", "Precision", "Recall", "Specificity"] 30 | for i in range(self.num_classes): 31 | TP = self.matrix[i, i] 32 | FP = np.sum(self.matrix[i, :]) - TP 33 | FN = np.sum(self.matrix[:, i]) - TP 34 | TN = np.sum(self.matrix) - TP - FP - FN 35 | Precision = round(TP / (TP + FP), 3) if TP + FP != 0 else 0. 36 | Recall = round(TP / (TP + FN), 3) if TP + FN != 0 else 0. 37 | Specificity = round(TN / (TN + FP), 3) if TN + FP != 0 else 0. 38 | table.add_row([self.labels[i], Precision, Recall, Specificity]) 39 | print(table) 40 | 41 | def plot(self): 42 | matrix = self.matrix 43 | print(matrix) 44 | plt.figure(figsize=(40, 40), dpi=100) # 设置画布的大小和dpi,为了使图片更加清晰 45 | plt.imshow(matrix, cmap=plt.cm.Blues) 46 | 47 | 48 | # 设置x轴坐标label 49 | plt.xticks(range(self.num_classes), self.labels, rotation=45) 50 | # 设置y轴坐标label 51 | plt.yticks(range(self.num_classes), self.labels) 52 | # 显示colorbar 53 | plt.colorbar() 54 | plt.xlabel('True Labels') 55 | plt.ylabel('Predicted Labels') 56 | plt.title('Confusion matrix') 57 | 58 | # 在图中标注数量/概率信息 59 | thresh = matrix.max() / 2 60 | for x in range(self.num_classes): 61 | for y in range(self.num_classes): 62 | # 注意这里的matrix[y, x]不是matrix[x, y] 63 | info = int(matrix[y, x]) 64 | plt.text(x, y, info, 65 | verticalalignment='center', 66 | horizontalalignment='center', 67 | color="white" if info > thresh else "black") 68 | plt.tight_layout() 69 | plt.savefig('./confusion_matrix.png', format='png') # 保存图像为png格式 70 | plt.show() 71 | 72 | 73 | 74 | -------------------------------------------------------------------------------- /ui.py: -------------------------------------------------------------------------------- 1 | python 2 | 3 | 4 | class FruitDetector: 5 | def __init__(self): 6 | self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 7 | self.data_transform = transforms.Compose( 8 | [transforms.Resize((224, 224)), 9 | transforms.ToTensor(), 10 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) 11 | self.json_path = './class_indices.json' 12 | self.weights_path = "./vgg16Net.pth" 13 | self.class_indict = None 14 | self.model = None 15 | 16 | def load_class_indict(self): 17 | assert os.path.exists(self.json_path), "file: '{}' dose not exist.".format(self.json_path) 18 | with open(self.json_path, "r") as f: 19 | self.class_indict = json.load(f) 20 | 21 | def load_model(self): 22 | assert os.path.exists(self.weights_path), "file: '{}' dose not exist.".format(self.weights_path) 23 | self.model = vgg(model_name="vgg16", num_classes=131).to(self.device) 24 | self.model.load_state_dict(torch.load(self.weights_path, map_location=self.device)) 25 | self.model.eval() 26 | 27 | def detect_image(self, img_path): 28 | assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path) 29 | img = Image.open(img_path) 30 | plt.imshow(img) 31 | img = self.data_transform(img) 32 | img = torch.unsqueeze(img, dim=0) 33 | with torch.no_grad(): 34 | output = torch.squeeze(self.model(img.to(self.device))).cpu() 35 | predict = torch.softmax(output, dim=0) 36 | predict_cla = torch.argmax(predict).numpy() 37 | print_res = "class: {} prob: {:.3}".format(self.class_indict[str(predict_cla)], 38 | predict[predict_cla].numpy()) 39 | plt.title(print_res) 40 | plt.savefig('./save.png', format='png') 41 | show = cv2.imread('./save.png') 42 | return print_res, show 43 | 44 | def detect_video(self, video_path): 45 | capture = cv2.VideoCapture(video_path) 46 | while True: 47 | _, image = capture.read() 48 | if image is None: 49 | break 50 | cv2.imwrite('./save.png', image) 51 | img_path = './save.png' 52 | assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path) 53 | img = Image.open(img_path) 54 | plt.imshow(img) 55 | img = self.data_transform(img) 56 | img = torch.unsqueeze(img, dim=0) 57 | with torch.no_grad(): 58 | output = torch.squeeze(self.model(img.to(self.device))).cpu() 59 | predict = torch.softmax(output, dim=0) 60 | predict_cla = torch.argmax(predict).numpy() 61 | print_res = "class: {} prob: {:.3}".format(self.class_indict[str(predict_cla)], 62 | predict[predict_cla].numpy()) 63 | plt.title(print_res) 64 | plt.savefig('./save.png', format='png') 65 | show = cv2.imread('./save.png') 66 | yield print_res, show 67 | 68 | 69 | 70 | class Ui_MainWindow(object): 71 | def setupUi(self, MainWindow): 72 | # ... 73 | 74 | def retranslateUi(self, MainWindow): 75 | # ... 76 | 77 | def openfile2(self): 78 | # ... 79 | 80 | def handleCalc4(self): 81 | # ... 82 | 83 | def openfile(self): 84 | # ... 85 | 86 | def handleCalc3(self): 87 | # ... 88 | 89 | def printf(self, text): 90 | # ... 91 | 92 | def showimg(self, img): 93 | # ... 94 | 95 | def click_1(self): 96 | # ... 97 | 98 | 99 | if __name__ == "__main__": 100 | app = QtWidgets.QApplication(sys.argv) 101 | MainWindow = QtWidgets.QMainWindow() 102 | ui = Ui_MainWindow() 103 | ui.setupUi(MainWindow) 104 | MainWindow.show() 105 | sys.exit(app.exec_()) 106 | -------------------------------------------------------------------------------- /fit.py: -------------------------------------------------------------------------------- 1 | python 2 | 3 | # SE模块 4 | class SELayer(nn.Module): 5 | def __init__(self, channel, reduction=16): 6 | super(SELayer, self).__init__() 7 | self.avg_pool = nn.AdaptiveAvgPool2d(1) # 全局平均池化 8 | # 两个全连接层,分别进行降维和升维 9 | self.fc = nn.Sequential( 10 | nn.Linear(channel, channel // reduction, bias=False), 11 | nn.ReLU(inplace=True), 12 | nn.Linear(channel // reduction, channel, bias=False), 13 | nn.Sigmoid() 14 | ) 15 | 16 | def forward(self, x): 17 | b, c, _, _ = x.size() 18 | y = self.avg_pool(x).view(b, c) 19 | y = self.fc(y).view(b, c, 1, 1) 20 | return x * y # 加权 21 | 22 | # SE-VGGConv模块 23 | class SEVGGConv(nn.Module): 24 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0): 25 | super(SEVGGConv, self).__init__() 26 | # 卷积 + BN + ReLU + SE 27 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding) 28 | self.bn = nn.BatchNorm2d(out_channels) 29 | self.relu = nn.ReLU(inplace=True) 30 | self.se = SELayer(out_channels) 31 | 32 | def forward(self, x): 33 | x = self.conv(x) 34 | x = self.bn(x) 35 | x = self.relu(x) 36 | x = self.se(x) 37 | return x 38 | 39 | # SE-VGG16-BN网络 40 | class SEVGG16BN(nn.Module): 41 | def __init__(self, num_classes=131): 42 | super(SEVGG16BN, self).__init__() 43 | # 根据VGG16结构定义网络层 44 | self.features = nn.Sequential( 45 | # 输入 224x224x3 46 | SEVGGConv(3, 64, kernel_size=3, padding=1), 47 | SEVGGConv(64, 64, kernel_size=3, padding=1), 48 | nn.MaxPool2d(2, 2), # 56x56x128 49 | SEVGGConv(64, 128, kernel_size=3, padding=1), 50 | SEVGGConv(128, 128, kernel_size=3, padding=1), 51 | nn.MaxPool2d(2, 2), # 28x28x256 52 | SEVGGConv(128, 256, kernel_size=3, padding=1), 53 | SEVGGConv(256, 256, kernel_size=3, padding=1), 54 | SEVGGConv(256, 256, kernel_size=3, padding=1), 55 | nn.MaxPool2d(2, 2), # 14x14x512 56 | SEVGGConv(256, 512, kernel_size=3, padding=1), 57 | SEVGGConv(512, 512, kernel_size=3, padding=1), 58 | SEVGGConv(512, 512, kernel_size=3, padding=1), 59 | nn.MaxPool2d(2, 2), # 7x7x512 60 | SEVGGConv(512, 512, kernel_size=3, padding=1), 61 | SEVGGConv(512, 512, kernel_size=3, padding=1), 62 | SEVGGConv(512, 512, kernel_size=3, padding=1), 63 | nn.MaxPool2d(2, 2), # 1x1x4096 64 | ) 65 | # 定义分类器部分 66 | self.classifier = nn.Sequential( 67 | nn.Linear(512 * 7 * 7, 4096), 68 | nn.ReLU(True), 69 | nn.Dropout(), 70 | nn.Linear(4096, 4096), 71 | nn.ReLU(True), 72 | nn.Dropout(), 73 | nn.Linear(4096, num_classes), 74 | ) 75 | 76 | def forward(self, x): 77 | x = self.features(x) 78 | x = x.view(x.size(0), -1) 79 | x = self.classifier(x) 80 | return x 81 | 82 | # 多损失函数融合 83 | class CombinedLoss(nn.Module): 84 | def __init__(self, alpha=0.1): 85 | super(CombinedLoss, self).__init__() 86 | self.alpha = alpha # 融合系数 87 | self.cross_entropy = nn.CrossEntropyLoss() 88 | # 中心损失可以用triplet loss或者contrastive loss代替 89 | 90 | def forward(self, outputs, labels): 91 | # 此处需要根据具体情况添加中心损失的计算 92 | loss = self.cross_entropy(outputs, labels) 93 | return loss 94 | 95 | # 封装为类 96 | class SEVGG16BNCombinedLoss(nn.Module): 97 | def __init__(self, num_classes=131, alpha=0.1): 98 | super(SEVGG16BNCombinedLoss, self).__init__() 99 | self.model = SEVGG16BN(num_classes) 100 | self.loss = CombinedLoss(alpha) 101 | 102 | def forward(self, x, labels): 103 | outputs = self.model(x) 104 | loss = self.loss(outputs, labels) 105 | return loss 106 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 基于改进SE-VGG16-BN的131种水果图像分类系统(颜色、品种分级) 2 | 3 | # 1.研究背景与意义 4 | 5 | 6 | 7 | 随着计算机视觉和机器学习的快速发展,图像分类成为了一个热门的研究领域。在许多实际应用中,如农业、食品安全和市场调研等领域,对水果图像进行准确分类和品种分级具有重要意义。然而,由于水果的形状、颜色和纹理等特征的多样性,以及光照条件和拍摄角度的变化,水果图像分类面临着许多挑战。 8 | 9 | 目前,基于深度学习的图像分类方法已经取得了显著的成果。其中,卷积神经网络(CNN)是一种非常有效的方法,可以自动学习图像的特征表示。然而,传统的CNN模型在处理水果图像分类时仍然存在一些问题。首先,传统的CNN模型对于水果图像中的颜色信息没有充分利用,导致分类准确率较低。其次,传统的CNN模型对于水果图像中的品种分级任务并不擅长,无法提供细粒度的分类结果。 10 | 11 | 因此,本研究旨在基于改进的SE-VGG16-BN模型,实现对131种水果图像的准确分类和品种分级。具体来说,本研究将从以下几个方面进行改进和优化: 12 | 13 | 首先,本研究将引入注意力机制,以增强模型对水果图像中的颜色信息的感知能力。通过学习图像中不同区域的重要性权重,模型可以更好地捕捉到水果图像中的颜色特征,从而提高分类准确率。 14 | 15 | 其次,本研究将引入细粒度分类的方法,以实现对水果图像的品种分级。通过在模型中增加额外的分类层,可以将水果图像分为更多的细粒度类别,从而提供更具体和详细的分类结果。 16 | 17 | 最后,本研究将对数据集进行充分的预处理和增强,以提高模型的鲁棒性和泛化能力。通过对数据集进行旋转、缩放和平移等操作,可以增加模型对不同光照条件和拍摄角度的适应能力,从而提高分类的准确性和稳定性。 18 | 19 | 本研究的意义在于提供了一种基于改进的SE-VGG16-BN模型的水果图像分类系统,可以在农业、食品安全和市场调研等领域中得到广泛应用。通过准确分类和品种分级,可以帮助农民和市场调研人员更好地了解水果的品质和市场需求,从而提高农产品的质量和市场竞争力。此外,本研究还可以为其他图像分类任务提供借鉴和参考,推动深度学习在计算机视觉领域的发展。 20 | 21 | # 2.图片演示 22 | ![在这里插入图片描述](bad338d7a8634d448caa840008a62e82.png#pic_center) 23 | ![在这里插入图片描述](5eb0c2fd5907425f9df795218dc5090b.png#pic_center) 24 | ![在这里插入图片描述](a9883c5b261e41eea4d0b9cd139d9cc0.png#pic_center) 25 | 26 | # 3.视频演示 27 | [基于改进SE-VGG16-BN的131种水果蔬菜图像分类系统(颜色、品种分级)_哔哩哔哩_bilibili](https://www.bilibili.com/video/BV1Uc411d7az/?spm_id_from=333.999.0.0&vd_source=bc9aec86d164b67a7004b996143742dc) 28 | 29 | # 4.数据集和训练参数设定 30 | [AAAI提供的蔬菜水果数据集](https://afdian.net/item/542e72f87a0c11ee87f552540025c377)包含了131个分类,包含了常见的所有的蔬菜和水果类型,并且根据颜色和类型进行了分级划分。 31 | ![在这里插入图片描述](cbdba346194e42deb692ce53e0814b48.png) 32 | 我们需要将数据集整理为以下结构: 33 | ``` 34 | -----data 35 | |-----train 36 | | |-----class1 37 | | |-----class2 38 | | |-----... 39 | | 40 | |-----val 41 | | |-----class1 42 | | |-----class2 43 | | |-----... 44 | 45 | ``` 46 | (1)为提高训练的效果,加快网络模型的收敛,对两个数据集的花卉图片按照保持长宽比的方式归一化,归一化后的尺寸为224×224×3. 47 | (2)将数据增强后的每类花卉图片数的70%划分为训练集,剩余30%作为测试集. 48 | (3)训练时保留VGG16经 ImageNet 预训练产生的用于特征提取的参数,SE单元模块中用于放缩参数r设置为文献[8]的作者所推荐的16,其余参数均使用正态分布随机值进行初始化. 49 | (4)采用随机梯度下降法来优化模型, batchsize设置为32, epoch设为3000,学习率设为0.001,动量因子设为0.9,权重衰减设为0.000 5. 50 | (5)为了防止过拟合,SE-VGG16 网络模型第6段的两个全连接层的dropout 设置为0.5. 51 | (6)多损失函数融合公式中入参数的值设置为0.5. 52 | 53 | 54 | # 5.核心代码讲解 55 | 56 | #### 5.1 ConfusionMatrix.py 57 | 58 | ```python 59 | 60 | 61 | class ConfusionMatrix(object): 62 | """ 63 | 注意,如果显示的图像不全,是matplotlib版本问题 64 | 本例程使用matplotlib-3.2.1(windows and ubuntu)绘制正常 65 | 需要额外安装prettytable库 66 | """ 67 | def __init__(self, num_classes: int, labels: list): 68 | self.matrix = np.zeros((num_classes, num_classes)) 69 | self.num_classes = num_classes 70 | self.labels = labels 71 | 72 | def update(self, preds, labels): 73 | for p, t in zip(preds, labels): 74 | self.matrix[p, t] += 1 75 | 76 | def summary(self): 77 | # calculate accuracy 78 | sum_TP = 0 79 | for i in range(self.num_classes): 80 | sum_TP += self.matrix[i, i] 81 | acc = sum_TP / np.sum(self.matrix) 82 | print("the model accuracy is ", acc) 83 | 84 | # precision, recall, specificity 85 | table = PrettyTable() 86 | table.field_names = ["", "Precision", "Recall", "Specificity"] 87 | for i in range(self.num_classes): 88 | TP = self.matrix[i, i] 89 | FP = np.sum(self.matrix[i, :]) - TP 90 | FN = np.sum(self.matrix[:, i]) - TP 91 | TN = np.sum(self.matrix) - TP - FP - FN 92 | Precision = round(TP / (TP + FP), 3) if TP + FP != 0 else 0. 93 | Recall = round(TP / (TP + FN), 3) if TP + FN != 0 else 0. 94 | Specificity = round(TN / (TN + FP), 3) if TN + FP != 0 else 0. 95 | table.add_row([self.labels[i], Precision, Recall, Specificity]) 96 | print(table) 97 | 98 | def plot(self): 99 | matrix = self.matrix 100 | print(matrix) 101 | plt.figure(figsize=(40, 40), dpi=100) # 设置画布的大小和dpi,为了使图片更加清晰 102 | plt.imshow(matrix, cmap=plt.cm.Blues) 103 | 104 | 105 | # 设置x轴坐标label 106 | plt.xticks(range(self.num_classes), self.labels, rotation=45) 107 | # 设置y轴坐标label 108 | plt.yticks(range(self.num_classes), self.labels) 109 | # 显示colorbar 110 | plt.colorbar() 111 | plt.xlabel('True Labels') 112 | plt.ylabel('Predicted Labels') 113 | plt.title('Confusion matrix') 114 | 115 | # 在图中标注数量/概率信息 116 | thresh = matrix.max() / 2 117 | for x in range(self.num_classes): 118 | for y in range(self.num_classes): 119 | # 注意这里的matrix[y, x]不是matrix[x, y] 120 | info = int(matrix[y, x]) 121 | plt.text(x, y, info, 122 | verticalalignment='center', 123 | horizontalalignment='center', 124 | color="white" if info > thresh else "black") 125 | plt.tight_layout() 126 | plt.savefig('./confusion_matrix.png', format='png') # 保存图像为png格式 127 | plt.show() 128 | 129 | 130 | 131 | ``` 132 | 133 | 该程序文件名为ConfusionMatrix.py,主要功能是计算和绘制混淆矩阵。 134 | 135 | 程序首先导入了所需的库和模块,包括os、json、torch、transforms、datasets、numpy、tqdm、matplotlib和PrettyTable。 136 | 137 | 然后定义了一个名为ConfusionMatrix的类,该类有以下几个方法: 138 | - `__init__(self, num_classes: int, labels: list)`:初始化方法,接收分类数和标签列表作为参数,创建一个大小为(num_classes, num_classes)的零矩阵,并保存分类数和标签列表。 139 | - `update(self, preds, labels)`:更新混淆矩阵的方法,接收预测结果和真实标签作为参数,根据预测结果和真实标签更新混淆矩阵。 140 | - `summary(self)`:计算并打印模型的准确率、精确度、召回率和特异度。 141 | - `plot(self)`:绘制混淆矩阵图像,并保存为png格式。 142 | 143 | 接下来是主程序部分,首先判断是否有可用的GPU,然后定义了数据的预处理方法和数据集路径。 144 | 145 | 然后创建了一个验证数据集的DataLoader,并加载了预训练的vgg模型权重。 146 | 147 | 接着读取了类别标签的json文件,并保存了标签列表。 148 | 149 | 然后创建了一个ConfusionMatrix对象,并将模型设置为评估模式。 150 | 151 | 在没有梯度的情况下,遍历验证数据集,对每个验证数据进行模型推理,并更新混淆矩阵。 152 | 153 | 最后调用ConfusionMatrix对象的plot方法绘制混淆矩阵图像,并调用summary方法打印模型的准确率、精确度、召回率和特异度。 154 | 155 | #### 5.2 fit.py 156 | 157 | ```python 158 | 159 | # SE模块 160 | class SELayer(nn.Module): 161 | def __init__(self, channel, reduction=16): 162 | super(SELayer, self).__init__() 163 | self.avg_pool = nn.AdaptiveAvgPool2d(1) # 全局平均池化 164 | # 两个全连接层,分别进行降维和升维 165 | self.fc = nn.Sequential( 166 | nn.Linear(channel, channel // reduction, bias=False), 167 | nn.ReLU(inplace=True), 168 | nn.Linear(channel // reduction, channel, bias=False), 169 | nn.Sigmoid() 170 | ) 171 | 172 | def forward(self, x): 173 | b, c, _, _ = x.size() 174 | y = self.avg_pool(x).view(b, c) 175 | y = self.fc(y).view(b, c, 1, 1) 176 | return x * y # 加权 177 | 178 | # SE-VGGConv模块 179 | class SEVGGConv(nn.Module): 180 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0): 181 | super(SEVGGConv, self).__init__() 182 | # 卷积 + BN + ReLU + SE 183 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding) 184 | self.bn = nn.BatchNorm2d(out_channels) 185 | self.relu = nn.ReLU(inplace=True) 186 | self.se = SELayer(out_channels) 187 | 188 | def forward(self, x): 189 | x = self.conv(x) 190 | x = self.bn(x) 191 | x = self.relu(x) 192 | x = self.se(x) 193 | return x 194 | 195 | # SE-VGG16-BN网络 196 | class SEVGG16BN(nn.Module): 197 | def __init__(self, num_classes=131): 198 | super(SEVGG16BN, self).__init__() 199 | # 根据VGG16结构定义网络层 200 | self.features = nn.Sequential( 201 | # 输入 224x224x3 202 | SEVGGConv(3, 64, kernel_size=3, padding=1), 203 | SEVGGConv(64, 64, kernel_size=3, padding=1), 204 | nn.MaxPool2d(2, 2), # 56x56x128 205 | SEVGGConv(64, 128, kernel_size=3, padding=1), 206 | SEVGGConv(128, 128, kernel_size=3, padding=1), 207 | nn.MaxPool2d(2, 2), # 28x28x256 208 | SEVGGConv(128, 256, kernel_size=3, padding=1), 209 | SEVGGConv(256, 256, kernel_size=3, padding=1), 210 | SEVGGConv(256, 256, kernel_size=3, padding=1), 211 | nn.MaxPool2d(2, 2), # 14x14x512 212 | SEVGGConv(256, 512, kernel_size=3, padding=1), 213 | SEVGGConv(512, 512, kernel_size=3, padding=1), 214 | SEVGGConv(512, 512, kernel_size=3, padding=1), 215 | nn.MaxPool2d(2, 2), # 7x7x512 216 | SEVGGConv(512, 512, kernel_size=3, padding=1), 217 | SEVGGConv(512, 512, kernel_size=3, padding=1), 218 | SEVGGConv(512, 512, kernel_size=3, padding=1), 219 | nn.MaxPool2d(2, 2), # 1x1x4096 220 | ) 221 | # 定义分类器部分 222 | self.classifier = nn.Sequential( 223 | nn.Linear(512 * 7 * 7, 4096), 224 | nn.ReLU(True), 225 | nn.Dropout(), 226 | nn.Linear(4096, 4096), 227 | nn.ReLU(True), 228 | nn.Dropout(), 229 | nn.Linear(4096, num_classes), 230 | ) 231 | 232 | def forward(self, x): 233 | x = self.features(x) 234 | x = x.view(x.size(0), -1) 235 | x = self.classifier(x) 236 | return x 237 | 238 | # 多损失函数融合 239 | class CombinedLoss(nn.Module): 240 | def __init__(self, alpha=0.1): 241 | super(CombinedLoss, self).__init__() 242 | self.alpha = alpha # 融合系数 243 | self.cross_entropy = nn.CrossEntropyLoss() 244 | # 中心损失可以用triplet loss或者contrastive loss代替 245 | 246 | def forward(self, outputs, labels): 247 | # 此处需要根据具体情况添加中心损失的计算 248 | loss = self.cross_entropy(outputs, labels) 249 | return loss 250 | 251 | # 封装为类 252 | class SEVGG16BNCombinedLoss(nn.Module): 253 | def __init__(self, num_classes=131, alpha=0.1): 254 | super(SEVGG16BNCombinedLoss, self).__init__() 255 | self.model = SEVGG16BN(num_classes) 256 | self.loss = CombinedLoss(alpha) 257 | 258 | def forward(self, x, labels): 259 | outputs = self.model(x) 260 | loss = self.loss(outputs, labels) 261 | return loss 262 | ``` 263 | 264 | 这个程序文件是一个用于图像分类的深度学习模型。它定义了一个名为SEVGG16BN的类,该类继承自nn.Module。这个类包含了一个特征提取部分和一个分类器部分。 265 | 266 | 特征提取部分使用了SEVGGConv模块,它是一个包含了卷积、批归一化、ReLU激活函数和SE模块的组合。SE模块是一种注意力机制,用于对特征图进行加权。特征提取部分共有5个阶段,每个阶段包含多个SEVGGConv模块和一个最大池化层。 267 | 268 | 分类器部分是一个全连接神经网络,它将特征图展平后经过多个线性层和ReLU激活函数得到最终的分类结果。 269 | 270 | 此外,程序文件还定义了一个CombinedLoss类,用于计算多个损失函数的融合损失。目前只实现了交叉熵损失函数,中心损失函数可以根据具体情况进行替换。 271 | 272 | 总体来说,这个程序文件定义了一个使用SE模块和VGG16结构的图像分类模型,并提供了多个损失函数的融合功能。 273 | 274 | #### 5.3 model.py 275 | 276 | ```python 277 | 278 | 279 | class VGG(nn.Module): 280 | def __init__(self, features, num_classes=1000, init_weights=False): 281 | super(VGG, self).__init__() 282 | self.features = features 283 | self.classifier = nn.Sequential( 284 | nn.Linear(512*7*7, 4096), 285 | nn.ReLU(True), 286 | nn.Dropout(p=0.5), 287 | nn.Linear(4096, 4096), 288 | nn.ReLU(True), 289 | nn.Dropout(p=0.5), 290 | nn.Linear(4096, num_classes) 291 | ) 292 | if init_weights: 293 | self._initialize_weights() 294 | 295 | def forward(self, x): 296 | x = self.features(x) 297 | x = torch.flatten(x, start_dim=1) 298 | x = self.classifier(x) 299 | return x 300 | 301 | def _initialize_weights(self): 302 | for m in self.modules(): 303 | if isinstance(m, nn.Conv2d): 304 | nn.init.xavier_uniform_(m.weight) 305 | if m.bias is not None: 306 | nn.init.constant_(m.bias, 0) 307 | elif isinstance(m, nn.Linear): 308 | nn.init.xavier_uniform_(m.weight) 309 | nn.init.constant_(m.bias, 0) 310 | 311 | 312 | def make_features(cfg: list): 313 | layers = [] 314 | in_channels = 3 315 | for v in cfg: 316 | if v == "M": 317 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 318 | else: 319 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 320 | layers += [conv2d, nn.ReLU(True)] 321 | in_channels = v 322 | return nn.Sequential(*layers) 323 | 324 | 325 | cfgs = { 326 | 'vgg11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 327 | 'vgg13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 328 | 'vgg16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 329 | 'vgg19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], 330 | } 331 | 332 | 333 | def vgg(model_name="vgg16", **kwargs): 334 | assert model_name in cfgs, "Warning: model number {} not in cfgs dict!".format(model_name) 335 | cfg = cfgs[model_name] 336 | 337 | model = VGG(make_features(cfg), **kwargs) 338 | return model 339 | ``` 340 | 341 | 该程序文件名为model.py,是一个实现VGG模型的代码文件。 342 | 343 | 该文件首先导入了torch.nn和torch模块,并定义了一个model_urls字典,包含了VGG模型的预训练权重的下载链接。 344 | 345 | 接下来定义了一个VGG类,继承自nn.Module类。该类的构造函数接受一个features参数和一个num_classes参数,用于构建VGG模型的特征提取部分和分类器部分。构造函数还有一个可选的init_weights参数,用于控制是否初始化模型的权重。 346 | 347 | VGG类的forward方法定义了模型的前向传播过程。首先将输入通过特征提取部分,然后将输出展平为一维张量,最后通过分类器部分得到最终的输出。 348 | 349 | VGG类还定义了一个私有方法_initialize_weights,用于初始化模型的权重。 350 | 351 | 接下来定义了一个make_features函数,用于根据给定的配置列表构建VGG模型的特征提取部分。该函数根据配置列表中的值来选择添加卷积层或最大池化层,并使用ReLU激活函数。 352 | 353 | 最后定义了一个vgg函数,用于创建VGG模型。该函数接受一个model_name参数,用于选择VGG模型的配置。根据model_name从cfgs字典中获取对应的配置列表,然后调用make_features函数构建特征提取部分,并将其传入VGG类的构造函数中创建模型。 354 | 355 | 总结起来,该程序文件实现了VGG模型的构建和前向传播过程,并提供了预训练权重的下载链接。 356 | 357 | #### 5.4 predict.py 358 | 359 | ```python 360 | import os 361 | import json 362 | 363 | import torch 364 | from PIL import Image 365 | from torchvision import transforms 366 | import matplotlib.pyplot as plt 367 | 368 | from model import vgg 369 | 370 | class ImageClassifier: 371 | def __init__(self, model_name, num_classes, weights_path, json_path): 372 | self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 373 | self.data_transform = transforms.Compose( 374 | [transforms.Resize((224, 224)), 375 | transforms.ToTensor(), 376 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) 377 | self.model = vgg(model_name=model_name, num_classes=num_classes).to(self.device) 378 | self.weights_path = weights_path 379 | self.json_path = json_path 380 | 381 | def load_image(self, img_path): 382 | assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path) 383 | img = Image.open(img_path) 384 | plt.imshow(img) 385 | img = self.data_transform(img) 386 | img = torch.unsqueeze(img, dim=0) 387 | return img 388 | 389 | def load_class_indict(self): 390 | assert os.path.exists(self.json_path), "file: '{}' dose not exist.".format(self.json_path) 391 | with open(self.json_path, "r") as f: 392 | class_indict = json.load(f) 393 | return class_indict 394 | 395 | def load_model_weights(self): 396 | assert os.path.exists(self.weights_path), "file: '{}' dose not exist.".format(self.weights_path) 397 | self.model.load_state_dict(torch.load(self.weights_path, map_location=self.device)) 398 | 399 | def predict(self, img): 400 | self.model.eval() 401 | with torch.no_grad(): 402 | output = torch.squeeze(self.model(img.to(self.device))).cpu() 403 | predict = torch.softmax(output, dim=0) 404 | predict_cla = torch.argmax(predict).numpy() 405 | return predict_cla, predict 406 | 407 | def show_result(self, class_indict, predict_cla, predict): 408 | print_res = "class: {} prob: {:.3}".format(class_indict[str(predict_cla)], 409 | predict[predict_cla].numpy()) 410 | plt.title(print_res) 411 | for i in range(len(predict)): 412 | print("class: {:10} prob: {:.3}".format(class_indict[str(i)], 413 | predict[i].numpy())) 414 | plt.show() 415 | 416 | def classify_image(self, img_path): 417 | img = self.load_image(img_path) 418 | class_indict = self.load_class_indict() 419 | self.load_model_weights() 420 | predict_cla, predict = self.predict(img) 421 | self.show_result(class_indict, predict_cla, predict) 422 | 423 | 424 | ``` 425 | 426 | 这个程序文件名为predict.py,它的功能是使用预训练的VGG模型对一张图片进行分类预测。程序的主要流程如下: 427 | 428 | 1. 导入所需的库和模块,包括os、json、torch、PIL、transforms和matplotlib.pyplot。 429 | 2. 定义了一个main函数作为程序的入口。 430 | 3. 判断是否有可用的GPU设备,如果有则使用cuda:0作为设备,否则使用cpu作为设备。 431 | 4. 定义了一个数据转换的操作,包括将图片调整为224x224大小、转换为Tensor格式、进行归一化操作。 432 | 5. 加载待预测的图片,并进行数据转换操作。 433 | 6. 读取类别标签的映射文件class_indices.json。 434 | 7. 创建一个VGG模型实例,并指定模型名称为vgg16,类别数为131。 435 | 8. 加载预训练好的模型权重文件vgg16Net.pth。 436 | 9. 将模型设置为评估模式。 437 | 10. 使用模型对图片进行预测,得到预测结果。 438 | 11. 打印预测结果,并在图像上显示预测结果。 439 | 12. 主函数调用main函数,开始执行程序。 440 | 441 | 442 | 443 | #### 5.5 ui.py 444 | 445 | ```python 446 | 447 | 448 | class FruitDetector: 449 | def __init__(self): 450 | self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 451 | self.data_transform = transforms.Compose( 452 | [transforms.Resize((224, 224)), 453 | transforms.ToTensor(), 454 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) 455 | self.json_path = './class_indices.json' 456 | self.weights_path = "./vgg16Net.pth" 457 | self.class_indict = None 458 | self.model = None 459 | 460 | def load_class_indict(self): 461 | assert os.path.exists(self.json_path), "file: '{}' dose not exist.".format(self.json_path) 462 | with open(self.json_path, "r") as f: 463 | self.class_indict = json.load(f) 464 | 465 | def load_model(self): 466 | assert os.path.exists(self.weights_path), "file: '{}' dose not exist.".format(self.weights_path) 467 | self.model = vgg(model_name="vgg16", num_classes=131).to(self.device) 468 | self.model.load_state_dict(torch.load(self.weights_path, map_location=self.device)) 469 | self.model.eval() 470 | 471 | def detect_image(self, img_path): 472 | assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path) 473 | img = Image.open(img_path) 474 | plt.imshow(img) 475 | img = self.data_transform(img) 476 | img = torch.unsqueeze(img, dim=0) 477 | with torch.no_grad(): 478 | output = torch.squeeze(self.model(img.to(self.device))).cpu() 479 | predict = torch.softmax(output, dim=0) 480 | predict_cla = torch.argmax(predict).numpy() 481 | print_res = "class: {} prob: {:.3}".format(self.class_indict[str(predict_cla)], 482 | predict[predict_cla].numpy()) 483 | plt.title(print_res) 484 | plt.savefig('./save.png', format='png') 485 | show = cv2.imread('./save.png') 486 | return print_res, show 487 | 488 | def detect_video(self, video_path): 489 | capture = cv2.VideoCapture(video_path) 490 | while True: 491 | _, image = capture.read() 492 | if image is None: 493 | break 494 | cv2.imwrite('./save.png', image) 495 | img_path = './save.png' 496 | assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path) 497 | img = Image.open(img_path) 498 | plt.imshow(img) 499 | img = self.data_transform(img) 500 | img = torch.unsqueeze(img, dim=0) 501 | with torch.no_grad(): 502 | output = torch.squeeze(self.model(img.to(self.device))).cpu() 503 | predict = torch.softmax(output, dim=0) 504 | predict_cla = torch.argmax(predict).numpy() 505 | print_res = "class: {} prob: {:.3}".format(self.class_indict[str(predict_cla)], 506 | predict[predict_cla].numpy()) 507 | plt.title(print_res) 508 | plt.savefig('./save.png', format='png') 509 | show = cv2.imread('./save.png') 510 | yield print_res, show 511 | 512 | 513 | 514 | class Ui_MainWindow(object): 515 | def setupUi(self, MainWindow): 516 | # ... 517 | 518 | def retranslateUi(self, MainWindow): 519 | # ... 520 | 521 | def openfile2(self): 522 | # ... 523 | 524 | def handleCalc4(self): 525 | # ... 526 | 527 | def openfile(self): 528 | # ... 529 | 530 | def handleCalc3(self): 531 | # ... 532 | 533 | def printf(self, text): 534 | # ... 535 | 536 | def showimg(self, img): 537 | # ... 538 | 539 | def click_1(self): 540 | # ... 541 | 542 | 543 | if __name__ == "__main__": 544 | app = QtWidgets.QApplication(sys.argv) 545 | MainWindow = QtWidgets.QMainWindow() 546 | ui = Ui_MainWindow() 547 | ui.setupUi(MainWindow) 548 | MainWindow.show() 549 | sys.exit(app.exec_()) 550 | ``` 551 | 552 | 这个程序文件是一个基于PyQt5的水果识别系统设计。程序中包含了图像识别和视频识别的功能。 553 | 554 | 程序首先导入了所需的库和模块,包括os、json、torch、PIL、transforms、matplotlib、cv2、numpy等。然后导入了自定义的模型vgg。 555 | 556 | 接下来定义了一个名为det的函数,用于进行图像或视频的识别。在函数中,首先根据设备是否支持GPU选择运行设备,然后定义了数据的预处理操作。接着读取了类别标签文件class_indices.json,并加载了预训练的模型vgg16Net.pth。然后根据输入的信息判断是图像还是视频,如果是图像,则读取图像并进行预测;如果是视频,则读取视频的每一帧并进行预测。最后将预测结果显示在界面上,并保存预测结果的图像。 557 | 558 | 程序还定义了一个名为Thread_1的线程类,用于在后台运行det函数。该线程类继承自QThread,并重写了run方法,在run方法中调用了det函数。 559 | 560 | 接下来定义了一个名为Ui_MainWindow的类,用于创建主窗口界面。在该类中,定义了界面的布局和控件,并绑定了相应的事件处理函数。 561 | 562 | 最后,在主程序中创建了一个Qt应用程序对象app,并创建了主窗口对象MainWindow和Ui_MainWindow对象ui。然后调用ui的setupUi方法设置主窗口的界面,并显示主窗口。 563 | 564 | 整个程序的功能是通过界面上的按钮来选择图像或视频文件,并进行识别。识别结果会显示在界面上,并保存预测结果的图像。 565 | 566 | # 6.系统整体结构 567 | 568 | 整体功能和构架概述: 569 | 570 | 该程序是一个基于改进SE-VGG16-BN的131种水果图像分类系统,具有颜色和品种分级的功能。它包含了多个文件,每个文件负责不同的功能模块。 571 | 572 | | 文件名 | 功能概述 | 573 | | --- | --- | 574 | | ConfusionMatrix.py | 计算和绘制混淆矩阵,以及打印模型的准确率、精确度、召回率和特异度。 | 575 | | fit.py | 定义了SEVGG16BN模型的特征提取部分和分类器部分,用于图像分类。 | 576 | | model.py | 实现了VGG模型的构建和前向传播过程,提供了预训练权重的下载链接。 | 577 | | predict.py | 使用预训练的VGG模型对单张图片进行分类预测。 | 578 | | train.py | 训练模型的脚本,包括数据预处理、模型定义、损失函数和优化器的设置、训练和验证过程等。 | 579 | | ui.py | 基于PyQt5的水果识别系统设计,包含图像识别和视频识别的功能。 | 580 | 581 | 582 | # 7.SENet网络 583 | SENet是最后一届ImageNet分类任务的冠军.SENet[8]的本质是采用通道注意力机制,通过深度学习的方式自动获取图像各个特征通道的权重,以增强有用特征并抑制无用特征.SENet的核心模块是squeeze-and-excitation (SE),主要分3个步骤对特征进行重标定,如图所示. 584 | ![在这里插入图片描述](a603d8d63e4346128b70ddd85954c364.png) 585 | 某种程度上具有全局的感受野,它表征着在特征通道上响应的全局分布. 586 | (2)Excitationl101即Fex操作,通过两个全连接层先降维后升维对squeeze操作的结果进行非线性变换,来为每个特征通道生成权值,该权值表示特征通道之间的相关性,如式(3)所示: 587 | s = Fex(z,W)=o(g(z, W))= o(W6(Wz))(3)其中,Wz是第一个全连接层操作,W1的维度是C/r×C,z的维度是1×1×C,因此Wz的输出维度是1×1×C/r,即通过r(维度的缩放因子)进行了降维,然后采用ReLU激活.第2个全连接层是将上个全连接层的输出乘以W,其中W的维度是C×C/r,因此最终输出的维度为1×1×C,即同squeeze操作输出的维度相同,最后再经过Sigmoid函数用于获取各通道归一化后的权重,得到维度为1×1×C的s, s用来表示通过之前两个全连接层的非线性变换学习到的第C个特征图的权值.Excitation这种先降维后升维的操作一方面降低模型复杂度,使网络具有更好的非线性,另一方面更好的拟合通道间复杂的相关性,提升了模型泛化能力. 588 | (3) Reweight"即 Fscale操作,将上一步excitation操作得到的权值s通过乘法逐通道加权到原始的特征上,完成在通道维度上的对原始特征的重标定。总之,SE模块通过自动学习,以获取每个特征通道的重要程度,然后根据各个通道的重要程度,一方面去提升有用的特征,另一方面抑制对当前任务不相关或作用不大的特征. 589 | 590 | # 8. VGG16网络 591 | VGGNet荣膺2014年ImageNet图像分类第2名的好成绩,其中VGG16是VGGNet中分类性能最好的网络之一,其网络结构如图所示. 592 | 593 | ![在这里插入图片描述](4c1f113fbd2e4f84b7f9fa57f2e13114.png) 594 | (1) VGG16 网络可分为6段,即5段卷积加1段全连接,其中5段卷积包含13个卷积层,1段全连接指网络最后的3个全连接层,因此VGG16网络总共有13+3=16层. 595 | (2)5段卷积用以提取低、中、高各层的图像特征,每一段有2或3个卷积层.为了增加网络的非线性、防止梯度消失、减少过拟合以及提高网络训练的速度,各卷积层后均采用ReLU激活函数.为利于捕捉细节变化,获得更好的非线性效果并减少参数数量,每个卷积层均采用3×3的卷积核,使得网络结构更加简洁,在必要时3×3卷积核的堆叠还可以替代5×5、7×7等较大的卷积核. 596 | (3)5段卷积的尾部均接有一个最大池化层,该池化层采用2×2的池化核,能够减小卷积层参数误差造成估计值均值的偏移,更容易捕捉图像和梯度的变化,有利于保留纹理等细节信息. 597 | (4) VGG16 网络的最后一段是3个全连接层,全连接层中的每一个节点都与上一层每个节点连接,把前一层的输出特征综合起来,起到分类器的作用. 598 | 总之, VGG16网络的深度为16层,这种较深的网络通过逐层的抽象,能够不断学习由低到高各层的特征,具有更强的非线性表达能力,能表达更为丰富的特征,拟合更为复杂的输入特征.另外, VGG16 网络最开始采用64个3×3卷积核,随着网络的加深,卷积核数量逐渐从64,增加到128、256、512,因此使其具有较大的网络宽度,宽度的增加能使网络各层学习到更为丰富的颜色、纹理等特征. 599 | 600 | # 9.改进后的VGG16网络 601 | #### SE-VGGConv模块 602 | SE模块的最大特点在于其内部采用常见的池化及全连接层,因此具有很强的通用性,可以方便的嵌入到其他常见的网络模型中.[谷歌大脑提出在VGG网络模型的卷积层之后加入SE视觉注意力单元](https://mbd.pub/o/bread/ZZaTmJlw),如图所示. 603 | 如前所述,在VGG网络的卷积层后,首先经过一个GAP全局平均池化层,即图1中的squeeze操作,用于获取通道级的全局特征.然后进入第一个FC层进行降维操作,用ReLU函数激活后进入第2个FC层进行升维操作, Sigmoid函数用于获取各通道归一化后的权重. Scalel8131(即 reweight 操作)将归一化后的权重加权到每个原始通道的特征之上,实现了在通道维度上的对原始特征的重标定. 604 | ![在这里插入图片描述](d459c7c632db4c28af196afd9512f20c.png) 605 | 606 | #### SE-VGG16-BN网络模型 607 | VGG16因其具有较好的深度及宽度,在图像分类的应用具有一定的优势,但对具有类间相似性高,类内差异性大以及存在复杂背景干扰的花卉分类,其准确率还有待提高.因此,在VGG16的基础上引入BN层及SE视觉注意力单元,可以充分提取蔬菜水果分类任务中类间相似性高、类内差异较大的敏感特征,从而提高蔬菜水果分类的准确率. 608 | 在VGG16加入BN层及SE视觉注意力单元后的网络结构如图4所示. 609 | 图对VGG16 网络做了如下改进:前5段卷积的每个卷积层中均加入SE视觉注意力单元、BN层和ReLU激活函数.其中SE单元用于学习各通道的重要程度,从而增强有用特征并抑制无用特征.BN层(batchnormalization)的作用是加快网络的训练和收敛的速度,防止梯度爆炸及梯度消失,使模型会变得更加稳定;ReLU激活函数[4]能增强网络的非线性、防止梯度消失、减少过拟合并提高网络训练的速度.为防止过拟合,第6段的两个FC层后面均加入 dropout. Softmax用于最终的分类,由于本文采用的蔬菜水果数据集有131种花卉,因此输出的是原始图片对应于131类蔬菜水果的概率. 610 | ![在这里插入图片描述](558a4c1d66c3482fad5cff9d55bfa3f9.png) 611 | 612 | #### 多损失函数融合 613 | 614 | 交叉嫡损失函数(cross-entropy cost function)经常用于分类任务中,起着控制模型的总体走势的作用,交叉嫡损失函数的定义如下: 615 | ![在这里插入图片描述](4b8bb1117771485bb723bc39ca5b84c4.png) 616 | 其中, n是批处理样本数,x为输入,y为标签值,y表示实际输出. 617 | 中心损失函数(center loss)优势是可以学习类内距离更小的特征,从而减小类内的差异,并能在一定程度上增大类间差异性,从而提高分类的准确率,中心损失函数的定义如下: 618 | ![在这里插入图片描述](7a1bc762d9e5484fbc3e70d7a5a7e901.png) 619 | 其中,n是批处理样本数,x;表示y类别的第i个特征,C,表示i类特征的中心值. 620 | 将交叉嫡损失函数和中心损失函数进行融合,并将其推广至多层神经网络,假设输出神经元的期望值是y={t,y.y3.…},y={,y2',y…},则融合后的计算公式如下: 621 | ![在这里插入图片描述](5641a46afa874bd29027e60b1bdd3773.png) 622 | 其中,融合系数⒉的取值范围是O-1,通过多损失函数的融合,放大了花卉的类间距离,缩小了类内距离,加快了网络的收敛速度,进一步提高了分类的效率和准确率. 623 | 624 | 625 | 626 | # 10.系统整合 627 | 628 | 下图[完整源码&数据集&环境部署视频教程&自定义UI界面](https://s.xiaocichang.com/s/07bffe) 629 | 630 | ![在这里插入图片描述](eded326eb3a54f6da1be4bed3790858c.png#pic_center) 631 | 632 | 633 | 参考博客[《基于改进SE-VGG16-BN的131种水果蔬菜图像分类系统(颜色、品种分级)》](https://mbd.pub/o/qunshan/work) 634 | 635 | 636 | # 11.参考文献 637 | --- 638 | [1][杨旺功](https://s.wanfangdata.com.cn/paper?q=%E4%BD%9C%E8%80%85:%22%E6%9D%A8%E6%97%BA%E5%8A%9F%22),[淮永建](https://s.wanfangdata.com.cn/paper?q=%E4%BD%9C%E8%80%85:%22%E6%B7%AE%E6%B0%B8%E5%BB%BA%22).[多层特征融合及兴趣区域的花卉图像分类](https://d.wanfangdata.com.cn/periodical/hebgcdxxb202104021)[J].[哈尔滨工程大学学报](https://sns.wanfangdata.com.cn/perio/hebgcdxxb).2021,(4).DOI:10.11990/jheu.201912064 . 639 | 640 | [2][严春满](https://s.wanfangdata.com.cn/paper?q=%E4%BD%9C%E8%80%85:%22%E4%B8%A5%E6%98%A5%E6%BB%A1%22),[王铖](https://s.wanfangdata.com.cn/paper?q=%E4%BD%9C%E8%80%85:%22%E7%8E%8B%E9%93%96%22).[卷积神经网络模型发展及应用](https://d.wanfangdata.com.cn/periodical/jsjkxyts202101003)[J].[计算机科学与探索](https://sns.wanfangdata.com.cn/perio/jsjkxyts).2021,(1).DOI:10.3778/j.issn.1673-9418.2008016 . 641 | 642 | [3][吴丽娜](https://s.wanfangdata.com.cn/paper?q=%E4%BD%9C%E8%80%85:%22%E5%90%B4%E4%B8%BD%E5%A8%9C%22),[王林山](https://s.wanfangdata.com.cn/paper?q=%E4%BD%9C%E8%80%85:%22%E7%8E%8B%E6%9E%97%E5%B1%B1%22).[改进的LeNet-5模型在花卉识别中的应用](https://d.wanfangdata.com.cn/periodical/jsjgcysj202003040)[J].[计算机工程与设计](https://sns.wanfangdata.com.cn/perio/jsjgcysj).2020,(3).DOI:10.16208/j.issn1000-7024.2020.03.040 . 643 | 644 | [4][李克文](https://s.wanfangdata.com.cn/paper?q=%E4%BD%9C%E8%80%85:%22%E6%9D%8E%E5%85%8B%E6%96%87%22),[李新宇](https://s.wanfangdata.com.cn/paper?q=%E4%BD%9C%E8%80%85:%22%E6%9D%8E%E6%96%B0%E5%AE%87%22).[基于SENet改进的Faster R-CNN行人检测模型](https://d.wanfangdata.com.cn/periodical/jsjxtyy202004041)[J].[计算机系统应用](https://sns.wanfangdata.com.cn/perio/jsjxtyy).2020,(4).DOI:10.15888/j.cnki.csa.007321 . 645 | 646 | [5][李昊玥](https://s.wanfangdata.com.cn/paper?q=%E4%BD%9C%E8%80%85:%22%E6%9D%8E%E6%98%8A%E7%8E%A5%22),[陈桂芬](https://s.wanfangdata.com.cn/paper?q=%E4%BD%9C%E8%80%85:%22%E9%99%88%E6%A1%82%E8%8A%AC%22),[裴傲](https://s.wanfangdata.com.cn/paper?q=%E4%BD%9C%E8%80%85:%22%E8%A3%B4%E5%82%B2%22).[基于改进Mask R-CNN的奶牛个体识别方法研究](https://d.wanfangdata.com.cn/periodical/hnnydxxb202006018)[J].[华南农业大学学报](https://sns.wanfangdata.com.cn/perio/hnnydxxb).2020,(6).DOI:10.7671/j.issn.1001-411X.202003030 . 647 | 648 | [6][孟庆宽](https://s.wanfangdata.com.cn/paper?q=%E4%BD%9C%E8%80%85:%22%E5%AD%9F%E5%BA%86%E5%AE%BD%22),[张漫](https://s.wanfangdata.com.cn/paper?q=%E4%BD%9C%E8%80%85:%22%E5%BC%A0%E6%BC%AB%22),[杨晓霞](https://s.wanfangdata.com.cn/paper?q=%E4%BD%9C%E8%80%85:%22%E6%9D%A8%E6%99%93%E9%9C%9E%22),等.[基于轻量卷积结合特征信息融合的玉米幼苗与杂草识别](https://d.wanfangdata.com.cn/periodical/nyjxxb202012026)[J].[农业机械学报](https://sns.wanfangdata.com.cn/perio/nyjxxb).2020,(12).DOI:10.6041/j.issn.1000-1298.2020.12.026 . 649 | 650 | [7][尹红](https://s.wanfangdata.com.cn/paper?q=%E4%BD%9C%E8%80%85:%22%E5%B0%B9%E7%BA%A2%22),[符祥](https://s.wanfangdata.com.cn/paper?q=%E4%BD%9C%E8%80%85:%22%E7%AC%A6%E7%A5%A5%22),[曾接贤](https://s.wanfangdata.com.cn/paper?q=%E4%BD%9C%E8%80%85:%22%E6%9B%BE%E6%8E%A5%E8%B4%A4%22),等.[选择性卷积特征融合的花卉图像分类](https://d.wanfangdata.com.cn/periodical/zgtxtxxb-a201905009)[J].[中国图象图形学报](https://sns.wanfangdata.com.cn/perio/zgtxtxxb-a).2019,(5). 651 | 652 | [8]佚名.[FuSENet: fused squeeze-and-excitation network for spectral-spatial hyperspectral image classification](https://d.wanfangdata.com.cn/periodical/5b5902437efe2f1ff07528583f0a2780)[J].IET image processing.2020,14(8).1653-1661.DOI:10.1049/iet-ipr.2019.1462 . 653 | 654 | [9][Cibuk, Musa](https://s.wanfangdata.com.cn/paper?q=%E4%BD%9C%E8%80%85:%22Cibuk%2C%20Musa%22),[Budak, Umit](https://s.wanfangdata.com.cn/paper?q=%E4%BD%9C%E8%80%85:%22Budak%2C%20Umit%22),[Guo, Yanhui](https://s.wanfangdata.com.cn/paper?q=%E4%BD%9C%E8%80%85:%22Guo%2C%20Yanhui%22),等.[Efficient deep features selections and classification for flower species recognition](https://d.wanfangdata.com.cn/periodical/60e8c7a309055caba77b38d5933dc842)[J].Measurement.2019.1377-13.DOI:10.1016/j.measurement.2019.01.041 . 655 | 656 | [10][Xiaoling Xia](https://s.wanfangdata.com.cn/paper?q=%E4%BD%9C%E8%80%85:%22Xiaoling%20Xia%22),[Cui Xu](https://s.wanfangdata.com.cn/paper?q=%E4%BD%9C%E8%80%85:%22Cui%20Xu%22),[Bing Nan](https://s.wanfangdata.com.cn/paper?q=%E4%BD%9C%E8%80%85:%22Bing%20Nan%22).Inception-v3 for flower classification[C]. 657 | 658 | 659 | --- 660 | #### 如果您需要更详细的【源码和环境部署教程】,除了通过【系统整合】小节的链接获取之外,还可以通过邮箱以下途径获取: 661 | #### 1.请先在GitHub上为该项目点赞(Star),编辑一封邮件,附上点赞的截图、项目的中文描述概述(About)以及您的用途需求,发送到我们的邮箱 662 | #### sharecode@yeah.net 663 | #### 2.我们收到邮件后会定期根据邮件的接收顺序将【完整源码和环境部署教程】发送到您的邮箱。 664 | #### 【免责声明】本文来源于用户投稿,如果侵犯任何第三方的合法权益,可通过邮箱联系删除。 --------------------------------------------------------------------------------