├── src_g2s
├── __init__.py
├── namespace_utils.py
├── amr_utils.py
├── padding_utils.py
├── metric_utils.py
├── metric_rouge_utils.py
├── metric_bleu_utils.py
├── G2S_data_stream.py
├── G2S_beam_decoder.py
├── G2S_model_graph.py
└── G2S_trainer.py
├── train_s2s.sh
├── train_g2s.sh
├── decode_s2s.sh
├── src_s2s
├── namespace_utils.py
├── prepare_question_generation_dataset.py
├── phrase_projection_layer_utils.py
├── padding_utils.py
├── metric_utils.py
├── metric_rouge_utils.py
├── prepare_paraphrase_dataset.py
├── prepare_summarization_dataset.py
├── sent_utils.py
├── NP2P_data_stream.py
├── phrase_lattice_utils.py
├── metric_bleu_utils.py
├── encoder_utils.py
├── NP2P_phrase_trainer.py
├── NP2P_trainer.py
└── NP2P_beam_decoder.py
├── decode_g2s.sh
├── data
├── 3_extract_embedding.py
├── 1_make_json.py
├── 4_insert_symbols_front.py
└── 2_gen_vocab.py
├── AMR_multiline_to_singleline.py
├── logs_g2s
└── extract_and_eval.py
├── config_g2s.json
├── config_s2s.json
└── README.md
/src_g2s/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/train_s2s.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #SBATCH -J AMR_s2s --partition=gpu --gres=gpu:1 --time=5-00:00:00 --output=train.out_s2s --error=train.err_s2s
3 | #SBATCH --mem=15GB
4 | #SBATCH -c 5
5 |
6 | export PYTHONPATH=$PYTHONPATH:/home/lsong10/ws/exp.graph_to_seq/neural-graph-to-seq-mp
7 |
8 | python src_s2s/NP2P_trainer.py --config_path config_s2s.json
9 |
10 |
--------------------------------------------------------------------------------
/train_g2s.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #SBATCH -J g_2m -C K80 --partition=gpu --gres=gpu:1 --time=5-00:00:00 --output=train.out --error=train.err
3 | #SBATCH --mem=80GB
4 | #SBATCH -c 5
5 |
6 | export PYTHONPATH=$PYTHONPATH:/home/lsong10/ws/exp.graph_to_seq/neural-graph-to-seq-mp
7 |
8 | python src_g2s/G2S_trainer.py --config_path logs_g2s/G2S.silver_2m.config.json
9 |
10 |
--------------------------------------------------------------------------------
/decode_s2s.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #SBATCH --partition=gpu --gres=gpu:1 --time=1:00:00 --output=decode.out --error=decode.err
3 | #SBATCH --mem=10GB
4 | #SBATCH -c 6
5 |
6 | export PYTHONPATH=$PYTHONPATH:/home/lsong10/ws/exp.graph_to_seq/neural-graph-to-seq-mp
7 |
8 | python src_s2s/NP2P_beam_decoder.py --model_prefix logs_s2s/NP2P.$1 \
9 | --in_path data/test.json \
10 | --out_path logs_s2s/test.s2s.$1\.tok \
11 | --mode beam
12 |
13 |
--------------------------------------------------------------------------------
/src_g2s/namespace_utils.py:
--------------------------------------------------------------------------------
1 | import json
2 |
3 | class Bunch(object):
4 | def __init__(self, adict):
5 | self.__dict__.update(adict)
6 |
7 | def save_namespace(FLAGS, out_path):
8 | FLAGS_dict = vars(FLAGS)
9 | with open(out_path, 'w') as fp:
10 | json.dump(FLAGS_dict, fp)
11 |
12 | def load_namespace(in_path):
13 | with open(in_path, 'r') as fp:
14 | FLAGS_dict = json.load(fp)
15 | return Bunch(FLAGS_dict)
--------------------------------------------------------------------------------
/src_s2s/namespace_utils.py:
--------------------------------------------------------------------------------
1 | import json
2 |
3 | class Bunch(object):
4 | def __init__(self, adict):
5 | self.__dict__.update(adict)
6 |
7 | def save_namespace(FLAGS, out_path):
8 | FLAGS_dict = vars(FLAGS)
9 | with open(out_path, 'w') as fp:
10 | json.dump(FLAGS_dict, fp)
11 |
12 | def load_namespace(in_path):
13 | with open(in_path, 'r') as fp:
14 | FLAGS_dict = json.load(fp)
15 | return Bunch(FLAGS_dict)
--------------------------------------------------------------------------------
/decode_g2s.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #SBATCH --partition=gpu --gres=gpu:1 -C K80 --time=1:00:00 --output=decode.out --error=decode.err
3 | #SBATCH --mem=10GB
4 | #SBATCH -c 6
5 |
6 | export PYTHONPATH=$PYTHONPATH:/home/lsong10/ws/exp.graph_to_seq/neural-graph-to-seq-mp
7 |
8 | python src_g2s/G2S_beam_decoder.py --model_prefix logs_g2s/G2S.$1 \
9 | --in_path data/test.json \
10 | --out_path logs_g2s/test.g2s.$1\.tok \
11 | --mode beam
12 |
13 |
--------------------------------------------------------------------------------
/data/3_extract_embedding.py:
--------------------------------------------------------------------------------
1 |
2 | import numpy
3 | import sys, os
4 | from collections import Counter
5 |
6 | vocab = set(line.strip() for line in open(sys.argv[1], 'rU'))
7 | print 'len(vocab)', len(vocab)
8 |
9 | intersect = set()
10 | f = open(sys.argv[2], 'w')
11 | for line in open('/home/lsong10/ws/data.embedding/glove.840B.300d.txt', 'rU'):
12 | word = line.strip().split()[0]
13 | if word in vocab:
14 | intersect.add(word)
15 | print >>f, line.strip()
16 | print len(intersect)
17 |
18 | for w in vocab - intersect:
19 | embedding = ' '.join([str('%.6f'%x) for x in numpy.random.normal(size=600)])
20 | print >>f, w, embedding
21 |
22 | f.close()
23 |
--------------------------------------------------------------------------------
/src_s2s/prepare_question_generation_dataset.py:
--------------------------------------------------------------------------------
1 | import json
2 |
3 | def read_all_SQuAD_questions(inpath):
4 | with open(inpath) as dataset_file:
5 | dataset_json = json.load(dataset_file, encoding='utf-8')
6 | dataset = dataset_json['data']
7 | all_questions = []
8 | for article in dataset:
9 | # title = article['title']
10 | for paragraph in article['paragraphs']:
11 | context = paragraph['context']
12 | for question in paragraph['qas']:
13 | question_text = question['question']
14 | question_id = question['id']
15 | answers = question['answers']
16 | return all_questions
17 |
18 |
--------------------------------------------------------------------------------
/data/1_make_json.py:
--------------------------------------------------------------------------------
1 | import sys, os, json
2 |
3 | amr = [x.strip() for x in open(sys.argv[1]+'-dfs-linear_src.txt','rU')]
4 | print 'len(amr)', len(amr)
5 | sent = [x.strip() for x in open(sys.argv[1]+'-dfs-linear_targ.txt','rU')]
6 | print 'len(sent)', len(sent)
7 | assert len(amr) == len(sent)
8 |
9 | ids = None
10 | if os.path.isfile(sys.argv[1]+'-ids.txt'):
11 | ids = [x.strip() for x in open(sys.argv[1]+'-ids.txt','rU')]
12 | assert len(amr) == len(ids)
13 |
14 | data = []
15 | for i in range(len(amr)):
16 | json_obj = {'amr':amr[i],'sent':sent[i],}
17 | if ids != None:
18 | json_obj['id'] = ids[i]
19 | data.append(json_obj)
20 | print len(data)
21 | json.dump(data,open(sys.argv[1]+'.json','w'))
22 |
23 |
--------------------------------------------------------------------------------
/data/4_insert_symbols_front.py:
--------------------------------------------------------------------------------
1 |
2 | import os,sys
3 | import numpy
4 |
5 | inpath = sys.argv[1]
6 | outpath = sys.argv[1]+'.st'
7 | f = open(outpath,'w')
8 | for i,line in enumerate(open(inpath,'rU')):
9 | if i == 0:
10 | vsize = len(line.strip().split())-1
11 | print vsize
12 | print >>f, '\t'.join(['0', '#pad#', ' '.join([str('%.6f'%x) for x in numpy.zeros(vsize)])])
13 | print >>f, '\t'.join(['1', '', ' '.join([str('%.6f'%x) for x in numpy.random.normal(size=vsize)])])
14 | print >>f, '\t'.join(['2', '', ' '.join([str('%.6f'%x) for x in numpy.random.normal(size=vsize)])])
15 | line = line.strip().split()
16 | word = line[0]
17 | line = ' '.join(line[1:])
18 | print >>f, '\t'.join([str(i+3), word, line])
19 |
20 |
--------------------------------------------------------------------------------
/AMR_multiline_to_singleline.py:
--------------------------------------------------------------------------------
1 |
2 | import sys
3 |
4 | ids = []
5 | id_dict = {}
6 | amrs = []
7 | amr_str = ''
8 | for line in open(sys.argv[1],'rU'):
9 | if line.startswith('#'):
10 | if line.startswith('# ::id'):
11 | id = line.lower().strip().split()[2]
12 | ids.append(id)
13 | id_dict[id] = len(ids)-1
14 | continue
15 | line = line.strip()
16 | if line == '':
17 | if amr_str != '':
18 | amrs.append(amr_str.strip())
19 | amr_str = ''
20 | else:
21 | amr_str = amr_str + line + ' '
22 |
23 | if amr_str != '':
24 | amrs.append(amr_str.strip())
25 | amr_str = ''
26 |
27 | if len(sys.argv) == 3:
28 | for line in open(sys.argv[2],'rU'):
29 | id = line.lower().strip()
30 | print amrs[id_dict[id]]
31 |
32 |
--------------------------------------------------------------------------------
/logs_g2s/extract_and_eval.py:
--------------------------------------------------------------------------------
1 |
2 | import sys
3 | import os
4 |
5 | output_dict = {}
6 | ref_dict = {}
7 | for i,line in enumerate(open(sys.argv[1],'rU')):
8 | if i%5 == 0:
9 | id = line.strip()
10 | elif i%5 == 1:
11 | ref = line.strip().lower().split()
12 | ref_dict[id] = ref
13 | elif i%5 == 2:
14 | rst = line.strip().replace('','',10).split()
15 | output_dict[id] = rst
16 |
17 | dataset_type = None
18 | if sys.argv[1].startswith('test'):
19 | dataset_type = 'test'
20 | elif sys.argv[1].startswith('dev'):
21 | dataset_type = 'dev'
22 |
23 | fout = open(sys.argv[1]+'.1best','w')
24 | fref = open(sys.argv[1]+'.ref','w')
25 | for id in output_dict.keys():
26 | print >>fout, ' '.join(output_dict[id]).lower()
27 | print >>fref, ' '.join(ref_dict[id]).lower()
28 | fout.close()
29 | fref.close()
30 |
31 | os.system('/home/lsong10/ws/exp.graph_to_seq/mosesdecoder/scripts/generic/multi-bleu.perl %s.ref < %s.1best' %(sys.argv[1],sys.argv[1]))
32 |
33 |
--------------------------------------------------------------------------------
/data/2_gen_vocab.py:
--------------------------------------------------------------------------------
1 |
2 | import json
3 | import sys
4 | import cPickle
5 | import re
6 | from collections import Counter
7 | import codecs
8 |
9 | def update(l, v):
10 | v.update([x.lower() for x in l])
11 |
12 | def update_vocab(path, vocab, vocab_edge, vocab_node):
13 | words = []
14 | words_edge = []
15 | words_node = []
16 | data = json.load(open(path,'rU'))
17 | for inst in data:
18 | words += inst['sent'].lower().strip().split()
19 | words += [x for x in inst['amr'].lower().strip().split() if x[0] != ':']
20 | words_edge += [x for x in inst['amr'].lower().strip().split() if x[0] == ':']
21 | words_node += [x for x in inst['amr'].lower().strip().split() if re.search('_[0-9]+', x) != None or x == 'num_unk']
22 | update(words, vocab)
23 | update(words_edge, vocab_edge)
24 | update(words_node, vocab_node)
25 |
26 | def output(d, path):
27 | f = codecs.open(path,'w',encoding='utf-8')
28 | for k,v in sorted(d.items(), key=lambda x:-x[1]):
29 | print >>f, k
30 | f.close()
31 |
32 | ##################
33 |
34 | vocab = Counter()
35 | vocab_edge = Counter()
36 | vocab_node = Counter()
37 | update_vocab('training.json', vocab, vocab_edge, vocab_node)
38 | print len(vocab), len(vocab_edge), len(vocab_node)
39 |
40 | output(vocab, 'vocab.txt')
41 | #output(vocab_edge, 'vocab_edge.txt')
42 | #output(vocab_node, 'vocab_node.txt')
43 |
44 |
--------------------------------------------------------------------------------
/config_g2s.json:
--------------------------------------------------------------------------------
1 | {
2 | "train_path": "data/2m.json",
3 | "finetune_path": "data/training.json",
4 | "test_path": "data/dev.json",
5 | "word_vec_path": "data/vectors_RwNN.txt.st",
6 | "suffix": "silver_2m",
7 | "model_dir": "logs_g2s",
8 | "isLower": true,
9 |
10 | "pointer_gen": true,
11 | "use_coverage": true,
12 | "attention_vec_size": 600,
13 | "batch_size": 100,
14 | "beam_size": 5,
15 |
16 | "num_syntax_match_layer": 9,
17 | "max_node_num": 180,
18 | "max_in_neigh_num": 2,
19 | "max_out_neigh_num": 8,
20 | "min_answer_len": 0,
21 | "max_answer_len": 50,
22 | "learning_rate": 0.0005,
23 | "lambda_l2": 1e-8,
24 | "dropout_rate": 0.1,
25 | "cov_loss_wt": 0.1,
26 | "max_epochs": 10,
27 | "optimize_type": "adam",
28 |
29 | "with_highway": false,
30 | "highway_layer_num": 1,
31 |
32 | "with_char": false,
33 | "char_dim": 100,
34 | "char_lstm_dim": 100,
35 | "max_char_per_word": 20,
36 |
37 | "attention_type": "hidden",
38 | "way_init_decoder": "all",
39 | "edgelabel_dim": 100,
40 | "neighbor_vector_dim": 600,
41 | "fix_word_vec": false,
42 | "compress_input": false,
43 | "compress_input_dim": 300,
44 |
45 | "gen_hidden_size": 600,
46 | "num_softmax_samples": 100,
47 | "mode": "ce_train",
48 |
49 | "config_path": "config.json",
50 | "generate_config": false
51 | }
52 |
--------------------------------------------------------------------------------
/config_s2s.json:
--------------------------------------------------------------------------------
1 | {
2 | "finetune_path": "",
3 | "train_path": "data/training.json",
4 | "test_path": "data/dev.json",
5 | "enc_word_vec_path": "data/vectors_RwNN.txt.st",
6 | "dec_word_vec_path": "data/vectors_RwNN.txt.st",
7 | "suffix": "gold",
8 | "model_dir": "logs_s2s",
9 | "isLower": true,
10 | "two_sent_inputs": false,
11 |
12 | "direction": "bidir",
13 | "pointer_gen": true,
14 | "use_coverage": true,
15 | "switch_qa": false,
16 | "attention_vec_size": 300,
17 | "cov_loss_wt": 0.1,
18 |
19 | "beam_size": 20,
20 |
21 | "min_answer_len": 0,
22 | "max_answer_len": 100,
23 | "batch_size": 20,
24 | "learning_rate": 0.001,
25 | "lambda_l2": 0.001,
26 | "dropout_rate": 0.2,
27 | "max_epochs": 50,
28 | "optimize_type": "adam",
29 |
30 | "with_filter_layer": true,
31 | "filter_layer_threshold": 0.2,
32 | "with_word_match": true,
33 | "with_sequential_match": true,
34 | "context_layer_num": 1,
35 | "context_lstm_dim": 300,
36 |
37 | "with_lex_decomposition": false,
38 | "lex_decompsition_dim": 50,
39 | "with_question_passage_word_feature": false,
40 |
41 | "with_phrase_projection": false,
42 |
43 | "with_highway": false,
44 | "highway_layer_num": 1,
45 | "with_match_highway": false,
46 | "with_aggregation_highway": false,
47 |
48 | "with_word": true,
49 | "with_char": true,
50 | "with_POS": false,
51 | "with_NER": false,
52 | "POS_dim": 20,
53 | "NER_dim": 20,
54 | "char_dim": 50,
55 | "char_lstm_dim": 100,
56 | "compress_input": false,
57 | "compress_input_dim": 300,
58 | "fix_word_vec": true,
59 | "max_passage_len": 100,
60 | "max_char_per_word": 20,
61 |
62 | "gen_hidden_size": 300,
63 | "num_softmax_samples": 100,
64 | "mode": "ce_train",
65 |
66 | "config_path": "config.json",
67 | "generate_config": false
68 | }
69 |
--------------------------------------------------------------------------------
/src_s2s/phrase_projection_layer_utils.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 | def collect_representation(representation, positions):
4 | '''
5 | representation: [batch_size, passsage_length, dim]
6 | positions: [batch_size, num_positions]
7 | '''
8 | def singel_instance(x):
9 | # x[0]: [passage_length, dim]
10 | # x[1]: [num_positions]
11 | return tf.gather(x[0], x[1])
12 | elems = (representation, positions)
13 | return tf.map_fn(singel_instance, elems, dtype=tf.float32) # [batch_size, num_positions, dim]
14 |
15 | class PhraseProjectionLayer(object):
16 | def __init__(self, placeholders):
17 | # placeholder assignments
18 | self.max_phrase_size = placeholders.max_phrase_size # a scaler, max number of phrases within a batch
19 | self.phrase_starts = placeholders.phrase_starts # [batch_size, chunk_len]
20 | self.phrase_ends = placeholders.phrase_ends # [batch_size, chunk_len]
21 | self.phrase_lengths = placeholders.phrase_lengths # [batch_size]
22 |
23 | def project_to_phrase_representation(self, encoder_representations):
24 | '''
25 | encoder_represenations: [batch_size, passage_length, encoder_dim]
26 | '''
27 | start_representations = collect_representation(encoder_representations, self.phrase_starts) # [batch_size, chunk_len, encoder_dim]
28 | end_representations = collect_representation(encoder_representations, self.phrase_ends) # [batch_size, chunk_len, encoder_dim]
29 | phrase_representations = tf.concat(2, [start_representations, end_representations], name='phrase_representation')
30 |
31 | phrase_len = tf.shape(self.phrase_starts)[1]
32 | phrase_mask = tf.sequence_mask(self.phrase_lengths, phrase_len, dtype=tf.float32) # [batch_size, phrase_len]
33 | phrase_mask = tf.expand_dims(phrase_mask, axis=-1, name='phrase_mask') # [batch_size, phrase_len, 'x']
34 |
35 | phrase_representations = phrase_representations * phrase_mask
36 | return phrase_representations # [batch_size, phrase_len, 2*encoder_dim]
37 |
38 |
39 |
40 |
--------------------------------------------------------------------------------
/src_g2s/amr_utils.py:
--------------------------------------------------------------------------------
1 |
2 | def read_anonymized(amr_lst, amr_node, amr_edge):
3 | assert sum(x=='(' for x in amr_lst) == sum(x==')' for x in amr_lst)
4 | cur_str = amr_lst[0]
5 | cur_id = len(amr_node)
6 | amr_node.append(cur_str)
7 |
8 | i = 1
9 | while i < len(amr_lst):
10 | if amr_lst[i].startswith(':') == False: ## cur cur-num_0
11 | nxt_str = amr_lst[i]
12 | nxt_id = len(amr_node)
13 | amr_node.append(nxt_str)
14 | amr_edge.append((cur_id, nxt_id, ':value'))
15 | i = i + 1
16 | elif amr_lst[i].startswith(':') and len(amr_lst) == 2: ## cur :edge
17 | nxt_str = 'num_unk'
18 | nxt_id = len(amr_node)
19 | amr_node.append(nxt_str)
20 | amr_edge.append((cur_id, nxt_id, amr_lst[i]))
21 | i = i + 1
22 | elif amr_lst[i].startswith(':') and amr_lst[i+1] != '(': ## cur :edge nxt
23 | nxt_str = amr_lst[i+1]
24 | nxt_id = len(amr_node)
25 | amr_node.append(nxt_str)
26 | amr_edge.append((cur_id, nxt_id, amr_lst[i]))
27 | i = i + 2
28 | elif amr_lst[i].startswith(':') and amr_lst[i+1] == '(': ## cur :edge ( ... )
29 | number = 1
30 | j = i+2
31 | while j < len(amr_lst):
32 | number += (amr_lst[j] == '(')
33 | number -= (amr_lst[j] == ')')
34 | if number == 0:
35 | break
36 | j += 1
37 | assert number == 0 and amr_lst[j] == ')', ' '.join(amr_lst[i+2:j])
38 | nxt_id = read_anonymized(amr_lst[i+2:j], amr_node, amr_edge)
39 | amr_edge.append((cur_id, nxt_id, amr_lst[i]))
40 | i = j + 1
41 | else:
42 | assert False, ' '.join(amr_lst)
43 | return cur_id
44 |
45 | if __name__ == '__main__':
46 | for path in ['data/dev-dfs-linear_src.txt', 'data/test-dfs-linear_src.txt', 'data/training-dfs-linear_src.txt', ]:
47 | print path
48 | for i, line in enumerate(open(path, 'rU')):
49 | amr_node = []
50 | amr_edge = []
51 | read_anonymized(line.strip().split(), amr_node, amr_edge)
52 |
--------------------------------------------------------------------------------
/src_g2s/padding_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | def make_batches(size, batch_size):
3 | nb_batch = int(np.ceil(size/float(batch_size)))
4 | return [(i*batch_size, min(size, (i+1)*batch_size)) for i in range(0, nb_batch)] # zgwang: starting point of each batch
5 |
6 | def pad_2d_vals_no_size(in_vals, dtype=np.int32):
7 | size1 = len(in_vals)
8 | size2 = np.max([len(x) for x in in_vals])
9 | return pad_2d_vals(in_vals, size1, size2, dtype=dtype)
10 |
11 | def pad_2d_vals(in_vals, dim1_size, dim2_size, dtype=np.int32):
12 | out_val = np.zeros((dim1_size, dim2_size), dtype=dtype)
13 | if dim1_size > len(in_vals): dim1_size = len(in_vals)
14 | for i in xrange(dim1_size):
15 | cur_in_vals = in_vals[i]
16 | cur_dim2_size = dim2_size
17 | if cur_dim2_size > len(cur_in_vals): cur_dim2_size = len(cur_in_vals)
18 | out_val[i,:cur_dim2_size] = cur_in_vals[:cur_dim2_size]
19 | return out_val
20 |
21 | def pad_3d_vals_no_size(in_vals, dtype=np.int32):
22 | size1 = len(in_vals)
23 | size2 = np.max([len(x) for x in in_vals])
24 | size3 = 0
25 | for val in in_vals:
26 | cur_size3 = np.max([len(x) for x in val])
27 | if size3 len(in_vals): dim1_size = len(in_vals)
33 | for i in xrange(dim1_size):
34 | in_vals_i = in_vals[i]
35 | cur_dim2_size = dim2_size
36 | if cur_dim2_size > len(in_vals_i): cur_dim2_size = len(in_vals_i)
37 | for j in xrange(cur_dim2_size):
38 | in_vals_ij = in_vals_i[j]
39 | cur_dim3_size = dim3_size
40 | if cur_dim3_size > len(in_vals_ij): cur_dim3_size = len(in_vals_ij)
41 | out_val[i, j, :cur_dim3_size] = in_vals_ij[:cur_dim3_size]
42 | return out_val
43 |
44 | def pad_4d_vals(in_vals, dim1_size, dim2_size, dim3_size, dim4_size, dtype=np.int32):
45 | out_val = np.zeros((dim1_size, dim2_size, dim3_size, dim4_size), dtype=dtype)
46 | if dim1_size > len(in_vals): dim1_size = len(in_vals)
47 | for i in xrange(dim1_size):
48 | in_vals_i = in_vals[i]
49 | cur_dim2_size = dim2_size
50 | if cur_dim2_size > len(in_vals_i): cur_dim2_size = len(in_vals_i)
51 | for j in xrange(cur_dim2_size):
52 | in_vals_ij = in_vals_i[j]
53 | cur_dim3_size = dim3_size
54 | if cur_dim3_size > len(in_vals_ij): cur_dim3_size = len(in_vals_ij)
55 | for k in xrange(cur_dim3_size):
56 | in_vals_ijk = in_vals_ij[k]
57 | cur_dim4_size = dim4_size
58 | if cur_dim4_size > len(in_vals_ijk): cur_dim4_size = len(in_vals_ijk)
59 | out_val[i, j, k, :cur_dim4_size] = in_vals_ijk[:cur_dim4_size]
60 | return out_val
61 |
62 | def pad_target_labels(in_val, max_length, dtype=np.float32):
63 | batch_size = len(in_val)
64 | out_val = np.zeros((batch_size, max_length), dtype=dtype)
65 | for i in xrange(batch_size):
66 | for index in in_val[i]:
67 | out_val[i,index] = 1.0
68 | return out_val
69 |
--------------------------------------------------------------------------------
/src_s2s/padding_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | def make_batches(size, batch_size):
3 | nb_batch = int(np.ceil(size/float(batch_size)))
4 | return [(i*batch_size, min(size, (i+1)*batch_size)) for i in range(0, nb_batch)] # zgwang: starting point of each batch
5 |
6 | def pad_2d_vals_no_size(in_vals, dtype=np.int32):
7 | size1 = len(in_vals)
8 | size2 = np.max([len(x) for x in in_vals])
9 | return pad_2d_vals(in_vals, size1, size2, dtype=dtype)
10 |
11 | def pad_2d_vals(in_vals, dim1_size, dim2_size, dtype=np.int32):
12 | out_val = np.zeros((dim1_size, dim2_size), dtype=dtype)
13 | if dim1_size > len(in_vals): dim1_size = len(in_vals)
14 | for i in xrange(dim1_size):
15 | cur_in_vals = in_vals[i]
16 | cur_dim2_size = dim2_size
17 | if cur_dim2_size > len(cur_in_vals): cur_dim2_size = len(cur_in_vals)
18 | out_val[i,:cur_dim2_size] = cur_in_vals[:cur_dim2_size]
19 | return out_val
20 |
21 | def pad_3d_vals_no_size(in_vals, dtype=np.int32):
22 | size1 = len(in_vals)
23 | size2 = np.max([len(x) for x in in_vals])
24 | size3 = 0
25 | for val in in_vals:
26 | cur_size3 = np.max([len(x) for x in val])
27 | if size3 len(in_vals): dim1_size = len(in_vals)
34 | for i in xrange(dim1_size):
35 | in_vals_i = in_vals[i]
36 | cur_dim2_size = dim2_size
37 | if cur_dim2_size > len(in_vals_i): cur_dim2_size = len(in_vals_i)
38 | for j in xrange(cur_dim2_size):
39 | in_vals_ij = in_vals_i[j]
40 | cur_dim3_size = dim3_size
41 | if cur_dim3_size > len(in_vals_ij): cur_dim3_size = len(in_vals_ij)
42 | out_val[i, j, :cur_dim3_size] = in_vals_ij[:cur_dim3_size]
43 | return out_val
44 |
45 | def pad_4d_vals(in_vals, dim1_size, dim2_size, dim3_size, dim4_size, dtype=np.int32):
46 | out_val = np.zeros((dim1_size, dim2_size, dim3_size, dim4_size), dtype=dtype)
47 | if dim1_size > len(in_vals): dim1_size = len(in_vals)
48 | for i in xrange(dim1_size):
49 | in_vals_i = in_vals[i]
50 | cur_dim2_size = dim2_size
51 | if cur_dim2_size > len(in_vals_i): cur_dim2_size = len(in_vals_i)
52 | for j in xrange(cur_dim2_size):
53 | in_vals_ij = in_vals_i[j]
54 | cur_dim3_size = dim3_size
55 | if cur_dim3_size > len(in_vals_ij): cur_dim3_size = len(in_vals_ij)
56 | for k in xrange(cur_dim3_size):
57 | in_vals_ijk = in_vals_ij[k]
58 | cur_dim4_size = dim4_size
59 | if cur_dim4_size > len(in_vals_ijk): cur_dim4_size = len(in_vals_ijk)
60 | out_val[i, j, k, :cur_dim4_size] = in_vals_ijk[:cur_dim4_size]
61 | return out_val
62 |
63 | def pad_target_labels(in_val, max_length, dtype=np.float32):
64 | batch_size = len(in_val)
65 | out_val = np.zeros((batch_size, max_length), dtype=dtype)
66 | for i in xrange(batch_size):
67 | for index in in_val[i]:
68 | out_val[i,index] = 1.0
69 | return out_val
70 |
--------------------------------------------------------------------------------
/src_g2s/metric_utils.py:
--------------------------------------------------------------------------------
1 | import cPickle as pickle
2 | import os
3 | import sys
4 | from metric_bleu_utils import Bleu
5 | from metric_rouge_utils import Rouge
6 |
7 | def score_all(ref, hypo):
8 | scorers = [
9 | (Bleu(4),["Bleu_1","Bleu_2","Bleu_3","Bleu_4"]),
10 | (Rouge(),"ROUGE_L"),
11 | ]
12 | final_scores = {}
13 | for scorer,method in scorers:
14 | score,scores = scorer.compute_score(ref,hypo)
15 | if type(score)==list:
16 | for m,s in zip(method,score):
17 | final_scores[m] = s
18 | else:
19 | final_scores[method] = score
20 |
21 | return final_scores
22 |
23 | def score(ref, hypo):
24 | scorers = [
25 | (Bleu(4),["Bleu_1","Bleu_2","Bleu_3","Bleu_4"])
26 |
27 | ]
28 | final_scores = {}
29 | for scorer,method in scorers:
30 | score,scores = scorer.compute_score(ref,hypo)
31 | if type(score)==list:
32 | for m,s in zip(method,score):
33 | final_scores[m] = s
34 | else:
35 | final_scores[method] = score
36 |
37 | return final_scores
38 |
39 | def evaluate_captions(ref,cand):
40 | hypo = {}
41 | refe = {}
42 | for i, caption in enumerate(cand):
43 | hypo[i] = [caption,]
44 | refe[i] = ref[i]
45 | final_scores = score(refe, hypo)
46 | return 1*final_scores['Bleu_4'] + 1*final_scores['Bleu_3'] + 0.5*final_scores['Bleu_1'] + 0.5*final_scores['Bleu_2']
47 |
48 | def evaluate(data_path='./data', split='val', get_scores=False):
49 | reference_path = os.path.join(data_path, "%s/%s.references.pkl" %(split, split))
50 | candidate_path = os.path.join(data_path, "%s/%s.candidate.captions.pkl" %(split, split))
51 |
52 | # load caption data
53 | with open(reference_path, 'rb') as f:
54 | ref = pickle.load(f)
55 | with open(candidate_path, 'rb') as f:
56 | cand = pickle.load(f)
57 |
58 | # make dictionary
59 | hypo = {}
60 | for i, caption in enumerate(cand):
61 | hypo[i] = [caption]
62 |
63 | # compute bleu score
64 | final_scores = score_all(ref, hypo)
65 |
66 | # print out scores
67 | print 'Bleu_1:\t',final_scores['Bleu_1']
68 | print 'Bleu_2:\t',final_scores['Bleu_2']
69 | print 'Bleu_3:\t',final_scores['Bleu_3']
70 | print 'Bleu_4:\t',final_scores['Bleu_4']
71 | print 'METEOR:\t',final_scores['METEOR']
72 | print 'ROUGE_L:',final_scores['ROUGE_L']
73 | print 'CIDEr:\t',final_scores['CIDEr']
74 |
75 | if get_scores:
76 | return final_scores
77 |
78 |
79 | if __name__ == "__main__":
80 | ref = [[u'a tiddy bear',u'a animal'],[u' a number of luggage bags on a cart in a lobby .', u' a cart filled with suitcases and bags .', u' trolley used for transporting personal luggage to guests rooms .', u' wheeled cart with luggage at lobby of commercial business .', u' a luggage cart topped with lots of luggage .']]
81 | dec = [u'some one',u' a man is standing next to a car with a suitcase .']
82 | r = [evaluate_captions([k], [v]) for k, v in zip(ref, dec)]
83 | print r
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
--------------------------------------------------------------------------------
/src_s2s/metric_utils.py:
--------------------------------------------------------------------------------
1 | import cPickle as pickle
2 | import os
3 | import sys
4 | from metric_bleu_utils import Bleu
5 | from metric_rouge_utils import Rouge
6 |
7 | def score_all(ref, hypo):
8 | scorers = [
9 | (Bleu(4),["Bleu_1","Bleu_2","Bleu_3","Bleu_4"]),
10 | (Rouge(),"ROUGE_L"),
11 | ]
12 | final_scores = {}
13 | for scorer,method in scorers:
14 | score,scores = scorer.compute_score(ref,hypo)
15 | if type(score)==list:
16 | for m,s in zip(method,score):
17 | final_scores[m] = s
18 | else:
19 | final_scores[method] = score
20 |
21 | return final_scores
22 |
23 | def score(ref, hypo):
24 | scorers = [
25 | (Bleu(4),["Bleu_1","Bleu_2","Bleu_3","Bleu_4"])
26 |
27 | ]
28 | final_scores = {}
29 | for scorer,method in scorers:
30 | score,scores = scorer.compute_score(ref,hypo)
31 | if type(score)==list:
32 | for m,s in zip(method,score):
33 | final_scores[m] = s
34 | else:
35 | final_scores[method] = score
36 |
37 | return final_scores
38 |
39 | def evaluate_captions(ref,cand):
40 | hypo = {}
41 | refe = {}
42 | for i, caption in enumerate(cand):
43 | hypo[i] = [caption,]
44 | refe[i] = ref[i]
45 | final_scores = score(refe, hypo)
46 | return 1*final_scores['Bleu_4'] + 1*final_scores['Bleu_3'] + 0.5*final_scores['Bleu_1'] + 0.5*final_scores['Bleu_2']
47 |
48 | def evaluate(data_path='./data', split='val', get_scores=False):
49 | reference_path = os.path.join(data_path, "%s/%s.references.pkl" %(split, split))
50 | candidate_path = os.path.join(data_path, "%s/%s.candidate.captions.pkl" %(split, split))
51 |
52 | # load caption data
53 | with open(reference_path, 'rb') as f:
54 | ref = pickle.load(f)
55 | with open(candidate_path, 'rb') as f:
56 | cand = pickle.load(f)
57 |
58 | # make dictionary
59 | hypo = {}
60 | for i, caption in enumerate(cand):
61 | hypo[i] = [caption]
62 |
63 | # compute bleu score
64 | final_scores = score_all(ref, hypo)
65 |
66 | # print out scores
67 | print 'Bleu_1:\t',final_scores['Bleu_1']
68 | print 'Bleu_2:\t',final_scores['Bleu_2']
69 | print 'Bleu_3:\t',final_scores['Bleu_3']
70 | print 'Bleu_4:\t',final_scores['Bleu_4']
71 | print 'METEOR:\t',final_scores['METEOR']
72 | print 'ROUGE_L:',final_scores['ROUGE_L']
73 | print 'CIDEr:\t',final_scores['CIDEr']
74 |
75 | if get_scores:
76 | return final_scores
77 |
78 |
79 | if __name__ == "__main__":
80 | ref = [[u'a tiddy bear',u'a animal'],[u' a number of luggage bags on a cart in a lobby .', u' a cart filled with suitcases and bags .', u' trolley used for transporting personal luggage to guests rooms .', u' wheeled cart with luggage at lobby of commercial business .', u' a luggage cart topped with lots of luggage .']]
81 | dec = [u'some one',u' a man is standing next to a car with a suitcase .']
82 | r = [evaluate_captions([k], [v]) for k, v in zip(ref, dec)]
83 | print r
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
--------------------------------------------------------------------------------
/src_g2s/metric_rouge_utils.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | #
3 | # File Name : rouge.py
4 | #
5 | # Description : Computes ROUGE-L metric as described by Lin and Hovey (2004)
6 | #
7 | # Creation Date : 2015-01-07 06:03
8 | # Author : Ramakrishna Vedantam
9 |
10 | import numpy as np
11 | import pdb
12 |
13 | def my_lcs(string, sub):
14 | """
15 | Calculates longest common subsequence for a pair of tokenized strings
16 | :param string : list of str : tokens from a string split using whitespace
17 | :param sub : list of str : shorter string, also split using whitespace
18 | :returns: length (list of int): length of the longest common subsequence between the two strings
19 | Note: my_lcs only gives length of the longest common subsequence, not the actual LCS
20 | """
21 | if(len(string)< len(sub)):
22 | sub, string = string, sub
23 |
24 | lengths = [[0 for i in range(0,len(sub)+1)] for j in range(0,len(string)+1)]
25 |
26 | for j in range(1,len(sub)+1):
27 | for i in range(1,len(string)+1):
28 | if(string[i-1] == sub[j-1]):
29 | lengths[i][j] = lengths[i-1][j-1] + 1
30 | else:
31 | lengths[i][j] = max(lengths[i-1][j] , lengths[i][j-1])
32 |
33 | return lengths[len(string)][len(sub)]
34 |
35 | class Rouge():
36 | '''
37 | Class for computing ROUGE-L score for a set of candidate sentences for the MS COCO test set
38 | '''
39 | def __init__(self):
40 | # vrama91: updated the value below based on discussion with Hovey
41 | self.beta = 1.2
42 |
43 | def calc_score(self, candidate, refs):
44 | """
45 | Compute ROUGE-L score given one candidate and references for an image
46 | :param candidate: str : candidate sentence to be evaluated
47 | :param refs: list of str : COCO reference sentences for the particular image to be evaluated
48 | :returns score: int (ROUGE-L score for the candidate evaluated against references)
49 | """
50 | assert(len(candidate)==1)
51 | assert(len(refs)>0)
52 | prec = []
53 | rec = []
54 |
55 | # split into tokens
56 | token_c = candidate[0].split(" ")
57 |
58 | for reference in refs:
59 | # split into tokens
60 | token_r = reference.split(" ")
61 | # compute the longest common subsequence
62 | lcs = my_lcs(token_r, token_c)
63 | prec.append(lcs/float(len(token_c)))
64 | rec.append(lcs/float(len(token_r)))
65 |
66 | prec_max = max(prec)
67 | rec_max = max(rec)
68 |
69 | if(prec_max!=0 and rec_max !=0):
70 | score = ((1 + self.beta**2)*prec_max*rec_max)/float(rec_max + self.beta**2*prec_max)
71 | else:
72 | score = 0.0
73 | return score
74 |
75 | def compute_score(self, gts, res):
76 | """
77 | Computes Rouge-L score given a set of reference and candidate sentences for the dataset
78 | Invoked by evaluate_captions.py
79 | :param hypo_for_image: dict : candidate / test sentences with "image name" key and "tokenized sentences" as values
80 | :param ref_for_image: dict : reference MS-COCO sentences with "image name" key and "tokenized sentences" as values
81 | :returns: average_score: float (mean ROUGE-L score computed by averaging scores for all the images)
82 | """
83 | assert(gts.keys() == res.keys())
84 | imgIds = gts.keys()
85 |
86 | score = []
87 | for id in imgIds:
88 | hypo = res[id]
89 | ref = gts[id]
90 |
91 | score.append(self.calc_score(hypo, ref))
92 |
93 | # Sanity check.
94 | assert(type(hypo) is list)
95 | assert(len(hypo) == 1)
96 | assert(type(ref) is list)
97 | assert(len(ref) > 0)
98 |
99 | average_score = np.mean(np.array(score))
100 | return average_score, np.array(score)
101 |
102 | def method(self):
103 | return "Rouge"
--------------------------------------------------------------------------------
/src_s2s/metric_rouge_utils.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | #
3 | # File Name : rouge.py
4 | #
5 | # Description : Computes ROUGE-L metric as described by Lin and Hovey (2004)
6 | #
7 | # Creation Date : 2015-01-07 06:03
8 | # Author : Ramakrishna Vedantam
9 |
10 | import numpy as np
11 | import pdb
12 |
13 | def my_lcs(string, sub):
14 | """
15 | Calculates longest common subsequence for a pair of tokenized strings
16 | :param string : list of str : tokens from a string split using whitespace
17 | :param sub : list of str : shorter string, also split using whitespace
18 | :returns: length (list of int): length of the longest common subsequence between the two strings
19 | Note: my_lcs only gives length of the longest common subsequence, not the actual LCS
20 | """
21 | if(len(string)< len(sub)):
22 | sub, string = string, sub
23 |
24 | lengths = [[0 for i in range(0,len(sub)+1)] for j in range(0,len(string)+1)]
25 |
26 | for j in range(1,len(sub)+1):
27 | for i in range(1,len(string)+1):
28 | if(string[i-1] == sub[j-1]):
29 | lengths[i][j] = lengths[i-1][j-1] + 1
30 | else:
31 | lengths[i][j] = max(lengths[i-1][j] , lengths[i][j-1])
32 |
33 | return lengths[len(string)][len(sub)]
34 |
35 | class Rouge():
36 | '''
37 | Class for computing ROUGE-L score for a set of candidate sentences for the MS COCO test set
38 | '''
39 | def __init__(self):
40 | # vrama91: updated the value below based on discussion with Hovey
41 | self.beta = 1.2
42 |
43 | def calc_score(self, candidate, refs):
44 | """
45 | Compute ROUGE-L score given one candidate and references for an image
46 | :param candidate: str : candidate sentence to be evaluated
47 | :param refs: list of str : COCO reference sentences for the particular image to be evaluated
48 | :returns score: int (ROUGE-L score for the candidate evaluated against references)
49 | """
50 | assert(len(candidate)==1)
51 | assert(len(refs)>0)
52 | prec = []
53 | rec = []
54 |
55 | # split into tokens
56 | token_c = candidate[0].split(" ")
57 |
58 | for reference in refs:
59 | # split into tokens
60 | token_r = reference.split(" ")
61 | # compute the longest common subsequence
62 | lcs = my_lcs(token_r, token_c)
63 | prec.append(lcs/float(len(token_c)))
64 | rec.append(lcs/float(len(token_r)))
65 |
66 | prec_max = max(prec)
67 | rec_max = max(rec)
68 |
69 | if(prec_max!=0 and rec_max !=0):
70 | score = ((1 + self.beta**2)*prec_max*rec_max)/float(rec_max + self.beta**2*prec_max)
71 | else:
72 | score = 0.0
73 | return score
74 |
75 | def compute_score(self, gts, res):
76 | """
77 | Computes Rouge-L score given a set of reference and candidate sentences for the dataset
78 | Invoked by evaluate_captions.py
79 | :param hypo_for_image: dict : candidate / test sentences with "image name" key and "tokenized sentences" as values
80 | :param ref_for_image: dict : reference MS-COCO sentences with "image name" key and "tokenized sentences" as values
81 | :returns: average_score: float (mean ROUGE-L score computed by averaging scores for all the images)
82 | """
83 | assert(gts.keys() == res.keys())
84 | imgIds = gts.keys()
85 |
86 | score = []
87 | for id in imgIds:
88 | hypo = res[id]
89 | ref = gts[id]
90 |
91 | score.append(self.calc_score(hypo, ref))
92 |
93 | # Sanity check.
94 | assert(type(hypo) is list)
95 | assert(len(hypo) == 1)
96 | assert(type(ref) is list)
97 | assert(len(ref) > 0)
98 |
99 | average_score = np.mean(np.array(score))
100 | return average_score, np.array(score)
101 |
102 | def method(self):
103 | return "Rouge"
--------------------------------------------------------------------------------
/src_s2s/prepare_paraphrase_dataset.py:
--------------------------------------------------------------------------------
1 | import json
2 | import random
3 | import re
4 |
5 | def read_MSCOCO(inpath):
6 | # load json file
7 | with open(inpath) as dataset_file:
8 | dataset_json = json.load(dataset_file, encoding='utf-8')
9 | annotations = dataset_json['annotations']
10 | print(len(annotations))
11 |
12 | # dispatch each caption to its corresponding image
13 | image_captions_dict = {}
14 | for annotation in annotations:
15 | image_id = annotation['image_id']
16 | id = annotation['id']
17 | caption = annotation['caption']
18 | captions = None
19 | if image_captions_dict.has_key(image_id):
20 | captions = image_captions_dict[image_id]
21 | else:
22 | captions = []
23 | captions.append(caption)
24 | image_captions_dict[image_id] = captions
25 |
26 | # check number of captions for each image
27 | all_instances = []
28 | for image_id in image_captions_dict.keys():
29 | captions = image_captions_dict[image_id]
30 | random.shuffle(captions)
31 | # print(len(captions))
32 | all_instances.append((image_id, captions))
33 | return all_instances
34 |
35 | def read_Quora(inpath):
36 | with open(inpath, "rt") as f:
37 | for line in f:
38 | line = line.decode('utf-8')
39 | line = line.strip()
40 | items = re.split('\t', line)
41 |
42 |
43 | def dump_out_to_json(instances, outpath):
44 | json_instances = []
45 | for (image_id, captions) in instances:
46 | json_instances.append({'id': str(image_id), 'text1': captions[0], 'text2': captions[1]})
47 | # json_instances.append({'id': str(image_id) + "-2", 'text1': captions[2], 'text2': captions[3]})
48 | with open(outpath, 'w') as outfile:
49 | json.dump(json_instances, outfile)
50 |
51 | def create_json_file(all_instances, outpath, batch_size=5000):
52 | import padding_utils
53 | batch_spans = padding_utils.make_batches(len(all_instances), batch_size)
54 | for batch_index, (batch_start, batch_end) in enumerate(batch_spans):
55 | cur_instances = all_instances[batch_start:batch_end]
56 | cur_outpath = outpath + ".{}".format(batch_index)
57 | print("Dump {} instances out to {}".format(len(cur_instances), cur_outpath))
58 | dump_out_to_json(cur_instances, cur_outpath)
59 |
60 | if __name__ == "__main__":
61 | ''' # create mscoco dataset
62 | dataset = "train"
63 | inpath = "/u/zhigwang/zhigwang1/sentence_generation/mscoco/annotations/captions_" + dataset + "2014.json"
64 | outpath = "/u/zhigwang/zhigwang1/sentence_generation/mscoco/data/" + dataset + ".json"
65 | all_instances = read_MSCOCO(inpath)
66 | batch_size = 5000
67 | create_json_file(all_instances, outpath, batch_size=batch_size)
68 | '''
69 |
70 | # create quora dataset
71 | rawpath = "/u/zhigwang/zhigwang1/sentence_match/quora/quora_duplicate_questions.tsv"
72 | batch_size = 10000
73 | # load all pairs
74 | print('Loading all question pairs ...')
75 | id_instances_dict = {}
76 | with open(rawpath, "rt") as f:
77 | for line in f:
78 | line = line.decode('utf-8').strip()
79 | if not line.endswith("1"): continue
80 | items = re.split('\t', line)
81 | cur_id = items[0]
82 | sent1 = items[3]
83 | sent2 = items[4]
84 | id_instances_dict[cur_id] = (sent1, sent2)
85 | print(len(id_instances_dict))
86 |
87 | for dataset in ['dev', 'test', 'train']:
88 | # collect all isntances
89 | inpath = "/u/zhigwang/zhigwang1/sentence_match/quora/" + dataset + ".tsv"
90 | outpath = "/u/zhigwang/zhigwang1/sentence_generation/quora/data/" + dataset + ".json"
91 | print(inpath)
92 | all_instances = []
93 | with open(inpath, "rt") as f:
94 | for line in f:
95 | line = line.decode('utf-8').strip()
96 | if not line.startswith("1"): continue
97 | items = re.split('\t', line)
98 | cur_id = items[3]
99 | all_instances.append((cur_id, id_instances_dict[cur_id]))
100 | create_json_file(all_instances, outpath, batch_size=batch_size)
101 |
102 |
103 |
104 | print('DONE!')
--------------------------------------------------------------------------------
/src_s2s/prepare_summarization_dataset.py:
--------------------------------------------------------------------------------
1 | import hashlib
2 | import os
3 | import json
4 |
5 | dm_single_close_quote = u'\u2019' # unicode
6 | dm_double_close_quote = u'\u201d'
7 | END_TOKENS = ['.', '!', '?', '...', "'", "`", '"', dm_single_close_quote, dm_double_close_quote, ")"] # acceptable ways to end a sentence
8 |
9 | SENTENCE_START = ''
10 | SENTENCE_END = ''
11 |
12 | def read_text_file(text_file):
13 | lines = []
14 | with open(text_file, "rt") as f:
15 | for line in f:
16 | line = line.decode('utf-8')
17 | lines.append(line.strip())
18 | return lines
19 |
20 | def fix_missing_period(line):
21 | """Adds a period to a line that is missing a period"""
22 | if "@highlight" in line: return line
23 | if line=="": return line
24 | if line[-1] in END_TOKENS: return line
25 | return line + " ."
26 |
27 | def get_art_abs(story_file):
28 | lines = read_text_file(story_file)
29 |
30 | # Lowercase everything
31 | # lines = [line.lower() for line in lines]
32 | # Put periods on the ends of lines that are missing them (this is a problem in the dataset because many image captions don't end in periods; consequently they end up in the body of the article as run-on sentences)
33 | lines = [fix_missing_period(line) for line in lines]
34 |
35 | # Separate out article and abstract sentences
36 | article_lines = []
37 | highlights = []
38 | next_is_highlight = False
39 | for idx,line in enumerate(lines):
40 | if line == "":
41 | continue # empty line
42 | elif line.startswith("@highlight"):
43 | next_is_highlight = True
44 | elif next_is_highlight:
45 | highlights.append(line)
46 | else:
47 | article_lines.append(line)
48 |
49 | # Make article into a single string
50 | article = ' '.join(article_lines)
51 |
52 | # Make abstract into a signle string, putting and tags around the sentences
53 | abstract = ' '.join(["%s %s %s" % (SENTENCE_START, sent, SENTENCE_END) for sent in highlights])
54 |
55 | return article, abstract
56 |
57 | def hashhex(s):
58 | """Returns a heximal formated SHA1 hash of the input string."""
59 | h = hashlib.sha1()
60 | h.update(s)
61 | return h.hexdigest()
62 |
63 | def dump_out_to_json(hash_path_dict, urls, outpath):
64 | all_instances = []
65 | for cur_url in urls:
66 | cur_hash_code = hashhex(cur_url)
67 | cur_path = hash_path_dict[cur_hash_code]
68 | (article, abstract) = get_art_abs(cur_path)
69 | all_instances.append({'id': cur_hash_code, 'text1': article, 'text2': abstract})
70 |
71 | with open(outpath, 'w') as outfile:
72 | json.dump(all_instances, outfile)
73 |
74 | def create_json_file(hash_path_dict, urlpath, outpath, batch_size=5000):
75 | all_urls = read_text_file(urlpath)
76 | import padding_utils
77 | batch_spans = padding_utils.make_batches(len(all_urls), batch_size)
78 | for batch_index, (batch_start, batch_end) in enumerate(batch_spans):
79 | cur_urls = all_urls[batch_start:batch_end]
80 | cur_outpath = outpath + ".{}".format(batch_index)
81 | print("Dump {} instances out to {}".format(len(cur_urls), cur_outpath))
82 | dump_out_to_json(hash_path_dict, cur_urls, cur_outpath)
83 |
84 |
85 |
86 | def process():
87 | cnn_in_dir = "/u/zhigwang/zhigwang1/sentence_generation/cnn-dailymail/cnn/cnn/stories"
88 | daily_mail_in_dir = "/u/zhigwang/zhigwang1/sentence_generation/cnn-dailymail/dailymail/dailymail/stories"
89 | train_urls = "/u/zhigwang/zhigwang1/sentence_generation/cnn-dailymail/url_lists/all_train.txt"
90 | val_urls = "/u/zhigwang/zhigwang1/sentence_generation/cnn-dailymail/url_lists/all_val.txt"
91 | test_urls = "/u/zhigwang/zhigwang1/sentence_generation/cnn-dailymail/url_lists/all_test.txt"
92 | outdir = "/u/zhigwang/zhigwang1/sentence_generation/cnn-dailymail/data"
93 | batch_size = 5000
94 |
95 | # collect all files
96 | print('collecting all files')
97 | in_dirs = [cnn_in_dir, daily_mail_in_dir]
98 | hash_path_dict = {}
99 | for in_dir in in_dirs:
100 | all_paths= os.listdir(in_dir)
101 | for cur_path in all_paths:
102 | if not cur_path.endswith(".story"): continue
103 | cur_hash_code = cur_path[:-len('.story')]
104 | cur_path = in_dir + "/" + cur_path
105 | hash_path_dict[cur_hash_code] = cur_path
106 | # print(cur_path)
107 | # print(cur_hash_code)
108 | print('number of files: {}'.format(len(hash_path_dict)))
109 | print('Creating val.json')
110 | create_json_file(hash_path_dict, val_urls, outdir + "/val.json", batch_size=batch_size)
111 | print('Creating test.json')
112 | create_json_file(hash_path_dict, test_urls, outdir + "/test.json", batch_size=batch_size)
113 | print('Creating train.json')
114 | create_json_file(hash_path_dict, train_urls, outdir + "/train.json", batch_size=batch_size)
115 |
116 | def generate_tok_commandlines(inpath):
117 | all_paths = read_text_file(inpath)
118 | for cur_path in all_paths:
119 | print("jbsub -q x86_12h -cores 10 -name process_summarization -mem 21G "
120 | + "/u/zhigwang/workspace/FactoidQA_Java/scripts/process_generation_datasets.sh "
121 | + "{} {}.tok 10".format(cur_path, cur_path))
122 |
123 |
124 | if __name__ == "__main__":
125 | '''
126 | inpath = "/u/zhigwang/zhigwang1/sentence_generation/cnn/cnn/stories/fffcd65676a501860ae312754e8cefc71f5ddab8.story"
127 | (article, abstract) = get_art_abs(inpath)
128 | print(article)
129 | print(abstract)
130 | '''
131 |
132 | # process()
133 | #'''
134 | inpath = "/u/zhigwang/zhigwang1/sentence_generation/cnn-dailymail/data/fof"
135 | # inpath = "/u/zhigwang/zhigwang1/sentence_generation/mscoco/data/fof"
136 | # inpath = "/u/zhigwang/zhigwang1/sentence_generation/quora/data/fof"
137 | generate_tok_commandlines(inpath)
138 | #'''
139 | # print("DONE!")
--------------------------------------------------------------------------------
/src_s2s/sent_utils.py:
--------------------------------------------------------------------------------
1 | import re
2 |
3 | class QASentence(object):
4 | def __init__(self, rawText, annotation, ID_num=None, isLower=False, end_sym=None):
5 | self.rawText = rawText
6 | self.annotation = annotation
7 | self.tokText = annotation['toks']
8 | # it's the answer sequence
9 | if end_sym != None:
10 | self.rawText += ' ' + end_sym
11 | self.tokText += ' ' + end_sym
12 | if isLower: self.tokText = self.tokText.lower()
13 | self.words = re.split("\\s+", self.tokText)
14 | self.startPositions = []
15 | self.endPositions = []
16 | positions = re.split("\\s+", annotation['positions'])
17 | for i in xrange(len(positions)):
18 | tmps = re.split("-", positions[i])
19 | self.startPositions.append(int(tmps[1]))
20 | self.endPositions.append(int(tmps[2]))
21 | self.POSs = annotation['POSs']
22 | self.NERs = annotation['NERs']
23 | if annotation.has_key('spans'): self.syntaxSpans = annotation['spans']
24 | self.length = len(self.words)
25 | self.ID_num = ID_num
26 |
27 | self.index_convered = False
28 | self.chunk_starts = None
29 |
30 | def chunk(self, maxlen):
31 | self.words = self.words[:maxlen]
32 | self.startPositions = self.startPositions[:maxlen]
33 | self.endPositions = self.endPositions[:maxlen]
34 | self.POSs = self.POSs[:maxlen]
35 | self.NERs = self.NERs[:maxlen]
36 |
37 | if self.index_convered:
38 | self.word_idx_seq = self.word_idx_seq[:maxlen]
39 | self.char_idx_seq = self.char_idx_seq[:maxlen]
40 | self.POS_idx_seq = self.POS_idx_seq[:maxlen]
41 | self.NER_idx_seq = self.NER_idx_seq[:maxlen]
42 |
43 | self.length = len(self.words)
44 |
45 | def TokSpan2RawSpan(self, startTokID, endTokID):
46 | start = self.startPositions[startTokID]
47 | end = self.endPositions[endTokID]
48 | return (start, end)
49 |
50 | def RawSpan2TokSpan(self, start, end):
51 | startTokID = -1
52 | endTokID = -1
53 | for i in xrange(len(self.startPositions)):
54 | if self.startPositions[i] == start:
55 | startTokID = i
56 | if self.endPositions[i] == end:
57 | endTokID = i
58 | return (startTokID, endTokID)
59 |
60 | def getRawChunk(self, start, end):
61 | if end>len(self.rawText):
62 | return None
63 | return self.rawText[start:end]
64 |
65 | def getRawChunkWithTokSpan(self, startTokID, endTokID):
66 | start = self.startPositions[startTokID]
67 | end = self.endPositions[endTokID]
68 | return self.rawText[start:end]
69 |
70 | def getTokChunk(self, startTokID, endTokID):
71 | curWords = []
72 | for i in xrange(startTokID,endTokID+1):
73 | curWords.append(self.words[i])
74 | return " ".join(curWords)
75 |
76 | def get_length(self):
77 | return self.length
78 |
79 | def get_max_word_len(self):
80 | max_word_len = 0
81 | for word in self.words:
82 | cur_len = len(word)
83 | if max_word_len < cur_len: max_word_len = cur_len
84 | return max_word_len
85 |
86 | def get_char_len(self):
87 | char_lens = []
88 | for word in self.words:
89 | cur_len = len(word)
90 | char_lens.append(cur_len)
91 | return char_lens
92 |
93 | def convert2index(self, word_vocab, char_vocab, POS_vocab, NER_vocab, max_char_per_word=-1):
94 | if self.index_convered: return # for each sentence, only conver once
95 |
96 | if word_vocab is not None:
97 | self.word_idx_seq = word_vocab.to_index_sequence(self.tokText)
98 |
99 | if char_vocab is not None:
100 | self.char_idx_seq = char_vocab.to_character_matrix(self.tokText, max_char_per_word=max_char_per_word)
101 |
102 | if POS_vocab is not None:
103 | self.POS_idx_seq = POS_vocab.to_index_sequence(self.POSs)
104 |
105 | if NER_vocab is not None:
106 | self.NER_idx_seq = NER_vocab.to_index_sequence(self.NERs)
107 |
108 | self.index_convered = True
109 |
110 | def collect_all_possible_chunks(self, max_chunk_len):
111 | if self.chunk_starts is None:
112 | self.chunk_starts = []
113 | self.chunk_ends = []
114 | for i in xrange(self.length):
115 | cur_word = self.words[i]
116 | if cur_word in ".!?;": continue
117 | for j in xrange(i, i+max_chunk_len):
118 | if j>=self.length: break
119 | cur_word = self.words[j]
120 | if cur_word in ".!?;": break
121 | self.chunk_starts.append(i)
122 | self.chunk_ends.append(j)
123 | return (self.chunk_starts, self.chunk_ends)
124 |
125 | def collect_all_entities(self):
126 | items = re.split("\\s+", self.NERs)
127 | prev_label = "O"
128 | cur_start = -1
129 | chunk_starts = []
130 | chunk_ends = []
131 | for i in xrange(len(items)):
132 | cur_label = items[i]
133 | if cur_label != prev_label:
134 | if cur_start != -1:
135 | chunk_starts.append(cur_start)
136 | chunk_ends.append(i-1)
137 | cur_start = -1
138 | if cur_label != "O":
139 | cur_start = i
140 | prev_label = cur_label
141 | if cur_start !=-1:
142 | chunk_starts.append(cur_start)
143 | chunk_ends.append(len(items)-1)
144 | return (chunk_starts, chunk_ends)
145 |
146 | def collect_all_syntax_chunks(self, max_chunk_len):
147 | if self.chunk_starts is None:
148 | self.chunk_starts = []
149 | self.chunk_ends = []
150 | self.chunk_labels = []
151 | all_spans = re.split("\\s+", self.syntaxSpans)
152 | for i in xrange(len(all_spans)):
153 | cur_span = all_spans[i]
154 | items = re.split("-", cur_span)
155 | cur_start = int(items[0])
156 | cur_end = int(items[1])
157 | cur_label = items[2]
158 | if cur_end-cur_start>=max_chunk_len: continue
159 | self.chunk_starts.append(cur_start)
160 | self.chunk_ends.append(cur_end)
161 | self.chunk_labels.append(cur_label)
162 | return (self.chunk_starts, self.chunk_ends, self.chunk_labels)
163 |
164 | if __name__ == "__main__":
165 | import NP2P_data_stream
166 | inpath = "/u/zhigwang/zhigwang1/sentence_generation/cnn-dailymail/data/val.json.tok"
167 | all_instances,_ = NP2P_data_stream.read_all_GenerationDatasets(inpath, isLower=True)
168 | sample_instance = all_instances[0][1]
169 | print('Raw text: {}'.format(sample_instance.rawText))
170 | (chunk_starts, chunk_ends, chunk_labels) = sample_instance.collect_all_syntax_chunks(5)
171 | for i in xrange(len(chunk_starts)):
172 | cur_start = chunk_starts[i]
173 | cur_end = chunk_ends[i]
174 | cur_label = chunk_labels[i]
175 | cur_text = sample_instance.getTokChunk(cur_start, cur_end)
176 | print("{}-{}-{}:{}".format(cur_start, cur_end, cur_label, cur_text))
177 | print("DONE!")
178 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Neural Graph to Sequence Model
2 |
3 | This repository contains the code for our paper [A Graph-to-Sequence Model for AMR-to-Text Generation](https://arxiv.org/abs/1805.02473) in ACL 2018
4 |
5 | The code is developed under TensorFlow 1.4.1.
6 | We shared our pretrained model along this repository.
7 | Due to the compitibility reason of TensorFlow, it may not be loaded by some lower version (such as 1.0.0).
8 |
9 | Please create issues if there are any questions! This can make things more tractable.
10 |
11 | ## Update logs
12 |
13 | ### Link update for silver AMR data and AMR simplifier (Oct. 13th, 2020)
14 | Due to the long period since my PhD graduation, my old homepage cannot be accessed now.
15 | I've moved data and scripts to Google Drive for download now.
16 | Pretrained models will be uploaded to Google Drive soon.
17 |
18 | ### Results on WebNLG and LDC2017T10 datasets (Aug. 4th, 2019)
19 | Following the same setting (such as data preprocessing) of [Marcheggiani and Perez-Beltrachini (INLG 2019)](https://www.aclweb.org/anthology/W18-6501), our model achieves a BLEU score of *64.2* on WebNLG dataset.
20 |
21 | For another widely used AMR dataset, LDC2017T10, our model achieves a similar BLEU score as that using LDC2015E86.
22 |
23 | ### Be careful about your tokenizer (Feb. 27th, 2019)
24 | We use the [PTB_tokenizer](https://nlp.stanford.edu/software/tokenizer.shtml) from Stanford corenlp to preprocess our data. If you plan to use our pretrained model, please be careful on the tokenizer you use.
25 | Also, be careful to keep the words cased during preprocessing, as the PTB_tokenizer is sensitive to that.
26 |
27 | ### Release of 2M automatically parsed data (link updated on Sep. 24th, 2020)
28 | We release our [2M sentences with their automatically parsed AMRs](https://drive.google.com/file/d/1mmTZFxRdBFOXbH6FCViQ30bVRIvGrwvx/view?usp=sharing) to the public.
29 |
30 | ## About AMR
31 |
32 | AMR is a graph-based semantic formalism, which can unified representations for several sentences of the same meaning.
33 | Comparing with other structures, such as dependency and semantic roles, the AMR graphs have several key differences:
34 | * AMRs only focus on concepts and their relations, so no function words are included. Actually the edge labels serve the role of function words.
35 | * Inflections are dropped when converting a noun, a verb or named entity into a AMR concept. Sometimes a synonym is used instead of the original word. This makes more unified AMRs so that each AMR graph can represent more sentences.
36 | * Relation tags (edge labels) are predefined and are not extracted from text (like the way OpenIE does).
37 | More details are in the official AMR page [AMR website@ISI](https://amr.isi.edu/download.html), where
38 | you can download the public-available AMR bank: [little prince](https://amr.isi.edu/download/amr-bank-struct-v1.6.txt).
39 | Try it for fun!
40 |
41 | ## Data precrocessing
42 | The [data loader](./src_g2s/G2S_data_stream.py) of our model requires simplified AMR graphs where variable tags, sense tags and quotes are removed. For example, the following AMR
43 | ```
44 | (d / describe-01
45 | :ARG0 (p / person
46 | :name (n / name
47 | :op1 "Ryan"))
48 | :ARG1 p
49 | :ARG2 genius)
50 | ```
51 | need to be simplified as
52 | ```
53 | describe :arg0 ( person :name ( name :op1 ryan ) ) :arg1 person :arg2 genius
54 | ```
55 | before being consumed by our model.
56 |
57 |
58 | We provide our scripts for AMR simplification.
59 | First, you need to make each AMR into a single line, where our released [script](./AMR_multiline_to_singleline.py) may serve your goal (You may need to slightly modify it).
60 | Second, to simplify the single-line AMRs, we release our tool that can be downloaded [here](https://drive.google.com/file/d/1PE8b44-H3Hu4I2Xf1XDAPFhH-GonR43i/view?usp=sharing).
61 | It is adapted from the [NeuralAMR](https://github.com/sinantie/NeuralAmr) system.
62 | To run our simplifier on a file ```demo.amr```, simply execute
63 | ```
64 | ./anonDeAnon_java.sh anonymizeAmrFull true demo.amr
65 | ```
66 | and it will output the simplified AMRs into ```demo.amr.anonymized```.
67 | Please note that our simplifier *does not* do anonymization.
68 | The resulting filename contains the 'anonymized' string because the original NeuralAMR creates the suffix.
69 |
70 |
71 | Another alternative is to write your own data loading code according to the format of your own AMR data.
72 |
73 |
74 | ### Input data format
75 | After simplifying your AMRs, you can merge them with the corresponding sentences into a JSON file.
76 | The JSON file is the actual input to the system.
77 | Its format is shown with the following sample:
78 | ```
79 | [{"amr": "describe :arg0 ( person :name ( name :op1 ryan ) ) :arg1 person :arg2 genius",
80 | "sent": "ryan 's description of himself : a genius .",
81 | "id": "demo"}]
82 | ```
83 | In general, the JSON file contains a list of instances, and each instance is a dictionary with fields of "amr", "sent" and "id"(optional).
84 |
85 | ### Vocabulary extraction
86 | After having the JSON files, you can extract vocabularies with our released scripts in the [./data/](./data/) directory.
87 | We also encourage you to write your own scripts.
88 |
89 | ## Training
90 |
91 | First, modify the PYTHONPATH within [train_g2s.sh](./train_g2s.sh) (for our graph-to-string model) or [train_s2s.sh](./train_s2s.sh) (for baseline).
92 | Second, modify config_g2s.json or config_s2s.json. You should pay attention to the field "suffix", which is an identifier of the model being trained and saved. We usually use the experiment setting, such as "bch20_lr1e3_l21e3", as the identifier.
93 | Finally, execute the corresponding script file, such as "./train_g2s.sh".
94 |
95 | ### Using large-scale automatic AMRs
96 |
97 | In this setting, we follow [Konstas et al., (2017)](https://arxiv.org/abs/1704.08381) to take the large-scale automatic data as the training set, taking the original gold data as a finetune set.
98 | To perform training in this way, you need to add a new field "finetune_path" in your config file and point it to the gold data. Besides the oringinal "train_path" should point to the automatic data.
99 |
100 | For training on the gold data only, we use an initial learning rate of 1e-3 and L2 normalization of 1e-3. We then lower the learning rate to be 8e-4, 5e-4 and 2e-4 after a number of epoches.
101 |
102 | For training on both gold and automatic data, the initial learning rate and L2 normalization are 5e-4 and 1e-8. We also lower the learning rate during training.
103 |
104 | The idea of lowering learning rate was first introduced by Konstas et al., (2017).
105 |
106 |
107 | ## Decoding with a pretained model
108 |
109 | Simply execute the corresponding decoding script with one argument being the identifier of the model you want to use.
110 | For instance, you can execute "./decode_g2s.sh bch20_lr1e3_l21e3".
111 | Please make sure you use the associated word vectors not others, because the pretrained model are *optimized* given the word vectors.
112 |
113 | ### Pretrained model
114 |
115 | We release a pretrained model (and word vectors) using gold plus 2M automatically-parsed AMRs [here](https://drive.google.com/file/d/1cP_VrGTGlRPdNOt2rV2wdOBnHrYDB10j/view?usp=sharing). With this model, we observed a BLEU of *33.6*, which is higher than our paper-reported number of 33.0. The pretrained model with only gold data is [here](https://drive.google.com/file/d/186snRVgx5nSLbAbgU9anbu4OMvrCTUP-/view?usp=sharing). It reports a test BLEU score of 23.3.
116 | The corresponding word-embedding file is [here](https://drive.google.com/file/d/1XhCW0eI1PQY51o_gB_MyG1DKsOG0bjIu/view?usp=sharing).
117 |
118 | ## Cite
119 | If you like our paper, please cite
120 | ```
121 | @inproceedings{song2018graph,
122 | title={A Graph-to-Sequence Model for AMR-to-Text Generation},
123 | author={Song, Linfeng and Zhang, Yue and Wang, Zhiguo and Gildea, Daniel},
124 | booktitle={Proceedings of the 56th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers)},
125 | pages={1616--1626},
126 | year={2018}
127 | }
128 | ```
129 |
--------------------------------------------------------------------------------
/src_s2s/NP2P_data_stream.py:
--------------------------------------------------------------------------------
1 | import json
2 | import re
3 | import numpy as np
4 | import random
5 | import padding_utils
6 | import phrase_lattice_utils
7 |
8 | def read_text_file(text_file):
9 | lines = []
10 | with open(text_file, "rt") as f:
11 | for line in f:
12 | line = line.decode('utf-8')
13 | lines.append(line.strip())
14 | return lines
15 |
16 |
17 | def read_all_GenerationDatasets(inpath, isLower=True):
18 | with open(inpath) as dataset_file:
19 | dataset = json.load(dataset_file, encoding='utf-8')
20 | all_instances = []
21 | max_answer_len = 0
22 | for instance in dataset:
23 | sent1 = instance['amr'].strip()
24 | sent2 = instance['sent'].strip()
25 | id = instance['id'] if 'id' in instance else None
26 | if sent1 == "" or sent2 == "":
27 | continue
28 | max_answer_len = max(max_answer_len, len(sent2.split())) # text2 is the sequence to be generated
29 | all_instances.append((sent1, sent2, id))
30 | return all_instances, max_answer_len
31 |
32 | def read_generation_datasets_from_fof(fofpath, isLower=True):
33 | all_paths = read_text_file(fofpath)
34 | all_instances = []
35 | max_answer_len = 0
36 | for cur_path in all_paths:
37 | print(cur_path)
38 | (cur_instances, cur_max_answer_len) = read_all_GenerationDatasets(cur_path, isLower=isLower)
39 | print("cur_max_answer_len: %s" % cur_max_answer_len)
40 | all_instances.extend(cur_instances)
41 | if max_answer_len 0.2 or oov_rate2 > 0.2:
68 | print('!!!!!oov_rate for ENC {} and DEC {}'.format(oov_rate1, oov_rate2))
69 | print(sent1)
70 | print(sent2)
71 | print('==============')
72 | if options.max_passage_len != -1: sent1_idx = sent1_idx[:options.max_passage_len]
73 | if options.max_answer_len != -1: sent2_idx = sent2_idx[:options.max_answer_len]
74 | instances.append((sent1_idx, sent2_idx, sent1, sent2, id))
75 |
76 | all_questions = instances
77 | instances = None
78 |
79 | # sort instances based on length
80 | if isSort:
81 | all_questions = sorted(all_questions, key=lambda xxx: (len(xxx[0]), len(xxx[1])))
82 | else:
83 | pass
84 | self.num_instances = len(all_questions)
85 |
86 | # distribute questions into different buckets
87 | batch_spans = padding_utils.make_batches(self.num_instances, batch_size)
88 | self.batches = []
89 | for batch_index, (batch_start, batch_end) in enumerate(batch_spans):
90 | cur_questions = []
91 | for i in xrange(batch_start, batch_end):
92 | cur_questions.append(all_questions[i])
93 | cur_batch = Batch(cur_questions, options, word_vocab=dec_word_vocab, char_vocab=char_vocab)
94 | self.batches.append(cur_batch)
95 |
96 | self.num_batch = len(self.batches)
97 | self.index_array = np.arange(self.num_batch)
98 | self.isShuffle = isShuffle
99 | if self.isShuffle: np.random.shuffle(self.index_array)
100 | self.isLoop = isLoop
101 | self.cur_pointer = 0
102 |
103 | def nextBatch(self):
104 | if self.cur_pointer>=self.num_batch:
105 | if not self.isLoop: return None
106 | self.cur_pointer = 0
107 | if self.isShuffle: np.random.shuffle(self.index_array)
108 | cur_batch = self.batches[self.index_array[self.cur_pointer]]
109 | self.cur_pointer += 1
110 | return cur_batch
111 |
112 | def reset(self):
113 | if self.isShuffle: np.random.shuffle(self.index_array)
114 | self.cur_pointer = 0
115 |
116 | def get_num_batch(self):
117 | return self.num_batch
118 |
119 | def get_num_instance(self):
120 | return self.num_instances
121 |
122 | def get_batch(self, i):
123 | if i>= self.num_batch: return None
124 | return self.batches[i]
125 |
126 | class Batch(object):
127 | def __init__(self, instances, options, word_vocab=None, char_vocab=None, POS_vocab=None, NER_vocab=None):
128 | self.options = options
129 |
130 | self.batch_size = len(instances)
131 | self.vocab = word_vocab
132 |
133 | self.id = [inst[-1] for inst in instances]
134 | self.source = [inst[-3] for inst in instances]
135 | self.target_ref = [inst[-2] for inst in instances]
136 |
137 | # create length
138 | self.sent1_length = [] # [batch_size]
139 | self.sent2_length = [] # [batch_size]
140 | for (sent1_idx, sent2_idx, _, _, _) in instances:
141 | self.sent1_length.append(len(sent1_idx))
142 | self.sent2_length.append(min(len(sent2_idx)+1, options.max_answer_len))
143 | self.sent1_length = np.array(self.sent1_length, dtype=np.int32)
144 | self.sent2_length = np.array(self.sent2_length, dtype=np.int32)
145 |
146 | # create word representation
147 | start_id = word_vocab.getIndex('')
148 | end_id = word_vocab.getIndex('')
149 | if options.with_word:
150 | self.sent1_word = [] # [batch_size, sent1_len]
151 | self.sent2_word = [] # [batch_size, sent2_len]
152 | self.sent2_input_word = []
153 | for (sent1_idx, sent2_idx, _, _, _) in instances:
154 | self.sent1_word.append(sent1_idx)
155 | self.sent2_word.append(sent2_idx+[end_id])
156 | self.sent2_input_word.append([start_id]+sent2_idx)
157 | self.sent1_word = padding_utils.pad_2d_vals(self.sent1_word, len(instances), np.max(self.sent1_length))
158 | self.sent2_word = padding_utils.pad_2d_vals(self.sent2_word, len(instances), options.max_answer_len)
159 | self.sent2_input_word = padding_utils.pad_2d_vals(self.sent2_input_word, len(instances), options.max_answer_len)
160 |
161 | self.in_answer_words = self.sent2_word
162 | self.gen_input_words = self.sent2_input_word
163 | self.answer_lengths = self.sent2_length
164 |
165 | if options.with_char:
166 | self.sent1_char = [] # [batch_size, sent1_len]
167 | self.sent1_char_lengths = []
168 | for (_, _, sent1, sent2, _) in instances:
169 | sent1_char_idx = char_vocab.to_character_matrix_for_list(sent1.split()[:options.max_passage_len])
170 | self.sent1_char.append(sent1_char_idx)
171 | self.sent1_char_lengths.append([len(x) for x in sent1_char_idx])
172 | self.sent1_char = padding_utils.pad_3d_vals_no_size(self.sent1_char)
173 | self.sent1_char_lengths = padding_utils.pad_2d_vals_no_size(self.sent1_char_lengths)
174 |
175 |
176 | if __name__ == "__main__":
177 | all_instances, _ = read_all_GenerationDatasets('./data/training.json', True)
178 | print(1.0*sum(1 for sent1, sent2, sent3 in all_instances if sent1.get_length() > 200))
179 | print(1.0*sum(1 for sent1, sent2, sent3 in all_instances if sent2.get_length() > 100))
180 | print('DONE!')
181 | all_instances, _ = read_all_GenerationDatasets('./data/test.json', True)
182 | print(1.0*sum(1 for sent1, sent2, sent3 in all_instances if sent1.get_length() > 200))
183 | print(1.0*sum(1 for sent1, sent2, sent3 in all_instances if sent2.get_length() > 100))
184 | print('DONE!')
185 |
--------------------------------------------------------------------------------
/src_s2s/phrase_lattice_utils.py:
--------------------------------------------------------------------------------
1 | from graphviz import Digraph
2 | import numpy as np
3 |
4 | class prefix_tree_node(object):
5 | def __init__(self, node_id):
6 | self.node_id = node_id
7 | self.phrase_id = -1# phrase_id==-1 for all intermediate nodes
8 | self.children = {}
9 | self.parent = None
10 |
11 | def add_child(self, word, next_node):
12 | self.children[word] = next_node
13 | next_node.parent = (word, self)
14 |
15 | def set_phrase_id(self, phrase_id):
16 | self.phrase_id = phrase_id
17 |
18 | def find_child(self, word):
19 | if self.children.has_key(word):
20 | return self.children.get(word)
21 | else:
22 | return None
23 |
24 | class prefix_tree(object):
25 | def __init__(self, phrase2id):
26 | self.root_node = prefix_tree_node(0)
27 | self.all_nodes = [self.root_node]
28 | self.phrase_id_node = {}
29 |
30 | for phrase,phrase_id in dict.iteritems(phrase2id):
31 | words = phrase.split()
32 | if len(words)<=1: continue
33 | cur_node = self.root_node
34 | for word in words:
35 | next_node = cur_node.find_child(word)
36 | if next_node == None:
37 | next_node = prefix_tree_node(len(self.all_nodes))
38 | cur_node.add_child(word, next_node)
39 | self.all_nodes.append(next_node)
40 | cur_node = next_node
41 | cur_node.set_phrase_id(phrase_id)
42 | self.phrase_id_node[phrase_id] = cur_node
43 |
44 | def get_phrase_id(self, phrase):
45 | words = phrase.split()
46 | cur_node = self.root_node
47 | for word in words:
48 | cur_node = cur_node.find_child(word)
49 | if cur_node is None: return None
50 | return cur_node.phrase_id
51 |
52 | def get_phrase(self, phrase_id):
53 | if not self.phrase_id_node.has_key(phrase_id): return None
54 | cur_node = self.phrase_id_node[phrase_id]
55 | words = []
56 | while cur_node:
57 | if cur_node.parent is None: break
58 | (cur_word, cur_parent) = cur_node.parent
59 | words.insert(0, cur_word)
60 | cur_node = cur_parent
61 | return " ".join(words)
62 |
63 | def has_phrase_id(self, phrase_id):
64 | return self.phrase_id_node.has_key(phrase_id)
65 |
66 |
67 | def init_bak(self, phrase_id):
68 | self.root_node = prefix_tree_node(0)
69 | self.all_nodes = [self.root_node]
70 |
71 | for phrase, phrase_id in phrase_id.iteritems():
72 | words = phrase.split()
73 | cur_node = self.root_node
74 | for word in words:
75 | next_node = cur_node.find_child(word)
76 | if next_node == None:
77 | next_node = prefix_tree_node(len(self.all_nodes))
78 | cur_node.add_child(word, next_node)
79 | self.all_nodes.append(next_node)
80 | cur_node = next_node
81 | cur_node.set_phrase_id(phrase_id)
82 |
83 | def __str__(self):
84 | dot = Digraph(name='prefix_tree')
85 | for cur_node in self.all_nodes:
86 | dot.node(str(cur_node.node_id), str(cur_node.phrase_id))
87 | for edge in cur_node.children.keys():
88 | cur_edge_node = cur_node.children[edge]
89 | dot.edge(str(cur_node.node_id), str(cur_edge_node.node_id), edge)
90 |
91 | dot.body.append(r'label = "prefix tree"')
92 | return dot.source
93 |
94 |
95 | class lattice_node(object):
96 | def __init__(self, node_id):
97 | self.node_id = node_id
98 | self.out_deges = []
99 |
100 | def add_edge(self, edge):
101 | self.out_deges.append(edge)
102 |
103 | def get_edge(self, i):
104 | if len(self.out_deges)-1 and cur_tail_node.node_id-cur_start_node.node_id>1: # one phrase is match
158 | # add one edge
159 | cur_phrase_id = cur_prefix_node.phrase_id
160 | # cur_phrase = id2phrase[cur_phrase_id]
161 | cur_phrase = prefix_tree.get_phrase(cur_phrase_id)
162 | new_edge = lattice_edge(cur_phrase, cur_phrase_id, cur_tail_node)
163 | cur_start_node.add_edge(new_edge)
164 | cur_node = cur_tail_node
165 |
166 | def sample_a_partition(self, max_matching=False):
167 | phrases = []
168 | phrase_ids = []
169 | cur_node = self.start_node
170 | while cur_node:
171 | all_edges = cur_node.out_deges
172 | edge_size = len(all_edges)
173 | if edge_size<=0: break
174 | if max_matching:
175 | max_len = -1
176 | sampled_idx = 0
177 | for i, cur_edge in enumerate(all_edges):
178 | cur_length = len(cur_edge.phrase.split())
179 | if max_len=len(words): break
208 | cur_phrase = " ".join(words[i:j+1])
209 | if phrase2id.has_key(cur_phrase): continue
210 | cur_index = len(phrase2id)
211 | phrase2id[cur_phrase] = cur_index
212 | id2phrase[cur_index] = cur_phrase
213 | return (phrase2id, id2phrase)
214 |
215 |
216 | if __name__ == "__main__":
217 | ''' # test create prefix tree
218 | phrase_id = {"a": 0, "a b f": 1, "a b c":2, "a b d":3, "b":4, "b d":5, "b e":6}
219 | tree = prefix_tree(phrase_id)
220 | print(tree)
221 | #'''
222 |
223 | ''' # test creating lattice
224 | sentence = "a b cc dd ee f g hhh iiiii jj kk lm"
225 | toks = sentence.split()
226 | lattice = phrase_lattice(toks)
227 | print(lattice)
228 | #'''
229 |
230 | src_sentence = "what is the significance of the periodic table ?"
231 | tgt_sentence = "what is a periodic table ?"
232 |
233 | # collect phrases from src_setence
234 | max_chunk_len = 4
235 | (phrase2id, id2phrase) = collect_all_possible_phrases(src_sentence, max_chunk_len=max_chunk_len)
236 |
237 | # create prefix tree
238 | tree = prefix_tree(phrase2id)
239 | # print(tree)
240 | #'''
241 | for phrase in phrase2id.keys():
242 | phrase_id = tree.get_phrase_id(phrase)
243 | cur_phrase = tree.get_phrase(phrase_id)
244 | print(phrase)
245 | print(phrase_id)
246 | print(cur_phrase)
247 | print()
248 | #'''
249 |
250 | '''
251 | # create lattice for the target sentence
252 | lattice = phrase_lattice(tgt_sentence.split(), word_vocab=None, prefix_tree=tree)
253 | # print(lattice)
254 |
255 | # sample partitions
256 | (phrases, phrase_ids) = lattice.sample_a_partition()
257 | print(phrases)
258 | print(phrase_ids)
259 |
260 | (phrases, phrase_ids) = lattice.sample_a_partition()
261 | print(phrases)
262 | print(phrase_ids)
263 |
264 | (phrases, phrase_ids) = lattice.sample_a_partition()
265 | print(phrases)
266 | print(phrase_ids)
267 |
268 | (phrases, phrase_ids) = lattice.sample_a_partition(max_matching=True)
269 | print(phrases)
270 | print(phrase_ids)
271 | #'''
--------------------------------------------------------------------------------
/src_g2s/metric_bleu_utils.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | # bleu_scorer.py
4 | # David Chiang
5 |
6 | # Copyright (c) 2004-2006 University of Maryland. All rights
7 | # reserved. Do not redistribute without permission from the
8 | # author. Not for commercial use.
9 |
10 | # Modified by:
11 | # Hao Fang
12 | # Tsung-Yi Lin
13 |
14 | '''Provides:
15 | cook_refs(refs, n=4): Transform a list of reference sentences as strings into a form usable by cook_test().
16 | cook_test(test, refs, n=4): Transform a test sentence as a string (together with the cooked reference sentences) into a form usable by score_cooked().
17 | '''
18 |
19 | import copy
20 | import sys, math, re
21 | from collections import defaultdict
22 |
23 | def precook(s, n=4, out=False):
24 | """Takes a string as input and returns an object that can be given to
25 | either cook_refs or cook_test. This is optional: cook_refs and cook_test
26 | can take string arguments as well."""
27 | words = s.split()
28 | counts = defaultdict(int)
29 | for k in xrange(1,n+1):
30 | for i in xrange(len(words)-k+1):
31 | ngram = tuple(words[i:i+k])
32 | counts[ngram] += 1
33 | return (len(words), counts)
34 |
35 | def cook_refs(refs, eff=None, n=4): ## lhuang: oracle will call with "average"
36 | '''Takes a list of reference sentences for a single segment
37 | and returns an object that encapsulates everything that BLEU
38 | needs to know about them.'''
39 |
40 | reflen = []
41 | maxcounts = {}
42 | for ref in refs:
43 | rl, counts = precook(ref, n)
44 | reflen.append(rl)
45 | for (ngram,count) in counts.iteritems():
46 | maxcounts[ngram] = max(maxcounts.get(ngram,0), count)
47 |
48 | # Calculate effective reference sentence length.
49 | if eff == "shortest":
50 | reflen = min(reflen)
51 | elif eff == "average":
52 | reflen = float(sum(reflen))/len(reflen)
53 |
54 | ## lhuang: N.B.: leave reflen computaiton to the very end!!
55 |
56 | ## lhuang: N.B.: in case of "closest", keep a list of reflens!! (bad design)
57 |
58 | return (reflen, maxcounts)
59 |
60 | def cook_test(test, (reflen, refmaxcounts), eff=None, n=4):
61 | '''Takes a test sentence and returns an object that
62 | encapsulates everything that BLEU needs to know about it.'''
63 |
64 | testlen, counts = precook(test, n, True)
65 |
66 | result = {}
67 |
68 | # Calculate effective reference sentence length.
69 |
70 | if eff == "closest":
71 | result["reflen"] = min((abs(l-testlen), l) for l in reflen)[1]
72 | else: ## i.e., "average" or "shortest" or None
73 | result["reflen"] = reflen
74 |
75 | result["testlen"] = testlen
76 |
77 | result["guess"] = [max(0,testlen-k+1) for k in xrange(1,n+1)]
78 |
79 | result['correct'] = [0]*n
80 | for (ngram, count) in counts.iteritems():
81 | result["correct"][len(ngram)-1] += min(refmaxcounts.get(ngram,0), count)
82 |
83 | return result
84 |
85 | class BleuScorer(object):
86 | """Bleu scorer.
87 | """
88 |
89 | __slots__ = "n", "crefs", "ctest", "_score", "_ratio", "_testlen", "_reflen", "special_reflen"
90 | # special_reflen is used in oracle (proportional effective ref len for a node).
91 |
92 | def copy(self):
93 | ''' copy the refs.'''
94 | new = BleuScorer(n=self.n)
95 | new.ctest = copy.copy(self.ctest)
96 | new.crefs = copy.copy(self.crefs)
97 | new._score = None
98 | return new
99 |
100 | def __init__(self, test=None, refs=None, n=4, special_reflen=None):
101 | ''' singular instance '''
102 |
103 | self.n = n
104 | self.crefs = []
105 | self.ctest = []
106 | self.cook_append(test, refs)
107 | self.special_reflen = special_reflen
108 |
109 | def cook_append(self, test, refs):
110 | '''called by constructor and __iadd__ to avoid creating new instances.'''
111 |
112 | if refs is not None:
113 | self.crefs.append(cook_refs(refs))
114 | if test is not None:
115 | cooked_test = cook_test(test, self.crefs[-1])
116 | self.ctest.append(cooked_test) ## N.B.: -1
117 | else:
118 | self.ctest.append(None) # lens of crefs and ctest have to match
119 |
120 | self._score = None ## need to recompute
121 |
122 | def ratio(self, option=None):
123 | self.compute_score(option=option)
124 | return self._ratio
125 |
126 | def score_ratio(self, option=None):
127 | '''return (bleu, len_ratio) pair'''
128 | return (self.fscore(option=option), self.ratio(option=option))
129 |
130 | def score_ratio_str(self, option=None):
131 | return "%.4f (%.2f)" % self.score_ratio(option)
132 |
133 | def reflen(self, option=None):
134 | self.compute_score(option=option)
135 | return self._reflen
136 |
137 | def testlen(self, option=None):
138 | self.compute_score(option=option)
139 | return self._testlen
140 |
141 | def retest(self, new_test):
142 | if type(new_test) is str:
143 | new_test = [new_test]
144 | assert len(new_test) == len(self.crefs), new_test
145 | self.ctest = []
146 | for t, rs in zip(new_test, self.crefs):
147 | self.ctest.append(cook_test(t, rs))
148 | self._score = None
149 |
150 | return self
151 |
152 | def rescore(self, new_test):
153 | ''' replace test(s) with new test(s), and returns the new score.'''
154 |
155 | return self.retest(new_test).compute_score()
156 |
157 | def size(self):
158 | assert len(self.crefs) == len(self.ctest), "refs/test mismatch! %d<>%d" % (len(self.crefs), len(self.ctest))
159 | return len(self.crefs)
160 |
161 | def __iadd__(self, other):
162 | '''add an instance (e.g., from another sentence).'''
163 |
164 | if type(other) is tuple:
165 | ## avoid creating new BleuScorer instances
166 | self.cook_append(other[0], other[1])
167 | else:
168 | assert self.compatible(other), "incompatible BLEUs."
169 | self.ctest.extend(other.ctest)
170 | self.crefs.extend(other.crefs)
171 | self._score = None ## need to recompute
172 |
173 | return self
174 |
175 | def compatible(self, other):
176 | return isinstance(other, BleuScorer) and self.n == other.n
177 |
178 | def single_reflen(self, option="average"):
179 | return self._single_reflen(self.crefs[0][0], option)
180 |
181 | def _single_reflen(self, reflens, option=None, testlen=None):
182 |
183 | if option == "shortest":
184 | reflen = min(reflens)
185 | elif option == "average":
186 | reflen = float(sum(reflens))/len(reflens)
187 | elif option == "closest":
188 | reflen = min((abs(l-testlen), l) for l in reflens)[1]
189 | else:
190 | assert False, "unsupported reflen option %s" % option
191 |
192 | return reflen
193 |
194 | def recompute_score(self, option=None, verbose=0):
195 | self._score = None
196 | return self.compute_score(option, verbose)
197 |
198 | def compute_score(self, option=None, verbose=0):
199 | n = self.n
200 | small = 1e-9
201 | tiny = 1e-15 ## so that if guess is 0 still return 0
202 | bleu_list = [[] for _ in range(n)]
203 |
204 | if self._score is not None:
205 | return self._score
206 |
207 | if option is None:
208 | option = "average" if len(self.crefs) == 1 else "closest"
209 |
210 | self._testlen = 0
211 | self._reflen = 0
212 | totalcomps = {'testlen':0, 'reflen':0, 'guess':[0]*n, 'correct':[0]*n}
213 |
214 | # for each sentence
215 | for comps in self.ctest:
216 | testlen = comps['testlen']
217 | self._testlen += testlen
218 |
219 | if self.special_reflen is None: ## need computation
220 | reflen = self._single_reflen(comps['reflen'], option, testlen)
221 | else:
222 | reflen = self.special_reflen
223 |
224 | self._reflen += reflen
225 |
226 | for key in ['guess','correct']:
227 | for k in xrange(n):
228 | totalcomps[key][k] += comps[key][k]
229 |
230 | # append per image bleu score
231 | bleu = 1.
232 | for k in xrange(n):
233 | bleu *= (float(comps['correct'][k]) + tiny) \
234 | /(float(comps['guess'][k]) + small)
235 | bleu_list[k].append(bleu ** (1./(k+1)))
236 | ratio = (testlen + tiny) / (reflen + small) ## N.B.: avoid zero division
237 | if ratio < 1:
238 | for k in xrange(n):
239 | bleu_list[k][-1] *= math.exp(1 - 1/ratio)
240 |
241 | if verbose > 1:
242 | print comps, reflen
243 |
244 | totalcomps['reflen'] = self._reflen
245 | totalcomps['testlen'] = self._testlen
246 |
247 | bleus = []
248 | bleu = 1.
249 | for k in xrange(n):
250 | bleu *= float(totalcomps['correct'][k] + tiny) \
251 | / (totalcomps['guess'][k] + small)
252 | bleus.append(bleu ** (1./(k+1)))
253 | ratio = (self._testlen + tiny) / (self._reflen + small) ## N.B.: avoid zero division
254 | if ratio < 1:
255 | for k in xrange(n):
256 | bleus[k] *= math.exp(1 - 1/ratio)
257 |
258 | if verbose > 0:
259 | print totalcomps
260 | print "ratio:", ratio
261 |
262 | self._score = bleus
263 | return self._score, bleu_list
264 |
265 | class Bleu:
266 | def __init__(self, n=4):
267 | # default compute Blue score up to 4
268 | self._n = n
269 | self._hypo_for_image = {}
270 | self.ref_for_image = {}
271 |
272 | def compute_score(self, gts, res):
273 |
274 | assert(gts.keys() == res.keys())
275 | imgIds = gts.keys()
276 |
277 | bleu_scorer = BleuScorer(n=self._n)
278 | for id in imgIds:
279 | hypo = res[id]
280 | ref = gts[id]
281 |
282 | # Sanity check.
283 | assert(type(hypo) is list)
284 | assert(len(hypo) == 1)
285 | assert(type(ref) is list)
286 | assert(len(ref) >= 1)
287 |
288 | bleu_scorer += (hypo[0], ref)
289 |
290 | #score, scores = bleu_scorer.compute_score(option='shortest')
291 | score, scores = bleu_scorer.compute_score(option='closest', verbose=0)
292 | #score, scores = bleu_scorer.compute_score(option='average', verbose=1)
293 |
294 | # return (bleu, bleu_info)
295 | return score, scores
296 |
297 | def method(self):
298 | return "Bleu"
299 |
--------------------------------------------------------------------------------
/src_s2s/metric_bleu_utils.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | # bleu_scorer.py
4 | # David Chiang
5 |
6 | # Copyright (c) 2004-2006 University of Maryland. All rights
7 | # reserved. Do not redistribute without permission from the
8 | # author. Not for commercial use.
9 |
10 | # Modified by:
11 | # Hao Fang
12 | # Tsung-Yi Lin
13 |
14 | '''Provides:
15 | cook_refs(refs, n=4): Transform a list of reference sentences as strings into a form usable by cook_test().
16 | cook_test(test, refs, n=4): Transform a test sentence as a string (together with the cooked reference sentences) into a form usable by score_cooked().
17 | '''
18 |
19 | import copy
20 | import sys, math, re
21 | from collections import defaultdict
22 |
23 | def precook(s, n=4, out=False):
24 | """Takes a string as input and returns an object that can be given to
25 | either cook_refs or cook_test. This is optional: cook_refs and cook_test
26 | can take string arguments as well."""
27 | words = s.split()
28 | counts = defaultdict(int)
29 | for k in xrange(1,n+1):
30 | for i in xrange(len(words)-k+1):
31 | ngram = tuple(words[i:i+k])
32 | counts[ngram] += 1
33 | return (len(words), counts)
34 |
35 | def cook_refs(refs, eff=None, n=4): ## lhuang: oracle will call with "average"
36 | '''Takes a list of reference sentences for a single segment
37 | and returns an object that encapsulates everything that BLEU
38 | needs to know about them.'''
39 |
40 | reflen = []
41 | maxcounts = {}
42 | for ref in refs:
43 | rl, counts = precook(ref, n)
44 | reflen.append(rl)
45 | for (ngram,count) in counts.iteritems():
46 | maxcounts[ngram] = max(maxcounts.get(ngram,0), count)
47 |
48 | # Calculate effective reference sentence length.
49 | if eff == "shortest":
50 | reflen = min(reflen)
51 | elif eff == "average":
52 | reflen = float(sum(reflen))/len(reflen)
53 |
54 | ## lhuang: N.B.: leave reflen computaiton to the very end!!
55 |
56 | ## lhuang: N.B.: in case of "closest", keep a list of reflens!! (bad design)
57 |
58 | return (reflen, maxcounts)
59 |
60 | def cook_test(test, (reflen, refmaxcounts), eff=None, n=4):
61 | '''Takes a test sentence and returns an object that
62 | encapsulates everything that BLEU needs to know about it.'''
63 |
64 | testlen, counts = precook(test, n, True)
65 |
66 | result = {}
67 |
68 | # Calculate effective reference sentence length.
69 |
70 | if eff == "closest":
71 | result["reflen"] = min((abs(l-testlen), l) for l in reflen)[1]
72 | else: ## i.e., "average" or "shortest" or None
73 | result["reflen"] = reflen
74 |
75 | result["testlen"] = testlen
76 |
77 | result["guess"] = [max(0,testlen-k+1) for k in xrange(1,n+1)]
78 |
79 | result['correct'] = [0]*n
80 | for (ngram, count) in counts.iteritems():
81 | result["correct"][len(ngram)-1] += min(refmaxcounts.get(ngram,0), count)
82 |
83 | return result
84 |
85 | class BleuScorer(object):
86 | """Bleu scorer.
87 | """
88 |
89 | __slots__ = "n", "crefs", "ctest", "_score", "_ratio", "_testlen", "_reflen", "special_reflen"
90 | # special_reflen is used in oracle (proportional effective ref len for a node).
91 |
92 | def copy(self):
93 | ''' copy the refs.'''
94 | new = BleuScorer(n=self.n)
95 | new.ctest = copy.copy(self.ctest)
96 | new.crefs = copy.copy(self.crefs)
97 | new._score = None
98 | return new
99 |
100 | def __init__(self, test=None, refs=None, n=4, special_reflen=None):
101 | ''' singular instance '''
102 |
103 | self.n = n
104 | self.crefs = []
105 | self.ctest = []
106 | self.cook_append(test, refs)
107 | self.special_reflen = special_reflen
108 |
109 | def cook_append(self, test, refs):
110 | '''called by constructor and __iadd__ to avoid creating new instances.'''
111 |
112 | if refs is not None:
113 | self.crefs.append(cook_refs(refs))
114 | if test is not None:
115 | cooked_test = cook_test(test, self.crefs[-1])
116 | self.ctest.append(cooked_test) ## N.B.: -1
117 | else:
118 | self.ctest.append(None) # lens of crefs and ctest have to match
119 |
120 | self._score = None ## need to recompute
121 |
122 | def ratio(self, option=None):
123 | self.compute_score(option=option)
124 | return self._ratio
125 |
126 | def score_ratio(self, option=None):
127 | '''return (bleu, len_ratio) pair'''
128 | return (self.fscore(option=option), self.ratio(option=option))
129 |
130 | def score_ratio_str(self, option=None):
131 | return "%.4f (%.2f)" % self.score_ratio(option)
132 |
133 | def reflen(self, option=None):
134 | self.compute_score(option=option)
135 | return self._reflen
136 |
137 | def testlen(self, option=None):
138 | self.compute_score(option=option)
139 | return self._testlen
140 |
141 | def retest(self, new_test):
142 | if type(new_test) is str:
143 | new_test = [new_test]
144 | assert len(new_test) == len(self.crefs), new_test
145 | self.ctest = []
146 | for t, rs in zip(new_test, self.crefs):
147 | self.ctest.append(cook_test(t, rs))
148 | self._score = None
149 |
150 | return self
151 |
152 | def rescore(self, new_test):
153 | ''' replace test(s) with new test(s), and returns the new score.'''
154 |
155 | return self.retest(new_test).compute_score()
156 |
157 | def size(self):
158 | assert len(self.crefs) == len(self.ctest), "refs/test mismatch! %d<>%d" % (len(self.crefs), len(self.ctest))
159 | return len(self.crefs)
160 |
161 | def __iadd__(self, other):
162 | '''add an instance (e.g., from another sentence).'''
163 |
164 | if type(other) is tuple:
165 | ## avoid creating new BleuScorer instances
166 | self.cook_append(other[0], other[1])
167 | else:
168 | assert self.compatible(other), "incompatible BLEUs."
169 | self.ctest.extend(other.ctest)
170 | self.crefs.extend(other.crefs)
171 | self._score = None ## need to recompute
172 |
173 | return self
174 |
175 | def compatible(self, other):
176 | return isinstance(other, BleuScorer) and self.n == other.n
177 |
178 | def single_reflen(self, option="average"):
179 | return self._single_reflen(self.crefs[0][0], option)
180 |
181 | def _single_reflen(self, reflens, option=None, testlen=None):
182 |
183 | if option == "shortest":
184 | reflen = min(reflens)
185 | elif option == "average":
186 | reflen = float(sum(reflens))/len(reflens)
187 | elif option == "closest":
188 | reflen = min((abs(l-testlen), l) for l in reflens)[1]
189 | else:
190 | assert False, "unsupported reflen option %s" % option
191 |
192 | return reflen
193 |
194 | def recompute_score(self, option=None, verbose=0):
195 | self._score = None
196 | return self.compute_score(option, verbose)
197 |
198 | def compute_score(self, option=None, verbose=0):
199 | n = self.n
200 | small = 1e-9
201 | tiny = 1e-15 ## so that if guess is 0 still return 0
202 | bleu_list = [[] for _ in range(n)]
203 |
204 | if self._score is not None:
205 | return self._score
206 |
207 | if option is None:
208 | option = "average" if len(self.crefs) == 1 else "closest"
209 |
210 | self._testlen = 0
211 | self._reflen = 0
212 | totalcomps = {'testlen':0, 'reflen':0, 'guess':[0]*n, 'correct':[0]*n}
213 |
214 | # for each sentence
215 | for comps in self.ctest:
216 | testlen = comps['testlen']
217 | self._testlen += testlen
218 |
219 | if self.special_reflen is None: ## need computation
220 | reflen = self._single_reflen(comps['reflen'], option, testlen)
221 | else:
222 | reflen = self.special_reflen
223 |
224 | self._reflen += reflen
225 |
226 | for key in ['guess','correct']:
227 | for k in xrange(n):
228 | totalcomps[key][k] += comps[key][k]
229 |
230 | # append per image bleu score
231 | bleu = 1.
232 | for k in xrange(n):
233 | bleu *= (float(comps['correct'][k]) + tiny) \
234 | /(float(comps['guess'][k]) + small)
235 | bleu_list[k].append(bleu ** (1./(k+1)))
236 | ratio = (testlen + tiny) / (reflen + small) ## N.B.: avoid zero division
237 | if ratio < 1:
238 | for k in xrange(n):
239 | bleu_list[k][-1] *= math.exp(1 - 1/ratio)
240 |
241 | if verbose > 1:
242 | print comps, reflen
243 |
244 | totalcomps['reflen'] = self._reflen
245 | totalcomps['testlen'] = self._testlen
246 |
247 | bleus = []
248 | bleu = 1.
249 | for k in xrange(n):
250 | bleu *= float(totalcomps['correct'][k] + tiny) \
251 | / (totalcomps['guess'][k] + small)
252 | bleus.append(bleu ** (1./(k+1)))
253 | ratio = (self._testlen + tiny) / (self._reflen + small) ## N.B.: avoid zero division
254 | if ratio < 1:
255 | for k in xrange(n):
256 | bleus[k] *= math.exp(1 - 1/ratio)
257 |
258 | if verbose > 0:
259 | print totalcomps
260 | print "ratio:", ratio
261 |
262 | self._score = bleus
263 | return self._score, bleu_list
264 |
265 | class Bleu:
266 | def __init__(self, n=4):
267 | # default compute Blue score up to 4
268 | self._n = n
269 | self._hypo_for_image = {}
270 | self.ref_for_image = {}
271 |
272 | def compute_score(self, gts, res):
273 |
274 | assert(gts.keys() == res.keys())
275 | imgIds = gts.keys()
276 |
277 | bleu_scorer = BleuScorer(n=self._n)
278 | for id in imgIds:
279 | hypo = res[id]
280 | ref = gts[id]
281 |
282 | # Sanity check.
283 | assert(type(hypo) is list)
284 | assert(len(hypo) == 1)
285 | assert(type(ref) is list)
286 | assert(len(ref) >= 1)
287 |
288 | bleu_scorer += (hypo[0], ref)
289 |
290 | #score, scores = bleu_scorer.compute_score(option='shortest')
291 | score, scores = bleu_scorer.compute_score(option='closest', verbose=0)
292 | #score, scores = bleu_scorer.compute_score(option='average', verbose=1)
293 |
294 | # return (bleu, bleu_info)
295 | return score, scores
296 |
297 | def method(self):
298 | return "Bleu"
299 |
--------------------------------------------------------------------------------
/src_s2s/encoder_utils.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import match_utils
3 |
4 | def collect_final_step_lstm(lstm_rep, lens):
5 | lens = tf.maximum(lens, tf.zeros_like(lens, dtype=tf.int32)) # [batch,]
6 | idxs = tf.range(0, limit=tf.shape(lens)[0]) # [batch,]
7 | indices = tf.stack((idxs,lens,), axis=1) # [batch_size, 2]
8 | return tf.gather_nd(lstm_rep, indices, name='lstm-forward-last')
9 |
10 | class SeqEncoder(object):
11 | def __init__(self, placeholders, options, word_vocab=None, char_vocab=None, POS_vocab=None, NER_vocab=None):
12 |
13 | self.options = options
14 |
15 | self.word_vocab = word_vocab
16 | self.char_vocab = char_vocab
17 | self.POS_vocab = POS_vocab
18 | self.NER_vocab = NER_vocab
19 |
20 | self.passage_lengths = placeholders.passage_lengths #tf.placeholder(tf.int32, [None])
21 | if options.with_word:
22 | self.in_passage_words = placeholders.in_passage_words #tf.placeholder(tf.int32, [None, None]) # [batch_size, passage_len]
23 |
24 | if options.with_char:
25 | self.passage_char_lengths = placeholders.passage_char_lengths
26 | #tf.placeholder(tf.int32, [None,None]) # [batch_size, passage_len]
27 | self.in_passage_chars = placeholders.in_passage_chars
28 | #tf.placeholder(tf.int32, [None, None, None]) # [batch_size, passage_len, p_char_len]
29 |
30 | if options.with_POS:
31 | self.in_passage_POSs = placeholders.in_passage_POSs #tf.placeholder(tf.int32, [None, None]) # [batch_size, passage_len]
32 |
33 | if options.with_NER:
34 | self.in_passage_NERs = placeholders.in_passage_NERs #tf.placeholder(tf.int32, [None, None]) # [batch_size, passage_len]
35 |
36 | def encode(self, is_training=True):
37 | options = self.options
38 |
39 | # ======word representation layer======
40 | in_passage_repres = []
41 | input_dim = 0
42 | if options.with_word and self.word_vocab is not None:
43 | word_vec_trainable = True
44 | cur_device = '/gpu:0'
45 | if options.fix_word_vec:
46 | word_vec_trainable = False
47 | cur_device = '/cpu:0'
48 | with tf.variable_scope("embedding"), tf.device(cur_device):
49 | self.word_embedding = tf.get_variable("word_embedding", trainable=word_vec_trainable,
50 | initializer=tf.constant(self.word_vocab.word_vecs), dtype=tf.float32)
51 |
52 | in_passage_word_repres = tf.nn.embedding_lookup(self.word_embedding, self.in_passage_words)
53 | # [batch_size, passage_len, word_dim]
54 | in_passage_repres.append(in_passage_word_repres)
55 |
56 | input_shape = tf.shape(self.in_passage_words)
57 | batch_size = input_shape[0]
58 | passage_len = input_shape[1]
59 | input_dim += self.word_vocab.word_dim
60 |
61 | if options.with_char and self.char_vocab is not None:
62 | input_shape = tf.shape(self.in_passage_chars)
63 | batch_size = input_shape[0]
64 | passage_len = input_shape[1]
65 | p_char_len = input_shape[2]
66 | char_dim = self.char_vocab.word_dim
67 | self.char_embedding = tf.get_variable("char_embedding",
68 | initializer=tf.constant(self.char_vocab.word_vecs), dtype=tf.float32)
69 | in_passage_char_repres = tf.nn.embedding_lookup(self.char_embedding,
70 | self.in_passage_chars) # [batch_size, passage_len, p_char_len, char_dim]
71 | in_passage_char_repres = tf.reshape(in_passage_char_repres, shape=[-1, p_char_len, char_dim])
72 | passage_char_lengths = tf.reshape(self.passage_char_lengths, [-1])
73 | with tf.variable_scope('char_lstm'):
74 | # lstm cell
75 | char_lstm_cell = tf.contrib.rnn.BasicLSTMCell(options.char_lstm_dim)
76 | # dropout
77 | if is_training: char_lstm_cell = tf.contrib.rnn.DropoutWrapper(char_lstm_cell,
78 | output_keep_prob=(1 - options.dropout_rate))
79 | char_lstm_cell = tf.contrib.rnn.MultiRNNCell([char_lstm_cell])
80 | # passage representation
81 | passage_char_outputs = tf.nn.dynamic_rnn(char_lstm_cell, in_passage_char_repres,
82 | sequence_length=passage_char_lengths,dtype=tf.float32)[0]
83 | # [batch_size*question_len, q_char_len, char_lstm_dim]
84 | passage_char_outputs = collect_final_step_lstm(passage_char_outputs, passage_char_lengths-1)
85 | passage_char_outputs = tf.reshape(passage_char_outputs, [batch_size, passage_len, options.char_lstm_dim])
86 |
87 | in_passage_repres.append(passage_char_outputs)
88 | input_dim += options.char_lstm_dim
89 |
90 | in_passage_repres = tf.concat(in_passage_repres, 2) # [batch_size, passage_len, dim]
91 |
92 | if options.compress_input: # compress input word vector into smaller vectors
93 | w_compress = tf.get_variable("w_compress_input", [input_dim, options.compress_input_dim], dtype=tf.float32)
94 | b_compress = tf.get_variable("b_compress_input", [options.compress_input_dim], dtype=tf.float32)
95 |
96 | in_passage_repres = tf.reshape(in_passage_repres, [-1, input_dim])
97 | in_passage_repres = tf.matmul(in_passage_repres, w_compress) + b_compress
98 | in_passage_repres = tf.tanh(in_passage_repres)
99 | in_passage_repres = tf.reshape(in_passage_repres, [batch_size, passage_len, options.compress_input_dim])
100 | input_dim = options.compress_input_dim
101 |
102 | if is_training:
103 | in_passage_repres = tf.nn.dropout(in_passage_repres, (1 - options.dropout_rate))
104 | else:
105 | in_passage_repres = tf.multiply(in_passage_repres, (1 - options.dropout_rate))
106 |
107 | passage_mask = tf.sequence_mask(self.passage_lengths, passage_len, dtype=tf.float32) # [batch_size, passage_len]
108 |
109 | # sequential context matching
110 | passage_forward = None
111 | passage_backward = None
112 | all_passage_representation = []
113 | passage_dim = 0
114 | with_lstm = True
115 | if with_lstm:
116 | with tf.variable_scope('biLSTM'):
117 | cur_in_passage_repres = in_passage_repres
118 | for i in xrange(options.context_layer_num):
119 | with tf.variable_scope('layer-{}'.format(i)):
120 | with tf.variable_scope('context_represent'):
121 | # parameters
122 | context_lstm_cell_fw = tf.contrib.rnn.LSTMCell(options.context_lstm_dim)
123 | context_lstm_cell_bw = tf.contrib.rnn.LSTMCell(options.context_lstm_dim)
124 | if is_training:
125 | context_lstm_cell_fw = tf.contrib.rnn.DropoutWrapper(context_lstm_cell_fw, output_keep_prob=(1 - options.dropout_rate))
126 | context_lstm_cell_bw = tf.contrib.rnn.DropoutWrapper(context_lstm_cell_bw, output_keep_prob=(1 - options.dropout_rate))
127 |
128 | # passage representation
129 | ((passage_context_representation_fw, passage_context_representation_bw),
130 | (passage_forward, passage_backward)) = tf.nn.bidirectional_dynamic_rnn(
131 | context_lstm_cell_fw, context_lstm_cell_bw, cur_in_passage_repres, dtype=tf.float32,
132 | sequence_length=self.passage_lengths) # [batch_size, passage_len, context_lstm_dim]
133 | if options.direction == 'forward':
134 | # [batch_size, passage_len, context_lstm_dim]
135 | cur_in_passage_repres = passage_context_representation_fw
136 | passage_dim += options.context_lstm_dim
137 | elif options.direction == 'backward':
138 | # [batch_size, passage_len, context_lstm_dim]
139 | cur_in_passage_repres = passage_context_representation_bw
140 | passage_dim += options.context_lstm_dim
141 | elif options.direction == 'bidir':
142 | # [batch_size, passage_len, 2*context_lstm_dim]
143 | cur_in_passage_repres = tf.concat(
144 | [passage_context_representation_fw, passage_context_representation_bw], 2)
145 | passage_dim += 2 * options.context_lstm_dim
146 | else:
147 | assert False
148 | all_passage_representation.append(cur_in_passage_repres)
149 |
150 |
151 | all_passage_representation = tf.concat(all_passage_representation, 2) # [batch_size, passage_len, passage_dim]
152 |
153 | if is_training:
154 | all_passage_representation = tf.nn.dropout(all_passage_representation, (1 - options.dropout_rate))
155 | else:
156 | all_passage_representation = tf.multiply(all_passage_representation, (1 - options.dropout_rate))
157 |
158 | # ======Highway layer======
159 | if options.with_match_highway:
160 | with tf.variable_scope("context_highway"):
161 | all_passage_representation = match_utils.multi_highway_layer(all_passage_representation,
162 | passage_dim,options.highway_layer_num)
163 |
164 | all_passage_representation = all_passage_representation * tf.expand_dims(passage_mask, axis=-1)
165 |
166 | # initial state for the LSTM decoder
167 | #'''
168 | with tf.variable_scope('initial_state_for_decoder'):
169 | # Define weights and biases to reduce the cell and reduce the state
170 | w_reduce_c = tf.get_variable('w_reduce_c', [2*options.context_lstm_dim, options.gen_hidden_size], dtype=tf.float32)
171 | w_reduce_h = tf.get_variable('w_reduce_h', [2*options.context_lstm_dim, options.gen_hidden_size], dtype=tf.float32)
172 | bias_reduce_c = tf.get_variable('bias_reduce_c', [options.gen_hidden_size], dtype=tf.float32)
173 | bias_reduce_h = tf.get_variable('bias_reduce_h', [options.gen_hidden_size], dtype=tf.float32)
174 |
175 | old_c = tf.concat(values=[passage_forward.c, passage_backward.c], axis=1)
176 | old_h = tf.concat(values=[passage_forward.h, passage_backward.h], axis=1)
177 | new_c = tf.nn.tanh(tf.matmul(old_c, w_reduce_c) + bias_reduce_c)
178 | new_h = tf.nn.tanh(tf.matmul(old_h, w_reduce_h) + bias_reduce_h)
179 |
180 | init_state = tf.contrib.rnn.LSTMStateTuple(new_c, new_h)
181 | '''
182 | new_c = tf.zeros([batch_size, options.gen_hidden_size])
183 | new_h = tf.zeros([batch_size, options.gen_hidden_size])
184 | init_state = LSTMStateTuple(new_c, new_h)
185 | '''
186 | return (passage_dim, all_passage_representation, init_state)
187 |
188 |
--------------------------------------------------------------------------------
/src_g2s/G2S_data_stream.py:
--------------------------------------------------------------------------------
1 | import json
2 | import re
3 | import numpy as np
4 | import random
5 | import padding_utils
6 | import amr_utils
7 |
8 | def read_text_file(text_file):
9 | lines = []
10 | with open(text_file, "rt") as f:
11 | for line in f:
12 | line = line.decode('utf-8')
13 | lines.append(line.strip())
14 | return lines
15 |
16 | def read_amr_file(inpath):
17 | nodes = [] # [batch, node_num,]
18 | in_neigh_indices = [] # [batch, node_num, neighbor_num,]
19 | in_neigh_edges = []
20 | out_neigh_indices = [] # [batch, node_num, neighbor_num,]
21 | out_neigh_edges = []
22 | sentences = [] # [batch, sent_length,]
23 | ids = []
24 | max_in_neigh = 0
25 | max_out_neigh = 0
26 | max_node = 0
27 | max_sent = 0
28 | with open(inpath, "rU") as f:
29 | for inst in json.load(f):
30 | amr = inst['amr']
31 | sent = inst['sent'].strip().split()
32 | id = inst['id'] if inst.has_key('id') else None
33 | amr_node = []
34 | amr_edge = []
35 | amr_utils.read_anonymized(amr.strip().split(), amr_node, amr_edge)
36 | # 1.
37 | nodes.append(amr_node)
38 | # 2. & 3.
39 | in_indices = [[i,] for i, x in enumerate(amr_node)]
40 | in_edges = [[':self',] for i, x in enumerate(amr_node)]
41 | out_indices = [[i,] for i, x in enumerate(amr_node)]
42 | out_edges = [[':self',] for i, x in enumerate(amr_node)]
43 | for (i,j,lb) in amr_edge:
44 | in_indices[j].append(i)
45 | in_edges[j].append(lb)
46 | out_indices[i].append(j)
47 | out_edges[i].append(lb)
48 | in_neigh_indices.append(in_indices)
49 | in_neigh_edges.append(in_edges)
50 | out_neigh_indices.append(out_indices)
51 | out_neigh_edges.append(out_edges)
52 | # 4.
53 | sentences.append(sent)
54 | ids.append(id)
55 | # update lengths
56 | max_in_neigh = max(max_in_neigh, max(len(x) for x in in_indices))
57 | max_out_neigh = max(max_out_neigh, max(len(x) for x in out_indices))
58 | max_node = max(max_node, len(amr_node))
59 | max_sent = max(max_sent, len(sent))
60 | return zip(nodes, in_neigh_indices, in_neigh_edges, out_neigh_indices, out_neigh_edges, sentences, ids), \
61 | max_node, max_in_neigh, max_out_neigh, max_sent
62 |
63 | def read_amr_from_fof(fofpath):
64 | all_paths = read_text_file(fofpath)
65 | all_instances = []
66 | max_node = 0
67 | max_in_neigh = 0
68 | max_out_neigh = 0
69 | max_sent = 0
70 | for cur_path in all_paths:
71 | print(cur_path)
72 | cur_instances, cur_node, cur_in_neigh, cur_out_neigh, cur_sent = read_amr_file(cur_path)
73 | all_instances.extend(cur_instances)
74 | max_node = max(max_node, cur_node)
75 | max_in_neigh = max(max_in_neigh, cur_in_neigh)
76 | max_out_neigh = max(max_out_neigh, cur_out_neigh)
77 | max_sent = max(max_sent, cur_sent)
78 | return all_instances, max_node, max_in_neigh, max_out_neigh, max_sent
79 |
80 | def collect_vocabs(all_instances):
81 | all_words = set()
82 | all_chars = set()
83 | all_edgelabels = set()
84 | # nodes: [corpus_size,node_num,]
85 | # neigh_indices & neigh_edges: [corpus_size,node_num,neigh_num,]
86 | # sentence: [corpus_size,sent_len,]
87 | for (nodes, in_neigh_indices, in_neigh_edges, out_neigh_indices, out_neigh_edges, sentence, id) in all_instances:
88 | all_words.update(nodes)
89 | all_words.update(sentence)
90 | for edges in in_neigh_edges:
91 | all_edgelabels.update(edges)
92 | for edges in out_neigh_edges:
93 | all_edgelabels.update(edges)
94 | for w in all_words:
95 | all_chars.update(w)
96 | return (all_words, all_chars, all_edgelabels)
97 |
98 | class G2SDataStream(object):
99 | def __init__(self, all_instances, word_vocab=None, char_vocab=None, edgelabel_vocab=None, options=None,
100 | isShuffle=False, isLoop=False, isSort=True, batch_size=-1):
101 | self.options = options
102 | if batch_size ==-1: batch_size=options.batch_size
103 | # index tokens and filter the dataset
104 | instances = []
105 | for (nodes, in_neigh_indices, in_neigh_edges, out_neigh_indices, out_neigh_edges, sentence, id) in all_instances:
106 | if options.max_node_num != -1 and len(nodes) > options.max_node_num:
107 | continue # remove very long passages
108 | in_neigh_indices = [x[:options.max_in_neigh_num] for x in in_neigh_indices]
109 | in_neigh_edges = [x[:options.max_in_neigh_num] for x in in_neigh_edges]
110 | out_neigh_indices = [x[:options.max_out_neigh_num] for x in out_neigh_indices]
111 | out_neigh_edges = [x[:options.max_out_neigh_num] for x in out_neigh_edges]
112 |
113 | nodes_idx = word_vocab.to_index_sequence_for_list(nodes)
114 | nodes_chars_idx = None
115 | if options.with_char:
116 | nodes_chars_idx = char_vocab.to_character_matrix_for_list(nodes, max_char_per_word=options.max_char_per_word)
117 | in_neigh_edges_idx = [edgelabel_vocab.to_index_sequence_for_list(edges) for edges in in_neigh_edges]
118 | out_neigh_edges_idx = [edgelabel_vocab.to_index_sequence_for_list(edges) for edges in out_neigh_edges]
119 | sentence_idx = word_vocab.to_index_sequence_for_list(sentence[:options.max_answer_len])
120 | instances.append((nodes_idx, nodes_chars_idx,
121 | in_neigh_indices, in_neigh_edges_idx, out_neigh_indices, out_neigh_edges_idx, sentence_idx, sentence, id))
122 |
123 | all_instances = instances
124 | instances = None
125 |
126 | # sort instances based on length
127 | if isSort:
128 | all_instances = sorted(all_instances, key=lambda inst: (len(inst[0]), len(inst[-2])))
129 |
130 | self.num_instances = len(all_instances)
131 |
132 | # distribute questions into different buckets
133 | batch_spans = padding_utils.make_batches(self.num_instances, batch_size)
134 | self.batches = []
135 | for batch_index, (batch_start, batch_end) in enumerate(batch_spans):
136 | cur_instances = []
137 | for i in xrange(batch_start, batch_end):
138 | cur_instances.append(all_instances[i])
139 | cur_batch = G2SBatch(cur_instances, options, word_vocab=word_vocab)
140 | self.batches.append(cur_batch)
141 |
142 | self.num_batch = len(self.batches)
143 | self.index_array = np.arange(self.num_batch)
144 | self.isShuffle = isShuffle
145 | if self.isShuffle: np.random.shuffle(self.index_array)
146 | self.isLoop = isLoop
147 | self.cur_pointer = 0
148 |
149 | def nextBatch(self):
150 | if self.cur_pointer>=self.num_batch:
151 | if not self.isLoop: return None
152 | self.cur_pointer = 0
153 | if self.isShuffle: np.random.shuffle(self.index_array)
154 | cur_batch = self.batches[self.index_array[self.cur_pointer]]
155 | self.cur_pointer += 1
156 | return cur_batch
157 |
158 | def reset(self):
159 | if self.isShuffle: np.random.shuffle(self.index_array)
160 | self.cur_pointer = 0
161 |
162 | def get_num_batch(self):
163 | return self.num_batch
164 |
165 | def get_num_instance(self):
166 | return self.num_instances
167 |
168 | def get_batch(self, i):
169 | if i>= self.num_batch: return None
170 | return self.batches[i]
171 |
172 | class G2SBatch(object):
173 | def __init__(self, instances, options, word_vocab=None):
174 | self.options = options
175 |
176 | self.amr_node = [x[0] for x in instances]
177 | self.id = [x[-1] for x in instances]
178 | self.target_ref = [x[-2] for x in instances] # list of tuples
179 | self.batch_size = len(instances)
180 | self.vocab = word_vocab
181 |
182 | # create length
183 | self.node_num = [] # [batch_size]
184 | self.sent_len = [] # [batch_size]
185 | for (nodes_idx, nodes_chars_idx, in_neigh_indices, in_neigh_edges_idx, out_neigh_indices, out_neigh_edges_idx,
186 | sentence_idx, sentence, id) in instances:
187 | self.node_num.append(len(nodes_idx))
188 | self.sent_len.append(min(len(sentence_idx)+1, options.max_answer_len))
189 | self.node_num = np.array(self.node_num, dtype=np.int32)
190 | self.sent_len = np.array(self.sent_len, dtype=np.int32)
191 |
192 | # node char num
193 | if options.with_char:
194 | self.nodes_chars_num = [[len(nodes_chars_idx) for nodes_chars_idx in instance[1]] for instance in instances]
195 | self.nodes_chars_num = padding_utils.pad_2d_vals_no_size(self.nodes_chars_num)
196 |
197 | # neigh mask
198 | self.in_neigh_mask = [] # [batch_size, node_num, neigh_num]
199 | self.out_neigh_mask = []
200 | for instance in instances:
201 | ins = []
202 | for in_neighs in instance[2]:
203 | ins.append([1 for _ in in_neighs])
204 | self.in_neigh_mask.append(ins)
205 | outs = []
206 | for out_neighs in instance[4]:
207 | outs.append([1 for _ in out_neighs])
208 | self.out_neigh_mask.append(outs)
209 | self.in_neigh_mask = padding_utils.pad_3d_vals_no_size(self.in_neigh_mask)
210 | self.out_neigh_mask = padding_utils.pad_3d_vals_no_size(self.out_neigh_mask)
211 |
212 | # create word representation
213 | start_id = word_vocab.getIndex('')
214 | end_id = word_vocab.getIndex('')
215 |
216 | self.nodes = [x[0] for x in instances]
217 | if options.with_char:
218 | self.nodes_chars = [inst[1] for inst in instances] # [batch_size, sent_len, char_num]
219 | self.in_neigh_indices = [x[2] for x in instances]
220 | self.in_neigh_edges = [x[3] for x in instances]
221 | self.out_neigh_indices = [x[4] for x in instances]
222 | self.out_neigh_edges = [x[5] for x in instances]
223 |
224 | self.sent_inp = []
225 | self.sent_out = []
226 | for _, _, _, _, _, _, sentence_idx, sentence, id in instances:
227 | if len(sentence_idx) < options.max_answer_len:
228 | self.sent_inp.append([start_id,]+sentence_idx)
229 | self.sent_out.append(sentence_idx+[end_id,])
230 | else:
231 | self.sent_inp.append([start_id,]+sentence_idx[:-1])
232 | self.sent_out.append(sentence_idx)
233 |
234 | # making ndarray
235 | self.nodes = padding_utils.pad_2d_vals_no_size(self.nodes)
236 | if options.with_char:
237 | self.nodes_chars = padding_utils.pad_3d_vals_no_size(self.nodes_chars)
238 | self.in_neigh_indices = padding_utils.pad_3d_vals_no_size(self.in_neigh_indices)
239 | self.in_neigh_edges = padding_utils.pad_3d_vals_no_size(self.in_neigh_edges)
240 | self.out_neigh_indices = padding_utils.pad_3d_vals_no_size(self.out_neigh_indices)
241 | self.out_neigh_edges = padding_utils.pad_3d_vals_no_size(self.out_neigh_edges)
242 |
243 | assert self.in_neigh_mask.shape == self.in_neigh_indices.shape
244 | assert self.in_neigh_mask.shape == self.in_neigh_edges.shape
245 | assert self.out_neigh_mask.shape == self.out_neigh_indices.shape
246 | assert self.out_neigh_mask.shape == self.out_neigh_edges.shape
247 |
248 | # [batch_size, sent_len_max]
249 | self.sent_inp = padding_utils.pad_2d_vals(self.sent_inp, len(self.sent_inp), options.max_answer_len)
250 | self.sent_out = padding_utils.pad_2d_vals(self.sent_out, len(self.sent_out), options.max_answer_len)
251 |
252 |
253 | if __name__ == "__main__":
254 | print('testset')
255 | all_instances, max_node_num, max_in_neigh_num, max_out_neigh_num, max_sent_len = read_amr_file('./data/test.json')
256 | print(max_in_neigh_num)
257 | print(max_out_neigh_num)
258 | print(max_node_num)
259 | print(max_sent_len)
260 | print('DONE!')
261 |
262 |
--------------------------------------------------------------------------------
/src_s2s/NP2P_phrase_trainer.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from __future__ import print_function
3 | import argparse
4 | import os
5 | import sys
6 | import time
7 | import numpy as np
8 |
9 | from vocab_utils import Vocab
10 | import namespace_utils
11 | import NP2P_data_stream
12 | from NP2P_model_graph import ModelGraph
13 | import NP2P_beam_decoder
14 | import metric_utils
15 | from nltk.translate.bleu_score import SmoothingFunction, sentence_bleu, corpus_bleu
16 | cc = SmoothingFunction()
17 |
18 | FLAGS = None
19 | import tensorflow as tf
20 | tf.logging.set_verbosity(tf.logging.ERROR) # DEBUG, INFO, WARN, ERROR, and FATAL
21 |
22 |
23 | def evaluate(sess, valid_graph, devDataStream, word_vocab, options=None):
24 | devDataStream.reset()
25 | gen = []
26 | ref = []
27 | for batch_index in xrange(devDataStream.get_num_batch()): # for each batch
28 | cur_batch = devDataStream.get_batch(batch_index)
29 | (sentences, _, _, _) = NP2P_beam_decoder.search(sess, valid_graph, word_vocab, cur_batch, FLAGS, decode_mode="greedy")
30 | for i in xrange(cur_batch.batch_size):
31 | cur_ref = cur_batch.instances[i][1].tokText
32 | cur_prediction = sentences[i]
33 | ref.append([cur_ref.split()])
34 | gen.append(cur_prediction.split())
35 | return corpus_bleu(ref, gen, smoothing_function=cc.method3)
36 |
37 | def main(_):
38 | print('Configurations:')
39 | print(FLAGS)
40 |
41 | log_dir = FLAGS.model_dir
42 | if not os.path.exists(log_dir):
43 | os.makedirs(log_dir)
44 |
45 | path_prefix = log_dir + "/NP2P.{}".format(FLAGS.suffix)
46 | init_model_prefix = FLAGS.init_model # "/u/zhigwang/zhigwang1/sentence_generation/mscoco/logs/NP2P.phrase_ce_train"
47 | log_file_path = path_prefix + ".log"
48 | print('Log file path: {}'.format(log_file_path))
49 | log_file = open(log_file_path, 'wt')
50 | log_file.write("{}\n".format(FLAGS))
51 | log_file.flush()
52 |
53 | # save configuration
54 | namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json")
55 |
56 | print('Loading train set.')
57 | if FLAGS.infile_format == 'fof':
58 | trainset, train_ans_len = NP2P_data_stream.read_generation_datasets_from_fof(FLAGS.train_path, isLower=FLAGS.isLower)
59 | if FLAGS.max_answer_len>train_ans_len: FLAGS.max_answer_len = train_ans_len
60 | else:
61 | trainset, train_ans_len = NP2P_data_stream.read_all_GQA_questions(FLAGS.train_path, isLower=FLAGS.isLower)
62 | print('Number of training samples: {}'.format(len(trainset)))
63 |
64 | print('Loading test set.')
65 | if FLAGS.infile_format == 'fof':
66 | testset, test_ans_len = NP2P_data_stream.read_generation_datasets_from_fof(FLAGS.test_path, isLower=FLAGS.isLower)
67 | else:
68 | testset, test_ans_len = NP2P_data_stream.read_all_GQA_questions(FLAGS.test_path, isLower=FLAGS.isLower)
69 | print('Number of test samples: {}'.format(len(testset)))
70 |
71 | max_actual_len = max(train_ans_len, test_ans_len)
72 | print('Max answer length: {}, truncated to {}'.format(max_actual_len, FLAGS.max_answer_len))
73 |
74 | word_vocab = None
75 | POS_vocab = None
76 | NER_vocab = None
77 | char_vocab = None
78 | has_pretrained_model = False
79 | best_path = path_prefix + ".best.model"
80 | if os.path.exists(init_model_prefix + ".best.model.index"):
81 | has_pretrained_model = True
82 | print('!!Existing pretrained model. Loading vocabs.')
83 | if FLAGS.with_word:
84 | word_vocab = Vocab(FLAGS.word_vec_path, fileformat='txt2')
85 | print('word_vocab: {}'.format(word_vocab.word_vecs.shape))
86 | if FLAGS.with_char:
87 | char_vocab = Vocab(init_model_prefix + ".char_vocab", fileformat='txt2')
88 | print('char_vocab: {}'.format(char_vocab.word_vecs.shape))
89 | if FLAGS.with_POS:
90 | POS_vocab = Vocab(init_model_prefix + ".POS_vocab", fileformat='txt2')
91 | print('POS_vocab: {}'.format(POS_vocab.word_vecs.shape))
92 | if FLAGS.with_NER:
93 | NER_vocab = Vocab(init_model_prefix + ".NER_vocab", fileformat='txt2')
94 | print('NER_vocab: {}'.format(NER_vocab.word_vecs.shape))
95 | else:
96 | print('Collecting vocabs.')
97 | (allWords, allChars, allPOSs, allNERs) = NP2P_data_stream.collect_vocabs(trainset)
98 | print('Number of words: {}'.format(len(allWords)))
99 | print('Number of allChars: {}'.format(len(allChars)))
100 | print('Number of allPOSs: {}'.format(len(allPOSs)))
101 | print('Number of allNERs: {}'.format(len(allNERs)))
102 |
103 | if FLAGS.with_word:
104 | word_vocab = Vocab(FLAGS.word_vec_path, fileformat='txt2')
105 | if FLAGS.with_char:
106 | char_vocab = Vocab(voc=allChars, dim=FLAGS.char_dim, fileformat='build')
107 | char_vocab.dump_to_txt2(path_prefix + ".char_vocab")
108 | if FLAGS.with_POS:
109 | POS_vocab = Vocab(voc=allPOSs, dim=FLAGS.POS_dim, fileformat='build')
110 | POS_vocab.dump_to_txt2(path_prefix + ".POS_vocab")
111 | if FLAGS.with_NER:
112 | NER_vocab = Vocab(voc=allNERs, dim=FLAGS.NER_dim, fileformat='build')
113 | NER_vocab.dump_to_txt2(path_prefix + ".NER_vocab")
114 |
115 | print('word vocab size {}'.format(word_vocab.vocab_size))
116 | sys.stdout.flush()
117 |
118 | print('Build DataStream ... ')
119 | trainDataStream = NP2P_data_stream.QADataStream(trainset, word_vocab, char_vocab, POS_vocab, NER_vocab, options=FLAGS,
120 | isShuffle=True, isLoop=True, isSort=True)
121 |
122 | devDataStream = NP2P_data_stream.QADataStream(testset, word_vocab, char_vocab, POS_vocab, NER_vocab, options=FLAGS,
123 | isShuffle=False, isLoop=False, isSort=True)
124 | print('Number of instances in trainDataStream: {}'.format(trainDataStream.get_num_instance()))
125 | print('Number of instances in devDataStream: {}'.format(devDataStream.get_num_instance()))
126 | print('Number of batches in trainDataStream: {}'.format(trainDataStream.get_num_batch()))
127 | print('Number of batches in devDataStream: {}'.format(devDataStream.get_num_batch()))
128 | sys.stdout.flush()
129 |
130 | init_scale = 0.01
131 | with tf.Graph().as_default():
132 | initializer = tf.random_uniform_initializer(-init_scale, init_scale)
133 | with tf.name_scope("Train"):
134 | with tf.variable_scope("Model", reuse=None, initializer=initializer):
135 | train_graph = ModelGraph(word_vocab=word_vocab, char_vocab=char_vocab, POS_vocab=POS_vocab,
136 | NER_vocab=NER_vocab, options=FLAGS, mode="rl_train_for_phrase")
137 |
138 | with tf.name_scope("Valid"):
139 | with tf.variable_scope("Model", reuse=True, initializer=initializer):
140 | valid_graph = ModelGraph(word_vocab=word_vocab, char_vocab=char_vocab, POS_vocab=POS_vocab,
141 | NER_vocab=NER_vocab, options=FLAGS, mode="decode")
142 |
143 | initializer = tf.global_variables_initializer()
144 |
145 | vars_ = {}
146 | for var in tf.all_variables():
147 | if "word_embedding" in var.name: continue
148 | if not var.name.startswith("Model"): continue
149 | vars_[var.name.split(":")[0]] = var
150 | saver = tf.train.Saver(vars_)
151 |
152 | sess = tf.Session()
153 | sess.run(initializer)
154 | if has_pretrained_model:
155 | print("Restoring model from " + init_model_prefix + ".best.model")
156 | saver.restore(sess, init_model_prefix + ".best.model")
157 | print("DONE!")
158 | sys.stdout.flush()
159 |
160 | # for first-time rl training, we get the current BLEU score
161 | print("First-time rl training, get the current BLEU score on dev")
162 | sys.stdout.flush()
163 | best_bleu = evaluate(sess, valid_graph, devDataStream, word_vocab, options=FLAGS)
164 | print('First-time bleu = %.4f' % best_bleu)
165 | log_file.write('First-time bleu = %.4f\n' % best_bleu)
166 |
167 | print('Start the training loop.')
168 | sys.stdout.flush()
169 | train_size = trainDataStream.get_num_batch()
170 | max_steps = train_size * FLAGS.max_epochs
171 | total_loss = 0.0
172 | start_time = time.time()
173 | for step in xrange(max_steps):
174 | cur_batch = trainDataStream.nextBatch()
175 | if FLAGS.with_baseline:
176 | # greedy search
177 | (greedy_sentences, _, _, _) = NP2P_beam_decoder.search(sess, valid_graph, word_vocab,
178 | cur_batch, FLAGS, decode_mode="greedy")
179 |
180 | if FLAGS.with_target_lattice:
181 | (sampled_sentences, sampled_prediction_lengths, sampled_generator_input_idx,
182 | sampled_generator_output_idx) = cur_batch.sample_a_partition()
183 | else:
184 | # multinomial sampling
185 | (sampled_sentences, sampled_prediction_lengths, sampled_generator_input_idx,
186 | sampled_generator_output_idx) = NP2P_beam_decoder.search(sess, valid_graph, word_vocab,
187 | cur_batch, FLAGS, decode_mode="multinomial")
188 | # calculate rewards
189 | rewards = []
190 | for i in xrange(cur_batch.batch_size):
191 | # print(sampled_sentences[i])
192 | # print(sampled_generator_input_idx[i])
193 | # print(sampled_generator_output_idx[i])
194 | cur_toks = cur_batch.instances[i][1].tokText.split()
195 | # r = sentence_bleu([cur_toks], sampled_sentences[i].split(), smoothing_function=cc.method3)
196 | r = 1.0
197 | b = 0.0
198 | if FLAGS.with_baseline:
199 | b = sentence_bleu([cur_toks], greedy_sentences[i].split(), smoothing_function=cc.method3)
200 | # r = metric_utils.evaluate_captions([cur_toks],[sampled_sentences[i]])
201 | # b = metric_utils.evaluate_captions([cur_toks],[greedy_sentences[i]])
202 | rewards.append(1.0* (r-b))
203 | rewards = np.array(rewards, dtype=np.float32)
204 | # sys.exit(-1)
205 |
206 | # update parameters
207 | feed_dict = train_graph.run_encoder(sess, cur_batch, FLAGS, only_feed_dict=True)
208 | feed_dict[train_graph.reward] = rewards
209 | feed_dict[train_graph.gen_input_words] = sampled_generator_input_idx
210 | feed_dict[train_graph.in_answer_words] = sampled_generator_output_idx
211 | feed_dict[train_graph.answer_lengths] = sampled_prediction_lengths
212 | (_, loss_value) = sess.run([train_graph.train_op, train_graph.loss], feed_dict)
213 | total_loss += loss_value
214 |
215 | if step % 100==0:
216 | print('{} '.format(step), end="")
217 | sys.stdout.flush()
218 |
219 | # Save a checkpoint and evaluate the model periodically.
220 | if (step + 1) % trainDataStream.get_num_batch() == 0 or (step + 1) == max_steps:
221 | print()
222 | duration = time.time() - start_time
223 | print('Step %d: loss = %.2f (%.3f sec)' % (step, total_loss, duration))
224 | log_file.write('Step %d: loss = %.2f (%.3f sec)\n' % (step, total_loss, duration))
225 | log_file.flush()
226 | sys.stdout.flush()
227 | total_loss = 0.0
228 |
229 | # Evaluate against the validation set.
230 | start_time = time.time()
231 | print('Validation Data Eval:')
232 | dev_bleu = evaluate(sess, valid_graph, devDataStream, word_vocab, options=FLAGS)
233 | print('Dev bleu = %.4f' % dev_bleu)
234 | log_file.write('Dev bleu = %.4f\n' % dev_bleu)
235 | log_file.flush()
236 | if best_bleu < dev_bleu:
237 | print('Saving weights, BLEU {} (prev_best) < {} (cur)'.format(best_bleu, dev_bleu))
238 | best_bleu = dev_bleu
239 | saver.save(sess, best_path) # TODO: save model
240 | duration = time.time() - start_time
241 | print('Duration %.3f sec' % (duration))
242 | sys.stdout.flush()
243 | log_file.write('Duration %.3f sec\n' % (duration))
244 | log_file.flush()
245 |
246 | log_file.close()
247 |
248 | def enrich_options(options):
249 | if not options.__dict__.has_key("infile_format"):
250 | options.__dict__["infile_format"] = "old"
251 |
252 | if not options.__dict__.has_key("with_target_lattice"):
253 | options.__dict__["with_target_lattice"] = False
254 |
255 | if not options.__dict__.has_key("with_baseline"):
256 | options.__dict__["with_baseline"] = True
257 |
258 | if not options.__dict__.has_key("add_first_word_prob_for_phrase"):
259 | options.__dict__["add_first_word_prob_for_phrase"] = False
260 |
261 | if not options.__dict__.has_key("pretrain_with_max_matching"):
262 | options.__dict__["pretrain_with_max_matching"] = False
263 | return options
264 |
265 |
266 | if __name__ == '__main__':
267 | parser = argparse.ArgumentParser()
268 | parser.add_argument('--config_path', type=str, help='Configuration file.')
269 |
270 | print("CUDA_VISIBLE_DEVICES " + os.environ['CUDA_VISIBLE_DEVICES'])
271 | FLAGS, unparsed = parser.parse_known_args()
272 |
273 |
274 | if FLAGS.config_path is not None:
275 | print('Loading the configuration from ' + FLAGS.config_path)
276 | FLAGS = namespace_utils.load_namespace(FLAGS.config_path)
277 |
278 | FLAGS = enrich_options(FLAGS)
279 |
280 | sys.stdout.flush()
281 | tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
282 |
--------------------------------------------------------------------------------
/src_g2s/G2S_beam_decoder.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from __future__ import print_function
3 | import argparse
4 | import re
5 | import os
6 | import sys
7 | import time
8 | import numpy as np
9 |
10 | import tensorflow as tf
11 | import namespace_utils
12 |
13 | import G2S_trainer
14 | import G2S_data_stream
15 | from G2S_model_graph import ModelGraph
16 |
17 | from vocab_utils import Vocab
18 |
19 |
20 |
21 | tf.logging.set_verbosity(tf.logging.ERROR) # DEBUG, INFO, WARN, ERROR, and FATAL
22 |
23 | def map_idx_to_word(predictions, vocab, passage, attn_dist):
24 | '''
25 | predictions: [batch, 1]
26 | '''
27 | word_size = vocab.vocab_size + 1
28 | all_words = []
29 | all_word_idx = []
30 | for i, idx in enumerate(predictions):
31 | if idx == vocab.vocab_size:
32 | k = np.argmax(attn_dist[i,])
33 | word = batch.instances[i][-1][k]
34 | else:
35 | word = vocab.getWord(idx)
36 | all_words.append(word)
37 | all_word_idx.append(idx)
38 | return all_words, all_word_idx
39 |
40 | def search(sess, model, vocab, batch, options, decode_mode='greedy'):
41 | '''
42 | for greedy search, multinomial search
43 | '''
44 | # Run the encoder to get the encoder hidden states and decoder initial state
45 | (encoder_states, encoder_features, node_idx, node_mask, initial_state) = model.run_encoder(sess, batch, options)
46 | # encoder_states: [batch_size, passage_len, encode_dim]
47 | # encoder_features: [batch_size, passage_len, attention_vec_size]
48 | # node_idx: [batch_size, passage_len]
49 | # node_mask: [batch_size, passage_len]
50 | # initial_state: a tupel of [batch_size, gen_dim]
51 |
52 | word_t = batch.sent_inp[:,0]
53 | state_t = initial_state
54 | context_t = np.zeros([batch.batch_size, model.encoder_dim])
55 | coverage_t = np.zeros([batch.batch_size, encoder_states.shape[1]])
56 | generator_output_idx = [] # store phrase index prediction
57 | text_results = []
58 | generator_input_idx = [word_t] # store word index
59 | for i in xrange(options.max_answer_len):
60 | if decode_mode == "pointwise": word_t = batch.sent_inp[:,i]
61 | feed_dict = {}
62 | feed_dict[model.init_decoder_state] = state_t
63 | feed_dict[model.context_t_1] = context_t
64 | feed_dict[model.coverage_t_1] = coverage_t
65 | feed_dict[model.word_t] = word_t
66 |
67 | feed_dict[model.encoder_states] = encoder_states
68 | feed_dict[model.encoder_features] = encoder_features
69 | feed_dict[model.node_idx] = node_idx
70 | feed_dict[model.node_mask] = node_mask
71 |
72 | if decode_mode in ["greedy", "pointwise", ]:
73 | prediction = model.greedy_prediction
74 | elif decode_mode == "multinomial":
75 | prediction = model.multinomial_prediction
76 |
77 | (state_t, attn_dist_t, context_t, coverage_t, prediction) = sess.run([model.state_t, model.attn_dist_t, model.context_t,
78 | model.coverage_t, prediction], feed_dict)
79 | # convert prediction to word ids
80 | generator_output_idx.append(prediction)
81 | prediction = np.reshape(prediction, [prediction.size, 1])
82 | cur_words, cur_word_idx = map_idx_to_word(prediction) # [batch_size, 1]
83 | cur_word_idx = np.array(cur_word_idx)
84 | cur_word_idx = np.reshape(cur_word_idx, [cur_word_idx.size])
85 | word_t = cur_word_idx
86 | text_results.append(cur_words)
87 | generator_input_idx.append(cur_word_idx)
88 |
89 | generator_input_idx = generator_input_idx[:-1] # remove the last word to shift one position to the right
90 | generator_input_idx = np.stack(generator_input_idx, axis=1) # [batch_size, max_len]
91 | generator_output_idx = np.stack(generator_output_idx, axis=1) # [batch_size, max_len]
92 |
93 | prediction_lengths = [] # [batch_size]
94 | sentences = [] # [batch_size]
95 | for i in xrange(batch.batch_size):
96 | words = []
97 | for j in xrange(options.max_answer_len):
98 | cur_phrase = text_results[j][i]
99 | words.append(cur_phrase)
100 | if cur_phrase == "": break # filter out based on end symbol
101 | prediction_lengths.append(len(words))
102 | cur_sent = " ".join(words)
103 | sentences.append(cur_sent)
104 |
105 | return (sentences, prediction_lengths, generator_input_idx, generator_output_idx)
106 |
107 | class Hypothesis(object):
108 | def __init__(self, tokens, log_ps, attn, state, context_vector, coverage_vector=None):
109 | self.tokens = tokens # store all tokens
110 | self.log_probs = log_ps # store log_probs for each time-step
111 | self.attn_ids = attn
112 | self.state = state
113 | self.context_vector = context_vector
114 | self.coverage_vector = coverage_vector
115 |
116 | def extend(self, token, log_prob, attn_i, state, context_vector, coverage_vector=None):
117 | return Hypothesis(self.tokens + [token], self.log_probs + [log_prob], self.attn_ids + [attn_i], state,
118 | context_vector, coverage_vector=coverage_vector)
119 |
120 | def latest_token(self):
121 | return self.tokens[-1]
122 |
123 | def avg_log_prob(self):
124 | return np.sum(self.log_probs[1:])/ (len(self.tokens)-1)
125 |
126 | def probs2string(self):
127 | out_string = ""
128 | for prob in self.log_probs:
129 | out_string += " %.4f" % prob
130 | return out_string.strip()
131 |
132 | def idx_seq_to_string(self, passage, vocab, options):
133 | word_size = vocab.vocab_size + 1
134 | all_words = []
135 | for i, idx in enumerate(self.tokens):
136 | cur_word = vocab.getWord(idx)
137 | if cur_word == 'UNK':
138 | idx = passage[self.attn_ids[i]]
139 | cur_word = vocab.getWord(idx)
140 | all_words.append(cur_word)
141 | return " ".join(all_words[1:])
142 |
143 |
144 | def sort_hyps(hyps):
145 | return sorted(hyps, key=lambda h: h.avg_log_prob(), reverse=True)
146 |
147 |
148 |
149 | def run_beam_search(sess, model, vocab, batch, options):
150 | # Run encoder
151 | (encoder_states, encoder_features, node_idx, node_mask, initial_state) = model.run_encoder(sess, batch, options)
152 | # encoder_states: [1, passage_len, encode_dim]
153 | # initial_state: a tupel of [1, gen_dim]
154 | # encoder_features: [1, passage_len, attention_vec_size]
155 | # node_idx: [1, passage_len]
156 | # node_mask: [1, passage_len]
157 |
158 | sent_stop_id = vocab.getIndex('')
159 |
160 | # Initialize this first hypothesis
161 | context_t = np.zeros([model.encoder_dim]) # [encode_dim]
162 | coverage_t = np.zeros((encoder_states.shape[1])) # [passage_len]
163 | hyps = []
164 | hyps.append(Hypothesis([batch.sent_inp[0][0]], [0.0], [-1], initial_state, context_t, coverage_vector=coverage_t))
165 |
166 | # beam search decoding
167 | results = [] # this will contain finished hypotheses (those that have emitted the token)
168 | steps = 0
169 | while steps < options.max_answer_len and len(results) < options.beam_size:
170 | cur_size = len(hyps) # current number of hypothesis in the beam
171 | cur_encoder_states = np.tile(encoder_states, (cur_size, 1, 1))
172 | cur_encoder_features = np.tile(encoder_features, (cur_size, 1, 1)) # [batch_size,passage_len, options.attention_vec_size]
173 | cur_node_idx = np.tile(node_idx, (cur_size, 1)) # [batch_size, passage_len]
174 | cur_node_mask = np.tile(node_mask, (cur_size, 1)) # [batch_size, passage_len]
175 | cur_state_t_1 = [] # [2, gen_steps]
176 | cur_context_t_1 = [] # [batch_size, encoder_dim]
177 | cur_coverage_t_1 = [] # [batch_size, passage_len]
178 | cur_word_t = [] # [batch_size]
179 | for h in hyps:
180 | cur_state_t_1.append(h.state)
181 | cur_context_t_1.append(h.context_vector)
182 | cur_word_t.append(h.latest_token())
183 | cur_coverage_t_1.append(h.coverage_vector)
184 | cur_context_t_1 = np.stack(cur_context_t_1, axis=0)
185 | cur_coverage_t_1 = np.stack(cur_coverage_t_1, axis=0)
186 | cur_word_t = np.array(cur_word_t)
187 |
188 | cells = [state.c for state in cur_state_t_1]
189 | hidds = [state.h for state in cur_state_t_1]
190 | new_c = np.concatenate(cells, axis=0)
191 | new_h = np.concatenate(hidds, axis=0)
192 | new_dec_init_state = tf.contrib.rnn.LSTMStateTuple(new_c, new_h)
193 |
194 | feed_dict = {}
195 | feed_dict[model.init_decoder_state] = new_dec_init_state
196 | feed_dict[model.context_t_1] = cur_context_t_1
197 | feed_dict[model.word_t] = cur_word_t
198 |
199 | feed_dict[model.encoder_states] = cur_encoder_states
200 | feed_dict[model.encoder_features] = cur_encoder_features
201 | feed_dict[model.nodes] = cur_node_idx
202 | feed_dict[model.nodes_mask] = cur_node_mask
203 | feed_dict[model.coverage_t_1] = cur_coverage_t_1
204 |
205 | (state_t, context_t, attn_dist_t, coverage_t, topk_log_probs, topk_ids) = sess.run([model.state_t, model.context_t, model.attn_dist_t,
206 | model.coverage_t, model.topk_log_probs, model.topk_ids], feed_dict)
207 |
208 | new_states = [tf.contrib.rnn.LSTMStateTuple(state_t.c[i:i+1, :], state_t.h[i:i+1, :]) for i in xrange(cur_size)]
209 |
210 | # Extend each hypothesis and collect them all in all_hyps
211 | if steps == 0: cur_size = 1
212 | all_hyps = []
213 | for i in xrange(cur_size):
214 | h = hyps[i]
215 | cur_state = new_states[i]
216 | cur_context = context_t[i]
217 | cur_coverage = coverage_t[i]
218 | for j in xrange(options.beam_size):
219 | cur_tok = topk_ids[i, j]
220 | # add anony constraint
221 | #if cur_tok in vocab.anony_ids and cur_tok not in batch.amr_anony_ids:
222 | # continue
223 | cur_tok_log_prob = topk_log_probs[i, j]
224 | cur_attn_i = np.argmax(attn_dist_t[i, :])
225 | new_hyp = h.extend(cur_tok, cur_tok_log_prob, cur_attn_i, cur_state, cur_context, coverage_vector=cur_coverage)
226 | all_hyps.append(new_hyp)
227 |
228 | # Filter and collect any hypotheses that have produced the end token.
229 | # hyps will contain hypotheses for the next step
230 | hyps = []
231 | for h in sort_hyps(all_hyps):
232 | # If this hypothesis is sufficiently long, put in results. Otherwise discard.
233 | if h.latest_token() == sent_stop_id:
234 | if steps >= options.min_answer_len:
235 | results.append(h)
236 | # hasn't reached stop token, so continue to extend this hypothesis
237 | else:
238 | hyps.append(h)
239 | if len(hyps) == options.beam_size or len(results) == options.beam_size:
240 | break
241 |
242 | steps += 1
243 |
244 | # At this point, either we've got beam_size results, or we've reached maximum decoder steps
245 | # if we don't have any complete results, add all current hypotheses (incomplete summaries) to results
246 | if len(results)==0:
247 | results = hyps
248 |
249 | # Sort hypotheses by average log probability
250 | hyps_sorted = sort_hyps(results)
251 |
252 | # Return the hypothesis with highest average log prob
253 | return hyps_sorted
254 |
255 | if __name__ == '__main__':
256 | parser = argparse.ArgumentParser()
257 | parser.add_argument('--model_prefix', type=str, required=True, help='Prefix to the models.')
258 | parser.add_argument('--in_path', type=str, required=True, help='The path to the test file.')
259 | parser.add_argument('--out_path', type=str, help='The path to the output file.')
260 | parser.add_argument('--mode', type=str,default='pointwise', help='The path to the output file.')
261 |
262 | args, unparsed = parser.parse_known_args()
263 |
264 | model_prefix = args.model_prefix
265 | in_path = args.in_path
266 | out_path = args.out_path
267 | mode = args.mode
268 |
269 | print("CUDA_VISIBLE_DEVICES " + os.environ['CUDA_VISIBLE_DEVICES'])
270 |
271 | # load the configuration file
272 | print('Loading configurations from ' + model_prefix + ".config.json")
273 | FLAGS = namespace_utils.load_namespace(model_prefix + ".config.json")
274 | FLAGS = G2S_trainer.enrich_options(FLAGS)
275 |
276 | # load vocabs
277 | print('Loading vocabs.')
278 | word_vocab = Vocab(FLAGS.word_vec_path, fileformat='txt2')
279 | print('word_vocab: {}'.format(word_vocab.word_vecs.shape))
280 | edgelabel_vocab = Vocab(model_prefix + ".edgelabel_vocab", fileformat='txt2')
281 | print('edgelabel_vocab: {}'.format(edgelabel_vocab.word_vecs.shape))
282 | char_vocab = None
283 | if FLAGS.with_char:
284 | char_vocab = Vocab(model_prefix + ".char_vocab", fileformat='txt2')
285 | print('char_vocab: {}'.format(char_vocab.word_vecs.shape))
286 |
287 |
288 | print('Loading test set from {}.'.format(in_path))
289 | testset, _, _, _, _ = G2S_data_stream.read_amr_file(in_path)
290 | print('Number of samples: {}'.format(len(testset)))
291 |
292 | print('Build DataStream ... ')
293 | batch_size=-1
294 | if mode not in ('pointwise', 'multinomial', 'greedy', 'greedy_evaluate', ):
295 | batch_size = 1
296 |
297 | devDataStream = G2S_data_stream.G2SDataStream(testset, word_vocab, char_vocab, edgelabel_vocab, options=FLAGS,
298 | isShuffle=False, isLoop=False, isSort=True, batch_size=batch_size)
299 | print('Number of instances in testDataStream: {}'.format(devDataStream.get_num_instance()))
300 | print('Number of batches in testDataStream: {}'.format(devDataStream.get_num_batch()))
301 |
302 | best_path = model_prefix + ".best.model"
303 | with tf.Graph().as_default():
304 | initializer = tf.random_uniform_initializer(-0.01, 0.01)
305 | with tf.name_scope("Valid"):
306 | with tf.variable_scope("Model", reuse=False, initializer=initializer):
307 | valid_graph = ModelGraph(word_vocab=word_vocab, char_vocab=char_vocab, Edgelabel_vocab=edgelabel_vocab,
308 | options=FLAGS, mode="decode")
309 |
310 | ## remove word _embedding
311 | vars_ = {}
312 | for var in tf.all_variables():
313 | if FLAGS.fix_word_vec and "word_embedding" in var.name: continue
314 | if not var.name.startswith("Model"): continue
315 | vars_[var.name.split(":")[0]] = var
316 | saver = tf.train.Saver(vars_)
317 |
318 | initializer = tf.global_variables_initializer()
319 | sess = tf.Session()
320 | sess.run(initializer)
321 |
322 | saver.restore(sess, best_path) # restore the model
323 |
324 | total = 0
325 | correct = 0
326 | if mode.endswith('evaluate'):
327 | ref_outfile = open(out_path+ ".ref", 'wt')
328 | pred_outfile = open(out_path+ ".pred", 'wt')
329 | else:
330 | outfile = open(out_path, 'wt')
331 | total_num = devDataStream.get_num_batch()
332 | devDataStream.reset()
333 | for i in range(total_num):
334 | cur_batch = devDataStream.get_batch(i)
335 | if mode in ['greedy', 'multinomial']:
336 | print('Batch {}'.format(i))
337 | (sentences, prediction_lengths, generator_input_idx,
338 | generator_output_idx) = search(sess, valid_graph, word_vocab, cur_batch, FLAGS, decode_mode=mode)
339 | for j in xrange(cur_batch.batch_size):
340 | outfile.write(cur_batch.target_ref[j].encode('utf-8') + "\n")
341 | outfile.write(sentences[j].encode('utf-8') + "\n")
342 | outfile.write("========\n")
343 | outfile.flush()
344 | else: # beam search
345 | print('Instance {}'.format(i))
346 | hyps = run_beam_search(sess, valid_graph, word_vocab, cur_batch, FLAGS)
347 | outfile.write(cur_batch.id[0] + "\n")
348 | outfile.write(' '.join(cur_batch.target_ref[0]).encode('utf-8') + "\n")
349 | for j in xrange(1):
350 | hyp = hyps[j]
351 | cur_passage = cur_batch.amr_node[0]
352 | cur_sent = hyp.idx_seq_to_string(cur_passage, word_vocab, FLAGS)
353 | outfile.write(cur_sent.encode('utf-8') + "\n")
354 | outfile.write("--------\n")
355 | outfile.write("========\n")
356 | outfile.flush()
357 | if mode.endswith('evaluate'):
358 | ref_outfile.close()
359 | pred_outfile.close()
360 | else:
361 | outfile.close()
362 |
363 |
364 |
365 |
366 |
--------------------------------------------------------------------------------
/src_g2s/G2S_model_graph.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import graph_encoder_utils
3 | import generator_utils
4 | import padding_utils
5 | from tensorflow.python.ops import variable_scope
6 | import numpy as np
7 | import random
8 |
9 | from nltk.translate.bleu_score import SmoothingFunction, sentence_bleu
10 | cc = SmoothingFunction()
11 |
12 | class ModelGraph(object):
13 | def __init__(self, word_vocab, char_vocab, Edgelabel_vocab, options=None, mode='ce_train'):
14 | # here 'mode', whose value can be:
15 | # 'ce_train',
16 | # 'rl_train',
17 | # 'evaluate',
18 | # 'evaluate_bleu',
19 | # 'decode'.
20 | # it is different from 'mode_gen' in generator_utils.py
21 | # value of 'mode_gen' can be ['ce_loss', 'rl_loss', 'greedy' or 'sample']
22 | self.mode = mode
23 |
24 | # is_training controls whether to use dropout
25 | is_training = True if mode in ('ce_train', ) else False
26 |
27 | self.options = options
28 | self.word_vocab = word_vocab
29 |
30 | # encode the input instance
31 | # encoder.graph_hidden [batch, node_num, vsize]
32 | # encoder.graph_cell [batch, node_num, vsize]
33 | self.encoder = graph_encoder_utils.GraphEncoder(
34 | word_vocab = word_vocab,
35 | edge_label_vocab = Edgelabel_vocab,
36 | char_vocab = char_vocab,
37 | is_training = is_training, options = options)
38 |
39 | # ============== Choices of attention memory ================
40 | if options.attention_type == 'hidden':
41 | self.encoder_dim = options.neighbor_vector_dim
42 | self.encoder_states = self.encoder.graph_hiddens
43 | elif options.attention_type == 'hidden_cell':
44 | self.encoder_dim = options.neighbor_vector_dim * 2
45 | self.encoder_states = tf.concat([self.encoder.graph_hiddens, self.encoder.graph_cells], 2)
46 | elif options.attention_type == 'hidden_embed':
47 | self.encoder_dim = options.neighbor_vector_dim + self.encoder.input_dim
48 | self.encoder_states = tf.concat([self.encoder.graph_hiddens, self.encoder.node_representations], 2)
49 | else:
50 | assert False, '%s not supported yet' % options.attention_type
51 |
52 | # ============== Choices of initializing decoder state =============
53 | if options.way_init_decoder == 'zero':
54 | new_c = tf.zeros([self.encoder.batch_size, options.gen_hidden_size])
55 | new_h = tf.zeros([self.encoder.batch_size, options.gen_hidden_size])
56 | elif options.way_init_decoder == 'all':
57 | new_c = tf.reduce_sum(self.encoder.graph_cells, axis=1)
58 | new_h = tf.reduce_sum(self.encoder.graph_hiddens, axis=1)
59 | elif options.way_init_decoder == 'root':
60 | new_c = self.encoder.graph_cells[:,0,:]
61 | new_h = self.encoder.graph_hiddens[:,0,:]
62 | else:
63 | assert False, 'way to initial decoder (%s) not supported' % options.way_init_decoder
64 | self.init_decoder_state = tf.contrib.rnn.LSTMStateTuple(new_c, new_h)
65 |
66 | # prepare AMR-side input for decoder
67 | self.nodes = self.encoder.passage_nodes
68 | self.nodes_num = self.encoder.passage_nodes_size
69 | if options.with_char:
70 | self.nodes_chars = self.encoder.passage_nodes_chars
71 | self.nodes_chars_num = self.encoder.passage_nodes_chars_size
72 | self.nodes_mask = self.encoder.passage_nodes_mask
73 |
74 | self.in_neigh_indices = self.encoder.passage_in_neighbor_indices
75 | self.in_neigh_edges = self.encoder.passage_in_neighbor_edges
76 | self.in_neigh_mask = self.encoder.passage_in_neighbor_mask
77 |
78 | self.out_neigh_indices = self.encoder.passage_out_neighbor_indices
79 | self.out_neigh_edges = self.encoder.passage_out_neighbor_edges
80 | self.out_neigh_mask = self.encoder.passage_out_neighbor_mask
81 |
82 | self.create_placeholders(options)
83 |
84 | loss_weights = tf.sequence_mask(self.answer_len, options.max_answer_len, dtype=tf.float32) # [batch_size, gen_steps]
85 |
86 | with variable_scope.variable_scope("generator"):
87 | # create generator
88 | self.generator = generator_utils.CovCopyAttenGen(self, options, word_vocab)
89 | # calculate encoder_features
90 | self.encoder_features = self.generator.calculate_encoder_features(self.encoder_states, self.encoder_dim)
91 |
92 | if mode == 'decode':
93 | self.context_t_1 = tf.placeholder(tf.float32, [None, self.encoder_dim], name='context_t_1') # [batch_size, encoder_dim]
94 | self.coverage_t_1 = tf.placeholder(tf.float32, [None, None], name='coverage_t_1') # [batch_size, encoder_dim]
95 | self.word_t = tf.placeholder(tf.int32, [None], name='word_t') # [batch_size]
96 |
97 | (self.state_t, self.context_t, self.coverage_t, self.attn_dist_t, self.p_gen_t, self.ouput_t,
98 | self.topk_log_probs, self.topk_ids, self.greedy_prediction, self.multinomial_prediction) = self.generator.decode_mode(
99 | word_vocab, options.beam_size, self.init_decoder_state, self.context_t_1, self.coverage_t_1, self.word_t,
100 | self.encoder_states, self.encoder_features, self.nodes, self.nodes_mask)
101 | # not buiding training op for this mode
102 | return
103 | elif mode == 'evaluate_bleu':
104 | _, _, self.greedy_words = self.generator.train_mode(word_vocab, self.encoder_dim, self.encoder_states, self.encoder_features,
105 | self.nodes, self.nodes_mask, self.init_decoder_state,
106 | self.answer_inp, self.answer_ref, loss_weights, mode_gen='greedy')
107 | # not buiding training op for this mode
108 | return
109 | elif mode in ('ce_train', 'evaluate', ):
110 | self.accu, self.loss, _ = self.generator.train_mode(word_vocab, self.encoder_dim, self.encoder_states, self.encoder_features,
111 | self.nodes, self.nodes_mask, self.init_decoder_state,
112 | self.answer_inp, self.answer_ref, loss_weights, mode_gen='ce_loss')
113 | if mode == 'evaluate': return # not buiding training op for evaluation
114 | elif mode == 'rl_train':
115 | _, self.loss, _ = self.generator.train_mode(word_vocab, self.encoder_dim, self.encoder_states,self.encoder_features,
116 | self.nodes, self.nodes_mask, self.init_decoder_state,
117 | self.answer_inp, self.answer_ref, loss_weights, mode_gen='rl_loss')
118 |
119 | tf.get_variable_scope().reuse_variables()
120 |
121 | _, _, self.sampled_words = self.generator.train_mode(word_vocab, self.encoder_dim, self.encoder_states,self.encoder_features,
122 | self.nodes, self.nodes_mask, self.init_decoder_state,
123 | self.answer_inp, self.answer_ref, None, mode_gen='sample')
124 |
125 | _, _, self.greedy_words = self.generator.train_mode(word_vocab, self.encoder_dim, self.encoder_states,self.encoder_features,
126 | self.nodes, self.nodes_mask, self.init_decoder_state,
127 | self.answer_inp, self.answer_ref, None, mode_gen='greedy')
128 |
129 |
130 | if options.optimize_type == 'adadelta':
131 | clipper = 50
132 | optimizer = tf.train.AdadeltaOptimizer(learning_rate=options.learning_rate)
133 | tvars = tf.trainable_variables()
134 | if options.lambda_l2>0.0:
135 | l2_loss = tf.add_n([tf.nn.l2_loss(v) for v in tvars if v.get_shape().ndims > 1])
136 | self.loss = self.loss + options.lambda_l2 * l2_loss
137 | grads, _ = tf.clip_by_global_norm(tf.gradients(self.loss, tvars), clipper)
138 | self.train_op = optimizer.apply_gradients(zip(grads, tvars))
139 | elif options.optimize_type == 'adam':
140 | clipper = 50
141 | optimizer = tf.train.AdamOptimizer(learning_rate=options.learning_rate)
142 | tvars = tf.trainable_variables()
143 | if options.lambda_l2>0.0:
144 | l2_loss = tf.add_n([tf.nn.l2_loss(v) for v in tvars if v.get_shape().ndims > 1])
145 | self.loss = self.loss + options.lambda_l2 * l2_loss
146 | grads, _ = tf.clip_by_global_norm(tf.gradients(self.loss, tvars), clipper)
147 | self.train_op = optimizer.apply_gradients(zip(grads, tvars))
148 |
149 | extra_train_ops = []
150 | train_ops = [self.train_op] + extra_train_ops
151 | self.train_op = tf.group(*train_ops)
152 |
153 | def create_placeholders(self, options):
154 | # build placeholder for answer
155 | self.answer_ref = tf.placeholder(tf.int32, [None, options.max_answer_len], name="answer_ref") # [batch_size, gen_steps]
156 | self.answer_inp = tf.placeholder(tf.int32, [None, options.max_answer_len], name="answer_inp") # [batch_size, gen_steps]
157 | self.answer_len = tf.placeholder(tf.int32, [None], name="answer_len") # [batch_size]
158 |
159 | # build placeholder for reinforcement learning
160 | self.reward = tf.placeholder(tf.float32, [None], name="reward")
161 |
162 |
163 | def run_greedy(self, sess, batch, options):
164 | feed_dict = self.run_encoder(sess, batch, options, only_feed_dict=True) # reuse this function to construct feed_dict
165 | feed_dict[self.answer_inp] = batch.sent_inp
166 | return sess.run(self.greedy_words, feed_dict)
167 |
168 |
169 | def run_ce_training(self, sess, batch, options, only_eval=False):
170 | feed_dict = self.run_encoder(sess, batch, options, only_feed_dict=True) # reuse this function to construct feed_dict
171 | feed_dict[self.answer_inp] = batch.sent_inp
172 | feed_dict[self.answer_ref] = batch.sent_out
173 | feed_dict[self.answer_len] = batch.sent_len
174 |
175 | if only_eval:
176 | return sess.run([self.accu, self.loss], feed_dict)
177 | else:
178 | return sess.run([self.train_op, self.loss], feed_dict)[1]
179 |
180 |
181 | def run_rl_training_subsample(self, sess, batch, options):
182 | flipp = options.flipp if options.__dict__.has_key('flipp') else 0.1
183 |
184 | # make feed_dict
185 | feed_dict = self.run_encoder(sess, batch, options, only_feed_dict=True)
186 | feed_dict[self.answer_inp] = batch.sent_inp
187 |
188 | # get greedy and gold outputs
189 | greedy_output = sess.run(self.greedy_words, feed_dict) # [batch, sent_len]
190 | greedy_output = greedy_output.tolist()
191 | gold_output = batch.sent_out.tolist()
192 |
193 | # generate sample_output by flipping coins
194 | sample_output = np.copy(batch.sent_out)
195 | for i in range(batch.sent_out.shape[0]):
196 | seq_len = min(options.max_answer_len, batch.sent_len[i]-1) # don't change stop token ''
197 | for j in range(seq_len):
198 | if greedy_output[i][j] != 0 and random.random() < flipp:
199 | sample_output[i,j] = greedy_output[i][j]
200 | sample_output = sample_output.tolist()
201 |
202 | st_wid = self.word_vocab.getIndex('')
203 | en_wid = self.word_vocab.getIndex('')
204 |
205 | rl_inputs = []
206 | rl_outputs = []
207 | rl_input_lengths = []
208 | reward = []
209 | for i, (sout,gout) in enumerate(zip(sample_output,greedy_output)):
210 | sout, slex = self.word_vocab.getLexical(sout)
211 | gout, glex = self.word_vocab.getLexical(gout)
212 | rl_inputs.append([st_wid,]+sout[:-1])
213 | rl_outputs.append(sout)
214 | rl_input_lengths.append(len(sout))
215 | _, ref_lex = self.word_vocab.getLexical(gold_output[i])
216 | slst = slex.split()
217 | glst = glex.split()
218 | rlst = ref_lex.split()
219 | if options.reward_type == 'bleu':
220 | r = sentence_bleu([rlst], slst, smoothing_function=cc.method3)
221 | b = sentence_bleu([rlst], glst, smoothing_function=cc.method3)
222 | elif options.reward_type == 'rouge':
223 | r = sentence_rouge(ref_lex, slex, smoothing_function=cc.method3)
224 | b = sentence_rouge(ref_lex, glex, smoothing_function=cc.method3)
225 | reward.append(r-b)
226 | #print('Ref: {}'.format(ref_lex.encode('utf-8','ignore')))
227 | #print('Sample: {}'.format(slex.encode('utf-8','ignore')))
228 | #print('Greedy: {}'.format(glex.encode('utf-8','ignore')))
229 | #print('R-B: {}'.format(reward[-1]))
230 | #print('-----')
231 |
232 | rl_inputs = padding_utils.pad_2d_vals(rl_inputs, len(rl_inputs), self.options.max_answer_len)
233 | rl_outputs = padding_utils.pad_2d_vals(rl_outputs, len(rl_outputs), self.options.max_answer_len)
234 | rl_input_lengths = np.array(rl_input_lengths, dtype=np.int32)
235 | reward = np.array(reward, dtype=np.float32)
236 | assert rl_inputs.shape == rl_outputs.shape
237 |
238 | feed_dict = self.run_encoder(sess, batch, options, only_feed_dict=True)
239 | feed_dict[self.reward] = reward
240 | feed_dict[self.answer_inp] = rl_inputs
241 | feed_dict[self.answer_ref] = rl_outputs
242 | feed_dict[self.answer_len] = rl_input_lengths
243 |
244 | _, loss = sess.run([self.train_op, self.loss], feed_dict)
245 | return loss
246 |
247 |
248 | def run_rl_training_model(self, sess, batch, options):
249 | feed_dict = self.run_encoder(sess, batch, options, only_feed_dict=True)
250 | feed_dict[self.answer_inp] = batch.sent_inp
251 |
252 | sample_output, greedy_output = sess.run(
253 | [self.sampled_words, self.greedy_words], feed_dict)
254 |
255 | sample_output = sample_output.tolist()
256 | greedy_output = greedy_output.tolist()
257 |
258 | st_wid = self.word_vocab.getIndex('')
259 | en_wid = self.word_vocab.getIndex('')
260 |
261 | rl_inputs = []
262 | rl_outputs = []
263 | rl_input_lengths = []
264 | reward = []
265 | for i, (sout,gout) in enumerate(zip(sample_output,greedy_output)):
266 | sout, slex = self.word_vocab.getLexical(sout)
267 | gout, glex = self.word_vocab.getLexical(gout)
268 | rl_inputs.append([st_wid,]+sout[:-1])
269 | rl_outputs.append(sout)
270 | rl_input_lengths.append(len(sout))
271 | ref_lex = batch.instances[i][-1]
272 | #r = metric_utils.evaluate_captions([ref_lex,],[slex,])
273 | #b = metric_utils.evaluate_captions([ref_lex,],[glex,])
274 | slst = slex.split()
275 | glst = glex.split()
276 | rlst = ref_lex.split()
277 | if options.reward_type == 'bleu':
278 | r = sentence_bleu([rlst], slst, smoothing_function=cc.method3)
279 | b = sentence_bleu([rlst], glst, smoothing_function=cc.method3)
280 | elif options.reward_type == 'rouge':
281 | r = sentence_rouge(ref_lex, slex, smoothing_function=cc.method3)
282 | b = sentence_rouge(ref_lex, glex, smoothing_function=cc.method3)
283 | reward.append(r-b)
284 | #print('Ref: {}'.format(ref_lex.encode('utf-8','ignore')))
285 | #print('Sample: {}'.format(slex.encode('utf-8','ignore')))
286 | #print('Greedy: {}'.format(glex.encode('utf-8','ignore')))
287 | #print('R-B: {}'.format(reward[-1]))
288 | #print('-----')
289 |
290 | rl_inputs = padding_utils.pad_2d_vals(rl_inputs, len(rl_inputs), self.options.max_answer_len)
291 | rl_outputs = padding_utils.pad_2d_vals(rl_outputs, len(rl_outputs), self.options.max_answer_len)
292 | rl_input_lengths = np.array(rl_input_lengths, dtype=np.int32)
293 | reward = np.array(reward, dtype=np.float32)
294 | assert rl_inputs.shape == rl_outputs.shape
295 |
296 | feed_dict = self.run_encoder(sess, batch, options, only_feed_dict=True)
297 | feed_dict[self.reward] = reward
298 | feed_dict[self.answer_inp] = rl_inputs
299 | feed_dict[self.answer_out] = rl_outputs
300 | feed_dict[self.answer_len] = rl_input_lengths
301 |
302 | _, loss = sess.run([self.train_op, self.loss], feed_dict)
303 | return loss
304 |
305 | def run_encoder(self, sess, batch, options, only_feed_dict=False):
306 | feed_dict = {}
307 | feed_dict[self.nodes] = batch.nodes
308 | feed_dict[self.nodes_num] = batch.node_num
309 | if options.with_char:
310 | feed_dict[self.nodes_chars] = batch.nodes_chars
311 | feed_dict[self.nodes_chars_num] = batch.nodes_chars_num
312 |
313 | feed_dict[self.in_neigh_indices] = batch.in_neigh_indices
314 | feed_dict[self.in_neigh_edges] = batch.in_neigh_edges
315 | feed_dict[self.in_neigh_mask] = batch.in_neigh_mask
316 |
317 | feed_dict[self.out_neigh_indices] = batch.out_neigh_indices
318 | feed_dict[self.out_neigh_edges] = batch.out_neigh_edges
319 | feed_dict[self.out_neigh_mask] = batch.out_neigh_mask
320 |
321 | if only_feed_dict:
322 | return feed_dict
323 |
324 | return sess.run([self.encoder_states, self.encoder_features, self.nodes, self.nodes_mask, self.init_decoder_state],
325 | feed_dict)
326 |
327 | if __name__ == '__main__':
328 | summary = " Tokyo is the one of the biggest city in the world."
329 | reference = "The capital of Japan, Tokyo, is the center of Japanese economy."
330 |
331 |
--------------------------------------------------------------------------------
/src_g2s/G2S_trainer.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from __future__ import print_function
3 | import argparse
4 | import os
5 | import sys
6 | import time
7 | import numpy as np
8 | import codecs
9 |
10 | from vocab_utils import Vocab
11 | import namespace_utils
12 | import G2S_data_stream
13 | from G2S_model_graph import ModelGraph
14 |
15 | FLAGS = None
16 | import tensorflow as tf
17 | tf.logging.set_verbosity(tf.logging.ERROR) # DEBUG, INFO, WARN, ERROR, and FATAL
18 |
19 | from nltk.translate.bleu_score import SmoothingFunction, corpus_bleu, sentence_bleu
20 | cc = SmoothingFunction()
21 |
22 | import metric_utils
23 |
24 | import platform
25 | def get_machine_name():
26 | return platform.node()
27 |
28 | def vec2string(val):
29 | result = ""
30 | for v in val:
31 | result += " {}".format(v)
32 | return result.strip()
33 |
34 |
35 | def softmax(x):
36 | """Compute softmax values for each sets of scores in x."""
37 | e_x = np.exp(x - np.max(x))
38 | return e_x / e_x.sum()
39 |
40 |
41 | def document_bleu(vocab, gen, ref, suffix=''):
42 | genlex = [vocab.getLexical(x)[1] for x in gen]
43 | reflex = [[vocab.getLexical(x)[1],] for x in ref]
44 | genlst = [x.split() for x in genlex]
45 | reflst = [[x[0].split()] for x in reflex]
46 | f = codecs.open('gen.txt'+suffix,'w','utf-8')
47 | for line in genlex:
48 | print(line, end='\n', file=f)
49 | f.close()
50 | f = codecs.open('ref.txt'+suffix,'w','utf-8')
51 | for line in reflex:
52 | print(line[0], end='\n', file=f)
53 | f.close()
54 | return corpus_bleu(reflst, genlst, smoothing_function=cc.method3)
55 |
56 |
57 | def evaluate(sess, valid_graph, devDataStream, options=None, suffix=''):
58 | devDataStream.reset()
59 | gen = []
60 | ref = []
61 | dev_loss = 0.0
62 | dev_right = 0.0
63 | dev_total = 0.0
64 | for batch_index in xrange(devDataStream.get_num_batch()): # for each batch
65 | cur_batch = devDataStream.get_batch(batch_index)
66 | if valid_graph.mode == 'evaluate':
67 | accu_value, loss_value = valid_graph.run_ce_training(sess, cur_batch, options, only_eval=True)
68 | dev_loss += loss_value
69 | dev_right += accu_value
70 | dev_total += np.sum(cur_batch.sent_len)
71 | elif valid_graph.mode == 'evaluate_bleu':
72 | gen.extend(valid_graph.run_greedy(sess, cur_batch, options).tolist())
73 | ref.extend(cur_batch.sent_out.tolist())
74 | else:
75 | assert False
76 |
77 | if valid_graph.mode == 'evaluate':
78 | return {'dev_loss':dev_loss, 'dev_accu':1.0*dev_right/dev_total, 'dev_right':dev_right, 'dev_total':dev_total, }
79 | else:
80 | return {'dev_bleu':document_bleu(valid_graph.word_vocab,gen,ref,suffix), }
81 |
82 |
83 |
84 | def main(_):
85 | print('Configurations:')
86 | print(FLAGS)
87 |
88 | log_dir = FLAGS.model_dir
89 | if not os.path.exists(log_dir):
90 | os.makedirs(log_dir)
91 |
92 | path_prefix = log_dir + "/G2S.{}".format(FLAGS.suffix)
93 | log_file_path = path_prefix + ".log"
94 | print('Log file path: {}'.format(log_file_path))
95 | log_file = open(log_file_path, 'wt')
96 | log_file.write("{}\n".format(FLAGS))
97 | log_file.flush()
98 |
99 | # save configuration
100 | namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json")
101 |
102 | print('Loading train set.')
103 | trainset, trn_node, trn_in_neigh, trn_out_neigh, trn_sent = G2S_data_stream.read_amr_file(FLAGS.train_path)
104 | print('Number of training samples: {}'.format(len(trainset)))
105 |
106 | print('Loading dev set.')
107 | devset, tst_node, tst_in_neigh, tst_out_neigh, tst_sent = G2S_data_stream.read_amr_file(FLAGS.test_path)
108 | print('Number of dev samples: {}'.format(len(devset)))
109 |
110 | if FLAGS.finetune_path != "":
111 | print('Loading finetune set.')
112 | ftset, ft_node, ft_in_neigh, ft_out_neigh, ft_sent = G2S_data_stream.read_amr_file(FLAGS.finetune_path)
113 | print('Number of finetune samples: {}'.format(len(ftset)))
114 | else:
115 | ftset, ft_node, ft_in_neigh, ft_out_neigh, ft_sent = (None, 0, 0, 0, 0)
116 |
117 | max_node = max(trn_node, tst_node, ft_node)
118 | max_in_neigh = max(trn_in_neigh, tst_in_neigh, ft_in_neigh)
119 | max_out_neigh = max(trn_out_neigh, tst_out_neigh, ft_out_neigh)
120 | max_sent = max(trn_sent, tst_sent, ft_sent)
121 | print('Max node number: {}, while max allowed is {}'.format(max_node, FLAGS.max_node_num))
122 | print('Max parent number: {}, truncated to {}'.format(max_in_neigh, FLAGS.max_in_neigh_num))
123 | print('Max children number: {}, truncated to {}'.format(max_out_neigh, FLAGS.max_out_neigh_num))
124 | print('Max answer length: {}, truncated to {}'.format(max_sent, FLAGS.max_answer_len))
125 |
126 | word_vocab = None
127 | char_vocab = None
128 | edgelabel_vocab = None
129 | has_pretrained_model = False
130 | best_path = path_prefix + ".best.model"
131 | if os.path.exists(best_path + ".index"):
132 | has_pretrained_model = True
133 | print('!!Existing pretrained model. Loading vocabs.')
134 | word_vocab = Vocab(FLAGS.word_vec_path, fileformat='txt2')
135 | print('word_vocab: {}'.format(word_vocab.word_vecs.shape))
136 | char_vocab = None
137 | if FLAGS.with_char:
138 | char_vocab = Vocab(path_prefix + ".char_vocab", fileformat='txt2')
139 | print('char_vocab: {}'.format(char_vocab.word_vecs.shape))
140 | edgelabel_vocab = Vocab(path_prefix + ".edgelabel_vocab", fileformat='txt2')
141 | else:
142 | print('Collecting vocabs.')
143 | (allWords, allChars, allEdgelabels) = G2S_data_stream.collect_vocabs(trainset)
144 | print('Number of words: {}'.format(len(allWords)))
145 | print('Number of allChars: {}'.format(len(allChars)))
146 | print('Number of allEdgelabels: {}'.format(len(allEdgelabels)))
147 |
148 | word_vocab = Vocab(FLAGS.word_vec_path, fileformat='txt2')
149 | char_vocab = None
150 | if FLAGS.with_char:
151 | char_vocab = Vocab(voc=allChars, dim=FLAGS.char_dim, fileformat='build')
152 | char_vocab.dump_to_txt2(path_prefix + ".char_vocab")
153 | edgelabel_vocab = Vocab(voc=allEdgelabels, dim=FLAGS.edgelabel_dim, fileformat='build')
154 | edgelabel_vocab.dump_to_txt2(path_prefix + ".edgelabel_vocab")
155 |
156 | print('word vocab size {}'.format(word_vocab.vocab_size))
157 | sys.stdout.flush()
158 |
159 | print('Build DataStream ... ')
160 | trainDataStream = G2S_data_stream.G2SDataStream(trainset, word_vocab, char_vocab, edgelabel_vocab, options=FLAGS,
161 | isShuffle=True, isLoop=True, isSort=True)
162 |
163 | devDataStream = G2S_data_stream.G2SDataStream(devset, word_vocab, char_vocab, edgelabel_vocab, options=FLAGS,
164 | isShuffle=False, isLoop=False, isSort=True)
165 | print('Number of instances in trainDataStream: {}'.format(trainDataStream.get_num_instance()))
166 | print('Number of instances in devDataStream: {}'.format(devDataStream.get_num_instance()))
167 | print('Number of batches in trainDataStream: {}'.format(trainDataStream.get_num_batch()))
168 | print('Number of batches in devDataStream: {}'.format(devDataStream.get_num_batch()))
169 | if ftset != None:
170 | ftDataStream = G2S_data_stream.G2SDataStream(ftset, word_vocab, char_vocab, edgelabel_vocab, options=FLAGS,
171 | isShuffle=True, isLoop=True, isSort=True)
172 | print('Number of instances in ftDataStream: {}'.format(ftDataStream.get_num_instance()))
173 | print('Number of batches in ftDataStream: {}'.format(ftDataStream.get_num_batch()))
174 |
175 | sys.stdout.flush()
176 |
177 | # initialize the best bleu and accu scores for current training session
178 | best_accu = FLAGS.best_accu if FLAGS.__dict__.has_key('best_accu') else 0.0
179 | best_bleu = FLAGS.best_bleu if FLAGS.__dict__.has_key('best_bleu') else 0.0
180 | if best_accu > 0.0:
181 | print('With initial dev accuracy {}'.format(best_accu))
182 | if best_bleu > 0.0:
183 | print('With initial dev BLEU score {}'.format(best_bleu))
184 |
185 | init_scale = 0.01
186 | with tf.Graph().as_default():
187 | initializer = tf.random_uniform_initializer(-init_scale, init_scale)
188 | with tf.name_scope("Train"):
189 | with tf.variable_scope("Model", reuse=None, initializer=initializer):
190 | train_graph = ModelGraph(word_vocab=word_vocab, Edgelabel_vocab=edgelabel_vocab,
191 | char_vocab=char_vocab, options=FLAGS, mode=FLAGS.mode)
192 |
193 | assert FLAGS.mode in ('ce_train', 'rl_train', )
194 | valid_mode = 'evaluate' if FLAGS.mode == 'ce_train' else 'evaluate_bleu'
195 |
196 | with tf.name_scope("Valid"):
197 | with tf.variable_scope("Model", reuse=True, initializer=initializer):
198 | valid_graph = ModelGraph(word_vocab=word_vocab, Edgelabel_vocab=edgelabel_vocab,
199 | char_vocab=char_vocab, options=FLAGS, mode=valid_mode)
200 |
201 | initializer = tf.global_variables_initializer()
202 |
203 | vars_ = {}
204 | for var in tf.all_variables():
205 | if FLAGS.fix_word_vec and "word_embedding" in var.name: continue
206 | if not var.name.startswith("Model"): continue
207 | vars_[var.name.split(":")[0]] = var
208 | print(var)
209 | saver = tf.train.Saver(vars_)
210 |
211 | sess = tf.Session()
212 | sess.run(initializer)
213 | if has_pretrained_model:
214 | print("Restoring model from " + best_path)
215 | saver.restore(sess, best_path)
216 | print("DONE!")
217 |
218 | if FLAGS.mode == 'rl_train' and abs(best_bleu) < 0.00001:
219 | print("Getting BLEU score for the model")
220 | sys.stdout.flush()
221 | best_bleu = evaluate(sess, valid_graph, devDataStream, options=FLAGS)['dev_bleu']
222 | FLAGS.best_bleu = best_bleu
223 | namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json")
224 | print('BLEU = %.4f' % best_bleu)
225 | sys.stdout.flush()
226 | log_file.write('BLEU = %.4f\n' % best_bleu)
227 | if FLAGS.mode == 'ce_train' and abs(best_accu) < 0.00001:
228 | print("Getting ACCU score for the model")
229 | best_accu = evaluate(sess, valid_graph, devDataStream, options=FLAGS)['dev_accu']
230 | FLAGS.best_accu = best_accu
231 | namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json")
232 | print('ACCU = %.4f' % best_accu)
233 | log_file.write('ACCU = %.4f\n' % best_accu)
234 |
235 | print('Start the training loop.')
236 | train_size = trainDataStream.get_num_batch()
237 | max_steps = train_size * FLAGS.max_epochs
238 | total_loss = 0.0
239 | start_time = time.time()
240 | for step in xrange(max_steps):
241 | cur_batch = trainDataStream.nextBatch()
242 | if FLAGS.mode == 'rl_train':
243 | loss_value = train_graph.run_rl_training_subsample(sess, cur_batch, FLAGS)
244 | elif FLAGS.mode == 'ce_train':
245 | loss_value = train_graph.run_ce_training(sess, cur_batch, FLAGS)
246 | total_loss += loss_value
247 |
248 | if step % 100==0:
249 | print('{} '.format(step), end="")
250 | sys.stdout.flush()
251 |
252 | # Save a checkpoint and evaluate the model periodically.
253 | if (step + 1) % trainDataStream.get_num_batch() == 0 or (step + 1) == max_steps or \
254 | (trainDataStream.get_num_batch() > 10000 and (step+1)%2000 == 0):
255 | print()
256 | duration = time.time() - start_time
257 | print('Step %d: loss = %.2f (%.3f sec)' % (step, total_loss, duration))
258 | log_file.write('Step %d: loss = %.2f (%.3f sec)\n' % (step, total_loss, duration))
259 | log_file.flush()
260 | sys.stdout.flush()
261 | total_loss = 0.0
262 |
263 | if ftset != None:
264 | best_accu, best_bleu = fine_tune(sess, saver, FLAGS, log_file,
265 | ftDataStream, devDataStream, train_graph, valid_graph, path_prefix, best_accu, best_bleu)
266 | else:
267 | best_accu, best_bleu = validate_and_save(sess, saver, FLAGS, log_file,
268 | devDataStream, valid_graph, path_prefix, best_accu, best_bleu)
269 | start_time = time.time()
270 |
271 | log_file.close()
272 |
273 |
274 | def validate_and_save(sess, saver, FLAGS, log_file,
275 | devDataStream, valid_graph, path_prefix, best_accu, best_bleu):
276 | best_path = path_prefix + ".best.model"
277 | start_time = time.time()
278 | print('Validation Data Eval:')
279 | res_dict = evaluate(sess, valid_graph, devDataStream, options=FLAGS)
280 | if valid_graph.mode == 'evaluate':
281 | dev_loss = res_dict['dev_loss']
282 | dev_accu = res_dict['dev_accu']
283 | dev_right = int(res_dict['dev_right'])
284 | dev_total = int(res_dict['dev_total'])
285 | print('Dev loss = %.4f' % dev_loss)
286 | log_file.write('Dev loss = %.4f\n' % dev_loss)
287 | print('Dev accu = %.4f %d/%d' % (dev_accu, dev_right, dev_total))
288 | log_file.write('Dev accu = %.4f %d/%d\n' % (dev_accu, dev_right, dev_total))
289 | log_file.flush()
290 | if best_accu < dev_accu:
291 | print('Saving weights, ACCU {} (prev_best) < {} (cur)'.format(best_accu, dev_accu))
292 | saver.save(sess, best_path)
293 | best_accu = dev_accu
294 | FLAGS.best_accu = dev_accu
295 | namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json")
296 | else:
297 | dev_bleu = res_dict['dev_bleu']
298 | print('Dev bleu = %.4f' % dev_bleu)
299 | log_file.write('Dev bleu = %.4f\n' % dev_bleu)
300 | log_file.flush()
301 | if best_bleu < dev_bleu:
302 | print('Saving weights, BLEU {} (prev_best) < {} (cur)'.format(best_bleu, dev_bleu))
303 | saver.save(sess, best_path)
304 | best_bleu = dev_bleu
305 | FLAGS.best_bleu = dev_bleu
306 | namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json")
307 | duration = time.time() - start_time
308 | print('Duration %.3f sec' % (duration))
309 | sys.stdout.flush()
310 | return best_accu, best_bleu
311 |
312 |
313 | def fine_tune(sess, saver, FLAGS, log_file,
314 | ftDataStream, devDataStream, train_graph, valid_graph, path_prefix, best_accu, best_bleu):
315 | print('=====Start the fine tuning.')
316 | sys.stdout.flush()
317 | max_steps = ftDataStream.get_num_batch() * 1
318 | best_path = path_prefix + ".best.model"
319 | total_loss = 0.0
320 | start_time = time.time()
321 | for step in xrange(max_steps):
322 | cur_batch = ftDataStream.nextBatch()
323 | if FLAGS.mode == 'rl_train':
324 | loss_value = train_graph.run_rl_training_subsample(sess, cur_batch, FLAGS)
325 | elif FLAGS.mode == 'ce_train':
326 | loss_value = train_graph.run_ce_training(sess, cur_batch, FLAGS)
327 | total_loss += loss_value
328 |
329 | if step % 100==0:
330 | print('{} '.format(step), end="")
331 | sys.stdout.flush()
332 |
333 | # Save a checkpoint and evaluate the model periodically.
334 | if (step + 1) % ftDataStream.get_num_batch() == 0 or (step + 1) == max_steps:
335 | print()
336 | duration = time.time() - start_time
337 | print('Step %d: loss = %.2f (%.3f sec)' % (step, total_loss, duration))
338 | sys.stdout.flush()
339 | log_file.write('Step %d: loss = %.2f (%.3f sec)\n' % (step, total_loss, duration))
340 | log_file.flush()
341 | best_accu, best_bleu = validate_and_save(sess, saver, FLAGS, log_file,
342 | devDataStream, valid_graph, path_prefix, best_accu, best_bleu)
343 | total_loss = 0.0
344 | start_time = time.time()
345 |
346 | print('=====End the fine tuning.')
347 | sys.stdout.flush()
348 | return best_accu, best_bleu
349 |
350 |
351 | def enrich_options(options):
352 | if not options.__dict__.has_key("finetune_path"):
353 | options.__dict__["finetune_path"] = ""
354 |
355 | if not options.__dict__.has_key("CE_loss"):
356 | options.__dict__["CE_loss"] = False
357 |
358 | if not options.__dict__.has_key("reward_type"):
359 | options.__dict__["reward_type"] = "bleu"
360 |
361 | if not options.__dict__.has_key("way_init_decoder"):
362 | options.__dict__["way_init_decoder"] = 'zero'
363 |
364 | return options
365 |
366 |
367 | if __name__ == '__main__':
368 | parser = argparse.ArgumentParser()
369 | parser.add_argument('--config_path', type=str, help='Configuration file.')
370 |
371 | #os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" # see issue #152
372 | #os.environ["CUDA_VISIBLE_DEVICES"]="2"
373 |
374 | print("CUDA_VISIBLE_DEVICES " + os.environ['CUDA_VISIBLE_DEVICES'])
375 | FLAGS, unparsed = parser.parse_known_args()
376 |
377 |
378 | if FLAGS.config_path is not None:
379 | print('Loading the configuration from ' + FLAGS.config_path)
380 | FLAGS = namespace_utils.load_namespace(FLAGS.config_path)
381 |
382 | FLAGS = enrich_options(FLAGS)
383 |
384 | sys.stdout.flush()
385 | tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
386 |
--------------------------------------------------------------------------------
/src_s2s/NP2P_trainer.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from __future__ import print_function
3 | import argparse
4 | import os
5 | import sys
6 | import time
7 | import numpy as np
8 | import codecs
9 |
10 | from vocab_utils import Vocab
11 | import namespace_utils
12 | import NP2P_data_stream
13 | from NP2P_model_graph import ModelGraph
14 |
15 | FLAGS = None
16 | import tensorflow as tf
17 | tf.logging.set_verbosity(tf.logging.ERROR) # DEBUG, INFO, WARN, ERROR, and FATAL
18 |
19 | from nltk.translate.bleu_score import SmoothingFunction, corpus_bleu, sentence_bleu
20 | cc = SmoothingFunction()
21 |
22 | import metric_utils
23 |
24 | import platform
25 | def get_machine_name():
26 | return platform.node()
27 |
28 | def vec2string(val):
29 | result = ""
30 | for v in val:
31 | result += " {}".format(v)
32 | return result.strip()
33 |
34 |
35 | def softmax(x):
36 | """Compute softmax values for each sets of scores in x."""
37 | e_x = np.exp(x - np.max(x))
38 | return e_x / e_x.sum()
39 |
40 |
41 | def document_bleu(vocab, gen, ref, suffix=''):
42 | genlex = [vocab.getLexical(x)[1] for x in gen]
43 | reflex = [[vocab.getLexical(x)[1],] for x in ref]
44 | #return metric_utils.evaluate_captions(genlex,reflex)
45 | genlst = [x.split() for x in genlex]
46 | reflst = [[x[0].split()] for x in reflex]
47 | f = codecs.open('gen.txt'+suffix,'w','utf-8')
48 | for line in genlex:
49 | print(line, end='\n', file=f)
50 | f.close()
51 | f = codecs.open('ref.txt'+suffix,'w','utf-8')
52 | for line in reflex:
53 | print(line[0], end='\n', file=f)
54 | f.close()
55 | return corpus_bleu(reflst, genlst, smoothing_function=cc.method3)
56 |
57 |
58 | def evaluate(sess, valid_graph, devDataStream, options=None, suffix=''):
59 | devDataStream.reset()
60 | gen = []
61 | ref = []
62 | dev_loss = 0.0
63 | dev_right = 0.0
64 | dev_total = 0.0
65 | for batch_index in xrange(devDataStream.get_num_batch()): # for each batch
66 | cur_batch = devDataStream.get_batch(batch_index)
67 | if valid_graph.mode == 'evaluate':
68 | accu_value, loss_value = valid_graph.run_ce_training(sess, cur_batch, options, only_eval=True)
69 | dev_loss += loss_value
70 | dev_right += accu_value
71 | dev_total += np.sum(cur_batch.answer_lengths)
72 | elif valid_graph.mode == 'evaluate_bleu':
73 | gen.extend(valid_graph.run_greedy(sess, cur_batch, options).tolist())
74 | ref.extend(cur_batch.in_answer_words.tolist())
75 | else:
76 | assert False
77 |
78 | if valid_graph.mode == 'evaluate':
79 | return {'dev_loss':dev_loss, 'dev_accu':1.0*dev_right/dev_total, 'dev_right':dev_right, 'dev_total':dev_total, }
80 | else:
81 | return {'dev_bleu':document_bleu(valid_graph.dec_word_vocab,gen,ref,suffix), }
82 |
83 |
84 |
85 | def main(_):
86 | print('Configurations:')
87 | print(FLAGS)
88 |
89 | log_dir = FLAGS.model_dir
90 | if not os.path.exists(log_dir):
91 | os.makedirs(log_dir)
92 |
93 | path_prefix = log_dir + "/NP2P.{}".format(FLAGS.suffix)
94 | log_file_path = path_prefix + ".log"
95 | print('Log file path: {}'.format(log_file_path))
96 | log_file = open(log_file_path, 'wt')
97 | log_file.write("{}\n".format(FLAGS))
98 | log_file.flush()
99 |
100 | # save configuration
101 | namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json")
102 |
103 | print('Loading training set.')
104 | trainset, train_ans_len = NP2P_data_stream.read_all_GenerationDatasets(FLAGS.train_path, isLower=FLAGS.isLower)
105 | print('Number of training samples: {}'.format(len(trainset)))
106 |
107 | print('Loading dev set.')
108 | devset, dev_ans_len = NP2P_data_stream.read_all_GenerationDatasets(FLAGS.test_path, isLower=FLAGS.isLower)
109 | print('Number of dev samples: {}'.format(len(devset)))
110 |
111 | if FLAGS.finetune_path != "":
112 | print('Loading finetune set.')
113 | ftset, ft_ans_len = NP2P_data_stream.read_all_GenerationDatasets(FLAGS.ft_path, isLower=FLAGS.isLower)
114 | print('Number of finetune samples: {}'.format(len(ftset)))
115 | else:
116 | ftset, ft_ans_len = (None, 0)
117 |
118 | max_actual_len = max(train_ans_len, ft_ans_len, dev_ans_len)
119 | print('Max answer length: {}, truncated to {}'.format(max_actual_len, FLAGS.max_answer_len))
120 |
121 | enc_word_vocab = None
122 | dec_word_vocab = None
123 | char_vocab = None
124 | has_pretrained_model = False
125 | best_path = path_prefix + ".best.model"
126 | if os.path.exists(best_path + ".index"):
127 | has_pretrained_model = True
128 | print('!!Existing pretrained model. Loading vocabs.')
129 | if FLAGS.with_word:
130 | enc_word_vocab = Vocab(FLAGS.enc_word_vec_path, fileformat='txt2')
131 | dec_word_vocab = Vocab(FLAGS.dec_word_vec_path, fileformat='txt2')
132 | print('Encoder word vocab: {}'.format(enc_word_vocab.word_vecs.shape))
133 | print('Decoder word vocab: {}'.format(dec_word_vocab.word_vecs.shape))
134 | if FLAGS.with_char:
135 | char_vocab = Vocab(path_prefix + ".char_vocab", fileformat='txt2')
136 | print('char_vocab: {}'.format(char_vocab.word_vecs.shape))
137 | else:
138 | print('Collecting vocabs.')
139 | (allWords, allChars) = NP2P_data_stream.collect_vocabs(trainset)
140 | print('Number of words: {}'.format(len(allWords)))
141 | print('Number of allChars: {}'.format(len(allChars)))
142 |
143 | if FLAGS.with_word:
144 | enc_word_vocab = Vocab(FLAGS.enc_word_vec_path, fileformat='txt2')
145 | dec_word_vocab = Vocab(FLAGS.dec_word_vec_path, fileformat='txt2')
146 | if FLAGS.with_char:
147 | char_vocab = Vocab(voc=allChars, dim=FLAGS.char_dim, fileformat='build')
148 | char_vocab.dump_to_txt2(path_prefix + ".char_vocab")
149 |
150 | print('Encoder word vocab size {}'.format(enc_word_vocab.vocab_size))
151 | print('Decoder word vocab size {}'.format(dec_word_vocab.vocab_size))
152 | sys.stdout.flush()
153 |
154 | print('Build DataStream ... ')
155 | trainDataStream = NP2P_data_stream.DataStream(trainset, enc_word_vocab, dec_word_vocab, char_vocab, options=FLAGS,
156 | isShuffle=True, isLoop=True, isSort=True)
157 | devDataStream = NP2P_data_stream.DataStream(devset, enc_word_vocab, dec_word_vocab, char_vocab, options=FLAGS,
158 | isShuffle=False, isLoop=False, isSort=True)
159 | print('Number of instances in trainDataStream: {}'.format(trainDataStream.get_num_instance()))
160 | print('Number of instances in devDataStream: {}'.format(devDataStream.get_num_instance()))
161 | print('Number of batches in trainDataStream: {}'.format(trainDataStream.get_num_batch()))
162 | print('Number of batches in devDataStream: {}'.format(devDataStream.get_num_batch()))
163 | if ftset != None:
164 | ftDataStream = NP2P_data_stream.DataStream(ftset, enc_word_vocab, dec_word_vocab, char_vocab, options=FLAGS,
165 | isShuffle=True, isLoop=True, isSort=True)
166 | print('Number of instances in ftDataStream: {}'.format(ftDataStream.get_num_instance()))
167 | print('Number of batches in ftDataStream: {}'.format(ftDataStream.get_num_batch()))
168 |
169 | sys.stdout.flush()
170 |
171 | init_scale = 0.01
172 | # initialize the best bleu and accu scores for current training session
173 | best_accu = FLAGS.best_accu if FLAGS.__dict__.has_key('best_accu') else 0.0
174 | best_bleu = FLAGS.best_bleu if FLAGS.__dict__.has_key('best_bleu') else 0.0
175 | if best_accu > 0.0:
176 | print('With initial dev accuracy {}'.format(best_accu))
177 | if best_bleu > 0.0:
178 | print('With initial dev BLEU score {}'.format(best_bleu))
179 |
180 | with tf.Graph().as_default():
181 | initializer = tf.random_uniform_initializer(-init_scale, init_scale)
182 | with tf.name_scope("Train"):
183 | with tf.variable_scope("Model", reuse=None, initializer=initializer):
184 | train_graph = ModelGraph(enc_word_vocab=enc_word_vocab, dec_word_vocab=dec_word_vocab, char_vocab=char_vocab,
185 | POS_vocab=None, NER_vocab=None, options=FLAGS, mode=FLAGS.mode)
186 |
187 | assert FLAGS.mode in ('ce_train', 'rl_train', )
188 | valid_mode = 'evaluate' if FLAGS.mode == 'ce_train' else 'evaluate_bleu'
189 |
190 | with tf.name_scope("Valid"):
191 | with tf.variable_scope("Model", reuse=True, initializer=initializer):
192 | valid_graph = ModelGraph(enc_word_vocab=enc_word_vocab, dec_word_vocab=dec_word_vocab, char_vocab=char_vocab,
193 | POS_vocab=None, NER_vocab=None, options=FLAGS, mode=valid_mode)
194 |
195 | initializer = tf.global_variables_initializer()
196 |
197 | vars_ = {}
198 | for var in tf.all_variables():
199 | if FLAGS.fix_word_vec and "word_embedding" in var.name: continue
200 | if not var.name.startswith("Model"): continue
201 | print(var)
202 | vars_[var.name.split(":")[0]] = var
203 | saver = tf.train.Saver(vars_)
204 |
205 | sess = tf.Session()
206 | sess.run(initializer)
207 | if has_pretrained_model:
208 | print("Restoring model from " + best_path)
209 | saver.restore(sess, best_path)
210 | print("DONE!")
211 |
212 | if FLAGS.mode == 'rl_train' and abs(best_bleu) < 0.00001:
213 | print("Getting BLEU score for the model")
214 | best_bleu = evaluate(sess, valid_graph, devDataStream, options=FLAGS)['dev_bleu']
215 | FLAGS.best_bleu = best_bleu
216 | namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json")
217 | print('BLEU = %.4f' % best_bleu)
218 | log_file.write('BLEU = %.4f\n' % best_bleu)
219 | if FLAGS.mode == 'ce_train' and abs(best_accu) < 0.00001:
220 | print("Getting ACCU score for the model")
221 | best_accu = evaluate(sess, valid_graph, devDataStream, options=FLAGS)['dev_accu']
222 | FLAGS.best_accu = best_accu
223 | namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json")
224 | print('ACCU = %.4f' % best_accu)
225 | log_file.write('ACCU = %.4f\n' % best_accu)
226 |
227 | print('Start the training loop.')
228 | train_size = trainDataStream.get_num_batch()
229 | max_steps = train_size * FLAGS.max_epochs
230 | total_loss = 0.0
231 | start_time = time.time()
232 | for step in xrange(max_steps):
233 | cur_batch = trainDataStream.nextBatch()
234 | if FLAGS.mode == 'rl_train':
235 | loss_value = train_graph.run_rl_training_2(sess, cur_batch, FLAGS)
236 | elif FLAGS.mode == 'ce_train':
237 | loss_value = train_graph.run_ce_training(sess, cur_batch, FLAGS)
238 | total_loss += loss_value
239 |
240 | if step % 100==0:
241 | print('{} '.format(step), end="")
242 | sys.stdout.flush()
243 |
244 |
245 | # Save a checkpoint and evaluate the model periodically.
246 | if (step + 1) % trainDataStream.get_num_batch() == 0 or (step + 1) == max_steps or \
247 | (trainDataStream.get_num_batch() > 10000 and (step + 1) % 2000 == 0):
248 | print()
249 | duration = time.time() - start_time
250 | print('Step %d: loss = %.2f (%.3f sec)' % (step, total_loss, duration))
251 | log_file.write('Step %d: loss = %.2f (%.3f sec)\n' % (step, total_loss, duration))
252 | log_file.flush()
253 | sys.stdout.flush()
254 | total_loss = 0.0
255 |
256 | if ftset != None:
257 | best_accu, best_bleu = fine_tune(sess, saver, FLAGS, log_file,
258 | ftDataStream, devDataStream, train_graph, valid_graph, path_prefix, best_accu, best_bleu)
259 | else:
260 | best_accu, best_bleu = validate_and_save(sess, saver, FLAGS, log_file,
261 | devDataStream, valid_graph, path_prefix, best_accu, best_bleu)
262 | start_time = time.time()
263 |
264 | log_file.close()
265 |
266 | def validate_and_save(sess, saver, FLAGS, log_file,
267 | devDataStream, valid_graph, path_prefix, best_accu, best_bleu):
268 | best_path = path_prefix + ".best.model"
269 | # Evaluate against the validation set.
270 | start_time = time.time()
271 | print('Validation Data Eval:')
272 | res_dict = evaluate(sess, valid_graph, devDataStream, options=FLAGS)
273 | if valid_graph.mode == 'evaluate':
274 | dev_loss = res_dict['dev_loss']
275 | dev_accu = res_dict['dev_accu']
276 | dev_right = int(res_dict['dev_right'])
277 | dev_total = int(res_dict['dev_total'])
278 | print('Dev loss = %.4f' % dev_loss)
279 | log_file.write('Dev loss = %.4f\n' % dev_loss)
280 | print('Dev accu = %.4f %d/%d' % (dev_accu, dev_right, dev_total))
281 | log_file.write('Dev accu = %.4f %d/%d\n' % (dev_accu, dev_right, dev_total))
282 | log_file.flush()
283 | if best_accu < dev_accu:
284 | print('Saving weights, ACCU {} (prev_best) < {} (cur)'.format(best_accu, dev_accu))
285 | saver.save(sess, best_path)
286 | best_accu = dev_accu
287 | FLAGS.best_accu = dev_accu
288 | namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json")
289 | else:
290 | dev_bleu = res_dict['dev_bleu']
291 | print('Dev bleu = %.4f' % dev_bleu)
292 | log_file.write('Dev bleu = %.4f\n' % dev_bleu)
293 | log_file.flush()
294 | if best_bleu < dev_bleu:
295 | print('Saving weights, BLEU {} (prev_best) < {} (cur)'.format(best_bleu, dev_bleu))
296 | saver.save(sess, best_path)
297 | best_bleu = dev_bleu
298 | FLAGS.best_bleu = dev_bleu
299 | namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json")
300 | duration = time.time() - start_time
301 | print('Duration %.3f sec' % (duration))
302 | sys.stdout.flush()
303 |
304 | log_file.write('Duration %.3f sec\n' % (duration))
305 | log_file.flush()
306 | return best_accu, best_bleu
307 |
308 |
309 | def fine_tune(sess, saver, FLAGS, log_file,
310 | ftDataStream, devDataStream, train_graph, valid_graph, path_prefix, best_accu, best_bleu):
311 | print('=====Start the fine tuning.')
312 | sys.stdout.flush()
313 | train_size = ftDataStream.get_num_batch()
314 | max_steps = train_size * 3
315 | best_path = path_prefix + ".best.model"
316 | total_loss = 0.0
317 | start_time = time.time()
318 | for step in xrange(max_steps):
319 | cur_batch = ftDataStream.nextBatch()
320 | if FLAGS.mode == 'rl_train':
321 | loss_value = train_graph.run_rl_training_subsample(sess, cur_batch, FLAGS)
322 | elif FLAGS.mode == 'ce_train':
323 | loss_value = train_graph.run_ce_training(sess, cur_batch, FLAGS)
324 | total_loss += loss_value
325 |
326 | if step % 100==0:
327 | print('{} '.format(step), end="")
328 | sys.stdout.flush()
329 |
330 | # Save a checkpoint and evaluate the model periodically.
331 | if (step + 1) % ftDataStream.get_num_batch() == 0 or (step + 1) == max_steps:
332 | print()
333 | duration = time.time() - start_time
334 | print('Step %d: loss = %.2f (%.3f sec)' % (step, total_loss, duration))
335 | sys.stdout.flush()
336 | log_file.write('Step %d: loss = %.2f (%.3f sec)\n' % (step, total_loss, duration))
337 | log_file.flush()
338 | total_loss = 0.0
339 |
340 | best_accu, best_bleu = validate_and_save(sess, saver, FLAGS, log_file,
341 | devDataStream, valid_graph, path_prefix, best_accu, best_bleu)
342 | print('=====End the fine tuning.')
343 | sys.stdout.flush()
344 | return best_accu, best_bleu
345 |
346 |
347 | def enrich_options(options):
348 | if not options.__dict__.has_key("finetune_path"):
349 | options.__dict__["finetune_path"] = ""
350 |
351 | if not options.__dict__.has_key("CE_loss"):
352 | options.__dict__["CE_loss"] = False
353 |
354 | if not options.__dict__.has_key("infile_format"):
355 | options.__dict__["infile_format"] = "plain"
356 |
357 | if not options.__dict__.has_key("with_target_lattice"):
358 | options.__dict__["with_target_lattice"] = False
359 |
360 | if not options.__dict__.has_key("add_first_word_prob_for_phrase"):
361 | options.__dict__["add_first_word_prob_for_phrase"] = False
362 |
363 | if not options.__dict__.has_key("pretrain_with_max_matching"):
364 | options.__dict__["pretrain_with_max_matching"] = False
365 |
366 | if not options.__dict__.has_key("reward_type"):
367 | options.__dict__["reward_type"] = "bleu"
368 |
369 | return options
370 |
371 |
372 | if __name__ == '__main__':
373 | parser = argparse.ArgumentParser()
374 | parser.add_argument('--config_path', type=str, help='Configuration file.')
375 |
376 | #os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" # see issue #152
377 | #os.environ["CUDA_VISIBLE_DEVICES"]="3"
378 |
379 | print("CUDA_VISIBLE_DEVICES " + os.environ['CUDA_VISIBLE_DEVICES'])
380 | FLAGS, unparsed = parser.parse_known_args()
381 |
382 |
383 | if FLAGS.config_path is not None:
384 | print('Loading the configuration from ' + FLAGS.config_path)
385 | FLAGS = namespace_utils.load_namespace(FLAGS.config_path)
386 |
387 | FLAGS = enrich_options(FLAGS)
388 |
389 | sys.stdout.flush()
390 | tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
391 |
--------------------------------------------------------------------------------
/src_s2s/NP2P_beam_decoder.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from __future__ import print_function
3 | import argparse
4 | import os
5 | import sys
6 | import time
7 | import numpy as np
8 |
9 | from vocab_utils import Vocab
10 | import namespace_utils
11 | import NP2P_data_stream
12 | from NP2P_model_graph import ModelGraph
13 |
14 | import re
15 |
16 | import tensorflow as tf
17 | import NP2P_trainer
18 | tf.logging.set_verbosity(tf.logging.ERROR) # DEBUG, INFO, WARN, ERROR, and FATAL
19 |
20 | def search(sess, model, vocab, batch, options, decode_mode='greedy'):
21 | '''
22 | for greedy search, multinomial search
23 | '''
24 | # Run the encoder to get the encoder hidden states and decoder initial state
25 | (phrase_representations, initial_state, encoder_features,phrase_idx, phrase_mask) = model.run_encoder(sess, batch, options)
26 | # phrase_representations: [batch_size, passage_len, encode_dim]
27 | # initial_state: a tupel of [batch_size, gen_dim]
28 | # encoder_features: [batch_size, passage_len, attention_vec_size]
29 | # phrase_idx: [batch_size, passage_len]
30 | # phrase_mask: [batch_size, passage_len]
31 |
32 | word_t = batch.gen_input_words[:,0]
33 | state_t = initial_state
34 | context_t = np.zeros([batch.batch_size, model.encode_dim])
35 | coverage_t = np.zeros((batch.batch_size, phrase_representations.shape[1]))
36 | generator_output_idx = [] # store phrase index prediction
37 | text_results = []
38 | generator_input_idx = [word_t] # store word index
39 | for i in xrange(options.max_answer_len):
40 | if decode_mode == "pointwise": word_t = batch.gen_input_words[:,i]
41 | feed_dict = {}
42 | feed_dict[model.init_decoder_state] = state_t
43 | feed_dict[model.context_t_1] = context_t
44 | feed_dict[model.coverage_t_1] = coverage_t
45 | feed_dict[model.word_t] = word_t
46 |
47 | feed_dict[model.phrase_representations] = phrase_representations
48 | feed_dict[model.encoder_features] = encoder_features
49 | feed_dict[model.phrase_idx] = phrase_idx
50 | feed_dict[model.phrase_mask] = phrase_mask
51 | if options.with_phrase_projection:
52 | feed_dict[model.max_phrase_size] = batch.max_phrase_size
53 | if options.add_first_word_prob_for_phrase:
54 | feed_dict[model.in_passage_words] = batch.sent1_word
55 | feed_dict[model.phrase_starts] = batch.phrase_starts
56 |
57 |
58 |
59 | if decode_mode in ["greedy","pointwise"]:
60 | prediction = model.greedy_prediction
61 | elif decode_mode == "multinomial":
62 | prediction = model.multinomial_prediction
63 |
64 | (state_t, context_t, attn_dist_t, coverage_t, prediction) = sess.run([model.state_t, model.context_t, model.attn_dist_t,
65 | model.coverage_t, prediction], feed_dict)
66 | attn_idx = np.argmax(attn_dist_t, axis=1) # [batch_size]
67 | # convert prediction to word ids
68 | generator_output_idx.append(prediction)
69 | prediction = np.reshape(prediction, [prediction.size, 1])
70 | [cur_words, cur_word_idx] = batch.map_phrase_idx_to_text(prediction) # [batch_size, 1]
71 | cur_word_idx = np.array(cur_word_idx)
72 | cur_word_idx = np.reshape(cur_word_idx, [cur_word_idx.size])
73 | word_t = cur_word_idx
74 | cur_words = flatten_words(cur_words) # [batch_size]
75 |
76 | for i, wword in enumerate(cur_words):
77 | if wword == 'UNK' and attn_idx[i] < len(batch.passage_words[i]):
78 | cur_words[i] = batch.passage_words[i][attn_idx[i]]
79 |
80 | text_results.append(cur_words)
81 | generator_input_idx.append(cur_word_idx)
82 |
83 | generator_input_idx = generator_input_idx[:-1] # remove the last word to shift one position to the right
84 | generator_output_idx = np.stack(generator_output_idx, axis=1) # [batch_size, max_len]
85 | generator_input_idx = np.stack(generator_input_idx, axis=1) # [batch_size, max_len]
86 |
87 | prediction_lengths = [] # [batch_size]
88 | sentences = [] # [batch_size]
89 | for i in xrange(batch.batch_size):
90 | words = []
91 | for j in xrange(options.max_answer_len):
92 | cur_phrase = text_results[j][i]
93 | # cur_phrase = cur_batch_text[j]
94 | words.append(cur_phrase)
95 | if cur_phrase == "": break# filter out based on end symbol
96 | prediction_lengths.append(len(words))
97 | cur_sent = " ".join(words)
98 | sentences.append(cur_sent)
99 |
100 | return (sentences, prediction_lengths, generator_input_idx, generator_output_idx)
101 |
102 | def flatten_words(cur_words):
103 | all_words = []
104 | for i in xrange(len(cur_words)):
105 | all_words.append(cur_words[i][0])
106 | return all_words
107 |
108 | class Hypothesis(object):
109 | def __init__(self, tokens, log_ps, attn, state, context_vector, coverage_vector=None):
110 | self.tokens = tokens # store all tokens
111 | self.log_probs = log_ps # store log_probs for each time-step
112 | self.attn_ids = attn
113 | self.state = state
114 | self.context_vector = context_vector
115 | self.coverage_vector = coverage_vector
116 |
117 | def extend(self, token, log_prob, attn_i, state, context_vector, coverage_vector=None):
118 | return Hypothesis(self.tokens + [token], self.log_probs + [log_prob], self.attn_ids + [attn_i], state,
119 | context_vector, coverage_vector=coverage_vector)
120 |
121 | def latest_token(self):
122 | return self.tokens[-1]
123 |
124 | def avg_log_prob(self):
125 | return np.sum(self.log_probs[1:])/ (len(self.tokens)-1)
126 |
127 | def probs2string(self):
128 | out_string = ""
129 | for prob in self.log_probs:
130 | out_string += " %.4f" % prob
131 | return out_string.strip()
132 |
133 | def idx_seq_to_string(self, passage, id2phrase, vocab, options):
134 | word_size = vocab.vocab_size + 1
135 | all_words = []
136 | for i, idx in enumerate(self.tokens):
137 | cur_word = vocab.getWord(idx)
138 | if cur_word == 'UNK':
139 | cur_word = passage[self.attn_ids[i]]
140 | all_words.append(cur_word)
141 | return " ".join(all_words[1:])
142 |
143 |
144 | def sort_hyps(hyps):
145 | return sorted(hyps, key=lambda h: h.avg_log_prob(), reverse=True)
146 |
147 |
148 |
149 | def run_beam_search(sess, model, vocab, batch, options):
150 | # Run encoder
151 | st = time.time()
152 | (phrase_representations, initial_state, encoder_features,phrase_idx, phrase_mask) = model.run_encoder(sess, batch, options)
153 | encoding_dur = time.time() - st
154 | # phrase_representations: [1, passage_len, encode_dim]
155 | # initial_state: a tupel of [1, gen_dim]
156 | # encoder_features: [1, passage_len, attention_vec_size]
157 | # phrase_idx: [1, passage_len]
158 | # phrase_mask: [1, passage_len]
159 |
160 | sent_stop_id = vocab.getIndex('')
161 |
162 | # Initialize this first hypothesis
163 | context_t = np.zeros([model.encode_dim]) # [encode_dim]
164 | coverage_t = np.zeros((phrase_representations.shape[1])) # [passage_len]
165 | hyps = []
166 | hyps.append(Hypothesis([batch.gen_input_words[0][0]], [0.0], [-1], initial_state, context_t, coverage_vector=coverage_t))
167 |
168 | # beam search decoding
169 | results = [] # this will contain finished hypotheses (those that have emitted the token)
170 | steps = 0
171 | while steps < options.max_answer_len and len(results) < options.beam_size:
172 | cur_size = len(hyps) # current number of hypothesis in the beam
173 | cur_phrase_representations = np.tile(phrase_representations, (cur_size, 1, 1))
174 | cur_encoder_features = np.tile(encoder_features, (cur_size, 1, 1)) # [batch_size,passage_len, options.attention_vec_size]
175 | cur_phrase_idx = np.tile(phrase_idx, (cur_size, 1)) # [batch_size, passage_len]
176 | cur_phrase_mask = np.tile(phrase_mask, (cur_size, 1)) # [batch_size, passage_len]
177 | cur_state_t_1 = [] # [2, gen_steps]
178 | cur_context_t_1 = [] # [batch_size, encoder_dim]
179 | cur_coverage_t_1 = [] # [batch_size, passage_len]
180 | cur_word_t = [] # [batch_size]
181 | for h in hyps:
182 | cur_state_t_1.append(h.state)
183 | cur_context_t_1.append(h.context_vector)
184 | cur_word_t.append(h.latest_token())
185 | cur_coverage_t_1.append(h.coverage_vector)
186 | cur_context_t_1 = np.stack(cur_context_t_1, axis=0)
187 | cur_coverage_t_1 = np.stack(cur_coverage_t_1, axis=0)
188 | cur_word_t = np.array(cur_word_t)
189 |
190 | cells = [state.c for state in cur_state_t_1]
191 | hidds = [state.h for state in cur_state_t_1]
192 | new_c = np.concatenate(cells, axis=0)
193 | new_h = np.concatenate(hidds, axis=0)
194 | new_dec_init_state = tf.contrib.rnn.LSTMStateTuple(new_c, new_h)
195 |
196 | feed_dict = {}
197 | feed_dict[model.init_decoder_state] = new_dec_init_state
198 | feed_dict[model.context_t_1] = cur_context_t_1
199 | feed_dict[model.word_t] = cur_word_t
200 |
201 | feed_dict[model.phrase_representations] = cur_phrase_representations
202 | feed_dict[model.encoder_features] = cur_encoder_features
203 | feed_dict[model.phrase_idx] = cur_phrase_idx
204 | feed_dict[model.phrase_mask] = cur_phrase_mask
205 | feed_dict[model.coverage_t_1] = cur_coverage_t_1
206 | if options.with_phrase_projection:
207 | feed_dict[model.max_phrase_size] = batch.max_phrase_size
208 | if options.add_first_word_prob_for_phrase:
209 | feed_dict[model.in_passage_words] = batch.sent1_word
210 | feed_dict[model.phrase_starts] = batch.phrase_starts
211 |
212 | (state_t, context_t, attn_dist_t, coverage_t, topk_log_probs, topk_ids) = sess.run([model.state_t, model.context_t,
213 | model.attn_dist_t, model.coverage_t, model.topk_log_probs, model.topk_ids], feed_dict)
214 |
215 | new_states = [tf.contrib.rnn.LSTMStateTuple(state_t.c[i:i+1, :], state_t.h[i:i+1, :]) for i in xrange(cur_size)]
216 |
217 | # Extend each hypothesis and collect them all in all_hyps
218 | if steps == 0: cur_size = 1
219 | all_hyps = []
220 | for i in xrange(cur_size):
221 | h = hyps[i]
222 | cur_state = new_states[i]
223 | cur_context = context_t[i]
224 | cur_coverage = coverage_t[i]
225 | for j in xrange(options.beam_size):
226 | cur_tok = topk_ids[i, j]
227 | cur_tok_log_prob = topk_log_probs[i, j]
228 | cur_attn_i = np.argmax(attn_dist_t[i, :])
229 | new_hyp = h.extend(cur_tok, cur_tok_log_prob, cur_attn_i, cur_state, cur_context, coverage_vector=cur_coverage)
230 | all_hyps.append(new_hyp)
231 |
232 | # Filter and collect any hypotheses that have produced the end token.
233 | # hyps will contain hypotheses for the next step
234 | hyps = []
235 | for h in sort_hyps(all_hyps):
236 | # If this hypothesis is sufficiently long, put in results. Otherwise discard.
237 | if h.latest_token() == sent_stop_id:
238 | if steps >= options.min_answer_len:
239 | results.append(h)
240 | # hasn't reached stop token, so continue to extend this hypothesis
241 | else:
242 | hyps.append(h)
243 | if len(hyps) == options.beam_size or len(results) == options.beam_size:
244 | break
245 |
246 | steps += 1
247 |
248 | # At this point, either we've got beam_size results, or we've reached maximum decoder steps
249 | # if we don't have any complete results, add all current hypotheses (incomplete summaries) to results
250 | if len(results)==0:
251 | results = hyps
252 |
253 | # Sort hypotheses by average log probability
254 | hyps_sorted = sort_hyps(results)
255 |
256 | # Return the hypothesis with highest average log prob
257 | return hyps_sorted, encoding_dur
258 |
259 | if __name__ == '__main__':
260 | parser = argparse.ArgumentParser()
261 | parser.add_argument('--model_prefix', type=str, required=True, help='Prefix to the models.')
262 | parser.add_argument('--in_path', type=str, required=True, help='The path to the test file.')
263 | parser.add_argument('--out_path', type=str, help='The path to the output file.')
264 | parser.add_argument('--mode', type=str,default='pointwise', help='The path to the output file.')
265 | parser.add_argument('--beam_size', type=int, default=-1, help='')
266 |
267 | args, unparsed = parser.parse_known_args()
268 |
269 | model_prefix = args.model_prefix
270 | in_path = args.in_path
271 | out_path = args.out_path
272 | mode = args.mode
273 |
274 | print("CUDA_VISIBLE_DEVICES " + os.environ['CUDA_VISIBLE_DEVICES'])
275 |
276 | # load the configuration file
277 | print('Loading configurations from ' + model_prefix + ".config.json")
278 | FLAGS = namespace_utils.load_namespace(model_prefix + ".config.json")
279 | FLAGS = NP2P_trainer.enrich_options(FLAGS)
280 | if args.beam_size != -1:
281 | FLAGS.beam_size = args.beam_size
282 |
283 | # load vocabs
284 | print('Loading vocabs.')
285 | enc_word_vocab = dec_word_vocab = char_vocab = POS_vocab = NER_vocab = None
286 | if FLAGS.with_word:
287 | enc_word_vocab = Vocab(FLAGS.enc_word_vec_path, fileformat='txt2')
288 | print('enc_word_vocab: {}'.format(enc_word_vocab.word_vecs.shape))
289 | dec_word_vocab = Vocab(FLAGS.dec_word_vec_path, fileformat='txt2')
290 | print('dec_word_vocab: {}'.format(dec_word_vocab.word_vecs.shape))
291 | if FLAGS.with_char:
292 | char_vocab = Vocab(model_prefix + ".char_vocab", fileformat='txt2')
293 | print('char_vocab: {}'.format(char_vocab.word_vecs.shape))
294 | if FLAGS.with_POS:
295 | POS_vocab = Vocab(model_prefix + ".POS_vocab", fileformat='txt2')
296 | print('POS_vocab: {}'.format(POS_vocab.word_vecs.shape))
297 | if FLAGS.with_NER:
298 | NER_vocab = Vocab(model_prefix + ".NER_vocab", fileformat='txt2')
299 | print('NER_vocab: {}'.format(NER_vocab.word_vecs.shape))
300 |
301 |
302 | print('Loading test set.')
303 | if FLAGS.infile_format == 'fof':
304 | testset, _ = NP2P_data_stream.read_generation_datasets_from_fof(in_path, isLower=FLAGS.isLower)
305 | elif FLAGS.infile_format == 'plain':
306 | testset, _ = NP2P_data_stream.read_all_GenerationDatasets(in_path, isLower=FLAGS.isLower)
307 | else:
308 | testset, _ = NP2P_data_stream.read_all_GQA_questions(in_path, isLower=FLAGS.isLower, switch=FLAGS.switch_qa)
309 | print('Number of samples: {}'.format(len(testset)))
310 |
311 | print('Build DataStream ... ')
312 | batch_size = -1
313 | if mode not in ('pointwise', 'multinomial', 'greedy', 'greedy_evaluate', ): batch_size = 1
314 | devDataStream = NP2P_data_stream.DataStream(testset, enc_word_vocab, dec_word_vocab, char_vocab, options=FLAGS,
315 | isShuffle=False, isLoop=False, isSort=True, batch_size=batch_size)
316 | print('Number of instances in testDataStream: {}'.format(devDataStream.get_num_instance()))
317 | print('Number of batches in testDataStream: {}'.format(devDataStream.get_num_batch()))
318 |
319 | best_path = model_prefix + ".best.model"
320 | with tf.Graph().as_default():
321 | initializer = tf.random_uniform_initializer(-0.01, 0.01)
322 | with tf.name_scope("Valid"):
323 | with tf.variable_scope("Model", reuse=False, initializer=initializer):
324 | valid_graph = ModelGraph(enc_word_vocab=enc_word_vocab, dec_word_vocab=dec_word_vocab, char_vocab=char_vocab,
325 | options=FLAGS, mode="decode")
326 |
327 | ## remove word _embedding
328 | vars_ = {}
329 | for var in tf.all_variables():
330 | if FLAGS.fix_word_vec and "word_embedding" in var.name: continue
331 | if not var.name.startswith("Model"): continue
332 | vars_[var.name.split(":")[0]] = var
333 | saver = tf.train.Saver(vars_)
334 |
335 | initializer = tf.global_variables_initializer()
336 | sess = tf.Session()
337 | sess.run(initializer)
338 |
339 | saver.restore(sess, best_path) # restore the model
340 |
341 | total = 0
342 | correct = 0
343 | outfile = open(out_path, 'wt')
344 | total_num = devDataStream.get_num_batch()
345 | devDataStream.reset()
346 | total_dur = 0
347 | for i in range(total_num):
348 | cur_batch = devDataStream.get_batch(i)
349 | if mode in ['greedy', 'multinomial']:
350 | print('Batch {}'.format(i))
351 | (sentences, prediction_lengths, generator_input_idx,
352 | generator_output_idx) = search(sess, valid_graph, dec_word_vocab, cur_batch, FLAGS, decode_mode=mode)
353 | for j in xrange(cur_batch.batch_size):
354 | outfile.write(cur_batch.id[j].encode('utf-8') + "\n")
355 | outfile.write(cur_batch.target_ref[j].encode('utf-8') + "\n")
356 | outfile.write(sentences[j].encode('utf-8') + "\n")
357 | outfile.write("========\n")
358 | outfile.flush()
359 | else: # beam search
360 | print('Instance {}'.format(i))
361 | hyps, dur = run_beam_search(sess, valid_graph, dec_word_vocab, cur_batch, FLAGS)
362 | total_dur += dur
363 | outfile.write(cur_batch.id[0].encode('utf-8') + "\n")
364 | outfile.write(cur_batch.target_ref[0].encode('utf-8') + "\n")
365 | for j in xrange(1):
366 | hyp = hyps[j]
367 | cur_passage = cur_batch.source[0]
368 | cur_id2phrase = None
369 | cur_sent = hyp.idx_seq_to_string(cur_passage, cur_id2phrase, dec_word_vocab, FLAGS)
370 | outfile.write(cur_sent.encode('utf-8') + "\n")
371 | outfile.write("--------\n")
372 | outfile.write("========\n")
373 | outfile.flush()
374 | outfile.close()
375 | print('Total encoding time {}'.format(total_dur))
376 |
377 |
378 |
379 |
--------------------------------------------------------------------------------