├── README.md └── src ├── GAN_main.py ├── RL_helper.py ├── RL_main.py ├── constants.py ├── data_generation ├── create_amazon_data.py ├── create_amazon_pqa_data.py ├── create_amazon_pqa_data_candqs.py ├── create_amazon_pqa_data_from_asins.py ├── create_lucene_amazon.py ├── create_pqa_data.py ├── filter_amazon_data_byid.py ├── filter_output_perid.py ├── remove_unks.py ├── run_create_amazon_data.sh ├── run_create_amazon_pqa_data.sh ├── run_create_amazon_pqa_data_from_asins.sh ├── run_create_lucene_amazon.sh ├── run_create_pqa_data.sh ├── run_filter_amazon_data_byid.sh └── run_filter_output_perid.sh ├── decode.py ├── embedding_generation ├── create_we_vocab.py ├── extract_amazon_data.py ├── extract_data.py ├── run_create_we_vocab.sh ├── run_extract_amazon_data.sh ├── run_extract_data.sh └── run_glove.sh ├── evaluation ├── all_ngrams.pl ├── calculate_diversity.sh ├── calculate_inter_annotator_agreement.py ├── combine_refs_for_meteor.py ├── create_amazon_multi_refs.py ├── create_crowdflower_data.py ├── create_crowdflower_data_beam.py ├── create_crowdflower_data_compare_ques.py ├── create_crowdflower_data_compare_ques_allpairs.py ├── create_crowdflower_data_specificity.py ├── create_crowdflower_data_specificity_multi.py ├── create_preds_for_refs.py ├── eval_HK ├── eval_HK_spec ├── eval_aus ├── read_crowdflower_full_results.py ├── read_crowdflower_full_results_compare_ques.py ├── read_crowdflower_results.py ├── read_crowdflower_results_binary.py ├── read_crowdflower_results_both.py ├── read_crowdflower_results_compare_ques.py ├── read_crowdflower_results_diff_formats.py ├── read_crowdflower_results_style.py ├── run_bleu.sh ├── run_create_amazon_multi_refs.sh ├── run_create_crowdflower_data.sh ├── run_create_crowdflower_data_beam.sh ├── run_create_crowdflower_data_compare_ques.sh ├── run_create_crowdflower_data_compare_ques_allpairs.sh ├── run_create_crowdflower_data_specificity.sh ├── run_create_crowdflower_data_specificity_multi.sh ├── run_create_preds_for_refs.sh ├── run_meteor.sh └── run_meteor_HK.sh ├── lucene ├── create_amazon_lucene_baseline.py ├── create_stackexchange_lucene_baseline.py ├── run_create_amazon_lucene_baseline.sh └── run_create_stackexchange_lucene_baseline.sh ├── main.py ├── run_GAN_decode.sh ├── run_GAN_decode_HK.sh ├── run_GAN_main.sh ├── run_GAN_main_HK.sh ├── run_GAN_main_electronics.sh ├── run_RL_decode.sh ├── run_RL_decode_HK.sh ├── run_RL_main.sh ├── run_RL_main_HK.sh ├── run_decode.sh ├── run_decode_HK.sh ├── run_decode_electronics.sh ├── run_main.sh ├── run_main_HK.sh ├── run_main_electronics.sh ├── run_pretrain_ans.sh ├── run_pretrain_ans_HK.sh ├── run_pretrain_util.sh ├── run_pretrain_util_HK.sh ├── seq2seq ├── GAN_train.py ├── RL_beam_decoder.py ├── RL_evaluate.py ├── RL_inference.py ├── RL_train.py ├── __init__.py ├── ans_train.py ├── attn.py ├── attnDecoderRNN.py ├── baselineFF.py ├── encoderRNN.py ├── evaluate.py ├── helper.py ├── main.py ├── masked_cross_entropy.py ├── prepare_data.py ├── read_data.py └── train.py └── utility ├── FeedForward.py ├── RL_evaluate.py ├── RL_train.py ├── RNN.py ├── __init__.py ├── combine_pickle.py ├── data_loader.py ├── evaluate_utility.py ├── helper_utility.py ├── main.py ├── rnn_classifier.py ├── run_combine_domains.sh ├── run_data_loader.sh ├── run_rnn_classifier.sh └── train_utility.py /README.md: -------------------------------------------------------------------------------- 1 | # Repository information 2 | 3 | This repository contains data and code for the paper below: 4 | 5 | 6 | Answer-based Adversarial Training for Generating Clarification Questions
7 | Sudha Rao (Sudha.Rao@microsoft.com) and Hal Daumé III (me@hal3.name)
8 | Proceedings of NAACL-HLT 2019 9 | 10 | # Downloading data 11 | 12 | * Download embeddings from https://go.umd.edu/clarification_questions_embeddings 13 | and save them into the repository folder 14 | * Download data from https://go.umd.edu/clarification_question_generation_dataset 15 | Unzip the two folders inside and copy them into the repository folder 16 | 17 | # Training models on StackExchange dataset 18 | 19 | * To train an MLE model, run src/run_main.sh 20 | 21 | * To train a Max-Utility model, follow these three steps: 22 | 23 | * run src/run_pretrain_ans.sh 24 | 25 | * run src/run_pretrain_util.sh 26 | 27 | * run src/run_RL_main.sh 28 | 29 | * To train a GAN-Utility model, follow these three steps (note, you can skip first two steps if you have already ran them for Max-Utility model): 30 | 31 | * run src/run_pretrain_ans.sh 32 | 33 | * run src/run_pretrain_util.sh 34 | 35 | * run src/run_GAN_main.sh 36 | 37 | # Training models on Amazon (Home & Kitchen) dataset 38 | 39 | * To train an MLE model, run src/run_main_HK.sh 40 | 41 | * To train a Max-Utility model, follow these three steps: 42 | 43 | * run src/run_pretrain_ans_HK.sh 44 | 45 | * run src/run_pretrain_util_HK.sh 46 | 47 | * run src/run_RL_main_HK.sh 48 | 49 | * To train a GAN-Utility model, follow these three steps (note, you can skip first two steps if you have already ran them for Max-Utility model): 50 | 51 | * run src/run_pretrain_ans_HK.sh 52 | 53 | * run src/run_pretrain_util_HK.sh 54 | 55 | * run src/run_GAN_main_HK.sh 56 | 57 | # Generating outputs using trained models 58 | 59 | * Run following scripts to generate outputs for models trained on StackExchange dataset: 60 | 61 | * For MLE model, run src/run_decode.sh 62 | 63 | * For Max-Utility model, run src/run_RL_decode.sh 64 | 65 | * For GAN-Utility model, run src/run_GAN_decode.sh 66 | 67 | * Run following scripts to generate outputs for models trained on Amazon dataset: 68 | 69 | * For MLE model, run src/run_decode_HK.sh 70 | 71 | * For Max-Utility model, run src/run_RL_decode_HK.sh 72 | 73 | * For GAN-Utility model, run src/run_GAN_decode_HK.sh 74 | 75 | # Evaluating generated outputs 76 | 77 | * For StackExchange dataset, reference for a subset of the test set was collected using human annotators. 78 | Hence we first create a version of the predictions file for which we have references by running following: 79 | src/evaluation/run_create_preds_for_refs.sh 80 | 81 | * For Amazon dataset, we have references for all instances in the test set. 82 | 83 | * We remove tokens from the generated outputs by simply removing them from the predictions file. 84 | 85 | * For BLEU score, run src/evaluation/run_bleu.sh 86 | 87 | * For METEOR score, run src/evaluation/run_meteor.sh 88 | 89 | * For Diversity score, run src/evaluation/calculate_diversiy.sh 90 | -------------------------------------------------------------------------------- /src/RL_helper.py: -------------------------------------------------------------------------------- 1 | from constants import * 2 | import math 3 | import numpy as np 4 | import nltk 5 | import time 6 | import torch 7 | 8 | 9 | def as_minutes(s): 10 | m = math.floor(s / 60) 11 | s -= m * 60 12 | return '%dm %ds' % (m, s) 13 | 14 | 15 | def time_since(since, percent): 16 | now = time.time() 17 | s = now - since 18 | es = s / (percent) 19 | rs = es - s 20 | return '%s (- %s)' % (as_minutes(s), as_minutes(rs)) 21 | 22 | 23 | def iterate_minibatches(p, pl, q, ql, pq, pql, a, al, batch_size, shuffle=True): 24 | if shuffle: 25 | indices = np.arange(len(p)) 26 | np.random.shuffle(indices) 27 | for start_idx in range(0, len(p) - batch_size + 1, batch_size): 28 | if shuffle: 29 | ex = indices[start_idx:start_idx + batch_size] 30 | else: 31 | ex = slice(start_idx, start_idx + batch_size) 32 | yield np.array(p)[ex], np.array(pl)[ex], np.array(q)[ex], np.array(ql)[ex], \ 33 | np.array(pq)[ex], np.array(pql)[ex], np.array(a)[ex], np.array(al)[ex] 34 | 35 | 36 | def reverse_dict(word2index): 37 | index2word = {} 38 | for w in word2index: 39 | ix = word2index[w] 40 | index2word[ix] = w 41 | return index2word 42 | 43 | 44 | def calculate_bleu(true, true_lens, pred, pred_lens, index2word, max_len): 45 | sent_bleu_scores = torch.zeros(len(pred)) 46 | for i in range(len(pred)): 47 | true_sent = [index2word[idx] for idx in true[i][:true_lens[i]]] 48 | pred_sent = [index2word[idx] for idx in pred[i][:pred_lens[i]]] 49 | sent_bleu_scores[i] = nltk.translate.bleu_score.sentence_bleu(true_sent, pred_sent) 50 | if USE_CUDA: 51 | sent_bleu_scores = sent_bleu_scores.cuda() 52 | return sent_bleu_scores 53 | -------------------------------------------------------------------------------- /src/constants.py: -------------------------------------------------------------------------------- 1 | USE_CUDA = True 2 | 3 | # Configure models 4 | HIDDEN_SIZE = 100 5 | DROPOUT = 0.5 6 | 7 | # Configure training/optimization 8 | LEARNING_RATE = 0.0001 9 | DECODER_LEARNING_RATIO = 5.0 10 | 11 | PAD_token = '' 12 | SOS_token = '' 13 | EOP_token = '' 14 | EOS_token = '' 15 | UNK_token = '' 16 | # UNK_token = 'unk' 17 | SPECIFIC_token = '' 18 | GENERIC_token = '' 19 | 20 | BEAM_SIZE = 5 21 | -------------------------------------------------------------------------------- /src/data_generation/create_amazon_pqa_data.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import gzip 3 | import nltk 4 | import pdb 5 | import sys, os 6 | import re 7 | from collections import defaultdict 8 | 9 | def parse(path): 10 | g = gzip.open(path, 'r') 11 | for l in g: 12 | yield eval(l) 13 | 14 | exception_chars = ['|', '/', '\\', '-', '(', ')', '!', ':', ';', '<', '>'] 15 | 16 | def preprocess(text): 17 | text = text.replace('|', ' ') 18 | text = text.replace('/', ' ') 19 | text = text.replace('\\', ' ') 20 | text = text.lower() 21 | #text = re.sub(r'\W+', ' ', text) 22 | ret_text = '' 23 | for sent in nltk.sent_tokenize(text): 24 | ret_text += ' '.join(nltk.word_tokenize(sent)) + ' ' 25 | return ret_text 26 | 27 | def main(args): 28 | products = {} 29 | for v in parse(args.metadata_fname): 30 | if 'description' not in v or 'title' not in v: 31 | continue 32 | asin = v['asin'] 33 | title = preprocess(v['title']) 34 | description = preprocess(v['description']) 35 | product = title + ' . ' + description 36 | products[asin] = product 37 | 38 | train_asin_file = open(args.train_asin_fname, 'w') 39 | train_context_file = open(args.train_context_fname, 'w') 40 | train_ques_file = open(args.train_ques_fname, 'w') 41 | train_ans_file = open(args.train_ans_fname, 'w') 42 | tune_asin_file = open(args.tune_asin_fname, 'w') 43 | tune_context_file = open(args.tune_context_fname, 'w') 44 | tune_ques_file = open(args.tune_ques_fname, 'w') 45 | tune_ans_file = open(args.tune_ans_fname, 'w') 46 | test_asin_file = open(args.test_asin_fname, 'w') 47 | test_context_file = open(args.test_context_fname, 'w') 48 | test_ques_file = open(args.test_ques_fname, 'w') 49 | test_ans_file = open(args.test_ans_fname, 'w') 50 | 51 | asins = [] 52 | contexts = [] 53 | questions = [] 54 | answers = [] 55 | for v in parse(args.qa_data_fname): 56 | asin = v['asin'] 57 | if asin not in products or 'answer' not in v: 58 | continue 59 | question = preprocess(v['question']) 60 | answer = preprocess(v['answer']) 61 | if not answer: 62 | continue 63 | asins.append(asin) 64 | contexts.append(products[asin]) 65 | questions.append(question) 66 | answers.append(answer) 67 | N = len(contexts) 68 | for i in range(int(N*0.8)): 69 | train_asin_file.write(asins[i]+'\n') 70 | train_context_file.write(contexts[i]+'\n') 71 | train_ques_file.write(questions[i]+'\n') 72 | train_ans_file.write(answers[i]+'\n') 73 | for i in range(int(N*0.8), int(N*0.9)): 74 | tune_asin_file.write(asins[i]+'\n') 75 | tune_context_file.write(contexts[i]+'\n') 76 | tune_ques_file.write(questions[i]+'\n') 77 | tune_ans_file.write(answers[i]+'\n') 78 | for i in range(int(N*0.9), N): 79 | test_asin_file.write(asins[i]+'\n') 80 | test_context_file.write(contexts[i]+'\n') 81 | test_ques_file.write(questions[i]+'\n') 82 | test_ans_file.write(answers[i]+'\n') 83 | 84 | train_asin_file.close() 85 | train_context_file.close() 86 | train_ques_file.close() 87 | train_ans_file.close() 88 | tune_asin_file.close() 89 | tune_context_file.close() 90 | tune_ques_file.close() 91 | tune_ans_file.close() 92 | test_asin_file.close() 93 | test_context_file.close() 94 | test_ques_file.close() 95 | test_ans_file.close() 96 | 97 | 98 | if __name__ == "__main__": 99 | argparser = argparse.ArgumentParser(sys.argv[0]) 100 | argparser.add_argument("--qa_data_fname", type = str) 101 | argparser.add_argument("--metadata_fname", type = str) 102 | argparser.add_argument("--train_asin_fname", type = str) 103 | argparser.add_argument("--train_context_fname", type = str) 104 | argparser.add_argument("--train_ques_fname", type = str) 105 | argparser.add_argument("--train_ans_fname", type = str) 106 | argparser.add_argument("--tune_asin_fname", type = str) 107 | argparser.add_argument("--tune_context_fname", type = str) 108 | argparser.add_argument("--tune_ques_fname", type = str) 109 | argparser.add_argument("--tune_ans_fname", type = str) 110 | argparser.add_argument("--test_asin_fname", type = str) 111 | argparser.add_argument("--test_context_fname", type = str) 112 | argparser.add_argument("--test_ques_fname", type = str) 113 | argparser.add_argument("--test_ans_fname", type = str) 114 | args = argparser.parse_args() 115 | print args 116 | print "" 117 | main(args) 118 | 119 | -------------------------------------------------------------------------------- /src/data_generation/create_amazon_pqa_data_from_asins.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import gzip 3 | import nltk 4 | import pdb 5 | import sys, os 6 | import re 7 | from collections import defaultdict 8 | 9 | 10 | def parse(path): 11 | g = gzip.open(path, 'r') 12 | for l in g: 13 | yield eval(l) 14 | 15 | 16 | exception_chars = ['|', '/', '\\', '-', '(', ')', '!', ':', ';', '<', '>'] 17 | 18 | 19 | def preprocess(text): 20 | text = text.replace('|', ' ') 21 | text = text.replace('/', ' ') 22 | text = text.replace('\\', ' ') 23 | text = text.lower() 24 | #text = re.sub(r'\W+', ' ', text) 25 | ret_text = '' 26 | for sent in nltk.sent_tokenize(text): 27 | ret_text += ' '.join(nltk.word_tokenize(sent)) + ' ' 28 | return ret_text 29 | 30 | 31 | def main(args): 32 | products = {} 33 | for v in parse(args.metadata_fname): 34 | if 'description' not in v or 'title' not in v: 35 | continue 36 | asin = v['asin'] 37 | title = preprocess(v['title']) 38 | description = preprocess(v['description']) 39 | product = title + ' . ' + description 40 | products[asin] = product 41 | 42 | train_asin_file = open(args.train_asin_fname, 'r') 43 | train_ans_file = open(args.train_ans_fname, 'w') 44 | tune_asin_file = open(args.tune_asin_fname, 'w') 45 | tune_context_file = open(args.tune_context_fname, 'w') 46 | tune_ques_file = open(args.tune_ques_fname, 'w') 47 | tune_ans_file = open(args.tune_ans_fname, 'w') 48 | test_asin_file = open(args.test_asin_fname, 'r') 49 | test_ans_file = open(args.test_ans_fname, 'w') 50 | 51 | train_asins = [] 52 | test_asins = [] 53 | for line in train_asin_file.readlines(): 54 | train_asins.append(line.strip('\n')) 55 | for line in test_asin_file.readlines(): 56 | test_asins.append(line.strip('\n')) 57 | 58 | asins = [] 59 | contexts = {} 60 | questions = {} 61 | answers = {} 62 | 63 | for v in parse(args.qa_data_fname): 64 | asin = v['asin'] 65 | if asin not in products or 'answer' not in v: 66 | continue 67 | question = preprocess(v['question']) 68 | answer = preprocess(v['answer']) 69 | if not answer: 70 | continue 71 | asins.append(asin) 72 | contexts[asin] = products[asin] 73 | questions[asin] = question 74 | answers[asin] = answer 75 | 76 | for asin in train_asins: 77 | train_ans_file.write(answers[asin]+'\n') 78 | for asin in asins: 79 | if asin in train_asins or asin in test_asins: 80 | continue 81 | tune_asin_file.write(asin+'\n') 82 | tune_context_file.write(contexts[asin]+'\n') 83 | tune_ques_file.write(questions[asin]+'\n') 84 | tune_ans_file.write(answers[asin]+'\n') 85 | for asin in test_asins: 86 | test_ans_file.write(answers[asin]+'\n') 87 | 88 | train_asin_file.close() 89 | train_ans_file.close() 90 | tune_asin_file.close() 91 | tune_context_file.close() 92 | tune_ques_file.close() 93 | tune_ans_file.close() 94 | test_asin_file.close() 95 | test_ans_file.close() 96 | 97 | 98 | if __name__ == "__main__": 99 | argparser = argparse.ArgumentParser(sys.argv[0]) 100 | argparser.add_argument("--qa_data_fname", type = str) 101 | argparser.add_argument("--metadata_fname", type = str) 102 | argparser.add_argument("--train_asin_fname", type = str) 103 | argparser.add_argument("--train_ans_fname", type = str) 104 | argparser.add_argument("--tune_asin_fname", type = str) 105 | argparser.add_argument("--tune_context_fname", type = str) 106 | argparser.add_argument("--tune_ques_fname", type = str) 107 | argparser.add_argument("--tune_ans_fname", type = str) 108 | argparser.add_argument("--test_asin_fname", type = str) 109 | argparser.add_argument("--test_ans_fname", type = str) 110 | args = argparser.parse_args() 111 | print args 112 | print "" 113 | main(args) 114 | 115 | -------------------------------------------------------------------------------- /src/data_generation/create_lucene_amazon.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import gzip 3 | import nltk 4 | import pdb 5 | import sys, os 6 | import re 7 | from collections import defaultdict 8 | 9 | def parse(path): 10 | g = gzip.open(path, 'r') 11 | for l in g: 12 | yield eval(l) 13 | 14 | exception_chars = ['|', '/', '\\', '-', '(', ')', '!', ':', ';', '<', '>'] 15 | 16 | def preprocess(text): 17 | text = text.replace('|', ' ') 18 | text = text.replace('/', ' ') 19 | text = text.replace('\\', ' ') 20 | text = text.lower() 21 | #text = re.sub(r'\W+', ' ', text) 22 | ret_text = '' 23 | for sent in nltk.sent_tokenize(text): 24 | ret_text += ' '.join(nltk.word_tokenize(sent)) + ' ' 25 | return ret_text 26 | 27 | def main(args): 28 | brand_counts = defaultdict(int) 29 | for v in parse(args.metadata_fname): 30 | if 'description' not in v or 'title' not in v or 'brand' not in v: 31 | continue 32 | brand_counts[v['brand']] += 1 33 | low_ct_brands = [] 34 | for brand, ct in brand_counts.iteritems(): 35 | if ct < 100: 36 | low_ct_brands.append(brand) 37 | products = {} 38 | for v in parse(args.metadata_fname): 39 | if 'description' not in v or 'title' not in v or 'brand' not in v: 40 | continue 41 | if v['brand'] not in low_ct_brands: 42 | continue 43 | asin = v['asin'] 44 | title = preprocess(v['title']) 45 | description = preprocess(v['description']) 46 | product = title + ' . ' + description 47 | products[asin] = product 48 | 49 | question_list = {} 50 | print 'Creating docs' 51 | for v in parse(args.qa_data_fname): 52 | asin = v['asin'] 53 | if asin not in products: 54 | continue 55 | print asin 56 | if asin not in question_list: 57 | f = open(os.path.join(args.product_dir, asin + '.txt'), 'w') 58 | f.write(products[asin]) 59 | f.close() 60 | question_list[asin] = [] 61 | question = preprocess(v['question']) 62 | question_list[asin].append(question) 63 | f = open(os.path.join(args.question_dir, asin + '_' + str(len(question_list[asin])) + '.txt'), 'w') 64 | f.write(question) 65 | f.close() 66 | 67 | if __name__ == "__main__": 68 | argparser = argparse.ArgumentParser(sys.argv[0]) 69 | argparser.add_argument("--qa_data_fname", type = str) 70 | argparser.add_argument("--metadata_fname", type = str) 71 | argparser.add_argument("--product_dir", type = str) 72 | argparser.add_argument("--question_dir", type = str) 73 | args = argparser.parse_args() 74 | print args 75 | print "" 76 | main(args) 77 | 78 | -------------------------------------------------------------------------------- /src/data_generation/create_pqa_data.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import sys 3 | from collections import defaultdict 4 | import csv 5 | import math 6 | import nltk 7 | 8 | 9 | MAX_POST_LEN = 100 10 | MAX_QUES_LEN = 20 11 | MAX_ANS_LEN = 20 12 | 13 | 14 | def write_to_file(ids, args, posts, question_candidates, answer_candidates, split): 15 | 16 | if split == 'train': 17 | context_file = open(args.train_context_fname, 'w') 18 | question_file = open(args.train_question_fname, 'w') 19 | answer_file = open(args.train_answer_fname, 'w') 20 | elif split == 'tune': 21 | context_file = open(args.tune_context_fname, 'w') 22 | question_file = open(args.tune_question_fname, 'w') 23 | answer_file = open(args.tune_answer_fname, 'w') 24 | elif split == 'test': 25 | context_file = open(args.test_context_fname, 'w') 26 | question_file = open(args.test_question_fname, 'w') 27 | answer_file = open(args.test_answer_fname, 'w') 28 | 29 | for k, post_id in enumerate(ids): 30 | context_file.write(posts[post_id]+'\n') 31 | question_file.write(question_candidates[post_id][0]+'\n') 32 | answer_file.write(answer_candidates[post_id][0]+'\n') 33 | 34 | context_file.close() 35 | question_file.close() 36 | answer_file.close() 37 | 38 | 39 | def trim_by_len(s, max_len): 40 | s = s.lower().strip() 41 | words = s.split() 42 | s = ' '.join(words[:max_len]) 43 | return s 44 | 45 | 46 | def trim_by_tfidf(posts, p_tf, p_idf): 47 | for post_id in posts: 48 | post = [] 49 | words = posts[post_id].split() 50 | for w in words: 51 | tf = words.count(w) 52 | if tf*p_idf[w] >= MIN_TFIDF: 53 | post.append(w) 54 | if len(post) >= MAX_POST_LEN: 55 | break 56 | posts[post_id] = ' '.join(post) 57 | return posts 58 | 59 | 60 | def read_data(args): 61 | print("Reading lines...") 62 | posts = {} 63 | question_candidates = {} 64 | answer_candidates = {} 65 | p_tf = defaultdict(int) 66 | p_idf = defaultdict(int) 67 | with open(args.post_data_tsvfile, 'rb') as tsvfile: 68 | post_reader = csv.reader(tsvfile, delimiter='\t') 69 | N = 0 70 | for row in post_reader: 71 | if N == 0: 72 | N += 1 73 | continue 74 | N += 1 75 | post_id, title, post = row 76 | post = title + ' ' + post 77 | post = post.lower().strip() 78 | for w in post.split(): 79 | p_tf[w] += 1 80 | for w in set(post.split()): 81 | p_idf[w] += 1 82 | posts[post_id] = post 83 | 84 | for w in p_idf: 85 | p_idf[w] = math.log(N*1.0/p_idf[w]) 86 | 87 | # for asin, post in posts.iteritems(): 88 | # posts[asin] = trim_by_len(post, MAX_POST_LEN) 89 | #posts = trim_by_tfidf(posts, p_tf, p_idf) 90 | N = 0 91 | with open(args.qa_data_tsvfile, 'rb') as tsvfile: 92 | qa_reader = csv.reader(tsvfile, delimiter='\t') 93 | i = 0 94 | for row in qa_reader: 95 | if i == 0: 96 | i += 1 97 | continue 98 | post_id, questions = row[0], row[1:11] 99 | answers = row[11:21] 100 | # questions = [trim_by_len(question, MAX_QUES_LEN) for question in questions] 101 | question_candidates[post_id] = questions 102 | # answers = [trim_by_len(answer, MAX_ANS_LEN) for answer in answers] 103 | answer_candidates[post_id] = answers 104 | 105 | train_ids = [train_id.strip('\n') for train_id in open(args.train_ids_file, 'r').readlines()] 106 | tune_ids = [tune_id.strip('\n') for tune_id in open(args.tune_ids_file, 'r').readlines()] 107 | test_ids = [test_id.strip('\n') for test_id in open(args.test_ids_file, 'r').readlines()] 108 | 109 | write_to_file(train_ids, args, posts, question_candidates, answer_candidates, 'train') 110 | write_to_file(tune_ids, args, posts, question_candidates, answer_candidates, 'tune') 111 | write_to_file(test_ids, args, posts, question_candidates, answer_candidates, 'test') 112 | 113 | 114 | if __name__ == "__main__": 115 | argparser = argparse.ArgumentParser(sys.argv[0]) 116 | argparser.add_argument("--post_data_tsvfile", type = str) 117 | argparser.add_argument("--qa_data_tsvfile", type = str) 118 | argparser.add_argument("--train_ids_file", type = str) 119 | argparser.add_argument("--train_context_fname", type = str) 120 | argparser.add_argument("--train_question_fname", type = str) 121 | argparser.add_argument("--train_answer_fname", type=str) 122 | argparser.add_argument("--tune_ids_file", type = str) 123 | argparser.add_argument("--tune_context_fname", type = str) 124 | argparser.add_argument("--tune_question_fname", type = str) 125 | argparser.add_argument("--tune_answer_fname", type=str) 126 | argparser.add_argument("--test_ids_file", type = str) 127 | argparser.add_argument("--test_context_fname", type=str) 128 | argparser.add_argument("--test_question_fname", type = str) 129 | argparser.add_argument("--test_answer_fname", type = str) 130 | args = argparser.parse_args() 131 | print args 132 | print "" 133 | read_data(args) 134 | -------------------------------------------------------------------------------- /src/data_generation/filter_amazon_data_byid.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import argparse 3 | from collections import defaultdict 4 | 5 | 6 | def main(args): 7 | train_ids_file = open(args.train_ids, 'r') 8 | train_answer_file = open(args.train_answer, 'r') 9 | train_candqs_ids_file = open(args.train_candqs_ids, 'r') 10 | train_answer_candqs_file = open(args.train_answer_candqs, 'w') 11 | train_ids = [] 12 | uniq_id_ct = defaultdict(int) 13 | for line in train_ids_file.readlines(): 14 | curr_id = line.strip('\n') 15 | uniq_id_ct[curr_id] += 1 16 | train_ids.append(curr_id+'_'+str(uniq_id_ct[curr_id])) 17 | i = 0 18 | train_answers = {} 19 | for line in train_answer_file.readlines(): 20 | train_answers[train_ids[i]] = line.strip('\n') 21 | i += 1 22 | for line in train_candqs_ids_file.readlines(): 23 | curr_id = line.strip('\n') 24 | try: 25 | train_answer_candqs_file.write(train_answers[curr_id]) 26 | except: 27 | import pdb 28 | pdb.set_trace() 29 | 30 | 31 | if __name__ == '__main__': 32 | argparser = argparse.ArgumentParser(sys.argv[0]) 33 | argparser.add_argument("--train_ids", type=str) 34 | argparser.add_argument("--train_answer", type=str) 35 | argparser.add_argument("--train_candqs_ids", type=str) 36 | argparser.add_argument("--train_answer_candqs", type=str) 37 | args = argparser.parse_args() 38 | print args 39 | print "" 40 | main(args) 41 | -------------------------------------------------------------------------------- /src/data_generation/filter_output_perid.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import argparse 3 | from collections import defaultdict 4 | 5 | 6 | def main(args): 7 | test_output_file = open(args.test_output, 'r') 8 | test_ids_file = open(args.test_ids, 'r') 9 | test_output_perid_file = open(args.test_output_perid, 'w') 10 | 11 | # ongoing_id = None 12 | id_seq_in_output = [] 13 | data_ct_per_id = defaultdict(int) 14 | for line in test_ids_file.readlines(): 15 | curr_id = line.strip('\n') 16 | data_ct_per_id[curr_id] += 1 17 | if args.max_per_id is not None: 18 | if data_ct_per_id[curr_id] <= args.max_per_id: 19 | id_seq_in_output.append(curr_id) 20 | else: 21 | id_seq_in_output.append(curr_id) 22 | i = 0 23 | ongoing_id = None 24 | test_output_lines = test_output_file.readlines() 25 | test_output_perid = [] 26 | for line in test_output_lines: 27 | if id_seq_in_output[i] != ongoing_id: 28 | test_output_perid.append(line) 29 | ongoing_id = id_seq_in_output[i] 30 | i += 1 31 | total_count = (len(test_output_perid) / args.batch_size - 1) * args.batch_size 32 | print total_count 33 | print len(test_output_perid) 34 | for line in test_output_perid[:total_count]: 35 | test_output_perid_file.write(line) 36 | 37 | 38 | if __name__ == '__main__': 39 | argparser = argparse.ArgumentParser(sys.argv[0]) 40 | argparser.add_argument("--test_output", type=str) 41 | argparser.add_argument("--test_ids", type=str) 42 | argparser.add_argument("--test_output_perid", type=str) 43 | argparser.add_argument("--max_per_id", type=int, default=None) 44 | argparser.add_argument("--batch_size", type=int) 45 | args = argparser.parse_args() 46 | print args 47 | print "" 48 | main(args) -------------------------------------------------------------------------------- /src/data_generation/remove_unks.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/raosudha89/clarification_question_generation_pytorch/23ae8aa0160eee70565751f4b6de13563a19d6ed/src/data_generation/remove_unks.py -------------------------------------------------------------------------------- /src/data_generation/run_create_amazon_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=candqs_Home_and_Kitchen 4 | #SBATCH --output=candqs_Home_and_Kitchen 5 | #SBATCH --qos=batch 6 | #SBATCH --mem=4g 7 | #SBATCH --time=4:00:00 8 | 9 | SITENAME=Home_and_Kitchen 10 | METADATA_DIR=/fs/clip-corpora/amazon_qa 11 | #DATA_DIR=/fs/clip-corpora/amazon_qa/$SITENAME 12 | DATA_DIR=/fs/clip-scratch/raosudha/amazon_qa/$SITENAME 13 | SCRIPT_DIR=/fs/clip-amr/clarification_question_generation/src-opennmt 14 | CQ_DATA_DIR=/fs/clip-scratch/raosudha/clarification_question_generation/data/amazon/$SITENAME 15 | 16 | mkdir $CQ_DATA_DIR 17 | 18 | #mkdir /fs/clip-corpora/amazon_qa/$SITENAME 19 | #mkdir /fs/clip-amr/clarification_question_generation/data/amazon/$SITENAME 20 | 21 | python $SCRIPT_DIR/create_amazon_data.py --prod_dir $DATA_DIR/prod_docs \ 22 | --ques_dir $DATA_DIR/ques_docs \ 23 | --metadata_fname $METADATA_DIR/meta_${SITENAME}.json.gz \ 24 | --sim_prod_fname $DATA_DIR/lucene_similar_prods.txt \ 25 | --sim_ques_fname $DATA_DIR/lucene_similar_ques.txt \ 26 | --train_src_fname $CQ_DATA_DIR/train_src \ 27 | --train_tgt_fname $CQ_DATA_DIR/train_tgt \ 28 | --tune_src_fname $CQ_DATA_DIR/tune_src \ 29 | --tune_tgt_fname $CQ_DATA_DIR/tune_tgt \ 30 | --test_src_fname $CQ_DATA_DIR/test_src \ 31 | --test_tgt_fname $CQ_DATA_DIR/test_tgt \ 32 | --train_ids_file $CQ_DATA_DIR/train_ids \ 33 | --tune_ids_file $CQ_DATA_DIR/tune_ids \ 34 | --test_ids_file $CQ_DATA_DIR/test_ids \ 35 | --candqs True\ 36 | #--onlycontext True \ 37 | #--simqs True \ 38 | #--template True \ 39 | #--nocontext True \ 40 | -------------------------------------------------------------------------------- /src/data_generation/run_create_amazon_pqa_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=pqa_data_Home_and_Kitchen 4 | #SBATCH --output=pqa_data_Home_and_Kitchen 5 | #SBATCH --qos=batch 6 | #SBATCH --mem=36g 7 | #SBATCH --time=24:00:00 8 | 9 | SITENAME=Home_and_Kitchen 10 | DATA_DIR=/fs/clip-corpora/amazon_qa 11 | CQ_DATA_DIR=/fs/clip-scratch/raosudha/clarification_question_generation/joint_learning/$SITENAME 12 | SCRIPT_DIR=/fs/clip-amr/clarification_question_generation_pytorch/src 13 | 14 | export PATH="/fs/clip-amr/anaconda2/bin:$PATH" 15 | 16 | python $SCRIPT_DIR/create_amazon_pqa_data.py --qa_data_fname $DATA_DIR/qa_${SITENAME}.json.gz \ 17 | --metadata_fname $DATA_DIR/meta_${SITENAME}.json.gz \ 18 | --train_asin_fname $CQ_DATA_DIR/train_asin.txt \ 19 | --train_context_fname $CQ_DATA_DIR/train_context.txt \ 20 | --train_ques_fname $CQ_DATA_DIR/train_ques.txt \ 21 | --train_ans_fname $CQ_DATA_DIR/train_ans.txt \ 22 | --tune_asin_fname $CQ_DATA_DIR/tune_asin.txt \ 23 | --tune_context_fname $CQ_DATA_DIR/tune_context.txt \ 24 | --tune_ques_fname $CQ_DATA_DIR/tune_ques.txt \ 25 | --tune_ans_fname $CQ_DATA_DIR/tune_ans.txt \ 26 | --test_asin_fname $CQ_DATA_DIR/test_asin.txt \ 27 | --test_context_fname $CQ_DATA_DIR/test_context.txt \ 28 | --test_ques_fname $CQ_DATA_DIR/test_ques.txt \ 29 | --test_ans_fname $CQ_DATA_DIR/test_ans.txt \ 30 | 31 | -------------------------------------------------------------------------------- /src/data_generation/run_create_amazon_pqa_data_from_asins.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=pqa_data_Home_and_Kitchen 4 | #SBATCH --output=pqa_data_Home_and_Kitchen 5 | #SBATCH --qos=batch 6 | #SBATCH --mem=36g 7 | #SBATCH --time=24:00:00 8 | 9 | SITENAME=Home_and_Kitchen 10 | DATA_DIR=/fs/clip-corpora/amazon_qa 11 | CQ_DATA_DIR=/fs/clip-amr/clarification_question_generation_pytorch/$SITENAME 12 | SCRIPT_DIR=/fs/clip-amr/clarification_question_generation_pytorch/src 13 | 14 | export PATH="/fs/clip-amr/anaconda2/bin:$PATH" 15 | 16 | python $SCRIPT_DIR/create_amazon_pqa_data_from_asins.py --qa_data_fname $DATA_DIR/qa_${SITENAME}.json.gz \ 17 | --metadata_fname $DATA_DIR/meta_${SITENAME}.json.gz \ 18 | --train_asin_fname $CQ_DATA_DIR/train_asin.txt \ 19 | --train_ans_fname $CQ_DATA_DIR/train_ans.txt \ 20 | --tune_asin_fname $CQ_DATA_DIR/tune_asin.txt \ 21 | --tune_context_fname $CQ_DATA_DIR/tune_context.txt \ 22 | --tune_ques_fname $CQ_DATA_DIR/tune_ques.txt \ 23 | --tune_ans_fname $CQ_DATA_DIR/tune_ans.txt \ 24 | --test_asin_fname $CQ_DATA_DIR/test_asin.txt \ 25 | --test_ans_fname $CQ_DATA_DIR/test_ans.txt \ 26 | 27 | -------------------------------------------------------------------------------- /src/data_generation/run_create_lucene_amazon.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=lucene_data_Home_and_Kitchen 4 | #SBATCH --output=lucene_data_Home_and_Kitchen 5 | #SBATCH --qos=batch 6 | #SBATCH --mem=36g 7 | #SBATCH --time=24:00:00 8 | 9 | SITENAME=Home_and_Kitchen 10 | DATA_DIR=/fs/clip-corpora/amazon_qa 11 | SCRATCH_DATA_DIR=/fs/clip-scratch/raosudha/amazon_qa 12 | SCRIPT_DIR=/fs/clip-amr/clarification_question_generation/src 13 | 14 | rm -r $SCRATCH_DATA_DIR/${SITENAME}/prod_docs/ 15 | rm -r $SCRATCH_DATA_DIR/${SITENAME}/ques_docs/ 16 | mkdir -p $SCRATCH_DATA_DIR/${SITENAME}/prod_docs/ 17 | mkdir -p $SCRATCH_DATA_DIR/${SITENAME}/ques_docs/ 18 | 19 | export PATH="/fs/clip-amr/anaconda2/bin:$PATH" 20 | 21 | python $SCRIPT_DIR/create_lucene_amazon.py --qa_data_fname $DATA_DIR/qa_${SITENAME}.json.gz \ 22 | --metadata_fname $DATA_DIR/meta_${SITENAME}.json.gz \ 23 | --product_dir $SCRATCH_DATA_DIR/${SITENAME}/prod_docs/ \ 24 | --question_dir $SCRATCH_DATA_DIR/${SITENAME}/ques_docs/ 25 | -------------------------------------------------------------------------------- /src/data_generation/run_create_pqa_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=pqa_data_aus 4 | #SBATCH --output=pqa_data_aus 5 | #SBATCH --qos=batch 6 | #SBATCH --mem=36g 7 | #SBATCH --time=24:00:00 8 | 9 | SITENAME=askubuntu_unix_superuser 10 | 11 | DATA_DIR=/fs/clip-amr/ranking_clarification_questions/data/$SITENAME 12 | CQ_DATA_DIR=/fs/clip-amr/clarification_question_generation_pytorch/$SITENAME 13 | SCRIPT_DIR=/fs/clip-amr/clarification_question_generation_pytorch/src 14 | 15 | python $SCRIPT_DIR/create_pqa_data.py --post_data_tsvfile $DATA_DIR/post_data.tsv \ 16 | --qa_data_tsvfile $DATA_DIR/qa_data.tsv \ 17 | --train_ids_file $DATA_DIR/train_ids \ 18 | --tune_ids_file $DATA_DIR/tune_ids \ 19 | --test_ids_file $DATA_DIR/test_ids \ 20 | --train_context_fname $CQ_DATA_DIR/train_context.txt \ 21 | --train_question_fname $CQ_DATA_DIR/train_question.txt \ 22 | --train_answer_fname $CQ_DATA_DIR/train_answer.txt \ 23 | --tune_context_fname $CQ_DATA_DIR/tune_context.txt \ 24 | --tune_question_fname $CQ_DATA_DIR/tune_question.txt \ 25 | --tune_answer_fname $CQ_DATA_DIR/tune_answer.txt \ 26 | --test_context_fname $CQ_DATA_DIR/test_context.txt \ 27 | --test_question_fname $CQ_DATA_DIR/test_question.txt \ 28 | --test_answer_fname $CQ_DATA_DIR/test_answer.txt \ 29 | -------------------------------------------------------------------------------- /src/data_generation/run_filter_amazon_data_byid.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | SITENAME=Home_and_Kitchen 4 | DATA_DIR=/fs/clip-corpora/amazon_qa 5 | OLD_CQ_DATA_DIR=/fs/clip-amr/clarification_question_generation/data/amazon/$SITENAME 6 | CQ_DATA_DIR=/fs/clip-scratch/raosudha/clarification_question_generation/joint_learning/$SITENAME 7 | SCRIPT_DIR=/fs/clip-amr/clarification_question_generation_pytorch/src 8 | 9 | export PATH="/fs/clip-amr/anaconda2/bin:$PATH" 10 | 11 | python $SCRIPT_DIR/filter_amazon_data_byid.py --train_ids $CQ_DATA_DIR/train_asin.txt \ 12 | --train_answer $CQ_DATA_DIR/train_answer.txt \ 13 | --train_candqs_ids $OLD_CQ_DATA_DIR/train_tgt_candqs.txt.ids \ 14 | --train_answer_candqs $CQ_DATA_DIR/train_answer_candqs.txt \ 15 | 16 | -------------------------------------------------------------------------------- /src/data_generation/run_filter_output_perid.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | SITENAME=Home_and_Kitchen 4 | DATA_DIR=/fs/clip-corpora/amazon_qa 5 | OLD_CQ_DATA_DIR=/fs/clip-amr/clarification_question_generation/data/amazon/$SITENAME 6 | CQ_DATA_DIR=/fs/clip-scratch/raosudha/clarification_question_generation/joint_learning/$SITENAME 7 | SCRIPT_DIR=/fs/clip-amr/clarification_question_generation_pytorch/src 8 | 9 | export PATH="/fs/clip-amr/anaconda2/bin:$PATH" 10 | 11 | python $SCRIPT_DIR/filter_output_perid.py --test_ids $CQ_DATA_DIR/tune_asin.txt \ 12 | --test_output $CQ_DATA_DIR/test_pred_question.txt.pretrained_greedy \ 13 | --test_output_perid $CQ_DATA_DIR/test_pred_question.txt.pretrained_greedy.perid \ 14 | --batch_size 128 \ 15 | # --max_per_id 3 \ 16 | # --test_output $CQ_DATA_DIR/test_pred_question.txt.GAN_mixer_pred_ans_3perid.epoch8.beam0 \ 17 | # --test_output_perid $CQ_DATA_DIR/test_pred_question.txt.GAN_mixer_pred_ans_3perid.epoch8.beam0.perid \ 18 | # --test_output $CQ_DATA_DIR/test_pred_question.txt.pretrained_greedy \ 19 | # --test_output_perid $CQ_DATA_DIR/test_pred_question.txt.pretrained_greedy.perid \ 20 | 21 | -------------------------------------------------------------------------------- /src/decode.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pickle as p 3 | import sys 4 | import torch 5 | 6 | from seq2seq.encoderRNN import * 7 | from seq2seq.attnDecoderRNN import * 8 | from seq2seq.RL_beam_decoder import * 9 | # from seq2seq.diverse_beam_decoder import * 10 | from seq2seq.RL_evaluate import * 11 | # from RL_helper import * 12 | from constants import * 13 | 14 | 15 | def main(args): 16 | print('Enter main') 17 | word_embeddings = p.load(open(args.word_embeddings, 'rb')) 18 | print('Loaded emb of size %d' % len(word_embeddings)) 19 | word_embeddings = np.array(word_embeddings) 20 | word2index = p.load(open(args.vocab, 'rb')) 21 | index2word = reverse_dict(word2index) 22 | 23 | if args.test_ids is not None: 24 | test_data = read_data(args.test_context, args.test_question, None, args.test_ids, 25 | args.max_post_len, args.max_ques_len, args.max_ans_len, mode='test') 26 | else: 27 | test_data = read_data(args.test_context, args.test_question, None, None, 28 | args.max_post_len, args.max_ques_len, args.max_ans_len, mode='test') 29 | 30 | print('No. of test_data %d' % len(test_data)) 31 | run_model(test_data, word_embeddings, word2index, index2word, args) 32 | 33 | 34 | def run_model(test_data, word_embeddings, word2index, index2word, args): 35 | print('Preprocessing test data..') 36 | te_id_seqs, te_post_seqs, te_post_lens, te_ques_seqs, te_ques_lens, \ 37 | te_post_ques_seqs, te_post_ques_lens, te_ans_seqs, te_ans_lens = preprocess_data(test_data, word2index, 38 | args.max_post_len, 39 | args.max_ques_len, 40 | args.max_ans_len) 41 | 42 | print('Defining encoder decoder models') 43 | q_encoder = EncoderRNN(HIDDEN_SIZE, word_embeddings, n_layers=2, dropout=DROPOUT) 44 | q_decoder = AttnDecoderRNN(HIDDEN_SIZE, len(word2index), word_embeddings, n_layers=2) 45 | 46 | if USE_CUDA: 47 | device = torch.device('cuda:0') 48 | else: 49 | device = torch.device('cpu') 50 | q_encoder = q_encoder.to(device) 51 | q_decoder = q_decoder.to(device) 52 | 53 | # Load encoder, decoder params 54 | print('Loading encoded, decoder params') 55 | if USE_CUDA: 56 | q_encoder.load_state_dict(torch.load(args.q_encoder_params)) 57 | q_decoder.load_state_dict(torch.load(args.q_decoder_params)) 58 | else: 59 | q_encoder.load_state_dict(torch.load(args.q_encoder_params, map_location='cpu')) 60 | q_decoder.load_state_dict(torch.load(args.q_decoder_params, map_location='cpu')) 61 | 62 | out_fname = args.test_pred_question+'.'+args.model 63 | # out_fname = None 64 | if args.greedy: 65 | evaluate_seq2seq(word2index, index2word, q_encoder, q_decoder, 66 | te_id_seqs, te_post_seqs, te_post_lens, te_ques_seqs, te_ques_lens, 67 | args.batch_size, args.max_ques_len, out_fname) 68 | elif args.beam: 69 | evaluate_beam(word2index, index2word, q_encoder, q_decoder, 70 | te_id_seqs, te_post_seqs, te_post_lens, te_ques_seqs, te_ques_lens, 71 | args.batch_size, args.max_ques_len, out_fname) 72 | elif args.diverse_beam: 73 | evaluate_diverse_beam(word2index, index2word, q_encoder, q_decoder, 74 | te_id_seqs, te_post_seqs, te_post_lens, te_ques_seqs, te_ques_lens, 75 | args.batch_size, args.max_ques_len, out_fname) 76 | else: 77 | print('Please specify mode of decoding: --greedy OR --beam OR --diverse_beam') 78 | 79 | 80 | if __name__ == "__main__": 81 | argparser = argparse.ArgumentParser(sys.argv[0]) 82 | argparser.add_argument("--test_context", type = str) 83 | argparser.add_argument("--test_question", type = str) 84 | argparser.add_argument("--test_answer", type = str) 85 | argparser.add_argument("--test_ids", type=str) 86 | argparser.add_argument("--test_pred_question", type = str) 87 | argparser.add_argument("--q_encoder_params", type = str) 88 | argparser.add_argument("--q_decoder_params", type = str) 89 | argparser.add_argument("--vocab", type = str) 90 | argparser.add_argument("--word_embeddings", type = str) 91 | argparser.add_argument("--max_post_len", type = int, default=300) 92 | argparser.add_argument("--max_ques_len", type = int, default=50) 93 | argparser.add_argument("--max_ans_len", type = int, default=50) 94 | argparser.add_argument("--batch_size", type = int, default=128) 95 | argparser.add_argument("--model", type=str) 96 | argparser.add_argument("--greedy", type=bool, default=False) 97 | argparser.add_argument("--beam", type=bool, default=False) 98 | argparser.add_argument("--diverse_beam", type=bool, default=False) 99 | args = argparser.parse_args() 100 | print(args) 101 | print("") 102 | main(args) 103 | -------------------------------------------------------------------------------- /src/embedding_generation/create_we_vocab.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import pickle as p 3 | 4 | PAD_token = '' 5 | SOS_token = '' 6 | EOP_token = '' 7 | EOS_token = '' 8 | 9 | if __name__ == "__main__": 10 | if len(sys.argv) < 4: 11 | print("usage: python create_we_vocab.py ") 12 | sys.exit(0) 13 | word_vectors_file = open(sys.argv[1], 'r') 14 | word_embeddings = [] 15 | vocab = {} 16 | vocab[PAD_token] = 0 17 | vocab[SOS_token] = 1 18 | vocab[EOP_token] = 2 19 | vocab[EOS_token] = 3 20 | word_embeddings.append(None) 21 | word_embeddings.append(None) 22 | word_embeddings.append(None) 23 | word_embeddings.append(None) 24 | 25 | i = 4 26 | for line in word_vectors_file.readlines(): 27 | vals = line.rstrip().split(' ') 28 | vocab[vals[0]] = i 29 | word_embeddings.append([float(v) for v in vals[1:]]) 30 | i += 1 31 | 32 | word_embeddings[0] = [0]*len(word_embeddings[4]) 33 | word_embeddings[1] = [0]*len(word_embeddings[4]) 34 | word_embeddings[2] = [0]*len(word_embeddings[4]) 35 | word_embeddings[3] = [0]*len(word_embeddings[4]) 36 | 37 | p.dump(word_embeddings, open(sys.argv[2], 'wb')) 38 | p.dump(vocab, open(sys.argv[3], 'wb')) 39 | 40 | -------------------------------------------------------------------------------- /src/embedding_generation/extract_amazon_data.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import gzip 3 | import nltk 4 | import pdb 5 | import sys, os 6 | import re 7 | from collections import defaultdict 8 | 9 | def parse(path): 10 | g = gzip.open(path, 'r') 11 | for l in g: 12 | yield eval(l) 13 | 14 | exception_chars = ['|', '/', '\\', '-', '(', ')', '!', ':', ';', '<', '>'] 15 | 16 | def preprocess(text): 17 | text = text.replace('|', ' ') 18 | text = text.replace('/', ' ') 19 | text = text.replace('\\', ' ') 20 | text = text.lower() 21 | ret_text = '' 22 | for sent in nltk.sent_tokenize(text): 23 | ret_text += ' '.join(nltk.word_tokenize(sent)) + ' ' 24 | return ret_text 25 | 26 | def main(args): 27 | output_file = open(args.output_fname, 'w') 28 | for v in parse(args.metadata_fname): 29 | if 'description' not in v or 'title' not in v: 30 | continue 31 | title = preprocess(v['title']) 32 | description = preprocess(v['description']) 33 | product = title + ' . ' + description 34 | output_file.write(product+'\n') 35 | 36 | for v in parse(args.qa_data_fname): 37 | if 'answer' not in v: 38 | continue 39 | question = preprocess(v['question']) 40 | answer = preprocess(v['answer']) 41 | output_file.write(question+'\n') 42 | output_file.write(answer+'\n') 43 | output_file.close() 44 | 45 | if __name__ == "__main__": 46 | argparser = argparse.ArgumentParser(sys.argv[0]) 47 | argparser.add_argument("--qa_data_fname", type = str) 48 | argparser.add_argument("--metadata_fname", type = str) 49 | argparser.add_argument("--output_fname", type = str) 50 | args = argparser.parse_args() 51 | print args 52 | print "" 53 | main(args) 54 | 55 | -------------------------------------------------------------------------------- /src/embedding_generation/extract_data.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import csv 3 | import sys 4 | 5 | def main(args): 6 | print("Reading lines...") 7 | output_file = open(args.output_data, 'a') 8 | with open(args.post_data_tsv, 'rb') as tsvfile: 9 | post_reader = csv.DictReader(tsvfile, delimiter='\t') 10 | for row in post_reader: 11 | post = row['title'] + ' ' + row['post'] 12 | post = post.lower().strip() 13 | output_file.write(post+'\n') 14 | 15 | with open(args.qa_data_tsv, 'rb') as tsvfile: 16 | qa_reader = csv.DictReader(tsvfile, delimiter='\t') 17 | for row in qa_reader: 18 | ques = row['q1'].lower().strip() 19 | ans = row['a1'].lower().strip() 20 | output_file.write(ques+'\n') 21 | output_file.write(ans+'\n') 22 | 23 | output_file.close() 24 | 25 | if __name__ == "__main__": 26 | argparser = argparse.ArgumentParser(sys.argv[0]) 27 | argparser.add_argument("--post_data_tsv", type = str) 28 | argparser.add_argument("--qa_data_tsv", type = str) 29 | argparser.add_argument("--output_data", type = str) 30 | args = argparser.parse_args() 31 | print args 32 | print "" 33 | main(args) 34 | 35 | -------------------------------------------------------------------------------- /src/embedding_generation/run_create_we_vocab.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SITENAME=askubuntu_unix_superuser 4 | SITENAME=Home_and_Kitchen 5 | 6 | SCRIPTS_DIR=/fs/clip-amr/clarification_question_generation_pytorch/src/embedding_generation 7 | EMB_DIR=/fs/clip-amr/clarification_question_generation_pytorch/embeddings/$SITENAME/200 8 | 9 | python $SCRIPTS_DIR/create_we_vocab.py $EMB_DIR/vectors.txt $EMB_DIR/word_embeddings.p $EMB_DIR/vocab.p 10 | 11 | -------------------------------------------------------------------------------- /src/embedding_generation/run_extract_amazon_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=emb_data_Home_and_Kitchen 4 | #SBATCH --output=emb_data_Home_and_Kitchen 5 | #SBATCH --qos=batch 6 | #SBATCH --mem=36g 7 | #SBATCH --time=24:00:00 8 | 9 | SITENAME=Home_and_Kitchen 10 | DATA_DIR=/fs/clip-corpora/amazon_qa 11 | EMB_DATA_DIR=/fs/clip-amr/clarification_question_generation_pytorch/embeddings/$SITENAME 12 | SCRIPT_DIR=/fs/clip-amr/clarification_question_generation_pytorch/src/embedding_generation 13 | 14 | python $SCRIPT_DIR/extract_amazon_data.py --qa_data_fname $DATA_DIR/qa_${SITENAME}.json.gz \ 15 | --metadata_fname $DATA_DIR/meta_${SITENAME}.json.gz \ 16 | --output_fname $EMB_DATA_DIR/${SITENAME}_data.txt 17 | 18 | -------------------------------------------------------------------------------- /src/embedding_generation/run_extract_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SITENAME=askubuntu.com 4 | #SITENAME=unix.stackexchange.com 5 | SITENAME=superuser.com 6 | 7 | CQ_DATA_DIR=/fs/clip-amr/ranking_clarification_questions/data/$SITENAME 8 | OUT_DATA_DIR=/fs/clip-amr/clarification_question_generation_pytorch/embeddings 9 | SCRIPT_DIR=/fs/clip-amr/clarification_question_generation_pytorch/src/embedding_generation 10 | 11 | python $SCRIPT_DIR/extract_data.py --post_data_tsv $CQ_DATA_DIR/post_data.tsv \ 12 | --qa_data_tsv $CQ_DATA_DIR/qa_data.tsv \ 13 | --output_data $OUT_DATA_DIR/${SITENAME}.data.txt 14 | 15 | -------------------------------------------------------------------------------- /src/embedding_generation/run_glove.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Makes programs, downloads sample data, trains a GloVe model, and then evaluates it. 4 | # One optional argument can specify the language used for eval script: matlab, octave or [default] python 5 | 6 | SITENAME=askubuntu_unix_superuser 7 | #SITENAME=Home_and_Kitchen 8 | 9 | DATADIR=clarification_question_generation_pytorch/embeddings/$SITENAME/200_10Kvocab 10 | 11 | CORPUS=clarification_question_generation_pytorch/embeddings/$SITENAME/${SITENAME}_data.txt 12 | VOCAB_FILE=$DATADIR/vocab.txt 13 | COOCCURRENCE_FILE=$DATADIR/cooccurrence.bin 14 | COOCCURRENCE_SHUF_FILE=$DATADIR/cooccurrence.shuf.bin 15 | BUILDDIR=GloVe-1.2/build #Download from https://nlp.stanford.edu/projects/glove/ 16 | SAVE_FILE=$DATADIR/vectors 17 | VERBOSE=2 18 | MEMORY=4.0 19 | #VOCAB_MIN_COUNT=10 20 | VOCAB_MIN_COUNT=50 21 | #VECTOR_SIZE=100 22 | VECTOR_SIZE=200 23 | MAX_ITER=30 24 | WINDOW_SIZE=15 25 | BINARY=2 26 | NUM_THREADS=4 27 | X_MAX=10 28 | 29 | $BUILDDIR/vocab_count -min-count $VOCAB_MIN_COUNT -verbose $VERBOSE < $CORPUS > $VOCAB_FILE 30 | $BUILDDIR/cooccur -memory $MEMORY -vocab-file $VOCAB_FILE -verbose $VERBOSE -window-size $WINDOW_SIZE < $CORPUS > $COOCCURRENCE_FILE 31 | $BUILDDIR/shuffle -memory $MEMORY -verbose $VERBOSE < $COOCCURRENCE_FILE > $COOCCURRENCE_SHUF_FILE 32 | $BUILDDIR/glove -save-file $SAVE_FILE -eta 0.05 -threads $NUM_THREADS -input-file $COOCCURRENCE_SHUF_FILE -x-max $X_MAX -iter $MAX_ITER -vector-size $VECTOR_SIZE -binary $BINARY -vocab-file $VOCAB_FILE -verbose $VERBOSE 33 | -------------------------------------------------------------------------------- /src/evaluation/all_ngrams.pl: -------------------------------------------------------------------------------- 1 | #!/usr/bin/perl -w 2 | use strict; 3 | 4 | my $usage = "usage: cat FILE | all_ngrams.pl [-noboundary|-smallboundary] \n"; 5 | 6 | my $boundary = 2; 7 | my $N = 1; 8 | 9 | while (1) { 10 | my $tmp = shift or die $usage; 11 | if ($tmp eq '-noboundary') { $boundary = 0; } 12 | elsif ($tmp eq '-smallboundary') { $boundary = 1; } 13 | elsif ($tmp =~ /^[0-9]+$/) { $N = $tmp; last; } 14 | else { die $usage; } 15 | } 16 | 17 | while (<>) { 18 | chomp; 19 | if (/^[\s]*$/) { next; } 20 | my @w = split; 21 | my $M = scalar @w; 22 | 23 | my $lo = -$N+1; 24 | if ($boundary == 0) { $lo = 0; } 25 | if (($boundary == 1) && ($N>1)) { $lo = -1; } 26 | 27 | my $hi = $M; 28 | if ($boundary == 0) { $hi = $M-$N+1; } 29 | if (($boundary == 1) && ($N>1)) { $hi = $M-$N+2; } 30 | 31 | for (my $i=$lo; $i<$hi; $i++) { 32 | for (my $j=0; $j<$N; $j++) { 33 | if ($j > 0) { print ' '; } 34 | print (($i+$j<0) ? '' : (($i+$j>=$M) ? '' : $w[$i+$j])); 35 | } 36 | print "\n"; 37 | } 38 | } 39 | -------------------------------------------------------------------------------- /src/evaluation/calculate_diversity.sh: -------------------------------------------------------------------------------- 1 | export LANGUAGE=en_US.UTF-8 2 | export LANG=en_US.UTF-8 3 | export LC_ALL=en_US.UTF-8 4 | 5 | NGRAMS_SCRIPT=/fs/clip-amr/clarification_question_generation_pytorch/src/evaluation/all_ngrams.pl 6 | 7 | count_uniq_trigrams=$( cat $1 | $NGRAMS_SCRIPT 3 | sort | uniq -c | sort -gr | wc -l ) 8 | count_all_trigrams=$( cat $1 | $NGRAMS_SCRIPT 3 | sort | sort -gr | wc -l ) 9 | echo "Trigram diversity" 10 | echo "scale=4; $count_uniq_trigrams / $count_all_trigrams" | bc 11 | 12 | count_uniq_bigrams=$( cat $1 | $NGRAMS_SCRIPT 2 | sort | uniq -c | sort -gr | wc -l ) 13 | count_all_bigrams=$( cat $1 | $NGRAMS_SCRIPT 2 | sort | sort -gr | wc -l ) 14 | echo "Bigram diversity" 15 | echo "scale=4; $count_uniq_bigrams / $count_all_bigrams" | bc 16 | 17 | count_uniq_unigrams=$( cat $1 | $NGRAMS_SCRIPT 1 | sort | uniq -c | sort -gr | wc -l ) 18 | count_all_unigrams=$( cat $1 | $NGRAMS_SCRIPT 1 | sort | sort -gr | wc -l ) 19 | echo "Unigram diversity" 20 | echo "scale=4; $count_uniq_unigrams / $count_all_unigrams" | bc 21 | 22 | 23 | -------------------------------------------------------------------------------- /src/evaluation/calculate_inter_annotator_agreement.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import csv 3 | from collections import defaultdict 4 | import sys 5 | import numpy as np 6 | import pdb 7 | 8 | on_topic_levels = {'Yes': 1, 'No': 0} 9 | is_grammatical_levels = {'Grammatical': 1, 'Comprehensible': 1, 'Incomprehensible': 0} 10 | is_specific_levels = {'Specific pretty much only to this product': 4, 11 | 'Specific to this and other very similar products (or the same product from a different manufacturer)': 3, 12 | 'Generic enough to be applicable to many other products of this type': 2, 13 | 'Generic enough to be applicable to any product under Home and Kitchen': 1, 14 | 'N/A (Not Applicable)': -1} 15 | asks_new_info_levels = {'Completely': 1, 'Somewhat': 1, 'No': 0, 'N/A (Not Applicable)': -1} 16 | useful_levels = {'Useful enough to be included in the product description': 4, 17 | 'Useful to a large number of potential buyers (or current users)': 3, 18 | 'Useful to a small number of potential buyers (or current users)': 2, 19 | 'Useful only to the person asking the question': 1, 20 | 'N/A (Not Applicable)': -1} 21 | 22 | model_dict = {'ref': 0, 'lucene': 1, 'seq2seq.beam': 2, 23 | 'rl.beam': 3, 24 | 'gan.beam': 4} 25 | 26 | model_list = ['ref', 'lucene', 'seq2seq.beam', 27 | 'rl.beam', 28 | 'gan.beam'] 29 | 30 | frequent_words = ['dimensions', 'dimension', 'size', 'measurements', 'measurement', 31 | 'weight', 'height', 'width', 'diameter', 'density', 32 | 'bpa', 'difference', 'thread', 33 | 'china', 'usa'] 34 | 35 | 36 | def get_avg_score(score_dict, ignore_na=True): 37 | curr_on_topic_score = 0. 38 | N = 0 39 | for score, count in score_dict.iteritems(): 40 | if score != -1: 41 | curr_on_topic_score += score * count 42 | if ignore_na: 43 | if score != -1: 44 | N += count 45 | else: 46 | N += count 47 | #print N 48 | return curr_on_topic_score * 1.0 / N 49 | 50 | 51 | def main(args): 52 | num_models = len(model_list) 53 | on_topic_conf_scores = [] 54 | is_grammatical_conf_scores = [] 55 | is_specific_conf_scores = [] 56 | asks_new_info_conf_scores = [] 57 | useful_conf_scores = [] 58 | 59 | asins_so_far = [None] * num_models 60 | for i in range(num_models): 61 | asins_so_far[i] = [] 62 | 63 | with open(args.aggregate_results) as csvfile: 64 | reader = csv.DictReader(csvfile) 65 | for row in reader: 66 | if row['_golden'] == 'true' or row['_unit_state'] == 'golden': 67 | continue 68 | asin = row['asin'] 69 | question = row['question'] 70 | model_name = row['model_name'] 71 | if model_name not in model_list: 72 | continue 73 | if asin not in asins_so_far[model_dict[model_name]]: 74 | asins_so_far[model_dict[model_name]].append(asin) 75 | else: 76 | print '%s duplicate %s' % (model_name, asin) 77 | continue 78 | on_topic_score = on_topic_levels[row['on_topic']] 79 | on_topic_conf_score = float(row['on_topic:confidence']) 80 | is_grammatical_score = is_grammatical_levels[row['grammatical']] 81 | is_grammatical_conf_score = float(row['grammatical:confidence']) 82 | is_specific_conf_score = float(row['is_specific:confidence']) 83 | asks_new_info_score = asks_new_info_levels[row['new_info']] 84 | asks_new_info_conf_score = float(row['new_info:confidence']) 85 | useful_conf_score = float(row['useful_to_another_buyer:confidence']) 86 | 87 | on_topic_conf_scores.append(on_topic_conf_score) 88 | is_grammatical_conf_scores.append(is_grammatical_conf_score) 89 | if on_topic_score != 0 and is_grammatical_score != 0: 90 | is_specific_conf_scores.append(is_specific_conf_score) 91 | asks_new_info_conf_scores.append(asks_new_info_conf_score) 92 | if asks_new_info_score != 0: 93 | useful_conf_scores.append(useful_conf_score) 94 | 95 | print 'On topic confidence: %.4f' % (sum(on_topic_conf_scores)/float(len(on_topic_conf_scores))) 96 | print 'Is grammatical confidence: %.4f' % (sum(is_grammatical_conf_scores)/float(len(is_grammatical_conf_scores))) 97 | print 'Asks new info confidence: %.4f' % (sum(asks_new_info_conf_scores)/float(len(asks_new_info_conf_scores))) 98 | print 'Useful confidence: %.4f' % (sum(useful_conf_scores)/float(len(useful_conf_scores))) 99 | print 'Specificity confidence: %.4f' % (sum(is_specific_conf_scores)/float(len(is_specific_conf_scores))) 100 | 101 | 102 | if __name__ == '__main__': 103 | argparser = argparse.ArgumentParser(sys.argv[0]) 104 | argparser.add_argument("--aggregate_results", type = str) 105 | args = argparser.parse_args() 106 | print args 107 | print "" 108 | main(args) 109 | -------------------------------------------------------------------------------- /src/evaluation/combine_refs_for_meteor.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import argparse 3 | 4 | def main(args): 5 | ref_sents = [None]*int(args.no_of_refs) 6 | for i in range(int(args.no_of_refs)): 7 | with open(args.ref_prefix+str(i), 'r') as f: 8 | ref_sents[i] = [line.strip('\n') for line in f.readlines()] 9 | 10 | combined_ref_file = open(args.combined_ref_fname, 'w') 11 | for i in range(len(ref_sents[0])): 12 | for j in range(int(args.no_of_refs)): 13 | combined_ref_file.write(ref_sents[j][i]+'\n') 14 | combined_ref_file.close() 15 | 16 | if __name__ == "__main__": 17 | argparser = argparse.ArgumentParser(sys.argv[0]) 18 | argparser.add_argument("--ref_prefix", type = str) 19 | argparser.add_argument("--no_of_refs", type = int) 20 | argparser.add_argument("--combined_ref_fname", type = str) 21 | args = argparser.parse_args() 22 | print args 23 | print "" 24 | main(args) 25 | -------------------------------------------------------------------------------- /src/evaluation/create_amazon_multi_refs.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import csv 3 | import sys, os, pdb 4 | import nltk 5 | import time 6 | import random 7 | from collections import defaultdict 8 | 9 | def create_refs(test_ids, question_candidates, ref_prefix): 10 | max_ref_count=0 11 | for ques_id in test_ids: 12 | asin = ques_id.split('_')[0] 13 | N = len(question_candidates[asin]) 14 | max_ref_count = max(max_ref_count, N) 15 | ref_files = [None]*max_ref_count 16 | for i in range(max_ref_count): 17 | ref_files[i] = open(ref_prefix+str(i), 'w') 18 | for ques_id in test_ids: 19 | asin = ques_id.split('_')[0] 20 | N = len(question_candidates[asin]) 21 | for i, ques in enumerate(question_candidates[asin]): 22 | ref_files[i].write(ques+'\n') 23 | choices = range(N) 24 | random.shuffle(choices) 25 | for j in range(N, max_ref_count): 26 | r = choices[j%N] 27 | ref_files[j].write(question_candidates[asin][r]+'\n') 28 | 29 | def main(args): 30 | question_candidates = {} 31 | test_ids = [test_id.strip('\n') for test_id in open(args.test_ids_file, 'r').readlines()] 32 | 33 | for fname in os.listdir(args.ques_dir): 34 | with open(os.path.join(args.ques_dir, fname), 'r') as f: 35 | asin = fname[:-4].split('_')[0] 36 | if asin not in question_candidates: 37 | question_candidates[asin] = [] 38 | ques = f.readline().strip('\n') 39 | question_candidates[asin].append(ques) 40 | 41 | create_refs(test_ids, question_candidates, args.ref_prefix) 42 | 43 | if __name__ == "__main__": 44 | argparser = argparse.ArgumentParser(sys.argv[0]) 45 | argparser.add_argument("--ques_dir", type = str) 46 | argparser.add_argument("--test_ids_file", type = str) 47 | argparser.add_argument("--ref_prefix", type = str) 48 | args = argparser.parse_args() 49 | print args 50 | print "" 51 | main(args) 52 | 53 | -------------------------------------------------------------------------------- /src/evaluation/create_crowdflower_data.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import gzip 3 | import nltk 4 | import pdb 5 | import sys 6 | from collections import defaultdict 7 | import csv 8 | import random 9 | 10 | 11 | def parse(path): 12 | g = gzip.open(path, 'r') 13 | for l in g: 14 | yield eval(l) 15 | 16 | 17 | def read_model_outputs(model_fname): 18 | i = 0 19 | model_outputs = {} 20 | model_file = open(model_fname, 'r') 21 | test_ids = [line.strip('\n') for line in open(model_fname+'.ids', 'r')] 22 | for line in model_file.readlines(): 23 | model_outputs[test_ids[i]] = line.strip('\n').replace(' ', '').replace(' ', '') 24 | i += 1 25 | return model_outputs 26 | 27 | 28 | def main(args): 29 | titles = {} 30 | descriptions = {} 31 | lucene_model_outs = read_model_outputs(args.lucene_model_fname) 32 | seq2seq_model_outs = read_model_outputs(args.seq2seq_model_fname) 33 | rl_model_outs = read_model_outputs(args.rl_model_fname) 34 | gan_model_outs = read_model_outputs(args.gan_model_fname) 35 | 36 | prev_batch_asins = [] 37 | with open(args.batch1_csv_file) as csvfile: 38 | reader = csv.DictReader(csvfile) 39 | for row in reader: 40 | asin = row['asin'] 41 | prev_batch_asins.append(asin) 42 | with open(args.batch2_csv_file) as csvfile: 43 | reader = csv.DictReader(csvfile) 44 | for row in reader: 45 | asin = row['asin'] 46 | prev_batch_asins.append(asin) 47 | with open(args.batch3_csv_file) as csvfile: 48 | reader = csv.DictReader(csvfile) 49 | for row in reader: 50 | asin = row['asin'] 51 | prev_batch_asins.append(asin) 52 | 53 | for v in parse(args.metadata_fname): 54 | asin = v['asin'] 55 | if asin in prev_batch_asins: 56 | continue 57 | if asin not in lucene_model_outs.keys(): 58 | continue 59 | title = v['title'] 60 | description = v['description'] 61 | length = len(description.split()) 62 | if length >= 100 or length < 10 or len(title.split()) == length: 63 | continue 64 | titles[asin] = title 65 | descriptions[asin] = description 66 | 67 | questions = defaultdict(list) 68 | for v in parse(args.qa_data_fname): 69 | asin = v['asin'] 70 | if asin in prev_batch_asins: 71 | continue 72 | if asin not in lucene_model_outs.keys(): 73 | continue 74 | if asin not in descriptions: 75 | continue 76 | question = ' '.join(nltk.sent_tokenize(v['question'])).lower() 77 | questions[asin].append(question) 78 | 79 | csv_file = open(args.csv_file, 'w') 80 | writer = csv.writer(csv_file, delimiter=',') 81 | writer.writerow(['asin', 'title', 'description', 'model_name', 'question']) 82 | all_rows = [] 83 | max_count = 200 84 | i = 0 85 | for asin in lucene_model_outs.keys(): 86 | if asin in prev_batch_asins: 87 | continue 88 | if asin not in titles: 89 | continue 90 | title = titles[asin] 91 | description = descriptions[asin] 92 | ref_question = random.choice(questions[asin]) 93 | lucene_question = lucene_model_outs[asin] 94 | if lucene_question == '': 95 | print 'Found empty line in lucene' 96 | continue 97 | seq2seq_question = seq2seq_model_outs[asin] 98 | rl_question = rl_model_outs[asin] 99 | gan_question = gan_model_outs[asin] 100 | all_rows.append([asin, title, description, 'ref', ref_question]) 101 | all_rows.append([asin, title, description, args.lucene_model_name, lucene_question]) 102 | all_rows.append([asin, title, description, args.seq2seq_model_name, seq2seq_question]) 103 | all_rows.append([asin, title, description, args.rl_model_name, rl_question]) 104 | all_rows.append([asin, title, description, args.gan_model_name, gan_question]) 105 | i += 1 106 | if i >= max_count: 107 | break 108 | random.shuffle(all_rows) 109 | for row in all_rows: 110 | writer.writerow(row) 111 | csv_file.close() 112 | 113 | 114 | if __name__ == "__main__": 115 | argparser = argparse.ArgumentParser(sys.argv[0]) 116 | argparser.add_argument("--qa_data_fname", type = str) 117 | argparser.add_argument("--metadata_fname", type = str) 118 | argparser.add_argument("--batch1_csv_file", type=str) 119 | argparser.add_argument("--batch2_csv_file", type=str) 120 | argparser.add_argument("--batch3_csv_file", type=str) 121 | argparser.add_argument("--csv_file", type=str) 122 | argparser.add_argument("--lucene_model_fname", type=str) 123 | argparser.add_argument("--lucene_model_name", type=str) 124 | argparser.add_argument("--seq2seq_model_fname", type=str) 125 | argparser.add_argument("--seq2seq_model_name", type=str) 126 | argparser.add_argument("--rl_model_fname", type=str) 127 | argparser.add_argument("--rl_model_name", type=str) 128 | argparser.add_argument("--gan_model_fname", type=str) 129 | argparser.add_argument("--gan_model_name", type=str) 130 | args = argparser.parse_args() 131 | print args 132 | print "" 133 | main(args) 134 | 135 | -------------------------------------------------------------------------------- /src/evaluation/create_crowdflower_data_beam.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import gzip 3 | import nltk 4 | import pdb 5 | import sys 6 | from collections import defaultdict 7 | import csv 8 | import random 9 | 10 | 11 | model_list = ['ref', 'lucene', 'seq2seq.beam', 12 | 'rl.beam', 13 | 'gan.beam'] 14 | model_dict = {'ref': 0, 'lucene': 1, 'seq2seq.beam': 2, 15 | 'rl.beam': 3, 16 | 'gan.beam': 4} 17 | 18 | 19 | def parse(path): 20 | g = gzip.open(path, 'r') 21 | for l in g: 22 | yield eval(l) 23 | 24 | 25 | def main(args): 26 | csv_file = open(args.output_csv_file, 'w') 27 | writer = csv.writer(csv_file, delimiter=',') 28 | writer.writerow(['asin', 'title', 'description', 'model_name', 'question']) 29 | all_rows = [] 30 | asins_so_far = defaultdict(list) 31 | with open(args.previous_csv_file) as csvfile: 32 | reader = csv.DictReader(csvfile) 33 | for row in reader: 34 | asin = row['asin'] 35 | if row['model_name'] not in model_list: 36 | continue 37 | if asin not in asins_so_far[row['model_name']]: 38 | asins_so_far[row['model_name']].append(asin) 39 | else: 40 | print 'Duplicate asin %s in %s' % (asin, row['model_name']) 41 | continue 42 | all_rows.append([row['asin'], row['title'], row['description'], row['model_name'], row['question']]) 43 | random.shuffle(all_rows) 44 | for row in all_rows: 45 | writer.writerow(row) 46 | csv_file.close() 47 | 48 | 49 | if __name__ == "__main__": 50 | argparser = argparse.ArgumentParser(sys.argv[0]) 51 | argparser.add_argument("--previous_csv_file", type=str) 52 | argparser.add_argument("--output_csv_file", type=str) 53 | args = argparser.parse_args() 54 | print args 55 | print "" 56 | main(args) 57 | 58 | -------------------------------------------------------------------------------- /src/evaluation/create_crowdflower_data_compare_ques.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import gzip 3 | import nltk 4 | import pdb 5 | import sys 6 | from collections import defaultdict 7 | import csv 8 | import random 9 | 10 | 11 | def parse(path): 12 | g = gzip.open(path, 'r') 13 | for l in g: 14 | yield eval(l) 15 | 16 | 17 | def main(args): 18 | titles = {} 19 | descriptions = {} 20 | 21 | previous_asins = [] 22 | with open(args.previous_csv_file_v1) as csvfile: 23 | reader = csv.DictReader(csvfile) 24 | for row in reader: 25 | asin = row['asin'] 26 | previous_asins.append(asin) 27 | with open(args.previous_csv_file_v2) as csvfile: 28 | reader = csv.DictReader(csvfile) 29 | for row in reader: 30 | asin = row['asin'] 31 | previous_asins.append(asin) 32 | with open(args.previous_csv_file_v3) as csvfile: 33 | reader = csv.DictReader(csvfile) 34 | for row in reader: 35 | asin = row['asin'] 36 | previous_asins.append(asin) 37 | with open(args.previous_csv_file_v4) as csvfile: 38 | reader = csv.DictReader(csvfile) 39 | for row in reader: 40 | asin = row['asin'] 41 | previous_asins.append(asin) 42 | 43 | train_asins = [line.strip('\n') for line in open(args.train_asins, 'r').readlines()] 44 | 45 | for v in parse(args.metadata_fname): 46 | asin = v['asin'] 47 | if asin not in train_asins: 48 | continue 49 | if asin in previous_asins: 50 | continue 51 | title = v['title'] 52 | description = v['description'] 53 | length = len(description.split()) 54 | if length >= 100 or length < 10 or len(title.split()) == length: 55 | continue 56 | titles[asin] = title 57 | descriptions[asin] = description 58 | 59 | questions = defaultdict(list) 60 | for v in parse(args.qa_data_fname): 61 | asin = v['asin'] 62 | if asin not in train_asins: 63 | continue 64 | if asin in previous_asins: 65 | continue 66 | if asin not in descriptions: 67 | continue 68 | question = ' '.join(nltk.sent_tokenize(v['question'])).lower() 69 | questions[asin].append(question) 70 | 71 | csv_file = open(args.csv_file, 'w') 72 | writer = csv.writer(csv_file, delimiter=',') 73 | writer.writerow(['asin', 'title', 'description', 'question_a', 'question_b']) 74 | all_rows = [] 75 | max_count = 200 76 | i = 0 77 | for asin in titles.keys(): 78 | if asin not in train_asins: 79 | continue 80 | if asin in previous_asins: 81 | continue 82 | title = titles[asin] 83 | description = descriptions[asin] 84 | random.shuffle(questions[asin]) 85 | for k in range(len(questions[asin])): 86 | if k == 6 or k == len(questions[asin])-1: 87 | all_rows.append([asin, title, description, questions[asin][k], questions[asin][0]]) 88 | print asin, k+1 89 | break 90 | else: 91 | all_rows.append([asin, title, description, questions[asin][k], questions[asin][k + 1]]) 92 | i += 1 93 | if i >= max_count: 94 | break 95 | random.shuffle(all_rows) 96 | for row in all_rows: 97 | writer.writerow(row) 98 | csv_file.close() 99 | 100 | 101 | if __name__ == "__main__": 102 | argparser = argparse.ArgumentParser(sys.argv[0]) 103 | argparser.add_argument("--qa_data_fname", type = str) 104 | argparser.add_argument("--metadata_fname", type = str) 105 | argparser.add_argument("--csv_file", type=str) 106 | argparser.add_argument("--train_asins", type=str) 107 | argparser.add_argument("--previous_csv_file_v1", type=str) 108 | argparser.add_argument("--previous_csv_file_v2", type=str) 109 | argparser.add_argument("--previous_csv_file_v3", type=str) 110 | argparser.add_argument("--previous_csv_file_v4", type=str) 111 | args = argparser.parse_args() 112 | print args 113 | print "" 114 | main(args) 115 | 116 | -------------------------------------------------------------------------------- /src/evaluation/create_crowdflower_data_compare_ques_allpairs.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import gzip 3 | import nltk 4 | import pdb 5 | import sys 6 | from collections import defaultdict 7 | import csv 8 | import random 9 | 10 | 11 | def parse(path): 12 | g = gzip.open(path, 'r') 13 | for l in g: 14 | yield eval(l) 15 | 16 | 17 | def main(args): 18 | titles = {} 19 | descriptions = {} 20 | 21 | previous_asins = [] 22 | with open(args.previous_csv_file) as csvfile: 23 | reader = csv.DictReader(csvfile) 24 | for row in reader: 25 | asin = row['asin'] 26 | previous_asins.append(asin) 27 | 28 | train_asins = [line.strip('\n') for line in open(args.train_asins, 'r').readlines()] 29 | 30 | for v in parse(args.metadata_fname): 31 | asin = v['asin'] 32 | if asin not in train_asins: 33 | continue 34 | if asin in previous_asins: 35 | continue 36 | title = v['title'] 37 | description = v['description'] 38 | length = len(description.split()) 39 | if length >= 100 or length < 10 or len(title.split()) == length: 40 | continue 41 | titles[asin] = title 42 | descriptions[asin] = description 43 | 44 | questions = defaultdict(list) 45 | for v in parse(args.qa_data_fname): 46 | asin = v['asin'] 47 | if asin not in train_asins: 48 | continue 49 | if asin in previous_asins: 50 | continue 51 | if asin not in descriptions: 52 | continue 53 | question = ' '.join(nltk.sent_tokenize(v['question'])).lower() 54 | questions[asin].append(question) 55 | 56 | csv_file = open(args.csv_file, 'w') 57 | writer = csv.writer(csv_file, delimiter=',') 58 | writer.writerow(['asin', 'title', 'description', 'question_a', 'question_b']) 59 | all_rows = [] 60 | max_count = 25 61 | i = 0 62 | for asin in titles.keys(): 63 | if asin not in train_asins: 64 | continue 65 | if asin in previous_asins: 66 | continue 67 | title = titles[asin] 68 | description = descriptions[asin] 69 | random.shuffle(questions[asin]) 70 | for j in range(len(questions[asin])): 71 | if j == 6 or j == len(questions[asin]) - 1: 72 | break 73 | for k in range(j+1, min(len(questions[asin]), 7)): 74 | all_rows.append([asin, title, description, questions[asin][j], questions[asin][k]]) 75 | i += 1 76 | if i >= max_count: 77 | break 78 | random.shuffle(all_rows) 79 | for row in all_rows: 80 | writer.writerow(row) 81 | csv_file.close() 82 | 83 | 84 | if __name__ == "__main__": 85 | argparser = argparse.ArgumentParser(sys.argv[0]) 86 | argparser.add_argument("--qa_data_fname", type = str) 87 | argparser.add_argument("--metadata_fname", type = str) 88 | argparser.add_argument("--csv_file", type=str) 89 | argparser.add_argument("--train_asins", type=str) 90 | argparser.add_argument("--previous_csv_file", type=str) 91 | args = argparser.parse_args() 92 | print args 93 | print "" 94 | main(args) 95 | 96 | -------------------------------------------------------------------------------- /src/evaluation/create_crowdflower_data_specificity.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import gzip 3 | import nltk 4 | import pdb 5 | import sys 6 | from collections import defaultdict 7 | import csv 8 | import random 9 | 10 | 11 | def parse(path): 12 | g = gzip.open(path, 'r') 13 | for l in g: 14 | yield eval(l) 15 | 16 | 17 | def read_model_outputs(model_fname): 18 | i = 0 19 | model_outputs = {} 20 | model_file = open(model_fname, 'r') 21 | test_ids = [line.strip('\n') for line in open(model_fname+'.ids', 'r')] 22 | for line in model_file.readlines(): 23 | model_outputs[test_ids[i]] = line.strip('\n').replace(' ', '').replace(' ', '') 24 | i += 1 25 | return model_outputs 26 | 27 | 28 | def main(args): 29 | titles = {} 30 | descriptions = {} 31 | lucene_model_outs = read_model_outputs(args.lucene_model_fname) 32 | seq2seq_model_outs = read_model_outputs(args.seq2seq_model_fname) 33 | seq2seq_specific_model_outs = read_model_outputs(args.seq2seq_specific_model_fname) 34 | seq2seq_generic_model_outs = read_model_outputs(args.seq2seq_generic_model_fname) 35 | 36 | for v in parse(args.metadata_fname): 37 | asin = v['asin'] 38 | if not v.has_key('description') or not v.has_key('title'): 39 | continue 40 | title = v['title'] 41 | description = v['description'] 42 | length = len(description.split()) 43 | if length >= 100 or length < 10 or len(title.split()) == length: 44 | continue 45 | titles[asin] = title 46 | descriptions[asin] = description 47 | 48 | questions = defaultdict(list) 49 | for v in parse(args.qa_data_fname): 50 | asin = v['asin'] 51 | if asin not in descriptions: 52 | continue 53 | question = ' '.join(nltk.sent_tokenize(v['question'])).lower() 54 | questions[asin].append(question) 55 | 56 | csv_file = open(args.csv_file, 'w') 57 | writer = csv.writer(csv_file, delimiter=',') 58 | writer.writerow(['asin', 'title', 'description', 'model_name', 'question']) 59 | all_rows = [] 60 | i = 0 61 | max_count = 50 62 | for asin in seq2seq_specific_model_outs.keys(): 63 | if asin not in descriptions: 64 | continue 65 | seq2seq_specific_question = seq2seq_specific_model_outs[asin] 66 | seq2seq_specific_question_tokens = seq2seq_specific_question.split() 67 | if '?' in seq2seq_specific_question_tokens: 68 | if seq2seq_specific_question_tokens.index('?') != len(seq2seq_specific_question_tokens)-1 : 69 | continue 70 | title = titles[asin] 71 | description = descriptions[asin] 72 | ref_question = random.choice(questions[asin]) 73 | lucene_question = lucene_model_outs[asin] 74 | seq2seq_question = seq2seq_model_outs[asin] 75 | seq2seq_generic_question = seq2seq_generic_model_outs[asin] 76 | all_rows.append([asin, title, description, "ref", ref_question]) 77 | all_rows.append([asin, title, description, args.lucene_model_name, lucene_question]) 78 | all_rows.append([asin, title, description, args.seq2seq_model_name, seq2seq_question]) 79 | all_rows.append([asin, title, description, args.seq2seq_specific_model_name, seq2seq_specific_question]) 80 | all_rows.append([asin, title, description, args.seq2seq_generic_model_name, seq2seq_generic_question]) 81 | i += 1 82 | if i >= max_count: 83 | break 84 | random.shuffle(all_rows) 85 | for row in all_rows: 86 | writer.writerow(row) 87 | csv_file.close() 88 | 89 | 90 | if __name__ == "__main__": 91 | argparser = argparse.ArgumentParser(sys.argv[0]) 92 | argparser.add_argument("--qa_data_fname", type = str) 93 | argparser.add_argument("--metadata_fname", type = str) 94 | #argparser.add_argument("--prev_csv_file", type=str) 95 | argparser.add_argument("--csv_file", type=str) 96 | argparser.add_argument("--lucene_model_name", type=str) 97 | argparser.add_argument("--lucene_model_fname", type=str) 98 | argparser.add_argument("--seq2seq_model_name", type=str) 99 | argparser.add_argument("--seq2seq_model_fname", type=str) 100 | argparser.add_argument("--seq2seq_specific_model_name", type=str) 101 | argparser.add_argument("--seq2seq_specific_model_fname", type=str) 102 | argparser.add_argument("--seq2seq_generic_model_name", type=str) 103 | argparser.add_argument("--seq2seq_generic_model_fname", type=str) 104 | args = argparser.parse_args() 105 | print args 106 | print "" 107 | main(args) 108 | 109 | -------------------------------------------------------------------------------- /src/evaluation/create_crowdflower_data_specificity_multi.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import gzip 3 | import nltk 4 | import pdb 5 | import sys 6 | from collections import defaultdict 7 | import csv 8 | import random 9 | 10 | 11 | def parse(path): 12 | g = gzip.open(path, 'r') 13 | for l in g: 14 | yield eval(l) 15 | 16 | 17 | def read_model_outputs(model_fname): 18 | i = 0 19 | model_outputs = {} 20 | model_file = open(model_fname, 'r') 21 | test_ids = [line.strip('\n') for line in open(model_fname+'.ids', 'r')] 22 | for line in model_file.readlines(): 23 | model_outputs[test_ids[i]] = line.strip('\n').replace(' ', '').replace(' ', '') 24 | i += 1 25 | return model_outputs 26 | 27 | 28 | def main(args): 29 | titles = {} 30 | descriptions = {} 31 | lucene_model_outs = read_model_outputs(args.lucene_model_fname) 32 | seq2seq_model_outs = read_model_outputs(args.seq2seq_model_fname) 33 | seq2seq_specific_model_outs = read_model_outputs(args.seq2seq_specific_model_fname) 34 | seq2seq_generic_model_outs = read_model_outputs(args.seq2seq_generic_model_fname) 35 | 36 | prev_asins = [] 37 | with open(args.prev_csv_file) as csvfile: 38 | reader = csv.DictReader(csvfile) 39 | for row in reader: 40 | asin = row['asin'] 41 | prev_asins.append(asin) 42 | 43 | for v in parse(args.metadata_fname): 44 | asin = v['asin'] 45 | if asin in prev_asins: 46 | continue 47 | if not v.has_key('description') or not v.has_key('title'): 48 | continue 49 | title = v['title'] 50 | description = v['description'] 51 | length = len(description.split()) 52 | if length >= 100 or length < 10 or len(title.split()) == length: 53 | continue 54 | titles[asin] = title 55 | descriptions[asin] = description 56 | 57 | questions = defaultdict(list) 58 | for v in parse(args.qa_data_fname): 59 | asin = v['asin'] 60 | if asin in prev_asins: 61 | continue 62 | if asin not in descriptions: 63 | continue 64 | question = ' '.join(nltk.sent_tokenize(v['question'])).lower() 65 | questions[asin].append(question) 66 | 67 | csv_file = open(args.csv_file, 'w') 68 | writer = csv.writer(csv_file, delimiter=',') 69 | writer.writerow(['asin', 'title', 'description', 'model_name', 'question']) 70 | all_rows = [] 71 | i = 0 72 | max_count = 50 73 | for asin in seq2seq_specific_model_outs.keys(): 74 | if asin not in descriptions: 75 | continue 76 | seq2seq_specific_question = seq2seq_specific_model_outs[asin] 77 | seq2seq_specific_question_tokens = seq2seq_specific_question.split() 78 | if '?' in seq2seq_specific_question_tokens: 79 | if seq2seq_specific_question_tokens.index('?') == len(seq2seq_specific_question_tokens)-1 : 80 | continue 81 | title = titles[asin] 82 | description = descriptions[asin] 83 | ref_question = random.choice(questions[asin]) 84 | lucene_question = lucene_model_outs[asin] 85 | seq2seq_question = seq2seq_model_outs[asin] 86 | seq2seq_generic_question = seq2seq_generic_model_outs[asin] 87 | all_rows.append([asin, title, description, "ref", ref_question]) 88 | all_rows.append([asin, title, description, args.lucene_model_name, lucene_question]) 89 | all_rows.append([asin, title, description, args.seq2seq_model_name, seq2seq_question]) 90 | all_rows.append([asin, title, description, args.seq2seq_specific_model_name, seq2seq_specific_question]) 91 | all_rows.append([asin, title, description, args.seq2seq_generic_model_name, seq2seq_generic_question]) 92 | i += 1 93 | if i >= max_count: 94 | break 95 | random.shuffle(all_rows) 96 | for row in all_rows: 97 | writer.writerow(row) 98 | csv_file.close() 99 | 100 | 101 | if __name__ == "__main__": 102 | argparser = argparse.ArgumentParser(sys.argv[0]) 103 | argparser.add_argument("--qa_data_fname", type = str) 104 | argparser.add_argument("--metadata_fname", type = str) 105 | argparser.add_argument("--prev_csv_file", type=str) 106 | argparser.add_argument("--csv_file", type=str) 107 | argparser.add_argument("--lucene_model_name", type=str) 108 | argparser.add_argument("--lucene_model_fname", type=str) 109 | argparser.add_argument("--seq2seq_model_name", type=str) 110 | argparser.add_argument("--seq2seq_model_fname", type=str) 111 | argparser.add_argument("--seq2seq_specific_model_name", type=str) 112 | argparser.add_argument("--seq2seq_specific_model_fname", type=str) 113 | argparser.add_argument("--seq2seq_generic_model_name", type=str) 114 | argparser.add_argument("--seq2seq_generic_model_fname", type=str) 115 | args = argparser.parse_args() 116 | print args 117 | print "" 118 | main(args) 119 | 120 | -------------------------------------------------------------------------------- /src/evaluation/create_preds_for_refs.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import csv 3 | import sys, os, pdb 4 | import nltk 5 | import time 6 | 7 | def get_annotations(line): 8 | set_info, post_id, best, valids, confidence = line.split(',') 9 | annotator_name = set_info.split('_')[0] 10 | sitename = set_info.split('_')[1] 11 | best = int(best) 12 | valids = [int(v) for v in valids.split()] 13 | confidence = int(confidence) 14 | return post_id, annotator_name, sitename, best, valids, confidence 15 | 16 | def read_human_annotations(human_annotations_filename): 17 | human_annotations_file = open(human_annotations_filename, 'r') 18 | annotations = {} 19 | for line in human_annotations_file.readlines(): 20 | line = line.strip('\n') 21 | splits = line.split('\t') 22 | post_id1, annotator_name1, sitename1, best1, valids1, confidence1 = get_annotations(splits[0]) 23 | post_id2, annotator_name2, sitename2, best2, valids2, confidence2 = get_annotations(splits[1]) 24 | assert(sitename1 == sitename2) 25 | assert(post_id1 == post_id2) 26 | post_id = sitename1+'_'+post_id1 27 | best_union = list(set([best1, best2])) 28 | valids_inter = list(set(valids1).intersection(set(valids2))) 29 | annotations[post_id] = list(set(best_union + valids_inter)) 30 | return annotations 31 | 32 | 33 | def main(args): 34 | question_candidates = {} 35 | model_outputs = [] 36 | 37 | test_ids = [test_id.strip('\n') for test_id in open(args.test_ids_file, 'r').readlines()] 38 | with open(args.qa_data_tsvfile, 'rb') as tsvfile: 39 | qa_reader = csv.reader(tsvfile, delimiter='\t') 40 | i = 0 41 | for row in qa_reader: 42 | if i == 0: 43 | i += 1 44 | continue 45 | post_id,questions = row[0], row[1:11] 46 | question_candidates[post_id] = questions 47 | 48 | annotations = read_human_annotations(args.human_annotations) 49 | model_output_file = open(args.model_output_file, 'r') 50 | for line in model_output_file.readlines(): 51 | model_outputs.append(line.strip('\n')) 52 | 53 | pred_file = open(args.model_output_file+'.hasrefs', 'w') 54 | for i, post_id in enumerate(test_ids): 55 | if post_id not in annotations: 56 | continue 57 | pred_file.write(model_outputs[i]+'\n') 58 | 59 | if __name__ == "__main__": 60 | argparser = argparse.ArgumentParser(sys.argv[0]) 61 | argparser.add_argument("--qa_data_tsvfile", type = str) 62 | argparser.add_argument("--human_annotations", type = str) 63 | argparser.add_argument("--model_output_file", type = str) 64 | argparser.add_argument("--test_ids_file", type = str) 65 | args = argparser.parse_args() 66 | print args 67 | print "" 68 | main(args) 69 | 70 | -------------------------------------------------------------------------------- /src/evaluation/eval_HK: -------------------------------------------------------------------------------- 1 | 2 | #lucene 3 | 4 | cat Home_and_Kitchen/blind_test_pred_question.lucene.txt | /fs/clip-ml/hal/bin/all_ngrams.pl 3 | sort | uniq -c | sort -gr | wc -l 5 | 6 | /fs/clip-software/user-supported/mosesdecoder/3.0/scripts/generic/multi-bleu.perl Home_and_Kitchen/test_ref < Home_and_Kitchen/blind_test_pred_question.lucene.txt 7 | 8 | #Seq2seq 9 | 10 | cat Home_and_Kitchen/blind_test_pred_ques.txt.seq2seq.epoch100.beam0.nounks | /fs/clip-ml/hal/bin/all_ngrams.pl 3 | sort | uniq -c | sort -gr | wc -l 11 | 12 | /fs/clip-software/user-supported/mosesdecoder/3.0/scripts/generic/multi-bleu.perl Home_and_Kitchen/test_ref < Home_and_Kitchen/blind_test_pred_ques.txt.seq2seq.epoch100.beam0.nounks 13 | 14 | #RL 15 | 16 | cat Home_and_Kitchen/blind_test_pred_ques.txt.RL_mixer_3perid.epoch5.beam0.nounks| /fs/clip-ml/hal/bin/all_ngrams.pl 3 | sort | uniq -c | sort -gr | wc -l 17 | 18 | /fs/clip-software/user-supported/mosesdecoder/3.0/scripts/generic/multi-bleu.perl Home_and_Kitchen/test_ref < Home_and_Kitchen/blind_test_pred_ques.txt.RL_mixer_3perid.epoch5.beam0.nounks 19 | 20 | #GAN 21 | 22 | cat Home_and_Kitchen/blind_test_pred_ques.txt.GAN_selfcritic_pred_ans_3perid.epoch8.beam0.nounks | /fs/clip-ml/hal/bin/all_ngrams.pl 3 | sort | uniq -c | sort -gr | wc -l 23 | 24 | /fs/clip-software/user-supported/mosesdecoder/3.0/scripts/generic/multi-bleu.perl Home_and_Kitchen/test_ref < Home_and_Kitchen/blind_test_pred_ques.txt.GAN_selfcritic_pred_ans_3perid.epoch8.beam0.nounks 25 | -------------------------------------------------------------------------------- /src/evaluation/eval_HK_spec: -------------------------------------------------------------------------------- 1 | # Lucene 2 | 3 | cat Home_and_Kitchen/blind_test_pred_question.lucene.txt | /fs/clip-ml/hal/bin/all_ngrams.pl 3 | sort | uniq -c | sort -gr | wc -l 4 | 5 | /fs/clip-software/user-supported/mosesdecoder/3.0/scripts/generic/multi-bleu.perl ../style_clarification_question_generation/specificQ_classifier/test_data/test_ref_specific < Home_and_Kitchen/blind_test_pred_question.lucene.txt 6 | 7 | # Seq2seq model 8 | 9 | cat Home_and_Kitchen/blind_test_pred_ques.txt.seq2seq.epoch100.beam0.nounks | /fs/clip-ml/hal/bin/all_ngrams.pl 3 | sort | uniq -c | sort -gr | wc -l 10 | 11 | /fs/clip-software/user-supported/mosesdecoder/3.0/scripts/generic/multi-bleu.perl ../style_clarification_question_generation/specificQ_classifier/test_data/test_ref_specific < Home_and_Kitchen/blind_test_pred_ques.txt.seq2seq.epoch100.beam0.nounks 12 | 13 | 14 | # RL model 15 | cat Home_and_Kitchen/blind_test_pred_ques.txt.RL_mixer_3perid.epoch5.beam0.nounks | /fs/clip-ml/hal/bin/all_ngrams.pl 3 | sort | uniq -c | sort -gr | wc -l 16 | 17 | /fs/clip-software/user-supported/mosesdecoder/3.0/scripts/generic/multi-bleu.perl ../style_clarification_question_generation/specificQ_classifier/test_data/test_ref_specific < Home_and_Kitchen/blind_test_pred_ques.txt.RL_mixer_3perid.epoch5.beam0.nounks 18 | 19 | # GAN model 20 | 21 | /fs/clip-software/user-supported/mosesdecoder/3.0/scripts/generic/multi-bleu.perl ../style_clarification_question_generation/specificQ_classifier/test_data/test_ref_generic < Home_and_Kitchen/blind_test_pred_ques.txt.GAN_selfcritic_pred_ans_3perid.epoch8.beam0.nounks 22 | 23 | # Specificity seq2seq model 24 | 25 | cat Home_and_Kitchen/blind_test_pred_ques.txt.seq2seq_tobeginning_tospecific.epoch65.beam0.nounks | /fs/clip-ml/hal/bin/all_ngrams.pl 3 | sort | uniq -c | sort -gr | wc -l 26 | 27 | /fs/clip-software/user-supported/mosesdecoder/3.0/scripts/generic/multi-bleu.perl ../style_clarification_question_generation/specificQ_classifier/test_data/test_ref_specific < Home_and_Kitchen/blind_test_pred_ques.txt.seq2seq_tobeginning_tospecific.epoch65.beam0.nounks 28 | 29 | # Specificty GAN model 30 | 31 | cat Home_and_Kitchen/blind_test_pred_ques.txt.GAN_selfcritic_pred_ans_tobeginning_tospecific.epoch8.beam0.nounks | /fs/clip-ml/hal/bin/all_ngrams.pl 3 | sort | uniq -c | sort -gr | wc -l 32 | 33 | /fs/clip-software/user-supported/mosesdecoder/3.0/scripts/generic/multi-bleu.perl ../style_clarification_question_generation/specificQ_classifier/test_data/test_ref_generic < Home_and_Kitchen/blind_test_pred_ques.txt.GAN_selfcritic_pred_ans_tobeginning_togeneric.epoch8.beam0.nounks 34 | -------------------------------------------------------------------------------- /src/evaluation/eval_aus: -------------------------------------------------------------------------------- 1 | #Reference 2 | 3 | /fs/clip-software/user-supported/mosesdecoder/3.0/scripts/generic/multi-bleu.perl ../clarification_question_generation/data/askubuntu_unix_superuser/test_ref < /fs/clip-scratch/raosudha/clarification_question_generation/joint_learning/askubuntu_unix_superuser/test_question.txt.hasrefs 4 | 5 | #Lucene 6 | 7 | /fs/clip-software/user-supported/mosesdecoder/3.0/scripts/generic/multi-bleu.perl ../clarification_question_generation/data/askubuntu_unix_superuser/test_ref < ../clarification_question_generation/data/askubuntu_unix_superuser/test_pred_lucene.txt.hasrefs 8 | 9 | cat ../clarification_question_generation/data/askubuntu_unix_superuser/test_pred_lucene.txt.hasrefs | /fs/clip-ml/hal/bin/all_ngrams.pl 1 | sort | uniq -c | sort -gr | wc -l 10 | 11 | #Seq2Seq 12 | 13 | /fs/clip-software/user-supported/mosesdecoder/3.0/scripts/generic/multi-bleu.perl ../clarification_question_generation/data/askubuntu_unix_superuser/test_ref < /fs/clip-scratch/raosudha/clarification_question_generation/joint_learning/askubuntu_unix_superuser/test_pred_question.txt.seq2seq.len_norm.beam0.hasrefs.nounks 14 | 15 | #RL 16 | 17 | /fs/clip-software/user-supported/mosesdecoder/3.0/scripts/generic/multi-bleu.perl ../clarification_question_generation/data/askubuntu_unix_superuser/test_ref < /fs/clip-scratch/raosudha/clarification_question_generation/joint_learning/askubuntu_unix_superuser/test_pred_question.txt.RL_mixer.epoch8.len_norm.beam0.hasrefs.nounks 18 | 19 | cat /fs/clip-scratch/raosudha/clarification_question_generation/joint_learning/askubuntu_unix_superuser/test_pred_question.txt.RL_selfcritic.epoch8.len_norm.beam0.hasrefs.nounks | /fs/clip-ml/hal/bin/all_ngrams.pl 3 | sort | uniq -c | sort -gr | wc -l 20 | 21 | #GAN 22 | 23 | /fs/clip-software/user-supported/mosesdecoder/3.0/scripts/generic/multi-bleu.perl ../clarification_question_generation/data/askubuntu_unix_superuser/test_ref < /fs/clip-scratch/raosudha/clarification_question_generation/joint_learning/askubuntu_unix_superuser/test_pred_question.txt.mixer_pred_ans.epoch8.len_norm.beam0.hasrefs.nounks 24 | 25 | /fs/clip-software/user-supported/mosesdecoder/3.0/scripts/generic/multi-bleu.perl ../clarification_question_generation/data/askubuntu_unix_superuser/test_ref < /fs/clip-scratch/raosudha/clarification_question_generation/joint_learning/askubuntu_unix_superuser/test_pred_question.txt.selfcritic_pred_ans.epoch8.len_norm.beam0.hasrefs.nounks 26 | 27 | cat /fs/clip-scratch/raosudha/clarification_question_generation/joint_learning/askubuntu_unix_superuser/test_pred_question.txt.mixer_pred_ans.epoch8.len_norm.beam0.hasrefs.nounks | /fs/clip-ml/hal/bin/all_ngrams.pl 1 | sort | uniq -c | sort -gr | wc -l 28 | -------------------------------------------------------------------------------- /src/evaluation/read_crowdflower_full_results_compare_ques.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import csv 3 | from collections import defaultdict 4 | import sys 5 | import pdb 6 | import numpy as np 7 | 8 | specificity_levels = ['Question A is more specific', 9 | 'Question B is more specific', 10 | 'Both are at the same level of specificity', 11 | 'N/A: One or both questions are not applicable to the product'] 12 | 13 | 14 | def get_avg_score(score_dict, ignore_na=False): 15 | curr_on_topic_score = 0. 16 | N = 0 17 | for score, count in score_dict.iteritems(): 18 | curr_on_topic_score += score * count 19 | if ignore_na: 20 | if score != 0: 21 | N += count 22 | else: 23 | N += count 24 | # print N 25 | return curr_on_topic_score * 1.0 / N 26 | 27 | 28 | def main(args): 29 | titles = {} 30 | descriptions = {} 31 | cand_ques_dict = defaultdict(list) 32 | cand_scores_dict = defaultdict(list) 33 | curr_asin = None 34 | curr_a = None 35 | curr_b = None 36 | a_scores = [] 37 | b_scores = [] 38 | ab_scores = [] 39 | with open(args.full_results) as csvfile: 40 | reader = csv.DictReader(csvfile) 41 | for row in reader: 42 | if row['_golden'] == 'true' or row['_tainted'] == 'true': 43 | continue 44 | asin = row['asin'] 45 | titles[asin] = row['title'] 46 | descriptions[asin] = row['description'] 47 | question_a = row['question_a'] 48 | question_b = row['question_b'] 49 | trust = row['_trust'] 50 | 51 | if curr_asin is None: 52 | curr_asin = asin 53 | curr_a = question_a 54 | curr_b = question_b 55 | elif asin != curr_asin and curr_a != question_a and curr_b != question_b: 56 | a_score = np.sum(a_scores)/5 57 | b_score = np.sum(b_scores)/5 58 | ab_score = np.sum(ab_scores)/5 59 | cand_scores_dict[curr_asin][cand_ques_dict[curr_asin].index(curr_a)].append(a_score + 0.5*ab_score) 60 | cand_scores_dict[curr_asin][cand_ques_dict[curr_asin].index(curr_b)].append(b_score + 0.5*ab_score) 61 | curr_asin = asin 62 | curr_a = question_a 63 | curr_b = question_b 64 | a_scores = [] 65 | b_scores = [] 66 | ab_scores = [] 67 | 68 | if question_a not in cand_ques_dict[asin]: 69 | cand_ques_dict[asin].append(question_a) 70 | cand_scores_dict[asin].append([]) 71 | if question_b not in cand_ques_dict[asin]: 72 | cand_ques_dict[asin].append(question_b) 73 | cand_scores_dict[asin].append([]) 74 | 75 | if row['on_topic'] == 'Question A is more specific': 76 | a_scores.append(float(trust)) 77 | elif row['on_topic'] == 'Question B is more specific': 78 | b_scores.append(float(trust)) 79 | elif row['on_topic'] == 'Both are at the same level of specificity': 80 | ab_scores.append(float(trust)) 81 | else: 82 | print 'ID: %s has irrelevant question' % asin 83 | 84 | corr = 0 85 | total = 0 86 | fp, tp, fn, tn = 0, 0, 0, 0 87 | for asin in titles: 88 | print asin 89 | print titles[asin] 90 | print descriptions[asin] 91 | for i, ques in enumerate(cand_ques_dict[asin]): 92 | true_v = np.mean(cand_scores_dict[asin][i]) 93 | pred_v = np.mean(cand_scores_dict[asin][i][:int(len(cand_scores_dict[asin][i])/2)]) 94 | print true_v, cand_scores_dict[asin][i], ques 95 | print pred_v 96 | if true_v < 0.5: 97 | true_l = 0 98 | else: 99 | true_l = 1 100 | if pred_v < 0.5: 101 | pred_l = 0 102 | else: 103 | pred_l = 1 104 | if true_l == pred_l: 105 | corr += 1 106 | if true_l == 0 and pred_l == 1: 107 | fp += 1 108 | if true_l == 0 and pred_l == 0: 109 | tn += 1 110 | if true_l == 1 and pred_l == 0: 111 | fn += 1 112 | if true_l == 1 and pred_l == 1: 113 | tp += 1 114 | total += 1 115 | print 116 | print 'accuracy' 117 | print corr*1.0/total 118 | print tp, fp, fn, tn 119 | 120 | 121 | if __name__ == '__main__': 122 | argparser = argparse.ArgumentParser(sys.argv[0]) 123 | argparser.add_argument("--full_results", type = str) 124 | args = argparser.parse_args() 125 | print args 126 | print "" 127 | main(args) 128 | -------------------------------------------------------------------------------- /src/evaluation/read_crowdflower_results_binary.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import csv 3 | from collections import defaultdict 4 | import sys 5 | import numpy as np 6 | import pdb 7 | 8 | generic_levels = {'Yes': 1, 'No': 0} 9 | 10 | model_dict = {'ref': 0, 'lucene': 1, 'seq2seq.beam': 2, 11 | 'rl.beam': 3, 12 | 'gan.beam': 4} 13 | 14 | model_list = ['ref', 'lucene', 'seq2seq.beam', 15 | 'rl.beam', 16 | 'gan.beam'] 17 | 18 | 19 | def get_avg_score(score_dict, ignore_na=False): 20 | curr_on_topic_score = 0. 21 | N = 0 22 | for score, count in score_dict.iteritems(): 23 | curr_on_topic_score += score * count 24 | if ignore_na: 25 | if score != 0: 26 | N += count 27 | else: 28 | N += count 29 | # print N 30 | return curr_on_topic_score * 1.0 / N 31 | 32 | 33 | def main(args): 34 | num_models = len(model_list) 35 | generic_scores = [None] * num_models 36 | asins_so_far = [None] * num_models 37 | for i in range(num_models): 38 | generic_scores[i] = defaultdict(int) 39 | asins_so_far[i] = [] 40 | 41 | with open(args.aggregate_results) as csvfile: 42 | reader = csv.DictReader(csvfile) 43 | for row in reader: 44 | if row['_golden'] == 'true' or row['_unit_state'] == 'golden': 45 | continue 46 | asin = row['asin'] 47 | question = row['question'] 48 | model_name = row['model_name'] 49 | if model_name not in model_list: 50 | continue 51 | if asin not in asins_so_far[model_dict[model_name]]: 52 | asins_so_far[model_dict[model_name]].append(asin) 53 | else: 54 | print '%s duplicate %s' % (model_name, asin) 55 | continue 56 | generic_score = generic_levels[row['on_topic']] 57 | generic_scores[model_dict[model_name]][generic_score] += 1 58 | 59 | for i in range(num_models): 60 | print model_list[i] 61 | print len(asins_so_far[i]) 62 | print 'Avg on generic score: %.2f' % get_avg_score(generic_scores[i]) 63 | print 'Generic:', generic_scores[i] 64 | 65 | 66 | if __name__ == '__main__': 67 | argparser = argparse.ArgumentParser(sys.argv[0]) 68 | argparser.add_argument("--aggregate_results", type = str) 69 | args = argparser.parse_args() 70 | print args 71 | print "" 72 | main(args) -------------------------------------------------------------------------------- /src/evaluation/read_crowdflower_results_compare_ques.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import csv 3 | from collections import defaultdict 4 | import sys 5 | import pdb 6 | import numpy as np 7 | 8 | specificity_levels = ['Question A is more specific', 9 | 'Question B is more specific', 10 | 'Both are at the same level of specificity', 11 | 'N/A: One or both questions are not applicable to the product'] 12 | 13 | 14 | def get_avg_score(score_dict, ignore_na=False): 15 | curr_on_topic_score = 0. 16 | N = 0 17 | for score, count in score_dict.iteritems(): 18 | curr_on_topic_score += score * count 19 | if ignore_na: 20 | if score != 0: 21 | N += count 22 | else: 23 | N += count 24 | # print N 25 | return curr_on_topic_score * 1.0 / N 26 | 27 | 28 | def main(args): 29 | titles = {} 30 | descriptions = {} 31 | cand_ques_dict = defaultdict(list) 32 | cand_scores_dict = defaultdict(list) 33 | with open(args.aggregate_results) as csvfile: 34 | reader = csv.DictReader(csvfile) 35 | for row in reader: 36 | if row['_unit_state'] != 'finalized': 37 | continue 38 | asin = row['asin'] 39 | titles[asin] = row['title'] 40 | descriptions[asin] = row['description'] 41 | question_a = row['question_a'] 42 | question_b = row['question_b'] 43 | confidence = row['on_topic:confidence'] 44 | 45 | if question_a not in cand_ques_dict[asin]: 46 | cand_ques_dict[asin].append(question_a) 47 | cand_scores_dict[asin].append([]) 48 | if question_b not in cand_ques_dict[asin]: 49 | cand_ques_dict[asin].append(question_b) 50 | cand_scores_dict[asin].append([]) 51 | 52 | if row['on_topic'] == 'Question A is more specific': 53 | cand_scores_dict[asin][cand_ques_dict[asin].index(question_a)].append(float(confidence)) 54 | cand_scores_dict[asin][cand_ques_dict[asin].index(question_b)].append((1 - float(confidence))) 55 | elif row['on_topic'] == 'Question B is more specific': 56 | cand_scores_dict[asin][cand_ques_dict[asin].index(question_b)].append(float(confidence)) 57 | cand_scores_dict[asin][cand_ques_dict[asin].index(question_a)].append((1 - float(confidence))) 58 | elif row['on_topic'] == 'Both are at the same level of specificity': 59 | cand_scores_dict[asin][cand_ques_dict[asin].index(question_b)].append(0.5) 60 | cand_scores_dict[asin][cand_ques_dict[asin].index(question_a)].append(0.5) 61 | else: 62 | print 'ID: %s has irrelevant question' % asin 63 | 64 | for asin in titles: 65 | print asin 66 | print titles[asin] 67 | print descriptions[asin] 68 | for i, ques in enumerate(cand_ques_dict[asin]): 69 | print np.mean(cand_scores_dict[asin][i]), cand_scores_dict[asin][i], ques 70 | print 71 | 72 | 73 | if __name__ == '__main__': 74 | argparser = argparse.ArgumentParser(sys.argv[0]) 75 | argparser.add_argument("--aggregate_results", type = str) 76 | args = argparser.parse_args() 77 | print args 78 | print "" 79 | main(args) -------------------------------------------------------------------------------- /src/evaluation/read_crowdflower_results_style.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import csv 3 | from collections import defaultdict 4 | import sys 5 | import numpy as np 6 | import pdb 7 | 8 | relevance_levels = {'Completely makes sense': 2, 9 | 'Somewhat makes sense': 1, 10 | 'Does not make sense': 0} 11 | is_grammatical_levels = {'Grammatical': 2, 'Comprehensible': 1, 'Incomprehensible': 0} 12 | is_specific_levels = {'Specific to this or the same product from a different manufacturer': 3, 13 | 'Specific to this or some other similar products': 2, 14 | 'Generic enough to be applicable to many other products of this type': 1, 15 | 'Generic enough to be applicable to any product under Home and Kitchen': 0, 16 | 'N/A (Not Applicable)': 0} 17 | asks_new_info_levels = {'Yes': 1, 'No': 0, 'N/A (Not Applicable)': 0} 18 | 19 | model_dict = {'ref': 0, 'lucene': 1, 'seq2seq': 2, 20 | 'seq2seq.generic': 3, 21 | 'seq2seq.specific': 4} 22 | 23 | model_list = ['ref', 'lucene', 'seq2seq', 24 | 'seq2seq.generic', 25 | 'seq2seq.specific'] 26 | 27 | 28 | def get_avg_score(score_dict, ignore_na=False): 29 | curr_relevance_score = 0. 30 | N = 0 31 | for score, count in score_dict.iteritems(): 32 | curr_relevance_score += score * count 33 | if ignore_na: 34 | if score != 0: 35 | N += count 36 | else: 37 | N += count 38 | # print N 39 | return curr_relevance_score * 1.0 / N 40 | 41 | 42 | def main(args): 43 | num_models = len(model_list) 44 | relevance_scores = [None] * num_models 45 | is_grammatical_scores = [None] * num_models 46 | is_specific_scores = [None] * num_models 47 | asks_new_info_scores = [None] * num_models 48 | asins_so_far = [None] * num_models 49 | for i in range(num_models): 50 | relevance_scores[i] = defaultdict(int) 51 | is_grammatical_scores[i] = defaultdict(int) 52 | is_specific_scores[i] = defaultdict(int) 53 | asks_new_info_scores[i] = defaultdict(int) 54 | asins_so_far[i] = [] 55 | 56 | with open(args.aggregate_results) as csvfile: 57 | reader = csv.DictReader(csvfile) 58 | for row in reader: 59 | if row['_golden'] == 'true' or row['_unit_state'] == 'golden': 60 | continue 61 | asin = row['asin'] 62 | question = row['question'] 63 | model_name = row['model_name'] 64 | if model_name not in model_list: 65 | continue 66 | if asin not in asins_so_far[model_dict[model_name]]: 67 | asins_so_far[model_dict[model_name]].append(asin) 68 | else: 69 | #print '%s duplicate %s' % (model_name, asin) 70 | continue 71 | relevance_score = relevance_levels[row['makes_sense']] 72 | is_grammatical_score = is_grammatical_levels[row['grammatical']] 73 | specific_score = is_specific_levels[row['is_specific']] 74 | asks_new_info_score = asks_new_info_levels[row['new_info']] 75 | relevance_scores[model_dict[model_name]][relevance_score] += 1 76 | is_grammatical_scores[model_dict[model_name]][is_grammatical_score] += 1 77 | if relevance_score != 0 and is_grammatical_score != 0: 78 | is_specific_scores[model_dict[model_name]][specific_score] += 1 79 | asks_new_info_scores[model_dict[model_name]][asks_new_info_score] += 1 80 | 81 | for i in range(num_models): 82 | print model_list[i] 83 | print len(asins_so_far[i]) 84 | print 'Avg on topic score: %.2f' % get_avg_score(relevance_scores[i]) 85 | print 'Avg grammaticality score: %.2f' % get_avg_score(is_grammatical_scores[i]) 86 | print 'Avg specificity score: %.2f' % get_avg_score(is_specific_scores[i]) 87 | print 'Avg new info score: %.2f' % get_avg_score(asks_new_info_scores[i]) 88 | print 89 | print 'On topic:', relevance_scores[i] 90 | print 'Is grammatical: ', is_grammatical_scores[i] 91 | print 'Is specific: ', is_specific_scores[i] 92 | print 'Asks new info: ', asks_new_info_scores[i] 93 | print 94 | 95 | 96 | if __name__ == '__main__': 97 | argparser = argparse.ArgumentParser(sys.argv[0]) 98 | argparser.add_argument("--aggregate_results", type = str) 99 | args = argparser.parse_args() 100 | print args 101 | print "" 102 | main(args) 103 | -------------------------------------------------------------------------------- /src/evaluation/run_bleu.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | SITENAME=Home_and_Kitchen 4 | CQ_DATA_DIR=/fs/clip-amr/clarification_question_generation_pytorch/$SITENAME 5 | BLEU_SCRIPT=/fs/clip-software/user-supported/mosesdecoder/3.0/scripts/generic/multi-bleu.perl 6 | 7 | $BLEU_SCRIPT $CQ_DATA_DIR/test_ref < $CQ_DATA_DIR/GAN_test_pred_question.txt.epoch8.hasrefs 8 | -------------------------------------------------------------------------------- /src/evaluation/run_create_amazon_multi_refs.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=create_amazon_multi_refs 4 | #SBATCH --output=create_amazon_multi_refs 5 | #SBATCH --qos=batch 6 | #SBATCH --mem=4g 7 | #SBATCH --time=4:00:00 8 | 9 | SITENAME=Home_and_Kitchen 10 | DATA_DIR=/fs/clip-corpora/amazon_qa/$SITENAME 11 | SCRIPT_DIR=/fs/clip-amr/clarification_question_generation_pytorch/src/evaluation/ 12 | CQ_DATA_DIR=/fs/clip-amr/clarification_question_generation_pytorch/$SITENAME 13 | 14 | python $SCRIPT_DIR/create_amazon_multi_refs.py --ques_dir $DATA_DIR/ques_docs \ 15 | --test_ids_file $CQ_DATA_DIR/blind_test_pred_question.txt.GAN_selfcritic_pred_ans_3perid.epoch8.len_norm.beam0.ids \ 16 | --ref_prefix $CQ_DATA_DIR/test_ref \ 17 | --test_context_file $CQ_DATA_DIR/test_context.txt 18 | -------------------------------------------------------------------------------- /src/evaluation/run_create_crowdflower_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=create_crowdflower_HK_beam_batch4 4 | #SBATCH --output=create_crowdflower_HK_beam_batch4 5 | #SBATCH --qos=batch 6 | #SBATCH --mem=32g 7 | #SBATCH --time=4:00:00 8 | 9 | SITENAME=Home_and_Kitchen 10 | CORPORA_DIR=/fs/clip-corpora/amazon_qa 11 | DATA_DIR=/fs/clip-scratch/raosudha/clarification_question_generation/joint_learning/$SITENAME 12 | CROWDFLOWER_DIR=/fs/clip-amr/clarification_question_generation_pytorch/evaluation/$SITENAME 13 | SCRIPT_DIR=/fs/clip-amr/clarification_question_generation_pytorch/src/evaluation 14 | 15 | export PATH="/fs/clip-amr/anaconda2/bin:$PATH" 16 | 17 | python $SCRIPT_DIR/create_crowdflower_data.py --qa_data_fname $CORPORA_DIR/qa_${SITENAME}.json.gz \ 18 | --metadata_fname $CORPORA_DIR/meta_${SITENAME}.json.gz \ 19 | --batch1_csv_file $CROWDFLOWER_DIR/crowdflower_lucene_seq2seq_rl_gan_diverse_beam_epoch8.batch1.csv \ 20 | --batch2_csv_file $CROWDFLOWER_DIR/crowdflower_lucene_seq2seq_rl_gan_beam_epoch8.batch2.csv \ 21 | --batch3_csv_file $CROWDFLOWER_DIR/crowdflower_lucene_seq2seq_rl_gan_beam_epoch8.batch3.csv \ 22 | --csv_file $CROWDFLOWER_DIR/crowdflower_lucene_seq2seq_rl_gan_beam_epoch8.batch4.csv \ 23 | --lucene_model_name lucene \ 24 | --lucene_model_fname $DATA_DIR/blind_test_pred_question.lucene.txt \ 25 | --seq2seq_model_name seq2seq.beam \ 26 | --seq2seq_model_fname $DATA_DIR/blind_test_pred_question.txt.seq2seq.len_norm.beam0 \ 27 | --rl_model_name rl.beam \ 28 | --rl_model_fname $DATA_DIR/blind_test_pred_question.txt.RL_selfcritic.epoch8.len_norm.beam0 \ 29 | --gan_model_name gan.beam \ 30 | --gan_model_fname $DATA_DIR/blind_test_pred_question.txt.GAN_selfcritic_pred_ans_3perid.epoch8.len_norm.beam0 \ 31 | -------------------------------------------------------------------------------- /src/evaluation/run_create_crowdflower_data_beam.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=create_crowdflower_HK_beam 4 | #SBATCH --output=create_crowdflower_HK_beam 5 | #SBATCH --qos=batch 6 | #SBATCH --mem=32g 7 | #SBATCH --time=4:00:00 8 | 9 | SITENAME=Home_and_Kitchen 10 | CORPORA_DIR=/fs/clip-corpora/amazon_qa 11 | DATA_DIR=/fs/clip-scratch/raosudha/clarification_question_generation/joint_learning/$SITENAME 12 | CROWDFLOWER_DIR=/fs/clip-amr/clarification_question_generation_pytorch/evaluation/$SITENAME 13 | SCRIPT_DIR=/fs/clip-amr/clarification_question_generation_pytorch/src/evaluation 14 | 15 | export PATH="/fs/clip-amr/anaconda2/bin:$PATH" 16 | 17 | python $SCRIPT_DIR/create_crowdflower_data_beam.py --previous_csv_file $CROWDFLOWER_DIR/crowdflower_lucene_seq2seq_rl_gan_diverse_beam_seq2seq_beam_gan_beam_rl_beam_epoch8.batch1.aggregate.csv \ 18 | --output_csv_file $CROWDFLOWER_DIR/crowdflower_lucene_seq2seq_rl_gan_beam.batch1.csv \ -------------------------------------------------------------------------------- /src/evaluation/run_create_crowdflower_data_compare_ques.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=create_crowdflower_HK_compare_batchD 4 | #SBATCH --output=create_crowdflower_HK_compare_batchD 5 | #SBATCH --qos=batch 6 | #SBATCH --mem=32g 7 | #SBATCH --time=4:00:00 8 | 9 | SITENAME=Home_and_Kitchen 10 | CORPORA_DIR=/fs/clip-corpora/amazon_qa 11 | DATA_DIR=/fs/clip-amr/clarification_question_generation_pytorch/$SITENAME 12 | CROWDFLOWER_DIR=/fs/clip-amr/clarification_question_generation_pytorch/evaluation/$SITENAME 13 | SCRIPT_DIR=/fs/clip-amr/clarification_question_generation_pytorch/src/evaluation 14 | 15 | export PATH="/fs/clip-amr/anaconda2/bin:$PATH" 16 | 17 | python $SCRIPT_DIR/create_crowdflower_data_compare_ques.py --qa_data_fname $CORPORA_DIR/qa_${SITENAME}.json.gz \ 18 | --metadata_fname $CORPORA_DIR/meta_${SITENAME}.json.gz \ 19 | --csv_file $CROWDFLOWER_DIR/crowdflower_compare_ques_batchD_100.csv \ 20 | --train_asins $DATA_DIR/train_asin.txt \ 21 | --previous_csv_file_v1 $CROWDFLOWER_DIR/crowdflower_compare_ques_batchA_100.csv \ 22 | --previous_csv_file_v2 $CROWDFLOWER_DIR/crowdflower_compare_ques_allpairs.csv \ 23 | --previous_csv_file_v3 $CROWDFLOWER_DIR/crowdflower_compare_ques_batchB_100.csv \ 24 | --previous_csv_file_v4 $CROWDFLOWER_DIR/crowdflower_compare_ques_batchC_100.csv \ 25 | -------------------------------------------------------------------------------- /src/evaluation/run_create_crowdflower_data_compare_ques_allpairs.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=create_crowdflower_HK_compare_allpairs 4 | #SBATCH --output=create_crowdflower_HK_compare_allpairs 5 | #SBATCH --qos=batch 6 | #SBATCH --mem=32g 7 | #SBATCH --time=4:00:00 8 | 9 | SITENAME=Home_and_Kitchen 10 | CORPORA_DIR=/fs/clip-corpora/amazon_qa 11 | DATA_DIR=/fs/clip-scratch/raosudha/clarification_question_generation/joint_learning/$SITENAME 12 | CROWDFLOWER_DIR=/fs/clip-amr/clarification_question_generation_pytorch/evaluation/$SITENAME 13 | SCRIPT_DIR=/fs/clip-amr/clarification_question_generation_pytorch/src/evaluation 14 | 15 | export PATH="/fs/clip-amr/anaconda2/bin:$PATH" 16 | 17 | python $SCRIPT_DIR/create_crowdflower_data_compare_ques_allpairs.py --qa_data_fname $CORPORA_DIR/qa_${SITENAME}.json.gz \ 18 | --metadata_fname $CORPORA_DIR/meta_${SITENAME}.json.gz \ 19 | --csv_file $CROWDFLOWER_DIR/crowdflower_compare_ques_allpairs.csv \ 20 | --train_asins $DATA_DIR/train_asin.txt \ 21 | --previous_csv_file $CROWDFLOWER_DIR/crowdflower_compare_ques_pilot.csv 22 | -------------------------------------------------------------------------------- /src/evaluation/run_create_crowdflower_data_specificity.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=create_crowdflower_HK_beam_spec_single_sent 4 | #SBATCH --output=create_crowdflower_HK_beam_spec_single_sent 5 | #SBATCH --qos=batch 6 | #SBATCH --mem=32g 7 | #SBATCH --time=4:00:00 8 | 9 | SITENAME=Home_and_Kitchen 10 | CORPORA_DIR=/fs/clip-corpora/amazon_qa 11 | DATA_DIR=/fs/clip-amr/clarification_question_generation_pytorch/$SITENAME 12 | CROWDFLOWER_DIR=/fs/clip-amr/clarification_question_generation_pytorch/evaluation/$SITENAME 13 | SCRIPT_DIR=/fs/clip-amr/clarification_question_generation_pytorch/src/evaluation 14 | 15 | export PATH="/fs/clip-amr/anaconda2/bin:$PATH" 16 | 17 | python $SCRIPT_DIR/create_crowdflower_data_specificity.py --qa_data_fname $CORPORA_DIR/qa_${SITENAME}.json.gz \ 18 | --metadata_fname $CORPORA_DIR/meta_${SITENAME}.json.gz \ 19 | --csv_file $CROWDFLOWER_DIR/crowdflower_seq2seq_epoch100_seq2seq_specific_seq2seq_generic_p100_q30_style_emb_single_sent.epoch100.csv \ 20 | --lucene_model_name lucene \ 21 | --lucene_model_fname $DATA_DIR/blind_test_pred_question.lucene.txt \ 22 | --seq2seq_model_name seq2seq \ 23 | --seq2seq_model_fname $DATA_DIR/blind_test_pred_ques.txt.seq2seq.epoch100.beam0 \ 24 | --seq2seq_specific_model_name seq2seq.specific \ 25 | --seq2seq_specific_model_fname $DATA_DIR/blind_test_pred_ques.txt.seq2seq_tobeginning_tospecific_p100_q30_style_emb.epoch100.beam0 \ 26 | --seq2seq_generic_model_name seq2seq.generic \ 27 | --seq2seq_generic_model_fname $DATA_DIR/blind_test_pred_ques.txt.seq2seq_tobeginning_togeneric_p100_q30_style_emb.epoch100.beam0 \ 28 | -------------------------------------------------------------------------------- /src/evaluation/run_create_crowdflower_data_specificity_multi.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=create_crowdflower_HK_beam_spec_multi_sent 4 | #SBATCH --output=create_crowdflower_HK_beam_spec_multi_sent 5 | #SBATCH --qos=batch 6 | #SBATCH --mem=32g 7 | #SBATCH --time=4:00:00 8 | 9 | SITENAME=Home_and_Kitchen 10 | CORPORA_DIR=/fs/clip-corpora/amazon_qa 11 | DATA_DIR=/fs/clip-amr/clarification_question_generation_pytorch/$SITENAME 12 | CROWDFLOWER_DIR=/fs/clip-amr/clarification_question_generation_pytorch/evaluation/$SITENAME 13 | SCRIPT_DIR=/fs/clip-amr/clarification_question_generation_pytorch/src/evaluation 14 | 15 | export PATH="/fs/clip-amr/anaconda2/bin:$PATH" 16 | 17 | python $SCRIPT_DIR/create_crowdflower_data_specificity.py --qa_data_fname $CORPORA_DIR/qa_${SITENAME}.json.gz \ 18 | --metadata_fname $CORPORA_DIR/meta_${SITENAME}.json.gz \ 19 | --prev_csv_file $CROWDFLOWER_DIR/crowdflower_seq2seq_epoch100_seq2seq_specific_seq2seq_generic_p100_q30_style_emb_single_sent.epoch100.csv \ 20 | --csv_file $CROWDFLOWER_DIR/crowdflower_seq2seq_epoch100_seq2seq_specific_seq2seq_generic_p100_q30_style_emb_multi_sent.epoch100.csv \ 21 | --lucene_model_name lucene \ 22 | --lucene_model_fname $DATA_DIR/blind_test_pred_question.lucene.txt \ 23 | --seq2seq_model_name seq2seq \ 24 | --seq2seq_model_fname $DATA_DIR/blind_test_pred_ques.txt.seq2seq.epoch100.beam0 \ 25 | --seq2seq_specific_model_name seq2seq.specific \ 26 | --seq2seq_specific_model_fname $DATA_DIR/blind_test_pred_ques.txt.seq2seq_tobeginning_tospecific_p100_q30_style_emb.epoch100.beam0 \ 27 | --seq2seq_generic_model_name seq2seq.generic \ 28 | --seq2seq_generic_model_fname $DATA_DIR/blind_test_pred_ques.txt.seq2seq_tobeginning_togeneric_p100_q30_style_emb.epoch100.beam0 \ 29 | -------------------------------------------------------------------------------- /src/evaluation/run_create_preds_for_refs.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --qos=batch 4 | #SBATCH --mem=4g 5 | #SBATCH --time=05:00:00 6 | 7 | SITENAME=askubuntu_unix_superuser 8 | CQ_DATA_DIR=/fs/clip-amr/clarification_question_generation_pytorch/$SITENAME 9 | SCRIPT_DIR=/fs/clip-amr/clarification_question_generation_pytorch/src/evaluation 10 | 11 | PRED_FILE=test_pred_question.txt.GAN_selfcritic_pred_ans_util_dis.epoch8 12 | 13 | python $SCRIPT_DIR/create_preds_for_refs.py --qa_data_tsvfile $CQ_DATA_DIR/qa_data.tsv \ 14 | --test_ids_file $CQ_DATA_DIR/test_ids \ 15 | --human_annotations $CQ_DATA_DIR/human_annotations \ 16 | --model_output_file $CQ_DATA_DIR/$PRED_FILE 17 | 18 | -------------------------------------------------------------------------------- /src/evaluation/run_meteor.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=meteor 4 | #SBATCH --output=meteor 5 | #SBATCH --qos=batch 6 | #SBATCH --mem=4g 7 | #SBATCH --time=4:00:00 8 | 9 | SITENAME=askubuntu_unix_superuser 10 | 11 | METEOR=/fs/clip-software/user-supported/meteor-1.5 12 | CQ_DATA_DIR=/fs/clip-amr/clarification_question_generation_pytorch/$SITENAME 13 | RESULTS_DIR=/fs/clip-amr/clarification_question_generation_pytorch/evaluation/results/$SITENAME 14 | 15 | TEST_SET=test_pred_question.txt.RL_selfcritic.epoch8.hasrefs.nounks 16 | 17 | java -Xmx2G -jar $METEOR/meteor-1.5.jar $CQ_DATA_DIR/$TEST_SET $CQ_DATA_DIR/test_ref_combined \ 18 | -l en -norm -r 6 \ 19 | > $RESULTS_DIR/${TEST_SET}.meteor 20 | -------------------------------------------------------------------------------- /src/evaluation/run_meteor_HK.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=meteor 4 | #SBATCH --qos=batch 5 | #SBATCH --mem=4g 6 | #SBATCH --time=4:00:00 7 | 8 | METEOR=/fs/clip-software/user-supported/meteor-1.5 9 | CQ_DATA_DIR=/fs/clip-amr/clarification_question_generation_pytorch/Home_and_Kitchen/ 10 | RESULTS_DIR=/fs/clip-amr/clarification_question_generation_pytorch/evaluation/results/Home_and_Kitchen 11 | 12 | TEST_SET=test_pred_ques.txt.seq2seq.epoch100.beam0.nounks 13 | 14 | java -Xmx2G -jar $METEOR/meteor-1.5.jar $CQ_DATA_DIR/$TEST_SET $CQ_DATA_DIR/test_ref_combined \ 15 | -l en -norm -r 10 \ 16 | > $RESULTS_DIR/${TEST_SET}.meteor 17 | -------------------------------------------------------------------------------- /src/lucene/create_amazon_lucene_baseline.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import sys, os 3 | from collections import defaultdict 4 | import csv 5 | import math 6 | import pdb 7 | import random 8 | import gzip 9 | 10 | 11 | def parse(path): 12 | g = gzip.open(path, 'r') 13 | for l in g: 14 | yield eval(l) 15 | 16 | 17 | def get_brand_info(metadata_fname): 18 | brand_info = {} 19 | for v in parse(metadata_fname): 20 | if 'description' not in v or 'title' not in v: 21 | continue 22 | asin = v['asin'] 23 | if 'brand' not in v.keys(): 24 | brand_info[asin] = None 25 | else: 26 | brand_info[asin] = v['brand'] 27 | return brand_info 28 | 29 | 30 | def create_lucene_preds(ids, args, quess, sim_prod): 31 | pred_file = open(args.lucene_pred_fname, 'w') 32 | for test_id in ids: 33 | prod_id = test_id.split('_')[0] 34 | choices = [] 35 | for i in range(min(len(sim_prod[prod_id]), 3)): 36 | choices += quess[sim_prod[prod_id][i]][:3] 37 | if len(choices) == 0: 38 | pred_file.write('\n') 39 | else: 40 | pred_ques = random.choice(choices) 41 | pred_file.write(pred_ques+'\n') 42 | pred_file.close() 43 | 44 | 45 | def get_sim_docs(sim_docs_filename, brand_info): 46 | sim_docs_file = open(sim_docs_filename, 'r') 47 | sim_docs = {} 48 | for line in sim_docs_file.readlines(): 49 | parts = line.split() 50 | #sim_docs[parts[0]] = parts[1:] 51 | asin = parts[0] 52 | sim_docs[asin] = [] 53 | if len(parts[1:]) == 0: 54 | continue 55 | for prod_id in parts[2:]: 56 | if brand_info[prod_id] and (brand_info[prod_id] != brand_info[asin]): 57 | sim_docs[asin].append(prod_id) 58 | if len(sim_docs[asin]) == 0: 59 | sim_docs[asin] = parts[10:13] 60 | if len(sim_docs[asin]) == 0: 61 | pdb.set_trace() 62 | return sim_docs 63 | 64 | 65 | def read_data(args): 66 | print("Reading lines...") 67 | quess = {} 68 | test_ids = [test_id.strip('\n') for test_id in open(args.test_ids_file, 'r').readlines()] 69 | print 'No. of test ids: %d' % len(test_ids) 70 | quess_rand = defaultdict(list) 71 | for fname in os.listdir(args.ques_dir): 72 | with open(os.path.join(args.ques_dir, fname), 'r') as f: 73 | ques_id = fname[:-4] 74 | asin, q_no = ques_id.split('_') 75 | ques = f.readline().strip('\n') 76 | quess_rand[asin].append((ques, q_no)) 77 | 78 | for asin in quess_rand: 79 | quess[asin] = [None]*len(quess_rand[asin]) 80 | for (ques, q_no) in quess_rand[asin]: 81 | q_no = int(q_no)-1 82 | quess[asin][q_no] = ques 83 | 84 | brand_info = get_brand_info(args.metadata_fname) 85 | sim_prod = get_sim_docs(args.sim_prod_fname, brand_info) 86 | create_lucene_preds(test_ids, args, quess, sim_prod) 87 | 88 | 89 | if __name__ == "__main__": 90 | argparser = argparse.ArgumentParser(sys.argv[0]) 91 | argparser.add_argument("--ques_dir", type = str) 92 | argparser.add_argument("--sim_prod_fname", type = str) 93 | argparser.add_argument("--test_ids_file", type = str) 94 | argparser.add_argument("--lucene_pred_fname", type = str) 95 | argparser.add_argument("--metadata_fname", type = str) 96 | args = argparser.parse_args() 97 | print args 98 | print "" 99 | read_data(args) 100 | -------------------------------------------------------------------------------- /src/lucene/create_stackexchange_lucene_baseline.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import sys, os 3 | from collections import defaultdict 4 | import csv 5 | import math 6 | import pdb 7 | import random 8 | 9 | def read_data(args): 10 | print("Reading lines...") 11 | posts = {} 12 | question_candidates = {} 13 | with open(args.post_data_tsvfile, 'rb') as tsvfile: 14 | post_reader = csv.reader(tsvfile, delimiter='\t') 15 | N = 0 16 | for row in post_reader: 17 | if N == 0: 18 | N += 1 19 | continue 20 | N += 1 21 | post_id,title,post = row 22 | post = title + ' ' + post 23 | post = post.lower().strip() 24 | posts[post_id] = post 25 | 26 | with open(args.qa_data_tsvfile, 'rb') as tsvfile: 27 | qa_reader = csv.reader(tsvfile, delimiter='\t') 28 | i = 0 29 | for row in qa_reader: 30 | if i == 0: 31 | i += 1 32 | continue 33 | post_id,questions = row[0], row[2:11] #Ignore the first question since that is the true question 34 | question_candidates[post_id] = questions 35 | 36 | test_ids = [test_id.strip('\n') for test_id in open(args.test_ids_file, 'r').readlines()] 37 | 38 | lucene_out_file = open(args.lucene_output_file, 'w') 39 | for i, test_id in enumerate(test_ids): 40 | r = random.randint(0,8) 41 | lucene_out_file.write(question_candidates[test_id][r] + '\n') 42 | lucene_out_file.close() 43 | 44 | if __name__ == "__main__": 45 | argparser = argparse.ArgumentParser(sys.argv[0]) 46 | argparser.add_argument("--post_data_tsvfile", type = str) 47 | argparser.add_argument("--qa_data_tsvfile", type = str) 48 | argparser.add_argument("--test_ids_file", type = str) 49 | argparser.add_argument("--lucene_output_file", type = str) 50 | args = argparser.parse_args() 51 | print args 52 | print "" 53 | read_data(args) 54 | 55 | -------------------------------------------------------------------------------- /src/lucene/run_create_amazon_lucene_baseline.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=lucene_Home_and_Kitchen_test 4 | #SBATCH --output=lucene_Home_and_Kitchen_test 5 | #SBATCH --qos=batch 6 | #SBATCH --mem=4g 7 | #SBATCH --time=4:00:00 8 | 9 | SITENAME=Home_and_Kitchen 10 | SCRIPT_DIR=/fs/clip-amr/clarification_question_generation_pytorch/src/lucene 11 | CQ_DATA_DIR=/fs/clip-amr/clarification_question_generation_pytorch/$SITENAME 12 | 13 | python $SCRIPT_DIR/create_amazon_lucene_baseline.py --ques_dir $CQ_DATA_DIR/ques_docs \ 14 | --sim_prod_fname $CQ_DATA_DIR/lucene_similar_prods.txt \ 15 | --test_ids_file $CQ_DATA_DIR/blind_test_pred_ques.txt.seq2seq.epoch100.beam0.ids \ 16 | --lucene_pred_fname $CQ_DATA_DIR/blind_test_pred_question.lucene.txt \ 17 | --metadata_fname $CQ_DATA_DIR/meta_Home_and_Kitchen.json.gz \ 18 | 19 | -------------------------------------------------------------------------------- /src/lucene/run_create_stackexchange_lucene_baseline.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | SITENAME=askubuntu_unix_superuser 4 | SCRIPT_DIR=/fs/clip-amr/clarification_question_generation_pytorch/src/lucene 5 | CQ_DATA_DIR=/fs/clip-amr/clarification_question_generation/$SITENAME 6 | 7 | python $SCRIPT_DIR/create_stackexchange_lucene_baseline.py --post_data_tsvfile $CQ_DATA_DIR/post_data.tsv \ 8 | --qa_data_tsvfile $CQ_DATA_DIR/qa_data.tsv \ 9 | --test_ids_file $CQ_DATA_DIR/test_ids \ 10 | --lucene_output_file $CQ_DATA_DIR/test_pred_question_lucene.txt \ 11 | 12 | -------------------------------------------------------------------------------- /src/run_GAN_decode.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | SITENAME=askubuntu_unix_superuser 4 | 5 | CQ_DATA_DIR=clarification_question_generation_pytorch/$SITENAME 6 | SCRIPT_DIR=clarification_question_generation_pytorch/src 7 | EMB_DIR=clarification_question_generation_pytorch/embeddings/$SITENAME 8 | 9 | python $SCRIPT_DIR/decode.py --test_context $CQ_DATA_DIR/test_context.txt \ 10 | --test_ques $CQ_DATA_DIR/test_question.txt \ 11 | --test_ans $CQ_DATA_DIR/test_answer.txt \ 12 | --test_ids $CQ_DATA_DIR/test_ids \ 13 | --test_pred_ques $CQ_DATA_DIR/test_pred_question.txt \ 14 | --q_encoder_params $CQ_DATA_DIR/q_encoder_params.epoch100.GAN_selfcritic_pred_ans.epoch12 \ 15 | --q_decoder_params $CQ_DATA_DIR/q_decoder_params.epoch100.GAN_selfcritic_pred_ans.epoch12 \ 16 | --word_embeddings $EMB_DIR/word_embeddings.p \ 17 | --vocab $EMB_DIR/vocab.p \ 18 | --model GAN.epoch12 \ 19 | --max_post_len 100 \ 20 | --max_ques_len 20 \ 21 | --max_ans_len 20 \ 22 | --batch_size 256 \ 23 | --beam True \ 24 | 25 | -------------------------------------------------------------------------------- /src/run_GAN_decode_HK.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | SITENAME=Home_and_Kitchen 4 | 5 | CQ_DATA_DIR=clarification_question_generation_pytorch/$SITENAME 6 | SCRIPT_DIR=clarification_question_generation_pytorch/src 7 | EMB_DIR=clarification_question_generation_pytorch/embeddings/$SITENAME 8 | 9 | PARAMS_DIR=$CQ_DATA_DIR 10 | 11 | python $SCRIPT_DIR/decode.py --test_context $CQ_DATA_DIR/test_context.txt \ 12 | --test_ques $CQ_DATA_DIR/test_ques.txt \ 13 | --test_ans $CQ_DATA_DIR/test_ans.txt \ 14 | --test_ids $CQ_DATA_DIR/test_asin.txt \ 15 | --test_pred_ques $CQ_DATA_DIR/blind_test_pred_ques.txt \ 16 | --q_encoder_params $PARAMS_DIR/q_encoder_params.epoch100.GAN_selfcritic_pred_ans.epoch12 \ 17 | --q_decoder_params $PARAMS_DIR/q_decoder_params.epoch100.GAN_selfcritic_pred_ans.epoch12 \ 18 | --word_embeddings $EMB_DIR/word_embeddings.p \ 19 | --vocab $EMB_DIR/vocab.p \ 20 | --model GAN_selfcritic_pred_ans.epoch12 \ 21 | --max_post_len 100 \ 22 | --max_ques_len 20 \ 23 | --max_ans_len 20 \ 24 | --batch_size 128 \ 25 | --beam True 26 | 27 | -------------------------------------------------------------------------------- /src/run_GAN_main.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | SITENAME=askubuntu_unix_superuser 4 | 5 | CQ_DATA_DIR=clarification_question_generation_pytorch/$SITENAME 6 | SCRIPT_DIR=clarification_question_generation_pytorch/src 7 | EMB_DIR=clarification_question_generation_pytorch/embeddings/$SITENAME 8 | 9 | python $SCRIPT_DIR/GAN_main.py --train_context $CQ_DATA_DIR/train_context.txt \ 10 | --train_ques $CQ_DATA_DIR/train_question.txt \ 11 | --train_ans $CQ_DATA_DIR/train_answer.txt \ 12 | --tune_context $CQ_DATA_DIR/tune_context.txt \ 13 | --tune_ques $CQ_DATA_DIR/tune_question.txt \ 14 | --tune_ans $CQ_DATA_DIR/tune_answer.txt \ 15 | --test_context $CQ_DATA_DIR/test_context.txt \ 16 | --test_ques $CQ_DATA_DIR/test_question.txt \ 17 | --test_ans $CQ_DATA_DIR/test_answer.txt \ 18 | --test_pred_ques $CQ_DATA_DIR/GAN_test_pred_question.txt \ 19 | --q_encoder_params $CQ_DATA_DIR/q_encoder_params.epoch100 \ 20 | --q_decoder_params $CQ_DATA_DIR/q_decoder_params.epoch100 \ 21 | --a_encoder_params $CQ_DATA_DIR/a_encoder_params.epoch100 \ 22 | --a_decoder_params $CQ_DATA_DIR/a_decoder_params.epoch100 \ 23 | --context_params $CQ_DATA_DIR/context_params.epoch10 \ 24 | --question_params $CQ_DATA_DIR/question_params.epoch10 \ 25 | --answer_params $CQ_DATA_DIR/answer_params.epoch10 \ 26 | --utility_params $CQ_DATA_DIR/utility_params.epoch10 \ 27 | --word_embeddings $EMB_DIR/word_embeddings.p \ 28 | --vocab $EMB_DIR/vocab.p \ 29 | --model GAN_selfcritic_pred_ans \ 30 | --max_post_len 100 \ 31 | --max_ques_len 20 \ 32 | --max_ans_len 20 \ 33 | --batch_size 256 \ 34 | --n_epochs 40 \ 35 | 36 | -------------------------------------------------------------------------------- /src/run_GAN_main_HK.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | SITENAME=Home_and_Kitchen 4 | 5 | CQ_DATA_DIR=clarification_question_generation_pytorch/$SITENAME 6 | SCRIPT_DIR=clarification_question_generation_pytorch/src 7 | EMB_DIR=clarification_question_generation_pytorch/embeddings/$SITENAME 8 | 9 | PARAMS_DIR=$CQ_DATA_DIR 10 | 11 | python $SCRIPT_DIR/GAN_main.py --train_context $CQ_DATA_DIR/train_context.txt \ 12 | --train_ques $CQ_DATA_DIR/train_ques.txt \ 13 | --train_ans $CQ_DATA_DIR/train_ans.txt \ 14 | --train_ids $CQ_DATA_DIR/train_asin.txt \ 15 | --tune_context $CQ_DATA_DIR/tune_context.txt \ 16 | --tune_ques $CQ_DATA_DIR/tune_ques.txt \ 17 | --tune_ans $CQ_DATA_DIR/tune_ans.txt \ 18 | --tune_ids $CQ_DATA_DIR/tune_asin.txt \ 19 | --test_context $CQ_DATA_DIR/test_context.txt \ 20 | --test_ques $CQ_DATA_DIR/test_ques.txt \ 21 | --test_ans $CQ_DATA_DIR/test_ans.txt \ 22 | --test_ids $CQ_DATA_DIR/test_asin.txt \ 23 | --test_pred_ques $CQ_DATA_DIR/blind_test_pred_ques.txt \ 24 | --q_encoder_params $PARAMS_DIR/q_encoder_params.epoch100 \ 25 | --q_decoder_params $PARAMS_DIR/q_decoder_params.epoch100 \ 26 | --a_encoder_params $PARAMS_DIR/a_encoder_params.epoch100 \ 27 | --a_decoder_params $PARAMS_DIR/a_decoder_params.epoch100 \ 28 | --context_params $PARAMS_DIR/context_params.epoch10 \ 29 | --question_params $PARAMS_DIR/question_params.epoch10 \ 30 | --answer_params $PARAMS_DIR/answer_params.epoch10 \ 31 | --utility_params $PARAMS_DIR/utility_params.epoch10 \ 32 | --word_embeddings $EMB_DIR/word_embeddings.p \ 33 | --vocab $EMB_DIR/vocab.p \ 34 | --model GAN_selfcritic_pred_ans \ 35 | --max_post_len 100 \ 36 | --max_ques_len 20 \ 37 | --max_ans_len 20 \ 38 | --batch_size 64 \ 39 | --n_epochs 20 \ 40 | 41 | -------------------------------------------------------------------------------- /src/run_GAN_main_electronics.sh: -------------------------------------------------------------------------------- 1 | python src/GAN_main.py --train_context baseline_data/train_context.txt \ 2 | --train_ques baseline_data/train_question.txt \ 3 | --train_ans baseline_data/train_answer.txt \ 4 | --train_ids baseline_data/train_asin.txt \ 5 | --tune_context baseline_data/valid_context.txt \ 6 | --tune_ques baseline_data/valid_question.txt \ 7 | --tune_ans baseline_data/valid_answer.txt \ 8 | --tune_ids baseline_data/valid_asin.txt \ 9 | --test_context baseline_data/test_context.txt \ 10 | --test_ques baseline_data/test_question.txt \ 11 | --test_ans baseline_data/test_answer.txt \ 12 | --test_ids baseline_data/test_asin.txt \ 13 | --test_pred_ques baseline_data/gan_ques49_ans43_util10_valid_predicted_question.txt \ 14 | --q_encoder_params baseline_data/seq2seq-pretrain-ques-v3/q_encoder_params.epoch49 \ 15 | --q_decoder_params baseline_data/seq2seq-pretrain-ques-v3/q_decoder_params.epoch49 \ 16 | --a_encoder_params baseline_data/seq2seq-pretrain-ans-v3/a_encoder_params.epoch43 \ 17 | --a_decoder_params baseline_data/seq2seq-pretrain-ans-v3/a_decoder_params.epoch43 \ 18 | --context_params baseline_data/seq2seq-pretrain-util-v3/context_params.epoch10 \ 19 | --question_params baseline_data/seq2seq-pretrain-util-v3/question_params.epoch10 \ 20 | --answer_params baseline_data/seq2seq-pretrain-util-v3/answer_params.epoch10 \ 21 | --utility_params baseline_data/seq2seq-pretrain-util-v3/utility_params.epoch10 \ 22 | --word_embeddings embeddings/amazon_200d_embeddings.p \ 23 | --vocab embeddings/amazon_200d_vocab.p \ 24 | --model GAN_selfcritic_ques49_ans43_util10 \ 25 | --max_post_len 100 \ 26 | --max_ques_len 20 \ 27 | --max_ans_len 20 \ 28 | --batch_size 64 \ 29 | --n_epochs 20 \ 30 | 31 | -------------------------------------------------------------------------------- /src/run_RL_decode.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | SITENAME=askubuntu_unix_superuser 4 | 5 | CQ_DATA_DIR=clarification_question_generation_pytorch/$SITENAME 6 | SCRIPT_DIR=clarification_question_generation_pytorch/src 7 | EMB_DIR=clarification_question_generation_pytorch/embeddings/$SITENAME 8 | 9 | python $SCRIPT_DIR/decode.py --test_context $CQ_DATA_DIR/test_context.txt \ 10 | --test_ques $CQ_DATA_DIR/test_question.txt \ 11 | --test_ans $CQ_DATA_DIR/test_answer.txt \ 12 | --test_ids $CQ_DATA_DIR/test_ids \ 13 | --test_pred_ques $CQ_DATA_DIR/test_pred_question.txt \ 14 | --q_encoder_params $CQ_DATA_DIR/q_encoder_params.epoch100.RL_selfcritic.epoch8 \ 15 | --q_decoder_params $CQ_DATA_DIR/q_decoder_params.epoch100.RL_selfcritic.epoch8 \ 16 | --word_embeddings $EMB_DIR/word_embeddings.p \ 17 | --vocab $EMB_DIR/vocab.p \ 18 | --model RL.epoch8 \ 19 | --max_post_len 100 \ 20 | --max_ques_len 20 \ 21 | --max_ans_len 20 \ 22 | --batch_size 256 \ 23 | --beam True \ 24 | 25 | -------------------------------------------------------------------------------- /src/run_RL_decode_HK.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | SITENAME=Home_and_Kitchen 4 | 5 | CQ_DATA_DIR=clarification_question_generation_pytorch/$SITENAME 6 | SCRIPT_DIR=clarification_question_generation_pytorch/src 7 | EMB_DIR=clarification_question_generation_pytorch/embeddings/$SITENAME 8 | 9 | PARAMS_DIR=$CQ_DATA_DIR 10 | 11 | python $SCRIPT_DIR/decode.py --test_context $CQ_DATA_DIR/test_context.txt \ 12 | --test_ques $CQ_DATA_DIR/test_ques.txt \ 13 | --test_ans $CQ_DATA_DIR/test_ans.txt \ 14 | --test_ids $CQ_DATA_DIR/test_asin.txt \ 15 | --test_pred_ques $CQ_DATA_DIR/blind_test_pred_ques.txt \ 16 | --q_encoder_params $PARAMS_DIR/q_encoder_params.epoch100.RL_selfcritic.epoch8 \ 17 | --q_decoder_params $PARAMS_DIR/q_decoder_params.epoch100.RL_selfcritic.epoch8 \ 18 | --word_embeddings $EMB_DIR/word_embeddings.p \ 19 | --vocab $EMB_DIR/vocab.p \ 20 | --model RL_selfcritic.epoch8 \ 21 | --max_post_len 100 \ 22 | --max_ques_len 20 \ 23 | --max_ans_len 20 \ 24 | --batch_size 128 \ 25 | --beam True 26 | 27 | -------------------------------------------------------------------------------- /src/run_RL_main.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | SITENAME=askubuntu_unix_superuser 4 | 5 | CQ_DATA_DIR=clarification_question_generation_pytorch/$SITENAME 6 | SCRIPT_DIR=clarification_question_generation_pytorch/src 7 | EMB_DIR=clarification_question_generation_pytorch/embeddings/$SITENAME 8 | 9 | python $SCRIPT_DIR/RL_main.py --train_context $CQ_DATA_DIR/train_context.txt \ 10 | --train_ques $CQ_DATA_DIR/train_question.txt \ 11 | --train_ans $CQ_DATA_DIR/train_answer.txt \ 12 | --tune_context $CQ_DATA_DIR/tune_context.txt \ 13 | --tune_ques $CQ_DATA_DIR/tune_question.txt \ 14 | --tune_ans $CQ_DATA_DIR/tune_answer.txt \ 15 | --test_context $CQ_DATA_DIR/test_context.txt \ 16 | --test_ques $CQ_DATA_DIR/test_question.txt \ 17 | --test_ans $CQ_DATA_DIR/test_answer.txt \ 18 | --test_pred_ques $CQ_DATA_DIR/test_pred_question.txt \ 19 | --q_encoder_params $CQ_DATA_DIR/q_encoder_params.epoch100 \ 20 | --q_decoder_params $CQ_DATA_DIR/q_decoder_params.epoch100 \ 21 | --a_encoder_params $CQ_DATA_DIR/a_encoder_params.epoch100 \ 22 | --a_decoder_params $CQ_DATA_DIR/a_decoder_params.epoch100 \ 23 | --context_params $CQ_DATA_DIR/context_params.epoch10 \ 24 | --question_params $CQ_DATA_DIR/question_params.epoch10 \ 25 | --answer_params $CQ_DATA_DIR/answer_params.epoch10 \ 26 | --utility_params $CQ_DATA_DIR/utility_params.epoch10 \ 27 | --word_embeddings $EMB_DIR/word_embeddings.p \ 28 | --vocab $EMB_DIR/vocab.p \ 29 | --model RL_selfcritic \ 30 | --max_post_len 100 \ 31 | --max_ques_len 20 \ 32 | --max_ans_len 20 \ 33 | --batch_size 256 \ 34 | --n_epochs 40 \ 35 | 36 | -------------------------------------------------------------------------------- /src/run_RL_main_HK.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | SITENAME=Home_and_Kitchen 4 | 5 | CQ_DATA_DIR=clarification_question_generation_pytorch/$SITENAME 6 | SCRIPT_DIR=clarification_question_generation_pytorch/src 7 | EMB_DIR=clarification_question_generation_pytorch/embeddings/$SITENAME 8 | 9 | PARAMS_DIR=$CQ_DATA_DIR 10 | 11 | python $SCRIPT_DIR/RL_main.py --train_context $CQ_DATA_DIR/train_context.txt \ 12 | --train_ques $CQ_DATA_DIR/train_ques.txt \ 13 | --train_ans $CQ_DATA_DIR/train_ans.txt \ 14 | --train_ids $CQ_DATA_DIR/train_asin.txt \ 15 | --tune_context $CQ_DATA_DIR/tune_context.txt \ 16 | --tune_ques $CQ_DATA_DIR/tune_ques.txt \ 17 | --tune_ans $CQ_DATA_DIR/tune_ans.txt \ 18 | --tune_ids $CQ_DATA_DIR/tune_asin.txt \ 19 | --test_context $CQ_DATA_DIR/test_context.txt \ 20 | --test_ques $CQ_DATA_DIR/test_ques.txt \ 21 | --test_ans $CQ_DATA_DIR/test_ans.txt \ 22 | --test_ids $CQ_DATA_DIR/test_asin.txt \ 23 | --test_pred_ques $CQ_DATA_DIR/test_pred_ques.txt \ 24 | --q_encoder_params $PARAMS_DIR/q_encoder_params.epoch100 \ 25 | --q_decoder_params $PARAMS_DIR/q_decoder_params.epoch100 \ 26 | --a_encoder_params $PARAMS_DIR/a_encoder_params.epoch100 \ 27 | --a_decoder_params $PARAMS_DIR/a_decoder_params.epoch100 \ 28 | --context_params $PARAMS_DIR/context_params.epoch10 \ 29 | --question_params $PARAMS_DIR/question_params.epoch10 \ 30 | --answer_params $PARAMS_DIR/answer_params.epoch10 \ 31 | --utility_params $PARAMS_DIR/utility_params.epoch10 \ 32 | --word_embeddings $EMB_DIR/word_embeddings.p \ 33 | --vocab $EMB_DIR/vocab.p \ 34 | --model RL_selfcritic \ 35 | --max_post_len 100 \ 36 | --max_ques_len 20 \ 37 | --max_ans_len 20 \ 38 | --batch_size 128 \ 39 | --n_epochs 40 \ 40 | -------------------------------------------------------------------------------- /src/run_decode.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | SITENAME=askubuntu_unix_superuser 4 | 5 | CQ_DATA_DIR=clarification_question_generation_pytorch/$SITENAME 6 | SCRIPT_DIR=clarification_question_generation_pytorch/src 7 | EMB_DIR=clarification_question_generation_pytorch/embeddings/$SITENAME 8 | 9 | python $SCRIPT_DIR/decode.py --test_context $CQ_DATA_DIR/test_context.txt \ 10 | --test_ques $CQ_DATA_DIR/test_question.txt \ 11 | --test_ans $CQ_DATA_DIR/test_answer.txt \ 12 | --test_ids $CQ_DATA_DIR/test_ids \ 13 | --test_pred_ques $CQ_DATA_DIR/test_pred_question.txt \ 14 | --q_encoder_params $CQ_DATA_DIR/q_encoder_params.epoch100 \ 15 | --q_decoder_params $CQ_DATA_DIR/q_decoder_params.epoch100 \ 16 | --word_embeddings $EMB_DIR/word_embeddings.p \ 17 | --vocab $EMB_DIR/vocab.p \ 18 | --model seq2seq.epoch100 \ 19 | --max_post_len 100 \ 20 | --max_ques_len 20 \ 21 | --max_ans_len 20 \ 22 | --batch_size 256 \ 23 | --beam True \ 24 | 25 | -------------------------------------------------------------------------------- /src/run_decode_HK.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | SITENAME=Home_and_Kitchen 4 | 5 | CQ_DATA_DIR=clarification_question_generation_pytorch/$SITENAME 6 | SCRIPT_DIR=clarification_question_generation_pytorch/src 7 | EMB_DIR=clarification_question_generation_pytorch/embeddings/$SITENAME 8 | 9 | PARAMS_DIR=$CQ_DATA_DIR 10 | 11 | python $SCRIPT_DIR/decode.py --test_context $CQ_DATA_DIR/test_context.txt \ 12 | --test_ques $CQ_DATA_DIR/test_ques.txt \ 13 | --test_ans $CQ_DATA_DIR/test_ans.txt \ 14 | --test_ids $CQ_DATA_DIR/test_asin.txt \ 15 | --test_pred_ques $CQ_DATA_DIR/blind_test_pred_ques.txt \ 16 | --q_encoder_params $PARAMS_DIR/q_encoder_params.epoch100 \ 17 | --q_decoder_params $PARAMS_DIR/q_decoder_params.epoch100 \ 18 | --word_embeddings $EMB_DIR/word_embeddings.p \ 19 | --vocab $EMB_DIR/vocab.p \ 20 | --model seq2seq.epoch12 \ 21 | --max_post_len 100 \ 22 | --max_ques_len 20 \ 23 | --max_ans_len 20 \ 24 | --batch_size 128 \ 25 | --beam True 26 | 27 | -------------------------------------------------------------------------------- /src/run_decode_electronics.sh: -------------------------------------------------------------------------------- 1 | python src/decode.py --test_context baseline_data/valid_context.txt \ 2 | --test_ques baseline_data/valid_question.txt \ 3 | --test_ans baseline_data/valid_answer.txt \ 4 | --test_ids baseline_data/valid_asin.txt \ 5 | --test_pred_ques baseline_data/valid_predicted_question.txt \ 6 | --q_encoder_params baseline_data/seq2seq-pretrain-ques-v3/q_encoder_params.epoch49 \ 7 | --q_decoder_params baseline_data/seq2seq-pretrain-ques-v3/q_decoder_params.epoch49 \ 8 | --word_embeddings embeddings/amazon_200d_embeddings.p \ 9 | --vocab embeddings/amazon_200d_vocab.p \ 10 | --model seq2seq.epoch49 \ 11 | --max_post_len 100 \ 12 | --max_ques_len 20 \ 13 | --max_ans_len 20 \ 14 | --batch_size 10 \ 15 | --n_epochs 40 \ 16 | --greedy True 17 | #--beam True 18 | -------------------------------------------------------------------------------- /src/run_main.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | SITENAME=askubuntu_unix_superuser 4 | 5 | CQ_DATA_DIR=clarification_question_generation_pytorch/$SITENAME 6 | SCRIPT_DIR=clarification_question_generation_pytorch/src 7 | EMB_DIR=clarification_question_generation_pytorch/embeddings/$SITENAME 8 | 9 | PARAMS_DIR=$CQ_DATA_DIR 10 | 11 | 12 | python $SCRIPT_DIR/main.py --train_context $CQ_DATA_DIR/train_context.txt \ 13 | --train_question $CQ_DATA_DIR/train_question.txt \ 14 | --train_answer $CQ_DATA_DIR/train_answer.txt \ 15 | --train_ids $CQ_DATA_DIR/train_ids \ 16 | --tune_context $CQ_DATA_DIR/tune_context.txt \ 17 | --tune_question $CQ_DATA_DIR/tune_question.txt \ 18 | --tune_answer $CQ_DATA_DIR/tune_answer.txt \ 19 | --tune_ids $CQ_DATA_DIR/tune_ids \ 20 | --test_context $CQ_DATA_DIR/test_context.txt \ 21 | --test_question $CQ_DATA_DIR/test_question.txt \ 22 | --test_answer $CQ_DATA_DIR/test_answer.txt \ 23 | --test_ids $CQ_DATA_DIR/test_ids \ 24 | --q_encoder_params $PARAMS_DIR/q_encoder_params \ 25 | --q_decoder_params $PARAMS_DIR/q_decoder_params \ 26 | --a_encoder_params $PARAMS_DIR/a_encoder_params \ 27 | --a_decoder_params $PARAMS_DIR/a_decoder_params \ 28 | --context_params $PARAMS_DIR/context_params \ 29 | --question_params $PARAMS_DIR/question_params \ 30 | --answer_params $PARAMS_DIR/answer_params \ 31 | --utility_params $PARAMS_DIR/utility_params \ 32 | --word_embeddings $EMB_DIR/word_embeddings.p \ 33 | --vocab $EMB_DIR/vocab.p \ 34 | --n_epochs 100 \ 35 | --max_post_len 100 \ 36 | --max_ques_len 20 \ 37 | --max_ans_len 20 \ 38 | --pretrain_ques True \ 39 | 40 | -------------------------------------------------------------------------------- /src/run_main_HK.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | SITENAME=Home_and_Kitchen 4 | 5 | CQ_DATA_DIR=clarification_question_generation_pytorch/$SITENAME 6 | SCRIPT_DIR=clarification_question_generation_pytorch/src 7 | EMB_DIR=clarification_question_generation_pytorch/embeddings/$SITENAME 8 | 9 | PARAMS_DIR=$CQ_DATA_DIR 10 | 11 | python $SCRIPT_DIR/main.py --train_context $CQ_DATA_DIR/train_context.txt \ 12 | --train_ques $CQ_DATA_DIR/train_ques.txt \ 13 | --train_ans $CQ_DATA_DIR/train_ans.txt \ 14 | --train_ids $CQ_DATA_DIR/train_asin.txt \ 15 | --tune_context $CQ_DATA_DIR/tune_context.txt \ 16 | --tune_ques $CQ_DATA_DIR/tune_ques.txt \ 17 | --tune_ans $CQ_DATA_DIR/tune_ans.txt \ 18 | --tune_ids $CQ_DATA_DIR/tune_asin.txt \ 19 | --test_context $CQ_DATA_DIR/test_context.txt \ 20 | --test_ques $CQ_DATA_DIR/test_ques.txt \ 21 | --test_ans $CQ_DATA_DIR/test_ans.txt \ 22 | --test_ids $CQ_DATA_DIR/test_asin.txt \ 23 | --q_encoder_params $PARAMS_DIR/q_encoder_params \ 24 | --q_decoder_params $PARAMS_DIR/q_decoder_params \ 25 | --a_encoder_params $PARAMS_DIR/a_encoder_params \ 26 | --a_decoder_params $PARAMS_DIR/a_decoder_params \ 27 | --context_params $PARAMS_DIR/context_params \ 28 | --question_params $PARAMS_DIR/question_params \ 29 | --answer_params $PARAMS_DIR/answer_params \ 30 | --utility_params $PARAMS_DIR/utility_params \ 31 | --word_embeddings $EMB_DIR/word_embeddings.p \ 32 | --vocab $EMB_DIR/vocab.p \ 33 | --n_epochs 100 \ 34 | --max_post_len 100 \ 35 | --max_ques_len 20 \ 36 | --max_ans_len 20 \ 37 | --pretrain_ques True \ 38 | #--pretrain_ans True \ 39 | #--pretrain_util True \ 40 | 41 | 42 | 43 | 44 | 45 | 46 | -------------------------------------------------------------------------------- /src/run_main_electronics.sh: -------------------------------------------------------------------------------- 1 | python src/main.py --train_context baseline_data/train_context.txt \ 2 | --train_ques baseline_data/train_question.txt \ 3 | --train_ans baseline_data/train_answer.txt \ 4 | --train_ids baseline_data/train_asin.txt \ 5 | --tune_context baseline_data/valid_context.txt \ 6 | --tune_ques baseline_data/valid_question.txt \ 7 | --tune_ans baseline_data/valid_answer.txt \ 8 | --tune_ids baseline_data/valid_asin.txt \ 9 | --test_context baseline_data/test_context.txt \ 10 | --test_ques baseline_data/test_question.txt \ 11 | --test_ans baseline_data/test_answer.txt \ 12 | --test_ids baseline_data/test_asin.txt \ 13 | --q_encoder_params $PT_OUTPUT_DIR/q_encoder_params \ 14 | --q_decoder_params $PT_OUTPUT_DIR/q_decoder_params \ 15 | --a_encoder_params $PT_OUTPUT_DIR/a_encoder_params \ 16 | --a_decoder_params $PT_OUTPUT_DIR/a_decoder_params \ 17 | --context_params $PT_OUTPUT_DIR/context_params \ 18 | --question_params $PT_OUTPUT_DIR/question_params \ 19 | --answer_params $PT_OUTPUT_DIR/answer_params \ 20 | --utility_params $PT_OUTPUT_DIR/utility_params \ 21 | --word_embeddings embeddings/amazon_200d_embeddings.p \ 22 | --vocab embeddings/amazon_200d_vocab.p \ 23 | --n_epochs 100 \ 24 | --batch_size 128 \ 25 | --max_post_len 100 \ 26 | --max_ques_len 20 \ 27 | --max_ans_len 20 \ 28 | --pretrain_ques True \ 29 | #--pretrain_ans True \ 30 | #--pretrain_util True \ 31 | -------------------------------------------------------------------------------- /src/run_pretrain_ans.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | SITENAME=askubuntu_unix_superuser 4 | 5 | CQ_DATA_DIR=clarification_question_generation_pytorch/$SITENAME 6 | SCRIPT_DIR=clarification_question_generation_pytorch/src 7 | EMB_DIR=clarification_question_generation_pytorch/embeddings/$SITENAME 8 | 9 | PARAMS_DIR=$CQ_DATA_DIR 10 | 11 | 12 | python $SCRIPT_DIR/main.py --train_context $CQ_DATA_DIR/train_context.txt \ 13 | --train_question $CQ_DATA_DIR/train_question.txt \ 14 | --train_answer $CQ_DATA_DIR/train_answer.txt \ 15 | --train_ids $CQ_DATA_DIR/train_ids \ 16 | --tune_context $CQ_DATA_DIR/tune_context.txt \ 17 | --tune_question $CQ_DATA_DIR/tune_question.txt \ 18 | --tune_answer $CQ_DATA_DIR/tune_answer.txt \ 19 | --tune_ids $CQ_DATA_DIR/tune_ids \ 20 | --test_context $CQ_DATA_DIR/test_context.txt \ 21 | --test_question $CQ_DATA_DIR/test_question.txt \ 22 | --test_answer $CQ_DATA_DIR/test_answer.txt \ 23 | --test_ids $CQ_DATA_DIR/test_ids \ 24 | --q_encoder_params $PARAMS_DIR/q_encoder_params \ 25 | --q_decoder_params $PARAMS_DIR/q_decoder_params \ 26 | --a_encoder_params $PARAMS_DIR/a_encoder_params \ 27 | --a_decoder_params $PARAMS_DIR/a_decoder_params \ 28 | --context_params $PARAMS_DIR/context_params \ 29 | --question_params $PARAMS_DIR/question_params \ 30 | --answer_params $PARAMS_DIR/answer_params \ 31 | --utility_params $PARAMS_DIR/utility_params \ 32 | --word_embeddings $EMB_DIR/word_embeddings.p \ 33 | --vocab $EMB_DIR/vocab.p \ 34 | --n_epochs 100 \ 35 | --max_post_len 100 \ 36 | --max_ques_len 20 \ 37 | --max_ans_len 20 \ 38 | --pretrain_ans True \ 39 | 40 | -------------------------------------------------------------------------------- /src/run_pretrain_ans_HK.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | SITENAME=Home_and_Kitchen 4 | 5 | CQ_DATA_DIR=clarification_question_generation_pytorch/$SITENAME 6 | SCRIPT_DIR=clarification_question_generation_pytorch/src 7 | EMB_DIR=clarification_question_generation_pytorch/embeddings/$SITENAME 8 | 9 | PARAMS_DIR=$CQ_DATA_DIR 10 | 11 | python $SCRIPT_DIR/main.py --train_context $CQ_DATA_DIR/train_context.txt \ 12 | --train_ques $CQ_DATA_DIR/train_ques.txt \ 13 | --train_ans $CQ_DATA_DIR/train_ans.txt \ 14 | --train_ids $CQ_DATA_DIR/train_asin.txt \ 15 | --tune_context $CQ_DATA_DIR/tune_context.txt \ 16 | --tune_ques $CQ_DATA_DIR/tune_ques.txt \ 17 | --tune_ans $CQ_DATA_DIR/tune_ans.txt \ 18 | --tune_ids $CQ_DATA_DIR/tune_asin.txt \ 19 | --test_context $CQ_DATA_DIR/test_context.txt \ 20 | --test_ques $CQ_DATA_DIR/test_ques.txt \ 21 | --test_ans $CQ_DATA_DIR/test_ans.txt \ 22 | --test_ids $CQ_DATA_DIR/test_asin.txt \ 23 | --q_encoder_params $PARAMS_DIR/q_encoder_params \ 24 | --q_decoder_params $PARAMS_DIR/q_decoder_params \ 25 | --a_encoder_params $PARAMS_DIR/a_encoder_params \ 26 | --a_decoder_params $PARAMS_DIR/a_decoder_params \ 27 | --context_params $PARAMS_DIR/context_params \ 28 | --question_params $PARAMS_DIR/question_params \ 29 | --answer_params $PARAMS_DIR/answer_params \ 30 | --utility_params $PARAMS_DIR/utility_params \ 31 | --word_embeddings $EMB_DIR/word_embeddings.p \ 32 | --vocab $EMB_DIR/vocab.p \ 33 | --n_epochs 100 \ 34 | --max_post_len 100 \ 35 | --max_ques_len 20 \ 36 | --max_ans_len 20 \ 37 | --pretrain_ans True \ 38 | -------------------------------------------------------------------------------- /src/run_pretrain_util.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | SITENAME=askubuntu_unix_superuser 4 | 5 | CQ_DATA_DIR=clarification_question_generation_pytorch/$SITENAME 6 | SCRIPT_DIR=clarification_question_generation_pytorch/src 7 | EMB_DIR=clarification_question_generation_pytorch/embeddings/$SITENAME 8 | 9 | PARAMS_DIR=$CQ_DATA_DIR 10 | 11 | 12 | python $SCRIPT_DIR/main.py --train_context $CQ_DATA_DIR/train_context.txt \ 13 | --train_question $CQ_DATA_DIR/train_question.txt \ 14 | --train_answer $CQ_DATA_DIR/train_answer.txt \ 15 | --train_ids $CQ_DATA_DIR/train_ids \ 16 | --tune_context $CQ_DATA_DIR/tune_context.txt \ 17 | --tune_question $CQ_DATA_DIR/tune_question.txt \ 18 | --tune_answer $CQ_DATA_DIR/tune_answer.txt \ 19 | --tune_ids $CQ_DATA_DIR/tune_ids \ 20 | --test_context $CQ_DATA_DIR/test_context.txt \ 21 | --test_question $CQ_DATA_DIR/test_question.txt \ 22 | --test_answer $CQ_DATA_DIR/test_answer.txt \ 23 | --test_ids $CQ_DATA_DIR/test_ids \ 24 | --q_encoder_params $PARAMS_DIR/q_encoder_params \ 25 | --q_decoder_params $PARAMS_DIR/q_decoder_params \ 26 | --a_encoder_params $PARAMS_DIR/a_encoder_params \ 27 | --a_decoder_params $PARAMS_DIR/a_decoder_params \ 28 | --context_params $PARAMS_DIR/context_params \ 29 | --question_params $PARAMS_DIR/question_params \ 30 | --answer_params $PARAMS_DIR/answer_params \ 31 | --utility_params $PARAMS_DIR/utility_params \ 32 | --word_embeddings $EMB_DIR/word_embeddings.p \ 33 | --vocab $EMB_DIR/vocab.p \ 34 | --n_epochs 10 \ 35 | --max_post_len 100 \ 36 | --max_ques_len 20 \ 37 | --max_ans_len 20 \ 38 | --pretrain_util True \ 39 | 40 | -------------------------------------------------------------------------------- /src/run_pretrain_util_HK.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | SITENAME=Home_and_Kitchen 4 | 5 | CQ_DATA_DIR=clarification_question_generation_pytorch/$SITENAME 6 | SCRIPT_DIR=clarification_question_generation_pytorch/src 7 | EMB_DIR=clarification_question_generation_pytorch/embeddings/$SITENAME 8 | 9 | PARAMS_DIR=$CQ_DATA_DIR 10 | 11 | python $SCRIPT_DIR/main.py --train_context $CQ_DATA_DIR/train_context.txt \ 12 | --train_ques $CQ_DATA_DIR/train_ques.txt \ 13 | --train_ans $CQ_DATA_DIR/train_ans.txt \ 14 | --train_ids $CQ_DATA_DIR/train_asin.txt \ 15 | --tune_context $CQ_DATA_DIR/tune_context.txt \ 16 | --tune_ques $CQ_DATA_DIR/tune_ques.txt \ 17 | --tune_ans $CQ_DATA_DIR/tune_ans.txt \ 18 | --tune_ids $CQ_DATA_DIR/tune_asin.txt \ 19 | --test_context $CQ_DATA_DIR/test_context.txt \ 20 | --test_ques $CQ_DATA_DIR/test_ques.txt \ 21 | --test_ans $CQ_DATA_DIR/test_ans.txt \ 22 | --test_ids $CQ_DATA_DIR/test_asin.txt \ 23 | --q_encoder_params $PARAMS_DIR/q_encoder_params \ 24 | --q_decoder_params $PARAMS_DIR/q_decoder_params \ 25 | --a_encoder_params $PARAMS_DIR/a_encoder_params \ 26 | --a_decoder_params $PARAMS_DIR/a_decoder_params \ 27 | --context_params $PARAMS_DIR/context_params \ 28 | --question_params $PARAMS_DIR/question_params \ 29 | --answer_params $PARAMS_DIR/answer_params \ 30 | --utility_params $PARAMS_DIR/utility_params \ 31 | --word_embeddings $EMB_DIR/word_embeddings.p \ 32 | --vocab $EMB_DIR/vocab.p \ 33 | --n_epochs 10 \ 34 | --max_post_len 100 \ 35 | --max_ques_len 20 \ 36 | --max_ans_len 20 \ 37 | --pretrain_util True \ 38 | -------------------------------------------------------------------------------- /src/seq2seq/RL_inference.py: -------------------------------------------------------------------------------- 1 | from constants import * 2 | from masked_cross_entropy import * 3 | import numpy as np 4 | import random 5 | import torch 6 | from torch.autograd import Variable 7 | 8 | def get_decoded_seqs(decoder_outputs, word2index, max_len, batch_size): 9 | decoded_seqs = [] 10 | decoded_lens = [] 11 | decoded_seq_masks = [] 12 | for b in range(batch_size): 13 | decoded_seq = [] 14 | decoded_seq_mask = [0]*max_len 15 | log_prob = 0. 16 | for t in range(max_len): 17 | topv, topi = decoder_outputs[t][b].data.topk(1) 18 | ni = topi[0].item() 19 | if ni == word2index[EOS_token]: 20 | decoded_seq.append(ni) 21 | break 22 | else: 23 | decoded_seq.append(ni) 24 | decoded_seq_mask[t] = 1 25 | decoded_lens.append(len(decoded_seq)) 26 | decoded_seq += [word2index[PAD_token]]*int(max_len - len(decoded_seq)) 27 | decoded_seqs.append(decoded_seq) 28 | decoded_seq_masks.append(decoded_seq_mask) 29 | decoded_lens = np.array(decoded_lens) 30 | decoded_seqs = np.array(decoded_seqs) 31 | return decoded_seqs, decoded_lens, decoded_seq_masks 32 | 33 | def inference(input_batches, input_lens, target_batches, target_lens, \ 34 | encoder, decoder, word2index, max_len, batch_size): 35 | 36 | input_batches = Variable(torch.LongTensor(np.array(input_batches))).transpose(0, 1) 37 | target_batches = Variable(torch.LongTensor(np.array(target_batches))).transpose(0, 1) 38 | 39 | decoder_input = Variable(torch.LongTensor([word2index[SOS_token]] * batch_size)) 40 | decoder_outputs = Variable(torch.zeros(max_len, batch_size, decoder.output_size)) 41 | 42 | if USE_CUDA: 43 | input_batches = input_batches.cuda() 44 | target_batches = target_batches.cuda() 45 | decoder_input = decoder_input.cuda() 46 | decoder_outputs = decoder_outputs.cuda() 47 | 48 | # Run post words through encoder 49 | encoder_outputs, encoder_hidden = encoder(input_batches, input_lens, None) 50 | 51 | # Prepare input and output variables 52 | decoder_hidden = encoder_hidden[:decoder.n_layers] + encoder_hidden[decoder.n_layers:] 53 | 54 | # Run through decoder one time step at a time 55 | for t in range(max_len): 56 | decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden, encoder_outputs) 57 | decoder_outputs[t] = decoder_output 58 | 59 | # Without teacher forcing 60 | for b in range(batch_size): 61 | topv, topi = decoder_output[b].topk(1) 62 | decoder_input[b] = topi.squeeze().detach() 63 | 64 | decoded_seqs, decoded_lens, decoded_seq_masks = get_decoded_seqs(decoder_outputs, word2index, max_len, batch_size) 65 | 66 | # Loss calculation and backpropagation 67 | #loss = masked_cross_entropy( 68 | # decoder_outputs.transpose(0, 1).contiguous(), # -> batch x seq 69 | # target_batches.transpose(0, 1).contiguous(), # -> batch x seq 70 | # target_lens 71 | #) 72 | 73 | log_probs = calculate_log_probs( 74 | decoder_outputs.transpose(0, 1).contiguous(), # -> batch x seq 75 | decoded_seq_masks, 76 | ) 77 | 78 | return log_probs, decoded_seqs, decoded_lens 79 | -------------------------------------------------------------------------------- /src/seq2seq/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/raosudha89/clarification_question_generation_pytorch/23ae8aa0160eee70565751f4b6de13563a19d6ed/src/seq2seq/__init__.py -------------------------------------------------------------------------------- /src/seq2seq/ans_train.py: -------------------------------------------------------------------------------- 1 | from constants import * 2 | from masked_cross_entropy import * 3 | import numpy as np 4 | import random 5 | import torch 6 | from torch.autograd import Variable 7 | from constants import * 8 | 9 | 10 | def train(input_batches, input_lens, target_batches, target_lens, 11 | encoder, decoder, encoder_optimizer, decoder_optimizer, 12 | word2index, args, mode='train'): 13 | if mode == 'train': 14 | encoder.train() 15 | decoder.train() 16 | 17 | # Zero gradients of both optimizers 18 | encoder_optimizer.zero_grad() 19 | decoder_optimizer.zero_grad() 20 | 21 | if USE_CUDA: 22 | input_batches = Variable(torch.LongTensor(np.array(input_batches)).cuda()).transpose(0, 1) 23 | target_batches = Variable(torch.LongTensor(np.array(target_batches)).cuda()).transpose(0, 1) 24 | else: 25 | input_batches = Variable(torch.LongTensor(np.array(input_batches))).transpose(0, 1) 26 | target_batches = Variable(torch.LongTensor(np.array(target_batches))).transpose(0, 1) 27 | 28 | # Run post words through encoder 29 | encoder_outputs, encoder_hidden = encoder(input_batches, input_lens, None) 30 | 31 | # Prepare input and output variables 32 | decoder_hidden = encoder_hidden[:decoder.n_layers] + encoder_hidden[decoder.n_layers:] 33 | 34 | if USE_CUDA: 35 | decoder_input = Variable(torch.LongTensor([word2index[SOS_token]] * args.batch_size).cuda()) 36 | all_decoder_outputs = Variable(torch.zeros(args.max_ans_len, args.batch_size, decoder.output_size).cuda()) 37 | decoder_outputs = Variable(torch.zeros(args.max_ans_len, args.batch_size).cuda()) 38 | else: 39 | decoder_input = Variable(torch.LongTensor([word2index[SOS_token]] * args.batch_size)) 40 | all_decoder_outputs = Variable(torch.zeros(args.max_ans_len, args.batch_size, decoder.output_size)) 41 | decoder_outputs = Variable(torch.zeros(args.max_ans_len, args.batch_size)) 42 | 43 | # Run through decoder one time step at a time 44 | for t in range(args.max_ans_len): 45 | decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden, encoder_outputs) 46 | all_decoder_outputs[t] = decoder_output 47 | 48 | # Teacher Forcing 49 | decoder_input = target_batches[t] # Next input is current target 50 | decoder_outputs[t] = target_batches[t] 51 | 52 | # # Greeding 53 | # topv, topi = decoder_output.data.topk(1) 54 | # decoder_outputs[t] = topi.squeeze(1) 55 | 56 | decoded_seqs = [] 57 | decoded_lens = [] 58 | for b in range(args.batch_size): 59 | decoded_seq = [] 60 | for t in range(args.max_ans_len): 61 | topi = decoder_outputs[t][b].data 62 | idx = int(topi.item()) 63 | if idx == word2index[EOS_token]: 64 | decoded_seq.append(idx) 65 | break 66 | else: 67 | decoded_seq.append(idx) 68 | decoded_lens.append(len(decoded_seq)) 69 | decoded_seq += [word2index[PAD_token]] * (args.max_ans_len - len(decoded_seq)) 70 | decoded_seqs.append(decoded_seq) 71 | 72 | loss_fn = torch.nn.NLLLoss() 73 | # Loss calculation and backpropagation 74 | loss = masked_cross_entropy( 75 | all_decoder_outputs.transpose(0, 1).contiguous(), # -> batch x seq 76 | target_batches.transpose(0, 1).contiguous(), # -> batch x seq 77 | target_lens, loss_fn, args.max_ans_len 78 | ) 79 | if mode == 'train': 80 | loss.backward() 81 | encoder_optimizer.step() 82 | decoder_optimizer.step() 83 | 84 | return loss, decoded_seqs, decoded_lens 85 | -------------------------------------------------------------------------------- /src/seq2seq/attn.py: -------------------------------------------------------------------------------- 1 | from constants import * 2 | import torch 3 | import torch.nn as nn 4 | from torch.autograd import Variable 5 | import torch.nn.functional as F 6 | 7 | class Attn(nn.Module): 8 | def __init__(self, hidden_size): 9 | super(Attn, self).__init__() 10 | self.hidden_size = hidden_size 11 | 12 | def forward(self, hidden, encoder_outputs): 13 | max_len = encoder_outputs.size(0) 14 | this_batch_size = encoder_outputs.size(1) 15 | 16 | # Create variable to store attention energies 17 | attn_energies = Variable(torch.zeros(this_batch_size, max_len)) # B x S 18 | if USE_CUDA: 19 | attn_energies = attn_energies.cuda() 20 | attn_energies = torch.bmm(hidden.transpose(0, 1), encoder_outputs.transpose(0,1).transpose(1,2)).squeeze(1) 21 | # Normalize energies to weights in range 0 to 1, resize to 1 x B x S 22 | return F.softmax(attn_energies, dim=1).unsqueeze(1) 23 | -------------------------------------------------------------------------------- /src/seq2seq/attnDecoderRNN.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import sys 4 | sys.path.append('src/seq2seq') 5 | from attn import * 6 | 7 | 8 | class AttnDecoderRNN(nn.Module): 9 | def __init__(self, hidden_size, output_size, word_embeddings, n_layers=1, dropout=0.1): 10 | super(AttnDecoderRNN, self).__init__() 11 | 12 | # Keep for reference 13 | self.hidden_size = hidden_size 14 | self.output_size = output_size 15 | self.n_layers = n_layers 16 | self.dropout = dropout 17 | 18 | # Define layers 19 | self.embedding = nn.Embedding(len(word_embeddings), len(word_embeddings[0])) 20 | self.embedding.weight.data.copy_(torch.from_numpy(word_embeddings)) 21 | self.embedding.weight.requires_grad = False 22 | self.embedding_dropout = nn.Dropout(dropout) 23 | self.gru = nn.GRU(len(word_embeddings[0]), hidden_size, n_layers, dropout=dropout) 24 | self.concat = nn.Linear(hidden_size * 2, hidden_size) 25 | self.out = nn.Linear(hidden_size, output_size) 26 | 27 | self.attn = Attn(hidden_size) 28 | 29 | def forward(self, input_seq, last_hidden, p_encoder_outputs): 30 | # Note: we run this one step at a time 31 | 32 | # Get the embedding of the current input word (last output word) 33 | embedded = self.embedding(input_seq) 34 | embedded = self.embedding_dropout(embedded) 35 | embedded = embedded.view(1, embedded.shape[0], embedded.shape[1]) # S=1 x B x N 36 | 37 | # Get current hidden state from input word and last hidden state 38 | rnn_output, hidden = self.gru(embedded, last_hidden) 39 | 40 | # Calculate attention from current RNN state and all p_encoder outputs; 41 | # apply to p_encoder outputs to get weighted average 42 | p_attn_weights = self.attn(rnn_output, p_encoder_outputs) 43 | p_context = p_attn_weights.bmm(p_encoder_outputs.transpose(0, 1)) # B x S=1 x N 44 | 45 | # Attentional vector using the RNN hidden state and context vector 46 | # concatenated together (Luong eq. 5) 47 | rnn_output = rnn_output.squeeze(0) # S=1 x B x N -> B x N 48 | 49 | p_context = p_context.squeeze(1) # B x S=1 x N -> B x N 50 | concat_input = torch.cat((rnn_output, p_context), 1) 51 | concat_output = F.tanh(self.concat(concat_input)) 52 | 53 | # Finally predict next token (Luong eq. 6, without softmax) 54 | output = self.out(concat_output) 55 | 56 | # Return final output, hidden state, and attention weights (for visualization) 57 | return output, hidden 58 | -------------------------------------------------------------------------------- /src/seq2seq/baselineFF.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from constants import * 4 | 5 | class BaselineFF(nn.Module): 6 | def __init__(self, input_dim): 7 | super(BaselineFF, self).__init__() 8 | 9 | self.layer1 = nn.Linear(input_dim, HIDDEN_SIZE) 10 | self.relu = nn.ReLU() 11 | self.layer2 = nn.Linear(HIDDEN_SIZE, 1) 12 | self.sigmoid = nn.Sigmoid() 13 | 14 | def forward(self, x): 15 | x = self.layer1(x) 16 | x = self.relu(x) 17 | x = self.layer2(x) 18 | x = self.sigmoid(x) 19 | return x 20 | 21 | -------------------------------------------------------------------------------- /src/seq2seq/encoderRNN.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class EncoderRNN(nn.Module): 5 | def __init__(self, hidden_size, word_embeddings, n_layers=1, dropout=0.1): 6 | super(EncoderRNN, self).__init__() 7 | 8 | self.hidden_size = hidden_size 9 | self.n_layers = n_layers 10 | self.dropout = dropout 11 | 12 | self.embedding = nn.Embedding(len(word_embeddings), len(word_embeddings[0])) 13 | self.embedding.weight.data.copy_(torch.from_numpy(word_embeddings)) 14 | self.embedding.weight.requires_grad = False 15 | self.gru = nn.GRU(len(word_embeddings[0]), hidden_size, n_layers, dropout=self.dropout, bidirectional=True) 16 | 17 | def forward(self, input_seqs, input_lengths, hidden=None): 18 | # Note: we run this all at once (over multiple batches of multiple sequences) 19 | embedded = self.embedding(input_seqs) 20 | #packed = torch.nn.utils.rnn.pack_padded_sequence(embedded, input_lengths) 21 | #outputs, hidden = self.gru(packed, hidden) 22 | #outputs, output_lengths = torch.nn.utils.rnn.pad_packed_sequence(outputs) # unpack (back to padded) 23 | outputs, hidden = self.gru(embedded, hidden) 24 | outputs = outputs[:, :, :self.hidden_size] + outputs[:, : ,self.hidden_size:] # Sum bidirectional outputs 25 | return outputs, hidden 26 | -------------------------------------------------------------------------------- /src/seq2seq/evaluate.py: -------------------------------------------------------------------------------- 1 | import random 2 | from constants import * 3 | from prepare_data import * 4 | from masked_cross_entropy import * 5 | from helper import * 6 | import torch 7 | import torch.nn as nn 8 | from torch.autograd import Variable 9 | 10 | 11 | def evaluate(word2index, index2word, encoder, decoder, test_data, 12 | max_output_length, BATCH_SIZE, out_file): 13 | ids_seqs, input_seqs, input_lens, output_seqs, output_lens = test_data 14 | total_loss = 0. 15 | n_batches = len(input_seqs) / BATCH_SIZE 16 | 17 | for ids_seqs_batch, input_seqs_batch, input_lens_batch, output_seqs_batch, output_lens_batch in \ 18 | iterate_minibatches(ids_seqs, input_seqs, input_lens, output_seqs, output_lens, BATCH_SIZE): 19 | 20 | if USE_CUDA: 21 | input_seqs_batch = Variable(torch.LongTensor(np.array(input_seqs_batch)).cuda()).transpose(0, 1) 22 | output_seqs_batch = Variable(torch.LongTensor(np.array(output_seqs_batch)).cuda()).transpose(0, 1) 23 | else: 24 | input_seqs_batch = Variable(torch.LongTensor(np.array(input_seqs_batch))).transpose(0, 1) 25 | output_seqs_batch = Variable(torch.LongTensor(np.array(output_seqs_batch))).transpose(0, 1) 26 | 27 | # Run post words through encoder 28 | encoder_outputs, encoder_hidden = encoder(input_seqs_batch, input_lens_batch, None) 29 | # Create starting vectors for decoder 30 | decoder_input = Variable(torch.LongTensor([word2index[SOS_token]] * BATCH_SIZE), volatile=True) 31 | decoder_hidden = encoder_hidden[:decoder.n_layers] + encoder_hidden[decoder.n_layers:] 32 | all_decoder_outputs = Variable(torch.zeros(max_output_length, BATCH_SIZE, decoder.output_size)) 33 | 34 | if USE_CUDA: 35 | decoder_input = decoder_input.cuda() 36 | all_decoder_outputs = all_decoder_outputs.cuda() 37 | 38 | # Run through decoder one time step at a time 39 | for t in range(max_output_length): 40 | decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden, encoder_outputs) 41 | all_decoder_outputs[t] = decoder_output 42 | # Choose top word from output 43 | topv, topi = decoder_output.data.topk(1) 44 | decoder_input = topi.squeeze(1) 45 | if out_file: 46 | for b in range(BATCH_SIZE): 47 | decoded_words = [] 48 | for t in range(max_output_length): 49 | topv, topi = all_decoder_outputs[t][b].data.topk(1) 50 | ni = topi[0].item() 51 | if ni == word2index[EOS_token]: 52 | decoded_words.append(EOS_token) 53 | break 54 | else: 55 | decoded_words.append(index2word[ni]) 56 | out_file.write(' '.join(decoded_words)+'\n') 57 | 58 | loss_fn = torch.nn.NLLLoss() 59 | # Loss calculation 60 | loss = masked_cross_entropy( 61 | all_decoder_outputs.transpose(0, 1).contiguous(), # -> batch x seq 62 | output_seqs_batch.transpose(0, 1).contiguous(), # -> batch x seq 63 | output_lens_batch, loss_fn, max_output_length 64 | ) 65 | total_loss += loss.item() 66 | return total_loss/n_batches 67 | -------------------------------------------------------------------------------- /src/seq2seq/helper.py: -------------------------------------------------------------------------------- 1 | from constants import * 2 | import math 3 | import numpy as np 4 | import nltk 5 | import random 6 | import time 7 | import torch 8 | 9 | 10 | def as_minutes(s): 11 | m = math.floor(s / 60) 12 | s -= m * 60 13 | return '%dm %ds' % (m, s) 14 | 15 | 16 | def time_since(since, percent): 17 | now = time.time() 18 | s = now - since 19 | es = s / (percent) 20 | rs = es - s 21 | return '%s (- %s)' % (as_minutes(s), as_minutes(rs)) 22 | 23 | 24 | def iterate_minibatches(id_seqs, input_seqs, input_lens, output_seqs, output_lens, batch_size, shuffle=True): 25 | if shuffle: 26 | indices = np.arange(len(input_seqs)) 27 | np.random.shuffle(indices) 28 | for start_idx in range(0, len(input_seqs) - batch_size + 1, batch_size): 29 | if shuffle: 30 | ex = indices[start_idx:start_idx + batch_size] 31 | else: 32 | ex = slice(start_idx, start_idx + batch_size) 33 | yield np.array(id_seqs)[ex], np.array(input_seqs)[ex], np.array(input_lens)[ex], \ 34 | np.array(output_seqs)[ex], np.array(output_lens)[ex] 35 | 36 | 37 | def reverse_dict(word2index): 38 | index2word = {} 39 | for w in word2index: 40 | ix = word2index[w] 41 | index2word[ix] = w 42 | return index2word 43 | 44 | 45 | def calculate_bleu(true, true_lens, pred, pred_lens, index2word, max_len): 46 | sent_bleu_scores = torch.zeros(len(pred)) 47 | for i in range(len(pred)): 48 | true_sent = [index2word[idx] for idx in true[i][:true_lens[i]]] 49 | pred_sent = [index2word[idx] for idx in pred[i][:pred_lens[i]]] 50 | sent_bleu_scores[i] = nltk.translate.bleu_score.sentence_bleu([true_sent], pred_sent) 51 | if USE_CUDA: 52 | sent_bleu_scores = sent_bleu_scores.cuda() 53 | return sent_bleu_scores 54 | -------------------------------------------------------------------------------- /src/seq2seq/main.py: -------------------------------------------------------------------------------- 1 | from attn import * 2 | from attnDecoderRNN import * 3 | from constants import * 4 | from encoderRNN import * 5 | from evaluate import * 6 | from helper import * 7 | from train import * 8 | import torch 9 | import torch.optim as optim 10 | from prepare_data import * 11 | 12 | 13 | def run_seq2seq(train_data, test_data, word2index, word_embeddings, 14 | encoder_params_file, decoder_params_file, 15 | encoder_params_contd_file, decoder_params_contd_file, max_target_length, 16 | n_epochs, batch_size, n_layers): 17 | # Initialize q models 18 | print('Initializing models') 19 | encoder = EncoderRNN(HIDDEN_SIZE, word_embeddings, n_layers, dropout=DROPOUT) 20 | decoder = AttnDecoderRNN(HIDDEN_SIZE, len(word2index), word_embeddings, n_layers) 21 | 22 | # Initialize optimizers 23 | encoder_optimizer = optim.Adam([par for par in encoder.parameters() if par.requires_grad], lr=LEARNING_RATE) 24 | decoder_optimizer = optim.Adam([par for par in decoder.parameters() if par.requires_grad], lr=LEARNING_RATE * DECODER_LEARNING_RATIO) 25 | 26 | # Move models to GPU 27 | if USE_CUDA: 28 | encoder.cuda() 29 | decoder.cuda() 30 | 31 | # Keep track of time elapsed and running averages 32 | start = time.time() 33 | print_loss_total = 0 # Reset every print_every 34 | epoch = 0.0 35 | #epoch = 12.0 36 | #print('Loading encoded, decoder params') 37 | #encoder.load_state_dict(torch.load(encoder_params_contd_file+'.epoch%d' % epoch)) 38 | #decoder.load_state_dict(torch.load(decoder_params_contd_file+'.epoch%d' % epoch)) 39 | 40 | ids_seqs, input_seqs, input_lens, output_seqs, output_lens = train_data 41 | 42 | n_batches = len(input_seqs) / batch_size 43 | teacher_forcing_ratio = 1.0 44 | # decr = teacher_forcing_ratio/n_epochs 45 | prev_test_loss = None 46 | num_decrease = 0.0 47 | while epoch < n_epochs: 48 | epoch += 1 49 | for ids_seqs_batch, input_seqs_batch, input_lens_batch, \ 50 | output_seqs_batch, output_lens_batch in \ 51 | iterate_minibatches(ids_seqs, input_seqs, input_lens, output_seqs, output_lens, batch_size): 52 | 53 | start_time = time.time() 54 | # Run the train function 55 | loss = train( 56 | input_seqs_batch, input_lens_batch, 57 | output_seqs_batch, output_lens_batch, 58 | encoder, decoder, 59 | encoder_optimizer, decoder_optimizer, 60 | word2index[SOS_token], max_target_length, 61 | batch_size, teacher_forcing_ratio 62 | ) 63 | 64 | # Keep track of loss 65 | print_loss_total += loss 66 | 67 | # teacher_forcing_ratio = teacher_forcing_ratio - decr 68 | print_loss_avg = print_loss_total / n_batches 69 | print_loss_total = 0 70 | print('Epoch: %d' % epoch) 71 | print('Train Set') 72 | print_summary = '%s %d %.4f' % (time_since(start, epoch / n_epochs), epoch, print_loss_avg) 73 | print(print_summary) 74 | print('Dev Set') 75 | curr_test_loss = evaluate(word2index, None, encoder, decoder, test_data, 76 | max_target_length, batch_size, None) 77 | print('%.4f ' % curr_test_loss) 78 | # if prev_test_loss is not None: 79 | # diff_test_loss = prev_test_loss - curr_test_loss 80 | # if diff_test_loss <= 0: 81 | # num_decrease += 1 82 | # if num_decrease > 5: 83 | # print 'Early stopping' 84 | # print 'Saving model params' 85 | # torch.save(encoder.state_dict(), encoder_params_file + '.epoch%d' % epoch) 86 | # torch.save(decoder.state_dict(), decoder_params_file + '.epoch%d' % epoch) 87 | # return 88 | # if epoch % 5 == 0: 89 | print('Saving model params') 90 | torch.save(encoder.state_dict(), encoder_params_file+'.epoch%d' % epoch) 91 | torch.save(decoder.state_dict(), decoder_params_file+'.epoch%d' % epoch) 92 | 93 | 94 | -------------------------------------------------------------------------------- /src/seq2seq/masked_cross_entropy.py: -------------------------------------------------------------------------------- 1 | from constants import * 2 | import torch 3 | from torch.nn import functional 4 | from torch.autograd import Variable 5 | 6 | 7 | def calculate_log_probs(logits, output, length, loss_fn, mixer_delta): 8 | batch_size = logits.shape[0] 9 | log_probs = functional.log_softmax(logits, dim=2) 10 | avg_log_probs = Variable(torch.zeros(batch_size)) 11 | if USE_CUDA: 12 | avg_log_probs = avg_log_probs.cuda() 13 | for b in range(batch_size): 14 | curr_len = length[b]-mixer_delta 15 | if curr_len > 0: 16 | avg_log_probs[b] = loss_fn(log_probs[b][:curr_len], output[b][:curr_len]) / curr_len 17 | return avg_log_probs 18 | 19 | 20 | def masked_cross_entropy(logits, target, length, loss_fn, mixer_delta=None): 21 | batch_size = logits.shape[0] 22 | # log_probs: (batch, max_len, num_classes) 23 | log_probs = functional.log_softmax(logits, dim=2) 24 | loss = 0. 25 | for b in range(batch_size): 26 | curr_len = min(length[b], mixer_delta) 27 | sent_loss = loss_fn(log_probs[b][:curr_len], target[b][:curr_len]) / curr_len 28 | loss += sent_loss 29 | loss = loss / batch_size 30 | return loss 31 | -------------------------------------------------------------------------------- /src/seq2seq/prepare_data.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | import pdb 4 | import sys 5 | sys.path.append('src/seq2seq') 6 | from read_data import * 7 | import numpy as np 8 | from constants import * 9 | 10 | UNK_TOKEN='' 11 | # UNK_TOKEN='unk' 12 | 13 | # Return a list of indexes, one for each word in the sentence, plus EOS 14 | def prepare_sequence(seq, word2index, max_len): 15 | sequence = [word2index[w] if w in word2index else word2index[UNK_TOKEN] for w in seq.split(' ')[:max_len-1]] 16 | sequence.append(word2index[EOS_token]) 17 | length = len(sequence) 18 | sequence += [word2index[PAD_token]]*int(max_len - len(sequence)) 19 | return sequence, length 20 | 21 | 22 | def prepare_pq_sequence(post_seq, ques_seq, word2index, max_post_len, max_ques_len): 23 | p_sequence = [word2index[w] if w in word2index else word2index[UNK_TOKEN] for w in post_seq.split(' ')[:(max_post_len-1)]] 24 | p_sequence.append(word2index[EOP_token]) 25 | p_sequence += [word2index[PAD_token]]*int(max_post_len - len(p_sequence)) 26 | q_sequence = [word2index[w] if w in word2index else word2index[UNK_TOKEN] for w in ques_seq.split(' ')[:(max_ques_len-1)]] 27 | q_sequence.append(word2index[EOS_token]) 28 | q_sequence += [word2index[PAD_token]]*int(max_ques_len - len(q_sequence)) 29 | sequence = p_sequence + q_sequence 30 | length = max_post_len, max_ques_len 31 | return sequence, length 32 | 33 | 34 | def preprocess_data(triples, word2index, max_post_len, max_ques_len, max_ans_len): 35 | id_seqs = [] 36 | post_seqs = [] 37 | post_lens = [] 38 | ques_seqs = [] 39 | ques_lens = [] 40 | post_ques_seqs = [] 41 | post_ques_lens = [] 42 | ans_seqs = [] 43 | ans_lens = [] 44 | 45 | for i in range(len(triples)): 46 | curr_id, post, ques, ans = triples[i] 47 | id_seqs.append(curr_id) 48 | post_seq, post_len = prepare_sequence(post, word2index, max_post_len) 49 | post_seqs.append(post_seq) 50 | post_lens.append(post_len) 51 | ques_seq, ques_len = prepare_sequence(ques, word2index, max_ques_len) 52 | ques_seqs.append(ques_seq) 53 | ques_lens.append(ques_len) 54 | post_ques_seq, post_ques_len = prepare_pq_sequence(post, ques, word2index, max_post_len, max_ques_len) 55 | post_ques_seqs.append(post_ques_seq) 56 | post_ques_lens.append(post_ques_len) 57 | if ans is not None: 58 | ans_seq, ans_len = prepare_sequence(ans, word2index, max_ans_len) 59 | ans_seqs.append(ans_seq) 60 | ans_lens.append(ans_len) 61 | 62 | return id_seqs, post_seqs, post_lens, ques_seqs, ques_lens, \ 63 | post_ques_seqs, post_ques_lens, ans_seqs, ans_lens 64 | -------------------------------------------------------------------------------- /src/seq2seq/read_data.py: -------------------------------------------------------------------------------- 1 | import re 2 | import csv 3 | from constants import * 4 | import unicodedata 5 | from collections import defaultdict 6 | import math 7 | 8 | 9 | def unicode_to_ascii(s): 10 | return ''.join( 11 | c for c in unicodedata.normalize('NFD', s) 12 | if unicodedata.category(c) != 'Mn' 13 | ) 14 | 15 | 16 | # Lowercase, trim, and remove non-letter characters 17 | def normalize_string(s, max_len): 18 | #s = unicode_to_ascii(s.lower().strip()) 19 | s = s.lower().strip() 20 | words = s.split() 21 | s = ' '.join(words[:max_len]) 22 | return s 23 | 24 | 25 | def get_context(line, max_post_len, max_ques_len): 26 | is_specific, is_generic = False, False 27 | if '' in line: 28 | line = line.replace(' ', '') 29 | is_specific = True 30 | if '' in line: 31 | line = line.replace(' ', '') 32 | is_generic = True 33 | if is_specific or is_generic: 34 | context = normalize_string(line, max_post_len-2) # one token space for specificity and another for EOS 35 | else: 36 | context = normalize_string(line, max_post_len-1) 37 | if is_specific: 38 | context = ' ' + context 39 | if is_generic: 40 | context += ' ' + context 41 | return context 42 | 43 | 44 | def read_data(context_fname, question_fname, answer_fname, ids_fname, 45 | max_post_len, max_ques_len, max_ans_len, count=None, mode='train'): 46 | if ids_fname is not None: 47 | ids = [] 48 | for line in open(ids_fname, 'r').readlines(): 49 | curr_id = line.strip('\n') 50 | ids.append(curr_id) 51 | 52 | print("Reading lines...") 53 | data = [] 54 | i = 0 55 | for line in open(context_fname, 'r').readlines(): 56 | context = get_context(line, max_post_len, max_ques_len) 57 | if ids_fname is not None: 58 | data.append([ids[i], context, None, None]) 59 | else: 60 | data.append([None, context, None, None]) 61 | i += 1 62 | if count and i == count: 63 | break 64 | 65 | i = 0 66 | for line in open(question_fname, 'r').readlines(): 67 | question = normalize_string(line, max_ques_len-1) 68 | data[i][2] = question 69 | i += 1 70 | if count and i == count: 71 | break 72 | assert(i == len(data)) 73 | 74 | if answer_fname is not None: 75 | i = 0 76 | for line in open(answer_fname, 'r').readlines(): 77 | answer = normalize_string(line, max_ans_len-1) # one token space for EOS 78 | data[i][3] = answer 79 | i += 1 80 | if count and i == count: 81 | break 82 | assert(i == len(data)) 83 | 84 | if ids_fname is not None: 85 | updated_data = [] 86 | i = 0 87 | if mode == 'test': 88 | max_per_id_count = 1 89 | else: 90 | max_per_id_count = 20 91 | data_ct_per_id = defaultdict(int) 92 | for curr_id in ids: 93 | data_ct_per_id[curr_id] += 1 94 | if data_ct_per_id[curr_id] <= max_per_id_count: 95 | updated_data.append(data[i]) 96 | i += 1 97 | if count and i == count: 98 | break 99 | assert (i == len(data)) 100 | return updated_data 101 | 102 | return data 103 | -------------------------------------------------------------------------------- /src/seq2seq/train.py: -------------------------------------------------------------------------------- 1 | from constants import * 2 | from masked_cross_entropy import * 3 | import numpy as np 4 | import random 5 | import torch 6 | from torch.autograd import Variable 7 | 8 | 9 | def train(input_batches, input_lens, target_batches, target_lens, 10 | encoder, decoder, encoder_optimizer, decoder_optimizer, 11 | SOS_idx, max_target_length, batch_size, teacher_forcing_ratio): 12 | 13 | # Zero gradients of both optimizers 14 | encoder_optimizer.zero_grad() 15 | decoder_optimizer.zero_grad() 16 | 17 | if USE_CUDA: 18 | input_batches = Variable(torch.LongTensor(np.array(input_batches)).cuda()).transpose(0, 1) 19 | target_batches = Variable(torch.LongTensor(np.array(target_batches)).cuda()).transpose(0, 1) 20 | else: 21 | input_batches = Variable(torch.LongTensor(np.array(input_batches))).transpose(0, 1) 22 | target_batches = Variable(torch.LongTensor(np.array(target_batches))).transpose(0, 1) 23 | 24 | # Run post words through encoder 25 | encoder_outputs, encoder_hidden = encoder(input_batches, input_lens, None) 26 | 27 | # Prepare input and output variables 28 | decoder_hidden = encoder_hidden[:decoder.n_layers] + encoder_hidden[decoder.n_layers:] 29 | 30 | if USE_CUDA: 31 | decoder_input = Variable(torch.LongTensor([SOS_idx] * batch_size).cuda()) 32 | all_decoder_outputs = Variable(torch.zeros(max_target_length, batch_size, decoder.output_size).cuda()) 33 | else: 34 | decoder_input = Variable(torch.LongTensor([SOS_idx] * batch_size)) 35 | all_decoder_outputs = Variable(torch.zeros(max_target_length, batch_size, decoder.output_size)) 36 | 37 | use_teacher_forcing = True if random.random() < teacher_forcing_ratio else False 38 | 39 | # Run through decoder one time step at a time 40 | for t in range(max_target_length): 41 | decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden, encoder_outputs) 42 | all_decoder_outputs[t] = decoder_output 43 | 44 | if use_teacher_forcing: 45 | # Teacher Forcing 46 | decoder_input = target_batches[t] # Next input is current target 47 | else: 48 | # Greeding decoding 49 | for b in range(batch_size): 50 | topi = decoder_output[b].topk(1)[1][0] 51 | decoder_input[b] = topi.squeeze().detach() 52 | 53 | loss_fn = torch.nn.NLLLoss() 54 | 55 | # Loss calculation and backpropagation 56 | loss = masked_cross_entropy( 57 | all_decoder_outputs.transpose(0, 1).contiguous(), # -> batch x seq 58 | target_batches.transpose(0, 1).contiguous(), # -> batch x seq 59 | target_lens, loss_fn, max_target_length 60 | ) 61 | loss.backward() 62 | encoder_optimizer.step() 63 | decoder_optimizer.step() 64 | return loss.item() 65 | -------------------------------------------------------------------------------- /src/utility/FeedForward.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from constants import * 4 | 5 | class FeedForward(nn.Module): 6 | def __init__(self, input_dim): 7 | super(FeedForward, self).__init__() 8 | 9 | self.layer1 = nn.Linear(input_dim, HIDDEN_SIZE) 10 | self.relu = nn.ReLU() 11 | self.layer2 = nn.Linear(HIDDEN_SIZE, 1) 12 | 13 | def forward(self, x): 14 | x = self.layer1(x) 15 | x = self.relu(x) 16 | x = self.layer2(x) 17 | return x 18 | 19 | -------------------------------------------------------------------------------- /src/utility/RL_evaluate.py: -------------------------------------------------------------------------------- 1 | from constants import * 2 | import sys 3 | sys.path.append('src/utility') 4 | from helper_utility import * 5 | import numpy as np 6 | import torch 7 | from torch.autograd import Variable 8 | 9 | 10 | def evaluate_utility(context_model, question_model, answer_model, utility_model, c, cl, q, ql, a, al, args): 11 | with torch.no_grad(): 12 | context_model.eval() 13 | question_model.eval() 14 | answer_model.eval() 15 | utility_model.eval() 16 | cm = get_masks(cl, args.max_post_len) 17 | qm = get_masks(ql, args.max_ques_len) 18 | am = get_masks(al, args.max_ans_len) 19 | 20 | c = torch.tensor(c) 21 | cm = torch.FloatTensor(cm) 22 | q = torch.tensor(q) 23 | qm = torch.FloatTensor(qm) 24 | a = torch.tensor(a) 25 | am = torch.FloatTensor(am) 26 | if USE_CUDA: 27 | c = c.cuda() 28 | cm = cm.cuda() 29 | q = q.cuda() 30 | qm = qm.cuda() 31 | a = a.cuda() 32 | am = am.cuda() 33 | c_hid, c_out = context_model(torch.transpose(c, 0, 1)) 34 | cm = torch.transpose(cm, 0, 1).unsqueeze(2) 35 | cm = cm.expand(cm.shape[0], cm.shape[1], 2*HIDDEN_SIZE) 36 | c_out = torch.sum(c_out * cm, dim=0) 37 | 38 | q_hid, q_out = question_model(torch.transpose(q, 0, 1)) 39 | qm = torch.transpose(qm, 0, 1).unsqueeze(2) 40 | qm = qm.expand(qm.shape[0], qm.shape[1], 2*HIDDEN_SIZE) 41 | q_out = torch.sum(q_out * qm, dim=0) 42 | 43 | a_hid, a_out = answer_model(torch.transpose(a, 0, 1)) 44 | am = torch.transpose(am, 0, 1).unsqueeze(2) 45 | am = am.expand(am.shape[0], am.shape[1], 2*HIDDEN_SIZE) 46 | a_out = torch.sum(a_out * am, dim=0) 47 | 48 | predictions = utility_model(torch.cat((c_out, q_out, a_out), 1)).squeeze(1) 49 | # predictions = utility_model(torch.cat((c_out, q_out), 1)).squeeze(1) 50 | # predictions = utility_model(q_out).squeeze(1) 51 | predictions = torch.nn.functional.sigmoid(predictions) 52 | 53 | return predictions 54 | -------------------------------------------------------------------------------- /src/utility/RL_train.py: -------------------------------------------------------------------------------- 1 | from helper import * 2 | import numpy as np 3 | import torch 4 | from constants import * 5 | 6 | 7 | def train_utility(context_model, question_model, answer_model, utility_model, optimizer, criterion, 8 | c, cm, q, qm, a, am, labs): 9 | optimizer.zero_grad() 10 | c = torch.tensor(c) 11 | cm = torch.FloatTensor(cm) 12 | q = torch.tensor(q) 13 | qm = torch.FloatTensor(qm) 14 | a = torch.tensor(a) 15 | am = torch.FloatTensor(am) 16 | if USE_CUDA: 17 | c = c.cuda() 18 | cm = cm.cuda() 19 | q = q.cuda() 20 | qm = qm.cuda() 21 | a = a.cuda() 22 | am = am.cuda() 23 | c_hid, c_out = context_model(torch.transpose(c, 0, 1)) 24 | cm = torch.transpose(cm, 0, 1).unsqueeze(2) 25 | cm = cm.expand(cm.shape[0], cm.shape[1], 2*HIDDEN_SIZE) 26 | c_out = torch.sum(c_out * cm, dim=0) 27 | 28 | q_hid, q_out = question_model(torch.transpose(q, 0, 1)) 29 | qm = torch.transpose(qm, 0, 1).unsqueeze(2) 30 | qm = qm.expand(qm.shape[0], qm.shape[1], 2*HIDDEN_SIZE) 31 | q_out = torch.sum(q_out * qm, dim=0) 32 | 33 | a_hid, a_out = answer_model(torch.transpose(a, 0, 1)) 34 | am = torch.transpose(am, 0, 1).unsqueeze(2) 35 | am = am.expand(am.shape[0], am.shape[1], 2*HIDDEN_SIZE) 36 | a_out = torch.sum(a_out * am, dim=0) 37 | 38 | predictions = utility_model(torch.cat((c_out, q_out, a_out), 1)).squeeze(1) 39 | if USE_CUDA: 40 | labs = labs.cuda() 41 | loss = criterion(predictions, labs) 42 | acc = binary_accuracy(predictions, labs) 43 | loss.backward() 44 | optimizer.step() 45 | return loss, acc 46 | -------------------------------------------------------------------------------- /src/utility/RNN.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from constants import * 4 | 5 | 6 | class RNN(nn.Module): 7 | def __init__(self, vocab_size, embedding_dim, n_layers): 8 | super(RNN, self).__init__() 9 | 10 | self.embedding = nn.Embedding(vocab_size, embedding_dim) 11 | self.rnn = nn.LSTM(embedding_dim, HIDDEN_SIZE, num_layers=n_layers, bidirectional=True) 12 | self.fc = nn.Linear(HIDDEN_SIZE*2, HIDDEN_SIZE) 13 | self.dropout = nn.Dropout(DROPOUT) 14 | 15 | def forward(self, x): 16 | 17 | #x = [sent len, batch size] 18 | 19 | embedded = self.dropout(self.embedding(x)) 20 | 21 | #embedded = [sent len, batch size, emb dim] 22 | 23 | output, (hidden, cell) = self.rnn(embedded) 24 | 25 | #output = [sent len, batch size, hid dim * num directions] 26 | #hidden = [num layers * num directions, batch size, hid. dim] 27 | #cell = [num layers * num directions, batch size, hid. dim] 28 | 29 | hidden = self.dropout(torch.cat((hidden[-2, :, :], hidden[-1, :, :]), dim=1)) 30 | 31 | #hidden [batch size, hid. dim * num directions] 32 | 33 | return self.fc(hidden.squeeze(0)), output 34 | 35 | -------------------------------------------------------------------------------- /src/utility/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/raosudha89/clarification_question_generation_pytorch/23ae8aa0160eee70565751f4b6de13563a19d6ed/src/utility/__init__.py -------------------------------------------------------------------------------- /src/utility/combine_pickle.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import pickle as p 3 | 4 | if __name__ == "__main__": 5 | askubuntu = p.load(open(sys.argv[1], 'rb')) 6 | unix = p.load(open(sys.argv[2], 'rb')) 7 | superuser = p.load(open(sys.argv[3], 'rb')) 8 | combined = unix + superuser + askubuntu 9 | p.dump(combined, open(sys.argv[4], 'wb')) 10 | -------------------------------------------------------------------------------- /src/utility/data_loader.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import argparse 3 | import sys 4 | import torch 5 | import torch.autograd as autograd 6 | import random 7 | import torch.utils.data as Data 8 | import pickle as p 9 | 10 | def load_data(contexts_fname, answers_fname): 11 | contexts_file = open(contexts_fname, 'r') 12 | answers_file = open(answers_fname, 'r') 13 | contexts = [] 14 | questions = [] 15 | answers = [] 16 | for line in contexts_file.readlines(): 17 | context, question = line.strip('\n').split('') 18 | contexts.append(context) 19 | questions.append(question) 20 | answers = [line.strip('\n') for line in answers_file.readlines()] 21 | data = [] 22 | for i in range(len(contexts)): 23 | data.append([contexts[i], questions[i], answers[i], 1]) 24 | r = random.randint(0, len(contexts)-1) 25 | data.append([contexts[i], questions[r], answers[r], 0]) 26 | random.shuffle(data) 27 | return data 28 | 29 | def main(args): 30 | train_data = load_data(args.train_contexts_fname, args.train_answers_fname) 31 | dev_data = load_data(args.tune_contexts_fname, args.tune_answers_fname) 32 | test_data = load_data(args.test_contexts_fname, args.test_answers_fname) 33 | 34 | p.dump(train_data, open(args.train_data, 'wb')) 35 | p.dump(dev_data, open(args.tune_data, 'wb')) 36 | p.dump(test_data, open(args.test_data, 'wb')) 37 | 38 | if __name__ == "__main__": 39 | argparser = argparse.ArgumentParser(sys.argv[0]) 40 | argparser.add_argument("--train_contexts_fname", type = str) 41 | argparser.add_argument("--train_answers_fname", type = str) 42 | argparser.add_argument("--tune_contexts_fname", type = str) 43 | argparser.add_argument("--tune_answers_fname", type = str) 44 | argparser.add_argument("--test_contexts_fname", type = str) 45 | argparser.add_argument("--test_answers_fname", type = str) 46 | argparser.add_argument("--train_data", type = str) 47 | argparser.add_argument("--tune_data", type = str) 48 | argparser.add_argument("--test_data", type = str) 49 | args = argparser.parse_args() 50 | print args 51 | print "" 52 | main(args) 53 | 54 | -------------------------------------------------------------------------------- /src/utility/evaluate_utility.py: -------------------------------------------------------------------------------- 1 | from helper_utility import * 2 | import numpy as np 3 | import torch 4 | from constants import * 5 | 6 | 7 | def evaluate(context_model, question_model, answer_model, utility_model, dev_data, criterion, args): 8 | epoch_loss = 0 9 | epoch_acc = 0 10 | num_batches = 0 11 | 12 | context_model.eval() 13 | question_model.eval() 14 | answer_model.eval() 15 | utility_model.eval() 16 | 17 | with torch.no_grad(): 18 | contexts, context_lens, questions, question_lens, answers, answer_lens, labels = dev_data 19 | context_masks = get_masks(context_lens, args.max_post_len) 20 | question_masks = get_masks(question_lens, args.max_ques_len) 21 | answer_masks = get_masks(answer_lens, args.max_ans_len) 22 | contexts = np.array(contexts) 23 | questions = np.array(questions) 24 | answers = np.array(answers) 25 | labels = np.array(labels) 26 | for c, cm, q, qm, a, am, l in iterate_minibatches(contexts, context_masks, 27 | questions, question_masks, 28 | answers, answer_masks, 29 | labels, args.batch_size): 30 | c = torch.tensor(c) 31 | cm = torch.FloatTensor(cm) 32 | q = torch.tensor(q) 33 | qm = torch.FloatTensor(qm) 34 | a = torch.tensor(a) 35 | am = torch.FloatTensor(am) 36 | if USE_CUDA: 37 | c = c.cuda() 38 | cm = cm.cuda() 39 | q = q.cuda() 40 | qm = qm.cuda() 41 | a = a.cuda() 42 | am = am.cuda() 43 | 44 | c_hid, c_out = context_model(torch.transpose(c, 0, 1)) 45 | cm = torch.transpose(cm, 0, 1).unsqueeze(2) 46 | cm = cm.expand(cm.shape[0], cm.shape[1], 2*HIDDEN_SIZE) 47 | c_out = torch.sum(c_out * cm, dim=0) 48 | 49 | # q_out: (sent_len, batch_size, num_directions*HIDDEN_DIM) 50 | q_hid, q_out = question_model(torch.transpose(q, 0, 1)) 51 | qm = torch.transpose(qm, 0, 1).unsqueeze(2) 52 | qm = qm.expand(qm.shape[0], qm.shape[1], 2*HIDDEN_SIZE) 53 | q_out = torch.sum(q_out * qm, dim=0) 54 | 55 | a_hid, a_out = answer_model(torch.transpose(a, 0, 1)) 56 | am = torch.transpose(am, 0, 1).unsqueeze(2) 57 | am = am.expand(am.shape[0], am.shape[1], 2*HIDDEN_SIZE) 58 | a_out = torch.sum(a_out * am, dim=0) 59 | 60 | predictions = utility_model(torch.cat((c_out, q_out, a_out), 1)).squeeze(1) 61 | predictions = torch.nn.functional.sigmoid(predictions) 62 | 63 | l = torch.FloatTensor([float(lab) for lab in l]) 64 | if USE_CUDA: 65 | l = l.cuda() 66 | loss = criterion(predictions, l) 67 | acc = binary_accuracy(predictions, l) 68 | epoch_loss += loss.item() 69 | epoch_acc += acc 70 | num_batches += 1 71 | 72 | return epoch_loss / num_batches, epoch_acc / num_batches 73 | 74 | -------------------------------------------------------------------------------- /src/utility/helper_utility.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | 6 | def iterate_minibatches(c, cm, q, qm, a, am, l, batch_size, shuffle=True): 7 | if shuffle: 8 | indices = np.arange(len(c)) 9 | np.random.shuffle(indices) 10 | for start_idx in range(0, len(c) - batch_size + 1, batch_size): 11 | if shuffle: 12 | excerpt = indices[start_idx:start_idx + batch_size] 13 | else: 14 | excerpt = slice(start_idx, start_idx + batch_size) 15 | yield c[excerpt], cm[excerpt], q[excerpt], qm[excerpt], a[excerpt], am[excerpt], l[excerpt] 16 | 17 | 18 | def get_masks(lens, max_len): 19 | masks = [] 20 | for i in range(len(lens)): 21 | if lens[i] == None: 22 | lens[i] = 0 23 | masks.append([1]*int(lens[i])+[0]*int(max_len-lens[i])) 24 | return np.array(masks) 25 | 26 | 27 | def binary_accuracy(predictions, truth): 28 | """ 29 | Returns accuracy per batch, i.e. if you get 8/10 right, this returns 0.8, NOT 8 30 | """ 31 | correct = 0. 32 | for i in range(len(predictions)): 33 | if predictions[i] >= 0.5 and truth[i] == 1: 34 | correct += 1 35 | elif predictions[i] < 0.5 and truth[i] == 0: 36 | correct += 1 37 | acc = correct/len(predictions) 38 | return acc 39 | 40 | -------------------------------------------------------------------------------- /src/utility/main.py: -------------------------------------------------------------------------------- 1 | from constants import * 2 | import sys 3 | sys.path.append('src/utility') 4 | from FeedForward import * 5 | from evaluate_utility import * 6 | import random 7 | from RNN import * 8 | import time 9 | import torch 10 | from torch import optim 11 | import torch.nn as nn 12 | import torch.autograd as autograd 13 | from train_utility import * 14 | 15 | 16 | def update_neg_data(data, index2word): 17 | ids_seqs, post_seqs, post_lens, ques_seqs, ques_lens, ans_seqs, ans_lens = data 18 | N = 2 19 | labels = [0]*int(N*len(post_seqs)) 20 | new_post_seqs = [None]*int(N*len(post_seqs)) 21 | new_post_lens = [None]*int(N*len(post_lens)) 22 | new_ques_seqs = [None]*int(N*len(ques_seqs)) 23 | new_ques_lens = [None]*int(N*len(ques_lens)) 24 | new_ans_seqs = [None]*int(N*len(ans_seqs)) 25 | new_ans_lens = [None]*int(N*len(ans_lens)) 26 | for i in range(len(post_seqs)): 27 | new_post_seqs[N*i] = post_seqs[i] 28 | new_post_lens[N*i] = post_lens[i] 29 | new_ques_seqs[N*i] = ques_seqs[i] 30 | new_ques_lens[N*i] = ques_lens[i] 31 | new_ans_seqs[N*i] = ans_seqs[i] 32 | new_ans_lens[N*i] = ans_lens[i] 33 | labels[N*i] = 1 34 | for j in range(1, N): 35 | r = random.randint(0, len(post_seqs)-1) 36 | new_post_seqs[N*i+j] = post_seqs[i] 37 | new_post_lens[N*i+j] = post_lens[i] 38 | new_ques_seqs[N*i+j] = ques_seqs[r] 39 | new_ques_lens[N*i+j] = ques_lens[r] 40 | new_ans_seqs[N*i+j] = ans_seqs[r] 41 | new_ans_lens[N*i+j] = ans_lens[r] 42 | labels[N*i+j] = 0 43 | 44 | data = new_post_seqs, new_post_lens, \ 45 | new_ques_seqs, new_ques_lens, \ 46 | new_ans_seqs, new_ans_lens, labels 47 | 48 | return data 49 | 50 | 51 | def run_utility(train_data, test_data, word_embeddings, index2word, args, n_layers): 52 | context_model = RNN(len(word_embeddings), len(word_embeddings[0]), n_layers) 53 | question_model = RNN(len(word_embeddings), len(word_embeddings[0]), n_layers) 54 | answer_model = RNN(len(word_embeddings), len(word_embeddings[0]), n_layers) 55 | utility_model = FeedForward(HIDDEN_SIZE*3*2) 56 | 57 | if USE_CUDA: 58 | word_embeddings = autograd.Variable(torch.FloatTensor(word_embeddings).cuda()) 59 | else: 60 | word_embeddings = autograd.Variable(torch.FloatTensor(word_embeddings)) 61 | 62 | context_model.embedding.weight.data.copy_(word_embeddings) 63 | question_model.embedding.weight.data.copy_(word_embeddings) 64 | answer_model.embedding.weight.data.copy_(word_embeddings) 65 | 66 | # Fix word embeddings 67 | context_model.embedding.weight.requires_grad = False 68 | question_model.embedding.weight.requires_grad = False 69 | answer_model.embedding.weight.requires_grad = False 70 | 71 | optimizer = optim.Adam(list([par for par in context_model.parameters() if par.requires_grad]) + 72 | list([par for par in question_model.parameters() if par.requires_grad]) + 73 | list([par for par in answer_model.parameters() if par.requires_grad]) + 74 | list([par for par in utility_model.parameters() if par.requires_grad])) 75 | 76 | criterion = nn.BCELoss() 77 | # criterion = nn.BCEWithLogitsLoss() 78 | if USE_CUDA: 79 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 80 | context_model = context_model.to(device) 81 | question_model = question_model.to(device) 82 | answer_model = answer_model.to(device) 83 | utility_model = utility_model.to(device) 84 | criterion = criterion.to(device) 85 | 86 | train_data = update_neg_data(train_data, index2word) 87 | test_data = update_neg_data(test_data, index2word) 88 | 89 | for epoch in range(args.n_epochs): 90 | start_time = time.time() 91 | train_loss, train_acc = train_fn(context_model, question_model, answer_model, utility_model, 92 | train_data, optimizer, criterion, args) 93 | valid_loss, valid_acc = evaluate(context_model, question_model, answer_model, utility_model, 94 | test_data, criterion, args) 95 | print('Epoch %d: Train Loss: %.3f, Train Acc: %.3f, Val Loss: %.3f, Val Acc: %.3f' % \ 96 | (epoch, train_loss, train_acc, valid_loss, valid_acc)) 97 | print('Time taken: ', time.time()-start_time) 98 | # if epoch % 5 == 0: 99 | print('Saving model params') 100 | torch.save(context_model.state_dict(), args.context_params+'.epoch%d' % epoch) 101 | torch.save(question_model.state_dict(), args.question_params+'.epoch%d' % epoch) 102 | torch.save(answer_model.state_dict(), args.answer_params+'.epoch%d' % epoch) 103 | torch.save(utility_model.state_dict(), args.utility_params+'.epoch%d' % epoch) 104 | 105 | -------------------------------------------------------------------------------- /src/utility/run_combine_domains.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | DATA_DIR=/fs/clip-scratch/raosudha/clarification_question_generation/utility 4 | UBUNTU=askubuntu.com 5 | UNIX=unix.stackexchange.com 6 | SUPERUSER=superuser.com 7 | SCRIPTS_DIR=/fs/clip-amr/clarification_question_generation_pytorch/src/utility 8 | SITE_NAME=askubuntu_unix_superuser 9 | 10 | mkdir $DATA_DIR/$SITE_NAME 11 | 12 | python $SCRIPTS_DIR/combine_pickle.py $DATA_DIR/$UBUNTU/train_data.p \ 13 | $DATA_DIR/$UNIX/train_data.p \ 14 | $DATA_DIR/$SUPERUSER/train_data.p \ 15 | $DATA_DIR/$SITE_NAME/train_data.p 16 | 17 | python $SCRIPTS_DIR/combine_pickle.py $DATA_DIR/$UBUNTU/tune_data.p \ 18 | $DATA_DIR/$UNIX/tune_data.p \ 19 | $DATA_DIR/$SUPERUSER/tune_data.p \ 20 | $DATA_DIR/$SITE_NAME/tune_data.p 21 | 22 | python $SCRIPTS_DIR/combine_pickle.py $DATA_DIR/$UBUNTU/test_data.p \ 23 | $DATA_DIR/$UNIX/test_data.p \ 24 | $DATA_DIR/$SUPERUSER/test_data.p \ 25 | $DATA_DIR/$SITE_NAME/test_data.p 26 | 27 | -------------------------------------------------------------------------------- /src/utility/run_data_loader.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=utility_data_aus 4 | #SBATCH --output=utility_data_aus 5 | #SBATCH --qos=batch 6 | #SBATCH --mem=36g 7 | #SBATCH --time=24:00:00 8 | 9 | #SITENAME=askubuntu.com 10 | #SITENAME=unix.stackexchange.com 11 | SITENAME=superuser.com 12 | #SITENAME=askubuntu_unix_superuser 13 | #SITENAME=Home_and_Kitchen 14 | QA_DATA_DIR=/fs/clip-scratch/raosudha/clarification_question_generation/question_answering/$SITENAME 15 | UTILITY_DATA_DIR=/fs/clip-scratch/raosudha/clarification_question_generation/utility/$SITENAME 16 | SCRIPT_DIR=/fs/clip-amr/clarification_question_generation_pytorch/src/utility 17 | 18 | export PATH="/fs/clip-amr/anaconda2/bin:$PATH" 19 | 20 | python $SCRIPT_DIR/data_loader.py --train_contexts_fname $QA_DATA_DIR/train_src \ 21 | --train_answers_fname $QA_DATA_DIR/train_tgt \ 22 | --tune_contexts_fname $QA_DATA_DIR/tune_src \ 23 | --tune_answers_fname $QA_DATA_DIR/tune_tgt \ 24 | --test_contexts_fname $QA_DATA_DIR/test_src \ 25 | --test_answers_fname $QA_DATA_DIR/test_tgt \ 26 | --train_data $UTILITY_DATA_DIR/train_data.p \ 27 | --tune_data $UTILITY_DATA_DIR/tune_data.p \ 28 | --test_data $UTILITY_DATA_DIR/test_data.p \ 29 | #--word_to_ix $UTILITY_DATA_DIR/word_to_ix.p \ 30 | 31 | 32 | -------------------------------------------------------------------------------- /src/utility/run_rnn_classifier.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=utility_aus_fullmodel_80K 4 | #SBATCH --output=utility_aus_fullmodel_80K 5 | #SBATCH --qos=gpu-short 6 | #SBATCH --partition=gpu 7 | #SBATCH --gres=gpu 8 | #SBATCH --mem=16g 9 | 10 | #SITENAME=askubuntu.com 11 | #SITENAME=unix.stackexchange.com 12 | #SITENAME=superuser.com 13 | SITENAME=askubuntu_unix_superuser 14 | #SITENAME=Home_and_Kitchen 15 | DATA_DIR=/fs/clip-corpora/amazon_qa 16 | UTILITY_DATA_DIR=/fs/clip-scratch/raosudha/clarification_question_generation/utility/$SITENAME 17 | SCRIPT_DIR=/fs/clip-amr/clarification_question_generation_pytorch/src/utility 18 | EMB_DIR=/fs/clip-amr/clarification_question_generation_pytorch/embeddings/$SITENAME/200_5Kvocab 19 | #EMB_DIR=/fs/clip-amr/ranking_clarification_questions/embeddings 20 | 21 | source /fs/clip-amr/gpu_virtualenv/bin/activate 22 | export PATH="/fs/clip-amr/anaconda2/bin:$PATH" 23 | 24 | python $SCRIPT_DIR/rnn_classifier.py --train_data $UTILITY_DATA_DIR/train_data.p \ 25 | --tune_data $UTILITY_DATA_DIR/tune_data.p \ 26 | --test_data $UTILITY_DATA_DIR/test_data.p \ 27 | --word_embeddings $EMB_DIR/word_embeddings.p \ 28 | --vocab $EMB_DIR/vocab.p \ 29 | --cuda True \ 30 | 31 | 32 | -------------------------------------------------------------------------------- /src/utility/train_utility.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('src/utility') 3 | from helper_utility import * 4 | import numpy as np 5 | import torch 6 | from constants import * 7 | 8 | 9 | def train_fn(context_model, question_model, answer_model, utility_model, train_data, optimizer, criterion, args): 10 | epoch_loss = 0 11 | epoch_acc = 0 12 | 13 | context_model.train() 14 | question_model.train() 15 | answer_model.train() 16 | utility_model.train() 17 | 18 | contexts, context_lens, questions, question_lens, answers, answer_lens, labels = train_data 19 | context_masks = get_masks(context_lens, args.max_post_len) 20 | question_masks = get_masks(question_lens, args.max_ques_len) 21 | answer_masks = get_masks(answer_lens, args.max_ans_len) 22 | contexts = np.array(contexts) 23 | questions = np.array(questions) 24 | answers = np.array(answers) 25 | labels = np.array(labels) 26 | 27 | num_batches = 0 28 | for c, cm, q, qm, a, am, l in iterate_minibatches(contexts, context_masks, questions, question_masks, 29 | answers, answer_masks, labels, args.batch_size): 30 | optimizer.zero_grad() 31 | c = torch.tensor(c.tolist()) 32 | cm = torch.FloatTensor(cm) 33 | q = torch.tensor(q.tolist()) 34 | qm = torch.FloatTensor(qm) 35 | a = torch.tensor(a.tolist()) 36 | am = torch.FloatTensor(am) 37 | if USE_CUDA: 38 | c = c.cuda() 39 | cm = cm.cuda() 40 | q = q.cuda() 41 | qm = qm.cuda() 42 | a = a.cuda() 43 | am = am.cuda() 44 | 45 | # c_out: (sent_len, batch_size, num_directions*HIDDEN_DIM) 46 | c_hid, c_out = context_model(torch.transpose(c, 0, 1)) 47 | cm = torch.transpose(cm, 0, 1).unsqueeze(2) 48 | cm = cm.expand(cm.shape[0], cm.shape[1], 2*HIDDEN_SIZE) 49 | c_out = torch.sum(c_out * cm, dim=0) 50 | 51 | q_hid, q_out = question_model(torch.transpose(q, 0, 1)) 52 | qm = torch.transpose(qm, 0, 1).unsqueeze(2) 53 | qm = qm.expand(qm.shape[0], qm.shape[1], 2*HIDDEN_SIZE) 54 | q_out = torch.sum(q_out * qm, dim=0) 55 | 56 | a_hid, a_out = answer_model(torch.transpose(a, 0, 1)) 57 | am = torch.transpose(am, 0, 1).unsqueeze(2) 58 | am = am.expand(am.shape[0], am.shape[1], 2*HIDDEN_SIZE) 59 | a_out = torch.sum(a_out * am, dim=0) 60 | 61 | predictions = utility_model(torch.cat((c_out, q_out, a_out), 1)).squeeze(1) 62 | # predictions = utility_model(torch.cat((c_out, q_out), 1)).squeeze(1) 63 | # predictions = utility_model(q_out).squeeze(1) 64 | predictions = torch.nn.functional.sigmoid(predictions) 65 | 66 | l = torch.FloatTensor([float(lab) for lab in l]) 67 | if USE_CUDA: 68 | l = l.cuda() 69 | loss = criterion(predictions, l) 70 | acc = binary_accuracy(predictions, l) 71 | loss.backward() 72 | optimizer.step() 73 | epoch_loss += loss.item() 74 | epoch_acc += acc 75 | num_batches += 1 76 | 77 | return epoch_loss, epoch_acc / num_batches 78 | 79 | 80 | --------------------------------------------------------------------------------