├── .gitignore
├── LICENSE
├── README.md
├── config.py
├── dataset
└── .gitkeep
├── images
└── architecture.png
├── main.py
├── model
├── __init__.py
├── crf.py
└── model.py
├── output
├── baidu
│ └── .gitkeep
├── dianping
│ └── .gitkeep
└── mafengwo
│ └── .gitkeep
├── requirements.txt
└── utils
├── __init__.py
├── data.py
├── preprocess.py
├── score.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 | MANIFEST
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 | .pytest_cache/
49 |
50 | # Translations
51 | *.mo
52 | *.pot
53 |
54 | # Django stuff:
55 | *.log
56 | local_settings.py
57 | db.sqlite3
58 |
59 | # Flask stuff:
60 | instance/
61 | .webassets-cache
62 |
63 | # Scrapy stuff:
64 | .scrapy
65 |
66 | # Sphinx documentation
67 | docs/_build/
68 |
69 | # PyBuilder
70 | target/
71 |
72 | # Jupyter Notebook
73 | .ipynb_checkpoints
74 |
75 | # pyenv
76 | .python-version
77 |
78 | # celery beat schedule file
79 | celerybeat-schedule
80 |
81 | # SageMath parsed files
82 | *.sage.py
83 |
84 | # Environments
85 | .env
86 | .venv
87 | env/
88 | venv/
89 | ENV/
90 | env.bak/
91 | venv.bak/
92 |
93 | # Spyder project settings
94 | .spyderproject
95 | .spyproject
96 |
97 | # Rope project settings
98 | .ropeproject
99 |
100 | # mkdocs documentation
101 | /site
102 |
103 | # mypy
104 | .mypy_cache/
105 |
106 | .idea
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2018
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Chinese Opinion Target Extraction
2 | Pytorch implement of "Character-based BiLSTM-CRF Incorporating POS and Dictionaries for Chinese Opinion Target Extraction", ACML2018 [\[paper](http://proceedings.mlr.press/v95/li18d.html), [pdf\]](http://proceedings.mlr.press/v95/li18d/li18d.pdf)
3 |
4 | ### Dependency
5 |
6 | While this implement might work for many cases, it is only tested for environment below:
7 |
8 | ```
9 | python == 3.6.8
10 | torch == 1.1.0
11 | thulac == 0.2.0
12 | tqdm
13 | keras == 2.3.0
14 | numpy == 1.17.0
15 | numba
16 | ```
17 |
18 | ### Usage
19 |
20 | 1. Install dependency
21 | 2. Download dataset from [this repo](https://github.com/lsvih/chinese-customer-review), move files into `./dataset` folder, then unzip `dictionary.zip`.
22 | 3. Train model: `python3 main.py --mode=train --dataset=baidu`
23 | 4. Test model: `python3 main.py --mode=test --dataset=baidu`
24 |
25 | > Note: It would cost about 10~20 minutes for pre-processing.
26 |
27 | ### Architecture
28 |
29 |
30 |

31 |
32 |
33 | ### Results
34 |
35 | | | Baidu | Mafengwo | Dianping |
36 | | --- | --- | --- | --- |
37 | | P | 85.791 | 83.273 | 83.753 |
38 | | R | 82.531 | 89.989 | 85.672 |
39 | | F1 | 84.130 | 86.501 | 84.702 |
40 |
41 | ### Citation
42 |
43 | If you find this work is useful in your research, please consider citing:
44 |
45 | ```
46 | @inproceedings{li2018character,
47 | title={Character-based BiLSTM-CRF Incorporating POS and Dictionaries for Chinese Opinion Target Extraction},
48 | author={Li, Yanzeng and Liu, Tingwen and Li, Diying and Li, Quangang and Shi, Jinqiao and Wang, Yanqiu},
49 | booktitle={Asian Conference on Machine Learning},
50 | pages={518--533},
51 | year={2018}
52 | }
53 | ```
54 |
55 |
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | class Config:
2 | def __init__(self):
3 | self.epoch = 20
4 | self.batch_size = 128
5 | self.MAX_SENTENCE_LENGTH = 250
6 | self.char_emb_dim = 50
7 | self.pos_emb_dim = 20
8 | self.tag_emb_dim = 20
9 | self.pos_hidden_dim = 50
10 | self.lstm_hidden_dim = 50
11 | self.dropout = 0.5
12 | self.lr = 0.002
13 | self.lr_decay = 0.03
14 | self.momentum = 0.01
15 | self.config_path = ''
16 | self.model_path = ''
17 | self.result_path = ''
18 | self.output_path = ''
19 |
20 | def set_dataset(self, dataset):
21 | self.model_path = './output/%s/model' % dataset
22 | self.config_path = './output/%s/setting' % dataset
23 | self.result_path = './output/%s/result.txt' % dataset
24 | self.output_path = './output/%s/' % dataset
25 |
--------------------------------------------------------------------------------
/dataset/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kdsec/chinese-opinion-target-extraction/05447962e6536a9c591fced1c09686a5209ac2f5/dataset/.gitkeep
--------------------------------------------------------------------------------
/images/architecture.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kdsec/chinese-opinion-target-extraction/05447962e6536a9c591fced1c09686a5209ac2f5/images/architecture.png
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import random
4 |
5 | from torch import optim
6 | from tqdm import trange
7 |
8 | from model.model import BiLSTM_CRF as Model
9 | from utils.data import Data
10 | from utils.preprocess import preprocess
11 | from utils.score import score
12 | from utils.utils import *
13 |
14 | seed_num = 123456
15 | random.seed(seed_num)
16 | torch.manual_seed(seed_num)
17 | np.random.seed(seed_num)
18 |
19 |
20 | def train(data):
21 | print('Training model...')
22 | save_data_setting(data)
23 | model = Model(data).to(device)
24 | optimizer = optim.RMSprop(model.parameters(), lr=data.lr, momentum=data.momentum)
25 | for epoch in range(data.epoch):
26 | print('Epoch: %s/%s' % (epoch, data.epoch))
27 | optimizer = lr_decay(optimizer, epoch, data.lr_decay, data.lr)
28 | total_loss = 0
29 | random.shuffle(data.ids)
30 | model.train()
31 | model.zero_grad()
32 | train_num = len(data.ids)
33 | total_batch = train_num // data.batch_size + 1
34 | for batch in trange(total_batch):
35 | start, end = slice_set(batch, data.batch_size, train_num)
36 | instance = data.ids[start:end]
37 | if not instance: continue
38 | *model_input, _ = load_batch(instance)
39 | loss = model.neg_log_likelihood_loss(*model_input)
40 | total_loss += loss.data.item()
41 | loss.backward()
42 | optimizer.step()
43 | model.zero_grad()
44 | print('Epoch %d loss = %.3f' % (epoch, total_loss))
45 | torch.save(model.state_dict(), data.model_path)
46 |
47 |
48 | def test(data):
49 | print('Testing model...')
50 | model = Model(data).to(device)
51 | model.load_state_dict(torch.load(data.model_path))
52 | instances = data.ids
53 | pred_results = []
54 | model.eval()
55 | test_num = len(instances)
56 | total_batch = test_num // data.batch_size + 1
57 | for batch in trange(total_batch):
58 | start, end = slice_set(batch, data.batch_size, test_num)
59 | instance = instances[start:end]
60 | if not instance: continue
61 | _, mask, *model_input, char_recover = load_batch(instance, True)
62 | tag_seq = model(mask, *model_input)
63 | pred_label = seq2label(tag_seq, mask, data.label_alphabet, char_recover)
64 | pred_results += pred_label
65 | return pred_results
66 |
67 |
68 | if __name__ == '__main__':
69 | parser = argparse.ArgumentParser(description='Setting mode and dataset.')
70 | parser.add_argument('--mode', choices=['train', 'test'], help='update algorithm', default='train')
71 | parser.add_argument('--dataset', choices=['baidu', 'dianping', 'mafengwo'], help='select dataset', default='baidu')
72 | args = parser.parse_args()
73 | mode = args.mode.lower()
74 | dataset = args.dataset.lower()
75 | print('Using dataset', dataset)
76 | train_file = './dataset/' + dataset + '/train_seg.txt'
77 | test_file = './dataset/' + dataset + '/test_seg.txt'
78 | if not os.path.exists(train_file) or not os.path.exists(test_file):
79 | preprocess(dataset)
80 | data = Data()
81 | data.set_dataset(dataset)
82 | if mode == 'train':
83 | data.data_loader(train_file, 'train')
84 | train(data)
85 | elif mode == 'test':
86 | data = pickle.load(open(data.config_path, 'rb'))
87 | data.data_loader(test_file, 'test')
88 | results = test(data)
89 | save_results(data, results)
90 | score(data.result_path, test_file, data.output_path)
91 |
--------------------------------------------------------------------------------
/model/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kdsec/chinese-opinion-target-extraction/05447962e6536a9c591fced1c09686a5209ac2f5/model/__init__.py
--------------------------------------------------------------------------------
/model/crf.py:
--------------------------------------------------------------------------------
1 | # Reference https://github.com/liu-nlper/SLTK/blob/master/sltk/nn/modules/crf.py
2 | import torch
3 | import torch.nn as nn
4 | from torch.autograd import Variable
5 |
6 |
7 | def log_sum_exp(vec, m_size):
8 | """
9 | Args:
10 | vec: size=(batch_size, vanishing_dim, hidden_dim)
11 | m_size: hidden_dim
12 | Returns:
13 | size=(batch_size, hidden_dim)
14 | """
15 | _, idx = torch.max(vec, 1) # B * 1 * M
16 | max_score = torch.gather(vec, 1, idx.view(-1, 1, m_size)).view(-1, 1, m_size) # B * M
17 | return max_score.view(-1, m_size) + torch.log(torch.sum(
18 | torch.exp(vec - max_score.expand_as(vec)), 1)).view(-1, m_size)
19 |
20 |
21 | class CRF(nn.Module):
22 |
23 | def __init__(self, **kwargs):
24 | """
25 | Args:
26 | target_size: int, target size
27 | use_cuda: bool, 是否使用gpu, default is True
28 | average_batch: bool, loss是否作平均, default is True
29 | """
30 | super(CRF, self).__init__()
31 | for k in kwargs:
32 | self.__setattr__(k, kwargs[k])
33 | if not hasattr(self, 'average_batch'):
34 | self.__setattr__('average_batch', True)
35 | if not hasattr(self, 'use_cuda'):
36 | self.__setattr__('use_cuda', True)
37 |
38 | # init transitions
39 | self.START_TAG_IDX, self.END_TAG_IDX = -2, -1
40 | init_transitions = torch.zeros(self.target_size + 2, self.target_size + 2)
41 | init_transitions[:, self.START_TAG_IDX] = -1000.
42 | init_transitions[self.END_TAG_IDX, :] = -1000.
43 | if self.use_cuda:
44 | init_transitions = init_transitions.cuda()
45 | self.transitions = nn.Parameter(init_transitions)
46 |
47 | def _forward_alg(self, feats, mask):
48 | """
49 | Do the forward algorithm to compute the partition function (batched).
50 | Args:
51 | feats: size=(batch_size, seq_len, self.target_size+2)
52 | mask: size=(batch_size, seq_len)
53 | Returns:
54 | xxx
55 | """
56 | batch_size = feats.size(0)
57 | seq_len = feats.size(1)
58 | tag_size = feats.size(-1)
59 |
60 | mask = mask.transpose(1, 0).contiguous()
61 | ins_num = batch_size * seq_len
62 |
63 | feats = feats.transpose(1, 0).contiguous().view(
64 | ins_num, 1, tag_size).expand(ins_num, tag_size, tag_size)
65 |
66 | scores = feats + self.transitions.view(
67 | 1, tag_size, tag_size).expand(ins_num, tag_size, tag_size)
68 | scores = scores.view(seq_len, batch_size, tag_size, tag_size)
69 |
70 | seq_iter = enumerate(scores)
71 | try:
72 | _, inivalues = seq_iter.__next__()
73 | except:
74 | _, inivalues = seq_iter.next()
75 | partition = inivalues[:, self.START_TAG_IDX, :].clone().view(batch_size, tag_size, 1)
76 |
77 | for idx, cur_values in seq_iter:
78 | cur_values = cur_values + partition.contiguous().view(
79 | batch_size, tag_size, 1).expand(batch_size, tag_size, tag_size)
80 | cur_partition = log_sum_exp(cur_values, tag_size)
81 |
82 | mask_idx = mask[idx, :].view(batch_size, 1).expand(batch_size, tag_size)
83 |
84 | masked_cur_partition = cur_partition.masked_select(mask_idx)
85 | if masked_cur_partition.dim() != 0:
86 | mask_idx = mask_idx.contiguous().view(batch_size, tag_size, 1)
87 | partition.masked_scatter_(mask_idx, masked_cur_partition)
88 |
89 | cur_values = self.transitions.view(1, tag_size, tag_size).expand(
90 | batch_size, tag_size, tag_size) + partition.contiguous().view(
91 | batch_size, tag_size, 1).expand(batch_size, tag_size, tag_size)
92 | cur_partition = log_sum_exp(cur_values, tag_size)
93 | final_partition = cur_partition[:, self.END_TAG_IDX]
94 | return final_partition.sum(), scores
95 |
96 | def _viterbi_decode(self, feats, mask):
97 | """
98 | Args:
99 | feats: size=(batch_size, seq_len, self.target_size+2)
100 | mask: size=(batch_size, seq_len)
101 | Returns:
102 | decode_idx: (batch_size, seq_len), viterbi decode结果
103 | path_score: size=(batch_size, 1), 每个句子的得分
104 | """
105 | batch_size = feats.size(0)
106 | seq_len = feats.size(1)
107 | tag_size = feats.size(-1)
108 |
109 | length_mask = torch.sum(mask, dim=1).view(batch_size, 1).long()
110 |
111 | mask = mask.transpose(1, 0).contiguous()
112 | ins_num = seq_len * batch_size
113 |
114 | feats = feats.transpose(1, 0).contiguous().view(
115 | ins_num, 1, tag_size).expand(ins_num, tag_size, tag_size)
116 |
117 | scores = feats + self.transitions.view(
118 | 1, tag_size, tag_size).expand(ins_num, tag_size, tag_size)
119 | scores = scores.view(seq_len, batch_size, tag_size, tag_size)
120 |
121 | seq_iter = enumerate(scores)
122 | # record the position of the best score
123 | back_points = list()
124 | partition_history = list()
125 |
126 | # mask = 1 + (-1) * mask
127 | mask = (1 - mask.long()).byte()
128 | try:
129 | _, inivalues = seq_iter.__next__()
130 | except:
131 | _, inivalues = seq_iter.next()
132 |
133 | partition = inivalues[:, self.START_TAG_IDX, :].clone().view(batch_size, tag_size, 1)
134 | partition_history.append(partition)
135 |
136 | for idx, cur_values in seq_iter:
137 | cur_values = cur_values + partition.contiguous().view(
138 | batch_size, tag_size, 1).expand(batch_size, tag_size, tag_size)
139 | partition, cur_bp = torch.max(cur_values, 1)
140 | partition_history.append(partition.unsqueeze(-1))
141 |
142 | cur_bp.masked_fill_(mask[idx].view(batch_size, 1).expand(batch_size, tag_size), 0)
143 | back_points.append(cur_bp)
144 |
145 | partition_history = torch.cat(partition_history).view(
146 | seq_len, batch_size, -1).transpose(1, 0).contiguous()
147 |
148 | last_position = length_mask.view(batch_size, 1, 1).expand(batch_size, 1, tag_size) - 1
149 | last_partition = torch.gather(
150 | partition_history, 1, last_position).view(batch_size, tag_size, 1)
151 |
152 | last_values = last_partition.expand(batch_size, tag_size, tag_size) + \
153 | self.transitions.view(1, tag_size, tag_size).expand(batch_size, tag_size, tag_size)
154 | _, last_bp = torch.max(last_values, 1)
155 | pad_zero = Variable(torch.zeros(batch_size, tag_size)).long()
156 | if self.use_cuda:
157 | pad_zero = pad_zero.cuda()
158 | back_points.append(pad_zero)
159 | back_points = torch.cat(back_points).view(seq_len, batch_size, tag_size)
160 |
161 | pointer = last_bp[:, self.END_TAG_IDX]
162 | insert_last = pointer.contiguous().view(batch_size, 1, 1).expand(batch_size, 1, tag_size)
163 | back_points = back_points.transpose(1, 0).contiguous()
164 |
165 | back_points.scatter_(1, last_position, insert_last)
166 |
167 | back_points = back_points.transpose(1, 0).contiguous()
168 |
169 | decode_idx = Variable(torch.LongTensor(seq_len, batch_size))
170 | if self.use_cuda:
171 | decode_idx = decode_idx.cuda()
172 | decode_idx[-1] = pointer.data
173 | for idx in range(len(back_points) - 2, -1, -1):
174 | pointer = torch.gather(back_points[idx], 1, pointer.contiguous().view(batch_size, 1))
175 | decode_idx[idx] = pointer.view(-1).data
176 | path_score = None
177 | decode_idx = decode_idx.transpose(1, 0)
178 | return path_score, decode_idx
179 |
180 | def forward(self, feats, mask):
181 | path_score, best_path = self._viterbi_decode(feats, mask)
182 | return path_score, best_path
183 |
184 | def _score_sentence(self, scores, mask, tags):
185 | """
186 | Args:
187 | scores: size=(seq_len, batch_size, tag_size, tag_size)
188 | mask: size=(batch_size, seq_len)
189 | tags: size=(batch_size, seq_len)
190 | Returns:
191 | score:
192 | """
193 | batch_size = scores.size(1)
194 | seq_len = scores.size(0)
195 | tag_size = scores.size(-1)
196 |
197 | new_tags = Variable(torch.LongTensor(batch_size, seq_len))
198 | if self.use_cuda:
199 | new_tags = new_tags.cuda()
200 | for idx in range(seq_len):
201 | if idx == 0:
202 | new_tags[:, 0] = (tag_size - 2) * tag_size + tags[:, 0]
203 | else:
204 | new_tags[:, idx] = tags[:, idx - 1] * tag_size + tags[:, idx]
205 |
206 | end_transition = self.transitions[:, self.END_TAG_IDX].contiguous().view(
207 | 1, tag_size).expand(batch_size, tag_size)
208 | length_mask = torch.sum(mask, dim=1).view(batch_size, 1).long()
209 | end_ids = torch.gather(tags, 1, length_mask - 1)
210 |
211 | end_energy = torch.gather(end_transition, 1, end_ids)
212 |
213 | new_tags = new_tags.transpose(1, 0).contiguous().view(seq_len, batch_size, 1)
214 | tg_energy = torch.gather(scores.view(seq_len, batch_size, -1), 2, new_tags).view(
215 | seq_len, batch_size)
216 | tg_energy = tg_energy.masked_select(mask.transpose(1, 0))
217 |
218 | gold_score = tg_energy.sum() + end_energy.sum()
219 |
220 | return gold_score
221 |
222 | def neg_log_likelihood_loss(self, feats, mask, tags):
223 | """
224 | Args:
225 | feats: size=(batch_size, seq_len, tag_size)
226 | mask: size=(batch_size, seq_len)
227 | tags: size=(batch_size, seq_len)
228 | """
229 | batch_size = feats.size(0)
230 | forward_score, scores = self._forward_alg(feats, mask)
231 | gold_score = self._score_sentence(scores, mask, tags)
232 | if self.average_batch:
233 | return (forward_score - gold_score) / batch_size
234 | return forward_score - gold_score
235 |
--------------------------------------------------------------------------------
/model/model.py:
--------------------------------------------------------------------------------
1 | from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
2 |
3 | from config import Config
4 | from model.crf import CRF
5 | from utils.utils import *
6 |
7 |
8 | class BiLSTM_CRF(nn.Module):
9 | def __init__(self, data):
10 | super(BiLSTM_CRF, self).__init__()
11 | label_size = data.label_alphabet.size()
12 | data.label_size = label_size + 2
13 | self.lstm = BiLSTM(data).to(device)
14 | self.crf = CRF(target_size=label_size, use_cuda=use_cuda, average_batch=True).to(device)
15 |
16 | def neg_log_likelihood_loss(self, batch_label, mask, *args):
17 | outs = self.lstm.get_output_score(*args)
18 | total_loss = self.crf.neg_log_likelihood_loss(outs, mask, batch_label)
19 | return total_loss
20 |
21 | def forward(self, mask, *args):
22 | outs = self.lstm.get_output_score(*args)
23 | scores, tag_seq = self.crf(outs, mask)
24 | return tag_seq
25 |
26 |
27 | class BiLSTM(nn.Module, Config):
28 | def __init__(self, data):
29 | Config.__init__(self)
30 | super(BiLSTM, self).__init__()
31 | self.drop = nn.Dropout(self.dropout).to(device)
32 | self.char_embeddings = init_embedding(data.char_alphabet.size(), self.char_emb_dim)
33 | self.pos_embeddings = init_embedding(data.dict_alphabet.size(), self.pos_emb_dim)
34 | self.tag_embeddings = init_embedding(data.tag_alphabet.size(), self.tag_emb_dim)
35 | self.lstm = nn.LSTM(self.char_emb_dim + self.pos_emb_dim + self.pos_hidden_dim, self.lstm_hidden_dim // 2,
36 | batch_first=True, bidirectional=True).to(device)
37 | self.hidden2tag = nn.Linear(self.lstm_hidden_dim, data.label_size).to(device)
38 | self.posBiLSTM = PosBiLSTM(data).to(device)
39 |
40 | def get_lstm_features(self, char_inputs, pos_inputs, tag_inputs, seq_lengths):
41 | char_embs = self.char_embeddings(char_inputs)
42 | char_embs = self.drop(char_embs)
43 | pos_embs = self.pos_embeddings(pos_inputs)
44 | pos_embs = self.drop(pos_embs)
45 | pos_lstm_out = self.posBiLSTM.get_lstm_features(tag_inputs, seq_lengths)
46 | emb = torch.cat([char_embs, pos_embs, pos_lstm_out], 2)
47 | packed_chars = pack_padded_sequence(emb, seq_lengths.cpu().numpy(), True)
48 | lstm_out, _ = self.lstm(packed_chars)
49 | lstm_out, _ = pad_packed_sequence(lstm_out)
50 | lstm_out = self.drop(lstm_out.transpose(1, 0))
51 | return lstm_out
52 |
53 | def get_output_score(self, *args):
54 | lstm_out = self.get_lstm_features(*args)
55 | outputs = self.hidden2tag(lstm_out)
56 | return outputs
57 |
58 | def forward(self, mask, *args):
59 | batch_size = args[0].size(0)
60 | seq_len = args[0].size(1)
61 | outs = self.get_output_score(*args)
62 | outs = outs.view(batch_size * seq_len, -1)
63 | _, tag_seq = torch.max(outs, 1)
64 | tag_seq = tag_seq.view(batch_size, seq_len)
65 | decode_seq = mask.long() * tag_seq
66 | return decode_seq
67 |
68 |
69 | class PosBiLSTM(nn.Module, Config):
70 | def __init__(self, data):
71 | Config.__init__(self)
72 | super(PosBiLSTM, self).__init__()
73 | self.drop = nn.Dropout(self.dropout).to(device)
74 | self.pos_embeddings = init_embedding(data.dict_alphabet.size(), self.pos_emb_dim)
75 | self.lstm = nn.LSTM(self.pos_emb_dim, self.pos_hidden_dim // 2, batch_first=True, bidirectional=True).to(device)
76 | self.hidden2tag = nn.Linear(self.pos_hidden_dim, data.label_size).to(device)
77 |
78 | def get_lstm_features(self, pos_inputs, seq_lengths):
79 | pos_embs = self.pos_embeddings(pos_inputs)
80 | pos_embs = self.drop(pos_embs)
81 | packed_words = pack_padded_sequence(pos_embs, seq_lengths.cpu().numpy(), True)
82 | lstm_out, _ = self.lstm(packed_words)
83 | lstm_out, _ = pad_packed_sequence(lstm_out)
84 | lstm_out = self.drop(lstm_out.transpose(1, 0))
85 | return lstm_out
86 |
87 | def get_output_score(self, *args):
88 | lstm_out = self.get_lstm_features(*args)
89 | outputs = self.hidden2tag(lstm_out)
90 | return outputs
91 |
92 | def forward(self, mask, *args):
93 | batch_size = args[0].size(0)
94 | seq_len = args[0].size(1)
95 | outs = self.get_output_score(*args)
96 | outs = outs.view(batch_size * seq_len, -1)
97 | _, tag_seq = torch.max(outs, 1)
98 | tag_seq = tag_seq.view(batch_size, seq_len)
99 | decode_seq = mask.long() * tag_seq
100 | return decode_seq
101 |
--------------------------------------------------------------------------------
/output/baidu/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kdsec/chinese-opinion-target-extraction/05447962e6536a9c591fced1c09686a5209ac2f5/output/baidu/.gitkeep
--------------------------------------------------------------------------------
/output/dianping/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kdsec/chinese-opinion-target-extraction/05447962e6536a9c591fced1c09686a5209ac2f5/output/dianping/.gitkeep
--------------------------------------------------------------------------------
/output/mafengwo/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kdsec/chinese-opinion-target-extraction/05447962e6536a9c591fced1c09686a5209ac2f5/output/mafengwo/.gitkeep
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch==1.1.0
2 | thulac==0.2.0
3 | tqdm
4 | keras==2.3.0
5 | numpy==1.22.0
6 | numba
7 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kdsec/chinese-opinion-target-extraction/05447962e6536a9c591fced1c09686a5209ac2f5/utils/__init__.py
--------------------------------------------------------------------------------
/utils/data.py:
--------------------------------------------------------------------------------
1 | import json
2 |
3 | from keras.preprocessing.text import Tokenizer
4 |
5 | from config import Config
6 |
7 | Tokenizer.size = lambda x: len(x.word_index) + 1
8 | Tokenizer.get_index = lambda x, item: x.word_index[item] if item in x.word_index else 0
9 | Tokenizer.get_item = lambda x, index: x.index_item[index] if index in x.index_item else None
10 |
11 |
12 | class Data(Config):
13 | def __init__(self):
14 | Config.__init__(self)
15 | self.char_alphabet, self.dict_alphabet, self.label_alphabet, self.tag_alphabet = [None] * 4
16 | self.texts, self.ids, self.sentences = [], [], []
17 |
18 | def build_alphabet(self):
19 | self.label_alphabet = Tokenizer(char_level=True)
20 | self.label_alphabet.fit_on_texts('OBME')
21 | self.label_alphabet.index_item = dict(map(reversed, self.label_alphabet.word_index.items()))
22 | self.char_alphabet = Tokenizer(char_level=True)
23 | self.char_alphabet.fit_on_texts(map(lambda s: s['char'], self.sentences))
24 | self.tag_alphabet = Tokenizer(char_level=True)
25 | self.tag_alphabet.fit_on_texts(map(lambda s: s['char_pos_tag'], self.sentences))
26 | self.dict_alphabet = Tokenizer(char_level=True)
27 | self.dict_alphabet.fit_on_texts(map(lambda s: [str(sum([2 ** i * x for i, x in enumerate(word_dict)]))
28 | for word_dict in s['dict_feature']], self.sentences))
29 |
30 | def read_instance(self):
31 | instence_texts = []
32 | instence_id = []
33 | for sentence in self.sentences:
34 | chars, labels, dict_feats, tags, char_id, label_id, dict_id, tag_id = [[] for _ in range(8)]
35 | for i, (char, label, dict_feat, tag) in enumerate(
36 | zip(sentence['char'], sentence['char_tag'], sentence['dict_feature'], sentence['char_pos_tag'])):
37 | if i == self.MAX_SENTENCE_LENGTH: continue
38 | chars.append(char)
39 | char_id.append(self.char_alphabet.get_index(char))
40 | labels.append(label)
41 | label_id.append(self.label_alphabet.get_index(label.lower()))
42 | dict_feat = str(sum([2 ** i * x for i, x in enumerate(dict_feat)]))
43 | dict_feats.append(dict_feat)
44 | dict_id.append(self.dict_alphabet.get_index(dict_feat))
45 | tags.append(tag)
46 | tag_id.append(self.tag_alphabet.get_index(tag.lower()))
47 | instence_texts.append([chars, dict_feats, tags, labels])
48 | instence_id.append([char_id, dict_id, tag_id, label_id])
49 | return instence_texts, instence_id
50 |
51 | def data_loader(self, input_file, name):
52 | self.sentences = [json.loads(line) for line in open(input_file, 'r', encoding='utf-8')]
53 | if name == 'train':
54 | self.build_alphabet()
55 | self.texts, self.ids = self.read_instance()
56 | self.sentences = []
57 |
--------------------------------------------------------------------------------
/utils/preprocess.py:
--------------------------------------------------------------------------------
1 | import json
2 | import sys
3 | from itertools import product
4 |
5 | from numba import jit
6 | from thulac import thulac
7 | from tqdm import tqdm
8 |
9 | segment_tool = None
10 | dictionary = None
11 |
12 |
13 | @jit
14 | def kmp_search(T, P):
15 | mapping = [0]
16 | for x in P[1:]:
17 | check_index = mapping[-1]
18 | if P[check_index] == x:
19 | mapping += [check_index + 1]
20 | else:
21 | mapping += [0]
22 | result = []
23 | p_pointer = 0
24 | t_pointer = 0
25 | while t_pointer < len(T):
26 | if P[p_pointer] == T[t_pointer]:
27 | p_pointer += 1
28 | t_pointer += 1
29 | if p_pointer >= len(P):
30 | result += [t_pointer - len(P)]
31 | p_pointer = 0 if p_pointer == 0 else mapping[p_pointer - 1]
32 | else:
33 | t_pointer += 1 if p_pointer == 0 else 0
34 | p_pointer = 0 if p_pointer == 0 else mapping[p_pointer - 1]
35 | return result
36 |
37 |
38 | @jit
39 | def make_char_tag(words: list, target: str) -> list:
40 | rs = []
41 | sentence = ''.join(words)
42 | kmp_rs = kmp_search(sentence, target)
43 | i = 0
44 | while i < len(sentence):
45 | if i in kmp_rs:
46 | for c_i in range(len(target)):
47 | if c_i == 0:
48 | rs.append('B')
49 | elif c_i == len(target) - 1:
50 | rs.append('E')
51 | else:
52 | rs.append('M')
53 | i += len(target)
54 | else:
55 | rs.append('O')
56 | i += 1
57 | return rs
58 |
59 |
60 | @jit
61 | def n_gram_in_dict(chars: list, char_index: int) -> list:
62 | rs = []
63 | for n in range(2, 6):
64 | if n > len(chars):
65 | rs += [0, 0]
66 | continue
67 | # front n-gram
68 | if char_index < n - 1:
69 | rs.append(0)
70 | else:
71 | word = ''.join(chars[char_index - n + 1: char_index + 1])
72 | rs.append(int(word in dictionary))
73 | # rear n-gram
74 | if char_index > len(chars) - n:
75 | rs.append(0)
76 | else:
77 | word = ''.join(chars[char_index:char_index + n])
78 | rs.append(int(word in dictionary))
79 | return rs
80 |
81 |
82 | def make_dict_feat(chars):
83 | vector = []
84 | for char_index in range(len(chars)):
85 | vector.append(n_gram_in_dict(chars, char_index))
86 | return vector
87 |
88 |
89 | # @jit
90 | def construct_features(origin: dict) -> dict:
91 | features = {'content': origin['s'], 'label': origin['ot'], 'word': [], 'POS': [], 'char': [], 'char_pos': [],
92 | 'char_pos_tag': [], 'char_word_tag': []}
93 | sentence = origin['s'].replace('\xa0','')
94 | # Segment
95 | cut_word, cut_pos = [], []
96 | tokens = segment_tool.cut(sentence)
97 | for word, pos in tokens:
98 | cut_word.append(word)
99 | cut_pos.append(pos)
100 | features['word'] += cut_word
101 | features['POS'] += cut_pos
102 | for word in features['word']:
103 | features['char'] += list(word)
104 | # Build char pos
105 | for word_index, word in enumerate(features['word']):
106 | for _ in word:
107 | features['char_pos'].append(features['POS'][word_index])
108 | # Build char tag(BMEO)
109 | features['char_tag'] = make_char_tag(features['word'], origin['ot'])
110 | # Build char_pos_tag
111 | for word_index, word in enumerate(features['word']):
112 | if len(word) == 1:
113 | features['char_pos_tag'].append('S_' + features['POS'][word_index])
114 | else:
115 | for index, char in enumerate(word):
116 | if index == 0:
117 | features['char_pos_tag'].append('B_' + features['POS'][word_index])
118 | elif index == len(word) - 1:
119 | features['char_pos_tag'].append('E_' + features['POS'][word_index])
120 | else:
121 | features['char_pos_tag'].append('M_' + features['POS'][word_index])
122 | # Build char_word_tag(BEMS)
123 | for word_index, word in enumerate(features['word']):
124 | if len(word) == 1:
125 | features['char_word_tag'].append('S')
126 | else:
127 | for index, char in enumerate(word):
128 | if index == 0:
129 | features['char_word_tag'].append('B')
130 | elif index == len(word) - 1:
131 | features['char_word_tag'].append('E')
132 | else:
133 | features['char_word_tag'].append('M')
134 | # Build dict_feat
135 | features['dict_feature'] = make_dict_feat(features['char'])
136 | return features
137 |
138 |
139 | def handle_data(dataset: str) -> list:
140 | print('Processing %s ...' % dataset)
141 | data = [construct_features(json.loads(line)) for line in
142 | tqdm(open(dataset, 'r', encoding='utf-8').readlines())]
143 | return data
144 |
145 |
146 | def preprocess(dataset: str):
147 | global segment_tool, dictionary
148 | print('Loading Segment Model...')
149 | segment_tool = thulac(rm_space=True)
150 | print('Loading dictionary')
151 | dictionary = set(map(lambda s: s.rstrip('\n'), open('dataset/dictionary.txt', encoding='utf-8').readlines()))
152 |
153 | dataset_list = (['train', 'test'], [dataset])
154 | for dataset_type, dataset_name in product(*dataset_list):
155 | with open('dataset/%s/%s_seg.txt' % (dataset_name, dataset_type), 'w', encoding='utf-8') as f:
156 | for line in handle_data('dataset/%s/%s.txt' % (dataset_name, dataset_type)):
157 | f.write(json.dumps(line, ensure_ascii=False) + '\n')
158 |
159 |
160 | if __name__ == '__main__':
161 | preprocess(sys.argv[1])
162 |
--------------------------------------------------------------------------------
/utils/score.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import time
4 |
5 |
6 | def score(input_file, test_file, result_path):
7 | pred_file = open(input_file, 'r', encoding='utf-8')
8 | pred = []
9 | for line in pred_file:
10 | if line == '': break
11 | pair = line.split('][')
12 | words = eval(pair[0] + ']')
13 | chars = ''.join(words)
14 | tags = eval('[' + pair[1])
15 | rs = []
16 | start, end = 0, 0
17 | while start < len(tags) - 1:
18 | if tags[start].split('_')[0] != 'o':
19 | while end < len(tags) - 1:
20 | if tags[end].split('_')[0] != 'o':
21 | end += 1
22 | else:
23 | break
24 | if end - start > 1:
25 | rs.append(chars[start:end])
26 | start = end
27 | start += 1
28 | end = start
29 | rs = list(set(rs))
30 | pred.append(rs)
31 | true_file = open(test_file, 'r', encoding='utf-8').readlines()
32 | result_file = os.path.join(result_path, 'result.txt')
33 | result_file_f = open(result_file, 'w', encoding='utf-8')
34 | for i, line in enumerate(true_file):
35 | info = json.loads(line)
36 | rs = {'content': info['content'], 'true_label': info['label'], 'pred_label': pred[i]}
37 | result_file_f.write(json.dumps(rs, ensure_ascii=False) + '\n')
38 | result_file_f.close()
39 | true_positive = 0
40 | positive = 0 # TP + TN
41 | total_num = 0 # TP + FN
42 | false_positive = 0
43 | wrong_file = open(os.path.join(result_path, 'wrong.json'), 'w', encoding='utf-8')
44 | with open(result_file, 'r', encoding='utf-8') as f:
45 | for line in f.readlines():
46 | rs_line = json.loads(line.strip())
47 | predict_label = rs_line['pred_label']
48 | true_label = rs_line['true_label']
49 | content = rs_line['content']
50 | if true_label in predict_label:
51 | true_positive += 1
52 | else:
53 | false_positive += 1
54 | wrong_file.write('{}\t{}\t{}\n'.format(predict_label, true_label, content))
55 | positive += len(predict_label)
56 | total_num += 1
57 | precision = 100.0 * true_positive / positive
58 | recall = 100.0 * true_positive / total_num
59 | F1 = 2 * precision * recall / (precision + recall)
60 | print('Results: right:%d wrong:%d model find:%d total:%d' % (true_positive, false_positive, positive, total_num))
61 | print('Metrics: Precision:%.3f Recall:%.3f F1:%.3f' % (precision, recall, F1))
62 | print(time.asctime(time.localtime(time.time())))
63 |
--------------------------------------------------------------------------------
/utils/utils.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import pickle
3 | import warnings
4 |
5 | import numpy as np
6 | import torch
7 | from torch import nn
8 |
9 | warnings.filterwarnings('ignore')
10 |
11 | use_cuda = torch.cuda.is_available()
12 | device = torch.device('cuda' if use_cuda else 'cpu')
13 |
14 |
15 | def init_embedding(vocab_size, embedding_dim):
16 | emb = nn.Embedding(vocab_size, embedding_dim).to(device)
17 | pretrain_emb = np.empty([vocab_size, embedding_dim])
18 | scale = np.sqrt(3.0 / embedding_dim)
19 | for index in range(vocab_size):
20 | pretrain_emb[index, :] = np.random.uniform(-scale, scale, [1, embedding_dim])
21 | emb.weight.data.copy_(torch.from_numpy(pretrain_emb))
22 | return emb
23 |
24 |
25 | def seq2label(pred_tensor, mask_tensor, label_alphabet, char_recover):
26 | pred_tensor = pred_tensor[char_recover]
27 | mask_tensor = mask_tensor[char_recover]
28 | seq_len = pred_tensor.size(1)
29 | mask = mask_tensor.cpu().data.numpy()
30 | pred_ids = pred_tensor.cpu().data.numpy()
31 | batch_size = mask.shape[0]
32 | labels = []
33 | for i in range(batch_size):
34 | pred = [label_alphabet.get_item(pred_ids[i][j]) for j in range(seq_len) if mask[i][j] != 0]
35 | labels.append(pred)
36 | return labels
37 |
38 |
39 | def slice_set(batch: int, batch_size: int, max: int):
40 | start = batch * batch_size
41 | end = (batch + 1) * batch_size
42 | if end > max:
43 | end = max
44 | return start, end
45 |
46 |
47 | def load_batch(instances, test_mode=False):
48 | batch_size = len(instances)
49 | chars = [instance[0] for instance in instances]
50 | dict_feats = [instance[1] for instance in instances]
51 | tags = [instance[2] for instance in instances]
52 | labels = [instance[3] for instance in instances]
53 | seq_lengths = torch.tensor(list(map(len, chars)), dtype=torch.long, device=device)
54 | max_seq_len = seq_lengths.max()
55 | with torch.set_grad_enabled(test_mode):
56 | char_seq_tensor = torch.zeros((batch_size, max_seq_len), dtype=torch.long, device=device)
57 | dict_seq_tensor = torch.zeros((batch_size, max_seq_len), dtype=torch.long, device=device)
58 | tag_seq_tensor = torch.zeros((batch_size, max_seq_len), dtype=torch.long, device=device)
59 | label_seq_tensor = torch.zeros((batch_size, max_seq_len), dtype=torch.long, device=device)
60 | mask = torch.zeros((batch_size, max_seq_len), dtype=torch.uint8, device=device)
61 | for idx, (seq, pos, tag, label, seqlen) in enumerate(zip(chars, dict_feats, tags, labels, seq_lengths)):
62 | char_seq_tensor[idx, :seqlen] = torch.tensor(seq, dtype=torch.long)
63 | dict_seq_tensor[idx, :seqlen] = torch.tensor(pos, dtype=torch.long)
64 | tag_seq_tensor[idx, :seqlen] = torch.tensor(tag, dtype=torch.long)
65 | label_seq_tensor[idx, :seqlen] = torch.tensor(label, dtype=torch.long)
66 | mask[idx, :seqlen] = torch.ones(seqlen.item(), dtype=torch.int64)
67 | seq_lengths, char_perm_idx = seq_lengths.sort(0, descending=True)
68 | char_seq_tensor = char_seq_tensor[char_perm_idx]
69 | dict_seq_tensor = dict_seq_tensor[char_perm_idx]
70 | tag_seq_tensor = tag_seq_tensor[char_perm_idx]
71 | label_seq_tensor = label_seq_tensor[char_perm_idx]
72 | mask = mask[char_perm_idx]
73 | _, char_seq_recover = char_perm_idx.sort(0, descending=False)
74 | return label_seq_tensor, mask, char_seq_tensor, dict_seq_tensor, tag_seq_tensor, seq_lengths, char_seq_recover
75 |
76 |
77 | def lr_decay(optimizer, epoch, decay_rate, init_lr):
78 | lr = init_lr / (1 + decay_rate * epoch)
79 | print('learning rate: {0}'.format(lr))
80 | for param_group in optimizer.param_groups:
81 | param_group['lr'] = lr
82 | return optimizer
83 |
84 |
85 | def save_data_setting(data):
86 | _data = copy.deepcopy(data)
87 | _data.texts, _data.ids = [], []
88 | pickle.dump(_data, open(data.config_path, 'wb+'))
89 |
90 |
91 | def save_results(data, results):
92 | result_file = open(data.result_path, 'w', encoding='utf-8')
93 | sent_num = len(results)
94 | content_list = data.texts
95 | for i in range(sent_num):
96 | result_file.write('{}{}\n'.format(content_list[i][0], results[i]))
97 | print('Results have been written into %s' % data.result_path)
98 |
--------------------------------------------------------------------------------