├── README.md ├── README.pdf ├── TAS_BERT_joint.py ├── TAS_BERT_separate.py ├── convert_tf_checkpoint_to_pytorch.py ├── data ├── __pycache__ │ └── change_TO_to_BIO.cpython-36.pyc ├── change_TO_to_BIO.py ├── data_preprocessing_for_TAS.py ├── semeval2015 │ ├── ABSA_15_Restaurants_Test.txt │ └── ABSA_15_Restaurants_Train.txt └── semeval2016 │ ├── ABSA_16_Restaurants_Test.txt │ └── ABSA_16_Restaurants_Train.txt ├── evaluation_for_AD_TD_TAD ├── A.jar ├── ABSA15.xsd ├── ABSA15.xsd~ ├── ABSA15BaseEvalValid.pdf ├── ABSA15_Restaurants_Test.xml ├── ABSA16_Restaurants_Test.xml ├── absa15.conf ├── absa15.sh └── change_pre_to_xml.py ├── evaluation_for_TSD_ASD_TASD.py ├── evaluation_for_loss_separate.py ├── modeling.py ├── optimization.py ├── processor.py └── tokenization.py /README.md: -------------------------------------------------------------------------------- 1 | # TAS-BERT 2 | 3 | Code and data for our paper [Target-Aspect-Sentiment Joint Detection for Aspect-Based Sentiment Analysis" (AAAI 2020)](https://aaai.org/ojs/index.php/AAAI/article/view/6447) 4 | 5 | Our code is based on [Utilizing BERT for Aspect-Based Sentiment Analysis via Constructing Auxiliary Sentence (NAACL 2019)](https://github.com/HSLCY/ABSA-BERT-pair) 6 | 7 | 8 | 9 | ## Requirements 10 | 11 | - pytorch: 1.0.1 12 | 13 | - python: 3.6.8 14 | 15 | - tensorflow: 1.13.1 (only for creating BERT-pytorch-model) 16 | 17 | - pytorch-crf: 0.7.2 18 | 19 | - numpy: 1.16.4 20 | 21 | - nltk: 3.4.4 22 | 23 | - sklearn: 0.21.2 24 | 25 | 26 | 27 | ## Data Preprocessing 28 | 29 | - Download [uncased BERT-Based model](https://github.com/google-research/bert), unpack, place the folder in the root directory and run `convert_tf_checkpoint_to_pytorch.py` to create BERT-pytorch-model. 30 | 31 | - run following commands to get preprocessed data. 32 | 33 | ``` 34 | cd data 35 | python data_preprocessing_for_TAS.py --dataset semeval2015 36 | python data_preprocessing_for_TAS.py --dataset semeval2016 37 | cd ../ 38 | ``` 39 | 40 | The preprocessed data is in folders **semeval2015/three_joint/BIO**,**semeval2015/three_joint/TO**,**semeval2016/three_joint/BIO** and **semeval2016/three_joint/TO**. BIO and TO are the two tagging schemes mentioned in our paper. 41 | 42 | The preprocessed data structure is as follows: 43 | 44 | | Key | Description | 45 | | ------------------ | ------------------------------------------------------------ | 46 | | *sentence_id* | Id of the sentence. | 47 | | *yes_no* | whether the sentence has corresponding sentiment in the corresponding aspect. The corresponding sentiment and aspect are given in *aspect_sentiment*. | 48 | | *aspect_sentiment* | pair of this line, such as "food quality positive". | 49 | | *sentence* | Content of the sentence. | 50 | | *ner_tags* | label sequence for targets that have corresponding sentiment in the corresponding aspect. The corresponding sentiment and aspect are given in *aspect_sentiment*. | 51 | 52 | 53 | 54 | ## Code Structure 55 | 56 | - `TAS_BERT_joint.py`: Program Runner. 57 | 58 | - `modeling.py`: Program Models. 59 | 60 | - `optimization.py`: Optimization for model. 61 | 62 | - `processor.py`: Data Processor. 63 | 64 | - `tokenization.py`: Tokenization, including three unknown-word-solutions. 65 | 66 | - `evaluation_for_TSD_ASD_TASD.py`: evaluation for ASD, TASD and TSD tasks. 67 | 68 | - `evaluation_for_AD_TD_TAD/`: The official evaluation tool for AD, TD and TAD tasks. 69 | 70 | - `TAS_BERT_separate.py` and `evaluation_for_loss_separate.py`: Separate detection for Ablation Study. 71 | 72 | 73 | 74 | ## Training & Testing 75 | 76 | If you want to train and test a joint detection model, you can use the following command: 77 | 78 | ``` 79 | CUDA_VISIBLE_DEVICES=0 python TAS_BERT_joint.py \ 80 | --data_dir data/semeval2016/three_joint/BIO/ \ 81 | --output_dir results/semeval2016/three_joint/BIO/my_result \ 82 | --vocab_file uncased_L-12_H-768_A-12/vocab.txt \ 83 | --bert_config_file uncased_L-12_H-768_A-12/bert_config.json \ 84 | --init_checkpoint uncased_L-12_H-768_A-12/pytorch_model.bin \ 85 | --tokenize_method word_split \ 86 | --use_crf \ 87 | --eval_test \ 88 | --do_lower_case \ 89 | --max_seq_length 128 \ 90 | --train_batch_size 24 \ 91 | --eval_batch_size 8 \ 92 | --learning_rate 2e-5 \ 93 | --num_train_epochs 30.0 94 | ``` 95 | 96 | The test results for each epoch will be stored in test_ep_*.txt in the output folder. 97 | 98 | 99 | 100 | ## Evaluation 101 | 102 | If you want to evaluate the test result for each epoch, you can use the following commands: 103 | 104 | **Note: We chose the epoch which performed best on the TASD task, and evaluate the result on all the subtasks.** 105 | 106 | - If you want to evaluate on the TASD task, ASD task, TSD task ignoring implicit targets, and TSD task considering implicit targets, you can use the following command: 107 | 108 | ``` 109 | python evaluation_for_TSD_ASD_TASD.py \ 110 | --output_dir results/semeval2016/three_joint/BIO/my_result \ 111 | --num_epochs 30 \ 112 | --tag_schema BIO 113 | ``` 114 | 115 | 116 | ⭐**The *tag_schema* bust be consistent with the contents in the *output_dir*, otherwise you will get error results.** 117 | 118 | "All tuples" correspond to "C1" in Table 3 of our paper. 119 | 120 | "Only NULL tuples" correspond to "C2" in Table 3 of our paper. 121 | 122 | "NO and pure O tag sequence" correspond to "C3" in Table 3 of our paper. 123 | 124 | 125 | 126 | 127 | As for the TD, AD and TAD tasks, we use [the evaluation tool provided by the SemEval2015 competition](http://alt.qcri.org/semeval2015/task12/index.php?id=data-and-tools). The tool requires a Java environment. 128 | 129 | - First, we should convert our test results into XML file format. You can use the following command: 130 | 131 | ``` 132 | cd evaluation_for_AD_TD_TAD 133 | python change_pre_to_xml.py \ 134 | --gold_path ../data/semeval2016/three_joint/BIO/test_TAS.tsv \ 135 | --pre_path ../results/semeval2016/three_joint/BIO/my_result/test_ep_23.txt \ 136 | --gold_xml_file ABSA16_Restaurants_Test.xml \ 137 | --pre_xml_file pred_file_2016.xml \ 138 | --tag_schema BIO 139 | ``` 140 | 141 | **Note: the "test_ep_*.txt" is the best epoch on the TASD task.** 142 | 143 | We will get a predication file in XML format: pred_file_2016.xml. 144 | 145 | 146 | 147 | - If you want to evaluate on the AD task: 148 | 149 | ``` 150 | java -cp ./A.jar absa15.Do Eval ./pred_file_2016.xml ./ABSA16_Restaurants_Test.xml 1 0 151 | ``` 152 | 153 | 154 | 155 | - If you want to evaluate on the TD task: 156 | 157 | ``` 158 | java -cp ./A.jar absa15.Do Eval ./pred_file_2016.xml ./ABSA16_Restaurants_Test.xml 2 0 159 | ``` 160 | 161 | 162 | 163 | - If you want to evaluate on the TAD task: 164 | 165 | ``` 166 | java -cp ./A.jar absa15.Do Eval ./pred_file_2016.xml ./ABSA16_Restaurants_Test.xml 3 0 167 | ``` 168 | 169 | 170 | 171 | ## Ablation Study 172 | 173 | If you want to try the separate models, please use the following commands: 174 | 175 | 176 | 177 | ``` 178 | CUDA_VISIBLE_DEVICES=0 python TAS_BERT_separate.py \ 179 | --data_dir data/semeval2016/three_joint/BIO/ \ 180 | --output_dir results/semeval2016/three_joint/BIO/my_result_AS \ 181 | --vocab_file uncased_L-12_H-768_A-12/vocab.txt \ 182 | --bert_config_file uncased_L-12_H-768_A-12/bert_config.json \ 183 | --init_checkpoint uncased_L-12_H-768_A-12/pytorch_model.bin \ 184 | --tokenize_method word_split \ 185 | --use_crf \ 186 | --subtask AS \ 187 | --eval_test \ 188 | --do_lower_case \ 189 | --max_seq_length 128 \ 190 | --train_batch_size 24 \ 191 | --eval_batch_size 8 \ 192 | --learning_rate 2e-5 \ 193 | --num_train_epochs 30.0 194 | ``` 195 | 196 | ``` 197 | CUDA_VISIBLE_DEVICES=0 python TAS_BERT_separate.py \ 198 | --data_dir data/semeval2016/three_joint/BIO/ \ 199 | --output_dir results/semeval2016/three_joint/BIO/my_result_T \ 200 | --vocab_file uncased_L-12_H-768_A-12/vocab.txt \ 201 | --bert_config_file uncased_L-12_H-768_A-12/bert_config.json \ 202 | --init_checkpoint uncased_L-12_H-768_A-12/pytorch_model.bin \ 203 | --tokenize_method word_split \ 204 | --use_crf \ 205 | --subtask T \ 206 | --eval_test \ 207 | --do_lower_case \ 208 | --max_seq_length 128 \ 209 | --train_batch_size 24 \ 210 | --eval_batch_size 8 \ 211 | --learning_rate 2e-5 \ 212 | --num_train_epochs 30.0 213 | ``` 214 | 215 | ``` 216 | python evaluation_for_loss_separate.py \ 217 | --output_dir_AS results/semeval2016/three_joint/BIO/my_result_AS \ 218 | --output_dir_T results/semeval2016/three_joint/BIO/my_result_T \ 219 | --num_epochs 30 \ 220 | --tag_schema BIO 221 | ``` 222 | -------------------------------------------------------------------------------- /README.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sysulic/TAS-BERT/704472423e4dfe1c50d9b53de2c2376db2af5ed0/README.pdf -------------------------------------------------------------------------------- /TAS_BERT_joint.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | """CUDA_VISIBLE_DEVICES=0 python TAS_BERT_joint.py""" 4 | from __future__ import absolute_import, division, print_function 5 | 6 | """ 7 | three-joint detection for target & aspect & sentiment 8 | """ 9 | 10 | import argparse 11 | import collections 12 | import logging 13 | import os 14 | import random 15 | 16 | import numpy as np 17 | import torch 18 | import torch.nn.functional as F 19 | from torch.utils.data import DataLoader, TensorDataset 20 | from torch.utils.data.distributed import DistributedSampler 21 | from torch.utils.data.sampler import RandomSampler, SequentialSampler 22 | from tqdm import tqdm, trange 23 | 24 | import tokenization 25 | from modeling import BertConfig, BertForTABSAJoint, BertForTABSAJoint_CRF 26 | from optimization import BERTAdam 27 | 28 | import datetime 29 | 30 | from processor import Semeval_Processor 31 | 32 | logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', 33 | datefmt = '%m/%d/%Y %H:%M:%S', 34 | level = logging.INFO) 35 | logger = logging.getLogger(__name__) 36 | 37 | 38 | class InputFeatures(object): 39 | """A single set of features of data.""" 40 | 41 | def __init__(self, input_ids, input_mask, segment_ids, label_id, ner_label_ids, ner_mask): 42 | self.input_ids = input_ids 43 | self.input_mask = input_mask 44 | self.segment_ids = segment_ids 45 | self.label_id = label_id 46 | self.ner_label_ids = ner_label_ids 47 | self.ner_mask = ner_mask 48 | 49 | 50 | def convert_examples_to_features(examples, label_list, max_seq_length, tokenizer, ner_label_list, tokenize_method): 51 | """Loads a data file into a list of `InputBatch`s.""" 52 | 53 | label_map = {} 54 | for (i, label) in enumerate(label_list): 55 | label_map[label] = i 56 | 57 | #here start with zero this means that "[PAD]" is zero 58 | ner_label_map = {} 59 | for (i, label) in enumerate(ner_label_list): 60 | ner_label_map[label] = i 61 | 62 | features = [] 63 | all_tokens = [] 64 | for (ex_index, example) in enumerate(tqdm(examples)): 65 | if tokenize_method == "word_split": 66 | # word_split 67 | word_num = 0 68 | tokens_a = tokenizer.tokenize(example.text_a) 69 | ner_labels_org = example.ner_labels_a.strip().split() 70 | ner_labels_a = [] 71 | token_bias_num = 0 72 | 73 | for (i, token) in enumerate(tokens_a): 74 | if token.startswith('##'): 75 | if ner_labels_org[i - 1 - token_bias_num] in ['O', 'T', 'I']: 76 | ner_labels_a.append(ner_labels_org[i - 1 - token_bias_num]) 77 | else: 78 | ner_labels_a.append('I') 79 | token_bias_num += 1 80 | else: 81 | word_num += 1 82 | ner_labels_a.append(ner_labels_org[i - token_bias_num]) 83 | 84 | assert word_num == len(ner_labels_org) 85 | assert len(ner_labels_a) == len(tokens_a) 86 | 87 | else: 88 | # prefix_match or unk_replace 89 | tokens_a = tokenizer.tokenize(example.text_a) 90 | ner_labels_a = example.ner_labels_a.strip().split() 91 | 92 | tokens_b = None 93 | if example.text_b: 94 | tokens_b = tokenizer.tokenize(example.text_b) 95 | 96 | if tokens_b: 97 | # Modifies `tokens_a` and `tokens_b` in place so that the total 98 | # length is less than the specified length. 99 | # Account for [CLS], [SEP], [SEP] with "- 3" 100 | _truncate_seq_pair(tokens_a, tokens_b, ner_labels_a, max_seq_length - 3) 101 | else: 102 | # Account for [CLS] and [SEP] with "- 2" 103 | if len(tokens_a) > max_seq_length - 2: 104 | tokens_a = tokens_a[0:(max_seq_length - 2)] 105 | ner_labels_a = ner_labels_a[0:(max_seq_length - 2)] 106 | 107 | # The convention in BERT is: 108 | # (a) For sequence pairs: 109 | # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP] 110 | # type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1 111 | # (b) For single sequences: 112 | # tokens: [CLS] the dog is hairy . [SEP] 113 | # type_ids: 0 0 0 0 0 0 0 114 | # 115 | # Where "type_ids" are used to indicate whether this is the first 116 | # sequence or the second sequence. The embedding vectors for `type=0` and 117 | # `type=1` were learned during pre-training and are added to the wordpiece 118 | # embedding vector (and position vector). This is not *strictly* necessary 119 | # since the [SEP] token unambigiously separates the sequences, but it makes 120 | # it easier for the model to learn the concept of sequences. 121 | # 122 | # For classification tasks, the first vector (corresponding to [CLS]) is 123 | # used as as the "sentence vector". Note that this only makes sense because 124 | # the entire model is fine-tuned. 125 | tokens = [] 126 | segment_ids = [] 127 | ner_label_ids = [] 128 | tokens.append("[CLS]") 129 | segment_ids.append(0) 130 | ner_label_ids.append(ner_label_map["[CLS]"]) 131 | try: 132 | for (i, token) in enumerate(tokens_a): 133 | tokens.append(token) 134 | segment_ids.append(0) 135 | ner_label_ids.append(ner_label_map[ner_labels_a[i]]) 136 | except: 137 | print(tokens_a) 138 | print(ner_labels_a) 139 | 140 | ner_mask = [1] * len(ner_label_ids) 141 | token_length = len(tokens) 142 | tokens.append("[SEP]") 143 | segment_ids.append(0) 144 | ner_label_ids.append(ner_label_map["[PAD]"]) 145 | 146 | if tokens_b: 147 | for token in tokens_b: 148 | tokens.append(token) 149 | segment_ids.append(1) 150 | ner_label_ids.append(ner_label_map["[PAD]"]) 151 | tokens.append("[SEP]") 152 | segment_ids.append(1) 153 | ner_label_ids.append(ner_label_map["[PAD]"]) 154 | 155 | input_ids = tokenizer.convert_tokens_to_ids(tokens) 156 | 157 | # The mask has 1 for real tokens and 0 for padding tokens. Only real 158 | # tokens are attended to. 159 | input_mask = [1] * len(input_ids) 160 | # Zero-pad up to the sequence length. 161 | while len(input_ids) < max_seq_length: 162 | input_ids.append(0) 163 | input_mask.append(0) 164 | segment_ids.append(0) 165 | ner_label_ids.append(ner_label_map["[PAD]"]) 166 | while len(ner_mask) < max_seq_length: 167 | ner_mask.append(0) 168 | 169 | assert len(input_ids) == max_seq_length 170 | assert len(input_mask) == max_seq_length 171 | assert len(segment_ids) == max_seq_length 172 | assert len(ner_mask) == max_seq_length 173 | assert len(ner_label_ids) == max_seq_length 174 | 175 | label_id = label_map[example.label] 176 | 177 | features.append( 178 | InputFeatures( 179 | input_ids=input_ids, 180 | input_mask=input_mask, 181 | segment_ids=segment_ids, 182 | label_id=label_id, 183 | ner_label_ids=ner_label_ids, 184 | ner_mask=ner_mask)) 185 | all_tokens.append(tokens[0:token_length]) 186 | return features, all_tokens 187 | 188 | 189 | def _truncate_seq_pair(tokens_a, tokens_b, ner_labels_a, max_length): 190 | """Truncates a sequence pair in place to the maximum length.""" 191 | 192 | # This is a simple heuristic which will always truncate the longer sequence 193 | # one token at a time. This makes more sense than truncating an equal percent 194 | # of tokens from each, since if one sequence is very short then each token 195 | # that's truncated likely contains more information than a longer sequence. 196 | while True: 197 | total_length = len(tokens_a) + len(tokens_b) 198 | if total_length <= max_length: 199 | break 200 | if len(tokens_a) > len(tokens_b): 201 | tokens_a.pop() 202 | ner_labels_a.pop() 203 | else: 204 | tokens_b.pop() 205 | 206 | 207 | def main(): 208 | parser = argparse.ArgumentParser() 209 | 210 | ## Required parameters 211 | parser.add_argument("--data_dir", 212 | default='data/semeval2015/three_joint/TO/', 213 | type=str, 214 | required=True, 215 | help="The input data dir. Should contain the .tsv files (or other data files) for the task.") 216 | parser.add_argument("--output_dir", 217 | default='results/semeval2015/three_joint/TO/my_result', 218 | type=str, 219 | required=True, 220 | help="The output directory where the model checkpoints will be written.") 221 | parser.add_argument("--vocab_file", 222 | default='uncased_L-12_H-768_A-12/vocab.txt', 223 | type=str, 224 | required=True, 225 | help="The vocabulary file that the BERT model was trained on.") 226 | parser.add_argument("--bert_config_file", 227 | default='uncased_L-12_H-768_A-12/bert_config.json', 228 | type=str, 229 | required=True, 230 | help="The config json file corresponding to the pre-trained BERT model. \n" 231 | "This specifies the model architecture.") 232 | parser.add_argument("--init_checkpoint", 233 | default='uncased_L-12_H-768_A-12/pytorch_model.bin', 234 | type=str, 235 | required=True, 236 | help="Initial checkpoint (usually from a pre-trained BERT model).") 237 | parser.add_argument("--tokenize_method", 238 | default='word_split', 239 | type=str, 240 | required=True, 241 | choices=["prefix_match", "unk_replace", "word_split"], 242 | help="how to solve the unknow words, max prefix match or replace with [UNK] or split to some words") 243 | parser.add_argument("--use_crf", 244 | default=True, 245 | required=True, 246 | action='store_true', 247 | help="Whether to use CRF after Bert sequence_output") 248 | 249 | ## Other parameters 250 | parser.add_argument("--eval_test", 251 | default=True, 252 | action='store_true', 253 | help="Whether to run eval on the test set.") 254 | parser.add_argument("--do_lower_case", 255 | default=True, 256 | action='store_true', 257 | help="Whether to lower case the input text. True for uncased models, False for cased models.") 258 | parser.add_argument("--max_seq_length", 259 | default=128, 260 | type=int, 261 | help="The maximum total input sequence length after WordPiece tokenization. \n" 262 | "Sequences longer than this will be truncated, and sequences shorter \n" 263 | "than this will be padded.") 264 | parser.add_argument("--train_batch_size", 265 | default=24, 266 | type=int, 267 | help="Total batch size for training.") 268 | parser.add_argument("--eval_batch_size", 269 | default=8, 270 | type=int, 271 | help="Total batch size for eval.") 272 | parser.add_argument("--learning_rate", 273 | default=2e-5, 274 | type=float, 275 | help="The initial learning rate for Adam.") 276 | parser.add_argument("--num_train_epochs", 277 | default=30.0, 278 | type=float, 279 | help="Total number of training epochs to perform.") 280 | parser.add_argument("--warmup_proportion", 281 | default=0.1, 282 | type=float, 283 | help="Proportion of training to perform linear learning rate warmup for. " 284 | "E.g., 0.1 = 10%% of training.") 285 | parser.add_argument("--no_cuda", 286 | default=False, 287 | action='store_true', 288 | help="Whether not to use CUDA when available") 289 | parser.add_argument("--accumulate_gradients", 290 | type=int, 291 | default=1, 292 | help="Number of steps to accumulate gradient on (divide the batch_size and accumulate)") 293 | parser.add_argument("--local_rank", 294 | type=int, 295 | default=-1, 296 | help="local_rank for distributed training on gpus") 297 | parser.add_argument('--seed', 298 | type=int, 299 | default=42, 300 | help="random seed for initialization") 301 | parser.add_argument('--gradient_accumulation_steps', 302 | type=int, 303 | default=1, 304 | help="Number of updates steps to accumualte before performing a backward/update pass.") 305 | args = parser.parse_args() 306 | 307 | 308 | if args.local_rank == -1 or args.no_cuda: 309 | device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") 310 | n_gpu = torch.cuda.device_count() 311 | else: 312 | device = torch.device("cuda", args.local_rank) 313 | n_gpu = 1 314 | # Initializes the distributed backend which will take care of sychronizing nodes/GPUs 315 | torch.distributed.init_process_group(backend='nccl') 316 | logger.info("device %s n_gpu %d distributed training %r", device, n_gpu, bool(args.local_rank != -1)) 317 | 318 | if args.accumulate_gradients < 1: 319 | raise ValueError("Invalid accumulate_gradients parameter: {}, should be >= 1".format( 320 | args.accumulate_gradients)) 321 | 322 | args.train_batch_size = int(args.train_batch_size / args.accumulate_gradients) 323 | 324 | random.seed(args.seed) 325 | np.random.seed(args.seed) 326 | torch.manual_seed(args.seed) 327 | if n_gpu > 0: 328 | torch.cuda.manual_seed_all(args.seed) 329 | 330 | bert_config = BertConfig.from_json_file(args.bert_config_file) 331 | 332 | if args.max_seq_length > bert_config.max_position_embeddings: 333 | raise ValueError( 334 | "Cannot use sequence length {} because the BERT model was only trained up to sequence length {}".format( 335 | args.max_seq_length, bert_config.max_position_embeddings)) 336 | 337 | processor = Semeval_Processor() 338 | label_list = processor.get_labels() 339 | ner_label_list = processor.get_ner_labels(args.data_dir) # BIO or TO tags for ner entity 340 | 341 | tokenizer = tokenization.FullTokenizer( 342 | vocab_file=args.vocab_file, tokenize_method=args.tokenize_method, do_lower_case=args.do_lower_case) 343 | 344 | if os.path.exists(args.output_dir) and os.listdir(args.output_dir): 345 | raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir)) 346 | os.makedirs(args.output_dir, exist_ok=True) 347 | 348 | # training set 349 | train_examples = None 350 | num_train_steps = None 351 | train_examples = processor.get_train_examples(args.data_dir) 352 | num_train_steps = int( 353 | len(train_examples) / args.train_batch_size * args.num_train_epochs) 354 | 355 | train_features, _ = convert_examples_to_features( 356 | train_examples, label_list, args.max_seq_length, tokenizer, ner_label_list, args.tokenize_method) 357 | logger.info("***** Running training *****") 358 | logger.info(" Num examples = %d", len(train_examples)) 359 | logger.info(" Batch size = %d", args.train_batch_size) 360 | logger.info(" Num steps = %d", num_train_steps) 361 | 362 | all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long) 363 | all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long) 364 | all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long) 365 | all_label_ids = torch.tensor([f.label_id for f in train_features], dtype=torch.long) 366 | all_ner_label_ids = torch.tensor([f.ner_label_ids for f in train_features], dtype=torch.long) 367 | all_ner_mask = torch.tensor([f.ner_mask for f in train_features], dtype=torch.long) 368 | 369 | train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids, all_ner_label_ids, all_ner_mask) 370 | if args.local_rank == -1: 371 | train_sampler = RandomSampler(train_data) 372 | else: 373 | train_sampler = DistributedSampler(train_data) 374 | train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size) 375 | 376 | # test set 377 | if args.eval_test: 378 | test_examples = processor.get_test_examples(args.data_dir) 379 | test_features, test_tokens = convert_examples_to_features( 380 | test_examples, label_list, args.max_seq_length, tokenizer, ner_label_list, args.tokenize_method) 381 | 382 | all_input_ids = torch.tensor([f.input_ids for f in test_features], dtype=torch.long) 383 | all_input_mask = torch.tensor([f.input_mask for f in test_features], dtype=torch.long) 384 | all_segment_ids = torch.tensor([f.segment_ids for f in test_features], dtype=torch.long) 385 | all_label_ids = torch.tensor([f.label_id for f in test_features], dtype=torch.long) 386 | all_ner_label_ids = torch.tensor([f.ner_label_ids for f in test_features], dtype=torch.long) 387 | all_ner_mask = torch.tensor([f.ner_mask for f in test_features], dtype=torch.long) 388 | 389 | test_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids, all_ner_label_ids, all_ner_mask) 390 | test_dataloader = DataLoader(test_data, batch_size=args.eval_batch_size, shuffle=False) 391 | 392 | 393 | # model and optimizer 394 | 395 | if args.use_crf: 396 | model = BertForTABSAJoint_CRF(bert_config, len(label_list), len(ner_label_list)) 397 | else: 398 | model = BertForTABSAJoint(bert_config, len(label_list), len(ner_label_list), args.max_seq_length) 399 | 400 | if args.init_checkpoint is not None: 401 | model.bert.load_state_dict(torch.load(args.init_checkpoint, map_location='cpu')) 402 | model.to(device) 403 | 404 | if args.local_rank != -1: 405 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], 406 | output_device=args.local_rank) 407 | elif n_gpu > 1: 408 | model = torch.nn.DataParallel(model) 409 | 410 | no_decay = ['bias', 'gamma', 'beta'] 411 | optimizer_parameters = [ 412 | {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay_rate': 0.01}, 413 | {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay_rate': 0.0} 414 | ] 415 | 416 | optimizer = BERTAdam(optimizer_parameters, 417 | lr=args.learning_rate, 418 | warmup=args.warmup_proportion, 419 | t_total=num_train_steps) 420 | 421 | 422 | # train 423 | output_log_file = os.path.join(args.output_dir, "log.txt") 424 | print("output_log_file=",output_log_file) 425 | with open(output_log_file, "w") as writer: 426 | if args.eval_test: 427 | writer.write("epoch\tglobal_step\tloss\ttest_loss\ttest_accuracy\n") 428 | else: 429 | writer.write("epoch\tglobal_step\tloss\n") 430 | 431 | global_step = 0 432 | epoch=0 433 | for _ in trange(int(args.num_train_epochs), desc="Epoch"): 434 | epoch+=1 435 | model.train() 436 | tr_loss = 0 437 | tr_ner_loss = 0 438 | nb_tr_examples, nb_tr_steps = 0, 0 439 | for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")): 440 | batch = tuple(t.to(device) for t in batch) 441 | input_ids, input_mask, segment_ids, label_ids, ner_label_ids, ner_mask = batch 442 | loss, ner_loss, _, _ = model(input_ids, segment_ids, input_mask, label_ids, ner_label_ids, ner_mask) 443 | 444 | if n_gpu > 1: 445 | loss = loss.mean() # mean() to average on multi-gpu. 446 | ner_loss = ner_loss.mean() 447 | if args.gradient_accumulation_steps > 1: 448 | loss = loss / args.gradient_accumulation_steps 449 | ner_loss = ner_loss / args.gradient_accumulation_steps 450 | loss.backward(retain_graph=True) 451 | ner_loss.backward() 452 | 453 | tr_loss += loss.item() 454 | tr_ner_loss += ner_loss.item() 455 | nb_tr_examples += input_ids.size(0) 456 | nb_tr_steps += 1 457 | if (step + 1) % args.gradient_accumulation_steps == 0: 458 | optimizer.step() # We have accumulated enought gradients 459 | model.zero_grad() 460 | global_step += 1 461 | 462 | 463 | # eval_test 464 | if args.eval_test: 465 | 466 | model.eval() 467 | test_loss, test_accuracy = 0, 0 468 | ner_test_loss = 0 469 | nb_test_steps, nb_test_examples = 0, 0 470 | with open(os.path.join(args.output_dir, "test_ep_"+str(epoch)+".txt"),"w") as f_test: 471 | f_test.write('yes_not\tyes_not_pre\tsentence\ttrue_ner\tpredict_ner\n') 472 | batch_index = 0 473 | for input_ids, input_mask, segment_ids, label_ids, ner_label_ids, ner_mask in test_dataloader: 474 | input_ids = input_ids.to(device) 475 | input_mask = input_mask.to(device) 476 | segment_ids = segment_ids.to(device) 477 | label_ids = label_ids.to(device) 478 | ner_label_ids = ner_label_ids.to(device) 479 | ner_mask = ner_mask.to(device) 480 | # test_tokens is the origin word in sentences [batch_size, sequence_length] 481 | ner_test_tokens = test_tokens[batch_index*args.eval_batch_size:(batch_index+1)*args.eval_batch_size] 482 | batch_index += 1 483 | 484 | with torch.no_grad(): 485 | tmp_test_loss, tmp_ner_test_loss, logits, ner_predict = model(input_ids, segment_ids, input_mask, label_ids, ner_label_ids, ner_mask) 486 | 487 | # category & polarity 488 | logits = F.softmax(logits, dim=-1) 489 | logits = logits.detach().cpu().numpy() 490 | label_ids = label_ids.to('cpu').numpy() 491 | outputs = np.argmax(logits, axis=1) 492 | 493 | if args.use_crf: 494 | # CRF 495 | ner_logits = ner_predict 496 | else: 497 | # softmax 498 | ner_logits = torch.argmax(F.log_softmax(ner_predict, dim=2),dim=2) 499 | ner_logits = ner_logits.detach().cpu().numpy() 500 | 501 | ner_label_ids = ner_label_ids.to('cpu').numpy() 502 | ner_mask = ner_mask.to('cpu').numpy() 503 | 504 | 505 | for output_i in range(len(outputs)): 506 | # category & polarity 507 | f_test.write(str(label_ids[output_i])) 508 | f_test.write('\t') 509 | f_test.write(str(outputs[output_i])) 510 | f_test.write('\t') 511 | 512 | # sentence & ner labels 513 | sentence_clean = [] 514 | label_true = [] 515 | label_pre = [] 516 | sentence_len = len(ner_test_tokens[output_i]) 517 | 518 | for i in range(sentence_len): 519 | if not ner_test_tokens[output_i][i].startswith('##'): 520 | sentence_clean.append(ner_test_tokens[output_i][i]) 521 | label_true.append(ner_label_list[ner_label_ids[output_i][i]]) 522 | label_pre.append(ner_label_list[ner_logits[output_i][i]]) 523 | 524 | f_test.write(' '.join(sentence_clean)) 525 | f_test.write('\t') 526 | f_test.write(' '.join(label_true)) 527 | f_test.write("\t") 528 | f_test.write(' '.join(label_pre)) 529 | f_test.write("\n") 530 | tmp_test_accuracy=np.sum(outputs == label_ids) 531 | test_loss += tmp_test_loss.mean().item() 532 | ner_test_loss += tmp_ner_test_loss.mean().item() 533 | test_accuracy += tmp_test_accuracy 534 | 535 | nb_test_examples += input_ids.size(0) 536 | nb_test_steps += 1 537 | 538 | test_loss = test_loss / nb_test_steps 539 | ner_test_loss = ner_test_loss / nb_test_steps 540 | test_accuracy = test_accuracy / nb_test_examples 541 | 542 | 543 | result = collections.OrderedDict() 544 | if args.eval_test: 545 | result = {'epoch': epoch, 546 | 'global_step': global_step, 547 | 'loss': tr_loss/nb_tr_steps, 548 | 'test_loss': test_loss, 549 | 'ner_test_loss': ner_test_loss, 550 | 'test_accuracy': test_accuracy} 551 | else: 552 | result = {'epoch': epoch, 553 | 'global_step': global_step, 554 | 'loss': tr_loss/nb_tr_steps, 555 | 'ner_loss': tr_ner_loss / nb_tr_steps} 556 | 557 | logger.info("***** Eval results *****") 558 | with open(output_log_file, "a+") as writer: 559 | for key in result.keys(): 560 | logger.info(" %s = %s\n", key, str(result[key])) 561 | writer.write("%s\t" % (str(result[key]))) 562 | writer.write("\n") 563 | 564 | if __name__ == "__main__": 565 | main() 566 | -------------------------------------------------------------------------------- /TAS_BERT_separate.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | """CUDA_VISIBLE_DEVICES=0 python TAS_BERT_separate.py""" 4 | from __future__ import absolute_import, division, print_function 5 | 6 | import argparse 7 | import collections 8 | import logging 9 | import os 10 | import random 11 | 12 | import numpy as np 13 | import torch 14 | import torch.nn.functional as F 15 | from torch.utils.data import DataLoader, TensorDataset 16 | from torch.utils.data.distributed import DistributedSampler 17 | from torch.utils.data.sampler import RandomSampler, SequentialSampler 18 | from tqdm import tqdm, trange 19 | 20 | import tokenization 21 | from modeling_split import BertConfig, BertForTABSAJoint_AS, BertForTABSAJoint_T, BertForTABSAJoint_CRF_AS, BertForTABSAJoint_CRF_T 22 | from optimization import BERTAdam 23 | 24 | import tensorflow as tf 25 | import datetime 26 | 27 | from processor import Semeval_Processor 28 | 29 | logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', 30 | datefmt = '%m/%d/%Y %H:%M:%S', 31 | level = logging.INFO) 32 | logger = logging.getLogger(__name__) 33 | 34 | 35 | class InputFeatures(object): 36 | """A single set of features of data.""" 37 | 38 | def __init__(self, input_ids, input_mask, segment_ids, label_id, ner_label_ids, ner_mask): 39 | self.input_ids = input_ids 40 | self.input_mask = input_mask 41 | self.segment_ids = segment_ids 42 | self.label_id = label_id 43 | self.ner_label_ids = ner_label_ids 44 | self.ner_mask = ner_mask 45 | 46 | 47 | def convert_examples_to_features(examples, label_list, max_seq_length, tokenizer, ner_label_list, tokenize_method): 48 | """Loads a data file into a list of `InputBatch`s.""" 49 | 50 | label_map = {} 51 | for (i, label) in enumerate(label_list): 52 | label_map[label] = i 53 | 54 | #here start with zero this means that "[PAD]" is zero 55 | ner_label_map = {} 56 | for (i, label) in enumerate(ner_label_list): 57 | ner_label_map[label] = i 58 | 59 | features = [] 60 | all_tokens = [] 61 | for (ex_index, example) in enumerate(tqdm(examples)): 62 | if tokenize_method == "word_split": 63 | # word_split 64 | word_num = 0 65 | tokens_a = tokenizer.tokenize(example.text_a) 66 | ner_labels_org = example.ner_labels_a.strip().split() 67 | ner_labels_a = [] 68 | token_bias_num = 0 69 | 70 | for (i, token) in enumerate(tokens_a): 71 | if token.startswith('##'): 72 | if ner_labels_org[i - 1 - token_bias_num] in ['O', 'T', 'I']: 73 | ner_labels_a.append(ner_labels_org[i - 1 - token_bias_num]) 74 | else: 75 | ner_labels_a.append('I') 76 | token_bias_num += 1 77 | else: 78 | word_num += 1 79 | ner_labels_a.append(ner_labels_org[i - token_bias_num]) 80 | 81 | assert word_num == len(ner_labels_org) 82 | assert len(ner_labels_a) == len(tokens_a) 83 | 84 | else: 85 | # prefix_match or unk_replace 86 | tokens_a = tokenizer.tokenize(example.text_a) 87 | ner_labels_a = example.ner_labels_a.strip().split() 88 | 89 | 90 | 91 | tokens_b = None 92 | if example.text_b: 93 | tokens_b = tokenizer.tokenize(example.text_b) 94 | 95 | if tokens_b: 96 | # Modifies `tokens_a` and `tokens_b` in place so that the total 97 | # length is less than the specified length. 98 | # Account for [CLS], [SEP], [SEP] with "- 3" 99 | _truncate_seq_pair(tokens_a, tokens_b, ner_labels_a, max_seq_length - 3) 100 | else: 101 | # Account for [CLS] and [SEP] with "- 2" 102 | if len(tokens_a) > max_seq_length - 2: 103 | tokens_a = tokens_a[0:(max_seq_length - 2)] 104 | ner_labels_a = ner_labels_a[0:(max_seq_length - 2)] 105 | 106 | # The convention in BERT is: 107 | # (a) For sequence pairs: 108 | # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP] 109 | # type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1 110 | # (b) For single sequences: 111 | # tokens: [CLS] the dog is hairy . [SEP] 112 | # type_ids: 0 0 0 0 0 0 0 113 | # 114 | # Where "type_ids" are used to indicate whether this is the first 115 | # sequence or the second sequence. The embedding vectors for `type=0` and 116 | # `type=1` were learned during pre-training and are added to the wordpiece 117 | # embedding vector (and position vector). This is not *strictly* necessary 118 | # since the [SEP] token unambigiously separates the sequences, but it makes 119 | # it easier for the model to learn the concept of sequences. 120 | # 121 | # For classification tasks, the first vector (corresponding to [CLS]) is 122 | # used as as the "sentence vector". Note that this only makes sense because 123 | # the entire model is fine-tuned. 124 | tokens = [] 125 | segment_ids = [] 126 | ner_label_ids = [] 127 | tokens.append("[CLS]") 128 | segment_ids.append(0) 129 | ner_label_ids.append(ner_label_map["[CLS]"]) 130 | try: 131 | for (i, token) in enumerate(tokens_a): 132 | tokens.append(token) 133 | segment_ids.append(0) 134 | ner_label_ids.append(ner_label_map[ner_labels_a[i]]) 135 | except: 136 | print(tokens_a) 137 | print(ner_labels_a) 138 | 139 | ner_mask = [1] * len(ner_label_ids) 140 | token_length = len(tokens) 141 | tokens.append("[SEP]") 142 | segment_ids.append(0) 143 | ner_label_ids.append(ner_label_map["[PAD]"]) 144 | 145 | if tokens_b: 146 | for token in tokens_b: 147 | tokens.append(token) 148 | segment_ids.append(1) 149 | ner_label_ids.append(ner_label_map["[PAD]"]) 150 | tokens.append("[SEP]") 151 | segment_ids.append(1) 152 | ner_label_ids.append(ner_label_map["[PAD]"]) 153 | 154 | input_ids = tokenizer.convert_tokens_to_ids(tokens) 155 | 156 | # The mask has 1 for real tokens and 0 for padding tokens. Only real 157 | # tokens are attended to. 158 | input_mask = [1] * len(input_ids) 159 | # Zero-pad up to the sequence length. 160 | while len(input_ids) < max_seq_length: 161 | input_ids.append(0) 162 | input_mask.append(0) 163 | segment_ids.append(0) 164 | ner_label_ids.append(ner_label_map["[PAD]"]) 165 | while len(ner_mask) < max_seq_length: 166 | ner_mask.append(0) 167 | 168 | assert len(input_ids) == max_seq_length 169 | assert len(input_mask) == max_seq_length 170 | assert len(segment_ids) == max_seq_length 171 | assert len(ner_mask) == max_seq_length 172 | assert len(ner_label_ids) == max_seq_length 173 | 174 | label_id = label_map[example.label] 175 | 176 | features.append( 177 | InputFeatures( 178 | input_ids=input_ids, 179 | input_mask=input_mask, 180 | segment_ids=segment_ids, 181 | label_id=label_id, 182 | ner_label_ids=ner_label_ids, 183 | ner_mask=ner_mask)) 184 | all_tokens.append(tokens[0:token_length]) 185 | return features, all_tokens 186 | 187 | 188 | def _truncate_seq_pair(tokens_a, tokens_b, ner_labels_a, max_length): 189 | """Truncates a sequence pair in place to the maximum length.""" 190 | 191 | # This is a simple heuristic which will always truncate the longer sequence 192 | # one token at a time. This makes more sense than truncating an equal percent 193 | # of tokens from each, since if one sequence is very short then each token 194 | # that's truncated likely contains more information than a longer sequence. 195 | while True: 196 | total_length = len(tokens_a) + len(tokens_b) 197 | if total_length <= max_length: 198 | break 199 | if len(tokens_a) > len(tokens_b): 200 | tokens_a.pop() 201 | ner_labels_a.pop() 202 | else: 203 | tokens_b.pop() 204 | 205 | 206 | def main(): 207 | parser = argparse.ArgumentParser() 208 | 209 | ## Required parameters 210 | parser.add_argument("--data_dir", 211 | default='data/semeval2015/three_joint/TO/', 212 | type=str, 213 | required=True, 214 | help="The input data dir. Should contain the .tsv files (or other data files) for the task.") 215 | parser.add_argument("--output_dir", 216 | default='results/semeval2015/loss_split/TO/LPM_TO_AS', 217 | type=str, 218 | required=True, 219 | help="The output directory where the model checkpoints will be written.") 220 | parser.add_argument("--vocab_file", 221 | default='uncased_L-12_H-768_A-12/vocab.txt', 222 | type=str, 223 | required=True, 224 | help="The vocabulary file that the BERT model was trained on.") 225 | parser.add_argument("--bert_config_file", 226 | default='uncased_L-12_H-768_A-12/bert_config.json', 227 | type=str, 228 | required=True, 229 | help="The config json file corresponding to the pre-trained BERT model. \n" 230 | "This specifies the model architecture.") 231 | parser.add_argument("--init_checkpoint", 232 | default='uncased_L-12_H-768_A-12/pytorch_model.bin', 233 | type=str, 234 | required=True, 235 | help="Initial checkpoint (usually from a pre-trained BERT model).") 236 | parser.add_argument("--tokenize_method", 237 | default='word_split', 238 | type=str, 239 | required=True, 240 | choices=["prefix_match", "unk_replace", "word_split"], 241 | help="how to solve the unknow words, max prefix match or replace with [UNK] or split to some words") 242 | parser.add_argument("--subtask", 243 | default='AS', 244 | type=str, 245 | required=True, 246 | choices=["AS", "T"], 247 | help="just one subtask") 248 | parser.add_argument("--use_crf", 249 | default=False, 250 | action='store_true', 251 | help="Whether to use CRF after Bert sequence_output") 252 | 253 | ## Other parameters 254 | parser.add_argument("--eval_test", 255 | default=True, 256 | action='store_true', 257 | help="Whether to run eval on the test set.") 258 | parser.add_argument("--do_lower_case", 259 | default=True, 260 | action='store_true', 261 | help="Whether to lower case the input text. True for uncased models, False for cased models.") 262 | parser.add_argument("--max_seq_length", 263 | default=128, 264 | type=int, 265 | help="The maximum total input sequence length after WordPiece tokenization. \n" 266 | "Sequences longer than this will be truncated, and sequences shorter \n" 267 | "than this will be padded.") 268 | parser.add_argument("--train_batch_size", 269 | default=18, 270 | type=int, 271 | help="Total batch size for training.") 272 | parser.add_argument("--eval_batch_size", 273 | default=8, 274 | type=int, 275 | help="Total batch size for eval.") 276 | parser.add_argument("--learning_rate", 277 | default=2e-5, 278 | type=float, 279 | help="The initial learning rate for Adam.") 280 | parser.add_argument("--num_train_epochs", 281 | default=30.0, 282 | type=float, 283 | help="Total number of training epochs to perform.") 284 | parser.add_argument("--warmup_proportion", 285 | default=0.1, 286 | type=float, 287 | help="Proportion of training to perform linear learning rate warmup for. " 288 | "E.g., 0.1 = 10%% of training.") 289 | parser.add_argument("--no_cuda", 290 | default=False, 291 | action='store_true', 292 | help="Whether not to use CUDA when available") 293 | parser.add_argument("--accumulate_gradients", 294 | type=int, 295 | default=1, 296 | help="Number of steps to accumulate gradient on (divide the batch_size and accumulate)") 297 | parser.add_argument("--local_rank", 298 | type=int, 299 | default=-1, 300 | help="local_rank for distributed training on gpus") 301 | parser.add_argument('--seed', 302 | type=int, 303 | default=42, 304 | help="random seed for initialization") 305 | parser.add_argument('--gradient_accumulation_steps', 306 | type=int, 307 | default=1, 308 | help="Number of updates steps to accumualte before performing a backward/update pass.") 309 | args = parser.parse_args() 310 | 311 | 312 | if args.local_rank == -1 or args.no_cuda: 313 | device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") 314 | n_gpu = torch.cuda.device_count() 315 | else: 316 | device = torch.device("cuda", args.local_rank) 317 | n_gpu = 1 318 | # Initializes the distributed backend which will take care of sychronizing nodes/GPUs 319 | torch.distributed.init_process_group(backend='nccl') 320 | logger.info("device %s n_gpu %d distributed training %r", device, n_gpu, bool(args.local_rank != -1)) 321 | 322 | if args.accumulate_gradients < 1: 323 | raise ValueError("Invalid accumulate_gradients parameter: {}, should be >= 1".format( 324 | args.accumulate_gradients)) 325 | 326 | args.train_batch_size = int(args.train_batch_size / args.accumulate_gradients) 327 | 328 | random.seed(args.seed) 329 | np.random.seed(args.seed) 330 | torch.manual_seed(args.seed) 331 | if n_gpu > 0: 332 | torch.cuda.manual_seed_all(args.seed) 333 | 334 | bert_config = BertConfig.from_json_file(args.bert_config_file) 335 | 336 | if args.max_seq_length > bert_config.max_position_embeddings: 337 | raise ValueError( 338 | "Cannot use sequence length {} because the BERT model was only trained up to sequence length {}".format( 339 | args.max_seq_length, bert_config.max_position_embeddings)) 340 | 341 | if os.path.exists(args.output_dir) and os.listdir(args.output_dir): 342 | raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir)) 343 | os.makedirs(args.output_dir, exist_ok=True) 344 | 345 | 346 | # prepare dataloaders 347 | processors = { 348 | 349 | "semeval":Semeval_Processor, 350 | } 351 | 352 | processor = Semeval_Processor() 353 | label_list = processor.get_labels() 354 | ner_label_list = processor.get_ner_labels(args.data_dir) # BIO or TO tags for ner entity 355 | 356 | tokenizer = tokenization.FullTokenizer( 357 | vocab_file=args.vocab_file, tokenize_method=args.tokenize_method, do_lower_case=args.do_lower_case) 358 | 359 | # training set 360 | train_examples = None 361 | num_train_steps = None 362 | train_examples = processor.get_train_examples(args.data_dir) 363 | num_train_steps = int( 364 | len(train_examples) / args.train_batch_size * args.num_train_epochs) 365 | 366 | train_features, _ = convert_examples_to_features( 367 | train_examples, label_list, args.max_seq_length, tokenizer, ner_label_list, args.tokenize_method) 368 | logger.info("***** Running training *****") 369 | logger.info(" Num examples = %d", len(train_examples)) 370 | logger.info(" Batch size = %d", args.train_batch_size) 371 | logger.info(" Num steps = %d", num_train_steps) 372 | 373 | all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long) 374 | all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long) 375 | all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long) 376 | all_label_ids = torch.tensor([f.label_id for f in train_features], dtype=torch.long) 377 | all_ner_label_ids = torch.tensor([f.ner_label_ids for f in train_features], dtype=torch.long) 378 | all_ner_mask = torch.tensor([f.ner_mask for f in train_features], dtype=torch.long) 379 | 380 | train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids, all_ner_label_ids, all_ner_mask) 381 | if args.local_rank == -1: 382 | train_sampler = RandomSampler(train_data) 383 | else: 384 | train_sampler = DistributedSampler(train_data) 385 | train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size) 386 | 387 | # test set 388 | if args.eval_test: 389 | test_examples = processor.get_test_examples(args.data_dir) 390 | test_features, test_tokens = convert_examples_to_features( 391 | test_examples, label_list, args.max_seq_length, tokenizer, ner_label_list, args.tokenize_method) 392 | 393 | all_input_ids = torch.tensor([f.input_ids for f in test_features], dtype=torch.long) 394 | all_input_mask = torch.tensor([f.input_mask for f in test_features], dtype=torch.long) 395 | all_segment_ids = torch.tensor([f.segment_ids for f in test_features], dtype=torch.long) 396 | all_label_ids = torch.tensor([f.label_id for f in test_features], dtype=torch.long) 397 | all_ner_label_ids = torch.tensor([f.ner_label_ids for f in test_features], dtype=torch.long) 398 | all_ner_mask = torch.tensor([f.ner_mask for f in test_features], dtype=torch.long) 399 | 400 | test_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids, all_ner_label_ids, all_ner_mask) 401 | test_dataloader = DataLoader(test_data, batch_size=args.eval_batch_size, shuffle=False) 402 | 403 | 404 | # model and optimizer 405 | 406 | if args.use_crf: 407 | if args.subtask == 'AS': 408 | model = BertForTABSAJoint_CRF_AS(bert_config, len(label_list), len(ner_label_list)) 409 | else: 410 | model = BertForTABSAJoint_CRF_T(bert_config, len(label_list), len(ner_label_list)) 411 | else: 412 | if args.subtask == 'AS': 413 | model = BertForTABSAJoint_AS(bert_config, len(label_list), len(ner_label_list), args.max_seq_length) 414 | else: 415 | model = BertForTABSAJoint_T(bert_config, len(label_list), len(ner_label_list), args.max_seq_length) 416 | 417 | if args.init_checkpoint is not None: 418 | model.bert.load_state_dict(torch.load(args.init_checkpoint, map_location='cpu')) 419 | model.to(device) 420 | 421 | if args.local_rank != -1: 422 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], 423 | output_device=args.local_rank) 424 | elif n_gpu > 1: 425 | model = torch.nn.DataParallel(model) 426 | 427 | no_decay = ['bias', 'gamma', 'beta'] 428 | optimizer_parameters = [ 429 | {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay_rate': 0.01}, 430 | {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay_rate': 0.0} 431 | ] 432 | 433 | optimizer = BERTAdam(optimizer_parameters, 434 | lr=args.learning_rate, 435 | warmup=args.warmup_proportion, 436 | t_total=num_train_steps) 437 | 438 | 439 | # train 440 | output_log_file = os.path.join(args.output_dir, "log.txt") 441 | print("output_log_file=",output_log_file) 442 | with open(output_log_file, "w") as writer: 443 | if args.eval_test: 444 | writer.write("epoch\tglobal_step\tloss\ttest_loss\ttest_accuracy\n") 445 | else: 446 | writer.write("epoch\tglobal_step\tloss\n") 447 | 448 | global_step = 0 449 | epoch=0 450 | for _ in trange(int(args.num_train_epochs), desc="Epoch"): 451 | epoch+=1 452 | model.train() 453 | tr_loss = 0 454 | tr_ner_loss = 0 455 | nb_tr_examples, nb_tr_steps = 0, 0 456 | for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")): 457 | batch = tuple(t.to(device) for t in batch) 458 | input_ids, input_mask, segment_ids, label_ids, ner_label_ids, ner_mask = batch 459 | # loss, ner_loss, _, _ = model(input_ids, segment_ids, input_mask, label_ids, ner_label_ids, ner_mask) 460 | loss, _ = model(input_ids, segment_ids, input_mask, label_ids, ner_label_ids, ner_mask) 461 | if n_gpu > 1: 462 | loss = loss.mean() # mean() to average on multi-gpu. 463 | 464 | if args.gradient_accumulation_steps > 1: 465 | loss = loss / args.gradient_accumulation_steps 466 | loss.backward() 467 | tr_loss += loss.item() 468 | nb_tr_examples += input_ids.size(0) 469 | nb_tr_steps += 1 470 | if (step + 1) % args.gradient_accumulation_steps == 0: 471 | optimizer.step() # We have accumulated enought gradients 472 | model.zero_grad() 473 | global_step += 1 474 | 475 | 476 | # eval_test 477 | if args.eval_test: 478 | 479 | model.eval() 480 | test_loss, test_accuracy = 0, 0 481 | nb_test_steps, nb_test_examples = 0, 0 482 | with open(os.path.join(args.output_dir, "test_ep_"+str(epoch)+".txt"),"w") as f_test: 483 | if args.subtask == 'AS': 484 | f_test.write('yes_not\tyes_not_pre\n') 485 | else: 486 | f_test.write('sentence\ttrue_ner\tpredict_ner\n') 487 | 488 | batch_index = 0 489 | for input_ids, input_mask, segment_ids, label_ids, ner_label_ids, ner_mask in test_dataloader: 490 | input_ids = input_ids.to(device) 491 | input_mask = input_mask.to(device) 492 | segment_ids = segment_ids.to(device) 493 | label_ids = label_ids.to(device) 494 | ner_label_ids = ner_label_ids.to(device) 495 | ner_mask = ner_mask.to(device) 496 | # test_tokens is the origin word in sentences [batch_size, sequence_length] 497 | ner_test_tokens = test_tokens[batch_index*args.eval_batch_size:(batch_index+1)*args.eval_batch_size] 498 | batch_index += 1 499 | 500 | with torch.no_grad(): 501 | # tmp_test_loss, tmp_ner_test_loss, logits, ner_predict = model(input_ids, segment_ids, input_mask, label_ids, ner_label_ids, ner_mask) 502 | tmp_test_loss, logits = model(input_ids, segment_ids, input_mask, label_ids, ner_label_ids, ner_mask) 503 | if args.subtask == 'AS': 504 | # category & polarity 505 | logits = F.softmax(logits, dim=-1) 506 | logits = logits.detach().cpu().numpy() 507 | label_ids = label_ids.to('cpu').numpy() 508 | outputs = np.argmax(logits, axis=1) 509 | for output_i in range(len(outputs)): 510 | # category & polarity 511 | f_test.write(str(label_ids[output_i])) 512 | f_test.write('\t') 513 | f_test.write(str(outputs[output_i])) 514 | f_test.write('\t') 515 | f_test.write("\n") 516 | tmp_test_accuracy=np.sum(outputs == label_ids) 517 | test_accuracy += tmp_test_accuracy 518 | else: 519 | if args.use_crf: 520 | # CRF 521 | ner_logits = logits 522 | else: 523 | # softmax 524 | ner_logits = torch.argmax(F.log_softmax(logits, dim=2),dim=2) 525 | ner_logits = ner_logits.detach().cpu().numpy() 526 | 527 | ner_label_ids = ner_label_ids.to('cpu').numpy() 528 | ner_mask = ner_mask.to('cpu').numpy() 529 | for output_i in range(len(ner_test_tokens)): 530 | # sentence & ner labels 531 | sentence_clean = [] 532 | label_true = [] 533 | label_pre = [] 534 | sentence_len = len(ner_test_tokens[output_i]) 535 | 536 | for i in range(sentence_len): 537 | if not ner_test_tokens[output_i][i].startswith('##'): 538 | sentence_clean.append(ner_test_tokens[output_i][i]) 539 | label_true.append(ner_label_list[ner_label_ids[output_i][i]]) 540 | label_pre.append(ner_label_list[ner_logits[output_i][i]]) 541 | 542 | f_test.write(' '.join(sentence_clean)) 543 | f_test.write('\t') 544 | f_test.write(' '.join(label_true)) 545 | f_test.write("\t") 546 | f_test.write(' '.join(label_pre)) 547 | f_test.write("\n") 548 | 549 | test_loss += tmp_test_loss.mean().item() 550 | nb_test_examples += input_ids.size(0) 551 | nb_test_steps += 1 552 | 553 | test_loss = test_loss / nb_test_steps 554 | #test_accuracy = test_accuracy / nb_test_examples 555 | 556 | 557 | result = collections.OrderedDict() 558 | if args.eval_test: 559 | result = {'epoch': epoch, 560 | 'global_step': global_step, 561 | 'loss': tr_loss/nb_tr_steps, 562 | 'test_loss': test_loss} 563 | else: 564 | result = {'epoch': epoch, 565 | 'global_step': global_step, 566 | 'loss': tr_loss/nb_tr_steps} 567 | 568 | logger.info("***** Eval results *****") 569 | with open(output_log_file, "a+") as writer: 570 | for key in result.keys(): 571 | logger.info(" %s = %s\n", key, str(result[key])) 572 | writer.write("%s\t" % (str(result[key]))) 573 | writer.write("\n") 574 | 575 | if __name__ == "__main__": 576 | main() 577 | -------------------------------------------------------------------------------- /convert_tf_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | # Reference: https://github.com/huggingface/pytorch-pretrained-BERT 4 | 5 | """Convert BERT checkpoint.""" 6 | 7 | from __future__ import absolute_import, division, print_function 8 | 9 | import argparse 10 | import re 11 | 12 | import numpy as np 13 | import torch 14 | 15 | import tensorflow as tf 16 | from modeling import BertConfig, BertModel 17 | 18 | parser = argparse.ArgumentParser() 19 | 20 | ## Required parameters 21 | parser.add_argument("--tf_checkpoint_path", 22 | default = 'uncased_L-12_H-768_A-12/bert_model.ckpt', 23 | type = str, 24 | required = False, 25 | help = "Path the TensorFlow checkpoint path.") 26 | parser.add_argument("--bert_config_file", 27 | default = 'uncased_L-12_H-768_A-12/bert_config.json', 28 | type = str, 29 | required = False, 30 | help = "The config json file corresponding to the pre-trained BERT model. \n" 31 | "This specifies the model architecture.") 32 | parser.add_argument("--pytorch_dump_path", 33 | default = 'uncased_L-12_H-768_A-12/pytorch_model.bin', 34 | type = str, 35 | required = False, 36 | help = "Path to the output PyTorch model.") 37 | 38 | args = parser.parse_args() 39 | 40 | def convert(): 41 | # Initialise PyTorch model 42 | config = BertConfig.from_json_file(args.bert_config_file) 43 | model = BertModel(config) 44 | 45 | # Load weights from TF model 46 | path = args.tf_checkpoint_path 47 | print("Converting TensorFlow checkpoint from {}".format(path)) 48 | 49 | init_vars = tf.train.list_variables(path) 50 | names = [] 51 | arrays = [] 52 | for name, shape in init_vars: 53 | print("Loading {} with shape {}".format(name, shape)) 54 | array = tf.train.load_variable(path, name) 55 | print("Numpy array shape {}".format(array.shape)) 56 | names.append(name) 57 | arrays.append(array) 58 | 59 | for name, array in zip(names, arrays): 60 | name = name[5:] # skip "bert/" 61 | print("Loading {}".format(name)) 62 | name = name.split('/') 63 | if any(n in ["adam_v", "adam_m","l_step"] for n in name): 64 | print("Skipping {}".format("/".join(name))) 65 | continue 66 | if name[0] in ['redictions', 'eq_relationship']: 67 | print("Skipping") 68 | continue 69 | pointer = model 70 | for m_name in name: 71 | if re.fullmatch(r'[A-Za-z]+_\d+', m_name): 72 | l = re.split(r'_(\d+)', m_name) 73 | else: 74 | l = [m_name] 75 | if l[0] == 'kernel': 76 | pointer = getattr(pointer, 'weight') 77 | else: 78 | pointer = getattr(pointer, l[0]) 79 | if len(l) >= 2: 80 | num = int(l[1]) 81 | pointer = pointer[num] 82 | if m_name[-11:] == '_embeddings': 83 | pointer = getattr(pointer, 'weight') 84 | elif m_name == 'kernel': 85 | array = np.transpose(array) 86 | try: 87 | assert pointer.shape == array.shape 88 | except AssertionError as e: 89 | e.args += (pointer.shape, array.shape) 90 | raise 91 | pointer.data = torch.from_numpy(array) 92 | 93 | # Save pytorch-model 94 | torch.save(model.state_dict(), args.pytorch_dump_path) 95 | 96 | if __name__ == "__main__": 97 | convert() 98 | -------------------------------------------------------------------------------- /data/__pycache__/change_TO_to_BIO.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sysulic/TAS-BERT/704472423e4dfe1c50d9b53de2c2376db2af5ed0/data/__pycache__/change_TO_to_BIO.cpython-36.pyc -------------------------------------------------------------------------------- /data/change_TO_to_BIO.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | """ 4 | create BIO labels for targets 5 | """ 6 | 7 | import csv 8 | import os 9 | import re 10 | 11 | def TXT_file(name): 12 | return '{}.txt'.format(name) 13 | 14 | def TSV_file(name): 15 | return '{}.tsv'.format(name) 16 | 17 | 18 | def change_TO_to_BIO(path, file_name): 19 | input_path = path + '/TO' 20 | entity_label = r"T+" 21 | output_path = path + '/BIO' 22 | file_out = file_name 23 | if not os.path.exists(output_path): 24 | os.makedirs(output_path) 25 | with open(os.path.join(input_path, TSV_file(file_name)), 'r', encoding='utf-8') as fin, open(os.path.join(output_path, TSV_file(file_out)), 'w', encoding='utf-8') as fout: 26 | fout.write('\t'.join(['sentence_id', 'yes_no', 'aspect_sentiment', 'sentence', 'ner_tags'])) 27 | fout.write('\n') 28 | fin.readline() 29 | for line in fin: 30 | line_arr = line.strip().split('\t') 31 | # change TO to BIO tags 32 | ner_tags = ''.join(line_arr[-1].split()) 33 | entity_list = re.finditer(entity_label, ner_tags) 34 | BIO_tags = ['O'] * len(ner_tags) 35 | for x in entity_list: 36 | start = x.start() 37 | en_len = len(x.group()) 38 | BIO_tags[start] = 'B' 39 | for m in range(start+1, start+en_len): 40 | BIO_tags[m] = 'I' 41 | 42 | line_arr[-1] = ' '.join(BIO_tags) 43 | fout.write('\t'.join(line_arr)) 44 | fout.write('\n') -------------------------------------------------------------------------------- /data/data_preprocessing_for_TAS.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | """ 4 | Preprocessing for SemEval-2015 and SemEval-2016 Datasets 5 | for three joint task (target & aspect & sentiment) 6 | 7 | TO floder means TO labeling schema for targets 8 | BIO floder means BIO labeling schema for targets 9 | """ 10 | 11 | import csv 12 | import os 13 | import re 14 | import argparse 15 | from change_TO_to_BIO import TXT_file, TSV_file, change_TO_to_BIO 16 | 17 | def get_aspect_sentiment_compose(path, file_name): 18 | aspect_set = [] 19 | sentiment_set = ['positive', 'negative', 'neutral'] # sentiment polarity 20 | with open(os.path.join(path, TXT_file(file_name)), 'r', encoding='utf-8') as fin: 21 | fin.readline() 22 | for line in fin: 23 | line_arr = line.strip().split('\t') 24 | if line_arr[6] == 'yes': # entailed == yes 25 | if line_arr[3] not in aspect_set: 26 | aspect_set.append(line_arr[3]) # aspect 27 | 28 | compose_set = [] 29 | for ca in aspect_set: 30 | for po in sentiment_set: 31 | compose_set.append(ca + ' ' + po) 32 | 33 | return compose_set 34 | 35 | 36 | def create_dataset_file(input_path, output_path, input_file, output_file, compose_set): 37 | one_data_nums = 0 38 | zero_data_nums = 0 39 | max_len = 0 40 | entity_sum = 0 41 | if not os.path.exists(output_path): 42 | os.makedirs(output_path) 43 | with open(os.path.join(input_path, TXT_file(input_file)), 'r', encoding='utf-8') as fin, open(os.path.join(output_path, TSV_file(output_file)), 'w', encoding='utf-8') as fout: 44 | fout.write('\t'.join(['sentence_id', 'yes_no', 'aspect_sentiment', 'sentence', 'ner_tags'])) 45 | fout.write('\n') 46 | fin.readline() 47 | pre_start = False 48 | pre_sentence_id = 'XXX' # sentence id of the previous line 49 | pre_sentence = 'XXX' 50 | record_of_one_sentence = set() # the set of aspect&sentiment that this sentence contains 51 | record_of_one_sentence_ner_tag = {} # the NER tags of the set of aspect&sentiment that this sentence contains 52 | for line in fin: 53 | line_arr = line.strip().split('\t') 54 | sentence_id = line_arr[0] 55 | if sentence_id != pre_sentence_id: # this is a new sentence 56 | if pre_start == True: 57 | for x in compose_set: 58 | # create yes line of this sentence 59 | if x in record_of_one_sentence: 60 | fout.write(pre_sentence_id + '\t' + '1' + '\t' + x + '\t' + pre_sentence + '\t' + record_of_one_sentence_ner_tag[x] + '\n') 61 | one_data_nums += 1 62 | # create no line 63 | else: 64 | fout.write(pre_sentence_id + '\t' + '0' + '\t' + x + '\t' + pre_sentence + '\t' + ' '.join(['O']*len(pre_sentence.split())) + '\n') 65 | zero_data_nums += 1 66 | 67 | else: 68 | pre_start = True 69 | record_of_one_sentence.clear() 70 | record_of_one_sentence_ner_tag.clear() 71 | pre_sentence_id = sentence_id 72 | 73 | 74 | if line_arr[6] == 'yes': # entailed == yes 75 | # get NER labels 76 | sentence = line_arr[1].strip().split(' ') 77 | gold_target = ' '.join(line_arr[2].strip().split()) 78 | ner_tags = ['O'] * len(sentence) 79 | start = int(line_arr[7]) - 1 80 | end = int(line_arr[8]) - 1 81 | if line_arr[1].startswith(' '): 82 | start -= 1 83 | end -= 1 84 | if not (start < 0 and end < 0): # not NULL 85 | get_target = ' '.join(sentence[start:end]) 86 | if gold_target != get_target: 87 | print('Error!!!!') 88 | print(line_arr[1]) 89 | print(gold_target) 90 | print(get_target) 91 | print(str(start) + ' - ' + str(end)) 92 | 93 | for x in range(start, end): 94 | ner_tags[x] = 'T' 95 | 96 | sentence_clear = [] 97 | ner_tags_clear = [] 98 | # solve the ' ' multi space 99 | special_token = "$()*+.[]?\\^}{|!'#%&,-/:;_~@<=>`\"’“”‘…" 100 | special_token_re = r"[\$\(\)\*\+\.\[\]\?\\\^\{\}\|!'#%&,-/:;_~@<=>`\"’‘“”…]{1,1}" 101 | for x in range(len(sentence)): 102 | in_word = False 103 | if sentence[x] != '': 104 | punctuation_list = re.finditer(special_token_re, sentence[x]) 105 | punctuation_list_start = [] 106 | punctuation_list_len = [] 107 | for m in punctuation_list: 108 | punctuation_list_start.append(m.start()) 109 | punctuation_list_len.append(len(m.group())) 110 | 111 | if len(punctuation_list_start) != 0: 112 | # the start is word 113 | if punctuation_list_start[0] != 0: 114 | sentence_clear.append(sentence[x][0:punctuation_list_start[0]]) 115 | ner_tags_clear.append(ner_tags[x]) 116 | for (i, m) in enumerate(punctuation_list_start): 117 | #print(len(punctuation_list_start)) 118 | #print(len(punctuation_list_len)) 119 | #print(str(m) + ' - ' + str(m+punctuation_list_len[i])) 120 | sentence_clear.append(sentence[x][m:m+punctuation_list_len[i]]) 121 | ner_tags_clear.append(ner_tags[x]) 122 | 123 | if i != len(punctuation_list_start) - 1: 124 | if m+punctuation_list_len[i] != punctuation_list_start[i+1] : 125 | sentence_clear.append(sentence[x][m+punctuation_list_len[i]:punctuation_list_start[i+1]]) 126 | ner_tags_clear.append(ner_tags[x]) 127 | 128 | else: 129 | if m+punctuation_list_len[i] < len(sentence[x]): 130 | sentence_clear.append(sentence[x][m+punctuation_list_len[i]:]) 131 | ner_tags_clear.append(ner_tags[x]) 132 | 133 | 134 | else: # has no punctuation 135 | sentence_clear.append(sentence[x]) 136 | ner_tags_clear.append(ner_tags[x]) 137 | 138 | assert '' not in sentence_clear 139 | assert len(sentence_clear) == len(ner_tags_clear) 140 | 141 | # get aspect&sentiment 142 | cate_pola = line_arr[3] + ' ' + line_arr[4] 143 | 144 | pre_sentence = ' '.join(sentence_clear) 145 | assert ' ' not in pre_sentence 146 | 147 | if len(sentence_clear) > max_len: 148 | max_len = len(sentence_clear) 149 | if cate_pola in record_of_one_sentence: # this aspect&sentiment has more than one target 150 | ner_tags_A = record_of_one_sentence_ner_tag[cate_pola].split() 151 | if len(ner_tags_A) != len(ner_tags_clear): 152 | print('Ner Tags Length Error!!!') 153 | else: 154 | for x in range(len(ner_tags_A)): 155 | if ner_tags_A[x] != 'O': 156 | ner_tags_clear[x] = ner_tags_A[x] 157 | record_of_one_sentence_ner_tag[cate_pola] = ' '.join(ner_tags_clear) 158 | if 'T' in ner_tags_A: 159 | entity_sum += 1 160 | 161 | else: 162 | record_of_one_sentence.add(cate_pola) 163 | record_of_one_sentence_ner_tag[cate_pola] = ' '.join(ner_tags_clear) 164 | entity_sum += 1 165 | print('entity_sum: ', entity_sum) 166 | print('max_sen_len: ', max_len) 167 | print('sample ratio: ', str(one_data_nums), '-', str(zero_data_nums)) 168 | 169 | 170 | 171 | 172 | if __name__ == '__main__': 173 | parser = argparse.ArgumentParser() 174 | parser.add_argument('--dataset', 175 | type=str, 176 | choices=["semeval2015", "semeval2016"], 177 | help='dataset, as a folder name, you can choose from semeval2015 and semeval2016') 178 | args = parser.parse_args() 179 | 180 | path = args.dataset + '/three_joint' 181 | output_path = path + '/TO' 182 | 183 | if '2015' in args.dataset: 184 | train_file = 'ABSA_15_Restaurants_Train' 185 | test_file = 'ABSA_15_Restaurants_Test' 186 | else: 187 | train_file = 'ABSA_16_Restaurants_Train' 188 | test_file = 'ABSA_16_Restaurants_Test' 189 | 190 | train_output = 'train_TAS' 191 | test_output = 'test_TAS' 192 | 193 | # get set of aspect-sentiment 194 | compose_set = get_aspect_sentiment_compose(args.dataset, train_file) 195 | 196 | for input_file, output_file in zip([train_file, test_file], [train_output, test_output]): 197 | # get preprocessed data, TO labeling schema 198 | create_dataset_file(args.dataset, output_path, input_file, output_file, compose_set) 199 | # get preprocessed data, BIO labeling schema 200 | change_TO_to_BIO(path, output_file) 201 | 202 | -------------------------------------------------------------------------------- /evaluation_for_AD_TD_TAD/A.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sysulic/TAS-BERT/704472423e4dfe1c50d9b53de2c2376db2af5ed0/evaluation_for_AD_TD_TAD/A.jar -------------------------------------------------------------------------------- /evaluation_for_AD_TD_TAD/ABSA15.xsd: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | -------------------------------------------------------------------------------- /evaluation_for_AD_TD_TAD/ABSA15.xsd~: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sysulic/TAS-BERT/704472423e4dfe1c50d9b53de2c2376db2af5ed0/evaluation_for_AD_TD_TAD/ABSA15.xsd~ -------------------------------------------------------------------------------- /evaluation_for_AD_TD_TAD/ABSA15BaseEvalValid.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sysulic/TAS-BERT/704472423e4dfe1c50d9b53de2c2376db2af5ed0/evaluation_for_AD_TD_TAD/ABSA15BaseEvalValid.pdf -------------------------------------------------------------------------------- /evaluation_for_AD_TD_TAD/absa15.conf: -------------------------------------------------------------------------------- 1 | 2 | # 1. Define source xml 3 | src=ABSA-15_Restaurants_Train_Final.xml 4 | #src=ABSA-15_Laptops_Train_Data.xml 5 | 6 | # 2. Choose domain (rest|lapt) 7 | dom=rest 8 | #dom=lapt 9 | 10 | # 3. Threshold for aspect prediction 11 | #laptops 12 | #thr=0.12 13 | # restaurants 14 | thr=0.2 15 | 16 | # 4. suffle reviews 1.yes 0.no 17 | sfl=1 18 | 19 | # 5. Cross validation 1.yes 0.no 20 | xva=0 21 | 22 | # 6. Num of parts that the src file will be split into. 23 | fld=10 24 | 25 | # 7. part (0,1,2,...,fld-1) 26 | partIdx=9 27 | 28 | # 8. Num of BOW features used in the SVM models. 29 | ftr=1000 30 | 31 | # 9. Folder for train and test xml files, models and outputs 32 | ttd=Files 33 | 34 | -------------------------------------------------------------------------------- /evaluation_for_AD_TD_TAD/absa15.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | clear 3 | 4 | # Baselines folder (full path) 5 | dir=$(pwd)/ 6 | 7 | echo "Reading config...." >&2 8 | . $dir"absa15.conf" 9 | 10 | # ttd full path 11 | pth=$dir$ttd/ 12 | 13 | 14 | # --------------------------------------------------- 15 | 16 | ABSABaseAndEval () 17 | { 18 | 19 | echo -e "***** Harvesting Features from Train *****" 20 | java -cp ./A.jar absa15.Do ExtractFeats $dom $dir $ttd $ftr$pIdxArg 21 | 22 | echo -e "***** Creating Train Vectors for Stages 1 and 2 *****" 23 | java -cp ./A.jar absa15.Do CreateVecs $dom $dir $ttd 1 "1"$pIdxArg 24 | 25 | echo -e "***** Training SVM model for category prediction *****" 26 | ./libsvm-3.18/svm-train -t 0 -b 1 -q ${pth}"tr.svm.asp"${suff} ${pth}"tr.svm.model.asp"${suff} 27 | 28 | # Stage 1 Predict 29 | 30 | echo -e "***** Creating Test Vectors for Stage 1 *****" 31 | java -cp ./A.jar absa15.Do CreateVecs $dom $dir $ttd 2 "1"$pIdxArg 32 | 33 | echo -e "***** Predicting categories *****" 34 | ./libsvm-3.18/svm-predict -b 1 ${pth}"te.svm.asp"${suff} ${pth}"tr.svm.model.asp"${suff} ${pth}"Out.asp"${suff} 35 | 36 | echo -e "***** Assigning categories using a threshold on the SVM prediction *****" 37 | java -cp ./A.jar absa15.Do Assign $dom $dir $ttd $thr 1 "0"$pIdxArg 38 | 39 | if [ "$dom" = "rest" ]; then 40 | 41 | echo -e "***** Determining targets using a target list created from train data *****" 42 | java -cp ./A.jar absa15.Do IdentifyTargets $dom $dir $ttd $pIdxArg 43 | fi 44 | 45 | # Stage 2 Training 46 | 47 | echo -e "***** Training polarity category model *****" 48 | ./libsvm-3.18/svm-train -t 0 -b 1 -q ${pth}"tr.svm.pol"${suff} ${pth}"tr.svm.model.pol"${suff} 49 | 50 | # Stage 2 Predict 51 | 52 | echo -e "***** Creating Test Vectors for Stage 2 *****" 53 | java -cp ./A.jar absa15.Do CreateVecs $dom $dir $ttd 2 "2"$pIdxArg 54 | 55 | # gold aspects 56 | echo -e "***** Predicting polarities using SVM for gold aspect categories *****" 57 | ./libsvm-3.18/svm-predict -b 1 ${pth}"te.svm.pol4g"${suff} ${pth}"tr.svm.model.pol"${suff} ${pth}"Out.pol"${suff} 58 | 59 | echo -e "***** Assigning polarities based on SVM prediction *****" 60 | java -cp ./A.jar absa15.Do Assign $dom $dir $ttd 0 2 "0"$pIdxArg 61 | 62 | # pred aspects 63 | #echo -e "***** Predicting polarities using SVM for predicted aspect categories *****" 64 | #./libsvm-3.18/svm-predict -b 1 ${pth}"te.svm.pol4p"${suff} ${pth}"tr.svm.model.pol"${suff} ${pth}"Out.pol"${suff} 65 | 66 | #echo -e "***** Assigning polarities based on SVM prediction *****" 67 | #java -cp ./A.jar absa15.Do Assign $dom $dir $ttd 0 2 "1"$pIdxArg 68 | 69 | # Evaluate results 70 | 71 | echo -e "\n" 72 | echo -e "***** Evaluate Stage 1 Output (target and category) *****" 73 | 74 | java -cp ./A.jar absa15.Do Eval ${pth}"teCln.PrdAspTrg.xml"${suff} ${pth}"teGld.xml"${suff} 1 0 75 | 76 | if [ "$dom" = "rest" ]; then 77 | java -cp ./A.jar absa15.Do Eval ${pth}"teCln.PrdAspTrg.xml"${suff} ${pth}"teGld.xml"${suff} 2 0 78 | java -cp ./A.jar absa15.Do Eval ${pth}"teCln.PrdAspTrg.xml"${suff} ${pth}"teGld.xml"${suff} 3 0 79 | fi 80 | 81 | echo -e "***** Evaluate Stage 2 Output (Polarity) *****" 82 | java -cp ./A.jar absa15.Do Eval ${pth}"teGldAspTrg.PrdPol.xml"${suff} ${pth}"teGld.xml"${suff} 5 1 83 | 84 | } 85 | 86 | echo -e "*******************************************" 87 | echo -e "BASELINES DIR:" $dir 88 | echo -e "Stage 1: Aspect and OTE extraction" 89 | echo -e "Stage 2: Polarity classification" 90 | 91 | echo -e "***** Validate Input XML *****" 92 | java -cp ./A.jar absa15.Do Validate ${dir}${src} ${dir}"ABSA15.xsd" $dom 93 | 94 | if [ "$xva" -eq 0 ]; then 95 | echo -e "***** Split Train Test *****" 96 | java -cp ./A.jar absa15.Do Split $sfl $dir $ttd $src $fld $partIdx 97 | ABSABaseAndEval 98 | else 99 | echo -e "***** Split *****" 100 | java -cp ./A.jar absa15.Do Split $sfl $dir $ttd $src $fld 101 | echo -e "\n***** Cross Validation*****\n" 102 | for i in $(eval echo {1..$fld}); do 103 | echo -e "Round " $i 104 | pIdxArg=" "$(($i-1)) 105 | suff="."$(($i-1)) 106 | echo $pIdxArg $suff 107 | ABSABaseAndEval 108 | echo -e "\n" 109 | done 110 | fi 111 | echo -e "*******************************************" 112 | 113 | 114 | -------------------------------------------------------------------------------- /evaluation_for_AD_TD_TAD/change_pre_to_xml.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | """preprocess for dataset.""" 4 | 5 | import csv 6 | import os 7 | import re 8 | import argparse 9 | import xml.etree.ElementTree as ET 10 | import xml.dom.minidom as DOM 11 | 12 | def TXT_file(name): 13 | return '{}.txt'.format(name) 14 | 15 | def TSV_file(name): 16 | return '{}.tsv'.format(name) 17 | 18 | def XML_file(name): 19 | return '{}.xml'.format(name) 20 | 21 | if __name__ == '__main__': 22 | parser = argparse.ArgumentParser() 23 | 24 | ## Required parameters 25 | parser.add_argument("--gold_path", 26 | type=str, 27 | required=True, 28 | help="The gold file") 29 | parser.add_argument("--pre_path", 30 | type=str, 31 | required=True, 32 | help="The test result file, such as */test_ep_21.txt") 33 | parser.add_argument("--gold_xml_file", 34 | type=str, 35 | required=True, 36 | help="gold xml file -> used to get the file structure") 37 | 38 | parser.add_argument("--pre_xml_file", 39 | type=str, 40 | required=True, 41 | help="get the prediction file in XML format") 42 | parser.add_argument("--tag_schema", 43 | type=str, 44 | required=True, 45 | choices=["TO", "BIO"], 46 | help="The tag schema of the result in pre_path") 47 | 48 | args = parser.parse_args() 49 | 50 | if args.tag_schema == 'BIO': 51 | entity_label = r"BI*" # for BIO 52 | else: 53 | entity_label = r"T+" # for TO 54 | 55 | with open(args.pre_path, 'r', encoding='utf-8') as f_pre, open(args.gold_path, 'r', encoding='utf-8') as f_gold: 56 | # clear the gold opinions and get the empty framework 57 | sen_tree_map = {} 58 | xml_tree = ET.parse(args.gold_xml_file) 59 | root = xml_tree.getroot() 60 | for review in root.findall('Review'): 61 | for sen in review.find('sentences').getchildren(): 62 | sen_key = sen.get('id') 63 | sen_tree_map[sen_key] = sen 64 | opinions = sen.find('Opinions') 65 | if opinions is not None: 66 | opinions.clear() 67 | 68 | f_pre.readline() 69 | f_gold.readline() 70 | pre_lines = f_pre.readlines() 71 | gold_lines = f_gold.readlines() 72 | for line_1, line_2 in zip(pre_lines, gold_lines): 73 | pre_line = line_1.strip().split('\t') 74 | gold_line = line_2.strip().split('\t') 75 | 76 | sentence_id = gold_line[0] 77 | yes_not = pre_line[1] 78 | category_polarity = gold_line[2].split() 79 | sentence = gold_line[3].split() 80 | # move [CLS] 81 | ner_tags = ''.join(pre_line[-1].split()[1:]) 82 | 83 | if yes_not == '1': 84 | category = '#'.join(category_polarity[0:2]).upper() 85 | polarity = category_polarity[-1] 86 | entitys = [] 87 | entity_list = re.finditer(entity_label, ner_tags) 88 | for x in entity_list: 89 | entitys.append(str(x.start()) + '-' + str(len(x.group()))) 90 | 91 | # write to the xml file 92 | #sen_reco = ".//sentence[@id='" + sentence_id.replace("'", r"\'") + "']" 93 | #print(sen_reco) 94 | #current_sen = root.find(sen_reco) 95 | current_sen = sen_tree_map[sentence_id] 96 | current_opinions = current_sen.find('Opinions') 97 | if current_opinions == None: 98 | current_opinions = ET.Element('Opinions') 99 | current_sen.append(current_opinions) 100 | 101 | if len(entitys) == 0: # NULL for this category&polarity 102 | op = ET.Element('Opinion') 103 | op.set('target', 'NULL') 104 | op.set('category', category) 105 | op.set('polarity', polarity) 106 | op.set('from', '0') 107 | op.set('to', '0') 108 | current_opinions.append(op) 109 | 110 | else: 111 | for x in entitys: 112 | start = int(x.split('-')[0]) 113 | end = int(x.split('-')[1]) + start 114 | target_match = re.compile('\\s*'.join(sentence[start:end])) 115 | sentence_org = ' '.join(sentence) 116 | target_match_list = re.finditer(target_match, sentence_org) 117 | true_idx = 0 118 | for m in target_match_list: 119 | if start == sentence_org[0:m.start()].count(' '): 120 | break 121 | true_idx += 1 122 | 123 | gold_sentence = current_sen.find('text').text 124 | target_match_list = re.finditer(target_match, gold_sentence) 125 | match_list = [] 126 | for m in target_match_list: 127 | match_list.append(str(m.start()) + '###' + str(len(m.group())) + '###' + m.group()) 128 | if len(match_list) < true_idx + 1: 129 | print("Error!!!!!!!!!!!!!!!!!!!!!") 130 | print(len(match_list)) 131 | print(target_match) 132 | print(sentence_org) 133 | else: 134 | info_list = match_list[true_idx].split('###') 135 | target = info_list[2] 136 | from_idx = info_list[0] 137 | to_idx = str(int(from_idx) + int(info_list[1])) 138 | op = ET.Element('Opinion') 139 | op.set('target', target) 140 | op.set('category', category) 141 | op.set('polarity', polarity) 142 | op.set('from', from_idx) 143 | op.set('to', to_idx) 144 | current_opinions.append(op) 145 | 146 | 147 | xml_string = ET.tostring(root) 148 | xml_write = DOM.parseString(xml_string) 149 | with open(args.pre_xml_file, 'w') as handle: 150 | xml_write.writexml(handle, indent=' ', encoding='utf-8') 151 | -------------------------------------------------------------------------------- /evaluation_for_TSD_ASD_TASD.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | """evaluate P R F1 for target & polarity joint task""" 4 | 5 | import csv 6 | import os 7 | import re 8 | import argparse 9 | 10 | def TXT_file(name): 11 | return '{}.txt'.format(name) 12 | 13 | def Clean_file(name): 14 | return '{}.tsv'.format(name) 15 | 16 | def evaluate_TSD_contain_NULL(path, best_epoch_file, tag_schema): 17 | with open(os.path.join(path, TXT_file(best_epoch_file)), 'r', encoding='utf-8') as f_pre: 18 | Gold_Num = 0 19 | True_Num = 0 20 | Pre_Num = 0 21 | tag_schema == 'TO' 22 | if tag_schema == 'TO': 23 | entity_label = r"T+" # for TO 24 | else: 25 | entity_label = r"BI*" # for BIO 26 | f_pre.readline() 27 | pre_lines = f_pre.readlines() 28 | 29 | # the polarity order in test file is: positive, negative, neutral 30 | lin_idx = 0 31 | positive_targets_gold = set() 32 | positive_targets_pred = set() 33 | negative_targets_gold = set() 34 | negative_targets_pred = set() 35 | neutral_targets_gold = set() 36 | neutral_targets_pred = set() 37 | NULL_for_positive_gold = False 38 | NULL_for_positive_pred = False 39 | NULL_for_negative_gold = False 40 | NULL_for_negative_pred = False 41 | NULL_for_neutral_gold = False 42 | NULL_for_neutral_pred = False 43 | pre_sen = '' 44 | now_sen = '' 45 | for line in pre_lines: 46 | lin_idx += 1 47 | pre_line = line.strip().split('\t') 48 | now_sen = pre_line[2] 49 | if now_sen != pre_sen: # a new sentence now, evaluate for pre sentence 50 | pre_sen = now_sen 51 | # positive 52 | if NULL_for_positive_gold: 53 | Gold_Num += 1 54 | if NULL_for_positive_pred: 55 | Pre_Num += 1 56 | if NULL_for_positive_gold and NULL_for_positive_pred: 57 | True_Num += 1 58 | Gold_Num += len(positive_targets_gold) 59 | Pre_Num += len(positive_targets_pred) 60 | True_Num += len(positive_targets_gold & positive_targets_pred) 61 | # negative 62 | if NULL_for_negative_gold: 63 | Gold_Num += 1 64 | if NULL_for_negative_pred: 65 | Pre_Num += 1 66 | if NULL_for_negative_gold and NULL_for_negative_pred: 67 | True_Num += 1 68 | Gold_Num += len(negative_targets_gold) 69 | Pre_Num += len(negative_targets_pred) 70 | True_Num += len(negative_targets_gold & negative_targets_pred) 71 | # neutral 72 | if NULL_for_neutral_gold: 73 | Gold_Num += 1 74 | if NULL_for_neutral_pred: 75 | Pre_Num += 1 76 | if NULL_for_neutral_gold and NULL_for_neutral_pred: 77 | True_Num += 1 78 | Gold_Num += len(neutral_targets_gold) 79 | Pre_Num += len(neutral_targets_pred) 80 | True_Num += len(neutral_targets_gold & neutral_targets_pred) 81 | 82 | # initialize for new sentence 83 | positive_targets_gold.clear() 84 | positive_targets_pred.clear() 85 | negative_targets_gold.clear() 86 | negative_targets_pred.clear() 87 | neutral_targets_gold.clear() 88 | neutral_targets_pred.clear() 89 | NULL_for_positive_gold = False 90 | NULL_for_positive_pred = False 91 | NULL_for_negative_gold = False 92 | NULL_for_negative_pred = False 93 | NULL_for_neutral_gold = False 94 | NULL_for_neutral_pred = False 95 | 96 | sentence_length = len(pre_line[2].split()) 97 | pre_ner_tags = ''.join(pre_line[-1].split()[1:]) # [CLS] sentence [SEP] ........ 98 | gold_ner_tags = ''.join(pre_line[-2].split()[1:]) 99 | if pre_line[0] == '1': # yes on gold 100 | gold_entity = set() 101 | gold_entity_list = re.finditer(entity_label, gold_ner_tags) 102 | for x in gold_entity_list: 103 | gold_entity.add(str(x.start()) + '-' + str(len(x.group()))) 104 | 105 | if lin_idx % 3 == 1: # this line for positive 106 | if len(gold_entity) == 0: # NULL 107 | NULL_for_positive_gold = True 108 | else: # not NULL, has entity in this sentence 109 | positive_targets_gold = positive_targets_gold | gold_entity 110 | elif lin_idx % 3 == 2: # this line for negative 111 | if len(gold_entity) == 0: # NULL 112 | NULL_for_negative_gold = True 113 | else: # not NULL, has entity in this sentence 114 | negative_targets_gold = negative_targets_gold | gold_entity 115 | else: # this line for neutral 116 | if len(gold_entity) == 0: # NULL 117 | NULL_for_neutral_gold = True 118 | else: # not NULL, has entity in this sentence 119 | neutral_targets_gold = neutral_targets_gold | gold_entity 120 | 121 | if pre_line[1] == '1': # yes on pre 122 | pre_entity = set() 123 | pre_entity_list = re.finditer(entity_label, pre_ner_tags) 124 | for x in pre_entity_list: 125 | pre_entity.add(str(x.start()) + '-' + str(len(x.group()))) 126 | 127 | if lin_idx % 3 == 1: # this line for positive 128 | if len(pre_entity) == 0: # NULL 129 | NULL_for_positive_pred = True 130 | else: # not NULL, has entity in this sentence 131 | positive_targets_pred = positive_targets_pred | pre_entity 132 | elif lin_idx % 3 == 2: # this line for negative 133 | if len(pre_entity) == 0: # NULL 134 | NULL_for_negative_pred = True 135 | else: # not NULL, has entity in this sentence 136 | negative_targets_pred = negative_targets_pred | pre_entity 137 | else: # this line for neutral 138 | if len(pre_entity) == 0: # NULL 139 | NULL_for_neutral_pred = True 140 | else: # not NULL, has entity in this sentence 141 | neutral_targets_pred = neutral_targets_pred | pre_entity 142 | 143 | P = True_Num / float(Pre_Num) if Pre_Num != 0 else 0 144 | R = True_Num / float(Gold_Num) 145 | F = (2*P*R)/float(P+R) if P!=0 else 0 146 | 147 | print('TSD task containing NULL:') 148 | print("\tP: ", P, " R: ", R, " F1: ", F) 149 | print('----------------------------------------------------\n\n') 150 | 151 | def evaluate_TSD_ignore_NULL(path, best_epoch_file, tag_schema): 152 | with open(os.path.join(path, TXT_file(best_epoch_file)), 'r', encoding='utf-8') as f_pre: 153 | Gold_Num = 0 154 | True_Num = 0 155 | Pre_Num = 0 156 | if tag_schema == 'TO': 157 | entity_label = r"T+" # for TO 158 | else: 159 | entity_label = r"BI*" # for BIO 160 | f_pre.readline() 161 | pre_lines = f_pre.readlines() 162 | 163 | # the polarity order in test file is: positive, negative, neutral 164 | lin_idx = 0 165 | positive_targets_gold = set() 166 | positive_targets_pred = set() 167 | negative_targets_gold = set() 168 | negative_targets_pred = set() 169 | neutral_targets_gold = set() 170 | neutral_targets_pred = set() 171 | pre_sen = '' 172 | now_sen = '' 173 | for line in pre_lines: 174 | lin_idx += 1 175 | pre_line = line.strip().split('\t') 176 | now_sen = pre_line[2] 177 | if now_sen != pre_sen: # a new sentence now, evaluate for pre sentence 178 | pre_sen = now_sen 179 | # positive 180 | Gold_Num += len(positive_targets_gold) 181 | Pre_Num += len(positive_targets_pred) 182 | True_Num += len(positive_targets_gold & positive_targets_pred) 183 | # negative 184 | Gold_Num += len(negative_targets_gold) 185 | Pre_Num += len(negative_targets_pred) 186 | True_Num += len(negative_targets_gold & negative_targets_pred) 187 | # neutral 188 | Gold_Num += len(neutral_targets_gold) 189 | Pre_Num += len(neutral_targets_pred) 190 | True_Num += len(neutral_targets_gold & neutral_targets_pred) 191 | 192 | # initialize for new sentence 193 | positive_targets_gold.clear() 194 | positive_targets_pred.clear() 195 | negative_targets_gold.clear() 196 | negative_targets_pred.clear() 197 | neutral_targets_gold.clear() 198 | neutral_targets_pred.clear() 199 | NULL_for_positive_gold = False 200 | NULL_for_positive_pred = False 201 | NULL_for_negative_gold = False 202 | NULL_for_negative_pred = False 203 | NULL_for_neutral_gold = False 204 | NULL_for_neutral_pred = False 205 | 206 | sentence_length = len(pre_line[2].split()) 207 | pre_ner_tags = ''.join(pre_line[-1].split()[1:]) # [CLS] sentence [SEP] ........ 208 | gold_ner_tags = ''.join(pre_line[-2].split()[1:]) 209 | if pre_line[0] == '1': # yes on gold 210 | gold_entity = set() 211 | gold_entity_list = re.finditer(entity_label, gold_ner_tags) 212 | for x in gold_entity_list: 213 | gold_entity.add(str(x.start()) + '-' + str(len(x.group()))) 214 | 215 | if lin_idx % 3 == 1: # this line for positive 216 | if len(gold_entity) != 0: # not NULL, has entity in this sentence 217 | positive_targets_gold = positive_targets_gold | gold_entity 218 | elif lin_idx % 3 == 2: # this line for negative 219 | if len(gold_entity) != 0: # not NULL, has entity in this sentence 220 | negative_targets_gold = negative_targets_gold | gold_entity 221 | else: # this line for neutral 222 | if len(gold_entity) != 0: # not NULL, has entity in this sentence 223 | neutral_targets_gold = neutral_targets_gold | gold_entity 224 | 225 | if pre_line[1] == '1': # yes on pre 226 | pre_entity = set() 227 | pre_entity_list = re.finditer(entity_label, pre_ner_tags) 228 | for x in pre_entity_list: 229 | pre_entity.add(str(x.start()) + '-' + str(len(x.group()))) 230 | 231 | if lin_idx % 3 == 1: # this line for positive 232 | if len(pre_entity) != 0: # not NULL, has entity in this sentence 233 | positive_targets_pred = positive_targets_pred | pre_entity 234 | elif lin_idx % 3 == 2: # this line for negative 235 | if len(pre_entity) != 0: # not NULL, has entity in this sentence 236 | negative_targets_pred = negative_targets_pred | pre_entity 237 | else: # this line for neutral 238 | if len(pre_entity) != 0: # not NULL, has entity in this sentence 239 | neutral_targets_pred = neutral_targets_pred | pre_entity 240 | 241 | P = True_Num / float(Pre_Num) if Pre_Num != 0 else 0 242 | R = True_Num / float(Gold_Num) 243 | F = (2*P*R)/float(P+R) if P!=0 else 0 244 | 245 | print('TSD task ignoring NULL:') 246 | print("\tP: ", P, " R: ", R, " F1: ", F) 247 | print('----------------------------------------------------\n\n') 248 | 249 | 250 | def evaluate_ASD(path, best_epoch_file): 251 | with open(os.path.join(path, TXT_file(best_epoch_file)), 'r', encoding='utf-8') as f_pre: 252 | Gold_Num = 0 253 | True_Num = 0 254 | Pre_Num = 0 255 | f_pre.readline() 256 | pre_lines = f_pre.readlines() 257 | for line in pre_lines: 258 | pre_line = line.strip().split('\t') 259 | 260 | if pre_line[0] == '1': # yes on gold 261 | Gold_Num += 1 262 | if pre_line[1] == '1': # yes on pre 263 | True_Num += 1 264 | 265 | if pre_line[1] == '1': # yes on pre 266 | Pre_Num += 1 267 | 268 | P = True_Num / float(Pre_Num) if Pre_Num != 0 else 0 269 | R = True_Num / float(Gold_Num) 270 | F = (2*P*R)/float(P+R) if P!=0 else 0 271 | 272 | print('ASD task:') 273 | print("\tP: ", P, " R: ", R, " F1: ", F) 274 | print('----------------------------------------------------\n\n') 275 | 276 | 277 | def evaluate_TASD(path, epochs, tag_schema): 278 | # record the best epoch 279 | best_epoch_file = '' 280 | best_P = 0 281 | best_R = 0 282 | best_F1 = 0 283 | best_NULL_P = 0 284 | best_NULL_R = 0 285 | best_NULL_F1 = 0 286 | best_NO_and_O_P = 0 287 | best_NO_and_O_R = 0 288 | best_NO_and_O_F1 = 0 289 | for index in range(epochs): 290 | file_pre = 'test_ep_' + str(index+1) 291 | with open(os.path.join(path, TXT_file(file_pre)), 'r', encoding='utf-8') as f_pre: 292 | Gold_Num = 0 293 | True_Num = 0 294 | Pre_Num = 0 295 | NULL_Gold_Num = 0 296 | NULL_True_Num = 0 297 | NULL_Pre_Num = 0 298 | NO_and_O_Gold_Num = 0 299 | NO_and_O_True_Num = 0 300 | NO_and_O_Pre_Num = 0 301 | if tag_schema == 'TO': 302 | entity_label = r"T+" # for TO 303 | else: 304 | entity_label = r"BI*" # for BIO 305 | f_pre.readline() 306 | pre_lines = f_pre.readlines() 307 | for line in pre_lines: 308 | pre_line = line.strip().split('\t') 309 | sentence_length = len(pre_line[2].split()) 310 | pre_ner_tags = ''.join(pre_line[-1].split()[1:]) # [CLS] sentence [SEP] ........ 311 | gold_ner_tags = ''.join(pre_line[-2].split()[1:]) 312 | if pre_line[0] == '1': # yes on gold 313 | gold_entity = [] 314 | pre_entity = [] 315 | gold_entity_list = re.finditer(entity_label, gold_ner_tags) 316 | pre_entity_list = re.finditer(entity_label, pre_ner_tags) 317 | for x in gold_entity_list: 318 | gold_entity.append(str(x.start()) + '-' + str(len(x.group()))) 319 | for x in pre_entity_list: 320 | pre_entity.append(str(x.start()) + '-' + str(len(x.group()))) 321 | 322 | if len(gold_entity) == 0: # NULL 323 | Gold_Num += 1 324 | NULL_Gold_Num += 1 325 | if len(pre_entity) == 0 and pre_line[1] == '1': 326 | True_Num += 1 327 | NULL_True_Num += 1 328 | else: # not NULL, has entity in this sentence 329 | Gold_Num += len(gold_entity) 330 | for x in gold_entity: 331 | if x in pre_entity and pre_line[1] == '1': 332 | True_Num += 1 333 | else: # no on gold 334 | NO_and_O_Gold_Num += 1 335 | if pre_line[1] == '0' and 'T' not in pre_ner_tags and 'B' not in pre_ner_tags and 'I' not in pre_ner_tags: 336 | NO_and_O_True_Num += 1 337 | 338 | if pre_line[1] == '1': # yes on pre 339 | pre_entity = [] 340 | pre_entity_list = re.finditer(entity_label, pre_ner_tags) 341 | for x in pre_entity_list: 342 | pre_entity.append(str(x.start()) + '-' + str(len(x.group()))) 343 | 344 | if len(pre_entity) == 0: # NULL 345 | Pre_Num += 1 346 | NULL_Pre_Num += 1 347 | else: # not NULL, has entity in this sentence 348 | Pre_Num += len(pre_entity) 349 | else: # no on pre 350 | if 'T' not in pre_ner_tags and 'B' not in pre_ner_tags and 'I' not in pre_ner_tags: 351 | NO_and_O_Pre_Num += 1 352 | 353 | P = True_Num / float(Pre_Num) if Pre_Num != 0 else 0 354 | R = True_Num / float(Gold_Num) 355 | F = (2*P*R)/float(P+R) if P!=0 else 0 356 | 357 | P_NULL = NULL_True_Num / float(NULL_Pre_Num) if NULL_Pre_Num != 0 else 0 358 | R_NULL = NULL_True_Num / float(NULL_Gold_Num) 359 | F_NULL = (2*P_NULL*R_NULL)/float(P_NULL+R_NULL) if P_NULL!=0 else 0 360 | 361 | P_NO_and_O = NO_and_O_True_Num / float(NO_and_O_Pre_Num) if NO_and_O_Pre_Num != 0 else 0 362 | R_NO_and_O = NO_and_O_True_Num / float(NO_and_O_Gold_Num) 363 | F_NO_and_O = (2*P_NO_and_O*R_NO_and_O)/float(P_NO_and_O+R_NO_and_O) if P_NO_and_O!=0 else 0 364 | 365 | if F > best_F1: 366 | best_P = P 367 | best_R = R 368 | best_F1 = F 369 | 370 | best_NULL_P = P_NULL 371 | best_NULL_R = R_NULL 372 | best_NULL_F1 = F_NULL 373 | 374 | best_NO_and_O_P = P_NO_and_O 375 | best_NO_and_O_R = R_NO_and_O 376 | best_NO_and_O_F1 = F_NO_and_O 377 | 378 | best_epoch_file = file_pre 379 | 380 | ''' 381 | print(file_pre, ' :') 382 | print('All tuples') 383 | print("\tP: ", P, " R: ", R, " F1: ", F) 384 | print('\t\tgold sum: ', Gold_Num) 385 | print('\t\tpre sum: ', Pre_Num) 386 | print('\t\ttrue sum: ', True_Num) 387 | print('----------------------------------------------------\n') 388 | 389 | print('Only NULL tuples') 390 | print("\tP: ", P_NULL, " R: ", R_NULL, " F1: ", F_NULL) 391 | print('\t\tgold sum: ', NULL_Gold_Num) 392 | print('\t\tpre sum: ', NULL_Pre_Num) 393 | print('\t\ttrue sum: ', NULL_True_Num) 394 | print('----------------------------------------------------\n') 395 | 396 | print('NO and pure O tag sequence') 397 | print("\tP: ", P_NO_and_O, " R: ", R_NO_and_O, " F1: ", F_NO_and_O) 398 | print('\t\tgold sum: ', NO_and_O_Gold_Num) 399 | print('\t\tpre sum: ', NO_and_O_Pre_Num) 400 | print('\t\ttrue sum: ', NO_and_O_True_Num) 401 | print('----------------------------------------------------\n') 402 | ''' 403 | 404 | print('\n') 405 | print("The best result is in ", best_epoch_file, ' :') 406 | 407 | print("TASD task:") 408 | print("\tAll tuples") 409 | print("\t\tP: ", best_P, " R: ", best_R, " F1: ", best_F1) 410 | print('----------------------------------------------------\n') 411 | print("\tOnly NULL tuples") 412 | print("\t\tP: ", best_NULL_P, " R: ", best_NULL_R, " F1: ", best_NULL_F1) 413 | print('----------------------------------------------------\n') 414 | print("\tNO and pure O tag sequence") 415 | print("\t\tP: ", best_NO_and_O_P, " R: ", best_NO_and_O_R, " F1: ", best_NO_and_O_F1) 416 | print('----------------------------------------------------\n\n') 417 | 418 | return best_epoch_file 419 | 420 | 421 | if __name__ == '__main__': 422 | parser = argparse.ArgumentParser() 423 | 424 | ## Required parameters 425 | parser.add_argument("--output_dir", 426 | type=str, 427 | required=True, 428 | help="The output_dir in training & testing") 429 | parser.add_argument("--tag_schema", 430 | type=str, 431 | required=True, 432 | choices=["TO", "BIO"], 433 | help="The tag schema of the result") 434 | parser.add_argument("--num_epochs", 435 | type=int, 436 | required=True, 437 | default=30, 438 | help="The epochs num in training & testing") 439 | 440 | args = parser.parse_args() 441 | 442 | best_epoch_file = evaluate_TASD(args.output_dir, args.num_epochs, args.tag_schema) 443 | evaluate_ASD(args.output_dir, best_epoch_file) 444 | evaluate_TSD_contain_NULL(args.output_dir, best_epoch_file, args.tag_schema) 445 | evaluate_TSD_ignore_NULL(args.output_dir, best_epoch_file, args.tag_schema) 446 | 447 | 448 | 449 | -------------------------------------------------------------------------------- /evaluation_for_loss_separate.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | import csv 4 | import os 5 | import re 6 | import pandas as pd 7 | 8 | def TXT_file(name): 9 | return '{}.txt'.format(name) 10 | 11 | def Clean_file(name): 12 | return '{}.tsv'.format(name) 13 | 14 | def evaluate_AS(pre_lines): 15 | Gold_Num = 0 16 | True_Num = 0 17 | Pre_Num = 0 18 | for line in pre_lines: 19 | pre_line = line.strip().split('\t') 20 | 21 | if pre_line[0] == '1': # yes on gold 22 | Gold_Num += 1 23 | if pre_line[1] == '1': # yes on pre 24 | True_Num += 1 25 | 26 | if pre_line[1] == '1': # yes on pre 27 | Pre_Num += 1 28 | 29 | P = True_Num / float(Pre_Num) if Pre_Num != 0 else 0 30 | R = True_Num / float(Gold_Num) 31 | F = (2*P*R)/float(P+R) if P!=0 else 0 32 | print("\tP: ", P, " R: ", R, " F1: ", F) 33 | print('\t\tgold sum: ', Gold_Num) 34 | print('\t\tpre sum: ', Pre_Num) 35 | print('\t\ttrue sum: ', True_Num) 36 | print('----------------------------------------------------\n') 37 | return F 38 | 39 | def evaluate_T(pre_lines, entity_label): 40 | Gold_Num = 0 41 | True_Num = 0 42 | Pre_Num = 0 43 | for line in pre_lines: 44 | pre_line = line.strip().split('\t') 45 | sentence_length = len(pre_line[0].split()) 46 | pre_ner_tags = ''.join(pre_line[2].split()[1:]) # [CLS] sentence [SEP] ........ 47 | gold_ner_tags = ''.join(pre_line[1].split()[1:]) 48 | 49 | gold_entity = [] 50 | pre_entity = [] 51 | gold_entity_list = re.finditer(entity_label, gold_ner_tags) 52 | pre_entity_list = re.finditer(entity_label, pre_ner_tags) 53 | for x in gold_entity_list: 54 | gold_entity.append(str(x.start()) + '-' + str(len(x.group()))) 55 | for x in pre_entity_list: 56 | pre_entity.append(str(x.start()) + '-' + str(len(x.group()))) 57 | 58 | Pre_Num += len(pre_entity) 59 | Gold_Num += len(gold_entity) 60 | for x in gold_entity: 61 | if x in pre_entity: 62 | True_Num += 1 63 | 64 | P = True_Num / float(Pre_Num) if Pre_Num != 0 else 0 65 | R = True_Num / float(Gold_Num) 66 | F = (2*P*R)/float(P+R) if P!=0 else 0 67 | print("\tP: ", P, " R: ", R, " F1: ", F) 68 | print('\t\tgold sum: ', Gold_Num) 69 | print('\t\tpre sum: ', Pre_Num) 70 | print('\t\ttrue sum: ', True_Num) 71 | print('----------------------------------------------------\n') 72 | return F 73 | 74 | 75 | if __name__ == '__main__': 76 | parser = argparse.ArgumentParser() 77 | 78 | ## Required parameters 79 | parser.add_argument("--output_dir_AS", 80 | type=str, 81 | required=True, 82 | help="The output_dir of subtask AS in training & testing") 83 | parser.add_argument("--output_dir_T", 84 | type=str, 85 | required=True, 86 | help="The output_dir of subtask T in training & testing") 87 | parser.add_argument("--tag_schema", 88 | type=str, 89 | required=True, 90 | choices=["TO", "BIO"], 91 | help="The tag schema of the result") 92 | parser.add_argument("--num_epochs", 93 | type=int, 94 | required=True, 95 | default=30, 96 | help="The epochs num in training & testing") 97 | 98 | 99 | args = parser.parse_args() 100 | 101 | best_ep_AS = 0 102 | best_F_AS = 0 103 | best_ep_T = 0 104 | best_F_T = 0 105 | if args.tag_schema == 'TO': 106 | entity_label = r"T+" # for TO 107 | else: 108 | entity_label = r"BI*" # for BIO 109 | 110 | ## for AS 111 | for index in range(1, args.num_epochs+1): 112 | file_pre = 'test_ep_' + str(index) 113 | with open(os.path.join(args.output_dir_AS, TXT_file(file_pre)), 'r', encoding='utf-8') as f_pre: 114 | print(file_pre) 115 | f_pre.readline() 116 | F = evaluate_AS(f_pre.readlines()) 117 | if F > best_F_AS: 118 | best_F_AS = F 119 | best_ep_AS = file_pre 120 | ## for T 121 | for index in range(1, args.num_epochs+1): 122 | file_pre = 'test_ep_' + str(index) 123 | with open(os.path.join(args.output_dir_T, TXT_file(file_pre)), 'r', encoding='utf-8') as f_pre: 124 | print(file_pre) 125 | f_pre.readline() 126 | F = evaluate_T(f_pre.readlines(), entity_label) 127 | if F > best_F_T: 128 | best_F_T = F 129 | best_ep_T = file_pre 130 | 131 | 132 | ### get best epoch for AS and best epoch for T, then put them together to get tuples 133 | # all tuples & only NULL tuples 134 | # prediction 'no'(0) and pure 'OOOOOOOOO' tag sequence 135 | with open(os.path.join(args.output_dir_AS, TXT_file(best_ep_AS)), 'r', encoding='utf-8') as f_AS, open(os.path.join(args.output_dir_T, TXT_file(best_ep_T)), 'r', encoding='utf-8') as f_T: 136 | Gold_Num = 0 137 | True_Num = 0 138 | Pre_Num = 0 139 | NULL_Gold_Num = 0 140 | NULL_True_Num = 0 141 | NULL_Pre_Num = 0 142 | NO_and_O_Gold_Num = 0 143 | NO_and_O_True_Num = 0 144 | NO_and_O_Pre_Num = 0 145 | 146 | f_AS.readline() 147 | f_T.readline() 148 | pre_lines_AS = f_AS.readlines() 149 | pre_lines_T = f_T.readlines() 150 | for line_AS, line_T in zip(pre_lines_AS, pre_lines_T): 151 | pre_line = line_AS.strip().split('\t') + line_T.strip().split('\t') 152 | sentence_length = len(pre_line[2].split()) 153 | pre_ner_tags = ''.join(pre_line[-1].split()[1:]) # [CLS] sentence [SEP] ........ 154 | gold_ner_tags = ''.join(pre_line[-2].split()[1:]) 155 | if pre_line[0] == '1': # yes on gold 156 | gold_entity = [] 157 | pre_entity = [] 158 | gold_entity_list = re.finditer(entity_label, gold_ner_tags) 159 | pre_entity_list = re.finditer(entity_label, pre_ner_tags) 160 | for x in gold_entity_list: 161 | gold_entity.append(str(x.start()) + '-' + str(len(x.group()))) 162 | for x in pre_entity_list: 163 | pre_entity.append(str(x.start()) + '-' + str(len(x.group()))) 164 | 165 | if len(gold_entity) == 0: # NULL 166 | Gold_Num += 1 167 | NULL_Gold_Num += 1 168 | if len(pre_entity) == 0 and pre_line[1] == '1': 169 | True_Num += 1 170 | NULL_True_Num += 1 171 | else: # not NULL, has entity in this sentence 172 | Gold_Num += len(gold_entity) 173 | for x in gold_entity: 174 | if x in pre_entity and pre_line[1] == '1': 175 | True_Num += 1 176 | else: # no on gold 177 | NO_and_O_Gold_Num += 1 178 | if pre_line[1] == '0' and 'T' not in pre_ner_tags and 'B' not in pre_ner_tags and 'I' not in pre_ner_tags: 179 | NO_and_O_True_Num += 1 180 | 181 | if pre_line[1] == '1': # yes on pre 182 | pre_entity = [] 183 | pre_entity_list = re.finditer(entity_label, pre_ner_tags) 184 | for x in pre_entity_list: 185 | pre_entity.append(str(x.start()) + '-' + str(len(x.group()))) 186 | 187 | if len(pre_entity) == 0: # NULL 188 | Pre_Num += 1 189 | NULL_Pre_Num += 1 190 | else: # not NULL, has entity in this sentence 191 | Pre_Num += len(pre_entity) 192 | else: # no on pre 193 | if 'T' not in pre_ner_tags and 'B' not in pre_ner_tags and 'I' not in pre_ner_tags: 194 | NO_and_O_Pre_Num += 1 195 | 196 | P = True_Num / float(Pre_Num) if Pre_Num != 0 else 0 197 | R = True_Num / float(Gold_Num) 198 | F = (2*P*R)/float(P+R) if P!=0 else 0 199 | 200 | P_NULL = NULL_True_Num / float(NULL_Pre_Num) if NULL_Pre_Num != 0 else 0 201 | R_NULL = NULL_True_Num / float(NULL_Gold_Num) 202 | F_NULL = (2*P_NULL*R_NULL)/float(P_NULL+R_NULL) if P_NULL!=0 else 0 203 | 204 | P_NO_and_O = NO_and_O_True_Num / float(NO_and_O_Pre_Num) if NO_and_O_Pre_Num != 0 else 0 205 | R_NO_and_O = NO_and_O_True_Num / float(NO_and_O_Gold_Num) 206 | F_NO_and_O = (2*P_NO_and_O*R_NO_and_O)/float(P_NO_and_O+R_NO_and_O) if P_NO_and_O!=0 else 0 207 | 208 | print(best_ep_AS + ' + ' + best_ep_T) 209 | print('All tuples') 210 | print("\tP: ", P, " R: ", R, " F1: ", F) 211 | print('\t\tgold sum: ', Gold_Num) 212 | print('\t\tpre sum: ', Pre_Num) 213 | print('\t\ttrue sum: ', True_Num) 214 | print('----------------------------------------------------\n') 215 | 216 | print('Only NULL tuples') 217 | print("\tP: ", P_NULL, " R: ", R_NULL, " F1: ", F_NULL) 218 | print('\t\tgold sum: ', NULL_Gold_Num) 219 | print('\t\tpre sum: ', NULL_Pre_Num) 220 | print('\t\ttrue sum: ', NULL_True_Num) 221 | print('----------------------------------------------------\n') 222 | 223 | print('NO and pure O tag sequence') 224 | print("\tP: ", P_NO_and_O, " R: ", R_NO_and_O, " F1: ", F_NO_and_O) 225 | print('\t\tgold sum: ', NO_and_O_Gold_Num) 226 | print('\t\tpre sum: ', NO_and_O_Pre_Num) 227 | print('\t\ttrue sum: ', NO_and_O_True_Num) 228 | print('----------------------------------------------------\n') -------------------------------------------------------------------------------- /modeling.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | # Reference: https://github.com/huggingface/pytorch-pretrained-BERT 4 | 5 | """PyTorch BERT model.""" 6 | 7 | from __future__ import absolute_import, division, print_function 8 | 9 | import copy 10 | import json 11 | import math 12 | 13 | import six 14 | import torch 15 | from torchcrf import CRF 16 | import torch.nn as nn 17 | from torch.nn import CrossEntropyLoss 18 | import numpy as np 19 | import tensorflow as tf 20 | import datetime 21 | 22 | 23 | 24 | 25 | def gelu(x): 26 | """Implementation of the gelu activation function. 27 | For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): 28 | 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) 29 | """ 30 | return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) 31 | 32 | 33 | class BertConfig(object): 34 | """Configuration class to store the configuration of a `BertModel`. 35 | """ 36 | def __init__(self, 37 | vocab_size, 38 | hidden_size=768, 39 | num_hidden_layers=12, 40 | num_attention_heads=12, 41 | intermediate_size=3072, 42 | hidden_act="gelu", 43 | hidden_dropout_prob=0.1, 44 | attention_probs_dropout_prob=0.1, 45 | max_position_embeddings=512, 46 | type_vocab_size=16, 47 | initializer_range=0.02): 48 | """Constructs BertConfig. 49 | 50 | Args: 51 | vocab_size: Vocabulary size of `inputs_ids` in `BertModel`. 52 | hidden_size: Size of the encoder layers and the pooler layer. 53 | num_hidden_layers: Number of hidden layers in the Transformer encoder. 54 | num_attention_heads: Number of attention heads for each attention layer in 55 | the Transformer encoder. 56 | intermediate_size: The size of the "intermediate" (i.e., feed-forward) 57 | layer in the Transformer encoder. 58 | hidden_act: The non-linear activation function (function or string) in the 59 | encoder and pooler. 60 | hidden_dropout_prob: The dropout probabilitiy for all fully connected 61 | layers in the embeddings, encoder, and pooler. 62 | attention_probs_dropout_prob: The dropout ratio for the attention 63 | probabilities. 64 | max_position_embeddings: The maximum sequence length that this model might 65 | ever be used with. Typically set this to something large just in case 66 | (e.g., 512 or 1024 or 2048). 67 | type_vocab_size: The vocabulary size of the `token_type_ids` passed into 68 | `BertModel`. 69 | initializer_range: The sttdev of the truncated_normal_initializer for 70 | initializing all weight matrices. 71 | """ 72 | self.vocab_size = vocab_size 73 | self.hidden_size = hidden_size 74 | self.num_hidden_layers = num_hidden_layers 75 | self.num_attention_heads = num_attention_heads 76 | self.hidden_act = hidden_act 77 | self.intermediate_size = intermediate_size 78 | self.hidden_dropout_prob = hidden_dropout_prob 79 | self.attention_probs_dropout_prob = attention_probs_dropout_prob 80 | self.max_position_embeddings = max_position_embeddings 81 | self.type_vocab_size = type_vocab_size 82 | self.initializer_range = initializer_range 83 | 84 | @classmethod 85 | def from_dict(cls, json_object): 86 | """Constructs a `BertConfig` from a Python dictionary of parameters.""" 87 | config = BertConfig(vocab_size=None) 88 | for (key, value) in six.iteritems(json_object): 89 | config.__dict__[key] = value 90 | return config 91 | 92 | @classmethod 93 | def from_json_file(cls, json_file): 94 | """Constructs a `BertConfig` from a json file of parameters.""" 95 | with open(json_file, "r") as reader: 96 | text = reader.read() 97 | return cls.from_dict(json.loads(text)) 98 | 99 | def to_dict(self): 100 | """Serializes this instance to a Python dictionary.""" 101 | output = copy.deepcopy(self.__dict__) 102 | return output 103 | 104 | def to_json_string(self): 105 | """Serializes this instance to a JSON string.""" 106 | return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" 107 | 108 | 109 | class BERTLayerNorm(nn.Module): 110 | def __init__(self, config, variance_epsilon=1e-12): 111 | """Construct a layernorm module in the TF style (epsilon inside the square root). 112 | """ 113 | super(BERTLayerNorm, self).__init__() 114 | self.gamma = nn.Parameter(torch.ones(config.hidden_size)) 115 | self.beta = nn.Parameter(torch.zeros(config.hidden_size)) 116 | self.variance_epsilon = variance_epsilon 117 | 118 | def forward(self, x): 119 | u = x.mean(-1, keepdim=True) 120 | s = (x - u).pow(2).mean(-1, keepdim=True) 121 | x = (x - u) / torch.sqrt(s + self.variance_epsilon) 122 | return self.gamma * x + self.beta 123 | 124 | class BERTEmbeddings(nn.Module): 125 | def __init__(self, config): 126 | super(BERTEmbeddings, self).__init__() 127 | """Construct the embedding module from word, position and token_type embeddings. 128 | """ 129 | self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size) 130 | self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) 131 | self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) 132 | 133 | # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load 134 | # any TensorFlow checkpoint file 135 | self.LayerNorm = BERTLayerNorm(config) 136 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 137 | 138 | def forward(self, input_ids, token_type_ids=None): 139 | seq_length = input_ids.size(1) 140 | position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) 141 | position_ids = position_ids.unsqueeze(0).expand_as(input_ids) 142 | if token_type_ids is None: 143 | token_type_ids = torch.zeros_like(input_ids) 144 | 145 | words_embeddings = self.word_embeddings(input_ids) 146 | position_embeddings = self.position_embeddings(position_ids) 147 | token_type_embeddings = self.token_type_embeddings(token_type_ids) 148 | 149 | embeddings = words_embeddings + position_embeddings + token_type_embeddings 150 | embeddings = self.LayerNorm(embeddings) 151 | embeddings = self.dropout(embeddings) 152 | return embeddings 153 | 154 | 155 | class BERTSelfAttention(nn.Module): 156 | def __init__(self, config): 157 | super(BERTSelfAttention, self).__init__() 158 | if config.hidden_size % config.num_attention_heads != 0: 159 | raise ValueError( 160 | "The hidden size (%d) is not a multiple of the number of attention " 161 | "heads (%d)" % (config.hidden_size, config.num_attention_heads)) 162 | self.num_attention_heads = config.num_attention_heads 163 | self.attention_head_size = int(config.hidden_size / config.num_attention_heads) 164 | self.all_head_size = self.num_attention_heads * self.attention_head_size 165 | 166 | self.query = nn.Linear(config.hidden_size, self.all_head_size) 167 | self.key = nn.Linear(config.hidden_size, self.all_head_size) 168 | self.value = nn.Linear(config.hidden_size, self.all_head_size) 169 | 170 | self.dropout = nn.Dropout(config.attention_probs_dropout_prob) 171 | 172 | def transpose_for_scores(self, x): 173 | new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) 174 | x = x.view(*new_x_shape) 175 | return x.permute(0, 2, 1, 3) 176 | 177 | def forward(self, hidden_states, attention_mask): 178 | mixed_query_layer = self.query(hidden_states) 179 | mixed_key_layer = self.key(hidden_states) 180 | mixed_value_layer = self.value(hidden_states) 181 | 182 | query_layer = self.transpose_for_scores(mixed_query_layer) 183 | key_layer = self.transpose_for_scores(mixed_key_layer) 184 | value_layer = self.transpose_for_scores(mixed_value_layer) 185 | 186 | # Take the dot product between "query" and "key" to get the raw attention scores. 187 | attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) 188 | attention_scores = attention_scores / math.sqrt(self.attention_head_size) 189 | # Apply the attention mask is (precomputed for all layers in BertModel forward() function) 190 | attention_scores = attention_scores + attention_mask 191 | 192 | # Normalize the attention scores to probabilities. 193 | attention_probs = nn.Softmax(dim=-1)(attention_scores) 194 | 195 | # This is actually dropping out entire tokens to attend to, which might 196 | # seem a bit unusual, but is taken from the original Transformer paper. 197 | attention_probs = self.dropout(attention_probs) 198 | 199 | context_layer = torch.matmul(attention_probs, value_layer) 200 | context_layer = context_layer.permute(0, 2, 1, 3).contiguous() 201 | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) 202 | context_layer = context_layer.view(*new_context_layer_shape) 203 | return context_layer 204 | 205 | 206 | class BERTSelfOutput(nn.Module): 207 | def __init__(self, config): 208 | super(BERTSelfOutput, self).__init__() 209 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 210 | self.LayerNorm = BERTLayerNorm(config) 211 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 212 | 213 | def forward(self, hidden_states, input_tensor): 214 | hidden_states = self.dense(hidden_states) 215 | hidden_states = self.dropout(hidden_states) 216 | hidden_states = self.LayerNorm(hidden_states + input_tensor) 217 | return hidden_states 218 | 219 | 220 | class BERTAttention(nn.Module): 221 | def __init__(self, config): 222 | super(BERTAttention, self).__init__() 223 | self.self = BERTSelfAttention(config) 224 | self.output = BERTSelfOutput(config) 225 | 226 | def forward(self, input_tensor, attention_mask): 227 | self_output = self.self(input_tensor, attention_mask) 228 | attention_output = self.output(self_output, input_tensor) 229 | return attention_output 230 | 231 | 232 | class BERTIntermediate(nn.Module): 233 | def __init__(self, config): 234 | super(BERTIntermediate, self).__init__() 235 | self.dense = nn.Linear(config.hidden_size, config.intermediate_size) 236 | self.intermediate_act_fn = gelu 237 | 238 | def forward(self, hidden_states): 239 | hidden_states = self.dense(hidden_states) 240 | hidden_states = self.intermediate_act_fn(hidden_states) 241 | return hidden_states 242 | 243 | 244 | class BERTOutput(nn.Module): 245 | def __init__(self, config): 246 | super(BERTOutput, self).__init__() 247 | self.dense = nn.Linear(config.intermediate_size, config.hidden_size) 248 | self.LayerNorm = BERTLayerNorm(config) 249 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 250 | 251 | def forward(self, hidden_states, input_tensor): 252 | hidden_states = self.dense(hidden_states) 253 | hidden_states = self.dropout(hidden_states) 254 | hidden_states = self.LayerNorm(hidden_states + input_tensor) 255 | return hidden_states 256 | 257 | 258 | class BERTLayer(nn.Module): 259 | def __init__(self, config): 260 | super(BERTLayer, self).__init__() 261 | self.attention = BERTAttention(config) 262 | self.intermediate = BERTIntermediate(config) 263 | self.output = BERTOutput(config) 264 | 265 | def forward(self, hidden_states, attention_mask): 266 | attention_output = self.attention(hidden_states, attention_mask) 267 | intermediate_output = self.intermediate(attention_output) 268 | layer_output = self.output(intermediate_output, attention_output) 269 | return layer_output 270 | 271 | 272 | class BERTEncoder(nn.Module): 273 | def __init__(self, config): 274 | super(BERTEncoder, self).__init__() 275 | layer = BERTLayer(config) 276 | self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)]) 277 | 278 | def forward(self, hidden_states, attention_mask): 279 | all_encoder_layers = [] 280 | for layer_module in self.layer: 281 | hidden_states = layer_module(hidden_states, attention_mask) 282 | all_encoder_layers.append(hidden_states) 283 | return all_encoder_layers 284 | 285 | 286 | class BERTPooler(nn.Module): 287 | def __init__(self, config): 288 | super(BERTPooler, self).__init__() 289 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 290 | self.activation = nn.Tanh() 291 | 292 | def forward(self, hidden_states): 293 | # We "pool" the model by simply taking the hidden state corresponding 294 | # to the first token. 295 | first_token_tensor = hidden_states[:, 0] 296 | #return first_token_tensor 297 | pooled_output = self.dense(first_token_tensor) 298 | pooled_output = self.activation(pooled_output) 299 | return pooled_output 300 | 301 | 302 | class BertModel(nn.Module): 303 | """BERT model ("Bidirectional Embedding Representations from a Transformer"). 304 | 305 | Example usage: 306 | ```python 307 | # Already been converted into WordPiece token ids 308 | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) 309 | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) 310 | token_type_ids = torch.LongTensor([[0, 0, 1], [0, 2, 0]]) 311 | 312 | config = modeling.BertConfig(vocab_size=32000, hidden_size=512, 313 | num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024) 314 | 315 | model = modeling.BertModel(config=config) 316 | all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask) 317 | ``` 318 | """ 319 | def __init__(self, config: BertConfig): 320 | """Constructor for BertModel. 321 | 322 | Args: 323 | config: `BertConfig` instance. 324 | """ 325 | super(BertModel, self).__init__() 326 | self.embeddings = BERTEmbeddings(config) 327 | self.encoder = BERTEncoder(config) 328 | self.pooler = BERTPooler(config) 329 | 330 | def forward(self, input_ids, token_type_ids=None, attention_mask=None): 331 | if attention_mask is None: 332 | attention_mask = torch.ones_like(input_ids) 333 | if token_type_ids is None: 334 | token_type_ids = torch.zeros_like(input_ids) 335 | 336 | # We create a 3D attention mask from a 2D tensor mask. 337 | # Sizes are [batch_size, 1, 1, from_seq_length] 338 | # So we can broadcast to [batch_size, num_heads, to_seq_length, from_seq_length] 339 | # this attention mask is more simple than the triangular masking of causal attention 340 | # used in OpenAI GPT, we just need to prepare the broadcast dimension here. 341 | extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) 342 | 343 | # Since attention_mask is 1.0 for positions we want to attend and 0.0 for 344 | # masked positions, this operation will create a tensor which is 0.0 for 345 | # positions we want to attend and -10000.0 for masked positions. 346 | # Since we are adding it to the raw scores before the softmax, this is 347 | # effectively the same as removing these entirely. 348 | extended_attention_mask = extended_attention_mask.float() 349 | extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 350 | 351 | embedding_output = self.embeddings(input_ids, token_type_ids) 352 | all_encoder_layers = self.encoder(embedding_output, extended_attention_mask) 353 | sequence_output = all_encoder_layers[-1] 354 | pooled_output = self.pooler(sequence_output) 355 | return all_encoder_layers, pooled_output 356 | 357 | class BertForSequenceClassification(nn.Module): 358 | """BERT model for classification. 359 | This module is composed of the BERT model with a linear layer on top of 360 | the pooled output. 361 | 362 | Example usage: 363 | ```python 364 | # Already been converted into WordPiece token ids 365 | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) 366 | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) 367 | token_type_ids = torch.LongTensor([[0, 0, 1], [0, 2, 0]]) 368 | 369 | config = BertConfig(vocab_size=32000, hidden_size=512, 370 | num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024) 371 | 372 | num_labels = 2 373 | 374 | model = BertForSequenceClassification(config, num_labels) 375 | logits = model(input_ids, token_type_ids, input_mask) 376 | ``` 377 | """ 378 | def __init__(self, config, num_labels): 379 | super(BertForSequenceClassification, self).__init__() 380 | self.bert = BertModel(config) 381 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 382 | self.classifier = nn.Linear(config.hidden_size, num_labels) 383 | 384 | def init_weights(module): 385 | if isinstance(module, (nn.Linear, nn.Embedding)): 386 | # Slightly different from the TF version which uses truncated_normal for initialization 387 | # cf https://github.com/pytorch/pytorch/pull/5617 388 | module.weight.data.normal_(mean=0.0, std=config.initializer_range) 389 | elif isinstance(module, BERTLayerNorm): 390 | module.beta.data.normal_(mean=0.0, std=config.initializer_range) 391 | module.gamma.data.normal_(mean=0.0, std=config.initializer_range) 392 | if isinstance(module, nn.Linear): 393 | module.bias.data.zero_() 394 | self.apply(init_weights) 395 | 396 | def forward(self, input_ids, token_type_ids, attention_mask, labels=None): 397 | _, pooled_output = self.bert(input_ids, token_type_ids, attention_mask) 398 | pooled_output = self.dropout(pooled_output) 399 | logits = self.classifier(pooled_output) 400 | 401 | if labels is not None: 402 | loss_fct = CrossEntropyLoss() 403 | loss = loss_fct(logits, labels) 404 | return loss, logits 405 | else: 406 | return logits 407 | 408 | 409 | class BertForQuestionAnswering(nn.Module): 410 | """BERT model for Question Answering (span extraction). 411 | This module is composed of the BERT model with a linear layer on top of 412 | the sequence output that computes start_logits and end_logits 413 | 414 | Example usage: 415 | ```python 416 | # Already been converted into WordPiece token ids 417 | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) 418 | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) 419 | token_type_ids = torch.LongTensor([[0, 0, 1], [0, 2, 0]]) 420 | 421 | config = BertConfig(vocab_size=32000, hidden_size=512, 422 | num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024) 423 | 424 | model = BertForQuestionAnswering(config) 425 | start_logits, end_logits = model(input_ids, token_type_ids, input_mask) 426 | ``` 427 | """ 428 | def __init__(self, config): 429 | super(BertForQuestionAnswering, self).__init__() 430 | self.bert = BertModel(config) 431 | # TODO check with Google if it's normal there is no dropout on the token classifier of SQuAD in the TF version 432 | # self.dropout = nn.Dropout(config.hidden_dropout_prob) 433 | self.qa_outputs = nn.Linear(config.hidden_size, 2) 434 | 435 | def init_weights(module): 436 | if isinstance(module, (nn.Linear, nn.Embedding)): 437 | # Slightly different from the TF version which uses truncated_normal for initialization 438 | # cf https://github.com/pytorch/pytorch/pull/5617 439 | module.weight.data.normal_(mean=0.0, std=config.initializer_range) 440 | elif isinstance(module, BERTLayerNorm): 441 | module.beta.data.normal_(mean=0.0, std=config.initializer_range) 442 | module.gamma.data.normal_(mean=0.0, std=config.initializer_range) 443 | if isinstance(module, nn.Linear): 444 | module.bias.data.zero_() 445 | self.apply(init_weights) 446 | 447 | def forward(self, input_ids, token_type_ids, attention_mask, start_positions=None, end_positions=None): 448 | all_encoder_layers, _ = self.bert(input_ids, token_type_ids, attention_mask) 449 | sequence_output = all_encoder_layers[-1] 450 | logits = self.qa_outputs(sequence_output) 451 | start_logits, end_logits = logits.split(1, dim=-1) 452 | start_logits = start_logits.squeeze(-1) 453 | end_logits = end_logits.squeeze(-1) 454 | 455 | if start_positions is not None and end_positions is not None: 456 | # If we are on multi-GPU, split add a dimension - if not this is a no-op 457 | start_positions = start_positions.squeeze(-1) 458 | end_positions = end_positions.squeeze(-1) 459 | # sometimes the start/end positions are outside our model inputs, we ignore these terms 460 | ignored_index = start_logits.size(1) 461 | start_positions.clamp_(0, ignored_index) 462 | end_positions.clamp_(0, ignored_index) 463 | 464 | loss_fct = CrossEntropyLoss(ignore_index=ignored_index) 465 | start_loss = loss_fct(start_logits, start_positions) 466 | end_loss = loss_fct(end_logits, end_positions) 467 | total_loss = (start_loss + end_loss) / 2 468 | return total_loss 469 | else: 470 | return start_logits, end_logits 471 | 472 | 473 | # BERT + softmax 474 | class BertForTABSAJoint(nn.Module): 475 | def __init__(self, config, num_labels, num_ner_labels, max_seq_length): 476 | super(BertForTABSAJoint, self).__init__() 477 | self.bert = BertModel(config) 478 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 479 | self.classifier = nn.Linear(config.hidden_size, num_labels) # num_labels is the type sum of 0 & 1 480 | self.ner_hidden2tag = nn.Linear(config.hidden_size, num_ner_labels) # num_ner_labels is the type sum of ner labels: TO or BIO etc 481 | self.num_labels = num_labels 482 | self.num_ner_labels = num_ner_labels 483 | self.max_seq_length = max_seq_length 484 | 485 | def init_weights(module): 486 | if isinstance(module, (nn.Linear, nn.Embedding)): 487 | # Slightly different from the TF version which uses truncated_normal for initialization 488 | # cf https://github.com/pytorch/pytorch/pull/5617 489 | module.weight.data.normal_(mean=0.0, std=config.initializer_range) 490 | elif isinstance(module, BERTLayerNorm): 491 | module.beta.data.normal_(mean=0.0, std=config.initializer_range) 492 | module.gamma.data.normal_(mean=0.0, std=config.initializer_range) 493 | if isinstance(module, nn.Linear): 494 | module.bias.data.zero_() 495 | self.apply(init_weights) 496 | 497 | def forward(self, input_ids, token_type_ids, attention_mask, labels, ner_labels, ner_mask): 498 | all_encoder_layers, pooled_output = self.bert(input_ids, token_type_ids, attention_mask) 499 | # get the last hidden layer 500 | sequence_output = all_encoder_layers[-1] 501 | # cross a dropout layer 502 | sequence_output = self.dropout(sequence_output) 503 | pooled_output = self.dropout(pooled_output) 504 | # the Classifier of category & polarity 505 | logits = self.classifier(pooled_output) 506 | ner_logits = self.ner_hidden2tag(sequence_output) 507 | ner_logits.reshape([-1, self.max_seq_length, self.num_ner_labels]) 508 | 509 | loss_fct = CrossEntropyLoss() 510 | loss = loss_fct(logits, labels) 511 | ner_loss_fct = CrossEntropyLoss(ignore_index=0) 512 | ner_loss = ner_loss_fct(ner_logits.view(-1, self.num_ner_labels), ner_labels.view(-1)) 513 | return loss, ner_loss, logits, ner_logits 514 | 515 | 516 | # BERT + CRF 517 | class BertForTABSAJoint_CRF(nn.Module): 518 | 519 | def __init__(self, config, num_labels, num_ner_labels): 520 | super(BertForTABSAJoint_CRF, self).__init__() 521 | self.bert = BertModel(config) 522 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 523 | self.classifier = nn.Linear(config.hidden_size, num_labels) # num_labels is the type sum of 0 & 1 524 | self.ner_hidden2tag = nn.Linear(config.hidden_size, num_ner_labels) # num_ner_labels is the type sum of ner labels: TO or BIO etc 525 | self.num_labels = num_labels 526 | self.num_ner_labels = num_ner_labels 527 | # CRF 528 | self.CRF_model = CRF(num_ner_labels, batch_first=True) 529 | 530 | def init_weights(module): 531 | if isinstance(module, (nn.Linear, nn.Embedding)): 532 | # Slightly different from the TF version which uses truncated_normal for initialization 533 | # cf https://github.com/pytorch/pytorch/pull/5617 534 | module.weight.data.normal_(mean=0.0, std=config.initializer_range) 535 | elif isinstance(module, BERTLayerNorm): 536 | module.beta.data.normal_(mean=0.0, std=config.initializer_range) 537 | module.gamma.data.normal_(mean=0.0, std=config.initializer_range) 538 | if isinstance(module, nn.Linear): 539 | module.bias.data.zero_() 540 | self.apply(init_weights) 541 | 542 | def forward(self, input_ids, token_type_ids, attention_mask, labels, ner_labels, ner_mask): 543 | all_encoder_layers, pooled_output = self.bert(input_ids, token_type_ids, attention_mask) 544 | # get the last hidden layer 545 | sequence_output = all_encoder_layers[-1] 546 | # cross a dropout layer 547 | sequence_output = self.dropout(sequence_output) 548 | pooled_output = self.dropout(pooled_output) 549 | # the Classifier of category & polarity 550 | logits = self.classifier(pooled_output) 551 | ner_logits = self.ner_hidden2tag(sequence_output) 552 | 553 | # the CRF layer of NER labels 554 | ner_loss_list = self.CRF_model(ner_logits, ner_labels, ner_mask.type(torch.ByteTensor).cuda(), reduction='none') 555 | ner_loss = torch.mean(-ner_loss_list) 556 | ner_predict = self.CRF_model.decode(ner_logits, ner_mask.type(torch.ByteTensor).cuda()) 557 | 558 | # the classifier of category & polarity 559 | loss_fct = CrossEntropyLoss() 560 | loss = loss_fct(logits, labels) 561 | return loss, ner_loss, logits, ner_predict 562 | 563 | #the model for ablation study, separate training 564 | # BERT + softmax 565 | class BertForTABSAJoint_AS(nn.Module): 566 | def __init__(self, config, num_labels, num_ner_labels, max_seq_length): 567 | super(BertForTABSAJoint_AS, self).__init__() 568 | self.bert = BertModel(config) 569 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 570 | self.classifier = nn.Linear(config.hidden_size, num_labels) # num_labels is the type sum of 0 & 1 571 | self.ner_hidden2tag = nn.Linear(config.hidden_size, num_ner_labels) # num_ner_labels is the type sum of ner labels: TO or BIO etc 572 | self.num_labels = num_labels 573 | self.num_ner_labels = num_ner_labels 574 | self.max_seq_length = max_seq_length 575 | 576 | def init_weights(module): 577 | if isinstance(module, (nn.Linear, nn.Embedding)): 578 | # Slightly different from the TF version which uses truncated_normal for initialization 579 | # cf https://github.com/pytorch/pytorch/pull/5617 580 | module.weight.data.normal_(mean=0.0, std=config.initializer_range) 581 | elif isinstance(module, BERTLayerNorm): 582 | module.beta.data.normal_(mean=0.0, std=config.initializer_range) 583 | module.gamma.data.normal_(mean=0.0, std=config.initializer_range) 584 | if isinstance(module, nn.Linear): 585 | module.bias.data.zero_() 586 | self.apply(init_weights) 587 | 588 | def forward(self, input_ids, token_type_ids, attention_mask, labels, ner_labels, ner_mask): 589 | all_encoder_layers, pooled_output = self.bert(input_ids, token_type_ids, attention_mask) 590 | # get the last hidden layer 591 | sequence_output = all_encoder_layers[-1] 592 | # cross a dropout layer 593 | sequence_output = self.dropout(sequence_output) 594 | pooled_output = self.dropout(pooled_output) 595 | # the Classifier of category & polarity 596 | logits = self.classifier(pooled_output) 597 | loss_fct = CrossEntropyLoss() 598 | loss = loss_fct(logits, labels) 599 | return loss, logits 600 | 601 | #the model for ablation study, separate training 602 | # BERT + softmax 603 | class BertForTABSAJoint_T(nn.Module): 604 | def __init__(self, config, num_labels, num_ner_labels, max_seq_length): 605 | super(BertForTABSAJoint_T, self).__init__() 606 | self.bert = BertModel(config) 607 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 608 | self.classifier = nn.Linear(config.hidden_size, num_labels) # num_labels is the type sum of 0 & 1 609 | self.ner_hidden2tag = nn.Linear(config.hidden_size, num_ner_labels) # num_ner_labels is the type sum of ner labels: TO or BIO etc 610 | self.num_labels = num_labels 611 | self.num_ner_labels = num_ner_labels 612 | self.max_seq_length = max_seq_length 613 | 614 | def init_weights(module): 615 | if isinstance(module, (nn.Linear, nn.Embedding)): 616 | # Slightly different from the TF version which uses truncated_normal for initialization 617 | # cf https://github.com/pytorch/pytorch/pull/5617 618 | module.weight.data.normal_(mean=0.0, std=config.initializer_range) 619 | elif isinstance(module, BERTLayerNorm): 620 | module.beta.data.normal_(mean=0.0, std=config.initializer_range) 621 | module.gamma.data.normal_(mean=0.0, std=config.initializer_range) 622 | if isinstance(module, nn.Linear): 623 | module.bias.data.zero_() 624 | self.apply(init_weights) 625 | 626 | def forward(self, input_ids, token_type_ids, attention_mask, labels, ner_labels, ner_mask): 627 | all_encoder_layers, pooled_output = self.bert(input_ids, token_type_ids, attention_mask) 628 | # get the last hidden layer 629 | sequence_output = all_encoder_layers[-1] 630 | # cross a dropout layer 631 | sequence_output = self.dropout(sequence_output) 632 | pooled_output = self.dropout(pooled_output) 633 | # the Classifier of category & polarity 634 | ner_logits = self.ner_hidden2tag(sequence_output) 635 | ner_logits.reshape([-1, self.max_seq_length, self.num_ner_labels]) 636 | 637 | ner_loss_fct = CrossEntropyLoss(ignore_index=0) 638 | ner_loss = ner_loss_fct(ner_logits.view(-1, self.num_ner_labels), ner_labels.view(-1)) 639 | return ner_loss, ner_logits 640 | 641 | #the model for ablation study, separate training 642 | # BERT + CRF 643 | class BertForTABSAJoint_CRF_AS(nn.Module): 644 | 645 | def __init__(self, config, num_labels, num_ner_labels): 646 | super(BertForTABSAJoint_CRF_AS, self).__init__() 647 | self.bert = BertModel(config) 648 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 649 | self.classifier = nn.Linear(config.hidden_size, num_labels) # num_labels is the type sum of 0 & 1 650 | self.ner_hidden2tag = nn.Linear(config.hidden_size, num_ner_labels) # num_ner_labels is the type sum of ner labels: TO or BIO etc 651 | self.num_labels = num_labels 652 | self.num_ner_labels = num_ner_labels 653 | # CRF 654 | self.CRF_model = CRF(num_ner_labels, batch_first=True) 655 | 656 | def init_weights(module): 657 | if isinstance(module, (nn.Linear, nn.Embedding)): 658 | # Slightly different from the TF version which uses truncated_normal for initialization 659 | # cf https://github.com/pytorch/pytorch/pull/5617 660 | module.weight.data.normal_(mean=0.0, std=config.initializer_range) 661 | elif isinstance(module, BERTLayerNorm): 662 | module.beta.data.normal_(mean=0.0, std=config.initializer_range) 663 | module.gamma.data.normal_(mean=0.0, std=config.initializer_range) 664 | if isinstance(module, nn.Linear): 665 | module.bias.data.zero_() 666 | self.apply(init_weights) 667 | 668 | def forward(self, input_ids, token_type_ids, attention_mask, labels, ner_labels, ner_mask): 669 | all_encoder_layers, pooled_output = self.bert(input_ids, token_type_ids, attention_mask) 670 | # get the last hidden layer 671 | sequence_output = all_encoder_layers[-1] 672 | # cross a dropout layer 673 | sequence_output = self.dropout(sequence_output) 674 | pooled_output = self.dropout(pooled_output) 675 | # the Classifier of category & polarity 676 | logits = self.classifier(pooled_output) 677 | # the classifier of category & polarity 678 | loss_fct = CrossEntropyLoss() 679 | loss = loss_fct(logits, labels) 680 | return loss, logits 681 | 682 | #the model for ablation study, separate training 683 | # BERT + CRF 684 | class BertForTABSAJoint_CRF_T(nn.Module): 685 | def __init__(self, config, num_labels, num_ner_labels): 686 | super(BertForTABSAJoint_CRF_T, self).__init__() 687 | self.bert = BertModel(config) 688 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 689 | self.classifier = nn.Linear(config.hidden_size, num_labels) # num_labels is the type sum of 0 & 1 690 | self.ner_hidden2tag = nn.Linear(config.hidden_size, num_ner_labels) # num_ner_labels is the type sum of ner labels: TO or BIO etc 691 | self.num_labels = num_labels 692 | self.num_ner_labels = num_ner_labels 693 | # CRF 694 | self.CRF_model = CRF(num_ner_labels, batch_first=True) 695 | 696 | def init_weights(module): 697 | if isinstance(module, (nn.Linear, nn.Embedding)): 698 | # Slightly different from the TF version which uses truncated_normal for initialization 699 | # cf https://github.com/pytorch/pytorch/pull/5617 700 | module.weight.data.normal_(mean=0.0, std=config.initializer_range) 701 | elif isinstance(module, BERTLayerNorm): 702 | module.beta.data.normal_(mean=0.0, std=config.initializer_range) 703 | module.gamma.data.normal_(mean=0.0, std=config.initializer_range) 704 | if isinstance(module, nn.Linear): 705 | module.bias.data.zero_() 706 | self.apply(init_weights) 707 | 708 | def forward(self, input_ids, token_type_ids, attention_mask, labels, ner_labels, ner_mask): 709 | all_encoder_layers, pooled_output = self.bert(input_ids, token_type_ids, attention_mask) 710 | # get the last hidden layer 711 | sequence_output = all_encoder_layers[-1] 712 | # cross a dropout layer 713 | sequence_output = self.dropout(sequence_output) 714 | pooled_output = self.dropout(pooled_output) 715 | ner_logits = self.ner_hidden2tag(sequence_output) 716 | 717 | # the CRF layer of NER labels 718 | ner_loss_list = self.CRF_model(ner_logits, ner_labels, ner_mask.type(torch.ByteTensor).cuda(), reduction='none') 719 | ner_loss = torch.mean(-ner_loss_list) 720 | ner_predict = self.CRF_model.decode(ner_logits, ner_mask.type(torch.ByteTensor).cuda()) 721 | 722 | return ner_loss, ner_predict 723 | 724 | 725 | -------------------------------------------------------------------------------- /optimization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | # Reference: https://github.com/huggingface/pytorch-pretrained-BERT 4 | 5 | """PyTorch optimization for BERT model.""" 6 | 7 | import math 8 | 9 | import torch 10 | from torch.nn.utils import clip_grad_norm_ 11 | from torch.optim import Optimizer 12 | 13 | 14 | def warmup_cosine(x, warmup=0.002): 15 | if x < warmup: 16 | return x/warmup 17 | return 0.5 * (1.0 + torch.cos(math.pi * x)) 18 | 19 | def warmup_constant(x, warmup=0.002): 20 | if x < warmup: 21 | return x/warmup 22 | return 1.0 23 | 24 | def warmup_linear(x, warmup=0.002): 25 | if x < warmup: 26 | return x/warmup 27 | return 1.0 - x 28 | 29 | SCHEDULES = { 30 | 'warmup_cosine':warmup_cosine, 31 | 'warmup_constant':warmup_constant, 32 | 'warmup_linear':warmup_linear, 33 | } 34 | 35 | 36 | class BERTAdam(Optimizer): 37 | """Implements BERT version of Adam algorithm with weight decay fix (and no ). 38 | Params: 39 | lr: learning rate 40 | warmup: portion of t_total for the warmup, -1 means no warmup. Default: -1 41 | t_total: total number of training steps for the learning 42 | rate schedule, -1 means constant learning rate. Default: -1 43 | schedule: schedule to use for the warmup (see above). Default: 'warmup_linear' 44 | b1: Adams b1. Default: 0.9 45 | b2: Adams b2. Default: 0.999 46 | e: Adams epsilon. Default: 1e-6 47 | weight_decay_rate: Weight decay. Default: 0.01 48 | max_grad_norm: Maximum norm for the gradients (-1 means no clipping). Default: 1.0 49 | """ 50 | def __init__(self, params, lr, warmup=-1, t_total=-1, schedule='warmup_linear', 51 | b1=0.9, b2=0.999, e=1e-6, weight_decay_rate=0.01, 52 | max_grad_norm=1.0): 53 | if not lr >= 0.0: 54 | raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr)) 55 | if schedule not in SCHEDULES: 56 | raise ValueError("Invalid schedule parameter: {}".format(schedule)) 57 | if not 0.0 <= warmup < 1.0 and not warmup == -1: 58 | raise ValueError("Invalid warmup: {} - should be in [0.0, 1.0[ or -1".format(warmup)) 59 | if not 0.0 <= b1 < 1.0: 60 | raise ValueError("Invalid b1 parameter: {} - should be in [0.0, 1.0[".format(b1)) 61 | if not 0.0 <= b2 < 1.0: 62 | raise ValueError("Invalid b2 parameter: {} - should be in [0.0, 1.0[".format(b2)) 63 | if not e >= 0.0: 64 | raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(e)) 65 | defaults = dict(lr=lr, schedule=schedule, warmup=warmup, t_total=t_total, 66 | b1=b1, b2=b2, e=e, weight_decay_rate=weight_decay_rate, 67 | max_grad_norm=max_grad_norm) 68 | super(BERTAdam, self).__init__(params, defaults) 69 | 70 | def get_lr(self): 71 | lr = [] 72 | print("l_total=",len(self.param_groups)) 73 | for group in self.param_groups: 74 | print("l_p=",len(group['params'])) 75 | for p in group['params']: 76 | state = self.state[p] 77 | if len(state) == 0: 78 | return [0] 79 | if group['t_total'] != -1: 80 | schedule_fct = SCHEDULES[group['schedule']] 81 | lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup']) 82 | else: 83 | lr_scheduled = group['lr'] 84 | lr.append(lr_scheduled) 85 | return lr 86 | 87 | def to(self, device): 88 | """ Move the optimizer state to a specified device""" 89 | for state in self.state.values(): 90 | state['exp_avg'].to(device) 91 | state['exp_avg_sq'].to(device) 92 | 93 | def initialize_step(self, initial_step): 94 | """Initialize state with a defined step (but we don't have stored averaged). 95 | Arguments: 96 | initial_step (int): Initial step number. 97 | """ 98 | for group in self.param_groups: 99 | for p in group['params']: 100 | state = self.state[p] 101 | # State initialization 102 | state['step'] = initial_step 103 | # Exponential moving average of gradient values 104 | state['exp_avg'] = torch.zeros_like(p.data) 105 | # Exponential moving average of squared gradient values 106 | state['exp_avg_sq'] = torch.zeros_like(p.data) 107 | 108 | def step(self, closure=None): 109 | """Performs a single optimization step. 110 | 111 | Arguments: 112 | closure (callable, optional): A closure that reevaluates the model 113 | and returns the loss. 114 | """ 115 | loss = None 116 | if closure is not None: 117 | loss = closure() 118 | 119 | for group in self.param_groups: 120 | for p in group['params']: 121 | if p.grad is None: 122 | continue 123 | grad = p.grad.data 124 | if grad.is_sparse: 125 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 126 | 127 | state = self.state[p] 128 | 129 | # State initialization 130 | if len(state) == 0: 131 | state['step'] = 0 132 | # Exponential moving average of gradient values 133 | state['next_m'] = torch.zeros_like(p.data) 134 | # Exponential moving average of squared gradient values 135 | state['next_v'] = torch.zeros_like(p.data) 136 | 137 | next_m, next_v = state['next_m'], state['next_v'] 138 | beta1, beta2 = group['b1'], group['b2'] 139 | 140 | # Add grad clipping 141 | if group['max_grad_norm'] > 0: 142 | clip_grad_norm_(p, group['max_grad_norm']) 143 | 144 | # Decay the first and second moment running average coefficient 145 | # In-place operations to update the averages at the same time 146 | next_m.mul_(beta1).add_(1 - beta1, grad) 147 | next_v.mul_(beta2).addcmul_(1 - beta2, grad, grad) 148 | update = next_m / (next_v.sqrt() + group['e']) 149 | 150 | # Just adding the square of the weights to the loss function is *not* 151 | # the correct way of using L2 regularization/weight decay with Adam, 152 | # since that will interact with the m and v parameters in strange ways. 153 | # 154 | # Instead we want ot decay the weights in a manner that doesn't interact 155 | # with the m/v parameters. This is equivalent to adding the square 156 | # of the weights to the loss with plain (non-momentum) SGD. 157 | if group['weight_decay_rate'] > 0.0: 158 | update += group['weight_decay_rate'] * p.data 159 | 160 | if group['t_total'] != -1: 161 | schedule_fct = SCHEDULES[group['schedule']] 162 | lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup']) 163 | else: 164 | lr_scheduled = group['lr'] 165 | 166 | update_with_lr = lr_scheduled * update 167 | p.data.add_(-update_with_lr) 168 | 169 | state['step'] += 1 170 | 171 | # step_size = lr_scheduled * math.sqrt(bias_correction2) / bias_correction1 172 | # bias_correction1 = 1 - beta1 ** state['step'] 173 | # bias_correction2 = 1 - beta2 ** state['step'] 174 | 175 | return loss 176 | -------------------------------------------------------------------------------- /processor.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | """Processors for Semeval Dataset.""" 4 | 5 | import csv 6 | import os 7 | import tokenization 8 | 9 | 10 | class InputExample(object): 11 | """A single training/test example for simple sequence classification.""" 12 | 13 | def __init__(self, guid, text_a, text_b=None, label=None, ner_labels_a=None): 14 | """Constructs a InputExample. 15 | 16 | Args: 17 | guid: Unique id for the example. 18 | text_a: string. The untokenized text of the first sequence. For single 19 | sequence tasks, only this sequence must be specified. 20 | text_b: (Optional) string. The untokenized text of the second sequence. 21 | Only must be specified for sequence pair tasks. 22 | label: (Optional) string. The label of the example. This should be 23 | specified for train and dev examples, but not for test examples. 24 | ner_labels_a: ner tag sequence for text_a. This should be 25 | specified for train and dev examples, but not for test examples. 26 | """ 27 | self.guid = guid 28 | self.text_a = text_a 29 | self.text_b = text_b 30 | self.label = label 31 | self.ner_labels_a = ner_labels_a 32 | 33 | 34 | class DataProcessor(object): 35 | """Base class for data converters for sequence classification data sets.""" 36 | 37 | def get_train_examples(self, data_dir): 38 | """Gets a collection of `InputExample`s for the train set.""" 39 | raise NotImplementedError() 40 | 41 | def get_dev_examples(self, data_dir): 42 | """Gets a collection of `InputExample`s for the dev set.""" 43 | raise NotImplementedError() 44 | 45 | def get_test_examples(self, data_dir): 46 | """Gets a collection of `InputExample`s for the test set.""" 47 | raise NotImplementedError() 48 | 49 | def get_labels(self): 50 | """Gets the list of labels for this data set.""" 51 | raise NotImplementedError() 52 | 53 | def get_ner_labels(self): 54 | """Gets the list of labels for this data set.""" 55 | raise NotImplementedError() 56 | 57 | @classmethod 58 | def _read_tsv(cls, input_file, quotechar=None): 59 | """Reads a tab separated value file.""" 60 | with open(input_file, "r") as f: 61 | reader = csv.reader(f, delimiter="\t", quotechar=quotechar) 62 | lines = [] 63 | for line in reader: 64 | lines.append(line) 65 | return lines 66 | 67 | 68 | class Semeval_Processor(DataProcessor): 69 | """Processor for the SemEval 2015 and 2016 data set.""" 70 | 71 | def get_train_examples(self, data_dir): 72 | """See base class.""" 73 | with open(os.path.join(data_dir, "train_TAS.tsv"), 'r', encoding='utf-8') as fin: 74 | fin.readline() 75 | train_data = fin.readlines() 76 | return self._create_examples(train_data, "train") 77 | 78 | def get_dev_examples(self, data_dir): 79 | """See base class.""" 80 | with open(os.path.join(data_dir, "dev_TAS.tsv"), 'r', encoding='utf-8') as fin: 81 | fin.readline() 82 | dev_data = fin.readlines() 83 | return self._create_examples(dev_data, "dev") 84 | 85 | def get_test_examples(self, data_dir): 86 | """See base class.""" 87 | with open(os.path.join(data_dir, "test_TAS.tsv"), 'r', encoding='utf-8') as fin: 88 | fin.readline() 89 | test_data = fin.readlines() 90 | return self._create_examples(test_data, "test") 91 | 92 | def get_labels(self): 93 | """See base class.""" 94 | return ['0', '1'] 95 | 96 | def get_ner_labels(self, data_dir): 97 | ner_labels = ['[PAD]', '[CLS]'] 98 | with open(os.path.join(data_dir, "train_TAS.tsv"), 'r', encoding='utf-8') as fin: 99 | fin.readline() 100 | for line in fin: 101 | tags = line.strip().split('\t')[-1].split() 102 | for x in tags: 103 | if x not in ner_labels: 104 | ner_labels.append(x) 105 | print(ner_labels) 106 | return ner_labels 107 | 108 | def _create_examples(self, lines, set_type): 109 | """Creates examples.""" 110 | examples = [] 111 | for (i, line) in enumerate(lines): 112 | line_arr = line.strip().split('\t') 113 | guid = "%s-%s" % (set_type, i) 114 | text_a = tokenization.convert_to_unicode(str(line_arr[3])) # sentence 115 | text_b = tokenization.convert_to_unicode(str(line_arr[2])) # category_polarity 116 | label = tokenization.convert_to_unicode(str(line_arr[1])) # yes or no 117 | ner_labels_a = tokenization.convert_to_unicode(str(line_arr[4])) # ner tags 118 | 119 | examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label, ner_labels_a=ner_labels_a)) 120 | return examples 121 | -------------------------------------------------------------------------------- /tokenization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | # Reference: https://github.com/huggingface/pytorch-pretrained-BERT 4 | 5 | """Tokenization classes.""" 6 | 7 | from __future__ import absolute_import, division, print_function 8 | 9 | import collections 10 | import unicodedata 11 | 12 | import six 13 | 14 | 15 | def convert_to_unicode(text): 16 | """Converts `text` to Unicode (if it's not already), assuming utf-8 input.""" 17 | if six.PY3: 18 | if isinstance(text, str): 19 | return text 20 | elif isinstance(text, bytes): 21 | return text.decode("utf-8", "ignore") 22 | else: 23 | raise ValueError("Unsupported string type: %s" % (type(text))) 24 | elif six.PY2: 25 | if isinstance(text, str): 26 | return text.decode("utf-8", "ignore") 27 | elif isinstance(text, unicode): 28 | return text 29 | else: 30 | raise ValueError("Unsupported string type: %s" % (type(text))) 31 | else: 32 | raise ValueError("Not running on Python2 or Python 3?") 33 | 34 | 35 | def printable_text(text): 36 | """Returns text encoded in a way suitable for print or `tf.logging`.""" 37 | 38 | # These functions want `str` for both Python2 and Python3, but in one case 39 | # it's a Unicode string and in the other it's a byte string. 40 | if six.PY3: 41 | if isinstance(text, str): 42 | return text 43 | elif isinstance(text, bytes): 44 | return text.decode("utf-8", "ignore") 45 | else: 46 | raise ValueError("Unsupported string type: %s" % (type(text))) 47 | elif six.PY2: 48 | if isinstance(text, str): 49 | return text 50 | elif isinstance(text, unicode): 51 | return text.encode("utf-8") 52 | else: 53 | raise ValueError("Unsupported string type: %s" % (type(text))) 54 | else: 55 | raise ValueError("Not running on Python2 or Python 3?") 56 | 57 | 58 | def load_vocab(vocab_file): 59 | """Loads a vocabulary file into a dictionary.""" 60 | vocab = collections.OrderedDict() 61 | index = 0 62 | with open(vocab_file, "r") as reader: 63 | while True: 64 | token = convert_to_unicode(reader.readline()) 65 | if not token: 66 | break 67 | token = token.strip() 68 | vocab[token] = index 69 | index += 1 70 | return vocab 71 | 72 | 73 | def convert_tokens_to_ids(vocab, tokens): 74 | """Converts a sequence of tokens into ids using the vocab.""" 75 | ids = [] 76 | for token in tokens: 77 | ids.append(vocab[token]) 78 | return ids 79 | 80 | 81 | def whitespace_tokenize(text): 82 | """Runs basic whitespace cleaning and splitting on a peice of text.""" 83 | text = text.strip() 84 | if not text: 85 | return [] 86 | tokens = text.split() 87 | return tokens 88 | 89 | 90 | class FullTokenizer(object): 91 | """Runs end-to-end tokenziation.""" 92 | 93 | def __init__(self, vocab_file, tokenize_method, do_lower_case=True): 94 | self.vocab = load_vocab(vocab_file) 95 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case) 96 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab, tokenize_method=tokenize_method) 97 | 98 | def tokenize(self, text): 99 | split_tokens = [] 100 | for token in self.basic_tokenizer.tokenize(text): 101 | for sub_token in self.wordpiece_tokenizer.tokenize(token): 102 | split_tokens.append(sub_token) 103 | 104 | return split_tokens 105 | 106 | def convert_tokens_to_ids(self, tokens): 107 | return convert_tokens_to_ids(self.vocab, tokens) 108 | 109 | 110 | class BasicTokenizer(object): 111 | """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" 112 | 113 | def __init__(self, do_lower_case=True): 114 | """Constructs a BasicTokenizer. 115 | 116 | Args: 117 | do_lower_case: Whether to lower case the input. 118 | """ 119 | self.do_lower_case = do_lower_case 120 | 121 | def tokenize(self, text): 122 | """Tokenizes a piece of text.""" 123 | text = convert_to_unicode(text) 124 | text = self._clean_text(text) 125 | orig_tokens = whitespace_tokenize(text) 126 | split_tokens = [] 127 | for token in orig_tokens: 128 | if self.do_lower_case: 129 | token = token.lower() 130 | token = self._run_strip_accents(token) 131 | split_tokens.extend(self._run_split_on_punc(token)) 132 | 133 | output_tokens = whitespace_tokenize(" ".join(split_tokens)) 134 | return output_tokens 135 | 136 | def _run_strip_accents(self, text): 137 | """Strips accents from a piece of text.""" 138 | text = unicodedata.normalize("NFD", text) 139 | output = [] 140 | for char in text: 141 | cat = unicodedata.category(char) 142 | if cat == "Mn": 143 | continue 144 | output.append(char) 145 | return "".join(output) 146 | 147 | def _run_split_on_punc(self, text): 148 | """Splits punctuation on a piece of text.""" 149 | chars = list(text) 150 | i = 0 151 | start_new_word = True 152 | output = [] 153 | while i < len(chars): 154 | char = chars[i] 155 | if _is_punctuation(char): 156 | output.append([char]) 157 | start_new_word = True 158 | else: 159 | if start_new_word: 160 | output.append([]) 161 | start_new_word = False 162 | output[-1].append(char) 163 | i += 1 164 | 165 | return ["".join(x) for x in output] 166 | 167 | def _clean_text(self, text): 168 | """Performs invalid character removal and whitespace cleanup on text.""" 169 | output = [] 170 | for char in text: 171 | cp = ord(char) 172 | if cp == 0 or cp == 0xfffd or _is_control(char): 173 | continue 174 | if _is_whitespace(char): 175 | output.append(" ") 176 | else: 177 | output.append(char) 178 | return "".join(output) 179 | 180 | 181 | class WordpieceTokenizer(object): 182 | """Runs WordPiece tokenization.""" 183 | 184 | def __init__(self, vocab, tokenize_method, unk_token="[UNK]", max_input_chars_per_word=100): 185 | self.vocab = vocab 186 | self.unk_token = unk_token 187 | self.max_input_chars_per_word = max_input_chars_per_word 188 | self.tokenize_method = tokenize_method 189 | 190 | 191 | def tokenize(self, text): 192 | text = convert_to_unicode(text) 193 | output_tokens = [] 194 | # replace unknow word with the longest prefix match word in dictionary 195 | if self.tokenize_method == "prefix_match": 196 | for token in whitespace_tokenize(text): 197 | chars = list(token) 198 | if len(chars) > self.max_input_chars_per_word: 199 | output_tokens.append(self.unk_token) 200 | continue 201 | start = 0 202 | end = len(chars) 203 | cur_substr = None 204 | while start < end: 205 | substr = ''.join(chars[start:end]) 206 | if substr in self.vocab: 207 | cur_substr = substr 208 | break 209 | end -= 1 210 | if cur_substr is None: 211 | output_tokens.append(self.unk_token) 212 | else: 213 | output_tokens.append(cur_substr) 214 | # replace unknow word to [UNK] 215 | elif self.tokenize_method == "unk_replace": 216 | for token in whitespace_tokenize(text): 217 | chars = list(token) 218 | if len(chars) > self.max_input_chars_per_word: 219 | output_tokens.append(self.unk_token) 220 | continue 221 | temp_str = "".join(chars[0:]) 222 | if temp_str in self.vocab: 223 | output_tokens.append(temp_str) 224 | else: 225 | output_tokens.append(self.unk_token) 226 | # split unknown word to several words 227 | elif self.tokenize_method == "word_split": 228 | for token in whitespace_tokenize(text): 229 | chars = list(token) 230 | if len(chars) > self.max_input_chars_per_word: 231 | output_tokens.append(self.unk_token) 232 | continue 233 | 234 | is_bad = False 235 | start = 0 236 | sub_tokens = [] 237 | while start < len(chars): 238 | end = len(chars) 239 | cur_substr = None 240 | while start < end: 241 | substr = "".join(chars[start:end]) 242 | if start > 0: 243 | substr = "##" + substr 244 | if substr in self.vocab: 245 | cur_substr = substr 246 | break 247 | end -= 1 248 | if cur_substr is None: 249 | is_bad = True 250 | break 251 | sub_tokens.append(cur_substr) 252 | start = end 253 | 254 | if is_bad: 255 | output_tokens.append(self.unk_token) 256 | else: 257 | output_tokens.extend(sub_tokens) 258 | 259 | return output_tokens 260 | 261 | def _is_whitespace(char): 262 | """Checks whether `chars` is a whitespace character.""" 263 | # \t, \n, and \r are technically contorl characters but we treat them 264 | # as whitespace since they are generally considered as such. 265 | if char == " " or char == "\t" or char == "\n" or char == "\r": 266 | return True 267 | cat = unicodedata.category(char) 268 | if cat == "Zs": 269 | return True 270 | return False 271 | 272 | 273 | def _is_control(char): 274 | """Checks whether `chars` is a control character.""" 275 | # These are technically control characters but we count them as whitespace 276 | # characters. 277 | if char == "\t" or char == "\n" or char == "\r": 278 | return False 279 | cat = unicodedata.category(char) 280 | if cat.startswith("C"): 281 | return True 282 | return False 283 | 284 | 285 | def _is_punctuation(char): 286 | """Checks whether `chars` is a punctuation character.""" 287 | cp = ord(char) 288 | # We treat all non-letter/number ASCII as punctuation. 289 | # Characters such as "^", "$", and "`" are not in the Unicode 290 | # Punctuation class but we treat them as punctuation anyways, for 291 | # consistency. 292 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or 293 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): 294 | return True 295 | cat = unicodedata.category(char) 296 | if cat.startswith("P"): 297 | return True 298 | return False 299 | --------------------------------------------------------------------------------