├── README.md
└── src
├── GAN_main.py
├── RL_helper.py
├── RL_main.py
├── constants.py
├── data_generation
├── create_amazon_data.py
├── create_amazon_pqa_data.py
├── create_amazon_pqa_data_candqs.py
├── create_amazon_pqa_data_from_asins.py
├── create_lucene_amazon.py
├── create_pqa_data.py
├── filter_amazon_data_byid.py
├── filter_output_perid.py
├── remove_unks.py
├── run_create_amazon_data.sh
├── run_create_amazon_pqa_data.sh
├── run_create_amazon_pqa_data_from_asins.sh
├── run_create_lucene_amazon.sh
├── run_create_pqa_data.sh
├── run_filter_amazon_data_byid.sh
└── run_filter_output_perid.sh
├── decode.py
├── embedding_generation
├── create_we_vocab.py
├── extract_amazon_data.py
├── extract_data.py
├── run_create_we_vocab.sh
├── run_extract_amazon_data.sh
├── run_extract_data.sh
└── run_glove.sh
├── evaluation
├── all_ngrams.pl
├── calculate_diversity.sh
├── calculate_inter_annotator_agreement.py
├── combine_refs_for_meteor.py
├── create_amazon_multi_refs.py
├── create_crowdflower_data.py
├── create_crowdflower_data_beam.py
├── create_crowdflower_data_compare_ques.py
├── create_crowdflower_data_compare_ques_allpairs.py
├── create_crowdflower_data_specificity.py
├── create_crowdflower_data_specificity_multi.py
├── create_preds_for_refs.py
├── eval_HK
├── eval_HK_spec
├── eval_aus
├── read_crowdflower_full_results.py
├── read_crowdflower_full_results_compare_ques.py
├── read_crowdflower_results.py
├── read_crowdflower_results_binary.py
├── read_crowdflower_results_both.py
├── read_crowdflower_results_compare_ques.py
├── read_crowdflower_results_diff_formats.py
├── read_crowdflower_results_style.py
├── run_bleu.sh
├── run_create_amazon_multi_refs.sh
├── run_create_crowdflower_data.sh
├── run_create_crowdflower_data_beam.sh
├── run_create_crowdflower_data_compare_ques.sh
├── run_create_crowdflower_data_compare_ques_allpairs.sh
├── run_create_crowdflower_data_specificity.sh
├── run_create_crowdflower_data_specificity_multi.sh
├── run_create_preds_for_refs.sh
├── run_meteor.sh
└── run_meteor_HK.sh
├── lucene
├── create_amazon_lucene_baseline.py
├── create_stackexchange_lucene_baseline.py
├── run_create_amazon_lucene_baseline.sh
└── run_create_stackexchange_lucene_baseline.sh
├── main.py
├── run_GAN_decode.sh
├── run_GAN_decode_HK.sh
├── run_GAN_main.sh
├── run_GAN_main_HK.sh
├── run_GAN_main_electronics.sh
├── run_RL_decode.sh
├── run_RL_decode_HK.sh
├── run_RL_main.sh
├── run_RL_main_HK.sh
├── run_decode.sh
├── run_decode_HK.sh
├── run_decode_electronics.sh
├── run_main.sh
├── run_main_HK.sh
├── run_main_electronics.sh
├── run_pretrain_ans.sh
├── run_pretrain_ans_HK.sh
├── run_pretrain_util.sh
├── run_pretrain_util_HK.sh
├── seq2seq
├── GAN_train.py
├── RL_beam_decoder.py
├── RL_evaluate.py
├── RL_inference.py
├── RL_train.py
├── __init__.py
├── ans_train.py
├── attn.py
├── attnDecoderRNN.py
├── baselineFF.py
├── encoderRNN.py
├── evaluate.py
├── helper.py
├── main.py
├── masked_cross_entropy.py
├── prepare_data.py
├── read_data.py
└── train.py
└── utility
├── FeedForward.py
├── RL_evaluate.py
├── RL_train.py
├── RNN.py
├── __init__.py
├── combine_pickle.py
├── data_loader.py
├── evaluate_utility.py
├── helper_utility.py
├── main.py
├── rnn_classifier.py
├── run_combine_domains.sh
├── run_data_loader.sh
├── run_rnn_classifier.sh
└── train_utility.py
/README.md:
--------------------------------------------------------------------------------
1 | # Repository information
2 |
3 | This repository contains data and code for the paper below:
4 |
5 |
6 | Answer-based Adversarial Training for Generating Clarification Questions
7 | Sudha Rao (Sudha.Rao@microsoft.com) and Hal Daumé III (me@hal3.name)
8 | Proceedings of NAACL-HLT 2019
9 |
10 | # Downloading data
11 |
12 | * Download embeddings from https://go.umd.edu/clarification_questions_embeddings
13 | and save them into the repository folder
14 | * Download data from https://go.umd.edu/clarification_question_generation_dataset
15 | Unzip the two folders inside and copy them into the repository folder
16 |
17 | # Training models on StackExchange dataset
18 |
19 | * To train an MLE model, run src/run_main.sh
20 |
21 | * To train a Max-Utility model, follow these three steps:
22 |
23 | * run src/run_pretrain_ans.sh
24 |
25 | * run src/run_pretrain_util.sh
26 |
27 | * run src/run_RL_main.sh
28 |
29 | * To train a GAN-Utility model, follow these three steps (note, you can skip first two steps if you have already ran them for Max-Utility model):
30 |
31 | * run src/run_pretrain_ans.sh
32 |
33 | * run src/run_pretrain_util.sh
34 |
35 | * run src/run_GAN_main.sh
36 |
37 | # Training models on Amazon (Home & Kitchen) dataset
38 |
39 | * To train an MLE model, run src/run_main_HK.sh
40 |
41 | * To train a Max-Utility model, follow these three steps:
42 |
43 | * run src/run_pretrain_ans_HK.sh
44 |
45 | * run src/run_pretrain_util_HK.sh
46 |
47 | * run src/run_RL_main_HK.sh
48 |
49 | * To train a GAN-Utility model, follow these three steps (note, you can skip first two steps if you have already ran them for Max-Utility model):
50 |
51 | * run src/run_pretrain_ans_HK.sh
52 |
53 | * run src/run_pretrain_util_HK.sh
54 |
55 | * run src/run_GAN_main_HK.sh
56 |
57 | # Generating outputs using trained models
58 |
59 | * Run following scripts to generate outputs for models trained on StackExchange dataset:
60 |
61 | * For MLE model, run src/run_decode.sh
62 |
63 | * For Max-Utility model, run src/run_RL_decode.sh
64 |
65 | * For GAN-Utility model, run src/run_GAN_decode.sh
66 |
67 | * Run following scripts to generate outputs for models trained on Amazon dataset:
68 |
69 | * For MLE model, run src/run_decode_HK.sh
70 |
71 | * For Max-Utility model, run src/run_RL_decode_HK.sh
72 |
73 | * For GAN-Utility model, run src/run_GAN_decode_HK.sh
74 |
75 | # Evaluating generated outputs
76 |
77 | * For StackExchange dataset, reference for a subset of the test set was collected using human annotators.
78 | Hence we first create a version of the predictions file for which we have references by running following:
79 | src/evaluation/run_create_preds_for_refs.sh
80 |
81 | * For Amazon dataset, we have references for all instances in the test set.
82 |
83 | * We remove tokens from the generated outputs by simply removing them from the predictions file.
84 |
85 | * For BLEU score, run src/evaluation/run_bleu.sh
86 |
87 | * For METEOR score, run src/evaluation/run_meteor.sh
88 |
89 | * For Diversity score, run src/evaluation/calculate_diversiy.sh
90 |
--------------------------------------------------------------------------------
/src/RL_helper.py:
--------------------------------------------------------------------------------
1 | from constants import *
2 | import math
3 | import numpy as np
4 | import nltk
5 | import time
6 | import torch
7 |
8 |
9 | def as_minutes(s):
10 | m = math.floor(s / 60)
11 | s -= m * 60
12 | return '%dm %ds' % (m, s)
13 |
14 |
15 | def time_since(since, percent):
16 | now = time.time()
17 | s = now - since
18 | es = s / (percent)
19 | rs = es - s
20 | return '%s (- %s)' % (as_minutes(s), as_minutes(rs))
21 |
22 |
23 | def iterate_minibatches(p, pl, q, ql, pq, pql, a, al, batch_size, shuffle=True):
24 | if shuffle:
25 | indices = np.arange(len(p))
26 | np.random.shuffle(indices)
27 | for start_idx in range(0, len(p) - batch_size + 1, batch_size):
28 | if shuffle:
29 | ex = indices[start_idx:start_idx + batch_size]
30 | else:
31 | ex = slice(start_idx, start_idx + batch_size)
32 | yield np.array(p)[ex], np.array(pl)[ex], np.array(q)[ex], np.array(ql)[ex], \
33 | np.array(pq)[ex], np.array(pql)[ex], np.array(a)[ex], np.array(al)[ex]
34 |
35 |
36 | def reverse_dict(word2index):
37 | index2word = {}
38 | for w in word2index:
39 | ix = word2index[w]
40 | index2word[ix] = w
41 | return index2word
42 |
43 |
44 | def calculate_bleu(true, true_lens, pred, pred_lens, index2word, max_len):
45 | sent_bleu_scores = torch.zeros(len(pred))
46 | for i in range(len(pred)):
47 | true_sent = [index2word[idx] for idx in true[i][:true_lens[i]]]
48 | pred_sent = [index2word[idx] for idx in pred[i][:pred_lens[i]]]
49 | sent_bleu_scores[i] = nltk.translate.bleu_score.sentence_bleu(true_sent, pred_sent)
50 | if USE_CUDA:
51 | sent_bleu_scores = sent_bleu_scores.cuda()
52 | return sent_bleu_scores
53 |
--------------------------------------------------------------------------------
/src/constants.py:
--------------------------------------------------------------------------------
1 | USE_CUDA = True
2 |
3 | # Configure models
4 | HIDDEN_SIZE = 100
5 | DROPOUT = 0.5
6 |
7 | # Configure training/optimization
8 | LEARNING_RATE = 0.0001
9 | DECODER_LEARNING_RATIO = 5.0
10 |
11 | PAD_token = ''
12 | SOS_token = ''
13 | EOP_token = ''
14 | EOS_token = ''
15 | UNK_token = ''
16 | # UNK_token = 'unk'
17 | SPECIFIC_token = ''
18 | GENERIC_token = ''
19 |
20 | BEAM_SIZE = 5
21 |
--------------------------------------------------------------------------------
/src/data_generation/create_amazon_pqa_data.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import gzip
3 | import nltk
4 | import pdb
5 | import sys, os
6 | import re
7 | from collections import defaultdict
8 |
9 | def parse(path):
10 | g = gzip.open(path, 'r')
11 | for l in g:
12 | yield eval(l)
13 |
14 | exception_chars = ['|', '/', '\\', '-', '(', ')', '!', ':', ';', '<', '>']
15 |
16 | def preprocess(text):
17 | text = text.replace('|', ' ')
18 | text = text.replace('/', ' ')
19 | text = text.replace('\\', ' ')
20 | text = text.lower()
21 | #text = re.sub(r'\W+', ' ', text)
22 | ret_text = ''
23 | for sent in nltk.sent_tokenize(text):
24 | ret_text += ' '.join(nltk.word_tokenize(sent)) + ' '
25 | return ret_text
26 |
27 | def main(args):
28 | products = {}
29 | for v in parse(args.metadata_fname):
30 | if 'description' not in v or 'title' not in v:
31 | continue
32 | asin = v['asin']
33 | title = preprocess(v['title'])
34 | description = preprocess(v['description'])
35 | product = title + ' . ' + description
36 | products[asin] = product
37 |
38 | train_asin_file = open(args.train_asin_fname, 'w')
39 | train_context_file = open(args.train_context_fname, 'w')
40 | train_ques_file = open(args.train_ques_fname, 'w')
41 | train_ans_file = open(args.train_ans_fname, 'w')
42 | tune_asin_file = open(args.tune_asin_fname, 'w')
43 | tune_context_file = open(args.tune_context_fname, 'w')
44 | tune_ques_file = open(args.tune_ques_fname, 'w')
45 | tune_ans_file = open(args.tune_ans_fname, 'w')
46 | test_asin_file = open(args.test_asin_fname, 'w')
47 | test_context_file = open(args.test_context_fname, 'w')
48 | test_ques_file = open(args.test_ques_fname, 'w')
49 | test_ans_file = open(args.test_ans_fname, 'w')
50 |
51 | asins = []
52 | contexts = []
53 | questions = []
54 | answers = []
55 | for v in parse(args.qa_data_fname):
56 | asin = v['asin']
57 | if asin not in products or 'answer' not in v:
58 | continue
59 | question = preprocess(v['question'])
60 | answer = preprocess(v['answer'])
61 | if not answer:
62 | continue
63 | asins.append(asin)
64 | contexts.append(products[asin])
65 | questions.append(question)
66 | answers.append(answer)
67 | N = len(contexts)
68 | for i in range(int(N*0.8)):
69 | train_asin_file.write(asins[i]+'\n')
70 | train_context_file.write(contexts[i]+'\n')
71 | train_ques_file.write(questions[i]+'\n')
72 | train_ans_file.write(answers[i]+'\n')
73 | for i in range(int(N*0.8), int(N*0.9)):
74 | tune_asin_file.write(asins[i]+'\n')
75 | tune_context_file.write(contexts[i]+'\n')
76 | tune_ques_file.write(questions[i]+'\n')
77 | tune_ans_file.write(answers[i]+'\n')
78 | for i in range(int(N*0.9), N):
79 | test_asin_file.write(asins[i]+'\n')
80 | test_context_file.write(contexts[i]+'\n')
81 | test_ques_file.write(questions[i]+'\n')
82 | test_ans_file.write(answers[i]+'\n')
83 |
84 | train_asin_file.close()
85 | train_context_file.close()
86 | train_ques_file.close()
87 | train_ans_file.close()
88 | tune_asin_file.close()
89 | tune_context_file.close()
90 | tune_ques_file.close()
91 | tune_ans_file.close()
92 | test_asin_file.close()
93 | test_context_file.close()
94 | test_ques_file.close()
95 | test_ans_file.close()
96 |
97 |
98 | if __name__ == "__main__":
99 | argparser = argparse.ArgumentParser(sys.argv[0])
100 | argparser.add_argument("--qa_data_fname", type = str)
101 | argparser.add_argument("--metadata_fname", type = str)
102 | argparser.add_argument("--train_asin_fname", type = str)
103 | argparser.add_argument("--train_context_fname", type = str)
104 | argparser.add_argument("--train_ques_fname", type = str)
105 | argparser.add_argument("--train_ans_fname", type = str)
106 | argparser.add_argument("--tune_asin_fname", type = str)
107 | argparser.add_argument("--tune_context_fname", type = str)
108 | argparser.add_argument("--tune_ques_fname", type = str)
109 | argparser.add_argument("--tune_ans_fname", type = str)
110 | argparser.add_argument("--test_asin_fname", type = str)
111 | argparser.add_argument("--test_context_fname", type = str)
112 | argparser.add_argument("--test_ques_fname", type = str)
113 | argparser.add_argument("--test_ans_fname", type = str)
114 | args = argparser.parse_args()
115 | print args
116 | print ""
117 | main(args)
118 |
119 |
--------------------------------------------------------------------------------
/src/data_generation/create_amazon_pqa_data_from_asins.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import gzip
3 | import nltk
4 | import pdb
5 | import sys, os
6 | import re
7 | from collections import defaultdict
8 |
9 |
10 | def parse(path):
11 | g = gzip.open(path, 'r')
12 | for l in g:
13 | yield eval(l)
14 |
15 |
16 | exception_chars = ['|', '/', '\\', '-', '(', ')', '!', ':', ';', '<', '>']
17 |
18 |
19 | def preprocess(text):
20 | text = text.replace('|', ' ')
21 | text = text.replace('/', ' ')
22 | text = text.replace('\\', ' ')
23 | text = text.lower()
24 | #text = re.sub(r'\W+', ' ', text)
25 | ret_text = ''
26 | for sent in nltk.sent_tokenize(text):
27 | ret_text += ' '.join(nltk.word_tokenize(sent)) + ' '
28 | return ret_text
29 |
30 |
31 | def main(args):
32 | products = {}
33 | for v in parse(args.metadata_fname):
34 | if 'description' not in v or 'title' not in v:
35 | continue
36 | asin = v['asin']
37 | title = preprocess(v['title'])
38 | description = preprocess(v['description'])
39 | product = title + ' . ' + description
40 | products[asin] = product
41 |
42 | train_asin_file = open(args.train_asin_fname, 'r')
43 | train_ans_file = open(args.train_ans_fname, 'w')
44 | tune_asin_file = open(args.tune_asin_fname, 'w')
45 | tune_context_file = open(args.tune_context_fname, 'w')
46 | tune_ques_file = open(args.tune_ques_fname, 'w')
47 | tune_ans_file = open(args.tune_ans_fname, 'w')
48 | test_asin_file = open(args.test_asin_fname, 'r')
49 | test_ans_file = open(args.test_ans_fname, 'w')
50 |
51 | train_asins = []
52 | test_asins = []
53 | for line in train_asin_file.readlines():
54 | train_asins.append(line.strip('\n'))
55 | for line in test_asin_file.readlines():
56 | test_asins.append(line.strip('\n'))
57 |
58 | asins = []
59 | contexts = {}
60 | questions = {}
61 | answers = {}
62 |
63 | for v in parse(args.qa_data_fname):
64 | asin = v['asin']
65 | if asin not in products or 'answer' not in v:
66 | continue
67 | question = preprocess(v['question'])
68 | answer = preprocess(v['answer'])
69 | if not answer:
70 | continue
71 | asins.append(asin)
72 | contexts[asin] = products[asin]
73 | questions[asin] = question
74 | answers[asin] = answer
75 |
76 | for asin in train_asins:
77 | train_ans_file.write(answers[asin]+'\n')
78 | for asin in asins:
79 | if asin in train_asins or asin in test_asins:
80 | continue
81 | tune_asin_file.write(asin+'\n')
82 | tune_context_file.write(contexts[asin]+'\n')
83 | tune_ques_file.write(questions[asin]+'\n')
84 | tune_ans_file.write(answers[asin]+'\n')
85 | for asin in test_asins:
86 | test_ans_file.write(answers[asin]+'\n')
87 |
88 | train_asin_file.close()
89 | train_ans_file.close()
90 | tune_asin_file.close()
91 | tune_context_file.close()
92 | tune_ques_file.close()
93 | tune_ans_file.close()
94 | test_asin_file.close()
95 | test_ans_file.close()
96 |
97 |
98 | if __name__ == "__main__":
99 | argparser = argparse.ArgumentParser(sys.argv[0])
100 | argparser.add_argument("--qa_data_fname", type = str)
101 | argparser.add_argument("--metadata_fname", type = str)
102 | argparser.add_argument("--train_asin_fname", type = str)
103 | argparser.add_argument("--train_ans_fname", type = str)
104 | argparser.add_argument("--tune_asin_fname", type = str)
105 | argparser.add_argument("--tune_context_fname", type = str)
106 | argparser.add_argument("--tune_ques_fname", type = str)
107 | argparser.add_argument("--tune_ans_fname", type = str)
108 | argparser.add_argument("--test_asin_fname", type = str)
109 | argparser.add_argument("--test_ans_fname", type = str)
110 | args = argparser.parse_args()
111 | print args
112 | print ""
113 | main(args)
114 |
115 |
--------------------------------------------------------------------------------
/src/data_generation/create_lucene_amazon.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import gzip
3 | import nltk
4 | import pdb
5 | import sys, os
6 | import re
7 | from collections import defaultdict
8 |
9 | def parse(path):
10 | g = gzip.open(path, 'r')
11 | for l in g:
12 | yield eval(l)
13 |
14 | exception_chars = ['|', '/', '\\', '-', '(', ')', '!', ':', ';', '<', '>']
15 |
16 | def preprocess(text):
17 | text = text.replace('|', ' ')
18 | text = text.replace('/', ' ')
19 | text = text.replace('\\', ' ')
20 | text = text.lower()
21 | #text = re.sub(r'\W+', ' ', text)
22 | ret_text = ''
23 | for sent in nltk.sent_tokenize(text):
24 | ret_text += ' '.join(nltk.word_tokenize(sent)) + ' '
25 | return ret_text
26 |
27 | def main(args):
28 | brand_counts = defaultdict(int)
29 | for v in parse(args.metadata_fname):
30 | if 'description' not in v or 'title' not in v or 'brand' not in v:
31 | continue
32 | brand_counts[v['brand']] += 1
33 | low_ct_brands = []
34 | for brand, ct in brand_counts.iteritems():
35 | if ct < 100:
36 | low_ct_brands.append(brand)
37 | products = {}
38 | for v in parse(args.metadata_fname):
39 | if 'description' not in v or 'title' not in v or 'brand' not in v:
40 | continue
41 | if v['brand'] not in low_ct_brands:
42 | continue
43 | asin = v['asin']
44 | title = preprocess(v['title'])
45 | description = preprocess(v['description'])
46 | product = title + ' . ' + description
47 | products[asin] = product
48 |
49 | question_list = {}
50 | print 'Creating docs'
51 | for v in parse(args.qa_data_fname):
52 | asin = v['asin']
53 | if asin not in products:
54 | continue
55 | print asin
56 | if asin not in question_list:
57 | f = open(os.path.join(args.product_dir, asin + '.txt'), 'w')
58 | f.write(products[asin])
59 | f.close()
60 | question_list[asin] = []
61 | question = preprocess(v['question'])
62 | question_list[asin].append(question)
63 | f = open(os.path.join(args.question_dir, asin + '_' + str(len(question_list[asin])) + '.txt'), 'w')
64 | f.write(question)
65 | f.close()
66 |
67 | if __name__ == "__main__":
68 | argparser = argparse.ArgumentParser(sys.argv[0])
69 | argparser.add_argument("--qa_data_fname", type = str)
70 | argparser.add_argument("--metadata_fname", type = str)
71 | argparser.add_argument("--product_dir", type = str)
72 | argparser.add_argument("--question_dir", type = str)
73 | args = argparser.parse_args()
74 | print args
75 | print ""
76 | main(args)
77 |
78 |
--------------------------------------------------------------------------------
/src/data_generation/create_pqa_data.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import sys
3 | from collections import defaultdict
4 | import csv
5 | import math
6 | import nltk
7 |
8 |
9 | MAX_POST_LEN = 100
10 | MAX_QUES_LEN = 20
11 | MAX_ANS_LEN = 20
12 |
13 |
14 | def write_to_file(ids, args, posts, question_candidates, answer_candidates, split):
15 |
16 | if split == 'train':
17 | context_file = open(args.train_context_fname, 'w')
18 | question_file = open(args.train_question_fname, 'w')
19 | answer_file = open(args.train_answer_fname, 'w')
20 | elif split == 'tune':
21 | context_file = open(args.tune_context_fname, 'w')
22 | question_file = open(args.tune_question_fname, 'w')
23 | answer_file = open(args.tune_answer_fname, 'w')
24 | elif split == 'test':
25 | context_file = open(args.test_context_fname, 'w')
26 | question_file = open(args.test_question_fname, 'w')
27 | answer_file = open(args.test_answer_fname, 'w')
28 |
29 | for k, post_id in enumerate(ids):
30 | context_file.write(posts[post_id]+'\n')
31 | question_file.write(question_candidates[post_id][0]+'\n')
32 | answer_file.write(answer_candidates[post_id][0]+'\n')
33 |
34 | context_file.close()
35 | question_file.close()
36 | answer_file.close()
37 |
38 |
39 | def trim_by_len(s, max_len):
40 | s = s.lower().strip()
41 | words = s.split()
42 | s = ' '.join(words[:max_len])
43 | return s
44 |
45 |
46 | def trim_by_tfidf(posts, p_tf, p_idf):
47 | for post_id in posts:
48 | post = []
49 | words = posts[post_id].split()
50 | for w in words:
51 | tf = words.count(w)
52 | if tf*p_idf[w] >= MIN_TFIDF:
53 | post.append(w)
54 | if len(post) >= MAX_POST_LEN:
55 | break
56 | posts[post_id] = ' '.join(post)
57 | return posts
58 |
59 |
60 | def read_data(args):
61 | print("Reading lines...")
62 | posts = {}
63 | question_candidates = {}
64 | answer_candidates = {}
65 | p_tf = defaultdict(int)
66 | p_idf = defaultdict(int)
67 | with open(args.post_data_tsvfile, 'rb') as tsvfile:
68 | post_reader = csv.reader(tsvfile, delimiter='\t')
69 | N = 0
70 | for row in post_reader:
71 | if N == 0:
72 | N += 1
73 | continue
74 | N += 1
75 | post_id, title, post = row
76 | post = title + ' ' + post
77 | post = post.lower().strip()
78 | for w in post.split():
79 | p_tf[w] += 1
80 | for w in set(post.split()):
81 | p_idf[w] += 1
82 | posts[post_id] = post
83 |
84 | for w in p_idf:
85 | p_idf[w] = math.log(N*1.0/p_idf[w])
86 |
87 | # for asin, post in posts.iteritems():
88 | # posts[asin] = trim_by_len(post, MAX_POST_LEN)
89 | #posts = trim_by_tfidf(posts, p_tf, p_idf)
90 | N = 0
91 | with open(args.qa_data_tsvfile, 'rb') as tsvfile:
92 | qa_reader = csv.reader(tsvfile, delimiter='\t')
93 | i = 0
94 | for row in qa_reader:
95 | if i == 0:
96 | i += 1
97 | continue
98 | post_id, questions = row[0], row[1:11]
99 | answers = row[11:21]
100 | # questions = [trim_by_len(question, MAX_QUES_LEN) for question in questions]
101 | question_candidates[post_id] = questions
102 | # answers = [trim_by_len(answer, MAX_ANS_LEN) for answer in answers]
103 | answer_candidates[post_id] = answers
104 |
105 | train_ids = [train_id.strip('\n') for train_id in open(args.train_ids_file, 'r').readlines()]
106 | tune_ids = [tune_id.strip('\n') for tune_id in open(args.tune_ids_file, 'r').readlines()]
107 | test_ids = [test_id.strip('\n') for test_id in open(args.test_ids_file, 'r').readlines()]
108 |
109 | write_to_file(train_ids, args, posts, question_candidates, answer_candidates, 'train')
110 | write_to_file(tune_ids, args, posts, question_candidates, answer_candidates, 'tune')
111 | write_to_file(test_ids, args, posts, question_candidates, answer_candidates, 'test')
112 |
113 |
114 | if __name__ == "__main__":
115 | argparser = argparse.ArgumentParser(sys.argv[0])
116 | argparser.add_argument("--post_data_tsvfile", type = str)
117 | argparser.add_argument("--qa_data_tsvfile", type = str)
118 | argparser.add_argument("--train_ids_file", type = str)
119 | argparser.add_argument("--train_context_fname", type = str)
120 | argparser.add_argument("--train_question_fname", type = str)
121 | argparser.add_argument("--train_answer_fname", type=str)
122 | argparser.add_argument("--tune_ids_file", type = str)
123 | argparser.add_argument("--tune_context_fname", type = str)
124 | argparser.add_argument("--tune_question_fname", type = str)
125 | argparser.add_argument("--tune_answer_fname", type=str)
126 | argparser.add_argument("--test_ids_file", type = str)
127 | argparser.add_argument("--test_context_fname", type=str)
128 | argparser.add_argument("--test_question_fname", type = str)
129 | argparser.add_argument("--test_answer_fname", type = str)
130 | args = argparser.parse_args()
131 | print args
132 | print ""
133 | read_data(args)
134 |
--------------------------------------------------------------------------------
/src/data_generation/filter_amazon_data_byid.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import argparse
3 | from collections import defaultdict
4 |
5 |
6 | def main(args):
7 | train_ids_file = open(args.train_ids, 'r')
8 | train_answer_file = open(args.train_answer, 'r')
9 | train_candqs_ids_file = open(args.train_candqs_ids, 'r')
10 | train_answer_candqs_file = open(args.train_answer_candqs, 'w')
11 | train_ids = []
12 | uniq_id_ct = defaultdict(int)
13 | for line in train_ids_file.readlines():
14 | curr_id = line.strip('\n')
15 | uniq_id_ct[curr_id] += 1
16 | train_ids.append(curr_id+'_'+str(uniq_id_ct[curr_id]))
17 | i = 0
18 | train_answers = {}
19 | for line in train_answer_file.readlines():
20 | train_answers[train_ids[i]] = line.strip('\n')
21 | i += 1
22 | for line in train_candqs_ids_file.readlines():
23 | curr_id = line.strip('\n')
24 | try:
25 | train_answer_candqs_file.write(train_answers[curr_id])
26 | except:
27 | import pdb
28 | pdb.set_trace()
29 |
30 |
31 | if __name__ == '__main__':
32 | argparser = argparse.ArgumentParser(sys.argv[0])
33 | argparser.add_argument("--train_ids", type=str)
34 | argparser.add_argument("--train_answer", type=str)
35 | argparser.add_argument("--train_candqs_ids", type=str)
36 | argparser.add_argument("--train_answer_candqs", type=str)
37 | args = argparser.parse_args()
38 | print args
39 | print ""
40 | main(args)
41 |
--------------------------------------------------------------------------------
/src/data_generation/filter_output_perid.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import argparse
3 | from collections import defaultdict
4 |
5 |
6 | def main(args):
7 | test_output_file = open(args.test_output, 'r')
8 | test_ids_file = open(args.test_ids, 'r')
9 | test_output_perid_file = open(args.test_output_perid, 'w')
10 |
11 | # ongoing_id = None
12 | id_seq_in_output = []
13 | data_ct_per_id = defaultdict(int)
14 | for line in test_ids_file.readlines():
15 | curr_id = line.strip('\n')
16 | data_ct_per_id[curr_id] += 1
17 | if args.max_per_id is not None:
18 | if data_ct_per_id[curr_id] <= args.max_per_id:
19 | id_seq_in_output.append(curr_id)
20 | else:
21 | id_seq_in_output.append(curr_id)
22 | i = 0
23 | ongoing_id = None
24 | test_output_lines = test_output_file.readlines()
25 | test_output_perid = []
26 | for line in test_output_lines:
27 | if id_seq_in_output[i] != ongoing_id:
28 | test_output_perid.append(line)
29 | ongoing_id = id_seq_in_output[i]
30 | i += 1
31 | total_count = (len(test_output_perid) / args.batch_size - 1) * args.batch_size
32 | print total_count
33 | print len(test_output_perid)
34 | for line in test_output_perid[:total_count]:
35 | test_output_perid_file.write(line)
36 |
37 |
38 | if __name__ == '__main__':
39 | argparser = argparse.ArgumentParser(sys.argv[0])
40 | argparser.add_argument("--test_output", type=str)
41 | argparser.add_argument("--test_ids", type=str)
42 | argparser.add_argument("--test_output_perid", type=str)
43 | argparser.add_argument("--max_per_id", type=int, default=None)
44 | argparser.add_argument("--batch_size", type=int)
45 | args = argparser.parse_args()
46 | print args
47 | print ""
48 | main(args)
--------------------------------------------------------------------------------
/src/data_generation/remove_unks.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/raosudha89/clarification_question_generation_pytorch/23ae8aa0160eee70565751f4b6de13563a19d6ed/src/data_generation/remove_unks.py
--------------------------------------------------------------------------------
/src/data_generation/run_create_amazon_data.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | #SBATCH --job-name=candqs_Home_and_Kitchen
4 | #SBATCH --output=candqs_Home_and_Kitchen
5 | #SBATCH --qos=batch
6 | #SBATCH --mem=4g
7 | #SBATCH --time=4:00:00
8 |
9 | SITENAME=Home_and_Kitchen
10 | METADATA_DIR=/fs/clip-corpora/amazon_qa
11 | #DATA_DIR=/fs/clip-corpora/amazon_qa/$SITENAME
12 | DATA_DIR=/fs/clip-scratch/raosudha/amazon_qa/$SITENAME
13 | SCRIPT_DIR=/fs/clip-amr/clarification_question_generation/src-opennmt
14 | CQ_DATA_DIR=/fs/clip-scratch/raosudha/clarification_question_generation/data/amazon/$SITENAME
15 |
16 | mkdir $CQ_DATA_DIR
17 |
18 | #mkdir /fs/clip-corpora/amazon_qa/$SITENAME
19 | #mkdir /fs/clip-amr/clarification_question_generation/data/amazon/$SITENAME
20 |
21 | python $SCRIPT_DIR/create_amazon_data.py --prod_dir $DATA_DIR/prod_docs \
22 | --ques_dir $DATA_DIR/ques_docs \
23 | --metadata_fname $METADATA_DIR/meta_${SITENAME}.json.gz \
24 | --sim_prod_fname $DATA_DIR/lucene_similar_prods.txt \
25 | --sim_ques_fname $DATA_DIR/lucene_similar_ques.txt \
26 | --train_src_fname $CQ_DATA_DIR/train_src \
27 | --train_tgt_fname $CQ_DATA_DIR/train_tgt \
28 | --tune_src_fname $CQ_DATA_DIR/tune_src \
29 | --tune_tgt_fname $CQ_DATA_DIR/tune_tgt \
30 | --test_src_fname $CQ_DATA_DIR/test_src \
31 | --test_tgt_fname $CQ_DATA_DIR/test_tgt \
32 | --train_ids_file $CQ_DATA_DIR/train_ids \
33 | --tune_ids_file $CQ_DATA_DIR/tune_ids \
34 | --test_ids_file $CQ_DATA_DIR/test_ids \
35 | --candqs True\
36 | #--onlycontext True \
37 | #--simqs True \
38 | #--template True \
39 | #--nocontext True \
40 |
--------------------------------------------------------------------------------
/src/data_generation/run_create_amazon_pqa_data.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | #SBATCH --job-name=pqa_data_Home_and_Kitchen
4 | #SBATCH --output=pqa_data_Home_and_Kitchen
5 | #SBATCH --qos=batch
6 | #SBATCH --mem=36g
7 | #SBATCH --time=24:00:00
8 |
9 | SITENAME=Home_and_Kitchen
10 | DATA_DIR=/fs/clip-corpora/amazon_qa
11 | CQ_DATA_DIR=/fs/clip-scratch/raosudha/clarification_question_generation/joint_learning/$SITENAME
12 | SCRIPT_DIR=/fs/clip-amr/clarification_question_generation_pytorch/src
13 |
14 | export PATH="/fs/clip-amr/anaconda2/bin:$PATH"
15 |
16 | python $SCRIPT_DIR/create_amazon_pqa_data.py --qa_data_fname $DATA_DIR/qa_${SITENAME}.json.gz \
17 | --metadata_fname $DATA_DIR/meta_${SITENAME}.json.gz \
18 | --train_asin_fname $CQ_DATA_DIR/train_asin.txt \
19 | --train_context_fname $CQ_DATA_DIR/train_context.txt \
20 | --train_ques_fname $CQ_DATA_DIR/train_ques.txt \
21 | --train_ans_fname $CQ_DATA_DIR/train_ans.txt \
22 | --tune_asin_fname $CQ_DATA_DIR/tune_asin.txt \
23 | --tune_context_fname $CQ_DATA_DIR/tune_context.txt \
24 | --tune_ques_fname $CQ_DATA_DIR/tune_ques.txt \
25 | --tune_ans_fname $CQ_DATA_DIR/tune_ans.txt \
26 | --test_asin_fname $CQ_DATA_DIR/test_asin.txt \
27 | --test_context_fname $CQ_DATA_DIR/test_context.txt \
28 | --test_ques_fname $CQ_DATA_DIR/test_ques.txt \
29 | --test_ans_fname $CQ_DATA_DIR/test_ans.txt \
30 |
31 |
--------------------------------------------------------------------------------
/src/data_generation/run_create_amazon_pqa_data_from_asins.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | #SBATCH --job-name=pqa_data_Home_and_Kitchen
4 | #SBATCH --output=pqa_data_Home_and_Kitchen
5 | #SBATCH --qos=batch
6 | #SBATCH --mem=36g
7 | #SBATCH --time=24:00:00
8 |
9 | SITENAME=Home_and_Kitchen
10 | DATA_DIR=/fs/clip-corpora/amazon_qa
11 | CQ_DATA_DIR=/fs/clip-amr/clarification_question_generation_pytorch/$SITENAME
12 | SCRIPT_DIR=/fs/clip-amr/clarification_question_generation_pytorch/src
13 |
14 | export PATH="/fs/clip-amr/anaconda2/bin:$PATH"
15 |
16 | python $SCRIPT_DIR/create_amazon_pqa_data_from_asins.py --qa_data_fname $DATA_DIR/qa_${SITENAME}.json.gz \
17 | --metadata_fname $DATA_DIR/meta_${SITENAME}.json.gz \
18 | --train_asin_fname $CQ_DATA_DIR/train_asin.txt \
19 | --train_ans_fname $CQ_DATA_DIR/train_ans.txt \
20 | --tune_asin_fname $CQ_DATA_DIR/tune_asin.txt \
21 | --tune_context_fname $CQ_DATA_DIR/tune_context.txt \
22 | --tune_ques_fname $CQ_DATA_DIR/tune_ques.txt \
23 | --tune_ans_fname $CQ_DATA_DIR/tune_ans.txt \
24 | --test_asin_fname $CQ_DATA_DIR/test_asin.txt \
25 | --test_ans_fname $CQ_DATA_DIR/test_ans.txt \
26 |
27 |
--------------------------------------------------------------------------------
/src/data_generation/run_create_lucene_amazon.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | #SBATCH --job-name=lucene_data_Home_and_Kitchen
4 | #SBATCH --output=lucene_data_Home_and_Kitchen
5 | #SBATCH --qos=batch
6 | #SBATCH --mem=36g
7 | #SBATCH --time=24:00:00
8 |
9 | SITENAME=Home_and_Kitchen
10 | DATA_DIR=/fs/clip-corpora/amazon_qa
11 | SCRATCH_DATA_DIR=/fs/clip-scratch/raosudha/amazon_qa
12 | SCRIPT_DIR=/fs/clip-amr/clarification_question_generation/src
13 |
14 | rm -r $SCRATCH_DATA_DIR/${SITENAME}/prod_docs/
15 | rm -r $SCRATCH_DATA_DIR/${SITENAME}/ques_docs/
16 | mkdir -p $SCRATCH_DATA_DIR/${SITENAME}/prod_docs/
17 | mkdir -p $SCRATCH_DATA_DIR/${SITENAME}/ques_docs/
18 |
19 | export PATH="/fs/clip-amr/anaconda2/bin:$PATH"
20 |
21 | python $SCRIPT_DIR/create_lucene_amazon.py --qa_data_fname $DATA_DIR/qa_${SITENAME}.json.gz \
22 | --metadata_fname $DATA_DIR/meta_${SITENAME}.json.gz \
23 | --product_dir $SCRATCH_DATA_DIR/${SITENAME}/prod_docs/ \
24 | --question_dir $SCRATCH_DATA_DIR/${SITENAME}/ques_docs/
25 |
--------------------------------------------------------------------------------
/src/data_generation/run_create_pqa_data.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | #SBATCH --job-name=pqa_data_aus
4 | #SBATCH --output=pqa_data_aus
5 | #SBATCH --qos=batch
6 | #SBATCH --mem=36g
7 | #SBATCH --time=24:00:00
8 |
9 | SITENAME=askubuntu_unix_superuser
10 |
11 | DATA_DIR=/fs/clip-amr/ranking_clarification_questions/data/$SITENAME
12 | CQ_DATA_DIR=/fs/clip-amr/clarification_question_generation_pytorch/$SITENAME
13 | SCRIPT_DIR=/fs/clip-amr/clarification_question_generation_pytorch/src
14 |
15 | python $SCRIPT_DIR/create_pqa_data.py --post_data_tsvfile $DATA_DIR/post_data.tsv \
16 | --qa_data_tsvfile $DATA_DIR/qa_data.tsv \
17 | --train_ids_file $DATA_DIR/train_ids \
18 | --tune_ids_file $DATA_DIR/tune_ids \
19 | --test_ids_file $DATA_DIR/test_ids \
20 | --train_context_fname $CQ_DATA_DIR/train_context.txt \
21 | --train_question_fname $CQ_DATA_DIR/train_question.txt \
22 | --train_answer_fname $CQ_DATA_DIR/train_answer.txt \
23 | --tune_context_fname $CQ_DATA_DIR/tune_context.txt \
24 | --tune_question_fname $CQ_DATA_DIR/tune_question.txt \
25 | --tune_answer_fname $CQ_DATA_DIR/tune_answer.txt \
26 | --test_context_fname $CQ_DATA_DIR/test_context.txt \
27 | --test_question_fname $CQ_DATA_DIR/test_question.txt \
28 | --test_answer_fname $CQ_DATA_DIR/test_answer.txt \
29 |
--------------------------------------------------------------------------------
/src/data_generation/run_filter_amazon_data_byid.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | SITENAME=Home_and_Kitchen
4 | DATA_DIR=/fs/clip-corpora/amazon_qa
5 | OLD_CQ_DATA_DIR=/fs/clip-amr/clarification_question_generation/data/amazon/$SITENAME
6 | CQ_DATA_DIR=/fs/clip-scratch/raosudha/clarification_question_generation/joint_learning/$SITENAME
7 | SCRIPT_DIR=/fs/clip-amr/clarification_question_generation_pytorch/src
8 |
9 | export PATH="/fs/clip-amr/anaconda2/bin:$PATH"
10 |
11 | python $SCRIPT_DIR/filter_amazon_data_byid.py --train_ids $CQ_DATA_DIR/train_asin.txt \
12 | --train_answer $CQ_DATA_DIR/train_answer.txt \
13 | --train_candqs_ids $OLD_CQ_DATA_DIR/train_tgt_candqs.txt.ids \
14 | --train_answer_candqs $CQ_DATA_DIR/train_answer_candqs.txt \
15 |
16 |
--------------------------------------------------------------------------------
/src/data_generation/run_filter_output_perid.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | SITENAME=Home_and_Kitchen
4 | DATA_DIR=/fs/clip-corpora/amazon_qa
5 | OLD_CQ_DATA_DIR=/fs/clip-amr/clarification_question_generation/data/amazon/$SITENAME
6 | CQ_DATA_DIR=/fs/clip-scratch/raosudha/clarification_question_generation/joint_learning/$SITENAME
7 | SCRIPT_DIR=/fs/clip-amr/clarification_question_generation_pytorch/src
8 |
9 | export PATH="/fs/clip-amr/anaconda2/bin:$PATH"
10 |
11 | python $SCRIPT_DIR/filter_output_perid.py --test_ids $CQ_DATA_DIR/tune_asin.txt \
12 | --test_output $CQ_DATA_DIR/test_pred_question.txt.pretrained_greedy \
13 | --test_output_perid $CQ_DATA_DIR/test_pred_question.txt.pretrained_greedy.perid \
14 | --batch_size 128 \
15 | # --max_per_id 3 \
16 | # --test_output $CQ_DATA_DIR/test_pred_question.txt.GAN_mixer_pred_ans_3perid.epoch8.beam0 \
17 | # --test_output_perid $CQ_DATA_DIR/test_pred_question.txt.GAN_mixer_pred_ans_3perid.epoch8.beam0.perid \
18 | # --test_output $CQ_DATA_DIR/test_pred_question.txt.pretrained_greedy \
19 | # --test_output_perid $CQ_DATA_DIR/test_pred_question.txt.pretrained_greedy.perid \
20 |
21 |
--------------------------------------------------------------------------------
/src/decode.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import pickle as p
3 | import sys
4 | import torch
5 |
6 | from seq2seq.encoderRNN import *
7 | from seq2seq.attnDecoderRNN import *
8 | from seq2seq.RL_beam_decoder import *
9 | # from seq2seq.diverse_beam_decoder import *
10 | from seq2seq.RL_evaluate import *
11 | # from RL_helper import *
12 | from constants import *
13 |
14 |
15 | def main(args):
16 | print('Enter main')
17 | word_embeddings = p.load(open(args.word_embeddings, 'rb'))
18 | print('Loaded emb of size %d' % len(word_embeddings))
19 | word_embeddings = np.array(word_embeddings)
20 | word2index = p.load(open(args.vocab, 'rb'))
21 | index2word = reverse_dict(word2index)
22 |
23 | if args.test_ids is not None:
24 | test_data = read_data(args.test_context, args.test_question, None, args.test_ids,
25 | args.max_post_len, args.max_ques_len, args.max_ans_len, mode='test')
26 | else:
27 | test_data = read_data(args.test_context, args.test_question, None, None,
28 | args.max_post_len, args.max_ques_len, args.max_ans_len, mode='test')
29 |
30 | print('No. of test_data %d' % len(test_data))
31 | run_model(test_data, word_embeddings, word2index, index2word, args)
32 |
33 |
34 | def run_model(test_data, word_embeddings, word2index, index2word, args):
35 | print('Preprocessing test data..')
36 | te_id_seqs, te_post_seqs, te_post_lens, te_ques_seqs, te_ques_lens, \
37 | te_post_ques_seqs, te_post_ques_lens, te_ans_seqs, te_ans_lens = preprocess_data(test_data, word2index,
38 | args.max_post_len,
39 | args.max_ques_len,
40 | args.max_ans_len)
41 |
42 | print('Defining encoder decoder models')
43 | q_encoder = EncoderRNN(HIDDEN_SIZE, word_embeddings, n_layers=2, dropout=DROPOUT)
44 | q_decoder = AttnDecoderRNN(HIDDEN_SIZE, len(word2index), word_embeddings, n_layers=2)
45 |
46 | if USE_CUDA:
47 | device = torch.device('cuda:0')
48 | else:
49 | device = torch.device('cpu')
50 | q_encoder = q_encoder.to(device)
51 | q_decoder = q_decoder.to(device)
52 |
53 | # Load encoder, decoder params
54 | print('Loading encoded, decoder params')
55 | if USE_CUDA:
56 | q_encoder.load_state_dict(torch.load(args.q_encoder_params))
57 | q_decoder.load_state_dict(torch.load(args.q_decoder_params))
58 | else:
59 | q_encoder.load_state_dict(torch.load(args.q_encoder_params, map_location='cpu'))
60 | q_decoder.load_state_dict(torch.load(args.q_decoder_params, map_location='cpu'))
61 |
62 | out_fname = args.test_pred_question+'.'+args.model
63 | # out_fname = None
64 | if args.greedy:
65 | evaluate_seq2seq(word2index, index2word, q_encoder, q_decoder,
66 | te_id_seqs, te_post_seqs, te_post_lens, te_ques_seqs, te_ques_lens,
67 | args.batch_size, args.max_ques_len, out_fname)
68 | elif args.beam:
69 | evaluate_beam(word2index, index2word, q_encoder, q_decoder,
70 | te_id_seqs, te_post_seqs, te_post_lens, te_ques_seqs, te_ques_lens,
71 | args.batch_size, args.max_ques_len, out_fname)
72 | elif args.diverse_beam:
73 | evaluate_diverse_beam(word2index, index2word, q_encoder, q_decoder,
74 | te_id_seqs, te_post_seqs, te_post_lens, te_ques_seqs, te_ques_lens,
75 | args.batch_size, args.max_ques_len, out_fname)
76 | else:
77 | print('Please specify mode of decoding: --greedy OR --beam OR --diverse_beam')
78 |
79 |
80 | if __name__ == "__main__":
81 | argparser = argparse.ArgumentParser(sys.argv[0])
82 | argparser.add_argument("--test_context", type = str)
83 | argparser.add_argument("--test_question", type = str)
84 | argparser.add_argument("--test_answer", type = str)
85 | argparser.add_argument("--test_ids", type=str)
86 | argparser.add_argument("--test_pred_question", type = str)
87 | argparser.add_argument("--q_encoder_params", type = str)
88 | argparser.add_argument("--q_decoder_params", type = str)
89 | argparser.add_argument("--vocab", type = str)
90 | argparser.add_argument("--word_embeddings", type = str)
91 | argparser.add_argument("--max_post_len", type = int, default=300)
92 | argparser.add_argument("--max_ques_len", type = int, default=50)
93 | argparser.add_argument("--max_ans_len", type = int, default=50)
94 | argparser.add_argument("--batch_size", type = int, default=128)
95 | argparser.add_argument("--model", type=str)
96 | argparser.add_argument("--greedy", type=bool, default=False)
97 | argparser.add_argument("--beam", type=bool, default=False)
98 | argparser.add_argument("--diverse_beam", type=bool, default=False)
99 | args = argparser.parse_args()
100 | print(args)
101 | print("")
102 | main(args)
103 |
--------------------------------------------------------------------------------
/src/embedding_generation/create_we_vocab.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import pickle as p
3 |
4 | PAD_token = ''
5 | SOS_token = ''
6 | EOP_token = ''
7 | EOS_token = ''
8 |
9 | if __name__ == "__main__":
10 | if len(sys.argv) < 4:
11 | print("usage: python create_we_vocab.py ")
12 | sys.exit(0)
13 | word_vectors_file = open(sys.argv[1], 'r')
14 | word_embeddings = []
15 | vocab = {}
16 | vocab[PAD_token] = 0
17 | vocab[SOS_token] = 1
18 | vocab[EOP_token] = 2
19 | vocab[EOS_token] = 3
20 | word_embeddings.append(None)
21 | word_embeddings.append(None)
22 | word_embeddings.append(None)
23 | word_embeddings.append(None)
24 |
25 | i = 4
26 | for line in word_vectors_file.readlines():
27 | vals = line.rstrip().split(' ')
28 | vocab[vals[0]] = i
29 | word_embeddings.append([float(v) for v in vals[1:]])
30 | i += 1
31 |
32 | word_embeddings[0] = [0]*len(word_embeddings[4])
33 | word_embeddings[1] = [0]*len(word_embeddings[4])
34 | word_embeddings[2] = [0]*len(word_embeddings[4])
35 | word_embeddings[3] = [0]*len(word_embeddings[4])
36 |
37 | p.dump(word_embeddings, open(sys.argv[2], 'wb'))
38 | p.dump(vocab, open(sys.argv[3], 'wb'))
39 |
40 |
--------------------------------------------------------------------------------
/src/embedding_generation/extract_amazon_data.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import gzip
3 | import nltk
4 | import pdb
5 | import sys, os
6 | import re
7 | from collections import defaultdict
8 |
9 | def parse(path):
10 | g = gzip.open(path, 'r')
11 | for l in g:
12 | yield eval(l)
13 |
14 | exception_chars = ['|', '/', '\\', '-', '(', ')', '!', ':', ';', '<', '>']
15 |
16 | def preprocess(text):
17 | text = text.replace('|', ' ')
18 | text = text.replace('/', ' ')
19 | text = text.replace('\\', ' ')
20 | text = text.lower()
21 | ret_text = ''
22 | for sent in nltk.sent_tokenize(text):
23 | ret_text += ' '.join(nltk.word_tokenize(sent)) + ' '
24 | return ret_text
25 |
26 | def main(args):
27 | output_file = open(args.output_fname, 'w')
28 | for v in parse(args.metadata_fname):
29 | if 'description' not in v or 'title' not in v:
30 | continue
31 | title = preprocess(v['title'])
32 | description = preprocess(v['description'])
33 | product = title + ' . ' + description
34 | output_file.write(product+'\n')
35 |
36 | for v in parse(args.qa_data_fname):
37 | if 'answer' not in v:
38 | continue
39 | question = preprocess(v['question'])
40 | answer = preprocess(v['answer'])
41 | output_file.write(question+'\n')
42 | output_file.write(answer+'\n')
43 | output_file.close()
44 |
45 | if __name__ == "__main__":
46 | argparser = argparse.ArgumentParser(sys.argv[0])
47 | argparser.add_argument("--qa_data_fname", type = str)
48 | argparser.add_argument("--metadata_fname", type = str)
49 | argparser.add_argument("--output_fname", type = str)
50 | args = argparser.parse_args()
51 | print args
52 | print ""
53 | main(args)
54 |
55 |
--------------------------------------------------------------------------------
/src/embedding_generation/extract_data.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import csv
3 | import sys
4 |
5 | def main(args):
6 | print("Reading lines...")
7 | output_file = open(args.output_data, 'a')
8 | with open(args.post_data_tsv, 'rb') as tsvfile:
9 | post_reader = csv.DictReader(tsvfile, delimiter='\t')
10 | for row in post_reader:
11 | post = row['title'] + ' ' + row['post']
12 | post = post.lower().strip()
13 | output_file.write(post+'\n')
14 |
15 | with open(args.qa_data_tsv, 'rb') as tsvfile:
16 | qa_reader = csv.DictReader(tsvfile, delimiter='\t')
17 | for row in qa_reader:
18 | ques = row['q1'].lower().strip()
19 | ans = row['a1'].lower().strip()
20 | output_file.write(ques+'\n')
21 | output_file.write(ans+'\n')
22 |
23 | output_file.close()
24 |
25 | if __name__ == "__main__":
26 | argparser = argparse.ArgumentParser(sys.argv[0])
27 | argparser.add_argument("--post_data_tsv", type = str)
28 | argparser.add_argument("--qa_data_tsv", type = str)
29 | argparser.add_argument("--output_data", type = str)
30 | args = argparser.parse_args()
31 | print args
32 | print ""
33 | main(args)
34 |
35 |
--------------------------------------------------------------------------------
/src/embedding_generation/run_create_we_vocab.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | #SITENAME=askubuntu_unix_superuser
4 | SITENAME=Home_and_Kitchen
5 |
6 | SCRIPTS_DIR=/fs/clip-amr/clarification_question_generation_pytorch/src/embedding_generation
7 | EMB_DIR=/fs/clip-amr/clarification_question_generation_pytorch/embeddings/$SITENAME/200
8 |
9 | python $SCRIPTS_DIR/create_we_vocab.py $EMB_DIR/vectors.txt $EMB_DIR/word_embeddings.p $EMB_DIR/vocab.p
10 |
11 |
--------------------------------------------------------------------------------
/src/embedding_generation/run_extract_amazon_data.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | #SBATCH --job-name=emb_data_Home_and_Kitchen
4 | #SBATCH --output=emb_data_Home_and_Kitchen
5 | #SBATCH --qos=batch
6 | #SBATCH --mem=36g
7 | #SBATCH --time=24:00:00
8 |
9 | SITENAME=Home_and_Kitchen
10 | DATA_DIR=/fs/clip-corpora/amazon_qa
11 | EMB_DATA_DIR=/fs/clip-amr/clarification_question_generation_pytorch/embeddings/$SITENAME
12 | SCRIPT_DIR=/fs/clip-amr/clarification_question_generation_pytorch/src/embedding_generation
13 |
14 | python $SCRIPT_DIR/extract_amazon_data.py --qa_data_fname $DATA_DIR/qa_${SITENAME}.json.gz \
15 | --metadata_fname $DATA_DIR/meta_${SITENAME}.json.gz \
16 | --output_fname $EMB_DATA_DIR/${SITENAME}_data.txt
17 |
18 |
--------------------------------------------------------------------------------
/src/embedding_generation/run_extract_data.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | #SITENAME=askubuntu.com
4 | #SITENAME=unix.stackexchange.com
5 | SITENAME=superuser.com
6 |
7 | CQ_DATA_DIR=/fs/clip-amr/ranking_clarification_questions/data/$SITENAME
8 | OUT_DATA_DIR=/fs/clip-amr/clarification_question_generation_pytorch/embeddings
9 | SCRIPT_DIR=/fs/clip-amr/clarification_question_generation_pytorch/src/embedding_generation
10 |
11 | python $SCRIPT_DIR/extract_data.py --post_data_tsv $CQ_DATA_DIR/post_data.tsv \
12 | --qa_data_tsv $CQ_DATA_DIR/qa_data.tsv \
13 | --output_data $OUT_DATA_DIR/${SITENAME}.data.txt
14 |
15 |
--------------------------------------------------------------------------------
/src/embedding_generation/run_glove.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Makes programs, downloads sample data, trains a GloVe model, and then evaluates it.
4 | # One optional argument can specify the language used for eval script: matlab, octave or [default] python
5 |
6 | SITENAME=askubuntu_unix_superuser
7 | #SITENAME=Home_and_Kitchen
8 |
9 | DATADIR=clarification_question_generation_pytorch/embeddings/$SITENAME/200_10Kvocab
10 |
11 | CORPUS=clarification_question_generation_pytorch/embeddings/$SITENAME/${SITENAME}_data.txt
12 | VOCAB_FILE=$DATADIR/vocab.txt
13 | COOCCURRENCE_FILE=$DATADIR/cooccurrence.bin
14 | COOCCURRENCE_SHUF_FILE=$DATADIR/cooccurrence.shuf.bin
15 | BUILDDIR=GloVe-1.2/build #Download from https://nlp.stanford.edu/projects/glove/
16 | SAVE_FILE=$DATADIR/vectors
17 | VERBOSE=2
18 | MEMORY=4.0
19 | #VOCAB_MIN_COUNT=10
20 | VOCAB_MIN_COUNT=50
21 | #VECTOR_SIZE=100
22 | VECTOR_SIZE=200
23 | MAX_ITER=30
24 | WINDOW_SIZE=15
25 | BINARY=2
26 | NUM_THREADS=4
27 | X_MAX=10
28 |
29 | $BUILDDIR/vocab_count -min-count $VOCAB_MIN_COUNT -verbose $VERBOSE < $CORPUS > $VOCAB_FILE
30 | $BUILDDIR/cooccur -memory $MEMORY -vocab-file $VOCAB_FILE -verbose $VERBOSE -window-size $WINDOW_SIZE < $CORPUS > $COOCCURRENCE_FILE
31 | $BUILDDIR/shuffle -memory $MEMORY -verbose $VERBOSE < $COOCCURRENCE_FILE > $COOCCURRENCE_SHUF_FILE
32 | $BUILDDIR/glove -save-file $SAVE_FILE -eta 0.05 -threads $NUM_THREADS -input-file $COOCCURRENCE_SHUF_FILE -x-max $X_MAX -iter $MAX_ITER -vector-size $VECTOR_SIZE -binary $BINARY -vocab-file $VOCAB_FILE -verbose $VERBOSE
33 |
--------------------------------------------------------------------------------
/src/evaluation/all_ngrams.pl:
--------------------------------------------------------------------------------
1 | #!/usr/bin/perl -w
2 | use strict;
3 |
4 | my $usage = "usage: cat FILE | all_ngrams.pl [-noboundary|-smallboundary] \n";
5 |
6 | my $boundary = 2;
7 | my $N = 1;
8 |
9 | while (1) {
10 | my $tmp = shift or die $usage;
11 | if ($tmp eq '-noboundary') { $boundary = 0; }
12 | elsif ($tmp eq '-smallboundary') { $boundary = 1; }
13 | elsif ($tmp =~ /^[0-9]+$/) { $N = $tmp; last; }
14 | else { die $usage; }
15 | }
16 |
17 | while (<>) {
18 | chomp;
19 | if (/^[\s]*$/) { next; }
20 | my @w = split;
21 | my $M = scalar @w;
22 |
23 | my $lo = -$N+1;
24 | if ($boundary == 0) { $lo = 0; }
25 | if (($boundary == 1) && ($N>1)) { $lo = -1; }
26 |
27 | my $hi = $M;
28 | if ($boundary == 0) { $hi = $M-$N+1; }
29 | if (($boundary == 1) && ($N>1)) { $hi = $M-$N+2; }
30 |
31 | for (my $i=$lo; $i<$hi; $i++) {
32 | for (my $j=0; $j<$N; $j++) {
33 | if ($j > 0) { print ' '; }
34 | print (($i+$j<0) ? '' : (($i+$j>=$M) ? '' : $w[$i+$j]));
35 | }
36 | print "\n";
37 | }
38 | }
39 |
--------------------------------------------------------------------------------
/src/evaluation/calculate_diversity.sh:
--------------------------------------------------------------------------------
1 | export LANGUAGE=en_US.UTF-8
2 | export LANG=en_US.UTF-8
3 | export LC_ALL=en_US.UTF-8
4 |
5 | NGRAMS_SCRIPT=/fs/clip-amr/clarification_question_generation_pytorch/src/evaluation/all_ngrams.pl
6 |
7 | count_uniq_trigrams=$( cat $1 | $NGRAMS_SCRIPT 3 | sort | uniq -c | sort -gr | wc -l )
8 | count_all_trigrams=$( cat $1 | $NGRAMS_SCRIPT 3 | sort | sort -gr | wc -l )
9 | echo "Trigram diversity"
10 | echo "scale=4; $count_uniq_trigrams / $count_all_trigrams" | bc
11 |
12 | count_uniq_bigrams=$( cat $1 | $NGRAMS_SCRIPT 2 | sort | uniq -c | sort -gr | wc -l )
13 | count_all_bigrams=$( cat $1 | $NGRAMS_SCRIPT 2 | sort | sort -gr | wc -l )
14 | echo "Bigram diversity"
15 | echo "scale=4; $count_uniq_bigrams / $count_all_bigrams" | bc
16 |
17 | count_uniq_unigrams=$( cat $1 | $NGRAMS_SCRIPT 1 | sort | uniq -c | sort -gr | wc -l )
18 | count_all_unigrams=$( cat $1 | $NGRAMS_SCRIPT 1 | sort | sort -gr | wc -l )
19 | echo "Unigram diversity"
20 | echo "scale=4; $count_uniq_unigrams / $count_all_unigrams" | bc
21 |
22 |
23 |
--------------------------------------------------------------------------------
/src/evaluation/calculate_inter_annotator_agreement.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import csv
3 | from collections import defaultdict
4 | import sys
5 | import numpy as np
6 | import pdb
7 |
8 | on_topic_levels = {'Yes': 1, 'No': 0}
9 | is_grammatical_levels = {'Grammatical': 1, 'Comprehensible': 1, 'Incomprehensible': 0}
10 | is_specific_levels = {'Specific pretty much only to this product': 4,
11 | 'Specific to this and other very similar products (or the same product from a different manufacturer)': 3,
12 | 'Generic enough to be applicable to many other products of this type': 2,
13 | 'Generic enough to be applicable to any product under Home and Kitchen': 1,
14 | 'N/A (Not Applicable)': -1}
15 | asks_new_info_levels = {'Completely': 1, 'Somewhat': 1, 'No': 0, 'N/A (Not Applicable)': -1}
16 | useful_levels = {'Useful enough to be included in the product description': 4,
17 | 'Useful to a large number of potential buyers (or current users)': 3,
18 | 'Useful to a small number of potential buyers (or current users)': 2,
19 | 'Useful only to the person asking the question': 1,
20 | 'N/A (Not Applicable)': -1}
21 |
22 | model_dict = {'ref': 0, 'lucene': 1, 'seq2seq.beam': 2,
23 | 'rl.beam': 3,
24 | 'gan.beam': 4}
25 |
26 | model_list = ['ref', 'lucene', 'seq2seq.beam',
27 | 'rl.beam',
28 | 'gan.beam']
29 |
30 | frequent_words = ['dimensions', 'dimension', 'size', 'measurements', 'measurement',
31 | 'weight', 'height', 'width', 'diameter', 'density',
32 | 'bpa', 'difference', 'thread',
33 | 'china', 'usa']
34 |
35 |
36 | def get_avg_score(score_dict, ignore_na=True):
37 | curr_on_topic_score = 0.
38 | N = 0
39 | for score, count in score_dict.iteritems():
40 | if score != -1:
41 | curr_on_topic_score += score * count
42 | if ignore_na:
43 | if score != -1:
44 | N += count
45 | else:
46 | N += count
47 | #print N
48 | return curr_on_topic_score * 1.0 / N
49 |
50 |
51 | def main(args):
52 | num_models = len(model_list)
53 | on_topic_conf_scores = []
54 | is_grammatical_conf_scores = []
55 | is_specific_conf_scores = []
56 | asks_new_info_conf_scores = []
57 | useful_conf_scores = []
58 |
59 | asins_so_far = [None] * num_models
60 | for i in range(num_models):
61 | asins_so_far[i] = []
62 |
63 | with open(args.aggregate_results) as csvfile:
64 | reader = csv.DictReader(csvfile)
65 | for row in reader:
66 | if row['_golden'] == 'true' or row['_unit_state'] == 'golden':
67 | continue
68 | asin = row['asin']
69 | question = row['question']
70 | model_name = row['model_name']
71 | if model_name not in model_list:
72 | continue
73 | if asin not in asins_so_far[model_dict[model_name]]:
74 | asins_so_far[model_dict[model_name]].append(asin)
75 | else:
76 | print '%s duplicate %s' % (model_name, asin)
77 | continue
78 | on_topic_score = on_topic_levels[row['on_topic']]
79 | on_topic_conf_score = float(row['on_topic:confidence'])
80 | is_grammatical_score = is_grammatical_levels[row['grammatical']]
81 | is_grammatical_conf_score = float(row['grammatical:confidence'])
82 | is_specific_conf_score = float(row['is_specific:confidence'])
83 | asks_new_info_score = asks_new_info_levels[row['new_info']]
84 | asks_new_info_conf_score = float(row['new_info:confidence'])
85 | useful_conf_score = float(row['useful_to_another_buyer:confidence'])
86 |
87 | on_topic_conf_scores.append(on_topic_conf_score)
88 | is_grammatical_conf_scores.append(is_grammatical_conf_score)
89 | if on_topic_score != 0 and is_grammatical_score != 0:
90 | is_specific_conf_scores.append(is_specific_conf_score)
91 | asks_new_info_conf_scores.append(asks_new_info_conf_score)
92 | if asks_new_info_score != 0:
93 | useful_conf_scores.append(useful_conf_score)
94 |
95 | print 'On topic confidence: %.4f' % (sum(on_topic_conf_scores)/float(len(on_topic_conf_scores)))
96 | print 'Is grammatical confidence: %.4f' % (sum(is_grammatical_conf_scores)/float(len(is_grammatical_conf_scores)))
97 | print 'Asks new info confidence: %.4f' % (sum(asks_new_info_conf_scores)/float(len(asks_new_info_conf_scores)))
98 | print 'Useful confidence: %.4f' % (sum(useful_conf_scores)/float(len(useful_conf_scores)))
99 | print 'Specificity confidence: %.4f' % (sum(is_specific_conf_scores)/float(len(is_specific_conf_scores)))
100 |
101 |
102 | if __name__ == '__main__':
103 | argparser = argparse.ArgumentParser(sys.argv[0])
104 | argparser.add_argument("--aggregate_results", type = str)
105 | args = argparser.parse_args()
106 | print args
107 | print ""
108 | main(args)
109 |
--------------------------------------------------------------------------------
/src/evaluation/combine_refs_for_meteor.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import argparse
3 |
4 | def main(args):
5 | ref_sents = [None]*int(args.no_of_refs)
6 | for i in range(int(args.no_of_refs)):
7 | with open(args.ref_prefix+str(i), 'r') as f:
8 | ref_sents[i] = [line.strip('\n') for line in f.readlines()]
9 |
10 | combined_ref_file = open(args.combined_ref_fname, 'w')
11 | for i in range(len(ref_sents[0])):
12 | for j in range(int(args.no_of_refs)):
13 | combined_ref_file.write(ref_sents[j][i]+'\n')
14 | combined_ref_file.close()
15 |
16 | if __name__ == "__main__":
17 | argparser = argparse.ArgumentParser(sys.argv[0])
18 | argparser.add_argument("--ref_prefix", type = str)
19 | argparser.add_argument("--no_of_refs", type = int)
20 | argparser.add_argument("--combined_ref_fname", type = str)
21 | args = argparser.parse_args()
22 | print args
23 | print ""
24 | main(args)
25 |
--------------------------------------------------------------------------------
/src/evaluation/create_amazon_multi_refs.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import csv
3 | import sys, os, pdb
4 | import nltk
5 | import time
6 | import random
7 | from collections import defaultdict
8 |
9 | def create_refs(test_ids, question_candidates, ref_prefix):
10 | max_ref_count=0
11 | for ques_id in test_ids:
12 | asin = ques_id.split('_')[0]
13 | N = len(question_candidates[asin])
14 | max_ref_count = max(max_ref_count, N)
15 | ref_files = [None]*max_ref_count
16 | for i in range(max_ref_count):
17 | ref_files[i] = open(ref_prefix+str(i), 'w')
18 | for ques_id in test_ids:
19 | asin = ques_id.split('_')[0]
20 | N = len(question_candidates[asin])
21 | for i, ques in enumerate(question_candidates[asin]):
22 | ref_files[i].write(ques+'\n')
23 | choices = range(N)
24 | random.shuffle(choices)
25 | for j in range(N, max_ref_count):
26 | r = choices[j%N]
27 | ref_files[j].write(question_candidates[asin][r]+'\n')
28 |
29 | def main(args):
30 | question_candidates = {}
31 | test_ids = [test_id.strip('\n') for test_id in open(args.test_ids_file, 'r').readlines()]
32 |
33 | for fname in os.listdir(args.ques_dir):
34 | with open(os.path.join(args.ques_dir, fname), 'r') as f:
35 | asin = fname[:-4].split('_')[0]
36 | if asin not in question_candidates:
37 | question_candidates[asin] = []
38 | ques = f.readline().strip('\n')
39 | question_candidates[asin].append(ques)
40 |
41 | create_refs(test_ids, question_candidates, args.ref_prefix)
42 |
43 | if __name__ == "__main__":
44 | argparser = argparse.ArgumentParser(sys.argv[0])
45 | argparser.add_argument("--ques_dir", type = str)
46 | argparser.add_argument("--test_ids_file", type = str)
47 | argparser.add_argument("--ref_prefix", type = str)
48 | args = argparser.parse_args()
49 | print args
50 | print ""
51 | main(args)
52 |
53 |
--------------------------------------------------------------------------------
/src/evaluation/create_crowdflower_data.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import gzip
3 | import nltk
4 | import pdb
5 | import sys
6 | from collections import defaultdict
7 | import csv
8 | import random
9 |
10 |
11 | def parse(path):
12 | g = gzip.open(path, 'r')
13 | for l in g:
14 | yield eval(l)
15 |
16 |
17 | def read_model_outputs(model_fname):
18 | i = 0
19 | model_outputs = {}
20 | model_file = open(model_fname, 'r')
21 | test_ids = [line.strip('\n') for line in open(model_fname+'.ids', 'r')]
22 | for line in model_file.readlines():
23 | model_outputs[test_ids[i]] = line.strip('\n').replace(' ', '').replace(' ', '')
24 | i += 1
25 | return model_outputs
26 |
27 |
28 | def main(args):
29 | titles = {}
30 | descriptions = {}
31 | lucene_model_outs = read_model_outputs(args.lucene_model_fname)
32 | seq2seq_model_outs = read_model_outputs(args.seq2seq_model_fname)
33 | rl_model_outs = read_model_outputs(args.rl_model_fname)
34 | gan_model_outs = read_model_outputs(args.gan_model_fname)
35 |
36 | prev_batch_asins = []
37 | with open(args.batch1_csv_file) as csvfile:
38 | reader = csv.DictReader(csvfile)
39 | for row in reader:
40 | asin = row['asin']
41 | prev_batch_asins.append(asin)
42 | with open(args.batch2_csv_file) as csvfile:
43 | reader = csv.DictReader(csvfile)
44 | for row in reader:
45 | asin = row['asin']
46 | prev_batch_asins.append(asin)
47 | with open(args.batch3_csv_file) as csvfile:
48 | reader = csv.DictReader(csvfile)
49 | for row in reader:
50 | asin = row['asin']
51 | prev_batch_asins.append(asin)
52 |
53 | for v in parse(args.metadata_fname):
54 | asin = v['asin']
55 | if asin in prev_batch_asins:
56 | continue
57 | if asin not in lucene_model_outs.keys():
58 | continue
59 | title = v['title']
60 | description = v['description']
61 | length = len(description.split())
62 | if length >= 100 or length < 10 or len(title.split()) == length:
63 | continue
64 | titles[asin] = title
65 | descriptions[asin] = description
66 |
67 | questions = defaultdict(list)
68 | for v in parse(args.qa_data_fname):
69 | asin = v['asin']
70 | if asin in prev_batch_asins:
71 | continue
72 | if asin not in lucene_model_outs.keys():
73 | continue
74 | if asin not in descriptions:
75 | continue
76 | question = ' '.join(nltk.sent_tokenize(v['question'])).lower()
77 | questions[asin].append(question)
78 |
79 | csv_file = open(args.csv_file, 'w')
80 | writer = csv.writer(csv_file, delimiter=',')
81 | writer.writerow(['asin', 'title', 'description', 'model_name', 'question'])
82 | all_rows = []
83 | max_count = 200
84 | i = 0
85 | for asin in lucene_model_outs.keys():
86 | if asin in prev_batch_asins:
87 | continue
88 | if asin not in titles:
89 | continue
90 | title = titles[asin]
91 | description = descriptions[asin]
92 | ref_question = random.choice(questions[asin])
93 | lucene_question = lucene_model_outs[asin]
94 | if lucene_question == '':
95 | print 'Found empty line in lucene'
96 | continue
97 | seq2seq_question = seq2seq_model_outs[asin]
98 | rl_question = rl_model_outs[asin]
99 | gan_question = gan_model_outs[asin]
100 | all_rows.append([asin, title, description, 'ref', ref_question])
101 | all_rows.append([asin, title, description, args.lucene_model_name, lucene_question])
102 | all_rows.append([asin, title, description, args.seq2seq_model_name, seq2seq_question])
103 | all_rows.append([asin, title, description, args.rl_model_name, rl_question])
104 | all_rows.append([asin, title, description, args.gan_model_name, gan_question])
105 | i += 1
106 | if i >= max_count:
107 | break
108 | random.shuffle(all_rows)
109 | for row in all_rows:
110 | writer.writerow(row)
111 | csv_file.close()
112 |
113 |
114 | if __name__ == "__main__":
115 | argparser = argparse.ArgumentParser(sys.argv[0])
116 | argparser.add_argument("--qa_data_fname", type = str)
117 | argparser.add_argument("--metadata_fname", type = str)
118 | argparser.add_argument("--batch1_csv_file", type=str)
119 | argparser.add_argument("--batch2_csv_file", type=str)
120 | argparser.add_argument("--batch3_csv_file", type=str)
121 | argparser.add_argument("--csv_file", type=str)
122 | argparser.add_argument("--lucene_model_fname", type=str)
123 | argparser.add_argument("--lucene_model_name", type=str)
124 | argparser.add_argument("--seq2seq_model_fname", type=str)
125 | argparser.add_argument("--seq2seq_model_name", type=str)
126 | argparser.add_argument("--rl_model_fname", type=str)
127 | argparser.add_argument("--rl_model_name", type=str)
128 | argparser.add_argument("--gan_model_fname", type=str)
129 | argparser.add_argument("--gan_model_name", type=str)
130 | args = argparser.parse_args()
131 | print args
132 | print ""
133 | main(args)
134 |
135 |
--------------------------------------------------------------------------------
/src/evaluation/create_crowdflower_data_beam.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import gzip
3 | import nltk
4 | import pdb
5 | import sys
6 | from collections import defaultdict
7 | import csv
8 | import random
9 |
10 |
11 | model_list = ['ref', 'lucene', 'seq2seq.beam',
12 | 'rl.beam',
13 | 'gan.beam']
14 | model_dict = {'ref': 0, 'lucene': 1, 'seq2seq.beam': 2,
15 | 'rl.beam': 3,
16 | 'gan.beam': 4}
17 |
18 |
19 | def parse(path):
20 | g = gzip.open(path, 'r')
21 | for l in g:
22 | yield eval(l)
23 |
24 |
25 | def main(args):
26 | csv_file = open(args.output_csv_file, 'w')
27 | writer = csv.writer(csv_file, delimiter=',')
28 | writer.writerow(['asin', 'title', 'description', 'model_name', 'question'])
29 | all_rows = []
30 | asins_so_far = defaultdict(list)
31 | with open(args.previous_csv_file) as csvfile:
32 | reader = csv.DictReader(csvfile)
33 | for row in reader:
34 | asin = row['asin']
35 | if row['model_name'] not in model_list:
36 | continue
37 | if asin not in asins_so_far[row['model_name']]:
38 | asins_so_far[row['model_name']].append(asin)
39 | else:
40 | print 'Duplicate asin %s in %s' % (asin, row['model_name'])
41 | continue
42 | all_rows.append([row['asin'], row['title'], row['description'], row['model_name'], row['question']])
43 | random.shuffle(all_rows)
44 | for row in all_rows:
45 | writer.writerow(row)
46 | csv_file.close()
47 |
48 |
49 | if __name__ == "__main__":
50 | argparser = argparse.ArgumentParser(sys.argv[0])
51 | argparser.add_argument("--previous_csv_file", type=str)
52 | argparser.add_argument("--output_csv_file", type=str)
53 | args = argparser.parse_args()
54 | print args
55 | print ""
56 | main(args)
57 |
58 |
--------------------------------------------------------------------------------
/src/evaluation/create_crowdflower_data_compare_ques.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import gzip
3 | import nltk
4 | import pdb
5 | import sys
6 | from collections import defaultdict
7 | import csv
8 | import random
9 |
10 |
11 | def parse(path):
12 | g = gzip.open(path, 'r')
13 | for l in g:
14 | yield eval(l)
15 |
16 |
17 | def main(args):
18 | titles = {}
19 | descriptions = {}
20 |
21 | previous_asins = []
22 | with open(args.previous_csv_file_v1) as csvfile:
23 | reader = csv.DictReader(csvfile)
24 | for row in reader:
25 | asin = row['asin']
26 | previous_asins.append(asin)
27 | with open(args.previous_csv_file_v2) as csvfile:
28 | reader = csv.DictReader(csvfile)
29 | for row in reader:
30 | asin = row['asin']
31 | previous_asins.append(asin)
32 | with open(args.previous_csv_file_v3) as csvfile:
33 | reader = csv.DictReader(csvfile)
34 | for row in reader:
35 | asin = row['asin']
36 | previous_asins.append(asin)
37 | with open(args.previous_csv_file_v4) as csvfile:
38 | reader = csv.DictReader(csvfile)
39 | for row in reader:
40 | asin = row['asin']
41 | previous_asins.append(asin)
42 |
43 | train_asins = [line.strip('\n') for line in open(args.train_asins, 'r').readlines()]
44 |
45 | for v in parse(args.metadata_fname):
46 | asin = v['asin']
47 | if asin not in train_asins:
48 | continue
49 | if asin in previous_asins:
50 | continue
51 | title = v['title']
52 | description = v['description']
53 | length = len(description.split())
54 | if length >= 100 or length < 10 or len(title.split()) == length:
55 | continue
56 | titles[asin] = title
57 | descriptions[asin] = description
58 |
59 | questions = defaultdict(list)
60 | for v in parse(args.qa_data_fname):
61 | asin = v['asin']
62 | if asin not in train_asins:
63 | continue
64 | if asin in previous_asins:
65 | continue
66 | if asin not in descriptions:
67 | continue
68 | question = ' '.join(nltk.sent_tokenize(v['question'])).lower()
69 | questions[asin].append(question)
70 |
71 | csv_file = open(args.csv_file, 'w')
72 | writer = csv.writer(csv_file, delimiter=',')
73 | writer.writerow(['asin', 'title', 'description', 'question_a', 'question_b'])
74 | all_rows = []
75 | max_count = 200
76 | i = 0
77 | for asin in titles.keys():
78 | if asin not in train_asins:
79 | continue
80 | if asin in previous_asins:
81 | continue
82 | title = titles[asin]
83 | description = descriptions[asin]
84 | random.shuffle(questions[asin])
85 | for k in range(len(questions[asin])):
86 | if k == 6 or k == len(questions[asin])-1:
87 | all_rows.append([asin, title, description, questions[asin][k], questions[asin][0]])
88 | print asin, k+1
89 | break
90 | else:
91 | all_rows.append([asin, title, description, questions[asin][k], questions[asin][k + 1]])
92 | i += 1
93 | if i >= max_count:
94 | break
95 | random.shuffle(all_rows)
96 | for row in all_rows:
97 | writer.writerow(row)
98 | csv_file.close()
99 |
100 |
101 | if __name__ == "__main__":
102 | argparser = argparse.ArgumentParser(sys.argv[0])
103 | argparser.add_argument("--qa_data_fname", type = str)
104 | argparser.add_argument("--metadata_fname", type = str)
105 | argparser.add_argument("--csv_file", type=str)
106 | argparser.add_argument("--train_asins", type=str)
107 | argparser.add_argument("--previous_csv_file_v1", type=str)
108 | argparser.add_argument("--previous_csv_file_v2", type=str)
109 | argparser.add_argument("--previous_csv_file_v3", type=str)
110 | argparser.add_argument("--previous_csv_file_v4", type=str)
111 | args = argparser.parse_args()
112 | print args
113 | print ""
114 | main(args)
115 |
116 |
--------------------------------------------------------------------------------
/src/evaluation/create_crowdflower_data_compare_ques_allpairs.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import gzip
3 | import nltk
4 | import pdb
5 | import sys
6 | from collections import defaultdict
7 | import csv
8 | import random
9 |
10 |
11 | def parse(path):
12 | g = gzip.open(path, 'r')
13 | for l in g:
14 | yield eval(l)
15 |
16 |
17 | def main(args):
18 | titles = {}
19 | descriptions = {}
20 |
21 | previous_asins = []
22 | with open(args.previous_csv_file) as csvfile:
23 | reader = csv.DictReader(csvfile)
24 | for row in reader:
25 | asin = row['asin']
26 | previous_asins.append(asin)
27 |
28 | train_asins = [line.strip('\n') for line in open(args.train_asins, 'r').readlines()]
29 |
30 | for v in parse(args.metadata_fname):
31 | asin = v['asin']
32 | if asin not in train_asins:
33 | continue
34 | if asin in previous_asins:
35 | continue
36 | title = v['title']
37 | description = v['description']
38 | length = len(description.split())
39 | if length >= 100 or length < 10 or len(title.split()) == length:
40 | continue
41 | titles[asin] = title
42 | descriptions[asin] = description
43 |
44 | questions = defaultdict(list)
45 | for v in parse(args.qa_data_fname):
46 | asin = v['asin']
47 | if asin not in train_asins:
48 | continue
49 | if asin in previous_asins:
50 | continue
51 | if asin not in descriptions:
52 | continue
53 | question = ' '.join(nltk.sent_tokenize(v['question'])).lower()
54 | questions[asin].append(question)
55 |
56 | csv_file = open(args.csv_file, 'w')
57 | writer = csv.writer(csv_file, delimiter=',')
58 | writer.writerow(['asin', 'title', 'description', 'question_a', 'question_b'])
59 | all_rows = []
60 | max_count = 25
61 | i = 0
62 | for asin in titles.keys():
63 | if asin not in train_asins:
64 | continue
65 | if asin in previous_asins:
66 | continue
67 | title = titles[asin]
68 | description = descriptions[asin]
69 | random.shuffle(questions[asin])
70 | for j in range(len(questions[asin])):
71 | if j == 6 or j == len(questions[asin]) - 1:
72 | break
73 | for k in range(j+1, min(len(questions[asin]), 7)):
74 | all_rows.append([asin, title, description, questions[asin][j], questions[asin][k]])
75 | i += 1
76 | if i >= max_count:
77 | break
78 | random.shuffle(all_rows)
79 | for row in all_rows:
80 | writer.writerow(row)
81 | csv_file.close()
82 |
83 |
84 | if __name__ == "__main__":
85 | argparser = argparse.ArgumentParser(sys.argv[0])
86 | argparser.add_argument("--qa_data_fname", type = str)
87 | argparser.add_argument("--metadata_fname", type = str)
88 | argparser.add_argument("--csv_file", type=str)
89 | argparser.add_argument("--train_asins", type=str)
90 | argparser.add_argument("--previous_csv_file", type=str)
91 | args = argparser.parse_args()
92 | print args
93 | print ""
94 | main(args)
95 |
96 |
--------------------------------------------------------------------------------
/src/evaluation/create_crowdflower_data_specificity.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import gzip
3 | import nltk
4 | import pdb
5 | import sys
6 | from collections import defaultdict
7 | import csv
8 | import random
9 |
10 |
11 | def parse(path):
12 | g = gzip.open(path, 'r')
13 | for l in g:
14 | yield eval(l)
15 |
16 |
17 | def read_model_outputs(model_fname):
18 | i = 0
19 | model_outputs = {}
20 | model_file = open(model_fname, 'r')
21 | test_ids = [line.strip('\n') for line in open(model_fname+'.ids', 'r')]
22 | for line in model_file.readlines():
23 | model_outputs[test_ids[i]] = line.strip('\n').replace(' ', '').replace(' ', '')
24 | i += 1
25 | return model_outputs
26 |
27 |
28 | def main(args):
29 | titles = {}
30 | descriptions = {}
31 | lucene_model_outs = read_model_outputs(args.lucene_model_fname)
32 | seq2seq_model_outs = read_model_outputs(args.seq2seq_model_fname)
33 | seq2seq_specific_model_outs = read_model_outputs(args.seq2seq_specific_model_fname)
34 | seq2seq_generic_model_outs = read_model_outputs(args.seq2seq_generic_model_fname)
35 |
36 | for v in parse(args.metadata_fname):
37 | asin = v['asin']
38 | if not v.has_key('description') or not v.has_key('title'):
39 | continue
40 | title = v['title']
41 | description = v['description']
42 | length = len(description.split())
43 | if length >= 100 or length < 10 or len(title.split()) == length:
44 | continue
45 | titles[asin] = title
46 | descriptions[asin] = description
47 |
48 | questions = defaultdict(list)
49 | for v in parse(args.qa_data_fname):
50 | asin = v['asin']
51 | if asin not in descriptions:
52 | continue
53 | question = ' '.join(nltk.sent_tokenize(v['question'])).lower()
54 | questions[asin].append(question)
55 |
56 | csv_file = open(args.csv_file, 'w')
57 | writer = csv.writer(csv_file, delimiter=',')
58 | writer.writerow(['asin', 'title', 'description', 'model_name', 'question'])
59 | all_rows = []
60 | i = 0
61 | max_count = 50
62 | for asin in seq2seq_specific_model_outs.keys():
63 | if asin not in descriptions:
64 | continue
65 | seq2seq_specific_question = seq2seq_specific_model_outs[asin]
66 | seq2seq_specific_question_tokens = seq2seq_specific_question.split()
67 | if '?' in seq2seq_specific_question_tokens:
68 | if seq2seq_specific_question_tokens.index('?') != len(seq2seq_specific_question_tokens)-1 :
69 | continue
70 | title = titles[asin]
71 | description = descriptions[asin]
72 | ref_question = random.choice(questions[asin])
73 | lucene_question = lucene_model_outs[asin]
74 | seq2seq_question = seq2seq_model_outs[asin]
75 | seq2seq_generic_question = seq2seq_generic_model_outs[asin]
76 | all_rows.append([asin, title, description, "ref", ref_question])
77 | all_rows.append([asin, title, description, args.lucene_model_name, lucene_question])
78 | all_rows.append([asin, title, description, args.seq2seq_model_name, seq2seq_question])
79 | all_rows.append([asin, title, description, args.seq2seq_specific_model_name, seq2seq_specific_question])
80 | all_rows.append([asin, title, description, args.seq2seq_generic_model_name, seq2seq_generic_question])
81 | i += 1
82 | if i >= max_count:
83 | break
84 | random.shuffle(all_rows)
85 | for row in all_rows:
86 | writer.writerow(row)
87 | csv_file.close()
88 |
89 |
90 | if __name__ == "__main__":
91 | argparser = argparse.ArgumentParser(sys.argv[0])
92 | argparser.add_argument("--qa_data_fname", type = str)
93 | argparser.add_argument("--metadata_fname", type = str)
94 | #argparser.add_argument("--prev_csv_file", type=str)
95 | argparser.add_argument("--csv_file", type=str)
96 | argparser.add_argument("--lucene_model_name", type=str)
97 | argparser.add_argument("--lucene_model_fname", type=str)
98 | argparser.add_argument("--seq2seq_model_name", type=str)
99 | argparser.add_argument("--seq2seq_model_fname", type=str)
100 | argparser.add_argument("--seq2seq_specific_model_name", type=str)
101 | argparser.add_argument("--seq2seq_specific_model_fname", type=str)
102 | argparser.add_argument("--seq2seq_generic_model_name", type=str)
103 | argparser.add_argument("--seq2seq_generic_model_fname", type=str)
104 | args = argparser.parse_args()
105 | print args
106 | print ""
107 | main(args)
108 |
109 |
--------------------------------------------------------------------------------
/src/evaluation/create_crowdflower_data_specificity_multi.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import gzip
3 | import nltk
4 | import pdb
5 | import sys
6 | from collections import defaultdict
7 | import csv
8 | import random
9 |
10 |
11 | def parse(path):
12 | g = gzip.open(path, 'r')
13 | for l in g:
14 | yield eval(l)
15 |
16 |
17 | def read_model_outputs(model_fname):
18 | i = 0
19 | model_outputs = {}
20 | model_file = open(model_fname, 'r')
21 | test_ids = [line.strip('\n') for line in open(model_fname+'.ids', 'r')]
22 | for line in model_file.readlines():
23 | model_outputs[test_ids[i]] = line.strip('\n').replace(' ', '').replace(' ', '')
24 | i += 1
25 | return model_outputs
26 |
27 |
28 | def main(args):
29 | titles = {}
30 | descriptions = {}
31 | lucene_model_outs = read_model_outputs(args.lucene_model_fname)
32 | seq2seq_model_outs = read_model_outputs(args.seq2seq_model_fname)
33 | seq2seq_specific_model_outs = read_model_outputs(args.seq2seq_specific_model_fname)
34 | seq2seq_generic_model_outs = read_model_outputs(args.seq2seq_generic_model_fname)
35 |
36 | prev_asins = []
37 | with open(args.prev_csv_file) as csvfile:
38 | reader = csv.DictReader(csvfile)
39 | for row in reader:
40 | asin = row['asin']
41 | prev_asins.append(asin)
42 |
43 | for v in parse(args.metadata_fname):
44 | asin = v['asin']
45 | if asin in prev_asins:
46 | continue
47 | if not v.has_key('description') or not v.has_key('title'):
48 | continue
49 | title = v['title']
50 | description = v['description']
51 | length = len(description.split())
52 | if length >= 100 or length < 10 or len(title.split()) == length:
53 | continue
54 | titles[asin] = title
55 | descriptions[asin] = description
56 |
57 | questions = defaultdict(list)
58 | for v in parse(args.qa_data_fname):
59 | asin = v['asin']
60 | if asin in prev_asins:
61 | continue
62 | if asin not in descriptions:
63 | continue
64 | question = ' '.join(nltk.sent_tokenize(v['question'])).lower()
65 | questions[asin].append(question)
66 |
67 | csv_file = open(args.csv_file, 'w')
68 | writer = csv.writer(csv_file, delimiter=',')
69 | writer.writerow(['asin', 'title', 'description', 'model_name', 'question'])
70 | all_rows = []
71 | i = 0
72 | max_count = 50
73 | for asin in seq2seq_specific_model_outs.keys():
74 | if asin not in descriptions:
75 | continue
76 | seq2seq_specific_question = seq2seq_specific_model_outs[asin]
77 | seq2seq_specific_question_tokens = seq2seq_specific_question.split()
78 | if '?' in seq2seq_specific_question_tokens:
79 | if seq2seq_specific_question_tokens.index('?') == len(seq2seq_specific_question_tokens)-1 :
80 | continue
81 | title = titles[asin]
82 | description = descriptions[asin]
83 | ref_question = random.choice(questions[asin])
84 | lucene_question = lucene_model_outs[asin]
85 | seq2seq_question = seq2seq_model_outs[asin]
86 | seq2seq_generic_question = seq2seq_generic_model_outs[asin]
87 | all_rows.append([asin, title, description, "ref", ref_question])
88 | all_rows.append([asin, title, description, args.lucene_model_name, lucene_question])
89 | all_rows.append([asin, title, description, args.seq2seq_model_name, seq2seq_question])
90 | all_rows.append([asin, title, description, args.seq2seq_specific_model_name, seq2seq_specific_question])
91 | all_rows.append([asin, title, description, args.seq2seq_generic_model_name, seq2seq_generic_question])
92 | i += 1
93 | if i >= max_count:
94 | break
95 | random.shuffle(all_rows)
96 | for row in all_rows:
97 | writer.writerow(row)
98 | csv_file.close()
99 |
100 |
101 | if __name__ == "__main__":
102 | argparser = argparse.ArgumentParser(sys.argv[0])
103 | argparser.add_argument("--qa_data_fname", type = str)
104 | argparser.add_argument("--metadata_fname", type = str)
105 | argparser.add_argument("--prev_csv_file", type=str)
106 | argparser.add_argument("--csv_file", type=str)
107 | argparser.add_argument("--lucene_model_name", type=str)
108 | argparser.add_argument("--lucene_model_fname", type=str)
109 | argparser.add_argument("--seq2seq_model_name", type=str)
110 | argparser.add_argument("--seq2seq_model_fname", type=str)
111 | argparser.add_argument("--seq2seq_specific_model_name", type=str)
112 | argparser.add_argument("--seq2seq_specific_model_fname", type=str)
113 | argparser.add_argument("--seq2seq_generic_model_name", type=str)
114 | argparser.add_argument("--seq2seq_generic_model_fname", type=str)
115 | args = argparser.parse_args()
116 | print args
117 | print ""
118 | main(args)
119 |
120 |
--------------------------------------------------------------------------------
/src/evaluation/create_preds_for_refs.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import csv
3 | import sys, os, pdb
4 | import nltk
5 | import time
6 |
7 | def get_annotations(line):
8 | set_info, post_id, best, valids, confidence = line.split(',')
9 | annotator_name = set_info.split('_')[0]
10 | sitename = set_info.split('_')[1]
11 | best = int(best)
12 | valids = [int(v) for v in valids.split()]
13 | confidence = int(confidence)
14 | return post_id, annotator_name, sitename, best, valids, confidence
15 |
16 | def read_human_annotations(human_annotations_filename):
17 | human_annotations_file = open(human_annotations_filename, 'r')
18 | annotations = {}
19 | for line in human_annotations_file.readlines():
20 | line = line.strip('\n')
21 | splits = line.split('\t')
22 | post_id1, annotator_name1, sitename1, best1, valids1, confidence1 = get_annotations(splits[0])
23 | post_id2, annotator_name2, sitename2, best2, valids2, confidence2 = get_annotations(splits[1])
24 | assert(sitename1 == sitename2)
25 | assert(post_id1 == post_id2)
26 | post_id = sitename1+'_'+post_id1
27 | best_union = list(set([best1, best2]))
28 | valids_inter = list(set(valids1).intersection(set(valids2)))
29 | annotations[post_id] = list(set(best_union + valids_inter))
30 | return annotations
31 |
32 |
33 | def main(args):
34 | question_candidates = {}
35 | model_outputs = []
36 |
37 | test_ids = [test_id.strip('\n') for test_id in open(args.test_ids_file, 'r').readlines()]
38 | with open(args.qa_data_tsvfile, 'rb') as tsvfile:
39 | qa_reader = csv.reader(tsvfile, delimiter='\t')
40 | i = 0
41 | for row in qa_reader:
42 | if i == 0:
43 | i += 1
44 | continue
45 | post_id,questions = row[0], row[1:11]
46 | question_candidates[post_id] = questions
47 |
48 | annotations = read_human_annotations(args.human_annotations)
49 | model_output_file = open(args.model_output_file, 'r')
50 | for line in model_output_file.readlines():
51 | model_outputs.append(line.strip('\n'))
52 |
53 | pred_file = open(args.model_output_file+'.hasrefs', 'w')
54 | for i, post_id in enumerate(test_ids):
55 | if post_id not in annotations:
56 | continue
57 | pred_file.write(model_outputs[i]+'\n')
58 |
59 | if __name__ == "__main__":
60 | argparser = argparse.ArgumentParser(sys.argv[0])
61 | argparser.add_argument("--qa_data_tsvfile", type = str)
62 | argparser.add_argument("--human_annotations", type = str)
63 | argparser.add_argument("--model_output_file", type = str)
64 | argparser.add_argument("--test_ids_file", type = str)
65 | args = argparser.parse_args()
66 | print args
67 | print ""
68 | main(args)
69 |
70 |
--------------------------------------------------------------------------------
/src/evaluation/eval_HK:
--------------------------------------------------------------------------------
1 |
2 | #lucene
3 |
4 | cat Home_and_Kitchen/blind_test_pred_question.lucene.txt | /fs/clip-ml/hal/bin/all_ngrams.pl 3 | sort | uniq -c | sort -gr | wc -l
5 |
6 | /fs/clip-software/user-supported/mosesdecoder/3.0/scripts/generic/multi-bleu.perl Home_and_Kitchen/test_ref < Home_and_Kitchen/blind_test_pred_question.lucene.txt
7 |
8 | #Seq2seq
9 |
10 | cat Home_and_Kitchen/blind_test_pred_ques.txt.seq2seq.epoch100.beam0.nounks | /fs/clip-ml/hal/bin/all_ngrams.pl 3 | sort | uniq -c | sort -gr | wc -l
11 |
12 | /fs/clip-software/user-supported/mosesdecoder/3.0/scripts/generic/multi-bleu.perl Home_and_Kitchen/test_ref < Home_and_Kitchen/blind_test_pred_ques.txt.seq2seq.epoch100.beam0.nounks
13 |
14 | #RL
15 |
16 | cat Home_and_Kitchen/blind_test_pred_ques.txt.RL_mixer_3perid.epoch5.beam0.nounks| /fs/clip-ml/hal/bin/all_ngrams.pl 3 | sort | uniq -c | sort -gr | wc -l
17 |
18 | /fs/clip-software/user-supported/mosesdecoder/3.0/scripts/generic/multi-bleu.perl Home_and_Kitchen/test_ref < Home_and_Kitchen/blind_test_pred_ques.txt.RL_mixer_3perid.epoch5.beam0.nounks
19 |
20 | #GAN
21 |
22 | cat Home_and_Kitchen/blind_test_pred_ques.txt.GAN_selfcritic_pred_ans_3perid.epoch8.beam0.nounks | /fs/clip-ml/hal/bin/all_ngrams.pl 3 | sort | uniq -c | sort -gr | wc -l
23 |
24 | /fs/clip-software/user-supported/mosesdecoder/3.0/scripts/generic/multi-bleu.perl Home_and_Kitchen/test_ref < Home_and_Kitchen/blind_test_pred_ques.txt.GAN_selfcritic_pred_ans_3perid.epoch8.beam0.nounks
25 |
--------------------------------------------------------------------------------
/src/evaluation/eval_HK_spec:
--------------------------------------------------------------------------------
1 | # Lucene
2 |
3 | cat Home_and_Kitchen/blind_test_pred_question.lucene.txt | /fs/clip-ml/hal/bin/all_ngrams.pl 3 | sort | uniq -c | sort -gr | wc -l
4 |
5 | /fs/clip-software/user-supported/mosesdecoder/3.0/scripts/generic/multi-bleu.perl ../style_clarification_question_generation/specificQ_classifier/test_data/test_ref_specific < Home_and_Kitchen/blind_test_pred_question.lucene.txt
6 |
7 | # Seq2seq model
8 |
9 | cat Home_and_Kitchen/blind_test_pred_ques.txt.seq2seq.epoch100.beam0.nounks | /fs/clip-ml/hal/bin/all_ngrams.pl 3 | sort | uniq -c | sort -gr | wc -l
10 |
11 | /fs/clip-software/user-supported/mosesdecoder/3.0/scripts/generic/multi-bleu.perl ../style_clarification_question_generation/specificQ_classifier/test_data/test_ref_specific < Home_and_Kitchen/blind_test_pred_ques.txt.seq2seq.epoch100.beam0.nounks
12 |
13 |
14 | # RL model
15 | cat Home_and_Kitchen/blind_test_pred_ques.txt.RL_mixer_3perid.epoch5.beam0.nounks | /fs/clip-ml/hal/bin/all_ngrams.pl 3 | sort | uniq -c | sort -gr | wc -l
16 |
17 | /fs/clip-software/user-supported/mosesdecoder/3.0/scripts/generic/multi-bleu.perl ../style_clarification_question_generation/specificQ_classifier/test_data/test_ref_specific < Home_and_Kitchen/blind_test_pred_ques.txt.RL_mixer_3perid.epoch5.beam0.nounks
18 |
19 | # GAN model
20 |
21 | /fs/clip-software/user-supported/mosesdecoder/3.0/scripts/generic/multi-bleu.perl ../style_clarification_question_generation/specificQ_classifier/test_data/test_ref_generic < Home_and_Kitchen/blind_test_pred_ques.txt.GAN_selfcritic_pred_ans_3perid.epoch8.beam0.nounks
22 |
23 | # Specificity seq2seq model
24 |
25 | cat Home_and_Kitchen/blind_test_pred_ques.txt.seq2seq_tobeginning_tospecific.epoch65.beam0.nounks | /fs/clip-ml/hal/bin/all_ngrams.pl 3 | sort | uniq -c | sort -gr | wc -l
26 |
27 | /fs/clip-software/user-supported/mosesdecoder/3.0/scripts/generic/multi-bleu.perl ../style_clarification_question_generation/specificQ_classifier/test_data/test_ref_specific < Home_and_Kitchen/blind_test_pred_ques.txt.seq2seq_tobeginning_tospecific.epoch65.beam0.nounks
28 |
29 | # Specificty GAN model
30 |
31 | cat Home_and_Kitchen/blind_test_pred_ques.txt.GAN_selfcritic_pred_ans_tobeginning_tospecific.epoch8.beam0.nounks | /fs/clip-ml/hal/bin/all_ngrams.pl 3 | sort | uniq -c | sort -gr | wc -l
32 |
33 | /fs/clip-software/user-supported/mosesdecoder/3.0/scripts/generic/multi-bleu.perl ../style_clarification_question_generation/specificQ_classifier/test_data/test_ref_generic < Home_and_Kitchen/blind_test_pred_ques.txt.GAN_selfcritic_pred_ans_tobeginning_togeneric.epoch8.beam0.nounks
34 |
--------------------------------------------------------------------------------
/src/evaluation/eval_aus:
--------------------------------------------------------------------------------
1 | #Reference
2 |
3 | /fs/clip-software/user-supported/mosesdecoder/3.0/scripts/generic/multi-bleu.perl ../clarification_question_generation/data/askubuntu_unix_superuser/test_ref < /fs/clip-scratch/raosudha/clarification_question_generation/joint_learning/askubuntu_unix_superuser/test_question.txt.hasrefs
4 |
5 | #Lucene
6 |
7 | /fs/clip-software/user-supported/mosesdecoder/3.0/scripts/generic/multi-bleu.perl ../clarification_question_generation/data/askubuntu_unix_superuser/test_ref < ../clarification_question_generation/data/askubuntu_unix_superuser/test_pred_lucene.txt.hasrefs
8 |
9 | cat ../clarification_question_generation/data/askubuntu_unix_superuser/test_pred_lucene.txt.hasrefs | /fs/clip-ml/hal/bin/all_ngrams.pl 1 | sort | uniq -c | sort -gr | wc -l
10 |
11 | #Seq2Seq
12 |
13 | /fs/clip-software/user-supported/mosesdecoder/3.0/scripts/generic/multi-bleu.perl ../clarification_question_generation/data/askubuntu_unix_superuser/test_ref < /fs/clip-scratch/raosudha/clarification_question_generation/joint_learning/askubuntu_unix_superuser/test_pred_question.txt.seq2seq.len_norm.beam0.hasrefs.nounks
14 |
15 | #RL
16 |
17 | /fs/clip-software/user-supported/mosesdecoder/3.0/scripts/generic/multi-bleu.perl ../clarification_question_generation/data/askubuntu_unix_superuser/test_ref < /fs/clip-scratch/raosudha/clarification_question_generation/joint_learning/askubuntu_unix_superuser/test_pred_question.txt.RL_mixer.epoch8.len_norm.beam0.hasrefs.nounks
18 |
19 | cat /fs/clip-scratch/raosudha/clarification_question_generation/joint_learning/askubuntu_unix_superuser/test_pred_question.txt.RL_selfcritic.epoch8.len_norm.beam0.hasrefs.nounks | /fs/clip-ml/hal/bin/all_ngrams.pl 3 | sort | uniq -c | sort -gr | wc -l
20 |
21 | #GAN
22 |
23 | /fs/clip-software/user-supported/mosesdecoder/3.0/scripts/generic/multi-bleu.perl ../clarification_question_generation/data/askubuntu_unix_superuser/test_ref < /fs/clip-scratch/raosudha/clarification_question_generation/joint_learning/askubuntu_unix_superuser/test_pred_question.txt.mixer_pred_ans.epoch8.len_norm.beam0.hasrefs.nounks
24 |
25 | /fs/clip-software/user-supported/mosesdecoder/3.0/scripts/generic/multi-bleu.perl ../clarification_question_generation/data/askubuntu_unix_superuser/test_ref < /fs/clip-scratch/raosudha/clarification_question_generation/joint_learning/askubuntu_unix_superuser/test_pred_question.txt.selfcritic_pred_ans.epoch8.len_norm.beam0.hasrefs.nounks
26 |
27 | cat /fs/clip-scratch/raosudha/clarification_question_generation/joint_learning/askubuntu_unix_superuser/test_pred_question.txt.mixer_pred_ans.epoch8.len_norm.beam0.hasrefs.nounks | /fs/clip-ml/hal/bin/all_ngrams.pl 1 | sort | uniq -c | sort -gr | wc -l
28 |
--------------------------------------------------------------------------------
/src/evaluation/read_crowdflower_full_results_compare_ques.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import csv
3 | from collections import defaultdict
4 | import sys
5 | import pdb
6 | import numpy as np
7 |
8 | specificity_levels = ['Question A is more specific',
9 | 'Question B is more specific',
10 | 'Both are at the same level of specificity',
11 | 'N/A: One or both questions are not applicable to the product']
12 |
13 |
14 | def get_avg_score(score_dict, ignore_na=False):
15 | curr_on_topic_score = 0.
16 | N = 0
17 | for score, count in score_dict.iteritems():
18 | curr_on_topic_score += score * count
19 | if ignore_na:
20 | if score != 0:
21 | N += count
22 | else:
23 | N += count
24 | # print N
25 | return curr_on_topic_score * 1.0 / N
26 |
27 |
28 | def main(args):
29 | titles = {}
30 | descriptions = {}
31 | cand_ques_dict = defaultdict(list)
32 | cand_scores_dict = defaultdict(list)
33 | curr_asin = None
34 | curr_a = None
35 | curr_b = None
36 | a_scores = []
37 | b_scores = []
38 | ab_scores = []
39 | with open(args.full_results) as csvfile:
40 | reader = csv.DictReader(csvfile)
41 | for row in reader:
42 | if row['_golden'] == 'true' or row['_tainted'] == 'true':
43 | continue
44 | asin = row['asin']
45 | titles[asin] = row['title']
46 | descriptions[asin] = row['description']
47 | question_a = row['question_a']
48 | question_b = row['question_b']
49 | trust = row['_trust']
50 |
51 | if curr_asin is None:
52 | curr_asin = asin
53 | curr_a = question_a
54 | curr_b = question_b
55 | elif asin != curr_asin and curr_a != question_a and curr_b != question_b:
56 | a_score = np.sum(a_scores)/5
57 | b_score = np.sum(b_scores)/5
58 | ab_score = np.sum(ab_scores)/5
59 | cand_scores_dict[curr_asin][cand_ques_dict[curr_asin].index(curr_a)].append(a_score + 0.5*ab_score)
60 | cand_scores_dict[curr_asin][cand_ques_dict[curr_asin].index(curr_b)].append(b_score + 0.5*ab_score)
61 | curr_asin = asin
62 | curr_a = question_a
63 | curr_b = question_b
64 | a_scores = []
65 | b_scores = []
66 | ab_scores = []
67 |
68 | if question_a not in cand_ques_dict[asin]:
69 | cand_ques_dict[asin].append(question_a)
70 | cand_scores_dict[asin].append([])
71 | if question_b not in cand_ques_dict[asin]:
72 | cand_ques_dict[asin].append(question_b)
73 | cand_scores_dict[asin].append([])
74 |
75 | if row['on_topic'] == 'Question A is more specific':
76 | a_scores.append(float(trust))
77 | elif row['on_topic'] == 'Question B is more specific':
78 | b_scores.append(float(trust))
79 | elif row['on_topic'] == 'Both are at the same level of specificity':
80 | ab_scores.append(float(trust))
81 | else:
82 | print 'ID: %s has irrelevant question' % asin
83 |
84 | corr = 0
85 | total = 0
86 | fp, tp, fn, tn = 0, 0, 0, 0
87 | for asin in titles:
88 | print asin
89 | print titles[asin]
90 | print descriptions[asin]
91 | for i, ques in enumerate(cand_ques_dict[asin]):
92 | true_v = np.mean(cand_scores_dict[asin][i])
93 | pred_v = np.mean(cand_scores_dict[asin][i][:int(len(cand_scores_dict[asin][i])/2)])
94 | print true_v, cand_scores_dict[asin][i], ques
95 | print pred_v
96 | if true_v < 0.5:
97 | true_l = 0
98 | else:
99 | true_l = 1
100 | if pred_v < 0.5:
101 | pred_l = 0
102 | else:
103 | pred_l = 1
104 | if true_l == pred_l:
105 | corr += 1
106 | if true_l == 0 and pred_l == 1:
107 | fp += 1
108 | if true_l == 0 and pred_l == 0:
109 | tn += 1
110 | if true_l == 1 and pred_l == 0:
111 | fn += 1
112 | if true_l == 1 and pred_l == 1:
113 | tp += 1
114 | total += 1
115 | print
116 | print 'accuracy'
117 | print corr*1.0/total
118 | print tp, fp, fn, tn
119 |
120 |
121 | if __name__ == '__main__':
122 | argparser = argparse.ArgumentParser(sys.argv[0])
123 | argparser.add_argument("--full_results", type = str)
124 | args = argparser.parse_args()
125 | print args
126 | print ""
127 | main(args)
128 |
--------------------------------------------------------------------------------
/src/evaluation/read_crowdflower_results_binary.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import csv
3 | from collections import defaultdict
4 | import sys
5 | import numpy as np
6 | import pdb
7 |
8 | generic_levels = {'Yes': 1, 'No': 0}
9 |
10 | model_dict = {'ref': 0, 'lucene': 1, 'seq2seq.beam': 2,
11 | 'rl.beam': 3,
12 | 'gan.beam': 4}
13 |
14 | model_list = ['ref', 'lucene', 'seq2seq.beam',
15 | 'rl.beam',
16 | 'gan.beam']
17 |
18 |
19 | def get_avg_score(score_dict, ignore_na=False):
20 | curr_on_topic_score = 0.
21 | N = 0
22 | for score, count in score_dict.iteritems():
23 | curr_on_topic_score += score * count
24 | if ignore_na:
25 | if score != 0:
26 | N += count
27 | else:
28 | N += count
29 | # print N
30 | return curr_on_topic_score * 1.0 / N
31 |
32 |
33 | def main(args):
34 | num_models = len(model_list)
35 | generic_scores = [None] * num_models
36 | asins_so_far = [None] * num_models
37 | for i in range(num_models):
38 | generic_scores[i] = defaultdict(int)
39 | asins_so_far[i] = []
40 |
41 | with open(args.aggregate_results) as csvfile:
42 | reader = csv.DictReader(csvfile)
43 | for row in reader:
44 | if row['_golden'] == 'true' or row['_unit_state'] == 'golden':
45 | continue
46 | asin = row['asin']
47 | question = row['question']
48 | model_name = row['model_name']
49 | if model_name not in model_list:
50 | continue
51 | if asin not in asins_so_far[model_dict[model_name]]:
52 | asins_so_far[model_dict[model_name]].append(asin)
53 | else:
54 | print '%s duplicate %s' % (model_name, asin)
55 | continue
56 | generic_score = generic_levels[row['on_topic']]
57 | generic_scores[model_dict[model_name]][generic_score] += 1
58 |
59 | for i in range(num_models):
60 | print model_list[i]
61 | print len(asins_so_far[i])
62 | print 'Avg on generic score: %.2f' % get_avg_score(generic_scores[i])
63 | print 'Generic:', generic_scores[i]
64 |
65 |
66 | if __name__ == '__main__':
67 | argparser = argparse.ArgumentParser(sys.argv[0])
68 | argparser.add_argument("--aggregate_results", type = str)
69 | args = argparser.parse_args()
70 | print args
71 | print ""
72 | main(args)
--------------------------------------------------------------------------------
/src/evaluation/read_crowdflower_results_compare_ques.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import csv
3 | from collections import defaultdict
4 | import sys
5 | import pdb
6 | import numpy as np
7 |
8 | specificity_levels = ['Question A is more specific',
9 | 'Question B is more specific',
10 | 'Both are at the same level of specificity',
11 | 'N/A: One or both questions are not applicable to the product']
12 |
13 |
14 | def get_avg_score(score_dict, ignore_na=False):
15 | curr_on_topic_score = 0.
16 | N = 0
17 | for score, count in score_dict.iteritems():
18 | curr_on_topic_score += score * count
19 | if ignore_na:
20 | if score != 0:
21 | N += count
22 | else:
23 | N += count
24 | # print N
25 | return curr_on_topic_score * 1.0 / N
26 |
27 |
28 | def main(args):
29 | titles = {}
30 | descriptions = {}
31 | cand_ques_dict = defaultdict(list)
32 | cand_scores_dict = defaultdict(list)
33 | with open(args.aggregate_results) as csvfile:
34 | reader = csv.DictReader(csvfile)
35 | for row in reader:
36 | if row['_unit_state'] != 'finalized':
37 | continue
38 | asin = row['asin']
39 | titles[asin] = row['title']
40 | descriptions[asin] = row['description']
41 | question_a = row['question_a']
42 | question_b = row['question_b']
43 | confidence = row['on_topic:confidence']
44 |
45 | if question_a not in cand_ques_dict[asin]:
46 | cand_ques_dict[asin].append(question_a)
47 | cand_scores_dict[asin].append([])
48 | if question_b not in cand_ques_dict[asin]:
49 | cand_ques_dict[asin].append(question_b)
50 | cand_scores_dict[asin].append([])
51 |
52 | if row['on_topic'] == 'Question A is more specific':
53 | cand_scores_dict[asin][cand_ques_dict[asin].index(question_a)].append(float(confidence))
54 | cand_scores_dict[asin][cand_ques_dict[asin].index(question_b)].append((1 - float(confidence)))
55 | elif row['on_topic'] == 'Question B is more specific':
56 | cand_scores_dict[asin][cand_ques_dict[asin].index(question_b)].append(float(confidence))
57 | cand_scores_dict[asin][cand_ques_dict[asin].index(question_a)].append((1 - float(confidence)))
58 | elif row['on_topic'] == 'Both are at the same level of specificity':
59 | cand_scores_dict[asin][cand_ques_dict[asin].index(question_b)].append(0.5)
60 | cand_scores_dict[asin][cand_ques_dict[asin].index(question_a)].append(0.5)
61 | else:
62 | print 'ID: %s has irrelevant question' % asin
63 |
64 | for asin in titles:
65 | print asin
66 | print titles[asin]
67 | print descriptions[asin]
68 | for i, ques in enumerate(cand_ques_dict[asin]):
69 | print np.mean(cand_scores_dict[asin][i]), cand_scores_dict[asin][i], ques
70 | print
71 |
72 |
73 | if __name__ == '__main__':
74 | argparser = argparse.ArgumentParser(sys.argv[0])
75 | argparser.add_argument("--aggregate_results", type = str)
76 | args = argparser.parse_args()
77 | print args
78 | print ""
79 | main(args)
--------------------------------------------------------------------------------
/src/evaluation/read_crowdflower_results_style.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import csv
3 | from collections import defaultdict
4 | import sys
5 | import numpy as np
6 | import pdb
7 |
8 | relevance_levels = {'Completely makes sense': 2,
9 | 'Somewhat makes sense': 1,
10 | 'Does not make sense': 0}
11 | is_grammatical_levels = {'Grammatical': 2, 'Comprehensible': 1, 'Incomprehensible': 0}
12 | is_specific_levels = {'Specific to this or the same product from a different manufacturer': 3,
13 | 'Specific to this or some other similar products': 2,
14 | 'Generic enough to be applicable to many other products of this type': 1,
15 | 'Generic enough to be applicable to any product under Home and Kitchen': 0,
16 | 'N/A (Not Applicable)': 0}
17 | asks_new_info_levels = {'Yes': 1, 'No': 0, 'N/A (Not Applicable)': 0}
18 |
19 | model_dict = {'ref': 0, 'lucene': 1, 'seq2seq': 2,
20 | 'seq2seq.generic': 3,
21 | 'seq2seq.specific': 4}
22 |
23 | model_list = ['ref', 'lucene', 'seq2seq',
24 | 'seq2seq.generic',
25 | 'seq2seq.specific']
26 |
27 |
28 | def get_avg_score(score_dict, ignore_na=False):
29 | curr_relevance_score = 0.
30 | N = 0
31 | for score, count in score_dict.iteritems():
32 | curr_relevance_score += score * count
33 | if ignore_na:
34 | if score != 0:
35 | N += count
36 | else:
37 | N += count
38 | # print N
39 | return curr_relevance_score * 1.0 / N
40 |
41 |
42 | def main(args):
43 | num_models = len(model_list)
44 | relevance_scores = [None] * num_models
45 | is_grammatical_scores = [None] * num_models
46 | is_specific_scores = [None] * num_models
47 | asks_new_info_scores = [None] * num_models
48 | asins_so_far = [None] * num_models
49 | for i in range(num_models):
50 | relevance_scores[i] = defaultdict(int)
51 | is_grammatical_scores[i] = defaultdict(int)
52 | is_specific_scores[i] = defaultdict(int)
53 | asks_new_info_scores[i] = defaultdict(int)
54 | asins_so_far[i] = []
55 |
56 | with open(args.aggregate_results) as csvfile:
57 | reader = csv.DictReader(csvfile)
58 | for row in reader:
59 | if row['_golden'] == 'true' or row['_unit_state'] == 'golden':
60 | continue
61 | asin = row['asin']
62 | question = row['question']
63 | model_name = row['model_name']
64 | if model_name not in model_list:
65 | continue
66 | if asin not in asins_so_far[model_dict[model_name]]:
67 | asins_so_far[model_dict[model_name]].append(asin)
68 | else:
69 | #print '%s duplicate %s' % (model_name, asin)
70 | continue
71 | relevance_score = relevance_levels[row['makes_sense']]
72 | is_grammatical_score = is_grammatical_levels[row['grammatical']]
73 | specific_score = is_specific_levels[row['is_specific']]
74 | asks_new_info_score = asks_new_info_levels[row['new_info']]
75 | relevance_scores[model_dict[model_name]][relevance_score] += 1
76 | is_grammatical_scores[model_dict[model_name]][is_grammatical_score] += 1
77 | if relevance_score != 0 and is_grammatical_score != 0:
78 | is_specific_scores[model_dict[model_name]][specific_score] += 1
79 | asks_new_info_scores[model_dict[model_name]][asks_new_info_score] += 1
80 |
81 | for i in range(num_models):
82 | print model_list[i]
83 | print len(asins_so_far[i])
84 | print 'Avg on topic score: %.2f' % get_avg_score(relevance_scores[i])
85 | print 'Avg grammaticality score: %.2f' % get_avg_score(is_grammatical_scores[i])
86 | print 'Avg specificity score: %.2f' % get_avg_score(is_specific_scores[i])
87 | print 'Avg new info score: %.2f' % get_avg_score(asks_new_info_scores[i])
88 | print
89 | print 'On topic:', relevance_scores[i]
90 | print 'Is grammatical: ', is_grammatical_scores[i]
91 | print 'Is specific: ', is_specific_scores[i]
92 | print 'Asks new info: ', asks_new_info_scores[i]
93 | print
94 |
95 |
96 | if __name__ == '__main__':
97 | argparser = argparse.ArgumentParser(sys.argv[0])
98 | argparser.add_argument("--aggregate_results", type = str)
99 | args = argparser.parse_args()
100 | print args
101 | print ""
102 | main(args)
103 |
--------------------------------------------------------------------------------
/src/evaluation/run_bleu.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | SITENAME=Home_and_Kitchen
4 | CQ_DATA_DIR=/fs/clip-amr/clarification_question_generation_pytorch/$SITENAME
5 | BLEU_SCRIPT=/fs/clip-software/user-supported/mosesdecoder/3.0/scripts/generic/multi-bleu.perl
6 |
7 | $BLEU_SCRIPT $CQ_DATA_DIR/test_ref < $CQ_DATA_DIR/GAN_test_pred_question.txt.epoch8.hasrefs
8 |
--------------------------------------------------------------------------------
/src/evaluation/run_create_amazon_multi_refs.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | #SBATCH --job-name=create_amazon_multi_refs
4 | #SBATCH --output=create_amazon_multi_refs
5 | #SBATCH --qos=batch
6 | #SBATCH --mem=4g
7 | #SBATCH --time=4:00:00
8 |
9 | SITENAME=Home_and_Kitchen
10 | DATA_DIR=/fs/clip-corpora/amazon_qa/$SITENAME
11 | SCRIPT_DIR=/fs/clip-amr/clarification_question_generation_pytorch/src/evaluation/
12 | CQ_DATA_DIR=/fs/clip-amr/clarification_question_generation_pytorch/$SITENAME
13 |
14 | python $SCRIPT_DIR/create_amazon_multi_refs.py --ques_dir $DATA_DIR/ques_docs \
15 | --test_ids_file $CQ_DATA_DIR/blind_test_pred_question.txt.GAN_selfcritic_pred_ans_3perid.epoch8.len_norm.beam0.ids \
16 | --ref_prefix $CQ_DATA_DIR/test_ref \
17 | --test_context_file $CQ_DATA_DIR/test_context.txt
18 |
--------------------------------------------------------------------------------
/src/evaluation/run_create_crowdflower_data.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | #SBATCH --job-name=create_crowdflower_HK_beam_batch4
4 | #SBATCH --output=create_crowdflower_HK_beam_batch4
5 | #SBATCH --qos=batch
6 | #SBATCH --mem=32g
7 | #SBATCH --time=4:00:00
8 |
9 | SITENAME=Home_and_Kitchen
10 | CORPORA_DIR=/fs/clip-corpora/amazon_qa
11 | DATA_DIR=/fs/clip-scratch/raosudha/clarification_question_generation/joint_learning/$SITENAME
12 | CROWDFLOWER_DIR=/fs/clip-amr/clarification_question_generation_pytorch/evaluation/$SITENAME
13 | SCRIPT_DIR=/fs/clip-amr/clarification_question_generation_pytorch/src/evaluation
14 |
15 | export PATH="/fs/clip-amr/anaconda2/bin:$PATH"
16 |
17 | python $SCRIPT_DIR/create_crowdflower_data.py --qa_data_fname $CORPORA_DIR/qa_${SITENAME}.json.gz \
18 | --metadata_fname $CORPORA_DIR/meta_${SITENAME}.json.gz \
19 | --batch1_csv_file $CROWDFLOWER_DIR/crowdflower_lucene_seq2seq_rl_gan_diverse_beam_epoch8.batch1.csv \
20 | --batch2_csv_file $CROWDFLOWER_DIR/crowdflower_lucene_seq2seq_rl_gan_beam_epoch8.batch2.csv \
21 | --batch3_csv_file $CROWDFLOWER_DIR/crowdflower_lucene_seq2seq_rl_gan_beam_epoch8.batch3.csv \
22 | --csv_file $CROWDFLOWER_DIR/crowdflower_lucene_seq2seq_rl_gan_beam_epoch8.batch4.csv \
23 | --lucene_model_name lucene \
24 | --lucene_model_fname $DATA_DIR/blind_test_pred_question.lucene.txt \
25 | --seq2seq_model_name seq2seq.beam \
26 | --seq2seq_model_fname $DATA_DIR/blind_test_pred_question.txt.seq2seq.len_norm.beam0 \
27 | --rl_model_name rl.beam \
28 | --rl_model_fname $DATA_DIR/blind_test_pred_question.txt.RL_selfcritic.epoch8.len_norm.beam0 \
29 | --gan_model_name gan.beam \
30 | --gan_model_fname $DATA_DIR/blind_test_pred_question.txt.GAN_selfcritic_pred_ans_3perid.epoch8.len_norm.beam0 \
31 |
--------------------------------------------------------------------------------
/src/evaluation/run_create_crowdflower_data_beam.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | #SBATCH --job-name=create_crowdflower_HK_beam
4 | #SBATCH --output=create_crowdflower_HK_beam
5 | #SBATCH --qos=batch
6 | #SBATCH --mem=32g
7 | #SBATCH --time=4:00:00
8 |
9 | SITENAME=Home_and_Kitchen
10 | CORPORA_DIR=/fs/clip-corpora/amazon_qa
11 | DATA_DIR=/fs/clip-scratch/raosudha/clarification_question_generation/joint_learning/$SITENAME
12 | CROWDFLOWER_DIR=/fs/clip-amr/clarification_question_generation_pytorch/evaluation/$SITENAME
13 | SCRIPT_DIR=/fs/clip-amr/clarification_question_generation_pytorch/src/evaluation
14 |
15 | export PATH="/fs/clip-amr/anaconda2/bin:$PATH"
16 |
17 | python $SCRIPT_DIR/create_crowdflower_data_beam.py --previous_csv_file $CROWDFLOWER_DIR/crowdflower_lucene_seq2seq_rl_gan_diverse_beam_seq2seq_beam_gan_beam_rl_beam_epoch8.batch1.aggregate.csv \
18 | --output_csv_file $CROWDFLOWER_DIR/crowdflower_lucene_seq2seq_rl_gan_beam.batch1.csv \
--------------------------------------------------------------------------------
/src/evaluation/run_create_crowdflower_data_compare_ques.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | #SBATCH --job-name=create_crowdflower_HK_compare_batchD
4 | #SBATCH --output=create_crowdflower_HK_compare_batchD
5 | #SBATCH --qos=batch
6 | #SBATCH --mem=32g
7 | #SBATCH --time=4:00:00
8 |
9 | SITENAME=Home_and_Kitchen
10 | CORPORA_DIR=/fs/clip-corpora/amazon_qa
11 | DATA_DIR=/fs/clip-amr/clarification_question_generation_pytorch/$SITENAME
12 | CROWDFLOWER_DIR=/fs/clip-amr/clarification_question_generation_pytorch/evaluation/$SITENAME
13 | SCRIPT_DIR=/fs/clip-amr/clarification_question_generation_pytorch/src/evaluation
14 |
15 | export PATH="/fs/clip-amr/anaconda2/bin:$PATH"
16 |
17 | python $SCRIPT_DIR/create_crowdflower_data_compare_ques.py --qa_data_fname $CORPORA_DIR/qa_${SITENAME}.json.gz \
18 | --metadata_fname $CORPORA_DIR/meta_${SITENAME}.json.gz \
19 | --csv_file $CROWDFLOWER_DIR/crowdflower_compare_ques_batchD_100.csv \
20 | --train_asins $DATA_DIR/train_asin.txt \
21 | --previous_csv_file_v1 $CROWDFLOWER_DIR/crowdflower_compare_ques_batchA_100.csv \
22 | --previous_csv_file_v2 $CROWDFLOWER_DIR/crowdflower_compare_ques_allpairs.csv \
23 | --previous_csv_file_v3 $CROWDFLOWER_DIR/crowdflower_compare_ques_batchB_100.csv \
24 | --previous_csv_file_v4 $CROWDFLOWER_DIR/crowdflower_compare_ques_batchC_100.csv \
25 |
--------------------------------------------------------------------------------
/src/evaluation/run_create_crowdflower_data_compare_ques_allpairs.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | #SBATCH --job-name=create_crowdflower_HK_compare_allpairs
4 | #SBATCH --output=create_crowdflower_HK_compare_allpairs
5 | #SBATCH --qos=batch
6 | #SBATCH --mem=32g
7 | #SBATCH --time=4:00:00
8 |
9 | SITENAME=Home_and_Kitchen
10 | CORPORA_DIR=/fs/clip-corpora/amazon_qa
11 | DATA_DIR=/fs/clip-scratch/raosudha/clarification_question_generation/joint_learning/$SITENAME
12 | CROWDFLOWER_DIR=/fs/clip-amr/clarification_question_generation_pytorch/evaluation/$SITENAME
13 | SCRIPT_DIR=/fs/clip-amr/clarification_question_generation_pytorch/src/evaluation
14 |
15 | export PATH="/fs/clip-amr/anaconda2/bin:$PATH"
16 |
17 | python $SCRIPT_DIR/create_crowdflower_data_compare_ques_allpairs.py --qa_data_fname $CORPORA_DIR/qa_${SITENAME}.json.gz \
18 | --metadata_fname $CORPORA_DIR/meta_${SITENAME}.json.gz \
19 | --csv_file $CROWDFLOWER_DIR/crowdflower_compare_ques_allpairs.csv \
20 | --train_asins $DATA_DIR/train_asin.txt \
21 | --previous_csv_file $CROWDFLOWER_DIR/crowdflower_compare_ques_pilot.csv
22 |
--------------------------------------------------------------------------------
/src/evaluation/run_create_crowdflower_data_specificity.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | #SBATCH --job-name=create_crowdflower_HK_beam_spec_single_sent
4 | #SBATCH --output=create_crowdflower_HK_beam_spec_single_sent
5 | #SBATCH --qos=batch
6 | #SBATCH --mem=32g
7 | #SBATCH --time=4:00:00
8 |
9 | SITENAME=Home_and_Kitchen
10 | CORPORA_DIR=/fs/clip-corpora/amazon_qa
11 | DATA_DIR=/fs/clip-amr/clarification_question_generation_pytorch/$SITENAME
12 | CROWDFLOWER_DIR=/fs/clip-amr/clarification_question_generation_pytorch/evaluation/$SITENAME
13 | SCRIPT_DIR=/fs/clip-amr/clarification_question_generation_pytorch/src/evaluation
14 |
15 | export PATH="/fs/clip-amr/anaconda2/bin:$PATH"
16 |
17 | python $SCRIPT_DIR/create_crowdflower_data_specificity.py --qa_data_fname $CORPORA_DIR/qa_${SITENAME}.json.gz \
18 | --metadata_fname $CORPORA_DIR/meta_${SITENAME}.json.gz \
19 | --csv_file $CROWDFLOWER_DIR/crowdflower_seq2seq_epoch100_seq2seq_specific_seq2seq_generic_p100_q30_style_emb_single_sent.epoch100.csv \
20 | --lucene_model_name lucene \
21 | --lucene_model_fname $DATA_DIR/blind_test_pred_question.lucene.txt \
22 | --seq2seq_model_name seq2seq \
23 | --seq2seq_model_fname $DATA_DIR/blind_test_pred_ques.txt.seq2seq.epoch100.beam0 \
24 | --seq2seq_specific_model_name seq2seq.specific \
25 | --seq2seq_specific_model_fname $DATA_DIR/blind_test_pred_ques.txt.seq2seq_tobeginning_tospecific_p100_q30_style_emb.epoch100.beam0 \
26 | --seq2seq_generic_model_name seq2seq.generic \
27 | --seq2seq_generic_model_fname $DATA_DIR/blind_test_pred_ques.txt.seq2seq_tobeginning_togeneric_p100_q30_style_emb.epoch100.beam0 \
28 |
--------------------------------------------------------------------------------
/src/evaluation/run_create_crowdflower_data_specificity_multi.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | #SBATCH --job-name=create_crowdflower_HK_beam_spec_multi_sent
4 | #SBATCH --output=create_crowdflower_HK_beam_spec_multi_sent
5 | #SBATCH --qos=batch
6 | #SBATCH --mem=32g
7 | #SBATCH --time=4:00:00
8 |
9 | SITENAME=Home_and_Kitchen
10 | CORPORA_DIR=/fs/clip-corpora/amazon_qa
11 | DATA_DIR=/fs/clip-amr/clarification_question_generation_pytorch/$SITENAME
12 | CROWDFLOWER_DIR=/fs/clip-amr/clarification_question_generation_pytorch/evaluation/$SITENAME
13 | SCRIPT_DIR=/fs/clip-amr/clarification_question_generation_pytorch/src/evaluation
14 |
15 | export PATH="/fs/clip-amr/anaconda2/bin:$PATH"
16 |
17 | python $SCRIPT_DIR/create_crowdflower_data_specificity.py --qa_data_fname $CORPORA_DIR/qa_${SITENAME}.json.gz \
18 | --metadata_fname $CORPORA_DIR/meta_${SITENAME}.json.gz \
19 | --prev_csv_file $CROWDFLOWER_DIR/crowdflower_seq2seq_epoch100_seq2seq_specific_seq2seq_generic_p100_q30_style_emb_single_sent.epoch100.csv \
20 | --csv_file $CROWDFLOWER_DIR/crowdflower_seq2seq_epoch100_seq2seq_specific_seq2seq_generic_p100_q30_style_emb_multi_sent.epoch100.csv \
21 | --lucene_model_name lucene \
22 | --lucene_model_fname $DATA_DIR/blind_test_pred_question.lucene.txt \
23 | --seq2seq_model_name seq2seq \
24 | --seq2seq_model_fname $DATA_DIR/blind_test_pred_ques.txt.seq2seq.epoch100.beam0 \
25 | --seq2seq_specific_model_name seq2seq.specific \
26 | --seq2seq_specific_model_fname $DATA_DIR/blind_test_pred_ques.txt.seq2seq_tobeginning_tospecific_p100_q30_style_emb.epoch100.beam0 \
27 | --seq2seq_generic_model_name seq2seq.generic \
28 | --seq2seq_generic_model_fname $DATA_DIR/blind_test_pred_ques.txt.seq2seq_tobeginning_togeneric_p100_q30_style_emb.epoch100.beam0 \
29 |
--------------------------------------------------------------------------------
/src/evaluation/run_create_preds_for_refs.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | #SBATCH --qos=batch
4 | #SBATCH --mem=4g
5 | #SBATCH --time=05:00:00
6 |
7 | SITENAME=askubuntu_unix_superuser
8 | CQ_DATA_DIR=/fs/clip-amr/clarification_question_generation_pytorch/$SITENAME
9 | SCRIPT_DIR=/fs/clip-amr/clarification_question_generation_pytorch/src/evaluation
10 |
11 | PRED_FILE=test_pred_question.txt.GAN_selfcritic_pred_ans_util_dis.epoch8
12 |
13 | python $SCRIPT_DIR/create_preds_for_refs.py --qa_data_tsvfile $CQ_DATA_DIR/qa_data.tsv \
14 | --test_ids_file $CQ_DATA_DIR/test_ids \
15 | --human_annotations $CQ_DATA_DIR/human_annotations \
16 | --model_output_file $CQ_DATA_DIR/$PRED_FILE
17 |
18 |
--------------------------------------------------------------------------------
/src/evaluation/run_meteor.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | #SBATCH --job-name=meteor
4 | #SBATCH --output=meteor
5 | #SBATCH --qos=batch
6 | #SBATCH --mem=4g
7 | #SBATCH --time=4:00:00
8 |
9 | SITENAME=askubuntu_unix_superuser
10 |
11 | METEOR=/fs/clip-software/user-supported/meteor-1.5
12 | CQ_DATA_DIR=/fs/clip-amr/clarification_question_generation_pytorch/$SITENAME
13 | RESULTS_DIR=/fs/clip-amr/clarification_question_generation_pytorch/evaluation/results/$SITENAME
14 |
15 | TEST_SET=test_pred_question.txt.RL_selfcritic.epoch8.hasrefs.nounks
16 |
17 | java -Xmx2G -jar $METEOR/meteor-1.5.jar $CQ_DATA_DIR/$TEST_SET $CQ_DATA_DIR/test_ref_combined \
18 | -l en -norm -r 6 \
19 | > $RESULTS_DIR/${TEST_SET}.meteor
20 |
--------------------------------------------------------------------------------
/src/evaluation/run_meteor_HK.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | #SBATCH --job-name=meteor
4 | #SBATCH --qos=batch
5 | #SBATCH --mem=4g
6 | #SBATCH --time=4:00:00
7 |
8 | METEOR=/fs/clip-software/user-supported/meteor-1.5
9 | CQ_DATA_DIR=/fs/clip-amr/clarification_question_generation_pytorch/Home_and_Kitchen/
10 | RESULTS_DIR=/fs/clip-amr/clarification_question_generation_pytorch/evaluation/results/Home_and_Kitchen
11 |
12 | TEST_SET=test_pred_ques.txt.seq2seq.epoch100.beam0.nounks
13 |
14 | java -Xmx2G -jar $METEOR/meteor-1.5.jar $CQ_DATA_DIR/$TEST_SET $CQ_DATA_DIR/test_ref_combined \
15 | -l en -norm -r 10 \
16 | > $RESULTS_DIR/${TEST_SET}.meteor
17 |
--------------------------------------------------------------------------------
/src/lucene/create_amazon_lucene_baseline.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import sys, os
3 | from collections import defaultdict
4 | import csv
5 | import math
6 | import pdb
7 | import random
8 | import gzip
9 |
10 |
11 | def parse(path):
12 | g = gzip.open(path, 'r')
13 | for l in g:
14 | yield eval(l)
15 |
16 |
17 | def get_brand_info(metadata_fname):
18 | brand_info = {}
19 | for v in parse(metadata_fname):
20 | if 'description' not in v or 'title' not in v:
21 | continue
22 | asin = v['asin']
23 | if 'brand' not in v.keys():
24 | brand_info[asin] = None
25 | else:
26 | brand_info[asin] = v['brand']
27 | return brand_info
28 |
29 |
30 | def create_lucene_preds(ids, args, quess, sim_prod):
31 | pred_file = open(args.lucene_pred_fname, 'w')
32 | for test_id in ids:
33 | prod_id = test_id.split('_')[0]
34 | choices = []
35 | for i in range(min(len(sim_prod[prod_id]), 3)):
36 | choices += quess[sim_prod[prod_id][i]][:3]
37 | if len(choices) == 0:
38 | pred_file.write('\n')
39 | else:
40 | pred_ques = random.choice(choices)
41 | pred_file.write(pred_ques+'\n')
42 | pred_file.close()
43 |
44 |
45 | def get_sim_docs(sim_docs_filename, brand_info):
46 | sim_docs_file = open(sim_docs_filename, 'r')
47 | sim_docs = {}
48 | for line in sim_docs_file.readlines():
49 | parts = line.split()
50 | #sim_docs[parts[0]] = parts[1:]
51 | asin = parts[0]
52 | sim_docs[asin] = []
53 | if len(parts[1:]) == 0:
54 | continue
55 | for prod_id in parts[2:]:
56 | if brand_info[prod_id] and (brand_info[prod_id] != brand_info[asin]):
57 | sim_docs[asin].append(prod_id)
58 | if len(sim_docs[asin]) == 0:
59 | sim_docs[asin] = parts[10:13]
60 | if len(sim_docs[asin]) == 0:
61 | pdb.set_trace()
62 | return sim_docs
63 |
64 |
65 | def read_data(args):
66 | print("Reading lines...")
67 | quess = {}
68 | test_ids = [test_id.strip('\n') for test_id in open(args.test_ids_file, 'r').readlines()]
69 | print 'No. of test ids: %d' % len(test_ids)
70 | quess_rand = defaultdict(list)
71 | for fname in os.listdir(args.ques_dir):
72 | with open(os.path.join(args.ques_dir, fname), 'r') as f:
73 | ques_id = fname[:-4]
74 | asin, q_no = ques_id.split('_')
75 | ques = f.readline().strip('\n')
76 | quess_rand[asin].append((ques, q_no))
77 |
78 | for asin in quess_rand:
79 | quess[asin] = [None]*len(quess_rand[asin])
80 | for (ques, q_no) in quess_rand[asin]:
81 | q_no = int(q_no)-1
82 | quess[asin][q_no] = ques
83 |
84 | brand_info = get_brand_info(args.metadata_fname)
85 | sim_prod = get_sim_docs(args.sim_prod_fname, brand_info)
86 | create_lucene_preds(test_ids, args, quess, sim_prod)
87 |
88 |
89 | if __name__ == "__main__":
90 | argparser = argparse.ArgumentParser(sys.argv[0])
91 | argparser.add_argument("--ques_dir", type = str)
92 | argparser.add_argument("--sim_prod_fname", type = str)
93 | argparser.add_argument("--test_ids_file", type = str)
94 | argparser.add_argument("--lucene_pred_fname", type = str)
95 | argparser.add_argument("--metadata_fname", type = str)
96 | args = argparser.parse_args()
97 | print args
98 | print ""
99 | read_data(args)
100 |
--------------------------------------------------------------------------------
/src/lucene/create_stackexchange_lucene_baseline.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import sys, os
3 | from collections import defaultdict
4 | import csv
5 | import math
6 | import pdb
7 | import random
8 |
9 | def read_data(args):
10 | print("Reading lines...")
11 | posts = {}
12 | question_candidates = {}
13 | with open(args.post_data_tsvfile, 'rb') as tsvfile:
14 | post_reader = csv.reader(tsvfile, delimiter='\t')
15 | N = 0
16 | for row in post_reader:
17 | if N == 0:
18 | N += 1
19 | continue
20 | N += 1
21 | post_id,title,post = row
22 | post = title + ' ' + post
23 | post = post.lower().strip()
24 | posts[post_id] = post
25 |
26 | with open(args.qa_data_tsvfile, 'rb') as tsvfile:
27 | qa_reader = csv.reader(tsvfile, delimiter='\t')
28 | i = 0
29 | for row in qa_reader:
30 | if i == 0:
31 | i += 1
32 | continue
33 | post_id,questions = row[0], row[2:11] #Ignore the first question since that is the true question
34 | question_candidates[post_id] = questions
35 |
36 | test_ids = [test_id.strip('\n') for test_id in open(args.test_ids_file, 'r').readlines()]
37 |
38 | lucene_out_file = open(args.lucene_output_file, 'w')
39 | for i, test_id in enumerate(test_ids):
40 | r = random.randint(0,8)
41 | lucene_out_file.write(question_candidates[test_id][r] + '\n')
42 | lucene_out_file.close()
43 |
44 | if __name__ == "__main__":
45 | argparser = argparse.ArgumentParser(sys.argv[0])
46 | argparser.add_argument("--post_data_tsvfile", type = str)
47 | argparser.add_argument("--qa_data_tsvfile", type = str)
48 | argparser.add_argument("--test_ids_file", type = str)
49 | argparser.add_argument("--lucene_output_file", type = str)
50 | args = argparser.parse_args()
51 | print args
52 | print ""
53 | read_data(args)
54 |
55 |
--------------------------------------------------------------------------------
/src/lucene/run_create_amazon_lucene_baseline.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | #SBATCH --job-name=lucene_Home_and_Kitchen_test
4 | #SBATCH --output=lucene_Home_and_Kitchen_test
5 | #SBATCH --qos=batch
6 | #SBATCH --mem=4g
7 | #SBATCH --time=4:00:00
8 |
9 | SITENAME=Home_and_Kitchen
10 | SCRIPT_DIR=/fs/clip-amr/clarification_question_generation_pytorch/src/lucene
11 | CQ_DATA_DIR=/fs/clip-amr/clarification_question_generation_pytorch/$SITENAME
12 |
13 | python $SCRIPT_DIR/create_amazon_lucene_baseline.py --ques_dir $CQ_DATA_DIR/ques_docs \
14 | --sim_prod_fname $CQ_DATA_DIR/lucene_similar_prods.txt \
15 | --test_ids_file $CQ_DATA_DIR/blind_test_pred_ques.txt.seq2seq.epoch100.beam0.ids \
16 | --lucene_pred_fname $CQ_DATA_DIR/blind_test_pred_question.lucene.txt \
17 | --metadata_fname $CQ_DATA_DIR/meta_Home_and_Kitchen.json.gz \
18 |
19 |
--------------------------------------------------------------------------------
/src/lucene/run_create_stackexchange_lucene_baseline.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | SITENAME=askubuntu_unix_superuser
4 | SCRIPT_DIR=/fs/clip-amr/clarification_question_generation_pytorch/src/lucene
5 | CQ_DATA_DIR=/fs/clip-amr/clarification_question_generation/$SITENAME
6 |
7 | python $SCRIPT_DIR/create_stackexchange_lucene_baseline.py --post_data_tsvfile $CQ_DATA_DIR/post_data.tsv \
8 | --qa_data_tsvfile $CQ_DATA_DIR/qa_data.tsv \
9 | --test_ids_file $CQ_DATA_DIR/test_ids \
10 | --lucene_output_file $CQ_DATA_DIR/test_pred_question_lucene.txt \
11 |
12 |
--------------------------------------------------------------------------------
/src/run_GAN_decode.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | SITENAME=askubuntu_unix_superuser
4 |
5 | CQ_DATA_DIR=clarification_question_generation_pytorch/$SITENAME
6 | SCRIPT_DIR=clarification_question_generation_pytorch/src
7 | EMB_DIR=clarification_question_generation_pytorch/embeddings/$SITENAME
8 |
9 | python $SCRIPT_DIR/decode.py --test_context $CQ_DATA_DIR/test_context.txt \
10 | --test_ques $CQ_DATA_DIR/test_question.txt \
11 | --test_ans $CQ_DATA_DIR/test_answer.txt \
12 | --test_ids $CQ_DATA_DIR/test_ids \
13 | --test_pred_ques $CQ_DATA_DIR/test_pred_question.txt \
14 | --q_encoder_params $CQ_DATA_DIR/q_encoder_params.epoch100.GAN_selfcritic_pred_ans.epoch12 \
15 | --q_decoder_params $CQ_DATA_DIR/q_decoder_params.epoch100.GAN_selfcritic_pred_ans.epoch12 \
16 | --word_embeddings $EMB_DIR/word_embeddings.p \
17 | --vocab $EMB_DIR/vocab.p \
18 | --model GAN.epoch12 \
19 | --max_post_len 100 \
20 | --max_ques_len 20 \
21 | --max_ans_len 20 \
22 | --batch_size 256 \
23 | --beam True \
24 |
25 |
--------------------------------------------------------------------------------
/src/run_GAN_decode_HK.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | SITENAME=Home_and_Kitchen
4 |
5 | CQ_DATA_DIR=clarification_question_generation_pytorch/$SITENAME
6 | SCRIPT_DIR=clarification_question_generation_pytorch/src
7 | EMB_DIR=clarification_question_generation_pytorch/embeddings/$SITENAME
8 |
9 | PARAMS_DIR=$CQ_DATA_DIR
10 |
11 | python $SCRIPT_DIR/decode.py --test_context $CQ_DATA_DIR/test_context.txt \
12 | --test_ques $CQ_DATA_DIR/test_ques.txt \
13 | --test_ans $CQ_DATA_DIR/test_ans.txt \
14 | --test_ids $CQ_DATA_DIR/test_asin.txt \
15 | --test_pred_ques $CQ_DATA_DIR/blind_test_pred_ques.txt \
16 | --q_encoder_params $PARAMS_DIR/q_encoder_params.epoch100.GAN_selfcritic_pred_ans.epoch12 \
17 | --q_decoder_params $PARAMS_DIR/q_decoder_params.epoch100.GAN_selfcritic_pred_ans.epoch12 \
18 | --word_embeddings $EMB_DIR/word_embeddings.p \
19 | --vocab $EMB_DIR/vocab.p \
20 | --model GAN_selfcritic_pred_ans.epoch12 \
21 | --max_post_len 100 \
22 | --max_ques_len 20 \
23 | --max_ans_len 20 \
24 | --batch_size 128 \
25 | --beam True
26 |
27 |
--------------------------------------------------------------------------------
/src/run_GAN_main.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | SITENAME=askubuntu_unix_superuser
4 |
5 | CQ_DATA_DIR=clarification_question_generation_pytorch/$SITENAME
6 | SCRIPT_DIR=clarification_question_generation_pytorch/src
7 | EMB_DIR=clarification_question_generation_pytorch/embeddings/$SITENAME
8 |
9 | python $SCRIPT_DIR/GAN_main.py --train_context $CQ_DATA_DIR/train_context.txt \
10 | --train_ques $CQ_DATA_DIR/train_question.txt \
11 | --train_ans $CQ_DATA_DIR/train_answer.txt \
12 | --tune_context $CQ_DATA_DIR/tune_context.txt \
13 | --tune_ques $CQ_DATA_DIR/tune_question.txt \
14 | --tune_ans $CQ_DATA_DIR/tune_answer.txt \
15 | --test_context $CQ_DATA_DIR/test_context.txt \
16 | --test_ques $CQ_DATA_DIR/test_question.txt \
17 | --test_ans $CQ_DATA_DIR/test_answer.txt \
18 | --test_pred_ques $CQ_DATA_DIR/GAN_test_pred_question.txt \
19 | --q_encoder_params $CQ_DATA_DIR/q_encoder_params.epoch100 \
20 | --q_decoder_params $CQ_DATA_DIR/q_decoder_params.epoch100 \
21 | --a_encoder_params $CQ_DATA_DIR/a_encoder_params.epoch100 \
22 | --a_decoder_params $CQ_DATA_DIR/a_decoder_params.epoch100 \
23 | --context_params $CQ_DATA_DIR/context_params.epoch10 \
24 | --question_params $CQ_DATA_DIR/question_params.epoch10 \
25 | --answer_params $CQ_DATA_DIR/answer_params.epoch10 \
26 | --utility_params $CQ_DATA_DIR/utility_params.epoch10 \
27 | --word_embeddings $EMB_DIR/word_embeddings.p \
28 | --vocab $EMB_DIR/vocab.p \
29 | --model GAN_selfcritic_pred_ans \
30 | --max_post_len 100 \
31 | --max_ques_len 20 \
32 | --max_ans_len 20 \
33 | --batch_size 256 \
34 | --n_epochs 40 \
35 |
36 |
--------------------------------------------------------------------------------
/src/run_GAN_main_HK.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | SITENAME=Home_and_Kitchen
4 |
5 | CQ_DATA_DIR=clarification_question_generation_pytorch/$SITENAME
6 | SCRIPT_DIR=clarification_question_generation_pytorch/src
7 | EMB_DIR=clarification_question_generation_pytorch/embeddings/$SITENAME
8 |
9 | PARAMS_DIR=$CQ_DATA_DIR
10 |
11 | python $SCRIPT_DIR/GAN_main.py --train_context $CQ_DATA_DIR/train_context.txt \
12 | --train_ques $CQ_DATA_DIR/train_ques.txt \
13 | --train_ans $CQ_DATA_DIR/train_ans.txt \
14 | --train_ids $CQ_DATA_DIR/train_asin.txt \
15 | --tune_context $CQ_DATA_DIR/tune_context.txt \
16 | --tune_ques $CQ_DATA_DIR/tune_ques.txt \
17 | --tune_ans $CQ_DATA_DIR/tune_ans.txt \
18 | --tune_ids $CQ_DATA_DIR/tune_asin.txt \
19 | --test_context $CQ_DATA_DIR/test_context.txt \
20 | --test_ques $CQ_DATA_DIR/test_ques.txt \
21 | --test_ans $CQ_DATA_DIR/test_ans.txt \
22 | --test_ids $CQ_DATA_DIR/test_asin.txt \
23 | --test_pred_ques $CQ_DATA_DIR/blind_test_pred_ques.txt \
24 | --q_encoder_params $PARAMS_DIR/q_encoder_params.epoch100 \
25 | --q_decoder_params $PARAMS_DIR/q_decoder_params.epoch100 \
26 | --a_encoder_params $PARAMS_DIR/a_encoder_params.epoch100 \
27 | --a_decoder_params $PARAMS_DIR/a_decoder_params.epoch100 \
28 | --context_params $PARAMS_DIR/context_params.epoch10 \
29 | --question_params $PARAMS_DIR/question_params.epoch10 \
30 | --answer_params $PARAMS_DIR/answer_params.epoch10 \
31 | --utility_params $PARAMS_DIR/utility_params.epoch10 \
32 | --word_embeddings $EMB_DIR/word_embeddings.p \
33 | --vocab $EMB_DIR/vocab.p \
34 | --model GAN_selfcritic_pred_ans \
35 | --max_post_len 100 \
36 | --max_ques_len 20 \
37 | --max_ans_len 20 \
38 | --batch_size 64 \
39 | --n_epochs 20 \
40 |
41 |
--------------------------------------------------------------------------------
/src/run_GAN_main_electronics.sh:
--------------------------------------------------------------------------------
1 | python src/GAN_main.py --train_context baseline_data/train_context.txt \
2 | --train_ques baseline_data/train_question.txt \
3 | --train_ans baseline_data/train_answer.txt \
4 | --train_ids baseline_data/train_asin.txt \
5 | --tune_context baseline_data/valid_context.txt \
6 | --tune_ques baseline_data/valid_question.txt \
7 | --tune_ans baseline_data/valid_answer.txt \
8 | --tune_ids baseline_data/valid_asin.txt \
9 | --test_context baseline_data/test_context.txt \
10 | --test_ques baseline_data/test_question.txt \
11 | --test_ans baseline_data/test_answer.txt \
12 | --test_ids baseline_data/test_asin.txt \
13 | --test_pred_ques baseline_data/gan_ques49_ans43_util10_valid_predicted_question.txt \
14 | --q_encoder_params baseline_data/seq2seq-pretrain-ques-v3/q_encoder_params.epoch49 \
15 | --q_decoder_params baseline_data/seq2seq-pretrain-ques-v3/q_decoder_params.epoch49 \
16 | --a_encoder_params baseline_data/seq2seq-pretrain-ans-v3/a_encoder_params.epoch43 \
17 | --a_decoder_params baseline_data/seq2seq-pretrain-ans-v3/a_decoder_params.epoch43 \
18 | --context_params baseline_data/seq2seq-pretrain-util-v3/context_params.epoch10 \
19 | --question_params baseline_data/seq2seq-pretrain-util-v3/question_params.epoch10 \
20 | --answer_params baseline_data/seq2seq-pretrain-util-v3/answer_params.epoch10 \
21 | --utility_params baseline_data/seq2seq-pretrain-util-v3/utility_params.epoch10 \
22 | --word_embeddings embeddings/amazon_200d_embeddings.p \
23 | --vocab embeddings/amazon_200d_vocab.p \
24 | --model GAN_selfcritic_ques49_ans43_util10 \
25 | --max_post_len 100 \
26 | --max_ques_len 20 \
27 | --max_ans_len 20 \
28 | --batch_size 64 \
29 | --n_epochs 20 \
30 |
31 |
--------------------------------------------------------------------------------
/src/run_RL_decode.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | SITENAME=askubuntu_unix_superuser
4 |
5 | CQ_DATA_DIR=clarification_question_generation_pytorch/$SITENAME
6 | SCRIPT_DIR=clarification_question_generation_pytorch/src
7 | EMB_DIR=clarification_question_generation_pytorch/embeddings/$SITENAME
8 |
9 | python $SCRIPT_DIR/decode.py --test_context $CQ_DATA_DIR/test_context.txt \
10 | --test_ques $CQ_DATA_DIR/test_question.txt \
11 | --test_ans $CQ_DATA_DIR/test_answer.txt \
12 | --test_ids $CQ_DATA_DIR/test_ids \
13 | --test_pred_ques $CQ_DATA_DIR/test_pred_question.txt \
14 | --q_encoder_params $CQ_DATA_DIR/q_encoder_params.epoch100.RL_selfcritic.epoch8 \
15 | --q_decoder_params $CQ_DATA_DIR/q_decoder_params.epoch100.RL_selfcritic.epoch8 \
16 | --word_embeddings $EMB_DIR/word_embeddings.p \
17 | --vocab $EMB_DIR/vocab.p \
18 | --model RL.epoch8 \
19 | --max_post_len 100 \
20 | --max_ques_len 20 \
21 | --max_ans_len 20 \
22 | --batch_size 256 \
23 | --beam True \
24 |
25 |
--------------------------------------------------------------------------------
/src/run_RL_decode_HK.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | SITENAME=Home_and_Kitchen
4 |
5 | CQ_DATA_DIR=clarification_question_generation_pytorch/$SITENAME
6 | SCRIPT_DIR=clarification_question_generation_pytorch/src
7 | EMB_DIR=clarification_question_generation_pytorch/embeddings/$SITENAME
8 |
9 | PARAMS_DIR=$CQ_DATA_DIR
10 |
11 | python $SCRIPT_DIR/decode.py --test_context $CQ_DATA_DIR/test_context.txt \
12 | --test_ques $CQ_DATA_DIR/test_ques.txt \
13 | --test_ans $CQ_DATA_DIR/test_ans.txt \
14 | --test_ids $CQ_DATA_DIR/test_asin.txt \
15 | --test_pred_ques $CQ_DATA_DIR/blind_test_pred_ques.txt \
16 | --q_encoder_params $PARAMS_DIR/q_encoder_params.epoch100.RL_selfcritic.epoch8 \
17 | --q_decoder_params $PARAMS_DIR/q_decoder_params.epoch100.RL_selfcritic.epoch8 \
18 | --word_embeddings $EMB_DIR/word_embeddings.p \
19 | --vocab $EMB_DIR/vocab.p \
20 | --model RL_selfcritic.epoch8 \
21 | --max_post_len 100 \
22 | --max_ques_len 20 \
23 | --max_ans_len 20 \
24 | --batch_size 128 \
25 | --beam True
26 |
27 |
--------------------------------------------------------------------------------
/src/run_RL_main.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | SITENAME=askubuntu_unix_superuser
4 |
5 | CQ_DATA_DIR=clarification_question_generation_pytorch/$SITENAME
6 | SCRIPT_DIR=clarification_question_generation_pytorch/src
7 | EMB_DIR=clarification_question_generation_pytorch/embeddings/$SITENAME
8 |
9 | python $SCRIPT_DIR/RL_main.py --train_context $CQ_DATA_DIR/train_context.txt \
10 | --train_ques $CQ_DATA_DIR/train_question.txt \
11 | --train_ans $CQ_DATA_DIR/train_answer.txt \
12 | --tune_context $CQ_DATA_DIR/tune_context.txt \
13 | --tune_ques $CQ_DATA_DIR/tune_question.txt \
14 | --tune_ans $CQ_DATA_DIR/tune_answer.txt \
15 | --test_context $CQ_DATA_DIR/test_context.txt \
16 | --test_ques $CQ_DATA_DIR/test_question.txt \
17 | --test_ans $CQ_DATA_DIR/test_answer.txt \
18 | --test_pred_ques $CQ_DATA_DIR/test_pred_question.txt \
19 | --q_encoder_params $CQ_DATA_DIR/q_encoder_params.epoch100 \
20 | --q_decoder_params $CQ_DATA_DIR/q_decoder_params.epoch100 \
21 | --a_encoder_params $CQ_DATA_DIR/a_encoder_params.epoch100 \
22 | --a_decoder_params $CQ_DATA_DIR/a_decoder_params.epoch100 \
23 | --context_params $CQ_DATA_DIR/context_params.epoch10 \
24 | --question_params $CQ_DATA_DIR/question_params.epoch10 \
25 | --answer_params $CQ_DATA_DIR/answer_params.epoch10 \
26 | --utility_params $CQ_DATA_DIR/utility_params.epoch10 \
27 | --word_embeddings $EMB_DIR/word_embeddings.p \
28 | --vocab $EMB_DIR/vocab.p \
29 | --model RL_selfcritic \
30 | --max_post_len 100 \
31 | --max_ques_len 20 \
32 | --max_ans_len 20 \
33 | --batch_size 256 \
34 | --n_epochs 40 \
35 |
36 |
--------------------------------------------------------------------------------
/src/run_RL_main_HK.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | SITENAME=Home_and_Kitchen
4 |
5 | CQ_DATA_DIR=clarification_question_generation_pytorch/$SITENAME
6 | SCRIPT_DIR=clarification_question_generation_pytorch/src
7 | EMB_DIR=clarification_question_generation_pytorch/embeddings/$SITENAME
8 |
9 | PARAMS_DIR=$CQ_DATA_DIR
10 |
11 | python $SCRIPT_DIR/RL_main.py --train_context $CQ_DATA_DIR/train_context.txt \
12 | --train_ques $CQ_DATA_DIR/train_ques.txt \
13 | --train_ans $CQ_DATA_DIR/train_ans.txt \
14 | --train_ids $CQ_DATA_DIR/train_asin.txt \
15 | --tune_context $CQ_DATA_DIR/tune_context.txt \
16 | --tune_ques $CQ_DATA_DIR/tune_ques.txt \
17 | --tune_ans $CQ_DATA_DIR/tune_ans.txt \
18 | --tune_ids $CQ_DATA_DIR/tune_asin.txt \
19 | --test_context $CQ_DATA_DIR/test_context.txt \
20 | --test_ques $CQ_DATA_DIR/test_ques.txt \
21 | --test_ans $CQ_DATA_DIR/test_ans.txt \
22 | --test_ids $CQ_DATA_DIR/test_asin.txt \
23 | --test_pred_ques $CQ_DATA_DIR/test_pred_ques.txt \
24 | --q_encoder_params $PARAMS_DIR/q_encoder_params.epoch100 \
25 | --q_decoder_params $PARAMS_DIR/q_decoder_params.epoch100 \
26 | --a_encoder_params $PARAMS_DIR/a_encoder_params.epoch100 \
27 | --a_decoder_params $PARAMS_DIR/a_decoder_params.epoch100 \
28 | --context_params $PARAMS_DIR/context_params.epoch10 \
29 | --question_params $PARAMS_DIR/question_params.epoch10 \
30 | --answer_params $PARAMS_DIR/answer_params.epoch10 \
31 | --utility_params $PARAMS_DIR/utility_params.epoch10 \
32 | --word_embeddings $EMB_DIR/word_embeddings.p \
33 | --vocab $EMB_DIR/vocab.p \
34 | --model RL_selfcritic \
35 | --max_post_len 100 \
36 | --max_ques_len 20 \
37 | --max_ans_len 20 \
38 | --batch_size 128 \
39 | --n_epochs 40 \
40 |
--------------------------------------------------------------------------------
/src/run_decode.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | SITENAME=askubuntu_unix_superuser
4 |
5 | CQ_DATA_DIR=clarification_question_generation_pytorch/$SITENAME
6 | SCRIPT_DIR=clarification_question_generation_pytorch/src
7 | EMB_DIR=clarification_question_generation_pytorch/embeddings/$SITENAME
8 |
9 | python $SCRIPT_DIR/decode.py --test_context $CQ_DATA_DIR/test_context.txt \
10 | --test_ques $CQ_DATA_DIR/test_question.txt \
11 | --test_ans $CQ_DATA_DIR/test_answer.txt \
12 | --test_ids $CQ_DATA_DIR/test_ids \
13 | --test_pred_ques $CQ_DATA_DIR/test_pred_question.txt \
14 | --q_encoder_params $CQ_DATA_DIR/q_encoder_params.epoch100 \
15 | --q_decoder_params $CQ_DATA_DIR/q_decoder_params.epoch100 \
16 | --word_embeddings $EMB_DIR/word_embeddings.p \
17 | --vocab $EMB_DIR/vocab.p \
18 | --model seq2seq.epoch100 \
19 | --max_post_len 100 \
20 | --max_ques_len 20 \
21 | --max_ans_len 20 \
22 | --batch_size 256 \
23 | --beam True \
24 |
25 |
--------------------------------------------------------------------------------
/src/run_decode_HK.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | SITENAME=Home_and_Kitchen
4 |
5 | CQ_DATA_DIR=clarification_question_generation_pytorch/$SITENAME
6 | SCRIPT_DIR=clarification_question_generation_pytorch/src
7 | EMB_DIR=clarification_question_generation_pytorch/embeddings/$SITENAME
8 |
9 | PARAMS_DIR=$CQ_DATA_DIR
10 |
11 | python $SCRIPT_DIR/decode.py --test_context $CQ_DATA_DIR/test_context.txt \
12 | --test_ques $CQ_DATA_DIR/test_ques.txt \
13 | --test_ans $CQ_DATA_DIR/test_ans.txt \
14 | --test_ids $CQ_DATA_DIR/test_asin.txt \
15 | --test_pred_ques $CQ_DATA_DIR/blind_test_pred_ques.txt \
16 | --q_encoder_params $PARAMS_DIR/q_encoder_params.epoch100 \
17 | --q_decoder_params $PARAMS_DIR/q_decoder_params.epoch100 \
18 | --word_embeddings $EMB_DIR/word_embeddings.p \
19 | --vocab $EMB_DIR/vocab.p \
20 | --model seq2seq.epoch12 \
21 | --max_post_len 100 \
22 | --max_ques_len 20 \
23 | --max_ans_len 20 \
24 | --batch_size 128 \
25 | --beam True
26 |
27 |
--------------------------------------------------------------------------------
/src/run_decode_electronics.sh:
--------------------------------------------------------------------------------
1 | python src/decode.py --test_context baseline_data/valid_context.txt \
2 | --test_ques baseline_data/valid_question.txt \
3 | --test_ans baseline_data/valid_answer.txt \
4 | --test_ids baseline_data/valid_asin.txt \
5 | --test_pred_ques baseline_data/valid_predicted_question.txt \
6 | --q_encoder_params baseline_data/seq2seq-pretrain-ques-v3/q_encoder_params.epoch49 \
7 | --q_decoder_params baseline_data/seq2seq-pretrain-ques-v3/q_decoder_params.epoch49 \
8 | --word_embeddings embeddings/amazon_200d_embeddings.p \
9 | --vocab embeddings/amazon_200d_vocab.p \
10 | --model seq2seq.epoch49 \
11 | --max_post_len 100 \
12 | --max_ques_len 20 \
13 | --max_ans_len 20 \
14 | --batch_size 10 \
15 | --n_epochs 40 \
16 | --greedy True
17 | #--beam True
18 |
--------------------------------------------------------------------------------
/src/run_main.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | SITENAME=askubuntu_unix_superuser
4 |
5 | CQ_DATA_DIR=clarification_question_generation_pytorch/$SITENAME
6 | SCRIPT_DIR=clarification_question_generation_pytorch/src
7 | EMB_DIR=clarification_question_generation_pytorch/embeddings/$SITENAME
8 |
9 | PARAMS_DIR=$CQ_DATA_DIR
10 |
11 |
12 | python $SCRIPT_DIR/main.py --train_context $CQ_DATA_DIR/train_context.txt \
13 | --train_question $CQ_DATA_DIR/train_question.txt \
14 | --train_answer $CQ_DATA_DIR/train_answer.txt \
15 | --train_ids $CQ_DATA_DIR/train_ids \
16 | --tune_context $CQ_DATA_DIR/tune_context.txt \
17 | --tune_question $CQ_DATA_DIR/tune_question.txt \
18 | --tune_answer $CQ_DATA_DIR/tune_answer.txt \
19 | --tune_ids $CQ_DATA_DIR/tune_ids \
20 | --test_context $CQ_DATA_DIR/test_context.txt \
21 | --test_question $CQ_DATA_DIR/test_question.txt \
22 | --test_answer $CQ_DATA_DIR/test_answer.txt \
23 | --test_ids $CQ_DATA_DIR/test_ids \
24 | --q_encoder_params $PARAMS_DIR/q_encoder_params \
25 | --q_decoder_params $PARAMS_DIR/q_decoder_params \
26 | --a_encoder_params $PARAMS_DIR/a_encoder_params \
27 | --a_decoder_params $PARAMS_DIR/a_decoder_params \
28 | --context_params $PARAMS_DIR/context_params \
29 | --question_params $PARAMS_DIR/question_params \
30 | --answer_params $PARAMS_DIR/answer_params \
31 | --utility_params $PARAMS_DIR/utility_params \
32 | --word_embeddings $EMB_DIR/word_embeddings.p \
33 | --vocab $EMB_DIR/vocab.p \
34 | --n_epochs 100 \
35 | --max_post_len 100 \
36 | --max_ques_len 20 \
37 | --max_ans_len 20 \
38 | --pretrain_ques True \
39 |
40 |
--------------------------------------------------------------------------------
/src/run_main_HK.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | SITENAME=Home_and_Kitchen
4 |
5 | CQ_DATA_DIR=clarification_question_generation_pytorch/$SITENAME
6 | SCRIPT_DIR=clarification_question_generation_pytorch/src
7 | EMB_DIR=clarification_question_generation_pytorch/embeddings/$SITENAME
8 |
9 | PARAMS_DIR=$CQ_DATA_DIR
10 |
11 | python $SCRIPT_DIR/main.py --train_context $CQ_DATA_DIR/train_context.txt \
12 | --train_ques $CQ_DATA_DIR/train_ques.txt \
13 | --train_ans $CQ_DATA_DIR/train_ans.txt \
14 | --train_ids $CQ_DATA_DIR/train_asin.txt \
15 | --tune_context $CQ_DATA_DIR/tune_context.txt \
16 | --tune_ques $CQ_DATA_DIR/tune_ques.txt \
17 | --tune_ans $CQ_DATA_DIR/tune_ans.txt \
18 | --tune_ids $CQ_DATA_DIR/tune_asin.txt \
19 | --test_context $CQ_DATA_DIR/test_context.txt \
20 | --test_ques $CQ_DATA_DIR/test_ques.txt \
21 | --test_ans $CQ_DATA_DIR/test_ans.txt \
22 | --test_ids $CQ_DATA_DIR/test_asin.txt \
23 | --q_encoder_params $PARAMS_DIR/q_encoder_params \
24 | --q_decoder_params $PARAMS_DIR/q_decoder_params \
25 | --a_encoder_params $PARAMS_DIR/a_encoder_params \
26 | --a_decoder_params $PARAMS_DIR/a_decoder_params \
27 | --context_params $PARAMS_DIR/context_params \
28 | --question_params $PARAMS_DIR/question_params \
29 | --answer_params $PARAMS_DIR/answer_params \
30 | --utility_params $PARAMS_DIR/utility_params \
31 | --word_embeddings $EMB_DIR/word_embeddings.p \
32 | --vocab $EMB_DIR/vocab.p \
33 | --n_epochs 100 \
34 | --max_post_len 100 \
35 | --max_ques_len 20 \
36 | --max_ans_len 20 \
37 | --pretrain_ques True \
38 | #--pretrain_ans True \
39 | #--pretrain_util True \
40 |
41 |
42 |
43 |
44 |
45 |
46 |
--------------------------------------------------------------------------------
/src/run_main_electronics.sh:
--------------------------------------------------------------------------------
1 | python src/main.py --train_context baseline_data/train_context.txt \
2 | --train_ques baseline_data/train_question.txt \
3 | --train_ans baseline_data/train_answer.txt \
4 | --train_ids baseline_data/train_asin.txt \
5 | --tune_context baseline_data/valid_context.txt \
6 | --tune_ques baseline_data/valid_question.txt \
7 | --tune_ans baseline_data/valid_answer.txt \
8 | --tune_ids baseline_data/valid_asin.txt \
9 | --test_context baseline_data/test_context.txt \
10 | --test_ques baseline_data/test_question.txt \
11 | --test_ans baseline_data/test_answer.txt \
12 | --test_ids baseline_data/test_asin.txt \
13 | --q_encoder_params $PT_OUTPUT_DIR/q_encoder_params \
14 | --q_decoder_params $PT_OUTPUT_DIR/q_decoder_params \
15 | --a_encoder_params $PT_OUTPUT_DIR/a_encoder_params \
16 | --a_decoder_params $PT_OUTPUT_DIR/a_decoder_params \
17 | --context_params $PT_OUTPUT_DIR/context_params \
18 | --question_params $PT_OUTPUT_DIR/question_params \
19 | --answer_params $PT_OUTPUT_DIR/answer_params \
20 | --utility_params $PT_OUTPUT_DIR/utility_params \
21 | --word_embeddings embeddings/amazon_200d_embeddings.p \
22 | --vocab embeddings/amazon_200d_vocab.p \
23 | --n_epochs 100 \
24 | --batch_size 128 \
25 | --max_post_len 100 \
26 | --max_ques_len 20 \
27 | --max_ans_len 20 \
28 | --pretrain_ques True \
29 | #--pretrain_ans True \
30 | #--pretrain_util True \
31 |
--------------------------------------------------------------------------------
/src/run_pretrain_ans.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | SITENAME=askubuntu_unix_superuser
4 |
5 | CQ_DATA_DIR=clarification_question_generation_pytorch/$SITENAME
6 | SCRIPT_DIR=clarification_question_generation_pytorch/src
7 | EMB_DIR=clarification_question_generation_pytorch/embeddings/$SITENAME
8 |
9 | PARAMS_DIR=$CQ_DATA_DIR
10 |
11 |
12 | python $SCRIPT_DIR/main.py --train_context $CQ_DATA_DIR/train_context.txt \
13 | --train_question $CQ_DATA_DIR/train_question.txt \
14 | --train_answer $CQ_DATA_DIR/train_answer.txt \
15 | --train_ids $CQ_DATA_DIR/train_ids \
16 | --tune_context $CQ_DATA_DIR/tune_context.txt \
17 | --tune_question $CQ_DATA_DIR/tune_question.txt \
18 | --tune_answer $CQ_DATA_DIR/tune_answer.txt \
19 | --tune_ids $CQ_DATA_DIR/tune_ids \
20 | --test_context $CQ_DATA_DIR/test_context.txt \
21 | --test_question $CQ_DATA_DIR/test_question.txt \
22 | --test_answer $CQ_DATA_DIR/test_answer.txt \
23 | --test_ids $CQ_DATA_DIR/test_ids \
24 | --q_encoder_params $PARAMS_DIR/q_encoder_params \
25 | --q_decoder_params $PARAMS_DIR/q_decoder_params \
26 | --a_encoder_params $PARAMS_DIR/a_encoder_params \
27 | --a_decoder_params $PARAMS_DIR/a_decoder_params \
28 | --context_params $PARAMS_DIR/context_params \
29 | --question_params $PARAMS_DIR/question_params \
30 | --answer_params $PARAMS_DIR/answer_params \
31 | --utility_params $PARAMS_DIR/utility_params \
32 | --word_embeddings $EMB_DIR/word_embeddings.p \
33 | --vocab $EMB_DIR/vocab.p \
34 | --n_epochs 100 \
35 | --max_post_len 100 \
36 | --max_ques_len 20 \
37 | --max_ans_len 20 \
38 | --pretrain_ans True \
39 |
40 |
--------------------------------------------------------------------------------
/src/run_pretrain_ans_HK.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | SITENAME=Home_and_Kitchen
4 |
5 | CQ_DATA_DIR=clarification_question_generation_pytorch/$SITENAME
6 | SCRIPT_DIR=clarification_question_generation_pytorch/src
7 | EMB_DIR=clarification_question_generation_pytorch/embeddings/$SITENAME
8 |
9 | PARAMS_DIR=$CQ_DATA_DIR
10 |
11 | python $SCRIPT_DIR/main.py --train_context $CQ_DATA_DIR/train_context.txt \
12 | --train_ques $CQ_DATA_DIR/train_ques.txt \
13 | --train_ans $CQ_DATA_DIR/train_ans.txt \
14 | --train_ids $CQ_DATA_DIR/train_asin.txt \
15 | --tune_context $CQ_DATA_DIR/tune_context.txt \
16 | --tune_ques $CQ_DATA_DIR/tune_ques.txt \
17 | --tune_ans $CQ_DATA_DIR/tune_ans.txt \
18 | --tune_ids $CQ_DATA_DIR/tune_asin.txt \
19 | --test_context $CQ_DATA_DIR/test_context.txt \
20 | --test_ques $CQ_DATA_DIR/test_ques.txt \
21 | --test_ans $CQ_DATA_DIR/test_ans.txt \
22 | --test_ids $CQ_DATA_DIR/test_asin.txt \
23 | --q_encoder_params $PARAMS_DIR/q_encoder_params \
24 | --q_decoder_params $PARAMS_DIR/q_decoder_params \
25 | --a_encoder_params $PARAMS_DIR/a_encoder_params \
26 | --a_decoder_params $PARAMS_DIR/a_decoder_params \
27 | --context_params $PARAMS_DIR/context_params \
28 | --question_params $PARAMS_DIR/question_params \
29 | --answer_params $PARAMS_DIR/answer_params \
30 | --utility_params $PARAMS_DIR/utility_params \
31 | --word_embeddings $EMB_DIR/word_embeddings.p \
32 | --vocab $EMB_DIR/vocab.p \
33 | --n_epochs 100 \
34 | --max_post_len 100 \
35 | --max_ques_len 20 \
36 | --max_ans_len 20 \
37 | --pretrain_ans True \
38 |
--------------------------------------------------------------------------------
/src/run_pretrain_util.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | SITENAME=askubuntu_unix_superuser
4 |
5 | CQ_DATA_DIR=clarification_question_generation_pytorch/$SITENAME
6 | SCRIPT_DIR=clarification_question_generation_pytorch/src
7 | EMB_DIR=clarification_question_generation_pytorch/embeddings/$SITENAME
8 |
9 | PARAMS_DIR=$CQ_DATA_DIR
10 |
11 |
12 | python $SCRIPT_DIR/main.py --train_context $CQ_DATA_DIR/train_context.txt \
13 | --train_question $CQ_DATA_DIR/train_question.txt \
14 | --train_answer $CQ_DATA_DIR/train_answer.txt \
15 | --train_ids $CQ_DATA_DIR/train_ids \
16 | --tune_context $CQ_DATA_DIR/tune_context.txt \
17 | --tune_question $CQ_DATA_DIR/tune_question.txt \
18 | --tune_answer $CQ_DATA_DIR/tune_answer.txt \
19 | --tune_ids $CQ_DATA_DIR/tune_ids \
20 | --test_context $CQ_DATA_DIR/test_context.txt \
21 | --test_question $CQ_DATA_DIR/test_question.txt \
22 | --test_answer $CQ_DATA_DIR/test_answer.txt \
23 | --test_ids $CQ_DATA_DIR/test_ids \
24 | --q_encoder_params $PARAMS_DIR/q_encoder_params \
25 | --q_decoder_params $PARAMS_DIR/q_decoder_params \
26 | --a_encoder_params $PARAMS_DIR/a_encoder_params \
27 | --a_decoder_params $PARAMS_DIR/a_decoder_params \
28 | --context_params $PARAMS_DIR/context_params \
29 | --question_params $PARAMS_DIR/question_params \
30 | --answer_params $PARAMS_DIR/answer_params \
31 | --utility_params $PARAMS_DIR/utility_params \
32 | --word_embeddings $EMB_DIR/word_embeddings.p \
33 | --vocab $EMB_DIR/vocab.p \
34 | --n_epochs 10 \
35 | --max_post_len 100 \
36 | --max_ques_len 20 \
37 | --max_ans_len 20 \
38 | --pretrain_util True \
39 |
40 |
--------------------------------------------------------------------------------
/src/run_pretrain_util_HK.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | SITENAME=Home_and_Kitchen
4 |
5 | CQ_DATA_DIR=clarification_question_generation_pytorch/$SITENAME
6 | SCRIPT_DIR=clarification_question_generation_pytorch/src
7 | EMB_DIR=clarification_question_generation_pytorch/embeddings/$SITENAME
8 |
9 | PARAMS_DIR=$CQ_DATA_DIR
10 |
11 | python $SCRIPT_DIR/main.py --train_context $CQ_DATA_DIR/train_context.txt \
12 | --train_ques $CQ_DATA_DIR/train_ques.txt \
13 | --train_ans $CQ_DATA_DIR/train_ans.txt \
14 | --train_ids $CQ_DATA_DIR/train_asin.txt \
15 | --tune_context $CQ_DATA_DIR/tune_context.txt \
16 | --tune_ques $CQ_DATA_DIR/tune_ques.txt \
17 | --tune_ans $CQ_DATA_DIR/tune_ans.txt \
18 | --tune_ids $CQ_DATA_DIR/tune_asin.txt \
19 | --test_context $CQ_DATA_DIR/test_context.txt \
20 | --test_ques $CQ_DATA_DIR/test_ques.txt \
21 | --test_ans $CQ_DATA_DIR/test_ans.txt \
22 | --test_ids $CQ_DATA_DIR/test_asin.txt \
23 | --q_encoder_params $PARAMS_DIR/q_encoder_params \
24 | --q_decoder_params $PARAMS_DIR/q_decoder_params \
25 | --a_encoder_params $PARAMS_DIR/a_encoder_params \
26 | --a_decoder_params $PARAMS_DIR/a_decoder_params \
27 | --context_params $PARAMS_DIR/context_params \
28 | --question_params $PARAMS_DIR/question_params \
29 | --answer_params $PARAMS_DIR/answer_params \
30 | --utility_params $PARAMS_DIR/utility_params \
31 | --word_embeddings $EMB_DIR/word_embeddings.p \
32 | --vocab $EMB_DIR/vocab.p \
33 | --n_epochs 10 \
34 | --max_post_len 100 \
35 | --max_ques_len 20 \
36 | --max_ans_len 20 \
37 | --pretrain_util True \
38 |
--------------------------------------------------------------------------------
/src/seq2seq/RL_inference.py:
--------------------------------------------------------------------------------
1 | from constants import *
2 | from masked_cross_entropy import *
3 | import numpy as np
4 | import random
5 | import torch
6 | from torch.autograd import Variable
7 |
8 | def get_decoded_seqs(decoder_outputs, word2index, max_len, batch_size):
9 | decoded_seqs = []
10 | decoded_lens = []
11 | decoded_seq_masks = []
12 | for b in range(batch_size):
13 | decoded_seq = []
14 | decoded_seq_mask = [0]*max_len
15 | log_prob = 0.
16 | for t in range(max_len):
17 | topv, topi = decoder_outputs[t][b].data.topk(1)
18 | ni = topi[0].item()
19 | if ni == word2index[EOS_token]:
20 | decoded_seq.append(ni)
21 | break
22 | else:
23 | decoded_seq.append(ni)
24 | decoded_seq_mask[t] = 1
25 | decoded_lens.append(len(decoded_seq))
26 | decoded_seq += [word2index[PAD_token]]*int(max_len - len(decoded_seq))
27 | decoded_seqs.append(decoded_seq)
28 | decoded_seq_masks.append(decoded_seq_mask)
29 | decoded_lens = np.array(decoded_lens)
30 | decoded_seqs = np.array(decoded_seqs)
31 | return decoded_seqs, decoded_lens, decoded_seq_masks
32 |
33 | def inference(input_batches, input_lens, target_batches, target_lens, \
34 | encoder, decoder, word2index, max_len, batch_size):
35 |
36 | input_batches = Variable(torch.LongTensor(np.array(input_batches))).transpose(0, 1)
37 | target_batches = Variable(torch.LongTensor(np.array(target_batches))).transpose(0, 1)
38 |
39 | decoder_input = Variable(torch.LongTensor([word2index[SOS_token]] * batch_size))
40 | decoder_outputs = Variable(torch.zeros(max_len, batch_size, decoder.output_size))
41 |
42 | if USE_CUDA:
43 | input_batches = input_batches.cuda()
44 | target_batches = target_batches.cuda()
45 | decoder_input = decoder_input.cuda()
46 | decoder_outputs = decoder_outputs.cuda()
47 |
48 | # Run post words through encoder
49 | encoder_outputs, encoder_hidden = encoder(input_batches, input_lens, None)
50 |
51 | # Prepare input and output variables
52 | decoder_hidden = encoder_hidden[:decoder.n_layers] + encoder_hidden[decoder.n_layers:]
53 |
54 | # Run through decoder one time step at a time
55 | for t in range(max_len):
56 | decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden, encoder_outputs)
57 | decoder_outputs[t] = decoder_output
58 |
59 | # Without teacher forcing
60 | for b in range(batch_size):
61 | topv, topi = decoder_output[b].topk(1)
62 | decoder_input[b] = topi.squeeze().detach()
63 |
64 | decoded_seqs, decoded_lens, decoded_seq_masks = get_decoded_seqs(decoder_outputs, word2index, max_len, batch_size)
65 |
66 | # Loss calculation and backpropagation
67 | #loss = masked_cross_entropy(
68 | # decoder_outputs.transpose(0, 1).contiguous(), # -> batch x seq
69 | # target_batches.transpose(0, 1).contiguous(), # -> batch x seq
70 | # target_lens
71 | #)
72 |
73 | log_probs = calculate_log_probs(
74 | decoder_outputs.transpose(0, 1).contiguous(), # -> batch x seq
75 | decoded_seq_masks,
76 | )
77 |
78 | return log_probs, decoded_seqs, decoded_lens
79 |
--------------------------------------------------------------------------------
/src/seq2seq/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/raosudha89/clarification_question_generation_pytorch/23ae8aa0160eee70565751f4b6de13563a19d6ed/src/seq2seq/__init__.py
--------------------------------------------------------------------------------
/src/seq2seq/ans_train.py:
--------------------------------------------------------------------------------
1 | from constants import *
2 | from masked_cross_entropy import *
3 | import numpy as np
4 | import random
5 | import torch
6 | from torch.autograd import Variable
7 | from constants import *
8 |
9 |
10 | def train(input_batches, input_lens, target_batches, target_lens,
11 | encoder, decoder, encoder_optimizer, decoder_optimizer,
12 | word2index, args, mode='train'):
13 | if mode == 'train':
14 | encoder.train()
15 | decoder.train()
16 |
17 | # Zero gradients of both optimizers
18 | encoder_optimizer.zero_grad()
19 | decoder_optimizer.zero_grad()
20 |
21 | if USE_CUDA:
22 | input_batches = Variable(torch.LongTensor(np.array(input_batches)).cuda()).transpose(0, 1)
23 | target_batches = Variable(torch.LongTensor(np.array(target_batches)).cuda()).transpose(0, 1)
24 | else:
25 | input_batches = Variable(torch.LongTensor(np.array(input_batches))).transpose(0, 1)
26 | target_batches = Variable(torch.LongTensor(np.array(target_batches))).transpose(0, 1)
27 |
28 | # Run post words through encoder
29 | encoder_outputs, encoder_hidden = encoder(input_batches, input_lens, None)
30 |
31 | # Prepare input and output variables
32 | decoder_hidden = encoder_hidden[:decoder.n_layers] + encoder_hidden[decoder.n_layers:]
33 |
34 | if USE_CUDA:
35 | decoder_input = Variable(torch.LongTensor([word2index[SOS_token]] * args.batch_size).cuda())
36 | all_decoder_outputs = Variable(torch.zeros(args.max_ans_len, args.batch_size, decoder.output_size).cuda())
37 | decoder_outputs = Variable(torch.zeros(args.max_ans_len, args.batch_size).cuda())
38 | else:
39 | decoder_input = Variable(torch.LongTensor([word2index[SOS_token]] * args.batch_size))
40 | all_decoder_outputs = Variable(torch.zeros(args.max_ans_len, args.batch_size, decoder.output_size))
41 | decoder_outputs = Variable(torch.zeros(args.max_ans_len, args.batch_size))
42 |
43 | # Run through decoder one time step at a time
44 | for t in range(args.max_ans_len):
45 | decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden, encoder_outputs)
46 | all_decoder_outputs[t] = decoder_output
47 |
48 | # Teacher Forcing
49 | decoder_input = target_batches[t] # Next input is current target
50 | decoder_outputs[t] = target_batches[t]
51 |
52 | # # Greeding
53 | # topv, topi = decoder_output.data.topk(1)
54 | # decoder_outputs[t] = topi.squeeze(1)
55 |
56 | decoded_seqs = []
57 | decoded_lens = []
58 | for b in range(args.batch_size):
59 | decoded_seq = []
60 | for t in range(args.max_ans_len):
61 | topi = decoder_outputs[t][b].data
62 | idx = int(topi.item())
63 | if idx == word2index[EOS_token]:
64 | decoded_seq.append(idx)
65 | break
66 | else:
67 | decoded_seq.append(idx)
68 | decoded_lens.append(len(decoded_seq))
69 | decoded_seq += [word2index[PAD_token]] * (args.max_ans_len - len(decoded_seq))
70 | decoded_seqs.append(decoded_seq)
71 |
72 | loss_fn = torch.nn.NLLLoss()
73 | # Loss calculation and backpropagation
74 | loss = masked_cross_entropy(
75 | all_decoder_outputs.transpose(0, 1).contiguous(), # -> batch x seq
76 | target_batches.transpose(0, 1).contiguous(), # -> batch x seq
77 | target_lens, loss_fn, args.max_ans_len
78 | )
79 | if mode == 'train':
80 | loss.backward()
81 | encoder_optimizer.step()
82 | decoder_optimizer.step()
83 |
84 | return loss, decoded_seqs, decoded_lens
85 |
--------------------------------------------------------------------------------
/src/seq2seq/attn.py:
--------------------------------------------------------------------------------
1 | from constants import *
2 | import torch
3 | import torch.nn as nn
4 | from torch.autograd import Variable
5 | import torch.nn.functional as F
6 |
7 | class Attn(nn.Module):
8 | def __init__(self, hidden_size):
9 | super(Attn, self).__init__()
10 | self.hidden_size = hidden_size
11 |
12 | def forward(self, hidden, encoder_outputs):
13 | max_len = encoder_outputs.size(0)
14 | this_batch_size = encoder_outputs.size(1)
15 |
16 | # Create variable to store attention energies
17 | attn_energies = Variable(torch.zeros(this_batch_size, max_len)) # B x S
18 | if USE_CUDA:
19 | attn_energies = attn_energies.cuda()
20 | attn_energies = torch.bmm(hidden.transpose(0, 1), encoder_outputs.transpose(0,1).transpose(1,2)).squeeze(1)
21 | # Normalize energies to weights in range 0 to 1, resize to 1 x B x S
22 | return F.softmax(attn_energies, dim=1).unsqueeze(1)
23 |
--------------------------------------------------------------------------------
/src/seq2seq/attnDecoderRNN.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import sys
4 | sys.path.append('src/seq2seq')
5 | from attn import *
6 |
7 |
8 | class AttnDecoderRNN(nn.Module):
9 | def __init__(self, hidden_size, output_size, word_embeddings, n_layers=1, dropout=0.1):
10 | super(AttnDecoderRNN, self).__init__()
11 |
12 | # Keep for reference
13 | self.hidden_size = hidden_size
14 | self.output_size = output_size
15 | self.n_layers = n_layers
16 | self.dropout = dropout
17 |
18 | # Define layers
19 | self.embedding = nn.Embedding(len(word_embeddings), len(word_embeddings[0]))
20 | self.embedding.weight.data.copy_(torch.from_numpy(word_embeddings))
21 | self.embedding.weight.requires_grad = False
22 | self.embedding_dropout = nn.Dropout(dropout)
23 | self.gru = nn.GRU(len(word_embeddings[0]), hidden_size, n_layers, dropout=dropout)
24 | self.concat = nn.Linear(hidden_size * 2, hidden_size)
25 | self.out = nn.Linear(hidden_size, output_size)
26 |
27 | self.attn = Attn(hidden_size)
28 |
29 | def forward(self, input_seq, last_hidden, p_encoder_outputs):
30 | # Note: we run this one step at a time
31 |
32 | # Get the embedding of the current input word (last output word)
33 | embedded = self.embedding(input_seq)
34 | embedded = self.embedding_dropout(embedded)
35 | embedded = embedded.view(1, embedded.shape[0], embedded.shape[1]) # S=1 x B x N
36 |
37 | # Get current hidden state from input word and last hidden state
38 | rnn_output, hidden = self.gru(embedded, last_hidden)
39 |
40 | # Calculate attention from current RNN state and all p_encoder outputs;
41 | # apply to p_encoder outputs to get weighted average
42 | p_attn_weights = self.attn(rnn_output, p_encoder_outputs)
43 | p_context = p_attn_weights.bmm(p_encoder_outputs.transpose(0, 1)) # B x S=1 x N
44 |
45 | # Attentional vector using the RNN hidden state and context vector
46 | # concatenated together (Luong eq. 5)
47 | rnn_output = rnn_output.squeeze(0) # S=1 x B x N -> B x N
48 |
49 | p_context = p_context.squeeze(1) # B x S=1 x N -> B x N
50 | concat_input = torch.cat((rnn_output, p_context), 1)
51 | concat_output = F.tanh(self.concat(concat_input))
52 |
53 | # Finally predict next token (Luong eq. 6, without softmax)
54 | output = self.out(concat_output)
55 |
56 | # Return final output, hidden state, and attention weights (for visualization)
57 | return output, hidden
58 |
--------------------------------------------------------------------------------
/src/seq2seq/baselineFF.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from constants import *
4 |
5 | class BaselineFF(nn.Module):
6 | def __init__(self, input_dim):
7 | super(BaselineFF, self).__init__()
8 |
9 | self.layer1 = nn.Linear(input_dim, HIDDEN_SIZE)
10 | self.relu = nn.ReLU()
11 | self.layer2 = nn.Linear(HIDDEN_SIZE, 1)
12 | self.sigmoid = nn.Sigmoid()
13 |
14 | def forward(self, x):
15 | x = self.layer1(x)
16 | x = self.relu(x)
17 | x = self.layer2(x)
18 | x = self.sigmoid(x)
19 | return x
20 |
21 |
--------------------------------------------------------------------------------
/src/seq2seq/encoderRNN.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | class EncoderRNN(nn.Module):
5 | def __init__(self, hidden_size, word_embeddings, n_layers=1, dropout=0.1):
6 | super(EncoderRNN, self).__init__()
7 |
8 | self.hidden_size = hidden_size
9 | self.n_layers = n_layers
10 | self.dropout = dropout
11 |
12 | self.embedding = nn.Embedding(len(word_embeddings), len(word_embeddings[0]))
13 | self.embedding.weight.data.copy_(torch.from_numpy(word_embeddings))
14 | self.embedding.weight.requires_grad = False
15 | self.gru = nn.GRU(len(word_embeddings[0]), hidden_size, n_layers, dropout=self.dropout, bidirectional=True)
16 |
17 | def forward(self, input_seqs, input_lengths, hidden=None):
18 | # Note: we run this all at once (over multiple batches of multiple sequences)
19 | embedded = self.embedding(input_seqs)
20 | #packed = torch.nn.utils.rnn.pack_padded_sequence(embedded, input_lengths)
21 | #outputs, hidden = self.gru(packed, hidden)
22 | #outputs, output_lengths = torch.nn.utils.rnn.pad_packed_sequence(outputs) # unpack (back to padded)
23 | outputs, hidden = self.gru(embedded, hidden)
24 | outputs = outputs[:, :, :self.hidden_size] + outputs[:, : ,self.hidden_size:] # Sum bidirectional outputs
25 | return outputs, hidden
26 |
--------------------------------------------------------------------------------
/src/seq2seq/evaluate.py:
--------------------------------------------------------------------------------
1 | import random
2 | from constants import *
3 | from prepare_data import *
4 | from masked_cross_entropy import *
5 | from helper import *
6 | import torch
7 | import torch.nn as nn
8 | from torch.autograd import Variable
9 |
10 |
11 | def evaluate(word2index, index2word, encoder, decoder, test_data,
12 | max_output_length, BATCH_SIZE, out_file):
13 | ids_seqs, input_seqs, input_lens, output_seqs, output_lens = test_data
14 | total_loss = 0.
15 | n_batches = len(input_seqs) / BATCH_SIZE
16 |
17 | for ids_seqs_batch, input_seqs_batch, input_lens_batch, output_seqs_batch, output_lens_batch in \
18 | iterate_minibatches(ids_seqs, input_seqs, input_lens, output_seqs, output_lens, BATCH_SIZE):
19 |
20 | if USE_CUDA:
21 | input_seqs_batch = Variable(torch.LongTensor(np.array(input_seqs_batch)).cuda()).transpose(0, 1)
22 | output_seqs_batch = Variable(torch.LongTensor(np.array(output_seqs_batch)).cuda()).transpose(0, 1)
23 | else:
24 | input_seqs_batch = Variable(torch.LongTensor(np.array(input_seqs_batch))).transpose(0, 1)
25 | output_seqs_batch = Variable(torch.LongTensor(np.array(output_seqs_batch))).transpose(0, 1)
26 |
27 | # Run post words through encoder
28 | encoder_outputs, encoder_hidden = encoder(input_seqs_batch, input_lens_batch, None)
29 | # Create starting vectors for decoder
30 | decoder_input = Variable(torch.LongTensor([word2index[SOS_token]] * BATCH_SIZE), volatile=True)
31 | decoder_hidden = encoder_hidden[:decoder.n_layers] + encoder_hidden[decoder.n_layers:]
32 | all_decoder_outputs = Variable(torch.zeros(max_output_length, BATCH_SIZE, decoder.output_size))
33 |
34 | if USE_CUDA:
35 | decoder_input = decoder_input.cuda()
36 | all_decoder_outputs = all_decoder_outputs.cuda()
37 |
38 | # Run through decoder one time step at a time
39 | for t in range(max_output_length):
40 | decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden, encoder_outputs)
41 | all_decoder_outputs[t] = decoder_output
42 | # Choose top word from output
43 | topv, topi = decoder_output.data.topk(1)
44 | decoder_input = topi.squeeze(1)
45 | if out_file:
46 | for b in range(BATCH_SIZE):
47 | decoded_words = []
48 | for t in range(max_output_length):
49 | topv, topi = all_decoder_outputs[t][b].data.topk(1)
50 | ni = topi[0].item()
51 | if ni == word2index[EOS_token]:
52 | decoded_words.append(EOS_token)
53 | break
54 | else:
55 | decoded_words.append(index2word[ni])
56 | out_file.write(' '.join(decoded_words)+'\n')
57 |
58 | loss_fn = torch.nn.NLLLoss()
59 | # Loss calculation
60 | loss = masked_cross_entropy(
61 | all_decoder_outputs.transpose(0, 1).contiguous(), # -> batch x seq
62 | output_seqs_batch.transpose(0, 1).contiguous(), # -> batch x seq
63 | output_lens_batch, loss_fn, max_output_length
64 | )
65 | total_loss += loss.item()
66 | return total_loss/n_batches
67 |
--------------------------------------------------------------------------------
/src/seq2seq/helper.py:
--------------------------------------------------------------------------------
1 | from constants import *
2 | import math
3 | import numpy as np
4 | import nltk
5 | import random
6 | import time
7 | import torch
8 |
9 |
10 | def as_minutes(s):
11 | m = math.floor(s / 60)
12 | s -= m * 60
13 | return '%dm %ds' % (m, s)
14 |
15 |
16 | def time_since(since, percent):
17 | now = time.time()
18 | s = now - since
19 | es = s / (percent)
20 | rs = es - s
21 | return '%s (- %s)' % (as_minutes(s), as_minutes(rs))
22 |
23 |
24 | def iterate_minibatches(id_seqs, input_seqs, input_lens, output_seqs, output_lens, batch_size, shuffle=True):
25 | if shuffle:
26 | indices = np.arange(len(input_seqs))
27 | np.random.shuffle(indices)
28 | for start_idx in range(0, len(input_seqs) - batch_size + 1, batch_size):
29 | if shuffle:
30 | ex = indices[start_idx:start_idx + batch_size]
31 | else:
32 | ex = slice(start_idx, start_idx + batch_size)
33 | yield np.array(id_seqs)[ex], np.array(input_seqs)[ex], np.array(input_lens)[ex], \
34 | np.array(output_seqs)[ex], np.array(output_lens)[ex]
35 |
36 |
37 | def reverse_dict(word2index):
38 | index2word = {}
39 | for w in word2index:
40 | ix = word2index[w]
41 | index2word[ix] = w
42 | return index2word
43 |
44 |
45 | def calculate_bleu(true, true_lens, pred, pred_lens, index2word, max_len):
46 | sent_bleu_scores = torch.zeros(len(pred))
47 | for i in range(len(pred)):
48 | true_sent = [index2word[idx] for idx in true[i][:true_lens[i]]]
49 | pred_sent = [index2word[idx] for idx in pred[i][:pred_lens[i]]]
50 | sent_bleu_scores[i] = nltk.translate.bleu_score.sentence_bleu([true_sent], pred_sent)
51 | if USE_CUDA:
52 | sent_bleu_scores = sent_bleu_scores.cuda()
53 | return sent_bleu_scores
54 |
--------------------------------------------------------------------------------
/src/seq2seq/main.py:
--------------------------------------------------------------------------------
1 | from attn import *
2 | from attnDecoderRNN import *
3 | from constants import *
4 | from encoderRNN import *
5 | from evaluate import *
6 | from helper import *
7 | from train import *
8 | import torch
9 | import torch.optim as optim
10 | from prepare_data import *
11 |
12 |
13 | def run_seq2seq(train_data, test_data, word2index, word_embeddings,
14 | encoder_params_file, decoder_params_file,
15 | encoder_params_contd_file, decoder_params_contd_file, max_target_length,
16 | n_epochs, batch_size, n_layers):
17 | # Initialize q models
18 | print('Initializing models')
19 | encoder = EncoderRNN(HIDDEN_SIZE, word_embeddings, n_layers, dropout=DROPOUT)
20 | decoder = AttnDecoderRNN(HIDDEN_SIZE, len(word2index), word_embeddings, n_layers)
21 |
22 | # Initialize optimizers
23 | encoder_optimizer = optim.Adam([par for par in encoder.parameters() if par.requires_grad], lr=LEARNING_RATE)
24 | decoder_optimizer = optim.Adam([par for par in decoder.parameters() if par.requires_grad], lr=LEARNING_RATE * DECODER_LEARNING_RATIO)
25 |
26 | # Move models to GPU
27 | if USE_CUDA:
28 | encoder.cuda()
29 | decoder.cuda()
30 |
31 | # Keep track of time elapsed and running averages
32 | start = time.time()
33 | print_loss_total = 0 # Reset every print_every
34 | epoch = 0.0
35 | #epoch = 12.0
36 | #print('Loading encoded, decoder params')
37 | #encoder.load_state_dict(torch.load(encoder_params_contd_file+'.epoch%d' % epoch))
38 | #decoder.load_state_dict(torch.load(decoder_params_contd_file+'.epoch%d' % epoch))
39 |
40 | ids_seqs, input_seqs, input_lens, output_seqs, output_lens = train_data
41 |
42 | n_batches = len(input_seqs) / batch_size
43 | teacher_forcing_ratio = 1.0
44 | # decr = teacher_forcing_ratio/n_epochs
45 | prev_test_loss = None
46 | num_decrease = 0.0
47 | while epoch < n_epochs:
48 | epoch += 1
49 | for ids_seqs_batch, input_seqs_batch, input_lens_batch, \
50 | output_seqs_batch, output_lens_batch in \
51 | iterate_minibatches(ids_seqs, input_seqs, input_lens, output_seqs, output_lens, batch_size):
52 |
53 | start_time = time.time()
54 | # Run the train function
55 | loss = train(
56 | input_seqs_batch, input_lens_batch,
57 | output_seqs_batch, output_lens_batch,
58 | encoder, decoder,
59 | encoder_optimizer, decoder_optimizer,
60 | word2index[SOS_token], max_target_length,
61 | batch_size, teacher_forcing_ratio
62 | )
63 |
64 | # Keep track of loss
65 | print_loss_total += loss
66 |
67 | # teacher_forcing_ratio = teacher_forcing_ratio - decr
68 | print_loss_avg = print_loss_total / n_batches
69 | print_loss_total = 0
70 | print('Epoch: %d' % epoch)
71 | print('Train Set')
72 | print_summary = '%s %d %.4f' % (time_since(start, epoch / n_epochs), epoch, print_loss_avg)
73 | print(print_summary)
74 | print('Dev Set')
75 | curr_test_loss = evaluate(word2index, None, encoder, decoder, test_data,
76 | max_target_length, batch_size, None)
77 | print('%.4f ' % curr_test_loss)
78 | # if prev_test_loss is not None:
79 | # diff_test_loss = prev_test_loss - curr_test_loss
80 | # if diff_test_loss <= 0:
81 | # num_decrease += 1
82 | # if num_decrease > 5:
83 | # print 'Early stopping'
84 | # print 'Saving model params'
85 | # torch.save(encoder.state_dict(), encoder_params_file + '.epoch%d' % epoch)
86 | # torch.save(decoder.state_dict(), decoder_params_file + '.epoch%d' % epoch)
87 | # return
88 | # if epoch % 5 == 0:
89 | print('Saving model params')
90 | torch.save(encoder.state_dict(), encoder_params_file+'.epoch%d' % epoch)
91 | torch.save(decoder.state_dict(), decoder_params_file+'.epoch%d' % epoch)
92 |
93 |
94 |
--------------------------------------------------------------------------------
/src/seq2seq/masked_cross_entropy.py:
--------------------------------------------------------------------------------
1 | from constants import *
2 | import torch
3 | from torch.nn import functional
4 | from torch.autograd import Variable
5 |
6 |
7 | def calculate_log_probs(logits, output, length, loss_fn, mixer_delta):
8 | batch_size = logits.shape[0]
9 | log_probs = functional.log_softmax(logits, dim=2)
10 | avg_log_probs = Variable(torch.zeros(batch_size))
11 | if USE_CUDA:
12 | avg_log_probs = avg_log_probs.cuda()
13 | for b in range(batch_size):
14 | curr_len = length[b]-mixer_delta
15 | if curr_len > 0:
16 | avg_log_probs[b] = loss_fn(log_probs[b][:curr_len], output[b][:curr_len]) / curr_len
17 | return avg_log_probs
18 |
19 |
20 | def masked_cross_entropy(logits, target, length, loss_fn, mixer_delta=None):
21 | batch_size = logits.shape[0]
22 | # log_probs: (batch, max_len, num_classes)
23 | log_probs = functional.log_softmax(logits, dim=2)
24 | loss = 0.
25 | for b in range(batch_size):
26 | curr_len = min(length[b], mixer_delta)
27 | sent_loss = loss_fn(log_probs[b][:curr_len], target[b][:curr_len]) / curr_len
28 | loss += sent_loss
29 | loss = loss / batch_size
30 | return loss
31 |
--------------------------------------------------------------------------------
/src/seq2seq/prepare_data.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.autograd import Variable
3 | import pdb
4 | import sys
5 | sys.path.append('src/seq2seq')
6 | from read_data import *
7 | import numpy as np
8 | from constants import *
9 |
10 | UNK_TOKEN=''
11 | # UNK_TOKEN='unk'
12 |
13 | # Return a list of indexes, one for each word in the sentence, plus EOS
14 | def prepare_sequence(seq, word2index, max_len):
15 | sequence = [word2index[w] if w in word2index else word2index[UNK_TOKEN] for w in seq.split(' ')[:max_len-1]]
16 | sequence.append(word2index[EOS_token])
17 | length = len(sequence)
18 | sequence += [word2index[PAD_token]]*int(max_len - len(sequence))
19 | return sequence, length
20 |
21 |
22 | def prepare_pq_sequence(post_seq, ques_seq, word2index, max_post_len, max_ques_len):
23 | p_sequence = [word2index[w] if w in word2index else word2index[UNK_TOKEN] for w in post_seq.split(' ')[:(max_post_len-1)]]
24 | p_sequence.append(word2index[EOP_token])
25 | p_sequence += [word2index[PAD_token]]*int(max_post_len - len(p_sequence))
26 | q_sequence = [word2index[w] if w in word2index else word2index[UNK_TOKEN] for w in ques_seq.split(' ')[:(max_ques_len-1)]]
27 | q_sequence.append(word2index[EOS_token])
28 | q_sequence += [word2index[PAD_token]]*int(max_ques_len - len(q_sequence))
29 | sequence = p_sequence + q_sequence
30 | length = max_post_len, max_ques_len
31 | return sequence, length
32 |
33 |
34 | def preprocess_data(triples, word2index, max_post_len, max_ques_len, max_ans_len):
35 | id_seqs = []
36 | post_seqs = []
37 | post_lens = []
38 | ques_seqs = []
39 | ques_lens = []
40 | post_ques_seqs = []
41 | post_ques_lens = []
42 | ans_seqs = []
43 | ans_lens = []
44 |
45 | for i in range(len(triples)):
46 | curr_id, post, ques, ans = triples[i]
47 | id_seqs.append(curr_id)
48 | post_seq, post_len = prepare_sequence(post, word2index, max_post_len)
49 | post_seqs.append(post_seq)
50 | post_lens.append(post_len)
51 | ques_seq, ques_len = prepare_sequence(ques, word2index, max_ques_len)
52 | ques_seqs.append(ques_seq)
53 | ques_lens.append(ques_len)
54 | post_ques_seq, post_ques_len = prepare_pq_sequence(post, ques, word2index, max_post_len, max_ques_len)
55 | post_ques_seqs.append(post_ques_seq)
56 | post_ques_lens.append(post_ques_len)
57 | if ans is not None:
58 | ans_seq, ans_len = prepare_sequence(ans, word2index, max_ans_len)
59 | ans_seqs.append(ans_seq)
60 | ans_lens.append(ans_len)
61 |
62 | return id_seqs, post_seqs, post_lens, ques_seqs, ques_lens, \
63 | post_ques_seqs, post_ques_lens, ans_seqs, ans_lens
64 |
--------------------------------------------------------------------------------
/src/seq2seq/read_data.py:
--------------------------------------------------------------------------------
1 | import re
2 | import csv
3 | from constants import *
4 | import unicodedata
5 | from collections import defaultdict
6 | import math
7 |
8 |
9 | def unicode_to_ascii(s):
10 | return ''.join(
11 | c for c in unicodedata.normalize('NFD', s)
12 | if unicodedata.category(c) != 'Mn'
13 | )
14 |
15 |
16 | # Lowercase, trim, and remove non-letter characters
17 | def normalize_string(s, max_len):
18 | #s = unicode_to_ascii(s.lower().strip())
19 | s = s.lower().strip()
20 | words = s.split()
21 | s = ' '.join(words[:max_len])
22 | return s
23 |
24 |
25 | def get_context(line, max_post_len, max_ques_len):
26 | is_specific, is_generic = False, False
27 | if '' in line:
28 | line = line.replace(' ', '')
29 | is_specific = True
30 | if '' in line:
31 | line = line.replace(' ', '')
32 | is_generic = True
33 | if is_specific or is_generic:
34 | context = normalize_string(line, max_post_len-2) # one token space for specificity and another for EOS
35 | else:
36 | context = normalize_string(line, max_post_len-1)
37 | if is_specific:
38 | context = ' ' + context
39 | if is_generic:
40 | context += ' ' + context
41 | return context
42 |
43 |
44 | def read_data(context_fname, question_fname, answer_fname, ids_fname,
45 | max_post_len, max_ques_len, max_ans_len, count=None, mode='train'):
46 | if ids_fname is not None:
47 | ids = []
48 | for line in open(ids_fname, 'r').readlines():
49 | curr_id = line.strip('\n')
50 | ids.append(curr_id)
51 |
52 | print("Reading lines...")
53 | data = []
54 | i = 0
55 | for line in open(context_fname, 'r').readlines():
56 | context = get_context(line, max_post_len, max_ques_len)
57 | if ids_fname is not None:
58 | data.append([ids[i], context, None, None])
59 | else:
60 | data.append([None, context, None, None])
61 | i += 1
62 | if count and i == count:
63 | break
64 |
65 | i = 0
66 | for line in open(question_fname, 'r').readlines():
67 | question = normalize_string(line, max_ques_len-1)
68 | data[i][2] = question
69 | i += 1
70 | if count and i == count:
71 | break
72 | assert(i == len(data))
73 |
74 | if answer_fname is not None:
75 | i = 0
76 | for line in open(answer_fname, 'r').readlines():
77 | answer = normalize_string(line, max_ans_len-1) # one token space for EOS
78 | data[i][3] = answer
79 | i += 1
80 | if count and i == count:
81 | break
82 | assert(i == len(data))
83 |
84 | if ids_fname is not None:
85 | updated_data = []
86 | i = 0
87 | if mode == 'test':
88 | max_per_id_count = 1
89 | else:
90 | max_per_id_count = 20
91 | data_ct_per_id = defaultdict(int)
92 | for curr_id in ids:
93 | data_ct_per_id[curr_id] += 1
94 | if data_ct_per_id[curr_id] <= max_per_id_count:
95 | updated_data.append(data[i])
96 | i += 1
97 | if count and i == count:
98 | break
99 | assert (i == len(data))
100 | return updated_data
101 |
102 | return data
103 |
--------------------------------------------------------------------------------
/src/seq2seq/train.py:
--------------------------------------------------------------------------------
1 | from constants import *
2 | from masked_cross_entropy import *
3 | import numpy as np
4 | import random
5 | import torch
6 | from torch.autograd import Variable
7 |
8 |
9 | def train(input_batches, input_lens, target_batches, target_lens,
10 | encoder, decoder, encoder_optimizer, decoder_optimizer,
11 | SOS_idx, max_target_length, batch_size, teacher_forcing_ratio):
12 |
13 | # Zero gradients of both optimizers
14 | encoder_optimizer.zero_grad()
15 | decoder_optimizer.zero_grad()
16 |
17 | if USE_CUDA:
18 | input_batches = Variable(torch.LongTensor(np.array(input_batches)).cuda()).transpose(0, 1)
19 | target_batches = Variable(torch.LongTensor(np.array(target_batches)).cuda()).transpose(0, 1)
20 | else:
21 | input_batches = Variable(torch.LongTensor(np.array(input_batches))).transpose(0, 1)
22 | target_batches = Variable(torch.LongTensor(np.array(target_batches))).transpose(0, 1)
23 |
24 | # Run post words through encoder
25 | encoder_outputs, encoder_hidden = encoder(input_batches, input_lens, None)
26 |
27 | # Prepare input and output variables
28 | decoder_hidden = encoder_hidden[:decoder.n_layers] + encoder_hidden[decoder.n_layers:]
29 |
30 | if USE_CUDA:
31 | decoder_input = Variable(torch.LongTensor([SOS_idx] * batch_size).cuda())
32 | all_decoder_outputs = Variable(torch.zeros(max_target_length, batch_size, decoder.output_size).cuda())
33 | else:
34 | decoder_input = Variable(torch.LongTensor([SOS_idx] * batch_size))
35 | all_decoder_outputs = Variable(torch.zeros(max_target_length, batch_size, decoder.output_size))
36 |
37 | use_teacher_forcing = True if random.random() < teacher_forcing_ratio else False
38 |
39 | # Run through decoder one time step at a time
40 | for t in range(max_target_length):
41 | decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden, encoder_outputs)
42 | all_decoder_outputs[t] = decoder_output
43 |
44 | if use_teacher_forcing:
45 | # Teacher Forcing
46 | decoder_input = target_batches[t] # Next input is current target
47 | else:
48 | # Greeding decoding
49 | for b in range(batch_size):
50 | topi = decoder_output[b].topk(1)[1][0]
51 | decoder_input[b] = topi.squeeze().detach()
52 |
53 | loss_fn = torch.nn.NLLLoss()
54 |
55 | # Loss calculation and backpropagation
56 | loss = masked_cross_entropy(
57 | all_decoder_outputs.transpose(0, 1).contiguous(), # -> batch x seq
58 | target_batches.transpose(0, 1).contiguous(), # -> batch x seq
59 | target_lens, loss_fn, max_target_length
60 | )
61 | loss.backward()
62 | encoder_optimizer.step()
63 | decoder_optimizer.step()
64 | return loss.item()
65 |
--------------------------------------------------------------------------------
/src/utility/FeedForward.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from constants import *
4 |
5 | class FeedForward(nn.Module):
6 | def __init__(self, input_dim):
7 | super(FeedForward, self).__init__()
8 |
9 | self.layer1 = nn.Linear(input_dim, HIDDEN_SIZE)
10 | self.relu = nn.ReLU()
11 | self.layer2 = nn.Linear(HIDDEN_SIZE, 1)
12 |
13 | def forward(self, x):
14 | x = self.layer1(x)
15 | x = self.relu(x)
16 | x = self.layer2(x)
17 | return x
18 |
19 |
--------------------------------------------------------------------------------
/src/utility/RL_evaluate.py:
--------------------------------------------------------------------------------
1 | from constants import *
2 | import sys
3 | sys.path.append('src/utility')
4 | from helper_utility import *
5 | import numpy as np
6 | import torch
7 | from torch.autograd import Variable
8 |
9 |
10 | def evaluate_utility(context_model, question_model, answer_model, utility_model, c, cl, q, ql, a, al, args):
11 | with torch.no_grad():
12 | context_model.eval()
13 | question_model.eval()
14 | answer_model.eval()
15 | utility_model.eval()
16 | cm = get_masks(cl, args.max_post_len)
17 | qm = get_masks(ql, args.max_ques_len)
18 | am = get_masks(al, args.max_ans_len)
19 |
20 | c = torch.tensor(c)
21 | cm = torch.FloatTensor(cm)
22 | q = torch.tensor(q)
23 | qm = torch.FloatTensor(qm)
24 | a = torch.tensor(a)
25 | am = torch.FloatTensor(am)
26 | if USE_CUDA:
27 | c = c.cuda()
28 | cm = cm.cuda()
29 | q = q.cuda()
30 | qm = qm.cuda()
31 | a = a.cuda()
32 | am = am.cuda()
33 | c_hid, c_out = context_model(torch.transpose(c, 0, 1))
34 | cm = torch.transpose(cm, 0, 1).unsqueeze(2)
35 | cm = cm.expand(cm.shape[0], cm.shape[1], 2*HIDDEN_SIZE)
36 | c_out = torch.sum(c_out * cm, dim=0)
37 |
38 | q_hid, q_out = question_model(torch.transpose(q, 0, 1))
39 | qm = torch.transpose(qm, 0, 1).unsqueeze(2)
40 | qm = qm.expand(qm.shape[0], qm.shape[1], 2*HIDDEN_SIZE)
41 | q_out = torch.sum(q_out * qm, dim=0)
42 |
43 | a_hid, a_out = answer_model(torch.transpose(a, 0, 1))
44 | am = torch.transpose(am, 0, 1).unsqueeze(2)
45 | am = am.expand(am.shape[0], am.shape[1], 2*HIDDEN_SIZE)
46 | a_out = torch.sum(a_out * am, dim=0)
47 |
48 | predictions = utility_model(torch.cat((c_out, q_out, a_out), 1)).squeeze(1)
49 | # predictions = utility_model(torch.cat((c_out, q_out), 1)).squeeze(1)
50 | # predictions = utility_model(q_out).squeeze(1)
51 | predictions = torch.nn.functional.sigmoid(predictions)
52 |
53 | return predictions
54 |
--------------------------------------------------------------------------------
/src/utility/RL_train.py:
--------------------------------------------------------------------------------
1 | from helper import *
2 | import numpy as np
3 | import torch
4 | from constants import *
5 |
6 |
7 | def train_utility(context_model, question_model, answer_model, utility_model, optimizer, criterion,
8 | c, cm, q, qm, a, am, labs):
9 | optimizer.zero_grad()
10 | c = torch.tensor(c)
11 | cm = torch.FloatTensor(cm)
12 | q = torch.tensor(q)
13 | qm = torch.FloatTensor(qm)
14 | a = torch.tensor(a)
15 | am = torch.FloatTensor(am)
16 | if USE_CUDA:
17 | c = c.cuda()
18 | cm = cm.cuda()
19 | q = q.cuda()
20 | qm = qm.cuda()
21 | a = a.cuda()
22 | am = am.cuda()
23 | c_hid, c_out = context_model(torch.transpose(c, 0, 1))
24 | cm = torch.transpose(cm, 0, 1).unsqueeze(2)
25 | cm = cm.expand(cm.shape[0], cm.shape[1], 2*HIDDEN_SIZE)
26 | c_out = torch.sum(c_out * cm, dim=0)
27 |
28 | q_hid, q_out = question_model(torch.transpose(q, 0, 1))
29 | qm = torch.transpose(qm, 0, 1).unsqueeze(2)
30 | qm = qm.expand(qm.shape[0], qm.shape[1], 2*HIDDEN_SIZE)
31 | q_out = torch.sum(q_out * qm, dim=0)
32 |
33 | a_hid, a_out = answer_model(torch.transpose(a, 0, 1))
34 | am = torch.transpose(am, 0, 1).unsqueeze(2)
35 | am = am.expand(am.shape[0], am.shape[1], 2*HIDDEN_SIZE)
36 | a_out = torch.sum(a_out * am, dim=0)
37 |
38 | predictions = utility_model(torch.cat((c_out, q_out, a_out), 1)).squeeze(1)
39 | if USE_CUDA:
40 | labs = labs.cuda()
41 | loss = criterion(predictions, labs)
42 | acc = binary_accuracy(predictions, labs)
43 | loss.backward()
44 | optimizer.step()
45 | return loss, acc
46 |
--------------------------------------------------------------------------------
/src/utility/RNN.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from constants import *
4 |
5 |
6 | class RNN(nn.Module):
7 | def __init__(self, vocab_size, embedding_dim, n_layers):
8 | super(RNN, self).__init__()
9 |
10 | self.embedding = nn.Embedding(vocab_size, embedding_dim)
11 | self.rnn = nn.LSTM(embedding_dim, HIDDEN_SIZE, num_layers=n_layers, bidirectional=True)
12 | self.fc = nn.Linear(HIDDEN_SIZE*2, HIDDEN_SIZE)
13 | self.dropout = nn.Dropout(DROPOUT)
14 |
15 | def forward(self, x):
16 |
17 | #x = [sent len, batch size]
18 |
19 | embedded = self.dropout(self.embedding(x))
20 |
21 | #embedded = [sent len, batch size, emb dim]
22 |
23 | output, (hidden, cell) = self.rnn(embedded)
24 |
25 | #output = [sent len, batch size, hid dim * num directions]
26 | #hidden = [num layers * num directions, batch size, hid. dim]
27 | #cell = [num layers * num directions, batch size, hid. dim]
28 |
29 | hidden = self.dropout(torch.cat((hidden[-2, :, :], hidden[-1, :, :]), dim=1))
30 |
31 | #hidden [batch size, hid. dim * num directions]
32 |
33 | return self.fc(hidden.squeeze(0)), output
34 |
35 |
--------------------------------------------------------------------------------
/src/utility/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/raosudha89/clarification_question_generation_pytorch/23ae8aa0160eee70565751f4b6de13563a19d6ed/src/utility/__init__.py
--------------------------------------------------------------------------------
/src/utility/combine_pickle.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import pickle as p
3 |
4 | if __name__ == "__main__":
5 | askubuntu = p.load(open(sys.argv[1], 'rb'))
6 | unix = p.load(open(sys.argv[2], 'rb'))
7 | superuser = p.load(open(sys.argv[3], 'rb'))
8 | combined = unix + superuser + askubuntu
9 | p.dump(combined, open(sys.argv[4], 'wb'))
10 |
--------------------------------------------------------------------------------
/src/utility/data_loader.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import argparse
3 | import sys
4 | import torch
5 | import torch.autograd as autograd
6 | import random
7 | import torch.utils.data as Data
8 | import pickle as p
9 |
10 | def load_data(contexts_fname, answers_fname):
11 | contexts_file = open(contexts_fname, 'r')
12 | answers_file = open(answers_fname, 'r')
13 | contexts = []
14 | questions = []
15 | answers = []
16 | for line in contexts_file.readlines():
17 | context, question = line.strip('\n').split('')
18 | contexts.append(context)
19 | questions.append(question)
20 | answers = [line.strip('\n') for line in answers_file.readlines()]
21 | data = []
22 | for i in range(len(contexts)):
23 | data.append([contexts[i], questions[i], answers[i], 1])
24 | r = random.randint(0, len(contexts)-1)
25 | data.append([contexts[i], questions[r], answers[r], 0])
26 | random.shuffle(data)
27 | return data
28 |
29 | def main(args):
30 | train_data = load_data(args.train_contexts_fname, args.train_answers_fname)
31 | dev_data = load_data(args.tune_contexts_fname, args.tune_answers_fname)
32 | test_data = load_data(args.test_contexts_fname, args.test_answers_fname)
33 |
34 | p.dump(train_data, open(args.train_data, 'wb'))
35 | p.dump(dev_data, open(args.tune_data, 'wb'))
36 | p.dump(test_data, open(args.test_data, 'wb'))
37 |
38 | if __name__ == "__main__":
39 | argparser = argparse.ArgumentParser(sys.argv[0])
40 | argparser.add_argument("--train_contexts_fname", type = str)
41 | argparser.add_argument("--train_answers_fname", type = str)
42 | argparser.add_argument("--tune_contexts_fname", type = str)
43 | argparser.add_argument("--tune_answers_fname", type = str)
44 | argparser.add_argument("--test_contexts_fname", type = str)
45 | argparser.add_argument("--test_answers_fname", type = str)
46 | argparser.add_argument("--train_data", type = str)
47 | argparser.add_argument("--tune_data", type = str)
48 | argparser.add_argument("--test_data", type = str)
49 | args = argparser.parse_args()
50 | print args
51 | print ""
52 | main(args)
53 |
54 |
--------------------------------------------------------------------------------
/src/utility/evaluate_utility.py:
--------------------------------------------------------------------------------
1 | from helper_utility import *
2 | import numpy as np
3 | import torch
4 | from constants import *
5 |
6 |
7 | def evaluate(context_model, question_model, answer_model, utility_model, dev_data, criterion, args):
8 | epoch_loss = 0
9 | epoch_acc = 0
10 | num_batches = 0
11 |
12 | context_model.eval()
13 | question_model.eval()
14 | answer_model.eval()
15 | utility_model.eval()
16 |
17 | with torch.no_grad():
18 | contexts, context_lens, questions, question_lens, answers, answer_lens, labels = dev_data
19 | context_masks = get_masks(context_lens, args.max_post_len)
20 | question_masks = get_masks(question_lens, args.max_ques_len)
21 | answer_masks = get_masks(answer_lens, args.max_ans_len)
22 | contexts = np.array(contexts)
23 | questions = np.array(questions)
24 | answers = np.array(answers)
25 | labels = np.array(labels)
26 | for c, cm, q, qm, a, am, l in iterate_minibatches(contexts, context_masks,
27 | questions, question_masks,
28 | answers, answer_masks,
29 | labels, args.batch_size):
30 | c = torch.tensor(c)
31 | cm = torch.FloatTensor(cm)
32 | q = torch.tensor(q)
33 | qm = torch.FloatTensor(qm)
34 | a = torch.tensor(a)
35 | am = torch.FloatTensor(am)
36 | if USE_CUDA:
37 | c = c.cuda()
38 | cm = cm.cuda()
39 | q = q.cuda()
40 | qm = qm.cuda()
41 | a = a.cuda()
42 | am = am.cuda()
43 |
44 | c_hid, c_out = context_model(torch.transpose(c, 0, 1))
45 | cm = torch.transpose(cm, 0, 1).unsqueeze(2)
46 | cm = cm.expand(cm.shape[0], cm.shape[1], 2*HIDDEN_SIZE)
47 | c_out = torch.sum(c_out * cm, dim=0)
48 |
49 | # q_out: (sent_len, batch_size, num_directions*HIDDEN_DIM)
50 | q_hid, q_out = question_model(torch.transpose(q, 0, 1))
51 | qm = torch.transpose(qm, 0, 1).unsqueeze(2)
52 | qm = qm.expand(qm.shape[0], qm.shape[1], 2*HIDDEN_SIZE)
53 | q_out = torch.sum(q_out * qm, dim=0)
54 |
55 | a_hid, a_out = answer_model(torch.transpose(a, 0, 1))
56 | am = torch.transpose(am, 0, 1).unsqueeze(2)
57 | am = am.expand(am.shape[0], am.shape[1], 2*HIDDEN_SIZE)
58 | a_out = torch.sum(a_out * am, dim=0)
59 |
60 | predictions = utility_model(torch.cat((c_out, q_out, a_out), 1)).squeeze(1)
61 | predictions = torch.nn.functional.sigmoid(predictions)
62 |
63 | l = torch.FloatTensor([float(lab) for lab in l])
64 | if USE_CUDA:
65 | l = l.cuda()
66 | loss = criterion(predictions, l)
67 | acc = binary_accuracy(predictions, l)
68 | epoch_loss += loss.item()
69 | epoch_acc += acc
70 | num_batches += 1
71 |
72 | return epoch_loss / num_batches, epoch_acc / num_batches
73 |
74 |
--------------------------------------------------------------------------------
/src/utility/helper_utility.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn.functional as F
4 |
5 |
6 | def iterate_minibatches(c, cm, q, qm, a, am, l, batch_size, shuffle=True):
7 | if shuffle:
8 | indices = np.arange(len(c))
9 | np.random.shuffle(indices)
10 | for start_idx in range(0, len(c) - batch_size + 1, batch_size):
11 | if shuffle:
12 | excerpt = indices[start_idx:start_idx + batch_size]
13 | else:
14 | excerpt = slice(start_idx, start_idx + batch_size)
15 | yield c[excerpt], cm[excerpt], q[excerpt], qm[excerpt], a[excerpt], am[excerpt], l[excerpt]
16 |
17 |
18 | def get_masks(lens, max_len):
19 | masks = []
20 | for i in range(len(lens)):
21 | if lens[i] == None:
22 | lens[i] = 0
23 | masks.append([1]*int(lens[i])+[0]*int(max_len-lens[i]))
24 | return np.array(masks)
25 |
26 |
27 | def binary_accuracy(predictions, truth):
28 | """
29 | Returns accuracy per batch, i.e. if you get 8/10 right, this returns 0.8, NOT 8
30 | """
31 | correct = 0.
32 | for i in range(len(predictions)):
33 | if predictions[i] >= 0.5 and truth[i] == 1:
34 | correct += 1
35 | elif predictions[i] < 0.5 and truth[i] == 0:
36 | correct += 1
37 | acc = correct/len(predictions)
38 | return acc
39 |
40 |
--------------------------------------------------------------------------------
/src/utility/main.py:
--------------------------------------------------------------------------------
1 | from constants import *
2 | import sys
3 | sys.path.append('src/utility')
4 | from FeedForward import *
5 | from evaluate_utility import *
6 | import random
7 | from RNN import *
8 | import time
9 | import torch
10 | from torch import optim
11 | import torch.nn as nn
12 | import torch.autograd as autograd
13 | from train_utility import *
14 |
15 |
16 | def update_neg_data(data, index2word):
17 | ids_seqs, post_seqs, post_lens, ques_seqs, ques_lens, ans_seqs, ans_lens = data
18 | N = 2
19 | labels = [0]*int(N*len(post_seqs))
20 | new_post_seqs = [None]*int(N*len(post_seqs))
21 | new_post_lens = [None]*int(N*len(post_lens))
22 | new_ques_seqs = [None]*int(N*len(ques_seqs))
23 | new_ques_lens = [None]*int(N*len(ques_lens))
24 | new_ans_seqs = [None]*int(N*len(ans_seqs))
25 | new_ans_lens = [None]*int(N*len(ans_lens))
26 | for i in range(len(post_seqs)):
27 | new_post_seqs[N*i] = post_seqs[i]
28 | new_post_lens[N*i] = post_lens[i]
29 | new_ques_seqs[N*i] = ques_seqs[i]
30 | new_ques_lens[N*i] = ques_lens[i]
31 | new_ans_seqs[N*i] = ans_seqs[i]
32 | new_ans_lens[N*i] = ans_lens[i]
33 | labels[N*i] = 1
34 | for j in range(1, N):
35 | r = random.randint(0, len(post_seqs)-1)
36 | new_post_seqs[N*i+j] = post_seqs[i]
37 | new_post_lens[N*i+j] = post_lens[i]
38 | new_ques_seqs[N*i+j] = ques_seqs[r]
39 | new_ques_lens[N*i+j] = ques_lens[r]
40 | new_ans_seqs[N*i+j] = ans_seqs[r]
41 | new_ans_lens[N*i+j] = ans_lens[r]
42 | labels[N*i+j] = 0
43 |
44 | data = new_post_seqs, new_post_lens, \
45 | new_ques_seqs, new_ques_lens, \
46 | new_ans_seqs, new_ans_lens, labels
47 |
48 | return data
49 |
50 |
51 | def run_utility(train_data, test_data, word_embeddings, index2word, args, n_layers):
52 | context_model = RNN(len(word_embeddings), len(word_embeddings[0]), n_layers)
53 | question_model = RNN(len(word_embeddings), len(word_embeddings[0]), n_layers)
54 | answer_model = RNN(len(word_embeddings), len(word_embeddings[0]), n_layers)
55 | utility_model = FeedForward(HIDDEN_SIZE*3*2)
56 |
57 | if USE_CUDA:
58 | word_embeddings = autograd.Variable(torch.FloatTensor(word_embeddings).cuda())
59 | else:
60 | word_embeddings = autograd.Variable(torch.FloatTensor(word_embeddings))
61 |
62 | context_model.embedding.weight.data.copy_(word_embeddings)
63 | question_model.embedding.weight.data.copy_(word_embeddings)
64 | answer_model.embedding.weight.data.copy_(word_embeddings)
65 |
66 | # Fix word embeddings
67 | context_model.embedding.weight.requires_grad = False
68 | question_model.embedding.weight.requires_grad = False
69 | answer_model.embedding.weight.requires_grad = False
70 |
71 | optimizer = optim.Adam(list([par for par in context_model.parameters() if par.requires_grad]) +
72 | list([par for par in question_model.parameters() if par.requires_grad]) +
73 | list([par for par in answer_model.parameters() if par.requires_grad]) +
74 | list([par for par in utility_model.parameters() if par.requires_grad]))
75 |
76 | criterion = nn.BCELoss()
77 | # criterion = nn.BCEWithLogitsLoss()
78 | if USE_CUDA:
79 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
80 | context_model = context_model.to(device)
81 | question_model = question_model.to(device)
82 | answer_model = answer_model.to(device)
83 | utility_model = utility_model.to(device)
84 | criterion = criterion.to(device)
85 |
86 | train_data = update_neg_data(train_data, index2word)
87 | test_data = update_neg_data(test_data, index2word)
88 |
89 | for epoch in range(args.n_epochs):
90 | start_time = time.time()
91 | train_loss, train_acc = train_fn(context_model, question_model, answer_model, utility_model,
92 | train_data, optimizer, criterion, args)
93 | valid_loss, valid_acc = evaluate(context_model, question_model, answer_model, utility_model,
94 | test_data, criterion, args)
95 | print('Epoch %d: Train Loss: %.3f, Train Acc: %.3f, Val Loss: %.3f, Val Acc: %.3f' % \
96 | (epoch, train_loss, train_acc, valid_loss, valid_acc))
97 | print('Time taken: ', time.time()-start_time)
98 | # if epoch % 5 == 0:
99 | print('Saving model params')
100 | torch.save(context_model.state_dict(), args.context_params+'.epoch%d' % epoch)
101 | torch.save(question_model.state_dict(), args.question_params+'.epoch%d' % epoch)
102 | torch.save(answer_model.state_dict(), args.answer_params+'.epoch%d' % epoch)
103 | torch.save(utility_model.state_dict(), args.utility_params+'.epoch%d' % epoch)
104 |
105 |
--------------------------------------------------------------------------------
/src/utility/run_combine_domains.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | DATA_DIR=/fs/clip-scratch/raosudha/clarification_question_generation/utility
4 | UBUNTU=askubuntu.com
5 | UNIX=unix.stackexchange.com
6 | SUPERUSER=superuser.com
7 | SCRIPTS_DIR=/fs/clip-amr/clarification_question_generation_pytorch/src/utility
8 | SITE_NAME=askubuntu_unix_superuser
9 |
10 | mkdir $DATA_DIR/$SITE_NAME
11 |
12 | python $SCRIPTS_DIR/combine_pickle.py $DATA_DIR/$UBUNTU/train_data.p \
13 | $DATA_DIR/$UNIX/train_data.p \
14 | $DATA_DIR/$SUPERUSER/train_data.p \
15 | $DATA_DIR/$SITE_NAME/train_data.p
16 |
17 | python $SCRIPTS_DIR/combine_pickle.py $DATA_DIR/$UBUNTU/tune_data.p \
18 | $DATA_DIR/$UNIX/tune_data.p \
19 | $DATA_DIR/$SUPERUSER/tune_data.p \
20 | $DATA_DIR/$SITE_NAME/tune_data.p
21 |
22 | python $SCRIPTS_DIR/combine_pickle.py $DATA_DIR/$UBUNTU/test_data.p \
23 | $DATA_DIR/$UNIX/test_data.p \
24 | $DATA_DIR/$SUPERUSER/test_data.p \
25 | $DATA_DIR/$SITE_NAME/test_data.p
26 |
27 |
--------------------------------------------------------------------------------
/src/utility/run_data_loader.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | #SBATCH --job-name=utility_data_aus
4 | #SBATCH --output=utility_data_aus
5 | #SBATCH --qos=batch
6 | #SBATCH --mem=36g
7 | #SBATCH --time=24:00:00
8 |
9 | #SITENAME=askubuntu.com
10 | #SITENAME=unix.stackexchange.com
11 | SITENAME=superuser.com
12 | #SITENAME=askubuntu_unix_superuser
13 | #SITENAME=Home_and_Kitchen
14 | QA_DATA_DIR=/fs/clip-scratch/raosudha/clarification_question_generation/question_answering/$SITENAME
15 | UTILITY_DATA_DIR=/fs/clip-scratch/raosudha/clarification_question_generation/utility/$SITENAME
16 | SCRIPT_DIR=/fs/clip-amr/clarification_question_generation_pytorch/src/utility
17 |
18 | export PATH="/fs/clip-amr/anaconda2/bin:$PATH"
19 |
20 | python $SCRIPT_DIR/data_loader.py --train_contexts_fname $QA_DATA_DIR/train_src \
21 | --train_answers_fname $QA_DATA_DIR/train_tgt \
22 | --tune_contexts_fname $QA_DATA_DIR/tune_src \
23 | --tune_answers_fname $QA_DATA_DIR/tune_tgt \
24 | --test_contexts_fname $QA_DATA_DIR/test_src \
25 | --test_answers_fname $QA_DATA_DIR/test_tgt \
26 | --train_data $UTILITY_DATA_DIR/train_data.p \
27 | --tune_data $UTILITY_DATA_DIR/tune_data.p \
28 | --test_data $UTILITY_DATA_DIR/test_data.p \
29 | #--word_to_ix $UTILITY_DATA_DIR/word_to_ix.p \
30 |
31 |
32 |
--------------------------------------------------------------------------------
/src/utility/run_rnn_classifier.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | #SBATCH --job-name=utility_aus_fullmodel_80K
4 | #SBATCH --output=utility_aus_fullmodel_80K
5 | #SBATCH --qos=gpu-short
6 | #SBATCH --partition=gpu
7 | #SBATCH --gres=gpu
8 | #SBATCH --mem=16g
9 |
10 | #SITENAME=askubuntu.com
11 | #SITENAME=unix.stackexchange.com
12 | #SITENAME=superuser.com
13 | SITENAME=askubuntu_unix_superuser
14 | #SITENAME=Home_and_Kitchen
15 | DATA_DIR=/fs/clip-corpora/amazon_qa
16 | UTILITY_DATA_DIR=/fs/clip-scratch/raosudha/clarification_question_generation/utility/$SITENAME
17 | SCRIPT_DIR=/fs/clip-amr/clarification_question_generation_pytorch/src/utility
18 | EMB_DIR=/fs/clip-amr/clarification_question_generation_pytorch/embeddings/$SITENAME/200_5Kvocab
19 | #EMB_DIR=/fs/clip-amr/ranking_clarification_questions/embeddings
20 |
21 | source /fs/clip-amr/gpu_virtualenv/bin/activate
22 | export PATH="/fs/clip-amr/anaconda2/bin:$PATH"
23 |
24 | python $SCRIPT_DIR/rnn_classifier.py --train_data $UTILITY_DATA_DIR/train_data.p \
25 | --tune_data $UTILITY_DATA_DIR/tune_data.p \
26 | --test_data $UTILITY_DATA_DIR/test_data.p \
27 | --word_embeddings $EMB_DIR/word_embeddings.p \
28 | --vocab $EMB_DIR/vocab.p \
29 | --cuda True \
30 |
31 |
32 |
--------------------------------------------------------------------------------
/src/utility/train_utility.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append('src/utility')
3 | from helper_utility import *
4 | import numpy as np
5 | import torch
6 | from constants import *
7 |
8 |
9 | def train_fn(context_model, question_model, answer_model, utility_model, train_data, optimizer, criterion, args):
10 | epoch_loss = 0
11 | epoch_acc = 0
12 |
13 | context_model.train()
14 | question_model.train()
15 | answer_model.train()
16 | utility_model.train()
17 |
18 | contexts, context_lens, questions, question_lens, answers, answer_lens, labels = train_data
19 | context_masks = get_masks(context_lens, args.max_post_len)
20 | question_masks = get_masks(question_lens, args.max_ques_len)
21 | answer_masks = get_masks(answer_lens, args.max_ans_len)
22 | contexts = np.array(contexts)
23 | questions = np.array(questions)
24 | answers = np.array(answers)
25 | labels = np.array(labels)
26 |
27 | num_batches = 0
28 | for c, cm, q, qm, a, am, l in iterate_minibatches(contexts, context_masks, questions, question_masks,
29 | answers, answer_masks, labels, args.batch_size):
30 | optimizer.zero_grad()
31 | c = torch.tensor(c.tolist())
32 | cm = torch.FloatTensor(cm)
33 | q = torch.tensor(q.tolist())
34 | qm = torch.FloatTensor(qm)
35 | a = torch.tensor(a.tolist())
36 | am = torch.FloatTensor(am)
37 | if USE_CUDA:
38 | c = c.cuda()
39 | cm = cm.cuda()
40 | q = q.cuda()
41 | qm = qm.cuda()
42 | a = a.cuda()
43 | am = am.cuda()
44 |
45 | # c_out: (sent_len, batch_size, num_directions*HIDDEN_DIM)
46 | c_hid, c_out = context_model(torch.transpose(c, 0, 1))
47 | cm = torch.transpose(cm, 0, 1).unsqueeze(2)
48 | cm = cm.expand(cm.shape[0], cm.shape[1], 2*HIDDEN_SIZE)
49 | c_out = torch.sum(c_out * cm, dim=0)
50 |
51 | q_hid, q_out = question_model(torch.transpose(q, 0, 1))
52 | qm = torch.transpose(qm, 0, 1).unsqueeze(2)
53 | qm = qm.expand(qm.shape[0], qm.shape[1], 2*HIDDEN_SIZE)
54 | q_out = torch.sum(q_out * qm, dim=0)
55 |
56 | a_hid, a_out = answer_model(torch.transpose(a, 0, 1))
57 | am = torch.transpose(am, 0, 1).unsqueeze(2)
58 | am = am.expand(am.shape[0], am.shape[1], 2*HIDDEN_SIZE)
59 | a_out = torch.sum(a_out * am, dim=0)
60 |
61 | predictions = utility_model(torch.cat((c_out, q_out, a_out), 1)).squeeze(1)
62 | # predictions = utility_model(torch.cat((c_out, q_out), 1)).squeeze(1)
63 | # predictions = utility_model(q_out).squeeze(1)
64 | predictions = torch.nn.functional.sigmoid(predictions)
65 |
66 | l = torch.FloatTensor([float(lab) for lab in l])
67 | if USE_CUDA:
68 | l = l.cuda()
69 | loss = criterion(predictions, l)
70 | acc = binary_accuracy(predictions, l)
71 | loss.backward()
72 | optimizer.step()
73 | epoch_loss += loss.item()
74 | epoch_acc += acc
75 | num_batches += 1
76 |
77 | return epoch_loss, epoch_acc / num_batches
78 |
79 |
80 |
--------------------------------------------------------------------------------