├── .gitignore ├── Database.py ├── README.md ├── Tagme.py ├── WikiData.py ├── drqa_retriever ├── __init__.py ├── doc_db.py ├── tfidf_doc_ranker.py └── utils.py ├── drqa_tokenizers ├── __init__.py ├── corenlp_tokenizer.py ├── regexp_tokenizer.py ├── simple_tokenizer.py ├── spacy_tokenizer.py └── tokenizer.py ├── pytorch_transformers ├── __init__.py ├── __main__.py ├── convert_gpt2_checkpoint_to_pytorch.py ├── convert_openai_checkpoint_to_pytorch.py ├── convert_pytorch_checkpoint_to_tf.py ├── convert_roberta_checkpoint_to_pytorch.py ├── convert_tf_checkpoint_to_pytorch.py ├── convert_transfo_xl_checkpoint_to_pytorch.py ├── convert_xlm_checkpoint_to_pytorch.py ├── convert_xlnet_checkpoint_to_pytorch.py ├── file_utils.py ├── modeling_auto.py ├── modeling_bert.py ├── modeling_distilbert.py ├── modeling_gpt2.py ├── modeling_openai.py ├── modeling_roberta.py ├── modeling_transfo_xl.py ├── modeling_transfo_xl_utilities.py ├── modeling_utils.py ├── modeling_xlm.py ├── modeling_xlnet.py ├── optimization.py ├── tokenization_auto.py ├── tokenization_bert.py ├── tokenization_distilbert.py ├── tokenization_gpt2.py ├── tokenization_openai.py ├── tokenization_roberta.py ├── tokenization_transfo_xl.py ├── tokenization_utils.py ├── tokenization_xlm.py └── tokenization_xlnet.py ├── rank_bm25.py ├── requirements.txt ├── retrieve_hybrid.py └── retriever.py /.gitignore: -------------------------------------------------------------------------------- 1 | data 2 | */__pycache__ 3 | */*.pyc 4 | -------------------------------------------------------------------------------- /Database.py: -------------------------------------------------------------------------------- 1 | import sqlite3 2 | 3 | class MyDatabase(object): 4 | 5 | def __init__(self, db_path, connect_each): 6 | self.db_path = db_path 7 | self.connect_each = connect_each 8 | if not connect_each: 9 | self.db = sqlite3.connect(db_path) 10 | self.cursor = self.db.cursor() 11 | self.tables = {} 12 | 13 | def create(self, create, table_name, keys): 14 | assert all([key[1] in ['INTEGER', 'TEXT'] for key in keys]) 15 | if create: 16 | assert table_name not in self.tables 17 | self.cursor.execute('''CREATE TABLE {}({})'''.format(table_name, 18 | ", ".join(["{} {} {}".format(key[0], key[1], 'KEY' if i==0 else '') 19 | for i, key in enumerate(keys)]))) 20 | 21 | self.db.commit() 22 | query = "CREATE INDEX index_{} ON {}({})".format(keys[0][0], table_name, keys[0][0]) 23 | self.cursor.execute(query) 24 | self.db.commit() 25 | self.tables[table_name] = [key[0] for key in keys] 26 | 27 | def insert(self, table_name, rows): 28 | assert table_name in self.tables 29 | assert all([len(self.tables[table_name])==len(row) for row in rows]) 30 | query = '''INSERT INTO {}({}) VALUES({})'''.format(table_name, 31 | ", ".join(self.tables[table_name]), 32 | ",".join(["?" for _ in range(len(self.tables[table_name]))])) 33 | self.cursor.executemany(query, rows) 34 | 35 | def commit(self): 36 | self.db.commit() 37 | return self.rowcount_all() 38 | 39 | def rowcount_all(self): 40 | return ["{} {}".format(table_name, self.rowcount(table_name)) for table_name in self.tables.keys()] 41 | 42 | def rowcount(self, table_name): 43 | self.cursor.execute("SELECT COUNT(*) FROM {}".format(table_name)) 44 | return self.cursor.fetchone()[0] 45 | 46 | def fetch(self, table_name, key, value): 47 | assert table_name in self.tables 48 | assert key in self.tables[table_name] 49 | if self.connect_each: 50 | db = sqlite3.connect(self.db_path) 51 | cursor = db.cursor() 52 | cursor.execute('''SELECT * FROM {} where {}=?'''.format(table_name, key), (value,)) 53 | rows = cursor.fetchall() 54 | else: 55 | self.cursor.execute('''SELECT * FROM {} where {}=?'''.format(table_name, key), (value,)) 56 | rows = self.cursor.fetchall() 57 | return rows 58 | 59 | def close(self): 60 | self.db.close() 61 | 62 | 63 | 64 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GraphRetriever 2 | 3 | This contains codes for the GraphRetriever that is proposed in the paper, [Knowledge Guided Text Retrieval and Reading for Open Domain Question Answering](https://arxiv.org/abs/1911.03868). 4 | 5 | ``` 6 | @article{ min2019knowledge , 7 | title={ Knowledge Guided Text Retrieval and Reading for Open Domain Question Answering }, 8 | author={ Min, Sewon and Chen, Danqi and Zettlemoyer, Luke and Hajishirzi, Hannaneh }, 9 | journal={ arXiv preprint arXiv:1911.03868 }, 10 | year={ 2019 } 11 | } 12 | ``` 13 | 14 | This README only describe the minimal set of command lines to run the GraphRetriever, and does not contain sufficient details about the model and how the code works. For more details, we recommend to read the paper or read the code. 15 | 16 | ## 0. Download QA data 17 | 18 | Download the data for your QA task in `data/`. 19 | 20 | ``` 21 | mkdir data 22 | wget https://nlp.cs.washington.edu/ambigqa/data/{webquestions|nq|triviaqa}.zip 23 | unzip {webquestions|nq|triviaqa}.zip -d data/ 24 | rm {webquestions|nq|triviaqa}.zip 25 | ``` 26 | 27 | ## 1. Preprocessing 28 | 29 | For Wikipedia, you need DB and TF-IDF index by following [DrQA](https://github.com/facebookresearch/DrQA). 30 | 31 | For WikiData, please run the following command. 32 | ``` 33 | python3 WikiData.py --dump_path {path_to_wikidata_dump} --data_dir {dir_for_saving_preprocessed_wikidata} --n_processed {n_processes_for_preprocessing} 34 | ``` 35 | 36 | This preprocessing code looks complicated, but what it really does is to store all the entities (in the DB called `Entities`), all the relations (in the DB called `Properties`) and all the triples (entity, relation, entity) (in the DB called `Claims`). We also store `Text2Entity` and `Text2Entity_norm` for mapping between entities and their text forms. 37 | 38 | Running preprocessing may take more than a few days depending on multiprocess availability. We recommend to modify the code if you want to print out the progress during preprocessing. 39 | 40 | 41 | ## 2. Extracting entities from the question 42 | 43 | Now, run `Tagme.py` to extract entities from the question. 44 | 45 | ``` 46 | python3 Tagme.py --data {webquestions|nq|triviaqa} --data_type {train|dev|test} --gcube_token {your_gcube_token} 47 | ``` 48 | 49 | You need GCUBE token in order to get an access to TAGME. Please refer to [here](https://pypi.org/project/tagme/) for details. 50 | 51 | Running `Tagme.py` will save entity data in `data/{webquestions|nq|triviaqa}|{webquestions|nq|triviaqa}-{train|dev|test}.tagme.json`. 52 | 53 | This part can be easily replaced by any better entity extraction model. For instance, you can use [ELQ](https://github.com/facebookresearch/BLINK/tree/master/elq) which has shown much better performance than TAGME on entity linking for questions (see [paper](https://arxiv.org/abs/2010.02413) for comparisons). 54 | 55 | ## 3. Running GraphRetriever 56 | 57 | Now, in order to actually run the GraphRetriever to get a paragraph graph using entities in the question, Wikipedia and Wikidata, please run the following command. 58 | 59 | ``` 60 | python3 retrieve_hybrid.py --data {webquestions|nq|triviaqa} --data_type {train|dev|test} \ 61 | --wiki_db_path {path_to_wiki_db_from_drqa} \ 62 | --tfidf_path {path_to_tfidf_from_drqa} \ 63 | --data_dir {path_to_wikidata_dir_you_preprocessed} 64 | ``` 65 | 66 | It will save retrieved paragraphs (along with the paragraph graph) in `data/{webquestions|nq|triviaqa}/{webquestions|nq|triviaqa}-{train|dev|test}.retrieved.json`. 67 | 68 | The format of the saved data is as follows. Each line contains 69 | - `question`: a string 70 | - `answers`: a list of strings, containing all acceptable answers 71 | - `paragraphs`: a list of paragraph, where each paragraph is represented by (title, index, tokenized context). "title" is the title of the Wikipedia page that this paragraph originated from, "index" is the ordering of this particular paragraph in the page, and "tokenized context" is a list of tokens based on BERTTokenizer. 72 | - `graph`: a dictionary representing relationships between paragraphs. Each key is a string representing a paragraph pair separated by a space (e.g. `0 2` means a pair of 0-th paragraph and 2-th paragraph according to the list in `paragraphs`), and each value is a relation. It is either a dictionary containing `texts` and `description` if it is a cross-document relation (Wikidata relation), or ``/`` if it is an inner-document relation. 73 | 74 | Note on tokenization: 75 | - We did tokenization during retrieval because we set each paragraph to contain 300 tokens at maximum. 76 | - Tokenization actually makes retrieval slow, so we implemented a caching behavior in `retriever.py`. 77 | - Our BERTTokenizer is based on an older version of Huggingface transformers, inside `pytorch_transformers` directory. If you prefer, you can remove this directory and import tokenizers from an updated version of Huggingface transformers. 78 | - If you prefer not to tokenize or want to use a simpler tokenization like `.split()`, please modify `retriever.py`. 79 | 80 | 81 | 82 | 83 | -------------------------------------------------------------------------------- /Tagme.py: -------------------------------------------------------------------------------- 1 | import json 2 | import tagme 3 | import argparse 4 | 5 | from tqdm import tqdm 6 | 7 | class TAGME(object): 8 | 9 | def __init__(self, gcube_token): 10 | tagme.GCUBE_TOKEN = gcube_token 11 | 12 | def extract_entities(self, question, threshold=0.1): 13 | if type(question)==list: 14 | return [self.extract_entities(_question) for _question in question] 15 | annotations = tagme.annotate(question) 16 | return [{'entity': ann.entity_title, 17 | 'entity_id': ann.entity_id, 18 | 'mention': ann.mention, 19 | 'score': ann.score} 20 | for ann in sorted(annotations.get_annotations(threshold), key=lambda x: -x.score)] 21 | 22 | def extract_mentions(self, question): 23 | if type(question)==list: 24 | return [self.extract_mentions(_question) for _question in question] 25 | 26 | mentions = tagme.mentions(question) 27 | return [{'mention': mention.mention, 'score': mention.linkprob} \ 28 | for mention in sorted(mentions.mentions, key=lambda x: -x.linkprob)] 29 | 30 | def get_semantic_relations(self, entity_pairs, is_id=False): 31 | if is_id: 32 | rels = tagme.relatedness_wid(entity_pairs) 33 | else: 34 | rels = tagme.relatedness_title(entity_pairs) 35 | return [{'entity1': rel.title1, 'entity2': rel.title2, 'score': rel.rel} \ 36 | for rel in rels.relatedness] 37 | 38 | if __name__ == '__main__': 39 | parser = argparse.ArgumentParser() 40 | parser.add_argument('--data', type=str, default="webquestions") 41 | parser.add_argument('--data_type', type=str, default="dev") 42 | parser.add_argument('--gcube_token', type=str, default=None) 43 | args = parser.parse_args() 44 | 45 | mytagme = TAGME(args.gcube_token) 46 | with open('data/{}/{}-{}.qa.json'.format(args.data, args.data, args.data_type), 'r') as f: 47 | orig_data = json.load(f) 48 | data = [] 49 | for d in tqdm(orig_data): 50 | data.append(mytagme.extract_entities(d['question'])) 51 | with open('data/{}/{}-{}.tagme.json'.format(args.data, args.data, args.data_type), 'w') as f: 52 | json.dump(data, f) 53 | -------------------------------------------------------------------------------- /WikiData.py: -------------------------------------------------------------------------------- 1 | import os 2 | import bz2 3 | import json 4 | import argparse 5 | import numpy as np 6 | import time 7 | import re 8 | import string 9 | 10 | from tqdm import tqdm 11 | from collections import defaultdict, Counter 12 | from IPython import embed 13 | 14 | from qwikidata.entity import WikidataItem, WikidataProperty 15 | from qwikidata.json_dump import WikidataJsonDump 16 | from qwikidata.utils import dump_entities_to_json 17 | 18 | from Database import MyDatabase 19 | from joblib import Parallel, delayed 20 | from multiprocessing import Pool 21 | #from joblib import wrap_non_picklable_objects 22 | 23 | SEP = "@@MYSEP@@" 24 | PERIOD = 1000000 25 | 26 | class MyWikiData(object): 27 | 28 | def __init__(self, args, load=False, connect_each=False): 29 | self.data_dir = args.data_dir 30 | self.db = MyDatabase(os.path.join(args.data_dir, 'wikidata.db'), 31 | connect_each=connect_each) 32 | 33 | self.db.create(load, 'Entities', [('id', 'TEXT'), 34 | ('description', 'TEXT'), ('enwiki_link', 'TEXT'), 35 | ('texts', 'TEXT'), ('texts_norm', 'TEXT')]) 36 | self.db.create(load, 'Properties', [('id', 'TEXT'), 37 | ('description', 'TEXT'), ('texts', 'TEXT')]) 38 | self.db.create(load, 'Claims', [('id', 'TEXT'), ('property_id', 'TEXT'), 39 | ('value', 'TEXT'), ('valuetype', 'TEXT')]) 40 | self.db.create(load, 'Text2Entity', [('text', 'TEXT'), ('entity_id', 'TEXT')]) 41 | self.db.create(load, 'Text2Entity_norm', [('text', 'TEXT'), ('entity_id', 'TEXT')]) 42 | 43 | self.none_property_ids = set() 44 | self.non_none_property_ids = set() 45 | 46 | if load: 47 | self.wjd = WikidataJsonDump(args.dump_path) 48 | 49 | def process(ii, entity_dict): 50 | if args.start>ii: 51 | return [] 52 | 53 | if entity_dict["type"] == "item": 54 | return self.handle_entity(WikidataItem(entity_dict)) 55 | elif entity_dict["type"] == "property": 56 | return self.handle_property(WikidataProperty(entity_dict)) 57 | else: 58 | return [] 59 | 60 | if load and args.n_processes==1: 61 | start_time=time.time() 62 | queries = defaultdict(list) 63 | for i, b in enumerate(self.wjd): 64 | for table_name, row in process(i, b): 65 | #self.db.insert(table_name, row) 66 | queries[table_name] += row 67 | if (i+1) % PERIOD == 0: 68 | for table_name, rows in queries.items(): 69 | self.db.insert(table_name, rows) 70 | queries = defaultdict(list) 71 | self.save_data(i+1) 72 | if PERIOD>10000: 73 | print ("%d mins"%((time.time()-start_time)/60)) 74 | else: 75 | print ("%d secs"%(time.time()-start_time)) 76 | start_time=time.time() 77 | self.save_data(i+1) 78 | elif load: 79 | process = CloudpickleWrapper(process) 80 | outputs = [] 81 | start_time = time.time() 82 | with Pool(args.n_processes) as pool: 83 | for i, b in enumerate(self.wjd): 84 | outputs.append(pool.apply_async(process, (i, b))) 85 | if (i+1) % PERIOD == 0: 86 | queries = defaultdict(list) 87 | for output in outputs: 88 | for table_name, row in output.get(): 89 | queries[table_name]+=row 90 | for table_name, rorws in queries.items(): 91 | self.db.insert(table_name, rows) 92 | self.save_data(i+1) 93 | if PERIOD>10000: 94 | print ("%d mins"%((time.time()-start_time)/60)) 95 | else: 96 | print ("%d secs"%(time.time()-start_time)) 97 | outputs = [] 98 | start_time=time.time() 99 | if len(outputs)>0: 100 | for output in outputs: 101 | for table_name, row in output.get(): 102 | self.db.insert(table_name, row) 103 | self.save_data(i+1) 104 | 105 | def save_data(self, step): 106 | print ("=====\tSTEP = {}\t=====".format(step)) 107 | print ("\t".join(self.db.commit())) 108 | 109 | def handle_entity(self, entity): 110 | entity_id = entity.entity_id 111 | enwiki_link = entity.get_sitelinks().get('enwiki', {}).get('title', None) 112 | all_claims = defaultdict(list) 113 | for property_id, claims in entity.get_truthy_claim_groups().items(): 114 | for claim in claims: 115 | assert property_id==claim.mainsnak.property_id 116 | if not self.filter_mainsnak(claim.mainsnak): 117 | all_claims[property_id].append(self.handle_mainsnak(claim.mainsnak)) 118 | #assert entity_id not in self.entities 119 | texts = [entity.get_label()] + entity.get_aliases() + [enwiki_link] 120 | texts = list(set([t for t in texts if t is not None])) 121 | texts_norm = list(set([normalize_answer(t) for t in texts])) 122 | 123 | return [('Entities', [(entity_id, entity.get_description(), enwiki_link, 124 | self.list2string(texts), self.list2string(texts_norm))]), 125 | ('Claims', [(entity_id, property_id, 126 | self.value2string(statement['value'], statement['valuetype']), 127 | statement['valuetype']) 128 | for property_id, statements in all_claims.items() 129 | for statement in statements])] + \ 130 | [('Text2Entity', [(text, entity_id) for text in texts])] + \ 131 | [('Text2Entity_norm', [(text, entity_id) for text in texts_norm])] 132 | 133 | def list2string(self, _list): 134 | return SEP.join(_list) 135 | 136 | def value2string(self, value, valuetype): 137 | if type(value)==str: 138 | return value 139 | if type(value)==dict and valuetype=='wikibase-entityid': 140 | return value['id'] 141 | elif type(value)==dict and valuetype=='monolingualtext': 142 | return value['text'] 143 | return json.dumps(value) 144 | 145 | def handle_property(self, property): 146 | #assert property.entity_id not in self.properties 147 | texts = [property.get_label()] + property.get_aliases() 148 | #self.db.insert('Properties', [(property.entity_id, property.get_description(), texts)]) 149 | return [('Properties', [(property.entity_id, property.get_description(), self.list2string(texts))])] 150 | 151 | def filter_mainsnak(self, snak): 152 | if snak.snaktype!='value': 153 | return True 154 | return snak.snak_datatype in ['external-id', 'url', 'commonsMedia', 155 | 'globe-coordinate', 'math', 'musical-notation', 156 | 'time', 'geo-shape'] 157 | 158 | def handle_mainsnak(self, snak): 159 | return {'datatype': snak.snak_datatype, 160 | 'valuetype': snak.value_datatype, 161 | 'value': snak.datavalue.value} 162 | 163 | def filter_entity_ids(self, entity_ids): 164 | filtered_entity_ids = [] 165 | for e_id in entity_ids: 166 | e = self.get_entity(e_id) 167 | if e is not None and e['enwiki_link'] is not None and \ 168 | (not e['enwiki_link'].startswith('Category:')) and \ 169 | (not e['enwiki_link'].endswith(' (disambiguation)')) and \ 170 | (not e['enwiki_link'].endswith(' (word)')): 171 | filtered_entity_ids.append(e_id) 172 | return filtered_entity_ids 173 | 174 | def get_entities_from_text(self, text): 175 | def _get_entities_from_text(norm): 176 | rows = self.db.fetch('Text2Entity_norm' if norm else 'Text2Entity', 'text', 177 | normalize_answer(text) if norm else text) 178 | if len(rows)==0: 179 | return set() 180 | entity_ids = [text for row in rows for text in row[1].split(SEP)] 181 | return self.filter_entity_ids(entity_ids) 182 | result = _get_entities_from_text(norm=False) 183 | if len(result)==0: 184 | result = _get_entities_from_text(norm=True) 185 | return result 186 | 187 | def get_entities_from_title(self, text): 188 | rows = self.db.fetch('Entities', 'enwiki_link', text) 189 | if len(rows)==0: 190 | return set() 191 | return set([r[0] for r in rows]) 192 | 193 | def get_texts_from_entity(self, entity_id): 194 | entity = self.get_entity(entity_id) 195 | return set(entity['texts'])|set(entity['texts_norm']) 196 | 197 | def get_entity(self, entity_id): 198 | if type(entity_id)==list or type(entity_id)==set: 199 | return [self.get_entity(_entity_id) for _entity_id in entity_id] 200 | rows = self.db.fetch('Entities', 'id', entity_id) 201 | assert len(rows)<=1 202 | if len(rows)==0: 203 | return None 204 | return {'description': rows[0][1], 'enwiki_link': rows[0][2], 205 | 'texts': rows[0][3].split(SEP), 'texts_norm': rows[0][4].split(SEP)} 206 | 207 | def get_property(self, property_id): 208 | if type(property_id)==list: 209 | return [self.get_property(i) for i in property_id] 210 | rows = self.db.fetch('Properties', 'id', property_id) 211 | if len(rows)==0: 212 | return None 213 | assert len(rows)==1 214 | return {'description': rows[0][1], 'texts': rows[0][2].split(SEP)} 215 | 216 | def get_neighbors(self, entity_id): 217 | rows = self.db.fetch('Claims', 'id', entity_id) 218 | return [{'property_id': row[1], 'value': row[2], 'valuetype': row[3]} for row in rows] 219 | 220 | def populate(self, seed_texts, k=2, use_aliases=False): 221 | all_entities = [] 222 | collected_titles = [(text, 0, 'seed') for text in seed_texts] 223 | new_entities = [] 224 | for text in seed_texts: 225 | if use_aliases: 226 | for e in self.get_entities_from_text(text): 227 | if e not in all_entities and e not in new_entities: 228 | new_entities.append(e) 229 | else: 230 | for e in self.get_entities_from_title(text): 231 | if e not in all_entities and e not in new_entities: 232 | new_entities.append(e) 233 | for hop in range(k): 234 | added_entities = [] 235 | for entity_id1 in new_entities: 236 | results = self.get_neighbors(entity_id1) 237 | for result in results: 238 | if result['valuetype']=='wikibase-entityid' and \ 239 | result['value'].startswith('Q'): 240 | entity_id2 = result['value'] 241 | if entity_id2 in all_entities or entity_id2 in new_entities \ 242 | or result['property_id'] in self.none_property_ids: 243 | continue 244 | if result['property_id'] not in self.non_none_property_ids: 245 | if self.get_property(result['property_id']) is None: 246 | self.none_property_ids.add(result['property_id']) 247 | continue 248 | self.non_none_property_ids.add(result['property_id']) 249 | added_entities.append(entity_id2) 250 | ent2 = self.get_entity(entity_id2) 251 | if ent2 is not None and ent2['enwiki_link'] is not None and \ 252 | not any([ent2['enwiki_link']==a[0] for a in collected_titles]): 253 | collected_titles.append((ent2['enwiki_link'], hop+1, result['property_id'])) 254 | if len(collected_titles)>=80: 255 | break 256 | all_entities += new_entities 257 | new_entities = added_entities 258 | if len(collected_titles)>=80: 259 | break 260 | return collected_titles 261 | 262 | def get_graph(self, doc_names): 263 | graph = {} 264 | for i, doc_name in enumerate(doc_names): 265 | for entity in self.get_entities_from_title(doc_name): 266 | for result in self.get_neighbors(entity): 267 | if result['valuetype']=='wikibase-entityid' and \ 268 | result['value'].startswith('Q'): 269 | e = self.get_entity(result['value']) 270 | if e is None: 271 | continue 272 | if self.get_property(result['property_id']) is None: 273 | continue 274 | title = e['enwiki_link'] 275 | if title in doc_names: 276 | graph[(doc_name, title)] = self.get_property(result['property_id']) 277 | return graph 278 | 279 | def get_neighbors(self, entity_id): 280 | rows = self.db.fetch('Claims', 'id', entity_id) 281 | return [{'property_id': row[1], 'value': row[2], 'valuetype': row[3]} for row in rows] 282 | 283 | def normalize_answer(s): 284 | 285 | def remove_articles(text): 286 | return re.sub(r'\b(a|an|the)\b', ' ', text) 287 | 288 | def white_space_fix(text): 289 | return ' '.join(text.split()) 290 | 291 | def remove_punc(text): 292 | exclude = set(string.punctuation) 293 | return ''.join(ch for ch in text if ch not in exclude) 294 | 295 | def lower(text): 296 | return text.lower() 297 | 298 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 299 | 300 | if __name__=='__main__': 301 | parser = argparse.ArgumentParser() 302 | parser.add_argument('--dump_path', type=str, default="/data/home/sewon/wikidata-20190708-all.json.bz2") 303 | parser.add_argument('--data_dir', type=str, default="/data/home/sewon/MyWikidata") 304 | parser.add_argument('--start', type=int, default=0) 305 | parser.add_argument('--n_processes', type=int, default=10) 306 | args = parser.parse_args() 307 | wikidata = MyWikiData(args, load=True) 308 | 309 | 310 | 311 | 312 | 313 | -------------------------------------------------------------------------------- /drqa_retriever/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import os 9 | 10 | DATA_DIR = "/home/sewon/analysis_multiple/DrQA/data" 11 | DEFAULTS = { 12 | 'db_path': os.path.join(DATA_DIR, 'wikipedia/docs.db'), 13 | 'tfidf_path': os.path.join( 14 | DATA_DIR, 15 | 'wikipedia/docs-tfidf-ngram=2-hash=16777216-tokenizer=simple.npz' 16 | ), 17 | } 18 | 19 | 20 | def set_default(key, value): 21 | global DEFAULTS 22 | DEFAULTS[key] = value 23 | 24 | 25 | def get_class(name): 26 | if name == 'tfidf': 27 | return TfidfDocRanker 28 | if name == 'sqlite': 29 | return DocDB 30 | raise RuntimeError('Invalid retriever class: %s' % name) 31 | 32 | 33 | from .doc_db import DocDB 34 | from .tfidf_doc_ranker import TfidfDocRanker 35 | -------------------------------------------------------------------------------- /drqa_retriever/doc_db.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | """Documents, in a sqlite database.""" 8 | 9 | import sqlite3 10 | from . import utils 11 | 12 | 13 | class DocDB(object): 14 | """Sqlite backed document storage. 15 | 16 | Implements get_doc_text(doc_id). 17 | """ 18 | 19 | def __init__(self, db_path=None): 20 | self.path = db_path 21 | self.connection = sqlite3.connect(self.path, check_same_thread=False) 22 | cursor = self.connection.cursor() 23 | cursor.execute("SELECT COUNT(*) FROM documents") 24 | print (cursor.fetchone()[0]) 25 | 26 | def __enter__(self): 27 | return self 28 | 29 | def __exit__(self, *args): 30 | self.close() 31 | 32 | def path(self): 33 | """Return the path to the file that backs this database.""" 34 | return self.path 35 | 36 | def close(self): 37 | """Close the connection to the database.""" 38 | self.connection.close() 39 | 40 | def get_doc_ids(self): 41 | """Fetch all ids of docs stored in the db.""" 42 | cursor = self.connection.cursor() 43 | cursor.execute("SELECT id FROM documents") 44 | results = [r[0] for r in cursor.fetchall()] 45 | cursor.close() 46 | return results 47 | 48 | def get_doc_text(self, doc_id): 49 | """Fetch the raw text of the doc for 'doc_id'.""" 50 | cursor = self.connection.cursor() 51 | cursor.execute( 52 | "SELECT text FROM documents WHERE id = ?", 53 | (utils.normalize(doc_id),) 54 | ) 55 | result = cursor.fetchone() 56 | cursor.close() 57 | return result if result is None else result[0] 58 | -------------------------------------------------------------------------------- /drqa_retriever/tfidf_doc_ranker.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | """Rank documents with TF-IDF scores""" 8 | 9 | import logging 10 | import numpy as np 11 | import scipy.sparse as sp 12 | 13 | from multiprocessing.pool import ThreadPool 14 | from functools import partial 15 | 16 | from drqa_retriever import utils 17 | from drqa_retriever import DEFAULTS 18 | import drqa_tokenizers as tokenizers 19 | 20 | logger = logging.getLogger(__name__) 21 | 22 | 23 | class TfidfDocRanker(object): 24 | """Loads a pre-weighted inverted index of token/document terms. 25 | Scores new queries by taking sparse dot products. 26 | """ 27 | 28 | def __init__(self, tfidf_path=None, strict=True): 29 | """ 30 | Args: 31 | tfidf_path: path to saved model file 32 | strict: fail on empty queries or continue (and return empty result) 33 | """ 34 | # Load from disk 35 | tfidf_path = tfidf_path or DEFAULTS['tfidf_path'] 36 | logger.info('Loading %s' % tfidf_path) 37 | matrix, metadata = utils.load_sparse_csr(tfidf_path) 38 | self.doc_mat = matrix 39 | self.ngrams = metadata['ngram'] 40 | self.hash_size = metadata['hash_size'] 41 | self.tokenizer = tokenizers.get_class(metadata['tokenizer'])() 42 | self.doc_freqs = metadata['doc_freqs'].squeeze() 43 | self.doc_dict = metadata['doc_dict'] 44 | self.num_docs = len(self.doc_dict[0]) 45 | self.strict = strict 46 | 47 | def get_doc_index(self, doc_id): 48 | """Convert doc_id --> doc_index""" 49 | return self.doc_dict[0][doc_id] 50 | 51 | def get_doc_id(self, doc_index): 52 | """Convert doc_index --> doc_id""" 53 | return self.doc_dict[1][doc_index] 54 | 55 | def closest_docs(self, query, k=1): 56 | """Closest docs by dot product between query and documents 57 | in tfidf weighted word vector space. 58 | """ 59 | spvec = self.text2spvec(query) 60 | res = spvec * self.doc_mat 61 | 62 | if len(res.data) <= k: 63 | o_sort = np.argsort(-res.data) 64 | else: 65 | o = np.argpartition(-res.data, k)[0:k] 66 | o_sort = o[np.argsort(-res.data[o])] 67 | 68 | doc_scores = res.data[o_sort] 69 | doc_ids = [self.get_doc_id(i) for i in res.indices[o_sort]] 70 | return doc_ids, doc_scores 71 | 72 | def batch_closest_docs(self, queries, k=1, num_workers=None): 73 | """Process a batch of closest_docs requests multithreaded. 74 | Note: we can use plain threads here as scipy is outside of the GIL. 75 | """ 76 | with ThreadPool(num_workers) as threads: 77 | closest_docs = partial(self.closest_docs, k=k) 78 | results = threads.map(closest_docs, queries) 79 | return results 80 | 81 | def parse(self, query): 82 | """Parse the query into tokens (either ngrams or tokens).""" 83 | tokens = self.tokenizer.tokenize(query) 84 | return tokens.ngrams(n=self.ngrams, uncased=True, 85 | filter_fn=utils.filter_ngram) 86 | 87 | def text2spvec(self, query): 88 | """Create a sparse tfidf-weighted word vector from query. 89 | 90 | tfidf = log(tf + 1) * log((N - Nt + 0.5) / (Nt + 0.5)) 91 | """ 92 | # Get hashed ngrams 93 | words = self.parse(utils.normalize(query)) 94 | wids = [utils.hash(w, self.hash_size) for w in words] 95 | 96 | if len(wids) == 0: 97 | if self.strict: 98 | raise RuntimeError('No valid word in: %s' % query) 99 | else: 100 | logger.warning('No valid word in: %s' % query) 101 | return sp.csr_matrix((1, self.hash_size)) 102 | 103 | # Count TF 104 | wids_unique, wids_counts = np.unique(wids, return_counts=True) 105 | tfs = np.log1p(wids_counts) 106 | 107 | # Count IDF 108 | Ns = self.doc_freqs[wids_unique] 109 | idfs = np.log((self.num_docs - Ns + 0.5) / (Ns + 0.5)) 110 | idfs[idfs < 0] = 0 111 | 112 | # TF-IDF 113 | data = np.multiply(tfs, idfs) 114 | 115 | # One row, sparse csr matrix 116 | indptr = np.array([0, len(wids_unique)]) 117 | spvec = sp.csr_matrix( 118 | (data, wids_unique, indptr), shape=(1, self.hash_size) 119 | ) 120 | 121 | return spvec 122 | -------------------------------------------------------------------------------- /drqa_retriever/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | """Various retriever utilities.""" 8 | 9 | import regex 10 | import unicodedata 11 | import numpy as np 12 | 13 | import scipy.sparse as sp 14 | #from sklearn.utils import murmurhash3_32 15 | 16 | 17 | # ------------------------------------------------------------------------------ 18 | # Sparse matrix saving/loading helpers. 19 | # ------------------------------------------------------------------------------ 20 | 21 | 22 | def save_sparse_csr(filename, matrix, metadata=None): 23 | data = { 24 | 'data': matrix.data, 25 | 'indices': matrix.indices, 26 | 'indptr': matrix.indptr, 27 | 'shape': matrix.shape, 28 | 'metadata': metadata, 29 | } 30 | np.savez(filename, **data) 31 | 32 | 33 | def load_sparse_csr(filename): 34 | loader = np.load(filename, allow_pickle=True) 35 | matrix = sp.csr_matrix((loader['data'], loader['indices'], 36 | loader['indptr']), shape=loader['shape']) 37 | return matrix, loader['metadata'].item(0) if 'metadata' in loader else None 38 | 39 | 40 | # ------------------------------------------------------------------------------ 41 | # Token hashing. 42 | # ------------------------------------------------------------------------------ 43 | 44 | 45 | def hash(token, num_buckets): 46 | """Unsigned 32 bit murmurhash for feature hashing.""" 47 | return murmurhash3_32(token, positive=True) % num_buckets 48 | 49 | 50 | # ------------------------------------------------------------------------------ 51 | # Text cleaning. 52 | # ------------------------------------------------------------------------------ 53 | 54 | 55 | STOPWORDS = { 56 | 'i', 'me', 'my', 'myself', 'we', 'our', 'ours', 'ourselves', 'you', 'your', 57 | 'yours', 'yourself', 'yourselves', 'he', 'him', 'his', 'himself', 'she', 58 | 'her', 'hers', 'herself', 'it', 'its', 'itself', 'they', 'them', 'their', 59 | 'theirs', 'themselves', 'what', 'which', 'who', 'whom', 'this', 'that', 60 | 'these', 'those', 'am', 'is', 'are', 'was', 'were', 'be', 'been', 'being', 61 | 'have', 'has', 'had', 'having', 'do', 'does', 'did', 'doing', 'a', 'an', 62 | 'the', 'and', 'but', 'if', 'or', 'because', 'as', 'until', 'while', 'of', 63 | 'at', 'by', 'for', 'with', 'about', 'against', 'between', 'into', 'through', 64 | 'during', 'before', 'after', 'above', 'below', 'to', 'from', 'up', 'down', 65 | 'in', 'out', 'on', 'off', 'over', 'under', 'again', 'further', 'then', 66 | 'once', 'here', 'there', 'when', 'where', 'why', 'how', 'all', 'any', 67 | 'both', 'each', 'few', 'more', 'most', 'other', 'some', 'such', 'no', 'nor', 68 | 'not', 'only', 'own', 'same', 'so', 'than', 'too', 'very', 's', 't', 'can', 69 | 'will', 'just', 'don', 'should', 'now', 'd', 'll', 'm', 'o', 're', 've', 70 | 'y', 'ain', 'aren', 'couldn', 'didn', 'doesn', 'hadn', 'hasn', 'haven', 71 | 'isn', 'ma', 'mightn', 'mustn', 'needn', 'shan', 'shouldn', 'wasn', 'weren', 72 | 'won', 'wouldn', "'ll", "'re", "'ve", "n't", "'s", "'d", "'m", "''", "``" 73 | } 74 | 75 | 76 | def normalize(text): 77 | """Resolve different type of unicode encodings.""" 78 | return unicodedata.normalize('NFD', text) 79 | 80 | 81 | def filter_word(text): 82 | """Take out english stopwords, punctuation, and compound endings.""" 83 | text = normalize(text) 84 | if regex.match(r'^\p{P}+$', text): 85 | return True 86 | if text.lower() in STOPWORDS: 87 | return True 88 | return False 89 | 90 | 91 | def filter_ngram(gram, mode='any'): 92 | """Decide whether to keep or discard an n-gram. 93 | 94 | Args: 95 | gram: list of tokens (length N) 96 | mode: Option to throw out ngram if 97 | 'any': any single token passes filter_word 98 | 'all': all tokens pass filter_word 99 | 'ends': book-ended by filterable tokens 100 | """ 101 | filtered = [filter_word(w) for w in gram] 102 | if mode == 'any': 103 | return any(filtered) 104 | elif mode == 'all': 105 | return all(filtered) 106 | elif mode == 'ends': 107 | return filtered[0] or filtered[-1] 108 | else: 109 | raise ValueError('Invalid mode: %s' % mode) 110 | -------------------------------------------------------------------------------- /drqa_tokenizers/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import os 9 | 10 | DEFAULTS = { 11 | 'corenlp_classpath': os.getenv('CLASSPATH') 12 | } 13 | 14 | 15 | def set_default(key, value): 16 | global DEFAULTS 17 | DEFAULTS[key] = value 18 | 19 | 20 | from .corenlp_tokenizer import CoreNLPTokenizer 21 | from .regexp_tokenizer import RegexpTokenizer 22 | from .simple_tokenizer import SimpleTokenizer 23 | 24 | # Spacy is optional 25 | try: 26 | from .spacy_tokenizer import SpacyTokenizer 27 | except ImportError: 28 | pass 29 | 30 | 31 | def get_class(name): 32 | if name == 'spacy': 33 | return SpacyTokenizer 34 | if name == 'corenlp': 35 | return CoreNLPTokenizer 36 | if name == 'regexp': 37 | return RegexpTokenizer 38 | if name == 'simple': 39 | return SimpleTokenizer 40 | 41 | raise RuntimeError('Invalid tokenizer: %s' % name) 42 | 43 | 44 | def get_annotators_for_args(args): 45 | annotators = set() 46 | if args.use_pos: 47 | annotators.add('pos') 48 | if args.use_lemma: 49 | annotators.add('lemma') 50 | if args.use_ner: 51 | annotators.add('ner') 52 | return annotators 53 | 54 | 55 | def get_annotators_for_model(model): 56 | return get_annotators_for_args(model.args) 57 | -------------------------------------------------------------------------------- /drqa_tokenizers/corenlp_tokenizer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | """Simple wrapper around the Stanford CoreNLP pipeline. 8 | 9 | Serves commands to a java subprocess running the jar. Requires java 8. 10 | """ 11 | 12 | import copy 13 | import json 14 | import pexpect 15 | 16 | from .tokenizer import Tokens, Tokenizer 17 | from . import DEFAULTS 18 | 19 | 20 | class CoreNLPTokenizer(Tokenizer): 21 | 22 | def __init__(self, **kwargs): 23 | """ 24 | Args: 25 | annotators: set that can include pos, lemma, and ner. 26 | classpath: Path to the corenlp directory of jars 27 | mem: Java heap memory 28 | """ 29 | self.classpath = (kwargs.get('classpath') or 30 | DEFAULTS['corenlp_classpath']) 31 | self.annotators = copy.deepcopy(kwargs.get('annotators', set())) 32 | self.mem = kwargs.get('mem', '2g') 33 | self._launch() 34 | 35 | def _launch(self): 36 | """Start the CoreNLP jar with pexpect.""" 37 | annotators = ['tokenize', 'ssplit'] 38 | if 'ner' in self.annotators: 39 | annotators.extend(['pos', 'lemma', 'ner']) 40 | elif 'lemma' in self.annotators: 41 | annotators.extend(['pos', 'lemma']) 42 | elif 'pos' in self.annotators: 43 | annotators.extend(['pos']) 44 | annotators = ','.join(annotators) 45 | options = ','.join(['untokenizable=noneDelete', 46 | 'invertible=true']) 47 | cmd = ['java', '-mx' + self.mem, '-cp', '"%s"' % self.classpath, 48 | 'edu.stanford.nlp.pipeline.StanfordCoreNLP', '-annotators', 49 | annotators, '-tokenize.options', options, 50 | '-outputFormat', 'json', '-prettyPrint', 'false'] 51 | 52 | # We use pexpect to keep the subprocess alive and feed it commands. 53 | # Because we don't want to get hit by the max terminal buffer size, 54 | # we turn off canonical input processing to have unlimited bytes. 55 | self.corenlp = pexpect.spawn('/bin/bash', maxread=100000, timeout=60) 56 | self.corenlp.setecho(False) 57 | self.corenlp.sendline('stty -icanon') 58 | self.corenlp.sendline(' '.join(cmd)) 59 | self.corenlp.delaybeforesend = 0 60 | self.corenlp.delayafterread = 0 61 | self.corenlp.expect_exact('NLP>', searchwindowsize=100) 62 | 63 | @staticmethod 64 | def _convert(token): 65 | if token == '-LRB-': 66 | return '(' 67 | if token == '-RRB-': 68 | return ')' 69 | if token == '-LSB-': 70 | return '[' 71 | if token == '-RSB-': 72 | return ']' 73 | if token == '-LCB-': 74 | return '{' 75 | if token == '-RCB-': 76 | return '}' 77 | return token 78 | 79 | def tokenize(self, text): 80 | # Since we're feeding text to the commandline, we're waiting on seeing 81 | # the NLP> prompt. Hacky! 82 | if 'NLP>' in text: 83 | raise RuntimeError('Bad token (NLP>) in text!') 84 | 85 | # Sending q will cause the process to quit -- manually override 86 | if text.lower().strip() == 'q': 87 | token = text.strip() 88 | index = text.index(token) 89 | data = [(token, text[index:], (index, index + 1), 'NN', 'q', 'O')] 90 | return Tokens(data, self.annotators) 91 | 92 | # Minor cleanup before tokenizing. 93 | clean_text = text.replace('\n', ' ') 94 | 95 | self.corenlp.sendline(clean_text.encode('utf-8')) 96 | self.corenlp.expect_exact('NLP>', searchwindowsize=100) 97 | 98 | # Skip to start of output (may have been stderr logging messages) 99 | output = self.corenlp.before 100 | start = output.find(b'{"sentences":') 101 | output = json.loads(output[start:].decode('utf-8')) 102 | 103 | data = [] 104 | tokens = [t for s in output['sentences'] for t in s['tokens']] 105 | for i in range(len(tokens)): 106 | # Get whitespace 107 | start_ws = tokens[i]['characterOffsetBegin'] 108 | if i + 1 < len(tokens): 109 | end_ws = tokens[i + 1]['characterOffsetBegin'] 110 | else: 111 | end_ws = tokens[i]['characterOffsetEnd'] 112 | 113 | data.append(( 114 | self._convert(tokens[i]['word']), 115 | text[start_ws: end_ws], 116 | (tokens[i]['characterOffsetBegin'], 117 | tokens[i]['characterOffsetEnd']), 118 | tokens[i].get('pos', None), 119 | tokens[i].get('lemma', None), 120 | tokens[i].get('ner', None) 121 | )) 122 | return Tokens(data, self.annotators) 123 | -------------------------------------------------------------------------------- /drqa_tokenizers/regexp_tokenizer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | """Regex based tokenizer that emulates the Stanford/NLTK PTB tokenizers. 8 | 9 | However it is purely in Python, supports robust untokenization, unicode, 10 | and requires minimal dependencies. 11 | """ 12 | 13 | import regex 14 | import logging 15 | from .tokenizer import Tokens, Tokenizer 16 | 17 | logger = logging.getLogger(__name__) 18 | 19 | 20 | class RegexpTokenizer(Tokenizer): 21 | DIGIT = r'\p{Nd}+([:\.\,]\p{Nd}+)*' 22 | TITLE = (r'(dr|esq|hon|jr|mr|mrs|ms|prof|rev|sr|st|rt|messrs|mmes|msgr)' 23 | r'\.(?=\p{Z})') 24 | ABBRV = r'([\p{L}]\.){2,}(?=\p{Z}|$)' 25 | ALPHA_NUM = r'[\p{L}\p{N}\p{M}]++' 26 | HYPHEN = r'{A}([-\u058A\u2010\u2011]{A})+'.format(A=ALPHA_NUM) 27 | NEGATION = r"((?!n't)[\p{L}\p{N}\p{M}])++(?=n't)|n't" 28 | CONTRACTION1 = r"can(?=not\b)" 29 | CONTRACTION2 = r"'([tsdm]|re|ll|ve)\b" 30 | START_DQUOTE = r'(?<=[\p{Z}\(\[{<]|^)(``|["\u0093\u201C\u00AB])(?!\p{Z})' 31 | START_SQUOTE = r'(?<=[\p{Z}\(\[{<]|^)[\'\u0091\u2018\u201B\u2039](?!\p{Z})' 32 | END_DQUOTE = r'(?%s)|(?P%s)|(?P<abbr>%s)|(?P<neg>%s)|(?P<hyph>%s)|' 47 | '(?P<contr1>%s)|(?P<alphanum>%s)|(?P<contr2>%s)|(?P<sdquote>%s)|' 48 | '(?P<edquote>%s)|(?P<ssquote>%s)|(?P<esquote>%s)|(?P<dash>%s)|' 49 | '(?<ellipses>%s)|(?P<punct>%s)|(?P<nonws>%s)' % 50 | (self.DIGIT, self.TITLE, self.ABBRV, self.NEGATION, self.HYPHEN, 51 | self.CONTRACTION1, self.ALPHA_NUM, self.CONTRACTION2, 52 | self.START_DQUOTE, self.END_DQUOTE, self.START_SQUOTE, 53 | self.END_SQUOTE, self.DASH, self.ELLIPSES, self.PUNCT, 54 | self.NON_WS), 55 | flags=regex.IGNORECASE + regex.UNICODE + regex.MULTILINE 56 | ) 57 | if len(kwargs.get('annotators', {})) > 0: 58 | logger.warning('%s only tokenizes! Skipping annotators: %s' % 59 | (type(self).__name__, kwargs.get('annotators'))) 60 | self.annotators = set() 61 | self.substitutions = kwargs.get('substitutions', True) 62 | 63 | def tokenize(self, text): 64 | data = [] 65 | matches = [m for m in self._regexp.finditer(text)] 66 | for i in range(len(matches)): 67 | # Get text 68 | token = matches[i].group() 69 | 70 | # Make normalizations for special token types 71 | if self.substitutions: 72 | groups = matches[i].groupdict() 73 | if groups['sdquote']: 74 | token = "``" 75 | elif groups['edquote']: 76 | token = "''" 77 | elif groups['ssquote']: 78 | token = "`" 79 | elif groups['esquote']: 80 | token = "'" 81 | elif groups['dash']: 82 | token = '--' 83 | elif groups['ellipses']: 84 | token = '...' 85 | 86 | # Get whitespace 87 | span = matches[i].span() 88 | start_ws = span[0] 89 | if i + 1 < len(matches): 90 | end_ws = matches[i + 1].span()[0] 91 | else: 92 | end_ws = span[1] 93 | 94 | # Format data 95 | data.append(( 96 | token, 97 | text[start_ws: end_ws], 98 | span, 99 | )) 100 | return Tokens(data, self.annotators) 101 | -------------------------------------------------------------------------------- /drqa_tokenizers/simple_tokenizer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | """Basic tokenizer that splits text into alpha-numeric tokens and 8 | non-whitespace tokens. 9 | """ 10 | 11 | import regex 12 | import logging 13 | from .tokenizer import Tokens, Tokenizer 14 | 15 | logger = logging.getLogger(__name__) 16 | 17 | 18 | class SimpleTokenizer(Tokenizer): 19 | ALPHA_NUM = r'[\p{L}\p{N}\p{M}]+' 20 | NON_WS = r'[^\p{Z}\p{C}]' 21 | 22 | def __init__(self, **kwargs): 23 | """ 24 | Args: 25 | annotators: None or empty set (only tokenizes). 26 | """ 27 | self._regexp = regex.compile( 28 | '(%s)|(%s)' % (self.ALPHA_NUM, self.NON_WS), 29 | flags=regex.IGNORECASE + regex.UNICODE + regex.MULTILINE 30 | ) 31 | if len(kwargs.get('annotators', {})) > 0: 32 | logger.warning('%s only tokenizes! Skipping annotators: %s' % 33 | (type(self).__name__, kwargs.get('annotators'))) 34 | self.annotators = set() 35 | 36 | def tokenize(self, text): 37 | data = [] 38 | matches = [m for m in self._regexp.finditer(text)] 39 | for i in range(len(matches)): 40 | # Get text 41 | token = matches[i].group() 42 | 43 | # Get whitespace 44 | span = matches[i].span() 45 | start_ws = span[0] 46 | if i + 1 < len(matches): 47 | end_ws = matches[i + 1].span()[0] 48 | else: 49 | end_ws = span[1] 50 | 51 | # Format data 52 | data.append(( 53 | token, 54 | text[start_ws: end_ws], 55 | span, 56 | )) 57 | return Tokens(data, self.annotators) 58 | -------------------------------------------------------------------------------- /drqa_tokenizers/spacy_tokenizer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | """Tokenizer that is backed by spaCy (spacy.io). 8 | 9 | Requires spaCy package and the spaCy english model. 10 | """ 11 | 12 | import spacy 13 | import copy 14 | from .tokenizer import Tokens, Tokenizer 15 | 16 | 17 | class SpacyTokenizer(Tokenizer): 18 | 19 | def __init__(self, **kwargs): 20 | """ 21 | Args: 22 | annotators: set that can include pos, lemma, and ner. 23 | model: spaCy model to use (either path, or keyword like 'en'). 24 | """ 25 | model = kwargs.get('model', 'en') 26 | self.annotators = copy.deepcopy(kwargs.get('annotators', set())) 27 | nlp_kwargs = {'parser': False} 28 | if not any([p in self.annotators for p in ['lemma', 'pos', 'ner']]): 29 | nlp_kwargs['tagger'] = False 30 | if 'ner' not in self.annotators: 31 | nlp_kwargs['entity'] = False 32 | self.nlp = spacy.load(model, **nlp_kwargs) 33 | 34 | def tokenize(self, text): 35 | # We don't treat new lines as tokens. 36 | clean_text = text.replace('\n', ' ') 37 | tokens = self.nlp.tokenizer(clean_text) 38 | if any([p in self.annotators for p in ['lemma', 'pos', 'ner']]): 39 | self.nlp.tagger(tokens) 40 | if 'ner' in self.annotators: 41 | self.nlp.entity(tokens) 42 | 43 | data = [] 44 | for i in range(len(tokens)): 45 | # Get whitespace 46 | start_ws = tokens[i].idx 47 | if i + 1 < len(tokens): 48 | end_ws = tokens[i + 1].idx 49 | else: 50 | end_ws = tokens[i].idx + len(tokens[i].text) 51 | 52 | data.append(( 53 | tokens[i].text, 54 | text[start_ws: end_ws], 55 | (tokens[i].idx, tokens[i].idx + len(tokens[i].text)), 56 | tokens[i].tag_, 57 | tokens[i].lemma_, 58 | tokens[i].ent_type_, 59 | )) 60 | 61 | # Set special option for non-entity tag: '' vs 'O' in spaCy 62 | return Tokens(data, self.annotators, opts={'non_ent': ''}) 63 | -------------------------------------------------------------------------------- /drqa_tokenizers/tokenizer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | """Base tokenizer/tokens classes and utilities.""" 8 | 9 | import copy 10 | 11 | 12 | class Tokens(object): 13 | """A class to represent a list of tokenized text.""" 14 | TEXT = 0 15 | TEXT_WS = 1 16 | SPAN = 2 17 | POS = 3 18 | LEMMA = 4 19 | NER = 5 20 | 21 | def __init__(self, data, annotators, opts=None): 22 | self.data = data 23 | self.annotators = annotators 24 | self.opts = opts or {} 25 | 26 | def __len__(self): 27 | """The number of tokens.""" 28 | return len(self.data) 29 | 30 | def slice(self, i=None, j=None): 31 | """Return a view of the list of tokens from [i, j).""" 32 | new_tokens = copy.copy(self) 33 | new_tokens.data = self.data[i: j] 34 | return new_tokens 35 | 36 | def untokenize(self): 37 | """Returns the original text (with whitespace reinserted).""" 38 | return ''.join([t[self.TEXT_WS] for t in self.data]).strip() 39 | 40 | def words(self, uncased=False): 41 | """Returns a list of the text of each token 42 | 43 | Args: 44 | uncased: lower cases text 45 | """ 46 | if uncased: 47 | return [t[self.TEXT].lower() for t in self.data] 48 | else: 49 | return [t[self.TEXT] for t in self.data] 50 | 51 | def offsets(self): 52 | """Returns a list of [start, end) character offsets of each token.""" 53 | return [t[self.SPAN] for t in self.data] 54 | 55 | def pos(self): 56 | """Returns a list of part-of-speech tags of each token. 57 | Returns None if this annotation was not included. 58 | """ 59 | if 'pos' not in self.annotators: 60 | return None 61 | return [t[self.POS] for t in self.data] 62 | 63 | def lemmas(self): 64 | """Returns a list of the lemmatized text of each token. 65 | Returns None if this annotation was not included. 66 | """ 67 | if 'lemma' not in self.annotators: 68 | return None 69 | return [t[self.LEMMA] for t in self.data] 70 | 71 | def entities(self): 72 | """Returns a list of named-entity-recognition tags of each token. 73 | Returns None if this annotation was not included. 74 | """ 75 | if 'ner' not in self.annotators: 76 | return None 77 | return [t[self.NER] for t in self.data] 78 | 79 | def ngrams(self, n=1, uncased=False, filter_fn=None, as_strings=True): 80 | """Returns a list of all ngrams from length 1 to n. 81 | 82 | Args: 83 | n: upper limit of ngram length 84 | uncased: lower cases text 85 | filter_fn: user function that takes in an ngram list and returns 86 | True or False to keep or not keep the ngram 87 | as_string: return the ngram as a string vs list 88 | """ 89 | def _skip(gram): 90 | if not filter_fn: 91 | return False 92 | return filter_fn(gram) 93 | 94 | words = self.words(uncased) 95 | ngrams = [(s, e + 1) 96 | for s in range(len(words)) 97 | for e in range(s, min(s + n, len(words))) 98 | if not _skip(words[s:e + 1])] 99 | 100 | # Concatenate into strings 101 | if as_strings: 102 | ngrams = ['{}'.format(' '.join(words[s:e])) for (s, e) in ngrams] 103 | 104 | return ngrams 105 | 106 | def entity_groups(self): 107 | """Group consecutive entity tokens with the same NER tag.""" 108 | entities = self.entities() 109 | if not entities: 110 | return None 111 | non_ent = self.opts.get('non_ent', 'O') 112 | groups = [] 113 | idx = 0 114 | while idx < len(entities): 115 | ner_tag = entities[idx] 116 | # Check for entity tag 117 | if ner_tag != non_ent: 118 | # Chomp the sequence 119 | start = idx 120 | while (idx < len(entities) and entities[idx] == ner_tag): 121 | idx += 1 122 | groups.append((self.slice(start, idx).untokenize(), ner_tag)) 123 | else: 124 | idx += 1 125 | return groups 126 | 127 | 128 | class Tokenizer(object): 129 | """Base tokenizer class. 130 | Tokenizers implement tokenize, which should return a Tokens class. 131 | """ 132 | def tokenize(self, text): 133 | raise NotImplementedError 134 | 135 | def shutdown(self): 136 | pass 137 | 138 | def __del__(self): 139 | self.shutdown() 140 | -------------------------------------------------------------------------------- /pytorch_transformers/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "1.1.0" 2 | from .tokenization_auto import AutoTokenizer 3 | from .tokenization_bert import BertTokenizer, BasicTokenizer, WordpieceTokenizer 4 | from .tokenization_openai import OpenAIGPTTokenizer 5 | from .tokenization_transfo_xl import (TransfoXLTokenizer, TransfoXLCorpus) 6 | from .tokenization_gpt2 import GPT2Tokenizer 7 | from .tokenization_xlnet import XLNetTokenizer, SPIECE_UNDERLINE 8 | from .tokenization_xlm import XLMTokenizer 9 | from .tokenization_roberta import RobertaTokenizer 10 | from .tokenization_distilbert import DistilBertTokenizer 11 | 12 | from .tokenization_utils import (PreTrainedTokenizer) 13 | 14 | from .modeling_auto import (AutoConfig, AutoModel) 15 | 16 | from .modeling_bert import (BertConfig, BertPreTrainedModel, BertModel, BertForPreTraining, 17 | BertForMaskedLM, BertForNextSentencePrediction, 18 | BertForSequenceClassification, BertForMultipleChoice, 19 | BertForTokenClassification, BertForQuestionAnswering, 20 | load_tf_weights_in_bert, BERT_PRETRAINED_MODEL_ARCHIVE_MAP, 21 | BERT_PRETRAINED_CONFIG_ARCHIVE_MAP) 22 | from .modeling_openai import (OpenAIGPTConfig, OpenAIGPTPreTrainedModel, OpenAIGPTModel, 23 | OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel, 24 | load_tf_weights_in_openai_gpt, OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, 25 | OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP) 26 | from .modeling_transfo_xl import (TransfoXLConfig, TransfoXLPreTrainedModel, TransfoXLModel, TransfoXLLMHeadModel, 27 | load_tf_weights_in_transfo_xl, TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP, 28 | TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP) 29 | from .modeling_gpt2 import (GPT2Config, GPT2PreTrainedModel, GPT2Model, 30 | GPT2LMHeadModel, GPT2DoubleHeadsModel, 31 | load_tf_weights_in_gpt2, GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, 32 | GPT2_PRETRAINED_MODEL_ARCHIVE_MAP) 33 | from .modeling_xlnet import (XLNetConfig, 34 | XLNetPreTrainedModel, XLNetModel, XLNetLMHeadModel, 35 | XLNetForSequenceClassification, XLNetForQuestionAnswering, 36 | load_tf_weights_in_xlnet, XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP, 37 | XLNET_PRETRAINED_MODEL_ARCHIVE_MAP) 38 | from .modeling_xlm import (XLMConfig, XLMPreTrainedModel , XLMModel, 39 | XLMWithLMHeadModel, XLMForSequenceClassification, 40 | XLMForQuestionAnswering, XLM_PRETRAINED_CONFIG_ARCHIVE_MAP, 41 | XLM_PRETRAINED_MODEL_ARCHIVE_MAP) 42 | from .modeling_roberta import (RobertaConfig, RobertaForMaskedLM, RobertaModel, RobertaForSequenceClassification, 43 | ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP) 44 | from .modeling_distilbert import (DistilBertConfig, DistilBertForMaskedLM, DistilBertModel, 45 | DistilBertForSequenceClassification, DistilBertForQuestionAnswering, 46 | DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP) 47 | from .modeling_utils import (WEIGHTS_NAME, CONFIG_NAME, TF_WEIGHTS_NAME, 48 | PretrainedConfig, PreTrainedModel, prune_layer, Conv1D) 49 | 50 | from .optimization import (AdamW, ConstantLRSchedule, WarmupConstantSchedule, WarmupCosineSchedule, 51 | WarmupCosineWithHardRestartsSchedule, WarmupLinearSchedule) 52 | 53 | from .file_utils import (PYTORCH_TRANSFORMERS_CACHE, PYTORCH_PRETRAINED_BERT_CACHE, cached_path) 54 | -------------------------------------------------------------------------------- /pytorch_transformers/__main__.py: -------------------------------------------------------------------------------- 1 | # coding: utf8 2 | def main(): 3 | import sys 4 | if (len(sys.argv) < 4 or len(sys.argv) > 6) or sys.argv[1] not in ["bert", "gpt", "transfo_xl", "gpt2", "xlnet", "xlm"]: 5 | print( 6 | "Should be used as one of: \n" 7 | ">> pytorch_transformers bert TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT, \n" 8 | ">> pytorch_transformers gpt OPENAI_GPT_CHECKPOINT_FOLDER_PATH PYTORCH_DUMP_OUTPUT [OPENAI_GPT_CONFIG], \n" 9 | ">> pytorch_transformers transfo_xl TF_CHECKPOINT_OR_DATASET PYTORCH_DUMP_OUTPUT [TF_CONFIG] or \n" 10 | ">> pytorch_transformers gpt2 TF_CHECKPOINT PYTORCH_DUMP_OUTPUT [GPT2_CONFIG] or \n" 11 | ">> pytorch_transformers xlnet TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT [FINETUNING_TASK_NAME] or \n" 12 | ">> pytorch_transformers xlm XLM_CHECKPOINT_PATH PYTORCH_DUMP_OUTPUT") 13 | else: 14 | if sys.argv[1] == "bert": 15 | try: 16 | from .convert_tf_checkpoint_to_pytorch import convert_tf_checkpoint_to_pytorch 17 | except ImportError: 18 | print("pytorch_transformers can only be used from the commandline to convert TensorFlow models in PyTorch, " 19 | "In that case, it requires TensorFlow to be installed. Please see " 20 | "https://www.tensorflow.org/install/ for installation instructions.") 21 | raise 22 | 23 | if len(sys.argv) != 5: 24 | # pylint: disable=line-too-long 25 | print("Should be used as `pytorch_transformers bert TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT`") 26 | else: 27 | PYTORCH_DUMP_OUTPUT = sys.argv.pop() 28 | TF_CONFIG = sys.argv.pop() 29 | TF_CHECKPOINT = sys.argv.pop() 30 | convert_tf_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT) 31 | elif sys.argv[1] == "gpt": 32 | from .convert_openai_checkpoint_to_pytorch import convert_openai_checkpoint_to_pytorch 33 | if len(sys.argv) < 4 or len(sys.argv) > 5: 34 | # pylint: disable=line-too-long 35 | print("Should be used as `pytorch_transformers gpt OPENAI_GPT_CHECKPOINT_FOLDER_PATH PYTORCH_DUMP_OUTPUT [OPENAI_GPT_CONFIG]`") 36 | else: 37 | OPENAI_GPT_CHECKPOINT_FOLDER_PATH = sys.argv[2] 38 | PYTORCH_DUMP_OUTPUT = sys.argv[3] 39 | if len(sys.argv) == 5: 40 | OPENAI_GPT_CONFIG = sys.argv[4] 41 | else: 42 | OPENAI_GPT_CONFIG = "" 43 | convert_openai_checkpoint_to_pytorch(OPENAI_GPT_CHECKPOINT_FOLDER_PATH, 44 | OPENAI_GPT_CONFIG, 45 | PYTORCH_DUMP_OUTPUT) 46 | elif sys.argv[1] == "transfo_xl": 47 | try: 48 | from .convert_transfo_xl_checkpoint_to_pytorch import convert_transfo_xl_checkpoint_to_pytorch 49 | except ImportError: 50 | print("pytorch_transformers can only be used from the commandline to convert TensorFlow models in PyTorch, " 51 | "In that case, it requires TensorFlow to be installed. Please see " 52 | "https://www.tensorflow.org/install/ for installation instructions.") 53 | raise 54 | if len(sys.argv) < 4 or len(sys.argv) > 5: 55 | # pylint: disable=line-too-long 56 | print("Should be used as `pytorch_transformers transfo_xl TF_CHECKPOINT/TF_DATASET_FILE PYTORCH_DUMP_OUTPUT [TF_CONFIG]`") 57 | else: 58 | if 'ckpt' in sys.argv[2].lower(): 59 | TF_CHECKPOINT = sys.argv[2] 60 | TF_DATASET_FILE = "" 61 | else: 62 | TF_DATASET_FILE = sys.argv[2] 63 | TF_CHECKPOINT = "" 64 | PYTORCH_DUMP_OUTPUT = sys.argv[3] 65 | if len(sys.argv) == 5: 66 | TF_CONFIG = sys.argv[4] 67 | else: 68 | TF_CONFIG = "" 69 | convert_transfo_xl_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT, TF_DATASET_FILE) 70 | elif sys.argv[1] == "gpt2": 71 | try: 72 | from .convert_gpt2_checkpoint_to_pytorch import convert_gpt2_checkpoint_to_pytorch 73 | except ImportError: 74 | print("pytorch_transformers can only be used from the commandline to convert TensorFlow models in PyTorch, " 75 | "In that case, it requires TensorFlow to be installed. Please see " 76 | "https://www.tensorflow.org/install/ for installation instructions.") 77 | raise 78 | 79 | if len(sys.argv) < 4 or len(sys.argv) > 5: 80 | # pylint: disable=line-too-long 81 | print("Should be used as `pytorch_transformers gpt2 TF_CHECKPOINT PYTORCH_DUMP_OUTPUT [TF_CONFIG]`") 82 | else: 83 | TF_CHECKPOINT = sys.argv[2] 84 | PYTORCH_DUMP_OUTPUT = sys.argv[3] 85 | if len(sys.argv) == 5: 86 | TF_CONFIG = sys.argv[4] 87 | else: 88 | TF_CONFIG = "" 89 | convert_gpt2_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT) 90 | elif sys.argv[1] == "xlnet": 91 | try: 92 | from .convert_xlnet_checkpoint_to_pytorch import convert_xlnet_checkpoint_to_pytorch 93 | except ImportError: 94 | print("pytorch_transformers can only be used from the commandline to convert TensorFlow models in PyTorch, " 95 | "In that case, it requires TensorFlow to be installed. Please see " 96 | "https://www.tensorflow.org/install/ for installation instructions.") 97 | raise 98 | 99 | if len(sys.argv) < 5 or len(sys.argv) > 6: 100 | # pylint: disable=line-too-long 101 | print("Should be used as `pytorch_transformers xlnet TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT [FINETUNING_TASK_NAME]`") 102 | else: 103 | TF_CHECKPOINT = sys.argv[2] 104 | TF_CONFIG = sys.argv[3] 105 | PYTORCH_DUMP_OUTPUT = sys.argv[4] 106 | if len(sys.argv) == 6: 107 | FINETUNING_TASK = sys.argv[5] 108 | else: 109 | FINETUNING_TASK = None 110 | 111 | convert_xlnet_checkpoint_to_pytorch(TF_CHECKPOINT, 112 | TF_CONFIG, 113 | PYTORCH_DUMP_OUTPUT, 114 | FINETUNING_TASK) 115 | elif sys.argv[1] == "xlm": 116 | from .convert_xlm_checkpoint_to_pytorch import convert_xlm_checkpoint_to_pytorch 117 | 118 | if len(sys.argv) != 4: 119 | # pylint: disable=line-too-long 120 | print("Should be used as `pytorch_transformers xlm XLM_CHECKPOINT_PATH PYTORCH_DUMP_OUTPUT`") 121 | else: 122 | XLM_CHECKPOINT_PATH = sys.argv[2] 123 | PYTORCH_DUMP_OUTPUT = sys.argv[3] 124 | 125 | convert_xlm_checkpoint_to_pytorch(XLM_CHECKPOINT_PATH, PYTORCH_DUMP_OUTPUT) 126 | 127 | if __name__ == '__main__': 128 | main() 129 | -------------------------------------------------------------------------------- /pytorch_transformers/convert_gpt2_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert OpenAI GPT checkpoint.""" 16 | 17 | from __future__ import absolute_import, division, print_function 18 | 19 | import argparse 20 | from io import open 21 | 22 | import torch 23 | 24 | from pytorch_transformers.modeling_gpt2 import (CONFIG_NAME, WEIGHTS_NAME, 25 | GPT2Config, 26 | GPT2Model, 27 | load_tf_weights_in_gpt2) 28 | 29 | import logging 30 | logging.basicConfig(level=logging.INFO) 31 | 32 | 33 | def convert_gpt2_checkpoint_to_pytorch(gpt2_checkpoint_path, gpt2_config_file, pytorch_dump_folder_path): 34 | # Construct model 35 | if gpt2_config_file == "": 36 | config = GPT2Config() 37 | else: 38 | config = GPT2Config.from_json_file(gpt2_config_file) 39 | model = GPT2Model(config) 40 | 41 | # Load weights from numpy 42 | load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path) 43 | 44 | # Save pytorch-model 45 | pytorch_weights_dump_path = pytorch_dump_folder_path + '/' + WEIGHTS_NAME 46 | pytorch_config_dump_path = pytorch_dump_folder_path + '/' + CONFIG_NAME 47 | print("Save PyTorch model to {}".format(pytorch_weights_dump_path)) 48 | torch.save(model.state_dict(), pytorch_weights_dump_path) 49 | print("Save configuration file to {}".format(pytorch_config_dump_path)) 50 | with open(pytorch_config_dump_path, "w", encoding="utf-8") as f: 51 | f.write(config.to_json_string()) 52 | 53 | 54 | if __name__ == "__main__": 55 | parser = argparse.ArgumentParser() 56 | ## Required parameters 57 | parser.add_argument("--gpt2_checkpoint_path", 58 | default = None, 59 | type = str, 60 | required = True, 61 | help = "Path to the TensorFlow checkpoint path.") 62 | parser.add_argument("--pytorch_dump_folder_path", 63 | default = None, 64 | type = str, 65 | required = True, 66 | help = "Path to the output PyTorch model.") 67 | parser.add_argument("--gpt2_config_file", 68 | default = "", 69 | type = str, 70 | help = "An optional config json file corresponding to the pre-trained OpenAI model. \n" 71 | "This specifies the model architecture.") 72 | args = parser.parse_args() 73 | convert_gpt2_checkpoint_to_pytorch(args.gpt2_checkpoint_path, 74 | args.gpt2_config_file, 75 | args.pytorch_dump_folder_path) 76 | -------------------------------------------------------------------------------- /pytorch_transformers/convert_openai_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert OpenAI GPT checkpoint.""" 16 | 17 | from __future__ import absolute_import, division, print_function 18 | 19 | import argparse 20 | from io import open 21 | 22 | import torch 23 | 24 | from pytorch_transformers.modeling_openai import (CONFIG_NAME, WEIGHTS_NAME, 25 | OpenAIGPTConfig, 26 | OpenAIGPTModel, 27 | load_tf_weights_in_openai_gpt) 28 | 29 | import logging 30 | logging.basicConfig(level=logging.INFO) 31 | 32 | 33 | def convert_openai_checkpoint_to_pytorch(openai_checkpoint_folder_path, openai_config_file, pytorch_dump_folder_path): 34 | # Construct model 35 | if openai_config_file == "": 36 | config = OpenAIGPTConfig() 37 | else: 38 | config = OpenAIGPTConfig.from_json_file(openai_config_file) 39 | model = OpenAIGPTModel(config) 40 | 41 | # Load weights from numpy 42 | load_tf_weights_in_openai_gpt(model, config, openai_checkpoint_folder_path) 43 | 44 | # Save pytorch-model 45 | pytorch_weights_dump_path = pytorch_dump_folder_path + '/' + WEIGHTS_NAME 46 | pytorch_config_dump_path = pytorch_dump_folder_path + '/' + CONFIG_NAME 47 | print("Save PyTorch model to {}".format(pytorch_weights_dump_path)) 48 | torch.save(model.state_dict(), pytorch_weights_dump_path) 49 | print("Save configuration file to {}".format(pytorch_config_dump_path)) 50 | with open(pytorch_config_dump_path, "w", encoding="utf-8") as f: 51 | f.write(config.to_json_string()) 52 | 53 | 54 | if __name__ == "__main__": 55 | parser = argparse.ArgumentParser() 56 | ## Required parameters 57 | parser.add_argument("--openai_checkpoint_folder_path", 58 | default = None, 59 | type = str, 60 | required = True, 61 | help = "Path to the TensorFlow checkpoint path.") 62 | parser.add_argument("--pytorch_dump_folder_path", 63 | default = None, 64 | type = str, 65 | required = True, 66 | help = "Path to the output PyTorch model.") 67 | parser.add_argument("--openai_config_file", 68 | default = "", 69 | type = str, 70 | help = "An optional config json file corresponding to the pre-trained OpenAI model. \n" 71 | "This specifies the model architecture.") 72 | args = parser.parse_args() 73 | convert_openai_checkpoint_to_pytorch(args.openai_checkpoint_folder_path, 74 | args.openai_config_file, 75 | args.pytorch_dump_folder_path) 76 | -------------------------------------------------------------------------------- /pytorch_transformers/convert_pytorch_checkpoint_to_tf.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Convert Huggingface Pytorch checkpoint to Tensorflow checkpoint.""" 17 | 18 | import os 19 | import argparse 20 | import torch 21 | import numpy as np 22 | import tensorflow as tf 23 | from pytorch_transformers.modeling import BertModel 24 | 25 | 26 | def convert_pytorch_checkpoint_to_tf(model:BertModel, ckpt_dir:str, model_name:str): 27 | 28 | """ 29 | :param model:BertModel Pytorch model instance to be converted 30 | :param ckpt_dir: Tensorflow model directory 31 | :param model_name: model name 32 | :return: 33 | 34 | Currently supported HF models: 35 | Y BertModel 36 | N BertForMaskedLM 37 | N BertForPreTraining 38 | N BertForMultipleChoice 39 | N BertForNextSentencePrediction 40 | N BertForSequenceClassification 41 | N BertForQuestionAnswering 42 | """ 43 | 44 | tensors_to_transpose = ( 45 | "dense.weight", 46 | "attention.self.query", 47 | "attention.self.key", 48 | "attention.self.value" 49 | ) 50 | 51 | var_map = ( 52 | ('layer.', 'layer_'), 53 | ('word_embeddings.weight', 'word_embeddings'), 54 | ('position_embeddings.weight', 'position_embeddings'), 55 | ('token_type_embeddings.weight', 'token_type_embeddings'), 56 | ('.', '/'), 57 | ('LayerNorm/weight', 'LayerNorm/gamma'), 58 | ('LayerNorm/bias', 'LayerNorm/beta'), 59 | ('weight', 'kernel') 60 | ) 61 | 62 | if not os.path.isdir(ckpt_dir): 63 | os.makedirs(ckpt_dir) 64 | 65 | state_dict = model.state_dict() 66 | 67 | def to_tf_var_name(name:str): 68 | for patt, repl in iter(var_map): 69 | name = name.replace(patt, repl) 70 | return 'bert/{}'.format(name) 71 | 72 | def create_tf_var(tensor:np.ndarray, name:str, session:tf.Session): 73 | tf_dtype = tf.dtypes.as_dtype(tensor.dtype) 74 | tf_var = tf.get_variable(dtype=tf_dtype, shape=tensor.shape, name=name, initializer=tf.zeros_initializer()) 75 | session.run(tf.variables_initializer([tf_var])) 76 | session.run(tf_var) 77 | return tf_var 78 | 79 | tf.reset_default_graph() 80 | with tf.Session() as session: 81 | for var_name in state_dict: 82 | tf_name = to_tf_var_name(var_name) 83 | torch_tensor = state_dict[var_name].numpy() 84 | if any([x in var_name for x in tensors_to_transpose]): 85 | torch_tensor = torch_tensor.T 86 | tf_var = create_tf_var(tensor=torch_tensor, name=tf_name, session=session) 87 | tf.keras.backend.set_value(tf_var, torch_tensor) 88 | tf_weight = session.run(tf_var) 89 | print("Successfully created {}: {}".format(tf_name, np.allclose(tf_weight, torch_tensor))) 90 | 91 | saver = tf.train.Saver(tf.trainable_variables()) 92 | saver.save(session, os.path.join(ckpt_dir, model_name.replace("-", "_") + ".ckpt")) 93 | 94 | 95 | def main(raw_args=None): 96 | parser = argparse.ArgumentParser() 97 | parser.add_argument("--model_name", 98 | type=str, 99 | required=True, 100 | help="model name e.g. bert-base-uncased") 101 | parser.add_argument("--cache_dir", 102 | type=str, 103 | default=None, 104 | required=False, 105 | help="Directory containing pytorch model") 106 | parser.add_argument("--pytorch_model_path", 107 | type=str, 108 | required=True, 109 | help="/path/to/<pytorch-model-name>.bin") 110 | parser.add_argument("--tf_cache_dir", 111 | type=str, 112 | required=True, 113 | help="Directory in which to save tensorflow model") 114 | args = parser.parse_args(raw_args) 115 | 116 | model = BertModel.from_pretrained( 117 | pretrained_model_name_or_path=args.model_name, 118 | state_dict=torch.load(args.pytorch_model_path), 119 | cache_dir=args.cache_dir 120 | ) 121 | 122 | convert_pytorch_checkpoint_to_tf( 123 | model=model, 124 | ckpt_dir=args.tf_cache_dir, 125 | model_name=args.model_name 126 | ) 127 | 128 | 129 | if __name__ == "__main__": 130 | main() 131 | -------------------------------------------------------------------------------- /pytorch_transformers/convert_roberta_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert RoBERTa checkpoint.""" 16 | 17 | from __future__ import absolute_import, division, print_function 18 | 19 | import argparse 20 | import logging 21 | import numpy as np 22 | import torch 23 | 24 | from fairseq.models.roberta import RobertaModel as FairseqRobertaModel 25 | from fairseq.modules import TransformerSentenceEncoderLayer 26 | from pytorch_transformers.modeling_bert import (BertConfig, BertEncoder, 27 | BertIntermediate, BertLayer, 28 | BertModel, BertOutput, 29 | BertSelfAttention, 30 | BertSelfOutput) 31 | from pytorch_transformers.modeling_roberta import (RobertaEmbeddings, 32 | RobertaForMaskedLM, 33 | RobertaForSequenceClassification, 34 | RobertaModel) 35 | 36 | logging.basicConfig(level=logging.INFO) 37 | logger = logging.getLogger(__name__) 38 | 39 | SAMPLE_TEXT = 'Hello world! cécé herlolip' 40 | 41 | 42 | def convert_roberta_checkpoint_to_pytorch(roberta_checkpoint_path, pytorch_dump_folder_path, classification_head): 43 | """ 44 | Copy/paste/tweak roberta's weights to our BERT structure. 45 | """ 46 | roberta = FairseqRobertaModel.from_pretrained(roberta_checkpoint_path) 47 | roberta.eval() # disable dropout 48 | config = BertConfig( 49 | vocab_size_or_config_json_file=50265, 50 | hidden_size=roberta.args.encoder_embed_dim, 51 | num_hidden_layers=roberta.args.encoder_layers, 52 | num_attention_heads=roberta.args.encoder_attention_heads, 53 | intermediate_size=roberta.args.encoder_ffn_embed_dim, 54 | max_position_embeddings=514, 55 | type_vocab_size=1, 56 | ) 57 | if classification_head: 58 | config.num_labels = roberta.args.num_classes 59 | print("Our BERT config:", config) 60 | 61 | model = RobertaForSequenceClassification(config) if classification_head else RobertaForMaskedLM(config) 62 | model.eval() 63 | 64 | # Now let's copy all the weights. 65 | # Embeddings 66 | roberta_sent_encoder = roberta.model.decoder.sentence_encoder 67 | model.roberta.embeddings.word_embeddings.weight = roberta_sent_encoder.embed_tokens.weight 68 | model.roberta.embeddings.position_embeddings.weight = roberta_sent_encoder.embed_positions.weight 69 | model.roberta.embeddings.token_type_embeddings.weight.data = torch.zeros_like(model.roberta.embeddings.token_type_embeddings.weight) # just zero them out b/c RoBERTa doesn't use them. 70 | model.roberta.embeddings.LayerNorm.weight = roberta_sent_encoder.emb_layer_norm.weight 71 | model.roberta.embeddings.LayerNorm.bias = roberta_sent_encoder.emb_layer_norm.bias 72 | model.roberta.embeddings.LayerNorm.variance_epsilon = roberta_sent_encoder.emb_layer_norm.eps 73 | 74 | for i in range(config.num_hidden_layers): 75 | # Encoder: start of layer 76 | layer: BertLayer = model.roberta.encoder.layer[i] 77 | roberta_layer: TransformerSentenceEncoderLayer = roberta_sent_encoder.layers[i] 78 | 79 | ### self attention 80 | self_attn: BertSelfAttention = layer.attention.self 81 | assert( 82 | roberta_layer.self_attn.in_proj_weight.shape == torch.Size((3 * config.hidden_size, config.hidden_size)) 83 | ) 84 | # we use three distinct linear layers so we split the source layer here. 85 | self_attn.query.weight.data = roberta_layer.self_attn.in_proj_weight[:config.hidden_size, :] 86 | self_attn.query.bias.data = roberta_layer.self_attn.in_proj_bias[:config.hidden_size] 87 | self_attn.key.weight.data = roberta_layer.self_attn.in_proj_weight[config.hidden_size:2*config.hidden_size, :] 88 | self_attn.key.bias.data = roberta_layer.self_attn.in_proj_bias[config.hidden_size:2*config.hidden_size] 89 | self_attn.value.weight.data = roberta_layer.self_attn.in_proj_weight[2*config.hidden_size:, :] 90 | self_attn.value.bias.data = roberta_layer.self_attn.in_proj_bias[2*config.hidden_size:] 91 | 92 | ### self-attention output 93 | self_output: BertSelfOutput = layer.attention.output 94 | assert( 95 | self_output.dense.weight.shape == roberta_layer.self_attn.out_proj.weight.shape 96 | ) 97 | self_output.dense.weight = roberta_layer.self_attn.out_proj.weight 98 | self_output.dense.bias = roberta_layer.self_attn.out_proj.bias 99 | self_output.LayerNorm.weight = roberta_layer.self_attn_layer_norm.weight 100 | self_output.LayerNorm.bias = roberta_layer.self_attn_layer_norm.bias 101 | self_output.LayerNorm.variance_epsilon = roberta_layer.self_attn_layer_norm.eps 102 | 103 | ### intermediate 104 | intermediate: BertIntermediate = layer.intermediate 105 | assert( 106 | intermediate.dense.weight.shape == roberta_layer.fc1.weight.shape 107 | ) 108 | intermediate.dense.weight = roberta_layer.fc1.weight 109 | intermediate.dense.bias = roberta_layer.fc1.bias 110 | 111 | ### output 112 | bert_output: BertOutput = layer.output 113 | assert( 114 | bert_output.dense.weight.shape == roberta_layer.fc2.weight.shape 115 | ) 116 | bert_output.dense.weight = roberta_layer.fc2.weight 117 | bert_output.dense.bias = roberta_layer.fc2.bias 118 | bert_output.LayerNorm.weight = roberta_layer.final_layer_norm.weight 119 | bert_output.LayerNorm.bias = roberta_layer.final_layer_norm.bias 120 | bert_output.LayerNorm.variance_epsilon = roberta_layer.final_layer_norm.eps 121 | #### end of layer 122 | 123 | if classification_head: 124 | model.classifier.dense.weight = roberta.model.classification_heads['mnli'].dense.weight 125 | model.classifier.dense.bias = roberta.model.classification_heads['mnli'].dense.bias 126 | model.classifier.out_proj.weight = roberta.model.classification_heads['mnli'].out_proj.weight 127 | model.classifier.out_proj.bias = roberta.model.classification_heads['mnli'].out_proj.bias 128 | else: 129 | # LM Head 130 | model.lm_head.dense.weight = roberta.model.decoder.lm_head.dense.weight 131 | model.lm_head.dense.bias = roberta.model.decoder.lm_head.dense.bias 132 | model.lm_head.layer_norm.weight = roberta.model.decoder.lm_head.layer_norm.weight 133 | model.lm_head.layer_norm.bias = roberta.model.decoder.lm_head.layer_norm.bias 134 | model.lm_head.layer_norm.variance_epsilon = roberta.model.decoder.lm_head.layer_norm.eps 135 | model.lm_head.decoder.weight = roberta.model.decoder.lm_head.weight 136 | model.lm_head.bias = roberta.model.decoder.lm_head.bias 137 | 138 | # Let's check that we get the same results. 139 | input_ids: torch.Tensor = roberta.encode(SAMPLE_TEXT).unsqueeze(0) # batch of size 1 140 | 141 | our_output = model(input_ids)[0] 142 | if classification_head: 143 | their_output = roberta.model.classification_heads['mnli'](roberta.extract_features(input_ids)) 144 | else: 145 | their_output = roberta.model(input_ids)[0] 146 | print(our_output.shape, their_output.shape) 147 | success = torch.allclose(our_output, their_output, atol=1e-3) 148 | print( 149 | "Do both models output the same tensors?", 150 | "🔥" if success else "💩" 151 | ) 152 | if not success: 153 | raise Exception("Something went wRoNg") 154 | 155 | print(f"Saving model to {pytorch_dump_folder_path}") 156 | model.save_pretrained(pytorch_dump_folder_path) 157 | 158 | 159 | if __name__ == "__main__": 160 | parser = argparse.ArgumentParser() 161 | ## Required parameters 162 | parser.add_argument("--roberta_checkpoint_path", 163 | default = None, 164 | type = str, 165 | required = True, 166 | help = "Path the official PyTorch dump.") 167 | parser.add_argument("--pytorch_dump_folder_path", 168 | default = None, 169 | type = str, 170 | required = True, 171 | help = "Path to the output PyTorch model.") 172 | parser.add_argument("--classification_head", 173 | action = "store_true", 174 | help = "Whether to convert a final classification head.") 175 | args = parser.parse_args() 176 | convert_roberta_checkpoint_to_pytorch( 177 | args.roberta_checkpoint_path, 178 | args.pytorch_dump_folder_path, 179 | args.classification_head 180 | ) 181 | 182 | -------------------------------------------------------------------------------- /pytorch_transformers/convert_tf_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert BERT checkpoint.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import argparse 22 | import torch 23 | 24 | from pytorch_transformers.modeling_bert import BertConfig, BertForPreTraining, load_tf_weights_in_bert 25 | 26 | import logging 27 | logging.basicConfig(level=logging.INFO) 28 | 29 | def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path): 30 | # Initialise PyTorch model 31 | config = BertConfig.from_json_file(bert_config_file) 32 | print("Building PyTorch model from configuration: {}".format(str(config))) 33 | model = BertForPreTraining(config) 34 | 35 | # Load weights from tf checkpoint 36 | load_tf_weights_in_bert(model, config, tf_checkpoint_path) 37 | 38 | # Save pytorch-model 39 | print("Save PyTorch model to {}".format(pytorch_dump_path)) 40 | torch.save(model.state_dict(), pytorch_dump_path) 41 | 42 | 43 | if __name__ == "__main__": 44 | parser = argparse.ArgumentParser() 45 | ## Required parameters 46 | parser.add_argument("--tf_checkpoint_path", 47 | default = None, 48 | type = str, 49 | required = True, 50 | help = "Path to the TensorFlow checkpoint path.") 51 | parser.add_argument("--bert_config_file", 52 | default = None, 53 | type = str, 54 | required = True, 55 | help = "The config json file corresponding to the pre-trained BERT model. \n" 56 | "This specifies the model architecture.") 57 | parser.add_argument("--pytorch_dump_path", 58 | default = None, 59 | type = str, 60 | required = True, 61 | help = "Path to the output PyTorch model.") 62 | args = parser.parse_args() 63 | convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, 64 | args.bert_config_file, 65 | args.pytorch_dump_path) 66 | -------------------------------------------------------------------------------- /pytorch_transformers/convert_transfo_xl_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert Transformer XL checkpoint and datasets.""" 16 | 17 | from __future__ import absolute_import, division, print_function 18 | 19 | import argparse 20 | import os 21 | import sys 22 | from io import open 23 | 24 | import torch 25 | 26 | import pytorch_transformers.tokenization_transfo_xl as data_utils 27 | 28 | from pytorch_transformers import CONFIG_NAME, WEIGHTS_NAME 29 | from pytorch_transformers.modeling_transfo_xl import (TransfoXLConfig, TransfoXLLMHeadModel, 30 | load_tf_weights_in_transfo_xl) 31 | from pytorch_transformers.tokenization_transfo_xl import (CORPUS_NAME, VOCAB_FILES_NAMES) 32 | 33 | if sys.version_info[0] == 2: 34 | import cPickle as pickle 35 | else: 36 | import pickle 37 | 38 | import logging 39 | logging.basicConfig(level=logging.INFO) 40 | 41 | # We do this to be able to load python 2 datasets pickles 42 | # See e.g. https://stackoverflow.com/questions/2121874/python-pickling-after-changing-a-modules-directory/2121918#2121918 43 | data_utils.Vocab = data_utils.TransfoXLTokenizer 44 | data_utils.Corpus = data_utils.TransfoXLCorpus 45 | sys.modules['data_utils'] = data_utils 46 | sys.modules['vocabulary'] = data_utils 47 | 48 | def convert_transfo_xl_checkpoint_to_pytorch(tf_checkpoint_path, 49 | transfo_xl_config_file, 50 | pytorch_dump_folder_path, 51 | transfo_xl_dataset_file): 52 | if transfo_xl_dataset_file: 53 | # Convert a pre-processed corpus (see original TensorFlow repo) 54 | with open(transfo_xl_dataset_file, "rb") as fp: 55 | corpus = pickle.load(fp, encoding="latin1") 56 | # Save vocabulary and dataset cache as Dictionaries (should be better than pickles for the long-term) 57 | pytorch_vocab_dump_path = pytorch_dump_folder_path + '/' + VOCAB_FILES_NAMES['pretrained_vocab_file'] 58 | print("Save vocabulary to {}".format(pytorch_vocab_dump_path)) 59 | corpus_vocab_dict = corpus.vocab.__dict__ 60 | torch.save(corpus_vocab_dict, pytorch_vocab_dump_path) 61 | 62 | corpus_dict_no_vocab = corpus.__dict__ 63 | corpus_dict_no_vocab.pop('vocab', None) 64 | pytorch_dataset_dump_path = pytorch_dump_folder_path + '/' + CORPUS_NAME 65 | print("Save dataset to {}".format(pytorch_dataset_dump_path)) 66 | torch.save(corpus_dict_no_vocab, pytorch_dataset_dump_path) 67 | 68 | if tf_checkpoint_path: 69 | # Convert a pre-trained TensorFlow model 70 | config_path = os.path.abspath(transfo_xl_config_file) 71 | tf_path = os.path.abspath(tf_checkpoint_path) 72 | 73 | print("Converting Transformer XL checkpoint from {} with config at {}".format(tf_path, config_path)) 74 | # Initialise PyTorch model 75 | if transfo_xl_config_file == "": 76 | config = TransfoXLConfig() 77 | else: 78 | config = TransfoXLConfig.from_json_file(transfo_xl_config_file) 79 | print("Building PyTorch model from configuration: {}".format(str(config))) 80 | model = TransfoXLLMHeadModel(config) 81 | 82 | model = load_tf_weights_in_transfo_xl(model, config, tf_path) 83 | # Save pytorch-model 84 | pytorch_weights_dump_path = os.path.join(pytorch_dump_folder_path, WEIGHTS_NAME) 85 | pytorch_config_dump_path = os.path.join(pytorch_dump_folder_path, CONFIG_NAME) 86 | print("Save PyTorch model to {}".format(os.path.abspath(pytorch_weights_dump_path))) 87 | torch.save(model.state_dict(), pytorch_weights_dump_path) 88 | print("Save configuration file to {}".format(os.path.abspath(pytorch_config_dump_path))) 89 | with open(pytorch_config_dump_path, "w", encoding="utf-8") as f: 90 | f.write(config.to_json_string()) 91 | 92 | 93 | if __name__ == "__main__": 94 | parser = argparse.ArgumentParser() 95 | parser.add_argument("--pytorch_dump_folder_path", 96 | default = None, 97 | type = str, 98 | required = True, 99 | help = "Path to the folder to store the PyTorch model or dataset/vocab.") 100 | parser.add_argument("--tf_checkpoint_path", 101 | default = "", 102 | type = str, 103 | help = "An optional path to a TensorFlow checkpoint path to be converted.") 104 | parser.add_argument("--transfo_xl_config_file", 105 | default = "", 106 | type = str, 107 | help = "An optional config json file corresponding to the pre-trained BERT model. \n" 108 | "This specifies the model architecture.") 109 | parser.add_argument("--transfo_xl_dataset_file", 110 | default = "", 111 | type = str, 112 | help = "An optional dataset file to be converted in a vocabulary.") 113 | args = parser.parse_args() 114 | convert_transfo_xl_checkpoint_to_pytorch(args.tf_checkpoint_path, 115 | args.transfo_xl_config_file, 116 | args.pytorch_dump_folder_path, 117 | args.transfo_xl_dataset_file) 118 | -------------------------------------------------------------------------------- /pytorch_transformers/convert_xlm_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert OpenAI GPT checkpoint.""" 16 | 17 | from __future__ import absolute_import, division, print_function 18 | 19 | import argparse 20 | import json 21 | from io import open 22 | 23 | import torch 24 | import numpy 25 | 26 | from pytorch_transformers.modeling_utils import CONFIG_NAME, WEIGHTS_NAME 27 | from pytorch_transformers.tokenization_xlm import VOCAB_FILES_NAMES 28 | 29 | import logging 30 | logging.basicConfig(level=logging.INFO) 31 | 32 | def convert_xlm_checkpoint_to_pytorch(xlm_checkpoint_path, pytorch_dump_folder_path): 33 | # Load checkpoint 34 | chkpt = torch.load(xlm_checkpoint_path, map_location='cpu') 35 | 36 | model = chkpt['model'] 37 | 38 | config = chkpt['params'] 39 | config = dict((n, v) for n, v in config.items() if not isinstance(v, (torch.FloatTensor, numpy.ndarray))) 40 | 41 | vocab = chkpt['dico_word2id'] 42 | vocab = dict((s + '</w>' if s.find('@@') == -1 and i > 13 else s.replace('@@', ''), i) for s, i in vocab.items()) 43 | 44 | # Save pytorch-model 45 | pytorch_weights_dump_path = pytorch_dump_folder_path + '/' + WEIGHTS_NAME 46 | pytorch_config_dump_path = pytorch_dump_folder_path + '/' + CONFIG_NAME 47 | pytorch_vocab_dump_path = pytorch_dump_folder_path + '/' + VOCAB_FILES_NAMES['vocab_file'] 48 | 49 | print("Save PyTorch model to {}".format(pytorch_weights_dump_path)) 50 | torch.save(model, pytorch_weights_dump_path) 51 | 52 | print("Save configuration file to {}".format(pytorch_config_dump_path)) 53 | with open(pytorch_config_dump_path, "w", encoding="utf-8") as f: 54 | f.write(json.dumps(config, indent=2) + "\n") 55 | 56 | print("Save vocab file to {}".format(pytorch_config_dump_path)) 57 | with open(pytorch_vocab_dump_path, "w", encoding="utf-8") as f: 58 | f.write(json.dumps(vocab, indent=2) + "\n") 59 | 60 | 61 | if __name__ == "__main__": 62 | parser = argparse.ArgumentParser() 63 | ## Required parameters 64 | parser.add_argument("--xlm_checkpoint_path", 65 | default = None, 66 | type = str, 67 | required = True, 68 | help = "Path the official PyTorch dump.") 69 | parser.add_argument("--pytorch_dump_folder_path", 70 | default = None, 71 | type = str, 72 | required = True, 73 | help = "Path to the output PyTorch model.") 74 | args = parser.parse_args() 75 | convert_xlm_checkpoint_to_pytorch(args.xlm_checkpoint_path, args.pytorch_dump_folder_path) 76 | -------------------------------------------------------------------------------- /pytorch_transformers/convert_xlnet_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert BERT checkpoint.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import os 22 | import argparse 23 | import torch 24 | 25 | from pytorch_transformers.modeling_xlnet import (CONFIG_NAME, WEIGHTS_NAME, 26 | XLNetConfig, 27 | XLNetLMHeadModel, XLNetForQuestionAnswering, 28 | XLNetForSequenceClassification, 29 | load_tf_weights_in_xlnet) 30 | 31 | GLUE_TASKS_NUM_LABELS = { 32 | "cola": 2, 33 | "mnli": 3, 34 | "mrpc": 2, 35 | "sst-2": 2, 36 | "sts-b": 1, 37 | "qqp": 2, 38 | "qnli": 2, 39 | "rte": 2, 40 | "wnli": 2, 41 | } 42 | 43 | import logging 44 | logging.basicConfig(level=logging.INFO) 45 | 46 | def convert_xlnet_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_folder_path, finetuning_task=None): 47 | # Initialise PyTorch model 48 | config = XLNetConfig.from_json_file(bert_config_file) 49 | 50 | finetuning_task = finetuning_task.lower() if finetuning_task is not None else "" 51 | if finetuning_task in GLUE_TASKS_NUM_LABELS: 52 | print("Building PyTorch XLNetForSequenceClassification model from configuration: {}".format(str(config))) 53 | config.finetuning_task = finetuning_task 54 | config.num_labels = GLUE_TASKS_NUM_LABELS[finetuning_task] 55 | model = XLNetForSequenceClassification(config) 56 | elif 'squad' in finetuning_task: 57 | config.finetuning_task = finetuning_task 58 | model = XLNetForQuestionAnswering(config) 59 | else: 60 | model = XLNetLMHeadModel(config) 61 | 62 | # Load weights from tf checkpoint 63 | load_tf_weights_in_xlnet(model, config, tf_checkpoint_path) 64 | 65 | # Save pytorch-model 66 | pytorch_weights_dump_path = os.path.join(pytorch_dump_folder_path, WEIGHTS_NAME) 67 | pytorch_config_dump_path = os.path.join(pytorch_dump_folder_path, CONFIG_NAME) 68 | print("Save PyTorch model to {}".format(os.path.abspath(pytorch_weights_dump_path))) 69 | torch.save(model.state_dict(), pytorch_weights_dump_path) 70 | print("Save configuration file to {}".format(os.path.abspath(pytorch_config_dump_path))) 71 | with open(pytorch_config_dump_path, "w", encoding="utf-8") as f: 72 | f.write(config.to_json_string()) 73 | 74 | 75 | if __name__ == "__main__": 76 | parser = argparse.ArgumentParser() 77 | ## Required parameters 78 | parser.add_argument("--tf_checkpoint_path", 79 | default = None, 80 | type = str, 81 | required = True, 82 | help = "Path to the TensorFlow checkpoint path.") 83 | parser.add_argument("--xlnet_config_file", 84 | default = None, 85 | type = str, 86 | required = True, 87 | help = "The config json file corresponding to the pre-trained XLNet model. \n" 88 | "This specifies the model architecture.") 89 | parser.add_argument("--pytorch_dump_folder_path", 90 | default = None, 91 | type = str, 92 | required = True, 93 | help = "Path to the folder to store the PyTorch model or dataset/vocab.") 94 | parser.add_argument("--finetuning_task", 95 | default = None, 96 | type = str, 97 | help = "Name of a task on which the XLNet TensorFloaw model was fine-tuned") 98 | args = parser.parse_args() 99 | print(args) 100 | 101 | convert_xlnet_checkpoint_to_pytorch(args.tf_checkpoint_path, 102 | args.xlnet_config_file, 103 | args.pytorch_dump_folder_path, 104 | args.finetuning_task) 105 | -------------------------------------------------------------------------------- /pytorch_transformers/file_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilities for working with the local dataset cache. 3 | This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp 4 | Copyright by the AllenNLP authors. 5 | """ 6 | from __future__ import (absolute_import, division, print_function, unicode_literals) 7 | 8 | import sys 9 | import json 10 | import logging 11 | import os 12 | import shutil 13 | import tempfile 14 | import fnmatch 15 | from functools import wraps 16 | from hashlib import sha256 17 | from io import open 18 | 19 | import boto3 20 | from botocore.config import Config 21 | from botocore.exceptions import ClientError 22 | import requests 23 | from tqdm import tqdm 24 | 25 | try: 26 | from torch.hub import _get_torch_home 27 | torch_cache_home = _get_torch_home() 28 | except ImportError: 29 | torch_cache_home = os.path.expanduser( 30 | os.getenv('TORCH_HOME', os.path.join( 31 | os.getenv('XDG_CACHE_HOME', '~/.cache'), 'torch'))) 32 | default_cache_path = os.path.join(torch_cache_home, 'pytorch_transformers') 33 | 34 | try: 35 | from urllib.parse import urlparse 36 | except ImportError: 37 | from urlparse import urlparse 38 | 39 | try: 40 | from pathlib import Path 41 | PYTORCH_PRETRAINED_BERT_CACHE = Path( 42 | os.getenv('PYTORCH_TRANSFORMERS_CACHE', os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', default_cache_path))) 43 | except (AttributeError, ImportError): 44 | PYTORCH_PRETRAINED_BERT_CACHE = os.getenv('PYTORCH_TRANSFORMERS_CACHE', 45 | os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', 46 | default_cache_path)) 47 | 48 | PYTORCH_TRANSFORMERS_CACHE = PYTORCH_PRETRAINED_BERT_CACHE # Kept for backward compatibility 49 | 50 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name 51 | 52 | 53 | def url_to_filename(url, etag=None): 54 | """ 55 | Convert `url` into a hashed filename in a repeatable way. 56 | If `etag` is specified, append its hash to the url's, delimited 57 | by a period. 58 | """ 59 | url_bytes = url.encode('utf-8') 60 | url_hash = sha256(url_bytes) 61 | filename = url_hash.hexdigest() 62 | 63 | if etag: 64 | etag_bytes = etag.encode('utf-8') 65 | etag_hash = sha256(etag_bytes) 66 | filename += '.' + etag_hash.hexdigest() 67 | 68 | return filename 69 | 70 | 71 | def filename_to_url(filename, cache_dir=None): 72 | """ 73 | Return the url and etag (which may be ``None``) stored for `filename`. 74 | Raise ``EnvironmentError`` if `filename` or its stored metadata do not exist. 75 | """ 76 | if cache_dir is None: 77 | cache_dir = PYTORCH_TRANSFORMERS_CACHE 78 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path): 79 | cache_dir = str(cache_dir) 80 | 81 | cache_path = os.path.join(cache_dir, filename) 82 | if not os.path.exists(cache_path): 83 | raise EnvironmentError("file {} not found".format(cache_path)) 84 | 85 | meta_path = cache_path + '.json' 86 | if not os.path.exists(meta_path): 87 | raise EnvironmentError("file {} not found".format(meta_path)) 88 | 89 | with open(meta_path, encoding="utf-8") as meta_file: 90 | metadata = json.load(meta_file) 91 | url = metadata['url'] 92 | etag = metadata['etag'] 93 | 94 | return url, etag 95 | 96 | 97 | def cached_path(url_or_filename, cache_dir=None, force_download=False, proxies=None): 98 | """ 99 | Given something that might be a URL (or might be a local path), 100 | determine which. If it's a URL, download the file and cache it, and 101 | return the path to the cached file. If it's already a local path, 102 | make sure the file exists and then return the path. 103 | Args: 104 | cache_dir: specify a cache directory to save the file to (overwrite the default cache dir). 105 | force_download: if True, re-dowload the file even if it's already cached in the cache dir. 106 | """ 107 | if cache_dir is None: 108 | cache_dir = PYTORCH_TRANSFORMERS_CACHE 109 | if sys.version_info[0] == 3 and isinstance(url_or_filename, Path): 110 | url_or_filename = str(url_or_filename) 111 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path): 112 | cache_dir = str(cache_dir) 113 | 114 | parsed = urlparse(url_or_filename) 115 | 116 | if parsed.scheme in ('http', 'https', 's3'): 117 | # URL, so get it from the cache (downloading if necessary) 118 | return get_from_cache(url_or_filename, cache_dir=cache_dir, force_download=force_download, proxies=proxies) 119 | elif os.path.exists(url_or_filename): 120 | # File, and it exists. 121 | return url_or_filename 122 | elif parsed.scheme == '': 123 | # File, but it doesn't exist. 124 | raise EnvironmentError("file {} not found".format(url_or_filename)) 125 | else: 126 | # Something unknown 127 | raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename)) 128 | 129 | 130 | def split_s3_path(url): 131 | """Split a full s3 path into the bucket name and path.""" 132 | parsed = urlparse(url) 133 | if not parsed.netloc or not parsed.path: 134 | raise ValueError("bad s3 path {}".format(url)) 135 | bucket_name = parsed.netloc 136 | s3_path = parsed.path 137 | # Remove '/' at beginning of path. 138 | if s3_path.startswith("/"): 139 | s3_path = s3_path[1:] 140 | return bucket_name, s3_path 141 | 142 | 143 | def s3_request(func): 144 | """ 145 | Wrapper function for s3 requests in order to create more helpful error 146 | messages. 147 | """ 148 | 149 | @wraps(func) 150 | def wrapper(url, *args, **kwargs): 151 | try: 152 | return func(url, *args, **kwargs) 153 | except ClientError as exc: 154 | if int(exc.response["Error"]["Code"]) == 404: 155 | raise EnvironmentError("file {} not found".format(url)) 156 | else: 157 | raise 158 | 159 | return wrapper 160 | 161 | 162 | @s3_request 163 | def s3_etag(url, proxies=None): 164 | """Check ETag on S3 object.""" 165 | s3_resource = boto3.resource("s3", config=Config(proxies=proxies)) 166 | bucket_name, s3_path = split_s3_path(url) 167 | s3_object = s3_resource.Object(bucket_name, s3_path) 168 | return s3_object.e_tag 169 | 170 | 171 | @s3_request 172 | def s3_get(url, temp_file, proxies=None): 173 | """Pull a file directly from S3.""" 174 | s3_resource = boto3.resource("s3", config=Config(proxies=proxies)) 175 | bucket_name, s3_path = split_s3_path(url) 176 | s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file) 177 | 178 | 179 | def http_get(url, temp_file, proxies=None): 180 | req = requests.get(url, stream=True, proxies=proxies) 181 | content_length = req.headers.get('Content-Length') 182 | total = int(content_length) if content_length is not None else None 183 | progress = tqdm(unit="B", total=total) 184 | for chunk in req.iter_content(chunk_size=1024): 185 | if chunk: # filter out keep-alive new chunks 186 | progress.update(len(chunk)) 187 | temp_file.write(chunk) 188 | progress.close() 189 | 190 | 191 | def get_from_cache(url, cache_dir=None, force_download=False, proxies=None): 192 | """ 193 | Given a URL, look for the corresponding dataset in the local cache. 194 | If it's not there, download it. Then return the path to the cached file. 195 | """ 196 | if cache_dir is None: 197 | cache_dir = PYTORCH_TRANSFORMERS_CACHE 198 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path): 199 | cache_dir = str(cache_dir) 200 | if sys.version_info[0] == 2 and not isinstance(cache_dir, str): 201 | cache_dir = str(cache_dir) 202 | 203 | if not os.path.exists(cache_dir): 204 | os.makedirs(cache_dir) 205 | 206 | # Get eTag to add to filename, if it exists. 207 | if url.startswith("s3://"): 208 | etag = s3_etag(url, proxies=proxies) 209 | else: 210 | try: 211 | response = requests.head(url, allow_redirects=True, proxies=proxies) 212 | if response.status_code != 200: 213 | etag = None 214 | else: 215 | etag = response.headers.get("ETag") 216 | except EnvironmentError: 217 | etag = None 218 | 219 | if sys.version_info[0] == 2 and etag is not None: 220 | etag = etag.decode('utf-8') 221 | filename = url_to_filename(url, etag) 222 | 223 | # get cache path to put the file 224 | cache_path = os.path.join(cache_dir, filename) 225 | 226 | # If we don't have a connection (etag is None) and can't identify the file 227 | # try to get the last downloaded one 228 | if not os.path.exists(cache_path) and etag is None: 229 | matching_files = fnmatch.filter(os.listdir(cache_dir), filename + '.*') 230 | matching_files = list(filter(lambda s: not s.endswith('.json'), matching_files)) 231 | if matching_files: 232 | cache_path = os.path.join(cache_dir, matching_files[-1]) 233 | 234 | if not os.path.exists(cache_path) or force_download: 235 | # Download to temporary file, then copy to cache dir once finished. 236 | # Otherwise you get corrupt cache entries if the download gets interrupted. 237 | with tempfile.NamedTemporaryFile() as temp_file: 238 | logger.info("%s not found in cache or force_download set to True, downloading to %s", url, temp_file.name) 239 | 240 | # GET file object 241 | if url.startswith("s3://"): 242 | s3_get(url, temp_file, proxies=proxies) 243 | else: 244 | http_get(url, temp_file, proxies=proxies) 245 | 246 | # we are copying the file before closing it, so flush to avoid truncation 247 | temp_file.flush() 248 | # shutil.copyfileobj() starts at the current position, so go to the start 249 | temp_file.seek(0) 250 | 251 | logger.info("copying %s to cache at %s", temp_file.name, cache_path) 252 | with open(cache_path, 'wb') as cache_file: 253 | shutil.copyfileobj(temp_file, cache_file) 254 | 255 | logger.info("creating metadata file for %s", cache_path) 256 | meta = {'url': url, 'etag': etag} 257 | meta_path = cache_path + '.json' 258 | with open(meta_path, 'w') as meta_file: 259 | output_string = json.dumps(meta) 260 | if sys.version_info[0] == 2 and isinstance(output_string, str): 261 | output_string = unicode(output_string, 'utf-8') # The beauty of python 2 262 | meta_file.write(output_string) 263 | 264 | logger.info("removing temp file %s", temp_file.name) 265 | 266 | return cache_path 267 | -------------------------------------------------------------------------------- /pytorch_transformers/modeling_transfo_xl_utilities.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ Utilities for PyTorch Transformer XL model. 17 | Directly adapted from https://github.com/kimiyoung/transformer-xl. 18 | """ 19 | 20 | from collections import defaultdict 21 | 22 | import numpy as np 23 | 24 | import torch 25 | import torch.nn as nn 26 | import torch.nn.functional as F 27 | 28 | # CUDA_MAJOR = int(torch.version.cuda.split('.')[0]) 29 | # CUDA_MINOR = int(torch.version.cuda.split('.')[1]) 30 | 31 | class ProjectedAdaptiveLogSoftmax(nn.Module): 32 | def __init__(self, n_token, d_embed, d_proj, cutoffs, div_val=1, 33 | keep_order=False): 34 | super(ProjectedAdaptiveLogSoftmax, self).__init__() 35 | 36 | self.n_token = n_token 37 | self.d_embed = d_embed 38 | self.d_proj = d_proj 39 | 40 | self.cutoffs = cutoffs + [n_token] 41 | self.cutoff_ends = [0] + self.cutoffs 42 | self.div_val = div_val 43 | 44 | self.shortlist_size = self.cutoffs[0] 45 | self.n_clusters = len(self.cutoffs) - 1 46 | self.head_size = self.shortlist_size + self.n_clusters 47 | 48 | if self.n_clusters > 0: 49 | self.cluster_weight = nn.Parameter(torch.zeros(self.n_clusters, self.d_embed)) 50 | self.cluster_bias = nn.Parameter(torch.zeros(self.n_clusters)) 51 | 52 | self.out_layers = nn.ModuleList() 53 | self.out_projs = nn.ParameterList() 54 | 55 | if div_val == 1: 56 | for i in range(len(self.cutoffs)): 57 | if d_proj != d_embed: 58 | self.out_projs.append( 59 | nn.Parameter(torch.FloatTensor(d_proj, d_embed)) 60 | ) 61 | else: 62 | self.out_projs.append(None) 63 | 64 | self.out_layers.append(nn.Linear(d_embed, n_token)) 65 | else: 66 | for i in range(len(self.cutoffs)): 67 | l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i+1] 68 | d_emb_i = d_embed // (div_val ** i) 69 | 70 | self.out_projs.append( 71 | nn.Parameter(torch.FloatTensor(d_proj, d_emb_i)) 72 | ) 73 | 74 | self.out_layers.append(nn.Linear(d_emb_i, r_idx-l_idx)) 75 | 76 | self.keep_order = keep_order 77 | 78 | def _compute_logit(self, hidden, weight, bias, proj): 79 | if proj is None: 80 | logit = F.linear(hidden, weight, bias=bias) 81 | else: 82 | # if CUDA_MAJOR <= 9 and CUDA_MINOR <= 1: 83 | proj_hid = F.linear(hidden, proj.t().contiguous()) 84 | logit = F.linear(proj_hid, weight, bias=bias) 85 | # else: 86 | # logit = torch.einsum('bd,de,ev->bv', (hidden, proj, weight.t())) 87 | # if bias is not None: 88 | # logit = logit + bias 89 | 90 | return logit 91 | 92 | def forward(self, hidden, labels=None, keep_order=False): 93 | ''' 94 | Params: 95 | hidden :: [len*bsz x d_proj] 96 | labels :: [len*bsz] 97 | Return: 98 | if labels is None: 99 | out :: [len*bsz] Negative log likelihood 100 | else: 101 | out :: [len*bsz x n_tokens] log probabilities of tokens over the vocabulary 102 | We could replace this implementation by the native PyTorch one 103 | if their's had an option to set bias on all clusters in the native one. 104 | here: https://github.com/pytorch/pytorch/blob/dbe6a7a9ff1a364a8706bf5df58a1ca96d2fd9da/torch/nn/modules/adaptive.py#L138 105 | ''' 106 | 107 | if labels is not None: 108 | labels = labels.view(-1) 109 | if hidden.size(0) != labels.size(0): 110 | raise RuntimeError('Input and labels should have the same size ' 111 | 'in the batch dimension.') 112 | 113 | if self.n_clusters == 0: 114 | logit = self._compute_logit(hidden, self.out_layers[0].weight, 115 | self.out_layers[0].bias, self.out_projs[0]) 116 | if labels is not None: 117 | out = -F.log_softmax(logit, dim=-1) \ 118 | .gather(1, labels.unsqueeze(1)).squeeze(1) 119 | else: 120 | out = F.log_softmax(logit, dim=-1) 121 | else: 122 | # construct weights and biases 123 | weights, biases = [], [] 124 | for i in range(len(self.cutoffs)): 125 | if self.div_val == 1: 126 | l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1] 127 | weight_i = self.out_layers[0].weight[l_idx:r_idx] 128 | bias_i = self.out_layers[0].bias[l_idx:r_idx] 129 | else: 130 | weight_i = self.out_layers[i].weight 131 | bias_i = self.out_layers[i].bias 132 | 133 | if i == 0: 134 | weight_i = torch.cat( 135 | [weight_i, self.cluster_weight], dim=0) 136 | bias_i = torch.cat( 137 | [bias_i, self.cluster_bias], dim=0) 138 | 139 | weights.append(weight_i) 140 | biases.append(bias_i) 141 | 142 | head_weight, head_bias, head_proj = weights[0], biases[0], self.out_projs[0] 143 | 144 | head_logit = self._compute_logit(hidden, head_weight, head_bias, head_proj) 145 | head_logprob = F.log_softmax(head_logit, dim=1) 146 | 147 | if labels is None: 148 | out = hidden.new_empty((head_logit.size(0), self.n_token)) 149 | else: 150 | out = torch.zeros_like(labels, dtype=hidden.dtype, device=hidden.device) 151 | 152 | offset = 0 153 | cutoff_values = [0] + self.cutoffs 154 | for i in range(len(cutoff_values) - 1): 155 | l_idx, r_idx = cutoff_values[i], cutoff_values[i + 1] 156 | 157 | if labels is not None: 158 | mask_i = (labels >= l_idx) & (labels < r_idx) 159 | indices_i = mask_i.nonzero().squeeze() 160 | 161 | if indices_i.numel() == 0: 162 | continue 163 | 164 | target_i = labels.index_select(0, indices_i) - l_idx 165 | head_logprob_i = head_logprob.index_select(0, indices_i) 166 | hidden_i = hidden.index_select(0, indices_i) 167 | else: 168 | hidden_i = hidden 169 | 170 | if i == 0: 171 | if labels is not None: 172 | logprob_i = head_logprob_i.gather(1, target_i[:, None]).squeeze(1) 173 | else: 174 | out[:, :self.cutoffs[0]] = head_logprob[:, :self.cutoffs[0]] 175 | else: 176 | weight_i, bias_i, proj_i = weights[i], biases[i], self.out_projs[i] 177 | 178 | tail_logit_i = self._compute_logit(hidden_i, weight_i, bias_i, proj_i) 179 | tail_logprob_i = F.log_softmax(tail_logit_i, dim=1) 180 | cluster_prob_idx = self.cutoffs[0] + i - 1 # No probability for the head cluster 181 | if labels is not None: 182 | logprob_i = head_logprob_i[:, cluster_prob_idx] \ 183 | + tail_logprob_i.gather(1, target_i[:, None]).squeeze(1) 184 | else: 185 | logprob_i = head_logprob[:, cluster_prob_idx, None] + tail_logprob_i 186 | out[:, l_idx:r_idx] = logprob_i 187 | 188 | if labels is not None: 189 | if (hasattr(self, 'keep_order') and self.keep_order) or keep_order: 190 | out.index_copy_(0, indices_i, -logprob_i) 191 | else: 192 | out[offset:offset+logprob_i.size(0)].copy_(-logprob_i) 193 | offset += logprob_i.size(0) 194 | 195 | return out 196 | 197 | 198 | def log_prob(self, hidden): 199 | r""" Computes log probabilities for all :math:`n\_classes` 200 | From: https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/adaptive.py 201 | Args: 202 | hidden (Tensor): a minibatch of examples 203 | Returns: 204 | log-probabilities of for each class :math:`c` 205 | in range :math:`0 <= c <= n\_classes`, where :math:`n\_classes` is a 206 | parameter passed to ``AdaptiveLogSoftmaxWithLoss`` constructor. 207 | Shape: 208 | - Input: :math:`(N, in\_features)` 209 | - Output: :math:`(N, n\_classes)` 210 | """ 211 | if self.n_clusters == 0: 212 | logit = self._compute_logit(hidden, self.out_layers[0].weight, 213 | self.out_layers[0].bias, self.out_projs[0]) 214 | return F.log_softmax(logit, dim=-1) 215 | else: 216 | # construct weights and biases 217 | weights, biases = [], [] 218 | for i in range(len(self.cutoffs)): 219 | if self.div_val == 1: 220 | l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1] 221 | weight_i = self.out_layers[0].weight[l_idx:r_idx] 222 | bias_i = self.out_layers[0].bias[l_idx:r_idx] 223 | else: 224 | weight_i = self.out_layers[i].weight 225 | bias_i = self.out_layers[i].bias 226 | 227 | if i == 0: 228 | weight_i = torch.cat( 229 | [weight_i, self.cluster_weight], dim=0) 230 | bias_i = torch.cat( 231 | [bias_i, self.cluster_bias], dim=0) 232 | 233 | weights.append(weight_i) 234 | biases.append(bias_i) 235 | 236 | head_weight, head_bias, head_proj = weights[0], biases[0], self.out_projs[0] 237 | head_logit = self._compute_logit(hidden, head_weight, head_bias, head_proj) 238 | 239 | out = hidden.new_empty((head_logit.size(0), self.n_token)) 240 | head_logprob = F.log_softmax(head_logit, dim=1) 241 | 242 | cutoff_values = [0] + self.cutoffs 243 | for i in range(len(cutoff_values) - 1): 244 | start_idx, stop_idx = cutoff_values[i], cutoff_values[i + 1] 245 | 246 | if i == 0: 247 | out[:, :self.cutoffs[0]] = head_logprob[:, :self.cutoffs[0]] 248 | else: 249 | weight_i, bias_i, proj_i = weights[i], biases[i], self.out_projs[i] 250 | 251 | tail_logit_i = self._compute_logit(hidden, weight_i, bias_i, proj_i) 252 | tail_logprob_i = F.log_softmax(tail_logit_i, dim=1) 253 | 254 | logprob_i = head_logprob[:, -i] + tail_logprob_i 255 | out[:, start_idx, stop_idx] = logprob_i 256 | 257 | return out 258 | 259 | 260 | class LogUniformSampler(object): 261 | def __init__(self, range_max, n_sample): 262 | """ 263 | Reference : https://github.com/tensorflow/tensorflow/blob/r1.10/tensorflow/python/ops/candidate_sampling_ops.py 264 | `P(class) = (log(class + 2) - log(class + 1)) / log(range_max + 1)` 265 | 266 | expected count can be approximated by 1 - (1 - p)^n 267 | and we use a numerically stable version -expm1(num_tries * log1p(-p)) 268 | 269 | Our implementation fixes num_tries at 2 * n_sample, and the actual #samples will vary from run to run 270 | """ 271 | with torch.no_grad(): 272 | self.range_max = range_max 273 | log_indices = torch.arange(1., range_max+2., 1.).log_() 274 | self.dist = (log_indices[1:] - log_indices[:-1]) / log_indices[-1] 275 | 276 | self.log_q = (- (-self.dist.double().log1p_() * 2 * n_sample).expm1_()).log_().float() 277 | 278 | self.n_sample = n_sample 279 | 280 | def sample(self, labels): 281 | """ 282 | labels: [b1, b2] 283 | Return 284 | true_log_probs: [b1, b2] 285 | samp_log_probs: [n_sample] 286 | neg_samples: [n_sample] 287 | """ 288 | 289 | # neg_samples = torch.empty(0).long() 290 | n_sample = self.n_sample 291 | n_tries = 2 * n_sample 292 | 293 | with torch.no_grad(): 294 | neg_samples = torch.multinomial(self.dist, n_tries, replacement=True).unique() 295 | device = labels.device 296 | neg_samples = neg_samples.to(device) 297 | true_log_probs = self.log_q[labels].to(device) 298 | samp_log_probs = self.log_q[neg_samples].to(device) 299 | return true_log_probs, samp_log_probs, neg_samples 300 | 301 | def sample_logits(embedding, bias, labels, inputs, sampler): 302 | """ 303 | embedding: an nn.Embedding layer 304 | bias: [n_vocab] 305 | labels: [b1, b2] 306 | inputs: [b1, b2, n_emb] 307 | sampler: you may use a LogUniformSampler 308 | Return 309 | logits: [b1, b2, 1 + n_sample] 310 | """ 311 | true_log_probs, samp_log_probs, neg_samples = sampler.sample(labels) 312 | n_sample = neg_samples.size(0) 313 | b1, b2 = labels.size(0), labels.size(1) 314 | all_ids = torch.cat([labels.view(-1), neg_samples]) 315 | all_w = embedding(all_ids) 316 | true_w = all_w[: -n_sample].view(b1, b2, -1) 317 | sample_w = all_w[- n_sample:].view(n_sample, -1) 318 | 319 | all_b = bias[all_ids] 320 | true_b = all_b[: -n_sample].view(b1, b2) 321 | sample_b = all_b[- n_sample:] 322 | 323 | hit = (labels[:, :, None] == neg_samples).detach() 324 | 325 | true_logits = torch.einsum('ijk,ijk->ij', 326 | [true_w, inputs]) + true_b - true_log_probs 327 | sample_logits = torch.einsum('lk,ijk->ijl', 328 | [sample_w, inputs]) + sample_b - samp_log_probs 329 | sample_logits.masked_fill_(hit, -1e30) 330 | logits = torch.cat([true_logits[:, :, None], sample_logits], -1) 331 | 332 | return logits 333 | -------------------------------------------------------------------------------- /pytorch_transformers/optimization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """PyTorch optimization for BERT model.""" 16 | 17 | import logging 18 | import math 19 | 20 | import torch 21 | from torch.optim import Optimizer 22 | from torch.optim.lr_scheduler import LambdaLR 23 | 24 | logger = logging.getLogger(__name__) 25 | 26 | class ConstantLRSchedule(LambdaLR): 27 | """ Constant learning rate schedule. 28 | """ 29 | def __init__(self, optimizer, last_epoch=-1): 30 | super(ConstantLRSchedule, self).__init__(optimizer, lambda _: 1.0, last_epoch=last_epoch) 31 | 32 | 33 | class WarmupConstantSchedule(LambdaLR): 34 | """ Linear warmup and then constant. 35 | Linearly increases learning rate schedule from 0 to 1 over `warmup_steps` training steps. 36 | Keeps learning rate schedule equal to 1. after warmup_steps. 37 | """ 38 | def __init__(self, optimizer, warmup_steps, last_epoch=-1): 39 | self.warmup_steps = warmup_steps 40 | super(WarmupConstantSchedule, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch) 41 | 42 | def lr_lambda(self, step): 43 | if step < self.warmup_steps: 44 | return float(step) / float(max(1.0, self.warmup_steps)) 45 | return 1. 46 | 47 | 48 | class WarmupLinearSchedule(LambdaLR): 49 | """ Linear warmup and then linear decay. 50 | Linearly increases learning rate from 0 to 1 over `warmup_steps` training steps. 51 | Linearly decreases learning rate from 1. to 0. over remaining `t_total - warmup_steps` steps. 52 | """ 53 | def __init__(self, optimizer, warmup_steps, t_total, last_epoch=-1): 54 | self.warmup_steps = warmup_steps 55 | self.t_total = t_total 56 | super(WarmupLinearSchedule, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch) 57 | 58 | def lr_lambda(self, step): 59 | if step < self.warmup_steps: 60 | return float(step) / float(max(1, self.warmup_steps)) 61 | return max(0.0, float(self.t_total - step) / float(max(1.0, self.t_total - self.warmup_steps))) 62 | 63 | 64 | class WarmupCosineSchedule(LambdaLR): 65 | """ Linear warmup and then cosine decay. 66 | Linearly increases learning rate from 0 to 1 over `warmup_steps` training steps. 67 | Decreases learning rate from 1. to 0. over remaining `t_total - warmup_steps` steps following a cosine curve. 68 | If `cycles` (default=0.5) is different from default, learning rate follows cosine function after warmup. 69 | """ 70 | def __init__(self, optimizer, warmup_steps, t_total, cycles=.5, last_epoch=-1): 71 | self.warmup_steps = warmup_steps 72 | self.t_total = t_total 73 | self.cycles = cycles 74 | super(WarmupCosineSchedule, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch) 75 | 76 | def lr_lambda(self, step): 77 | if step < self.warmup_steps: 78 | return float(step) / float(max(1.0, self.warmup_steps)) 79 | # progress after warmup 80 | progress = float(step - self.warmup_steps) / float(max(1, self.t_total - self.warmup_steps)) 81 | return max(0.0, 0.5 * (1. + math.cos(math.pi * float(self.cycles) * 2.0 * progress))) 82 | 83 | 84 | class WarmupCosineWithHardRestartsSchedule(LambdaLR): 85 | """ Linear warmup and then cosine cycles with hard restarts. 86 | Linearly increases learning rate from 0 to 1 over `warmup_steps` training steps. 87 | If `cycles` (default=1.) is different from default, learning rate follows `cycles` times a cosine decaying 88 | learning rate (with hard restarts). 89 | """ 90 | def __init__(self, optimizer, warmup_steps, t_total, cycles=1., last_epoch=-1): 91 | self.warmup_steps = warmup_steps 92 | self.t_total = t_total 93 | self.cycles = cycles 94 | super(WarmupCosineWithHardRestartsSchedule, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch) 95 | 96 | def lr_lambda(self, step): 97 | if step < self.warmup_steps: 98 | return float(step) / float(max(1, self.warmup_steps)) 99 | # progress after warmup 100 | progress = float(step - self.warmup_steps) / float(max(1, self.t_total - self.warmup_steps)) 101 | if progress >= 1.0: 102 | return 0.0 103 | return max(0.0, 0.5 * (1. + math.cos(math.pi * ((float(self.cycles) * progress) % 1.0)))) 104 | 105 | 106 | 107 | class AdamW(Optimizer): 108 | """ Implements Adam algorithm with weight decay fix. 109 | 110 | Parameters: 111 | lr (float): learning rate. Default 1e-3. 112 | betas (tuple of 2 floats): Adams beta parameters (b1, b2). Default: (0.9, 0.999) 113 | eps (float): Adams epsilon. Default: 1e-6 114 | weight_decay (float): Weight decay. Default: 0.0 115 | correct_bias (bool): can be set to False to avoid correcting bias in Adam (e.g. like in Bert TF repository). Default True. 116 | """ 117 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6, weight_decay=0.0, 118 | correct_bias=True): 119 | if lr < 0.0: 120 | raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr)) 121 | if not 0.0 <= betas[0] < 1.0: 122 | raise ValueError("Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[0])) 123 | if not 0.0 <= betas[1] < 1.0: 124 | raise ValueError("Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[1])) 125 | if not 0.0 <= eps: 126 | raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(eps)) 127 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, 128 | correct_bias=correct_bias) 129 | super(AdamW, self).__init__(params, defaults) 130 | 131 | def step(self, closure=None): 132 | """Performs a single optimization step. 133 | 134 | Arguments: 135 | closure (callable, optional): A closure that reevaluates the model 136 | and returns the loss. 137 | """ 138 | loss = None 139 | if closure is not None: 140 | loss = closure() 141 | 142 | for group in self.param_groups: 143 | for p in group['params']: 144 | if p.grad is None: 145 | continue 146 | grad = p.grad.data 147 | if grad.is_sparse: 148 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 149 | 150 | state = self.state[p] 151 | 152 | # State initialization 153 | if len(state) == 0: 154 | state['step'] = 0 155 | # Exponential moving average of gradient values 156 | state['exp_avg'] = torch.zeros_like(p.data) 157 | # Exponential moving average of squared gradient values 158 | state['exp_avg_sq'] = torch.zeros_like(p.data) 159 | 160 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 161 | beta1, beta2 = group['betas'] 162 | 163 | state['step'] += 1 164 | 165 | # Decay the first and second moment running average coefficient 166 | # In-place operations to update the averages at the same time 167 | exp_avg.mul_(beta1).add_(1.0 - beta1, grad) 168 | exp_avg_sq.mul_(beta2).addcmul_(1.0 - beta2, grad, grad) 169 | denom = exp_avg_sq.sqrt().add_(group['eps']) 170 | 171 | step_size = group['lr'] 172 | if group['correct_bias']: # No bias correction for Bert 173 | bias_correction1 = 1.0 - beta1 ** state['step'] 174 | bias_correction2 = 1.0 - beta2 ** state['step'] 175 | step_size = step_size * math.sqrt(bias_correction2) / bias_correction1 176 | 177 | p.data.addcdiv_(-step_size, exp_avg, denom) 178 | 179 | # Just adding the square of the weights to the loss function is *not* 180 | # the correct way of using L2 regularization/weight decay with Adam, 181 | # since that will interact with the m and v parameters in strange ways. 182 | # 183 | # Instead we want to decay the weights in a manner that doesn't interact 184 | # with the m/v parameters. This is equivalent to adding the square 185 | # of the weights to the loss with plain (non-momentum) SGD. 186 | # Add weight decay at the end (fixed version) 187 | if group['weight_decay'] > 0.0: 188 | p.data.add_(-group['lr'] * group['weight_decay'], p.data) 189 | 190 | return loss 191 | -------------------------------------------------------------------------------- /pytorch_transformers/tokenization_auto.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ Auto Model class. """ 16 | 17 | from __future__ import absolute_import, division, print_function, unicode_literals 18 | 19 | import logging 20 | 21 | from .tokenization_bert import BertTokenizer 22 | from .tokenization_openai import OpenAIGPTTokenizer 23 | from .tokenization_gpt2 import GPT2Tokenizer 24 | from .tokenization_transfo_xl import TransfoXLTokenizer 25 | from .tokenization_xlnet import XLNetTokenizer 26 | from .tokenization_xlm import XLMTokenizer 27 | from .tokenization_roberta import RobertaTokenizer 28 | 29 | logger = logging.getLogger(__name__) 30 | 31 | class AutoTokenizer(object): 32 | r""":class:`~pytorch_transformers.AutoTokenizer` is a generic tokenizer class 33 | that will be instantiated as one of the tokenizer classes of the library 34 | when created with the `AutoTokenizer.from_pretrained(pretrained_model_name_or_path)` 35 | class method. 36 | 37 | The `from_pretrained()` method take care of returning the correct tokenizer class instance 38 | using pattern matching on the `pretrained_model_name_or_path` string. 39 | 40 | The tokenizer class to instantiate is selected as the first pattern matching 41 | in the `pretrained_model_name_or_path` string (in the following order): 42 | - contains `bert`: BertTokenizer (Bert model) 43 | - contains `openai-gpt`: OpenAIGPTTokenizer (OpenAI GPT model) 44 | - contains `gpt2`: GPT2Tokenizer (OpenAI GPT-2 model) 45 | - contains `transfo-xl`: TransfoXLTokenizer (Transformer-XL model) 46 | - contains `xlnet`: XLNetTokenizer (XLNet model) 47 | - contains `xlm`: XLMTokenizer (XLM model) 48 | - contains `roberta`: RobertaTokenizer (RoBERTa model) 49 | 50 | This class cannot be instantiated using `__init__()` (throw an error). 51 | """ 52 | def __init__(self): 53 | raise EnvironmentError("AutoTokenizer is designed to be instantiated " 54 | "using the `AutoTokenizer.from_pretrained(pretrained_model_name_or_path)` method.") 55 | 56 | @classmethod 57 | def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs): 58 | r""" Instantiate a one of the tokenizer classes of the library 59 | from a pre-trained model vocabulary. 60 | 61 | The tokenizer class to instantiate is selected as the first pattern matching 62 | in the `pretrained_model_name_or_path` string (in the following order): 63 | - contains `bert`: BertTokenizer (Bert model) 64 | - contains `openai-gpt`: OpenAIGPTTokenizer (OpenAI GPT model) 65 | - contains `gpt2`: GPT2Tokenizer (OpenAI GPT-2 model) 66 | - contains `transfo-xl`: TransfoXLTokenizer (Transformer-XL model) 67 | - contains `xlnet`: XLNetTokenizer (XLNet model) 68 | - contains `xlm`: XLMTokenizer (XLM model) 69 | - contains `roberta`: RobertaTokenizer (XLM model) 70 | 71 | Params: 72 | **pretrained_model_name_or_path**: either: 73 | - a string with the `shortcut name` of a pre-trained model configuration to load from cache 74 | or download and cache if not already stored in cache (e.g. 'bert-base-uncased'). 75 | - a path to a `directory` containing a configuration file saved 76 | using the `save_pretrained(save_directory)` method. 77 | - a path or url to a saved configuration `file`. 78 | **cache_dir**: (`optional`) string: 79 | Path to a directory in which a downloaded pre-trained model 80 | configuration should be cached if the standard cache should not be used. 81 | 82 | Examples:: 83 | 84 | config = AutoTokenizer.from_pretrained('bert-base-uncased') # Download vocabulary from S3 and cache. 85 | config = AutoTokenizer.from_pretrained('./test/bert_saved_model/') # E.g. tokenizer was saved using `save_pretrained('./test/saved_model/')` 86 | 87 | """ 88 | if 'roberta' in pretrained_model_name_or_path: 89 | return RobertaTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) 90 | elif 'bert' in pretrained_model_name_or_path: 91 | return BertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) 92 | elif 'openai-gpt' in pretrained_model_name_or_path: 93 | return OpenAIGPTTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) 94 | elif 'gpt2' in pretrained_model_name_or_path: 95 | return GPT2Tokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) 96 | elif 'transfo-xl' in pretrained_model_name_or_path: 97 | return TransfoXLTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) 98 | elif 'xlnet' in pretrained_model_name_or_path: 99 | return XLNetTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) 100 | elif 'xlm' in pretrained_model_name_or_path: 101 | return XLMTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) 102 | 103 | raise ValueError("Unrecognized model identifier in {}. Should contains one of " 104 | "'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', " 105 | "'xlm', 'roberta'".format(pretrained_model_name_or_path)) 106 | -------------------------------------------------------------------------------- /pytorch_transformers/tokenization_distilbert.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes for DistilBERT.""" 16 | 17 | from __future__ import absolute_import, division, print_function, unicode_literals 18 | 19 | import collections 20 | import logging 21 | import os 22 | import unicodedata 23 | from io import open 24 | 25 | from .tokenization_bert import BertTokenizer 26 | 27 | logger = logging.getLogger(__name__) 28 | 29 | VOCAB_FILES_NAMES = {'vocab_file': 'vocab.txt'} 30 | 31 | PRETRAINED_VOCAB_FILES_MAP = { 32 | 'vocab_file': 33 | { 34 | 'distilbert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt", 35 | 'distilbert-base-uncased-distilled-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt", 36 | } 37 | } 38 | 39 | PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { 40 | 'distilbert-base-uncased': 512, 41 | 'distilbert-base-uncased-distilled-squad': 512, 42 | } 43 | 44 | 45 | class DistilBertTokenizer(BertTokenizer): 46 | r""" 47 | Constructs a DistilBertTokenizer. 48 | :class:`~pytorch_transformers.DistilBertTokenizer` is identical to BertTokenizer and runs end-to-end tokenization: punctuation splitting + wordpiece 49 | 50 | Args: 51 | vocab_file: Path to a one-wordpiece-per-line vocabulary file 52 | do_lower_case: Whether to lower case the input. Only has an effect when do_wordpiece_only=False 53 | do_basic_tokenize: Whether to do basic tokenization before wordpiece. 54 | max_len: An artificial maximum length to truncate tokenized sequences to; Effective maximum length is always the 55 | minimum of this value (if specified) and the underlying BERT model's sequence length. 56 | never_split: List of tokens which will never be split during tokenization. Only has an effect when 57 | do_wordpiece_only=False 58 | """ 59 | 60 | vocab_files_names = VOCAB_FILES_NAMES 61 | pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP 62 | max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES 63 | -------------------------------------------------------------------------------- /pytorch_transformers/tokenization_gpt2.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes for OpenAI GPT.""" 16 | from __future__ import (absolute_import, division, print_function, 17 | unicode_literals) 18 | 19 | import sys 20 | import json 21 | import logging 22 | import os 23 | import regex as re 24 | from io import open 25 | 26 | try: 27 | from functools import lru_cache 28 | except ImportError: 29 | # Just a dummy decorator to get the checks to run on python2 30 | # because honestly I don't want to support a byte-level unicode BPE tokenizer on python 2 right now. 31 | def lru_cache(): 32 | return lambda func: func 33 | 34 | from .tokenization_utils import PreTrainedTokenizer 35 | 36 | logger = logging.getLogger(__name__) 37 | 38 | VOCAB_FILES_NAMES = { 39 | 'vocab_file': 'vocab.json', 40 | 'merges_file': 'merges.txt', 41 | } 42 | 43 | PRETRAINED_VOCAB_FILES_MAP = { 44 | 'vocab_file': 45 | { 46 | 'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json", 47 | 'gpt2-medium': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-vocab.json", 48 | 'gpt2-large': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-large-vocab.json", 49 | }, 50 | 'merges_file': 51 | { 52 | 'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt", 53 | 'gpt2-medium': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-merges.txt", 54 | 'gpt2-large': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-large-merges.txt", 55 | }, 56 | } 57 | 58 | PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { 59 | 'gpt2': 1024, 60 | 'gpt2-medium': 1024, 61 | 'gpt2-large': 1024, 62 | } 63 | 64 | @lru_cache() 65 | def bytes_to_unicode(): 66 | """ 67 | Returns list of utf-8 byte and a corresponding list of unicode strings. 68 | The reversible bpe codes work on unicode strings. 69 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 70 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 71 | This is a signficant percentage of your normal, say, 32K bpe vocab. 72 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 73 | And avoids mapping to whitespace/control characters the bpe code barfs on. 74 | """ 75 | _chr = unichr if sys.version_info[0] == 2 else chr 76 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 77 | cs = bs[:] 78 | n = 0 79 | for b in range(2**8): 80 | if b not in bs: 81 | bs.append(b) 82 | cs.append(2**8+n) 83 | n += 1 84 | cs = [_chr(n) for n in cs] 85 | return dict(zip(bs, cs)) 86 | 87 | def get_pairs(word): 88 | """Return set of symbol pairs in a word. 89 | 90 | Word is represented as tuple of symbols (symbols being variable-length strings). 91 | """ 92 | pairs = set() 93 | prev_char = word[0] 94 | for char in word[1:]: 95 | pairs.add((prev_char, char)) 96 | prev_char = char 97 | return pairs 98 | 99 | class GPT2Tokenizer(PreTrainedTokenizer): 100 | """ 101 | GPT-2 BPE tokenizer. Peculiarities: 102 | - Byte-level BPE 103 | """ 104 | vocab_files_names = VOCAB_FILES_NAMES 105 | pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP 106 | max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES 107 | 108 | def __init__(self, vocab_file, merges_file, errors='replace', unk_token="<|endoftext|>", 109 | bos_token="<|endoftext|>", eos_token="<|endoftext|>", **kwargs): 110 | super(GPT2Tokenizer, self).__init__(bos_token=bos_token, eos_token=eos_token, unk_token=unk_token, **kwargs) 111 | self.max_len_single_sentence = self.max_len # no default special tokens - you can update this value if you add special tokens 112 | self.max_len_sentences_pair = self.max_len # no default special tokens - you can update this value if you add special tokens 113 | 114 | self.encoder = json.load(open(vocab_file)) 115 | self.decoder = {v:k for k,v in self.encoder.items()} 116 | self.errors = errors # how to handle errors in decoding 117 | self.byte_encoder = bytes_to_unicode() 118 | self.byte_decoder = {v:k for k, v in self.byte_encoder.items()} 119 | bpe_data = open(merges_file, encoding='utf-8').read().split('\n')[1:-1] 120 | bpe_merges = [tuple(merge.split()) for merge in bpe_data] 121 | self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges)))) 122 | self.cache = {} 123 | 124 | # Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions 125 | self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""") 126 | 127 | @property 128 | def vocab_size(self): 129 | return len(self.encoder) 130 | 131 | def bpe(self, token): 132 | if token in self.cache: 133 | return self.cache[token] 134 | word = tuple(token) 135 | pairs = get_pairs(word) 136 | 137 | if not pairs: 138 | return token 139 | 140 | while True: 141 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 142 | if bigram not in self.bpe_ranks: 143 | break 144 | first, second = bigram 145 | new_word = [] 146 | i = 0 147 | while i < len(word): 148 | try: 149 | j = word.index(first, i) 150 | new_word.extend(word[i:j]) 151 | i = j 152 | except: 153 | new_word.extend(word[i:]) 154 | break 155 | 156 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 157 | new_word.append(first+second) 158 | i += 2 159 | else: 160 | new_word.append(word[i]) 161 | i += 1 162 | new_word = tuple(new_word) 163 | word = new_word 164 | if len(word) == 1: 165 | break 166 | else: 167 | pairs = get_pairs(word) 168 | word = ' '.join(word) 169 | self.cache[token] = word 170 | return word 171 | 172 | def _tokenize(self, text): 173 | """ Tokenize a string. """ 174 | bpe_tokens = [] 175 | for token in re.findall(self.pat, text): 176 | if sys.version_info[0] == 2: 177 | token = ''.join(self.byte_encoder[ord(b)] for b in token) 178 | else: 179 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 180 | bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(' ')) 181 | return bpe_tokens 182 | 183 | def _convert_token_to_id(self, token): 184 | """ Converts a token (str/unicode) in an id using the vocab. """ 185 | return self.encoder.get(token, self.encoder.get(self.unk_token)) 186 | 187 | def _convert_id_to_token(self, index): 188 | """Converts an index (integer) in a token (string/unicode) using the vocab.""" 189 | return self.decoder.get(index) 190 | 191 | def convert_tokens_to_string(self, tokens): 192 | """ Converts a sequence of tokens (string) in a single string. """ 193 | text = ''.join(tokens) 194 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors) 195 | return text 196 | 197 | def save_vocabulary(self, save_directory): 198 | """Save the tokenizer vocabulary and merge files to a directory.""" 199 | if not os.path.isdir(save_directory): 200 | logger.error("Vocabulary path ({}) should be a directory".format(save_directory)) 201 | return 202 | vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES['vocab_file']) 203 | merge_file = os.path.join(save_directory, VOCAB_FILES_NAMES['merges_file']) 204 | 205 | with open(vocab_file, 'w', encoding='utf-8') as f: 206 | f.write(json.dumps(self.encoder, ensure_ascii=False)) 207 | 208 | index = 0 209 | with open(merge_file, "w", encoding="utf-8") as writer: 210 | writer.write(u'#version: 0.2\n') 211 | for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]): 212 | if index != token_index: 213 | logger.warning("Saving vocabulary to {}: BPE merge indices are not consecutive." 214 | " Please check that the tokenizer is not corrupted!".format(merge_file)) 215 | index = token_index 216 | writer.write(' '.join(bpe_tokens) + u'\n') 217 | index += 1 218 | 219 | return vocab_file, merge_file 220 | -------------------------------------------------------------------------------- /pytorch_transformers/tokenization_openai.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes for OpenAI GPT.""" 16 | from __future__ import (absolute_import, division, print_function, 17 | unicode_literals) 18 | 19 | import json 20 | import logging 21 | import os 22 | import re 23 | from io import open 24 | 25 | from .tokenization_utils import PreTrainedTokenizer 26 | from .tokenization_bert import BasicTokenizer 27 | 28 | logger = logging.getLogger(__name__) 29 | 30 | VOCAB_FILES_NAMES = { 31 | 'vocab_file': 'vocab.json', 32 | 'merges_file': 'merges.txt', 33 | } 34 | 35 | PRETRAINED_VOCAB_FILES_MAP = { 36 | 'vocab_file': 37 | { 38 | 'openai-gpt': "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-vocab.json", 39 | }, 40 | 'merges_file': 41 | { 42 | 'openai-gpt': "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-merges.txt", 43 | }, 44 | } 45 | 46 | PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { 47 | 'openai-gpt': 512, 48 | } 49 | 50 | def get_pairs(word): 51 | """ 52 | Return set of symbol pairs in a word. 53 | word is represented as tuple of symbols (symbols being variable-length strings) 54 | """ 55 | pairs = set() 56 | prev_char = word[0] 57 | for char in word[1:]: 58 | pairs.add((prev_char, char)) 59 | prev_char = char 60 | return pairs 61 | 62 | def text_standardize(text): 63 | """ 64 | fixes some issues the spacy tokenizer had on books corpus 65 | also does some whitespace standardization 66 | """ 67 | text = text.replace('—', '-') 68 | text = text.replace('–', '-') 69 | text = text.replace('―', '-') 70 | text = text.replace('…', '...') 71 | text = text.replace('´', "'") 72 | text = re.sub(r'''(-+|~+|!+|"+|;+|\?+|\++|,+|\)+|\(+|\\+|\/+|\*+|\[+|\]+|}+|{+|\|+|_+)''', r' \1 ', text) 73 | text = re.sub(r'\s*\n\s*', ' \n ', text) 74 | text = re.sub(r'[^\S\n]+', ' ', text) 75 | return text.strip() 76 | 77 | class OpenAIGPTTokenizer(PreTrainedTokenizer): 78 | """ 79 | BPE tokenizer. Peculiarities: 80 | - lower case all inputs 81 | - uses SpaCy tokenizer and ftfy for pre-BPE tokenization if they are installed, fallback to BERT's BasicTokenizer if not. 82 | """ 83 | vocab_files_names = VOCAB_FILES_NAMES 84 | pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP 85 | max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES 86 | 87 | def __init__(self, vocab_file, merges_file, unk_token="<unk>", **kwargs): 88 | super(OpenAIGPTTokenizer, self).__init__(unk_token=unk_token, **kwargs) 89 | 90 | self.max_len_single_sentence = self.max_len # no default special tokens - you can update this value if you add special tokens 91 | self.max_len_sentences_pair = self.max_len # no default special tokens - you can update this value if you add special tokens 92 | 93 | try: 94 | import ftfy 95 | from spacy.lang.en import English 96 | _nlp = English() 97 | self.nlp = _nlp.Defaults.create_tokenizer(_nlp) 98 | self.fix_text = ftfy.fix_text 99 | except ImportError: 100 | logger.warning("ftfy or spacy is not installed using BERT BasicTokenizer instead of SpaCy & ftfy.") 101 | self.nlp = BasicTokenizer(do_lower_case=True) 102 | self.fix_text = None 103 | 104 | self.encoder = json.load(open(vocab_file, encoding="utf-8")) 105 | self.decoder = {v:k for k,v in self.encoder.items()} 106 | merges = open(merges_file, encoding='utf-8').read().split('\n')[1:-1] 107 | merges = [tuple(merge.split()) for merge in merges] 108 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 109 | self.cache = {} 110 | 111 | @property 112 | def vocab_size(self): 113 | return len(self.encoder) 114 | 115 | def bpe(self, token): 116 | word = tuple(token[:-1]) + (token[-1] + '</w>',) 117 | if token in self.cache: 118 | return self.cache[token] 119 | pairs = get_pairs(word) 120 | 121 | if not pairs: 122 | return token+'</w>' 123 | 124 | while True: 125 | bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf'))) 126 | if bigram not in self.bpe_ranks: 127 | break 128 | first, second = bigram 129 | new_word = [] 130 | i = 0 131 | while i < len(word): 132 | try: 133 | j = word.index(first, i) 134 | new_word.extend(word[i:j]) 135 | i = j 136 | except: 137 | new_word.extend(word[i:]) 138 | break 139 | 140 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 141 | new_word.append(first+second) 142 | i += 2 143 | else: 144 | new_word.append(word[i]) 145 | i += 1 146 | new_word = tuple(new_word) 147 | word = new_word 148 | if len(word) == 1: 149 | break 150 | else: 151 | pairs = get_pairs(word) 152 | word = ' '.join(word) 153 | if word == '\n </w>': 154 | word = '\n</w>' 155 | self.cache[token] = word 156 | return word 157 | 158 | def _tokenize(self, text): 159 | """ Tokenize a string. """ 160 | split_tokens = [] 161 | if self.fix_text is None: 162 | # Using BERT's BasicTokenizer 163 | text = self.nlp.tokenize(text) 164 | for token in text: 165 | split_tokens.extend([t for t in self.bpe(token).split(' ')]) 166 | else: 167 | # Using SpaCy & ftfy (original tokenization process of OpenAI GPT) 168 | text = self.nlp(text_standardize(self.fix_text(text))) 169 | for token in text: 170 | split_tokens.extend([t for t in self.bpe(token.text.lower()).split(' ')]) 171 | return split_tokens 172 | 173 | def _convert_token_to_id(self, token): 174 | """ Converts a token (str/unicode) in an id using the vocab. """ 175 | return self.encoder.get(token, self.encoder.get(self.unk_token)) 176 | 177 | def _convert_id_to_token(self, index): 178 | """Converts an id in a token (BPE) using the vocab.""" 179 | return self.decoder.get(index, self.unk_token) 180 | 181 | def convert_tokens_to_string(self, tokens): 182 | """ Converts a sequence of tokens (string) in a single string. """ 183 | out_string = ''.join(tokens).replace('</w>', ' ').strip() 184 | return out_string 185 | 186 | def save_vocabulary(self, save_directory): 187 | """Save the tokenizer vocabulary and merge files to a directory.""" 188 | if not os.path.isdir(save_directory): 189 | logger.error("Vocabulary path ({}) should be a directory".format(save_directory)) 190 | return 191 | vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES['vocab_file']) 192 | merge_file = os.path.join(save_directory, VOCAB_FILES_NAMES['merges_file']) 193 | 194 | with open(vocab_file, 'w', encoding='utf-8') as f: 195 | f.write(json.dumps(self.encoder, ensure_ascii=False)) 196 | 197 | index = 0 198 | with open(merge_file, "w", encoding="utf-8") as writer: 199 | writer.write(u'#version: 0.2\n') 200 | for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]): 201 | if index != token_index: 202 | logger.warning("Saving vocabulary to {}: BPE merge indices are not consecutive." 203 | " Please check that the tokenizer is not corrupted!".format(merge_file)) 204 | index = token_index 205 | writer.write(' '.join(bpe_tokens) + u'\n') 206 | index += 1 207 | 208 | return vocab_file, merge_file 209 | -------------------------------------------------------------------------------- /pytorch_transformers/tokenization_roberta.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes for RoBERTa.""" 16 | from __future__ import (absolute_import, division, print_function, 17 | unicode_literals) 18 | 19 | import sys 20 | import json 21 | import logging 22 | import os 23 | import regex as re 24 | from io import open 25 | 26 | from .tokenization_gpt2 import bytes_to_unicode, get_pairs 27 | from .tokenization_utils import PreTrainedTokenizer 28 | 29 | try: 30 | from functools import lru_cache 31 | except ImportError: 32 | # Just a dummy decorator to get the checks to run on python2 33 | # because honestly I don't want to support a byte-level unicode BPE tokenizer on python 2 right now. 34 | def lru_cache(): 35 | return lambda func: func 36 | 37 | logger = logging.getLogger(__name__) 38 | 39 | VOCAB_FILES_NAMES = { 40 | 'vocab_file': 'vocab.json', 41 | 'merges_file': 'merges.txt', 42 | } 43 | 44 | PRETRAINED_VOCAB_FILES_MAP = { 45 | 'vocab_file': 46 | { 47 | 'roberta-base': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-vocab.json", 48 | 'roberta-large': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-vocab.json", 49 | 'roberta-large-mnli': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-mnli-vocab.json", 50 | }, 51 | 'merges_file': 52 | { 53 | 'roberta-base': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-merges.txt", 54 | 'roberta-large': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-merges.txt", 55 | 'roberta-large-mnli': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-mnli-merges.txt", 56 | }, 57 | } 58 | 59 | PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { 60 | 'roberta-base': 512, 61 | 'roberta-large': 512, 62 | 'roberta-large-mnli': 512, 63 | } 64 | 65 | 66 | class RobertaTokenizer(PreTrainedTokenizer): 67 | """ 68 | RoBERTa BPE tokenizer, derived from the GPT-2 tokenizer. Peculiarities: Byte-level BPE 69 | """ 70 | vocab_files_names = VOCAB_FILES_NAMES 71 | pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP 72 | max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES 73 | 74 | def __init__(self, vocab_file, merges_file, errors='replace', bos_token="<s>", eos_token="</s>", sep_token="</s>", 75 | cls_token="<s>", unk_token="<unk>", pad_token='<pad>', mask_token='<mask>', **kwargs): 76 | super(RobertaTokenizer, self).__init__(bos_token=bos_token, eos_token=eos_token, unk_token=unk_token, 77 | sep_token=sep_token, cls_token=cls_token, pad_token=pad_token, 78 | mask_token=mask_token, **kwargs) 79 | 80 | self.max_len_single_sentence = self.max_len - 2 # take into account special tokens 81 | self.max_len_sentences_pair = self.max_len - 4 # take into account special tokens 82 | 83 | self.encoder = json.load(open(vocab_file, encoding="utf-8")) 84 | self.decoder = {v: k for k, v in self.encoder.items()} 85 | self.errors = errors # how to handle errors in decoding 86 | self.byte_encoder = bytes_to_unicode() 87 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 88 | bpe_data = open(merges_file, encoding='utf-8').read().split('\n')[1:-1] 89 | bpe_merges = [tuple(merge.split()) for merge in bpe_data] 90 | self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges)))) 91 | self.cache = {} 92 | 93 | # Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions 94 | self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""") 95 | 96 | @property 97 | def vocab_size(self): 98 | return len(self.encoder) 99 | 100 | def bpe(self, token): 101 | if token in self.cache: 102 | return self.cache[token] 103 | word = tuple(token) 104 | pairs = get_pairs(word) 105 | 106 | if not pairs: 107 | return token 108 | 109 | while True: 110 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 111 | if bigram not in self.bpe_ranks: 112 | break 113 | first, second = bigram 114 | new_word = [] 115 | i = 0 116 | while i < len(word): 117 | try: 118 | j = word.index(first, i) 119 | new_word.extend(word[i:j]) 120 | i = j 121 | except: 122 | new_word.extend(word[i:]) 123 | break 124 | 125 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 126 | new_word.append(first+second) 127 | i += 2 128 | else: 129 | new_word.append(word[i]) 130 | i += 1 131 | new_word = tuple(new_word) 132 | word = new_word 133 | if len(word) == 1: 134 | break 135 | else: 136 | pairs = get_pairs(word) 137 | word = ' '.join(word) 138 | self.cache[token] = word 139 | return word 140 | 141 | def _tokenize(self, text): 142 | """ Tokenize a string. """ 143 | bpe_tokens = [] 144 | for token in re.findall(self.pat, text): 145 | if sys.version_info[0] == 2: 146 | token = ''.join(self.byte_encoder[ord(b)] for b in token) 147 | else: 148 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 149 | bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(' ')) 150 | return bpe_tokens 151 | 152 | def _convert_token_to_id(self, token): 153 | """ Converts a token (str/unicode) in an id using the vocab. """ 154 | return self.encoder.get(token, self.encoder.get(self.unk_token)) 155 | 156 | def _convert_id_to_token(self, index): 157 | """Converts an index (integer) in a token (string/unicode) using the vocab.""" 158 | return self.decoder.get(index) 159 | 160 | def convert_tokens_to_string(self, tokens): 161 | """ Converts a sequence of tokens (string) in a single string. """ 162 | text = ''.join(tokens) 163 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors) 164 | return text 165 | 166 | def add_special_tokens_single_sentence(self, token_ids): 167 | """ 168 | Adds special tokens to a sequence for sequence classification tasks. 169 | A RoBERTa sequence has the following format: <s> X </s> 170 | """ 171 | return [self._convert_token_to_id(self.cls_token)] + token_ids + [self._convert_token_to_id(self.sep_token)] 172 | 173 | def add_special_tokens_sentences_pair(self, token_ids_0, token_ids_1): 174 | """ 175 | Adds special tokens to a sequence pair for sequence classification tasks. 176 | A RoBERTa sequence pair has the following format: <s> A </s></s> B </s> 177 | """ 178 | sep = [self._convert_token_to_id(self.sep_token)] 179 | cls = [self._convert_token_to_id(self.cls_token)] 180 | return cls + token_ids_0 + sep + sep + token_ids_1 + sep 181 | 182 | def save_vocabulary(self, save_directory): 183 | """Save the tokenizer vocabulary and merge files to a directory.""" 184 | if not os.path.isdir(save_directory): 185 | logger.error("Vocabulary path ({}) should be a directory".format(save_directory)) 186 | return 187 | vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES['vocab_file']) 188 | merge_file = os.path.join(save_directory, VOCAB_FILES_NAMES['merges_file']) 189 | 190 | with open(vocab_file, 'w', encoding='utf-8') as f: 191 | f.write(json.dumps(self.encoder, ensure_ascii=False)) 192 | 193 | index = 0 194 | with open(merge_file, "w", encoding="utf-8") as writer: 195 | writer.write(u'#version: 0.2\n') 196 | for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]): 197 | if index != token_index: 198 | logger.warning("Saving vocabulary to {}: BPE merge indices are not consecutive." 199 | " Please check that the tokenizer is not corrupted!".format(merge_file)) 200 | index = token_index 201 | writer.write(' '.join(bpe_tokens) + u'\n') 202 | index += 1 203 | 204 | return vocab_file, merge_file 205 | -------------------------------------------------------------------------------- /pytorch_transformers/tokenization_xlm.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 The Open AI Team Authors and The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes for OpenAI GPT.""" 16 | from __future__ import (absolute_import, division, print_function, 17 | unicode_literals) 18 | 19 | import json 20 | import logging 21 | import os 22 | import re 23 | from io import open 24 | 25 | from .tokenization_utils import PreTrainedTokenizer 26 | from .tokenization_bert import BasicTokenizer 27 | 28 | logger = logging.getLogger(__name__) 29 | 30 | VOCAB_FILES_NAMES = { 31 | 'vocab_file': 'vocab.json', 32 | 'merges_file': 'merges.txt', 33 | } 34 | 35 | PRETRAINED_VOCAB_FILES_MAP = { 36 | 'vocab_file': 37 | { 38 | 'xlm-mlm-en-2048': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-en-2048-vocab.json", 39 | 'xlm-mlm-ende-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-ende-1024-vocab.json", 40 | 'xlm-mlm-enfr-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-enfr-1024-vocab.json", 41 | 'xlm-mlm-enro-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-enro-1024-vocab.json", 42 | 'xlm-mlm-tlm-xnli15-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-tlm-xnli15-1024-vocab.json", 43 | 'xlm-mlm-xnli15-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-xnli15-1024-vocab.json", 44 | 'xlm-clm-enfr-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-clm-enfr-1024-vocab.json", 45 | 'xlm-clm-ende-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-clm-ende-1024-vocab.json", 46 | }, 47 | 'merges_file': 48 | { 49 | 'xlm-mlm-en-2048': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-en-2048-merges.txt", 50 | 'xlm-mlm-ende-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-ende-1024-merges.txt", 51 | 'xlm-mlm-enfr-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-enfr-1024-merges.txt", 52 | 'xlm-mlm-enro-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-enro-1024-merges.txt", 53 | 'xlm-mlm-tlm-xnli15-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-tlm-xnli15-1024-merges.txt", 54 | 'xlm-mlm-xnli15-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-xnli15-1024-merges.txt", 55 | 'xlm-clm-enfr-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-enfr-1024-merges.txt", 56 | 'xlm-clm-ende-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-ende-1024-merges.txt", 57 | }, 58 | } 59 | 60 | PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { 61 | 'xlm-mlm-en-2048': 512, 62 | 'xlm-mlm-ende-1024': 512, 63 | 'xlm-mlm-enfr-1024': 512, 64 | 'xlm-mlm-enro-1024': 512, 65 | 'xlm-mlm-tlm-xnli15-1024': 512, 66 | 'xlm-mlm-xnli15-1024': 512, 67 | 'xlm-clm-enfr-1024': 512, 68 | 'xlm-clm-ende-1024': 512, 69 | } 70 | 71 | def get_pairs(word): 72 | """ 73 | Return set of symbol pairs in a word. 74 | word is represented as tuple of symbols (symbols being variable-length strings) 75 | """ 76 | pairs = set() 77 | prev_char = word[0] 78 | for char in word[1:]: 79 | pairs.add((prev_char, char)) 80 | prev_char = char 81 | return pairs 82 | 83 | def text_standardize(text): 84 | """ 85 | fixes some issues the spacy tokenizer had on books corpus 86 | also does some whitespace standardization 87 | """ 88 | text = text.replace('—', '-') 89 | text = text.replace('–', '-') 90 | text = text.replace('―', '-') 91 | text = text.replace('…', '...') 92 | text = text.replace('´', "'") 93 | text = re.sub(r'''(-+|~+|!+|"+|;+|\?+|\++|,+|\)+|\(+|\\+|\/+|\*+|\[+|\]+|}+|{+|\|+|_+)''', r' \1 ', text) 94 | text = re.sub(r'\s*\n\s*', ' \n ', text) 95 | text = re.sub(r'[^\S\n]+', ' ', text) 96 | return text.strip() 97 | 98 | class XLMTokenizer(PreTrainedTokenizer): 99 | """ 100 | BPE tokenizer for XLM, adapted from OpenAI BPE tokenizer. Peculiarities: 101 | 102 | - lower case all inputs 103 | 104 | - uses `SpaCy tokenizer <https://spacy.io/api/tokenizer/>`_ and \ 105 | `ftfy <https://ftfy.readthedocs.io/en/latest/>`_ for pre-BPE tokenization if they are installed, \ 106 | fallback to BERT's BasicTokenizer if not. 107 | 108 | - argument ``special_tokens`` and function ``set_special_tokens``, can be used to add additional symbols \ 109 | (ex: "__classify__") to a vocabulary. 110 | """ 111 | vocab_files_names = VOCAB_FILES_NAMES 112 | pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP 113 | max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES 114 | 115 | def __init__(self, vocab_file, merges_file, unk_token="<unk>", bos_token="<s>", 116 | sep_token="</s>", pad_token="<pad>", cls_token="</s>", 117 | mask_token="<special1>", additional_special_tokens=["<special0>", 118 | "<special1>", "<special2>", "<special3>", "<special4>", "<special5>", 119 | "<special6>", "<special7>", "<special8>", "<special9>"], **kwargs): 120 | super(XLMTokenizer, self).__init__(unk_token=unk_token, bos_token=bos_token, 121 | sep_token=sep_token, pad_token=pad_token, 122 | cls_token=cls_token, mask_token=mask_token, 123 | additional_special_tokens=additional_special_tokens, 124 | **kwargs) 125 | 126 | self.max_len_single_sentence = self.max_len - 2 # take into account special tokens 127 | self.max_len_sentences_pair = self.max_len - 3 # take into account special tokens 128 | 129 | try: 130 | import ftfy 131 | from spacy.lang.en import English 132 | _nlp = English() 133 | self.nlp = _nlp.Defaults.create_tokenizer(_nlp) 134 | self.fix_text = ftfy.fix_text 135 | except ImportError: 136 | logger.warning("ftfy or spacy is not installed using BERT BasicTokenizer instead of SpaCy & ftfy.") 137 | self.nlp = BasicTokenizer(do_lower_case=True) 138 | self.fix_text = None 139 | 140 | self.encoder = json.load(open(vocab_file, encoding="utf-8")) 141 | self.decoder = {v:k for k,v in self.encoder.items()} 142 | merges = open(merges_file, encoding='utf-8').read().split('\n')[:-1] 143 | merges = [tuple(merge.split()[:2]) for merge in merges] 144 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 145 | self.cache = {} 146 | 147 | @property 148 | def vocab_size(self): 149 | return len(self.encoder) 150 | 151 | def bpe(self, token): 152 | word = tuple(token[:-1]) + (token[-1] + '</w>',) 153 | if token in self.cache: 154 | return self.cache[token] 155 | pairs = get_pairs(word) 156 | 157 | if not pairs: 158 | return token+'</w>' 159 | 160 | while True: 161 | bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf'))) 162 | if bigram not in self.bpe_ranks: 163 | break 164 | first, second = bigram 165 | new_word = [] 166 | i = 0 167 | while i < len(word): 168 | try: 169 | j = word.index(first, i) 170 | new_word.extend(word[i:j]) 171 | i = j 172 | except: 173 | new_word.extend(word[i:]) 174 | break 175 | 176 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 177 | new_word.append(first+second) 178 | i += 2 179 | else: 180 | new_word.append(word[i]) 181 | i += 1 182 | new_word = tuple(new_word) 183 | word = new_word 184 | if len(word) == 1: 185 | break 186 | else: 187 | pairs = get_pairs(word) 188 | word = ' '.join(word) 189 | if word == '\n </w>': 190 | word = '\n</w>' 191 | self.cache[token] = word 192 | return word 193 | 194 | def _tokenize(self, text): 195 | """ Tokenize a string. """ 196 | split_tokens = [] 197 | if self.fix_text is None: 198 | # Using BERT's BasicTokenizer 199 | text = self.nlp.tokenize(text) 200 | for token in text: 201 | split_tokens.extend([t for t in self.bpe(token).split(' ')]) 202 | else: 203 | # Using SpaCy & ftfy (original tokenization process of OpenAI GPT) 204 | text = self.nlp(text_standardize(self.fix_text(text))) 205 | for token in text: 206 | split_tokens.extend([t for t in self.bpe(token.text.lower()).split(' ')]) 207 | return split_tokens 208 | 209 | def _convert_token_to_id(self, token): 210 | """ Converts a token (str/unicode) in an id using the vocab. """ 211 | return self.encoder.get(token, self.encoder.get(self.unk_token)) 212 | 213 | def _convert_id_to_token(self, index): 214 | """Converts an index (integer) in a token (string/unicode) using the vocab.""" 215 | return self.decoder.get(index, self.unk_token) 216 | 217 | def convert_tokens_to_string(self, tokens): 218 | """ Converts a sequence of tokens (string) in a single string. """ 219 | out_string = ''.join(tokens).replace('</w>', ' ').strip() 220 | return out_string 221 | 222 | def add_special_tokens_single_sentence(self, token_ids): 223 | """ 224 | Adds special tokens to a sequence for sequence classification tasks. 225 | An XLM sequence has the following format: [CLS] X [SEP] 226 | """ 227 | return [self._convert_token_to_id(self.cls_token)] + token_ids + [self._convert_token_to_id(self.sep_token)] 228 | 229 | def add_special_tokens_sentences_pair(self, token_ids_0, token_ids_1): 230 | """ 231 | Adds special tokens to a sequence pair for sequence classification tasks. 232 | An XLM sequence pair has the following format: [CLS] A [SEP] B [SEP] 233 | """ 234 | sep = [self._convert_token_to_id(self.sep_token)] 235 | cls = [self._convert_token_to_id(self.cls_token)] 236 | return cls + token_ids_0 + sep + token_ids_1 + sep 237 | 238 | def save_vocabulary(self, save_directory): 239 | """Save the tokenizer vocabulary and merge files to a directory.""" 240 | if not os.path.isdir(save_directory): 241 | logger.error("Vocabulary path ({}) should be a directory".format(save_directory)) 242 | return 243 | vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES['vocab_file']) 244 | merge_file = os.path.join(save_directory, VOCAB_FILES_NAMES['merges_file']) 245 | 246 | with open(vocab_file, 'w', encoding='utf-8') as f: 247 | f.write(json.dumps(self.encoder, ensure_ascii=False)) 248 | 249 | index = 0 250 | with open(merge_file, "w", encoding="utf-8") as writer: 251 | for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]): 252 | if index != token_index: 253 | logger.warning("Saving vocabulary to {}: BPE merge indices are not consecutive." 254 | " Please check that the tokenizer is not corrupted!".format(merge_file)) 255 | index = token_index 256 | writer.write(' '.join(bpe_tokens) + u'\n') 257 | index += 1 258 | 259 | return vocab_file, merge_file 260 | -------------------------------------------------------------------------------- /pytorch_transformers/tokenization_xlnet.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ Tokenization classes for XLNet model.""" 16 | from __future__ import (absolute_import, division, print_function, 17 | unicode_literals) 18 | 19 | import logging 20 | import os 21 | from shutil import copyfile 22 | 23 | import unicodedata 24 | import six 25 | 26 | from .tokenization_utils import PreTrainedTokenizer 27 | 28 | logger = logging.getLogger(__name__) 29 | 30 | VOCAB_FILES_NAMES = {'vocab_file': 'spiece.model'} 31 | 32 | PRETRAINED_VOCAB_FILES_MAP = { 33 | 'vocab_file': 34 | { 35 | 'xlnet-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-base-cased-spiece.model", 36 | 'xlnet-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-large-cased-spiece.model", 37 | } 38 | } 39 | 40 | PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { 41 | 'xlnet-base-cased': None, 42 | 'xlnet-large-cased': None, 43 | } 44 | 45 | SPIECE_UNDERLINE = u'▁' 46 | 47 | # Segments (not really needed) 48 | SEG_ID_A = 0 49 | SEG_ID_B = 1 50 | SEG_ID_CLS = 2 51 | SEG_ID_SEP = 3 52 | SEG_ID_PAD = 4 53 | 54 | class XLNetTokenizer(PreTrainedTokenizer): 55 | """ 56 | SentencePiece based tokenizer. Peculiarities: 57 | 58 | - requires `SentencePiece <https://github.com/google/sentencepiece>`_ 59 | """ 60 | vocab_files_names = VOCAB_FILES_NAMES 61 | pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP 62 | max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES 63 | 64 | def __init__(self, vocab_file, max_len=None, 65 | do_lower_case=False, remove_space=True, keep_accents=False, 66 | bos_token="<s>", eos_token="</s>", unk_token="<unk>", sep_token="<sep>", 67 | pad_token="<pad>", cls_token="<cls>", mask_token="<mask>", 68 | additional_special_tokens=["<eop>", "<eod>"], **kwargs): 69 | super(XLNetTokenizer, self).__init__(bos_token=bos_token, eos_token=eos_token, 70 | unk_token=unk_token, sep_token=sep_token, 71 | pad_token=pad_token, cls_token=cls_token, 72 | mask_token=mask_token, additional_special_tokens= 73 | additional_special_tokens, **kwargs) 74 | 75 | self.max_len_single_sentence = self.max_len - 2 # take into account special tokens 76 | self.max_len_sentences_pair = self.max_len - 3 # take into account special tokens 77 | 78 | try: 79 | import sentencepiece as spm 80 | except ImportError: 81 | logger.warning("You need to install SentencePiece to use XLNetTokenizer: https://github.com/google/sentencepiece" 82 | "pip install sentencepiece") 83 | 84 | self.do_lower_case = do_lower_case 85 | self.remove_space = remove_space 86 | self.keep_accents = keep_accents 87 | self.vocab_file = vocab_file 88 | 89 | self.sp_model = spm.SentencePieceProcessor() 90 | self.sp_model.Load(vocab_file) 91 | 92 | @property 93 | def vocab_size(self): 94 | return len(self.sp_model) 95 | 96 | def __getstate__(self): 97 | state = self.__dict__.copy() 98 | state["sp_model"] = None 99 | return state 100 | 101 | def __setstate__(self, d): 102 | self.__dict__ = d 103 | try: 104 | import sentencepiece as spm 105 | except ImportError: 106 | logger.warning("You need to install SentencePiece to use XLNetTokenizer: https://github.com/google/sentencepiece" 107 | "pip install sentencepiece") 108 | self.sp_model = spm.SentencePieceProcessor() 109 | self.sp_model.Load(self.vocab_file) 110 | 111 | def preprocess_text(self, inputs): 112 | if self.remove_space: 113 | outputs = ' '.join(inputs.strip().split()) 114 | else: 115 | outputs = inputs 116 | outputs = outputs.replace("``", '"').replace("''", '"') 117 | 118 | if six.PY2 and isinstance(outputs, str): 119 | outputs = outputs.decode('utf-8') 120 | 121 | if not self.keep_accents: 122 | outputs = unicodedata.normalize('NFKD', outputs) 123 | outputs = ''.join([c for c in outputs if not unicodedata.combining(c)]) 124 | if self.do_lower_case: 125 | outputs = outputs.lower() 126 | 127 | return outputs 128 | 129 | def _tokenize(self, text, return_unicode=True, sample=False): 130 | """ Tokenize a string. 131 | return_unicode is used only for py2 132 | """ 133 | text = self.preprocess_text(text) 134 | # note(zhiliny): in some systems, sentencepiece only accepts str for py2 135 | if six.PY2 and isinstance(text, unicode): 136 | text = text.encode('utf-8') 137 | 138 | if not sample: 139 | pieces = self.sp_model.EncodeAsPieces(text) 140 | else: 141 | pieces = self.sp_model.SampleEncodeAsPieces(text, 64, 0.1) 142 | new_pieces = [] 143 | for piece in pieces: 144 | if len(piece) > 1 and piece[-1] == ',' and piece[-2].isdigit(): 145 | cur_pieces = self.sp_model.EncodeAsPieces( 146 | piece[:-1].replace(SPIECE_UNDERLINE, '')) 147 | if piece[0] != SPIECE_UNDERLINE and cur_pieces[0][0] == SPIECE_UNDERLINE: 148 | if len(cur_pieces[0]) == 1: 149 | cur_pieces = cur_pieces[1:] 150 | else: 151 | cur_pieces[0] = cur_pieces[0][1:] 152 | cur_pieces.append(piece[-1]) 153 | new_pieces.extend(cur_pieces) 154 | else: 155 | new_pieces.append(piece) 156 | 157 | # note(zhiliny): convert back to unicode for py2 158 | if six.PY2 and return_unicode: 159 | ret_pieces = [] 160 | for piece in new_pieces: 161 | if isinstance(piece, str): 162 | piece = piece.decode('utf-8') 163 | ret_pieces.append(piece) 164 | new_pieces = ret_pieces 165 | 166 | return new_pieces 167 | 168 | def _convert_token_to_id(self, token): 169 | """ Converts a token (str/unicode) in an id using the vocab. """ 170 | return self.sp_model.PieceToId(token) 171 | 172 | def _convert_id_to_token(self, index, return_unicode=True): 173 | """Converts an index (integer) in a token (string/unicode) using the vocab.""" 174 | token = self.sp_model.IdToPiece(index) 175 | if six.PY2 and return_unicode and isinstance(token, str): 176 | token = token.decode('utf-8') 177 | return token 178 | 179 | def convert_tokens_to_string(self, tokens): 180 | """Converts a sequence of tokens (strings for sub-words) in a single string.""" 181 | out_string = ''.join(tokens).replace(SPIECE_UNDERLINE, ' ').strip() 182 | return out_string 183 | 184 | def add_special_tokens_single_sentence(self, token_ids): 185 | """ 186 | Adds special tokens to a sequence pair for sequence classification tasks. 187 | An XLNet sequence pair has the following format: A [SEP] B [SEP][CLS] 188 | """ 189 | sep = [self._convert_token_to_id(self.sep_token)] 190 | cls = [self._convert_token_to_id(self.cls_token)] 191 | return token_ids + sep + cls 192 | 193 | def add_special_tokens_sentences_pair(self, token_ids_0, token_ids_1): 194 | """ 195 | Adds special tokens to a sequence for sequence classification tasks. 196 | An XLNet sequence has the following format: X [SEP][CLS] 197 | """ 198 | sep = [self._convert_token_to_id(self.sep_token)] 199 | cls = [self._convert_token_to_id(self.cls_token)] 200 | return token_ids_0 + sep + token_ids_1 + sep + cls 201 | 202 | def save_vocabulary(self, save_directory): 203 | """ Save the sentencepiece vocabulary (copy original file) and special tokens file 204 | to a directory. 205 | """ 206 | if not os.path.isdir(save_directory): 207 | logger.error("Vocabulary path ({}) should be a directory".format(save_directory)) 208 | return 209 | out_vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES['vocab_file']) 210 | 211 | if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file): 212 | copyfile(self.vocab_file, out_vocab_file) 213 | 214 | return (out_vocab_file,) 215 | -------------------------------------------------------------------------------- /rank_bm25.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import math 4 | import numpy as np 5 | from multiprocessing import Pool, cpu_count 6 | 7 | """ 8 | All of these algorithms have been taken from the paper: 9 | Trotmam et al, Improvements to BM25 and Language Models Examined 10 | 11 | Here we implement all the BM25 variations mentioned. 12 | """ 13 | 14 | 15 | class BM25: 16 | def __init__(self, corpus, tokenizer=None): 17 | self.corpus_size = len(corpus) 18 | self.avgdl = 0 19 | self.doc_freqs = [] 20 | self.idf = {} 21 | self.doc_len = [] 22 | self.tokenizer = tokenizer 23 | 24 | if tokenizer: 25 | corpus = self._tokenize_corpus(corpus) 26 | 27 | nd = self._initialize(corpus) 28 | self._calc_idf(nd) 29 | 30 | def _initialize(self, corpus): 31 | nd = {} # word -> number of documents with word 32 | num_doc = 0 33 | for document in corpus: 34 | self.doc_len.append(len(document)) 35 | num_doc += len(document) 36 | 37 | frequencies = {} 38 | for word in document: 39 | if word not in frequencies: 40 | frequencies[word] = 0 41 | frequencies[word] += 1 42 | self.doc_freqs.append(frequencies) 43 | 44 | for word, freq in frequencies.items(): 45 | if word not in nd: 46 | nd[word] = 0 47 | nd[word] += 1 48 | 49 | self.avgdl = num_doc / self.corpus_size 50 | return nd 51 | 52 | def _tokenize_corpus(self, corpus): 53 | pool = Pool(cpu_count()) 54 | tokenized_corpus = pool.map(self.tokenizer, corpus) 55 | return tokenized_corpus 56 | 57 | def _calc_idf(self, nd): 58 | raise NotImplementedError() 59 | 60 | def get_scores(self, query): 61 | raise NotImplementedError() 62 | 63 | def get_top_n(self, query, documents, n=5): 64 | 65 | assert self.corpus_size == len(documents), "The documents given don't match the index corpus!" 66 | 67 | scores = self.get_scores(query) 68 | top_n = np.argsort(scores)[::-1][:n] 69 | return [documents[i] for i in top_n] 70 | 71 | 72 | class BM25Okapi(BM25): 73 | def __init__(self, corpus, tokenizer=None, k1=1.5, b=0.75, epsilon=0.25): 74 | self.k1 = k1 75 | self.b = b 76 | self.epsilon = epsilon 77 | super().__init__(corpus, tokenizer) 78 | 79 | def _calc_idf(self, nd): 80 | """ 81 | Calculates frequencies of terms in documents and in corpus. 82 | This algorithm sets a floor on the idf values to eps * average_idf 83 | """ 84 | # collect idf sum to calculate an average idf for epsilon value 85 | idf_sum = 0 86 | # collect words with negative idf to set them a special epsilon value. 87 | # idf can be negative if word is contained in more than half of documents 88 | negative_idfs = [] 89 | for word, freq in nd.items(): 90 | idf = math.log(self.corpus_size - freq + 0.5) - math.log(freq + 0.5) 91 | self.idf[word] = idf 92 | idf_sum += idf 93 | if idf < 0: 94 | negative_idfs.append(word) 95 | self.average_idf = idf_sum / len(self.idf) 96 | 97 | eps = self.epsilon * self.average_idf 98 | for word in negative_idfs: 99 | self.idf[word] = eps 100 | 101 | def get_scores(self, query): 102 | """ 103 | The ATIRE BM25 variant uses an idf function which uses a log(idf) score. To prevent negative idf scores, 104 | this algorithm also adds a floor to the idf value of epsilon. 105 | See [Trotman, A., X. Jia, M. Crane, Towards an Efficient and Effective Search Engine] for more info 106 | :param query: 107 | :return: 108 | """ 109 | score = np.zeros(self.corpus_size) 110 | doc_len = np.array(self.doc_len) 111 | for q in query: 112 | q_freq = np.array([(doc.get(q) or 0) for doc in self.doc_freqs]) 113 | score += (self.idf.get(q) or 0) * (q_freq * (self.k1 + 1) / 114 | (q_freq + self.k1 * (1 - self.b + self.b * doc_len / self.avgdl))) 115 | return score 116 | 117 | 118 | class BM25L(BM25): 119 | def __init__(self, corpus, tokenizer=None, k1=1.5, b=0.75, delta=0.5): 120 | # Algorithm specific parameters 121 | self.k1 = k1 122 | self.b = b 123 | self.delta = delta 124 | super().__init__(corpus, tokenizer) 125 | 126 | def _calc_idf(self, nd): 127 | for word, freq in nd.items(): 128 | idf = math.log(self.corpus_size + 1) - math.log(freq + 0.5) 129 | self.idf[word] = idf 130 | 131 | def get_scores(self, query): 132 | score = np.zeros(self.corpus_size) 133 | doc_len = np.array(self.doc_len) 134 | for q in query: 135 | q_freq = np.array([(doc.get(q) or 0) for doc in self.doc_freqs]) 136 | ctd = q_freq / (1 - self.b + self.b * doc_len / self.avgdl) 137 | score += (self.idf.get(q) or 0) * q_freq * (self.k1 + 1) * (ctd + self.delta) / \ 138 | (self.k1 + ctd + self.delta) 139 | return score 140 | 141 | 142 | class BM25Plus(BM25): 143 | def __init__(self, corpus, tokenizer=None, k1=1.5, b=0.75, delta=1): 144 | # Algorithm specific parameters 145 | self.k1 = k1 146 | self.b = b 147 | self.delta = delta 148 | super().__init__(corpus, tokenizer) 149 | 150 | def _calc_idf(self, nd): 151 | for word, freq in nd.items(): 152 | idf = math.log((self.corpus_size + 1) / freq) 153 | self.idf[word] = idf 154 | 155 | def get_scores(self, query): 156 | score = np.zeros(self.corpus_size) 157 | doc_len = np.array(self.doc_len) 158 | for q in query: 159 | q_freq = np.array([(doc.get(q) or 0) for doc in self.doc_freqs]) 160 | score += (self.idf.get(q) or 0) * (self.delta + (q_freq * (self.k1 + 1)) / 161 | (self.k1 * (1 - self.b + self.b * doc_len / self.avgdl) + q_freq)) 162 | return score 163 | 164 | 165 | # BM25Adpt and BM25T are a bit more complicated than the previous algorithms here. Here a term-specific k1 166 | # parameter is calculated before scoring is done 167 | 168 | # class BM25Adpt(BM25): 169 | # def __init__(self, corpus, k1=1.5, b=0.75, delta=1): 170 | # # Algorithm specific parameters 171 | # self.k1 = k1 172 | # self.b = b 173 | # self.delta = delta 174 | # super().__init__(corpus) 175 | # 176 | # def _calc_idf(self, nd): 177 | # for word, freq in nd.items(): 178 | # idf = math.log((self.corpus_size + 1) / freq) 179 | # self.idf[word] = idf 180 | # 181 | # def get_scores(self, query): 182 | # score = np.zeros(self.corpus_size) 183 | # doc_len = np.array(self.doc_len) 184 | # for q in query: 185 | # q_freq = np.array([(doc.get(q) or 0) for doc in self.doc_freqs]) 186 | # score += (self.idf.get(q) or 0) * (self.delta + (q_freq * (self.k1 + 1)) / 187 | # (self.k1 * (1 - self.b + self.b * doc_len / self.avgdl) + q_freq)) 188 | # return score 189 | # 190 | # 191 | # class BM25T(BM25): 192 | # def __init__(self, corpus, k1=1.5, b=0.75, delta=1): 193 | # # Algorithm specific parameters 194 | # self.k1 = k1 195 | # self.b = b 196 | # self.delta = delta 197 | # super().__init__(corpus) 198 | # 199 | # def _calc_idf(self, nd): 200 | # for word, freq in nd.items(): 201 | # idf = math.log((self.corpus_size + 1) / freq) 202 | # self.idf[word] = idf 203 | # 204 | # def get_scores(self, query): 205 | # score = np.zeros(self.corpus_size) 206 | # doc_len = np.array(self.doc_len) 207 | # for q in query: 208 | # q_freq = np.array([(doc.get(q) or 0) for doc in self.doc_freqs]) 209 | # score += (self.idf.get(q) or 0) * (self.delta + (q_freq * (self.k1 + 1)) / 210 | # (self.k1 * (1 - self.b + self.b * doc_len / self.avgdl) + q_freq)) 211 | # return score 212 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tagme 2 | scipy 3 | qwikidata 4 | 5 | -------------------------------------------------------------------------------- /retrieve_hybrid.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import time 4 | import json 5 | import numpy as np 6 | 7 | from tqdm import tqdm 8 | from collections import Counter, defaultdict 9 | from retriever import Retriever 10 | from WikiData import MyWikiData 11 | 12 | class Data(object): 13 | 14 | def __init__(self, args, wikidata): 15 | self.data_path = "data/{}/{}-{}.qa.json".format(args.data, args.data, args.data_type) 16 | self.tagme_path = "data/{}/{}-{}.tagme.json".format(args.data, args.data, args.data_type) 17 | 18 | self.wikidata = wikidata 19 | self.retriever = Retriever(args) 20 | 21 | with open(self.data_path, 'r') as f: 22 | orig_data = json.load(f) 23 | 24 | with open(self.tagme_path, 'r') as f: 25 | tagme_data = json.load(f) 26 | 27 | assert len(orig_data)==len(tagme_data) 28 | print ("Loaded {} QA data".format(len(orig_data))) 29 | 30 | self.save_path = "data/{}/{}-{}.retrieved.json".format(args.data, args.data, args.data_type) 31 | 32 | #### data to save ### 33 | n_cross_relations = [] 34 | n_inner_relations = [] 35 | n_total_relations = [] 36 | data_to_save = [] 37 | 38 | N_TFIDF = 5 if args.data=="webquestions" else 10 39 | N_BM25 = 40 if args.data=="webquestions" else 80 40 | 41 | for i, (d, tags) in tqdm(enumerate(zip(orig_data, tagme_data))): 42 | if len(tags)>0: 43 | sorted_tags = sorted(tags, key=lambda x: -x['score']) if 'score' in tags[0] else tags.copy() 44 | tags = [] 45 | for e in sorted_tags: 46 | # for some reason, tagme keeps tagging "The Who" for "who" questions. 47 | # we will exclude them. 48 | if not ((e['entity']=='The Who' and e['mention']=='who') or e["entity"]=="song"): 49 | if e['entity'] not in tags: 50 | tags.append(e['entity']) 51 | 52 | tfidf_docs = self.retriever.get_titles_from_query(d['question'], N_TFIDF) 53 | for t in tfidf_docs: 54 | if t not in tags: 55 | tags.append(t) 56 | keywords = self.wikidata.populate(tags, k=args.n_hops, use_aliases=False) 57 | collected_docs = set() 58 | collected_paragraphs = [] 59 | paragraphs_to_run_bm25 = [] 60 | for (doc_name, hop, relation) in keywords[:80]: 61 | if doc_name in collected_docs: 62 | continue 63 | collected_docs.add(doc_name) 64 | contents = self.retriever.get_contents_from_title(doc_name, 65 | n_words=self.retriever.get_n_words(d['question'], doc_name), 66 | only_first=hop>0) 67 | if len(contents)==0: 68 | continue 69 | collected_paragraphs.append((contents[0], hop, relation)) 70 | assert hop==0 or len(contents)==1 71 | paragraphs_to_run_bm25 += [(content, relation) for content in contents[1:]] 72 | 73 | collected_paragraphs = [par for i, par in sorted(enumerate(collected_paragraphs), 74 | key=lambda x: (x[1][1], x[1][0][1], x[0]))] 75 | bm25_paragraphs = self.retriever.get_paragraphs_from_documents(d['question'], 76 | paragraphs_to_run_bm25, 77 | N_BM25, 78 | only_first=False, 79 | is_tuple=True) 80 | pars = [(par, rel) for par, hop, rel in collected_paragraphs if hop==0] 81 | pars_1 = [(par, rel) for par, hop, rel in collected_paragraphs if hop==1] 82 | for p_i in range(len(bm25_paragraphs)): 83 | if len(pars_1)>p_i: 84 | pars.append(pars_1[p_i]) 85 | pars.append(bm25_paragraphs[p_i]) 86 | pars += self.retriever.get_paragraphs_from_documents(d['question'], 87 | pars_1[len(bm25_paragraphs):], 88 | 100, 89 | only_first=False, is_tuple=True) 90 | pars += self.retriever.get_paragraphs_from_documents(d['question'], 91 | [(par, rel) for par, hop, rel in collected_paragraphs if hop>1], 92 | 100, 93 | only_first=False, is_tuple=True) 94 | # truncate pars to be 100 at maximum 95 | pars = pars[:100] 96 | 97 | relations = [p[1] for p in pars] 98 | pars = [p[0] for p in pars] 99 | 100 | # get graph information for the GrpahReader 101 | collected_docs = set([par[0] for par in pars]) 102 | graph = self.wikidata.get_graph(collected_docs) 103 | constructed_graph = {} 104 | n_cross, n_inner = 0, 0 105 | for i1, (title1, index1, _) in enumerate(pars): 106 | for i2, (title2, index2, _) in enumerate(pars): 107 | if i1==i2: continue 108 | if (title1, title2) in graph and index1==index2==0: 109 | constructed_graph[(i1, i2)] = graph[(title1, title2)] 110 | n_cross += 1 111 | if title1==title2 and index1==0 and index2>0: 112 | constructed_graph[(i1, i2)] = ["<CHILD_PARAGRAPH>"] 113 | constructed_graph[(i2, i1)] = ["<PARENT_PARAGRAPH>"] 114 | n_inner += 2 115 | n_cross_relations.append(n_cross) 116 | n_inner_relations.append(n_inner) 117 | n_total_relations.append(n_cross+n_inner) 118 | data_to_save.append(json.dumps({ 119 | 'question': d['question'], 120 | 'answers': d['answers'], 121 | 'paragraphs': pars, 122 | 'graph': {'{} {}'.format(k[0], k[1]): v for k, v in constructed_graph.items()} 123 | })) 124 | 125 | print ("Cross", np.mean(n_cross_relations)) 126 | print ("Inner", np.mean(n_inner_relations)) 127 | print ("Total", np.mean(n_total_relations)) 128 | 129 | with open(self.save_path, 'w') as f: 130 | f.write("\n".join(data_to_save)) 131 | 132 | if __name__ == '__main__': 133 | parser = argparse.ArgumentParser() 134 | parser.add_argument('--data', type=str, default="webquestions") 135 | parser.add_argument('--data_type', type=str, default="dev") 136 | parser.add_argument('--tfidf_path', type=str, 137 | default="/data/sewon/wikipedia/docs-tfidf.npz") 138 | parser.add_argument('--dump_path', type=str, default="/data/sewon/wikidata-20190708-all.json.bz2") 139 | parser.add_argument('--data_dir', type=str, default="/data/sewon/MyWikidata") 140 | parser.add_argument('--n_hops', type=int, default=2) 141 | parser.add_argument('--new', action="store_true") 142 | parser.add_argument('--wiki_db_path', type=str, 143 | default="/data/sewon/wikipedia/docs.db") 144 | #parser.add_argument('--vocab_file', type=str, 145 | # default="/data/home/sewon/bert_vocab.txt") 146 | args = parser.parse_args() 147 | wikidata = MyWikiData(args) 148 | data = Data(args, wikidata) 149 | 150 | 151 | -------------------------------------------------------------------------------- /retriever.py: -------------------------------------------------------------------------------- 1 | import time 2 | import math 3 | import numpy as np 4 | import argparse 5 | from tqdm import tqdm 6 | from collections import defaultdict, Counter 7 | 8 | import drqa_retriever as retriever 9 | from drqa_retriever import DocDB 10 | from rank_bm25 import BM25Okapi 11 | from Database import MyDatabase 12 | 13 | from pytorch_transformers import BertTokenizer, BasicTokenizer 14 | 15 | title_s = "<t>" 16 | title_e = "</t>" 17 | 18 | SEP1 = "<@@SEP@@>" 19 | SEP2 = "<##SEP##>" 20 | SEP3 = "<$$SEP$$>" 21 | 22 | 23 | class Retriever(object): 24 | 25 | def __init__(self, args, need_vocab=True): 26 | self.tfidf_path=args.tfidf_path 27 | self.ranker = retriever.get_class('tfidf')(tfidf_path=self.tfidf_path) 28 | self.first_para_only = False 29 | self.db = DocDB(args.wiki_db_path) 30 | self.L = 300 31 | self.first_para_only = False 32 | 33 | if need_vocab: 34 | tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') 35 | btokenizer = BasicTokenizer() 36 | self.tokenize = lambda c, t_c: tokenizer.tokenize(c) 37 | self.btokenize = btokenizer.tokenize 38 | 39 | self.keyword2title = defaultdict(list) 40 | self.cache = {} 41 | 42 | def get_titles_from_query(self, query, n_docs): 43 | try: 44 | doc_names, doc_scores = self.ranker.closest_docs(query, n_docs) 45 | except Exception: 46 | return [] 47 | return doc_names 48 | 49 | def get_contents_from_title(self, doc_name, n_words, only_first): 50 | if doc_name in self.cache: 51 | contents = self.cache[doc_name] 52 | else: 53 | try: 54 | contents = self.db.get_doc_text(doc_name).split('\n\n') 55 | except Exception: 56 | return [] 57 | if contents[0]==doc_name: 58 | contents = contents[1:] 59 | contents = [c for c in contents if len(c.strip())>0] 60 | for i, c in enumerate(contents): 61 | t_c = self.btokenize(c) 62 | t_c2 = self.tokenize(c, t_c) 63 | contents[i] = "{}{}{}".format(SEP1.join(t_c), SEP2, SEP1.join(t_c2)) 64 | contents = SEP3.join(contents) 65 | self.cache[doc_name] = contents 66 | if len(contents)==0: 67 | return [] 68 | contents = [[ci.split(SEP1) for ci in c.split(SEP2)] for c in contents.split(SEP3)] 69 | return self.get_preprocessed_paragraphs(doc_name, contents.copy(), n_words=n_words, 70 | only_first=only_first) 71 | 72 | 73 | def get_preprocessed_paragraphs(self, doc_name, contents, n_words, only_first=False): 74 | curr_paragraphs = [] 75 | curr_lengths = [] 76 | for tokenized_par, tokenized_par2 in contents: 77 | l = len(tokenized_par2) 78 | if len(curr_lengths)>0 and l<=n_words-curr_lengths[-1]-3: 79 | curr_paragraphs[-1] += ["<p>"] 80 | offset = l-len(tokenized_par) 81 | assert offset>=0 82 | curr_paragraphs[-1] += tokenized_par.copy() 83 | curr_lengths[-1] += l if curr_lengths[-1]==0 else l+3 84 | else: 85 | if l>n_words: 86 | offset = n_words-len(tokenized_par2)+len(tokenized_par) 87 | if offset<=n_words/2.0: 88 | continue 89 | tokenized_par = tokenized_par[:offset].copy() 90 | curr_paragraphs.append(tokenized_par.copy()) 91 | curr_lengths.append(l) 92 | #assert curr_lengths[-1]<=n_words 93 | if only_first and len(curr_paragraphs)>1: 94 | curr_paragraphs = curr_paragraphs[:1] 95 | break 96 | tok_doc_name = self.btokenize(doc_name) 97 | return [[doc_name, i, [title_s] + tok_doc_name + [title_e] + t] 98 | for i, t in enumerate(curr_paragraphs)] 99 | 100 | def get_paragraphs_from_documents(self, query, _paragraphs, n_paragraphs, 101 | only_first=False, is_tuple=False): 102 | 103 | if len(_paragraphs)==0 or only_first: 104 | return _paragraphs 105 | 106 | if is_tuple: 107 | relations = [p[1] for p in _paragraphs] 108 | _paragraphs = [p[0] for p in _paragraphs] 109 | 110 | bm25 = BM25Okapi([p[2] for p in _paragraphs]) 111 | paragraphs = [] 112 | for index, score in sorted(enumerate(bm25.get_scores(self.btokenize(query)).tolist()), 113 | key=lambda x: (-x[1], x[0])): 114 | if score==0 or len(paragraphs)==n_paragraphs: 115 | break 116 | if is_tuple: 117 | paragraphs.append((_paragraphs[index], relations[index])) 118 | else: 119 | paragraphs.append(_paragraphs[index]) 120 | return paragraphs 121 | 122 | def get_n_words(self, query, doc_name): 123 | n_words = self.L - len(self.tokenize(query, self.btokenize(query))) - 7 - 12 124 | return 10*math.floor((n_words-len(self.tokenize(doc_name, self.btokenize(doc_name))))/10.0) 125 | 126 | def get_contents_from_query(self, query, n_docs, only_first=False): 127 | doc_names = self.get_titles_from_query(query, n_docs) 128 | return [self.get_contents_from_title(doc_name, 129 | n_words=self.get_n_words(query, doc_name), 130 | only_first=only_first) 131 | for doc_name in doc_names] 132 | 133 | def get_paragraphs_from_titles(self, query, doc_names, n_paragraphs, only_first, 134 | run_bm25=False): 135 | contents = [] 136 | for doc_name in doc_names: 137 | contents += self.get_contents_from_title(doc_name, 138 | n_words=self.get_n_words(query, doc_name), 139 | only_first=only_first) 140 | if len(contents)>=n_paragraphs and not run_bm25: 141 | break 142 | if not run_bm25: 143 | return contents[:n_paragraphs] 144 | paragraphs = self.get_paragraphs_from_documents(query, contents, n_paragraphs, 145 | only_first=only_first) 146 | return paragraphs #[:n_paragraphs] 147 | 148 | def get_paragraphs_from_query(self, query, n_docs, n_paragraphs, only_first=False): 149 | doc_names = self.get_titles_from_query(query, n_docs) 150 | return self.get_paragraphs_from_titles(query, doc_names, n_paragraphs, 151 | only_first=only_first, run_bm25=True) 152 | 153 | def get_paragraphs_from_keywords(self, query, keywords, n_paragraphs, only_first=True): 154 | doc_names = [] 155 | for keyword in keywords: 156 | if type(keyword)==tuple: 157 | keyword, _ = keyword 158 | assert keyword in self.keyword2title 159 | for doc_name in self.keyword2title[keyword]: 160 | if doc_name not in doc_names: 161 | doc_names.append(doc_name) 162 | return self.get_paragraphs_from_titles(query, doc_names, n_paragraphs, only_first=only_first) 163 | 164 | def get_keyword2title(self, keywords): 165 | keyword2title = defaultdict(list) 166 | for keyword in keywords: 167 | if type(keyword)==tuple: 168 | keyword, aliases = keyword 169 | else: 170 | aliases = [] 171 | if keyword in keyword2title: 172 | continue 173 | if keyword in self.keyword2title: 174 | keyword2title[keyword] = self.keyword2title[keyword] 175 | continue 176 | if self.db.get_doc_text(keyword) is not None: 177 | keyword2title[keyword].append(keyword) 178 | else: 179 | for t in aliases: 180 | if t!=keyword and self.db.get_doc_text(t) is not None: 181 | keyword2title[keyword].append(t) 182 | if len(keyword2title[keyword])==0: 183 | doc = self.get_titles_from_query(keyword, 1) 184 | if len(doc)>0 and doc[0]!=keyword and doc[0] not in aliases and self.db.get_doc_text(doc[0]) is not None: 185 | keyword2title[keyword].append(doc[0]) 186 | self.keyword2title.update(keyword2title) 187 | return keyword2title 188 | 189 | --------------------------------------------------------------------------------