├── README.md ├── MAML-FSIDS_CIC-DDoS2019-5shot ├── special_grads.py ├── utils.py ├── data_generator_test.py ├── data_generator_train.py ├── main5-5test.py ├── main5-5train.py └── maml.py └── FSIDS-IoT_Data_process ├── Data_extraction ├── CSE-CIC-IDS2018.py ├── CIC-IDS2017.py ├── NSL-KDD.py ├── NUSW-NB15.py └── CIC-DDoS2019.py ├── Data_process ├── CSE-CIC-IDS2018.py ├── image_covert.py ├── CIC-IDS2017.py ├── NSL-KDD.py ├── UNSW-NB15.py ├── CIC-DDoS2019.py └── Multithreading.py └── Data_check.py /README.md: -------------------------------------------------------------------------------- 1 | # MAML-CNN-FSIDS-IoT 2 | A Few-shot Based Model-Agnostic Meta-Learning for Intrusion Detection in Security of Internet of Things 3 | 4 | ### Dataset available: 5 | https://github.com/Chaomeng-Lu/Dataset-FSIDS-IoT 6 | 7 | IF you use this code and Dataset for research, please site the paper "A Few-shot Based Model-Agnostic Meta-Learning for Intrusion Detection in Security of Internet of Things". 8 | For more information about this project: https://doi.org/10.1109/JIOT.2023.3283408. 9 | -------------------------------------------------------------------------------- /MAML-FSIDS_CIC-DDoS2019-5shot/special_grads.py: -------------------------------------------------------------------------------- 1 | #coding=utf-8 2 | """ Code for second derivatives not implemented in TensorFlow library. """ 3 | from tensorflow.python.framework import ops 4 | from tensorflow.python.ops import array_ops 5 | from tensorflow.python.ops import gen_nn_ops 6 | 7 | 8 | @ops.RegisterGradient("MaxPoolGrad") 9 | def _MaxPoolGradGrad(op, grad): 10 | gradient = gen_nn_ops._max_pool_grad(op.inputs[0], op.outputs[0], 11 | grad, op.get_attr("ksize"), op.get_attr("strides"), 12 | padding=op.get_attr("padding"), data_format=op.get_attr("data_format")) 13 | gradgrad1 = array_ops.zeros(shape=array_ops.shape(op.inputs[1]), dtype=gradient.dtype) 14 | gradgrad2 = array_ops.zeros(shape=array_ops.shape(op.inputs[2]), dtype=gradient.dtype) 15 | return (gradient, gradgrad1, gradgrad2) 16 | -------------------------------------------------------------------------------- /FSIDS-IoT_Data_process/Data_extraction/CSE-CIC-IDS2018.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import pandas as pd 3 | 4 | path = 'F:/GraduateStudy/DataSet/archive(CSE-CIC-IDS2018)/' 5 | 6 | csv_data = pd.read_csv(path+'03-02-2018.csv') # 读取数据 7 | 8 | # print(csv_data.shape) 9 | # print(csv_data.columns) 10 | # sys.exit(0) 11 | 12 | labels_values_counts = csv_data['Label'].value_counts() 13 | labels_values = labels_values_counts.index 14 | 15 | path1 = 'F:/GraduateStudy/DataSet/!A-IDSFS/CSE-CIC-IDS2018/03-02_' 16 | 17 | n = 5000 # 抽取数据量 18 | 19 | for labels_value in labels_values: 20 | df_sample = csv_data[csv_data['Label'] == labels_value] 21 | sample_count = df_sample['Label'].value_counts() 22 | if sample_count[0] >= 5000: 23 | df_sample = df_sample.sample(n) 24 | filepath = path1 + labels_value + '_extract'+str(n)+'.csv' # 写入数据 25 | else: 26 | print("数据量不足"+str(n)) 27 | filepath = path1 + labels_value + '_extract' + str(sample_count[0]) + '.csv' # 写入数据 28 | df_columns = pd.DataFrame([list(csv_data.columns)]) 29 | df_columns.to_csv(filepath, mode='w', header=False, index=0) 30 | df_sample.to_csv(filepath, mode='a', header=False, index=0) 31 | 32 | # labels = csv_data['Label'].value_counts() 33 | # 34 | # print(labels) -------------------------------------------------------------------------------- /FSIDS-IoT_Data_process/Data_extraction/CIC-IDS2017.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import pandas as pd 3 | 4 | path = 'F:/GraduateStudy/DataSet/CIC-IDS2017/MachineLearningCVE/' 5 | 6 | csv_data = pd.read_csv(path+'Wednesday-workingHours.pcap_ISCX.csv') # 读取数据 7 | 8 | # print(csv_data.shape) 9 | # print(csv_data.columns) 10 | # labels = csv_data[' Label'].value_counts() 11 | # print(labels) 12 | # sys.exit(0) 13 | 14 | labels_values_counts = csv_data[' Label'].value_counts() 15 | labels_values = labels_values_counts.index 16 | 17 | path1 = 'F:/GraduateStudy/DataSet/!A-IDSFS/CIC-IDS2017/Wednesday_' 18 | 19 | n = 5000 # 抽取数据量 20 | 21 | for labels_value in labels_values: 22 | df_sample = csv_data[csv_data[' Label'] == labels_value] 23 | sample_count = df_sample[' Label'].value_counts() 24 | if sample_count[0] >= 5000: 25 | df_sample = df_sample.sample(n) 26 | filepath = path1 + labels_value + '_extract'+str(n)+'.csv' # 写入数据 27 | else: 28 | print("数据量不足"+str(n)) 29 | filepath = path1 + labels_value + '_extract' + str(sample_count[0]) + '.csv' # 写入数据 30 | df_columns = pd.DataFrame([list(csv_data.columns)]) 31 | df_columns.to_csv(filepath, mode='w', header=False, index=0) 32 | df_sample.to_csv(filepath, mode='a', header=False, index=0) 33 | 34 | # labels = csv_data['Label'].value_counts() 35 | # 36 | # print(labels) -------------------------------------------------------------------------------- /FSIDS-IoT_Data_process/Data_extraction/NSL-KDD.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import pandas as pd 3 | 4 | path = 'F:/GraduateStudy/DataSet/Z2--NSL-KDD/' 5 | 6 | csv_data1 = pd.read_csv(path+'KDDTest+.csv') # 读取数据 7 | csv_data2 = pd.read_csv(path+'KDDTrain+.csv') # 读取数据 8 | csv_data = pd.concat([csv_data1, csv_data2]) 9 | # csv_data.to_csv('1.csv') 10 | # sys.exit(0) 11 | print(csv_data.shape) 12 | # labels = csv_data['attack_cat'].value_counts() 13 | # print(labels) 14 | 15 | csv_data = csv_data.drop(['r'], axis=1) # 去除无价值属性 16 | 17 | print(csv_data.shape) 18 | labels_values_counts = csv_data['label'].value_counts() 19 | labels_values = labels_values_counts.index 20 | 21 | path1 = 'F:/GraduateStudy/DataSet/!A-IDSFS/NSL-KDD/' 22 | 23 | n = 5000 # 抽取数据量 24 | 25 | for labels_value in labels_values: 26 | df_sample = csv_data[csv_data['label'] == labels_value] 27 | sample_count = df_sample['label'].value_counts() 28 | if sample_count[0] >= 5000: 29 | df_sample = df_sample.sample(n) 30 | filepath = path1 + labels_value + '_extract'+str(n)+'.csv' # 写入数据 31 | else: 32 | print("数据量不足"+str(n)) 33 | filepath = path1 + labels_value + '_extract' + str(sample_count[0]) + '.csv' # 写入数据 34 | df_columns = pd.DataFrame([list(csv_data.columns)]) 35 | df_columns.to_csv(filepath, mode='w', header=False, index=0) 36 | df_sample.to_csv(filepath, mode='a', header=False, index=0) 37 | 38 | # labels = csv_data['Label'].value_counts() 39 | # 40 | # print(labels) -------------------------------------------------------------------------------- /FSIDS-IoT_Data_process/Data_extraction/NUSW-NB15.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import pandas as pd 3 | 4 | path = 'F:/GraduateStudy/DataSet/UNSW-NB15/' 5 | 6 | csv_data1 = pd.read_csv(path+'UNSW_NB15_testing-set.csv') # 读取数据 7 | csv_data2 = pd.read_csv(path+'UNSW_NB15_training-set.csv') # 读取数据 8 | csv_data = pd.concat([csv_data1, csv_data2]) 9 | # csv_data.to_csv('1.csv') 10 | # sys.exit(0) 11 | print(csv_data.shape) 12 | # print(csv_data.columns) 13 | # labels = csv_data['attack_cat'].value_counts() 14 | # print(labels) 15 | # sys.exit(0) 16 | 17 | csv_data = csv_data.drop(['id', 'label'], axis=1) # 去除无价值属性 18 | 19 | print(csv_data.shape) 20 | labels_values_counts = csv_data['attack_cat'].value_counts() 21 | labels_values = labels_values_counts.index 22 | 23 | path1 = 'F:/GraduateStudy/DataSet/!A-IDSFS/UNSW-NB15/' 24 | 25 | n = 5000 # 抽取数据量 26 | 27 | for labels_value in labels_values: 28 | df_sample = csv_data[csv_data['attack_cat'] == labels_value] 29 | sample_count = df_sample['attack_cat'].value_counts() 30 | if sample_count[0] >= 5000: 31 | df_sample = df_sample.sample(n) 32 | filepath = path1 + labels_value + '_extract'+str(n)+'.csv' # 写入数据 33 | else: 34 | print("数据量不足"+str(n)) 35 | filepath = path1 + labels_value + '_extract' + str(sample_count[0]) + '.csv' # 写入数据 36 | df_columns = pd.DataFrame([list(csv_data.columns)]) 37 | df_columns.to_csv(filepath, mode='w', header=False, index=0) 38 | df_sample.to_csv(filepath, mode='a', header=False, index=0) 39 | 40 | # labels = csv_data['Label'].value_counts() 41 | # 42 | # print(labels) -------------------------------------------------------------------------------- /FSIDS-IoT_Data_process/Data_process/CSE-CIC-IDS2018.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import pandas as pd 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | import matplotlib 6 | import os 7 | 8 | path = 'G:/Desktop 2022.06.08 File Backup/GraduateStudy/DataSet/Processed data for MAML/IDSFS_Original/CSE-CIC-IDS2018/' 9 | path1 = 'G:/Desktop 2022.06.08 File Backup/GraduateStudy/DataSet/Processed data for MAML/RIDS_28_28_RGBA/CSE-CIC-IDS2018/' 10 | 11 | files = os.listdir(path) 12 | for file in files: 13 | pathf = os.path.join(path, file) 14 | data = pd.read_csv(pathf) # 读取数据 15 | data = np.array(data) 16 | few_data = data[:20, 1:-1] 17 | few_data = np.delete(few_data, 1, axis=1) 18 | s = 1 19 | patht = os.path.join(path1, file) 20 | if not os.path.isdir(patht): 21 | os.makedirs(patht) 22 | # 特征数 = 77 23 | for sample in few_data: 24 | # print(sample) 25 | n = 28 26 | matrix = np.zeros((n,n)) 27 | # print(matrix) 28 | # matrix[0][0] = sample[0] 29 | i = 0 30 | for j in range(n*n-10): 31 | if (j+1)%10 == 6: 32 | # print((j+1)//n,(j+1)%n) 33 | p, q = (j+1)//n,(j+1)%n 34 | matrix[p][q] = sample[i] 35 | i = i+1 36 | # print(matrix) 37 | # sys.exit(0) 38 | # plt.imshow(matrix, plt.cm.gray) #生成灰度图像 39 | # plt.imshow(matrix) 40 | # plt.axis('off') 41 | # plt.savefig(patht+'/pic-'+str(s)+'.png', dpi=22.8, bbox_inches='tight') #plt.savefig('./img/pic-{}.png'.format(epoch + 1)) 42 | 43 | # 矩阵转图像方法二 44 | matplotlib.image.imsave(patht + '/pic-' + str(s) + '.png', matrix) 45 | s = s+1 -------------------------------------------------------------------------------- /MAML-FSIDS_CIC-DDoS2019-5shot/utils.py: -------------------------------------------------------------------------------- 1 | #coding=utf-8 2 | """ Utility functions. """ 3 | import numpy as np 4 | import os 5 | import random 6 | import tensorflow as tf 7 | 8 | from tensorflow.contrib.layers.python import layers as tf_layers 9 | from tensorflow.python.platform import flags 10 | 11 | FLAGS = flags.FLAGS 12 | 13 | 14 | ## Image helper 15 | def get_images(paths, labels, nb_samples=None, shuffle=True): 16 | if nb_samples is not None: 17 | sampler = lambda x: random.sample(x, nb_samples) 18 | else: 19 | sampler = lambda x: x 20 | images = [(i, os.path.join(path, image)) \ 21 | for i, path in zip(labels, paths) \ 22 | for image in sampler(os.listdir(path))] 23 | if shuffle: 24 | random.shuffle(images) 25 | return images 26 | 27 | 28 | ## Network helpers 29 | def conv_block(x, weight, bias, reuse, scope): 30 | # conv 31 | x = tf.nn.conv2d(x, weight, [1, 1, 1, 1], 'SAME') + bias 32 | # batch norm 33 | x = tf_layers.batch_norm(x, activation_fn=tf.nn.relu, reuse=reuse, scope=scope) 34 | # pooling 35 | x = tf.nn.max_pool(x, [1, 2, 2, 1], [1, 2, 2, 1], 'VALID') 36 | return x 37 | 38 | 39 | def normalize(inp, activation, reuse, scope): 40 | if FLAGS.norm == 'batch_norm': 41 | return tf_layers.batch_norm(inp, activation_fn=activation, reuse=reuse, scope=scope) 42 | elif FLAGS.norm == 'layer_norm': 43 | return tf_layers.layer_norm(inp, activation_fn=activation, reuse=reuse, scope=scope) 44 | elif FLAGS.norm == 'None': 45 | if activation is not None: 46 | return activation(inp) 47 | else: 48 | return inp 49 | 50 | 51 | ## Loss functions 52 | def mse(pred, label): 53 | pred = tf.reshape(pred, [-1]) 54 | label = tf.reshape(label, [-1]) 55 | return tf.reduce_mean(tf.square(pred - label)) 56 | 57 | 58 | def xent(pred, label): 59 | # Note - with tf version <=0.12, this loss has incorrect 2nd derivatives 60 | return tf.nn.softmax_cross_entropy_with_logits(logits=pred, labels=label) / FLAGS.update_batch_size 61 | -------------------------------------------------------------------------------- /FSIDS-IoT_Data_process/Data_process/image_covert.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | 3 | # im = Image.open("F:/GraduateStudy/TensorflowProject/MAML/maml-master/data/IDSFS_Test/CIC-DDoS2019/BENIGN/pic-1.png") 4 | # im = Image.open("F:/GraduateStudy/TensorflowProject/MAML/maml-master/data/omniglot_resized/Alphabet_of_the_Magi/character01/0709_01.png") 5 | # print(im.getbands()) 6 | # from PIL import Image 7 | # import numpy as np 8 | # 9 | # img = Image.open('F:/GraduateStudy/TensorflowProject/MAML/maml-master/data/IDSFS_Test/CIC-DDoS2019/BENIGN/pic-1.png').convert('1') 10 | # print(img.getbands()) # ('P',) 这种是有彩色的,而L是没有彩色的 11 | # img.save('F:/GraduateStudy/TensorflowProject/MAML/maml-master/data/IDSFS_Test/CIC-DDoS2019/BENIGN/pic-1-0.png') # 转换后的进行保存 12 | 13 | import os 14 | 15 | path = "J:/Desktop 2022.06.08 File Backup/GraduateStudy/DataSet/Processed data for MAML/RIDS_28_28_RGBA/UNSW-NB15/" 16 | path1 = "J:/Desktop 2022.06.08 File Backup/GraduateStudy/DataSet/Processed data for MAML/RIDS_84_84_RGB/UNSW-NB15/" 17 | 18 | 19 | folders = os.listdir(path) 20 | # files = os.listdir(path) 21 | # print(files) 22 | for folder in folders: 23 | folder_path = os.path.join(path, folder) 24 | files = os.listdir(folder_path) 25 | # save_path = path1+folder 26 | save_path = os.path.join(path1, folder) 27 | if not os.path.isdir(save_path): 28 | os.makedirs(save_path) 29 | for pic in files: 30 | # print(pic) 31 | img = Image.open(os.path.join(folder_path, pic)).convert('RGB') 32 | img = img.resize((84, 84), resample=Image.LANCZOS) 33 | print(img.getbands()) # ('P',) 这种是有彩色的,而L是没有彩色的 34 | print(img.size) 35 | 36 | # file_name, file_extend = os.path.splitext(pic) 37 | # print(file_name,file_extend) 38 | # pic_new = os.path.join(os.path.abspath(save_path), file_name + '.jpg') 39 | 40 | # pic_new = os.path.join(os.path.abspath(save_path), pic) 41 | pic_new = os.path.join(save_path, pic) 42 | 43 | img.save(pic_new) 44 | -------------------------------------------------------------------------------- /FSIDS-IoT_Data_process/Data_process/CIC-IDS2017.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import pandas as pd 3 | import numpy as np 4 | import matplotlib 5 | import matplotlib.pyplot as plt 6 | import os 7 | 8 | path = 'G:/Desktop 2022.06.08 File Backup/GraduateStudy/DataSet/Processed data for MAML/IDSFS_Original/CIC-IDS2017/' 9 | path1 = 'G:/Desktop 2022.06.08 File Backup/GraduateStudy/DataSet/Processed data for MAML/RIDS_28_28_RGBA/CIC-IDS2017/' 10 | 11 | files = os.listdir(path) 12 | for file in files: 13 | pathf = os.path.join(path, file) 14 | data = pd.read_csv(pathf) # 读取数据 15 | data = np.array(data) 16 | few_data = data[:20, 1:-1] 17 | # print(few_data.shape) 18 | # print(few_data[0, :]) 19 | # sys.exit(0) 20 | # print(few_data.shape) 21 | # print(few_data[0,:]) 22 | # sample = few_data[17,:] 23 | patht = os.path.join(path1, file) 24 | if not os.path.isdir(patht): 25 | os.makedirs(patht) 26 | s = 1 27 | # 特征数 = 77 28 | for sample in few_data: 29 | # print(sample) 30 | n = 28 31 | matrix = np.zeros((n,n)) 32 | # print(matrix) 33 | # matrix[0][0] = sample[0] 34 | i = 0 35 | for j in range(n*n-10): 36 | if (j+1)%10 == 6: 37 | # print((j+1)//n,(j+1)%n) 38 | p, q = (j+1)//n,(j+1)%n 39 | matrix[p][q] = sample[i] 40 | i = i+1 41 | 42 | # plt.imshow(matrix, plt.cm.gray) #生成灰度图像 43 | # np.savetxt('Text/result' + str(s) + '.txt', matrix) 44 | # 矩阵转图像方法二 45 | matplotlib.image.imsave(patht + '/pic-' + str(s) + '.png', matrix) 46 | 47 | # plt.imshow(matrix) 48 | # plt.axis('off') 49 | # plt.savefig(patht+'pic-'+str(s)+'.png', dpi=7.3, bbox_inches='tight') #plt.savefig('./img/pic-{}.png'.format(epoch + 1)) 50 | # plt.savefig(patht + '/pic-' + str(s) + '.png', dpi=22.8, bbox_inches='tight', 51 | # pad_inches=0) # plt.savefig('./img/pic-{}.png'.format(epoch + 1)) 52 | s = s+1 -------------------------------------------------------------------------------- /FSIDS-IoT_Data_process/Data_process/NSL-KDD.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import pandas as pd 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | import matplotlib 6 | import os 7 | 8 | path = 'J:/Desktop 2022.06.08 File Backup/GraduateStudy/DataSet/Processed data for MAML/IDSFS_Original/NSL-KDD/' 9 | path1 = 'J:/Desktop 2022.06.08 File Backup/GraduateStudy/DataSet/Processed data for MAML/RIDS_28_28_RGBA/NSL-KDD/' 10 | 11 | files = os.listdir(path) 12 | for file in files: 13 | pathf = os.path.join(path, file) 14 | data = pd.read_csv(pathf) # 读取数据 15 | data = np.array(data) 16 | # print(data.shape) 17 | # sys.exit(0) 18 | # data = data[:,:-1] 19 | few_data = data[:20, :-1] 20 | few_data = np.delete(few_data, 1, axis=1) 21 | few_data = np.delete(few_data, 1, axis=1) 22 | few_data = np.delete(few_data, 1, axis=1) 23 | # print(few_data.shape) 24 | # print(few_data[0,:]) 25 | # sys.exit(0) 26 | # sample = few_data[17,:] 27 | s = 1 28 | patht = os.path.join(path1, file) 29 | if not os.path.isdir(patht): 30 | os.makedirs(patht) 31 | # 特征数 = 38 32 | for sample in few_data: 33 | # print(sample) 34 | n = 28 35 | matrix = np.zeros((n,n)) 36 | # print(matrix) 37 | # matrix[0][0] = sample[0] 38 | i = 0 39 | for j in range(n*n-24): 40 | if (j+1)%20 == 18: 41 | # print((j+1)//n,(j+1)%n) 42 | p, q = (j+1)//n,(j+1)%n 43 | matrix[p][q] = sample[i] 44 | i = i+1 45 | # print(i) 46 | # print(matrix) 47 | # sys.exit(0) 48 | # plt.imshow(matrix, plt.cm.gray) #生成灰度图像 49 | # plt.imshow(matrix) 50 | # plt.xticks([]) # 去掉横坐标值 51 | # plt.yticks([]) # 去掉纵坐标值 52 | # plt.savefig(patht+'/pic-'+str(s)+'.png', dpi=22.8, bbox_inches='tight') #plt.savefig('./img/pic-{}.png'.format(epoch + 1)) 53 | 54 | # 矩阵转图像方法二 55 | matplotlib.image.imsave(patht + '/pic-' + str(s) + '.png', matrix) 56 | s = s+1 -------------------------------------------------------------------------------- /FSIDS-IoT_Data_process/Data_process/UNSW-NB15.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from random import random 3 | 4 | import pandas as pd 5 | import numpy as np 6 | import matplotlib.pyplot as plt 7 | import matplotlib 8 | import os 9 | 10 | path = 'J:/Desktop 2022.06.08 File Backup/GraduateStudy/DataSet/Processed data for MAML/IDSFS_Original/UNSW-NB15/' 11 | path1 = 'J:/Desktop 2022.06.08 File Backup/GraduateStudy/DataSet/Processed data for MAML/RIDS_28_28_RGBA/UNSW-NB15/' 12 | 13 | files = os.listdir(path) 14 | for file in files: 15 | pathf = os.path.join(path, file) 16 | data = pd.read_csv(pathf) # 读取数据 17 | data = np.array(data) 18 | # print(data.shape) 19 | 20 | # data = data[:,:-1] 21 | few_data = data[:20, :-1] 22 | few_data = np.delete(few_data, 1, axis=1) 23 | few_data = np.delete(few_data, 1, axis=1) 24 | few_data = np.delete(few_data, 1, axis=1) 25 | # print(few_data.shape) 26 | # print(few_data[0,:]) 27 | # sys.exit(0) 28 | # sample = few_data[17,:] 29 | s = 1 30 | patht = os.path.join(path1, file) 31 | if not os.path.isdir(patht): 32 | os.makedirs(patht) 33 | # 特征数 = 39 34 | for sample in few_data: 35 | # print(sample) 36 | n = 28 37 | matrix = np.zeros((n,n)) 38 | # print(matrix) 39 | # matrix[0][0] = sample[0] 40 | i = 0 41 | for j in range(n*n): 42 | if (j+1)%20 == 18: 43 | # print((j+1)//n,(j+1)%n) 44 | p, q = (j+1)//n,(j+1)%n 45 | matrix[p][q] = sample[i] 46 | i = i+1 47 | # print(matrix) 48 | # print(i) 49 | # sys.exit(0) 50 | # plt.imshow(matrix, plt.cm.gray) #生成灰度图像 51 | # plt.imshow(matrix) 52 | # plt.axis('off') 53 | # plt.xticks([]) # 去掉横坐标值 54 | # plt.yticks([]) # 去掉纵坐标值 55 | # plt.savefig(patht+'/pic-'+str(s)+'.png', dpi=22.8, bbox_inches='tight') #plt.savefig('./img/pic-{}.png'.format(epoch + 1)) 56 | 57 | # 矩阵转图像方法二 58 | matplotlib.image.imsave(patht + '/pic-' + str(s) + '.png', matrix) 59 | s = s+1 -------------------------------------------------------------------------------- /FSIDS-IoT_Data_process/Data_check.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | 3 | # csv_data = pd.read_csv('F:/GraduateStudy/DataSet/Z2--NSL-KDD/KDDTest+.txt') # 读取训练数据 4 | 5 | # print(csv_data.shape) 6 | # print(csv_data.columns) 7 | # labels = csv_data['LABEL'].value_counts() 8 | # print(labels) 9 | # 10 | # df_sample1 = csv_data[csv_data['LABEL'] == 'Normal flow'] 11 | # df_sample1 = df_sample1.sample(100000) 12 | # 13 | # df_sample2 = csv_data[csv_data['LABEL'] == 'Denial of Service R-U-Dead-Yet'] 14 | # df_sample2 = df_sample2.sample(100000) 15 | # 16 | # df_sample3 = csv_data[csv_data['LABEL'] == 'Denial of Service Slowloris'] 17 | # df_sample3 = df_sample3.sample(100000) 18 | # df_sample4 = pd.concat([df_sample1, df_sample2]) 19 | # df_sample = pd.concat([df_sample3, df_sample4]) 20 | # 21 | # filepath = 'F:/GraduateStudy/DataSet/SIMARGL2021/dataset-part2-reduce.csv' 22 | # df_columns = pd.DataFrame([list(csv_data.columns)]) 23 | # df_columns.to_csv(filepath, mode='w', header=False, index=0) 24 | # df_sample.to_csv(filepath, mode='a', header=False, index=0) 25 | 26 | #kddcup99 27 | col_names = ["duration", "protocol_type", "service", "flag", "src_bytes", 28 | 29 | "dst_bytes", "land", "wrong_fragment", "urgent", "hot", "num_failed_logins", 30 | 31 | "logged_in", "num_compromised", "root_shell", "su_attempted", "num_root", 32 | "num_file_creations", "num_shells", "num_access_files", "num_outbound_cmds", 33 | 34 | "is_host_login", "is_guest_login", "count", "srv_count", "serror_rate", 35 | "srv_serror_rate", "rerror_rate", "srv_rerror_rate", "same_srv_rate", 36 | 37 | "diff_srv_rate", "srv_diff_host_rate", "dst_host_count", "dst_host_srv_count", 38 | "dst_host_same_srv_rate", "dst_host_diff_srv_rate", "dst_host_same_src_port_rate", 39 | "dst_host_srv_diff_host_rate", "dst_host_serror_rate", "dst_host_srv_serror_rate", 40 | 41 | "dst_host_rerror_rate", "dst_host_srv_rerror_rate", "label", "r"] # 42个标识 42 | 43 | csv_data = pd.read_csv('F:/GraduateStudy/DataSet/Z2--NSL-KDD/KDDTest+.txt', names = col_names) # 读取训练数据 44 | 45 | print(csv_data.shape) 46 | print(csv_data.columns) 47 | labels = csv_data['label'].value_counts() 48 | print(labels) 49 | 50 | csv_data.to_csv("F:/GraduateStudy/DataSet/Z2--NSL-KDD/KDDTest+.csv", index=0)#另存为csv文件 -------------------------------------------------------------------------------- /FSIDS-IoT_Data_process/Data_process/CIC-DDoS2019.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import pandas as pd 3 | import numpy as np 4 | import matplotlib 5 | import matplotlib.pyplot as plt 6 | import os 7 | 8 | path = 'J:/Desktop 2022.06.08 File Backup/GraduateStudy/DataSet/Processed data for MAML/IDSFS_Original/CIC-DDoS2019/' 9 | path1 = 'J:/Desktop 2022.06.08 File Backup/GraduateStudy/DataSet/Processed data for MAML/RIDS_28_28_RGBA/CIC-DDoS2019/' 10 | 11 | files = os.listdir(path) 12 | for file in files: 13 | pathf = os.path.join(path, file) 14 | data = pd.read_csv(pathf) # 读取数据 15 | data = np.array(data) 16 | # print(data.shape) 17 | # data = data[:,:-1] 18 | few_data = data[:20,7:-1] 19 | few_data = np.delete(few_data, -2, axis=1) 20 | s = 1 21 | patht = os.path.join(path1, file) 22 | if not os.path.isdir(patht): 23 | os.makedirs(patht) 24 | # 特征数 = 78 25 | for sample in few_data: 26 | # print(sample) 27 | n = 28 28 | matrix = np.zeros((n,n)) 29 | # print(matrix) 30 | # matrix[0][0] = sample[0] 31 | i = 0 32 | for j in range(n*n-5): 33 | if (j+1)%10 == 6: 34 | # print((j+1)//n,(j+1)%n) 35 | p, q = (j+1)//n,(j+1)%n 36 | matrix[p][q] = sample[i] 37 | i = i+1 38 | # print(matrix) 39 | # temp = pd.DataFrame(matrix) 40 | # temp.to_csv('result1.csv', header=0, index=0) 41 | # np.savetxt("result1.txt", matrix) 42 | # sys.exit(0) 43 | # plt.imshow(matrix, plt.cm.gray) #生成灰度图像 44 | 45 | # 矩阵转图像方法三 46 | # import cv2 47 | # import numpy as np 48 | # cv2.imwrite(patht+'/pic-'+str(s)+'.png', matrix) 49 | 50 | # 矩阵转图像方法二 # 51 | matplotlib.image.imsave(patht+'/pic-'+str(s)+'.png', matrix) 52 | 53 | # path55 = 'J:/Desktop 2022.06.08 File Backup/GraduateStudy/DataSet/Processed data for MAML/RIDS_28_28_RGBA/CIC-DDoS2019/01-12_DrDoS_DNS_extract5000.csv/pic-19.png' 54 | # ar = matplotlib.image.imread(patht+'/pic-'+str(s)+'.png', format=None) 55 | # ar = matplotlib.image.imread(path55, format=None) 56 | # for i in range(4): 57 | # ar1 = ar[:, :, i] 58 | # temp = pd.DataFrame(ar1) 59 | # temp.to_csv('result2-'+ str(i) +'.csv', header=0, index=0) 60 | 61 | # np.savetxt("result2.txt", ar.reshape(1,-1)) 62 | 63 | #矩阵转图像方法一 64 | # plt.imshow(matrix) 65 | # plt.axis('off') 66 | # plt.savefig(patht+'/pic-'+str(s)+'.png', dpi=22.8, bbox_inches='tight', pad_inches=0) #plt.savefig('./img/pic-{}.png'.format(epoch + 1)) 67 | s = s+1 68 | sys.exit(0) -------------------------------------------------------------------------------- /FSIDS-IoT_Data_process/Data_extraction/CIC-DDoS2019.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import pandas as pd 3 | 4 | path = 'I:/Temp/临时3/CIC-DDoS2019-reduce/01-12/' 5 | 6 | #01-12 7 | csv_data0 = pd.read_csv(path+'DrDoS_DNS.csv') # 读取数据 8 | csv_data1 = pd.read_csv(path+'DrDoS_LDAP.csv') # 读取数据 9 | csv_data2 = pd.read_csv(path+'DrDoS_MSSQL.csv') # 读取数据 10 | csv_data3 = pd.read_csv(path+'DrDoS_NetBIOS.csv') # 读取数据 11 | csv_data4 = pd.read_csv(path+'DrDoS_NTP.csv') # 读取数据 12 | csv_data5 = pd.read_csv(path+'DrDoS_SNMP.csv') # 读取数据 13 | csv_data6 = pd.read_csv(path+'DrDoS_SSDP.csv') # 读取数据 14 | csv_data7 = pd.read_csv(path+'DrDoS_UDP.csv') # 读取数据 15 | csv_data8 = pd.read_csv(path+'Syn.csv') # 读取数据 16 | csv_data9 = pd.read_csv(path+'TFTP.csv') # 读取数据 17 | csv_data10 = pd.read_csv(path+'UDPLag.csv') # 读取数据 18 | csv_data = pd.concat([csv_data1, csv_data2]) 19 | csv_data = pd.concat([csv_data, csv_data3]) 20 | csv_data = pd.concat([csv_data, csv_data0]) 21 | csv_data = pd.concat([csv_data, csv_data4]) 22 | csv_data = pd.concat([csv_data, csv_data5]) 23 | csv_data = pd.concat([csv_data, csv_data6]) 24 | csv_data = pd.concat([csv_data, csv_data7]) 25 | csv_data = pd.concat([csv_data, csv_data8]) 26 | csv_data = pd.concat([csv_data, csv_data9]) 27 | csv_data = pd.concat([csv_data, csv_data10]) 28 | 29 | #03-11 30 | # csv_data1 = pd.read_csv(path+'LDAP.csv') # 读取数据 31 | # csv_data2 = pd.read_csv(path+'MSSQL.csv') # 读取数据 32 | # csv_data3 = pd.read_csv(path+'NetBIOS.csv') # 读取数据 33 | # csv_data4 = pd.read_csv(path+'Portmap.csv') # 读取数据 34 | # csv_data5 = pd.read_csv(path+'Syn.csv') # 读取数据 35 | # csv_data6 = pd.read_csv(path+'UDP.csv') # 读取数据 36 | # csv_data7 = pd.read_csv(path+'UDPLag.csv') # 读取数据 37 | # csv_data = pd.concat([csv_data1, csv_data2]) 38 | # csv_data = pd.concat([csv_data, csv_data3]) 39 | # csv_data = pd.concat([csv_data, csv_data4]) 40 | # csv_data = pd.concat([csv_data, csv_data5]) 41 | # csv_data = pd.concat([csv_data, csv_data6]) 42 | # csv_data = pd.concat([csv_data, csv_data7]) 43 | 44 | print(csv_data.shape) 45 | # print(csv_data.columns) 46 | labels = csv_data[' Label'].value_counts() 47 | print(labels) 48 | 49 | # csv_data = csv_data.drop(['Unnamed: 0'], axis=1) # 去除无价值属性 50 | 51 | labels_values_counts = csv_data[' Label'].value_counts() 52 | labels_values = labels_values_counts.index 53 | 54 | path1 = 'I:/Temp/临时3/CIC-DDos2019FS/01-12/01-12_' 55 | 56 | n = 5000 # 抽取数据量 57 | 58 | for labels_value in labels_values: 59 | df_sample = csv_data[csv_data[' Label'] == labels_value] 60 | sample_count = df_sample[' Label'].value_counts() 61 | if sample_count[0] >= 5000: 62 | df_sample = df_sample.sample(n) 63 | filepath = path1 + labels_value + '_extract'+str(n)+'.csv' # 写入数据 64 | else: 65 | print("数据量不足"+str(n)) 66 | filepath = path1 + labels_value + '_extract' + str(sample_count[0]) + '.csv' # 写入数据 67 | df_columns = pd.DataFrame([list(csv_data.columns)]) 68 | df_columns.to_csv(filepath, mode='w', header=False, index=0) 69 | df_sample.to_csv(filepath, mode='a', header=False, index=0) 70 | 71 | # labels = csv_data['Label'].value_counts() 72 | # 73 | # print(labels) -------------------------------------------------------------------------------- /FSIDS-IoT_Data_process/Data_process/Multithreading.py: -------------------------------------------------------------------------------- 1 | import threading 2 | import sys 3 | import pandas as pd 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | import os 7 | 8 | class DataSource: 9 | def __init__(self, few_data, startLine=0, maxcount=None): 10 | # self.dataFileName = few_data 11 | self.startLine = startLine # 第一行行号为1 12 | self.line_index = startLine # 当前读取位置 13 | self.maxcount = maxcount # 读取最大行数 14 | self.lock = threading.RLock() # 同步锁 15 | self.__data__ = few_data 16 | 17 | # self.__data__ = open(self.dataFileName, 'r', encoding= 'utf-8') 18 | for i in range(self.startLine): 19 | l = self.__data__.readline() 20 | 21 | def getLine(self): 22 | self.lock.acquire() 23 | try: 24 | if self.maxcount is None or self.line_index < (self.startLine + self.maxcount): 25 | line = self.__data__[self.line_index] 26 | # print('line:',line) 27 | # sys.exit(0) 28 | if len(line): 29 | self.line_index += 1 30 | return True, line 31 | else: 32 | return False, None 33 | else: 34 | return False, None 35 | 36 | except Exception as e: 37 | return False, "处理出错:" + e.args 38 | finally: 39 | self.lock.release() 40 | 41 | # def __del__(self): 42 | # if not self.__data__.closed: 43 | # self.__data__.close() 44 | # print("关闭数据源:", self.dataFileName) 45 | 46 | 47 | def process(worker_id, datasource, patht): 48 | count = 0 49 | while True: 50 | status, data1 = datasource.getLine() 51 | # print(data) 52 | # sys.exit(0) 53 | if status: 54 | print(">>> 线程[%d] 获得数据, 正在处理……" % worker_id) 55 | n = 28 56 | matrix = np.zeros((n, n)) 57 | # print(matrix) 58 | # matrix[0][0] = sample[0] 59 | i = 0 60 | for j in range(n * n - 5): 61 | if (j + 1) % 10 == 6: 62 | # print((j+1)//n,(j+1)%n) 63 | p, q = (j + 1) // n, (j + 1) % n 64 | # sys.exit(0) 65 | matrix[p][q] = data1[i] 66 | i = i + 1 67 | plt.imshow(matrix) 68 | plt.axis('off') 69 | plt.savefig(patht + '/pic-' + str(count) + '.png', dpi=22.8, bbox_inches='tight', 70 | pad_inches=0) # plt.savefig('./img/pic-{}.png'.format(epoch + 1)) 71 | print(">>> 线程[%d] 处理数据 完成" % worker_id) 72 | count += 1 73 | else: 74 | break # 退出循环 75 | print(">>> 线程[%d] 结束, 共处理[%d]条数据" % (worker_id, count)) 76 | 77 | def main(): 78 | path = 'G:/Desktop 2022.06.08 File Backup/GraduateStudy/DataSet/IDSFS-all5000extraction-new/CIC-DDoS2019/' 79 | path1 = 'G:/Desktop 2022.06.08 File Backup/GraduateStudy/DataSet/Processed data for G-CNN/RIDS_84_84_RGBA/CIC-DDoS2019/' 80 | 81 | files = os.listdir(path) 82 | for file in files: 83 | pathf = os.path.join(path, file) 84 | data = pd.read_csv(pathf) # 读取数据 85 | data = np.array(data) 86 | # print(data.shape) 87 | # data = data[:,:-1] 88 | few_data = data[:, 7:-1] 89 | few_data = np.delete(few_data, -2, axis=1) 90 | # np.savetxt(file + '.csv', few_data) 91 | patht = os.path.join(path1, file) 92 | if not os.path.isdir(patht): 93 | os.makedirs(patht) 94 | datasource = DataSource(few_data) 95 | workercount = 10 # 开启的线程数,注意:并非越多越快哦 96 | workers = [] 97 | for i in range(workercount): 98 | worker = threading.Thread(target=process, args=(i+1, datasource, patht)) 99 | worker.start() 100 | workers.append(worker) 101 | 102 | for worker in workers: 103 | worker.join() 104 | 105 | if __name__ == "__main__": 106 | main() -------------------------------------------------------------------------------- /MAML-FSIDS_CIC-DDoS2019-5shot/data_generator_test.py: -------------------------------------------------------------------------------- 1 | #coding=utf-8 2 | import numpy as np 3 | import os, sys 4 | import random 5 | import tensorflow as tf 6 | import tqdm 7 | import pickle 8 | 9 | def get_images(paths, labels, nb_samples=None, shuffle=True): 10 | if nb_samples is not None: 11 | sampler = lambda x: random.sample(x, nb_samples) 12 | else: 13 | sampler = lambda x: x 14 | images = [(i, os.path.join(path, image)) \ 15 | for i, path in zip(labels, paths) \ 16 | for image in sampler(os.listdir(path))] 17 | if shuffle: 18 | random.shuffle(images) 19 | return images 20 | 21 | class DataGenerator: 22 | """ 23 | Data Generator capable of generating batches of sinusoid or Omniglot data. 24 | A "class" is considered a class of omniglot digits or a particular sinusoid function. 25 | """ 26 | 27 | def __init__(self, nway, kshot, kquery, meta_batchsz, total_batch_num = 20000): 28 | """ 29 | 30 | :param nway: 31 | :param kshot: 32 | :param kquery: 33 | :param meta_batchsz: 34 | """ 35 | self.meta_batchsz = meta_batchsz 36 | # number of images to sample per class 37 | self.nimg = kshot + kquery 38 | self.nway = nway 39 | self.imgsz = (84, 84) 40 | self.total_batch_num = total_batch_num 41 | self.dim_input = np.prod(self.imgsz) * 3 # 21168 42 | self.dim_output = nway 43 | 44 | metatrain_folder = './FS-IDS/train' 45 | metaval_folder = './FS-IDS/test' 46 | 47 | self.metatrain_folders = [os.path.join(metatrain_folder, label) \ 48 | for label in os.listdir(metatrain_folder) \ 49 | if os.path.isdir(os.path.join(metatrain_folder, label)) \ 50 | ] 51 | self.metaval_folders = [os.path.join(metaval_folder, label) \ 52 | for label in os.listdir(metaval_folder) \ 53 | if os.path.isdir(os.path.join(metaval_folder, label)) \ 54 | ] 55 | self.rotations = [0] 56 | 57 | 58 | print('metatrain_folder:', self.metatrain_folders[:2]) 59 | print('metaval_folders:', self.metaval_folders[:2]) 60 | 61 | 62 | def make_data_tensor(self, training=True): 63 | """ 64 | 65 | :param training: 66 | :return: 67 | """ 68 | if training: 69 | folders = self.metatrain_folders 70 | num_total_batches = self.total_batch_num 71 | else: 72 | folders = self.metaval_folders 73 | num_total_batches = 600 74 | 75 | 76 | if training and os.path.exists('filelist.pkl'): 77 | 78 | labels = np.arange(self.nway).repeat(self.nimg).tolist() 79 | with open('filelist.pkl', 'rb') as f: 80 | all_filenames = pickle.load(f) 81 | print('load episodes from file, len:', len(all_filenames)) 82 | 83 | else: # test or not existed. 84 | 85 | # 16 in one class, 16*5 in one task 86 | # [task1_0_img0, task1_0_img15, task1_1_img0,] 87 | all_filenames = [] 88 | for _ in tqdm.tqdm(range(num_total_batches), 'generating episodes'): # 200000 89 | # from image folder sample 5 class randomly 90 | sampled_folders = random.sample(folders, self.nway) 91 | random.shuffle(sampled_folders) 92 | # sample 16 images from selected folders, and each with label 0-4, (0/1..., path), orderly, no shuffle! 93 | # len: 5 * 16 94 | labels_and_images = get_images(sampled_folders, range(self.nway), nb_samples=self.nimg, shuffle=False) 95 | 96 | # make sure the above isn't randomized order 97 | labels = [li[0] for li in labels_and_images] 98 | filenames = [li[1] for li in labels_and_images] 99 | all_filenames.extend(filenames) 100 | 101 | if training: # only save for training. 102 | with open('filelist.pkl', 'wb') as f: 103 | pickle.dump(all_filenames,f) 104 | print('save all file list to filelist.pkl') 105 | 106 | # make queue for tensorflow to read from 107 | print('creating pipeline ops') 108 | filename_queue = tf.train.string_input_producer(tf.convert_to_tensor(all_filenames), shuffle=False) 109 | image_reader = tf.WholeFileReader() 110 | _, image_file = image_reader.read(filename_queue) 111 | 112 | image = tf.image.decode_jpeg(image_file, channels=3) 113 | # tensorflow format: N*H*W*C 114 | image.set_shape((self.imgsz[0], self.imgsz[1], 3)) 115 | # reshape(image, [84*84*3]) 116 | image = tf.reshape(image, [self.dim_input]) 117 | # convert to range(0,1) 118 | image = tf.cast(image, tf.float32) / 255.0 119 | 120 | examples_per_batch = self.nway * self.nimg # 5*16 121 | # batch here means batch of meta-learning, including 4 tasks = 4*80 122 | batch_image_size = self.meta_batchsz * examples_per_batch # 4* 80 123 | 124 | print('batching images') 125 | images = tf.train.batch( 126 | [image], 127 | batch_size=batch_image_size, # 4*80 128 | num_threads= self.meta_batchsz, 129 | capacity= 256 + 3 * batch_image_size, # 256 + 3* 4*80 130 | ) 131 | 132 | all_image_batches, all_label_batches = [], [] 133 | print('manipulating images to be right order') 134 | # images contains current batch, namely 4 task, 4* 80 135 | for i in range(self.meta_batchsz): # 4 136 | # current task, 80 images 137 | image_batch = images[i * examples_per_batch:(i + 1) * examples_per_batch] 138 | 139 | # as all labels of all task are the same, which is 0,0,..1,1,..2,2,..3,3,..4,4... 140 | label_batch = tf.convert_to_tensor(labels) 141 | new_list, new_label_list = [], [] 142 | # for each image from 0 to 15 in all 5 class 143 | for k in range(self.nimg): # 16 144 | class_idxs = tf.range(0, self.nway) # 0-4 145 | class_idxs = tf.random_shuffle(class_idxs) 146 | # it will cope with 5 images parallelly 147 | # [0, 16, 32, 48, 64] or [1, 17, 33, 49, 65] 148 | true_idxs = class_idxs * self.nimg + k 149 | new_list.append(tf.gather(image_batch, true_idxs)) 150 | 151 | new_label_list.append(tf.gather(label_batch, true_idxs)) 152 | 153 | # [80, 84*84*3] 154 | new_list = tf.concat(new_list, 0) # has shape [self.num_classes*self.num_samples_per_class, self.dim_input] 155 | # [80] 156 | new_label_list = tf.concat(new_label_list, 0) 157 | all_image_batches.append(new_list) 158 | all_label_batches.append(new_label_list) 159 | 160 | # [4, 80, 84*84*3] 161 | all_image_batches = tf.stack(all_image_batches) 162 | # [4, 80] 163 | all_label_batches = tf.stack(all_label_batches) 164 | # [4, 80, 5] 165 | all_label_batches = tf.one_hot(all_label_batches, self.nway) 166 | 167 | print('image_b:', all_image_batches) 168 | print('label_onehot_b:', all_label_batches) 169 | 170 | return all_image_batches, all_label_batches 171 | 172 | -------------------------------------------------------------------------------- /MAML-FSIDS_CIC-DDoS2019-5shot/data_generator_train.py: -------------------------------------------------------------------------------- 1 | #coding=utf-8 2 | import numpy as np 3 | import os, sys 4 | import random 5 | import tensorflow as tf 6 | import tqdm 7 | import pickle 8 | 9 | def get_images(paths, labels, nb_samples=None, shuffle=True): 10 | if nb_samples is not None: 11 | sampler = lambda x: random.sample(x, nb_samples) 12 | else: 13 | sampler = lambda x: x 14 | images = [(i, os.path.join(path, image)) \ 15 | for i, path in zip(labels, paths) \ 16 | for image in sampler(os.listdir(path))] 17 | if shuffle: 18 | random.shuffle(images) 19 | return images 20 | 21 | class DataGenerator: 22 | """ 23 | Data Generator capable of generating batches of sinusoid or Omniglot data. 24 | A "class" is considered a class of omniglot digits or a particular sinusoid function. 25 | """ 26 | 27 | def __init__(self, nway, kshot, kquery, meta_batchsz, total_batch_num = 20000): 28 | """ 29 | 30 | :param nway: 31 | :param kshot: 32 | :param kquery: 33 | :param meta_batchsz: 34 | """ 35 | self.meta_batchsz = meta_batchsz 36 | # number of images to sample per class 37 | self.nimg = kshot + kquery 38 | self.nway = nway 39 | self.imgsz = (84, 84) 40 | self.total_batch_num = total_batch_num 41 | self.dim_input = np.prod(self.imgsz) * 3 # 21168 42 | self.dim_output = nway 43 | 44 | metatrain_folder = './FS-IDS/train' 45 | metaval_folder = './FS-IDS/val' 46 | 47 | self.metatrain_folders = [os.path.join(metatrain_folder, label) \ 48 | for label in os.listdir(metatrain_folder) \ 49 | if os.path.isdir(os.path.join(metatrain_folder, label)) \ 50 | ] 51 | self.metaval_folders = [os.path.join(metaval_folder, label) \ 52 | for label in os.listdir(metaval_folder) \ 53 | if os.path.isdir(os.path.join(metaval_folder, label)) \ 54 | ] 55 | self.rotations = [0] 56 | 57 | 58 | print('metatrain_folder:', self.metatrain_folders[:2]) 59 | print('metaval_folders:', self.metaval_folders[:2]) 60 | 61 | 62 | def make_data_tensor(self, training=True): 63 | """ 64 | 65 | :param training: 66 | :return: 67 | """ 68 | if training: 69 | folders = self.metatrain_folders 70 | num_total_batches = self.total_batch_num 71 | else: 72 | folders = self.metaval_folders 73 | num_total_batches = 600 74 | 75 | 76 | if training and os.path.exists('filelist.pkl'): 77 | 78 | labels = np.arange(self.nway).repeat(self.nimg).tolist() 79 | with open('filelist.pkl', 'rb') as f: 80 | all_filenames = pickle.load(f) 81 | print('load episodes from file, len:', len(all_filenames)) 82 | 83 | else: # test or not existed. 84 | 85 | # 16 in one class, 16*5 in one task 86 | # [task1_0_img0, task1_0_img15, task1_1_img0,] 87 | all_filenames = [] 88 | for _ in tqdm.tqdm(range(num_total_batches), 'generating episodes'): # 200000 89 | # from image folder sample 5 class randomly 90 | sampled_folders = random.sample(folders, self.nway) 91 | random.shuffle(sampled_folders) 92 | # sample 16 images from selected folders, and each with label 0-4, (0/1..., path), orderly, no shuffle! 93 | # len: 5 * 16 94 | labels_and_images = get_images(sampled_folders, range(self.nway), nb_samples=self.nimg, shuffle=False) 95 | 96 | # make sure the above isn't randomized order 97 | labels = [li[0] for li in labels_and_images] 98 | filenames = [li[1] for li in labels_and_images] 99 | all_filenames.extend(filenames) 100 | 101 | if training: # only save for training. 102 | with open('filelist.pkl', 'wb') as f: 103 | pickle.dump(all_filenames,f) 104 | print('save all file list to filelist.pkl') 105 | 106 | # make queue for tensorflow to read from 107 | print('creating pipeline ops') 108 | filename_queue = tf.train.string_input_producer(tf.convert_to_tensor(all_filenames), shuffle=False) 109 | image_reader = tf.WholeFileReader() 110 | _, image_file = image_reader.read(filename_queue) 111 | 112 | image = tf.image.decode_jpeg(image_file, channels=3) 113 | # tensorflow format: N*H*W*C 114 | image.set_shape((self.imgsz[0], self.imgsz[1], 3)) 115 | # reshape(image, [84*84*3]) 116 | image = tf.reshape(image, [self.dim_input]) 117 | # convert to range(0,1) 118 | image = tf.cast(image, tf.float32) / 255.0 119 | 120 | examples_per_batch = self.nway * self.nimg # 5*16 121 | # batch here means batch of meta-learning, including 4 tasks = 4*80 122 | batch_image_size = self.meta_batchsz * examples_per_batch # 4* 80 123 | 124 | print('batching images') 125 | images = tf.train.batch( 126 | [image], 127 | batch_size=batch_image_size, # 4*80 128 | num_threads= self.meta_batchsz, 129 | capacity= 256 + 3 * batch_image_size, # 256 + 3* 4*80 130 | ) 131 | 132 | all_image_batches, all_label_batches = [], [] 133 | print('manipulating images to be right order') 134 | # images contains current batch, namely 4 task, 4* 80 135 | for i in range(self.meta_batchsz): # 4 136 | # current task, 80 images 137 | image_batch = images[i * examples_per_batch:(i + 1) * examples_per_batch] 138 | 139 | # as all labels of all task are the same, which is 0,0,..1,1,..2,2,..3,3,..4,4... 140 | label_batch = tf.convert_to_tensor(labels) 141 | new_list, new_label_list = [], [] 142 | # for each image from 0 to 15 in all 5 class 143 | for k in range(self.nimg): # 16 144 | class_idxs = tf.range(0, self.nway) # 0-4 145 | class_idxs = tf.random_shuffle(class_idxs) 146 | # it will cope with 5 images parallelly 147 | # [0, 16, 32, 48, 64] or [1, 17, 33, 49, 65] 148 | true_idxs = class_idxs * self.nimg + k 149 | new_list.append(tf.gather(image_batch, true_idxs)) 150 | 151 | new_label_list.append(tf.gather(label_batch, true_idxs)) 152 | 153 | # [80, 84*84*3] 154 | new_list = tf.concat(new_list, 0) # has shape [self.num_classes*self.num_samples_per_class, self.dim_input] 155 | # [80] 156 | new_label_list = tf.concat(new_label_list, 0) 157 | all_image_batches.append(new_list) 158 | all_label_batches.append(new_label_list) 159 | 160 | # [4, 80, 84*84*3] 161 | all_image_batches = tf.stack(all_image_batches) 162 | # [4, 80] 163 | all_label_batches = tf.stack(all_label_batches) 164 | # [4, 80, 5] 165 | all_label_batches = tf.one_hot(all_label_batches, self.nway) 166 | 167 | print('image_b:', all_image_batches) 168 | print('label_onehot_b:', all_label_batches) 169 | 170 | return all_image_batches, all_label_batches 171 | 172 | -------------------------------------------------------------------------------- /MAML-FSIDS_CIC-DDoS2019-5shot/main5-5test.py: -------------------------------------------------------------------------------- 1 | #coding=utf-8 2 | import os 3 | import numpy as np 4 | import argparse 5 | import random 6 | import tensorflow as tf 7 | 8 | from data_generator_test import DataGenerator 9 | from maml import MAML 10 | 11 | 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument('-t', '--test', action='store_true', default=True, help='set for test, otherwise train') 14 | args = parser.parse_args() 15 | 16 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 17 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 18 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 19 | 20 | 21 | def train(model, saver, sess): 22 | """ 23 | 24 | :param model: 25 | :param saver: 26 | :param sess: 27 | :return: 28 | """ 29 | # write graph to tensorboard 30 | # tb = tf.summary.FileWriter(os.path.join('logs', 'mini'), sess.graph) 31 | prelosses, postlosses, preaccs, postaccs = [], [], [], [] 32 | best_acc = 0 33 | 34 | # train for meta_iteartion epoches 35 | for iteration in range(10000): 36 | # this is the main op 37 | ops = [model.meta_op] 38 | 39 | # add summary and print op 40 | if iteration % 100 == 0: 41 | ops.extend([model.summ_op, 42 | model.query_losses[0], model.query_losses[-1], 43 | model.query_accs[0], model.query_accs[-1]]) 44 | 45 | # run all ops 46 | result = sess.run(ops) 47 | 48 | # summary 49 | if iteration % 100 == 0: 50 | # summ_op 51 | # tb.add_summary(result[1], iteration) 52 | # query_losses[0] 53 | prelosses.append(result[2]) 54 | # query_losses[-1] 55 | postlosses.append(result[3]) 56 | # query_accs[0] 57 | preaccs.append(result[4]) 58 | # query_accs[-1] 59 | postaccs.append(result[5]) 60 | 61 | print(iteration, '\tloss:', np.mean(prelosses), '=>', np.mean(postlosses), 62 | '\t\tacc:', np.mean(preaccs), '=>', np.mean(postaccs)) 63 | prelosses, postlosses, preaccs, postaccs = [], [], [], [] 64 | 65 | # evaluation 66 | if iteration % 1000 == 0: 67 | # DO NOT write as a = b = [], in that case a=b 68 | # DO NOT use train variable as we have train func already. 69 | acc1s, acc2s = [], [] 70 | # sample 20 times to get more accurate statistics. 71 | for _ in range(100): 72 | acc1, acc2 = sess.run([model.test_query_accs[0], 73 | model.test_query_accs[-1]]) 74 | acc1s.append(acc1) 75 | acc2s.append(acc2) 76 | 77 | acc = np.mean(acc2s) 78 | print('>>>>\t\tValidation accs: ', np.mean(acc1s), acc, 'best:', best_acc, '\t\t<<<<') 79 | 80 | if acc - best_acc > 0.05 or acc > 0.4: 81 | saver.save(sess, os.path.join('ckpt', 'mini.mdl')) 82 | best_acc = acc 83 | print('saved into ckpt:', acc) 84 | 85 | 86 | def test(model, sess): 87 | 88 | np.random.seed(1) 89 | random.seed(1) 90 | 91 | # repeat test accuracy for 600 times 92 | test_accs = [] 93 | for i in range(600): 94 | if i % 100 == 1: 95 | print(i) 96 | # extend return None!!! 97 | ops = [model.test_support_acc] 98 | ops.extend(model.test_query_accs) 99 | result = sess.run(ops) 100 | test_accs.append(result) 101 | 102 | # [600, K+1] 103 | test_accs = np.array(test_accs) 104 | # [K+1] 105 | means = np.mean(test_accs, 0) 106 | stds = np.std(test_accs, 0) 107 | ci95 = 1.96 * stds / np.sqrt(600) 108 | 109 | print('[support_t0, query_t0 - \t\t\tK] ') 110 | print('mean:', means) 111 | print('stds:', stds) 112 | print('ci95:', ci95) 113 | 114 | 115 | 116 | def main(): 117 | training = not args.test 118 | kshot = 5 119 | kquery = 15 120 | nway = 5 121 | meta_batchsz = 4 122 | K = 5 123 | 124 | 125 | # kshot + kquery images per category, nway categories, meta_batchsz tasks. 126 | db = DataGenerator(nway, kshot, kquery, meta_batchsz, 200000) 127 | 128 | if training: # only construct training model if needed 129 | # get the tensor 130 | # image_tensor: [4, 80, 84*84*3] 131 | # label_tensor: [4, 80, 5] 132 | image_tensor, label_tensor = db.make_data_tensor(training=True) 133 | 134 | # NOTICE: the image order in 80 images should like this now: 135 | # [label2, label1, label3, label0, label4, and then repeat by 15 times, namely one task] 136 | # support_x : [4, 1*5, 84*84*3] 137 | # query_x : [4, 15*5, 84*84*3] 138 | # support_y : [4, 5, 5] 139 | # query_y : [4, 15*5, 5] 140 | support_x = tf.slice(image_tensor, [0, 0, 0], [-1, nway * kshot, -1], name='support_x') 141 | query_x = tf.slice(image_tensor, [0, nway * kshot, 0], [-1, -1, -1], name='query_x') 142 | support_y = tf.slice(label_tensor, [0, 0, 0], [-1, nway * kshot, -1], name='support_y') 143 | query_y = tf.slice(label_tensor, [0, nway * kshot, 0], [-1, -1, -1], name='query_y') 144 | 145 | # construct test tensors. 146 | image_tensor, label_tensor = db.make_data_tensor(training=False) 147 | support_x_test = tf.slice(image_tensor, [0, 0, 0], [-1, nway * kshot, -1], name='support_x_test') 148 | query_x_test = tf.slice(image_tensor, [0, nway * kshot, 0], [-1, -1, -1], name='query_x_test') 149 | support_y_test = tf.slice(label_tensor, [0, 0, 0], [-1, nway * kshot, -1], name='support_y_test') 150 | query_y_test = tf.slice(label_tensor, [0, nway * kshot, 0], [-1, -1, -1], name='query_y_test') 151 | 152 | 153 | # 1. construct MAML model 154 | model = MAML(84, 3, 5) 155 | 156 | # construct metatrain_ and metaval_ 157 | if training: 158 | model.build(support_x, support_y, query_x, query_y, K, meta_batchsz, mode='train') 159 | model.build(support_x_test, support_y_test, query_x_test, query_y_test, K, meta_batchsz, mode='eval') 160 | else: 161 | model.build(support_x_test, support_y_test, query_x_test, query_y_test, K + 5, meta_batchsz, mode='test') 162 | model.summ_op = tf.summary.merge_all() 163 | 164 | all_vars = filter(lambda x: 'meta_optim' not in x.name, tf.trainable_variables()) 165 | for p in all_vars: 166 | print(p) 167 | 168 | 169 | config = tf.ConfigProto() 170 | config.gpu_options.allow_growth = True 171 | sess = tf.InteractiveSession(config=config) 172 | # tf.global_variables() to save moving_mean and moving variance of batch norm 173 | # tf.trainable_variables() NOT include moving_mean and moving_variance. 174 | saver = tf.train.Saver(tf.global_variables(), max_to_keep=5) 175 | 176 | # initialize, under interative session 177 | tf.global_variables_initializer().run() 178 | tf.train.start_queue_runners() 179 | 180 | if os.path.exists(os.path.join('ckpt', 'checkpoint')): 181 | # alway load ckpt both train and test. 182 | model_file = tf.train.latest_checkpoint('ckpt') 183 | print("Restoring model weights from ", model_file) 184 | saver.restore(sess, model_file) 185 | 186 | 187 | if training: 188 | train(model, saver, sess) 189 | else: 190 | test(model, sess) 191 | 192 | 193 | if __name__ == "__main__": 194 | main() 195 | -------------------------------------------------------------------------------- /MAML-FSIDS_CIC-DDoS2019-5shot/main5-5train.py: -------------------------------------------------------------------------------- 1 | #coding=utf-8 2 | import os 3 | import numpy as np 4 | import argparse 5 | import random 6 | import tensorflow as tf 7 | 8 | from data_generator_train import DataGenerator 9 | from maml import MAML 10 | 11 | 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument('-t', '--test', action='store_true', default=False, help='set for test, otherwise train') 14 | args = parser.parse_args() 15 | 16 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 17 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 18 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 19 | 20 | 21 | def train(model, saver, sess): 22 | """ 23 | 24 | :param model: 25 | :param saver: 26 | :param sess: 27 | :return: 28 | """ 29 | # write graph to tensorboard 30 | # tb = tf.summary.FileWriter(os.path.join('logs', 'mini'), sess.graph) 31 | prelosses, postlosses, preaccs, postaccs = [], [], [], [] 32 | best_acc = 0 33 | 34 | # train for meta_iteartion epoches 35 | for iteration in range(10000): 36 | # this is the main op 37 | ops = [model.meta_op] 38 | 39 | # add summary and print op 40 | if iteration % 100 == 0: 41 | ops.extend([model.summ_op, 42 | model.query_losses[0], model.query_losses[-1], 43 | model.query_accs[0], model.query_accs[-1]]) 44 | 45 | # run all ops 46 | result = sess.run(ops) 47 | 48 | # summary 49 | if iteration % 100 == 0: 50 | # summ_op 51 | # tb.add_summary(result[1], iteration) 52 | # query_losses[0] 53 | prelosses.append(result[2]) 54 | # query_losses[-1] 55 | postlosses.append(result[3]) 56 | # query_accs[0] 57 | preaccs.append(result[4]) 58 | # query_accs[-1] 59 | postaccs.append(result[5]) 60 | 61 | print(iteration, '\tloss:', np.mean(prelosses), '=>', np.mean(postlosses), 62 | '\t\tacc:', np.mean(preaccs), '=>', np.mean(postaccs)) 63 | prelosses, postlosses, preaccs, postaccs = [], [], [], [] 64 | 65 | # evaluation 66 | if iteration % 1000 == 0: 67 | # DO NOT write as a = b = [], in that case a=b 68 | # DO NOT use train variable as we have train func already. 69 | acc1s, acc2s = [], [] 70 | # sample 20 times to get more accurate statistics. 71 | for _ in range(100): 72 | acc1, acc2 = sess.run([model.test_query_accs[0], 73 | model.test_query_accs[-1]]) 74 | acc1s.append(acc1) 75 | acc2s.append(acc2) 76 | 77 | acc = np.mean(acc2s) 78 | print('>>>>\t\tValidation accs: ', np.mean(acc1s), acc, 'best:', best_acc, '\t\t<<<<') 79 | 80 | if acc - best_acc > 0.05 or acc > 0.4: 81 | saver.save(sess, os.path.join('ckpt', 'mini.mdl')) 82 | best_acc = acc 83 | print('saved into ckpt:', acc) 84 | 85 | 86 | def test(model, sess): 87 | 88 | np.random.seed(1) 89 | random.seed(1) 90 | 91 | # repeat test accuracy for 600 times 92 | test_accs = [] 93 | for i in range(600): 94 | if i % 100 == 1: 95 | print(i) 96 | # extend return None!!! 97 | ops = [model.test_support_acc] 98 | ops.extend(model.test_query_accs) 99 | result = sess.run(ops) 100 | test_accs.append(result) 101 | 102 | # [600, K+1] 103 | test_accs = np.array(test_accs) 104 | # [K+1] 105 | means = np.mean(test_accs, 0) 106 | stds = np.std(test_accs, 0) 107 | ci95 = 1.96 * stds / np.sqrt(600) 108 | 109 | print('[support_t0, query_t0 - \t\t\tK] ') 110 | print('mean:', means) 111 | print('stds:', stds) 112 | print('ci95:', ci95) 113 | 114 | 115 | 116 | def main(): 117 | training = not args.test 118 | kshot = 5 119 | kquery = 15 120 | nway = 5 121 | meta_batchsz = 4 122 | K = 5 123 | 124 | 125 | # kshot + kquery images per category, nway categories, meta_batchsz tasks. 126 | db = DataGenerator(nway, kshot, kquery, meta_batchsz, 200000) 127 | 128 | if training: # only construct training model if needed 129 | # get the tensor 130 | # image_tensor: [4, 80, 84*84*3] 131 | # label_tensor: [4, 80, 5] 132 | image_tensor, label_tensor = db.make_data_tensor(training=True) 133 | 134 | # NOTICE: the image order in 80 images should like this now: 135 | # [label2, label1, label3, label0, label4, and then repeat by 15 times, namely one task] 136 | # support_x : [4, 1*5, 84*84*3] 137 | # query_x : [4, 15*5, 84*84*3] 138 | # support_y : [4, 5, 5] 139 | # query_y : [4, 15*5, 5] 140 | support_x = tf.slice(image_tensor, [0, 0, 0], [-1, nway * kshot, -1], name='support_x') 141 | query_x = tf.slice(image_tensor, [0, nway * kshot, 0], [-1, -1, -1], name='query_x') 142 | support_y = tf.slice(label_tensor, [0, 0, 0], [-1, nway * kshot, -1], name='support_y') 143 | query_y = tf.slice(label_tensor, [0, nway * kshot, 0], [-1, -1, -1], name='query_y') 144 | 145 | # construct test tensors. 146 | image_tensor, label_tensor = db.make_data_tensor(training=False) 147 | support_x_test = tf.slice(image_tensor, [0, 0, 0], [-1, nway * kshot, -1], name='support_x_test') 148 | query_x_test = tf.slice(image_tensor, [0, nway * kshot, 0], [-1, -1, -1], name='query_x_test') 149 | support_y_test = tf.slice(label_tensor, [0, 0, 0], [-1, nway * kshot, -1], name='support_y_test') 150 | query_y_test = tf.slice(label_tensor, [0, nway * kshot, 0], [-1, -1, -1], name='query_y_test') 151 | 152 | 153 | # 1. construct MAML model 154 | model = MAML(84, 3, 5) 155 | 156 | # construct metatrain_ and metaval_ 157 | if training: 158 | model.build(support_x, support_y, query_x, query_y, K, meta_batchsz, mode='train') 159 | model.build(support_x_test, support_y_test, query_x_test, query_y_test, K, meta_batchsz, mode='eval') 160 | else: 161 | model.build(support_x_test, support_y_test, query_x_test, query_y_test, K + 5, meta_batchsz, mode='test') 162 | model.summ_op = tf.summary.merge_all() 163 | 164 | all_vars = filter(lambda x: 'meta_optim' not in x.name, tf.trainable_variables()) 165 | for p in all_vars: 166 | print(p) 167 | 168 | 169 | config = tf.ConfigProto() 170 | config.gpu_options.allow_growth = True 171 | sess = tf.InteractiveSession(config=config) 172 | # tf.global_variables() to save moving_mean and moving variance of batch norm 173 | # tf.trainable_variables() NOT include moving_mean and moving_variance. 174 | saver = tf.train.Saver(tf.global_variables(), max_to_keep=5) 175 | 176 | # initialize, under interative session 177 | tf.global_variables_initializer().run() 178 | tf.train.start_queue_runners() 179 | 180 | if os.path.exists(os.path.join('ckpt', 'checkpoint')): 181 | # alway load ckpt both train and test. 182 | model_file = tf.train.latest_checkpoint('ckpt') 183 | print("Restoring model weights from ", model_file) 184 | saver.restore(sess, model_file) 185 | 186 | 187 | if training: 188 | train(model, saver, sess) 189 | else: 190 | test(model, sess) 191 | 192 | 193 | if __name__ == "__main__": 194 | main() 195 | -------------------------------------------------------------------------------- /MAML-FSIDS_CIC-DDoS2019-5shot/maml.py: -------------------------------------------------------------------------------- 1 | #coding=utf-8 2 | import sys 3 | 4 | import numpy as np 5 | import tensorflow as tf 6 | 7 | 8 | class MAML: 9 | def __init__(self, d, c, nway, meta_lr=1e-3, train_lr=1e-2): 10 | """ 11 | 12 | :param d: 13 | :param c: 14 | :param nway: 15 | :param meta_lr: 16 | :param train_lr: 17 | """ 18 | self.d = d 19 | self.c = c 20 | self.nway = nway 21 | self.meta_lr = meta_lr 22 | self.train_lr = train_lr 23 | 24 | print('img shape:', self.d, self.d, self.c, 'meta-lr:', meta_lr, 'train-lr:', train_lr) 25 | 26 | def build(self, support_xb, support_yb, query_xb, query_yb, K, meta_batchsz, mode='train'): 27 | """ 28 | 29 | :param support_xb: [b, setsz, 84*84*3] 30 | :param support_yb: [b, setsz, n-way] 31 | :param query_xb: [b, querysz, 84*84*3] 32 | :param query_yb: [b, querysz, n-way] 33 | :param K: train update steps 34 | :param meta_batchsz:tasks number 35 | :param mode: train/eval/test, for training, we build train&eval network meanwhile. 36 | :return: 37 | """ 38 | # create or reuse network variable, not including batch_norm variable, therefore we need extra reuse mechnism 39 | # to reuse batch_norm variables. 40 | self.weights = self.conv_weights() 41 | # TODO: meta-test is sort of test stage. 42 | training = True if mode is 'train' else False 43 | 44 | def meta_task(input): 45 | """ 46 | map_fn only support one parameters, so we need to unpack from tuple. 47 | :param support_x: [setsz, 84*84*3] 48 | :param support_y: [setsz, n-way] 49 | :param query_x: [querysz, 84*84*3] 50 | :param query_y: [querysz, n-way] 51 | :param training: training or not, for batch_norm 52 | :return: 53 | """ 54 | support_x, support_y, query_x, query_y = input 55 | # to record the op in t update step. 56 | query_preds, query_losses, query_accs = [], [], [] 57 | 58 | # ================================== 59 | # REUSE True False 60 | # Not exist Error Create one 61 | # Existed reuse Error 62 | # ================================== 63 | # That's, to create variable, you must turn off reuse 64 | support_pred = self.forward(support_x, self.weights, training) 65 | support_loss = tf.nn.softmax_cross_entropy_with_logits(logits=support_pred, labels=support_y) 66 | support_acc = tf.contrib.metrics.accuracy(tf.argmax(tf.nn.softmax(support_pred, dim=1), axis=1), 67 | tf.argmax(support_y, axis=1)) 68 | # compute gradients 69 | grads = tf.gradients(support_loss, list(self.weights.values())) 70 | # grad and variable dict 71 | gvs = dict(zip(self.weights.keys(), grads)) 72 | 73 | # theta_pi = theta - alpha * grads 74 | fast_weights = dict(zip(self.weights.keys(), 75 | [self.weights[key] - self.train_lr * gvs[key] for key in self.weights.keys()])) 76 | # use theta_pi to forward meta-test 77 | query_pred = self.forward(query_x, fast_weights, training) 78 | # meta-test loss 79 | query_loss = tf.nn.softmax_cross_entropy_with_logits(logits=query_pred, labels=query_y) 80 | # record T0 pred and loss for meta-test 81 | query_preds.append(query_pred) 82 | query_losses.append(query_loss) 83 | 84 | # continue to build T1-TK steps graph 85 | for _ in range(1, K): 86 | # T_k loss on meta-train 87 | # we need meta-train loss to fine-tune the task and meta-test loss to update theta 88 | loss = tf.nn.softmax_cross_entropy_with_logits(logits=self.forward(support_x, fast_weights, training), 89 | labels=support_y) 90 | # compute gradients 91 | grads = tf.gradients(loss, list(fast_weights.values())) 92 | # compose grad and variable dict 93 | gvs = dict(zip(fast_weights.keys(), grads)) 94 | # update theta_pi according to varibles 95 | fast_weights = dict(zip(fast_weights.keys(), [fast_weights[key] - self.train_lr * gvs[key] 96 | for key in fast_weights.keys()])) 97 | # forward on theta_pi 98 | query_pred = self.forward(query_x, fast_weights, training) 99 | # we need accumulate all meta-test losses to update theta 100 | query_loss = tf.nn.softmax_cross_entropy_with_logits(logits=query_pred, labels=query_y) 101 | query_preds.append(query_pred) 102 | query_losses.append(query_loss) 103 | 104 | 105 | # compute every steps' accuracy on query set 106 | for i in range(K): 107 | query_accs.append(tf.contrib.metrics.accuracy(tf.argmax(tf.nn.softmax(query_preds[i], dim=1), axis=1), 108 | tf.argmax(query_y, axis=1))) 109 | # we just use the first step support op: support_pred & support_loss, but igonre these support op 110 | # at step 1:K-1. 111 | # however, we return all pred&loss&acc op at each time steps. 112 | result = [support_pred, support_loss, support_acc, query_preds, query_losses, query_accs] 113 | 114 | return result 115 | 116 | # return: [support_pred, support_loss, support_acc, query_preds, query_losses, query_accs] 117 | out_dtype = [tf.float32, tf.float32, tf.float32, [tf.float32] * K, [tf.float32] * K, [tf.float32] * K] 118 | result = tf.map_fn(meta_task, elems=(support_xb, support_yb, query_xb, query_yb), 119 | dtype=out_dtype, parallel_iterations=meta_batchsz, name='map_fn') 120 | support_pred_tasks, support_loss_tasks, support_acc_tasks, \ 121 | query_preds_tasks, query_losses_tasks, query_accs_tasks = result 122 | 123 | 124 | if mode is 'train': 125 | # average loss 126 | self.support_loss = support_loss = tf.reduce_sum(support_loss_tasks) / meta_batchsz 127 | # [avgloss_t1, avgloss_t2, ..., avgloss_K] 128 | self.query_losses = query_losses = [tf.reduce_sum(query_losses_tasks[j]) / meta_batchsz 129 | for j in range(K)] 130 | # average accuracy 131 | self.support_acc = support_acc = tf.reduce_sum(support_acc_tasks) / meta_batchsz 132 | # average accuracies 133 | self.query_accs = query_accs = [tf.reduce_sum(query_accs_tasks[j]) / meta_batchsz 134 | for j in range(K)] 135 | 136 | # # add batch_norm ops before meta_op 137 | # update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 138 | # with tf.control_dependencies(update_ops): 139 | # # TODO: the update_ops must be put before tf.train.AdamOptimizer, 140 | # # otherwise it throws Not in same Frame Error. 141 | # meta_loss = tf.identity(self.query_losses[-1]) 142 | 143 | # meta-train optim 144 | optimizer = tf.train.AdamOptimizer(self.meta_lr, name='meta_optim') 145 | # meta-train gradients, query_losses[-1] is the accumulated loss across over tasks. 146 | gvs = optimizer.compute_gradients(self.query_losses[-1]) 147 | # meta-train grads clipping 148 | gvs = [(tf.clip_by_norm(grad, 10), var) for grad, var in gvs] 149 | # update theta 150 | self.meta_op = optimizer.apply_gradients(gvs) 151 | 152 | 153 | else: # test & eval 154 | 155 | # average loss 156 | self.test_support_loss = support_loss = tf.reduce_sum(support_loss_tasks) / meta_batchsz 157 | # [avgloss_t1, avgloss_t2, ..., avgloss_K] 158 | self.test_query_losses = query_losses = [tf.reduce_sum(query_losses_tasks[j]) / meta_batchsz 159 | for j in range(K)] 160 | # average accuracy 161 | self.test_support_acc = support_acc = tf.reduce_sum(support_acc_tasks) / meta_batchsz 162 | # average accuracies 163 | self.test_query_accs = query_accs = [tf.reduce_sum(query_accs_tasks[j]) / meta_batchsz 164 | for j in range(K)] 165 | 166 | # NOTICE: every time build model, support_loss will be added to the summary, but it's different. 167 | tf.summary.scalar(mode + ':support loss', support_loss) 168 | tf.summary.scalar(mode + ':support acc', support_acc) 169 | for j in range(K): 170 | tf.summary.scalar(mode + ':query loss, step ' + str(j + 1), query_losses[j]) 171 | tf.summary.scalar(mode + ':query acc, step ' + str(j + 1), query_accs[j]) 172 | 173 | 174 | 175 | 176 | def conv_weights(self): 177 | weights = {} 178 | 179 | conv_initializer = tf.contrib.layers.xavier_initializer_conv2d() 180 | fc_initializer = tf.contrib.layers.xavier_initializer() 181 | k = 3 182 | 183 | with tf.variable_scope('MAML', reuse= tf.AUTO_REUSE): 184 | weights['conv1'] = tf.get_variable('conv1w', [k, k, 3, 32], initializer=conv_initializer) 185 | weights['b1'] = tf.get_variable('conv1b', initializer=tf.zeros([32])) 186 | weights['conv2'] = tf.get_variable('conv2w', [k, k, 32, 32], initializer=conv_initializer) 187 | weights['b2'] = tf.get_variable('conv2b', initializer=tf.zeros([32])) 188 | weights['conv3'] = tf.get_variable('conv3w', [k, k, 32, 32], initializer=conv_initializer) 189 | weights['b3'] = tf.get_variable('conv3b', initializer=tf.zeros([32])) 190 | weights['conv4'] = tf.get_variable('conv4w', [k, k, 32, 32], initializer=conv_initializer) 191 | weights['b4'] = tf.get_variable('conv4b', initializer=tf.zeros([32])) 192 | 193 | # assumes max pooling 194 | weights['w5'] = tf.get_variable('fc1w', [32 * 5 * 5, self.nway], initializer=fc_initializer) 195 | weights['b5'] = tf.get_variable('fc1b', initializer=tf.zeros([self.nway])) 196 | 197 | 198 | return weights 199 | 200 | def conv_block(self, x, weight, bias, scope, training): 201 | """ 202 | build a block with conv2d->batch_norm->pooling 203 | :param x: 204 | :param weight: 205 | :param bias: 206 | :param scope: 207 | :param training: 208 | :return: 209 | """ 210 | # conv 211 | x = tf.nn.conv2d(x, weight, [1, 1, 1, 1], 'SAME', name=scope + '_conv2d') + bias 212 | # batch norm, activation_fn=tf.nn.relu, 213 | # NOTICE: must have tf.layers.batch_normalization 214 | # x = tf.contrib.layers.batch_norm(x, activation_fn=tf.nn.relu) 215 | with tf.variable_scope('MAML'): 216 | # train is set to True ALWAYS, please refer to https://github.com/cbfinn/maml/issues/9 217 | # when FLAGS.train=True, we still need to build evaluation network 218 | x = tf.layers.batch_normalization(x, training=True, name=scope + '_bn', reuse=tf.AUTO_REUSE) 219 | # relu 220 | x = tf.nn.relu(x, name=scope + '_relu') 221 | # pooling 222 | x = tf.nn.max_pool(x, [1, 2, 2, 1], [1, 2, 2, 1], 'VALID', name=scope + '_pool') 223 | return x 224 | 225 | 226 | def forward(self, x, weights, training): 227 | """ 228 | 229 | 230 | :param x: 231 | :param weights: 232 | :param training: 233 | :return: 234 | """ 235 | # [b, 84, 84, 3] 236 | x = tf.reshape(x, [-1, self.d, self.d, self.c], name='reshape1') 237 | 238 | hidden1 = self.conv_block(x, weights['conv1'], weights['b1'], 'conv0', training) 239 | hidden2 = self.conv_block(hidden1, weights['conv2'], weights['b2'], 'conv1', training) 240 | hidden3 = self.conv_block(hidden2, weights['conv3'], weights['b3'], 'conv2', training) 241 | hidden4 = self.conv_block(hidden3, weights['conv4'], weights['b4'], 'conv3', training) 242 | 243 | # get_shape is static shape, (5, 5, 5, 32) 244 | # print('flatten1:', hidden1.get_shape()) 245 | # print('flatten2:', hidden2.get_shape()) 246 | # print('flatten3:', hidden3.get_shape()) 247 | # print('flatten4:', hidden4.get_shape()) 248 | # flatten layer 249 | hidden4 = tf.reshape(hidden4, [-1, np.prod([int(dim) for dim in hidden4.get_shape()[1:]])], name='reshape2') 250 | 251 | # print('flatten4-r:', hidden4.get_shape()) 252 | # sys.exit(0) 253 | 254 | output = tf.add(tf.matmul(hidden4, weights['w5']), weights['b5'], name='fc1') 255 | 256 | return output 257 | 258 | --------------------------------------------------------------------------------