├── .gitignore ├── .vscode └── launch.json ├── LICENSE ├── drqa ├── __init__.py ├── features │ ├── map.txt │ ├── readme.md │ └── stopword_zh.txt ├── pipeline │ ├── __init__.py │ ├── drqa.py │ └── simpleDrQA.py ├── reader │ ├── __init__.py │ ├── config.py │ ├── data.py │ ├── layers.py │ ├── model.py │ ├── predictor.py │ ├── rnn_reader.py │ ├── utils.py │ └── vector.py ├── retriever │ ├── __init__.py │ ├── doc_db.py │ ├── net_retriever.py │ ├── tfidf_doc_ranker.py │ └── utils.py └── tokenizers │ ├── Zh_tokenizer.py │ ├── __init__.py │ ├── corenlp_tokenizer.py │ ├── regexp_tokenizer.py │ ├── simple_tokenizer.py │ ├── spacy_tokenizer.py │ ├── tokenizer.py │ └── zh_features.py ├── readme.md ├── requirements.txt ├── scripts ├── convert │ ├── squad.py │ └── webquestions.py ├── distant │ ├── check_data.py │ └── generate.py ├── pipeline │ ├── eval.py │ ├── interactive.py │ ├── predict.py │ └── sinteractive.py ├── reader │ ├── README.md │ ├── interactive.py │ ├── predict.py │ ├── preprocess.py │ └── train.py └── retriever │ ├── README.md │ ├── build_db.py │ ├── build_tfidf.py │ ├── eval.py │ ├── interactive.py │ └── prep_wikipedia.py └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.DS_Store 3 | __pycache__/ 4 | *~ 5 | data/ 6 | .vscode/ 7 | *.tar.gz 8 | *.egg-info 9 | *.sh 10 | img/ 11 | zh_dict.json -------------------------------------------------------------------------------- /.vscode/launch.json: -------------------------------------------------------------------------------- 1 | { 2 | "version": "0.2.0", 3 | "configurations": [ 4 | 5 | { 6 | "name": "Python", 7 | "type": "python", 8 | "request": "launch", 9 | "stopOnEntry": true, 10 | "pythonPath": "${config:python.pythonPath}", 11 | "program": "${file}", 12 | "cwd": "${workspaceRoot}", 13 | "env": {}, 14 | "envFile": "${workspaceRoot}/.env", 15 | "debugOptions": [ 16 | "WaitOnAbnormalExit", 17 | "WaitOnNormalExit", 18 | "RedirectOutput" 19 | ] 20 | }, 21 | { 22 | "name": "PySpark", 23 | "type": "python", 24 | "request": "launch", 25 | "stopOnEntry": true, 26 | "osx": { 27 | "pythonPath": "${env:SPARK_HOME}/bin/spark-submit" 28 | }, 29 | "windows": { 30 | "pythonPath": "${env:SPARK_HOME}/bin/spark-submit.cmd" 31 | }, 32 | "linux": { 33 | "pythonPath": "${env:SPARK_HOME}/bin/spark-submit" 34 | }, 35 | "program": "${file}", 36 | "cwd": "${workspaceRoot}", 37 | "env": {}, 38 | "envFile": "${workspaceRoot}/.env", 39 | "debugOptions": [ 40 | "WaitOnAbnormalExit", 41 | "WaitOnNormalExit", 42 | "RedirectOutput" 43 | ] 44 | }, 45 | { 46 | "name": "Python Module", 47 | "type": "python", 48 | "request": "launch", 49 | "stopOnEntry": true, 50 | "pythonPath": "${config:python.pythonPath}", 51 | "module": "module.name", 52 | "cwd": "${workspaceRoot}", 53 | "env": {}, 54 | "envFile": "${workspaceRoot}/.env", 55 | "debugOptions": [ 56 | "WaitOnAbnormalExit", 57 | "WaitOnNormalExit", 58 | "RedirectOutput" 59 | ] 60 | }, 61 | { 62 | "name": "Integrated Terminal/Console", 63 | "type": "python", 64 | "request": "launch", 65 | "stopOnEntry": true, 66 | "pythonPath": "${config:python.pythonPath}", 67 | "program": "${file}", 68 | "cwd": "", 69 | "console": "integratedTerminal", 70 | "env": {}, 71 | "envFile": "${workspaceRoot}/.env", 72 | "debugOptions": [ 73 | "WaitOnAbnormalExit", 74 | "WaitOnNormalExit" 75 | ] 76 | }, 77 | { 78 | "name": "External Terminal/Console", 79 | "type": "python", 80 | "request": "launch", 81 | "stopOnEntry": true, 82 | "pythonPath": "${config:python.pythonPath}", 83 | "program": "${file}", 84 | "cwd": "", 85 | "console": "externalTerminal", 86 | "env": {}, 87 | "envFile": "${workspaceRoot}/.env", 88 | "debugOptions": [ 89 | "WaitOnAbnormalExit", 90 | "WaitOnNormalExit" 91 | ] 92 | }, 93 | { 94 | "name": "Django", 95 | "type": "python", 96 | "request": "launch", 97 | "stopOnEntry": true, 98 | "pythonPath": "${config:python.pythonPath}", 99 | "program": "${workspaceRoot}/manage.py", 100 | "cwd": "${workspaceRoot}", 101 | "args": [ 102 | "runserver", 103 | "--noreload", 104 | "--nothreading" 105 | ], 106 | "env": {}, 107 | "envFile": "${workspaceRoot}/.env", 108 | "debugOptions": [ 109 | "WaitOnAbnormalExit", 110 | "WaitOnNormalExit", 111 | "RedirectOutput", 112 | "DjangoDebugging" 113 | ] 114 | }, 115 | { 116 | "name": "Flask", 117 | "type": "python", 118 | "request": "launch", 119 | "stopOnEntry": false, 120 | "pythonPath": "${config:python.pythonPath}", 121 | "program": "fully qualified path fo 'flask' executable. Generally located along with python interpreter", 122 | "cwd": "${workspaceRoot}", 123 | "env": { 124 | "FLASK_APP": "${workspaceRoot}/quickstart/app.py" 125 | }, 126 | "args": [ 127 | "run", 128 | "--no-debugger", 129 | "--no-reload" 130 | ], 131 | "envFile": "${workspaceRoot}/.env", 132 | "debugOptions": [ 133 | "WaitOnAbnormalExit", 134 | "WaitOnNormalExit", 135 | "RedirectOutput" 136 | ] 137 | }, 138 | { 139 | "name": "Flask (old)", 140 | "type": "python", 141 | "request": "launch", 142 | "stopOnEntry": false, 143 | "pythonPath": "${config:python.pythonPath}", 144 | "program": "${workspaceRoot}/run.py", 145 | "cwd": "${workspaceRoot}", 146 | "args": [], 147 | "env": {}, 148 | "envFile": "${workspaceRoot}/.env", 149 | "debugOptions": [ 150 | "WaitOnAbnormalExit", 151 | "WaitOnNormalExit", 152 | "RedirectOutput" 153 | ] 154 | }, 155 | { 156 | "name": "Pyramid", 157 | "type": "python", 158 | "request": "launch", 159 | "stopOnEntry": true, 160 | "pythonPath": "${config:python.pythonPath}", 161 | "cwd": "${workspaceRoot}", 162 | "env": {}, 163 | "envFile": "${workspaceRoot}/.env", 164 | "args": [ 165 | "${workspaceRoot}/development.ini" 166 | ], 167 | "debugOptions": [ 168 | "WaitOnAbnormalExit", 169 | "WaitOnNormalExit", 170 | "RedirectOutput", 171 | "Pyramid" 172 | ] 173 | }, 174 | { 175 | "name": "Watson", 176 | "type": "python", 177 | "request": "launch", 178 | "stopOnEntry": true, 179 | "pythonPath": "${config:python.pythonPath}", 180 | "program": "${workspaceRoot}/console.py", 181 | "cwd": "${workspaceRoot}", 182 | "args": [ 183 | "dev", 184 | "runserver", 185 | "--noreload=True" 186 | ], 187 | "env": {}, 188 | "envFile": "${workspaceRoot}/.env", 189 | "debugOptions": [ 190 | "WaitOnAbnormalExit", 191 | "WaitOnNormalExit", 192 | "RedirectOutput" 193 | ] 194 | }, 195 | { 196 | "name": "Attach (Remote Debug)", 197 | "type": "python", 198 | "request": "attach", 199 | "localRoot": "${workspaceRoot}", 200 | "remoteRoot": "${workspaceRoot}", 201 | "port": 3000, 202 | "secret": "my_secret", 203 | "host": "localhost" 204 | } 205 | ] 206 | } -------------------------------------------------------------------------------- /drqa/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import os 9 | import sys 10 | from pathlib import PosixPath 11 | 12 | if sys.version_info < (3, 5): 13 | raise RuntimeError('DrQA supports Python 3.5 or higher.') 14 | 15 | DATA_DIR = ( 16 | os.getenv('DRQA_DATA') or 17 | os.path.join(PosixPath(__file__).absolute().parents[1].as_posix(), 'data') 18 | ) 19 | 20 | from . import tokenizers 21 | from . import reader 22 | from . import retriever 23 | from . import pipeline 24 | -------------------------------------------------------------------------------- /drqa/features/map.txt: -------------------------------------------------------------------------------- 1 | 西交大 西安交通大学 2 | 西交 西安交通大学 3 | 西安交大 西安交通大学 4 | XJTU 西安交通大学 5 | xjtu 西安交通大学 6 | Xjtu 西安交通大学 7 | 数学院 数学与统计学院 8 | 数学学院 数学与统计学院 9 | 前沿 前沿科学技术研 10 | 机械 机械工程 11 | 电气 电气工程 12 | 能动 能源与动力工程 13 | 电信 电子与信息工程 14 | 材料 材料科学与工程 15 | 人居学院 人居环境与建筑工程学院 16 | 生命 生命科学与技术 17 | 航天 航天航空 18 | 化工 化学工程与技术 19 | 经金 经济与金融 20 | 公管 公共政策与管理 21 | 人文 人文社会学科 22 | 新闻学院 新闻与新媒体学院 23 | -------------------------------------------------------------------------------- /drqa/features/readme.md: -------------------------------------------------------------------------------- 1 | 1. map provided simple transformation both in question and in docs 2 | 2. stop word stored chinese stop words 3 | 3. zh\_dict.json is used in zh\_features.py for Chinese English translation. 4 | 5 | 6 | -------------------------------------------------------------------------------- /drqa/pipeline/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import os 9 | from ..tokenizers import CoreNLPTokenizer 10 | from ..retriever import TfidfDocRanker 11 | from ..retriever import DocDB 12 | from .. import DATA_DIR 13 | 14 | DEFAULTS = { 15 | 'tokenizer': CoreNLPTokenizer, 16 | 'ranker': TfidfDocRanker, 17 | 'db': DocDB, 18 | 'reader_model': os.path.join(DATA_DIR, 'reader/multitask.mdl'), 19 | } 20 | 21 | 22 | def set_default(key, value): 23 | global DEFAULTS 24 | DEFAULTS[key] = value 25 | 26 | 27 | from .drqa import DrQA 28 | from .simpleDrQA import SDrQA -------------------------------------------------------------------------------- /drqa/pipeline/drqa.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | """Full DrQA pipeline.""" 8 | # original pipline: you may want to check yourself if you need this one 9 | 10 | 11 | import torch 12 | import regex 13 | import heapq 14 | import math 15 | import time 16 | import logging 17 | 18 | from multiprocessing import Pool as ProcessPool 19 | from multiprocessing.util import Finalize 20 | 21 | from ..reader.vector import batchify 22 | from ..reader.data import ReaderDataset, SortedBatchSampler 23 | from .. import reader 24 | from .. import tokenizers 25 | from . import DEFAULTS 26 | 27 | logger = logging.getLogger(__name__) 28 | 29 | 30 | # ------------------------------------------------------------------------------ 31 | # Multiprocessing functions to fetch and tokenize text 32 | # ------------------------------------------------------------------------------ 33 | 34 | PROCESS_TOK = None 35 | PROCESS_DB = None 36 | PROCESS_CANDS = None 37 | 38 | 39 | def init(tokenizer_class, tokenizer_opts, db_class, db_opts, candidates=None): 40 | global PROCESS_TOK, PROCESS_DB, PROCESS_CANDS 41 | PROCESS_TOK = tokenizer_class(**tokenizer_opts) 42 | Finalize(PROCESS_TOK, PROCESS_TOK.shutdown, exitpriority=100) 43 | PROCESS_DB = db_class(**db_opts) 44 | Finalize(PROCESS_DB, PROCESS_DB.close, exitpriority=100) 45 | PROCESS_CANDS = candidates 46 | 47 | 48 | def fetch_text(doc_id): 49 | global PROCESS_DB 50 | return PROCESS_DB.get_doc_text(doc_id) 51 | 52 | 53 | def tokenize_text(text): 54 | global PROCESS_TOK 55 | return PROCESS_TOK.tokenize(text) 56 | 57 | 58 | # ------------------------------------------------------------------------------ 59 | # Main DrQA pipeline 60 | # ------------------------------------------------------------------------------ 61 | 62 | 63 | class DrQA(object): 64 | # Target size for squashing short paragraphs together. 65 | # 0 = read every paragraph independently 66 | # infty = read all paragraphs together 67 | GROUP_LENGTH = 0 68 | 69 | def __init__( 70 | self, 71 | reader_model=None, 72 | embedding_file=None, 73 | tokenizer=None, 74 | fixed_candidates=None, 75 | batch_size=128, 76 | cuda=True, 77 | data_parallel=False, 78 | max_loaders=5, 79 | num_workers=None, 80 | db_config=None, 81 | ranker_config=None 82 | ): 83 | """Initialize the pipeline. 84 | 85 | Args: 86 | reader_model: model file from which to load the DocReader. 87 | embedding_file: if given, will expand DocReader dictionary to use 88 | all available pretrained embeddings. 89 | tokenizer: string option to specify tokenizer used on docs. 90 | fixed_candidates: if given, all predictions will be constrated to 91 | the set of candidates contained in the file. One entry per line. 92 | batch_size: batch size when processing paragraphs. 93 | cuda: whether to use the gpu. 94 | data_parallel: whether to use multile gpus. 95 | max_loaders: max number of async data loading workers when reading. 96 | (default is fine). 97 | num_workers: number of parallel CPU processes to use for tokenizing 98 | and post processing resuls. 99 | db_config: config for doc db. 100 | ranker_config: config for ranker. 101 | """ 102 | self.batch_size = batch_size 103 | self.max_loaders = max_loaders 104 | self.fixed_candidates = fixed_candidates is not None 105 | self.cuda = cuda 106 | 107 | logger.info('Initializing document ranker...') 108 | ranker_config = ranker_config or {} 109 | ranker_class = ranker_config.get('class', DEFAULTS['ranker']) 110 | ranker_opts = ranker_config.get('options', {}) 111 | self.ranker = ranker_class(**ranker_opts) 112 | 113 | logger.info('Initializing model...') 114 | reader_model = reader_model or DEFAULTS['reader_model'] 115 | self.reader = reader.DocReader.load(reader_model, normalize=False) 116 | if embedding_file: 117 | logger.info('Expanding dictionary...') 118 | words = reader.utils.index_embedding_words(embedding_file) 119 | added = self.reader.expand_dictionary(words) 120 | self.reader.load_embeddings(added, embedding_file) 121 | if cuda: 122 | self.reader.cuda() 123 | if data_parallel: 124 | self.reader.parallelize() 125 | 126 | if not tokenizer: 127 | tok_class = DEFAULTS['tokenizer'] 128 | else: 129 | tok_class = tokenizers.get_class(tokenizer) 130 | annotators = tokenizers.get_annotators_for_model(self.reader) 131 | tok_opts = {'annotators': annotators} 132 | 133 | db_config = db_config or {} 134 | db_class = db_config.get('class', DEFAULTS['db']) 135 | db_opts = db_config.get('options', {}) 136 | 137 | logger.info('Initializing tokenizers and document retrievers...') 138 | self.num_workers = num_workers 139 | self.processes = ProcessPool( 140 | num_workers, 141 | initializer=init, 142 | initargs=(tok_class, tok_opts, db_class, db_opts, fixed_candidates) 143 | ) 144 | 145 | def _split_doc(self, doc): 146 | """Given a doc, split it into chunks (by paragraph).""" 147 | curr = [] 148 | curr_len = 0 149 | for split in regex.split(r'\n+', doc): 150 | split = split.strip() 151 | if len(split) == 0: 152 | continue 153 | # Maybe group paragraphs together until we hit a length limit 154 | if len(curr) > 0 and curr_len + len(split) > self.GROUP_LENGTH: 155 | yield ' '.join(curr) 156 | curr = [] 157 | curr_len = 0 158 | curr.append(split) 159 | curr_len += len(split) 160 | if len(curr) > 0: 161 | yield ' '.join(curr) 162 | 163 | def _get_loader(self, data, num_loaders): 164 | """Return a pytorch data iterator for provided examples.""" 165 | dataset = ReaderDataset(data, self.reader) 166 | sampler = SortedBatchSampler( 167 | dataset.lengths(), 168 | self.batch_size, 169 | shuffle=False 170 | ) 171 | loader = torch.utils.data.DataLoader( 172 | dataset, 173 | batch_size=self.batch_size, 174 | sampler=sampler, 175 | num_workers=num_loaders, 176 | collate_fn=batchify, 177 | pin_memory=self.cuda, 178 | ) 179 | return loader 180 | 181 | def process(self, query, candidates=None, top_n=1, n_docs=5, 182 | return_context=False): 183 | """Run a single query.""" 184 | predictions = self.process_batch( 185 | [query], [candidates] if candidates else None, 186 | top_n, n_docs, return_context 187 | ) 188 | return predictions[0] 189 | 190 | def process_batch(self, queries, candidates=None, top_n=1, n_docs=5, 191 | return_context=False): 192 | """Run a batch of queries (more efficient).""" 193 | t0 = time.time() 194 | logger.info('Processing %d queries...' % len(queries)) 195 | logger.info('Retrieving top %d docs...' % n_docs) 196 | 197 | # Rank documents for queries. 198 | if len(queries) == 1: 199 | ranked = [self.ranker.closest_docs(queries[0], k=n_docs)] 200 | else: 201 | ranked = self.ranker.batch_closest_docs( 202 | queries, k=n_docs, num_workers=self.num_workers 203 | ) 204 | all_docids, all_doc_scores = zip(*ranked) 205 | 206 | # Flatten document ids and retrieve text from database. 207 | # We remove duplicates for processing efficiency. 208 | flat_docids = list({d for docids in all_docids for d in docids}) 209 | did2didx = {did: didx for didx, did in enumerate(flat_docids)} 210 | doc_texts = self.processes.map(fetch_text, flat_docids) 211 | 212 | # Split and flatten documents. Maintain a mapping from doc (index in 213 | # flat list) to split (index in flat list). 214 | flat_splits = [] 215 | didx2sidx = [] 216 | for text in doc_texts: 217 | splits = self._split_doc(text) 218 | didx2sidx.append([len(flat_splits), -1]) 219 | for split in splits: 220 | flat_splits.append(split) 221 | didx2sidx[-1][1] = len(flat_splits) 222 | 223 | # Push through the tokenizers as fast as possible. 224 | q_tokens = self.processes.map_async(tokenize_text, queries) 225 | s_tokens = self.processes.map_async(tokenize_text, flat_splits) 226 | q_tokens = q_tokens.get() 227 | s_tokens = s_tokens.get() 228 | 229 | # Group into structured example inputs. Examples' ids represent 230 | # mappings to their question, document, and split ids. 231 | examples = [] 232 | for qidx in range(len(queries)): 233 | for rel_didx, did in enumerate(all_docids[qidx]): 234 | start, end = didx2sidx[did2didx[did]] 235 | for sidx in range(start, end): 236 | if (len(q_tokens[qidx].words()) > 0 and 237 | len(s_tokens[sidx].words()) > 0): 238 | examples.append({ 239 | 'id': (qidx, rel_didx, sidx), 240 | 'question': q_tokens[qidx].words(), 241 | 'qlemma': q_tokens[qidx].lemmas(), 242 | 'document': s_tokens[sidx].words(), 243 | 'lemma': s_tokens[sidx].lemmas(), 244 | 'pos': s_tokens[sidx].pos(), 245 | 'ner': s_tokens[sidx].entities(), 246 | }) 247 | 248 | logger.info('Reading %d paragraphs...' % len(examples)) 249 | 250 | # Push all examples through the document reader. 251 | # We decode argmax start/end indices asychronously on CPU. 252 | result_handles = [] 253 | num_loaders = min(self.max_loaders, math.floor(len(examples) / 1e3)) 254 | for batch in self._get_loader(examples, num_loaders): 255 | if candidates or self.fixed_candidates: 256 | batch_cands = [] 257 | for ex_id in batch[-1]: 258 | batch_cands.append({ 259 | 'input': s_tokens[ex_id[2]], 260 | 'cands': candidates[ex_id[0]] if candidates else None 261 | }) 262 | handle = self.reader.predict( 263 | batch, batch_cands, async_pool=self.processes 264 | ) 265 | else: 266 | handle = self.reader.predict(batch, async_pool=self.processes) 267 | result_handles.append((handle, batch[-1], batch[0].size(0))) 268 | 269 | # Iterate through the predictions, and maintain priority queues for 270 | # top scored answers for each question in the batch. 271 | queues = [[] for _ in range(len(queries))] 272 | for result, ex_ids, batch_size in result_handles: 273 | s, e, score = result.get() 274 | for i in range(batch_size): 275 | # We take the top prediction per split. 276 | if len(score[i]) > 0: 277 | item = (score[i][0], ex_ids[i], s[i][0], e[i][0]) 278 | queue = queues[ex_ids[i][0]] 279 | if len(queue) < top_n: 280 | heapq.heappush(queue, item) 281 | else: 282 | heapq.heappushpop(queue, item) 283 | 284 | # Arrange final top prediction data. 285 | all_predictions = [] 286 | for queue in queues: 287 | predictions = [] 288 | while len(queue) > 0: 289 | score, (qidx, rel_didx, sidx), s, e = heapq.heappop(queue) 290 | prediction = { 291 | 'doc_id': all_docids[qidx][rel_didx], 292 | 'span': s_tokens[sidx].slice(s, e + 1).untokenize(), 293 | 'doc_score': float(all_doc_scores[qidx][rel_didx]), 294 | 'span_score': float(score), 295 | } 296 | if return_context: 297 | prediction['context'] = { 298 | 'text': s_tokens[sidx].untokenize(), 299 | 'start': s_tokens[sidx].offsets()[s][0], 300 | 'end': s_tokens[sidx].offsets()[e][1], 301 | } 302 | predictions.append(prediction) 303 | all_predictions.append(predictions[-1::-1]) 304 | 305 | logger.info('Processed %d queries in %.4f (s)' % 306 | (len(queries), time.time() - t0)) 307 | 308 | return all_predictions 309 | -------------------------------------------------------------------------------- /drqa/pipeline/simpleDrQA.py: -------------------------------------------------------------------------------- 1 | import re 2 | import sqlite3 3 | from drqa import retriever 4 | from drqa.retriever.net_retriever import retriver 5 | from drqa.tokenizers.zh_features import normalize, STOPWORDS 6 | import jieba 7 | import logging 8 | import Levenshtein 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | # simplesingel thread DrQA agent 13 | 14 | 15 | class SDrQA(object): 16 | def __init__(self, predictor, rankerPath, dbPath, ebdPath=None): 17 | self.predictor = predictor 18 | self.ranker = retriever.get_class('tfidf')(tfidf_path=rankerPath) 19 | conn = sqlite3.connect(dbPath) 20 | self.db = conn.cursor() 21 | self.filter = filtText('drqa/features/map.txt') 22 | self.score = contextScore(ebdPath) 23 | 24 | def predict(self, query, qasTopN=1, docTopN=1, netTopN=1): 25 | def process(text): 26 | ans = [] 27 | print('=================raw text==================') 28 | print(text) 29 | print('===================================') 30 | lines = self.BrealLine(text) 31 | for line in lines: 32 | predictions = self.predictor.predict( 33 | line, query, candidates=None, top_n=qasTopN) 34 | for p in predictions: 35 | ans.append({ 36 | 'text': line, 37 | 'contextScore': self.score.releventScore(line, query), 38 | 'answer': p[0], 39 | 'answerScore': p[1] 40 | }) 41 | return ans 42 | query = self.NormAndFilt(query) 43 | logger.info('[question after filting : %s ]' % query) 44 | ans = [] 45 | if netTopN > 0: 46 | docs = self.retrieveFromNet(query, k=netTopN) 47 | logger.info('[retreive from net : %s | expect : %s]' % 48 | (len(docs), netTopN)) 49 | for i, text in enumerate(docs): 50 | ans.extend(process(text)) 51 | 52 | logger.info('[retreive from db]') 53 | doc_names, doc_scores = self.ranker.closest_docs(query, k=docTopN) 54 | for i, doc in enumerate(doc_names): 55 | cursor = self.db.execute( 56 | 'SELECT text from documents WHERE id = "%s"' % doc) 57 | for row in cursor: 58 | text = row[0] 59 | ans.extend(process(text)) 60 | return ans 61 | 62 | def retrieveFromNet(self, text, k=1): 63 | texts = retriver(text, k) 64 | return [self.NormAndFilt(t) for t in texts] 65 | 66 | def BrealLine(self, text, minLen=64, maxLen=128): 67 | curr = [] 68 | curr_len = 0 69 | 70 | def replace(match): 71 | s = text[match.start():match.end()] 72 | return s.replace('.', '$$$') 73 | text = re.sub('[[0-9]+\.[0-9]+]', replace, text) 74 | for split in re.split('[\n+\.+\?+\!+]', text): 75 | split = split.strip().replace('$$$', '.') 76 | if len(split) == 0: 77 | continue 78 | # Maybe group paragraphs together until we hit a length limit 79 | if len(curr) > 0 and curr_len + len(split) > maxLen: 80 | yield ' '.join(curr) 81 | curr = [] 82 | curr_len = 0 83 | curr.append(split) 84 | curr_len += len(split) 85 | if len(curr) > 0: 86 | yield ' '.join(curr) 87 | 88 | def NormAndFilt(self, text): 89 | return self.filter.filt(normalize(text)) 90 | 91 | 92 | class filtText(object): 93 | def __init__(self, path): 94 | self.table = {} 95 | if not path: 96 | return 97 | 98 | with open(path, encoding='utf-8') as f: 99 | for line in f: 100 | l = line.split(' ') 101 | self.table[l[0]] = l[1] 102 | 103 | def filt(self, text, ng=1): 104 | for key in self.table.keys(): 105 | val = self.table[key].replace('\n', '') 106 | l = text.split(key) 107 | ngram = None 108 | bngram = None 109 | if key in val: 110 | ngram = val[val.find(key) + 111 | len(key): val.find(key) + len(key) + ng] 112 | bngram = val[ngram if val.find( 113 | key) - ng >= 0 else 0:val.find(key)] 114 | idx = 0 115 | tout = '' 116 | for sep in l: 117 | tout += sep 118 | if idx + 1 < len(l) and key in val and not\ 119 | (ngram and len(l[idx + 1]) > len(ngram) and not 120 | l[idx + 1][0:len(ngram)] == ngram) and not\ 121 | (bngram and len(text) > len(bngram) and not 122 | text[-len(bngram):] == bngram): 123 | tout += key 124 | elif idx + 1 < len(l): 125 | tout += val 126 | idx += 1 127 | text = tout 128 | 129 | return text 130 | 131 | 132 | class contextScore(object): 133 | def __init__(self, dictpath=None): 134 | self.dic = {} 135 | if not dictpath: 136 | return 137 | 138 | logger.info('[ loading embedding for text score ]') 139 | with open(dictpath) as f: 140 | for line in f: 141 | parsed = line.rstrip().split(' ') 142 | w = normalize(parsed[0]) 143 | vec = [float(i) for i in parsed[1:]] 144 | self.dic[w] = vec 145 | 146 | def releventScore(self, text, ques, tfidf={}): 147 | def filtWord(li): 148 | # filt out stop words 149 | nl = [] 150 | for l in li: 151 | if l not in STOPWORDS: 152 | nl.append(l) 153 | return nl 154 | 155 | def sims(t, q): 156 | if t in self.dic.keys() and q in self.dic.keys(): 157 | vector1 = self.dic[t] 158 | vector2 = self.dic[q] 159 | dot_product = 0.0 160 | normA = 0.0 161 | normB = 0.0 162 | for a, b in zip(vector1, vector2): 163 | dot_product += a * b 164 | normA += a**2 165 | normB += b**2 166 | if normA == 0.0 or normB == 0.0: 167 | return 0 168 | else: 169 | return dot_product / ((normA * normB)**0.5) 170 | else: 171 | l = max([len(t), len(q)]) 172 | if Levenshtein.distance(t, q) < l: 173 | return (l - Levenshtein.distance(t, q)) / l * 0.7 174 | else: 175 | return 0 176 | 177 | ttoks = filtWord(jieba.lcut_for_search(text)) 178 | qtoks = filtWord(jieba.lcut_for_search(ques)) 179 | 180 | score = 0 181 | if len(ttoks) == 0: 182 | return 0 183 | for tword in ttoks: 184 | for qword in qtoks: 185 | 186 | if tword in tfidf.keys(): 187 | rate = tfidf[tword] 188 | else: 189 | rate = 1 190 | 191 | if tword == qword: 192 | # exact match 193 | score += rate * 2.5 194 | elif sims(tword, qword) > 0.4: 195 | # similar 196 | score += sims(tword, qword) * rate 197 | # remove advantage of length 198 | return score / len(ttoks) / len(qtoks) * 100 199 | -------------------------------------------------------------------------------- /drqa/reader/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import os 9 | from ..tokenizers import CoreNLPTokenizer 10 | from .. import DATA_DIR 11 | 12 | 13 | DEFAULTS = { 14 | 'tokenizer': CoreNLPTokenizer, 15 | 'model': os.path.join(DATA_DIR, 'reader/single.mdl'), 16 | } 17 | 18 | 19 | def set_default(key, value): 20 | global DEFAULTS 21 | DEFAULTS[key] = value 22 | 23 | from .model import DocReader 24 | from .predictor import Predictor 25 | from . import config 26 | from . import vector 27 | from . import data 28 | from . import utils 29 | -------------------------------------------------------------------------------- /drqa/reader/config.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | """Model architecture/optimization options for DrQA document reader.""" 8 | 9 | import argparse 10 | import logging 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | # Index of arguments concerning the core model architecture 15 | MODEL_ARCHITECTURE = { 16 | 'model_type', 'embedding_dim', 'hidden_size', 'doc_layers', 17 | 'question_layers', 'rnn_type', 'concat_rnn_layers', 'question_merge', 18 | 'use_qemb', 'use_in_question', 'use_pos', 'use_ner', 'use_lemma', 'use_tf', 19 | 'use_similarity' 20 | } 21 | 22 | # Index of arguments concerning the model optimizer/training 23 | MODEL_OPTIMIZER = { 24 | 'fix_embeddings', 'optimizer', 'learning_rate', 'momentum', 'weight_decay', 25 | 'rnn_padding', 'dropout_rnn', 'dropout_rnn_output', 'dropout_emb', 26 | 'max_len', 'grad_clipping', 'tune_partial', 'use_similarity' 27 | } 28 | 29 | 30 | def str2bool(v): 31 | return v.lower() in ('yes', 'true', 't', '1', 'y') 32 | 33 | 34 | def add_model_args(parser): 35 | parser.register('type', 'bool', str2bool) 36 | 37 | # Model architecture 38 | model = parser.add_argument_group('DrQA Reader Model Architecture') 39 | model.add_argument('--model-type', type=str, default='rnn', 40 | help='Model architecture type') 41 | model.add_argument('--embedding-dim', type=int, default=300, 42 | help='Embedding size if embedding_file is not given') 43 | model.add_argument('--hidden-size', type=int, default=128, 44 | help='Hidden size of RNN units') 45 | model.add_argument('--doc-layers', type=int, default=3, 46 | help='Number of encoding layers for document') 47 | model.add_argument('--question-layers', type=int, default=3, 48 | help='Number of encoding layers for question') 49 | model.add_argument('--rnn-type', type=str, default='lstm', 50 | help='RNN type: LSTM, GRU, or RNN') 51 | 52 | # Model specific details 53 | detail = parser.add_argument_group('DrQA Reader Model Details') 54 | detail.add_argument('--concat-rnn-layers', type='bool', default=True, 55 | help='Combine hidden states from each encoding layer') 56 | detail.add_argument('--question-merge', type=str, default='self_attn', 57 | help='The way of computing the question representation') 58 | detail.add_argument('--use-qemb', type='bool', default=True, 59 | help='Whether to use weighted question embeddings') 60 | detail.add_argument('--use-in-question', type='bool', default=True, 61 | help='Whether to use in_question_* features (including pinyin digit)') 62 | detail.add_argument('--use-pos', type='bool', default=True, 63 | help='Whether to use pos features') 64 | detail.add_argument('--use-ner', type='bool', default=True, 65 | help='Whether to use ner features') 66 | detail.add_argument('--use-lemma', type='bool', default=True, 67 | help='Whether to use lemma features (translation in chinese)') 68 | detail.add_argument('--use-tf', type='bool', default=True, 69 | help='Whether to use term frequency features') 70 | detail.add_argument('--use-similarity', type='bool', default=True, 71 | help='Whether to use highest similarity between words ' 72 | + 'as extra lemma feature (experimental)') 73 | 74 | # Optimization details 75 | optim = parser.add_argument_group('DrQA Reader Optimization') 76 | optim.add_argument('--dropout-emb', type=float, default=0.4, 77 | help='Dropout rate for word embeddings') 78 | optim.add_argument('--dropout-rnn', type=float, default=0.4, 79 | help='Dropout rate for RNN states') 80 | optim.add_argument('--dropout-rnn-output', type='bool', default=True, 81 | help='Whether to dropout the RNN output') 82 | optim.add_argument('--optimizer', type=str, default='adamax', 83 | help='Optimizer: sgd or adamax') 84 | optim.add_argument('--learning-rate', type=float, default=0.1, 85 | help='Learning rate for SGD only') 86 | optim.add_argument('--grad-clipping', type=float, default=10, 87 | help='Gradient clipping') 88 | optim.add_argument('--weight-decay', type=float, default=0, 89 | help='Weight decay factor') 90 | optim.add_argument('--momentum', type=float, default=0, 91 | help='Momentum factor') 92 | optim.add_argument('--fix-embeddings', type='bool', default=True, 93 | help='Keep word embeddings fixed (use pretrained)') 94 | optim.add_argument('--tune-partial', type=int, default=0, 95 | help='Backprop through only the top N question words') 96 | optim.add_argument('--rnn-padding', type='bool', default=False, 97 | help='Explicitly account for padding in RNN encoding') 98 | optim.add_argument('--max-len', type=int, default=15, 99 | help='The max span allowed during decoding') 100 | 101 | 102 | def get_model_args(args): 103 | """Filter args for model ones. 104 | 105 | From a args Namespace, return a new Namespace with *only* the args specific 106 | to the model architecture or optimization. (i.e. the ones defined here.) 107 | """ 108 | global MODEL_ARCHITECTURE, MODEL_OPTIMIZER 109 | required_args = MODEL_ARCHITECTURE | MODEL_OPTIMIZER 110 | arg_values = {k: v for k, v in vars(args).items() if k in required_args} 111 | return argparse.Namespace(**arg_values) 112 | 113 | 114 | def override_model_args(old_args, new_args): 115 | """Set args to new parameters. 116 | 117 | Decide which model args to keep and which to override when resolving a set 118 | of saved args and new args. 119 | 120 | We keep the new optimation, but leave the model architecture alone. 121 | """ 122 | global MODEL_OPTIMIZER 123 | old_args, new_args = vars(old_args), vars(new_args) 124 | for k in old_args.keys(): 125 | if k in new_args and old_args[k] != new_args[k]: 126 | if k in MODEL_OPTIMIZER: 127 | logger.info('Overriding saved %s: %s --> %s' % 128 | (k, old_args[k], new_args[k])) 129 | old_args[k] = new_args[k] 130 | else: 131 | logger.info('Keeping saved %s: %s' % (k, old_args[k])) 132 | return argparse.Namespace(**old_args) 133 | -------------------------------------------------------------------------------- /drqa/reader/data.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | """Data processing/loading helpers.""" 8 | 9 | import numpy as np 10 | import logging 11 | import unicodedata 12 | 13 | from torch.utils.data import Dataset 14 | from torch.utils.data.sampler import Sampler 15 | from .vector import vectorize 16 | from ..tokenizers.zh_features import normalize 17 | logger = logging.getLogger(__name__) 18 | 19 | 20 | # ------------------------------------------------------------------------------ 21 | # Dictionary class for tokens. 22 | # ------------------------------------------------------------------------------ 23 | 24 | 25 | class Dictionary(object): 26 | NULL = '' 27 | UNK = '' 28 | START = 2 29 | 30 | @staticmethod 31 | def normalize(token): 32 | return normalize(token) 33 | 34 | def __init__(self): 35 | self.tok2ind = {self.NULL: 0, self.UNK: 1} 36 | self.ind2tok = {0: self.NULL, 1: self.UNK} 37 | 38 | def __len__(self): 39 | return len(self.tok2ind) 40 | 41 | def __iter__(self): 42 | return iter(self.tok2ind) 43 | 44 | def __contains__(self, key): 45 | if type(key) == int: 46 | return key in self.ind2tok 47 | elif type(key) == str: 48 | return self.normalize(key) in self.tok2ind 49 | 50 | def __getitem__(self, key): 51 | if type(key) == int: 52 | return self.ind2tok.get(key, self.UNK) 53 | if type(key) == str: 54 | return self.tok2ind.get(self.normalize(key), 55 | self.tok2ind.get(self.UNK)) 56 | 57 | def __setitem__(self, key, item): 58 | if type(key) == int and type(item) == str: 59 | self.ind2tok[key] = item 60 | elif type(key) == str and type(item) == int: 61 | self.tok2ind[key] = item 62 | else: 63 | raise RuntimeError('Invalid (key, item) types.') 64 | 65 | def add(self, token): 66 | token = self.normalize(token) 67 | if token not in self.tok2ind: 68 | index = len(self.tok2ind) 69 | self.tok2ind[token] = index 70 | self.ind2tok[index] = token 71 | 72 | def tokens(self): 73 | """Get dictionary tokens. 74 | 75 | Return all the words indexed by this dictionary, except for special 76 | tokens. 77 | """ 78 | tokens = [k for k in self.tok2ind.keys() 79 | if k not in {'', ''}] 80 | return tokens 81 | 82 | 83 | # ------------------------------------------------------------------------------ 84 | # PyTorch dataset class for SQuAD (and SQuAD-like) data. 85 | # ------------------------------------------------------------------------------ 86 | 87 | 88 | class ReaderDataset(Dataset): 89 | 90 | def __init__(self, examples, model, single_answer=False): 91 | self.model = model 92 | self.examples = examples 93 | self.single_answer = single_answer 94 | 95 | def __len__(self): 96 | return len(self.examples) 97 | 98 | def __getitem__(self, index): 99 | return vectorize(self.examples[index], self.model, self.single_answer) 100 | 101 | def lengths(self): 102 | return [(len(ex['document']), len(ex['question'])) 103 | for ex in self.examples] 104 | 105 | 106 | # ------------------------------------------------------------------------------ 107 | # PyTorch sampler returning batched of sorted lengths (by doc and question). 108 | # ------------------------------------------------------------------------------ 109 | 110 | 111 | class SortedBatchSampler(Sampler): 112 | 113 | def __init__(self, lengths, batch_size, shuffle=True): 114 | self.lengths = lengths 115 | self.batch_size = batch_size 116 | self.shuffle = shuffle 117 | 118 | def __iter__(self): 119 | lengths = np.array( 120 | [(-l[0], -l[1], np.random.random()) for l in self.lengths], 121 | dtype=[('l1', np.int_), ('l2', np.int_), ('rand', np.float_)] 122 | ) 123 | indices = np.argsort(lengths, order=('l1', 'l2', 'rand')) 124 | batches = [indices[i:i + self.batch_size] 125 | for i in range(0, len(indices), self.batch_size)] 126 | if self.shuffle: 127 | np.random.shuffle(batches) 128 | return iter([i for batch in batches for i in batch]) 129 | 130 | def __len__(self): 131 | return len(self.lengths) 132 | -------------------------------------------------------------------------------- /drqa/reader/layers.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | """Definitions of model layers/NN modules""" 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | from torch.autograd import Variable 13 | 14 | 15 | # ------------------------------------------------------------------------------ 16 | # Modules 17 | # ------------------------------------------------------------------------------ 18 | 19 | 20 | class StackedBRNN(nn.Module): 21 | """Stacked Bi-directional RNNs. 22 | 23 | Differs from standard PyTorch library in that it has the option to save 24 | and concat the hidden states between layers. (i.e. the output hidden size 25 | for each sequence input is num_layers * hidden_size). 26 | """ 27 | 28 | def __init__(self, input_size, hidden_size, num_layers, 29 | dropout_rate=0, dropout_output=False, rnn_type=nn.LSTM, 30 | concat_layers=False, padding=False): 31 | super(StackedBRNN, self).__init__() 32 | self.padding = padding 33 | self.dropout_output = dropout_output 34 | self.dropout_rate = dropout_rate 35 | self.num_layers = num_layers 36 | self.concat_layers = concat_layers 37 | self.rnns = nn.ModuleList() 38 | for i in range(num_layers): 39 | input_size = input_size if i == 0 else 2 * hidden_size 40 | self.rnns.append(rnn_type(input_size, hidden_size, 41 | num_layers=1, 42 | bidirectional=True)) 43 | 44 | def forward(self, x, x_mask): 45 | """Encode either padded or non-padded sequences. 46 | 47 | Can choose to either handle or ignore variable length sequences. 48 | Always handle padding in eval. 49 | 50 | Args: 51 | x: batch * len * hdim 52 | x_mask: batch * len (1 for padding, 0 for true) 53 | Output: 54 | x_encoded: batch * len * hdim_encoded 55 | """ 56 | if x_mask.data.sum() == 0: 57 | # No padding necessary. 58 | output = self._forward_unpadded(x, x_mask) 59 | elif self.padding or not self.training: 60 | # Pad if we care or if its during eval. 61 | output = self._forward_padded(x, x_mask) 62 | else: 63 | # We don't care. 64 | output = self._forward_unpadded(x, x_mask) 65 | 66 | return output.contiguous() 67 | 68 | def _forward_unpadded(self, x, x_mask): 69 | """Faster encoding that ignores any padding.""" 70 | # Transpose batch and sequence dims 71 | x = x.transpose(0, 1) 72 | 73 | # Encode all layers 74 | outputs = [x] 75 | for i in range(self.num_layers): 76 | rnn_input = outputs[-1] 77 | 78 | # Apply dropout to hidden input 79 | if self.dropout_rate > 0: 80 | rnn_input = F.dropout(rnn_input, 81 | p=self.dropout_rate, 82 | training=self.training) 83 | # Forward 84 | rnn_output = self.rnns[i](rnn_input)[0] 85 | outputs.append(rnn_output) 86 | 87 | # Concat hidden layers 88 | if self.concat_layers: 89 | output = torch.cat(outputs[1:], 2) 90 | else: 91 | output = outputs[-1] 92 | 93 | # Transpose back 94 | output = output.transpose(0, 1) 95 | 96 | # Dropout on output layer 97 | if self.dropout_output and self.dropout_rate > 0: 98 | output = F.dropout(output, 99 | p=self.dropout_rate, 100 | training=self.training) 101 | return output 102 | 103 | def _forward_padded(self, x, x_mask): 104 | """Slower (significantly), but more precise, encoding that handles 105 | padding. 106 | """ 107 | # Compute sorted sequence lengths 108 | lengths = x_mask.data.eq(0).long().sum(1).squeeze() 109 | _, idx_sort = torch.sort(lengths, dim=0, descending=True) 110 | _, idx_unsort = torch.sort(idx_sort, dim=0) 111 | 112 | lengths = list(lengths[idx_sort]) 113 | idx_sort = Variable(idx_sort) 114 | idx_unsort = Variable(idx_unsort) 115 | 116 | # Sort x 117 | x = x.index_select(0, idx_sort) 118 | 119 | # Transpose batch and sequence dims 120 | x = x.transpose(0, 1) 121 | 122 | # Pack it up 123 | rnn_input = nn.utils.rnn.pack_padded_sequence(x, lengths) 124 | 125 | # Encode all layers 126 | outputs = [rnn_input] 127 | for i in range(self.num_layers): 128 | rnn_input = outputs[-1] 129 | 130 | # Apply dropout to input 131 | if self.dropout_rate > 0: 132 | dropout_input = F.dropout(rnn_input.data, 133 | p=self.dropout_rate, 134 | training=self.training) 135 | rnn_input = nn.utils.rnn.PackedSequence(dropout_input, 136 | rnn_input.batch_sizes) 137 | outputs.append(self.rnns[i](rnn_input)[0]) 138 | 139 | # Unpack everything 140 | for i, o in enumerate(outputs[1:], 1): 141 | outputs[i] = nn.utils.rnn.pad_packed_sequence(o)[0] 142 | 143 | # Concat hidden layers or take final 144 | if self.concat_layers: 145 | output = torch.cat(outputs[1:], 2) 146 | else: 147 | output = outputs[-1] 148 | 149 | # Transpose and unsort 150 | output = output.transpose(0, 1) 151 | output = output.index_select(0, idx_unsort) 152 | 153 | # Pad up to original batch sequence length 154 | if output.size(1) != x_mask.size(1): 155 | padding = torch.zeros(output.size(0), 156 | x_mask.size(1) - output.size(1), 157 | output.size(2)).type(output.data.type()) 158 | output = torch.cat([output, Variable(padding)], 1) 159 | 160 | # Dropout on output layer 161 | if self.dropout_output and self.dropout_rate > 0: 162 | output = F.dropout(output, 163 | p=self.dropout_rate, 164 | training=self.training) 165 | return output 166 | 167 | 168 | class SeqAttnMatch(nn.Module): 169 | """Given sequences X and Y, match sequence Y to each element in X. 170 | 171 | * o_i = sum(alpha_j * y_j) for i in X 172 | * alpha_j = softmax(y_j * x_i) 173 | """ 174 | 175 | def __init__(self, input_size, identity=False): 176 | super(SeqAttnMatch, self).__init__() 177 | if not identity: 178 | self.linear = nn.Linear(input_size, input_size) 179 | else: 180 | self.linear = None 181 | 182 | def forward(self, x, y, y_mask): 183 | """ 184 | Args: 185 | x: batch * len1 * hdim 186 | y: batch * len2 * hdim 187 | y_mask: batch * len2 (1 for padding, 0 for true) 188 | Output: 189 | matched_seq: batch * len1 * hdim 190 | """ 191 | # Project vectors 192 | if self.linear: 193 | x_proj = self.linear(x.view(-1, x.size(2))).view(x.size()) 194 | x_proj = F.relu(x_proj) 195 | y_proj = self.linear(y.view(-1, y.size(2))).view(y.size()) 196 | y_proj = F.relu(y_proj) 197 | else: 198 | x_proj = x 199 | y_proj = y 200 | 201 | # Compute scores 202 | scores = x_proj.bmm(y_proj.transpose(2, 1)) 203 | 204 | # Mask padding 205 | y_mask = y_mask.unsqueeze(1).expand(scores.size()) 206 | scores.data.masked_fill_(y_mask.data, -float('inf')) 207 | 208 | # Normalize with softmax 209 | alpha_flat = F.softmax(scores.view(-1, y.size(1))) 210 | alpha = alpha_flat.view(-1, x.size(1), y.size(1)) 211 | 212 | # Take weighted average 213 | matched_seq = alpha.bmm(y) 214 | return matched_seq 215 | 216 | 217 | class BilinearSeqAttn(nn.Module): 218 | """A bilinear attention layer over a sequence X w.r.t y: 219 | 220 | * o_i = softmax(x_i'Wy) for x_i in X. 221 | 222 | Optionally don't normalize output weights. 223 | """ 224 | 225 | def __init__(self, x_size, y_size, identity=False, normalize=True): 226 | super(BilinearSeqAttn, self).__init__() 227 | self.normalize = normalize 228 | 229 | # If identity is true, we just use a dot product without transformation. 230 | if not identity: 231 | self.linear = nn.Linear(y_size, x_size) 232 | else: 233 | self.linear = None 234 | 235 | def forward(self, x, y, x_mask): 236 | """ 237 | Args: 238 | x: batch * len * hdim1 239 | y: batch * hdim2 240 | x_mask: batch * len (1 for padding, 0 for true) 241 | Output: 242 | alpha = batch * len 243 | """ 244 | Wy = self.linear(y) if self.linear is not None else y 245 | xWy = x.bmm(Wy.unsqueeze(2)).squeeze(2) 246 | xWy.data.masked_fill_(x_mask.data, -float('inf')) 247 | if self.normalize: 248 | if self.training: 249 | # In training we output log-softmax for NLL 250 | alpha = F.log_softmax(xWy) 251 | else: 252 | # ...Otherwise 0-1 probabilities 253 | alpha = F.softmax(xWy) 254 | else: 255 | alpha = xWy.exp() 256 | return alpha 257 | 258 | 259 | class LinearSeqAttn(nn.Module): 260 | """Self attention over a sequence: 261 | 262 | * o_i = softmax(Wx_i) for x_i in X. 263 | """ 264 | 265 | def __init__(self, input_size): 266 | super(LinearSeqAttn, self).__init__() 267 | self.linear = nn.Linear(input_size, 1) 268 | 269 | def forward(self, x, x_mask): 270 | """ 271 | Args: 272 | x: batch * len * hdim 273 | x_mask: batch * len (1 for padding, 0 for true) 274 | Output: 275 | alpha: batch * len 276 | """ 277 | x_flat = x.view(-1, x.size(-1)) 278 | scores = self.linear(x_flat).view(x.size(0), x.size(1)) 279 | scores.data.masked_fill_(x_mask.data, -float('inf')) 280 | alpha = F.softmax(scores) 281 | return alpha 282 | 283 | 284 | # ------------------------------------------------------------------------------ 285 | # Functional 286 | # ------------------------------------------------------------------------------ 287 | 288 | 289 | def uniform_weights(x, x_mask): 290 | """Return uniform weights over non-masked x (a sequence of vectors). 291 | 292 | Args: 293 | x: batch * len * hdim 294 | x_mask: batch * len (1 for padding, 0 for true) 295 | Output: 296 | x_avg: batch * hdim 297 | """ 298 | alpha = Variable(torch.ones(x.size(0), x.size(1))) 299 | if x.data.is_cuda: 300 | alpha = alpha.cuda() 301 | alpha = alpha * x_mask.eq(0).float() 302 | alpha = alpha / alpha.sum(1).expand(alpha.size()) 303 | return alpha 304 | 305 | 306 | def weighted_avg(x, weights): 307 | """Return a weighted average of x (a sequence of vectors). 308 | 309 | Args: 310 | x: batch * len * hdim 311 | weights: batch * len, sum(dim = 1) = 1 312 | Output: 313 | x_avg: batch * hdim 314 | """ 315 | return weights.unsqueeze(1).bmm(x).squeeze(1) 316 | -------------------------------------------------------------------------------- /drqa/reader/predictor.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | """DrQA Document Reader predictor""" 8 | 9 | import logging 10 | 11 | from multiprocessing import Pool as ProcessPool 12 | from multiprocessing.util import Finalize 13 | 14 | from .vector import vectorize, batchify 15 | from .model import DocReader 16 | from . import DEFAULTS, utils 17 | from .. import tokenizers 18 | 19 | logger = logging.getLogger(__name__) 20 | 21 | 22 | # ------------------------------------------------------------------------------ 23 | # Tokenize + annotate 24 | # ------------------------------------------------------------------------------ 25 | 26 | PROCESS_TOK = None 27 | 28 | 29 | def init(tokenizer_class, annotators): 30 | global PROCESS_TOK 31 | PROCESS_TOK = tokenizer_class(annotators=annotators) 32 | Finalize(PROCESS_TOK, PROCESS_TOK.shutdown, exitpriority=100) 33 | 34 | 35 | def tokenize(text): 36 | global PROCESS_TOK 37 | return PROCESS_TOK.tokenize(text) 38 | 39 | 40 | # ------------------------------------------------------------------------------ 41 | # Predictor class. 42 | # ------------------------------------------------------------------------------ 43 | 44 | 45 | class Predictor(object): 46 | """Load a pretrained DocReader model and predict inputs on the fly.""" 47 | 48 | def __init__(self, model=None, tokenizer=None, 49 | embedding_file=None, num_workers=None): 50 | """ 51 | Args: 52 | model: path to saved model file 53 | tokenizer: option string to select tokenizer class 54 | embedding_file: if provided, will expand dictionary to use all 55 | available pretrained vectors in this file. 56 | num_workers: number of CPU processes to use to preprocess batches. 57 | """ 58 | logger.info('Initializing model...') 59 | self.model = DocReader.load(model or DEFAULTS['model']) 60 | 61 | if embedding_file: 62 | logger.info('Expanding dictionary...') 63 | words = utils.index_embedding_words(embedding_file) 64 | added = self.model.expand_dictionary(words) 65 | self.model.load_embeddings(added, embedding_file) 66 | 67 | logger.info('Initializing tokenizer...') 68 | annotators = tokenizers.get_annotators_for_model(self.model) 69 | if not tokenizer: 70 | tokenizer_class = DEFAULTS['tokenizer'] 71 | else: 72 | tokenizer_class = tokenizers.get_class(tokenizer) 73 | 74 | if num_workers is None or num_workers > 0: 75 | self.workers = ProcessPool( 76 | num_workers, 77 | initializer=init, 78 | initargs=(tokenizer_class, annotators), 79 | ) 80 | else: 81 | self.workers = None 82 | self.tokenizer = tokenizer_class(annotators=annotators) 83 | 84 | def predict(self, document, question, candidates=None, top_n=1): 85 | """Predict a single document - question pair.""" 86 | results = self.predict_batch([(document, question, candidates,)], top_n) 87 | return results[0] 88 | 89 | def predict_batch(self, batch, top_n=1): 90 | """Predict a batch of document - question pairs.""" 91 | documents, questions, candidates = [], [], [] 92 | for b in batch: 93 | documents.append(b[0]) 94 | questions.append(b[1]) 95 | candidates.append(b[2] if len(b) == 3 else None) 96 | candidates = candidates if any(candidates) else None 97 | 98 | # Tokenize the inputs, perhaps multi-processed. 99 | if self.workers: 100 | q_tokens = self.workers.map_async(tokenize, questions) 101 | d_tokens = self.workers.map_async(tokenize, documents) 102 | q_tokens = list(q_tokens.get()) 103 | d_tokens = list(d_tokens.get()) 104 | else: 105 | q_tokens = list(map(self.tokenizer.tokenize, questions)) 106 | d_tokens = list(map(self.tokenizer.tokenize, documents)) 107 | 108 | examples = [] 109 | for i in range(len(questions)): 110 | examples.append({ 111 | 'id': i, 112 | 'question': q_tokens[i].words(), 113 | 'qlemma': q_tokens[i].lemmas(), 114 | 'document': d_tokens[i].words(), 115 | 'lemma': d_tokens[i].lemmas(), 116 | 'pos': d_tokens[i].pos(), 117 | 'ner': d_tokens[i].entities(), 118 | }) 119 | 120 | # Stick document tokens in candidates for decoding 121 | if candidates: 122 | candidates = [{'input': d_tokens[i], 'cands': candidates[i]} 123 | for i in range(len(candidates))] 124 | 125 | # Build the batch and run it through the model 126 | batch_exs = batchify([vectorize(e, self.model) for e in examples]) 127 | s, e, score = self.model.predict(batch_exs, candidates, top_n) 128 | 129 | # Retrieve the predicted spans 130 | results = [] 131 | for i in range(len(s)): 132 | predictions = [] 133 | for j in range(len(s[i])): 134 | span = d_tokens[i].slice(s[i][j], e[i][j] + 1).untokenize() 135 | predictions.append((span, score[i][j])) 136 | results.append(predictions) 137 | return results 138 | 139 | def cuda(self): 140 | self.model.cuda() 141 | 142 | def cpu(self): 143 | self.model.cpu() 144 | -------------------------------------------------------------------------------- /drqa/reader/rnn_reader.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | """Implementation of the RNN based DrQA reader.""" 8 | 9 | import torch 10 | import torch.nn as nn 11 | from . import layers 12 | 13 | 14 | # ------------------------------------------------------------------------------ 15 | # Network 16 | # ------------------------------------------------------------------------------ 17 | 18 | 19 | class RnnDocReader(nn.Module): 20 | RNN_TYPES = {'lstm': nn.LSTM, 'gru': nn.GRU, 'rnn': nn.RNN} 21 | 22 | def __init__(self, args, normalize=True): 23 | super(RnnDocReader, self).__init__() 24 | # Store config 25 | self.args = args 26 | 27 | # Word embeddings (+1 for padding) 28 | self.embedding = nn.Embedding(args.vocab_size, 29 | args.embedding_dim, 30 | padding_idx=0) 31 | 32 | # Projection for attention weighted question 33 | if args.use_qemb: 34 | self.qemb_match = layers.SeqAttnMatch(args.embedding_dim) 35 | 36 | # Input size to RNN: word emb + question emb + manual features 37 | doc_input_size = args.embedding_dim + args.num_features 38 | if args.use_qemb: 39 | doc_input_size += args.embedding_dim 40 | 41 | # RNN document encoder 42 | self.doc_rnn = layers.StackedBRNN( 43 | input_size=doc_input_size, 44 | hidden_size=args.hidden_size, 45 | num_layers=args.doc_layers, 46 | dropout_rate=args.dropout_rnn, 47 | dropout_output=args.dropout_rnn_output, 48 | concat_layers=args.concat_rnn_layers, 49 | rnn_type=self.RNN_TYPES[args.rnn_type], 50 | padding=args.rnn_padding, 51 | ) 52 | 53 | # RNN question encoder 54 | self.question_rnn = layers.StackedBRNN( 55 | input_size=args.embedding_dim, 56 | hidden_size=args.hidden_size, 57 | num_layers=args.question_layers, 58 | dropout_rate=args.dropout_rnn, 59 | dropout_output=args.dropout_rnn_output, 60 | concat_layers=args.concat_rnn_layers, 61 | rnn_type=self.RNN_TYPES[args.rnn_type], 62 | padding=args.rnn_padding, 63 | ) 64 | 65 | # Output sizes of rnn encoders 66 | doc_hidden_size = 2 * args.hidden_size 67 | question_hidden_size = 2 * args.hidden_size 68 | if args.concat_rnn_layers: 69 | doc_hidden_size *= args.doc_layers 70 | question_hidden_size *= args.question_layers 71 | 72 | # Question merging 73 | if args.question_merge not in ['avg', 'self_attn']: 74 | raise NotImplementedError('merge_mode = %s' % args.merge_mode) 75 | if args.question_merge == 'self_attn': 76 | self.self_attn = layers.LinearSeqAttn(question_hidden_size) 77 | 78 | # Bilinear attention for span start/end 79 | self.start_attn = layers.BilinearSeqAttn( 80 | doc_hidden_size, 81 | question_hidden_size, 82 | normalize=normalize, 83 | ) 84 | self.end_attn = layers.BilinearSeqAttn( 85 | doc_hidden_size, 86 | question_hidden_size, 87 | normalize=normalize, 88 | ) 89 | 90 | def forward(self, x1, x1_f, x1_mask, x2, x2_mask): 91 | """Inputs: 92 | x1 = document word indices [batch * len_d] 93 | x1_f = document word features indices [batch * len_d * nfeat] 94 | x1_mask = document padding mask [batch * len_d] 95 | x2 = question word indices [batch * len_q] 96 | x2_mask = question padding mask [batch * len_q] 97 | """ 98 | # Embed both document and question 99 | x1_emb = self.embedding(x1) 100 | x2_emb = self.embedding(x2) 101 | 102 | # Dropout on embeddings 103 | if self.args.dropout_emb > 0: 104 | x1_emb = nn.functional.dropout(x1_emb, p=self.args.dropout_emb, 105 | training=self.training) 106 | x2_emb = nn.functional.dropout(x2_emb, p=self.args.dropout_emb, 107 | training=self.training) 108 | 109 | # Form document encoding inputs 110 | drnn_input = [x1_emb] 111 | 112 | # Add attention-weighted question representation 113 | if self.args.use_qemb: 114 | x2_weighted_emb = self.qemb_match(x1_emb, x2_emb, x2_mask) 115 | drnn_input.append(x2_weighted_emb) 116 | 117 | # Add manual features 118 | if self.args.num_features > 0: 119 | drnn_input.append(x1_f) 120 | 121 | # Encode document with RNN 122 | doc_hiddens = self.doc_rnn(torch.cat(drnn_input, 2), x1_mask) 123 | 124 | # Encode question with RNN + merge hiddens 125 | question_hiddens = self.question_rnn(x2_emb, x2_mask) 126 | if self.args.question_merge == 'avg': 127 | q_merge_weights = layers.uniform_weights(question_hiddens, x2_mask) 128 | elif self.args.question_merge == 'self_attn': 129 | q_merge_weights = self.self_attn(question_hiddens, x2_mask) 130 | question_hidden = layers.weighted_avg(question_hiddens, q_merge_weights) 131 | 132 | # Predict start and end positions 133 | start_scores = self.start_attn(doc_hiddens, question_hidden, x1_mask) 134 | end_scores = self.end_attn(doc_hiddens, question_hidden, x1_mask) 135 | return start_scores, end_scores 136 | -------------------------------------------------------------------------------- /drqa/reader/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | """DrQA reader utilities.""" 8 | 9 | import json 10 | import time 11 | import logging 12 | import string 13 | import regex as re 14 | import random 15 | 16 | from collections import Counter 17 | from .data import Dictionary 18 | 19 | logger = logging.getLogger(__name__) 20 | 21 | 22 | # ------------------------------------------------------------------------------ 23 | # Data loading 24 | # ------------------------------------------------------------------------------ 25 | 26 | 27 | def load_data(args, filename, skip_no_answer=False): 28 | """Load examples from preprocessed file. 29 | One example per line, JSON encoded. 30 | """ 31 | # Load JSON lines 32 | with open(filename) as f: 33 | examples = [json.loads(line) for line in f] 34 | random.shuffle(examples) 35 | # Make case insensitive? 36 | if args.uncased_question or args.uncased_doc: 37 | for ex in examples: 38 | if args.uncased_question: 39 | ex['question'] = [w.lower() for w in ex['question']] 40 | if args.uncased_doc: 41 | ex['document'] = [w.lower() for w in ex['document']] 42 | 43 | # Skip unparsed (start/end) examples 44 | if skip_no_answer: 45 | examples = [ex for ex in examples if len(ex['answers']) > 0] 46 | 47 | return examples 48 | 49 | 50 | def load_text(filename): 51 | """Load the paragraphs only of a SQuAD dataset. Store as qid -> text.""" 52 | # Load JSON file 53 | with open(filename) as f: 54 | examples = json.load(f)['data'] 55 | 56 | texts = {} 57 | for article in examples: 58 | for paragraph in article['paragraphs']: 59 | for qa in paragraph['qas']: 60 | texts[qa['id']] = paragraph['context'] 61 | return texts 62 | 63 | 64 | def load_answers(filename): 65 | """Load the answers only of a SQuAD dataset. Store as qid -> [answers].""" 66 | # Load JSON file 67 | with open(filename) as f: 68 | examples = json.load(f)['data'] 69 | 70 | ans = {} 71 | for article in examples: 72 | for paragraph in article['paragraphs']: 73 | for qa in paragraph['qas']: 74 | ans[qa['id']] = list(map(lambda x: x['text'], qa['answers'])) 75 | return ans 76 | 77 | 78 | # ------------------------------------------------------------------------------ 79 | # Dictionary building 80 | # ------------------------------------------------------------------------------ 81 | 82 | 83 | def index_embedding_words(embedding_file): 84 | """Put all the words in embedding_file into a set.""" 85 | words = set() 86 | with open(embedding_file) as f: 87 | for line in f: 88 | w = Dictionary.normalize(line.rstrip().split(' ')[0]) 89 | words.add(w) 90 | return words 91 | 92 | 93 | def load_words(args, examples): 94 | """Iterate and index all the words in examples (documents + questions).""" 95 | def _insert(iterable): 96 | for w in iterable: 97 | w = Dictionary.normalize(w) 98 | if valid_words and w not in valid_words: 99 | continue 100 | words.add(w) 101 | 102 | if args.restrict_vocab and args.embedding_file: 103 | logger.info('Restricting to words in %s' % args.embedding_file) 104 | valid_words = index_embedding_words(args.embedding_file) 105 | logger.info('Num words in set = %d' % len(valid_words)) 106 | else: 107 | valid_words = None 108 | 109 | words = set() 110 | for ex in examples: 111 | _insert(ex['question']) 112 | _insert(ex['document']) 113 | return words 114 | 115 | 116 | def build_word_dict(args, examples): 117 | """Return a dictionary from question and document words in 118 | provided examples. 119 | """ 120 | word_dict = Dictionary() 121 | for w in load_words(args, examples): 122 | word_dict.add(w) 123 | return word_dict 124 | 125 | 126 | def top_question_words(args, examples, word_dict): 127 | """Count and return the most common question words in provided examples.""" 128 | word_count = Counter() 129 | for ex in examples: 130 | for w in ex['question']: 131 | w = Dictionary.normalize(w) 132 | if w in word_dict: 133 | word_count.update([w]) 134 | return word_count.most_common(args.tune_partial) 135 | 136 | 137 | def build_feature_dict(args, examples): 138 | """Index features (one hot) from fields in examples and options.""" 139 | def _insert(feature): 140 | if feature not in feature_dict: 141 | feature_dict[feature] = len(feature_dict) 142 | 143 | feature_dict = {} 144 | 145 | # Exact match features 146 | if args.use_in_question: 147 | _insert('in_question') 148 | _insert('in_question_uncased') 149 | # this is transfered for chinese pinyin and number verification 150 | if args.use_lemma: 151 | _insert('in_question_lemma') 152 | 153 | # Part of speech tag features 154 | if args.use_pos: 155 | for ex in examples: 156 | for w in ex['pos']: 157 | _insert('pos=%s' % w) 158 | 159 | # Named entity tag features 160 | if args.use_ner: 161 | for ex in examples: 162 | for w in ex['ner']: 163 | _insert('ner=%s' % w) 164 | 165 | # Term frequency feature 166 | if args.use_tf: 167 | _insert('tf') 168 | return feature_dict 169 | 170 | 171 | # ------------------------------------------------------------------------------ 172 | # Evaluation. Follows official evalutation script for v1.1 of the SQuAD dataset. 173 | # ------------------------------------------------------------------------------ 174 | 175 | 176 | def normalize_answer(s): 177 | """Lower text and remove punctuation, articles and extra whitespace.""" 178 | def remove_articles(text): 179 | return re.sub(r'\b(a|an|the)\b', ' ', text) 180 | 181 | def white_space_fix(text): 182 | return ' '.join(text.split()) 183 | 184 | def remove_punc(text): 185 | exclude = set(string.punctuation) 186 | return ''.join(ch for ch in text if ch not in exclude) 187 | 188 | def lower(text): 189 | return text.lower() 190 | 191 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 192 | 193 | 194 | def f1_score(prediction, ground_truth): 195 | """Compute the geometric mean of precision and recall for answer tokens.""" 196 | prediction_tokens = normalize_answer(prediction).split() 197 | ground_truth_tokens = normalize_answer(ground_truth).split() 198 | common = Counter(prediction_tokens) & Counter(ground_truth_tokens) 199 | num_same = sum(common.values()) 200 | if num_same == 0: 201 | return 0 202 | precision = 1.0 * num_same / len(prediction_tokens) 203 | recall = 1.0 * num_same / len(ground_truth_tokens) 204 | f1 = (2 * precision * recall) / (precision + recall) 205 | return f1 206 | 207 | 208 | def exact_match_score(prediction, ground_truth): 209 | """Check if the prediction is a (soft) exact match with the ground truth.""" 210 | return normalize_answer(prediction) == normalize_answer(ground_truth) 211 | 212 | 213 | def regex_match_score(prediction, pattern): 214 | """Check if the prediction matches the given regular expression.""" 215 | try: 216 | compiled = re.compile( 217 | pattern, 218 | flags=re.IGNORECASE + re.UNICODE + re.MULTILINE 219 | ) 220 | except BaseException: 221 | logger.warn('Regular expression failed to compile: %s' % pattern) 222 | return False 223 | return compiled.match(prediction) is not None 224 | 225 | 226 | def metric_max_over_ground_truths(metric_fn, prediction, ground_truths): 227 | """Given a prediction and multiple valid answers, return the score of 228 | the best prediction-answer_n pair given a metric function. 229 | """ 230 | scores_for_ground_truths = [] 231 | for ground_truth in ground_truths: 232 | score = metric_fn(prediction, ground_truth) 233 | scores_for_ground_truths.append(score) 234 | return max(scores_for_ground_truths) 235 | 236 | 237 | # ------------------------------------------------------------------------------ 238 | # Utility classes 239 | # ------------------------------------------------------------------------------ 240 | 241 | 242 | class AverageMeter(object): 243 | """Computes and stores the average and current value.""" 244 | 245 | def __init__(self): 246 | self.reset() 247 | 248 | def reset(self): 249 | self.val = 0 250 | self.avg = 0 251 | self.sum = 0 252 | self.count = 0 253 | 254 | def update(self, val, n=1): 255 | self.val = val 256 | self.sum += val * n 257 | self.count += n 258 | self.avg = self.sum / self.count 259 | 260 | 261 | class Timer(object): 262 | """Computes elapsed time.""" 263 | 264 | def __init__(self): 265 | self.running = True 266 | self.total = 0 267 | self.start = time.time() 268 | 269 | def reset(self): 270 | self.running = True 271 | self.total = 0 272 | self.start = time.time() 273 | return self 274 | 275 | def resume(self): 276 | if not self.running: 277 | self.running = True 278 | self.start = time.time() 279 | return self 280 | 281 | def stop(self): 282 | if self.running: 283 | self.running = False 284 | self.total += time.time() - self.start 285 | return self 286 | 287 | def time(self): 288 | if self.running: 289 | return self.total + time.time() - self.start 290 | return self.total 291 | -------------------------------------------------------------------------------- /drqa/reader/vector.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | """Functions for putting examples into torch format.""" 8 | 9 | from collections import Counter 10 | import torch 11 | from ..tokenizers.zh_features import similar, STOPWORDS 12 | compareHan = similar().compare 13 | 14 | 15 | def vectorize(ex, model, single_answer=False): 16 | """Torchify a single example.""" 17 | args = model.args 18 | word_dict = model.word_dict 19 | feature_dict = model.feature_dict 20 | embedding = model.network.embedding.weight.data 21 | # Index words 22 | document = torch.LongTensor([word_dict[w] for w in ex['document']]) 23 | question = torch.LongTensor([word_dict[w] for w in ex['question']]) 24 | 25 | # Create extra features vector 26 | if len(feature_dict) > 0: 27 | features = torch.zeros(len(ex['document']), len(feature_dict)) 28 | else: 29 | features = None 30 | 31 | # f_{exact_match} 32 | if args.use_in_question: 33 | def cos(vector1, vector2): 34 | dot_product = 0.0 35 | normA = 0.0 36 | normB = 0.0 37 | for a, b in zip(vector1, vector2): 38 | dot_product += a * b 39 | normA += a**2 40 | normB += b**2 41 | if normA == 0.0 or normB == 0.0: 42 | return None 43 | else: 44 | return dot_product / ((normA * normB)**0.5) 45 | q_words_cased = {w for w in ex['question']} 46 | q_words_uncased = {w.lower() for w in ex['question']} 47 | q_lemma = {w for w in ex['qlemma']} if args.use_lemma else None 48 | for i in range(len(ex['document'])): 49 | if ex['document'][i] in STOPWORDS: 50 | continue 51 | if ex['document'][i] in q_words_cased: 52 | features[i][feature_dict['in_question']] = 1.0 53 | 54 | for _w2 in q_words_uncased: 55 | # abandoned function : use lowest distance bewtween question 56 | # and answer as lemma. 57 | # if args.use_lemma: 58 | # # if args.use_lemma: 59 | # v1 = embedding[word_dict[ex['document'][i].lower()]] 60 | # v2 = embedding[word_dict[_w2]] 61 | # score = cos(v1, v2) 62 | # if score > features[i][feature_dict['in_question_lemma']]: 63 | # features[i][feature_dict['in_question_lemma']] = score 64 | 65 | if compareHan(ex['document'][i].lower(), _w2) == 1.0: 66 | features[i][feature_dict['in_question_uncased']] = 1.0 67 | break 68 | 69 | if q_lemma and ex['lemma'][i] in q_lemma: 70 | # lemma in Chinese is defined (replaced) as English translation 71 | features[i][feature_dict['in_question_lemma']] = 1.0 72 | 73 | # f_{token} (POS) 74 | if args.use_pos: 75 | for i, w in enumerate(ex['pos']): 76 | f = 'pos=%s' % w 77 | if f in feature_dict: 78 | features[i][feature_dict[f]] = 1.0 79 | 80 | # f_{token} (NER) 81 | if args.use_ner: 82 | for i, w in enumerate(ex['ner']): 83 | f = 'ner=%s' % w 84 | if f in feature_dict: 85 | features[i][feature_dict[f]] = 1.0 86 | 87 | # f_{token} (TF) 88 | if args.use_tf: 89 | counter = Counter([w.lower() for w in ex['document']]) 90 | l = len(ex['document']) 91 | for i, w in enumerate(ex['document']): 92 | features[i][feature_dict['tf']] = counter[w.lower()] * 1.0 / l 93 | 94 | # Maybe return without target 95 | if 'answers' not in ex: 96 | return document, features, question, ex['id'] 97 | 98 | # ...or with target(s) (might still be empty if answers is empty) 99 | if single_answer: 100 | assert(len(ex['answers']) > 0) 101 | start = torch.LongTensor(1).fill_(ex['answers'][0][0]) 102 | end = torch.LongTensor(1).fill_(ex['answers'][0][1]) 103 | else: 104 | start = [a[0] for a in ex['answers']] 105 | end = [a[1] for a in ex['answers']] 106 | 107 | # print(start) 108 | # print(end) 109 | return document, features, question, start, end, ex['id'] 110 | 111 | 112 | def batchify(batch): 113 | """Gather a batch of individual examples into one batch.""" 114 | NUM_INPUTS = 3 115 | NUM_TARGETS = 2 116 | NUM_EXTRA = 1 117 | 118 | ids = [ex[-1] for ex in batch] 119 | docs = [ex[0] for ex in batch] 120 | features = [ex[1] for ex in batch] 121 | questions = [ex[2] for ex in batch] 122 | 123 | # Batch documents and features 124 | max_length = max([d.size(0) for d in docs]) 125 | x1 = torch.LongTensor(len(docs), max_length).zero_() 126 | x1_mask = torch.ByteTensor(len(docs), max_length).fill_(1) 127 | if features[0] is None: 128 | x1_f = None 129 | else: 130 | x1_f = torch.zeros(len(docs), max_length, features[0].size(1)) 131 | for i, d in enumerate(docs): 132 | x1[i, :d.size(0)].copy_(d) 133 | x1_mask[i, :d.size(0)].fill_(0) 134 | if x1_f is not None: 135 | x1_f[i, :d.size(0)].copy_(features[i]) 136 | 137 | # Batch questions 138 | max_length = max([q.size(0) for q in questions]) 139 | x2 = torch.LongTensor(len(questions), max_length).zero_() 140 | x2_mask = torch.ByteTensor(len(questions), max_length).fill_(1) 141 | for i, q in enumerate(questions): 142 | x2[i, :q.size(0)].copy_(q) 143 | x2_mask[i, :q.size(0)].fill_(0) 144 | 145 | # Maybe return without targets 146 | if len(batch[0]) == NUM_INPUTS + NUM_EXTRA: 147 | return x1, x1_f, x1_mask, x2, x2_mask, ids 148 | 149 | elif len(batch[0]) == NUM_INPUTS + NUM_EXTRA + NUM_TARGETS: 150 | # ...Otherwise add targets 151 | if torch.is_tensor(batch[0][3]): 152 | y_s = torch.cat([ex[3] for ex in batch]) 153 | y_e = torch.cat([ex[4] for ex in batch]) 154 | else: 155 | y_s = [ex[3] for ex in batch] 156 | y_e = [ex[4] for ex in batch] 157 | else: 158 | raise RuntimeError('Incorrect number of inputs per example.') 159 | 160 | return x1, x1_f, x1_mask, x2, x2_mask, y_s, y_e, ids 161 | -------------------------------------------------------------------------------- /drqa/retriever/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import os 9 | from .. import DATA_DIR 10 | 11 | DEFAULTS = { 12 | 'db_path': os.path.join(DATA_DIR, 'wikipedia/docs.db'), 13 | 'tfidf_path': os.path.join( 14 | DATA_DIR, 15 | 'wikipedia/docs-tfidf-ngram=2-hash=16777216-tokenizer=simple.npz' 16 | ), 17 | } 18 | 19 | 20 | def set_default(key, value): 21 | global DEFAULTS 22 | DEFAULTS[key] = value 23 | 24 | 25 | def get_class(name): 26 | if name == 'tfidf': 27 | return TfidfDocRanker 28 | if name == 'sqlite': 29 | return DocDB 30 | raise RuntimeError('Invalid retriever class: %s' % name) 31 | 32 | 33 | from .doc_db import DocDB 34 | from .tfidf_doc_ranker import TfidfDocRanker 35 | -------------------------------------------------------------------------------- /drqa/retriever/doc_db.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | """Documents, in a sqlite database.""" 8 | 9 | import sqlite3 10 | from . import utils 11 | from . import DEFAULTS 12 | 13 | 14 | class DocDB(object): 15 | """Sqlite backed document storage. 16 | 17 | Implements get_doc_text(doc_id). 18 | """ 19 | 20 | def __init__(self, db_path=None): 21 | self.path = db_path or DEFAULTS['db_path'] 22 | self.connection = sqlite3.connect(self.path, check_same_thread=False) 23 | 24 | def __enter__(self): 25 | return self 26 | 27 | def __exit__(self, *args): 28 | self.close() 29 | 30 | def path(self): 31 | """Return the path to the file that backs this database.""" 32 | return self.path 33 | 34 | def close(self): 35 | """Close the connection to the database.""" 36 | self.connection.close() 37 | 38 | def get_doc_ids(self): 39 | """Fetch all ids of docs stored in the db.""" 40 | cursor = self.connection.cursor() 41 | cursor.execute("SELECT id FROM documents") 42 | results = [r[0] for r in cursor.fetchall()] 43 | cursor.close() 44 | return results 45 | 46 | def get_doc_text(self, doc_id): 47 | """Fetch the raw text of the doc for 'doc_id'.""" 48 | cursor = self.connection.cursor() 49 | cursor.execute( 50 | "SELECT text FROM documents WHERE id = ?", 51 | (utils.normalize(doc_id),) 52 | ) 53 | result = cursor.fetchone() 54 | cursor.close() 55 | return result if result is None else result[0] 56 | -------------------------------------------------------------------------------- /drqa/retriever/net_retriever.py: -------------------------------------------------------------------------------- 1 | # encoding='utf-8' 2 | # by yuan xin jie 3 | 4 | import json 5 | import requests 6 | from urllib.parse import quote 7 | from bs4 import BeautifulSoup 8 | import bs4 9 | import re 10 | import uuid 11 | import os 12 | 13 | 14 | def get_hrefs(soup, doc_num): 15 | count = 0 16 | href = [] 17 | for tr in soup.find_all('h3'): 18 | if isinstance(tr, bs4.element.Tag): 19 | tar = tr.a 20 | href.append(tar.attrs['href']) 21 | count += 1 22 | if count >= doc_num: 23 | break 24 | return href 25 | 26 | 27 | def get_html(url): 28 | res = requests.get(url) 29 | res.encoding = 'utf-8' 30 | return res.text 31 | 32 | 33 | def get_content_by_vsb(soup): 34 | content = [] 35 | for tr in soup.find_all('div', id=re.compile('vsb_')): 36 | if isinstance(tr, bs4.element.Tag): 37 | for p in tr.find_all('p'): 38 | content.append(p.text) 39 | return content 40 | 41 | 42 | def get_jsnr_content(soup): 43 | # 先把名字找出来,这个就很恶心 44 | name = '' 45 | for tr in soup.find_all('td', width=re.compile('21%')): 46 | if isinstance(tr, bs4.element.Tag): 47 | name = tr.text 48 | if name == '' or name is None: 49 | return '' 50 | title = [] 51 | text = [] 52 | content = [] 53 | # 开头处一个title对应一个text 54 | for tr in soup.find_all('div', class_=re.compile('jiaoshi_title')): 55 | if isinstance(tr, bs4.element.Tag): 56 | title.append(tr.text) 57 | for tr in soup.find_all('div', class_=re.compile('jstext')): 58 | if isinstance(tr, bs4.element.Tag): 59 | temp = ''.join(tr.text) 60 | temp = temp.replace(' ', ' ') 61 | if 'function' not in temp: 62 | text.append(temp) 63 | for i in range(len(text)): 64 | if i >= len(title): 65 | content[len(title)] += name + text[i] 66 | else: 67 | content.append(name + title[i] + text[i]) 68 | return content 69 | 70 | 71 | def get_content_by_p(soup): 72 | content = [] 73 | for tr in soup.find_all('p'): 74 | if isinstance(tr, bs4.element.Tag): 75 | content.append(''.join(tr.text.split())) 76 | 77 | return content 78 | 79 | 80 | def get_content_by_indent(soup): 81 | content = [] 82 | try: 83 | for tr in soup.find_all('p', class_=re.compile('indent')): 84 | if isinstance(tr, bs4.element.Tag): 85 | content.append(tr.text) 86 | return content 87 | except: 88 | return [] 89 | 90 | 91 | def get_content(link): 92 | html = get_html(link) 93 | soup = BeautifulSoup(html, 'html.parser') 94 | real_url = requests.get(link).url 95 | if 'jsnr.jsp' in real_url: 96 | content = get_jsnr_content(soup) 97 | return '\n'.join([w for w in content if len(w) > 10]) 98 | 99 | content = get_content_by_vsb(soup) 100 | if content == [] or content == '': 101 | content = get_content_by_indent(soup) 102 | if content == [] or content == '': 103 | content = get_content_by_p(soup) 104 | for i in range(len(content)): 105 | content[i] = content[i].replace(' ', ' ') 106 | return '\n'.join([w for w in content if len(w) > 10]) 107 | 108 | 109 | def save_content_to_files(content): 110 | if os.path.isdir('data') is False: 111 | os.mkdir('data') 112 | for w in content: 113 | if len(w) < 30: 114 | continue 115 | id = str(uuid.uuid1()) 116 | with open('data/' + id + '.txt', 'w', encoding='utf-8') as f: 117 | f.write(json.dumps({'id': id, 'text': w}, ensure_ascii=False)) 118 | 119 | 120 | def retriver(question, doc_num): 121 | if question == '': 122 | return False 123 | 124 | if doc_num == '' or doc_num is None: 125 | doc_num = 5 126 | elif type(doc_num) == str: 127 | doc_num = eval(doc_num) 128 | 129 | # url = 'http://www.baidu.com/s?wd=' + quote(question + ' site:xjtu.edu.cn') 130 | url = 'http://www.baidu.com/s?wd=' + quote(question) 131 | # remove target specification 132 | html = get_html(url) 133 | soup = BeautifulSoup(html, 'html.parser') 134 | hrefs = get_hrefs(soup, doc_num) 135 | content = [] 136 | for link in hrefs: 137 | content.append(get_content(link)) 138 | 139 | # for w in content : 140 | # if len(w)>20: 141 | # print(w) 142 | # print([w for w in content if len(w) > 20]) 143 | return ([w for w in content if len(w) > 20]) 144 | 145 | 146 | if __name__ == '__main__': 147 | retriver('交大哪年办学', 5) 148 | -------------------------------------------------------------------------------- /drqa/retriever/tfidf_doc_ranker.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | """Rank documents with TF-IDF scores""" 8 | 9 | import logging 10 | import numpy as np 11 | import scipy.sparse as sp 12 | 13 | from multiprocessing.pool import ThreadPool 14 | from functools import partial 15 | 16 | from . import utils 17 | from . import DEFAULTS 18 | from .. import tokenizers 19 | 20 | logger = logging.getLogger(__name__) 21 | 22 | 23 | class TfidfDocRanker(object): 24 | """Loads a pre-weighted inverted index of token/document terms. 25 | Scores new queries by taking sparse dot products. 26 | """ 27 | 28 | def __init__(self, tfidf_path=None, strict=True): 29 | """ 30 | Args: 31 | tfidf_path: path to saved model file 32 | strict: fail on empty queries or continue (and return empty result) 33 | """ 34 | # Load from disk 35 | tfidf_path = tfidf_path or DEFAULTS['tfidf_path'] 36 | logger.info('Loading %s' % tfidf_path) 37 | matrix, metadata = utils.load_sparse_csr(tfidf_path) 38 | self.doc_mat = matrix 39 | self.ngrams = metadata['ngram'] 40 | self.hash_size = metadata['hash_size'] 41 | self.tokenizer = tokenizers.get_class(metadata['tokenizer'])() 42 | self.doc_freqs = metadata['doc_freqs'].squeeze() 43 | self.doc_dict = metadata['doc_dict'] 44 | self.num_docs = len(self.doc_dict[0]) 45 | self.strict = strict 46 | 47 | def get_doc_index(self, doc_id): 48 | """Convert doc_id --> doc_index""" 49 | return self.doc_dict[0][doc_id] 50 | 51 | def get_doc_id(self, doc_index): 52 | """Convert doc_index --> doc_id""" 53 | return self.doc_dict[1][doc_index] 54 | 55 | def closest_docs(self, query, k=1): 56 | """Closest docs by dot product between query and documents 57 | in tfidf weighted word vector space. 58 | """ 59 | spvec = self.text2spvec(query) 60 | res = spvec * self.doc_mat 61 | 62 | if len(res.data) <= k: 63 | o_sort = np.argsort(-res.data) 64 | else: 65 | o = np.argpartition(-res.data, k)[0:k] 66 | o_sort = o[np.argsort(-res.data[o])] 67 | 68 | doc_scores = res.data[o_sort] 69 | doc_ids = [self.get_doc_id(i) for i in res.indices[o_sort]] 70 | return doc_ids, doc_scores 71 | 72 | def batch_closest_docs(self, queries, k=1, num_workers=None): 73 | """Process a batch of closest_docs requests multithreaded. 74 | Note: we can use plain threads here as scipy is outside of the GIL. 75 | """ 76 | with ThreadPool(num_workers) as threads: 77 | closest_docs = partial(self.closest_docs, k=k) 78 | results = threads.map(closest_docs, queries) 79 | return results 80 | 81 | def parse(self, query): 82 | """Parse the query into tokens (either ngrams or tokens).""" 83 | tokens = self.tokenizer.tokenize(query) 84 | return tokens.ngrams(n=self.ngrams, uncased=True, 85 | filter_fn=utils.filter_ngram) 86 | 87 | def text2spvec(self, query): 88 | """Create a sparse tfidf-weighted word vector from query. 89 | 90 | tfidf = log(tf + 1) * log((N - Nt + 0.5) / (Nt + 0.5)) 91 | """ 92 | # Get hashed ngrams 93 | words = self.parse(utils.normalize(query)) 94 | wids = [utils.hash(w, self.hash_size) for w in words] 95 | 96 | if len(wids) == 0: 97 | if self.strict: 98 | raise RuntimeError('No valid word in: %s' % query) 99 | else: 100 | logger.warning('No valid word in: %s' % query) 101 | return sp.csr_matrix((1, self.hash_size)) 102 | 103 | # Count TF 104 | wids_unique, wids_counts = np.unique(wids, return_counts=True) 105 | tfs = np.log1p(wids_counts) 106 | 107 | # Count IDF 108 | Ns = self.doc_freqs[wids_unique] 109 | idfs = np.log((self.num_docs - Ns + 0.5) / (Ns + 0.5)) 110 | idfs[idfs < 0] = 0 111 | 112 | # TF-IDF 113 | data = np.multiply(tfs, idfs) 114 | 115 | # One row, sparse csr matrix 116 | indptr = np.array([0, len(wids_unique)]) 117 | spvec = sp.csr_matrix( 118 | (data, wids_unique, indptr), shape=(1, self.hash_size) 119 | ) 120 | 121 | return spvec 122 | -------------------------------------------------------------------------------- /drqa/retriever/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | """Various retriever utilities.""" 8 | 9 | import regex 10 | import unicodedata 11 | import numpy as np 12 | import scipy.sparse as sp 13 | from sklearn.utils import murmurhash3_32 14 | from hanziconv import HanziConv 15 | from ..tokenizers.zh_features import STOPWORDS, normalize 16 | # ------------------------------------------------------------------------------ 17 | # Sparse matrix saving/loading helpers. 18 | # ------------------------------------------------------------------------------ 19 | 20 | 21 | def save_sparse_csr(filename, matrix, metadata=None): 22 | data = { 23 | 'data': matrix.data, 24 | 'indices': matrix.indices, 25 | 'indptr': matrix.indptr, 26 | 'shape': matrix.shape, 27 | 'metadata': metadata, 28 | } 29 | np.savez(filename, **data) 30 | 31 | 32 | def load_sparse_csr(filename): 33 | loader = np.load(filename) 34 | matrix = sp.csr_matrix((loader['data'], loader['indices'], 35 | loader['indptr']), shape=loader['shape']) 36 | return matrix, loader['metadata'].item(0) if 'metadata' in loader else None 37 | 38 | 39 | # ------------------------------------------------------------------------------ 40 | # Token hashing. 41 | # ------------------------------------------------------------------------------ 42 | 43 | 44 | def hash(token, num_buckets): 45 | """Unsigned 32 bit murmurhash for feature hashing.""" 46 | return murmurhash3_32(token, positive=True) % num_buckets 47 | 48 | 49 | # ------------------------------------------------------------------------------ 50 | # Text cleaning. 51 | # ------------------------------------------------------------------------------ 52 | 53 | 54 | def filter_word(text): 55 | """Take out english stopwords, punctuation, and compound endings.""" 56 | text = normalize(text) 57 | if regex.match(r'^\p{P}+$', text): 58 | return True 59 | if text.lower() in STOPWORDS: 60 | return True 61 | return False 62 | 63 | 64 | def filter_ngram(gram, mode='any'): 65 | """Decide whether to keep or discard an n-gram. 66 | 67 | Args: 68 | gram: list of tokens (length N) 69 | mode: Option to throw out ngram if 70 | 'any': any single token passes filter_word 71 | 'all': all tokens pass filter_word 72 | 'ends': book-ended by filterable tokens 73 | """ 74 | filtered = [filter_word(w) for w in gram] 75 | if mode == 'any': 76 | return any(filtered) 77 | elif mode == 'all': 78 | return all(filtered) 79 | elif mode == 'ends': 80 | return filtered[0] or filtered[-1] 81 | else: 82 | raise ValueError('Invalid mode: %s' % mode) 83 | -------------------------------------------------------------------------------- /drqa/tokenizers/Zh_tokenizer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # modified chinese tokenizer, use core_nlp 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | """Simple wrapper around the Stanford CoreNLP pipeline. 7 | 8 | Serves commands to a java subprocess running the jar. Requires java 8. 9 | """ 10 | 11 | 12 | import copy 13 | import json 14 | import pexpect 15 | 16 | from .tokenizer import Tokens, Tokenizer 17 | from . import DEFAULTS 18 | from .zh_features import trans, normalize 19 | 20 | 21 | class ZhTokenizer(Tokenizer): 22 | 23 | def __init__(self, **kwargs): 24 | """ 25 | Args: 26 | annotators: set that can include pos, lemma, and ner. 27 | classpath: Path to the corenlp directory of jars 28 | mem: Java heap memory 29 | """ 30 | # doesn't seem to work ... 31 | self.classpath = (kwargs.get('classpath') or 32 | DEFAULTS['corenlp_classpath']) 33 | 34 | # fixme : specific a path by yourself 35 | # self.classpath = '/home/amose/corenlp/*' # fixme : preset classPath 36 | self.annotators = copy.deepcopy(kwargs.get('annotators', set())) 37 | self.mem = kwargs.get('mem', '2g') 38 | self._launch() 39 | self.trans = trans('drqa/features/zh_dict.json') 40 | 41 | def _launch(self): 42 | """Start the CoreNLP jar with pexpect.""" 43 | annotators = ['tokenize', 'ssplit'] 44 | if 'ner' in self.annotators: 45 | annotators.extend(['pos', 'lemma', 'ner']) 46 | elif 'lemma' in self.annotators: 47 | annotators.extend(['pos', 'lemma']) 48 | elif 'pos' in self.annotators: 49 | annotators.extend(['pos']) 50 | annotators = ','.join(annotators) 51 | options = ','.join(['untokenizable=noneDelete', 52 | 'invertible=true']) 53 | cmd = ['java', '-mx' + self.mem, '-cp', '\'%s\'' % self.classpath, 54 | 'edu.stanford.nlp.pipeline.StanfordCoreNLP', '-props', 55 | 'StanfordCoreNLP-chinese.properties', 56 | '-annotators', annotators, '-tokenize.options', options, 57 | '-outputFormat', 'json', '-prettyPrint', 'false'] 58 | # print(cmd) 59 | # We use pexpect to keep the subprocess alive and feed it commands. 60 | # Because we don't want to get hit by the max terminal buffer size, 61 | # we turn off canonical input processing to have unlimited bytes. 62 | self.corenlp = pexpect.spawn('/bin/bash', maxread=100000, timeout=60) 63 | self.corenlp.setecho(False) 64 | self.corenlp.sendline('stty -icanon') 65 | self.corenlp.sendline(' '.join(cmd)) 66 | # print(' '.join(cmd)) 67 | self.corenlp.delaybeforesend = 0 68 | self.corenlp.delayafterread = 0 69 | self.corenlp.expect_exact('NLP>', searchwindowsize=100) 70 | print('[init tokenizer done]') 71 | 72 | @staticmethod 73 | def _convert(token): 74 | if token == '-LRB-': 75 | return '(' 76 | if token == '-RRB-': 77 | return ')' 78 | if token == '-LSB-': 79 | return '[' 80 | if token == '-RSB-': 81 | return ']' 82 | if token == '-LCB-': 83 | return '{' 84 | if token == '-RCB-': 85 | return '}' 86 | return token 87 | 88 | def tokenize(self, text): 89 | # Since we're feeding text to the commandline, we're waiting on seeing 90 | # the NLP> prompt. Hacky! 91 | if 'NLP>' in text: 92 | raise RuntimeError('Bad token (NLP>) in text!') 93 | 94 | # Sending q will cause the process to quit -- manually override 95 | if text.lower().strip() == 'q': 96 | token = text.strip() 97 | index = text.index(token) 98 | data = [(token, text[index:], (index, index + 1), 'NN', 'q', 'O')] 99 | return Tokens(data, self.annotators) 100 | 101 | # Minor cleanup before tokenizing. 102 | clean_text = normalize(text) 103 | 104 | self.corenlp.sendline(clean_text.encode('utf-8')) 105 | self.corenlp.expect_exact('NLP>', searchwindowsize=100) 106 | 107 | # Skip to start of output (may have been stderr logging messages) 108 | output = self.corenlp.before 109 | start = output.find(b'{"sentences":') 110 | output = json.loads(output[start:].decode('utf-8')) 111 | 112 | data = [] 113 | tokens = [t for s in output['sentences'] for t in s['tokens']] 114 | for i in range(len(tokens)): 115 | # Get whitespace 116 | start_ws = tokens[i]['characterOffsetBegin'] 117 | if i + 1 < len(tokens): 118 | end_ws = tokens[i + 1]['characterOffsetBegin'] 119 | else: 120 | end_ws = tokens[i]['characterOffsetEnd'] 121 | 122 | data.append(( 123 | self._convert(tokens[i]['word']), 124 | text[start_ws: end_ws], 125 | (tokens[i]['characterOffsetBegin'], 126 | tokens[i]['characterOffsetEnd']), 127 | tokens[i].get('pos', None), 128 | # lemma : translation or pinyin 129 | self.trans.translate(tokens[i].get('lemma', None), 130 | tokens[i].get('pos', None)), 131 | # tokens[i].get('lemma', None), 132 | tokens[i].get('ner', None), 133 | # self.trans.pinyin(text[start_ws: end_ws]) 134 | )) 135 | # print(data) 136 | return Tokens(data, self.annotators) 137 | -------------------------------------------------------------------------------- /drqa/tokenizers/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import os 9 | 10 | DEFAULTS = { 11 | 'corenlp_classpath': os.getenv('CLASSPATH') 12 | } 13 | 14 | 15 | def set_default(key, value): 16 | global DEFAULTS 17 | DEFAULTS[key] = value 18 | 19 | 20 | from .corenlp_tokenizer import CoreNLPTokenizer 21 | from .regexp_tokenizer import RegexpTokenizer 22 | from .simple_tokenizer import SimpleTokenizer 23 | from .Zh_tokenizer import ZhTokenizer 24 | # Spacy is optional 25 | try: 26 | from .spacy_tokenizer import SpacyTokenizer 27 | except ImportError: 28 | pass 29 | 30 | 31 | def get_class(name): 32 | if name == 'spacy': 33 | return SpacyTokenizer 34 | if name == 'corenlp': 35 | return CoreNLPTokenizer 36 | if name == 'regexp': 37 | return RegexpTokenizer 38 | if name == 'simple': 39 | return SimpleTokenizer 40 | if name == 'zh': 41 | return ZhTokenizer 42 | raise RuntimeError('Invalid tokenizer: %s' % name) 43 | 44 | 45 | def get_annotators_for_args(args): 46 | annotators = set() 47 | if args.use_pos: 48 | annotators.add('pos') 49 | if args.use_lemma: 50 | annotators.add('lemma') 51 | if args.use_ner: 52 | annotators.add('ner') 53 | # if args.use_pinyin: 54 | # annotators.add('pinyin') 55 | return annotators 56 | 57 | 58 | def get_annotators_for_model(model): 59 | return get_annotators_for_args(model.args) 60 | -------------------------------------------------------------------------------- /drqa/tokenizers/corenlp_tokenizer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | """Simple wrapper around the Stanford CoreNLP pipeline. 8 | 9 | Serves commands to a java subprocess running the jar. Requires java 8. 10 | """ 11 | 12 | import copy 13 | import json 14 | import pexpect 15 | 16 | from .tokenizer import Tokens, Tokenizer 17 | from . import DEFAULTS 18 | 19 | 20 | class CoreNLPTokenizer(Tokenizer): 21 | 22 | def __init__(self, **kwargs): 23 | """ 24 | Args: 25 | annotators: set that can include pos, lemma, and ner. 26 | classpath: Path to the corenlp directory of jars 27 | mem: Java heap memory 28 | """ 29 | self.classpath = (kwargs.get('classpath') or 30 | DEFAULTS['corenlp_classpath']) 31 | self.annotators = copy.deepcopy(kwargs.get('annotators', set())) 32 | self.mem = kwargs.get('mem', '2g') 33 | self._launch() 34 | 35 | def _launch(self): 36 | """Start the CoreNLP jar with pexpect.""" 37 | annotators = ['tokenize', 'ssplit'] 38 | if 'ner' in self.annotators: 39 | annotators.extend(['pos', 'lemma', 'ner']) 40 | elif 'lemma' in self.annotators: 41 | annotators.extend(['pos', 'lemma']) 42 | elif 'pos' in self.annotators: 43 | annotators.extend(['pos']) 44 | annotators = ','.join(annotators) 45 | options = ','.join(['untokenizable=noneDelete', 46 | 'invertible=true']) 47 | cmd = ['java', '-mx' + self.mem, '-cp', '"%s"' % self.classpath, 48 | 'edu.stanford.nlp.pipeline.StanfordCoreNLP', '-annotators', 49 | annotators, '-tokenize.options', options, 50 | '-outputFormat', 'json', '-prettyPrint', 'false'] 51 | 52 | # We use pexpect to keep the subprocess alive and feed it commands. 53 | # Because we don't want to get hit by the max terminal buffer size, 54 | # we turn off canonical input processing to have unlimited bytes. 55 | self.corenlp = pexpect.spawn('/bin/bash', maxread=100000, timeout=60) 56 | self.corenlp.setecho(False) 57 | self.corenlp.sendline('stty -icanon') 58 | self.corenlp.sendline(' '.join(cmd)) 59 | self.corenlp.delaybeforesend = 0 60 | self.corenlp.delayafterread = 0 61 | self.corenlp.expect_exact('NLP>', searchwindowsize=100) 62 | 63 | @staticmethod 64 | def _convert(token): 65 | if token == '-LRB-': 66 | return '(' 67 | if token == '-RRB-': 68 | return ')' 69 | if token == '-LSB-': 70 | return '[' 71 | if token == '-RSB-': 72 | return ']' 73 | if token == '-LCB-': 74 | return '{' 75 | if token == '-RCB-': 76 | return '}' 77 | return token 78 | 79 | def tokenize(self, text): 80 | # Since we're feeding text to the commandline, we're waiting on seeing 81 | # the NLP> prompt. Hacky! 82 | if 'NLP>' in text: 83 | raise RuntimeError('Bad token (NLP>) in text!') 84 | 85 | # Sending q will cause the process to quit -- manually override 86 | if text.lower().strip() == 'q': 87 | token = text.strip() 88 | index = text.index(token) 89 | data = [(token, text[index:], (index, index + 1), 'NN', 'q', 'O')] 90 | return Tokens(data, self.annotators) 91 | 92 | # Minor cleanup before tokenizing. 93 | clean_text = text.replace('\n', ' ') 94 | 95 | self.corenlp.sendline(clean_text.encode('utf-8')) 96 | self.corenlp.expect_exact('NLP>', searchwindowsize=100) 97 | 98 | # Skip to start of output (may have been stderr logging messages) 99 | output = self.corenlp.before 100 | start = output.find(b'{"sentences":') 101 | output = json.loads(output[start:].decode('utf-8')) 102 | 103 | data = [] 104 | tokens = [t for s in output['sentences'] for t in s['tokens']] 105 | for i in range(len(tokens)): 106 | # Get whitespace 107 | start_ws = tokens[i]['characterOffsetBegin'] 108 | if i + 1 < len(tokens): 109 | end_ws = tokens[i + 1]['characterOffsetBegin'] 110 | else: 111 | end_ws = tokens[i]['characterOffsetEnd'] 112 | 113 | data.append(( 114 | self._convert(tokens[i]['word']), 115 | text[start_ws: end_ws], 116 | (tokens[i]['characterOffsetBegin'], 117 | tokens[i]['characterOffsetEnd']), 118 | tokens[i].get('pos', None), 119 | tokens[i].get('lemma', None), 120 | tokens[i].get('ner', None) 121 | )) 122 | return Tokens(data, self.annotators) 123 | -------------------------------------------------------------------------------- /drqa/tokenizers/regexp_tokenizer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | """Regex based tokenizer that emulates the Stanford/NLTK PTB tokenizers. 8 | 9 | However it is purely in Python, supports robust untokenization, unicode, 10 | and requires minimal dependencies. 11 | """ 12 | 13 | import regex 14 | import logging 15 | from .tokenizer import Tokens, Tokenizer 16 | 17 | logger = logging.getLogger(__name__) 18 | 19 | 20 | class RegexpTokenizer(Tokenizer): 21 | DIGIT = r'\p{Nd}+([:\.\,]\p{Nd}+)*' 22 | TITLE = (r'(dr|esq|hon|jr|mr|mrs|ms|prof|rev|sr|st|rt|messrs|mmes|msgr)' 23 | r'\.(?=\p{Z})') 24 | ABBRV = r'([\p{L}]\.){2,}(?=\p{Z}|$)' 25 | ALPHA_NUM = r'[\p{L}\p{N}\p{M}]++' 26 | HYPHEN = r'{A}([-\u058A\u2010\u2011]{A})+'.format(A=ALPHA_NUM) 27 | NEGATION = r"((?!n't)[\p{L}\p{N}\p{M}])++(?=n't)|n't" 28 | CONTRACTION1 = r"can(?=not\b)" 29 | CONTRACTION2 = r"'([tsdm]|re|ll|ve)\b" 30 | START_DQUOTE = r'(?<=[\p{Z}\(\[{<]|^)(``|["\u0093\u201C\u00AB])(?!\p{Z})' 31 | START_SQUOTE = r'(?<=[\p{Z}\(\[{<]|^)[\'\u0091\u2018\u201B\u2039](?!\p{Z})' 32 | END_DQUOTE = r'(?%s)|(?P%s)|(?P<abbr>%s)|(?P<neg>%s)|(?P<hyph>%s)|' 47 | '(?P<contr1>%s)|(?P<alphanum>%s)|(?P<contr2>%s)|(?P<sdquote>%s)|' 48 | '(?P<edquote>%s)|(?P<ssquote>%s)|(?P<esquote>%s)|(?P<dash>%s)|' 49 | '(?<ellipses>%s)|(?P<punct>%s)|(?P<nonws>%s)' % 50 | (self.DIGIT, self.TITLE, self.ABBRV, self.NEGATION, self.HYPHEN, 51 | self.CONTRACTION1, self.ALPHA_NUM, self.CONTRACTION2, 52 | self.START_DQUOTE, self.END_DQUOTE, self.START_SQUOTE, 53 | self.END_SQUOTE, self.DASH, self.ELLIPSES, self.PUNCT, 54 | self.NON_WS), 55 | flags=regex.IGNORECASE + regex.UNICODE + regex.MULTILINE 56 | ) 57 | if len(kwargs.get('annotators', {})) > 0: 58 | logger.warning('%s only tokenizes! Skipping annotators: %s' % 59 | (type(self).__name__, kwargs.get('annotators'))) 60 | self.annotators = set() 61 | self.substitutions = kwargs.get('substitutions', True) 62 | 63 | def tokenize(self, text): 64 | data = [] 65 | matches = [m for m in self._regexp.finditer(text)] 66 | for i in range(len(matches)): 67 | # Get text 68 | token = matches[i].group() 69 | 70 | # Make normalizations for special token types 71 | if self.substitutions: 72 | groups = matches[i].groupdict() 73 | if groups['sdquote']: 74 | token = "``" 75 | elif groups['edquote']: 76 | token = "''" 77 | elif groups['ssquote']: 78 | token = "`" 79 | elif groups['esquote']: 80 | token = "'" 81 | elif groups['dash']: 82 | token = '--' 83 | elif groups['ellipses']: 84 | token = '...' 85 | 86 | # Get whitespace 87 | span = matches[i].span() 88 | start_ws = span[0] 89 | if i + 1 < len(matches): 90 | end_ws = matches[i + 1].span()[0] 91 | else: 92 | end_ws = span[1] 93 | 94 | # Format data 95 | data.append(( 96 | token, 97 | text[start_ws: end_ws], 98 | span, 99 | )) 100 | return Tokens(data, self.annotators) 101 | -------------------------------------------------------------------------------- /drqa/tokenizers/simple_tokenizer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | """Basic tokenizer that splits text into alpha-numeric tokens and 8 | non-whitespace tokens. 9 | """ 10 | 11 | import regex 12 | import logging 13 | from .tokenizer import Tokens, Tokenizer 14 | 15 | logger = logging.getLogger(__name__) 16 | 17 | 18 | class SimpleTokenizer(Tokenizer): 19 | ALPHA_NUM = r'[\p{L}\p{N}\p{M}]+' 20 | NON_WS = r'[^\p{Z}\p{C}]' 21 | 22 | def __init__(self, **kwargs): 23 | """ 24 | Args: 25 | annotators: None or empty set (only tokenizes). 26 | """ 27 | self._regexp = regex.compile( 28 | '(%s)|(%s)' % (self.ALPHA_NUM, self.NON_WS), 29 | flags=regex.IGNORECASE + regex.UNICODE + regex.MULTILINE 30 | ) 31 | if len(kwargs.get('annotators', {})) > 0: 32 | logger.warning('%s only tokenizes! Skipping annotators: %s' % 33 | (type(self).__name__, kwargs.get('annotators'))) 34 | self.annotators = set() 35 | 36 | def tokenize(self, text): 37 | data = [] 38 | matches = [m for m in self._regexp.finditer(text)] 39 | for i in range(len(matches)): 40 | # Get text 41 | token = matches[i].group() 42 | 43 | # Get whitespace 44 | span = matches[i].span() 45 | start_ws = span[0] 46 | if i + 1 < len(matches): 47 | end_ws = matches[i + 1].span()[0] 48 | else: 49 | end_ws = span[1] 50 | 51 | # Format data 52 | data.append(( 53 | token, 54 | text[start_ws: end_ws], 55 | span, 56 | )) 57 | return Tokens(data, self.annotators) 58 | -------------------------------------------------------------------------------- /drqa/tokenizers/spacy_tokenizer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | """Tokenizer that is backed by spaCy (spacy.io). 8 | 9 | Requires spaCy package and the spaCy english model. 10 | """ 11 | 12 | import spacy 13 | import copy 14 | from .tokenizer import Tokens, Tokenizer 15 | 16 | 17 | class SpacyTokenizer(Tokenizer): 18 | 19 | def __init__(self, **kwargs): 20 | """ 21 | Args: 22 | annotators: set that can include pos, lemma, and ner. 23 | model: spaCy model to use (either path, or keyword like 'en'). 24 | """ 25 | model = kwargs.get('model', 'en') 26 | self.annotators = copy.deepcopy(kwargs.get('annotators', set())) 27 | nlp_kwargs = {'parser': False} 28 | if not {'lemma', 'pos', 'ner'} & self.annotators: 29 | nlp_kwargs['tagger'] = False 30 | if not {'ner'} & self.annotators: 31 | nlp_kwargs['entity'] = False 32 | self.nlp = spacy.load(model, **nlp_kwargs) 33 | 34 | def tokenize(self, text): 35 | # We don't treat new lines as tokens. 36 | clean_text = text.replace('\n', ' ') 37 | tokens = self.nlp.tokenizer(clean_text) 38 | if {'lemma', 'pos', 'ner'} & self.annotators: 39 | self.nlp.tagger(tokens) 40 | if {'ner'} & self.annotators: 41 | self.nlp.entity(tokens) 42 | 43 | data = [] 44 | for i in range(len(tokens)): 45 | # Get whitespace 46 | start_ws = tokens[i].idx 47 | if i + 1 < len(tokens): 48 | end_ws = tokens[i + 1].idx 49 | else: 50 | end_ws = tokens[i].idx + len(tokens[i].text) 51 | 52 | data.append(( 53 | tokens[i].text, 54 | text[start_ws: end_ws], 55 | (tokens[i].idx, tokens[i].idx + len(tokens[i].text)), 56 | tokens[i].tag_, 57 | tokens[i].lemma_, 58 | tokens[i].ent_type_, 59 | )) 60 | 61 | # Set special option for non-entity tag: '' vs 'O' in spaCy 62 | return Tokens(data, self.annotators, opts={'non_ent': ''}) 63 | -------------------------------------------------------------------------------- /drqa/tokenizers/tokenizer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | """Base tokenizer/tokens classes and utilities.""" 8 | 9 | import copy 10 | 11 | 12 | class Tokens(object): 13 | """A class to represent a list of tokenized text.""" 14 | TEXT = 0 15 | TEXT_WS = 1 16 | SPAN = 2 17 | POS = 3 18 | LEMMA = 4 19 | NER = 5 20 | 21 | def __init__(self, data, annotators, opts=None): 22 | self.data = data 23 | self.annotators = annotators 24 | self.opts = opts or {} 25 | 26 | def __len__(self): 27 | """The number of tokens.""" 28 | return len(self.data) 29 | 30 | def slice(self, i=None, j=None): 31 | """Return a view of the list of tokens from [i, j).""" 32 | new_tokens = copy.copy(self) 33 | new_tokens.data = self.data[i: j] 34 | return new_tokens 35 | 36 | def untokenize(self): 37 | """Returns the original text (with whitespace reinserted).""" 38 | return ''.join([t[self.TEXT_WS] for t in self.data]).strip() 39 | 40 | def words(self, uncased=False): 41 | """Returns a list of the text of each token 42 | 43 | Args: 44 | uncased: lower cases text 45 | """ 46 | if uncased: 47 | return [t[self.TEXT].lower() for t in self.data] 48 | else: 49 | return [t[self.TEXT] for t in self.data] 50 | 51 | def offsets(self): 52 | """Returns a list of [start, end) character offsets of each token.""" 53 | return [t[self.SPAN] for t in self.data] 54 | 55 | def pos(self): 56 | """Returns a list of part-of-speech tags of each token. 57 | Returns None if this annotation was not included. 58 | """ 59 | if 'pos' not in self.annotators: 60 | return None 61 | return [t[self.POS] for t in self.data] 62 | 63 | def lemmas(self): 64 | """Returns a list of the lemmatized text of each token. 65 | Returns None if this annotation was not included. 66 | """ 67 | if 'lemma' not in self.annotators: 68 | return None 69 | return [t[self.LEMMA] for t in self.data] 70 | 71 | def entities(self): 72 | """Returns a list of named-entity-recognition tags of each token. 73 | Returns None if this annotation was not included. 74 | """ 75 | if 'ner' not in self.annotators: 76 | return None 77 | return [t[self.NER] for t in self.data] 78 | 79 | def ngrams(self, n=1, uncased=False, filter_fn=None, as_strings=True): 80 | """Returns a list of all ngrams from length 1 to n. 81 | 82 | Args: 83 | n: upper limit of ngram length 84 | uncased: lower cases text 85 | filter_fn: user function that takes in an ngram list and returns 86 | True or False to keep or not keep the ngram 87 | as_string: return the ngram as a string vs list 88 | """ 89 | def _skip(gram): 90 | if not filter_fn: 91 | return False 92 | return filter_fn(gram) 93 | 94 | words = self.words(uncased) 95 | ngrams = [(s, e + 1) 96 | for s in range(len(words)) 97 | for e in range(s, min(s + n, len(words))) 98 | if not _skip(words[s:e + 1])] 99 | 100 | # Concatenate into strings 101 | if as_strings: 102 | ngrams = ['{}'.format(' '.join(words[s:e])) for (s, e) in ngrams] 103 | 104 | return ngrams 105 | 106 | def entity_groups(self): 107 | """Group consecutive entity tokens with the same NER tag.""" 108 | entities = self.entities() 109 | if not entities: 110 | return None 111 | non_ent = self.opts.get('non_ent', 'O') 112 | groups = [] 113 | idx = 0 114 | while idx < len(entities): 115 | ner_tag = entities[idx] 116 | # Check for entity tag 117 | if ner_tag != non_ent: 118 | # Chomp the sequence 119 | start = idx 120 | while (idx < len(entities) and entities[idx] == ner_tag): 121 | idx += 1 122 | groups.append((self.slice(start, idx).untokenize(), ner_tag)) 123 | else: 124 | idx += 1 125 | return groups 126 | 127 | 128 | class Tokenizer(object): 129 | """Base tokenizer class. 130 | Tokenizers implement tokenize, which should return a Tokens class. 131 | """ 132 | def tokenize(self, text): 133 | raise NotImplementedError 134 | 135 | def shutdown(self): 136 | pass 137 | 138 | def __del__(self): 139 | self.shutdown() 140 | -------------------------------------------------------------------------------- /drqa/tokenizers/zh_features.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -* 3 | # file witch exports important features for Chinese 4 | # include translation, normalization, similarty comparation and stopwords 5 | 6 | 7 | import os 8 | import re 9 | import json 10 | from pypinyin import lazy_pinyin 11 | from hanziconv import HanziConv 12 | import unicodedata 13 | 14 | 15 | def loadDict(path): 16 | dict = {} 17 | with open(path, 'r', encoding='utf-8') as f: 18 | for line in f: 19 | arr = line.split(':::') 20 | para = json.loads(arr[1])['paraphrase'] 21 | pp = {} 22 | for word in para: 23 | word = word.replace(' ', '') 24 | if word.find('.') != -1: 25 | t = word.split('.')[0] 26 | w = word.split('.')[1].split(';')[0].strip() 27 | pp[t] = w 28 | else: 29 | pp['*'] = word.strip() 30 | dict[arr[0]] = pp 31 | return dict 32 | 33 | 34 | class Youdao(object): 35 | 36 | def __init__(self): 37 | super(Youdao, self).__init__() 38 | 39 | def query(self, word): 40 | import requests 41 | try: 42 | import xml.etree.cElementTree as ET 43 | except ImportError: 44 | import xml.etree.ElementTree as ET 45 | sess = requests.Session() 46 | headers = { 47 | 'Host': 'dict.youdao.com', 48 | 'User-Agent': 'Mozilla/5.0 (X11; Linux x86_64; rv:50.0) Gecko/20100101 Firefox/50.0', 49 | 'Accept': 'text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8', 50 | 'Accept-Language': 'en-US,en;q=0.5', 51 | 'Accept-Encoding': 'gzip, deflate' 52 | } 53 | sess.headers.update(headers) 54 | url = 'http://dict.youdao.com/fsearch?q=%s' % (word) 55 | try: 56 | resp = sess.get(url, timeout=100) 57 | except: 58 | return None 59 | text = resp.content 60 | if (resp.status_code == 200) and (text): 61 | tree = ET.ElementTree(ET.fromstring(text)) 62 | returnPhrase = tree.find('return-phrase') 63 | if returnPhrase.text.strip() != word: 64 | return None 65 | customTranslation = tree.find('custom-translation') 66 | if not customTranslation: 67 | return None 68 | trans = '' 69 | for t in customTranslation.findall('translation'): 70 | transText = t[0].text 71 | if transText: 72 | trans = transText 73 | return trans 74 | return None 75 | else: 76 | return None 77 | 78 | 79 | class trans(object): 80 | # translation for Chinese word 81 | def __init__(self, path): 82 | # fixme : there are some bugs while loading dictionary 83 | self.dict = loadDict(path) 84 | 85 | def translate(self, word, pos, use_online=False): 86 | if self.dict.get(word): 87 | d = self.dict.get(word) 88 | # find if pos tag match dict result 89 | # give a random one if no one match 90 | if d.get(pos.lower()): 91 | return d.get(pos.lower()) 92 | return next(iter(d.values())) 93 | elif use_online: 94 | # never use_online so far 95 | pass 96 | elif word: 97 | return ' '.join(lazy_pinyin(word)) 98 | else: 99 | return word 100 | 101 | def pinyin(self, word): 102 | if word: 103 | return ' '.join(lazy_pinyin(word)) 104 | else: 105 | return word 106 | 107 | 108 | STOPWORDS = { 109 | 'i', 'me', 'my', 'myself', 'we', 'our', 'ours', 'ourselves', 'you', 'your', 110 | 'yours', 'yourself', 'yourselves', 'he', 'him', 'his', 'himself', 'she', 111 | 'her', 'hers', 'herself', 'it', 'its', 'itself', 'they', 'them', 'their', 112 | 'theirs', 'themselves', 'what', 'which', 'who', 'whom', 'this', 'that', 113 | 'these', 'those', 'am', 'is', 'are', 'was', 'were', 'be', 'been', 'being', 114 | 'have', 'has', 'had', 'having', 'do', 'does', 'did', 'doing', 'a', 'an', 115 | 'the', 'and', 'but', 'if', 'or', 'because', 'as', 'until', 'while', 'of', 116 | 'at', 'by', 'for', 'with', 'about', 'against', 'between', 'into', 'through', 117 | 'during', 'before', 'after', 'above', 'below', 'to', 'from', 'up', 'down', 118 | 'in', 'out', 'on', 'off', 'over', 'under', 'again', 'further', 'then', 119 | 'once', 'here', 'there', 'when', 'where', 'why', 'how', 'all', 'any', 120 | 'both', 'each', 'few', 'more', 'most', 'other', 'some', 'such', 'no', 'nor', 121 | 'not', 'only', 'own', 'same', 'so', 'than', 'too', 'very', 's', 't', 'can', 122 | 'will', 'just', 'don', 'should', 'now', 'd', 'll', 'm', 'o', 're', 've', 123 | 'y', 'ain', 'aren', 'couldn', 'didn', 'doesn', 'hadn', 'hasn', 'haven', 124 | 'isn', 'ma', 'mightn', 'mustn', 'needn', 'shan', 'shouldn', 'wasn', 'weren', 125 | 'won', 'wouldn', "'ll", "'re", "'ve", "n't", "'s", "'d", "'m", "''", "``" 126 | } 127 | with open('drqa/features/stopword_zh.txt') as f: 128 | # load chinese stop word 129 | for line in f: 130 | STOPWORDS.add(line.replace('\n', '')) 131 | 132 | 133 | class similar(object): 134 | ''' a very rough method to evaluate the similarity between context and 135 | questions. Not productive at all :( 136 | # fixme : should be replaced by a complex module 137 | ps: this one is also used to get the best position for answers 138 | (where is the best show up in context)''' 139 | def __init__(self): 140 | self.chs_arabic_map = {u'零': 0, u'一': 1, u'二': 2, u'三': 3, u'四': 4, 141 | u'五': 5, u'六': 6, u'七': 7, u'八': 8, u'九': 9, 142 | u'十': 10, u'百': 100, u'千': 10 ** 3, u'万': 10 ** 4, 143 | u'亿': 10 ** 8} 144 | 145 | def compare(self, word0, word1): 146 | # print(word0 + '|' + word1) 147 | word0 = normalize(word0) 148 | word1 = normalize(word1) 149 | if word0 not in STOPWORDS and word1 not in STOPWORDS: 150 | if ' '.join(lazy_pinyin(word0)) == ' '.join(lazy_pinyin(word1)): 151 | return 1.0 152 | elif self.convertHan(word0) == self.convertHan(word1): 153 | return 1.0 154 | else: 155 | return 0.0 156 | else: 157 | return 0.0 158 | 159 | def convertHan(self, text): 160 | ls = re.finditer( 161 | '[零|一|二|三|四|五|六|七|八|九|十][零|一|二|三|四|五|六|七|八|九|十|百|千|万|亿]+', text) 162 | for i in ls: 163 | s = text[i.span()[0]:i.span()[1]] 164 | try: 165 | text = text.replace( 166 | s, (str)(self.convertChineseDigitsToArabic(s))) 167 | except: 168 | return text 169 | return text 170 | 171 | def convertChineseDigitsToArabic(self, chinese_digits): 172 | result = 0 173 | tmp = 0 174 | hnd_mln = 0 175 | for count in range(len(chinese_digits)): 176 | curr_char = chinese_digits[count] 177 | curr_digit = self.chs_arabic_map.get(curr_char, None) 178 | # meet 「亿」 or 「億」 179 | if curr_digit == 10 ** 8: 180 | result = result + tmp 181 | result = result * curr_digit 182 | # get result before 「亿」 and store it into hnd_mln 183 | # reset `result` 184 | hnd_mln = hnd_mln * 10 ** 8 + result 185 | result = 0 186 | tmp = 0 187 | # meet 「万」 or 「萬」 188 | elif curr_digit == 10 ** 4: 189 | result = result + tmp 190 | result = result * curr_digit 191 | tmp = 0 192 | # meet 「十」, 「百」, 「千」 or their traditional version 193 | elif curr_digit >= 10: 194 | tmp = 1 if tmp == 0 else tmp 195 | result = result + curr_digit * tmp 196 | tmp = 0 197 | # meet single digit 198 | elif curr_digit is not None: 199 | tmp = tmp * 10 + curr_digit 200 | else: 201 | return result 202 | result = result + tmp 203 | result = result + hnd_mln 204 | return result 205 | 206 | 207 | def normalize(text): 208 | toSim = HanziConv.toSimplified(text.replace('\n', ' ')) 209 | t2 = unicodedata.normalize('NFKC', toSim) 210 | table = {ord(f): ord(t) for f, t in zip( 211 | u',。!?【】()%#@&1234567890', 212 | u',.!?[]()%#@&1234567890')} 213 | t3 = t2.translate(table) 214 | return t3 215 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # DrQA Chinese implementation 2 | 3 | ## Introduction 4 | This is a modified version of facebook [DrQA](https://github.com/facebookresearch/DrQA) module which supports Chinese language. The git repo is for study only. This module can be used to answer question for any specific context. The initial optimization is targeting to area of specific university. This project is not fully tested nor fully complete.Online retriever using baidu service is broken. 5 | 6 | ## DrQA Introduction 7 | >DrQA is a system for reading comprehension applied to open-domain question answering. In particular, DrQA is targeted at the task of "machine reading at scale" (MRS). In this setting, we are searching for an answer to a question in a potentially very large corpus of unstructured documents (that may not be redundant). Thus the system has to combine the challenges of document retrieval (finding the relevant documents) with that of machine comprehension of text (identifying the answers from those documents). 8 | 9 | >Our experiments with DrQA focus on answering factoid questions while using Wikipedia as the unique knowledge source for documents. Wikipedia is a well-suited source of large-scale, rich, detailed information. In order to answer any question, one must first retrieve the few potentially relevant articles among more than 5 million, and then scan them carefully to identify the answer. 10 | 11 | >Note that DrQA treats Wikipedia as a generic collection of articles and does not rely on its internal graph structure. As a result, DrQA can be straightforwardly applied to any collection of documents, as described in the retriever README. 12 | 13 | 14 | ## Installation 15 | This is a modified version of facebook [DrQA](https://github.com/facebookresearch/DrQA) module. This module is for study only. 16 | To install this module, please install pytorch according to [pytorch.org](http://pytorch.org/) and run the setup.py in python3 environment. (3.5, 3.6 both works well) (the setup may cover the facebook DrQA) If I missed some requirements, please just install with pip. Then install corenlp with Chinese package according to [CoreNLP offical](https://stanfordnlp.github.io/CoreNLP/), you may specific classpath in environment or in file `drqa\tokenizers\Zh_tokenizer.py`. Then you may download vectors and training sets to start your work. 17 | Download link : [Data](https://pan.baidu.com/s/1geMDxMN) , secret: 232d 18 | Merge drqa folder with original folder, the file contains common data file and zh_dict.json for Chinese_English translation. 19 | 20 | ## Structures 21 | /data : stores all the data 22 | /vector 23 | /'training set' 24 | /'db' : retriever db 25 | /'module' : saved module 26 | ... 27 | /drqa : main modules 28 | /features : common features file shared in project 29 | /pipline : concact reader and retriever 30 | drqa.py original pipline manager 31 | simpleDrQA.py a simple version of pipline manager 32 | /reader : reader module 33 | ... 34 | /retriever : retriever module 35 | ... 36 | net_retriever.py : simply retrieve context (search) in the search engine (baidu) and use results as context 37 | /tokenizers : tokenizer features 38 | /Zh_tokenizer.py : corenlp chinese (use tag '--tokenizer zh' to specific) 39 | /zh_features.py : common chinese features 40 | ... 41 | /scripts : common command line methold 42 | ... 43 | /pipline 44 | sinteractive.py : use simple drqa agent 45 | ... 46 | Common files in the project is not mentioned, please check with facebook DrQA. 47 | 48 | ## Features 49 | Please check facebook module for designing features. 50 | As a Chinese implementation of original module, this project supported full Chinese support with full Chinese linguistic tags. Chinese Lemma tag is replaced with English translation. All the expression will be parsed through Chinese normalization. (symbol, simp and trad) 51 | Includes function for various Chinese features transformation : 52 | 1. simplified to traditional 53 | 2. Chinese to pinyin 54 | 3. Chinese number to number 55 | 4. SBC case to DBC case 56 | 6. common symbol transformation 57 | 58 | The module embed with common words mapping. (abbreviation <-> full spelling, etc.) 59 | This module provides a simple context scoring function for better answer ranking. 60 | Provide simple context retriever. (worked with baidu search engine) 61 | Provide parsed and tested training set (based on WebQA) and word embedding (60 dimension and 200 dimension). 62 | Provide with testing module. 63 | 64 | 65 | ## Result 66 | sinteractive.py result example: 67 | 68 | \>\>\> process("西交图书馆的全名?", doc_n=1, net_n=3) 69 | 09/27/2017 04:45:38 PM: [ [question after filting : 西安交通大学图书馆的全名? ] ] 70 | 09/27/2017 04:45:39 PM: [ [retreive from net : 3 | expect : 3] ] 71 | 72 | ... 73 | 74 | 09/27/2017 04:45:43 PM: [ [retreive from db] ] 75 | \=================raw text================== 76 | ...侧,目前为工程训练中心、实验室及艺术庭院.西安交大图书馆北楼始建于1961年7月,共三层,建筑面积11200平方米,是和老教学主楼一并设计建设“中苏风格”建筑,风格朴实宏伟.和北楼相连的南楼建筑面积18000平方米,于1991年3月投入使用,地上13层,地下2层.设计上南楼保留了北楼的设计元素,外形呈金字塔形,被部分同学们戏称为“铁甲小宝”.图书馆南楼顶部有报时的大钟,报时音乐为“东方红”,2010年曾改为中国名曲“茉莉花”,后因国际形势变化改回“东方红”.1995年,图书馆经钱学森本人同意及中宣部批准改名钱学森图书馆,并由时任中共中央总书记、国家主席江泽民题写馆名.现今该图书馆拥有阅览座位3518席,累计藏书522.8万册(件),报刊10053种,现刊4089种. 77 | \=================================== 78 | 79 | .... 80 | 81 | 图书馆南楼顶部有报时的大钟,报时音乐为“东方红”,2010年曾改为中国名曲“茉莉花”,后因国际形势变化改回“东方红” 1995年,图书馆经钱学森本人同意及中宣部批准改名钱学森图书馆,并由时任中共中央总书记、国家主席江泽民题写馆名 82 | ======== answer :钱学森图书馆 83 | answer score : 0.0819935 84 | context score : 9.164698700898501 85 | Time: 12.1489 86 | 87 | Training with WebQA training dataset, the code runs a 65% exact match rate in valiation set. 88 | The result of retriever module or pipline is not tested. (our document set is not complete at all and retriever module seems working badly) The procession for context (retrieved data) is vital in final performance. 89 | 90 | 91 | ## License 92 | DrQA_cn is BSD-licensed based on [DrQA](https://github.com/facebookresearch/DrQA). 93 | 94 | Training set is licensed by baidu : [WebQA](http://idl.baidu.com/WebQA.html). This dataset is released for research purpose only. Copyright (C) 2016 Baidu.com, Inc. All Rights Reserved. 95 | 96 | 97 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | scikit-learn 3 | termcolor 4 | regex 5 | tqdm 6 | prettytable 7 | scipy 8 | nltk 9 | pexpect 10 | pypinyin 11 | hanziconv 12 | -------------------------------------------------------------------------------- /scripts/convert/squad.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | """A script to convert the default SQuAD dataset to the format: 8 | 9 | '{"question": "q1", "answer": ["a11", ..., "a1i"]}' 10 | ... 11 | '{"question": "qN", "answer": ["aN1", ..., "aNi"]}' 12 | 13 | """ 14 | 15 | import argparse 16 | import json 17 | 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument('input', type=str) 20 | parser.add_argument('output', type=str) 21 | args = parser.parse_args() 22 | 23 | # Read dataset 24 | with open(args.input) as f: 25 | dataset = json.load(f) 26 | 27 | # Iterate and write question-answer pairs 28 | with open(args.output, 'w') as f: 29 | for article in dataset['data']: 30 | for paragraph in article['paragraphs']: 31 | for qa in paragraph['qas']: 32 | question = qa['question'] 33 | answer = [a['text'] for a in qa['answers']] 34 | f.write(json.dumps({'question': question, 'answer': answer})) 35 | f.write('\n') 36 | -------------------------------------------------------------------------------- /scripts/convert/webquestions.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | """A script to convert the default WebQuestions dataset to the format: 8 | 9 | '{"question": "q1", "answer": ["a11", ..., "a1i"]}' 10 | ... 11 | '{"question": "qN", "answer": ["aN1", ..., "aNi"]}' 12 | 13 | """ 14 | 15 | import argparse 16 | import re 17 | import json 18 | 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument('input', type=str) 21 | parser.add_argument('output', type=str) 22 | args = parser.parse_args() 23 | 24 | # Read dataset 25 | with open(args.input) as f: 26 | dataset = json.load(f) 27 | 28 | # Iterate and write question-answer pairs 29 | with open(args.output, 'w') as f: 30 | for ex in dataset: 31 | question = ex['utterance'] 32 | answer = ex['targetValue'] 33 | answer = re.findall( 34 | r'(?<=\(description )(.+?)(?=\) \(description|\)\)$)', answer 35 | ) 36 | answer = [a.replace('"', '') for a in answer] 37 | f.write(json.dumps({'question': question, 'answer': answer})) 38 | f.write('\n') 39 | -------------------------------------------------------------------------------- /scripts/distant/check_data.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | """A script to visually inspect generated data.""" 8 | 9 | import argparse 10 | import json 11 | from termcolor import colored 12 | 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument('file', type=str) 15 | args = parser.parse_args() 16 | 17 | with open(args.file) as f: 18 | lines = f.readlines() 19 | for line in lines: 20 | data = json.loads(line) 21 | question = ' '.join(data['question']) 22 | start, end = data['answers'][0] 23 | doc = data['document'] 24 | pre = ' '.join(doc[:start]) 25 | ans = colored(' '.join(doc[start: end + 1]), 'red', attrs=['bold']) 26 | post = ' '.join(doc[end + 1:]) 27 | print('-' * 50) 28 | print('Question: %s' % question) 29 | print('') 30 | print('Document: %s' % (' '.join([pre, ans, post]))) 31 | input() 32 | -------------------------------------------------------------------------------- /scripts/distant/generate.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | """A script to generate distantly supervised training data. 8 | 9 | Using Wikipedia and available QA datasets, we search for a paragraph 10 | that can be used as a supporting context. 11 | """ 12 | 13 | import argparse 14 | import uuid 15 | import heapq 16 | import logging 17 | import regex as re 18 | import os 19 | import json 20 | import random 21 | 22 | from functools import partial 23 | from collections import Counter 24 | from multiprocessing import Pool, cpu_count 25 | from multiprocessing.util import Finalize 26 | 27 | from nltk.tokenize import word_tokenize 28 | from nltk.chunk import ne_chunk 29 | from nltk.tag import pos_tag 30 | 31 | from drqa import tokenizers 32 | from drqa import retriever 33 | from drqa.retriever import utils 34 | 35 | logger = logging.getLogger() 36 | 37 | 38 | # ------------------------------------------------------------------------------ 39 | # Fetch text, tokenize + annotate 40 | # ------------------------------------------------------------------------------ 41 | 42 | PROCESS_TOK = None 43 | PROCESS_DB = None 44 | 45 | 46 | def init(tokenizer_class, tokenizer_opts, db_class=None, db_opts=None): 47 | global PROCESS_TOK, PROCESS_DB 48 | PROCESS_TOK = tokenizer_class(**tokenizer_opts) 49 | Finalize(PROCESS_TOK, PROCESS_TOK.shutdown, exitpriority=100) 50 | 51 | # optionally open a db connection 52 | if db_class: 53 | PROCESS_DB = db_class(**db_opts) 54 | Finalize(PROCESS_DB, PROCESS_DB.close, exitpriority=100) 55 | 56 | 57 | def fetch_text(doc_id): 58 | global PROCESS_DB 59 | return PROCESS_DB.get_doc_text(doc_id) 60 | 61 | 62 | def tokenize_text(text): 63 | global PROCESS_TOK 64 | return PROCESS_TOK.tokenize(text) 65 | 66 | 67 | def nltk_entity_groups(text): 68 | """Return all contiguous NER tagged chunks by NLTK.""" 69 | parse_tree = ne_chunk(pos_tag(word_tokenize(text))) 70 | ner_chunks = [' '.join([l[0] for l in t.leaves()]) 71 | for t in parse_tree.subtrees() if t.label() != 'S'] 72 | return ner_chunks 73 | 74 | 75 | # ------------------------------------------------------------------------------ 76 | # Find answer candidates. 77 | # ------------------------------------------------------------------------------ 78 | 79 | 80 | def find_answer(paragraph, q_tokens, answer, opts): 81 | """Return the best matching answer offsets from a paragraph. 82 | 83 | The paragraph is skipped if: 84 | * It is too long or short. 85 | * It doesn't contain the answer at all. 86 | * It doesn't contain named entities found in the question. 87 | * The answer context match score is too low. 88 | - This is the unigram + bigram overlap within +/- window_sz. 89 | """ 90 | # Length check 91 | if len(paragraph) > opts['char_max'] or len(paragraph) < opts['char_min']: 92 | return 93 | 94 | # Answer check 95 | if opts['regex']: 96 | # Add group around the whole answer 97 | answer = '(%s)' % answer[0] 98 | ans_regex = re.compile(answer, flags=re.IGNORECASE + re.UNICODE) 99 | answers = ans_regex.findall(paragraph) 100 | answers = {a[0] if isinstance(a, tuple) else a for a in answers} 101 | answers = {a.strip() for a in answers if len(a.strip()) > 0} 102 | else: 103 | answers = {a for a in answer if a in paragraph} 104 | if len(answers) == 0: 105 | return 106 | 107 | # Entity check. Default tokenizer + NLTK to minimize falling through cracks 108 | q_tokens, q_nltk_ner = q_tokens 109 | for ne in q_tokens.entity_groups(): 110 | if ne[0] not in paragraph: 111 | return 112 | for ne in q_nltk_ner: 113 | if ne not in paragraph: 114 | return 115 | 116 | # Search... 117 | p_tokens = tokenize_text(paragraph) 118 | p_words = p_tokens.words(uncased=True) 119 | q_grams = Counter(q_tokens.ngrams( 120 | n=2, uncased=True, filter_fn=utils.filter_ngram 121 | )) 122 | 123 | best_score = 0 124 | best_ex = None 125 | for ans in answers: 126 | try: 127 | a_words = tokenize_text(ans).words(uncased=True) 128 | except RuntimeError: 129 | logger.warn('Failed to tokenize answer: %s' % ans) 130 | continue 131 | for idx in range(len(p_words)): 132 | if p_words[idx:idx + len(a_words)] == a_words: 133 | # Overlap check 134 | w_s = max(idx - opts['window_sz'], 0) 135 | w_e = min(idx + opts['window_sz'] + len(a_words), len(p_words)) 136 | w_tokens = p_tokens.slice(w_s, w_e) 137 | w_grams = Counter(w_tokens.ngrams( 138 | n=2, uncased=True, filter_fn=utils.filter_ngram 139 | )) 140 | score = sum((w_grams & q_grams).values()) 141 | if score > best_score: 142 | # Success! Set new score + formatted example 143 | best_score = score 144 | best_ex = { 145 | 'id': uuid.uuid4().hex, 146 | 'question': q_tokens.words(), 147 | 'document': p_tokens.words(), 148 | 'offsets': p_tokens.offsets(), 149 | 'answers': [(idx, idx + len(a_words) - 1)], 150 | 'qlemma': q_tokens.lemmas(), 151 | 'lemma': p_tokens.lemmas(), 152 | 'pos': p_tokens.pos(), 153 | 'ner': p_tokens.entities(), 154 | } 155 | if best_score >= opts['match_threshold']: 156 | return best_score, best_ex 157 | 158 | 159 | def search_docs(inputs, max_ex=5, opts=None): 160 | """Given a set of document ids (returned by ranking for a question), search 161 | for top N best matching (by heuristic) paragraphs that contain the answer. 162 | """ 163 | if not opts: 164 | raise RuntimeError('Options dict must be supplied.') 165 | 166 | doc_ids, q_tokens, answer = inputs 167 | examples = [] 168 | for i, doc_id in enumerate(doc_ids): 169 | for j, paragraph in enumerate(re.split(r'\n+', fetch_text(doc_id))): 170 | found = find_answer(paragraph, q_tokens, answer, opts) 171 | if found: 172 | # Reverse ranking, giving priority to early docs + paragraphs 173 | score = (found[0], -i, -j, random.random()) 174 | if len(examples) < max_ex: 175 | heapq.heappush(examples, (score, found[1])) 176 | else: 177 | heapq.heappushpop(examples, (score, found[1])) 178 | return [e[1] for e in examples] 179 | 180 | 181 | def process(questions, answers, outfile, opts): 182 | """Generate examples for all questions.""" 183 | logger.info('Processing %d question answer pairs...' % len(questions)) 184 | logger.info('Will save to %s.dstrain and %s.dsdev' % (outfile, outfile)) 185 | 186 | # Load ranker 187 | ranker = opts['ranker_class'](strict=False) 188 | logger.info('Ranking documents (top %d per question)...' % opts['n_docs']) 189 | ranked = ranker.batch_closest_docs(questions, k=opts['n_docs']) 190 | ranked = [r[0] for r in ranked] 191 | 192 | # Start pool of tokenizers with ner enabled 193 | workers = Pool(opts['workers'], initializer=init, 194 | initargs=(opts['tokenizer_class'], {'annotators': {'ner'}})) 195 | 196 | logger.info('Pre-tokenizing questions...') 197 | q_tokens = workers.map(tokenize_text, questions) 198 | q_ner = workers.map(nltk_entity_groups, questions) 199 | q_tokens = list(zip(q_tokens, q_ner)) 200 | workers.close() 201 | workers.join() 202 | 203 | # Start pool of simple tokenizers + db connections 204 | workers = Pool(opts['workers'], initializer=init, 205 | initargs=(opts['tokenizer_class'], {}, 206 | opts['db_class'], {})) 207 | 208 | logger.info('Searching documents...') 209 | cnt = 0 210 | inputs = [(ranked[i], q_tokens[i], answers[i]) for i in range(len(ranked))] 211 | search_fn = partial(search_docs, max_ex=opts['max_ex'], opts=opts['search']) 212 | with open(outfile + '.dstrain', 'w') as f_train, \ 213 | open(outfile + '.dsdev', 'w') as f_dev: 214 | for res in workers.imap_unordered(search_fn, inputs): 215 | for ex in res: 216 | cnt += 1 217 | f = f_dev if random.random() < opts['dev_split'] else f_train 218 | f.write(json.dumps(ex)) 219 | f.write('\n') 220 | if cnt % 1000 == 0: 221 | logging.info('%d results so far...' % cnt) 222 | workers.close() 223 | workers.join() 224 | logging.info('Finished. Total = %d' % cnt) 225 | 226 | 227 | # ------------------------------------------------------------------------------ 228 | # Main & commandline options 229 | # ------------------------------------------------------------------------------ 230 | 231 | 232 | if __name__ == "__main__": 233 | parser = argparse.ArgumentParser() 234 | parser.add_argument('data_dir', type=str, help='Dataset directory') 235 | parser.add_argument('data_name', type=str, help='Dataset name') 236 | parser.add_argument('out_dir', type=str, help='Output directory') 237 | 238 | dataset = parser.add_argument_group('Dataset') 239 | dataset.add_argument('--regex', action='store_true', 240 | help='Flag if answers are expressed as regexps') 241 | dataset.add_argument('--dev-split', type=float, default=0, 242 | help='Hold out for ds dev set (0.X)') 243 | 244 | search = parser.add_argument_group('Search Heuristic') 245 | search.add_argument('--match-threshold', type=int, default=1, 246 | help='Minimum context overlap with question') 247 | search.add_argument('--char-max', type=int, default=1500, 248 | help='Maximum allowed context length') 249 | search.add_argument('--char-min', type=int, default=25, 250 | help='Minimum allowed context length') 251 | search.add_argument('--window-sz', type=int, default=20, 252 | help='Use context on +/- window_sz for overlap measure') 253 | 254 | general = parser.add_argument_group('General') 255 | general.add_argument('--max-ex', type=int, default=5, 256 | help='Maximum matches generated per question') 257 | general.add_argument('--n-docs', type=int, default=5, 258 | help='Number of docs retrieved per question') 259 | general.add_argument('--tokenizer', type=str, default='corenlp') 260 | general.add_argument('--ranker', type=str, default='tfidf') 261 | general.add_argument('--db', type=str, default='sqlite') 262 | general.add_argument('--workers', type=int, default=cpu_count()) 263 | args = parser.parse_args() 264 | 265 | # Logging 266 | logger.setLevel(logging.INFO) 267 | fmt = logging.Formatter('%(asctime)s: [ %(message)s ]', 268 | '%m/%d/%Y %I:%M:%S %p') 269 | console = logging.StreamHandler() 270 | console.setFormatter(fmt) 271 | logger.addHandler(console) 272 | 273 | # Read dataset 274 | dataset = os.path.join(args.data_dir, args.data_name) 275 | questions = [] 276 | answers = [] 277 | for line in open(dataset): 278 | data = json.loads(line) 279 | question = data['question'] 280 | answer = data['answer'] 281 | 282 | # Make sure the regex compiles 283 | if args.regex: 284 | try: 285 | re.compile(answer[0]) 286 | except BaseException: 287 | logger.warning('Regex failed to compile: %s' % answer) 288 | continue 289 | 290 | questions.append(question) 291 | answers.append(answer) 292 | 293 | # Get classes 294 | ranker_class = retriever.get_class(args.ranker) 295 | db_class = retriever.get_class(args.db) 296 | tokenizer_class = tokenizers.get_class(args.tokenizer) 297 | 298 | # Form options 299 | search_keys = ('regex', 'match_threshold', 'char_max', 300 | 'char_min', 'window_sz') 301 | opts = { 302 | 'ranker_class': retriever.get_class(args.ranker), 303 | 'tokenizer_class': tokenizers.get_class(args.tokenizer), 304 | 'db_class': retriever.get_class(args.db), 305 | 'search': {k: vars(args)[k] for k in search_keys}, 306 | } 307 | opts.update(vars(args)) 308 | 309 | # Process! 310 | outname = os.path.splitext(args.data_name)[0] 311 | outfile = os.path.join(args.out_dir, outname) 312 | process(questions, answers, outfile, opts) 313 | -------------------------------------------------------------------------------- /scripts/pipeline/eval.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | """Follows official evaluation script for v1.1 of the SQuAD dataset.""" 8 | 9 | import argparse 10 | import json 11 | from drqa.retriever.utils import normalize 12 | from drqa.reader.utils import ( 13 | exact_match_score, 14 | regex_match_score, 15 | metric_max_over_ground_truths 16 | ) 17 | 18 | 19 | def evaluate(dataset_file, prediction_file, regex=False): 20 | print('-' * 50) 21 | print('Dataset: %s' % dataset_file) 22 | print('Predictions: %s' % prediction_file) 23 | 24 | answers = [] 25 | for line in open(args.dataset): 26 | data = json.loads(line) 27 | answer = [normalize(a) for a in data['answer']] 28 | answers.append(answer) 29 | 30 | predictions = [] 31 | with open(prediction_file) as f: 32 | for line in f: 33 | data = json.loads(line) 34 | prediction = normalize(data[0]['span']) 35 | predictions.append(prediction) 36 | 37 | exact_match = 0 38 | for i in range(len(predictions)): 39 | match_fn = regex_match_score if regex else exact_match_score 40 | exact_match += metric_max_over_ground_truths( 41 | match_fn, predictions[i], answers[i] 42 | ) 43 | total = len(predictions) 44 | exact_match = 100.0 * exact_match / total 45 | print({'exact_match': exact_match}) 46 | 47 | 48 | if __name__ == '__main__': 49 | parser = argparse.ArgumentParser() 50 | parser.add_argument('dataset', type=str) 51 | parser.add_argument('predictions', type=str) 52 | parser.add_argument('--regex', action='store_true') 53 | args = parser.parse_args() 54 | evaluate(args.dataset, args.predictions, args.regex) 55 | -------------------------------------------------------------------------------- /scripts/pipeline/interactive.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | """Interactive interface to full DrQA pipeline.""" 8 | ''' word with original pipline. not working well in my 9 | context. plz use sinteractive instead''' 10 | 11 | 12 | import torch 13 | import argparse 14 | import code 15 | import prettytable 16 | import logging 17 | 18 | from termcolor import colored 19 | from drqa import pipeline 20 | from drqa.retriever import utils 21 | 22 | logger = logging.getLogger() 23 | logger.setLevel(logging.INFO) 24 | fmt = logging.Formatter('%(asctime)s: [ %(message)s ]', '%m/%d/%Y %I:%M:%S %p') 25 | console = logging.StreamHandler() 26 | console.setFormatter(fmt) 27 | logger.addHandler(console) 28 | 29 | parser = argparse.ArgumentParser() 30 | parser.add_argument('--reader-model', type=str, default=None, 31 | help='Path to trained Document Reader model') 32 | parser.add_argument('--retriever-model', type=str, default=None, 33 | help='Path to Document Retriever model (tfidf)') 34 | parser.add_argument('--doc-db', type=str, default=None, 35 | help='Path to Document DB') 36 | parser.add_argument('--tokenizer', type=str, default=None, 37 | help=("String option specifying tokenizer type to " 38 | "use (e.g. 'corenlp')")) 39 | parser.add_argument('--candidate-file', type=str, default=None, 40 | help=("List of candidates to restrict predictions to, " 41 | "one candidate per line")) 42 | parser.add_argument('--no-cuda', action='store_true', 43 | help="Use CPU only") 44 | parser.add_argument('--gpu', type=int, default=-1, 45 | help="Specify GPU device id to use") 46 | args = parser.parse_args() 47 | 48 | args.cuda = not args.no_cuda and torch.cuda.is_available() 49 | if args.cuda: 50 | torch.cuda.set_device(args.gpu) 51 | logger.info('CUDA enabled (GPU %d)' % args.gpu) 52 | else: 53 | logger.info('Running on CPU only.') 54 | 55 | if args.candidate_file: 56 | logger.info('Loading candidates from %s' % args.candidate_file) 57 | candidates = set() 58 | with open(args.candidate_file) as f: 59 | for line in f: 60 | line = utils.normalize(line.strip()).lower() 61 | candidates.add(line) 62 | logger.info('Loaded %d candidates.' % len(candidates)) 63 | else: 64 | candidates = None 65 | 66 | logger.info('Initializing pipeline...') 67 | DrQA = pipeline.DrQA( 68 | cuda=args.cuda, 69 | fixed_candidates=candidates, 70 | reader_model=args.reader_model, 71 | ranker_config={'options': {'tfidf_path': args.retriever_model}}, 72 | db_config={'options': {'db_path': args.doc_db}}, 73 | tokenizer=args.tokenizer, 74 | num_workers=1, 75 | max_loaders=1, 76 | embedding_file='data/vector/zh200.vec' 77 | ) 78 | 79 | 80 | # ------------------------------------------------------------------------------ 81 | # Drop in to interactive mode 82 | # ------------------------------------------------------------------------------ 83 | 84 | 85 | def process(question, candidates=None, top_n=1, n_docs=5): 86 | predictions = DrQA.process( 87 | question, candidates, top_n, n_docs, return_context=True 88 | ) 89 | table = prettytable.PrettyTable( 90 | ['Rank', 'Answer', 'Doc', 'Answer Score', 'Doc Score'] 91 | ) 92 | for i, p in enumerate(predictions, 1): 93 | table.add_row([i, p['span'], p['doc_id'], 94 | '%.5g' % p['span_score'], 95 | '%.5g' % p['doc_score']]) 96 | print('Top Predictions:') 97 | print(table) 98 | print('\nContexts:') 99 | for p in predictions: 100 | text = p['context']['text'] 101 | start = p['context']['start'] 102 | end = p['context']['end'] 103 | output = (text[:start] + 104 | colored(text[start: end], 'green', attrs=['bold']) + 105 | text[end:]) 106 | print('[ Doc = %s ]' % p['doc_id']) 107 | print(output + '\n') 108 | 109 | 110 | banner = """ 111 | Interactive DrQA 112 | >> process(question, candidates=None, top_n=1, n_docs=5) 113 | >> usage() 114 | """ 115 | 116 | 117 | def usage(): 118 | print(banner) 119 | 120 | 121 | code.interact(banner=banner, local=locals()) 122 | -------------------------------------------------------------------------------- /scripts/pipeline/predict.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | """Run predictions using the full DrQA retriever-reader pipeline.""" 8 | 9 | import torch 10 | import os 11 | import time 12 | import json 13 | import argparse 14 | import logging 15 | 16 | from drqa import pipeline 17 | from drqa.retriever import utils 18 | 19 | 20 | logger = logging.getLogger() 21 | logger.setLevel(logging.INFO) 22 | fmt = logging.Formatter('%(asctime)s: [ %(message)s ]', '%m/%d/%Y %I:%M:%S %p') 23 | console = logging.StreamHandler() 24 | console.setFormatter(fmt) 25 | logger.addHandler(console) 26 | 27 | parser = argparse.ArgumentParser() 28 | parser.add_argument('dataset', type=str) 29 | parser.add_argument('--out-dir', type=str, default='/tmp', 30 | help=("Directory to write prediction file to " 31 | "(<dataset>-<model>-pipeline.preds)")) 32 | parser.add_argument('--reader-model', type=str, default=None, 33 | help="Path to trained Document Reader model") 34 | parser.add_argument('--retriever-model', type=str, default=None, 35 | help="Path to Document Retriever model (tfidf)") 36 | parser.add_argument('--doc-db', type=str, default=None, 37 | help='Path to Document DB') 38 | parser.add_argument('--embedding-file', type=str, default=None, 39 | help=("Expand dictionary to use all pretrained " 40 | "embeddings in this file")) 41 | parser.add_argument('--candidate-file', type=str, default=None, 42 | help=("List of candidates to restrict predictions to, " 43 | "one candidate per line")) 44 | parser.add_argument('--n-docs', type=int, default=5, 45 | help="Number of docs to retrieve per query") 46 | parser.add_argument('--top-n', type=int, default=1, 47 | help="Number of predictions to make per query") 48 | parser.add_argument('--tokenizer', type=str, default=None, 49 | help=("String option specifying tokenizer type to use " 50 | "(e.g. 'corenlp')")) 51 | parser.add_argument('--no-cuda', action='store_true', 52 | help="Use CPU only") 53 | parser.add_argument('--gpu', type=int, default=-1, 54 | help="Specify GPU device id to use") 55 | parser.add_argument('--parallel', action='store_true', 56 | help='Use data parallel (split across gpus)') 57 | parser.add_argument('--num-workers', type=int, default=None, 58 | help='Number of CPU processes (for tokenizing, etc)') 59 | parser.add_argument('--batch-size', type=int, default=128, 60 | help='Document paragraph batching size') 61 | parser.add_argument('--predict-batch-size', type=int, default=1000, 62 | help='Question batching size') 63 | args = parser.parse_args() 64 | t0 = time.time() 65 | 66 | args.cuda = not args.no_cuda and torch.cuda.is_available() 67 | if args.cuda: 68 | torch.cuda.set_device(args.gpu) 69 | logger.info('CUDA enabled (GPU %d)' % args.gpu) 70 | else: 71 | logger.info('Running on CPU only.') 72 | 73 | if args.candidate_file: 74 | logger.info('Loading candidates from %s' % args.candidate_file) 75 | candidates = set() 76 | with open(args.candidate_file) as f: 77 | for line in f: 78 | line = utils.normalize(line.strip()).lower() 79 | candidates.add(line) 80 | logger.info('Loaded %d candidates.' % len(candidates)) 81 | else: 82 | candidates = None 83 | 84 | logger.info('Initializing pipeline...') 85 | DrQA = pipeline.DrQA( 86 | reader_model=args.reader_model, 87 | fixed_candidates=candidates, 88 | embedding_file=args.embedding_file, 89 | tokenizer=args.tokenizer, 90 | batch_size=args.batch_size, 91 | cuda=args.cuda, 92 | data_parallel=args.parallel, 93 | ranker_config={'options': {'tfidf_path': args.retriever_model, 94 | 'strict': False}}, 95 | db_config={'options': {'db_path': args.doc_db}}, 96 | num_workers=args.num_workers, 97 | ) 98 | 99 | 100 | # ------------------------------------------------------------------------------ 101 | # Read in dataset and make predictions 102 | # ------------------------------------------------------------------------------ 103 | 104 | 105 | logger.info('Loading queries from %s' % args.dataset) 106 | queries = [] 107 | for line in open(args.dataset): 108 | data = json.loads(line) 109 | queries.append(data['question']) 110 | 111 | model = os.path.splitext(os.path.basename(args.reader_model or 'default'))[0] 112 | basename = os.path.splitext(os.path.basename(args.dataset))[0] 113 | outfile = os.path.join(args.out_dir, basename + '-' + model + '-pipeline.preds') 114 | 115 | logger.info('Writing results to %s' % outfile) 116 | with open(outfile, 'w') as f: 117 | batches = [queries[i: i + args.predict_batch_size] 118 | for i in range(0, len(queries), args.predict_batch_size)] 119 | for i, batch in enumerate(batches): 120 | logger.info( 121 | '-' * 25 + ' Batch %d/%d ' % (i + 1, len(batches)) + '-' * 25 122 | ) 123 | predictions = DrQA.process_batch( 124 | batch, 125 | n_docs=args.n_docs, 126 | top_n=args.top_n, 127 | ) 128 | for p in predictions: 129 | f.write(json.dumps(p) + '\n') 130 | 131 | logger.info('Total time: %.2f' % (time.time() - t0)) 132 | -------------------------------------------------------------------------------- /scripts/pipeline/sinteractive.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # simple interactive model. only run a single thread process to void troubles 4 | 5 | import torch 6 | import code 7 | import argparse 8 | import logging 9 | import prettytable 10 | import time 11 | 12 | from drqa.reader import Predictor 13 | from drqa.pipeline.simpleDrQA import SDrQA 14 | 15 | logger = logging.getLogger() 16 | logger.setLevel(logging.INFO) 17 | fmt = logging.Formatter('%(asctime)s: [ %(message)s ]', '%m/%d/%Y %I:%M:%S %p') 18 | console = logging.StreamHandler() 19 | console.setFormatter(fmt) 20 | logger.addHandler(console) 21 | 22 | 23 | # ------------------------------------------------------------------------------ 24 | # Commandline arguments & init 25 | # ------------------------------------------------------------------------------ 26 | 27 | 28 | parser = argparse.ArgumentParser() 29 | parser.add_argument('--model', type=str, default=None, 30 | help='Path to model to use') 31 | parser.add_argument('--tokenizer', type=str, default='zh', 32 | help=("String option specifying tokenizer type to use " 33 | "(e.g. 'corenlp')")) 34 | parser.add_argument('--no-cuda', action='store_true', 35 | help='Use CPU only') 36 | parser.add_argument('--embedding-file', default=None, 37 | help='embedding') 38 | parser.add_argument('--gpu', type=int, default=-1, 39 | help='Specify GPU device id to use') 40 | parser.add_argument('--tfidf-model', type=str, default=None) 41 | parser.add_argument('--db', type=str, default=None) 42 | args = parser.parse_args() 43 | 44 | args.cuda = not args.no_cuda and torch.cuda.is_available() 45 | if args.cuda: 46 | torch.cuda.set_device(args.gpu) 47 | logger.info('CUDA enabled (GPU %d)' % args.gpu) 48 | else: 49 | logger.info('Running on CPU only.') 50 | 51 | predictor = Predictor(args.model, args.tokenizer, num_workers=0, 52 | embedding_file=args.embedding_file) 53 | if args.cuda: 54 | predictor.cuda() 55 | 56 | # maybe a different embedding to save memory 57 | drqa = SDrQA(predictor, args.tfidf_model, args.db, ebdPath=args.embedding_file) 58 | # ------------------------------------------------------------------------------ 59 | # Drop in to interactive mode 60 | # ------------------------------------------------------------------------------ 61 | 62 | 63 | def process(question, doc_n=1, pred_n=1, net_n=1): 64 | t0 = time.time() 65 | answers = drqa.predict(question, qasTopN=pred_n, 66 | docTopN=doc_n, netTopN=net_n) 67 | 68 | def sort(a): 69 | return (0.2 + a['answerScore']) * a['contextScore'] 70 | answers = sorted(answers, key=sort) 71 | for ans in answers: 72 | print('---------------------------------------------------------') 73 | print(ans['text']) 74 | print("======== answer :" + ans['answer']) 75 | print('answer score : ' + str(ans['answerScore'])) 76 | print('context score : ' + str(ans['contextScore'])) 77 | # predictions = predictor.predict(document, question, candidates, top_n) 78 | # table = prettytable.PrettyTable(['Rank', 'Span', 'Score']) 79 | # for i, p in enumerate(predictions, 1): 80 | # table.add_row([i, p[0], p[1]]) 81 | # print(table) 82 | print('Time: %.4f' % (time.time() - t0)) 83 | 84 | 85 | banner = """ 86 | DrQA Interactive Document Reader Module 87 | >> process(question, doc_n=1, pred_n=1, net_n=1): 88 | doc_n: doc number in database 89 | pred_n: answer number for every context 90 | net_n: doc number in search engine 91 | >> usage() 92 | """ 93 | 94 | 95 | def usage(): 96 | print(banner) 97 | 98 | 99 | code.interact(banner=banner, local=locals()) 100 | -------------------------------------------------------------------------------- /scripts/reader/README.md: -------------------------------------------------------------------------------- 1 | # Document Reader 2 | 3 | ## Preprocessing 4 | 5 | `preprocess.py` takes a SQuAD-formatted dataset and outputs a preprocessed, training-ready file. Specifically, it handles tokenization, mapping character offsets to token offsets, and any additional featurization such as lemmatization, part-of-speech tagging, and named entity recognition. 6 | 7 | To preprocess SQuAD (assuming both input and output files are in `data/datasets`): 8 | 9 | ```bash 10 | python scripts/reader/preprocess.py data/datasets data/datasets --split SQuAD-v1.1-train 11 | ``` 12 | ```bash 13 | python scripts/reader/preprocess.py data/datasets data/datasets --split SQuAD-v1.1-dev 14 | ``` 15 | - _You need to have [SQuAD](../../README.md#qa-datasets) train-v1.1.json and dev-v1.1.json in data/datasets (here renamed as SQuAD-v1.1-<train/dev>.json)_ 16 | 17 | ## Training 18 | 19 | `train.py` is the main train script for the Document Reader. 20 | 21 | To get started with training a model on SQuAD with our best hyper parameters: 22 | 23 | ```bash 24 | python scripts/reader/train.py --embedding-file glove.840B.300d.txt --tune-partial 1000 25 | ``` 26 | - _You need to have the [glove embeddings](#note-on-word-embeddings) downloaded to data/embeddings/glove.840B.300d.txt._ 27 | - _You need to have done the preprocessing above._ 28 | 29 | The training has many options that you can tune: 30 | 31 | ``` 32 | Environment: 33 | --no-cuda Train on CPU, even if GPUs are available. (default: False) 34 | --gpu Run on a specific GPU (default: -1) 35 | --data-workers Number of subprocesses for data loading (default: 5) 36 | --parallel Use DataParallel on all available GPUs (default: False) 37 | --random-seed Random seed for all numpy/torch/cuda operations (for reproducibility). 38 | --num-epochs Train data iterations. 39 | --batch-size Batch size for training. 40 | --test-batch-size Batch size during validation/testing. 41 | 42 | Filesystem: 43 | --model-dir Directory for saved models/checkpoints/logs (default: /tmp/drqa-models). 44 | --model-name Unique model identifier (.mdl, .txt, .checkpoint) (default: <generated uuid>). 45 | --data-dir Directory of training/validation data (default: data/datasets). 46 | --train-file Preprocessed train file (default: SQuAD-v1.1-train-processed-corenlp.txt). 47 | --dev-file Preprocessed dev file (default: SQuAD-v1.1-dev-processed-corenlp.txt). 48 | --dev-json Unprocessed dev file to run validation while training on (used to get original text for getting spans and answer texts) (default: SQuAD-v1.1-dev.json). 49 | --embed-dir Directory of pre-trained embedding files (default: data/embeddings). 50 | --embedding-file Space-separated pretrained embeddings file (default: None). 51 | 52 | Saving/Loading: 53 | --checkpoint Save model + optimizer state after each epoch (default: False). 54 | --pretrained Path to a pretrained model to warm-start with (default: <empty>). 55 | --expand-dictionary Expand dictionary of pretrained (--pretrained) model to include training/dev words of new data (default: False). 56 | 57 | Preprocessing: 58 | --uncased-question Question words will be lower-cased (default: False). 59 | --uncased-doc Document words will be lower-cased (default: False). 60 | --restrict-vocab Only use pre-trained words in embedding_file (default: True). 61 | 62 | General: 63 | --official-eval Validate with official SQuAD eval (default: True). 64 | --valid-metric The evaluation metric used for model selection (default: f1). 65 | --display-iter Log state after every <display_iter> epochs (default: 25). 66 | --sort-by-len Sort batches by length for speed (default: True). 67 | 68 | DrQA Reader Model Architecture: 69 | --model-type Model architecture type (default: rnn). 70 | --embedding-dim Embedding size if embedding_file is not given (default: 300). 71 | --hidden-size Hidden size of RNN units (default: 128). 72 | --doc-layers Number of encoding layers for document (default: 3). 73 | --question-layers Number of encoding layers for question (default: 3). 74 | --rnn-type RNN type: LSTM, GRU, or RNN (default: lstm). 75 | 76 | DrQA Reader Model Details: 77 | --concat-rnn-layers Combine hidden states from each encoding layer (default: True). 78 | --question-merge The way of computing the question representation (default: self_attn). 79 | --use-qemb Whether to use weighted question embeddings (default: True). 80 | --use-in-question Whether to use in_question_* (cased, uncased, lemma) features (default: True). 81 | --use-pos Whether to use pos features (default: True). 82 | --use-ner Whether to use ner features (default: True). 83 | --use-lemma Whether to use lemma features (default: True). 84 | --use-tf Whether to use term frequency features (default: True). 85 | 86 | DrQA Reader Optimization: 87 | --dropout-emb Dropout rate for word embeddings (default: 0.4). 88 | --dropout-rnn Dropout rate for RNN states (default: 0.4). 89 | --dropout-rnn-output Whether to dropout the RNN output (default: True). 90 | --optimizer Optimizer: sgd or adamax (default: adamax). 91 | --learning-rate Learning rate for SGD only (default: 0.1). 92 | --grad-clipping Gradient clipping (default: 10). 93 | --weight-decay Weight decay factor (default: 0). 94 | --momentum Momentum factor (default: 0). 95 | --fix-embeddings Keep word embeddings fixed (use pretrained) (default: True). 96 | --tune-partial Backprop through only the top N question words (default: 0). 97 | --rnn-padding Explicitly account for padding (and skip it) in RNN encoding (default: False). 98 | --max-len MAX_LEN The max span allowed during decoding (default: 15). 99 | ``` 100 | 101 | ### Note on Word Embeddings 102 | 103 | Using pre-trained word embeddings is very important for performance. The models we provide were trained with cased GloVe embeddings trained on Common Crawl, however we have also found that other embeddings such as FastText do quite well. 104 | 105 | We suggest downloading the embeddings files and storing them under `data/embeddings/<file>.txt` (this is the default for `--embedding-dir`). The code expects space separated plain text files (\<token\> \<d1\> ... \<dN\>). 106 | 107 | - [GloVe: Common Crawl (cased)](http://nlp.stanford.edu/data/wordvecs/glove.840B.300d.zip) 108 | - [FastText: Wikipedia (uncased)](https://s3-us-west-1.amazonaws.com/fasttext-vectors/wiki.en.vec) 109 | 110 | ## Predicting 111 | 112 | `predict.py` uses a trained Document Reader model to make predictions for an input dataset. 113 | 114 | Required arguments: 115 | ``` 116 | dataset SQuAD-like dataset to evaluate on (format B). 117 | ``` 118 | 119 | Optional arguments: 120 | ``` 121 | --model Path to model to use. 122 | --embedding-file Expand dictionary to use all pretrained embeddings in this file. 123 | --out-dir Directory to write prediction file to (<dataset>-<model>.preds). 124 | --tokenizer String option specifying tokenizer type to use (e.g. 'corenlp'). 125 | --num-workers Number of CPU processes (for tokenizing, etc). 126 | --no-cuda Use CPU only. 127 | --gpu Specify GPU device id to use. 128 | --batch-size Example batching size (Reduce in case of OOM). 129 | --top-n Store top N predicted spans per example. 130 | --official Only store single top span instead of top N list. (The SQuAD eval script takes a dict of qid: span). 131 | ``` 132 | 133 | Note: The CoreNLP NER annotator is not fully deterministic (depends on the order examples are processed). Predictions may fluctuate very slightly between runs if `num-workers` > 1 and the model was trained with `use-ner` on. 134 | 135 | ## Interactive 136 | 137 | The Document Reader can also be used interactively (like the [full pipeline](../../README.md#quick-start-demo)). 138 | 139 | ```bash 140 | python scripts/reader/interactive.py --model /path/to/model 141 | ``` 142 | 143 | ``` 144 | >>> text = "Mary had a little lamb, whose fleece was white as snow. And everywhere that Mary went the lamb was sure to go." 145 | >>> question = "What color is Mary's lamb?" 146 | >>> process(text, question) 147 | 148 | +------+-------+---------+ 149 | | Rank | Span | Score | 150 | +------+-------+---------+ 151 | | 1 | white | 0.78002 | 152 | +------+-------+---------+ 153 | ``` -------------------------------------------------------------------------------- /scripts/reader/interactive.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | """A script to run the DrQA reader model interactively.""" 8 | 9 | import torch 10 | import code 11 | import argparse 12 | import logging 13 | import prettytable 14 | import time 15 | 16 | from drqa.reader import Predictor 17 | 18 | logger = logging.getLogger() 19 | logger.setLevel(logging.INFO) 20 | fmt = logging.Formatter('%(asctime)s: [ %(message)s ]', '%m/%d/%Y %I:%M:%S %p') 21 | console = logging.StreamHandler() 22 | console.setFormatter(fmt) 23 | logger.addHandler(console) 24 | 25 | 26 | # ------------------------------------------------------------------------------ 27 | # Commandline arguments & init 28 | # ------------------------------------------------------------------------------ 29 | 30 | 31 | parser = argparse.ArgumentParser() 32 | parser.add_argument('--model', type=str, default=None, 33 | help='Path to model to use') 34 | parser.add_argument('--tokenizer', type=str, default=None, 35 | help=("String option specifying tokenizer type to use " 36 | "(e.g. 'corenlp')")) 37 | parser.add_argument('--no-cuda', action='store_true', 38 | help='Use CPU only') 39 | parser.add_argument('--embedding-file', default='data/vector/zh200.vec', 40 | help='embedding') 41 | parser.add_argument('--gpu', type=int, default=-1, 42 | help='Specify GPU device id to use') 43 | args = parser.parse_args() 44 | 45 | args.cuda = not args.no_cuda and torch.cuda.is_available() 46 | if args.cuda: 47 | torch.cuda.set_device(args.gpu) 48 | logger.info('CUDA enabled (GPU %d)' % args.gpu) 49 | else: 50 | logger.info('Running on CPU only.') 51 | 52 | predictor = Predictor(args.model, args.tokenizer, num_workers=0, 53 | embedding_file=args.embedding_file) 54 | if args.cuda: 55 | predictor.cuda() 56 | 57 | 58 | # ------------------------------------------------------------------------------ 59 | # Drop in to interactive mode 60 | # ------------------------------------------------------------------------------ 61 | 62 | 63 | def process(document, question, candidates=None, top_n=1): 64 | t0 = time.time() 65 | predictions = predictor.predict(document, question, candidates, top_n) 66 | table = prettytable.PrettyTable(['Rank', 'Span', 'Score']) 67 | for i, p in enumerate(predictions, 1): 68 | table.add_row([i, p[0], p[1]]) 69 | print(table) 70 | print('Time: %.4f' % (time.time() - t0)) 71 | 72 | 73 | banner = """ 74 | DrQA Interactive Document Reader Module 75 | >> process(document, question, candidates=None, top_n=1) 76 | >> usage() 77 | """ 78 | 79 | 80 | def usage(): 81 | print(banner) 82 | 83 | 84 | code.interact(banner=banner, local=locals()) 85 | -------------------------------------------------------------------------------- /scripts/reader/predict.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | """A script to make and save model predictions on an input dataset.""" 8 | 9 | import os 10 | import time 11 | import torch 12 | import argparse 13 | import logging 14 | import json 15 | 16 | from tqdm import tqdm 17 | from drqa.reader import Predictor 18 | 19 | logger = logging.getLogger() 20 | logger.setLevel(logging.INFO) 21 | fmt = logging.Formatter('%(asctime)s: [ %(message)s ]', '%m/%d/%Y %I:%M:%S %p') 22 | console = logging.StreamHandler() 23 | console.setFormatter(fmt) 24 | logger.addHandler(console) 25 | 26 | parser = argparse.ArgumentParser() 27 | parser.add_argument('dataset', type=str, default=None, 28 | help='SQuAD-like dataset to evaluate on') 29 | parser.add_argument('--model', type=str, default=None, 30 | help='Path to model to use') 31 | parser.add_argument('--embedding-file', type=str, default=None, 32 | help=('Expand dictionary to use all pretrained ' 33 | 'embeddings in this file.')) 34 | parser.add_argument('--out-dir', type=str, default='/tmp', 35 | help=('Directory to write prediction file to ' 36 | '(<dataset>-<model>.preds)')) 37 | parser.add_argument('--tokenizer', type=str, default=None, 38 | help=("String option specifying tokenizer type to use " 39 | "(e.g. 'corenlp')")) 40 | parser.add_argument('--num-workers', type=int, default=None, 41 | help='Number of CPU processes (for tokenizing, etc)') 42 | parser.add_argument('--no-cuda', action='store_true', 43 | help='Use CPU only') 44 | parser.add_argument('--gpu', type=int, default=-1, 45 | help='Specify GPU device id to use') 46 | parser.add_argument('--batch-size', type=int, default=128, 47 | help='Example batching size') 48 | parser.add_argument('--top-n', type=int, default=1, 49 | help='Store top N predicted spans per example') 50 | parser.add_argument('--official', action='store_true', 51 | help='Only store single top span instead of top N list') 52 | args = parser.parse_args() 53 | t0 = time.time() 54 | 55 | args.cuda = not args.no_cuda and torch.cuda.is_available() 56 | if args.cuda: 57 | torch.cuda.set_device(args.gpu) 58 | logger.info('CUDA enabled (GPU %d)' % args.gpu) 59 | else: 60 | logger.info('Running on CPU only.') 61 | 62 | predictor = Predictor( 63 | args.model, 64 | args.tokenizer, 65 | args.embedding_file, 66 | args.num_workers, 67 | ) 68 | if args.cuda: 69 | predictor.cuda() 70 | 71 | 72 | # ------------------------------------------------------------------------------ 73 | # Read in dataset and make predictions. 74 | # ------------------------------------------------------------------------------ 75 | 76 | 77 | examples = [] 78 | qids = [] 79 | answer = [] 80 | with open(args.dataset) as f: 81 | data = json.load(f)['data'] 82 | for article in data: 83 | for paragraph in article['paragraphs']: 84 | context = paragraph['context'] 85 | for qa in paragraph['qas']: 86 | qids.append(qa['id']) 87 | examples.append((context, qa['question'])) 88 | answer.append(qa['answers']) 89 | 90 | results = {} 91 | accurateIndex = 0 92 | accuracy = 0 93 | for i in tqdm(range(0, len(examples), args.batch_size)): 94 | predictions = predictor.predict_batch( 95 | examples[i:i + args.batch_size], top_n=args.top_n 96 | ) 97 | 98 | for j in range(len(predictions)): 99 | # Official eval expects just a qid --> span 100 | if args.official: 101 | results[qids[i + j]] = {'predictions': predictions[j] 102 | [0][0], 'answers': answer[i + j]} 103 | for idx in range(predictions[j][0][0]): 104 | if predictions[j][0][0][idx][0] == answer[i + j][idx]['text']: 105 | accuracy = (accuracy * accurateIndex + 1) / \ 106 | (accurateIndex + 1) 107 | accurateIndex += 1 108 | else: 109 | accuracy = (accuracy * accurateIndex) / \ 110 | (accurateIndex + 1) 111 | accurateIndex += 1 112 | 113 | # Otherwise we store top N and scores for debugging. 114 | else: 115 | preds = [(p[0], float(p[1])) for p in predictions[j]] 116 | results[qids[i + j]] = {'predictions': preds, 117 | 'answers': answer[i + j]} 118 | # print(str({'predictions': preds, 'answers': answer[i + j]})) 119 | predCount = 0 120 | for pred in preds: 121 | if predCount < len(answer[i + j]) and pred[0] == answer[i + j][predCount]['text']: 122 | accuracy = (accuracy * accurateIndex + 1) / \ 123 | (accurateIndex + 1) 124 | accurateIndex += 1 125 | else: 126 | accuracy = (accuracy * accurateIndex) / \ 127 | (accurateIndex + 1) 128 | accurateIndex += 1 129 | predCount += 1 130 | 131 | 132 | model = os.path.splitext(os.path.basename(args.model or 'default'))[0] 133 | basename = os.path.splitext(os.path.basename(args.dataset))[0] 134 | outfile = os.path.join(args.out_dir, basename + '-' + model + '.preds') 135 | 136 | logger.info('Writing results to %s' % outfile) 137 | with open(outfile, 'w') as f: 138 | json.dump(results, f) 139 | 140 | logger.info('Total time: %.2f' % (time.time() - t0)) 141 | logger.info('Total items: %s' % accurateIndex) 142 | logger.info('Accuracy: %s' % accuracy) 143 | -------------------------------------------------------------------------------- /scripts/reader/preprocess.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | """Preprocess the SQuAD dataset for training.""" 8 | 9 | import argparse 10 | import os 11 | import sys 12 | import json 13 | import time 14 | 15 | from multiprocessing import Pool 16 | from multiprocessing.util import Finalize 17 | from functools import partial 18 | from drqa import tokenizers 19 | 20 | # ------------------------------------------------------------------------------ 21 | # Tokenize + annotate. 22 | # ------------------------------------------------------------------------------ 23 | 24 | TOK = None 25 | 26 | 27 | def init(tokenizer_class, options): 28 | global TOK 29 | TOK = tokenizer_class(**options) 30 | Finalize(TOK, TOK.shutdown, exitpriority=100) 31 | 32 | 33 | def tokenize(text): 34 | """Call the global process tokenizer on the input text.""" 35 | global TOK 36 | tokens = TOK.tokenize(text) 37 | output = { 38 | 'words': tokens.words(), 39 | 'offsets': tokens.offsets(), 40 | 'pos': tokens.pos(), 41 | 'lemma': tokens.lemmas(), 42 | 'ner': tokens.entities(), 43 | } 44 | return output 45 | 46 | 47 | # ------------------------------------------------------------------------------ 48 | # Process dataset examples 49 | # ------------------------------------------------------------------------------ 50 | 51 | 52 | def load_dataset(path): 53 | """Load json file and store fields separately.""" 54 | with open(path) as f: 55 | data = json.load(f)['data'] 56 | output = {'qids': [], 'questions': [], 'answers': [], 57 | 'contexts': [], 'qid2cid': []} 58 | for article in data: 59 | for paragraph in article['paragraphs']: 60 | output['contexts'].append(paragraph['context']) 61 | for qa in paragraph['qas']: 62 | output['qids'].append(qa['id']) 63 | output['questions'].append(qa['question']) 64 | output['qid2cid'].append(len(output['contexts']) - 1) 65 | if 'answers' in qa: 66 | output['answers'].append(qa['answers']) 67 | return output 68 | 69 | 70 | def find_answer(offsets, begin_offset, end_offset): 71 | """Match token offsets with the char begin/end offsets of the answer.""" 72 | start = [i for i, tok in enumerate(offsets) if tok[0] == begin_offset] 73 | end = [i for i, tok in enumerate(offsets) if tok[1] == end_offset] 74 | assert(len(start) <= 1) 75 | assert(len(end) <= 1) 76 | if len(start) == 1 and len(end) == 1: 77 | return start[0], end[0] 78 | 79 | 80 | def process_dataset(data, tokenizer, workers=None): 81 | """Iterate processing (tokenize, parse, etc) dataset multithreaded.""" 82 | tokenizer_class = tokenizers.get_class(tokenizer) 83 | make_pool = partial(Pool, workers, initializer=init) 84 | workers = make_pool(initargs=(tokenizer_class, {'annotators': {'lemma'}})) 85 | q_tokens = workers.map(tokenize, data['questions']) 86 | workers.close() 87 | workers.join() 88 | 89 | workers = make_pool( 90 | initargs=(tokenizer_class, {'annotators': {'lemma', 'pos', 'ner'}}) 91 | ) 92 | c_tokens = workers.map(tokenize, data['contexts']) 93 | workers.close() 94 | workers.join() 95 | 96 | for idx in range(len(data['qids'])): 97 | question = q_tokens[idx]['words'] 98 | qlemma = q_tokens[idx]['lemma'] 99 | document = c_tokens[data['qid2cid'][idx]]['words'] 100 | offsets = c_tokens[data['qid2cid'][idx]]['offsets'] 101 | lemma = c_tokens[data['qid2cid'][idx]]['lemma'] 102 | pos = c_tokens[data['qid2cid'][idx]]['pos'] 103 | ner = c_tokens[data['qid2cid'][idx]]['ner'] 104 | ans_tokens = [] 105 | if len(data['answers']) > 0: 106 | for ans in data['answers'][idx]: 107 | found = find_answer(offsets, 108 | ans['answer_start'], 109 | ans['answer_start'] + len(ans['text'])) 110 | if found: 111 | ans_tokens.append(found) 112 | yield { 113 | 'id': data['qids'][idx], 114 | 'question': question, 115 | 'document': document, 116 | 'offsets': offsets, 117 | 'answers': ans_tokens, 118 | 'qlemma': qlemma, 119 | 'lemma': lemma, 120 | 'pos': pos, 121 | 'ner': ner, 122 | } 123 | 124 | 125 | # ----------------------------------------------------------------------------- 126 | # Commandline options 127 | # ----------------------------------------------------------------------------- 128 | 129 | 130 | parser = argparse.ArgumentParser() 131 | parser.add_argument('data_dir', type=str, help='Path to SQuAD data directory') 132 | parser.add_argument('out_dir', type=str, help='Path to output file dir') 133 | parser.add_argument('--split', type=str, help='Filename for train/dev split', 134 | default='SQuAD-v1.1-train') 135 | parser.add_argument('--workers', type=int, default=None) 136 | parser.add_argument('--tokenizer', type=str, default='corenlp') 137 | args = parser.parse_args() 138 | 139 | t0 = time.time() 140 | 141 | in_file = os.path.join(args.data_dir, args.split + '.json') 142 | print('Loading dataset %s' % in_file, file=sys.stderr) 143 | dataset = load_dataset(in_file) 144 | 145 | out_file = os.path.join( 146 | args.out_dir, '%s-processed-%s.txt' % (args.split, args.tokenizer) 147 | ) 148 | print('Will write to file %s' % out_file, file=sys.stderr) 149 | with open(out_file, 'w') as f: 150 | for ex in process_dataset(dataset, args.tokenizer, args.workers): 151 | f.write(json.dumps(ex) + '\n') 152 | print('Total time: %.4f (s)' % (time.time() - t0)) 153 | -------------------------------------------------------------------------------- /scripts/retriever/README.md: -------------------------------------------------------------------------------- 1 | # Document Retriever 2 | 3 | ## Storing the Documents 4 | 5 | To efficiently store and access our documents, we store them in a sqlite database. The key is the `doc_id` and the value is the `text`. 6 | 7 | To create a sqlite db from a corpus of documents, run: 8 | 9 | ```bash 10 | python build_db.py /path/to/data /path/to/saved/db.db 11 | ``` 12 | 13 | Optional arguments: 14 | ``` 15 | --preprocess File path to a python module that defines a `preprocess` function. 16 | --num-workers Number of CPU processes (for tokenizing, etc). 17 | ``` 18 | 19 | The data path can either be a path to a nested directory of files (such as what the [WikiExtractor](https://github.com/attardi/wikiextractor) script outputs) or a single file. Each file should consist of JSON-encoded documents that have `id` and `text` fields, one per line: 20 | 21 | ```python 22 | {"id": "doc1", "text": "text of doc1"} 23 | ... 24 | {"id": "docN", "text": "text of docN"} 25 | ``` 26 | 27 | `--preprocess /path/to/.py/file` is another optional argument that allows you to supply a python module that defines a `preprocess(doc_object)` function to filter/process documents before they are put in the db. See `prep_wikipedia.py` for an example. 28 | 29 | ## Building the TF-IDF N-grams 30 | 31 | To build a TF-IDF weighted word-doc sparse matrix from the documents stored in the sqlite db, run: 32 | 33 | ```bash 34 | python build_tfidf.py /path/to/doc/db /path/to/output/dir 35 | ``` 36 | 37 | Optional arguments: 38 | ``` 39 | --ngram Use up to N-size n-grams (e.g. 2 = unigrams + bigrams). By default only ngrams without stopwords or punctuation are kept. 40 | --hash-size Number of buckets to use for hashing ngrams. 41 | --tokenizer String option specifying tokenizer type to use (e.g. 'corenlp'). 42 | --num-workers Number of CPU processes (for tokenizing, etc). 43 | ``` 44 | 45 | The sparse matrix and its associated metadata will be saved to the output directory under `<db-name>-tfidf-ngram=<N>-hash=<N>-tokenizer=<T>.npz`. 46 | 47 | ## Interactive 48 | 49 | The Document Retriever can also be used interactively (like the [full pipeline](../../README.md#quick-start-demo)). 50 | 51 | ```bash 52 | python scripts/retriever/interactive.py --model /path/to/model 53 | ``` 54 | 55 | ``` 56 | >>> process('question answering', k=5) 57 | 58 | +------+-------------------------------+-----------+ 59 | | Rank | Doc Id | Doc Score | 60 | +------+-------------------------------+-----------+ 61 | | 1 | Question answering | 327.89 | 62 | | 2 | Watson (computer) | 217.26 | 63 | | 3 | Eric Nyberg | 214.36 | 64 | | 4 | Social information seeking | 212.63 | 65 | | 5 | Language Computer Corporation | 184.64 | 66 | +------+-------------------------------+-----------+ 67 | ``` -------------------------------------------------------------------------------- /scripts/retriever/build_db.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | """A script to read in and store documents in a sqlite database.""" 8 | 9 | import argparse 10 | import sqlite3 11 | import json 12 | import os 13 | import logging 14 | import importlib.util 15 | 16 | from multiprocessing import Pool as ProcessPool 17 | from tqdm import tqdm 18 | from drqa.retriever import utils 19 | from drqa.pipeline.simpleDrQA import filtText 20 | 21 | 22 | filt = filtText('drqa/features/map.txt').filt 23 | logger = logging.getLogger() 24 | logger.setLevel(logging.INFO) 25 | fmt = logging.Formatter('%(asctime)s: [ %(message)s ]', '%m/%d/%Y %I:%M:%S %p') 26 | console = logging.StreamHandler() 27 | console.setFormatter(fmt) 28 | logger.addHandler(console) 29 | 30 | 31 | # ------------------------------------------------------------------------------ 32 | # Import helper 33 | # ------------------------------------------------------------------------------ 34 | 35 | 36 | PREPROCESS_FN = None 37 | 38 | 39 | def init(filename): 40 | global PREPROCESS_FN 41 | if filename: 42 | PREPROCESS_FN = import_module(filename).preprocess 43 | 44 | 45 | def import_module(filename): 46 | """Import a module given a full path to the file.""" 47 | spec = importlib.util.spec_from_file_location('doc_filter', filename) 48 | module = importlib.util.module_from_spec(spec) 49 | spec.loader.exec_module(module) 50 | return module 51 | 52 | 53 | # ------------------------------------------------------------------------------ 54 | # Store corpus. 55 | # ------------------------------------------------------------------------------ 56 | 57 | 58 | def iter_files(path): 59 | """Walk through all files located under a root path.""" 60 | if os.path.isfile(path): 61 | yield path 62 | elif os.path.isdir(path): 63 | for dirpath, _, filenames in os.walk(path): 64 | for f in filenames: 65 | yield os.path.join(dirpath, f) 66 | else: 67 | raise RuntimeError('Path %s is invalid' % path) 68 | 69 | 70 | def get_contents(filename): 71 | """Parse the contents of a file. Each line is a JSON encoded document.""" 72 | global PREPROCESS_FN 73 | documents = [] 74 | with open(filename) as f: 75 | for line in f: 76 | # Parse document 77 | doc = json.loads(line) 78 | # Maybe preprocess the document with custom function 79 | if PREPROCESS_FN: 80 | doc = PREPROCESS_FN(doc) 81 | # Skip if it is empty or None 82 | if not doc: 83 | continue 84 | # Add the document 85 | documents.append( 86 | (utils.normalize(doc['id']), extraNormalize(doc['text']))) 87 | return documents 88 | 89 | 90 | def extraNormalize(text): 91 | # normalize and filt multi-spelling text 92 | return filt(utils.normalize(text)) 93 | 94 | 95 | def store_contents(data_path, save_path, preprocess, num_workers=None): 96 | """Preprocess and store a corpus of documents in sqlite. 97 | 98 | Args: 99 | data_path: Root path to directory (or directory of directories) of files 100 | containing json encoded documents (must have `id` and `text` fields). 101 | save_path: Path to output sqlite db. 102 | preprocess: Path to file defining a custom `preprocess` function. Takes 103 | in and outputs a structured doc. 104 | num_workers: Number of parallel processes to use when reading docs. 105 | """ 106 | if os.path.isfile(save_path): 107 | raise RuntimeError('%s already exists! Not overwriting.' % save_path) 108 | 109 | logger.info('Reading into database...') 110 | conn = sqlite3.connect(save_path) 111 | c = conn.cursor() 112 | c.execute("CREATE TABLE documents (id PRIMARY KEY, text);") 113 | 114 | workers = ProcessPool(num_workers, initializer=init, 115 | initargs=(preprocess,)) 116 | files = [f for f in iter_files(data_path)] 117 | count = 0 118 | with tqdm(total=len(files)) as pbar: 119 | for pairs in tqdm(workers.imap_unordered(get_contents, files)): 120 | count += len(pairs) 121 | c.executemany("INSERT INTO documents VALUES (?,?)", pairs) 122 | pbar.update() 123 | logger.info('Read %d docs.' % count) 124 | logger.info('Committing...') 125 | conn.commit() 126 | conn.close() 127 | 128 | 129 | # ------------------------------------------------------------------------------ 130 | # Main. 131 | # ------------------------------------------------------------------------------ 132 | 133 | 134 | if __name__ == '__main__': 135 | parser = argparse.ArgumentParser() 136 | parser.add_argument('data_path', type=str, help='/path/to/data') 137 | parser.add_argument('save_path', type=str, help='/path/to/saved/db.db') 138 | parser.add_argument('--preprocess', type=str, default=None, 139 | help=('File path to a python module that defines ' 140 | 'a `preprocess` function')) 141 | parser.add_argument('--num-workers', type=int, default=None, 142 | help='Number of CPU processes (for tokenizing, etc)') 143 | args = parser.parse_args() 144 | 145 | store_contents( 146 | args.data_path, args.save_path, args.preprocess, args.num_workers 147 | ) 148 | -------------------------------------------------------------------------------- /scripts/retriever/build_tfidf.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | """A script to build the tf-idf document matrices for retrieval.""" 8 | 9 | import numpy as np 10 | import scipy.sparse as sp 11 | import argparse 12 | import os 13 | import math 14 | import logging 15 | 16 | from multiprocessing import Pool as ProcessPool 17 | from multiprocessing.util import Finalize 18 | from functools import partial 19 | from collections import Counter 20 | 21 | from drqa import retriever 22 | from drqa import tokenizers 23 | 24 | logger = logging.getLogger() 25 | logger.setLevel(logging.INFO) 26 | fmt = logging.Formatter('%(asctime)s: [ %(message)s ]', '%m/%d/%Y %I:%M:%S %p') 27 | console = logging.StreamHandler() 28 | console.setFormatter(fmt) 29 | logger.addHandler(console) 30 | 31 | 32 | # ------------------------------------------------------------------------------ 33 | # Multiprocessing functions 34 | # ------------------------------------------------------------------------------ 35 | 36 | DOC2IDX = None 37 | PROCESS_TOK = None 38 | PROCESS_DB = None 39 | 40 | 41 | def init(tokenizer_class, db_class, db_opts): 42 | global PROCESS_TOK, PROCESS_DB 43 | PROCESS_TOK = tokenizer_class() 44 | Finalize(PROCESS_TOK, PROCESS_TOK.shutdown, exitpriority=100) 45 | PROCESS_DB = db_class(**db_opts) 46 | Finalize(PROCESS_DB, PROCESS_DB.close, exitpriority=100) 47 | 48 | 49 | def fetch_text(doc_id): 50 | global PROCESS_DB 51 | return PROCESS_DB.get_doc_text(doc_id) 52 | 53 | 54 | def tokenize(text): 55 | global PROCESS_TOK 56 | return PROCESS_TOK.tokenize(text) 57 | 58 | 59 | # ------------------------------------------------------------------------------ 60 | # Build article --> word count sparse matrix. 61 | # ------------------------------------------------------------------------------ 62 | 63 | 64 | def count(ngram, hash_size, doc_id): 65 | """Fetch the text of a document and compute hashed ngrams counts.""" 66 | global DOC2IDX 67 | row, col, data = [], [], [] 68 | # Tokenize 69 | tokens = tokenize(retriever.utils.normalize(fetch_text(doc_id))) 70 | 71 | # Get ngrams from tokens, with stopword/punctuation filtering. 72 | ngrams = tokens.ngrams( 73 | n=ngram, uncased=True, filter_fn=retriever.utils.filter_ngram 74 | ) 75 | 76 | # Hash ngrams and count occurences 77 | counts = Counter([retriever.utils.hash(gram, hash_size) for gram in ngrams]) 78 | 79 | # Return in sparse matrix data format. 80 | row.extend(counts.keys()) 81 | col.extend([DOC2IDX[doc_id]] * len(counts)) 82 | data.extend(counts.values()) 83 | return row, col, data 84 | 85 | 86 | def get_count_matrix(args, db, db_opts): 87 | """Form a sparse word to document count matrix (inverted index). 88 | 89 | M[i, j] = # times word i appears in document j. 90 | """ 91 | # Map doc_ids to indexes 92 | global DOC2IDX 93 | db_class = retriever.get_class(db) 94 | with db_class(**db_opts) as doc_db: 95 | doc_ids = doc_db.get_doc_ids() 96 | DOC2IDX = {doc_id: i for i, doc_id in enumerate(doc_ids)} 97 | 98 | # Setup worker pool 99 | tok_class = tokenizers.get_class(args.tokenizer) 100 | workers = ProcessPool( 101 | args.num_workers, 102 | initializer=init, 103 | initargs=(tok_class, db_class, db_opts) 104 | ) 105 | 106 | # Compute the count matrix in steps (to keep in memory) 107 | logger.info('Mapping...') 108 | row, col, data = [], [], [] 109 | step = max(int(len(doc_ids) / 10), 1) 110 | batches = [doc_ids[i:i + step] for i in range(0, len(doc_ids), step)] 111 | _count = partial(count, args.ngram, args.hash_size) 112 | for i, batch in enumerate(batches): 113 | logger.info('-' * 25 + 'Batch %d/%d' % (i + 1, len(batches)) + '-' * 25) 114 | for b_row, b_col, b_data in workers.imap_unordered(_count, batch): 115 | row.extend(b_row) 116 | col.extend(b_col) 117 | data.extend(b_data) 118 | workers.close() 119 | workers.join() 120 | 121 | logger.info('Creating sparse matrix...') 122 | count_matrix = sp.csr_matrix( 123 | (data, (row, col)), shape=(args.hash_size, len(doc_ids)) 124 | ) 125 | count_matrix.sum_duplicates() 126 | return count_matrix, (DOC2IDX, doc_ids) 127 | 128 | 129 | # ------------------------------------------------------------------------------ 130 | # Transform count matrix to different forms. 131 | # ------------------------------------------------------------------------------ 132 | 133 | 134 | def get_tfidf_matrix(cnts): 135 | """Convert the word count matrix into tfidf one. 136 | 137 | tfidf = log(tf + 1) * log((N - Nt + 0.5) / (Nt + 0.5)) 138 | * tf = term frequency in document 139 | * N = number of documents 140 | * Nt = number of occurences of term in all documents 141 | """ 142 | Ns = get_doc_freqs(cnts) 143 | idfs = np.log((cnts.shape[1] - Ns + 0.5) / (Ns + 0.5)) 144 | idfs[idfs < 0] = 0 145 | idfs = sp.diags(idfs, 0) 146 | tfs = cnts.log1p() 147 | tfidfs = idfs.dot(tfs) 148 | return tfidfs 149 | 150 | 151 | def get_doc_freqs(cnts): 152 | """Return word --> # of docs it appears in.""" 153 | binary = (cnts > 0).astype(int) 154 | freqs = np.array(binary.sum(1)).squeeze() 155 | return freqs 156 | 157 | 158 | # ------------------------------------------------------------------------------ 159 | # Main. 160 | # ------------------------------------------------------------------------------ 161 | 162 | 163 | if __name__ == '__main__': 164 | parser = argparse.ArgumentParser() 165 | parser.add_argument('db_path', type=str, default=None, 166 | help='Path to sqlite db holding document texts') 167 | parser.add_argument('out_dir', type=str, default=None, 168 | help='Directory for saving output files') 169 | parser.add_argument('--ngram', type=int, default=2, 170 | help=('Use up to N-size n-grams ' 171 | '(e.g. 2 = unigrams + bigrams)')) 172 | parser.add_argument('--hash-size', type=int, default=int(math.pow(2, 24)), 173 | help='Number of buckets to use for hashing ngrams') 174 | parser.add_argument('--tokenizer', type=str, default='simple', 175 | help=("String option specifying tokenizer type to use " 176 | "(e.g. 'corenlp')")) 177 | parser.add_argument('--num-workers', type=int, default=None, 178 | help='Number of CPU processes (for tokenizing, etc)') 179 | args = parser.parse_args() 180 | 181 | logging.info('Counting words...') 182 | count_matrix, doc_dict = get_count_matrix( 183 | args, 'sqlite', {'db_path': args.db_path} 184 | ) 185 | 186 | logger.info('Making tfidf vectors...') 187 | tfidf = get_tfidf_matrix(count_matrix) 188 | 189 | logger.info('Getting word-doc frequencies...') 190 | freqs = get_doc_freqs(count_matrix) 191 | 192 | basename = os.path.splitext(os.path.basename(args.db_path))[0] 193 | basename += ('-tfidf-ngram=%d-hash=%d-tokenizer=%s' % 194 | (args.ngram, args.hash_size, args.tokenizer)) 195 | filename = os.path.join(args.out_dir, basename) 196 | 197 | logger.info('Saving to %s.npz' % filename) 198 | metadata = { 199 | 'doc_freqs': freqs, 200 | 'tokenizer': args.tokenizer, 201 | 'hash_size': args.hash_size, 202 | 'ngram': args.ngram, 203 | 'doc_dict': doc_dict 204 | } 205 | retriever.utils.save_sparse_csr(filename, tfidf, metadata) 206 | -------------------------------------------------------------------------------- /scripts/retriever/eval.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | """Evaluate the accuracy of the DrQA retriever module.""" 8 | 9 | import regex as re 10 | import logging 11 | import argparse 12 | import json 13 | import time 14 | import os 15 | 16 | from multiprocessing import Pool as ProcessPool 17 | from multiprocessing.util import Finalize 18 | from functools import partial 19 | from drqa import retriever, tokenizers 20 | from drqa.retriever import utils 21 | 22 | # ------------------------------------------------------------------------------ 23 | # Multiprocessing target functions. 24 | # ------------------------------------------------------------------------------ 25 | 26 | PROCESS_TOK = None 27 | PROCESS_DB = None 28 | 29 | 30 | def init(tokenizer_class, tokenizer_opts, db_class, db_opts): 31 | global PROCESS_TOK, PROCESS_DB 32 | PROCESS_TOK = tokenizer_class(**tokenizer_opts) 33 | Finalize(PROCESS_TOK, PROCESS_TOK.shutdown, exitpriority=100) 34 | PROCESS_DB = db_class(**db_opts) 35 | Finalize(PROCESS_DB, PROCESS_DB.close, exitpriority=100) 36 | 37 | 38 | def regex_match(text, pattern): 39 | """Test if a regex pattern is contained within a text.""" 40 | try: 41 | pattern = re.compile( 42 | pattern, 43 | flags=re.IGNORECASE + re.UNICODE + re.MULTILINE, 44 | ) 45 | except BaseException: 46 | return False 47 | return pattern.search(text) is not None 48 | 49 | 50 | def has_answer(answer, doc_id, match): 51 | """Check if a document contains an answer string. 52 | 53 | If `match` is string, token matching is done between the text and answer. 54 | If `match` is regex, we search the whole text with the regex. 55 | """ 56 | global PROCESS_DB, PROCESS_TOK 57 | text = PROCESS_DB.get_doc_text(doc_id) 58 | text = utils.normalize(text) 59 | if match == 'string': 60 | # Answer is a list of possible strings 61 | text = PROCESS_TOK.tokenize(text).words(uncased=True) 62 | for single_answer in answer: 63 | single_answer = utils.normalize(single_answer) 64 | single_answer = PROCESS_TOK.tokenize(single_answer) 65 | single_answer = single_answer.words(uncased=True) 66 | for i in range(0, len(text) - len(single_answer) + 1): 67 | if single_answer == text[i: i + len(single_answer)]: 68 | return True 69 | elif match == 'regex': 70 | # Answer is a regex 71 | single_answer = utils.normalize(answer[0]) 72 | if regex_match(text, single_answer): 73 | return True 74 | return False 75 | 76 | 77 | def get_score(answer_doc, match): 78 | """Search through all the top docs to see if they have the answer.""" 79 | answer, (doc_ids, doc_scores) = answer_doc 80 | for doc_id in doc_ids: 81 | if has_answer(answer, doc_id, match): 82 | return 1 83 | return 0 84 | 85 | 86 | # ------------------------------------------------------------------------------ 87 | # Main 88 | # ------------------------------------------------------------------------------ 89 | 90 | 91 | if __name__ == '__main__': 92 | logger = logging.getLogger() 93 | logger.setLevel(logging.INFO) 94 | fmt = logging.Formatter('%(asctime)s: [ %(message)s ]', 95 | '%m/%d/%Y %I:%M:%S %p') 96 | console = logging.StreamHandler() 97 | console.setFormatter(fmt) 98 | logger.addHandler(console) 99 | 100 | parser = argparse.ArgumentParser() 101 | parser.add_argument('dataset', type=str, default=None) 102 | parser.add_argument('--model', type=str, default=None) 103 | parser.add_argument('--tokenizer', type=str, default='regexp') 104 | parser.add_argument('--n-docs', type=int, default=5) 105 | parser.add_argument('--num-workers', type=int, default=None) 106 | parser.add_argument('--match', type=str, default='string', 107 | choices=['regex', 'string']) 108 | args = parser.parse_args() 109 | 110 | # start time 111 | start = time.time() 112 | 113 | # read all the data and store it 114 | logger.info('Reading data ...') 115 | questions = [] 116 | answers = [] 117 | for line in open(args.dataset): 118 | data = json.loads(line) 119 | question = data['question'] 120 | answer = data['answer'] 121 | questions.append(question) 122 | answers.append(answer) 123 | 124 | # get the closest docs for each question. 125 | logger.info('Initializing ranker...') 126 | ranker = retriever.get_class('tfidf')(tfidf_path=args.model) 127 | 128 | logger.info('Ranking...') 129 | closest_docs = ranker.batch_closest_docs( 130 | questions, k=args.n_docs, num_workers=args.num_workers 131 | ) 132 | answers_docs = zip(answers, closest_docs) 133 | 134 | # define processes 135 | tok_class = tokenizers.get_class(args.tokenizer) 136 | tok_opts = {} 137 | db_class = retriever.DocDB 138 | db_opts = {} 139 | processes = ProcessPool( 140 | processes=args.num_workers, 141 | initializer=init, 142 | initargs=(tok_class, tok_opts, db_class, db_opts) 143 | ) 144 | 145 | # compute the scores for each pair, and print the statistics 146 | logger.info('Retrieving and computing scores...') 147 | get_score_partial = partial(get_score, match=args.match) 148 | scores = processes.map(get_score_partial, answers_docs) 149 | 150 | filename = os.path.basename(args.dataset) 151 | stats = ( 152 | "\n" + "-" * 50 + "\n" + 153 | "{filename}\n" + 154 | "Examples:\t\t\t{total}\n" + 155 | "Matches in top {k}:\t\t{m}\n" + 156 | "Match % in top {k}:\t\t{p:2.2f}\n" + 157 | "Total time:\t\t\t{t:2.4f} (s)\n" 158 | ).format( 159 | filename=filename, 160 | total=len(scores), 161 | k=args.n_docs, 162 | m=sum(scores), 163 | p=(sum(scores) / len(scores) * 100), 164 | t=time.time() - start, 165 | ) 166 | 167 | print(stats) 168 | -------------------------------------------------------------------------------- /scripts/retriever/interactive.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | """Interactive mode for the tfidf DrQA retriever module.""" 8 | 9 | import argparse 10 | import sqlite3 11 | import code 12 | import prettytable 13 | import logging 14 | from drqa import retriever 15 | 16 | logger = logging.getLogger() 17 | logger.setLevel(logging.INFO) 18 | fmt = logging.Formatter('%(asctime)s: [ %(message)s ]', '%m/%d/%Y %I:%M:%S %p') 19 | console = logging.StreamHandler() 20 | console.setFormatter(fmt) 21 | logger.addHandler(console) 22 | 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument('--model', type=str, default=None) 25 | parser.add_argument('--db', type=str, default=None) 26 | args = parser.parse_args() 27 | 28 | logger.info('Initializing ranker...') 29 | ranker = retriever.get_class('tfidf')(tfidf_path=args.model) 30 | conn = sqlite3.connect(args.db) 31 | c = conn.cursor() 32 | 33 | # ------------------------------------------------------------------------------ 34 | # Drop in to interactive 35 | # ------------------------------------------------------------------------------ 36 | 37 | 38 | def process(query, k=1): 39 | doc_names, doc_scores = ranker.closest_docs(query, k) 40 | table = prettytable.PrettyTable( 41 | ['Rank', 'Doc Id', 'Doc Score'] 42 | ) 43 | for i in range(len(doc_names)): 44 | table.add_row([i + 1, doc_names[i], '%.5g' % doc_scores[i]]) 45 | print(table) 46 | for i in range(len(doc_names)): 47 | print('data of %s' % i) 48 | cursor = c.execute('SELECT text from documents WHERE id = "%s"' % doc_names[i]) 49 | for row in cursor: 50 | print("text = " + row[0]) 51 | 52 | banner = """ 53 | Interactive TF-IDF DrQA Retriever 54 | >> process(question, k=1) 55 | >> usage() 56 | """ 57 | 58 | 59 | def usage(): 60 | print(banner) 61 | 62 | 63 | code.interact(banner=banner, local=locals()) 64 | -------------------------------------------------------------------------------- /scripts/retriever/prep_wikipedia.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree 7 | """Preprocess function to filter/prepare Wikipedia docs.""" 8 | 9 | import regex as re 10 | from html.parser import HTMLParser 11 | 12 | PARSER = HTMLParser() 13 | BLACKLIST = set(['23443579', '52643645']) # Conflicting disambig. pages 14 | 15 | 16 | def preprocess(article): 17 | # Take out HTML escaping WikiExtractor didn't clean 18 | for k, v in article.items(): 19 | article[k] = PARSER.unescape(v) 20 | 21 | # Filter some disambiguation pages not caught by the WikiExtractor 22 | if article['id'] in BLACKLIST: 23 | return None 24 | if '(disambiguation)' in article['title'].lower(): 25 | return None 26 | if '(disambiguation page)' in article['title'].lower(): 27 | return None 28 | 29 | # Take out List/Index/Outline pages (mostly links) 30 | if re.match(r'(List of .+)|(Index of .+)|(Outline of .+)', 31 | article['title']): 32 | return None 33 | 34 | # Return doc with `id` set to `title` 35 | return {'id': article['title'], 'text': article['text']} 36 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # This source code is licensed under the license found in the 3 | # LICENSE file in the root directory of this source tree. 4 | 5 | from setuptools import setup, find_packages 6 | import sys 7 | 8 | with open('README.md') as f: 9 | readme = f.read() 10 | 11 | with open('LICENSE') as f: 12 | license = f.read() 13 | 14 | with open('requirements.txt') as f: 15 | reqs = f.read() 16 | 17 | setup( 18 | name='drqa', 19 | version='0.1.0', 20 | description='Reading Wikipedia to Answer Open-Domain Questions', 21 | long_description=readme, 22 | license=license, 23 | python_requires='>=3.5', 24 | packages=find_packages(exclude=('data')), 25 | install_requires=reqs.strip().split('\n'), 26 | ) --------------------------------------------------------------------------------