├── LICENSE ├── Makefile ├── README.md ├── build_phrase_index.py ├── config.sh ├── densephrases ├── __init__.py ├── demo │ ├── __init__.py │ └── static │ │ ├── examples.txt │ │ ├── examples_context.txt │ │ ├── files │ │ ├── all.js │ │ ├── bootstrap.min.js │ │ ├── favicon.ico │ │ ├── jquery-3.3.1.min.js │ │ ├── overview_new.png │ │ ├── plogo.png │ │ ├── popper.min.js │ │ ├── preview-new.gif │ │ ├── steps.png │ │ └── style.css │ │ ├── index.html │ │ └── index_single.html ├── encoder.py ├── index.py ├── model.py ├── options.py └── utils │ ├── __init__.py │ ├── data_utils.py │ ├── embed_utils.py │ ├── eval_utils.py │ ├── file_utils.py │ ├── kilt │ ├── __init__.py │ ├── eval.py │ └── kilt_utils.py │ ├── open_utils.py │ ├── single_utils.py │ ├── squad_metrics.py │ └── squad_utils.py ├── download.sh ├── eval_phrase_retrieval.py ├── examples ├── README.md ├── create-custom-index │ ├── README.md │ ├── articles.json │ └── questions.json ├── entity-linking │ └── README.md ├── fusion-in-decoder │ └── README.md ├── knowledge-dialogue │ └── README.md └── slot-filling │ └── README.md ├── generate_phrase_vecs.py ├── requirements.txt ├── run_demo.py ├── scripts ├── analysis │ ├── run_analysis.py │ └── run_analysis_dpr.py ├── benchmark │ ├── benchmark_hdf5.py │ ├── create_benchmark_data.py │ └── data │ │ ├── nq_1000_dev_denspi.json │ │ ├── nq_1000_dev_dpr.csv │ │ └── nq_1000_dev_orqa.jsonl ├── dump │ ├── check_dump.py │ ├── filter_hdf5.py │ ├── filter_stats.py │ ├── save_meta.py │ └── split_hdf5.py ├── kilt │ ├── build_title2wikiid.py │ ├── sample_kilt.py │ └── strip_pred.py ├── parallel │ ├── add_to_index.py │ └── dump_phrases.py ├── postprocess │ ├── recall.py │ └── recall_transform.py ├── preprocess │ ├── README.md │ ├── build_db.py │ ├── build_wikisquad.py │ ├── compress_metadata.py │ ├── concat_wikisquad.py │ ├── create_nq_reader.py │ ├── create_nq_reader_doc_wiki.py │ ├── create_nq_reader_wiki.py │ ├── create_openqa.py │ ├── create_psg_hdf5.py │ ├── create_tqa_ds.py │ ├── doc_db.py │ ├── download_wikidump.py │ ├── filter_noans.py │ ├── filter_wiki.py │ ├── merge_openqa.py │ ├── merge_paq.py │ ├── merge_singleqa.py │ ├── nq_utils.py │ ├── prep_wikipedia.py │ ├── sample_nq_reader_doc_wiki.py │ ├── simple_tokenizer.py │ └── stat_entities.py └── question_generation │ ├── filter_qg.py │ └── generate_squad.py ├── setup.py ├── slides └── emnlp2021_slides.pdf ├── train_cross_encoder.py ├── train_query.py └── train_rc.py /config.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Modify below to your choice of directory 4 | export BASE_DIR=./ 5 | 6 | while read -p "Use to $BASE_DIR as the base directory (requires at least 220GB for the installation)? [yes/no]: " choice; do 7 | case "$choice" in 8 | yes ) 9 | break ;; 10 | no ) 11 | while read -p "Type in the directory: " choice; do 12 | case "$choice" in 13 | * ) 14 | export BASE_DIR=$choice; 15 | echo "Base directory set to $BASE_DIR"; 16 | break ;; 17 | esac 18 | done 19 | break ;; 20 | * ) echo "Please answer yes or no."; 21 | exit 0 ;; 22 | esac 23 | done 24 | 25 | # DATA_DIR: for datasets (including 'kilt', 'open-qa', 'single-qa', 'truecase', 'wikidump') 26 | # SAVE_DIR: for pre-trained models or dumps; new models and dumps will also be saved here 27 | # CACHE_DIR: for cache files from huggingface transformers 28 | export DATA_DIR=$BASE_DIR/densephrases-data 29 | export SAVE_DIR=$BASE_DIR/outputs 30 | export CACHE_DIR=$BASE_DIR/cache 31 | 32 | # Create directories 33 | mkdir -p $DATA_DIR 34 | mkdir -p $SAVE_DIR 35 | mkdir -p $SAVE_DIR/logs 36 | mkdir -p $CACHE_DIR 37 | 38 | printf "\nEnvironment variables are set as follows:\n" 39 | echo "DATA_DIR=$DATA_DIR" 40 | echo "SAVE_DIR=$SAVE_DIR" 41 | echo "CACHE_DIR=$CACHE_DIR" 42 | 43 | # Append to bashrc, instructions 44 | while read -p "Add to ~/.bashrc (recommended)? [yes/no]: " choice; do 45 | case "$choice" in 46 | yes ) 47 | echo -e "\n# DensePhrases setup" >> ~/.bashrc; 48 | echo "export DATA_DIR=$DATA_DIR" >> ~/.bashrc; 49 | echo "export SAVE_DIR=$SAVE_DIR" >> ~/.bashrc; 50 | echo "export CACHE_DIR=$CACHE_DIR" >> ~/.bashrc; 51 | break ;; 52 | no ) 53 | break ;; 54 | * ) echo "Please answer yes or no." ;; 55 | esac 56 | done 57 | -------------------------------------------------------------------------------- /densephrases/__init__.py: -------------------------------------------------------------------------------- 1 | from .encoder import Encoder 2 | from .index import MIPS 3 | from .options import Options 4 | from .model import DensePhrases 5 | -------------------------------------------------------------------------------- /densephrases/demo/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-nlp/DensePhrases/9583883ea9390b0308e806c3e72fa5831afa445b/densephrases/demo/__init__.py -------------------------------------------------------------------------------- /densephrases/demo/static/examples.txt: -------------------------------------------------------------------------------- 1 | who determines the size of the supreme court 2 | who won series 7 of great british bake off 3 | when does the new wheel of fortune season start 4 | when does season 3 of lucifer come out 5 | what kind of currency is used in new zealand 6 | who is the highest paid nba player in 2016 7 | who is el senor de los cielos based on 8 | who is paige on days of our lives 9 | who is the creator of star vs the forces of evil 10 | total number of articles in indian constitution at present 11 | who plays percy in the lost city of z 12 | what was uncle jesse's original last name on full house 13 | how many goals scored ronaldo in his career 14 | who plays male lead in far from the madding crowd 15 | when was a whiter shade of pale recorded 16 | when did medicare begin in the united states 17 | who sings don't stand so close to me 18 | where was war on the planet of the apes filmed 19 | who wrote love so soft by kelly clarkson 20 | who is the longest serving manager in man united 21 | Who is the fourth president of USA? 22 | the seventh president of USA 23 | What is South Korea known for? 24 | What tends to lead to more money? 25 | Who was defeated by computer in chess game? 26 | Name three famous writers 27 | What makes a successful startup? 28 | Why did Oracle sue Google? 29 | Where can you find water in desert? 30 | What does AMI stand for? 31 | How heavy was the apollo 11? 32 | What is water consisted of? 33 | What makes a man great? 34 | Which city is famous for coffee? 35 | On which date was Genghis Khan's palace rediscovered by archeaologists? 36 | What is another term for x-ray imaging? 37 | Who scolded Luther about his rudeness? 38 | What was the Yuan's paper money called? 39 | -------------------------------------------------------------------------------- /densephrases/demo/static/files/favicon.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-nlp/DensePhrases/9583883ea9390b0308e806c3e72fa5831afa445b/densephrases/demo/static/files/favicon.ico -------------------------------------------------------------------------------- /densephrases/demo/static/files/overview_new.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-nlp/DensePhrases/9583883ea9390b0308e806c3e72fa5831afa445b/densephrases/demo/static/files/overview_new.png -------------------------------------------------------------------------------- /densephrases/demo/static/files/plogo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-nlp/DensePhrases/9583883ea9390b0308e806c3e72fa5831afa445b/densephrases/demo/static/files/plogo.png -------------------------------------------------------------------------------- /densephrases/demo/static/files/preview-new.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-nlp/DensePhrases/9583883ea9390b0308e806c3e72fa5831afa445b/densephrases/demo/static/files/preview-new.gif -------------------------------------------------------------------------------- /densephrases/demo/static/files/steps.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-nlp/DensePhrases/9583883ea9390b0308e806c3e72fa5831afa445b/densephrases/demo/static/files/steps.png -------------------------------------------------------------------------------- /densephrases/demo/static/files/style.css: -------------------------------------------------------------------------------- 1 | html { position: relative; min-height: 100%; } 2 | body { margin-bottom: 60px; font-family: Verdana, sans-serif;} 3 | .footer { position: absolute; bottom: 0; width: 100%; height: 40px; line-height: 15px; background-color: #f5f5f5; padding-top: 5px; font-size: 12px; text-align: center;} 4 | label, footer { user-select: none; } 5 | .list-group-item:first-of-type { background-color: #BEE6FF; color: #000000; } 6 | .score { position:absolute; bottom:0; right:15px;} 7 | 8 | .list-group-mine .list-group-item { 9 | background-color: #DFDFDF; 10 | border-left-color: #fff; 11 | border-right-color: #fff; 12 | } 13 | .list-group-mine .list-group-item:first-child { 14 | display:none; 15 | } 16 | 17 | .paper_title { 18 | margin-top: 15px; 19 | margin-left: auto; 20 | margin-right: auto; 21 | margin-bottom: auto; 22 | width: 70%; 23 | text-align: center; 24 | } 25 | .detail { 26 | margin: auto; 27 | width: 50%; 28 | } 29 | .detail2 { 30 | margin-top: 8px; 31 | margin-left: auto; 32 | margin-right: auto; 33 | margin-bottom: auto; 34 | width: 50%; 35 | } 36 | .card { 37 | margin-top: -15px; 38 | } 39 | -------------------------------------------------------------------------------- /densephrases/demo/static/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | DensePhrases 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 24 | 25 | 26 | 27 | 34 | 35 | 36 | 37 | 48 | 49 |
50 |
51 |
52 |
53 |

DensePhrase Demo

54 | Project by Jinhyuk Lee  Mujeen Sung  Alexandar Wettig  Jaewoo Kang  Danqi Chen
55 | Korea University  Princeton University
56 |
57 |
58 |
59 | From 5 million Wikipedia articles, DensePhrases searches phrase-level answers to your questions or retrieve relevant passages in real-time. More details are in our ACL'21 paper and EMNLP'21 paper. 60 |

61 | You can type in any natural language question below and get the results in real-time. Retrieved phrases are denoted in boldface for each passage. Current model is case-sensitive and the best results are obtained when queries have proper letter cases (e.g., "Name Apple's products" not "name apple's products"). Our current demo has the following specs: 62 |

63 | 68 |
69 | 70 |
71 | 72 | 73 |
74 |
75 | 78 | 80 |
81 | 82 | 85 |
86 | 89 |
90 |
91 | 92 |
93 |
94 |
95 | 96 | 97 | 99 |   English Wikipedia (2018.12.20) 100 |
101 |
102 |
103 | 104 | 105 |
106 |
    107 |
  • 108 |
109 |
110 | 111 |
112 | 113 | 123 | 124 | 199 | 200 | 201 | 202 | 203 | -------------------------------------------------------------------------------- /densephrases/demo/static/index_single.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | DensePhrases 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 27 |
28 | 29 | 30 |
31 |
32 | 35 | 43 |
44 | 45 | 48 |
49 | 52 |
53 |
54 | 55 |
56 |
Latency:
57 |
58 | 62 | Single passage 63 |
64 |
65 | 66 |
67 |
    68 |
  • 69 |
