├── README.md ├── prediction.py └── train.py /README.md: -------------------------------------------------------------------------------- 1 | 训练集、测试集图片应按照【正样本、负样本】顺序命名,且从0开始命名( 比如训练集中 0.jpg-4.jpg为正样本 5.jpg-9.jpg为负样本) 2 | -------------------------------------------------------------------------------- /prediction.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import matplotlib.image as mpimg 3 | import numpy as np 4 | from sklearn.svm import SVR 5 | from skimage import feature as skft 6 | import torch 7 | from train import photo, loadPicture, texture_detect 8 | 9 | 10 | # 预测 11 | def prediction(n,path,path1): 12 | """ 13 | n:验证图片数量 14 | path:验证数据集路径 15 | path1:处理后的数据集存放路径 16 | """ 17 | train_hist = torch.load('train.mat') 18 | train_label = torch.load('label.mat') 19 | photo(n, path, path1) 20 | mdata, mlabel = loadPicture(n, path1) 21 | hist = texture_detect(n, 1, mdata) 22 | svr_rbf = SVR(kernel='rbf', C=1e3, gamma=0.1) 23 | model = torch.load('1.pkl') 24 | clf = model.fit(train_hist, train_label) 25 | 26 | predicted = model.predict(hist) 27 | print(predicted) 28 | 29 | 30 | prediction() 31 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import matplotlib.image as mpimg 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | from sklearn.multiclass import OneVsRestClassifier 6 | from sklearn.svm import SVR 7 | from skimage import feature as skft 8 | from sklearn.ensemble import GradientBoostingClassifier 9 | from sklearn.grid_search import GridSearchCV 10 | from sklearn import metrics 11 | import pandas as pd 12 | import torch 13 | 14 | 15 | # ROI提取正样本 16 | def positive_roi(par, path, task): 17 | """ 18 | par:bbox文件 19 | path:需要存储正样本的路径 20 | task:原图 21 | """ 22 | x1 = [] # 左 23 | x2 = [] 24 | y1 = [] # 上 25 | y2 = [] 26 | img = cv2.imread(task) 27 | 28 | with open(par) as txt: # 打开文件 29 | for line in txt: 30 | line = line.strip() # 去除多余空格 31 | line1 = line.split() 32 | line2 = list(map(lambda x: float(x), line1)) # 将list中的字符串转化为数字 33 | y1.append(round(line2[2])) 34 | y2.append(round(line2[4])) 35 | x1.append(round(line2[1])) 36 | x2.append(round(line2[3])) 37 | 38 | for a, b, c, d, i in zip(y1, y2, x1, x2, range(0, 24)): 39 | roi = img[a:b, c:d] # 上下,左右 40 | # roi[150:500, 0:1000] = 0 #添加黑边 41 | cv2.imwrite(path + '%d.jpg' % i, roi) 42 | 43 | 44 | # 处理数据集 45 | # 训练集和测试集要按照正负样本的顺序命名(从0.jpg开始) 46 | def photo(n, path0, path1): 47 | """ 48 | n:训练集、测试集或验证集所含图片总数 49 | path0:训练集、测试集或验证集所在路径 50 | path1:处理后训练集、测试集或验证集存放路径 51 | """ 52 | for i in range(n): 53 | img = cv2.imread(path0 + '%d.jpg' % i) 54 | img1 = cv2.resize(img, (512, 512)) # resize成(512,512)大小 55 | # print(img1) 56 | img2 = cv2.cvtColor(img1, cv2.COLOR_BGR2GRAY) # 转为灰色图像(此方法只支持二维图像) 57 | # print(img2) 58 | # print(img2.shape) 59 | cv2.imwrite(path1 + '%d.tiff' % i, img2) # 以tiff格式存储 60 | 61 | 62 | # 导入数据集 63 | def loadPicture(n, path): 64 | """ 65 | n:训练集、测试集或验证集所含图片总数 66 | path:处理后训练集、测试集或验证集存放路径 67 | """ 68 | index = 0 69 | data = np.zeros((n, 512, 512)) 70 | label = np.zeros((n)) 71 | for i in np.arange(n): 72 | image = mpimg.imread(path + str(i) + '.tiff') 73 | vdata = np.zeros((512, 512)) 74 | vdata[0:image.shape[0], 0:image.shape[1]] = image 75 | if i < (n / 2): 76 | data[index, :, :] = vdata 77 | label[index] = 0 78 | index += 1 79 | else: 80 | data[index, :, :] = vdata 81 | label[index] = 1 82 | index += 1 83 | return data, label 84 | 85 | 86 | # 使用LBP方法提取图像的纹理特征 87 | def texture_detect(n, radius, data): 88 | """ 89 | n:训练集、测试集或验证集所含图片总数 90 | radius:半径 91 | data:需要提取的数据 92 | """ 93 | n_point = radius * 8 94 | hist = np.zeros((n, 256)) 95 | for i in np.arange(n): 96 | # 使用LBP方法提取图像的纹理特征 97 | lbp = skft.local_binary_pattern(data[i], n_point, radius, 'default') 98 | # 统计图像的直方图 99 | max_bins = int(lbp.max() + 1) 100 | # hist size:256 101 | hist[i], _ = np.histogram(lbp, normed=True, bins=max_bins, range=(0, max_bins)) 102 | return hist 103 | 104 | 105 | if __name__ == '__main__': 106 | photo(50, 'train1/', 'train/') 107 | photo(10, 'qwe/', 'test/') 108 | train_data, train_label = loadPicture(50, 'train/') 109 | test_data, test_label = loadPicture(10, 'test/') 110 | train_hist = texture_detect(50, 1, train_data) 111 | test_hist = texture_detect(10, 1, test_data) 112 | svr_rbf = SVR(kernel='rbf', C=1e3, gamma=0.1) 113 | 114 | model = OneVsRestClassifier(svr_rbf, -1) 115 | clf = model.fit(train_hist, train_label) 116 | 117 | # 测试模型 118 | clf.score(test_hist, test_label) 119 | 120 | # 保存模型及参数 121 | torch.save(model, '1.pkl') 122 | torch.save(train_hist, 'train.mat') 123 | torch.save(train_label, 'label.mat') 124 | --------------------------------------------------------------------------------