├── src_g2s ├── __init__.py ├── namespace_utils.py ├── amr_utils.py ├── padding_utils.py ├── metric_utils.py ├── metric_rouge_utils.py ├── metric_bleu_utils.py ├── G2S_data_stream.py ├── G2S_beam_decoder.py ├── G2S_model_graph.py └── G2S_trainer.py ├── train_s2s.sh ├── train_g2s.sh ├── decode_s2s.sh ├── src_s2s ├── namespace_utils.py ├── prepare_question_generation_dataset.py ├── phrase_projection_layer_utils.py ├── padding_utils.py ├── metric_utils.py ├── metric_rouge_utils.py ├── prepare_paraphrase_dataset.py ├── prepare_summarization_dataset.py ├── sent_utils.py ├── NP2P_data_stream.py ├── phrase_lattice_utils.py ├── metric_bleu_utils.py ├── encoder_utils.py ├── NP2P_phrase_trainer.py ├── NP2P_trainer.py └── NP2P_beam_decoder.py ├── decode_g2s.sh ├── data ├── 3_extract_embedding.py ├── 1_make_json.py ├── 4_insert_symbols_front.py └── 2_gen_vocab.py ├── AMR_multiline_to_singleline.py ├── logs_g2s └── extract_and_eval.py ├── config_g2s.json ├── config_s2s.json └── README.md /src_g2s/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /train_s2s.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -J AMR_s2s --partition=gpu --gres=gpu:1 --time=5-00:00:00 --output=train.out_s2s --error=train.err_s2s 3 | #SBATCH --mem=15GB 4 | #SBATCH -c 5 5 | 6 | export PYTHONPATH=$PYTHONPATH:/home/lsong10/ws/exp.graph_to_seq/neural-graph-to-seq-mp 7 | 8 | python src_s2s/NP2P_trainer.py --config_path config_s2s.json 9 | 10 | -------------------------------------------------------------------------------- /train_g2s.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -J g_2m -C K80 --partition=gpu --gres=gpu:1 --time=5-00:00:00 --output=train.out --error=train.err 3 | #SBATCH --mem=80GB 4 | #SBATCH -c 5 5 | 6 | export PYTHONPATH=$PYTHONPATH:/home/lsong10/ws/exp.graph_to_seq/neural-graph-to-seq-mp 7 | 8 | python src_g2s/G2S_trainer.py --config_path logs_g2s/G2S.silver_2m.config.json 9 | 10 | -------------------------------------------------------------------------------- /decode_s2s.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --partition=gpu --gres=gpu:1 --time=1:00:00 --output=decode.out --error=decode.err 3 | #SBATCH --mem=10GB 4 | #SBATCH -c 6 5 | 6 | export PYTHONPATH=$PYTHONPATH:/home/lsong10/ws/exp.graph_to_seq/neural-graph-to-seq-mp 7 | 8 | python src_s2s/NP2P_beam_decoder.py --model_prefix logs_s2s/NP2P.$1 \ 9 | --in_path data/test.json \ 10 | --out_path logs_s2s/test.s2s.$1\.tok \ 11 | --mode beam 12 | 13 | -------------------------------------------------------------------------------- /src_g2s/namespace_utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | class Bunch(object): 4 | def __init__(self, adict): 5 | self.__dict__.update(adict) 6 | 7 | def save_namespace(FLAGS, out_path): 8 | FLAGS_dict = vars(FLAGS) 9 | with open(out_path, 'w') as fp: 10 | json.dump(FLAGS_dict, fp) 11 | 12 | def load_namespace(in_path): 13 | with open(in_path, 'r') as fp: 14 | FLAGS_dict = json.load(fp) 15 | return Bunch(FLAGS_dict) -------------------------------------------------------------------------------- /src_s2s/namespace_utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | class Bunch(object): 4 | def __init__(self, adict): 5 | self.__dict__.update(adict) 6 | 7 | def save_namespace(FLAGS, out_path): 8 | FLAGS_dict = vars(FLAGS) 9 | with open(out_path, 'w') as fp: 10 | json.dump(FLAGS_dict, fp) 11 | 12 | def load_namespace(in_path): 13 | with open(in_path, 'r') as fp: 14 | FLAGS_dict = json.load(fp) 15 | return Bunch(FLAGS_dict) -------------------------------------------------------------------------------- /decode_g2s.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --partition=gpu --gres=gpu:1 -C K80 --time=1:00:00 --output=decode.out --error=decode.err 3 | #SBATCH --mem=10GB 4 | #SBATCH -c 6 5 | 6 | export PYTHONPATH=$PYTHONPATH:/home/lsong10/ws/exp.graph_to_seq/neural-graph-to-seq-mp 7 | 8 | python src_g2s/G2S_beam_decoder.py --model_prefix logs_g2s/G2S.$1 \ 9 | --in_path data/test.json \ 10 | --out_path logs_g2s/test.g2s.$1\.tok \ 11 | --mode beam 12 | 13 | -------------------------------------------------------------------------------- /data/3_extract_embedding.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy 3 | import sys, os 4 | from collections import Counter 5 | 6 | vocab = set(line.strip() for line in open(sys.argv[1], 'rU')) 7 | print 'len(vocab)', len(vocab) 8 | 9 | intersect = set() 10 | f = open(sys.argv[2], 'w') 11 | for line in open('/home/lsong10/ws/data.embedding/glove.840B.300d.txt', 'rU'): 12 | word = line.strip().split()[0] 13 | if word in vocab: 14 | intersect.add(word) 15 | print >>f, line.strip() 16 | print len(intersect) 17 | 18 | for w in vocab - intersect: 19 | embedding = ' '.join([str('%.6f'%x) for x in numpy.random.normal(size=600)]) 20 | print >>f, w, embedding 21 | 22 | f.close() 23 | -------------------------------------------------------------------------------- /src_s2s/prepare_question_generation_dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | def read_all_SQuAD_questions(inpath): 4 | with open(inpath) as dataset_file: 5 | dataset_json = json.load(dataset_file, encoding='utf-8') 6 | dataset = dataset_json['data'] 7 | all_questions = [] 8 | for article in dataset: 9 | # title = article['title'] 10 | for paragraph in article['paragraphs']: 11 | context = paragraph['context'] 12 | for question in paragraph['qas']: 13 | question_text = question['question'] 14 | question_id = question['id'] 15 | answers = question['answers'] 16 | return all_questions 17 | 18 | -------------------------------------------------------------------------------- /data/1_make_json.py: -------------------------------------------------------------------------------- 1 | import sys, os, json 2 | 3 | amr = [x.strip() for x in open(sys.argv[1]+'-dfs-linear_src.txt','rU')] 4 | print 'len(amr)', len(amr) 5 | sent = [x.strip() for x in open(sys.argv[1]+'-dfs-linear_targ.txt','rU')] 6 | print 'len(sent)', len(sent) 7 | assert len(amr) == len(sent) 8 | 9 | ids = None 10 | if os.path.isfile(sys.argv[1]+'-ids.txt'): 11 | ids = [x.strip() for x in open(sys.argv[1]+'-ids.txt','rU')] 12 | assert len(amr) == len(ids) 13 | 14 | data = [] 15 | for i in range(len(amr)): 16 | json_obj = {'amr':amr[i],'sent':sent[i],} 17 | if ids != None: 18 | json_obj['id'] = ids[i] 19 | data.append(json_obj) 20 | print len(data) 21 | json.dump(data,open(sys.argv[1]+'.json','w')) 22 | 23 | -------------------------------------------------------------------------------- /data/4_insert_symbols_front.py: -------------------------------------------------------------------------------- 1 | 2 | import os,sys 3 | import numpy 4 | 5 | inpath = sys.argv[1] 6 | outpath = sys.argv[1]+'.st' 7 | f = open(outpath,'w') 8 | for i,line in enumerate(open(inpath,'rU')): 9 | if i == 0: 10 | vsize = len(line.strip().split())-1 11 | print vsize 12 | print >>f, '\t'.join(['0', '#pad#', ' '.join([str('%.6f'%x) for x in numpy.zeros(vsize)])]) 13 | print >>f, '\t'.join(['1', '', ' '.join([str('%.6f'%x) for x in numpy.random.normal(size=vsize)])]) 14 | print >>f, '\t'.join(['2', '', ' '.join([str('%.6f'%x) for x in numpy.random.normal(size=vsize)])]) 15 | line = line.strip().split() 16 | word = line[0] 17 | line = ' '.join(line[1:]) 18 | print >>f, '\t'.join([str(i+3), word, line]) 19 | 20 | -------------------------------------------------------------------------------- /AMR_multiline_to_singleline.py: -------------------------------------------------------------------------------- 1 | 2 | import sys 3 | 4 | ids = [] 5 | id_dict = {} 6 | amrs = [] 7 | amr_str = '' 8 | for line in open(sys.argv[1],'rU'): 9 | if line.startswith('#'): 10 | if line.startswith('# ::id'): 11 | id = line.lower().strip().split()[2] 12 | ids.append(id) 13 | id_dict[id] = len(ids)-1 14 | continue 15 | line = line.strip() 16 | if line == '': 17 | if amr_str != '': 18 | amrs.append(amr_str.strip()) 19 | amr_str = '' 20 | else: 21 | amr_str = amr_str + line + ' ' 22 | 23 | if amr_str != '': 24 | amrs.append(amr_str.strip()) 25 | amr_str = '' 26 | 27 | if len(sys.argv) == 3: 28 | for line in open(sys.argv[2],'rU'): 29 | id = line.lower().strip() 30 | print amrs[id_dict[id]] 31 | 32 | -------------------------------------------------------------------------------- /logs_g2s/extract_and_eval.py: -------------------------------------------------------------------------------- 1 | 2 | import sys 3 | import os 4 | 5 | output_dict = {} 6 | ref_dict = {} 7 | for i,line in enumerate(open(sys.argv[1],'rU')): 8 | if i%5 == 0: 9 | id = line.strip() 10 | elif i%5 == 1: 11 | ref = line.strip().lower().split() 12 | ref_dict[id] = ref 13 | elif i%5 == 2: 14 | rst = line.strip().replace('','',10).split() 15 | output_dict[id] = rst 16 | 17 | dataset_type = None 18 | if sys.argv[1].startswith('test'): 19 | dataset_type = 'test' 20 | elif sys.argv[1].startswith('dev'): 21 | dataset_type = 'dev' 22 | 23 | fout = open(sys.argv[1]+'.1best','w') 24 | fref = open(sys.argv[1]+'.ref','w') 25 | for id in output_dict.keys(): 26 | print >>fout, ' '.join(output_dict[id]).lower() 27 | print >>fref, ' '.join(ref_dict[id]).lower() 28 | fout.close() 29 | fref.close() 30 | 31 | os.system('/home/lsong10/ws/exp.graph_to_seq/mosesdecoder/scripts/generic/multi-bleu.perl %s.ref < %s.1best' %(sys.argv[1],sys.argv[1])) 32 | 33 | -------------------------------------------------------------------------------- /data/2_gen_vocab.py: -------------------------------------------------------------------------------- 1 | 2 | import json 3 | import sys 4 | import cPickle 5 | import re 6 | from collections import Counter 7 | import codecs 8 | 9 | def update(l, v): 10 | v.update([x.lower() for x in l]) 11 | 12 | def update_vocab(path, vocab, vocab_edge, vocab_node): 13 | words = [] 14 | words_edge = [] 15 | words_node = [] 16 | data = json.load(open(path,'rU')) 17 | for inst in data: 18 | words += inst['sent'].lower().strip().split() 19 | words += [x for x in inst['amr'].lower().strip().split() if x[0] != ':'] 20 | words_edge += [x for x in inst['amr'].lower().strip().split() if x[0] == ':'] 21 | words_node += [x for x in inst['amr'].lower().strip().split() if re.search('_[0-9]+', x) != None or x == 'num_unk'] 22 | update(words, vocab) 23 | update(words_edge, vocab_edge) 24 | update(words_node, vocab_node) 25 | 26 | def output(d, path): 27 | f = codecs.open(path,'w',encoding='utf-8') 28 | for k,v in sorted(d.items(), key=lambda x:-x[1]): 29 | print >>f, k 30 | f.close() 31 | 32 | ################## 33 | 34 | vocab = Counter() 35 | vocab_edge = Counter() 36 | vocab_node = Counter() 37 | update_vocab('training.json', vocab, vocab_edge, vocab_node) 38 | print len(vocab), len(vocab_edge), len(vocab_node) 39 | 40 | output(vocab, 'vocab.txt') 41 | #output(vocab_edge, 'vocab_edge.txt') 42 | #output(vocab_node, 'vocab_node.txt') 43 | 44 | -------------------------------------------------------------------------------- /config_g2s.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_path": "data/2m.json", 3 | "finetune_path": "data/training.json", 4 | "test_path": "data/dev.json", 5 | "word_vec_path": "data/vectors_RwNN.txt.st", 6 | "suffix": "silver_2m", 7 | "model_dir": "logs_g2s", 8 | "isLower": true, 9 | 10 | "pointer_gen": true, 11 | "use_coverage": true, 12 | "attention_vec_size": 600, 13 | "batch_size": 100, 14 | "beam_size": 5, 15 | 16 | "num_syntax_match_layer": 9, 17 | "max_node_num": 180, 18 | "max_in_neigh_num": 2, 19 | "max_out_neigh_num": 8, 20 | "min_answer_len": 0, 21 | "max_answer_len": 50, 22 | "learning_rate": 0.0005, 23 | "lambda_l2": 1e-8, 24 | "dropout_rate": 0.1, 25 | "cov_loss_wt": 0.1, 26 | "max_epochs": 10, 27 | "optimize_type": "adam", 28 | 29 | "with_highway": false, 30 | "highway_layer_num": 1, 31 | 32 | "with_char": false, 33 | "char_dim": 100, 34 | "char_lstm_dim": 100, 35 | "max_char_per_word": 20, 36 | 37 | "attention_type": "hidden", 38 | "way_init_decoder": "all", 39 | "edgelabel_dim": 100, 40 | "neighbor_vector_dim": 600, 41 | "fix_word_vec": false, 42 | "compress_input": false, 43 | "compress_input_dim": 300, 44 | 45 | "gen_hidden_size": 600, 46 | "num_softmax_samples": 100, 47 | "mode": "ce_train", 48 | 49 | "config_path": "config.json", 50 | "generate_config": false 51 | } 52 | -------------------------------------------------------------------------------- /config_s2s.json: -------------------------------------------------------------------------------- 1 | { 2 | "finetune_path": "", 3 | "train_path": "data/training.json", 4 | "test_path": "data/dev.json", 5 | "enc_word_vec_path": "data/vectors_RwNN.txt.st", 6 | "dec_word_vec_path": "data/vectors_RwNN.txt.st", 7 | "suffix": "gold", 8 | "model_dir": "logs_s2s", 9 | "isLower": true, 10 | "two_sent_inputs": false, 11 | 12 | "direction": "bidir", 13 | "pointer_gen": true, 14 | "use_coverage": true, 15 | "switch_qa": false, 16 | "attention_vec_size": 300, 17 | "cov_loss_wt": 0.1, 18 | 19 | "beam_size": 20, 20 | 21 | "min_answer_len": 0, 22 | "max_answer_len": 100, 23 | "batch_size": 20, 24 | "learning_rate": 0.001, 25 | "lambda_l2": 0.001, 26 | "dropout_rate": 0.2, 27 | "max_epochs": 50, 28 | "optimize_type": "adam", 29 | 30 | "with_filter_layer": true, 31 | "filter_layer_threshold": 0.2, 32 | "with_word_match": true, 33 | "with_sequential_match": true, 34 | "context_layer_num": 1, 35 | "context_lstm_dim": 300, 36 | 37 | "with_lex_decomposition": false, 38 | "lex_decompsition_dim": 50, 39 | "with_question_passage_word_feature": false, 40 | 41 | "with_phrase_projection": false, 42 | 43 | "with_highway": false, 44 | "highway_layer_num": 1, 45 | "with_match_highway": false, 46 | "with_aggregation_highway": false, 47 | 48 | "with_word": true, 49 | "with_char": true, 50 | "with_POS": false, 51 | "with_NER": false, 52 | "POS_dim": 20, 53 | "NER_dim": 20, 54 | "char_dim": 50, 55 | "char_lstm_dim": 100, 56 | "compress_input": false, 57 | "compress_input_dim": 300, 58 | "fix_word_vec": true, 59 | "max_passage_len": 100, 60 | "max_char_per_word": 20, 61 | 62 | "gen_hidden_size": 300, 63 | "num_softmax_samples": 100, 64 | "mode": "ce_train", 65 | 66 | "config_path": "config.json", 67 | "generate_config": false 68 | } 69 | -------------------------------------------------------------------------------- /src_s2s/phrase_projection_layer_utils.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | def collect_representation(representation, positions): 4 | ''' 5 | representation: [batch_size, passsage_length, dim] 6 | positions: [batch_size, num_positions] 7 | ''' 8 | def singel_instance(x): 9 | # x[0]: [passage_length, dim] 10 | # x[1]: [num_positions] 11 | return tf.gather(x[0], x[1]) 12 | elems = (representation, positions) 13 | return tf.map_fn(singel_instance, elems, dtype=tf.float32) # [batch_size, num_positions, dim] 14 | 15 | class PhraseProjectionLayer(object): 16 | def __init__(self, placeholders): 17 | # placeholder assignments 18 | self.max_phrase_size = placeholders.max_phrase_size # a scaler, max number of phrases within a batch 19 | self.phrase_starts = placeholders.phrase_starts # [batch_size, chunk_len] 20 | self.phrase_ends = placeholders.phrase_ends # [batch_size, chunk_len] 21 | self.phrase_lengths = placeholders.phrase_lengths # [batch_size] 22 | 23 | def project_to_phrase_representation(self, encoder_representations): 24 | ''' 25 | encoder_represenations: [batch_size, passage_length, encoder_dim] 26 | ''' 27 | start_representations = collect_representation(encoder_representations, self.phrase_starts) # [batch_size, chunk_len, encoder_dim] 28 | end_representations = collect_representation(encoder_representations, self.phrase_ends) # [batch_size, chunk_len, encoder_dim] 29 | phrase_representations = tf.concat(2, [start_representations, end_representations], name='phrase_representation') 30 | 31 | phrase_len = tf.shape(self.phrase_starts)[1] 32 | phrase_mask = tf.sequence_mask(self.phrase_lengths, phrase_len, dtype=tf.float32) # [batch_size, phrase_len] 33 | phrase_mask = tf.expand_dims(phrase_mask, axis=-1, name='phrase_mask') # [batch_size, phrase_len, 'x'] 34 | 35 | phrase_representations = phrase_representations * phrase_mask 36 | return phrase_representations # [batch_size, phrase_len, 2*encoder_dim] 37 | 38 | 39 | 40 | -------------------------------------------------------------------------------- /src_g2s/amr_utils.py: -------------------------------------------------------------------------------- 1 | 2 | def read_anonymized(amr_lst, amr_node, amr_edge): 3 | assert sum(x=='(' for x in amr_lst) == sum(x==')' for x in amr_lst) 4 | cur_str = amr_lst[0] 5 | cur_id = len(amr_node) 6 | amr_node.append(cur_str) 7 | 8 | i = 1 9 | while i < len(amr_lst): 10 | if amr_lst[i].startswith(':') == False: ## cur cur-num_0 11 | nxt_str = amr_lst[i] 12 | nxt_id = len(amr_node) 13 | amr_node.append(nxt_str) 14 | amr_edge.append((cur_id, nxt_id, ':value')) 15 | i = i + 1 16 | elif amr_lst[i].startswith(':') and len(amr_lst) == 2: ## cur :edge 17 | nxt_str = 'num_unk' 18 | nxt_id = len(amr_node) 19 | amr_node.append(nxt_str) 20 | amr_edge.append((cur_id, nxt_id, amr_lst[i])) 21 | i = i + 1 22 | elif amr_lst[i].startswith(':') and amr_lst[i+1] != '(': ## cur :edge nxt 23 | nxt_str = amr_lst[i+1] 24 | nxt_id = len(amr_node) 25 | amr_node.append(nxt_str) 26 | amr_edge.append((cur_id, nxt_id, amr_lst[i])) 27 | i = i + 2 28 | elif amr_lst[i].startswith(':') and amr_lst[i+1] == '(': ## cur :edge ( ... ) 29 | number = 1 30 | j = i+2 31 | while j < len(amr_lst): 32 | number += (amr_lst[j] == '(') 33 | number -= (amr_lst[j] == ')') 34 | if number == 0: 35 | break 36 | j += 1 37 | assert number == 0 and amr_lst[j] == ')', ' '.join(amr_lst[i+2:j]) 38 | nxt_id = read_anonymized(amr_lst[i+2:j], amr_node, amr_edge) 39 | amr_edge.append((cur_id, nxt_id, amr_lst[i])) 40 | i = j + 1 41 | else: 42 | assert False, ' '.join(amr_lst) 43 | return cur_id 44 | 45 | if __name__ == '__main__': 46 | for path in ['data/dev-dfs-linear_src.txt', 'data/test-dfs-linear_src.txt', 'data/training-dfs-linear_src.txt', ]: 47 | print path 48 | for i, line in enumerate(open(path, 'rU')): 49 | amr_node = [] 50 | amr_edge = [] 51 | read_anonymized(line.strip().split(), amr_node, amr_edge) 52 | -------------------------------------------------------------------------------- /src_g2s/padding_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | def make_batches(size, batch_size): 3 | nb_batch = int(np.ceil(size/float(batch_size))) 4 | return [(i*batch_size, min(size, (i+1)*batch_size)) for i in range(0, nb_batch)] # zgwang: starting point of each batch 5 | 6 | def pad_2d_vals_no_size(in_vals, dtype=np.int32): 7 | size1 = len(in_vals) 8 | size2 = np.max([len(x) for x in in_vals]) 9 | return pad_2d_vals(in_vals, size1, size2, dtype=dtype) 10 | 11 | def pad_2d_vals(in_vals, dim1_size, dim2_size, dtype=np.int32): 12 | out_val = np.zeros((dim1_size, dim2_size), dtype=dtype) 13 | if dim1_size > len(in_vals): dim1_size = len(in_vals) 14 | for i in xrange(dim1_size): 15 | cur_in_vals = in_vals[i] 16 | cur_dim2_size = dim2_size 17 | if cur_dim2_size > len(cur_in_vals): cur_dim2_size = len(cur_in_vals) 18 | out_val[i,:cur_dim2_size] = cur_in_vals[:cur_dim2_size] 19 | return out_val 20 | 21 | def pad_3d_vals_no_size(in_vals, dtype=np.int32): 22 | size1 = len(in_vals) 23 | size2 = np.max([len(x) for x in in_vals]) 24 | size3 = 0 25 | for val in in_vals: 26 | cur_size3 = np.max([len(x) for x in val]) 27 | if size3 len(in_vals): dim1_size = len(in_vals) 33 | for i in xrange(dim1_size): 34 | in_vals_i = in_vals[i] 35 | cur_dim2_size = dim2_size 36 | if cur_dim2_size > len(in_vals_i): cur_dim2_size = len(in_vals_i) 37 | for j in xrange(cur_dim2_size): 38 | in_vals_ij = in_vals_i[j] 39 | cur_dim3_size = dim3_size 40 | if cur_dim3_size > len(in_vals_ij): cur_dim3_size = len(in_vals_ij) 41 | out_val[i, j, :cur_dim3_size] = in_vals_ij[:cur_dim3_size] 42 | return out_val 43 | 44 | def pad_4d_vals(in_vals, dim1_size, dim2_size, dim3_size, dim4_size, dtype=np.int32): 45 | out_val = np.zeros((dim1_size, dim2_size, dim3_size, dim4_size), dtype=dtype) 46 | if dim1_size > len(in_vals): dim1_size = len(in_vals) 47 | for i in xrange(dim1_size): 48 | in_vals_i = in_vals[i] 49 | cur_dim2_size = dim2_size 50 | if cur_dim2_size > len(in_vals_i): cur_dim2_size = len(in_vals_i) 51 | for j in xrange(cur_dim2_size): 52 | in_vals_ij = in_vals_i[j] 53 | cur_dim3_size = dim3_size 54 | if cur_dim3_size > len(in_vals_ij): cur_dim3_size = len(in_vals_ij) 55 | for k in xrange(cur_dim3_size): 56 | in_vals_ijk = in_vals_ij[k] 57 | cur_dim4_size = dim4_size 58 | if cur_dim4_size > len(in_vals_ijk): cur_dim4_size = len(in_vals_ijk) 59 | out_val[i, j, k, :cur_dim4_size] = in_vals_ijk[:cur_dim4_size] 60 | return out_val 61 | 62 | def pad_target_labels(in_val, max_length, dtype=np.float32): 63 | batch_size = len(in_val) 64 | out_val = np.zeros((batch_size, max_length), dtype=dtype) 65 | for i in xrange(batch_size): 66 | for index in in_val[i]: 67 | out_val[i,index] = 1.0 68 | return out_val 69 | -------------------------------------------------------------------------------- /src_s2s/padding_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | def make_batches(size, batch_size): 3 | nb_batch = int(np.ceil(size/float(batch_size))) 4 | return [(i*batch_size, min(size, (i+1)*batch_size)) for i in range(0, nb_batch)] # zgwang: starting point of each batch 5 | 6 | def pad_2d_vals_no_size(in_vals, dtype=np.int32): 7 | size1 = len(in_vals) 8 | size2 = np.max([len(x) for x in in_vals]) 9 | return pad_2d_vals(in_vals, size1, size2, dtype=dtype) 10 | 11 | def pad_2d_vals(in_vals, dim1_size, dim2_size, dtype=np.int32): 12 | out_val = np.zeros((dim1_size, dim2_size), dtype=dtype) 13 | if dim1_size > len(in_vals): dim1_size = len(in_vals) 14 | for i in xrange(dim1_size): 15 | cur_in_vals = in_vals[i] 16 | cur_dim2_size = dim2_size 17 | if cur_dim2_size > len(cur_in_vals): cur_dim2_size = len(cur_in_vals) 18 | out_val[i,:cur_dim2_size] = cur_in_vals[:cur_dim2_size] 19 | return out_val 20 | 21 | def pad_3d_vals_no_size(in_vals, dtype=np.int32): 22 | size1 = len(in_vals) 23 | size2 = np.max([len(x) for x in in_vals]) 24 | size3 = 0 25 | for val in in_vals: 26 | cur_size3 = np.max([len(x) for x in val]) 27 | if size3 len(in_vals): dim1_size = len(in_vals) 34 | for i in xrange(dim1_size): 35 | in_vals_i = in_vals[i] 36 | cur_dim2_size = dim2_size 37 | if cur_dim2_size > len(in_vals_i): cur_dim2_size = len(in_vals_i) 38 | for j in xrange(cur_dim2_size): 39 | in_vals_ij = in_vals_i[j] 40 | cur_dim3_size = dim3_size 41 | if cur_dim3_size > len(in_vals_ij): cur_dim3_size = len(in_vals_ij) 42 | out_val[i, j, :cur_dim3_size] = in_vals_ij[:cur_dim3_size] 43 | return out_val 44 | 45 | def pad_4d_vals(in_vals, dim1_size, dim2_size, dim3_size, dim4_size, dtype=np.int32): 46 | out_val = np.zeros((dim1_size, dim2_size, dim3_size, dim4_size), dtype=dtype) 47 | if dim1_size > len(in_vals): dim1_size = len(in_vals) 48 | for i in xrange(dim1_size): 49 | in_vals_i = in_vals[i] 50 | cur_dim2_size = dim2_size 51 | if cur_dim2_size > len(in_vals_i): cur_dim2_size = len(in_vals_i) 52 | for j in xrange(cur_dim2_size): 53 | in_vals_ij = in_vals_i[j] 54 | cur_dim3_size = dim3_size 55 | if cur_dim3_size > len(in_vals_ij): cur_dim3_size = len(in_vals_ij) 56 | for k in xrange(cur_dim3_size): 57 | in_vals_ijk = in_vals_ij[k] 58 | cur_dim4_size = dim4_size 59 | if cur_dim4_size > len(in_vals_ijk): cur_dim4_size = len(in_vals_ijk) 60 | out_val[i, j, k, :cur_dim4_size] = in_vals_ijk[:cur_dim4_size] 61 | return out_val 62 | 63 | def pad_target_labels(in_val, max_length, dtype=np.float32): 64 | batch_size = len(in_val) 65 | out_val = np.zeros((batch_size, max_length), dtype=dtype) 66 | for i in xrange(batch_size): 67 | for index in in_val[i]: 68 | out_val[i,index] = 1.0 69 | return out_val 70 | -------------------------------------------------------------------------------- /src_g2s/metric_utils.py: -------------------------------------------------------------------------------- 1 | import cPickle as pickle 2 | import os 3 | import sys 4 | from metric_bleu_utils import Bleu 5 | from metric_rouge_utils import Rouge 6 | 7 | def score_all(ref, hypo): 8 | scorers = [ 9 | (Bleu(4),["Bleu_1","Bleu_2","Bleu_3","Bleu_4"]), 10 | (Rouge(),"ROUGE_L"), 11 | ] 12 | final_scores = {} 13 | for scorer,method in scorers: 14 | score,scores = scorer.compute_score(ref,hypo) 15 | if type(score)==list: 16 | for m,s in zip(method,score): 17 | final_scores[m] = s 18 | else: 19 | final_scores[method] = score 20 | 21 | return final_scores 22 | 23 | def score(ref, hypo): 24 | scorers = [ 25 | (Bleu(4),["Bleu_1","Bleu_2","Bleu_3","Bleu_4"]) 26 | 27 | ] 28 | final_scores = {} 29 | for scorer,method in scorers: 30 | score,scores = scorer.compute_score(ref,hypo) 31 | if type(score)==list: 32 | for m,s in zip(method,score): 33 | final_scores[m] = s 34 | else: 35 | final_scores[method] = score 36 | 37 | return final_scores 38 | 39 | def evaluate_captions(ref,cand): 40 | hypo = {} 41 | refe = {} 42 | for i, caption in enumerate(cand): 43 | hypo[i] = [caption,] 44 | refe[i] = ref[i] 45 | final_scores = score(refe, hypo) 46 | return 1*final_scores['Bleu_4'] + 1*final_scores['Bleu_3'] + 0.5*final_scores['Bleu_1'] + 0.5*final_scores['Bleu_2'] 47 | 48 | def evaluate(data_path='./data', split='val', get_scores=False): 49 | reference_path = os.path.join(data_path, "%s/%s.references.pkl" %(split, split)) 50 | candidate_path = os.path.join(data_path, "%s/%s.candidate.captions.pkl" %(split, split)) 51 | 52 | # load caption data 53 | with open(reference_path, 'rb') as f: 54 | ref = pickle.load(f) 55 | with open(candidate_path, 'rb') as f: 56 | cand = pickle.load(f) 57 | 58 | # make dictionary 59 | hypo = {} 60 | for i, caption in enumerate(cand): 61 | hypo[i] = [caption] 62 | 63 | # compute bleu score 64 | final_scores = score_all(ref, hypo) 65 | 66 | # print out scores 67 | print 'Bleu_1:\t',final_scores['Bleu_1'] 68 | print 'Bleu_2:\t',final_scores['Bleu_2'] 69 | print 'Bleu_3:\t',final_scores['Bleu_3'] 70 | print 'Bleu_4:\t',final_scores['Bleu_4'] 71 | print 'METEOR:\t',final_scores['METEOR'] 72 | print 'ROUGE_L:',final_scores['ROUGE_L'] 73 | print 'CIDEr:\t',final_scores['CIDEr'] 74 | 75 | if get_scores: 76 | return final_scores 77 | 78 | 79 | if __name__ == "__main__": 80 | ref = [[u'a tiddy bear',u'a animal'],[u' a number of luggage bags on a cart in a lobby .', u' a cart filled with suitcases and bags .', u' trolley used for transporting personal luggage to guests rooms .', u' wheeled cart with luggage at lobby of commercial business .', u' a luggage cart topped with lots of luggage .']] 81 | dec = [u'some one',u' a man is standing next to a car with a suitcase .'] 82 | r = [evaluate_captions([k], [v]) for k, v in zip(ref, dec)] 83 | print r 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | -------------------------------------------------------------------------------- /src_s2s/metric_utils.py: -------------------------------------------------------------------------------- 1 | import cPickle as pickle 2 | import os 3 | import sys 4 | from metric_bleu_utils import Bleu 5 | from metric_rouge_utils import Rouge 6 | 7 | def score_all(ref, hypo): 8 | scorers = [ 9 | (Bleu(4),["Bleu_1","Bleu_2","Bleu_3","Bleu_4"]), 10 | (Rouge(),"ROUGE_L"), 11 | ] 12 | final_scores = {} 13 | for scorer,method in scorers: 14 | score,scores = scorer.compute_score(ref,hypo) 15 | if type(score)==list: 16 | for m,s in zip(method,score): 17 | final_scores[m] = s 18 | else: 19 | final_scores[method] = score 20 | 21 | return final_scores 22 | 23 | def score(ref, hypo): 24 | scorers = [ 25 | (Bleu(4),["Bleu_1","Bleu_2","Bleu_3","Bleu_4"]) 26 | 27 | ] 28 | final_scores = {} 29 | for scorer,method in scorers: 30 | score,scores = scorer.compute_score(ref,hypo) 31 | if type(score)==list: 32 | for m,s in zip(method,score): 33 | final_scores[m] = s 34 | else: 35 | final_scores[method] = score 36 | 37 | return final_scores 38 | 39 | def evaluate_captions(ref,cand): 40 | hypo = {} 41 | refe = {} 42 | for i, caption in enumerate(cand): 43 | hypo[i] = [caption,] 44 | refe[i] = ref[i] 45 | final_scores = score(refe, hypo) 46 | return 1*final_scores['Bleu_4'] + 1*final_scores['Bleu_3'] + 0.5*final_scores['Bleu_1'] + 0.5*final_scores['Bleu_2'] 47 | 48 | def evaluate(data_path='./data', split='val', get_scores=False): 49 | reference_path = os.path.join(data_path, "%s/%s.references.pkl" %(split, split)) 50 | candidate_path = os.path.join(data_path, "%s/%s.candidate.captions.pkl" %(split, split)) 51 | 52 | # load caption data 53 | with open(reference_path, 'rb') as f: 54 | ref = pickle.load(f) 55 | with open(candidate_path, 'rb') as f: 56 | cand = pickle.load(f) 57 | 58 | # make dictionary 59 | hypo = {} 60 | for i, caption in enumerate(cand): 61 | hypo[i] = [caption] 62 | 63 | # compute bleu score 64 | final_scores = score_all(ref, hypo) 65 | 66 | # print out scores 67 | print 'Bleu_1:\t',final_scores['Bleu_1'] 68 | print 'Bleu_2:\t',final_scores['Bleu_2'] 69 | print 'Bleu_3:\t',final_scores['Bleu_3'] 70 | print 'Bleu_4:\t',final_scores['Bleu_4'] 71 | print 'METEOR:\t',final_scores['METEOR'] 72 | print 'ROUGE_L:',final_scores['ROUGE_L'] 73 | print 'CIDEr:\t',final_scores['CIDEr'] 74 | 75 | if get_scores: 76 | return final_scores 77 | 78 | 79 | if __name__ == "__main__": 80 | ref = [[u'a tiddy bear',u'a animal'],[u' a number of luggage bags on a cart in a lobby .', u' a cart filled with suitcases and bags .', u' trolley used for transporting personal luggage to guests rooms .', u' wheeled cart with luggage at lobby of commercial business .', u' a luggage cart topped with lots of luggage .']] 81 | dec = [u'some one',u' a man is standing next to a car with a suitcase .'] 82 | r = [evaluate_captions([k], [v]) for k, v in zip(ref, dec)] 83 | print r 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | -------------------------------------------------------------------------------- /src_g2s/metric_rouge_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # File Name : rouge.py 4 | # 5 | # Description : Computes ROUGE-L metric as described by Lin and Hovey (2004) 6 | # 7 | # Creation Date : 2015-01-07 06:03 8 | # Author : Ramakrishna Vedantam 9 | 10 | import numpy as np 11 | import pdb 12 | 13 | def my_lcs(string, sub): 14 | """ 15 | Calculates longest common subsequence for a pair of tokenized strings 16 | :param string : list of str : tokens from a string split using whitespace 17 | :param sub : list of str : shorter string, also split using whitespace 18 | :returns: length (list of int): length of the longest common subsequence between the two strings 19 | Note: my_lcs only gives length of the longest common subsequence, not the actual LCS 20 | """ 21 | if(len(string)< len(sub)): 22 | sub, string = string, sub 23 | 24 | lengths = [[0 for i in range(0,len(sub)+1)] for j in range(0,len(string)+1)] 25 | 26 | for j in range(1,len(sub)+1): 27 | for i in range(1,len(string)+1): 28 | if(string[i-1] == sub[j-1]): 29 | lengths[i][j] = lengths[i-1][j-1] + 1 30 | else: 31 | lengths[i][j] = max(lengths[i-1][j] , lengths[i][j-1]) 32 | 33 | return lengths[len(string)][len(sub)] 34 | 35 | class Rouge(): 36 | ''' 37 | Class for computing ROUGE-L score for a set of candidate sentences for the MS COCO test set 38 | ''' 39 | def __init__(self): 40 | # vrama91: updated the value below based on discussion with Hovey 41 | self.beta = 1.2 42 | 43 | def calc_score(self, candidate, refs): 44 | """ 45 | Compute ROUGE-L score given one candidate and references for an image 46 | :param candidate: str : candidate sentence to be evaluated 47 | :param refs: list of str : COCO reference sentences for the particular image to be evaluated 48 | :returns score: int (ROUGE-L score for the candidate evaluated against references) 49 | """ 50 | assert(len(candidate)==1) 51 | assert(len(refs)>0) 52 | prec = [] 53 | rec = [] 54 | 55 | # split into tokens 56 | token_c = candidate[0].split(" ") 57 | 58 | for reference in refs: 59 | # split into tokens 60 | token_r = reference.split(" ") 61 | # compute the longest common subsequence 62 | lcs = my_lcs(token_r, token_c) 63 | prec.append(lcs/float(len(token_c))) 64 | rec.append(lcs/float(len(token_r))) 65 | 66 | prec_max = max(prec) 67 | rec_max = max(rec) 68 | 69 | if(prec_max!=0 and rec_max !=0): 70 | score = ((1 + self.beta**2)*prec_max*rec_max)/float(rec_max + self.beta**2*prec_max) 71 | else: 72 | score = 0.0 73 | return score 74 | 75 | def compute_score(self, gts, res): 76 | """ 77 | Computes Rouge-L score given a set of reference and candidate sentences for the dataset 78 | Invoked by evaluate_captions.py 79 | :param hypo_for_image: dict : candidate / test sentences with "image name" key and "tokenized sentences" as values 80 | :param ref_for_image: dict : reference MS-COCO sentences with "image name" key and "tokenized sentences" as values 81 | :returns: average_score: float (mean ROUGE-L score computed by averaging scores for all the images) 82 | """ 83 | assert(gts.keys() == res.keys()) 84 | imgIds = gts.keys() 85 | 86 | score = [] 87 | for id in imgIds: 88 | hypo = res[id] 89 | ref = gts[id] 90 | 91 | score.append(self.calc_score(hypo, ref)) 92 | 93 | # Sanity check. 94 | assert(type(hypo) is list) 95 | assert(len(hypo) == 1) 96 | assert(type(ref) is list) 97 | assert(len(ref) > 0) 98 | 99 | average_score = np.mean(np.array(score)) 100 | return average_score, np.array(score) 101 | 102 | def method(self): 103 | return "Rouge" -------------------------------------------------------------------------------- /src_s2s/metric_rouge_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # File Name : rouge.py 4 | # 5 | # Description : Computes ROUGE-L metric as described by Lin and Hovey (2004) 6 | # 7 | # Creation Date : 2015-01-07 06:03 8 | # Author : Ramakrishna Vedantam 9 | 10 | import numpy as np 11 | import pdb 12 | 13 | def my_lcs(string, sub): 14 | """ 15 | Calculates longest common subsequence for a pair of tokenized strings 16 | :param string : list of str : tokens from a string split using whitespace 17 | :param sub : list of str : shorter string, also split using whitespace 18 | :returns: length (list of int): length of the longest common subsequence between the two strings 19 | Note: my_lcs only gives length of the longest common subsequence, not the actual LCS 20 | """ 21 | if(len(string)< len(sub)): 22 | sub, string = string, sub 23 | 24 | lengths = [[0 for i in range(0,len(sub)+1)] for j in range(0,len(string)+1)] 25 | 26 | for j in range(1,len(sub)+1): 27 | for i in range(1,len(string)+1): 28 | if(string[i-1] == sub[j-1]): 29 | lengths[i][j] = lengths[i-1][j-1] + 1 30 | else: 31 | lengths[i][j] = max(lengths[i-1][j] , lengths[i][j-1]) 32 | 33 | return lengths[len(string)][len(sub)] 34 | 35 | class Rouge(): 36 | ''' 37 | Class for computing ROUGE-L score for a set of candidate sentences for the MS COCO test set 38 | ''' 39 | def __init__(self): 40 | # vrama91: updated the value below based on discussion with Hovey 41 | self.beta = 1.2 42 | 43 | def calc_score(self, candidate, refs): 44 | """ 45 | Compute ROUGE-L score given one candidate and references for an image 46 | :param candidate: str : candidate sentence to be evaluated 47 | :param refs: list of str : COCO reference sentences for the particular image to be evaluated 48 | :returns score: int (ROUGE-L score for the candidate evaluated against references) 49 | """ 50 | assert(len(candidate)==1) 51 | assert(len(refs)>0) 52 | prec = [] 53 | rec = [] 54 | 55 | # split into tokens 56 | token_c = candidate[0].split(" ") 57 | 58 | for reference in refs: 59 | # split into tokens 60 | token_r = reference.split(" ") 61 | # compute the longest common subsequence 62 | lcs = my_lcs(token_r, token_c) 63 | prec.append(lcs/float(len(token_c))) 64 | rec.append(lcs/float(len(token_r))) 65 | 66 | prec_max = max(prec) 67 | rec_max = max(rec) 68 | 69 | if(prec_max!=0 and rec_max !=0): 70 | score = ((1 + self.beta**2)*prec_max*rec_max)/float(rec_max + self.beta**2*prec_max) 71 | else: 72 | score = 0.0 73 | return score 74 | 75 | def compute_score(self, gts, res): 76 | """ 77 | Computes Rouge-L score given a set of reference and candidate sentences for the dataset 78 | Invoked by evaluate_captions.py 79 | :param hypo_for_image: dict : candidate / test sentences with "image name" key and "tokenized sentences" as values 80 | :param ref_for_image: dict : reference MS-COCO sentences with "image name" key and "tokenized sentences" as values 81 | :returns: average_score: float (mean ROUGE-L score computed by averaging scores for all the images) 82 | """ 83 | assert(gts.keys() == res.keys()) 84 | imgIds = gts.keys() 85 | 86 | score = [] 87 | for id in imgIds: 88 | hypo = res[id] 89 | ref = gts[id] 90 | 91 | score.append(self.calc_score(hypo, ref)) 92 | 93 | # Sanity check. 94 | assert(type(hypo) is list) 95 | assert(len(hypo) == 1) 96 | assert(type(ref) is list) 97 | assert(len(ref) > 0) 98 | 99 | average_score = np.mean(np.array(score)) 100 | return average_score, np.array(score) 101 | 102 | def method(self): 103 | return "Rouge" -------------------------------------------------------------------------------- /src_s2s/prepare_paraphrase_dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | import random 3 | import re 4 | 5 | def read_MSCOCO(inpath): 6 | # load json file 7 | with open(inpath) as dataset_file: 8 | dataset_json = json.load(dataset_file, encoding='utf-8') 9 | annotations = dataset_json['annotations'] 10 | print(len(annotations)) 11 | 12 | # dispatch each caption to its corresponding image 13 | image_captions_dict = {} 14 | for annotation in annotations: 15 | image_id = annotation['image_id'] 16 | id = annotation['id'] 17 | caption = annotation['caption'] 18 | captions = None 19 | if image_captions_dict.has_key(image_id): 20 | captions = image_captions_dict[image_id] 21 | else: 22 | captions = [] 23 | captions.append(caption) 24 | image_captions_dict[image_id] = captions 25 | 26 | # check number of captions for each image 27 | all_instances = [] 28 | for image_id in image_captions_dict.keys(): 29 | captions = image_captions_dict[image_id] 30 | random.shuffle(captions) 31 | # print(len(captions)) 32 | all_instances.append((image_id, captions)) 33 | return all_instances 34 | 35 | def read_Quora(inpath): 36 | with open(inpath, "rt") as f: 37 | for line in f: 38 | line = line.decode('utf-8') 39 | line = line.strip() 40 | items = re.split('\t', line) 41 | 42 | 43 | def dump_out_to_json(instances, outpath): 44 | json_instances = [] 45 | for (image_id, captions) in instances: 46 | json_instances.append({'id': str(image_id), 'text1': captions[0], 'text2': captions[1]}) 47 | # json_instances.append({'id': str(image_id) + "-2", 'text1': captions[2], 'text2': captions[3]}) 48 | with open(outpath, 'w') as outfile: 49 | json.dump(json_instances, outfile) 50 | 51 | def create_json_file(all_instances, outpath, batch_size=5000): 52 | import padding_utils 53 | batch_spans = padding_utils.make_batches(len(all_instances), batch_size) 54 | for batch_index, (batch_start, batch_end) in enumerate(batch_spans): 55 | cur_instances = all_instances[batch_start:batch_end] 56 | cur_outpath = outpath + ".{}".format(batch_index) 57 | print("Dump {} instances out to {}".format(len(cur_instances), cur_outpath)) 58 | dump_out_to_json(cur_instances, cur_outpath) 59 | 60 | if __name__ == "__main__": 61 | ''' # create mscoco dataset 62 | dataset = "train" 63 | inpath = "/u/zhigwang/zhigwang1/sentence_generation/mscoco/annotations/captions_" + dataset + "2014.json" 64 | outpath = "/u/zhigwang/zhigwang1/sentence_generation/mscoco/data/" + dataset + ".json" 65 | all_instances = read_MSCOCO(inpath) 66 | batch_size = 5000 67 | create_json_file(all_instances, outpath, batch_size=batch_size) 68 | ''' 69 | 70 | # create quora dataset 71 | rawpath = "/u/zhigwang/zhigwang1/sentence_match/quora/quora_duplicate_questions.tsv" 72 | batch_size = 10000 73 | # load all pairs 74 | print('Loading all question pairs ...') 75 | id_instances_dict = {} 76 | with open(rawpath, "rt") as f: 77 | for line in f: 78 | line = line.decode('utf-8').strip() 79 | if not line.endswith("1"): continue 80 | items = re.split('\t', line) 81 | cur_id = items[0] 82 | sent1 = items[3] 83 | sent2 = items[4] 84 | id_instances_dict[cur_id] = (sent1, sent2) 85 | print(len(id_instances_dict)) 86 | 87 | for dataset in ['dev', 'test', 'train']: 88 | # collect all isntances 89 | inpath = "/u/zhigwang/zhigwang1/sentence_match/quora/" + dataset + ".tsv" 90 | outpath = "/u/zhigwang/zhigwang1/sentence_generation/quora/data/" + dataset + ".json" 91 | print(inpath) 92 | all_instances = [] 93 | with open(inpath, "rt") as f: 94 | for line in f: 95 | line = line.decode('utf-8').strip() 96 | if not line.startswith("1"): continue 97 | items = re.split('\t', line) 98 | cur_id = items[3] 99 | all_instances.append((cur_id, id_instances_dict[cur_id])) 100 | create_json_file(all_instances, outpath, batch_size=batch_size) 101 | 102 | 103 | 104 | print('DONE!') -------------------------------------------------------------------------------- /src_s2s/prepare_summarization_dataset.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import os 3 | import json 4 | 5 | dm_single_close_quote = u'\u2019' # unicode 6 | dm_double_close_quote = u'\u201d' 7 | END_TOKENS = ['.', '!', '?', '...', "'", "`", '"', dm_single_close_quote, dm_double_close_quote, ")"] # acceptable ways to end a sentence 8 | 9 | SENTENCE_START = '' 10 | SENTENCE_END = '' 11 | 12 | def read_text_file(text_file): 13 | lines = [] 14 | with open(text_file, "rt") as f: 15 | for line in f: 16 | line = line.decode('utf-8') 17 | lines.append(line.strip()) 18 | return lines 19 | 20 | def fix_missing_period(line): 21 | """Adds a period to a line that is missing a period""" 22 | if "@highlight" in line: return line 23 | if line=="": return line 24 | if line[-1] in END_TOKENS: return line 25 | return line + " ." 26 | 27 | def get_art_abs(story_file): 28 | lines = read_text_file(story_file) 29 | 30 | # Lowercase everything 31 | # lines = [line.lower() for line in lines] 32 | # Put periods on the ends of lines that are missing them (this is a problem in the dataset because many image captions don't end in periods; consequently they end up in the body of the article as run-on sentences) 33 | lines = [fix_missing_period(line) for line in lines] 34 | 35 | # Separate out article and abstract sentences 36 | article_lines = [] 37 | highlights = [] 38 | next_is_highlight = False 39 | for idx,line in enumerate(lines): 40 | if line == "": 41 | continue # empty line 42 | elif line.startswith("@highlight"): 43 | next_is_highlight = True 44 | elif next_is_highlight: 45 | highlights.append(line) 46 | else: 47 | article_lines.append(line) 48 | 49 | # Make article into a single string 50 | article = ' '.join(article_lines) 51 | 52 | # Make abstract into a signle string, putting and tags around the sentences 53 | abstract = ' '.join(["%s %s %s" % (SENTENCE_START, sent, SENTENCE_END) for sent in highlights]) 54 | 55 | return article, abstract 56 | 57 | def hashhex(s): 58 | """Returns a heximal formated SHA1 hash of the input string.""" 59 | h = hashlib.sha1() 60 | h.update(s) 61 | return h.hexdigest() 62 | 63 | def dump_out_to_json(hash_path_dict, urls, outpath): 64 | all_instances = [] 65 | for cur_url in urls: 66 | cur_hash_code = hashhex(cur_url) 67 | cur_path = hash_path_dict[cur_hash_code] 68 | (article, abstract) = get_art_abs(cur_path) 69 | all_instances.append({'id': cur_hash_code, 'text1': article, 'text2': abstract}) 70 | 71 | with open(outpath, 'w') as outfile: 72 | json.dump(all_instances, outfile) 73 | 74 | def create_json_file(hash_path_dict, urlpath, outpath, batch_size=5000): 75 | all_urls = read_text_file(urlpath) 76 | import padding_utils 77 | batch_spans = padding_utils.make_batches(len(all_urls), batch_size) 78 | for batch_index, (batch_start, batch_end) in enumerate(batch_spans): 79 | cur_urls = all_urls[batch_start:batch_end] 80 | cur_outpath = outpath + ".{}".format(batch_index) 81 | print("Dump {} instances out to {}".format(len(cur_urls), cur_outpath)) 82 | dump_out_to_json(hash_path_dict, cur_urls, cur_outpath) 83 | 84 | 85 | 86 | def process(): 87 | cnn_in_dir = "/u/zhigwang/zhigwang1/sentence_generation/cnn-dailymail/cnn/cnn/stories" 88 | daily_mail_in_dir = "/u/zhigwang/zhigwang1/sentence_generation/cnn-dailymail/dailymail/dailymail/stories" 89 | train_urls = "/u/zhigwang/zhigwang1/sentence_generation/cnn-dailymail/url_lists/all_train.txt" 90 | val_urls = "/u/zhigwang/zhigwang1/sentence_generation/cnn-dailymail/url_lists/all_val.txt" 91 | test_urls = "/u/zhigwang/zhigwang1/sentence_generation/cnn-dailymail/url_lists/all_test.txt" 92 | outdir = "/u/zhigwang/zhigwang1/sentence_generation/cnn-dailymail/data" 93 | batch_size = 5000 94 | 95 | # collect all files 96 | print('collecting all files') 97 | in_dirs = [cnn_in_dir, daily_mail_in_dir] 98 | hash_path_dict = {} 99 | for in_dir in in_dirs: 100 | all_paths= os.listdir(in_dir) 101 | for cur_path in all_paths: 102 | if not cur_path.endswith(".story"): continue 103 | cur_hash_code = cur_path[:-len('.story')] 104 | cur_path = in_dir + "/" + cur_path 105 | hash_path_dict[cur_hash_code] = cur_path 106 | # print(cur_path) 107 | # print(cur_hash_code) 108 | print('number of files: {}'.format(len(hash_path_dict))) 109 | print('Creating val.json') 110 | create_json_file(hash_path_dict, val_urls, outdir + "/val.json", batch_size=batch_size) 111 | print('Creating test.json') 112 | create_json_file(hash_path_dict, test_urls, outdir + "/test.json", batch_size=batch_size) 113 | print('Creating train.json') 114 | create_json_file(hash_path_dict, train_urls, outdir + "/train.json", batch_size=batch_size) 115 | 116 | def generate_tok_commandlines(inpath): 117 | all_paths = read_text_file(inpath) 118 | for cur_path in all_paths: 119 | print("jbsub -q x86_12h -cores 10 -name process_summarization -mem 21G " 120 | + "/u/zhigwang/workspace/FactoidQA_Java/scripts/process_generation_datasets.sh " 121 | + "{} {}.tok 10".format(cur_path, cur_path)) 122 | 123 | 124 | if __name__ == "__main__": 125 | ''' 126 | inpath = "/u/zhigwang/zhigwang1/sentence_generation/cnn/cnn/stories/fffcd65676a501860ae312754e8cefc71f5ddab8.story" 127 | (article, abstract) = get_art_abs(inpath) 128 | print(article) 129 | print(abstract) 130 | ''' 131 | 132 | # process() 133 | #''' 134 | inpath = "/u/zhigwang/zhigwang1/sentence_generation/cnn-dailymail/data/fof" 135 | # inpath = "/u/zhigwang/zhigwang1/sentence_generation/mscoco/data/fof" 136 | # inpath = "/u/zhigwang/zhigwang1/sentence_generation/quora/data/fof" 137 | generate_tok_commandlines(inpath) 138 | #''' 139 | # print("DONE!") -------------------------------------------------------------------------------- /src_s2s/sent_utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | class QASentence(object): 4 | def __init__(self, rawText, annotation, ID_num=None, isLower=False, end_sym=None): 5 | self.rawText = rawText 6 | self.annotation = annotation 7 | self.tokText = annotation['toks'] 8 | # it's the answer sequence 9 | if end_sym != None: 10 | self.rawText += ' ' + end_sym 11 | self.tokText += ' ' + end_sym 12 | if isLower: self.tokText = self.tokText.lower() 13 | self.words = re.split("\\s+", self.tokText) 14 | self.startPositions = [] 15 | self.endPositions = [] 16 | positions = re.split("\\s+", annotation['positions']) 17 | for i in xrange(len(positions)): 18 | tmps = re.split("-", positions[i]) 19 | self.startPositions.append(int(tmps[1])) 20 | self.endPositions.append(int(tmps[2])) 21 | self.POSs = annotation['POSs'] 22 | self.NERs = annotation['NERs'] 23 | if annotation.has_key('spans'): self.syntaxSpans = annotation['spans'] 24 | self.length = len(self.words) 25 | self.ID_num = ID_num 26 | 27 | self.index_convered = False 28 | self.chunk_starts = None 29 | 30 | def chunk(self, maxlen): 31 | self.words = self.words[:maxlen] 32 | self.startPositions = self.startPositions[:maxlen] 33 | self.endPositions = self.endPositions[:maxlen] 34 | self.POSs = self.POSs[:maxlen] 35 | self.NERs = self.NERs[:maxlen] 36 | 37 | if self.index_convered: 38 | self.word_idx_seq = self.word_idx_seq[:maxlen] 39 | self.char_idx_seq = self.char_idx_seq[:maxlen] 40 | self.POS_idx_seq = self.POS_idx_seq[:maxlen] 41 | self.NER_idx_seq = self.NER_idx_seq[:maxlen] 42 | 43 | self.length = len(self.words) 44 | 45 | def TokSpan2RawSpan(self, startTokID, endTokID): 46 | start = self.startPositions[startTokID] 47 | end = self.endPositions[endTokID] 48 | return (start, end) 49 | 50 | def RawSpan2TokSpan(self, start, end): 51 | startTokID = -1 52 | endTokID = -1 53 | for i in xrange(len(self.startPositions)): 54 | if self.startPositions[i] == start: 55 | startTokID = i 56 | if self.endPositions[i] == end: 57 | endTokID = i 58 | return (startTokID, endTokID) 59 | 60 | def getRawChunk(self, start, end): 61 | if end>len(self.rawText): 62 | return None 63 | return self.rawText[start:end] 64 | 65 | def getRawChunkWithTokSpan(self, startTokID, endTokID): 66 | start = self.startPositions[startTokID] 67 | end = self.endPositions[endTokID] 68 | return self.rawText[start:end] 69 | 70 | def getTokChunk(self, startTokID, endTokID): 71 | curWords = [] 72 | for i in xrange(startTokID,endTokID+1): 73 | curWords.append(self.words[i]) 74 | return " ".join(curWords) 75 | 76 | def get_length(self): 77 | return self.length 78 | 79 | def get_max_word_len(self): 80 | max_word_len = 0 81 | for word in self.words: 82 | cur_len = len(word) 83 | if max_word_len < cur_len: max_word_len = cur_len 84 | return max_word_len 85 | 86 | def get_char_len(self): 87 | char_lens = [] 88 | for word in self.words: 89 | cur_len = len(word) 90 | char_lens.append(cur_len) 91 | return char_lens 92 | 93 | def convert2index(self, word_vocab, char_vocab, POS_vocab, NER_vocab, max_char_per_word=-1): 94 | if self.index_convered: return # for each sentence, only conver once 95 | 96 | if word_vocab is not None: 97 | self.word_idx_seq = word_vocab.to_index_sequence(self.tokText) 98 | 99 | if char_vocab is not None: 100 | self.char_idx_seq = char_vocab.to_character_matrix(self.tokText, max_char_per_word=max_char_per_word) 101 | 102 | if POS_vocab is not None: 103 | self.POS_idx_seq = POS_vocab.to_index_sequence(self.POSs) 104 | 105 | if NER_vocab is not None: 106 | self.NER_idx_seq = NER_vocab.to_index_sequence(self.NERs) 107 | 108 | self.index_convered = True 109 | 110 | def collect_all_possible_chunks(self, max_chunk_len): 111 | if self.chunk_starts is None: 112 | self.chunk_starts = [] 113 | self.chunk_ends = [] 114 | for i in xrange(self.length): 115 | cur_word = self.words[i] 116 | if cur_word in ".!?;": continue 117 | for j in xrange(i, i+max_chunk_len): 118 | if j>=self.length: break 119 | cur_word = self.words[j] 120 | if cur_word in ".!?;": break 121 | self.chunk_starts.append(i) 122 | self.chunk_ends.append(j) 123 | return (self.chunk_starts, self.chunk_ends) 124 | 125 | def collect_all_entities(self): 126 | items = re.split("\\s+", self.NERs) 127 | prev_label = "O" 128 | cur_start = -1 129 | chunk_starts = [] 130 | chunk_ends = [] 131 | for i in xrange(len(items)): 132 | cur_label = items[i] 133 | if cur_label != prev_label: 134 | if cur_start != -1: 135 | chunk_starts.append(cur_start) 136 | chunk_ends.append(i-1) 137 | cur_start = -1 138 | if cur_label != "O": 139 | cur_start = i 140 | prev_label = cur_label 141 | if cur_start !=-1: 142 | chunk_starts.append(cur_start) 143 | chunk_ends.append(len(items)-1) 144 | return (chunk_starts, chunk_ends) 145 | 146 | def collect_all_syntax_chunks(self, max_chunk_len): 147 | if self.chunk_starts is None: 148 | self.chunk_starts = [] 149 | self.chunk_ends = [] 150 | self.chunk_labels = [] 151 | all_spans = re.split("\\s+", self.syntaxSpans) 152 | for i in xrange(len(all_spans)): 153 | cur_span = all_spans[i] 154 | items = re.split("-", cur_span) 155 | cur_start = int(items[0]) 156 | cur_end = int(items[1]) 157 | cur_label = items[2] 158 | if cur_end-cur_start>=max_chunk_len: continue 159 | self.chunk_starts.append(cur_start) 160 | self.chunk_ends.append(cur_end) 161 | self.chunk_labels.append(cur_label) 162 | return (self.chunk_starts, self.chunk_ends, self.chunk_labels) 163 | 164 | if __name__ == "__main__": 165 | import NP2P_data_stream 166 | inpath = "/u/zhigwang/zhigwang1/sentence_generation/cnn-dailymail/data/val.json.tok" 167 | all_instances,_ = NP2P_data_stream.read_all_GenerationDatasets(inpath, isLower=True) 168 | sample_instance = all_instances[0][1] 169 | print('Raw text: {}'.format(sample_instance.rawText)) 170 | (chunk_starts, chunk_ends, chunk_labels) = sample_instance.collect_all_syntax_chunks(5) 171 | for i in xrange(len(chunk_starts)): 172 | cur_start = chunk_starts[i] 173 | cur_end = chunk_ends[i] 174 | cur_label = chunk_labels[i] 175 | cur_text = sample_instance.getTokChunk(cur_start, cur_end) 176 | print("{}-{}-{}:{}".format(cur_start, cur_end, cur_label, cur_text)) 177 | print("DONE!") 178 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Neural Graph to Sequence Model 2 | 3 | This repository contains the code for our paper [A Graph-to-Sequence Model for AMR-to-Text Generation](https://arxiv.org/abs/1805.02473) in ACL 2018 4 | 5 | The code is developed under TensorFlow 1.4.1. 6 | We shared our pretrained model along this repository. 7 | Due to the compitibility reason of TensorFlow, it may not be loaded by some lower version (such as 1.0.0). 8 | 9 | Please create issues if there are any questions! This can make things more tractable. 10 | 11 | ## Update logs 12 | 13 | ### Link update for silver AMR data and AMR simplifier (Oct. 13th, 2020) 14 | Due to the long period since my PhD graduation, my old homepage cannot be accessed now. 15 | I've moved data and scripts to Google Drive for download now. 16 | Pretrained models will be uploaded to Google Drive soon. 17 | 18 | ### Results on WebNLG and LDC2017T10 datasets (Aug. 4th, 2019) 19 | Following the same setting (such as data preprocessing) of [Marcheggiani and Perez-Beltrachini (INLG 2019)](https://www.aclweb.org/anthology/W18-6501), our model achieves a BLEU score of *64.2* on WebNLG dataset. 20 | 21 | For another widely used AMR dataset, LDC2017T10, our model achieves a similar BLEU score as that using LDC2015E86. 22 | 23 | ### Be careful about your tokenizer (Feb. 27th, 2019) 24 | We use the [PTB_tokenizer](https://nlp.stanford.edu/software/tokenizer.shtml) from Stanford corenlp to preprocess our data. If you plan to use our pretrained model, please be careful on the tokenizer you use. 25 | Also, be careful to keep the words cased during preprocessing, as the PTB_tokenizer is sensitive to that. 26 | 27 | ### Release of 2M automatically parsed data (link updated on Sep. 24th, 2020) 28 | We release our [2M sentences with their automatically parsed AMRs](https://drive.google.com/file/d/1mmTZFxRdBFOXbH6FCViQ30bVRIvGrwvx/view?usp=sharing) to the public. 29 | 30 | ## About AMR 31 | 32 | AMR is a graph-based semantic formalism, which can unified representations for several sentences of the same meaning. 33 | Comparing with other structures, such as dependency and semantic roles, the AMR graphs have several key differences: 34 | * AMRs only focus on concepts and their relations, so no function words are included. Actually the edge labels serve the role of function words. 35 | * Inflections are dropped when converting a noun, a verb or named entity into a AMR concept. Sometimes a synonym is used instead of the original word. This makes more unified AMRs so that each AMR graph can represent more sentences. 36 | * Relation tags (edge labels) are predefined and are not extracted from text (like the way OpenIE does). 37 | More details are in the official AMR page [AMR website@ISI](https://amr.isi.edu/download.html), where 38 | you can download the public-available AMR bank: [little prince](https://amr.isi.edu/download/amr-bank-struct-v1.6.txt). 39 | Try it for fun! 40 | 41 | ## Data precrocessing 42 | The [data loader](./src_g2s/G2S_data_stream.py) of our model requires simplified AMR graphs where variable tags, sense tags and quotes are removed. For example, the following AMR 43 | ``` 44 | (d / describe-01 45 | :ARG0 (p / person 46 | :name (n / name 47 | :op1 "Ryan")) 48 | :ARG1 p 49 | :ARG2 genius) 50 | ``` 51 | need to be simplified as 52 | ``` 53 | describe :arg0 ( person :name ( name :op1 ryan ) ) :arg1 person :arg2 genius 54 | ``` 55 | before being consumed by our model. 56 | 57 | 58 | We provide our scripts for AMR simplification. 59 | First, you need to make each AMR into a single line, where our released [script](./AMR_multiline_to_singleline.py) may serve your goal (You may need to slightly modify it). 60 | Second, to simplify the single-line AMRs, we release our tool that can be downloaded [here](https://drive.google.com/file/d/1PE8b44-H3Hu4I2Xf1XDAPFhH-GonR43i/view?usp=sharing). 61 | It is adapted from the [NeuralAMR](https://github.com/sinantie/NeuralAmr) system. 62 | To run our simplifier on a file ```demo.amr```, simply execute 63 | ``` 64 | ./anonDeAnon_java.sh anonymizeAmrFull true demo.amr 65 | ``` 66 | and it will output the simplified AMRs into ```demo.amr.anonymized```. 67 | Please note that our simplifier *does not* do anonymization. 68 | The resulting filename contains the 'anonymized' string because the original NeuralAMR creates the suffix. 69 | 70 | 71 | Another alternative is to write your own data loading code according to the format of your own AMR data. 72 | 73 | 74 | ### Input data format 75 | After simplifying your AMRs, you can merge them with the corresponding sentences into a JSON file. 76 | The JSON file is the actual input to the system. 77 | Its format is shown with the following sample: 78 | ``` 79 | [{"amr": "describe :arg0 ( person :name ( name :op1 ryan ) ) :arg1 person :arg2 genius", 80 | "sent": "ryan 's description of himself : a genius .", 81 | "id": "demo"}] 82 | ``` 83 | In general, the JSON file contains a list of instances, and each instance is a dictionary with fields of "amr", "sent" and "id"(optional). 84 | 85 | ### Vocabulary extraction 86 | After having the JSON files, you can extract vocabularies with our released scripts in the [./data/](./data/) directory. 87 | We also encourage you to write your own scripts. 88 | 89 | ## Training 90 | 91 | First, modify the PYTHONPATH within [train_g2s.sh](./train_g2s.sh) (for our graph-to-string model) or [train_s2s.sh](./train_s2s.sh) (for baseline).
92 | Second, modify config_g2s.json or config_s2s.json. You should pay attention to the field "suffix", which is an identifier of the model being trained and saved. We usually use the experiment setting, such as "bch20_lr1e3_l21e3", as the identifier.
93 | Finally, execute the corresponding script file, such as "./train_g2s.sh". 94 | 95 | ### Using large-scale automatic AMRs 96 | 97 | In this setting, we follow [Konstas et al., (2017)](https://arxiv.org/abs/1704.08381) to take the large-scale automatic data as the training set, taking the original gold data as a finetune set. 98 | To perform training in this way, you need to add a new field "finetune_path" in your config file and point it to the gold data. Besides the oringinal "train_path" should point to the automatic data. 99 | 100 | For training on the gold data only, we use an initial learning rate of 1e-3 and L2 normalization of 1e-3. We then lower the learning rate to be 8e-4, 5e-4 and 2e-4 after a number of epoches. 101 | 102 | For training on both gold and automatic data, the initial learning rate and L2 normalization are 5e-4 and 1e-8. We also lower the learning rate during training. 103 | 104 | The idea of lowering learning rate was first introduced by Konstas et al., (2017). 105 | 106 | 107 | ## Decoding with a pretained model 108 | 109 | Simply execute the corresponding decoding script with one argument being the identifier of the model you want to use. 110 | For instance, you can execute "./decode_g2s.sh bch20_lr1e3_l21e3". 111 | Please make sure you use the associated word vectors not others, because the pretrained model are *optimized* given the word vectors. 112 | 113 | ### Pretrained model 114 | 115 | We release a pretrained model (and word vectors) using gold plus 2M automatically-parsed AMRs [here](https://drive.google.com/file/d/1cP_VrGTGlRPdNOt2rV2wdOBnHrYDB10j/view?usp=sharing). With this model, we observed a BLEU of *33.6*, which is higher than our paper-reported number of 33.0. The pretrained model with only gold data is [here](https://drive.google.com/file/d/186snRVgx5nSLbAbgU9anbu4OMvrCTUP-/view?usp=sharing). It reports a test BLEU score of 23.3. 116 | The corresponding word-embedding file is [here](https://drive.google.com/file/d/1XhCW0eI1PQY51o_gB_MyG1DKsOG0bjIu/view?usp=sharing). 117 | 118 | ## Cite 119 | If you like our paper, please cite 120 | ``` 121 | @inproceedings{song2018graph, 122 | title={A Graph-to-Sequence Model for AMR-to-Text Generation}, 123 | author={Song, Linfeng and Zhang, Yue and Wang, Zhiguo and Gildea, Daniel}, 124 | booktitle={Proceedings of the 56th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers)}, 125 | pages={1616--1626}, 126 | year={2018} 127 | } 128 | ``` 129 | -------------------------------------------------------------------------------- /src_s2s/NP2P_data_stream.py: -------------------------------------------------------------------------------- 1 | import json 2 | import re 3 | import numpy as np 4 | import random 5 | import padding_utils 6 | import phrase_lattice_utils 7 | 8 | def read_text_file(text_file): 9 | lines = [] 10 | with open(text_file, "rt") as f: 11 | for line in f: 12 | line = line.decode('utf-8') 13 | lines.append(line.strip()) 14 | return lines 15 | 16 | 17 | def read_all_GenerationDatasets(inpath, isLower=True): 18 | with open(inpath) as dataset_file: 19 | dataset = json.load(dataset_file, encoding='utf-8') 20 | all_instances = [] 21 | max_answer_len = 0 22 | for instance in dataset: 23 | sent1 = instance['amr'].strip() 24 | sent2 = instance['sent'].strip() 25 | id = instance['id'] if 'id' in instance else None 26 | if sent1 == "" or sent2 == "": 27 | continue 28 | max_answer_len = max(max_answer_len, len(sent2.split())) # text2 is the sequence to be generated 29 | all_instances.append((sent1, sent2, id)) 30 | return all_instances, max_answer_len 31 | 32 | def read_generation_datasets_from_fof(fofpath, isLower=True): 33 | all_paths = read_text_file(fofpath) 34 | all_instances = [] 35 | max_answer_len = 0 36 | for cur_path in all_paths: 37 | print(cur_path) 38 | (cur_instances, cur_max_answer_len) = read_all_GenerationDatasets(cur_path, isLower=isLower) 39 | print("cur_max_answer_len: %s" % cur_max_answer_len) 40 | all_instances.extend(cur_instances) 41 | if max_answer_len 0.2 or oov_rate2 > 0.2: 68 | print('!!!!!oov_rate for ENC {} and DEC {}'.format(oov_rate1, oov_rate2)) 69 | print(sent1) 70 | print(sent2) 71 | print('==============') 72 | if options.max_passage_len != -1: sent1_idx = sent1_idx[:options.max_passage_len] 73 | if options.max_answer_len != -1: sent2_idx = sent2_idx[:options.max_answer_len] 74 | instances.append((sent1_idx, sent2_idx, sent1, sent2, id)) 75 | 76 | all_questions = instances 77 | instances = None 78 | 79 | # sort instances based on length 80 | if isSort: 81 | all_questions = sorted(all_questions, key=lambda xxx: (len(xxx[0]), len(xxx[1]))) 82 | else: 83 | pass 84 | self.num_instances = len(all_questions) 85 | 86 | # distribute questions into different buckets 87 | batch_spans = padding_utils.make_batches(self.num_instances, batch_size) 88 | self.batches = [] 89 | for batch_index, (batch_start, batch_end) in enumerate(batch_spans): 90 | cur_questions = [] 91 | for i in xrange(batch_start, batch_end): 92 | cur_questions.append(all_questions[i]) 93 | cur_batch = Batch(cur_questions, options, word_vocab=dec_word_vocab, char_vocab=char_vocab) 94 | self.batches.append(cur_batch) 95 | 96 | self.num_batch = len(self.batches) 97 | self.index_array = np.arange(self.num_batch) 98 | self.isShuffle = isShuffle 99 | if self.isShuffle: np.random.shuffle(self.index_array) 100 | self.isLoop = isLoop 101 | self.cur_pointer = 0 102 | 103 | def nextBatch(self): 104 | if self.cur_pointer>=self.num_batch: 105 | if not self.isLoop: return None 106 | self.cur_pointer = 0 107 | if self.isShuffle: np.random.shuffle(self.index_array) 108 | cur_batch = self.batches[self.index_array[self.cur_pointer]] 109 | self.cur_pointer += 1 110 | return cur_batch 111 | 112 | def reset(self): 113 | if self.isShuffle: np.random.shuffle(self.index_array) 114 | self.cur_pointer = 0 115 | 116 | def get_num_batch(self): 117 | return self.num_batch 118 | 119 | def get_num_instance(self): 120 | return self.num_instances 121 | 122 | def get_batch(self, i): 123 | if i>= self.num_batch: return None 124 | return self.batches[i] 125 | 126 | class Batch(object): 127 | def __init__(self, instances, options, word_vocab=None, char_vocab=None, POS_vocab=None, NER_vocab=None): 128 | self.options = options 129 | 130 | self.batch_size = len(instances) 131 | self.vocab = word_vocab 132 | 133 | self.id = [inst[-1] for inst in instances] 134 | self.source = [inst[-3] for inst in instances] 135 | self.target_ref = [inst[-2] for inst in instances] 136 | 137 | # create length 138 | self.sent1_length = [] # [batch_size] 139 | self.sent2_length = [] # [batch_size] 140 | for (sent1_idx, sent2_idx, _, _, _) in instances: 141 | self.sent1_length.append(len(sent1_idx)) 142 | self.sent2_length.append(min(len(sent2_idx)+1, options.max_answer_len)) 143 | self.sent1_length = np.array(self.sent1_length, dtype=np.int32) 144 | self.sent2_length = np.array(self.sent2_length, dtype=np.int32) 145 | 146 | # create word representation 147 | start_id = word_vocab.getIndex('') 148 | end_id = word_vocab.getIndex('') 149 | if options.with_word: 150 | self.sent1_word = [] # [batch_size, sent1_len] 151 | self.sent2_word = [] # [batch_size, sent2_len] 152 | self.sent2_input_word = [] 153 | for (sent1_idx, sent2_idx, _, _, _) in instances: 154 | self.sent1_word.append(sent1_idx) 155 | self.sent2_word.append(sent2_idx+[end_id]) 156 | self.sent2_input_word.append([start_id]+sent2_idx) 157 | self.sent1_word = padding_utils.pad_2d_vals(self.sent1_word, len(instances), np.max(self.sent1_length)) 158 | self.sent2_word = padding_utils.pad_2d_vals(self.sent2_word, len(instances), options.max_answer_len) 159 | self.sent2_input_word = padding_utils.pad_2d_vals(self.sent2_input_word, len(instances), options.max_answer_len) 160 | 161 | self.in_answer_words = self.sent2_word 162 | self.gen_input_words = self.sent2_input_word 163 | self.answer_lengths = self.sent2_length 164 | 165 | if options.with_char: 166 | self.sent1_char = [] # [batch_size, sent1_len] 167 | self.sent1_char_lengths = [] 168 | for (_, _, sent1, sent2, _) in instances: 169 | sent1_char_idx = char_vocab.to_character_matrix_for_list(sent1.split()[:options.max_passage_len]) 170 | self.sent1_char.append(sent1_char_idx) 171 | self.sent1_char_lengths.append([len(x) for x in sent1_char_idx]) 172 | self.sent1_char = padding_utils.pad_3d_vals_no_size(self.sent1_char) 173 | self.sent1_char_lengths = padding_utils.pad_2d_vals_no_size(self.sent1_char_lengths) 174 | 175 | 176 | if __name__ == "__main__": 177 | all_instances, _ = read_all_GenerationDatasets('./data/training.json', True) 178 | print(1.0*sum(1 for sent1, sent2, sent3 in all_instances if sent1.get_length() > 200)) 179 | print(1.0*sum(1 for sent1, sent2, sent3 in all_instances if sent2.get_length() > 100)) 180 | print('DONE!') 181 | all_instances, _ = read_all_GenerationDatasets('./data/test.json', True) 182 | print(1.0*sum(1 for sent1, sent2, sent3 in all_instances if sent1.get_length() > 200)) 183 | print(1.0*sum(1 for sent1, sent2, sent3 in all_instances if sent2.get_length() > 100)) 184 | print('DONE!') 185 | -------------------------------------------------------------------------------- /src_s2s/phrase_lattice_utils.py: -------------------------------------------------------------------------------- 1 | from graphviz import Digraph 2 | import numpy as np 3 | 4 | class prefix_tree_node(object): 5 | def __init__(self, node_id): 6 | self.node_id = node_id 7 | self.phrase_id = -1# phrase_id==-1 for all intermediate nodes 8 | self.children = {} 9 | self.parent = None 10 | 11 | def add_child(self, word, next_node): 12 | self.children[word] = next_node 13 | next_node.parent = (word, self) 14 | 15 | def set_phrase_id(self, phrase_id): 16 | self.phrase_id = phrase_id 17 | 18 | def find_child(self, word): 19 | if self.children.has_key(word): 20 | return self.children.get(word) 21 | else: 22 | return None 23 | 24 | class prefix_tree(object): 25 | def __init__(self, phrase2id): 26 | self.root_node = prefix_tree_node(0) 27 | self.all_nodes = [self.root_node] 28 | self.phrase_id_node = {} 29 | 30 | for phrase,phrase_id in dict.iteritems(phrase2id): 31 | words = phrase.split() 32 | if len(words)<=1: continue 33 | cur_node = self.root_node 34 | for word in words: 35 | next_node = cur_node.find_child(word) 36 | if next_node == None: 37 | next_node = prefix_tree_node(len(self.all_nodes)) 38 | cur_node.add_child(word, next_node) 39 | self.all_nodes.append(next_node) 40 | cur_node = next_node 41 | cur_node.set_phrase_id(phrase_id) 42 | self.phrase_id_node[phrase_id] = cur_node 43 | 44 | def get_phrase_id(self, phrase): 45 | words = phrase.split() 46 | cur_node = self.root_node 47 | for word in words: 48 | cur_node = cur_node.find_child(word) 49 | if cur_node is None: return None 50 | return cur_node.phrase_id 51 | 52 | def get_phrase(self, phrase_id): 53 | if not self.phrase_id_node.has_key(phrase_id): return None 54 | cur_node = self.phrase_id_node[phrase_id] 55 | words = [] 56 | while cur_node: 57 | if cur_node.parent is None: break 58 | (cur_word, cur_parent) = cur_node.parent 59 | words.insert(0, cur_word) 60 | cur_node = cur_parent 61 | return " ".join(words) 62 | 63 | def has_phrase_id(self, phrase_id): 64 | return self.phrase_id_node.has_key(phrase_id) 65 | 66 | 67 | def init_bak(self, phrase_id): 68 | self.root_node = prefix_tree_node(0) 69 | self.all_nodes = [self.root_node] 70 | 71 | for phrase, phrase_id in phrase_id.iteritems(): 72 | words = phrase.split() 73 | cur_node = self.root_node 74 | for word in words: 75 | next_node = cur_node.find_child(word) 76 | if next_node == None: 77 | next_node = prefix_tree_node(len(self.all_nodes)) 78 | cur_node.add_child(word, next_node) 79 | self.all_nodes.append(next_node) 80 | cur_node = next_node 81 | cur_node.set_phrase_id(phrase_id) 82 | 83 | def __str__(self): 84 | dot = Digraph(name='prefix_tree') 85 | for cur_node in self.all_nodes: 86 | dot.node(str(cur_node.node_id), str(cur_node.phrase_id)) 87 | for edge in cur_node.children.keys(): 88 | cur_edge_node = cur_node.children[edge] 89 | dot.edge(str(cur_node.node_id), str(cur_edge_node.node_id), edge) 90 | 91 | dot.body.append(r'label = "prefix tree"') 92 | return dot.source 93 | 94 | 95 | class lattice_node(object): 96 | def __init__(self, node_id): 97 | self.node_id = node_id 98 | self.out_deges = [] 99 | 100 | def add_edge(self, edge): 101 | self.out_deges.append(edge) 102 | 103 | def get_edge(self, i): 104 | if len(self.out_deges)-1 and cur_tail_node.node_id-cur_start_node.node_id>1: # one phrase is match 158 | # add one edge 159 | cur_phrase_id = cur_prefix_node.phrase_id 160 | # cur_phrase = id2phrase[cur_phrase_id] 161 | cur_phrase = prefix_tree.get_phrase(cur_phrase_id) 162 | new_edge = lattice_edge(cur_phrase, cur_phrase_id, cur_tail_node) 163 | cur_start_node.add_edge(new_edge) 164 | cur_node = cur_tail_node 165 | 166 | def sample_a_partition(self, max_matching=False): 167 | phrases = [] 168 | phrase_ids = [] 169 | cur_node = self.start_node 170 | while cur_node: 171 | all_edges = cur_node.out_deges 172 | edge_size = len(all_edges) 173 | if edge_size<=0: break 174 | if max_matching: 175 | max_len = -1 176 | sampled_idx = 0 177 | for i, cur_edge in enumerate(all_edges): 178 | cur_length = len(cur_edge.phrase.split()) 179 | if max_len=len(words): break 208 | cur_phrase = " ".join(words[i:j+1]) 209 | if phrase2id.has_key(cur_phrase): continue 210 | cur_index = len(phrase2id) 211 | phrase2id[cur_phrase] = cur_index 212 | id2phrase[cur_index] = cur_phrase 213 | return (phrase2id, id2phrase) 214 | 215 | 216 | if __name__ == "__main__": 217 | ''' # test create prefix tree 218 | phrase_id = {"a": 0, "a b f": 1, "a b c":2, "a b d":3, "b":4, "b d":5, "b e":6} 219 | tree = prefix_tree(phrase_id) 220 | print(tree) 221 | #''' 222 | 223 | ''' # test creating lattice 224 | sentence = "a b cc dd ee f g hhh iiiii jj kk lm" 225 | toks = sentence.split() 226 | lattice = phrase_lattice(toks) 227 | print(lattice) 228 | #''' 229 | 230 | src_sentence = "what is the significance of the periodic table ?" 231 | tgt_sentence = "what is a periodic table ?" 232 | 233 | # collect phrases from src_setence 234 | max_chunk_len = 4 235 | (phrase2id, id2phrase) = collect_all_possible_phrases(src_sentence, max_chunk_len=max_chunk_len) 236 | 237 | # create prefix tree 238 | tree = prefix_tree(phrase2id) 239 | # print(tree) 240 | #''' 241 | for phrase in phrase2id.keys(): 242 | phrase_id = tree.get_phrase_id(phrase) 243 | cur_phrase = tree.get_phrase(phrase_id) 244 | print(phrase) 245 | print(phrase_id) 246 | print(cur_phrase) 247 | print() 248 | #''' 249 | 250 | ''' 251 | # create lattice for the target sentence 252 | lattice = phrase_lattice(tgt_sentence.split(), word_vocab=None, prefix_tree=tree) 253 | # print(lattice) 254 | 255 | # sample partitions 256 | (phrases, phrase_ids) = lattice.sample_a_partition() 257 | print(phrases) 258 | print(phrase_ids) 259 | 260 | (phrases, phrase_ids) = lattice.sample_a_partition() 261 | print(phrases) 262 | print(phrase_ids) 263 | 264 | (phrases, phrase_ids) = lattice.sample_a_partition() 265 | print(phrases) 266 | print(phrase_ids) 267 | 268 | (phrases, phrase_ids) = lattice.sample_a_partition(max_matching=True) 269 | print(phrases) 270 | print(phrase_ids) 271 | #''' -------------------------------------------------------------------------------- /src_g2s/metric_bleu_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # bleu_scorer.py 4 | # David Chiang 5 | 6 | # Copyright (c) 2004-2006 University of Maryland. All rights 7 | # reserved. Do not redistribute without permission from the 8 | # author. Not for commercial use. 9 | 10 | # Modified by: 11 | # Hao Fang 12 | # Tsung-Yi Lin 13 | 14 | '''Provides: 15 | cook_refs(refs, n=4): Transform a list of reference sentences as strings into a form usable by cook_test(). 16 | cook_test(test, refs, n=4): Transform a test sentence as a string (together with the cooked reference sentences) into a form usable by score_cooked(). 17 | ''' 18 | 19 | import copy 20 | import sys, math, re 21 | from collections import defaultdict 22 | 23 | def precook(s, n=4, out=False): 24 | """Takes a string as input and returns an object that can be given to 25 | either cook_refs or cook_test. This is optional: cook_refs and cook_test 26 | can take string arguments as well.""" 27 | words = s.split() 28 | counts = defaultdict(int) 29 | for k in xrange(1,n+1): 30 | for i in xrange(len(words)-k+1): 31 | ngram = tuple(words[i:i+k]) 32 | counts[ngram] += 1 33 | return (len(words), counts) 34 | 35 | def cook_refs(refs, eff=None, n=4): ## lhuang: oracle will call with "average" 36 | '''Takes a list of reference sentences for a single segment 37 | and returns an object that encapsulates everything that BLEU 38 | needs to know about them.''' 39 | 40 | reflen = [] 41 | maxcounts = {} 42 | for ref in refs: 43 | rl, counts = precook(ref, n) 44 | reflen.append(rl) 45 | for (ngram,count) in counts.iteritems(): 46 | maxcounts[ngram] = max(maxcounts.get(ngram,0), count) 47 | 48 | # Calculate effective reference sentence length. 49 | if eff == "shortest": 50 | reflen = min(reflen) 51 | elif eff == "average": 52 | reflen = float(sum(reflen))/len(reflen) 53 | 54 | ## lhuang: N.B.: leave reflen computaiton to the very end!! 55 | 56 | ## lhuang: N.B.: in case of "closest", keep a list of reflens!! (bad design) 57 | 58 | return (reflen, maxcounts) 59 | 60 | def cook_test(test, (reflen, refmaxcounts), eff=None, n=4): 61 | '''Takes a test sentence and returns an object that 62 | encapsulates everything that BLEU needs to know about it.''' 63 | 64 | testlen, counts = precook(test, n, True) 65 | 66 | result = {} 67 | 68 | # Calculate effective reference sentence length. 69 | 70 | if eff == "closest": 71 | result["reflen"] = min((abs(l-testlen), l) for l in reflen)[1] 72 | else: ## i.e., "average" or "shortest" or None 73 | result["reflen"] = reflen 74 | 75 | result["testlen"] = testlen 76 | 77 | result["guess"] = [max(0,testlen-k+1) for k in xrange(1,n+1)] 78 | 79 | result['correct'] = [0]*n 80 | for (ngram, count) in counts.iteritems(): 81 | result["correct"][len(ngram)-1] += min(refmaxcounts.get(ngram,0), count) 82 | 83 | return result 84 | 85 | class BleuScorer(object): 86 | """Bleu scorer. 87 | """ 88 | 89 | __slots__ = "n", "crefs", "ctest", "_score", "_ratio", "_testlen", "_reflen", "special_reflen" 90 | # special_reflen is used in oracle (proportional effective ref len for a node). 91 | 92 | def copy(self): 93 | ''' copy the refs.''' 94 | new = BleuScorer(n=self.n) 95 | new.ctest = copy.copy(self.ctest) 96 | new.crefs = copy.copy(self.crefs) 97 | new._score = None 98 | return new 99 | 100 | def __init__(self, test=None, refs=None, n=4, special_reflen=None): 101 | ''' singular instance ''' 102 | 103 | self.n = n 104 | self.crefs = [] 105 | self.ctest = [] 106 | self.cook_append(test, refs) 107 | self.special_reflen = special_reflen 108 | 109 | def cook_append(self, test, refs): 110 | '''called by constructor and __iadd__ to avoid creating new instances.''' 111 | 112 | if refs is not None: 113 | self.crefs.append(cook_refs(refs)) 114 | if test is not None: 115 | cooked_test = cook_test(test, self.crefs[-1]) 116 | self.ctest.append(cooked_test) ## N.B.: -1 117 | else: 118 | self.ctest.append(None) # lens of crefs and ctest have to match 119 | 120 | self._score = None ## need to recompute 121 | 122 | def ratio(self, option=None): 123 | self.compute_score(option=option) 124 | return self._ratio 125 | 126 | def score_ratio(self, option=None): 127 | '''return (bleu, len_ratio) pair''' 128 | return (self.fscore(option=option), self.ratio(option=option)) 129 | 130 | def score_ratio_str(self, option=None): 131 | return "%.4f (%.2f)" % self.score_ratio(option) 132 | 133 | def reflen(self, option=None): 134 | self.compute_score(option=option) 135 | return self._reflen 136 | 137 | def testlen(self, option=None): 138 | self.compute_score(option=option) 139 | return self._testlen 140 | 141 | def retest(self, new_test): 142 | if type(new_test) is str: 143 | new_test = [new_test] 144 | assert len(new_test) == len(self.crefs), new_test 145 | self.ctest = [] 146 | for t, rs in zip(new_test, self.crefs): 147 | self.ctest.append(cook_test(t, rs)) 148 | self._score = None 149 | 150 | return self 151 | 152 | def rescore(self, new_test): 153 | ''' replace test(s) with new test(s), and returns the new score.''' 154 | 155 | return self.retest(new_test).compute_score() 156 | 157 | def size(self): 158 | assert len(self.crefs) == len(self.ctest), "refs/test mismatch! %d<>%d" % (len(self.crefs), len(self.ctest)) 159 | return len(self.crefs) 160 | 161 | def __iadd__(self, other): 162 | '''add an instance (e.g., from another sentence).''' 163 | 164 | if type(other) is tuple: 165 | ## avoid creating new BleuScorer instances 166 | self.cook_append(other[0], other[1]) 167 | else: 168 | assert self.compatible(other), "incompatible BLEUs." 169 | self.ctest.extend(other.ctest) 170 | self.crefs.extend(other.crefs) 171 | self._score = None ## need to recompute 172 | 173 | return self 174 | 175 | def compatible(self, other): 176 | return isinstance(other, BleuScorer) and self.n == other.n 177 | 178 | def single_reflen(self, option="average"): 179 | return self._single_reflen(self.crefs[0][0], option) 180 | 181 | def _single_reflen(self, reflens, option=None, testlen=None): 182 | 183 | if option == "shortest": 184 | reflen = min(reflens) 185 | elif option == "average": 186 | reflen = float(sum(reflens))/len(reflens) 187 | elif option == "closest": 188 | reflen = min((abs(l-testlen), l) for l in reflens)[1] 189 | else: 190 | assert False, "unsupported reflen option %s" % option 191 | 192 | return reflen 193 | 194 | def recompute_score(self, option=None, verbose=0): 195 | self._score = None 196 | return self.compute_score(option, verbose) 197 | 198 | def compute_score(self, option=None, verbose=0): 199 | n = self.n 200 | small = 1e-9 201 | tiny = 1e-15 ## so that if guess is 0 still return 0 202 | bleu_list = [[] for _ in range(n)] 203 | 204 | if self._score is not None: 205 | return self._score 206 | 207 | if option is None: 208 | option = "average" if len(self.crefs) == 1 else "closest" 209 | 210 | self._testlen = 0 211 | self._reflen = 0 212 | totalcomps = {'testlen':0, 'reflen':0, 'guess':[0]*n, 'correct':[0]*n} 213 | 214 | # for each sentence 215 | for comps in self.ctest: 216 | testlen = comps['testlen'] 217 | self._testlen += testlen 218 | 219 | if self.special_reflen is None: ## need computation 220 | reflen = self._single_reflen(comps['reflen'], option, testlen) 221 | else: 222 | reflen = self.special_reflen 223 | 224 | self._reflen += reflen 225 | 226 | for key in ['guess','correct']: 227 | for k in xrange(n): 228 | totalcomps[key][k] += comps[key][k] 229 | 230 | # append per image bleu score 231 | bleu = 1. 232 | for k in xrange(n): 233 | bleu *= (float(comps['correct'][k]) + tiny) \ 234 | /(float(comps['guess'][k]) + small) 235 | bleu_list[k].append(bleu ** (1./(k+1))) 236 | ratio = (testlen + tiny) / (reflen + small) ## N.B.: avoid zero division 237 | if ratio < 1: 238 | for k in xrange(n): 239 | bleu_list[k][-1] *= math.exp(1 - 1/ratio) 240 | 241 | if verbose > 1: 242 | print comps, reflen 243 | 244 | totalcomps['reflen'] = self._reflen 245 | totalcomps['testlen'] = self._testlen 246 | 247 | bleus = [] 248 | bleu = 1. 249 | for k in xrange(n): 250 | bleu *= float(totalcomps['correct'][k] + tiny) \ 251 | / (totalcomps['guess'][k] + small) 252 | bleus.append(bleu ** (1./(k+1))) 253 | ratio = (self._testlen + tiny) / (self._reflen + small) ## N.B.: avoid zero division 254 | if ratio < 1: 255 | for k in xrange(n): 256 | bleus[k] *= math.exp(1 - 1/ratio) 257 | 258 | if verbose > 0: 259 | print totalcomps 260 | print "ratio:", ratio 261 | 262 | self._score = bleus 263 | return self._score, bleu_list 264 | 265 | class Bleu: 266 | def __init__(self, n=4): 267 | # default compute Blue score up to 4 268 | self._n = n 269 | self._hypo_for_image = {} 270 | self.ref_for_image = {} 271 | 272 | def compute_score(self, gts, res): 273 | 274 | assert(gts.keys() == res.keys()) 275 | imgIds = gts.keys() 276 | 277 | bleu_scorer = BleuScorer(n=self._n) 278 | for id in imgIds: 279 | hypo = res[id] 280 | ref = gts[id] 281 | 282 | # Sanity check. 283 | assert(type(hypo) is list) 284 | assert(len(hypo) == 1) 285 | assert(type(ref) is list) 286 | assert(len(ref) >= 1) 287 | 288 | bleu_scorer += (hypo[0], ref) 289 | 290 | #score, scores = bleu_scorer.compute_score(option='shortest') 291 | score, scores = bleu_scorer.compute_score(option='closest', verbose=0) 292 | #score, scores = bleu_scorer.compute_score(option='average', verbose=1) 293 | 294 | # return (bleu, bleu_info) 295 | return score, scores 296 | 297 | def method(self): 298 | return "Bleu" 299 | -------------------------------------------------------------------------------- /src_s2s/metric_bleu_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # bleu_scorer.py 4 | # David Chiang 5 | 6 | # Copyright (c) 2004-2006 University of Maryland. All rights 7 | # reserved. Do not redistribute without permission from the 8 | # author. Not for commercial use. 9 | 10 | # Modified by: 11 | # Hao Fang 12 | # Tsung-Yi Lin 13 | 14 | '''Provides: 15 | cook_refs(refs, n=4): Transform a list of reference sentences as strings into a form usable by cook_test(). 16 | cook_test(test, refs, n=4): Transform a test sentence as a string (together with the cooked reference sentences) into a form usable by score_cooked(). 17 | ''' 18 | 19 | import copy 20 | import sys, math, re 21 | from collections import defaultdict 22 | 23 | def precook(s, n=4, out=False): 24 | """Takes a string as input and returns an object that can be given to 25 | either cook_refs or cook_test. This is optional: cook_refs and cook_test 26 | can take string arguments as well.""" 27 | words = s.split() 28 | counts = defaultdict(int) 29 | for k in xrange(1,n+1): 30 | for i in xrange(len(words)-k+1): 31 | ngram = tuple(words[i:i+k]) 32 | counts[ngram] += 1 33 | return (len(words), counts) 34 | 35 | def cook_refs(refs, eff=None, n=4): ## lhuang: oracle will call with "average" 36 | '''Takes a list of reference sentences for a single segment 37 | and returns an object that encapsulates everything that BLEU 38 | needs to know about them.''' 39 | 40 | reflen = [] 41 | maxcounts = {} 42 | for ref in refs: 43 | rl, counts = precook(ref, n) 44 | reflen.append(rl) 45 | for (ngram,count) in counts.iteritems(): 46 | maxcounts[ngram] = max(maxcounts.get(ngram,0), count) 47 | 48 | # Calculate effective reference sentence length. 49 | if eff == "shortest": 50 | reflen = min(reflen) 51 | elif eff == "average": 52 | reflen = float(sum(reflen))/len(reflen) 53 | 54 | ## lhuang: N.B.: leave reflen computaiton to the very end!! 55 | 56 | ## lhuang: N.B.: in case of "closest", keep a list of reflens!! (bad design) 57 | 58 | return (reflen, maxcounts) 59 | 60 | def cook_test(test, (reflen, refmaxcounts), eff=None, n=4): 61 | '''Takes a test sentence and returns an object that 62 | encapsulates everything that BLEU needs to know about it.''' 63 | 64 | testlen, counts = precook(test, n, True) 65 | 66 | result = {} 67 | 68 | # Calculate effective reference sentence length. 69 | 70 | if eff == "closest": 71 | result["reflen"] = min((abs(l-testlen), l) for l in reflen)[1] 72 | else: ## i.e., "average" or "shortest" or None 73 | result["reflen"] = reflen 74 | 75 | result["testlen"] = testlen 76 | 77 | result["guess"] = [max(0,testlen-k+1) for k in xrange(1,n+1)] 78 | 79 | result['correct'] = [0]*n 80 | for (ngram, count) in counts.iteritems(): 81 | result["correct"][len(ngram)-1] += min(refmaxcounts.get(ngram,0), count) 82 | 83 | return result 84 | 85 | class BleuScorer(object): 86 | """Bleu scorer. 87 | """ 88 | 89 | __slots__ = "n", "crefs", "ctest", "_score", "_ratio", "_testlen", "_reflen", "special_reflen" 90 | # special_reflen is used in oracle (proportional effective ref len for a node). 91 | 92 | def copy(self): 93 | ''' copy the refs.''' 94 | new = BleuScorer(n=self.n) 95 | new.ctest = copy.copy(self.ctest) 96 | new.crefs = copy.copy(self.crefs) 97 | new._score = None 98 | return new 99 | 100 | def __init__(self, test=None, refs=None, n=4, special_reflen=None): 101 | ''' singular instance ''' 102 | 103 | self.n = n 104 | self.crefs = [] 105 | self.ctest = [] 106 | self.cook_append(test, refs) 107 | self.special_reflen = special_reflen 108 | 109 | def cook_append(self, test, refs): 110 | '''called by constructor and __iadd__ to avoid creating new instances.''' 111 | 112 | if refs is not None: 113 | self.crefs.append(cook_refs(refs)) 114 | if test is not None: 115 | cooked_test = cook_test(test, self.crefs[-1]) 116 | self.ctest.append(cooked_test) ## N.B.: -1 117 | else: 118 | self.ctest.append(None) # lens of crefs and ctest have to match 119 | 120 | self._score = None ## need to recompute 121 | 122 | def ratio(self, option=None): 123 | self.compute_score(option=option) 124 | return self._ratio 125 | 126 | def score_ratio(self, option=None): 127 | '''return (bleu, len_ratio) pair''' 128 | return (self.fscore(option=option), self.ratio(option=option)) 129 | 130 | def score_ratio_str(self, option=None): 131 | return "%.4f (%.2f)" % self.score_ratio(option) 132 | 133 | def reflen(self, option=None): 134 | self.compute_score(option=option) 135 | return self._reflen 136 | 137 | def testlen(self, option=None): 138 | self.compute_score(option=option) 139 | return self._testlen 140 | 141 | def retest(self, new_test): 142 | if type(new_test) is str: 143 | new_test = [new_test] 144 | assert len(new_test) == len(self.crefs), new_test 145 | self.ctest = [] 146 | for t, rs in zip(new_test, self.crefs): 147 | self.ctest.append(cook_test(t, rs)) 148 | self._score = None 149 | 150 | return self 151 | 152 | def rescore(self, new_test): 153 | ''' replace test(s) with new test(s), and returns the new score.''' 154 | 155 | return self.retest(new_test).compute_score() 156 | 157 | def size(self): 158 | assert len(self.crefs) == len(self.ctest), "refs/test mismatch! %d<>%d" % (len(self.crefs), len(self.ctest)) 159 | return len(self.crefs) 160 | 161 | def __iadd__(self, other): 162 | '''add an instance (e.g., from another sentence).''' 163 | 164 | if type(other) is tuple: 165 | ## avoid creating new BleuScorer instances 166 | self.cook_append(other[0], other[1]) 167 | else: 168 | assert self.compatible(other), "incompatible BLEUs." 169 | self.ctest.extend(other.ctest) 170 | self.crefs.extend(other.crefs) 171 | self._score = None ## need to recompute 172 | 173 | return self 174 | 175 | def compatible(self, other): 176 | return isinstance(other, BleuScorer) and self.n == other.n 177 | 178 | def single_reflen(self, option="average"): 179 | return self._single_reflen(self.crefs[0][0], option) 180 | 181 | def _single_reflen(self, reflens, option=None, testlen=None): 182 | 183 | if option == "shortest": 184 | reflen = min(reflens) 185 | elif option == "average": 186 | reflen = float(sum(reflens))/len(reflens) 187 | elif option == "closest": 188 | reflen = min((abs(l-testlen), l) for l in reflens)[1] 189 | else: 190 | assert False, "unsupported reflen option %s" % option 191 | 192 | return reflen 193 | 194 | def recompute_score(self, option=None, verbose=0): 195 | self._score = None 196 | return self.compute_score(option, verbose) 197 | 198 | def compute_score(self, option=None, verbose=0): 199 | n = self.n 200 | small = 1e-9 201 | tiny = 1e-15 ## so that if guess is 0 still return 0 202 | bleu_list = [[] for _ in range(n)] 203 | 204 | if self._score is not None: 205 | return self._score 206 | 207 | if option is None: 208 | option = "average" if len(self.crefs) == 1 else "closest" 209 | 210 | self._testlen = 0 211 | self._reflen = 0 212 | totalcomps = {'testlen':0, 'reflen':0, 'guess':[0]*n, 'correct':[0]*n} 213 | 214 | # for each sentence 215 | for comps in self.ctest: 216 | testlen = comps['testlen'] 217 | self._testlen += testlen 218 | 219 | if self.special_reflen is None: ## need computation 220 | reflen = self._single_reflen(comps['reflen'], option, testlen) 221 | else: 222 | reflen = self.special_reflen 223 | 224 | self._reflen += reflen 225 | 226 | for key in ['guess','correct']: 227 | for k in xrange(n): 228 | totalcomps[key][k] += comps[key][k] 229 | 230 | # append per image bleu score 231 | bleu = 1. 232 | for k in xrange(n): 233 | bleu *= (float(comps['correct'][k]) + tiny) \ 234 | /(float(comps['guess'][k]) + small) 235 | bleu_list[k].append(bleu ** (1./(k+1))) 236 | ratio = (testlen + tiny) / (reflen + small) ## N.B.: avoid zero division 237 | if ratio < 1: 238 | for k in xrange(n): 239 | bleu_list[k][-1] *= math.exp(1 - 1/ratio) 240 | 241 | if verbose > 1: 242 | print comps, reflen 243 | 244 | totalcomps['reflen'] = self._reflen 245 | totalcomps['testlen'] = self._testlen 246 | 247 | bleus = [] 248 | bleu = 1. 249 | for k in xrange(n): 250 | bleu *= float(totalcomps['correct'][k] + tiny) \ 251 | / (totalcomps['guess'][k] + small) 252 | bleus.append(bleu ** (1./(k+1))) 253 | ratio = (self._testlen + tiny) / (self._reflen + small) ## N.B.: avoid zero division 254 | if ratio < 1: 255 | for k in xrange(n): 256 | bleus[k] *= math.exp(1 - 1/ratio) 257 | 258 | if verbose > 0: 259 | print totalcomps 260 | print "ratio:", ratio 261 | 262 | self._score = bleus 263 | return self._score, bleu_list 264 | 265 | class Bleu: 266 | def __init__(self, n=4): 267 | # default compute Blue score up to 4 268 | self._n = n 269 | self._hypo_for_image = {} 270 | self.ref_for_image = {} 271 | 272 | def compute_score(self, gts, res): 273 | 274 | assert(gts.keys() == res.keys()) 275 | imgIds = gts.keys() 276 | 277 | bleu_scorer = BleuScorer(n=self._n) 278 | for id in imgIds: 279 | hypo = res[id] 280 | ref = gts[id] 281 | 282 | # Sanity check. 283 | assert(type(hypo) is list) 284 | assert(len(hypo) == 1) 285 | assert(type(ref) is list) 286 | assert(len(ref) >= 1) 287 | 288 | bleu_scorer += (hypo[0], ref) 289 | 290 | #score, scores = bleu_scorer.compute_score(option='shortest') 291 | score, scores = bleu_scorer.compute_score(option='closest', verbose=0) 292 | #score, scores = bleu_scorer.compute_score(option='average', verbose=1) 293 | 294 | # return (bleu, bleu_info) 295 | return score, scores 296 | 297 | def method(self): 298 | return "Bleu" 299 | -------------------------------------------------------------------------------- /src_s2s/encoder_utils.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import match_utils 3 | 4 | def collect_final_step_lstm(lstm_rep, lens): 5 | lens = tf.maximum(lens, tf.zeros_like(lens, dtype=tf.int32)) # [batch,] 6 | idxs = tf.range(0, limit=tf.shape(lens)[0]) # [batch,] 7 | indices = tf.stack((idxs,lens,), axis=1) # [batch_size, 2] 8 | return tf.gather_nd(lstm_rep, indices, name='lstm-forward-last') 9 | 10 | class SeqEncoder(object): 11 | def __init__(self, placeholders, options, word_vocab=None, char_vocab=None, POS_vocab=None, NER_vocab=None): 12 | 13 | self.options = options 14 | 15 | self.word_vocab = word_vocab 16 | self.char_vocab = char_vocab 17 | self.POS_vocab = POS_vocab 18 | self.NER_vocab = NER_vocab 19 | 20 | self.passage_lengths = placeholders.passage_lengths #tf.placeholder(tf.int32, [None]) 21 | if options.with_word: 22 | self.in_passage_words = placeholders.in_passage_words #tf.placeholder(tf.int32, [None, None]) # [batch_size, passage_len] 23 | 24 | if options.with_char: 25 | self.passage_char_lengths = placeholders.passage_char_lengths 26 | #tf.placeholder(tf.int32, [None,None]) # [batch_size, passage_len] 27 | self.in_passage_chars = placeholders.in_passage_chars 28 | #tf.placeholder(tf.int32, [None, None, None]) # [batch_size, passage_len, p_char_len] 29 | 30 | if options.with_POS: 31 | self.in_passage_POSs = placeholders.in_passage_POSs #tf.placeholder(tf.int32, [None, None]) # [batch_size, passage_len] 32 | 33 | if options.with_NER: 34 | self.in_passage_NERs = placeholders.in_passage_NERs #tf.placeholder(tf.int32, [None, None]) # [batch_size, passage_len] 35 | 36 | def encode(self, is_training=True): 37 | options = self.options 38 | 39 | # ======word representation layer====== 40 | in_passage_repres = [] 41 | input_dim = 0 42 | if options.with_word and self.word_vocab is not None: 43 | word_vec_trainable = True 44 | cur_device = '/gpu:0' 45 | if options.fix_word_vec: 46 | word_vec_trainable = False 47 | cur_device = '/cpu:0' 48 | with tf.variable_scope("embedding"), tf.device(cur_device): 49 | self.word_embedding = tf.get_variable("word_embedding", trainable=word_vec_trainable, 50 | initializer=tf.constant(self.word_vocab.word_vecs), dtype=tf.float32) 51 | 52 | in_passage_word_repres = tf.nn.embedding_lookup(self.word_embedding, self.in_passage_words) 53 | # [batch_size, passage_len, word_dim] 54 | in_passage_repres.append(in_passage_word_repres) 55 | 56 | input_shape = tf.shape(self.in_passage_words) 57 | batch_size = input_shape[0] 58 | passage_len = input_shape[1] 59 | input_dim += self.word_vocab.word_dim 60 | 61 | if options.with_char and self.char_vocab is not None: 62 | input_shape = tf.shape(self.in_passage_chars) 63 | batch_size = input_shape[0] 64 | passage_len = input_shape[1] 65 | p_char_len = input_shape[2] 66 | char_dim = self.char_vocab.word_dim 67 | self.char_embedding = tf.get_variable("char_embedding", 68 | initializer=tf.constant(self.char_vocab.word_vecs), dtype=tf.float32) 69 | in_passage_char_repres = tf.nn.embedding_lookup(self.char_embedding, 70 | self.in_passage_chars) # [batch_size, passage_len, p_char_len, char_dim] 71 | in_passage_char_repres = tf.reshape(in_passage_char_repres, shape=[-1, p_char_len, char_dim]) 72 | passage_char_lengths = tf.reshape(self.passage_char_lengths, [-1]) 73 | with tf.variable_scope('char_lstm'): 74 | # lstm cell 75 | char_lstm_cell = tf.contrib.rnn.BasicLSTMCell(options.char_lstm_dim) 76 | # dropout 77 | if is_training: char_lstm_cell = tf.contrib.rnn.DropoutWrapper(char_lstm_cell, 78 | output_keep_prob=(1 - options.dropout_rate)) 79 | char_lstm_cell = tf.contrib.rnn.MultiRNNCell([char_lstm_cell]) 80 | # passage representation 81 | passage_char_outputs = tf.nn.dynamic_rnn(char_lstm_cell, in_passage_char_repres, 82 | sequence_length=passage_char_lengths,dtype=tf.float32)[0] 83 | # [batch_size*question_len, q_char_len, char_lstm_dim] 84 | passage_char_outputs = collect_final_step_lstm(passage_char_outputs, passage_char_lengths-1) 85 | passage_char_outputs = tf.reshape(passage_char_outputs, [batch_size, passage_len, options.char_lstm_dim]) 86 | 87 | in_passage_repres.append(passage_char_outputs) 88 | input_dim += options.char_lstm_dim 89 | 90 | in_passage_repres = tf.concat(in_passage_repres, 2) # [batch_size, passage_len, dim] 91 | 92 | if options.compress_input: # compress input word vector into smaller vectors 93 | w_compress = tf.get_variable("w_compress_input", [input_dim, options.compress_input_dim], dtype=tf.float32) 94 | b_compress = tf.get_variable("b_compress_input", [options.compress_input_dim], dtype=tf.float32) 95 | 96 | in_passage_repres = tf.reshape(in_passage_repres, [-1, input_dim]) 97 | in_passage_repres = tf.matmul(in_passage_repres, w_compress) + b_compress 98 | in_passage_repres = tf.tanh(in_passage_repres) 99 | in_passage_repres = tf.reshape(in_passage_repres, [batch_size, passage_len, options.compress_input_dim]) 100 | input_dim = options.compress_input_dim 101 | 102 | if is_training: 103 | in_passage_repres = tf.nn.dropout(in_passage_repres, (1 - options.dropout_rate)) 104 | else: 105 | in_passage_repres = tf.multiply(in_passage_repres, (1 - options.dropout_rate)) 106 | 107 | passage_mask = tf.sequence_mask(self.passage_lengths, passage_len, dtype=tf.float32) # [batch_size, passage_len] 108 | 109 | # sequential context matching 110 | passage_forward = None 111 | passage_backward = None 112 | all_passage_representation = [] 113 | passage_dim = 0 114 | with_lstm = True 115 | if with_lstm: 116 | with tf.variable_scope('biLSTM'): 117 | cur_in_passage_repres = in_passage_repres 118 | for i in xrange(options.context_layer_num): 119 | with tf.variable_scope('layer-{}'.format(i)): 120 | with tf.variable_scope('context_represent'): 121 | # parameters 122 | context_lstm_cell_fw = tf.contrib.rnn.LSTMCell(options.context_lstm_dim) 123 | context_lstm_cell_bw = tf.contrib.rnn.LSTMCell(options.context_lstm_dim) 124 | if is_training: 125 | context_lstm_cell_fw = tf.contrib.rnn.DropoutWrapper(context_lstm_cell_fw, output_keep_prob=(1 - options.dropout_rate)) 126 | context_lstm_cell_bw = tf.contrib.rnn.DropoutWrapper(context_lstm_cell_bw, output_keep_prob=(1 - options.dropout_rate)) 127 | 128 | # passage representation 129 | ((passage_context_representation_fw, passage_context_representation_bw), 130 | (passage_forward, passage_backward)) = tf.nn.bidirectional_dynamic_rnn( 131 | context_lstm_cell_fw, context_lstm_cell_bw, cur_in_passage_repres, dtype=tf.float32, 132 | sequence_length=self.passage_lengths) # [batch_size, passage_len, context_lstm_dim] 133 | if options.direction == 'forward': 134 | # [batch_size, passage_len, context_lstm_dim] 135 | cur_in_passage_repres = passage_context_representation_fw 136 | passage_dim += options.context_lstm_dim 137 | elif options.direction == 'backward': 138 | # [batch_size, passage_len, context_lstm_dim] 139 | cur_in_passage_repres = passage_context_representation_bw 140 | passage_dim += options.context_lstm_dim 141 | elif options.direction == 'bidir': 142 | # [batch_size, passage_len, 2*context_lstm_dim] 143 | cur_in_passage_repres = tf.concat( 144 | [passage_context_representation_fw, passage_context_representation_bw], 2) 145 | passage_dim += 2 * options.context_lstm_dim 146 | else: 147 | assert False 148 | all_passage_representation.append(cur_in_passage_repres) 149 | 150 | 151 | all_passage_representation = tf.concat(all_passage_representation, 2) # [batch_size, passage_len, passage_dim] 152 | 153 | if is_training: 154 | all_passage_representation = tf.nn.dropout(all_passage_representation, (1 - options.dropout_rate)) 155 | else: 156 | all_passage_representation = tf.multiply(all_passage_representation, (1 - options.dropout_rate)) 157 | 158 | # ======Highway layer====== 159 | if options.with_match_highway: 160 | with tf.variable_scope("context_highway"): 161 | all_passage_representation = match_utils.multi_highway_layer(all_passage_representation, 162 | passage_dim,options.highway_layer_num) 163 | 164 | all_passage_representation = all_passage_representation * tf.expand_dims(passage_mask, axis=-1) 165 | 166 | # initial state for the LSTM decoder 167 | #''' 168 | with tf.variable_scope('initial_state_for_decoder'): 169 | # Define weights and biases to reduce the cell and reduce the state 170 | w_reduce_c = tf.get_variable('w_reduce_c', [2*options.context_lstm_dim, options.gen_hidden_size], dtype=tf.float32) 171 | w_reduce_h = tf.get_variable('w_reduce_h', [2*options.context_lstm_dim, options.gen_hidden_size], dtype=tf.float32) 172 | bias_reduce_c = tf.get_variable('bias_reduce_c', [options.gen_hidden_size], dtype=tf.float32) 173 | bias_reduce_h = tf.get_variable('bias_reduce_h', [options.gen_hidden_size], dtype=tf.float32) 174 | 175 | old_c = tf.concat(values=[passage_forward.c, passage_backward.c], axis=1) 176 | old_h = tf.concat(values=[passage_forward.h, passage_backward.h], axis=1) 177 | new_c = tf.nn.tanh(tf.matmul(old_c, w_reduce_c) + bias_reduce_c) 178 | new_h = tf.nn.tanh(tf.matmul(old_h, w_reduce_h) + bias_reduce_h) 179 | 180 | init_state = tf.contrib.rnn.LSTMStateTuple(new_c, new_h) 181 | ''' 182 | new_c = tf.zeros([batch_size, options.gen_hidden_size]) 183 | new_h = tf.zeros([batch_size, options.gen_hidden_size]) 184 | init_state = LSTMStateTuple(new_c, new_h) 185 | ''' 186 | return (passage_dim, all_passage_representation, init_state) 187 | 188 | -------------------------------------------------------------------------------- /src_g2s/G2S_data_stream.py: -------------------------------------------------------------------------------- 1 | import json 2 | import re 3 | import numpy as np 4 | import random 5 | import padding_utils 6 | import amr_utils 7 | 8 | def read_text_file(text_file): 9 | lines = [] 10 | with open(text_file, "rt") as f: 11 | for line in f: 12 | line = line.decode('utf-8') 13 | lines.append(line.strip()) 14 | return lines 15 | 16 | def read_amr_file(inpath): 17 | nodes = [] # [batch, node_num,] 18 | in_neigh_indices = [] # [batch, node_num, neighbor_num,] 19 | in_neigh_edges = [] 20 | out_neigh_indices = [] # [batch, node_num, neighbor_num,] 21 | out_neigh_edges = [] 22 | sentences = [] # [batch, sent_length,] 23 | ids = [] 24 | max_in_neigh = 0 25 | max_out_neigh = 0 26 | max_node = 0 27 | max_sent = 0 28 | with open(inpath, "rU") as f: 29 | for inst in json.load(f): 30 | amr = inst['amr'] 31 | sent = inst['sent'].strip().split() 32 | id = inst['id'] if inst.has_key('id') else None 33 | amr_node = [] 34 | amr_edge = [] 35 | amr_utils.read_anonymized(amr.strip().split(), amr_node, amr_edge) 36 | # 1. 37 | nodes.append(amr_node) 38 | # 2. & 3. 39 | in_indices = [[i,] for i, x in enumerate(amr_node)] 40 | in_edges = [[':self',] for i, x in enumerate(amr_node)] 41 | out_indices = [[i,] for i, x in enumerate(amr_node)] 42 | out_edges = [[':self',] for i, x in enumerate(amr_node)] 43 | for (i,j,lb) in amr_edge: 44 | in_indices[j].append(i) 45 | in_edges[j].append(lb) 46 | out_indices[i].append(j) 47 | out_edges[i].append(lb) 48 | in_neigh_indices.append(in_indices) 49 | in_neigh_edges.append(in_edges) 50 | out_neigh_indices.append(out_indices) 51 | out_neigh_edges.append(out_edges) 52 | # 4. 53 | sentences.append(sent) 54 | ids.append(id) 55 | # update lengths 56 | max_in_neigh = max(max_in_neigh, max(len(x) for x in in_indices)) 57 | max_out_neigh = max(max_out_neigh, max(len(x) for x in out_indices)) 58 | max_node = max(max_node, len(amr_node)) 59 | max_sent = max(max_sent, len(sent)) 60 | return zip(nodes, in_neigh_indices, in_neigh_edges, out_neigh_indices, out_neigh_edges, sentences, ids), \ 61 | max_node, max_in_neigh, max_out_neigh, max_sent 62 | 63 | def read_amr_from_fof(fofpath): 64 | all_paths = read_text_file(fofpath) 65 | all_instances = [] 66 | max_node = 0 67 | max_in_neigh = 0 68 | max_out_neigh = 0 69 | max_sent = 0 70 | for cur_path in all_paths: 71 | print(cur_path) 72 | cur_instances, cur_node, cur_in_neigh, cur_out_neigh, cur_sent = read_amr_file(cur_path) 73 | all_instances.extend(cur_instances) 74 | max_node = max(max_node, cur_node) 75 | max_in_neigh = max(max_in_neigh, cur_in_neigh) 76 | max_out_neigh = max(max_out_neigh, cur_out_neigh) 77 | max_sent = max(max_sent, cur_sent) 78 | return all_instances, max_node, max_in_neigh, max_out_neigh, max_sent 79 | 80 | def collect_vocabs(all_instances): 81 | all_words = set() 82 | all_chars = set() 83 | all_edgelabels = set() 84 | # nodes: [corpus_size,node_num,] 85 | # neigh_indices & neigh_edges: [corpus_size,node_num,neigh_num,] 86 | # sentence: [corpus_size,sent_len,] 87 | for (nodes, in_neigh_indices, in_neigh_edges, out_neigh_indices, out_neigh_edges, sentence, id) in all_instances: 88 | all_words.update(nodes) 89 | all_words.update(sentence) 90 | for edges in in_neigh_edges: 91 | all_edgelabels.update(edges) 92 | for edges in out_neigh_edges: 93 | all_edgelabels.update(edges) 94 | for w in all_words: 95 | all_chars.update(w) 96 | return (all_words, all_chars, all_edgelabels) 97 | 98 | class G2SDataStream(object): 99 | def __init__(self, all_instances, word_vocab=None, char_vocab=None, edgelabel_vocab=None, options=None, 100 | isShuffle=False, isLoop=False, isSort=True, batch_size=-1): 101 | self.options = options 102 | if batch_size ==-1: batch_size=options.batch_size 103 | # index tokens and filter the dataset 104 | instances = [] 105 | for (nodes, in_neigh_indices, in_neigh_edges, out_neigh_indices, out_neigh_edges, sentence, id) in all_instances: 106 | if options.max_node_num != -1 and len(nodes) > options.max_node_num: 107 | continue # remove very long passages 108 | in_neigh_indices = [x[:options.max_in_neigh_num] for x in in_neigh_indices] 109 | in_neigh_edges = [x[:options.max_in_neigh_num] for x in in_neigh_edges] 110 | out_neigh_indices = [x[:options.max_out_neigh_num] for x in out_neigh_indices] 111 | out_neigh_edges = [x[:options.max_out_neigh_num] for x in out_neigh_edges] 112 | 113 | nodes_idx = word_vocab.to_index_sequence_for_list(nodes) 114 | nodes_chars_idx = None 115 | if options.with_char: 116 | nodes_chars_idx = char_vocab.to_character_matrix_for_list(nodes, max_char_per_word=options.max_char_per_word) 117 | in_neigh_edges_idx = [edgelabel_vocab.to_index_sequence_for_list(edges) for edges in in_neigh_edges] 118 | out_neigh_edges_idx = [edgelabel_vocab.to_index_sequence_for_list(edges) for edges in out_neigh_edges] 119 | sentence_idx = word_vocab.to_index_sequence_for_list(sentence[:options.max_answer_len]) 120 | instances.append((nodes_idx, nodes_chars_idx, 121 | in_neigh_indices, in_neigh_edges_idx, out_neigh_indices, out_neigh_edges_idx, sentence_idx, sentence, id)) 122 | 123 | all_instances = instances 124 | instances = None 125 | 126 | # sort instances based on length 127 | if isSort: 128 | all_instances = sorted(all_instances, key=lambda inst: (len(inst[0]), len(inst[-2]))) 129 | 130 | self.num_instances = len(all_instances) 131 | 132 | # distribute questions into different buckets 133 | batch_spans = padding_utils.make_batches(self.num_instances, batch_size) 134 | self.batches = [] 135 | for batch_index, (batch_start, batch_end) in enumerate(batch_spans): 136 | cur_instances = [] 137 | for i in xrange(batch_start, batch_end): 138 | cur_instances.append(all_instances[i]) 139 | cur_batch = G2SBatch(cur_instances, options, word_vocab=word_vocab) 140 | self.batches.append(cur_batch) 141 | 142 | self.num_batch = len(self.batches) 143 | self.index_array = np.arange(self.num_batch) 144 | self.isShuffle = isShuffle 145 | if self.isShuffle: np.random.shuffle(self.index_array) 146 | self.isLoop = isLoop 147 | self.cur_pointer = 0 148 | 149 | def nextBatch(self): 150 | if self.cur_pointer>=self.num_batch: 151 | if not self.isLoop: return None 152 | self.cur_pointer = 0 153 | if self.isShuffle: np.random.shuffle(self.index_array) 154 | cur_batch = self.batches[self.index_array[self.cur_pointer]] 155 | self.cur_pointer += 1 156 | return cur_batch 157 | 158 | def reset(self): 159 | if self.isShuffle: np.random.shuffle(self.index_array) 160 | self.cur_pointer = 0 161 | 162 | def get_num_batch(self): 163 | return self.num_batch 164 | 165 | def get_num_instance(self): 166 | return self.num_instances 167 | 168 | def get_batch(self, i): 169 | if i>= self.num_batch: return None 170 | return self.batches[i] 171 | 172 | class G2SBatch(object): 173 | def __init__(self, instances, options, word_vocab=None): 174 | self.options = options 175 | 176 | self.amr_node = [x[0] for x in instances] 177 | self.id = [x[-1] for x in instances] 178 | self.target_ref = [x[-2] for x in instances] # list of tuples 179 | self.batch_size = len(instances) 180 | self.vocab = word_vocab 181 | 182 | # create length 183 | self.node_num = [] # [batch_size] 184 | self.sent_len = [] # [batch_size] 185 | for (nodes_idx, nodes_chars_idx, in_neigh_indices, in_neigh_edges_idx, out_neigh_indices, out_neigh_edges_idx, 186 | sentence_idx, sentence, id) in instances: 187 | self.node_num.append(len(nodes_idx)) 188 | self.sent_len.append(min(len(sentence_idx)+1, options.max_answer_len)) 189 | self.node_num = np.array(self.node_num, dtype=np.int32) 190 | self.sent_len = np.array(self.sent_len, dtype=np.int32) 191 | 192 | # node char num 193 | if options.with_char: 194 | self.nodes_chars_num = [[len(nodes_chars_idx) for nodes_chars_idx in instance[1]] for instance in instances] 195 | self.nodes_chars_num = padding_utils.pad_2d_vals_no_size(self.nodes_chars_num) 196 | 197 | # neigh mask 198 | self.in_neigh_mask = [] # [batch_size, node_num, neigh_num] 199 | self.out_neigh_mask = [] 200 | for instance in instances: 201 | ins = [] 202 | for in_neighs in instance[2]: 203 | ins.append([1 for _ in in_neighs]) 204 | self.in_neigh_mask.append(ins) 205 | outs = [] 206 | for out_neighs in instance[4]: 207 | outs.append([1 for _ in out_neighs]) 208 | self.out_neigh_mask.append(outs) 209 | self.in_neigh_mask = padding_utils.pad_3d_vals_no_size(self.in_neigh_mask) 210 | self.out_neigh_mask = padding_utils.pad_3d_vals_no_size(self.out_neigh_mask) 211 | 212 | # create word representation 213 | start_id = word_vocab.getIndex('') 214 | end_id = word_vocab.getIndex('') 215 | 216 | self.nodes = [x[0] for x in instances] 217 | if options.with_char: 218 | self.nodes_chars = [inst[1] for inst in instances] # [batch_size, sent_len, char_num] 219 | self.in_neigh_indices = [x[2] for x in instances] 220 | self.in_neigh_edges = [x[3] for x in instances] 221 | self.out_neigh_indices = [x[4] for x in instances] 222 | self.out_neigh_edges = [x[5] for x in instances] 223 | 224 | self.sent_inp = [] 225 | self.sent_out = [] 226 | for _, _, _, _, _, _, sentence_idx, sentence, id in instances: 227 | if len(sentence_idx) < options.max_answer_len: 228 | self.sent_inp.append([start_id,]+sentence_idx) 229 | self.sent_out.append(sentence_idx+[end_id,]) 230 | else: 231 | self.sent_inp.append([start_id,]+sentence_idx[:-1]) 232 | self.sent_out.append(sentence_idx) 233 | 234 | # making ndarray 235 | self.nodes = padding_utils.pad_2d_vals_no_size(self.nodes) 236 | if options.with_char: 237 | self.nodes_chars = padding_utils.pad_3d_vals_no_size(self.nodes_chars) 238 | self.in_neigh_indices = padding_utils.pad_3d_vals_no_size(self.in_neigh_indices) 239 | self.in_neigh_edges = padding_utils.pad_3d_vals_no_size(self.in_neigh_edges) 240 | self.out_neigh_indices = padding_utils.pad_3d_vals_no_size(self.out_neigh_indices) 241 | self.out_neigh_edges = padding_utils.pad_3d_vals_no_size(self.out_neigh_edges) 242 | 243 | assert self.in_neigh_mask.shape == self.in_neigh_indices.shape 244 | assert self.in_neigh_mask.shape == self.in_neigh_edges.shape 245 | assert self.out_neigh_mask.shape == self.out_neigh_indices.shape 246 | assert self.out_neigh_mask.shape == self.out_neigh_edges.shape 247 | 248 | # [batch_size, sent_len_max] 249 | self.sent_inp = padding_utils.pad_2d_vals(self.sent_inp, len(self.sent_inp), options.max_answer_len) 250 | self.sent_out = padding_utils.pad_2d_vals(self.sent_out, len(self.sent_out), options.max_answer_len) 251 | 252 | 253 | if __name__ == "__main__": 254 | print('testset') 255 | all_instances, max_node_num, max_in_neigh_num, max_out_neigh_num, max_sent_len = read_amr_file('./data/test.json') 256 | print(max_in_neigh_num) 257 | print(max_out_neigh_num) 258 | print(max_node_num) 259 | print(max_sent_len) 260 | print('DONE!') 261 | 262 | -------------------------------------------------------------------------------- /src_s2s/NP2P_phrase_trainer.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import print_function 3 | import argparse 4 | import os 5 | import sys 6 | import time 7 | import numpy as np 8 | 9 | from vocab_utils import Vocab 10 | import namespace_utils 11 | import NP2P_data_stream 12 | from NP2P_model_graph import ModelGraph 13 | import NP2P_beam_decoder 14 | import metric_utils 15 | from nltk.translate.bleu_score import SmoothingFunction, sentence_bleu, corpus_bleu 16 | cc = SmoothingFunction() 17 | 18 | FLAGS = None 19 | import tensorflow as tf 20 | tf.logging.set_verbosity(tf.logging.ERROR) # DEBUG, INFO, WARN, ERROR, and FATAL 21 | 22 | 23 | def evaluate(sess, valid_graph, devDataStream, word_vocab, options=None): 24 | devDataStream.reset() 25 | gen = [] 26 | ref = [] 27 | for batch_index in xrange(devDataStream.get_num_batch()): # for each batch 28 | cur_batch = devDataStream.get_batch(batch_index) 29 | (sentences, _, _, _) = NP2P_beam_decoder.search(sess, valid_graph, word_vocab, cur_batch, FLAGS, decode_mode="greedy") 30 | for i in xrange(cur_batch.batch_size): 31 | cur_ref = cur_batch.instances[i][1].tokText 32 | cur_prediction = sentences[i] 33 | ref.append([cur_ref.split()]) 34 | gen.append(cur_prediction.split()) 35 | return corpus_bleu(ref, gen, smoothing_function=cc.method3) 36 | 37 | def main(_): 38 | print('Configurations:') 39 | print(FLAGS) 40 | 41 | log_dir = FLAGS.model_dir 42 | if not os.path.exists(log_dir): 43 | os.makedirs(log_dir) 44 | 45 | path_prefix = log_dir + "/NP2P.{}".format(FLAGS.suffix) 46 | init_model_prefix = FLAGS.init_model # "/u/zhigwang/zhigwang1/sentence_generation/mscoco/logs/NP2P.phrase_ce_train" 47 | log_file_path = path_prefix + ".log" 48 | print('Log file path: {}'.format(log_file_path)) 49 | log_file = open(log_file_path, 'wt') 50 | log_file.write("{}\n".format(FLAGS)) 51 | log_file.flush() 52 | 53 | # save configuration 54 | namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json") 55 | 56 | print('Loading train set.') 57 | if FLAGS.infile_format == 'fof': 58 | trainset, train_ans_len = NP2P_data_stream.read_generation_datasets_from_fof(FLAGS.train_path, isLower=FLAGS.isLower) 59 | if FLAGS.max_answer_len>train_ans_len: FLAGS.max_answer_len = train_ans_len 60 | else: 61 | trainset, train_ans_len = NP2P_data_stream.read_all_GQA_questions(FLAGS.train_path, isLower=FLAGS.isLower) 62 | print('Number of training samples: {}'.format(len(trainset))) 63 | 64 | print('Loading test set.') 65 | if FLAGS.infile_format == 'fof': 66 | testset, test_ans_len = NP2P_data_stream.read_generation_datasets_from_fof(FLAGS.test_path, isLower=FLAGS.isLower) 67 | else: 68 | testset, test_ans_len = NP2P_data_stream.read_all_GQA_questions(FLAGS.test_path, isLower=FLAGS.isLower) 69 | print('Number of test samples: {}'.format(len(testset))) 70 | 71 | max_actual_len = max(train_ans_len, test_ans_len) 72 | print('Max answer length: {}, truncated to {}'.format(max_actual_len, FLAGS.max_answer_len)) 73 | 74 | word_vocab = None 75 | POS_vocab = None 76 | NER_vocab = None 77 | char_vocab = None 78 | has_pretrained_model = False 79 | best_path = path_prefix + ".best.model" 80 | if os.path.exists(init_model_prefix + ".best.model.index"): 81 | has_pretrained_model = True 82 | print('!!Existing pretrained model. Loading vocabs.') 83 | if FLAGS.with_word: 84 | word_vocab = Vocab(FLAGS.word_vec_path, fileformat='txt2') 85 | print('word_vocab: {}'.format(word_vocab.word_vecs.shape)) 86 | if FLAGS.with_char: 87 | char_vocab = Vocab(init_model_prefix + ".char_vocab", fileformat='txt2') 88 | print('char_vocab: {}'.format(char_vocab.word_vecs.shape)) 89 | if FLAGS.with_POS: 90 | POS_vocab = Vocab(init_model_prefix + ".POS_vocab", fileformat='txt2') 91 | print('POS_vocab: {}'.format(POS_vocab.word_vecs.shape)) 92 | if FLAGS.with_NER: 93 | NER_vocab = Vocab(init_model_prefix + ".NER_vocab", fileformat='txt2') 94 | print('NER_vocab: {}'.format(NER_vocab.word_vecs.shape)) 95 | else: 96 | print('Collecting vocabs.') 97 | (allWords, allChars, allPOSs, allNERs) = NP2P_data_stream.collect_vocabs(trainset) 98 | print('Number of words: {}'.format(len(allWords))) 99 | print('Number of allChars: {}'.format(len(allChars))) 100 | print('Number of allPOSs: {}'.format(len(allPOSs))) 101 | print('Number of allNERs: {}'.format(len(allNERs))) 102 | 103 | if FLAGS.with_word: 104 | word_vocab = Vocab(FLAGS.word_vec_path, fileformat='txt2') 105 | if FLAGS.with_char: 106 | char_vocab = Vocab(voc=allChars, dim=FLAGS.char_dim, fileformat='build') 107 | char_vocab.dump_to_txt2(path_prefix + ".char_vocab") 108 | if FLAGS.with_POS: 109 | POS_vocab = Vocab(voc=allPOSs, dim=FLAGS.POS_dim, fileformat='build') 110 | POS_vocab.dump_to_txt2(path_prefix + ".POS_vocab") 111 | if FLAGS.with_NER: 112 | NER_vocab = Vocab(voc=allNERs, dim=FLAGS.NER_dim, fileformat='build') 113 | NER_vocab.dump_to_txt2(path_prefix + ".NER_vocab") 114 | 115 | print('word vocab size {}'.format(word_vocab.vocab_size)) 116 | sys.stdout.flush() 117 | 118 | print('Build DataStream ... ') 119 | trainDataStream = NP2P_data_stream.QADataStream(trainset, word_vocab, char_vocab, POS_vocab, NER_vocab, options=FLAGS, 120 | isShuffle=True, isLoop=True, isSort=True) 121 | 122 | devDataStream = NP2P_data_stream.QADataStream(testset, word_vocab, char_vocab, POS_vocab, NER_vocab, options=FLAGS, 123 | isShuffle=False, isLoop=False, isSort=True) 124 | print('Number of instances in trainDataStream: {}'.format(trainDataStream.get_num_instance())) 125 | print('Number of instances in devDataStream: {}'.format(devDataStream.get_num_instance())) 126 | print('Number of batches in trainDataStream: {}'.format(trainDataStream.get_num_batch())) 127 | print('Number of batches in devDataStream: {}'.format(devDataStream.get_num_batch())) 128 | sys.stdout.flush() 129 | 130 | init_scale = 0.01 131 | with tf.Graph().as_default(): 132 | initializer = tf.random_uniform_initializer(-init_scale, init_scale) 133 | with tf.name_scope("Train"): 134 | with tf.variable_scope("Model", reuse=None, initializer=initializer): 135 | train_graph = ModelGraph(word_vocab=word_vocab, char_vocab=char_vocab, POS_vocab=POS_vocab, 136 | NER_vocab=NER_vocab, options=FLAGS, mode="rl_train_for_phrase") 137 | 138 | with tf.name_scope("Valid"): 139 | with tf.variable_scope("Model", reuse=True, initializer=initializer): 140 | valid_graph = ModelGraph(word_vocab=word_vocab, char_vocab=char_vocab, POS_vocab=POS_vocab, 141 | NER_vocab=NER_vocab, options=FLAGS, mode="decode") 142 | 143 | initializer = tf.global_variables_initializer() 144 | 145 | vars_ = {} 146 | for var in tf.all_variables(): 147 | if "word_embedding" in var.name: continue 148 | if not var.name.startswith("Model"): continue 149 | vars_[var.name.split(":")[0]] = var 150 | saver = tf.train.Saver(vars_) 151 | 152 | sess = tf.Session() 153 | sess.run(initializer) 154 | if has_pretrained_model: 155 | print("Restoring model from " + init_model_prefix + ".best.model") 156 | saver.restore(sess, init_model_prefix + ".best.model") 157 | print("DONE!") 158 | sys.stdout.flush() 159 | 160 | # for first-time rl training, we get the current BLEU score 161 | print("First-time rl training, get the current BLEU score on dev") 162 | sys.stdout.flush() 163 | best_bleu = evaluate(sess, valid_graph, devDataStream, word_vocab, options=FLAGS) 164 | print('First-time bleu = %.4f' % best_bleu) 165 | log_file.write('First-time bleu = %.4f\n' % best_bleu) 166 | 167 | print('Start the training loop.') 168 | sys.stdout.flush() 169 | train_size = trainDataStream.get_num_batch() 170 | max_steps = train_size * FLAGS.max_epochs 171 | total_loss = 0.0 172 | start_time = time.time() 173 | for step in xrange(max_steps): 174 | cur_batch = trainDataStream.nextBatch() 175 | if FLAGS.with_baseline: 176 | # greedy search 177 | (greedy_sentences, _, _, _) = NP2P_beam_decoder.search(sess, valid_graph, word_vocab, 178 | cur_batch, FLAGS, decode_mode="greedy") 179 | 180 | if FLAGS.with_target_lattice: 181 | (sampled_sentences, sampled_prediction_lengths, sampled_generator_input_idx, 182 | sampled_generator_output_idx) = cur_batch.sample_a_partition() 183 | else: 184 | # multinomial sampling 185 | (sampled_sentences, sampled_prediction_lengths, sampled_generator_input_idx, 186 | sampled_generator_output_idx) = NP2P_beam_decoder.search(sess, valid_graph, word_vocab, 187 | cur_batch, FLAGS, decode_mode="multinomial") 188 | # calculate rewards 189 | rewards = [] 190 | for i in xrange(cur_batch.batch_size): 191 | # print(sampled_sentences[i]) 192 | # print(sampled_generator_input_idx[i]) 193 | # print(sampled_generator_output_idx[i]) 194 | cur_toks = cur_batch.instances[i][1].tokText.split() 195 | # r = sentence_bleu([cur_toks], sampled_sentences[i].split(), smoothing_function=cc.method3) 196 | r = 1.0 197 | b = 0.0 198 | if FLAGS.with_baseline: 199 | b = sentence_bleu([cur_toks], greedy_sentences[i].split(), smoothing_function=cc.method3) 200 | # r = metric_utils.evaluate_captions([cur_toks],[sampled_sentences[i]]) 201 | # b = metric_utils.evaluate_captions([cur_toks],[greedy_sentences[i]]) 202 | rewards.append(1.0* (r-b)) 203 | rewards = np.array(rewards, dtype=np.float32) 204 | # sys.exit(-1) 205 | 206 | # update parameters 207 | feed_dict = train_graph.run_encoder(sess, cur_batch, FLAGS, only_feed_dict=True) 208 | feed_dict[train_graph.reward] = rewards 209 | feed_dict[train_graph.gen_input_words] = sampled_generator_input_idx 210 | feed_dict[train_graph.in_answer_words] = sampled_generator_output_idx 211 | feed_dict[train_graph.answer_lengths] = sampled_prediction_lengths 212 | (_, loss_value) = sess.run([train_graph.train_op, train_graph.loss], feed_dict) 213 | total_loss += loss_value 214 | 215 | if step % 100==0: 216 | print('{} '.format(step), end="") 217 | sys.stdout.flush() 218 | 219 | # Save a checkpoint and evaluate the model periodically. 220 | if (step + 1) % trainDataStream.get_num_batch() == 0 or (step + 1) == max_steps: 221 | print() 222 | duration = time.time() - start_time 223 | print('Step %d: loss = %.2f (%.3f sec)' % (step, total_loss, duration)) 224 | log_file.write('Step %d: loss = %.2f (%.3f sec)\n' % (step, total_loss, duration)) 225 | log_file.flush() 226 | sys.stdout.flush() 227 | total_loss = 0.0 228 | 229 | # Evaluate against the validation set. 230 | start_time = time.time() 231 | print('Validation Data Eval:') 232 | dev_bleu = evaluate(sess, valid_graph, devDataStream, word_vocab, options=FLAGS) 233 | print('Dev bleu = %.4f' % dev_bleu) 234 | log_file.write('Dev bleu = %.4f\n' % dev_bleu) 235 | log_file.flush() 236 | if best_bleu < dev_bleu: 237 | print('Saving weights, BLEU {} (prev_best) < {} (cur)'.format(best_bleu, dev_bleu)) 238 | best_bleu = dev_bleu 239 | saver.save(sess, best_path) # TODO: save model 240 | duration = time.time() - start_time 241 | print('Duration %.3f sec' % (duration)) 242 | sys.stdout.flush() 243 | log_file.write('Duration %.3f sec\n' % (duration)) 244 | log_file.flush() 245 | 246 | log_file.close() 247 | 248 | def enrich_options(options): 249 | if not options.__dict__.has_key("infile_format"): 250 | options.__dict__["infile_format"] = "old" 251 | 252 | if not options.__dict__.has_key("with_target_lattice"): 253 | options.__dict__["with_target_lattice"] = False 254 | 255 | if not options.__dict__.has_key("with_baseline"): 256 | options.__dict__["with_baseline"] = True 257 | 258 | if not options.__dict__.has_key("add_first_word_prob_for_phrase"): 259 | options.__dict__["add_first_word_prob_for_phrase"] = False 260 | 261 | if not options.__dict__.has_key("pretrain_with_max_matching"): 262 | options.__dict__["pretrain_with_max_matching"] = False 263 | return options 264 | 265 | 266 | if __name__ == '__main__': 267 | parser = argparse.ArgumentParser() 268 | parser.add_argument('--config_path', type=str, help='Configuration file.') 269 | 270 | print("CUDA_VISIBLE_DEVICES " + os.environ['CUDA_VISIBLE_DEVICES']) 271 | FLAGS, unparsed = parser.parse_known_args() 272 | 273 | 274 | if FLAGS.config_path is not None: 275 | print('Loading the configuration from ' + FLAGS.config_path) 276 | FLAGS = namespace_utils.load_namespace(FLAGS.config_path) 277 | 278 | FLAGS = enrich_options(FLAGS) 279 | 280 | sys.stdout.flush() 281 | tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) 282 | -------------------------------------------------------------------------------- /src_g2s/G2S_beam_decoder.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import print_function 3 | import argparse 4 | import re 5 | import os 6 | import sys 7 | import time 8 | import numpy as np 9 | 10 | import tensorflow as tf 11 | import namespace_utils 12 | 13 | import G2S_trainer 14 | import G2S_data_stream 15 | from G2S_model_graph import ModelGraph 16 | 17 | from vocab_utils import Vocab 18 | 19 | 20 | 21 | tf.logging.set_verbosity(tf.logging.ERROR) # DEBUG, INFO, WARN, ERROR, and FATAL 22 | 23 | def map_idx_to_word(predictions, vocab, passage, attn_dist): 24 | ''' 25 | predictions: [batch, 1] 26 | ''' 27 | word_size = vocab.vocab_size + 1 28 | all_words = [] 29 | all_word_idx = [] 30 | for i, idx in enumerate(predictions): 31 | if idx == vocab.vocab_size: 32 | k = np.argmax(attn_dist[i,]) 33 | word = batch.instances[i][-1][k] 34 | else: 35 | word = vocab.getWord(idx) 36 | all_words.append(word) 37 | all_word_idx.append(idx) 38 | return all_words, all_word_idx 39 | 40 | def search(sess, model, vocab, batch, options, decode_mode='greedy'): 41 | ''' 42 | for greedy search, multinomial search 43 | ''' 44 | # Run the encoder to get the encoder hidden states and decoder initial state 45 | (encoder_states, encoder_features, node_idx, node_mask, initial_state) = model.run_encoder(sess, batch, options) 46 | # encoder_states: [batch_size, passage_len, encode_dim] 47 | # encoder_features: [batch_size, passage_len, attention_vec_size] 48 | # node_idx: [batch_size, passage_len] 49 | # node_mask: [batch_size, passage_len] 50 | # initial_state: a tupel of [batch_size, gen_dim] 51 | 52 | word_t = batch.sent_inp[:,0] 53 | state_t = initial_state 54 | context_t = np.zeros([batch.batch_size, model.encoder_dim]) 55 | coverage_t = np.zeros([batch.batch_size, encoder_states.shape[1]]) 56 | generator_output_idx = [] # store phrase index prediction 57 | text_results = [] 58 | generator_input_idx = [word_t] # store word index 59 | for i in xrange(options.max_answer_len): 60 | if decode_mode == "pointwise": word_t = batch.sent_inp[:,i] 61 | feed_dict = {} 62 | feed_dict[model.init_decoder_state] = state_t 63 | feed_dict[model.context_t_1] = context_t 64 | feed_dict[model.coverage_t_1] = coverage_t 65 | feed_dict[model.word_t] = word_t 66 | 67 | feed_dict[model.encoder_states] = encoder_states 68 | feed_dict[model.encoder_features] = encoder_features 69 | feed_dict[model.node_idx] = node_idx 70 | feed_dict[model.node_mask] = node_mask 71 | 72 | if decode_mode in ["greedy", "pointwise", ]: 73 | prediction = model.greedy_prediction 74 | elif decode_mode == "multinomial": 75 | prediction = model.multinomial_prediction 76 | 77 | (state_t, attn_dist_t, context_t, coverage_t, prediction) = sess.run([model.state_t, model.attn_dist_t, model.context_t, 78 | model.coverage_t, prediction], feed_dict) 79 | # convert prediction to word ids 80 | generator_output_idx.append(prediction) 81 | prediction = np.reshape(prediction, [prediction.size, 1]) 82 | cur_words, cur_word_idx = map_idx_to_word(prediction) # [batch_size, 1] 83 | cur_word_idx = np.array(cur_word_idx) 84 | cur_word_idx = np.reshape(cur_word_idx, [cur_word_idx.size]) 85 | word_t = cur_word_idx 86 | text_results.append(cur_words) 87 | generator_input_idx.append(cur_word_idx) 88 | 89 | generator_input_idx = generator_input_idx[:-1] # remove the last word to shift one position to the right 90 | generator_input_idx = np.stack(generator_input_idx, axis=1) # [batch_size, max_len] 91 | generator_output_idx = np.stack(generator_output_idx, axis=1) # [batch_size, max_len] 92 | 93 | prediction_lengths = [] # [batch_size] 94 | sentences = [] # [batch_size] 95 | for i in xrange(batch.batch_size): 96 | words = [] 97 | for j in xrange(options.max_answer_len): 98 | cur_phrase = text_results[j][i] 99 | words.append(cur_phrase) 100 | if cur_phrase == "": break # filter out based on end symbol 101 | prediction_lengths.append(len(words)) 102 | cur_sent = " ".join(words) 103 | sentences.append(cur_sent) 104 | 105 | return (sentences, prediction_lengths, generator_input_idx, generator_output_idx) 106 | 107 | class Hypothesis(object): 108 | def __init__(self, tokens, log_ps, attn, state, context_vector, coverage_vector=None): 109 | self.tokens = tokens # store all tokens 110 | self.log_probs = log_ps # store log_probs for each time-step 111 | self.attn_ids = attn 112 | self.state = state 113 | self.context_vector = context_vector 114 | self.coverage_vector = coverage_vector 115 | 116 | def extend(self, token, log_prob, attn_i, state, context_vector, coverage_vector=None): 117 | return Hypothesis(self.tokens + [token], self.log_probs + [log_prob], self.attn_ids + [attn_i], state, 118 | context_vector, coverage_vector=coverage_vector) 119 | 120 | def latest_token(self): 121 | return self.tokens[-1] 122 | 123 | def avg_log_prob(self): 124 | return np.sum(self.log_probs[1:])/ (len(self.tokens)-1) 125 | 126 | def probs2string(self): 127 | out_string = "" 128 | for prob in self.log_probs: 129 | out_string += " %.4f" % prob 130 | return out_string.strip() 131 | 132 | def idx_seq_to_string(self, passage, vocab, options): 133 | word_size = vocab.vocab_size + 1 134 | all_words = [] 135 | for i, idx in enumerate(self.tokens): 136 | cur_word = vocab.getWord(idx) 137 | if cur_word == 'UNK': 138 | idx = passage[self.attn_ids[i]] 139 | cur_word = vocab.getWord(idx) 140 | all_words.append(cur_word) 141 | return " ".join(all_words[1:]) 142 | 143 | 144 | def sort_hyps(hyps): 145 | return sorted(hyps, key=lambda h: h.avg_log_prob(), reverse=True) 146 | 147 | 148 | 149 | def run_beam_search(sess, model, vocab, batch, options): 150 | # Run encoder 151 | (encoder_states, encoder_features, node_idx, node_mask, initial_state) = model.run_encoder(sess, batch, options) 152 | # encoder_states: [1, passage_len, encode_dim] 153 | # initial_state: a tupel of [1, gen_dim] 154 | # encoder_features: [1, passage_len, attention_vec_size] 155 | # node_idx: [1, passage_len] 156 | # node_mask: [1, passage_len] 157 | 158 | sent_stop_id = vocab.getIndex('') 159 | 160 | # Initialize this first hypothesis 161 | context_t = np.zeros([model.encoder_dim]) # [encode_dim] 162 | coverage_t = np.zeros((encoder_states.shape[1])) # [passage_len] 163 | hyps = [] 164 | hyps.append(Hypothesis([batch.sent_inp[0][0]], [0.0], [-1], initial_state, context_t, coverage_vector=coverage_t)) 165 | 166 | # beam search decoding 167 | results = [] # this will contain finished hypotheses (those that have emitted the token) 168 | steps = 0 169 | while steps < options.max_answer_len and len(results) < options.beam_size: 170 | cur_size = len(hyps) # current number of hypothesis in the beam 171 | cur_encoder_states = np.tile(encoder_states, (cur_size, 1, 1)) 172 | cur_encoder_features = np.tile(encoder_features, (cur_size, 1, 1)) # [batch_size,passage_len, options.attention_vec_size] 173 | cur_node_idx = np.tile(node_idx, (cur_size, 1)) # [batch_size, passage_len] 174 | cur_node_mask = np.tile(node_mask, (cur_size, 1)) # [batch_size, passage_len] 175 | cur_state_t_1 = [] # [2, gen_steps] 176 | cur_context_t_1 = [] # [batch_size, encoder_dim] 177 | cur_coverage_t_1 = [] # [batch_size, passage_len] 178 | cur_word_t = [] # [batch_size] 179 | for h in hyps: 180 | cur_state_t_1.append(h.state) 181 | cur_context_t_1.append(h.context_vector) 182 | cur_word_t.append(h.latest_token()) 183 | cur_coverage_t_1.append(h.coverage_vector) 184 | cur_context_t_1 = np.stack(cur_context_t_1, axis=0) 185 | cur_coverage_t_1 = np.stack(cur_coverage_t_1, axis=0) 186 | cur_word_t = np.array(cur_word_t) 187 | 188 | cells = [state.c for state in cur_state_t_1] 189 | hidds = [state.h for state in cur_state_t_1] 190 | new_c = np.concatenate(cells, axis=0) 191 | new_h = np.concatenate(hidds, axis=0) 192 | new_dec_init_state = tf.contrib.rnn.LSTMStateTuple(new_c, new_h) 193 | 194 | feed_dict = {} 195 | feed_dict[model.init_decoder_state] = new_dec_init_state 196 | feed_dict[model.context_t_1] = cur_context_t_1 197 | feed_dict[model.word_t] = cur_word_t 198 | 199 | feed_dict[model.encoder_states] = cur_encoder_states 200 | feed_dict[model.encoder_features] = cur_encoder_features 201 | feed_dict[model.nodes] = cur_node_idx 202 | feed_dict[model.nodes_mask] = cur_node_mask 203 | feed_dict[model.coverage_t_1] = cur_coverage_t_1 204 | 205 | (state_t, context_t, attn_dist_t, coverage_t, topk_log_probs, topk_ids) = sess.run([model.state_t, model.context_t, model.attn_dist_t, 206 | model.coverage_t, model.topk_log_probs, model.topk_ids], feed_dict) 207 | 208 | new_states = [tf.contrib.rnn.LSTMStateTuple(state_t.c[i:i+1, :], state_t.h[i:i+1, :]) for i in xrange(cur_size)] 209 | 210 | # Extend each hypothesis and collect them all in all_hyps 211 | if steps == 0: cur_size = 1 212 | all_hyps = [] 213 | for i in xrange(cur_size): 214 | h = hyps[i] 215 | cur_state = new_states[i] 216 | cur_context = context_t[i] 217 | cur_coverage = coverage_t[i] 218 | for j in xrange(options.beam_size): 219 | cur_tok = topk_ids[i, j] 220 | # add anony constraint 221 | #if cur_tok in vocab.anony_ids and cur_tok not in batch.amr_anony_ids: 222 | # continue 223 | cur_tok_log_prob = topk_log_probs[i, j] 224 | cur_attn_i = np.argmax(attn_dist_t[i, :]) 225 | new_hyp = h.extend(cur_tok, cur_tok_log_prob, cur_attn_i, cur_state, cur_context, coverage_vector=cur_coverage) 226 | all_hyps.append(new_hyp) 227 | 228 | # Filter and collect any hypotheses that have produced the end token. 229 | # hyps will contain hypotheses for the next step 230 | hyps = [] 231 | for h in sort_hyps(all_hyps): 232 | # If this hypothesis is sufficiently long, put in results. Otherwise discard. 233 | if h.latest_token() == sent_stop_id: 234 | if steps >= options.min_answer_len: 235 | results.append(h) 236 | # hasn't reached stop token, so continue to extend this hypothesis 237 | else: 238 | hyps.append(h) 239 | if len(hyps) == options.beam_size or len(results) == options.beam_size: 240 | break 241 | 242 | steps += 1 243 | 244 | # At this point, either we've got beam_size results, or we've reached maximum decoder steps 245 | # if we don't have any complete results, add all current hypotheses (incomplete summaries) to results 246 | if len(results)==0: 247 | results = hyps 248 | 249 | # Sort hypotheses by average log probability 250 | hyps_sorted = sort_hyps(results) 251 | 252 | # Return the hypothesis with highest average log prob 253 | return hyps_sorted 254 | 255 | if __name__ == '__main__': 256 | parser = argparse.ArgumentParser() 257 | parser.add_argument('--model_prefix', type=str, required=True, help='Prefix to the models.') 258 | parser.add_argument('--in_path', type=str, required=True, help='The path to the test file.') 259 | parser.add_argument('--out_path', type=str, help='The path to the output file.') 260 | parser.add_argument('--mode', type=str,default='pointwise', help='The path to the output file.') 261 | 262 | args, unparsed = parser.parse_known_args() 263 | 264 | model_prefix = args.model_prefix 265 | in_path = args.in_path 266 | out_path = args.out_path 267 | mode = args.mode 268 | 269 | print("CUDA_VISIBLE_DEVICES " + os.environ['CUDA_VISIBLE_DEVICES']) 270 | 271 | # load the configuration file 272 | print('Loading configurations from ' + model_prefix + ".config.json") 273 | FLAGS = namespace_utils.load_namespace(model_prefix + ".config.json") 274 | FLAGS = G2S_trainer.enrich_options(FLAGS) 275 | 276 | # load vocabs 277 | print('Loading vocabs.') 278 | word_vocab = Vocab(FLAGS.word_vec_path, fileformat='txt2') 279 | print('word_vocab: {}'.format(word_vocab.word_vecs.shape)) 280 | edgelabel_vocab = Vocab(model_prefix + ".edgelabel_vocab", fileformat='txt2') 281 | print('edgelabel_vocab: {}'.format(edgelabel_vocab.word_vecs.shape)) 282 | char_vocab = None 283 | if FLAGS.with_char: 284 | char_vocab = Vocab(model_prefix + ".char_vocab", fileformat='txt2') 285 | print('char_vocab: {}'.format(char_vocab.word_vecs.shape)) 286 | 287 | 288 | print('Loading test set from {}.'.format(in_path)) 289 | testset, _, _, _, _ = G2S_data_stream.read_amr_file(in_path) 290 | print('Number of samples: {}'.format(len(testset))) 291 | 292 | print('Build DataStream ... ') 293 | batch_size=-1 294 | if mode not in ('pointwise', 'multinomial', 'greedy', 'greedy_evaluate', ): 295 | batch_size = 1 296 | 297 | devDataStream = G2S_data_stream.G2SDataStream(testset, word_vocab, char_vocab, edgelabel_vocab, options=FLAGS, 298 | isShuffle=False, isLoop=False, isSort=True, batch_size=batch_size) 299 | print('Number of instances in testDataStream: {}'.format(devDataStream.get_num_instance())) 300 | print('Number of batches in testDataStream: {}'.format(devDataStream.get_num_batch())) 301 | 302 | best_path = model_prefix + ".best.model" 303 | with tf.Graph().as_default(): 304 | initializer = tf.random_uniform_initializer(-0.01, 0.01) 305 | with tf.name_scope("Valid"): 306 | with tf.variable_scope("Model", reuse=False, initializer=initializer): 307 | valid_graph = ModelGraph(word_vocab=word_vocab, char_vocab=char_vocab, Edgelabel_vocab=edgelabel_vocab, 308 | options=FLAGS, mode="decode") 309 | 310 | ## remove word _embedding 311 | vars_ = {} 312 | for var in tf.all_variables(): 313 | if FLAGS.fix_word_vec and "word_embedding" in var.name: continue 314 | if not var.name.startswith("Model"): continue 315 | vars_[var.name.split(":")[0]] = var 316 | saver = tf.train.Saver(vars_) 317 | 318 | initializer = tf.global_variables_initializer() 319 | sess = tf.Session() 320 | sess.run(initializer) 321 | 322 | saver.restore(sess, best_path) # restore the model 323 | 324 | total = 0 325 | correct = 0 326 | if mode.endswith('evaluate'): 327 | ref_outfile = open(out_path+ ".ref", 'wt') 328 | pred_outfile = open(out_path+ ".pred", 'wt') 329 | else: 330 | outfile = open(out_path, 'wt') 331 | total_num = devDataStream.get_num_batch() 332 | devDataStream.reset() 333 | for i in range(total_num): 334 | cur_batch = devDataStream.get_batch(i) 335 | if mode in ['greedy', 'multinomial']: 336 | print('Batch {}'.format(i)) 337 | (sentences, prediction_lengths, generator_input_idx, 338 | generator_output_idx) = search(sess, valid_graph, word_vocab, cur_batch, FLAGS, decode_mode=mode) 339 | for j in xrange(cur_batch.batch_size): 340 | outfile.write(cur_batch.target_ref[j].encode('utf-8') + "\n") 341 | outfile.write(sentences[j].encode('utf-8') + "\n") 342 | outfile.write("========\n") 343 | outfile.flush() 344 | else: # beam search 345 | print('Instance {}'.format(i)) 346 | hyps = run_beam_search(sess, valid_graph, word_vocab, cur_batch, FLAGS) 347 | outfile.write(cur_batch.id[0] + "\n") 348 | outfile.write(' '.join(cur_batch.target_ref[0]).encode('utf-8') + "\n") 349 | for j in xrange(1): 350 | hyp = hyps[j] 351 | cur_passage = cur_batch.amr_node[0] 352 | cur_sent = hyp.idx_seq_to_string(cur_passage, word_vocab, FLAGS) 353 | outfile.write(cur_sent.encode('utf-8') + "\n") 354 | outfile.write("--------\n") 355 | outfile.write("========\n") 356 | outfile.flush() 357 | if mode.endswith('evaluate'): 358 | ref_outfile.close() 359 | pred_outfile.close() 360 | else: 361 | outfile.close() 362 | 363 | 364 | 365 | 366 | -------------------------------------------------------------------------------- /src_g2s/G2S_model_graph.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import graph_encoder_utils 3 | import generator_utils 4 | import padding_utils 5 | from tensorflow.python.ops import variable_scope 6 | import numpy as np 7 | import random 8 | 9 | from nltk.translate.bleu_score import SmoothingFunction, sentence_bleu 10 | cc = SmoothingFunction() 11 | 12 | class ModelGraph(object): 13 | def __init__(self, word_vocab, char_vocab, Edgelabel_vocab, options=None, mode='ce_train'): 14 | # here 'mode', whose value can be: 15 | # 'ce_train', 16 | # 'rl_train', 17 | # 'evaluate', 18 | # 'evaluate_bleu', 19 | # 'decode'. 20 | # it is different from 'mode_gen' in generator_utils.py 21 | # value of 'mode_gen' can be ['ce_loss', 'rl_loss', 'greedy' or 'sample'] 22 | self.mode = mode 23 | 24 | # is_training controls whether to use dropout 25 | is_training = True if mode in ('ce_train', ) else False 26 | 27 | self.options = options 28 | self.word_vocab = word_vocab 29 | 30 | # encode the input instance 31 | # encoder.graph_hidden [batch, node_num, vsize] 32 | # encoder.graph_cell [batch, node_num, vsize] 33 | self.encoder = graph_encoder_utils.GraphEncoder( 34 | word_vocab = word_vocab, 35 | edge_label_vocab = Edgelabel_vocab, 36 | char_vocab = char_vocab, 37 | is_training = is_training, options = options) 38 | 39 | # ============== Choices of attention memory ================ 40 | if options.attention_type == 'hidden': 41 | self.encoder_dim = options.neighbor_vector_dim 42 | self.encoder_states = self.encoder.graph_hiddens 43 | elif options.attention_type == 'hidden_cell': 44 | self.encoder_dim = options.neighbor_vector_dim * 2 45 | self.encoder_states = tf.concat([self.encoder.graph_hiddens, self.encoder.graph_cells], 2) 46 | elif options.attention_type == 'hidden_embed': 47 | self.encoder_dim = options.neighbor_vector_dim + self.encoder.input_dim 48 | self.encoder_states = tf.concat([self.encoder.graph_hiddens, self.encoder.node_representations], 2) 49 | else: 50 | assert False, '%s not supported yet' % options.attention_type 51 | 52 | # ============== Choices of initializing decoder state ============= 53 | if options.way_init_decoder == 'zero': 54 | new_c = tf.zeros([self.encoder.batch_size, options.gen_hidden_size]) 55 | new_h = tf.zeros([self.encoder.batch_size, options.gen_hidden_size]) 56 | elif options.way_init_decoder == 'all': 57 | new_c = tf.reduce_sum(self.encoder.graph_cells, axis=1) 58 | new_h = tf.reduce_sum(self.encoder.graph_hiddens, axis=1) 59 | elif options.way_init_decoder == 'root': 60 | new_c = self.encoder.graph_cells[:,0,:] 61 | new_h = self.encoder.graph_hiddens[:,0,:] 62 | else: 63 | assert False, 'way to initial decoder (%s) not supported' % options.way_init_decoder 64 | self.init_decoder_state = tf.contrib.rnn.LSTMStateTuple(new_c, new_h) 65 | 66 | # prepare AMR-side input for decoder 67 | self.nodes = self.encoder.passage_nodes 68 | self.nodes_num = self.encoder.passage_nodes_size 69 | if options.with_char: 70 | self.nodes_chars = self.encoder.passage_nodes_chars 71 | self.nodes_chars_num = self.encoder.passage_nodes_chars_size 72 | self.nodes_mask = self.encoder.passage_nodes_mask 73 | 74 | self.in_neigh_indices = self.encoder.passage_in_neighbor_indices 75 | self.in_neigh_edges = self.encoder.passage_in_neighbor_edges 76 | self.in_neigh_mask = self.encoder.passage_in_neighbor_mask 77 | 78 | self.out_neigh_indices = self.encoder.passage_out_neighbor_indices 79 | self.out_neigh_edges = self.encoder.passage_out_neighbor_edges 80 | self.out_neigh_mask = self.encoder.passage_out_neighbor_mask 81 | 82 | self.create_placeholders(options) 83 | 84 | loss_weights = tf.sequence_mask(self.answer_len, options.max_answer_len, dtype=tf.float32) # [batch_size, gen_steps] 85 | 86 | with variable_scope.variable_scope("generator"): 87 | # create generator 88 | self.generator = generator_utils.CovCopyAttenGen(self, options, word_vocab) 89 | # calculate encoder_features 90 | self.encoder_features = self.generator.calculate_encoder_features(self.encoder_states, self.encoder_dim) 91 | 92 | if mode == 'decode': 93 | self.context_t_1 = tf.placeholder(tf.float32, [None, self.encoder_dim], name='context_t_1') # [batch_size, encoder_dim] 94 | self.coverage_t_1 = tf.placeholder(tf.float32, [None, None], name='coverage_t_1') # [batch_size, encoder_dim] 95 | self.word_t = tf.placeholder(tf.int32, [None], name='word_t') # [batch_size] 96 | 97 | (self.state_t, self.context_t, self.coverage_t, self.attn_dist_t, self.p_gen_t, self.ouput_t, 98 | self.topk_log_probs, self.topk_ids, self.greedy_prediction, self.multinomial_prediction) = self.generator.decode_mode( 99 | word_vocab, options.beam_size, self.init_decoder_state, self.context_t_1, self.coverage_t_1, self.word_t, 100 | self.encoder_states, self.encoder_features, self.nodes, self.nodes_mask) 101 | # not buiding training op for this mode 102 | return 103 | elif mode == 'evaluate_bleu': 104 | _, _, self.greedy_words = self.generator.train_mode(word_vocab, self.encoder_dim, self.encoder_states, self.encoder_features, 105 | self.nodes, self.nodes_mask, self.init_decoder_state, 106 | self.answer_inp, self.answer_ref, loss_weights, mode_gen='greedy') 107 | # not buiding training op for this mode 108 | return 109 | elif mode in ('ce_train', 'evaluate', ): 110 | self.accu, self.loss, _ = self.generator.train_mode(word_vocab, self.encoder_dim, self.encoder_states, self.encoder_features, 111 | self.nodes, self.nodes_mask, self.init_decoder_state, 112 | self.answer_inp, self.answer_ref, loss_weights, mode_gen='ce_loss') 113 | if mode == 'evaluate': return # not buiding training op for evaluation 114 | elif mode == 'rl_train': 115 | _, self.loss, _ = self.generator.train_mode(word_vocab, self.encoder_dim, self.encoder_states,self.encoder_features, 116 | self.nodes, self.nodes_mask, self.init_decoder_state, 117 | self.answer_inp, self.answer_ref, loss_weights, mode_gen='rl_loss') 118 | 119 | tf.get_variable_scope().reuse_variables() 120 | 121 | _, _, self.sampled_words = self.generator.train_mode(word_vocab, self.encoder_dim, self.encoder_states,self.encoder_features, 122 | self.nodes, self.nodes_mask, self.init_decoder_state, 123 | self.answer_inp, self.answer_ref, None, mode_gen='sample') 124 | 125 | _, _, self.greedy_words = self.generator.train_mode(word_vocab, self.encoder_dim, self.encoder_states,self.encoder_features, 126 | self.nodes, self.nodes_mask, self.init_decoder_state, 127 | self.answer_inp, self.answer_ref, None, mode_gen='greedy') 128 | 129 | 130 | if options.optimize_type == 'adadelta': 131 | clipper = 50 132 | optimizer = tf.train.AdadeltaOptimizer(learning_rate=options.learning_rate) 133 | tvars = tf.trainable_variables() 134 | if options.lambda_l2>0.0: 135 | l2_loss = tf.add_n([tf.nn.l2_loss(v) for v in tvars if v.get_shape().ndims > 1]) 136 | self.loss = self.loss + options.lambda_l2 * l2_loss 137 | grads, _ = tf.clip_by_global_norm(tf.gradients(self.loss, tvars), clipper) 138 | self.train_op = optimizer.apply_gradients(zip(grads, tvars)) 139 | elif options.optimize_type == 'adam': 140 | clipper = 50 141 | optimizer = tf.train.AdamOptimizer(learning_rate=options.learning_rate) 142 | tvars = tf.trainable_variables() 143 | if options.lambda_l2>0.0: 144 | l2_loss = tf.add_n([tf.nn.l2_loss(v) for v in tvars if v.get_shape().ndims > 1]) 145 | self.loss = self.loss + options.lambda_l2 * l2_loss 146 | grads, _ = tf.clip_by_global_norm(tf.gradients(self.loss, tvars), clipper) 147 | self.train_op = optimizer.apply_gradients(zip(grads, tvars)) 148 | 149 | extra_train_ops = [] 150 | train_ops = [self.train_op] + extra_train_ops 151 | self.train_op = tf.group(*train_ops) 152 | 153 | def create_placeholders(self, options): 154 | # build placeholder for answer 155 | self.answer_ref = tf.placeholder(tf.int32, [None, options.max_answer_len], name="answer_ref") # [batch_size, gen_steps] 156 | self.answer_inp = tf.placeholder(tf.int32, [None, options.max_answer_len], name="answer_inp") # [batch_size, gen_steps] 157 | self.answer_len = tf.placeholder(tf.int32, [None], name="answer_len") # [batch_size] 158 | 159 | # build placeholder for reinforcement learning 160 | self.reward = tf.placeholder(tf.float32, [None], name="reward") 161 | 162 | 163 | def run_greedy(self, sess, batch, options): 164 | feed_dict = self.run_encoder(sess, batch, options, only_feed_dict=True) # reuse this function to construct feed_dict 165 | feed_dict[self.answer_inp] = batch.sent_inp 166 | return sess.run(self.greedy_words, feed_dict) 167 | 168 | 169 | def run_ce_training(self, sess, batch, options, only_eval=False): 170 | feed_dict = self.run_encoder(sess, batch, options, only_feed_dict=True) # reuse this function to construct feed_dict 171 | feed_dict[self.answer_inp] = batch.sent_inp 172 | feed_dict[self.answer_ref] = batch.sent_out 173 | feed_dict[self.answer_len] = batch.sent_len 174 | 175 | if only_eval: 176 | return sess.run([self.accu, self.loss], feed_dict) 177 | else: 178 | return sess.run([self.train_op, self.loss], feed_dict)[1] 179 | 180 | 181 | def run_rl_training_subsample(self, sess, batch, options): 182 | flipp = options.flipp if options.__dict__.has_key('flipp') else 0.1 183 | 184 | # make feed_dict 185 | feed_dict = self.run_encoder(sess, batch, options, only_feed_dict=True) 186 | feed_dict[self.answer_inp] = batch.sent_inp 187 | 188 | # get greedy and gold outputs 189 | greedy_output = sess.run(self.greedy_words, feed_dict) # [batch, sent_len] 190 | greedy_output = greedy_output.tolist() 191 | gold_output = batch.sent_out.tolist() 192 | 193 | # generate sample_output by flipping coins 194 | sample_output = np.copy(batch.sent_out) 195 | for i in range(batch.sent_out.shape[0]): 196 | seq_len = min(options.max_answer_len, batch.sent_len[i]-1) # don't change stop token '' 197 | for j in range(seq_len): 198 | if greedy_output[i][j] != 0 and random.random() < flipp: 199 | sample_output[i,j] = greedy_output[i][j] 200 | sample_output = sample_output.tolist() 201 | 202 | st_wid = self.word_vocab.getIndex('') 203 | en_wid = self.word_vocab.getIndex('') 204 | 205 | rl_inputs = [] 206 | rl_outputs = [] 207 | rl_input_lengths = [] 208 | reward = [] 209 | for i, (sout,gout) in enumerate(zip(sample_output,greedy_output)): 210 | sout, slex = self.word_vocab.getLexical(sout) 211 | gout, glex = self.word_vocab.getLexical(gout) 212 | rl_inputs.append([st_wid,]+sout[:-1]) 213 | rl_outputs.append(sout) 214 | rl_input_lengths.append(len(sout)) 215 | _, ref_lex = self.word_vocab.getLexical(gold_output[i]) 216 | slst = slex.split() 217 | glst = glex.split() 218 | rlst = ref_lex.split() 219 | if options.reward_type == 'bleu': 220 | r = sentence_bleu([rlst], slst, smoothing_function=cc.method3) 221 | b = sentence_bleu([rlst], glst, smoothing_function=cc.method3) 222 | elif options.reward_type == 'rouge': 223 | r = sentence_rouge(ref_lex, slex, smoothing_function=cc.method3) 224 | b = sentence_rouge(ref_lex, glex, smoothing_function=cc.method3) 225 | reward.append(r-b) 226 | #print('Ref: {}'.format(ref_lex.encode('utf-8','ignore'))) 227 | #print('Sample: {}'.format(slex.encode('utf-8','ignore'))) 228 | #print('Greedy: {}'.format(glex.encode('utf-8','ignore'))) 229 | #print('R-B: {}'.format(reward[-1])) 230 | #print('-----') 231 | 232 | rl_inputs = padding_utils.pad_2d_vals(rl_inputs, len(rl_inputs), self.options.max_answer_len) 233 | rl_outputs = padding_utils.pad_2d_vals(rl_outputs, len(rl_outputs), self.options.max_answer_len) 234 | rl_input_lengths = np.array(rl_input_lengths, dtype=np.int32) 235 | reward = np.array(reward, dtype=np.float32) 236 | assert rl_inputs.shape == rl_outputs.shape 237 | 238 | feed_dict = self.run_encoder(sess, batch, options, only_feed_dict=True) 239 | feed_dict[self.reward] = reward 240 | feed_dict[self.answer_inp] = rl_inputs 241 | feed_dict[self.answer_ref] = rl_outputs 242 | feed_dict[self.answer_len] = rl_input_lengths 243 | 244 | _, loss = sess.run([self.train_op, self.loss], feed_dict) 245 | return loss 246 | 247 | 248 | def run_rl_training_model(self, sess, batch, options): 249 | feed_dict = self.run_encoder(sess, batch, options, only_feed_dict=True) 250 | feed_dict[self.answer_inp] = batch.sent_inp 251 | 252 | sample_output, greedy_output = sess.run( 253 | [self.sampled_words, self.greedy_words], feed_dict) 254 | 255 | sample_output = sample_output.tolist() 256 | greedy_output = greedy_output.tolist() 257 | 258 | st_wid = self.word_vocab.getIndex('') 259 | en_wid = self.word_vocab.getIndex('') 260 | 261 | rl_inputs = [] 262 | rl_outputs = [] 263 | rl_input_lengths = [] 264 | reward = [] 265 | for i, (sout,gout) in enumerate(zip(sample_output,greedy_output)): 266 | sout, slex = self.word_vocab.getLexical(sout) 267 | gout, glex = self.word_vocab.getLexical(gout) 268 | rl_inputs.append([st_wid,]+sout[:-1]) 269 | rl_outputs.append(sout) 270 | rl_input_lengths.append(len(sout)) 271 | ref_lex = batch.instances[i][-1] 272 | #r = metric_utils.evaluate_captions([ref_lex,],[slex,]) 273 | #b = metric_utils.evaluate_captions([ref_lex,],[glex,]) 274 | slst = slex.split() 275 | glst = glex.split() 276 | rlst = ref_lex.split() 277 | if options.reward_type == 'bleu': 278 | r = sentence_bleu([rlst], slst, smoothing_function=cc.method3) 279 | b = sentence_bleu([rlst], glst, smoothing_function=cc.method3) 280 | elif options.reward_type == 'rouge': 281 | r = sentence_rouge(ref_lex, slex, smoothing_function=cc.method3) 282 | b = sentence_rouge(ref_lex, glex, smoothing_function=cc.method3) 283 | reward.append(r-b) 284 | #print('Ref: {}'.format(ref_lex.encode('utf-8','ignore'))) 285 | #print('Sample: {}'.format(slex.encode('utf-8','ignore'))) 286 | #print('Greedy: {}'.format(glex.encode('utf-8','ignore'))) 287 | #print('R-B: {}'.format(reward[-1])) 288 | #print('-----') 289 | 290 | rl_inputs = padding_utils.pad_2d_vals(rl_inputs, len(rl_inputs), self.options.max_answer_len) 291 | rl_outputs = padding_utils.pad_2d_vals(rl_outputs, len(rl_outputs), self.options.max_answer_len) 292 | rl_input_lengths = np.array(rl_input_lengths, dtype=np.int32) 293 | reward = np.array(reward, dtype=np.float32) 294 | assert rl_inputs.shape == rl_outputs.shape 295 | 296 | feed_dict = self.run_encoder(sess, batch, options, only_feed_dict=True) 297 | feed_dict[self.reward] = reward 298 | feed_dict[self.answer_inp] = rl_inputs 299 | feed_dict[self.answer_out] = rl_outputs 300 | feed_dict[self.answer_len] = rl_input_lengths 301 | 302 | _, loss = sess.run([self.train_op, self.loss], feed_dict) 303 | return loss 304 | 305 | def run_encoder(self, sess, batch, options, only_feed_dict=False): 306 | feed_dict = {} 307 | feed_dict[self.nodes] = batch.nodes 308 | feed_dict[self.nodes_num] = batch.node_num 309 | if options.with_char: 310 | feed_dict[self.nodes_chars] = batch.nodes_chars 311 | feed_dict[self.nodes_chars_num] = batch.nodes_chars_num 312 | 313 | feed_dict[self.in_neigh_indices] = batch.in_neigh_indices 314 | feed_dict[self.in_neigh_edges] = batch.in_neigh_edges 315 | feed_dict[self.in_neigh_mask] = batch.in_neigh_mask 316 | 317 | feed_dict[self.out_neigh_indices] = batch.out_neigh_indices 318 | feed_dict[self.out_neigh_edges] = batch.out_neigh_edges 319 | feed_dict[self.out_neigh_mask] = batch.out_neigh_mask 320 | 321 | if only_feed_dict: 322 | return feed_dict 323 | 324 | return sess.run([self.encoder_states, self.encoder_features, self.nodes, self.nodes_mask, self.init_decoder_state], 325 | feed_dict) 326 | 327 | if __name__ == '__main__': 328 | summary = " Tokyo is the one of the biggest city in the world." 329 | reference = "The capital of Japan, Tokyo, is the center of Japanese economy." 330 | 331 | -------------------------------------------------------------------------------- /src_g2s/G2S_trainer.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import print_function 3 | import argparse 4 | import os 5 | import sys 6 | import time 7 | import numpy as np 8 | import codecs 9 | 10 | from vocab_utils import Vocab 11 | import namespace_utils 12 | import G2S_data_stream 13 | from G2S_model_graph import ModelGraph 14 | 15 | FLAGS = None 16 | import tensorflow as tf 17 | tf.logging.set_verbosity(tf.logging.ERROR) # DEBUG, INFO, WARN, ERROR, and FATAL 18 | 19 | from nltk.translate.bleu_score import SmoothingFunction, corpus_bleu, sentence_bleu 20 | cc = SmoothingFunction() 21 | 22 | import metric_utils 23 | 24 | import platform 25 | def get_machine_name(): 26 | return platform.node() 27 | 28 | def vec2string(val): 29 | result = "" 30 | for v in val: 31 | result += " {}".format(v) 32 | return result.strip() 33 | 34 | 35 | def softmax(x): 36 | """Compute softmax values for each sets of scores in x.""" 37 | e_x = np.exp(x - np.max(x)) 38 | return e_x / e_x.sum() 39 | 40 | 41 | def document_bleu(vocab, gen, ref, suffix=''): 42 | genlex = [vocab.getLexical(x)[1] for x in gen] 43 | reflex = [[vocab.getLexical(x)[1],] for x in ref] 44 | genlst = [x.split() for x in genlex] 45 | reflst = [[x[0].split()] for x in reflex] 46 | f = codecs.open('gen.txt'+suffix,'w','utf-8') 47 | for line in genlex: 48 | print(line, end='\n', file=f) 49 | f.close() 50 | f = codecs.open('ref.txt'+suffix,'w','utf-8') 51 | for line in reflex: 52 | print(line[0], end='\n', file=f) 53 | f.close() 54 | return corpus_bleu(reflst, genlst, smoothing_function=cc.method3) 55 | 56 | 57 | def evaluate(sess, valid_graph, devDataStream, options=None, suffix=''): 58 | devDataStream.reset() 59 | gen = [] 60 | ref = [] 61 | dev_loss = 0.0 62 | dev_right = 0.0 63 | dev_total = 0.0 64 | for batch_index in xrange(devDataStream.get_num_batch()): # for each batch 65 | cur_batch = devDataStream.get_batch(batch_index) 66 | if valid_graph.mode == 'evaluate': 67 | accu_value, loss_value = valid_graph.run_ce_training(sess, cur_batch, options, only_eval=True) 68 | dev_loss += loss_value 69 | dev_right += accu_value 70 | dev_total += np.sum(cur_batch.sent_len) 71 | elif valid_graph.mode == 'evaluate_bleu': 72 | gen.extend(valid_graph.run_greedy(sess, cur_batch, options).tolist()) 73 | ref.extend(cur_batch.sent_out.tolist()) 74 | else: 75 | assert False 76 | 77 | if valid_graph.mode == 'evaluate': 78 | return {'dev_loss':dev_loss, 'dev_accu':1.0*dev_right/dev_total, 'dev_right':dev_right, 'dev_total':dev_total, } 79 | else: 80 | return {'dev_bleu':document_bleu(valid_graph.word_vocab,gen,ref,suffix), } 81 | 82 | 83 | 84 | def main(_): 85 | print('Configurations:') 86 | print(FLAGS) 87 | 88 | log_dir = FLAGS.model_dir 89 | if not os.path.exists(log_dir): 90 | os.makedirs(log_dir) 91 | 92 | path_prefix = log_dir + "/G2S.{}".format(FLAGS.suffix) 93 | log_file_path = path_prefix + ".log" 94 | print('Log file path: {}'.format(log_file_path)) 95 | log_file = open(log_file_path, 'wt') 96 | log_file.write("{}\n".format(FLAGS)) 97 | log_file.flush() 98 | 99 | # save configuration 100 | namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json") 101 | 102 | print('Loading train set.') 103 | trainset, trn_node, trn_in_neigh, trn_out_neigh, trn_sent = G2S_data_stream.read_amr_file(FLAGS.train_path) 104 | print('Number of training samples: {}'.format(len(trainset))) 105 | 106 | print('Loading dev set.') 107 | devset, tst_node, tst_in_neigh, tst_out_neigh, tst_sent = G2S_data_stream.read_amr_file(FLAGS.test_path) 108 | print('Number of dev samples: {}'.format(len(devset))) 109 | 110 | if FLAGS.finetune_path != "": 111 | print('Loading finetune set.') 112 | ftset, ft_node, ft_in_neigh, ft_out_neigh, ft_sent = G2S_data_stream.read_amr_file(FLAGS.finetune_path) 113 | print('Number of finetune samples: {}'.format(len(ftset))) 114 | else: 115 | ftset, ft_node, ft_in_neigh, ft_out_neigh, ft_sent = (None, 0, 0, 0, 0) 116 | 117 | max_node = max(trn_node, tst_node, ft_node) 118 | max_in_neigh = max(trn_in_neigh, tst_in_neigh, ft_in_neigh) 119 | max_out_neigh = max(trn_out_neigh, tst_out_neigh, ft_out_neigh) 120 | max_sent = max(trn_sent, tst_sent, ft_sent) 121 | print('Max node number: {}, while max allowed is {}'.format(max_node, FLAGS.max_node_num)) 122 | print('Max parent number: {}, truncated to {}'.format(max_in_neigh, FLAGS.max_in_neigh_num)) 123 | print('Max children number: {}, truncated to {}'.format(max_out_neigh, FLAGS.max_out_neigh_num)) 124 | print('Max answer length: {}, truncated to {}'.format(max_sent, FLAGS.max_answer_len)) 125 | 126 | word_vocab = None 127 | char_vocab = None 128 | edgelabel_vocab = None 129 | has_pretrained_model = False 130 | best_path = path_prefix + ".best.model" 131 | if os.path.exists(best_path + ".index"): 132 | has_pretrained_model = True 133 | print('!!Existing pretrained model. Loading vocabs.') 134 | word_vocab = Vocab(FLAGS.word_vec_path, fileformat='txt2') 135 | print('word_vocab: {}'.format(word_vocab.word_vecs.shape)) 136 | char_vocab = None 137 | if FLAGS.with_char: 138 | char_vocab = Vocab(path_prefix + ".char_vocab", fileformat='txt2') 139 | print('char_vocab: {}'.format(char_vocab.word_vecs.shape)) 140 | edgelabel_vocab = Vocab(path_prefix + ".edgelabel_vocab", fileformat='txt2') 141 | else: 142 | print('Collecting vocabs.') 143 | (allWords, allChars, allEdgelabels) = G2S_data_stream.collect_vocabs(trainset) 144 | print('Number of words: {}'.format(len(allWords))) 145 | print('Number of allChars: {}'.format(len(allChars))) 146 | print('Number of allEdgelabels: {}'.format(len(allEdgelabels))) 147 | 148 | word_vocab = Vocab(FLAGS.word_vec_path, fileformat='txt2') 149 | char_vocab = None 150 | if FLAGS.with_char: 151 | char_vocab = Vocab(voc=allChars, dim=FLAGS.char_dim, fileformat='build') 152 | char_vocab.dump_to_txt2(path_prefix + ".char_vocab") 153 | edgelabel_vocab = Vocab(voc=allEdgelabels, dim=FLAGS.edgelabel_dim, fileformat='build') 154 | edgelabel_vocab.dump_to_txt2(path_prefix + ".edgelabel_vocab") 155 | 156 | print('word vocab size {}'.format(word_vocab.vocab_size)) 157 | sys.stdout.flush() 158 | 159 | print('Build DataStream ... ') 160 | trainDataStream = G2S_data_stream.G2SDataStream(trainset, word_vocab, char_vocab, edgelabel_vocab, options=FLAGS, 161 | isShuffle=True, isLoop=True, isSort=True) 162 | 163 | devDataStream = G2S_data_stream.G2SDataStream(devset, word_vocab, char_vocab, edgelabel_vocab, options=FLAGS, 164 | isShuffle=False, isLoop=False, isSort=True) 165 | print('Number of instances in trainDataStream: {}'.format(trainDataStream.get_num_instance())) 166 | print('Number of instances in devDataStream: {}'.format(devDataStream.get_num_instance())) 167 | print('Number of batches in trainDataStream: {}'.format(trainDataStream.get_num_batch())) 168 | print('Number of batches in devDataStream: {}'.format(devDataStream.get_num_batch())) 169 | if ftset != None: 170 | ftDataStream = G2S_data_stream.G2SDataStream(ftset, word_vocab, char_vocab, edgelabel_vocab, options=FLAGS, 171 | isShuffle=True, isLoop=True, isSort=True) 172 | print('Number of instances in ftDataStream: {}'.format(ftDataStream.get_num_instance())) 173 | print('Number of batches in ftDataStream: {}'.format(ftDataStream.get_num_batch())) 174 | 175 | sys.stdout.flush() 176 | 177 | # initialize the best bleu and accu scores for current training session 178 | best_accu = FLAGS.best_accu if FLAGS.__dict__.has_key('best_accu') else 0.0 179 | best_bleu = FLAGS.best_bleu if FLAGS.__dict__.has_key('best_bleu') else 0.0 180 | if best_accu > 0.0: 181 | print('With initial dev accuracy {}'.format(best_accu)) 182 | if best_bleu > 0.0: 183 | print('With initial dev BLEU score {}'.format(best_bleu)) 184 | 185 | init_scale = 0.01 186 | with tf.Graph().as_default(): 187 | initializer = tf.random_uniform_initializer(-init_scale, init_scale) 188 | with tf.name_scope("Train"): 189 | with tf.variable_scope("Model", reuse=None, initializer=initializer): 190 | train_graph = ModelGraph(word_vocab=word_vocab, Edgelabel_vocab=edgelabel_vocab, 191 | char_vocab=char_vocab, options=FLAGS, mode=FLAGS.mode) 192 | 193 | assert FLAGS.mode in ('ce_train', 'rl_train', ) 194 | valid_mode = 'evaluate' if FLAGS.mode == 'ce_train' else 'evaluate_bleu' 195 | 196 | with tf.name_scope("Valid"): 197 | with tf.variable_scope("Model", reuse=True, initializer=initializer): 198 | valid_graph = ModelGraph(word_vocab=word_vocab, Edgelabel_vocab=edgelabel_vocab, 199 | char_vocab=char_vocab, options=FLAGS, mode=valid_mode) 200 | 201 | initializer = tf.global_variables_initializer() 202 | 203 | vars_ = {} 204 | for var in tf.all_variables(): 205 | if FLAGS.fix_word_vec and "word_embedding" in var.name: continue 206 | if not var.name.startswith("Model"): continue 207 | vars_[var.name.split(":")[0]] = var 208 | print(var) 209 | saver = tf.train.Saver(vars_) 210 | 211 | sess = tf.Session() 212 | sess.run(initializer) 213 | if has_pretrained_model: 214 | print("Restoring model from " + best_path) 215 | saver.restore(sess, best_path) 216 | print("DONE!") 217 | 218 | if FLAGS.mode == 'rl_train' and abs(best_bleu) < 0.00001: 219 | print("Getting BLEU score for the model") 220 | sys.stdout.flush() 221 | best_bleu = evaluate(sess, valid_graph, devDataStream, options=FLAGS)['dev_bleu'] 222 | FLAGS.best_bleu = best_bleu 223 | namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json") 224 | print('BLEU = %.4f' % best_bleu) 225 | sys.stdout.flush() 226 | log_file.write('BLEU = %.4f\n' % best_bleu) 227 | if FLAGS.mode == 'ce_train' and abs(best_accu) < 0.00001: 228 | print("Getting ACCU score for the model") 229 | best_accu = evaluate(sess, valid_graph, devDataStream, options=FLAGS)['dev_accu'] 230 | FLAGS.best_accu = best_accu 231 | namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json") 232 | print('ACCU = %.4f' % best_accu) 233 | log_file.write('ACCU = %.4f\n' % best_accu) 234 | 235 | print('Start the training loop.') 236 | train_size = trainDataStream.get_num_batch() 237 | max_steps = train_size * FLAGS.max_epochs 238 | total_loss = 0.0 239 | start_time = time.time() 240 | for step in xrange(max_steps): 241 | cur_batch = trainDataStream.nextBatch() 242 | if FLAGS.mode == 'rl_train': 243 | loss_value = train_graph.run_rl_training_subsample(sess, cur_batch, FLAGS) 244 | elif FLAGS.mode == 'ce_train': 245 | loss_value = train_graph.run_ce_training(sess, cur_batch, FLAGS) 246 | total_loss += loss_value 247 | 248 | if step % 100==0: 249 | print('{} '.format(step), end="") 250 | sys.stdout.flush() 251 | 252 | # Save a checkpoint and evaluate the model periodically. 253 | if (step + 1) % trainDataStream.get_num_batch() == 0 or (step + 1) == max_steps or \ 254 | (trainDataStream.get_num_batch() > 10000 and (step+1)%2000 == 0): 255 | print() 256 | duration = time.time() - start_time 257 | print('Step %d: loss = %.2f (%.3f sec)' % (step, total_loss, duration)) 258 | log_file.write('Step %d: loss = %.2f (%.3f sec)\n' % (step, total_loss, duration)) 259 | log_file.flush() 260 | sys.stdout.flush() 261 | total_loss = 0.0 262 | 263 | if ftset != None: 264 | best_accu, best_bleu = fine_tune(sess, saver, FLAGS, log_file, 265 | ftDataStream, devDataStream, train_graph, valid_graph, path_prefix, best_accu, best_bleu) 266 | else: 267 | best_accu, best_bleu = validate_and_save(sess, saver, FLAGS, log_file, 268 | devDataStream, valid_graph, path_prefix, best_accu, best_bleu) 269 | start_time = time.time() 270 | 271 | log_file.close() 272 | 273 | 274 | def validate_and_save(sess, saver, FLAGS, log_file, 275 | devDataStream, valid_graph, path_prefix, best_accu, best_bleu): 276 | best_path = path_prefix + ".best.model" 277 | start_time = time.time() 278 | print('Validation Data Eval:') 279 | res_dict = evaluate(sess, valid_graph, devDataStream, options=FLAGS) 280 | if valid_graph.mode == 'evaluate': 281 | dev_loss = res_dict['dev_loss'] 282 | dev_accu = res_dict['dev_accu'] 283 | dev_right = int(res_dict['dev_right']) 284 | dev_total = int(res_dict['dev_total']) 285 | print('Dev loss = %.4f' % dev_loss) 286 | log_file.write('Dev loss = %.4f\n' % dev_loss) 287 | print('Dev accu = %.4f %d/%d' % (dev_accu, dev_right, dev_total)) 288 | log_file.write('Dev accu = %.4f %d/%d\n' % (dev_accu, dev_right, dev_total)) 289 | log_file.flush() 290 | if best_accu < dev_accu: 291 | print('Saving weights, ACCU {} (prev_best) < {} (cur)'.format(best_accu, dev_accu)) 292 | saver.save(sess, best_path) 293 | best_accu = dev_accu 294 | FLAGS.best_accu = dev_accu 295 | namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json") 296 | else: 297 | dev_bleu = res_dict['dev_bleu'] 298 | print('Dev bleu = %.4f' % dev_bleu) 299 | log_file.write('Dev bleu = %.4f\n' % dev_bleu) 300 | log_file.flush() 301 | if best_bleu < dev_bleu: 302 | print('Saving weights, BLEU {} (prev_best) < {} (cur)'.format(best_bleu, dev_bleu)) 303 | saver.save(sess, best_path) 304 | best_bleu = dev_bleu 305 | FLAGS.best_bleu = dev_bleu 306 | namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json") 307 | duration = time.time() - start_time 308 | print('Duration %.3f sec' % (duration)) 309 | sys.stdout.flush() 310 | return best_accu, best_bleu 311 | 312 | 313 | def fine_tune(sess, saver, FLAGS, log_file, 314 | ftDataStream, devDataStream, train_graph, valid_graph, path_prefix, best_accu, best_bleu): 315 | print('=====Start the fine tuning.') 316 | sys.stdout.flush() 317 | max_steps = ftDataStream.get_num_batch() * 1 318 | best_path = path_prefix + ".best.model" 319 | total_loss = 0.0 320 | start_time = time.time() 321 | for step in xrange(max_steps): 322 | cur_batch = ftDataStream.nextBatch() 323 | if FLAGS.mode == 'rl_train': 324 | loss_value = train_graph.run_rl_training_subsample(sess, cur_batch, FLAGS) 325 | elif FLAGS.mode == 'ce_train': 326 | loss_value = train_graph.run_ce_training(sess, cur_batch, FLAGS) 327 | total_loss += loss_value 328 | 329 | if step % 100==0: 330 | print('{} '.format(step), end="") 331 | sys.stdout.flush() 332 | 333 | # Save a checkpoint and evaluate the model periodically. 334 | if (step + 1) % ftDataStream.get_num_batch() == 0 or (step + 1) == max_steps: 335 | print() 336 | duration = time.time() - start_time 337 | print('Step %d: loss = %.2f (%.3f sec)' % (step, total_loss, duration)) 338 | sys.stdout.flush() 339 | log_file.write('Step %d: loss = %.2f (%.3f sec)\n' % (step, total_loss, duration)) 340 | log_file.flush() 341 | best_accu, best_bleu = validate_and_save(sess, saver, FLAGS, log_file, 342 | devDataStream, valid_graph, path_prefix, best_accu, best_bleu) 343 | total_loss = 0.0 344 | start_time = time.time() 345 | 346 | print('=====End the fine tuning.') 347 | sys.stdout.flush() 348 | return best_accu, best_bleu 349 | 350 | 351 | def enrich_options(options): 352 | if not options.__dict__.has_key("finetune_path"): 353 | options.__dict__["finetune_path"] = "" 354 | 355 | if not options.__dict__.has_key("CE_loss"): 356 | options.__dict__["CE_loss"] = False 357 | 358 | if not options.__dict__.has_key("reward_type"): 359 | options.__dict__["reward_type"] = "bleu" 360 | 361 | if not options.__dict__.has_key("way_init_decoder"): 362 | options.__dict__["way_init_decoder"] = 'zero' 363 | 364 | return options 365 | 366 | 367 | if __name__ == '__main__': 368 | parser = argparse.ArgumentParser() 369 | parser.add_argument('--config_path', type=str, help='Configuration file.') 370 | 371 | #os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" # see issue #152 372 | #os.environ["CUDA_VISIBLE_DEVICES"]="2" 373 | 374 | print("CUDA_VISIBLE_DEVICES " + os.environ['CUDA_VISIBLE_DEVICES']) 375 | FLAGS, unparsed = parser.parse_known_args() 376 | 377 | 378 | if FLAGS.config_path is not None: 379 | print('Loading the configuration from ' + FLAGS.config_path) 380 | FLAGS = namespace_utils.load_namespace(FLAGS.config_path) 381 | 382 | FLAGS = enrich_options(FLAGS) 383 | 384 | sys.stdout.flush() 385 | tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) 386 | -------------------------------------------------------------------------------- /src_s2s/NP2P_trainer.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import print_function 3 | import argparse 4 | import os 5 | import sys 6 | import time 7 | import numpy as np 8 | import codecs 9 | 10 | from vocab_utils import Vocab 11 | import namespace_utils 12 | import NP2P_data_stream 13 | from NP2P_model_graph import ModelGraph 14 | 15 | FLAGS = None 16 | import tensorflow as tf 17 | tf.logging.set_verbosity(tf.logging.ERROR) # DEBUG, INFO, WARN, ERROR, and FATAL 18 | 19 | from nltk.translate.bleu_score import SmoothingFunction, corpus_bleu, sentence_bleu 20 | cc = SmoothingFunction() 21 | 22 | import metric_utils 23 | 24 | import platform 25 | def get_machine_name(): 26 | return platform.node() 27 | 28 | def vec2string(val): 29 | result = "" 30 | for v in val: 31 | result += " {}".format(v) 32 | return result.strip() 33 | 34 | 35 | def softmax(x): 36 | """Compute softmax values for each sets of scores in x.""" 37 | e_x = np.exp(x - np.max(x)) 38 | return e_x / e_x.sum() 39 | 40 | 41 | def document_bleu(vocab, gen, ref, suffix=''): 42 | genlex = [vocab.getLexical(x)[1] for x in gen] 43 | reflex = [[vocab.getLexical(x)[1],] for x in ref] 44 | #return metric_utils.evaluate_captions(genlex,reflex) 45 | genlst = [x.split() for x in genlex] 46 | reflst = [[x[0].split()] for x in reflex] 47 | f = codecs.open('gen.txt'+suffix,'w','utf-8') 48 | for line in genlex: 49 | print(line, end='\n', file=f) 50 | f.close() 51 | f = codecs.open('ref.txt'+suffix,'w','utf-8') 52 | for line in reflex: 53 | print(line[0], end='\n', file=f) 54 | f.close() 55 | return corpus_bleu(reflst, genlst, smoothing_function=cc.method3) 56 | 57 | 58 | def evaluate(sess, valid_graph, devDataStream, options=None, suffix=''): 59 | devDataStream.reset() 60 | gen = [] 61 | ref = [] 62 | dev_loss = 0.0 63 | dev_right = 0.0 64 | dev_total = 0.0 65 | for batch_index in xrange(devDataStream.get_num_batch()): # for each batch 66 | cur_batch = devDataStream.get_batch(batch_index) 67 | if valid_graph.mode == 'evaluate': 68 | accu_value, loss_value = valid_graph.run_ce_training(sess, cur_batch, options, only_eval=True) 69 | dev_loss += loss_value 70 | dev_right += accu_value 71 | dev_total += np.sum(cur_batch.answer_lengths) 72 | elif valid_graph.mode == 'evaluate_bleu': 73 | gen.extend(valid_graph.run_greedy(sess, cur_batch, options).tolist()) 74 | ref.extend(cur_batch.in_answer_words.tolist()) 75 | else: 76 | assert False 77 | 78 | if valid_graph.mode == 'evaluate': 79 | return {'dev_loss':dev_loss, 'dev_accu':1.0*dev_right/dev_total, 'dev_right':dev_right, 'dev_total':dev_total, } 80 | else: 81 | return {'dev_bleu':document_bleu(valid_graph.dec_word_vocab,gen,ref,suffix), } 82 | 83 | 84 | 85 | def main(_): 86 | print('Configurations:') 87 | print(FLAGS) 88 | 89 | log_dir = FLAGS.model_dir 90 | if not os.path.exists(log_dir): 91 | os.makedirs(log_dir) 92 | 93 | path_prefix = log_dir + "/NP2P.{}".format(FLAGS.suffix) 94 | log_file_path = path_prefix + ".log" 95 | print('Log file path: {}'.format(log_file_path)) 96 | log_file = open(log_file_path, 'wt') 97 | log_file.write("{}\n".format(FLAGS)) 98 | log_file.flush() 99 | 100 | # save configuration 101 | namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json") 102 | 103 | print('Loading training set.') 104 | trainset, train_ans_len = NP2P_data_stream.read_all_GenerationDatasets(FLAGS.train_path, isLower=FLAGS.isLower) 105 | print('Number of training samples: {}'.format(len(trainset))) 106 | 107 | print('Loading dev set.') 108 | devset, dev_ans_len = NP2P_data_stream.read_all_GenerationDatasets(FLAGS.test_path, isLower=FLAGS.isLower) 109 | print('Number of dev samples: {}'.format(len(devset))) 110 | 111 | if FLAGS.finetune_path != "": 112 | print('Loading finetune set.') 113 | ftset, ft_ans_len = NP2P_data_stream.read_all_GenerationDatasets(FLAGS.ft_path, isLower=FLAGS.isLower) 114 | print('Number of finetune samples: {}'.format(len(ftset))) 115 | else: 116 | ftset, ft_ans_len = (None, 0) 117 | 118 | max_actual_len = max(train_ans_len, ft_ans_len, dev_ans_len) 119 | print('Max answer length: {}, truncated to {}'.format(max_actual_len, FLAGS.max_answer_len)) 120 | 121 | enc_word_vocab = None 122 | dec_word_vocab = None 123 | char_vocab = None 124 | has_pretrained_model = False 125 | best_path = path_prefix + ".best.model" 126 | if os.path.exists(best_path + ".index"): 127 | has_pretrained_model = True 128 | print('!!Existing pretrained model. Loading vocabs.') 129 | if FLAGS.with_word: 130 | enc_word_vocab = Vocab(FLAGS.enc_word_vec_path, fileformat='txt2') 131 | dec_word_vocab = Vocab(FLAGS.dec_word_vec_path, fileformat='txt2') 132 | print('Encoder word vocab: {}'.format(enc_word_vocab.word_vecs.shape)) 133 | print('Decoder word vocab: {}'.format(dec_word_vocab.word_vecs.shape)) 134 | if FLAGS.with_char: 135 | char_vocab = Vocab(path_prefix + ".char_vocab", fileformat='txt2') 136 | print('char_vocab: {}'.format(char_vocab.word_vecs.shape)) 137 | else: 138 | print('Collecting vocabs.') 139 | (allWords, allChars) = NP2P_data_stream.collect_vocabs(trainset) 140 | print('Number of words: {}'.format(len(allWords))) 141 | print('Number of allChars: {}'.format(len(allChars))) 142 | 143 | if FLAGS.with_word: 144 | enc_word_vocab = Vocab(FLAGS.enc_word_vec_path, fileformat='txt2') 145 | dec_word_vocab = Vocab(FLAGS.dec_word_vec_path, fileformat='txt2') 146 | if FLAGS.with_char: 147 | char_vocab = Vocab(voc=allChars, dim=FLAGS.char_dim, fileformat='build') 148 | char_vocab.dump_to_txt2(path_prefix + ".char_vocab") 149 | 150 | print('Encoder word vocab size {}'.format(enc_word_vocab.vocab_size)) 151 | print('Decoder word vocab size {}'.format(dec_word_vocab.vocab_size)) 152 | sys.stdout.flush() 153 | 154 | print('Build DataStream ... ') 155 | trainDataStream = NP2P_data_stream.DataStream(trainset, enc_word_vocab, dec_word_vocab, char_vocab, options=FLAGS, 156 | isShuffle=True, isLoop=True, isSort=True) 157 | devDataStream = NP2P_data_stream.DataStream(devset, enc_word_vocab, dec_word_vocab, char_vocab, options=FLAGS, 158 | isShuffle=False, isLoop=False, isSort=True) 159 | print('Number of instances in trainDataStream: {}'.format(trainDataStream.get_num_instance())) 160 | print('Number of instances in devDataStream: {}'.format(devDataStream.get_num_instance())) 161 | print('Number of batches in trainDataStream: {}'.format(trainDataStream.get_num_batch())) 162 | print('Number of batches in devDataStream: {}'.format(devDataStream.get_num_batch())) 163 | if ftset != None: 164 | ftDataStream = NP2P_data_stream.DataStream(ftset, enc_word_vocab, dec_word_vocab, char_vocab, options=FLAGS, 165 | isShuffle=True, isLoop=True, isSort=True) 166 | print('Number of instances in ftDataStream: {}'.format(ftDataStream.get_num_instance())) 167 | print('Number of batches in ftDataStream: {}'.format(ftDataStream.get_num_batch())) 168 | 169 | sys.stdout.flush() 170 | 171 | init_scale = 0.01 172 | # initialize the best bleu and accu scores for current training session 173 | best_accu = FLAGS.best_accu if FLAGS.__dict__.has_key('best_accu') else 0.0 174 | best_bleu = FLAGS.best_bleu if FLAGS.__dict__.has_key('best_bleu') else 0.0 175 | if best_accu > 0.0: 176 | print('With initial dev accuracy {}'.format(best_accu)) 177 | if best_bleu > 0.0: 178 | print('With initial dev BLEU score {}'.format(best_bleu)) 179 | 180 | with tf.Graph().as_default(): 181 | initializer = tf.random_uniform_initializer(-init_scale, init_scale) 182 | with tf.name_scope("Train"): 183 | with tf.variable_scope("Model", reuse=None, initializer=initializer): 184 | train_graph = ModelGraph(enc_word_vocab=enc_word_vocab, dec_word_vocab=dec_word_vocab, char_vocab=char_vocab, 185 | POS_vocab=None, NER_vocab=None, options=FLAGS, mode=FLAGS.mode) 186 | 187 | assert FLAGS.mode in ('ce_train', 'rl_train', ) 188 | valid_mode = 'evaluate' if FLAGS.mode == 'ce_train' else 'evaluate_bleu' 189 | 190 | with tf.name_scope("Valid"): 191 | with tf.variable_scope("Model", reuse=True, initializer=initializer): 192 | valid_graph = ModelGraph(enc_word_vocab=enc_word_vocab, dec_word_vocab=dec_word_vocab, char_vocab=char_vocab, 193 | POS_vocab=None, NER_vocab=None, options=FLAGS, mode=valid_mode) 194 | 195 | initializer = tf.global_variables_initializer() 196 | 197 | vars_ = {} 198 | for var in tf.all_variables(): 199 | if FLAGS.fix_word_vec and "word_embedding" in var.name: continue 200 | if not var.name.startswith("Model"): continue 201 | print(var) 202 | vars_[var.name.split(":")[0]] = var 203 | saver = tf.train.Saver(vars_) 204 | 205 | sess = tf.Session() 206 | sess.run(initializer) 207 | if has_pretrained_model: 208 | print("Restoring model from " + best_path) 209 | saver.restore(sess, best_path) 210 | print("DONE!") 211 | 212 | if FLAGS.mode == 'rl_train' and abs(best_bleu) < 0.00001: 213 | print("Getting BLEU score for the model") 214 | best_bleu = evaluate(sess, valid_graph, devDataStream, options=FLAGS)['dev_bleu'] 215 | FLAGS.best_bleu = best_bleu 216 | namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json") 217 | print('BLEU = %.4f' % best_bleu) 218 | log_file.write('BLEU = %.4f\n' % best_bleu) 219 | if FLAGS.mode == 'ce_train' and abs(best_accu) < 0.00001: 220 | print("Getting ACCU score for the model") 221 | best_accu = evaluate(sess, valid_graph, devDataStream, options=FLAGS)['dev_accu'] 222 | FLAGS.best_accu = best_accu 223 | namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json") 224 | print('ACCU = %.4f' % best_accu) 225 | log_file.write('ACCU = %.4f\n' % best_accu) 226 | 227 | print('Start the training loop.') 228 | train_size = trainDataStream.get_num_batch() 229 | max_steps = train_size * FLAGS.max_epochs 230 | total_loss = 0.0 231 | start_time = time.time() 232 | for step in xrange(max_steps): 233 | cur_batch = trainDataStream.nextBatch() 234 | if FLAGS.mode == 'rl_train': 235 | loss_value = train_graph.run_rl_training_2(sess, cur_batch, FLAGS) 236 | elif FLAGS.mode == 'ce_train': 237 | loss_value = train_graph.run_ce_training(sess, cur_batch, FLAGS) 238 | total_loss += loss_value 239 | 240 | if step % 100==0: 241 | print('{} '.format(step), end="") 242 | sys.stdout.flush() 243 | 244 | 245 | # Save a checkpoint and evaluate the model periodically. 246 | if (step + 1) % trainDataStream.get_num_batch() == 0 or (step + 1) == max_steps or \ 247 | (trainDataStream.get_num_batch() > 10000 and (step + 1) % 2000 == 0): 248 | print() 249 | duration = time.time() - start_time 250 | print('Step %d: loss = %.2f (%.3f sec)' % (step, total_loss, duration)) 251 | log_file.write('Step %d: loss = %.2f (%.3f sec)\n' % (step, total_loss, duration)) 252 | log_file.flush() 253 | sys.stdout.flush() 254 | total_loss = 0.0 255 | 256 | if ftset != None: 257 | best_accu, best_bleu = fine_tune(sess, saver, FLAGS, log_file, 258 | ftDataStream, devDataStream, train_graph, valid_graph, path_prefix, best_accu, best_bleu) 259 | else: 260 | best_accu, best_bleu = validate_and_save(sess, saver, FLAGS, log_file, 261 | devDataStream, valid_graph, path_prefix, best_accu, best_bleu) 262 | start_time = time.time() 263 | 264 | log_file.close() 265 | 266 | def validate_and_save(sess, saver, FLAGS, log_file, 267 | devDataStream, valid_graph, path_prefix, best_accu, best_bleu): 268 | best_path = path_prefix + ".best.model" 269 | # Evaluate against the validation set. 270 | start_time = time.time() 271 | print('Validation Data Eval:') 272 | res_dict = evaluate(sess, valid_graph, devDataStream, options=FLAGS) 273 | if valid_graph.mode == 'evaluate': 274 | dev_loss = res_dict['dev_loss'] 275 | dev_accu = res_dict['dev_accu'] 276 | dev_right = int(res_dict['dev_right']) 277 | dev_total = int(res_dict['dev_total']) 278 | print('Dev loss = %.4f' % dev_loss) 279 | log_file.write('Dev loss = %.4f\n' % dev_loss) 280 | print('Dev accu = %.4f %d/%d' % (dev_accu, dev_right, dev_total)) 281 | log_file.write('Dev accu = %.4f %d/%d\n' % (dev_accu, dev_right, dev_total)) 282 | log_file.flush() 283 | if best_accu < dev_accu: 284 | print('Saving weights, ACCU {} (prev_best) < {} (cur)'.format(best_accu, dev_accu)) 285 | saver.save(sess, best_path) 286 | best_accu = dev_accu 287 | FLAGS.best_accu = dev_accu 288 | namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json") 289 | else: 290 | dev_bleu = res_dict['dev_bleu'] 291 | print('Dev bleu = %.4f' % dev_bleu) 292 | log_file.write('Dev bleu = %.4f\n' % dev_bleu) 293 | log_file.flush() 294 | if best_bleu < dev_bleu: 295 | print('Saving weights, BLEU {} (prev_best) < {} (cur)'.format(best_bleu, dev_bleu)) 296 | saver.save(sess, best_path) 297 | best_bleu = dev_bleu 298 | FLAGS.best_bleu = dev_bleu 299 | namespace_utils.save_namespace(FLAGS, path_prefix + ".config.json") 300 | duration = time.time() - start_time 301 | print('Duration %.3f sec' % (duration)) 302 | sys.stdout.flush() 303 | 304 | log_file.write('Duration %.3f sec\n' % (duration)) 305 | log_file.flush() 306 | return best_accu, best_bleu 307 | 308 | 309 | def fine_tune(sess, saver, FLAGS, log_file, 310 | ftDataStream, devDataStream, train_graph, valid_graph, path_prefix, best_accu, best_bleu): 311 | print('=====Start the fine tuning.') 312 | sys.stdout.flush() 313 | train_size = ftDataStream.get_num_batch() 314 | max_steps = train_size * 3 315 | best_path = path_prefix + ".best.model" 316 | total_loss = 0.0 317 | start_time = time.time() 318 | for step in xrange(max_steps): 319 | cur_batch = ftDataStream.nextBatch() 320 | if FLAGS.mode == 'rl_train': 321 | loss_value = train_graph.run_rl_training_subsample(sess, cur_batch, FLAGS) 322 | elif FLAGS.mode == 'ce_train': 323 | loss_value = train_graph.run_ce_training(sess, cur_batch, FLAGS) 324 | total_loss += loss_value 325 | 326 | if step % 100==0: 327 | print('{} '.format(step), end="") 328 | sys.stdout.flush() 329 | 330 | # Save a checkpoint and evaluate the model periodically. 331 | if (step + 1) % ftDataStream.get_num_batch() == 0 or (step + 1) == max_steps: 332 | print() 333 | duration = time.time() - start_time 334 | print('Step %d: loss = %.2f (%.3f sec)' % (step, total_loss, duration)) 335 | sys.stdout.flush() 336 | log_file.write('Step %d: loss = %.2f (%.3f sec)\n' % (step, total_loss, duration)) 337 | log_file.flush() 338 | total_loss = 0.0 339 | 340 | best_accu, best_bleu = validate_and_save(sess, saver, FLAGS, log_file, 341 | devDataStream, valid_graph, path_prefix, best_accu, best_bleu) 342 | print('=====End the fine tuning.') 343 | sys.stdout.flush() 344 | return best_accu, best_bleu 345 | 346 | 347 | def enrich_options(options): 348 | if not options.__dict__.has_key("finetune_path"): 349 | options.__dict__["finetune_path"] = "" 350 | 351 | if not options.__dict__.has_key("CE_loss"): 352 | options.__dict__["CE_loss"] = False 353 | 354 | if not options.__dict__.has_key("infile_format"): 355 | options.__dict__["infile_format"] = "plain" 356 | 357 | if not options.__dict__.has_key("with_target_lattice"): 358 | options.__dict__["with_target_lattice"] = False 359 | 360 | if not options.__dict__.has_key("add_first_word_prob_for_phrase"): 361 | options.__dict__["add_first_word_prob_for_phrase"] = False 362 | 363 | if not options.__dict__.has_key("pretrain_with_max_matching"): 364 | options.__dict__["pretrain_with_max_matching"] = False 365 | 366 | if not options.__dict__.has_key("reward_type"): 367 | options.__dict__["reward_type"] = "bleu" 368 | 369 | return options 370 | 371 | 372 | if __name__ == '__main__': 373 | parser = argparse.ArgumentParser() 374 | parser.add_argument('--config_path', type=str, help='Configuration file.') 375 | 376 | #os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" # see issue #152 377 | #os.environ["CUDA_VISIBLE_DEVICES"]="3" 378 | 379 | print("CUDA_VISIBLE_DEVICES " + os.environ['CUDA_VISIBLE_DEVICES']) 380 | FLAGS, unparsed = parser.parse_known_args() 381 | 382 | 383 | if FLAGS.config_path is not None: 384 | print('Loading the configuration from ' + FLAGS.config_path) 385 | FLAGS = namespace_utils.load_namespace(FLAGS.config_path) 386 | 387 | FLAGS = enrich_options(FLAGS) 388 | 389 | sys.stdout.flush() 390 | tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) 391 | -------------------------------------------------------------------------------- /src_s2s/NP2P_beam_decoder.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import print_function 3 | import argparse 4 | import os 5 | import sys 6 | import time 7 | import numpy as np 8 | 9 | from vocab_utils import Vocab 10 | import namespace_utils 11 | import NP2P_data_stream 12 | from NP2P_model_graph import ModelGraph 13 | 14 | import re 15 | 16 | import tensorflow as tf 17 | import NP2P_trainer 18 | tf.logging.set_verbosity(tf.logging.ERROR) # DEBUG, INFO, WARN, ERROR, and FATAL 19 | 20 | def search(sess, model, vocab, batch, options, decode_mode='greedy'): 21 | ''' 22 | for greedy search, multinomial search 23 | ''' 24 | # Run the encoder to get the encoder hidden states and decoder initial state 25 | (phrase_representations, initial_state, encoder_features,phrase_idx, phrase_mask) = model.run_encoder(sess, batch, options) 26 | # phrase_representations: [batch_size, passage_len, encode_dim] 27 | # initial_state: a tupel of [batch_size, gen_dim] 28 | # encoder_features: [batch_size, passage_len, attention_vec_size] 29 | # phrase_idx: [batch_size, passage_len] 30 | # phrase_mask: [batch_size, passage_len] 31 | 32 | word_t = batch.gen_input_words[:,0] 33 | state_t = initial_state 34 | context_t = np.zeros([batch.batch_size, model.encode_dim]) 35 | coverage_t = np.zeros((batch.batch_size, phrase_representations.shape[1])) 36 | generator_output_idx = [] # store phrase index prediction 37 | text_results = [] 38 | generator_input_idx = [word_t] # store word index 39 | for i in xrange(options.max_answer_len): 40 | if decode_mode == "pointwise": word_t = batch.gen_input_words[:,i] 41 | feed_dict = {} 42 | feed_dict[model.init_decoder_state] = state_t 43 | feed_dict[model.context_t_1] = context_t 44 | feed_dict[model.coverage_t_1] = coverage_t 45 | feed_dict[model.word_t] = word_t 46 | 47 | feed_dict[model.phrase_representations] = phrase_representations 48 | feed_dict[model.encoder_features] = encoder_features 49 | feed_dict[model.phrase_idx] = phrase_idx 50 | feed_dict[model.phrase_mask] = phrase_mask 51 | if options.with_phrase_projection: 52 | feed_dict[model.max_phrase_size] = batch.max_phrase_size 53 | if options.add_first_word_prob_for_phrase: 54 | feed_dict[model.in_passage_words] = batch.sent1_word 55 | feed_dict[model.phrase_starts] = batch.phrase_starts 56 | 57 | 58 | 59 | if decode_mode in ["greedy","pointwise"]: 60 | prediction = model.greedy_prediction 61 | elif decode_mode == "multinomial": 62 | prediction = model.multinomial_prediction 63 | 64 | (state_t, context_t, attn_dist_t, coverage_t, prediction) = sess.run([model.state_t, model.context_t, model.attn_dist_t, 65 | model.coverage_t, prediction], feed_dict) 66 | attn_idx = np.argmax(attn_dist_t, axis=1) # [batch_size] 67 | # convert prediction to word ids 68 | generator_output_idx.append(prediction) 69 | prediction = np.reshape(prediction, [prediction.size, 1]) 70 | [cur_words, cur_word_idx] = batch.map_phrase_idx_to_text(prediction) # [batch_size, 1] 71 | cur_word_idx = np.array(cur_word_idx) 72 | cur_word_idx = np.reshape(cur_word_idx, [cur_word_idx.size]) 73 | word_t = cur_word_idx 74 | cur_words = flatten_words(cur_words) # [batch_size] 75 | 76 | for i, wword in enumerate(cur_words): 77 | if wword == 'UNK' and attn_idx[i] < len(batch.passage_words[i]): 78 | cur_words[i] = batch.passage_words[i][attn_idx[i]] 79 | 80 | text_results.append(cur_words) 81 | generator_input_idx.append(cur_word_idx) 82 | 83 | generator_input_idx = generator_input_idx[:-1] # remove the last word to shift one position to the right 84 | generator_output_idx = np.stack(generator_output_idx, axis=1) # [batch_size, max_len] 85 | generator_input_idx = np.stack(generator_input_idx, axis=1) # [batch_size, max_len] 86 | 87 | prediction_lengths = [] # [batch_size] 88 | sentences = [] # [batch_size] 89 | for i in xrange(batch.batch_size): 90 | words = [] 91 | for j in xrange(options.max_answer_len): 92 | cur_phrase = text_results[j][i] 93 | # cur_phrase = cur_batch_text[j] 94 | words.append(cur_phrase) 95 | if cur_phrase == "": break# filter out based on end symbol 96 | prediction_lengths.append(len(words)) 97 | cur_sent = " ".join(words) 98 | sentences.append(cur_sent) 99 | 100 | return (sentences, prediction_lengths, generator_input_idx, generator_output_idx) 101 | 102 | def flatten_words(cur_words): 103 | all_words = [] 104 | for i in xrange(len(cur_words)): 105 | all_words.append(cur_words[i][0]) 106 | return all_words 107 | 108 | class Hypothesis(object): 109 | def __init__(self, tokens, log_ps, attn, state, context_vector, coverage_vector=None): 110 | self.tokens = tokens # store all tokens 111 | self.log_probs = log_ps # store log_probs for each time-step 112 | self.attn_ids = attn 113 | self.state = state 114 | self.context_vector = context_vector 115 | self.coverage_vector = coverage_vector 116 | 117 | def extend(self, token, log_prob, attn_i, state, context_vector, coverage_vector=None): 118 | return Hypothesis(self.tokens + [token], self.log_probs + [log_prob], self.attn_ids + [attn_i], state, 119 | context_vector, coverage_vector=coverage_vector) 120 | 121 | def latest_token(self): 122 | return self.tokens[-1] 123 | 124 | def avg_log_prob(self): 125 | return np.sum(self.log_probs[1:])/ (len(self.tokens)-1) 126 | 127 | def probs2string(self): 128 | out_string = "" 129 | for prob in self.log_probs: 130 | out_string += " %.4f" % prob 131 | return out_string.strip() 132 | 133 | def idx_seq_to_string(self, passage, id2phrase, vocab, options): 134 | word_size = vocab.vocab_size + 1 135 | all_words = [] 136 | for i, idx in enumerate(self.tokens): 137 | cur_word = vocab.getWord(idx) 138 | if cur_word == 'UNK': 139 | cur_word = passage[self.attn_ids[i]] 140 | all_words.append(cur_word) 141 | return " ".join(all_words[1:]) 142 | 143 | 144 | def sort_hyps(hyps): 145 | return sorted(hyps, key=lambda h: h.avg_log_prob(), reverse=True) 146 | 147 | 148 | 149 | def run_beam_search(sess, model, vocab, batch, options): 150 | # Run encoder 151 | st = time.time() 152 | (phrase_representations, initial_state, encoder_features,phrase_idx, phrase_mask) = model.run_encoder(sess, batch, options) 153 | encoding_dur = time.time() - st 154 | # phrase_representations: [1, passage_len, encode_dim] 155 | # initial_state: a tupel of [1, gen_dim] 156 | # encoder_features: [1, passage_len, attention_vec_size] 157 | # phrase_idx: [1, passage_len] 158 | # phrase_mask: [1, passage_len] 159 | 160 | sent_stop_id = vocab.getIndex('') 161 | 162 | # Initialize this first hypothesis 163 | context_t = np.zeros([model.encode_dim]) # [encode_dim] 164 | coverage_t = np.zeros((phrase_representations.shape[1])) # [passage_len] 165 | hyps = [] 166 | hyps.append(Hypothesis([batch.gen_input_words[0][0]], [0.0], [-1], initial_state, context_t, coverage_vector=coverage_t)) 167 | 168 | # beam search decoding 169 | results = [] # this will contain finished hypotheses (those that have emitted the token) 170 | steps = 0 171 | while steps < options.max_answer_len and len(results) < options.beam_size: 172 | cur_size = len(hyps) # current number of hypothesis in the beam 173 | cur_phrase_representations = np.tile(phrase_representations, (cur_size, 1, 1)) 174 | cur_encoder_features = np.tile(encoder_features, (cur_size, 1, 1)) # [batch_size,passage_len, options.attention_vec_size] 175 | cur_phrase_idx = np.tile(phrase_idx, (cur_size, 1)) # [batch_size, passage_len] 176 | cur_phrase_mask = np.tile(phrase_mask, (cur_size, 1)) # [batch_size, passage_len] 177 | cur_state_t_1 = [] # [2, gen_steps] 178 | cur_context_t_1 = [] # [batch_size, encoder_dim] 179 | cur_coverage_t_1 = [] # [batch_size, passage_len] 180 | cur_word_t = [] # [batch_size] 181 | for h in hyps: 182 | cur_state_t_1.append(h.state) 183 | cur_context_t_1.append(h.context_vector) 184 | cur_word_t.append(h.latest_token()) 185 | cur_coverage_t_1.append(h.coverage_vector) 186 | cur_context_t_1 = np.stack(cur_context_t_1, axis=0) 187 | cur_coverage_t_1 = np.stack(cur_coverage_t_1, axis=0) 188 | cur_word_t = np.array(cur_word_t) 189 | 190 | cells = [state.c for state in cur_state_t_1] 191 | hidds = [state.h for state in cur_state_t_1] 192 | new_c = np.concatenate(cells, axis=0) 193 | new_h = np.concatenate(hidds, axis=0) 194 | new_dec_init_state = tf.contrib.rnn.LSTMStateTuple(new_c, new_h) 195 | 196 | feed_dict = {} 197 | feed_dict[model.init_decoder_state] = new_dec_init_state 198 | feed_dict[model.context_t_1] = cur_context_t_1 199 | feed_dict[model.word_t] = cur_word_t 200 | 201 | feed_dict[model.phrase_representations] = cur_phrase_representations 202 | feed_dict[model.encoder_features] = cur_encoder_features 203 | feed_dict[model.phrase_idx] = cur_phrase_idx 204 | feed_dict[model.phrase_mask] = cur_phrase_mask 205 | feed_dict[model.coverage_t_1] = cur_coverage_t_1 206 | if options.with_phrase_projection: 207 | feed_dict[model.max_phrase_size] = batch.max_phrase_size 208 | if options.add_first_word_prob_for_phrase: 209 | feed_dict[model.in_passage_words] = batch.sent1_word 210 | feed_dict[model.phrase_starts] = batch.phrase_starts 211 | 212 | (state_t, context_t, attn_dist_t, coverage_t, topk_log_probs, topk_ids) = sess.run([model.state_t, model.context_t, 213 | model.attn_dist_t, model.coverage_t, model.topk_log_probs, model.topk_ids], feed_dict) 214 | 215 | new_states = [tf.contrib.rnn.LSTMStateTuple(state_t.c[i:i+1, :], state_t.h[i:i+1, :]) for i in xrange(cur_size)] 216 | 217 | # Extend each hypothesis and collect them all in all_hyps 218 | if steps == 0: cur_size = 1 219 | all_hyps = [] 220 | for i in xrange(cur_size): 221 | h = hyps[i] 222 | cur_state = new_states[i] 223 | cur_context = context_t[i] 224 | cur_coverage = coverage_t[i] 225 | for j in xrange(options.beam_size): 226 | cur_tok = topk_ids[i, j] 227 | cur_tok_log_prob = topk_log_probs[i, j] 228 | cur_attn_i = np.argmax(attn_dist_t[i, :]) 229 | new_hyp = h.extend(cur_tok, cur_tok_log_prob, cur_attn_i, cur_state, cur_context, coverage_vector=cur_coverage) 230 | all_hyps.append(new_hyp) 231 | 232 | # Filter and collect any hypotheses that have produced the end token. 233 | # hyps will contain hypotheses for the next step 234 | hyps = [] 235 | for h in sort_hyps(all_hyps): 236 | # If this hypothesis is sufficiently long, put in results. Otherwise discard. 237 | if h.latest_token() == sent_stop_id: 238 | if steps >= options.min_answer_len: 239 | results.append(h) 240 | # hasn't reached stop token, so continue to extend this hypothesis 241 | else: 242 | hyps.append(h) 243 | if len(hyps) == options.beam_size or len(results) == options.beam_size: 244 | break 245 | 246 | steps += 1 247 | 248 | # At this point, either we've got beam_size results, or we've reached maximum decoder steps 249 | # if we don't have any complete results, add all current hypotheses (incomplete summaries) to results 250 | if len(results)==0: 251 | results = hyps 252 | 253 | # Sort hypotheses by average log probability 254 | hyps_sorted = sort_hyps(results) 255 | 256 | # Return the hypothesis with highest average log prob 257 | return hyps_sorted, encoding_dur 258 | 259 | if __name__ == '__main__': 260 | parser = argparse.ArgumentParser() 261 | parser.add_argument('--model_prefix', type=str, required=True, help='Prefix to the models.') 262 | parser.add_argument('--in_path', type=str, required=True, help='The path to the test file.') 263 | parser.add_argument('--out_path', type=str, help='The path to the output file.') 264 | parser.add_argument('--mode', type=str,default='pointwise', help='The path to the output file.') 265 | parser.add_argument('--beam_size', type=int, default=-1, help='') 266 | 267 | args, unparsed = parser.parse_known_args() 268 | 269 | model_prefix = args.model_prefix 270 | in_path = args.in_path 271 | out_path = args.out_path 272 | mode = args.mode 273 | 274 | print("CUDA_VISIBLE_DEVICES " + os.environ['CUDA_VISIBLE_DEVICES']) 275 | 276 | # load the configuration file 277 | print('Loading configurations from ' + model_prefix + ".config.json") 278 | FLAGS = namespace_utils.load_namespace(model_prefix + ".config.json") 279 | FLAGS = NP2P_trainer.enrich_options(FLAGS) 280 | if args.beam_size != -1: 281 | FLAGS.beam_size = args.beam_size 282 | 283 | # load vocabs 284 | print('Loading vocabs.') 285 | enc_word_vocab = dec_word_vocab = char_vocab = POS_vocab = NER_vocab = None 286 | if FLAGS.with_word: 287 | enc_word_vocab = Vocab(FLAGS.enc_word_vec_path, fileformat='txt2') 288 | print('enc_word_vocab: {}'.format(enc_word_vocab.word_vecs.shape)) 289 | dec_word_vocab = Vocab(FLAGS.dec_word_vec_path, fileformat='txt2') 290 | print('dec_word_vocab: {}'.format(dec_word_vocab.word_vecs.shape)) 291 | if FLAGS.with_char: 292 | char_vocab = Vocab(model_prefix + ".char_vocab", fileformat='txt2') 293 | print('char_vocab: {}'.format(char_vocab.word_vecs.shape)) 294 | if FLAGS.with_POS: 295 | POS_vocab = Vocab(model_prefix + ".POS_vocab", fileformat='txt2') 296 | print('POS_vocab: {}'.format(POS_vocab.word_vecs.shape)) 297 | if FLAGS.with_NER: 298 | NER_vocab = Vocab(model_prefix + ".NER_vocab", fileformat='txt2') 299 | print('NER_vocab: {}'.format(NER_vocab.word_vecs.shape)) 300 | 301 | 302 | print('Loading test set.') 303 | if FLAGS.infile_format == 'fof': 304 | testset, _ = NP2P_data_stream.read_generation_datasets_from_fof(in_path, isLower=FLAGS.isLower) 305 | elif FLAGS.infile_format == 'plain': 306 | testset, _ = NP2P_data_stream.read_all_GenerationDatasets(in_path, isLower=FLAGS.isLower) 307 | else: 308 | testset, _ = NP2P_data_stream.read_all_GQA_questions(in_path, isLower=FLAGS.isLower, switch=FLAGS.switch_qa) 309 | print('Number of samples: {}'.format(len(testset))) 310 | 311 | print('Build DataStream ... ') 312 | batch_size = -1 313 | if mode not in ('pointwise', 'multinomial', 'greedy', 'greedy_evaluate', ): batch_size = 1 314 | devDataStream = NP2P_data_stream.DataStream(testset, enc_word_vocab, dec_word_vocab, char_vocab, options=FLAGS, 315 | isShuffle=False, isLoop=False, isSort=True, batch_size=batch_size) 316 | print('Number of instances in testDataStream: {}'.format(devDataStream.get_num_instance())) 317 | print('Number of batches in testDataStream: {}'.format(devDataStream.get_num_batch())) 318 | 319 | best_path = model_prefix + ".best.model" 320 | with tf.Graph().as_default(): 321 | initializer = tf.random_uniform_initializer(-0.01, 0.01) 322 | with tf.name_scope("Valid"): 323 | with tf.variable_scope("Model", reuse=False, initializer=initializer): 324 | valid_graph = ModelGraph(enc_word_vocab=enc_word_vocab, dec_word_vocab=dec_word_vocab, char_vocab=char_vocab, 325 | options=FLAGS, mode="decode") 326 | 327 | ## remove word _embedding 328 | vars_ = {} 329 | for var in tf.all_variables(): 330 | if FLAGS.fix_word_vec and "word_embedding" in var.name: continue 331 | if not var.name.startswith("Model"): continue 332 | vars_[var.name.split(":")[0]] = var 333 | saver = tf.train.Saver(vars_) 334 | 335 | initializer = tf.global_variables_initializer() 336 | sess = tf.Session() 337 | sess.run(initializer) 338 | 339 | saver.restore(sess, best_path) # restore the model 340 | 341 | total = 0 342 | correct = 0 343 | outfile = open(out_path, 'wt') 344 | total_num = devDataStream.get_num_batch() 345 | devDataStream.reset() 346 | total_dur = 0 347 | for i in range(total_num): 348 | cur_batch = devDataStream.get_batch(i) 349 | if mode in ['greedy', 'multinomial']: 350 | print('Batch {}'.format(i)) 351 | (sentences, prediction_lengths, generator_input_idx, 352 | generator_output_idx) = search(sess, valid_graph, dec_word_vocab, cur_batch, FLAGS, decode_mode=mode) 353 | for j in xrange(cur_batch.batch_size): 354 | outfile.write(cur_batch.id[j].encode('utf-8') + "\n") 355 | outfile.write(cur_batch.target_ref[j].encode('utf-8') + "\n") 356 | outfile.write(sentences[j].encode('utf-8') + "\n") 357 | outfile.write("========\n") 358 | outfile.flush() 359 | else: # beam search 360 | print('Instance {}'.format(i)) 361 | hyps, dur = run_beam_search(sess, valid_graph, dec_word_vocab, cur_batch, FLAGS) 362 | total_dur += dur 363 | outfile.write(cur_batch.id[0].encode('utf-8') + "\n") 364 | outfile.write(cur_batch.target_ref[0].encode('utf-8') + "\n") 365 | for j in xrange(1): 366 | hyp = hyps[j] 367 | cur_passage = cur_batch.source[0] 368 | cur_id2phrase = None 369 | cur_sent = hyp.idx_seq_to_string(cur_passage, cur_id2phrase, dec_word_vocab, FLAGS) 370 | outfile.write(cur_sent.encode('utf-8') + "\n") 371 | outfile.write("--------\n") 372 | outfile.write("========\n") 373 | outfile.flush() 374 | outfile.close() 375 | print('Total encoding time {}'.format(total_dur)) 376 | 377 | 378 | 379 | --------------------------------------------------------------------------------