70 |
71 | 72 |
73 | 74 | 82 | 83 | 84 | 202 | 203 | 204 | 205 | 206 | -------------------------------------------------------------------------------- /densephrases/model.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import logging 3 | import numpy as np 4 | import os 5 | 6 | from densephrases import Options 7 | from densephrases.utils.single_utils import load_encoder 8 | from densephrases.utils.open_utils import load_phrase_index, get_query2vec, load_qa_pairs 9 | from densephrases.utils.squad_utils import TrueCaser 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | class DensePhrases(object): 15 | def __init__(self, 16 | load_dir, 17 | dump_dir, 18 | index_name='start/1048576_flat_OPQ96', 19 | device='cuda', 20 | verbose=False, 21 | **kwargs): 22 | print("This could take up to 15 mins depending on the file reading speed of HDD/SSD") 23 | 24 | # Turn off loggers 25 | if not verbose: 26 | logging.getLogger("densephrases").setLevel(logging.WARNING) 27 | logging.getLogger("transformers").setLevel(logging.WARNING) 28 | 29 | # Get default options 30 | options = Options() 31 | options.add_model_options() 32 | options.add_index_options() 33 | options.add_retrieval_options() 34 | options.add_data_options() 35 | self.args = options.parse() 36 | 37 | # Set options 38 | self.args.load_dir = load_dir 39 | self.args.dump_dir = dump_dir 40 | self.args.cache_dir = os.environ['CACHE_DIR'] 41 | self.args.index_name = index_name 42 | self.args.cuda = True if device == 'cuda' else False 43 | self.args.__dict__.update(kwargs) 44 | 45 | # Load encoder 46 | self.set_encoder(load_dir, device) 47 | 48 | # Load MIPS 49 | self.mips = load_phrase_index(self.args, ignore_logging=not verbose) 50 | 51 | # Others 52 | self.truecase = TrueCaser(os.path.join(os.environ['DATA_DIR'], self.args.truecase_path)) 53 | print("Loading DensePhrases Completed!") 54 | 55 | def search(self, query='', retrieval_unit='phrase', top_k=10, truecase=True, return_meta=False): 56 | # If query is str, single query 57 | single_query = False 58 | if type(query) == str: 59 | batch_query = [query] 60 | single_query = True 61 | else: 62 | assert type(query) == list 63 | batch_query = query 64 | 65 | # Pre-processing 66 | if truecase: 67 | query = [self.truecase.get_true_case(query) if query == query.lower() else query for query in batch_query] 68 | 69 | # Get question vector 70 | outs = self.query2vec(batch_query) 71 | start = np.concatenate([out[0] for out in outs], 0) 72 | end = np.concatenate([out[1] for out in outs], 0) 73 | query_vec = np.concatenate([start, end], 1) 74 | 75 | # Search 76 | agg_strats = {'phrase': 'opt1', 'sentence': 'opt2', 'paragraph': 'opt2', 'document': 'opt3'} 77 | if retrieval_unit not in agg_strats: 78 | raise NotImplementedError(f'"{retrieval_unit}" not supported. Choose one of {agg_strats.keys()}.') 79 | search_top_k = top_k 80 | if retrieval_unit in ['sentence', 'paragraph', 'document']: 81 | search_top_k *= 2 82 | rets = self.mips.search( 83 | query_vec, q_texts=batch_query, nprobe=256, 84 | top_k=search_top_k, max_answer_length=10, 85 | return_idxs=False, aggregate=True, agg_strat=agg_strats[retrieval_unit], 86 | return_sent=True if retrieval_unit == 'sentence' else False 87 | ) 88 | 89 | # Gather results 90 | rets = [ret[:top_k] for ret in rets] 91 | if retrieval_unit == 'phrase': 92 | retrieved = [[rr['answer'] for rr in ret][:top_k] for ret in rets] 93 | elif retrieval_unit == 'sentence': 94 | retrieved = [[rr['context'] for rr in ret][:top_k] for ret in rets] 95 | elif retrieval_unit == 'paragraph': 96 | retrieved = [[rr['context'] for rr in ret][:top_k] for ret in rets] 97 | elif retrieval_unit == 'document': 98 | retrieved = [[rr['title'][0] for rr in ret][:top_k] for ret in rets] 99 | else: 100 | raise NotImplementedError() 101 | 102 | if single_query: 103 | rets = rets[0] 104 | retrieved = retrieved[0] 105 | 106 | if return_meta: 107 | return retrieved, rets 108 | else: 109 | return retrieved 110 | 111 | def set_encoder(self, load_dir, device='cuda'): 112 | self.args.load_dir = load_dir 113 | self.model, self.tokenizer, self.config = load_encoder(device, self.args) 114 | self.query2vec = get_query2vec( 115 | query_encoder=self.model, tokenizer=self.tokenizer, args=self.args, batch_size=64 116 | ) 117 | 118 | def evaluate(self, test_path, **kwargs): 119 | from eval_phrase_retrieval import evaluate as evaluate_fn 120 | 121 | # Set new arguments 122 | new_args = copy.deepcopy(self.args) 123 | new_args.test_path = test_path 124 | new_args.truecase = True 125 | new_args.__dict__.update(kwargs) 126 | 127 | # Run with new_arg 128 | evaluate_fn(new_args, self.mips, self.model, self.tokenizer) 129 | -------------------------------------------------------------------------------- /densephrases/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-nlp/DensePhrases/9583883ea9390b0308e806c3e72fa5831afa445b/densephrases/utils/__init__.py -------------------------------------------------------------------------------- /densephrases/utils/eval_utils.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import ujson as json 3 | import re 4 | import string 5 | import unicodedata 6 | import pickle 7 | from collections import Counter 8 | 9 | def normalize_answer(s): 10 | 11 | def remove_articles(text): 12 | return re.sub(r'\b(a|an|the)\b', ' ', text) 13 | 14 | def white_space_fix(text): 15 | return ' '.join(text.split()) 16 | 17 | def remove_punc(text): 18 | exclude = set(string.punctuation) 19 | return ''.join(ch for ch in text if ch not in exclude) 20 | 21 | def lower(text): 22 | return text.lower() 23 | 24 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 25 | 26 | 27 | def f1_score(prediction, ground_truth): 28 | normalized_prediction = normalize_answer(prediction) 29 | normalized_ground_truth = normalize_answer(ground_truth) 30 | 31 | ZERO_METRIC = (0, 0, 0) 32 | 33 | if normalized_prediction in ['yes', 'no', 'noanswer'] and normalized_prediction != normalized_ground_truth: 34 | return ZERO_METRIC 35 | if normalized_ground_truth in ['yes', 'no', 'noanswer'] and normalized_prediction != normalized_ground_truth: 36 | return ZERO_METRIC 37 | 38 | prediction_tokens = normalized_prediction.split() 39 | ground_truth_tokens = normalized_ground_truth.split() 40 | common = Counter(prediction_tokens) & Counter(ground_truth_tokens) 41 | num_same = sum(common.values()) 42 | if num_same == 0: 43 | return ZERO_METRIC 44 | precision = 1.0 * num_same / len(prediction_tokens) 45 | recall = 1.0 * num_same / len(ground_truth_tokens) 46 | f1 = (2 * precision * recall) / (precision + recall) 47 | return f1, precision, recall 48 | 49 | 50 | def exact_match_score(prediction, ground_truth): 51 | return (normalize_answer(prediction) == normalize_answer(ground_truth)) 52 | 53 | 54 | def drqa_normalize(text): 55 | """Resolve different type of unicode encodings.""" 56 | return unicodedata.normalize('NFD', text) 57 | 58 | 59 | def drqa_exact_match_score(prediction, ground_truth): 60 | """Check if the prediction is a (soft) exact match with the ground truth.""" 61 | return normalize_answer(prediction) == normalize_answer(ground_truth) 62 | 63 | 64 | def drqa_regex_match_score(prediction, pattern): 65 | """Check if the prediction matches the given regular expression.""" 66 | try: 67 | compiled = re.compile( 68 | pattern, 69 | flags=re.IGNORECASE + re.UNICODE + re.MULTILINE 70 | ) 71 | except BaseException as e: 72 | # logger.warn('Regular expression failed to compile: %s' % pattern) 73 | # print('re failed to compile: [%s] due to [%s]' % (pattern, e)) 74 | return False 75 | return compiled.match(prediction) is not None 76 | 77 | 78 | def drqa_metric_max_over_ground_truths(metric_fn, prediction, ground_truths): 79 | """Given a prediction and multiple valid answers, return the score of 80 | the best prediction-answer_n pair given a metric function. 81 | """ 82 | scores_for_ground_truths = [] 83 | for ground_truth in ground_truths: 84 | score = metric_fn(prediction, ground_truth) 85 | scores_for_ground_truths.append(score) 86 | return max(scores_for_ground_truths) 87 | 88 | 89 | def update_answer(metrics, prediction, gold): 90 | em = exact_match_score(prediction, gold) 91 | f1, prec, recall = f1_score(prediction, gold) 92 | metrics['em'] += em 93 | metrics['f1'] += f1 94 | metrics['prec'] += prec 95 | metrics['recall'] += recall 96 | return em, prec, recall 97 | 98 | 99 | def update_sp(metrics, prediction, gold): 100 | cur_sp_pred = set(map(tuple, prediction)) 101 | gold_sp_pred = set(map(tuple, gold)) 102 | tp, fp, fn = 0, 0, 0 103 | for e in cur_sp_pred: 104 | if e in gold_sp_pred: 105 | tp += 1 106 | else: 107 | fp += 1 108 | for e in gold_sp_pred: 109 | if e not in cur_sp_pred: 110 | fn += 1 111 | prec = 1.0 * tp / (tp + fp) if tp + fp > 0 else 0.0 112 | recall = 1.0 * tp / (tp + fn) if tp + fn > 0 else 0.0 113 | f1 = 2 * prec * recall / (prec + recall) if prec + recall > 0 else 0.0 114 | em = 1.0 if fp + fn == 0 else 0.0 115 | metrics['sp_em'] += em 116 | metrics['sp_f1'] += f1 117 | metrics['sp_prec'] += prec 118 | metrics['sp_recall'] += recall 119 | return em, prec, recall 120 | 121 | 122 | def eval(prediction_file, gold_file): 123 | with open(prediction_file) as f: 124 | prediction = json.load(f) 125 | with open(gold_file) as f: 126 | gold = json.load(f) 127 | 128 | metrics = {'em': 0, 'f1': 0, 'prec': 0, 'recall': 0, 129 | 'sp_em': 0, 'sp_f1': 0, 'sp_prec': 0, 'sp_recall': 0, 130 | 'joint_em': 0, 'joint_f1': 0, 'joint_prec': 0, 'joint_recall': 0} 131 | 132 | for dp in gold: 133 | cur_id = dp['_id'] 134 | em, prec, recall = update_answer( 135 | metrics, prediction['answer'][cur_id], dp['answer']) 136 | 137 | N = len(gold) 138 | for k in metrics.keys(): 139 | metrics[k] /= N 140 | 141 | print(metrics) 142 | 143 | 144 | def analyze(prediction_file, gold_file): 145 | with open(prediction_file) as f: 146 | prediction = json.load(f) 147 | with open(gold_file) as f: 148 | gold = json.load(f) 149 | metrics = {'em': 0, 'f1': 0, 'prec': 0, 'recall': 0, 150 | 'sp_em': 0, 'sp_f1': 0, 'sp_prec': 0, 'sp_recall': 0, 151 | 'joint_em': 0, 'joint_f1': 0, 'joint_prec': 0, 'joint_recall': 0} 152 | 153 | for dp in gold: 154 | cur_id = dp['_id'] 155 | 156 | em, prec, recall = update_answer( 157 | metrics, prediction['answer'][cur_id], dp['answer']) 158 | if (prec + recall == 0): 159 | f1 = 0 160 | else: 161 | f1 = 2 * prec * recall / (prec+recall) 162 | 163 | print (dp['answer'], prediction['answer'][cur_id]) 164 | print (f1, em) 165 | a = input() 166 | 167 | 168 | if __name__ == '__main__': 169 | #eval(sys.argv[1], sys.argv[2]) 170 | analyze(sys.argv[1], sys.argv[2]) 171 | -------------------------------------------------------------------------------- /densephrases/utils/kilt/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-nlp/DensePhrases/9583883ea9390b0308e806c3e72fa5831afa445b/densephrases/utils/kilt/__init__.py -------------------------------------------------------------------------------- /densephrases/utils/kilt/kilt_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | import nltk 9 | import json 10 | import os 11 | import logging 12 | import sys 13 | import time 14 | import string 15 | import random 16 | 17 | 18 | def normalize_answer(s): 19 | """Lower text and remove punctuation, articles and extra whitespace.""" 20 | 21 | def remove_punc(text): 22 | exclude = set(string.punctuation) 23 | return "".join(ch for ch in text if ch not in exclude) 24 | 25 | def lower(text): 26 | return text.lower() 27 | 28 | return remove_punc(lower(s)) 29 | 30 | 31 | def validate_datapoint(datapoint, logger): 32 | 33 | # input is a string 34 | if not isinstance(datapoint["input"], str): 35 | if logger: 36 | logger.warning( 37 | "[{}] input is not a string {}".format( 38 | datapoint["id"], datapoint["input"] 39 | ) 40 | ) 41 | return False 42 | 43 | # output is not empty 44 | if "output" in datapoint: 45 | if len(datapoint["output"]) == 0: 46 | if logger: 47 | logger.warning("[{}] empty output".format(datapoint["id"])) 48 | return False 49 | 50 | for output in datapoint["output"]: 51 | # answer is a string 52 | if "answer" in output: 53 | if not isinstance(output["answer"], str): 54 | if logger: 55 | logger.warning( 56 | "[{}] answer is not a string {}".format( 57 | datapoint["id"], output["answer"] 58 | ) 59 | ) 60 | return False 61 | 62 | # provenance is not empty 63 | # if len(output["provenance"]) == 0: 64 | # if logger: 65 | # logger.warning("[{}] empty provenance".format(datapoint["id"])) 66 | # return False 67 | 68 | if "provenance" in output: 69 | for provenance in output["provenance"]: 70 | # wikipedia_id is provided 71 | if not isinstance(provenance["wikipedia_id"], str): 72 | if logger: 73 | logger.warning( 74 | "[{}] wikipedia_id is not a string {}".format( 75 | datapoint["id"], provenance["wikipedia_id"] 76 | ) 77 | ) 78 | return False 79 | 80 | # title is provided 81 | if not isinstance(provenance["title"], str): 82 | if logger: 83 | logger.warning( 84 | "[{}] title is not a string {}".format( 85 | datapoint["id"], provenance["title"] 86 | ) 87 | ) 88 | return False 89 | 90 | return True 91 | 92 | 93 | def load_data(filename): 94 | data = [] 95 | with open(filename, "r") as fin: 96 | lines = fin.readlines() 97 | for line in lines: 98 | data.append(json.loads(line)) 99 | return data 100 | 101 | 102 | def store_data(filename, data): 103 | with open(filename, "w+") as outfile: 104 | for idx, element in enumerate(data): 105 | # print(round(idx * 100 / len(data), 2), "%", end="\r") 106 | # sys.stdout.flush() 107 | json.dump(element, outfile) 108 | outfile.write("\n") 109 | 110 | 111 | def get_bleu(candidate_tokens, gold_tokens): 112 | 113 | candidate_tokens = [x for x in candidate_tokens if len(x.strip()) > 0] 114 | gold_tokens = [x for x in gold_tokens if len(x.strip()) > 0] 115 | 116 | # The default BLEU calculates a score for up to 117 | # 4-grams using uniform weights (this is called BLEU-4) 118 | weights = (0.25, 0.25, 0.25, 0.25) 119 | 120 | if len(gold_tokens) < 4: 121 | # lower order ngrams 122 | weights = [1.0 / len(gold_tokens) for _ in range(len(gold_tokens))] 123 | 124 | BLEUscore = nltk.translate.bleu_score.sentence_bleu( 125 | [candidate_tokens], gold_tokens, weights=weights 126 | ) 127 | return BLEUscore 128 | 129 | 130 | # split a list in num parts evenly 131 | def chunk_it(seq, num): 132 | assert num > 0 133 | chunk_len = len(seq) // num 134 | chunks = [seq[i * chunk_len : i * chunk_len + chunk_len] for i in range(num)] 135 | 136 | diff = len(seq) - chunk_len * num # 0 <= diff < num 137 | for i in range(diff): 138 | chunks[i].append(seq[chunk_len * num + i]) 139 | 140 | return chunks 141 | 142 | 143 | def init_logging(base_logdir, modelname, logger=None): 144 | 145 | # logging format 146 | # "%(asctime)s - %(name)s - %(levelname)s - %(message)s" 147 | formatter = logging.Formatter( 148 | "%(asctime)s - %(name)s - %(levelname)s - %(message)s" 149 | ) 150 | 151 | log_directory = "{}/{}/".format(base_logdir, modelname) 152 | 153 | if logger == None: 154 | logger = logging.getLogger("KILT") 155 | 156 | logger.setLevel(logging.DEBUG) 157 | 158 | # console handler 159 | ch = logging.StreamHandler(sys.stdout) 160 | ch.setLevel(logging.DEBUG) 161 | ch.setFormatter(formatter) 162 | 163 | logger.addHandler(ch) 164 | 165 | else: 166 | # remove previous file handler 167 | logger.handlers.pop() 168 | 169 | os.makedirs(log_directory, exist_ok=True) 170 | 171 | # file handler 172 | fh = logging.FileHandler(str(log_directory) + "/info.log") 173 | fh.setLevel(logging.DEBUG) 174 | fh.setFormatter(formatter) 175 | 176 | logger.addHandler(fh) 177 | 178 | logger.propagate = False 179 | logger.info("logging in {}".format(log_directory)) 180 | return logger 181 | 182 | 183 | def create_logdir_with_timestamp(base_logdir): 184 | timestr = time.strftime("%Y%m%d_%H%M%S") 185 | # create new directory 186 | log_directory = "{}/{}_{}/".format(base_logdir, timestr, random.randint(0, 1000)) 187 | os.makedirs(log_directory) 188 | return log_directory -------------------------------------------------------------------------------- /densephrases/utils/open_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import logging 4 | import json 5 | import torch 6 | import numpy as np 7 | 8 | from densephrases import MIPS 9 | from densephrases.utils.single_utils import backward_compat 10 | from densephrases.utils.squad_utils import get_question_dataloader, TrueCaser 11 | from densephrases.utils.embed_utils import get_question_results 12 | 13 | from transformers import ( 14 | MODEL_MAPPING, 15 | AutoConfig, 16 | AutoTokenizer, 17 | AutoModel, 18 | ) 19 | 20 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', datefmt='%m/%d/%Y %H:%M:%S', 21 | level=logging.INFO) 22 | logger = logging.getLogger(__name__) 23 | truecase = None 24 | 25 | 26 | def load_phrase_index(args, ignore_logging=False): 27 | # Configure paths for index serving 28 | phrase_dump_dir = os.path.join(args.dump_dir, args.phrase_dir) 29 | index_dir = os.path.join(args.dump_dir, args.index_name) 30 | index_path = os.path.join(index_dir, args.index_path) 31 | idx2id_path = os.path.join(index_dir, args.idx2id_path) 32 | 33 | # Load mips 34 | if 'aggregate' in args.__dict__.keys(): 35 | logger.info(f'Aggregate: {args.aggregate}') 36 | mips = MIPS( 37 | phrase_dump_dir=phrase_dump_dir, 38 | index_path=index_path, 39 | idx2id_path=idx2id_path, 40 | cuda=args.cuda, 41 | logging_level=logging.WARNING if ignore_logging else (logging.DEBUG if args.verbose_logging else logging.INFO), 42 | ) 43 | return mips 44 | 45 | 46 | def load_cross_encoder(device, args): 47 | 48 | # Configure paths for cross-encoder serving 49 | cross_encoder = torch.load( 50 | os.path.join(args.load_dir, "pytorch_model.bin"), map_location=torch.device('cpu') 51 | ) 52 | new_qd = {n[len('bert')+1:]: p for n, p in cross_encoder.items() if 'bert' in n} 53 | new_linear = {n[len('qa_outputs')+1:]: p for n, p in cross_encoder.items() if 'qa_outputs' in n} 54 | config, unused_kwargs = AutoConfig.from_pretrained( 55 | args.pretrained_name_or_path, 56 | cache_dir=args.cache_dir if args.cache_dir else None, 57 | return_unused_kwargs=True 58 | ) 59 | tokenizer = AutoTokenizer.from_pretrained( 60 | args.tokenizer_name if args.tokenizer_name else args.pretrained_name_or_path, 61 | do_lower_case=args.do_lower_case, 62 | cache_dir=args.cache_dir if args.cache_dir else None, 63 | ) 64 | model = AutoModel.from_pretrained( 65 | args.pretrained_name_or_path, 66 | from_tf=bool(".ckpt" in args.pretrained_name_or_path), 67 | config=config, 68 | cache_dir=args.cache_dir if args.cache_dir else None, 69 | ) 70 | model.load_state_dict(new_qd) 71 | qa_outputs = torch.nn.Linear(config.hidden_size, 2) 72 | qa_outputs.load_state_dict(new_linear) 73 | ce_model = torch.nn.ModuleList( 74 | [model, qa_outputs] 75 | ) 76 | ce_model.to(device) 77 | 78 | logger.info(f'CrossEncoder loaded from {args.load_dir} having {MODEL_MAPPING[config.__class__]}') 79 | logger.info('Number of model parameters: {:,}'.format(sum(p.numel() for p in ce_model.parameters()))) 80 | return ce_model, tokenizer 81 | 82 | 83 | def get_query2vec(query_encoder, tokenizer, args, batch_size=64): 84 | device = 'cuda' if args.cuda else 'cpu' 85 | def query2vec(queries): 86 | question_dataloader, question_examples, query_features = get_question_dataloader( 87 | queries, tokenizer, args.max_query_length, batch_size=batch_size 88 | ) 89 | question_results = get_question_results( 90 | question_examples, query_features, question_dataloader, device, query_encoder, batch_size=batch_size 91 | ) 92 | if args.verbose_logging: 93 | logger.info(f"{len(query_features)} queries: {' '.join(query_features[0].tokens_)}") 94 | outs = [] 95 | for qr_idx, question_result in enumerate(question_results): 96 | out = ( 97 | question_result.start_vec.tolist(), question_result.end_vec.tolist(), query_features[qr_idx].tokens_ 98 | ) 99 | outs.append(out) 100 | return outs 101 | return query2vec 102 | 103 | 104 | def load_qa_pairs(data_path, args, q_idx=None, draft_num_examples=100, shuffle=False): 105 | q_ids = [] 106 | questions = [] 107 | answers = [] 108 | titles = [] 109 | data = json.load(open(data_path))['data'] 110 | for data_idx, item in enumerate(data): 111 | if q_idx is not None: 112 | if data_idx != q_idx: 113 | continue 114 | q_id = item['id'] 115 | if 'origin' in item: 116 | q_id = item['origin'].split('.')[0] + '-' + q_id 117 | question = item['question'] 118 | if '[START_ENT]' in question: 119 | question = question[max(question.index('[START_ENT]')-300, 0):question.index('[END_ENT]')+300] 120 | answer = item['answers'] 121 | title = item.get('titles', ['']) 122 | if len(answer) == 0: 123 | continue 124 | q_ids.append(q_id) 125 | questions.append(question) 126 | answers.append(answer) 127 | titles.append(title) 128 | questions = [query[:-1] if query.endswith('?') else query for query in questions] 129 | # questions = [query.lower() for query in questions] # force lower query 130 | 131 | if args.do_lower_case: 132 | logger.info(f'Lowercasing queries') 133 | questions = [query.lower() for query in questions] 134 | 135 | if shuffle: 136 | qa_pairs = list(zip(q_ids, questions, answers, titles)) 137 | random.shuffle(qa_pairs) 138 | q_ids, questions, answers, titles = zip(*qa_pairs) 139 | logger.info(f'Shuffling QA pairs') 140 | 141 | if args.draft: 142 | q_ids = np.array(q_ids)[:draft_num_examples].tolist() 143 | questions = np.array(questions)[:draft_num_examples].tolist() 144 | answers = np.array(answers)[:draft_num_examples].tolist() 145 | titles = np.array(titles)[:draft_num_examples].tolist() 146 | 147 | if args.truecase: 148 | try: 149 | global truecase 150 | if truecase is None: 151 | logger.info('loading truecaser') 152 | truecase = TrueCaser(os.path.join(os.environ['DATA_DIR'], args.truecase_path)) 153 | logger.info('Truecasing queries') 154 | questions = [truecase.get_true_case(query) if query == query.lower() else query for query in questions] 155 | except Exception as e: 156 | print(e) 157 | 158 | logger.info(f'Loading {len(questions)} questions from {data_path}') 159 | logger.info(f'Sample Q ({q_ids[0]}): {questions[0]}, A: {answers[0]}, Title: {titles[0]}') 160 | return q_ids, questions, answers, titles 161 | 162 | -------------------------------------------------------------------------------- /densephrases/utils/single_utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | import logging 4 | import copy 5 | import os 6 | import numpy as np 7 | 8 | from functools import partial 9 | from transformers import ( 10 | MODEL_MAPPING, 11 | AutoConfig, 12 | AutoTokenizer, 13 | AutoModel, 14 | ) 15 | from densephrases import Encoder 16 | 17 | logger = logging.getLogger(__name__) 18 | 19 | 20 | def set_seed(args): 21 | random.seed(args.seed) 22 | np.random.seed(args.seed) 23 | torch.manual_seed(args.seed) 24 | if torch.cuda.is_available(): 25 | torch.cuda.manual_seed_all(args.seed) 26 | 27 | 28 | def to_list(tensor): 29 | return tensor.detach().cpu().tolist() 30 | 31 | 32 | def to_numpy(tensor): 33 | return tensor.detach().cpu().numpy() 34 | 35 | 36 | def backward_compat(model_dict): 37 | # Remove teacher 38 | model_dict = {key: val for key, val in model_dict.items() if not key.startswith('cross_encoder')} 39 | model_dict = {key: val for key, val in model_dict.items() if not key.startswith('bert_qd')} 40 | model_dict = {key: val for key, val in model_dict.items() if not key.startswith('qa_outputs')} 41 | 42 | # Replace old names to current ones 43 | mapping = { 44 | 'bert_start': 'phrase_encoder', 45 | 'bert_q_start': 'query_start_encoder', 46 | 'bert_q_end': 'query_end_encoder', 47 | } 48 | new_model_dict = {} 49 | for key, val in model_dict.items(): 50 | for old_key, new_key in mapping.items(): 51 | if key.startswith(old_key): 52 | new_model_dict[key.replace(old_key, new_key)] = val 53 | elif all(not key.startswith(old_k) for old_k in mapping.keys()): 54 | new_model_dict[key] = val 55 | 56 | return new_model_dict 57 | 58 | 59 | def load_encoder(device, args, phrase_only=False): 60 | # Configure paths for DnesePhrases 61 | args.model_type = args.model_type.lower() 62 | config = AutoConfig.from_pretrained( 63 | args.config_name if args.config_name else args.pretrained_name_or_path, 64 | cache_dir=args.cache_dir if args.cache_dir else None, 65 | ) 66 | tokenizer = AutoTokenizer.from_pretrained( 67 | args.tokenizer_name if args.tokenizer_name else args.pretrained_name_or_path, 68 | do_lower_case=args.do_lower_case, 69 | cache_dir=args.cache_dir if args.cache_dir else None, 70 | ) 71 | 72 | # Prepare PLM if not load_dir 73 | pretrained = None 74 | if not args.load_dir: 75 | pretrained = AutoModel.from_pretrained( 76 | args.pretrained_name_or_path, 77 | config=config, 78 | cache_dir=args.cache_dir if args.cache_dir else None, 79 | ) 80 | load_class = Encoder 81 | logger.info(f'DensePhrases encoder initialized with {args.pretrained_name_or_path} ({pretrained.__class__})') 82 | else: 83 | # TODO: need to update transformers so that from_pretrained maps to model hub directly 84 | if args.load_dir.startswith('princeton-nlp'): 85 | hf_model_path = f"https://huggingface.co/{args.load_dir}/resolve/main/pytorch_model.bin" 86 | else: 87 | hf_model_path = args.load_dir 88 | load_class = partial( 89 | Encoder.from_pretrained, 90 | pretrained_model_name_or_path=hf_model_path, 91 | cache_dir=args.cache_dir if args.cache_dir else None, 92 | ) 93 | logger.info(f'DensePhrases encoder loaded from {args.load_dir}') 94 | 95 | # DensePhrases encoder object 96 | model = load_class( 97 | config=config, 98 | tokenizer=tokenizer, 99 | transformer_cls=MODEL_MAPPING[config.__class__], 100 | pretrained=copy.deepcopy(pretrained) if pretrained is not None else None, 101 | lambda_kl=getattr(args, 'lambda_kl', 0.0), 102 | lambda_neg=getattr(args, 'lambda_neg', 0.0), 103 | lambda_flt=getattr(args, 'lambda_flt', 0.0), 104 | ) 105 | 106 | # Phrase only (for phrase embedding) 107 | if phrase_only: 108 | if hasattr(model, "module"): 109 | del model.module.query_start_encoder 110 | del model.module.query_end_encoder 111 | else: 112 | del model.query_start_encoder 113 | del model.query_end_encoder 114 | logger.info("Load only phrase encoders for embedding phrases") 115 | 116 | model.to(device) 117 | logger.info('Number of model parameters: {:,}'.format(sum(p.numel() for p in model.parameters()))) 118 | return model, tokenizer, config 119 | -------------------------------------------------------------------------------- /download.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | while read -p "Choose a resource to download [data/wiki/models/index]: " choice; do 4 | case "$choice" in 5 | data ) 6 | TARGET=$choice 7 | TARGET_DIR=$DATA_DIR 8 | break ;; 9 | wiki ) 10 | TARGET=$choice 11 | TARGET_DIR=$DATA_DIR 12 | break ;; 13 | models ) 14 | TARGET=$choice 15 | TARGET_DIR=$SAVE_DIR 16 | break ;; 17 | index ) 18 | TARGET=$choice 19 | TARGET_DIR=$SAVE_DIR 20 | break ;; 21 | * ) echo "Please type among [data/wiki/models/index]"; 22 | exit 0 ;; 23 | esac 24 | done 25 | 26 | echo "$TARGET will be downloaded at $TARGET_DIR" 27 | 28 | # Download + untar + rm 29 | case "$TARGET" in 30 | data ) 31 | wget -O "$TARGET_DIR/densephrases-data.tar.gz" "https://nlp.cs.princeton.edu/projects/densephrases/densephrases-data.tar.gz" 32 | tar -xzvf "$TARGET_DIR/densephrases-data.tar.gz" -C "$TARGET_DIR" --strip 1 33 | rm "$TARGET_DIR/densephrases-data.tar.gz" ;; 34 | wiki ) 35 | wget -O "$TARGET_DIR/wikidump.tar.gz" "https://nlp.cs.princeton.edu/projects/densephrases/wikidump.tar.gz" 36 | tar -xzvf "$TARGET_DIR/wikidump.tar.gz" -C "$TARGET_DIR" 37 | rm "$TARGET_DIR/wikidump.tar.gz" ;; 38 | models ) 39 | wget -O "$TARGET_DIR/outputs.tar.gz" "https://nlp.cs.princeton.edu/projects/densephrases/outputs.tar.gz" 40 | tar -xzvf "$TARGET_DIR/outputs.tar.gz" -C "$TARGET_DIR" --strip 1 41 | rm "$TARGET_DIR/outputs.tar.gz" ;; 42 | index ) 43 | wget -O "$TARGET_DIR/densephrases-multi_wiki-20181220.tar.gz" "https://nlp.cs.princeton.edu/projects/densephrases/densephrases-multi_wiki-20181220.tar.gz" 44 | tar -xzvf "$TARGET_DIR/densephrases-multi_wiki-20181220.tar.gz" -C "$TARGET_DIR" 45 | rm "$TARGET_DIR/densephrases-multi_wiki-20181220.tar.gz" ;; 46 | * ) echo "Wrong target $TARGET"; 47 | exit 0 ;; 48 | esac 49 | 50 | echo "Downloading $TARGET done!" 51 | -------------------------------------------------------------------------------- /examples/README.md: -------------------------------------------------------------------------------- 1 | # DensePhrases Examples 2 | 3 | We provide descriptions on how to use DensePhrases for different applications. 4 | For instance, based on the retrieved passages from DensePhrases, you can train a state-of-the-art open-domain question answering model called [Fusion-in-Decoder](https://arxiv.org/abs/2007.01282) by Izacard and Grave, 2021, or you can run entity linking with DensePhrases. 5 | 6 | * [Basics: Multi-Granularity Text Retrieval](#basics-multi-granularity-text-retrieval) 7 | * [Create a Custom Phrase Index](https://github.com/princeton-nlp/DensePhrases/tree/main/examples/create-custom-index) 8 | * [Open-Domain QA with Fusion-in-Decoder](https://github.com/princeton-nlp/DensePhrases/tree/main/examples/fusion-in-decoder) 9 | * [Entity Linking](https://github.com/princeton-nlp/DensePhrases/tree/main/examples/entity-linking) 10 | * [Knowledge-grounded Dialogue](https://github.com/princeton-nlp/DensePhrases/tree/main/examples/knowledge-dialogue) 11 | * [Slot Filling](https://github.com/princeton-nlp/DensePhrases/tree/main/examples/slot-filling) 12 | 13 | ## Basics: Multi-Granularity Text Retrieval 14 | The most basic use of DensePhrases is to retrieve phrases, sentences, paragraphs, or documents for your query. 15 | ```python 16 | >>> from densephrases import DensePhrases 17 | 18 | # Load DensePhrases 19 | >>> model = DensePhrases( 20 | ... load_dir='princeton-nlp/densephrases-multi-query-multi', 21 | ... dump_dir='/path/to/densephrases-multi_wiki-20181220/dump' 22 | ... ) 23 | 24 | # Search phrases 25 | >>> model.search('Who won the Nobel Prize in peace?', retrieval_unit='phrase', top_k=5) 26 | ['Denis Mukwege,', 'Theodore Roosevelt', 'Denis Mukwege', 'John Mott', 'Mother Teresa'] 27 | 28 | # Search sentences 29 | >>> model.search('Why is the sky blue', retrieval_unit='sentence', top_k=1) 30 | ['The blue color is sometimes wrongly attributed to Rayleigh scattering, which is responsible for the color of the sky.'] 31 | 32 | # Search paragraphs 33 | >>> model.search('How to become a great researcher', retrieval_unit='paragraph', top_k=1) 34 | ['... Levine said he believes the key to being a great researcher is having passion for research in and working on questions that the researcher is truly curious about. He said: "Have patience, persistence and enthusiasm and you’ll be fine."'] 35 | 36 | # Search documents (Wikipedia titles) 37 | >>> model.search('What is the history of internet', retrieval_unit='document', top_k=3) 38 | ['Computer network', 'History of the World Wide Web', 'History of the Internet'] 39 | ``` 40 | 41 | For batch queries, simply feed a list of queries as ``query``. 42 | To get more detailed search results, set ``return_meta=True`` as follows: 43 | ```python 44 | # Search phrases and get detailed results 45 | >>> phrases, metadata = model.search(['Who won the Nobel Prize in peace?', 'Name products of Apple.'], retrieval_unit='phrase', return_meta=True) 46 | 47 | >>> phrases[0] 48 | ['Denis Mukwege,', 'Theodore Roosevelt', 'Denis Mukwege', 'John Mott', 'Muhammad Yunus', ...] 49 | 50 | >>> metadata[0] 51 | [{'context': '... The most recent as of 2018, Denis Mukwege, was awarded his Peace Prize in 2018. ...', 'title': ['List of black Nobel laureates'], 'doc_idx': 5433697, 'start_pos': 558, 'end_pos': 572, 'start_idx': 15, 'end_idx': 16, 'score': 99.670166015625, ..., 'answer': 'Denis Mukwege,'}, ...] 52 | ``` 53 | Note that when the model returns phrases, it also returns passages in its metadata as described in our [EMNLP paper](https://arxiv.org/abs/2109.08133).
54 | 55 | ### CPU-only Mode 56 | ```python 57 | # Load DensePhrases in CPU-only mode 58 | >>> model = DensePhrases( 59 | ... load_dir='princeton-nlp/densephrases-multi-query-multi', 60 | ... dump_dir='/path/to/densephrases-multi_wiki-20181220/dump', 61 | ... device='cpu', 62 | ... max_query_length=24, # reduce the maximum query length for a faster query encoding (optional) 63 | ... ) 64 | ``` 65 | 66 | ### Changing the Index or the Encoder 67 | ```python 68 | # Load DensePhrases with a smaller phrase index 69 | >>> model = DensePhrases( 70 | ... load_dir='princeton-nlp/densephrases-multi-query-multi', 71 | ... dump_dir='/path/to/densephrases-multi_wiki-20181220/dump', 72 | ... index_name='start/1048576_flat_OPQ96_small' 73 | ... ) 74 | 75 | # Change the DensePhrases encoder to 'princeton-nlp/densephrases-multi-query-tqa' (trained on TriviaQA) 76 | >>> model.set_encoder('princeton-nlp/densephrases-multi-query-tqa') 77 | ``` 78 | 79 | ### Evaluation 80 | ```python 81 | >>> import os 82 | 83 | # Evaluate loaded DensePhrases on Natural Questions 84 | >>> model.evaluate(test_path=os.path.join(os.environ['DATA_DIR'], 'open-qa/nq-open/test_preprocessed.json')) 85 | ``` 86 | -------------------------------------------------------------------------------- /examples/create-custom-index/README.md: -------------------------------------------------------------------------------- 1 | # Creating a Custom Phrase Index with DensePhrases 2 | 3 | Basically, DensePhrases uses a text corpus pre-processed in the following format (a snippet from [articles.json](https://github.com/princeton-nlp/DensePhrases/blob/main/examples/create-custom-index/articles.json)): 4 | ``` 5 | { 6 | "data": [ 7 | { 8 | "title": "America's Got Talent (season 4)", 9 | "paragraphs": [ 10 | { 11 | "context": " The fourth season of \"America's Got Talent\", ... Country singer Kevin Skinner was named the winner on September 16, 2009 ..." 12 | }, 13 | { 14 | "context": " Season four was Hasselhoff's final season as a judge. This season started broadcasting live on August 4, 2009. ..." 15 | }, 16 | ... 17 | ] 18 | }, 19 | ] 20 | } 21 | ``` 22 | 23 | ## Building a Phrase Index 24 | Each `context` contains a single natural paragraph of a variable length. The following command creates phrase vectors for the custom corpus (`articles.json`) with the `densephrases-multi` model. 25 | 26 | ```bash 27 | python generate_phrase_vecs.py \ 28 | --model_type bert \ 29 | --pretrained_name_or_path SpanBERT/spanbert-base-cased \ 30 | --data_dir ./ \ 31 | --cache_dir $CACHE_DIR \ 32 | --predict_file examples/create-custom-index/articles.json \ 33 | --do_dump \ 34 | --max_seq_length 512 \ 35 | --doc_stride 500 \ 36 | --fp16 \ 37 | --filter_threshold -2.0 \ 38 | --append_title \ 39 | --load_dir $SAVE_DIR/densephrases-multi \ 40 | --output_dir $SAVE_DIR/densephrases-multi_sample 41 | ``` 42 | The phrase vectors (and their metadata) will be saved under `$SAVE_DIR/densephrases-multi_sample/dump/phrase`. Now you need to create a faiss index as follows: 43 | ```bash 44 | python build_phrase_index.py \ 45 | --dump_dir $SAVE_DIR/densephrases-multi_sample/dump \ 46 | --stage all \ 47 | --replace \ 48 | --num_clusters 32 \ 49 | --fine_quant OPQ96 \ 50 | --doc_sample_ratio 1.0 \ 51 | --vec_sample_ratio 1.0 \ 52 | --cuda 53 | 54 | # Compress metadata for faster inference 55 | python scripts/preprocess/compress_metadata.py \ 56 | --input_dump_dir $SAVE_DIR/densephrases-multi_sample/dump/phrase \ 57 | --output_dir $SAVE_DIR/densephrases-multi_sample/dump 58 | ``` 59 | Note that this example uses a very small text corpus and the hyperparameters for `build_phrase_index.py` in a larger scale corpus can be found [here](https://github.com/princeton-nlp/DensePhrases/tree/main#densephrases-training-indexing-and-inference). 60 | Depending on the size of the corpus, the hyperparameters should change as follows: 61 | * `num_clusters`: Set to make the number of vectors per cluster < 2000 (e.g., `--num_culsters 256` works well for `dev_wiki.json`). 62 | * `doc/vec_sample_ratio`: Use the default value (0.2) except for the small scale experiments (shown above). 63 | * `fine_quant`: Currently only OPQ96 is supported. 64 | 65 | The phrase index (with IVFOPQ) will be saved under `$SAVE_DIR/densephrases-multi_sample/dump/start`. 66 | For creating a large-scale phrase index (e.g., Wikipedia), see [dump_phrases.py](https://github.com/princeton-nlp/DensePhrases/blob/main/scripts/parallel/dump_phrases.py) for an example, which is also explained [here](https://github.com/princeton-nlp/DensePhrases/tree/main#2-creating-a-phrase-index). 67 | 68 | ## Testing a Phrase Index 69 | You can use this phrase index to run a [demo](https://github.com/princeton-nlp/DensePhrases/tree/main#playing-with-a-densephrases-demo) or evaluate your set of queries. 70 | For instance, you can feed a set of questions (`questions.json`) to the custom phrase index as follows: 71 | ```bash 72 | python eval_phrase_retrieval.py \ 73 | --run_mode eval \ 74 | --cuda \ 75 | --dump_dir $SAVE_DIR/densephrases-multi_sample/dump \ 76 | --index_name start/32_flat_OPQ96 \ 77 | --load_dir $SAVE_DIR/densephrases-multi \ 78 | --test_path examples/create-custom-index/questions.json \ 79 | --save_pred \ 80 | --truecase 81 | ``` 82 | The prediction file will be saved as `$SAVE_DIR/densephrases-multi/pred/questions_3_top10.pred`, which shows the answer phrases and the passages that contain the phrases: 83 | ``` 84 | { 85 | "1": { 86 | "question": "Who won season 4 of America's got talent", 87 | ... 88 | "prediction": [ 89 | "Kevin Skinner", 90 | ... 91 | ], 92 | "evidence": [ 93 | "The fourth season of \"America's Got Talent\", an American television reality show talent competition, premiered on the NBC network on June 23, 2009. Country singer Kevin Skinner was named the winner on September 16, 2009.", 94 | ... 95 | ], 96 | } 97 | ... 98 | } 99 | ``` 100 | -------------------------------------------------------------------------------- /examples/create-custom-index/questions.json: -------------------------------------------------------------------------------- 1 | { 2 | "data": [ 3 | { 4 | "id": "1", 5 | "question": "who won season 4 of america's got talent", 6 | "answers": ["Kevin Skinner", "Country singer Kevin Skinner"] 7 | }, 8 | { 9 | "id": "2", 10 | "question": "how many goals scored ronaldo in 2014-2015 season", 11 | "answers": ["61"] 12 | }, 13 | { 14 | "id": "3", 15 | "question": "who plays william boldwood in far from the madding crowd", 16 | "answers": ["Michael Sheen"] 17 | } 18 | ] 19 | } 20 | -------------------------------------------------------------------------------- /examples/entity-linking/README.md: -------------------------------------------------------------------------------- 1 | # Entity Linking 2 | 3 | ## Pre-trained Models 4 | | Model | Query-FT. & Eval | R-Precision| Description | 5 | |:-------------------------------|:--------:|:--------:|:--------:| 6 | | [densephrases-multi-query-ay2](https://huggingface.co/princeton-nlp/densephrases-multi-query-ay2) | AIDA CoNLL-YAGO (AY2) | 61.6 | Result from [eval.ai](https://eval.ai/web/challenges/challenge-page/689/overview) | 7 | | [densephrases-multi-query-kilt-multi](https://huggingface.co/princeton-nlp/densephrases-multi-query-kilt-multi) | Multiple / AY2 | 68.4 | Trained on multiple KILT tasks | 8 | 9 | ## How to Use 10 | ```python 11 | >>> from densephrases import DensePhrases 12 | 13 | # Load densephraes-multi-query-ay2 14 | >>> model = DensePhrases( 15 | ... load_dir='princeton-nlp/densephrases-multi-query-ay2', 16 | ... dump_dir='/path/to/densephrases-multi_wiki-20181220/dump', 17 | ... ) 18 | 19 | # Entities need to be surrounded by [START_ENT] and [END_ENT] tags 20 | >>> model.search('West Indian all-rounder Phil Simmons took four for 38 on Friday as Leicestershire beat [START_ENT] Somerset [END_ENT] by an innings and 39 runs', retrieval_unit='document', top_k=1) 21 | ['Somerset County Cricket Club'] 22 | 23 | >>> model.search('[START_ENT] Security Council [END_ENT] members expressed concern on Thursday', retrieval_unit='document', top_k=1) 24 | ['United Nations Security Council'] 25 | ``` 26 | 27 | ### Evaluation 28 | ```python 29 | >>> import os 30 | 31 | # Evaluate loaded DensePhrases on AIDA CoNLL-YAGO (KILT) 32 | >>> model.evaluate( 33 | ... test_path=os.path.join(os.environ['DATA_DIR'], 'kilt/ay2/aidayago2-dev-kilt_open.json'), 34 | ... is_kilt=True, title2wikiid_path=os.path.join(os.environ['DATA_DIR'], 'wikidump/title2wikiid.json'), 35 | ... kilt_gold_path=os.path.join(os.environ['DATA_DIR'], 'kilt/ay2/aidayago2-dev-kilt.jsonl'), agg_strat='opt2', max_query_length=384 36 | ... ) 37 | ``` 38 | 39 | For test accuracy, use `aidayago2-test-kilt_open.json` instead and submit the prediction file (saved as `$SAVE_DIR/densephrases-multi-query-ay2/pred-kilt/*.jsonl`) to [eval.ai](https://eval.ai/web/challenges/challenge-page/689/overview). 40 | For WNED-WIKI and WNED-CWEB, follow the same process with files specified in the `wned-kilt-data` and `cweb-kilt-data` targets in [Makefile](https://github.com/princeton-nlp/DensePhrases/blob/main/Makefile). 41 | You can also evaluate the model with Makefile `eval-index` target by simply chaning the dependency. 42 | -------------------------------------------------------------------------------- /examples/fusion-in-decoder/README.md: -------------------------------------------------------------------------------- 1 | # Fusion-in-Decoder with DensePhrases 2 | You can use retrieved passages from DensePhrases to build a state-of-the-art open-domain QA system called [Fusion-in-Decoder](https://arxiv.org/abs/2007.01282) (FiD). 3 | Note that DensePhrases (w/o reader) already provides phrase-level answers for end-to-end open-domain QA whose performance is comparable to DPR (w/ BERT reader). This section provides how you can further improve the performance using a generative reader model (T5). 4 | 5 | ## Getting Top Passages from DensePhrases 6 | First, you need to get passages from DensePhrases. 7 | Using DensePhrases-multi, you can retrieve passages for Natural Questions as follows: 8 | ``` 9 | TRAIN_DATA=open-qa/nq-open/train_preprocessed.json 10 | DEV_DATA=open-qa/nq-open/dev_preprocessed.json 11 | TEST_DATA=open-qa/nq-open/test_preprocessed.json 12 | 13 | # Change --test_path accordingly 14 | python eval_phrase_retrieval.py \ 15 | --run_mode eval \ 16 | --model_type bert \ 17 | --pretrained_name_or_path SpanBERT/spanbert-base-cased \ 18 | --cuda \ 19 | --dump_dir $SAVE_DIR/densephrases-multi_wiki-20181220/dump/ \ 20 | --index_name start/1048576_flat_OPQ96 \ 21 | --load_dir $SAVE_DIR/densephrases-multi-query-nq \ 22 | --test_path $DATA_DIR/$TEST_DATA \ 23 | --save_pred \ 24 | --aggregate \ 25 | --agg_strat opt2 \ 26 | --top_k 200 \ 27 | --eval_psg \ 28 | --psg_top_k 100 \ 29 | --truecase 30 | ``` 31 | Since FiD requires training passages, you need to change `--test_path` to `$TRAIN_DATA` or `$DEV_DATA` to get training or development passages, respectively. 32 | Equivalently, you can use `eval-index-psg` in our [Makefile](https://github.com/princeton-nlp/DensePhrases/blob/main/Makefile). 33 | For TriviaQA, simply change the dataset to `tqa-open-data` specified in Makefile. 34 | 35 | After the inference, you will be able to get the following three files used for training and evaluating FiD models: 36 | * train_preprocessed_79168_top200_psg-top100.json 37 | * dev_preprocessed_8757_top200_psg-top100.json 38 | * test_preprocessed_3610_top200_psg-top100.json 39 | 40 | We will assume that these files are saved under `$SAVE_DIR/fid-data`. 41 | Note that each retrieved passage in DensePhrases is a natural paragraph mostly in different lengths. For the exact replication of the experiments in our EMNLP paper, you need a phrase index created from Wikipedia pre-processed for DPR (100-word passages), which we plan to provide soonish. 42 | 43 | ## Installing Fusion-in-Decoder 44 | For Fusion-in-Decoder, we use [the official code](https://github.com/facebookresearch/FiD) provided by the authors. 45 | It is often better to use a separate conda environment to train FiD. 46 | See [here](https://github.com/facebookresearch/FiD#dependencies) for dependencies. 47 | 48 | ```bash 49 | # Install torch with conda (please check your CUDA version) 50 | conda create -n fid python=3.7 51 | conda activate fid 52 | conda install pytorch=1.9.0 cudatoolkit=11.0 -c pytorch 53 | 54 | # Install Fusion-in-Decoder 55 | git clone https://github.com/facebookresearch/FiD.git 56 | cd FiD 57 | pip install -r requirements.txt 58 | ``` 59 | 60 | ## Training and Evaluation 61 | ```bash 62 | TRAIN_DATA=fid-data/train_preprocessed_79168_top200_psg-top100.json 63 | DEV_DATA=fid-data/dev_preprocessed_8757_top200_psg-top100.json 64 | TEST_DATA=fid-data/test_preprocessed_3610_top200_psg-top100.json 65 | 66 | # Train T5-base with top 5 passages (DDP using 4 GPUs) 67 | nohup python /path/to/miniconda3/envs/fid/lib/python3.6/site-packages/torch/distributed/launch.py \ 68 | --nnode=1 --node_rank=0 --nproc_per_node=4 train_reader.py \ 69 | --train_data $SAVE_DIR/$TRAIN_DATA \ 70 | --eval_data $SAVE_DIR/$DEV_DATA \ 71 | --model_size base \ 72 | --per_gpu_batch_size 1 \ 73 | --accumulation_steps 16 \ 74 | --total_steps 160000 \ 75 | --eval_freq 8000 \ 76 | --save_freq 8000 \ 77 | --n_context 5 \ 78 | --lr 0.00005 \ 79 | --text_maxlength 300 \ 80 | --name nq_reader_base-dph-c5-d4 \ 81 | --checkpoint_dir $SAVE_DIR/fid-data/pretrained_models > nq_reader_base-dph-c5-d4_out.log & 82 | 83 | # Test T5-base with top 5 passages (DDP using 4 GPUs) 84 | python /n/fs/nlp-jl5167/miniconda3/envs/fid/lib/python3.6/site-packages/torch/distributed/launch.py \ 85 | --nnode=1 --node_rank=0 --nproc_per_node=4 test_reader.py \ 86 | --model_path $SAVE_DIR/fid-data/pretrained_models/nq_reader_base-dph-c5-d4/checkpoint/best_dev \ 87 | --eval_data $SAVE_DIR/$TEST_DATA \ 88 | --per_gpu_batch_size 1 \ 89 | --n_context 5 \ 90 | --write_results \ 91 | --name nq_reader_base-dph-c5-d4 \ 92 | --checkpoint_dir $SAVE_DIR/fid-data/pretrained_models \ 93 | --text_maxlength 300 94 | ``` 95 | Note that most hyperparameters follow the original work and the only difference is the use of `--accumulation_steps 16` and proper adjustment to its training, save, evaluation steps. Larger `--text_maxlength` is used to cover natural paragraphs that are often longer than 100 words. 96 | -------------------------------------------------------------------------------- /examples/knowledge-dialogue/README.md: -------------------------------------------------------------------------------- 1 | # Knowledge-Grounded Dialogue 2 | 3 | ## Pre-trained Models 4 | | Model | Query-FT. & Eval | R-Precision| Description | 5 | |:-------------------------------|:--------:|:--------:|:--------:| 6 | | [densephrases-multi-query-wow](https://huggingface.co/princeton-nlp/densephrases-multi-query-wow) | Wizard of Wikipedia (WoW) | 47.0 | Result from [eval.ai](https://eval.ai/web/challenges/challenge-page/689/overview) | 7 | | [densephrases-multi-query-kilt-multi](https://huggingface.co/princeton-nlp/densephrases-multi-query-kilt-multi) | Multiple / WoW | 55.7 | Trained on multiple KILT tasks | 8 | 9 | ## How to Use 10 | ```python 11 | >>> from densephrases import DensePhrases 12 | 13 | # Load densephraes-multi-query-wow 14 | >>> model = DensePhrases( 15 | ... load_dir='princeton-nlp/densephrases-multi-query-wow', 16 | ... dump_dir='/path/to/densephrases-multi_wiki-20181220/dump', 17 | ... ) 18 | 19 | # Feed a dialogue as a query 20 | >>> model.search('I love rap music.', retrieval_unit='document', top_k=10) 21 | ['Rapping', 'Hip hop', 'Rap metal', 'Hip hop music', 'Rapso', 'Battle rap', 'Rape', 'Eurodance', 'Chopper (rap)', 'Rape culture'] 22 | 23 | >>> model.search('Have you heard of Yamaha? They started as a piano manufacturer in 1887!', retrieval_unit='document', top_k=5) 24 | ['Yamaha Corporation', 'Yamaha Drums', 'Tōkai Gakki', 'Suzuki Musical Instrument Corporation', 'Supermoto'] 25 | 26 | # You can get more metadata on the document by setting return_meta=True 27 | >>> doc, meta = model.search('I love rap music.', retrieval_unit='document', top_k=1, return_meta=True) 28 | >>> meta 29 | [{'context': 'Rap is usually delivered over a beat, typically provided by a DJ, turntablist, ...', 'title': ['Rapping'], 'doc_idx': 4096192, 'start_pos': 647, 'end_pos': 660, 'start_idx': 91, 'end_idx': 93, 'score': 53.58412170410156, ... 'answer': 'hip-hop music'}] 30 | ``` 31 | 32 | ### Evaluation 33 | ```python 34 | >>> import os 35 | 36 | # Evaluate loaded DensePhrases on Wizard of Wikipedia 37 | >>> model.evaluate( 38 | ... test_path=os.path.join(os.environ['DATA_DIR'], 'kilt/wow/wow-dev-kilt_open.json'), 39 | ... is_kilt=True, title2wikiid_path=os.path.join(os.environ['DATA_DIR'], 'wikidump/title2wikiid.json'), 40 | ... kilt_gold_path=os.path.join(os.environ['DATA_DIR'], 'kilt/wow/wow-dev-kilt.jsonl'), agg_strat='opt2', max_query_length=384 41 | ... ) 42 | ``` 43 | 44 | For test accuracy, use `wow-test-kilt_open.json` instead and submit the prediction file (saved as `$SAVE_DIR/densephrases-multi-query-wow/pred-kilt/*.jsonl`) to [eval.ai](https://eval.ai/web/challenges/challenge-page/689/overview). 45 | You can also evaluate the model with Makefile `eval-index` target by simply chaning the dependency. 46 | -------------------------------------------------------------------------------- /examples/slot-filling/README.md: -------------------------------------------------------------------------------- 1 | # Slot Filling 2 | 3 | ## Pre-trained Models 4 | | Model | Query-FT. & Eval | KILT-Accuracy | Description | 5 | |:-------------------------------|:--------:|:--------:|:--------:| 6 | | [densephrases-multi-query-trex](https://nlp.cs.princeton.edu/projects/densephrases/models/densephrases-multi-query-trex.tar.gz) | T-REx | 22.3 | Result from [eval.ai](https://eval.ai/web/challenges/challenge-page/689/overview) | 7 | | [densephrases-multi-query-zsre](https://nlp.cs.princeton.edu/projects/densephrases/models/densephrases-multi-query-zsre.tar.gz) | Zero-shot RE | 40.0 | | 8 | 9 | ## How to Use 10 | ```python 11 | >>> from densephrases import DensePhrases 12 | 13 | # Load densephraes-multi-query-trex locally 14 | >>> model = DensePhrases( 15 | ... load_dir='/path/to/densephrases-multi-query-trex', 16 | ... dump_dir='/path/to/densephrases-multi_wiki-20181220/dump', 17 | ... ) 18 | 19 | # Slot filling queries are in the format of 'Subject [SEP] Relation' 20 | >>> model.search('Superman [SEP] father', retrieval_unit='phrase', top_k=5) 21 | ['Jor-El', 'Clark Kent', 'Jor-El', 'Jor-El', 'Jor-El'] 22 | 23 | >>> model.search('Cirith Ungol [SEP] genre', retrieval_unit='phrase', top_k=5) 24 | ['heavy metal', 'doom metal', 'metal', 'Elvish', 'madrigal comedy'] 25 | ``` 26 | 27 | ### Evaluation 28 | ```python 29 | >>> import os 30 | 31 | # Evaluate loaded DensePhrases on T-REx (KILT) 32 | >>> model.evaluate( 33 | ... test_path=os.path.join(os.environ['DATA_DIR'], 'kilt/trex/trex-dev-kilt_open.json'), 34 | ... is_kilt=True, title2wikiid_path=os.path.join(os.environ['DATA_DIR'], 'wikidump/title2wikiid.json'), 35 | ... kilt_gold_path=os.path.join(os.environ['DATA_DIR'], 'kilt/trex/trex-dev-kilt.jsonl'), agg_strat='opt2', 36 | ... ) 37 | ``` 38 | 39 | For test accuracy, use `trex-test-kilt_open.json` instead and submit the prediction file (saved as `$SAVE_DIR/densephrases-multi-query-trex/pred-kilt/densephrases-multi-query-trex_trex-test-kilt_open_5000.jsonl`) to [eval.ai](https://eval.ai/web/challenges/challenge-page/689/overview). 40 | For zero-shot relation extraction, follow the same process with files specified in the `zsre-kilt-data` target in [Makefile](https://github.com/princeton-nlp/DensePhrases/blob/main/Makefile). 41 | You can also evaluate the model with Makefile `eval-index` target by simply chaning the dependency to `trex-kilt-data` or `zsre-kilt-data`. 42 | -------------------------------------------------------------------------------- /generate_phrase_vecs.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ Finetuning the library models for question-answering on SQuAD (DistilBERT, Bert, XLM, XLNet).""" 17 | 18 | 19 | import argparse 20 | import logging 21 | import os 22 | import timeit 23 | import copy 24 | import h5py 25 | import torch 26 | 27 | from tqdm import tqdm, trange 28 | from torch.utils.data import DataLoader, SequentialSampler 29 | from torch.utils.data.distributed import DistributedSampler 30 | 31 | from transformers import ( 32 | MODEL_MAPPING, 33 | AutoConfig, 34 | AutoModel, 35 | AutoTokenizer, 36 | ) 37 | from densephrases.utils.squad_utils import ContextResult, load_and_cache_examples 38 | from densephrases.utils.single_utils import set_seed, to_list, to_numpy, backward_compat, load_encoder 39 | from densephrases.utils.embed_utils import write_phrases, write_filter 40 | from densephrases import Options 41 | 42 | logger = logging.getLogger(__name__) 43 | 44 | 45 | def dump_phrases(args, model, tokenizer, filter_only=False): 46 | output_path = 'dump/phrase' if not filter_only else 'dump/filter' 47 | if not os.path.exists(os.path.join(args.output_dir, output_path)): 48 | os.makedirs(os.path.join(args.output_dir, output_path)) 49 | 50 | start_time = timeit.default_timer() 51 | if ':' not in args.predict_file: 52 | predict_files = [args.predict_file] 53 | offsets = [0] 54 | output_dump_file = os.path.join( 55 | args.output_dir, f"{output_path}/{os.path.splitext(os.path.basename(args.predict_file))[0]}.hdf5" 56 | ) 57 | else: 58 | dirname = os.path.dirname(args.predict_file) 59 | basename = os.path.basename(args.predict_file) 60 | start, end = list(map(int, basename.split(':'))) 61 | output_dump_file = os.path.join( 62 | args.output_dir, f"{output_path}/{start}-{end}.hdf5" 63 | ) 64 | 65 | # skip files if possible 66 | if os.path.exists(output_dump_file): 67 | with h5py.File(output_dump_file, 'r') as f: 68 | dids = list(map(int, f.keys())) 69 | start = int(max(dids) / 1000) 70 | logger.info('%s exists; starting from %d' % (output_dump_file, start)) 71 | 72 | names = [str(i).zfill(4) for i in range(start, end)] 73 | predict_files = [os.path.join(dirname, name) for name in names] 74 | offsets = [int(each) * 1000 for each in names] 75 | 76 | for offset, predict_file in zip(offsets, predict_files): 77 | args.predict_file = predict_file 78 | logger.info(f"***** Pre-processing contexts from {args.predict_file} *****") 79 | dataset, examples, features = load_and_cache_examples( 80 | args, tokenizer, evaluate=True, output_examples=True, context_only=True 81 | ) 82 | for example in examples: 83 | example.doc_idx += offset 84 | 85 | args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu) 86 | 87 | # Note that DistributedSampler samples randomly 88 | eval_sampler = SequentialSampler(dataset) 89 | eval_dataloader = DataLoader(dataset, sampler=eval_sampler, batch_size=args.eval_batch_size) 90 | 91 | logger.info(f"***** Dumping Phrases from {args.predict_file} *****") 92 | logger.info(" Num examples = %d", len(dataset)) 93 | logger.info(" Batch size = %d", args.eval_batch_size) 94 | start_time = timeit.default_timer() 95 | 96 | def get_phrase_results(): 97 | for batch in tqdm(eval_dataloader, desc="Dumping"): 98 | model.eval() 99 | batch = tuple(t.to(args.device) for t in batch) 100 | 101 | with torch.no_grad(): 102 | inputs = { 103 | "input_ids": batch[0], 104 | "attention_mask": batch[1], 105 | "token_type_ids": batch[2], 106 | "return_phrase": True, 107 | } 108 | feature_indices = batch[3] 109 | outputs = model(**inputs) 110 | 111 | for i, feature_index in enumerate(feature_indices): 112 | # TODO: i and feature_index are the same number! Simplify by removing enumerate? 113 | eval_feature = features[feature_index.item()] 114 | unique_id = int(eval_feature.unique_id) 115 | 116 | output = [ 117 | to_numpy(output[i]) if type(output) != dict else {k: to_numpy(v[i]) for k, v in output.items()} 118 | for output in outputs 119 | ] 120 | 121 | if len(output) != 4: 122 | raise NotImplementedError 123 | else: 124 | start_vecs, end_vecs, sft_logits, eft_logits = output 125 | result = ContextResult( 126 | unique_id, 127 | start_vecs=start_vecs, 128 | end_vecs=end_vecs, 129 | sft_logits=sft_logits, 130 | eft_logits=eft_logits, 131 | ) 132 | yield result 133 | 134 | if not filter_only: 135 | write_phrases( 136 | examples, features, get_phrase_results(), args.max_answer_length, args.do_lower_case, tokenizer, 137 | output_dump_file, args.filter_threshold, args.verbose_logging, 138 | args.dense_offset, args.dense_scale, has_title=args.append_title, 139 | ) 140 | else: 141 | write_filter( 142 | examples, features, get_phrase_results(), tokenizer, 143 | output_dump_file, args.filter_threshold, args.verbose_logging, has_title=args.append_title, 144 | ) 145 | 146 | evalTime = timeit.default_timer() - start_time 147 | logger.info("Evaluation done in total %f secs (%f sec per example)", evalTime, evalTime / len(dataset)) 148 | 149 | 150 | def main(): 151 | # See options in densephrases.options 152 | options = Options() 153 | options.add_model_options() 154 | options.add_data_options() 155 | options.add_rc_options() 156 | args = options.parse() 157 | 158 | # Setup CUDA, GPU & distributed training 159 | if args.local_rank == -1 or args.no_cuda: 160 | device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") 161 | args.n_gpu = 0 if args.no_cuda else torch.cuda.device_count() 162 | else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs 163 | torch.cuda.set_device(args.local_rank) 164 | device = torch.device("cuda", args.local_rank) 165 | torch.distributed.init_process_group(backend="nccl") 166 | args.n_gpu = 1 167 | args.device = device 168 | 169 | # Setup logging 170 | logging.basicConfig( 171 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 172 | datefmt="%m/%d/%Y %H:%M:%S", 173 | level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN, 174 | ) 175 | logger.warning( 176 | "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s", 177 | args.local_rank, 178 | device, 179 | args.n_gpu, 180 | bool(args.local_rank != -1), 181 | args.fp16, 182 | ) 183 | 184 | # Set seed 185 | set_seed(args) 186 | 187 | # Load config, tokenizer 188 | if args.local_rank not in [-1, 0]: 189 | # Make sure only the first process in distributed training will download model & vocab 190 | torch.distributed.barrier() 191 | 192 | args.model_type = args.model_type.lower() 193 | config, unused_kwargs = AutoConfig.from_pretrained( 194 | args.config_name if args.config_name else args.pretrained_name_or_path, 195 | cache_dir=args.cache_dir if args.cache_dir else None, 196 | output_hidden_states=False, 197 | return_unused_kwargs=True 198 | ) 199 | tokenizer = AutoTokenizer.from_pretrained( 200 | args.tokenizer_name if args.tokenizer_name else args.pretrained_name_or_path, 201 | do_lower_case=args.do_lower_case, 202 | cache_dir=args.cache_dir if args.cache_dir else None, 203 | ) 204 | 205 | if args.local_rank == 0: 206 | # Make sure only the first process in distributed training will download model & vocab 207 | torch.distributed.barrier() 208 | 209 | logger.info("Dump parameters %s", args) 210 | 211 | # Before we do anything with models, we want to ensure that we get fp16 execution of torch.einsum if args.fp16 is set. 212 | # Otherwise it'll default to "promote" mode, and we'll get fp32 operations. Note that running `--fp16_opt_level="O2"` 213 | # will remove the need for this code, but it is still valid. 214 | if args.fp16: 215 | try: 216 | import apex 217 | apex.amp.register_half_function(torch, "einsum") 218 | except ImportError: 219 | raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") 220 | 221 | # Create phrase vectors 222 | if args.do_dump: 223 | assert args.load_dir 224 | model, tokenizer, config = load_encoder(device, args, phrase_only=True) 225 | 226 | args.draft = False 227 | dump_phrases(args, model, tokenizer, filter_only=args.filter_only) 228 | 229 | 230 | if __name__ == "__main__": 231 | main() 232 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.9.0 2 | faiss-gpu==1.6.5 3 | transformers==2.9.0 4 | spacy==2.3.2 5 | h5py 6 | tqdm 7 | blosc 8 | ujson 9 | rouge 10 | wandb 11 | nltk 12 | flask 13 | flask_cors 14 | tornado 15 | requests-futures 16 | -------------------------------------------------------------------------------- /scripts/benchmark/benchmark_hdf5.py: -------------------------------------------------------------------------------- 1 | import h5py 2 | 3 | from tqdm import tqdm 4 | 5 | 6 | paths = [ 7 | 'dumps/sbcd_sqdqgnqqg_inb64_s384_sqdnq_pinb2_0_20181220_concat/dump/phrase/0-200.hdf5', 8 | 'dumps/sbcd_sqdqgnqqg_inb64_s384_sqdnq_pinb2_0_20181220_concat/dump/phrase/200-400.hdf5' 9 | ] 10 | phrase_dumps = [h5py.File(path, 'r') for path in paths] 11 | 12 | 13 | # Just testing how fast it is to read hdf5 files from disk 14 | for phrase_dump in phrase_dumps: 15 | for doc_id, doc_val in tqdm(phrase_dump.items()): 16 | kk = doc_val['start'][-10:] 17 | -------------------------------------------------------------------------------- /scripts/benchmark/create_benchmark_data.py: -------------------------------------------------------------------------------- 1 | import json 2 | import pdb 3 | 4 | data_for_denspi = [] 5 | data_for_dpr = [] 6 | 7 | with open('benchmark/nq_1000_dev_orqa.jsonl', encoding='utf-8') as f: 8 | idx = 0 9 | while True: 10 | line = f.readline() 11 | if line == "": 12 | break 13 | 14 | sample = json.loads(line) 15 | 16 | data_for_denspi.append({ 17 | 'id':f'dev_{idx}', 18 | 'question': sample['question'], 19 | 'answers': sample['answer'] 20 | }) 21 | data_for_dpr.append("\t".join([sample['question'], str(sample['answer'])])) 22 | 23 | idx += 1 24 | 25 | # save data_for_dpr as csv 26 | with open('benchmark/nq_1000_dev_dpr.csv', 'w', encoding='utf-8') as f: 27 | for line in data_for_dpr: 28 | f.writelines(line) 29 | f.writelines("\n") 30 | 31 | # save data_for_denspi as json 32 | with open('benchmark/nq_1000_dev_denspi.json', 'w', encoding='utf-8') as f: 33 | json.dump({'data': data_for_denspi}, f) 34 | -------------------------------------------------------------------------------- /scripts/dump/check_dump.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import h5py 4 | from tqdm import tqdm 5 | 6 | 7 | def get_range(name): 8 | # name = name.replace('_tfidf', '') 9 | return list(map(int, os.path.splitext(name)[0].split('-'))) 10 | 11 | 12 | def find_name(names, pos): 13 | for name in names: 14 | start, end = get_range(name) 15 | assert start != end, 'you have self-looping at %s' % name 16 | if start == pos: 17 | return name, end 18 | raise Exception('hdf5 file starting with %d not found.') 19 | 20 | 21 | def check_dump(args): 22 | print('checking dir contiguity...') 23 | names = os.listdir(args.dump_dir) 24 | pos = args.start 25 | while pos < args.end: 26 | name, pos = find_name(names, pos) 27 | assert pos == args.end, 'reached %d, which is different from the specified end %d' % (pos, args.end) 28 | print('dir contiguity test passed!') 29 | print('checking file corruption...') 30 | pos = args.start 31 | corrupted_paths = [] 32 | while pos < args.end: 33 | name, pos = find_name(names, pos) 34 | path = os.path.join(args.dump_dir, name) 35 | try: 36 | with h5py.File(path, 'r') as f: 37 | print('checking %s...' % path) 38 | for dk, group in tqdm(f.items()): 39 | keys = list(group.keys()) 40 | except Exception as e: 41 | print(e) 42 | print('%s corrupted!' % path) 43 | corrupted_paths.append(path) 44 | if len(corrupted_paths) > 0: 45 | print('following files are corrupted:') 46 | for path in corrupted_paths: 47 | print(path) 48 | else: 49 | print('file corruption test passed!') 50 | 51 | 52 | def get_args(): 53 | parser = argparse.ArgumentParser() 54 | parser.add_argument('dump_dir') 55 | parser.add_argument('start', type=int) 56 | parser.add_argument('end', type=int) 57 | 58 | return parser.parse_args() 59 | 60 | 61 | def main(): 62 | args = get_args() 63 | check_dump(args) 64 | 65 | 66 | if __name__ == '__main__': 67 | main() 68 | -------------------------------------------------------------------------------- /scripts/dump/filter_hdf5.py: -------------------------------------------------------------------------------- 1 | import h5py 2 | import os 3 | from tqdm import tqdm 4 | 5 | input_dump_dir = 'dumps/sbcd_sqd_ftinb84_kl_x4_20181220_concat/dump/phrase/' 6 | select = 0 7 | print(f'************** {select} *****************') 8 | input_dump_paths = sorted( 9 | [os.path.join(input_dump_dir, name) for name in os.listdir(input_dump_dir) if 'hdf5' in name] 10 | )[select:] 11 | print(input_dump_paths) 12 | input_dumps = [h5py.File(path, 'r') for path in input_dump_paths] 13 | dump_names = [os.path.splitext(os.path.basename(path))[0] for path in input_dump_paths] 14 | print(input_dumps) 15 | 16 | # Filter dump for a lighter version 17 | ''' 18 | output_dumps = [ 19 | h5py.File(f'dumps/densephrases-multi_wiki-20181220/dump/phrase/{k}.hdf5', 'w') 20 | for k in dump_names 21 | ] 22 | print(output_dumps) 23 | 24 | 25 | for dump_idx, (input_dump, output_dump) in tqdm(enumerate(zip(input_dumps, output_dumps))): 26 | print(f'filtering {input_dump} to {output_dump}') 27 | for idx, (key, val) in tqdm(enumerate(input_dump.items())): 28 | 29 | dg = output_dump.create_group(key) 30 | dg.attrs['context'] = val.attrs['context'][:] 31 | dg.attrs['title'] = val.attrs['title'][:] 32 | for k_, v_ in val.items(): 33 | if k_ not in ['start', 'len_per_para', 'start2end']: 34 | dg.create_dataset(k_, data=v_[:]) 35 | 36 | input_dump.close() 37 | output_dump.close() 38 | 39 | print('filter done') 40 | ''' 41 | 42 | def load_doc_groups(phrase_dump_dir): 43 | phrase_dump_paths = sorted( 44 | [os.path.join(phrase_dump_dir, name) for name in os.listdir(phrase_dump_dir) if 'hdf5' in name] 45 | ) 46 | doc_groups = {} 47 | types = ['word2char_start', 'word2char_end', 'f2o_start'] 48 | attrs = ['context', 'title'] 49 | phrase_dumps = [h5py.File(path, 'r') for path in phrase_dump_paths] 50 | for path in tqdm(phrase_dump_paths, desc='loading doc groups'): 51 | with h5py.File(path, 'r') as f: 52 | for key in tqdm(f): 53 | import pdb; pdb.set_trace() 54 | doc_group = {} 55 | for type_ in types: 56 | doc_group[type_] = f[key][type_][:] 57 | for attr in attrs: 58 | doc_group[attr] = f[key].attrs[attr] 59 | doc_groups[key] = doc_group 60 | return doc_groups 61 | 62 | # Save below as a pickle file and load it on memory for later use 63 | doc_groups = load_doc_groups(input_dump_dir) 64 | -------------------------------------------------------------------------------- /scripts/dump/filter_stats.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import h5py 4 | import numpy as np 5 | from tqdm import tqdm 6 | 7 | 8 | def get_range(name): 9 | # name = name.replace('_tfidf', '') 10 | return list(map(int, os.path.splitext(name)[0].split('-'))) 11 | 12 | 13 | def find_name(names, pos): 14 | for name in names: 15 | start, end = get_range(name) 16 | assert start != end, 'you have self-looping at %s' % name 17 | if start == pos: 18 | return name, end 19 | raise Exception('hdf5 file starting with %d not found.') 20 | 21 | 22 | def check_dump(args): 23 | print('checking dir contiguity...') 24 | names = os.listdir(args.dump_dir) 25 | pos = args.start 26 | while pos < args.end: 27 | name, pos = find_name(names, pos) 28 | assert pos == args.end, 'reached %d, which is different from the specified end %d' % (pos, args.end) 29 | print('dir contiguity test passed!') 30 | print('checking file corruption...') 31 | pos = args.start 32 | corrupted_paths = [] 33 | 34 | all_count = 0 35 | thresholds = [0.0, 0.5, 1, 1.5, 2, 2.5, 3, 3.5, 4, 4.5, 5] 36 | save_bins = {th: 0 for th in thresholds} 37 | while pos < args.end: 38 | name, pos = find_name(names, pos) 39 | path = os.path.join(args.dump_dir, name) 40 | with h5py.File(path, 'r') as f: 41 | print('checking %s...' % path) 42 | for dk, group in tqdm(f.items()): 43 | filter_start = group['filter_start'][:] 44 | filter_end = group['filter_end'][:] 45 | for th in thresholds: 46 | start_idxs, = np.where(filter_start > th) 47 | end_idxs, = np.where(filter_end > th) 48 | num_save_vec = len(set(np.concatenate([start_idxs, end_idxs]))) 49 | save_bins[th] += num_save_vec 50 | all_count += len(filter_start) 51 | # break 52 | 53 | print(all_count) 54 | print(save_bins) 55 | comp_rate = {th: f'{save_num/all_count*100:.2f}%' for th, save_num in save_bins.items()} 56 | print(f'Compression rate: {comp_rate}') 57 | if len(corrupted_paths) > 0: 58 | print('following files are corrupted:') 59 | for path in corrupted_paths: 60 | print(path) 61 | else: 62 | print('file corruption test passed!') 63 | 64 | 65 | def get_args(): 66 | parser = argparse.ArgumentParser() 67 | parser.add_argument('dump_dir') 68 | parser.add_argument('start', type=int) 69 | parser.add_argument('end', type=int) 70 | 71 | return parser.parse_args() 72 | 73 | 74 | def main(): 75 | args = get_args() 76 | check_dump(args) 77 | 78 | 79 | if __name__ == '__main__': 80 | main() 81 | -------------------------------------------------------------------------------- /scripts/dump/save_meta.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import h5py 4 | import torch 5 | from tqdm import tqdm 6 | 7 | 8 | def get_range(name): 9 | # name = name.replace('_tfidf', '') 10 | return list(map(int, os.path.splitext(name)[0].split('-'))) 11 | 12 | 13 | def find_name(names, pos): 14 | for name in names: 15 | start, end = get_range(name) 16 | assert start != end, 'you have self-looping at %s' % name 17 | if start == pos: 18 | return name, end 19 | raise Exception('hdf5 file starting with %d not found.') 20 | 21 | 22 | def check_dump(args): 23 | print('checking dir contiguity...') 24 | names = os.listdir(args.dump_dir) 25 | pos = args.start 26 | while pos < args.end: 27 | name, pos = find_name(names, pos) 28 | assert pos == args.end, 'reached %d, which is different from the specified end %d' % (pos, args.end) 29 | print('dir contiguity test passed!') 30 | print('checking file corruption...') 31 | pos = args.start 32 | corrupted_paths = [] 33 | metadata = {} 34 | keys_to_save = ['f2o_end', 'f2o_start', 'span_logits', 'start2end', 'word2char_end', 'word2char_start'] 35 | while pos < args.end: 36 | name, pos = find_name(names, pos) 37 | path = os.path.join(args.dump_dir, name) 38 | try: 39 | with h5py.File(path, 'r') as f: 40 | print('checking %s...' % path) 41 | for dk, group in tqdm(f.items()): 42 | # keys = list(group.keys()) 43 | metadata[dk] = {save_key: group[save_key][:] for save_key in keys_to_save} 44 | metadata[dk]['context'] = group.attrs['context'] 45 | metadata[dk]['title'] = group.attrs['title'] 46 | except Exception as e: 47 | print(e) 48 | print('%s corrupted!' % path) 49 | corrupted_paths.append(path) 50 | 51 | break 52 | 53 | torch.save(metadata, 'tmp.bin') 54 | if len(corrupted_paths) > 0: 55 | print('following files are corrupted:') 56 | for path in corrupted_paths: 57 | print(path) 58 | else: 59 | print('file corruption test passed!') 60 | 61 | 62 | def get_args(): 63 | parser = argparse.ArgumentParser() 64 | parser.add_argument('dump_dir') 65 | parser.add_argument('start', type=int) 66 | parser.add_argument('end', type=int) 67 | 68 | return parser.parse_args() 69 | 70 | 71 | def main(): 72 | args = get_args() 73 | check_dump(args) 74 | 75 | 76 | if __name__ == '__main__': 77 | main() 78 | -------------------------------------------------------------------------------- /scripts/dump/split_hdf5.py: -------------------------------------------------------------------------------- 1 | import h5py 2 | import os 3 | from tqdm import tqdm 4 | 5 | input_dump_dir = 'dumps/sbcd_sqd_ftinb84_kl_x4_20181220_concat/dump/phrase/' 6 | select = 6 7 | print(f'************** {select} *****************') 8 | input_dump_paths = sorted( 9 | [os.path.join(input_dump_dir, name) for name in os.listdir(input_dump_dir) if 'hdf5' in name] 10 | )[select:select+1] 11 | print(input_dump_paths) 12 | input_dumps = [h5py.File(path, 'r') for path in input_dump_paths] 13 | 14 | dump_names = [os.path.splitext(os.path.basename(path))[0] for path in input_dump_paths] 15 | dump_ranges = [list(map(int, name.split('-'))) for name in dump_names] 16 | new_ranges = [] 17 | for range_ in dump_ranges: 18 | # print(range_) 19 | middle = sum(range_) // 2 # split by half 20 | new_range_ = [[range_[0], middle], [middle, range_[1]]] 21 | # print(new_range_) 22 | new_ranges.append(new_range_) 23 | 24 | output_dumps = [ 25 | [h5py.File(f'dumps/sbcd_sqd_ftinb84_kl_x4_20181220_concat/dump/phrase/{ra[0]}-{ra[1]}.hdf5', 'w') 26 | for ra in range_] 27 | for range_ in new_ranges 28 | ] 29 | 30 | print(input_dumps) 31 | print(output_dumps) 32 | print(new_ranges) 33 | 34 | # dev-100M-c 160408 35 | # dev_wiki_noise 250000 36 | 37 | for dump_idx, (input_dump, new_range, output_dump) in tqdm(enumerate(zip(input_dumps, new_ranges, output_dumps))): 38 | print(f'splitting {input_dump} to {output_dump}') 39 | for idx, (key, val) in tqdm(enumerate(input_dump.items())): 40 | # if idx < 250000/2: 41 | if int(key) < new_range[0][1] * 1000: 42 | output_dump[0].copy(val, key) 43 | else: 44 | output_dump[1].copy(val, key) 45 | 46 | input_dump.close() 47 | output_dump[0].close() 48 | output_dump[1].close() 49 | 50 | print('copy done') 51 | -------------------------------------------------------------------------------- /scripts/kilt/build_title2wikiid.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | 5 | """A script to read in and store documents in a sqlite database.""" 6 | 7 | import argparse 8 | import sqlite3 9 | import json 10 | import os 11 | import logging 12 | import importlib.util 13 | import unicodedata 14 | import html 15 | 16 | from multiprocessing import Pool as ProcessPool 17 | from tqdm import tqdm 18 | 19 | logger = logging.getLogger() 20 | logger.setLevel(logging.INFO) 21 | fmt = logging.Formatter('%(asctime)s: [ %(message)s ]', '%m/%d/%Y %I:%M:%S %p') 22 | console = logging.StreamHandler() 23 | console.setFormatter(fmt) 24 | logger.addHandler(console) 25 | 26 | 27 | # ------------------------------------------------------------------------------ 28 | # Import helper 29 | # ------------------------------------------------------------------------------ 30 | 31 | 32 | PREPROCESS_FN = None 33 | 34 | 35 | def init(filename): 36 | global PREPROCESS_FN 37 | if filename: 38 | PREPROCESS_FN = import_module(filename).preprocess 39 | 40 | 41 | def import_module(filename): 42 | """Import a module given a full path to the file.""" 43 | spec = importlib.util.spec_from_file_location('doc_filter', filename) 44 | module = importlib.util.module_from_spec(spec) 45 | spec.loader.exec_module(module) 46 | return module 47 | 48 | 49 | # ------------------------------------------------------------------------------ 50 | # Store corpus. 51 | # ------------------------------------------------------------------------------ 52 | 53 | def normalize(text): 54 | """Resolve different type of unicode encodings.""" 55 | return unicodedata.normalize('NFD', html.unescape(text)) 56 | 57 | def iter_files(path): 58 | """Walk through all files located under a root path.""" 59 | if os.path.isfile(path): 60 | yield path 61 | elif os.path.isdir(path): 62 | for dirpath, _, filenames in os.walk(path): 63 | for f in filenames: 64 | yield os.path.join(dirpath, f) 65 | else: 66 | raise RuntimeError('Path %s is invalid' % path) 67 | 68 | 69 | def get_contents(filename): 70 | """Parse the contents of a file. Each line is a JSON encoded document.""" 71 | # documents = [] 72 | results = {} 73 | with open(filename, encoding='utf-8') as f: 74 | for line in f: 75 | # Parse document 76 | doc = json.loads(line) 77 | # Skip if it is empty or None 78 | if not doc: 79 | continue 80 | # Add the document 81 | 82 | title = normalize(doc['title']) 83 | if '&' in title: 84 | import pdb; pdb.set_trace() 85 | 86 | if 'u0' in title: 87 | import pdb; pdb.set_trace() 88 | results[title] = doc['id'] 89 | return results 90 | 91 | 92 | def store_contents(data_path, save_path): 93 | results = {} 94 | files = [f for f in iter_files(data_path)] 95 | for file in tqdm(files): 96 | contents = get_contents(file) 97 | results.update(contents) 98 | 99 | print(f"len(results)={len(results)}") 100 | with open(save_path, 'w') as f: 101 | json.dump(results, f) 102 | 103 | # ------------------------------------------------------------------------------ 104 | # Main. 105 | # ------------------------------------------------------------------------------ 106 | 107 | 108 | if __name__ == '__main__': 109 | parser = argparse.ArgumentParser() 110 | parser.add_argument('--data_path', type=str, help='/path/to/data') 111 | parser.add_argument('--save_path', type=str, help='/path/to/saved/db.db') 112 | args = parser.parse_args() 113 | 114 | store_contents( 115 | args.data_path, args.save_path 116 | ) -------------------------------------------------------------------------------- /scripts/kilt/sample_kilt.py: -------------------------------------------------------------------------------- 1 | import json 2 | import argparse 3 | import os 4 | import random 5 | import time 6 | import numpy as np 7 | 8 | from tqdm import tqdm 9 | 10 | 11 | def main(input_file, num_sample, balanced): 12 | print('reading', input_file) 13 | random.seed(999) 14 | np.random.seed(999) 15 | 16 | examples = json.load(open(input_file))['data'] 17 | print(f'sampling from {len(examples)}') 18 | relation_dict = {} 19 | for example in tqdm(examples): 20 | relation = example['question'].split(' [SEP] ')[-1] 21 | if relation not in relation_dict: 22 | relation_dict[relation] = [] 23 | relation_dict[relation].append(example) 24 | 25 | top_relations = sorted(relation_dict.items(), key=lambda x: len(x[1]), reverse=True) 26 | print('There are', len(relation_dict), 'relations.') 27 | print([(rel, len(rel_list)) for rel, rel_list in top_relations]) 28 | print() 29 | exit() 30 | 31 | if not balanced: 32 | sample_per_relation = { 33 | rel: int((len(rel_list)/len(examples)) * num_sample) + 1 for rel, rel_list in top_relations 34 | } 35 | else: 36 | sample_per_relation = { 37 | rel: min(num_sample, len(rel_list)) for rel, rel_list in top_relations 38 | } 39 | print('Sample following number of relations') 40 | print(sample_per_relation) 41 | 42 | sample_examples = [] 43 | for rel, rel_list in relation_dict.items(): 44 | sample_idx = np.random.choice(len(rel_list), size=(sample_per_relation[rel]), replace=False) 45 | sample_examples += np.array(rel_list)[sample_idx].tolist() 46 | 47 | out_file = input_file.replace('.json', f'_{num_sample}_{"balanced" if balanced else "ratio"}.json') 48 | print(f'Saving {len(sample_examples)} examples to {out_file}') 49 | with open(out_file, 'w') as f: 50 | json.dump({'data': sample_examples}, f) 51 | 52 | 53 | if __name__ == '__main__': 54 | parser = argparse.ArgumentParser() 55 | parser.add_argument("input_file", type=str) 56 | parser.add_argument("--num_sample", type=int, required=True) 57 | parser.add_argument("--balanced", action='store_true', default=False) 58 | 59 | args = parser.parse_args() 60 | 61 | main(args.input_file, args.num_sample, args.balanced) 62 | -------------------------------------------------------------------------------- /scripts/kilt/strip_pred.py: -------------------------------------------------------------------------------- 1 | from densephrases.utils.kilt.eval import evaluate as kilt_evaluate 2 | from densephrases.utils.kilt.kilt_utils import load_data, store_data 3 | import string 4 | import argparse 5 | 6 | 7 | def strip_pred(input_file, gold_file): 8 | 9 | print('original evaluation result:', input_file) 10 | result = kilt_evaluate(gold=gold_file, guess=input_file) 11 | print(result) 12 | 13 | preds = load_data(input_file) 14 | for pred in preds: 15 | pred['output'][0]['answer'] = pred['output'][0]['answer'].strip(string.punctuation) 16 | 17 | out_file = input_file.replace('.jsonl', '_strip.jsonl') 18 | print('strip evaluation result:', out_file) 19 | store_data(out_file, preds) 20 | new_result = kilt_evaluate(gold=gold_file, guess=out_file) 21 | print(new_result) 22 | 23 | 24 | 25 | if __name__ == '__main__': 26 | parser = argparse.ArgumentParser() 27 | parser.add_argument('input_file', type=str) 28 | parser.add_argument('gold_file', type=str) 29 | args = parser.parse_args() 30 | strip_pred(args.input_file, args.gold_file) 31 | -------------------------------------------------------------------------------- /scripts/parallel/add_to_index.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import subprocess 4 | 5 | import h5py 6 | from tqdm import tqdm 7 | 8 | 9 | def get_size(name): 10 | a, b = list(map(int, os.path.splitext(name)[0].split('-'))) 11 | return b - a 12 | 13 | 14 | def bin_names(dir_, names, num_bins): 15 | names = sorted(names, key=lambda name_: -os.path.getsize(os.path.join(dir_, name_))) 16 | bins = [] 17 | for name in names: 18 | if len(bins) < num_bins: 19 | bins.append([name]) 20 | else: 21 | smallest_bin = min(bins, key=lambda bin_: sum(get_size(name_) for name_ in bin_)) 22 | smallest_bin.append(name) 23 | return bins 24 | 25 | 26 | def run_add_to_index(args): 27 | def get_cmd(dump_paths, offset_): 28 | return ["python", 29 | "build_phrase_index.py", 30 | f"{args.dump_dir}", 31 | "add", 32 | "--fine_quant", "SQ4", 33 | "--dump_paths", f"{dump_paths}", 34 | "--offset", f"{offset_}", 35 | "--num_clusters", f"{args.num_clusters}", 36 | f"{'--cuda' if args.cuda else ''}"] 37 | 38 | 39 | dir_ = os.path.join(args.dump_dir, 'phrase') 40 | names = os.listdir(dir_) 41 | bins = bin_names(dir_, names, args.num_gpus) 42 | offsets = [args.max_num_per_file * each for each in range(len(bins))] 43 | 44 | print('adding with offset:') 45 | for offset, bin_ in zip(offsets, bins): 46 | print('%d: %s' % (offset, ','.join(bin_))) 47 | 48 | for kk, (bin_, offset) in enumerate(zip(bins, offsets)): 49 | if args.start <= kk < args.end: 50 | print(get_cmd(','.join(bin_), offset)) 51 | subprocess.run(get_cmd(','.join(bin_), offset)) 52 | if args.draft: 53 | break 54 | 55 | 56 | def get_args(): 57 | parser = argparse.ArgumentParser() 58 | parser.add_argument('--dump_dir', default='dump/76_dev-1B-c') 59 | parser.add_argument('--num_cpus', default=4, type=int) 60 | parser.add_argument('--num_gpus', default=60, type=int) 61 | parser.add_argument('--mem_size', default=40, type=int, help='mem size in GB') 62 | parser.add_argument('--num_clusters', default=4096, type=int) 63 | parser.add_argument('--draft', default=False, action='store_true') 64 | parser.add_argument('--max_num_per_file', default=int(1e8), type=int, 65 | help='max num per file for setting up good offsets.') 66 | parser.add_argument('--cuda', default=False, action='store_true') 67 | parser.add_argument('--start', default=0, type=int) 68 | parser.add_argument('--end', default=3, type=int) 69 | args = parser.parse_args() 70 | 71 | return args 72 | 73 | 74 | def main(): 75 | args = get_args() 76 | run_add_to_index(args) 77 | 78 | 79 | if __name__ == '__main__': 80 | main() 81 | -------------------------------------------------------------------------------- /scripts/parallel/dump_phrases.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import math 3 | import os 4 | import subprocess 5 | 6 | 7 | def run_dump_phrase(args): 8 | do_lower_case = '--do_lower_case' if args.do_lower_case else '' 9 | append_title = '--append_title' if args.append_title else '' 10 | def get_cmd(start_doc, end_doc): 11 | return ["python", "generate_phrase_vecs.py", 12 | "--model_type", f"{args.model_type}", 13 | "--pretrained_name_or_path", f"{args.pretrained_name_or_path}", 14 | "--data_dir", f"{args.phrase_data_dir}", 15 | "--cache_dir", f"{args.cache_dir}", 16 | "--predict_file", f"{start_doc}:{end_doc}", 17 | "--do_dump", 18 | "--max_seq_length", "512", 19 | "--doc_stride", "500", 20 | "--fp16", 21 | "--load_dir", f"{args.load_dir}", 22 | "--output_dir", f"{args.output_dir}", 23 | "--filter_threshold", f"{args.filter_threshold:.2f}"] + \ 24 | ([f"{do_lower_case}"] if len(do_lower_case) > 0 else []) + \ 25 | ([f"{append_title}"] if len(append_title) > 0 else []) 26 | 27 | num_docs = args.end - args.start 28 | num_gpus = args.num_gpus 29 | num_docs_per_gpu = int(math.ceil(num_docs / num_gpus)) 30 | start_docs = list(range(args.start, args.end, num_docs_per_gpu)) 31 | end_docs = start_docs[1:] + [args.end] 32 | 33 | print(start_docs) 34 | print(end_docs) 35 | 36 | for device_idx, (start_doc, end_doc) in enumerate(zip(start_docs, end_docs)): 37 | print(get_cmd(start_doc, end_doc)) 38 | subprocess.Popen(get_cmd(start_doc, end_doc)) 39 | 40 | 41 | def get_args(): 42 | parser = argparse.ArgumentParser() 43 | parser.add_argument('--model_type', default='bert') 44 | parser.add_argument('--pretrained_name_or_path', default='SpanBERT/spanbert-base-cased') 45 | parser.add_argument('--data_dir', default='') 46 | parser.add_argument('--cache_dir', default='') 47 | parser.add_argument('--data_name', default='') # for suffix 48 | parser.add_argument('--load_dir', default='') 49 | parser.add_argument('--output_dir', default='') 50 | parser.add_argument('--do_lower_case', default=False, action='store_true') 51 | parser.add_argument('--append_title', default=False, action='store_true') 52 | parser.add_argument('--filter_threshold', default=-1e9, type=float) 53 | parser.add_argument('--num_gpus', default=1, type=int) 54 | parser.add_argument('--start', default=0, type=int) 55 | parser.add_argument('--end', default=8, type=int) 56 | args = parser.parse_args() 57 | 58 | args.output_dir = args.output_dir + '_%s' % (os.path.basename(args.data_name)) 59 | args.phrase_data_dir = os.path.join(args.data_dir, args.data_name) 60 | 61 | return args 62 | 63 | 64 | def main(): 65 | args = get_args() 66 | run_dump_phrase(args) 67 | 68 | 69 | if __name__ == '__main__': 70 | main() 71 | -------------------------------------------------------------------------------- /scripts/postprocess/recall.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import re 4 | import unicodedata 5 | from collections import defaultdict 6 | from tqdm import tqdm 7 | from scripts.preprocess.simple_tokenizer import SimpleTokenizer 8 | 9 | 10 | def read_file(infile, handle_file, log=False, skip_first_line=False): 11 | if log: 12 | print('Opening "{}"...'.format(infile)) 13 | data = None 14 | with open(infile) as f: 15 | if skip_first_line: 16 | f.readline() 17 | data = handle_file(f) 18 | if log: 19 | print(' Done.') 20 | return data 21 | 22 | 23 | def read_jsonl(infile, log=False): 24 | handler = lambda f: [json.loads(line) for line in f.readlines()] 25 | return read_file(infile, handler, log=log) 26 | 27 | 28 | def read_json(infile, log=False): 29 | handler = lambda f: json.load(f) 30 | return read_file(infile, handler, log=log) 31 | 32 | 33 | def _normalize(text): 34 | return unicodedata.normalize('NFD', text) 35 | 36 | ############################################################################### 37 | ### HAS_ANSWER FUNCTIONS #################################################### 38 | ############################################################################### 39 | def has_answer_field(ctx, answers): 40 | return ctx['has_answer'] 41 | 42 | 43 | tokenizer = SimpleTokenizer(**{}) 44 | def string_match(ctx, answers): 45 | text = tokenizer.tokenize(ctx['text']).words(uncased=True) 46 | 47 | for single_answer in answers: 48 | single_answer = _normalize(single_answer) 49 | single_answer = tokenizer.tokenize(single_answer) 50 | single_answer = single_answer.words(uncased=True) 51 | 52 | for i in range(0, len(text) - len(single_answer) + 1): 53 | if single_answer == text[i: i + len(single_answer)]: 54 | return True 55 | return False 56 | 57 | 58 | def normalized_title(ctx, answers): 59 | for answer in answers: 60 | a = a.lower().strip() 61 | title = ctx['title'].lower().strip() 62 | if a == title[:len(a)]: 63 | return True 64 | return False 65 | 66 | 67 | def regex(ctx, answers): 68 | text = ctx['text'] 69 | for answer in answers: 70 | answer = _normalize(answer) 71 | if regex_match(text, answer): 72 | return True 73 | return False 74 | 75 | 76 | def regex_match(text, pattern): 77 | """Test if a regex pattern is contained within a text.""" 78 | try: 79 | pattern = re.compile( 80 | pattern, 81 | flags=re.IGNORECASE + re.UNICODE + re.MULTILINE, 82 | ) 83 | except BaseException: 84 | return False 85 | return pattern.search(text) is not None 86 | 87 | 88 | ############################################################################### 89 | ### CALCULATION FUNCTIONS ################################################### 90 | ############################################################################### 91 | def precision_fn(results, k_vals, has_answer): 92 | n_hits = {k: 0 for k in k_vals} 93 | mrrs = [] 94 | precs = [] 95 | PREC_K = 20 96 | MRR_K = 20 97 | 98 | for result in tqdm(results): 99 | ans = result['answers'] 100 | ctxs = result['ctxs'] 101 | found_k = len(ctxs) + 1 102 | found = False 103 | num_hit = 0 104 | for c_idx,ctx in enumerate(ctxs): 105 | if has_answer(ctx, ans): 106 | if not found: 107 | found_k = c_idx # record first one 108 | found = True 109 | 110 | if c_idx < PREC_K: # P@k 111 | num_hit += 1 112 | # break 113 | for k in k_vals: 114 | if found_k < k: 115 | n_hits[k] += 1 116 | 117 | if found_k >= MRR_K: 118 | mrrs.append(0) 119 | else: 120 | mrrs.append(1/(found_k + 1)) 121 | precs.append(num_hit/PREC_K) 122 | 123 | print('*'*50) 124 | for k in k_vals: 125 | if len(results) == 0: 126 | print('No results.') 127 | else: 128 | print('Top-{} = {:.2%}'.format(k, n_hits[k] / len(results))) 129 | 130 | print(f'Acc@{k_vals[0]} when Acc@{k_vals[-1]} = {n_hits[k_vals[0]]/n_hits[k_vals[-1]]*100:.2f}%') 131 | print(f'MRR@{MRR_K} = {sum(mrrs)/len(mrrs)*100:.2f}') 132 | print(f'P@{PREC_K} = {sum(precs)/len(precs)*100:.2f}') 133 | 134 | 135 | def precision_fn_file(infile, n_docs, k_vals, has_answer, args): 136 | results = read_jsonl(infile) if args.jsonl else read_json(infile) 137 | 138 | # stats 139 | ctx_lens = [sum([len(pp['text'].split()) for pp in re['ctxs']])/len(re['ctxs']) for re in results] 140 | print(f'ctx token length: {sum(ctx_lens)/len(ctx_lens):.2f}') 141 | 142 | # unique titles 143 | title_lens = [len(set(pp['title'] for pp in re['ctxs'])) for re in results] 144 | print(f'unique titles: {sum(title_lens)/len(title_lens):.2f}') 145 | 146 | precision_fn(results, k_vals, has_answer) 147 | 148 | 149 | # Top-20 and Top-100 150 | def precision_per_bucket(results_file, longtail_file, n_docs, k_vals, longtail_tags, ans_fn): 151 | results = read_json(results_file) 152 | annotations = read_json(longtail_file) 153 | for tag in longtail_tags: 154 | bucket = [result for idx,result in enumerate(results) if tag == annotations[idx]['annotations']] 155 | print('==== Bucket={} ====='.format(tag)) 156 | precision_fn(bucket, n_docs, k_vals, ans_fn) 157 | print() 158 | 159 | 160 | if __name__ == '__main__': 161 | parser = argparse.ArgumentParser() 162 | parser.add_argument('--results_file', required=True, type=str, default=None, 163 | help="Location of the results file to parse.") 164 | parser.add_argument('--n_docs', type=int, default=100, 165 | help="Maximum number of docs retrieved.") 166 | parser.add_argument('--k_values', type=str, default='1,5,10,20,40,50,60,80,100', 167 | help="Top-K values to print out") 168 | parser.add_argument('--ans_fn', type=str, default='has_answer', 169 | help="How to check whether has the answer. title | has_answer") 170 | parser.add_argument('--jsonl', action='store_true', help='Set if results is a jsonl file.') 171 | 172 | # Longtail Entity Analysis 173 | parser.add_argument('--longtail', action='store_true', 174 | help='whether or not to include longtail buckets') 175 | parser.add_argument('--longtail_file', required=False, type=str, default=None, 176 | help='Mapping from question to longtail entity tags.') 177 | parser.add_argument('--longtail_tags', type=str, default='p10,p25,p50,p75,p90', 178 | help='Tags for the longtail entities within longtail_file') 179 | 180 | args = parser.parse_args() 181 | ks = [int(k) for k in args.k_values.split(',')] 182 | if args.ans_fn == 'has_answer': 183 | ans_fn = has_answer_field 184 | elif args.ans_fn == 'title': 185 | ans_fn = normalized_title 186 | elif args.ans_fn == 'string': 187 | ans_fn = string_match 188 | elif args.ans_fn == 'regex': 189 | ans_fn = regex 190 | else: 191 | raise Exception('Answer function not recognized') 192 | 193 | if args.longtail: 194 | longtail_tags = args.longtail_tags.split(',') 195 | precision_per_bucket(args.results_file, args.longtail_file, 196 | args.n_docs, ks, longtail_tags, ans_fn) 197 | else: 198 | precision_fn_file(args.results_file, args.n_docs, ks, ans_fn, args) 199 | -------------------------------------------------------------------------------- /scripts/postprocess/recall_transform.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import sys 4 | import argparse 5 | import numpy as np 6 | from spacy.lang.en import English 7 | from tqdm import tqdm 8 | 9 | nlp = English() 10 | nlp.add_pipe(nlp.create_pipe('sentencizer')) 11 | 12 | 13 | def main(args): 14 | pred_file = os.path.join(args.model_dir, 'pred', args.pred_file) 15 | my_pred = json.load(open(pred_file)) 16 | 17 | my_target = [] 18 | avg_len = [] 19 | for qid, pred in tqdm(enumerate(my_pred.values())): 20 | my_dict = {"id": str(qid), "question": None, "answers": [], "ctxs": []} 21 | 22 | # truncate 23 | pred = {key: val[:args.psg_top_k] if key in ['evidence', 'title', 'se_pos', 'prediction'] else val for key, val in pred.items()} 24 | 25 | # TODO: need to add id for predictions.pred in the future 26 | my_dict["question"] = pred["question"] 27 | my_dict["answers"] = pred["answer"] 28 | pred["title"] = [titles[0] for titles in pred["title"]] 29 | 30 | assert len(set(pred["evidence"])) == len(pred["evidence"]) == len(pred["title"]), "Should use opt2 for aggregation" 31 | # assert all(pr in evd for pr, evd in zip(pred["prediction"], pred["evidence"])) # prediction included 32 | 33 | # Pad up to top-k 34 | if not(len(pred["prediction"]) == len(pred["evidence"]) == len(pred["title"]) == args.psg_top_k): 35 | assert len(pred["prediction"]) == len(pred["evidence"]) == len(pred["title"]) < args.psg_top_k, \ 36 | (len(pred["prediction"]), len(pred["evidence"]), len(pred["title"])) 37 | print(len(pred["prediction"]), len(pred["evidence"]), len(pred["title"])) 38 | 39 | pred["evidence"] += [pred["evidence"][-1]] * (args.psg_top_k - len(pred["prediction"])) 40 | pred["title"] += [pred["title"][-1]] * (args.psg_top_k - len(pred["prediction"])) 41 | pred["se_pos"] += [pred["se_pos"][-1]] * (args.psg_top_k - len(pred["prediction"])) 42 | pred["prediction"] += [pred["prediction"][-1]] * (args.psg_top_k - len(pred["prediction"])) 43 | assert len(pred["prediction"]) == len(pred["evidence"]) == len(pred["title"]) == args.psg_top_k 44 | 45 | # Used for markers 46 | START = '' 47 | END = '' 48 | se_idxs = [[se_pos[0], max(se_pos[0], se_pos[1])] for se_pos in pred["se_pos"]] 49 | 50 | # Return sentence 51 | if args.return_sent: 52 | sents = [[(X.text, X[0].idx) for X in nlp(evidence).sents] for evidence in pred['evidence']] 53 | sent_idxs = [ 54 | sorted(set([sum(np.array([st[1] for st in sent]) <= se_idx[0]) - 1] + [sum(np.array([st[1] for st in sent]) <= se_idx[1]-1) - 1])) 55 | for se_idx, sent in zip(se_idxs, sents) 56 | ] 57 | se_idxs = [[se_pos[0]-sent[sent_idx[0]][1], se_pos[1]-sent[sent_idx[0]][1]] for se_pos, sent_idx, sent in zip(se_idxs, sent_idxs, sents)] 58 | if not all(pred.replace(' ', '') in ' '.join([sent[sidx][0] for sidx in range(sent_idx[0], sent_idx[-1]+1)]).replace(' ', '') 59 | for pred, sent, sent_idx in zip(pred['prediction'], sents, sent_idxs)): 60 | import pdb; pdb.set_trace() 61 | pass 62 | 63 | # get sentence based on the window 64 | max_context_len = args.max_context_len - 2 if args.mark_phrase else args.max_context_len 65 | my_dict["ctxs"] = [ 66 | # {"title": title, "text": ' '.join(' '.join([sent[sidx][0] for sidx in range(sent_idx[0], sent_idx[-1]+1)]).split()[:max_context_len])} 67 | {"title": title, "text": ' '.join(' '.join([sent[sidx][0] for sidx in range( 68 | max(0, sent_idx[0]-args.sent_window), min(sent_idx[-1]+1+args.sent_window, len(sent)))] 69 | ).split()[:max_context_len]) 70 | } 71 | for title, sent, sent_idx in zip(pred["title"], sents, sent_idxs) 72 | ] 73 | # Return passagae 74 | else: 75 | my_dict["ctxs"] = [ 76 | {"title": title, "text": ' '.join(evd.split()[:args.max_context_len])} 77 | for evd, title in zip(pred["evidence"], pred["title"]) 78 | ] 79 | 80 | # Add markers for predicted phrases 81 | if args.mark_phrase: 82 | my_dict["ctxs"] = [ 83 | {"title": ctx["title"], "text": ctx["text"][:se[0]] + f"{START} " + ctx["text"][se[0]:se[1]] + f" {END}" + ctx["text"][se[1]:]} 84 | for ctx, se in zip(my_dict["ctxs"], se_idxs) 85 | ] 86 | 87 | my_target.append(my_dict) 88 | avg_len += [len(ctx['text'].split()) for ctx in my_dict["ctxs"]] 89 | assert len(my_dict["ctxs"]) == args.psg_top_k 90 | assert all(len(ctx['text'].split()) <= args.max_context_len for ctx in my_dict["ctxs"]) 91 | 92 | print(f"avg ctx len={sum(avg_len)/len(avg_len):.2f} for {len(my_pred)} preds") 93 | 94 | out_file = os.path.join( 95 | args.model_dir, 'pred', 96 | os.path.splitext(args.pred_file)[0] + 97 | f'_{"sent" if args.return_sent else "psg"}-top{args.psg_top_k}{"_mark" if args.mark_phrase else ""}.json' 98 | ) 99 | print(f"dump to {out_file}") 100 | json.dump(my_target, open(out_file, 'w'), indent=4) 101 | 102 | 103 | if __name__ == '__main__': 104 | parser = argparse.ArgumentParser() 105 | 106 | parser.add_argument('--model_dir', type=str, default='') 107 | parser.add_argument('--pred_file', type=str, default='') 108 | parser.add_argument('--psg_top_k', type=int, default=100) 109 | parser.add_argument('--max_context_len', type=int, default=999999999) 110 | parser.add_argument('--mark_phrase', default=False, action='store_true') 111 | parser.add_argument('--return_sent', default=False, action='store_true') 112 | parser.add_argument('--sent_window', type=int, default=0) 113 | args = parser.parse_args() 114 | 115 | main(args) 116 | -------------------------------------------------------------------------------- /scripts/preprocess/README.md: -------------------------------------------------------------------------------- 1 | ## Create SQuAD-Style Wiki Dump (20181220) 2 | 3 | ### Download wiki dump of 20181220 4 | ``` 5 | python download_wikidump.py \ 6 | --output_dir /hdd1/data/wikidump 7 | ``` 8 | 9 | ### Extract Wiki dump via Wikiextractor 10 | Use [Wikiextractor](https://github.com/attardi/wikiextractor) to convert wiki dump into the json style. 11 | 12 | ``` 13 | python WikiExtractor.py \ 14 | --filter_disambig_pages \ 15 | --json \ 16 | -o /hdd1/data/wikidump/extracted/ \ 17 | /hdd1/data/wikidump/enwiki-20181220-pages-articles.xml.bz2 18 | ``` 19 | 20 | ### Build docs.db in SQlite style 21 | ``` 22 | python build_db.py \ 23 | --data_path /hdd1/data/wikidump/extracted \ 24 | --save_path /hdd1/data/wikidump/docs_20181220.db \ 25 | --preprocess prep_wikipedia.py 26 | ``` 27 | 28 | ### Transform sqlite to squad-style 29 | ``` 30 | python build_wikisquad.py \ 31 | --db_path /hdd1/data/wikidump/docs_20181220.db \ 32 | --out_dir /hdd1/data/wikidump/20181220 33 | ``` 34 | 35 | ### Concatenate short length of paragraphs 36 | ``` 37 | python concat_wikisquad.py \ 38 | --input_dir /hdd1/data/wikidump/20181220 \ 39 | --output_dir /hdd1/data/wikidump/20181220_concat 40 | ``` 41 | -------------------------------------------------------------------------------- /scripts/preprocess/build_db.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | # https://github.com/facebookresearch/DrQA/blob/master/scripts/retriever/build_db.py 8 | """A script to read in and store documents in a sqlite database.""" 9 | 10 | import argparse 11 | import sqlite3 12 | import json 13 | import os 14 | import logging 15 | import importlib.util 16 | 17 | from multiprocessing import Pool as ProcessPool 18 | from tqdm import tqdm 19 | import unicodedata 20 | 21 | logger = logging.getLogger() 22 | logger.setLevel(logging.INFO) 23 | fmt = logging.Formatter('%(asctime)s: [ %(message)s ]', '%m/%d/%Y %I:%M:%S %p') 24 | console = logging.StreamHandler() 25 | console.setFormatter(fmt) 26 | logger.addHandler(console) 27 | 28 | # ------------------------------------------------------------------------------ 29 | # Utils 30 | # ------------------------------------------------------------------------------ 31 | 32 | def normalize(text): 33 | """Resolve different type of unicode encodings.""" 34 | return unicodedata.normalize('NFD', text) 35 | 36 | # ------------------------------------------------------------------------------ 37 | # Import helper 38 | # ------------------------------------------------------------------------------ 39 | 40 | 41 | PREPROCESS_FN = None 42 | 43 | 44 | def init(filename): 45 | global PREPROCESS_FN 46 | if filename: 47 | PREPROCESS_FN = import_module(filename).preprocess 48 | 49 | 50 | def import_module(filename): 51 | """Import a module given a full path to the file.""" 52 | spec = importlib.util.spec_from_file_location('doc_filter', filename) 53 | module = importlib.util.module_from_spec(spec) 54 | spec.loader.exec_module(module) 55 | return module 56 | 57 | 58 | # ------------------------------------------------------------------------------ 59 | # Store corpus. 60 | # ------------------------------------------------------------------------------ 61 | 62 | 63 | def iter_files(path): 64 | """Walk through all files located under a root path.""" 65 | if os.path.isfile(path): 66 | yield path 67 | elif os.path.isdir(path): 68 | for dirpath, _, filenames in os.walk(path): 69 | for f in filenames: 70 | yield os.path.join(dirpath, f) 71 | else: 72 | raise RuntimeError('Path %s is invalid' % path) 73 | 74 | 75 | def get_contents(filename): 76 | """Parse the contents of a file. Each line is a JSON encoded document.""" 77 | global PREPROCESS_FN 78 | documents = [] 79 | with open(filename) as f: 80 | for line in f: 81 | # Parse document 82 | doc = json.loads(line) 83 | # Maybe preprocess the document with custom function 84 | if PREPROCESS_FN: 85 | doc = PREPROCESS_FN(doc) 86 | # Skip if it is empty or None 87 | if not doc: 88 | continue 89 | # Add the document 90 | documents.append((normalize(doc['id']), doc['text'])) 91 | return documents 92 | 93 | 94 | def store_contents(data_path, save_path, preprocess, num_workers=None): 95 | """Preprocess and store a corpus of documents in sqlite. 96 | Args: 97 | data_path: Root path to directory (or directory of directories) of files 98 | containing json encoded documents (must have `id` and `text` fields). 99 | save_path: Path to output sqlite db. 100 | preprocess: Path to file defining a custom `preprocess` function. Takes 101 | in and outputs a structured doc. 102 | num_workers: Number of parallel processes to use when reading docs. 103 | """ 104 | if os.path.isfile(save_path): 105 | raise RuntimeError('%s already exists! Not overwriting.' % save_path) 106 | 107 | logger.info('Reading into database...') 108 | conn = sqlite3.connect(save_path) 109 | c = conn.cursor() 110 | c.execute("CREATE TABLE documents (id PRIMARY KEY, text);") 111 | 112 | workers = ProcessPool(num_workers, initializer=init, initargs=(preprocess,)) 113 | files = [f for f in iter_files(data_path)] 114 | count = 0 115 | with tqdm(total=len(files)) as pbar: 116 | for pairs in tqdm(workers.imap_unordered(get_contents, files)): 117 | count += len(pairs) 118 | c.executemany("INSERT OR IGNORE INTO documents VALUES (?,?)", pairs) 119 | pbar.update() 120 | logger.info('Read %d docs.' % count) 121 | logger.info('Committing...') 122 | conn.commit() 123 | conn.close() 124 | 125 | 126 | # ------------------------------------------------------------------------------ 127 | # Main. 128 | # ------------------------------------------------------------------------------ 129 | 130 | 131 | if __name__ == '__main__': 132 | parser = argparse.ArgumentParser() 133 | parser.add_argument('--data_path', type=str, help='/path/to/data') 134 | parser.add_argument('--save_path', type=str, help='/path/to/saved/db.db') 135 | parser.add_argument('--preprocess', type=str, default=None, 136 | help=('File path to a python module that defines ' 137 | 'a `preprocess` function')) 138 | parser.add_argument('--num-workers', type=int, default=None, 139 | help='Number of CPU processes (for tokenizing, etc)') 140 | args = parser.parse_args() 141 | 142 | store_contents( 143 | args.data_path, args.save_path, args.preprocess, args.num_workers 144 | ) -------------------------------------------------------------------------------- /scripts/preprocess/compress_metadata.py: -------------------------------------------------------------------------------- 1 | import pdb 2 | import os 3 | import h5py 4 | from tqdm import tqdm 5 | import sys 6 | import zlib 7 | import numpy as np 8 | import traceback 9 | import blosc 10 | import pickle 11 | import argparse 12 | 13 | # get size of the whole metadata 14 | def get_size(d): 15 | size = 0 16 | for i in d: 17 | word2char_start_size = sys.getsizeof(d[i]['word2char_start']) 18 | word2char_end_size = sys.getsizeof(d[i]['word2char_end']) 19 | f2o_start_size = sys.getsizeof(d[i]['f2o_start']) 20 | context_size = sys.getsizeof(d[i]['context']) 21 | title_size = sys.getsizeof(d[i]['title']) 22 | size+=word2char_start_size 23 | size+=word2char_end_size 24 | size+=f2o_start_size 25 | size+=context_size 26 | size+=title_size 27 | 28 | return size 29 | 30 | # compress metadata using zlib 31 | # http://python-blosc.blosc.org/tutorial.html 32 | def compress(d): 33 | for i in d: 34 | word2char_start = d[i]['word2char_start'] 35 | word2char_end = d[i]['word2char_end'] 36 | f2o_start = d[i]['f2o_start'] 37 | context=d[i]['context'] 38 | title=d[i]['title'] 39 | 40 | # save type to use when decompressing 41 | type1= word2char_start.dtype 42 | type2= word2char_end.dtype 43 | type3= f2o_start.dtype 44 | 45 | d[i]['word2char_start'] = blosc.compress(word2char_start, typesize=1,cname='zlib') 46 | d[i]['word2char_end'] = blosc.compress(word2char_end, typesize=1,cname='zlib') 47 | d[i]['f2o_start'] = blosc.compress(f2o_start, typesize=1,cname='zlib') 48 | d[i]['context'] = blosc.compress(context.encode('utf-8'),cname='zlib') 49 | d[i]['dtypes']={ 50 | 'word2char_start':type1, 51 | 'word2char_end':type2, 52 | 'f2o_start':type3 53 | } 54 | 55 | # check if compression is lossless 56 | try: 57 | decompressed_word2char_start = np.frombuffer(blosc.decompress(d[i]['word2char_start']), type1) 58 | decompressed_word2char_end = np.frombuffer(blosc.decompress(d[i]['word2char_end']), type2) 59 | decompressed_f2o_start = np.frombuffer(blosc.decompress(d[i]['f2o_start']), type3) 60 | decompressed_context = blosc.decompress(d[i]['context']).decode('utf-8') 61 | 62 | assert ((word2char_start == decompressed_word2char_start).all()) 63 | assert ((word2char_end == decompressed_word2char_end).all()) 64 | assert ((f2o_start ==decompressed_f2o_start).all()) 65 | assert (context == decompressed_context) 66 | except Exception as e: 67 | print(e) 68 | traceback.print_exc() 69 | pdb.set_trace() 70 | return d 71 | 72 | def load_doc_groups(phrase_dump_dir): 73 | phrase_dump_paths = sorted( 74 | [os.path.join(phrase_dump_dir, name) for name in os.listdir(phrase_dump_dir) if 'hdf5' in name] 75 | ) 76 | doc_groups = {} 77 | types = ['word2char_start', 'word2char_end', 'f2o_start'] 78 | attrs = ['context', 'title'] 79 | phrase_dumps = [h5py.File(path, 'r') for path in phrase_dump_paths] 80 | phrase_dumps = phrase_dumps[:1] 81 | for path in tqdm(phrase_dump_paths, desc='loading doc groups'): 82 | with h5py.File(path, 'r') as f: 83 | for key in tqdm(f): 84 | doc_group = {} 85 | for type_ in types: 86 | doc_group[type_] = f[key][type_][:] 87 | for attr in attrs: 88 | doc_group[attr] = f[key].attrs[attr] 89 | doc_groups[key] = doc_group 90 | 91 | return doc_groups 92 | 93 | def main(args): 94 | # Use it for saving to memory 95 | doc_groups = load_doc_groups(args.input_dump_dir) 96 | 97 | # Get the size of meta data before compression 98 | size_before_compression = get_size(doc_groups) 99 | 100 | # compress metadata using zlib 101 | doc_groups = compress(doc_groups) 102 | 103 | # Get the size of meta data before compression 104 | size_after_compression = get_size(doc_groups) 105 | 106 | print(f"compressed by {round(size_after_compression/size_before_compression*100,2)}%") 107 | 108 | # save compressed meta as a pickle format 109 | output_file = os.path.join(args.output_dir, 'meta_compressed.pkl') 110 | with open(output_file,'wb') as f: 111 | pickle.dump(doc_groups, f) 112 | 113 | if __name__ == '__main__': 114 | parser = argparse.ArgumentParser() 115 | 116 | parser.add_argument('--input_dump_dir', type=str, default='dump/sbcd_sqdqgnqqg_inb64_s384_sqdnq_pinb2_0_20181220_concat/dump/phrase') 117 | parser.add_argument('--output_dir', type=str, default='./') 118 | args = parser.parse_args() 119 | 120 | main(args) 121 | -------------------------------------------------------------------------------- /scripts/preprocess/concat_wikisquad.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import argparse 4 | 5 | from tqdm import tqdm 6 | import pdb 7 | 8 | def normalize(text): 9 | return text.lower().replace('_', ' ') 10 | 11 | 12 | def concat_wikisquad(args): 13 | names = os.listdir(args.input_dir) 14 | data = {'data': []} 15 | for name in tqdm(names): 16 | from_path = os.path.join(args.input_dir, name) 17 | with open(from_path, 'r') as fp: 18 | from_ = json.load(fp) 19 | 20 | for ai, article in enumerate(from_['data']): 21 | article['id'] = int(name) * 1000 + ai 22 | 23 | articles = [] 24 | for article in from_['data']: 25 | articles.append(article) 26 | 27 | for article in articles: 28 | to_article = {'title': article['title'], 'paragraphs': []} 29 | context = "" 30 | for para_idx, para in enumerate(article['paragraphs']): 31 | context = context + " " + para['context'] 32 | if args.min_num_chars <= len(context): 33 | to_article['paragraphs'].append({'context': context}) 34 | context = "" 35 | # if the length of the last paragraph is less than min_num_chars, 36 | # append it to the previous saving 37 | elif para_idx == len(article['paragraphs']) -1 : 38 | if len(to_article['paragraphs']): 39 | previous_context = to_article['paragraphs'][-1]['context'] 40 | previous_context = previous_context + " " + context 41 | to_article['paragraphs'][-1]['context'] = previous_context 42 | # if no previous saving exists, create it. 43 | else: 44 | to_article['paragraphs'].append({'context': context}) 45 | 46 | data['data'].append(to_article) 47 | 48 | if not os.path.exists(args.output_dir): 49 | os.makedirs(args.output_dir) 50 | for start_idx in range(0, len(data['data']), args.docs_per_file): 51 | to_path = os.path.join(args.output_dir, str(int(start_idx / args.docs_per_file)).zfill(4)) 52 | cur_data = {'data': data['data'][start_idx:start_idx + args.docs_per_file]} 53 | with open(to_path, 'w') as fp: 54 | json.dump(cur_data, fp) 55 | 56 | def get_args(): 57 | parser = argparse.ArgumentParser() 58 | parser.add_argument('--input_dir') 59 | parser.add_argument('--output_dir') 60 | parser.add_argument('--min_num_chars', default=500, type=int) 61 | parser.add_argument('--docs_per_file', default=1000, type=int) 62 | 63 | return parser.parse_args() 64 | 65 | 66 | def main(): 67 | args = get_args() 68 | concat_wikisquad(args) 69 | 70 | if __name__ == '__main__': 71 | main() -------------------------------------------------------------------------------- /scripts/preprocess/create_nq_reader.py: -------------------------------------------------------------------------------- 1 | import json 2 | import argparse 3 | import pdb 4 | import glob 5 | from nq_utils import load_examples 6 | 7 | def convert_tokens_to_answer(paragraph_tokens, answer_tokens): 8 | answer_token_indexes = [] 9 | for answer_token in answer_tokens: 10 | answer_token_index = paragraph_tokens.index(answer_token) 11 | answer_token_indexes.append(answer_token_index) 12 | 13 | if len(answer_token_indexes) != (answer_token_indexes[-1] - answer_token_indexes[0] + 1): 14 | print("answer_token_indexes=",answer_token_indexes) 15 | pdb.set_trace() 16 | 17 | 18 | context = "" 19 | answer_text = "" 20 | answer_start = -1 21 | for i, paragraph_token in enumerate(paragraph_tokens): 22 | # skip html token 23 | if not paragraph_token['html_token']: 24 | token = paragraph_token['token'] 25 | 26 | # prepare appending token with white space 27 | if context != "": context +=" " 28 | 29 | # update answer_start 30 | if i == answer_token_indexes[0]: 31 | answer_start = len(context) 32 | 33 | # append token 34 | context += token 35 | 36 | # update answer_end 37 | if i == answer_token_indexes[-1]: 38 | answer_end = len(context) 39 | 40 | answer_text = context[answer_start:answer_end] 41 | 42 | # sanity check 43 | assert context != "" 44 | assert answer_text != "" 45 | assert answer_start != -1 46 | 47 | return context, answer_text, answer_start 48 | 49 | def main(args): 50 | # load nq_open and get ids 51 | with open(args.nq_open_path, 'r') as f: 52 | nq_open_data = json.load(f)['data'] 53 | nq_open_ids = [qas['id'] for qas in nq_open_data] 54 | 55 | # load nq_orig 56 | nq_orig_paths = sorted(glob.glob(args.nq_orig_path_pattern)) 57 | nq_reader_data = [] 58 | for i, nq_orig_path in enumerate(nq_orig_paths): 59 | with open(nq_orig_path, mode='rb') as fileobj: 60 | examples = load_examples(fileobj, 'train', 'short_answers') 61 | 62 | # filter examples contained in nq_open ids 63 | examples = dict(filter(lambda x: int(x[0]) in nq_open_ids, list(examples.items()))) 64 | 65 | for example_id, example in examples.items(): 66 | # filter candidates with answers 67 | candidates = list(filter(lambda x: x.contains_answer, example.candidates)) 68 | if len(candidates) == 0: 69 | continue 70 | 71 | title = example.title 72 | # TODO! consider multi annotation for nq_orig_dev set 73 | short_answers = example.short_answers[0] # assume single annotation 74 | paragraphs=[] 75 | 76 | for candidate in candidates: 77 | # filter

