├── 模型验证 ├── 说明.md ├── 图片 │ ├── 月季.jpg │ ├── 杜鹃.jpg │ ├── 桃花.jpg │ ├── 樱花.jpg │ ├── 玫瑰.jpg │ ├── 睡莲.jpg │ ├── 雏菊.jpg │ ├── 勋章菊.jpg │ ├── 向日葵.jpg │ ├── 康乃馨.jpg │ ├── 旱金莲.jpg │ ├── 毛地黄.jpg │ ├── 水仙花.jpg │ ├── 牵牛花.jpg │ ├── 紫罗兰.jpg │ ├── 肿柄菊.jpg │ ├── 花菱草.jpg │ ├── 茉莉花.jpg │ ├── 蒲公英.jpg │ └── 郁金香.jpg └── 图片测试.py ├── upload ├── 123.jpg ├── 456.jpg ├── 789.jpg ├── image-20200323134803440.png ├── image-20200323161939465.png ├── image-20200323162157922.png ├── image-20200323162309627.png ├── image-20200323184552442.jpg ├── image-20200323225509766.png ├── image-20200323230153110.png ├── image-20200323231914637.jpg ├── image-20200323231935143.jpg └── image-20200423135159846.png ├── 模型训练 ├── VGG16 │ ├── 说明.md │ ├── VGG16_test.py │ └── VGG16_train.py └── AlexNet │ ├── 说明.md │ ├── AlexNet_test.py │ └── AlexNet_train.py ├── data └── 说明.md ├── flask ├── img │ └── 这个文件夹里面放测试的图片.txt └── app.py ├── 百度接口实现花卉识别 ├── 说明.md ├── 获取token.py └── 花卉识别.py ├── .gitignore ├── 数据扩展程序 ├── 2.上下翻转.py ├── 1.镜像翻转.py └── 3.椒盐噪声.py ├── 数据来源 ├── 说明.md ├── 爬虫.py └── 分离数据.py ├── 数据切分 └── data_divide.py └── README.md /模型验证/说明.md: -------------------------------------------------------------------------------- 1 | 这里需要把训练好的VGG模型复制到本文件夹内 -------------------------------------------------------------------------------- /upload/123.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SJcun/flower_recognition/HEAD/upload/123.jpg -------------------------------------------------------------------------------- /upload/456.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SJcun/flower_recognition/HEAD/upload/456.jpg -------------------------------------------------------------------------------- /upload/789.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SJcun/flower_recognition/HEAD/upload/789.jpg -------------------------------------------------------------------------------- /模型验证/图片/月季.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SJcun/flower_recognition/HEAD/模型验证/图片/月季.jpg -------------------------------------------------------------------------------- /模型验证/图片/杜鹃.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SJcun/flower_recognition/HEAD/模型验证/图片/杜鹃.jpg -------------------------------------------------------------------------------- /模型验证/图片/桃花.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SJcun/flower_recognition/HEAD/模型验证/图片/桃花.jpg -------------------------------------------------------------------------------- /模型验证/图片/樱花.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SJcun/flower_recognition/HEAD/模型验证/图片/樱花.jpg -------------------------------------------------------------------------------- /模型验证/图片/玫瑰.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SJcun/flower_recognition/HEAD/模型验证/图片/玫瑰.jpg -------------------------------------------------------------------------------- /模型验证/图片/睡莲.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SJcun/flower_recognition/HEAD/模型验证/图片/睡莲.jpg -------------------------------------------------------------------------------- /模型验证/图片/雏菊.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SJcun/flower_recognition/HEAD/模型验证/图片/雏菊.jpg -------------------------------------------------------------------------------- /模型验证/图片/勋章菊.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SJcun/flower_recognition/HEAD/模型验证/图片/勋章菊.jpg -------------------------------------------------------------------------------- /模型验证/图片/向日葵.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SJcun/flower_recognition/HEAD/模型验证/图片/向日葵.jpg -------------------------------------------------------------------------------- /模型验证/图片/康乃馨.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SJcun/flower_recognition/HEAD/模型验证/图片/康乃馨.jpg -------------------------------------------------------------------------------- /模型验证/图片/旱金莲.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SJcun/flower_recognition/HEAD/模型验证/图片/旱金莲.jpg -------------------------------------------------------------------------------- /模型验证/图片/毛地黄.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SJcun/flower_recognition/HEAD/模型验证/图片/毛地黄.jpg -------------------------------------------------------------------------------- /模型验证/图片/水仙花.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SJcun/flower_recognition/HEAD/模型验证/图片/水仙花.jpg -------------------------------------------------------------------------------- /模型验证/图片/牵牛花.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SJcun/flower_recognition/HEAD/模型验证/图片/牵牛花.jpg -------------------------------------------------------------------------------- /模型验证/图片/紫罗兰.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SJcun/flower_recognition/HEAD/模型验证/图片/紫罗兰.jpg -------------------------------------------------------------------------------- /模型验证/图片/肿柄菊.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SJcun/flower_recognition/HEAD/模型验证/图片/肿柄菊.jpg -------------------------------------------------------------------------------- /模型验证/图片/花菱草.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SJcun/flower_recognition/HEAD/模型验证/图片/花菱草.jpg -------------------------------------------------------------------------------- /模型验证/图片/茉莉花.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SJcun/flower_recognition/HEAD/模型验证/图片/茉莉花.jpg -------------------------------------------------------------------------------- /模型验证/图片/蒲公英.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SJcun/flower_recognition/HEAD/模型验证/图片/蒲公英.jpg -------------------------------------------------------------------------------- /模型验证/图片/郁金香.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SJcun/flower_recognition/HEAD/模型验证/图片/郁金香.jpg -------------------------------------------------------------------------------- /模型训练/VGG16/说明.md: -------------------------------------------------------------------------------- 1 | 训练200个epoch的模型已上传百度网盘。 2 | 3 | 链接:https://pan.baidu.com/s/1dHo8kmY9zbgIAy2Rp8NTZg 提取码:5vw1 -------------------------------------------------------------------------------- /data/说明.md: -------------------------------------------------------------------------------- 1 | 由于git上传限制,使用的数据已保存于百度网盘,如有需要,请自行下载。 2 | 3 | 链接:https://pan.baidu.com/s/1Ku7xOEp0yc8n2Kb9RM6GGQ 提取码:8jc5 -------------------------------------------------------------------------------- /flask/img/这个文件夹里面放测试的图片.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SJcun/flower_recognition/HEAD/flask/img/这个文件夹里面放测试的图片.txt -------------------------------------------------------------------------------- /百度接口实现花卉识别/说明.md: -------------------------------------------------------------------------------- 1 | 这个主要是**尝尝鲜**,看看大公司花卉识别的效果是怎么样的。 2 | 3 | 相关资料可以查看我的[博客](https://juejin.im/post/5e43f2af6fb9a07c91100bb3) -------------------------------------------------------------------------------- /upload/image-20200323134803440.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SJcun/flower_recognition/HEAD/upload/image-20200323134803440.png -------------------------------------------------------------------------------- /upload/image-20200323161939465.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SJcun/flower_recognition/HEAD/upload/image-20200323161939465.png -------------------------------------------------------------------------------- /upload/image-20200323162157922.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SJcun/flower_recognition/HEAD/upload/image-20200323162157922.png -------------------------------------------------------------------------------- /upload/image-20200323162309627.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SJcun/flower_recognition/HEAD/upload/image-20200323162309627.png -------------------------------------------------------------------------------- /upload/image-20200323184552442.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SJcun/flower_recognition/HEAD/upload/image-20200323184552442.jpg -------------------------------------------------------------------------------- /upload/image-20200323225509766.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SJcun/flower_recognition/HEAD/upload/image-20200323225509766.png -------------------------------------------------------------------------------- /upload/image-20200323230153110.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SJcun/flower_recognition/HEAD/upload/image-20200323230153110.png -------------------------------------------------------------------------------- /upload/image-20200323231914637.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SJcun/flower_recognition/HEAD/upload/image-20200323231914637.jpg -------------------------------------------------------------------------------- /upload/image-20200323231935143.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SJcun/flower_recognition/HEAD/upload/image-20200323231935143.jpg -------------------------------------------------------------------------------- /upload/image-20200423135159846.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SJcun/flower_recognition/HEAD/upload/image-20200423135159846.png -------------------------------------------------------------------------------- /百度接口实现花卉识别/获取token.py: -------------------------------------------------------------------------------- 1 | # encoding:utf-8 2 | import requests 3 | 4 | # client_id 为官网获取的AK, client_secret 为官网获取的SK 5 | host = 'https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id=VUPQKszA6SZFPzrLX53XSMFL&client_secret=kpNHiA9saiGsw3E4wmCeq6iVHMSKdbxq' 6 | response = requests.get(host) 7 | if response: 8 | print(response.json()['access_token']) #['access_token'] -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Compiled class file 2 | *.class 3 | 4 | # Log file 5 | *.log 6 | 7 | # BlueJ files 8 | *.ctxt 9 | 10 | # Mobile Tools for Java (J2ME) 11 | .mtj.tmp/ 12 | 13 | # Package Files # 14 | *.jar 15 | *.war 16 | *.nar 17 | *.ear 18 | *.tar.gz 19 | *.rar 20 | *.zip 21 | *.pkl 22 | 23 | # virtual machine crash logs, see http://www.java.com/en/download/help/error_hotspot.xml 24 | hs_err_pid* 25 | -------------------------------------------------------------------------------- /百度接口实现花卉识别/花卉识别.py: -------------------------------------------------------------------------------- 1 | import requests 2 | import base64 3 | 4 | ''' 5 | #植物识别 6 | ''' 7 | 8 | request_url = "https://aip.baidubce.com/rest/2.0/image-classify/v1/plant" 9 | 10 | path='【图片地址】' 11 | # 二进制方式打开图片文件 12 | f = open(path, 'rb') 13 | img = base64.b64encode(f.read()) 14 | 15 | params = {"image":img} 16 | access_token = '【请求的access_token】' 17 | request_url = request_url + "?access_token=" + access_token 18 | headers = {'content-type': 'application/x-www-form-urlencoded'} 19 | response = requests.post(request_url, data=params, headers=headers) 20 | if response: 21 | print (response.json()) 22 | 23 | -------------------------------------------------------------------------------- /模型训练/AlexNet/说明.md: -------------------------------------------------------------------------------- 1 | train文件是训练程序,test文件是测试程序 2 | 3 | 模型直接调用的pytorch内保存的AlexNet模型,我们要在这个基础上训练的话(因为它的参数并不适合我们的花卉分类),要更改原模型的全连接层。 4 | 5 | ```python 6 | #选择模型 7 | net = models.alexnet() 8 | net.classifier = nn.Sequential( 9 | nn.Dropout(), 10 | nn.Linear(256 * 6 * 6, 4096), 11 | nn.ReLU(inplace=True), 12 | nn.Dropout(), 13 | nn.Linear(4096, 4096), 14 | nn.ReLU(inplace=True), 15 | nn.Linear(4096, 20), #20代表我们要训练的花卉类别为20 16 | ) 17 | ``` 18 | 19 | 在程序中,我虽然读取了**验证集**的数据,但是我并没有使用它! 20 | 21 | 经过500个epoch的效果是最好的,测试集准确率可以达到83% 22 | 23 | 500个epoch训练模型参数已上传百度网盘,如有需要,请自行下载。 24 | 25 | 链接:https://pan.baidu.com/s/1JEyVr226qOJN8aKgHdpr9A 提取码:kfen -------------------------------------------------------------------------------- /数据扩展程序/2.上下翻转.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | from PIL import Image 4 | #import matplotlib.pyplot as plt 5 | 6 | flowe_classes = ['peach_blossom','Jasminum','Matthiola', 7 | 'Rosa','Rhododendron','Dianthus','Cerasus','Narcissus','Pharbitis','Gazania', 8 | 'Eschscholtzia','Tithonia'] 9 | 10 | for name in flowe_classes: 11 | a = os.listdir('work/data/'+name) 12 | count = 1 13 | print(name) 14 | for x in a: 15 | 16 | img=Image.open('work/data/'+name+'/'+x) 17 | dst=img.transpose(Image.FLIP_TOP_BOTTOM)#上下互换 18 | newname='work/data/'+name+'/'+name+'1_'+str(count)+'.jpg' 19 | dst=dst.convert('RGB') 20 | dst.save(newname) 21 | count += 1 22 | -------------------------------------------------------------------------------- /数据扩展程序/1.镜像翻转.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | from PIL import Image 4 | #import matplotlib.pyplot as plt 5 | 6 | flowe_classes = ['daisy','dandelion','roses','sunflowers','tulips','Nymphaea', 7 | 'Tropaeolum_majus','Digitalis_purpurea','peach_blossom','Jasminum','Matthiola', 8 | 'Rosa','Rhododendron','Dianthus','Cerasus','Narcissus','Pharbitis','Gazania', 9 | 'Eschscholtzia','Tithonia'] 10 | 11 | 12 | for name in flowe_classes: 13 | a = os.listdir('work/data/'+name) 14 | count = 1 15 | print(name) 16 | for x in a: 17 | #print(x) 18 | 19 | img=Image.open('work/data/'+name+'/'+x) 20 | dst=img.transpose(Image.FLIP_LEFT_RIGHT)#左右互换 21 | newname='work/data/'+name+'/'+name+'_'+str(count)+'.jpg' 22 | dst=dst.convert('RGB') 23 | dst.save(newname) 24 | count += 1 25 | -------------------------------------------------------------------------------- /数据来源/说明.md: -------------------------------------------------------------------------------- 1 | **这个文件夹里放了我的数据来源,你们想扩展数据集的话可以自行扩展** 2 | 3 | 由于git的上传限制,大文件无法被上传,现在已经放到了百度网盘,请根据需要自行下载。 4 | 5 | 链接:https://pan.baidu.com/s/1iqFRI0UYHfVnMFTsGbtx-w 提取码:wu9s 6 | 7 | 我目前使用的是20种花卉数据,日后肯定会增多。 8 | 9 | 数据来源主要取决于3个方面: 10 | 11 | - 5种花卉数据集,每类花卉包含600张到900张不等的图片 12 | - 来源于Oxford 102 Flowers数据集,该数据集包含102类英国花卉数据,每个类别包含 40 到 258 张图像 13 | - 最后一部分来源于百度图片,使用python程序批量采集花卉图像数据 14 | 15 | 这个时候就有人问了,为什么不直接用Oxford 102 Flowers,它可是有102种花卉啊!!还加别的数据集干什么?这个得解释一下,首先呢,它每种花卉的数据量比较小,直接拿来训练的话肯定会过拟合;另一方面就是最后我是要出软件的,要拿手机去真实的拍照的,英国花卉数据集里面的花我都没见过,像什么肿柄菊、毛地黄、旱金莲什么的,所以我后来用的爬虫在百度爬了几种常见的花卉,像桃花、月季、康乃馨之类的。 16 | 17 | 5种花卉的数据集包含雏菊、蒲公英、玫瑰花、向日葵和郁金香。 18 | 19 | 关于Oxford 102 Flowers数据集,它的原始数据把所有图像放在了一个名为jpg的文件夹内,并且还有两个用于分类花卉的文件,我在文件夹放了**数据切分**的python程序,他会把数据集按照相应文件且分为训练集、验证集和测试集。遗憾的是,我并不知道每种花的真实名称,所以采用的统一的名称(c1, c2, ……)。如果想知道花卉名称,建议使用百度接口识别一下,相关博客可以参考[百度接口实现花卉识别](https://juejin.im/post/5e43f2af6fb9a07c91100bb3),过程比较繁琐,这也是我只选了这个数据集其中几种花卉的原因。 20 | 21 | 22 | 23 | 24 | 25 | -------------------------------------------------------------------------------- /数据来源/爬虫.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | import re 3 | import requests 4 | 5 | 6 | def dowmloadPic(html, keyword): 7 | pic_url = re.findall('"objURL":"(.*?)",', html, re.S) 8 | i = 1 9 | print('找到关键词:' + keyword + '的图片,现在开始下载图片...') 10 | for each in pic_url: 11 | print('正在下载第' + str(i) + '张图片,图片地址:' + str(each)) 12 | try: 13 | pic = requests.get(each, timeout=10) 14 | except requests.exceptions.ConnectionError: 15 | print('【错误】当前图片无法下载') 16 | continue 17 | 18 | dir = '1_data/Eschscholtzia/'+str(i) + '.jpg' 19 | fp = open(dir, 'wb') 20 | fp.write(pic.content) 21 | fp.close() 22 | i += 1 23 | 24 | 25 | if __name__ == '__main__': 26 | word = input("Input key word: ") 27 | url = 'http://image.baidu.com/search/flip?tn=baiduimage&ie=utf-8&word=' + word + '&ct=201326592&v=flip' 28 | result = requests.get(url) 29 | dowmloadPic(result.text, word) -------------------------------------------------------------------------------- /数据扩展程序/3.椒盐噪声.py: -------------------------------------------------------------------------------- 1 | 2 | #coding=utf-8 3 | import os 4 | from PIL import Image 5 | import numpy as np 6 | import matplotlib.pyplot as plt 7 | 8 | 9 | flowe_classes = ['peach_blossom','Jasminum','Matthiola', 10 | 'Rosa','Rhododendron','Dianthus','Cerasus','Narcissus','Pharbitis','Gazania', 11 | 'Eschscholtzia','Tithonia'] 12 | 13 | for name in flowe_classes: 14 | a=os.listdir('work/data/'+name) 15 | count=1 16 | print(name) 17 | for x in a: 18 | oldname='work/data/'+name+'/'+x 19 | img=np.array(Image.open(oldname)) 20 | #随机生成5000个椒盐 21 | rows,cols,dims=img.shape 22 | for i in range(5000): 23 | x=np.random.randint(0,rows) 24 | y=np.random.randint(0,cols) 25 | img[x,y,:]=255 26 | img.flags.writeable = True # 将数组改为读写模式 27 | dst=Image.fromarray(np.uint8(img)) 28 | newname='work/data/'+name+'/'+name+'2_'+str(count)+'.jpg' 29 | dst=dst.convert('RGB') 30 | dst.save(newname) 31 | count+=1 -------------------------------------------------------------------------------- /数据切分/data_divide.py: -------------------------------------------------------------------------------- 1 | ''' 2 | 将原有数据集划分为训练集、验证集和测试集 3 | ''' 4 | 5 | 6 | import os 7 | import random 8 | #import shutil 9 | from shutil import copy2 10 | 11 | #比例 12 | scale = [0.6, 0.2, 0.2] 13 | 14 | #类别 15 | classes = ['daisy', 'dandelion', 'roses', 'sunflowers', 'tulips','Nymphaea','Tropaeolum_majus','Digitalis_purpurea','peach_blossom', 16 | 'Jasminum','Matthiola','Rosa','Rhododendron','Dianthus','Cerasus','Narcissus','Pharbitis','Gazania','Eschscholtzia','Tithonia'] 17 | 18 | 19 | for each in classes: 20 | datadir_normal = "work/data/"+each+"/" #原文件夹 21 | 22 | all_data = os.listdir(datadir_normal)#(图片文件夹) 23 | num_all_data = len(all_data) 24 | print(each+ "类图片数量: " + str(num_all_data) ) 25 | index_list = list(range(num_all_data)) 26 | #print(index_list) 27 | random.shuffle(index_list) 28 | num = 0 29 | 30 | trainDir = "work/new_data/train/"+each#(将训练集放在这个文件夹下) 31 | if not os.path.exists(trainDir): #如果不存在这个文件夹,就创造一个 32 | os.makedirs(trainDir) 33 | 34 | validDir = "work/new_data/val/"+each#(将验证集放在这个文件夹下) 35 | if not os.path.exists(validDir): 36 | os.makedirs(validDir) 37 | 38 | testDir = "work/new_data/test/"+each#(将测试集放在这个文件夹下) 39 | if not os.path.exists(testDir): 40 | os.makedirs(testDir) 41 | 42 | for i in index_list: 43 | fileName = os.path.join(datadir_normal, all_data[i]) 44 | if num < num_all_data*scale[0]: 45 | #print(str(fileName)) 46 | copy2(fileName, trainDir) 47 | elif num>num_all_data*scale[0] and num < num_all_data*(scale[0]+scale[1]): 48 | #print(str(fileName)) 49 | copy2(fileName, validDir) 50 | else: 51 | copy2(fileName, testDir) 52 | num += 1 -------------------------------------------------------------------------------- /模型训练/AlexNet/AlexNet_test.py: -------------------------------------------------------------------------------- 1 | ''' 2 | 测试 3 | ''' 4 | 5 | #这个是读取训练好的模型 6 | #再测试 7 | import torch 8 | import numpy as np 9 | import matplotlib.pyplot as plt 10 | from torch import nn, optim 11 | from torch.utils.data import DataLoader 12 | import torchvision 13 | from torch.autograd import Variable 14 | import torchvision.transforms as transforms 15 | 16 | import model 17 | 18 | #类别 19 | #classes = ('daisy', 'dandelion', 'roses', 'sunflowers', 'tulips') 20 | 21 | # 定义一些超参数 22 | batch_size = 100 #批大小 23 | 24 | #数据预处理 25 | data_transform = transforms.Compose([ 26 | transforms.Resize((224,224), 2), #对图像大小统一 27 | transforms.RandomHorizontalFlip(), #图像翻转 28 | transforms.ToTensor(), 29 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[ #图像归一化 30 | 0.229, 0.224, 0.225]) 31 | ]) 32 | 33 | test_dataset = torchvision.datasets.ImageFolder(root='work/data/test/', transform=data_transform) 34 | test_loader = DataLoader(test_dataset, batch_size = batch_size, shuffle=True, num_workers=0) 35 | 36 | #花卉类别 37 | data_classes = test_dataset.classes 38 | 39 | #选择CPU还是GPU的操作 40 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 41 | 42 | net = model.AlexNet() 43 | net.load_state_dict(torch.load('alexnet_flower_500.pkl')) 44 | #model.eval() 45 | 46 | correct = 0 47 | total = 0 48 | with torch.no_grad(): 49 | for data in test_loader: 50 | images, labels = data 51 | #images, labels = images.to(device), labels.to(device) 52 | images.to(device), labels.to(device) 53 | images, labels = Variable(images), Variable(labels) 54 | 55 | outputs = net(images) 56 | _, predicted = torch.max(outputs.data, 1) 57 | total += labels.size(0) 58 | correct += (predicted == labels).sum().item() 59 | 60 | print('Accuracy of the network on the test images: %d %%' % ( 61 | 100 * correct / total)) 62 | 63 | -------------------------------------------------------------------------------- /模型训练/VGG16/VGG16_test.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | ''' 4 | 测试 5 | ''' 6 | 7 | #这个是读取训练好的模型 8 | #再测试 9 | import torch 10 | import numpy as np 11 | import matplotlib.pyplot as plt 12 | from torchvision import models 13 | from torch import nn, optim 14 | from torch.utils.data import DataLoader 15 | import torchvision 16 | from torch.autograd import Variable 17 | import torchvision.transforms as transforms 18 | 19 | #import VGG16_model 20 | 21 | # 定义一些超参数 22 | batch_size = 100 #批大小 23 | 24 | #数据预处理 25 | data_transform = transforms.Compose([ 26 | transforms.Resize((224,224), 2), #对图像大小统一 27 | transforms.RandomHorizontalFlip(), #图像翻转 28 | transforms.ToTensor(), 29 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[ #图像归一化 30 | 0.229, 0.224, 0.225]) 31 | ]) 32 | 33 | test_dataset = torchvision.datasets.ImageFolder(root='work/data/test/', transform=data_transform) 34 | test_loader = DataLoader(test_dataset, batch_size = batch_size, shuffle=True, num_workers=0) 35 | 36 | data_classes = test_dataset.classes 37 | 38 | #选择CPU还是GPU的操作 39 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 40 | 41 | #选择模型 42 | net = models.vgg16() 43 | net.classifier = nn.Sequential(nn.Linear(25088, 4096), #vgg16 44 | nn.ReLU(), 45 | nn.Dropout(p=0.5), 46 | nn.Linear(4096, 4096), 47 | nn.ReLU(), 48 | nn.Dropout(p=0.5), 49 | nn.Linear(4096, 20)) 50 | #net = VGG16_model.VGG16() 51 | net.load_state_dict(torch.load("VGG16_flower_200.pkl")) 52 | net.to(device) 53 | #net.eval() 54 | 55 | correct = 0 56 | total = 0 57 | with torch.no_grad(): 58 | for data in test_loader: 59 | images, labels = data 60 | images, labels = images.to(device), labels.to(device) 61 | images, labels = Variable(images), Variable(labels) 62 | 63 | outputs = net(images) 64 | _, predicted = torch.max(outputs.data, 1) 65 | total += labels.size(0) 66 | correct += (predicted == labels).sum().item() 67 | 68 | print('Accuracy of the network on the test images: %d %%' % ( 69 | 100 * correct / total)) 70 | 71 | -------------------------------------------------------------------------------- /模型验证/图片测试.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | from PIL import Image 4 | import torchvision.transforms as transforms 5 | from torchvision import models #人家的模型 6 | from torch.autograd import Variable 7 | import torch 8 | #from torchvision.datasets import ImageFolder 9 | from torch import nn 10 | #import VGG16_model 11 | 12 | 13 | #数据预处理 14 | data_transform = transforms.Compose([ 15 | transforms.Resize((224,224), 2), #对图像大小统一 16 | transforms.ToTensor(), 17 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[ #图像归一化 18 | 0.229, 0.224, 0.225]) 19 | ]) 20 | 21 | #类别 22 | #这个类别是我在训练的过程输出的训练集的类别,是按照训练的顺序排列的 23 | data_classes = ['Cerasus', 'Dianthus', 'Digitalis_purpurea', 'Eschscholtzia', 24 | 'Gazania', 'Jasminum', 'Matthiola', 'Narcissus', 'Nymphaea', 25 | 'Pharbitis', 'Rhododendron', 'Rosa', 'Tithonia', 'Tropaeolum_majus', 26 | 'daisy', 'dandelion', 'peach_blossom', 'roses', 'sunflowers', 'tulips'] 27 | 28 | #读取数据 29 | img = Image.open('./图片/向日葵.jpg') 30 | img=data_transform(img)#这里经过转换后输出的input格式是[C,H,W],网络输入还需要增加一维批量大小B 31 | img = img.unsqueeze(0)#增加一维,输出的img格式为[1,C,H,W] 32 | 33 | #类别 34 | #train_dataset = ImageFolder(root='work/data/train/',transform=data_transform) 35 | #data_classes = train_dataset.classes 36 | 37 | #选择CPU还是GPU的操作 38 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 39 | 40 | #选择模型 41 | 42 | net = models.vgg16() 43 | net.classifier = nn.Sequential(nn.Linear(25088, 4096), #vgg16 44 | nn.ReLU(), 45 | nn.Dropout(p=0.5), 46 | nn.Linear(4096, 4096), 47 | nn.ReLU(), 48 | nn.Dropout(p=0.5), 49 | nn.Linear(4096, 20)) 50 | 51 | 52 | #读取参数 53 | net.load_state_dict(torch.load("VGG16_flower_200.pkl",map_location=torch.device('cpu'))) 54 | net.eval() 55 | net.to(device) 56 | 57 | img = Variable(img) 58 | score = net(img)#将图片输入网络得到输出 59 | probability = nn.functional.softmax(score,dim=1)#计算softmax,即该图片属于各类的概率 60 | max_value,index = torch.max(probability,1)#找到最大概率对应的索引号,该图片即为该索引号对应的类别 61 | print() 62 | print("识别为'{}'的概率为{}".format(data_classes[index.item()],max_value.item())) 63 | 64 | 65 | #pytorch网络输入图片的格式是[B,C,H,W],分别为batch(每批送入网络的图片数量),图片通道数,图片高,图片宽。 66 | #通过PIL的Image读取的图片是一个图片对象,可以进行裁剪翻转等torchvision.transforms变换。 67 | 68 | 69 | 70 | -------------------------------------------------------------------------------- /flask/app.py: -------------------------------------------------------------------------------- 1 | from flask import Flask, request 2 | 3 | app = Flask(__name__) 4 | 5 | ''' 6 | 卷积神经网络相关程序 7 | ''' 8 | from PIL import Image 9 | import torchvision.transforms as transforms 10 | from torchvision import models #人家的模型 11 | from torch.autograd import Variable 12 | import torch 13 | from torch import nn 14 | 15 | #数据预处理 16 | data_transform = transforms.Compose([ 17 | transforms.Resize((224,224), 2), #对图像大小统一 18 | transforms.ToTensor(), 19 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[ #图像归一化 20 | 0.229, 0.224, 0.225]) 21 | ]) 22 | 23 | #类别 24 | data_classes = ['Cerasus', 'Dianthus', 'Digitalis_purpurea', 'Eschscholtzia', 25 | 'Gazania', 'Jasminum', 'Matthiola', 'Narcissus', 'Nymphaea', 26 | 'Pharbitis', 'Rhododendron', 'Rosa', 'Tithonia', 'Tropaeolum_majus', 27 | 'daisy', 'dandelion', 'peach_blossom', 'roses', 'sunflowers', 'tulips'] 28 | 29 | #选择CPU还是GPU的操作 30 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 31 | 32 | #选择模型 33 | net = models.vgg16() 34 | net.classifier = nn.Sequential(nn.Linear(25088, 4096), #vgg16 35 | nn.ReLU(), 36 | nn.Dropout(p=0.5), 37 | nn.Linear(4096, 4096), 38 | nn.ReLU(), 39 | nn.Dropout(p=0.5), 40 | nn.Linear(4096, 20)) 41 | 42 | net.load_state_dict(torch.load("VGG16_flower_200.pkl",map_location=torch.device('cpu'))) 43 | net.eval() 44 | net.to(device) 45 | 46 | ''' 47 | flask相关程序 48 | ''' 49 | @app.route('/inference') 50 | def inference(): 51 | 52 | im_url = request.args.get('url') 53 | 54 | #读取数据 55 | img = Image.open(im_url) 56 | img=data_transform(img)#这里经过转换后输出的input格式是[C,H,W],网络输入还需要增加一维批量大小B 57 | img = img.unsqueeze(0)#增加一维,输出的img格式为[1,C,H,W] 58 | 59 | img = Variable(img) 60 | score = net(img)#将图片输入网络得到输出 61 | probability = nn.functional.softmax(score,dim=1)#计算softmax,即该图片属于各类的概率 62 | max_value,index = torch.max(probability,1)#找到最大概率对应的索引号,该图片即为该索引号对应的类别 63 | 64 | return str(data_classes[index.item()]) 65 | #print() 66 | #print("识别为'{}'的概率为{}".format(data_classes[index.item()],max_value.item())) 67 | 68 | 69 | #return 'inference' 70 | 71 | if __name__ == '__main__': 72 | app.run() 73 | 74 | 75 | -------------------------------------------------------------------------------- /模型训练/VGG16/VGG16_train.py: -------------------------------------------------------------------------------- 1 | ''' 2 | 改成用现成的模型 3 | ''' 4 | import torch 5 | from torchvision import models 6 | from torch import nn, optim 7 | from torch.utils.data import DataLoader 8 | #import torchvision 9 | from torch.autograd import Variable 10 | import torchvision.transforms as transforms 11 | from torchvision.datasets import ImageFolder 12 | 13 | #import model 14 | 15 | # 定义一些超参数 16 | batch_size = 100 #批大小 17 | learning_rate = 0.001 18 | num_epoches = 15000 19 | 20 | #数据预处理 21 | data_transform = transforms.Compose([ 22 | #transforms.Scale((224,224), 2), #对图像大小统一 23 | transforms.Resize([224, 224], 2), 24 | transforms.RandomHorizontalFlip(), #图像翻转 25 | transforms.ToTensor(), 26 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[ #图像归一化 27 | 0.229, 0.224, 0.225]) 28 | ]) 29 | 30 | #获取数据集 31 | #训练集 32 | train_dataset = ImageFolder(root='work/data/train/',transform=data_transform) 33 | train_loader = DataLoader(train_dataset, batch_size = batch_size, shuffle=True) 34 | #验证集 35 | val_dataset = ImageFolder(root='work/data/val/', transform=data_transform) 36 | val_loader = DataLoader(val_dataset, batch_size = batch_size, shuffle=True) 37 | 38 | #类别 39 | data_classes = train_dataset.classes 40 | 41 | #选择模型 42 | net = models.vgg16() 43 | net.classifier = nn.Sequential(nn.Linear(25088, 4096), #vgg16 44 | nn.ReLU(), 45 | nn.Dropout(p=0.5), 46 | nn.Linear(4096, 4096), 47 | nn.ReLU(), 48 | nn.Dropout(p=0.5), 49 | nn.Linear(4096, 20)) 50 | 51 | 52 | #损失函数和优化器 53 | criterion = nn.CrossEntropyLoss() 54 | optimizer = optim.SGD(net.parameters(), lr=learning_rate) 55 | 56 | #选择CPU还是GPU的操作 57 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 58 | 59 | net.to(device) 60 | 61 | #开始训练 62 | for epoch in range(num_epoches): 63 | 64 | running_loss = 0. 65 | #batch_size = 100 66 | 67 | for i, data in enumerate(train_loader): 68 | 69 | inputs, labels = data 70 | inputs, labels = inputs.to(device), labels.to(device) 71 | inputs, labels = Variable(inputs), Variable(labels) 72 | 73 | optimizer.zero_grad() 74 | 75 | outputs = net(inputs) 76 | loss = criterion(outputs, labels) 77 | loss.backward() 78 | optimizer.step() 79 | 80 | running_loss += loss.item() 81 | print('[%d, %5d] loss: %.4f' %(epoch + 1, (i+1)*batch_size, loss.item())) 82 | 83 | 84 | print('Finished Training') 85 | 86 | #保存模型 87 | torch.save(net.state_dict(), 'flower_2.pkl') 88 | 89 | 90 | -------------------------------------------------------------------------------- /模型训练/AlexNet/AlexNet_train.py: -------------------------------------------------------------------------------- 1 | ''' 2 | 改成用现成的模型 3 | ''' 4 | import torch 5 | from torchvision import models 6 | from torch import nn, optim 7 | from torch.utils.data import DataLoader 8 | #import torchvision 9 | from torch.autograd import Variable 10 | import torchvision.transforms as transforms 11 | from torchvision.datasets import ImageFolder 12 | 13 | #import model 14 | 15 | # 定义一些超参数 16 | batch_size = 100 #批大小 17 | learning_rate = 0.001 #学习率 18 | num_epoches = 15000 #期望训练次数 19 | 20 | #数据预处理 21 | data_transform = transforms.Compose([ 22 | #transforms.Scale((224,224), 2), #对图像大小统一 23 | transforms.Resize([224, 224], 2), 24 | transforms.RandomHorizontalFlip(), #图像翻转 25 | transforms.ToTensor(), 26 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[ #图像归一化 27 | 0.229, 0.224, 0.225]) 28 | ]) 29 | 30 | #获取数据集 31 | #训练集 32 | train_dataset = ImageFolder(root='work/data/train/',transform=data_transform) 33 | train_loader = DataLoader(train_dataset, batch_size = batch_size, shuffle=True) 34 | #验证集 35 | val_dataset = ImageFolder(root='work/data/val/', transform=data_transform) 36 | val_loader = DataLoader(val_dataset, batch_size = batch_size, shuffle=True) 37 | 38 | #类别 39 | data_classes = train_dataset.classes 40 | 41 | #选择模型 42 | net = models.alexnet() 43 | net.classifier = nn.Sequential( 44 | nn.Dropout(), 45 | nn.Linear(256 * 6 * 6, 4096), 46 | nn.ReLU(inplace=True), 47 | nn.Dropout(), 48 | nn.Linear(4096, 4096), 49 | nn.ReLU(inplace=True), 50 | nn.Linear(4096, 20), 51 | ) 52 | 53 | 54 | #net.load_state_dict(torch.load('alexnet_flower_450.pkl')) 55 | 56 | #损失函数和优化器 57 | criterion = nn.CrossEntropyLoss() 58 | optimizer = optim.SGD(net.parameters(), lr=learning_rate) 59 | 60 | #选择CPU还是GPU的操作 61 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 62 | 63 | net.to(device) 64 | 65 | #开始训练 66 | for epoch in range(num_epoches): 67 | 68 | running_loss = 0. 69 | 70 | for i, data in enumerate(train_loader): 71 | 72 | inputs, labels = data 73 | inputs, labels = inputs.to(device), labels.to(device) 74 | inputs, labels = Variable(inputs), Variable(labels) 75 | 76 | optimizer.zero_grad() 77 | 78 | outputs = net(inputs) 79 | loss = criterion(outputs, labels) 80 | loss.backward() 81 | optimizer.step() 82 | 83 | running_loss += loss.item() 84 | print('[%d, %5d] loss: %.4f' %(epoch + 1, (i+1)*batch_size, loss.item())) 85 | 86 | ''' 87 | #可以每隔一段时间保存一次模型参数 88 | if((epoch+1)%100 == 0) : 89 | torch.save(net.state_dict(), 'alexnet_flower_'+str(epoch+1)+'.pkl') 90 | if((epoch+1)%20 == 0) : 91 | print('是否继续训练?') 92 | a=input() 93 | if (a=='y'): 94 | continue 95 | else: 96 | break 97 | ''' 98 | 99 | print('Finished Training') 100 | 101 | #保存模型参数 102 | torch.save(net.state_dict(), 'flower.pkl') 103 | -------------------------------------------------------------------------------- /数据来源/分离数据.py: -------------------------------------------------------------------------------- 1 | import scipy.io 2 | import numpy as np 3 | import os 4 | from PIL import Image 5 | import shutil 6 | 7 | ########取出 imagelabels 文件的值############ 8 | 9 | imagelabels_path='imagelabels.mat' 10 | labels = scipy.io.loadmat(imagelabels_path) 11 | labels = np.array(labels['labels'][0])-1 12 | 13 | ######## 取出 flower dataset: train test valid 数据id标识 ######## 14 | setid_path='setid.mat' 15 | setid = scipy.io.loadmat(setid_path) 16 | 17 | validation = np.array(setid['valid'][0]) - 1 18 | np.random.shuffle(validation) 19 | 20 | train = np.array(setid['trnid'][0]) - 1 21 | np.random.shuffle(train) 22 | 23 | test=np.array(setid['tstid'][0]) -1 24 | np.random.shuffle(test) 25 | ######## flower data path 数据保存路径 ######## 26 | flower_dir = list() 27 | 28 | ######## flower data dirs 生成保存数据的绝对路径和名称 ######## 29 | for img in os.listdir("jpg"): 30 | 31 | ######## flower data ######## 32 | flower_dir.append(os.path.join("jpg", img)) 33 | 34 | ######## flower data dirs sort 数据的绝对路径和名称排序 从小到大 ######## 35 | flower_dir.sort() 36 | 37 | #print(flower_dir) 38 | 39 | #####生成flower data train的分类数据 ####### 40 | des_folder_train="prepare_pic\\train" 41 | for tid in train: 42 | ######## open image and get label ######## 43 | img=Image.open(flower_dir[tid]) 44 | #print(flower_dir[tid]) 45 | ######## resize img ####### 46 | img = img.resize((256, 256),Image.ANTIALIAS) 47 | lable=labels[tid] 48 | #print(lable) 49 | 50 | path=flower_dir[tid] 51 | #print("path:",path) 52 | 53 | base_path=os.path.basename(path) 54 | #print("base_path:",base_path) 55 | ######类别目录路径 56 | classes="c"+str(lable) 57 | class_path=os.path.join(des_folder_train,classes) 58 | # 没有这个文件夹,就创造这个文件夹 59 | if not os.path.exists(class_path): 60 | os.makedirs(class_path) 61 | 62 | #print("class_path:",class_path) 63 | despath=os.path.join(class_path,base_path) 64 | #print("despath:",despath) 65 | img.save(despath) 66 | 67 | 68 | #####生成flower data validation的分类数据 ####### 69 | des_folder_validation="prepare_pic\\validation" 70 | 71 | for tid in validation: 72 | ######## open image and get label ######## 73 | img=Image.open(flower_dir[tid]) 74 | #print(flower_dir[tid]) 75 | img = img.resize((256, 256),Image.ANTIALIAS) 76 | lable=labels[tid] 77 | #print(lable) 78 | path=flower_dir[tid] 79 | print("path:",path) 80 | base_path=os.path.basename(path) 81 | print("base_path:",base_path) 82 | classes="c"+str(lable) 83 | class_path=os.path.join(des_folder_validation,classes) 84 | # 没有这个文件夹,就创造这个文件夹 85 | if not os.path.exists(class_path): 86 | 87 | os.makedirs(class_path) 88 | print("class_path:",class_path) 89 | despath=os.path.join(class_path,base_path) 90 | print("despath:",despath) 91 | img.save(despath) 92 | 93 | 94 | #####生成flower data test的分类数据 ####### 95 | des_folder_test="prepare_pic\\test" 96 | for tid in test: 97 | ######## open image and get label ######## 98 | img=Image.open(flower_dir[tid]) 99 | #print(flower_dir[tid]) 100 | img = img.resize((256, 256),Image.ANTIALIAS) 101 | lable=labels[tid] 102 | #print(lable) 103 | path=flower_dir[tid] 104 | print("path:",path) 105 | base_path=os.path.basename(path) 106 | print("base_path:",base_path) 107 | classes="c"+str(lable) 108 | class_path=os.path.join(des_folder_test,classes) 109 | # 没有这个文件夹,就创造这个文件夹 110 | if not os.path.exists(class_path): 111 | os.makedirs(class_path) 112 | print("class_path:",class_path) 113 | despath=os.path.join(class_path,base_path) 114 | print("despath:",despath) 115 | img.save(despath) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 基于移动终端的花卉识别系统 2 | 3 | (**文章图片加载有问题的话可以参考**[博客](https://juejin.im/post/5e7d7ac1518825736f6400c1)) 4 | 5 | #### 介绍 6 | 7 |   python开发的分类器,java开发的安卓软件 8 | 9 |   现在我想把这两部分分到两个仓库中,**本仓库是花卉分类器**。 10 | 11 |   花卉分类器使用语言:Python,使用深度学习框架:PyTorch,方法:训练卷积神经网络 12 | 13 |   关于PyTorch的基本用法可以参考博客:https://juejin.im/post/5e759ae3e51d4526f76edcd3 14 | 15 |   更多内容请关注博客:https://juejin.im/user/5ddfe924e51d4532da11c157 16 | 17 | #### 数据集 18 | 19 |   data文件夹内存放了我使用的20种花卉数据集。日后会继续扩增。 20 | 21 |   近期发现了一个不错的图片网站,可以自行扩充数据(https://www.ivsky.com/) 22 | 23 |   数据来源主要取决于3个方面: 24 | 25 | - 5种花卉数据集,每类花卉包含600张到900张不等的图片 26 | - 来源于Oxford 102 Flowers数据集,该数据集包含102类英国花卉数据,每个类别包含 40 到 258 张图像 27 | - 最后一部分来源于百度图片,使用python程序批量采集花卉图像数据 28 | 29 |   有些花卉的name是我自己写的,采用的是花卉的学名,通常是拉丁文。 30 | 31 |   我选用的20种花卉数据如下所示: 32 | 33 | | 编号 | name | 名称 | 数量 | 34 | | :--: | :----------------: | :----: | :--: | 35 | | 1 | daisy | 雏菊 | 633 | 36 | | 2 | dandelion | 蒲公英 | 898 | 37 | | 3 | roses | 玫瑰花 | 641 | 38 | | 4 | sunflowers | 向日葵 | 699 | 39 | | 5 | tulips | 郁金香 | 799 | 40 | | 6 | Nymphaea | 睡莲 | 226 | 41 | | 7 | Tropaeolum_majus | 旱金莲 | 196 | 42 | | 8 | Digitalis_purpurea | 毛地黄 | 190 | 43 | | 9 | peach_blossom | 桃花 | 55 | 44 | | 10 | Jasminum | 茉莉花 | 60 | 45 | | 11 | Matthiola | 紫罗兰 | 54 | 46 | | 12 | Rosa | 月季 | 54 | 47 | | 13 | Rhododendron | 杜鹃花 | 57 | 48 | | 14 | Dianthus | 康乃馨 | 48 | 49 | | 15 | Cerasus | 樱花 | 50 | 50 | | 16 | Narcissus | 水仙花 | 52 | 51 | | 17 | Pharbitis | 牵牛花 | 46 | 52 | | 18 | Gazania | 勋章菊 | 108 | 53 | | 19 | Eschscholtzia | 花菱草 | 82 | 54 | | 20 | Tithonia | 肿柄菊 | 47 | 55 | 56 |   花卉样式: 57 | 58 | ![image-20200323134803440](https://gitee.com/Sjcun/flower_recognition/raw/master/upload/image-20200323134803440.png) 59 | 60 | #### 数据扩展 61 | 62 |   收集到的每种花卉数量不是很多,而像樱花、水仙花等都是每类50张左右,数据量过少,若直接拿去训练模型的话,正确率不会太高,且会发生严重的过拟合。 63 | 64 |   目前使用的数据扩展方法分为三种:镜像翻转、上下翻转和椒盐噪声。 65 | 66 |   **镜像翻转**:将图片左右翻转,生成新的数据 67 | 68 | ![image-20200323161939465](https://gitee.com/Sjcun/flower_recognition/raw/master/upload/image-20200323161939465.png) 69 | 70 |   **上下翻转**:将图片上下翻转,生成新的数据 71 | 72 | ![image-20200323162157922](https://gitee.com/Sjcun/flower_recognition/raw/master/upload/image-20200323162157922.png) 73 | 74 |   **椒盐噪声**:为图片增加噪声,生成新的数据 75 | 76 | ![image-20200323162309627](https://gitee.com/Sjcun/flower_recognition/raw/master/upload/image-20200323162309627.png) 77 | 78 |   扩展后的花卉数量如下所示: 79 | 80 | | 编号 | name | 名称 | 数量 | 增量后数量 | 81 | | :--: | :----------------: | :----: | :--: | :--------: | 82 | | 1 | daisy | 雏菊 | 633 | 2496 | 83 | | 2 | dandelion | 蒲公英 | 898 | 3588 | 84 | | 3 | roses | 玫瑰花 | 641 | 2400 | 85 | | 4 | sunflowers | 向日葵 | 699 | 2796 | 86 | | 5 | tulips | 郁金香 | 799 | 3196 | 87 | | 6 | Nymphaea | 睡莲 | 226 | 1808 | 88 | | 7 | Tropaeolum_majus | 旱金莲 | 196 | 1568 | 89 | | 8 | Digitalis_purpurea | 毛地黄 | 190 | 1360 | 90 | | 9 | peach_blossom | 桃花 | 55 | 440 | 91 | | 10 | Jasminum | 茉莉花 | 60 | 480 | 92 | | 11 | Matthiola | 紫罗兰 | 54 | 432 | 93 | | 12 | Rosa | 月季 | 54 | 432 | 94 | | 13 | Rhododendron | 杜鹃花 | 57 | 456 | 95 | | 14 | Dianthus | 康乃馨 | 48 | 384 | 96 | | 15 | Cerasus | 樱花 | 50 | 400 | 97 | | 16 | Narcissus | 水仙花 | 52 | 416 | 98 | | 17 | Pharbitis | 牵牛花 | 46 | 368 | 99 | | 18 | Gazania | 勋章菊 | 108 | 464 | 100 | | 19 | Eschscholtzia | 花菱草 | 82 | 656 | 101 | | 20 | Tithonia | 肿柄菊 | 47 | 376 | 102 | 103 | #### 数据切分 104 | 105 |   数据集准备好了,要切分为训练集、验证集和测试集。 106 | 107 |   在PyTorch的torchvision包内有一个关于计算机视觉的数据读取类`ImageFolder`,它的调用方式是torchvision.datasets.ImageFolder,主要功能是读取图片数据,且要求图片是下图这种存放方式。 108 | 109 | ![123](https://gitee.com/Sjcun/flower_recognition/raw/master/upload/123.jpg) 110 | 111 |   然后这样来调用类: 112 | 113 | ```python 114 | train_dataset = ImageFolder(root='./data/train/',transform=data_transform) 115 | ``` 116 | 117 |   root表示根目录,transform表示数据预处理方式。 118 | 119 |   这种方式将train目录下的cat和dog文件夹内的所有图片作为训练集,而文件夹名cat和dog作为标签数据进行训练。 120 | 121 |   因此我们就要像ImageFolder要求的那样切分数据集。 122 | 123 | ![image-20200323184552442](https://gitee.com/Sjcun/flower_recognition/raw/master/upload/image-20200323184552442.jpg) 124 | 125 |   我切分的比例是3:1:1。实际上,如果不想切分出验证集的话,可以将验证集的代码部分注掉,直接使用训练集和测试集也是可以的。 126 | 127 | ```python 128 | #比例 129 | scale = [0.6, 0.2, 0.2] 130 | ``` 131 | 132 |   至此,数据部分准备完成了。 133 | 134 | #### 模型训练 135 | 136 |   目前采用的是AlexNet和VGG16两种网络,其实两种网络比较相似,不同的是VGG16较于AlexNet更“深” 137 | 138 |   AlexNet网络结构如下: 139 | 140 | ![456](https://gitee.com/Sjcun/flower_recognition/raw/master/upload/456.jpg) 141 | 142 |   VGG16网络结构如下: 143 | 144 | ![789](https://gitee.com/Sjcun/flower_recognition/raw/master/upload/789.jpg) 145 | 146 |   二者相比较,VGG16准确率更高一些,可见更深的网络对于提高准确率有一定的帮助。 147 | 148 |   AlexNet训练过程中的准确率变化如下: 149 | 150 | ![image-20200323225509766](https://gitee.com/Sjcun/flower_recognition/raw/master/upload/image-20200323225509766.png) 151 | 152 |   VGG16经历200个epoch训练的准确率变化如下: 153 | 154 | ![image-20200323230153110](https://gitee.com/Sjcun/flower_recognition/raw/master/upload/image-20200323230153110.png) 155 | 156 |   AlexNet经历了500个epoch训练后最终能达到83%的准确率 157 | 158 |   VGG16经历了200个epoch训练后最终能达到90%的正确率 159 | 160 | #### 模型验证 161 | 162 |   除了验证测试集以外,还可以用图片去验证模型的训练效果。 163 | 164 |   选用的是验证效果比较好的VGG16网络,读取的参数是200个epoch训练后的参数 165 | 166 | ![image-20200323231914637](https://gitee.com/Sjcun/flower_recognition/raw/master/upload/image-20200323231914637.jpg) 167 | 168 | ![image-20200323231935143](https://gitee.com/Sjcun/flower_recognition/raw/master/upload/image-20200323231935143.jpg) 169 | 170 |   可以看到,测试的效果还是非常好的,模型可以非常准确的判断花卉的种类。 171 | 172 | > 一个补充 173 | 174 | 如果你恰好有个云服务器,又想做一个web服务器的话,可以尝试flask框架(当然在本地也可以使用flask,不过这个就没有多大意义了) 175 | 176 | 按照`flask`文件夹中的程序,在服务器上运行之后,然后打开一个新网页,输入`IP:端口?图片地址`就可以做识别了。 177 | 178 | ![image-20200423135159846](https://gitee.com/Sjcun/flower_recognition/raw/master/upload/image-20200423135159846.png) 179 | 180 | 其中`sjcup.cn`是我的一个域名,这里可以替换为自己服务器的`公网IP` 181 | 182 | 另外还有一个坑就是图片名称不可为中文名称,否则会检测不到 183 | 184 | **公网IP无法访问**的问题可以根据[博客](https://juejin.im/post/5ea06b2151882573947254d4)做一些修改 185 | 186 | #### 下一步计划 187 | 188 | - 扩增数据集,可以识别更多类别的花卉 189 | - 采用新的网络训练,如Inception V3 190 | 191 | 192 | 193 | 194 | 195 | 196 | 197 | 198 | 199 | 200 | 201 | 202 | 203 | 204 | 205 | --------------------------------------------------------------------------------