├── generate ├── make.sh ├── generate_semeval_NLI_B_QA_B.py ├── data_utils_sentihood.py ├── generate_sentihood_NLI_M.py ├── generate_sentihood_QA_M.py ├── generate_semeval_NLI_M.py ├── generate_sentihood_BERT_single.py ├── generate_semeval_BERT_single.py ├── generate_semeval_QA_M.py └── generate_sentihood_NLI_B_QA_B.py ├── LICENSE ├── convert_tf_checkpoint_to_pytorch.py ├── README.md ├── optimization.py ├── tokenization.py ├── evaluation.py ├── processor.py ├── modeling.py └── run_classifier_TABSA.py /generate/make.sh: -------------------------------------------------------------------------------- 1 | # generate datasets 2 | 3 | python generate_${1}_NLI_M.py 4 | python generate_${1}_QA_M.py 5 | python generate_${1}_NLI_B_QA_B.py 6 | python generate_${1}_BERT_single.py 7 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 HSLCY 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /generate/generate_semeval_NLI_B_QA_B.py: -------------------------------------------------------------------------------- 1 | data_dir='../data/semeval2014/bert-pair/' 2 | 3 | labels=['positive', 'neutral', 'negative', 'conflict', 'none'] 4 | with open(data_dir+"test_NLI_M.csv","r",encoding="utf-8") as f, \ 5 | open(data_dir+"test_NLI_B.csv","w",encoding="utf-8") as g_nli, \ 6 | open(data_dir+"test_QA_B.csv","w",encoding="utf-8") as g_qa: 7 | s=f.readline().strip() 8 | while s: 9 | tmp=s.split("\t") 10 | for label in labels: 11 | t_nli = label + " - " + tmp[2] 12 | t_qa = "the polarity of the aspect " + tmp[2] + " is " + label + " ." 13 | if tmp[1]==label: 14 | g_nli.write(tmp[0]+"\t1\t"+t_nli+"\t"+tmp[3]+"\n") 15 | g_qa.write(tmp[0]+"\t1\t"+t_qa+"\t"+tmp[3]+"\n") 16 | else: 17 | g_nli.write(tmp[0]+"\t0\t"+t_nli+"\t"+tmp[3]+"\n") 18 | g_qa.write(tmp[0]+"\t0\t"+t_qa+"\t"+tmp[3]+"\n") 19 | s = f.readline().strip() 20 | 21 | 22 | with open(data_dir+"train_NLI_M.csv","r",encoding="utf-8") as f, \ 23 | open(data_dir+"train_NLI_B.csv","w",encoding="utf-8") as g_nli, \ 24 | open(data_dir+"train_QA_B.csv","w",encoding="utf-8") as g_qa: 25 | s=f.readline().strip() 26 | while s: 27 | tmp=s.split("\t") 28 | for label in labels: 29 | t_nli = label + " - " + tmp[2] 30 | t_qa = "the polarity of the aspect " + tmp[2] + " is " + label + " ." 31 | if tmp[1]==label: 32 | g_nli.write(tmp[0]+"\t1\t"+t_nli+"\t"+tmp[3]+"\n") 33 | g_qa.write(tmp[0]+"\t1\t"+t_qa+"\t"+tmp[3]+"\n") 34 | else: 35 | g_nli.write(tmp[0]+"\t0\t"+t_nli+"\t"+tmp[3]+"\n") 36 | g_qa.write(tmp[0]+"\t0\t"+t_qa+"\t"+tmp[3]+"\n") 37 | s = f.readline().strip() -------------------------------------------------------------------------------- /generate/data_utils_sentihood.py: -------------------------------------------------------------------------------- 1 | # Reference: https://github.com/liufly/delayed-memory-update-entnet 2 | 3 | from __future__ import absolute_import 4 | 5 | import json 6 | import operator 7 | import os 8 | import re 9 | import sys 10 | import xml.etree.ElementTree 11 | 12 | import nltk 13 | import numpy as np 14 | 15 | 16 | def load_task(data_dir, aspect2idx): 17 | in_file = os.path.join(data_dir, 'sentihood-train.json') 18 | train = parse_sentihood_json(in_file) 19 | in_file = os.path.join(data_dir, 'sentihood-dev.json') 20 | dev = parse_sentihood_json(in_file) 21 | in_file = os.path.join(data_dir, 'sentihood-test.json') 22 | test = parse_sentihood_json(in_file) 23 | 24 | train = convert_input(train, aspect2idx) 25 | train_aspect_idx = get_aspect_idx(train, aspect2idx) 26 | train = tokenize(train) 27 | dev = convert_input(dev, aspect2idx) 28 | dev_aspect_idx = get_aspect_idx(dev, aspect2idx) 29 | dev = tokenize(dev) 30 | test = convert_input(test, aspect2idx) 31 | test_aspect_idx = get_aspect_idx(test, aspect2idx) 32 | test = tokenize(test) 33 | 34 | return (train, train_aspect_idx), (dev, dev_aspect_idx), (test, test_aspect_idx) 35 | 36 | 37 | def get_aspect_idx(data, aspect2idx): 38 | ret = [] 39 | for _, _, _, aspect, _ in data: 40 | ret.append(aspect2idx[aspect]) 41 | assert len(data) == len(ret) 42 | return np.array(ret) 43 | 44 | 45 | def parse_sentihood_json(in_file): 46 | with open(in_file) as f: 47 | data = json.load(f) 48 | ret = [] 49 | for d in data: 50 | text = d['text'] 51 | sent_id = d['id'] 52 | opinions = [] 53 | targets = set() 54 | for opinion in d['opinions']: 55 | sentiment = opinion['sentiment'] 56 | aspect = opinion['aspect'] 57 | target_entity = opinion['target_entity'] 58 | targets.add(target_entity) 59 | opinions.append((target_entity, aspect, sentiment)) 60 | ret.append((sent_id, text, opinions)) 61 | return ret 62 | 63 | 64 | def convert_input(data, all_aspects): 65 | ret = [] 66 | for sent_id, text, opinions in data: 67 | for target_entity, aspect, sentiment in opinions: 68 | if aspect not in all_aspects: 69 | continue 70 | ret.append((sent_id, text, target_entity, aspect, sentiment)) 71 | assert 'LOCATION1' in text 72 | targets = set(['LOCATION1']) 73 | if 'LOCATION2' in text: 74 | targets.add('LOCATION2') 75 | for target in targets: 76 | aspects = set([a for t, a, _ in opinions if t == target]) 77 | none_aspects = [a for a in all_aspects if a not in aspects] 78 | for aspect in none_aspects: 79 | ret.append((sent_id, text, target, aspect, 'None')) 80 | return ret 81 | 82 | 83 | def tokenize(data): 84 | ret = [] 85 | for sent_id, text, target_entity, aspect, sentiment in data: 86 | new_text = nltk.word_tokenize(text) 87 | new_aspect = aspect.split('-') 88 | ret.append((sent_id, new_text, target_entity, new_aspect, sentiment)) 89 | return ret 90 | -------------------------------------------------------------------------------- /convert_tf_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | # Reference: https://github.com/huggingface/pytorch-pretrained-BERT 4 | 5 | """Convert BERT checkpoint.""" 6 | 7 | from __future__ import absolute_import, division, print_function 8 | 9 | import argparse 10 | import re 11 | 12 | import numpy as np 13 | import torch 14 | 15 | import tensorflow as tf 16 | from modeling import BertConfig, BertModel 17 | 18 | parser = argparse.ArgumentParser() 19 | 20 | ## Required parameters 21 | parser.add_argument("--tf_checkpoint_path", 22 | default = None, 23 | type = str, 24 | required = True, 25 | help = "Path the TensorFlow checkpoint path.") 26 | parser.add_argument("--bert_config_file", 27 | default = None, 28 | type = str, 29 | required = True, 30 | help = "The config json file corresponding to the pre-trained BERT model. \n" 31 | "This specifies the model architecture.") 32 | parser.add_argument("--pytorch_dump_path", 33 | default = None, 34 | type = str, 35 | required = True, 36 | help = "Path to the output PyTorch model.") 37 | 38 | args = parser.parse_args() 39 | 40 | def convert(): 41 | # Initialise PyTorch model 42 | config = BertConfig.from_json_file(args.bert_config_file) 43 | model = BertModel(config) 44 | 45 | # Load weights from TF model 46 | path = args.tf_checkpoint_path 47 | print("Converting TensorFlow checkpoint from {}".format(path)) 48 | 49 | init_vars = tf.train.list_variables(path) 50 | names = [] 51 | arrays = [] 52 | for name, shape in init_vars: 53 | print("Loading {} with shape {}".format(name, shape)) 54 | array = tf.train.load_variable(path, name) 55 | print("Numpy array shape {}".format(array.shape)) 56 | names.append(name) 57 | arrays.append(array) 58 | 59 | for name, array in zip(names, arrays): 60 | name = name[5:] # skip "bert/" 61 | print("Loading {}".format(name)) 62 | name = name.split('/') 63 | if any(n in ["adam_v", "adam_m","l_step"] for n in name): 64 | print("Skipping {}".format("/".join(name))) 65 | continue 66 | if name[0] in ['redictions', 'eq_relationship']: 67 | print("Skipping") 68 | continue 69 | pointer = model 70 | for m_name in name: 71 | if re.fullmatch(r'[A-Za-z]+_\d+', m_name): 72 | l = re.split(r'_(\d+)', m_name) 73 | else: 74 | l = [m_name] 75 | if l[0] == 'kernel': 76 | pointer = getattr(pointer, 'weight') 77 | else: 78 | pointer = getattr(pointer, l[0]) 79 | if len(l) >= 2: 80 | num = int(l[1]) 81 | pointer = pointer[num] 82 | if m_name[-11:] == '_embeddings': 83 | pointer = getattr(pointer, 'weight') 84 | elif m_name == 'kernel': 85 | array = np.transpose(array) 86 | try: 87 | assert pointer.shape == array.shape 88 | except AssertionError as e: 89 | e.args += (pointer.shape, array.shape) 90 | raise 91 | pointer.data = torch.from_numpy(array) 92 | 93 | # Save pytorch-model 94 | torch.save(model.state_dict(), args.pytorch_dump_path) 95 | 96 | if __name__ == "__main__": 97 | convert() 98 | -------------------------------------------------------------------------------- /generate/generate_sentihood_NLI_M.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from data_utils_sentihood import * 4 | 5 | data_dir='../data/sentihood/' 6 | aspect2idx = { 7 | 'general': 0, 8 | 'price': 1, 9 | 'transit-location': 2, 10 | 'safety': 3, 11 | } 12 | 13 | (train, train_aspect_idx), (val, val_aspect_idx), (test, test_aspect_idx) = load_task(data_dir, aspect2idx) 14 | 15 | print("len(train) = ", len(train)) 16 | print("len(val) = ", len(val)) 17 | print("len(test) = ", len(test)) 18 | 19 | train.sort(key=lambda x:x[2]+str(x[0])+x[3][0]) 20 | val.sort(key=lambda x:x[2]+str(x[0])+x[3][0]) 21 | test.sort(key=lambda x:x[2]+str(x[0])+x[3][0]) 22 | 23 | dir_path = data_dir+'bert-pair/' 24 | if not os.path.exists(dir_path): 25 | os.makedirs(dir_path) 26 | 27 | with open(dir_path+"train_NLI_M.tsv","w",encoding="utf-8") as f: 28 | f.write("id\tsentence1\tsentence2\tlabel\n") 29 | for v in train: 30 | f.write(str(v[0])+"\t") 31 | word=v[1][0].lower() 32 | if word=='location1':f.write('location - 1') 33 | elif word=='location2':f.write('location - 2') 34 | elif word[0]=='\'':f.write("\' "+word[1:]) 35 | else:f.write(word) 36 | for i in range(1,len(v[1])): 37 | word=v[1][i].lower() 38 | f.write(" ") 39 | if word == 'location1': 40 | f.write('location - 1') 41 | elif word == 'location2': 42 | f.write('location - 2') 43 | elif word[0] == '\'': 44 | f.write("\' " + word[1:]) 45 | else: 46 | f.write(word) 47 | f.write("\t") 48 | if v[2]=='LOCATION1':f.write('location - 1 - ') 49 | if v[2]=='LOCATION2':f.write('location - 2 - ') 50 | if len(v[3])==1: 51 | f.write(v[3][0]+"\t") 52 | else: 53 | f.write("transit location\t") 54 | f.write(v[4]+"\n") 55 | 56 | with open(dir_path+"dev_NLI_M.tsv","w",encoding="utf-8") as f: 57 | f.write("id\tsentence1\tsentence2\tlabel\n") 58 | for v in val: 59 | f.write(str(v[0])+"\t") 60 | word=v[1][0].lower() 61 | if word=='location1':f.write('location - 1') 62 | elif word=='location2':f.write('location - 2') 63 | elif word[0]=='\'':f.write("\' "+word[1:]) 64 | else:f.write(word) 65 | for i in range(1,len(v[1])): 66 | word=v[1][i].lower() 67 | f.write(" ") 68 | if word == 'location1': 69 | f.write('location - 1') 70 | elif word == 'location2': 71 | f.write('location - 2') 72 | elif word[0] == '\'': 73 | f.write("\' " + word[1:]) 74 | else: 75 | f.write(word) 76 | f.write("\t") 77 | if v[2]=='LOCATION1':f.write('location - 1 - ') 78 | if v[2]=='LOCATION2':f.write('location - 2 - ') 79 | if len(v[3])==1: 80 | f.write(v[3][0]+"\t") 81 | else: 82 | f.write("transit location\t") 83 | f.write(v[4]+"\n") 84 | 85 | with open(dir_path+"test_NLI_M.tsv","w",encoding="utf-8") as f: 86 | f.write("id\tsentence1\tsentence2\tlabel\n") 87 | for v in test: 88 | f.write(str(v[0])+"\t") 89 | word=v[1][0].lower() 90 | if word=='location1':f.write('location - 1') 91 | elif word=='location2':f.write('location - 2') 92 | elif word[0]=='\'':f.write("\' "+word[1:]) 93 | else:f.write(word) 94 | for i in range(1,len(v[1])): 95 | word=v[1][i].lower() 96 | f.write(" ") 97 | if word == 'location1': 98 | f.write('location - 1') 99 | elif word == 'location2': 100 | f.write('location - 2') 101 | elif word[0] == '\'': 102 | f.write("\' " + word[1:]) 103 | else: 104 | f.write(word) 105 | f.write("\t") 106 | if v[2]=='LOCATION1':f.write('location - 1 - ') 107 | if v[2]=='LOCATION2':f.write('location - 2 - ') 108 | if len(v[3])==1: 109 | f.write(v[3][0]+"\t") 110 | else: 111 | f.write("transit location\t") 112 | f.write(v[4]+"\n") 113 | -------------------------------------------------------------------------------- /generate/generate_sentihood_QA_M.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from data_utils_sentihood import * 4 | 5 | data_dir='../data/sentihood/' 6 | aspect2idx = { 7 | 'general': 0, 8 | 'price': 1, 9 | 'transit-location': 2, 10 | 'safety': 3, 11 | } 12 | 13 | (train, train_aspect_idx), (val, val_aspect_idx), (test, test_aspect_idx) = load_task(data_dir, aspect2idx) 14 | 15 | print("len(train) = ", len(train)) 16 | print("len(val) = ", len(val)) 17 | print("len(test) = ", len(test)) 18 | 19 | train.sort(key=lambda x:x[2]+str(x[0])+x[3][0]) 20 | val.sort(key=lambda x:x[2]+str(x[0])+x[3][0]) 21 | test.sort(key=lambda x:x[2]+str(x[0])+x[3][0]) 22 | 23 | dir_path = data_dir+'bert-pair/' 24 | if not os.path.exists(dir_path): 25 | os.makedirs(dir_path) 26 | 27 | with open(dir_path+"train_QA_M.tsv","w",encoding="utf-8") as f: 28 | f.write("id\tsentence1\tsentence2\tlabel\n") 29 | for v in train: 30 | f.write(str(v[0])+"\t") 31 | word=v[1][0].lower() 32 | if word=='location1':f.write('location - 1') 33 | elif word=='location2':f.write('location - 2') 34 | elif word[0]=='\'':f.write("\' "+word[1:]) 35 | else:f.write(word) 36 | for i in range(1,len(v[1])): 37 | word=v[1][i].lower() 38 | f.write(" ") 39 | if word == 'location1': 40 | f.write('location - 1') 41 | elif word == 'location2': 42 | f.write('location - 2') 43 | elif word[0] == '\'': 44 | f.write("\' " + word[1:]) 45 | else: 46 | f.write(word) 47 | f.write("\t") 48 | f.write("what do you think of the ") 49 | if len(v[3])==1: 50 | f.write(v[3][0]+" ") 51 | else: 52 | f.write("transit location ") 53 | if v[2]=='LOCATION1':f.write('of location - 1 ?\t') 54 | if v[2]=='LOCATION2':f.write('of location - 2 ?\t') 55 | f.write(v[4]+"\n") 56 | 57 | with open(dir_path+"dev_QA_M.tsv","w",encoding="utf-8") as f: 58 | f.write("id\tsentence1\tsentence2\tlabel\n") 59 | for v in val: 60 | f.write(str(v[0])+"\t") 61 | word=v[1][0].lower() 62 | if word=='location1':f.write('location - 1') 63 | elif word=='location2':f.write('location - 2') 64 | elif word[0]=='\'':f.write("\' "+word[1:]) 65 | else:f.write(word) 66 | for i in range(1,len(v[1])): 67 | word=v[1][i].lower() 68 | f.write(" ") 69 | if word == 'location1': 70 | f.write('location - 1') 71 | elif word == 'location2': 72 | f.write('location - 2') 73 | elif word[0] == '\'': 74 | f.write("\' " + word[1:]) 75 | else: 76 | f.write(word) 77 | f.write("\t") 78 | f.write("what do you think of the ") 79 | if len(v[3])==1: 80 | f.write(v[3][0]+" ") 81 | else: 82 | f.write("transit location ") 83 | if v[2]=='LOCATION1':f.write('of location - 1 ?\t') 84 | if v[2]=='LOCATION2':f.write('of location - 2 ?\t') 85 | f.write(v[4]+"\n") 86 | 87 | with open(dir_path+"test_QA_M.tsv","w",encoding="utf-8") as f: 88 | f.write("id\tsentence1\tsentence2\tlabel\n") 89 | for v in test: 90 | f.write(str(v[0])+"\t") 91 | word=v[1][0].lower() 92 | if word=='location1':f.write('location - 1') 93 | elif word=='location2':f.write('location - 2') 94 | elif word[0]=='\'':f.write("\' "+word[1:]) 95 | else:f.write(word) 96 | for i in range(1,len(v[1])): 97 | word=v[1][i].lower() 98 | f.write(" ") 99 | if word == 'location1': 100 | f.write('location - 1') 101 | elif word == 'location2': 102 | f.write('location - 2') 103 | elif word[0] == '\'': 104 | f.write("\' " + word[1:]) 105 | else: 106 | f.write(word) 107 | f.write("\t") 108 | f.write("what do you think of the ") 109 | if len(v[3])==1: 110 | f.write(v[3][0]+" ") 111 | else: 112 | f.write("transit location ") 113 | if v[2]=='LOCATION1':f.write('of location - 1 ?\t') 114 | if v[2]=='LOCATION2':f.write('of location - 2 ?\t') 115 | f.write(v[4]+"\n") 116 | -------------------------------------------------------------------------------- /generate/generate_semeval_NLI_M.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | data_dir='../data/semeval2014/' 4 | 5 | dir_path = data_dir+'bert-pair/' 6 | if not os.path.exists(dir_path): 7 | os.makedirs(dir_path) 8 | 9 | with open(dir_path+"test_NLI_M.csv","w",encoding="utf-8") as g: 10 | with open(data_dir+"Restaurants_Test_Gold.xml","r",encoding="utf-8") as f: 11 | s=f.readline().strip() 12 | while s: 13 | category=[] 14 | polarity=[] 15 | if "") 18 | id=s[left+4:right-1] 19 | while not "" in s: 20 | if "" in s: 21 | left=s.find("") 22 | right=s.find("") 23 | text=s[left+6:right] 24 | if "aspectCategory" in s: 25 | left=s.find("category=") 26 | right=s.find("polarity=") 27 | category.append(s[left+10:right-2]) 28 | left=s.find("polarity=") 29 | right=s.find("/>") 30 | polarity.append(s[left+10:right-2]) 31 | s=f.readline().strip() 32 | if "price" in category: 33 | g.write(id+"\t"+polarity[category.index("price")]+"\t"+"price"+"\t"+text+"\n") 34 | else: 35 | g.write(id + "\t" + "none" + "\t" + "price" + "\t" + text + "\n") 36 | if "anecdotes/miscellaneous" in category: 37 | g.write(id+"\t"+polarity[category.index("anecdotes/miscellaneous")]+"\t"+"anecdotes"+"\t"+text+"\n") 38 | else: 39 | g.write(id + "\t" + "none" + "\t" + "anecdotes" + "\t" + text + "\n") 40 | if "food" in category: 41 | g.write(id+"\t"+polarity[category.index("food")]+"\t"+"food"+"\t"+text+"\n") 42 | else: 43 | g.write(id + "\t" + "none" + "\t" + "food" + "\t" + text + "\n") 44 | if "ambience" in category: 45 | g.write(id+"\t"+polarity[category.index("ambience")]+"\t"+"ambience"+"\t"+text+"\n") 46 | else: 47 | g.write(id + "\t" + "none" + "\t" + "ambience" + "\t" + text + "\n") 48 | if "service" in category: 49 | g.write(id+"\t"+polarity[category.index("service")]+"\t"+"service"+"\t"+text+"\n") 50 | else: 51 | g.write(id + "\t" + "none" + "\t" + "service" + "\t" + text + "\n") 52 | else: 53 | s = f.readline().strip() 54 | 55 | 56 | with open(dir_path+"train_NLI_M.csv","w",encoding="utf-8") as g: 57 | with open(data_dir+"Restaurants_Train.xml","r",encoding="utf-8") as f: 58 | s=f.readline().strip() 59 | while s: 60 | category=[] 61 | polarity=[] 62 | if "") 65 | id=s[left+4:right-1] 66 | while not "" in s: 67 | if "" in s: 68 | left=s.find("") 69 | right=s.find("") 70 | text=s[left+6:right] 71 | if "aspectCategory" in s: 72 | left=s.find("category=") 73 | right=s.find("polarity=") 74 | category.append(s[left+10:right-2]) 75 | left=s.find("polarity=") 76 | right=s.find("/>") 77 | polarity.append(s[left+10:right-1]) 78 | s=f.readline().strip() 79 | if "price" in category: 80 | g.write(id+"\t"+polarity[category.index("price")]+"\t"+"price"+"\t"+text+"\n") 81 | else: 82 | g.write(id + "\t" + "none" + "\t" + "price" + "\t" + text + "\n") 83 | if "anecdotes/miscellaneous" in category: 84 | g.write(id+"\t"+polarity[category.index("anecdotes/miscellaneous")]+"\t"+"anecdotes"+"\t"+text+"\n") 85 | else: 86 | g.write(id + "\t" + "none" + "\t" + "anecdotes" + "\t" + text + "\n") 87 | if "food" in category: 88 | g.write(id+"\t"+polarity[category.index("food")]+"\t"+"food"+"\t"+text+"\n") 89 | else: 90 | g.write(id + "\t" + "none" + "\t" + "food" + "\t" + text + "\n") 91 | if "ambience" in category: 92 | g.write(id+"\t"+polarity[category.index("ambience")]+"\t"+"ambience"+"\t"+text+"\n") 93 | else: 94 | g.write(id + "\t" + "none" + "\t" + "ambience" + "\t" + text + "\n") 95 | if "service" in category: 96 | g.write(id+"\t"+polarity[category.index("service")]+"\t"+"service"+"\t"+text+"\n") 97 | else: 98 | g.write(id + "\t" + "none" + "\t" + "service" + "\t" + text + "\n") 99 | else: 100 | s = f.readline().strip() 101 | 102 | 103 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ABSA as a Sentence Pair Classification Task 2 | 3 | Codes and corpora for paper "Utilizing BERT for Aspect-Based Sentiment Analysis via Constructing Auxiliary Sentence" (NAACL 2019) 4 | 5 | ## Requirement 6 | 7 | * pytorch: 1.0.0 8 | * python: 3.7.1 9 | * tensorflow: 1.13.1 (only needed for converting BERT-tensorflow-model to pytorch-model) 10 | * numpy: 1.15.4 11 | * nltk 12 | * sklearn 13 | 14 | ## Step 1: prepare datasets 15 | 16 | ### SentiHood 17 | 18 | Since the link given in the [dataset released paper]() has failed, we use the [dataset mirror]() listed in [NLP-progress](https://github.com/sebastianruder/NLP-progress/blob/master/english/sentiment_analysis.md) and fix some mistakes (there are duplicate aspect data in several sentences). See directory: `data/sentihood/`. 19 | 20 | Run following commands to prepare datasets for tasks: 21 | 22 | ``` 23 | cd generate/ 24 | bash make.sh sentihood 25 | ``` 26 | 27 | ### SemEval 2014 28 | 29 | Train Data is available in [SemEval-2014 ABSA Restaurant Reviews - Train Data](http://metashare.ilsp.gr:8080/repository/browse/semeval-2014-absa-restaurant-reviews-train-data/479d18c0625011e38685842b2b6a04d72cb57ba6c07743b9879d1a04e72185b8/) and Gold Test Data is available in [SemEval-2014 ABSA Test Data - Gold Annotations](http://metashare.ilsp.gr:8080/repository/browse/semeval-2014-absa-test-data-gold-annotations/b98d11cec18211e38229842b2b6a04d77591d40acd7542b7af823a54fb03a155/). See directory: `data/semeval2014/`. 30 | 31 | Run following commands to prepare datasets for tasks: 32 | 33 | ``` 34 | cd generate/ 35 | bash make.sh semeval 36 | ``` 37 | 38 | ## Step 2: prepare BERT-pytorch-model 39 | 40 | Download [BERT-Base (Google's pre-trained models)](https://github.com/google-research/bert) and then convert a tensorflow checkpoint to a pytorch model. 41 | 42 | For example: 43 | 44 | ``` 45 | python convert_tf_checkpoint_to_pytorch.py \ 46 | --tf_checkpoint_path uncased_L-12_H-768_A-12/bert_model.ckpt \ 47 | --bert_config_file uncased_L-12_H-768_A-12/bert_config.json \ 48 | --pytorch_dump_path uncased_L-12_H-768_A-12/pytorch_model.bin 49 | ``` 50 | 51 | ## Step 3: train 52 | 53 | For example, **BERT-pair-NLI_M** task on **SentiHood** dataset: 54 | 55 | ``` 56 | CUDA_VISIBLE_DEVICES=0,1,2,3 python run_classifier_TABSA.py \ 57 | --task_name sentihood_NLI_M \ 58 | --data_dir data/sentihood/bert-pair/ \ 59 | --vocab_file uncased_L-12_H-768_A-12/vocab.txt \ 60 | --bert_config_file uncased_L-12_H-768_A-12/bert_config.json \ 61 | --init_checkpoint uncased_L-12_H-768_A-12/pytorch_model.bin \ 62 | --eval_test \ 63 | --do_lower_case \ 64 | --max_seq_length 512 \ 65 | --train_batch_size 24 \ 66 | --learning_rate 2e-5 \ 67 | --num_train_epochs 6.0 \ 68 | --output_dir results/sentihood/NLI_M \ 69 | --seed 42 70 | ``` 71 | 72 | Note: 73 | 74 | * For SentiHood, `--task_name` must be chosen in `sentihood_NLI_M`, `sentihood_QA_M`, `sentihood_NLI_B`, `sentihood_QA_B` and `sentihood_single`. And for `sentihood_single` task, 8 different tasks (use datasets generated in step 1, see directory `data/sentihood/bert-single`) should be trained separately and then evaluated together. 75 | * For SemEval-2014, `--task_name` must be chosen in `semeval_NLI_M`, `semeval_QA_M`, `semeval_NLI_B`, `semeval_QA_B` and `semeval_single`. And for `semeval_single` task, 5 different tasks (use datasets generated in step 1, see directory : `data/semeval2014/bert-single`) should be trained separately and then evaluated together. 76 | 77 | ## Step 4: evaluation 78 | 79 | Evaluate the results on test set (calculate Acc, F1, etc.). 80 | 81 | For example, **BERT-pair-NLI_M** task on **SentiHood** dataset: 82 | 83 | ``` 84 | python evaluation.py --task_name sentihood_NLI_M --pred_data_dir results/sentihood/NLI_M/test_ep_4.txt 85 | ``` 86 | 87 | Note: 88 | 89 | * As mentioned in step 3, for `sentihood_single` task, 8 different tasks should be trained separately and then evaluated together. `--pred_data_dir` should be a directory that contains **8 files** named as follows: `loc1_general.txt`, `loc1_price.txt`, `loc1_safety.txt`, `loc1_transit.txt`, `loc2_general.txt`, `loc2_price.txt`, `loc2_safety.txt` and `loc2_transit.txt` 90 | * As mentioned in step 3, for `semeval_single` task, 5 different tasks should be trained separately and then evaluated together. `--pred_data_dir` should be a directory that contains **5 files** named as follows: `price.txt`, `anecdotes.txt`, `food.txt`, `ambience.txt` and `service.txt` 91 | * For the rest 8 tasks, `--pred_data_dir` should be a file just like that in the example. 92 | 93 | 94 | ## Citation 95 | 96 | ``` 97 | @inproceedings{sun-etal-2019-utilizing, 98 | title = "Utilizing {BERT} for Aspect-Based Sentiment Analysis via Constructing Auxiliary Sentence", 99 | author = "Sun, Chi and 100 | Huang, Luyao and 101 | Qiu, Xipeng", 102 | booktitle = "Proceedings of the 2019 Conference of the North {A}merican Chapter of the Association for Computational Linguistics: Human Language Technologies, Volume 1 (Long and Short Papers)", 103 | month = jun, 104 | year = "2019", 105 | address = "Minneapolis, Minnesota", 106 | publisher = "Association for Computational Linguistics", 107 | url = "https://www.aclweb.org/anthology/N19-1035", 108 | pages = "380--385" 109 | } 110 | ``` 111 | -------------------------------------------------------------------------------- /generate/generate_sentihood_BERT_single.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from data_utils_sentihood import * 4 | 5 | data_dir='../data/sentihood/' 6 | aspect2idx = { 7 | 'general': 0, 8 | 'price': 1, 9 | 'transit-location': 2, 10 | 'safety': 3, 11 | } 12 | 13 | (train, train_aspect_idx), (val, val_aspect_idx), (test, test_aspect_idx) = load_task(data_dir, aspect2idx) 14 | 15 | print("len(train) = ", len(train)) 16 | print("len(val) = ", len(val)) 17 | print("len(test) = ", len(test)) 18 | 19 | train.sort(key=lambda x:x[2]+str(x[0])+x[3][0]) 20 | val.sort(key=lambda x:x[2]+str(x[0])+x[3][0]) 21 | test.sort(key=lambda x:x[2]+str(x[0])+x[3][0]) 22 | 23 | location_name = ['loc1', 'loc2'] 24 | aspect_name = ['general', 'price', 'safety', 'transit'] 25 | dir_path = [data_dir + 'bert-single/' + i + '_' + j + '/' for i in location_name for j in aspect_name] 26 | for path in dir_path: 27 | if not os.path.exists(path): 28 | os.makedirs(path) 29 | 30 | count=0 31 | with open(dir_path[0]+"train.tsv","w",encoding="utf-8") as f1_general, \ 32 | open(dir_path[1]+"train.tsv", "w", encoding="utf-8") as f1_price, \ 33 | open(dir_path[2]+"train.tsv", "w", encoding="utf-8") as f1_safety, \ 34 | open(dir_path[3]+"train.tsv", "w", encoding="utf-8") as f1_transit, \ 35 | open(dir_path[4]+"train.tsv", "w", encoding="utf-8") as f2_general, \ 36 | open(dir_path[5]+"train.tsv", "w", encoding="utf-8") as f2_price, \ 37 | open(dir_path[6]+"train.tsv", "w", encoding="utf-8") as f2_safety, \ 38 | open(dir_path[7]+"train.tsv", "w",encoding="utf-8") as f2_transit, \ 39 | open(data_dir + "bert-pair/train_NLI_M.tsv", "r", encoding="utf-8") as f: 40 | s = f.readline().strip() 41 | s = f.readline().strip() 42 | while s: 43 | count+=1 44 | tmp=s.split("\t") 45 | line=tmp[0]+"\t"+tmp[1]+"\t"+tmp[3]+"\n" 46 | if count<=11908: #loc1 47 | if count%4==1: 48 | f1_general.write(line) 49 | if count%4==2: 50 | f1_price.write(line) 51 | if count%4==3: 52 | f1_safety.write(line) 53 | if count%4==0: 54 | f1_transit.write(line) 55 | else: #loc2 56 | if count%4==1: 57 | f2_general.write(line) 58 | if count%4==2: 59 | f2_price.write(line) 60 | if count%4==3: 61 | f2_safety.write(line) 62 | if count%4==0: 63 | f2_transit.write(line) 64 | s = f.readline().strip() 65 | 66 | count=0 67 | with open(dir_path[0]+"dev.tsv","w",encoding="utf-8") as f1_general, \ 68 | open(dir_path[1]+"dev.tsv", "w", encoding="utf-8") as f1_price, \ 69 | open(dir_path[2]+"dev.tsv", "w", encoding="utf-8") as f1_safety, \ 70 | open(dir_path[3]+"dev.tsv", "w", encoding="utf-8") as f1_transit, \ 71 | open(dir_path[4]+"dev.tsv", "w", encoding="utf-8") as f2_general, \ 72 | open(dir_path[5]+"dev.tsv", "w", encoding="utf-8") as f2_price, \ 73 | open(dir_path[6]+"dev.tsv", "w", encoding="utf-8") as f2_safety, \ 74 | open(dir_path[7]+"dev.tsv", "w",encoding="utf-8") as f2_transit, \ 75 | open(data_dir + "bert-pair/dev_NLI_M.tsv", "r", encoding="utf-8") as f: 76 | s = f.readline().strip() 77 | s = f.readline().strip() 78 | while s: 79 | count+=1 80 | tmp=s.split("\t") 81 | line=tmp[0]+"\t"+tmp[1]+"\t"+tmp[3]+"\n" 82 | if count<=2988: #loc1 83 | if count%4==1: 84 | f1_general.write(line) 85 | if count%4==2: 86 | f1_price.write(line) 87 | if count%4==3: 88 | f1_safety.write(line) 89 | if count%4==0: 90 | f1_transit.write(line) 91 | else: #loc2 92 | if count%4==1: 93 | f2_general.write(line) 94 | if count%4==2: 95 | f2_price.write(line) 96 | if count%4==3: 97 | f2_safety.write(line) 98 | if count%4==0: 99 | f2_transit.write(line) 100 | s = f.readline().strip() 101 | 102 | count=0 103 | with open(dir_path[0]+"test.tsv","w",encoding="utf-8") as f1_general, \ 104 | open(dir_path[1]+"test.tsv", "w", encoding="utf-8") as f1_price, \ 105 | open(dir_path[2]+"test.tsv", "w", encoding="utf-8") as f1_safety, \ 106 | open(dir_path[3]+"test.tsv", "w", encoding="utf-8") as f1_transit, \ 107 | open(dir_path[4]+"test.tsv", "w", encoding="utf-8") as f2_general, \ 108 | open(dir_path[5]+"test.tsv", "w", encoding="utf-8") as f2_price, \ 109 | open(dir_path[6]+"test.tsv", "w", encoding="utf-8") as f2_safety, \ 110 | open(dir_path[7]+"test.tsv", "w",encoding="utf-8") as f2_transit, \ 111 | open(data_dir + "bert-pair/test_NLI_M.tsv", "r", encoding="utf-8") as f: 112 | s = f.readline().strip() 113 | s = f.readline().strip() 114 | while s: 115 | count+=1 116 | tmp=s.split("\t") 117 | line=tmp[0]+"\t"+tmp[1]+"\t"+tmp[3]+"\n" 118 | if count<=5964: #loc1 119 | if count%4==1: 120 | f1_general.write(line) 121 | if count%4==2: 122 | f1_price.write(line) 123 | if count%4==3: 124 | f1_safety.write(line) 125 | if count%4==0: 126 | f1_transit.write(line) 127 | else: #loc2 128 | if count%4==1: 129 | f2_general.write(line) 130 | if count%4==2: 131 | f2_price.write(line) 132 | if count%4==3: 133 | f2_safety.write(line) 134 | if count%4==0: 135 | f2_transit.write(line) 136 | s = f.readline().strip() 137 | 138 | print("Finished!") -------------------------------------------------------------------------------- /generate/generate_semeval_BERT_single.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | data_dir='../data/semeval2014/' 4 | 5 | aspect_name = ['price', 'anecdotes', 'food', 'ambience', 'service'] 6 | dir_path = [data_dir + 'bert-single/' + i + '/' for i in aspect_name] 7 | for path in dir_path: 8 | if not os.path.exists(path): 9 | os.makedirs(path) 10 | 11 | with open(dir_path[0]+"test.csv", "w", encoding="utf-8") as g_price, \ 12 | open(dir_path[1]+"test.csv", "w", encoding="utf-8") as g_anecdotes,\ 13 | open(dir_path[2]+"test.csv", "w", encoding="utf-8") as g_food,\ 14 | open(dir_path[3]+"test.csv", "w", encoding="utf-8") as g_ambience,\ 15 | open(dir_path[4]+"test.csv", "w", encoding="utf-8") as g_service,\ 16 | open(data_dir+"Restaurants_Test_Gold.xml","r",encoding="utf-8") as f: 17 | s=f.readline().strip() 18 | while s: 19 | category=[] 20 | polarity=[] 21 | if "") 24 | id=s[left+4:right-1] 25 | while not "" in s: 26 | if "" in s: 27 | left=s.find("") 28 | right=s.find("") 29 | text=s[left+6:right] 30 | if "aspectCategory" in s: 31 | left=s.find("category=") 32 | right=s.find("polarity=") 33 | category.append(s[left+10:right-2]) 34 | left=s.find("polarity=") 35 | right=s.find("/>") 36 | polarity.append(s[left+10:right-2]) 37 | s=f.readline().strip() 38 | if "price" in category: 39 | g_price.write(id+"\t"+polarity[category.index("price")]+"\t"+"price"+"\t"+text+"\n") 40 | else: 41 | g_price.write(id + "\t" + "none" + "\t" + "price" + "\t" + text + "\n") 42 | if "anecdotes/miscellaneous" in category: 43 | g_anecdotes.write(id+"\t"+polarity[category.index("anecdotes/miscellaneous")]+"\t"+"anecdotes"+"\t"+text+"\n") 44 | else: 45 | g_anecdotes.write(id + "\t" + "none" + "\t" + "anecdotes" + "\t" + text + "\n") 46 | if "food" in category: 47 | g_food.write(id+"\t"+polarity[category.index("food")]+"\t"+"food"+"\t"+text+"\n") 48 | else: 49 | g_food.write(id + "\t" + "none" + "\t" + "food" + "\t" + text + "\n") 50 | if "ambience" in category: 51 | g_ambience.write(id+"\t"+polarity[category.index("ambience")]+"\t"+"ambience"+"\t"+text+"\n") 52 | else: 53 | g_ambience.write(id + "\t" + "none" + "\t" + "ambience" + "\t" + text + "\n") 54 | if "service" in category: 55 | g_service.write(id+"\t"+polarity[category.index("service")]+"\t"+"service"+"\t"+text+"\n") 56 | else: 57 | g_service.write(id + "\t" + "none" + "\t" + "service" + "\t" + text + "\n") 58 | else: 59 | s = f.readline().strip() 60 | 61 | 62 | with open(dir_path[0]+"train.csv", "w", encoding="utf-8") as g_price, \ 63 | open(dir_path[1]+"train.csv", "w", encoding="utf-8") as g_anecdotes,\ 64 | open(dir_path[2]+"train.csv", "w", encoding="utf-8") as g_food,\ 65 | open(dir_path[3]+"train.csv", "w", encoding="utf-8") as g_ambience,\ 66 | open(dir_path[4]+"train.csv", "w", encoding="utf-8") as g_service,\ 67 | open(data_dir+"Restaurants_Train.xml","r",encoding="utf-8") as f: 68 | s=f.readline().strip() 69 | while s: 70 | category=[] 71 | polarity=[] 72 | if "") 75 | id=s[left+4:right-1] 76 | while not "" in s: 77 | if "" in s: 78 | left=s.find("") 79 | right=s.find("") 80 | text=s[left+6:right] 81 | if "aspectCategory" in s: 82 | left=s.find("category=") 83 | right=s.find("polarity=") 84 | category.append(s[left+10:right-2]) 85 | left=s.find("polarity=") 86 | right=s.find("/>") 87 | polarity.append(s[left+10:right-1]) 88 | s=f.readline().strip() 89 | if "price" in category: 90 | g_price.write(id+"\t"+polarity[category.index("price")]+"\t"+"price"+"\t"+text+"\n") 91 | else: 92 | g_price.write(id + "\t" + "none" + "\t" + "price" + "\t" + text + "\n") 93 | if "anecdotes/miscellaneous" in category: 94 | g_anecdotes.write(id+"\t"+polarity[category.index("anecdotes/miscellaneous")]+"\t"+"anecdotes"+"\t"+text+"\n") 95 | else: 96 | g_anecdotes.write(id + "\t" + "none" + "\t" + "anecdotes" + "\t" + text + "\n") 97 | if "food" in category: 98 | g_food.write(id+"\t"+polarity[category.index("food")]+"\t"+"food"+"\t"+text+"\n") 99 | else: 100 | g_food.write(id + "\t" + "none" + "\t" + "food" + "\t" + text + "\n") 101 | if "ambience" in category: 102 | g_ambience.write(id+"\t"+polarity[category.index("ambience")]+"\t"+"ambience"+"\t"+text+"\n") 103 | else: 104 | g_ambience.write(id + "\t" + "none" + "\t" + "ambience" + "\t" + text + "\n") 105 | if "service" in category: 106 | g_service.write(id+"\t"+polarity[category.index("service")]+"\t"+"service"+"\t"+text+"\n") 107 | else: 108 | g_service.write(id + "\t" + "none" + "\t" + "service" + "\t" + text + "\n") 109 | else: 110 | s = f.readline().strip() 111 | 112 | print("Finished!") 113 | 114 | -------------------------------------------------------------------------------- /generate/generate_semeval_QA_M.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | data_dir='../data/semeval2014/' 4 | 5 | dir_path = data_dir+'bert-pair/' 6 | if not os.path.exists(dir_path): 7 | os.makedirs(dir_path) 8 | 9 | with open(dir_path+"test_QA_M.csv","w",encoding="utf-8") as g: 10 | with open(data_dir+"Restaurants_Test_Gold.xml","r",encoding="utf-8") as f: 11 | s=f.readline().strip() 12 | while s: 13 | category=[] 14 | polarity=[] 15 | if "") 18 | id=s[left+4:right-1] 19 | while not "" in s: 20 | if "" in s: 21 | left=s.find("") 22 | right=s.find("") 23 | text=s[left+6:right] 24 | if "aspectCategory" in s: 25 | left=s.find("category=") 26 | right=s.find("polarity=") 27 | category.append(s[left+10:right-2]) 28 | left=s.find("polarity=") 29 | right=s.find("/>") 30 | polarity.append(s[left+10:right-2]) 31 | s=f.readline().strip() 32 | if "price" in category: 33 | g.write(id+"\t"+polarity[category.index("price")]+"\t"+"what do you think of the price of it ?"+"\t"+text+"\n") 34 | else: 35 | g.write(id + "\t" + "none" + "\t" + "what do you think of the price of it ?" + "\t" + text + "\n") 36 | if "anecdotes/miscellaneous" in category: 37 | g.write(id+"\t"+polarity[category.index("anecdotes/miscellaneous")]+"\t"+"what do you think of the anecdotes of it ?"+"\t"+text+"\n") 38 | else: 39 | g.write(id + "\t" + "none" + "\t" + "what do you think of the anecdotes of it ?" + "\t" + text + "\n") 40 | if "food" in category: 41 | g.write(id+"\t"+polarity[category.index("food")]+"\t"+"what do you think of the food of it ?"+"\t"+text+"\n") 42 | else: 43 | g.write(id + "\t" + "none" + "\t" + "what do you think of the food of it ?" + "\t" + text + "\n") 44 | if "ambience" in category: 45 | g.write(id+"\t"+polarity[category.index("ambience")]+"\t"+"what do you think of the ambience of it ?"+"\t"+text+"\n") 46 | else: 47 | g.write(id + "\t" + "none" + "\t" + "what do you think of the ambience of it ?" + "\t" + text + "\n") 48 | if "service" in category: 49 | g.write(id+"\t"+polarity[category.index("service")]+"\t"+"what do you think of the service of it ?"+"\t"+text+"\n") 50 | else: 51 | g.write(id + "\t" + "none" + "\t" + "what do you think of the service of it ?" + "\t" + text + "\n") 52 | else: 53 | s = f.readline().strip() 54 | 55 | 56 | with open(dir_path+"train_QA_M.csv","w",encoding="utf-8") as g: 57 | with open(data_dir+"Restaurants_Train.xml","r",encoding="utf-8") as f: 58 | s=f.readline().strip() 59 | while s: 60 | category=[] 61 | polarity=[] 62 | if "") 65 | id=s[left+4:right-1] 66 | while not "" in s: 67 | if "" in s: 68 | left=s.find("") 69 | right=s.find("") 70 | text=s[left+6:right] 71 | if "aspectCategory" in s: 72 | left=s.find("category=") 73 | right=s.find("polarity=") 74 | category.append(s[left+10:right-2]) 75 | left=s.find("polarity=") 76 | right=s.find("/>") 77 | polarity.append(s[left+10:right-1]) 78 | s=f.readline().strip() 79 | if "price" in category: 80 | g.write(id+"\t"+polarity[category.index("price")]+"\t"+"what do you think of the price of it ?"+"\t"+text+"\n") 81 | else: 82 | g.write(id + "\t" + "none" + "\t" + "what do you think of the price of it ?" + "\t" + text + "\n") 83 | if "anecdotes/miscellaneous" in category: 84 | g.write(id+"\t"+polarity[category.index("anecdotes/miscellaneous")]+"\t"+"what do you think of the anecdotes of it ?"+"\t"+text+"\n") 85 | else: 86 | g.write(id + "\t" + "none" + "\t" + "what do you think of the anecdotes of it ?" + "\t" + text + "\n") 87 | if "food" in category: 88 | g.write(id+"\t"+polarity[category.index("food")]+"\t"+"what do you think of the food of it ?"+"\t"+text+"\n") 89 | else: 90 | g.write(id + "\t" + "none" + "\t" + "what do you think of the food of it ?" + "\t" + text + "\n") 91 | if "ambience" in category: 92 | g.write(id+"\t"+polarity[category.index("ambience")]+"\t"+"what do you think of the ambience of it ?"+"\t"+text+"\n") 93 | else: 94 | g.write(id + "\t" + "none" + "\t" + "what do you think of the ambience of it ?" + "\t" + text + "\n") 95 | if "service" in category: 96 | g.write(id+"\t"+polarity[category.index("service")]+"\t"+"what do you think of the service of it ?"+"\t"+text+"\n") 97 | else: 98 | g.write(id + "\t" + "none" + "\t" + "what do you think of the service of it ?" + "\t" + text + "\n") 99 | else: 100 | s = f.readline().strip() -------------------------------------------------------------------------------- /optimization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | # Reference: https://github.com/huggingface/pytorch-pretrained-BERT 4 | 5 | """PyTorch optimization for BERT model.""" 6 | 7 | import math 8 | 9 | import torch 10 | from torch.nn.utils import clip_grad_norm_ 11 | from torch.optim import Optimizer 12 | 13 | 14 | def warmup_cosine(x, warmup=0.002): 15 | if x < warmup: 16 | return x/warmup 17 | return 0.5 * (1.0 + torch.cos(math.pi * x)) 18 | 19 | def warmup_constant(x, warmup=0.002): 20 | if x < warmup: 21 | return x/warmup 22 | return 1.0 23 | 24 | def warmup_linear(x, warmup=0.002): 25 | if x < warmup: 26 | return x/warmup 27 | return 1.0 - x 28 | 29 | SCHEDULES = { 30 | 'warmup_cosine':warmup_cosine, 31 | 'warmup_constant':warmup_constant, 32 | 'warmup_linear':warmup_linear, 33 | } 34 | 35 | 36 | class BERTAdam(Optimizer): 37 | """Implements BERT version of Adam algorithm with weight decay fix (and no ). 38 | Params: 39 | lr: learning rate 40 | warmup: portion of t_total for the warmup, -1 means no warmup. Default: -1 41 | t_total: total number of training steps for the learning 42 | rate schedule, -1 means constant learning rate. Default: -1 43 | schedule: schedule to use for the warmup (see above). Default: 'warmup_linear' 44 | b1: Adams b1. Default: 0.9 45 | b2: Adams b2. Default: 0.999 46 | e: Adams epsilon. Default: 1e-6 47 | weight_decay_rate: Weight decay. Default: 0.01 48 | max_grad_norm: Maximum norm for the gradients (-1 means no clipping). Default: 1.0 49 | """ 50 | def __init__(self, params, lr, warmup=-1, t_total=-1, schedule='warmup_linear', 51 | b1=0.9, b2=0.999, e=1e-6, weight_decay_rate=0.01, 52 | max_grad_norm=1.0): 53 | if not lr >= 0.0: 54 | raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr)) 55 | if schedule not in SCHEDULES: 56 | raise ValueError("Invalid schedule parameter: {}".format(schedule)) 57 | if not 0.0 <= warmup < 1.0 and not warmup == -1: 58 | raise ValueError("Invalid warmup: {} - should be in [0.0, 1.0[ or -1".format(warmup)) 59 | if not 0.0 <= b1 < 1.0: 60 | raise ValueError("Invalid b1 parameter: {} - should be in [0.0, 1.0[".format(b1)) 61 | if not 0.0 <= b2 < 1.0: 62 | raise ValueError("Invalid b2 parameter: {} - should be in [0.0, 1.0[".format(b2)) 63 | if not e >= 0.0: 64 | raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(e)) 65 | defaults = dict(lr=lr, schedule=schedule, warmup=warmup, t_total=t_total, 66 | b1=b1, b2=b2, e=e, weight_decay_rate=weight_decay_rate, 67 | max_grad_norm=max_grad_norm) 68 | super(BERTAdam, self).__init__(params, defaults) 69 | 70 | def get_lr(self): 71 | lr = [] 72 | print("l_total=",len(self.param_groups)) 73 | for group in self.param_groups: 74 | print("l_p=",len(group['params'])) 75 | for p in group['params']: 76 | state = self.state[p] 77 | if len(state) == 0: 78 | return [0] 79 | if group['t_total'] != -1: 80 | schedule_fct = SCHEDULES[group['schedule']] 81 | lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup']) 82 | else: 83 | lr_scheduled = group['lr'] 84 | lr.append(lr_scheduled) 85 | return lr 86 | 87 | def to(self, device): 88 | """ Move the optimizer state to a specified device""" 89 | for state in self.state.values(): 90 | state['exp_avg'].to(device) 91 | state['exp_avg_sq'].to(device) 92 | 93 | def initialize_step(self, initial_step): 94 | """Initialize state with a defined step (but we don't have stored averaged). 95 | Arguments: 96 | initial_step (int): Initial step number. 97 | """ 98 | for group in self.param_groups: 99 | for p in group['params']: 100 | state = self.state[p] 101 | # State initialization 102 | state['step'] = initial_step 103 | # Exponential moving average of gradient values 104 | state['exp_avg'] = torch.zeros_like(p.data) 105 | # Exponential moving average of squared gradient values 106 | state['exp_avg_sq'] = torch.zeros_like(p.data) 107 | 108 | def step(self, closure=None): 109 | """Performs a single optimization step. 110 | 111 | Arguments: 112 | closure (callable, optional): A closure that reevaluates the model 113 | and returns the loss. 114 | """ 115 | loss = None 116 | if closure is not None: 117 | loss = closure() 118 | 119 | for group in self.param_groups: 120 | for p in group['params']: 121 | if p.grad is None: 122 | continue 123 | grad = p.grad.data 124 | if grad.is_sparse: 125 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 126 | 127 | state = self.state[p] 128 | 129 | # State initialization 130 | if len(state) == 0: 131 | state['step'] = 0 132 | # Exponential moving average of gradient values 133 | state['next_m'] = torch.zeros_like(p.data) 134 | # Exponential moving average of squared gradient values 135 | state['next_v'] = torch.zeros_like(p.data) 136 | 137 | next_m, next_v = state['next_m'], state['next_v'] 138 | beta1, beta2 = group['b1'], group['b2'] 139 | 140 | # Add grad clipping 141 | if group['max_grad_norm'] > 0: 142 | clip_grad_norm_(p, group['max_grad_norm']) 143 | 144 | # Decay the first and second moment running average coefficient 145 | # In-place operations to update the averages at the same time 146 | next_m.mul_(beta1).add_(1 - beta1, grad) 147 | next_v.mul_(beta2).addcmul_(1 - beta2, grad, grad) 148 | update = next_m / (next_v.sqrt() + group['e']) 149 | 150 | # Just adding the square of the weights to the loss function is *not* 151 | # the correct way of using L2 regularization/weight decay with Adam, 152 | # since that will interact with the m and v parameters in strange ways. 153 | # 154 | # Instead we want ot decay the weights in a manner that doesn't interact 155 | # with the m/v parameters. This is equivalent to adding the square 156 | # of the weights to the loss with plain (non-momentum) SGD. 157 | if group['weight_decay_rate'] > 0.0: 158 | update += group['weight_decay_rate'] * p.data 159 | 160 | if group['t_total'] != -1: 161 | schedule_fct = SCHEDULES[group['schedule']] 162 | lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup']) 163 | else: 164 | lr_scheduled = group['lr'] 165 | 166 | update_with_lr = lr_scheduled * update 167 | p.data.add_(-update_with_lr) 168 | 169 | state['step'] += 1 170 | 171 | # step_size = lr_scheduled * math.sqrt(bias_correction2) / bias_correction1 172 | # bias_correction1 = 1 - beta1 ** state['step'] 173 | # bias_correction2 = 1 - beta2 ** state['step'] 174 | 175 | return loss 176 | -------------------------------------------------------------------------------- /generate/generate_sentihood_NLI_B_QA_B.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from data_utils_sentihood import * 4 | 5 | data_dir='../data/sentihood/' 6 | aspect2idx = { 7 | 'general': 0, 8 | 'price': 1, 9 | 'transit-location': 2, 10 | 'safety': 3, 11 | } 12 | 13 | (train, train_aspect_idx), (val, val_aspect_idx), (test, test_aspect_idx) = load_task(data_dir, aspect2idx) 14 | 15 | print("len(train) = ", len(train)) 16 | print("len(val) = ", len(val)) 17 | print("len(test) = ", len(test)) 18 | 19 | train.sort(key=lambda x:x[2]+str(x[0])+x[3][0]) 20 | val.sort(key=lambda x:x[2]+str(x[0])+x[3][0]) 21 | test.sort(key=lambda x:x[2]+str(x[0])+x[3][0]) 22 | 23 | dir_path = data_dir+'bert-pair/' 24 | if not os.path.exists(dir_path): 25 | os.makedirs(dir_path) 26 | 27 | sentiments=["None","Positive","Negative"] 28 | with open(dir_path+"train_NLI_B.tsv","w",encoding="utf-8") as f: 29 | f.write("id\tsentence1\tsentence2\tlabel\n") 30 | for v in train: 31 | for sentiment in sentiments: 32 | f.write(str(v[0])+"\t") 33 | word=v[1][0].lower() 34 | if word=='location1':f.write('location - 1') 35 | elif word=='location2':f.write('location - 2') 36 | elif word[0]=='\'':f.write("\' "+word[1:]) 37 | else:f.write(word) 38 | for i in range(1,len(v[1])): 39 | word=v[1][i].lower() 40 | f.write(" ") 41 | if word == 'location1': 42 | f.write('location - 1') 43 | elif word == 'location2': 44 | f.write('location - 2') 45 | elif word[0] == '\'': 46 | f.write("\' " + word[1:]) 47 | else: 48 | f.write(word) 49 | f.write("\t") 50 | f.write(sentiment+" - ") 51 | if v[2]=='LOCATION1':f.write('location - 1 - ') 52 | if v[2]=='LOCATION2':f.write('location - 2 - ') 53 | if len(v[3])==1: 54 | f.write(v[3][0]+"\t") 55 | else: 56 | f.write("transit location\t") 57 | if v[4]==sentiment: 58 | f.write("1\n") 59 | else: 60 | f.write("0\n") 61 | 62 | with open(dir_path+"train_QA_B.tsv","w",encoding="utf-8") as f: 63 | f.write("id\tsentence1\tsentence2\tlabel\n") 64 | for v in train: 65 | for sentiment in sentiments: 66 | f.write(str(v[0])+"\t") 67 | word=v[1][0].lower() 68 | if word=='location1':f.write('location - 1') 69 | elif word=='location2':f.write('location - 2') 70 | elif word[0]=='\'':f.write("\' "+word[1:]) 71 | else:f.write(word) 72 | for i in range(1,len(v[1])): 73 | word=v[1][i].lower() 74 | f.write(" ") 75 | if word == 'location1': 76 | f.write('location - 1') 77 | elif word == 'location2': 78 | f.write('location - 2') 79 | elif word[0] == '\'': 80 | f.write("\' " + word[1:]) 81 | else: 82 | f.write(word) 83 | f.write("\t") 84 | f.write("the polarity of the aspect ") 85 | if len(v[3])==1: 86 | f.write(v[3][0]) 87 | else: 88 | f.write("transit location") 89 | if v[2]=='LOCATION1':f.write(' of location - 1 - is ') 90 | if v[2]=='LOCATION2':f.write(' of location - 2 - is ') 91 | f.write(sentiment+" .\t") 92 | if v[4]==sentiment: 93 | f.write("1\n") 94 | else: 95 | f.write("0\n") 96 | 97 | with open(dir_path+"dev_NLI_B.tsv","w",encoding="utf-8") as f: 98 | f.write("id\tsentence1\tsentence2\tlabel\n") 99 | for v in val: 100 | for sentiment in sentiments: 101 | f.write(str(v[0])+"\t") 102 | word=v[1][0].lower() 103 | if word=='location1':f.write('location - 1') 104 | elif word=='location2':f.write('location - 2') 105 | elif word[0]=='\'':f.write("\' "+word[1:]) 106 | else:f.write(word) 107 | for i in range(1,len(v[1])): 108 | word=v[1][i].lower() 109 | f.write(" ") 110 | if word == 'location1': 111 | f.write('location - 1') 112 | elif word == 'location2': 113 | f.write('location - 2') 114 | elif word[0] == '\'': 115 | f.write("\' " + word[1:]) 116 | else: 117 | f.write(word) 118 | f.write("\t") 119 | f.write(sentiment+" - ") 120 | if v[2]=='LOCATION1':f.write('location - 1 - ') 121 | if v[2]=='LOCATION2':f.write('location - 2 - ') 122 | if len(v[3])==1: 123 | f.write(v[3][0]+"\t") 124 | else: 125 | f.write("transit location\t") 126 | if v[4]==sentiment: 127 | f.write("1\n") 128 | else: 129 | f.write("0\n") 130 | 131 | with open(dir_path+"dev_QA_B.tsv","w",encoding="utf-8") as f: 132 | f.write("id\tsentence1\tsentence2\tlabel\n") 133 | for v in val: 134 | for sentiment in sentiments: 135 | f.write(str(v[0])+"\t") 136 | word=v[1][0].lower() 137 | if word=='location1':f.write('location - 1') 138 | elif word=='location2':f.write('location - 2') 139 | elif word[0]=='\'':f.write("\' "+word[1:]) 140 | else:f.write(word) 141 | for i in range(1,len(v[1])): 142 | word=v[1][i].lower() 143 | f.write(" ") 144 | if word == 'location1': 145 | f.write('location - 1') 146 | elif word == 'location2': 147 | f.write('location - 2') 148 | elif word[0] == '\'': 149 | f.write("\' " + word[1:]) 150 | else: 151 | f.write(word) 152 | f.write("\t") 153 | f.write("the polarity of the aspect ") 154 | if len(v[3])==1: 155 | f.write(v[3][0]) 156 | else: 157 | f.write("transit location") 158 | if v[2]=='LOCATION1':f.write(' of location - 1 - is ') 159 | if v[2]=='LOCATION2':f.write(' of location - 2 - is ') 160 | f.write(sentiment+" .\t") 161 | if v[4]==sentiment: 162 | f.write("1\n") 163 | else: 164 | f.write("0\n") 165 | 166 | with open(dir_path+"test_NLI_B.tsv","w",encoding="utf-8") as f: 167 | f.write("id\tsentence1\tsentence2\tlabel\n") 168 | for v in test: 169 | for sentiment in sentiments: 170 | f.write(str(v[0])+"\t") 171 | word=v[1][0].lower() 172 | if word=='location1':f.write('location - 1') 173 | elif word=='location2':f.write('location - 2') 174 | elif word[0]=='\'':f.write("\' "+word[1:]) 175 | else:f.write(word) 176 | for i in range(1,len(v[1])): 177 | word=v[1][i].lower() 178 | f.write(" ") 179 | if word == 'location1': 180 | f.write('location - 1') 181 | elif word == 'location2': 182 | f.write('location - 2') 183 | elif word[0] == '\'': 184 | f.write("\' " + word[1:]) 185 | else: 186 | f.write(word) 187 | f.write("\t") 188 | f.write(sentiment + " - ") 189 | if v[2]=='LOCATION1':f.write('location - 1 - ') 190 | if v[2]=='LOCATION2':f.write('location - 2 - ') 191 | if len(v[3])==1: 192 | f.write(v[3][0]+"\t") 193 | else: 194 | f.write("transit location\t") 195 | if v[4]==sentiment: 196 | f.write("1\n") 197 | else: 198 | f.write("0\n") 199 | 200 | with open(dir_path+"test_QA_B.tsv","w",encoding="utf-8") as f: 201 | f.write("id\tsentence1\tsentence2\tlabel\n") 202 | for v in test: 203 | for sentiment in sentiments: 204 | f.write(str(v[0])+"\t") 205 | word=v[1][0].lower() 206 | if word=='location1':f.write('location - 1') 207 | elif word=='location2':f.write('location - 2') 208 | elif word[0]=='\'':f.write("\' "+word[1:]) 209 | else:f.write(word) 210 | for i in range(1,len(v[1])): 211 | word=v[1][i].lower() 212 | f.write(" ") 213 | if word == 'location1': 214 | f.write('location - 1') 215 | elif word == 'location2': 216 | f.write('location - 2') 217 | elif word[0] == '\'': 218 | f.write("\' " + word[1:]) 219 | else: 220 | f.write(word) 221 | f.write("\t") 222 | f.write("the polarity of the aspect ") 223 | if len(v[3])==1: 224 | f.write(v[3][0]) 225 | else: 226 | f.write("transit location") 227 | if v[2]=='LOCATION1':f.write(' of location - 1 - is ') 228 | if v[2]=='LOCATION2':f.write(' of location - 2 - is ') 229 | f.write(sentiment+" .\t") 230 | if v[4]==sentiment: 231 | f.write("1\n") 232 | else: 233 | f.write("0\n") -------------------------------------------------------------------------------- /tokenization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | # Reference: https://github.com/huggingface/pytorch-pretrained-BERT 4 | 5 | """Tokenization classes.""" 6 | 7 | from __future__ import absolute_import, division, print_function 8 | 9 | import collections 10 | import unicodedata 11 | 12 | import six 13 | 14 | 15 | def convert_to_unicode(text): 16 | """Converts `text` to Unicode (if it's not already), assuming utf-8 input.""" 17 | if six.PY3: 18 | if isinstance(text, str): 19 | return text 20 | elif isinstance(text, bytes): 21 | return text.decode("utf-8", "ignore") 22 | else: 23 | raise ValueError("Unsupported string type: %s" % (type(text))) 24 | elif six.PY2: 25 | if isinstance(text, str): 26 | return text.decode("utf-8", "ignore") 27 | elif isinstance(text, unicode): 28 | return text 29 | else: 30 | raise ValueError("Unsupported string type: %s" % (type(text))) 31 | else: 32 | raise ValueError("Not running on Python2 or Python 3?") 33 | 34 | 35 | def printable_text(text): 36 | """Returns text encoded in a way suitable for print or `tf.logging`.""" 37 | 38 | # These functions want `str` for both Python2 and Python3, but in one case 39 | # it's a Unicode string and in the other it's a byte string. 40 | if six.PY3: 41 | if isinstance(text, str): 42 | return text 43 | elif isinstance(text, bytes): 44 | return text.decode("utf-8", "ignore") 45 | else: 46 | raise ValueError("Unsupported string type: %s" % (type(text))) 47 | elif six.PY2: 48 | if isinstance(text, str): 49 | return text 50 | elif isinstance(text, unicode): 51 | return text.encode("utf-8") 52 | else: 53 | raise ValueError("Unsupported string type: %s" % (type(text))) 54 | else: 55 | raise ValueError("Not running on Python2 or Python 3?") 56 | 57 | 58 | def load_vocab(vocab_file): 59 | """Loads a vocabulary file into a dictionary.""" 60 | vocab = collections.OrderedDict() 61 | index = 0 62 | with open(vocab_file, "r") as reader: 63 | while True: 64 | token = convert_to_unicode(reader.readline()) 65 | if not token: 66 | break 67 | token = token.strip() 68 | vocab[token] = index 69 | index += 1 70 | return vocab 71 | 72 | 73 | def convert_tokens_to_ids(vocab, tokens): 74 | """Converts a sequence of tokens into ids using the vocab.""" 75 | ids = [] 76 | for token in tokens: 77 | ids.append(vocab[token]) 78 | return ids 79 | 80 | 81 | def whitespace_tokenize(text): 82 | """Runs basic whitespace cleaning and splitting on a peice of text.""" 83 | text = text.strip() 84 | if not text: 85 | return [] 86 | tokens = text.split() 87 | return tokens 88 | 89 | 90 | class FullTokenizer(object): 91 | """Runs end-to-end tokenziation.""" 92 | 93 | def __init__(self, vocab_file, do_lower_case=True): 94 | self.vocab = load_vocab(vocab_file) 95 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case) 96 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) 97 | 98 | def tokenize(self, text): 99 | split_tokens = [] 100 | for token in self.basic_tokenizer.tokenize(text): 101 | for sub_token in self.wordpiece_tokenizer.tokenize(token): 102 | split_tokens.append(sub_token) 103 | 104 | return split_tokens 105 | 106 | def convert_tokens_to_ids(self, tokens): 107 | return convert_tokens_to_ids(self.vocab, tokens) 108 | 109 | 110 | class BasicTokenizer(object): 111 | """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" 112 | 113 | def __init__(self, do_lower_case=True): 114 | """Constructs a BasicTokenizer. 115 | 116 | Args: 117 | do_lower_case: Whether to lower case the input. 118 | """ 119 | self.do_lower_case = do_lower_case 120 | 121 | def tokenize(self, text): 122 | """Tokenizes a piece of text.""" 123 | text = convert_to_unicode(text) 124 | text = self._clean_text(text) 125 | orig_tokens = whitespace_tokenize(text) 126 | split_tokens = [] 127 | for token in orig_tokens: 128 | if self.do_lower_case: 129 | token = token.lower() 130 | token = self._run_strip_accents(token) 131 | split_tokens.extend(self._run_split_on_punc(token)) 132 | 133 | output_tokens = whitespace_tokenize(" ".join(split_tokens)) 134 | return output_tokens 135 | 136 | def _run_strip_accents(self, text): 137 | """Strips accents from a piece of text.""" 138 | text = unicodedata.normalize("NFD", text) 139 | output = [] 140 | for char in text: 141 | cat = unicodedata.category(char) 142 | if cat == "Mn": 143 | continue 144 | output.append(char) 145 | return "".join(output) 146 | 147 | def _run_split_on_punc(self, text): 148 | """Splits punctuation on a piece of text.""" 149 | chars = list(text) 150 | i = 0 151 | start_new_word = True 152 | output = [] 153 | while i < len(chars): 154 | char = chars[i] 155 | if _is_punctuation(char): 156 | output.append([char]) 157 | start_new_word = True 158 | else: 159 | if start_new_word: 160 | output.append([]) 161 | start_new_word = False 162 | output[-1].append(char) 163 | i += 1 164 | 165 | return ["".join(x) for x in output] 166 | 167 | def _clean_text(self, text): 168 | """Performs invalid character removal and whitespace cleanup on text.""" 169 | output = [] 170 | for char in text: 171 | cp = ord(char) 172 | if cp == 0 or cp == 0xfffd or _is_control(char): 173 | continue 174 | if _is_whitespace(char): 175 | output.append(" ") 176 | else: 177 | output.append(char) 178 | return "".join(output) 179 | 180 | 181 | class WordpieceTokenizer(object): 182 | """Runs WordPiece tokenization.""" 183 | 184 | def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=100): 185 | self.vocab = vocab 186 | self.unk_token = unk_token 187 | self.max_input_chars_per_word = max_input_chars_per_word 188 | 189 | def tokenize(self, text): 190 | """Tokenizes a piece of text into its word pieces. 191 | 192 | This uses a greedy longest-match-first algorithm to perform tokenization 193 | using the given vocabulary. 194 | 195 | For example: 196 | input = "unaffable" 197 | output = ["un", "##aff", "##able"] 198 | 199 | Args: 200 | text: A single token or whitespace separated tokens. This should have 201 | already been passed through `BasicTokenizer. 202 | 203 | Returns: 204 | A list of wordpiece tokens. 205 | """ 206 | 207 | text = convert_to_unicode(text) 208 | 209 | output_tokens = [] 210 | for token in whitespace_tokenize(text): 211 | chars = list(token) 212 | if len(chars) > self.max_input_chars_per_word: 213 | output_tokens.append(self.unk_token) 214 | continue 215 | 216 | is_bad = False 217 | start = 0 218 | sub_tokens = [] 219 | while start < len(chars): 220 | end = len(chars) 221 | cur_substr = None 222 | while start < end: 223 | substr = "".join(chars[start:end]) 224 | if start > 0: 225 | substr = "##" + substr 226 | if substr in self.vocab: 227 | cur_substr = substr 228 | break 229 | end -= 1 230 | if cur_substr is None: 231 | is_bad = True 232 | break 233 | sub_tokens.append(cur_substr) 234 | start = end 235 | 236 | if is_bad: 237 | output_tokens.append(self.unk_token) 238 | else: 239 | output_tokens.extend(sub_tokens) 240 | return output_tokens 241 | 242 | 243 | def _is_whitespace(char): 244 | """Checks whether `chars` is a whitespace character.""" 245 | # \t, \n, and \r are technically contorl characters but we treat them 246 | # as whitespace since they are generally considered as such. 247 | if char == " " or char == "\t" or char == "\n" or char == "\r": 248 | return True 249 | cat = unicodedata.category(char) 250 | if cat == "Zs": 251 | return True 252 | return False 253 | 254 | 255 | def _is_control(char): 256 | """Checks whether `chars` is a control character.""" 257 | # These are technically control characters but we count them as whitespace 258 | # characters. 259 | if char == "\t" or char == "\n" or char == "\r": 260 | return False 261 | cat = unicodedata.category(char) 262 | if cat.startswith("C"): 263 | return True 264 | return False 265 | 266 | 267 | def _is_punctuation(char): 268 | """Checks whether `chars` is a punctuation character.""" 269 | cp = ord(char) 270 | # We treat all non-letter/number ASCII as punctuation. 271 | # Characters such as "^", "$", and "`" are not in the Unicode 272 | # Punctuation class but we treat them as punctuation anyways, for 273 | # consistency. 274 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or 275 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): 276 | return True 277 | cat = unicodedata.category(char) 278 | if cat.startswith("P"): 279 | return True 280 | return False 281 | -------------------------------------------------------------------------------- /evaluation.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import collections 3 | 4 | import numpy as np 5 | import pandas as pd 6 | from sklearn import metrics 7 | from sklearn.preprocessing import label_binarize 8 | 9 | 10 | def get_y_true(task_name): 11 | """ 12 | Read file to obtain y_true. 13 | All of five tasks of Sentihood use the test set of task-BERT-pair-NLI-M to get true labels. 14 | All of five tasks of SemEval-2014 use the test set of task-BERT-pair-NLI-M to get true labels. 15 | """ 16 | if task_name in ["sentihood_single", "sentihood_NLI_M", "sentihood_QA_M", "sentihood_NLI_B", "sentihood_QA_B"]: 17 | true_data_file = "data/sentihood/bert-pair/test_NLI_M.tsv" 18 | 19 | df = pd.read_csv(true_data_file,sep='\t') 20 | y_true = [] 21 | for i in range(len(df)): 22 | label = df['label'][i] 23 | assert label in ['None', 'Positive', 'Negative'], "error!" 24 | if label == 'None': 25 | n = 0 26 | elif label == 'Positive': 27 | n = 1 28 | else: 29 | n = 2 30 | y_true.append(n) 31 | else: 32 | true_data_file = "data/semeval2014/bert-pair/test_NLI_M.csv" 33 | 34 | df = pd.read_csv(true_data_file,sep='\t',header=None).values 35 | y_true=[] 36 | for i in range(len(df)): 37 | label = df[i][1] 38 | assert label in ['positive', 'neutral', 'negative', 'conflict', 'none'], "error!" 39 | if label == 'positive': 40 | n = 0 41 | elif label == 'neutral': 42 | n = 1 43 | elif label == 'negative': 44 | n = 2 45 | elif label == 'conflict': 46 | n = 3 47 | elif label == 'none': 48 | n = 4 49 | y_true.append(n) 50 | 51 | return y_true 52 | 53 | 54 | def get_y_pred(task_name, pred_data_dir): 55 | """ 56 | Read file to obtain y_pred and scores. 57 | """ 58 | pred=[] 59 | score=[] 60 | if task_name in ["sentihood_NLI_M", "sentihood_QA_M"]: 61 | with open(pred_data_dir, "r", encoding="utf-8") as f: 62 | s=f.readline().strip().split() 63 | while s: 64 | pred.append(int(s[0])) 65 | score.append([float(s[1]),float(s[2]),float(s[3])]) 66 | s = f.readline().strip().split() 67 | elif task_name in ["sentihood_NLI_B", "sentihood_QA_B"]: 68 | count = 0 69 | tmp = [] 70 | with open(pred_data_dir, "r", encoding="utf-8") as f: 71 | s = f.readline().strip().split() 72 | while s: 73 | tmp.append([float(s[2])]) 74 | count += 1 75 | if count % 3 == 0: 76 | tmp_sum = np.sum(tmp) 77 | t = [] 78 | for i in range(3): 79 | t.append(tmp[i] / tmp_sum) 80 | score.append(t) 81 | if t[0] >= t[1] and t[0] >= t[2]: 82 | pred.append(0) 83 | elif t[1] >= t[0] and t[1] >= t[2]: 84 | pred.append(1) 85 | else: 86 | pred.append(2) 87 | tmp = [] 88 | s = f.readline().strip().split() 89 | elif task_name == "sentihood_single": 90 | count = 0 91 | with open(pred_data_dir + "loc1_general.txt", "r", encoding="utf-8") as f1_general, \ 92 | open(pred_data_dir + "loc1_price.txt", "r", encoding="utf-8") as f1_price, \ 93 | open(pred_data_dir + "loc1_safety.txt", "r", encoding="utf-8") as f1_safety, \ 94 | open(pred_data_dir + "loc1_transit.txt", "r", encoding="utf-8") as f1_transit: 95 | s = f1_general.readline().strip().split() 96 | while s: 97 | count += 1 98 | pred.append(int(s[0])) 99 | score.append([float(s[1]), float(s[2]), float(s[3])]) 100 | if count % 4 == 0: 101 | s = f1_general.readline().strip().split() 102 | if count % 4 == 1: 103 | s = f1_price.readline().strip().split() 104 | if count % 4 == 2: 105 | s = f1_safety.readline().strip().split() 106 | if count % 4 == 3: 107 | s = f1_transit.readline().strip().split() 108 | 109 | with open(pred_data_dir + "loc2_general.txt", "r", encoding="utf-8") as f2_general, \ 110 | open(pred_data_dir + "loc2_price.txt", "r", encoding="utf-8") as f2_price, \ 111 | open(pred_data_dir + "loc2_safety.txt", "r", encoding="utf-8") as f2_safety, \ 112 | open(pred_data_dir + "loc2_transit.txt", "r", encoding="utf-8") as f2_transit: 113 | s = f2_general.readline().strip().split() 114 | while s: 115 | count += 1 116 | pred.append(int(s[0])) 117 | score.append([float(s[1]), float(s[2]), float(s[3])]) 118 | if count % 4 == 0: 119 | s = f2_general.readline().strip().split() 120 | if count % 4 == 1: 121 | s = f2_price.readline().strip().split() 122 | if count % 4 == 2: 123 | s = f2_safety.readline().strip().split() 124 | if count % 4 == 3: 125 | s = f2_transit.readline().strip().split() 126 | elif task_name in ["semeval_NLI_M", "semeval_QA_M"]: 127 | with open(pred_data_dir,"r",encoding="utf-8") as f: 128 | s=f.readline().strip().split() 129 | while s: 130 | pred.append(int(s[0])) 131 | score.append([float(s[1]), float(s[2]), float(s[3]), float(s[4]), float(s[5])]) 132 | s = f.readline().strip().split() 133 | elif task_name in ["semeval_NLI_B", "semeval_QA_B"]: 134 | count = 0 135 | tmp = [] 136 | with open(pred_data_dir, "r", encoding="utf-8") as f: 137 | s = f.readline().strip().split() 138 | while s: 139 | tmp.append([float(s[2])]) 140 | count += 1 141 | if count % 5 == 0: 142 | tmp_sum = np.sum(tmp) 143 | t = [] 144 | for i in range(5): 145 | t.append(tmp[i] / tmp_sum) 146 | score.append(t) 147 | if t[0] >= t[1] and t[0] >= t[2] and t[0]>=t[3] and t[0]>=t[4]: 148 | pred.append(0) 149 | elif t[1] >= t[0] and t[1] >= t[2] and t[1]>=t[3] and t[1]>=t[4]: 150 | pred.append(1) 151 | elif t[2] >= t[0] and t[2] >= t[1] and t[2]>=t[3] and t[2]>=t[4]: 152 | pred.append(2) 153 | elif t[3] >= t[0] and t[3] >= t[1] and t[3]>=t[2] and t[3]>=t[4]: 154 | pred.append(3) 155 | else: 156 | pred.append(4) 157 | tmp = [] 158 | s = f.readline().strip().split() 159 | else: 160 | count = 0 161 | with open(pred_data_dir+"price.txt","r",encoding="utf-8") as f_price, \ 162 | open(pred_data_dir+"anecdotes.txt", "r", encoding="utf-8") as f_anecdotes, \ 163 | open(pred_data_dir+"food.txt", "r", encoding="utf-8") as f_food, \ 164 | open(pred_data_dir+"ambience.txt", "r", encoding="utf-8") as f_ambience, \ 165 | open(pred_data_dir+"service.txt", "r", encoding="utf-8") as f_service: 166 | s = f_price.readline().strip().split() 167 | while s: 168 | count += 1 169 | pred.append(int(s[0])) 170 | score.append([float(s[1]), float(s[2]), float(s[3]), float(s[4]), float(s[5])]) 171 | if count % 5 == 0: 172 | s = f_price.readline().strip().split() 173 | if count % 5 == 1: 174 | s = f_anecdotes.readline().strip().split() 175 | if count % 5 == 2: 176 | s = f_food.readline().strip().split() 177 | if count % 5 == 3: 178 | s = f_ambience.readline().strip().split() 179 | if count % 5 == 4: 180 | s = f_service.readline().strip().split() 181 | 182 | return pred, score 183 | 184 | 185 | def sentihood_strict_acc(y_true, y_pred): 186 | """ 187 | Calculate "strict Acc" of aspect detection task of Sentihood. 188 | """ 189 | total_cases=int(len(y_true)/4) 190 | true_cases=0 191 | for i in range(total_cases): 192 | if y_true[i*4]!=y_pred[i*4]:continue 193 | if y_true[i*4+1]!=y_pred[i*4+1]:continue 194 | if y_true[i*4+2]!=y_pred[i*4+2]:continue 195 | if y_true[i*4+3]!=y_pred[i*4+3]:continue 196 | true_cases+=1 197 | aspect_strict_Acc = true_cases/total_cases 198 | 199 | return aspect_strict_Acc 200 | 201 | 202 | def sentihood_macro_F1(y_true, y_pred): 203 | """ 204 | Calculate "Macro-F1" of aspect detection task of Sentihood. 205 | """ 206 | p_all=0 207 | r_all=0 208 | count=0 209 | for i in range(len(y_pred)//4): 210 | a=set() 211 | b=set() 212 | for j in range(4): 213 | if y_pred[i*4+j]!=0: 214 | a.add(j) 215 | if y_true[i*4+j]!=0: 216 | b.add(j) 217 | if len(b)==0:continue 218 | a_b=a.intersection(b) 219 | if len(a_b)>0: 220 | p=len(a_b)/len(a) 221 | r=len(a_b)/len(b) 222 | else: 223 | p=0 224 | r=0 225 | count+=1 226 | p_all+=p 227 | r_all+=r 228 | Ma_p=p_all/count 229 | Ma_r=r_all/count 230 | aspect_Macro_F1 = 2*Ma_p*Ma_r/(Ma_p+Ma_r) 231 | 232 | return aspect_Macro_F1 233 | 234 | 235 | def sentihood_AUC_Acc(y_true, score): 236 | """ 237 | Calculate "Macro-AUC" of both aspect detection and sentiment classification tasks of Sentihood. 238 | Calculate "Acc" of sentiment classification task of Sentihood. 239 | """ 240 | # aspect-Macro-AUC 241 | aspect_y_true=[] 242 | aspect_y_score=[] 243 | aspect_y_trues=[[],[],[],[]] 244 | aspect_y_scores=[[],[],[],[]] 245 | for i in range(len(y_true)): 246 | if y_true[i]>0: 247 | aspect_y_true.append(0) 248 | else: 249 | aspect_y_true.append(1) # "None": 1 250 | tmp_score=score[i][0] # probability of "None" 251 | aspect_y_score.append(tmp_score) 252 | aspect_y_trues[i%4].append(aspect_y_true[-1]) 253 | aspect_y_scores[i%4].append(aspect_y_score[-1]) 254 | 255 | aspect_auc=[] 256 | for i in range(4): 257 | aspect_auc.append(metrics.roc_auc_score(aspect_y_trues[i], aspect_y_scores[i])) 258 | aspect_Macro_AUC = np.mean(aspect_auc) 259 | 260 | # sentiment-Macro-AUC 261 | sentiment_y_true=[] 262 | sentiment_y_pred=[] 263 | sentiment_y_score=[] 264 | sentiment_y_trues=[[],[],[],[]] 265 | sentiment_y_scores=[[],[],[],[]] 266 | for i in range(len(y_true)): 267 | if y_true[i]>0: 268 | sentiment_y_true.append(y_true[i]-1) # "Postive":0, "Negative":1 269 | tmp_score=score[i][2]/(score[i][1]+score[i][2]) # probability of "Negative" 270 | sentiment_y_score.append(tmp_score) 271 | if tmp_score>0.5: 272 | sentiment_y_pred.append(1) # "Negative": 1 273 | else: 274 | sentiment_y_pred.append(0) 275 | sentiment_y_trues[i%4].append(sentiment_y_true[-1]) 276 | sentiment_y_scores[i%4].append(sentiment_y_score[-1]) 277 | 278 | sentiment_auc=[] 279 | for i in range(4): 280 | sentiment_auc.append(metrics.roc_auc_score(sentiment_y_trues[i], sentiment_y_scores[i])) 281 | sentiment_Macro_AUC = np.mean(sentiment_auc) 282 | 283 | # sentiment Acc 284 | sentiment_y_true = np.array(sentiment_y_true) 285 | sentiment_y_pred = np.array(sentiment_y_pred) 286 | sentiment_Acc = metrics.accuracy_score(sentiment_y_true,sentiment_y_pred) 287 | 288 | return aspect_Macro_AUC, sentiment_Acc, sentiment_Macro_AUC 289 | 290 | 291 | def semeval_PRF(y_true, y_pred): 292 | """ 293 | Calculate "Micro P R F" of aspect detection task of SemEval-2014. 294 | """ 295 | s_all=0 296 | g_all=0 297 | s_g_all=0 298 | for i in range(len(y_pred)//5): 299 | s=set() 300 | g=set() 301 | for j in range(5): 302 | if y_pred[i*5+j]!=4: 303 | s.add(j) 304 | if y_true[i*5+j]!=4: 305 | g.add(j) 306 | if len(g)==0:continue 307 | s_g=s.intersection(g) 308 | s_all+=len(s) 309 | g_all+=len(g) 310 | s_g_all+=len(s_g) 311 | 312 | p=s_g_all/s_all 313 | r=s_g_all/g_all 314 | f=2*p*r/(p+r) 315 | 316 | return p,r,f 317 | 318 | 319 | def semeval_Acc(y_true, y_pred, score, classes=4): 320 | """ 321 | Calculate "Acc" of sentiment classification task of SemEval-2014. 322 | """ 323 | assert classes in [2, 3, 4], "classes must be 2 or 3 or 4." 324 | 325 | if classes == 4: 326 | total=0 327 | total_right=0 328 | for i in range(len(y_true)): 329 | if y_true[i]==4:continue 330 | total+=1 331 | tmp=y_pred[i] 332 | if tmp==4: 333 | if score[i][0]>=score[i][1] and score[i][0]>=score[i][2] and score[i][0]>=score[i][3]: 334 | tmp=0 335 | elif score[i][1]>=score[i][0] and score[i][1]>=score[i][2] and score[i][1]>=score[i][3]: 336 | tmp=1 337 | elif score[i][2]>=score[i][0] and score[i][2]>=score[i][1] and score[i][2]>=score[i][3]: 338 | tmp=2 339 | else: 340 | tmp=3 341 | if y_true[i]==tmp: 342 | total_right+=1 343 | sentiment_Acc = total_right/total 344 | elif classes == 3: 345 | total=0 346 | total_right=0 347 | for i in range(len(y_true)): 348 | if y_true[i]>=3:continue 349 | total+=1 350 | tmp=y_pred[i] 351 | if tmp>=3: 352 | if score[i][0]>=score[i][1] and score[i][0]>=score[i][2]: 353 | tmp=0 354 | elif score[i][1]>=score[i][0] and score[i][1]>=score[i][2]: 355 | tmp=1 356 | else: 357 | tmp=2 358 | if y_true[i]==tmp: 359 | total_right+=1 360 | sentiment_Acc = total_right/total 361 | else: 362 | total=0 363 | total_right=0 364 | for i in range(len(y_true)): 365 | if y_true[i]>=3 or y_true[i]==1:continue 366 | total+=1 367 | tmp=y_pred[i] 368 | if tmp>=3 or tmp==1: 369 | if score[i][0]>=score[i][2]: 370 | tmp=0 371 | else: 372 | tmp=2 373 | if y_true[i]==tmp: 374 | total_right+=1 375 | sentiment_Acc = total_right/total 376 | 377 | return sentiment_Acc 378 | 379 | 380 | def main(): 381 | parser = argparse.ArgumentParser() 382 | parser.add_argument("--task_name", 383 | default=None, 384 | type=str, 385 | required=True, 386 | choices=["sentihood_single", "sentihood_NLI_M", "sentihood_QA_M", \ 387 | "sentihood_NLI_B", "sentihood_QA_B", "semeval_single", \ 388 | "semeval_NLI_M", "semeval_QA_M", "semeval_NLI_B", "semeval_QA_B"], 389 | help="The name of the task to evalution.") 390 | parser.add_argument("--pred_data_dir", 391 | default=None, 392 | type=str, 393 | required=True, 394 | help="The pred data dir.") 395 | args = parser.parse_args() 396 | 397 | 398 | result = collections.OrderedDict() 399 | if args.task_name in ["sentihood_single", "sentihood_NLI_M", "sentihood_QA_M", "sentihood_NLI_B", "sentihood_QA_B"]: 400 | y_true = get_y_true(args.task_name) 401 | y_pred, score = get_y_pred(args.task_name, args.pred_data_dir) 402 | aspect_strict_Acc = sentihood_strict_acc(y_true, y_pred) 403 | aspect_Macro_F1 = sentihood_macro_F1(y_true, y_pred) 404 | aspect_Macro_AUC, sentiment_Acc, sentiment_Macro_AUC = sentihood_AUC_Acc(y_true, score) 405 | result = {'aspect_strict_Acc': aspect_strict_Acc, 406 | 'aspect_Macro_F1': aspect_Macro_F1, 407 | 'aspect_Macro_AUC': aspect_Macro_AUC, 408 | 'sentiment_Acc': sentiment_Acc, 409 | 'sentiment_Macro_AUC': sentiment_Macro_AUC} 410 | else: 411 | y_true = get_y_true(args.task_name) 412 | y_pred, score = get_y_pred(args.task_name, args.pred_data_dir) 413 | aspect_P, aspect_R, aspect_F = semeval_PRF(y_true, y_pred) 414 | sentiment_Acc_4_classes = semeval_Acc(y_true, y_pred, score, 4) 415 | sentiment_Acc_3_classes = semeval_Acc(y_true, y_pred, score, 3) 416 | sentiment_Acc_2_classes = semeval_Acc(y_true, y_pred, score, 2) 417 | result = {'aspect_P': aspect_P, 418 | 'aspect_R': aspect_R, 419 | 'aspect_F': aspect_F, 420 | 'sentiment_Acc_4_classes': sentiment_Acc_4_classes, 421 | 'sentiment_Acc_3_classes': sentiment_Acc_3_classes, 422 | 'sentiment_Acc_2_classes': sentiment_Acc_2_classes} 423 | 424 | for key in result.keys(): 425 | print(key, "=",str(result[key])) 426 | 427 | 428 | if __name__ == "__main__": 429 | main() 430 | -------------------------------------------------------------------------------- /processor.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | """Processors for different tasks.""" 4 | 5 | import csv 6 | import os 7 | 8 | import pandas as pd 9 | 10 | import tokenization 11 | 12 | 13 | class InputExample(object): 14 | """A single training/test example for simple sequence classification.""" 15 | 16 | def __init__(self, guid, text_a, text_b=None, label=None): 17 | """Constructs a InputExample. 18 | 19 | Args: 20 | guid: Unique id for the example. 21 | text_a: string. The untokenized text of the first sequence. For single 22 | sequence tasks, only this sequence must be specified. 23 | text_b: (Optional) string. The untokenized text of the second sequence. 24 | Only must be specified for sequence pair tasks. 25 | label: (Optional) string. The label of the example. This should be 26 | specified for train and dev examples, but not for test examples. 27 | """ 28 | self.guid = guid 29 | self.text_a = text_a 30 | self.text_b = text_b 31 | self.label = label 32 | 33 | 34 | class DataProcessor(object): 35 | """Base class for data converters for sequence classification data sets.""" 36 | 37 | def get_train_examples(self, data_dir): 38 | """Gets a collection of `InputExample`s for the train set.""" 39 | raise NotImplementedError() 40 | 41 | def get_dev_examples(self, data_dir): 42 | """Gets a collection of `InputExample`s for the dev set.""" 43 | raise NotImplementedError() 44 | 45 | def get_test_examples(self, data_dir): 46 | """Gets a collection of `InputExample`s for the test set.""" 47 | raise NotImplementedError() 48 | 49 | def get_labels(self): 50 | """Gets the list of labels for this data set.""" 51 | raise NotImplementedError() 52 | 53 | @classmethod 54 | def _read_tsv(cls, input_file, quotechar=None): 55 | """Reads a tab separated value file.""" 56 | with open(input_file, "r") as f: 57 | reader = csv.reader(f, delimiter="\t", quotechar=quotechar) 58 | lines = [] 59 | for line in reader: 60 | lines.append(line) 61 | return lines 62 | 63 | 64 | class Sentihood_single_Processor(DataProcessor): 65 | """Processor for the Sentihood data set.""" 66 | 67 | def get_train_examples(self, data_dir): 68 | """See base class.""" 69 | train_data = pd.read_csv(os.path.join(data_dir, "train.tsv"),header=None,sep="\t").values 70 | return self._create_examples(train_data, "train") 71 | 72 | def get_dev_examples(self, data_dir): 73 | """See base class.""" 74 | dev_data = pd.read_csv(os.path.join(data_dir, "dev.tsv"),header=None,sep="\t").values 75 | return self._create_examples(dev_data, "dev") 76 | 77 | def get_test_examples(self, data_dir): 78 | """See base class.""" 79 | test_data = pd.read_csv(os.path.join(data_dir, "test.tsv"),header=None,sep="\t").values 80 | return self._create_examples(test_data, "test") 81 | 82 | def get_labels(self): 83 | """See base class.""" 84 | return ['None', 'Positive', 'Negative'] 85 | 86 | def _create_examples(self, lines, set_type): 87 | """Creates examples for the training and dev sets.""" 88 | examples = [] 89 | for (i, line) in enumerate(lines): 90 | # if i>50:break 91 | guid = "%s-%s" % (set_type, i) 92 | text_a = tokenization.convert_to_unicode(str(line[1])) 93 | label = tokenization.convert_to_unicode(str(line[2])) 94 | if i%1000==0: 95 | print(i) 96 | print("guid=",guid) 97 | print("text_a=",text_a) 98 | print("label=",label) 99 | examples.append( 100 | InputExample(guid=guid, text_a=text_a, text_b=None, label=label)) 101 | return examples 102 | 103 | 104 | class Sentihood_NLI_M_Processor(DataProcessor): 105 | """Processor for the Sentihood data set.""" 106 | 107 | def get_train_examples(self, data_dir): 108 | """See base class.""" 109 | train_data = pd.read_csv(os.path.join(data_dir, "train_NLI_M.tsv"),sep="\t").values 110 | return self._create_examples(train_data, "train") 111 | 112 | def get_dev_examples(self, data_dir): 113 | """See base class.""" 114 | dev_data = pd.read_csv(os.path.join(data_dir, "dev_NLI_M.tsv"),sep="\t").values 115 | return self._create_examples(dev_data, "dev") 116 | 117 | def get_test_examples(self, data_dir): 118 | """See base class.""" 119 | test_data = pd.read_csv(os.path.join(data_dir, "test_NLI_M.tsv"),sep="\t").values 120 | return self._create_examples(test_data, "test") 121 | 122 | def get_labels(self): 123 | """See base class.""" 124 | return ['None', 'Positive', 'Negative'] 125 | 126 | def _create_examples(self, lines, set_type): 127 | """Creates examples for the training and dev sets.""" 128 | examples = [] 129 | for (i, line) in enumerate(lines): 130 | # if i>50:break 131 | guid = "%s-%s" % (set_type, i) 132 | text_a = tokenization.convert_to_unicode(str(line[1])) 133 | text_b = tokenization.convert_to_unicode(str(line[2])) 134 | label = tokenization.convert_to_unicode(str(line[3])) 135 | if i%1000==0: 136 | print(i) 137 | print("guid=",guid) 138 | print("text_a=",text_a) 139 | print("text_b=",text_b) 140 | print("label=",label) 141 | examples.append( 142 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 143 | return examples 144 | 145 | 146 | class Sentihood_QA_M_Processor(DataProcessor): 147 | """Processor for the Sentihood data set.""" 148 | 149 | def get_train_examples(self, data_dir): 150 | """See base class.""" 151 | train_data = pd.read_csv(os.path.join(data_dir, "train_QA_M.tsv"),sep="\t").values 152 | return self._create_examples(train_data, "train") 153 | 154 | def get_dev_examples(self, data_dir): 155 | """See base class.""" 156 | dev_data = pd.read_csv(os.path.join(data_dir, "dev_QA_M.tsv"),sep="\t").values 157 | return self._create_examples(dev_data, "dev") 158 | 159 | def get_test_examples(self, data_dir): 160 | """See base class.""" 161 | test_data = pd.read_csv(os.path.join(data_dir, "test_QA_M.tsv"),sep="\t").values 162 | return self._create_examples(test_data, "test") 163 | 164 | def get_labels(self): 165 | """See base class.""" 166 | return ['None', 'Positive', 'Negative'] 167 | 168 | def _create_examples(self, lines, set_type): 169 | """Creates examples for the training and dev sets.""" 170 | examples = [] 171 | for (i, line) in enumerate(lines): 172 | # if i>50:break 173 | guid = "%s-%s" % (set_type, i) 174 | text_a = tokenization.convert_to_unicode(str(line[1])) 175 | text_b = tokenization.convert_to_unicode(str(line[2])) 176 | label = tokenization.convert_to_unicode(str(line[3])) 177 | if i%1000==0: 178 | print(i) 179 | print("guid=",guid) 180 | print("text_a=",text_a) 181 | print("text_b=",text_b) 182 | print("label=",label) 183 | examples.append( 184 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 185 | return examples 186 | 187 | 188 | class Sentihood_NLI_B_Processor(DataProcessor): 189 | """Processor for the Sentihood data set.""" 190 | 191 | def get_train_examples(self, data_dir): 192 | """See base class.""" 193 | train_data = pd.read_csv(os.path.join(data_dir, "train_NLI_B.tsv"),sep="\t").values 194 | return self._create_examples(train_data, "train") 195 | 196 | def get_dev_examples(self, data_dir): 197 | """See base class.""" 198 | dev_data = pd.read_csv(os.path.join(data_dir, "dev_NLI_B.tsv"),sep="\t").values 199 | return self._create_examples(dev_data, "dev") 200 | 201 | def get_test_examples(self, data_dir): 202 | """See base class.""" 203 | test_data = pd.read_csv(os.path.join(data_dir, "test_NLI_B.tsv"),sep="\t").values 204 | return self._create_examples(test_data, "test") 205 | 206 | def get_labels(self): 207 | """See base class.""" 208 | return ['0', '1'] 209 | 210 | def _create_examples(self, lines, set_type): 211 | """Creates examples for the training and dev sets.""" 212 | examples = [] 213 | for (i, line) in enumerate(lines): 214 | # if i>50:break 215 | guid = "%s-%s" % (set_type, i) 216 | text_a = tokenization.convert_to_unicode(str(line[2])) 217 | text_b = tokenization.convert_to_unicode(str(line[1])) 218 | label = tokenization.convert_to_unicode(str(line[3])) 219 | if i%1000==0: 220 | print(i) 221 | print("guid=",guid) 222 | print("text_a=",text_a) 223 | print("text_b=",text_b) 224 | print("label=",label) 225 | examples.append( 226 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 227 | return examples 228 | 229 | 230 | class Sentihood_QA_B_Processor(DataProcessor): 231 | """Processor for the Sentihood data set.""" 232 | 233 | def get_train_examples(self, data_dir): 234 | """See base class.""" 235 | train_data = pd.read_csv(os.path.join(data_dir, "train_QA_B.tsv"),sep="\t").values 236 | return self._create_examples(train_data, "train") 237 | 238 | def get_dev_examples(self, data_dir): 239 | """See base class.""" 240 | dev_data = pd.read_csv(os.path.join(data_dir, "dev_QA_B.tsv"),sep="\t").values 241 | return self._create_examples(dev_data, "dev") 242 | 243 | def get_test_examples(self, data_dir): 244 | """See base class.""" 245 | test_data = pd.read_csv(os.path.join(data_dir, "test_QA_B.tsv"),sep="\t").values 246 | return self._create_examples(test_data, "test") 247 | 248 | def get_labels(self): 249 | """See base class.""" 250 | return ['0', '1'] 251 | 252 | def _create_examples(self, lines, set_type): 253 | """Creates examples for the training and dev sets.""" 254 | examples = [] 255 | for (i, line) in enumerate(lines): 256 | # if i>50:break 257 | guid = "%s-%s" % (set_type, i) 258 | text_a = tokenization.convert_to_unicode(str(line[2])) 259 | text_b = tokenization.convert_to_unicode(str(line[1])) 260 | label = tokenization.convert_to_unicode(str(line[3])) 261 | if i%1000==0: 262 | print(i) 263 | print("guid=",guid) 264 | print("text_a=",text_a) 265 | print("text_b=",text_b) 266 | print("label=",label) 267 | examples.append( 268 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 269 | return examples 270 | 271 | 272 | class Semeval_single_Processor(DataProcessor): 273 | """Processor for the Semeval 2014 data set.""" 274 | 275 | def get_train_examples(self, data_dir): 276 | """See base class.""" 277 | train_data = pd.read_csv(os.path.join(data_dir, "train.csv"),header=None,sep="\t").values 278 | return self._create_examples(train_data, "train") 279 | 280 | def get_dev_examples(self, data_dir): 281 | """See base class.""" 282 | dev_data = pd.read_csv(os.path.join(data_dir, "dev.csv"),header=None,sep="\t").values 283 | return self._create_examples(dev_data, "dev") 284 | 285 | def get_test_examples(self, data_dir): 286 | """See base class.""" 287 | test_data = pd.read_csv(os.path.join(data_dir, "test.csv"),header=None,sep="\t").values 288 | return self._create_examples(test_data, "test") 289 | 290 | def get_labels(self): 291 | """See base class.""" 292 | return ['positive', 'neutral', 'negative', 'conflict', 'none'] 293 | 294 | def _create_examples(self, lines, set_type): 295 | """Creates examples for the training and dev sets.""" 296 | examples = [] 297 | for (i, line) in enumerate(lines): 298 | # if i>50:break 299 | guid = "%s-%s" % (set_type, i) 300 | text_a = tokenization.convert_to_unicode(str(line[3])) 301 | label = tokenization.convert_to_unicode(str(line[1])) 302 | if i%1000==0: 303 | print(i) 304 | print("guid=",guid) 305 | print("text_a=",text_a) 306 | print("label=",label) 307 | examples.append( 308 | InputExample(guid=guid, text_a=text_a, text_b=None, label=label)) 309 | return examples 310 | 311 | 312 | class Semeval_NLI_M_Processor(DataProcessor): 313 | """Processor for the Semeval 2014 data set.""" 314 | 315 | def get_train_examples(self, data_dir): 316 | """See base class.""" 317 | train_data = pd.read_csv(os.path.join(data_dir, "train_NLI_M.csv"),header=None,sep="\t").values 318 | return self._create_examples(train_data, "train") 319 | 320 | def get_dev_examples(self, data_dir): 321 | """See base class.""" 322 | dev_data = pd.read_csv(os.path.join(data_dir, "dev_NLI_M.csv"),header=None,sep="\t").values 323 | return self._create_examples(dev_data, "dev") 324 | 325 | def get_test_examples(self, data_dir): 326 | """See base class.""" 327 | test_data = pd.read_csv(os.path.join(data_dir, "test_NLI_M.csv"),header=None,sep="\t").values 328 | return self._create_examples(test_data, "test") 329 | 330 | def get_labels(self): 331 | """See base class.""" 332 | return ['positive', 'neutral', 'negative', 'conflict', 'none'] 333 | 334 | def _create_examples(self, lines, set_type): 335 | """Creates examples for the training and dev sets.""" 336 | examples = [] 337 | for (i, line) in enumerate(lines): 338 | # if i>50:break 339 | guid = "%s-%s" % (set_type, i) 340 | text_a = tokenization.convert_to_unicode(str(line[3])) 341 | text_b = tokenization.convert_to_unicode(str(line[2])) 342 | label = tokenization.convert_to_unicode(str(line[1])) 343 | if i%1000==0: 344 | print(i) 345 | print("guid=",guid) 346 | print("text_a=",text_a) 347 | print("label=",label) 348 | examples.append( 349 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 350 | return examples 351 | 352 | 353 | class Semeval_QA_M_Processor(DataProcessor): 354 | """Processor for the Semeval 2014 data set.""" 355 | 356 | def get_train_examples(self, data_dir): 357 | """See base class.""" 358 | train_data = pd.read_csv(os.path.join(data_dir, "train_QA_M.csv"),header=None,sep="\t").values 359 | return self._create_examples(train_data, "train") 360 | 361 | def get_dev_examples(self, data_dir): 362 | """See base class.""" 363 | dev_data = pd.read_csv(os.path.join(data_dir, "dev_QA_M.csv"),header=None,sep="\t").values 364 | return self._create_examples(dev_data, "dev") 365 | 366 | def get_test_examples(self, data_dir): 367 | """See base class.""" 368 | test_data = pd.read_csv(os.path.join(data_dir, "test_QA_M.csv"),header=None,sep="\t").values 369 | return self._create_examples(test_data, "test") 370 | 371 | def get_labels(self): 372 | """See base class.""" 373 | return ['positive', 'neutral', 'negative', 'conflict', 'none'] 374 | 375 | def _create_examples(self, lines, set_type): 376 | """Creates examples for the training and dev sets.""" 377 | examples = [] 378 | for (i, line) in enumerate(lines): 379 | # if i>50:break 380 | guid = "%s-%s" % (set_type, i) 381 | text_a = tokenization.convert_to_unicode(str(line[3])) 382 | text_b = tokenization.convert_to_unicode(str(line[2])) 383 | label = tokenization.convert_to_unicode(str(line[1])) 384 | if i%1000==0: 385 | print(i) 386 | print("guid=",guid) 387 | print("text_a=",text_a) 388 | print("label=",label) 389 | examples.append( 390 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 391 | return examples 392 | 393 | 394 | class Semeval_NLI_B_Processor(DataProcessor): 395 | """Processor for the Semeval 2014 data set.""" 396 | 397 | def get_train_examples(self, data_dir): 398 | """See base class.""" 399 | train_data = pd.read_csv(os.path.join(data_dir, "train_NLI_B.csv"),header=None,sep="\t").values 400 | return self._create_examples(train_data, "train") 401 | 402 | def get_dev_examples(self, data_dir): 403 | """See base class.""" 404 | dev_data = pd.read_csv(os.path.join(data_dir, "dev_NLI_B.csv"),header=None,sep="\t").values 405 | return self._create_examples(dev_data, "dev") 406 | 407 | def get_test_examples(self, data_dir): 408 | """See base class.""" 409 | test_data = pd.read_csv(os.path.join(data_dir, "test_NLI_B.csv"),header=None,sep="\t").values 410 | return self._create_examples(test_data, "test") 411 | 412 | def get_labels(self): 413 | """See base class.""" 414 | return ['0', '1'] 415 | 416 | def _create_examples(self, lines, set_type): 417 | """Creates examples for the training and dev sets.""" 418 | examples = [] 419 | for (i, line) in enumerate(lines): 420 | # if i>50:break 421 | guid = "%s-%s" % (set_type, i) 422 | text_a = tokenization.convert_to_unicode(str(line[2])) 423 | text_b = tokenization.convert_to_unicode(str(line[3])) 424 | label = tokenization.convert_to_unicode(str(line[1])) 425 | if i%1000==0: 426 | print(i) 427 | print("guid=",guid) 428 | print("text_a=",text_a) 429 | print("label=",label) 430 | examples.append( 431 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 432 | return examples 433 | 434 | 435 | class Semeval_QA_B_Processor(DataProcessor): 436 | """Processor for the Semeval 2014 data set.""" 437 | 438 | def get_train_examples(self, data_dir): 439 | """See base class.""" 440 | train_data = pd.read_csv(os.path.join(data_dir, "train_QA_B.csv"),header=None,sep="\t").values 441 | return self._create_examples(train_data, "train") 442 | 443 | def get_dev_examples(self, data_dir): 444 | """See base class.""" 445 | dev_data = pd.read_csv(os.path.join(data_dir, "dev_QA_B.csv"),header=None,sep="\t").values 446 | return self._create_examples(dev_data, "dev") 447 | 448 | def get_test_examples(self, data_dir): 449 | """See base class.""" 450 | test_data = pd.read_csv(os.path.join(data_dir, "test_QA_B.csv"),header=None,sep="\t").values 451 | return self._create_examples(test_data, "test") 452 | 453 | def get_labels(self): 454 | """See base class.""" 455 | return ['0', '1'] 456 | 457 | def _create_examples(self, lines, set_type): 458 | """Creates examples for the training and dev sets.""" 459 | examples = [] 460 | for (i, line) in enumerate(lines): 461 | # if i>50:break 462 | guid = "%s-%s" % (set_type, i) 463 | text_a = tokenization.convert_to_unicode(str(line[2])) 464 | text_b = tokenization.convert_to_unicode(str(line[3])) 465 | label = tokenization.convert_to_unicode(str(line[1])) 466 | if i%1000==0: 467 | print(i) 468 | print("guid=",guid) 469 | print("text_a=",text_a) 470 | print("label=",label) 471 | examples.append( 472 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 473 | return examples 474 | -------------------------------------------------------------------------------- /modeling.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | # Reference: https://github.com/huggingface/pytorch-pretrained-BERT 4 | 5 | """PyTorch BERT model.""" 6 | 7 | from __future__ import absolute_import, division, print_function 8 | 9 | import copy 10 | import json 11 | import math 12 | 13 | import six 14 | import torch 15 | import torch.nn as nn 16 | from torch.nn import CrossEntropyLoss 17 | 18 | 19 | def gelu(x): 20 | """Implementation of the gelu activation function. 21 | For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): 22 | 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) 23 | """ 24 | return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) 25 | 26 | 27 | class BertConfig(object): 28 | """Configuration class to store the configuration of a `BertModel`. 29 | """ 30 | def __init__(self, 31 | vocab_size, 32 | hidden_size=768, 33 | num_hidden_layers=12, 34 | num_attention_heads=12, 35 | intermediate_size=3072, 36 | hidden_act="gelu", 37 | hidden_dropout_prob=0.1, 38 | attention_probs_dropout_prob=0.1, 39 | max_position_embeddings=512, 40 | type_vocab_size=16, 41 | initializer_range=0.02): 42 | """Constructs BertConfig. 43 | 44 | Args: 45 | vocab_size: Vocabulary size of `inputs_ids` in `BertModel`. 46 | hidden_size: Size of the encoder layers and the pooler layer. 47 | num_hidden_layers: Number of hidden layers in the Transformer encoder. 48 | num_attention_heads: Number of attention heads for each attention layer in 49 | the Transformer encoder. 50 | intermediate_size: The size of the "intermediate" (i.e., feed-forward) 51 | layer in the Transformer encoder. 52 | hidden_act: The non-linear activation function (function or string) in the 53 | encoder and pooler. 54 | hidden_dropout_prob: The dropout probabilitiy for all fully connected 55 | layers in the embeddings, encoder, and pooler. 56 | attention_probs_dropout_prob: The dropout ratio for the attention 57 | probabilities. 58 | max_position_embeddings: The maximum sequence length that this model might 59 | ever be used with. Typically set this to something large just in case 60 | (e.g., 512 or 1024 or 2048). 61 | type_vocab_size: The vocabulary size of the `token_type_ids` passed into 62 | `BertModel`. 63 | initializer_range: The sttdev of the truncated_normal_initializer for 64 | initializing all weight matrices. 65 | """ 66 | self.vocab_size = vocab_size 67 | self.hidden_size = hidden_size 68 | self.num_hidden_layers = num_hidden_layers 69 | self.num_attention_heads = num_attention_heads 70 | self.hidden_act = hidden_act 71 | self.intermediate_size = intermediate_size 72 | self.hidden_dropout_prob = hidden_dropout_prob 73 | self.attention_probs_dropout_prob = attention_probs_dropout_prob 74 | self.max_position_embeddings = max_position_embeddings 75 | self.type_vocab_size = type_vocab_size 76 | self.initializer_range = initializer_range 77 | 78 | @classmethod 79 | def from_dict(cls, json_object): 80 | """Constructs a `BertConfig` from a Python dictionary of parameters.""" 81 | config = BertConfig(vocab_size=None) 82 | for (key, value) in six.iteritems(json_object): 83 | config.__dict__[key] = value 84 | return config 85 | 86 | @classmethod 87 | def from_json_file(cls, json_file): 88 | """Constructs a `BertConfig` from a json file of parameters.""" 89 | with open(json_file, "r") as reader: 90 | text = reader.read() 91 | return cls.from_dict(json.loads(text)) 92 | 93 | def to_dict(self): 94 | """Serializes this instance to a Python dictionary.""" 95 | output = copy.deepcopy(self.__dict__) 96 | return output 97 | 98 | def to_json_string(self): 99 | """Serializes this instance to a JSON string.""" 100 | return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" 101 | 102 | 103 | class BERTLayerNorm(nn.Module): 104 | def __init__(self, config, variance_epsilon=1e-12): 105 | """Construct a layernorm module in the TF style (epsilon inside the square root). 106 | """ 107 | super(BERTLayerNorm, self).__init__() 108 | self.gamma = nn.Parameter(torch.ones(config.hidden_size)) 109 | self.beta = nn.Parameter(torch.zeros(config.hidden_size)) 110 | self.variance_epsilon = variance_epsilon 111 | 112 | def forward(self, x): 113 | u = x.mean(-1, keepdim=True) 114 | s = (x - u).pow(2).mean(-1, keepdim=True) 115 | x = (x - u) / torch.sqrt(s + self.variance_epsilon) 116 | return self.gamma * x + self.beta 117 | 118 | class BERTEmbeddings(nn.Module): 119 | def __init__(self, config): 120 | super(BERTEmbeddings, self).__init__() 121 | """Construct the embedding module from word, position and token_type embeddings. 122 | """ 123 | self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size) 124 | self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) 125 | self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) 126 | 127 | # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load 128 | # any TensorFlow checkpoint file 129 | self.LayerNorm = BERTLayerNorm(config) 130 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 131 | 132 | def forward(self, input_ids, token_type_ids=None): 133 | seq_length = input_ids.size(1) 134 | position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) 135 | position_ids = position_ids.unsqueeze(0).expand_as(input_ids) 136 | if token_type_ids is None: 137 | token_type_ids = torch.zeros_like(input_ids) 138 | 139 | words_embeddings = self.word_embeddings(input_ids) 140 | position_embeddings = self.position_embeddings(position_ids) 141 | token_type_embeddings = self.token_type_embeddings(token_type_ids) 142 | 143 | embeddings = words_embeddings + position_embeddings + token_type_embeddings 144 | embeddings = self.LayerNorm(embeddings) 145 | embeddings = self.dropout(embeddings) 146 | return embeddings 147 | 148 | 149 | class BERTSelfAttention(nn.Module): 150 | def __init__(self, config): 151 | super(BERTSelfAttention, self).__init__() 152 | if config.hidden_size % config.num_attention_heads != 0: 153 | raise ValueError( 154 | "The hidden size (%d) is not a multiple of the number of attention " 155 | "heads (%d)" % (config.hidden_size, config.num_attention_heads)) 156 | self.num_attention_heads = config.num_attention_heads 157 | self.attention_head_size = int(config.hidden_size / config.num_attention_heads) 158 | self.all_head_size = self.num_attention_heads * self.attention_head_size 159 | 160 | self.query = nn.Linear(config.hidden_size, self.all_head_size) 161 | self.key = nn.Linear(config.hidden_size, self.all_head_size) 162 | self.value = nn.Linear(config.hidden_size, self.all_head_size) 163 | 164 | self.dropout = nn.Dropout(config.attention_probs_dropout_prob) 165 | 166 | def transpose_for_scores(self, x): 167 | new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) 168 | x = x.view(*new_x_shape) 169 | return x.permute(0, 2, 1, 3) 170 | 171 | def forward(self, hidden_states, attention_mask): 172 | mixed_query_layer = self.query(hidden_states) 173 | mixed_key_layer = self.key(hidden_states) 174 | mixed_value_layer = self.value(hidden_states) 175 | 176 | query_layer = self.transpose_for_scores(mixed_query_layer) 177 | key_layer = self.transpose_for_scores(mixed_key_layer) 178 | value_layer = self.transpose_for_scores(mixed_value_layer) 179 | 180 | # Take the dot product between "query" and "key" to get the raw attention scores. 181 | attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) 182 | attention_scores = attention_scores / math.sqrt(self.attention_head_size) 183 | # Apply the attention mask is (precomputed for all layers in BertModel forward() function) 184 | attention_scores = attention_scores + attention_mask 185 | 186 | # Normalize the attention scores to probabilities. 187 | attention_probs = nn.Softmax(dim=-1)(attention_scores) 188 | 189 | # This is actually dropping out entire tokens to attend to, which might 190 | # seem a bit unusual, but is taken from the original Transformer paper. 191 | attention_probs = self.dropout(attention_probs) 192 | 193 | context_layer = torch.matmul(attention_probs, value_layer) 194 | context_layer = context_layer.permute(0, 2, 1, 3).contiguous() 195 | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) 196 | context_layer = context_layer.view(*new_context_layer_shape) 197 | return context_layer 198 | 199 | 200 | class BERTSelfOutput(nn.Module): 201 | def __init__(self, config): 202 | super(BERTSelfOutput, self).__init__() 203 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 204 | self.LayerNorm = BERTLayerNorm(config) 205 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 206 | 207 | def forward(self, hidden_states, input_tensor): 208 | hidden_states = self.dense(hidden_states) 209 | hidden_states = self.dropout(hidden_states) 210 | hidden_states = self.LayerNorm(hidden_states + input_tensor) 211 | return hidden_states 212 | 213 | 214 | class BERTAttention(nn.Module): 215 | def __init__(self, config): 216 | super(BERTAttention, self).__init__() 217 | self.self = BERTSelfAttention(config) 218 | self.output = BERTSelfOutput(config) 219 | 220 | def forward(self, input_tensor, attention_mask): 221 | self_output = self.self(input_tensor, attention_mask) 222 | attention_output = self.output(self_output, input_tensor) 223 | return attention_output 224 | 225 | 226 | class BERTIntermediate(nn.Module): 227 | def __init__(self, config): 228 | super(BERTIntermediate, self).__init__() 229 | self.dense = nn.Linear(config.hidden_size, config.intermediate_size) 230 | self.intermediate_act_fn = gelu 231 | 232 | def forward(self, hidden_states): 233 | hidden_states = self.dense(hidden_states) 234 | hidden_states = self.intermediate_act_fn(hidden_states) 235 | return hidden_states 236 | 237 | 238 | class BERTOutput(nn.Module): 239 | def __init__(self, config): 240 | super(BERTOutput, self).__init__() 241 | self.dense = nn.Linear(config.intermediate_size, config.hidden_size) 242 | self.LayerNorm = BERTLayerNorm(config) 243 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 244 | 245 | def forward(self, hidden_states, input_tensor): 246 | hidden_states = self.dense(hidden_states) 247 | hidden_states = self.dropout(hidden_states) 248 | hidden_states = self.LayerNorm(hidden_states + input_tensor) 249 | return hidden_states 250 | 251 | 252 | class BERTLayer(nn.Module): 253 | def __init__(self, config): 254 | super(BERTLayer, self).__init__() 255 | self.attention = BERTAttention(config) 256 | self.intermediate = BERTIntermediate(config) 257 | self.output = BERTOutput(config) 258 | 259 | def forward(self, hidden_states, attention_mask): 260 | attention_output = self.attention(hidden_states, attention_mask) 261 | intermediate_output = self.intermediate(attention_output) 262 | layer_output = self.output(intermediate_output, attention_output) 263 | return layer_output 264 | 265 | 266 | class BERTEncoder(nn.Module): 267 | def __init__(self, config): 268 | super(BERTEncoder, self).__init__() 269 | layer = BERTLayer(config) 270 | self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)]) 271 | 272 | def forward(self, hidden_states, attention_mask): 273 | all_encoder_layers = [] 274 | for layer_module in self.layer: 275 | hidden_states = layer_module(hidden_states, attention_mask) 276 | all_encoder_layers.append(hidden_states) 277 | return all_encoder_layers 278 | 279 | 280 | class BERTPooler(nn.Module): 281 | def __init__(self, config): 282 | super(BERTPooler, self).__init__() 283 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 284 | self.activation = nn.Tanh() 285 | 286 | def forward(self, hidden_states): 287 | # We "pool" the model by simply taking the hidden state corresponding 288 | # to the first token. 289 | first_token_tensor = hidden_states[:, 0] 290 | #return first_token_tensor 291 | pooled_output = self.dense(first_token_tensor) 292 | pooled_output = self.activation(pooled_output) 293 | return pooled_output 294 | 295 | 296 | class BertModel(nn.Module): 297 | """BERT model ("Bidirectional Embedding Representations from a Transformer"). 298 | 299 | Example usage: 300 | ```python 301 | # Already been converted into WordPiece token ids 302 | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) 303 | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) 304 | token_type_ids = torch.LongTensor([[0, 0, 1], [0, 2, 0]]) 305 | 306 | config = modeling.BertConfig(vocab_size=32000, hidden_size=512, 307 | num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024) 308 | 309 | model = modeling.BertModel(config=config) 310 | all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask) 311 | ``` 312 | """ 313 | def __init__(self, config: BertConfig): 314 | """Constructor for BertModel. 315 | 316 | Args: 317 | config: `BertConfig` instance. 318 | """ 319 | super(BertModel, self).__init__() 320 | self.embeddings = BERTEmbeddings(config) 321 | self.encoder = BERTEncoder(config) 322 | self.pooler = BERTPooler(config) 323 | 324 | def forward(self, input_ids, token_type_ids=None, attention_mask=None): 325 | if attention_mask is None: 326 | attention_mask = torch.ones_like(input_ids) 327 | if token_type_ids is None: 328 | token_type_ids = torch.zeros_like(input_ids) 329 | 330 | # We create a 3D attention mask from a 2D tensor mask. 331 | # Sizes are [batch_size, 1, 1, from_seq_length] 332 | # So we can broadcast to [batch_size, num_heads, to_seq_length, from_seq_length] 333 | # this attention mask is more simple than the triangular masking of causal attention 334 | # used in OpenAI GPT, we just need to prepare the broadcast dimension here. 335 | extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) 336 | 337 | # Since attention_mask is 1.0 for positions we want to attend and 0.0 for 338 | # masked positions, this operation will create a tensor which is 0.0 for 339 | # positions we want to attend and -10000.0 for masked positions. 340 | # Since we are adding it to the raw scores before the softmax, this is 341 | # effectively the same as removing these entirely. 342 | extended_attention_mask = extended_attention_mask.float() 343 | extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 344 | 345 | embedding_output = self.embeddings(input_ids, token_type_ids) 346 | all_encoder_layers = self.encoder(embedding_output, extended_attention_mask) 347 | sequence_output = all_encoder_layers[-1] 348 | pooled_output = self.pooler(sequence_output) 349 | return all_encoder_layers, pooled_output 350 | 351 | class BertForSequenceClassification(nn.Module): 352 | """BERT model for classification. 353 | This module is composed of the BERT model with a linear layer on top of 354 | the pooled output. 355 | 356 | Example usage: 357 | ```python 358 | # Already been converted into WordPiece token ids 359 | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) 360 | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) 361 | token_type_ids = torch.LongTensor([[0, 0, 1], [0, 2, 0]]) 362 | 363 | config = BertConfig(vocab_size=32000, hidden_size=512, 364 | num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024) 365 | 366 | num_labels = 2 367 | 368 | model = BertForSequenceClassification(config, num_labels) 369 | logits = model(input_ids, token_type_ids, input_mask) 370 | ``` 371 | """ 372 | def __init__(self, config, num_labels): 373 | super(BertForSequenceClassification, self).__init__() 374 | self.bert = BertModel(config) 375 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 376 | self.classifier = nn.Linear(config.hidden_size, num_labels) 377 | 378 | def init_weights(module): 379 | if isinstance(module, (nn.Linear, nn.Embedding)): 380 | # Slightly different from the TF version which uses truncated_normal for initialization 381 | # cf https://github.com/pytorch/pytorch/pull/5617 382 | module.weight.data.normal_(mean=0.0, std=config.initializer_range) 383 | elif isinstance(module, BERTLayerNorm): 384 | module.beta.data.normal_(mean=0.0, std=config.initializer_range) 385 | module.gamma.data.normal_(mean=0.0, std=config.initializer_range) 386 | if isinstance(module, nn.Linear): 387 | module.bias.data.zero_() 388 | self.apply(init_weights) 389 | 390 | def forward(self, input_ids, token_type_ids, attention_mask, labels=None): 391 | _, pooled_output = self.bert(input_ids, token_type_ids, attention_mask) 392 | pooled_output = self.dropout(pooled_output) 393 | logits = self.classifier(pooled_output) 394 | 395 | if labels is not None: 396 | loss_fct = CrossEntropyLoss() 397 | loss = loss_fct(logits, labels) 398 | return loss, logits 399 | else: 400 | return logits 401 | 402 | 403 | class BertForQuestionAnswering(nn.Module): 404 | """BERT model for Question Answering (span extraction). 405 | This module is composed of the BERT model with a linear layer on top of 406 | the sequence output that computes start_logits and end_logits 407 | 408 | Example usage: 409 | ```python 410 | # Already been converted into WordPiece token ids 411 | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) 412 | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) 413 | token_type_ids = torch.LongTensor([[0, 0, 1], [0, 2, 0]]) 414 | 415 | config = BertConfig(vocab_size=32000, hidden_size=512, 416 | num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024) 417 | 418 | model = BertForQuestionAnswering(config) 419 | start_logits, end_logits = model(input_ids, token_type_ids, input_mask) 420 | ``` 421 | """ 422 | def __init__(self, config): 423 | super(BertForQuestionAnswering, self).__init__() 424 | self.bert = BertModel(config) 425 | # TODO check with Google if it's normal there is no dropout on the token classifier of SQuAD in the TF version 426 | # self.dropout = nn.Dropout(config.hidden_dropout_prob) 427 | self.qa_outputs = nn.Linear(config.hidden_size, 2) 428 | 429 | def init_weights(module): 430 | if isinstance(module, (nn.Linear, nn.Embedding)): 431 | # Slightly different from the TF version which uses truncated_normal for initialization 432 | # cf https://github.com/pytorch/pytorch/pull/5617 433 | module.weight.data.normal_(mean=0.0, std=config.initializer_range) 434 | elif isinstance(module, BERTLayerNorm): 435 | module.beta.data.normal_(mean=0.0, std=config.initializer_range) 436 | module.gamma.data.normal_(mean=0.0, std=config.initializer_range) 437 | if isinstance(module, nn.Linear): 438 | module.bias.data.zero_() 439 | self.apply(init_weights) 440 | 441 | def forward(self, input_ids, token_type_ids, attention_mask, start_positions=None, end_positions=None): 442 | all_encoder_layers, _ = self.bert(input_ids, token_type_ids, attention_mask) 443 | sequence_output = all_encoder_layers[-1] 444 | logits = self.qa_outputs(sequence_output) 445 | start_logits, end_logits = logits.split(1, dim=-1) 446 | start_logits = start_logits.squeeze(-1) 447 | end_logits = end_logits.squeeze(-1) 448 | 449 | if start_positions is not None and end_positions is not None: 450 | # If we are on multi-GPU, split add a dimension - if not this is a no-op 451 | start_positions = start_positions.squeeze(-1) 452 | end_positions = end_positions.squeeze(-1) 453 | # sometimes the start/end positions are outside our model inputs, we ignore these terms 454 | ignored_index = start_logits.size(1) 455 | start_positions.clamp_(0, ignored_index) 456 | end_positions.clamp_(0, ignored_index) 457 | 458 | loss_fct = CrossEntropyLoss(ignore_index=ignored_index) 459 | start_loss = loss_fct(start_logits, start_positions) 460 | end_loss = loss_fct(end_logits, end_positions) 461 | total_loss = (start_loss + end_loss) / 2 462 | return total_loss 463 | else: 464 | return start_logits, end_logits 465 | -------------------------------------------------------------------------------- /run_classifier_TABSA.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | """BERT finetuning runner.""" 4 | 5 | from __future__ import absolute_import, division, print_function 6 | 7 | import argparse 8 | import collections 9 | import logging 10 | import os 11 | import random 12 | 13 | import numpy as np 14 | import torch 15 | import torch.nn.functional as F 16 | from torch.utils.data import DataLoader, TensorDataset 17 | from torch.utils.data.distributed import DistributedSampler 18 | from torch.utils.data.sampler import RandomSampler, SequentialSampler 19 | from tqdm import tqdm, trange 20 | 21 | import tokenization 22 | from modeling import BertConfig, BertForSequenceClassification 23 | from optimization import BERTAdam 24 | from processor import (Semeval_NLI_B_Processor, Semeval_NLI_M_Processor, 25 | Semeval_QA_B_Processor, Semeval_QA_M_Processor, 26 | Semeval_single_Processor, Sentihood_NLI_B_Processor, 27 | Sentihood_NLI_M_Processor, Sentihood_QA_B_Processor, 28 | Sentihood_QA_M_Processor, Sentihood_single_Processor) 29 | 30 | logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', 31 | datefmt = '%m/%d/%Y %H:%M:%S', 32 | level = logging.INFO) 33 | logger = logging.getLogger(__name__) 34 | 35 | 36 | class InputFeatures(object): 37 | """A single set of features of data.""" 38 | 39 | def __init__(self, input_ids, input_mask, segment_ids, label_id): 40 | self.input_ids = input_ids 41 | self.input_mask = input_mask 42 | self.segment_ids = segment_ids 43 | self.label_id = label_id 44 | 45 | 46 | def convert_examples_to_features(examples, label_list, max_seq_length, tokenizer): 47 | """Loads a data file into a list of `InputBatch`s.""" 48 | 49 | label_map = {} 50 | for (i, label) in enumerate(label_list): 51 | label_map[label] = i 52 | 53 | features = [] 54 | for (ex_index, example) in enumerate(tqdm(examples)): 55 | tokens_a = tokenizer.tokenize(example.text_a) 56 | 57 | tokens_b = None 58 | if example.text_b: 59 | tokens_b = tokenizer.tokenize(example.text_b) 60 | 61 | if tokens_b: 62 | # Modifies `tokens_a` and `tokens_b` in place so that the total 63 | # length is less than the specified length. 64 | # Account for [CLS], [SEP], [SEP] with "- 3" 65 | _truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3) 66 | else: 67 | # Account for [CLS] and [SEP] with "- 2" 68 | if len(tokens_a) > max_seq_length - 2: 69 | tokens_a = tokens_a[0:(max_seq_length - 2)] 70 | 71 | # The convention in BERT is: 72 | # (a) For sequence pairs: 73 | # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP] 74 | # type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1 75 | # (b) For single sequences: 76 | # tokens: [CLS] the dog is hairy . [SEP] 77 | # type_ids: 0 0 0 0 0 0 0 78 | # 79 | # Where "type_ids" are used to indicate whether this is the first 80 | # sequence or the second sequence. The embedding vectors for `type=0` and 81 | # `type=1` were learned during pre-training and are added to the wordpiece 82 | # embedding vector (and position vector). This is not *strictly* necessary 83 | # since the [SEP] token unambigiously separates the sequences, but it makes 84 | # it easier for the model to learn the concept of sequences. 85 | # 86 | # For classification tasks, the first vector (corresponding to [CLS]) is 87 | # used as as the "sentence vector". Note that this only makes sense because 88 | # the entire model is fine-tuned. 89 | tokens = [] 90 | segment_ids = [] 91 | tokens.append("[CLS]") 92 | segment_ids.append(0) 93 | for token in tokens_a: 94 | tokens.append(token) 95 | segment_ids.append(0) 96 | tokens.append("[SEP]") 97 | segment_ids.append(0) 98 | 99 | if tokens_b: 100 | for token in tokens_b: 101 | tokens.append(token) 102 | segment_ids.append(1) 103 | tokens.append("[SEP]") 104 | segment_ids.append(1) 105 | 106 | input_ids = tokenizer.convert_tokens_to_ids(tokens) 107 | 108 | # The mask has 1 for real tokens and 0 for padding tokens. Only real 109 | # tokens are attended to. 110 | input_mask = [1] * len(input_ids) 111 | 112 | # Zero-pad up to the sequence length. 113 | while len(input_ids) < max_seq_length: 114 | input_ids.append(0) 115 | input_mask.append(0) 116 | segment_ids.append(0) 117 | 118 | assert len(input_ids) == max_seq_length 119 | assert len(input_mask) == max_seq_length 120 | assert len(segment_ids) == max_seq_length 121 | 122 | label_id = label_map[example.label] 123 | 124 | features.append( 125 | InputFeatures( 126 | input_ids=input_ids, 127 | input_mask=input_mask, 128 | segment_ids=segment_ids, 129 | label_id=label_id)) 130 | return features 131 | 132 | 133 | def _truncate_seq_pair(tokens_a, tokens_b, max_length): 134 | """Truncates a sequence pair in place to the maximum length.""" 135 | 136 | # This is a simple heuristic which will always truncate the longer sequence 137 | # one token at a time. This makes more sense than truncating an equal percent 138 | # of tokens from each, since if one sequence is very short then each token 139 | # that's truncated likely contains more information than a longer sequence. 140 | while True: 141 | total_length = len(tokens_a) + len(tokens_b) 142 | if total_length <= max_length: 143 | break 144 | if len(tokens_a) > len(tokens_b): 145 | tokens_a.pop() 146 | else: 147 | tokens_b.pop() 148 | 149 | 150 | def main(): 151 | parser = argparse.ArgumentParser() 152 | 153 | ## Required parameters 154 | parser.add_argument("--task_name", 155 | default=None, 156 | type=str, 157 | required=True, 158 | choices=["sentihood_single", "sentihood_NLI_M", "sentihood_QA_M", \ 159 | "sentihood_NLI_B", "sentihood_QA_B", "semeval_single", \ 160 | "semeval_NLI_M", "semeval_QA_M", "semeval_NLI_B", "semeval_QA_B"], 161 | help="The name of the task to train.") 162 | parser.add_argument("--data_dir", 163 | default=None, 164 | type=str, 165 | required=True, 166 | help="The input data dir. Should contain the .tsv files (or other data files) for the task.") 167 | parser.add_argument("--vocab_file", 168 | default=None, 169 | type=str, 170 | required=True, 171 | help="The vocabulary file that the BERT model was trained on.") 172 | parser.add_argument("--bert_config_file", 173 | default=None, 174 | type=str, 175 | required=True, 176 | help="The config json file corresponding to the pre-trained BERT model. \n" 177 | "This specifies the model architecture.") 178 | parser.add_argument("--output_dir", 179 | default=None, 180 | type=str, 181 | required=True, 182 | help="The output directory where the model checkpoints will be written.") 183 | 184 | ## Other parameters 185 | parser.add_argument("--init_checkpoint", 186 | default=None, 187 | type=str, 188 | help="Initial checkpoint (usually from a pre-trained BERT model).") 189 | parser.add_argument("--init_eval_checkpoint", 190 | default=None, 191 | type=str, 192 | help="Initial checkpoint (usually from a pre-trained BERT model + classifier).") 193 | parser.add_argument("--do_save_model", 194 | default=False, 195 | action='store_true', 196 | help="Whether to save model.") 197 | parser.add_argument("--eval_test", 198 | default=False, 199 | action='store_true', 200 | help="Whether to run eval on the test set.") 201 | parser.add_argument("--do_lower_case", 202 | default=False, 203 | action='store_true', 204 | help="Whether to lower case the input text. True for uncased models, False for cased models.") 205 | parser.add_argument("--max_seq_length", 206 | default=128, 207 | type=int, 208 | help="The maximum total input sequence length after WordPiece tokenization. \n" 209 | "Sequences longer than this will be truncated, and sequences shorter \n" 210 | "than this will be padded.") 211 | parser.add_argument("--train_batch_size", 212 | default=32, 213 | type=int, 214 | help="Total batch size for training.") 215 | parser.add_argument("--eval_batch_size", 216 | default=8, 217 | type=int, 218 | help="Total batch size for eval.") 219 | parser.add_argument("--learning_rate", 220 | default=5e-5, 221 | type=float, 222 | help="The initial learning rate for Adam.") 223 | parser.add_argument("--num_train_epochs", 224 | default=3.0, 225 | type=float, 226 | help="Total number of training epochs to perform.") 227 | parser.add_argument("--warmup_proportion", 228 | default=0.1, 229 | type=float, 230 | help="Proportion of training to perform linear learning rate warmup for. " 231 | "E.g., 0.1 = 10%% of training.") 232 | parser.add_argument("--no_cuda", 233 | default=False, 234 | action='store_true', 235 | help="Whether not to use CUDA when available") 236 | parser.add_argument("--accumulate_gradients", 237 | type=int, 238 | default=1, 239 | help="Number of steps to accumulate gradient on (divide the batch_size and accumulate)") 240 | parser.add_argument("--local_rank", 241 | type=int, 242 | default=-1, 243 | help="local_rank for distributed training on gpus") 244 | parser.add_argument('--seed', 245 | type=int, 246 | default=42, 247 | help="random seed for initialization") 248 | parser.add_argument('--gradient_accumulation_steps', 249 | type=int, 250 | default=1, 251 | help="Number of updates steps to accumualte before performing a backward/update pass.") 252 | args = parser.parse_args() 253 | 254 | 255 | if args.local_rank == -1 or args.no_cuda: 256 | device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") 257 | n_gpu = torch.cuda.device_count() 258 | else: 259 | device = torch.device("cuda", args.local_rank) 260 | n_gpu = 1 261 | # Initializes the distributed backend which will take care of sychronizing nodes/GPUs 262 | torch.distributed.init_process_group(backend='nccl') 263 | logger.info("device %s n_gpu %d distributed training %r", device, n_gpu, bool(args.local_rank != -1)) 264 | 265 | if args.accumulate_gradients < 1: 266 | raise ValueError("Invalid accumulate_gradients parameter: {}, should be >= 1".format( 267 | args.accumulate_gradients)) 268 | 269 | args.train_batch_size = int(args.train_batch_size / args.accumulate_gradients) 270 | 271 | random.seed(args.seed) 272 | np.random.seed(args.seed) 273 | torch.manual_seed(args.seed) 274 | if n_gpu > 0: 275 | torch.cuda.manual_seed_all(args.seed) 276 | 277 | bert_config = BertConfig.from_json_file(args.bert_config_file) 278 | 279 | if args.max_seq_length > bert_config.max_position_embeddings: 280 | raise ValueError( 281 | "Cannot use sequence length {} because the BERT model was only trained up to sequence length {}".format( 282 | args.max_seq_length, bert_config.max_position_embeddings)) 283 | 284 | if os.path.exists(args.output_dir) and os.listdir(args.output_dir): 285 | raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir)) 286 | os.makedirs(args.output_dir, exist_ok=True) 287 | 288 | 289 | # prepare dataloaders 290 | processors = { 291 | "sentihood_single":Sentihood_single_Processor, 292 | "sentihood_NLI_M":Sentihood_NLI_M_Processor, 293 | "sentihood_QA_M":Sentihood_QA_M_Processor, 294 | "sentihood_NLI_B":Sentihood_NLI_B_Processor, 295 | "sentihood_QA_B":Sentihood_QA_B_Processor, 296 | "semeval_single":Semeval_single_Processor, 297 | "semeval_NLI_M":Semeval_NLI_M_Processor, 298 | "semeval_QA_M":Semeval_QA_M_Processor, 299 | "semeval_NLI_B":Semeval_NLI_B_Processor, 300 | "semeval_QA_B":Semeval_QA_B_Processor, 301 | } 302 | 303 | processor = processors[args.task_name]() 304 | label_list = processor.get_labels() 305 | 306 | tokenizer = tokenization.FullTokenizer( 307 | vocab_file=args.vocab_file, do_lower_case=args.do_lower_case) 308 | 309 | # training set 310 | train_examples = None 311 | num_train_steps = None 312 | train_examples = processor.get_train_examples(args.data_dir) 313 | num_train_steps = int( 314 | len(train_examples) / args.train_batch_size * args.num_train_epochs) 315 | 316 | train_features = convert_examples_to_features( 317 | train_examples, label_list, args.max_seq_length, tokenizer) 318 | logger.info("***** Running training *****") 319 | logger.info(" Num examples = %d", len(train_examples)) 320 | logger.info(" Batch size = %d", args.train_batch_size) 321 | logger.info(" Num steps = %d", num_train_steps) 322 | 323 | all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long) 324 | all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long) 325 | all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long) 326 | all_label_ids = torch.tensor([f.label_id for f in train_features], dtype=torch.long) 327 | 328 | train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids) 329 | if args.local_rank == -1: 330 | train_sampler = RandomSampler(train_data) 331 | else: 332 | train_sampler = DistributedSampler(train_data) 333 | train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size) 334 | 335 | # test set 336 | if args.eval_test: 337 | test_examples = processor.get_test_examples(args.data_dir) 338 | test_features = convert_examples_to_features( 339 | test_examples, label_list, args.max_seq_length, tokenizer) 340 | 341 | all_input_ids = torch.tensor([f.input_ids for f in test_features], dtype=torch.long) 342 | all_input_mask = torch.tensor([f.input_mask for f in test_features], dtype=torch.long) 343 | all_segment_ids = torch.tensor([f.segment_ids for f in test_features], dtype=torch.long) 344 | all_label_ids = torch.tensor([f.label_id for f in test_features], dtype=torch.long) 345 | 346 | test_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids) 347 | test_dataloader = DataLoader(test_data, batch_size=args.eval_batch_size, shuffle=False) 348 | 349 | 350 | # model and optimizer 351 | model = BertForSequenceClassification(bert_config, len(label_list)) 352 | if args.init_eval_checkpoint is not None: 353 | model.load_state_dict(torch.load(args.init_eval_checkpoint, map_location='cpu')) 354 | elif args.init_checkpoint is not None: 355 | model.bert.load_state_dict(torch.load(args.init_checkpoint, map_location='cpu')) 356 | model.to(device) 357 | 358 | if args.local_rank != -1: 359 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], 360 | output_device=args.local_rank) 361 | elif n_gpu > 1: 362 | model = torch.nn.DataParallel(model) 363 | 364 | no_decay = ['bias', 'gamma', 'beta'] 365 | optimizer_parameters = [ 366 | {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay_rate': 0.01}, 367 | {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay_rate': 0.0} 368 | ] 369 | 370 | optimizer = BERTAdam(optimizer_parameters, 371 | lr=args.learning_rate, 372 | warmup=args.warmup_proportion, 373 | t_total=num_train_steps) 374 | 375 | 376 | # train 377 | output_log_file = os.path.join(args.output_dir, "log.txt") 378 | print("output_log_file=",output_log_file) 379 | with open(output_log_file, "w") as writer: 380 | if args.eval_test: 381 | writer.write("epoch\tglobal_step\tloss\ttest_loss\ttest_accuracy\n") 382 | else: 383 | writer.write("epoch\tglobal_step\tloss\n") 384 | 385 | global_step = 0 386 | epoch = 0 387 | for _ in trange(int(args.num_train_epochs), desc="Epoch"): 388 | epoch += 1 389 | model.train() 390 | tr_loss = 0 391 | nb_tr_examples, nb_tr_steps = 0, 0 392 | for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")): 393 | batch = tuple(t.to(device) for t in batch) 394 | input_ids, input_mask, segment_ids, label_ids = batch 395 | loss, _ = model(input_ids, segment_ids, input_mask, label_ids) 396 | if n_gpu > 1: 397 | loss = loss.mean() # mean() to average on multi-gpu. 398 | if args.gradient_accumulation_steps > 1: 399 | loss = loss / args.gradient_accumulation_steps 400 | loss.backward() 401 | tr_loss += loss.item() 402 | nb_tr_examples += input_ids.size(0) 403 | nb_tr_steps += 1 404 | if (step + 1) % args.gradient_accumulation_steps == 0: 405 | optimizer.step() # We have accumulated enought gradients 406 | model.zero_grad() 407 | global_step += 1 408 | 409 | if args.do_save_model: 410 | if n_gpu > 1: 411 | torch.save(model.module.state_dict(), os.path.join(args.output_dir, f'model_ep_{epoch}.bin')) 412 | else: 413 | torch.save(model.state_dict(), os.path.join(args.output_dir, f'model_ep_{epoch}.bin')) 414 | 415 | # eval_test 416 | if args.eval_test: 417 | model.eval() 418 | test_loss, test_accuracy = 0, 0 419 | nb_test_steps, nb_test_examples = 0, 0 420 | with open(os.path.join(args.output_dir, f"test_ep_{epoch}.txt"), "w") as f_test: 421 | for input_ids, input_mask, segment_ids, label_ids in test_dataloader: 422 | input_ids = input_ids.to(device) 423 | input_mask = input_mask.to(device) 424 | segment_ids = segment_ids.to(device) 425 | label_ids = label_ids.to(device) 426 | 427 | with torch.no_grad(): 428 | tmp_test_loss, logits = model(input_ids, segment_ids, input_mask, label_ids) 429 | 430 | logits = F.softmax(logits, dim=-1) 431 | logits = logits.detach().cpu().numpy() 432 | label_ids = label_ids.to('cpu').numpy() 433 | outputs = np.argmax(logits, axis=1) 434 | for output_i in range(len(outputs)): 435 | f_test.write(str(outputs[output_i])) 436 | for ou in logits[output_i]: 437 | f_test.write(" "+str(ou)) 438 | f_test.write("\n") 439 | tmp_test_accuracy=np.sum(outputs == label_ids) 440 | 441 | test_loss += tmp_test_loss.mean().item() 442 | test_accuracy += tmp_test_accuracy 443 | 444 | nb_test_examples += input_ids.size(0) 445 | nb_test_steps += 1 446 | 447 | test_loss = test_loss / nb_test_steps 448 | test_accuracy = test_accuracy / nb_test_examples 449 | 450 | 451 | result = collections.OrderedDict() 452 | if args.eval_test: 453 | result = {'epoch': epoch, 454 | 'global_step': global_step, 455 | 'loss': tr_loss/nb_tr_steps, 456 | 'test_loss': test_loss, 457 | 'test_accuracy': test_accuracy} 458 | else: 459 | result = {'epoch': epoch, 460 | 'global_step': global_step, 461 | 'loss': tr_loss/nb_tr_steps} 462 | 463 | logger.info("***** Eval results *****") 464 | with open(output_log_file, "a+") as writer: 465 | for key in result.keys(): 466 | logger.info(" %s = %s\n", key, str(result[key])) 467 | writer.write("%s\t" % (str(result[key]))) 468 | writer.write("\n") 469 | 470 | if __name__ == "__main__": 471 | main() 472 | --------------------------------------------------------------------------------