examples 78 | contents = candidate.contents 79 | is_paragraph = contents.startswith('

') 80 | start_token = candidate.start_token 81 | end_token = candidate.end_token 82 | tokens = example.document_tokens[start_token:end_token] 83 | 84 | answers = [] 85 | for short_answer in short_answers: 86 | answer_start_token = short_answer['start_token'] 87 | answer_end_token = short_answer['end_token'] 88 | if answer_end_token-answer_start_token>5: 89 | continue 90 | answer_tokens = example.document_tokens[answer_start_token:answer_end_token] 91 | # convert tokens to context, answer_text, answer_start 92 | context, answer_text, answer_start = convert_tokens_to_answer(tokens, answer_tokens) 93 | answers.append({ 94 | 'text': answer_text, 95 | 'answer_start': answer_start 96 | }) 97 | 98 | qas = [{ 99 | 'question':example.question_text, 100 | 'is_impossible': False if is_paragraph else True, 101 | 'answers':answers, 102 | 'is_distant': False, 103 | 'id':int(example_id), 104 | }] 105 | paragraphs.append({ 106 | 'context':context, 107 | 'qas':qas 108 | }) 109 | nq_reader_data.append({ 110 | 'title': title, 111 | 'paragraphs':paragraphs 112 | }) 113 | 114 | nq_reader = { 115 | 'data' : nq_reader_data 116 | } 117 | # save nq_reader 118 | with open(args.output_path,'w') as f: 119 | json.dump(nq_reader, f, indent=2) 120 | 121 | if __name__ == '__main__': 122 | parser = argparse.ArgumentParser() 123 | 124 | # Required parameters 125 | parser.add_argument( 126 | "--nq_open_path", 127 | default=None, 128 | type=str, 129 | required=True, 130 | help="nq-open path (eg. nq-open/dev.json)" 131 | ) 132 | parser.add_argument( 133 | "--nq_orig_path_pattern", 134 | default=None, 135 | type=str, 136 | required=True, 137 | help="nq-open path (eg. natural-questions/train/nq-train-*.jsonl.gz)" 138 | ) 139 | parser.add_argument( 140 | "--output_path", 141 | default=None, 142 | type=str, 143 | required=True, 144 | help="nq-reader directory (eg. nq-reader/dev.json)" 145 | ) 146 | 147 | args = parser.parse_args() 148 | 149 | main(args) 150 | 151 | -------------------------------------------------------------------------------- /scripts/preprocess/create_nq_reader_doc_wiki.py: -------------------------------------------------------------------------------- 1 | import json 2 | import glob 3 | import argparse 4 | from tqdm import tqdm 5 | import os 6 | 7 | def main(args): 8 | wiki_dir = args.wiki_dir 9 | nq_dir = args.nq_reader_docs_dir 10 | output_dir = args.output_dir 11 | 12 | wiki_file_list = glob.glob(os.path.join(wiki_dir,"*")) 13 | wiki_titles = [] 14 | num_wiki = 0 15 | wiki_title2paragraphs = {} 16 | for filename in tqdm(wiki_file_list, total=len(wiki_file_list)): 17 | with open(filename,'r') as f: 18 | data = json.load(f)['data'] 19 | 20 | for doc in data: 21 | title = doc['title'] 22 | wiki_titles.append(title) 23 | paragraph = doc['paragraphs'] 24 | wiki_title2paragraphs[title] = paragraph 25 | num_wiki += 1 26 | 27 | assert len(wiki_title2paragraphs) == num_wiki 28 | 29 | nq_file_list = glob.glob(os.path.join(nq_dir,"*")) 30 | nq_titles = [] 31 | unmatched_titles = [] 32 | num_matched = 0 33 | num_unmatched = 0 34 | for filename in tqdm(nq_file_list, total=len(nq_file_list)): 35 | with open(filename,'r') as f: 36 | data = json.load(f)['data'] 37 | 38 | for doc in data: 39 | title = doc['title'] 40 | nq_titles.append(title) 41 | if title in wiki_title2paragraphs: 42 | doc['paragraphs'] = wiki_title2paragraphs[title] 43 | num_matched += 1 44 | else: 45 | unmatched_titles.append(title) 46 | num_unmatched +=1 47 | 48 | new_paragraphs = [] 49 | for paragraph in doc['paragraphs']: 50 | if ('is_paragraph' in paragraph) and (not paragraph['is_paragraph']): 51 | continue 52 | 53 | new_paragraphs.append({ 54 | 'context': paragraph['context'] 55 | }) 56 | doc['paragraphs'] = new_paragraphs 57 | 58 | if not os.path.exists(output_dir): 59 | os.mkdir(output_dir) 60 | 61 | output_path = os.path.join(output_dir,os.path.basename(filename)) 62 | output = { 63 | 'data': data 64 | } 65 | 66 | with open(output_path, 'w') as f: 67 | json.dump(output, f, indent=2) 68 | 69 | # with open('unmatched_title.txt', 'w') as f: 70 | # for title in unmatched_titles: 71 | # if 'list of' in title: 72 | # continue 73 | # f.writelines(title) 74 | # f.writelines("\n") 75 | 76 | print("num_matched={} num_unmatched={}".format(num_matched, num_unmatched)) 77 | print("len(nq_titles)={} len(wiki_titles)={}".format(len(nq_titles), len(wiki_titles))) 78 | 79 | if __name__ == '__main__': 80 | parser = argparse.ArgumentParser() 81 | 82 | # Required parameters 83 | parser.add_argument("--wiki_dir", type=str, required=True) 84 | parser.add_argument("--nq_reader_docs_dir", type=str, required=True) 85 | parser.add_argument("--output_dir", type=str, required=True) 86 | 87 | args = parser.parse_args() 88 | 89 | main(args) 90 | 91 | -------------------------------------------------------------------------------- /scripts/preprocess/create_nq_reader_wiki.py: -------------------------------------------------------------------------------- 1 | import json 2 | import argparse 3 | import os 4 | import glob 5 | import copy 6 | 7 | from tqdm import tqdm 8 | 9 | 10 | def nq_to_wiki(input_file, output_dir, wiki_dump): 11 | with open(input_file, 'r') as f: 12 | nq_data = json.load(f)['data'] 13 | 14 | para_cnt = 0 15 | match_cnt = 0 16 | title_not_found_cnt = 0 17 | answer_not_found_cnt = 0 18 | tokenize_error = 0 19 | WINDOW = 10 20 | new_data = [] 21 | for article in tqdm(nq_data): 22 | title = article['title'] if type(article['title']) != list else article['title'][0] 23 | 24 | assert len(article['paragraphs']) == 1 25 | for paragraph in article['paragraphs']: 26 | para_cnt += 1 27 | new_paragraph = None 28 | answer_found = False 29 | assert len(paragraph['qas']) == 1, 'NQ only has single para for each Q' 30 | 31 | # We skip these cases and use existing paras 32 | qa = paragraph['qas'][0] 33 | if 'redundant' in str(qa['id']): 34 | break 35 | 36 | if qa['is_impossible'] or (title not in wiki_dump): 37 | pass 38 | else: 39 | # Or we find matching answers 40 | answers = qa['answers'] if type(qa['answers']) == list else [qa['answers']] 41 | for answer in answers: 42 | start_window = WINDOW if WINDOW < answer['answer_start'] else answer['answer_start'] 43 | 44 | answer_text = paragraph['context'][ 45 | answer['answer_start']:answer['answer_start']+len(answer['text']) 46 | ].replace('\'\'', '"').replace('``', '"').replace(' ', '').lower() 47 | 48 | answer_text_with_context = [ 49 | paragraph['context'][ # Front/Back 10 chars 50 | answer['answer_start']-start_window:answer['answer_start']+len(answer['text'])+WINDOW 51 | ].replace('\'\'', '"').replace('``', '"').replace(' ', '').lower(), 52 | paragraph['context'][ # Front 10 chars 53 | answer['answer_start']-start_window:answer['answer_start']+len(answer['text']) 54 | ].replace('\'\'', '"').replace('``', '"').replace(' ', '').lower(), 55 | paragraph['context'][ # Back 10 chars 56 | answer['answer_start']:answer['answer_start']+len(answer['text'])+WINDOW 57 | ].replace('\'\'', '"').replace('``', '"').replace(' ', '').lower(), 58 | ] 59 | 60 | new_start = None 61 | wiki_paragraph = None 62 | for wiki_par in wiki_dump[title]: 63 | wiki_par_char = ''.join([char.lower()[0] for char in wiki_par['context'].replace(' ', '')]) 64 | nosp_to_sp = {} 65 | for sp_idx, char in enumerate(wiki_par['context']): 66 | if char != ' ': 67 | nosp_to_sp[len(nosp_to_sp)] = sp_idx 68 | assert len(nosp_to_sp) == len(wiki_par_char) 69 | 70 | # Context match 71 | if any([at_with_context in wiki_par_char for at_with_context in answer_text_with_context]): 72 | at_with_context = [at for at in answer_text_with_context if at in wiki_par_char][0] 73 | tmp_start = wiki_par_char.index(at_with_context) 74 | if len([at for at in answer_text_with_context if at in wiki_par_char]) < 3: 75 | if at_with_context == answer_text: # There are some false negatives but we skip 76 | # print(paragraph['context']) 77 | # print(wiki_par['context']) 78 | # print(answer_text) 79 | # import pdb; pdb.set_trace() 80 | break 81 | # try: 82 | new_start = nosp_to_sp[wiki_par_char[tmp_start:].index(answer_text)+tmp_start] 83 | new_end = nosp_to_sp[wiki_par_char[tmp_start:].index(answer_text)+tmp_start+len(answer_text)-1] 84 | wiki_paragraph = copy.deepcopy(wiki_par['context']) 85 | # except ValueError as e: 86 | # print("Could not found start position after de-tokenize") 87 | # tokenize_error += 1 88 | # import pdb; pdb.set_trace() 89 | # continue 90 | answer_found = True 91 | break 92 | # elif answer_text in wiki_par_char: 93 | # answer_found = True 94 | 95 | # If answer is found, append 96 | if new_start is not None: 97 | if answer_text != wiki_par['context'][new_start:new_end+1].lower().replace(' ', ''): 98 | print('mismatch between original vs. new answer: {} vs. {}'.format( 99 | answer_text, wiki_par['context'][new_start:new_end+1].lower().replace(' ', '') 100 | )) 101 | 102 | if new_paragraph is None: 103 | new_paragraph = copy.deepcopy(paragraph) 104 | new_paragraph['context'] = wiki_paragraph 105 | new_paragraph['qas'][0]['answers'] = [{ 106 | 'text': wiki_paragraph[new_start:new_end+1], 107 | 'answer_start': new_start, 108 | 'wiki_matched': True, 109 | }] 110 | else: 111 | if new_paragraph['context'] != wiki_paragraph: # If other answers are in different para, we skip 112 | continue 113 | new_paragraph['qas'][0]['answers'].append({ 114 | 'text': wiki_paragraph[new_start:new_end+1], 115 | 'answer_start': new_start, 116 | 'wiki_matched': True, 117 | }) 118 | 119 | # Just use existing paragraph when no answer is found 120 | if not answer_found: 121 | answer_not_found_cnt += 1 122 | new_paragraph = copy.deepcopy(paragraph) 123 | for qas in new_paragraph['qas']: 124 | for ans in qas['answers']: 125 | ans['wiki_matched'] = False 126 | else: 127 | match_cnt += 1 128 | 129 | assert new_paragraph is not None 130 | new_data.append({ 131 | 'title': title, 132 | 'paragraphs': [new_paragraph], 133 | }) 134 | 135 | print(f'matched title: {para_cnt}') 136 | print(f'not found title: {title_not_found_cnt}') 137 | print(f'matched answer: {match_cnt}') 138 | print(f'answer not found: {answer_not_found_cnt}') 139 | print(f'tokenize error: {tokenize_error}') 140 | print(f'total saved data: {len(new_data)}') 141 | 142 | output_path = os.path.join( 143 | os.path.dirname(input_file), os.path.splitext(os.path.basename(input_file))[0] + '_wiki3.json' 144 | ) 145 | print(f'Saving into {output_path}') 146 | with open(output_path, 'w') as f: 147 | json.dump({'data': new_data}, f) 148 | print() 149 | 150 | 151 | if __name__ == '__main__': 152 | parser = argparse.ArgumentParser() 153 | parser.add_argument('input_files', type=str, default=None) 154 | parser.add_argument('output_dir', type=str) 155 | parser.add_argument('wiki_dir', type=str, default=None) 156 | args = parser.parse_args() 157 | 158 | # Prepare wiki first 159 | wiki_files = sorted(glob.glob(args.wiki_dir + "*")) 160 | print(f'Matching with {len(wiki_files)} number of wikisquad files') 161 | wiki_dump = {} 162 | for wiki_file in tqdm(wiki_files): 163 | with open(wiki_file, 'r') as f: 164 | wiki_squad = json.load(f) 165 | for wiki_article in wiki_squad['data']: 166 | wiki_dump[wiki_article['title']] = wiki_article['paragraphs'] 167 | # break 168 | 169 | for input_file in args.input_files.split(','): 170 | print(f'Processing {input_file}') 171 | nq_to_wiki(input_file, args.output_dir, wiki_dump) 172 | -------------------------------------------------------------------------------- /scripts/preprocess/create_openqa.py: -------------------------------------------------------------------------------- 1 | import json 2 | import argparse 3 | import os 4 | import csv 5 | 6 | from tqdm import tqdm 7 | # from drqa.retriever.utils import normalize 8 | 9 | def get_gold_answers_kilt(gold): 10 | ground_truths = set() 11 | for item in gold["output"]: 12 | if "answer" in item and item["answer"] and len(item["answer"].strip()) > 0: 13 | ground_truths.add(item["answer"].strip()) 14 | return ground_truths 15 | 16 | def preprocess_openqa(input_file, input_type, out_dir): 17 | data_to_save = [] 18 | # SQuAD 19 | if input_type == 'SQuAD': 20 | with open(input_file, 'r') as f: 21 | articles = json.load(f)['data'] 22 | for article in articles: 23 | for paragraph in article['paragraphs']: 24 | for qa in paragraph['qas']: 25 | if type(qa['answers']) == dict: 26 | qa['answers'] = [qa['answers']] 27 | data_to_save.append({ 28 | 'id': qa['id'], 29 | 'question': qa['question'], 30 | 'answers': [ans['text'] for ans in qa['answers']] 31 | }) 32 | # CuratedTrec / WebQuestions / WikiMovies 33 | elif input_type == 'DrQA': 34 | tag = os.path.splitext(os.path.basename(input_file))[0] 35 | for line_idx, line in tqdm(enumerate(open(input_file))): 36 | data = json.loads(line) 37 | # answers = [normalize(a) for a in data['answer']] # necessary? 38 | answers = [a for a in data['answer']] 39 | data_to_save.append({ 40 | 'id': f'{tag}_{line_idx}', 41 | 'question': data['question'], 42 | 'answers': answers 43 | }) 44 | # NaturalQuestions / TriviaQA 45 | elif input_type == 'HardEM': 46 | tag = os.path.splitext(os.path.basename(input_file))[0] 47 | data = json.load(open(input_file))['data'] 48 | for item_idx, item in tqdm(enumerate(data)): 49 | data_to_save.append({ 50 | 'id': f'{tag}_{item_idx}', 51 | 'question': item['question'], 52 | 'answers': item['answers'] 53 | }) 54 | # DPR style files 55 | elif input_type == 'DPR': 56 | tag = os.path.splitext(os.path.basename(input_file))[0] 57 | data = json.load(open(input_file)) 58 | for item_idx, item in tqdm(enumerate(data)): 59 | data_to_save.append({ 60 | 'id': f'{tag}_{item_idx}', 61 | 'question': item['question'], 62 | 'answers': item['answers'] 63 | }) 64 | # COVID-19 65 | elif input_type == 'COVID-19': 66 | assert os.path.isdir(input_file) 67 | for filename in os.listdir(input_file): 68 | if 'preprocessed' in filename: 69 | print(f'Skipping {filename}') 70 | continue 71 | file_path = os.path.join(input_file, filename) 72 | tag = os.path.splitext(os.path.basename(file_path))[0] 73 | with open(file_path, 'r') as f: 74 | with tqdm(enumerate(f)) as tq: 75 | tq.set_description(filename + '\t') 76 | for line_idx, line in tq: 77 | data_to_save.append({ 78 | 'id': f'{tag}_{line_idx}', 79 | 'question': line.strip(), 80 | 'answers': [''] 81 | }) 82 | # TREX, ZSRE (KILT) 83 | elif input_type.lower() in ['trex', 't-rex', 'zsre']: 84 | with open(input_file) as f: 85 | for line in tqdm(f): 86 | data = json.loads(line) 87 | id = data['id'] 88 | question = data['input'] 89 | answers = get_gold_answers_kilt(data) 90 | answers = list(answers) 91 | 92 | data_to_save.append({ 93 | 'id': id, 94 | 'question': question, 95 | 'answers': answers 96 | }) 97 | # Jsonl (LAMA) 98 | elif input_type.lower() in ['jsonl']: 99 | tag = os.path.splitext(os.path.basename(input_file))[0] 100 | with open(input_file) as f: 101 | for line_idx, line in tqdm(enumerate(f)): 102 | data = json.loads(line) 103 | question = data['question'] 104 | answers = data['answer'] 105 | 106 | data_to_save.append({ 107 | 'id': f'{tag}_{line_idx}', 108 | 'question': question, 109 | 'answers': answers 110 | }) 111 | # CSV 112 | elif input_type.lower() in ['csv']: 113 | import ast 114 | tag = os.path.splitext(os.path.basename(input_file))[0] 115 | with open(input_file) as f: 116 | csv_reader = csv.reader(f, delimiter='\t') 117 | for line_idx, line in tqdm(enumerate(csv_reader)): 118 | question = line[0] 119 | answers = ast.literal_eval(line[1]) 120 | 121 | data_to_save.append({ 122 | 'id': f'{tag}_{line_idx}', 123 | 'question': question, 124 | 'answers': answers 125 | }) 126 | else: 127 | raise NotImplementedError 128 | 129 | assert os.path.exists(out_dir) 130 | out_path = os.path.join(out_dir, os.path.splitext(os.path.basename(input_file))[0] + '_preprocessed.json') 131 | print(f'Saving {len(data_to_save)} questions.') 132 | print('Writing to %s\n'% out_path) 133 | with open(out_path, 'w') as f: 134 | json.dump({'data': data_to_save}, f) 135 | 136 | 137 | if __name__ == '__main__': 138 | parser = argparse.ArgumentParser() 139 | parser.add_argument('input_file', type=str, default=None) 140 | parser.add_argument('out_dir', type=str) 141 | parser.add_argument('--input_type', type=str, default='SQuAD', help='SQuAD|DrQA|HardEM') 142 | args = parser.parse_args() 143 | preprocess_openqa(args.input_file, args.input_type, args.out_dir) 144 | -------------------------------------------------------------------------------- /scripts/preprocess/create_psg_hdf5.py: -------------------------------------------------------------------------------- 1 | import json 2 | import argparse 3 | import os 4 | import h5py 5 | import csv 6 | 7 | from tqdm import tqdm 8 | 9 | 10 | def create_psg_hdf5(input_file, out_file): 11 | passages = {} 12 | with open(input_file) as f: 13 | psg_file = csv.reader(f, delimiter='\t') 14 | for data_idx, data in tqdm(enumerate(psg_file)): 15 | if data_idx == 0: 16 | print('Reading', data) 17 | continue 18 | id_, psg, title = data 19 | passages[id_] = [psg, title] 20 | # break 21 | 22 | # Must use bucket; otherwise writing to a hdf5 file is very slow with a large number of keys 23 | bucket_size = 1000000 24 | # buckets = [(start, min(start+bucket_size-1, 21015324)) for start in range(1, 21015325, bucket_size)] 25 | buckets = [(start, min(start+bucket_size-1, len(passages))) for start in range(1, len(passages)+1, bucket_size)] 26 | print(f'Putting {len(passages)} passages into {len(buckets)} buckets') 27 | print(buckets) 28 | with h5py.File(out_file, 'w') as f: 29 | for pid, data in tqdm(passages.items()): 30 | bucket_name = None 31 | for start, end in buckets: 32 | if (int(pid) >= start) and (int(pid) <= end): 33 | bucket_name = f'{start}-{end}' 34 | break 35 | assert bucket_name is not None 36 | # continue 37 | 38 | if bucket_name not in f: 39 | dg = f.create_group(bucket_name) 40 | else: 41 | dg = f[bucket_name] 42 | assert pid not in dg 43 | pg = dg.create_group(pid) 44 | pg.attrs['context'], pg.attrs['title'] = data 45 | 46 | print(f'Saving {out_file} done') 47 | 48 | 49 | if __name__ == '__main__': 50 | parser = argparse.ArgumentParser() 51 | parser.add_argument('input_file', type=str, default=None) 52 | parser.add_argument('out_file', type=str) 53 | args = parser.parse_args() 54 | create_psg_hdf5(args.input_file, args.out_file) 55 | -------------------------------------------------------------------------------- /scripts/preprocess/create_tqa_ds.py: -------------------------------------------------------------------------------- 1 | 2 | import json 3 | import pdb 4 | import re 5 | import random 6 | from tqdm import tqdm 7 | import string 8 | import argparse 9 | 10 | try: 11 | from eval_utils import ( 12 | drqa_exact_match_score, 13 | drqa_regex_match_score, 14 | drqa_metric_max_over_ground_truths 15 | ) 16 | except ModuleNotFoundError: 17 | import sys 18 | import os 19 | sys.path.append(os.path.dirname(os.path.abspath(os.path.dirname(__file__)))) 20 | from eval_utils import ( 21 | drqa_exact_match_score, 22 | drqa_regex_match_score, 23 | drqa_metric_max_over_ground_truths 24 | ) 25 | 26 | # fix random seed 27 | random.seed(0) 28 | 29 | def find_substring_and_return_random_idx(substring, string): 30 | substring_idxs = [m.start() for m in re.finditer(re.escape(substring), string)] 31 | substring_idx = random.choice(substring_idxs) 32 | return substring_idx 33 | 34 | def main(args): 35 | print("loading input data") 36 | with open(args.input_path, encoding='utf-8') as f: 37 | data = json.load(f) 38 | 39 | output_data = [] 40 | 41 | for sample_id in tqdm(data): 42 | sample = data[sample_id] 43 | 44 | question = sample['question'] 45 | answers = sample['answer'] 46 | predictions = sample['prediction'] 47 | titles = sample['title'] 48 | evidences = sample['evidence'] 49 | 50 | match_fn = drqa_regex_match_score if args.regex else drqa_exact_match_score 51 | 52 | answer_text = "" 53 | answer_start = -1 54 | ds_context = "" 55 | ds_title = "" 56 | # is_from_context = False 57 | 58 | # check if prediction is matched in a golden answer in the answer list 59 | for pred_idx, pred in enumerate(predictions): 60 | if pred != "" and drqa_metric_max_over_ground_truths(match_fn, pred, answers): 61 | answer_text = pred 62 | ds_context = evidences[pred_idx] 63 | ds_title = titles[pred_idx][0] 64 | answer_start = find_substring_and_return_random_idx(answer_text, ds_context) 65 | break 66 | 67 | # NOTE! hide these lines because is_from_context contains too many noises 68 | # # in case prediction is not matched to any golden answer, 69 | # # check if golden answer is contained in the context 70 | # if answer_start < 0: 71 | # found = False 72 | # for evid_idx, evid in enumerate(evidences): 73 | # for ans in answers: 74 | # if ans != "" and ans in evid: 75 | # found = True 76 | # answer_text = ans 77 | # answer_start = find_substring_and_return_random_idx(ans, evid) 78 | # ds_context = evidences[evid_idx] 79 | # ds_title = titles[evid_idx][0] 80 | # is_from_context = True 81 | # if found: 82 | # break 83 | 84 | # no answer is found in 85 | is_impossible = False 86 | if answer_start < 0 or answer_text == "": 87 | ds_title = titles[0][0] 88 | ds_context = evidences[0] 89 | is_impossible = True 90 | else: 91 | assert answer_text == ds_context[answer_start:answer_start+len(answer_text)] 92 | 93 | output_data.append({ 94 | 'title': ds_title, 95 | 'paragraphs':[{ 96 | 'context': ds_context, 97 | 'qas':[{ 98 | 'question': question, 99 | 'is_impossible' : is_impossible, 100 | 'answers': [{ 101 | 'text': answer_text, 102 | 'answer_start': answer_start 103 | }] if is_impossible == False else [], 104 | # 'is_from_context':is_from_context 105 | }], 106 | 'id': sample_id 107 | }] 108 | }) 109 | 110 | with open(args.output_path, 'w', encoding='utf-8') as f: 111 | json.dump({ 112 | 'data': output_data 113 | },f) 114 | 115 | 116 | # ------------------------------------------------------------------------------ 117 | # Main. 118 | # ------------------------------------------------------------------------------ 119 | 120 | 121 | if __name__ == '__main__': 122 | parser = argparse.ArgumentParser() 123 | parser.add_argument('input_path', type=str, default='/home/pred/sbcd_sqdqgnqqg_inb64_s384_sqdnq_pinb2_0_20181220_concat_train_preprocessed_78785.pred') 124 | parser.add_argument('output_path', type=str, default='tqa_ds_train.json') 125 | parser.add_argument('--regex', action='store_true') 126 | args = parser.parse_args() 127 | 128 | main(args) 129 | -------------------------------------------------------------------------------- /scripts/preprocess/doc_db.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | """Documents, in a sqlite database.""" 8 | 9 | import sqlite3 10 | import unicodedata 11 | 12 | def normalize(text): 13 | """Resolve different type of unicode encodings.""" 14 | return unicodedata.normalize('NFD', text) 15 | 16 | class DocDB(object): 17 | """Sqlite backed document storage. 18 | Implements get_doc_text(doc_id). 19 | """ 20 | 21 | def __init__(self, db_path=None): 22 | # self.path = db_path or DEFAULTS['db_path'] 23 | self.path = db_path 24 | self.connection = sqlite3.connect(self.path, check_same_thread=False) 25 | 26 | def __enter__(self): 27 | return self 28 | 29 | def __exit__(self, *args): 30 | self.close() 31 | 32 | def path(self): 33 | """Return the path to the file that backs this database.""" 34 | return self.path 35 | 36 | def close(self): 37 | """Close the connection to the database.""" 38 | self.connection.close() 39 | 40 | def get_doc_ids(self): 41 | """Fetch all ids of docs stored in the db.""" 42 | cursor = self.connection.cursor() 43 | cursor.execute("SELECT id FROM documents") 44 | results = [r[0] for r in cursor.fetchall()] 45 | cursor.close() 46 | return results 47 | 48 | def get_doc_text(self, doc_id): 49 | """Fetch the raw text of the doc for 'doc_id'.""" 50 | cursor = self.connection.cursor() 51 | cursor.execute( 52 | "SELECT text FROM documents WHERE id = ?", 53 | (normalize(doc_id),) 54 | ) 55 | result = cursor.fetchone() 56 | cursor.close() 57 | return result if result is None else result[0] -------------------------------------------------------------------------------- /scripts/preprocess/download_wikidump.py: -------------------------------------------------------------------------------- 1 | """ 2 | download wiki dump 20181220 checking md5sum 3 | """ 4 | 5 | import os 6 | import json 7 | import urllib.request 8 | import urllib.parse as urlparse 9 | import argparse 10 | import hashlib 11 | import logging 12 | import portalocker 13 | import pdb 14 | from tqdm import tqdm 15 | 16 | def parse_args(): 17 | """ 18 | Parse input arguments 19 | """ 20 | parser = argparse.ArgumentParser() 21 | 22 | # Required 23 | parser.add_argument('--output_dir', required=True) 24 | 25 | args = parser.parse_args() 26 | return args 27 | 28 | def download_file(url, output_dir, size, expected_md5sum=None): 29 | """ 30 | download file and check md5sum 31 | """ 32 | logging.info("url={}".format(url)) 33 | 34 | if not os.path.exists(output_dir): 35 | os.mkdir(output_dir) 36 | bz2file = os.path.join(output_dir, os.path.basename(url)) 37 | 38 | lockfile = '{}.lock'.format(bz2file) 39 | with portalocker.Lock(lockfile, 'w', timeout=60): 40 | if not os.path.exists(bz2file) or os.path.getsize(bz2file) != size: 41 | logging.info("Downloading {}".format(bz2file)) 42 | with urllib.request.urlopen(url) as f: 43 | with open(bz2file, 'wb') as out: 44 | for data in tqdm(f, unit='KB'): 45 | out.write(data) 46 | 47 | # Check md5sum 48 | if expected_md5sum is not None: 49 | md5 = hashlib.md5() 50 | with open(bz2file, 'rb') as infile: 51 | for line in infile: 52 | md5.update(line) 53 | if md5.hexdigest() != expected_md5sum: 54 | logging.error('Fatal: MD5 sum of downloaded file was incorrect (got {}, expected {}).'.format(md5.hexdigest(), expected_md5)) 55 | logging.error('Please manually delete "{}" and rerun the command.'.format(tarball)) 56 | logging.error('If the problem persists, the tarball may have changed, in which case, please contact the SacreBLEU maintainer.') 57 | sys.exit(1) 58 | else: 59 | logging.info('Checksum passed: {}'.format(md5.hexdigest())) 60 | 61 | def main(args): 62 | url = 'https://archive.org/download/enwiki-20181220/enwiki-20181220-pages-articles.xml.bz2' 63 | expected_md5sum = 'ccf875b2af67109fe5b98b5b720ce322' 64 | size = 15712882238 65 | 66 | download_file( 67 | url=url, 68 | output_dir=args.output_dir, 69 | size=size, 70 | expected_md5sum=expected_md5sum 71 | ) 72 | 73 | if __name__ == '__main__': 74 | logging.basicConfig( 75 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 76 | datefmt="%m/%d/%Y %H:%M:%S", 77 | level=logging.INFO 78 | ) 79 | args = parse_args() 80 | main(args) -------------------------------------------------------------------------------- /scripts/preprocess/filter_noans.py: -------------------------------------------------------------------------------- 1 | import spacy 2 | import json 3 | import random 4 | import numpy as np 5 | from tqdm import tqdm 6 | from squad_metrics import compute_exact 7 | nlp = spacy.load("en_core_web_sm") 8 | 9 | doc = nlp('European authorities fined Google a record $5.1 billion on Wednesday for abusing its power in the mobile phone market and ordered the company to alter its practices') 10 | print([(X.text, X.label_) for X in doc.ents]) 11 | 12 | 13 | data_path = '/home/data/nq-reader/dev_wiki3.json' 14 | sample = False 15 | print(f'reading {data_path} with sampling: {sample}') 16 | train_set = json.load(open(data_path)) 17 | new_train_set = {'data': []} 18 | cnt = 0 19 | new_cnt = 0 20 | filtered_cnt = 0 21 | 22 | for article in tqdm(train_set['data']): 23 | new_article = { 24 | 'title': article['title'], 25 | 'paragraphs': [] 26 | } 27 | for p_idx, paragraph in enumerate(article['paragraphs']): 28 | new_paragraph = { 29 | 'context': paragraph['context'], 30 | 'qas' : [], 31 | } 32 | 33 | for qa in paragraph['qas']: 34 | question = qa['question'] 35 | id_ = qa['id'] 36 | assert type(qa["answers"]) == dict or type(qa["answers"]) == list, type(qa["answers"]) 37 | if type(qa["answers"]) == dict: 38 | qa["answers"] = [qa["answers"]] 39 | cnt += 1 40 | if len(qa["answers"]) == 0: 41 | filtered_cnt += 1 42 | continue 43 | 44 | new_paragraph['qas'].append(qa) 45 | new_cnt += 1 46 | new_article['paragraphs'].append(new_paragraph) 47 | 48 | new_train_set['data'].append(new_article) 49 | # break 50 | 51 | write_path = data_path.replace('.json', '_na_filtered.json') 52 | with open(write_path, 'w') as f: 53 | json.dump(new_train_set, f) 54 | 55 | assert filtered_cnt + new_cnt == cnt 56 | print(f'writing to {write_path} with {cnt} samples') 57 | print(f'all sample: {cnt}, new sample: {new_cnt}') 58 | -------------------------------------------------------------------------------- /scripts/preprocess/filter_wiki.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import argparse 4 | 5 | from tqdm import tqdm 6 | 7 | 8 | def filter_wiki(args): 9 | if not os.path.exists(args.to_dir): 10 | os.makedirs(args.to_dir) 11 | 12 | names = os.listdir(args.from_dir) 13 | from_paths = [os.path.join(args.from_dir, name) for name in names] 14 | to_paths = [os.path.join(args.to_dir, name) for name in names] 15 | 16 | for from_path, to_path in zip(tqdm(from_paths), to_paths): 17 | with open(from_path, 'r') as fp: 18 | from_ = json.load(fp) 19 | to = {'data': []} 20 | for article in from_['data']: 21 | to_article = {'paragraphs': [], 'title': article['title']} 22 | for para in article['paragraphs']: 23 | if args.min_num_chars <= len(para['context']) < args.max_num_chars: 24 | to_article['paragraphs'].append(para) 25 | to['data'].append(to_article) 26 | 27 | with open(to_path, 'w') as fp: 28 | json.dump(to, fp) 29 | 30 | 31 | def get_args(): 32 | parser = argparse.ArgumentParser() 33 | parser.add_argument('from_dir') 34 | parser.add_argument('to_dir') 35 | parser.add_argument('--min_num_chars', default=250, type=int) 36 | parser.add_argument('--max_num_chars', default=2500, type=int) 37 | return parser.parse_args() 38 | 39 | 40 | def main(): 41 | args = get_args() 42 | filter_wiki(args) 43 | 44 | 45 | if __name__ == '__main__': 46 | main() 47 | -------------------------------------------------------------------------------- /scripts/preprocess/merge_openqa.py: -------------------------------------------------------------------------------- 1 | import json 2 | import argparse 3 | import os 4 | 5 | from tqdm import tqdm 6 | 7 | 8 | def merge_openqa(input_dir, output_path): 9 | 10 | paths = [ 11 | 'open-qa/nq-open/train_preprocessed.json', 12 | 'open-qa/webq/WebQuestions-train-nodev_preprocessed.json', 13 | 'open-qa/trec/CuratedTrec-train-nodev_preprocessed.json', 14 | 'open-qa/triviaqa-unfiltered/train_preprocessed.json', 15 | 'open-qa/squad/train_preprocessed.json', 16 | 'kilt/trex/trex-train-kilt_open_10000.json', 17 | 'kilt/zsre/structured_zeroshot-train-kilt_open_10000.json', 18 | ] 19 | paths = [os.path.join(input_dir, path) for path in paths] 20 | assert all([os.path.exists(path) for path in paths]) 21 | 22 | data_to_save = [] 23 | sep_cnt = 0 24 | for path in paths: 25 | with open(path) as f: 26 | data = json.load(f)['data'] 27 | for item in data: 28 | if ' [SEP] ' in item['question']: 29 | item['question'] = item['question'].replace(' [SEP] ', ' ') 30 | sep_cnt += 1 31 | data_to_save += data 32 | print(f'{path} has {len(data)} QA pairs') 33 | 34 | print(f'Saving {len(data_to_save)} questions to output_path') 35 | print(f'Removed [SEP] for {sep_cnt} questions') 36 | print('Writing to %s\n'% output_path) 37 | with open(output_path, 'w') as f: 38 | json.dump({'data': data_to_save}, f) 39 | 40 | 41 | if __name__ == '__main__': 42 | parser = argparse.ArgumentParser() 43 | parser.add_argument('input_dir', type=str, default=None) 44 | parser.add_argument('output_path', type=str) 45 | args = parser.parse_args() 46 | merge_openqa(args.input_dir, args.output_path) 47 | -------------------------------------------------------------------------------- /scripts/preprocess/merge_paq.py: -------------------------------------------------------------------------------- 1 | import json 2 | import argparse 3 | import os 4 | import h5py 5 | import csv 6 | 7 | from tqdm import tqdm 8 | 9 | 10 | def merge_paq(input_dir, out_file): 11 | num_split = 8 12 | filenames = [f'PAQ.metadata.hard0-{k}.jsonl' for k in range(num_split)] 13 | print('reading', filenames) 14 | fps = [open(os.path.join(input_dir, filename), 'r') for filename in filenames] 15 | 16 | with open(out_file, 'w') as fw: 17 | fp_idx = 0 18 | total_cnt = 0 19 | hard_cnt = 0 20 | line = fps[fp_idx].readline() 21 | while line: 22 | # for stats 23 | meta = json.loads(line) 24 | if len(meta['hard_neg_pids']) > 0: 25 | hard_cnt += 1 26 | total_cnt += 1 27 | 28 | if total_cnt % 100000 == 0: 29 | print(f'Total: {total_cnt}, Hard neg: {hard_cnt}') 30 | 31 | # write it 32 | json.dump(meta, fw, separators=(',', ':')) 33 | fw.write('\n') 34 | fp_idx = (fp_idx + 1) % num_split 35 | line = fps[fp_idx].readline() 36 | 37 | print(f'Total: {total_cnt}, Hard neg: {hard_cnt}') 38 | print(f'Saving {out_file} done') 39 | 40 | 41 | if __name__ == '__main__': 42 | parser = argparse.ArgumentParser() 43 | parser.add_argument('input_dir', type=str, default=None) 44 | parser.add_argument('out_file', type=str) 45 | args = parser.parse_args() 46 | merge_paq(args.input_dir, args.out_file) 47 | -------------------------------------------------------------------------------- /scripts/preprocess/merge_singleqa.py: -------------------------------------------------------------------------------- 1 | import json 2 | import argparse 3 | import os 4 | 5 | from tqdm import tqdm 6 | 7 | 8 | def merge_single(input_dir, output_path): 9 | 10 | paths = [ 11 | 'single-qa/nq/train_wiki3.json', 12 | 'single-qa/webq/webq-train_ds.json', 13 | 'single-qa/trec/trec-train_ds.json', 14 | 'single-qa/tqa/tqa-train_ds.json', 15 | # 'single-qa/squad/train-v1.1.json', 16 | ] 17 | paths = [os.path.join(input_dir, path) for path in paths] 18 | assert all([os.path.exists(path) for path in paths]) 19 | 20 | data_to_save = [] 21 | sep_cnt = 0 22 | for path in paths: 23 | with open(path) as f: 24 | data = json.load(f)['data'] 25 | data_to_save += data 26 | print(f'{path} has {len(data)} PQA triples') 27 | 28 | print(f'Saving {len(data_to_save)} RC triples to output_path') 29 | print('Writing to %s\n'% output_path) 30 | with open(output_path, 'w') as f: 31 | json.dump({'data': data_to_save}, f) 32 | 33 | 34 | if __name__ == '__main__': 35 | parser = argparse.ArgumentParser() 36 | parser.add_argument('input_dir', type=str, default=None) 37 | parser.add_argument('output_path', type=str) 38 | args = parser.parse_args() 39 | merge_single(args.input_dir, args.output_path) 40 | -------------------------------------------------------------------------------- /scripts/preprocess/nq_utils.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import json 3 | import numpy as np 4 | from tqdm import tqdm 5 | 6 | class LongAnswerCandidate(object): 7 | """Representation of long answer candidate.""" 8 | 9 | def __init__(self, contents, index, is_answer, contains_answer, start_token, end_token): 10 | self.contents = contents 11 | self.index = index 12 | self.is_answer = is_answer 13 | self.contains_answer = contains_answer 14 | self.start_token = start_token 15 | self.end_token = end_token 16 | if is_answer: 17 | self.style = 'is_answer' 18 | elif contains_answer: 19 | self.style = 'contains_answer' 20 | else: 21 | self.style = 'not_answer' 22 | 23 | 24 | class Example(object): 25 | """Example representation.""" 26 | 27 | def __init__(self, json_example, dataset): 28 | self.json_example = json_example 29 | 30 | # Whole example info. 31 | self.url = json_example['document_url'] 32 | self.title = ( 33 | json_example['document_title'] 34 | if 'document_title' in json_example else 'Wikipedia') 35 | # self.example_id = base64.urlsafe_b64encode( 36 | # str(self.json_example['example_id'])) 37 | self.example_id = str(self.json_example['example_id']) 38 | self.document_html = self.json_example['document_html'].encode('utf-8') 39 | self.document_tokens = self.json_example['document_tokens'] 40 | self.question_text = json_example['question_text'] 41 | 42 | if dataset == 'train': 43 | if len(json_example['annotations']) != 1: 44 | raise ValueError( 45 | 'Train set json_examples should have a single annotation.') 46 | annotation = json_example['annotations'][0] 47 | self.has_long_answer = annotation['long_answer']['start_byte'] >= 0 48 | self.has_short_answer = annotation[ 49 | 'short_answers'] or annotation['yes_no_answer'] != 'NONE' 50 | 51 | elif dataset == 'dev': 52 | if len(json_example['annotations']) != 5: 53 | raise ValueError('Dev set json_examples should have five annotations.') 54 | self.has_long_answer = sum([ 55 | annotation['long_answer']['start_byte'] >= 0 56 | for annotation in json_example['annotations'] 57 | ]) >= 2 58 | self.has_short_answer = sum([ 59 | bool(annotation['short_answers']) or 60 | annotation['yes_no_answer'] != 'NONE' 61 | for annotation in json_example['annotations'] 62 | ]) >= 2 63 | 64 | self.long_answers = [ 65 | a['long_answer'] 66 | for a in json_example['annotations'] 67 | if a['long_answer']['start_byte'] >= 0 and self.has_long_answer 68 | ] 69 | self.short_answers = [ 70 | a['short_answers'] 71 | for a in json_example['annotations'] 72 | if a['short_answers'] and self.has_short_answer 73 | ] 74 | self.yes_no_answers = [ 75 | a['yes_no_answer'] 76 | for a in json_example['annotations'] 77 | if a['yes_no_answer'] != 'NONE' and self.has_short_answer 78 | ] 79 | 80 | if self.has_long_answer: 81 | long_answer_bounds = [ 82 | (la['start_byte'], la['end_byte']) for la in self.long_answers 83 | ] 84 | long_answer_counts = [ 85 | long_answer_bounds.count(la) for la in long_answer_bounds 86 | ] 87 | long_answer = self.long_answers[np.argmax(long_answer_counts)] 88 | self.long_answer_text = self.render_long_answer(long_answer) 89 | 90 | else: 91 | self.long_answer_text = '' 92 | 93 | if self.has_short_answer: 94 | short_answers_ids = [[ 95 | (s['start_byte'], s['end_byte']) for s in a 96 | ] for a in self.short_answers] + [a for a in self.yes_no_answers] 97 | short_answers_counts = [ 98 | short_answers_ids.count(a) for a in short_answers_ids 99 | ] 100 | 101 | self.short_answers_texts = [ 102 | b', '.join([ 103 | self.render_span(s['start_byte'], s['end_byte']) 104 | for s in short_answer 105 | ]) 106 | for short_answer in self.short_answers 107 | ] 108 | 109 | self.short_answers_texts += self.yes_no_answers 110 | self.short_answers_text = self.short_answers_texts[np.argmax( 111 | short_answers_counts)] 112 | self.short_answers_texts = set(self.short_answers_texts) 113 | 114 | else: 115 | self.short_answers_texts = [] 116 | self.short_answers_text = '' 117 | 118 | self.candidates = self.get_candidates( 119 | self.json_example['long_answer_candidates']) 120 | 121 | self.candidates_with_answer = [ 122 | i for i, c in enumerate(self.candidates) if c.contains_answer 123 | ] 124 | 125 | def render_long_answer(self, long_answer): 126 | """Wrap table rows and list items, and render the long answer. 127 | 128 | Args: 129 | long_answer: Long answer dictionary. 130 | 131 | Returns: 132 | String representation of the long answer span. 133 | """ 134 | 135 | if long_answer['end_token'] - long_answer['start_token'] > 500: 136 | return 'Large long answer' 137 | 138 | html_tag = self.document_tokens[long_answer['end_token'] - 1]['token'] 139 | if html_tag == '' and self.render_span( 140 | long_answer['start_byte'], long_answer['end_byte']).count(b'') > 30: 141 | return 'Large table long answer' 142 | 143 | elif html_tag == '': 144 | return '{}
'.format( 145 | self.render_span(long_answer['start_byte'], long_answer['end_byte'])) 146 | 147 | elif html_tag in ['', '', '']: 148 | return '

    {}
'.format( 149 | self.render_span(long_answer['start_byte'], long_answer['end_byte'])) 150 | 151 | else: 152 | return self.render_span(long_answer['start_byte'], 153 | long_answer['end_byte']) 154 | 155 | def render_span(self, start, end): 156 | return self.document_html[start:end] 157 | 158 | def get_candidates(self, json_candidates): 159 | """Returns a list of `LongAnswerCandidate` objects for top level candidates. 160 | 161 | Args: 162 | json_candidates: List of Json records representing candidates. 163 | 164 | Returns: 165 | List of `LongAnswerCandidate` objects. 166 | """ 167 | candidates = [] 168 | top_level_candidates = [c for c in json_candidates if c['top_level']] 169 | for candidate in top_level_candidates: 170 | tokenized_contents = ' '.join([ 171 | t['token'] for t in self.json_example['document_tokens'] 172 | [candidate['start_token']:candidate['end_token']] 173 | ]) 174 | 175 | start = candidate['start_byte'] 176 | end = candidate['end_byte'] 177 | start_token = candidate['start_token'] 178 | end_token = candidate['end_token'] 179 | is_answer = self.has_long_answer and np.any( 180 | [(start == ans['start_byte']) and (end == ans['end_byte']) 181 | for ans in self.long_answers]) 182 | contains_answer = self.has_long_answer and np.any( 183 | [(start <= ans['start_byte']) and (end >= ans['end_byte']) 184 | for ans in self.long_answers]) 185 | 186 | candidates.append( 187 | LongAnswerCandidate(tokenized_contents, len(candidates), is_answer, 188 | contains_answer, start_token, end_token)) 189 | 190 | return candidates 191 | 192 | def has_long_answer(json_example): 193 | for annotation in json_example['annotations']: 194 | if annotation['long_answer']['start_byte'] >= 0: 195 | return True 196 | return False 197 | 198 | 199 | def has_short_answer(json_example): 200 | for annotation in json_example['annotations']: 201 | if annotation['short_answers']: 202 | return True 203 | return False 204 | 205 | def load_examples(fileobj, dataset, mode): 206 | """Reads jsonlines containing NQ examples. 207 | 208 | Args: 209 | fileobj: File object containing NQ examples. 210 | 211 | Returns: 212 | Dictionary mapping example id to `Example` object. 213 | """ 214 | 215 | def _load(examples, f): 216 | """Read serialized json from `f`, create examples, and add to `examples`.""" 217 | 218 | for l in tqdm(f): 219 | json_example = json.loads(l) 220 | if mode == 'long_answers' and not has_long_answer(json_example): 221 | continue 222 | 223 | elif mode == 'short_answers' and not has_short_answer(json_example): 224 | continue 225 | 226 | example = Example(json_example, dataset) 227 | examples[example.example_id] = example 228 | 229 | examples = {} 230 | _load(examples, gzip.GzipFile(fileobj=fileobj)) 231 | 232 | return examples 233 | -------------------------------------------------------------------------------- /scripts/preprocess/prep_wikipedia.py: -------------------------------------------------------------------------------- 1 | # https://github.com/facebookresearch/DrQA/blob/master/scripts/retriever/prep_wikipedia.py 2 | # #!/usr/bin/env python3 3 | # Copyright 2017-present, Facebook, Inc. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the license found in the 7 | # LICENSE file in the root directory of this source tree 8 | """Preprocess function to filter/prepare Wikipedia docs.""" 9 | 10 | import regex as re 11 | from html.parser import HTMLParser 12 | 13 | PARSER = HTMLParser() 14 | BLACKLIST = set(['23443579', '52643645']) # Conflicting disambig. pages 15 | 16 | 17 | def preprocess(article): 18 | # Take out HTML escaping WikiExtractor didn't clean 19 | for k, v in article.items(): 20 | article[k] = PARSER.unescape(v) 21 | 22 | # Filter some disambiguation pages not caught by the WikiExtractor 23 | if article['id'] in BLACKLIST: 24 | return None 25 | if '(disambiguation)' in article['title'].lower(): 26 | return None 27 | if '(disambiguation page)' in article['title'].lower(): 28 | return None 29 | 30 | # Take out List/Index/Outline pages (mostly links) 31 | if re.match(r'(List of .+)|(Index of .+)|(Outline of .+)', 32 | article['title']): 33 | return None 34 | 35 | # Return doc with `id` set to `title` 36 | return {'id': article['title'], 'text': article['text']} -------------------------------------------------------------------------------- /scripts/preprocess/sample_nq_reader_doc_wiki.py: -------------------------------------------------------------------------------- 1 | import json 2 | import glob 3 | import pdb 4 | import argparse 5 | from tqdm import tqdm 6 | import os 7 | import random 8 | import time 9 | 10 | def main(args): 11 | sampling_ratio = args.sampling_ratio 12 | wiki_dir = args.wiki_dir 13 | docs_wiki_dir = args.docs_wiki_dir 14 | output_dir = args.output_dir 15 | 16 | # count the number of total words in wikidump 17 | wiki_file_list = glob.glob(os.path.join(wiki_dir,"*")) 18 | # num_words_in_wiki = 0 19 | # for filename in tqdm(wiki_file_list, total=len(wiki_file_list)): 20 | # with open(filename,'r') as f: 21 | # data = json.load(f)['data'] 22 | 23 | # for doc in data: 24 | # for paragraph in doc['paragraphs']: 25 | # context = paragraph['context'] 26 | # num_words_in_wiki += len(context.split(" ")) 27 | 28 | # print(num_words_in_wiki) 29 | 30 | num_words_in_wiki = 2054581517 31 | num_sample_words = int(num_words_in_wiki * sampling_ratio) 32 | 33 | print("num_words_in_wiki={}".format(num_words_in_wiki)) 34 | 35 | # count the number of total words in docs_wiki 36 | docs_wiki_file_list = sorted(glob.glob(os.path.join(docs_wiki_dir,"*"))) 37 | num_words_in_docs_wiki = 0 38 | docs_wiki_titles = {} 39 | docs_wikis = [] 40 | for filename in tqdm(docs_wiki_file_list, total=len(docs_wiki_file_list)): 41 | with open(filename,'r') as f: 42 | data = json.load(f)['data'] 43 | 44 | for doc in data: 45 | docs_wikis.append(doc) 46 | docs_wiki_titles[doc['title']] = "" 47 | for paragraph in doc['paragraphs']: 48 | context = paragraph['context'] 49 | num_words_in_docs_wiki += len(context.split(" ")) 50 | 51 | print("num_words_in_docs_wiki={}".format(num_words_in_docs_wiki)) 52 | random.seed(2020) 53 | i = 0 54 | while True: 55 | if num_words_in_docs_wiki > num_sample_words: 56 | break 57 | 58 | # random pick from wiki filelist 59 | # start_time = time.time() 60 | random_wiki_file = random.sample(wiki_file_list, 1)[0] 61 | # if i % 100 == 0: 62 | # print("(1) ", time.time() - start_time) 63 | 64 | with open(random_wiki_file,'r') as f: 65 | data = json.load(f)['data'] 66 | 67 | # random pick from articles 68 | # start_time = time.time() 69 | random_articles = random.sample(data, 100) 70 | # if i % 100 == 0: 71 | # print("(2) ", time.time() - start_time) 72 | 73 | # start_time = time.time() 74 | for random_article in random_articles: 75 | # if already existing article in docs_wiki, then pass 76 | if random_article['title'] in docs_wiki_titles: 77 | continue 78 | docs_wikis.append(random_article) 79 | docs_wiki_titles[random_article['title']] = "" 80 | # if i % 100 == 0: 81 | # print("(3) ", time.time() - start_time) 82 | 83 | # start_time = time.time() 84 | for random_article in random_articles: 85 | for paragraph in random_article['paragraphs']: 86 | context = paragraph['context'] 87 | num_words_in_docs_wiki += len(context.split(" ")) 88 | # if i % 100 == 0: 89 | # print("(4) ", time.time() - start_time) 90 | 91 | if i % 100 == 0: 92 | print("title={} len(docs_wiki_titles)={} ratio={}".format(random_article['title'], len(docs_wiki_titles), num_words_in_docs_wiki/num_words_in_wiki)) 93 | i += 1 94 | 95 | if not os.path.exists(output_dir): 96 | os.mkdir(output_dir) 97 | 98 | # shuffle docs_wikis for balanced file size 99 | random.shuffle(docs_wikis) 100 | 101 | for i in range(int(len(docs_wikis)/1000) + 1): 102 | output_file = os.path.join(output_dir, '{:d}'.format(i).zfill(4)) 103 | local_docs_wikis = docs_wikis[i*1000:(i+1)*1000] 104 | 105 | output = { 106 | 'data' : local_docs_wikis 107 | } 108 | 109 | # save nq_reader 110 | with open(output_file,'w') as f: 111 | json.dump(output, f) 112 | 113 | # # pdb.set_trace() 114 | 115 | # wiki_titles = [] 116 | # wiki_title2paragraphs = {} 117 | # for filename in tqdm(wiki_file_list, total=len(wiki_file_list)): 118 | # with open(filename,'r') as f: 119 | # data = json.load(f)['data'] 120 | 121 | # for doc in data: 122 | # title = doc['title'] 123 | # wiki_titles.append(title) 124 | # paragraph = doc['paragraphs'] 125 | # wiki_title2paragraphs[title] = paragraph 126 | # num_wiki += 1 127 | 128 | # assert len(wiki_title2paragraphs) == num_wiki 129 | 130 | # nq_file_list = glob.glob(os.path.join(nq_dir,"*")) 131 | # nq_titles = [] 132 | # unmatched_titles = [] 133 | # num_matched = 0 134 | # num_unmatched = 0 135 | # for filename in tqdm(nq_file_list, total=len(nq_file_list)): 136 | # with open(filename,'r') as f: 137 | # data = json.load(f)['data'] 138 | 139 | # for doc in data: 140 | # title = doc['title'] 141 | # nq_titles.append(title) 142 | # if title in wiki_title2paragraphs and len(wiki_title2paragraphs[title])>0: 143 | # doc['paragraphs'] = wiki_title2paragraphs[title] 144 | # num_matched += 1 145 | # else: 146 | # unmatched_titles.append(title) 147 | # num_unmatched +=1 148 | 149 | # new_paragraphs = [] 150 | # for paragraph in doc['paragraphs']: 151 | # if ('is_paragraph' in paragraph) and (not paragraph['is_paragraph']): 152 | # continue 153 | 154 | # new_paragraphs.append({ 155 | # 'context': paragraph['context'] 156 | # }) 157 | # doc['paragraphs'] = new_paragraphs 158 | 159 | # if not os.path.exists(output_dir): 160 | # os.mkdir(output_dir) 161 | 162 | # output_path = os.path.join(output_dir,os.path.basename(filename)) 163 | # output = { 164 | # 'data': data 165 | # } 166 | 167 | # with open(output_path, 'w') as f: 168 | # json.dump(output, f, indent=2) 169 | 170 | # with open('unmatched_title_old_dev.txt', 'w') as f: 171 | # for title in unmatched_titles: 172 | # f.writelines(title) 173 | # f.writelines("\n") 174 | 175 | # print("num_matched={} num_unmatched={}".format(num_matched, num_unmatched)) 176 | # print("len(nq_titles)={} len(wiki_titles)={}".format(len(nq_titles), len(wiki_titles))) 177 | 178 | if __name__ == '__main__': 179 | parser = argparse.ArgumentParser() 180 | # Required parameters 181 | parser.add_argument("--sampling_ratio", type=float, required=True) 182 | parser.add_argument("--wiki_dir", type=str, required=True) 183 | parser.add_argument("--docs_wiki_dir", type=str, required=True) 184 | parser.add_argument("--output_dir", type=str, required=True) 185 | 186 | args = parser.parse_args() 187 | 188 | main(args) 189 | -------------------------------------------------------------------------------- /scripts/preprocess/simple_tokenizer.py: -------------------------------------------------------------------------------- 1 | # https://github.com/facebookresearch/DrQA/blob/master/drqa/tokenizers/simple_tokenizer.py#L18 2 | 3 | #!/usr/bin/env python3 4 | # Copyright 2017-present, Facebook, Inc. 5 | # All rights reserved. 6 | # 7 | # This source code is licensed under the license found in the 8 | # LICENSE file in the root directory of this source tree. 9 | """Basic tokenizer that splits text into alpha-numeric tokens and 10 | non-whitespace tokens. 11 | """ 12 | 13 | import copy 14 | import regex 15 | import logging 16 | 17 | logger = logging.getLogger(__name__) 18 | 19 | 20 | class Tokens(object): 21 | """A class to represent a list of tokenized text.""" 22 | TEXT = 0 23 | TEXT_WS = 1 24 | SPAN = 2 25 | POS = 3 26 | LEMMA = 4 27 | NER = 5 28 | 29 | def __init__(self, data, annotators, opts=None): 30 | self.data = data 31 | self.annotators = annotators 32 | self.opts = opts or {} 33 | 34 | def __len__(self): 35 | """The number of tokens.""" 36 | return len(self.data) 37 | 38 | def slice(self, i=None, j=None): 39 | """Return a view of the list of tokens from [i, j).""" 40 | new_tokens = copy.copy(self) 41 | new_tokens.data = self.data[i: j] 42 | return new_tokens 43 | 44 | def untokenize(self): 45 | """Returns the original text (with whitespace reinserted).""" 46 | return ''.join([t[self.TEXT_WS] for t in self.data]).strip() 47 | 48 | def words(self, uncased=False): 49 | """Returns a list of the text of each token 50 | Args: 51 | uncased: lower cases text 52 | """ 53 | if uncased: 54 | return [t[self.TEXT].lower() for t in self.data] 55 | else: 56 | return [t[self.TEXT] for t in self.data] 57 | 58 | def offsets(self): 59 | """Returns a list of [start, end) character offsets of each token.""" 60 | return [t[self.SPAN] for t in self.data] 61 | 62 | def pos(self): 63 | """Returns a list of part-of-speech tags of each token. 64 | Returns None if this annotation was not included. 65 | """ 66 | if 'pos' not in self.annotators: 67 | return None 68 | return [t[self.POS] for t in self.data] 69 | 70 | def lemmas(self): 71 | """Returns a list of the lemmatized text of each token. 72 | Returns None if this annotation was not included. 73 | """ 74 | if 'lemma' not in self.annotators: 75 | return None 76 | return [t[self.LEMMA] for t in self.data] 77 | 78 | def entities(self): 79 | """Returns a list of named-entity-recognition tags of each token. 80 | Returns None if this annotation was not included. 81 | """ 82 | if 'ner' not in self.annotators: 83 | return None 84 | return [t[self.NER] for t in self.data] 85 | 86 | def ngrams(self, n=1, uncased=False, filter_fn=None, as_strings=True): 87 | """Returns a list of all ngrams from length 1 to n. 88 | Args: 89 | n: upper limit of ngram length 90 | uncased: lower cases text 91 | filter_fn: user function that takes in an ngram list and returns 92 | True or False to keep or not keep the ngram 93 | as_string: return the ngram as a string vs list 94 | """ 95 | def _skip(gram): 96 | if not filter_fn: 97 | return False 98 | return filter_fn(gram) 99 | 100 | words = self.words(uncased) 101 | ngrams = [(s, e + 1) 102 | for s in range(len(words)) 103 | for e in range(s, min(s + n, len(words))) 104 | if not _skip(words[s:e + 1])] 105 | 106 | # Concatenate into strings 107 | if as_strings: 108 | ngrams = ['{}'.format(' '.join(words[s:e])) for (s, e) in ngrams] 109 | 110 | return ngrams 111 | 112 | def entity_groups(self): 113 | """Group consecutive entity tokens with the same NER tag.""" 114 | entities = self.entities() 115 | if not entities: 116 | return None 117 | non_ent = self.opts.get('non_ent', 'O') 118 | groups = [] 119 | idx = 0 120 | while idx < len(entities): 121 | ner_tag = entities[idx] 122 | # Check for entity tag 123 | if ner_tag != non_ent: 124 | # Chomp the sequence 125 | start = idx 126 | while (idx < len(entities) and entities[idx] == ner_tag): 127 | idx += 1 128 | groups.append((self.slice(start, idx).untokenize(), ner_tag)) 129 | else: 130 | idx += 1 131 | return groups 132 | 133 | 134 | class Tokenizer(object): 135 | """Base tokenizer class. 136 | Tokenizers implement tokenize, which should return a Tokens class. 137 | """ 138 | def tokenize(self, text): 139 | raise NotImplementedError 140 | 141 | def shutdown(self): 142 | pass 143 | 144 | def __del__(self): 145 | self.shutdown() 146 | 147 | class SimpleTokenizer(Tokenizer): 148 | ALPHA_NUM = r'[\p{L}\p{N}\p{M}]+' 149 | NON_WS = r'[^\p{Z}\p{C}]' 150 | 151 | def __init__(self, **kwargs): 152 | """ 153 | Args: 154 | annotators: None or empty set (only tokenizes). 155 | """ 156 | self._regexp = regex.compile( 157 | '(%s)|(%s)' % (self.ALPHA_NUM, self.NON_WS), 158 | flags=regex.IGNORECASE + regex.UNICODE + regex.MULTILINE 159 | ) 160 | if len(kwargs.get('annotators', {})) > 0: 161 | logger.warning('%s only tokenizes! Skipping annotators: %s' % 162 | (type(self).__name__, kwargs.get('annotators'))) 163 | self.annotators = set() 164 | 165 | def tokenize(self, text): 166 | data = [] 167 | matches = [m for m in self._regexp.finditer(text)] 168 | for i in range(len(matches)): 169 | # Get text 170 | token = matches[i].group() 171 | 172 | # Get whitespace 173 | span = matches[i].span() 174 | start_ws = span[0] 175 | if i + 1 < len(matches): 176 | end_ws = matches[i + 1].span()[0] 177 | else: 178 | end_ws = span[1] 179 | 180 | # Format data 181 | data.append(( 182 | token, 183 | text[start_ws: end_ws], 184 | span, 185 | )) 186 | return Tokens(data, self.annotators) -------------------------------------------------------------------------------- /scripts/preprocess/stat_entities.py: -------------------------------------------------------------------------------- 1 | import spacy 2 | import json 3 | import random 4 | import numpy as np 5 | from tqdm import tqdm 6 | 7 | nlp_sent = spacy.load("en_core_web_sm") 8 | doc = nlp_sent('European authorities fined Google a record $5.1 billion on Wednesday for abusing its power in the mobile phone market and ordered the company to alter its practices') 9 | print([(X.text, X.label_) for X in doc.ents]) 10 | 11 | 12 | # pred_file = '/n/fs/nlp-jl5167/outputs/pred/dev_preprocessed_8757.pred' 13 | pred_file = 'lama-test-P20_preprocessed_953.pred' 14 | with open(pred_file) as f: 15 | predictions = json.load(f) 16 | 17 | stat = {} 18 | ent_types = {} 19 | tokenizer_error_cnt = 0 20 | entity_error_cnt = 0 21 | for pid, result in predictions.items(): 22 | question = result['question'] 23 | q_sws = result['q_tokens'][1:-1] # except [CLS], [SEP] 24 | q_ents = [(X.text, X.label_, X[0].idx) for X in nlp_sent(question).ents] 25 | if len(q_ents) == 0: 26 | entity_error_cnt += 1 27 | continue 28 | 29 | word_idx = 0 30 | word_to_sw = {} 31 | for sw_idx, sw in enumerate(q_sws): 32 | if word_idx not in word_to_sw: 33 | word_to_sw[word_idx] = [] 34 | word_to_sw[word_idx].append(sw_idx) 35 | if sw_idx < len(q_sws) - 1: 36 | if not q_sws[sw_idx+1].startswith('##'): 37 | word_idx += 1 38 | try: 39 | assert word_idx == len(question.split(' ')) - 1 40 | except Exception as e: 41 | tokenizer_error_cnt += 1 42 | continue 43 | 44 | char_to_word = {} 45 | word_idx = 0 46 | for ch_idx, ch in enumerate(question): 47 | if ch == ' ': 48 | word_idx += 1 49 | continue 50 | char_to_word[ch_idx] = word_idx 51 | 52 | try: 53 | assert word_idx == len(question.split(' ')) - 1 54 | except Exception as e: 55 | tokenizer_error_cnt += 1 56 | continue 57 | 58 | num_sw = [] 59 | ent_list = [ 60 | 'EVENT', 'FAC', 'GPE', 'LANGUAGE', 'LAW', 'LOC', 61 | 'NORP', 'ORG', 'PERSON', 'PRODUCT', 'WORK_OF_ART' 62 | ] 63 | for ent_text, ent_label, ent_start in q_ents: 64 | if ent_label not in ent_list: 65 | continue 66 | char_start = ent_start 67 | char_end = ent_start + len(ent_text) - 1 68 | word_start = char_to_word[char_start] 69 | word_end = char_to_word[char_end] 70 | num_sw.append(sum([len(word_to_sw[word]) for word in range(word_start, word_end+1)])) 71 | # num_sw.append(max([len(word_to_sw[word]) for word in range(word_start, word_end+1)])) 72 | if ent_label not in ent_types: 73 | ent_types[ent_label] = 0 74 | print(ent_text, ent_label) 75 | ent_types[ent_label] += 1 76 | 77 | if len(num_sw) == 0: 78 | entity_error_cnt += 1 79 | continue 80 | 81 | num_sw = max(num_sw) 82 | if num_sw not in stat: 83 | print(num_sw, q_sws) 84 | stat[num_sw] = [] 85 | stat[num_sw].append(int(result['em_top1'])) 86 | 87 | output = sorted({key: (f'{sum(val)/len(val):.2f}', f'{len(val)} Qs') for key, val in stat.items()}.items()) 88 | print(f'exclude {tokenizer_error_cnt} questions for tokenization error') 89 | print(f'exclude {entity_error_cnt} questions for entity not found error') 90 | print(f'stat: {output} for {len(predictions) - tokenizer_error_cnt - entity_error_cnt} questions') 91 | print(sorted(ent_types.items())) 92 | -------------------------------------------------------------------------------- /scripts/question_generation/filter_qg.py: -------------------------------------------------------------------------------- 1 | import spacy 2 | import json 3 | import random 4 | import numpy as np 5 | from tqdm import tqdm 6 | from squad_metrics import compute_exact 7 | nlp = spacy.load("en_core_web_sm") 8 | 9 | doc = nlp('European authorities fined Google a record $5.1 billion on Wednesday for abusing its power in the mobile phone market and ordered the company to alter its practices') 10 | print([(X.text, X.label_) for X in doc.ents]) 11 | 12 | 13 | data_path = 'data/squad-nq/train-sqdqg_nqqg.json' 14 | sample = False 15 | print(f'reading {data_path} with sampling: {sample}') 16 | train_set = json.load(open(data_path)) 17 | new_train_set = {'data': []} 18 | cnt = 0 19 | new_cnt = 0 20 | orig_cnt = 0 21 | miss_cnt = 0 22 | 23 | prediction_path = 'models/spanbert-base-cased-sqdnq_qgfilter/predictions_.json' 24 | predictions = {str(id_): pred for id_, pred in json.load(open(prediction_path)).items()} 25 | 26 | for article in tqdm(train_set['data']): 27 | new_article = { 28 | 'title': article['title'], 29 | 'paragraphs': [] 30 | } 31 | for p_idx, paragraph in enumerate(article['paragraphs']): 32 | new_paragraph = { 33 | 'context': paragraph['context'], 34 | 'qas' : [], 35 | } 36 | 37 | for qa in paragraph['qas']: 38 | question = qa['question'] 39 | id_ = str(qa['id']) 40 | # assert id_ in predictions 41 | if id_ not in predictions: 42 | print('missing predictions', id_) 43 | miss_cnt += 1 44 | continue 45 | if all(kk in id_ for kk in['_p', '_s', '_a']): 46 | if not compute_exact(qa['answers'][0]['text'], predictions[id_]): 47 | continue 48 | else: 49 | new_cnt += 1 50 | else: 51 | orig_cnt += 1 52 | 53 | new_paragraph['qas'].append(qa) 54 | cnt += 1 55 | new_article['paragraphs'].append(new_paragraph) 56 | 57 | new_train_set['data'].append(new_article) 58 | # break 59 | 60 | write_path = data_path.replace('.json', '_filtered.json') 61 | with open(write_path, 'w') as f: 62 | json.dump(new_train_set, f) 63 | 64 | assert orig_cnt + new_cnt == cnt 65 | print(f'writing to {write_path} with {cnt} samples') 66 | print(f'orig sample: {orig_cnt}, new sample: {new_cnt}') 67 | print(f'missing sample: {miss_cnt}') 68 | -------------------------------------------------------------------------------- /scripts/question_generation/generate_squad.py: -------------------------------------------------------------------------------- 1 | import spacy 2 | import json 3 | import random 4 | import numpy as np 5 | from tqdm import tqdm 6 | from pipelines import pipeline 7 | 8 | nlp_sent = spacy.load("en_core_web_sm") 9 | doc = nlp_sent('European authorities fined Google a record $5.1 billion on Wednesday for abusing its power in the mobile phone market and ordered the company to alter its practices') 10 | print([(X.text, X.label_) for X in doc.ents]) 11 | 12 | # Please train your own model on SQuAD and load as below 13 | nlp = pipeline("multitask-qa-qg", model="t5-large-multi-hl/checkpoint-3500", qg_format="highlight") 14 | 15 | 16 | data_path = '/home/data/squad/train-v1.1.json' 17 | sample = False 18 | print(f'reading {data_path} with sampling: {sample}') 19 | train_set = json.load(open(data_path)) 20 | new_train_set = {'data': []} 21 | cnt = 0 22 | answer_stats = [] 23 | bs = 16 24 | tmp_path = data_path.replace('.json', '_qg_t5l35-sqd_tmp.json') 25 | tmp_file = open(tmp_path, 'a') 26 | 27 | for article in tqdm(train_set['data']): 28 | new_article = { 29 | 'title': article['title'], 30 | 'paragraphs': [] 31 | } 32 | for p_idx, paragraph in enumerate(article['paragraphs']): 33 | new_paragraph = { 34 | 'context': paragraph['context'], 35 | 'qas' : [], 36 | } 37 | 38 | # Add existing QA pairs 39 | for qa in paragraph['qas']: 40 | new_paragraph['qas'].append(qa) 41 | cnt += 1 42 | 43 | # Get sentences 44 | sents = [sent for sent in nlp_sent(paragraph['context']).sents] 45 | qa_pairs = [] 46 | try: 47 | qa_pairs = nlp(paragraph['context']) 48 | except Exception as e: 49 | print('Neural QG error:', paragraph['context'][:50], e) 50 | 51 | ents = [[] for _ in range(len(sents))] 52 | try: 53 | for sent_idx, sent in enumerate(sents): 54 | parse_list = [ent for ent in sent.ents] 55 | ents[sent_idx] += parse_list 56 | except Exception as e: 57 | print('NER error:', sent.text, e) 58 | 59 | cst_qa_pairs = [] 60 | try: 61 | flat_ents = [e for ent in ents for e in ent] 62 | qg_examples = nlp._prepare_inputs_for_qg_from_answers_hl( 63 | [sent.text.strip() for sent in sents], [[e.text for e in ent] for ent in ents] 64 | ) 65 | qg_inputs = [example['source_text'] for example in qg_examples] 66 | cst_qs = [] 67 | for i in range(0, len(qg_inputs), bs): 68 | cst_qs += nlp._generate_questions(qg_inputs[i:i+bs]) 69 | assert len(cst_qs) == len(qg_examples) 70 | cst_qa_pairs = [{'answer': example['answer'], 'question': que} for example, que in zip(qg_examples, cst_qs)] 71 | except Exception as e: 72 | print('Ent QG error:', e) 73 | 74 | orig_len = len(qa_pairs) 75 | qa_pairs = qa_pairs + cst_qa_pairs 76 | if len(qa_pairs) == 0: 77 | print('Skipping as no questions generated for:', sent.text) 78 | continue 79 | flat_ents = [None]*orig_len + flat_ents 80 | 81 | q_set = [] 82 | for qa_idx, qa_pair in enumerate(qa_pairs): 83 | ans = qa_pair['answer'] 84 | que = qa_pair['question'] 85 | if que in q_set: 86 | continue 87 | q_set.append(que) 88 | try: 89 | if flat_ents[qa_idx] is not None: 90 | ans_start = flat_ents[qa_idx][0].idx 91 | else: 92 | ans_start = paragraph['context'].index(ans) 93 | except Exception as e: 94 | print('Skipping ans:', ans, e) 95 | continue 96 | if ans != paragraph['context'][ans_start:ans_start+len(ans)]: 97 | print(f'skipping mis-match {ans}') 98 | continue 99 | new_paragraph['qas'].append({ 100 | 'answers': [{'answer_start': ans_start, 'text': ans}], 101 | 'question': que, 102 | 'id': f'{article["title"]}_p{p_idx}_s{sent_idx}_a{qa_idx}', 103 | }) 104 | tmp_file.write( 105 | f'{article["title"]}_p{p_idx}_s{sent_idx}_a{qa_idx}\t{que}\t{ans}\t{ans_start}\n' 106 | ) 107 | cnt += 1 108 | 109 | if len(qa_pairs) > 0: 110 | print(qa_pairs[0]) 111 | new_article['paragraphs'].append(new_paragraph) 112 | 113 | new_train_set['data'].append(new_article) 114 | 115 | write_path = data_path.replace('.json', '_qg_t5l35-sqd.json') 116 | with open(write_path, 'w') as f: 117 | json.dump(new_train_set, f) 118 | 119 | print(f'writing to {write_path} with {cnt} samples') 120 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import io 2 | from setuptools import setup, find_packages 3 | 4 | with open('README.md', encoding='utf8') as f: 5 | readme = f.read() 6 | 7 | with open('LICENSE', encoding='utf8') as f: 8 | license = f.read() 9 | 10 | with open('requirements.txt', encoding='utf8') as f: 11 | reqs = f.read() 12 | 13 | setup( 14 | name='densephrases', 15 | version='1.0', 16 | description='Learning Dense Representations of Phrases at Scale', 17 | long_description=readme, 18 | license=license, 19 | url='https://github.com/princeton-nlp/DensePhrases', 20 | keywords=['phrase', 'embedding', 'retrieval', 'nlp', 'open-domain', 'qa'], 21 | python_requires='>=3.7', 22 | install_requires=reqs.strip().split('\n'), 23 | ) 24 | -------------------------------------------------------------------------------- /slides/emnlp2021_slides.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-nlp/DensePhrases/9583883ea9390b0308e806c3e72fa5831afa445b/slides/emnlp2021_slides.pdf --------------------------------------------------------------------------------