├── models ├── gnn.py ├── __init__.py ├── __pycache__ │ ├── ms.cpython-38.pyc │ ├── birnn.cpython-38.pyc │ ├── cnn.cpython-38.pyc │ ├── fcn.cpython-38.pyc │ ├── sesy.cpython-38.pyc │ ├── lstmatt.cpython-38.pyc │ ├── __init__.cpython-38.pyc │ ├── r_bilstm_c.cpython-38.pyc │ └── bilstm_dense.cpython-38.pyc ├── fcn.py ├── bilstm_dense.py ├── birnn.py ├── cnn.py ├── r_bilstm_c.py ├── ms.py ├── sesy.py └── lstmatt.py ├── run_multistages.sh ├── logger.py ├── LICENSE ├── configs ├── stage1.json ├── stage2.json ├── stage3.json ├── sesy.json └── test.json ├── utils.py ├── processors ├── process.py └── graph_process.py ├── readme.md ├── dataset.py ├── incorporate.py ├── distil.py └── main.py /models/gnn.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/__pycache__/ms.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yjs1224/TextSteganalysis/HEAD/models/__pycache__/ms.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/birnn.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yjs1224/TextSteganalysis/HEAD/models/__pycache__/birnn.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/cnn.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yjs1224/TextSteganalysis/HEAD/models/__pycache__/cnn.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/fcn.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yjs1224/TextSteganalysis/HEAD/models/__pycache__/fcn.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/sesy.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yjs1224/TextSteganalysis/HEAD/models/__pycache__/sesy.cpython-38.pyc -------------------------------------------------------------------------------- /run_multistages.sh: -------------------------------------------------------------------------------- 1 | # MS_TL 2 | # step 1 3 | python main.py 4 | # setp 2 5 | python distil.py 6 | # step 3 7 | python incorporate.py 8 | -------------------------------------------------------------------------------- /models/__pycache__/lstmatt.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yjs1224/TextSteganalysis/HEAD/models/__pycache__/lstmatt.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yjs1224/TextSteganalysis/HEAD/models/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/r_bilstm_c.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yjs1224/TextSteganalysis/HEAD/models/__pycache__/r_bilstm_c.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/bilstm_dense.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yjs1224/TextSteganalysis/HEAD/models/__pycache__/bilstm_dense.cpython-38.pyc -------------------------------------------------------------------------------- /logger.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import logging 4 | 5 | 6 | class Logger(object): 7 | def __init__(self, log_file): 8 | self.logger = logging.getLogger() 9 | self.formatter = logging.Formatter(fmt='[%(asctime)s] %(message)s', datefmt='%Y-%m-%d %H:%M:%S') 10 | 11 | self.logger.setLevel(logging.INFO) 12 | self.logger.handlers = [] 13 | 14 | fh = logging.FileHandler(log_file, mode='w') 15 | fh.setLevel(logging.INFO) 16 | fh.setFormatter(self.formatter) 17 | self.logger.addHandler(fh) 18 | 19 | sh = logging.StreamHandler() 20 | sh.setLevel(logging.INFO) 21 | sh.setFormatter(self.formatter) 22 | self.logger.addHandler(sh) 23 | 24 | def info(self, text): 25 | self.logger.info(text) 26 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 yjs1224 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /configs/stage1.json: -------------------------------------------------------------------------------- 1 | { 2 | "use_processor": true, 3 | "out_dir": "./test", 4 | "checkpoint": "stage1-best", 5 | "record_file": "record.txt", 6 | "model": "ft-bert", 7 | "class_num": 2, 8 | "tokenizer": false, 9 | "seed": 42, 10 | "gpuid": "1", 11 | "task_name": "steganalysis", 12 | "use_plm":true, 13 | 14 | "Dataset": { 15 | "name": "tweet", 16 | "stego_file": "data/stego.txt", 17 | "cover_file": "data/cover.txt", 18 | "csv_dir": "data", 19 | "resplit": true, 20 | "split_ratio": 0.8, 21 | "save_cache": false, 22 | "overwrite_cache": true 23 | }, 24 | "Training_with_Processor": { 25 | "num_train_epochs": 100, 26 | "learning_rate": 5e-5, 27 | "eval_and_save_steps": 100, 28 | "model_name_or_path": "prajjwal1/bert-tiny", 29 | "do_lower_case":true, 30 | "per_gpu_train_batch_size": 32, 31 | "per_gpu_eval_batch_size": 32, 32 | "n_gpu": 1, 33 | "max_steps": -1, 34 | "gradient_accumulation_steps": 1, 35 | "warmup_ratio": 0.06, 36 | "weight_decay": 0.01, 37 | "adam_epsilon": 1e-8, 38 | "max_grad_norm": 1.0, 39 | "logging_steps": -1, 40 | "evaluate_during_training": true, 41 | "save_only_best": true, 42 | "use_fixed_seq_length": true, 43 | "eval_all_checkpoints": true, 44 | "skip_evaluate_dev":false 45 | }, 46 | "FineTuneBERT": { 47 | "criteration": "CrossEntropyLoss" 48 | } 49 | } -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import sklearn.metrics as metrics 3 | 4 | class MyDict(dict): 5 | __setattr__ = dict.__setitem__ 6 | # def __setattr__(self, key, value): 7 | # try: 8 | # self[key] = value 9 | # except: 10 | # raise AttributeError(key) 11 | # __getattr__ = dict.__getitem__ 12 | def __getattr__(self, item): 13 | try: 14 | return self[item] 15 | except: 16 | raise AttributeError(item) 17 | 18 | class Config(object): 19 | def __init__(self, config_path): 20 | configs = json.load(open(config_path, "r", encoding="utf-8")) 21 | self.configs = self.dictobj2obj(configs) 22 | self.configs.state_dict = configs 23 | 24 | def dictobj2obj(self, dictobj): 25 | if not isinstance(dictobj, dict): 26 | return dictobj 27 | d = MyDict() 28 | for k, v in dictobj.items(): 29 | d[k] = self.dictobj2obj(v) 30 | return d 31 | 32 | 33 | 34 | def get_configs(self): 35 | return self.configs 36 | 37 | 38 | def compute_metrics(task_name, preds, labels, stego_label=1): 39 | assert len(preds) == len(labels), f"Predictions and labels have mismatched lengths {len(preds)} and {len(labels)}" 40 | if task_name in ["steganalysis", "graph_steganalysis"]: 41 | return {"accuracy": metrics.accuracy_score(labels, preds), 42 | "macro_f1":metrics.f1_score(labels, preds, average="macro"), 43 | "precision":metrics.precision_score(labels, preds, pos_label=stego_label), 44 | "recall":metrics.recall_score(labels, preds, pos_label=stego_label), 45 | "f1_score":metrics.f1_score(labels, preds, pos_label=stego_label)} 46 | else: 47 | raise KeyError(task_name) 48 | -------------------------------------------------------------------------------- /configs/stage2.json: -------------------------------------------------------------------------------- 1 | { 2 | "use_processor": true, 3 | "out_dir": "./test", 4 | "checkpoint": "stage2-cnn-best", 5 | "record_file": "record.txt", 6 | "teacher_model": "ft-bert", 7 | "teacher_model_name_or_path": "test/bert-cnn-best", 8 | "student_model": "ms-cnn", 9 | "class_num": 2, 10 | "tokenizer": false, 11 | "seed": 42, 12 | "gpuid": "1", 13 | "task_name": "steganalysis", 14 | 15 | "Dataset": { 16 | "name": "tweet", 17 | "stego_file": "data/stego.txt", 18 | "cover_file": "data/cover.txt", 19 | "csv_dir": "data", 20 | "resplit": true, 21 | "split_ratio": 0.8, 22 | "save_cache": false, 23 | "overwrite_cache": true 24 | }, 25 | "Training_with_Processor": { 26 | "num_train_epochs": 10, 27 | "learning_rate": 1e-4, 28 | "eval_and_save_steps": 100, 29 | "model_name_or_path": "prajjwal1/bert-tiny", 30 | "do_lower_case":true, 31 | "per_gpu_train_batch_size": 32, 32 | "per_gpu_eval_batch_size": 64, 33 | "n_gpu": 1, 34 | "max_steps": -1, 35 | "gradient_accumulation_steps": 1, 36 | "warmup_ratio": 0.06, 37 | "weight_decay": 0.01, 38 | "adam_epsilon": 1e-8, 39 | "max_grad_norm": 1.0, 40 | "logging_steps": -1, 41 | "evaluate_during_training": true, 42 | "save_only_best": true, 43 | "use_fixed_seq_length": true, 44 | "eval_all_checkpoints": true, 45 | "skip_evaluate_dev":false, 46 | "teacher_criteration":"KLDivLoss", 47 | "distil_T": 2, 48 | "distil_alpha": 0.95 49 | }, 50 | "FineTuneBERT": { 51 | "criteration": "CrossEntropyLoss" 52 | }, 53 | "MSCNN": { 54 | "embed_size":300, 55 | "filter_num": 100, 56 | "dropout_rate": 0.5, 57 | "criteration":"CrossEntropyLoss" 58 | }, 59 | "MSBiRNN": { 60 | "embed_size":300, 61 | "hidden_size": 200, 62 | "num_layers": 1, 63 | "bidirectional": true, 64 | "criteration":"CrossEntropyLoss" 65 | } 66 | } -------------------------------------------------------------------------------- /configs/stage3.json: -------------------------------------------------------------------------------- 1 | { 2 | "use_processor": true, 3 | "out_dir": "./test", 4 | "checkpoint": "stage3-birnn-best", 5 | "record_file": "record.txt", 6 | "model":"ms-birnnC", 7 | "class_num": 2, 8 | "tokenizer": false, 9 | "seed": 42, 10 | "gpuid": "0", 11 | 12 | "task_name": "steganalysis", 13 | "Dataset": { 14 | "name": "tweet", 15 | "stego_file": "data/stego.txt", 16 | "cover_file": "data/cover.txt", 17 | "csv_dir": "data", 18 | "resplit": true, 19 | "split_ratio": 0.8, 20 | "save_cache": false, 21 | "overwrite_cache": true 22 | }, 23 | "Training_with_Processor": { 24 | "num_train_epochs": 1, 25 | "learning_rate": 1e-4, 26 | "eval_and_save_steps": 50, 27 | "model_name_or_path": "prajjwal1/bert-tiny", 28 | "do_lower_case":true, 29 | "per_gpu_train_batch_size": 32, 30 | "per_gpu_eval_batch_size": 64, 31 | "n_gpu": 1, 32 | "max_steps": -1, 33 | "gradient_accumulation_steps": 1, 34 | "warmup_ratio": 0.06, 35 | "weight_decay": 0.01, 36 | "adam_epsilon": 1e-8, 37 | "max_grad_norm": 1.0, 38 | "logging_steps": -1, 39 | "evaluate_during_training": true, 40 | "save_only_best": true, 41 | "use_fixed_seq_length": true, 42 | "eval_all_checkpoints": true, 43 | "skip_evaluate_dev":false, 44 | "teacher_criteration":"KLDivLoss", 45 | "distil_T": 2, 46 | "distil_alpha": 0.95 47 | }, 48 | "FineTuneBERT": { 49 | "criteration": "CrossEntropyLoss" 50 | }, 51 | "MSBiRNNC": { 52 | "dropout_rate": 0.5, 53 | "criteration":"CrossEntropyLoss", 54 | "bidirectional": true, 55 | "cnn_checkpoint": "test/stage2-cnn-best", 56 | "birnn_checkpoint": "test/stage2-birnn-best", 57 | "MSCNN": { 58 | "embed_size":300, 59 | "filter_num": 100, 60 | "dropout_rate": 0.5, 61 | "criteration":"CrossEntropyLoss" 62 | }, 63 | "MSBiRNN": { 64 | "embed_size":300, 65 | "hidden_size": 200, 66 | "num_layers": 1, 67 | "bidirectional": true, 68 | "criteration":"CrossEntropyLoss" 69 | } 70 | } 71 | } -------------------------------------------------------------------------------- /configs/sesy.json: -------------------------------------------------------------------------------- 1 | { 2 | "use_processor": true, 3 | "out_dir": "./test", 4 | "checkpoint": "bert-sesy-best", 5 | "record_file": "sesy-parl-record.txt", 6 | "model": "sesy", 7 | "class_num": 2, 8 | "tokenizer": false, 9 | "seed": 42, 10 | "gpuid": "0", 11 | "task_name": "graph_steganalysis", 12 | "use_plm":true, 13 | 14 | "Dataset": { 15 | "name": "tweet", 16 | "stego_file": "data/stego.txt", 17 | "cover_file": "data/cover.txt", 18 | "csv_dir": "data", 19 | "resplit": true, 20 | "split_ratio": 0.8, 21 | "save_cache": true, 22 | "overwrite_cache": true 23 | }, 24 | "Tokenizer": { 25 | "model_name_or_path": "bert-base-uncased" 26 | }, 27 | "Training_with_Processor": { 28 | "num_train_epochs": 10, 29 | "learning_rate": 5e-5, 30 | "eval_and_save_steps": 100, 31 | "model_name_or_path": "bert-base-uncased", 32 | "do_lower_case":true, 33 | "per_gpu_train_batch_size": 32, 34 | "per_gpu_eval_batch_size": 32, 35 | "n_gpu": 1, 36 | "max_steps": -1, 37 | "gradient_accumulation_steps": 1, 38 | "warmup_ratio": 0.06, 39 | "weight_decay": 0.01, 40 | "adam_epsilon": 1e-8, 41 | "max_grad_norm": 1.0, 42 | "logging_steps": -1, 43 | "evaluate_during_training": true, 44 | "save_only_best": true, 45 | "use_fixed_seq_length": true, 46 | "eval_all_checkpoints": true, 47 | "skip_evaluate_dev":false 48 | }, 49 | "Training": { 50 | "batch_size": 100, 51 | "epoch": 10, 52 | "learning_rate": 0.001, 53 | "early_stop": 50, 54 | "warmup_ratio": 0.06, 55 | "weight_decay": 0.01, 56 | "adam_epsilon": 1e-8 57 | }, 58 | "Vocabulary": { 59 | "word_drop": 0, 60 | "do_lower": true, 61 | "max_length": 60 62 | }, 63 | "SESY": { 64 | "clf": "cnn", 65 | "criteration": "CrossEntropyLoss", 66 | "strategy": "parl", 67 | "embed_size": 768, 68 | "hidden_dim": 128, 69 | "readout_size": 64, 70 | "gat_alpha": 0.2, 71 | "gat_heads": 8, 72 | "dropout_rate": 0.2, 73 | "TC_configs": { 74 | "cnn": { 75 | "filter_num": 128, 76 | "filter_size": [3, 4, 5] 77 | 78 | }, 79 | "rnn": { 80 | "cell":"bi-lstm", 81 | "hidden_dim": 256, 82 | "num_layers": 1 83 | }, 84 | "fc": { 85 | } 86 | } 87 | } 88 | } 89 | -------------------------------------------------------------------------------- /models/fcn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from transformers import BertPreTrainedModel,BertModel 4 | 5 | 6 | class TC_base(nn.Module): 7 | def __init__(self, in_features, class_num, dropout_rate,): 8 | super(TC_base, self).__init__() 9 | self.in_features = in_features 10 | self.dropout_prob = dropout_rate 11 | self.num_labels = class_num # 12 | self.dropout = nn.Dropout(self.dropout_prob) 13 | self.pool = nn.AdaptiveAvgPool1d(1) 14 | self.classifier = nn.Linear(self.in_features, self.num_labels) 15 | 16 | def forward(self,features): 17 | clf_input = self.pool(features.permute(0,2,1)).squeeze() 18 | logits = self.classifier(clf_input) 19 | return logits 20 | 21 | def extra_repr(self) -> str: 22 | return 'features {}->{},'.format( 23 | self.in_features, self.class_num 24 | ) 25 | 26 | 27 | class TC(nn.Module): 28 | def __init__(self, vocab_size, embed_size, class_num, dropout_rate, criteration="CrossEntropyLoss"): 29 | super(TC,self).__init__() 30 | self.embed_size= embed_size 31 | self.dropout_prob = dropout_rate 32 | self.num_labels = class_num 33 | 34 | self.embedding = nn.Embedding(vocab_size, self.embed_size,) 35 | self.classifier = TC_base(self.embed_size, self.num_labels,self.dropout_prob) 36 | 37 | if criteration == "CrossEntropyLoss": 38 | self.criteration = nn.CrossEntropyLoss() 39 | else: 40 | # default loss 41 | self.criteration = nn.CrossEntropyLoss() 42 | 43 | 44 | def forward(self,input_ids, labels, attention_mask=None,token_type_ids=None): 45 | clf_input = self.embedding(input_ids.long()) 46 | logits = self.classifier(clf_input) 47 | loss = self.criteration(logits, labels) 48 | return loss, logits 49 | 50 | 51 | class BERT_TC(BertPreTrainedModel): 52 | def __init__(self, config, **kwargs): 53 | super().__init__(config) 54 | self.class_num = kwargs["class_num"] 55 | self.dropout_rate = kwargs["dropout_rate"] 56 | self.embed_size = config.hidden_size # not kwags["embed_size"] 57 | self.plm_config = config 58 | 59 | self.bert = BertModel(self.plm_config) 60 | self.classifier = TC_base(self.embed_size, self.class_num, self.dropout_rate) 61 | if kwargs["criteration"] == "CrossEntropyLoss": 62 | self.criteration = nn.CrossEntropyLoss() 63 | else: 64 | # default loss 65 | self.criteration = nn.CrossEntropyLoss() 66 | 67 | def extra_repr(self) -> str: 68 | return 'bert word embedding dim:{}'.format( 69 | self.embed_size 70 | ) 71 | 72 | def forward(self, input_ids, labels, attention_mask, token_type_ids): 73 | outputs = self.bert(input_ids, attention_mask, token_type_ids) 74 | embedding = outputs[0] 75 | 76 | 77 | logits = self.classifier(embedding) 78 | loss = self.criteration(logits, labels) 79 | return loss, logits 80 | 81 | 82 | -------------------------------------------------------------------------------- /configs/test.json: -------------------------------------------------------------------------------- 1 | { 2 | "use_processor": true, 3 | "out_dir": "out_dir", 4 | "checkpoint": "best", 5 | "record_file": "record.txt", 6 | "model": "cnn", 7 | "class_num": 2, 8 | "tokenizer": false, 9 | "seed": 42, 10 | "gpuid": "0", 11 | "task_name": "steganalysis", 12 | "use_plm":true, 13 | "repeat_num":10, 14 | 15 | "Dataset": { 16 | "name": "tweet", 17 | "stego_file": "data/stego.txt", 18 | "cover_file": "data/cover.txt", 19 | "csv_dir": "data", 20 | "resplit": true, 21 | "split_ratio": 0.8, 22 | "save_cache": false, 23 | "overwrite_cache": true 24 | }, 25 | "Tokenizer": { 26 | "model_name_or_path": "bert-base-uncased" 27 | }, 28 | "Training_with_Processor": { 29 | "num_train_epochs": 1, 30 | "learning_rate": 1e-4, 31 | "eval_and_save_steps": 50, 32 | "model_name_or_path": "prajjwal1/bert-tiny", 33 | "do_lower_case":true, 34 | "per_gpu_train_batch_size": 40, 35 | "per_gpu_eval_batch_size": 100, 36 | "n_gpu": 1, 37 | "max_steps": -1, 38 | "gradient_accumulation_steps": 1, 39 | "warmup_ratio": 0.06, 40 | "weight_decay": 0.01, 41 | "adam_epsilon": 1e-8, 42 | "max_grad_norm": 1.0, 43 | "logging_steps": -1, 44 | "evaluate_during_training": true, 45 | "save_only_best": true, 46 | "use_fixed_seq_length": true, 47 | "eval_all_checkpoints": true, 48 | "skip_evaluate_dev":false 49 | }, 50 | "Training": { 51 | "batch_size": 100, 52 | "epoch": 10, 53 | "learning_rate": 0.001, 54 | "early_stop": 50, 55 | "warmup_ratio": 0.06, 56 | "weight_decay": 0.01, 57 | "adam_epsilon": 1e-8 58 | }, 59 | "Vocabulary": { 60 | "word_drop": 0, 61 | "do_lower": true, 62 | "max_length": 60 63 | }, 64 | 65 | "CNN": { 66 | "embed_size": 128, 67 | "filter_num": 128, 68 | "filter_size": [3, 4, 5], 69 | "dropout_rate": 0.2, 70 | "criteration": "CrossEntropyLoss" 71 | }, 72 | "RNN": { 73 | "cell":"bi-lstm", 74 | "embed_size": 128, 75 | "hidden_dim": 256, 76 | "num_layers": 1, 77 | "dropout_rate": 0.2, 78 | "criteration": "CrossEntropyLoss" 79 | }, 80 | "FCN": { 81 | "embed_size": 128, 82 | "dropout_rate": 0.2, 83 | "criteration": "CrossEntropyLoss" 84 | }, 85 | "LSTMATT": { 86 | "embed_size": 128, 87 | "hidden_dim": 256, 88 | "dropout_rate":0.2, 89 | "criteration": "CrossEntropyLoss", 90 | "bidirectional": true 91 | }, 92 | "RBiLSTMC": { 93 | "num_layers": 1, 94 | "kernel_sizes": [3,4,5], 95 | "kernel_num": 128, 96 | "embed_dim": 128, 97 | "hidden_dim": 256, 98 | "LSTM_dropout": 0.2, 99 | "CNN_dropout": 0.2, 100 | "Ci": 1, 101 | "criteration": "CrossEntropyLoss" 102 | }, 103 | "BiLSTMDENSE": { 104 | "num_layers": 1, 105 | "embed_dim": 256, 106 | "hidden_dim": 200, 107 | "dropout_rate": 0.2, 108 | "criteration": "CrossEntropyLoss" 109 | }, 110 | "SESY": { 111 | "clf": "cnn", 112 | "criteration": "CrossEntropyLoss", 113 | "strategy": "cas", 114 | "embed_size": 100, 115 | "hidden_dim": 128, 116 | "readout_size": 64, 117 | "gat_alpha": 0.2, 118 | "gat_heads": 8, 119 | "dropout_rate": 0.2, 120 | "TC_configs": { 121 | "cnn": { 122 | "filter_num": 128, 123 | "filter_size": [3, 4, 5] 124 | 125 | }, 126 | "rnn": { 127 | "cell":"bi-lstm", 128 | "hidden_dim": 256, 129 | "num_layers": 1 130 | }, 131 | "fc": { 132 | } 133 | } 134 | } 135 | } 136 | -------------------------------------------------------------------------------- /models/bilstm_dense.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from transformers import BertPreTrainedModel, BertModel 4 | 5 | 6 | class TC_base(nn.Module): 7 | def __init__(self, in_features,class_num, num_layers,hidden_size,dropout_rate): 8 | super(TC_base, self).__init__() 9 | self.in_features= in_features 10 | self.class_num = class_num 11 | 12 | D = in_features 13 | C = class_num 14 | N = num_layers 15 | H = hidden_size 16 | 17 | self.lstm1 = nn.LSTM(D, H, num_layers=N, \ 18 | bidirectional=True, 19 | batch_first=True,dropout=dropout_rate) 20 | 21 | self.lstm2 = nn.LSTM(2 * H, H, num_layers=N, \ 22 | bidirectional=True, 23 | batch_first=True,dropout=dropout_rate) 24 | 25 | self.lstm3 = nn.LSTM(4 * H, H, num_layers=N, \ 26 | bidirectional=True, 27 | batch_first=True,dropout=dropout_rate) 28 | 29 | self.fc1 = nn.Linear(2 * H, C) 30 | 31 | def forward(self, features): 32 | out1, _ = self.lstm1(features) 33 | out2, _ = self.lstm2(out1) 34 | out3, _ = self.lstm3(torch.cat([out1, out2], 2)) 35 | out = torch.add(torch.add(out1, out2), out3) 36 | logits = self.fc1(out[:, -1, :]) 37 | return logits 38 | 39 | def extra_repr(self) -> str: 40 | return 'features {}->{},'.format( 41 | self.in_features, self.class_num 42 | ) 43 | 44 | 45 | class TC(nn.Module): 46 | def __init__(self, vocab_size, embed_dim,class_num, num_layers,hidden_dim,dropout_rate,criteration="CrossEntropyLoss"): 47 | super(TC, self).__init__() 48 | 49 | V = vocab_size 50 | D = embed_dim 51 | C = class_num 52 | N = num_layers 53 | H = hidden_dim 54 | 55 | self.embed = nn.Embedding(V, D) 56 | # self.classifier = TC_base(embed_dim,class_num,num_layers,hidden_size,dropout_rate) 57 | self.classifier = TC_base(D,C, N, H, dropout_rate) 58 | 59 | if criteration == "CrossEntropyLoss": 60 | self.criteration = nn.CrossEntropyLoss() 61 | else: 62 | # default loss 63 | self.criteration = nn.CrossEntropyLoss() 64 | 65 | def forward(self, input_ids,labels,attention_mask=None,token_type_ids=None): 66 | embedding = self.embed(input_ids) 67 | logits = self.classifier(embedding) 68 | loss = self.criteration(logits,labels) 69 | return loss,logits 70 | 71 | 72 | class BERT_TC(BertPreTrainedModel): 73 | def __init__(self, config, **kwargs): 74 | super().__init__(config) 75 | 76 | self.bert_config = config 77 | self.embed_dim = config.hidden_size 78 | self.class_num = kwargs["class_num"] 79 | self.num_layers = kwargs["num_layers"] 80 | self.hidden_size = kwargs["hidden_dim"] 81 | self.dropout_rate = kwargs["dropout_rate"] 82 | 83 | self.bert = BertModel(self.bert_config) 84 | self.classifier = TC_base(self.embed_dim,self.class_num,self.num_layers,self.hidden_size,self.dropout_rate) 85 | if kwargs["criteration"] == "CrossEntropyLoss": 86 | self.criteration = nn.CrossEntropyLoss() 87 | else: 88 | # default loss 89 | self.criteration = nn.CrossEntropyLoss() 90 | 91 | def extra_repr(self) -> str: 92 | return 'bert word embedding dim:{}'.format( 93 | self.embed_size 94 | ) 95 | 96 | def forward(self, input_ids, labels, attention_mask, token_type_ids): 97 | outputs = self.bert(input_ids, attention_mask, token_type_ids) 98 | embedding = outputs[0] 99 | logits = self.classifier(embedding) 100 | loss = self.criteration(logits, labels) 101 | return loss, logits 102 | -------------------------------------------------------------------------------- /models/birnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from transformers import BertPreTrainedModel,BertModel 4 | 5 | class TC_base(nn.Module): 6 | def __init__(self, cell, in_features, hidden_dim, num_layers, class_num, dropout_rate,): 7 | super(TC_base, self).__init__() 8 | self._cell = cell 9 | self.in_features=in_features 10 | self.class_num = class_num 11 | self.rnn = None 12 | if cell == 'rnn': 13 | self.rnn = nn.RNN(in_features, hidden_dim, num_layers, dropout=dropout_rate) 14 | out_hidden_dim = hidden_dim * num_layers 15 | elif cell == 'bi-rnn': 16 | self.rnn = nn.RNN(in_features, hidden_dim, num_layers, dropout=dropout_rate, bidirectional=True) 17 | out_hidden_dim = 2 * hidden_dim * num_layers 18 | elif cell == 'gru': 19 | self.rnn = nn.GRU(in_features, hidden_dim, num_layers, dropout=dropout_rate) 20 | out_hidden_dim = hidden_dim * num_layers 21 | elif cell == 'bi-gru': 22 | self.rnn = nn.GRU(in_features, hidden_dim, num_layers, dropout=dropout_rate, bidirectional=True) 23 | out_hidden_dim = 2 * hidden_dim * num_layers 24 | elif cell == 'lstm': 25 | self.rnn = nn.LSTM(in_features, hidden_dim, num_layers, dropout=dropout_rate) 26 | out_hidden_dim = 2 * hidden_dim * num_layers 27 | elif cell == 'bi-lstm': 28 | self.rnn = nn.LSTM(in_features, hidden_dim, num_layers, dropout=dropout_rate, bidirectional=True) 29 | out_hidden_dim = 4 * hidden_dim * num_layers 30 | else: 31 | raise Exception("no such rnn cell") 32 | self.output_layer = nn.Linear(out_hidden_dim, class_num) 33 | 34 | def forward(self,features): 35 | _ = features.permute(1, 0, 2) 36 | __, h_out = self.rnn(_) 37 | if self._cell in ["lstm", "bi-lstm"]: 38 | h_out = torch.cat([h_out[0], h_out[1]], dim=2) 39 | h_out = h_out.permute(1, 0, 2) 40 | h_out = h_out.reshape(-1, h_out.shape[1] * h_out.shape[2]) 41 | logits = self.output_layer(h_out) 42 | return logits 43 | 44 | def extra_repr(self) -> str: 45 | return 'features {}->{},'.format( 46 | self.in_features, self.class_num 47 | ) 48 | 49 | 50 | class TC(nn.Module): 51 | def __init__(self, cell, vocab_size, embed_size, hidden_dim, num_layers, class_num, dropout_rate, criteration="CrossEntropyLoss"): 52 | super(TC, self).__init__() 53 | self._cell = cell 54 | 55 | self.embedding = nn.Embedding(vocab_size, embed_size) 56 | self.classifier = TC_base(cell,embed_size,hidden_dim,num_layers,class_num,dropout_rate) 57 | if criteration == "CrossEntropyLoss": 58 | self.criteration = nn.CrossEntropyLoss() 59 | else: 60 | # default loss 61 | self.criteration = nn.CrossEntropyLoss() 62 | 63 | def forward(self, input_ids, labels,attention_mask=None,token_type_ids=None): 64 | x = input_ids.long() 65 | embedding = self.embedding(x) 66 | logits = self.classifier(embedding) 67 | loss = self.criteration(logits, labels) 68 | return loss,logits 69 | 70 | def extra_repr(self) -> str: 71 | return 'features {}->{},'.format( 72 | self.embed_size, self.class_num 73 | ) 74 | 75 | 76 | class BERT_TC(BertPreTrainedModel): 77 | def __init__(self, config, **kwargs): 78 | super().__init__(config) 79 | self.class_num = kwargs["class_num"] 80 | self.dropout_rate = kwargs["dropout_rate"] 81 | self.cell = kwargs["cell"] 82 | self.hidden_dim = kwargs["hidden_dim"] 83 | self.num_layers = kwargs["num_layers"] 84 | self.embed_size = config.hidden_size # not kwags["embed_size"] 85 | self.plm_config = config 86 | 87 | self.bert = BertModel(self.plm_config) 88 | self.classifier = TC_base(self.cell, self.embed_size, self.hidden_dim, self.num_layers, self.class_num, self.dropout_rate) 89 | if kwargs["criteration"] == "CrossEntropyLoss": 90 | self.criteration = nn.CrossEntropyLoss() 91 | else: 92 | # default loss 93 | self.criteration = nn.CrossEntropyLoss() 94 | 95 | 96 | def extra_repr(self) -> str: 97 | return 'bert word embedding dim:{}'.format( 98 | self.embed_size 99 | ) 100 | 101 | 102 | def forward(self,input_ids, labels, attention_mask, token_type_ids): 103 | outputs = self.bert(input_ids, attention_mask, token_type_ids) 104 | embedding = outputs[0] 105 | logits = self.classifier(embedding) 106 | loss = self.criteration(logits, labels) 107 | return loss, logits 108 | -------------------------------------------------------------------------------- /models/cnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from transformers import BertPreTrainedModel,BertModel,DistilBertPreTrainedModel,DistilBertModel 4 | 5 | 6 | class TC_base(nn.Module): 7 | def __init__(self, in_features, filter_num, filter_size, class_num, dropout_rate,): 8 | super(TC_base, self).__init__() 9 | 10 | self.cnn_list = nn.ModuleList() 11 | for size in filter_size: 12 | self.cnn_list.append(nn.Conv1d(in_features, filter_num, size)) 13 | self.relu = nn.ReLU() 14 | self.max_pool = nn.AdaptiveMaxPool1d(1) 15 | self.dropout = nn.Dropout(dropout_rate) 16 | self.output_layer = nn.Linear(filter_num * len(filter_size), class_num) 17 | 18 | self.in_features = in_features 19 | self.class_num = class_num 20 | 21 | def forward(self, features): 22 | _ = features.permute(0, 2, 1) 23 | result = [] 24 | for self.cnn in self.cnn_list: 25 | __ = self.cnn(_) 26 | __ = self.relu(__) 27 | __ = self.max_pool(__) 28 | result.append(__.squeeze(dim=2)) 29 | 30 | _ = torch.cat(result, dim=1) 31 | _ = self.dropout(_) 32 | _ = self.output_layer(_) 33 | return _ 34 | 35 | def extra_repr(self) -> str: 36 | return 'features {}->{},'.format( 37 | self.in_features, self.class_num 38 | ) 39 | 40 | 41 | class TC(nn.Module): 42 | def __init__(self, vocab_size, embed_size, filter_num, filter_size, class_num, dropout_rate, criteration="CrossEntropyLoss",): 43 | super(TC, self).__init__() 44 | self.embedding = nn.Embedding(vocab_size, embed_size) 45 | self.classifier = TC_base(embed_size,filter_num,filter_size,class_num,dropout_rate) 46 | if criteration == "CrossEntropyLoss": 47 | self.criteration = nn.CrossEntropyLoss() 48 | else: 49 | # default loss 50 | self.criteration = nn.CrossEntropyLoss() 51 | 52 | 53 | def forward(self, input_ids, labels,attention_mask=None,token_type_ids=None): 54 | clf_input = self.embedding(input_ids.long()) 55 | logits = self.classifier(clf_input) 56 | loss = self.criteration(logits, labels) 57 | return loss, logits 58 | 59 | 60 | class BERT_TC(BertPreTrainedModel): 61 | def __init__(self, config, **kwargs): 62 | super().__init__(config) 63 | self.filter_size = kwargs["filter_size"] 64 | self.filter_num = kwargs["filter_num"] 65 | self.class_num = kwargs["class_num"] 66 | self.dropout_rate = kwargs["dropout_rate"] 67 | self.embed_size = config.hidden_size # not kwags["embed_size"] 68 | self.plm_config = config 69 | 70 | self.bert = BertModel(self.plm_config) 71 | self.classifier = TC_base(self.embed_size, self.filter_num,self.filter_size,self.class_num,self.dropout_rate) 72 | if kwargs["criteration"] == "CrossEntropyLoss": 73 | self.criteration = nn.CrossEntropyLoss() 74 | else: 75 | # default loss 76 | self.criteration = nn.CrossEntropyLoss() 77 | 78 | def extra_repr(self) -> str: 79 | return 'bert word embedding dim:{}'.format( 80 | self.embed_size 81 | ) 82 | 83 | 84 | def forward(self,input_ids, labels, attention_mask, token_type_ids): 85 | outputs = self.bert(input_ids,attention_mask,token_type_ids) 86 | embedding = outputs[0] 87 | logits = self.classifier(embedding) 88 | loss = self.criteration(logits, labels) 89 | return loss, logits 90 | 91 | 92 | class DistilBERT_TC(DistilBertPreTrainedModel): 93 | def __init__(self, config, **kwargs): 94 | super().__init__(config) 95 | self.filter_size = kwargs["filter_size"] 96 | self.filter_num = kwargs["filter_num"] 97 | self.class_num = kwargs["class_num"] 98 | self.dropout_rate = kwargs["dropout_rate"] 99 | self.embed_size = config.hidden_size # not kwags["embed_size"] 100 | self.plm_config = config 101 | 102 | self.bert = DistilBertModel(self.plm_config) 103 | self.classifier = TC_base(self.embed_size, self.filter_num,self.filter_size,self.class_num, self.dropout_rate) 104 | if kwargs["criteration"] == "CrossEntropyLoss": 105 | self.criteration = nn.CrossEntropyLoss() 106 | else: 107 | # default loss 108 | self.criteration = nn.CrossEntropyLoss() 109 | 110 | def extra_repr(self) -> str: 111 | return 'bert word embedding dim:{}'.format( 112 | self.embed_size 113 | ) 114 | 115 | 116 | def forward(self,input_ids, labels, attention_mask, token_type_ids): 117 | outputs = self.bert(input_ids,attention_mask) 118 | embedding = outputs[0] 119 | logits = self.classifier(embedding) 120 | loss = self.criteration(logits, labels) 121 | return loss, logits -------------------------------------------------------------------------------- /processors/process.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import TensorDataset 3 | import csv 4 | import os 5 | import json 6 | 7 | 8 | class InputExample(object): 9 | def __init__(self, sentence=None, label=None): 10 | self.sentence = sentence 11 | self.label = label 12 | 13 | 14 | class SeqInputFeatures(object): 15 | """A single set of features of data for the ABSA task""" 16 | 17 | def __init__(self, input_ids, input_mask, segment_ids, label_ids): 18 | self.input_ids = input_ids 19 | self.input_mask = input_mask 20 | self.segment_ids = segment_ids 21 | self.label_ids = label_ids 22 | 23 | 24 | class SteganalysisProcessor(object): 25 | def __init__(self, tokenizer): 26 | self.tokenizer = tokenizer 27 | self.max_seq_len = 128 28 | self.label_list = [0, 1] 29 | self.num_labels = 2 30 | self.label2id = {} 31 | self.id2label = {} 32 | for idx, label in enumerate(self.label_list): 33 | self.label2id[label] = idx 34 | self.id2label[idx] = label 35 | 36 | 37 | def get_examples(self, file_name): 38 | return self._create_examples( 39 | file_name=file_name 40 | ) 41 | 42 | 43 | def get_train_examples(self, dir): 44 | return self.get_examples(os.path.join(dir, "train.csv")) 45 | 46 | 47 | def get_dev_examples(self, dir): 48 | return self.get_examples(os.path.join(dir, "val.csv")) 49 | 50 | 51 | def get_test_examples(self, dir): 52 | return self.get_examples(os.path.join(dir, "test.csv")) 53 | 54 | 55 | def _create_examples(self, file_name): 56 | examples = [] 57 | file = file_name 58 | lines = csv.reader(open(file, 'r', encoding='utf-8')) 59 | for i, line in enumerate(lines): 60 | if i > 0: 61 | sentence = line[0].lower().strip() 62 | label_t = line[1].strip() 63 | if label_t == "0": 64 | label = 0 65 | if label_t == "1": 66 | label = 1 67 | examples.append(InputExample(sentence=sentence, label=label)) 68 | 69 | # dataset = self.convert_examples_to_features(examples) 70 | # return dataset, examples 71 | return examples 72 | 73 | 74 | def convert_examples_to_features(self, examples): 75 | features = [] 76 | for example in examples: 77 | inputs = self.tokenizer.encode_plus( 78 | example.sentence, 79 | add_special_tokens=True, 80 | max_length=self.max_seq_len, 81 | padding='max_length', 82 | return_attention_mask=True, 83 | return_token_type_ids=True, 84 | truncation=True 85 | ) 86 | input_ids = inputs["input_ids"] 87 | attention_mask = inputs["attention_mask"] 88 | token_type_ids = inputs["token_type_ids"] 89 | if example.label is not None: 90 | label_id = self.label2id[example.label] 91 | else: 92 | label_id = -1 93 | 94 | features.append( 95 | SeqInputFeatures(input_ids=input_ids, 96 | input_mask=attention_mask, 97 | segment_ids=token_type_ids, 98 | label_ids=label_id,)) 99 | 100 | all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long) 101 | all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long) 102 | all_segment_ids = torch.tensor([f.segment_ids for f in features], dtype=torch.long) 103 | all_label_ids = torch.tensor([f.label_ids for f in features], dtype=torch.long) 104 | dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids) 105 | return dataset 106 | # return features 107 | 108 | def write_preds(self, preds, ex_ids, out_dir): 109 | """Write predictions in SuperGLUE format.""" 110 | preds = preds[ex_ids] # sort just in case we got scrambled 111 | idx2label = {i: label for i, label in enumerate(self.get_labels())} 112 | with open(os.path.join(out_dir, "steganalysis.jsonl"), "w") as pred_fh: 113 | for idx, pred in enumerate(preds): 114 | pred_label = idx2label[int(pred)] 115 | pred_fh.write(f"{json.dumps({'idx': idx, 'label': pred_label})}\n") 116 | logger.info(f"Wrote {len(preds)} predictions to {out_dir}.") 117 | 118 | def get_labels(self): 119 | return self.label_list -------------------------------------------------------------------------------- /models/r_bilstm_c.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from transformers import BertPreTrainedModel, BertModel 5 | 6 | class TC_base(nn.Module): 7 | def __init__(self,in_features, class_num, num_layers, hidden_size,Ci, kernel_num, kernel_sizes,LSTM_dropout,CNN_dropout): 8 | super(TC_base, self).__init__() 9 | self.in_features =in_features 10 | self.class_num = class_num 11 | D = in_features 12 | C = class_num 13 | N = num_layers 14 | H = hidden_size 15 | Ci = Ci 16 | Co = kernel_num 17 | Ks = kernel_sizes 18 | self.lstm = nn.LSTM(D, H, num_layers=N, \ 19 | bidirectional=True, 20 | batch_first=True, 21 | dropout=LSTM_dropout) 22 | 23 | self.conv1_D = nn.Conv2d(Ci, Co, (1, 2 * H)) 24 | 25 | self.convK_1 = nn.ModuleList( 26 | [nn.Conv2d(Co, Co, (K, 1)) for K in Ks]) 27 | 28 | self.conv3 = nn.Conv2d(Co, Co, (3, 1)) 29 | 30 | self.conv4 = nn.Conv2d(Co, Co, (3, 1), padding=(1, 0)) 31 | 32 | self.CNN_dropout = nn.Dropout(CNN_dropout) 33 | self.fc1 = nn.Linear(len(Ks) * Co, C) 34 | 35 | def forward(self, features): 36 | out, _ = self.lstm(features) # [batch_size, sen_len, H*2] 37 | x = out.unsqueeze(1) 38 | x = self.conv1_D(x) 39 | 40 | x = [F.relu(conv(x)) for conv in self.convK_1] 41 | x3 = [F.relu(self.conv3(i)) for i in x] 42 | x4 = [F.relu(self.conv4(i)) for i in x3] 43 | inception = [] 44 | for i in range(len(x4)): 45 | res = torch.add(x3[i], x4[i]) 46 | inception.append(res) 47 | 48 | x = [i.squeeze(3) for i in inception] 49 | x = [F.max_pool1d(i, i.size(2)).squeeze(2) for i in x] 50 | x = torch.cat(x, 1) 51 | 52 | x = self.CNN_dropout(x) 53 | logits = self.fc1(x) 54 | return logits 55 | 56 | def extra_repr(self) -> str: 57 | return 'features {}->{},'.format( 58 | self.in_features, self.class_num 59 | ) 60 | 61 | 62 | 63 | class TC(nn.Module): 64 | def __init__(self, embed_dim,num_layers,hidden_dim,class_num, 65 | kernel_num,kernel_sizes,vocab_size,LSTM_dropout,CNN_dropout,Ci=1, 66 | field=None,criteration="CrossEntropyLoss"): 67 | super(TC, self).__init__() 68 | 69 | V = vocab_size 70 | D = embed_dim 71 | C = class_num 72 | N = num_layers 73 | H = hidden_dim 74 | Ci = Ci 75 | Co = kernel_num 76 | Ks = kernel_sizes 77 | 78 | self.embed_A = nn.Embedding(V, D) 79 | self.embed_B = nn.Embedding(V, D) 80 | # self.embed_B.weight.data.copy_(field.vocab.vectors) 81 | self.classifier = TC_base(D,C,N,H,Ci,Co,Ks,LSTM_dropout,CNN_dropout) 82 | if criteration == "CrossEntropyLoss": 83 | self.criteration = nn.CrossEntropyLoss() 84 | else: 85 | # default loss 86 | self.criteration = nn.CrossEntropyLoss() 87 | 88 | 89 | def forward(self, input_ids, labels, attention_mask=None,token_type_ids=None): 90 | x= input_ids 91 | x_A = self.embed_A(x) # x [batch_size, sen_len, D] 92 | x_B = self.embed_B(x) 93 | x = torch.add(x_A, x_B) 94 | logits = self.classifier(x) 95 | loss = self.criteration(logits,labels) 96 | return loss,logits 97 | 98 | 99 | 100 | class BERT_TC(BertPreTrainedModel): 101 | def __init__(self, config, **kwargs): 102 | super().__init__(config) 103 | self.bert_config = config 104 | self.embed_dim = config.hidden_size 105 | self.class_num = kwargs["class_num"] 106 | self.num_layers = kwargs["num_layers"] 107 | self.hidden_size = kwargs["hidden_dim"] 108 | self.Ci = kwargs["Ci"] 109 | self.kernel_num = kwargs["kernel_num"] 110 | self.kernel_sizes = kwargs["kernel_sizes"] 111 | self.LSTM_dropout = kwargs["LSTM_dropout"] 112 | self.CNN_dropout = kwargs["CNN_dropout"] 113 | 114 | self.bert = BertModel(self.bert_config) 115 | self.classifier = TC_base(self.embed_dim,self.class_num,self.num_layers,self.hidden_size,self.Ci,self.kernel_num,self.kernel_sizes, 116 | self.LSTM_dropout,self.CNN_dropout) 117 | 118 | if kwargs["criteration"] == "CrossEntropyLoss": 119 | self.criteration = nn.CrossEntropyLoss() 120 | else: 121 | # default loss 122 | self.criteration = nn.CrossEntropyLoss() 123 | 124 | def extra_repr(self) -> str: 125 | return 'bert word embedding dim:{}'.format( 126 | self.embed_size 127 | ) 128 | 129 | def forward(self, input_ids, labels, attention_mask, token_type_ids): 130 | outputs = self.bert(input_ids, attention_mask, token_type_ids) 131 | embedding = outputs[0] 132 | logits = self.classifier(embedding) 133 | loss = self.criteration(logits, labels) 134 | return loss, logits -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # 文本隐写分析 (linguistic steganalysis) 2 | 3 | 4 | # NOTICE: 之前版本可能存在一些严重的bugs,请自己检查使用的版本。 Serious Bugs in older version (Please Check your versions): 5 | - #### 最好的模型可能会被较差的模型覆盖(已修复) 6 | - #### 训练集上的metircs缺失(已修复) 7 | - #### 以不正确的比率拆分train和val(已修复) 8 | - #### 添加了一些重要的注意事项, 见后文(使用此repo之前请仔细阅读) 9 | - #### The best model may be covered by worse model (Fixed) 10 | - #### metircs on train set are missed (Fixed) 11 | - #### split train and val in incorrect ratio (Fixed) 12 | - #### Add Important Attention below. (Carefully read it before using this repo) 13 | 14 | ### 感谢发现指出问题! 15 | ### Thanks for finding and reporting bugs! 16 | 17 | 18 | ## 测试环境 (environment for development) 19 | python 3.7 20 | 21 | transformers 4.12 22 | 23 | ## 包含(include) 24 | ### Non-Transformer-based 25 | - [TS-CSW: text steganalysis and hidden capacity estimation based on convolutional sliding windows (TS-CSW)](https://link.springer.com/article/10.1007/s11042-020-08716-w) 26 | - [TS-RNN: Text Steganalysis Based on Recurrent Neural Networks (TS-BiRNN)](https://ieeexplore.ieee.org/abstract/document/8727932) 27 | - [Linguistic Steganalysis via Densely Connected LSTM with Feature Pyramid (BiLISTM-Dense)](https://dl.acm.org/doi/abs/10.1145/3369412.3395067) 28 | - [A Fast and Efficient Text Steganalysis Method (TS-FCN)](https://ieeexplore.ieee.org/document/8653856) 29 | - [A Hybrid R-BILSTM-C Neural Network Based Text Steganalysis(R-BiLSTM-C)](https://ieeexplore.ieee.org/abstract/document/8903243) 30 | - [Linguistic Steganalysis With Graph Neural Networks (GNN) (waiting......) ](https://ieeexplore.ieee.org/document/9364681) 31 | 32 | #### multi-stages 33 | - [Real-Time Text Steganalysis Based on Multi-Stage Transfer Learning (MS-TL)](https://ieeexplore.ieee.org/abstract/document/9484749/) 34 | ------ 35 | ### Transformer-based 36 | - [SeSy: Linguistic Steganalysis Framework Integrating Semantic and Syntactic Features (Sesy)](https://ieeexplore.ieee.org/abstract/document/9591452) 37 | - [High-Performance Linguistic Steganalysis, Capacity Estimation and Steganographic Positioning (BERT-LSTM-ATT)](https://link.springer.com/chapter/10.1007%2F978-3-030-69449-4_7) 38 | 39 | ## 使用样例(How to use) 40 | - 使用命令行:`python main.py --config_path your_config_file_path` 41 | - 例如:`python main.py --config_path ./configs/test.json` 42 | - 作为参考,./configs/test.json 里包含多种方法的超参数设置(存在冗余,根据自己需求删改),其中"use_plm"用来控制是否需要预训练语言模型,"model"表示使用哪种方法,根据需求修改"Training_with_Processor"或者"Training"参数,推荐前者即使用Processor进行预处理。默认情况下,我们只设置了**num_train_epochs=1, 在很多数据集下,是不能很好收敛的**。 43 | - 例如,如果想使用[TS-FCN](https://link.springer.com/article/10.1007/s11042-020-08716-w)方法,一个最简单的方法就是将test.json里的"model"改为"fcn","use_plm"改为false, 然后设置合适的学习率、epoch、batchsize 等等(在Training_with_Processor中) 44 | - Use Command: `python main.py --config_path your_config_file_path` 45 | - For example: `python main.py --config_path ./configs/test.json` 46 | - As a reference, ./configs/test.json contains the hyper-parameters settings of various methods (redundant, delete and modify up to your own setting). Among them, "use_plm" is used to control whether the pretrained language model (PLM) is used to embed words. "model" indicates which method to be use. Modify the "training_with_processor" or "training" parameters according to your needs. It is recommended that the former, namely using processor for pre-processing your data. By default, we only set **num_ Train_ Epochs = 1, where models can not converge well in many datasets **. 47 | - For example, if you want to try the [TS-FCN](https://link.springer.com/article/10.1007/s11042-020-08716-w) method, a very easy way is to set "model" as "fcn", set "use_plm" as false, and then set the appropriate learning rate, epoch, batchsize, etc. (in Training_with_Processor), based on "test.json". 48 | 49 | ## 注意事项(Attention!) 50 | - 之前的sesy config不是[论文](https://ieeexplore.ieee.org/abstract/document/9591452)中使用的config. 已修复(2023.03)。 51 | - 如果使用此代码进行不同方法的对比时,请仔细确认数据集的拆分方式是否相同。 52 | - 可以在config.json 的Dataset中设置自己的“split_ratio”,例如设置为0.8, 则会产生0.2的test.csv”, 0.2的“val.csv”和0.6的“train.csv” 53 | - 或者也可以将“stego_txt”和“cover_txt”设置为None,设置“resplit"为False,并将数据集放在“csv_dir”所指示的文件路径中。放置的数据集应根据对比实验的需要,事先自行拆分为“train.csv”、“val.csv”和“test.csv”。 54 | - Previous sesy config is not the configs used in paper [SESY](https://ieeexplore.ieee.org/abstract/document/9591452). We have fixed it. 55 | - When using this repo for camparation, please carefully confirm whether the spliting of data set is the same. 56 | - You can reset the "split_ratio" in Dataset config. For example, if set it to 0.8, you will get 0.2 "test.csv", 0.2 "val.csv", and 0.6 "train. csv". 57 | - Or, you can set stego_txt and cover_txt as None, set "resplit" as False, and put your dataset in path "csv_dir". The dataset should beforehand be split into "train.csv", "val.csv" and "test.csv" by yourself. 58 | 59 | 60 | ## Codes Reference 61 | 基础 [R-BiLSTM-C](https://ieeexplore.ieee.org/abstract/document/8903243) [BiLISTM-Dense](https://dl.acm.org/doi/abs/10.1145/3369412.3395067) [MS-TL](https://ieeexplore.ieee.org/abstract/document/9484749/) 模型的实现参考自[CAU-Tstega/Text-steganalysis](https://github.com/CAU-Tstega/Text-steganalysis) 62 | 63 | implements of [R-BiLSTM-C](https://ieeexplore.ieee.org/abstract/document/8903243) [BiLISTM-Dense](https://dl.acm.org/doi/abs/10.1145/3369412.3395067) [MS-TL](https://ieeexplore.ieee.org/abstract/document/9484749/) refer to [CAU-Tstega/Text-steganalysis](https://github.com/CAU-Tstega/Text-steganalysis) 64 | 65 | # 欢迎补充与讨论 (both supplement and discussion are grateful and helpful) 66 | -------------------------------------------------------------------------------- /models/ms.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from transformers import BertPreTrainedModel,BertModel 4 | import os 5 | 6 | class MyBert(BertPreTrainedModel): 7 | def __init__(self, config, **kwargs): 8 | super().__init__(config) 9 | self.bert_config = config 10 | self.bert = BertModel(self.bert_config) 11 | self.class_num = kwargs["class_num"] 12 | self.classifier = nn.Linear(self.bert_config.hidden_size, self.class_num) 13 | if kwargs["criteration"] == "CrossEntropyLoss": 14 | self.criteration = nn.CrossEntropyLoss() 15 | else: 16 | # default loss 17 | self.criteration = nn.CrossEntropyLoss() 18 | 19 | def forward(self, input_ids, labels, attention_mask, token_type_ids): 20 | outputs = self.bert(input_ids, attention_mask, token_type_ids) 21 | logits = self.classifier(outputs[1]) 22 | loss = self.criteration(logits, labels) 23 | return loss, logits 24 | 25 | 26 | class CNN(nn.Module): 27 | def __init__(self, vocab_size, embed_size, filter_num, class_num, dropout_rate, criteration="CrossEntropyLoss",): 28 | super(CNN, self).__init__() 29 | self.vocab_size = vocab_size 30 | self.embed_size = embed_size 31 | self.class_num = class_num 32 | self.embedding = nn.Embedding(self.vocab_size, self.embed_size) 33 | self.conv = nn.Conv2d(1, filter_num, (3,self.embed_size)) 34 | self.dropout = nn.Dropout(dropout_rate) 35 | self.classifier = nn.Linear(filter_num, self.class_num) 36 | if criteration == "CrossEntropyLoss": 37 | self.criteration = nn.CrossEntropyLoss() 38 | else: 39 | # default loss 40 | self.criteration = nn.CrossEntropyLoss() 41 | 42 | def forward(self, input_ids, labels, attention_mask=None,token_type_ids=None): 43 | clf_input = self.embedding(input_ids).unsqueeze(3).permute(0,3,1,2) 44 | clf_input = self.conv(clf_input) 45 | clf_input = nn.functional.relu(clf_input).squeeze(3) 46 | clf_input = nn.functional.max_pool1d(clf_input, clf_input.size(2)).squeeze(2) 47 | clf_input = self.dropout(clf_input) 48 | logits = self.classifier(clf_input) 49 | loss = self.criteration(logits, labels) 50 | return loss, logits 51 | 52 | 53 | class BiRNN(nn.Module): 54 | def __init__(self, vocab_size, embed_size, hidden_size,num_layers, bidirectional, class_num, criteration="CrossEntropyLoss", ): 55 | super(BiRNN, self).__init__() 56 | self.vocab_size = vocab_size 57 | self.embed_size = embed_size 58 | self.hidden_size = hidden_size 59 | self.class_num = class_num 60 | self.embedding = nn.Embedding(self.vocab_size, self.embed_size) 61 | self.lstm = nn.LSTM(self.embed_size, self.hidden_size, num_layers=num_layers,bidirectional=bidirectional,batch_first=True ) 62 | self.classifier = nn.Linear(2*self.hidden_size, self.class_num) 63 | if criteration == "CrossEntropyLoss": 64 | self.criteration = nn.CrossEntropyLoss() 65 | else: 66 | # default loss 67 | self.criteration = nn.CrossEntropyLoss() 68 | 69 | def forward(self, input_ids, labels, attention_mask=None, token_type_ids=None): 70 | clf_input = self.embedding(input_ids) 71 | clf_input, _ = self.lstm(clf_input) 72 | logits = self.classifier(clf_input[:,-1,:]) 73 | loss = self.criteration(logits, labels) 74 | return loss, logits 75 | 76 | 77 | class BiRNN_C(nn.Module): 78 | def __init__(self, class_num, dropout_rate, criteration="CrossEntropyLoss", **kwargs): 79 | super(BiRNN_C, self).__init__() 80 | cnn = torch.load(os.path.join(kwargs["cnn_checkpoint"], "pytorch_model.bin")) 81 | rnn = torch.load(os.path.join(kwargs["birnn_checkpoint"], "pytorch_model.bin")) 82 | # self.vocab_size = vocab_size 83 | # self.embed_size = embed_size 84 | # self.hidden_size = hidden_size 85 | self.class_num = class_num 86 | self.embed_C = cnn.embedding 87 | self.embed_R = rnn.embedding 88 | self.lstm = rnn.lstm 89 | self.conv = cnn.conv 90 | # self.embed_C = nn.Embedding(self.vocab_size, self.embed_size) 91 | # self.embed_R = nn.Embedding(self.vocab_size, self.embed_size) 92 | # self.lstm = nn.LSTM(self.embed_size, self.hidden_size, num_layers=num_layers,bidirectional=bidirectional,batch_first=True) 93 | # self.conv = nn.Conv2d(1, filter_num, (3, self.embed_size)) 94 | self.dropout = nn.Dropout(dropout_rate) 95 | self.classifier = nn.Linear(cnn.classifier.in_features+rnn.classifier.in_features, self.class_num) 96 | if criteration == "CrossEntropyLoss": 97 | self.criteration = nn.CrossEntropyLoss() 98 | else: 99 | # default loss 100 | self.criteration = nn.CrossEntropyLoss() 101 | 102 | def forward(self, input_ids, labels, attention_mask=None, token_type_ids=None): 103 | rnn_input = self.embed_R(input_ids) 104 | rnn_input, _ = self.lstm(rnn_input) 105 | rnn_input = rnn_input[:,-1,:] 106 | 107 | cnn_input = self.embed_C(input_ids).unsqueeze(3).permute(0,3,1,2) 108 | cnn_input = self.conv(cnn_input) 109 | cnn_input = nn.functional.relu(cnn_input).squeeze(3) 110 | cnn_input = nn.functional.max_pool1d(cnn_input, cnn_input.size(2)).squeeze(2) 111 | 112 | clf_input = torch.cat([cnn_input, rnn_input],dim=1) 113 | clf_input = self.dropout(clf_input) 114 | logits = self.classifier(clf_input) 115 | loss = self.criteration(logits, labels) 116 | return loss, logits -------------------------------------------------------------------------------- /models/sesy.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from transformers import BertPreTrainedModel, BertModel 5 | 6 | 7 | class GATLayer(nn.Module): 8 | def __init__(self, in_features, out_features, dropout, alpha, concat=True, get_att=False): 9 | super(GATLayer, self).__init__() 10 | self.in_features = in_features 11 | self.out_features = out_features 12 | self.alpha = alpha 13 | self.concat = concat 14 | self.get_att = get_att 15 | self.dropout = dropout 16 | self.W = nn.Parameter(torch.zeros(size=(in_features, out_features)).cuda()) 17 | nn.init.xavier_uniform_(self.W.data, gain=1.414) 18 | self.a = nn.Parameter(torch.zeros(size=(2 * out_features, 1)).cuda()) 19 | nn.init.xavier_uniform_(self.a.data, gain=1.414) 20 | self.leakyrelu = nn.LeakyReLU(self.alpha) 21 | 22 | def forward(self, input, adj): 23 | h = torch.matmul(input, self.W) 24 | a_input = self._prepare_attentional_mechanism_input(h) 25 | e = self.leakyrelu(torch.matmul(a_input, self.a).squeeze(-1)) 26 | 27 | zero_vec = -9e15 * torch.ones_like(e) 28 | attention = torch.where(adj > 0, e, zero_vec) 29 | attention = F.softmax(attention, dim=-1) 30 | attention = F.dropout(attention, self.dropout, training=self.training) 31 | h_prime = torch.matmul(attention, h) 32 | if self.concat: 33 | return F.elu(h_prime) 34 | else: 35 | return h_prime 36 | 37 | def _prepare_attentional_mechanism_input(self, Wh): 38 | B, M, E = Wh.shape # (batch_zize, number_nodes, out_features) 39 | Wh_repeated_in_chunks = Wh.repeat_interleave(M, dim=1) # (B, M*M, E) 40 | Wh_repeated_alternating = Wh.repeat(1, M, 1) # (B, M*M, E) 41 | all_combinations_matrix = torch.cat([Wh_repeated_in_chunks, Wh_repeated_alternating], dim=-1) # (B, M*M,2E) 42 | return all_combinations_matrix.view(B, M, M, 2 * E) 43 | 44 | def __repr__(self): 45 | return self.__class__.__name__ + '(' + str(self.in_features) + '->' + str(self.out_features) + ')' 46 | 47 | 48 | class GAT(nn.Module): 49 | def __init__(self, n_feat, n_hid, out_features, dropout, alpha, n_heads): 50 | super(GAT, self).__init__() 51 | self.hidden = n_hid 52 | self.max_length = 128 53 | self.dropout = 0.1 54 | self.attentions = [GATLayer(n_feat, n_hid, dropout=self.dropout, alpha=alpha, concat=True, get_att=False) for _ 55 | in range(n_heads)] 56 | # self.attentions_adj = [GATLayer(n_feat, self.max_length, alpha=alpha, concat=True,get_att=True) for _ in range(n_heads)] 57 | for i, attention in enumerate(self.attentions): 58 | self.add_module('attention_{}'.format(i), attention) 59 | 60 | self.out_att = GATLayer(n_hid * n_heads, out_features, dropout=self.dropout, alpha=alpha, concat=False) 61 | 62 | def forward(self, x_input, adj): 63 | x = F.dropout(x_input, self.dropout, training=self.training) 64 | x = torch.cat([att(x, adj) for att in self.attentions], dim=-1) 65 | x = F.dropout(x, self.dropout, training=self.training) 66 | x = F.elu(self.out_att(x, adj)) 67 | return x 68 | 69 | 70 | class TC(nn.Module): 71 | def __init__(self,vocab_size,clf, TC_configs, embed_size,class_num, hidden_dim, readout_size, gat_alpha, gat_heads,dropout_rate,strategy="cas", criteration="CrossEntropyLoss",): 72 | super(TC, self).__init__() 73 | 74 | self.vocab_size = vocab_size 75 | self.embed_size = embed_size 76 | self.clf_name = clf 77 | self.dropout_rate = dropout_rate 78 | self.class_num = class_num 79 | self.hidden_dim = hidden_dim 80 | self.readout_size = readout_size 81 | self.gat_alpha = gat_alpha 82 | self.gat_heads = gat_heads 83 | self.strategy = strategy 84 | 85 | if self.clf_name == "cnn": 86 | from .cnn import TC_base 87 | self.clf_configs = TC_configs.cnn 88 | elif self.clf_name == "fc": 89 | from .fcn import TC_base 90 | self.clf_configs = TC_configs.fc 91 | elif self.clf_name == "rnn": 92 | from .birnn import TC_base 93 | self.clf_configs = TC_configs.rnn 94 | else: 95 | assert 0, "No such clf, only support cnn rnn & fc" 96 | self.embedding = nn.Embedding(self.vocab_size, self.embed_size) 97 | self.gat = GAT( 98 | n_feat=self.embed_size, 99 | n_hid=self.hidden_dim, 100 | out_features=self.readout_size, 101 | alpha=self.gat_alpha, 102 | n_heads=self.gat_heads, 103 | dropout=self.dropout_rate 104 | ) 105 | if self.strategy.lower() == "cas": 106 | self.clf_configs.in_features = self.readout_size 107 | elif self.strategy.lower() == "parl": 108 | self.clf_configs.in_features = self.readout_size + self.embed_size 109 | self.classifier = TC_base(**{**self.clf_configs,"class_num":self.class_num,"dropout_rate":self.dropout_rate}) 110 | if criteration == "CrossEntropyLoss": 111 | self.criteration = nn.CrossEntropyLoss() 112 | else: 113 | # default loss 114 | self.criteration = nn.CrossEntropyLoss() 115 | 116 | # covert my dict to standard dict 117 | self.clf_configs = {**self.clf_configs} 118 | 119 | def forward(self, input_ids, labels, attention_mask=None, token_type_ids=None, graph=None): 120 | embedding = self.embedding(input_ids) 121 | gat_out = self.gat(embedding, graph) 122 | 123 | if self.strategy.lower() == "cas": 124 | logits = self.classifier(gat_out) 125 | elif self.strategy.lower() == "parl": 126 | logits = self.classifier(torch.cat([gat_out,embedding],dim=2)) 127 | loss = self.criteration(logits, labels) 128 | return loss, logits 129 | 130 | 131 | class BERT_TC(BertPreTrainedModel): 132 | def __init__(self, config, **kwargs): 133 | super().__init__(config) 134 | TC_configs = kwargs["TC_configs"] 135 | self.embed_size = config.hidden_size 136 | self.clf_name = kwargs["clf"] 137 | self.dropout_rate = kwargs["dropout_rate"] 138 | self.class_num = kwargs["class_num"] 139 | self.hidden_dim = kwargs["hidden_dim"] 140 | self.readout_size = kwargs["readout_size"] 141 | self.gat_alpha = kwargs["gat_alpha"] 142 | self.gat_heads = kwargs["gat_heads"] 143 | self.strategy = kwargs["strategy"] 144 | self.bert_config = config 145 | 146 | if self.clf_name == "cnn": 147 | from .cnn import TC_base 148 | self.clf_configs = TC_configs.cnn 149 | elif self.clf_name == "fc": 150 | from .fcn import TC_base 151 | self.clf_configs = TC_configs.fc 152 | elif self.clf_name == "rnn": 153 | from .birnn import TC_base 154 | self.clf_configs = TC_configs.rnn 155 | else: 156 | assert 0, "No such clf, only support cnn rnn & fc" 157 | 158 | self.bert = BertModel(config) 159 | self.gat = GAT( 160 | n_feat=self.embed_size, 161 | n_hid=self.hidden_dim, 162 | out_features=self.readout_size, 163 | alpha=self.gat_alpha, 164 | n_heads=self.gat_heads, 165 | dropout=self.dropout_rate 166 | ) 167 | if self.strategy.lower() == "cas": 168 | self.clf_configs.in_features = self.readout_size 169 | elif self.strategy.lower() == "parl": 170 | self.clf_configs.in_features = self.readout_size + self.embed_size 171 | self.classifier = TC_base( 172 | **{**self.clf_configs, "class_num": self.class_num, "dropout_rate": self.dropout_rate}) 173 | if kwargs["criteration"] == "CrossEntropyLoss": 174 | self.criteration = nn.CrossEntropyLoss() 175 | else: 176 | # default loss 177 | self.criteration = nn.CrossEntropyLoss() 178 | 179 | self.init_weights() 180 | 181 | def forward(self, input_ids, attention_mask=None, token_type_ids=None, 182 | graph=None, labels=None,): 183 | outputs = self.bert(input_ids=input_ids, 184 | attention_mask=attention_mask, 185 | token_type_ids=token_type_ids) 186 | embedding = outputs[0] # [batch_size, node,hidden_size] 187 | gat_out = self.gat(embedding, graph) 188 | if self.strategy.lower() == "cas": 189 | logits = self.classifier(gat_out) 190 | elif self.strategy.lower() == "parl": 191 | logits = self.classifier(torch.cat([gat_out,embedding],dim=2)) 192 | loss = self.criteration(logits, labels) 193 | return loss, logits 194 | -------------------------------------------------------------------------------- /models/lstmatt.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from transformers import BertPreTrainedModel,BertModel 6 | 7 | class Attention(nn.Module): 8 | def __init__(self, embed_dim, hidden_dim=None, out_dim=None, n_head=1, score_function='dot_product', dropout=0): 9 | super(Attention, self).__init__() 10 | if hidden_dim is None: 11 | hidden_dim = embed_dim // n_head 12 | if out_dim is None: 13 | out_dim = embed_dim 14 | self.embed_dim = embed_dim 15 | self.hidden_dim = hidden_dim 16 | self.n_head = n_head 17 | self.score_function = score_function 18 | self.w_k = nn.Linear(embed_dim, n_head * hidden_dim) 19 | self.w_q = nn.Linear(embed_dim, n_head * hidden_dim) 20 | self.proj = nn.Linear(n_head * hidden_dim, out_dim) 21 | self.dropout = nn.Dropout(dropout) 22 | 23 | if self.score_function == 'mlp': 24 | self.weight = nn.Parameter(torch.Tensor(hidden_dim * 2)) 25 | elif self.score_function == 'bi_linear': 26 | self.weight = nn.Parameter(torch.Tensor(hidden_dim, hidden_dim)) 27 | else: 28 | self.register_parameter('weight', None) 29 | self.reset_parameters() 30 | 31 | def reset_parameters(self): 32 | stdv = 1. / math.sqrt(self.hidden_dim) 33 | if self.weight is not None: 34 | self.weight.data.uniform_(-stdv, stdv) 35 | 36 | def forward(self, k, q): 37 | if len(q.shape) == 2: 38 | q = torch.unsqueeze(q, dim=1) 39 | if len(k.shape) == 2: 40 | k = torch.unsqueeze(k, dim=1) 41 | mb_size = k.shape[0] 42 | k_len = k.shape[1] 43 | q_len = q.shape[1] 44 | # k: (?, k_len, embed_dim,) 45 | # q: (?, q_len, embed_dim,) 46 | # kx: (n_head*?, k_len, hidden_dim) 47 | # qx: (n_head*?, q_len, hidden_dim) 48 | # score: (n_head*?, q_len, k_len,) 49 | # output: (?, q_len, out_dim,) 50 | kx = self.w_k(k).view(mb_size, k_len, self.n_head, self.hidden_dim) 51 | kx = kx.permute(2, 0, 1, 3).contiguous().view(-1, k_len, self.hidden_dim) 52 | qx = self.w_q(q).view(mb_size, q_len, self.n_head, self.hidden_dim) 53 | qx = qx.permute(2, 0, 1, 3).contiguous().view(-1, q_len, self.hidden_dim) 54 | if self.score_function == 'dot_product': 55 | kt = kx.permute(0, 2, 1) 56 | score = torch.bmm(qx, kt) 57 | elif self.score_function == 'scaled_dot_product': 58 | kt = kx.permute(0, 2, 1) 59 | qkt = torch.bmm(qx, kt) 60 | score = torch.div(qkt, math.sqrt(self.hidden_dim)) 61 | elif self.score_function == 'mlp': 62 | kxx = torch.unsqueeze(kx, dim=1).expand(-1, q_len, -1, -1) 63 | qxx = torch.unsqueeze(qx, dim=2).expand(-1, -1, k_len, -1) 64 | kq = torch.cat((kxx, qxx), dim=-1) # (n_head*?, q_len, k_len, hidden_dim*2) 65 | # kq = torch.unsqueeze(kx, dim=1) + torch.unsqueeze(qx, dim=2) 66 | score = F.tanh(torch.matmul(kq, self.weight)) 67 | elif self.score_function == 'bi_linear': 68 | qw = torch.matmul(qx, self.weight) 69 | kt = kx.permute(0, 2, 1) 70 | score = torch.bmm(qw, kt) 71 | else: 72 | raise RuntimeError('invalid score_function') 73 | score = F.softmax(score, dim=-1) 74 | output = torch.bmm(score, kx) # (n_head*?, q_len, hidden_dim) 75 | output = torch.cat(torch.split(output, mb_size, dim=0), dim=-1) # (?, q_len, n_head*hidden_dim) 76 | output = self.proj(output) # (?, q_len, out_dim) 77 | output = self.dropout(output) 78 | return output, score 79 | 80 | 81 | class lstm(nn.Module): 82 | def __init__(self, input_size, hidden_size, bidirectional=True): 83 | super(lstm, self).__init__() 84 | self.input_size = input_size 85 | if bidirectional: 86 | self.hidden_size = hidden_size // 2 87 | else: 88 | self.hidden_size = hidden_size 89 | self.bidirectional = bidirectional 90 | 91 | self.LNx = nn.LayerNorm(4 * self.hidden_size) 92 | self.LNh = nn.LayerNorm(4 * self.hidden_size) 93 | self.LNc = nn.LayerNorm(self.hidden_size) 94 | self.Wx = nn.Linear(in_features=self.input_size, out_features=4 * self.hidden_size, bias=True) 95 | self.Wh = nn.Linear(in_features=self.hidden_size, out_features=4 * self.hidden_size, bias=True) 96 | 97 | def forward(self, x): 98 | def recurrence(xt, hidden): # enhanced with layer norm 99 | # input: input to the current cell 100 | htm1, ctm1 = hidden 101 | gates = self.LNx(self.Wx(xt)) + self.LNh(self.Wh(htm1)) 102 | it, ft, gt, ot = gates.chunk(4, 1) 103 | it = torch.sigmoid(it) 104 | ft = torch.sigmoid(ft) 105 | gt = torch.tanh(gt) 106 | ot = torch.sigmoid(ot) 107 | ct = (ft * ctm1) + (it * gt) 108 | ht = ot * torch.tanh(self.LNc(ct)) 109 | return ht, ct 110 | 111 | output = [] 112 | steps = range(x.size(1)) 113 | hidden = self.init_hidden(x.size(0)) 114 | inputs = x.transpose(0, 1) 115 | for t in steps: 116 | hidden = recurrence(inputs[t], hidden) 117 | output.append(hidden[0]) 118 | output = torch.stack(output, 0).transpose(0, 1) 119 | if self.bidirectional: 120 | hidden_b = self.init_hidden(x.size(0)) 121 | output_b = [] 122 | for t in steps[::-1]: 123 | hidden_b = recurrence(inputs[t], hidden_b) 124 | output_b.append(hidden_b[0]) 125 | output_b = output_b[::-1] 126 | output_b = torch.stack(output_b, 0).transpose(0, 1) 127 | output = torch.cat([output, output_b], dim=-1) 128 | return output 129 | 130 | def init_hidden(self, bs): 131 | h_0 = torch.zeros(bs, self.hidden_size).cuda() 132 | c_0 = torch.zeros(bs, self.hidden_size).cuda() 133 | return h_0, c_0 134 | 135 | 136 | class TC_base(nn.Module): 137 | def __init__(self,in_features, hidden_dim, class_num, dropout_rate,bidirectional): 138 | super(TC_base, self).__init__() 139 | self.in_features = in_features 140 | self.dropout_prob = dropout_rate 141 | self.num_labels = class_num 142 | self.hidden_size = hidden_dim 143 | self.bidirectional = bidirectional 144 | self.dropout = nn.Dropout(self.dropout_prob) 145 | self.lstm = lstm( 146 | input_size=self.in_features, 147 | hidden_size=self.hidden_size, 148 | bidirectional=True 149 | ) 150 | self.attn = Attention( 151 | embed_dim=self.hidden_size, 152 | hidden_dim=self.hidden_size, 153 | n_head=1, 154 | score_function='mlp', 155 | dropout=self.dropout_prob 156 | ) 157 | self.classifier = nn.Linear(self.hidden_size, self.num_labels) 158 | 159 | def forward(self, features, mask, input_ids_len): 160 | output = self.lstm(features) 161 | output = self.dropout(output) 162 | scc, scc1 = self.attn(output,output) 163 | t = input_ids_len.view( input_ids_len.size(0),1) 164 | scc_sen = torch.sum(scc,dim=2) 165 | scc_mean = torch.div(torch.sum(scc,dim=1),t) 166 | logits = self.classifier(scc_mean) 167 | return logits 168 | 169 | def extra_repr(self) -> str: 170 | return 'features {}->{},'.format( 171 | self.in_features, self.class_num 172 | ) 173 | 174 | 175 | class TC(nn.Module): 176 | def __init__(self, vocab_size, embed_size, hidden_dim, class_num, dropout_rate,bidirectional=True,criteration="CrossEntropyLoss"): 177 | super(TC,self).__init__() 178 | self.embed_size = embed_size 179 | self.dropout_prob = dropout_rate 180 | self.num_labels = class_num 181 | self.bidirectional = bidirectional 182 | self.hidden_size = hidden_dim 183 | self.embed = nn.Embedding(vocab_size, self.embed_size) 184 | self.classifier = TC_base(self.embed_size,self.hidden_size,self.num_labels,self.dropout_prob,self.bidirectional) 185 | if criteration == "CrossEntropyLoss": 186 | self.criteration = nn.CrossEntropyLoss() 187 | else: 188 | # default loss 189 | self.criteration = nn.CrossEntropyLoss() 190 | # self.it_weights() 191 | 192 | 193 | def forward(self, input_ids,labels,attention_mask=None,token_type_ids=None): 194 | input_ids_len = torch.sum(input_ids != 0, dim=-1).float() 195 | input_lstm = self.embed(input_ids.long())[0] 196 | mask = torch.ones_like(input_ids.long()) 197 | mask[input_ids.long() != 0 ] = 0 198 | logits = self.classifier(input_lstm,mask,input_ids_len) 199 | loss = self.criteration(logits,labels) 200 | return loss,logits 201 | 202 | 203 | class BERT_TC(BertPreTrainedModel): 204 | def __init__(self, config, **kwargs): 205 | super().__init__(config) 206 | self.bert_config = config 207 | self.bert = BertModel(self.bert_config) 208 | self.embed_size = config.hidden_size 209 | self.hidden_size = kwargs["hidden_dim"] 210 | self.num_labels = kwargs["class_num"] 211 | self.dropout_prob = kwargs["dropout_rate"] 212 | self.bidirectional = kwargs["bidirectional"] 213 | self.classifier = TC_base(self.embed_size, self.hidden_size, self.num_labels, self.dropout_prob, 214 | self.bidirectional) 215 | if kwargs["criteration"] == "CrossEntropyLoss": 216 | self.criteration = nn.CrossEntropyLoss() 217 | else: 218 | # default loss 219 | self.criteration = nn.CrossEntropyLoss() 220 | # self.it_weights() 221 | 222 | def forward(self,input_ids, labels, attention_mask, token_type_ids): 223 | input_ids_len = torch.sum(input_ids != 0, dim=-1).float() 224 | outputs = self.bert(input_ids, attention_mask, token_type_ids) 225 | embedding = outputs[0] 226 | mask = torch.ones_like(input_ids.long()) 227 | mask[input_ids.long() != 0 ] = 0 228 | logits = self.classifier(embedding,mask,input_ids_len) 229 | loss = self.criteration(logits,labels) 230 | return loss,logits 231 | 232 | 233 | def extra_repr(self) -> str: 234 | return 'bert word embedding dim:{}'.format( 235 | self.embed_size 236 | ) -------------------------------------------------------------------------------- /processors/graph_process.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from torch.utils.data import TensorDataset 4 | import csv 5 | import os 6 | import copy 7 | import spacy 8 | import stanza 9 | 10 | # os.environ["CUDA_VISIBLE_DEVICES"] = "0" 11 | 12 | # from stanza.server import CoreNLPClient 13 | # corenlp = CoreNLPClient(annotators=["tokenize","ssplit","pos","lemma","depparse"],DEFAULT_ENDPOINT="http://localhost:9001") 14 | 15 | # Corenlp = stanza.Pipeline(lang="en",processors='tokenize,pos,lemma,depparse',tokenize_pretokenized=True) 16 | 17 | # def dependency_adj_matrix_test(text): 18 | # document = NLP(text) 19 | # seq_len = len(text.split()) 20 | # matrix = np.zeros((seq_len, seq_len)).astype('float32') 21 | # 22 | # for token in document: 23 | # if token.i < seq_len: 24 | # matrix[token.i][token.i] = 1 25 | # for child in token.children: 26 | # if child.i < seq_len: 27 | # matrix[token.i][child.i] = 1 28 | # return matrix 29 | 30 | 31 | def dependency_adj_matrix(Corenlp, text, is_bidirectional=True,is_self_loop=True): 32 | # matrix_tmp = dependency_adj_matrix_test(text) 33 | # seq_len = len(text.split()) 34 | # text ='the absence of violence and sex was refreshing' 35 | doc = Corenlp(text) 36 | words = [] 37 | ids = [] 38 | head_ids = [] 39 | relations = [] 40 | for sent in doc.sentences: 41 | for word in sent.words: 42 | ids.append(word.id) 43 | words.append(word.text) 44 | head_ids.append(word.head) 45 | relations.append(word.deprel) 46 | matrix_size = len(words)+1 47 | matrix = np.zeros((matrix_size,matrix_size)) 48 | relation_mat = [["NULL" for _ in range(matrix_size)] for __ in range(matrix_size)] 49 | # relation_mat = np.zeros((matrix_size,matrix_size), dtype=str) 50 | for id, head_id, relation_type in zip(ids,head_ids, relations): 51 | matrix[id][head_id] = 1 52 | relation_mat[id][head_id] = relation_type 53 | if is_bidirectional: 54 | matrix[head_id][id] = 1 55 | relation_mat[head_id][id] = relation_type 56 | if is_self_loop: 57 | matrix[id][id] = 1 58 | relation_mat[id][id] = "selfloop" 59 | return matrix, np.array(relation_mat) 60 | 61 | 62 | class InputExample(object): 63 | def __init__(self, sentence=None, label=None): 64 | self.sentence = sentence 65 | self.label = label 66 | # self.dependency = dependency 67 | 68 | 69 | class SeqInputFeatures(object): 70 | def __init__(self, input_ids,input_mask,segment_ids, label_ids,dependency_adj): 71 | self.input_ids = input_ids 72 | self.input_mask = input_mask 73 | self.segment_ids = segment_ids 74 | self.label_ids = label_ids 75 | self.dependency_adj = dependency_adj 76 | 77 | 78 | class GraphSteganalysisProcessor(object): 79 | def __init__(self, tokenizer): 80 | self.tokenizer = tokenizer 81 | self.order = 1 82 | self.max_seq_len = 128 83 | self.label_list = [0, 1] 84 | self.num_labels = 2 85 | self.label2id = {} 86 | self.id2label = {} 87 | self.cls_connect = True 88 | self.sep_connect = False 89 | self.use_stanza = True 90 | if self.use_stanza: 91 | self.tokenizer.add_special_tokens({"additional_special_tokens":['[unused0]']}) 92 | for idx, label in enumerate(self.label_list): 93 | self.label2id[label] = idx 94 | self.id2label[idx] = label 95 | 96 | def get_examples(self, file_name): 97 | return self._create_examples( 98 | file_name=file_name 99 | ) 100 | 101 | def get_train_examples(self, dir): 102 | return self.get_examples(os.path.join(dir, "train.csv")) 103 | 104 | def get_dev_examples(self, dir): 105 | return self.get_examples(os.path.join(dir, "val.csv")) 106 | 107 | def get_test_examples(self, dir): 108 | return self.get_examples(os.path.join(dir, "test.csv")) 109 | 110 | def _create_examples(self, file_name): 111 | examples = [] 112 | file = file_name 113 | lines = csv.reader(open(file, 'r', encoding='utf-8')) 114 | # scores = [] 115 | for i, line in enumerate(lines): 116 | if i > 0: 117 | sentence = line[0].lower().strip() 118 | label_t = line[1].strip() 119 | if label_t == "0": 120 | label = 0 121 | if label_t == "1": 122 | label = 1 123 | examples.append(InputExample(sentence=sentence, label=label)) 124 | 125 | # dataset = self.convert_examples_to_features(examples) 126 | return examples 127 | 128 | 129 | def merge(self, sentence, dependency): 130 | ## merge depedency of spacy format and bert format 131 | input_ids = self.tokenizer.encode_plus(sentence)["input_ids"][1:-1] 132 | padding_matrix_length = self.max_seq_len - len(input_ids)-2 133 | # new_dependency = np.zeros((len(input_ids),len(input_ids))) 134 | word_piece = [self.tokenizer.encode_plus(x)["input_ids"][1:-1] for x in sentence.split()] 135 | idx = 0 136 | dependency_list = dependency.tolist() 137 | for w_p in word_piece: 138 | if len(w_p) != 1: 139 | for row_idx in range(len(dependency_list)): 140 | row = dependency_list[row_idx] 141 | row = row[:idx]+ [row[idx]]*(len(w_p)-1) + row[idx:] 142 | dependency_list[row_idx] = copy.deepcopy(row) 143 | dependency_list = dependency_list[:idx]+ [dependency_list[idx]]*(len(w_p)-1)+dependency_list[idx:] 144 | idx += len(w_p) 145 | 146 | new_dependency = np.array(dependency_list) 147 | if new_dependency.shape[0] != new_dependency.shape[1]: 148 | print("error 1 sentence:%s"%sentence) 149 | if len(input_ids) != new_dependency.shape[0]: 150 | print("error 2 sentence: %s"%sentence) 151 | 152 | if padding_matrix_length <= 0: 153 | new_dependency = new_dependency[:self.max_seq_len-2, :self.max_seq_len-2] 154 | new_dependency = np.pad(new_dependency,((1,1),(1,1)), "constant") 155 | new_dependency[0,0]=1 156 | new_dependency[self.max_seq_len-1,self.max_seq_len-1]=1 157 | else: 158 | new_dependency = np.pad(new_dependency, ((1, padding_matrix_length+1), (1, padding_matrix_length+1)), 159 | 'constant') 160 | new_dependency[0,0]=1 161 | new_dependency[len(input_ids)+1,len(input_ids)+1]=1 162 | return new_dependency 163 | 164 | 165 | def convert_examples_to_features(self, examples): 166 | ''' 167 | only for bert tokenizer 168 | ''' 169 | 170 | Corenlp = stanza.Pipeline(lang="en", processors='tokenize,pos,lemma,depparse', tokenize_pretokenized=True) 171 | features = [] 172 | for example in examples: 173 | if self.use_stanza: 174 | # convert2stanzaformat 175 | # only for bert tokenizer 176 | inputs = self.tokenizer.encode_plus( 177 | "[unused0] " + example.sentence, 178 | add_special_tokens=True, 179 | max_length=self.max_seq_len, 180 | padding="max_length", 181 | truncation = True, 182 | return_attention_mask=True, 183 | return_token_type_ids=True 184 | ) 185 | else: 186 | inputs = self.tokenizer.encode_plus( 187 | example.sentence, 188 | add_special_tokens=True, 189 | max_length=self.max_seq_len, 190 | padding="max_length", 191 | truncation=True, 192 | return_attention_mask=True, 193 | return_token_type_ids=True 194 | ) 195 | input_ids = inputs["input_ids"] 196 | attention_mask = inputs["attention_mask"] 197 | token_type_ids = inputs["token_type_ids"] 198 | if example.label is not None: 199 | label_id = self.label2id[example.label] 200 | else: 201 | label_id = -1 202 | 203 | # convert2stanzaformat 204 | # only for bert tokenizer 205 | if self.use_stanza: 206 | input_ids_tmp = self.tokenizer.encode_plus(example.sentence)["input_ids"] 207 | words = [self.tokenizer.decode(id) for id in input_ids_tmp[1:-1]] 208 | dependency, _ = dependency_adj_matrix(Corenlp, " ".join(words)) 209 | # dependency, _ = dependency_adj_matrix(Corenlp=Corenlp, text=self.tokenizer.decode(input_ids_tmp[2:-1])) 210 | sep_position = len(input_ids_tmp) 211 | else: 212 | dependency, _ = dependency_adj_matrix(example.sentence) 213 | dependency_ = dependency 214 | for _ in range(self.order-1): 215 | dependency_ = np.matmul(dependency_, dependency) 216 | dependency_[dependency_>0] = 1 217 | # padding dependency 218 | padding_matrix_length = self.max_seq_len - dependency_.shape[0] - 2 219 | if padding_matrix_length <= 0: 220 | new_dependency = dependency_[:self.max_seq_len - 2, :self.max_seq_len - 2] 221 | constant_values = 1 if self.cls_connect else 0 222 | new_dependency = np.pad(new_dependency, ((1, 0),(1,0)), "constant",constant_values=constant_values) 223 | constant_values = 1 if self.sep_connect else 0 224 | new_dependency = np.pad(new_dependency, ((0,1), (0, 1)), "constant",constant_values=constant_values) 225 | new_dependency[0, 0] = 1 226 | new_dependency[self.max_seq_len - 1, self.max_seq_len - 1] = 1 227 | else: 228 | constant_values = 1 if self.cls_connect else 0 229 | new_dependency = np.pad(dependency_, 230 | ((1, 0), (1,0)), 231 | 'constant',constant_values=constant_values) 232 | constant_values = 1 if self.sep_connect else 0 233 | new_dependency = np.pad(new_dependency, 234 | ((0,1), (0, 1)), 235 | 'constant',constant_values=constant_values) 236 | constant_values = 0 237 | new_dependency = np.pad(new_dependency, 238 | ((0,padding_matrix_length), (0, padding_matrix_length)), 239 | 'constant',constant_values=constant_values) 240 | new_dependency[0, 0] = 1 241 | new_dependency[sep_position, sep_position] = 1 242 | # return new_dependency 243 | # dependency = dependency_ 244 | # dependency = self.merge(example.sentence,example.dependency) 245 | features.append( 246 | SeqInputFeatures(input_ids=input_ids, 247 | input_mask=attention_mask, 248 | segment_ids=token_type_ids, 249 | label_ids=label_id, 250 | dependency_adj=new_dependency)) 251 | 252 | all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long) 253 | all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long) 254 | all_segment_ids = torch.tensor([f.segment_ids for f in features], dtype=torch.long) 255 | all_label_ids = torch.tensor([f.label_ids for f in features], dtype=torch.long) 256 | all_dependency_matrix = torch.tensor([f.dependency_adj for f in features], dtype=torch.long) 257 | dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids, all_dependency_matrix) 258 | return dataset 259 | 260 | 261 | def get_labels(self): 262 | return self.label_list 263 | 264 | 265 | if __name__ == '__main__': 266 | ''' 267 | function testing 268 | ''' 269 | 270 | # from transformers import AutoTokenizer 271 | # tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") 272 | # processor = GraphSteganalysisProcessor(tokenizer) 273 | # get_examples = processor.get_train_examples 274 | # _, examples = get_examples('../data') 275 | # processor.get_test_examples("../data") 276 | # processor.get_dev_examples("../data") -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import collections 3 | from sklearn.model_selection import train_test_split 4 | from transformers import AutoTokenizer 5 | 6 | 7 | class BertDataHelper(object): 8 | def __init__(self, raw, word_drop=5, ratio=0.8, use_label=True, use_length=False, tokenizer_config=None): 9 | assert (use_label and (len(raw) == 2)) or ((not use_label) and (len(raw) == 1)) 10 | self._word_drop = word_drop 11 | 12 | self.use_label = use_label 13 | self.use_length = use_length 14 | self.tokenizer_config= tokenizer_config 15 | self.train = None 16 | self.train_num = 0 17 | self.val = None 18 | self.val_num = None 19 | self.test = None 20 | self.test_num = 0 21 | if self.use_label: 22 | self.label_train = None 23 | self.label_test = None 24 | self.label_val = None 25 | if self.use_length: 26 | self.train_length = None 27 | self.test_length = None 28 | self.val_length = None 29 | self.max_sentence_length = 0 30 | self.min_sentence_length = 0 31 | 32 | self.vocab_size = 0 33 | self.vocab_size_raw = 0 34 | self.sentence_num = 0 35 | self.word_num = 0 36 | 37 | self.w2i = {} 38 | self.i2w = {} 39 | 40 | sentences = [] 41 | for _ in raw: 42 | sentences += _ 43 | 44 | self._build_vocabulary(sentences) 45 | corpus_length = None 46 | label = None 47 | if self.use_length: 48 | corpus, corpus_length = self._build_corpus(sentences) 49 | else: 50 | corpus = self._build_corpus(sentences) 51 | if self.use_label: 52 | label = self._build_label(raw) 53 | self._split(corpus, ratio, corpus_length=corpus_length, label=label) 54 | # self._split(corpus, ratio, corpus_length=corpus_length, label=label,sentences=sentences) 55 | 56 | def _build_label(self, raw): 57 | label = [0]*len(raw[0]) + [1]*len(raw[1]) 58 | return np.array(label) 59 | 60 | def _build_vocabulary(self, sentences): 61 | self.sentence_num = len(sentences) 62 | words = [] 63 | for sentence in sentences: 64 | words += sentence.split(' ') 65 | self.word_num = len(words) 66 | word_distribution = sorted(collections.Counter(words).items(), key=lambda x: x[1], reverse=True) 67 | self.vocab_size_raw = len(word_distribution) 68 | self.w2i['_PAD'] = 0 69 | self.w2i['_UNK'] = 1 70 | self.w2i['_BOS'] = 2 71 | self.w2i['_EOS'] = 3 72 | self.i2w[0] = '_PAD' 73 | self.i2w[1] = '_UNK' 74 | self.i2w[2] = '_BOS' 75 | self.i2w[3] = '_EOS' 76 | for (word, value) in word_distribution: 77 | if value > self._word_drop: 78 | self.w2i[word] = len(self.w2i) 79 | self.i2w[len(self.i2w)] = word 80 | self.vocab_size = len(self.i2w) 81 | 82 | def _build_corpus(self, sentences): 83 | self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_config.model_name_or_path) 84 | self.vocab_size = self.tokenizer.vocab_size 85 | corpus = self.tokenizer(sentences)["input_ids"] 86 | if self.use_length: 87 | corpus_length = np.array([len(i) for i in corpus]) 88 | self.max_sentence_length = corpus_length.max() 89 | self.min_sentence_length = corpus_length.min() 90 | return np.array(corpus), np.array(corpus_length) 91 | else: 92 | return np.array(corpus) 93 | 94 | def _split(self, corpus, ratio, corpus_length=None, label=None, sentences=None): 95 | self.train, self.test, self.label_train, self.label_test = train_test_split(corpus, label, test_size=1-ratio, shuffle=True, random_state=42) 96 | self.train, self.val, self.label_train, self.label_val = train_test_split(self.train, self.label_train, test_size=(1-ratio)/ratio) 97 | indices = list(range(self.sentence_num)) 98 | self.train_num = len(self.train) 99 | self.val_num = len(self.val) 100 | self.test_num = len(self.test) 101 | 102 | def _padding(self, batch_data): 103 | max_length = max([len(i) for i in batch_data]) 104 | for i in range(len(batch_data)): 105 | batch_data[i] += [self.tokenizer.pad_token_id] * (max_length - len(batch_data[i])) 106 | return np.array(list(batch_data)) 107 | 108 | def train_generator(self, batch_size, shuffle=True): 109 | indices = list(range(self.train_num)) 110 | if shuffle: 111 | np.random.shuffle(indices) 112 | while True: 113 | batch_indices = indices[0:batch_size] # 产生一个batch的index 114 | indices = indices[batch_size:] # 去掉本次index 115 | if len(batch_indices) == 0: 116 | return True 117 | batch_data = self.train[batch_indices] 118 | batch_data = self._padding(batch_data) 119 | result = [batch_data] 120 | if self.use_length: 121 | batch_length = self.train_length[batch_indices] 122 | result.append(batch_length) 123 | if self.use_label: 124 | batch_label = self.label_train[batch_indices] 125 | result.append(batch_label) 126 | yield tuple(result) 127 | 128 | def val_generator(self, batch_size, shuffle=True): 129 | indices = list(range(self.val_num)) 130 | if shuffle: 131 | np.random.shuffle(indices) 132 | while True: 133 | batch_indices = indices[0:batch_size] # 产生一个batch的index 134 | indices = indices[batch_size:] # 去掉本次index 135 | if len(batch_indices) == 0: 136 | return True 137 | batch_data = self.val[batch_indices] 138 | batch_data = self._padding(batch_data) 139 | result = [batch_data] 140 | if self.use_length: 141 | batch_length = self.val_length[batch_indices] 142 | result.append(batch_length) 143 | if self.use_label: 144 | batch_label = self.label_val[batch_indices] 145 | result.append(batch_label) 146 | yield tuple(result) 147 | 148 | 149 | def test_generator(self, batch_size, shuffle=True): 150 | indices = list(range(self.test_num)) 151 | if shuffle: 152 | np.random.shuffle(indices) 153 | while True: 154 | batch_indices = indices[0:batch_size] # 产生一个batch的index 155 | indices = indices[batch_size:] # 去掉本次index 156 | if len(batch_indices) == 0: 157 | return True 158 | batch_data = self.test[batch_indices] 159 | batch_data = self._padding(batch_data) 160 | result = [batch_data] 161 | if self.use_length: 162 | batch_length = self.test_length[batch_indices] 163 | result.append(batch_length) 164 | if self.use_label: 165 | batch_label = self.label_test[batch_indices] 166 | result.append(batch_label) 167 | yield tuple(result) 168 | pass 169 | 170 | 171 | class DataHelper(object): 172 | def __init__(self, raw, word_drop=5, ratio=0.8, use_label=False, use_length=False, do_lower=True, max_length=60): 173 | assert (use_label and (len(raw) == 2)) or ((not use_label) and (len(raw) == 1)) 174 | self._word_drop = word_drop 175 | self.use_label = use_label 176 | self.use_length = use_length 177 | self.do_lower = do_lower 178 | self.max_length = max_length 179 | self.train = None 180 | self.train_num = 0 181 | self.val = None 182 | self.val_num = None 183 | self.test = None 184 | self.test_num = 0 185 | if self.use_label: 186 | self.label_train = None 187 | self.label_test = None 188 | self.label_val = None 189 | if self.use_length: 190 | self.train_length = None 191 | self.test_length = None 192 | self.val_length = None 193 | self.max_sentence_length = 0 194 | self.min_sentence_length = 0 195 | 196 | self.vocab_size = 0 197 | self.vocab_size_raw = 0 198 | self.sentence_num = 0 199 | self.word_num = 0 200 | 201 | self.w2i = {} 202 | self.i2w = {} 203 | 204 | sentences = [] 205 | for _ in raw: 206 | sentences += _ 207 | 208 | self._build_vocabulary(sentences) 209 | corpus_length = None 210 | label = None 211 | if self.use_length: 212 | corpus, corpus_length = self._build_corpus(sentences) 213 | else: 214 | corpus = self._build_corpus(sentences) 215 | if self.use_label: 216 | label = self._build_label(raw) 217 | self._split(corpus, ratio, corpus_length=corpus_length, label=label) 218 | # self._split(corpus, ratio, corpus_length=corpus_length, label=label,sentences=sentences) 219 | 220 | def _build_label(self, raw): 221 | label = [0]*len(raw[0]) + [1]*len(raw[1]) 222 | return np.array(label) 223 | 224 | def _build_vocabulary(self, sentences): 225 | self.sentence_num = len(sentences) 226 | words = [] 227 | for sentence in sentences: 228 | if self.do_lower: 229 | words += sentence.lower().split(' ') 230 | else: 231 | words += sentence.split(' ') 232 | self.word_num = len(words) 233 | word_distribution = sorted(collections.Counter(words).items(), key=lambda x: x[1], reverse=True) 234 | self.vocab_size_raw = len(word_distribution) 235 | self.w2i['_PAD'] = 0 236 | self.w2i['_UNK'] = 1 237 | self.w2i['_BOS'] = 2 238 | self.w2i['_EOS'] = 3 239 | self.i2w[0] = '_PAD' 240 | self.i2w[1] = '_UNK' 241 | self.i2w[2] = '_BOS' 242 | self.i2w[3] = '_EOS' 243 | for (word, value) in word_distribution: 244 | if value > self._word_drop: 245 | self.w2i[word] = len(self.w2i) 246 | self.i2w[len(self.i2w)] = word 247 | self.vocab_size = len(self.i2w) 248 | 249 | def _build_corpus(self, sentences): 250 | def _transfer(word): 251 | try: 252 | return self.w2i[word] 253 | except: 254 | return self.w2i['_UNK'] 255 | sentences = [" ".join(sentence.split(" ")[:self.max_length-2]) for sentence in sentences] 256 | corpus = [[self.w2i["_BOS"]] + list(map(_transfer, sentence.split(' '))) + [self.w2i["_EOS"]] for sentence in sentences] 257 | if self.use_length: 258 | corpus_length = np.array([len(i) for i in corpus]) 259 | self.max_sentence_length = corpus_length.max() 260 | self.min_sentence_length = corpus_length.min() 261 | return np.array(corpus), np.array(corpus_length) 262 | else: 263 | return np.array(corpus) 264 | 265 | def _split(self, corpus, ratio, corpus_length=None, label=None, sentences=None): 266 | self.train, self.test, self.label_train, self.label_test = train_test_split(corpus, label, test_size=1-ratio, shuffle=True, random_state=42) 267 | self.train, self.val, self.label_train, self.label_val = train_test_split(self.train, self.label_train, test_size=(1-ratio)/ratio) 268 | indices = list(range(self.sentence_num)) 269 | # np.random.shuffle(indices) 270 | # self.train = corpus[indices[:int(self.sentence_num * ratio)]] 271 | self.train_num = len(self.train) 272 | self.val_num = len(self.val) 273 | # self.test = corpus[indices[int(self.sentence_num * ratio):]] 274 | self.test_num = len(self.test) 275 | # if sentences is not None: 276 | # sentences = np.array(sentences) 277 | # self.train_org = sentences[indices[:int(self.sentence_num * ratio)]] 278 | # self.test_org = sentences[indices[int(self.sentence_num * ratio):]] 279 | # if self.use_length: 280 | # self.train_length = corpus_length[indices[:int(self.sentence_num * ratio)]] 281 | # self.test_length = corpus_length[indices[int(self.sentence_num * ratio):]] 282 | # if self.use_label: 283 | # # self.label_train = label[indices[:int(self.sentence_num * ratio)]] 284 | # # self.label_test = label[indices[int(self.sentence_num*ratio):]] 285 | 286 | def _padding(self, batch_data): 287 | max_length = max([len(i) for i in batch_data]) 288 | for i in range(len(batch_data)): 289 | batch_data[i] += [self.w2i["_PAD"]] * (max_length - len(batch_data[i])) 290 | return np.array(list(batch_data)) 291 | 292 | def train_generator(self, batch_size, shuffle=True): 293 | indices = list(range(self.train_num)) 294 | if shuffle: 295 | np.random.shuffle(indices) 296 | while True: 297 | batch_indices = indices[0:batch_size] # 产生一个batch的index 298 | indices = indices[batch_size:] # 去掉本次index 299 | if len(batch_indices) == 0: 300 | return True 301 | batch_data = self.train[batch_indices] 302 | batch_data = self._padding(batch_data) 303 | result = [batch_data] 304 | if self.use_length: 305 | batch_length = self.train_length[batch_indices] 306 | result.append(batch_length) 307 | if self.use_label: 308 | batch_label = self.label_train[batch_indices] 309 | result.append(batch_label) 310 | yield tuple(result) 311 | 312 | def val_generator(self, batch_size, shuffle=True): 313 | indices = list(range(self.val_num)) 314 | if shuffle: 315 | np.random.shuffle(indices) 316 | while True: 317 | batch_indices = indices[0:batch_size] # 产生一个batch的index 318 | indices = indices[batch_size:] # 去掉本次index 319 | if len(batch_indices) == 0: 320 | return True 321 | batch_data = self.val[batch_indices] 322 | batch_data = self._padding(batch_data) 323 | result = [batch_data] 324 | if self.use_length: 325 | batch_length = self.val_length[batch_indices] 326 | result.append(batch_length) 327 | if self.use_label: 328 | batch_label = self.label_val[batch_indices] 329 | result.append(batch_label) 330 | yield tuple(result) 331 | 332 | 333 | def test_generator(self, batch_size, shuffle=True): 334 | indices = list(range(self.test_num)) 335 | if shuffle: 336 | np.random.shuffle(indices) 337 | while True: 338 | batch_indices = indices[0:batch_size] # 产生一个batch的index 339 | indices = indices[batch_size:] # 去掉本次index 340 | if len(batch_indices) == 0: 341 | return True 342 | batch_data = self.test[batch_indices] 343 | batch_data = self._padding(batch_data) 344 | result = [batch_data] 345 | if self.use_length: 346 | batch_length = self.test_length[batch_indices] 347 | result.append(batch_length) 348 | if self.use_label: 349 | batch_label = self.label_test[batch_indices] 350 | result.append(batch_label) 351 | yield tuple(result) 352 | 353 | pass 354 | 355 | 356 | 357 | 358 | 359 | 360 | 361 | -------------------------------------------------------------------------------- /incorporate.py: -------------------------------------------------------------------------------- 1 | import utils 2 | import random 3 | import dataset 4 | import numpy as np 5 | import logging 6 | import os 7 | import json 8 | import time 9 | import csv 10 | from tqdm import tqdm 11 | 12 | import torch 13 | from torch.utils.data import DataLoader, RandomSampler, SequentialSampler 14 | 15 | from sklearn.model_selection import train_test_split 16 | from models import birnn as BiRNN,cnn as CNN,lstmatt as LSTMATT,fcn as FCN, r_bilstm_c as RBC,\ 17 | bilstm_dense as BLSTMDENSE,sesy as SESY, ms as MS_TL 18 | 19 | 20 | from transformers import ( 21 | AdamW, 22 | AutoTokenizer, 23 | get_linear_schedule_with_warmup, 24 | ) 25 | 26 | try: 27 | from torch.utils.tensorboard import SummaryWriter 28 | except ImportError: 29 | from tensorboardX import SummaryWriter 30 | 31 | task_metrics = {"steganalysis" : "accuracy", 32 | "graph_steganalysis" : "accuracy",} 33 | 34 | 35 | logger = logging.getLogger(__name__) 36 | time_stamp = "-".join(time.ctime().split()) 37 | 38 | 39 | def set_seed(seed): 40 | random.seed(seed) 41 | np.random.seed(seed) 42 | torch.manual_seed(seed) 43 | torch.cuda.manual_seed_all(seed) 44 | 45 | 46 | def load_model(Configs, VOCAB_SIZE=None, checkpoint=None): 47 | 48 | logger.info("----------------init model-----------------------") 49 | 50 | if Configs.model.lower() in ["ms-birnnc", "ms-birnn_c"]: 51 | Model_Configs = Configs.MSBiRNNC 52 | Model_Configs.vocab_size = VOCAB_SIZE 53 | model = MS_TL.BiRNN_C(**{**Model_Configs, "class_num": Configs.class_num, }) 54 | 55 | # Model_Configs = Configs.MSBiRNNC.MSCNN 56 | # Model_Configs.vocab_size = VOCAB_SIZE 57 | # cnn_sub_model = MS_TL.CNN(**{**Model_Configs, "class_num": Configs.class_num, }) 58 | # Model_Configs = Configs.MSBiRNNC.MSBiRNN 59 | # Model_Configs.vocab_size = VOCAB_SIZE 60 | # rnn_sub_model = MS_TL.BiRNN(**{**Model_Configs, "class_num": Configs.class_num, }) 61 | 62 | else: 63 | logger.error("no such model, exit") 64 | exit() 65 | if checkpoint is not None: 66 | logger.info("---------------------loading model from {}------------\n\n".format(checkpoint)) 67 | model = torch.load(os.path.join(checkpoint, "pytorch_model.bin")) 68 | 69 | logger.info("Student Model Configs") 70 | logger.info(json.dumps({**{"MODEL_TYPE": Configs.model}, **Model_Configs, })) 71 | model = model.to(Configs.device) 72 | return model 73 | 74 | 75 | def train( model, Configs, tokenizer): 76 | train_dataset = load_and_cache_examples(Configs.Dataset, Configs.task_name, tokenizer) 77 | Training_Configs = Configs.Training_with_Processor 78 | Training_Configs.train_batch_size = Training_Configs.per_gpu_train_batch_size * max(1, Training_Configs.n_gpu) 79 | train_sampler = RandomSampler(train_dataset) 80 | train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=Training_Configs.train_batch_size) 81 | 82 | if Training_Configs.max_steps > 0: 83 | t_total = Training_Configs.max_steps 84 | Training_Configs.num_train_epochs = Training_Configs.max_steps // (len(train_dataloader) // Training_Configs.gradient_accumulation_steps) + 1 85 | else: 86 | t_total = len(train_dataloader) // Training_Configs.gradient_accumulation_steps * Training_Configs.num_train_epochs 87 | 88 | num_warmup_steps = int(Training_Configs.warmup_ratio * t_total) 89 | # Prepare optimizer and schedule (linear warmup and decay) 90 | no_decay = ["bias", "LayerNorm.weight"] 91 | optimizer_grouped_parameters = [ 92 | { 93 | "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 94 | "weight_decay": Training_Configs.weight_decay, 95 | }, 96 | {"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0}, 97 | ] 98 | optimizer = AdamW(optimizer_grouped_parameters, lr=Training_Configs.learning_rate, eps=Training_Configs.adam_epsilon) 99 | scheduler = get_linear_schedule_with_warmup( 100 | optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=t_total 101 | ) 102 | 103 | # Check if saved optimizer or scheduler states exist 104 | if os.path.isfile(os.path.join(Training_Configs.model_name_or_path, "optimizer.pt")) and os.path.isfile( 105 | os.path.join(Training_Configs.model_name_or_path, "scheduler.pt") 106 | ): 107 | # Load in optimizer and scheduler states 108 | optimizer.load_state_dict(torch.load(os.path.join(Training_Configs.model_name_or_path, "optimizer.pt"))) 109 | scheduler.load_state_dict(torch.load(os.path.join(Training_Configs.model_name_or_path, "scheduler.pt"))) 110 | 111 | 112 | # Train! 113 | logger.info("***** Running training *****") 114 | logger.info(" Num examples = %d", len(train_dataset)) 115 | logger.info(" Num Epochs = %d", Training_Configs.num_train_epochs) 116 | logger.info(" Instantaneous batch size per GPU = %d", Training_Configs.per_gpu_train_batch_size) 117 | logger.info( 118 | " Total train batch size (w. parallel, distributed & accumulation) = %d", 119 | Training_Configs.train_batch_size 120 | * Training_Configs.gradient_accumulation_steps) 121 | logger.info(" Gradient Accumulation steps = %d", Training_Configs.gradient_accumulation_steps) 122 | logger.info(" Total optimization steps = %d", t_total) 123 | 124 | global_step = 0 125 | epochs_trained = 0 126 | steps_trained_in_current_epoch = 0 127 | 128 | # Check if continuing training from a checkpoint 129 | if os.path.exists(Training_Configs.model_name_or_path): 130 | # set global_step to gobal_step of last saved checkpoint from model path 131 | try: 132 | global_step = int(Training_Configs.model_name_or_path.split("-")[-1].split("/")[0]) 133 | except ValueError: 134 | global_step = 0 135 | 136 | epochs_trained = global_step // (len(train_dataloader) // args.gradient_accumulation_steps) 137 | steps_trained_in_current_epoch = global_step % (len(train_dataloader) // args.gradient_accumulation_steps) 138 | 139 | logger.info(" Continuing training from checkpoint, will skip to saved global_step") 140 | logger.info(" Continuing training from epoch %d", epochs_trained) 141 | logger.info(" Continuing training from global step %d", global_step) 142 | logger.info(" Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch) 143 | 144 | tr_loss, logging_loss = 0.0, 0.0 145 | model.zero_grad() 146 | train_iterator = range(epochs_trained, int(Training_Configs.num_train_epochs)) 147 | 148 | set_seed(Configs.seed) # Added here for reproductibility 149 | 150 | best_val_metric = None 151 | for epoch_n in train_iterator: 152 | epoch_iterator = tqdm(train_dataloader, desc=f"Epoch {epoch_n}", disable=False) 153 | for step, batch in enumerate(epoch_iterator): 154 | 155 | # Skip past any already trained steps if resuming training 156 | if steps_trained_in_current_epoch > 0: 157 | steps_trained_in_current_epoch -= 1 158 | continue 159 | 160 | model.train() 161 | batch = tuple(t.to(Configs.device) for t in batch) 162 | inputs = {"input_ids": batch[0], "attention_mask": batch[1], "token_type_ids": batch[2], "labels": batch[3]} 163 | if Configs.task_name == "graph_steganalysis": 164 | inputs = {**inputs,"graph":batch[4]} 165 | 166 | outputs = model(**inputs) 167 | loss = outputs[0] 168 | 169 | if Training_Configs.n_gpu > 1: 170 | loss = loss.mean() # mean() to average on multi-gpu parallel training 171 | if Training_Configs.gradient_accumulation_steps > 1: 172 | loss = loss / Training_Configs.gradient_accumulation_steps 173 | 174 | loss.backward() 175 | 176 | tr_loss += loss.item() 177 | if (step + 1) % Training_Configs.gradient_accumulation_steps == 0: 178 | torch.nn.utils.clip_grad_norm_(model.parameters(), Training_Configs.max_grad_norm) 179 | optimizer.step() 180 | scheduler.step() # Update learning rate schedule 181 | model.zero_grad() 182 | global_step += 1 183 | 184 | logs = {} 185 | if Training_Configs.logging_steps > 0 and global_step % Training_Configs.logging_steps == 0: 186 | # Log metrics 187 | if Training_Configs.evaluate_during_training: # Only evaluate when single GPU otherwise metrics may not average well 188 | results, _ , _ = evaluate( model, tokenizer, Configs, Configs.task_name, use_tqdm=False) 189 | for key, value in results.items(): 190 | eval_key = "eval_{}".format(key) 191 | logs[eval_key] = value 192 | 193 | loss_scalar = (tr_loss - logging_loss) / Training_Configs.logging_steps 194 | learning_rate_scalar = scheduler.get_last_lr()[0] 195 | logs["learning_rate"] = learning_rate_scalar 196 | logs["avg_loss_since_last_log"] = loss_scalar 197 | logging_loss = tr_loss 198 | 199 | logging.info(json.dumps({**logs, **{"step": global_step}})) 200 | 201 | 202 | if ( Training_Configs.eval_and_save_steps > 0 and global_step % Training_Configs.eval_and_save_steps == 0) \ 203 | or (step+1==t_total): 204 | # evaluate 205 | results, _, _ = evaluate(model, tokenizer, Configs, Configs.task_name, use_tqdm=False) 206 | logger.info("------Next Evalset will be loaded from cached file------") 207 | Configs.Dataset.overwrite_cache = False 208 | for key, value in results.items(): 209 | logs[f"eval_{key}"] = value 210 | logger.info(json.dumps({**logs, **{"step": global_step}})) 211 | 212 | # save 213 | if Training_Configs.save_only_best: 214 | output_dirs = [os.path.join(Configs.out_dir, Configs.checkpoint)] 215 | else: 216 | output_dirs = [os.path.join(Configs.out_dir, f"checkpoint-{global_step}")] 217 | curr_val_metric = results[task_metrics[Configs.task_name]] 218 | if best_val_metric is None or curr_val_metric > best_val_metric: 219 | # check if best model so far 220 | logger.info("Congratulations, best model so far!") 221 | best_val_metric = curr_val_metric 222 | 223 | for output_dir in output_dirs: 224 | # in each dir, save model, tokenizer, args, optimizer, scheduler 225 | if not os.path.exists(output_dir): 226 | os.makedirs(output_dir) 227 | model_to_save = ( 228 | model.module if hasattr(model, "module") else model 229 | ) # Take care of distributed/parallel training 230 | logger.info("Saving model checkpoint to %s", output_dir) 231 | torch.save(model_to_save, os.path.join(output_dir, "pytorch_model.bin")) 232 | torch.save(Configs.state_dict, os.path.join(output_dir, "training_args.bin")) 233 | torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt")) 234 | torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt")) 235 | tokenizer.save_pretrained(output_dir) 236 | logger.info("\tSaved model checkpoint to %s", output_dir) 237 | 238 | 239 | if Training_Configs.max_steps > 0 and global_step > Training_Configs.max_steps: 240 | epoch_iterator.close() 241 | break 242 | if Training_Configs.max_steps > 0 and global_step > Training_Configs.max_steps: 243 | # train_iterator.close() 244 | break 245 | 246 | return global_step, tr_loss / global_step 247 | 248 | 249 | def evaluate(model, tokenizer, Configs, task_name, split="dev", prefix="", use_tqdm=True): 250 | Training_Configs = Configs.Training_with_Processor 251 | results = {} 252 | if task_name == "record": 253 | eval_dataset, eval_answers = load_and_cache_examples(Configs.Dataset, task_name, tokenizer, split=split) 254 | else: 255 | eval_dataset = load_and_cache_examples(Configs.Dataset, task_name, tokenizer, split=split) 256 | 257 | if not os.path.exists(Configs.out_dir): 258 | os.makedirs(Configs.out_dir) 259 | 260 | Training_Configs.eval_batch_size = Training_Configs.per_gpu_eval_batch_size * max(1, Training_Configs.n_gpu) 261 | # Note that DistributedSampler samples randomly 262 | eval_sampler = SequentialSampler(eval_dataset) 263 | eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=Training_Configs.eval_batch_size) 264 | 265 | # multi-gpu eval 266 | if Training_Configs.n_gpu > 1 and not isinstance(model, torch.nn.DataParallel): 267 | model = torch.nn.DataParallel(model) 268 | 269 | # Eval! 270 | logger.info(f"***** Running evaluation: {prefix} on {task_name} {split} *****") 271 | logger.info(" Num examples = %d", len(eval_dataset)) 272 | logger.info(" Batch size = %d", Training_Configs.eval_batch_size) 273 | eval_loss = 0.0 274 | nb_eval_steps = 0 275 | preds = None 276 | out_label_ids = None 277 | ex_ids = None 278 | eval_dataloader = tqdm(eval_dataloader, desc="Evaluating") if use_tqdm else eval_dataloader 279 | for batch in eval_dataloader: 280 | model.eval() 281 | batch = tuple(t.to(Configs.device) for t in batch) 282 | guids = batch[-1] 283 | 284 | max_seq_length = batch[0].size(1) 285 | if Training_Configs.use_fixed_seq_length: # no dynamic sequence length 286 | batch_seq_length = max_seq_length 287 | else: 288 | batch_seq_length = torch.max(batch[-2], 0)[0].item() 289 | 290 | if batch_seq_length < max_seq_length: 291 | inputs = {"input_ids": batch[0][:, :batch_seq_length].contiguous(), 292 | "attention_mask": batch[1][:, :batch_seq_length].contiguous(), 293 | "token_type_ids":batch[2][:, :batch_seq_length].contiguous(), 294 | "labels": batch[3]} 295 | # inputs["token_type_ids"] = ( 296 | # batch[2][:, :batch_seq_length].contiguous() if Configs.model_type in ["bert", "xlnet", "albert"] else None 297 | # ) # XLM, DistilBERT, RoBERTa, and XLM-RoBERTa don't use segment_ids 298 | else: 299 | inputs = {"input_ids": batch[0], "attention_mask": batch[1], "token_type_ids":batch[2],"labels": batch[3]} 300 | 301 | if Configs.task_name == "graph_steganalysis": 302 | inputs = {**inputs,"graph":batch[4]} 303 | # inputs["token_type_ids"] = ( 304 | # batch[2][:, :batch_seq_length].contiguous() if Configs.model_type in ["bert", "xlnet", "albert"] else None 305 | # ) # XLM, DistilBERT, RoBERTa, and XLM-RoBERTa don't use segment_ids 306 | 307 | with torch.no_grad(): 308 | outputs = model(**inputs) 309 | tmp_eval_loss, logits = outputs[:2] 310 | 311 | eval_loss += tmp_eval_loss.mean().item() 312 | nb_eval_steps += 1 313 | if preds is None: 314 | preds = logits.detach().cpu().numpy() 315 | out_label_ids = inputs["labels"].detach().cpu().numpy() 316 | ex_ids = [guids.detach().cpu().numpy()] 317 | else: 318 | preds = np.append(preds, logits.detach().cpu().numpy(), axis=0) 319 | out_label_ids = np.append(out_label_ids, inputs["labels"].detach().cpu().numpy(), axis=0) 320 | ex_ids.append(guids.detach().cpu().numpy()) 321 | 322 | ex_ids = np.concatenate(ex_ids, axis=0) 323 | eval_loss = eval_loss / nb_eval_steps 324 | 325 | preds = np.argmax(preds, axis=1) 326 | 327 | result = utils.compute_metrics(task_name, preds, out_label_ids,) 328 | results.update(result) 329 | if prefix == "": 330 | return results, preds, ex_ids 331 | output_eval_file = os.path.join(Configs.out_dir, prefix, "eval_results.txt") 332 | with open(output_eval_file, "w") as writer: 333 | logger.info(f"***** {split} results: {prefix} *****") 334 | for key in sorted(result.keys()): 335 | logger.info(" %s = %s", key, str(result[key])) 336 | writer.write("%s = %s\n" % (key, str(result[key]))) 337 | 338 | return results, preds, ex_ids 339 | 340 | 341 | def load_and_cache_examples(Dataset_Configs, task, tokenizer, split="train"): 342 | if task == "steganalysis": 343 | from processors.process import SteganalysisProcessor as DataProcessor 344 | elif task == "graph_steganalysis": 345 | from processors.graph_process import GraphSteganalysisProcessor as DataProcessor 346 | 347 | processor = DataProcessor(tokenizer) 348 | # Load data features from cache or dataset file 349 | cached_tensors_file = os.path.join( 350 | Dataset_Configs.csv_dir, 351 | "tensors_{}_{}_{}".format( 352 | split, time_stamp, str(task), 353 | ), 354 | ) 355 | if os.path.exists(cached_tensors_file) and not Dataset_Configs.overwrite_cache: 356 | logger.info("Loading tensors from cached file %s", cached_tensors_file) 357 | start_time = time.time() 358 | dataset = torch.load(cached_tensors_file) 359 | logger.info("\tFinished loading tensors") 360 | logger.info(f"\tin {time.time() - start_time}s") 361 | 362 | else: 363 | # no cached tensors, process data from scratch 364 | logger.info("Creating features from dataset file at %s", Dataset_Configs.csv_dir) 365 | if split == "train": 366 | get_examples = processor.get_train_examples 367 | elif split == "dev": 368 | get_examples = processor.get_dev_examples 369 | elif split == "test": 370 | get_examples = processor.get_test_examples 371 | 372 | examples = get_examples(Dataset_Configs.csv_dir) 373 | dataset = processor.convert_examples_to_features(examples,) 374 | logger.info("\tFinished creating features") 375 | 376 | logger.info("\tFinished converting features into tensors") 377 | if Dataset_Configs.save_cache: 378 | logger.info("Saving features into cached file %s", cached_tensors_file) 379 | torch.save(dataset, cached_tensors_file) 380 | logger.info("\tFinished saving tensors") 381 | 382 | if task == "record" and split in ["dev", "test"]: 383 | answers = processor.get_answers(Dataset_Configs.csv_dir, split) 384 | return dataset, answers 385 | else: 386 | return dataset 387 | 388 | 389 | def main(Configs): 390 | Dataset_Configs = Configs.Dataset 391 | Configs.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 392 | os.makedirs(Configs.out_dir,exist_ok=True) 393 | set_seed(Configs.seed) 394 | 395 | # Setup logging 396 | logging.basicConfig( 397 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 398 | datefmt="%m/%d/%Y %H:%M:%S", 399 | level=logging.INFO, 400 | ) 401 | handler = logging.FileHandler(os.path.join(Configs.out_dir,time_stamp+"_log")) 402 | handler.setLevel(logging.INFO) 403 | formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') 404 | handler.setFormatter(formatter) 405 | logger.addHandler(handler) 406 | 407 | logger.info("--------------Main Configs-------------------") 408 | logger.info(Configs) 409 | 410 | logger.info("--------------loading data-------------------") 411 | logger.info("Dataset Configs") 412 | logger.info(json.dumps(Dataset_Configs)) 413 | 414 | Configs.model_name_or_path = Configs.Training_with_Processor.model_name_or_path 415 | logger.info("\tload plm name or path from Training_with_Processor args") 416 | 417 | logger.info("-------------------------------------------------------------------------------------------------------") 418 | # prepare data 419 | if Configs.use_processor: 420 | # translate txt into csv 421 | if not Dataset_Configs.resplit and os.path.exists(Dataset_Configs.csv_dir) and \ 422 | os.path.exists(os.path.join(Dataset_Configs.csv_dir,"train.csv")) and \ 423 | os.path.exists(os.path.join(Dataset_Configs.csv_dir,"val.csv")) and \ 424 | os.path.exists(os.path.join(Dataset_Configs.csv_dir,"val.csv")): 425 | pass 426 | else: 427 | os.makedirs(Dataset_Configs.csv_dir, exist_ok=True) 428 | with open(Dataset_Configs.cover_file, 'r', encoding='utf-8') as f: 429 | covers = f.read().split("\n") 430 | covers = list(filter(lambda x: x not in ['', None], covers)) 431 | random.shuffle(covers) 432 | with open(Dataset_Configs.stego_file, 'r', encoding='utf-8') as f: 433 | stegos = f.read().split("\n") 434 | stegos = list(filter(lambda x: x not in ['', None], stegos)) 435 | random.shuffle(stegos) 436 | texts = covers+stegos 437 | labels = [0]*len(covers) + [1]*len(stegos) 438 | 439 | # val_ratio = (1-Dataset_Configs.split_ratio)/Dataset_Configs.split_ratio 440 | train_texts,test_texts,train_labels,test_labels = train_test_split(texts,labels,train_size=Dataset_Configs.split_ratio) 441 | train_texts,val_texts,train_labels,val_labels = train_test_split(train_texts, train_labels, train_size=Dataset_Configs.split_ratio) 442 | def write2file(X, Y, filename): 443 | with open(filename, "w", encoding="utf-8") as f: 444 | writer = csv.writer(f) 445 | writer.writerow(["text", "label"]) 446 | for x, y in zip(X, Y): 447 | writer.writerow([x, y]) 448 | write2file(train_texts,train_labels, os.path.join(Dataset_Configs.csv_dir,"train.csv")) 449 | write2file(val_texts, val_labels, os.path.join(Dataset_Configs.csv_dir, "val.csv")) 450 | write2file(test_texts, test_labels, os.path.join(Dataset_Configs.csv_dir, "test.csv")) 451 | tokenizer = AutoTokenizer.from_pretrained(Configs.model_name_or_path,) 452 | VOCAB_SIZE = tokenizer.vocab_size 453 | 454 | else: 455 | # not recommend 456 | raise NotImplemented 457 | 458 | model = load_model(Configs, VOCAB_SIZE=VOCAB_SIZE) 459 | 460 | logger.info("--------------start training--------------------") 461 | 462 | if Configs.use_processor: 463 | # train_dataset = load_and_cache_examples(Dataset_Configs, Configs.task_name, tokenizer) # , evaluate=False) 464 | global_step, tr_loss = train(model, Configs, tokenizer) 465 | logger.info(" global_step = %s, average loss = %s", global_step, tr_loss) 466 | Tarining_Configs = Configs.Training_with_Processor 467 | 468 | checkpoints = [os.path.join(Configs.out_dir, Configs.checkpoint)] 469 | logger.info("Evaluate the following checkpoints: %s", checkpoints) 470 | 471 | for checkpoint in checkpoints: 472 | prefix = checkpoint.split("/")[-1] 473 | student_model = load_model(Configs, VOCAB_SIZE=tokenizer.vocab_size,checkpoint=checkpoint) 474 | result, preds, ex_ids = evaluate(student_model, tokenizer, Configs, Configs.task_name, split="test", prefix=prefix) 475 | test_acc = result["accuracy"] 476 | test_precision = result["precision"] 477 | test_recall = result["recall"] 478 | test_Fscore = result["f1_score"] 479 | 480 | else: 481 | raise NotImplemented 482 | 483 | record_file = Configs.record_file if Configs.record_file is not None else "record.txt" 484 | result_path = os.path.join(Configs.out_dir, time_stamp+"----"+record_file) 485 | with open(result_path, "w", encoding="utf-8") as f: 486 | f.write("test phase:\naccuracy\t{:.4f}\nprecision\t{:.4f}\nrecall\t{:.4f}\nf1_score\t{:.4f}" 487 | .format(test_acc*100,test_precision*100,test_recall*100,test_Fscore*100)) 488 | 489 | 490 | if __name__ == '__main__': 491 | import argparse 492 | parser = argparse.ArgumentParser(description="argument for generation") 493 | parser.add_argument("--config_path", type=str, default="./configs/stage3.json") 494 | args = parser.parse_args() 495 | Configs = utils.Config(args.config_path).get_configs() 496 | os.environ["CUDA_VISIBLE_DEVICES"] = Configs.gpuid 497 | main(Configs) -------------------------------------------------------------------------------- /distil.py: -------------------------------------------------------------------------------- 1 | import utils 2 | import random 3 | import dataset 4 | import numpy as np 5 | import logging 6 | import os 7 | import json 8 | import time 9 | import csv 10 | from tqdm import tqdm 11 | 12 | import torch 13 | from torch.utils.data import DataLoader, RandomSampler, SequentialSampler 14 | 15 | from sklearn.model_selection import train_test_split 16 | from models import birnn as BiRNN,cnn as CNN,lstmatt as LSTMATT,fcn as FCN, r_bilstm_c as RBC,\ 17 | bilstm_dense as BLSTMDENSE,sesy as SESY, ms as MS_TL 18 | 19 | 20 | from transformers import ( 21 | AdamW, 22 | AutoTokenizer, 23 | get_linear_schedule_with_warmup, 24 | ) 25 | 26 | try: 27 | from torch.utils.tensorboard import SummaryWriter 28 | except ImportError: 29 | from tensorboardX import SummaryWriter 30 | 31 | task_metrics = {"steganalysis" : "accuracy", 32 | "graph_steganalysis" : "accuracy",} 33 | 34 | 35 | logger = logging.getLogger(__name__) 36 | time_stamp = "-".join(time.ctime().split()) 37 | 38 | 39 | def set_seed(seed): 40 | random.seed(seed) 41 | np.random.seed(seed) 42 | torch.manual_seed(seed) 43 | torch.cuda.manual_seed_all(seed) 44 | 45 | 46 | def load_teacher_model(Configs, VOCAB_SIZE=None, checkpoint=None): 47 | # set model 48 | 49 | logger.info("----------------init model-----------------------") 50 | 51 | logger.info("-------------loading teacher model--------------\n\n") 52 | model_name_or_path = Configs.teacher_model_name_or_path 53 | 54 | if Configs.teacher_model.lower() in ["ts-csw", "cnn"]: 55 | Model_Configs = Configs.CNN 56 | model = CNN.BERT_TC.from_pretrained(model_name_or_path, 57 | **{**Model_Configs, "class_num": Configs.class_num, }) 58 | elif Configs.teacher_model.lower() in ["birnn"]: 59 | Model_Configs = Configs.RNN 60 | model = BiRNN.BERT_TC.from_pretrained(model_name_or_path, 61 | **{**Model_Configs, "class_num": Configs.class_num, }) 62 | elif Configs.teacher_model.lower() in ["fcn", "fc"]: 63 | Model_Configs = Configs.FCN 64 | model = FCN.BERT_TC.from_pretrained(model_name_or_path,**{**Model_Configs, "class_num": Configs.class_num, }) 65 | elif Configs.teacher_model.lower() in ["lstmatt"]: 66 | Model_Configs = Configs.LSTMATT 67 | model = LSTMATT.BERT_TC.from_pretrained(model_name_or_path, 68 | **{**Model_Configs, "class_num": Configs.class_num, }) 69 | elif Configs.teacher_model.lower() in ["r-bilstm-c", "r-b-c", "rbc", "rbilstmc"]: 70 | Model_Configs = Configs.RBiLSTMC 71 | model = RBC.BERT_TC.from_pretrained(model_name_or_path, 72 | **{**Model_Configs, "class_num": Configs.class_num, }) 73 | elif Configs.teacher_model.lower() in ["bilstmdense", "bilstm-dense", "bilstm_dense", "bi-lstm-dense"]: 74 | Model_Configs = Configs.BiLSTMDENSE 75 | model = BLSTMDENSE.BERT_TC.from_pretrained(model_name_or_path, 76 | **{**Model_Configs, "class_num": Configs.class_num, }) 77 | elif Configs.teacher_model.lower() in ["sesy"]: 78 | Model_Configs = Configs.SESY 79 | model = SESY.BERT_TC.from_pretrained(model_name_or_path, 80 | **{**Model_Configs, "class_num": Configs.class_num, }) 81 | elif Configs.teacher_model.lower() in ["ft-bert", "finetune-bert", "ftbert", "finetunebert"]: 82 | Model_Configs = Configs.FineTuneBERT 83 | model = MS_TL.MyBert.from_pretrained(model_name_or_path, 84 | **{**Model_Configs, "class_num": Configs.class_num, }) 85 | else: 86 | logger.error("no such model, exit") 87 | exit() 88 | 89 | logger.info("Teacher Model Configs") 90 | logger.info(json.dumps({**{"MODEL_TYPE": Configs.teacher_model}, **Model_Configs, })) 91 | model = model.to(Configs.device) 92 | return model 93 | 94 | 95 | def load_student_model(Configs, VOCAB_SIZE=None, checkpoint=None): 96 | 97 | logger.info("----------------init model-----------------------") 98 | if Configs.student_model.lower() in ["ts-csw", "cnn"]: 99 | Model_Configs = Configs.CNN 100 | Model_Configs.vocab_size = VOCAB_SIZE 101 | model = CNN.TC(**{**Model_Configs, "class_num": Configs.class_num, }) 102 | elif Configs.student_model.lower() in ["birnn"]: 103 | Model_Configs = Configs.RNN 104 | Model_Configs.vocab_size = VOCAB_SIZE 105 | model = BiRNN.TC(**{**Model_Configs, "class_num": Configs.class_num}) 106 | elif Configs.student_model.lower() in ["fcn", "fc"]: 107 | Model_Configs = Configs.FCN 108 | Model_Configs.vocab_size = VOCAB_SIZE 109 | model = FCN.TC(**{**Model_Configs, "class_num": Configs.class_num, }) 110 | elif Configs.student_model.lower() in ["lstmatt"]: 111 | Model_Configs = Configs.LSTMATT 112 | Model_Configs.vocab_size = VOCAB_SIZE 113 | model = LSTMATT.TC(**{**Model_Configs, "class_num": Configs.class_num, }) 114 | elif Configs.student_model.lower() in ["r-bilstm-c", "r-b-c", "rbc", "rbilstmc"]: 115 | Model_Configs = Configs.RBiLSTMC 116 | Model_Configs.vocab_size = VOCAB_SIZE 117 | model = RBC.TC(**{**Model_Configs, "class_num": Configs.class_num, }) 118 | elif Configs.student_model.lower() in ["bilstmdense", "bilstm-dense", "bilstm_dense", "bi-lstm-dense"]: 119 | Model_Configs = Configs.BiLSTMDENSE 120 | Model_Configs.vocab_size = VOCAB_SIZE 121 | model = BLSTMDENSE.TC(**{**Model_Configs, "class_num": Configs.class_num, }) 122 | elif Configs.student_model.lower() in ["sesy"]: 123 | Model_Configs = Configs.SESY 124 | Model_Configs.vocab_size = VOCAB_SIZE 125 | model = SESY.TC(**{**Model_Configs, "class_num": Configs.class_num, }) 126 | elif Configs.student_model.lower() in ["ms-cnn", "mscnn", "multistage-cnn"]: 127 | Model_Configs = Configs.MSCNN 128 | Model_Configs.vocab_size = VOCAB_SIZE 129 | model = MS_TL.CNN(**{**Model_Configs, "class_num": Configs.class_num, }) 130 | elif Configs.student_model.lower() in ["ms-rnn", "msrnn", "multistage-rnn","ms-birnn", "msbirnn", "multistage-birnn"]: 131 | Model_Configs = Configs.MSBiRNN 132 | Model_Configs.vocab_size = VOCAB_SIZE 133 | model = MS_TL.BiRNN(**{**Model_Configs, "class_num": Configs.class_num, }) 134 | elif Configs.model.lower() in ["gnn"]: 135 | raise NotImplemented 136 | else: 137 | logger.error("no such model, exit") 138 | exit() 139 | if checkpoint is not None: 140 | logger.info("---------------------loading model from {}------------\n\n".format(checkpoint)) 141 | model = torch.load(os.path.join(checkpoint, "pytorch_model.bin")) 142 | 143 | logger.info("Student Model Configs") 144 | logger.info(json.dumps({**{"MODEL_TYPE": Configs.student_model}, **Model_Configs, })) 145 | model = model.to(Configs.device) 146 | return model 147 | 148 | 149 | def train(teacher_model=None, model=None, Configs=None, tokenizer=None): 150 | train_dataset = load_and_cache_examples(Configs.Dataset, Configs.task_name, tokenizer) 151 | Training_Configs = Configs.Training_with_Processor 152 | Training_Configs.train_batch_size = Training_Configs.per_gpu_train_batch_size * max(1, Training_Configs.n_gpu) 153 | train_sampler = RandomSampler(train_dataset) 154 | train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=Training_Configs.train_batch_size) 155 | 156 | if teacher_model is not None: 157 | if Training_Configs.teacher_criteration == "KLDivLoss": 158 | teaching_criterion = torch.nn.KLDivLoss(reduction='batchmean') 159 | 160 | if Training_Configs.max_steps > 0: 161 | t_total = Training_Configs.max_steps 162 | Training_Configs.num_train_epochs = Training_Configs.max_steps // (len(train_dataloader) // Training_Configs.gradient_accumulation_steps) + 1 163 | else: 164 | t_total = len(train_dataloader) // Training_Configs.gradient_accumulation_steps * Training_Configs.num_train_epochs 165 | 166 | num_warmup_steps = int(Training_Configs.warmup_ratio * t_total) 167 | # Prepare optimizer and schedule (linear warmup and decay) 168 | no_decay = ["bias", "LayerNorm.weight"] 169 | optimizer_grouped_parameters = [ 170 | { 171 | "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 172 | "weight_decay": Training_Configs.weight_decay, 173 | }, 174 | {"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0}, 175 | ] 176 | optimizer = AdamW(optimizer_grouped_parameters, lr=Training_Configs.learning_rate, eps=Training_Configs.adam_epsilon) 177 | scheduler = get_linear_schedule_with_warmup( 178 | optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=t_total 179 | ) 180 | 181 | # Check if saved optimizer or scheduler states exist 182 | if os.path.isfile(os.path.join(Training_Configs.model_name_or_path, "optimizer.pt")) and os.path.isfile( 183 | os.path.join(Training_Configs.model_name_or_path, "scheduler.pt") 184 | ): 185 | # Load in optimizer and scheduler states 186 | optimizer.load_state_dict(torch.load(os.path.join(Training_Configs.model_name_or_path, "optimizer.pt"))) 187 | scheduler.load_state_dict(torch.load(os.path.join(Training_Configs.model_name_or_path, "scheduler.pt"))) 188 | 189 | 190 | # Train! 191 | logger.info("***** Running training *****") 192 | logger.info(" Num examples = %d", len(train_dataset)) 193 | logger.info(" Num Epochs = %d", Training_Configs.num_train_epochs) 194 | logger.info(" Instantaneous batch size per GPU = %d", Training_Configs.per_gpu_train_batch_size) 195 | logger.info( 196 | " Total train batch size (w. parallel, distributed & accumulation) = %d", 197 | Training_Configs.train_batch_size 198 | * Training_Configs.gradient_accumulation_steps) 199 | logger.info(" Gradient Accumulation steps = %d", Training_Configs.gradient_accumulation_steps) 200 | logger.info(" Total optimization steps = %d", t_total) 201 | 202 | global_step = 0 203 | epochs_trained = 0 204 | steps_trained_in_current_epoch = 0 205 | 206 | # Check if continuing training from a checkpoint 207 | if os.path.exists(Training_Configs.model_name_or_path): 208 | # set global_step to gobal_step of last saved checkpoint from model path 209 | try: 210 | global_step = int(Training_Configs.model_name_or_path.split("-")[-1].split("/")[0]) 211 | except ValueError: 212 | global_step = 0 213 | 214 | epochs_trained = global_step // (len(train_dataloader) // args.gradient_accumulation_steps) 215 | steps_trained_in_current_epoch = global_step % (len(train_dataloader) // args.gradient_accumulation_steps) 216 | 217 | logger.info(" Continuing training from checkpoint, will skip to saved global_step") 218 | logger.info(" Continuing training from epoch %d", epochs_trained) 219 | logger.info(" Continuing training from global step %d", global_step) 220 | logger.info(" Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch) 221 | 222 | tr_loss, logging_loss = 0.0, 0.0 223 | model.zero_grad() 224 | train_iterator = range(epochs_trained, int(Training_Configs.num_train_epochs)) 225 | 226 | set_seed(Configs.seed) # Added here for reproductibility 227 | 228 | best_val_metric = None 229 | for epoch_n in train_iterator: 230 | epoch_iterator = tqdm(train_dataloader, desc=f"Epoch {epoch_n}", disable=False) 231 | for step, batch in enumerate(epoch_iterator): 232 | 233 | # Skip past any already trained steps if resuming training 234 | if steps_trained_in_current_epoch > 0: 235 | steps_trained_in_current_epoch -= 1 236 | continue 237 | 238 | model.train() 239 | batch = tuple(t.to(Configs.device) for t in batch) 240 | inputs = {"input_ids": batch[0], "attention_mask": batch[1], "token_type_ids": batch[2], "labels": batch[3]} 241 | if Configs.task_name == "graph_steganalysis": 242 | inputs = {**inputs,"graph":batch[4]} 243 | 244 | if teacher_model is not None: 245 | student_outputs = model(**inputs) 246 | teacher_outputs = teacher_model(**inputs) 247 | student_loss = student_outputs[0] 248 | student_logits = student_outputs[1] 249 | teacher_logits = teacher_outputs[1] 250 | outputs_S = torch.nn.functional.log_softmax(student_logits / Training_Configs.distil_T, dim=1) 251 | outputs_T = torch.nn.functional.softmax(teacher_logits / Training_Configs.distil_T, dim=1) 252 | teaching_loss = teaching_criterion(outputs_S, outputs_T)*Training_Configs.distil_T*Training_Configs.distil_T 253 | loss = student_loss*(1-Training_Configs.distil_alpha) + teaching_loss*Training_Configs.distil_alpha 254 | else: 255 | outputs = model(**inputs) 256 | loss = outputs[0] 257 | 258 | if Training_Configs.n_gpu > 1: 259 | loss = loss.mean() # mean() to average on multi-gpu parallel training 260 | if Training_Configs.gradient_accumulation_steps > 1: 261 | loss = loss / Training_Configs.gradient_accumulation_steps 262 | 263 | loss.backward() 264 | 265 | tr_loss += loss.item() 266 | if (step + 1) % Training_Configs.gradient_accumulation_steps == 0: 267 | torch.nn.utils.clip_grad_norm_(model.parameters(), Training_Configs.max_grad_norm) 268 | optimizer.step() 269 | scheduler.step() # Update learning rate schedule 270 | model.zero_grad() 271 | global_step += 1 272 | 273 | logs = {} 274 | if Training_Configs.logging_steps > 0 and global_step % Training_Configs.logging_steps == 0: 275 | # Log metrics 276 | if Training_Configs.evaluate_during_training: # Only evaluate when single GPU otherwise metrics may not average well 277 | results, _ , _ = evaluate( model, tokenizer, Configs, Configs.task_name, use_tqdm=False) 278 | for key, value in results.items(): 279 | eval_key = "eval_{}".format(key) 280 | logs[eval_key] = value 281 | 282 | loss_scalar = (tr_loss - logging_loss) / Training_Configs.logging_steps 283 | learning_rate_scalar = scheduler.get_last_lr()[0] 284 | logs["learning_rate"] = learning_rate_scalar 285 | logs["avg_loss_since_last_log"] = loss_scalar 286 | logging_loss = tr_loss 287 | 288 | logging.info(json.dumps({**logs, **{"step": global_step}})) 289 | 290 | 291 | if ( Training_Configs.eval_and_save_steps > 0 and global_step % Training_Configs.eval_and_save_steps == 0) \ 292 | or (step+1==t_total): 293 | # evaluate 294 | results, _, _ = evaluate(model, tokenizer, Configs, Configs.task_name, use_tqdm=False) 295 | logger.info("------Next Evalset will be loaded from cached file------") 296 | Configs.Dataset.overwrite_cache = False 297 | for key, value in results.items(): 298 | logs[f"eval_{key}"] = value 299 | logger.info(json.dumps({**logs, **{"step": global_step}})) 300 | 301 | # save 302 | if Training_Configs.save_only_best: 303 | output_dirs = [os.path.join(Configs.out_dir, Configs.checkpoint)] 304 | else: 305 | output_dirs = [os.path.join(Configs.out_dir, f"checkpoint-{global_step}")] 306 | curr_val_metric = results[task_metrics[Configs.task_name]] 307 | if best_val_metric is None or curr_val_metric > best_val_metric: 308 | # check if best model so far 309 | logger.info("Congratulations, best model so far!") 310 | best_val_metric = curr_val_metric 311 | 312 | for output_dir in output_dirs: 313 | # in each dir, save model, tokenizer, args, optimizer, scheduler 314 | if not os.path.exists(output_dir): 315 | os.makedirs(output_dir) 316 | model_to_save = ( 317 | model.module if hasattr(model, "module") else model 318 | ) # Take care of distributed/parallel training 319 | logger.info("Saving model checkpoint to %s", output_dir) 320 | 321 | torch.save(model_to_save, os.path.join(output_dir, "pytorch_model.bin")) 322 | torch.save(Configs.state_dict, os.path.join(output_dir, "training_args.bin")) 323 | torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt")) 324 | torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt")) 325 | tokenizer.save_pretrained(output_dir) 326 | logger.info("\tSaved model checkpoint to %s", output_dir) 327 | 328 | 329 | if Training_Configs.max_steps > 0 and global_step > Training_Configs.max_steps: 330 | epoch_iterator.close() 331 | break 332 | if Training_Configs.max_steps > 0 and global_step > Training_Configs.max_steps: 333 | # train_iterator.close() 334 | break 335 | 336 | return global_step, tr_loss / global_step 337 | 338 | 339 | def evaluate(model, tokenizer, Configs, task_name, split="dev", prefix="", use_tqdm=True): 340 | Training_Configs = Configs.Training_with_Processor 341 | results = {} 342 | if task_name == "record": 343 | eval_dataset, eval_answers = load_and_cache_examples(Configs.Dataset, task_name, tokenizer, split=split) 344 | else: 345 | eval_dataset = load_and_cache_examples(Configs.Dataset, task_name, tokenizer, split=split) 346 | 347 | if not os.path.exists(Configs.out_dir): 348 | os.makedirs(Configs.out_dir) 349 | 350 | Training_Configs.eval_batch_size = Training_Configs.per_gpu_eval_batch_size * max(1, Training_Configs.n_gpu) 351 | # Note that DistributedSampler samples randomly 352 | eval_sampler = SequentialSampler(eval_dataset) 353 | eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=Training_Configs.eval_batch_size) 354 | 355 | # multi-gpu eval 356 | if Training_Configs.n_gpu > 1 and not isinstance(model, torch.nn.DataParallel): 357 | model = torch.nn.DataParallel(model) 358 | 359 | # Eval! 360 | logger.info(f"***** Running evaluation: {prefix} on {task_name} {split} *****") 361 | logger.info(" Num examples = %d", len(eval_dataset)) 362 | logger.info(" Batch size = %d", Training_Configs.eval_batch_size) 363 | eval_loss = 0.0 364 | nb_eval_steps = 0 365 | preds = None 366 | out_label_ids = None 367 | ex_ids = None 368 | eval_dataloader = tqdm(eval_dataloader, desc="Evaluating") if use_tqdm else eval_dataloader 369 | for batch in eval_dataloader: 370 | model.eval() 371 | batch = tuple(t.to(Configs.device) for t in batch) 372 | guids = batch[-1] 373 | 374 | max_seq_length = batch[0].size(1) 375 | if Training_Configs.use_fixed_seq_length: # no dynamic sequence length 376 | batch_seq_length = max_seq_length 377 | else: 378 | batch_seq_length = torch.max(batch[-2], 0)[0].item() 379 | 380 | if batch_seq_length < max_seq_length: 381 | inputs = {"input_ids": batch[0][:, :batch_seq_length].contiguous(), 382 | "attention_mask": batch[1][:, :batch_seq_length].contiguous(), 383 | "token_type_ids":batch[2][:, :batch_seq_length].contiguous(), 384 | "labels": batch[3]} 385 | # inputs["token_type_ids"] = ( 386 | # batch[2][:, :batch_seq_length].contiguous() if Configs.model_type in ["bert", "xlnet", "albert"] else None 387 | # ) # XLM, DistilBERT, RoBERTa, and XLM-RoBERTa don't use segment_ids 388 | else: 389 | inputs = {"input_ids": batch[0], "attention_mask": batch[1], "token_type_ids":batch[2],"labels": batch[3]} 390 | 391 | if Configs.task_name == "graph_steganalysis": 392 | inputs = {**inputs,"graph":batch[4]} 393 | # inputs["token_type_ids"] = ( 394 | # batch[2][:, :batch_seq_length].contiguous() if Configs.model_type in ["bert", "xlnet", "albert"] else None 395 | # ) # XLM, DistilBERT, RoBERTa, and XLM-RoBERTa don't use segment_ids 396 | 397 | with torch.no_grad(): 398 | outputs = model(**inputs) 399 | tmp_eval_loss, logits = outputs[:2] 400 | 401 | eval_loss += tmp_eval_loss.mean().item() 402 | nb_eval_steps += 1 403 | if preds is None: 404 | preds = logits.detach().cpu().numpy() 405 | out_label_ids = inputs["labels"].detach().cpu().numpy() 406 | ex_ids = [guids.detach().cpu().numpy()] 407 | else: 408 | preds = np.append(preds, logits.detach().cpu().numpy(), axis=0) 409 | out_label_ids = np.append(out_label_ids, inputs["labels"].detach().cpu().numpy(), axis=0) 410 | ex_ids.append(guids.detach().cpu().numpy()) 411 | 412 | ex_ids = np.concatenate(ex_ids, axis=0) 413 | eval_loss = eval_loss / nb_eval_steps 414 | 415 | preds = np.argmax(preds, axis=1) 416 | 417 | result = utils.compute_metrics(task_name, preds, out_label_ids,) 418 | results.update(result) 419 | if prefix == "": 420 | return results, preds, ex_ids 421 | output_eval_file = os.path.join(Configs.out_dir, prefix, "eval_results.txt") 422 | with open(output_eval_file, "w") as writer: 423 | logger.info(f"***** {split} results: {prefix} *****") 424 | for key in sorted(result.keys()): 425 | logger.info(" %s = %s", key, str(result[key])) 426 | writer.write("%s = %s\n" % (key, str(result[key]))) 427 | 428 | return results, preds, ex_ids 429 | 430 | 431 | def load_and_cache_examples(Dataset_Configs, task, tokenizer, split="train"): 432 | if task == "steganalysis": 433 | from processors.process import SteganalysisProcessor as DataProcessor 434 | elif task == "graph_steganalysis": 435 | from processors.graph_process import GraphSteganalysisProcessor as DataProcessor 436 | 437 | processor = DataProcessor(tokenizer) 438 | # Load data features from cache or dataset file 439 | cached_tensors_file = os.path.join( 440 | Dataset_Configs.csv_dir, 441 | "tensors_{}_{}_{}".format( 442 | split, time_stamp, str(task), 443 | ), 444 | ) 445 | if os.path.exists(cached_tensors_file) and not Dataset_Configs.overwrite_cache: 446 | logger.info("Loading tensors from cached file %s", cached_tensors_file) 447 | start_time = time.time() 448 | dataset = torch.load(cached_tensors_file) 449 | logger.info("\tFinished loading tensors") 450 | logger.info(f"\tin {time.time() - start_time}s") 451 | 452 | else: 453 | # no cached tensors, process data from scratch 454 | logger.info("Creating features from dataset file at %s", Dataset_Configs.csv_dir) 455 | if split == "train": 456 | get_examples = processor.get_train_examples 457 | elif split == "dev": 458 | get_examples = processor.get_dev_examples 459 | elif split == "test": 460 | get_examples = processor.get_test_examples 461 | 462 | examples = get_examples(Dataset_Configs.csv_dir) 463 | dataset = processor.convert_examples_to_features(examples,) 464 | logger.info("\tFinished creating features") 465 | 466 | logger.info("\tFinished converting features into tensors") 467 | if Dataset_Configs.save_cache: 468 | logger.info("Saving features into cached file %s", cached_tensors_file) 469 | torch.save(dataset, cached_tensors_file) 470 | logger.info("\tFinished saving tensors") 471 | 472 | if task == "record" and split in ["dev", "test"]: 473 | answers = processor.get_answers(Dataset_Configs.csv_dir, split) 474 | return dataset, answers 475 | else: 476 | return dataset 477 | 478 | 479 | def main(Configs): 480 | Dataset_Configs = Configs.Dataset 481 | Configs.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 482 | os.makedirs(Configs.out_dir,exist_ok=True) 483 | set_seed(Configs.seed) 484 | 485 | # Setup logging 486 | logging.basicConfig( 487 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 488 | datefmt="%m/%d/%Y %H:%M:%S", 489 | level=logging.INFO, 490 | ) 491 | handler = logging.FileHandler(os.path.join(Configs.out_dir,time_stamp+"_log")) 492 | handler.setLevel(logging.INFO) 493 | formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') 494 | handler.setFormatter(formatter) 495 | logger.addHandler(handler) 496 | 497 | logger.info("--------------Main Configs-------------------") 498 | logger.info(Configs) 499 | 500 | logger.info("--------------loading data-------------------") 501 | logger.info("Dataset Configs") 502 | logger.info(json.dumps(Dataset_Configs)) 503 | 504 | Configs.model_name_or_path = Configs.Training_with_Processor.model_name_or_path 505 | logger.info("\tload plm name or path from Training_with_Processor args") 506 | 507 | logger.info("-------------------------------------------------------------------------------------------------------") 508 | # prepare data 509 | if Configs.use_processor: 510 | # translate txt into csv 511 | if not Dataset_Configs.resplit and os.path.exists(Dataset_Configs.csv_dir) and \ 512 | os.path.exists(os.path.join(Dataset_Configs.csv_dir,"train.csv")) and \ 513 | os.path.exists(os.path.join(Dataset_Configs.csv_dir,"val.csv")) and \ 514 | os.path.exists(os.path.join(Dataset_Configs.csv_dir,"val.csv")): 515 | pass 516 | else: 517 | os.makedirs(Dataset_Configs.csv_dir, exist_ok=True) 518 | with open(Dataset_Configs.cover_file, 'r', encoding='utf-8') as f: 519 | covers = f.read().split("\n") 520 | covers = list(filter(lambda x: x not in ['', None], covers)) 521 | random.shuffle(covers) 522 | with open(Dataset_Configs.stego_file, 'r', encoding='utf-8') as f: 523 | stegos = f.read().split("\n") 524 | stegos = list(filter(lambda x: x not in ['', None], stegos)) 525 | random.shuffle(stegos) 526 | texts = covers+stegos 527 | labels = [0]*len(covers) + [1]*len(stegos) 528 | 529 | # val_ratio = (1-Dataset_Configs.split_ratio)/Dataset_Configs.split_ratio 530 | train_texts,test_texts,train_labels,test_labels = train_test_split(texts,labels,train_size=Dataset_Configs.split_ratio) 531 | train_texts,val_texts,train_labels,val_labels = train_test_split(train_texts, train_labels, train_size=Dataset_Configs.split_ratio) 532 | def write2file(X, Y, filename): 533 | with open(filename, "w", encoding="utf-8") as f: 534 | writer = csv.writer(f) 535 | writer.writerow(["text", "label"]) 536 | for x, y in zip(X, Y): 537 | writer.writerow([x, y]) 538 | write2file(train_texts,train_labels, os.path.join(Dataset_Configs.csv_dir,"train.csv")) 539 | write2file(val_texts, val_labels, os.path.join(Dataset_Configs.csv_dir, "val.csv")) 540 | write2file(test_texts, test_labels, os.path.join(Dataset_Configs.csv_dir, "test.csv")) 541 | tokenizer = AutoTokenizer.from_pretrained(Configs.model_name_or_path,) 542 | VOCAB_SIZE = tokenizer.vocab_size 543 | 544 | else: 545 | # not recommend 546 | raise NotImplemented 547 | 548 | teacher_model = load_teacher_model(Configs, VOCAB_SIZE=VOCAB_SIZE) 549 | student_model = load_student_model(Configs, VOCAB_SIZE=VOCAB_SIZE) 550 | 551 | logger.info("--------------start training--------------------") 552 | 553 | if Configs.use_processor: 554 | # train_dataset = load_and_cache_examples(Dataset_Configs, Configs.task_name, tokenizer) # , evaluate=False) 555 | global_step, tr_loss = train(teacher_model, student_model, Configs, tokenizer) 556 | logger.info(" global_step = %s, average loss = %s", global_step, tr_loss) 557 | Tarining_Configs = Configs.Training_with_Processor 558 | 559 | checkpoints = [os.path.join(Configs.out_dir, Configs.checkpoint)] 560 | logger.info("Evaluate the following checkpoints: %s", checkpoints) 561 | 562 | for checkpoint in checkpoints: 563 | # tokenizer = AutoTokenizer.from_pretrained(checkpoint, do_lower_case=Tarining_Configs.do_lower_case) 564 | prefix = checkpoint.split("/")[-1] 565 | student_model = load_student_model(Configs, VOCAB_SIZE=tokenizer.vocab_size,checkpoint=checkpoint) 566 | # if not Configs.use_plm: 567 | # model = torch.load(os.path.join(checkpoint, "pytorch_model.bin")) 568 | # logger.info("--------------load model without pretrained language model-----------------") 569 | # else: 570 | # logger.info("--------------load model with pretrained language model--------------------") 571 | result, preds, ex_ids = evaluate(student_model, tokenizer, Configs, Configs.task_name, split="test", prefix=prefix) 572 | test_acc = result["accuracy"] 573 | test_precision = result["precision"] 574 | test_recall = result["recall"] 575 | test_Fscore = result["f1_score"] 576 | 577 | else: 578 | raise NotImplemented 579 | 580 | record_file = Configs.record_file if Configs.record_file is not None else "record.txt" 581 | result_path = os.path.join(Configs.out_dir, time_stamp+"----"+record_file) 582 | with open(result_path, "w", encoding="utf-8") as f: 583 | f.write("test phase:\naccuracy\t{:.4f}\nprecision\t{:.4f}\nrecall\t{:.4f}\nf1_score\t{:.4f}" 584 | .format(test_acc*100,test_precision*100,test_recall*100,test_Fscore*100)) 585 | 586 | 587 | if __name__ == '__main__': 588 | import argparse 589 | parser = argparse.ArgumentParser(description="argument for generation") 590 | parser.add_argument("--config_path", type=str, default="./configs/stage2.json") 591 | args = parser.parse_args() 592 | Configs = utils.Config(args.config_path).get_configs() 593 | os.environ["CUDA_VISIBLE_DEVICES"] = Configs.gpuid 594 | main(Configs) -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import utils 2 | import random 3 | import dataset 4 | import numpy as np 5 | import logging 6 | import os 7 | import json 8 | import time 9 | import csv 10 | from tqdm import tqdm 11 | 12 | import torch 13 | from torch.utils.data import DataLoader, RandomSampler, SequentialSampler 14 | 15 | from sklearn.model_selection import train_test_split 16 | from models import birnn as BiRNN,cnn as CNN,lstmatt as LSTMATT,fcn as FCN, r_bilstm_c as RBC,\ 17 | bilstm_dense as BLSTMDENSE,sesy as SESY 18 | 19 | 20 | from transformers import ( 21 | AdamW, 22 | AutoTokenizer, 23 | get_linear_schedule_with_warmup, 24 | ) 25 | 26 | try: 27 | from torch.utils.tensorboard import SummaryWriter 28 | except ImportError: 29 | from tensorboardX import SummaryWriter 30 | 31 | task_metrics = {"steganalysis" : "accuracy", 32 | "graph_steganalysis" : "accuracy",} 33 | 34 | 35 | logger = logging.getLogger(__name__) 36 | time_stamp = "-".join(time.ctime().split()) 37 | 38 | 39 | def set_seed(seed): 40 | random.seed(seed) 41 | np.random.seed(seed) 42 | torch.manual_seed(seed) 43 | torch.cuda.manual_seed_all(seed) 44 | 45 | 46 | def load_model(Configs, VOCAB_SIZE=None, checkpoint=None): 47 | # set model 48 | 49 | logger.info("----------------init model-----------------------") 50 | 51 | 52 | if not Configs.use_plm: 53 | if Configs.model.lower() in ["ts-csw", "cnn"]: 54 | Model_Configs = Configs.CNN 55 | Model_Configs.vocab_size = VOCAB_SIZE 56 | model = CNN.TC(**{**Model_Configs, "class_num": Configs.class_num, }) 57 | elif Configs.model.lower() in ["birnn"]: 58 | Model_Configs = Configs.RNN 59 | Model_Configs.vocab_size = VOCAB_SIZE 60 | model = BiRNN.TC(**{**Model_Configs, "class_num": Configs.class_num}) 61 | elif Configs.model.lower() in ["fcn", "fc"]: 62 | Model_Configs = Configs.FCN 63 | Model_Configs.vocab_size = VOCAB_SIZE 64 | model = FCN.TC(**{**Model_Configs, "class_num": Configs.class_num, }) 65 | elif Configs.model.lower() in ["lstmatt"]: 66 | Model_Configs = Configs.LSTMATT 67 | Model_Configs.vocab_size = VOCAB_SIZE 68 | model = LSTMATT.TC(**{**Model_Configs, "class_num": Configs.class_num, }) 69 | elif Configs.model.lower() in ["r-bilstm-c", "r-b-c", "rbc", "rbilstmc"]: 70 | Model_Configs = Configs.RBiLSTMC 71 | Model_Configs.vocab_size = VOCAB_SIZE 72 | model = RBC.TC(**{**Model_Configs, "class_num": Configs.class_num, }) 73 | elif Configs.model.lower() in ["bilstmdense", "bilstm-dense", "bilstm_dense", "bi-lstm-dense"]: 74 | Model_Configs = Configs.BiLSTMDENSE 75 | Model_Configs.vocab_size = VOCAB_SIZE 76 | model = BLSTMDENSE.TC(**{**Model_Configs, "class_num": Configs.class_num, }) 77 | elif Configs.model.lower() in ["sesy"]: 78 | Model_Configs = Configs.SESY 79 | Model_Configs.vocab_size = VOCAB_SIZE 80 | model = SESY.TC(**{**Model_Configs, "class_num": Configs.class_num, }) 81 | elif Configs.model.lower() in ["gnn"]: 82 | Model_Configs = Configs.GNN 83 | Model_Configs.vocab_size = VOCAB_SIZE 84 | model = GNN.TC(**{**Model_Configs, "class_num": Configs.class_num, }) 85 | else: 86 | logger.error("no such model, exit") 87 | exit() 88 | if checkpoint is not None: 89 | logger.info("---------------------loading model from {}------------\n\n".format(checkpoint)) 90 | model = torch.load(os.path.join(checkpoint, "pytorch_model.bin")) 91 | else: 92 | if checkpoint is not None: 93 | logger.info("---------------------loading model from {}------------\n\n".format(checkpoint)) 94 | model_name_or_path = checkpoint 95 | else: 96 | logger.info("-------------loading pretrained language model from huggingface--------------\n\n") 97 | model_name_or_path = Configs.model_name_or_path 98 | 99 | if Configs.model.lower() in ["ts-csw", "cnn"]: 100 | Model_Configs = Configs.CNN 101 | model = CNN.BERT_TC.from_pretrained(model_name_or_path, 102 | **{**Model_Configs, "class_num": Configs.class_num, }) 103 | elif Configs.model.lower() in ["birnn"]: 104 | Model_Configs = Configs.RNN 105 | model = BiRNN.BERT_TC.from_pretrained(model_name_or_path, 106 | **{**Model_Configs, "class_num": Configs.class_num, }) 107 | elif Configs.model.lower() in ["fcn", "fc"]: 108 | Model_Configs = Configs.FCN 109 | model = FCN.BERT_TC.from_pretrained(model_name_or_path,**{**Model_Configs, "class_num": Configs.class_num, }) 110 | elif Configs.model.lower() in ["lstmatt"]: 111 | Model_Configs = Configs.LSTMATT 112 | model = LSTMATT.BERT_TC.from_pretrained(model_name_or_path, 113 | **{**Model_Configs, "class_num": Configs.class_num, }) 114 | elif Configs.model.lower() in ["r-bilstm-c", "r-b-c", "rbc", "rbilstmc"]: 115 | Model_Configs = Configs.RBiLSTMC 116 | model = RBC.BERT_TC.from_pretrained(model_name_or_path, 117 | **{**Model_Configs, "class_num": Configs.class_num, }) 118 | elif Configs.model.lower() in ["bilstmdense", "bilstm-dense", "bilstm_dense", "bi-lstm-dense"]: 119 | Model_Configs = Configs.BiLSTMDENSE 120 | model = BLSTMDENSE.BERT_TC.from_pretrained(model_name_or_path, 121 | **{**Model_Configs, "class_num": Configs.class_num, }) 122 | elif Configs.model.lower() in ["sesy"]: 123 | Model_Configs = Configs.SESY 124 | model = SESY.BERT_TC.from_pretrained(model_name_or_path, 125 | **{**Model_Configs, "class_num": Configs.class_num, }) 126 | else: 127 | logger.error("no such model, exit") 128 | exit() 129 | 130 | 131 | logger.info("Model Configs") 132 | logger.info(json.dumps({**{"MODEL_TYPE": Configs.model}, **Model_Configs, })) 133 | model = model.to(Configs.device) 134 | return model 135 | 136 | 137 | def train_with_helper(data_helper,model,Configs,): 138 | os.makedirs(os.path.join(Configs.out_dir, Configs.checkpoint),exist_ok=True) 139 | checkpoint = os.path.join(Configs.out_dir, Configs.checkpoint, "maxacc.pth") 140 | Training_Configs = Configs.Training 141 | logger.info("Training Configs") 142 | logger.info(Training_Configs) 143 | logger.info("-----------------------------------------------") 144 | t_total = data_helper.train_num// Training_Configs.batch_size * Training_Configs.epoch 145 | num_warmup_steps = int(Training_Configs.warmup_ratio * t_total) 146 | no_decay = ["bias", "LayerNorm.weight"] 147 | optimizer_grouped_parameters = [ 148 | { 149 | "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 150 | "weight_decay": Training_Configs.weight_decay, 151 | }, 152 | {"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0}, 153 | ] 154 | optimizer = AdamW(optimizer_grouped_parameters, lr=Training_Configs.learning_rate, eps=Training_Configs.adam_epsilon) 155 | scheduler = get_linear_schedule_with_warmup( 156 | optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=t_total 157 | ) 158 | # optimizer = optim.SGD(model.parameters(), lr=Training_Configs.learning_rate, momentum=0.9) 159 | early_stop = 0 160 | best_acc = 0 161 | best_test_loss = 1000 162 | best_precison = 0 163 | best_recall = 0 164 | best_F1 = 0 165 | 166 | logger.info("------------number of instance-------------") 167 | logger.info(format(f"train\t{data_helper.train_num}")) 168 | logger.info(format(f"val \t{data_helper.val_num}")) 169 | logger.info(format(f"test \t{data_helper.test_num}")) 170 | 171 | for epoch in range(Training_Configs.epoch): 172 | model.train() 173 | generator_train = data_helper.train_generator(Training_Configs.batch_size) 174 | train_loss = [] 175 | train_acc = [] 176 | while True: 177 | try: 178 | text, label = generator_train.__next__() 179 | except: 180 | break 181 | optimizer.zero_grad() 182 | loss,y = model(torch.from_numpy(text).long().to(Configs.device), torch.from_numpy(label).long().to(Configs.device)) 183 | # loss = criteration(y, torch.from_numpy(label).long().to(Training_configs.device)) 184 | loss.backward() 185 | optimizer.step() 186 | scheduler.step() 187 | train_loss.append(loss.item()) 188 | y = y.cpu().detach().numpy() 189 | train_acc += [1 if np.argmax(y[i]) == label[i] else 0 for i in range(len(y))] 190 | 191 | val_loss, val_acc, val_precision, val_recall, val_Fscore = eval_with_helper(data_helper,model,Configs) 192 | 193 | logger.info( 194 | "epoch {:d}, training loss {:.4f}, train acc {:.4f}, val loss {:.4f}, val acc {:.4f}, val pre {:.4f},val recall {:.4f},val F1 {:.4f}" 195 | .format(epoch + 1, np.mean(train_loss), np.mean(train_acc), val_loss, val_acc, val_precision, val_recall, 196 | val_Fscore)) 197 | 198 | 199 | if val_acc > best_acc: 200 | state = {'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'epoch': epoch, "scheduler":scheduler.state_dict(), 201 | "training_args":Configs.state_dict,"val loss": val_loss, "val acc": np.mean(val_acc)} 202 | torch.save(state, checkpoint) 203 | best_test_loss = val_loss 204 | best_acc = val_acc 205 | best_precison = val_precision 206 | best_recall = val_recall 207 | best_F1 = val_Fscore 208 | early_stop = 0 209 | else: 210 | early_stop += 1 211 | if early_stop >= Training_Configs.early_stop : 212 | break 213 | 214 | logger.info("--------------start calculate metrics--------------") 215 | 216 | state = torch.load(checkpoint) 217 | optimizer.load_state_dict(state["optimizer"]) 218 | scheduler.load_state_dict(state["scheduler"]) 219 | model.load_state_dict(state["model"]) 220 | test_loss, test_acc, test_precision, test_recall, test_Fscore = eval_with_helper(data_helper,model,Configs,"test") 221 | logger.info('val: loss: {:.4f}, acc: {:.4f}, pre {:.4f}, recall {:.4f}, F1 {:.4f}'.format(best_test_loss, best_acc, 222 | best_precison, best_recall, best_F1)) 223 | logger.info( 224 | "test: loss {:.4f}, acc {:.4f}, pre {:.4f}, recall {:.4f}, F1 {:.4f}".format(test_loss, test_acc, test_precision, 225 | test_recall, test_Fscore)) 226 | return test_acc, test_precision, test_recall, test_Fscore 227 | 228 | 229 | def eval_with_helper(data_helper,model, Configs, eval_or_test="eval"): 230 | Training_Configs = Configs.Training 231 | model.eval() 232 | generator = data_helper.val_generator(Training_Configs.batch_size) if eval_or_test == "eval" \ 233 | else data_helper.test_generator(Training_Configs.batch_size) 234 | test_loss = 0 235 | test_acc = [] 236 | test_tp = [] 237 | tfn = [] 238 | tpfn = [] 239 | length_sum = 0 240 | 241 | while True: 242 | with torch.no_grad(): 243 | try: 244 | text, label = generator.__next__() 245 | except: 246 | break 247 | loss,y = model(torch.from_numpy(text).long().to(Configs.device), 248 | torch.from_numpy(label).long().to(Configs.device)) 249 | # loss = criteration(y, torch.from_numpy(label).long().to(Training_configs.device)) 250 | loss = loss.cpu().numpy() 251 | test_loss += loss * len(text) 252 | length_sum += len(text) 253 | y = y.cpu().numpy() 254 | label_pred = np.argmax(y, axis=-1) 255 | test_acc += [1 if np.argmax(y[i]) == label[i] else 0 for i in range(len(y))] 256 | test_tp += [1 if np.argmax(y[i]) == label[i] and label[i] == 1 else 0 for i in range(len(y))] 257 | tfn += [1 if np.argmax(y[i]) == 1 else 0 for i in range(len(y))] 258 | tpfn += [1 if label[i] == 1 else 0 for i in range(len(y))] 259 | 260 | test_loss = test_loss / length_sum 261 | acc = np.mean(test_acc) 262 | tpsum = np.sum(test_tp) 263 | test_precision = tpsum / (np.sum(tfn) + 1e-5) 264 | test_recall = tpsum / np.sum(tpfn) 265 | test_Fscore = 2 * test_precision * test_recall / (test_recall + test_precision + 1e-10) 266 | return test_loss, acc, test_precision, test_recall, test_Fscore 267 | 268 | 269 | def train( model, Configs, tokenizer): 270 | train_dataset = load_and_cache_examples(Configs.Dataset, Configs.task_name, tokenizer) 271 | Training_Configs = Configs.Training_with_Processor 272 | Training_Configs.train_batch_size = Training_Configs.per_gpu_train_batch_size * max(1, Training_Configs.n_gpu) 273 | train_sampler = RandomSampler(train_dataset) 274 | train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=Training_Configs.train_batch_size) 275 | 276 | if Training_Configs.max_steps > 0: 277 | t_total = Training_Configs.max_steps 278 | Training_Configs.num_train_epochs = Training_Configs.max_steps // (len(train_dataloader) // Training_Configs.gradient_accumulation_steps) + 1 279 | else: 280 | t_total = len(train_dataloader) // Training_Configs.gradient_accumulation_steps * Training_Configs.num_train_epochs 281 | 282 | num_warmup_steps = int(Training_Configs.warmup_ratio * t_total) 283 | # Prepare optimizer and schedule (linear warmup and decay) 284 | no_decay = ["bias", "LayerNorm.weight"] 285 | optimizer_grouped_parameters = [ 286 | { 287 | "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 288 | "weight_decay": Training_Configs.weight_decay, 289 | }, 290 | {"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0}, 291 | ] 292 | optimizer = AdamW(optimizer_grouped_parameters, lr=Training_Configs.learning_rate, eps=Training_Configs.adam_epsilon) 293 | scheduler = get_linear_schedule_with_warmup( 294 | optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=t_total 295 | ) 296 | 297 | # Check if saved optimizer or scheduler states exist 298 | if os.path.isfile(os.path.join(Training_Configs.model_name_or_path, "optimizer.pt")) and os.path.isfile( 299 | os.path.join(Training_Configs.model_name_or_path, "scheduler.pt") 300 | ): 301 | # Load in optimizer and scheduler states 302 | optimizer.load_state_dict(torch.load(os.path.join(Training_Configs.model_name_or_path, "optimizer.pt"))) 303 | scheduler.load_state_dict(torch.load(os.path.join(Training_Configs.model_name_or_path, "scheduler.pt"))) 304 | 305 | 306 | # Train! 307 | logger.info("***** Running training *****") 308 | logger.info(" Num examples = %d", len(train_dataset)) 309 | logger.info(" Num Epochs = %d", Training_Configs.num_train_epochs) 310 | logger.info(" Instantaneous batch size per GPU = %d", Training_Configs.per_gpu_train_batch_size) 311 | logger.info( 312 | " Total train batch size (w. parallel, distributed & accumulation) = %d", 313 | Training_Configs.train_batch_size 314 | * Training_Configs.gradient_accumulation_steps) 315 | logger.info(" Gradient Accumulation steps = %d", Training_Configs.gradient_accumulation_steps) 316 | logger.info(" Total optimization steps = %d", t_total) 317 | 318 | global_step = 0 319 | epochs_trained = 0 320 | steps_trained_in_current_epoch = 0 321 | 322 | # Check if continuing training from a checkpoint 323 | if os.path.exists(Training_Configs.model_name_or_path): 324 | # set global_step to gobal_step of last saved checkpoint from model path 325 | try: 326 | global_step = int(Training_Configs.model_name_or_path.split("-")[-1].split("/")[0]) 327 | except ValueError: 328 | global_step = 0 329 | 330 | epochs_trained = global_step // (len(train_dataloader) // args.gradient_accumulation_steps) 331 | steps_trained_in_current_epoch = global_step % (len(train_dataloader) // args.gradient_accumulation_steps) 332 | 333 | logger.info(" Continuing training from checkpoint, will skip to saved global_step") 334 | logger.info(" Continuing training from epoch %d", epochs_trained) 335 | logger.info(" Continuing training from global step %d", global_step) 336 | logger.info(" Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch) 337 | 338 | tr_loss, logging_loss = 0.0, 0.0 339 | model.zero_grad() 340 | train_iterator = range(epochs_trained, int(Training_Configs.num_train_epochs)) 341 | 342 | set_seed(Configs.seed) # Added here for reproductibility 343 | 344 | best_val_metric = None 345 | for epoch_n in train_iterator: 346 | epoch_iterator = tqdm(train_dataloader, desc=f"Epoch {epoch_n}", disable=False) 347 | for step, batch in enumerate(epoch_iterator): 348 | 349 | # Skip past any already trained steps if resuming training 350 | if steps_trained_in_current_epoch > 0: 351 | steps_trained_in_current_epoch -= 1 352 | continue 353 | 354 | model.train() 355 | batch = tuple(t.to(Configs.device) for t in batch) 356 | inputs = {"input_ids": batch[0], "attention_mask": batch[1], "token_type_ids": batch[2], "labels": batch[3]} 357 | if Configs.task_name == "graph_steganalysis": 358 | inputs = {**inputs,"graph":batch[4]} 359 | 360 | outputs = model(**inputs) 361 | loss = outputs[0] 362 | 363 | if Training_Configs.n_gpu > 1: 364 | loss = loss.mean() # mean() to average on multi-gpu parallel training 365 | if Training_Configs.gradient_accumulation_steps > 1: 366 | loss = loss / Training_Configs.gradient_accumulation_steps 367 | 368 | loss.backward() 369 | 370 | tr_loss += loss.item() 371 | if (step + 1) % Training_Configs.gradient_accumulation_steps == 0: 372 | torch.nn.utils.clip_grad_norm_(model.parameters(), Training_Configs.max_grad_norm) 373 | optimizer.step() 374 | scheduler.step() # Update learning rate schedule 375 | model.zero_grad() 376 | global_step += 1 377 | 378 | logs = {} 379 | if Training_Configs.logging_steps > 0 and global_step % Training_Configs.logging_steps == 0: 380 | # Log metrics 381 | if Training_Configs.evaluate_during_training: # Only evaluate when single GPU otherwise metrics may not average well 382 | results, _ , _ = evaluate(model, tokenizer, Configs, Configs.task_name, use_tqdm=False, split="train") 383 | for key, value in results.items(): 384 | eval_key = "Train_{}".format(key) 385 | logs[eval_key] = value 386 | 387 | loss_scalar = (tr_loss - logging_loss) / Training_Configs.logging_steps 388 | learning_rate_scalar = scheduler.get_last_lr()[0] 389 | logs["learning_rate"] = learning_rate_scalar 390 | logs["avg_loss_since_last_log"] = loss_scalar 391 | logging_loss = tr_loss 392 | 393 | logging.info(json.dumps({**logs, **{"step": global_step}})) 394 | 395 | 396 | if ( Training_Configs.eval_and_save_steps > 0 and global_step % Training_Configs.eval_and_save_steps == 0) \ 397 | or (step+1==t_total): 398 | # evaluate 399 | results, _, _ = evaluate(model, tokenizer, Configs, Configs.task_name, use_tqdm=False) 400 | # logger.info("------Next Evalset will be loaded from cached file------") 401 | Configs.Dataset.overwrite_cache = False 402 | for key, value in results.items(): 403 | logs[f"eval_{key}"] = value 404 | logger.info(json.dumps({**logs, **{"step": global_step}})) 405 | 406 | # save 407 | if Training_Configs.save_only_best: 408 | output_dirs = [os.path.join(Configs.out_dir, Configs.checkpoint)] 409 | else: 410 | output_dirs = [os.path.join(Configs.out_dir, f"checkpoint-{global_step}")] 411 | curr_val_metric = results[task_metrics[Configs.task_name]] 412 | if best_val_metric is None or curr_val_metric > best_val_metric: 413 | # check if best model so far 414 | logger.info("Congratulations, best model so far!") 415 | best_val_metric = curr_val_metric 416 | 417 | for output_dir in output_dirs: 418 | # in each dir, save model, tokenizer, args, optimizer, scheduler 419 | if not os.path.exists(output_dir): 420 | os.makedirs(output_dir) 421 | model_to_save = ( 422 | model.module if hasattr(model, "module") else model 423 | ) # Take care of distributed/parallel training 424 | logger.info("Saving model checkpoint to %s", output_dir) 425 | if Configs.use_plm: 426 | model_to_save.save_pretrained(output_dir) 427 | else: 428 | torch.save(model_to_save, os.path.join(output_dir, "pytorch_model.bin")) 429 | torch.save(Configs.state_dict, os.path.join(output_dir, "training_args.bin")) 430 | torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt")) 431 | torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt")) 432 | tokenizer.save_pretrained(output_dir) 433 | logger.info("\tSaved model checkpoint to %s", output_dir) 434 | 435 | 436 | if Training_Configs.max_steps > 0 and global_step > Training_Configs.max_steps: 437 | epoch_iterator.close() 438 | break 439 | if Training_Configs.max_steps > 0 and global_step > Training_Configs.max_steps: 440 | # train_iterator.close() 441 | break 442 | 443 | return global_step, tr_loss / global_step 444 | 445 | 446 | def evaluate(model, tokenizer, Configs, task_name, split="dev", prefix="", use_tqdm=True): 447 | Training_Configs = Configs.Training_with_Processor 448 | results = {} 449 | if task_name == "record": 450 | eval_dataset, eval_answers = load_and_cache_examples(Configs.Dataset, task_name, tokenizer, split=split) 451 | else: 452 | eval_dataset = load_and_cache_examples(Configs.Dataset, task_name, tokenizer, split=split) 453 | 454 | if not os.path.exists(Configs.out_dir): 455 | os.makedirs(Configs.out_dir) 456 | 457 | Training_Configs.eval_batch_size = Training_Configs.per_gpu_eval_batch_size * max(1, Training_Configs.n_gpu) 458 | # Note that DistributedSampler samples randomly 459 | eval_sampler = SequentialSampler(eval_dataset) 460 | eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=Training_Configs.eval_batch_size) 461 | 462 | # multi-gpu eval 463 | if Training_Configs.n_gpu > 1 and not isinstance(model, torch.nn.DataParallel): 464 | model = torch.nn.DataParallel(model) 465 | 466 | # Eval! 467 | logger.info(f"***** Running evaluation: {prefix} on {task_name} {split} *****") 468 | logger.info("Num examples = %d", len(eval_dataset)) 469 | logger.info("Batch size = %d", Training_Configs.eval_batch_size) 470 | eval_loss = 0.0 471 | nb_eval_steps = 0 472 | preds = None 473 | out_label_ids = None 474 | ex_ids = None 475 | eval_dataloader = tqdm(eval_dataloader, desc="Evaluating") if use_tqdm else eval_dataloader 476 | for batch in eval_dataloader: 477 | model.eval() 478 | batch = tuple(t.to(Configs.device) for t in batch) 479 | guids = batch[-1] 480 | 481 | max_seq_length = batch[0].size(1) 482 | if Training_Configs.use_fixed_seq_length: # no dynamic sequence length 483 | batch_seq_length = max_seq_length 484 | else: 485 | batch_seq_length = torch.max(batch[-2], 0)[0].item() 486 | 487 | if batch_seq_length < max_seq_length: 488 | inputs = {"input_ids": batch[0][:, :batch_seq_length].contiguous(), 489 | "attention_mask": batch[1][:, :batch_seq_length].contiguous(), 490 | "token_type_ids":batch[2][:, :batch_seq_length].contiguous(), 491 | "labels": batch[3]} 492 | # inputs["token_type_ids"] = ( 493 | # batch[2][:, :batch_seq_length].contiguous() if Configs.model_type in ["bert", "xlnet", "albert"] else None 494 | # ) # XLM, DistilBERT, RoBERTa, and XLM-RoBERTa don't use segment_ids 495 | else: 496 | inputs = {"input_ids": batch[0], "attention_mask": batch[1], "token_type_ids":batch[2],"labels": batch[3]} 497 | 498 | if Configs.task_name == "graph_steganalysis": 499 | inputs = {**inputs,"graph":batch[4]} 500 | # inputs["token_type_ids"] = ( 501 | # batch[2][:, :batch_seq_length].contiguous() if Configs.model_type in ["bert", "xlnet", "albert"] else None 502 | # ) # XLM, DistilBERT, RoBERTa, and XLM-RoBERTa don't use segment_ids 503 | 504 | with torch.no_grad(): 505 | outputs = model(**inputs) 506 | tmp_eval_loss, logits = outputs[:2] 507 | 508 | eval_loss += tmp_eval_loss.mean().item() 509 | nb_eval_steps += 1 510 | if preds is None: 511 | preds = logits.detach().cpu().numpy() 512 | out_label_ids = inputs["labels"].detach().cpu().numpy() 513 | ex_ids = [guids.detach().cpu().numpy()] 514 | else: 515 | preds = np.append(preds, logits.detach().cpu().numpy(), axis=0) 516 | out_label_ids = np.append(out_label_ids, inputs["labels"].detach().cpu().numpy(), axis=0) 517 | ex_ids.append(guids.detach().cpu().numpy()) 518 | 519 | ex_ids = np.concatenate(ex_ids, axis=0) 520 | eval_loss = eval_loss / nb_eval_steps 521 | 522 | preds = np.argmax(preds, axis=1) 523 | 524 | result = utils.compute_metrics(task_name, preds, out_label_ids,) 525 | results.update(result) 526 | if prefix == "": 527 | return results, preds, ex_ids 528 | output_eval_file = os.path.join(Configs.out_dir, prefix, "eval_results.txt") 529 | with open(output_eval_file, "w") as writer: 530 | logger.info(f"***** {split} results: {prefix} *****") 531 | for key in sorted(result.keys()): 532 | logger.info(" %s = %s", key, str(result[key])) 533 | writer.write("%s = %s\n" % (key, str(result[key]))) 534 | 535 | return results, preds, ex_ids 536 | 537 | 538 | def load_and_cache_examples(Dataset_Configs, task, tokenizer, split="train"): 539 | if task == "steganalysis": 540 | from processors.process import SteganalysisProcessor as DataProcessor 541 | elif task == "graph_steganalysis": 542 | from processors.graph_process import GraphSteganalysisProcessor as DataProcessor 543 | 544 | processor = DataProcessor(tokenizer) 545 | # Load data features from cache or dataset file 546 | cached_tensors_file = os.path.join( 547 | Dataset_Configs.csv_dir, 548 | "tensors_{}_{}_{}".format( 549 | split, time_stamp, str(task), 550 | ), 551 | ) 552 | if os.path.exists(cached_tensors_file) and not Dataset_Configs.overwrite_cache: 553 | logger.info("Loading tensors from cached file %s", cached_tensors_file) 554 | start_time = time.time() 555 | dataset = torch.load(cached_tensors_file) 556 | logger.info("Finished loading tensors") 557 | logger.info(f"in {time.time() - start_time}s") 558 | 559 | else: 560 | # no cached tensors, process data from scratch 561 | logger.info("Creating features from dataset file at %s", Dataset_Configs.csv_dir) 562 | if split == "train": 563 | get_examples = processor.get_train_examples 564 | elif split == "dev": 565 | get_examples = processor.get_dev_examples 566 | elif split == "test": 567 | get_examples = processor.get_test_examples 568 | 569 | examples = get_examples(Dataset_Configs.csv_dir) 570 | dataset = processor.convert_examples_to_features(examples,) 571 | logger.info("Finished creating features") 572 | 573 | logger.info("Finished converting features into tensors") 574 | if Dataset_Configs.save_cache: 575 | logger.info("Saving features into cached file %s", cached_tensors_file) 576 | torch.save(dataset, cached_tensors_file) 577 | logger.info("Finished saving tensors") 578 | 579 | if task == "record" and split in ["dev", "test"]: 580 | answers = processor.get_answers(Dataset_Configs.csv_dir, split) 581 | return dataset, answers 582 | else: 583 | return dataset 584 | 585 | 586 | def main(Configs, seed_shift=0): 587 | # args conflict checking 588 | if Configs.use_plm: 589 | assert Configs.use_processor, "\nWhen using plm, You can only use processor to process dataset!!\n" 590 | 591 | Dataset_Configs = Configs.Dataset 592 | Configs.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 593 | os.makedirs(Configs.out_dir,exist_ok=True) 594 | set_seed(Configs.seed+seed_shift) 595 | 596 | # Setup logging 597 | logging.basicConfig( 598 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 599 | datefmt="%m/%d/%Y %H:%M:%S", 600 | level=logging.INFO, 601 | ) 602 | handler = logging.FileHandler(os.path.join(Configs.out_dir,time_stamp+"_log")) 603 | handler.setLevel(logging.INFO) 604 | formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') 605 | handler.setFormatter(formatter) 606 | logger.addHandler(handler) 607 | 608 | logger.info("--------------Main Configs-------------------") 609 | logger.info(Configs) 610 | 611 | logger.info("--------------loading data-------------------") 612 | logger.info("Dataset Configs") 613 | logger.info(json.dumps(Dataset_Configs)) 614 | 615 | # check whether to use plm 616 | logger.info("----------------use plm or not----------------") 617 | if Configs.use_plm: 618 | logger.info("------------------YES-----------------------------") 619 | Configs.model_name_or_path = Configs.Training_with_Processor.model_name_or_path 620 | logger.info("\tload plm name or path from Training_with_Processor args") 621 | else: 622 | logger.info("--------------------NO-------------------------") 623 | if Configs.use_processor: 624 | Configs.model_name_or_path = Configs.Training_with_Processor.model_name_or_path 625 | logger.info("\tload plm name or path from Training_with_Processor args") 626 | else: 627 | Configs.model_name_or_path = Configs.Tokenizer.model_name_or_path 628 | logger.info("\tload plm name or path from Tokenizer args") 629 | 630 | logger.info("-------------------------------------------------------------------------------------------------------") 631 | # prepare data 632 | if Configs.use_processor: 633 | # translate txt into csv 634 | if not Dataset_Configs.resplit and os.path.exists(Dataset_Configs.csv_dir) and \ 635 | os.path.exists(os.path.join(Dataset_Configs.csv_dir,"train.csv")) and \ 636 | os.path.exists(os.path.join(Dataset_Configs.csv_dir,"val.csv")) and \ 637 | os.path.exists(os.path.join(Dataset_Configs.csv_dir,"val.csv")): 638 | pass 639 | else: 640 | os.makedirs(Dataset_Configs.csv_dir, exist_ok=True) 641 | with open(Dataset_Configs.cover_file, 'r', encoding='utf-8') as f: 642 | covers = f.read().split("\n") 643 | covers = list(filter(lambda x: x not in ['', None], covers)) 644 | random.shuffle(covers) 645 | with open(Dataset_Configs.stego_file, 'r', encoding='utf-8') as f: 646 | stegos = f.read().split("\n") 647 | stegos = list(filter(lambda x: x not in ['', None], stegos)) 648 | random.shuffle(stegos) 649 | texts = covers+stegos 650 | labels = [0]*len(covers) + [1]*len(stegos) 651 | val_ratio = (1-Dataset_Configs.split_ratio)/Dataset_Configs.split_ratio 652 | train_texts,test_texts,train_labels,test_labels = train_test_split(texts,labels,train_size=Dataset_Configs.split_ratio) 653 | train_texts,val_texts, train_labels,val_labels, = train_test_split(train_texts, train_labels, train_size=1-val_ratio) 654 | def write2file(X, Y, filename): 655 | with open(filename, "w", encoding="utf-8", newline="") as f: 656 | writer = csv.writer(f) 657 | writer.writerow(["text", "label"]) 658 | for x, y in zip(X, Y): 659 | writer.writerow([x, y]) 660 | write2file(train_texts,train_labels, os.path.join(Dataset_Configs.csv_dir,"train.csv")) 661 | write2file(val_texts, val_labels, os.path.join(Dataset_Configs.csv_dir, "val.csv")) 662 | write2file(test_texts, test_labels, os.path.join(Dataset_Configs.csv_dir, "test.csv")) 663 | tokenizer = AutoTokenizer.from_pretrained(Configs.model_name_or_path,) 664 | VOCAB_SIZE = tokenizer.vocab_size 665 | 666 | else: 667 | # not recommend 668 | with open(Dataset_Configs.cover_file, 'r', encoding='utf-8') as f: 669 | covers = f.read().split("\n") 670 | covers = list(filter(lambda x: x not in ['', None], covers)) 671 | random.shuffle(covers) 672 | with open(Dataset_Configs.stego_file, 'r', encoding='utf-8') as f: 673 | stegos = f.read().split("\n") 674 | stegos = list(filter(lambda x: x not in ['', None], stegos)) 675 | random.shuffle( stegos) 676 | 677 | 678 | if Configs.tokenizer: 679 | Tokenizer_Configs = Configs.Tokenizer 680 | data_helper = dataset.BertDataHelper([covers, stegos], ratio=Dataset_Configs.split_ratio, 681 | tokenizer_config=Tokenizer_Configs) 682 | else: 683 | Vocabulary_Configs = Configs.Vocabulary 684 | data_helper = dataset.DataHelper([covers, stegos], use_label=True, 685 | ratio=Dataset_Configs.split_ratio, 686 | word_drop=Vocabulary_Configs.word_drop, 687 | do_lower=Vocabulary_Configs.do_lower, 688 | max_length= Vocabulary_Configs.max_length) 689 | 690 | VOCAB_SIZE = data_helper.vocab_size 691 | 692 | model = load_model(Configs, VOCAB_SIZE=VOCAB_SIZE) 693 | 694 | logger.info("--------------start training--------------------") 695 | 696 | if Configs.use_processor: 697 | # train_dataset = load_and_cache_examples(Dataset_Configs, Configs.task_name, tokenizer) # , evaluate=False) 698 | global_step, tr_loss = train(model, Configs, tokenizer) 699 | logger.info(" global_step = %s, average loss = %s", global_step, tr_loss) 700 | Training_Configs = Configs.Training_with_Processor 701 | 702 | checkpoints = [os.path.join(Configs.out_dir, Configs.checkpoint)] 703 | logger.info("Evaluate the following checkpoints: %s", checkpoints) 704 | 705 | for checkpoint in checkpoints: 706 | tokenizer = AutoTokenizer.from_pretrained(checkpoint, do_lower_case=Training_Configs.do_lower_case) 707 | prefix = checkpoint.split("/")[-1] 708 | model = load_model(Configs, VOCAB_SIZE=tokenizer.vocab_size,checkpoint=checkpoint) 709 | # if not Configs.use_plm: 710 | # model = torch.load(os.path.join(checkpoint, "pytorch_model.bin")) 711 | # logger.info("--------------load model without pretrained language model-----------------") 712 | # else: 713 | # logger.info("--------------load model with pretrained language model--------------------") 714 | result, preds, ex_ids = evaluate(model, tokenizer, Configs, Configs.task_name, split="test", prefix=prefix) 715 | test_acc = result["accuracy"] 716 | test_precision = result["precision"] 717 | test_recall = result["recall"] 718 | test_Fscore = result["f1_score"] 719 | 720 | else: 721 | test_acc, test_precision, test_recall, test_Fscore = train_with_helper(data_helper,model,Configs) 722 | 723 | record_file = Configs.record_file if Configs.record_file is not None else "record.txt" 724 | result_path = os.path.join(Configs.out_dir, time_stamp+"----"+record_file) 725 | with open(result_path, "w", encoding="utf-8") as f: 726 | f.write("test phase:\naccuracy\t{:.4f}\nprecision\t{:.4f}\nrecall\t{:.4f}\nf1_score\t{:.4f}" 727 | .format(test_acc*100,test_precision*100,test_recall*100,test_Fscore*100)) 728 | 729 | return test_acc, test_precision, test_recall, test_Fscore 730 | 731 | 732 | if __name__ == '__main__': 733 | import argparse 734 | import numpy as np 735 | parser = argparse.ArgumentParser(description="argument for generation") 736 | parser.add_argument("--config_path", type=str, default="./configs/test.json") 737 | args = parser.parse_args() 738 | Configs = utils.Config(args.config_path).get_configs() 739 | os.environ["CUDA_VISIBLE_DEVICES"] = Configs.gpuid 740 | total_test_acc=[] 741 | total_test_precision=[] 742 | total_test_recall=[] 743 | total_test_Fscore=[] 744 | for i in range(Configs.get("repeat_num", 1)): 745 | test_acc, test_precision, test_recall, test_Fscore = main(Configs,seed_shift=i) 746 | total_test_acc.append(test_acc) 747 | total_test_precision.append(test_precision) 748 | total_test_recall.append(test_recall) 749 | total_test_Fscore.append(test_Fscore) 750 | message = "Final results\n(repeat times: {}):\naccuracy\t{:.2f}%+{:.2f}%\nprecision\t{:.2f}%+{:.2f}%\nrecall\t{:.2f}%+{:.2f}%\nf1_score\t{:.2f}%+{:.2f}%"\ 751 | .format(Configs.get("repeat_num", 1), np.mean(total_test_acc)*100, np.std(total_test_acc)*100, 752 | np.mean(total_test_precision)*100, np.std(total_test_precision)*100, 753 | np.mean(total_test_recall)*100, np.std(total_test_recall)*100, 754 | np.mean(total_test_Fscore)*100, np.std(total_test_Fscore)*100) 755 | logger.info(message) 756 | --------------------------------------------------------------------------------