├── .gitignore ├── DL ├── dl.py └── training │ └── GCseq25.csv ├── Deep_Learning.pdf ├── ML ├── ml.py └── training │ └── GCstats.csv ├── README.md ├── class_results.csv ├── final_results.csv ├── imgs ├── Ensemble.png ├── Improvement.png ├── Metrics.png ├── PerSNIEnsemble.png ├── RFvsEnsemble.png ├── classes.png ├── directionality.png └── perSniBest.png ├── main.py ├── notebook.ipynb ├── preliminary_results.csv ├── preprocessing ├── create_pcap_stat.py ├── pytcpdump.py ├── pytcpdump.pyc ├── pytcpdump_utils.py └── pytcpdump_utils.pyc ├── references.txt ├── requirements.txt └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | pcaps 2 | .* 3 | imgs -------------------------------------------------------------------------------- /DL/dl.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn.metrics import accuracy_score 3 | from keras.preprocessing import text, sequence 4 | from keras.preprocessing.text import Tokenizer 5 | from keras.utils import to_categorical 6 | from keras.models import Model, Input, Sequential 7 | from keras.layers import GRU, LSTM, Embedding, Dense, TimeDistributed, Bidirectional, Activation, Dropout 8 | from sklearn.model_selection import train_test_split 9 | from keras.metrics import categorical_accuracy 10 | from keras import backend as K 11 | import tensorflow as tf 12 | from keras import optimizers 13 | import pandas as pd 14 | from keras.utils.np_utils import to_categorical 15 | from sklearn.model_selection import KFold 16 | from keras.callbacks import EarlyStopping 17 | import matplotlib.pyplot as plt 18 | import autosklearn.classification 19 | from keras.layers import * 20 | from sklearn.preprocessing import MinMaxScaler 21 | from sklearn.preprocessing import Normalizer 22 | from sklearn.preprocessing import StandardScaler 23 | 24 | BATCH_SIZE = 64 25 | EPOCHS = 100 # use early stopping 26 | FOLDS = 10 27 | SEQ_LEN = 25 28 | NUM_ROWS = -1 # just use first day for now, set to -1 for all data 29 | MIN_CONNECTIONS_LIST = [100] 30 | 31 | def read_csv(file_path, has_header=True): 32 | with open(file_path) as f: 33 | if has_header: f.readline() 34 | data = [] 35 | for line in f: 36 | line = line.strip().split(",") 37 | data.append([x for x in line]) 38 | return data 39 | 40 | #################################################### 41 | # Filter for SNIs meeting min connection threshold 42 | #################################################### 43 | def data_load_and_filter(datasetfile, min_connections): 44 | dataset = read_csv(datasetfile) 45 | 46 | # Use first n rows 47 | dataset = dataset[:NUM_ROWS] 48 | 49 | # packet sizes 50 | X1 = np.array([z[1:SEQ_LEN + 1] for z in dataset]) 51 | 52 | # payload sizes 53 | X2 = np.array([z[SEQ_LEN + 1:2*SEQ_LEN + 1] for z in dataset]) 54 | 55 | # inter-arrival times 56 | X3 = np.array([z[2*SEQ_LEN + 1:3*SEQ_LEN + 1] for z in dataset]) 57 | X3 = X3.astype(float) 58 | X3[np.where(X3 != 0 )] = np.log(X3[np.where(X3 != 0 )]) 59 | 60 | # direction 61 | X4 = np.array([z[3*SEQ_LEN + 1:4*SEQ_LEN + 1] for z in dataset]) 62 | 63 | y = np.array([z[0] for z in dataset]) 64 | print("Shape of X1 =", np.shape(X1)) 65 | print("Shape of X2 =", np.shape(X2)) 66 | print("Shape of X3 =", np.shape(X3)) 67 | print("Shape of X4 =", np.shape(X4)) 68 | print("Shape of y =", np.shape(y)) 69 | 70 | print("Entering min connections filter section! ") 71 | snis, counts = np.unique(y, return_counts=True) 72 | above_min_conns = list() 73 | 74 | for i in range(len(counts)): 75 | if counts[i] > min_connections: 76 | above_min_conns.append(snis[i]) 77 | 78 | print("Filtering done. SNI classes remaining: ", len(above_min_conns)) 79 | indices = np.isin(y, above_min_conns) 80 | X1 = X1[indices] 81 | X2 = X2[indices] 82 | X3 = X3[indices] 83 | X4 = X4[indices] 84 | y = y[indices] 85 | 86 | print("Filtered shape of X1 =", np.shape(X1)) 87 | print("Filtered shape of X2 =", np.shape(X2)) 88 | print("Filtered shape of X3 =", np.shape(X3)) 89 | print("Filtered shape of X4 =", np.shape(X4)) 90 | print("Filtered shape of y =", np.shape(y)) 91 | 92 | ##### BASIC PARAMETERS ##### 93 | n_samples = np.shape(X1)[0] 94 | time_steps = np.shape(X1)[1] # we have a time series of 100 payload sizes 95 | n_features = 1 # 1 feature which is packet size 96 | 97 | ##### CREATES MAPPING FROM SNI STRING TO INT ##### 98 | class_map = {sni:i for i, sni in enumerate(np.unique(y))} 99 | rev_class_map = {val: key for key, val in class_map.items()} 100 | 101 | n_labels = len(class_map) 102 | 103 | ##### CHANGE Y TO PD SO ITS EASIER TO MAP ##### 104 | y_pd = pd.DataFrame(y) 105 | y_pd = y_pd[0].map(class_map) 106 | 107 | ##### DUPLICATE Y LABELS, WE WILL NEED THIS LATER ##### 108 | y = y_pd.values.reshape(n_samples,) 109 | 110 | return X1, X2, X3, X4, y, time_steps, n_features, n_labels, rev_class_map 111 | 112 | ######################################################### 113 | ###### USE RNN TO CLASSIFY PACKET SEQUENCES -> SNI ###### 114 | ######################################################### 115 | def DLClassification(X_train, X_test, y_train, y_test,time_steps, n_features, n_labels, dropout): 116 | X_train = np.stack([X_train], axis=2) 117 | X_test = np.stack([X_test], axis=2) 118 | 119 | # if you dont have newest keras version, you might have to remove restore_best_weights = True 120 | early_stopping = EarlyStopping(monitor='val_loss', min_delta=0, patience=5, verbose=1, mode='min') 121 | model = Sequential() 122 | model.add(Conv1D(200, 3, activation='relu', input_shape=(time_steps, n_features))) 123 | model.add(BatchNormalization()) 124 | model.add(Conv1D(400, 3, activation='relu')) 125 | model.add(BatchNormalization()) 126 | model.add(GRU(200)) 127 | model.add(Dropout(dropout)) 128 | model.add(Dense(200, activation='sigmoid')) 129 | model.add(Dropout(dropout)) 130 | model.add(Dense(n_labels, activation='softmax')) 131 | model.compile(loss='sparse_categorical_crossentropy',optimizer='adam', metrics=['acc']) 132 | model.summary() 133 | model.fit(X_train, y_train, epochs=EPOCHS, batch_size=BATCH_SIZE, verbose=1, shuffle=False, validation_data=(X_test, y_test), callbacks = [early_stopping]) 134 | return model.predict(X_test) 135 | 136 | #*********************************************************************************** 137 | # autosklearn classifier to find the best achievable accuracy 138 | #*********************************************************************************** 139 | def auto_sklearn_classification(X_train, X_test, y_train, y_test): 140 | cls = autosklearn.classification.AutoSklearnClassifier(time_left_for_this_task=300, per_run_time_limit=90, ml_memory_limit=50000) 141 | cls.fit(X_train, y_train) 142 | print(cls.sprint_statistics()) 143 | print(cls.show_models()) 144 | predictions = cls.predict(X_test) 145 | accuracy = accuracy_score(y_test, predictions) 146 | return accuracy 147 | 148 | 149 | if __name__ == "__main__": 150 | datasetfile = "training/GCseq25.csv" 151 | 152 | kf = KFold(n_splits=FOLDS, shuffle=True) 153 | 154 | # try a variety of min conn settings for graph 155 | accuracies = [] 156 | for min_connections in MIN_CONNECTIONS_LIST: 157 | X1, X2, X3, X4, y, time_steps, n_features, n_labels, rev_class_map = data_load_and_filter(datasetfile, min_connections) 158 | 159 | total_nn1, total_nn2, total_nn3, total_nn123, total_cls = 0, 0, 0, 0, 0 160 | for train_index, test_index in kf.split(X1): 161 | 162 | X1_train, X1_test = X1[train_index], X1[test_index] # Packet sizes 163 | X2_train, X2_test = X2[train_index], X2[test_index] # Payload sizes 164 | X3_train, X3_test = X3[train_index], X3[test_index] # Inter-Arrival Times 165 | 166 | # Directional features not used! 167 | # X4_train, X4_test = X4[train_index], X4[test_index] 168 | 169 | y_train, y_test = y[train_index], y[test_index] 170 | 171 | # CNN-RNN for Packet Size 172 | predictions1 = DLClassification(X1_train, X1_test, y_train, y_test, time_steps, n_features, n_labels, dropout=0.0) 173 | 174 | # CNN-RNN for Payload Size 175 | predictions2 = DLClassification(X2_train, X2_test, y_train, y_test, time_steps, n_features, n_labels, dropout=0.0) 176 | 177 | # CNN-RNN for Inter-Arrival times 178 | predictions3 = DLClassification(X3_train, X3_test, y_train, y_test, time_steps, n_features, n_labels, dropout=0.25) 179 | 180 | nn_acc1 = 1. * np.sum([np.argmax(x) for x in predictions1] == y_test) / len(y_test) 181 | print("CNN-RNN Packet ACCURACY: %s"%(nn_acc1)) 182 | 183 | nn_acc2 = 1. * np.sum([np.argmax(x) for x in predictions2] == y_test) / len(y_test) 184 | print("CNN-RNN Payload ACCURACY: %s"%(nn_acc2)) 185 | 186 | nn_acc3 = 1. * np.sum([np.argmax(x) for x in predictions3] == y_test) / len(y_test) 187 | print("CNN-RNN IAT ACCURACY: %s"%(nn_acc3)) 188 | 189 | # Ensemble CNN-RNN 190 | predictions123 = (predictions1 * (1.0/3) + predictions2 * (1.0/3) + predictions3 * (1.0/3)) 191 | nn_acc123 = 1. * np.sum([np.argmax(x) for x in predictions123] == y_test) / len(y_test) 192 | print("Ensemble CNN-RNN ACCURACY: %s"%(nn_acc123)) 193 | 194 | total_nn1+= nn_acc1 195 | total_nn2+= nn_acc2 196 | total_nn3+= nn_acc3 197 | total_nn123+= nn_acc123 198 | 199 | # Uncomment for auto sklearn results on sequence features 200 | # cls_acc = auto_sklearn_classification(X_train, X_test, y_train, y_test) 201 | # print("Auto sklearn Accuracy: %s "%(cls_acc)) 202 | # total_cls += cls_acc 203 | 204 | # Uncomment to run once 205 | # FOLDS = 1 206 | # break 207 | 208 | total_nn1 = 1. * total_nn1 / FOLDS 209 | total_nn2 = 1. * total_nn2 / FOLDS 210 | total_nn3 = 1. * total_nn3 / FOLDS 211 | total_nn123 = 1. * total_nn123 / FOLDS 212 | total_cls = 1. * total_cls / FOLDS 213 | 214 | print("AVG CNN-RNN Packet: %s\n AVG CNN-RNN Payload: %s\n AVG CNN-RNN IAT: %s\n AVG CNN-RNN Ensemble: %s\n AVG CLS: %s\n "%(total_nn1, total_nn2, total_nn3, total_nn123, total_cls)) 215 | 216 | accuracies.append([total_nn1, total_nn2, total_nn3, total_nn123, total_cls]) 217 | 218 | print(accuracies) 219 | 220 | 221 | -------------------------------------------------------------------------------- /Deep_Learning.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/niloofarbayat/NetworkClassification/4445bd0b86f2c2ca088f5e008138b1fec7abad4c/Deep_Learning.pdf -------------------------------------------------------------------------------- /ML/ml.py: -------------------------------------------------------------------------------- 1 | # Random Forest versus autosklearn classifier 2 | #*********************************************************************************** 3 | 4 | import numpy as np 5 | import sklearn.ensemble, sklearn.model_selection 6 | from sklearn.decomposition import PCA 7 | import matplotlib.pyplot as plt 8 | from mpl_toolkits.mplot3d import Axes3D 9 | from sklearn.ensemble import RandomForestClassifier 10 | from sklearn.model_selection import KFold 11 | from sklearn.metrics import accuracy_score 12 | import autosklearn.classification 13 | import pandas 14 | from sklearn.metrics import classification_report 15 | import csv 16 | from collections import defaultdict 17 | 18 | MIN_CONNECTIONS_LIST = [100] 19 | FOLDS = 10 20 | NUM_ROWS = -1 # set to -1 for all data 21 | 22 | def read_csv(file_path, has_header=True): 23 | with open(file_path) as f: 24 | if has_header: f.readline() 25 | data = [] 26 | for line in f: 27 | line = line.strip().split(",") 28 | data.append([x for x in line]) 29 | return data 30 | 31 | #*********************************************************************************** 32 | # Filter for SNIs meeting min connection threshold 33 | #*********************************************************************************** 34 | def data_load_and_filter(datasetfile, min_connections): 35 | dataset = read_csv(datasetfile) 36 | 37 | # Use first n rows if necessary 38 | dataset = dataset[:NUM_ROWS] 39 | 40 | X = np.array([z[1:] for z in dataset]) 41 | y = np.array([z[0] for z in dataset]) 42 | print("Shape of X =", np.shape(X)) 43 | print("Shape of y =", np.shape(y)) 44 | 45 | print("Entering min connections filter section! ") 46 | snis, counts = np.unique(y, return_counts=True) 47 | above_min_conns = list() 48 | 49 | for i in range(len(counts)): 50 | if counts[i] > min_connections: 51 | above_min_conns.append(snis[i]) 52 | 53 | print("Filtering done. SNI classes remaining: ", len(above_min_conns)) 54 | indices = np.isin(y, above_min_conns) 55 | X = X[indices] 56 | y = y[indices] 57 | 58 | print("Filtered shape of X =", np.shape(X)) 59 | print("Filtered shape of y =", np.shape(y)) 60 | 61 | #it's needed for auto_sklearn to work 62 | X = X.astype(np.float) 63 | return X, y 64 | 65 | #*********************************************************************************** 66 | # SNI prediction using Random Forest Classifier 67 | #*********************************************************************************** 68 | def MLClassification(X_train, X_test, y_train, y_test): 69 | rf = RandomForestClassifier(n_estimators=250, n_jobs=10) 70 | rf.fit(X_train, y_train) 71 | predictions = rf.predict(X_test) 72 | 73 | report = [] 74 | report_str = classification_report(y_test, predictions) 75 | for row in report_str.split("\n"): 76 | parsed_row = [x for x in row.split(" ") if len(x) > 0] 77 | if len(parsed_row) > 0: 78 | report.append(parsed_row) 79 | 80 | # save accuracy, precision, recall, F1-Score to dictionary 81 | accuracy = accuracy_score(predictions,y_test) 82 | precision = float(report[-1][1]) 83 | recall = float(report[-1][2]) 84 | f1_score = float(report[-1][3]) 85 | return accuracy, precision, recall, f1_score 86 | 87 | 88 | #*********************************************************************************** 89 | # Autosklearn classifier to find the best achievable accuracy 90 | #*********************************************************************************** 91 | def auto_sklearn_classification(X_train, X_test, y_train, y_test): 92 | cls = autosklearn.classification.AutoSklearnClassifier(time_left_for_this_task=300, per_run_time_limit=90, ml_memory_limit=10000) 93 | cls.fit(X_train, y_train) 94 | predictions = cls.predict(X_test) 95 | 96 | report = [] 97 | report_str = classification_report(y_test, predictions) 98 | for row in report_str.split("\n"): 99 | parsed_row = [x for x in row.split(" ") if len(x) > 0] 100 | if len(parsed_row) > 0: 101 | report.append(parsed_row) 102 | 103 | # save accuracy, precision, recall, F1-Score to dictionary 104 | accuracy = accuracy_score(predictions,y_test) 105 | precision = float(report[-1][1]) 106 | recall = float(report[-1][2]) 107 | f1_score = float(report[-1][3]) 108 | return accuracy, precision, recall, f1_score 109 | 110 | if __name__ == "__main__": 111 | 112 | datasetfile = "training/GCstats.csv" 113 | 114 | kf = KFold(n_splits=FOLDS, shuffle=True) 115 | for min_connections in MIN_CONNECTIONS_LIST: 116 | X, y = data_load_and_filter(datasetfile, min_connections) 117 | total_rf, total_cls = [0,0,0,0], [0,0,0,0] 118 | 119 | for train_index, test_index in kf.split(X): 120 | X_train, X_test = X[train_index], X[test_index] 121 | y_train, y_test = y[train_index], y[test_index] 122 | accuracy, precision, recall, f1_score = MLClassification(X_train, X_test, y_train, y_test) 123 | print("Random Forest ACCURACY: %s"%(accuracy)) 124 | total_rf[0] += accuracy 125 | total_rf[1] += precision 126 | total_rf[2] += recall 127 | total_rf[3] += f1_score 128 | 129 | accuracy, precision, recall, f1_score = auto_sklearn_classification(X_train, X_test, y_train, y_test) 130 | print("Auto sklearn ACCURACY: %s "%(accuracy)) 131 | 132 | total_cls[0] += accuracy 133 | total_cls[1] += precision 134 | total_cls[2] += recall 135 | total_cls[3] += f1_score 136 | 137 | # Uncomment to run once 138 | # FOLDS = 1 139 | # break 140 | 141 | print("AVG Random Forest: %s, AVG Auto-Sklearn: %s "%(1. * total_rf[0] / FOLDS, 1. * total_cls[0] / FOLDS)) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # NetworkClassification 2 | ### Overview 3 | 4 | ### Requirements 5 | - Specified in `requirements.txt` 6 | - Python 3.6 7 | 8 | ### Preprocessing 9 | 10 | ##### `create_pcap_stat.py` 11 | 12 | Specifications: 13 | 14 | 1. Path to pcap file(s) 15 | 2. Path for output csv file with statistical features 16 | 3. Path for output csv file with sequence features 17 | 18 | Run with python3.6 `create_pcap_stat.py` 19 | 20 | Program loads pcap files into memory by iterating through the HTTPs flow and grouping them on connection (see pytcpdump). 21 | Calculate statistical features that include the following: 22 | 23 | 1. Packet size: num, 25th, 50th, 75th, max, avg, var (remote->local, local->remote, combined) 24 | 2. Inter-arrival time: 25th, 50th, 75th (remote->local, local->remote, combined) 25 | 3. Payload size: 25th, 50th, 75th, max, avg, var (remote->local, local->remote) 26 | 27 | Get sequence features with padding (specify first n packets to use): 28 | 29 | 1. Packet sizes 30 | 2. Payload sizes 31 | 3. Inter-arrival times 32 | 4. Directionality 33 | 34 | ##### `pytcpdump.py` 35 | 36 | Loads one or more pcap files into memory by iterating through the HTTPs flow and grouping packets on connection (TCP only). 37 | Stores the following attributes for each connection in a cache: 38 | 39 | 1. SNI 40 | 2. Accumulated bytes 41 | 3. Arrival times 42 | 4. Packet sizes 43 | 5. Payload sizes. 44 | 45 | ##### `pytcpdump_utils.py` 46 | Utility functions for pytcpdump. Includes functions for parsing connection id, ip position, etc. 47 | 48 | ### ML 49 | ML Folder contains `ml.py`, which runs Random Forest classification on statistical features from 50 | the TCP handshake. High level summary can be broken down below: 51 | 52 | 1. Read CSV (`training/GCstats.csv`) of the packet/payload/inter-arrival time statistical features. 53 | 2. Filter for SNIs meeting a minimum number of connections. 54 | 3. Create a Random Forest Classifier and run 10-Fold Cross Validation for accuracy. 55 | 4. (Optional) Create Auto-Sklearn classifier and run 10-Fold Cross Validation for accuracy. 56 | 57 | ### DL 58 | DL Folder contains `dl.py` which is responsible for predicting SNI using sequence data. High level 59 | summary broken down below: 60 | 61 | 1. Read CSV (`training/GCseq25.csv`) of the packet/payload/inter-arrival time sequence features. 62 | 2. Filter for SNIs meeting a minimum number of connections. 63 | 3. Create three CNN-RNNs (one for each feature sequence) and run 10-Fold Cross Validation. 64 | 4. Get accuracy results for each CNN-RNNs, as well as an ensemble classifier 65 | 5. (Optional) Create Auto-Sklearn classifier and run 10-Fold Cross Validation for accuracy. 66 | 67 | ### `main.py` 68 | Reads in CSV files (`ML/training/GCstats.csv`, `DL/training/GCseq25.csv`). Filters SNIs meeting a minimum 69 | number of connections. Creates the following classifiers: 70 | 71 | 1. Random Forest 72 | 2. Baseline RNN trained on packet size sequences 73 | 3. CNN-RNN trained on packet size sequences 74 | 4. CNN-RNN trained on payload size sequences 75 | 5. CNN-RNN trained on inter-arrival time sequences 76 | 6. Ensemble CNN-RNN 77 | 7. Ensemble CNN-RNN + Random Forest 78 | 79 | Writes accuracy results to `final_results.csv`. (Optional) Can also write per-SNI class results to `class_results.csv`. 80 | 81 | -------------------------------------------------------------------------------- /class_results.csv: -------------------------------------------------------------------------------- 1 | ,ascii.jp,assets.adobedtm.com,beacon.krxd.net,d.adroll.com,facebook.com,google.com,google.fr,mc.yandex.ru,nexus.ensighten.com,pixel.quantserve.com,secure.adnxs.com,ssl.gstatic.com,tags.tiqcdn.com 2 | Random Forest,1,0.983471074,1,0.98951049,0.99047619,1,0.987421384,0.985714286,0.948275862,1,1,0.981595092,0.979452055 3 | Baseline RNN,1,0.884297521,1,0.835664336,0.980952381,0.98540146,0.974842767,0.992857143,0.939655172,0.987261146,0.965250965,0.969325153,0.924657534 4 | Packet CNN-RNN,1,0.776859504,0.990384615,0.835664336,0.914285714,0.99270073,0.993710692,0.957142857,0.99137931,0.974522293,0.94980695,0.975460123,0.993150685 5 | Payload CNN-RNN,1,0.958677686,0.990384615,0.842657343,0.980952381,0.99270073,0.993710692,0.992857143,0.982758621,0.98089172,0.992277992,0.969325153,0.993150685 6 | IAT CNN-RNN,0.945121951,0.983471074,0.980769231,0.968531469,0.99047619,0.98540146,0.981132075,0.921428571,0.965517241,0.98089172,0.922779923,0.760736196,0.945205479 7 | Ensemble CNN-RNN,1,0.983471074,1,0.846153846,0.99047619,0.99270073,0.993710692,0.992857143,1,0.98089172,0.992277992,0.969325153,0.993150685 8 | Ensemble RF + CNN-RNN,1,0.991735537,1,0.986013986,0.99047619,0.99270073,0.993710692,1,0.99137931,1,1,0.993865031,0.993150685 -------------------------------------------------------------------------------- /final_results.csv: -------------------------------------------------------------------------------- 1 | model,min connections,accuracy,precision,recall,f1_score 2 | Random Forest,100,0.9254768392370573,0.93,0.93,0.92 3 | Packet,100,0.8461852861035423,0.87,0.85,0.84 4 | Payload,100,0.8426430517711172,0.87,0.84,0.84 5 | IAT,100,0.7310626702997275,0.75,0.73,0.72 6 | Ensemble,100,0.943732970027248,0.95,0.94,0.94 7 | Ensemble + Domain,100,0.9467302452316076,0.95,0.95,0.95 8 | RF + Domain,100,0.9282016348773842,0.93,0.93,0.93 9 | -------------------------------------------------------------------------------- /imgs/Ensemble.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/niloofarbayat/NetworkClassification/4445bd0b86f2c2ca088f5e008138b1fec7abad4c/imgs/Ensemble.png -------------------------------------------------------------------------------- /imgs/Improvement.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/niloofarbayat/NetworkClassification/4445bd0b86f2c2ca088f5e008138b1fec7abad4c/imgs/Improvement.png -------------------------------------------------------------------------------- /imgs/Metrics.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/niloofarbayat/NetworkClassification/4445bd0b86f2c2ca088f5e008138b1fec7abad4c/imgs/Metrics.png -------------------------------------------------------------------------------- /imgs/PerSNIEnsemble.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/niloofarbayat/NetworkClassification/4445bd0b86f2c2ca088f5e008138b1fec7abad4c/imgs/PerSNIEnsemble.png -------------------------------------------------------------------------------- /imgs/RFvsEnsemble.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/niloofarbayat/NetworkClassification/4445bd0b86f2c2ca088f5e008138b1fec7abad4c/imgs/RFvsEnsemble.png -------------------------------------------------------------------------------- /imgs/classes.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/niloofarbayat/NetworkClassification/4445bd0b86f2c2ca088f5e008138b1fec7abad4c/imgs/classes.png -------------------------------------------------------------------------------- /imgs/directionality.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/niloofarbayat/NetworkClassification/4445bd0b86f2c2ca088f5e008138b1fec7abad4c/imgs/directionality.png -------------------------------------------------------------------------------- /imgs/perSniBest.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/niloofarbayat/NetworkClassification/4445bd0b86f2c2ca088f5e008138b1fec7abad4c/imgs/perSniBest.png -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from sklearn.metrics import accuracy_score 2 | from keras.preprocessing import text, sequence 3 | from keras.preprocessing.text import Tokenizer 4 | from keras.utils import to_categorical 5 | from keras.models import Model, Input, Sequential 6 | from keras.layers import GRU, LSTM, Embedding, Dense, TimeDistributed, Bidirectional, Activation, Dropout 7 | from sklearn.model_selection import train_test_split 8 | from keras.metrics import categorical_accuracy 9 | from keras import backend as K 10 | import tensorflow as tf 11 | from keras import optimizers 12 | import pandas as pd 13 | from keras.utils.np_utils import to_categorical 14 | from sklearn.model_selection import KFold 15 | from keras.callbacks import EarlyStopping 16 | import matplotlib.pyplot as plt 17 | from keras.layers import * 18 | from sklearn.preprocessing import MinMaxScaler 19 | from sklearn.preprocessing import Normalizer 20 | from sklearn.preprocessing import StandardScaler 21 | from sklearn.ensemble import RandomForestClassifier 22 | import csv 23 | from collections import defaultdict 24 | from utils import * 25 | 26 | BATCH_SIZE = 64 27 | EPOCHS = 100 # use early stopping 28 | SEQ_LEN = 25 29 | NUM_ROWS = -1 # set to -1 for all data 30 | MIN_CONNECTIONS_LIST = [100] # try a variety of min conn settings for model 31 | 32 | ######################################################### 33 | # RANDOM FOREST FOR ML CLASSIFICATION 34 | ######################################################### 35 | def MLClassification(X_train, X_test, y_train, y_test): 36 | rf = RandomForestClassifier(n_estimators=250, n_jobs=10) 37 | rf.fit(X_train, y_train) 38 | return rf.predict_proba(X_test) 39 | 40 | ######################################################### 41 | # BEST CNN-RNN FOR SEQUENCE CLASSIFICATION 42 | ######################################################### 43 | def DLClassification(X_train, X_test, y_train, y_test, time_steps, n_features, n_labels): 44 | # if you dont have newest keras version, you might have to remove restore_best_weights = True 45 | early_stopping = EarlyStopping(monitor='val_loss', min_delta=0, patience=5, verbose=1, mode='min', restore_best_weights=True) 46 | model = Sequential() 47 | model.add(Conv1D(200, 3, activation='relu', input_shape=(time_steps, n_features))) 48 | model.add(BatchNormalization()) 49 | model.add(Conv1D(400, 3, activation='relu')) 50 | model.add(BatchNormalization()) 51 | model.add(GRU(200)) 52 | model.add(Dropout(0.1)) 53 | model.add(Dense(200, activation='sigmoid')) 54 | model.add(Dropout(0.1)) 55 | model.add(Dense(n_labels, activation='softmax')) 56 | model.compile(loss='sparse_categorical_crossentropy',optimizer='adam', metrics=['acc']) 57 | model.summary() 58 | model.fit(X_train, y_train, epochs=EPOCHS, batch_size=BATCH_SIZE, verbose=1, shuffle=True, validation_split=0.05, callbacks = [early_stopping]) 59 | return model.predict(X_test) 60 | 61 | if __name__ == "__main__": 62 | statistics = [["model", "min connections", "accuracy", "precision", "recall", "f1_score"]] 63 | for min_connections in MIN_CONNECTIONS_LIST: 64 | datasetfile = "DL/training/GCseq25.csv" 65 | X, y = data_load_and_filter(datasetfile, min_connections, NUM_ROWS) 66 | X1, X2, X3, X4, y, time_steps, n_labels, rev_class_map = process_dl_features(X, y, SEQ_LEN) 67 | 68 | datasetfile = "ML/training/GCstats.csv" 69 | X, _ = data_load_and_filter(datasetfile, min_connections, NUM_ROWS) 70 | 71 | # Classifiers to test! 72 | stats = {} 73 | for model in ["Random Forest", "Packet", "Payload", "IAT", "Ensemble", "Ensemble + Domain", "RF + Domain"]: 74 | stats[model] = [0,0,0,0] 75 | 76 | FOLDS = 20 77 | kf = KFold(n_splits=FOLDS, shuffle=True) 78 | 79 | # Perform 10-Fold Cross Validation 80 | for train_index, test_index in kf.split(X1): 81 | 82 | # Statistical features 83 | X_train, X_test = X[train_index], X[test_index] 84 | 85 | # Packet features 86 | X1_train, X1_test = X1[train_index], X1[test_index] 87 | 88 | # Payload features 89 | X2_train, X2_test = X2[train_index], X2[test_index] 90 | 91 | # Inter-Arrival Time features 92 | X3_train, X3_test = X3[train_index], X3[test_index] 93 | 94 | # Directional features 95 | X4_train, X4_test = X4[train_index], X4[test_index] 96 | 97 | # Labels 98 | y_train, y_test = y[train_index], y[test_index] 99 | 100 | print("Training size: ", len(y_train)) 101 | print("Test set size: ", len(y_test)) 102 | 103 | # Random Forest classifier 104 | predictions_rf = MLClassification(X_train, X_test, y_train, y_test) 105 | stats = update_stats(stats, "Random Forest", predictions_rf, y_test) 106 | print("Random Forest ACCURACY: %s"%(stats["Random Forest"][0])) 107 | 108 | # domain expertise 109 | _, freqs = np.unique(y_test, return_counts=True) 110 | domain_rf = domain_expertise(predictions_rf, freqs, y_test) 111 | stats = update_stats(stats, "RF + Domain", domain_rf, y_test) 112 | print("RF + Domain ACCURACY: %s"%(stats["RF + Domain"][0])) 113 | 114 | 115 | # Create 3D input arrays (batch_size, time_steps, n_features = 2) 116 | X_train = np.stack([X1_train, X4_train], axis=2) 117 | X_test = np.stack([X1_test, X4_test], axis=2) 118 | 119 | predictions_1 = DLClassification(X_train, X_test, y_train, y_test, time_steps, 2, n_labels) 120 | stats = update_stats(stats, "Packet", predictions_1, y_test) 121 | print("Packet ACCURACY: %s"%(stats["Packet"][0])) 122 | 123 | # Create 3D input arrays (batch_size, time_steps, n_features = 2) 124 | X_train = np.stack([X2_train, X4_train], axis=2) 125 | X_test = np.stack([X2_test, X4_test], axis=2) 126 | 127 | predictions_2 = DLClassification(X_train, X_test, y_train, y_test, time_steps, 2, n_labels) 128 | stats = update_stats(stats, "Payload", predictions_2, y_test) 129 | print("Payload ACCURACY: %s"%(stats["Payload"][0])) 130 | 131 | # Create 3D input arrays (batch_size, time_steps, n_features = 2) 132 | X_train = np.stack([X3_train, X4_train], axis=2) 133 | X_test = np.stack([X3_test, X4_test], axis=2) 134 | 135 | predictions_3 = DLClassification(X_train, X_test, y_train, y_test, time_steps, 2, n_labels) 136 | stats = update_stats(stats, "IAT", predictions_3, y_test) 137 | print("IAT ACCURACY: %s"%(stats["IAT"][0])) 138 | 139 | predictions_123 = (predictions_1 + predictions_2 + predictions_3) / 3.0 140 | predictions_123_rf = (predictions_rf * 0.5 + predictions_123 * 0.5) 141 | stats = update_stats(stats, "Ensemble", predictions_123_rf, y_test) 142 | print("Ensemble DL ACCURACY: %s"%(stats["Ensemble"][0])) 143 | print("Ensemble ACCURACY: %s"%(stats["Ensemble"][0])) 144 | 145 | # domain expertise 146 | domain_ensemble = domain_expertise(predictions_123_rf, freqs, y_test) 147 | stats = update_stats(stats, "Ensemble + Domain", domain_ensemble, y_test) 148 | print("Ensemble + Domain ACCURACY: %s"%(stats["Ensemble + Domain"][0])) 149 | 150 | # Uncomment below to get per-SNI accuracy 151 | # output_class_accuracies(rev_class_map, predictions_rf, predictions_1, predictions_2, predictions_3, predictions_123, predictions_123_rf) 152 | 153 | # Uncomment to run once 154 | FOLDS = 1 155 | break 156 | 157 | for model, stats in stats.items(): 158 | statistics.append([model, min_connections] + [1. * x / FOLDS for x in stats]) 159 | 160 | with open('final_results.csv', 'a') as file: 161 | wr = csv.writer(file) 162 | for statistic in statistics: 163 | wr.writerow(statistic) 164 | statistics = [] 165 | 166 | -------------------------------------------------------------------------------- /preliminary_results.csv: -------------------------------------------------------------------------------- 1 | ,25,50,75,100,125,150,175,200,225,250 random forest,0.856004622,0.884260714,0.905921539,0.911257011,0.927944293,0.927669383,0.970734106,0.973661305,0.968511995,0.993581043 Auto Sklearn (rf features),0.833739837,0.876564278,0.890921886,0.886839899,0.914031621,0.919245283,0.957264957,0.956439394,0.950561798,0.990498812 DL baseline,0.334007297,0.456290392,0.606363353,0.640907953,0.71246457,0.711608396,0.813095364,0.832097421,0.734850693,0.776559778 Auto Sklearn (dl features),0.72601626,0.790671217,0.831808586,0.8248114,0.862648221,0.870283019,0.952136752,0.946969697,0.959550562,0.976247031 -------------------------------------------------------------------------------- /preprocessing/create_pcap_stat.py: -------------------------------------------------------------------------------- 1 | import subprocess, struct, time, select, threading, os, sys, traceback, itertools,math, collections 2 | import pytcpdump 3 | import re 4 | import tldextract 5 | import numpy as np 6 | import math 7 | 8 | #*********************************************************************************** 9 | # Header for csv with statistical features 10 | #*********************************************************************************** 11 | def stat_head(): 12 | return "sni,CSPktNum,CSPktsize25,CSPktSize50,CSPktSize75,CSPktSizeMax,CSPktSizeAvg,CSPktSizeVar,CSPaysize25,CSPaySize50,CSPaySize75,CSPaySizeMax,CSPaySizeAvg,CSPaySizeVar,CSiat25,CSiat50,CSiat75,SCPktNum,SCPktsize25,SCPktSize50,SCPktSize75,SCPktSizeMax,SCPktSizeAvg,SCPktSizeVar,SCPaysize25,SCPaySize50,SCPaySize75,SCPaySizeMax,SCPaySizeAvg,SCPaySizeVar,SCiat25,SCiat50,SCiat75,PktNum,Pktsize25,PktSize50,PktSize75,PktSizeMax,PktSizeAvg,PktSizeVar,iat25,iat50,iat75\n" 13 | 14 | #*********************************************************************************** 15 | # Header for csv with sequence features 16 | #*********************************************************************************** 17 | def sequence_head(n): 18 | return "sni," + ','.join([str(i) for i in range(1,n)]) + "\n" 19 | 20 | #*********************************************************************************** 21 | # Get features for packets/payloads (25th, 50th, 75th) percentiles, max, mean, var 22 | #*********************************************************************************** 23 | def stat_calc(x, iat=False): 24 | if len(x)==0: 25 | return [str(a) for a in [0,0,0,0,0,0]] 26 | if len(x)==1: 27 | return [str(a) for a in [x[0], x[0], x[0], x[0], x[0], 0]] 28 | x = sorted(x) 29 | p25,p50,p75 = get_percentiles(x) 30 | return [str(a) for a in [p25,p50,p75,max(x),np.mean(x),np.var(x)]] 31 | 32 | #*********************************************************************************** 33 | # Helper function to get percentiles 34 | #*********************************************************************************** 35 | def get_percentiles(x): 36 | return x[int(round((len(x)-1)/4.0))], x[int(round((len(x)-1)/2.0))], x[int(round((len(x)-1)*3/4.0))] 37 | 38 | #*********************************************************************************** 39 | # Helper function to combine milliseconds/seconds timestamps 40 | #*********************************************************************************** 41 | def combine_at(sec, usec): 42 | l = len(sec) 43 | return [sec[i]+usec[i]*1e-6 for i in range(l)] 44 | 45 | #*********************************************************************************** 46 | # Get features for inter-arrival times (25th, 50th, 75th) percentiles 47 | #*********************************************************************************** 48 | def stat_prepare_iat(t): 49 | l = len(t) 50 | iat = [t[i+1]-t[i] for i in range(l-1)] 51 | if len(iat)==0: 52 | return [str(a) for a in [0,0,0]] 53 | if len(iat)==1: 54 | return [str(a) for a in [iat[0], iat[0], iat[0]]] 55 | p25,p50,p75 = get_percentiles(iat) 56 | return [str(a) for a in [p25,p50,p75]] 57 | 58 | #*********************************************************************************** 59 | # Get statistical features from tcp packet sequences 60 | #*********************************************************************************** 61 | def stat_create(data,filename,first_n_packets): 62 | with open(filename,'w') as f: 63 | f.write(stat_head()) 64 | for id in data: 65 | item=data[id] 66 | 67 | sni=SNIModificationbyone(item[0]) 68 | 69 | # exclude unknown domains 70 | if sni == 'unknown' or sni == 'unknown.': 71 | continue 72 | 73 | line=[sni] 74 | 75 | # remote->local features 76 | # 1 length 77 | # 2-7 packets stats 78 | # 8-14 payload stats 79 | # 15-17 inter-arrival time stats 80 | line+=[str(len(item[4][0]))] 81 | line+=stat_calc(item[4][0]) 82 | line+=stat_calc(item[5][0]) 83 | arrival1=combine_at(item[2][0], item[3][0]) 84 | line+=stat_prepare_iat(arrival1) 85 | 86 | # local->remote 87 | # 18 length 88 | # 19-24 packets stats 89 | # 25-30 payload stats 90 | # 31-33 inter-arrival time stats 91 | line+=[str(len(item[4][1]))] 92 | line+=stat_calc(item[4][1]) 93 | line+=stat_calc(item[5][1]) 94 | arrival2=combine_at(item[2][1], item[3][1]) 95 | line+=stat_prepare_iat(arrival2) 96 | 97 | # both 98 | # 34-39 packets stats 99 | # 40-42 inter-arrival time stats 100 | line+=[str(len(item[4][1]) + len(item[4][0]))] 101 | line+=stat_calc(item[4][1] + item[4][0]) 102 | line+=stat_prepare_iat(sorted(arrival1 + arrival2)) 103 | 104 | line= ','.join(line) 105 | f.write(line) 106 | f.write('\n') 107 | 108 | #*********************************************************************************** 109 | # Create features from tcp packet sequences 110 | #*********************************************************************************** 111 | def sequence_create(data, filename, first_n_packets): 112 | with open(filename,'w') as f: 113 | f.write(sequence_head(first_n_packets)) 114 | counter = 0 115 | skipped = 0 116 | for id in data: 117 | item=data[id] 118 | sni=SNIModificationbyone(item[0]) 119 | 120 | # exclude unknown domains 121 | counter = counter + 1 122 | if sni == 'unknown' or sni == 'unknown.': 123 | skipped = skipped + 1 124 | continue 125 | 126 | line=[sni] 127 | 128 | # Calculate arrival times in millis for local->remote and remote->local 129 | arrival1=combine_at(item[2][0], item[3][0]) 130 | arrival2=combine_at(item[2][1], item[3][1]) 131 | 132 | # Sort all packets by arrival times to get sequence in correct order 133 | packets = zip(arrival1 + arrival2, list(item[4][0]) + list(item[4][1])) 134 | packets = [str(x) for _,x in sorted(packets)] 135 | 136 | # Zero padding for sequences that are too short 137 | if len(packets) < first_n_packets: 138 | packets = [str(0)]*(first_n_packets - len(packets)) + packets 139 | 140 | line+=packets[0:first_n_packets] 141 | 142 | # Sort all payloads by arrival times to get sequence in correct order 143 | payloads = zip(arrival1 + arrival2, list(item[5][0]) + list(item[5][1])) 144 | payloads = [str(x) for _,x in sorted(payloads)] 145 | 146 | # Zero padding for sequences that are too short 147 | if len(payloads) < first_n_packets: 148 | payloads = [str(0)]*(first_n_packets - len(payloads)) + payloads 149 | 150 | line+=payloads[0:first_n_packets] 151 | 152 | # Sort all packets by arrival times to get sequence in correct order 153 | arrivals = sorted(arrival1 + arrival2) 154 | iat = [str(0)] + [str(arrivals[i+1]-arrivals[i]) for i in range(len(arrivals)-1)] 155 | 156 | # Zero padding for sequences that are too short 157 | if len(iat) < first_n_packets: 158 | iat = [str(0)]*(first_n_packets - len(iat)) + iat 159 | 160 | line+=iat[0:first_n_packets] 161 | 162 | # Sort all directions by arrival times to get direction sequence in correct order (-1, 1, 0) 163 | # remote -> local = -1 164 | # local -> remote = 1 165 | # padding = 0 166 | direction = zip(arrival1 + arrival2, [-1]*len(item[5][0]) + [1]*len(item[5][1])) 167 | direction = [str(x) for _,x in sorted(direction)] 168 | 169 | # Zero padding for direction sequences that are too short 170 | if len(direction) < first_n_packets: 171 | direction = [str(0)]*(first_n_packets - len(direction)) + direction 172 | 173 | line+=direction[0:first_n_packets] 174 | 175 | line= ','.join(line) 176 | f.write(line) 177 | f.write('\n') 178 | print("Skipped percentage: ", 1. * skipped / counter) 179 | 180 | #*********************************************************************************** 181 | # Parts of this function borrowed from the following paper: 182 | # 183 | # Multi-Level identification Framework to Identify HTTPS Services 184 | # Author by Wazen Shbair, 185 | # University of Lorraine, 186 | # France 187 | # wazen.shbair@gmail.com 188 | # January, 2017 189 | # 190 | # SNi modification for the sub-domain parts only 191 | #*********************************************************************************** 192 | def SNIModificationbyone(sni): 193 | temp = tldextract.extract(sni.encode().decode()) 194 | x = re.sub("\d+", "", temp.subdomain) # remove numbers 195 | x = re.sub("[-,.]", "", x) #remove dashes 196 | x = re.sub("[(?:www.)]", "", x) #remove www 197 | if len(x) > 0: 198 | newsni = x + "." + temp.domain + "." + temp.suffix # reconstruct the sni 199 | else: 200 | newsni = temp.domain + "." + temp.suffix 201 | 202 | return newsni 203 | 204 | #*********************************************************************************** 205 | # Inputs 206 | # 1. pcap file (filtered for SSL) 207 | # 2. output file for statistical features 208 | # 3. output file for sequence features 209 | #*********************************************************************************** 210 | if __name__ == "__main__": 211 | pcap_file = ['../pcaps/GCDay1SSL.pcap', '../pcaps/GCDay2SSL.pcap','../pcaps/GCDay3SSL.pcap', 212 | '../pcaps/GCDay4SSL.pcap','../pcaps/GCDay5SSL.pcap','../pcaps/GCDay6SSL.pcap', 213 | '../pcaps/GCDay7SSL.pcap','../pcaps/GCDay8SSL.pcap','../pcaps/GCDay9SSL.pcap', 214 | '../pcaps/GCDay10SSL.pcap','../pcaps/GCDay11SSL.pcap','../pcaps/GCDay12SSL.pcap'] 215 | output_file_stats = '../ML/training/GCstats.csv' 216 | output_file_seqs = '../DL/training/GCseq25.csv' 217 | for fname in pcap_file: 218 | print ('process', fname) 219 | pytcpdump.process_file(fname) 220 | print (fname,"finished, kept",len(pytcpdump.cache.cache),'records') 221 | 222 | stat_create(pytcpdump.cache.cache, output_file_stats, first_n_packets=25) 223 | sequence_create(pytcpdump.cache.cache, output_file_seqs, first_n_packets=25) 224 | -------------------------------------------------------------------------------- /preprocessing/pytcpdump.py: -------------------------------------------------------------------------------- 1 | #*********************************************************************************** 2 | # replicating what tcpdump terminal command does in python using appropriate filters 3 | # For instance filters for https handshake packets and based on handshake packets gets the sni 4 | #*********************************************************************************** 5 | 6 | 7 | 8 | import subprocess, struct, time, select, threading, os, sys, traceback, itertools, math, collections 9 | from ctypes import * 10 | from pytcpdump_utils import * 11 | 12 | cache = None 13 | 14 | #*********************************************************************************** 15 | # Pcap file header 16 | #*********************************************************************************** 17 | class pcap_hdr_s(Structure): 18 | _fields_ = [('magic', c_uint32), 19 | ('v1', c_uint16), 20 | ('v2', c_uint16), 21 | ('zone', c_uint32), 22 | ('sigfigs', c_uint32), 23 | ('snaplen', c_uint32), 24 | ('network', c_uint32)] 25 | 26 | #*********************************************************************************** 27 | # Packet headers 28 | #*********************************************************************************** 29 | class pcaprec_hdr_s(Structure): 30 | _fields_ = [('sec', c_uint32), 31 | ('usec', c_uint32), 32 | ('len', c_uint32), 33 | ('olen', c_uint32)] 34 | 35 | #*********************************************************************************** 36 | # Cache stores data for each TCP handshake. 37 | # 38 | # Key: 39 | # - connections ID 40 | # 41 | # Values: 42 | # - TCP data dictionary 43 | # 44 | # 0:str:hostname, 45 | # 1:[int,int]:accumulated bytes[remote->local, local->remote], 46 | # 2:list[[int],[int]]:arrival seconds, 47 | # 3:list[[int],[int]]:arrival micro-seconds, 48 | # 4:list[[int],[int]]:packet size 49 | # 5:list[[int],[int]]:tcp payload size 50 | #*********************************************************************************** 51 | class Cache: 52 | def __init__(self): 53 | self.lock = threading.RLock() 54 | self.cache = collections.OrderedDict() 55 | 56 | def set_hostname(self, key, hostname): 57 | with self.lock: 58 | item = self.pop(key) 59 | item[0] = hostname 60 | self.cache[key] = item 61 | 62 | def update(self, key, fromLocal, pcaprec_hdr, payload_len): 63 | with self.lock: 64 | item = self.pop(key) 65 | item[1][fromLocal]+=pcaprec_hdr.olen 66 | item[2][fromLocal].append(pcaprec_hdr.sec) 67 | item[3][fromLocal].append(pcaprec_hdr.usec) 68 | item[4][fromLocal].append(pcaprec_hdr.olen) 69 | item[5][fromLocal].append(payload_len) 70 | self.cache[key] = item 71 | 72 | def pop(self, key): 73 | with self.lock: 74 | try: 75 | return self.cache.pop(key) 76 | except KeyError: 77 | return ['unknown',[0,0],[[],[]],[[],[]],[[],[]],[[],[]]] 78 | 79 | #*********************************************************************************** 80 | # Search for the SNI in the encrypted payload! 81 | #*********************************************************************************** 82 | def sni_pos(data): 83 | pos=0 84 | while 1: 85 | pos=data.find(b'.',pos) 86 | if pos<0: 87 | return -1,-1 88 | 89 | # get start of domain 90 | pos1=pos 91 | while isDomainChar(data[pos1-1:pos1]): 92 | pos1-=1 93 | 94 | # get end of domain 95 | pos2=pos 96 | while 1: 97 | pos2+=1 98 | if pos2==len(data): 99 | break 100 | if not isDomainChar(data[pos2:pos2+1]): 101 | break 102 | 103 | # min domain length 104 | if pos2-pos1>=5: 105 | if ord(data[pos1-1:pos1])==0: #data[pos1-2:pos1] should be the len 106 | pos1+=1 107 | if (ord(data[pos1-2:pos1-1])<<8)+ord(data[pos1-1:pos1]) == pos2-pos1: 108 | return pos1, pos2 109 | 110 | # try again 111 | pos=pos2 112 | 113 | #*********************************************************************************** 114 | # The meat and potatoes: 115 | # 1. Iterate through the pcap file 116 | # 2. Get connection id 117 | # 3. Add packet to connection id cache 118 | # 4. Search for hostname in encrypted payload 119 | #*********************************************************************************** 120 | def process_file(filename): 121 | global cache 122 | if not cache: 123 | cache = Cache() 124 | pcap_hdr = pcap_hdr_s() 125 | pkt_hdr = pcaprec_hdr_s() 126 | pkt_hdr_size = sizeof(pkt_hdr) 127 | with open(filename,'rb') as f: 128 | f.readinto(pcap_hdr) 129 | counter = 0 130 | while 1: 131 | counter = counter + 1 132 | if f.readinto(pkt_hdr)!=pkt_hdr_size: 133 | break 134 | 135 | # get next packet 136 | data = bytes(f.read(pkt_hdr.len)) 137 | ip_pos = get_ip_pos(data) # IP layer 138 | 139 | # Invalid IP shouldn't happen, but just in case... 140 | if ip_pos<0: 141 | print("Invalid IP: ", ip_pos) 142 | continue 143 | 144 | tcp_pos = ip_pos + ((ord(data[ip_pos:ip_pos+1]) & 0x0f)<<2) 145 | ip_proto = ord(data[ip_pos+9:ip_pos+10]) 146 | 147 | # Must be TCP! 148 | if ip_proto != 0x06: 149 | print("Skipping unexpected connection: ", ip_proto) 150 | continue 151 | 152 | conn_id = data[ip_pos+ 12: ip_pos+20] + data[tcp_pos : tcp_pos+4] 153 | conn_id, fromLocal = unify_conn_id(conn_id) 154 | tcp_load_pos = tcp_pos + (( ord(data[tcp_pos+12:tcp_pos+13]) & 0xf0 )>>2) 155 | cache.update(conn_id, fromLocal, pkt_hdr, pkt_hdr.olen-tcp_load_pos) 156 | if (tcp_load_pos+50: 159 | cache.set_hostname(conn_id, data[pos1:pos2].decode()) 160 | -------------------------------------------------------------------------------- /preprocessing/pytcpdump.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/niloofarbayat/NetworkClassification/4445bd0b86f2c2ca088f5e008138b1fec7abad4c/preprocessing/pytcpdump.pyc -------------------------------------------------------------------------------- /preprocessing/pytcpdump_utils.py: -------------------------------------------------------------------------------- 1 | #*********************************************************************************** 2 | # Valid domain characters 3 | #*********************************************************************************** 4 | def isDomainChar(c): 5 | if c.isalpha() or c.isdigit() or c==b'-' or c==b'_' or c==b'.' or c==b'/': 6 | return True 7 | return False 8 | 9 | #*********************************************************************************** 10 | # Reverses the format dst_ip+dport+src_ip+sport to src_ip+sport+dst_ip+dport 11 | #*********************************************************************************** 12 | def conn_id_reverse(id): 13 | return id[4:8] + id[0:4] + id[10:12] + id[8:10] 14 | 15 | #*********************************************************************************** 16 | # Checks for remote vs local 17 | #*********************************************************************************** 18 | def isLocalAddress(s): 19 | return s.startswith(b'\x0a') or s.startswith(b'\xc0\xa8') #10.x.x.x, 192.168.x.x 20 | 21 | #*********************************************************************************** 22 | # Gets connection id 23 | # unify both direct (src_ip+sport+dst_ip+dport) and reverse (dst_ip+dport+src_ip+sport) 24 | # network traffic 25 | #*********************************************************************************** 26 | def unify_conn_id(conn_id): 27 | l = isLocalAddress(conn_id[0:4]) 28 | r = isLocalAddress(conn_id[4:8]) 29 | if (l==r): 30 | if conn_id[0:4] < conn_id[4:8]: 31 | return conn_id, 1 32 | else: 33 | return conn_id_reverse(conn_id), 0 34 | #else l!=r 35 | if not l: 36 | return conn_id_reverse(conn_id), 0 37 | return conn_id, 1 38 | 39 | #*********************************************************************************** 40 | # Gets ip position 41 | #*********************************************************************************** 42 | def get_ip_pos(data): 43 | ip_pos = 14 44 | if data[12:14] == b'\x81\x00': 45 | ip_pos = 18 46 | if data[ip_pos-2:ip_pos-1]==b'\x08' and data[ip_pos-1:ip_pos]==b'\x00': 47 | return ip_pos 48 | else: 49 | return -1 #not ip protocol 50 | -------------------------------------------------------------------------------- /preprocessing/pytcpdump_utils.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/niloofarbayat/NetworkClassification/4445bd0b86f2c2ca088f5e008138b1fec7abad4c/preprocessing/pytcpdump_utils.pyc -------------------------------------------------------------------------------- /references.txt: -------------------------------------------------------------------------------- 1 | @inproceedings{shbair2016multi, 2 | title={A multi-level framework to identify https services}, 3 | author={Shbair, W. M. and Cholez, T. and Francois, J. and Chrisment, I.}, 4 | booktitle={Network Operations and Management Symposium (NOMS), 2016 IEEE/IFIP}, 5 | pages={240--248}, 6 | year={2016}, 7 | organization={IEEE} 8 | } 9 | 10 | @article{nguyen2008survey, 11 | title={A survey of techniques for internet traffic classification using machine learning}, 12 | author={Nguyen, T. T. and Armitage, G.}, 13 | journal={IEEE Communications Surveys \& Tutorials}, 14 | volume={10}, 15 | number={4}, 16 | pages={56--76}, 17 | year={2008}, 18 | publisher={IEEE} 19 | } 20 | 21 | @article{fadlullah2017state, 22 | title={State-of-the-art deep learning: Evolving machine intelligence toward tomorrow’s intelligent network traffic control systems}, 23 | author={Fadlullah, Z. and Tang, F. and Mao, B. and Kato, N. and Akashi, O. and Inoue, T. and Mizutani, K.}, 24 | journal={IEEE Communications Surveys \& Tutorials}, 25 | volume={19}, 26 | number={4}, 27 | pages={2432--2455}, 28 | year={2017}, 29 | publisher={IEEE} 30 | } 31 | 32 | @misc{kolton2011system, 33 | title={System to enable detecting attacks within encrypted traffic}, 34 | author={Kolton, D. and Stav, A. and Wexler, A. and Frydman, A. E. and Zahavi, Y.}, 35 | year={2011}, 36 | month=feb # "~22", 37 | publisher={Google Patents}, 38 | note={US Patent 7,895,652} 39 | } 40 | 41 | @misc{SNI, 42 | title={Transport layer security (TLS) extensions: Extension 43 | definitions (RFC 6066)}, 44 | author={D. Eastlake}, 45 | year={2011} 46 | } 47 | 48 | @misc{TMG, 49 | title={Configuring HTTPS inspection with forefront threat 50 | management gateway (TMG)}, 51 | author={R. Hicks}, 52 | howpublished = {http://techgenix.com/Configuring-HTTPS-Inspection-Forefront-Threat-Management-Gateway-TMG-2010/} 53 | } 54 | 55 | @article{Moore:2005:ITC:1071690.1064220, 56 | author = {Moore, Andrew W. and Zuev, Denis}, 57 | title = {Internet Traffic Classification Using Bayesian Analysis Techniques}, 58 | journal = {SIGMETRICS Perform. Eval. Rev.}, 59 | issue_date = {June 2005}, 60 | volume = {33}, 61 | number = {1}, 62 | month = jun, 63 | year = {2005}, 64 | issn = {0163-5999}, 65 | pages = {50--60}, 66 | numpages = {11}, 67 | url = {http://doi.acm.org/10.1145/1071690.1064220}, 68 | doi = {10.1145/1071690.1064220}, 69 | acmid = {1064220}, 70 | publisher = {ACM}, 71 | address = {New York, NY, USA}, 72 | keywords = {flow classification, internet traffic, traffic identification}, 73 | } 74 | 75 | @INPROCEEDINGS{6147705, 76 | author={Y. Okada and S. Ata and N. Nakamura and Y. Nakahira and I. Oka}, 77 | booktitle={2011 10th International Conference on Machine Learning and Applications and Workshops}, 78 | title={Comparisons of Machine Learning Algorithms for Application Identification of Encrypted Traffic}, 79 | year={2011}, 80 | volume={2}, 81 | number={}, 82 | pages={358-361}, 83 | keywords={Bayes methods;computer network management;computer network security;cryptography;decision trees;learning (artificial intelligence);support vector machines;machine learning algorithms;application identification;traffic encryption;network operators;network management;bandwidth control;traffic security;encrypted packets;estimated features method;EFM;support vector machine;Naive Bayes Kernel Estimation;decision tree;Encryption;Accuracy;Support vector machines;Training data;Monitoring;Machine learning algorithms}, 84 | doi={10.1109/ICMLA.2011.162}, 85 | ISSN={}, 86 | month={Dec},} 87 | 88 | @misc{alexa, 89 | title={The top 500 sites on the web}, 90 | year={2018}, 91 | howpublished = {\url{https://www.alexa.com/topsites}} 92 | } 93 | 94 | @inproceedings{chen2010side, 95 | title={Side-channel leaks in web applications: A reality today, a challenge tomorrow}, 96 | author={Chen, Shuo and Wang, Rui and Wang, XiaoFeng and Zhang, Kehuan}, 97 | booktitle={2010 IEEE Symposium on Security and Privacy}, 98 | pages={191--206}, 99 | year={2010}, 100 | organization={IEEE} 101 | } 102 | 103 | @inproceedings{shbair2015efficiently, 104 | title={Efficiently bypassing SNI-based HTTPS filtering}, 105 | author={Shbair, Wazen M and Cholez, Thibault and Goichot, Antoine and Chrisment, Isabelle}, 106 | booktitle={Integrated Network Management (IM), 2015 IFIP/IEEE International Symposium on}, 107 | pages={990--995}, 108 | year={2015}, 109 | organization={IEEE} 110 | } 111 | 112 | 113 | @misc{ESNI, 114 | title={ESNI: A Privacy-Protecting Upgrade to HTTPS}, 115 | author={SETH SCHOEN}, 116 | howpublished = {\url{https://www.eff.org/deeplinks/2018/09/esni-privacy-protecting-upgrade-https}}, 117 | journal = {EFF DeepLinks Blog}, 118 | year={2018} 119 | } 120 | 121 | @misc{domain-fronting, 122 | title={Don't panic about domain fronting, an SNI fix is getting hacked out}, 123 | author={ Thomas Claburn}, 124 | howpublished = {\url{https://www.theregister.co.uk/2018/07/17/encrypted_server_names/}}, 125 | journal = {The Register}, 126 | year={2018} 127 | } 128 | 129 | @misc{shbair2016, 130 | author = {Wazen Shbair, Thibault Cholez, Jerome Francois, Isabelle Chrisment}, 131 | title = {HTTPS Websites Dataset}, 132 | howpublished={\url{4 http://betternet.lhs.loria.fr/datasets/https/}}, 133 | year = {2016} 134 | } 135 | 136 | @INPROCEEDINGS{5356534, 137 | author={R. Alshammari and A. N. Zincir-Heywood}, 138 | booktitle={2009 IEEE Symposium on Computational Intelligence for Security and Defense Applications}, 139 | title={Machine learning based encrypted traffic classification: Identifying SSH and Skype}, 140 | year={2009}, 141 | volume={}, 142 | number={}, 143 | pages={1-8}, 144 | keywords={cryptography;learning (artificial intelligence);support vector machines;telecommunication traffic;machine learning;encrypted traffic classification;secure shell;Skype;traffic classification;adaboost;support vector machine;Nai¿e Bayesian;RIPPER;flow based features;C4.5 based approach;Machine learning;Cryptography;Telecommunication traffic;Traffic control;Payloads;Bayesian methods;Robustness;Support vector machines;Support vector machine classification;Financial management}, 145 | doi={10.1109/CISDA.2009.5356534}, 146 | ISSN={2329-6267}, 147 | month={July},} 148 | 149 | 150 | @misc{rfc5246, 151 | author = {T. Dierks}, 152 | title = {The transport layer security (TLS) protocol version 1.2 (RFC 153 | 5246)}, 154 | howpublished={\url{https://tools.ietf.org/html/rfc5246}}, 155 | year = {2008} 156 | } 157 | 158 | 159 | @inproceedings{naylor2014cost, 160 | title={The cost of the S in HTTPS}, 161 | author={Naylor, David and Finamore, Alessandro and Leontiadis, Ilias and Grunenberger, Yan and Mellia, Marco and Munaf{\`o}, Maurizio and Papagiannaki, Konstantina and Steenkiste, Peter}, 162 | booktitle={Proceedings of the 10th ACM International on Conference on emerging Networking Experiments and Technologies}, 163 | pages={133--140}, 164 | year={2014}, 165 | organization={ACM} 166 | } 167 | 168 | @misc{man-in-middle, 169 | author = {Tanmay Patange}, 170 | title = {THow to defend yourself against MITM or Man-in-the-middle attack}, 171 | howpublished={\url{https://hackerspace.kinja.com/how-to-defend-yourself-against-mitm-or-man-in-the-middl-1461796382}}, 172 | year = {2013} 173 | } 174 | 175 | 176 | @incollection{NIPS2015_5872, 177 | title = {Efficient and Robust Automated Machine Learning}, 178 | author = {Feurer, Matthias and Klein, Aaron and Eggensperger, Katharina and 179 | Springenberg, Jost and Blum, Manuel and Hutter, Frank}, 180 | booktitle = {Advances in Neural Information Processing Systems 28}, 181 | editor = {C. Cortes and N. D. Lawrence and D. D. Lee and M. Sugiyama and R. Garnett}, 182 | pages = {2962--2970}, 183 | year = {2015}, 184 | publisher = {Curran Associates, Inc.}, 185 | url = {http://papers.nips.cc/paper/5872-efficient-and-robust-automated-machine-learning.pdf} 186 | } 187 | 188 | @ARTICLE{8026581, 189 | author={M. Lopez-Martin and B. Carro and A. Sanchez-Esguevillas and J. Lloret}, 190 | journal={IEEE Access}, 191 | title={Network Traffic Classifier With Convolutional and Recurrent Neural Networks for Internet of Things}, 192 | year={2017}, 193 | volume={5}, 194 | number={}, 195 | pages={18042-18050}, 196 | keywords={Internet of Things;learning (artificial intelligence);recurrent neural nets;telecommunication computing;telecommunication traffic;network traffic classifier;convolutional networks;recurrent neural networks;NTC;current network monitoring systems;network service;communication flow;current network flow;traffic volume;heterogeneous devices;IoT traffic;convolutional neural network;Internet of Things networks;CNN;deep learning models;Ports (Computers);Telecommunication traffic;Feature extraction;Recurrent neural networks;Machine learning;Payloads;Biological neural networks;Convolutional neural network;deep learning;network traffic classification;recurrent neural network}, 197 | doi={10.1109/ACCESS.2017.2747560}, 198 | ISSN={2169-3536}, 199 | month={},} 200 | 201 | @inproceedings{krizhevsky2012imagenet, 202 | title={Imagenet classification with deep convolutional neural networks}, 203 | author={Krizhevsky, Alex and Sutskever, Ilya and Hinton, Geoffrey E}, 204 | booktitle={Advances in neural information processing systems}, 205 | pages={1097--1105}, 206 | year={2012} 207 | } -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.6.1 2 | alabaster==0.7.12 3 | astor==0.7.1 4 | auto-sklearn==0.4.1 5 | Babel==2.6.0 6 | certifi==2018.11.29 7 | chardet==3.0.4 8 | ConfigSpace==0.4.7 9 | cycler==0.10.0 10 | Cython==0.29 11 | docutils==0.14 12 | gast==0.2.0 13 | grpcio==1.16.1 14 | h5py==2.8.0 15 | idna==2.7 16 | imagesize==1.1.0 17 | Jinja2==2.10 18 | joblib==0.13.0 19 | Keras==2.2.4 20 | Keras-Applications==1.0.6 21 | Keras-Preprocessing==1.0.5 22 | kiwisolver==1.0.1 23 | liac-arff==2.3.1 24 | lockfile==0.12.2 25 | Markdown==3.0.1 26 | MarkupSafe==1.1.0 27 | matplotlib==3.0.2 28 | nose==1.3.7 29 | numpy==1.15.4 30 | packaging==18.0 31 | pandas==0.23.4 32 | protobuf==3.6.1 33 | psutil==5.4.8 34 | Pygments==2.3.0 35 | pynisher==0.4.2 36 | pyparsing==2.3.0 37 | pyrfr==0.7.4 38 | python-dateutil==2.7.5 39 | pytz==2018.7 40 | PyYAML==3.13 41 | requests==2.20.1 42 | scikit-learn==0.19.2 43 | scipy==1.1.0 44 | six==1.11.0 45 | sklearn==0.0 46 | smac==0.8.0 47 | snowballstemmer==1.2.1 48 | Sphinx==1.8.2 49 | sphinx-rtd-theme==0.4.2 50 | sphinxcontrib-websupport==1.1.0 51 | tensorboard==1.12.0 52 | tensorflow==1.12.0 53 | termcolor==1.1.0 54 | typing==3.6.6 55 | urllib3==1.24.1 56 | Werkzeug==0.14.1 57 | xgboost==0.7.post3 58 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import csv 4 | from sklearn.metrics import classification_report 5 | 6 | #*********************************************************************************** 7 | # Utility function to write accuracies per class to a CSV, for a variety of classifiers 8 | #*********************************************************************************** 9 | def output_class_accuracies(rev_class_map, predictions_rf, predictions1, predictions2, predictions3, predictions123, predictions123rf): 10 | classes = [] 11 | accuracies_rf = [] 12 | accuracies1 = [] 13 | accuracies2 = [] 14 | accuracies3 = [] 15 | accuracies123 = [] 16 | accuracies123rf= [] 17 | 18 | snis = np.unique(y_test) 19 | for sni in snis: 20 | indices = np.where(y_test == sni) 21 | correct_rf = np.sum([np.argmax(x) for x in predictions_rf[indices]] == y_test[indices]) 22 | correct1 = np.sum([np.argmax(x) for x in predictions1[indices]] == y_test[indices]) 23 | correct2 = np.sum([np.argmax(x) for x in predictions2[indices]] == y_test[indices]) 24 | correct3 = np.sum([np.argmax(x) for x in predictions3[indices]] == y_test[indices]) 25 | correct123 = np.sum([np.argmax(x) for x in predictions123[indices]] == y_test[indices]) 26 | correct123rf = np.sum([np.argmax(x) for x in predictions123rf[indices]] == y_test[indices]) 27 | 28 | classes.append(rev_class_map[sni]) 29 | accuracies_rf.append(1. * correct_rf / len(indices[0])) 30 | accuracies1.append(1. * correct1 / len(indices[0])) 31 | accuracies2.append(1. * correct2 / len(indices[0])) 32 | accuracies3.append(1. * correct3 / len(indices[0])) 33 | accuracies123.append(1. * correct123 / len(indices[0])) 34 | accuracies123rf.append(1. * correct123rf / len(indices[0])) 35 | 36 | with open('class_results.csv', 'w') as file: 37 | wr = csv.writer(file) 38 | wr.writerow([' '] + classes) 39 | wr.writerow(['Random Forest'] + accuracies_rf) 40 | wr.writerow(['Packet CNN-RNN'] + accuracies1) 41 | wr.writerow(['Payload CNN-RNN'] + accuracies2) 42 | wr.writerow(['IAT CNN-RNN'] + accuracies3) 43 | wr.writerow(['Ensemble CNN-RNN'] + accuracies123) 44 | wr.writerow(['Ensemble RF + CNN-RNN'] + accuracies123rf) 45 | 46 | def read_csv(file_path, has_header=True): 47 | with open(file_path) as f: 48 | if has_header: f.readline() 49 | data = [] 50 | for line in f: 51 | line = line.strip().split(",") 52 | data.append([x for x in line]) 53 | return data 54 | 55 | 56 | #*********************************************************************************** 57 | # Filter the data set using the minimum connections filter 58 | #*********************************************************************************** 59 | def data_load_and_filter(datasetfile, min_connections, NUM_ROWS=-1): 60 | dataset = read_csv(datasetfile) 61 | 62 | # Use first n rows if necessary 63 | dataset = dataset[:NUM_ROWS] 64 | 65 | X = np.array([z[1:] for z in dataset]) 66 | y = np.array([z[0] for z in dataset]) 67 | print("Shape of X =", np.shape(X)) 68 | print("Shape of y =", np.shape(y)) 69 | 70 | print("Entering min connections filter section! ") 71 | snis, counts = np.unique(y, return_counts=True) 72 | above_min_conns = list() 73 | 74 | for i in range(len(counts)): 75 | if counts[i] > min_connections: 76 | above_min_conns.append(snis[i]) 77 | 78 | print("Filtering done. SNI classes remaining: ", len(above_min_conns)) 79 | indices = np.isin(y, above_min_conns) 80 | X = X[indices] 81 | y = y[indices] 82 | 83 | print("Filtered shape of X =", np.shape(X)) 84 | print("Filtered shape of y =", np.shape(y)) 85 | 86 | #it's needed for auto_sklearn to work 87 | X = X.astype(np.float) 88 | return X, y 89 | 90 | #*********************************************************************************** 91 | # Function separate out input data into packet, payload, IAT, direction sequences 92 | # 93 | # len(X[0]) = 100 94 | # X_1...X_25 = Packet Sizes 95 | # X_26...X_50 = Payload Sizes 96 | # X_51...X_75 = Inter-Arrival Times 97 | # X_76...X_100 = Directional Features 98 | #*********************************************************************************** 99 | def process_dl_features(X, y, SEQ_LEN=25): 100 | # packet, payload, IAT, direction 101 | X1 = X[:,:SEQ_LEN] 102 | X2 = X[:,SEQ_LEN:2*SEQ_LEN] 103 | X3 = X[:,2*SEQ_LEN:3*SEQ_LEN] 104 | X4 = X[:,3*SEQ_LEN:4*SEQ_LEN] 105 | 106 | X3[np.where(X3 != 0 )] = np.log(X3[np.where(X3 != 0 )]) 107 | 108 | print("Filtered shape of X1 =", np.shape(X1)) 109 | print("Filtered shape of X2 =", np.shape(X2)) 110 | print("Filtered shape of X3 =", np.shape(X3)) 111 | print("Filtered shape of X4 =", np.shape(X4)) 112 | print("Filtered shape of y =", np.shape(y)) 113 | 114 | ##### BASIC PARAMETERS ##### 115 | n_samples = np.shape(X1)[0] 116 | time_steps = np.shape(X1)[1] # we have a time series of 100 payload sizes 117 | 118 | ##### CREATES MAPPING FROM SNI STRING TO INT ##### 119 | class_map = {sni:i for i, sni in enumerate(np.unique(y))} 120 | rev_class_map = {val: key for key, val in class_map.items()} 121 | 122 | n_labels = len(class_map) 123 | 124 | ##### CHANGE Y TO PD SO ITS EASIER TO MAP ##### 125 | y_pd = pd.DataFrame(y) 126 | y_pd = y_pd[0].map(class_map) 127 | 128 | ##### DUPLICATE Y LABELS, WE WILL NEED THIS LATER ##### 129 | y = y_pd.values.reshape(n_samples,) 130 | 131 | return X1, X2, X3, X4, y, time_steps, n_labels, rev_class_map 132 | 133 | 134 | #*********************************************************************************** 135 | # Function to save sklearn report on precision, recall, F1-Score into a dictionary 136 | #*********************************************************************************** 137 | def update_stats(stats, model, predictions, y_test): 138 | report = classification_report(y_test, [np.argmax(x) for x in predictions]) 139 | 140 | report_list = [] 141 | for row in report.split("\n"): 142 | parsed_row = [x for x in row.split(" ") if len(x) > 0] 143 | if len(parsed_row) > 0: 144 | report_list.append(parsed_row) 145 | 146 | # save accuracy, precision, recall, F1-Score to dictionary 147 | stats[model][0] += float(1. * np.sum([np.argmax(x) for x in predictions] == y_test) / len(y_test)) 148 | stats[model][1] += float(report_list[-1][1]) 149 | stats[model][2] += float(report_list[-1][2]) 150 | stats[model][3] += float(report_list[-1][3]) 151 | 152 | return stats 153 | 154 | def get_freq_vector(predictions, frequencies): 155 | pred_freqs = np.zeros(np.shape(frequencies)) 156 | for pred in predictions: 157 | pred_freqs[np.argmax(pred)] += 1 158 | 159 | return 1. * pred_freqs / np.sum(pred_freqs) 160 | 161 | def domain_expertise(predictions, frequencies, y_test, iterations=10000): 162 | alpha = 0.001 163 | frequencies = 1. * frequencies / np.sum(frequencies) 164 | pred_freqs = get_freq_vector(predictions, frequencies) 165 | coefficients = np.ones(np.shape(frequencies)) 166 | 167 | for i in range(iterations): 168 | #print("ACC: ", float(1. * np.sum([np.argmax(x) for x in predictions * coefficients] == y_test) / len(y_test))) 169 | #print("DIFF: ", np.max(pred_freqs - frequencies)) 170 | pred_freqs = get_freq_vector(predictions * coefficients, frequencies) 171 | index = np.argmax(pred_freqs - frequencies) 172 | coefficients[index] -= alpha 173 | 174 | return predictions * coefficients 175 | 176 | --------------------------------------------------------------------------------