├── generate
├── make.sh
├── generate_semeval_NLI_B_QA_B.py
├── data_utils_sentihood.py
├── generate_sentihood_NLI_M.py
├── generate_sentihood_QA_M.py
├── generate_semeval_NLI_M.py
├── generate_sentihood_BERT_single.py
├── generate_semeval_BERT_single.py
├── generate_semeval_QA_M.py
└── generate_sentihood_NLI_B_QA_B.py
├── LICENSE
├── convert_tf_checkpoint_to_pytorch.py
├── README.md
├── optimization.py
├── tokenization.py
├── evaluation.py
├── processor.py
├── modeling.py
└── run_classifier_TABSA.py
/generate/make.sh:
--------------------------------------------------------------------------------
1 | # generate datasets
2 |
3 | python generate_${1}_NLI_M.py
4 | python generate_${1}_QA_M.py
5 | python generate_${1}_NLI_B_QA_B.py
6 | python generate_${1}_BERT_single.py
7 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2019 HSLCY
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/generate/generate_semeval_NLI_B_QA_B.py:
--------------------------------------------------------------------------------
1 | data_dir='../data/semeval2014/bert-pair/'
2 |
3 | labels=['positive', 'neutral', 'negative', 'conflict', 'none']
4 | with open(data_dir+"test_NLI_M.csv","r",encoding="utf-8") as f, \
5 | open(data_dir+"test_NLI_B.csv","w",encoding="utf-8") as g_nli, \
6 | open(data_dir+"test_QA_B.csv","w",encoding="utf-8") as g_qa:
7 | s=f.readline().strip()
8 | while s:
9 | tmp=s.split("\t")
10 | for label in labels:
11 | t_nli = label + " - " + tmp[2]
12 | t_qa = "the polarity of the aspect " + tmp[2] + " is " + label + " ."
13 | if tmp[1]==label:
14 | g_nli.write(tmp[0]+"\t1\t"+t_nli+"\t"+tmp[3]+"\n")
15 | g_qa.write(tmp[0]+"\t1\t"+t_qa+"\t"+tmp[3]+"\n")
16 | else:
17 | g_nli.write(tmp[0]+"\t0\t"+t_nli+"\t"+tmp[3]+"\n")
18 | g_qa.write(tmp[0]+"\t0\t"+t_qa+"\t"+tmp[3]+"\n")
19 | s = f.readline().strip()
20 |
21 |
22 | with open(data_dir+"train_NLI_M.csv","r",encoding="utf-8") as f, \
23 | open(data_dir+"train_NLI_B.csv","w",encoding="utf-8") as g_nli, \
24 | open(data_dir+"train_QA_B.csv","w",encoding="utf-8") as g_qa:
25 | s=f.readline().strip()
26 | while s:
27 | tmp=s.split("\t")
28 | for label in labels:
29 | t_nli = label + " - " + tmp[2]
30 | t_qa = "the polarity of the aspect " + tmp[2] + " is " + label + " ."
31 | if tmp[1]==label:
32 | g_nli.write(tmp[0]+"\t1\t"+t_nli+"\t"+tmp[3]+"\n")
33 | g_qa.write(tmp[0]+"\t1\t"+t_qa+"\t"+tmp[3]+"\n")
34 | else:
35 | g_nli.write(tmp[0]+"\t0\t"+t_nli+"\t"+tmp[3]+"\n")
36 | g_qa.write(tmp[0]+"\t0\t"+t_qa+"\t"+tmp[3]+"\n")
37 | s = f.readline().strip()
--------------------------------------------------------------------------------
/generate/data_utils_sentihood.py:
--------------------------------------------------------------------------------
1 | # Reference: https://github.com/liufly/delayed-memory-update-entnet
2 |
3 | from __future__ import absolute_import
4 |
5 | import json
6 | import operator
7 | import os
8 | import re
9 | import sys
10 | import xml.etree.ElementTree
11 |
12 | import nltk
13 | import numpy as np
14 |
15 |
16 | def load_task(data_dir, aspect2idx):
17 | in_file = os.path.join(data_dir, 'sentihood-train.json')
18 | train = parse_sentihood_json(in_file)
19 | in_file = os.path.join(data_dir, 'sentihood-dev.json')
20 | dev = parse_sentihood_json(in_file)
21 | in_file = os.path.join(data_dir, 'sentihood-test.json')
22 | test = parse_sentihood_json(in_file)
23 |
24 | train = convert_input(train, aspect2idx)
25 | train_aspect_idx = get_aspect_idx(train, aspect2idx)
26 | train = tokenize(train)
27 | dev = convert_input(dev, aspect2idx)
28 | dev_aspect_idx = get_aspect_idx(dev, aspect2idx)
29 | dev = tokenize(dev)
30 | test = convert_input(test, aspect2idx)
31 | test_aspect_idx = get_aspect_idx(test, aspect2idx)
32 | test = tokenize(test)
33 |
34 | return (train, train_aspect_idx), (dev, dev_aspect_idx), (test, test_aspect_idx)
35 |
36 |
37 | def get_aspect_idx(data, aspect2idx):
38 | ret = []
39 | for _, _, _, aspect, _ in data:
40 | ret.append(aspect2idx[aspect])
41 | assert len(data) == len(ret)
42 | return np.array(ret)
43 |
44 |
45 | def parse_sentihood_json(in_file):
46 | with open(in_file) as f:
47 | data = json.load(f)
48 | ret = []
49 | for d in data:
50 | text = d['text']
51 | sent_id = d['id']
52 | opinions = []
53 | targets = set()
54 | for opinion in d['opinions']:
55 | sentiment = opinion['sentiment']
56 | aspect = opinion['aspect']
57 | target_entity = opinion['target_entity']
58 | targets.add(target_entity)
59 | opinions.append((target_entity, aspect, sentiment))
60 | ret.append((sent_id, text, opinions))
61 | return ret
62 |
63 |
64 | def convert_input(data, all_aspects):
65 | ret = []
66 | for sent_id, text, opinions in data:
67 | for target_entity, aspect, sentiment in opinions:
68 | if aspect not in all_aspects:
69 | continue
70 | ret.append((sent_id, text, target_entity, aspect, sentiment))
71 | assert 'LOCATION1' in text
72 | targets = set(['LOCATION1'])
73 | if 'LOCATION2' in text:
74 | targets.add('LOCATION2')
75 | for target in targets:
76 | aspects = set([a for t, a, _ in opinions if t == target])
77 | none_aspects = [a for a in all_aspects if a not in aspects]
78 | for aspect in none_aspects:
79 | ret.append((sent_id, text, target, aspect, 'None'))
80 | return ret
81 |
82 |
83 | def tokenize(data):
84 | ret = []
85 | for sent_id, text, target_entity, aspect, sentiment in data:
86 | new_text = nltk.word_tokenize(text)
87 | new_aspect = aspect.split('-')
88 | ret.append((sent_id, new_text, target_entity, new_aspect, sentiment))
89 | return ret
90 |
--------------------------------------------------------------------------------
/convert_tf_checkpoint_to_pytorch.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 |
3 | # Reference: https://github.com/huggingface/pytorch-pretrained-BERT
4 |
5 | """Convert BERT checkpoint."""
6 |
7 | from __future__ import absolute_import, division, print_function
8 |
9 | import argparse
10 | import re
11 |
12 | import numpy as np
13 | import torch
14 |
15 | import tensorflow as tf
16 | from modeling import BertConfig, BertModel
17 |
18 | parser = argparse.ArgumentParser()
19 |
20 | ## Required parameters
21 | parser.add_argument("--tf_checkpoint_path",
22 | default = None,
23 | type = str,
24 | required = True,
25 | help = "Path the TensorFlow checkpoint path.")
26 | parser.add_argument("--bert_config_file",
27 | default = None,
28 | type = str,
29 | required = True,
30 | help = "The config json file corresponding to the pre-trained BERT model. \n"
31 | "This specifies the model architecture.")
32 | parser.add_argument("--pytorch_dump_path",
33 | default = None,
34 | type = str,
35 | required = True,
36 | help = "Path to the output PyTorch model.")
37 |
38 | args = parser.parse_args()
39 |
40 | def convert():
41 | # Initialise PyTorch model
42 | config = BertConfig.from_json_file(args.bert_config_file)
43 | model = BertModel(config)
44 |
45 | # Load weights from TF model
46 | path = args.tf_checkpoint_path
47 | print("Converting TensorFlow checkpoint from {}".format(path))
48 |
49 | init_vars = tf.train.list_variables(path)
50 | names = []
51 | arrays = []
52 | for name, shape in init_vars:
53 | print("Loading {} with shape {}".format(name, shape))
54 | array = tf.train.load_variable(path, name)
55 | print("Numpy array shape {}".format(array.shape))
56 | names.append(name)
57 | arrays.append(array)
58 |
59 | for name, array in zip(names, arrays):
60 | name = name[5:] # skip "bert/"
61 | print("Loading {}".format(name))
62 | name = name.split('/')
63 | if any(n in ["adam_v", "adam_m","l_step"] for n in name):
64 | print("Skipping {}".format("/".join(name)))
65 | continue
66 | if name[0] in ['redictions', 'eq_relationship']:
67 | print("Skipping")
68 | continue
69 | pointer = model
70 | for m_name in name:
71 | if re.fullmatch(r'[A-Za-z]+_\d+', m_name):
72 | l = re.split(r'_(\d+)', m_name)
73 | else:
74 | l = [m_name]
75 | if l[0] == 'kernel':
76 | pointer = getattr(pointer, 'weight')
77 | else:
78 | pointer = getattr(pointer, l[0])
79 | if len(l) >= 2:
80 | num = int(l[1])
81 | pointer = pointer[num]
82 | if m_name[-11:] == '_embeddings':
83 | pointer = getattr(pointer, 'weight')
84 | elif m_name == 'kernel':
85 | array = np.transpose(array)
86 | try:
87 | assert pointer.shape == array.shape
88 | except AssertionError as e:
89 | e.args += (pointer.shape, array.shape)
90 | raise
91 | pointer.data = torch.from_numpy(array)
92 |
93 | # Save pytorch-model
94 | torch.save(model.state_dict(), args.pytorch_dump_path)
95 |
96 | if __name__ == "__main__":
97 | convert()
98 |
--------------------------------------------------------------------------------
/generate/generate_sentihood_NLI_M.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from data_utils_sentihood import *
4 |
5 | data_dir='../data/sentihood/'
6 | aspect2idx = {
7 | 'general': 0,
8 | 'price': 1,
9 | 'transit-location': 2,
10 | 'safety': 3,
11 | }
12 |
13 | (train, train_aspect_idx), (val, val_aspect_idx), (test, test_aspect_idx) = load_task(data_dir, aspect2idx)
14 |
15 | print("len(train) = ", len(train))
16 | print("len(val) = ", len(val))
17 | print("len(test) = ", len(test))
18 |
19 | train.sort(key=lambda x:x[2]+str(x[0])+x[3][0])
20 | val.sort(key=lambda x:x[2]+str(x[0])+x[3][0])
21 | test.sort(key=lambda x:x[2]+str(x[0])+x[3][0])
22 |
23 | dir_path = data_dir+'bert-pair/'
24 | if not os.path.exists(dir_path):
25 | os.makedirs(dir_path)
26 |
27 | with open(dir_path+"train_NLI_M.tsv","w",encoding="utf-8") as f:
28 | f.write("id\tsentence1\tsentence2\tlabel\n")
29 | for v in train:
30 | f.write(str(v[0])+"\t")
31 | word=v[1][0].lower()
32 | if word=='location1':f.write('location - 1')
33 | elif word=='location2':f.write('location - 2')
34 | elif word[0]=='\'':f.write("\' "+word[1:])
35 | else:f.write(word)
36 | for i in range(1,len(v[1])):
37 | word=v[1][i].lower()
38 | f.write(" ")
39 | if word == 'location1':
40 | f.write('location - 1')
41 | elif word == 'location2':
42 | f.write('location - 2')
43 | elif word[0] == '\'':
44 | f.write("\' " + word[1:])
45 | else:
46 | f.write(word)
47 | f.write("\t")
48 | if v[2]=='LOCATION1':f.write('location - 1 - ')
49 | if v[2]=='LOCATION2':f.write('location - 2 - ')
50 | if len(v[3])==1:
51 | f.write(v[3][0]+"\t")
52 | else:
53 | f.write("transit location\t")
54 | f.write(v[4]+"\n")
55 |
56 | with open(dir_path+"dev_NLI_M.tsv","w",encoding="utf-8") as f:
57 | f.write("id\tsentence1\tsentence2\tlabel\n")
58 | for v in val:
59 | f.write(str(v[0])+"\t")
60 | word=v[1][0].lower()
61 | if word=='location1':f.write('location - 1')
62 | elif word=='location2':f.write('location - 2')
63 | elif word[0]=='\'':f.write("\' "+word[1:])
64 | else:f.write(word)
65 | for i in range(1,len(v[1])):
66 | word=v[1][i].lower()
67 | f.write(" ")
68 | if word == 'location1':
69 | f.write('location - 1')
70 | elif word == 'location2':
71 | f.write('location - 2')
72 | elif word[0] == '\'':
73 | f.write("\' " + word[1:])
74 | else:
75 | f.write(word)
76 | f.write("\t")
77 | if v[2]=='LOCATION1':f.write('location - 1 - ')
78 | if v[2]=='LOCATION2':f.write('location - 2 - ')
79 | if len(v[3])==1:
80 | f.write(v[3][0]+"\t")
81 | else:
82 | f.write("transit location\t")
83 | f.write(v[4]+"\n")
84 |
85 | with open(dir_path+"test_NLI_M.tsv","w",encoding="utf-8") as f:
86 | f.write("id\tsentence1\tsentence2\tlabel\n")
87 | for v in test:
88 | f.write(str(v[0])+"\t")
89 | word=v[1][0].lower()
90 | if word=='location1':f.write('location - 1')
91 | elif word=='location2':f.write('location - 2')
92 | elif word[0]=='\'':f.write("\' "+word[1:])
93 | else:f.write(word)
94 | for i in range(1,len(v[1])):
95 | word=v[1][i].lower()
96 | f.write(" ")
97 | if word == 'location1':
98 | f.write('location - 1')
99 | elif word == 'location2':
100 | f.write('location - 2')
101 | elif word[0] == '\'':
102 | f.write("\' " + word[1:])
103 | else:
104 | f.write(word)
105 | f.write("\t")
106 | if v[2]=='LOCATION1':f.write('location - 1 - ')
107 | if v[2]=='LOCATION2':f.write('location - 2 - ')
108 | if len(v[3])==1:
109 | f.write(v[3][0]+"\t")
110 | else:
111 | f.write("transit location\t")
112 | f.write(v[4]+"\n")
113 |
--------------------------------------------------------------------------------
/generate/generate_sentihood_QA_M.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from data_utils_sentihood import *
4 |
5 | data_dir='../data/sentihood/'
6 | aspect2idx = {
7 | 'general': 0,
8 | 'price': 1,
9 | 'transit-location': 2,
10 | 'safety': 3,
11 | }
12 |
13 | (train, train_aspect_idx), (val, val_aspect_idx), (test, test_aspect_idx) = load_task(data_dir, aspect2idx)
14 |
15 | print("len(train) = ", len(train))
16 | print("len(val) = ", len(val))
17 | print("len(test) = ", len(test))
18 |
19 | train.sort(key=lambda x:x[2]+str(x[0])+x[3][0])
20 | val.sort(key=lambda x:x[2]+str(x[0])+x[3][0])
21 | test.sort(key=lambda x:x[2]+str(x[0])+x[3][0])
22 |
23 | dir_path = data_dir+'bert-pair/'
24 | if not os.path.exists(dir_path):
25 | os.makedirs(dir_path)
26 |
27 | with open(dir_path+"train_QA_M.tsv","w",encoding="utf-8") as f:
28 | f.write("id\tsentence1\tsentence2\tlabel\n")
29 | for v in train:
30 | f.write(str(v[0])+"\t")
31 | word=v[1][0].lower()
32 | if word=='location1':f.write('location - 1')
33 | elif word=='location2':f.write('location - 2')
34 | elif word[0]=='\'':f.write("\' "+word[1:])
35 | else:f.write(word)
36 | for i in range(1,len(v[1])):
37 | word=v[1][i].lower()
38 | f.write(" ")
39 | if word == 'location1':
40 | f.write('location - 1')
41 | elif word == 'location2':
42 | f.write('location - 2')
43 | elif word[0] == '\'':
44 | f.write("\' " + word[1:])
45 | else:
46 | f.write(word)
47 | f.write("\t")
48 | f.write("what do you think of the ")
49 | if len(v[3])==1:
50 | f.write(v[3][0]+" ")
51 | else:
52 | f.write("transit location ")
53 | if v[2]=='LOCATION1':f.write('of location - 1 ?\t')
54 | if v[2]=='LOCATION2':f.write('of location - 2 ?\t')
55 | f.write(v[4]+"\n")
56 |
57 | with open(dir_path+"dev_QA_M.tsv","w",encoding="utf-8") as f:
58 | f.write("id\tsentence1\tsentence2\tlabel\n")
59 | for v in val:
60 | f.write(str(v[0])+"\t")
61 | word=v[1][0].lower()
62 | if word=='location1':f.write('location - 1')
63 | elif word=='location2':f.write('location - 2')
64 | elif word[0]=='\'':f.write("\' "+word[1:])
65 | else:f.write(word)
66 | for i in range(1,len(v[1])):
67 | word=v[1][i].lower()
68 | f.write(" ")
69 | if word == 'location1':
70 | f.write('location - 1')
71 | elif word == 'location2':
72 | f.write('location - 2')
73 | elif word[0] == '\'':
74 | f.write("\' " + word[1:])
75 | else:
76 | f.write(word)
77 | f.write("\t")
78 | f.write("what do you think of the ")
79 | if len(v[3])==1:
80 | f.write(v[3][0]+" ")
81 | else:
82 | f.write("transit location ")
83 | if v[2]=='LOCATION1':f.write('of location - 1 ?\t')
84 | if v[2]=='LOCATION2':f.write('of location - 2 ?\t')
85 | f.write(v[4]+"\n")
86 |
87 | with open(dir_path+"test_QA_M.tsv","w",encoding="utf-8") as f:
88 | f.write("id\tsentence1\tsentence2\tlabel\n")
89 | for v in test:
90 | f.write(str(v[0])+"\t")
91 | word=v[1][0].lower()
92 | if word=='location1':f.write('location - 1')
93 | elif word=='location2':f.write('location - 2')
94 | elif word[0]=='\'':f.write("\' "+word[1:])
95 | else:f.write(word)
96 | for i in range(1,len(v[1])):
97 | word=v[1][i].lower()
98 | f.write(" ")
99 | if word == 'location1':
100 | f.write('location - 1')
101 | elif word == 'location2':
102 | f.write('location - 2')
103 | elif word[0] == '\'':
104 | f.write("\' " + word[1:])
105 | else:
106 | f.write(word)
107 | f.write("\t")
108 | f.write("what do you think of the ")
109 | if len(v[3])==1:
110 | f.write(v[3][0]+" ")
111 | else:
112 | f.write("transit location ")
113 | if v[2]=='LOCATION1':f.write('of location - 1 ?\t')
114 | if v[2]=='LOCATION2':f.write('of location - 2 ?\t')
115 | f.write(v[4]+"\n")
116 |
--------------------------------------------------------------------------------
/generate/generate_semeval_NLI_M.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | data_dir='../data/semeval2014/'
4 |
5 | dir_path = data_dir+'bert-pair/'
6 | if not os.path.exists(dir_path):
7 | os.makedirs(dir_path)
8 |
9 | with open(dir_path+"test_NLI_M.csv","w",encoding="utf-8") as g:
10 | with open(data_dir+"Restaurants_Test_Gold.xml","r",encoding="utf-8") as f:
11 | s=f.readline().strip()
12 | while s:
13 | category=[]
14 | polarity=[]
15 | if "")
18 | id=s[left+4:right-1]
19 | while not "" in s:
20 | if "" in s:
21 | left=s.find("")
22 | right=s.find("")
23 | text=s[left+6:right]
24 | if "aspectCategory" in s:
25 | left=s.find("category=")
26 | right=s.find("polarity=")
27 | category.append(s[left+10:right-2])
28 | left=s.find("polarity=")
29 | right=s.find("/>")
30 | polarity.append(s[left+10:right-2])
31 | s=f.readline().strip()
32 | if "price" in category:
33 | g.write(id+"\t"+polarity[category.index("price")]+"\t"+"price"+"\t"+text+"\n")
34 | else:
35 | g.write(id + "\t" + "none" + "\t" + "price" + "\t" + text + "\n")
36 | if "anecdotes/miscellaneous" in category:
37 | g.write(id+"\t"+polarity[category.index("anecdotes/miscellaneous")]+"\t"+"anecdotes"+"\t"+text+"\n")
38 | else:
39 | g.write(id + "\t" + "none" + "\t" + "anecdotes" + "\t" + text + "\n")
40 | if "food" in category:
41 | g.write(id+"\t"+polarity[category.index("food")]+"\t"+"food"+"\t"+text+"\n")
42 | else:
43 | g.write(id + "\t" + "none" + "\t" + "food" + "\t" + text + "\n")
44 | if "ambience" in category:
45 | g.write(id+"\t"+polarity[category.index("ambience")]+"\t"+"ambience"+"\t"+text+"\n")
46 | else:
47 | g.write(id + "\t" + "none" + "\t" + "ambience" + "\t" + text + "\n")
48 | if "service" in category:
49 | g.write(id+"\t"+polarity[category.index("service")]+"\t"+"service"+"\t"+text+"\n")
50 | else:
51 | g.write(id + "\t" + "none" + "\t" + "service" + "\t" + text + "\n")
52 | else:
53 | s = f.readline().strip()
54 |
55 |
56 | with open(dir_path+"train_NLI_M.csv","w",encoding="utf-8") as g:
57 | with open(data_dir+"Restaurants_Train.xml","r",encoding="utf-8") as f:
58 | s=f.readline().strip()
59 | while s:
60 | category=[]
61 | polarity=[]
62 | if "")
65 | id=s[left+4:right-1]
66 | while not "" in s:
67 | if "" in s:
68 | left=s.find("")
69 | right=s.find("")
70 | text=s[left+6:right]
71 | if "aspectCategory" in s:
72 | left=s.find("category=")
73 | right=s.find("polarity=")
74 | category.append(s[left+10:right-2])
75 | left=s.find("polarity=")
76 | right=s.find("/>")
77 | polarity.append(s[left+10:right-1])
78 | s=f.readline().strip()
79 | if "price" in category:
80 | g.write(id+"\t"+polarity[category.index("price")]+"\t"+"price"+"\t"+text+"\n")
81 | else:
82 | g.write(id + "\t" + "none" + "\t" + "price" + "\t" + text + "\n")
83 | if "anecdotes/miscellaneous" in category:
84 | g.write(id+"\t"+polarity[category.index("anecdotes/miscellaneous")]+"\t"+"anecdotes"+"\t"+text+"\n")
85 | else:
86 | g.write(id + "\t" + "none" + "\t" + "anecdotes" + "\t" + text + "\n")
87 | if "food" in category:
88 | g.write(id+"\t"+polarity[category.index("food")]+"\t"+"food"+"\t"+text+"\n")
89 | else:
90 | g.write(id + "\t" + "none" + "\t" + "food" + "\t" + text + "\n")
91 | if "ambience" in category:
92 | g.write(id+"\t"+polarity[category.index("ambience")]+"\t"+"ambience"+"\t"+text+"\n")
93 | else:
94 | g.write(id + "\t" + "none" + "\t" + "ambience" + "\t" + text + "\n")
95 | if "service" in category:
96 | g.write(id+"\t"+polarity[category.index("service")]+"\t"+"service"+"\t"+text+"\n")
97 | else:
98 | g.write(id + "\t" + "none" + "\t" + "service" + "\t" + text + "\n")
99 | else:
100 | s = f.readline().strip()
101 |
102 |
103 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # ABSA as a Sentence Pair Classification Task
2 |
3 | Codes and corpora for paper "Utilizing BERT for Aspect-Based Sentiment Analysis via Constructing Auxiliary Sentence" (NAACL 2019)
4 |
5 | ## Requirement
6 |
7 | * pytorch: 1.0.0
8 | * python: 3.7.1
9 | * tensorflow: 1.13.1 (only needed for converting BERT-tensorflow-model to pytorch-model)
10 | * numpy: 1.15.4
11 | * nltk
12 | * sklearn
13 |
14 | ## Step 1: prepare datasets
15 |
16 | ### SentiHood
17 |
18 | Since the link given in the [dataset released paper]() has failed, we use the [dataset mirror]() listed in [NLP-progress](https://github.com/sebastianruder/NLP-progress/blob/master/english/sentiment_analysis.md) and fix some mistakes (there are duplicate aspect data in several sentences). See directory: `data/sentihood/`.
19 |
20 | Run following commands to prepare datasets for tasks:
21 |
22 | ```
23 | cd generate/
24 | bash make.sh sentihood
25 | ```
26 |
27 | ### SemEval 2014
28 |
29 | Train Data is available in [SemEval-2014 ABSA Restaurant Reviews - Train Data](http://metashare.ilsp.gr:8080/repository/browse/semeval-2014-absa-restaurant-reviews-train-data/479d18c0625011e38685842b2b6a04d72cb57ba6c07743b9879d1a04e72185b8/) and Gold Test Data is available in [SemEval-2014 ABSA Test Data - Gold Annotations](http://metashare.ilsp.gr:8080/repository/browse/semeval-2014-absa-test-data-gold-annotations/b98d11cec18211e38229842b2b6a04d77591d40acd7542b7af823a54fb03a155/). See directory: `data/semeval2014/`.
30 |
31 | Run following commands to prepare datasets for tasks:
32 |
33 | ```
34 | cd generate/
35 | bash make.sh semeval
36 | ```
37 |
38 | ## Step 2: prepare BERT-pytorch-model
39 |
40 | Download [BERT-Base (Google's pre-trained models)](https://github.com/google-research/bert) and then convert a tensorflow checkpoint to a pytorch model.
41 |
42 | For example:
43 |
44 | ```
45 | python convert_tf_checkpoint_to_pytorch.py \
46 | --tf_checkpoint_path uncased_L-12_H-768_A-12/bert_model.ckpt \
47 | --bert_config_file uncased_L-12_H-768_A-12/bert_config.json \
48 | --pytorch_dump_path uncased_L-12_H-768_A-12/pytorch_model.bin
49 | ```
50 |
51 | ## Step 3: train
52 |
53 | For example, **BERT-pair-NLI_M** task on **SentiHood** dataset:
54 |
55 | ```
56 | CUDA_VISIBLE_DEVICES=0,1,2,3 python run_classifier_TABSA.py \
57 | --task_name sentihood_NLI_M \
58 | --data_dir data/sentihood/bert-pair/ \
59 | --vocab_file uncased_L-12_H-768_A-12/vocab.txt \
60 | --bert_config_file uncased_L-12_H-768_A-12/bert_config.json \
61 | --init_checkpoint uncased_L-12_H-768_A-12/pytorch_model.bin \
62 | --eval_test \
63 | --do_lower_case \
64 | --max_seq_length 512 \
65 | --train_batch_size 24 \
66 | --learning_rate 2e-5 \
67 | --num_train_epochs 6.0 \
68 | --output_dir results/sentihood/NLI_M \
69 | --seed 42
70 | ```
71 |
72 | Note:
73 |
74 | * For SentiHood, `--task_name` must be chosen in `sentihood_NLI_M`, `sentihood_QA_M`, `sentihood_NLI_B`, `sentihood_QA_B` and `sentihood_single`. And for `sentihood_single` task, 8 different tasks (use datasets generated in step 1, see directory `data/sentihood/bert-single`) should be trained separately and then evaluated together.
75 | * For SemEval-2014, `--task_name` must be chosen in `semeval_NLI_M`, `semeval_QA_M`, `semeval_NLI_B`, `semeval_QA_B` and `semeval_single`. And for `semeval_single` task, 5 different tasks (use datasets generated in step 1, see directory : `data/semeval2014/bert-single`) should be trained separately and then evaluated together.
76 |
77 | ## Step 4: evaluation
78 |
79 | Evaluate the results on test set (calculate Acc, F1, etc.).
80 |
81 | For example, **BERT-pair-NLI_M** task on **SentiHood** dataset:
82 |
83 | ```
84 | python evaluation.py --task_name sentihood_NLI_M --pred_data_dir results/sentihood/NLI_M/test_ep_4.txt
85 | ```
86 |
87 | Note:
88 |
89 | * As mentioned in step 3, for `sentihood_single` task, 8 different tasks should be trained separately and then evaluated together. `--pred_data_dir` should be a directory that contains **8 files** named as follows: `loc1_general.txt`, `loc1_price.txt`, `loc1_safety.txt`, `loc1_transit.txt`, `loc2_general.txt`, `loc2_price.txt`, `loc2_safety.txt` and `loc2_transit.txt`
90 | * As mentioned in step 3, for `semeval_single` task, 5 different tasks should be trained separately and then evaluated together. `--pred_data_dir` should be a directory that contains **5 files** named as follows: `price.txt`, `anecdotes.txt`, `food.txt`, `ambience.txt` and `service.txt`
91 | * For the rest 8 tasks, `--pred_data_dir` should be a file just like that in the example.
92 |
93 |
94 | ## Citation
95 |
96 | ```
97 | @inproceedings{sun-etal-2019-utilizing,
98 | title = "Utilizing {BERT} for Aspect-Based Sentiment Analysis via Constructing Auxiliary Sentence",
99 | author = "Sun, Chi and
100 | Huang, Luyao and
101 | Qiu, Xipeng",
102 | booktitle = "Proceedings of the 2019 Conference of the North {A}merican Chapter of the Association for Computational Linguistics: Human Language Technologies, Volume 1 (Long and Short Papers)",
103 | month = jun,
104 | year = "2019",
105 | address = "Minneapolis, Minnesota",
106 | publisher = "Association for Computational Linguistics",
107 | url = "https://www.aclweb.org/anthology/N19-1035",
108 | pages = "380--385"
109 | }
110 | ```
111 |
--------------------------------------------------------------------------------
/generate/generate_sentihood_BERT_single.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from data_utils_sentihood import *
4 |
5 | data_dir='../data/sentihood/'
6 | aspect2idx = {
7 | 'general': 0,
8 | 'price': 1,
9 | 'transit-location': 2,
10 | 'safety': 3,
11 | }
12 |
13 | (train, train_aspect_idx), (val, val_aspect_idx), (test, test_aspect_idx) = load_task(data_dir, aspect2idx)
14 |
15 | print("len(train) = ", len(train))
16 | print("len(val) = ", len(val))
17 | print("len(test) = ", len(test))
18 |
19 | train.sort(key=lambda x:x[2]+str(x[0])+x[3][0])
20 | val.sort(key=lambda x:x[2]+str(x[0])+x[3][0])
21 | test.sort(key=lambda x:x[2]+str(x[0])+x[3][0])
22 |
23 | location_name = ['loc1', 'loc2']
24 | aspect_name = ['general', 'price', 'safety', 'transit']
25 | dir_path = [data_dir + 'bert-single/' + i + '_' + j + '/' for i in location_name for j in aspect_name]
26 | for path in dir_path:
27 | if not os.path.exists(path):
28 | os.makedirs(path)
29 |
30 | count=0
31 | with open(dir_path[0]+"train.tsv","w",encoding="utf-8") as f1_general, \
32 | open(dir_path[1]+"train.tsv", "w", encoding="utf-8") as f1_price, \
33 | open(dir_path[2]+"train.tsv", "w", encoding="utf-8") as f1_safety, \
34 | open(dir_path[3]+"train.tsv", "w", encoding="utf-8") as f1_transit, \
35 | open(dir_path[4]+"train.tsv", "w", encoding="utf-8") as f2_general, \
36 | open(dir_path[5]+"train.tsv", "w", encoding="utf-8") as f2_price, \
37 | open(dir_path[6]+"train.tsv", "w", encoding="utf-8") as f2_safety, \
38 | open(dir_path[7]+"train.tsv", "w",encoding="utf-8") as f2_transit, \
39 | open(data_dir + "bert-pair/train_NLI_M.tsv", "r", encoding="utf-8") as f:
40 | s = f.readline().strip()
41 | s = f.readline().strip()
42 | while s:
43 | count+=1
44 | tmp=s.split("\t")
45 | line=tmp[0]+"\t"+tmp[1]+"\t"+tmp[3]+"\n"
46 | if count<=11908: #loc1
47 | if count%4==1:
48 | f1_general.write(line)
49 | if count%4==2:
50 | f1_price.write(line)
51 | if count%4==3:
52 | f1_safety.write(line)
53 | if count%4==0:
54 | f1_transit.write(line)
55 | else: #loc2
56 | if count%4==1:
57 | f2_general.write(line)
58 | if count%4==2:
59 | f2_price.write(line)
60 | if count%4==3:
61 | f2_safety.write(line)
62 | if count%4==0:
63 | f2_transit.write(line)
64 | s = f.readline().strip()
65 |
66 | count=0
67 | with open(dir_path[0]+"dev.tsv","w",encoding="utf-8") as f1_general, \
68 | open(dir_path[1]+"dev.tsv", "w", encoding="utf-8") as f1_price, \
69 | open(dir_path[2]+"dev.tsv", "w", encoding="utf-8") as f1_safety, \
70 | open(dir_path[3]+"dev.tsv", "w", encoding="utf-8") as f1_transit, \
71 | open(dir_path[4]+"dev.tsv", "w", encoding="utf-8") as f2_general, \
72 | open(dir_path[5]+"dev.tsv", "w", encoding="utf-8") as f2_price, \
73 | open(dir_path[6]+"dev.tsv", "w", encoding="utf-8") as f2_safety, \
74 | open(dir_path[7]+"dev.tsv", "w",encoding="utf-8") as f2_transit, \
75 | open(data_dir + "bert-pair/dev_NLI_M.tsv", "r", encoding="utf-8") as f:
76 | s = f.readline().strip()
77 | s = f.readline().strip()
78 | while s:
79 | count+=1
80 | tmp=s.split("\t")
81 | line=tmp[0]+"\t"+tmp[1]+"\t"+tmp[3]+"\n"
82 | if count<=2988: #loc1
83 | if count%4==1:
84 | f1_general.write(line)
85 | if count%4==2:
86 | f1_price.write(line)
87 | if count%4==3:
88 | f1_safety.write(line)
89 | if count%4==0:
90 | f1_transit.write(line)
91 | else: #loc2
92 | if count%4==1:
93 | f2_general.write(line)
94 | if count%4==2:
95 | f2_price.write(line)
96 | if count%4==3:
97 | f2_safety.write(line)
98 | if count%4==0:
99 | f2_transit.write(line)
100 | s = f.readline().strip()
101 |
102 | count=0
103 | with open(dir_path[0]+"test.tsv","w",encoding="utf-8") as f1_general, \
104 | open(dir_path[1]+"test.tsv", "w", encoding="utf-8") as f1_price, \
105 | open(dir_path[2]+"test.tsv", "w", encoding="utf-8") as f1_safety, \
106 | open(dir_path[3]+"test.tsv", "w", encoding="utf-8") as f1_transit, \
107 | open(dir_path[4]+"test.tsv", "w", encoding="utf-8") as f2_general, \
108 | open(dir_path[5]+"test.tsv", "w", encoding="utf-8") as f2_price, \
109 | open(dir_path[6]+"test.tsv", "w", encoding="utf-8") as f2_safety, \
110 | open(dir_path[7]+"test.tsv", "w",encoding="utf-8") as f2_transit, \
111 | open(data_dir + "bert-pair/test_NLI_M.tsv", "r", encoding="utf-8") as f:
112 | s = f.readline().strip()
113 | s = f.readline().strip()
114 | while s:
115 | count+=1
116 | tmp=s.split("\t")
117 | line=tmp[0]+"\t"+tmp[1]+"\t"+tmp[3]+"\n"
118 | if count<=5964: #loc1
119 | if count%4==1:
120 | f1_general.write(line)
121 | if count%4==2:
122 | f1_price.write(line)
123 | if count%4==3:
124 | f1_safety.write(line)
125 | if count%4==0:
126 | f1_transit.write(line)
127 | else: #loc2
128 | if count%4==1:
129 | f2_general.write(line)
130 | if count%4==2:
131 | f2_price.write(line)
132 | if count%4==3:
133 | f2_safety.write(line)
134 | if count%4==0:
135 | f2_transit.write(line)
136 | s = f.readline().strip()
137 |
138 | print("Finished!")
--------------------------------------------------------------------------------
/generate/generate_semeval_BERT_single.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | data_dir='../data/semeval2014/'
4 |
5 | aspect_name = ['price', 'anecdotes', 'food', 'ambience', 'service']
6 | dir_path = [data_dir + 'bert-single/' + i + '/' for i in aspect_name]
7 | for path in dir_path:
8 | if not os.path.exists(path):
9 | os.makedirs(path)
10 |
11 | with open(dir_path[0]+"test.csv", "w", encoding="utf-8") as g_price, \
12 | open(dir_path[1]+"test.csv", "w", encoding="utf-8") as g_anecdotes,\
13 | open(dir_path[2]+"test.csv", "w", encoding="utf-8") as g_food,\
14 | open(dir_path[3]+"test.csv", "w", encoding="utf-8") as g_ambience,\
15 | open(dir_path[4]+"test.csv", "w", encoding="utf-8") as g_service,\
16 | open(data_dir+"Restaurants_Test_Gold.xml","r",encoding="utf-8") as f:
17 | s=f.readline().strip()
18 | while s:
19 | category=[]
20 | polarity=[]
21 | if "")
24 | id=s[left+4:right-1]
25 | while not "" in s:
26 | if "" in s:
27 | left=s.find("")
28 | right=s.find("")
29 | text=s[left+6:right]
30 | if "aspectCategory" in s:
31 | left=s.find("category=")
32 | right=s.find("polarity=")
33 | category.append(s[left+10:right-2])
34 | left=s.find("polarity=")
35 | right=s.find("/>")
36 | polarity.append(s[left+10:right-2])
37 | s=f.readline().strip()
38 | if "price" in category:
39 | g_price.write(id+"\t"+polarity[category.index("price")]+"\t"+"price"+"\t"+text+"\n")
40 | else:
41 | g_price.write(id + "\t" + "none" + "\t" + "price" + "\t" + text + "\n")
42 | if "anecdotes/miscellaneous" in category:
43 | g_anecdotes.write(id+"\t"+polarity[category.index("anecdotes/miscellaneous")]+"\t"+"anecdotes"+"\t"+text+"\n")
44 | else:
45 | g_anecdotes.write(id + "\t" + "none" + "\t" + "anecdotes" + "\t" + text + "\n")
46 | if "food" in category:
47 | g_food.write(id+"\t"+polarity[category.index("food")]+"\t"+"food"+"\t"+text+"\n")
48 | else:
49 | g_food.write(id + "\t" + "none" + "\t" + "food" + "\t" + text + "\n")
50 | if "ambience" in category:
51 | g_ambience.write(id+"\t"+polarity[category.index("ambience")]+"\t"+"ambience"+"\t"+text+"\n")
52 | else:
53 | g_ambience.write(id + "\t" + "none" + "\t" + "ambience" + "\t" + text + "\n")
54 | if "service" in category:
55 | g_service.write(id+"\t"+polarity[category.index("service")]+"\t"+"service"+"\t"+text+"\n")
56 | else:
57 | g_service.write(id + "\t" + "none" + "\t" + "service" + "\t" + text + "\n")
58 | else:
59 | s = f.readline().strip()
60 |
61 |
62 | with open(dir_path[0]+"train.csv", "w", encoding="utf-8") as g_price, \
63 | open(dir_path[1]+"train.csv", "w", encoding="utf-8") as g_anecdotes,\
64 | open(dir_path[2]+"train.csv", "w", encoding="utf-8") as g_food,\
65 | open(dir_path[3]+"train.csv", "w", encoding="utf-8") as g_ambience,\
66 | open(dir_path[4]+"train.csv", "w", encoding="utf-8") as g_service,\
67 | open(data_dir+"Restaurants_Train.xml","r",encoding="utf-8") as f:
68 | s=f.readline().strip()
69 | while s:
70 | category=[]
71 | polarity=[]
72 | if "")
75 | id=s[left+4:right-1]
76 | while not "" in s:
77 | if "" in s:
78 | left=s.find("")
79 | right=s.find("")
80 | text=s[left+6:right]
81 | if "aspectCategory" in s:
82 | left=s.find("category=")
83 | right=s.find("polarity=")
84 | category.append(s[left+10:right-2])
85 | left=s.find("polarity=")
86 | right=s.find("/>")
87 | polarity.append(s[left+10:right-1])
88 | s=f.readline().strip()
89 | if "price" in category:
90 | g_price.write(id+"\t"+polarity[category.index("price")]+"\t"+"price"+"\t"+text+"\n")
91 | else:
92 | g_price.write(id + "\t" + "none" + "\t" + "price" + "\t" + text + "\n")
93 | if "anecdotes/miscellaneous" in category:
94 | g_anecdotes.write(id+"\t"+polarity[category.index("anecdotes/miscellaneous")]+"\t"+"anecdotes"+"\t"+text+"\n")
95 | else:
96 | g_anecdotes.write(id + "\t" + "none" + "\t" + "anecdotes" + "\t" + text + "\n")
97 | if "food" in category:
98 | g_food.write(id+"\t"+polarity[category.index("food")]+"\t"+"food"+"\t"+text+"\n")
99 | else:
100 | g_food.write(id + "\t" + "none" + "\t" + "food" + "\t" + text + "\n")
101 | if "ambience" in category:
102 | g_ambience.write(id+"\t"+polarity[category.index("ambience")]+"\t"+"ambience"+"\t"+text+"\n")
103 | else:
104 | g_ambience.write(id + "\t" + "none" + "\t" + "ambience" + "\t" + text + "\n")
105 | if "service" in category:
106 | g_service.write(id+"\t"+polarity[category.index("service")]+"\t"+"service"+"\t"+text+"\n")
107 | else:
108 | g_service.write(id + "\t" + "none" + "\t" + "service" + "\t" + text + "\n")
109 | else:
110 | s = f.readline().strip()
111 |
112 | print("Finished!")
113 |
114 |
--------------------------------------------------------------------------------
/generate/generate_semeval_QA_M.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | data_dir='../data/semeval2014/'
4 |
5 | dir_path = data_dir+'bert-pair/'
6 | if not os.path.exists(dir_path):
7 | os.makedirs(dir_path)
8 |
9 | with open(dir_path+"test_QA_M.csv","w",encoding="utf-8") as g:
10 | with open(data_dir+"Restaurants_Test_Gold.xml","r",encoding="utf-8") as f:
11 | s=f.readline().strip()
12 | while s:
13 | category=[]
14 | polarity=[]
15 | if "")
18 | id=s[left+4:right-1]
19 | while not "" in s:
20 | if "" in s:
21 | left=s.find("")
22 | right=s.find("")
23 | text=s[left+6:right]
24 | if "aspectCategory" in s:
25 | left=s.find("category=")
26 | right=s.find("polarity=")
27 | category.append(s[left+10:right-2])
28 | left=s.find("polarity=")
29 | right=s.find("/>")
30 | polarity.append(s[left+10:right-2])
31 | s=f.readline().strip()
32 | if "price" in category:
33 | g.write(id+"\t"+polarity[category.index("price")]+"\t"+"what do you think of the price of it ?"+"\t"+text+"\n")
34 | else:
35 | g.write(id + "\t" + "none" + "\t" + "what do you think of the price of it ?" + "\t" + text + "\n")
36 | if "anecdotes/miscellaneous" in category:
37 | g.write(id+"\t"+polarity[category.index("anecdotes/miscellaneous")]+"\t"+"what do you think of the anecdotes of it ?"+"\t"+text+"\n")
38 | else:
39 | g.write(id + "\t" + "none" + "\t" + "what do you think of the anecdotes of it ?" + "\t" + text + "\n")
40 | if "food" in category:
41 | g.write(id+"\t"+polarity[category.index("food")]+"\t"+"what do you think of the food of it ?"+"\t"+text+"\n")
42 | else:
43 | g.write(id + "\t" + "none" + "\t" + "what do you think of the food of it ?" + "\t" + text + "\n")
44 | if "ambience" in category:
45 | g.write(id+"\t"+polarity[category.index("ambience")]+"\t"+"what do you think of the ambience of it ?"+"\t"+text+"\n")
46 | else:
47 | g.write(id + "\t" + "none" + "\t" + "what do you think of the ambience of it ?" + "\t" + text + "\n")
48 | if "service" in category:
49 | g.write(id+"\t"+polarity[category.index("service")]+"\t"+"what do you think of the service of it ?"+"\t"+text+"\n")
50 | else:
51 | g.write(id + "\t" + "none" + "\t" + "what do you think of the service of it ?" + "\t" + text + "\n")
52 | else:
53 | s = f.readline().strip()
54 |
55 |
56 | with open(dir_path+"train_QA_M.csv","w",encoding="utf-8") as g:
57 | with open(data_dir+"Restaurants_Train.xml","r",encoding="utf-8") as f:
58 | s=f.readline().strip()
59 | while s:
60 | category=[]
61 | polarity=[]
62 | if "")
65 | id=s[left+4:right-1]
66 | while not "" in s:
67 | if "" in s:
68 | left=s.find("")
69 | right=s.find("")
70 | text=s[left+6:right]
71 | if "aspectCategory" in s:
72 | left=s.find("category=")
73 | right=s.find("polarity=")
74 | category.append(s[left+10:right-2])
75 | left=s.find("polarity=")
76 | right=s.find("/>")
77 | polarity.append(s[left+10:right-1])
78 | s=f.readline().strip()
79 | if "price" in category:
80 | g.write(id+"\t"+polarity[category.index("price")]+"\t"+"what do you think of the price of it ?"+"\t"+text+"\n")
81 | else:
82 | g.write(id + "\t" + "none" + "\t" + "what do you think of the price of it ?" + "\t" + text + "\n")
83 | if "anecdotes/miscellaneous" in category:
84 | g.write(id+"\t"+polarity[category.index("anecdotes/miscellaneous")]+"\t"+"what do you think of the anecdotes of it ?"+"\t"+text+"\n")
85 | else:
86 | g.write(id + "\t" + "none" + "\t" + "what do you think of the anecdotes of it ?" + "\t" + text + "\n")
87 | if "food" in category:
88 | g.write(id+"\t"+polarity[category.index("food")]+"\t"+"what do you think of the food of it ?"+"\t"+text+"\n")
89 | else:
90 | g.write(id + "\t" + "none" + "\t" + "what do you think of the food of it ?" + "\t" + text + "\n")
91 | if "ambience" in category:
92 | g.write(id+"\t"+polarity[category.index("ambience")]+"\t"+"what do you think of the ambience of it ?"+"\t"+text+"\n")
93 | else:
94 | g.write(id + "\t" + "none" + "\t" + "what do you think of the ambience of it ?" + "\t" + text + "\n")
95 | if "service" in category:
96 | g.write(id+"\t"+polarity[category.index("service")]+"\t"+"what do you think of the service of it ?"+"\t"+text+"\n")
97 | else:
98 | g.write(id + "\t" + "none" + "\t" + "what do you think of the service of it ?" + "\t" + text + "\n")
99 | else:
100 | s = f.readline().strip()
--------------------------------------------------------------------------------
/optimization.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 |
3 | # Reference: https://github.com/huggingface/pytorch-pretrained-BERT
4 |
5 | """PyTorch optimization for BERT model."""
6 |
7 | import math
8 |
9 | import torch
10 | from torch.nn.utils import clip_grad_norm_
11 | from torch.optim import Optimizer
12 |
13 |
14 | def warmup_cosine(x, warmup=0.002):
15 | if x < warmup:
16 | return x/warmup
17 | return 0.5 * (1.0 + torch.cos(math.pi * x))
18 |
19 | def warmup_constant(x, warmup=0.002):
20 | if x < warmup:
21 | return x/warmup
22 | return 1.0
23 |
24 | def warmup_linear(x, warmup=0.002):
25 | if x < warmup:
26 | return x/warmup
27 | return 1.0 - x
28 |
29 | SCHEDULES = {
30 | 'warmup_cosine':warmup_cosine,
31 | 'warmup_constant':warmup_constant,
32 | 'warmup_linear':warmup_linear,
33 | }
34 |
35 |
36 | class BERTAdam(Optimizer):
37 | """Implements BERT version of Adam algorithm with weight decay fix (and no ).
38 | Params:
39 | lr: learning rate
40 | warmup: portion of t_total for the warmup, -1 means no warmup. Default: -1
41 | t_total: total number of training steps for the learning
42 | rate schedule, -1 means constant learning rate. Default: -1
43 | schedule: schedule to use for the warmup (see above). Default: 'warmup_linear'
44 | b1: Adams b1. Default: 0.9
45 | b2: Adams b2. Default: 0.999
46 | e: Adams epsilon. Default: 1e-6
47 | weight_decay_rate: Weight decay. Default: 0.01
48 | max_grad_norm: Maximum norm for the gradients (-1 means no clipping). Default: 1.0
49 | """
50 | def __init__(self, params, lr, warmup=-1, t_total=-1, schedule='warmup_linear',
51 | b1=0.9, b2=0.999, e=1e-6, weight_decay_rate=0.01,
52 | max_grad_norm=1.0):
53 | if not lr >= 0.0:
54 | raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr))
55 | if schedule not in SCHEDULES:
56 | raise ValueError("Invalid schedule parameter: {}".format(schedule))
57 | if not 0.0 <= warmup < 1.0 and not warmup == -1:
58 | raise ValueError("Invalid warmup: {} - should be in [0.0, 1.0[ or -1".format(warmup))
59 | if not 0.0 <= b1 < 1.0:
60 | raise ValueError("Invalid b1 parameter: {} - should be in [0.0, 1.0[".format(b1))
61 | if not 0.0 <= b2 < 1.0:
62 | raise ValueError("Invalid b2 parameter: {} - should be in [0.0, 1.0[".format(b2))
63 | if not e >= 0.0:
64 | raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(e))
65 | defaults = dict(lr=lr, schedule=schedule, warmup=warmup, t_total=t_total,
66 | b1=b1, b2=b2, e=e, weight_decay_rate=weight_decay_rate,
67 | max_grad_norm=max_grad_norm)
68 | super(BERTAdam, self).__init__(params, defaults)
69 |
70 | def get_lr(self):
71 | lr = []
72 | print("l_total=",len(self.param_groups))
73 | for group in self.param_groups:
74 | print("l_p=",len(group['params']))
75 | for p in group['params']:
76 | state = self.state[p]
77 | if len(state) == 0:
78 | return [0]
79 | if group['t_total'] != -1:
80 | schedule_fct = SCHEDULES[group['schedule']]
81 | lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup'])
82 | else:
83 | lr_scheduled = group['lr']
84 | lr.append(lr_scheduled)
85 | return lr
86 |
87 | def to(self, device):
88 | """ Move the optimizer state to a specified device"""
89 | for state in self.state.values():
90 | state['exp_avg'].to(device)
91 | state['exp_avg_sq'].to(device)
92 |
93 | def initialize_step(self, initial_step):
94 | """Initialize state with a defined step (but we don't have stored averaged).
95 | Arguments:
96 | initial_step (int): Initial step number.
97 | """
98 | for group in self.param_groups:
99 | for p in group['params']:
100 | state = self.state[p]
101 | # State initialization
102 | state['step'] = initial_step
103 | # Exponential moving average of gradient values
104 | state['exp_avg'] = torch.zeros_like(p.data)
105 | # Exponential moving average of squared gradient values
106 | state['exp_avg_sq'] = torch.zeros_like(p.data)
107 |
108 | def step(self, closure=None):
109 | """Performs a single optimization step.
110 |
111 | Arguments:
112 | closure (callable, optional): A closure that reevaluates the model
113 | and returns the loss.
114 | """
115 | loss = None
116 | if closure is not None:
117 | loss = closure()
118 |
119 | for group in self.param_groups:
120 | for p in group['params']:
121 | if p.grad is None:
122 | continue
123 | grad = p.grad.data
124 | if grad.is_sparse:
125 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
126 |
127 | state = self.state[p]
128 |
129 | # State initialization
130 | if len(state) == 0:
131 | state['step'] = 0
132 | # Exponential moving average of gradient values
133 | state['next_m'] = torch.zeros_like(p.data)
134 | # Exponential moving average of squared gradient values
135 | state['next_v'] = torch.zeros_like(p.data)
136 |
137 | next_m, next_v = state['next_m'], state['next_v']
138 | beta1, beta2 = group['b1'], group['b2']
139 |
140 | # Add grad clipping
141 | if group['max_grad_norm'] > 0:
142 | clip_grad_norm_(p, group['max_grad_norm'])
143 |
144 | # Decay the first and second moment running average coefficient
145 | # In-place operations to update the averages at the same time
146 | next_m.mul_(beta1).add_(1 - beta1, grad)
147 | next_v.mul_(beta2).addcmul_(1 - beta2, grad, grad)
148 | update = next_m / (next_v.sqrt() + group['e'])
149 |
150 | # Just adding the square of the weights to the loss function is *not*
151 | # the correct way of using L2 regularization/weight decay with Adam,
152 | # since that will interact with the m and v parameters in strange ways.
153 | #
154 | # Instead we want ot decay the weights in a manner that doesn't interact
155 | # with the m/v parameters. This is equivalent to adding the square
156 | # of the weights to the loss with plain (non-momentum) SGD.
157 | if group['weight_decay_rate'] > 0.0:
158 | update += group['weight_decay_rate'] * p.data
159 |
160 | if group['t_total'] != -1:
161 | schedule_fct = SCHEDULES[group['schedule']]
162 | lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup'])
163 | else:
164 | lr_scheduled = group['lr']
165 |
166 | update_with_lr = lr_scheduled * update
167 | p.data.add_(-update_with_lr)
168 |
169 | state['step'] += 1
170 |
171 | # step_size = lr_scheduled * math.sqrt(bias_correction2) / bias_correction1
172 | # bias_correction1 = 1 - beta1 ** state['step']
173 | # bias_correction2 = 1 - beta2 ** state['step']
174 |
175 | return loss
176 |
--------------------------------------------------------------------------------
/generate/generate_sentihood_NLI_B_QA_B.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from data_utils_sentihood import *
4 |
5 | data_dir='../data/sentihood/'
6 | aspect2idx = {
7 | 'general': 0,
8 | 'price': 1,
9 | 'transit-location': 2,
10 | 'safety': 3,
11 | }
12 |
13 | (train, train_aspect_idx), (val, val_aspect_idx), (test, test_aspect_idx) = load_task(data_dir, aspect2idx)
14 |
15 | print("len(train) = ", len(train))
16 | print("len(val) = ", len(val))
17 | print("len(test) = ", len(test))
18 |
19 | train.sort(key=lambda x:x[2]+str(x[0])+x[3][0])
20 | val.sort(key=lambda x:x[2]+str(x[0])+x[3][0])
21 | test.sort(key=lambda x:x[2]+str(x[0])+x[3][0])
22 |
23 | dir_path = data_dir+'bert-pair/'
24 | if not os.path.exists(dir_path):
25 | os.makedirs(dir_path)
26 |
27 | sentiments=["None","Positive","Negative"]
28 | with open(dir_path+"train_NLI_B.tsv","w",encoding="utf-8") as f:
29 | f.write("id\tsentence1\tsentence2\tlabel\n")
30 | for v in train:
31 | for sentiment in sentiments:
32 | f.write(str(v[0])+"\t")
33 | word=v[1][0].lower()
34 | if word=='location1':f.write('location - 1')
35 | elif word=='location2':f.write('location - 2')
36 | elif word[0]=='\'':f.write("\' "+word[1:])
37 | else:f.write(word)
38 | for i in range(1,len(v[1])):
39 | word=v[1][i].lower()
40 | f.write(" ")
41 | if word == 'location1':
42 | f.write('location - 1')
43 | elif word == 'location2':
44 | f.write('location - 2')
45 | elif word[0] == '\'':
46 | f.write("\' " + word[1:])
47 | else:
48 | f.write(word)
49 | f.write("\t")
50 | f.write(sentiment+" - ")
51 | if v[2]=='LOCATION1':f.write('location - 1 - ')
52 | if v[2]=='LOCATION2':f.write('location - 2 - ')
53 | if len(v[3])==1:
54 | f.write(v[3][0]+"\t")
55 | else:
56 | f.write("transit location\t")
57 | if v[4]==sentiment:
58 | f.write("1\n")
59 | else:
60 | f.write("0\n")
61 |
62 | with open(dir_path+"train_QA_B.tsv","w",encoding="utf-8") as f:
63 | f.write("id\tsentence1\tsentence2\tlabel\n")
64 | for v in train:
65 | for sentiment in sentiments:
66 | f.write(str(v[0])+"\t")
67 | word=v[1][0].lower()
68 | if word=='location1':f.write('location - 1')
69 | elif word=='location2':f.write('location - 2')
70 | elif word[0]=='\'':f.write("\' "+word[1:])
71 | else:f.write(word)
72 | for i in range(1,len(v[1])):
73 | word=v[1][i].lower()
74 | f.write(" ")
75 | if word == 'location1':
76 | f.write('location - 1')
77 | elif word == 'location2':
78 | f.write('location - 2')
79 | elif word[0] == '\'':
80 | f.write("\' " + word[1:])
81 | else:
82 | f.write(word)
83 | f.write("\t")
84 | f.write("the polarity of the aspect ")
85 | if len(v[3])==1:
86 | f.write(v[3][0])
87 | else:
88 | f.write("transit location")
89 | if v[2]=='LOCATION1':f.write(' of location - 1 - is ')
90 | if v[2]=='LOCATION2':f.write(' of location - 2 - is ')
91 | f.write(sentiment+" .\t")
92 | if v[4]==sentiment:
93 | f.write("1\n")
94 | else:
95 | f.write("0\n")
96 |
97 | with open(dir_path+"dev_NLI_B.tsv","w",encoding="utf-8") as f:
98 | f.write("id\tsentence1\tsentence2\tlabel\n")
99 | for v in val:
100 | for sentiment in sentiments:
101 | f.write(str(v[0])+"\t")
102 | word=v[1][0].lower()
103 | if word=='location1':f.write('location - 1')
104 | elif word=='location2':f.write('location - 2')
105 | elif word[0]=='\'':f.write("\' "+word[1:])
106 | else:f.write(word)
107 | for i in range(1,len(v[1])):
108 | word=v[1][i].lower()
109 | f.write(" ")
110 | if word == 'location1':
111 | f.write('location - 1')
112 | elif word == 'location2':
113 | f.write('location - 2')
114 | elif word[0] == '\'':
115 | f.write("\' " + word[1:])
116 | else:
117 | f.write(word)
118 | f.write("\t")
119 | f.write(sentiment+" - ")
120 | if v[2]=='LOCATION1':f.write('location - 1 - ')
121 | if v[2]=='LOCATION2':f.write('location - 2 - ')
122 | if len(v[3])==1:
123 | f.write(v[3][0]+"\t")
124 | else:
125 | f.write("transit location\t")
126 | if v[4]==sentiment:
127 | f.write("1\n")
128 | else:
129 | f.write("0\n")
130 |
131 | with open(dir_path+"dev_QA_B.tsv","w",encoding="utf-8") as f:
132 | f.write("id\tsentence1\tsentence2\tlabel\n")
133 | for v in val:
134 | for sentiment in sentiments:
135 | f.write(str(v[0])+"\t")
136 | word=v[1][0].lower()
137 | if word=='location1':f.write('location - 1')
138 | elif word=='location2':f.write('location - 2')
139 | elif word[0]=='\'':f.write("\' "+word[1:])
140 | else:f.write(word)
141 | for i in range(1,len(v[1])):
142 | word=v[1][i].lower()
143 | f.write(" ")
144 | if word == 'location1':
145 | f.write('location - 1')
146 | elif word == 'location2':
147 | f.write('location - 2')
148 | elif word[0] == '\'':
149 | f.write("\' " + word[1:])
150 | else:
151 | f.write(word)
152 | f.write("\t")
153 | f.write("the polarity of the aspect ")
154 | if len(v[3])==1:
155 | f.write(v[3][0])
156 | else:
157 | f.write("transit location")
158 | if v[2]=='LOCATION1':f.write(' of location - 1 - is ')
159 | if v[2]=='LOCATION2':f.write(' of location - 2 - is ')
160 | f.write(sentiment+" .\t")
161 | if v[4]==sentiment:
162 | f.write("1\n")
163 | else:
164 | f.write("0\n")
165 |
166 | with open(dir_path+"test_NLI_B.tsv","w",encoding="utf-8") as f:
167 | f.write("id\tsentence1\tsentence2\tlabel\n")
168 | for v in test:
169 | for sentiment in sentiments:
170 | f.write(str(v[0])+"\t")
171 | word=v[1][0].lower()
172 | if word=='location1':f.write('location - 1')
173 | elif word=='location2':f.write('location - 2')
174 | elif word[0]=='\'':f.write("\' "+word[1:])
175 | else:f.write(word)
176 | for i in range(1,len(v[1])):
177 | word=v[1][i].lower()
178 | f.write(" ")
179 | if word == 'location1':
180 | f.write('location - 1')
181 | elif word == 'location2':
182 | f.write('location - 2')
183 | elif word[0] == '\'':
184 | f.write("\' " + word[1:])
185 | else:
186 | f.write(word)
187 | f.write("\t")
188 | f.write(sentiment + " - ")
189 | if v[2]=='LOCATION1':f.write('location - 1 - ')
190 | if v[2]=='LOCATION2':f.write('location - 2 - ')
191 | if len(v[3])==1:
192 | f.write(v[3][0]+"\t")
193 | else:
194 | f.write("transit location\t")
195 | if v[4]==sentiment:
196 | f.write("1\n")
197 | else:
198 | f.write("0\n")
199 |
200 | with open(dir_path+"test_QA_B.tsv","w",encoding="utf-8") as f:
201 | f.write("id\tsentence1\tsentence2\tlabel\n")
202 | for v in test:
203 | for sentiment in sentiments:
204 | f.write(str(v[0])+"\t")
205 | word=v[1][0].lower()
206 | if word=='location1':f.write('location - 1')
207 | elif word=='location2':f.write('location - 2')
208 | elif word[0]=='\'':f.write("\' "+word[1:])
209 | else:f.write(word)
210 | for i in range(1,len(v[1])):
211 | word=v[1][i].lower()
212 | f.write(" ")
213 | if word == 'location1':
214 | f.write('location - 1')
215 | elif word == 'location2':
216 | f.write('location - 2')
217 | elif word[0] == '\'':
218 | f.write("\' " + word[1:])
219 | else:
220 | f.write(word)
221 | f.write("\t")
222 | f.write("the polarity of the aspect ")
223 | if len(v[3])==1:
224 | f.write(v[3][0])
225 | else:
226 | f.write("transit location")
227 | if v[2]=='LOCATION1':f.write(' of location - 1 - is ')
228 | if v[2]=='LOCATION2':f.write(' of location - 2 - is ')
229 | f.write(sentiment+" .\t")
230 | if v[4]==sentiment:
231 | f.write("1\n")
232 | else:
233 | f.write("0\n")
--------------------------------------------------------------------------------
/tokenization.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 |
3 | # Reference: https://github.com/huggingface/pytorch-pretrained-BERT
4 |
5 | """Tokenization classes."""
6 |
7 | from __future__ import absolute_import, division, print_function
8 |
9 | import collections
10 | import unicodedata
11 |
12 | import six
13 |
14 |
15 | def convert_to_unicode(text):
16 | """Converts `text` to Unicode (if it's not already), assuming utf-8 input."""
17 | if six.PY3:
18 | if isinstance(text, str):
19 | return text
20 | elif isinstance(text, bytes):
21 | return text.decode("utf-8", "ignore")
22 | else:
23 | raise ValueError("Unsupported string type: %s" % (type(text)))
24 | elif six.PY2:
25 | if isinstance(text, str):
26 | return text.decode("utf-8", "ignore")
27 | elif isinstance(text, unicode):
28 | return text
29 | else:
30 | raise ValueError("Unsupported string type: %s" % (type(text)))
31 | else:
32 | raise ValueError("Not running on Python2 or Python 3?")
33 |
34 |
35 | def printable_text(text):
36 | """Returns text encoded in a way suitable for print or `tf.logging`."""
37 |
38 | # These functions want `str` for both Python2 and Python3, but in one case
39 | # it's a Unicode string and in the other it's a byte string.
40 | if six.PY3:
41 | if isinstance(text, str):
42 | return text
43 | elif isinstance(text, bytes):
44 | return text.decode("utf-8", "ignore")
45 | else:
46 | raise ValueError("Unsupported string type: %s" % (type(text)))
47 | elif six.PY2:
48 | if isinstance(text, str):
49 | return text
50 | elif isinstance(text, unicode):
51 | return text.encode("utf-8")
52 | else:
53 | raise ValueError("Unsupported string type: %s" % (type(text)))
54 | else:
55 | raise ValueError("Not running on Python2 or Python 3?")
56 |
57 |
58 | def load_vocab(vocab_file):
59 | """Loads a vocabulary file into a dictionary."""
60 | vocab = collections.OrderedDict()
61 | index = 0
62 | with open(vocab_file, "r") as reader:
63 | while True:
64 | token = convert_to_unicode(reader.readline())
65 | if not token:
66 | break
67 | token = token.strip()
68 | vocab[token] = index
69 | index += 1
70 | return vocab
71 |
72 |
73 | def convert_tokens_to_ids(vocab, tokens):
74 | """Converts a sequence of tokens into ids using the vocab."""
75 | ids = []
76 | for token in tokens:
77 | ids.append(vocab[token])
78 | return ids
79 |
80 |
81 | def whitespace_tokenize(text):
82 | """Runs basic whitespace cleaning and splitting on a peice of text."""
83 | text = text.strip()
84 | if not text:
85 | return []
86 | tokens = text.split()
87 | return tokens
88 |
89 |
90 | class FullTokenizer(object):
91 | """Runs end-to-end tokenziation."""
92 |
93 | def __init__(self, vocab_file, do_lower_case=True):
94 | self.vocab = load_vocab(vocab_file)
95 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case)
96 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
97 |
98 | def tokenize(self, text):
99 | split_tokens = []
100 | for token in self.basic_tokenizer.tokenize(text):
101 | for sub_token in self.wordpiece_tokenizer.tokenize(token):
102 | split_tokens.append(sub_token)
103 |
104 | return split_tokens
105 |
106 | def convert_tokens_to_ids(self, tokens):
107 | return convert_tokens_to_ids(self.vocab, tokens)
108 |
109 |
110 | class BasicTokenizer(object):
111 | """Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
112 |
113 | def __init__(self, do_lower_case=True):
114 | """Constructs a BasicTokenizer.
115 |
116 | Args:
117 | do_lower_case: Whether to lower case the input.
118 | """
119 | self.do_lower_case = do_lower_case
120 |
121 | def tokenize(self, text):
122 | """Tokenizes a piece of text."""
123 | text = convert_to_unicode(text)
124 | text = self._clean_text(text)
125 | orig_tokens = whitespace_tokenize(text)
126 | split_tokens = []
127 | for token in orig_tokens:
128 | if self.do_lower_case:
129 | token = token.lower()
130 | token = self._run_strip_accents(token)
131 | split_tokens.extend(self._run_split_on_punc(token))
132 |
133 | output_tokens = whitespace_tokenize(" ".join(split_tokens))
134 | return output_tokens
135 |
136 | def _run_strip_accents(self, text):
137 | """Strips accents from a piece of text."""
138 | text = unicodedata.normalize("NFD", text)
139 | output = []
140 | for char in text:
141 | cat = unicodedata.category(char)
142 | if cat == "Mn":
143 | continue
144 | output.append(char)
145 | return "".join(output)
146 |
147 | def _run_split_on_punc(self, text):
148 | """Splits punctuation on a piece of text."""
149 | chars = list(text)
150 | i = 0
151 | start_new_word = True
152 | output = []
153 | while i < len(chars):
154 | char = chars[i]
155 | if _is_punctuation(char):
156 | output.append([char])
157 | start_new_word = True
158 | else:
159 | if start_new_word:
160 | output.append([])
161 | start_new_word = False
162 | output[-1].append(char)
163 | i += 1
164 |
165 | return ["".join(x) for x in output]
166 |
167 | def _clean_text(self, text):
168 | """Performs invalid character removal and whitespace cleanup on text."""
169 | output = []
170 | for char in text:
171 | cp = ord(char)
172 | if cp == 0 or cp == 0xfffd or _is_control(char):
173 | continue
174 | if _is_whitespace(char):
175 | output.append(" ")
176 | else:
177 | output.append(char)
178 | return "".join(output)
179 |
180 |
181 | class WordpieceTokenizer(object):
182 | """Runs WordPiece tokenization."""
183 |
184 | def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=100):
185 | self.vocab = vocab
186 | self.unk_token = unk_token
187 | self.max_input_chars_per_word = max_input_chars_per_word
188 |
189 | def tokenize(self, text):
190 | """Tokenizes a piece of text into its word pieces.
191 |
192 | This uses a greedy longest-match-first algorithm to perform tokenization
193 | using the given vocabulary.
194 |
195 | For example:
196 | input = "unaffable"
197 | output = ["un", "##aff", "##able"]
198 |
199 | Args:
200 | text: A single token or whitespace separated tokens. This should have
201 | already been passed through `BasicTokenizer.
202 |
203 | Returns:
204 | A list of wordpiece tokens.
205 | """
206 |
207 | text = convert_to_unicode(text)
208 |
209 | output_tokens = []
210 | for token in whitespace_tokenize(text):
211 | chars = list(token)
212 | if len(chars) > self.max_input_chars_per_word:
213 | output_tokens.append(self.unk_token)
214 | continue
215 |
216 | is_bad = False
217 | start = 0
218 | sub_tokens = []
219 | while start < len(chars):
220 | end = len(chars)
221 | cur_substr = None
222 | while start < end:
223 | substr = "".join(chars[start:end])
224 | if start > 0:
225 | substr = "##" + substr
226 | if substr in self.vocab:
227 | cur_substr = substr
228 | break
229 | end -= 1
230 | if cur_substr is None:
231 | is_bad = True
232 | break
233 | sub_tokens.append(cur_substr)
234 | start = end
235 |
236 | if is_bad:
237 | output_tokens.append(self.unk_token)
238 | else:
239 | output_tokens.extend(sub_tokens)
240 | return output_tokens
241 |
242 |
243 | def _is_whitespace(char):
244 | """Checks whether `chars` is a whitespace character."""
245 | # \t, \n, and \r are technically contorl characters but we treat them
246 | # as whitespace since they are generally considered as such.
247 | if char == " " or char == "\t" or char == "\n" or char == "\r":
248 | return True
249 | cat = unicodedata.category(char)
250 | if cat == "Zs":
251 | return True
252 | return False
253 |
254 |
255 | def _is_control(char):
256 | """Checks whether `chars` is a control character."""
257 | # These are technically control characters but we count them as whitespace
258 | # characters.
259 | if char == "\t" or char == "\n" or char == "\r":
260 | return False
261 | cat = unicodedata.category(char)
262 | if cat.startswith("C"):
263 | return True
264 | return False
265 |
266 |
267 | def _is_punctuation(char):
268 | """Checks whether `chars` is a punctuation character."""
269 | cp = ord(char)
270 | # We treat all non-letter/number ASCII as punctuation.
271 | # Characters such as "^", "$", and "`" are not in the Unicode
272 | # Punctuation class but we treat them as punctuation anyways, for
273 | # consistency.
274 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or
275 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
276 | return True
277 | cat = unicodedata.category(char)
278 | if cat.startswith("P"):
279 | return True
280 | return False
281 |
--------------------------------------------------------------------------------
/evaluation.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import collections
3 |
4 | import numpy as np
5 | import pandas as pd
6 | from sklearn import metrics
7 | from sklearn.preprocessing import label_binarize
8 |
9 |
10 | def get_y_true(task_name):
11 | """
12 | Read file to obtain y_true.
13 | All of five tasks of Sentihood use the test set of task-BERT-pair-NLI-M to get true labels.
14 | All of five tasks of SemEval-2014 use the test set of task-BERT-pair-NLI-M to get true labels.
15 | """
16 | if task_name in ["sentihood_single", "sentihood_NLI_M", "sentihood_QA_M", "sentihood_NLI_B", "sentihood_QA_B"]:
17 | true_data_file = "data/sentihood/bert-pair/test_NLI_M.tsv"
18 |
19 | df = pd.read_csv(true_data_file,sep='\t')
20 | y_true = []
21 | for i in range(len(df)):
22 | label = df['label'][i]
23 | assert label in ['None', 'Positive', 'Negative'], "error!"
24 | if label == 'None':
25 | n = 0
26 | elif label == 'Positive':
27 | n = 1
28 | else:
29 | n = 2
30 | y_true.append(n)
31 | else:
32 | true_data_file = "data/semeval2014/bert-pair/test_NLI_M.csv"
33 |
34 | df = pd.read_csv(true_data_file,sep='\t',header=None).values
35 | y_true=[]
36 | for i in range(len(df)):
37 | label = df[i][1]
38 | assert label in ['positive', 'neutral', 'negative', 'conflict', 'none'], "error!"
39 | if label == 'positive':
40 | n = 0
41 | elif label == 'neutral':
42 | n = 1
43 | elif label == 'negative':
44 | n = 2
45 | elif label == 'conflict':
46 | n = 3
47 | elif label == 'none':
48 | n = 4
49 | y_true.append(n)
50 |
51 | return y_true
52 |
53 |
54 | def get_y_pred(task_name, pred_data_dir):
55 | """
56 | Read file to obtain y_pred and scores.
57 | """
58 | pred=[]
59 | score=[]
60 | if task_name in ["sentihood_NLI_M", "sentihood_QA_M"]:
61 | with open(pred_data_dir, "r", encoding="utf-8") as f:
62 | s=f.readline().strip().split()
63 | while s:
64 | pred.append(int(s[0]))
65 | score.append([float(s[1]),float(s[2]),float(s[3])])
66 | s = f.readline().strip().split()
67 | elif task_name in ["sentihood_NLI_B", "sentihood_QA_B"]:
68 | count = 0
69 | tmp = []
70 | with open(pred_data_dir, "r", encoding="utf-8") as f:
71 | s = f.readline().strip().split()
72 | while s:
73 | tmp.append([float(s[2])])
74 | count += 1
75 | if count % 3 == 0:
76 | tmp_sum = np.sum(tmp)
77 | t = []
78 | for i in range(3):
79 | t.append(tmp[i] / tmp_sum)
80 | score.append(t)
81 | if t[0] >= t[1] and t[0] >= t[2]:
82 | pred.append(0)
83 | elif t[1] >= t[0] and t[1] >= t[2]:
84 | pred.append(1)
85 | else:
86 | pred.append(2)
87 | tmp = []
88 | s = f.readline().strip().split()
89 | elif task_name == "sentihood_single":
90 | count = 0
91 | with open(pred_data_dir + "loc1_general.txt", "r", encoding="utf-8") as f1_general, \
92 | open(pred_data_dir + "loc1_price.txt", "r", encoding="utf-8") as f1_price, \
93 | open(pred_data_dir + "loc1_safety.txt", "r", encoding="utf-8") as f1_safety, \
94 | open(pred_data_dir + "loc1_transit.txt", "r", encoding="utf-8") as f1_transit:
95 | s = f1_general.readline().strip().split()
96 | while s:
97 | count += 1
98 | pred.append(int(s[0]))
99 | score.append([float(s[1]), float(s[2]), float(s[3])])
100 | if count % 4 == 0:
101 | s = f1_general.readline().strip().split()
102 | if count % 4 == 1:
103 | s = f1_price.readline().strip().split()
104 | if count % 4 == 2:
105 | s = f1_safety.readline().strip().split()
106 | if count % 4 == 3:
107 | s = f1_transit.readline().strip().split()
108 |
109 | with open(pred_data_dir + "loc2_general.txt", "r", encoding="utf-8") as f2_general, \
110 | open(pred_data_dir + "loc2_price.txt", "r", encoding="utf-8") as f2_price, \
111 | open(pred_data_dir + "loc2_safety.txt", "r", encoding="utf-8") as f2_safety, \
112 | open(pred_data_dir + "loc2_transit.txt", "r", encoding="utf-8") as f2_transit:
113 | s = f2_general.readline().strip().split()
114 | while s:
115 | count += 1
116 | pred.append(int(s[0]))
117 | score.append([float(s[1]), float(s[2]), float(s[3])])
118 | if count % 4 == 0:
119 | s = f2_general.readline().strip().split()
120 | if count % 4 == 1:
121 | s = f2_price.readline().strip().split()
122 | if count % 4 == 2:
123 | s = f2_safety.readline().strip().split()
124 | if count % 4 == 3:
125 | s = f2_transit.readline().strip().split()
126 | elif task_name in ["semeval_NLI_M", "semeval_QA_M"]:
127 | with open(pred_data_dir,"r",encoding="utf-8") as f:
128 | s=f.readline().strip().split()
129 | while s:
130 | pred.append(int(s[0]))
131 | score.append([float(s[1]), float(s[2]), float(s[3]), float(s[4]), float(s[5])])
132 | s = f.readline().strip().split()
133 | elif task_name in ["semeval_NLI_B", "semeval_QA_B"]:
134 | count = 0
135 | tmp = []
136 | with open(pred_data_dir, "r", encoding="utf-8") as f:
137 | s = f.readline().strip().split()
138 | while s:
139 | tmp.append([float(s[2])])
140 | count += 1
141 | if count % 5 == 0:
142 | tmp_sum = np.sum(tmp)
143 | t = []
144 | for i in range(5):
145 | t.append(tmp[i] / tmp_sum)
146 | score.append(t)
147 | if t[0] >= t[1] and t[0] >= t[2] and t[0]>=t[3] and t[0]>=t[4]:
148 | pred.append(0)
149 | elif t[1] >= t[0] and t[1] >= t[2] and t[1]>=t[3] and t[1]>=t[4]:
150 | pred.append(1)
151 | elif t[2] >= t[0] and t[2] >= t[1] and t[2]>=t[3] and t[2]>=t[4]:
152 | pred.append(2)
153 | elif t[3] >= t[0] and t[3] >= t[1] and t[3]>=t[2] and t[3]>=t[4]:
154 | pred.append(3)
155 | else:
156 | pred.append(4)
157 | tmp = []
158 | s = f.readline().strip().split()
159 | else:
160 | count = 0
161 | with open(pred_data_dir+"price.txt","r",encoding="utf-8") as f_price, \
162 | open(pred_data_dir+"anecdotes.txt", "r", encoding="utf-8") as f_anecdotes, \
163 | open(pred_data_dir+"food.txt", "r", encoding="utf-8") as f_food, \
164 | open(pred_data_dir+"ambience.txt", "r", encoding="utf-8") as f_ambience, \
165 | open(pred_data_dir+"service.txt", "r", encoding="utf-8") as f_service:
166 | s = f_price.readline().strip().split()
167 | while s:
168 | count += 1
169 | pred.append(int(s[0]))
170 | score.append([float(s[1]), float(s[2]), float(s[3]), float(s[4]), float(s[5])])
171 | if count % 5 == 0:
172 | s = f_price.readline().strip().split()
173 | if count % 5 == 1:
174 | s = f_anecdotes.readline().strip().split()
175 | if count % 5 == 2:
176 | s = f_food.readline().strip().split()
177 | if count % 5 == 3:
178 | s = f_ambience.readline().strip().split()
179 | if count % 5 == 4:
180 | s = f_service.readline().strip().split()
181 |
182 | return pred, score
183 |
184 |
185 | def sentihood_strict_acc(y_true, y_pred):
186 | """
187 | Calculate "strict Acc" of aspect detection task of Sentihood.
188 | """
189 | total_cases=int(len(y_true)/4)
190 | true_cases=0
191 | for i in range(total_cases):
192 | if y_true[i*4]!=y_pred[i*4]:continue
193 | if y_true[i*4+1]!=y_pred[i*4+1]:continue
194 | if y_true[i*4+2]!=y_pred[i*4+2]:continue
195 | if y_true[i*4+3]!=y_pred[i*4+3]:continue
196 | true_cases+=1
197 | aspect_strict_Acc = true_cases/total_cases
198 |
199 | return aspect_strict_Acc
200 |
201 |
202 | def sentihood_macro_F1(y_true, y_pred):
203 | """
204 | Calculate "Macro-F1" of aspect detection task of Sentihood.
205 | """
206 | p_all=0
207 | r_all=0
208 | count=0
209 | for i in range(len(y_pred)//4):
210 | a=set()
211 | b=set()
212 | for j in range(4):
213 | if y_pred[i*4+j]!=0:
214 | a.add(j)
215 | if y_true[i*4+j]!=0:
216 | b.add(j)
217 | if len(b)==0:continue
218 | a_b=a.intersection(b)
219 | if len(a_b)>0:
220 | p=len(a_b)/len(a)
221 | r=len(a_b)/len(b)
222 | else:
223 | p=0
224 | r=0
225 | count+=1
226 | p_all+=p
227 | r_all+=r
228 | Ma_p=p_all/count
229 | Ma_r=r_all/count
230 | aspect_Macro_F1 = 2*Ma_p*Ma_r/(Ma_p+Ma_r)
231 |
232 | return aspect_Macro_F1
233 |
234 |
235 | def sentihood_AUC_Acc(y_true, score):
236 | """
237 | Calculate "Macro-AUC" of both aspect detection and sentiment classification tasks of Sentihood.
238 | Calculate "Acc" of sentiment classification task of Sentihood.
239 | """
240 | # aspect-Macro-AUC
241 | aspect_y_true=[]
242 | aspect_y_score=[]
243 | aspect_y_trues=[[],[],[],[]]
244 | aspect_y_scores=[[],[],[],[]]
245 | for i in range(len(y_true)):
246 | if y_true[i]>0:
247 | aspect_y_true.append(0)
248 | else:
249 | aspect_y_true.append(1) # "None": 1
250 | tmp_score=score[i][0] # probability of "None"
251 | aspect_y_score.append(tmp_score)
252 | aspect_y_trues[i%4].append(aspect_y_true[-1])
253 | aspect_y_scores[i%4].append(aspect_y_score[-1])
254 |
255 | aspect_auc=[]
256 | for i in range(4):
257 | aspect_auc.append(metrics.roc_auc_score(aspect_y_trues[i], aspect_y_scores[i]))
258 | aspect_Macro_AUC = np.mean(aspect_auc)
259 |
260 | # sentiment-Macro-AUC
261 | sentiment_y_true=[]
262 | sentiment_y_pred=[]
263 | sentiment_y_score=[]
264 | sentiment_y_trues=[[],[],[],[]]
265 | sentiment_y_scores=[[],[],[],[]]
266 | for i in range(len(y_true)):
267 | if y_true[i]>0:
268 | sentiment_y_true.append(y_true[i]-1) # "Postive":0, "Negative":1
269 | tmp_score=score[i][2]/(score[i][1]+score[i][2]) # probability of "Negative"
270 | sentiment_y_score.append(tmp_score)
271 | if tmp_score>0.5:
272 | sentiment_y_pred.append(1) # "Negative": 1
273 | else:
274 | sentiment_y_pred.append(0)
275 | sentiment_y_trues[i%4].append(sentiment_y_true[-1])
276 | sentiment_y_scores[i%4].append(sentiment_y_score[-1])
277 |
278 | sentiment_auc=[]
279 | for i in range(4):
280 | sentiment_auc.append(metrics.roc_auc_score(sentiment_y_trues[i], sentiment_y_scores[i]))
281 | sentiment_Macro_AUC = np.mean(sentiment_auc)
282 |
283 | # sentiment Acc
284 | sentiment_y_true = np.array(sentiment_y_true)
285 | sentiment_y_pred = np.array(sentiment_y_pred)
286 | sentiment_Acc = metrics.accuracy_score(sentiment_y_true,sentiment_y_pred)
287 |
288 | return aspect_Macro_AUC, sentiment_Acc, sentiment_Macro_AUC
289 |
290 |
291 | def semeval_PRF(y_true, y_pred):
292 | """
293 | Calculate "Micro P R F" of aspect detection task of SemEval-2014.
294 | """
295 | s_all=0
296 | g_all=0
297 | s_g_all=0
298 | for i in range(len(y_pred)//5):
299 | s=set()
300 | g=set()
301 | for j in range(5):
302 | if y_pred[i*5+j]!=4:
303 | s.add(j)
304 | if y_true[i*5+j]!=4:
305 | g.add(j)
306 | if len(g)==0:continue
307 | s_g=s.intersection(g)
308 | s_all+=len(s)
309 | g_all+=len(g)
310 | s_g_all+=len(s_g)
311 |
312 | p=s_g_all/s_all
313 | r=s_g_all/g_all
314 | f=2*p*r/(p+r)
315 |
316 | return p,r,f
317 |
318 |
319 | def semeval_Acc(y_true, y_pred, score, classes=4):
320 | """
321 | Calculate "Acc" of sentiment classification task of SemEval-2014.
322 | """
323 | assert classes in [2, 3, 4], "classes must be 2 or 3 or 4."
324 |
325 | if classes == 4:
326 | total=0
327 | total_right=0
328 | for i in range(len(y_true)):
329 | if y_true[i]==4:continue
330 | total+=1
331 | tmp=y_pred[i]
332 | if tmp==4:
333 | if score[i][0]>=score[i][1] and score[i][0]>=score[i][2] and score[i][0]>=score[i][3]:
334 | tmp=0
335 | elif score[i][1]>=score[i][0] and score[i][1]>=score[i][2] and score[i][1]>=score[i][3]:
336 | tmp=1
337 | elif score[i][2]>=score[i][0] and score[i][2]>=score[i][1] and score[i][2]>=score[i][3]:
338 | tmp=2
339 | else:
340 | tmp=3
341 | if y_true[i]==tmp:
342 | total_right+=1
343 | sentiment_Acc = total_right/total
344 | elif classes == 3:
345 | total=0
346 | total_right=0
347 | for i in range(len(y_true)):
348 | if y_true[i]>=3:continue
349 | total+=1
350 | tmp=y_pred[i]
351 | if tmp>=3:
352 | if score[i][0]>=score[i][1] and score[i][0]>=score[i][2]:
353 | tmp=0
354 | elif score[i][1]>=score[i][0] and score[i][1]>=score[i][2]:
355 | tmp=1
356 | else:
357 | tmp=2
358 | if y_true[i]==tmp:
359 | total_right+=1
360 | sentiment_Acc = total_right/total
361 | else:
362 | total=0
363 | total_right=0
364 | for i in range(len(y_true)):
365 | if y_true[i]>=3 or y_true[i]==1:continue
366 | total+=1
367 | tmp=y_pred[i]
368 | if tmp>=3 or tmp==1:
369 | if score[i][0]>=score[i][2]:
370 | tmp=0
371 | else:
372 | tmp=2
373 | if y_true[i]==tmp:
374 | total_right+=1
375 | sentiment_Acc = total_right/total
376 |
377 | return sentiment_Acc
378 |
379 |
380 | def main():
381 | parser = argparse.ArgumentParser()
382 | parser.add_argument("--task_name",
383 | default=None,
384 | type=str,
385 | required=True,
386 | choices=["sentihood_single", "sentihood_NLI_M", "sentihood_QA_M", \
387 | "sentihood_NLI_B", "sentihood_QA_B", "semeval_single", \
388 | "semeval_NLI_M", "semeval_QA_M", "semeval_NLI_B", "semeval_QA_B"],
389 | help="The name of the task to evalution.")
390 | parser.add_argument("--pred_data_dir",
391 | default=None,
392 | type=str,
393 | required=True,
394 | help="The pred data dir.")
395 | args = parser.parse_args()
396 |
397 |
398 | result = collections.OrderedDict()
399 | if args.task_name in ["sentihood_single", "sentihood_NLI_M", "sentihood_QA_M", "sentihood_NLI_B", "sentihood_QA_B"]:
400 | y_true = get_y_true(args.task_name)
401 | y_pred, score = get_y_pred(args.task_name, args.pred_data_dir)
402 | aspect_strict_Acc = sentihood_strict_acc(y_true, y_pred)
403 | aspect_Macro_F1 = sentihood_macro_F1(y_true, y_pred)
404 | aspect_Macro_AUC, sentiment_Acc, sentiment_Macro_AUC = sentihood_AUC_Acc(y_true, score)
405 | result = {'aspect_strict_Acc': aspect_strict_Acc,
406 | 'aspect_Macro_F1': aspect_Macro_F1,
407 | 'aspect_Macro_AUC': aspect_Macro_AUC,
408 | 'sentiment_Acc': sentiment_Acc,
409 | 'sentiment_Macro_AUC': sentiment_Macro_AUC}
410 | else:
411 | y_true = get_y_true(args.task_name)
412 | y_pred, score = get_y_pred(args.task_name, args.pred_data_dir)
413 | aspect_P, aspect_R, aspect_F = semeval_PRF(y_true, y_pred)
414 | sentiment_Acc_4_classes = semeval_Acc(y_true, y_pred, score, 4)
415 | sentiment_Acc_3_classes = semeval_Acc(y_true, y_pred, score, 3)
416 | sentiment_Acc_2_classes = semeval_Acc(y_true, y_pred, score, 2)
417 | result = {'aspect_P': aspect_P,
418 | 'aspect_R': aspect_R,
419 | 'aspect_F': aspect_F,
420 | 'sentiment_Acc_4_classes': sentiment_Acc_4_classes,
421 | 'sentiment_Acc_3_classes': sentiment_Acc_3_classes,
422 | 'sentiment_Acc_2_classes': sentiment_Acc_2_classes}
423 |
424 | for key in result.keys():
425 | print(key, "=",str(result[key]))
426 |
427 |
428 | if __name__ == "__main__":
429 | main()
430 |
--------------------------------------------------------------------------------
/processor.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 |
3 | """Processors for different tasks."""
4 |
5 | import csv
6 | import os
7 |
8 | import pandas as pd
9 |
10 | import tokenization
11 |
12 |
13 | class InputExample(object):
14 | """A single training/test example for simple sequence classification."""
15 |
16 | def __init__(self, guid, text_a, text_b=None, label=None):
17 | """Constructs a InputExample.
18 |
19 | Args:
20 | guid: Unique id for the example.
21 | text_a: string. The untokenized text of the first sequence. For single
22 | sequence tasks, only this sequence must be specified.
23 | text_b: (Optional) string. The untokenized text of the second sequence.
24 | Only must be specified for sequence pair tasks.
25 | label: (Optional) string. The label of the example. This should be
26 | specified for train and dev examples, but not for test examples.
27 | """
28 | self.guid = guid
29 | self.text_a = text_a
30 | self.text_b = text_b
31 | self.label = label
32 |
33 |
34 | class DataProcessor(object):
35 | """Base class for data converters for sequence classification data sets."""
36 |
37 | def get_train_examples(self, data_dir):
38 | """Gets a collection of `InputExample`s for the train set."""
39 | raise NotImplementedError()
40 |
41 | def get_dev_examples(self, data_dir):
42 | """Gets a collection of `InputExample`s for the dev set."""
43 | raise NotImplementedError()
44 |
45 | def get_test_examples(self, data_dir):
46 | """Gets a collection of `InputExample`s for the test set."""
47 | raise NotImplementedError()
48 |
49 | def get_labels(self):
50 | """Gets the list of labels for this data set."""
51 | raise NotImplementedError()
52 |
53 | @classmethod
54 | def _read_tsv(cls, input_file, quotechar=None):
55 | """Reads a tab separated value file."""
56 | with open(input_file, "r") as f:
57 | reader = csv.reader(f, delimiter="\t", quotechar=quotechar)
58 | lines = []
59 | for line in reader:
60 | lines.append(line)
61 | return lines
62 |
63 |
64 | class Sentihood_single_Processor(DataProcessor):
65 | """Processor for the Sentihood data set."""
66 |
67 | def get_train_examples(self, data_dir):
68 | """See base class."""
69 | train_data = pd.read_csv(os.path.join(data_dir, "train.tsv"),header=None,sep="\t").values
70 | return self._create_examples(train_data, "train")
71 |
72 | def get_dev_examples(self, data_dir):
73 | """See base class."""
74 | dev_data = pd.read_csv(os.path.join(data_dir, "dev.tsv"),header=None,sep="\t").values
75 | return self._create_examples(dev_data, "dev")
76 |
77 | def get_test_examples(self, data_dir):
78 | """See base class."""
79 | test_data = pd.read_csv(os.path.join(data_dir, "test.tsv"),header=None,sep="\t").values
80 | return self._create_examples(test_data, "test")
81 |
82 | def get_labels(self):
83 | """See base class."""
84 | return ['None', 'Positive', 'Negative']
85 |
86 | def _create_examples(self, lines, set_type):
87 | """Creates examples for the training and dev sets."""
88 | examples = []
89 | for (i, line) in enumerate(lines):
90 | # if i>50:break
91 | guid = "%s-%s" % (set_type, i)
92 | text_a = tokenization.convert_to_unicode(str(line[1]))
93 | label = tokenization.convert_to_unicode(str(line[2]))
94 | if i%1000==0:
95 | print(i)
96 | print("guid=",guid)
97 | print("text_a=",text_a)
98 | print("label=",label)
99 | examples.append(
100 | InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
101 | return examples
102 |
103 |
104 | class Sentihood_NLI_M_Processor(DataProcessor):
105 | """Processor for the Sentihood data set."""
106 |
107 | def get_train_examples(self, data_dir):
108 | """See base class."""
109 | train_data = pd.read_csv(os.path.join(data_dir, "train_NLI_M.tsv"),sep="\t").values
110 | return self._create_examples(train_data, "train")
111 |
112 | def get_dev_examples(self, data_dir):
113 | """See base class."""
114 | dev_data = pd.read_csv(os.path.join(data_dir, "dev_NLI_M.tsv"),sep="\t").values
115 | return self._create_examples(dev_data, "dev")
116 |
117 | def get_test_examples(self, data_dir):
118 | """See base class."""
119 | test_data = pd.read_csv(os.path.join(data_dir, "test_NLI_M.tsv"),sep="\t").values
120 | return self._create_examples(test_data, "test")
121 |
122 | def get_labels(self):
123 | """See base class."""
124 | return ['None', 'Positive', 'Negative']
125 |
126 | def _create_examples(self, lines, set_type):
127 | """Creates examples for the training and dev sets."""
128 | examples = []
129 | for (i, line) in enumerate(lines):
130 | # if i>50:break
131 | guid = "%s-%s" % (set_type, i)
132 | text_a = tokenization.convert_to_unicode(str(line[1]))
133 | text_b = tokenization.convert_to_unicode(str(line[2]))
134 | label = tokenization.convert_to_unicode(str(line[3]))
135 | if i%1000==0:
136 | print(i)
137 | print("guid=",guid)
138 | print("text_a=",text_a)
139 | print("text_b=",text_b)
140 | print("label=",label)
141 | examples.append(
142 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
143 | return examples
144 |
145 |
146 | class Sentihood_QA_M_Processor(DataProcessor):
147 | """Processor for the Sentihood data set."""
148 |
149 | def get_train_examples(self, data_dir):
150 | """See base class."""
151 | train_data = pd.read_csv(os.path.join(data_dir, "train_QA_M.tsv"),sep="\t").values
152 | return self._create_examples(train_data, "train")
153 |
154 | def get_dev_examples(self, data_dir):
155 | """See base class."""
156 | dev_data = pd.read_csv(os.path.join(data_dir, "dev_QA_M.tsv"),sep="\t").values
157 | return self._create_examples(dev_data, "dev")
158 |
159 | def get_test_examples(self, data_dir):
160 | """See base class."""
161 | test_data = pd.read_csv(os.path.join(data_dir, "test_QA_M.tsv"),sep="\t").values
162 | return self._create_examples(test_data, "test")
163 |
164 | def get_labels(self):
165 | """See base class."""
166 | return ['None', 'Positive', 'Negative']
167 |
168 | def _create_examples(self, lines, set_type):
169 | """Creates examples for the training and dev sets."""
170 | examples = []
171 | for (i, line) in enumerate(lines):
172 | # if i>50:break
173 | guid = "%s-%s" % (set_type, i)
174 | text_a = tokenization.convert_to_unicode(str(line[1]))
175 | text_b = tokenization.convert_to_unicode(str(line[2]))
176 | label = tokenization.convert_to_unicode(str(line[3]))
177 | if i%1000==0:
178 | print(i)
179 | print("guid=",guid)
180 | print("text_a=",text_a)
181 | print("text_b=",text_b)
182 | print("label=",label)
183 | examples.append(
184 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
185 | return examples
186 |
187 |
188 | class Sentihood_NLI_B_Processor(DataProcessor):
189 | """Processor for the Sentihood data set."""
190 |
191 | def get_train_examples(self, data_dir):
192 | """See base class."""
193 | train_data = pd.read_csv(os.path.join(data_dir, "train_NLI_B.tsv"),sep="\t").values
194 | return self._create_examples(train_data, "train")
195 |
196 | def get_dev_examples(self, data_dir):
197 | """See base class."""
198 | dev_data = pd.read_csv(os.path.join(data_dir, "dev_NLI_B.tsv"),sep="\t").values
199 | return self._create_examples(dev_data, "dev")
200 |
201 | def get_test_examples(self, data_dir):
202 | """See base class."""
203 | test_data = pd.read_csv(os.path.join(data_dir, "test_NLI_B.tsv"),sep="\t").values
204 | return self._create_examples(test_data, "test")
205 |
206 | def get_labels(self):
207 | """See base class."""
208 | return ['0', '1']
209 |
210 | def _create_examples(self, lines, set_type):
211 | """Creates examples for the training and dev sets."""
212 | examples = []
213 | for (i, line) in enumerate(lines):
214 | # if i>50:break
215 | guid = "%s-%s" % (set_type, i)
216 | text_a = tokenization.convert_to_unicode(str(line[2]))
217 | text_b = tokenization.convert_to_unicode(str(line[1]))
218 | label = tokenization.convert_to_unicode(str(line[3]))
219 | if i%1000==0:
220 | print(i)
221 | print("guid=",guid)
222 | print("text_a=",text_a)
223 | print("text_b=",text_b)
224 | print("label=",label)
225 | examples.append(
226 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
227 | return examples
228 |
229 |
230 | class Sentihood_QA_B_Processor(DataProcessor):
231 | """Processor for the Sentihood data set."""
232 |
233 | def get_train_examples(self, data_dir):
234 | """See base class."""
235 | train_data = pd.read_csv(os.path.join(data_dir, "train_QA_B.tsv"),sep="\t").values
236 | return self._create_examples(train_data, "train")
237 |
238 | def get_dev_examples(self, data_dir):
239 | """See base class."""
240 | dev_data = pd.read_csv(os.path.join(data_dir, "dev_QA_B.tsv"),sep="\t").values
241 | return self._create_examples(dev_data, "dev")
242 |
243 | def get_test_examples(self, data_dir):
244 | """See base class."""
245 | test_data = pd.read_csv(os.path.join(data_dir, "test_QA_B.tsv"),sep="\t").values
246 | return self._create_examples(test_data, "test")
247 |
248 | def get_labels(self):
249 | """See base class."""
250 | return ['0', '1']
251 |
252 | def _create_examples(self, lines, set_type):
253 | """Creates examples for the training and dev sets."""
254 | examples = []
255 | for (i, line) in enumerate(lines):
256 | # if i>50:break
257 | guid = "%s-%s" % (set_type, i)
258 | text_a = tokenization.convert_to_unicode(str(line[2]))
259 | text_b = tokenization.convert_to_unicode(str(line[1]))
260 | label = tokenization.convert_to_unicode(str(line[3]))
261 | if i%1000==0:
262 | print(i)
263 | print("guid=",guid)
264 | print("text_a=",text_a)
265 | print("text_b=",text_b)
266 | print("label=",label)
267 | examples.append(
268 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
269 | return examples
270 |
271 |
272 | class Semeval_single_Processor(DataProcessor):
273 | """Processor for the Semeval 2014 data set."""
274 |
275 | def get_train_examples(self, data_dir):
276 | """See base class."""
277 | train_data = pd.read_csv(os.path.join(data_dir, "train.csv"),header=None,sep="\t").values
278 | return self._create_examples(train_data, "train")
279 |
280 | def get_dev_examples(self, data_dir):
281 | """See base class."""
282 | dev_data = pd.read_csv(os.path.join(data_dir, "dev.csv"),header=None,sep="\t").values
283 | return self._create_examples(dev_data, "dev")
284 |
285 | def get_test_examples(self, data_dir):
286 | """See base class."""
287 | test_data = pd.read_csv(os.path.join(data_dir, "test.csv"),header=None,sep="\t").values
288 | return self._create_examples(test_data, "test")
289 |
290 | def get_labels(self):
291 | """See base class."""
292 | return ['positive', 'neutral', 'negative', 'conflict', 'none']
293 |
294 | def _create_examples(self, lines, set_type):
295 | """Creates examples for the training and dev sets."""
296 | examples = []
297 | for (i, line) in enumerate(lines):
298 | # if i>50:break
299 | guid = "%s-%s" % (set_type, i)
300 | text_a = tokenization.convert_to_unicode(str(line[3]))
301 | label = tokenization.convert_to_unicode(str(line[1]))
302 | if i%1000==0:
303 | print(i)
304 | print("guid=",guid)
305 | print("text_a=",text_a)
306 | print("label=",label)
307 | examples.append(
308 | InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
309 | return examples
310 |
311 |
312 | class Semeval_NLI_M_Processor(DataProcessor):
313 | """Processor for the Semeval 2014 data set."""
314 |
315 | def get_train_examples(self, data_dir):
316 | """See base class."""
317 | train_data = pd.read_csv(os.path.join(data_dir, "train_NLI_M.csv"),header=None,sep="\t").values
318 | return self._create_examples(train_data, "train")
319 |
320 | def get_dev_examples(self, data_dir):
321 | """See base class."""
322 | dev_data = pd.read_csv(os.path.join(data_dir, "dev_NLI_M.csv"),header=None,sep="\t").values
323 | return self._create_examples(dev_data, "dev")
324 |
325 | def get_test_examples(self, data_dir):
326 | """See base class."""
327 | test_data = pd.read_csv(os.path.join(data_dir, "test_NLI_M.csv"),header=None,sep="\t").values
328 | return self._create_examples(test_data, "test")
329 |
330 | def get_labels(self):
331 | """See base class."""
332 | return ['positive', 'neutral', 'negative', 'conflict', 'none']
333 |
334 | def _create_examples(self, lines, set_type):
335 | """Creates examples for the training and dev sets."""
336 | examples = []
337 | for (i, line) in enumerate(lines):
338 | # if i>50:break
339 | guid = "%s-%s" % (set_type, i)
340 | text_a = tokenization.convert_to_unicode(str(line[3]))
341 | text_b = tokenization.convert_to_unicode(str(line[2]))
342 | label = tokenization.convert_to_unicode(str(line[1]))
343 | if i%1000==0:
344 | print(i)
345 | print("guid=",guid)
346 | print("text_a=",text_a)
347 | print("label=",label)
348 | examples.append(
349 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
350 | return examples
351 |
352 |
353 | class Semeval_QA_M_Processor(DataProcessor):
354 | """Processor for the Semeval 2014 data set."""
355 |
356 | def get_train_examples(self, data_dir):
357 | """See base class."""
358 | train_data = pd.read_csv(os.path.join(data_dir, "train_QA_M.csv"),header=None,sep="\t").values
359 | return self._create_examples(train_data, "train")
360 |
361 | def get_dev_examples(self, data_dir):
362 | """See base class."""
363 | dev_data = pd.read_csv(os.path.join(data_dir, "dev_QA_M.csv"),header=None,sep="\t").values
364 | return self._create_examples(dev_data, "dev")
365 |
366 | def get_test_examples(self, data_dir):
367 | """See base class."""
368 | test_data = pd.read_csv(os.path.join(data_dir, "test_QA_M.csv"),header=None,sep="\t").values
369 | return self._create_examples(test_data, "test")
370 |
371 | def get_labels(self):
372 | """See base class."""
373 | return ['positive', 'neutral', 'negative', 'conflict', 'none']
374 |
375 | def _create_examples(self, lines, set_type):
376 | """Creates examples for the training and dev sets."""
377 | examples = []
378 | for (i, line) in enumerate(lines):
379 | # if i>50:break
380 | guid = "%s-%s" % (set_type, i)
381 | text_a = tokenization.convert_to_unicode(str(line[3]))
382 | text_b = tokenization.convert_to_unicode(str(line[2]))
383 | label = tokenization.convert_to_unicode(str(line[1]))
384 | if i%1000==0:
385 | print(i)
386 | print("guid=",guid)
387 | print("text_a=",text_a)
388 | print("label=",label)
389 | examples.append(
390 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
391 | return examples
392 |
393 |
394 | class Semeval_NLI_B_Processor(DataProcessor):
395 | """Processor for the Semeval 2014 data set."""
396 |
397 | def get_train_examples(self, data_dir):
398 | """See base class."""
399 | train_data = pd.read_csv(os.path.join(data_dir, "train_NLI_B.csv"),header=None,sep="\t").values
400 | return self._create_examples(train_data, "train")
401 |
402 | def get_dev_examples(self, data_dir):
403 | """See base class."""
404 | dev_data = pd.read_csv(os.path.join(data_dir, "dev_NLI_B.csv"),header=None,sep="\t").values
405 | return self._create_examples(dev_data, "dev")
406 |
407 | def get_test_examples(self, data_dir):
408 | """See base class."""
409 | test_data = pd.read_csv(os.path.join(data_dir, "test_NLI_B.csv"),header=None,sep="\t").values
410 | return self._create_examples(test_data, "test")
411 |
412 | def get_labels(self):
413 | """See base class."""
414 | return ['0', '1']
415 |
416 | def _create_examples(self, lines, set_type):
417 | """Creates examples for the training and dev sets."""
418 | examples = []
419 | for (i, line) in enumerate(lines):
420 | # if i>50:break
421 | guid = "%s-%s" % (set_type, i)
422 | text_a = tokenization.convert_to_unicode(str(line[2]))
423 | text_b = tokenization.convert_to_unicode(str(line[3]))
424 | label = tokenization.convert_to_unicode(str(line[1]))
425 | if i%1000==0:
426 | print(i)
427 | print("guid=",guid)
428 | print("text_a=",text_a)
429 | print("label=",label)
430 | examples.append(
431 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
432 | return examples
433 |
434 |
435 | class Semeval_QA_B_Processor(DataProcessor):
436 | """Processor for the Semeval 2014 data set."""
437 |
438 | def get_train_examples(self, data_dir):
439 | """See base class."""
440 | train_data = pd.read_csv(os.path.join(data_dir, "train_QA_B.csv"),header=None,sep="\t").values
441 | return self._create_examples(train_data, "train")
442 |
443 | def get_dev_examples(self, data_dir):
444 | """See base class."""
445 | dev_data = pd.read_csv(os.path.join(data_dir, "dev_QA_B.csv"),header=None,sep="\t").values
446 | return self._create_examples(dev_data, "dev")
447 |
448 | def get_test_examples(self, data_dir):
449 | """See base class."""
450 | test_data = pd.read_csv(os.path.join(data_dir, "test_QA_B.csv"),header=None,sep="\t").values
451 | return self._create_examples(test_data, "test")
452 |
453 | def get_labels(self):
454 | """See base class."""
455 | return ['0', '1']
456 |
457 | def _create_examples(self, lines, set_type):
458 | """Creates examples for the training and dev sets."""
459 | examples = []
460 | for (i, line) in enumerate(lines):
461 | # if i>50:break
462 | guid = "%s-%s" % (set_type, i)
463 | text_a = tokenization.convert_to_unicode(str(line[2]))
464 | text_b = tokenization.convert_to_unicode(str(line[3]))
465 | label = tokenization.convert_to_unicode(str(line[1]))
466 | if i%1000==0:
467 | print(i)
468 | print("guid=",guid)
469 | print("text_a=",text_a)
470 | print("label=",label)
471 | examples.append(
472 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
473 | return examples
474 |
--------------------------------------------------------------------------------
/modeling.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 |
3 | # Reference: https://github.com/huggingface/pytorch-pretrained-BERT
4 |
5 | """PyTorch BERT model."""
6 |
7 | from __future__ import absolute_import, division, print_function
8 |
9 | import copy
10 | import json
11 | import math
12 |
13 | import six
14 | import torch
15 | import torch.nn as nn
16 | from torch.nn import CrossEntropyLoss
17 |
18 |
19 | def gelu(x):
20 | """Implementation of the gelu activation function.
21 | For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
22 | 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
23 | """
24 | return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
25 |
26 |
27 | class BertConfig(object):
28 | """Configuration class to store the configuration of a `BertModel`.
29 | """
30 | def __init__(self,
31 | vocab_size,
32 | hidden_size=768,
33 | num_hidden_layers=12,
34 | num_attention_heads=12,
35 | intermediate_size=3072,
36 | hidden_act="gelu",
37 | hidden_dropout_prob=0.1,
38 | attention_probs_dropout_prob=0.1,
39 | max_position_embeddings=512,
40 | type_vocab_size=16,
41 | initializer_range=0.02):
42 | """Constructs BertConfig.
43 |
44 | Args:
45 | vocab_size: Vocabulary size of `inputs_ids` in `BertModel`.
46 | hidden_size: Size of the encoder layers and the pooler layer.
47 | num_hidden_layers: Number of hidden layers in the Transformer encoder.
48 | num_attention_heads: Number of attention heads for each attention layer in
49 | the Transformer encoder.
50 | intermediate_size: The size of the "intermediate" (i.e., feed-forward)
51 | layer in the Transformer encoder.
52 | hidden_act: The non-linear activation function (function or string) in the
53 | encoder and pooler.
54 | hidden_dropout_prob: The dropout probabilitiy for all fully connected
55 | layers in the embeddings, encoder, and pooler.
56 | attention_probs_dropout_prob: The dropout ratio for the attention
57 | probabilities.
58 | max_position_embeddings: The maximum sequence length that this model might
59 | ever be used with. Typically set this to something large just in case
60 | (e.g., 512 or 1024 or 2048).
61 | type_vocab_size: The vocabulary size of the `token_type_ids` passed into
62 | `BertModel`.
63 | initializer_range: The sttdev of the truncated_normal_initializer for
64 | initializing all weight matrices.
65 | """
66 | self.vocab_size = vocab_size
67 | self.hidden_size = hidden_size
68 | self.num_hidden_layers = num_hidden_layers
69 | self.num_attention_heads = num_attention_heads
70 | self.hidden_act = hidden_act
71 | self.intermediate_size = intermediate_size
72 | self.hidden_dropout_prob = hidden_dropout_prob
73 | self.attention_probs_dropout_prob = attention_probs_dropout_prob
74 | self.max_position_embeddings = max_position_embeddings
75 | self.type_vocab_size = type_vocab_size
76 | self.initializer_range = initializer_range
77 |
78 | @classmethod
79 | def from_dict(cls, json_object):
80 | """Constructs a `BertConfig` from a Python dictionary of parameters."""
81 | config = BertConfig(vocab_size=None)
82 | for (key, value) in six.iteritems(json_object):
83 | config.__dict__[key] = value
84 | return config
85 |
86 | @classmethod
87 | def from_json_file(cls, json_file):
88 | """Constructs a `BertConfig` from a json file of parameters."""
89 | with open(json_file, "r") as reader:
90 | text = reader.read()
91 | return cls.from_dict(json.loads(text))
92 |
93 | def to_dict(self):
94 | """Serializes this instance to a Python dictionary."""
95 | output = copy.deepcopy(self.__dict__)
96 | return output
97 |
98 | def to_json_string(self):
99 | """Serializes this instance to a JSON string."""
100 | return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
101 |
102 |
103 | class BERTLayerNorm(nn.Module):
104 | def __init__(self, config, variance_epsilon=1e-12):
105 | """Construct a layernorm module in the TF style (epsilon inside the square root).
106 | """
107 | super(BERTLayerNorm, self).__init__()
108 | self.gamma = nn.Parameter(torch.ones(config.hidden_size))
109 | self.beta = nn.Parameter(torch.zeros(config.hidden_size))
110 | self.variance_epsilon = variance_epsilon
111 |
112 | def forward(self, x):
113 | u = x.mean(-1, keepdim=True)
114 | s = (x - u).pow(2).mean(-1, keepdim=True)
115 | x = (x - u) / torch.sqrt(s + self.variance_epsilon)
116 | return self.gamma * x + self.beta
117 |
118 | class BERTEmbeddings(nn.Module):
119 | def __init__(self, config):
120 | super(BERTEmbeddings, self).__init__()
121 | """Construct the embedding module from word, position and token_type embeddings.
122 | """
123 | self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
124 | self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
125 | self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
126 |
127 | # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
128 | # any TensorFlow checkpoint file
129 | self.LayerNorm = BERTLayerNorm(config)
130 | self.dropout = nn.Dropout(config.hidden_dropout_prob)
131 |
132 | def forward(self, input_ids, token_type_ids=None):
133 | seq_length = input_ids.size(1)
134 | position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
135 | position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
136 | if token_type_ids is None:
137 | token_type_ids = torch.zeros_like(input_ids)
138 |
139 | words_embeddings = self.word_embeddings(input_ids)
140 | position_embeddings = self.position_embeddings(position_ids)
141 | token_type_embeddings = self.token_type_embeddings(token_type_ids)
142 |
143 | embeddings = words_embeddings + position_embeddings + token_type_embeddings
144 | embeddings = self.LayerNorm(embeddings)
145 | embeddings = self.dropout(embeddings)
146 | return embeddings
147 |
148 |
149 | class BERTSelfAttention(nn.Module):
150 | def __init__(self, config):
151 | super(BERTSelfAttention, self).__init__()
152 | if config.hidden_size % config.num_attention_heads != 0:
153 | raise ValueError(
154 | "The hidden size (%d) is not a multiple of the number of attention "
155 | "heads (%d)" % (config.hidden_size, config.num_attention_heads))
156 | self.num_attention_heads = config.num_attention_heads
157 | self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
158 | self.all_head_size = self.num_attention_heads * self.attention_head_size
159 |
160 | self.query = nn.Linear(config.hidden_size, self.all_head_size)
161 | self.key = nn.Linear(config.hidden_size, self.all_head_size)
162 | self.value = nn.Linear(config.hidden_size, self.all_head_size)
163 |
164 | self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
165 |
166 | def transpose_for_scores(self, x):
167 | new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
168 | x = x.view(*new_x_shape)
169 | return x.permute(0, 2, 1, 3)
170 |
171 | def forward(self, hidden_states, attention_mask):
172 | mixed_query_layer = self.query(hidden_states)
173 | mixed_key_layer = self.key(hidden_states)
174 | mixed_value_layer = self.value(hidden_states)
175 |
176 | query_layer = self.transpose_for_scores(mixed_query_layer)
177 | key_layer = self.transpose_for_scores(mixed_key_layer)
178 | value_layer = self.transpose_for_scores(mixed_value_layer)
179 |
180 | # Take the dot product between "query" and "key" to get the raw attention scores.
181 | attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
182 | attention_scores = attention_scores / math.sqrt(self.attention_head_size)
183 | # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
184 | attention_scores = attention_scores + attention_mask
185 |
186 | # Normalize the attention scores to probabilities.
187 | attention_probs = nn.Softmax(dim=-1)(attention_scores)
188 |
189 | # This is actually dropping out entire tokens to attend to, which might
190 | # seem a bit unusual, but is taken from the original Transformer paper.
191 | attention_probs = self.dropout(attention_probs)
192 |
193 | context_layer = torch.matmul(attention_probs, value_layer)
194 | context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
195 | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
196 | context_layer = context_layer.view(*new_context_layer_shape)
197 | return context_layer
198 |
199 |
200 | class BERTSelfOutput(nn.Module):
201 | def __init__(self, config):
202 | super(BERTSelfOutput, self).__init__()
203 | self.dense = nn.Linear(config.hidden_size, config.hidden_size)
204 | self.LayerNorm = BERTLayerNorm(config)
205 | self.dropout = nn.Dropout(config.hidden_dropout_prob)
206 |
207 | def forward(self, hidden_states, input_tensor):
208 | hidden_states = self.dense(hidden_states)
209 | hidden_states = self.dropout(hidden_states)
210 | hidden_states = self.LayerNorm(hidden_states + input_tensor)
211 | return hidden_states
212 |
213 |
214 | class BERTAttention(nn.Module):
215 | def __init__(self, config):
216 | super(BERTAttention, self).__init__()
217 | self.self = BERTSelfAttention(config)
218 | self.output = BERTSelfOutput(config)
219 |
220 | def forward(self, input_tensor, attention_mask):
221 | self_output = self.self(input_tensor, attention_mask)
222 | attention_output = self.output(self_output, input_tensor)
223 | return attention_output
224 |
225 |
226 | class BERTIntermediate(nn.Module):
227 | def __init__(self, config):
228 | super(BERTIntermediate, self).__init__()
229 | self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
230 | self.intermediate_act_fn = gelu
231 |
232 | def forward(self, hidden_states):
233 | hidden_states = self.dense(hidden_states)
234 | hidden_states = self.intermediate_act_fn(hidden_states)
235 | return hidden_states
236 |
237 |
238 | class BERTOutput(nn.Module):
239 | def __init__(self, config):
240 | super(BERTOutput, self).__init__()
241 | self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
242 | self.LayerNorm = BERTLayerNorm(config)
243 | self.dropout = nn.Dropout(config.hidden_dropout_prob)
244 |
245 | def forward(self, hidden_states, input_tensor):
246 | hidden_states = self.dense(hidden_states)
247 | hidden_states = self.dropout(hidden_states)
248 | hidden_states = self.LayerNorm(hidden_states + input_tensor)
249 | return hidden_states
250 |
251 |
252 | class BERTLayer(nn.Module):
253 | def __init__(self, config):
254 | super(BERTLayer, self).__init__()
255 | self.attention = BERTAttention(config)
256 | self.intermediate = BERTIntermediate(config)
257 | self.output = BERTOutput(config)
258 |
259 | def forward(self, hidden_states, attention_mask):
260 | attention_output = self.attention(hidden_states, attention_mask)
261 | intermediate_output = self.intermediate(attention_output)
262 | layer_output = self.output(intermediate_output, attention_output)
263 | return layer_output
264 |
265 |
266 | class BERTEncoder(nn.Module):
267 | def __init__(self, config):
268 | super(BERTEncoder, self).__init__()
269 | layer = BERTLayer(config)
270 | self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)])
271 |
272 | def forward(self, hidden_states, attention_mask):
273 | all_encoder_layers = []
274 | for layer_module in self.layer:
275 | hidden_states = layer_module(hidden_states, attention_mask)
276 | all_encoder_layers.append(hidden_states)
277 | return all_encoder_layers
278 |
279 |
280 | class BERTPooler(nn.Module):
281 | def __init__(self, config):
282 | super(BERTPooler, self).__init__()
283 | self.dense = nn.Linear(config.hidden_size, config.hidden_size)
284 | self.activation = nn.Tanh()
285 |
286 | def forward(self, hidden_states):
287 | # We "pool" the model by simply taking the hidden state corresponding
288 | # to the first token.
289 | first_token_tensor = hidden_states[:, 0]
290 | #return first_token_tensor
291 | pooled_output = self.dense(first_token_tensor)
292 | pooled_output = self.activation(pooled_output)
293 | return pooled_output
294 |
295 |
296 | class BertModel(nn.Module):
297 | """BERT model ("Bidirectional Embedding Representations from a Transformer").
298 |
299 | Example usage:
300 | ```python
301 | # Already been converted into WordPiece token ids
302 | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
303 | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
304 | token_type_ids = torch.LongTensor([[0, 0, 1], [0, 2, 0]])
305 |
306 | config = modeling.BertConfig(vocab_size=32000, hidden_size=512,
307 | num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024)
308 |
309 | model = modeling.BertModel(config=config)
310 | all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
311 | ```
312 | """
313 | def __init__(self, config: BertConfig):
314 | """Constructor for BertModel.
315 |
316 | Args:
317 | config: `BertConfig` instance.
318 | """
319 | super(BertModel, self).__init__()
320 | self.embeddings = BERTEmbeddings(config)
321 | self.encoder = BERTEncoder(config)
322 | self.pooler = BERTPooler(config)
323 |
324 | def forward(self, input_ids, token_type_ids=None, attention_mask=None):
325 | if attention_mask is None:
326 | attention_mask = torch.ones_like(input_ids)
327 | if token_type_ids is None:
328 | token_type_ids = torch.zeros_like(input_ids)
329 |
330 | # We create a 3D attention mask from a 2D tensor mask.
331 | # Sizes are [batch_size, 1, 1, from_seq_length]
332 | # So we can broadcast to [batch_size, num_heads, to_seq_length, from_seq_length]
333 | # this attention mask is more simple than the triangular masking of causal attention
334 | # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
335 | extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
336 |
337 | # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
338 | # masked positions, this operation will create a tensor which is 0.0 for
339 | # positions we want to attend and -10000.0 for masked positions.
340 | # Since we are adding it to the raw scores before the softmax, this is
341 | # effectively the same as removing these entirely.
342 | extended_attention_mask = extended_attention_mask.float()
343 | extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
344 |
345 | embedding_output = self.embeddings(input_ids, token_type_ids)
346 | all_encoder_layers = self.encoder(embedding_output, extended_attention_mask)
347 | sequence_output = all_encoder_layers[-1]
348 | pooled_output = self.pooler(sequence_output)
349 | return all_encoder_layers, pooled_output
350 |
351 | class BertForSequenceClassification(nn.Module):
352 | """BERT model for classification.
353 | This module is composed of the BERT model with a linear layer on top of
354 | the pooled output.
355 |
356 | Example usage:
357 | ```python
358 | # Already been converted into WordPiece token ids
359 | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
360 | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
361 | token_type_ids = torch.LongTensor([[0, 0, 1], [0, 2, 0]])
362 |
363 | config = BertConfig(vocab_size=32000, hidden_size=512,
364 | num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024)
365 |
366 | num_labels = 2
367 |
368 | model = BertForSequenceClassification(config, num_labels)
369 | logits = model(input_ids, token_type_ids, input_mask)
370 | ```
371 | """
372 | def __init__(self, config, num_labels):
373 | super(BertForSequenceClassification, self).__init__()
374 | self.bert = BertModel(config)
375 | self.dropout = nn.Dropout(config.hidden_dropout_prob)
376 | self.classifier = nn.Linear(config.hidden_size, num_labels)
377 |
378 | def init_weights(module):
379 | if isinstance(module, (nn.Linear, nn.Embedding)):
380 | # Slightly different from the TF version which uses truncated_normal for initialization
381 | # cf https://github.com/pytorch/pytorch/pull/5617
382 | module.weight.data.normal_(mean=0.0, std=config.initializer_range)
383 | elif isinstance(module, BERTLayerNorm):
384 | module.beta.data.normal_(mean=0.0, std=config.initializer_range)
385 | module.gamma.data.normal_(mean=0.0, std=config.initializer_range)
386 | if isinstance(module, nn.Linear):
387 | module.bias.data.zero_()
388 | self.apply(init_weights)
389 |
390 | def forward(self, input_ids, token_type_ids, attention_mask, labels=None):
391 | _, pooled_output = self.bert(input_ids, token_type_ids, attention_mask)
392 | pooled_output = self.dropout(pooled_output)
393 | logits = self.classifier(pooled_output)
394 |
395 | if labels is not None:
396 | loss_fct = CrossEntropyLoss()
397 | loss = loss_fct(logits, labels)
398 | return loss, logits
399 | else:
400 | return logits
401 |
402 |
403 | class BertForQuestionAnswering(nn.Module):
404 | """BERT model for Question Answering (span extraction).
405 | This module is composed of the BERT model with a linear layer on top of
406 | the sequence output that computes start_logits and end_logits
407 |
408 | Example usage:
409 | ```python
410 | # Already been converted into WordPiece token ids
411 | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
412 | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
413 | token_type_ids = torch.LongTensor([[0, 0, 1], [0, 2, 0]])
414 |
415 | config = BertConfig(vocab_size=32000, hidden_size=512,
416 | num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024)
417 |
418 | model = BertForQuestionAnswering(config)
419 | start_logits, end_logits = model(input_ids, token_type_ids, input_mask)
420 | ```
421 | """
422 | def __init__(self, config):
423 | super(BertForQuestionAnswering, self).__init__()
424 | self.bert = BertModel(config)
425 | # TODO check with Google if it's normal there is no dropout on the token classifier of SQuAD in the TF version
426 | # self.dropout = nn.Dropout(config.hidden_dropout_prob)
427 | self.qa_outputs = nn.Linear(config.hidden_size, 2)
428 |
429 | def init_weights(module):
430 | if isinstance(module, (nn.Linear, nn.Embedding)):
431 | # Slightly different from the TF version which uses truncated_normal for initialization
432 | # cf https://github.com/pytorch/pytorch/pull/5617
433 | module.weight.data.normal_(mean=0.0, std=config.initializer_range)
434 | elif isinstance(module, BERTLayerNorm):
435 | module.beta.data.normal_(mean=0.0, std=config.initializer_range)
436 | module.gamma.data.normal_(mean=0.0, std=config.initializer_range)
437 | if isinstance(module, nn.Linear):
438 | module.bias.data.zero_()
439 | self.apply(init_weights)
440 |
441 | def forward(self, input_ids, token_type_ids, attention_mask, start_positions=None, end_positions=None):
442 | all_encoder_layers, _ = self.bert(input_ids, token_type_ids, attention_mask)
443 | sequence_output = all_encoder_layers[-1]
444 | logits = self.qa_outputs(sequence_output)
445 | start_logits, end_logits = logits.split(1, dim=-1)
446 | start_logits = start_logits.squeeze(-1)
447 | end_logits = end_logits.squeeze(-1)
448 |
449 | if start_positions is not None and end_positions is not None:
450 | # If we are on multi-GPU, split add a dimension - if not this is a no-op
451 | start_positions = start_positions.squeeze(-1)
452 | end_positions = end_positions.squeeze(-1)
453 | # sometimes the start/end positions are outside our model inputs, we ignore these terms
454 | ignored_index = start_logits.size(1)
455 | start_positions.clamp_(0, ignored_index)
456 | end_positions.clamp_(0, ignored_index)
457 |
458 | loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
459 | start_loss = loss_fct(start_logits, start_positions)
460 | end_loss = loss_fct(end_logits, end_positions)
461 | total_loss = (start_loss + end_loss) / 2
462 | return total_loss
463 | else:
464 | return start_logits, end_logits
465 |
--------------------------------------------------------------------------------
/run_classifier_TABSA.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 |
3 | """BERT finetuning runner."""
4 |
5 | from __future__ import absolute_import, division, print_function
6 |
7 | import argparse
8 | import collections
9 | import logging
10 | import os
11 | import random
12 |
13 | import numpy as np
14 | import torch
15 | import torch.nn.functional as F
16 | from torch.utils.data import DataLoader, TensorDataset
17 | from torch.utils.data.distributed import DistributedSampler
18 | from torch.utils.data.sampler import RandomSampler, SequentialSampler
19 | from tqdm import tqdm, trange
20 |
21 | import tokenization
22 | from modeling import BertConfig, BertForSequenceClassification
23 | from optimization import BERTAdam
24 | from processor import (Semeval_NLI_B_Processor, Semeval_NLI_M_Processor,
25 | Semeval_QA_B_Processor, Semeval_QA_M_Processor,
26 | Semeval_single_Processor, Sentihood_NLI_B_Processor,
27 | Sentihood_NLI_M_Processor, Sentihood_QA_B_Processor,
28 | Sentihood_QA_M_Processor, Sentihood_single_Processor)
29 |
30 | logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
31 | datefmt = '%m/%d/%Y %H:%M:%S',
32 | level = logging.INFO)
33 | logger = logging.getLogger(__name__)
34 |
35 |
36 | class InputFeatures(object):
37 | """A single set of features of data."""
38 |
39 | def __init__(self, input_ids, input_mask, segment_ids, label_id):
40 | self.input_ids = input_ids
41 | self.input_mask = input_mask
42 | self.segment_ids = segment_ids
43 | self.label_id = label_id
44 |
45 |
46 | def convert_examples_to_features(examples, label_list, max_seq_length, tokenizer):
47 | """Loads a data file into a list of `InputBatch`s."""
48 |
49 | label_map = {}
50 | for (i, label) in enumerate(label_list):
51 | label_map[label] = i
52 |
53 | features = []
54 | for (ex_index, example) in enumerate(tqdm(examples)):
55 | tokens_a = tokenizer.tokenize(example.text_a)
56 |
57 | tokens_b = None
58 | if example.text_b:
59 | tokens_b = tokenizer.tokenize(example.text_b)
60 |
61 | if tokens_b:
62 | # Modifies `tokens_a` and `tokens_b` in place so that the total
63 | # length is less than the specified length.
64 | # Account for [CLS], [SEP], [SEP] with "- 3"
65 | _truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3)
66 | else:
67 | # Account for [CLS] and [SEP] with "- 2"
68 | if len(tokens_a) > max_seq_length - 2:
69 | tokens_a = tokens_a[0:(max_seq_length - 2)]
70 |
71 | # The convention in BERT is:
72 | # (a) For sequence pairs:
73 | # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]
74 | # type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1
75 | # (b) For single sequences:
76 | # tokens: [CLS] the dog is hairy . [SEP]
77 | # type_ids: 0 0 0 0 0 0 0
78 | #
79 | # Where "type_ids" are used to indicate whether this is the first
80 | # sequence or the second sequence. The embedding vectors for `type=0` and
81 | # `type=1` were learned during pre-training and are added to the wordpiece
82 | # embedding vector (and position vector). This is not *strictly* necessary
83 | # since the [SEP] token unambigiously separates the sequences, but it makes
84 | # it easier for the model to learn the concept of sequences.
85 | #
86 | # For classification tasks, the first vector (corresponding to [CLS]) is
87 | # used as as the "sentence vector". Note that this only makes sense because
88 | # the entire model is fine-tuned.
89 | tokens = []
90 | segment_ids = []
91 | tokens.append("[CLS]")
92 | segment_ids.append(0)
93 | for token in tokens_a:
94 | tokens.append(token)
95 | segment_ids.append(0)
96 | tokens.append("[SEP]")
97 | segment_ids.append(0)
98 |
99 | if tokens_b:
100 | for token in tokens_b:
101 | tokens.append(token)
102 | segment_ids.append(1)
103 | tokens.append("[SEP]")
104 | segment_ids.append(1)
105 |
106 | input_ids = tokenizer.convert_tokens_to_ids(tokens)
107 |
108 | # The mask has 1 for real tokens and 0 for padding tokens. Only real
109 | # tokens are attended to.
110 | input_mask = [1] * len(input_ids)
111 |
112 | # Zero-pad up to the sequence length.
113 | while len(input_ids) < max_seq_length:
114 | input_ids.append(0)
115 | input_mask.append(0)
116 | segment_ids.append(0)
117 |
118 | assert len(input_ids) == max_seq_length
119 | assert len(input_mask) == max_seq_length
120 | assert len(segment_ids) == max_seq_length
121 |
122 | label_id = label_map[example.label]
123 |
124 | features.append(
125 | InputFeatures(
126 | input_ids=input_ids,
127 | input_mask=input_mask,
128 | segment_ids=segment_ids,
129 | label_id=label_id))
130 | return features
131 |
132 |
133 | def _truncate_seq_pair(tokens_a, tokens_b, max_length):
134 | """Truncates a sequence pair in place to the maximum length."""
135 |
136 | # This is a simple heuristic which will always truncate the longer sequence
137 | # one token at a time. This makes more sense than truncating an equal percent
138 | # of tokens from each, since if one sequence is very short then each token
139 | # that's truncated likely contains more information than a longer sequence.
140 | while True:
141 | total_length = len(tokens_a) + len(tokens_b)
142 | if total_length <= max_length:
143 | break
144 | if len(tokens_a) > len(tokens_b):
145 | tokens_a.pop()
146 | else:
147 | tokens_b.pop()
148 |
149 |
150 | def main():
151 | parser = argparse.ArgumentParser()
152 |
153 | ## Required parameters
154 | parser.add_argument("--task_name",
155 | default=None,
156 | type=str,
157 | required=True,
158 | choices=["sentihood_single", "sentihood_NLI_M", "sentihood_QA_M", \
159 | "sentihood_NLI_B", "sentihood_QA_B", "semeval_single", \
160 | "semeval_NLI_M", "semeval_QA_M", "semeval_NLI_B", "semeval_QA_B"],
161 | help="The name of the task to train.")
162 | parser.add_argument("--data_dir",
163 | default=None,
164 | type=str,
165 | required=True,
166 | help="The input data dir. Should contain the .tsv files (or other data files) for the task.")
167 | parser.add_argument("--vocab_file",
168 | default=None,
169 | type=str,
170 | required=True,
171 | help="The vocabulary file that the BERT model was trained on.")
172 | parser.add_argument("--bert_config_file",
173 | default=None,
174 | type=str,
175 | required=True,
176 | help="The config json file corresponding to the pre-trained BERT model. \n"
177 | "This specifies the model architecture.")
178 | parser.add_argument("--output_dir",
179 | default=None,
180 | type=str,
181 | required=True,
182 | help="The output directory where the model checkpoints will be written.")
183 |
184 | ## Other parameters
185 | parser.add_argument("--init_checkpoint",
186 | default=None,
187 | type=str,
188 | help="Initial checkpoint (usually from a pre-trained BERT model).")
189 | parser.add_argument("--init_eval_checkpoint",
190 | default=None,
191 | type=str,
192 | help="Initial checkpoint (usually from a pre-trained BERT model + classifier).")
193 | parser.add_argument("--do_save_model",
194 | default=False,
195 | action='store_true',
196 | help="Whether to save model.")
197 | parser.add_argument("--eval_test",
198 | default=False,
199 | action='store_true',
200 | help="Whether to run eval on the test set.")
201 | parser.add_argument("--do_lower_case",
202 | default=False,
203 | action='store_true',
204 | help="Whether to lower case the input text. True for uncased models, False for cased models.")
205 | parser.add_argument("--max_seq_length",
206 | default=128,
207 | type=int,
208 | help="The maximum total input sequence length after WordPiece tokenization. \n"
209 | "Sequences longer than this will be truncated, and sequences shorter \n"
210 | "than this will be padded.")
211 | parser.add_argument("--train_batch_size",
212 | default=32,
213 | type=int,
214 | help="Total batch size for training.")
215 | parser.add_argument("--eval_batch_size",
216 | default=8,
217 | type=int,
218 | help="Total batch size for eval.")
219 | parser.add_argument("--learning_rate",
220 | default=5e-5,
221 | type=float,
222 | help="The initial learning rate for Adam.")
223 | parser.add_argument("--num_train_epochs",
224 | default=3.0,
225 | type=float,
226 | help="Total number of training epochs to perform.")
227 | parser.add_argument("--warmup_proportion",
228 | default=0.1,
229 | type=float,
230 | help="Proportion of training to perform linear learning rate warmup for. "
231 | "E.g., 0.1 = 10%% of training.")
232 | parser.add_argument("--no_cuda",
233 | default=False,
234 | action='store_true',
235 | help="Whether not to use CUDA when available")
236 | parser.add_argument("--accumulate_gradients",
237 | type=int,
238 | default=1,
239 | help="Number of steps to accumulate gradient on (divide the batch_size and accumulate)")
240 | parser.add_argument("--local_rank",
241 | type=int,
242 | default=-1,
243 | help="local_rank for distributed training on gpus")
244 | parser.add_argument('--seed',
245 | type=int,
246 | default=42,
247 | help="random seed for initialization")
248 | parser.add_argument('--gradient_accumulation_steps',
249 | type=int,
250 | default=1,
251 | help="Number of updates steps to accumualte before performing a backward/update pass.")
252 | args = parser.parse_args()
253 |
254 |
255 | if args.local_rank == -1 or args.no_cuda:
256 | device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
257 | n_gpu = torch.cuda.device_count()
258 | else:
259 | device = torch.device("cuda", args.local_rank)
260 | n_gpu = 1
261 | # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
262 | torch.distributed.init_process_group(backend='nccl')
263 | logger.info("device %s n_gpu %d distributed training %r", device, n_gpu, bool(args.local_rank != -1))
264 |
265 | if args.accumulate_gradients < 1:
266 | raise ValueError("Invalid accumulate_gradients parameter: {}, should be >= 1".format(
267 | args.accumulate_gradients))
268 |
269 | args.train_batch_size = int(args.train_batch_size / args.accumulate_gradients)
270 |
271 | random.seed(args.seed)
272 | np.random.seed(args.seed)
273 | torch.manual_seed(args.seed)
274 | if n_gpu > 0:
275 | torch.cuda.manual_seed_all(args.seed)
276 |
277 | bert_config = BertConfig.from_json_file(args.bert_config_file)
278 |
279 | if args.max_seq_length > bert_config.max_position_embeddings:
280 | raise ValueError(
281 | "Cannot use sequence length {} because the BERT model was only trained up to sequence length {}".format(
282 | args.max_seq_length, bert_config.max_position_embeddings))
283 |
284 | if os.path.exists(args.output_dir) and os.listdir(args.output_dir):
285 | raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
286 | os.makedirs(args.output_dir, exist_ok=True)
287 |
288 |
289 | # prepare dataloaders
290 | processors = {
291 | "sentihood_single":Sentihood_single_Processor,
292 | "sentihood_NLI_M":Sentihood_NLI_M_Processor,
293 | "sentihood_QA_M":Sentihood_QA_M_Processor,
294 | "sentihood_NLI_B":Sentihood_NLI_B_Processor,
295 | "sentihood_QA_B":Sentihood_QA_B_Processor,
296 | "semeval_single":Semeval_single_Processor,
297 | "semeval_NLI_M":Semeval_NLI_M_Processor,
298 | "semeval_QA_M":Semeval_QA_M_Processor,
299 | "semeval_NLI_B":Semeval_NLI_B_Processor,
300 | "semeval_QA_B":Semeval_QA_B_Processor,
301 | }
302 |
303 | processor = processors[args.task_name]()
304 | label_list = processor.get_labels()
305 |
306 | tokenizer = tokenization.FullTokenizer(
307 | vocab_file=args.vocab_file, do_lower_case=args.do_lower_case)
308 |
309 | # training set
310 | train_examples = None
311 | num_train_steps = None
312 | train_examples = processor.get_train_examples(args.data_dir)
313 | num_train_steps = int(
314 | len(train_examples) / args.train_batch_size * args.num_train_epochs)
315 |
316 | train_features = convert_examples_to_features(
317 | train_examples, label_list, args.max_seq_length, tokenizer)
318 | logger.info("***** Running training *****")
319 | logger.info(" Num examples = %d", len(train_examples))
320 | logger.info(" Batch size = %d", args.train_batch_size)
321 | logger.info(" Num steps = %d", num_train_steps)
322 |
323 | all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long)
324 | all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long)
325 | all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long)
326 | all_label_ids = torch.tensor([f.label_id for f in train_features], dtype=torch.long)
327 |
328 | train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
329 | if args.local_rank == -1:
330 | train_sampler = RandomSampler(train_data)
331 | else:
332 | train_sampler = DistributedSampler(train_data)
333 | train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size)
334 |
335 | # test set
336 | if args.eval_test:
337 | test_examples = processor.get_test_examples(args.data_dir)
338 | test_features = convert_examples_to_features(
339 | test_examples, label_list, args.max_seq_length, tokenizer)
340 |
341 | all_input_ids = torch.tensor([f.input_ids for f in test_features], dtype=torch.long)
342 | all_input_mask = torch.tensor([f.input_mask for f in test_features], dtype=torch.long)
343 | all_segment_ids = torch.tensor([f.segment_ids for f in test_features], dtype=torch.long)
344 | all_label_ids = torch.tensor([f.label_id for f in test_features], dtype=torch.long)
345 |
346 | test_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
347 | test_dataloader = DataLoader(test_data, batch_size=args.eval_batch_size, shuffle=False)
348 |
349 |
350 | # model and optimizer
351 | model = BertForSequenceClassification(bert_config, len(label_list))
352 | if args.init_eval_checkpoint is not None:
353 | model.load_state_dict(torch.load(args.init_eval_checkpoint, map_location='cpu'))
354 | elif args.init_checkpoint is not None:
355 | model.bert.load_state_dict(torch.load(args.init_checkpoint, map_location='cpu'))
356 | model.to(device)
357 |
358 | if args.local_rank != -1:
359 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],
360 | output_device=args.local_rank)
361 | elif n_gpu > 1:
362 | model = torch.nn.DataParallel(model)
363 |
364 | no_decay = ['bias', 'gamma', 'beta']
365 | optimizer_parameters = [
366 | {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay_rate': 0.01},
367 | {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay_rate': 0.0}
368 | ]
369 |
370 | optimizer = BERTAdam(optimizer_parameters,
371 | lr=args.learning_rate,
372 | warmup=args.warmup_proportion,
373 | t_total=num_train_steps)
374 |
375 |
376 | # train
377 | output_log_file = os.path.join(args.output_dir, "log.txt")
378 | print("output_log_file=",output_log_file)
379 | with open(output_log_file, "w") as writer:
380 | if args.eval_test:
381 | writer.write("epoch\tglobal_step\tloss\ttest_loss\ttest_accuracy\n")
382 | else:
383 | writer.write("epoch\tglobal_step\tloss\n")
384 |
385 | global_step = 0
386 | epoch = 0
387 | for _ in trange(int(args.num_train_epochs), desc="Epoch"):
388 | epoch += 1
389 | model.train()
390 | tr_loss = 0
391 | nb_tr_examples, nb_tr_steps = 0, 0
392 | for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")):
393 | batch = tuple(t.to(device) for t in batch)
394 | input_ids, input_mask, segment_ids, label_ids = batch
395 | loss, _ = model(input_ids, segment_ids, input_mask, label_ids)
396 | if n_gpu > 1:
397 | loss = loss.mean() # mean() to average on multi-gpu.
398 | if args.gradient_accumulation_steps > 1:
399 | loss = loss / args.gradient_accumulation_steps
400 | loss.backward()
401 | tr_loss += loss.item()
402 | nb_tr_examples += input_ids.size(0)
403 | nb_tr_steps += 1
404 | if (step + 1) % args.gradient_accumulation_steps == 0:
405 | optimizer.step() # We have accumulated enought gradients
406 | model.zero_grad()
407 | global_step += 1
408 |
409 | if args.do_save_model:
410 | if n_gpu > 1:
411 | torch.save(model.module.state_dict(), os.path.join(args.output_dir, f'model_ep_{epoch}.bin'))
412 | else:
413 | torch.save(model.state_dict(), os.path.join(args.output_dir, f'model_ep_{epoch}.bin'))
414 |
415 | # eval_test
416 | if args.eval_test:
417 | model.eval()
418 | test_loss, test_accuracy = 0, 0
419 | nb_test_steps, nb_test_examples = 0, 0
420 | with open(os.path.join(args.output_dir, f"test_ep_{epoch}.txt"), "w") as f_test:
421 | for input_ids, input_mask, segment_ids, label_ids in test_dataloader:
422 | input_ids = input_ids.to(device)
423 | input_mask = input_mask.to(device)
424 | segment_ids = segment_ids.to(device)
425 | label_ids = label_ids.to(device)
426 |
427 | with torch.no_grad():
428 | tmp_test_loss, logits = model(input_ids, segment_ids, input_mask, label_ids)
429 |
430 | logits = F.softmax(logits, dim=-1)
431 | logits = logits.detach().cpu().numpy()
432 | label_ids = label_ids.to('cpu').numpy()
433 | outputs = np.argmax(logits, axis=1)
434 | for output_i in range(len(outputs)):
435 | f_test.write(str(outputs[output_i]))
436 | for ou in logits[output_i]:
437 | f_test.write(" "+str(ou))
438 | f_test.write("\n")
439 | tmp_test_accuracy=np.sum(outputs == label_ids)
440 |
441 | test_loss += tmp_test_loss.mean().item()
442 | test_accuracy += tmp_test_accuracy
443 |
444 | nb_test_examples += input_ids.size(0)
445 | nb_test_steps += 1
446 |
447 | test_loss = test_loss / nb_test_steps
448 | test_accuracy = test_accuracy / nb_test_examples
449 |
450 |
451 | result = collections.OrderedDict()
452 | if args.eval_test:
453 | result = {'epoch': epoch,
454 | 'global_step': global_step,
455 | 'loss': tr_loss/nb_tr_steps,
456 | 'test_loss': test_loss,
457 | 'test_accuracy': test_accuracy}
458 | else:
459 | result = {'epoch': epoch,
460 | 'global_step': global_step,
461 | 'loss': tr_loss/nb_tr_steps}
462 |
463 | logger.info("***** Eval results *****")
464 | with open(output_log_file, "a+") as writer:
465 | for key in result.keys():
466 | logger.info(" %s = %s\n", key, str(result[key]))
467 | writer.write("%s\t" % (str(result[key])))
468 | writer.write("\n")
469 |
470 | if __name__ == "__main__":
471 | main()
472 |
--------------------------------------------------------------------------------