├── LICENSE
├── Makefile
├── README.md
├── build_phrase_index.py
├── config.sh
├── densephrases
├── __init__.py
├── demo
│ ├── __init__.py
│ └── static
│ │ ├── examples.txt
│ │ ├── examples_context.txt
│ │ ├── files
│ │ ├── all.js
│ │ ├── bootstrap.min.js
│ │ ├── favicon.ico
│ │ ├── jquery-3.3.1.min.js
│ │ ├── overview_new.png
│ │ ├── plogo.png
│ │ ├── popper.min.js
│ │ ├── preview-new.gif
│ │ ├── steps.png
│ │ └── style.css
│ │ ├── index.html
│ │ └── index_single.html
├── encoder.py
├── index.py
├── model.py
├── options.py
└── utils
│ ├── __init__.py
│ ├── data_utils.py
│ ├── embed_utils.py
│ ├── eval_utils.py
│ ├── file_utils.py
│ ├── kilt
│ ├── __init__.py
│ ├── eval.py
│ └── kilt_utils.py
│ ├── open_utils.py
│ ├── single_utils.py
│ ├── squad_metrics.py
│ └── squad_utils.py
├── download.sh
├── eval_phrase_retrieval.py
├── examples
├── README.md
├── create-custom-index
│ ├── README.md
│ ├── articles.json
│ └── questions.json
├── entity-linking
│ └── README.md
├── fusion-in-decoder
│ └── README.md
├── knowledge-dialogue
│ └── README.md
└── slot-filling
│ └── README.md
├── generate_phrase_vecs.py
├── requirements.txt
├── run_demo.py
├── scripts
├── analysis
│ ├── run_analysis.py
│ └── run_analysis_dpr.py
├── benchmark
│ ├── benchmark_hdf5.py
│ ├── create_benchmark_data.py
│ └── data
│ │ ├── nq_1000_dev_denspi.json
│ │ ├── nq_1000_dev_dpr.csv
│ │ └── nq_1000_dev_orqa.jsonl
├── dump
│ ├── check_dump.py
│ ├── filter_hdf5.py
│ ├── filter_stats.py
│ ├── save_meta.py
│ └── split_hdf5.py
├── kilt
│ ├── build_title2wikiid.py
│ ├── sample_kilt.py
│ └── strip_pred.py
├── parallel
│ ├── add_to_index.py
│ └── dump_phrases.py
├── postprocess
│ ├── recall.py
│ └── recall_transform.py
├── preprocess
│ ├── README.md
│ ├── build_db.py
│ ├── build_wikisquad.py
│ ├── compress_metadata.py
│ ├── concat_wikisquad.py
│ ├── create_nq_reader.py
│ ├── create_nq_reader_doc_wiki.py
│ ├── create_nq_reader_wiki.py
│ ├── create_openqa.py
│ ├── create_psg_hdf5.py
│ ├── create_tqa_ds.py
│ ├── doc_db.py
│ ├── download_wikidump.py
│ ├── filter_noans.py
│ ├── filter_wiki.py
│ ├── merge_openqa.py
│ ├── merge_paq.py
│ ├── merge_singleqa.py
│ ├── nq_utils.py
│ ├── prep_wikipedia.py
│ ├── sample_nq_reader_doc_wiki.py
│ ├── simple_tokenizer.py
│ └── stat_entities.py
└── question_generation
│ ├── filter_qg.py
│ └── generate_squad.py
├── setup.py
├── slides
└── emnlp2021_slides.pdf
├── train_cross_encoder.py
├── train_query.py
└── train_rc.py
/config.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Modify below to your choice of directory
4 | export BASE_DIR=./
5 |
6 | while read -p "Use to $BASE_DIR as the base directory (requires at least 220GB for the installation)? [yes/no]: " choice; do
7 | case "$choice" in
8 | yes )
9 | break ;;
10 | no )
11 | while read -p "Type in the directory: " choice; do
12 | case "$choice" in
13 | * )
14 | export BASE_DIR=$choice;
15 | echo "Base directory set to $BASE_DIR";
16 | break ;;
17 | esac
18 | done
19 | break ;;
20 | * ) echo "Please answer yes or no.";
21 | exit 0 ;;
22 | esac
23 | done
24 |
25 | # DATA_DIR: for datasets (including 'kilt', 'open-qa', 'single-qa', 'truecase', 'wikidump')
26 | # SAVE_DIR: for pre-trained models or dumps; new models and dumps will also be saved here
27 | # CACHE_DIR: for cache files from huggingface transformers
28 | export DATA_DIR=$BASE_DIR/densephrases-data
29 | export SAVE_DIR=$BASE_DIR/outputs
30 | export CACHE_DIR=$BASE_DIR/cache
31 |
32 | # Create directories
33 | mkdir -p $DATA_DIR
34 | mkdir -p $SAVE_DIR
35 | mkdir -p $SAVE_DIR/logs
36 | mkdir -p $CACHE_DIR
37 |
38 | printf "\nEnvironment variables are set as follows:\n"
39 | echo "DATA_DIR=$DATA_DIR"
40 | echo "SAVE_DIR=$SAVE_DIR"
41 | echo "CACHE_DIR=$CACHE_DIR"
42 |
43 | # Append to bashrc, instructions
44 | while read -p "Add to ~/.bashrc (recommended)? [yes/no]: " choice; do
45 | case "$choice" in
46 | yes )
47 | echo -e "\n# DensePhrases setup" >> ~/.bashrc;
48 | echo "export DATA_DIR=$DATA_DIR" >> ~/.bashrc;
49 | echo "export SAVE_DIR=$SAVE_DIR" >> ~/.bashrc;
50 | echo "export CACHE_DIR=$CACHE_DIR" >> ~/.bashrc;
51 | break ;;
52 | no )
53 | break ;;
54 | * ) echo "Please answer yes or no." ;;
55 | esac
56 | done
57 |
--------------------------------------------------------------------------------
/densephrases/__init__.py:
--------------------------------------------------------------------------------
1 | from .encoder import Encoder
2 | from .index import MIPS
3 | from .options import Options
4 | from .model import DensePhrases
5 |
--------------------------------------------------------------------------------
/densephrases/demo/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princeton-nlp/DensePhrases/9583883ea9390b0308e806c3e72fa5831afa445b/densephrases/demo/__init__.py
--------------------------------------------------------------------------------
/densephrases/demo/static/examples.txt:
--------------------------------------------------------------------------------
1 | who determines the size of the supreme court
2 | who won series 7 of great british bake off
3 | when does the new wheel of fortune season start
4 | when does season 3 of lucifer come out
5 | what kind of currency is used in new zealand
6 | who is the highest paid nba player in 2016
7 | who is el senor de los cielos based on
8 | who is paige on days of our lives
9 | who is the creator of star vs the forces of evil
10 | total number of articles in indian constitution at present
11 | who plays percy in the lost city of z
12 | what was uncle jesse's original last name on full house
13 | how many goals scored ronaldo in his career
14 | who plays male lead in far from the madding crowd
15 | when was a whiter shade of pale recorded
16 | when did medicare begin in the united states
17 | who sings don't stand so close to me
18 | where was war on the planet of the apes filmed
19 | who wrote love so soft by kelly clarkson
20 | who is the longest serving manager in man united
21 | Who is the fourth president of USA?
22 | the seventh president of USA
23 | What is South Korea known for?
24 | What tends to lead to more money?
25 | Who was defeated by computer in chess game?
26 | Name three famous writers
27 | What makes a successful startup?
28 | Why did Oracle sue Google?
29 | Where can you find water in desert?
30 | What does AMI stand for?
31 | How heavy was the apollo 11?
32 | What is water consisted of?
33 | What makes a man great?
34 | Which city is famous for coffee?
35 | On which date was Genghis Khan's palace rediscovered by archeaologists?
36 | What is another term for x-ray imaging?
37 | Who scolded Luther about his rudeness?
38 | What was the Yuan's paper money called?
39 |
--------------------------------------------------------------------------------
/densephrases/demo/static/files/favicon.ico:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princeton-nlp/DensePhrases/9583883ea9390b0308e806c3e72fa5831afa445b/densephrases/demo/static/files/favicon.ico
--------------------------------------------------------------------------------
/densephrases/demo/static/files/overview_new.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princeton-nlp/DensePhrases/9583883ea9390b0308e806c3e72fa5831afa445b/densephrases/demo/static/files/overview_new.png
--------------------------------------------------------------------------------
/densephrases/demo/static/files/plogo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princeton-nlp/DensePhrases/9583883ea9390b0308e806c3e72fa5831afa445b/densephrases/demo/static/files/plogo.png
--------------------------------------------------------------------------------
/densephrases/demo/static/files/preview-new.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princeton-nlp/DensePhrases/9583883ea9390b0308e806c3e72fa5831afa445b/densephrases/demo/static/files/preview-new.gif
--------------------------------------------------------------------------------
/densephrases/demo/static/files/steps.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princeton-nlp/DensePhrases/9583883ea9390b0308e806c3e72fa5831afa445b/densephrases/demo/static/files/steps.png
--------------------------------------------------------------------------------
/densephrases/demo/static/files/style.css:
--------------------------------------------------------------------------------
1 | html { position: relative; min-height: 100%; }
2 | body { margin-bottom: 60px; font-family: Verdana, sans-serif;}
3 | .footer { position: absolute; bottom: 0; width: 100%; height: 40px; line-height: 15px; background-color: #f5f5f5; padding-top: 5px; font-size: 12px; text-align: center;}
4 | label, footer { user-select: none; }
5 | .list-group-item:first-of-type { background-color: #BEE6FF; color: #000000; }
6 | .score { position:absolute; bottom:0; right:15px;}
7 |
8 | .list-group-mine .list-group-item {
9 | background-color: #DFDFDF;
10 | border-left-color: #fff;
11 | border-right-color: #fff;
12 | }
13 | .list-group-mine .list-group-item:first-child {
14 | display:none;
15 | }
16 |
17 | .paper_title {
18 | margin-top: 15px;
19 | margin-left: auto;
20 | margin-right: auto;
21 | margin-bottom: auto;
22 | width: 70%;
23 | text-align: center;
24 | }
25 | .detail {
26 | margin: auto;
27 | width: 50%;
28 | }
29 | .detail2 {
30 | margin-top: 8px;
31 | margin-left: auto;
32 | margin-right: auto;
33 | margin-bottom: auto;
34 | width: 50%;
35 | }
36 | .card {
37 | margin-top: -15px;
38 | }
39 |
--------------------------------------------------------------------------------
/densephrases/demo/static/index.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
52 |
57 |
58 |
59 | From 5 million Wikipedia articles,
DensePhrases searches phrase-level answers to your questions or retrieve relevant passages in real-time. More details are in our
ACL'21 paper and
EMNLP'21 paper .
60 |
61 | You can type in any natural language question below and get the results in real-time. Retrieved phrases are denoted in
boldface for each passage. Current model is case-sensitive and the best results are obtained when queries have proper letter cases (e.g., "Name Apple's products" not "name apple's products"). Our current demo has the following specs:
62 |
63 |
64 | Accuracy: 40.8% on Natural Questions (open), Latency: ≈100ms/Q (with at least top 10 results)
65 | Resources: 11GB GPU, 100GB RAM
66 | Code link | Contact: Jinhyuk Lee (lee.jnhk@gmail.com)
67 |
68 |
69 |
70 |
71 |
72 |
73 |
91 |
92 |
93 |
94 |
95 |
96 | Real-time Search
97 |
99 | English Wikipedia (2018.12.20)
100 |
101 |
102 |
103 |
104 |
105 |
110 |
111 |
112 |
113 |
123 |
124 |
199 |
200 |
201 |
202 |
203 |
--------------------------------------------------------------------------------
/densephrases/demo/static/index_single.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
DensePhrases
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 | DensePhrases
18 |
26 |
27 |
28 |
29 |
30 |
54 |
55 |
56 |
Latency:
57 |
58 |
62 | Single passage
63 |
64 |
65 |
66 |
71 |
72 |
73 |
74 |
82 |
83 |
84 |
202 |
203 |
204 |
205 |
206 |
--------------------------------------------------------------------------------
/densephrases/model.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import logging
3 | import numpy as np
4 | import os
5 |
6 | from densephrases import Options
7 | from densephrases.utils.single_utils import load_encoder
8 | from densephrases.utils.open_utils import load_phrase_index, get_query2vec, load_qa_pairs
9 | from densephrases.utils.squad_utils import TrueCaser
10 |
11 | logger = logging.getLogger(__name__)
12 |
13 |
14 | class DensePhrases(object):
15 | def __init__(self,
16 | load_dir,
17 | dump_dir,
18 | index_name='start/1048576_flat_OPQ96',
19 | device='cuda',
20 | verbose=False,
21 | **kwargs):
22 | print("This could take up to 15 mins depending on the file reading speed of HDD/SSD")
23 |
24 | # Turn off loggers
25 | if not verbose:
26 | logging.getLogger("densephrases").setLevel(logging.WARNING)
27 | logging.getLogger("transformers").setLevel(logging.WARNING)
28 |
29 | # Get default options
30 | options = Options()
31 | options.add_model_options()
32 | options.add_index_options()
33 | options.add_retrieval_options()
34 | options.add_data_options()
35 | self.args = options.parse()
36 |
37 | # Set options
38 | self.args.load_dir = load_dir
39 | self.args.dump_dir = dump_dir
40 | self.args.cache_dir = os.environ['CACHE_DIR']
41 | self.args.index_name = index_name
42 | self.args.cuda = True if device == 'cuda' else False
43 | self.args.__dict__.update(kwargs)
44 |
45 | # Load encoder
46 | self.set_encoder(load_dir, device)
47 |
48 | # Load MIPS
49 | self.mips = load_phrase_index(self.args, ignore_logging=not verbose)
50 |
51 | # Others
52 | self.truecase = TrueCaser(os.path.join(os.environ['DATA_DIR'], self.args.truecase_path))
53 | print("Loading DensePhrases Completed!")
54 |
55 | def search(self, query='', retrieval_unit='phrase', top_k=10, truecase=True, return_meta=False):
56 | # If query is str, single query
57 | single_query = False
58 | if type(query) == str:
59 | batch_query = [query]
60 | single_query = True
61 | else:
62 | assert type(query) == list
63 | batch_query = query
64 |
65 | # Pre-processing
66 | if truecase:
67 | query = [self.truecase.get_true_case(query) if query == query.lower() else query for query in batch_query]
68 |
69 | # Get question vector
70 | outs = self.query2vec(batch_query)
71 | start = np.concatenate([out[0] for out in outs], 0)
72 | end = np.concatenate([out[1] for out in outs], 0)
73 | query_vec = np.concatenate([start, end], 1)
74 |
75 | # Search
76 | agg_strats = {'phrase': 'opt1', 'sentence': 'opt2', 'paragraph': 'opt2', 'document': 'opt3'}
77 | if retrieval_unit not in agg_strats:
78 | raise NotImplementedError(f'"{retrieval_unit}" not supported. Choose one of {agg_strats.keys()}.')
79 | search_top_k = top_k
80 | if retrieval_unit in ['sentence', 'paragraph', 'document']:
81 | search_top_k *= 2
82 | rets = self.mips.search(
83 | query_vec, q_texts=batch_query, nprobe=256,
84 | top_k=search_top_k, max_answer_length=10,
85 | return_idxs=False, aggregate=True, agg_strat=agg_strats[retrieval_unit],
86 | return_sent=True if retrieval_unit == 'sentence' else False
87 | )
88 |
89 | # Gather results
90 | rets = [ret[:top_k] for ret in rets]
91 | if retrieval_unit == 'phrase':
92 | retrieved = [[rr['answer'] for rr in ret][:top_k] for ret in rets]
93 | elif retrieval_unit == 'sentence':
94 | retrieved = [[rr['context'] for rr in ret][:top_k] for ret in rets]
95 | elif retrieval_unit == 'paragraph':
96 | retrieved = [[rr['context'] for rr in ret][:top_k] for ret in rets]
97 | elif retrieval_unit == 'document':
98 | retrieved = [[rr['title'][0] for rr in ret][:top_k] for ret in rets]
99 | else:
100 | raise NotImplementedError()
101 |
102 | if single_query:
103 | rets = rets[0]
104 | retrieved = retrieved[0]
105 |
106 | if return_meta:
107 | return retrieved, rets
108 | else:
109 | return retrieved
110 |
111 | def set_encoder(self, load_dir, device='cuda'):
112 | self.args.load_dir = load_dir
113 | self.model, self.tokenizer, self.config = load_encoder(device, self.args)
114 | self.query2vec = get_query2vec(
115 | query_encoder=self.model, tokenizer=self.tokenizer, args=self.args, batch_size=64
116 | )
117 |
118 | def evaluate(self, test_path, **kwargs):
119 | from eval_phrase_retrieval import evaluate as evaluate_fn
120 |
121 | # Set new arguments
122 | new_args = copy.deepcopy(self.args)
123 | new_args.test_path = test_path
124 | new_args.truecase = True
125 | new_args.__dict__.update(kwargs)
126 |
127 | # Run with new_arg
128 | evaluate_fn(new_args, self.mips, self.model, self.tokenizer)
129 |
--------------------------------------------------------------------------------
/densephrases/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princeton-nlp/DensePhrases/9583883ea9390b0308e806c3e72fa5831afa445b/densephrases/utils/__init__.py
--------------------------------------------------------------------------------
/densephrases/utils/eval_utils.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import ujson as json
3 | import re
4 | import string
5 | import unicodedata
6 | import pickle
7 | from collections import Counter
8 |
9 | def normalize_answer(s):
10 |
11 | def remove_articles(text):
12 | return re.sub(r'\b(a|an|the)\b', ' ', text)
13 |
14 | def white_space_fix(text):
15 | return ' '.join(text.split())
16 |
17 | def remove_punc(text):
18 | exclude = set(string.punctuation)
19 | return ''.join(ch for ch in text if ch not in exclude)
20 |
21 | def lower(text):
22 | return text.lower()
23 |
24 | return white_space_fix(remove_articles(remove_punc(lower(s))))
25 |
26 |
27 | def f1_score(prediction, ground_truth):
28 | normalized_prediction = normalize_answer(prediction)
29 | normalized_ground_truth = normalize_answer(ground_truth)
30 |
31 | ZERO_METRIC = (0, 0, 0)
32 |
33 | if normalized_prediction in ['yes', 'no', 'noanswer'] and normalized_prediction != normalized_ground_truth:
34 | return ZERO_METRIC
35 | if normalized_ground_truth in ['yes', 'no', 'noanswer'] and normalized_prediction != normalized_ground_truth:
36 | return ZERO_METRIC
37 |
38 | prediction_tokens = normalized_prediction.split()
39 | ground_truth_tokens = normalized_ground_truth.split()
40 | common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
41 | num_same = sum(common.values())
42 | if num_same == 0:
43 | return ZERO_METRIC
44 | precision = 1.0 * num_same / len(prediction_tokens)
45 | recall = 1.0 * num_same / len(ground_truth_tokens)
46 | f1 = (2 * precision * recall) / (precision + recall)
47 | return f1, precision, recall
48 |
49 |
50 | def exact_match_score(prediction, ground_truth):
51 | return (normalize_answer(prediction) == normalize_answer(ground_truth))
52 |
53 |
54 | def drqa_normalize(text):
55 | """Resolve different type of unicode encodings."""
56 | return unicodedata.normalize('NFD', text)
57 |
58 |
59 | def drqa_exact_match_score(prediction, ground_truth):
60 | """Check if the prediction is a (soft) exact match with the ground truth."""
61 | return normalize_answer(prediction) == normalize_answer(ground_truth)
62 |
63 |
64 | def drqa_regex_match_score(prediction, pattern):
65 | """Check if the prediction matches the given regular expression."""
66 | try:
67 | compiled = re.compile(
68 | pattern,
69 | flags=re.IGNORECASE + re.UNICODE + re.MULTILINE
70 | )
71 | except BaseException as e:
72 | # logger.warn('Regular expression failed to compile: %s' % pattern)
73 | # print('re failed to compile: [%s] due to [%s]' % (pattern, e))
74 | return False
75 | return compiled.match(prediction) is not None
76 |
77 |
78 | def drqa_metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
79 | """Given a prediction and multiple valid answers, return the score of
80 | the best prediction-answer_n pair given a metric function.
81 | """
82 | scores_for_ground_truths = []
83 | for ground_truth in ground_truths:
84 | score = metric_fn(prediction, ground_truth)
85 | scores_for_ground_truths.append(score)
86 | return max(scores_for_ground_truths)
87 |
88 |
89 | def update_answer(metrics, prediction, gold):
90 | em = exact_match_score(prediction, gold)
91 | f1, prec, recall = f1_score(prediction, gold)
92 | metrics['em'] += em
93 | metrics['f1'] += f1
94 | metrics['prec'] += prec
95 | metrics['recall'] += recall
96 | return em, prec, recall
97 |
98 |
99 | def update_sp(metrics, prediction, gold):
100 | cur_sp_pred = set(map(tuple, prediction))
101 | gold_sp_pred = set(map(tuple, gold))
102 | tp, fp, fn = 0, 0, 0
103 | for e in cur_sp_pred:
104 | if e in gold_sp_pred:
105 | tp += 1
106 | else:
107 | fp += 1
108 | for e in gold_sp_pred:
109 | if e not in cur_sp_pred:
110 | fn += 1
111 | prec = 1.0 * tp / (tp + fp) if tp + fp > 0 else 0.0
112 | recall = 1.0 * tp / (tp + fn) if tp + fn > 0 else 0.0
113 | f1 = 2 * prec * recall / (prec + recall) if prec + recall > 0 else 0.0
114 | em = 1.0 if fp + fn == 0 else 0.0
115 | metrics['sp_em'] += em
116 | metrics['sp_f1'] += f1
117 | metrics['sp_prec'] += prec
118 | metrics['sp_recall'] += recall
119 | return em, prec, recall
120 |
121 |
122 | def eval(prediction_file, gold_file):
123 | with open(prediction_file) as f:
124 | prediction = json.load(f)
125 | with open(gold_file) as f:
126 | gold = json.load(f)
127 |
128 | metrics = {'em': 0, 'f1': 0, 'prec': 0, 'recall': 0,
129 | 'sp_em': 0, 'sp_f1': 0, 'sp_prec': 0, 'sp_recall': 0,
130 | 'joint_em': 0, 'joint_f1': 0, 'joint_prec': 0, 'joint_recall': 0}
131 |
132 | for dp in gold:
133 | cur_id = dp['_id']
134 | em, prec, recall = update_answer(
135 | metrics, prediction['answer'][cur_id], dp['answer'])
136 |
137 | N = len(gold)
138 | for k in metrics.keys():
139 | metrics[k] /= N
140 |
141 | print(metrics)
142 |
143 |
144 | def analyze(prediction_file, gold_file):
145 | with open(prediction_file) as f:
146 | prediction = json.load(f)
147 | with open(gold_file) as f:
148 | gold = json.load(f)
149 | metrics = {'em': 0, 'f1': 0, 'prec': 0, 'recall': 0,
150 | 'sp_em': 0, 'sp_f1': 0, 'sp_prec': 0, 'sp_recall': 0,
151 | 'joint_em': 0, 'joint_f1': 0, 'joint_prec': 0, 'joint_recall': 0}
152 |
153 | for dp in gold:
154 | cur_id = dp['_id']
155 |
156 | em, prec, recall = update_answer(
157 | metrics, prediction['answer'][cur_id], dp['answer'])
158 | if (prec + recall == 0):
159 | f1 = 0
160 | else:
161 | f1 = 2 * prec * recall / (prec+recall)
162 |
163 | print (dp['answer'], prediction['answer'][cur_id])
164 | print (f1, em)
165 | a = input()
166 |
167 |
168 | if __name__ == '__main__':
169 | #eval(sys.argv[1], sys.argv[2])
170 | analyze(sys.argv[1], sys.argv[2])
171 |
--------------------------------------------------------------------------------
/densephrases/utils/kilt/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princeton-nlp/DensePhrases/9583883ea9390b0308e806c3e72fa5831afa445b/densephrases/utils/kilt/__init__.py
--------------------------------------------------------------------------------
/densephrases/utils/kilt/kilt_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 |
8 | import nltk
9 | import json
10 | import os
11 | import logging
12 | import sys
13 | import time
14 | import string
15 | import random
16 |
17 |
18 | def normalize_answer(s):
19 | """Lower text and remove punctuation, articles and extra whitespace."""
20 |
21 | def remove_punc(text):
22 | exclude = set(string.punctuation)
23 | return "".join(ch for ch in text if ch not in exclude)
24 |
25 | def lower(text):
26 | return text.lower()
27 |
28 | return remove_punc(lower(s))
29 |
30 |
31 | def validate_datapoint(datapoint, logger):
32 |
33 | # input is a string
34 | if not isinstance(datapoint["input"], str):
35 | if logger:
36 | logger.warning(
37 | "[{}] input is not a string {}".format(
38 | datapoint["id"], datapoint["input"]
39 | )
40 | )
41 | return False
42 |
43 | # output is not empty
44 | if "output" in datapoint:
45 | if len(datapoint["output"]) == 0:
46 | if logger:
47 | logger.warning("[{}] empty output".format(datapoint["id"]))
48 | return False
49 |
50 | for output in datapoint["output"]:
51 | # answer is a string
52 | if "answer" in output:
53 | if not isinstance(output["answer"], str):
54 | if logger:
55 | logger.warning(
56 | "[{}] answer is not a string {}".format(
57 | datapoint["id"], output["answer"]
58 | )
59 | )
60 | return False
61 |
62 | # provenance is not empty
63 | # if len(output["provenance"]) == 0:
64 | # if logger:
65 | # logger.warning("[{}] empty provenance".format(datapoint["id"]))
66 | # return False
67 |
68 | if "provenance" in output:
69 | for provenance in output["provenance"]:
70 | # wikipedia_id is provided
71 | if not isinstance(provenance["wikipedia_id"], str):
72 | if logger:
73 | logger.warning(
74 | "[{}] wikipedia_id is not a string {}".format(
75 | datapoint["id"], provenance["wikipedia_id"]
76 | )
77 | )
78 | return False
79 |
80 | # title is provided
81 | if not isinstance(provenance["title"], str):
82 | if logger:
83 | logger.warning(
84 | "[{}] title is not a string {}".format(
85 | datapoint["id"], provenance["title"]
86 | )
87 | )
88 | return False
89 |
90 | return True
91 |
92 |
93 | def load_data(filename):
94 | data = []
95 | with open(filename, "r") as fin:
96 | lines = fin.readlines()
97 | for line in lines:
98 | data.append(json.loads(line))
99 | return data
100 |
101 |
102 | def store_data(filename, data):
103 | with open(filename, "w+") as outfile:
104 | for idx, element in enumerate(data):
105 | # print(round(idx * 100 / len(data), 2), "%", end="\r")
106 | # sys.stdout.flush()
107 | json.dump(element, outfile)
108 | outfile.write("\n")
109 |
110 |
111 | def get_bleu(candidate_tokens, gold_tokens):
112 |
113 | candidate_tokens = [x for x in candidate_tokens if len(x.strip()) > 0]
114 | gold_tokens = [x for x in gold_tokens if len(x.strip()) > 0]
115 |
116 | # The default BLEU calculates a score for up to
117 | # 4-grams using uniform weights (this is called BLEU-4)
118 | weights = (0.25, 0.25, 0.25, 0.25)
119 |
120 | if len(gold_tokens) < 4:
121 | # lower order ngrams
122 | weights = [1.0 / len(gold_tokens) for _ in range(len(gold_tokens))]
123 |
124 | BLEUscore = nltk.translate.bleu_score.sentence_bleu(
125 | [candidate_tokens], gold_tokens, weights=weights
126 | )
127 | return BLEUscore
128 |
129 |
130 | # split a list in num parts evenly
131 | def chunk_it(seq, num):
132 | assert num > 0
133 | chunk_len = len(seq) // num
134 | chunks = [seq[i * chunk_len : i * chunk_len + chunk_len] for i in range(num)]
135 |
136 | diff = len(seq) - chunk_len * num # 0 <= diff < num
137 | for i in range(diff):
138 | chunks[i].append(seq[chunk_len * num + i])
139 |
140 | return chunks
141 |
142 |
143 | def init_logging(base_logdir, modelname, logger=None):
144 |
145 | # logging format
146 | # "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
147 | formatter = logging.Formatter(
148 | "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
149 | )
150 |
151 | log_directory = "{}/{}/".format(base_logdir, modelname)
152 |
153 | if logger == None:
154 | logger = logging.getLogger("KILT")
155 |
156 | logger.setLevel(logging.DEBUG)
157 |
158 | # console handler
159 | ch = logging.StreamHandler(sys.stdout)
160 | ch.setLevel(logging.DEBUG)
161 | ch.setFormatter(formatter)
162 |
163 | logger.addHandler(ch)
164 |
165 | else:
166 | # remove previous file handler
167 | logger.handlers.pop()
168 |
169 | os.makedirs(log_directory, exist_ok=True)
170 |
171 | # file handler
172 | fh = logging.FileHandler(str(log_directory) + "/info.log")
173 | fh.setLevel(logging.DEBUG)
174 | fh.setFormatter(formatter)
175 |
176 | logger.addHandler(fh)
177 |
178 | logger.propagate = False
179 | logger.info("logging in {}".format(log_directory))
180 | return logger
181 |
182 |
183 | def create_logdir_with_timestamp(base_logdir):
184 | timestr = time.strftime("%Y%m%d_%H%M%S")
185 | # create new directory
186 | log_directory = "{}/{}_{}/".format(base_logdir, timestr, random.randint(0, 1000))
187 | os.makedirs(log_directory)
188 | return log_directory
--------------------------------------------------------------------------------
/densephrases/utils/open_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | import logging
4 | import json
5 | import torch
6 | import numpy as np
7 |
8 | from densephrases import MIPS
9 | from densephrases.utils.single_utils import backward_compat
10 | from densephrases.utils.squad_utils import get_question_dataloader, TrueCaser
11 | from densephrases.utils.embed_utils import get_question_results
12 |
13 | from transformers import (
14 | MODEL_MAPPING,
15 | AutoConfig,
16 | AutoTokenizer,
17 | AutoModel,
18 | )
19 |
20 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', datefmt='%m/%d/%Y %H:%M:%S',
21 | level=logging.INFO)
22 | logger = logging.getLogger(__name__)
23 | truecase = None
24 |
25 |
26 | def load_phrase_index(args, ignore_logging=False):
27 | # Configure paths for index serving
28 | phrase_dump_dir = os.path.join(args.dump_dir, args.phrase_dir)
29 | index_dir = os.path.join(args.dump_dir, args.index_name)
30 | index_path = os.path.join(index_dir, args.index_path)
31 | idx2id_path = os.path.join(index_dir, args.idx2id_path)
32 |
33 | # Load mips
34 | if 'aggregate' in args.__dict__.keys():
35 | logger.info(f'Aggregate: {args.aggregate}')
36 | mips = MIPS(
37 | phrase_dump_dir=phrase_dump_dir,
38 | index_path=index_path,
39 | idx2id_path=idx2id_path,
40 | cuda=args.cuda,
41 | logging_level=logging.WARNING if ignore_logging else (logging.DEBUG if args.verbose_logging else logging.INFO),
42 | )
43 | return mips
44 |
45 |
46 | def load_cross_encoder(device, args):
47 |
48 | # Configure paths for cross-encoder serving
49 | cross_encoder = torch.load(
50 | os.path.join(args.load_dir, "pytorch_model.bin"), map_location=torch.device('cpu')
51 | )
52 | new_qd = {n[len('bert')+1:]: p for n, p in cross_encoder.items() if 'bert' in n}
53 | new_linear = {n[len('qa_outputs')+1:]: p for n, p in cross_encoder.items() if 'qa_outputs' in n}
54 | config, unused_kwargs = AutoConfig.from_pretrained(
55 | args.pretrained_name_or_path,
56 | cache_dir=args.cache_dir if args.cache_dir else None,
57 | return_unused_kwargs=True
58 | )
59 | tokenizer = AutoTokenizer.from_pretrained(
60 | args.tokenizer_name if args.tokenizer_name else args.pretrained_name_or_path,
61 | do_lower_case=args.do_lower_case,
62 | cache_dir=args.cache_dir if args.cache_dir else None,
63 | )
64 | model = AutoModel.from_pretrained(
65 | args.pretrained_name_or_path,
66 | from_tf=bool(".ckpt" in args.pretrained_name_or_path),
67 | config=config,
68 | cache_dir=args.cache_dir if args.cache_dir else None,
69 | )
70 | model.load_state_dict(new_qd)
71 | qa_outputs = torch.nn.Linear(config.hidden_size, 2)
72 | qa_outputs.load_state_dict(new_linear)
73 | ce_model = torch.nn.ModuleList(
74 | [model, qa_outputs]
75 | )
76 | ce_model.to(device)
77 |
78 | logger.info(f'CrossEncoder loaded from {args.load_dir} having {MODEL_MAPPING[config.__class__]}')
79 | logger.info('Number of model parameters: {:,}'.format(sum(p.numel() for p in ce_model.parameters())))
80 | return ce_model, tokenizer
81 |
82 |
83 | def get_query2vec(query_encoder, tokenizer, args, batch_size=64):
84 | device = 'cuda' if args.cuda else 'cpu'
85 | def query2vec(queries):
86 | question_dataloader, question_examples, query_features = get_question_dataloader(
87 | queries, tokenizer, args.max_query_length, batch_size=batch_size
88 | )
89 | question_results = get_question_results(
90 | question_examples, query_features, question_dataloader, device, query_encoder, batch_size=batch_size
91 | )
92 | if args.verbose_logging:
93 | logger.info(f"{len(query_features)} queries: {' '.join(query_features[0].tokens_)}")
94 | outs = []
95 | for qr_idx, question_result in enumerate(question_results):
96 | out = (
97 | question_result.start_vec.tolist(), question_result.end_vec.tolist(), query_features[qr_idx].tokens_
98 | )
99 | outs.append(out)
100 | return outs
101 | return query2vec
102 |
103 |
104 | def load_qa_pairs(data_path, args, q_idx=None, draft_num_examples=100, shuffle=False):
105 | q_ids = []
106 | questions = []
107 | answers = []
108 | titles = []
109 | data = json.load(open(data_path))['data']
110 | for data_idx, item in enumerate(data):
111 | if q_idx is not None:
112 | if data_idx != q_idx:
113 | continue
114 | q_id = item['id']
115 | if 'origin' in item:
116 | q_id = item['origin'].split('.')[0] + '-' + q_id
117 | question = item['question']
118 | if '[START_ENT]' in question:
119 | question = question[max(question.index('[START_ENT]')-300, 0):question.index('[END_ENT]')+300]
120 | answer = item['answers']
121 | title = item.get('titles', [''])
122 | if len(answer) == 0:
123 | continue
124 | q_ids.append(q_id)
125 | questions.append(question)
126 | answers.append(answer)
127 | titles.append(title)
128 | questions = [query[:-1] if query.endswith('?') else query for query in questions]
129 | # questions = [query.lower() for query in questions] # force lower query
130 |
131 | if args.do_lower_case:
132 | logger.info(f'Lowercasing queries')
133 | questions = [query.lower() for query in questions]
134 |
135 | if shuffle:
136 | qa_pairs = list(zip(q_ids, questions, answers, titles))
137 | random.shuffle(qa_pairs)
138 | q_ids, questions, answers, titles = zip(*qa_pairs)
139 | logger.info(f'Shuffling QA pairs')
140 |
141 | if args.draft:
142 | q_ids = np.array(q_ids)[:draft_num_examples].tolist()
143 | questions = np.array(questions)[:draft_num_examples].tolist()
144 | answers = np.array(answers)[:draft_num_examples].tolist()
145 | titles = np.array(titles)[:draft_num_examples].tolist()
146 |
147 | if args.truecase:
148 | try:
149 | global truecase
150 | if truecase is None:
151 | logger.info('loading truecaser')
152 | truecase = TrueCaser(os.path.join(os.environ['DATA_DIR'], args.truecase_path))
153 | logger.info('Truecasing queries')
154 | questions = [truecase.get_true_case(query) if query == query.lower() else query for query in questions]
155 | except Exception as e:
156 | print(e)
157 |
158 | logger.info(f'Loading {len(questions)} questions from {data_path}')
159 | logger.info(f'Sample Q ({q_ids[0]}): {questions[0]}, A: {answers[0]}, Title: {titles[0]}')
160 | return q_ids, questions, answers, titles
161 |
162 |
--------------------------------------------------------------------------------
/densephrases/utils/single_utils.py:
--------------------------------------------------------------------------------
1 | import random
2 | import torch
3 | import logging
4 | import copy
5 | import os
6 | import numpy as np
7 |
8 | from functools import partial
9 | from transformers import (
10 | MODEL_MAPPING,
11 | AutoConfig,
12 | AutoTokenizer,
13 | AutoModel,
14 | )
15 | from densephrases import Encoder
16 |
17 | logger = logging.getLogger(__name__)
18 |
19 |
20 | def set_seed(args):
21 | random.seed(args.seed)
22 | np.random.seed(args.seed)
23 | torch.manual_seed(args.seed)
24 | if torch.cuda.is_available():
25 | torch.cuda.manual_seed_all(args.seed)
26 |
27 |
28 | def to_list(tensor):
29 | return tensor.detach().cpu().tolist()
30 |
31 |
32 | def to_numpy(tensor):
33 | return tensor.detach().cpu().numpy()
34 |
35 |
36 | def backward_compat(model_dict):
37 | # Remove teacher
38 | model_dict = {key: val for key, val in model_dict.items() if not key.startswith('cross_encoder')}
39 | model_dict = {key: val for key, val in model_dict.items() if not key.startswith('bert_qd')}
40 | model_dict = {key: val for key, val in model_dict.items() if not key.startswith('qa_outputs')}
41 |
42 | # Replace old names to current ones
43 | mapping = {
44 | 'bert_start': 'phrase_encoder',
45 | 'bert_q_start': 'query_start_encoder',
46 | 'bert_q_end': 'query_end_encoder',
47 | }
48 | new_model_dict = {}
49 | for key, val in model_dict.items():
50 | for old_key, new_key in mapping.items():
51 | if key.startswith(old_key):
52 | new_model_dict[key.replace(old_key, new_key)] = val
53 | elif all(not key.startswith(old_k) for old_k in mapping.keys()):
54 | new_model_dict[key] = val
55 |
56 | return new_model_dict
57 |
58 |
59 | def load_encoder(device, args, phrase_only=False):
60 | # Configure paths for DnesePhrases
61 | args.model_type = args.model_type.lower()
62 | config = AutoConfig.from_pretrained(
63 | args.config_name if args.config_name else args.pretrained_name_or_path,
64 | cache_dir=args.cache_dir if args.cache_dir else None,
65 | )
66 | tokenizer = AutoTokenizer.from_pretrained(
67 | args.tokenizer_name if args.tokenizer_name else args.pretrained_name_or_path,
68 | do_lower_case=args.do_lower_case,
69 | cache_dir=args.cache_dir if args.cache_dir else None,
70 | )
71 |
72 | # Prepare PLM if not load_dir
73 | pretrained = None
74 | if not args.load_dir:
75 | pretrained = AutoModel.from_pretrained(
76 | args.pretrained_name_or_path,
77 | config=config,
78 | cache_dir=args.cache_dir if args.cache_dir else None,
79 | )
80 | load_class = Encoder
81 | logger.info(f'DensePhrases encoder initialized with {args.pretrained_name_or_path} ({pretrained.__class__})')
82 | else:
83 | # TODO: need to update transformers so that from_pretrained maps to model hub directly
84 | if args.load_dir.startswith('princeton-nlp'):
85 | hf_model_path = f"https://huggingface.co/{args.load_dir}/resolve/main/pytorch_model.bin"
86 | else:
87 | hf_model_path = args.load_dir
88 | load_class = partial(
89 | Encoder.from_pretrained,
90 | pretrained_model_name_or_path=hf_model_path,
91 | cache_dir=args.cache_dir if args.cache_dir else None,
92 | )
93 | logger.info(f'DensePhrases encoder loaded from {args.load_dir}')
94 |
95 | # DensePhrases encoder object
96 | model = load_class(
97 | config=config,
98 | tokenizer=tokenizer,
99 | transformer_cls=MODEL_MAPPING[config.__class__],
100 | pretrained=copy.deepcopy(pretrained) if pretrained is not None else None,
101 | lambda_kl=getattr(args, 'lambda_kl', 0.0),
102 | lambda_neg=getattr(args, 'lambda_neg', 0.0),
103 | lambda_flt=getattr(args, 'lambda_flt', 0.0),
104 | )
105 |
106 | # Phrase only (for phrase embedding)
107 | if phrase_only:
108 | if hasattr(model, "module"):
109 | del model.module.query_start_encoder
110 | del model.module.query_end_encoder
111 | else:
112 | del model.query_start_encoder
113 | del model.query_end_encoder
114 | logger.info("Load only phrase encoders for embedding phrases")
115 |
116 | model.to(device)
117 | logger.info('Number of model parameters: {:,}'.format(sum(p.numel() for p in model.parameters())))
118 | return model, tokenizer, config
119 |
--------------------------------------------------------------------------------
/download.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | while read -p "Choose a resource to download [data/wiki/models/index]: " choice; do
4 | case "$choice" in
5 | data )
6 | TARGET=$choice
7 | TARGET_DIR=$DATA_DIR
8 | break ;;
9 | wiki )
10 | TARGET=$choice
11 | TARGET_DIR=$DATA_DIR
12 | break ;;
13 | models )
14 | TARGET=$choice
15 | TARGET_DIR=$SAVE_DIR
16 | break ;;
17 | index )
18 | TARGET=$choice
19 | TARGET_DIR=$SAVE_DIR
20 | break ;;
21 | * ) echo "Please type among [data/wiki/models/index]";
22 | exit 0 ;;
23 | esac
24 | done
25 |
26 | echo "$TARGET will be downloaded at $TARGET_DIR"
27 |
28 | # Download + untar + rm
29 | case "$TARGET" in
30 | data )
31 | wget -O "$TARGET_DIR/densephrases-data.tar.gz" "https://nlp.cs.princeton.edu/projects/densephrases/densephrases-data.tar.gz"
32 | tar -xzvf "$TARGET_DIR/densephrases-data.tar.gz" -C "$TARGET_DIR" --strip 1
33 | rm "$TARGET_DIR/densephrases-data.tar.gz" ;;
34 | wiki )
35 | wget -O "$TARGET_DIR/wikidump.tar.gz" "https://nlp.cs.princeton.edu/projects/densephrases/wikidump.tar.gz"
36 | tar -xzvf "$TARGET_DIR/wikidump.tar.gz" -C "$TARGET_DIR"
37 | rm "$TARGET_DIR/wikidump.tar.gz" ;;
38 | models )
39 | wget -O "$TARGET_DIR/outputs.tar.gz" "https://nlp.cs.princeton.edu/projects/densephrases/outputs.tar.gz"
40 | tar -xzvf "$TARGET_DIR/outputs.tar.gz" -C "$TARGET_DIR" --strip 1
41 | rm "$TARGET_DIR/outputs.tar.gz" ;;
42 | index )
43 | wget -O "$TARGET_DIR/densephrases-multi_wiki-20181220.tar.gz" "https://nlp.cs.princeton.edu/projects/densephrases/densephrases-multi_wiki-20181220.tar.gz"
44 | tar -xzvf "$TARGET_DIR/densephrases-multi_wiki-20181220.tar.gz" -C "$TARGET_DIR"
45 | rm "$TARGET_DIR/densephrases-multi_wiki-20181220.tar.gz" ;;
46 | * ) echo "Wrong target $TARGET";
47 | exit 0 ;;
48 | esac
49 |
50 | echo "Downloading $TARGET done!"
51 |
--------------------------------------------------------------------------------
/examples/README.md:
--------------------------------------------------------------------------------
1 | # DensePhrases Examples
2 |
3 | We provide descriptions on how to use DensePhrases for different applications.
4 | For instance, based on the retrieved passages from DensePhrases, you can train a state-of-the-art open-domain question answering model called [Fusion-in-Decoder](https://arxiv.org/abs/2007.01282) by Izacard and Grave, 2021, or you can run entity linking with DensePhrases.
5 |
6 | * [Basics: Multi-Granularity Text Retrieval](#basics-multi-granularity-text-retrieval)
7 | * [Create a Custom Phrase Index](https://github.com/princeton-nlp/DensePhrases/tree/main/examples/create-custom-index)
8 | * [Open-Domain QA with Fusion-in-Decoder](https://github.com/princeton-nlp/DensePhrases/tree/main/examples/fusion-in-decoder)
9 | * [Entity Linking](https://github.com/princeton-nlp/DensePhrases/tree/main/examples/entity-linking)
10 | * [Knowledge-grounded Dialogue](https://github.com/princeton-nlp/DensePhrases/tree/main/examples/knowledge-dialogue)
11 | * [Slot Filling](https://github.com/princeton-nlp/DensePhrases/tree/main/examples/slot-filling)
12 |
13 | ## Basics: Multi-Granularity Text Retrieval
14 | The most basic use of DensePhrases is to retrieve phrases, sentences, paragraphs, or documents for your query.
15 | ```python
16 | >>> from densephrases import DensePhrases
17 |
18 | # Load DensePhrases
19 | >>> model = DensePhrases(
20 | ... load_dir='princeton-nlp/densephrases-multi-query-multi',
21 | ... dump_dir='/path/to/densephrases-multi_wiki-20181220/dump'
22 | ... )
23 |
24 | # Search phrases
25 | >>> model.search('Who won the Nobel Prize in peace?', retrieval_unit='phrase', top_k=5)
26 | ['Denis Mukwege,', 'Theodore Roosevelt', 'Denis Mukwege', 'John Mott', 'Mother Teresa']
27 |
28 | # Search sentences
29 | >>> model.search('Why is the sky blue', retrieval_unit='sentence', top_k=1)
30 | ['The blue color is sometimes wrongly attributed to Rayleigh scattering, which is responsible for the color of the sky.']
31 |
32 | # Search paragraphs
33 | >>> model.search('How to become a great researcher', retrieval_unit='paragraph', top_k=1)
34 | ['... Levine said he believes the key to being a great researcher is having passion for research in and working on questions that the researcher is truly curious about. He said: "Have patience, persistence and enthusiasm and you’ll be fine."']
35 |
36 | # Search documents (Wikipedia titles)
37 | >>> model.search('What is the history of internet', retrieval_unit='document', top_k=3)
38 | ['Computer network', 'History of the World Wide Web', 'History of the Internet']
39 | ```
40 |
41 | For batch queries, simply feed a list of queries as ``query``.
42 | To get more detailed search results, set ``return_meta=True`` as follows:
43 | ```python
44 | # Search phrases and get detailed results
45 | >>> phrases, metadata = model.search(['Who won the Nobel Prize in peace?', 'Name products of Apple.'], retrieval_unit='phrase', return_meta=True)
46 |
47 | >>> phrases[0]
48 | ['Denis Mukwege,', 'Theodore Roosevelt', 'Denis Mukwege', 'John Mott', 'Muhammad Yunus', ...]
49 |
50 | >>> metadata[0]
51 | [{'context': '... The most recent as of 2018, Denis Mukwege, was awarded his Peace Prize in 2018. ...', 'title': ['List of black Nobel laureates'], 'doc_idx': 5433697, 'start_pos': 558, 'end_pos': 572, 'start_idx': 15, 'end_idx': 16, 'score': 99.670166015625, ..., 'answer': 'Denis Mukwege,'}, ...]
52 | ```
53 | Note that when the model returns phrases, it also returns passages in its metadata as described in our [EMNLP paper](https://arxiv.org/abs/2109.08133).
54 |
55 | ### CPU-only Mode
56 | ```python
57 | # Load DensePhrases in CPU-only mode
58 | >>> model = DensePhrases(
59 | ... load_dir='princeton-nlp/densephrases-multi-query-multi',
60 | ... dump_dir='/path/to/densephrases-multi_wiki-20181220/dump',
61 | ... device='cpu',
62 | ... max_query_length=24, # reduce the maximum query length for a faster query encoding (optional)
63 | ... )
64 | ```
65 |
66 | ### Changing the Index or the Encoder
67 | ```python
68 | # Load DensePhrases with a smaller phrase index
69 | >>> model = DensePhrases(
70 | ... load_dir='princeton-nlp/densephrases-multi-query-multi',
71 | ... dump_dir='/path/to/densephrases-multi_wiki-20181220/dump',
72 | ... index_name='start/1048576_flat_OPQ96_small'
73 | ... )
74 |
75 | # Change the DensePhrases encoder to 'princeton-nlp/densephrases-multi-query-tqa' (trained on TriviaQA)
76 | >>> model.set_encoder('princeton-nlp/densephrases-multi-query-tqa')
77 | ```
78 |
79 | ### Evaluation
80 | ```python
81 | >>> import os
82 |
83 | # Evaluate loaded DensePhrases on Natural Questions
84 | >>> model.evaluate(test_path=os.path.join(os.environ['DATA_DIR'], 'open-qa/nq-open/test_preprocessed.json'))
85 | ```
86 |
--------------------------------------------------------------------------------
/examples/create-custom-index/README.md:
--------------------------------------------------------------------------------
1 | # Creating a Custom Phrase Index with DensePhrases
2 |
3 | Basically, DensePhrases uses a text corpus pre-processed in the following format (a snippet from [articles.json](https://github.com/princeton-nlp/DensePhrases/blob/main/examples/create-custom-index/articles.json)):
4 | ```
5 | {
6 | "data": [
7 | {
8 | "title": "America's Got Talent (season 4)",
9 | "paragraphs": [
10 | {
11 | "context": " The fourth season of \"America's Got Talent\", ... Country singer Kevin Skinner was named the winner on September 16, 2009 ..."
12 | },
13 | {
14 | "context": " Season four was Hasselhoff's final season as a judge. This season started broadcasting live on August 4, 2009. ..."
15 | },
16 | ...
17 | ]
18 | },
19 | ]
20 | }
21 | ```
22 |
23 | ## Building a Phrase Index
24 | Each `context` contains a single natural paragraph of a variable length. The following command creates phrase vectors for the custom corpus (`articles.json`) with the `densephrases-multi` model.
25 |
26 | ```bash
27 | python generate_phrase_vecs.py \
28 | --model_type bert \
29 | --pretrained_name_or_path SpanBERT/spanbert-base-cased \
30 | --data_dir ./ \
31 | --cache_dir $CACHE_DIR \
32 | --predict_file examples/create-custom-index/articles.json \
33 | --do_dump \
34 | --max_seq_length 512 \
35 | --doc_stride 500 \
36 | --fp16 \
37 | --filter_threshold -2.0 \
38 | --append_title \
39 | --load_dir $SAVE_DIR/densephrases-multi \
40 | --output_dir $SAVE_DIR/densephrases-multi_sample
41 | ```
42 | The phrase vectors (and their metadata) will be saved under `$SAVE_DIR/densephrases-multi_sample/dump/phrase`. Now you need to create a faiss index as follows:
43 | ```bash
44 | python build_phrase_index.py \
45 | --dump_dir $SAVE_DIR/densephrases-multi_sample/dump \
46 | --stage all \
47 | --replace \
48 | --num_clusters 32 \
49 | --fine_quant OPQ96 \
50 | --doc_sample_ratio 1.0 \
51 | --vec_sample_ratio 1.0 \
52 | --cuda
53 |
54 | # Compress metadata for faster inference
55 | python scripts/preprocess/compress_metadata.py \
56 | --input_dump_dir $SAVE_DIR/densephrases-multi_sample/dump/phrase \
57 | --output_dir $SAVE_DIR/densephrases-multi_sample/dump
58 | ```
59 | Note that this example uses a very small text corpus and the hyperparameters for `build_phrase_index.py` in a larger scale corpus can be found [here](https://github.com/princeton-nlp/DensePhrases/tree/main#densephrases-training-indexing-and-inference).
60 | Depending on the size of the corpus, the hyperparameters should change as follows:
61 | * `num_clusters`: Set to make the number of vectors per cluster < 2000 (e.g., `--num_culsters 256` works well for `dev_wiki.json`).
62 | * `doc/vec_sample_ratio`: Use the default value (0.2) except for the small scale experiments (shown above).
63 | * `fine_quant`: Currently only OPQ96 is supported.
64 |
65 | The phrase index (with IVFOPQ) will be saved under `$SAVE_DIR/densephrases-multi_sample/dump/start`.
66 | For creating a large-scale phrase index (e.g., Wikipedia), see [dump_phrases.py](https://github.com/princeton-nlp/DensePhrases/blob/main/scripts/parallel/dump_phrases.py) for an example, which is also explained [here](https://github.com/princeton-nlp/DensePhrases/tree/main#2-creating-a-phrase-index).
67 |
68 | ## Testing a Phrase Index
69 | You can use this phrase index to run a [demo](https://github.com/princeton-nlp/DensePhrases/tree/main#playing-with-a-densephrases-demo) or evaluate your set of queries.
70 | For instance, you can feed a set of questions (`questions.json`) to the custom phrase index as follows:
71 | ```bash
72 | python eval_phrase_retrieval.py \
73 | --run_mode eval \
74 | --cuda \
75 | --dump_dir $SAVE_DIR/densephrases-multi_sample/dump \
76 | --index_name start/32_flat_OPQ96 \
77 | --load_dir $SAVE_DIR/densephrases-multi \
78 | --test_path examples/create-custom-index/questions.json \
79 | --save_pred \
80 | --truecase
81 | ```
82 | The prediction file will be saved as `$SAVE_DIR/densephrases-multi/pred/questions_3_top10.pred`, which shows the answer phrases and the passages that contain the phrases:
83 | ```
84 | {
85 | "1": {
86 | "question": "Who won season 4 of America's got talent",
87 | ...
88 | "prediction": [
89 | "Kevin Skinner",
90 | ...
91 | ],
92 | "evidence": [
93 | "The fourth season of \"America's Got Talent\", an American television reality show talent competition, premiered on the NBC network on June 23, 2009. Country singer Kevin Skinner was named the winner on September 16, 2009.",
94 | ...
95 | ],
96 | }
97 | ...
98 | }
99 | ```
100 |
--------------------------------------------------------------------------------
/examples/create-custom-index/questions.json:
--------------------------------------------------------------------------------
1 | {
2 | "data": [
3 | {
4 | "id": "1",
5 | "question": "who won season 4 of america's got talent",
6 | "answers": ["Kevin Skinner", "Country singer Kevin Skinner"]
7 | },
8 | {
9 | "id": "2",
10 | "question": "how many goals scored ronaldo in 2014-2015 season",
11 | "answers": ["61"]
12 | },
13 | {
14 | "id": "3",
15 | "question": "who plays william boldwood in far from the madding crowd",
16 | "answers": ["Michael Sheen"]
17 | }
18 | ]
19 | }
20 |
--------------------------------------------------------------------------------
/examples/entity-linking/README.md:
--------------------------------------------------------------------------------
1 | # Entity Linking
2 |
3 | ## Pre-trained Models
4 | | Model | Query-FT. & Eval | R-Precision| Description |
5 | |:-------------------------------|:--------:|:--------:|:--------:|
6 | | [densephrases-multi-query-ay2](https://huggingface.co/princeton-nlp/densephrases-multi-query-ay2) | AIDA CoNLL-YAGO (AY2) | 61.6 | Result from [eval.ai](https://eval.ai/web/challenges/challenge-page/689/overview) |
7 | | [densephrases-multi-query-kilt-multi](https://huggingface.co/princeton-nlp/densephrases-multi-query-kilt-multi) | Multiple / AY2 | 68.4 | Trained on multiple KILT tasks |
8 |
9 | ## How to Use
10 | ```python
11 | >>> from densephrases import DensePhrases
12 |
13 | # Load densephraes-multi-query-ay2
14 | >>> model = DensePhrases(
15 | ... load_dir='princeton-nlp/densephrases-multi-query-ay2',
16 | ... dump_dir='/path/to/densephrases-multi_wiki-20181220/dump',
17 | ... )
18 |
19 | # Entities need to be surrounded by [START_ENT] and [END_ENT] tags
20 | >>> model.search('West Indian all-rounder Phil Simmons took four for 38 on Friday as Leicestershire beat [START_ENT] Somerset [END_ENT] by an innings and 39 runs', retrieval_unit='document', top_k=1)
21 | ['Somerset County Cricket Club']
22 |
23 | >>> model.search('[START_ENT] Security Council [END_ENT] members expressed concern on Thursday', retrieval_unit='document', top_k=1)
24 | ['United Nations Security Council']
25 | ```
26 |
27 | ### Evaluation
28 | ```python
29 | >>> import os
30 |
31 | # Evaluate loaded DensePhrases on AIDA CoNLL-YAGO (KILT)
32 | >>> model.evaluate(
33 | ... test_path=os.path.join(os.environ['DATA_DIR'], 'kilt/ay2/aidayago2-dev-kilt_open.json'),
34 | ... is_kilt=True, title2wikiid_path=os.path.join(os.environ['DATA_DIR'], 'wikidump/title2wikiid.json'),
35 | ... kilt_gold_path=os.path.join(os.environ['DATA_DIR'], 'kilt/ay2/aidayago2-dev-kilt.jsonl'), agg_strat='opt2', max_query_length=384
36 | ... )
37 | ```
38 |
39 | For test accuracy, use `aidayago2-test-kilt_open.json` instead and submit the prediction file (saved as `$SAVE_DIR/densephrases-multi-query-ay2/pred-kilt/*.jsonl`) to [eval.ai](https://eval.ai/web/challenges/challenge-page/689/overview).
40 | For WNED-WIKI and WNED-CWEB, follow the same process with files specified in the `wned-kilt-data` and `cweb-kilt-data` targets in [Makefile](https://github.com/princeton-nlp/DensePhrases/blob/main/Makefile).
41 | You can also evaluate the model with Makefile `eval-index` target by simply chaning the dependency.
42 |
--------------------------------------------------------------------------------
/examples/fusion-in-decoder/README.md:
--------------------------------------------------------------------------------
1 | # Fusion-in-Decoder with DensePhrases
2 | You can use retrieved passages from DensePhrases to build a state-of-the-art open-domain QA system called [Fusion-in-Decoder](https://arxiv.org/abs/2007.01282) (FiD).
3 | Note that DensePhrases (w/o reader) already provides phrase-level answers for end-to-end open-domain QA whose performance is comparable to DPR (w/ BERT reader). This section provides how you can further improve the performance using a generative reader model (T5).
4 |
5 | ## Getting Top Passages from DensePhrases
6 | First, you need to get passages from DensePhrases.
7 | Using DensePhrases-multi, you can retrieve passages for Natural Questions as follows:
8 | ```
9 | TRAIN_DATA=open-qa/nq-open/train_preprocessed.json
10 | DEV_DATA=open-qa/nq-open/dev_preprocessed.json
11 | TEST_DATA=open-qa/nq-open/test_preprocessed.json
12 |
13 | # Change --test_path accordingly
14 | python eval_phrase_retrieval.py \
15 | --run_mode eval \
16 | --model_type bert \
17 | --pretrained_name_or_path SpanBERT/spanbert-base-cased \
18 | --cuda \
19 | --dump_dir $SAVE_DIR/densephrases-multi_wiki-20181220/dump/ \
20 | --index_name start/1048576_flat_OPQ96 \
21 | --load_dir $SAVE_DIR/densephrases-multi-query-nq \
22 | --test_path $DATA_DIR/$TEST_DATA \
23 | --save_pred \
24 | --aggregate \
25 | --agg_strat opt2 \
26 | --top_k 200 \
27 | --eval_psg \
28 | --psg_top_k 100 \
29 | --truecase
30 | ```
31 | Since FiD requires training passages, you need to change `--test_path` to `$TRAIN_DATA` or `$DEV_DATA` to get training or development passages, respectively.
32 | Equivalently, you can use `eval-index-psg` in our [Makefile](https://github.com/princeton-nlp/DensePhrases/blob/main/Makefile).
33 | For TriviaQA, simply change the dataset to `tqa-open-data` specified in Makefile.
34 |
35 | After the inference, you will be able to get the following three files used for training and evaluating FiD models:
36 | * train_preprocessed_79168_top200_psg-top100.json
37 | * dev_preprocessed_8757_top200_psg-top100.json
38 | * test_preprocessed_3610_top200_psg-top100.json
39 |
40 | We will assume that these files are saved under `$SAVE_DIR/fid-data`.
41 | Note that each retrieved passage in DensePhrases is a natural paragraph mostly in different lengths. For the exact replication of the experiments in our EMNLP paper, you need a phrase index created from Wikipedia pre-processed for DPR (100-word passages), which we plan to provide soonish.
42 |
43 | ## Installing Fusion-in-Decoder
44 | For Fusion-in-Decoder, we use [the official code](https://github.com/facebookresearch/FiD) provided by the authors.
45 | It is often better to use a separate conda environment to train FiD.
46 | See [here](https://github.com/facebookresearch/FiD#dependencies) for dependencies.
47 |
48 | ```bash
49 | # Install torch with conda (please check your CUDA version)
50 | conda create -n fid python=3.7
51 | conda activate fid
52 | conda install pytorch=1.9.0 cudatoolkit=11.0 -c pytorch
53 |
54 | # Install Fusion-in-Decoder
55 | git clone https://github.com/facebookresearch/FiD.git
56 | cd FiD
57 | pip install -r requirements.txt
58 | ```
59 |
60 | ## Training and Evaluation
61 | ```bash
62 | TRAIN_DATA=fid-data/train_preprocessed_79168_top200_psg-top100.json
63 | DEV_DATA=fid-data/dev_preprocessed_8757_top200_psg-top100.json
64 | TEST_DATA=fid-data/test_preprocessed_3610_top200_psg-top100.json
65 |
66 | # Train T5-base with top 5 passages (DDP using 4 GPUs)
67 | nohup python /path/to/miniconda3/envs/fid/lib/python3.6/site-packages/torch/distributed/launch.py \
68 | --nnode=1 --node_rank=0 --nproc_per_node=4 train_reader.py \
69 | --train_data $SAVE_DIR/$TRAIN_DATA \
70 | --eval_data $SAVE_DIR/$DEV_DATA \
71 | --model_size base \
72 | --per_gpu_batch_size 1 \
73 | --accumulation_steps 16 \
74 | --total_steps 160000 \
75 | --eval_freq 8000 \
76 | --save_freq 8000 \
77 | --n_context 5 \
78 | --lr 0.00005 \
79 | --text_maxlength 300 \
80 | --name nq_reader_base-dph-c5-d4 \
81 | --checkpoint_dir $SAVE_DIR/fid-data/pretrained_models > nq_reader_base-dph-c5-d4_out.log &
82 |
83 | # Test T5-base with top 5 passages (DDP using 4 GPUs)
84 | python /n/fs/nlp-jl5167/miniconda3/envs/fid/lib/python3.6/site-packages/torch/distributed/launch.py \
85 | --nnode=1 --node_rank=0 --nproc_per_node=4 test_reader.py \
86 | --model_path $SAVE_DIR/fid-data/pretrained_models/nq_reader_base-dph-c5-d4/checkpoint/best_dev \
87 | --eval_data $SAVE_DIR/$TEST_DATA \
88 | --per_gpu_batch_size 1 \
89 | --n_context 5 \
90 | --write_results \
91 | --name nq_reader_base-dph-c5-d4 \
92 | --checkpoint_dir $SAVE_DIR/fid-data/pretrained_models \
93 | --text_maxlength 300
94 | ```
95 | Note that most hyperparameters follow the original work and the only difference is the use of `--accumulation_steps 16` and proper adjustment to its training, save, evaluation steps. Larger `--text_maxlength` is used to cover natural paragraphs that are often longer than 100 words.
96 |
--------------------------------------------------------------------------------
/examples/knowledge-dialogue/README.md:
--------------------------------------------------------------------------------
1 | # Knowledge-Grounded Dialogue
2 |
3 | ## Pre-trained Models
4 | | Model | Query-FT. & Eval | R-Precision| Description |
5 | |:-------------------------------|:--------:|:--------:|:--------:|
6 | | [densephrases-multi-query-wow](https://huggingface.co/princeton-nlp/densephrases-multi-query-wow) | Wizard of Wikipedia (WoW) | 47.0 | Result from [eval.ai](https://eval.ai/web/challenges/challenge-page/689/overview) |
7 | | [densephrases-multi-query-kilt-multi](https://huggingface.co/princeton-nlp/densephrases-multi-query-kilt-multi) | Multiple / WoW | 55.7 | Trained on multiple KILT tasks |
8 |
9 | ## How to Use
10 | ```python
11 | >>> from densephrases import DensePhrases
12 |
13 | # Load densephraes-multi-query-wow
14 | >>> model = DensePhrases(
15 | ... load_dir='princeton-nlp/densephrases-multi-query-wow',
16 | ... dump_dir='/path/to/densephrases-multi_wiki-20181220/dump',
17 | ... )
18 |
19 | # Feed a dialogue as a query
20 | >>> model.search('I love rap music.', retrieval_unit='document', top_k=10)
21 | ['Rapping', 'Hip hop', 'Rap metal', 'Hip hop music', 'Rapso', 'Battle rap', 'Rape', 'Eurodance', 'Chopper (rap)', 'Rape culture']
22 |
23 | >>> model.search('Have you heard of Yamaha? They started as a piano manufacturer in 1887!', retrieval_unit='document', top_k=5)
24 | ['Yamaha Corporation', 'Yamaha Drums', 'Tōkai Gakki', 'Suzuki Musical Instrument Corporation', 'Supermoto']
25 |
26 | # You can get more metadata on the document by setting return_meta=True
27 | >>> doc, meta = model.search('I love rap music.', retrieval_unit='document', top_k=1, return_meta=True)
28 | >>> meta
29 | [{'context': 'Rap is usually delivered over a beat, typically provided by a DJ, turntablist, ...', 'title': ['Rapping'], 'doc_idx': 4096192, 'start_pos': 647, 'end_pos': 660, 'start_idx': 91, 'end_idx': 93, 'score': 53.58412170410156, ... 'answer': 'hip-hop music'}]
30 | ```
31 |
32 | ### Evaluation
33 | ```python
34 | >>> import os
35 |
36 | # Evaluate loaded DensePhrases on Wizard of Wikipedia
37 | >>> model.evaluate(
38 | ... test_path=os.path.join(os.environ['DATA_DIR'], 'kilt/wow/wow-dev-kilt_open.json'),
39 | ... is_kilt=True, title2wikiid_path=os.path.join(os.environ['DATA_DIR'], 'wikidump/title2wikiid.json'),
40 | ... kilt_gold_path=os.path.join(os.environ['DATA_DIR'], 'kilt/wow/wow-dev-kilt.jsonl'), agg_strat='opt2', max_query_length=384
41 | ... )
42 | ```
43 |
44 | For test accuracy, use `wow-test-kilt_open.json` instead and submit the prediction file (saved as `$SAVE_DIR/densephrases-multi-query-wow/pred-kilt/*.jsonl`) to [eval.ai](https://eval.ai/web/challenges/challenge-page/689/overview).
45 | You can also evaluate the model with Makefile `eval-index` target by simply chaning the dependency.
46 |
--------------------------------------------------------------------------------
/examples/slot-filling/README.md:
--------------------------------------------------------------------------------
1 | # Slot Filling
2 |
3 | ## Pre-trained Models
4 | | Model | Query-FT. & Eval | KILT-Accuracy | Description |
5 | |:-------------------------------|:--------:|:--------:|:--------:|
6 | | [densephrases-multi-query-trex](https://nlp.cs.princeton.edu/projects/densephrases/models/densephrases-multi-query-trex.tar.gz) | T-REx | 22.3 | Result from [eval.ai](https://eval.ai/web/challenges/challenge-page/689/overview) |
7 | | [densephrases-multi-query-zsre](https://nlp.cs.princeton.edu/projects/densephrases/models/densephrases-multi-query-zsre.tar.gz) | Zero-shot RE | 40.0 | |
8 |
9 | ## How to Use
10 | ```python
11 | >>> from densephrases import DensePhrases
12 |
13 | # Load densephraes-multi-query-trex locally
14 | >>> model = DensePhrases(
15 | ... load_dir='/path/to/densephrases-multi-query-trex',
16 | ... dump_dir='/path/to/densephrases-multi_wiki-20181220/dump',
17 | ... )
18 |
19 | # Slot filling queries are in the format of 'Subject [SEP] Relation'
20 | >>> model.search('Superman [SEP] father', retrieval_unit='phrase', top_k=5)
21 | ['Jor-El', 'Clark Kent', 'Jor-El', 'Jor-El', 'Jor-El']
22 |
23 | >>> model.search('Cirith Ungol [SEP] genre', retrieval_unit='phrase', top_k=5)
24 | ['heavy metal', 'doom metal', 'metal', 'Elvish', 'madrigal comedy']
25 | ```
26 |
27 | ### Evaluation
28 | ```python
29 | >>> import os
30 |
31 | # Evaluate loaded DensePhrases on T-REx (KILT)
32 | >>> model.evaluate(
33 | ... test_path=os.path.join(os.environ['DATA_DIR'], 'kilt/trex/trex-dev-kilt_open.json'),
34 | ... is_kilt=True, title2wikiid_path=os.path.join(os.environ['DATA_DIR'], 'wikidump/title2wikiid.json'),
35 | ... kilt_gold_path=os.path.join(os.environ['DATA_DIR'], 'kilt/trex/trex-dev-kilt.jsonl'), agg_strat='opt2',
36 | ... )
37 | ```
38 |
39 | For test accuracy, use `trex-test-kilt_open.json` instead and submit the prediction file (saved as `$SAVE_DIR/densephrases-multi-query-trex/pred-kilt/densephrases-multi-query-trex_trex-test-kilt_open_5000.jsonl`) to [eval.ai](https://eval.ai/web/challenges/challenge-page/689/overview).
40 | For zero-shot relation extraction, follow the same process with files specified in the `zsre-kilt-data` target in [Makefile](https://github.com/princeton-nlp/DensePhrases/blob/main/Makefile).
41 | You can also evaluate the model with Makefile `eval-index` target by simply chaning the dependency to `trex-kilt-data` or `zsre-kilt-data`.
42 |
--------------------------------------------------------------------------------
/generate_phrase_vecs.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | """ Finetuning the library models for question-answering on SQuAD (DistilBERT, Bert, XLM, XLNet)."""
17 |
18 |
19 | import argparse
20 | import logging
21 | import os
22 | import timeit
23 | import copy
24 | import h5py
25 | import torch
26 |
27 | from tqdm import tqdm, trange
28 | from torch.utils.data import DataLoader, SequentialSampler
29 | from torch.utils.data.distributed import DistributedSampler
30 |
31 | from transformers import (
32 | MODEL_MAPPING,
33 | AutoConfig,
34 | AutoModel,
35 | AutoTokenizer,
36 | )
37 | from densephrases.utils.squad_utils import ContextResult, load_and_cache_examples
38 | from densephrases.utils.single_utils import set_seed, to_list, to_numpy, backward_compat, load_encoder
39 | from densephrases.utils.embed_utils import write_phrases, write_filter
40 | from densephrases import Options
41 |
42 | logger = logging.getLogger(__name__)
43 |
44 |
45 | def dump_phrases(args, model, tokenizer, filter_only=False):
46 | output_path = 'dump/phrase' if not filter_only else 'dump/filter'
47 | if not os.path.exists(os.path.join(args.output_dir, output_path)):
48 | os.makedirs(os.path.join(args.output_dir, output_path))
49 |
50 | start_time = timeit.default_timer()
51 | if ':' not in args.predict_file:
52 | predict_files = [args.predict_file]
53 | offsets = [0]
54 | output_dump_file = os.path.join(
55 | args.output_dir, f"{output_path}/{os.path.splitext(os.path.basename(args.predict_file))[0]}.hdf5"
56 | )
57 | else:
58 | dirname = os.path.dirname(args.predict_file)
59 | basename = os.path.basename(args.predict_file)
60 | start, end = list(map(int, basename.split(':')))
61 | output_dump_file = os.path.join(
62 | args.output_dir, f"{output_path}/{start}-{end}.hdf5"
63 | )
64 |
65 | # skip files if possible
66 | if os.path.exists(output_dump_file):
67 | with h5py.File(output_dump_file, 'r') as f:
68 | dids = list(map(int, f.keys()))
69 | start = int(max(dids) / 1000)
70 | logger.info('%s exists; starting from %d' % (output_dump_file, start))
71 |
72 | names = [str(i).zfill(4) for i in range(start, end)]
73 | predict_files = [os.path.join(dirname, name) for name in names]
74 | offsets = [int(each) * 1000 for each in names]
75 |
76 | for offset, predict_file in zip(offsets, predict_files):
77 | args.predict_file = predict_file
78 | logger.info(f"***** Pre-processing contexts from {args.predict_file} *****")
79 | dataset, examples, features = load_and_cache_examples(
80 | args, tokenizer, evaluate=True, output_examples=True, context_only=True
81 | )
82 | for example in examples:
83 | example.doc_idx += offset
84 |
85 | args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
86 |
87 | # Note that DistributedSampler samples randomly
88 | eval_sampler = SequentialSampler(dataset)
89 | eval_dataloader = DataLoader(dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)
90 |
91 | logger.info(f"***** Dumping Phrases from {args.predict_file} *****")
92 | logger.info(" Num examples = %d", len(dataset))
93 | logger.info(" Batch size = %d", args.eval_batch_size)
94 | start_time = timeit.default_timer()
95 |
96 | def get_phrase_results():
97 | for batch in tqdm(eval_dataloader, desc="Dumping"):
98 | model.eval()
99 | batch = tuple(t.to(args.device) for t in batch)
100 |
101 | with torch.no_grad():
102 | inputs = {
103 | "input_ids": batch[0],
104 | "attention_mask": batch[1],
105 | "token_type_ids": batch[2],
106 | "return_phrase": True,
107 | }
108 | feature_indices = batch[3]
109 | outputs = model(**inputs)
110 |
111 | for i, feature_index in enumerate(feature_indices):
112 | # TODO: i and feature_index are the same number! Simplify by removing enumerate?
113 | eval_feature = features[feature_index.item()]
114 | unique_id = int(eval_feature.unique_id)
115 |
116 | output = [
117 | to_numpy(output[i]) if type(output) != dict else {k: to_numpy(v[i]) for k, v in output.items()}
118 | for output in outputs
119 | ]
120 |
121 | if len(output) != 4:
122 | raise NotImplementedError
123 | else:
124 | start_vecs, end_vecs, sft_logits, eft_logits = output
125 | result = ContextResult(
126 | unique_id,
127 | start_vecs=start_vecs,
128 | end_vecs=end_vecs,
129 | sft_logits=sft_logits,
130 | eft_logits=eft_logits,
131 | )
132 | yield result
133 |
134 | if not filter_only:
135 | write_phrases(
136 | examples, features, get_phrase_results(), args.max_answer_length, args.do_lower_case, tokenizer,
137 | output_dump_file, args.filter_threshold, args.verbose_logging,
138 | args.dense_offset, args.dense_scale, has_title=args.append_title,
139 | )
140 | else:
141 | write_filter(
142 | examples, features, get_phrase_results(), tokenizer,
143 | output_dump_file, args.filter_threshold, args.verbose_logging, has_title=args.append_title,
144 | )
145 |
146 | evalTime = timeit.default_timer() - start_time
147 | logger.info("Evaluation done in total %f secs (%f sec per example)", evalTime, evalTime / len(dataset))
148 |
149 |
150 | def main():
151 | # See options in densephrases.options
152 | options = Options()
153 | options.add_model_options()
154 | options.add_data_options()
155 | options.add_rc_options()
156 | args = options.parse()
157 |
158 | # Setup CUDA, GPU & distributed training
159 | if args.local_rank == -1 or args.no_cuda:
160 | device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
161 | args.n_gpu = 0 if args.no_cuda else torch.cuda.device_count()
162 | else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
163 | torch.cuda.set_device(args.local_rank)
164 | device = torch.device("cuda", args.local_rank)
165 | torch.distributed.init_process_group(backend="nccl")
166 | args.n_gpu = 1
167 | args.device = device
168 |
169 | # Setup logging
170 | logging.basicConfig(
171 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
172 | datefmt="%m/%d/%Y %H:%M:%S",
173 | level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN,
174 | )
175 | logger.warning(
176 | "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
177 | args.local_rank,
178 | device,
179 | args.n_gpu,
180 | bool(args.local_rank != -1),
181 | args.fp16,
182 | )
183 |
184 | # Set seed
185 | set_seed(args)
186 |
187 | # Load config, tokenizer
188 | if args.local_rank not in [-1, 0]:
189 | # Make sure only the first process in distributed training will download model & vocab
190 | torch.distributed.barrier()
191 |
192 | args.model_type = args.model_type.lower()
193 | config, unused_kwargs = AutoConfig.from_pretrained(
194 | args.config_name if args.config_name else args.pretrained_name_or_path,
195 | cache_dir=args.cache_dir if args.cache_dir else None,
196 | output_hidden_states=False,
197 | return_unused_kwargs=True
198 | )
199 | tokenizer = AutoTokenizer.from_pretrained(
200 | args.tokenizer_name if args.tokenizer_name else args.pretrained_name_or_path,
201 | do_lower_case=args.do_lower_case,
202 | cache_dir=args.cache_dir if args.cache_dir else None,
203 | )
204 |
205 | if args.local_rank == 0:
206 | # Make sure only the first process in distributed training will download model & vocab
207 | torch.distributed.barrier()
208 |
209 | logger.info("Dump parameters %s", args)
210 |
211 | # Before we do anything with models, we want to ensure that we get fp16 execution of torch.einsum if args.fp16 is set.
212 | # Otherwise it'll default to "promote" mode, and we'll get fp32 operations. Note that running `--fp16_opt_level="O2"`
213 | # will remove the need for this code, but it is still valid.
214 | if args.fp16:
215 | try:
216 | import apex
217 | apex.amp.register_half_function(torch, "einsum")
218 | except ImportError:
219 | raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
220 |
221 | # Create phrase vectors
222 | if args.do_dump:
223 | assert args.load_dir
224 | model, tokenizer, config = load_encoder(device, args, phrase_only=True)
225 |
226 | args.draft = False
227 | dump_phrases(args, model, tokenizer, filter_only=args.filter_only)
228 |
229 |
230 | if __name__ == "__main__":
231 | main()
232 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch==1.9.0
2 | faiss-gpu==1.6.5
3 | transformers==2.9.0
4 | spacy==2.3.2
5 | h5py
6 | tqdm
7 | blosc
8 | ujson
9 | rouge
10 | wandb
11 | nltk
12 | flask
13 | flask_cors
14 | tornado
15 | requests-futures
16 |
--------------------------------------------------------------------------------
/scripts/benchmark/benchmark_hdf5.py:
--------------------------------------------------------------------------------
1 | import h5py
2 |
3 | from tqdm import tqdm
4 |
5 |
6 | paths = [
7 | 'dumps/sbcd_sqdqgnqqg_inb64_s384_sqdnq_pinb2_0_20181220_concat/dump/phrase/0-200.hdf5',
8 | 'dumps/sbcd_sqdqgnqqg_inb64_s384_sqdnq_pinb2_0_20181220_concat/dump/phrase/200-400.hdf5'
9 | ]
10 | phrase_dumps = [h5py.File(path, 'r') for path in paths]
11 |
12 |
13 | # Just testing how fast it is to read hdf5 files from disk
14 | for phrase_dump in phrase_dumps:
15 | for doc_id, doc_val in tqdm(phrase_dump.items()):
16 | kk = doc_val['start'][-10:]
17 |
--------------------------------------------------------------------------------
/scripts/benchmark/create_benchmark_data.py:
--------------------------------------------------------------------------------
1 | import json
2 | import pdb
3 |
4 | data_for_denspi = []
5 | data_for_dpr = []
6 |
7 | with open('benchmark/nq_1000_dev_orqa.jsonl', encoding='utf-8') as f:
8 | idx = 0
9 | while True:
10 | line = f.readline()
11 | if line == "":
12 | break
13 |
14 | sample = json.loads(line)
15 |
16 | data_for_denspi.append({
17 | 'id':f'dev_{idx}',
18 | 'question': sample['question'],
19 | 'answers': sample['answer']
20 | })
21 | data_for_dpr.append("\t".join([sample['question'], str(sample['answer'])]))
22 |
23 | idx += 1
24 |
25 | # save data_for_dpr as csv
26 | with open('benchmark/nq_1000_dev_dpr.csv', 'w', encoding='utf-8') as f:
27 | for line in data_for_dpr:
28 | f.writelines(line)
29 | f.writelines("\n")
30 |
31 | # save data_for_denspi as json
32 | with open('benchmark/nq_1000_dev_denspi.json', 'w', encoding='utf-8') as f:
33 | json.dump({'data': data_for_denspi}, f)
34 |
--------------------------------------------------------------------------------
/scripts/dump/check_dump.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import h5py
4 | from tqdm import tqdm
5 |
6 |
7 | def get_range(name):
8 | # name = name.replace('_tfidf', '')
9 | return list(map(int, os.path.splitext(name)[0].split('-')))
10 |
11 |
12 | def find_name(names, pos):
13 | for name in names:
14 | start, end = get_range(name)
15 | assert start != end, 'you have self-looping at %s' % name
16 | if start == pos:
17 | return name, end
18 | raise Exception('hdf5 file starting with %d not found.')
19 |
20 |
21 | def check_dump(args):
22 | print('checking dir contiguity...')
23 | names = os.listdir(args.dump_dir)
24 | pos = args.start
25 | while pos < args.end:
26 | name, pos = find_name(names, pos)
27 | assert pos == args.end, 'reached %d, which is different from the specified end %d' % (pos, args.end)
28 | print('dir contiguity test passed!')
29 | print('checking file corruption...')
30 | pos = args.start
31 | corrupted_paths = []
32 | while pos < args.end:
33 | name, pos = find_name(names, pos)
34 | path = os.path.join(args.dump_dir, name)
35 | try:
36 | with h5py.File(path, 'r') as f:
37 | print('checking %s...' % path)
38 | for dk, group in tqdm(f.items()):
39 | keys = list(group.keys())
40 | except Exception as e:
41 | print(e)
42 | print('%s corrupted!' % path)
43 | corrupted_paths.append(path)
44 | if len(corrupted_paths) > 0:
45 | print('following files are corrupted:')
46 | for path in corrupted_paths:
47 | print(path)
48 | else:
49 | print('file corruption test passed!')
50 |
51 |
52 | def get_args():
53 | parser = argparse.ArgumentParser()
54 | parser.add_argument('dump_dir')
55 | parser.add_argument('start', type=int)
56 | parser.add_argument('end', type=int)
57 |
58 | return parser.parse_args()
59 |
60 |
61 | def main():
62 | args = get_args()
63 | check_dump(args)
64 |
65 |
66 | if __name__ == '__main__':
67 | main()
68 |
--------------------------------------------------------------------------------
/scripts/dump/filter_hdf5.py:
--------------------------------------------------------------------------------
1 | import h5py
2 | import os
3 | from tqdm import tqdm
4 |
5 | input_dump_dir = 'dumps/sbcd_sqd_ftinb84_kl_x4_20181220_concat/dump/phrase/'
6 | select = 0
7 | print(f'************** {select} *****************')
8 | input_dump_paths = sorted(
9 | [os.path.join(input_dump_dir, name) for name in os.listdir(input_dump_dir) if 'hdf5' in name]
10 | )[select:]
11 | print(input_dump_paths)
12 | input_dumps = [h5py.File(path, 'r') for path in input_dump_paths]
13 | dump_names = [os.path.splitext(os.path.basename(path))[0] for path in input_dump_paths]
14 | print(input_dumps)
15 |
16 | # Filter dump for a lighter version
17 | '''
18 | output_dumps = [
19 | h5py.File(f'dumps/densephrases-multi_wiki-20181220/dump/phrase/{k}.hdf5', 'w')
20 | for k in dump_names
21 | ]
22 | print(output_dumps)
23 |
24 |
25 | for dump_idx, (input_dump, output_dump) in tqdm(enumerate(zip(input_dumps, output_dumps))):
26 | print(f'filtering {input_dump} to {output_dump}')
27 | for idx, (key, val) in tqdm(enumerate(input_dump.items())):
28 |
29 | dg = output_dump.create_group(key)
30 | dg.attrs['context'] = val.attrs['context'][:]
31 | dg.attrs['title'] = val.attrs['title'][:]
32 | for k_, v_ in val.items():
33 | if k_ not in ['start', 'len_per_para', 'start2end']:
34 | dg.create_dataset(k_, data=v_[:])
35 |
36 | input_dump.close()
37 | output_dump.close()
38 |
39 | print('filter done')
40 | '''
41 |
42 | def load_doc_groups(phrase_dump_dir):
43 | phrase_dump_paths = sorted(
44 | [os.path.join(phrase_dump_dir, name) for name in os.listdir(phrase_dump_dir) if 'hdf5' in name]
45 | )
46 | doc_groups = {}
47 | types = ['word2char_start', 'word2char_end', 'f2o_start']
48 | attrs = ['context', 'title']
49 | phrase_dumps = [h5py.File(path, 'r') for path in phrase_dump_paths]
50 | for path in tqdm(phrase_dump_paths, desc='loading doc groups'):
51 | with h5py.File(path, 'r') as f:
52 | for key in tqdm(f):
53 | import pdb; pdb.set_trace()
54 | doc_group = {}
55 | for type_ in types:
56 | doc_group[type_] = f[key][type_][:]
57 | for attr in attrs:
58 | doc_group[attr] = f[key].attrs[attr]
59 | doc_groups[key] = doc_group
60 | return doc_groups
61 |
62 | # Save below as a pickle file and load it on memory for later use
63 | doc_groups = load_doc_groups(input_dump_dir)
64 |
--------------------------------------------------------------------------------
/scripts/dump/filter_stats.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import h5py
4 | import numpy as np
5 | from tqdm import tqdm
6 |
7 |
8 | def get_range(name):
9 | # name = name.replace('_tfidf', '')
10 | return list(map(int, os.path.splitext(name)[0].split('-')))
11 |
12 |
13 | def find_name(names, pos):
14 | for name in names:
15 | start, end = get_range(name)
16 | assert start != end, 'you have self-looping at %s' % name
17 | if start == pos:
18 | return name, end
19 | raise Exception('hdf5 file starting with %d not found.')
20 |
21 |
22 | def check_dump(args):
23 | print('checking dir contiguity...')
24 | names = os.listdir(args.dump_dir)
25 | pos = args.start
26 | while pos < args.end:
27 | name, pos = find_name(names, pos)
28 | assert pos == args.end, 'reached %d, which is different from the specified end %d' % (pos, args.end)
29 | print('dir contiguity test passed!')
30 | print('checking file corruption...')
31 | pos = args.start
32 | corrupted_paths = []
33 |
34 | all_count = 0
35 | thresholds = [0.0, 0.5, 1, 1.5, 2, 2.5, 3, 3.5, 4, 4.5, 5]
36 | save_bins = {th: 0 for th in thresholds}
37 | while pos < args.end:
38 | name, pos = find_name(names, pos)
39 | path = os.path.join(args.dump_dir, name)
40 | with h5py.File(path, 'r') as f:
41 | print('checking %s...' % path)
42 | for dk, group in tqdm(f.items()):
43 | filter_start = group['filter_start'][:]
44 | filter_end = group['filter_end'][:]
45 | for th in thresholds:
46 | start_idxs, = np.where(filter_start > th)
47 | end_idxs, = np.where(filter_end > th)
48 | num_save_vec = len(set(np.concatenate([start_idxs, end_idxs])))
49 | save_bins[th] += num_save_vec
50 | all_count += len(filter_start)
51 | # break
52 |
53 | print(all_count)
54 | print(save_bins)
55 | comp_rate = {th: f'{save_num/all_count*100:.2f}%' for th, save_num in save_bins.items()}
56 | print(f'Compression rate: {comp_rate}')
57 | if len(corrupted_paths) > 0:
58 | print('following files are corrupted:')
59 | for path in corrupted_paths:
60 | print(path)
61 | else:
62 | print('file corruption test passed!')
63 |
64 |
65 | def get_args():
66 | parser = argparse.ArgumentParser()
67 | parser.add_argument('dump_dir')
68 | parser.add_argument('start', type=int)
69 | parser.add_argument('end', type=int)
70 |
71 | return parser.parse_args()
72 |
73 |
74 | def main():
75 | args = get_args()
76 | check_dump(args)
77 |
78 |
79 | if __name__ == '__main__':
80 | main()
81 |
--------------------------------------------------------------------------------
/scripts/dump/save_meta.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import h5py
4 | import torch
5 | from tqdm import tqdm
6 |
7 |
8 | def get_range(name):
9 | # name = name.replace('_tfidf', '')
10 | return list(map(int, os.path.splitext(name)[0].split('-')))
11 |
12 |
13 | def find_name(names, pos):
14 | for name in names:
15 | start, end = get_range(name)
16 | assert start != end, 'you have self-looping at %s' % name
17 | if start == pos:
18 | return name, end
19 | raise Exception('hdf5 file starting with %d not found.')
20 |
21 |
22 | def check_dump(args):
23 | print('checking dir contiguity...')
24 | names = os.listdir(args.dump_dir)
25 | pos = args.start
26 | while pos < args.end:
27 | name, pos = find_name(names, pos)
28 | assert pos == args.end, 'reached %d, which is different from the specified end %d' % (pos, args.end)
29 | print('dir contiguity test passed!')
30 | print('checking file corruption...')
31 | pos = args.start
32 | corrupted_paths = []
33 | metadata = {}
34 | keys_to_save = ['f2o_end', 'f2o_start', 'span_logits', 'start2end', 'word2char_end', 'word2char_start']
35 | while pos < args.end:
36 | name, pos = find_name(names, pos)
37 | path = os.path.join(args.dump_dir, name)
38 | try:
39 | with h5py.File(path, 'r') as f:
40 | print('checking %s...' % path)
41 | for dk, group in tqdm(f.items()):
42 | # keys = list(group.keys())
43 | metadata[dk] = {save_key: group[save_key][:] for save_key in keys_to_save}
44 | metadata[dk]['context'] = group.attrs['context']
45 | metadata[dk]['title'] = group.attrs['title']
46 | except Exception as e:
47 | print(e)
48 | print('%s corrupted!' % path)
49 | corrupted_paths.append(path)
50 |
51 | break
52 |
53 | torch.save(metadata, 'tmp.bin')
54 | if len(corrupted_paths) > 0:
55 | print('following files are corrupted:')
56 | for path in corrupted_paths:
57 | print(path)
58 | else:
59 | print('file corruption test passed!')
60 |
61 |
62 | def get_args():
63 | parser = argparse.ArgumentParser()
64 | parser.add_argument('dump_dir')
65 | parser.add_argument('start', type=int)
66 | parser.add_argument('end', type=int)
67 |
68 | return parser.parse_args()
69 |
70 |
71 | def main():
72 | args = get_args()
73 | check_dump(args)
74 |
75 |
76 | if __name__ == '__main__':
77 | main()
78 |
--------------------------------------------------------------------------------
/scripts/dump/split_hdf5.py:
--------------------------------------------------------------------------------
1 | import h5py
2 | import os
3 | from tqdm import tqdm
4 |
5 | input_dump_dir = 'dumps/sbcd_sqd_ftinb84_kl_x4_20181220_concat/dump/phrase/'
6 | select = 6
7 | print(f'************** {select} *****************')
8 | input_dump_paths = sorted(
9 | [os.path.join(input_dump_dir, name) for name in os.listdir(input_dump_dir) if 'hdf5' in name]
10 | )[select:select+1]
11 | print(input_dump_paths)
12 | input_dumps = [h5py.File(path, 'r') for path in input_dump_paths]
13 |
14 | dump_names = [os.path.splitext(os.path.basename(path))[0] for path in input_dump_paths]
15 | dump_ranges = [list(map(int, name.split('-'))) for name in dump_names]
16 | new_ranges = []
17 | for range_ in dump_ranges:
18 | # print(range_)
19 | middle = sum(range_) // 2 # split by half
20 | new_range_ = [[range_[0], middle], [middle, range_[1]]]
21 | # print(new_range_)
22 | new_ranges.append(new_range_)
23 |
24 | output_dumps = [
25 | [h5py.File(f'dumps/sbcd_sqd_ftinb84_kl_x4_20181220_concat/dump/phrase/{ra[0]}-{ra[1]}.hdf5', 'w')
26 | for ra in range_]
27 | for range_ in new_ranges
28 | ]
29 |
30 | print(input_dumps)
31 | print(output_dumps)
32 | print(new_ranges)
33 |
34 | # dev-100M-c 160408
35 | # dev_wiki_noise 250000
36 |
37 | for dump_idx, (input_dump, new_range, output_dump) in tqdm(enumerate(zip(input_dumps, new_ranges, output_dumps))):
38 | print(f'splitting {input_dump} to {output_dump}')
39 | for idx, (key, val) in tqdm(enumerate(input_dump.items())):
40 | # if idx < 250000/2:
41 | if int(key) < new_range[0][1] * 1000:
42 | output_dump[0].copy(val, key)
43 | else:
44 | output_dump[1].copy(val, key)
45 |
46 | input_dump.close()
47 | output_dump[0].close()
48 | output_dump[1].close()
49 |
50 | print('copy done')
51 |
--------------------------------------------------------------------------------
/scripts/kilt/build_title2wikiid.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright 2017-present, Facebook, Inc.
3 | # All rights reserved.
4 |
5 | """A script to read in and store documents in a sqlite database."""
6 |
7 | import argparse
8 | import sqlite3
9 | import json
10 | import os
11 | import logging
12 | import importlib.util
13 | import unicodedata
14 | import html
15 |
16 | from multiprocessing import Pool as ProcessPool
17 | from tqdm import tqdm
18 |
19 | logger = logging.getLogger()
20 | logger.setLevel(logging.INFO)
21 | fmt = logging.Formatter('%(asctime)s: [ %(message)s ]', '%m/%d/%Y %I:%M:%S %p')
22 | console = logging.StreamHandler()
23 | console.setFormatter(fmt)
24 | logger.addHandler(console)
25 |
26 |
27 | # ------------------------------------------------------------------------------
28 | # Import helper
29 | # ------------------------------------------------------------------------------
30 |
31 |
32 | PREPROCESS_FN = None
33 |
34 |
35 | def init(filename):
36 | global PREPROCESS_FN
37 | if filename:
38 | PREPROCESS_FN = import_module(filename).preprocess
39 |
40 |
41 | def import_module(filename):
42 | """Import a module given a full path to the file."""
43 | spec = importlib.util.spec_from_file_location('doc_filter', filename)
44 | module = importlib.util.module_from_spec(spec)
45 | spec.loader.exec_module(module)
46 | return module
47 |
48 |
49 | # ------------------------------------------------------------------------------
50 | # Store corpus.
51 | # ------------------------------------------------------------------------------
52 |
53 | def normalize(text):
54 | """Resolve different type of unicode encodings."""
55 | return unicodedata.normalize('NFD', html.unescape(text))
56 |
57 | def iter_files(path):
58 | """Walk through all files located under a root path."""
59 | if os.path.isfile(path):
60 | yield path
61 | elif os.path.isdir(path):
62 | for dirpath, _, filenames in os.walk(path):
63 | for f in filenames:
64 | yield os.path.join(dirpath, f)
65 | else:
66 | raise RuntimeError('Path %s is invalid' % path)
67 |
68 |
69 | def get_contents(filename):
70 | """Parse the contents of a file. Each line is a JSON encoded document."""
71 | # documents = []
72 | results = {}
73 | with open(filename, encoding='utf-8') as f:
74 | for line in f:
75 | # Parse document
76 | doc = json.loads(line)
77 | # Skip if it is empty or None
78 | if not doc:
79 | continue
80 | # Add the document
81 |
82 | title = normalize(doc['title'])
83 | if '&' in title:
84 | import pdb; pdb.set_trace()
85 |
86 | if 'u0' in title:
87 | import pdb; pdb.set_trace()
88 | results[title] = doc['id']
89 | return results
90 |
91 |
92 | def store_contents(data_path, save_path):
93 | results = {}
94 | files = [f for f in iter_files(data_path)]
95 | for file in tqdm(files):
96 | contents = get_contents(file)
97 | results.update(contents)
98 |
99 | print(f"len(results)={len(results)}")
100 | with open(save_path, 'w') as f:
101 | json.dump(results, f)
102 |
103 | # ------------------------------------------------------------------------------
104 | # Main.
105 | # ------------------------------------------------------------------------------
106 |
107 |
108 | if __name__ == '__main__':
109 | parser = argparse.ArgumentParser()
110 | parser.add_argument('--data_path', type=str, help='/path/to/data')
111 | parser.add_argument('--save_path', type=str, help='/path/to/saved/db.db')
112 | args = parser.parse_args()
113 |
114 | store_contents(
115 | args.data_path, args.save_path
116 | )
--------------------------------------------------------------------------------
/scripts/kilt/sample_kilt.py:
--------------------------------------------------------------------------------
1 | import json
2 | import argparse
3 | import os
4 | import random
5 | import time
6 | import numpy as np
7 |
8 | from tqdm import tqdm
9 |
10 |
11 | def main(input_file, num_sample, balanced):
12 | print('reading', input_file)
13 | random.seed(999)
14 | np.random.seed(999)
15 |
16 | examples = json.load(open(input_file))['data']
17 | print(f'sampling from {len(examples)}')
18 | relation_dict = {}
19 | for example in tqdm(examples):
20 | relation = example['question'].split(' [SEP] ')[-1]
21 | if relation not in relation_dict:
22 | relation_dict[relation] = []
23 | relation_dict[relation].append(example)
24 |
25 | top_relations = sorted(relation_dict.items(), key=lambda x: len(x[1]), reverse=True)
26 | print('There are', len(relation_dict), 'relations.')
27 | print([(rel, len(rel_list)) for rel, rel_list in top_relations])
28 | print()
29 | exit()
30 |
31 | if not balanced:
32 | sample_per_relation = {
33 | rel: int((len(rel_list)/len(examples)) * num_sample) + 1 for rel, rel_list in top_relations
34 | }
35 | else:
36 | sample_per_relation = {
37 | rel: min(num_sample, len(rel_list)) for rel, rel_list in top_relations
38 | }
39 | print('Sample following number of relations')
40 | print(sample_per_relation)
41 |
42 | sample_examples = []
43 | for rel, rel_list in relation_dict.items():
44 | sample_idx = np.random.choice(len(rel_list), size=(sample_per_relation[rel]), replace=False)
45 | sample_examples += np.array(rel_list)[sample_idx].tolist()
46 |
47 | out_file = input_file.replace('.json', f'_{num_sample}_{"balanced" if balanced else "ratio"}.json')
48 | print(f'Saving {len(sample_examples)} examples to {out_file}')
49 | with open(out_file, 'w') as f:
50 | json.dump({'data': sample_examples}, f)
51 |
52 |
53 | if __name__ == '__main__':
54 | parser = argparse.ArgumentParser()
55 | parser.add_argument("input_file", type=str)
56 | parser.add_argument("--num_sample", type=int, required=True)
57 | parser.add_argument("--balanced", action='store_true', default=False)
58 |
59 | args = parser.parse_args()
60 |
61 | main(args.input_file, args.num_sample, args.balanced)
62 |
--------------------------------------------------------------------------------
/scripts/kilt/strip_pred.py:
--------------------------------------------------------------------------------
1 | from densephrases.utils.kilt.eval import evaluate as kilt_evaluate
2 | from densephrases.utils.kilt.kilt_utils import load_data, store_data
3 | import string
4 | import argparse
5 |
6 |
7 | def strip_pred(input_file, gold_file):
8 |
9 | print('original evaluation result:', input_file)
10 | result = kilt_evaluate(gold=gold_file, guess=input_file)
11 | print(result)
12 |
13 | preds = load_data(input_file)
14 | for pred in preds:
15 | pred['output'][0]['answer'] = pred['output'][0]['answer'].strip(string.punctuation)
16 |
17 | out_file = input_file.replace('.jsonl', '_strip.jsonl')
18 | print('strip evaluation result:', out_file)
19 | store_data(out_file, preds)
20 | new_result = kilt_evaluate(gold=gold_file, guess=out_file)
21 | print(new_result)
22 |
23 |
24 |
25 | if __name__ == '__main__':
26 | parser = argparse.ArgumentParser()
27 | parser.add_argument('input_file', type=str)
28 | parser.add_argument('gold_file', type=str)
29 | args = parser.parse_args()
30 | strip_pred(args.input_file, args.gold_file)
31 |
--------------------------------------------------------------------------------
/scripts/parallel/add_to_index.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import subprocess
4 |
5 | import h5py
6 | from tqdm import tqdm
7 |
8 |
9 | def get_size(name):
10 | a, b = list(map(int, os.path.splitext(name)[0].split('-')))
11 | return b - a
12 |
13 |
14 | def bin_names(dir_, names, num_bins):
15 | names = sorted(names, key=lambda name_: -os.path.getsize(os.path.join(dir_, name_)))
16 | bins = []
17 | for name in names:
18 | if len(bins) < num_bins:
19 | bins.append([name])
20 | else:
21 | smallest_bin = min(bins, key=lambda bin_: sum(get_size(name_) for name_ in bin_))
22 | smallest_bin.append(name)
23 | return bins
24 |
25 |
26 | def run_add_to_index(args):
27 | def get_cmd(dump_paths, offset_):
28 | return ["python",
29 | "build_phrase_index.py",
30 | f"{args.dump_dir}",
31 | "add",
32 | "--fine_quant", "SQ4",
33 | "--dump_paths", f"{dump_paths}",
34 | "--offset", f"{offset_}",
35 | "--num_clusters", f"{args.num_clusters}",
36 | f"{'--cuda' if args.cuda else ''}"]
37 |
38 |
39 | dir_ = os.path.join(args.dump_dir, 'phrase')
40 | names = os.listdir(dir_)
41 | bins = bin_names(dir_, names, args.num_gpus)
42 | offsets = [args.max_num_per_file * each for each in range(len(bins))]
43 |
44 | print('adding with offset:')
45 | for offset, bin_ in zip(offsets, bins):
46 | print('%d: %s' % (offset, ','.join(bin_)))
47 |
48 | for kk, (bin_, offset) in enumerate(zip(bins, offsets)):
49 | if args.start <= kk < args.end:
50 | print(get_cmd(','.join(bin_), offset))
51 | subprocess.run(get_cmd(','.join(bin_), offset))
52 | if args.draft:
53 | break
54 |
55 |
56 | def get_args():
57 | parser = argparse.ArgumentParser()
58 | parser.add_argument('--dump_dir', default='dump/76_dev-1B-c')
59 | parser.add_argument('--num_cpus', default=4, type=int)
60 | parser.add_argument('--num_gpus', default=60, type=int)
61 | parser.add_argument('--mem_size', default=40, type=int, help='mem size in GB')
62 | parser.add_argument('--num_clusters', default=4096, type=int)
63 | parser.add_argument('--draft', default=False, action='store_true')
64 | parser.add_argument('--max_num_per_file', default=int(1e8), type=int,
65 | help='max num per file for setting up good offsets.')
66 | parser.add_argument('--cuda', default=False, action='store_true')
67 | parser.add_argument('--start', default=0, type=int)
68 | parser.add_argument('--end', default=3, type=int)
69 | args = parser.parse_args()
70 |
71 | return args
72 |
73 |
74 | def main():
75 | args = get_args()
76 | run_add_to_index(args)
77 |
78 |
79 | if __name__ == '__main__':
80 | main()
81 |
--------------------------------------------------------------------------------
/scripts/parallel/dump_phrases.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import math
3 | import os
4 | import subprocess
5 |
6 |
7 | def run_dump_phrase(args):
8 | do_lower_case = '--do_lower_case' if args.do_lower_case else ''
9 | append_title = '--append_title' if args.append_title else ''
10 | def get_cmd(start_doc, end_doc):
11 | return ["python", "generate_phrase_vecs.py",
12 | "--model_type", f"{args.model_type}",
13 | "--pretrained_name_or_path", f"{args.pretrained_name_or_path}",
14 | "--data_dir", f"{args.phrase_data_dir}",
15 | "--cache_dir", f"{args.cache_dir}",
16 | "--predict_file", f"{start_doc}:{end_doc}",
17 | "--do_dump",
18 | "--max_seq_length", "512",
19 | "--doc_stride", "500",
20 | "--fp16",
21 | "--load_dir", f"{args.load_dir}",
22 | "--output_dir", f"{args.output_dir}",
23 | "--filter_threshold", f"{args.filter_threshold:.2f}"] + \
24 | ([f"{do_lower_case}"] if len(do_lower_case) > 0 else []) + \
25 | ([f"{append_title}"] if len(append_title) > 0 else [])
26 |
27 | num_docs = args.end - args.start
28 | num_gpus = args.num_gpus
29 | num_docs_per_gpu = int(math.ceil(num_docs / num_gpus))
30 | start_docs = list(range(args.start, args.end, num_docs_per_gpu))
31 | end_docs = start_docs[1:] + [args.end]
32 |
33 | print(start_docs)
34 | print(end_docs)
35 |
36 | for device_idx, (start_doc, end_doc) in enumerate(zip(start_docs, end_docs)):
37 | print(get_cmd(start_doc, end_doc))
38 | subprocess.Popen(get_cmd(start_doc, end_doc))
39 |
40 |
41 | def get_args():
42 | parser = argparse.ArgumentParser()
43 | parser.add_argument('--model_type', default='bert')
44 | parser.add_argument('--pretrained_name_or_path', default='SpanBERT/spanbert-base-cased')
45 | parser.add_argument('--data_dir', default='')
46 | parser.add_argument('--cache_dir', default='')
47 | parser.add_argument('--data_name', default='') # for suffix
48 | parser.add_argument('--load_dir', default='')
49 | parser.add_argument('--output_dir', default='')
50 | parser.add_argument('--do_lower_case', default=False, action='store_true')
51 | parser.add_argument('--append_title', default=False, action='store_true')
52 | parser.add_argument('--filter_threshold', default=-1e9, type=float)
53 | parser.add_argument('--num_gpus', default=1, type=int)
54 | parser.add_argument('--start', default=0, type=int)
55 | parser.add_argument('--end', default=8, type=int)
56 | args = parser.parse_args()
57 |
58 | args.output_dir = args.output_dir + '_%s' % (os.path.basename(args.data_name))
59 | args.phrase_data_dir = os.path.join(args.data_dir, args.data_name)
60 |
61 | return args
62 |
63 |
64 | def main():
65 | args = get_args()
66 | run_dump_phrase(args)
67 |
68 |
69 | if __name__ == '__main__':
70 | main()
71 |
--------------------------------------------------------------------------------
/scripts/postprocess/recall.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import re
4 | import unicodedata
5 | from collections import defaultdict
6 | from tqdm import tqdm
7 | from scripts.preprocess.simple_tokenizer import SimpleTokenizer
8 |
9 |
10 | def read_file(infile, handle_file, log=False, skip_first_line=False):
11 | if log:
12 | print('Opening "{}"...'.format(infile))
13 | data = None
14 | with open(infile) as f:
15 | if skip_first_line:
16 | f.readline()
17 | data = handle_file(f)
18 | if log:
19 | print(' Done.')
20 | return data
21 |
22 |
23 | def read_jsonl(infile, log=False):
24 | handler = lambda f: [json.loads(line) for line in f.readlines()]
25 | return read_file(infile, handler, log=log)
26 |
27 |
28 | def read_json(infile, log=False):
29 | handler = lambda f: json.load(f)
30 | return read_file(infile, handler, log=log)
31 |
32 |
33 | def _normalize(text):
34 | return unicodedata.normalize('NFD', text)
35 |
36 | ###############################################################################
37 | ### HAS_ANSWER FUNCTIONS ####################################################
38 | ###############################################################################
39 | def has_answer_field(ctx, answers):
40 | return ctx['has_answer']
41 |
42 |
43 | tokenizer = SimpleTokenizer(**{})
44 | def string_match(ctx, answers):
45 | text = tokenizer.tokenize(ctx['text']).words(uncased=True)
46 |
47 | for single_answer in answers:
48 | single_answer = _normalize(single_answer)
49 | single_answer = tokenizer.tokenize(single_answer)
50 | single_answer = single_answer.words(uncased=True)
51 |
52 | for i in range(0, len(text) - len(single_answer) + 1):
53 | if single_answer == text[i: i + len(single_answer)]:
54 | return True
55 | return False
56 |
57 |
58 | def normalized_title(ctx, answers):
59 | for answer in answers:
60 | a = a.lower().strip()
61 | title = ctx['title'].lower().strip()
62 | if a == title[:len(a)]:
63 | return True
64 | return False
65 |
66 |
67 | def regex(ctx, answers):
68 | text = ctx['text']
69 | for answer in answers:
70 | answer = _normalize(answer)
71 | if regex_match(text, answer):
72 | return True
73 | return False
74 |
75 |
76 | def regex_match(text, pattern):
77 | """Test if a regex pattern is contained within a text."""
78 | try:
79 | pattern = re.compile(
80 | pattern,
81 | flags=re.IGNORECASE + re.UNICODE + re.MULTILINE,
82 | )
83 | except BaseException:
84 | return False
85 | return pattern.search(text) is not None
86 |
87 |
88 | ###############################################################################
89 | ### CALCULATION FUNCTIONS ###################################################
90 | ###############################################################################
91 | def precision_fn(results, k_vals, has_answer):
92 | n_hits = {k: 0 for k in k_vals}
93 | mrrs = []
94 | precs = []
95 | PREC_K = 20
96 | MRR_K = 20
97 |
98 | for result in tqdm(results):
99 | ans = result['answers']
100 | ctxs = result['ctxs']
101 | found_k = len(ctxs) + 1
102 | found = False
103 | num_hit = 0
104 | for c_idx,ctx in enumerate(ctxs):
105 | if has_answer(ctx, ans):
106 | if not found:
107 | found_k = c_idx # record first one
108 | found = True
109 |
110 | if c_idx < PREC_K: # P@k
111 | num_hit += 1
112 | # break
113 | for k in k_vals:
114 | if found_k < k:
115 | n_hits[k] += 1
116 |
117 | if found_k >= MRR_K:
118 | mrrs.append(0)
119 | else:
120 | mrrs.append(1/(found_k + 1))
121 | precs.append(num_hit/PREC_K)
122 |
123 | print('*'*50)
124 | for k in k_vals:
125 | if len(results) == 0:
126 | print('No results.')
127 | else:
128 | print('Top-{} = {:.2%}'.format(k, n_hits[k] / len(results)))
129 |
130 | print(f'Acc@{k_vals[0]} when Acc@{k_vals[-1]} = {n_hits[k_vals[0]]/n_hits[k_vals[-1]]*100:.2f}%')
131 | print(f'MRR@{MRR_K} = {sum(mrrs)/len(mrrs)*100:.2f}')
132 | print(f'P@{PREC_K} = {sum(precs)/len(precs)*100:.2f}')
133 |
134 |
135 | def precision_fn_file(infile, n_docs, k_vals, has_answer, args):
136 | results = read_jsonl(infile) if args.jsonl else read_json(infile)
137 |
138 | # stats
139 | ctx_lens = [sum([len(pp['text'].split()) for pp in re['ctxs']])/len(re['ctxs']) for re in results]
140 | print(f'ctx token length: {sum(ctx_lens)/len(ctx_lens):.2f}')
141 |
142 | # unique titles
143 | title_lens = [len(set(pp['title'] for pp in re['ctxs'])) for re in results]
144 | print(f'unique titles: {sum(title_lens)/len(title_lens):.2f}')
145 |
146 | precision_fn(results, k_vals, has_answer)
147 |
148 |
149 | # Top-20 and Top-100
150 | def precision_per_bucket(results_file, longtail_file, n_docs, k_vals, longtail_tags, ans_fn):
151 | results = read_json(results_file)
152 | annotations = read_json(longtail_file)
153 | for tag in longtail_tags:
154 | bucket = [result for idx,result in enumerate(results) if tag == annotations[idx]['annotations']]
155 | print('==== Bucket={} ====='.format(tag))
156 | precision_fn(bucket, n_docs, k_vals, ans_fn)
157 | print()
158 |
159 |
160 | if __name__ == '__main__':
161 | parser = argparse.ArgumentParser()
162 | parser.add_argument('--results_file', required=True, type=str, default=None,
163 | help="Location of the results file to parse.")
164 | parser.add_argument('--n_docs', type=int, default=100,
165 | help="Maximum number of docs retrieved.")
166 | parser.add_argument('--k_values', type=str, default='1,5,10,20,40,50,60,80,100',
167 | help="Top-K values to print out")
168 | parser.add_argument('--ans_fn', type=str, default='has_answer',
169 | help="How to check whether has the answer. title | has_answer")
170 | parser.add_argument('--jsonl', action='store_true', help='Set if results is a jsonl file.')
171 |
172 | # Longtail Entity Analysis
173 | parser.add_argument('--longtail', action='store_true',
174 | help='whether or not to include longtail buckets')
175 | parser.add_argument('--longtail_file', required=False, type=str, default=None,
176 | help='Mapping from question to longtail entity tags.')
177 | parser.add_argument('--longtail_tags', type=str, default='p10,p25,p50,p75,p90',
178 | help='Tags for the longtail entities within longtail_file')
179 |
180 | args = parser.parse_args()
181 | ks = [int(k) for k in args.k_values.split(',')]
182 | if args.ans_fn == 'has_answer':
183 | ans_fn = has_answer_field
184 | elif args.ans_fn == 'title':
185 | ans_fn = normalized_title
186 | elif args.ans_fn == 'string':
187 | ans_fn = string_match
188 | elif args.ans_fn == 'regex':
189 | ans_fn = regex
190 | else:
191 | raise Exception('Answer function not recognized')
192 |
193 | if args.longtail:
194 | longtail_tags = args.longtail_tags.split(',')
195 | precision_per_bucket(args.results_file, args.longtail_file,
196 | args.n_docs, ks, longtail_tags, ans_fn)
197 | else:
198 | precision_fn_file(args.results_file, args.n_docs, ks, ans_fn, args)
199 |
--------------------------------------------------------------------------------
/scripts/postprocess/recall_transform.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import sys
4 | import argparse
5 | import numpy as np
6 | from spacy.lang.en import English
7 | from tqdm import tqdm
8 |
9 | nlp = English()
10 | nlp.add_pipe(nlp.create_pipe('sentencizer'))
11 |
12 |
13 | def main(args):
14 | pred_file = os.path.join(args.model_dir, 'pred', args.pred_file)
15 | my_pred = json.load(open(pred_file))
16 |
17 | my_target = []
18 | avg_len = []
19 | for qid, pred in tqdm(enumerate(my_pred.values())):
20 | my_dict = {"id": str(qid), "question": None, "answers": [], "ctxs": []}
21 |
22 | # truncate
23 | pred = {key: val[:args.psg_top_k] if key in ['evidence', 'title', 'se_pos', 'prediction'] else val for key, val in pred.items()}
24 |
25 | # TODO: need to add id for predictions.pred in the future
26 | my_dict["question"] = pred["question"]
27 | my_dict["answers"] = pred["answer"]
28 | pred["title"] = [titles[0] for titles in pred["title"]]
29 |
30 | assert len(set(pred["evidence"])) == len(pred["evidence"]) == len(pred["title"]), "Should use opt2 for aggregation"
31 | # assert all(pr in evd for pr, evd in zip(pred["prediction"], pred["evidence"])) # prediction included
32 |
33 | # Pad up to top-k
34 | if not(len(pred["prediction"]) == len(pred["evidence"]) == len(pred["title"]) == args.psg_top_k):
35 | assert len(pred["prediction"]) == len(pred["evidence"]) == len(pred["title"]) < args.psg_top_k, \
36 | (len(pred["prediction"]), len(pred["evidence"]), len(pred["title"]))
37 | print(len(pred["prediction"]), len(pred["evidence"]), len(pred["title"]))
38 |
39 | pred["evidence"] += [pred["evidence"][-1]] * (args.psg_top_k - len(pred["prediction"]))
40 | pred["title"] += [pred["title"][-1]] * (args.psg_top_k - len(pred["prediction"]))
41 | pred["se_pos"] += [pred["se_pos"][-1]] * (args.psg_top_k - len(pred["prediction"]))
42 | pred["prediction"] += [pred["prediction"][-1]] * (args.psg_top_k - len(pred["prediction"]))
43 | assert len(pred["prediction"]) == len(pred["evidence"]) == len(pred["title"]) == args.psg_top_k
44 |
45 | # Used for markers
46 | START = '
'
47 | END = ''
48 | se_idxs = [[se_pos[0], max(se_pos[0], se_pos[1])] for se_pos in pred["se_pos"]]
49 |
50 | # Return sentence
51 | if args.return_sent:
52 | sents = [[(X.text, X[0].idx) for X in nlp(evidence).sents] for evidence in pred['evidence']]
53 | sent_idxs = [
54 | sorted(set([sum(np.array([st[1] for st in sent]) <= se_idx[0]) - 1] + [sum(np.array([st[1] for st in sent]) <= se_idx[1]-1) - 1]))
55 | for se_idx, sent in zip(se_idxs, sents)
56 | ]
57 | se_idxs = [[se_pos[0]-sent[sent_idx[0]][1], se_pos[1]-sent[sent_idx[0]][1]] for se_pos, sent_idx, sent in zip(se_idxs, sent_idxs, sents)]
58 | if not all(pred.replace(' ', '') in ' '.join([sent[sidx][0] for sidx in range(sent_idx[0], sent_idx[-1]+1)]).replace(' ', '')
59 | for pred, sent, sent_idx in zip(pred['prediction'], sents, sent_idxs)):
60 | import pdb; pdb.set_trace()
61 | pass
62 |
63 | # get sentence based on the window
64 | max_context_len = args.max_context_len - 2 if args.mark_phrase else args.max_context_len
65 | my_dict["ctxs"] = [
66 | # {"title": title, "text": ' '.join(' '.join([sent[sidx][0] for sidx in range(sent_idx[0], sent_idx[-1]+1)]).split()[:max_context_len])}
67 | {"title": title, "text": ' '.join(' '.join([sent[sidx][0] for sidx in range(
68 | max(0, sent_idx[0]-args.sent_window), min(sent_idx[-1]+1+args.sent_window, len(sent)))]
69 | ).split()[:max_context_len])
70 | }
71 | for title, sent, sent_idx in zip(pred["title"], sents, sent_idxs)
72 | ]
73 | # Return passagae
74 | else:
75 | my_dict["ctxs"] = [
76 | {"title": title, "text": ' '.join(evd.split()[:args.max_context_len])}
77 | for evd, title in zip(pred["evidence"], pred["title"])
78 | ]
79 |
80 | # Add markers for predicted phrases
81 | if args.mark_phrase:
82 | my_dict["ctxs"] = [
83 | {"title": ctx["title"], "text": ctx["text"][:se[0]] + f"{START} " + ctx["text"][se[0]:se[1]] + f" {END}" + ctx["text"][se[1]:]}
84 | for ctx, se in zip(my_dict["ctxs"], se_idxs)
85 | ]
86 |
87 | my_target.append(my_dict)
88 | avg_len += [len(ctx['text'].split()) for ctx in my_dict["ctxs"]]
89 | assert len(my_dict["ctxs"]) == args.psg_top_k
90 | assert all(len(ctx['text'].split()) <= args.max_context_len for ctx in my_dict["ctxs"])
91 |
92 | print(f"avg ctx len={sum(avg_len)/len(avg_len):.2f} for {len(my_pred)} preds")
93 |
94 | out_file = os.path.join(
95 | args.model_dir, 'pred',
96 | os.path.splitext(args.pred_file)[0] +
97 | f'_{"sent" if args.return_sent else "psg"}-top{args.psg_top_k}{"_mark" if args.mark_phrase else ""}.json'
98 | )
99 | print(f"dump to {out_file}")
100 | json.dump(my_target, open(out_file, 'w'), indent=4)
101 |
102 |
103 | if __name__ == '__main__':
104 | parser = argparse.ArgumentParser()
105 |
106 | parser.add_argument('--model_dir', type=str, default='')
107 | parser.add_argument('--pred_file', type=str, default='')
108 | parser.add_argument('--psg_top_k', type=int, default=100)
109 | parser.add_argument('--max_context_len', type=int, default=999999999)
110 | parser.add_argument('--mark_phrase', default=False, action='store_true')
111 | parser.add_argument('--return_sent', default=False, action='store_true')
112 | parser.add_argument('--sent_window', type=int, default=0)
113 | args = parser.parse_args()
114 |
115 | main(args)
116 |
--------------------------------------------------------------------------------
/scripts/preprocess/README.md:
--------------------------------------------------------------------------------
1 | ## Create SQuAD-Style Wiki Dump (20181220)
2 |
3 | ### Download wiki dump of 20181220
4 | ```
5 | python download_wikidump.py \
6 | --output_dir /hdd1/data/wikidump
7 | ```
8 |
9 | ### Extract Wiki dump via Wikiextractor
10 | Use [Wikiextractor](https://github.com/attardi/wikiextractor) to convert wiki dump into the json style.
11 |
12 | ```
13 | python WikiExtractor.py \
14 | --filter_disambig_pages \
15 | --json \
16 | -o /hdd1/data/wikidump/extracted/ \
17 | /hdd1/data/wikidump/enwiki-20181220-pages-articles.xml.bz2
18 | ```
19 |
20 | ### Build docs.db in SQlite style
21 | ```
22 | python build_db.py \
23 | --data_path /hdd1/data/wikidump/extracted \
24 | --save_path /hdd1/data/wikidump/docs_20181220.db \
25 | --preprocess prep_wikipedia.py
26 | ```
27 |
28 | ### Transform sqlite to squad-style
29 | ```
30 | python build_wikisquad.py \
31 | --db_path /hdd1/data/wikidump/docs_20181220.db \
32 | --out_dir /hdd1/data/wikidump/20181220
33 | ```
34 |
35 | ### Concatenate short length of paragraphs
36 | ```
37 | python concat_wikisquad.py \
38 | --input_dir /hdd1/data/wikidump/20181220 \
39 | --output_dir /hdd1/data/wikidump/20181220_concat
40 | ```
41 |
--------------------------------------------------------------------------------
/scripts/preprocess/build_db.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright 2017-present, Facebook, Inc.
3 | # All rights reserved.
4 | #
5 | # This source code is licensed under the license found in the
6 | # LICENSE file in the root directory of this source tree.
7 | # https://github.com/facebookresearch/DrQA/blob/master/scripts/retriever/build_db.py
8 | """A script to read in and store documents in a sqlite database."""
9 |
10 | import argparse
11 | import sqlite3
12 | import json
13 | import os
14 | import logging
15 | import importlib.util
16 |
17 | from multiprocessing import Pool as ProcessPool
18 | from tqdm import tqdm
19 | import unicodedata
20 |
21 | logger = logging.getLogger()
22 | logger.setLevel(logging.INFO)
23 | fmt = logging.Formatter('%(asctime)s: [ %(message)s ]', '%m/%d/%Y %I:%M:%S %p')
24 | console = logging.StreamHandler()
25 | console.setFormatter(fmt)
26 | logger.addHandler(console)
27 |
28 | # ------------------------------------------------------------------------------
29 | # Utils
30 | # ------------------------------------------------------------------------------
31 |
32 | def normalize(text):
33 | """Resolve different type of unicode encodings."""
34 | return unicodedata.normalize('NFD', text)
35 |
36 | # ------------------------------------------------------------------------------
37 | # Import helper
38 | # ------------------------------------------------------------------------------
39 |
40 |
41 | PREPROCESS_FN = None
42 |
43 |
44 | def init(filename):
45 | global PREPROCESS_FN
46 | if filename:
47 | PREPROCESS_FN = import_module(filename).preprocess
48 |
49 |
50 | def import_module(filename):
51 | """Import a module given a full path to the file."""
52 | spec = importlib.util.spec_from_file_location('doc_filter', filename)
53 | module = importlib.util.module_from_spec(spec)
54 | spec.loader.exec_module(module)
55 | return module
56 |
57 |
58 | # ------------------------------------------------------------------------------
59 | # Store corpus.
60 | # ------------------------------------------------------------------------------
61 |
62 |
63 | def iter_files(path):
64 | """Walk through all files located under a root path."""
65 | if os.path.isfile(path):
66 | yield path
67 | elif os.path.isdir(path):
68 | for dirpath, _, filenames in os.walk(path):
69 | for f in filenames:
70 | yield os.path.join(dirpath, f)
71 | else:
72 | raise RuntimeError('Path %s is invalid' % path)
73 |
74 |
75 | def get_contents(filename):
76 | """Parse the contents of a file. Each line is a JSON encoded document."""
77 | global PREPROCESS_FN
78 | documents = []
79 | with open(filename) as f:
80 | for line in f:
81 | # Parse document
82 | doc = json.loads(line)
83 | # Maybe preprocess the document with custom function
84 | if PREPROCESS_FN:
85 | doc = PREPROCESS_FN(doc)
86 | # Skip if it is empty or None
87 | if not doc:
88 | continue
89 | # Add the document
90 | documents.append((normalize(doc['id']), doc['text']))
91 | return documents
92 |
93 |
94 | def store_contents(data_path, save_path, preprocess, num_workers=None):
95 | """Preprocess and store a corpus of documents in sqlite.
96 | Args:
97 | data_path: Root path to directory (or directory of directories) of files
98 | containing json encoded documents (must have `id` and `text` fields).
99 | save_path: Path to output sqlite db.
100 | preprocess: Path to file defining a custom `preprocess` function. Takes
101 | in and outputs a structured doc.
102 | num_workers: Number of parallel processes to use when reading docs.
103 | """
104 | if os.path.isfile(save_path):
105 | raise RuntimeError('%s already exists! Not overwriting.' % save_path)
106 |
107 | logger.info('Reading into database...')
108 | conn = sqlite3.connect(save_path)
109 | c = conn.cursor()
110 | c.execute("CREATE TABLE documents (id PRIMARY KEY, text);")
111 |
112 | workers = ProcessPool(num_workers, initializer=init, initargs=(preprocess,))
113 | files = [f for f in iter_files(data_path)]
114 | count = 0
115 | with tqdm(total=len(files)) as pbar:
116 | for pairs in tqdm(workers.imap_unordered(get_contents, files)):
117 | count += len(pairs)
118 | c.executemany("INSERT OR IGNORE INTO documents VALUES (?,?)", pairs)
119 | pbar.update()
120 | logger.info('Read %d docs.' % count)
121 | logger.info('Committing...')
122 | conn.commit()
123 | conn.close()
124 |
125 |
126 | # ------------------------------------------------------------------------------
127 | # Main.
128 | # ------------------------------------------------------------------------------
129 |
130 |
131 | if __name__ == '__main__':
132 | parser = argparse.ArgumentParser()
133 | parser.add_argument('--data_path', type=str, help='/path/to/data')
134 | parser.add_argument('--save_path', type=str, help='/path/to/saved/db.db')
135 | parser.add_argument('--preprocess', type=str, default=None,
136 | help=('File path to a python module that defines '
137 | 'a `preprocess` function'))
138 | parser.add_argument('--num-workers', type=int, default=None,
139 | help='Number of CPU processes (for tokenizing, etc)')
140 | args = parser.parse_args()
141 |
142 | store_contents(
143 | args.data_path, args.save_path, args.preprocess, args.num_workers
144 | )
--------------------------------------------------------------------------------
/scripts/preprocess/compress_metadata.py:
--------------------------------------------------------------------------------
1 | import pdb
2 | import os
3 | import h5py
4 | from tqdm import tqdm
5 | import sys
6 | import zlib
7 | import numpy as np
8 | import traceback
9 | import blosc
10 | import pickle
11 | import argparse
12 |
13 | # get size of the whole metadata
14 | def get_size(d):
15 | size = 0
16 | for i in d:
17 | word2char_start_size = sys.getsizeof(d[i]['word2char_start'])
18 | word2char_end_size = sys.getsizeof(d[i]['word2char_end'])
19 | f2o_start_size = sys.getsizeof(d[i]['f2o_start'])
20 | context_size = sys.getsizeof(d[i]['context'])
21 | title_size = sys.getsizeof(d[i]['title'])
22 | size+=word2char_start_size
23 | size+=word2char_end_size
24 | size+=f2o_start_size
25 | size+=context_size
26 | size+=title_size
27 |
28 | return size
29 |
30 | # compress metadata using zlib
31 | # http://python-blosc.blosc.org/tutorial.html
32 | def compress(d):
33 | for i in d:
34 | word2char_start = d[i]['word2char_start']
35 | word2char_end = d[i]['word2char_end']
36 | f2o_start = d[i]['f2o_start']
37 | context=d[i]['context']
38 | title=d[i]['title']
39 |
40 | # save type to use when decompressing
41 | type1= word2char_start.dtype
42 | type2= word2char_end.dtype
43 | type3= f2o_start.dtype
44 |
45 | d[i]['word2char_start'] = blosc.compress(word2char_start, typesize=1,cname='zlib')
46 | d[i]['word2char_end'] = blosc.compress(word2char_end, typesize=1,cname='zlib')
47 | d[i]['f2o_start'] = blosc.compress(f2o_start, typesize=1,cname='zlib')
48 | d[i]['context'] = blosc.compress(context.encode('utf-8'),cname='zlib')
49 | d[i]['dtypes']={
50 | 'word2char_start':type1,
51 | 'word2char_end':type2,
52 | 'f2o_start':type3
53 | }
54 |
55 | # check if compression is lossless
56 | try:
57 | decompressed_word2char_start = np.frombuffer(blosc.decompress(d[i]['word2char_start']), type1)
58 | decompressed_word2char_end = np.frombuffer(blosc.decompress(d[i]['word2char_end']), type2)
59 | decompressed_f2o_start = np.frombuffer(blosc.decompress(d[i]['f2o_start']), type3)
60 | decompressed_context = blosc.decompress(d[i]['context']).decode('utf-8')
61 |
62 | assert ((word2char_start == decompressed_word2char_start).all())
63 | assert ((word2char_end == decompressed_word2char_end).all())
64 | assert ((f2o_start ==decompressed_f2o_start).all())
65 | assert (context == decompressed_context)
66 | except Exception as e:
67 | print(e)
68 | traceback.print_exc()
69 | pdb.set_trace()
70 | return d
71 |
72 | def load_doc_groups(phrase_dump_dir):
73 | phrase_dump_paths = sorted(
74 | [os.path.join(phrase_dump_dir, name) for name in os.listdir(phrase_dump_dir) if 'hdf5' in name]
75 | )
76 | doc_groups = {}
77 | types = ['word2char_start', 'word2char_end', 'f2o_start']
78 | attrs = ['context', 'title']
79 | phrase_dumps = [h5py.File(path, 'r') for path in phrase_dump_paths]
80 | phrase_dumps = phrase_dumps[:1]
81 | for path in tqdm(phrase_dump_paths, desc='loading doc groups'):
82 | with h5py.File(path, 'r') as f:
83 | for key in tqdm(f):
84 | doc_group = {}
85 | for type_ in types:
86 | doc_group[type_] = f[key][type_][:]
87 | for attr in attrs:
88 | doc_group[attr] = f[key].attrs[attr]
89 | doc_groups[key] = doc_group
90 |
91 | return doc_groups
92 |
93 | def main(args):
94 | # Use it for saving to memory
95 | doc_groups = load_doc_groups(args.input_dump_dir)
96 |
97 | # Get the size of meta data before compression
98 | size_before_compression = get_size(doc_groups)
99 |
100 | # compress metadata using zlib
101 | doc_groups = compress(doc_groups)
102 |
103 | # Get the size of meta data before compression
104 | size_after_compression = get_size(doc_groups)
105 |
106 | print(f"compressed by {round(size_after_compression/size_before_compression*100,2)}%")
107 |
108 | # save compressed meta as a pickle format
109 | output_file = os.path.join(args.output_dir, 'meta_compressed.pkl')
110 | with open(output_file,'wb') as f:
111 | pickle.dump(doc_groups, f)
112 |
113 | if __name__ == '__main__':
114 | parser = argparse.ArgumentParser()
115 |
116 | parser.add_argument('--input_dump_dir', type=str, default='dump/sbcd_sqdqgnqqg_inb64_s384_sqdnq_pinb2_0_20181220_concat/dump/phrase')
117 | parser.add_argument('--output_dir', type=str, default='./')
118 | args = parser.parse_args()
119 |
120 | main(args)
121 |
--------------------------------------------------------------------------------
/scripts/preprocess/concat_wikisquad.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import argparse
4 |
5 | from tqdm import tqdm
6 | import pdb
7 |
8 | def normalize(text):
9 | return text.lower().replace('_', ' ')
10 |
11 |
12 | def concat_wikisquad(args):
13 | names = os.listdir(args.input_dir)
14 | data = {'data': []}
15 | for name in tqdm(names):
16 | from_path = os.path.join(args.input_dir, name)
17 | with open(from_path, 'r') as fp:
18 | from_ = json.load(fp)
19 |
20 | for ai, article in enumerate(from_['data']):
21 | article['id'] = int(name) * 1000 + ai
22 |
23 | articles = []
24 | for article in from_['data']:
25 | articles.append(article)
26 |
27 | for article in articles:
28 | to_article = {'title': article['title'], 'paragraphs': []}
29 | context = ""
30 | for para_idx, para in enumerate(article['paragraphs']):
31 | context = context + " " + para['context']
32 | if args.min_num_chars <= len(context):
33 | to_article['paragraphs'].append({'context': context})
34 | context = ""
35 | # if the length of the last paragraph is less than min_num_chars,
36 | # append it to the previous saving
37 | elif para_idx == len(article['paragraphs']) -1 :
38 | if len(to_article['paragraphs']):
39 | previous_context = to_article['paragraphs'][-1]['context']
40 | previous_context = previous_context + " " + context
41 | to_article['paragraphs'][-1]['context'] = previous_context
42 | # if no previous saving exists, create it.
43 | else:
44 | to_article['paragraphs'].append({'context': context})
45 |
46 | data['data'].append(to_article)
47 |
48 | if not os.path.exists(args.output_dir):
49 | os.makedirs(args.output_dir)
50 | for start_idx in range(0, len(data['data']), args.docs_per_file):
51 | to_path = os.path.join(args.output_dir, str(int(start_idx / args.docs_per_file)).zfill(4))
52 | cur_data = {'data': data['data'][start_idx:start_idx + args.docs_per_file]}
53 | with open(to_path, 'w') as fp:
54 | json.dump(cur_data, fp)
55 |
56 | def get_args():
57 | parser = argparse.ArgumentParser()
58 | parser.add_argument('--input_dir')
59 | parser.add_argument('--output_dir')
60 | parser.add_argument('--min_num_chars', default=500, type=int)
61 | parser.add_argument('--docs_per_file', default=1000, type=int)
62 |
63 | return parser.parse_args()
64 |
65 |
66 | def main():
67 | args = get_args()
68 | concat_wikisquad(args)
69 |
70 | if __name__ == '__main__':
71 | main()
--------------------------------------------------------------------------------
/scripts/preprocess/create_nq_reader.py:
--------------------------------------------------------------------------------
1 | import json
2 | import argparse
3 | import pdb
4 | import glob
5 | from nq_utils import load_examples
6 |
7 | def convert_tokens_to_answer(paragraph_tokens, answer_tokens):
8 | answer_token_indexes = []
9 | for answer_token in answer_tokens:
10 | answer_token_index = paragraph_tokens.index(answer_token)
11 | answer_token_indexes.append(answer_token_index)
12 |
13 | if len(answer_token_indexes) != (answer_token_indexes[-1] - answer_token_indexes[0] + 1):
14 | print("answer_token_indexes=",answer_token_indexes)
15 | pdb.set_trace()
16 |
17 |
18 | context = ""
19 | answer_text = ""
20 | answer_start = -1
21 | for i, paragraph_token in enumerate(paragraph_tokens):
22 | # skip html token
23 | if not paragraph_token['html_token']:
24 | token = paragraph_token['token']
25 |
26 | # prepare appending token with white space
27 | if context != "": context +=" "
28 |
29 | # update answer_start
30 | if i == answer_token_indexes[0]:
31 | answer_start = len(context)
32 |
33 | # append token
34 | context += token
35 |
36 | # update answer_end
37 | if i == answer_token_indexes[-1]:
38 | answer_end = len(context)
39 |
40 | answer_text = context[answer_start:answer_end]
41 |
42 | # sanity check
43 | assert context != ""
44 | assert answer_text != ""
45 | assert answer_start != -1
46 |
47 | return context, answer_text, answer_start
48 |
49 | def main(args):
50 | # load nq_open and get ids
51 | with open(args.nq_open_path, 'r') as f:
52 | nq_open_data = json.load(f)['data']
53 | nq_open_ids = [qas['id'] for qas in nq_open_data]
54 |
55 | # load nq_orig
56 | nq_orig_paths = sorted(glob.glob(args.nq_orig_path_pattern))
57 | nq_reader_data = []
58 | for i, nq_orig_path in enumerate(nq_orig_paths):
59 | with open(nq_orig_path, mode='rb') as fileobj:
60 | examples = load_examples(fileobj, 'train', 'short_answers')
61 |
62 | # filter examples contained in nq_open ids
63 | examples = dict(filter(lambda x: int(x[0]) in nq_open_ids, list(examples.items())))
64 |
65 | for example_id, example in examples.items():
66 | # filter candidates with answers
67 | candidates = list(filter(lambda x: x.contains_answer, example.candidates))
68 | if len(candidates) == 0:
69 | continue
70 |
71 | title = example.title
72 | # TODO! consider multi annotation for nq_orig_dev set
73 | short_answers = example.short_answers[0] # assume single annotation
74 | paragraphs=[]
75 |
76 | for candidate in candidates:
77 | # filter examples
78 | contents = candidate.contents
79 | is_paragraph = contents.startswith('
')
80 | start_token = candidate.start_token
81 | end_token = candidate.end_token
82 | tokens = example.document_tokens[start_token:end_token]
83 |
84 | answers = []
85 | for short_answer in short_answers:
86 | answer_start_token = short_answer['start_token']
87 | answer_end_token = short_answer['end_token']
88 | if answer_end_token-answer_start_token>5:
89 | continue
90 | answer_tokens = example.document_tokens[answer_start_token:answer_end_token]
91 | # convert tokens to context, answer_text, answer_start
92 | context, answer_text, answer_start = convert_tokens_to_answer(tokens, answer_tokens)
93 | answers.append({
94 | 'text': answer_text,
95 | 'answer_start': answer_start
96 | })
97 |
98 | qas = [{
99 | 'question':example.question_text,
100 | 'is_impossible': False if is_paragraph else True,
101 | 'answers':answers,
102 | 'is_distant': False,
103 | 'id':int(example_id),
104 | }]
105 | paragraphs.append({
106 | 'context':context,
107 | 'qas':qas
108 | })
109 | nq_reader_data.append({
110 | 'title': title,
111 | 'paragraphs':paragraphs
112 | })
113 |
114 | nq_reader = {
115 | 'data' : nq_reader_data
116 | }
117 | # save nq_reader
118 | with open(args.output_path,'w') as f:
119 | json.dump(nq_reader, f, indent=2)
120 |
121 | if __name__ == '__main__':
122 | parser = argparse.ArgumentParser()
123 |
124 | # Required parameters
125 | parser.add_argument(
126 | "--nq_open_path",
127 | default=None,
128 | type=str,
129 | required=True,
130 | help="nq-open path (eg. nq-open/dev.json)"
131 | )
132 | parser.add_argument(
133 | "--nq_orig_path_pattern",
134 | default=None,
135 | type=str,
136 | required=True,
137 | help="nq-open path (eg. natural-questions/train/nq-train-*.jsonl.gz)"
138 | )
139 | parser.add_argument(
140 | "--output_path",
141 | default=None,
142 | type=str,
143 | required=True,
144 | help="nq-reader directory (eg. nq-reader/dev.json)"
145 | )
146 |
147 | args = parser.parse_args()
148 |
149 | main(args)
150 |
151 |
--------------------------------------------------------------------------------
/scripts/preprocess/create_nq_reader_doc_wiki.py:
--------------------------------------------------------------------------------
1 | import json
2 | import glob
3 | import argparse
4 | from tqdm import tqdm
5 | import os
6 |
7 | def main(args):
8 | wiki_dir = args.wiki_dir
9 | nq_dir = args.nq_reader_docs_dir
10 | output_dir = args.output_dir
11 |
12 | wiki_file_list = glob.glob(os.path.join(wiki_dir,"*"))
13 | wiki_titles = []
14 | num_wiki = 0
15 | wiki_title2paragraphs = {}
16 | for filename in tqdm(wiki_file_list, total=len(wiki_file_list)):
17 | with open(filename,'r') as f:
18 | data = json.load(f)['data']
19 |
20 | for doc in data:
21 | title = doc['title']
22 | wiki_titles.append(title)
23 | paragraph = doc['paragraphs']
24 | wiki_title2paragraphs[title] = paragraph
25 | num_wiki += 1
26 |
27 | assert len(wiki_title2paragraphs) == num_wiki
28 |
29 | nq_file_list = glob.glob(os.path.join(nq_dir,"*"))
30 | nq_titles = []
31 | unmatched_titles = []
32 | num_matched = 0
33 | num_unmatched = 0
34 | for filename in tqdm(nq_file_list, total=len(nq_file_list)):
35 | with open(filename,'r') as f:
36 | data = json.load(f)['data']
37 |
38 | for doc in data:
39 | title = doc['title']
40 | nq_titles.append(title)
41 | if title in wiki_title2paragraphs:
42 | doc['paragraphs'] = wiki_title2paragraphs[title]
43 | num_matched += 1
44 | else:
45 | unmatched_titles.append(title)
46 | num_unmatched +=1
47 |
48 | new_paragraphs = []
49 | for paragraph in doc['paragraphs']:
50 | if ('is_paragraph' in paragraph) and (not paragraph['is_paragraph']):
51 | continue
52 |
53 | new_paragraphs.append({
54 | 'context': paragraph['context']
55 | })
56 | doc['paragraphs'] = new_paragraphs
57 |
58 | if not os.path.exists(output_dir):
59 | os.mkdir(output_dir)
60 |
61 | output_path = os.path.join(output_dir,os.path.basename(filename))
62 | output = {
63 | 'data': data
64 | }
65 |
66 | with open(output_path, 'w') as f:
67 | json.dump(output, f, indent=2)
68 |
69 | # with open('unmatched_title.txt', 'w') as f:
70 | # for title in unmatched_titles:
71 | # if 'list of' in title:
72 | # continue
73 | # f.writelines(title)
74 | # f.writelines("\n")
75 |
76 | print("num_matched={} num_unmatched={}".format(num_matched, num_unmatched))
77 | print("len(nq_titles)={} len(wiki_titles)={}".format(len(nq_titles), len(wiki_titles)))
78 |
79 | if __name__ == '__main__':
80 | parser = argparse.ArgumentParser()
81 |
82 | # Required parameters
83 | parser.add_argument("--wiki_dir", type=str, required=True)
84 | parser.add_argument("--nq_reader_docs_dir", type=str, required=True)
85 | parser.add_argument("--output_dir", type=str, required=True)
86 |
87 | args = parser.parse_args()
88 |
89 | main(args)
90 |
91 |
--------------------------------------------------------------------------------
/scripts/preprocess/create_nq_reader_wiki.py:
--------------------------------------------------------------------------------
1 | import json
2 | import argparse
3 | import os
4 | import glob
5 | import copy
6 |
7 | from tqdm import tqdm
8 |
9 |
10 | def nq_to_wiki(input_file, output_dir, wiki_dump):
11 | with open(input_file, 'r') as f:
12 | nq_data = json.load(f)['data']
13 |
14 | para_cnt = 0
15 | match_cnt = 0
16 | title_not_found_cnt = 0
17 | answer_not_found_cnt = 0
18 | tokenize_error = 0
19 | WINDOW = 10
20 | new_data = []
21 | for article in tqdm(nq_data):
22 | title = article['title'] if type(article['title']) != list else article['title'][0]
23 |
24 | assert len(article['paragraphs']) == 1
25 | for paragraph in article['paragraphs']:
26 | para_cnt += 1
27 | new_paragraph = None
28 | answer_found = False
29 | assert len(paragraph['qas']) == 1, 'NQ only has single para for each Q'
30 |
31 | # We skip these cases and use existing paras
32 | qa = paragraph['qas'][0]
33 | if 'redundant' in str(qa['id']):
34 | break
35 |
36 | if qa['is_impossible'] or (title not in wiki_dump):
37 | pass
38 | else:
39 | # Or we find matching answers
40 | answers = qa['answers'] if type(qa['answers']) == list else [qa['answers']]
41 | for answer in answers:
42 | start_window = WINDOW if WINDOW < answer['answer_start'] else answer['answer_start']
43 |
44 | answer_text = paragraph['context'][
45 | answer['answer_start']:answer['answer_start']+len(answer['text'])
46 | ].replace('\'\'', '"').replace('``', '"').replace(' ', '').lower()
47 |
48 | answer_text_with_context = [
49 | paragraph['context'][ # Front/Back 10 chars
50 | answer['answer_start']-start_window:answer['answer_start']+len(answer['text'])+WINDOW
51 | ].replace('\'\'', '"').replace('``', '"').replace(' ', '').lower(),
52 | paragraph['context'][ # Front 10 chars
53 | answer['answer_start']-start_window:answer['answer_start']+len(answer['text'])
54 | ].replace('\'\'', '"').replace('``', '"').replace(' ', '').lower(),
55 | paragraph['context'][ # Back 10 chars
56 | answer['answer_start']:answer['answer_start']+len(answer['text'])+WINDOW
57 | ].replace('\'\'', '"').replace('``', '"').replace(' ', '').lower(),
58 | ]
59 |
60 | new_start = None
61 | wiki_paragraph = None
62 | for wiki_par in wiki_dump[title]:
63 | wiki_par_char = ''.join([char.lower()[0] for char in wiki_par['context'].replace(' ', '')])
64 | nosp_to_sp = {}
65 | for sp_idx, char in enumerate(wiki_par['context']):
66 | if char != ' ':
67 | nosp_to_sp[len(nosp_to_sp)] = sp_idx
68 | assert len(nosp_to_sp) == len(wiki_par_char)
69 |
70 | # Context match
71 | if any([at_with_context in wiki_par_char for at_with_context in answer_text_with_context]):
72 | at_with_context = [at for at in answer_text_with_context if at in wiki_par_char][0]
73 | tmp_start = wiki_par_char.index(at_with_context)
74 | if len([at for at in answer_text_with_context if at in wiki_par_char]) < 3:
75 | if at_with_context == answer_text: # There are some false negatives but we skip
76 | # print(paragraph['context'])
77 | # print(wiki_par['context'])
78 | # print(answer_text)
79 | # import pdb; pdb.set_trace()
80 | break
81 | # try:
82 | new_start = nosp_to_sp[wiki_par_char[tmp_start:].index(answer_text)+tmp_start]
83 | new_end = nosp_to_sp[wiki_par_char[tmp_start:].index(answer_text)+tmp_start+len(answer_text)-1]
84 | wiki_paragraph = copy.deepcopy(wiki_par['context'])
85 | # except ValueError as e:
86 | # print("Could not found start position after de-tokenize")
87 | # tokenize_error += 1
88 | # import pdb; pdb.set_trace()
89 | # continue
90 | answer_found = True
91 | break
92 | # elif answer_text in wiki_par_char:
93 | # answer_found = True
94 |
95 | # If answer is found, append
96 | if new_start is not None:
97 | if answer_text != wiki_par['context'][new_start:new_end+1].lower().replace(' ', ''):
98 | print('mismatch between original vs. new answer: {} vs. {}'.format(
99 | answer_text, wiki_par['context'][new_start:new_end+1].lower().replace(' ', '')
100 | ))
101 |
102 | if new_paragraph is None:
103 | new_paragraph = copy.deepcopy(paragraph)
104 | new_paragraph['context'] = wiki_paragraph
105 | new_paragraph['qas'][0]['answers'] = [{
106 | 'text': wiki_paragraph[new_start:new_end+1],
107 | 'answer_start': new_start,
108 | 'wiki_matched': True,
109 | }]
110 | else:
111 | if new_paragraph['context'] != wiki_paragraph: # If other answers are in different para, we skip
112 | continue
113 | new_paragraph['qas'][0]['answers'].append({
114 | 'text': wiki_paragraph[new_start:new_end+1],
115 | 'answer_start': new_start,
116 | 'wiki_matched': True,
117 | })
118 |
119 | # Just use existing paragraph when no answer is found
120 | if not answer_found:
121 | answer_not_found_cnt += 1
122 | new_paragraph = copy.deepcopy(paragraph)
123 | for qas in new_paragraph['qas']:
124 | for ans in qas['answers']:
125 | ans['wiki_matched'] = False
126 | else:
127 | match_cnt += 1
128 |
129 | assert new_paragraph is not None
130 | new_data.append({
131 | 'title': title,
132 | 'paragraphs': [new_paragraph],
133 | })
134 |
135 | print(f'matched title: {para_cnt}')
136 | print(f'not found title: {title_not_found_cnt}')
137 | print(f'matched answer: {match_cnt}')
138 | print(f'answer not found: {answer_not_found_cnt}')
139 | print(f'tokenize error: {tokenize_error}')
140 | print(f'total saved data: {len(new_data)}')
141 |
142 | output_path = os.path.join(
143 | os.path.dirname(input_file), os.path.splitext(os.path.basename(input_file))[0] + '_wiki3.json'
144 | )
145 | print(f'Saving into {output_path}')
146 | with open(output_path, 'w') as f:
147 | json.dump({'data': new_data}, f)
148 | print()
149 |
150 |
151 | if __name__ == '__main__':
152 | parser = argparse.ArgumentParser()
153 | parser.add_argument('input_files', type=str, default=None)
154 | parser.add_argument('output_dir', type=str)
155 | parser.add_argument('wiki_dir', type=str, default=None)
156 | args = parser.parse_args()
157 |
158 | # Prepare wiki first
159 | wiki_files = sorted(glob.glob(args.wiki_dir + "*"))
160 | print(f'Matching with {len(wiki_files)} number of wikisquad files')
161 | wiki_dump = {}
162 | for wiki_file in tqdm(wiki_files):
163 | with open(wiki_file, 'r') as f:
164 | wiki_squad = json.load(f)
165 | for wiki_article in wiki_squad['data']:
166 | wiki_dump[wiki_article['title']] = wiki_article['paragraphs']
167 | # break
168 |
169 | for input_file in args.input_files.split(','):
170 | print(f'Processing {input_file}')
171 | nq_to_wiki(input_file, args.output_dir, wiki_dump)
172 |
--------------------------------------------------------------------------------
/scripts/preprocess/create_openqa.py:
--------------------------------------------------------------------------------
1 | import json
2 | import argparse
3 | import os
4 | import csv
5 |
6 | from tqdm import tqdm
7 | # from drqa.retriever.utils import normalize
8 |
9 | def get_gold_answers_kilt(gold):
10 | ground_truths = set()
11 | for item in gold["output"]:
12 | if "answer" in item and item["answer"] and len(item["answer"].strip()) > 0:
13 | ground_truths.add(item["answer"].strip())
14 | return ground_truths
15 |
16 | def preprocess_openqa(input_file, input_type, out_dir):
17 | data_to_save = []
18 | # SQuAD
19 | if input_type == 'SQuAD':
20 | with open(input_file, 'r') as f:
21 | articles = json.load(f)['data']
22 | for article in articles:
23 | for paragraph in article['paragraphs']:
24 | for qa in paragraph['qas']:
25 | if type(qa['answers']) == dict:
26 | qa['answers'] = [qa['answers']]
27 | data_to_save.append({
28 | 'id': qa['id'],
29 | 'question': qa['question'],
30 | 'answers': [ans['text'] for ans in qa['answers']]
31 | })
32 | # CuratedTrec / WebQuestions / WikiMovies
33 | elif input_type == 'DrQA':
34 | tag = os.path.splitext(os.path.basename(input_file))[0]
35 | for line_idx, line in tqdm(enumerate(open(input_file))):
36 | data = json.loads(line)
37 | # answers = [normalize(a) for a in data['answer']] # necessary?
38 | answers = [a for a in data['answer']]
39 | data_to_save.append({
40 | 'id': f'{tag}_{line_idx}',
41 | 'question': data['question'],
42 | 'answers': answers
43 | })
44 | # NaturalQuestions / TriviaQA
45 | elif input_type == 'HardEM':
46 | tag = os.path.splitext(os.path.basename(input_file))[0]
47 | data = json.load(open(input_file))['data']
48 | for item_idx, item in tqdm(enumerate(data)):
49 | data_to_save.append({
50 | 'id': f'{tag}_{item_idx}',
51 | 'question': item['question'],
52 | 'answers': item['answers']
53 | })
54 | # DPR style files
55 | elif input_type == 'DPR':
56 | tag = os.path.splitext(os.path.basename(input_file))[0]
57 | data = json.load(open(input_file))
58 | for item_idx, item in tqdm(enumerate(data)):
59 | data_to_save.append({
60 | 'id': f'{tag}_{item_idx}',
61 | 'question': item['question'],
62 | 'answers': item['answers']
63 | })
64 | # COVID-19
65 | elif input_type == 'COVID-19':
66 | assert os.path.isdir(input_file)
67 | for filename in os.listdir(input_file):
68 | if 'preprocessed' in filename:
69 | print(f'Skipping {filename}')
70 | continue
71 | file_path = os.path.join(input_file, filename)
72 | tag = os.path.splitext(os.path.basename(file_path))[0]
73 | with open(file_path, 'r') as f:
74 | with tqdm(enumerate(f)) as tq:
75 | tq.set_description(filename + '\t')
76 | for line_idx, line in tq:
77 | data_to_save.append({
78 | 'id': f'{tag}_{line_idx}',
79 | 'question': line.strip(),
80 | 'answers': ['']
81 | })
82 | # TREX, ZSRE (KILT)
83 | elif input_type.lower() in ['trex', 't-rex', 'zsre']:
84 | with open(input_file) as f:
85 | for line in tqdm(f):
86 | data = json.loads(line)
87 | id = data['id']
88 | question = data['input']
89 | answers = get_gold_answers_kilt(data)
90 | answers = list(answers)
91 |
92 | data_to_save.append({
93 | 'id': id,
94 | 'question': question,
95 | 'answers': answers
96 | })
97 | # Jsonl (LAMA)
98 | elif input_type.lower() in ['jsonl']:
99 | tag = os.path.splitext(os.path.basename(input_file))[0]
100 | with open(input_file) as f:
101 | for line_idx, line in tqdm(enumerate(f)):
102 | data = json.loads(line)
103 | question = data['question']
104 | answers = data['answer']
105 |
106 | data_to_save.append({
107 | 'id': f'{tag}_{line_idx}',
108 | 'question': question,
109 | 'answers': answers
110 | })
111 | # CSV
112 | elif input_type.lower() in ['csv']:
113 | import ast
114 | tag = os.path.splitext(os.path.basename(input_file))[0]
115 | with open(input_file) as f:
116 | csv_reader = csv.reader(f, delimiter='\t')
117 | for line_idx, line in tqdm(enumerate(csv_reader)):
118 | question = line[0]
119 | answers = ast.literal_eval(line[1])
120 |
121 | data_to_save.append({
122 | 'id': f'{tag}_{line_idx}',
123 | 'question': question,
124 | 'answers': answers
125 | })
126 | else:
127 | raise NotImplementedError
128 |
129 | assert os.path.exists(out_dir)
130 | out_path = os.path.join(out_dir, os.path.splitext(os.path.basename(input_file))[0] + '_preprocessed.json')
131 | print(f'Saving {len(data_to_save)} questions.')
132 | print('Writing to %s\n'% out_path)
133 | with open(out_path, 'w') as f:
134 | json.dump({'data': data_to_save}, f)
135 |
136 |
137 | if __name__ == '__main__':
138 | parser = argparse.ArgumentParser()
139 | parser.add_argument('input_file', type=str, default=None)
140 | parser.add_argument('out_dir', type=str)
141 | parser.add_argument('--input_type', type=str, default='SQuAD', help='SQuAD|DrQA|HardEM')
142 | args = parser.parse_args()
143 | preprocess_openqa(args.input_file, args.input_type, args.out_dir)
144 |
--------------------------------------------------------------------------------
/scripts/preprocess/create_psg_hdf5.py:
--------------------------------------------------------------------------------
1 | import json
2 | import argparse
3 | import os
4 | import h5py
5 | import csv
6 |
7 | from tqdm import tqdm
8 |
9 |
10 | def create_psg_hdf5(input_file, out_file):
11 | passages = {}
12 | with open(input_file) as f:
13 | psg_file = csv.reader(f, delimiter='\t')
14 | for data_idx, data in tqdm(enumerate(psg_file)):
15 | if data_idx == 0:
16 | print('Reading', data)
17 | continue
18 | id_, psg, title = data
19 | passages[id_] = [psg, title]
20 | # break
21 |
22 | # Must use bucket; otherwise writing to a hdf5 file is very slow with a large number of keys
23 | bucket_size = 1000000
24 | # buckets = [(start, min(start+bucket_size-1, 21015324)) for start in range(1, 21015325, bucket_size)]
25 | buckets = [(start, min(start+bucket_size-1, len(passages))) for start in range(1, len(passages)+1, bucket_size)]
26 | print(f'Putting {len(passages)} passages into {len(buckets)} buckets')
27 | print(buckets)
28 | with h5py.File(out_file, 'w') as f:
29 | for pid, data in tqdm(passages.items()):
30 | bucket_name = None
31 | for start, end in buckets:
32 | if (int(pid) >= start) and (int(pid) <= end):
33 | bucket_name = f'{start}-{end}'
34 | break
35 | assert bucket_name is not None
36 | # continue
37 |
38 | if bucket_name not in f:
39 | dg = f.create_group(bucket_name)
40 | else:
41 | dg = f[bucket_name]
42 | assert pid not in dg
43 | pg = dg.create_group(pid)
44 | pg.attrs['context'], pg.attrs['title'] = data
45 |
46 | print(f'Saving {out_file} done')
47 |
48 |
49 | if __name__ == '__main__':
50 | parser = argparse.ArgumentParser()
51 | parser.add_argument('input_file', type=str, default=None)
52 | parser.add_argument('out_file', type=str)
53 | args = parser.parse_args()
54 | create_psg_hdf5(args.input_file, args.out_file)
55 |
--------------------------------------------------------------------------------
/scripts/preprocess/create_tqa_ds.py:
--------------------------------------------------------------------------------
1 |
2 | import json
3 | import pdb
4 | import re
5 | import random
6 | from tqdm import tqdm
7 | import string
8 | import argparse
9 |
10 | try:
11 | from eval_utils import (
12 | drqa_exact_match_score,
13 | drqa_regex_match_score,
14 | drqa_metric_max_over_ground_truths
15 | )
16 | except ModuleNotFoundError:
17 | import sys
18 | import os
19 | sys.path.append(os.path.dirname(os.path.abspath(os.path.dirname(__file__))))
20 | from eval_utils import (
21 | drqa_exact_match_score,
22 | drqa_regex_match_score,
23 | drqa_metric_max_over_ground_truths
24 | )
25 |
26 | # fix random seed
27 | random.seed(0)
28 |
29 | def find_substring_and_return_random_idx(substring, string):
30 | substring_idxs = [m.start() for m in re.finditer(re.escape(substring), string)]
31 | substring_idx = random.choice(substring_idxs)
32 | return substring_idx
33 |
34 | def main(args):
35 | print("loading input data")
36 | with open(args.input_path, encoding='utf-8') as f:
37 | data = json.load(f)
38 |
39 | output_data = []
40 |
41 | for sample_id in tqdm(data):
42 | sample = data[sample_id]
43 |
44 | question = sample['question']
45 | answers = sample['answer']
46 | predictions = sample['prediction']
47 | titles = sample['title']
48 | evidences = sample['evidence']
49 |
50 | match_fn = drqa_regex_match_score if args.regex else drqa_exact_match_score
51 |
52 | answer_text = ""
53 | answer_start = -1
54 | ds_context = ""
55 | ds_title = ""
56 | # is_from_context = False
57 |
58 | # check if prediction is matched in a golden answer in the answer list
59 | for pred_idx, pred in enumerate(predictions):
60 | if pred != "" and drqa_metric_max_over_ground_truths(match_fn, pred, answers):
61 | answer_text = pred
62 | ds_context = evidences[pred_idx]
63 | ds_title = titles[pred_idx][0]
64 | answer_start = find_substring_and_return_random_idx(answer_text, ds_context)
65 | break
66 |
67 | # NOTE! hide these lines because is_from_context contains too many noises
68 | # # in case prediction is not matched to any golden answer,
69 | # # check if golden answer is contained in the context
70 | # if answer_start < 0:
71 | # found = False
72 | # for evid_idx, evid in enumerate(evidences):
73 | # for ans in answers:
74 | # if ans != "" and ans in evid:
75 | # found = True
76 | # answer_text = ans
77 | # answer_start = find_substring_and_return_random_idx(ans, evid)
78 | # ds_context = evidences[evid_idx]
79 | # ds_title = titles[evid_idx][0]
80 | # is_from_context = True
81 | # if found:
82 | # break
83 |
84 | # no answer is found in
85 | is_impossible = False
86 | if answer_start < 0 or answer_text == "":
87 | ds_title = titles[0][0]
88 | ds_context = evidences[0]
89 | is_impossible = True
90 | else:
91 | assert answer_text == ds_context[answer_start:answer_start+len(answer_text)]
92 |
93 | output_data.append({
94 | 'title': ds_title,
95 | 'paragraphs':[{
96 | 'context': ds_context,
97 | 'qas':[{
98 | 'question': question,
99 | 'is_impossible' : is_impossible,
100 | 'answers': [{
101 | 'text': answer_text,
102 | 'answer_start': answer_start
103 | }] if is_impossible == False else [],
104 | # 'is_from_context':is_from_context
105 | }],
106 | 'id': sample_id
107 | }]
108 | })
109 |
110 | with open(args.output_path, 'w', encoding='utf-8') as f:
111 | json.dump({
112 | 'data': output_data
113 | },f)
114 |
115 |
116 | # ------------------------------------------------------------------------------
117 | # Main.
118 | # ------------------------------------------------------------------------------
119 |
120 |
121 | if __name__ == '__main__':
122 | parser = argparse.ArgumentParser()
123 | parser.add_argument('input_path', type=str, default='/home/pred/sbcd_sqdqgnqqg_inb64_s384_sqdnq_pinb2_0_20181220_concat_train_preprocessed_78785.pred')
124 | parser.add_argument('output_path', type=str, default='tqa_ds_train.json')
125 | parser.add_argument('--regex', action='store_true')
126 | args = parser.parse_args()
127 |
128 | main(args)
129 |
--------------------------------------------------------------------------------
/scripts/preprocess/doc_db.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright 2017-present, Facebook, Inc.
3 | # All rights reserved.
4 | #
5 | # This source code is licensed under the license found in the
6 | # LICENSE file in the root directory of this source tree.
7 | """Documents, in a sqlite database."""
8 |
9 | import sqlite3
10 | import unicodedata
11 |
12 | def normalize(text):
13 | """Resolve different type of unicode encodings."""
14 | return unicodedata.normalize('NFD', text)
15 |
16 | class DocDB(object):
17 | """Sqlite backed document storage.
18 | Implements get_doc_text(doc_id).
19 | """
20 |
21 | def __init__(self, db_path=None):
22 | # self.path = db_path or DEFAULTS['db_path']
23 | self.path = db_path
24 | self.connection = sqlite3.connect(self.path, check_same_thread=False)
25 |
26 | def __enter__(self):
27 | return self
28 |
29 | def __exit__(self, *args):
30 | self.close()
31 |
32 | def path(self):
33 | """Return the path to the file that backs this database."""
34 | return self.path
35 |
36 | def close(self):
37 | """Close the connection to the database."""
38 | self.connection.close()
39 |
40 | def get_doc_ids(self):
41 | """Fetch all ids of docs stored in the db."""
42 | cursor = self.connection.cursor()
43 | cursor.execute("SELECT id FROM documents")
44 | results = [r[0] for r in cursor.fetchall()]
45 | cursor.close()
46 | return results
47 |
48 | def get_doc_text(self, doc_id):
49 | """Fetch the raw text of the doc for 'doc_id'."""
50 | cursor = self.connection.cursor()
51 | cursor.execute(
52 | "SELECT text FROM documents WHERE id = ?",
53 | (normalize(doc_id),)
54 | )
55 | result = cursor.fetchone()
56 | cursor.close()
57 | return result if result is None else result[0]
--------------------------------------------------------------------------------
/scripts/preprocess/download_wikidump.py:
--------------------------------------------------------------------------------
1 | """
2 | download wiki dump 20181220 checking md5sum
3 | """
4 |
5 | import os
6 | import json
7 | import urllib.request
8 | import urllib.parse as urlparse
9 | import argparse
10 | import hashlib
11 | import logging
12 | import portalocker
13 | import pdb
14 | from tqdm import tqdm
15 |
16 | def parse_args():
17 | """
18 | Parse input arguments
19 | """
20 | parser = argparse.ArgumentParser()
21 |
22 | # Required
23 | parser.add_argument('--output_dir', required=True)
24 |
25 | args = parser.parse_args()
26 | return args
27 |
28 | def download_file(url, output_dir, size, expected_md5sum=None):
29 | """
30 | download file and check md5sum
31 | """
32 | logging.info("url={}".format(url))
33 |
34 | if not os.path.exists(output_dir):
35 | os.mkdir(output_dir)
36 | bz2file = os.path.join(output_dir, os.path.basename(url))
37 |
38 | lockfile = '{}.lock'.format(bz2file)
39 | with portalocker.Lock(lockfile, 'w', timeout=60):
40 | if not os.path.exists(bz2file) or os.path.getsize(bz2file) != size:
41 | logging.info("Downloading {}".format(bz2file))
42 | with urllib.request.urlopen(url) as f:
43 | with open(bz2file, 'wb') as out:
44 | for data in tqdm(f, unit='KB'):
45 | out.write(data)
46 |
47 | # Check md5sum
48 | if expected_md5sum is not None:
49 | md5 = hashlib.md5()
50 | with open(bz2file, 'rb') as infile:
51 | for line in infile:
52 | md5.update(line)
53 | if md5.hexdigest() != expected_md5sum:
54 | logging.error('Fatal: MD5 sum of downloaded file was incorrect (got {}, expected {}).'.format(md5.hexdigest(), expected_md5))
55 | logging.error('Please manually delete "{}" and rerun the command.'.format(tarball))
56 | logging.error('If the problem persists, the tarball may have changed, in which case, please contact the SacreBLEU maintainer.')
57 | sys.exit(1)
58 | else:
59 | logging.info('Checksum passed: {}'.format(md5.hexdigest()))
60 |
61 | def main(args):
62 | url = 'https://archive.org/download/enwiki-20181220/enwiki-20181220-pages-articles.xml.bz2'
63 | expected_md5sum = 'ccf875b2af67109fe5b98b5b720ce322'
64 | size = 15712882238
65 |
66 | download_file(
67 | url=url,
68 | output_dir=args.output_dir,
69 | size=size,
70 | expected_md5sum=expected_md5sum
71 | )
72 |
73 | if __name__ == '__main__':
74 | logging.basicConfig(
75 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
76 | datefmt="%m/%d/%Y %H:%M:%S",
77 | level=logging.INFO
78 | )
79 | args = parse_args()
80 | main(args)
--------------------------------------------------------------------------------
/scripts/preprocess/filter_noans.py:
--------------------------------------------------------------------------------
1 | import spacy
2 | import json
3 | import random
4 | import numpy as np
5 | from tqdm import tqdm
6 | from squad_metrics import compute_exact
7 | nlp = spacy.load("en_core_web_sm")
8 |
9 | doc = nlp('European authorities fined Google a record $5.1 billion on Wednesday for abusing its power in the mobile phone market and ordered the company to alter its practices')
10 | print([(X.text, X.label_) for X in doc.ents])
11 |
12 |
13 | data_path = '/home/data/nq-reader/dev_wiki3.json'
14 | sample = False
15 | print(f'reading {data_path} with sampling: {sample}')
16 | train_set = json.load(open(data_path))
17 | new_train_set = {'data': []}
18 | cnt = 0
19 | new_cnt = 0
20 | filtered_cnt = 0
21 |
22 | for article in tqdm(train_set['data']):
23 | new_article = {
24 | 'title': article['title'],
25 | 'paragraphs': []
26 | }
27 | for p_idx, paragraph in enumerate(article['paragraphs']):
28 | new_paragraph = {
29 | 'context': paragraph['context'],
30 | 'qas' : [],
31 | }
32 |
33 | for qa in paragraph['qas']:
34 | question = qa['question']
35 | id_ = qa['id']
36 | assert type(qa["answers"]) == dict or type(qa["answers"]) == list, type(qa["answers"])
37 | if type(qa["answers"]) == dict:
38 | qa["answers"] = [qa["answers"]]
39 | cnt += 1
40 | if len(qa["answers"]) == 0:
41 | filtered_cnt += 1
42 | continue
43 |
44 | new_paragraph['qas'].append(qa)
45 | new_cnt += 1
46 | new_article['paragraphs'].append(new_paragraph)
47 |
48 | new_train_set['data'].append(new_article)
49 | # break
50 |
51 | write_path = data_path.replace('.json', '_na_filtered.json')
52 | with open(write_path, 'w') as f:
53 | json.dump(new_train_set, f)
54 |
55 | assert filtered_cnt + new_cnt == cnt
56 | print(f'writing to {write_path} with {cnt} samples')
57 | print(f'all sample: {cnt}, new sample: {new_cnt}')
58 |
--------------------------------------------------------------------------------
/scripts/preprocess/filter_wiki.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import argparse
4 |
5 | from tqdm import tqdm
6 |
7 |
8 | def filter_wiki(args):
9 | if not os.path.exists(args.to_dir):
10 | os.makedirs(args.to_dir)
11 |
12 | names = os.listdir(args.from_dir)
13 | from_paths = [os.path.join(args.from_dir, name) for name in names]
14 | to_paths = [os.path.join(args.to_dir, name) for name in names]
15 |
16 | for from_path, to_path in zip(tqdm(from_paths), to_paths):
17 | with open(from_path, 'r') as fp:
18 | from_ = json.load(fp)
19 | to = {'data': []}
20 | for article in from_['data']:
21 | to_article = {'paragraphs': [], 'title': article['title']}
22 | for para in article['paragraphs']:
23 | if args.min_num_chars <= len(para['context']) < args.max_num_chars:
24 | to_article['paragraphs'].append(para)
25 | to['data'].append(to_article)
26 |
27 | with open(to_path, 'w') as fp:
28 | json.dump(to, fp)
29 |
30 |
31 | def get_args():
32 | parser = argparse.ArgumentParser()
33 | parser.add_argument('from_dir')
34 | parser.add_argument('to_dir')
35 | parser.add_argument('--min_num_chars', default=250, type=int)
36 | parser.add_argument('--max_num_chars', default=2500, type=int)
37 | return parser.parse_args()
38 |
39 |
40 | def main():
41 | args = get_args()
42 | filter_wiki(args)
43 |
44 |
45 | if __name__ == '__main__':
46 | main()
47 |
--------------------------------------------------------------------------------
/scripts/preprocess/merge_openqa.py:
--------------------------------------------------------------------------------
1 | import json
2 | import argparse
3 | import os
4 |
5 | from tqdm import tqdm
6 |
7 |
8 | def merge_openqa(input_dir, output_path):
9 |
10 | paths = [
11 | 'open-qa/nq-open/train_preprocessed.json',
12 | 'open-qa/webq/WebQuestions-train-nodev_preprocessed.json',
13 | 'open-qa/trec/CuratedTrec-train-nodev_preprocessed.json',
14 | 'open-qa/triviaqa-unfiltered/train_preprocessed.json',
15 | 'open-qa/squad/train_preprocessed.json',
16 | 'kilt/trex/trex-train-kilt_open_10000.json',
17 | 'kilt/zsre/structured_zeroshot-train-kilt_open_10000.json',
18 | ]
19 | paths = [os.path.join(input_dir, path) for path in paths]
20 | assert all([os.path.exists(path) for path in paths])
21 |
22 | data_to_save = []
23 | sep_cnt = 0
24 | for path in paths:
25 | with open(path) as f:
26 | data = json.load(f)['data']
27 | for item in data:
28 | if ' [SEP] ' in item['question']:
29 | item['question'] = item['question'].replace(' [SEP] ', ' ')
30 | sep_cnt += 1
31 | data_to_save += data
32 | print(f'{path} has {len(data)} QA pairs')
33 |
34 | print(f'Saving {len(data_to_save)} questions to output_path')
35 | print(f'Removed [SEP] for {sep_cnt} questions')
36 | print('Writing to %s\n'% output_path)
37 | with open(output_path, 'w') as f:
38 | json.dump({'data': data_to_save}, f)
39 |
40 |
41 | if __name__ == '__main__':
42 | parser = argparse.ArgumentParser()
43 | parser.add_argument('input_dir', type=str, default=None)
44 | parser.add_argument('output_path', type=str)
45 | args = parser.parse_args()
46 | merge_openqa(args.input_dir, args.output_path)
47 |
--------------------------------------------------------------------------------
/scripts/preprocess/merge_paq.py:
--------------------------------------------------------------------------------
1 | import json
2 | import argparse
3 | import os
4 | import h5py
5 | import csv
6 |
7 | from tqdm import tqdm
8 |
9 |
10 | def merge_paq(input_dir, out_file):
11 | num_split = 8
12 | filenames = [f'PAQ.metadata.hard0-{k}.jsonl' for k in range(num_split)]
13 | print('reading', filenames)
14 | fps = [open(os.path.join(input_dir, filename), 'r') for filename in filenames]
15 |
16 | with open(out_file, 'w') as fw:
17 | fp_idx = 0
18 | total_cnt = 0
19 | hard_cnt = 0
20 | line = fps[fp_idx].readline()
21 | while line:
22 | # for stats
23 | meta = json.loads(line)
24 | if len(meta['hard_neg_pids']) > 0:
25 | hard_cnt += 1
26 | total_cnt += 1
27 |
28 | if total_cnt % 100000 == 0:
29 | print(f'Total: {total_cnt}, Hard neg: {hard_cnt}')
30 |
31 | # write it
32 | json.dump(meta, fw, separators=(',', ':'))
33 | fw.write('\n')
34 | fp_idx = (fp_idx + 1) % num_split
35 | line = fps[fp_idx].readline()
36 |
37 | print(f'Total: {total_cnt}, Hard neg: {hard_cnt}')
38 | print(f'Saving {out_file} done')
39 |
40 |
41 | if __name__ == '__main__':
42 | parser = argparse.ArgumentParser()
43 | parser.add_argument('input_dir', type=str, default=None)
44 | parser.add_argument('out_file', type=str)
45 | args = parser.parse_args()
46 | merge_paq(args.input_dir, args.out_file)
47 |
--------------------------------------------------------------------------------
/scripts/preprocess/merge_singleqa.py:
--------------------------------------------------------------------------------
1 | import json
2 | import argparse
3 | import os
4 |
5 | from tqdm import tqdm
6 |
7 |
8 | def merge_single(input_dir, output_path):
9 |
10 | paths = [
11 | 'single-qa/nq/train_wiki3.json',
12 | 'single-qa/webq/webq-train_ds.json',
13 | 'single-qa/trec/trec-train_ds.json',
14 | 'single-qa/tqa/tqa-train_ds.json',
15 | # 'single-qa/squad/train-v1.1.json',
16 | ]
17 | paths = [os.path.join(input_dir, path) for path in paths]
18 | assert all([os.path.exists(path) for path in paths])
19 |
20 | data_to_save = []
21 | sep_cnt = 0
22 | for path in paths:
23 | with open(path) as f:
24 | data = json.load(f)['data']
25 | data_to_save += data
26 | print(f'{path} has {len(data)} PQA triples')
27 |
28 | print(f'Saving {len(data_to_save)} RC triples to output_path')
29 | print('Writing to %s\n'% output_path)
30 | with open(output_path, 'w') as f:
31 | json.dump({'data': data_to_save}, f)
32 |
33 |
34 | if __name__ == '__main__':
35 | parser = argparse.ArgumentParser()
36 | parser.add_argument('input_dir', type=str, default=None)
37 | parser.add_argument('output_path', type=str)
38 | args = parser.parse_args()
39 | merge_single(args.input_dir, args.output_path)
40 |
--------------------------------------------------------------------------------
/scripts/preprocess/nq_utils.py:
--------------------------------------------------------------------------------
1 | import gzip
2 | import json
3 | import numpy as np
4 | from tqdm import tqdm
5 |
6 | class LongAnswerCandidate(object):
7 | """Representation of long answer candidate."""
8 |
9 | def __init__(self, contents, index, is_answer, contains_answer, start_token, end_token):
10 | self.contents = contents
11 | self.index = index
12 | self.is_answer = is_answer
13 | self.contains_answer = contains_answer
14 | self.start_token = start_token
15 | self.end_token = end_token
16 | if is_answer:
17 | self.style = 'is_answer'
18 | elif contains_answer:
19 | self.style = 'contains_answer'
20 | else:
21 | self.style = 'not_answer'
22 |
23 |
24 | class Example(object):
25 | """Example representation."""
26 |
27 | def __init__(self, json_example, dataset):
28 | self.json_example = json_example
29 |
30 | # Whole example info.
31 | self.url = json_example['document_url']
32 | self.title = (
33 | json_example['document_title']
34 | if 'document_title' in json_example else 'Wikipedia')
35 | # self.example_id = base64.urlsafe_b64encode(
36 | # str(self.json_example['example_id']))
37 | self.example_id = str(self.json_example['example_id'])
38 | self.document_html = self.json_example['document_html'].encode('utf-8')
39 | self.document_tokens = self.json_example['document_tokens']
40 | self.question_text = json_example['question_text']
41 |
42 | if dataset == 'train':
43 | if len(json_example['annotations']) != 1:
44 | raise ValueError(
45 | 'Train set json_examples should have a single annotation.')
46 | annotation = json_example['annotations'][0]
47 | self.has_long_answer = annotation['long_answer']['start_byte'] >= 0
48 | self.has_short_answer = annotation[
49 | 'short_answers'] or annotation['yes_no_answer'] != 'NONE'
50 |
51 | elif dataset == 'dev':
52 | if len(json_example['annotations']) != 5:
53 | raise ValueError('Dev set json_examples should have five annotations.')
54 | self.has_long_answer = sum([
55 | annotation['long_answer']['start_byte'] >= 0
56 | for annotation in json_example['annotations']
57 | ]) >= 2
58 | self.has_short_answer = sum([
59 | bool(annotation['short_answers']) or
60 | annotation['yes_no_answer'] != 'NONE'
61 | for annotation in json_example['annotations']
62 | ]) >= 2
63 |
64 | self.long_answers = [
65 | a['long_answer']
66 | for a in json_example['annotations']
67 | if a['long_answer']['start_byte'] >= 0 and self.has_long_answer
68 | ]
69 | self.short_answers = [
70 | a['short_answers']
71 | for a in json_example['annotations']
72 | if a['short_answers'] and self.has_short_answer
73 | ]
74 | self.yes_no_answers = [
75 | a['yes_no_answer']
76 | for a in json_example['annotations']
77 | if a['yes_no_answer'] != 'NONE' and self.has_short_answer
78 | ]
79 |
80 | if self.has_long_answer:
81 | long_answer_bounds = [
82 | (la['start_byte'], la['end_byte']) for la in self.long_answers
83 | ]
84 | long_answer_counts = [
85 | long_answer_bounds.count(la) for la in long_answer_bounds
86 | ]
87 | long_answer = self.long_answers[np.argmax(long_answer_counts)]
88 | self.long_answer_text = self.render_long_answer(long_answer)
89 |
90 | else:
91 | self.long_answer_text = ''
92 |
93 | if self.has_short_answer:
94 | short_answers_ids = [[
95 | (s['start_byte'], s['end_byte']) for s in a
96 | ] for a in self.short_answers] + [a for a in self.yes_no_answers]
97 | short_answers_counts = [
98 | short_answers_ids.count(a) for a in short_answers_ids
99 | ]
100 |
101 | self.short_answers_texts = [
102 | b', '.join([
103 | self.render_span(s['start_byte'], s['end_byte'])
104 | for s in short_answer
105 | ])
106 | for short_answer in self.short_answers
107 | ]
108 |
109 | self.short_answers_texts += self.yes_no_answers
110 | self.short_answers_text = self.short_answers_texts[np.argmax(
111 | short_answers_counts)]
112 | self.short_answers_texts = set(self.short_answers_texts)
113 |
114 | else:
115 | self.short_answers_texts = []
116 | self.short_answers_text = ''
117 |
118 | self.candidates = self.get_candidates(
119 | self.json_example['long_answer_candidates'])
120 |
121 | self.candidates_with_answer = [
122 | i for i, c in enumerate(self.candidates) if c.contains_answer
123 | ]
124 |
125 | def render_long_answer(self, long_answer):
126 | """Wrap table rows and list items, and render the long answer.
127 |
128 | Args:
129 | long_answer: Long answer dictionary.
130 |
131 | Returns:
132 | String representation of the long answer span.
133 | """
134 |
135 | if long_answer['end_token'] - long_answer['start_token'] > 500:
136 | return 'Large long answer'
137 |
138 | html_tag = self.document_tokens[long_answer['end_token'] - 1]['token']
139 | if html_tag == '' and self.render_span(
140 | long_answer['start_byte'], long_answer['end_byte']).count(b'
') > 30:
141 | return 'Large table long answer'
142 |
143 | elif html_tag == ' ':
144 | return ''.format(
145 | self.render_span(long_answer['start_byte'], long_answer['end_byte']))
146 |
147 | elif html_tag in ['', '', '']:
148 | return ''.format(
149 | self.render_span(long_answer['start_byte'], long_answer['end_byte']))
150 |
151 | else:
152 | return self.render_span(long_answer['start_byte'],
153 | long_answer['end_byte'])
154 |
155 | def render_span(self, start, end):
156 | return self.document_html[start:end]
157 |
158 | def get_candidates(self, json_candidates):
159 | """Returns a list of `LongAnswerCandidate` objects for top level candidates.
160 |
161 | Args:
162 | json_candidates: List of Json records representing candidates.
163 |
164 | Returns:
165 | List of `LongAnswerCandidate` objects.
166 | """
167 | candidates = []
168 | top_level_candidates = [c for c in json_candidates if c['top_level']]
169 | for candidate in top_level_candidates:
170 | tokenized_contents = ' '.join([
171 | t['token'] for t in self.json_example['document_tokens']
172 | [candidate['start_token']:candidate['end_token']]
173 | ])
174 |
175 | start = candidate['start_byte']
176 | end = candidate['end_byte']
177 | start_token = candidate['start_token']
178 | end_token = candidate['end_token']
179 | is_answer = self.has_long_answer and np.any(
180 | [(start == ans['start_byte']) and (end == ans['end_byte'])
181 | for ans in self.long_answers])
182 | contains_answer = self.has_long_answer and np.any(
183 | [(start <= ans['start_byte']) and (end >= ans['end_byte'])
184 | for ans in self.long_answers])
185 |
186 | candidates.append(
187 | LongAnswerCandidate(tokenized_contents, len(candidates), is_answer,
188 | contains_answer, start_token, end_token))
189 |
190 | return candidates
191 |
192 | def has_long_answer(json_example):
193 | for annotation in json_example['annotations']:
194 | if annotation['long_answer']['start_byte'] >= 0:
195 | return True
196 | return False
197 |
198 |
199 | def has_short_answer(json_example):
200 | for annotation in json_example['annotations']:
201 | if annotation['short_answers']:
202 | return True
203 | return False
204 |
205 | def load_examples(fileobj, dataset, mode):
206 | """Reads jsonlines containing NQ examples.
207 |
208 | Args:
209 | fileobj: File object containing NQ examples.
210 |
211 | Returns:
212 | Dictionary mapping example id to `Example` object.
213 | """
214 |
215 | def _load(examples, f):
216 | """Read serialized json from `f`, create examples, and add to `examples`."""
217 |
218 | for l in tqdm(f):
219 | json_example = json.loads(l)
220 | if mode == 'long_answers' and not has_long_answer(json_example):
221 | continue
222 |
223 | elif mode == 'short_answers' and not has_short_answer(json_example):
224 | continue
225 |
226 | example = Example(json_example, dataset)
227 | examples[example.example_id] = example
228 |
229 | examples = {}
230 | _load(examples, gzip.GzipFile(fileobj=fileobj))
231 |
232 | return examples
233 |
--------------------------------------------------------------------------------
/scripts/preprocess/prep_wikipedia.py:
--------------------------------------------------------------------------------
1 | # https://github.com/facebookresearch/DrQA/blob/master/scripts/retriever/prep_wikipedia.py
2 | # #!/usr/bin/env python3
3 | # Copyright 2017-present, Facebook, Inc.
4 | # All rights reserved.
5 | #
6 | # This source code is licensed under the license found in the
7 | # LICENSE file in the root directory of this source tree
8 | """Preprocess function to filter/prepare Wikipedia docs."""
9 |
10 | import regex as re
11 | from html.parser import HTMLParser
12 |
13 | PARSER = HTMLParser()
14 | BLACKLIST = set(['23443579', '52643645']) # Conflicting disambig. pages
15 |
16 |
17 | def preprocess(article):
18 | # Take out HTML escaping WikiExtractor didn't clean
19 | for k, v in article.items():
20 | article[k] = PARSER.unescape(v)
21 |
22 | # Filter some disambiguation pages not caught by the WikiExtractor
23 | if article['id'] in BLACKLIST:
24 | return None
25 | if '(disambiguation)' in article['title'].lower():
26 | return None
27 | if '(disambiguation page)' in article['title'].lower():
28 | return None
29 |
30 | # Take out List/Index/Outline pages (mostly links)
31 | if re.match(r'(List of .+)|(Index of .+)|(Outline of .+)',
32 | article['title']):
33 | return None
34 |
35 | # Return doc with `id` set to `title`
36 | return {'id': article['title'], 'text': article['text']}
--------------------------------------------------------------------------------
/scripts/preprocess/sample_nq_reader_doc_wiki.py:
--------------------------------------------------------------------------------
1 | import json
2 | import glob
3 | import pdb
4 | import argparse
5 | from tqdm import tqdm
6 | import os
7 | import random
8 | import time
9 |
10 | def main(args):
11 | sampling_ratio = args.sampling_ratio
12 | wiki_dir = args.wiki_dir
13 | docs_wiki_dir = args.docs_wiki_dir
14 | output_dir = args.output_dir
15 |
16 | # count the number of total words in wikidump
17 | wiki_file_list = glob.glob(os.path.join(wiki_dir,"*"))
18 | # num_words_in_wiki = 0
19 | # for filename in tqdm(wiki_file_list, total=len(wiki_file_list)):
20 | # with open(filename,'r') as f:
21 | # data = json.load(f)['data']
22 |
23 | # for doc in data:
24 | # for paragraph in doc['paragraphs']:
25 | # context = paragraph['context']
26 | # num_words_in_wiki += len(context.split(" "))
27 |
28 | # print(num_words_in_wiki)
29 |
30 | num_words_in_wiki = 2054581517
31 | num_sample_words = int(num_words_in_wiki * sampling_ratio)
32 |
33 | print("num_words_in_wiki={}".format(num_words_in_wiki))
34 |
35 | # count the number of total words in docs_wiki
36 | docs_wiki_file_list = sorted(glob.glob(os.path.join(docs_wiki_dir,"*")))
37 | num_words_in_docs_wiki = 0
38 | docs_wiki_titles = {}
39 | docs_wikis = []
40 | for filename in tqdm(docs_wiki_file_list, total=len(docs_wiki_file_list)):
41 | with open(filename,'r') as f:
42 | data = json.load(f)['data']
43 |
44 | for doc in data:
45 | docs_wikis.append(doc)
46 | docs_wiki_titles[doc['title']] = ""
47 | for paragraph in doc['paragraphs']:
48 | context = paragraph['context']
49 | num_words_in_docs_wiki += len(context.split(" "))
50 |
51 | print("num_words_in_docs_wiki={}".format(num_words_in_docs_wiki))
52 | random.seed(2020)
53 | i = 0
54 | while True:
55 | if num_words_in_docs_wiki > num_sample_words:
56 | break
57 |
58 | # random pick from wiki filelist
59 | # start_time = time.time()
60 | random_wiki_file = random.sample(wiki_file_list, 1)[0]
61 | # if i % 100 == 0:
62 | # print("(1) ", time.time() - start_time)
63 |
64 | with open(random_wiki_file,'r') as f:
65 | data = json.load(f)['data']
66 |
67 | # random pick from articles
68 | # start_time = time.time()
69 | random_articles = random.sample(data, 100)
70 | # if i % 100 == 0:
71 | # print("(2) ", time.time() - start_time)
72 |
73 | # start_time = time.time()
74 | for random_article in random_articles:
75 | # if already existing article in docs_wiki, then pass
76 | if random_article['title'] in docs_wiki_titles:
77 | continue
78 | docs_wikis.append(random_article)
79 | docs_wiki_titles[random_article['title']] = ""
80 | # if i % 100 == 0:
81 | # print("(3) ", time.time() - start_time)
82 |
83 | # start_time = time.time()
84 | for random_article in random_articles:
85 | for paragraph in random_article['paragraphs']:
86 | context = paragraph['context']
87 | num_words_in_docs_wiki += len(context.split(" "))
88 | # if i % 100 == 0:
89 | # print("(4) ", time.time() - start_time)
90 |
91 | if i % 100 == 0:
92 | print("title={} len(docs_wiki_titles)={} ratio={}".format(random_article['title'], len(docs_wiki_titles), num_words_in_docs_wiki/num_words_in_wiki))
93 | i += 1
94 |
95 | if not os.path.exists(output_dir):
96 | os.mkdir(output_dir)
97 |
98 | # shuffle docs_wikis for balanced file size
99 | random.shuffle(docs_wikis)
100 |
101 | for i in range(int(len(docs_wikis)/1000) + 1):
102 | output_file = os.path.join(output_dir, '{:d}'.format(i).zfill(4))
103 | local_docs_wikis = docs_wikis[i*1000:(i+1)*1000]
104 |
105 | output = {
106 | 'data' : local_docs_wikis
107 | }
108 |
109 | # save nq_reader
110 | with open(output_file,'w') as f:
111 | json.dump(output, f)
112 |
113 | # # pdb.set_trace()
114 |
115 | # wiki_titles = []
116 | # wiki_title2paragraphs = {}
117 | # for filename in tqdm(wiki_file_list, total=len(wiki_file_list)):
118 | # with open(filename,'r') as f:
119 | # data = json.load(f)['data']
120 |
121 | # for doc in data:
122 | # title = doc['title']
123 | # wiki_titles.append(title)
124 | # paragraph = doc['paragraphs']
125 | # wiki_title2paragraphs[title] = paragraph
126 | # num_wiki += 1
127 |
128 | # assert len(wiki_title2paragraphs) == num_wiki
129 |
130 | # nq_file_list = glob.glob(os.path.join(nq_dir,"*"))
131 | # nq_titles = []
132 | # unmatched_titles = []
133 | # num_matched = 0
134 | # num_unmatched = 0
135 | # for filename in tqdm(nq_file_list, total=len(nq_file_list)):
136 | # with open(filename,'r') as f:
137 | # data = json.load(f)['data']
138 |
139 | # for doc in data:
140 | # title = doc['title']
141 | # nq_titles.append(title)
142 | # if title in wiki_title2paragraphs and len(wiki_title2paragraphs[title])>0:
143 | # doc['paragraphs'] = wiki_title2paragraphs[title]
144 | # num_matched += 1
145 | # else:
146 | # unmatched_titles.append(title)
147 | # num_unmatched +=1
148 |
149 | # new_paragraphs = []
150 | # for paragraph in doc['paragraphs']:
151 | # if ('is_paragraph' in paragraph) and (not paragraph['is_paragraph']):
152 | # continue
153 |
154 | # new_paragraphs.append({
155 | # 'context': paragraph['context']
156 | # })
157 | # doc['paragraphs'] = new_paragraphs
158 |
159 | # if not os.path.exists(output_dir):
160 | # os.mkdir(output_dir)
161 |
162 | # output_path = os.path.join(output_dir,os.path.basename(filename))
163 | # output = {
164 | # 'data': data
165 | # }
166 |
167 | # with open(output_path, 'w') as f:
168 | # json.dump(output, f, indent=2)
169 |
170 | # with open('unmatched_title_old_dev.txt', 'w') as f:
171 | # for title in unmatched_titles:
172 | # f.writelines(title)
173 | # f.writelines("\n")
174 |
175 | # print("num_matched={} num_unmatched={}".format(num_matched, num_unmatched))
176 | # print("len(nq_titles)={} len(wiki_titles)={}".format(len(nq_titles), len(wiki_titles)))
177 |
178 | if __name__ == '__main__':
179 | parser = argparse.ArgumentParser()
180 | # Required parameters
181 | parser.add_argument("--sampling_ratio", type=float, required=True)
182 | parser.add_argument("--wiki_dir", type=str, required=True)
183 | parser.add_argument("--docs_wiki_dir", type=str, required=True)
184 | parser.add_argument("--output_dir", type=str, required=True)
185 |
186 | args = parser.parse_args()
187 |
188 | main(args)
189 |
--------------------------------------------------------------------------------
/scripts/preprocess/simple_tokenizer.py:
--------------------------------------------------------------------------------
1 | # https://github.com/facebookresearch/DrQA/blob/master/drqa/tokenizers/simple_tokenizer.py#L18
2 |
3 | #!/usr/bin/env python3
4 | # Copyright 2017-present, Facebook, Inc.
5 | # All rights reserved.
6 | #
7 | # This source code is licensed under the license found in the
8 | # LICENSE file in the root directory of this source tree.
9 | """Basic tokenizer that splits text into alpha-numeric tokens and
10 | non-whitespace tokens.
11 | """
12 |
13 | import copy
14 | import regex
15 | import logging
16 |
17 | logger = logging.getLogger(__name__)
18 |
19 |
20 | class Tokens(object):
21 | """A class to represent a list of tokenized text."""
22 | TEXT = 0
23 | TEXT_WS = 1
24 | SPAN = 2
25 | POS = 3
26 | LEMMA = 4
27 | NER = 5
28 |
29 | def __init__(self, data, annotators, opts=None):
30 | self.data = data
31 | self.annotators = annotators
32 | self.opts = opts or {}
33 |
34 | def __len__(self):
35 | """The number of tokens."""
36 | return len(self.data)
37 |
38 | def slice(self, i=None, j=None):
39 | """Return a view of the list of tokens from [i, j)."""
40 | new_tokens = copy.copy(self)
41 | new_tokens.data = self.data[i: j]
42 | return new_tokens
43 |
44 | def untokenize(self):
45 | """Returns the original text (with whitespace reinserted)."""
46 | return ''.join([t[self.TEXT_WS] for t in self.data]).strip()
47 |
48 | def words(self, uncased=False):
49 | """Returns a list of the text of each token
50 | Args:
51 | uncased: lower cases text
52 | """
53 | if uncased:
54 | return [t[self.TEXT].lower() for t in self.data]
55 | else:
56 | return [t[self.TEXT] for t in self.data]
57 |
58 | def offsets(self):
59 | """Returns a list of [start, end) character offsets of each token."""
60 | return [t[self.SPAN] for t in self.data]
61 |
62 | def pos(self):
63 | """Returns a list of part-of-speech tags of each token.
64 | Returns None if this annotation was not included.
65 | """
66 | if 'pos' not in self.annotators:
67 | return None
68 | return [t[self.POS] for t in self.data]
69 |
70 | def lemmas(self):
71 | """Returns a list of the lemmatized text of each token.
72 | Returns None if this annotation was not included.
73 | """
74 | if 'lemma' not in self.annotators:
75 | return None
76 | return [t[self.LEMMA] for t in self.data]
77 |
78 | def entities(self):
79 | """Returns a list of named-entity-recognition tags of each token.
80 | Returns None if this annotation was not included.
81 | """
82 | if 'ner' not in self.annotators:
83 | return None
84 | return [t[self.NER] for t in self.data]
85 |
86 | def ngrams(self, n=1, uncased=False, filter_fn=None, as_strings=True):
87 | """Returns a list of all ngrams from length 1 to n.
88 | Args:
89 | n: upper limit of ngram length
90 | uncased: lower cases text
91 | filter_fn: user function that takes in an ngram list and returns
92 | True or False to keep or not keep the ngram
93 | as_string: return the ngram as a string vs list
94 | """
95 | def _skip(gram):
96 | if not filter_fn:
97 | return False
98 | return filter_fn(gram)
99 |
100 | words = self.words(uncased)
101 | ngrams = [(s, e + 1)
102 | for s in range(len(words))
103 | for e in range(s, min(s + n, len(words)))
104 | if not _skip(words[s:e + 1])]
105 |
106 | # Concatenate into strings
107 | if as_strings:
108 | ngrams = ['{}'.format(' '.join(words[s:e])) for (s, e) in ngrams]
109 |
110 | return ngrams
111 |
112 | def entity_groups(self):
113 | """Group consecutive entity tokens with the same NER tag."""
114 | entities = self.entities()
115 | if not entities:
116 | return None
117 | non_ent = self.opts.get('non_ent', 'O')
118 | groups = []
119 | idx = 0
120 | while idx < len(entities):
121 | ner_tag = entities[idx]
122 | # Check for entity tag
123 | if ner_tag != non_ent:
124 | # Chomp the sequence
125 | start = idx
126 | while (idx < len(entities) and entities[idx] == ner_tag):
127 | idx += 1
128 | groups.append((self.slice(start, idx).untokenize(), ner_tag))
129 | else:
130 | idx += 1
131 | return groups
132 |
133 |
134 | class Tokenizer(object):
135 | """Base tokenizer class.
136 | Tokenizers implement tokenize, which should return a Tokens class.
137 | """
138 | def tokenize(self, text):
139 | raise NotImplementedError
140 |
141 | def shutdown(self):
142 | pass
143 |
144 | def __del__(self):
145 | self.shutdown()
146 |
147 | class SimpleTokenizer(Tokenizer):
148 | ALPHA_NUM = r'[\p{L}\p{N}\p{M}]+'
149 | NON_WS = r'[^\p{Z}\p{C}]'
150 |
151 | def __init__(self, **kwargs):
152 | """
153 | Args:
154 | annotators: None or empty set (only tokenizes).
155 | """
156 | self._regexp = regex.compile(
157 | '(%s)|(%s)' % (self.ALPHA_NUM, self.NON_WS),
158 | flags=regex.IGNORECASE + regex.UNICODE + regex.MULTILINE
159 | )
160 | if len(kwargs.get('annotators', {})) > 0:
161 | logger.warning('%s only tokenizes! Skipping annotators: %s' %
162 | (type(self).__name__, kwargs.get('annotators')))
163 | self.annotators = set()
164 |
165 | def tokenize(self, text):
166 | data = []
167 | matches = [m for m in self._regexp.finditer(text)]
168 | for i in range(len(matches)):
169 | # Get text
170 | token = matches[i].group()
171 |
172 | # Get whitespace
173 | span = matches[i].span()
174 | start_ws = span[0]
175 | if i + 1 < len(matches):
176 | end_ws = matches[i + 1].span()[0]
177 | else:
178 | end_ws = span[1]
179 |
180 | # Format data
181 | data.append((
182 | token,
183 | text[start_ws: end_ws],
184 | span,
185 | ))
186 | return Tokens(data, self.annotators)
--------------------------------------------------------------------------------
/scripts/preprocess/stat_entities.py:
--------------------------------------------------------------------------------
1 | import spacy
2 | import json
3 | import random
4 | import numpy as np
5 | from tqdm import tqdm
6 |
7 | nlp_sent = spacy.load("en_core_web_sm")
8 | doc = nlp_sent('European authorities fined Google a record $5.1 billion on Wednesday for abusing its power in the mobile phone market and ordered the company to alter its practices')
9 | print([(X.text, X.label_) for X in doc.ents])
10 |
11 |
12 | # pred_file = '/n/fs/nlp-jl5167/outputs/pred/dev_preprocessed_8757.pred'
13 | pred_file = 'lama-test-P20_preprocessed_953.pred'
14 | with open(pred_file) as f:
15 | predictions = json.load(f)
16 |
17 | stat = {}
18 | ent_types = {}
19 | tokenizer_error_cnt = 0
20 | entity_error_cnt = 0
21 | for pid, result in predictions.items():
22 | question = result['question']
23 | q_sws = result['q_tokens'][1:-1] # except [CLS], [SEP]
24 | q_ents = [(X.text, X.label_, X[0].idx) for X in nlp_sent(question).ents]
25 | if len(q_ents) == 0:
26 | entity_error_cnt += 1
27 | continue
28 |
29 | word_idx = 0
30 | word_to_sw = {}
31 | for sw_idx, sw in enumerate(q_sws):
32 | if word_idx not in word_to_sw:
33 | word_to_sw[word_idx] = []
34 | word_to_sw[word_idx].append(sw_idx)
35 | if sw_idx < len(q_sws) - 1:
36 | if not q_sws[sw_idx+1].startswith('##'):
37 | word_idx += 1
38 | try:
39 | assert word_idx == len(question.split(' ')) - 1
40 | except Exception as e:
41 | tokenizer_error_cnt += 1
42 | continue
43 |
44 | char_to_word = {}
45 | word_idx = 0
46 | for ch_idx, ch in enumerate(question):
47 | if ch == ' ':
48 | word_idx += 1
49 | continue
50 | char_to_word[ch_idx] = word_idx
51 |
52 | try:
53 | assert word_idx == len(question.split(' ')) - 1
54 | except Exception as e:
55 | tokenizer_error_cnt += 1
56 | continue
57 |
58 | num_sw = []
59 | ent_list = [
60 | 'EVENT', 'FAC', 'GPE', 'LANGUAGE', 'LAW', 'LOC',
61 | 'NORP', 'ORG', 'PERSON', 'PRODUCT', 'WORK_OF_ART'
62 | ]
63 | for ent_text, ent_label, ent_start in q_ents:
64 | if ent_label not in ent_list:
65 | continue
66 | char_start = ent_start
67 | char_end = ent_start + len(ent_text) - 1
68 | word_start = char_to_word[char_start]
69 | word_end = char_to_word[char_end]
70 | num_sw.append(sum([len(word_to_sw[word]) for word in range(word_start, word_end+1)]))
71 | # num_sw.append(max([len(word_to_sw[word]) for word in range(word_start, word_end+1)]))
72 | if ent_label not in ent_types:
73 | ent_types[ent_label] = 0
74 | print(ent_text, ent_label)
75 | ent_types[ent_label] += 1
76 |
77 | if len(num_sw) == 0:
78 | entity_error_cnt += 1
79 | continue
80 |
81 | num_sw = max(num_sw)
82 | if num_sw not in stat:
83 | print(num_sw, q_sws)
84 | stat[num_sw] = []
85 | stat[num_sw].append(int(result['em_top1']))
86 |
87 | output = sorted({key: (f'{sum(val)/len(val):.2f}', f'{len(val)} Qs') for key, val in stat.items()}.items())
88 | print(f'exclude {tokenizer_error_cnt} questions for tokenization error')
89 | print(f'exclude {entity_error_cnt} questions for entity not found error')
90 | print(f'stat: {output} for {len(predictions) - tokenizer_error_cnt - entity_error_cnt} questions')
91 | print(sorted(ent_types.items()))
92 |
--------------------------------------------------------------------------------
/scripts/question_generation/filter_qg.py:
--------------------------------------------------------------------------------
1 | import spacy
2 | import json
3 | import random
4 | import numpy as np
5 | from tqdm import tqdm
6 | from squad_metrics import compute_exact
7 | nlp = spacy.load("en_core_web_sm")
8 |
9 | doc = nlp('European authorities fined Google a record $5.1 billion on Wednesday for abusing its power in the mobile phone market and ordered the company to alter its practices')
10 | print([(X.text, X.label_) for X in doc.ents])
11 |
12 |
13 | data_path = 'data/squad-nq/train-sqdqg_nqqg.json'
14 | sample = False
15 | print(f'reading {data_path} with sampling: {sample}')
16 | train_set = json.load(open(data_path))
17 | new_train_set = {'data': []}
18 | cnt = 0
19 | new_cnt = 0
20 | orig_cnt = 0
21 | miss_cnt = 0
22 |
23 | prediction_path = 'models/spanbert-base-cased-sqdnq_qgfilter/predictions_.json'
24 | predictions = {str(id_): pred for id_, pred in json.load(open(prediction_path)).items()}
25 |
26 | for article in tqdm(train_set['data']):
27 | new_article = {
28 | 'title': article['title'],
29 | 'paragraphs': []
30 | }
31 | for p_idx, paragraph in enumerate(article['paragraphs']):
32 | new_paragraph = {
33 | 'context': paragraph['context'],
34 | 'qas' : [],
35 | }
36 |
37 | for qa in paragraph['qas']:
38 | question = qa['question']
39 | id_ = str(qa['id'])
40 | # assert id_ in predictions
41 | if id_ not in predictions:
42 | print('missing predictions', id_)
43 | miss_cnt += 1
44 | continue
45 | if all(kk in id_ for kk in['_p', '_s', '_a']):
46 | if not compute_exact(qa['answers'][0]['text'], predictions[id_]):
47 | continue
48 | else:
49 | new_cnt += 1
50 | else:
51 | orig_cnt += 1
52 |
53 | new_paragraph['qas'].append(qa)
54 | cnt += 1
55 | new_article['paragraphs'].append(new_paragraph)
56 |
57 | new_train_set['data'].append(new_article)
58 | # break
59 |
60 | write_path = data_path.replace('.json', '_filtered.json')
61 | with open(write_path, 'w') as f:
62 | json.dump(new_train_set, f)
63 |
64 | assert orig_cnt + new_cnt == cnt
65 | print(f'writing to {write_path} with {cnt} samples')
66 | print(f'orig sample: {orig_cnt}, new sample: {new_cnt}')
67 | print(f'missing sample: {miss_cnt}')
68 |
--------------------------------------------------------------------------------
/scripts/question_generation/generate_squad.py:
--------------------------------------------------------------------------------
1 | import spacy
2 | import json
3 | import random
4 | import numpy as np
5 | from tqdm import tqdm
6 | from pipelines import pipeline
7 |
8 | nlp_sent = spacy.load("en_core_web_sm")
9 | doc = nlp_sent('European authorities fined Google a record $5.1 billion on Wednesday for abusing its power in the mobile phone market and ordered the company to alter its practices')
10 | print([(X.text, X.label_) for X in doc.ents])
11 |
12 | # Please train your own model on SQuAD and load as below
13 | nlp = pipeline("multitask-qa-qg", model="t5-large-multi-hl/checkpoint-3500", qg_format="highlight")
14 |
15 |
16 | data_path = '/home/data/squad/train-v1.1.json'
17 | sample = False
18 | print(f'reading {data_path} with sampling: {sample}')
19 | train_set = json.load(open(data_path))
20 | new_train_set = {'data': []}
21 | cnt = 0
22 | answer_stats = []
23 | bs = 16
24 | tmp_path = data_path.replace('.json', '_qg_t5l35-sqd_tmp.json')
25 | tmp_file = open(tmp_path, 'a')
26 |
27 | for article in tqdm(train_set['data']):
28 | new_article = {
29 | 'title': article['title'],
30 | 'paragraphs': []
31 | }
32 | for p_idx, paragraph in enumerate(article['paragraphs']):
33 | new_paragraph = {
34 | 'context': paragraph['context'],
35 | 'qas' : [],
36 | }
37 |
38 | # Add existing QA pairs
39 | for qa in paragraph['qas']:
40 | new_paragraph['qas'].append(qa)
41 | cnt += 1
42 |
43 | # Get sentences
44 | sents = [sent for sent in nlp_sent(paragraph['context']).sents]
45 | qa_pairs = []
46 | try:
47 | qa_pairs = nlp(paragraph['context'])
48 | except Exception as e:
49 | print('Neural QG error:', paragraph['context'][:50], e)
50 |
51 | ents = [[] for _ in range(len(sents))]
52 | try:
53 | for sent_idx, sent in enumerate(sents):
54 | parse_list = [ent for ent in sent.ents]
55 | ents[sent_idx] += parse_list
56 | except Exception as e:
57 | print('NER error:', sent.text, e)
58 |
59 | cst_qa_pairs = []
60 | try:
61 | flat_ents = [e for ent in ents for e in ent]
62 | qg_examples = nlp._prepare_inputs_for_qg_from_answers_hl(
63 | [sent.text.strip() for sent in sents], [[e.text for e in ent] for ent in ents]
64 | )
65 | qg_inputs = [example['source_text'] for example in qg_examples]
66 | cst_qs = []
67 | for i in range(0, len(qg_inputs), bs):
68 | cst_qs += nlp._generate_questions(qg_inputs[i:i+bs])
69 | assert len(cst_qs) == len(qg_examples)
70 | cst_qa_pairs = [{'answer': example['answer'], 'question': que} for example, que in zip(qg_examples, cst_qs)]
71 | except Exception as e:
72 | print('Ent QG error:', e)
73 |
74 | orig_len = len(qa_pairs)
75 | qa_pairs = qa_pairs + cst_qa_pairs
76 | if len(qa_pairs) == 0:
77 | print('Skipping as no questions generated for:', sent.text)
78 | continue
79 | flat_ents = [None]*orig_len + flat_ents
80 |
81 | q_set = []
82 | for qa_idx, qa_pair in enumerate(qa_pairs):
83 | ans = qa_pair['answer']
84 | que = qa_pair['question']
85 | if que in q_set:
86 | continue
87 | q_set.append(que)
88 | try:
89 | if flat_ents[qa_idx] is not None:
90 | ans_start = flat_ents[qa_idx][0].idx
91 | else:
92 | ans_start = paragraph['context'].index(ans)
93 | except Exception as e:
94 | print('Skipping ans:', ans, e)
95 | continue
96 | if ans != paragraph['context'][ans_start:ans_start+len(ans)]:
97 | print(f'skipping mis-match {ans}')
98 | continue
99 | new_paragraph['qas'].append({
100 | 'answers': [{'answer_start': ans_start, 'text': ans}],
101 | 'question': que,
102 | 'id': f'{article["title"]}_p{p_idx}_s{sent_idx}_a{qa_idx}',
103 | })
104 | tmp_file.write(
105 | f'{article["title"]}_p{p_idx}_s{sent_idx}_a{qa_idx}\t{que}\t{ans}\t{ans_start}\n'
106 | )
107 | cnt += 1
108 |
109 | if len(qa_pairs) > 0:
110 | print(qa_pairs[0])
111 | new_article['paragraphs'].append(new_paragraph)
112 |
113 | new_train_set['data'].append(new_article)
114 |
115 | write_path = data_path.replace('.json', '_qg_t5l35-sqd.json')
116 | with open(write_path, 'w') as f:
117 | json.dump(new_train_set, f)
118 |
119 | print(f'writing to {write_path} with {cnt} samples')
120 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | import io
2 | from setuptools import setup, find_packages
3 |
4 | with open('README.md', encoding='utf8') as f:
5 | readme = f.read()
6 |
7 | with open('LICENSE', encoding='utf8') as f:
8 | license = f.read()
9 |
10 | with open('requirements.txt', encoding='utf8') as f:
11 | reqs = f.read()
12 |
13 | setup(
14 | name='densephrases',
15 | version='1.0',
16 | description='Learning Dense Representations of Phrases at Scale',
17 | long_description=readme,
18 | license=license,
19 | url='https://github.com/princeton-nlp/DensePhrases',
20 | keywords=['phrase', 'embedding', 'retrieval', 'nlp', 'open-domain', 'qa'],
21 | python_requires='>=3.7',
22 | install_requires=reqs.strip().split('\n'),
23 | )
24 |
--------------------------------------------------------------------------------
/slides/emnlp2021_slides.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/princeton-nlp/DensePhrases/9583883ea9390b0308e806c3e72fa5831afa445b/slides/emnlp2021_slides.pdf
--------------------------------------------------------------------------------