├── README.md ├── TestModel.py ├── LoadImage.py ├── Judge.py ├── PacketToImage.py ├── TenClassificationCNN.py └── LICENSE /README.md: -------------------------------------------------------------------------------- 1 | # TrafficClassification 2 | using 2d-cnn to classify network traffic 3 | -------------------------------------------------------------------------------- /TestModel.py: -------------------------------------------------------------------------------- 1 | #测试模型或预测结果 2 | import tensorflow as tf 3 | import cv2 as cv 4 | import numpy as np 5 | import copy 6 | images = [] 7 | #模型存放位置 8 | modelpath = "E:\\Flow classification\\TransformImage\\model(10)" 9 | #测试集文件位置 10 | testimagepath = "E:\\Flow classification\\TransformImage\\test" 11 | #模型名字 12 | modelname = 'model.ckpt-100.meta' 13 | #用于测试的图片名 14 | testimagename = 'facebook_test279.jpg' 15 | #种类 16 | kindclass = 10 17 | #实验次数 18 | exp_count = 44 19 | #结果存放位置 20 | resultpath = 'E:\\Flow classification\\TransformImage\\Result\\result{}.txt'.format(exp_count) 21 | with tf.Session().as_default() as sess: 22 | labeldicts = {} 23 | labelexa = [0 for i in range(kindclass)] 24 | #获取标签 25 | with open(resultpath,'r') as f: 26 | for line in f.readlines(): 27 | line = line.strip('\n') 28 | labeltemp = line.split(' ') 29 | classname = labeltemp[0] 30 | classflagnum = int(labeltemp[-1]) 31 | classflaglist = copy.deepcopy(labelexa) 32 | classflaglist[classflagnum] = 1 33 | labeldicts.update({classname:classflagnum}) 34 | #使用模型 35 | saver = tf.train.import_meta_graph(modelpath+"\\"+modelname) 36 | saver.restore(sess,tf.train.latest_checkpoint(modelpath)) 37 | image = cv.imread(testimagepath+'\\'+testimagename,flags=0) 38 | image = image[np.newaxis,:,:,np.newaxis] 39 | pred= tf.get_collection('pred_network')[0] 40 | prednum = tf.arg_max(pred,1) 41 | graph = tf.get_default_graph() 42 | x1 = graph.get_operation_by_name('input/x').outputs[0] 43 | y1 = graph.get_operation_by_name('input/y').outputs[0] 44 | keep_prob1 = graph.get_operation_by_name('input/keep_prob').outputs[0] 45 | result1 = sess.run(prednum, feed_dict={x1: image, keep_prob1: 1.0}) 46 | for onelabel in labeldicts.items(): 47 | if onelabel[1]== result1: 48 | print("预测结果为",onelabel[0]) -------------------------------------------------------------------------------- /LoadImage.py: -------------------------------------------------------------------------------- 1 | #加载训练集用于训练 2 | import os 3 | import random 4 | import cv2 as cv 5 | import numpy as np 6 | import re 7 | import copy 8 | classkind= 10 9 | labeldict = {} 10 | 11 | def train_image(path,classname=None,labeldicts = None): 12 | global labeldict 13 | count = 0 14 | templabels = [0 for i in range(classkind)] 15 | images = [] 16 | labels = [] 17 | if labeldicts !=None: 18 | labeldict=labeldicts 19 | imagenamelist = [] 20 | if classname==None: 21 | imagenamelist = [path+"\\"+name for name in os.listdir(path) if name.lower().endswith('jpg')] #生成一个列表 lower()把所有大写字母转换成小写字母 22 | else: 23 | imagenamelist = [path+"\\"+name for name in os.listdir(path) if name.lower().endswith('jpg')and name.lower().startswith(classname)] 24 | random.shuffle(imagenamelist) 25 | random.shuffle(imagenamelist) #随机排序 26 | for i in imagenamelist: 27 | image = cv.imread(i,flags=0) #读入图像 28 | image = image[:,:,np.newaxis] #添加数组维度 29 | images.append(image) 30 | pattern = re.compile('^[a-z]+') 31 | vpnpattern = re.compile('(vpn_[a-z]+)') 32 | name = i.split('\\')[-1] 33 | if name.startswith('vpn'): 34 | name = vpnpattern.findall(name.lower())[0] 35 | else: 36 | name = pattern.findall(name.lower())[0] 37 | if name in labeldict: 38 | label = labeldict[name] 39 | labels.append(label) 40 | count +=1 41 | else: 42 | labellength = len(labeldict) 43 | templabel = copy.deepcopy(templabels) 44 | templabel[labellength] = 1 45 | labeldict.update({name:templabel}) 46 | label = templabel 47 | labels.append(label) 48 | count += 1 49 | images = np.array(images) 50 | labels = np.array(labels) 51 | if classname!=None: 52 | return images, labels 53 | else: 54 | return images,labels,labeldict 55 | -------------------------------------------------------------------------------- /Judge.py: -------------------------------------------------------------------------------- 1 | #对训练好的模型进行评判指标计算 2 | import tensorflow as tf 3 | from LoadImage import train_image 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | from itertools import cycle 7 | from sklearn.metrics import roc_curve, auc 8 | from scipy import interp 9 | import copy 10 | 11 | 12 | images = [] 13 | #模型的存放位置 14 | modelpath = "E:\\Flow classification\\TransformImage\\model(10)" 15 | #测试集存放位置 16 | testimagepath = "E:\\Flow classification\\TransformImage\\test\\" 17 | #模型的名字 18 | modelname = 'model.ckpt-100.meta' 19 | #实验的次数 20 | exp_count = 44 21 | #分类的大小 22 | kindclass=10 23 | #分类和标签的对应关系 24 | resultpath = 'E:\\Flow classification\\TransformImage\\Result\\result{}.txt'.format(exp_count) 25 | 26 | #ROC曲线是一种显示在分类模型所在分类阈值下的效果的图表 27 | def ROC(classdict,testlabel,pred): 28 | fpr = dict() 29 | tpr = dict() 30 | thresholds =dict() 31 | roc_auc = dict() 32 | for i in range(len(classdict)): 33 | t = testlabel[:, i] 34 | p = pred[:, i] 35 | sum1 = 0 36 | for j in range(len(t)): 37 | sum1 =sum1+(t[j]-p[j]) 38 | #调用roc_curve计算每中阈值的fpr和tpr值 39 | fpr[i],tpr[i],thresholds[i] = roc_curve(testlabel[:,i].ravel(),pred[:,i].ravel()) 40 | roc_auc[i] = auc(fpr[i],tpr[i]) 41 | all_fpr = np.unique(np.concatenate([fpr[i] for i in range(len(classdict))])) 42 | mean_tpr = np.zeros_like(all_fpr) 43 | for i in range(len(classdict)): 44 | mean_tpr +=interp(all_fpr,fpr[i],tpr[i]) 45 | mean_tpr/=len(classdict) 46 | fpr["macro"] = all_fpr 47 | tpr["macro"] = mean_tpr 48 | roc_auc["macro"] = auc(fpr["macro"] ,tpr["macro"] ) 49 | 50 | #画图线的粗细 51 | lw = 2 52 | #画各种曲线的平均值曲线 53 | plt.figure(figsize=(8, 6)) 54 | plt.plot(fpr["macro"], tpr["macro"], 55 | label='macro-average ROC curve (area = {0:0.2f})' 56 | ''.format(roc_auc["macro"]), 57 | color='navy', linestyle=':', linewidth=4) 58 | #定义曲线颜色 59 | colors = [] 60 | if kindclass==20: 61 | colors = cycle(['aqua', 'darkorange', 'cornflowerblue', 'pink', 'crimson', 'orchid', 'purple', 'indigo', 'black', 'slategray','blue','darkslateblue','yellow','red','cyan','orange','tan','brown','olive','gold']) 62 | else: 63 | colors = cycle(['aqua', 'darkorange', 'cornflowerblue','pink','crimson','orchid','purple','indigo','black','slategray']) 64 | #对每个分类画曲线 65 | for i, color in zip(range(len(classdict)), colors): 66 | name = '' 67 | for kv in classdict.items(): 68 | if kv[1] == i: 69 | name = kv[0] 70 | plt.plot(fpr[i], tpr[i], color=color, lw=lw, 71 | label='ROC curve of class {0} (area = {1:0.2f})' 72 | ''.format(name, roc_auc[i])) 73 | plt.plot([0, 1], [0, 1], 'k--', lw=lw) 74 | plt.xlim([0.0, 1.0]) 75 | plt.ylim([0.0, 1.05]) 76 | plt.xlabel('False Positive Rate') 77 | plt.ylabel('True Positive Rate') 78 | plt.title('Some extension of Receiver operating characteristic to multi-class') 79 | font = {'size':7} 80 | plt.legend(loc="lower right",prop=font) 81 | plt.show() 82 | 83 | 84 | 85 | with tf.Session().as_default() as sess: 86 | classifidicts = {} 87 | classdict = {} 88 | labelexa = [0for i in range(kindclass)] 89 | with open(resultpath,'r') as f: 90 | for line in f.readlines(): 91 | line = line.strip('\n') 92 | labeltemp = line.split(' ') 93 | classname = labeltemp[0] 94 | classflagnum = int(labeltemp[-1]) 95 | classflaglist = copy.deepcopy(labelexa) 96 | classflaglist[classflagnum] = 1 97 | classifidicts.update({classname: classflaglist}) 98 | classdict.update({classname:classflagnum}) 99 | #调用模型 100 | saver = tf.train.import_meta_graph(modelpath+'\\'+modelname) 101 | saver.restore(sess,tf.train.latest_checkpoint(modelpath)) 102 | #获取输出 103 | pred= tf.get_collection('pred_network')[0] 104 | graph = tf.get_default_graph() 105 | x1 = graph.get_operation_by_name('input/x').outputs[0] 106 | keep_prob1 = graph.get_operation_by_name('input/keep_prob').outputs[0] 107 | test_x, test_y ,_= train_image(testimagepath,labeldicts=classifidicts) 108 | #获取输出 109 | predict = sess.run(pred, feed_dict={x1: test_x, keep_prob1: 1.0}) 110 | #画根据结果画ROC曲线 111 | ROC(classdict, test_y, predict) 112 | -------------------------------------------------------------------------------- /PacketToImage.py: -------------------------------------------------------------------------------- 1 | #将流量包转化为图片 2 | import struct #在字符串和字节流之间转换 3 | import scipy.misc as sm #将数组保存成图像形式 4 | import numpy as np 5 | import os #对路径进行操作 6 | import shutil #引用拷贝覆盖的方法 7 | 8 | 9 | def TransformToImage(pixelList,classname,step,imagepath): 10 | ''' 11 | :param pixelList: 图像的像素列表 12 | :param classname: 分类的名称 13 | :param step: 第多少步 14 | :param imagepath: 保存图像的位置 15 | :return: 16 | ''' 17 | traincount = (step//10)*9+(step%10) #区分训练集和测试集 18 | testcount = step//10 19 | if step % 10 in range(1,10): #训练集 20 | newPixels = np.reshape(pixelList,(32,32)) #reshape 给数组一个新的形状而不改变其数据 21 | if os.path.exists(imagepath+"\\train"): #转换成图片 22 | sm.imsave(imagepath+"\\train\\"+classname+"_train"+str(traincount)+'.jpg',newPixels) 23 | else: 24 | os.makedirs(imagepath+"\\train") #递归创建目录 25 | sm.imsave(imagepath + "\\train\\" + classname +"_train"+ str(traincount) + '.jpg', newPixels) 26 | if step % 10in [0]: #测试集 27 | newPixels = np.reshape(pixelList, (32, 32)) 28 | if os.path.exists(imagepath+"\\test"): 29 | sm.imsave(imagepath+"\\test\\"+classname+"_test"+str(traincount)+'.jpg',newPixels) 30 | else: 31 | os.makedirs(imagepath + "\\test") 32 | sm.imsave(imagepath + "\\test\\" + classname +"_test"+ str(traincount) + '.jpg', newPixels) 33 | def FillImage(step,classname,count,imagepath): 34 | ''' 35 | :param step: 当前进行到多少步 36 | :param classname: 分类的名称 37 | :param count: 将要到达的步数 38 | :param imagepath: 保存文件的位置 39 | :return:复制的个数 40 | ''' 41 | temp = 0 42 | new_train_num = (step//10)*9+(step%10) 43 | mult = (((count//10)*9)//new_train_num)-1 44 | quot = ((count//10)*9) % new_train_num 45 | for i in range(1,mult+1): 46 | for j in range(1,new_train_num+1): 47 | current_step = new_train_num*i+j 48 | shutil.copyfile(imagepath+"\\train\\"+classname+"_train"+str(j)+".jpg",imagepath+"\\train\\"+classname+"_train"+str(current_step)+".jpg") #复制文件到另一目录 49 | temp +=1 50 | for z in range(1,quot+1): 51 | current_step = (mult+1)*new_train_num+z 52 | shutil.copyfile(imagepath + "\\train\\" + classname + "_train" + str(z) + ".jpg",imagepath + "\\train\\" + classname + "_train" + str(current_step) + ".jpg") 53 | temp += 1 54 | return temp 55 | def ReadFile(path,classification,count,imagepath): 56 | ''' 57 | :param path: 存放pcap文件的地址 58 | :param classification: 列表 表示分类的种类 59 | :param count: 最多制作的包的个数 60 | :param imagepath: 保存图像的位置 61 | :return: 返回是以分类名为Key,value是所有该分类的不超过count个的列表 62 | ''' 63 | num = 0 64 | #加入VPN流量分类 65 | vpn_class = map(lambda x:"vpn_"+x,classification) 66 | classification = classification+list(vpn_class) 67 | 68 | #单纯的VPN流量分类 69 | #classification = vpn_class 70 | 71 | 72 | for classname in classification: 73 | step = 0 74 | temp = 0 75 | total = 0 76 | for onepcap in os.listdir(path): #在目录中进行遍历 77 | if onepcap.startswith(classname) and onepcap.endswith(".pcap"): 78 | with open(path + "\\" + onepcap, 'rb') as f: #读文件 79 | data = f.read() 80 | pcap_header = {} 81 | #包头的固定格式 共24字节 82 | pcap_header['magic_number'] = data[0:4] 83 | pcap_header['version_major'] = data[4:6] 84 | pcap_header['version_minor'] = data[6:8] 85 | pcap_header['thiszone'] = data[8:12] 86 | pcap_header['sigfigs'] = data[12:16] 87 | pcap_header['snaplen'] = data[16:20] 88 | pcap_header['linktype'] = data[20:24] 89 | # print(pcap_header) 90 | pcap_packet_header = {} 91 | #包内容 92 | i = 24 93 | while (i < len(data)): 94 | pcap_packet_header['GMTtime'] = data[i:i + 4] #截取 95 | pcap_packet_header['MicroTime'] = data[i + 4:i + 8] 96 | pcap_packet_header['caplen'] = data[i + 8:i + 12] 97 | pcap_packet_header['len'] = data[i + 12:i + 16] 98 | #求出此包的长度 99 | packet_len = struct.unpack('I', pcap_packet_header['len'])[0] #解包 I指unsignedint 返回一个元组 100 | if packet_len<=1024: 101 | pixels = np.zeros(1024)# 返回长度为1024的数组 102 | packet_pixel = [pixel for pixel in data[i+16:i+16+packet_len]] 103 | pixels[0:len(packet_pixel)] = packet_pixel 104 | else: 105 | pixels = np.zeros(1024) 106 | packet_pixel = [pixel for pixel in data[i + 16:i + 16 + 1024]] 107 | pixels[0:len(packet_pixel)] = packet_pixel 108 | step += 1 109 | num += 1 110 | TransformToImage(pixels,classname,step,imagepath) 111 | i = i + packet_len + 16 112 | if step >= count: 113 | break 114 | print(onepcap,step) 115 | if step>=count: 116 | total = int(count*0.9) 117 | break 118 | if step < count: 119 | temp = FillImage(step, classname, count, imagepath) 120 | total = (step//10)*9+step%10 + temp 121 | num += temp 122 | print('保存'+classname+"类别共"+str(total)+"个,其中扩充"+str(temp)+"个") 123 | print("共保存"+str(num)+"个训练测试图像包") 124 | 125 | 126 | if __name__ == '__main__': 127 | path = "E:\\CompletePCAPs" 128 | imagepath = "E:\\Flow classification\\TransformImage" 129 | classification = ['voipbuster','facebook','email','netflix','hangouts','icq','youtube','skype','vimeo','spotify'] 130 | ReadFile(path,classification,1000,imagepath) 131 | 132 | 133 | -------------------------------------------------------------------------------- /TenClassificationCNN.py: -------------------------------------------------------------------------------- 1 | #分类 2 | import tensorflow as tf 3 | import os 4 | import time 5 | from openpyxl import Workbook #操作excel 6 | from openpyxl import load_workbook 7 | from LoadImage import train_image 8 | 9 | #训练集位置 10 | datapath = "E:\\Flow classification\\TransformImage\\train" 11 | #测试集位置 12 | testpath = "E:\\Flow classification\\TransformImage\\test" 13 | #标签存放位置 14 | resultpath = "E:\\Flow classification\\TransformImage\\Result" 15 | #实验次数 16 | exp_count = 44 17 | #excel存放位置 18 | result_excel = "E:\\Flow classification\\TransformImage\\ExpResult" 19 | #模型名字 20 | modelname = "model(10)" 21 | #每个批次大小 22 | batch_size = 15 23 | #计算一共有多少个批次 24 | n_batch = len(os.listdir(datapath))//batch_size 25 | #分类和标签的对应关系 26 | labeldict = {} 27 | 28 | #Tensorboard参数概要 29 | def variable_summaries(var): 30 | with tf.name_scope("summaries"): #解决命名冲突问题 31 | mean = tf.reduce_mean(var) 32 | tf.summary.scalar('mean',mean)#平均值 33 | with tf.name_scope("stddev"): 34 | stddev = tf.sqrt(tf.reduce_mean(tf.square(var-mean)))#计算平方根 35 | tf.summary.scalar('stddev',stddev)#标准差 36 | tf.summary.scalar('max',tf.reduce_max(var))#最大值 37 | tf.summary.scalar('min',tf.reduce_min(var))#最小值 38 | tf.summary.histogram('histogram',var)#直方图 39 | 40 | #初始化权值 41 | def weight_variable(shape): 42 | initial = tf.truncated_normal(shape,stddev=0.1)#生成一个截断的正太分布 43 | return tf.Variable(initial) 44 | 45 | #初始化偏置 46 | def bias_variable(shape): 47 | initial = tf.constant(0.1,shape=shape)#创建常量 48 | return tf.Variable(initial) 49 | 50 | #卷积层 51 | def conv2d_same(x,W): 52 | #x input tensor of shape '[batch, in_height. in_width. in_channels]' 输入的图片张量 53 | #W filter / kernel tensor of shape [filter_heiht. filter_width, in_channels, out_channels] 卷积核 54 | #'strides[0] = strides[3] = 1' strides[1]代表x方向的步长,strides[2]代表y方向的步长 55 | #padding:A 'string' from:'"SAME", "VALID"' SAME表示考虑边界 56 | return tf.nn.conv2d(x,W,strides=[1,1,1,1],padding='SAME') 57 | 58 | def conv2d_valid(x,W): 59 | return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding="VALID") 60 | 61 | #池化层 62 | def max_pool_2x2(x): 63 | #ksize[1,x,y,1] 参数依次是:输入(feature map),池化窗口大小,步长,边界 64 | return tf.nn.max_pool(x,ksize=[1,2,2,1],strides=[1,2,2,1],padding='SAME') 65 | 66 | 67 | with tf.name_scope('input'): 68 | x = tf.placeholder(tf.float32,[None,32,32,1],name='x') 69 | y = tf.placeholder(tf.float32,[None,10],name='y') 70 | keep_prob = tf.placeholder(tf.float32, name='keep_prob') 71 | 72 | with tf.name_scope('learning_rate'): 73 | LR = tf.Variable(1e-4,dtype=tf.float32) 74 | tf.summary.scalar("lr", LR) 75 | 76 | #改变x的格式转为4D的向量[batch,in_height.in_width, in_channels] 77 | with tf.name_scope('image_reshape'): 78 | x_image = tf.reshape(x,[-1,32,32,1]) 79 | tf.summary.image("input",x_image) 80 | 81 | with tf.name_scope('Conv_layer1'): 82 | # 初始化第一个卷积层的权值和偏置 83 | W_conv1 = weight_variable([3,3,1,4])#3*3的采样窗口,4个卷积核从1个平面抽取特征 84 | b_conv1 = bias_variable([4])#每一个卷积核一个偏置值 85 | #把x_image和权值进行卷积,在加上偏置值,然后应用relu激活函数 86 | with tf.name_scope('w_plus_b1'): 87 | res_conv1 = conv2d_valid(x_image,W_conv1)+b_conv1 88 | tf.summary.histogram('res_conv1',res_conv1) 89 | h_conv1 = tf.nn.relu(res_conv1,name='conv1_relu') 90 | h_pool1 = max_pool_2x2(h_conv1)#运行max_pooling 91 | with tf.name_scope('Conv_layer2'): 92 | #初始化第二个卷积层的权值和偏置 93 | W_conv2 = weight_variable([3,3,4,8])#5*5的采样窗口,8个卷积核从4个平面抽取特 94 | b_conv2 = bias_variable([8]) 95 | #把和权值进行卷积,在加上偏置值,然后应用relu激函数h_pool1 96 | with tf.name_scope('w_plus_b2'): 97 | res_conv2 = conv2d_same(h_pool1,W_conv2)+b_conv2 98 | tf.summary.histogram("res_conv2",res_conv2) 99 | h_conv2 = tf.nn.relu(res_conv2) 100 | h_pool2 = max_pool_2x2(h_conv2)#运行max_pooling 101 | 102 | #32*32的图片第一次卷积后还是28*28,第一次池化后边为15*15 103 | #第二次卷积后为15*15,第二次池化后变为了8*8 104 | #经过上面的操作后变为8张8*8的平面 105 | 106 | #初始化第一个全连接层的权值 107 | with tf.name_scope('FC1'): 108 | W_fc1 = weight_variable([8*8*8,2048])#上一层有8*8*8个神经院,全连接层有2048个神经元 109 | b_fc1 = bias_variable([2048])#2048个偏置值 110 | 111 | #把池化层2的输出扁平化为1维 112 | h_pool2_flat = tf.reshape(h_pool2,[-1,8*8*8]) 113 | #求第一个全连接层的输出 114 | h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat,W_fc1) + b_fc1) 115 | # keep_prob 用来表示神经元的输出概率 116 | h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob) 117 | 118 | with tf.name_scope('FC2'): 119 | 120 | #初始化第二个全连接层 121 | #进行二十分类 122 | W_fc2 = weight_variable([2048,10]) 123 | b_fc2 = bias_variable([10]) 124 | with tf.name_scope('Pr'): 125 | #计算输出 126 | preduction = tf.nn.softmax(tf.matmul(h_fc1_drop,W_fc2)+b_fc2) 127 | 128 | #保存模型 129 | saver = tf.train.Saver() 130 | tf.add_to_collection('pred_network',preduction ) 131 | #交叉熵代价函数 132 | with tf.name_scope('loss'): 133 | cross_entorpy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y,logits=preduction)) 134 | tf.summary.scalar("loss",cross_entorpy) 135 | #使用AdamOptinizer进行优化 136 | with tf.name_scope('train'): 137 | train_step = tf.train.AdamOptimizer(LR).minimize(cross_entorpy) 138 | #结果存放在一个bool列表中 139 | with tf.name_scope('accuracy'): 140 | with tf.name_scope("correct_prediction"): 141 | correct_prediction = tf.equal(tf.argmax(preduction,1),tf.argmax(y,1)) 142 | #求准确率 143 | with tf.name_scope('accuracy'): 144 | accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32)) 145 | tf.summary.scalar('accuracy',accuracy) 146 | 147 | #合并记录项 148 | merged = tf.summary.merge_all() 149 | 150 | with tf.Session() as sess: 151 | #初始化参数 152 | sess.run(tf.global_variables_initializer()) 153 | train_writer = tf.summary.FileWriter('logs/train/',sess.graph) 154 | start_time = time.time() 155 | if os.path.exists(result_excel): 156 | print("存在存放结果的文件夹") 157 | else: 158 | os.makedirs(result_excel) 159 | print("创建文件成功") 160 | start_make_img_time = time.time() 161 | #获取训练集和标签 162 | trainimages, trainlabels, labeldict = train_image(datapath) 163 | end_make_img_time = time.time() 164 | print("处理图片用时{}".format(end_make_img_time - start_make_img_time)) 165 | for epoch in range(1,101): 166 | end_make_img_time=time.time() 167 | for batch in range(n_batch): 168 | trainimage,trainlabel =trainimages[batch*batch_size:batch_size*(batch+1)], trainlabels[batch*batch_size:batch_size*(batch+1)] 169 | summary, _ = sess.run([merged, train_step], feed_dict={x: trainimage, y: trainlabel, keep_prob: 0.7}) 170 | end_train_time = time.time() 171 | train_writer.add_summary(summary,epoch) 172 | print("训练用时"+str(end_train_time - end_make_img_time)) 173 | class_count = [] 174 | class_acc = [] 175 | row = [] 176 | row.append(epoch) 177 | if epoch == 1: 178 | row_title = ["批次"]+list(labeldict.keys())+['综合准确率'] 179 | if os.path.exists(resultpath): 180 | with open(resultpath+"\\"+"result"+str(exp_count)+".txt",'w') as f: 181 | for line in labeldict.items(): 182 | ph = line[1].index(max(line[1])) 183 | line = line[0]+' '+str(ph)+"\n" 184 | f.writelines(line) 185 | else: 186 | os.makedirs(resultpath) 187 | with open(resultpath+"\\"+"result"+str(exp_count)+".txt",'w') as f: 188 | for line in labeldict.items(): 189 | ph = line[1].index(max(line[1])) 190 | line = line[0] + ' ' + str(ph) + "\n" 191 | f.writelines(line) 192 | #创建excel保存每次训练各个分类的准确率变化 193 | wb = Workbook() 194 | sheet = wb.active 195 | sheet.append(row_title) 196 | wb.save(result_excel+"\\result"+str(exp_count)+".xlsx") 197 | for classname in labeldict.keys(): 198 | #定义测试时相关数据保存的地方 199 | test_writer = tf.summary.FileWriter('logs/test/'+classname+'/') 200 | testfaceimage,testlabel = train_image(testpath,classname=classname) 201 | class_count.append(len(testlabel)) 202 | end_make_test_img_time = time.time() 203 | summary,one_class_acc = sess.run([merged,accuracy],feed_dict={x:testfaceimage,y:testlabel,keep_prob:1.0}) 204 | end_test_time = time.time() 205 | # 添加到同一个记录中 206 | test_writer.add_summary(summary,epoch) 207 | test_writer.flush() 208 | print("Tter " + str(epoch) + "类别:"+classname + "的测试准确率为" + str(one_class_acc) + "测试用时" + str(end_test_time - end_make_test_img_time)) 209 | row.append(one_class_acc) 210 | class_acc.append(one_class_acc) 211 | testfaceimage, testlabel,_ = train_image(testpath) 212 | test_writer = tf.summary.FileWriter('logs/test/total/') 213 | summary, totalacc = sess.run([merged, accuracy],feed_dict={x: testfaceimage, y: testlabel, keep_prob: 1.0}) 214 | #添加到同一个记录中 215 | test_writer.add_summary(summary, epoch) 216 | test_writer.flush() 217 | row.append(totalacc) 218 | print("Tter " + str(epoch) + "总的测试准确率为" + str(totalacc) ) 219 | wb = load_workbook(result_excel+"\\result"+str(exp_count)+".xlsx") 220 | sheet = wb.active 221 | sheet.append(row) 222 | wb.save(result_excel+"\\result"+str(exp_count)+".xlsx") 223 | #保存模型 224 | modelsavepath = "E:\\Flow classification\\TransformImage\\{}\\".format(modelname) 225 | if os.path.exists(modelsavepath): 226 | saver.save(sess,"E:\\Flow classification\\TransformImage\\{}\\model.ckpt".format(modelname),epoch) 227 | else: 228 | os.makedirs("E:\\Flow classification\\TransformImage\\{}\\".format(modelname)) 229 | saver.save(sess, "E:\\Flow classification\\TransformImage\\{}\\model.ckpt".format(modelname), epoch) 230 | train_writer.close() 231 | test_writer.close() 232 | end_time = time.time() 233 | print("总用时"+str(end_time-start_time)) -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | --------------------------------------------------------------------------------