├── CVA.png ├── LICENSE ├── data_preprocess.py ├── evaluate.py ├── preprocessor.py ├── readme.md ├── requirements.txt ├── results ├── results_ISCX.txt └── results_USTC.txt ├── train.py └── utils.py /CVA.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YasodGinige/TrafficGPT/894d1d3993414737148d27ca8d55f0ea3a9d9f4b/CVA.png -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Yasod 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /data_preprocess.py: -------------------------------------------------------------------------------- 1 | from preprocessor import Data_Preprocess 2 | import numpy as np 3 | import pandas as pd 4 | import os 5 | from sklearn.utils import shuffle 6 | from sklearn.model_selection import train_test_split 7 | import _pickle as cPickle 8 | import gc 9 | import argparse 10 | import sys 11 | 12 | def preprocess(args): 13 | data_path = args.data_path 14 | dataset = args.dataset 15 | 16 | if not os.path.exists("./temp_dir"): 17 | os.makedirs("temp_dir") 18 | 19 | Data_preprocessor = Data_Preprocess() 20 | Data_preprocessor.preprocess_dataset(data_path, dataset) 21 | 22 | if __name__ == "__main__": 23 | parser = argparse.ArgumentParser(description="data preprocessing") 24 | parser.add_argument("--data_path", type=str, default='./data', help="path to the datasets") 25 | parser.add_argument("--dataset", type=str, default='AWF', help="Dataset name") 26 | args = parser.parse_args() 27 | preprocess(args) 28 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import pandas as pd 4 | from tqdm.auto import tqdm 5 | import sys 6 | import random 7 | import gc 8 | import os 9 | import torch 10 | import torch.nn as nn 11 | from torch.utils.data import Dataset, DataLoader 12 | from sklearn.metrics import classification_report, accuracy_score 13 | from utils import DatasetCreator, GPT2_collator 14 | from transformers import (set_seed, 15 | TrainingArguments, 16 | Trainer, 17 | GPT2Config, 18 | GPT2Tokenizer, 19 | AdamW, 20 | get_linear_schedule_with_warmup, 21 | GPT2ForSequenceClassification) 22 | 23 | random.seed(42) 24 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 25 | torch.cuda.empty_cache() 26 | 27 | def pre_process(dataset): 28 | dataset['trace'] = dataset['text'] 29 | return dataset 30 | 31 | def get_labels(file): 32 | df = pd.read_csv(file) 33 | return np.array(df['target']) 34 | 35 | def predict(model, dataloader, device): 36 | model.eval() 37 | predictions_labels = [] 38 | 39 | for ind,batch in enumerate(tqdm(dataloader, total=len(dataloader))): 40 | batch = {k:v.type(torch.long).to(device) for k,v in batch.items()} 41 | with torch.no_grad(): 42 | outputs = model(**batch) 43 | _, logits = outputs[:2] 44 | if ind == 0: 45 | predictions_labels = logits.to('cpu').numpy() 46 | else: 47 | predictions_labels = np.concatenate((predictions_labels, logits.to('cpu').numpy()), axis=0) 48 | return predictions_labels 49 | 50 | 51 | def claculate_mean_vectors(NB_CLASSES, model_predictions, y_train): 52 | for i in range(NB_CLASSES): 53 | variable_name = f"Mean_{i}" 54 | locals()[variable_name]=np.array([0] * NB_CLASSES) 55 | count=[0]*NB_CLASSES 56 | txt_O = "Mean_{Class1:.0f}" 57 | Means={} 58 | for i in range(NB_CLASSES): 59 | Means[txt_O.format(Class1=i)]=np.array([0]*NB_CLASSES) 60 | 61 | for i in range(len(model_predictions)): 62 | k=np.argmax(model_predictions[i]) 63 | if (np.argmax(model_predictions[i])==y_train[i]): 64 | Means[txt_O.format(Class1=y_train[i])]=Means[txt_O.format(Class1=y_train[i])] + model_predictions[i] 65 | count[y_train[i]]+=1 66 | 67 | Mean_Vectors=[] 68 | for i in range(NB_CLASSES): 69 | Means[txt_O.format(Class1=i)]=Means[txt_O.format(Class1=i)]/count[i] 70 | Mean_Vectors.append(Means[txt_O.format(Class1=i)]) 71 | 72 | Mean_Vectors=np.array(Mean_Vectors) 73 | return Mean_Vectors 74 | 75 | def calculate_thresholds(NB_CLASSES, model_predictions, y_valid, Mean_Vectors, K_number, TH_value): 76 | 77 | txt_1 = "Dist_{Class1:.0f}" 78 | Distances={} 79 | for i in range(NB_CLASSES): 80 | Distances[txt_1.format(Class1=i)]=[] 81 | 82 | Indexes=[] 83 | for i in range(NB_CLASSES): 84 | Indexes.append([]) 85 | 86 | Values={} 87 | for i in range(NB_CLASSES): 88 | Values[i]=[0]*NB_CLASSES 89 | 90 | for i in range(len(model_predictions)): 91 | if (y_valid[i]==np.argmax(model_predictions[i])): 92 | dist = np.linalg.norm(Mean_Vectors[y_valid[i]]-model_predictions[i]) 93 | for k in range(NB_CLASSES): 94 | if k!=int(y_valid[i]): 95 | Values[y_valid[i]][k]+=np.linalg.norm(Mean_Vectors[k]-model_predictions[i])-dist 96 | 97 | for i in range(NB_CLASSES): 98 | Tot=0 99 | for l in range(K_number): 100 | Min=min(Values[i]) 101 | Tot+=Min 102 | Indexes[i].append(Values[i].index(Min)) 103 | Values[i][Values[i].index(Min)]=1000000 104 | 105 | Indexes=np.array(Indexes) 106 | 107 | 108 | txt_1 = "Dist_{Class1:.0f}" 109 | Distances={} 110 | for i in range(NB_CLASSES): 111 | Distances[txt_1.format(Class1=i)]=[] 112 | 113 | for i in range(len(model_predictions)): 114 | if (y_valid[i]==np.argmax(model_predictions[i])): 115 | dist = np.linalg.norm(Mean_Vectors[y_valid[i]]-model_predictions[i]) 116 | Distances[txt_1.format(Class1=y_valid[i])].append(dist) 117 | 118 | TH=[0]*NB_CLASSES 119 | for j in range(NB_CLASSES): 120 | Distances[txt_1.format(Class1=j)].sort() 121 | Dist=Distances[txt_1.format(Class1=j)] 122 | try: 123 | TH[j]=Dist[int(len(Dist)*TH_value)] 124 | except: 125 | if j == 0: 126 | TH[j] = 10 127 | else: 128 | TH[j] = TH[j-1] 129 | 130 | Threasholds_1=np.array(TH) 131 | print("Thresholds for method 1 calculated") 132 | 133 | 134 | txt_1 = "Dist_{Class1:.0f}" 135 | Distances={} 136 | for i in range(NB_CLASSES): 137 | Distances[txt_1.format(Class1=i)]=[] 138 | 139 | for i in range(len(model_predictions)): 140 | if (y_valid[i]==np.argmax(model_predictions[i])): 141 | dist = np.linalg.norm(Mean_Vectors[y_valid[i]]-model_predictions[i]) 142 | Tot=0 143 | for k in range(NB_CLASSES): 144 | if k!=int(y_valid[i]) and k in Indexes[y_valid[i]]: 145 | Tot+=(np.linalg.norm(Mean_Vectors[k]-model_predictions[i])-dist) 146 | Distances[txt_1.format(Class1=y_valid[i])].append(Tot) 147 | 148 | TH=[0]*NB_CLASSES 149 | for j in range(NB_CLASSES): 150 | Distances[txt_1.format(Class1=j)].sort() 151 | Dist=Distances[txt_1.format(Class1=j)] 152 | try: 153 | TH[j]=Dist[int(len(Dist)*(1-TH_value))] 154 | except: 155 | if j == 0: 156 | TH[j] = 10 157 | else: 158 | TH[j] = TH[j-1] 159 | 160 | Threasholds_2=np.array(TH) 161 | print("Thresholds for method 2 calculated") 162 | 163 | 164 | txt_1 = "Dist_{Class1:.0f}" 165 | Distances={} 166 | for i in range(NB_CLASSES): 167 | Distances[txt_1.format(Class1=i)]=[] 168 | 169 | for i in range(len(model_predictions)): 170 | if (y_valid[i]==np.argmax(model_predictions[i])): 171 | dist = np.linalg.norm(Mean_Vectors[y_valid[i]]-model_predictions[i]) 172 | Tot=0 173 | for k in range(NB_CLASSES): 174 | if k!=int(y_valid[i]) and k in Indexes[y_valid[i]]: 175 | Tot+=np.linalg.norm(Mean_Vectors[k]-model_predictions[i]) 176 | Tot=dist/Tot 177 | Distances[txt_1.format(Class1=y_valid[i])].append(Tot) 178 | 179 | TH=[0]*NB_CLASSES 180 | for j in range(NB_CLASSES): 181 | Distances[txt_1.format(Class1=j)].sort() 182 | Dist=Distances[txt_1.format(Class1=j)] 183 | try: 184 | TH[j]=Dist[int(len(Dist)*TH_value)] 185 | except: 186 | if j == 0: 187 | TH[j] = 10 188 | else: 189 | TH[j] = TH[j-1] 190 | 191 | Threasholds_3=np.array(TH) 192 | print("Thresholds for method 3 calculated") 193 | 194 | return Threasholds_1, Threasholds_2, Threasholds_3, Indexes 195 | 196 | 197 | def print_results(prediction_classes, prediction_classes_open, y_test, y_open, NB_CLASSES, KLND_type, dataset_name): 198 | 199 | y_test = y_test.astype(int) 200 | y_open = y_open.astype(int) 201 | 202 | acc_Close = accuracy_score(prediction_classes, y_test[:len(prediction_classes)]) 203 | print('Test accuracy Normal model_Closed_set :', acc_Close) 204 | 205 | acc_Open = accuracy_score(prediction_classes_open, y_open[:len(prediction_classes_open)]) 206 | print('Test accuracy Normal model_Open_set :', acc_Open) 207 | 208 | y_test=y_test[:len(prediction_classes)] 209 | y_open=y_open[:len(prediction_classes_open)] 210 | 211 | Matrix=[] 212 | for i in range(NB_CLASSES+1): 213 | Matrix.append(np.zeros(NB_CLASSES+1)) 214 | 215 | for i in range(len(y_test)): 216 | Matrix[y_test[i]][prediction_classes[i]]+=1 217 | 218 | for i in range(len(y_open)): 219 | Matrix[y_open[i]][prediction_classes_open[i]]+=1 220 | 221 | 222 | print("\n", "Micro") 223 | F1_Score_micro=Micro_F1(Matrix, NB_CLASSES) 224 | print("Average Micro F1_Score: ", F1_Score_micro) 225 | 226 | print("\n", "Macro") 227 | F1_Score_macro=Macro_F1(Matrix, NB_CLASSES) 228 | print("Average Macro F1_Score: ", F1_Score_macro) 229 | 230 | text_file = open("./results/results_"+ dataset_name +".txt", "a") 231 | 232 | text_file.write('########' + KLND_type + '#########\n') 233 | text_file.write('Test accuracy Normal model_Closed_set :'+ str(acc_Close) + '\n') 234 | text_file.write('Test accuracy Normal model_Open_set :'+ str(acc_Open) + '\n') 235 | text_file.write("Average Micro F1_Score: " + str(F1_Score_micro) + '\n') 236 | text_file.write("Average Macro F1_Score: " + str(F1_Score_macro) + '\n') 237 | text_file.write('\n') 238 | text_file.close() 239 | 240 | 241 | def final_classification(NB_CLASSES, model_predictions_test, model_predictions_open, y_test, y_open, Mean_vectors, Indexes, Threasholds_1, Threasholds_2, Threasholds_3, dataset_name): 242 | 243 | 244 | print("\n", "############## Distance Method 1 #################################") 245 | prediction_classes=[] 246 | for i in range(len(model_predictions_test)): 247 | 248 | d=np.argmax(model_predictions_test[i], axis=0) 249 | if np.linalg.norm(model_predictions_test[i]-Mean_vectors[d])>Threasholds_1[d]: 250 | prediction_classes.append(NB_CLASSES) 251 | 252 | else: 253 | prediction_classes.append(d) 254 | prediction_classes=np.array(prediction_classes) 255 | 256 | prediction_classes_open=[] 257 | for i in range(len(model_predictions_open)): 258 | 259 | d=np.argmax(model_predictions_open[i], axis=0) 260 | if np.linalg.norm(model_predictions_open[i]-Mean_vectors[d])>Threasholds_1[d]: 261 | prediction_classes_open.append(NB_CLASSES) 262 | else: 263 | prediction_classes_open.append(d) 264 | prediction_classes_open=np.array(prediction_classes_open) 265 | print_results(prediction_classes, prediction_classes_open, y_test, y_open, NB_CLASSES, 'K-LND1', dataset_name) 266 | 267 | print("\n", "############## Distance Method 2 #################################") 268 | prediction_classes=[] 269 | for i in range(len(model_predictions_test)): 270 | d=np.argmax(model_predictions_test[i], axis=0) 271 | dist=np.linalg.norm(Mean_vectors[d]-model_predictions_test[i]) 272 | Tot=0 273 | for k in range(NB_CLASSES): 274 | if k!=d: 275 | Tot+=np.linalg.norm(Mean_vectors[k]-model_predictions_test[i])-dist 276 | 277 | if TotThreasholds_3[d]: 314 | prediction_classes.append(NB_CLASSES) 315 | 316 | else: 317 | prediction_classes.append(d) 318 | 319 | prediction_classes=np.array(prediction_classes) 320 | 321 | prediction_classes_open=[] 322 | for i in range(len(model_predictions_open)): 323 | d=np.argmax(model_predictions_open[i], axis=0) 324 | dist=np.linalg.norm(Mean_vectors[d]-model_predictions_open[i]) 325 | Tot=0 326 | for k in range(NB_CLASSES): 327 | if k!=int(d) and k in Indexes[d]: 328 | Tot+=np.linalg.norm(Mean_vectors[k]-model_predictions_open[i]) 329 | Tot=dist/Tot 330 | if Tot>Threasholds_3[d]: 331 | prediction_classes_open.append(NB_CLASSES) 332 | 333 | else: 334 | prediction_classes_open.append(d) 335 | 336 | prediction_classes_open=np.array(prediction_classes_open) 337 | print_results(prediction_classes, prediction_classes_open, y_test, y_open, NB_CLASSES, 'K-LND3', dataset_name) 338 | 339 | 340 | def Micro_F1(Matrix, NB_CLASSES): 341 | epsilon = 1e-8 342 | TP = 0 343 | FP = 0 344 | TN = 0 345 | 346 | for k in range(NB_CLASSES): 347 | TP += Matrix[k][k] 348 | FP += (np.sum(Matrix, axis=0)[k] - Matrix[k][k]) 349 | TN += (np.sum(Matrix, axis=1)[k] - Matrix[k][k]) 350 | 351 | Micro_Prec = TP / (TP + FP) 352 | Micro_Rec = TP / (TP + TN) 353 | print("Micro_Prec:", Micro_Prec) 354 | print("Micro_Rec:", Micro_Rec) 355 | Micro_F1 = 2 * Micro_Prec * Micro_Rec / (Micro_Rec + Micro_Prec + epsilon) 356 | 357 | return Micro_F1 358 | 359 | def Macro_F1(Matrix, NB_CLASSES): 360 | epsilon = 1e-8 361 | F1s = np.zeros(NB_CLASSES) 362 | 363 | for k in range(NB_CLASSES): 364 | TP = Matrix[k][k] 365 | FP = np.sum(Matrix[:, k]) - TP 366 | FN = np.sum(Matrix[k, :]) - TP 367 | 368 | precision = TP / (TP + FP + epsilon) 369 | recall = TP / (TP + FN + epsilon) 370 | F1s[k] = 2 * precision * recall / (precision + recall + epsilon) 371 | 372 | macro_F1 = np.mean(F1s) 373 | print("Per-class F1s:", F1s) 374 | print("Macro F1:", macro_F1) 375 | return macro_F1 376 | 377 | 378 | def main(args): 379 | max_len = args.max_len 380 | batch_size = args.batch_size 381 | epochs = args.epochs 382 | num_labels = args.num_labels 383 | K_number = args.K_number 384 | TH_value = args.TH_value 385 | dataset = args.dataset 386 | 387 | print('Loading gpt-2 model') 388 | model_config = GPT2Config.from_pretrained(pretrained_model_name_or_path='gpt2', num_labels=num_labels) 389 | 390 | print('Loading tokenizer...') 391 | tokenizer = GPT2Tokenizer.from_pretrained(pretrained_model_name_or_path='gpt2') 392 | tokenizer.padding_side = "right" 393 | tokenizer.pad_token = tokenizer.eos_token 394 | 395 | print('Loading model...') 396 | model = GPT2ForSequenceClassification.from_pretrained(pretrained_model_name_or_path='gpt2', config=model_config) 397 | model.resize_token_embeddings(len(tokenizer)) 398 | model.config.pad_token_id = model.config.eos_token_id 399 | model = nn.DataParallel(model) 400 | 401 | gpt2_collator = GPT2_collator(tokenizer=tokenizer, max_seq_len=max_len) 402 | optimizer = AdamW(model.parameters(), lr = 5e-5, eps = 1e-8, weight_decay=0.01) 403 | 404 | model.load_state_dict(torch.load('./trained_models/trained_gpt_' + dataset + '.pth')) 405 | model.to(device) 406 | 407 | train_dataset = pd.read_csv('./temp_dir/train.csv') 408 | start_index = int(len(train_dataset) * 0.6) 409 | train_subset = train_dataset[start_index:] 410 | train_processed = pre_process(train_dataset) 411 | train_data = DatasetCreator(train_processed, train=False) 412 | train_eval_dataloader = DataLoader(train_data, batch_size=32, shuffle=False, collate_fn=gpt2_collator) 413 | 414 | train_predictions = predict(model, train_eval_dataloader, device) 415 | y_train = get_labels('./temp_dir/train.csv') 416 | del train_data, train_dataset, train_processed, train_eval_dataloader 417 | gc.collect() 418 | 419 | valid_dataset = pd.read_csv('./temp_dir/valid.csv') 420 | y_valid = get_labels('./temp_dir/valid.csv') 421 | if dataset == 'DC': 422 | valid_dataset = pd.concat([train_subset, valid_dataset], ignore_index=True) 423 | y_valid = np.concatenate((y_valid,y_train[start_index:]), axis=0) 424 | 425 | valid_processed = pre_process(valid_dataset) 426 | valid_data = DatasetCreator(valid_processed, train=False) 427 | valid_eval_dataloader = DataLoader(valid_data, batch_size=32, shuffle=False, collate_fn=gpt2_collator) 428 | 429 | valid_predictions = predict(model, valid_eval_dataloader, device) 430 | del valid_data, valid_dataset, valid_processed, valid_eval_dataloader 431 | gc.collect() 432 | 433 | test_dataset = pd.read_csv('./temp_dir/test.csv') 434 | test_processed = pre_process(test_dataset) 435 | test_data = DatasetCreator(test_processed, train=False) 436 | test_dataloader = DataLoader(test_data, batch_size=32, shuffle=False, collate_fn=gpt2_collator) 437 | 438 | test_predictions = predict(model, test_dataloader, device) 439 | y_test = get_labels('./temp_dir/test.csv') 440 | del test_data, test_dataset, test_processed, test_dataloader 441 | gc.collect() 442 | 443 | open_dataset = pd.read_csv('./temp_dir/open.csv') 444 | open_processed = pre_process(open_dataset) 445 | open_data = DatasetCreator(open_processed, train=False) 446 | open_dataloader = DataLoader(open_data, batch_size=32, shuffle=False, collate_fn=gpt2_collator) 447 | 448 | open_predictions = predict(model, open_dataloader, device) 449 | y_open = get_labels('./temp_dir/open.csv') 450 | del open_data, open_dataset, open_processed, open_dataloader 451 | y_open = np.array([num_labels]*len(y_open)) 452 | 453 | if not os.path.exists('./results'): 454 | os.makedirs('./results') 455 | Mean_Vectors = claculate_mean_vectors(num_labels, train_predictions, y_train) 456 | Threasholds_1, Threasholds_2, Threasholds_3, Indexes = calculate_thresholds(num_labels, valid_predictions, y_valid, Mean_Vectors, K_number, TH_value) 457 | final_classification(num_labels, test_predictions, open_predictions, y_test, y_open, Mean_Vectors, Indexes, Threasholds_1, Threasholds_2, Threasholds_3, dataset) 458 | 459 | 460 | if __name__ == "__main__": 461 | parser = argparse.ArgumentParser(description="Train GPT-2 model with sequence classification") 462 | parser.add_argument("--max_len", type=int, default=1000, help="Max length of the text for input") 463 | parser.add_argument("--batch_size", type=int, default=12, help="Batch size for training") 464 | parser.add_argument("--epochs", type=int, default=3, help="Number of epochs to train") 465 | parser.add_argument("--num_labels", type=int, default=120, help="Number of labels for classification") 466 | parser.add_argument("--K_number", type=int, default=50, help="K nearest naibours") 467 | parser.add_argument("--TH_value", type=float, default=0.8, help="Threshold value for distances") 468 | parser.add_argument("--dataset", type=str, default='DC', help="Dataset name") 469 | 470 | args = parser.parse_args() 471 | main(args) 472 | -------------------------------------------------------------------------------- /preprocessor.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import os 4 | from sklearn.utils import shuffle 5 | from sklearn.model_selection import train_test_split 6 | import _pickle as cPickle 7 | import gc 8 | 9 | 10 | class Data_Preprocess(): 11 | def __init__(self): 12 | self.train_data = None 13 | self.valid_data = None 14 | self.test_data = None 15 | self.open_data = None 16 | self.name_list = ['train', 'valid', 'test', 'open'] 17 | 18 | @staticmethod 19 | def create_pattern(x, y): 20 | max_pat_len = 32 21 | pat = '' 22 | while x !=0: 23 | if x.svg)]() 5 | 6 | **Note:** 7 | - ⭐ **Please leave a STAR if you like this project!** ⭐ 8 | - If you are using this work for academic purposes, please cite our [paper](https://dl.acm.org/doi/abs/10.1145/3674213.3674217). 9 | - If you find any incorrect / inappropriate / outdated content, please kindly consider opening an issue or a PR. 10 | 11 |
12 | overall architecure 13 |
14 | 15 | In this repository, we guide you in setting up the TrafficGPT project in a local environment and reproducing the results. TrafficGPT, a novel traffic analysis attack that leverages GPT-2, a popular LLM, to enhance feature extraction, thereby improving 16 | the open-set performance of downstream classification. We use five existing encrypted traffic datasets to show how the feature extraction by GPT-2 improves the open-set performance of traffic 17 | analysis attacks. As the open-set classification methods, we use K-LND, OpenMax, and Backgroundclass methods, and shows that K-LND methods have higher performance overall. 18 | 19 | **Datasets:** [AWF](https://arxiv.org/abs/1708.06376), [DF](https://arxiv.org/abs/1801.02265), [DC](https://www.semanticscholar.org/paper/Deep-Content%3A-Unveiling-Video-Streaming-Content-Li-Huang/f9feff95bc1d68674d5db426053f417bd2c8786b), [USTC](https://drive.google.com/file/d/1F09zxln9iFg2HWoqc6m4LKFhYK7cDQv_/view), [CSTNet-tls](https://drive.google.com/drive/folders/1JSsYmevkxQFanoKOi_i1ooA6pH3s9sDr) 20 | 21 | **Openset methods** 22 | - [K-LND methods](https://github.com/ThiliniDahanayaka/Open-Set-Traffic-Classification) 23 | - OpenMax 24 | - Background class 25 | 26 | # Using TrafficGPT 27 | 28 | First, clone the git repo and install the requirements. 29 | ``` 30 | git clone https://github.com/YasodGinige/TrafficGPT.git 31 | cd TrafficGPT 32 | pip install -r requirements.txt 33 | ``` 34 | Next, download the dataset and place it in the data directory. 35 | ``` 36 | gdown https://drive.google.com/uc?id=1-MVfxyHdQeUguBmYrIIw1jhMVSqxXQgO 37 | unzip data.zip 38 | ``` 39 | 40 | Then, preprocess the dataset you want to train and evaluate. Here, the dataset name should be DF, AWF, DC, USTC, or CSTNet. 41 | ``` 42 | python3 data_preprocess.py --data_path ./data --dataset 43 | ``` 44 | To train the model, run the suitable code for the dataset: 45 | ``` 46 | python3 train.py --max_len 1024 --batch_size 12 --epochs 3 --num_labels 60 --dataset DF 47 | python3 train.py --max_len 1024 --batch_size 12 --epochs 3 --num_labels 200 --dataset AWF 48 | python3 train.py --max_len 1024 --batch_size 12 --epochs 3 --num_labels 4 --dataset DC 49 | python3 train.py --max_len 1024 --batch_size 12 --epochs 3 --num_labels 12 --dataset USTC 50 | python3 train.py --max_len 1024 --batch_size 12 --epochs 3 --num_labels 75 --dataset CSTNet 51 | ``` 52 | To evaluate, run the suitable code for the dataset: 53 | ``` 54 | python3 evaluate.py --max_len 1024 --batch_size 12 --epochs 3 --num_labels 60 --K_number 30 --TH_value 0.8 --dataset DF 55 | python3 evaluate.py --max_len 1024 --batch_size 12 --epochs 3 --num_labels 200 --K_number 50 --TH_value 0.9 --dataset AWF 56 | python3 evaluate.py --max_len 1024 --batch_size 12 --epochs 3 --num_labels 4 --K_number 4 --TH_value 0.9 --dataset DC 57 | python3 evaluate.py --max_len 1024 --batch_size 12 --epochs 3 --num_labels 12 --K_number 5 --TH_value 0.8 --dataset USTC 58 | python3 evaluate.py --max_len 1024 --batch_size 12 --epochs 5 --num_labels 75 --K_number 20 --TH_value 0.8 --dataset CSTNe 59 | ``` 60 | 61 | # Citations 62 | If you are using this work for academic purposes, please cite our [paper](https://dl.acm.org/doi/abs/10.1145/3674213.3674217). 63 | ``` 64 | @inproceedings{ginige2024trafficgpt, 65 | title={TrafficGPT: An LLM Approach for Open-Set Encrypted Traffic Classification}, 66 | author={Ginige, Yasod and Dahanayaka, Thilini and Seneviratne, Suranga}, 67 | booktitle={Proceedings of the Asian Internet Engineering Conference 2024}, 68 | pages={26--35}, 69 | year={2024} 70 | } 71 | ``` 72 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | numpy 3 | pandas 4 | tqdm 5 | scikit-learn 6 | transformers 7 | gdown 8 | -------------------------------------------------------------------------------- /results/results_ISCX.txt: -------------------------------------------------------------------------------- 1 | K-LND1 2 | 3 | [95. 0. 0. 0. 0. 1. 0. 0. 0. 0. 37.] 4 | [ 0. 298. 0. 0. 0. 0. 0. 0. 0. 0. 82.] 5 | [ 0. 0. 378. 1. 0. 0. 0. 0. 0. 0. 110.] 6 | [ 0. 0. 0. 277. 0. 2. 0. 0. 0. 0. 86.] 7 | [ 0. 0. 0. 0. 386. 0. 0. 0. 0. 0. 73.] 8 | [ 1. 0. 0. 0. 0. 30. 0. 0. 0. 0. 16.] 9 | [ 0. 0. 0. 0. 0. 0. 374. 0. 0. 0. 107.] 10 | [ 0. 0. 0. 0. 0. 0. 0. 391. 0. 0. 108.] 11 | [ 0. 0. 0. 0. 0. 0. 0. 0. 178. 0. 62.] 12 | [ 0. 0. 0. 0. 0. 0. 0. 1. 0. 382. 88.] 13 | [ 3. 1. 0. 2296. 2. 2. 0. 0. 0. 0. 31575.] 14 | 15 | Test accuracy Normal model_Closed_set :0.7825476992143658 16 | Test accuracy Normal model_Open_set :0.9319932701673603 17 | Average Micro F1_Score: 0.6438877938410408 18 | Average Macro F1_Score: 0.7900214548184425 19 | 20 | {0: {'precision': 0.9595959595959596, 'recall': 0.7142857142857143, 'f1': 0.8189655172413793}, 1: {'precision': 0.9966555183946488, 'recall': 0.7842105263157895, 'f1': 0.8777614138438882}, 2: {'precision': 1.0, 'recall': 0.7730061349693251, 'f1': 0.8719723183391004}, 3: {'precision': 0.10761460761460762, 'recall': 0.7589041095890411, 'f1': 0.18849948962232052}, 4: {'precision': 0.9948453608247423, 'recall': 0.840958605664488, 'f1': 0.9114521841794568}, 5: {'precision': 0.8571428571428571, 'recall': 0.6382978723404256, 'f1': 0.7317073170731707}, 6: {'precision': 1.0, 'recall': 0.7775467775467776, 'f1': 0.8748538011695907}, 7: {'precision': 0.9974489795918368, 'recall': 0.7835671342685371, 'f1': 0.8776655443322111}, 8: {'precision': 1.0, 'recall': 0.7416666666666667, 'f1': 0.8516746411483254}, 9: {'precision': 1.0, 'recall': 0.8110403397027601, 'f1': 0.895662368112544}, 10: {'precision': 0.976224338362602, 'recall': 0.9319932701673603, 'f1': 0.9535961825951708}} 21 | 22 | 23 | 24 | K-LND2 25 | 26 | [122. 0. 0. 2. 0. 4. 0. 0. 0. 0. 5.] 27 | [ 2. 376. 0. 0. 0. 1. 0. 0. 0. 0. 1.] 28 | [ 3. 0. 471. 4. 0. 2. 0. 0. 0. 0. 9.] 29 | [ 4. 0. 5. 351. 0. 3. 0. 0. 0. 0. 2.] 30 | [ 0. 1. 0. 0. 445. 0. 0. 0. 0. 0. 13.] 31 | [ 7. 1. 0. 1. 0. 37. 0. 0. 0. 0. 1.] 32 | [ 0. 0. 0. 0. 0. 0. 478. 1. 0. 2. 0.] 33 | [ 0. 2. 0. 0. 0. 0. 0. 497. 0. 0. 0.] 34 | [ 0. 0. 0. 2. 0. 0. 0. 0. 234. 0. 4.] 35 | [ 0. 0. 0. 0. 0. 0. 0. 2. 0. 469. 0.] 36 | [ 0. 0. 0. 2252. 0. 2. 0. 0. 0. 0. 31625.] 37 | 38 | Test accuracy Normal model_Closed_set :0.9764309764309764 39 | Test accuracy Normal model_Open_set :0.93346911065852 40 | Average Micro F1_Score: 0.7446239387930865 41 | Average Macro F1_Score: 0.8832125135439283 42 | 43 | {0: {'precision': 0.8840579710144928, 'recall': 0.9172932330827067, 'f1': 0.9003690036900369}, 1: {'precision': 0.9894736842105263, 'recall': 0.9894736842105263, 'f1': 0.9894736842105263}, 2: {'precision': 0.9894957983193278, 'recall': 0.9631901840490797, 'f1': 0.9761658031088083}, 3: {'precision': 0.13437978560490046, 'recall': 0.9616438356164384, 'f1': 0.23580786026200876}, 4: {'precision': 1.0, 'recall': 0.9694989106753813, 'f1': 0.9845132743362832}, 5: {'precision': 0.7551020408163265, 'recall': 0.7872340425531915, 'f1': 0.7708333333333333}, 6: {'precision': 1.0, 'recall': 0.9937629937629938, 'f1': 0.9968717413972888}, 7: {'precision': 0.994, 'recall': 0.9959919839679359, 'f1': 0.994994994994995}, 8: {'precision': 1.0, 'recall': 0.975, 'f1': 0.9873417721518987}, 9: {'precision': 0.9957537154989384, 'recall': 0.9957537154989384, 'f1': 0.9957537154989384}, 10: {'precision': 0.9988945041061276, 'recall': 0.93346911065852, 'f1': 0.9650742306107813}} 44 | 45 | 46 | 47 | K-LND3 48 | 49 | [123. 0. 0. 0. 0. 4. 0. 0. 0. 0. 6.] 50 | [ 3. 376. 0. 0. 0. 1. 0. 0. 0. 0. 0.] 51 | [ 4. 0. 470. 4. 0. 2. 0. 0. 0. 0. 9.] 52 | [ 4. 0. 5. 348. 0. 3. 0. 0. 0. 0. 5.] 53 | [ 0. 1. 0. 0. 432. 0. 0. 0. 0. 0. 26.] 54 | [ 7. 1. 0. 0. 0. 37. 0. 0. 0. 0. 2.] 55 | [ 0. 0. 0. 0. 0. 0. 471. 0. 0. 0. 10.] 56 | [ 0. 2. 0. 0. 0. 0. 0. 484. 0. 0. 13.] 57 | [ 0. 0. 0. 1. 0. 0. 0. 0. 230. 0. 9.] 58 | [ 0. 0. 0. 0. 0. 0. 0. 1. 0. 437. 33.] 59 | [ 0. 0. 0. 2255. 0. 2. 0. 0. 0. 0. 31622.] 60 | 61 | Test accuracy Normal model_Closed_set :0.9562289562289562 62 | Test accuracy Normal model_Open_set :0.9333805602290505 63 | Average Micro F1_Score: 0.7351164749912454 64 | Average Macro F1_Score: 0.8751538961844633 65 | 66 | {0: {'precision': 0.8723404255319149, 'recall': 0.924812030075188, 'f1': 0.8978102189781022}, 1: {'precision': 0.9894736842105263, 'recall': 0.9894736842105263, 'f1': 0.9894736842105263}, 2: {'precision': 0.9894736842105263, 'recall': 0.9611451942740287, 'f1': 0.9751037344398341}, 3: {'precision': 0.1334355828220859, 'recall': 0.9534246575342465, 'f1': 0.23410696266397582}, 4: {'precision': 1.0, 'recall': 0.9411764705882353, 'f1': 0.9696969696969697}, 5: {'precision': 0.7551020408163265, 'recall': 0.7872340425531915, 'f1': 0.7708333333333333}, 6: {'precision': 1.0, 'recall': 0.9792099792099792, 'f1': 0.9894957983193277}, 7: {'precision': 0.9979381443298969, 'recall': 0.969939879759519, 'f1': 0.9837398373983739}, 8: {'precision': 1.0, 'recall': 0.9583333333333334, 'f1': 0.9787234042553191}, 9: {'precision': 1.0, 'recall': 0.9278131634819533, 'f1': 0.9625550660792952}, 10: {'precision': 0.9964392626437687, 'recall': 0.9333805602290505, 'f1': 0.9638796598286951}} -------------------------------------------------------------------------------- /results/results_USTC.txt: -------------------------------------------------------------------------------- 1 | ########K-LND1######### 2 | [367. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 133.] 3 | [ 0. 372. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 128.] 4 | [ 0. 0. 383. 0. 0. 0. 0. 0. 0. 0. 0. 0. 117.] 5 | [ 0. 0. 0. 376. 0. 0. 0. 0. 0. 0. 0. 0. 124.] 6 | [ 0. 0. 0. 0. 384. 0. 0. 0. 0. 0. 0. 0. 116.] 7 | [ 0. 0. 0. 0. 0. 393. 0. 0. 0. 0. 0. 0. 107.] 8 | [ 0. 0. 0. 0. 0. 0. 392. 0. 0. 0. 0. 0. 108.] 9 | [ 0. 0. 0. 0. 0. 0. 0. 354. 0. 0. 0. 0. 146.] 10 | [ 0. 0. 0. 0. 0. 0. 0. 0. 370. 0. 0. 0. 130.] 11 | [ 0. 0. 0. 0. 0. 0. 0. 0. 0. 384. 0. 0. 116.] 12 | [ 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 369. 0. 130.] 13 | [ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 386. 114.] 14 | [ 113. 0. 0. 34. 0. 58. 700. 0. 2. 2895. 0. 0. 16698.] 15 | 16 | Test accuracy Normal model_Closed_set :0.755 17 | Test accuracy Normal model_Open_set :0.8145365853658536 18 | Average Micro F1_Score: 0.6321077185679012 19 | Average Macro F1_Score: 0.7571018721851352 20 | 21 | {0: {'precision': 0.7645833333333333, 'recall': 0.734, 'f1': 0.7489795918367346}, 1: {'precision': 1.0, 'recall': 0.744, 'f1': 0.8532110091743119}, 2: {'precision': 1.0, 'recall': 0.766, 'f1': 0.8674971687429218}, 3: {'precision': 0.9170731707317074, 'recall': 0.752, 'f1': 0.8263736263736263}, 4: {'precision': 1.0, 'recall': 0.768, 'f1': 0.8687782805429864}, 5: {'precision': 0.8713968957871396, 'recall': 0.786, 'f1': 0.8264984227129337}, 6: {'precision': 0.358974358974359, 'recall': 0.784, 'f1': 0.49246231155778897}, 7: {'precision': 1.0, 'recall': 0.708, 'f1': 0.82903981264637}, 8: {'precision': 0.9946236559139785, 'recall': 0.74, 'f1': 0.8486238532110091}, 9: {'precision': 0.11707317073170732, 'recall': 0.768, 'f1': 0.2031746031746032}, 10: {'precision': 1.0, 'recall': 0.738, 'f1': 0.8492520138089759}, 11: {'precision': 1.0, 'recall': 0.772, 'f1': 0.871331828442438}, 12: {'precision': 0.9191390983651676, 'recall': 0.8145365853658536, 'f1': 0.8636822096361239}} 22 | 23 | 24 | 25 | ########K-LND2######### 26 | [500. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 27 | [ 0. 500. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 28 | [ 0. 0. 500. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 29 | [ 0. 0. 0. 500. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 30 | [ 0. 0. 0. 0. 498. 0. 2. 0. 0. 0. 0. 0. 0.] 31 | [ 0. 0. 0. 0. 0. 500. 0. 0. 0. 0. 0. 0. 0.] 32 | [ 0. 0. 0. 0. 0. 0. 496. 3. 0. 1. 0. 0. 0.] 33 | [ 0. 0. 0. 0. 0. 0. 0. 500. 0. 0. 0. 0. 0.] 34 | [ 0. 0. 0. 0. 0. 0. 0. 0. 500. 0. 0. 0. 0.] 35 | [ 0. 0. 0. 0. 1. 0. 3. 0. 0. 495. 1. 0. 0.] 36 | [ 0. 0. 0. 0. 0. 0. 0. 0. 0. 4. 496. 0. 0.] 37 | [ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 500. 0.] 38 | [ 55. 0. 0. 53. 0. 42. 872. 0. 0. 2809. 39 | 0. 0. 16669.] 40 | 41 | Test accuracy Normal model_Closed_set :0.9975 42 | Test accuracy Normal model_Open_set :0.8131219512195121 43 | Average Micro F1_Score: 0.7561114222399336 44 | Average Macro F1_Score: 0.8863134188783387 45 | 46 | {0: {'precision': 0.9009009009009009, 'recall': 1.0, 'f1': 0.947867298578199}, 1: {'precision': 1.0, 'recall': 1.0, 'f1': 1.0}, 2: {'precision': 1.0, 'recall': 1.0, 'f1': 1.0}, 3: {'precision': 0.9041591320072333, 'recall': 1.0, 'f1': 0.949667616334283}, 4: {'precision': 0.9979959919839679, 'recall': 0.996, 'f1': 0.9969969969969971}, 5: {'precision': 0.922509225092251, 'recall': 1.0, 'f1': 0.9596928982725529}, 6: {'precision': 0.36125273124544793, 'recall': 0.992, 'f1': 0.5296316070475174}, 7: {'precision': 0.9940357852882704, 'recall': 1.0, 'f1': 0.9970089730807578}, 8: {'precision': 1.0, 'recall': 1.0, 'f1': 1.0}, 9: {'precision': 0.14959202175883954, 'recall': 0.99, 'f1': 0.25991073772643736}, 10: {'precision': 0.9979879275653923, 'recall': 0.992, 'f1': 0.9949849548645938}, 11: {'precision': 1.0, 'recall': 1.0, 'f1': 1.0}, 12: {'precision': 1.0, 'recall': 0.8131219512195121, 'f1': 0.8969302375635609}} 47 | 48 | 49 | ########K-LND3######### 50 | [500. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 51 | [ 0. 500. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 52 | [ 0. 0. 500. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 53 | [ 0. 0. 0. 500. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 54 | [ 0. 0. 0. 0. 498. 0. 2. 0. 0. 0. 0. 0. 0.] 55 | [ 0. 0. 0. 0. 0. 500. 0. 0. 0. 0. 0. 0. 0.] 56 | [ 0. 0. 0. 0. 0. 0. 496. 3. 0. 1. 0. 0. 0.] 57 | [ 0. 0. 0. 0. 0. 0. 0. 500. 0. 0. 0. 0. 0.] 58 | [ 0. 0. 0. 0. 0. 0. 0. 0. 500. 0. 0. 0. 0.] 59 | [ 0. 0. 0. 0. 1. 0. 3. 0. 0. 495. 1. 0. 0.] 60 | [ 0. 0. 0. 0. 0. 0. 0. 0. 0. 4. 489. 0. 7.] 61 | [ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 500. 0.] 62 | [ 104. 0. 0. 34. 0. 56. 743. 0. 2. 2818. 0. 0. 16743.] 63 | 64 | Test accuracy Normal model_Closed_set :0.9963333333333333 65 | Test accuracy Normal model_Open_set :0.8167317073170731 66 | Average Micro F1_Score: 0.7591111063945578 67 | Average Macro F1_Score: 0.8856565734621363 68 | 69 | {0: {'precision': 0.8278145695364238, 'recall': 1.0, 'f1': 0.9057971014492754}, 1: {'precision': 1.0, 'recall': 1.0, 'f1': 1.0}, 2: {'precision': 1.0, 'recall': 1.0, 'f1': 1.0}, 3: {'precision': 0.9363295880149812, 'recall': 1.0, 'f1': 0.9671179883945841}, 4: {'precision': 0.9979959919839679, 'recall': 0.996, 'f1': 0.9969969969969971}, 5: {'precision': 0.8992805755395683, 'recall': 1.0, 'f1': 0.9469696969696969}, 6: {'precision': 0.3987138263665595, 'recall': 0.992, 'f1': 0.5688073394495413}, 7: {'precision': 0.9940357852882704, 'recall': 1.0, 'f1': 0.9970089730807578}, 8: {'precision': 0.9960159362549801, 'recall': 1.0, 'f1': 0.998003992015968}, 9: {'precision': 0.1491862567811935, 'recall': 0.99, 'f1': 0.25929806181246723}, 10: {'precision': 0.9979591836734694, 'recall': 0.978, 'f1': 0.9878787878787879}, 11: {'precision': 1.0, 'recall': 1.0, 'f1': 1.0}, 12: {'precision': 0.9995820895522388, 'recall': 0.8167317073170731, 'f1': 0.8989530201342281}} -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import pandas as pd 4 | from tqdm.auto import tqdm 5 | import os 6 | import random 7 | import gc 8 | import torch 9 | import torch.nn as nn 10 | from torch.utils.data import Dataset, DataLoader 11 | from sklearn.metrics import classification_report, accuracy_score 12 | from utils import DatasetCreator, GPT2_collator 13 | from transformers import (set_seed, 14 | TrainingArguments, 15 | Trainer, 16 | GPT2Config, 17 | GPT2Tokenizer, 18 | AdamW, 19 | get_linear_schedule_with_warmup, 20 | GPT2ForSequenceClassification) 21 | 22 | random.seed(42) 23 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 24 | torch.cuda.empty_cache() 25 | 26 | 27 | def pre_process(dataset): 28 | dataset['trace'] = dataset['text'] 29 | return dataset 30 | 31 | def get_labels(file): 32 | df = pd.read_csv(file) 33 | return np.array(df['target']) 34 | 35 | # Function for training 36 | def train(model, dataloader, optimizer, scheduler, device): 37 | model.train() 38 | predictions_labels = [] 39 | true_labels = [] 40 | total_loss = 0 41 | 42 | for batch in tqdm(dataloader, total=len(dataloader)): 43 | true_labels += batch['labels'].numpy().flatten().tolist() 44 | batch = {k:v.type(torch.long).to(device) for k,v in batch.items()} 45 | optimizer.zero_grad() 46 | outputs = model(**batch) 47 | loss, logits = outputs[:2] 48 | if loss.dim() > 0: 49 | loss = loss.mean() 50 | 51 | total_loss += loss.item() 52 | loss.backward() 53 | torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) 54 | optimizer.step() 55 | scheduler.step() 56 | predictions_labels += logits.argmax(axis=-1).flatten().tolist() 57 | avg_epoch_loss = total_loss / len(dataloader) 58 | return predictions_labels, true_labels, avg_epoch_loss 59 | 60 | # Function for validation 61 | def validate(model, dataloader, device): 62 | model.eval() 63 | predictions_labels = [] 64 | true_labels = [] 65 | total_loss = 0 66 | 67 | for batch in tqdm(dataloader, total=len(dataloader)): 68 | true_labels += batch['labels'].numpy().flatten().tolist() 69 | batch = {k:v.type(torch.long).to(device) for k,v in batch.items()} 70 | with torch.no_grad(): 71 | outputs = model(**batch) 72 | loss, logits = outputs[:2] 73 | if loss.dim() > 0: 74 | loss = loss.mean() 75 | 76 | total_loss += loss.item() 77 | predictions_labels += logits.argmax(axis=-1).flatten().tolist() 78 | avg_epoch_loss = total_loss / len(dataloader) 79 | return predictions_labels, true_labels, avg_epoch_loss 80 | 81 | def predict(dataloader, device): 82 | global model 83 | model.eval() 84 | predictions_labels = [] 85 | 86 | for ind,batch in enumerate(tqdm(dataloader, total=len(dataloader))): 87 | batch = {k:v.type(torch.long).to(device) for k,v in batch.items()} 88 | with torch.no_grad(): 89 | outputs = model(**batch) 90 | _, logits = outputs[:2] 91 | if ind == 0: 92 | predictions_labels = logits.to('cpu').numpy() 93 | else: 94 | predictions_labels = np.concatenate((predictions_labels, logits.to('cpu').numpy()), axis=0) 95 | return predictions_labels 96 | 97 | 98 | def main(args): 99 | max_len = args.max_len 100 | batch_size = args.batch_size 101 | epochs = args.epochs 102 | num_labels = args.num_labels 103 | dataset = args.dataset 104 | 105 | if not os.path.exists("./trained_models"): 106 | os.makedirs("trained_models") 107 | 108 | train_dataset = pd.read_csv('./temp_dir/train.csv') 109 | val_dataset = pd.read_csv('./temp_dir/valid.csv') 110 | 111 | print('Loading gpt-2 model') 112 | model_config = GPT2Config.from_pretrained(pretrained_model_name_or_path='gpt2', num_labels=num_labels) 113 | 114 | print('Loading tokenizer...') 115 | tokenizer = GPT2Tokenizer.from_pretrained(pretrained_model_name_or_path='gpt2') 116 | tokenizer.padding_side = "right" 117 | tokenizer.pad_token = tokenizer.eos_token 118 | 119 | print('Loading model...') 120 | model = GPT2ForSequenceClassification.from_pretrained(pretrained_model_name_or_path='gpt2', config=model_config) 121 | model.resize_token_embeddings(len(tokenizer)) 122 | model.config.pad_token_id = model.config.eos_token_id 123 | model = nn.DataParallel(model) 124 | model.to(device) 125 | 126 | gpt2_collator = GPT2_collator(tokenizer=tokenizer, max_seq_len=max_len) 127 | 128 | # Prepare training data 129 | processed_data = pre_process(train_dataset) 130 | train_data = DatasetCreator(processed_data, train=True) 131 | train_dataloader = DataLoader(train_data, batch_size=batch_size, shuffle=True, collate_fn=gpt2_collator) 132 | 133 | # Prepare validation data 134 | val_processed = pre_process(val_dataset) 135 | val_data = DatasetCreator(val_processed, train=True) 136 | val_dataloader = DataLoader(val_data, batch_size=batch_size, shuffle=True, collate_fn=gpt2_collator) 137 | 138 | optimizer = AdamW(model.parameters(), lr = 5e-5, eps = 1e-8, weight_decay=0.01) 139 | total_steps = len(train_dataloader) * epochs 140 | scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps = 0, num_training_steps = total_steps) 141 | loss = [] 142 | accuracy = [] 143 | val_loss_list = [] 144 | val_accuracy_list = [] 145 | 146 | for epoch in tqdm(range(epochs)): 147 | train_labels, true_labels, train_loss = train(model, train_dataloader, optimizer, scheduler, device) 148 | train_acc = accuracy_score(true_labels, train_labels) 149 | print('epoch: %.2f train accuracy %.2f' % (epoch, train_acc)) 150 | loss.append(train_loss) 151 | accuracy.append(train_acc) 152 | 153 | val_labels, val_true_labels, val_loss = validate(model, val_dataloader, device) 154 | val_acc= accuracy_score(val_true_labels, val_labels) 155 | print('epoch: %.2f validation accuracy %.2f' % (epoch, val_acc)) 156 | val_loss_list.append(val_loss) 157 | val_accuracy_list.append(val_acc) 158 | 159 | torch.save(model.state_dict(), './trained_models/trained_gpt_' + dataset + '.pth') 160 | 161 | 162 | if __name__ == "__main__": 163 | parser = argparse.ArgumentParser(description="Train GPT-2 model with sequence classification") 164 | parser.add_argument("--max_len", type=int, default=1024, help="Max length of the text for input") 165 | parser.add_argument("--batch_size", type=int, default=12, help="Batch size for training") 166 | parser.add_argument("--epochs", type=int, default=3, help="Number of epochs to train") 167 | parser.add_argument("--num_labels", type=int, default=120, help="Number of labels for classification") 168 | parser.add_argument("--dataset", type=str, default='AWF', help="Dataset name") 169 | 170 | args = parser.parse_args() 171 | main(args) 172 | 173 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset, DataLoader 3 | import numpy as np 4 | import pandas as pd 5 | import os 6 | from sklearn.utils import shuffle 7 | from sklearn.model_selection import train_test_split 8 | import _pickle as cPickle 9 | import gc 10 | from transformers import (set_seed, 11 | TrainingArguments, 12 | Trainer, 13 | GPT2Config, 14 | GPT2Tokenizer, 15 | AdamW, 16 | get_linear_schedule_with_warmup, 17 | GPT2ForSequenceClassification) 18 | 19 | 20 | class DatasetCreator(Dataset): 21 | def __init__(self, processed_data, train): 22 | self.data = processed_data 23 | self.train = train 24 | 25 | def __len__(self): 26 | return len(self.data) 27 | 28 | def __getitem__(self, index): 29 | line = self.data.iloc[index] 30 | if self.train: 31 | return {'text': line['trace'], 'label': line['target']} 32 | else: 33 | return {'text': line['trace'], 'label': 0} 34 | 35 | class GPT2_collator(object): 36 | def __init__(self, tokenizer, max_seq_len=None): 37 | self.tokenizer = tokenizer 38 | self.max_seq_len = max_seq_len 39 | return 40 | 41 | def __call__(self, sequences): 42 | texts = [sequence['text'] for sequence in sequences] 43 | labels = [int(sequence['label']) for sequence in sequences] 44 | inputs = self.tokenizer(text=texts, 45 | return_tensors='pt', 46 | padding=True, 47 | truncation=True, 48 | max_length=self.max_seq_len) 49 | inputs.update({'labels': torch.tensor(labels)}) 50 | return inputs 51 | 52 | class Data_Preprocess(): 53 | def __init__(self): 54 | self.train_data = None 55 | self.valid_data = None 56 | self.test_data = None 57 | self.open_data = None 58 | self.name_list = ['train', 'valid', 'test', 'open'] 59 | 60 | @staticmethod 61 | def create_pattern(x, y): 62 | max_pat_len = 32 63 | pat = '' 64 | while x !=0: 65 | if x