'])
275 | indices += [2]
276 | out = transducer.tag([indices], char2idx, sess, batch_size=1)
277 | out = out[0].replace(' ', ' ')
278 | return out
279 |
280 | def train(self, t_x, t_y, v_x, v_y_raw, v_y_gold, idx2tag, idx2char, unk_chars, trans_dict, sess, epochs,
281 | trained_model, transducer=None, lr=0.05, decay=0.05, decay_step=1, sent_seg=False, outpath=None):
282 | lr_r = lr
283 |
284 | best_epoch = 0
285 | best_score = [0] * 6
286 |
287 | chars = toolbox.decode_chars(v_x[0], idx2char)
288 |
289 | for i in range(len(v_x[0])):
290 | for j, n in enumerate(v_x[0][i]):
291 | if n in unk_chars:
292 | v_x[0][i][j] = 1
293 |
294 | for i in range(len(t_x[0])):
295 | for k in range(len(t_x[0][i])):
296 | for j, n in enumerate(t_x[0][i][k]):
297 | if n in unk_chars:
298 | t_x[0][i][k][j] = 1
299 |
300 | transducer_dict = None
301 | if transducer is not None:
302 | char2idx = {k:v for v, k in idx2char.items()}
303 |
304 | def transducer_dict(trans_str):
305 | return self.define_transducer_dict(trans_str, char2idx, sess[-1], transducer)
306 |
307 | for epoch in range(epochs):
308 | print 'epoch: %d' % (epoch + 1)
309 | t = time()
310 | if epoch % decay_step == 0 and decay > 0:
311 | lr_r = lr/(1 + decay*(epoch/decay_step))
312 |
313 | data_list = t_x + t_y
314 |
315 | samples = zip(*data_list)
316 |
317 | random.shuffle(samples)
318 |
319 | for sample in samples:
320 | c_len = len(sample[0][0])
321 | idx = self.bucket_dit[c_len]
322 | real_batch_size = self.real_batches[idx]
323 | model = self.input_v[idx] + self.output_[idx]
324 | Batch.train(sess=sess[0], model=model, batch_size=real_batch_size, config=self.train_step[idx],
325 | lr=self.l_rate, lrv=lr_r, dr=self.drop_out, drv=self.drop_out_v, data=list(sample),
326 | verbose=False)
327 |
328 | predictions = []
329 |
330 | #for v_b_x in zip(*v_x):
331 | c_len = len(v_x[0][0])
332 | idx = self.bucket_dit[c_len]
333 | b_prediction = self.predict(data=v_x, sess=sess, model=self.input_v[idx] + self.output[idx], index=idx,
334 | argmax=True, batch_size=200)
335 | b_prediction = toolbox.decode_tags(b_prediction, idx2tag)
336 | predictions.append(b_prediction)
337 |
338 | predictions = zip(*predictions)
339 | predictions = toolbox.merge_bucket(predictions)
340 |
341 | if self.is_space == 'sea':
342 | prediction_out, raw_out = toolbox.generate_output_sea(chars, predictions)
343 | else:
344 | prediction_out, raw_out = toolbox.generate_output(chars, predictions, trans_dict, transducer_dict)
345 |
346 | if sent_seg:
347 | scores = evaluation.evaluator(prediction_out, v_y_gold, raw_out, v_y_raw)
348 | else:
349 | scores = evaluation.evaluator(prediction_out, v_y_gold)
350 | if sent_seg:
351 | c_score = scores[2] * scores[5]
352 | c_best_score = best_score[2] * best_score[5]
353 | else:
354 | c_score = scores[2]
355 | c_best_score = best_score[2]
356 |
357 | if c_score > c_best_score:
358 | best_epoch = epoch + 1
359 | best_score = scores
360 | self.saver.save(sess[0], trained_model, write_meta_graph=False)
361 |
362 | if outpath is not None:
363 | wt = codecs.open(outpath, 'w', encoding='utf-8')
364 | for pre in prediction_out[0]:
365 | wt.write(pre + '\n')
366 | wt.close()
367 |
368 |
369 | if sent_seg:
370 | print 'Sentence segmentation:'
371 | print 'F score: %f\n' % scores[5]
372 | print 'Word segmentation:'
373 | print 'F score: %f' % scores[2]
374 | else:
375 | print 'F score: %f' % c_score
376 | print 'Time consumed: %d seconds' % int(time() - t)
377 | print 'Training is finished!'
378 | if sent_seg:
379 | print 'Sentence segmentation:'
380 | print 'Best F score: %f' % best_score[5]
381 | print 'Best Precision: %f' % best_score[3]
382 | print 'Best Recall: %f\n' % best_score[4]
383 | print 'Word segmentation:'
384 | print 'Best F score: %f' % best_score[2]
385 | print 'Best Precision: %f' % best_score[0]
386 | print 'Best Recall: %f\n' % best_score[1]
387 | else:
388 | print 'Best F score: %f' % best_score[2]
389 | print 'Best Precision: %f' % best_score[0]
390 | print 'Best Recall: %f\n' % best_score[1]
391 | print 'Best epoch: %d' % best_epoch
392 |
393 | def test(self, t_x, t_y_raw, t_y_gold, idx2tag, idx2char, unk_chars, sub_dict, trans_dict, sess, transducer,
394 | ensemble=None, batch_size=100, sent_seg=False, bias=-1, outpath=None, trans_type='mix'):
395 |
396 | chars = toolbox.decode_chars(t_x[0], idx2char)
397 | gold_out = t_y_gold
398 |
399 | for i in range(len(t_x[0])):
400 | for j, n in enumerate(t_x[0][i]):
401 | if n in sub_dict:
402 | t_x[0][i][j] = sub_dict[n]
403 | elif n in unk_chars:
404 | t_x[0][i][j] = 1
405 |
406 | transducer_dict = None
407 | if transducer is not None:
408 | char2idx = {v: k for k, v in idx2char.items()}
409 |
410 | def transducer_dict(trans_str):
411 | return self.define_transducer_dict(trans_str, char2idx, sess[-1], transducer)
412 |
413 | if bias < 0:
414 | argmax = True
415 | else:
416 | argmax = False
417 |
418 | prediction = self.predict(data=t_x, sess=sess, model=self.input_v[0] + self.output[0], index=0,
419 | argmax=argmax, batch_size=batch_size, ensemble=ensemble)
420 |
421 | if bias >= 0 and self.crf == 0:
422 | prediction = [toolbox.biased_out(prediction[0], bias)]
423 |
424 | predictions = toolbox.decode_tags(prediction, idx2tag)
425 |
426 | if self.is_space == 'sea':
427 | prediction_out, raw_out = toolbox.generate_output_sea(chars, predictions)
428 | else:
429 | prediction_out, raw_out = toolbox.generate_output(chars, predictions, trans_dict, transducer_dict,
430 | trans_type=trans_type)
431 |
432 | if sent_seg:
433 | scores = evaluation.evaluator(prediction_out, gold_out, raw_out, t_y_raw)
434 | else:
435 | scores = evaluation.evaluator(prediction_out, gold_out, verbose=True)
436 |
437 | if outpath is not None:
438 | wt = codecs.open(outpath, 'w', encoding='utf-8')
439 | for pre in prediction_out[0]:
440 | wt.write(pre + '\n')
441 | wt.close()
442 |
443 | print 'Evaluation scores:'
444 | if sent_seg:
445 | print 'Sentence segmentation:'
446 | print 'F score: %f' % scores[5]
447 | print 'Precision: %f' % scores[3]
448 | print 'Recall: %f\n' % scores[4]
449 | print 'Word segmentation:'
450 | print 'F score: %f' % scores[2]
451 | print 'Precision: %f' % scores[0]
452 | print 'Recall: %f\n' % scores[1]
453 | else:
454 | print 'Precision: %f' % scores[0]
455 | print 'Recall: %f' % scores[1]
456 | print 'F score: %f' % scores[2]
457 | print 'True negative rate: %f' % scores[3]
458 |
459 | def tag(self, r_x, r_x_raw, idx2tag, idx2char, unk_chars, sub_dict, trans_dict, sess, transducer, ensemble=None,
460 | batch_size=100, outpath=None, sent_seg=False, seg_large=False, form='conll'):
461 |
462 | chars = toolbox.decode_chars(r_x[0], idx2char)
463 |
464 | for i in range(len(r_x[0])):
465 | for j, n in enumerate(r_x[0][i]):
466 | if n in sub_dict:
467 | r_x[0][i][j] = sub_dict[n]
468 | elif n in unk_chars:
469 | r_x[0][i][j] = 1
470 |
471 | c_len = len(r_x[0][0])
472 | idx = self.bucket_dit[c_len]
473 |
474 | real_batch = batch_size * 300 / c_len
475 |
476 | transducer_dict = None
477 | if transducer is not None:
478 | char2idx = {v: k for k, v in idx2char.items()}
479 |
480 | def transducer_dict(trans_str):
481 | return self.define_transducer_dict(trans_str, char2idx, sess[-1], transducer)
482 |
483 | prediction = self.predict(data=r_x, sess=sess, model=self.input_v[idx] + self.output[idx], index=idx,
484 | argmax=True, batch_size=real_batch, ensemble=ensemble)
485 |
486 | predictions = toolbox.decode_tags(prediction, idx2tag)
487 |
488 | if self.is_space == 'sea':
489 | prediction_out, raw_out = toolbox.generate_output_sea(chars, predictions)
490 | multi_out = prediction_out
491 | else:
492 | prediction_out, raw_out, multi_out = toolbox.generate_output(chars, predictions, trans_dict,
493 | transducer_dict, multi_tok=True)
494 |
495 | pre_out = []
496 | mut_out = []
497 | for pre in prediction_out:
498 | pre_out += pre
499 | for mul in multi_out:
500 | mut_out += mul
501 | prediction_out = pre_out
502 | multi_out = mut_out
503 |
504 | if form == 'mlp1' or form == 'mlp2':
505 | prediction_out = toolbox.mlp_post(r_x_raw, prediction_out, self.is_space, form)
506 |
507 | if not seg_large:
508 | toolbox.printer(r_x_raw, prediction_out, multi_out, outpath, sent_seg, form)
509 |
510 | else:
511 | return prediction_out, multi_out
512 |
513 | def predict(self, data, sess, model, index=None, argmax=True, batch_size=100, ensemble=None, verbose=False):
514 | if self.crf:
515 | assert index is not None
516 | predictions = Batch.predict(sess=sess[0], decode_sess=sess[1], model=model,
517 | transitions=[self.transition_char], crf=self.crf, scores=self.scores[index],
518 | decode_holders=self.decode_holders[index], batch_size=batch_size,
519 | data=data, dr=self.drop_out, ensemble=ensemble, verbose=verbose)
520 | else:
521 | predictions = Batch.predict(sess=sess[0], model=model, crf=self.crf, argmax=argmax, batch_size=batch_size,
522 | data=data, dr=self.drop_out, ensemble=ensemble, verbose=verbose)
523 | return predictions
524 |
525 |
--------------------------------------------------------------------------------
/reader.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | @author: Yan Shao, yan.shao@lingfil.uu.se
4 | """
5 | import codecs
6 |
7 |
8 | def gold(path, is_dev=True, form='conll', is_space=False):
9 | sents = []
10 | sent = []
11 | cter = 0
12 | sents_dev = None
13 | if not is_dev:
14 | sents_dev = []
15 | for line in codecs.open(path, 'rb', encoding='utf8'):
16 | line = line.strip()
17 | if form == 'conll':
18 | segs = line.split('\t')
19 | if len(segs) == 10:
20 | if '.' not in segs[0]:
21 | sent.append(tuple(segs))
22 | elif len(sent) > 0:
23 | if not is_dev and cter == 9:
24 | sents_dev.append(sent)
25 | cter = 0
26 | else:
27 | sents.append(sent)
28 | cter += 1
29 | sent = []
30 | elif form == 'mlp1' or form == 'mlp2':
31 | if len(line) > 0:
32 | if form == 'mlp1':
33 | segs = []
34 | for l_seg in line.split(' '):
35 | if len(l_seg) > 0:
36 | if is_space == 'sea':
37 | segs.append(l_seg.replace('_', ' '))
38 | else:
39 | segs += l_seg.split('\\\\')
40 | else:
41 | segs = line.split()
42 | for i, seg in enumerate(segs):
43 | sent.append((str(i + 1), seg))
44 | if not is_dev and cter == 9:
45 | sents_dev.append(sent)
46 | cter = 0
47 | else:
48 | sents.append(sent)
49 | cter += 1
50 | sent = []
51 | else:
52 | raise Exception('Format error, available: conll, mlp1, mlp2')
53 | if is_dev:
54 | return sents
55 | else:
56 | return sents, sents_dev
57 |
58 |
59 | def raw(path):
60 | sents = []
61 | for line in codecs.open(path, 'rb', encoding='utf-8'):
62 | line = line.strip()
63 | sents.append(line)
64 | return sents
65 |
66 |
67 | def get_gold(sent, ignore_mwt=False):
68 | line = ''
69 | nt = -1
70 | mwt = ''
71 | segs = []
72 | for tk in sent:
73 | if '-' in tk[0]:
74 | if nt == 0:
75 | s_mwt = ''.join(segs)
76 | if ignore_mwt and s_mwt != mwt:
77 | line += ' ' + mwt
78 | else:
79 | for seg in segs:
80 | line += ' ' + seg
81 | mwt = tk[1]
82 | sp = tk[0].split('-')
83 | nt = int(sp[1]) - int(sp[0]) + 1
84 | segs = []
85 | elif nt == -1:
86 | line += ' ' + tk[1]
87 | elif nt > 0:
88 | segs.append(tk[1])
89 | nt -= 1
90 | elif nt == 0:
91 | s_mwt = ''.join(segs)
92 | if ignore_mwt and s_mwt != mwt:
93 | line += ' ' + mwt
94 | else:
95 | for seg in segs:
96 | line += ' ' + seg
97 | nt = -1
98 | mwt = ''
99 | segs = []
100 | line += ' ' + tk[1]
101 | return line.strip()
102 |
103 |
104 | def test_gold(path, form='conll', is_space=False, ignore_mwt=False):
105 | sents = []
106 | sent = []
107 | st = ''
108 | for line in codecs.open(path, 'rb', encoding='utf-8'):
109 | line = line.strip()
110 | if form == 'conll':
111 | segs = line.split('\t')
112 | if len(segs) == 10:
113 | if '.' not in segs[0]:
114 | sent.append(tuple(segs))
115 | elif len(sent) > 0:
116 | sents.append(sent)
117 | sent = []
118 | elif form == 'mlp1' or form == 'mlp2':
119 | if len(line) > 0:
120 | if form == 'mlp1':
121 | segs = []
122 | for l_seg in line.split(' '):
123 | if is_space == 'sea':
124 | segs.append(l_seg.replace('_', ' '))
125 | else:
126 | segs += l_seg.split('\\\\')
127 | else:
128 | segs = line.split()
129 | for seg in segs:
130 | st += ' ' + seg
131 | sents.append(st.strip())
132 | st = ''
133 | else:
134 | raise Exception('Format error, available: conll, mlp1, mlp2')
135 | if form == 'conll':
136 | p_sents = [get_gold(s_sent, ignore_mwt=ignore_mwt) for s_sent in sents]
137 | sents = p_sents
138 | return sents
139 |
140 |
141 | def get_raw(path, fin, fout, cat='other', new=True, is_dev=True, form='conll', is_space=False):
142 | fout = codecs.open(path + '/' + fout, 'w', encoding='utf-8')
143 | fout_dev = None
144 | if not is_dev:
145 | fout_dev = codecs.open(path + '/raw_dev.txt', 'w', encoding='utf-8')
146 | cter = 0
147 | if form == 'conll':
148 | if cat == 'gold':
149 | for line in codecs.open(path + '/' + fin, 'r', encoding='utf-8'):
150 | line = line.strip()
151 | line = line.replace('&apos', '\'')
152 | if len(line) > 0 and ('# sentence' in line or '# text' in line):
153 | if new:
154 | if not is_dev and cter == 9:
155 | fout_dev.write(line[line.index('=') + 1:].lstrip() + '\n')
156 | cter = 0
157 | else:
158 | fout.write(line[line.index('=') + 1:].lstrip() + '\n')
159 | cter += 1
160 | else:
161 | if not is_dev and cter == 9:
162 | fout_dev.write(line[line.index(':') + 1:].lstrip() + '\n')
163 | cter = 0
164 | else:
165 | fout.write(line[line.index(':') + 1:].lstrip() + '\n')
166 | cter += 1
167 |
168 | elif cat == 'zh':
169 | pt = ''
170 | for line in codecs.open(path + '/' + fin, 'r', encoding='utf-8'):
171 | line = line.strip()
172 | line = line.split('\t')
173 | if len(line) == 10:
174 | pt += line[1]
175 | else:
176 | if len(pt) > 0:
177 | if not is_dev and cter == 9:
178 | fout_dev.write(pt + '\n')
179 | cter = 0
180 | else:
181 | fout.write(pt + '\n')
182 | cter += 1
183 | pt = ''
184 |
185 | else:
186 | punc_e = ['!', ')', ',', '.', ';', ':', '?', '»', '...', ']', '..', '....', '%', 'º', '²', '°']
187 | punc_b = ['¿', '¡', '(', '«', '[']
188 | punc_m = ['"', '\'']
189 | punc_e = [s.decode('utf-8') for s in punc_e]
190 | punc_b = [s.decode('utf-8') for s in punc_b]
191 | punc_m = [s.decode('utf-8') for s in punc_m]
192 | md = {}
193 | for p in punc_m:
194 | md[p] = True
195 | pt = ''
196 | ct = 0
197 | for line in codecs.open(path + '/' + fin, 'r', encoding='utf-8'):
198 | line = line.strip()
199 | segs = line.split('\t')
200 | if len(segs) == 10:
201 | if '-' in segs[0]:
202 | sp = segs[0].split('-')
203 | ct = int(sp[1]) - int(sp[0]) + 1
204 | if len(pt) > 0 and pt[-1] in punc_b:
205 | pt += segs[1]
206 | elif len(pt) > 0 and pt[-1] in punc_m:
207 | if md[pt[-1]]:
208 | pt += ' ' + segs[1]
209 | else:
210 | pt += segs[1]
211 | else:
212 | pt += ' ' + segs[1]
213 | elif ct == 0:
214 | if segs[1] in punc_e:
215 | pt += segs[1]
216 | elif len(pt) > 0 and pt[-1] in punc_b:
217 | pt += segs[1]
218 | if segs[1] in punc_m:
219 | if md[segs[1]]:
220 | md[segs[1]] = False
221 | else:
222 | md[segs[1]] = True
223 | elif segs[1] in punc_m:
224 | if md[segs[1]]:
225 | pt += ' ' + segs[1]
226 | md[segs[1]] = False
227 | else:
228 | pt += segs[1]
229 | md[segs[1]] = True
230 | elif len(pt) > 0 and pt[-1] in punc_m:
231 | if md[pt[-1]]:
232 | pt += ' ' + segs[1]
233 | else:
234 | pt += segs[1]
235 | elif segs[1][0] == '\'':
236 | pt += segs[1]
237 | else:
238 | pt += ' ' + segs[1]
239 | else:
240 | ct -= 1
241 | else:
242 | if len(pt) > 0:
243 | pt = pt.lstrip()
244 | pt = pt.replace(' ",', '",')
245 | pt = pt.replace(' ".', '".')
246 | pt = pt.replace(':"...', ': "...')
247 | pt = pt.replace(' n\'t', 'n\'t')
248 | pt = pt.replace(' - ', '-')
249 | pt = pt.replace(' -- ', '--')
250 | pt = pt.replace(' / ', '/')
251 | if not is_dev and cter == 9:
252 | fout_dev.write(pt + '\n')
253 | cter = 0
254 | else:
255 | fout.write(pt + '\n')
256 | cter += 1
257 | pt = ''
258 | for p in punc_m:
259 | md[p] = True
260 |
261 | elif form == 'mlp1' or form == 'mlp2':
262 | for line in codecs.open(path + '/' + fin, 'r', encoding='utf-8'):
263 | line = line.strip()
264 | if len(line) > 0:
265 | if form == 'mlp1':
266 | if is_space == 'sea':
267 | line = line.replace('_', ' ')
268 | else:
269 | line = line.replace('\\\\', '')
270 | else:
271 | line = ''.join(line.split())
272 | if not is_dev and cter == 9:
273 | fout_dev.write(line + '\n')
274 | cter = 0
275 | else:
276 | fout.write(line + '\n')
277 | cter += 1
278 | else:
279 | raise Exception('Format error, available: conll, mlp1, mlp2')
280 | fout.close()
281 | if not is_dev:
282 | fout_dev.close()
283 |
--------------------------------------------------------------------------------
/segmenter.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | @author: Yan Shao, yan.shao@lingfil.uu.se
4 | """
5 | import reader
6 | import toolbox
7 | from model import Model
8 | from transducer_model import Seq2seq
9 | import sys
10 | import argparse
11 | import os
12 | import codecs
13 | import tensorflow as tf
14 | import cPickle as pickle
15 |
16 | from time import time
17 |
18 | parser = argparse.ArgumentParser(description='A Universal Tokeniser. Written by Y. Shao, Uppsala University')
19 | parser.add_argument('action', default='tag', choices=['train', 'test', 'tag'], help='train, test or tag')
20 |
21 | parser.add_argument('-f', '--format', default='conll', help='Data format of different tasks, conll, mlp1 or mlp2')
22 |
23 | parser.add_argument('-p', '--path', default=None, help='Path of the workstation')
24 |
25 | parser.add_argument('-t', '--train', default=None, help='File for training')
26 | parser.add_argument('-d', '--dev', default=None, help='File for validation')
27 | parser.add_argument('-e', '--test', default=None, help='File for evaluation')
28 | parser.add_argument('-r', '--raw', default=None, help='Raw file for tagging')
29 |
30 | parser.add_argument('-m', '--model', default='trained_model', help='Name of the trained model')
31 | parser.add_argument('-crf', '--crf', default=1, type=int, help='Using CRF interface')
32 |
33 | parser.add_argument('-bt', '--bucket_size', default=50, type=int, help='Bucket size')
34 | parser.add_argument('-sl', '--sent_limit', default=300, type=int, help='Long sentences will be chopped')
35 |
36 | parser.add_argument('-tg', '--tags', default='BIES', help='Boundary Tagging, default is BIES')
37 |
38 | parser.add_argument('-ed', '--emb_dimension', default=50, type=int, help='Dimension of the embeddings')
39 | parser.add_argument('-emb', '--embeddings', default=None, help='Path and name of pre-trained char embeddings')
40 |
41 | parser.add_argument('-ng', '--ngram', default=1, type=int, help='Using ngrams')
42 |
43 | parser.add_argument('-cell', '--cell', default='gru', help='Use GRU as the recurrent cell', choices=['gru', 'lstm'])
44 | parser.add_argument('-rnn', '--rnn_cell_dimension', default=200, type=int, help='Dimension of the RNN cells')
45 | parser.add_argument('-layer', '--rnn_layer_number', default=1, type=int, help='Numbers of the RNN layers')
46 |
47 | parser.add_argument('-dr', '--dropout_rate', default=0.5, type=float, help='Dropout rate')
48 |
49 | parser.add_argument('-iter', '--epochs', default=30, type=int, help='Numbers of epochs')
50 | parser.add_argument('-iter_trans', '--epochs_trans', default=50, type=int, help='Epochs for training the transducer')
51 |
52 | parser.add_argument('-op', '--optimizer', default='adagrad', help='Optimizer')
53 | parser.add_argument('-lr', '--learning_rate', default=0.2, type=float, help='Initial learning rate')
54 | parser.add_argument('-lr_trans', '--learning_rate_trans', default=0.3, type=float, help='Initial learning rate')
55 | parser.add_argument('-ld', '--decay_rate', default=0.05, type=float, help='Learning rate decay')
56 | parser.add_argument('-mt', '--momentum', default=None, type=float, help='Momentum')
57 |
58 | parser.add_argument('-ncp', '--no_clipping', default=False, action='store_true', help='Do not apply gradient clipping')
59 |
60 | parser.add_argument("-tb","--train_batch", help="Training batch size", default=10, type=int)
61 | parser.add_argument("-eb","--test_batch", help="Testing batch size", default=500, type=int)
62 | parser.add_argument("-rb","--tag_batch", help="Tagging batch size", default=500, type=int)
63 |
64 | parser.add_argument("-g","--gpu", help="the id of gpu, the default is 0", default=0, type=int)
65 |
66 | parser.add_argument('-opth', '--output_path', default=None, help='Output path')
67 |
68 | parser.add_argument('-sea', '--sea', help='Process languages like Vietamese', default=False, action='store_true')
69 |
70 | parser.add_argument('-ss', '--sent_seg', help='Perform sentence seg', default=False, action='store_true')
71 |
72 | parser.add_argument('-ens', '--ensemble', default=False, help='Ensemble several weights', action='store_true')
73 |
74 | parser.add_argument('-sgl', '--segment_large', default=False, help='Segment (very) large file', action='store_true')
75 |
76 | parser.add_argument('-lgs', '--large_size', default=10000, type=int, help='Segment (very) large file')
77 |
78 | parser.add_argument('-ot', '--only_tokenised', default=False,
79 | help='Only output the tokenised file when segment (very) large file', action='store_true')
80 |
81 | parser.add_argument('-ts', '--train_size', default=-1, type=int, help='No. of sentences used for training')
82 |
83 | parser.add_argument('-rs', '--reset', default=False, help='Delete and re-initialise the intermediate files',
84 | action='store_true')
85 | parser.add_argument('-rst', '--reset_trans', default=False, help='Retrain the transducers', action='store_true')
86 |
87 | parser.add_argument('-isp', '--ignore_space', default=False, help='Ignore space delimiters', action='store_true')
88 | parser.add_argument('-imt', '--ignore_mwt', default=False, help='Ignore multi-word tokens to be transcribed',
89 | action='store_true')
90 |
91 | parser.add_argument('-sb', '--segmentation_bias', default=-1, type=float,
92 | help='Add segmentation bias to under(over)-splitting')
93 |
94 | parser.add_argument('-tt', '--transduction_type', default='mix', choices=['mix', 'dict', 'trans', 'none'],
95 | help='Different ways of transducing the non-segmental MWTs')
96 |
97 | args = parser.parse_args()
98 |
99 | sys = reload(sys)
100 | sys.setdefaultencoding('utf-8')
101 | print 'Encoding: ', sys.getdefaultencoding()
102 |
103 | if args.action == 'train':
104 | assert args.path is not None
105 | path = args.path
106 | train_file = args.train
107 | dev_file = args.dev
108 | model_file = args.model
109 | print 'Reading data......'
110 | f_names = os.listdir(path)
111 | if train_file is None or dev_file is None:
112 | for f_n in f_names:
113 | if 'ud-train.conllu' in f_n or 'training.segd' in f_n or 'ud-sample.conllu' in f_n:
114 | train_file = f_n
115 | elif 'ud-dev.conllu' in f_n or 'development.segd' in f_n:
116 | dev_file = f_n
117 | assert train_file is not None
118 | is_space = True
119 | if 'Chinese' in path or 'Japanese' in path or args.format == 'mlp2':
120 | is_space = False
121 |
122 | if args.sea:
123 | is_space = 'sea'
124 | if args.reset or not os.path.isfile(path + '/raw_train.txt') or not os.path.isfile(path + '/raw_dev.txt'):
125 | cat = 'other'
126 | if 'Chinese' in path or 'Japanese' in path:
127 | cat = 'zh'
128 | for line in codecs.open(path + '/' + train_file, 'r', encoding='utf-8'):
129 | if len(line) < 2:
130 | break
131 | if '# sentence' in line or '# text' in line:
132 | cat = 'gold'
133 |
134 | if dev_file is None:
135 | reader.get_raw(path, train_file, '/raw_train.txt', cat, is_dev=False, form=args.format, is_space=is_space)
136 | else:
137 | reader.get_raw(path, train_file, '/raw_train.txt', cat, form=args.format, is_space=is_space)
138 | reader.get_raw(path, dev_file, '/raw_dev.txt', cat, form=args.format, is_space=is_space)
139 |
140 | if args.reset or not os.path.isfile(path + '/tag_train.txt') or not os.path.isfile(path + '/tag_dev.txt') or \
141 | not os.path.isfile(path + '/tag_dev_gold.txt'):
142 | if dev_file is None:
143 | raws_train = reader.raw(path + '/raw_train.txt')
144 | raws_dev = reader.raw(path + '/raw_dev.txt')
145 | sents_train, sents_dev = reader.gold(path + '/' + train_file, False, form=args.format, is_space=is_space)
146 | else:
147 | raws_train = reader.raw(path + '/raw_train.txt')
148 | sents_train = reader.gold(path + '/' + train_file, form=args.format, is_space=is_space)
149 |
150 | raws_dev = reader.raw(path + '/raw_dev.txt')
151 | sents_dev = reader.gold(path + '/' + dev_file, form=args.format, is_space=is_space)
152 |
153 | if is_space != 'sea':
154 | toolbox.raw2tags(raws_train, sents_train, path, 'tag_train.txt', ignore_space=args.ignore_space,
155 | reset=args.reset, tag_scheme=args.tags, ignore_mwt=args.ignore_mwt)
156 | toolbox.raw2tags(raws_dev, sents_dev, path, 'tag_dev.txt', creat_dict=False, gold_path='tag_dev_gold.txt',
157 | ignore_space=args.ignore_space, tag_scheme=args.tags, ignore_mwt=args.ignore_mwt)
158 | else:
159 | toolbox.raw2tags_sea(raws_train, sents_train, path, 'tag_train.txt', reset=args.reset, tag_scheme=args.tags)
160 | toolbox.raw2tags_sea(raws_dev, sents_dev, path, 'tag_dev.txt', gold_path='tag_dev_gold.txt',
161 | tag_scheme=args.tags)
162 |
163 | if args.reset or not os.path.isfile(path + '/chars.txt'):
164 | toolbox.get_chars(path, ['raw_train.txt', 'raw_dev.txt'], sea=is_space)
165 |
166 | char2idx, unk_chars_idx, idx2char, tag2idx, idx2tag, trans_dict = toolbox.get_dicts(path, args.sent_seg, args.tags,
167 | args.crf)
168 |
169 | if args.embeddings is not None:
170 | print 'Reading embeddings...'
171 | short_emb = args.embeddings[args.embeddings.index('/') + 1: args.embeddings.index('.')]
172 | if args.reset or not os.path.isfile(path + '/' + short_emb + '_sub.txt'):
173 | toolbox.get_sample_embedding(path, args.embeddings, char2idx)
174 | emb_dim, emb, valid_chars = toolbox.read_sample_embedding(path, short_emb, char2idx)
175 | for vch in valid_chars:
176 | if char2idx[vch] in unk_chars_idx:
177 | unk_chars_idx.remove(char2idx[vch])
178 | else:
179 | emb_dim = args.emb_dimension
180 | emb = None
181 |
182 | train_x, train_y, max_len_train = toolbox.get_input_vec(path, 'tag_train.txt', char2idx, tag2idx,
183 | limit=args.sent_limit, sent_seg=args.sent_seg,
184 | is_space=is_space, train_size=args.train_size,
185 | ignore_space=args.ignore_space)
186 |
187 | dev_x, max_len_dev = toolbox.get_input_vec_raw(path, 'raw_dev.txt', char2idx, limit=args.sent_limit,
188 | sent_seg=args.sent_seg, is_space=is_space,
189 | ignore_space=args.ignore_space)
190 | if args.sent_seg:
191 | print 'Joint sentence segmentation...'
192 | else:
193 | print 'Training set: %d instances; Dev set: %d instances.' % (len(train_x[0]), len(dev_x[0]))
194 |
195 | nums_grams = None
196 | ng_embs = None
197 |
198 | if args.ngram > 1 and (args.reset or not os.path.isfile(path + '/' + str(args.ngram) + 'gram.txt')):
199 | toolbox.get_ngrams(path, args.ngram, is_space)
200 |
201 | ngram = toolbox.read_ngrams(path, args.ngram)
202 |
203 | if args.ngram > 1:
204 | gram2idx = toolbox.get_ngram_dic(ngram)
205 | train_gram = toolbox.get_gram_vec(path, 'tag_train.txt', gram2idx, limit=args.sent_limit,sent_seg=args.sent_seg,
206 | is_space=is_space, ignore_space=args.ignore_space)
207 | dev_gram = toolbox.get_gram_vec(path, 'raw_dev.txt', gram2idx, is_raw=True, limit=args.sent_limit,
208 | sent_seg=args.sent_seg, is_space=is_space, ignore_space=args.ignore_space)
209 | train_x += train_gram
210 | dev_x += dev_gram
211 | nums_grams = []
212 | for dic in gram2idx:
213 | nums_grams.append(len(dic.keys()))
214 |
215 | max_len = max(max_len_train, max_len_dev)
216 |
217 | b_train_x, b_train_y = toolbox.buckets(train_x, train_y, size=args.bucket_size)
218 | b_train_x, b_train_y, b_lens, b_count = toolbox.pad_bucket(b_train_x, b_train_y, max_len)
219 |
220 | b_dev_x = [toolbox.pad_zeros(dev_x_i, max_len) for dev_x_i in dev_x]
221 |
222 | b_dev_y_gold = [line.strip() for line in codecs.open(path + '/tag_dev_gold.txt', 'r', encoding='utf-8')]
223 |
224 | nums_tag = len(tag2idx)
225 |
226 | config = tf.ConfigProto(allow_soft_placement=True)
227 | gpu_config = "/gpu:" + str(args.gpu)
228 |
229 | transducer = None
230 | transducer_graph = None
231 | trans_model = None
232 | trans_init = None
233 |
234 | if len(trans_dict) > 200 and not args.ignore_mwt:
235 | transducer = toolbox.get_dict_vec(trans_dict, char2idx)
236 | t = time()
237 |
238 | initializer = tf.contrib.layers.xavier_initializer()
239 |
240 | if transducer is not None:
241 | transducer_graph = tf.Graph()
242 | with transducer_graph.as_default():
243 | with tf.variable_scope("transducer") as scope:
244 | trans_model = Seq2seq(path + '/' + model_file + '_transducer')
245 | print 'Defining transducer...'
246 | trans_model.define(char_num=len(char2idx), rnn_dim=args.rnn_cell_dimension, emb_dim=args.emb_dimension,
247 | max_x=len(transducer[0][0]), max_y=len(transducer[1][0]))
248 | trans_init = tf.global_variables_initializer()
249 | transducer_graph.finalize()
250 |
251 | print 'Initialization....'
252 | main_graph = tf.Graph()
253 | with main_graph.as_default():
254 | with tf.variable_scope("tagger") as scope:
255 | model = Model(nums_chars=len(char2idx) + 2, nums_tags=nums_tag, buckets_char=b_lens, counts=b_count,
256 | crf=args.crf, ngram=nums_grams, batch_size=args.train_batch, sent_seg=args.sent_seg,
257 | is_space=is_space, emb_path=args.embeddings, tag_scheme=args.tags)
258 |
259 | model.main_graph(trained_model=path + '/' + model_file + '_model', scope=scope,
260 | emb_dim=emb_dim, cell=args.cell, rnn_dim=args.rnn_cell_dimension,
261 | rnn_num=args.rnn_layer_number, drop_out=args.dropout_rate, emb=emb)
262 | t = time()
263 |
264 | model.config(optimizer=args.optimizer, decay=args.decay_rate, lr_v=args.learning_rate,
265 | momentum=args.momentum, clipping=not args.no_clipping)
266 | init = tf.global_variables_initializer()
267 |
268 | print 'Done. Time consumed: %d seconds' % int(time() - t)
269 |
270 | main_graph.finalize()
271 |
272 | main_sess = tf.Session(config=config, graph=main_graph)
273 |
274 | if args.crf > 0:
275 | decode_graph = tf.Graph()
276 | with decode_graph.as_default():
277 | model.decode_graph()
278 | decode_graph.finalize()
279 |
280 | decode_sess = tf.Session(config=config, graph=decode_graph)
281 |
282 | sess = [main_sess, decode_sess]
283 |
284 | else:
285 | sess = [main_sess, None]
286 |
287 | with tf.device(gpu_config):
288 |
289 | if transducer is not None:
290 | print 'Building transducer...'
291 | t = time()
292 | trans_sess = tf.Session(config=config, graph=transducer_graph)
293 | trans_sess.run(trans_init)
294 | trans_model.train(transducer[0], transducer[1], transducer[2], transducer[3], args.learning_rate_trans,
295 | char2idx, trans_sess, args.epochs_trans, batch_size=10, reset=args.reset_trans)
296 | sess.append(trans_sess)
297 | print 'Done. Time consumed: %d seconds' % int(time() - t)
298 | print 'Training the main segmenter..'
299 | main_sess.run(init)
300 | print 'Initialisation...'
301 | print 'Done. Time consumed: %d seconds' % int(time() - t)
302 | t = time()
303 | b_dev_raw = [line.strip() for line in codecs.open(path + '/raw_dev.txt', 'r', encoding='utf-8')]
304 | model.train(b_train_x, b_train_y, b_dev_x, b_dev_raw, b_dev_y_gold, idx2tag, idx2char, unk_chars_idx, trans_dict,
305 | sess, args.epochs, path + '/' + model_file + '_weights', transducer=trans_model,
306 | lr=args.learning_rate, decay=args.decay_rate, sent_seg=args.sent_seg, outpath=args.output_path)
307 |
308 | else:
309 |
310 | assert args.path is not None
311 | assert args.model is not None
312 | path = args.path
313 | assert os.path.isfile(path + '/chars.txt')
314 |
315 | model_file = args.model
316 |
317 | if args.ensemble:
318 | if not os.path.isfile(path + '/' + model_file + '_1_model') or not os.path.isfile(path + '/' + model_file +
319 | '_1_weights.index'):
320 | raise Exception('Not any model file or weights file under the name of ' + model_file + '.')
321 | fin = open(path + '/' + model_file + '_1_model', 'rb')
322 | else:
323 | if not os.path.isfile(path + '/' + model_file + '_model') or not os.path.isfile(path + '/' + model_file +
324 | '_weights.index'):
325 | raise Exception('No model file or weights file under the name of ' + model_file + '.')
326 | fin = open(path + '/' + model_file + '_model', 'rb')
327 |
328 | weight_path = path + '/' + model_file
329 |
330 | param_dic = pickle.load(fin)
331 | fin.close()
332 |
333 | nums_chars = param_dic['nums_chars']
334 | nums_tags = param_dic['nums_tags']
335 | crf = param_dic['crf']
336 | emb_dim = param_dic['emb_dim']
337 | cell = param_dic['cell']
338 | rnn_dim = param_dic['rnn_dim']
339 | rnn_num = param_dic['rnn_num']
340 | drop_out = param_dic['drop_out']
341 | buckets_char = param_dic['buckets_char']
342 | nums_ngrams = param_dic['ngram']
343 | is_space = param_dic['is_space']
344 | sent_seg = param_dic['sent_seg']
345 | emb_path = param_dic['emb_path']
346 | tag_scheme = param_dic['tag_scheme']
347 |
348 | if args.embeddings is not None:
349 | emb_path = args.embeddings
350 |
351 | ngram = 1
352 | grams, gram2idx = None, None
353 | if nums_ngrams is not None:
354 | ngram = len(nums_ngrams) + 1
355 |
356 | char2idx, unk_chars_idx, idx2char, tag2idx, idx2tag, trans_dict = toolbox.get_dicts(path, sent_seg, tag_scheme, crf)
357 |
358 | trans_char_num = len(char2idx)
359 |
360 | if ngram > 1:
361 | grams = toolbox.read_ngrams(path, ngram)
362 |
363 | new_chars, new_grams = None, None
364 |
365 | test_x, test_y, raw_x, test_y_gold = None, None, None, None
366 |
367 | sub_dict = None
368 |
369 | max_step = None
370 |
371 | raw_file = None
372 |
373 | if args.action == 'test':
374 | test_file = args.test
375 | f_names = os.listdir(path)
376 | if test_file is None:
377 | for f_n in f_names:
378 | if 'ud-test.conllu' in f_n:
379 | test_file = f_n
380 | assert test_file is not None
381 |
382 | cat = 'other'
383 | if 'Chinese' in path or 'Japanese' in path:
384 | cat = 'zh'
385 | for line in codecs.open(path + '/' + test_file, 'r', encoding='utf-8'):
386 | if len(line) < 2:
387 | break
388 | if '# sentence' in line or '# text' in line:
389 | cat = 'gold'
390 | reader.get_raw(path, test_file, 'raw_test.txt', cat, form=args.format)
391 |
392 | raws_test = reader.raw(path + '/raw_test.txt')
393 | test_y_gold = reader.test_gold(path + '/' + test_file, form=args.format, is_space=is_space,
394 | ignore_mwt=args.ignore_mwt)
395 |
396 | new_chars = toolbox.get_new_chars(path + '/raw_test.txt', char2idx, is_space)
397 |
398 | if emb_path is not None:
399 | valid_chars = toolbox.get_valid_chars(new_chars + char2idx.keys(), emb_path)
400 | else:
401 | valid_chars = None
402 |
403 | char2idx, idx2char, unk_chars_idx, sub_dict = toolbox.update_char_dict(char2idx, new_chars, unk_chars_idx, valid_chars)
404 |
405 | test_x, max_len_test = toolbox.get_input_vec_raw(path, 'raw_test.txt', char2idx, limit=args.sent_limit + 100,
406 | sent_seg=sent_seg, is_space=is_space,
407 | ignore_space=args.ignore_space)
408 |
409 | max_step = max_len_test
410 |
411 | if sent_seg:
412 | print 'Joint sentence segmentation...'
413 | else:
414 | print 'Test set: %d instances.' % len(test_x[0])
415 |
416 | if ngram > 1:
417 | gram2idx = toolbox.get_ngram_dic(grams)
418 | new_grams = toolbox.get_new_grams(path + '/' + test_file, gram2idx, is_space=is_space)
419 |
420 | test_grams = toolbox.get_gram_vec(path, 'raw_test.txt', gram2idx, is_raw=True, limit=args.sent_limit + 100,
421 | sent_seg=sent_seg, is_space=is_space, ignore_space=args.ignore_space)
422 | test_x += test_grams
423 |
424 | for k in range(len(test_x)):
425 | test_x[k] = toolbox.pad_zeros(test_x[k], max_step)
426 |
427 | elif args.action == 'tag':
428 | assert args.raw is not None
429 |
430 | raw_file = args.raw
431 | new_chars = toolbox.get_new_chars(raw_file, char2idx, is_space)
432 |
433 | if emb_path is not None:
434 | valid_chars = toolbox.get_valid_chars(new_chars, emb_path)
435 | else:
436 | valid_chars = None
437 |
438 | char2idx, idx2char, unk_chars_idx, sub_dict = toolbox.update_char_dict(char2idx, new_chars, unk_chars_idx,
439 | valid_chars)
440 |
441 | if not args.segment_large:
442 |
443 | if sent_seg:
444 | raw_x, raw_len = toolbox.get_input_vec_tag(None, raw_file, char2idx, limit=args.sent_limit + 100,
445 | is_space=is_space)
446 | else:
447 | raw_x, raw_len = toolbox.get_input_vec_raw(None, raw_file, char2idx, limit=args.sent_limit + 100,
448 | sent_seg=sent_seg, is_space=is_space)
449 |
450 | if sent_seg:
451 | print 'Joint sentence segmentation...'
452 | else:
453 | print 'Raw setences: %d instances.' % len(raw_x[0])
454 |
455 | max_step = raw_len
456 |
457 | else:
458 |
459 | max_step = args.sent_limit
460 |
461 | if ngram > 1:
462 | gram2idx = toolbox.get_ngram_dic(grams)
463 | new_grams = toolbox.get_new_grams(raw_file, gram2idx, is_raw=True, is_space=is_space)
464 |
465 | if not args.segment_large:
466 | if sent_seg:
467 | raw_grams = toolbox.get_gram_vec_tag(None, raw_file, gram2idx, limit=args.sent_limit + 100,
468 | is_space=is_space)
469 | else:
470 | raw_grams = toolbox.get_gram_vec(None, raw_file, gram2idx, is_raw=True, limit=args.sent_limit + 100,
471 | sent_seg=sent_seg, is_space=is_space)
472 |
473 | raw_x += raw_grams
474 |
475 | if not args.segment_large:
476 | for k in range(len(raw_x)):
477 | raw_x[k] = toolbox.pad_zeros(raw_x[k], max_step)
478 |
479 | config = tf.ConfigProto(allow_soft_placement=True)
480 | gpu_config = "/gpu:" + str(args.gpu)
481 |
482 | transducer = None
483 | transducer_graph = None
484 | trans_model = None
485 | trans_init = None
486 |
487 | if len(trans_dict) > 200:
488 | transducer = toolbox.get_dict_vec(trans_dict, char2idx)
489 | t = time()
490 |
491 | initializer = tf.contrib.layers.xavier_initializer()
492 |
493 | if transducer is not None:
494 | transducer_graph = tf.Graph()
495 | with transducer_graph.as_default():
496 | with tf.variable_scope("transducer") as scope:
497 | trans_model = Seq2seq(path + '/' + model_file + '_transducer')
498 | trans_fin = open(path + '/' + model_file + '_transducer_model', 'rb')
499 | trans_param_dic = pickle.load(trans_fin)
500 | trans_fin.close()
501 |
502 | tr_char_num = trans_param_dic['char_num']
503 | tr_rnn_dim = trans_param_dic['rnn_dim']
504 | tr_emb_dim = trans_param_dic['emb_dim']
505 | tr_max_x = trans_param_dic['max_x']
506 | tr_max_y = trans_param_dic['max_y']
507 |
508 | print 'Defining transducer...'
509 | trans_model.define(char_num=tr_char_num, rnn_dim=tr_rnn_dim, emb_dim=tr_emb_dim,
510 | max_x=tr_max_x, max_y=tr_max_y, write_trans_model=False)
511 | trans_init = tf.global_variables_initializer()
512 | transducer_graph.finalize()
513 |
514 | print 'Initialization....'
515 | main_graph = tf.Graph()
516 | with main_graph.as_default():
517 | with tf.variable_scope("tagger") as scope:
518 | model = Model(nums_chars=nums_chars, nums_tags=nums_tags, buckets_char=[max_step], counts=[200],
519 | crf=crf, ngram=nums_ngrams, batch_size=args.tag_batch, is_space=is_space)
520 |
521 | model.main_graph(trained_model=None, scope=scope, emb_dim=emb_dim, cell=cell,
522 | rnn_dim=rnn_dim, rnn_num=rnn_num, drop_out=drop_out)
523 |
524 | model.define_updates(new_chars=new_chars, emb_path=emb_path, char2idx=char2idx)
525 |
526 | init = tf.global_variables_initializer()
527 |
528 | print 'Done. Time consumed: %d seconds' % int(time() - t)
529 | main_graph.finalize()
530 |
531 | idx=None
532 |
533 | if args.ensemble:
534 | idx = 1
535 | main_sess = []
536 | while os.path.isfile(path + '/' + model_file + '_' + str(idx) + '_weights.index'):
537 | main_sess.append(tf.Session(config=config, graph=main_graph))
538 | idx += 1
539 | else:
540 | main_sess = tf.Session(config=config, graph=main_graph)
541 |
542 | if crf:
543 | decode_graph = tf.Graph()
544 |
545 | with decode_graph.as_default():
546 | model.decode_graph()
547 | decode_graph.finalize()
548 |
549 | decode_sess = tf.Session(config=config, graph=decode_graph)
550 |
551 | sess = [main_sess, decode_sess]
552 |
553 | else:
554 | sess = [main_sess, None]
555 |
556 | with tf.device(gpu_config):
557 | ens_model = None
558 | print 'Loading weights....'
559 | if args.ensemble:
560 | for i in range(1, idx):
561 | print 'Ensemble: ' + str(i)
562 | main_sess[i - 1].run(init)
563 | model.run_updates(main_sess[i - 1], weight_path + '_' + str(i) + '_weights')
564 | else:
565 | main_sess.run(init)
566 | model.run_updates(main_sess, weight_path + '_weights')
567 |
568 | if transducer is not None:
569 | print 'Loading transducer...'
570 | t = time()
571 | trans_sess = tf.Session(config=config, graph=transducer_graph)
572 | trans_sess.run(trans_init)
573 | if os.path.isfile(path + '/' + model_file + '_transducer_weights'):
574 | trans_weight_path = path + '/' + model_file + '_transducer_weights'
575 | trans_weight_path = trans_weight_path.replace('//', '/')
576 | trans_model.saver.restore(trans_sess, trans_weight_path)
577 | sess.append(trans_sess)
578 |
579 | if args.action == 'test':
580 | test_y_raw = [line.strip() for line in codecs.open(path + '/raw_test.txt', 'rb', encoding='utf-8')]
581 | model.test(test_x, test_y_raw, test_y_gold, idx2tag, idx2char, unk_chars_idx, sub_dict, trans_dict, sess,
582 | transducer=trans_model, ensemble=args.ensemble, batch_size=args.test_batch, sent_seg=sent_seg,
583 | bias=args.segmentation_bias, outpath=args.output_path, trans_type=args.transduction_type)
584 |
585 | if args.action == 'tag':
586 |
587 | if not args.segment_large:
588 | raw_sents = []
589 | for line in codecs.open(raw_file, 'rb', encoding='utf-8'):
590 | line = line.strip()
591 | if len(line) > 0:
592 | raw_sents.append(line)
593 | model.tag(raw_x, raw_sents, idx2tag, idx2char, unk_chars_idx, sub_dict, trans_dict, sess,
594 | transducer=trans_model, outpath=args.output_path, ensemble=args.ensemble,
595 | batch_size=args.tag_batch, sent_seg=sent_seg, seg_large=args.segment_large, form=args.format)
596 | else:
597 | count = 0
598 | c_line = 0
599 | l_writer = codecs.open(args.output_path, 'w', encoding='utf-8')
600 | out = []
601 | with codecs.open(raw_file, 'r', encoding='utf-8') as l_file:
602 | lines = []
603 | for line in l_file:
604 | line = line.strip()
605 | if len(line) > 0:
606 | lines.append(line)
607 | else:
608 | c_line += 1
609 | if c_line >= args.large_size:
610 | count += len(lines)
611 | c_line = 0
612 | print count
613 | if args.sent_seg:
614 | raw_x, _ = toolbox.get_input_vec_tag(None, None, char2idx, lines=lines,
615 | limit=args.sent_limit, is_space=is_space)
616 | else:
617 | raw_x, _ = toolbox.get_input_vec_raw(None, None, char2idx, lines=lines,
618 | limit=args.sent_limit, sent_seg=sent_seg,
619 | is_space=is_space)
620 | if ngram > 1:
621 | if sent_seg:
622 | raw_grams = toolbox.get_gram_vec_tag(None, None, gram2idx, lines=lines,
623 | limit=args.sent_limit, is_space=is_space)
624 | else:
625 | raw_grams = toolbox.get_gram_vec(None, None, gram2idx, lines=lines, is_raw=True,
626 | limit=args.sent_limit, sent_seg=sent_seg,
627 | is_space=is_space)
628 | raw_x += raw_grams
629 |
630 | for k in range(len(raw_x)):
631 | raw_x[k] = toolbox.pad_zeros(raw_x[k], max_step)
632 |
633 | predition, multi = model.tag(raw_x, lines, idx2tag, idx2char, unk_chars_idx, sub_dict,
634 | trans_dict, sess, transducer=trans_model,
635 | outpath=args.output_path, ensemble=args.ensemble,
636 | batch_size=args.tag_batch, sent_seg=sent_seg,
637 | seg_large=args.segment_large, form=args.format)
638 |
639 | if args.only_tokenised:
640 | for l_out in predition:
641 | if len(l_out.strip()) > 0:
642 | l_writer.write(l_out + '\n')
643 | else:
644 | for tagged_t, multi_t in zip(predition, multi):
645 | if len(tagged_t.strip()) > 0:
646 | l_writer.write('#sent_tok: ' + tagged_t + '\n')
647 | idx = 1
648 | tgs = multi_t.split(' ')
649 | pl = ''
650 | for _ in range(8):
651 | pl += '\t' + '_'
652 | for tg in tgs:
653 | if '!#!' in tg:
654 | segs = tg.split('!#!')
655 | l_writer.write(str(idx) + '-' + str(int(segs[1]) + idx - 1) + '\t' +
656 | segs[0] + pl + '\n')
657 | else:
658 | l_writer.write(str(idx) + '\t' + tg + pl + '\n')
659 | idx += 1
660 | l_writer.write('\n')
661 | lines = []
662 | if len(lines) > 0:
663 |
664 | if args.sent_seg:
665 | raw_x, _ = toolbox.get_input_vec_tag(None, None, char2idx, lines=lines,
666 | limit=args.sent_limit, is_space=is_space)
667 | else:
668 | raw_x, _ = toolbox.get_input_vec_raw(None, None, char2idx, lines=lines,
669 | limit=args.sent_limit, sent_seg=sent_seg,
670 | is_space=is_space)
671 | if ngram > 1:
672 | if sent_seg:
673 | raw_grams = toolbox.get_gram_vec_tag(None, None, gram2idx, lines=lines,
674 | limit=args.sent_limit, is_space=is_space)
675 | else:
676 | raw_grams = toolbox.get_gram_vec(None, None, gram2idx, lines=lines, is_raw=True,
677 | limit=args.sent_limit, sent_seg=sent_seg,
678 | is_space=is_space)
679 | raw_x += raw_grams
680 |
681 | for k in range(len(raw_x)):
682 | raw_x[k] = toolbox.pad_zeros(raw_x[k], max_step)
683 |
684 | prediction, multi = model.tag(raw_x, lines, idx2tag, idx2char, unk_chars_idx, sub_dict,
685 | trans_dict, sess, transducer=trans_model,
686 | outpath=args.output_path, ensemble=args.ensemble,
687 | batch_size=args.tag_batch, sent_seg=sent_seg,
688 | seg_large=args.segment_large, form=args.format)
689 |
690 | if args.only_tokenised:
691 | for l_out in prediction:
692 | if len(l_out.strip()) > 0:
693 | l_writer.write(l_out + '\n')
694 | else:
695 | for tagged_t, multi_t in zip(prediction, multi):
696 | if len(tagged_t.strip()) > 0:
697 | l_writer.write('#sent_tok: ' + tagged_t + '\n')
698 | idx = 1
699 | tgs = multi_t.split(' ')
700 | pl = ''
701 | for _ in range(8):
702 | pl += '\t' + '_'
703 | for tg in tgs:
704 | if '!#!' in tg:
705 | segs = tg.split('!#!')
706 | l_writer.write(str(idx) + '-' + str(int(segs[1]) + idx - 1) + '\t' +
707 | segs[0] + pl + '\n')
708 | else:
709 | l_writer.write(str(idx) + '\t' + tg + pl + '\n')
710 | idx += 1
711 | l_writer.write('\n')
712 | l_writer.close()
713 |
714 | print 'Done.'
715 |
--------------------------------------------------------------------------------
/toolbox.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | @author: Yan Shao, yan.shao@lingfil.uu.se
4 | """
5 | import codecs
6 | import sys
7 | import numpy as np
8 | import random
9 | import os
10 | import math
11 | from reader import get_gold
12 |
13 | sys = reload(sys)
14 | sys.setdefaultencoding('utf-8')
15 |
16 | punc = ['!', ')', ',', '.', ';', ':', '?', '»', '...', '..', '....', '%', 'º', '²', '°', '¿', '¡', '(', '«',
17 | '"', '\'', '-', '。', '·', '।', '۔']
18 |
19 |
20 | def pre_token(line):
21 | out = []
22 | for seg in line.split(' '):
23 | f_out = []
24 | b_out = []
25 | while len(seg) > 0 and (seg[0] in punc or ('0' <= seg[0] <= '9')):
26 | f_out.append(seg[0])
27 | seg = seg[1:]
28 | while len(seg) > 0 and (seg[-1] in punc or ('0' <= seg[-1] <= '9')):
29 | b_out = [seg[-1]] + b_out
30 | seg = seg[:-1]
31 | if len(seg) > 0:
32 | out += f_out + [seg] + b_out
33 | else:
34 | out += f_out + b_out
35 | return out
36 |
37 |
38 | def get_chars(path, filelist, sea=False):
39 | char_set = {}
40 | out_char = codecs.open(path + '/chars.txt', 'w', encoding='utf-8')
41 | for i, file_name in enumerate(filelist):
42 | for line in codecs.open(path + '/' + file_name, 'rb', encoding='utf-8'):
43 | line = line.strip()
44 | if sea=='sea':
45 | line = pre_token(line)
46 | for ch in line:
47 | if ch in char_set:
48 | if i == 0:
49 | char_set[ch] += 1
50 | else:
51 | char_set[ch] = 1
52 | for k, v in char_set.items():
53 | out_char.write(k + '\t' + str(v) + '\n')
54 | out_char.close()
55 |
56 |
57 | def get_dicts(path, sent_seg, tag_scheme='BIES', crf=1):
58 | char2idx = {'': 0, '': 1, '<#>': 2}
59 | unk_chars_idx = []
60 | idx = 3
61 | for line in codecs.open(path + '/chars.txt', 'r', encoding='utf-8'):
62 | segs = line.split('\t')
63 | if len(segs[0].strip()) == 0:
64 | if ' ' not in char2idx:
65 | char2idx[' '] = idx
66 | idx += 1
67 | else:
68 | char2idx[segs[0]] = idx
69 | if int(segs[1]) == 1:
70 | unk_chars_idx.append(idx)
71 | idx += 1
72 | idx2char = {k: v for v, k in char2idx.items()}
73 | if tag_scheme == 'BI':
74 | if crf > 0:
75 | tag2idx = {'': 0, 'B': 1, 'I': 2}
76 | idx = 3
77 | else:
78 | tag2idx = {'B': 0, 'I': 1}
79 | idx = 2
80 | else:
81 | if crf > 0:
82 | tag2idx = {'
': 0, 'B': 1, 'I': 2, 'E': 3, 'S': 4}
83 | idx = 5
84 | else:
85 | tag2idx = {'B': 0, 'I':1, 'E':2, 'S':3}
86 | idx = 4
87 | for line in codecs.open(path + '/tags.txt', 'r', encoding='utf-8'):
88 | line = line.strip()
89 | if line not in tag2idx:
90 | tag2idx[line] = idx
91 | idx += 1
92 | if sent_seg:
93 | tag2idx['T'] = idx
94 | tag2idx['U'] = idx + 1
95 | idx2tag = {k: v for v, k in tag2idx.items()}
96 |
97 | trans_dict = {}
98 | key = ''
99 | if os.path.isfile(path + '/dict.txt'):
100 | for line in codecs.open(path + '/dict.txt', 'r', encoding='utf-8'):
101 | line = line.strip()
102 | if len(line) > 0:
103 | segs = line.split('\t')
104 | if len(segs) == 1:
105 | key = segs[0]
106 | trans_dict[key] = None
107 | elif len(segs) == 2:
108 | if trans_dict[key] is None:
109 | trans_dict[key] = segs[0].replace(' ', ' ')
110 |
111 | return char2idx, unk_chars_idx, idx2char, tag2idx, idx2tag, trans_dict
112 |
113 |
114 | def ngrams(raw, gram, is_space):
115 | gram_set = {}
116 | li = gram/2
117 | ri = gram - li - 1
118 | p = ''
119 | last_line = ''
120 | is_first = True
121 | for line in raw:
122 | for i in range(len(line)):
123 | if i - li < 0:
124 | if is_space != 'sea':
125 | lp = p * (li - i) + line[:i]
126 | else:
127 | lp = [p] * (li - i) + line[:i]
128 | else:
129 | lp = line[i - li:i]
130 | if i + ri + 1 > len(line):
131 | if is_space != 'sea':
132 | rp = line[i:] + p*(i + ri + 1 - len(line))
133 | else:
134 | rp = line[i:] + [p] * (i + ri + 1 - len(line))
135 | else:
136 | rp = line[i:i+ri+1]
137 | ch = lp + rp
138 | if is_space == 'sea':
139 | ch = '_'.join(ch)
140 | if ch in gram_set:
141 | gram_set[ch] += 1
142 | else:
143 | gram_set[ch] = 1
144 | if is_first:
145 | is_first = False
146 | else:
147 | if is_space is True:
148 | last_line += ' '
149 | start_idx = len(last_line) - ri
150 | if start_idx < 0:
151 | start_idx = 0
152 | end_idx = li + len(last_line)
153 | j_line = last_line + line
154 | for i in range(start_idx, end_idx):
155 | if i - li < 0:
156 | if is_space != 'sea':
157 | j_lp = p * (-i) + j_line[start_idx:i]
158 | else:
159 | j_lp = [p] * (-i) + j_line[start_idx:i]
160 | else:
161 | j_lp = j_line[i - li:i]
162 | if i + ri + 1 > len(j_line):
163 | if is_space != 'sea':
164 | j_rp = j_line[i:end_idx] + p * (ri + i + 1 - len(j_line))
165 | else:
166 | j_rp = j_line[i:end_idx] + [p] * (ri + i + 1 - len(j_line))
167 | else:
168 | j_rp = j_line[i:ri + 1 + i]
169 | j_ch = j_lp + j_rp
170 | if is_space == 'sea':
171 | ch = '_'.join(j_ch)
172 | if ch in gram_set:
173 | gram_set[ch] += 1
174 | else:
175 | gram_set[ch] = 1
176 | last_line = line
177 | return gram_set
178 |
179 |
180 | def get_ngrams(path, ng, is_space):
181 | raw = []
182 | for line in codecs.open(path + '/raw_train.txt', 'r', encoding='utf-8'):
183 | if is_space == 'sea':
184 | segs = pre_token(line.strip())
185 | else:
186 | segs = line.strip()
187 | raw.append(segs)
188 | if ng > 1:
189 | for i in range(2, ng + 1):
190 | out_gram = codecs.open(path + '/' + str(i) + 'gram.txt', 'w', encoding='utf-8')
191 | grams = ngrams(raw, i, is_space)
192 | for k, v in grams.items():
193 | out_gram.write(k + '\t' + str(v) + '\n')
194 | out_gram.close()
195 |
196 |
197 | def read_ngrams(path, ng):
198 | ngs = []
199 | for i in range(2, ng + 1):
200 | ng = {}
201 | for line in codecs.open(path + '/' + str(i) + 'gram.txt', 'r', encoding='utf-8'):
202 | line = line.rstrip()
203 | segs = line.split('\t')
204 | while len(segs[0]) < i:
205 | segs[0] += ' '
206 | ng[segs[0]] = int(segs[1])
207 | ngs.append(ng)
208 | return ngs
209 |
210 |
211 | def get_sample_embedding(path, emb, chars2idx):
212 | chars = chars2idx.keys()
213 | short_emb = emb[emb.index('/') + 1: emb.index('.')]
214 | emb_dic = {}
215 | valid_chars=[]
216 | for line in codecs.open(emb, 'rb', encoding='utf-8'):
217 | line = line.strip()
218 | sets = line.split(' ')
219 | emb_dic[sets[0]] = np.asarray(sets[1:], dtype='float32')
220 | fout = codecs.open(path + '/' + short_emb + '_sub.txt', 'w', encoding='utf-8')
221 | for ch in chars:
222 | p_line = ch
223 | if ch in emb_dic:
224 | valid_chars.append(ch)
225 | for emb in emb_dic[ch]:
226 | p_line += ' ' + unicode(emb)
227 | fout.write(p_line + '\n')
228 | fout.close()
229 |
230 |
231 | def read_sample_embedding(path, short_emb, char2idx):
232 | emb_values = []
233 | valid_chars = []
234 | emb_dic={}
235 | for line in codecs.open(path + '/' + short_emb + '_sub.txt', 'rb', encoding='utf-8'):
236 | first_ch = line[0]
237 | line = line.rstrip()
238 | sets = line.split(' ')
239 | if first_ch == ' ':
240 | emb_dic[' '] = np.asarray(sets, dtype='float32')
241 | else:
242 | emb_dic[sets[0]] = np.asarray(sets[1:], dtype='float32')
243 | emb_dim = len(emb_dic.items()[0][1])
244 | for ch in char2idx.keys():
245 | if ch in emb_dic:
246 | emb_values.append(emb_dic[ch])
247 | valid_chars.append(ch)
248 | else:
249 | rand = np.random.uniform(-math.sqrt(float(3) / emb_dim), math.sqrt(float(3) / emb_dim), emb_dim)
250 | emb_values.append(np.asarray(rand, dtype='float32'))
251 | emb_dim = len(emb_values[0])
252 | return emb_dim, emb_values, valid_chars
253 |
254 |
255 | def get_sent_raw(path, fname, is_space=True):
256 | long_line = ''
257 | for line in codecs.open(path + '/' + fname, 'r', encoding='utf-8'):
258 | line = line.strip()
259 | if is_space:
260 | long_line += ' ' + line
261 | else:
262 | long_line += line
263 | if is_space:
264 | long_line = long_line[1:]
265 |
266 | return long_line
267 |
268 |
269 | def chop(line, ad_s, limit):
270 | out = []
271 | chopped = False
272 | while len(line) > 0:
273 | if chopped:
274 | s_line = line[:limit - 1]
275 | s_line = [ad_s] + s_line
276 | else:
277 | chopped = True
278 | s_line = line[:limit]
279 | out.append(s_line)
280 | line = line[limit - 10:]
281 | if len(line) < 10:
282 | line = ''
283 | while len(out) > 0 and len(out[-1]) < limit-1:
284 | out[-1].append(0)
285 | return out
286 |
287 |
288 | def get_input_vec(path, fname, char2idx, tag2idx, limit=500, sent_seg=False, is_space=True, train_size=-1, ignore_space=False):
289 | ct = 0
290 | max_len = 0
291 | space_idx = None
292 | if is_space is True:
293 | assert ' ' in char2idx
294 | space_idx = char2idx[' ']
295 | x_indices = []
296 | y_indices = []
297 | s_count = 0
298 | l_count = 0
299 | x = []
300 | y = []
301 |
302 | n_sent = 0
303 |
304 | if sent_seg:
305 | for line in codecs.open(path + '/' + fname, 'r', encoding='utf-8'):
306 | line = line.strip()
307 | if len(line) == 0:
308 | ct = 0
309 | elif ct == 0:
310 | if is_space == 'sea':
311 | line = pre_token(line)
312 | for ch in line:
313 | if len(ch.strip()) == 0:
314 | x.append(char2idx[' '])
315 | elif ch in char2idx:
316 | x.append(char2idx[ch])
317 | else:
318 | x.append(char2idx[''])
319 | if is_space is True and not ignore_space:
320 | x = [space_idx] + x
321 | x_indices += x
322 | x = []
323 | ct = 1
324 | elif ct == 1:
325 | for ch in line:
326 | y.append(tag2idx[ch])
327 | if y[-1] == tag2idx['S']:
328 | y[-1] = tag2idx['T']
329 | else:
330 | y[-1] = tag2idx['U']
331 | if is_space is True and not ignore_space:
332 | y = [tag2idx['X']] + y
333 | y_indices += y
334 | y = []
335 | n_sent += 1
336 | if 0 < train_size <= n_sent:
337 | break
338 | x_indices = chop(x_indices, char2idx['<#>'], limit)
339 | y_indices = chop(y_indices, tag2idx['I'], limit)
340 | max_len = limit
341 | else:
342 | for line in codecs.open(path + '/' + fname, 'r', encoding='utf-8'):
343 | line = line.strip()
344 | if len(line) == 0:
345 | ct = 0
346 | elif ct == 0:
347 | if is_space == 'sea':
348 | line = pre_token(line)
349 | max_len = max(max_len, len(line))
350 | s_count += 1
351 | if len(line) > limit:
352 | l_count += 1
353 | chopped = False
354 | while len(line) > 0:
355 | s_line = line[:limit - 1]
356 | line = line[limit - 10:]
357 | if len(line) < 10:
358 | line = ''
359 | if not chopped:
360 | chopped = True
361 | else:
362 | x.append(char2idx['<#>'])
363 | for ch in s_line:
364 | if len(ch.strip()) == 0:
365 | x.append(char2idx[' '])
366 | elif ch in char2idx:
367 | x.append(char2idx[ch])
368 | else:
369 | x.append(char2idx[''])
370 | x_indices.append(x)
371 | x = []
372 | ct = 1
373 | elif ct == 1:
374 | chopped = False
375 | while len(line) > 0:
376 | s_line = line[:limit - 1]
377 | line = line[limit - 10:]
378 | if len(line) < 10:
379 | line = ''
380 | if not chopped:
381 | chopped = True
382 | else:
383 | y.append(tag2idx['I'])
384 | for ch in s_line:
385 | y.append(tag2idx[ch])
386 | y_indices.append(y)
387 | y = []
388 | n_sent += 1
389 | if 0 < train_size <= n_sent:
390 | break
391 | max_len = min(max_len, limit)
392 | if l_count > 0:
393 | print '%d (out of %d) sentences are chopped.' % (l_count, s_count)
394 | return [x_indices], [y_indices], max_len
395 |
396 |
397 | def get_input_vec_sent(path, fname, char2idx, win_size, is_space=True):
398 | pre_line = ''
399 | c_line = ''
400 | x = []
401 | y = []
402 | is_first = True
403 | for line in codecs.open(path + '/' + fname, 'r', encoding='utf-8'):
404 | line = line.strip()
405 | if is_space == 'sea':
406 | line = pre_token(line)
407 | start_idx = len(pre_line)
408 | if is_space is True:
409 | j_line = pre_line + ' ' + c_line + ' ' + line
410 | end_idx = start_idx + len(c_line) + 1
411 | if is_first:
412 | is_first = False
413 | j_line = j_line[1:]
414 | end_idx -= 1
415 | else:
416 | j_line = pre_line + c_line + line
417 | end_idx = start_idx + len(c_line)
418 | for i in range(start_idx, end_idx):
419 | indices = []
420 | for j in range(i - win_size, i + win_size + 1):
421 | if j < 0 or j >= len(j_line):
422 | indices.append(char2idx[''])
423 | else:
424 | if j_line[j] in char2idx:
425 | indices.append(char2idx[j_line[j]])
426 | else:
427 | indices.append(char2idx[''])
428 | x.append(indices)
429 | if i == end_idx - 1:
430 | y.append(1)
431 | else:
432 | y.append(0)
433 | pre_line = c_line
434 | c_line = line
435 | if is_space is True:
436 | j_line = pre_line + ' ' + c_line
437 | else:
438 | j_line = pre_line + c_line
439 | start_idx = len(pre_line)
440 | end_idx = start_idx + len(c_line)
441 | for i in range(start_idx, end_idx):
442 | indices = []
443 | for j in range(i - win_size, i + win_size + 1):
444 | if j < 0 or j >= len(j_line):
445 | indices.append(char2idx[''])
446 | else:
447 | if j_line[j] in char2idx:
448 | indices.append(char2idx[j_line[j]])
449 | else:
450 | indices.append(char2idx[''])
451 | x.append(indices)
452 | if i == end_idx - 1:
453 | y.append(1)
454 | else:
455 | y.append(0)
456 |
457 | assert len(x) == len(y)
458 | return x, y
459 |
460 |
461 | def get_input_vec_sent_raw(raws, char2idx, win_size):
462 | x = []
463 | for i in range(len(raws)):
464 | indices = []
465 | for j in range(i - win_size, i + win_size + 1):
466 | if j < 0 or j >= len(raws):
467 | indices.append(char2idx[''])
468 | else:
469 | if raws[j] in char2idx:
470 | indices.append(char2idx[raws[j]])
471 | else:
472 | indices.append(char2idx[''])
473 | x.append(indices)
474 | return x
475 |
476 |
477 | def get_input_vec_raw(path, fname, char2idx, lines=None, limit=500, sent_seg=False, is_space=True, ignore_space=False):
478 | max_len = 0
479 | space_idx = None
480 | is_first = True
481 | if is_space is True:
482 | assert ' ' in char2idx
483 | space_idx = char2idx[' ']
484 | x_indices = []
485 | s_count = 0
486 | l_count = 0
487 | x = []
488 | if lines is None:
489 | assert fname is not None
490 | if path is None:
491 | real_path = fname
492 | else:
493 | real_path = path + '/' + fname
494 | lines = codecs.open(real_path, 'r', encoding='utf-8')
495 | if sent_seg:
496 | for line in lines:
497 | line = line.strip()
498 | if is_space == 'sea':
499 | line = pre_token(line)
500 | elif ignore_space:
501 | line = ''.join(line.split())
502 | for ch in line:
503 | if len(ch.strip()) == 0:
504 | x.append(char2idx[' '])
505 | elif ch in char2idx:
506 | x.append(char2idx[ch])
507 | else:
508 | x.append(char2idx[''])
509 | if is_space is True and not ignore_space:
510 | if is_first:
511 | is_first = False
512 | else:
513 | x = [space_idx] + x
514 | x_indices += x
515 | x = []
516 | x_indices = chop(x_indices, char2idx['<#>'], limit)
517 | max_len = limit
518 | else:
519 | for line in lines:
520 | line = line.strip()
521 | if len(line) > 0:
522 | if is_space == 'sea':
523 | line = pre_token(line)
524 | elif ignore_space:
525 | line = ''.join(line.split())
526 | max_len = max(max_len, len(line))
527 | s_count += 1
528 |
529 | for ch in line:
530 | if len(ch.strip()) == 0:
531 | x.append(char2idx[' '])
532 | elif ch in char2idx:
533 | x.append(char2idx[ch])
534 | else:
535 | x.append(char2idx[''])
536 |
537 | if len(line) > limit:
538 | l_count += 1
539 | chop_x = chop(x, char2idx['<#>'], limit)
540 | x_indices += chop_x
541 | else:
542 | x_indices.append(x)
543 | x = []
544 | max_len = min(max_len, limit)
545 | if l_count > 0:
546 | print '%d (out of %d) sentences are chopped.' % (l_count, s_count)
547 | return [x_indices], max_len
548 |
549 |
550 | def get_input_vec_tag(path, fname, char2idx, lines=None, limit=500, is_space=True):
551 | space_idx = None
552 | if is_space is True:
553 | assert ' ' in char2idx
554 | space_idx = char2idx[' ']
555 | x_indices = []
556 | out = []
557 | x = []
558 | is_first = True
559 | if lines is None:
560 | assert fname is not None
561 | if path is None:
562 | real_path = fname
563 | else:
564 | real_path = path + '/' + fname
565 | lines = codecs.open(real_path, 'r', encoding='utf-8')
566 | for line in lines:
567 | line = line.strip()
568 | if len(line) > 0:
569 | if is_space == 'sea':
570 | line = pre_token(line)
571 | if len(line) > 0:
572 | for ch in line:
573 | if len(ch.strip()) == 0:
574 | x.append(char2idx[' '])
575 | elif ch in char2idx:
576 | x.append(char2idx[ch])
577 | else:
578 | x.append(char2idx[''])
579 | if is_space is True:
580 | if is_first:
581 | is_first = False
582 | else:
583 | x = [space_idx] + x
584 | x_indices += x
585 | x = []
586 | elif len(x_indices) > 0:
587 | x_indices = chop(x_indices, char2idx['<#>'], limit)
588 | out += x_indices
589 | x_indices = []
590 | is_first = True
591 |
592 | if len(x_indices) > 0:
593 | x_indices = chop(x_indices, char2idx['<#>'], limit)
594 | out += x_indices
595 |
596 | return [out], limit
597 |
598 |
599 | def get_vecs(str, char2idx):
600 | out = []
601 | for ch in str:
602 | if ch in char2idx:
603 | out.append(char2idx[ch])
604 | return out
605 |
606 |
607 | def get_dict_vec(trans_dict, char2idx):
608 | max_x, max_y = 0, 0
609 | x = []
610 | y = []
611 | for k, v in trans_dict.items():
612 | x.append(get_vecs(k, char2idx))
613 | y.append(get_vecs(v.replace(' ', ' '), char2idx) + [2])
614 | if len(k) > max_x:
615 | max_x = len(k)
616 | if len(v) > max_y:
617 | max_y = len(v)
618 | max_x += 5
619 | max_y += 5
620 | x = pad_zeros(x, max_x)
621 | y = pad_zeros(y, max_y)
622 | assert len(x) == len(y)
623 | num = len(x)
624 | xy = zip(x, y)
625 | random.shuffle(xy)
626 | xy = zip(*xy)
627 | t_x = xy[0][:int(num * 0.95)]
628 | t_y = xy[1][:int(num * 0.95)]
629 | v_x = xy[0][int(num * 0.95):]
630 | v_y = xy[1][int(num * 0.95):]
631 | return t_x, t_y, v_x, v_y
632 |
633 |
634 | def get_ngram_dic(ng):
635 | gram_dics = []
636 | for i, gram in enumerate(ng):
637 | g_dic = {'': 0, '': 1, '<#>': 2}
638 | idx = 3
639 | for g in gram.keys():
640 | if gram[g] > 1:
641 | g_dic[g] = idx
642 | else:
643 | g_dic[g] = 1
644 | idx += 1
645 | gram_dics.append(g_dic)
646 | return gram_dics
647 |
648 |
649 | def gram_vec(raw, dic, limit=500, sent_seg=False, is_space=True):
650 | out = []
651 | if is_space == 'sea':
652 | ngram = len(dic.keys()[0].split('_'))
653 | else:
654 | ngram = 0
655 | for k in dic.keys():
656 | if '' not in k:
657 | ngram = len(k)
658 | break
659 | li = ngram/2
660 | ri = ngram - li - 1
661 | p = ''
662 | indices = []
663 | is_first = True
664 | if sent_seg:
665 | last_line = ''
666 | for line in raw:
667 | for i in range(len(line)):
668 | if i - li < 0:
669 | if is_space != 'sea':
670 | lp = p * (li - i) + line[:i]
671 | else:
672 | lp = [p] * (li - i) + line[:i]
673 | else:
674 | lp = line[i - li:i]
675 | if i + ri + 1 > len(line):
676 | if is_space != 'sea':
677 | rp = line[i:] + p * (i + ri + 1 - len(line))
678 | else:
679 | rp = line[i:] + [p] * (i + ri + 1 - len(line))
680 | else:
681 | rp = line[i:i + ri + 1]
682 | ch = lp + rp
683 | if is_space == 'sea':
684 | ch = '_'.join(ch)
685 | if ch in dic:
686 | indices.append(dic[ch])
687 | else:
688 | indices.append(dic[''])
689 | if is_first:
690 | is_first = False
691 | else:
692 | start_idx = len(last_line) - ri
693 | if start_idx < 0:
694 | start_idx = 0
695 | if is_space:
696 | last_line += ' '
697 | j_line = last_line + line
698 | end_idx = len(last_line) + li
699 | j_indices = []
700 | for i in range(start_idx, end_idx):
701 | if i - li < 0:
702 | if is_space != 'sea':
703 | j_lp = p * (-i) + j_line[start_idx:i]
704 | else:
705 | j_lp = [p] * (-i) + j_line[start_idx:i]
706 | else:
707 | j_lp = j_line[i - li:i]
708 | if i + ri + 1 > len(j_line):
709 | if is_space != 'sea':
710 | j_rp = j_line[i:end_idx] + p * (ri + i + 1 - len(j_line))
711 | else:
712 | j_rp = j_line[i:end_idx] + [p] * (ri + i + 1 - len(j_line))
713 | else:
714 | j_rp = j_line[i:ri + 1 + i]
715 | j_ch = j_lp + j_rp
716 | if is_space == 'sea':
717 | j_ch = '_'.join(j_ch)
718 | if j_ch in dic:
719 | j_indices.append(dic[j_ch])
720 | else:
721 | j_indices.append(dic[''])
722 | if ri > 0:
723 | out = out[: - ri] + j_indices[:ri]
724 | if is_space:
725 | indices = j_indices[ - (li + 1):] + indices[li:]
726 | else:
727 | indices = j_indices[ - li:] + indices[li:]
728 | out += indices
729 | indices = []
730 | last_line = line
731 | out = chop(out, dic['<#>'], limit)
732 |
733 | else:
734 | for line in raw:
735 | chopped = False
736 | while len(line) > 0:
737 | s_line = line[:limit - 1]
738 | line = line[limit - 10:]
739 | if len(line) < 10:
740 | line = ''
741 | if not chopped:
742 | chopped = True
743 | else:
744 | indices.append(dic['<#>'])
745 | for i in range(len(s_line)):
746 | if i - li < 0:
747 | if is_space != 'sea':
748 | lp = p * (li - i) + s_line[:i]
749 | else:
750 | lp = [p] * (li - i) + s_line[:i]
751 | else:
752 | lp = s_line[i - li:i]
753 | if i + ri + 1 > len(s_line):
754 | if is_space != 'sea':
755 | rp = s_line[i:] + p * (i + ri + 1 - len(s_line))
756 | else:
757 | rp = s_line[i:] + [p] * (i + ri + 1 - len(s_line))
758 | else:
759 | rp = s_line[i:i + ri + 1]
760 | ch = lp + rp
761 | if is_space == 'sea':
762 | ch = '_'.join(ch)
763 | if ch in dic:
764 | indices.append(dic[ch])
765 | else:
766 | indices.append(dic[''])
767 | out.append(indices)
768 | indices = []
769 | return out
770 |
771 |
772 | def get_gram_vec(path, fname, gram2index, lines=None, is_raw=False, limit=500, sent_seg=False, is_space=True, ignore_space=False):
773 | raw = []
774 | i = 0
775 | if lines is None:
776 | assert fname is not None
777 | if path is None:
778 | real_path = fname
779 | else:
780 | real_path = path + '/' + fname
781 | lines = codecs.open(real_path, 'r', encoding='utf-8')
782 | for line in lines:
783 | line = line.strip()
784 | if is_space == 'sea':
785 | line = pre_token(line)
786 | elif ignore_space:
787 | line = ''.join(line.split())
788 | if i == 0 or is_raw:
789 | raw.append(line)
790 | i += 1
791 | if len(line) > 0:
792 | i += 1
793 | else:
794 | i = 0
795 | out = []
796 | for g_dic in gram2index:
797 | out.append(gram_vec(raw, g_dic, limit, sent_seg, is_space))
798 | return out
799 |
800 |
801 | def get_gram_vec_tag(path, fname, gram2index, lines=None, limit=500, is_space=True, ignore_space=False):
802 | raw = []
803 | out = [[] for _ in range(len(gram2index))]
804 | if lines is None:
805 | assert fname is not None
806 | if path is None:
807 | real_path = fname
808 | else:
809 | real_path = path + '/' + fname
810 | lines = codecs.open(real_path, 'r', encoding='utf-8')
811 | for line in lines:
812 | line = line.strip()
813 | if is_space == 'sea':
814 | line = pre_token(line)
815 | elif ignore_space:
816 | line = ''.join(line.split())
817 | if len(line) > 0:
818 | raw.append(line)
819 | else:
820 | for i, g_dic in enumerate(gram2index):
821 | out[i] += gram_vec(raw, g_dic, limit, True, is_space)
822 | raw = []
823 | if len(raw) > 0:
824 | for i, g_dic in enumerate(gram2index):
825 | out[i] += gram_vec(raw, g_dic, limit, True, is_space)
826 | return out
827 |
828 |
829 | def read_vocab_tag(path):
830 | '''
831 | Read tags from index files and create dictionaries
832 | :param path:
833 | :return tag2idx, idx2tag
834 | '''
835 | tag2idx = {}
836 | for i, line in enumerate(codecs.open(path, 'rb', encoding='utf-8')):
837 | line = line.strip()
838 | tag2idx[line] = i
839 | idx2tag = {k: v for v, k in tag2idx.items()}
840 | return tag2idx, idx2tag
841 |
842 |
843 | def get_tags(can, action='sep', tag_scheme='BIES', ignore_mwt=False):
844 | tags = []
845 | if tag_scheme == 'BI':
846 | for i in range(len(can)):
847 | if i == 0:
848 | if action == 'sep' or ignore_mwt:
849 | tags.append('B')
850 | else:
851 | tags.append('K')
852 | else:
853 | if action == 'sep' or ignore_mwt:
854 | tags.append('I')
855 | else:
856 | tags.append('Z')
857 | else:
858 | for i in range(len(can)):
859 | if len(can) == 1:
860 | if action == 'sep' or ignore_mwt:
861 | tags.append('S')
862 | else:
863 | tags.append('D')
864 | elif i == 0:
865 | if action == 'sep' or ignore_mwt:
866 | tags.append('B')
867 | else:
868 | tags.append('K')
869 | elif i == len(can) - 1:
870 | if action == 'sep' or ignore_mwt:
871 | tags.append('E')
872 | else:
873 | tags.append('J')
874 | else:
875 | if action == 'sep' or ignore_mwt:
876 | tags.append('I')
877 | else:
878 | tags.append('Z')
879 | return tags
880 |
881 |
882 | def update_dict(trans_dic, can, trans):
883 | can = can.lower()
884 | if can not in trans_dic:
885 | trans_dic[can] = {}
886 | if trans not in trans_dic[can]:
887 | trans_dic[can][trans] = 1
888 | else:
889 | trans_dic[can][trans] += 1
890 | return trans_dic
891 |
892 |
893 | def raw2tags(raw, sents, path, train_file, creat_dict=True, gold_path=None, ignore_space=False, reset=False,
894 | tag_scheme='BIES', ignore_mwt=False):
895 | wt = codecs.open(path + '/' + train_file, 'w', encoding='utf-8')
896 | if creat_dict and not ignore_mwt:
897 | wd = codecs.open(path + '/dict.txt', 'w', encoding='utf-8')
898 | wg = None
899 | if gold_path is not None:
900 | wg = codecs.open(path + '/' + gold_path, 'w', encoding='utf-8')
901 | wtg = None
902 | if reset or not os.path.isfile(path + '/tags.txt'):
903 | wtg = codecs.open(path + '/tags.txt', 'w', encoding='utf-8')
904 | final_dic = {}
905 | assert len(raw) == len(sents)
906 | invalid = 0
907 | s_tags = set()
908 |
909 | def matched(can, sent_l, tags, trans_dic):
910 | if '-' in sent_l[0][0]:
911 | nums = sent_l[0][0].split('-')
912 | count = int(nums[1]) - int(nums[0])
913 | sent_l.pop(0)
914 | segs = []
915 | while count >= 0:
916 | segs.append(sent_l[0][1])
917 | sent_l.pop(0)
918 | count -= 1
919 | j_seg = ''.join(segs)
920 | if j_seg == can:
921 | for seg in segs:
922 | tags += get_tags(seg, tag_scheme=tag_scheme)
923 | elif can.replace('-', '') == j_seg:
924 | for c_split in can.split('-'):
925 | tags += get_tags(c_split, tag_scheme=tag_scheme)
926 | if tag_scheme == 'BI':
927 | tags.append('I')
928 | else:
929 | tags.append('X')
930 | tags.pop()
931 | else:
932 | tags += get_tags(can, action='trans', tag_scheme=tag_scheme, ignore_mwt=ignore_mwt)
933 | if not ignore_mwt:
934 | trans = ' '.join(segs)
935 | trans_dic = update_dict(trans_dic, can, trans)
936 | else:
937 | tags += get_tags(can, tag_scheme=tag_scheme)
938 | sent_l.pop(0)
939 |
940 | return tags, trans_dic
941 |
942 | for raw_l, sent_l in zip(raw, sents):
943 | if ignore_space:
944 | raw_l = ''.join(raw_l.split())
945 | tags = []
946 | cans = raw_l.split(' ')
947 | trans_dic = {}
948 | gold = get_gold(sent_l, ignore_mwt=ignore_mwt)
949 | pre = ''
950 | for can in cans:
951 | t_can = can.strip()
952 | purged = len(can) - len(t_can)
953 | if purged > 0:
954 | can = t_can
955 | while purged > 0:
956 | if tag_scheme == 'BI':
957 | tags.append('I')
958 | else:
959 | tags.append('X')
960 | purged -= 1
961 | done = False
962 | if len(pre) > 0:
963 | can = pre + ' ' + can
964 | while not done:
965 | if can == sent_l[0][1]:
966 | tags, trans_dic = matched(can, sent_l, tags, trans_dic)
967 | done = True
968 | pre = ''
969 | else:
970 | if len(can) >= len(sent_l[0][1]):
971 | s_l = len(sent_l[0][1])
972 | s_can = can[:s_l]
973 | if s_can != sent_l[0][1]:
974 | done = True
975 | tags, trans_dic = matched(s_can, sent_l, tags, trans_dic)
976 | can = can[s_l:]
977 | if len(can) == 0:
978 | done = True
979 | pre = ''
980 | else:
981 | pre = can
982 | done = True
983 | if len(pre) == 0:
984 | if tag_scheme == 'BI':
985 | tags.append('I')
986 | else:
987 | tags.append('X')
988 | if len(tags) > 0:
989 | tags.pop()
990 | if len(tags) == len(raw_l):
991 | for tg in tags:
992 | s_tags.add(tg)
993 | wt.write(raw_l + '\n')
994 | wt.write(''.join(tags) + '\n')
995 | wt.write('\n')
996 | for key in trans_dic:
997 | if key not in final_dic:
998 | final_dic[key] = trans_dic[key]
999 | else:
1000 | for tr in trans_dic[key]:
1001 | if tr in final_dic[key]:
1002 | final_dic[key][tr] += trans_dic[key][tr]
1003 | else:
1004 | final_dic[key][tr] = trans_dic[key][tr]
1005 | else:
1006 | invalid += 1
1007 | if wg is not None:
1008 | wg.write(gold + '\n')
1009 | if wg is not None:
1010 | wg.close()
1011 | if wtg is not None:
1012 | for stg in s_tags:
1013 | wtg.write(stg + '\n')
1014 | wtg.close()
1015 | if creat_dict and not ignore_mwt:
1016 | for key in final_dic:
1017 | wd.write(key + '\n')
1018 | s_dic = sorted(final_dic[key].items(), key=lambda x: x[1], reverse=True)
1019 | for i in s_dic:
1020 | wd.write(i[0] + '\t' + str(i[1]) + '\n')
1021 | wd.write('\n')
1022 | wt.close()
1023 | print 'invalid sentences: ', invalid, len(raw)
1024 |
1025 |
1026 | def raw2tags_sea(raw, sents, path, train_file, gold_path=None, reset=False, tag_scheme='BIES'):
1027 | wt = codecs.open(path + '/' + train_file, 'w', encoding='utf-8')
1028 | wg = None
1029 | if gold_path is not None:
1030 | wg = codecs.open(path + '/' + gold_path, 'w', encoding='utf-8')
1031 | assert len(raw) == len(sents)
1032 | invalid = 0
1033 | wtg = None
1034 | if reset or not os.path.isfile(path + '/tags.txt'):
1035 | wtg = codecs.open(path + '/tags.txt', 'w', encoding='utf-8')
1036 |
1037 | s_tags = set()
1038 |
1039 | def matched(can, sent_l, tags):
1040 | segs = can.split(' ')
1041 | sent_l.pop(0)
1042 | if len(segs) == 1:
1043 | tags.append('S')
1044 | elif len(segs) > 1:
1045 | if tag_scheme == 'BI':
1046 | tags += ['B'] + ['I'] * (len(segs) - 1)
1047 | else:
1048 | mid_t = ['I'] * (len(segs) - 2)
1049 | tags += ['B'] + mid_t + ['E']
1050 | return tags
1051 |
1052 | for raw_l, sent_l in zip(raw, sents):
1053 | tags = []
1054 | cans = pre_token(raw_l)
1055 | gold = get_gold(sent_l)
1056 | pre = ''
1057 | for can in cans:
1058 | t_can = can.strip()
1059 | purged = len(can) - len(t_can)
1060 | if purged > 0:
1061 | can = t_can
1062 | while purged > 0:
1063 | if tag_scheme == 'BI':
1064 | tags.append('I')
1065 | else:
1066 | tags.append('X')
1067 | purged -= 1
1068 | if len(pre) > 0:
1069 | can = pre + ' ' + can
1070 | j_can = ''.join(can.split())
1071 | if sent_l:
1072 | j_sent = ''.join(sent_l[0][1].split())
1073 | if j_can == j_sent:
1074 | tags = matched(can, sent_l, tags)
1075 | pre = ''
1076 | else:
1077 | assert len(j_can) < len(j_sent)
1078 | pre = can
1079 | if len(tags) == len(cans):
1080 | for tg in tags:
1081 | s_tags.add(tg)
1082 | wt.write(raw_l + '\n')
1083 | wt.write(''.join(tags) + '\n')
1084 | wt.write('\n')
1085 | else:
1086 | invalid += 1
1087 | if wg is not None:
1088 | wg.write(gold + '\n')
1089 | if wg is not None:
1090 | wg.close()
1091 | if wtg is not None:
1092 | for stg in s_tags:
1093 | wtg.write(stg + '\n')
1094 | wtg.close()
1095 | wt.close()
1096 |
1097 | print 'invalid sentences: ', invalid, len(raw)
1098 |
1099 |
1100 | def pad_zeros(l, max_len):
1101 | padded = None
1102 | if type(l) is list:
1103 | padded = []
1104 | for item in l:
1105 | if len(item) <= max_len:
1106 | padded.append(np.pad(item, (0, max_len - len(item)), 'constant', constant_values=0))
1107 | else:
1108 | padded.append(np.asarray(item[:max_len]))
1109 | padded = np.asarray(padded)
1110 | elif type(l) is dict:
1111 | padded = {}
1112 | for k, v in l.iteritems():
1113 | padded[k] = [np.pad(item, (0, max_len - len(item)), 'constant', constant_values=0) for item in v]
1114 | return padded
1115 |
1116 | def unpad_zeros(l):
1117 | out = []
1118 | for tags in l:
1119 | out.append([np.trim_zeros(line) for line in tags])
1120 | return out
1121 |
1122 |
1123 | def buckets(x, y, size=50):
1124 | assert len(x[0]) == len(y[0])
1125 | num_inputs = len(x)
1126 | samples = x + y
1127 | num_items = len(samples)
1128 | xy = zip(*samples)
1129 | xy.sort(key=lambda i: len(i[0]))
1130 | t_len = size
1131 | idx = 0
1132 | bucks = [[[]] for _ in range(num_items)]
1133 | for item in xy:
1134 | if len(item[0]) > t_len:
1135 | if len(bucks[0][idx]) > 0:
1136 | for buck in bucks:
1137 | buck.append([])
1138 | idx += 1
1139 | while len(item[0]) > t_len:
1140 | t_len += size
1141 | for i in range(num_items):
1142 | #print item[i]
1143 | bucks[i][idx].append(item[i])
1144 |
1145 | return bucks[:num_inputs], bucks[num_inputs:]
1146 |
1147 |
1148 | def pad_bucket(x, y, limit, bucket_len_c=None):
1149 | assert len(x[0]) == len(y[0])
1150 | num_inputs = len(x)
1151 | num_tags = len(y)
1152 | padded = [[] for _ in range(num_tags + num_inputs)]
1153 | bucket_counts = []
1154 | samples = x + y
1155 | xy = zip(*samples)
1156 | if bucket_len_c is None:
1157 | bucket_len_c = []
1158 | for i, item in enumerate(xy):
1159 | max_len = len(item[0][-1])
1160 | if i == len(xy) - 1:
1161 | max_len = limit
1162 | bucket_len_c.append(max_len)
1163 | bucket_counts.append(len(item[0]))
1164 | for idx in range(num_tags + num_inputs):
1165 | padded[idx].append(pad_zeros(item[idx], max_len))
1166 | print 'Number of buckets: ', len(bucket_len_c)
1167 | else:
1168 | idy = 0
1169 | for item in xy:
1170 | max_len = len(item[0][-1])
1171 | while idy < len(bucket_len_c) and max_len > bucket_len_c[idy]:
1172 | idy += 1
1173 | bucket_counts.append(len(item[0]))
1174 | if idy >= len(bucket_len_c):
1175 | for idx in range(num_tags + num_inputs):
1176 | padded[idx].append(pad_zeros(item[idx], max_len))
1177 | bucket_len_c.append(max_len)
1178 | else:
1179 | for idx in range(num_tags + num_inputs):
1180 | padded[idx].append(pad_zeros(item[idx], bucket_len_c[idy]))
1181 | return padded[:num_inputs], padded[num_inputs:], bucket_len_c, bucket_counts
1182 |
1183 |
1184 | def get_real_batch(counts, b_size):
1185 | real_batch_sizes = []
1186 | for c in counts:
1187 | if c < b_size:
1188 | real_batch_sizes.append(c)
1189 | else:
1190 | real_batch_sizes.append(b_size)
1191 | return real_batch_sizes
1192 |
1193 |
1194 | def merge_bucket(x):
1195 | out = []
1196 | for item in x:
1197 | m = []
1198 | for i in item:
1199 | m += i
1200 | out.append(m)
1201 | return out
1202 |
1203 |
1204 | def decode_tags(idx, index2tags):
1205 | out = []
1206 | for id in idx:
1207 | sents = []
1208 | for line in id:
1209 | sent = []
1210 | for item in line:
1211 | tag = index2tags[item]
1212 | tag = tag.replace('E', 'I')
1213 | tag = tag.replace('S', 'B')
1214 | tag = tag.replace('J', 'Z')
1215 | tag = tag.replace('D', 'K')
1216 | sent.append(tag)
1217 | sents.append(sent)
1218 | out.append(sents)
1219 | return out
1220 |
1221 |
1222 | def decode_chars(idx, idx2chars):
1223 | out = []
1224 | for line in idx:
1225 | line = np.trim_zeros(line)
1226 | out.append([idx2chars[item] for item in line])
1227 | return out
1228 |
1229 |
1230 | def generate_output(chars, tags, trans_dict, transducer_dict=None, multi_tok=False, trans_type='mix'):
1231 | out = []
1232 | mult_out = []
1233 | raw_out = []
1234 | sent_seg = False
1235 |
1236 | def map_trans(c_trans, type=trans_type):
1237 | if c_trans in trans_dict and (type == 'mix' or type == 'dict'):
1238 | c_trans = trans_dict[c_trans]
1239 | elif transducer_dict is not None and (type == 'mix' or type == 'trans'):
1240 | c_trans = transducer_dict(c_trans)
1241 | sp = c_trans.split()
1242 | c_trans = ' '.join(sp)
1243 |
1244 | return c_trans
1245 |
1246 | def add_pline(p_line, mt_p_line, c_trans, multi_tok, trans=False):
1247 | c_trans = c_trans.strip()
1248 | if len(c_trans) > 0:
1249 | if trans:
1250 | o_trans = c_trans
1251 | c_trans = map_trans(c_trans)
1252 | if multi_tok:
1253 | num_tr = len(c_trans.split(' '))
1254 | mt_p_line += ' ' + o_trans + '!#!' + str(num_tr) + ' ' + c_trans
1255 | else:
1256 | if multi_tok:
1257 | mt_p_line += ' ' + c_trans
1258 | p_line += ' ' + c_trans
1259 | return p_line, mt_p_line
1260 |
1261 | def split_sent(lines, s_str):
1262 | for i in range(len(lines)):
1263 | s_line = lines[i].strip()
1264 | while s_line and s_line[-1] == s_str:
1265 | s_line = s_line[:-1]
1266 | sents = s_line.split(s_str)
1267 | lines[i] = [sent.strip() for sent in sents]
1268 | return lines
1269 |
1270 | for i, tag in enumerate(tags):
1271 | assert len(chars) == len(tag)
1272 | sub_out = []
1273 | sub_raw_out = []
1274 | multi_sub_out = []
1275 | j_chars = []
1276 | j_tags = []
1277 | is_first = True
1278 | for chs, tgs in zip(chars, tag):
1279 | if chs[0] == '<#>':
1280 | assert len(j_chars) > 0
1281 | if is_first:
1282 | is_first = False
1283 | j_chars[-1] = j_chars[-1][:-5] + chs[6:]
1284 | j_tags[-1] = j_tags[-1][:-5] + tgs[6:]
1285 | else:
1286 | j_chars[-1] = j_chars[-1][:-5] + chs[5:]
1287 | j_tags[-1] = j_tags[-1][:-5] + tgs[5:]
1288 | else:
1289 | j_chars.append(chs)
1290 | j_tags.append(tgs)
1291 | is_first = True
1292 | chars = j_chars
1293 | tag = j_tags
1294 | for chs, tgs in zip(chars, tag):
1295 | assert len(chs) == len(tgs)
1296 | c_word = ''
1297 | c_trans = ''
1298 | p_line = ''
1299 | r_line = ''
1300 | mt_p_line = ''
1301 | for ch, tg in zip(chs, tgs):
1302 | r_line += ch
1303 | if tg == 'I':
1304 | if len(c_trans) > 0:
1305 | p_line, mt_p_line = add_pline(p_line, mt_p_line, c_trans, multi_tok, trans=True)
1306 | c_trans = ''
1307 | c_word = ch
1308 | else:
1309 | c_word += ch
1310 | elif tg == 'Z':
1311 | if len(c_word) > 0:
1312 | p_line, mt_p_line = add_pline(p_line, mt_p_line, c_word, multi_tok)
1313 | c_word = ''
1314 | c_trans = ch
1315 | else:
1316 | c_trans += ch
1317 | elif tg == 'B':
1318 | if len(c_word) > 0:
1319 | c_word = c_word.strip()
1320 | p_line, mt_p_line = add_pline(p_line, mt_p_line, c_word, multi_tok)
1321 | elif len(c_trans) > 0:
1322 | c_trans = c_trans.strip()
1323 | p_line, mt_p_line = add_pline(p_line, mt_p_line, c_trans, multi_tok, trans=True)
1324 | c_trans = ''
1325 | c_word = ch
1326 | elif tg == 'K':
1327 | if len(c_word) > 0:
1328 | p_line, mt_p_line = add_pline(p_line, mt_p_line, c_word, multi_tok)
1329 | c_word = ''
1330 | elif len(c_trans) > 0:
1331 | p_line, mt_p_line = add_pline(p_line, mt_p_line, c_trans, multi_tok, trans=True)
1332 | c_trans = ch
1333 | elif tg == 'T':
1334 | sent_seg = True
1335 | if len(c_word) > 0:
1336 | p_line, mt_p_line = add_pline(p_line, mt_p_line, c_word, multi_tok)
1337 | c_word = ''
1338 | elif len(c_trans) > 0:
1339 | p_line, mt_p_line = add_pline(p_line, mt_p_line, c_trans, multi_tok, trans=True)
1340 | c_trans = ''
1341 | p_line += ' ' + ch + ''
1342 | if multi_tok:
1343 | mt_p_line += ' ' + ch + ''
1344 | r_line += ''
1345 | elif tg == 'U':
1346 | sent_seg = True
1347 | if len(c_word) > 0:
1348 | c_word += ch
1349 | p_line, mt_p_line = add_pline(p_line, mt_p_line, c_word, multi_tok)
1350 | c_word = ''
1351 | elif len(c_trans) > 0:
1352 | c_trans += ch
1353 | p_line, mt_p_line = add_pline(p_line, mt_p_line, c_trans, multi_tok, trans=True)
1354 | c_trans = ''
1355 | elif len(ch.strip()) > 0:
1356 | p_line += ch
1357 | if multi_tok:
1358 | mt_p_line += ch
1359 | p_line += ''
1360 | if multi_tok:
1361 | mt_p_line += ''
1362 | r_line += ''
1363 | elif tg == 'X' and len(ch.strip()) > 0:
1364 | if len(c_word) > 0:
1365 | c_word += ch
1366 | elif len(c_trans) > 0:
1367 | c_trans += ch
1368 | else:
1369 | c_word = ch
1370 | elif len(ch.strip()) > 0:
1371 | if len(c_word) > 0:
1372 | c_word += ' ' + ch
1373 | elif len(c_trans) > 0:
1374 | c_trans += ' ' + ch
1375 | else:
1376 | c_word = ch
1377 | if len(c_word) > 0:
1378 | c_word = c_word.strip()
1379 | p_line, mt_p_line = add_pline(p_line, mt_p_line, c_word, multi_tok)
1380 | elif len(c_trans) > 0:
1381 | c_trans = c_trans.strip()
1382 | p_line, mt_p_line = add_pline(p_line, mt_p_line, c_trans, multi_tok, trans=True)
1383 | sub_out.append(p_line.strip())
1384 | sub_raw_out.append(r_line.strip())
1385 | if multi_tok:
1386 | multi_sub_out.append(mt_p_line.strip())
1387 | out.append(sub_out)
1388 | raw_out.append(sub_raw_out)
1389 | if multi_tok:
1390 | mult_out.append(multi_sub_out)
1391 | out[0][-1].rstrip('')
1392 | raw_out[0][-1].rstrip('')
1393 | if sent_seg:
1394 | out = split_sent(out[0], '')
1395 | raw_out = split_sent(raw_out[0], '')
1396 | if multi_tok:
1397 | mult_out[0][-1].rstrip('')
1398 | if sent_seg:
1399 | mult_out = split_sent(mult_out[0], '')
1400 | return out, raw_out, mult_out
1401 | else:
1402 | return out, raw_out
1403 |
1404 |
1405 | def generate_output_sea(chars, tags):
1406 | out = []
1407 | raw_out = []
1408 | sent_seg = False
1409 |
1410 | def split_sent(lines, s_str):
1411 | for i in range(len(lines)):
1412 | s_line = lines[i].strip()
1413 | while s_line and s_line[-1] == s_str:
1414 | s_line = s_line[:-1]
1415 | sents = s_line.split(s_str)
1416 | lines[i] = [sent.strip() for sent in sents]
1417 | return lines
1418 |
1419 | for i, tag in enumerate(tags):
1420 | assert len(chars) == len(tag)
1421 | sub_out = []
1422 | sub_raw_out = []
1423 | j_chars = []
1424 | j_tags = []
1425 | is_first = True
1426 | for chs, tgs in zip(chars, tag):
1427 | if chs[0] == '<#>':
1428 | assert len(j_chars) > 0
1429 | if is_first:
1430 | is_first = False
1431 | j_chars[-1] = j_chars[-1][:-5] + chs[6:]
1432 | j_tags[-1] = j_tags[-1][:-5] + tgs[6:]
1433 | else:
1434 | j_chars[-1] = j_chars[-1][:-5] + chs[5:]
1435 | j_tags[-1] = j_tags[-1][:-5] + tgs[5:]
1436 | else:
1437 | j_chars.append(chs)
1438 | j_tags.append(tgs)
1439 | is_first = True
1440 | chars = j_chars
1441 | tag = j_tags
1442 | for chs, tgs in zip(chars, tag):
1443 | assert len(chs) == len(tgs)
1444 | p_line = ''
1445 | r_line = ''
1446 | for ch, tg in zip(chs, tgs):
1447 | r_line += ' ' + ch
1448 | if tg == 'I':
1449 | if ch == '.' or (ch >= '0' and ch <= '9'):
1450 | p_line += ch
1451 | else:
1452 | p_line += ' ' + ch
1453 | elif tg == 'B':
1454 | p_line += ' ' + ch
1455 | elif tg == 'T':
1456 | sent_seg = True
1457 | p_line += ' ' + ch + ''
1458 | r_line += ''
1459 | elif tg == 'U':
1460 | sent_seg = True
1461 | p_line += ch + ''
1462 | r_line += ''
1463 | elif len(ch.strip()) > 0:
1464 | p_line += ' ' + ch
1465 | sub_out.append(p_line.strip())
1466 | sub_raw_out.append(r_line.strip())
1467 | out.append(sub_out)
1468 | raw_out.append(sub_raw_out)
1469 | out[0][-1].rstrip('')
1470 | raw_out[0][-1].rstrip('')
1471 | if sent_seg:
1472 | out = split_sent(out[0], '')
1473 | raw_out = split_sent(raw_out[0], '')
1474 | return out, raw_out
1475 |
1476 |
1477 | def trim_output(out, length):
1478 | assert len(out) == len(length)
1479 | trimmed_out = []
1480 | for item, l in zip(out, length):
1481 | trimmed_out.append(item[:l])
1482 | return trimmed_out
1483 |
1484 |
1485 | def generate_trans_out(x, idx2char):
1486 | out = ''
1487 | for i in x:
1488 | if i == 3:
1489 | out += ' '
1490 | elif i in idx2char:
1491 | out += idx2char[i]
1492 | if '<#>' in out:
1493 | out = out[:out.index('<#>')]
1494 | out = out.replace(' ', ' ')
1495 | out = out.replace(' ', ' ')
1496 | return out
1497 |
1498 |
1499 | def generate_sent_out(raw, predictions):
1500 | out = []
1501 | line = ''
1502 | assert len(raw) == len(predictions)
1503 | for ch, tag in zip(raw, predictions):
1504 | line += ch
1505 | if tag == 1:
1506 | line = line.strip()
1507 | out.append(line)
1508 | line = ''
1509 | if len(line) > 0:
1510 | line = line.strip()
1511 | out.append(line)
1512 | return out
1513 |
1514 |
1515 | def viterbi(max_scores, max_scores_pre, length, batch_size):
1516 | best_paths = []
1517 | for m in range(batch_size):
1518 | path = []
1519 | last_max_node = np.argmax(max_scores[m][length[m] - 1])
1520 | path.append(last_max_node)
1521 | for t in range(1, length[m])[::-1]:
1522 | last_max_node = max_scores_pre[m][t][last_max_node]
1523 | path.append(last_max_node)
1524 | path = path[::-1]
1525 | best_paths.append(path)
1526 | return best_paths
1527 |
1528 |
1529 | def get_new_chars(path, char2idx, is_space):
1530 | new_chars = set()
1531 | for line in codecs.open(path, 'rb', encoding='utf-8'):
1532 | line = line.strip()
1533 | if is_space == 'sea':
1534 | line = pre_token(line)
1535 | for ch in line:
1536 | if ch not in char2idx:
1537 | new_chars.add(ch)
1538 | return new_chars
1539 |
1540 |
1541 | def get_valid_chars(chars, emb_path):
1542 | valid_chars = []
1543 | total = []
1544 | for line in codecs.open(emb_path, 'rb', encoding='utf-8'):
1545 | line = line.strip()
1546 | sets = line.split(' ')
1547 | total.append(sets[0])
1548 | for ch in chars:
1549 | if ch in total:
1550 | valid_chars.append(ch)
1551 | return valid_chars
1552 |
1553 |
1554 | def get_new_embeddings(new_chars, emb_dim, emb_path):
1555 | assert os.path.isfile(emb_path)
1556 | emb = {}
1557 | new_emb = []
1558 | for line in codecs.open(emb_path, 'rb', encoding='utf-8'):
1559 | line = line.strip()
1560 | sets = line.split(' ')
1561 | emb[sets[0]] = np.asarray(sets[1:], dtype='float32')
1562 | if '' not in emb:
1563 | unk = np.random.uniform(-math.sqrt(float(3) / emb_dim), math.sqrt(float(3) / emb_dim), emb_dim)
1564 | emb[''] = np.asarray(unk, dtype='float32')
1565 | for ch in new_chars:
1566 | if ch in emb:
1567 | new_emb.append(emb[ch])
1568 | else:
1569 | new_emb.append(emb[''])
1570 | return new_emb
1571 |
1572 |
1573 | def update_char_dict(char2idx, new_chars, unk_chars_idx, valid_chars=None):
1574 | l_quos = ['"', '«', '„']
1575 | r_quos = ['"', '»', '“']
1576 | l_quos = [unicode(ch) for ch in l_quos]
1577 | r_quos = [unicode(ch) for ch in r_quos]
1578 | sub_dict = {}
1579 | old_chars = char2idx.keys()
1580 | dim = len(char2idx) + 10
1581 | if valid_chars is not None:
1582 | for ch in valid_chars:
1583 | if char2idx[ch] in unk_chars_idx:
1584 | unk_chars_idx.remove(ch)
1585 | for char in new_chars:
1586 | if char not in char2idx and len(char.strip()) > 0:
1587 | char2idx[char] = dim
1588 | if valid_chars is None or char not in valid_chars:
1589 | unk_chars_idx.append(dim)
1590 | dim += 1
1591 | idx2char = {k: v for v, k in char2idx.items()}
1592 | for ch in new_chars:
1593 | if ch in l_quos:
1594 | for l_ch in l_quos:
1595 | if l_ch in old_chars:
1596 | sub_dict[char2idx[ch]] = char2idx[l_ch]
1597 | if char2idx[ch] in unk_chars_idx:
1598 | unk_chars_idx.remove(char2idx[ch])
1599 | break
1600 | elif ch in r_quos:
1601 | for r_ch in r_quos:
1602 | if r_ch in old_chars:
1603 | sub_dict[char2idx[ch]] = char2idx[r_ch]
1604 | if char2idx[ch] in unk_chars_idx:
1605 | unk_chars_idx.remove(char2idx[ch])
1606 | break
1607 | return char2idx, idx2char, unk_chars_idx, sub_dict
1608 |
1609 |
1610 | def get_new_grams(path, gram2idx, is_raw=False, is_space=True):
1611 | raw = []
1612 | i = 0
1613 | for line in codecs.open(path, 'rb', encoding='utf-8'):
1614 | line = line.strip()
1615 | if is_space == 'sea':
1616 | line = pre_token(line)
1617 | if i == 0 or is_raw:
1618 | raw.append(line)
1619 | i += 1
1620 | if len(line) > 0:
1621 | i += 1
1622 | else:
1623 | i = 0
1624 | new_grams = []
1625 | for g_dic in gram2idx:
1626 | new_g = []
1627 | if is_space == 'sea':
1628 | n = len(g_dic.keys()[0].split('_'))
1629 | else:
1630 | n = 0
1631 | for k in g_dic.keys():
1632 | if '' not in k:
1633 | n = len(k)
1634 | break
1635 | grams = ngrams(raw, n, is_space)
1636 | for g in grams:
1637 | if g not in g_dic:
1638 | new_g.append(g)
1639 | new_grams.append(new_g)
1640 | return new_grams
1641 |
1642 |
1643 | def printer(raw, tagged, multi_out, outpath, sent_seg, form='conll'):
1644 | assert len(tagged) == len(multi_out)
1645 | validator(raw, multi_out)
1646 | wt = codecs.open(outpath, 'w', encoding='utf-8')
1647 | if form == 'conll':
1648 | if not sent_seg:
1649 | for raw_t, tagged_t, multi_t in zip(raw, tagged, multi_out):
1650 | if len(multi_t) > 0:
1651 | wt.write('#sent_raw: ' + raw_t + '\n')
1652 | wt.write('#sent_tok: ' + tagged_t + '\n')
1653 | idx = 1
1654 | tgs = multi_t.split(' ')
1655 | pl = ''
1656 | for _ in range(8):
1657 | pl += '\t' + '_'
1658 | for tg in tgs:
1659 | if '!#!' in tg:
1660 | segs = tg.split('!#!')
1661 | wt.write(str(idx) + '-' + str(int(segs[1]) + idx - 1) + '\t' + segs[0] + pl + '\n')
1662 | else:
1663 | wt.write(str(idx) + '\t' + tg + pl + '\n')
1664 | idx += 1
1665 | wt.write('\n')
1666 | else:
1667 | for tagged_t, multi_t in zip(tagged, multi_out):
1668 | if len(tagged_t.strip()) > 0:
1669 | wt.write('#sent_tok: '+ tagged_t + '\n')
1670 | idx = 1
1671 | tgs = multi_t.split(' ')
1672 | pl = ''
1673 | for _ in range(8):
1674 | pl += '\t' + '_'
1675 | for tg in tgs:
1676 | if '!#!' in tg:
1677 | segs = tg.split('!#!')
1678 | wt.write(str(idx) + '-' + str(int(segs[1]) + idx - 1) + '\t' + segs[0] + pl + '\n')
1679 | else:
1680 | wt.write(str(idx) + '\t' + tg + pl + '\n')
1681 | idx += 1
1682 | wt.write('\n')
1683 | else:
1684 | for tg in tagged:
1685 | wt.write(tg + '\n')
1686 | wt.close()
1687 |
1688 |
1689 | def biased_out(prediction, bias):
1690 | out = []
1691 | b_pres = []
1692 | for pre in prediction:
1693 | b_pres.append(pre[:,0] - pre[:,1])
1694 | props = np.concatenate(b_pres)
1695 | props = np.sort(props)[::-1]
1696 | idx = int(bias*len(props))
1697 | if idx == len(props):
1698 | idx -= 1
1699 | th = props[idx]
1700 | print 'threshold: ', th, 1 / (1 + np.exp(-th))
1701 | for pre in b_pres:
1702 | pre[pre >= th] = 0
1703 | pre[pre != 0] = 1
1704 | out.append(pre)
1705 | return out
1706 |
1707 |
1708 | def to_one_hot(y, nb_classes=None):
1709 | '''Convert class vector (integers from 0 to nb_classes) to binary class matrix, for use with categorical_crossentropy.
1710 | # Arguments
1711 | y: class vector to be converted into a matrix
1712 | nb_classes: total number of classes
1713 | # Returns
1714 | A binary matrix representation of the input.
1715 | '''
1716 | if not nb_classes:
1717 | nb_classes = np.max(y)+1
1718 | Y = np.zeros((len(y), nb_classes))
1719 | for i in range(len(y)):
1720 | Y[i, y[i]] = 1.
1721 | return Y
1722 |
1723 |
1724 | def validator(raw, generated):
1725 | raw_l = ''.join(raw)
1726 | raw_l = ''.join(raw_l.split())
1727 | for g in generated:
1728 | g_tokens = g.split(' ')
1729 | j = 0
1730 | while j < len(g_tokens):
1731 | if '!#!' in g_tokens[j]:
1732 | segs = g_tokens[j].split('!#!')
1733 | c_t = int(segs[1])
1734 | r_seg = ''.join(segs[0].split())
1735 | l_w = len(r_seg)
1736 | if r_seg == raw_l[:l_w]:
1737 | raw_l = raw_l[l_w:]
1738 | raw_l = raw_l.strip()
1739 | else:
1740 | raise Exception('Error: unmatch...')
1741 | j += c_t
1742 | else:
1743 | r_seg = ''.join(g_tokens[j].split())
1744 | l_w = len(r_seg)
1745 | if r_seg == raw_l[:l_w]:
1746 | raw_l = raw_l[l_w:]
1747 | raw_l = raw_l.strip()
1748 | else:
1749 | print r_seg
1750 | print raw_l[:l_w]
1751 | print ''
1752 | raise Exception('Error: unmatch...')
1753 | j += 1
1754 |
1755 |
1756 | def mlp_post(raw, prediction, is_space=False, form='mlp1'):
1757 | assert len(raw) == len(prediction)
1758 | out = []
1759 | for r_l, p_l in zip(raw, prediction):
1760 | st = ''
1761 | rtokens = r_l.split()
1762 | ptokens = p_l.split(' ')
1763 | purged = []
1764 | for pt in ptokens:
1765 | purged.append(pt.strip())
1766 | ptokens = purged
1767 | ptokens_str = ''.join(ptokens)
1768 | assert ''.join(rtokens) == ''.join(ptokens_str.split())
1769 | if form == 'mlp1':
1770 | if is_space == 'sea':
1771 | for p_t in ptokens:
1772 | st += p_t.replace(' ', '_') + ' '
1773 | else:
1774 | while rtokens and ptokens:
1775 | if rtokens[0] == ptokens[0]:
1776 | st += ptokens[0] + ' '
1777 | rtokens.pop(0)
1778 | ptokens.pop(0)
1779 | else:
1780 | if len(rtokens[0]) <= len(ptokens[0]):
1781 | assert ptokens[0][:len(rtokens[0])] == rtokens[0]
1782 | st += rtokens[0] + ' '
1783 | ptokens[0] = ptokens[0][len(rtokens[0]):].strip()
1784 | rtokens.pop(0)
1785 | else:
1786 | can = ''
1787 | while can != rtokens[0] and ptokens:
1788 | can += ptokens[0]
1789 | st += ptokens[0] + '\\\\'
1790 | ptokens.pop(0)
1791 | st = st[:-2] + ' '
1792 | rtokens.pop(0)
1793 | else:
1794 | for p_t in ptokens:
1795 | st += p_t + ' '
1796 | out.append(st.strip())
1797 | return out
--------------------------------------------------------------------------------
/transducer_model.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import tensorflow as tf
3 | import tensorflow.contrib.legacy_seq2seq as seq2seq
4 | import toolbox
5 | import batch as Batch
6 | import numpy as np
7 | import cPickle as pickle
8 | import evaluation
9 |
10 | import os
11 |
12 | class Seq2seq(object):
13 |
14 | def __init__(self, trained_model):
15 | self.en_vec = None
16 | self.de_vec = None
17 | self.trans_output = None
18 | self.trans_labels = None
19 | self.feed_previouse = None
20 | self.trans_l_rate = None
21 | self.trained = trained_model
22 | self.decode_step = None
23 | self.encode_step = None
24 |
25 | def define(self, char_num, rnn_dim, emb_dim, max_x, max_y, write_trans_model=True):
26 | self.decode_step = max_y
27 | self.encode_step = max_x
28 | self.en_vec = [tf.placeholder(tf.int32, [None], name='en_input' + str(i)) for i in range(max_x)]
29 | self.trans_labels = [tf.placeholder(tf.int32, [None], name='de_input' + str(i)) for i in range(max_y)]
30 | weights = [tf.cast(tf.sign(ot_t), tf.float32) for ot_t in self.trans_labels]
31 | self.de_vec = [tf.zeros_like(self.trans_labels[0], tf.int32)] + self.trans_labels[:-1]
32 | self.feed_previous = tf.placeholder(tf.bool)
33 | self.trans_l_rate = tf.placeholder(tf.float32, [], name='learning_rate')
34 | seq_cell = tf.nn.rnn_cell.BasicLSTMCell(rnn_dim, state_is_tuple=True)
35 | self.trans_output, states = seq2seq.embedding_attention_seq2seq(self.en_vec, self.de_vec, seq_cell, char_num,
36 | char_num, emb_dim, feed_previous=self.feed_previous)
37 |
38 | loss = seq2seq.sequence_loss(self.trans_output, self.trans_labels, weights)
39 | optimizer = tf.train.AdagradOptimizer(learning_rate=self.trans_l_rate)
40 |
41 | params = tf.trainable_variables()
42 | gradients = tf.gradients(loss, params)
43 | clipped_gradients, norm = tf.clip_by_global_norm(gradients, 5.0)
44 | self.trans_train = optimizer.apply_gradients(zip(clipped_gradients, params))
45 |
46 | self.saver = tf.train.Saver()
47 |
48 | if write_trans_model:
49 | param_dic = {}
50 | param_dic['char_num'] = char_num
51 | param_dic['rnn_dim'] = rnn_dim
52 | param_dic['emb_dim'] = emb_dim
53 | param_dic['max_x'] = max_x
54 | param_dic['max_y'] = max_y
55 | # print param_dic
56 | f_model = open(self.trained + '_model', 'w')
57 | pickle.dump(param_dic, f_model)
58 | f_model.close()
59 |
60 | def train(self, t_x, t_y, v_x, v_y, lrv, char2idx, sess, epochs, batch_size=10, reset=True):
61 |
62 | idx2char = {k: v for v, k in char2idx.items()}
63 | v_y_g = [np.trim_zeros(v_y_t) for v_y_t in v_y]
64 | gold_out = [toolbox.generate_trans_out(v_y_t, idx2char) for v_y_t in v_y_g]
65 |
66 | best_score = 0
67 |
68 | if reset or not os.path.isfile(self.trained + '_weights.index'):
69 | for epoch in range(epochs):
70 | Batch.train_seq2seq(sess, model=self.en_vec + self.trans_labels, decoding=self.feed_previous,
71 | batch_size=batch_size, config=self.trans_train, lr=self.trans_l_rate, lrv=lrv,
72 | data=[t_x] + [t_y])
73 | pred = Batch.predict_seq2seq(sess, model=self.en_vec + self.de_vec + self.trans_output,
74 | decoding=self.feed_previous, decode_len=self.decode_step,
75 | data=[v_x], argmax=True, batch_size=100)
76 | pred_out = [toolbox.generate_trans_out(pre_t, idx2char) for pre_t in pred]
77 |
78 | c_scores = evaluation.trans_evaluator(gold_out, pred_out)
79 |
80 | print 'epoch: %d' % (epoch + 1)
81 |
82 | print 'ACC: %f' % c_scores[0]
83 | print 'Token F score: %f' % c_scores[1]
84 |
85 | if c_scores[1] > best_score:
86 | best_score = c_scores[1]
87 | self.saver.save(sess, self.trained + '_weights', write_meta_graph=False)
88 |
89 | if best_score > 0 or not reset:
90 | self.saver.restore(sess, self.trained + '_weights')
91 |
92 | def tag(self, t_x, char2idx, sess, batch_size=100):
93 |
94 | t_x = [t_x_t[:self.encode_step] for t_x_t in t_x]
95 | t_x = toolbox.pad_zeros(t_x, self.encode_step)
96 |
97 | idx2char = {k: v for v, k in char2idx.items()}
98 |
99 | pred = Batch.predict_seq2seq(sess, model=self.en_vec + self.de_vec + self.trans_output, decoding=self.feed_previous,
100 | decode_len=self.decode_step, data=[t_x], argmax=True, batch_size=batch_size)
101 | pred_out = [toolbox.generate_trans_out(pre_t, idx2char) for pre_t in pred]
102 |
103 | return pred_out
104 |
105 |
106 |
--------------------------------------------------------------------------------