├── README.md ├── convert_msmarco_to_duobert_tfrecord.py ├── duobert_architecture.svg ├── metrics.py ├── modeling.py ├── modeling_test.py ├── msmarco_eval.py ├── optimization.py ├── optimization_test.py ├── requirements.txt ├── run_duobert_msmarco.py ├── tokenization.py └── tokenization_test.py /README.md: -------------------------------------------------------------------------------- 1 | # duoBERT 2 | 3 | duoBERT is a pairwise ranking model based on BERT that is the last stage of a multi-stage retrieval pipeline: 4 | 5 | ![duobert](duobert_architecture.svg) 6 | 7 | To train and re-rank with monoBERT, please check [this repository](https://github.com/nyu-dl/dl4marco-bert). 8 | 9 | As of Jan 13th 2020, our MS MARCO leaderboard entry is the top scoring model with available code: 10 | 11 | MSMARCO Passage Re-Ranking Leaderboard (Jan 13th 2020) | Eval MRR@10 | Dev MRR@10 12 | ------------------------------------- | :------: | :------: 13 | SOTA - Enriched BERT base + AOA index + CAS | 0.393 | 0.408 14 | BM25 + monoBERT + duoBERT + TCP (this code) | 0.379 | 0.390 15 | 16 | For more details, check out our paper: 17 | 18 | + Rodrigo Nogueira, Wei Yang, Kyunghyun Cho, and Jimmy Lin. [Multi-Stage Document Ranking with BERT.](https://arxiv.org/abs/1910.14424) _arXiv:1910.14424_, October 2019. 19 | 20 | **NOTE!** The duoBERT model is no longer under active development and this repo is no longer being maintained. 21 | We have shifted our efforts to [ranking with sequence-to-sequence models](https://www.aclweb.org/anthology/2020.findings-emnlp.63/). 22 | A T5-based variant of the mono/duo design is described in [an overview of our submissions to the TREC-COVID challenge](https://www.aclweb.org/anthology/2020.sdp-1.5/), and a more detailed description of mono/duoT5 is in preparation. 23 | 24 | ## Data and Trained Models 25 | 26 | We make the following data available for download: 27 | 28 | + `bert-large-msmarco-pretrained_only.zip`: monoBERT large pretrained on the MS MARCO corpus but not finetuned on the ranking task. We pretrained this model starting from the original BERT-large WWM (Whole Word Mask) checkpoint. It was pretrained for 100k iterations, batch size 128, learning rate 3e-6, and 10k warmup steps. We finetuned monoBERT and duoBERT from this checkpoint. 29 | + `monobert-large-msmarco-pretrained-and-finetuned.zip`: monoBERT large pretrained on the MS MARCO corpus and finetuned on the MS MARCO ranking task. 30 | + `duobert-large-msmarco-pretrained-and-finetuned.zip`: duoBERT large pretrained on the MS MARCO corpus and finetuned on the MS MARCO ranking task. 31 | + `run.bm25.dev.small.tsv`: Approximately 6,980,000 pairs of dev set queries and retrieved passages using BM25. In this tsv file, the first column is the query id, the second column is the passage id, and the third column is the rank of the passage. There are 1000 passages per query in this file. 32 | + `run.bm25.test.small.tsv`: Approximately 6,837,000 pairs of test set queries and retrieved passages using BM25. 33 | + `run.monobert.dev.small.tsv`: Approximately 6,980,000 pairs of dev set queries and retrieved passages using BM25 and re-ranked with monoBERT. In this tsv file, the first column is the query id, the second column is the passage id, and the third column is the rank of the passage. There are 1000 passages per query in this file. 34 | + `run.monobert.test.small.tsv`: Approximately 6,837,000 pairs of test set queries and retrieved passages using BM25 and re-ranked with monoBERT. 35 | + `run.duobert.dev.small.tsv`: Approximately 6,980 x 30 pairs of dev set queries and passages re-ranked using duoBERT. In this run, the input to duoBERT were the top-30 passages re-ranked by monoBERT. 36 | + `run.duobert.test.tsv`: Approximately 6,837 x 30 pairs of test set queries and passages re-ranked using duoBERT. In this run, the input to duoBERT were the top-30 passages re-ranked by monoBERT. 37 | + `dataset_train.tf`: Approximately 80M pairs of training set queries and passages (40M relevant and 40M non-relevant) in the TF Record format. 38 | + `dataset_dev.tf`: Approximately 6,980 x 30 pairs of dev set queries and passages in the TF Record format. These top-30 passages will be re-ranked by duoBERT. 39 | + `dataset_test.tf`: Approximately 6,837 x 30 pairs of test set queries and passages in the TF Record format. These top-30 passages will be re-ranked by duoBERT. 40 | + `query_doc_ids_dev.txt`: Approximately 6,980 x 30 pairs of query and doc id that will be used during inference. 41 | + `query_doc_ids_test.txt`: Approximately 6,837 x 30 pairs of query and doc id that will be used during inference. 42 | + `queries.dev.small.tsv`: 6,980 queries from the MS MARCO dev set. In this tsv file, the first column is the query id, and the second is the query text. 43 | + `queries.eval.small.tsv`: 6,837 queries from the MS MARCO test (eval) set. In this tsv file, the first column is the query id, and the second is the query text. 44 | + `qrels.dev.small.tsv`: 7,437 pairs of query relevant passage ids from the MS MARCO dev set. In this tsv file, the first column is the query id, and the third column is the passage id. The other two columns (second and fourth) are not used. 45 | + `collection.tar.gz`: All passages (8,841,823) in the MS MARCO passage corpus. In this tsv file, the first column is the passage id, and the second is the passage text. 46 | + `triples.train.small.tar.gz`: Approximatelly 40M triples of query, relevant and non-relevant passages that are used to train duoBERT. 47 | 48 | Download and verify the above files from the below table: 49 | 50 | File | Size | MD5 | Download 51 | :----|-----:|:----|:----- 52 | `bert-large-msmarco-pretrained-only.zip` | 3.44 GB | `88f1d0bd351058b1da1eb49b60c2e750` | [[Dropbox](https://www.dropbox.com/s/nvqs8qk7q63qr0s/bert-large-msmarco-pretrained-only.zip?dl=1)] 53 | `monobert-large-msmarco-pretrained-and-finetuned.zip` | 3.42 GB | `db201b6433b3e605201746bda6b7723b` | [[Dropbox](https://www.dropbox.com/s/fhy7vf5488muz9u/monobert-large-msmarco-pretrained-and-finetuned.zip?dl=1)] 54 | `duobert-large-msmarco-pretrained-and-finetuned.zip` | 3.43 GB | `dcae7441103ae8241f16df743b75337b` | [[Dropbox](https://www.dropbox.com/s/kxd8fitk4ax1hb5/duobert-large-msmarco-pretrained-and-finetuned.zip?dl=1)] 55 | `run.bm25.dev.small.tsv.gz` | 44 MB | `0a7802ab41999161339087186dda4145` | [[Dropbox](https://www.dropbox.com/s/5pqpcnlzlib2b3a/run.bm25.dev.small.tsv.gz?dl=1)] 56 | `run.bm25.test.small.tsv.gz` | 43 MB | `1ea465405f6a2467cb62015454bc88c7` | [[Dropbox](https://www.dropbox.com/s/6fzxajh79dkw8s1/run.bm25.test.small.tsv.gz?dl=1)] 57 | `run.monobert.dev.small.tsv.gz` | 44 MB | `dee6065e7177facb7c740f607e40ac63` | [[Dropbox](https://www.dropbox.com/s/h5kiff0ofn3djvf/run.monobert.dev.small.tsv.gz?dl=1)] 58 | `run.monobert.test.small.tsv.gz` | 43 MB | `f0e16234351a0a81d83f188e72662fbd` | [[Dropbox](https://www.dropbox.com/s/ctccble07k7lvlc/run.monobert.test.small.tsv.gz?dl=1)] 59 | `run.duobert.dev.small.tsv.gz` | 2.0 MB | `0be1f12ab7c7bd2d913d31756a8f0a19` | [[Dropbox](https://www.dropbox.com/s/fffu74voideid5p/run.duobert.dev.small.tsv.gz?dl=1)] 60 | `run.duobert.test.small.tsv.gz` | 2.0 MB | `0d4f1770f8be20411ed8c00fb727103d` | [[Dropbox](https://www.dropbox.com/s/93bj0ehhse3fbuv/run.duobert.test.small.tsv.gz?dl=1)] 61 | `dataset_train.tf.gz` | 8.8 GB | `7a3a6705f3662837a1e874d7ed970d27` | [[Dropbox](https://www.dropbox.com/s/zi46r0905d2y908/dataset_train.tf.gz?dl=1)] 62 | `dataset_dev.tf.gz` | 241 MB | `f4966bd5426092564a59c1a1c8e34539` | [[Dropbox](https://www.dropbox.com/s/yykiop01sto1fzf/dataset_dev.tf.gz?dl=1)] 63 | `dataset_test.tf.gz` | 236 MB | `5387a926950b112616926fe3d475a22f` | [[Dropbox](https://www.dropbox.com/s/qx97yhq34ndtc7p/dataset_test.tf.gz?dl=1)] 64 | `query_doc_ids_dev.txt.gz` | 19 MB | `05361aead605c1b8a8cc8d71ef3ff0f8` | [[Dropbox](https://www.dropbox.com/s/ttml8v0irfsmqcv/query_doc_ids_dev.txt.gz?dl=1)] 65 | `query_doc_ids_test.txt.gz` | 19 MB | `5e657dff1e1f0748d29b291e5c731f9f` | [[Dropbox](https://www.dropbox.com/s/jvtf3qa8ux3wma8/query_doc_ids_test.txt.gz?dl=1)] 66 | `queries.dev.small.tsv` | 283 KB | `41e980d881317a4a323129d482e9f5e5` | [[Dropbox](https://www.dropbox.com/s/iyw98nof7omynst/queries.dev.small.tsv?dl=1)] 67 | `queries.eval.small.tsv` | 274 KB | `bafaf0b9eb23503d2a5948709f34fc3a` | [[Dropbox](https://www.dropbox.com/s/yst2tz1s9i2z5mx/queries.eval.small.tsv?dl=1)] 68 | `qrels.dev.small.tsv` | 140 KB| `38a80559a561707ac2ec0f150ecd1e8a` | [[Dropbox](https://www.dropbox.com/s/ie27l0mzcjb5fbc/qrels.dev.small.tsv?dl=1)] 69 | `collection.tar.gz` | 987 MB | `87dd01826da3e2ad45447ba5af577628` | [[Dropbox](https://www.dropbox.com/s/m1n2wf80l1lb9j1/collection.tar.gz?dl=1)] 70 | `triples.train.small.tar.gz` | 7.4 GB | `c13bf99ff23ca691105ad12eab837f84` | [[Dropbox](https://www.dropbox.com/s/6r4a8hpcgq0szep/triples.train.small.tar.gz?dl=1)] 71 | 72 | All of the above files are stored in [this repo](https://git.uwaterloo.ca/jimmylin/duobert-data). 73 | As an alternative to downloading each file separately, clone the repo and you'll have everything. 74 | 75 | ## Replicating our MS MARCO results with duoBERT 76 | Here we provide instructions on how to replicate our BM25 + monoBERT + duoBERT + TCP dev run on MS MARCO leaderboard. 77 | 78 | NOTE 1: we will run these experiments using a TPU; thus, you will need a Google Cloud account. Alternatively, you can use a GPU, but we haven't tried ourselves. 79 | 80 | NOTE 2: For instructions on how to train and run inference using monoBERT, please check this [repository](https://github.com/nyu-dl/dl4marco-bert). 81 | 82 | First download the following files (using the links in the table above): 83 | - `qrels.dev.small.tsv` 84 | - `dataset_dev.tf` 85 | - `duobert-large-msmarco-pretrained-and-finetuned.zip` 86 | 87 | Unzip `duobert-large-msmarco-pretrained-and-finetuned.zip` and upload the files to a bucket in the Google Cloud Storage. 88 | 89 | Create a virtual machine with TPU in the Google Cloud. We provide below a 90 | command-line example that should be executed in the Google Cloud Shell (change `your-tpu` 91 | accordingly): 92 | ``` 93 | ctpu up --zone=us-central1-b --name your-tpu --tpu-size=v3-8 --disk-size-gb=250 \ 94 | --machine-type=n1-standard-4 --preemptible --tf-version=1.15 --noconf 95 | ``` 96 | 97 | ssh into the virtual machine and clone the git repo: 98 | ``` 99 | git clone https://github.com/castorini/duobert.git 100 | ``` 101 | 102 | Run duoBERT in evaluation mode (change `your-tpu` and `your-bucket` accordingly): 103 | ``` 104 | python run_duobert_msmarco.py \ 105 | --data_dir=gs://your-bucket \ 106 | --bert_config_file=gs://your-bucket/bert_config.json \ 107 | --output_dir=. \ 108 | --init_checkpoint=gs://your-bucket/model.ckpt-100000 \ 109 | --max_seq_length=512 \ 110 | --do_train=False \ 111 | --do_eval=True \ 112 | --eval_batch_size=128 \ 113 | --num_eval_docs=30 \ 114 | --use_tpu=True \ 115 | --tpu_name=your-tpu \ 116 | --tpu_zone=us-central1-b 117 | ``` 118 | 119 | This inference takes approximately 4 hours on a TPU v3. 120 | Once finished, run the evaluation script: 121 | ``` 122 | python3 msmarco_eval.py qrels.dev.small.tsv ./msmarco_predictions_dev.tsv 123 | ``` 124 | 125 | The output should be like this: 126 | ``` 127 | ##################### 128 | MRR @10: 0.3904377586755809 129 | QueriesRanked: 6980 130 | ##################### 131 | ``` 132 | 133 | ## Training DuoBERT 134 | Here we provide instructions to train duoBERT. Note that a fully trained model is available in the above table. 135 | 136 | First download the following files (using the links in the table above): 137 | - `qrels.dev.small.tsv` 138 | - `dataset_train.tf` 139 | - `bert-large-msmarco-pretrained-only.zip` 140 | 141 | Unzip `bert-large-msmarco-pretrained-only.zip` and upload all files to your Google Cloud Storage bucket. 142 | 143 | Run duoBERT in training mode (change `your-tpu` and `your-bucket` accordingly): 144 | ``` 145 | python run_duobert_msmarco.py \ 146 | --data_dir=gs://your-bucket \ 147 | --bert_config_file=gs://your-bucket/bert_config.json \ 148 | --output_dir=gs://your-bucket/output \ 149 | --init_checkpoint=gs://your-bucket/model.ckpt-100000 \ 150 | --max_seq_length=512 \ 151 | --do_train=True \ 152 | --do_eval=False \ 153 | --learning_rate=3e-6 \ 154 | --train_batch_size=128 \ 155 | --num_train_steps=100000 \ 156 | --num_warmup_steps=10000 \ 157 | --use_tpu=True \ 158 | --tpu_name=your-tpu \ 159 | --tpu_zone=us-central1-b 160 | ``` 161 | 162 | This training should take approximately 30 hours on a TPU v3. 163 | 164 | 165 | ## Creating a TF Record dataset 166 | Here we provide instructions to create the training, dev, and test TF Record files that are consumed by duoBERT. Note that these files are available in the above table. 167 | 168 | Use the links from the table above to download the following files: 169 | - `collection.tar.gz` (needs to be uncompressed) 170 | - `triples.train.small.tar.gz` (needs to be uncompressed) 171 | - `queries.dev.small.tsv` 172 | - `queries.eval.small.tsv` 173 | - `run.monobert.dev.small.tsv` 174 | - `run.monobert.test.small.tsv` 175 | - `qrels.dev.small.tsv` 176 | - `vocab.txt` (available in `duobert-large-msmarco-pretrained-and-finetuned.zip`) 177 | 178 | ``` 179 | python convert_msmarco_to_duobert_tfrecord.py \ 180 | --output_folder=. \ 181 | --corpus=collection.tsv \ 182 | --vocab_file=vocab.txt \ 183 | --triples_train=triples.train.small.tsv \ 184 | --queries_dev=queries.dev.small.tsv \ 185 | --queries_test=queries.eval.small.tsv \ 186 | --run_dev=run.monobert.dev.small.tsv \ 187 | --run_test=run.monobert.test.small.tsv \ 188 | --qrels_dev=qrels.dev.small.tsv \ 189 | --num_dev_docs=30 \ 190 | --num_test_docs=30 \ 191 | --max_seq_length=512 \ 192 | --max_query_length=64 193 | ``` 194 | 195 | This conversion takes approximately 30-50 hours and will produce the following files: 196 | - `dataset_train.tf` 197 | - `dataset_dev.tf` 198 | - `dataset_test.tf` 199 | - `query_doc_ids_dev.txt` 200 | - `query_doc_ids_test.txt` 201 | 202 | 203 | ## How do I cite this work? 204 | ``` 205 | @article{nogueira2019multi, 206 | title={Multi-stage document ranking with BERT}, 207 | author={Nogueira, Rodrigo and Yang, Wei and Cho, Kyunghyun and Lin, Jimmy}, 208 | journal={arXiv preprint arXiv:1910.14424}, 209 | year={2019} 210 | } 211 | ``` 212 | -------------------------------------------------------------------------------- /convert_msmarco_to_duobert_tfrecord.py: -------------------------------------------------------------------------------- 1 | """Converts MS MARCO data into TF Records that will be consumed by duoBERT.""" 2 | import collections 3 | import json 4 | import os 5 | import tensorflow as tf 6 | import time 7 | import tokenization 8 | 9 | from tqdm import tqdm 10 | 11 | 12 | flags = tf.flags 13 | FLAGS = flags.FLAGS 14 | 15 | flags.DEFINE_string( 16 | 'output_folder', None, 'Folder where the TFRecord files will be writen.') 17 | 18 | flags.DEFINE_string( 19 | 'vocab_file', None, 20 | 'The vocabulary file that the BERT model was trained on.') 21 | 22 | flags.DEFINE_string( 23 | 'triples_train', None, 24 | 'TSV file containing query, relevant and non-relevant docs.') 25 | 26 | flags.DEFINE_string( 27 | 'corpus', None, 'Path to the tsv file containing the paragraphs.') 28 | 29 | flags.DEFINE_string( 30 | 'queries_dev', None, 'Path to the pairs for dev.') 31 | 32 | flags.DEFINE_string( 33 | 'queries_test', None, 34 | 'Path to the pairs for test.') 35 | 36 | flags.DEFINE_string( 37 | 'run_dev', None, 'Path to the query id / candidate doc ids pairs for dev.') 38 | 39 | flags.DEFINE_string( 40 | 'run_test', None, 41 | 'Path to the query id / candidate doc ids pairs for test.') 42 | 43 | flags.DEFINE_string( 44 | 'qrels_dev', None, 45 | 'Path to the query id / relevant doc ids pairs for dev.') 46 | 47 | flags.DEFINE_integer( 48 | 'num_dev_docs', 1000, 49 | 'The number of docs per query for the development set.') 50 | 51 | flags.DEFINE_integer( 52 | 'num_test_docs', 1000, 53 | 'The number of docs per query for the test set.') 54 | 55 | flags.DEFINE_integer( 56 | 'max_seq_length', 512, 57 | 'The maximum total input sequence length after WordPiece tokenization. ' 58 | 'Sequences longer than this will be truncated, and sequences shorter than ' 59 | 'this will be padded.') 60 | 61 | flags.DEFINE_integer( 62 | 'max_query_length', 64, 63 | 'The maximum query sequence length after WordPiece tokenization. ' 64 | 'Sequences longer than this will be truncated.') 65 | 66 | flags.DEFINE_string( 67 | 'pad_doc_id', '5500000', 68 | 'ID of the pad document. This pad document is added to the TF Records ' 69 | 'whenever the number of retrieved documents is lower than num_eval_docs. ' 70 | 'This can be any valid doc id.') 71 | 72 | 73 | def convert_train(tokenizer): 74 | """Convert triples train to a TF Record file.""" 75 | start_time = time.time() 76 | 77 | print('Counting the number of training examples...') 78 | num_examples = sum(1 for _ in open(FLAGS.triples_train)) 79 | 80 | print('Converting to tfrecord...') 81 | with tf.python_io.TFRecordWriter( 82 | FLAGS.output_folder + '/dataset_train.tf') as writer: 83 | for i, line in tqdm(enumerate(open(FLAGS.triples_train)), 84 | total=num_examples): 85 | query, relevant_doc, non_relevant_doc = line.rstrip().split('\t') 86 | 87 | query = tokenization.convert_to_unicode(query) 88 | query_ids = tokenization.convert_to_bert_input( 89 | text=query, 90 | max_seq_length=FLAGS.max_query_length, 91 | tokenizer=tokenizer, 92 | add_cls=True) 93 | 94 | labels = [1, 0] 95 | 96 | if i % 1000 == 0: 97 | print(f'query: {query}') 98 | print(f'Relevant doc: {relevant_doc}') 99 | print(f'Non-Relevant doc: {non_relevant_doc}\n') 100 | 101 | doc_token_ids = [ 102 | tokenization.convert_to_bert_input( 103 | text=tokenization.convert_to_unicode(doc_text), 104 | max_seq_length=( 105 | FLAGS.max_seq_length - len(query_ids)) // 2, 106 | tokenizer=tokenizer, 107 | add_cls=False) 108 | for doc_text in [relevant_doc, non_relevant_doc] 109 | ] 110 | 111 | input_ids = [ 112 | query_ids + doc_token_ids[0] + doc_token_ids[1], 113 | query_ids + doc_token_ids[1] + doc_token_ids[0] 114 | ] 115 | segment_ids = [ 116 | ([0] * len(query_ids) + [1] * len(doc_token_ids[0]) + 117 | [2] * len(doc_token_ids[1])), 118 | ([0] * len(query_ids) + [1] * len(doc_token_ids[1]) + 119 | [2] * len(doc_token_ids[0])) 120 | ] 121 | 122 | for input_id, segment_id, label in zip( 123 | input_ids, segment_ids, labels): 124 | 125 | input_id_tf = tf.train.Feature( 126 | int64_list=tf.train.Int64List(value=input_id)) 127 | 128 | segment_id_tf = tf.train.Feature( 129 | int64_list=tf.train.Int64List(value=segment_id)) 130 | 131 | labels_tf = tf.train.Feature( 132 | int64_list=tf.train.Int64List(value=[label])) 133 | 134 | features = tf.train.Features(feature={ 135 | 'input_ids': input_id_tf, 136 | 'segment_ids': segment_id_tf, 137 | 'label': labels_tf, 138 | }) 139 | example = tf.train.Example(features=features) 140 | writer.write(example.SerializeToString()) 141 | 142 | 143 | def convert_dataset(data, corpus, set_name, max_docs, tokenizer): 144 | """Convert dev or test dataset to a TF Record file.""" 145 | ids_file = open( 146 | FLAGS.output_folder + '/query_doc_ids_' + set_name + '.txt' , 'w') 147 | output_path = FLAGS.output_folder + '/dataset_' + set_name + '.tf' 148 | 149 | print(f'Converting {set_name} to tfrecord') 150 | start_time = time.time() 151 | 152 | with tf.python_io.TFRecordWriter(output_path) as writer: 153 | for i, query_id in tqdm(enumerate(data), total=len(data)): 154 | query, qrels, doc_ids = data[query_id] 155 | 156 | query = tokenization.convert_to_unicode(query) 157 | query_ids = tokenization.convert_to_bert_input( 158 | text=query, 159 | max_seq_length=FLAGS.max_query_length, 160 | tokenizer=tokenizer, 161 | add_cls=True) 162 | 163 | doc_ids = doc_ids[:max_docs] 164 | 165 | # Add fake docs so we always have max_docs per query. 166 | doc_ids += max(0, max_docs - len(doc_ids)) * [FLAGS.pad_doc_id] 167 | 168 | labels = [ 169 | 1 if doc_id in qrels else 0 170 | for doc_id in doc_ids 171 | ] 172 | 173 | if i % 1000 == 0: 174 | print(f'query: {query}; len qrels: {len(qrels)}') 175 | print(f'sum labels: {sum(labels)}') 176 | for j, (label, doc_id) in enumerate(zip(labels, doc_ids)): 177 | print(f'doc {j}, label {label}, id: {doc_id}\n' 178 | f'{corpus[doc_id]}\n\n') 179 | print() 180 | 181 | doc_token_ids = [ 182 | tokenization.convert_to_bert_input( 183 | text=tokenization.convert_to_unicode( 184 | corpus[doc_id]), 185 | max_seq_length=( 186 | FLAGS.max_seq_length - len(query_ids)) // 2, 187 | tokenizer=tokenizer, 188 | add_cls=False) 189 | for doc_id in doc_ids 190 | ] 191 | input_ids = [] 192 | segment_ids = [] 193 | pair_doc_ids = [] 194 | labels_pair = [] 195 | for num_a, (doc_id_a, doc_token_id_a, label_a) in enumerate( 196 | zip(doc_ids, doc_token_ids, labels)): 197 | for num_b, (doc_id_b, doc_token_id_b) in enumerate( 198 | zip(doc_ids, doc_token_ids)): 199 | if num_a == num_b: 200 | continue 201 | input_ids.append( 202 | query_ids + doc_token_id_a + doc_token_id_b) 203 | segment_ids.append(( 204 | [0] * len(query_ids) + 205 | [1] * len(doc_token_id_a) + 206 | [2] * len(doc_token_id_b))) 207 | pair_doc_ids.append((doc_id_a, doc_id_b)) 208 | labels_pair.append(label_a) 209 | 210 | for input_id, segment_id, label, pair_doc_id in zip( 211 | input_ids, segment_ids, labels_pair, pair_doc_ids): 212 | 213 | ids_file.write( 214 | f'{query_id}\t{pair_doc_id[0]}\t{pair_doc_id[1]}\n') 215 | 216 | input_id_tf = tf.train.Feature( 217 | int64_list=tf.train.Int64List(value=input_id)) 218 | 219 | segment_id_tf = tf.train.Feature( 220 | int64_list=tf.train.Int64List(value=segment_id)) 221 | 222 | labels_tf = tf.train.Feature( 223 | int64_list=tf.train.Int64List(value=[label])) 224 | 225 | features = tf.train.Features(feature={ 226 | 'input_ids': input_id_tf, 227 | 'segment_ids': segment_id_tf, 228 | 'label': labels_tf, 229 | }) 230 | 231 | example = tf.train.Example(features=features) 232 | writer.write(example.SerializeToString()) 233 | 234 | ids_file.close() 235 | 236 | 237 | def load_qrels(path): 238 | """Loads qrels into a dict of key: query_id, value: list of relevant doc 239 | ids.""" 240 | qrels = collections.defaultdict(set) 241 | print(f'Loading qrels: {path}') 242 | with open(path) as f: 243 | for line in tqdm(f): 244 | query_id, _, doc_id, relevance = line.rstrip().split('\t') 245 | if int(relevance) >= 1: 246 | qrels[query_id].add(doc_id) 247 | return qrels 248 | 249 | 250 | def load_queries(path): 251 | """Loads queries into a dict of key: query_id, value: query text.""" 252 | queries = {} 253 | print(f'Loading queries: {path}') 254 | with open(path) as f: 255 | for line in tqdm(f): 256 | query_id, query = line.rstrip().split('\t') 257 | queries[query_id] = query 258 | return queries 259 | 260 | 261 | def load_run(path): 262 | """Loads run into a dict of key: query_id, value: list of candidate doc 263 | ids.""" 264 | 265 | # We want to preserve the order of runs so we can pair the run file with 266 | # the TFRecord file. 267 | run = collections.OrderedDict() 268 | print(f'Loading run: {path}') 269 | with open(path) as f: 270 | for line in tqdm(f): 271 | query_id, doc_id, rank= line.split('\t') 272 | if query_id not in run: 273 | run[query_id] = [] 274 | run[query_id].append((doc_id, int(rank))) 275 | 276 | 277 | # Sort candidate docs by rank. 278 | sorted_run = collections.OrderedDict() 279 | for query_id, doc_ids_ranks in run.items(): 280 | sorted(doc_ids_ranks, key=lambda x: x[1]) 281 | doc_ids = [doc_ids for doc_ids, _ in doc_ids_ranks] 282 | sorted_run[query_id] = doc_ids 283 | 284 | return sorted_run 285 | 286 | 287 | def merge(qrels, run, queries): 288 | """Merge qrels and runs into a single dict of key: query, 289 | value: tuple(relevant_doc_ids, candidate_doc_ids)""" 290 | data = collections.OrderedDict() 291 | for query_id, candidate_doc_ids in run.items(): 292 | query = queries[query_id] 293 | relevant_doc_ids = set() 294 | if qrels: 295 | relevant_doc_ids = qrels[query_id] 296 | data[query_id] = (query, relevant_doc_ids, candidate_doc_ids) 297 | return data 298 | 299 | 300 | def load_corpus(path): 301 | """Load corpus into a dictionary with keys as doc ids and values as doc 302 | texts.""" 303 | corpus = {} 304 | with open(path) as f: 305 | for line in tqdm(f): 306 | doc_id, doc_text = line.strip().split('\t') 307 | corpus[doc_id] = doc_text 308 | return corpus 309 | 310 | 311 | def main(): 312 | if not os.path.exists(FLAGS.output_folder): 313 | os.makedirs(FLAGS.output_folder) 314 | 315 | print('Loading Tokenizer...') 316 | tokenizer = tokenization.FullTokenizer( 317 | vocab_file=FLAGS.vocab_file, do_lower_case=True) 318 | 319 | print('Loading Corpus...') 320 | corpus = load_corpus(FLAGS.corpus) 321 | 322 | print('Converting Training Set...') 323 | convert_train(tokenizer=tokenizer) 324 | 325 | for set_name, queries_path, qrels_path, run_path, max_docs in [ 326 | ('dev', FLAGS.queries_dev, FLAGS.qrels_dev, FLAGS.run_dev, 327 | FLAGS.num_dev_docs), 328 | ('test', FLAGS.queries_test, None, FLAGS.run_test, 329 | FLAGS.num_test_docs)]: 330 | 331 | print(f'Converting {set_name}') 332 | qrels = None 333 | if set_name != 'test': 334 | qrels = load_qrels(path=qrels_path) 335 | 336 | queries = load_queries(queries_path) 337 | run = load_run(path=run_path) 338 | data = merge(qrels=qrels, run=run, queries=queries) 339 | 340 | convert_dataset(data=data, 341 | corpus=corpus, 342 | set_name=set_name, 343 | max_docs=max_docs, 344 | tokenizer=tokenizer) 345 | print('Done!') 346 | 347 | 348 | if __name__ == '__main__': 349 | main() 350 | -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def average_precision(gt, pred): 5 | """ 6 | Computes the average precision. 7 | 8 | This function computes the average prescision at k between two lists of 9 | items. 10 | 11 | Parameters 12 | ---------- 13 | gt: set 14 | A set of ground-truth elements (order doesn't matter) 15 | pred: list 16 | A list of predicted elements (order does matter) 17 | 18 | Returns 19 | ------- 20 | score: double 21 | The average precision over the input lists 22 | """ 23 | 24 | if not gt: 25 | return 0.0 26 | 27 | score = 0.0 28 | num_hits = 0.0 29 | for i,p in enumerate(pred): 30 | if p in gt and p not in pred[:i]: 31 | num_hits += 1.0 32 | score += num_hits / (i + 1.0) 33 | 34 | return score / max(1.0, len(gt)) 35 | 36 | 37 | def NDCG(gt, pred, use_graded_scores=False): 38 | score = 0.0 39 | for rank, item in enumerate(pred): 40 | if item in gt: 41 | if use_graded_scores: 42 | grade = 1.0 / (gt.index(item) + 1) 43 | else: 44 | grade = 1.0 45 | score += grade / np.log2(rank + 2) 46 | 47 | norm = 0.0 48 | for rank in range(len(gt)): 49 | if use_graded_scores: 50 | grade = 1.0 / (rank + 1) 51 | else: 52 | grade = 1.0 53 | norm += grade / np.log2(rank + 2) 54 | return score / max(0.3, norm) 55 | 56 | 57 | def metrics(gt, pred, metrics_map): 58 | ''' 59 | Returns a numpy array containing metrics specified by metrics_map. 60 | gt: ground-truth items 61 | pred: predicted items 62 | ''' 63 | out = np.zeros((len(metrics_map),), np.float32) 64 | 65 | if ('MAP' in metrics_map): 66 | avg_precision = average_precision(gt=gt, pred=pred) 67 | out[metrics_map.index('MAP')] = avg_precision 68 | 69 | if ('RPrec' in metrics_map): 70 | intersec = len(gt & set(pred[:len(gt)])) 71 | out[metrics_map.index('RPrec')] = intersec / max(1., float(len(gt))) 72 | 73 | if 'MRR' in metrics_map: 74 | score = 0.0 75 | for rank, item in enumerate(pred): 76 | if item in gt: 77 | score = 1.0 / (rank + 1.0) 78 | break 79 | out[metrics_map.index('MRR')] = score 80 | 81 | if 'MRR@10' in metrics_map: 82 | score = 0.0 83 | for rank, item in enumerate(pred[:10]): 84 | if item in gt: 85 | score = 1.0 / (rank + 1.0) 86 | break 87 | out[metrics_map.index('MRR@10')] = score 88 | 89 | if ('NDCG' in metrics_map): 90 | out[metrics_map.index('NDCG')] = NDCG(gt, pred) 91 | 92 | return out 93 | 94 | -------------------------------------------------------------------------------- /modeling.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """The main BERT model and related functions.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import collections 22 | import copy 23 | import json 24 | import math 25 | import re 26 | import six 27 | import tensorflow as tf 28 | 29 | 30 | class BertConfig(object): 31 | """Configuration for `BertModel`.""" 32 | 33 | def __init__(self, 34 | vocab_size, 35 | hidden_size=768, 36 | num_hidden_layers=12, 37 | num_attention_heads=12, 38 | intermediate_size=3072, 39 | hidden_act="gelu", 40 | hidden_dropout_prob=0.1, 41 | attention_probs_dropout_prob=0.1, 42 | max_position_embeddings=512, 43 | type_vocab_size=16, 44 | initializer_range=0.02): 45 | """Constructs BertConfig. 46 | 47 | Args: 48 | vocab_size: Vocabulary size of `inputs_ids` in `BertModel`. 49 | hidden_size: Size of the encoder layers and the pooler layer. 50 | num_hidden_layers: Number of hidden layers in the Transformer encoder. 51 | num_attention_heads: Number of attention heads for each attention layer in 52 | the Transformer encoder. 53 | intermediate_size: The size of the "intermediate" (i.e., feed-forward) 54 | layer in the Transformer encoder. 55 | hidden_act: The non-linear activation function (function or string) in the 56 | encoder and pooler. 57 | hidden_dropout_prob: The dropout probability for all fully connected 58 | layers in the embeddings, encoder, and pooler. 59 | attention_probs_dropout_prob: The dropout ratio for the attention 60 | probabilities. 61 | max_position_embeddings: The maximum sequence length that this model might 62 | ever be used with. Typically set this to something large just in case 63 | (e.g., 512 or 1024 or 2048). 64 | type_vocab_size: The vocabulary size of the `token_type_ids` passed into 65 | `BertModel`. 66 | initializer_range: The stdev of the truncated_normal_initializer for 67 | initializing all weight matrices. 68 | """ 69 | self.vocab_size = vocab_size 70 | self.hidden_size = hidden_size 71 | self.num_hidden_layers = num_hidden_layers 72 | self.num_attention_heads = num_attention_heads 73 | self.hidden_act = hidden_act 74 | self.intermediate_size = intermediate_size 75 | self.hidden_dropout_prob = hidden_dropout_prob 76 | self.attention_probs_dropout_prob = attention_probs_dropout_prob 77 | self.max_position_embeddings = max_position_embeddings 78 | self.type_vocab_size = type_vocab_size 79 | self.initializer_range = initializer_range 80 | 81 | @classmethod 82 | def from_dict(cls, json_object): 83 | """Constructs a `BertConfig` from a Python dictionary of parameters.""" 84 | config = BertConfig(vocab_size=None) 85 | for (key, value) in six.iteritems(json_object): 86 | config.__dict__[key] = value 87 | return config 88 | 89 | @classmethod 90 | def from_json_file(cls, json_file): 91 | """Constructs a `BertConfig` from a json file of parameters.""" 92 | with tf.gfile.GFile(json_file, "r") as reader: 93 | text = reader.read() 94 | return cls.from_dict(json.loads(text)) 95 | 96 | def to_dict(self): 97 | """Serializes this instance to a Python dictionary.""" 98 | output = copy.deepcopy(self.__dict__) 99 | return output 100 | 101 | def to_json_string(self): 102 | """Serializes this instance to a JSON string.""" 103 | return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" 104 | 105 | 106 | class BertModel(object): 107 | """BERT model ("Bidirectional Embedding Representations from a Transformer"). 108 | 109 | Example usage: 110 | 111 | ```python 112 | # Already been converted into WordPiece token ids 113 | input_ids = tf.constant([[31, 51, 99], [15, 5, 0]]) 114 | input_mask = tf.constant([[1, 1, 1], [1, 1, 0]]) 115 | token_type_ids = tf.constant([[0, 0, 1], [0, 2, 0]]) 116 | 117 | config = modeling.BertConfig(vocab_size=32000, hidden_size=512, 118 | num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024) 119 | 120 | model = modeling.BertModel(config=config, is_training=True, 121 | input_ids=input_ids, input_mask=input_mask, token_type_ids=token_type_ids) 122 | 123 | label_embeddings = tf.get_variable(...) 124 | pooled_output = model.get_pooled_output() 125 | logits = tf.matmul(pooled_output, label_embeddings) 126 | ... 127 | ``` 128 | """ 129 | 130 | def __init__(self, 131 | config, 132 | is_training, 133 | input_ids, 134 | input_mask=None, 135 | token_type_ids=None, 136 | use_one_hot_embeddings=True, 137 | scope=None): 138 | """Constructor for BertModel. 139 | 140 | Args: 141 | config: `BertConfig` instance. 142 | is_training: bool. rue for training model, false for eval model. Controls 143 | whether dropout will be applied. 144 | input_ids: int32 Tensor of shape [batch_size, seq_length]. 145 | input_mask: (optional) int32 Tensor of shape [batch_size, seq_length]. 146 | token_type_ids: (optional) int32 Tensor of shape [batch_size, seq_length]. 147 | use_one_hot_embeddings: (optional) bool. Whether to use one-hot word 148 | embeddings or tf.embedding_lookup() for the word embeddings. On the TPU, 149 | it is must faster if this is True, on the CPU or GPU, it is faster if 150 | this is False. 151 | scope: (optional) variable scope. Defaults to "bert". 152 | 153 | Raises: 154 | ValueError: The config is invalid or one of the input tensor shapes 155 | is invalid. 156 | """ 157 | config = copy.deepcopy(config) 158 | if not is_training: 159 | config.hidden_dropout_prob = 0.0 160 | config.attention_probs_dropout_prob = 0.0 161 | 162 | input_shape = get_shape_list(input_ids, expected_rank=2) 163 | batch_size = input_shape[0] 164 | seq_length = input_shape[1] 165 | 166 | if input_mask is None: 167 | input_mask = tf.ones(shape=[batch_size, seq_length], dtype=tf.int32) 168 | 169 | if token_type_ids is None: 170 | token_type_ids = tf.zeros(shape=[batch_size, seq_length], dtype=tf.int32) 171 | 172 | with tf.variable_scope("bert", scope): 173 | with tf.variable_scope("embeddings"): 174 | # Perform embedding lookup on the word ids. 175 | (self.embedding_output, self.embedding_table) = embedding_lookup( 176 | input_ids=input_ids, 177 | vocab_size=config.vocab_size, 178 | embedding_size=config.hidden_size, 179 | initializer_range=config.initializer_range, 180 | word_embedding_name="word_embeddings", 181 | use_one_hot_embeddings=use_one_hot_embeddings) 182 | 183 | # Add positional embeddings and token type embeddings, then layer 184 | # normalize and perform dropout. 185 | self.embedding_output = embedding_postprocessor( 186 | input_tensor=self.embedding_output, 187 | use_token_type=True, 188 | token_type_ids=token_type_ids, 189 | token_type_vocab_size=config.type_vocab_size, 190 | token_type_embedding_name="token_type_embeddings", 191 | use_position_embeddings=True, 192 | position_embedding_name="position_embeddings", 193 | initializer_range=config.initializer_range, 194 | max_position_embeddings=config.max_position_embeddings, 195 | dropout_prob=config.hidden_dropout_prob) 196 | 197 | with tf.variable_scope("encoder"): 198 | # This converts a 2D mask of shape [batch_size, seq_length] to a 3D 199 | # mask of shape [batch_size, seq_length, seq_length] which is used 200 | # for the attention scores. 201 | attention_mask = create_attention_mask_from_input_mask( 202 | input_ids, input_mask) 203 | 204 | # Run the stacked transformer. 205 | # `sequence_output` shape = [batch_size, seq_length, hidden_size]. 206 | self.all_encoder_layers = transformer_model( 207 | input_tensor=self.embedding_output, 208 | attention_mask=attention_mask, 209 | hidden_size=config.hidden_size, 210 | num_hidden_layers=config.num_hidden_layers, 211 | num_attention_heads=config.num_attention_heads, 212 | intermediate_size=config.intermediate_size, 213 | intermediate_act_fn=get_activation(config.hidden_act), 214 | hidden_dropout_prob=config.hidden_dropout_prob, 215 | attention_probs_dropout_prob=config.attention_probs_dropout_prob, 216 | initializer_range=config.initializer_range, 217 | do_return_all_layers=True) 218 | 219 | self.sequence_output = self.all_encoder_layers[-1] 220 | # The "pooler" converts the encoded sequence tensor of shape 221 | # [batch_size, seq_length, hidden_size] to a tensor of shape 222 | # [batch_size, hidden_size]. This is necessary for segment-level 223 | # (or segment-pair-level) classification tasks where we need a fixed 224 | # dimensional representation of the segment. 225 | with tf.variable_scope("pooler"): 226 | # We "pool" the model by simply taking the hidden state corresponding 227 | # to the first token. We assume that this has been pre-trained 228 | first_token_tensor = tf.squeeze(self.sequence_output[:, 0:1, :], axis=1) 229 | self.pooled_output = tf.layers.dense( 230 | first_token_tensor, 231 | config.hidden_size, 232 | activation=tf.tanh, 233 | kernel_initializer=create_initializer(config.initializer_range)) 234 | 235 | def get_pooled_output(self): 236 | return self.pooled_output 237 | 238 | def get_sequence_output(self): 239 | """Gets final hidden layer of encoder. 240 | 241 | Returns: 242 | float Tensor of shape [batch_size, seq_length, hidden_size] corresponding 243 | to the final hidden of the transformer encoder. 244 | """ 245 | return self.sequence_output 246 | 247 | def get_all_encoder_layers(self): 248 | return self.all_encoder_layers 249 | 250 | def get_embedding_output(self): 251 | """Gets output of the embedding lookup (i.e., input to the transformer). 252 | 253 | Returns: 254 | float Tensor of shape [batch_size, seq_length, hidden_size] corresponding 255 | to the output of the embedding layer, after summing the word 256 | embeddings with the positional embeddings and the token type embeddings, 257 | then performing layer normalization. This is the input to the transformer. 258 | """ 259 | return self.embedding_output 260 | 261 | def get_embedding_table(self): 262 | return self.embedding_table 263 | 264 | 265 | def gelu(input_tensor): 266 | """Gaussian Error Linear Unit. 267 | 268 | This is a smoother version of the RELU. 269 | Original paper: https://arxiv.org/abs/1606.08415 270 | 271 | Args: 272 | input_tensor: float Tensor to perform activation. 273 | 274 | Returns: 275 | `input_tensor` with the GELU activation applied. 276 | """ 277 | cdf = 0.5 * (1.0 + tf.erf(input_tensor / tf.sqrt(2.0))) 278 | return input_tensor * cdf 279 | 280 | 281 | def get_activation(activation_string): 282 | """Maps a string to a Python function, e.g., "relu" => `tf.nn.relu`. 283 | 284 | Args: 285 | activation_string: String name of the activation function. 286 | 287 | Returns: 288 | A Python function corresponding to the activation function. If 289 | `activation_string` is None, empty, or "linear", this will return None. 290 | If `activation_string` is not a string, it will return `activation_string`. 291 | 292 | Raises: 293 | ValueError: The `activation_string` does not correspond to a known 294 | activation. 295 | """ 296 | 297 | # We assume that anything that"s not a string is already an activation 298 | # function, so we just return it. 299 | if not isinstance(activation_string, six.string_types): 300 | return activation_string 301 | 302 | if not activation_string: 303 | return None 304 | 305 | act = activation_string.lower() 306 | if act == "linear": 307 | return None 308 | elif act == "relu": 309 | return tf.nn.relu 310 | elif act == "gelu": 311 | return gelu 312 | elif act == "tanh": 313 | return tf.tanh 314 | else: 315 | raise ValueError("Unsupported activation: %s" % act) 316 | 317 | 318 | def get_assignment_map_from_checkpoint(tvars, init_checkpoint): 319 | """Compute the union of the current variables and checkpoint variables.""" 320 | assignment_map = {} 321 | initialized_variable_names = {} 322 | 323 | name_to_variable = collections.OrderedDict() 324 | for var in tvars: 325 | name = var.name 326 | m = re.match("^(.*):\\d+$", name) 327 | if m is not None: 328 | name = m.group(1) 329 | name_to_variable[name] = var 330 | 331 | init_vars = tf.train.list_variables(init_checkpoint) 332 | 333 | assignment_map = collections.OrderedDict() 334 | for x in init_vars: 335 | (name, var) = (x[0], x[1]) 336 | if name not in name_to_variable: 337 | continue 338 | assignment_map[name] = name 339 | initialized_variable_names[name] = 1 340 | initialized_variable_names[name + ":0"] = 1 341 | 342 | return (assignment_map, initialized_variable_names) 343 | 344 | 345 | def dropout(input_tensor, dropout_prob): 346 | """Perform dropout. 347 | 348 | Args: 349 | input_tensor: float Tensor. 350 | dropout_prob: Python float. The probability of dropping out a value (NOT of 351 | *keeping* a dimension as in `tf.nn.dropout`). 352 | 353 | Returns: 354 | A version of `input_tensor` with dropout applied. 355 | """ 356 | if dropout_prob is None or dropout_prob == 0.0: 357 | return input_tensor 358 | 359 | output = tf.nn.dropout(input_tensor, 1.0 - dropout_prob) 360 | return output 361 | 362 | 363 | def layer_norm(input_tensor, name=None): 364 | """Run layer normalization on the last dimension of the tensor.""" 365 | return tf.contrib.layers.layer_norm( 366 | inputs=input_tensor, begin_norm_axis=-1, begin_params_axis=-1, scope=name) 367 | 368 | 369 | def layer_norm_and_dropout(input_tensor, dropout_prob, name=None): 370 | """Runs layer normalization followed by dropout.""" 371 | output_tensor = layer_norm(input_tensor, name) 372 | output_tensor = dropout(output_tensor, dropout_prob) 373 | return output_tensor 374 | 375 | 376 | def create_initializer(initializer_range=0.02): 377 | """Creates a `truncated_normal_initializer` with the given range.""" 378 | return tf.truncated_normal_initializer(stddev=initializer_range) 379 | 380 | 381 | def embedding_lookup(input_ids, 382 | vocab_size, 383 | embedding_size=128, 384 | initializer_range=0.02, 385 | word_embedding_name="word_embeddings", 386 | use_one_hot_embeddings=False): 387 | """Looks up words embeddings for id tensor. 388 | 389 | Args: 390 | input_ids: int32 Tensor of shape [batch_size, seq_length] containing word 391 | ids. 392 | vocab_size: int. Size of the embedding vocabulary. 393 | embedding_size: int. Width of the word embeddings. 394 | initializer_range: float. Embedding initialization range. 395 | word_embedding_name: string. Name of the embedding table. 396 | use_one_hot_embeddings: bool. If True, use one-hot method for word 397 | embeddings. If False, use `tf.nn.embedding_lookup()`. One hot is better 398 | for TPUs. 399 | 400 | Returns: 401 | float Tensor of shape [batch_size, seq_length, embedding_size]. 402 | """ 403 | # This function assumes that the input is of shape [batch_size, seq_length, 404 | # num_inputs]. 405 | # 406 | # If the input is a 2D tensor of shape [batch_size, seq_length], we 407 | # reshape to [batch_size, seq_length, 1]. 408 | if input_ids.shape.ndims == 2: 409 | input_ids = tf.expand_dims(input_ids, axis=[-1]) 410 | 411 | embedding_table = tf.get_variable( 412 | name=word_embedding_name, 413 | shape=[vocab_size, embedding_size], 414 | initializer=create_initializer(initializer_range)) 415 | 416 | if use_one_hot_embeddings: 417 | flat_input_ids = tf.reshape(input_ids, [-1]) 418 | one_hot_input_ids = tf.one_hot(flat_input_ids, depth=vocab_size) 419 | output = tf.matmul(one_hot_input_ids, embedding_table) 420 | else: 421 | output = tf.nn.embedding_lookup(embedding_table, input_ids) 422 | 423 | input_shape = get_shape_list(input_ids) 424 | 425 | output = tf.reshape(output, 426 | input_shape[0:-1] + [input_shape[-1] * embedding_size]) 427 | return (output, embedding_table) 428 | 429 | 430 | def embedding_postprocessor(input_tensor, 431 | use_token_type=False, 432 | token_type_ids=None, 433 | token_type_vocab_size=16, 434 | token_type_embedding_name="token_type_embeddings", 435 | use_position_embeddings=True, 436 | position_embedding_name="position_embeddings", 437 | initializer_range=0.02, 438 | max_position_embeddings=512, 439 | dropout_prob=0.1): 440 | """Performs various post-processing on a word embedding tensor. 441 | 442 | Args: 443 | input_tensor: float Tensor of shape [batch_size, seq_length, 444 | embedding_size]. 445 | use_token_type: bool. Whether to add embeddings for `token_type_ids`. 446 | token_type_ids: (optional) int32 Tensor of shape [batch_size, seq_length]. 447 | Must be specified if `use_token_type` is True. 448 | token_type_vocab_size: int. The vocabulary size of `token_type_ids`. 449 | token_type_embedding_name: string. The name of the embedding table variable 450 | for token type ids. 451 | use_position_embeddings: bool. Whether to add position embeddings for the 452 | position of each token in the sequence. 453 | position_embedding_name: string. The name of the embedding table variable 454 | for positional embeddings. 455 | initializer_range: float. Range of the weight initialization. 456 | max_position_embeddings: int. Maximum sequence length that might ever be 457 | used with this model. This can be longer than the sequence length of 458 | input_tensor, but cannot be shorter. 459 | dropout_prob: float. Dropout probability applied to the final output tensor. 460 | 461 | Returns: 462 | float tensor with same shape as `input_tensor`. 463 | 464 | Raises: 465 | ValueError: One of the tensor shapes or input values is invalid. 466 | """ 467 | input_shape = get_shape_list(input_tensor, expected_rank=3) 468 | batch_size = input_shape[0] 469 | seq_length = input_shape[1] 470 | width = input_shape[2] 471 | 472 | if seq_length > max_position_embeddings: 473 | raise ValueError("The seq length (%d) cannot be greater than " 474 | "`max_position_embeddings` (%d)" % 475 | (seq_length, max_position_embeddings)) 476 | 477 | output = input_tensor 478 | 479 | if use_token_type: 480 | if token_type_ids is None: 481 | raise ValueError("`token_type_ids` must be specified if" 482 | "`use_token_type` is True.") 483 | token_type_table = tf.get_variable( 484 | name=token_type_embedding_name, 485 | shape=[token_type_vocab_size, width], 486 | initializer=create_initializer(initializer_range)) 487 | # This vocab will be small so we always do one-hot here, since it is always 488 | # faster for a small vocabulary. 489 | flat_token_type_ids = tf.reshape(token_type_ids, [-1]) 490 | one_hot_ids = tf.one_hot(flat_token_type_ids, depth=token_type_vocab_size) 491 | token_type_embeddings = tf.matmul(one_hot_ids, token_type_table) 492 | token_type_embeddings = tf.reshape(token_type_embeddings, 493 | [batch_size, seq_length, width]) 494 | output += token_type_embeddings 495 | 496 | if use_position_embeddings: 497 | full_position_embeddings = tf.get_variable( 498 | name=position_embedding_name, 499 | shape=[max_position_embeddings, width], 500 | initializer=create_initializer(initializer_range)) 501 | # Since the position embedding table is a learned variable, we create it 502 | # using a (long) sequence length `max_position_embeddings`. The actual 503 | # sequence length might be shorter than this, for faster training of 504 | # tasks that do not have long sequences. 505 | # 506 | # So `full_position_embeddings` is effectively an embedding table 507 | # for position [0, 1, 2, ..., max_position_embeddings-1], and the current 508 | # sequence has positions [0, 1, 2, ... seq_length-1], so we can just 509 | # perform a slice. 510 | if seq_length < max_position_embeddings: 511 | position_embeddings = tf.slice(full_position_embeddings, [0, 0], 512 | [seq_length, -1]) 513 | else: 514 | position_embeddings = full_position_embeddings 515 | 516 | num_dims = len(output.shape.as_list()) 517 | 518 | # Only the last two dimensions are relevant (`seq_length` and `width`), so 519 | # we broadcast among the first dimensions, which is typically just 520 | # the batch size. 521 | position_broadcast_shape = [] 522 | for _ in range(num_dims - 2): 523 | position_broadcast_shape.append(1) 524 | position_broadcast_shape.extend([seq_length, width]) 525 | position_embeddings = tf.reshape(position_embeddings, 526 | position_broadcast_shape) 527 | output += position_embeddings 528 | 529 | output = layer_norm_and_dropout(output, dropout_prob) 530 | return output 531 | 532 | 533 | def create_attention_mask_from_input_mask(from_tensor, to_mask): 534 | """Create 3D attention mask from a 2D tensor mask. 535 | 536 | Args: 537 | from_tensor: 2D or 3D Tensor of shape [batch_size, from_seq_length, ...]. 538 | to_mask: int32 Tensor of shape [batch_size, to_seq_length]. 539 | 540 | Returns: 541 | float Tensor of shape [batch_size, from_seq_length, to_seq_length]. 542 | """ 543 | from_shape = get_shape_list(from_tensor, expected_rank=[2, 3]) 544 | batch_size = from_shape[0] 545 | from_seq_length = from_shape[1] 546 | 547 | to_shape = get_shape_list(to_mask, expected_rank=2) 548 | to_seq_length = to_shape[1] 549 | 550 | to_mask = tf.cast( 551 | tf.reshape(to_mask, [batch_size, 1, to_seq_length]), tf.float32) 552 | 553 | # We don't assume that `from_tensor` is a mask (although it could be). We 554 | # don't actually care if we attend *from* padding tokens (only *to* padding) 555 | # tokens so we create a tensor of all ones. 556 | # 557 | # `broadcast_ones` = [batch_size, from_seq_length, 1] 558 | broadcast_ones = tf.ones( 559 | shape=[batch_size, from_seq_length, 1], dtype=tf.float32) 560 | 561 | # Here we broadcast along two dimensions to create the mask. 562 | mask = broadcast_ones * to_mask 563 | 564 | return mask 565 | 566 | 567 | def attention_layer(from_tensor, 568 | to_tensor, 569 | attention_mask=None, 570 | num_attention_heads=1, 571 | size_per_head=512, 572 | query_act=None, 573 | key_act=None, 574 | value_act=None, 575 | attention_probs_dropout_prob=0.0, 576 | initializer_range=0.02, 577 | do_return_2d_tensor=False, 578 | batch_size=None, 579 | from_seq_length=None, 580 | to_seq_length=None): 581 | """Performs multi-headed attention from `from_tensor` to `to_tensor`. 582 | 583 | This is an implementation of multi-headed attention based on "Attention 584 | is all you Need". If `from_tensor` and `to_tensor` are the same, then 585 | this is self-attention. Each timestep in `from_tensor` attends to the 586 | corresponding sequence in `to_tensor`, and returns a fixed-with vector. 587 | 588 | This function first projects `from_tensor` into a "query" tensor and 589 | `to_tensor` into "key" and "value" tensors. These are (effectively) a list 590 | of tensors of length `num_attention_heads`, where each tensor is of shape 591 | [batch_size, seq_length, size_per_head]. 592 | 593 | Then, the query and key tensors are dot-producted and scaled. These are 594 | softmaxed to obtain attention probabilities. The value tensors are then 595 | interpolated by these probabilities, then concatenated back to a single 596 | tensor and returned. 597 | 598 | In practice, the multi-headed attention are done with transposes and 599 | reshapes rather than actual separate tensors. 600 | 601 | Args: 602 | from_tensor: float Tensor of shape [batch_size, from_seq_length, 603 | from_width]. 604 | to_tensor: float Tensor of shape [batch_size, to_seq_length, to_width]. 605 | attention_mask: (optional) int32 Tensor of shape [batch_size, 606 | from_seq_length, to_seq_length]. The values should be 1 or 0. The 607 | attention scores will effectively be set to -infinity for any positions in 608 | the mask that are 0, and will be unchanged for positions that are 1. 609 | num_attention_heads: int. Number of attention heads. 610 | size_per_head: int. Size of each attention head. 611 | query_act: (optional) Activation function for the query transform. 612 | key_act: (optional) Activation function for the key transform. 613 | value_act: (optional) Activation function for the value transform. 614 | attention_probs_dropout_prob: (optional) float. Dropout probability of the 615 | attention probabilities. 616 | initializer_range: float. Range of the weight initializer. 617 | do_return_2d_tensor: bool. If True, the output will be of shape [batch_size 618 | * from_seq_length, num_attention_heads * size_per_head]. If False, the 619 | output will be of shape [batch_size, from_seq_length, num_attention_heads 620 | * size_per_head]. 621 | batch_size: (Optional) int. If the input is 2D, this might be the batch size 622 | of the 3D version of the `from_tensor` and `to_tensor`. 623 | from_seq_length: (Optional) If the input is 2D, this might be the seq length 624 | of the 3D version of the `from_tensor`. 625 | to_seq_length: (Optional) If the input is 2D, this might be the seq length 626 | of the 3D version of the `to_tensor`. 627 | 628 | Returns: 629 | float Tensor of shape [batch_size, from_seq_length, 630 | num_attention_heads * size_per_head]. (If `do_return_2d_tensor` is 631 | true, this will be of shape [batch_size * from_seq_length, 632 | num_attention_heads * size_per_head]). 633 | 634 | Raises: 635 | ValueError: Any of the arguments or tensor shapes are invalid. 636 | """ 637 | 638 | def transpose_for_scores(input_tensor, batch_size, num_attention_heads, 639 | seq_length, width): 640 | output_tensor = tf.reshape( 641 | input_tensor, [batch_size, seq_length, num_attention_heads, width]) 642 | 643 | output_tensor = tf.transpose(output_tensor, [0, 2, 1, 3]) 644 | return output_tensor 645 | 646 | from_shape = get_shape_list(from_tensor, expected_rank=[2, 3]) 647 | to_shape = get_shape_list(to_tensor, expected_rank=[2, 3]) 648 | 649 | if len(from_shape) != len(to_shape): 650 | raise ValueError( 651 | "The rank of `from_tensor` must match the rank of `to_tensor`.") 652 | 653 | if len(from_shape) == 3: 654 | batch_size = from_shape[0] 655 | from_seq_length = from_shape[1] 656 | to_seq_length = to_shape[1] 657 | elif len(from_shape) == 2: 658 | if (batch_size is None or from_seq_length is None or to_seq_length is None): 659 | raise ValueError( 660 | "When passing in rank 2 tensors to attention_layer, the values " 661 | "for `batch_size`, `from_seq_length`, and `to_seq_length` " 662 | "must all be specified.") 663 | 664 | # Scalar dimensions referenced here: 665 | # B = batch size (number of sequences) 666 | # F = `from_tensor` sequence length 667 | # T = `to_tensor` sequence length 668 | # N = `num_attention_heads` 669 | # H = `size_per_head` 670 | 671 | from_tensor_2d = reshape_to_matrix(from_tensor) 672 | to_tensor_2d = reshape_to_matrix(to_tensor) 673 | 674 | # `query_layer` = [B*F, N*H] 675 | query_layer = tf.layers.dense( 676 | from_tensor_2d, 677 | num_attention_heads * size_per_head, 678 | activation=query_act, 679 | name="query", 680 | kernel_initializer=create_initializer(initializer_range)) 681 | 682 | # `key_layer` = [B*T, N*H] 683 | key_layer = tf.layers.dense( 684 | to_tensor_2d, 685 | num_attention_heads * size_per_head, 686 | activation=key_act, 687 | name="key", 688 | kernel_initializer=create_initializer(initializer_range)) 689 | 690 | # `value_layer` = [B*T, N*H] 691 | value_layer = tf.layers.dense( 692 | to_tensor_2d, 693 | num_attention_heads * size_per_head, 694 | activation=value_act, 695 | name="value", 696 | kernel_initializer=create_initializer(initializer_range)) 697 | 698 | # `query_layer` = [B, N, F, H] 699 | query_layer = transpose_for_scores(query_layer, batch_size, 700 | num_attention_heads, from_seq_length, 701 | size_per_head) 702 | 703 | # `key_layer` = [B, N, T, H] 704 | key_layer = transpose_for_scores(key_layer, batch_size, num_attention_heads, 705 | to_seq_length, size_per_head) 706 | 707 | # Take the dot product between "query" and "key" to get the raw 708 | # attention scores. 709 | # `attention_scores` = [B, N, F, T] 710 | attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True) 711 | attention_scores = tf.multiply(attention_scores, 712 | 1.0 / math.sqrt(float(size_per_head))) 713 | 714 | if attention_mask is not None: 715 | # `attention_mask` = [B, 1, F, T] 716 | attention_mask = tf.expand_dims(attention_mask, axis=[1]) 717 | 718 | # Since attention_mask is 1.0 for positions we want to attend and 0.0 for 719 | # masked positions, this operation will create a tensor which is 0.0 for 720 | # positions we want to attend and -10000.0 for masked positions. 721 | adder = (1.0 - tf.cast(attention_mask, tf.float32)) * -10000.0 722 | 723 | # Since we are adding it to the raw scores before the softmax, this is 724 | # effectively the same as removing these entirely. 725 | attention_scores += adder 726 | 727 | # Normalize the attention scores to probabilities. 728 | # `attention_probs` = [B, N, F, T] 729 | attention_probs = tf.nn.softmax(attention_scores) 730 | 731 | # This is actually dropping out entire tokens to attend to, which might 732 | # seem a bit unusual, but is taken from the original Transformer paper. 733 | attention_probs = dropout(attention_probs, attention_probs_dropout_prob) 734 | 735 | # `value_layer` = [B, T, N, H] 736 | value_layer = tf.reshape( 737 | value_layer, 738 | [batch_size, to_seq_length, num_attention_heads, size_per_head]) 739 | 740 | # `value_layer` = [B, N, T, H] 741 | value_layer = tf.transpose(value_layer, [0, 2, 1, 3]) 742 | 743 | # `context_layer` = [B, N, F, H] 744 | context_layer = tf.matmul(attention_probs, value_layer) 745 | 746 | # `context_layer` = [B, F, N, H] 747 | context_layer = tf.transpose(context_layer, [0, 2, 1, 3]) 748 | 749 | if do_return_2d_tensor: 750 | # `context_layer` = [B*F, N*V] 751 | context_layer = tf.reshape( 752 | context_layer, 753 | [batch_size * from_seq_length, num_attention_heads * size_per_head]) 754 | else: 755 | # `context_layer` = [B, F, N*V] 756 | context_layer = tf.reshape( 757 | context_layer, 758 | [batch_size, from_seq_length, num_attention_heads * size_per_head]) 759 | 760 | return context_layer 761 | 762 | 763 | def transformer_model(input_tensor, 764 | attention_mask=None, 765 | hidden_size=768, 766 | num_hidden_layers=12, 767 | num_attention_heads=12, 768 | intermediate_size=3072, 769 | intermediate_act_fn=gelu, 770 | hidden_dropout_prob=0.1, 771 | attention_probs_dropout_prob=0.1, 772 | initializer_range=0.02, 773 | do_return_all_layers=False): 774 | """Multi-headed, multi-layer Transformer from "Attention is All You Need". 775 | 776 | This is almost an exact implementation of the original Transformer encoder. 777 | 778 | See the original paper: 779 | https://arxiv.org/abs/1706.03762 780 | 781 | Also see: 782 | https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/models/transformer.py 783 | 784 | Args: 785 | input_tensor: float Tensor of shape [batch_size, seq_length, hidden_size]. 786 | attention_mask: (optional) int32 Tensor of shape [batch_size, seq_length, 787 | seq_length], with 1 for positions that can be attended to and 0 in 788 | positions that should not be. 789 | hidden_size: int. Hidden size of the Transformer. 790 | num_hidden_layers: int. Number of layers (blocks) in the Transformer. 791 | num_attention_heads: int. Number of attention heads in the Transformer. 792 | intermediate_size: int. The size of the "intermediate" (a.k.a., feed 793 | forward) layer. 794 | intermediate_act_fn: function. The non-linear activation function to apply 795 | to the output of the intermediate/feed-forward layer. 796 | hidden_dropout_prob: float. Dropout probability for the hidden layers. 797 | attention_probs_dropout_prob: float. Dropout probability of the attention 798 | probabilities. 799 | initializer_range: float. Range of the initializer (stddev of truncated 800 | normal). 801 | do_return_all_layers: Whether to also return all layers or just the final 802 | layer. 803 | 804 | Returns: 805 | float Tensor of shape [batch_size, seq_length, hidden_size], the final 806 | hidden layer of the Transformer. 807 | 808 | Raises: 809 | ValueError: A Tensor shape or parameter is invalid. 810 | """ 811 | if hidden_size % num_attention_heads != 0: 812 | raise ValueError( 813 | "The hidden size (%d) is not a multiple of the number of attention " 814 | "heads (%d)" % (hidden_size, num_attention_heads)) 815 | 816 | attention_head_size = int(hidden_size / num_attention_heads) 817 | input_shape = get_shape_list(input_tensor, expected_rank=3) 818 | batch_size = input_shape[0] 819 | seq_length = input_shape[1] 820 | input_width = input_shape[2] 821 | 822 | # The Transformer performs sum residuals on all layers so the input needs 823 | # to be the same as the hidden size. 824 | if input_width != hidden_size: 825 | raise ValueError("The width of the input tensor (%d) != hidden size (%d)" % 826 | (input_width, hidden_size)) 827 | 828 | # We keep the representation as a 2D tensor to avoid re-shaping it back and 829 | # forth from a 3D tensor to a 2D tensor. Re-shapes are normally free on 830 | # the GPU/CPU but may not be free on the TPU, so we want to minimize them to 831 | # help the optimizer. 832 | prev_output = reshape_to_matrix(input_tensor) 833 | 834 | all_layer_outputs = [] 835 | for layer_idx in range(num_hidden_layers): 836 | with tf.variable_scope("layer_%d" % layer_idx): 837 | layer_input = prev_output 838 | 839 | with tf.variable_scope("attention"): 840 | attention_heads = [] 841 | with tf.variable_scope("self"): 842 | attention_head = attention_layer( 843 | from_tensor=layer_input, 844 | to_tensor=layer_input, 845 | attention_mask=attention_mask, 846 | num_attention_heads=num_attention_heads, 847 | size_per_head=attention_head_size, 848 | attention_probs_dropout_prob=attention_probs_dropout_prob, 849 | initializer_range=initializer_range, 850 | do_return_2d_tensor=True, 851 | batch_size=batch_size, 852 | from_seq_length=seq_length, 853 | to_seq_length=seq_length) 854 | attention_heads.append(attention_head) 855 | 856 | attention_output = None 857 | if len(attention_heads) == 1: 858 | attention_output = attention_heads[0] 859 | else: 860 | # In the case where we have other sequences, we just concatenate 861 | # them to the self-attention head before the projection. 862 | attention_output = tf.concat(attention_heads, axis=-1) 863 | 864 | # Run a linear projection of `hidden_size` then add a residual 865 | # with `layer_input`. 866 | with tf.variable_scope("output"): 867 | attention_output = tf.layers.dense( 868 | attention_output, 869 | hidden_size, 870 | kernel_initializer=create_initializer(initializer_range)) 871 | attention_output = dropout(attention_output, hidden_dropout_prob) 872 | attention_output = layer_norm(attention_output + layer_input) 873 | 874 | # The activation is only applied to the "intermediate" hidden layer. 875 | with tf.variable_scope("intermediate"): 876 | intermediate_output = tf.layers.dense( 877 | attention_output, 878 | intermediate_size, 879 | activation=intermediate_act_fn, 880 | kernel_initializer=create_initializer(initializer_range)) 881 | 882 | # Down-project back to `hidden_size` then add the residual. 883 | with tf.variable_scope("output"): 884 | layer_output = tf.layers.dense( 885 | intermediate_output, 886 | hidden_size, 887 | kernel_initializer=create_initializer(initializer_range)) 888 | layer_output = dropout(layer_output, hidden_dropout_prob) 889 | layer_output = layer_norm(layer_output + attention_output) 890 | prev_output = layer_output 891 | all_layer_outputs.append(layer_output) 892 | 893 | if do_return_all_layers: 894 | final_outputs = [] 895 | for layer_output in all_layer_outputs: 896 | final_output = reshape_from_matrix(layer_output, input_shape) 897 | final_outputs.append(final_output) 898 | return final_outputs 899 | else: 900 | final_output = reshape_from_matrix(prev_output, input_shape) 901 | return final_output 902 | 903 | 904 | def get_shape_list(tensor, expected_rank=None, name=None): 905 | """Returns a list of the shape of tensor, preferring static dimensions. 906 | 907 | Args: 908 | tensor: A tf.Tensor object to find the shape of. 909 | expected_rank: (optional) int. The expected rank of `tensor`. If this is 910 | specified and the `tensor` has a different rank, and exception will be 911 | thrown. 912 | name: Optional name of the tensor for the error message. 913 | 914 | Returns: 915 | A list of dimensions of the shape of tensor. All static dimensions will 916 | be returned as python integers, and dynamic dimensions will be returned 917 | as tf.Tensor scalars. 918 | """ 919 | if name is None: 920 | name = tensor.name 921 | 922 | if expected_rank is not None: 923 | assert_rank(tensor, expected_rank, name) 924 | 925 | shape = tensor.shape.as_list() 926 | 927 | non_static_indexes = [] 928 | for (index, dim) in enumerate(shape): 929 | if dim is None: 930 | non_static_indexes.append(index) 931 | 932 | if not non_static_indexes: 933 | return shape 934 | 935 | dyn_shape = tf.shape(tensor) 936 | for index in non_static_indexes: 937 | shape[index] = dyn_shape[index] 938 | return shape 939 | 940 | 941 | def reshape_to_matrix(input_tensor): 942 | """Reshapes a >= rank 2 tensor to a rank 2 tensor (i.e., a matrix).""" 943 | ndims = input_tensor.shape.ndims 944 | if ndims < 2: 945 | raise ValueError("Input tensor must have at least rank 2. Shape = %s" % 946 | (input_tensor.shape)) 947 | if ndims == 2: 948 | return input_tensor 949 | 950 | width = input_tensor.shape[-1] 951 | output_tensor = tf.reshape(input_tensor, [-1, width]) 952 | return output_tensor 953 | 954 | 955 | def reshape_from_matrix(output_tensor, orig_shape_list): 956 | """Reshapes a rank 2 tensor back to its original rank >= 2 tensor.""" 957 | if len(orig_shape_list) == 2: 958 | return output_tensor 959 | 960 | output_shape = get_shape_list(output_tensor) 961 | 962 | orig_dims = orig_shape_list[0:-1] 963 | width = output_shape[-1] 964 | 965 | return tf.reshape(output_tensor, orig_dims + [width]) 966 | 967 | 968 | def assert_rank(tensor, expected_rank, name=None): 969 | """Raises an exception if the tensor rank is not of the expected rank. 970 | 971 | Args: 972 | tensor: A tf.Tensor to check the rank of. 973 | expected_rank: Python integer or list of integers, expected rank. 974 | name: Optional name of the tensor for the error message. 975 | 976 | Raises: 977 | ValueError: If the expected shape doesn't match the actual shape. 978 | """ 979 | if name is None: 980 | name = tensor.name 981 | 982 | expected_rank_dict = {} 983 | if isinstance(expected_rank, six.integer_types): 984 | expected_rank_dict[expected_rank] = True 985 | else: 986 | for x in expected_rank: 987 | expected_rank_dict[x] = True 988 | 989 | actual_rank = tensor.shape.ndims 990 | if actual_rank not in expected_rank_dict: 991 | scope_name = tf.get_variable_scope().name 992 | raise ValueError( 993 | "For the tensor `%s` in scope `%s`, the actual rank " 994 | "`%d` (shape = %s) is not equal to the expected rank `%s`" % 995 | (name, scope_name, actual_rank, str(tensor.shape), str(expected_rank))) 996 | -------------------------------------------------------------------------------- /modeling_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import collections 20 | import json 21 | import random 22 | import re 23 | 24 | import modeling 25 | import six 26 | import tensorflow as tf 27 | 28 | 29 | class BertModelTest(tf.test.TestCase): 30 | 31 | class BertModelTester(object): 32 | 33 | def __init__(self, 34 | parent, 35 | batch_size=13, 36 | seq_length=7, 37 | is_training=True, 38 | use_input_mask=True, 39 | use_token_type_ids=True, 40 | vocab_size=99, 41 | hidden_size=32, 42 | num_hidden_layers=5, 43 | num_attention_heads=4, 44 | intermediate_size=37, 45 | hidden_act="gelu", 46 | hidden_dropout_prob=0.1, 47 | attention_probs_dropout_prob=0.1, 48 | max_position_embeddings=512, 49 | type_vocab_size=16, 50 | initializer_range=0.02, 51 | scope=None): 52 | self.parent = parent 53 | self.batch_size = batch_size 54 | self.seq_length = seq_length 55 | self.is_training = is_training 56 | self.use_input_mask = use_input_mask 57 | self.use_token_type_ids = use_token_type_ids 58 | self.vocab_size = vocab_size 59 | self.hidden_size = hidden_size 60 | self.num_hidden_layers = num_hidden_layers 61 | self.num_attention_heads = num_attention_heads 62 | self.intermediate_size = intermediate_size 63 | self.hidden_act = hidden_act 64 | self.hidden_dropout_prob = hidden_dropout_prob 65 | self.attention_probs_dropout_prob = attention_probs_dropout_prob 66 | self.max_position_embeddings = max_position_embeddings 67 | self.type_vocab_size = type_vocab_size 68 | self.initializer_range = initializer_range 69 | self.scope = scope 70 | 71 | def create_model(self): 72 | input_ids = BertModelTest.ids_tensor([self.batch_size, self.seq_length], 73 | self.vocab_size) 74 | 75 | input_mask = None 76 | if self.use_input_mask: 77 | input_mask = BertModelTest.ids_tensor( 78 | [self.batch_size, self.seq_length], vocab_size=2) 79 | 80 | token_type_ids = None 81 | if self.use_token_type_ids: 82 | token_type_ids = BertModelTest.ids_tensor( 83 | [self.batch_size, self.seq_length], self.type_vocab_size) 84 | 85 | config = modeling.BertConfig( 86 | vocab_size=self.vocab_size, 87 | hidden_size=self.hidden_size, 88 | num_hidden_layers=self.num_hidden_layers, 89 | num_attention_heads=self.num_attention_heads, 90 | intermediate_size=self.intermediate_size, 91 | hidden_act=self.hidden_act, 92 | hidden_dropout_prob=self.hidden_dropout_prob, 93 | attention_probs_dropout_prob=self.attention_probs_dropout_prob, 94 | max_position_embeddings=self.max_position_embeddings, 95 | type_vocab_size=self.type_vocab_size, 96 | initializer_range=self.initializer_range) 97 | 98 | model = modeling.BertModel( 99 | config=config, 100 | is_training=self.is_training, 101 | input_ids=input_ids, 102 | input_mask=input_mask, 103 | token_type_ids=token_type_ids, 104 | scope=self.scope) 105 | 106 | outputs = { 107 | "embedding_output": model.get_embedding_output(), 108 | "sequence_output": model.get_sequence_output(), 109 | "pooled_output": model.get_pooled_output(), 110 | "all_encoder_layers": model.get_all_encoder_layers(), 111 | } 112 | return outputs 113 | 114 | def check_output(self, result): 115 | self.parent.assertAllEqual( 116 | result["embedding_output"].shape, 117 | [self.batch_size, self.seq_length, self.hidden_size]) 118 | 119 | self.parent.assertAllEqual( 120 | result["sequence_output"].shape, 121 | [self.batch_size, self.seq_length, self.hidden_size]) 122 | 123 | self.parent.assertAllEqual(result["pooled_output"].shape, 124 | [self.batch_size, self.hidden_size]) 125 | 126 | def test_default(self): 127 | self.run_tester(BertModelTest.BertModelTester(self)) 128 | 129 | def test_config_to_json_string(self): 130 | config = modeling.BertConfig(vocab_size=99, hidden_size=37) 131 | obj = json.loads(config.to_json_string()) 132 | self.assertEqual(obj["vocab_size"], 99) 133 | self.assertEqual(obj["hidden_size"], 37) 134 | 135 | def run_tester(self, tester): 136 | with self.test_session() as sess: 137 | ops = tester.create_model() 138 | init_op = tf.group(tf.global_variables_initializer(), 139 | tf.local_variables_initializer()) 140 | sess.run(init_op) 141 | output_result = sess.run(ops) 142 | tester.check_output(output_result) 143 | 144 | self.assert_all_tensors_reachable(sess, [init_op, ops]) 145 | 146 | @classmethod 147 | def ids_tensor(cls, shape, vocab_size, rng=None, name=None): 148 | """Creates a random int32 tensor of the shape within the vocab size.""" 149 | if rng is None: 150 | rng = random.Random() 151 | 152 | total_dims = 1 153 | for dim in shape: 154 | total_dims *= dim 155 | 156 | values = [] 157 | for _ in range(total_dims): 158 | values.append(rng.randint(0, vocab_size - 1)) 159 | 160 | return tf.constant(value=values, dtype=tf.int32, shape=shape, name=name) 161 | 162 | def assert_all_tensors_reachable(self, sess, outputs): 163 | """Checks that all the tensors in the graph are reachable from outputs.""" 164 | graph = sess.graph 165 | 166 | ignore_strings = [ 167 | "^.*/dilation_rate$", 168 | "^.*/Tensordot/concat$", 169 | "^.*/Tensordot/concat/axis$", 170 | "^testing/.*$", 171 | ] 172 | 173 | ignore_regexes = [re.compile(x) for x in ignore_strings] 174 | 175 | unreachable = self.get_unreachable_ops(graph, outputs) 176 | filtered_unreachable = [] 177 | for x in unreachable: 178 | do_ignore = False 179 | for r in ignore_regexes: 180 | m = r.match(x.name) 181 | if m is not None: 182 | do_ignore = True 183 | if do_ignore: 184 | continue 185 | filtered_unreachable.append(x) 186 | unreachable = filtered_unreachable 187 | 188 | self.assertEqual( 189 | len(unreachable), 0, "The following ops are unreachable: %s" % 190 | (" ".join([x.name for x in unreachable]))) 191 | 192 | @classmethod 193 | def get_unreachable_ops(cls, graph, outputs): 194 | """Finds all of the tensors in graph that are unreachable from outputs.""" 195 | outputs = cls.flatten_recursive(outputs) 196 | output_to_op = collections.defaultdict(list) 197 | op_to_all = collections.defaultdict(list) 198 | assign_out_to_in = collections.defaultdict(list) 199 | 200 | for op in graph.get_operations(): 201 | for x in op.inputs: 202 | op_to_all[op.name].append(x.name) 203 | for y in op.outputs: 204 | output_to_op[y.name].append(op.name) 205 | op_to_all[op.name].append(y.name) 206 | if str(op.type) == "Assign": 207 | for y in op.outputs: 208 | for x in op.inputs: 209 | assign_out_to_in[y.name].append(x.name) 210 | 211 | assign_groups = collections.defaultdict(list) 212 | for out_name in assign_out_to_in.keys(): 213 | name_group = assign_out_to_in[out_name] 214 | for n1 in name_group: 215 | assign_groups[n1].append(out_name) 216 | for n2 in name_group: 217 | if n1 != n2: 218 | assign_groups[n1].append(n2) 219 | 220 | seen_tensors = {} 221 | stack = [x.name for x in outputs] 222 | while stack: 223 | name = stack.pop() 224 | if name in seen_tensors: 225 | continue 226 | seen_tensors[name] = True 227 | 228 | if name in output_to_op: 229 | for op_name in output_to_op[name]: 230 | if op_name in op_to_all: 231 | for input_name in op_to_all[op_name]: 232 | if input_name not in stack: 233 | stack.append(input_name) 234 | 235 | expanded_names = [] 236 | if name in assign_groups: 237 | for assign_name in assign_groups[name]: 238 | expanded_names.append(assign_name) 239 | 240 | for expanded_name in expanded_names: 241 | if expanded_name not in stack: 242 | stack.append(expanded_name) 243 | 244 | unreachable_ops = [] 245 | for op in graph.get_operations(): 246 | is_unreachable = False 247 | all_names = [x.name for x in op.inputs] + [x.name for x in op.outputs] 248 | for name in all_names: 249 | if name not in seen_tensors: 250 | is_unreachable = True 251 | if is_unreachable: 252 | unreachable_ops.append(op) 253 | return unreachable_ops 254 | 255 | @classmethod 256 | def flatten_recursive(cls, item): 257 | """Flattens (potentially nested) a tuple/dictionary/list to a list.""" 258 | output = [] 259 | if isinstance(item, list): 260 | output.extend(item) 261 | elif isinstance(item, tuple): 262 | output.extend(list(item)) 263 | elif isinstance(item, dict): 264 | for (_, v) in six.iteritems(item): 265 | output.append(v) 266 | else: 267 | return [item] 268 | 269 | flat_output = [] 270 | for x in output: 271 | flat_output.extend(cls.flatten_recursive(x)) 272 | return flat_output 273 | 274 | 275 | if __name__ == "__main__": 276 | tf.test.main() 277 | -------------------------------------------------------------------------------- /msmarco_eval.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module computes evaluation metrics for MSMARCO dataset on the ranking task. 3 | Command line: 4 | python msmarco_eval_ranking.py 5 | 6 | Creation Date : 06/12/2018 7 | Last Modified : 1/21/2019 8 | Authors : Daniel Campos , Rutger van Haasteren 9 | """ 10 | import sys 11 | import statistics 12 | 13 | from collections import Counter 14 | 15 | MaxMRRRank = 10 16 | 17 | def load_reference_from_stream(f): 18 | """Load Reference reference relevant passages 19 | Args:f (stream): stream to load. 20 | Returns:qids_to_relevant_passageids (dict): dictionary mapping from query_id (int) to relevant passages (list of ints). 21 | """ 22 | qids_to_relevant_passageids = {} 23 | for l in f: 24 | try: 25 | l = l.strip().split('\t') 26 | qid = int(l[0]) 27 | if qid in qids_to_relevant_passageids: 28 | pass 29 | else: 30 | qids_to_relevant_passageids[qid] = [] 31 | qids_to_relevant_passageids[qid].append(int(l[2])) 32 | except: 33 | raise IOError('\"%s\" is not valid format' % l) 34 | return qids_to_relevant_passageids 35 | 36 | def load_reference(path_to_reference): 37 | """Load Reference reference relevant passages 38 | Args:path_to_reference (str): path to a file to load. 39 | Returns:qids_to_relevant_passageids (dict): dictionary mapping from query_id (int) to relevant passages (list of ints). 40 | """ 41 | with open(path_to_reference,'r') as f: 42 | qids_to_relevant_passageids = load_reference_from_stream(f) 43 | return qids_to_relevant_passageids 44 | 45 | def load_candidate_from_stream(f): 46 | """Load candidate data from a stream. 47 | Args:f (stream): stream to load. 48 | Returns:qid_to_ranked_candidate_passages (dict): dictionary mapping from query_id (int) to a list of 1000 passage ids(int) ranked by relevance and importance 49 | """ 50 | qid_to_ranked_candidate_passages = {} 51 | for l in f: 52 | try: 53 | l = l.strip().split('\t') 54 | qid = int(l[0]) 55 | pid = int(l[1]) 56 | rank = int(l[2]) 57 | if qid in qid_to_ranked_candidate_passages: 58 | pass 59 | else: 60 | # By default, all PIDs in the list of 1000 are 0. Only override those that are given 61 | tmp = [0] * 1000 62 | qid_to_ranked_candidate_passages[qid] = tmp 63 | qid_to_ranked_candidate_passages[qid][rank-1]=pid 64 | except: 65 | raise IOError('\"%s\" is not valid format' % l) 66 | return qid_to_ranked_candidate_passages 67 | 68 | def load_candidate(path_to_candidate): 69 | """Load candidate data from a file. 70 | Args:path_to_candidate (str): path to file to load. 71 | Returns:qid_to_ranked_candidate_passages (dict): dictionary mapping from query_id (int) to a list of 1000 passage ids(int) ranked by relevance and importance 72 | """ 73 | 74 | with open(path_to_candidate,'r') as f: 75 | qid_to_ranked_candidate_passages = load_candidate_from_stream(f) 76 | return qid_to_ranked_candidate_passages 77 | 78 | def quality_checks_qids(qids_to_relevant_passageids, qids_to_ranked_candidate_passages): 79 | """Perform quality checks on the dictionaries 80 | 81 | Args: 82 | p_qids_to_relevant_passageids (dict): dictionary of query-passage mapping 83 | Dict as read in with load_reference or load_reference_from_stream 84 | p_qids_to_ranked_candidate_passages (dict): dictionary of query-passage candidates 85 | Returns: 86 | bool,str: Boolean whether allowed, message to be shown in case of a problem 87 | """ 88 | message = '' 89 | allowed = True 90 | 91 | # Create sets of the QIDs for the submitted and reference queries 92 | candidate_set = set(qids_to_ranked_candidate_passages.keys()) 93 | ref_set = set(qids_to_relevant_passageids.keys()) 94 | 95 | # Check that we do not have multiple passages per query 96 | for qid in qids_to_ranked_candidate_passages: 97 | # Remove all zeros from the candidates 98 | duplicate_pids = set([item for item, count in Counter(qids_to_ranked_candidate_passages[qid]).items() if count > 1]) 99 | 100 | if len(duplicate_pids-set([0])) > 0: 101 | message = "Cannot rank a passage multiple times for a single query. QID={qid}, PID={pid}".format( 102 | qid=qid, pid=list(duplicate_pids)[0]) 103 | allowed = False 104 | 105 | return allowed, message 106 | 107 | def compute_metrics(qids_to_relevant_passageids, qids_to_ranked_candidate_passages): 108 | """Compute MRR metric 109 | Args: 110 | p_qids_to_relevant_passageids (dict): dictionary of query-passage mapping 111 | Dict as read in with load_reference or load_reference_from_stream 112 | p_qids_to_ranked_candidate_passages (dict): dictionary of query-passage candidates 113 | Returns: 114 | dict: dictionary of metrics {'MRR': } 115 | """ 116 | all_scores = {} 117 | MRR = 0 118 | qids_with_relevant_passages = 0 119 | ranking = [] 120 | for qid in qids_to_ranked_candidate_passages: 121 | if qid in qids_to_relevant_passageids: 122 | ranking.append(0) 123 | target_pid = qids_to_relevant_passageids[qid] 124 | candidate_pid = qids_to_ranked_candidate_passages[qid] 125 | for i in range(0,MaxMRRRank): 126 | if candidate_pid[i] in target_pid: 127 | MRR += 1/(i + 1) 128 | ranking.pop() 129 | ranking.append(i+1) 130 | break 131 | if len(ranking) == 0: 132 | raise IOError("No matching QIDs found. Are you sure you are scoring the evaluation set?") 133 | 134 | MRR = MRR/len(qids_to_relevant_passageids) 135 | all_scores['MRR @10'] = MRR 136 | all_scores['QueriesRanked'] = len(qids_to_ranked_candidate_passages) 137 | return all_scores 138 | 139 | def compute_metrics_from_files(path_to_reference, path_to_candidate, perform_checks=True): 140 | """Compute MRR metric 141 | Args: 142 | p_path_to_reference_file (str): path to reference file. 143 | Reference file should contain lines in the following format: 144 | QUERYID\tPASSAGEID 145 | Where PASSAGEID is a relevant passage for a query. Note QUERYID can repeat on different lines with different PASSAGEIDs 146 | p_path_to_candidate_file (str): path to candidate file. 147 | Candidate file sould contain lines in the following format: 148 | QUERYID\tPASSAGEID1\tRank 149 | If a user wishes to use the TREC format please run the script with a -t flag at the end. If this flag is used the expected format is 150 | QUERYID\tITER\tDOCNO\tRANK\tSIM\tRUNID 151 | Where the values are separated by tabs and ranked in order of relevance 152 | Returns: 153 | dict: dictionary of metrics {'MRR': } 154 | """ 155 | 156 | qids_to_relevant_passageids = load_reference(path_to_reference) 157 | qids_to_ranked_candidate_passages = load_candidate(path_to_candidate) 158 | if perform_checks: 159 | allowed, message = quality_checks_qids(qids_to_relevant_passageids, qids_to_ranked_candidate_passages) 160 | if message != '': print(message) 161 | 162 | return compute_metrics(qids_to_relevant_passageids, qids_to_ranked_candidate_passages) 163 | 164 | def main(): 165 | """Command line: 166 | python msmarco_eval_ranking.py 167 | """ 168 | 169 | if len(sys.argv) == 3: 170 | path_to_reference = sys.argv[1] 171 | path_to_candidate = sys.argv[2] 172 | metrics = compute_metrics_from_files(path_to_reference, path_to_candidate) 173 | print('#####################') 174 | for metric in sorted(metrics): 175 | print('{}: {}'.format(metric, metrics[metric])) 176 | print('#####################') 177 | 178 | else: 179 | print('Usage: msmarco_eval_ranking.py ') 180 | exit() 181 | 182 | if __name__ == '__main__': 183 | main() 184 | 185 | -------------------------------------------------------------------------------- /optimization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Functions and classes related to optimization (weight updates).""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import re 22 | import tensorflow as tf 23 | 24 | 25 | def create_optimizer(loss, init_lr, num_train_steps, num_warmup_steps, use_tpu): 26 | """Creates an optimizer training op.""" 27 | global_step = tf.train.get_or_create_global_step() 28 | 29 | learning_rate = tf.constant(value=init_lr, shape=[], dtype=tf.float32) 30 | 31 | # Implements linear decay of the learning rate. 32 | learning_rate = tf.train.polynomial_decay( 33 | learning_rate, 34 | global_step, 35 | num_train_steps, 36 | end_learning_rate=0.0, 37 | power=1.0, 38 | cycle=False) 39 | 40 | # Implements linear warmup. I.e., if global_step < num_warmup_steps, the 41 | # learning rate will be `global_step/num_warmup_steps * init_lr`. 42 | if num_warmup_steps: 43 | global_steps_int = tf.cast(global_step, tf.int32) 44 | warmup_steps_int = tf.constant(num_warmup_steps, dtype=tf.int32) 45 | 46 | global_steps_float = tf.cast(global_steps_int, tf.float32) 47 | warmup_steps_float = tf.cast(warmup_steps_int, tf.float32) 48 | 49 | warmup_percent_done = global_steps_float / warmup_steps_float 50 | warmup_learning_rate = init_lr * warmup_percent_done 51 | 52 | is_warmup = tf.cast(global_steps_int < warmup_steps_int, tf.float32) 53 | learning_rate = ( 54 | (1.0 - is_warmup) * learning_rate + is_warmup * warmup_learning_rate) 55 | 56 | # It is recommended that you use this optimizer for fine tuning, since this 57 | # is how the model was trained (note that the Adam m/v variables are NOT 58 | # loaded from init_checkpoint.) 59 | optimizer = AdamWeightDecayOptimizer( 60 | learning_rate=learning_rate, 61 | weight_decay_rate=0.01, 62 | beta_1=0.9, 63 | beta_2=0.999, 64 | epsilon=1e-6, 65 | exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"]) 66 | 67 | if use_tpu: 68 | optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer) 69 | 70 | tvars = tf.trainable_variables() 71 | grads = tf.gradients(loss, tvars) 72 | 73 | # This is how the model was pre-trained. 74 | (grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0) 75 | 76 | train_op = optimizer.apply_gradients( 77 | zip(grads, tvars), global_step=global_step) 78 | 79 | new_global_step = global_step + 1 80 | train_op = tf.group(train_op, [global_step.assign(new_global_step)]) 81 | return train_op 82 | 83 | 84 | class AdamWeightDecayOptimizer(tf.train.Optimizer): 85 | """A basic Adam optimizer that includes "correct" L2 weight decay.""" 86 | 87 | def __init__(self, 88 | learning_rate, 89 | weight_decay_rate=0.0, 90 | beta_1=0.9, 91 | beta_2=0.999, 92 | epsilon=1e-6, 93 | exclude_from_weight_decay=None, 94 | name="AdamWeightDecayOptimizer"): 95 | """Constructs a AdamWeightDecayOptimizer.""" 96 | super(AdamWeightDecayOptimizer, self).__init__(False, name) 97 | 98 | self.learning_rate = learning_rate 99 | self.weight_decay_rate = weight_decay_rate 100 | self.beta_1 = beta_1 101 | self.beta_2 = beta_2 102 | self.epsilon = epsilon 103 | self.exclude_from_weight_decay = exclude_from_weight_decay 104 | 105 | def apply_gradients(self, grads_and_vars, global_step=None, name=None): 106 | """See base class.""" 107 | assignments = [] 108 | for (grad, param) in grads_and_vars: 109 | if grad is None or param is None: 110 | continue 111 | 112 | param_name = self._get_variable_name(param.name) 113 | 114 | m = tf.get_variable( 115 | name=param_name + "/adam_m", 116 | shape=param.shape.as_list(), 117 | dtype=tf.float32, 118 | trainable=False, 119 | initializer=tf.zeros_initializer()) 120 | v = tf.get_variable( 121 | name=param_name + "/adam_v", 122 | shape=param.shape.as_list(), 123 | dtype=tf.float32, 124 | trainable=False, 125 | initializer=tf.zeros_initializer()) 126 | 127 | # Standard Adam update. 128 | next_m = ( 129 | tf.multiply(self.beta_1, m) + tf.multiply(1.0 - self.beta_1, grad)) 130 | next_v = ( 131 | tf.multiply(self.beta_2, v) + tf.multiply(1.0 - self.beta_2, 132 | tf.square(grad))) 133 | 134 | update = next_m / (tf.sqrt(next_v) + self.epsilon) 135 | 136 | # Just adding the square of the weights to the loss function is *not* 137 | # the correct way of using L2 regularization/weight decay with Adam, 138 | # since that will interact with the m and v parameters in strange ways. 139 | # 140 | # Instead we want ot decay the weights in a manner that doesn't interact 141 | # with the m/v parameters. This is equivalent to adding the square 142 | # of the weights to the loss with plain (non-momentum) SGD. 143 | if self._do_use_weight_decay(param_name): 144 | update += self.weight_decay_rate * param 145 | 146 | update_with_lr = self.learning_rate * update 147 | 148 | next_param = param - update_with_lr 149 | 150 | assignments.extend( 151 | [param.assign(next_param), 152 | m.assign(next_m), 153 | v.assign(next_v)]) 154 | return tf.group(*assignments, name=name) 155 | 156 | def _do_use_weight_decay(self, param_name): 157 | """Whether to use L2 weight decay for `param_name`.""" 158 | if not self.weight_decay_rate: 159 | return False 160 | if self.exclude_from_weight_decay: 161 | for r in self.exclude_from_weight_decay: 162 | if re.search(r, param_name) is not None: 163 | return False 164 | return True 165 | 166 | def _get_variable_name(self, param_name): 167 | """Get the variable name from the tensor name.""" 168 | m = re.match("^(.*):\\d+$", param_name) 169 | if m is not None: 170 | param_name = m.group(1) 171 | return param_name 172 | -------------------------------------------------------------------------------- /optimization_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import optimization 20 | import tensorflow as tf 21 | 22 | 23 | class OptimizationTest(tf.test.TestCase): 24 | 25 | def test_adam(self): 26 | with self.test_session() as sess: 27 | w = tf.get_variable( 28 | "w", 29 | shape=[3], 30 | initializer=tf.constant_initializer([0.1, -0.2, -0.1])) 31 | x = tf.constant([0.4, 0.2, -0.5]) 32 | loss = tf.reduce_mean(tf.square(x - w)) 33 | tvars = tf.trainable_variables() 34 | grads = tf.gradients(loss, tvars) 35 | global_step = tf.train.get_or_create_global_step() 36 | optimizer = optimization.AdamWeightDecayOptimizer(learning_rate=0.2) 37 | train_op = optimizer.apply_gradients(zip(grads, tvars), global_step) 38 | init_op = tf.group(tf.global_variables_initializer(), 39 | tf.local_variables_initializer()) 40 | sess.run(init_op) 41 | for _ in range(100): 42 | sess.run(train_op) 43 | w_np = sess.run(w) 44 | self.assertAllClose(w_np.flat, [0.4, 0.2, -0.5], rtol=1e-2, atol=1e-2) 45 | 46 | 47 | if __name__ == "__main__": 48 | tf.test.main() 49 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tensorflow >= 1.11.0 # CPU Version of TensorFlow. 2 | # tensorflow-gpu >= 1.11.0 # GPU version of TensorFlow. 3 | tqdm 4 | -------------------------------------------------------------------------------- /run_duobert_msmarco.py: -------------------------------------------------------------------------------- 1 | """Code to train and eval a duoBERT re-ranker on the MS MARCO dataset.""" 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | import time 7 | 8 | import numpy as np 9 | import tensorflow as tf 10 | 11 | # local modules 12 | import metrics 13 | import modeling 14 | import optimization 15 | 16 | flags = tf.flags 17 | 18 | FLAGS = flags.FLAGS 19 | 20 | # Required parameters 21 | flags.DEFINE_string( 22 | 'data_dir', 23 | './data/tfrecord/', 24 | 'The input data dir. Should contain the .tfrecord files and the ' 25 | 'supporting query-docids mapping files.') 26 | 27 | flags.DEFINE_string( 28 | 'bert_config_file', 29 | './data/bert/pretrained_models/uncased_L-24_H-1024_A-16/bert_config.json', 30 | 'The config json file corresponding to the pre-trained BERT model. ' 31 | 'This specifies the model architecture.') 32 | 33 | flags.DEFINE_string( 34 | 'output_dir', './data/output', 35 | 'The output directory where the model checkpoints will be written.') 36 | 37 | flags.DEFINE_string( 38 | 'init_checkpoint', 39 | './data/bert/pretrained_models/uncased_L-24_H-1024_A-16/bert_model.ckpt', 40 | 'Initial checkpoint (usually from a pre-trained BERT model).') 41 | 42 | flags.DEFINE_integer( 43 | 'max_seq_length', 512, 44 | 'The maximum total input sequence length after WordPiece tokenization. ' 45 | 'Sequences longer than this will be truncated, and sequences shorter ' 46 | 'than this will be padded.') 47 | 48 | flags.DEFINE_bool('do_train', True, 'Whether to run training.') 49 | 50 | flags.DEFINE_bool('do_eval', True, 'Whether to run eval on the dev set.') 51 | 52 | flags.DEFINE_integer('train_batch_size', 128, 'Total batch size for training.') 53 | 54 | flags.DEFINE_integer('eval_batch_size', 128, 'Total batch size for eval.') 55 | 56 | flags.DEFINE_float( 57 | 'learning_rate', 3e-6, 'The initial learning rate for Adam.') 58 | 59 | flags.DEFINE_integer( 60 | 'num_train_steps', 100000, 'Total number of training steps to perform.') 61 | 62 | flags.DEFINE_integer( 63 | 'max_eval_examples', None, 'Maximum number of examples to be evaluated.') 64 | 65 | flags.DEFINE_integer( 66 | 'num_eval_docs', 30, 67 | 'Number of docs per query in the dev and eval files.') 68 | 69 | flags.DEFINE_string( 70 | 'pad_doc_id', '5500000', 71 | 'ID of the pad document that will removed from the predictions. This pad ' 72 | 'document is added to TF Records whenever the number of retrieved ' 73 | 'documents is lower than num_eval_docs. This doc id must be the same used ' 74 | 'in convert_msmarco_to_duobert_tfrecord.py script.') 75 | 76 | flags.DEFINE_integer( 77 | 'num_warmup_steps', 10000, 78 | 'Number of training steps to perform linear learning rate warmup.') 79 | 80 | flags.DEFINE_integer( 81 | 'save_checkpoints_steps', 1000, 82 | 'How often to save the model checkpoint.') 83 | 84 | flags.DEFINE_integer( 85 | 'iterations_per_loop', 1000, 86 | 'How many steps to make in each estimator call.') 87 | 88 | flags.DEFINE_bool('use_tpu', False, 'Whether to use TPU or GPU/CPU.') 89 | 90 | tf.flags.DEFINE_string( 91 | 'tpu_name', None, 92 | 'The Cloud TPU to use for training. This should be either the name ' 93 | 'used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 ' 94 | 'url.') 95 | 96 | tf.flags.DEFINE_string( 97 | 'tpu_zone', None, 98 | '[Optional] GCE zone where the Cloud TPU is located in. If not ' 99 | 'specified, we will attempt to automatically detect the GCE project from ' 100 | 'metadata.') 101 | 102 | tf.flags.DEFINE_string( 103 | 'gcp_project', None, 104 | '[Optional] Project name for the Cloud TPU-enabled project. If not ' 105 | 'specified, we will attempt to automatically detect the GCE project from ' 106 | 'metadata.') 107 | 108 | tf.flags.DEFINE_string('master', None, '[Optional] TensorFlow master URL.') 109 | 110 | flags.DEFINE_integer( 111 | 'num_tpu_cores', 8, 112 | 'Only used if `use_tpu` is True. Total number of TPU cores to use.') 113 | 114 | 115 | METRICS_MAP = ['MAP', 'RPrec', 'MRR', 'NDCG', 'MRR@10'] 116 | 117 | 118 | def create_model(bert_config, is_training, input_ids, input_mask, segment_ids, 119 | labels, num_labels, use_one_hot_embeddings): 120 | """Creates a classification model.""" 121 | model = modeling.BertModel( 122 | config=bert_config, 123 | is_training=is_training, 124 | input_ids=input_ids, 125 | input_mask=input_mask, 126 | token_type_ids=segment_ids, 127 | use_one_hot_embeddings=use_one_hot_embeddings) 128 | 129 | output_layer = model.get_pooled_output() 130 | 131 | hidden_size = output_layer.shape[-1].value 132 | 133 | output_weights = tf.get_variable( 134 | "output_weights", [num_labels, hidden_size], 135 | initializer=tf.truncated_normal_initializer(stddev=0.02)) 136 | 137 | output_bias = tf.get_variable( 138 | "output_bias", [num_labels], initializer=tf.zeros_initializer()) 139 | 140 | with tf.variable_scope("loss"): 141 | if is_training: 142 | # I.e., 0.1 dropout 143 | output_layer = tf.nn.dropout(output_layer, keep_prob=0.9) 144 | 145 | logits = tf.matmul(output_layer, output_weights, transpose_b=True) 146 | logits = tf.nn.bias_add(logits, output_bias) 147 | log_probs = tf.nn.log_softmax(logits, axis=-1) 148 | probs = tf.nn.softmax(logits, axis=-1) 149 | 150 | one_hot_labels = tf.one_hot(labels, depth=num_labels, dtype=tf.float32) 151 | 152 | per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1) 153 | loss = tf.reduce_mean(per_example_loss) 154 | 155 | return (loss, per_example_loss, probs) 156 | 157 | 158 | def model_fn_builder(bert_config, num_labels, init_checkpoint, learning_rate, 159 | num_train_steps, num_warmup_steps, use_tpu, 160 | use_one_hot_embeddings): 161 | """Returns `model_fn` closure for TPUEstimator.""" 162 | 163 | def model_fn(features, labels, mode, params): 164 | """The `model_fn` for TPUEstimator.""" 165 | 166 | tf.logging.info('*** Features ***') 167 | for name in sorted(features.keys()): 168 | tf.logging.info( 169 | ' name = %s, shape = %s' % (name, features[name].shape)) 170 | 171 | input_ids = features['input_ids'] 172 | input_mask = features['input_mask'] 173 | segment_ids = features['segment_ids'] 174 | label_ids = features['label_ids'] 175 | 176 | is_training = (mode == tf.estimator.ModeKeys.TRAIN) 177 | (total_loss, per_example_loss, probs) = create_model( 178 | bert_config, is_training, input_ids, input_mask, segment_ids, 179 | label_ids, num_labels, use_one_hot_embeddings) 180 | 181 | tvars = tf.trainable_variables() 182 | 183 | scaffold_fn = None 184 | initialized_variable_names = [] 185 | if init_checkpoint: 186 | (assignment_map, initialized_variable_names 187 | ) = modeling.get_assignment_map_from_checkpoint( 188 | tvars, init_checkpoint) 189 | if use_tpu: 190 | def tpu_scaffold(): 191 | tf.train.init_from_checkpoint(init_checkpoint, 192 | assignment_map) 193 | return tf.train.Scaffold() 194 | 195 | scaffold_fn = tpu_scaffold 196 | else: 197 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map) 198 | 199 | tf.logging.info('**** Trainable Variables ****') 200 | for var in tvars: 201 | init_string = '' 202 | if var.name in initialized_variable_names: 203 | init_string = ', *INIT_FROM_CKPT*' 204 | tf.logging.info(' name = %s, shape = %s%s', var.name, var.shape, 205 | init_string) 206 | 207 | output_spec = None 208 | if mode == tf.estimator.ModeKeys.TRAIN: 209 | 210 | train_op = optimization.create_optimizer( 211 | total_loss, learning_rate, num_train_steps, num_warmup_steps, 212 | use_tpu) 213 | 214 | output_spec = tf.contrib.tpu.TPUEstimatorSpec( 215 | mode=mode, 216 | loss=total_loss, 217 | train_op=train_op, 218 | scaffold_fn=scaffold_fn) 219 | 220 | elif mode == tf.estimator.ModeKeys.PREDICT: 221 | output_spec = tf.contrib.tpu.TPUEstimatorSpec( 222 | mode=mode, 223 | predictions={ 224 | 'probs': probs, 225 | 'label_ids': label_ids, 226 | }, 227 | scaffold_fn=scaffold_fn) 228 | 229 | else: 230 | raise ValueError( 231 | 'Only TRAIN and PREDICT modes are supported: %s' % (mode)) 232 | 233 | return output_spec 234 | 235 | return model_fn 236 | 237 | 238 | def input_fn_builder(dataset_path, seq_length, is_training, 239 | max_eval_examples=None, num_skip=0): 240 | """Creates an `input_fn` closure to be passed to TPUEstimator.""" 241 | 242 | def input_fn(params): 243 | """The actual input function.""" 244 | 245 | batch_size = params['batch_size'] 246 | output_buffer_size = batch_size * 1000 247 | 248 | def extract_fn(data_record): 249 | features = { 250 | 'input_ids': tf.FixedLenSequenceFeature( 251 | [], tf.int64, allow_missing=True), 252 | 'segment_ids': tf.FixedLenSequenceFeature( 253 | [], tf.int64, allow_missing=True), 254 | 'label': tf.FixedLenFeature([], tf.int64) 255 | } 256 | sample = tf.parse_single_example(data_record, features) 257 | 258 | input_ids = tf.cast(sample['input_ids'], tf.int32) 259 | segment_ids = tf.cast(sample['segment_ids'], tf.int32) 260 | label_ids = tf.cast(sample['label'], tf.int32) 261 | 262 | input_mask = tf.ones_like(input_ids) 263 | 264 | features = { 265 | 'input_ids': input_ids, 266 | 'segment_ids': segment_ids, 267 | 'input_mask': input_mask, 268 | 'label_ids': label_ids 269 | } 270 | return features 271 | 272 | dataset = tf.data.TFRecordDataset([dataset_path]) 273 | dataset = dataset.map( 274 | extract_fn, num_parallel_calls=4).prefetch(output_buffer_size) 275 | 276 | if is_training: 277 | dataset = dataset.repeat() 278 | dataset = dataset.shuffle(buffer_size=1000) 279 | else: 280 | if num_skip > 0: 281 | dataset = dataset.skip(num_skip) 282 | 283 | if max_eval_examples: 284 | # Use at most this number of examples (debugging only). 285 | dataset = dataset.take(max_eval_examples) 286 | # pass 287 | 288 | dataset = dataset.padded_batch( 289 | batch_size=batch_size, 290 | padded_shapes={ 291 | 'input_ids': [seq_length], 292 | 'segment_ids': [seq_length], 293 | 'input_mask': [seq_length], 294 | 'label_ids': [] 295 | }, 296 | padding_values={ 297 | 'input_ids': 0, 298 | 'segment_ids': 0, 299 | 'input_mask': 0, 300 | 'label_ids': 0 301 | }, 302 | drop_remainder=True) 303 | 304 | return dataset 305 | return input_fn 306 | 307 | 308 | def main(_): 309 | tf.logging.set_verbosity(tf.logging.INFO) 310 | 311 | if not FLAGS.do_train and not FLAGS.do_eval: 312 | raise ValueError( 313 | 'At least one of `FLAGS.do_train` or `FLAGS.do_eval` must be ' 314 | 'True.') 315 | 316 | bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file) 317 | 318 | if FLAGS.max_seq_length > bert_config.max_position_embeddings: 319 | raise ValueError( 320 | 'Cannot use sequence length %d because the BERT model ' 321 | 'was only trained up to sequence length %d' % 322 | (FLAGS.max_seq_length, bert_config.max_position_embeddings)) 323 | 324 | tpu_cluster_resolver = None 325 | if FLAGS.use_tpu and FLAGS.tpu_name: 326 | tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver( 327 | FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project) 328 | 329 | is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2 330 | run_config = tf.contrib.tpu.RunConfig( 331 | cluster=tpu_cluster_resolver, 332 | master=FLAGS.master, 333 | model_dir=FLAGS.output_dir, 334 | save_checkpoints_steps=FLAGS.save_checkpoints_steps, 335 | tpu_config=tf.contrib.tpu.TPUConfig( 336 | iterations_per_loop=FLAGS.iterations_per_loop, 337 | num_shards=FLAGS.num_tpu_cores, 338 | per_host_input_for_training=is_per_host)) 339 | 340 | model_fn = model_fn_builder( 341 | bert_config=bert_config, 342 | num_labels=2, 343 | init_checkpoint=FLAGS.init_checkpoint, 344 | learning_rate=FLAGS.learning_rate, 345 | num_train_steps=FLAGS.num_train_steps, 346 | num_warmup_steps=FLAGS.num_warmup_steps, 347 | use_tpu=FLAGS.use_tpu, 348 | use_one_hot_embeddings=FLAGS.use_tpu) 349 | 350 | # If TPU is not available, this will fall back to normal Estimator on CPU 351 | # or GPU. 352 | estimator = tf.contrib.tpu.TPUEstimator( 353 | use_tpu=FLAGS.use_tpu, 354 | model_fn=model_fn, 355 | config=run_config, 356 | train_batch_size=FLAGS.train_batch_size, 357 | eval_batch_size=FLAGS.eval_batch_size, 358 | predict_batch_size=FLAGS.eval_batch_size) 359 | 360 | if FLAGS.do_train: 361 | tf.logging.info('***** Running training *****') 362 | tf.logging.info(' Batch size = %d', FLAGS.train_batch_size) 363 | tf.logging.info(' Num steps = %d', FLAGS.num_train_steps) 364 | train_input_fn = input_fn_builder( 365 | dataset_path=FLAGS.data_dir + '/dataset_train.tf', 366 | seq_length=FLAGS.max_seq_length, 367 | is_training=True) 368 | estimator.train(input_fn=train_input_fn, 369 | max_steps=FLAGS.num_train_steps) 370 | tf.logging.info('Done Training!') 371 | 372 | if FLAGS.do_eval: 373 | num_eval_docs2 = FLAGS.num_eval_docs * (FLAGS.num_eval_docs - 1) 374 | for set_name in ['dev']: 375 | tf.logging.info('***** Running evaluation *****') 376 | tf.logging.info(' Batch size = %d', FLAGS.eval_batch_size) 377 | 378 | predictions_path = ( 379 | FLAGS.output_dir + '/msmarco_predictions_' + set_name + '.tsv') 380 | total_count = 0 381 | if tf.gfile.Exists(predictions_path): 382 | with tf.gfile.Open(predictions_path, 'r') as predictions_file: 383 | total_count = sum(1 for line in predictions_file) 384 | tf.logging.info( 385 | '{} examples already processed. Skipping them.'.format( 386 | total_count / FLAGS.num_eval_docs)) 387 | total_count = total_count * (FLAGS.num_eval_docs - 1) 388 | 389 | query_docids_map = [] 390 | with tf.gfile.Open(FLAGS.data_dir + '/query_doc_ids_' + set_name + 391 | '.txt') as ref_file: 392 | 393 | for line in ref_file: 394 | query_docids_map.append(line.strip().split('\t')) 395 | 396 | max_eval_examples = None 397 | if FLAGS.max_eval_examples: 398 | max_eval_examples = FLAGS.max_eval_examples * num_eval_docs2 399 | 400 | eval_input_fn = input_fn_builder( 401 | dataset_path=(FLAGS.data_dir + '/dataset_' + set_name + 402 | '.tf'), 403 | seq_length=FLAGS.max_seq_length, 404 | is_training=False, 405 | max_eval_examples=max_eval_examples, 406 | num_skip=total_count) 407 | 408 | # ***IMPORTANT NOTE*** 409 | # The logging output produced by the feed queues during evaluation 410 | # is very large (~14M lines for the dev set), which causes the tab 411 | # to crash if you don't have enough memory on your local machine. 412 | # We suppress this frequent logging by setting the verbosity to 413 | # WARN during the evaluation phase. 414 | tf.logging.set_verbosity(tf.logging.WARN) 415 | 416 | result = estimator.predict(input_fn=eval_input_fn, 417 | yield_single_examples=True) 418 | start_time = time.time() 419 | results = [] 420 | all_metrics = np.zeros(len(METRICS_MAP)) 421 | example_idx = 0 422 | 423 | for item in result: 424 | results.append((item['probs'], item['label_ids'])) 425 | total_count += 1 426 | 427 | if len(results) == num_eval_docs2: 428 | 429 | probs, labels = zip(*results) 430 | probs = np.stack(probs).reshape( 431 | FLAGS.num_eval_docs, FLAGS.num_eval_docs - 1, 2) 432 | labels = np.stack(labels).reshape( 433 | FLAGS.num_eval_docs, FLAGS.num_eval_docs - 1) 434 | 435 | for labels_i in labels: 436 | assert len(set(list(labels_i))) == 1, ( 437 | 'Labels must be all the same.') 438 | 439 | labels = labels[:, 0] 440 | 441 | scores = probs[:, :, 1] 442 | 443 | pred_docs = scores.sum(1).argsort()[::-1] 444 | 445 | gt = set(list(np.where(labels > 0)[0])) 446 | 447 | all_metrics += metrics.metrics( 448 | gt=gt, pred=pred_docs, metrics_map=METRICS_MAP) 449 | 450 | start_idx = total_count - num_eval_docs2 451 | end_idx = total_count 452 | query_ids, doc_ids, _ = zip( 453 | *query_docids_map[start_idx:end_idx]) 454 | assert len(set(query_ids)) == 1, ( 455 | 'Query ids must be all the same.') 456 | query_id = query_ids[0] 457 | 458 | # Unique doc ids are every FLAGS.num_eval_docs - 1 459 | doc_ids = doc_ids[::FLAGS.num_eval_docs - 1] 460 | # Workaround to make mode=a work when the file was not yet 461 | # created. 462 | mode = 'w' 463 | if tf.gfile.Exists(predictions_path): 464 | mode = 'a' 465 | with tf.gfile.Open( 466 | predictions_path, mode) as predictions_file: 467 | for rank, doc_idx in enumerate(pred_docs): 468 | doc_id = doc_ids[doc_idx] 469 | if doc_id != FLAGS.pad_doc_id: 470 | predictions_file.write('{}\t{}\t{}\n'.format( 471 | query_id, doc_id, rank + 1)) 472 | example_idx += 1 473 | results = [] 474 | 475 | if example_idx % 100 == 0: 476 | tf.logging.warn( 477 | 'Read {} examples in {} secs. Metrics so ' 478 | 'far:'.format(example_idx, 479 | int(time.time() - start_time))) 480 | tf.logging.warn(' '.join(METRICS_MAP)) 481 | tf.logging.warn(all_metrics / example_idx) 482 | 483 | # Once the feed queues are finished, we can set the verbosity back 484 | # to INFO. 485 | tf.logging.set_verbosity(tf.logging.INFO) 486 | 487 | all_metrics /= example_idx 488 | 489 | tf.logging.info('Eval {}:'.format(set_name)) 490 | tf.logging.info(' '.join(METRICS_MAP)) 491 | tf.logging.info(all_metrics) 492 | tf.logging.info('Done evaluating {}'.format(set_name)) 493 | 494 | 495 | if __name__ == '__main__': 496 | tf.app.run() 497 | -------------------------------------------------------------------------------- /tokenization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import collections 22 | import unicodedata 23 | import six 24 | import tensorflow as tf 25 | 26 | 27 | def convert_to_bert_input(text, max_seq_length, tokenizer, add_cls): 28 | 29 | tokens = tokenizer.tokenize(text) 30 | 31 | # Account for [CLS] and [SEP] with "- 2" 32 | if len(tokens) > max_seq_length - 2: 33 | tokens = tokens[:max_seq_length - 2] 34 | 35 | # The convention in BERT is: 36 | # (a) For sequence pairs: 37 | # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP] 38 | # type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1 39 | # (b) For single sequences: 40 | # tokens: [CLS] the dog is hairy . [SEP] 41 | # type_ids: 0 0 0 0 0 0 0 42 | # 43 | # Where "type_ids" are used to indicate whether this is the first 44 | # sequence or the second sequence. The embedding vectors for `type=0` and 45 | # `type=1` were learned during pre-training and are added to the wordpiece 46 | # embedding vector (and position vector). This is not *strictly* necessary 47 | # since the [SEP] token unambigiously separates the sequences, but it makes 48 | # it easier for the model to learn the concept of sequences. 49 | # 50 | # For classification tasks, the first vector (corresponding to [CLS]) is 51 | # used as as the "sentence vector". Note that this only makes sense because 52 | # the entire model is fine-tuned. 53 | if add_cls: 54 | tokens = ["[CLS]"] + tokens 55 | tokens += ["[SEP]"] 56 | 57 | input_ids = tokenizer.convert_tokens_to_ids(tokens) 58 | 59 | return input_ids 60 | 61 | 62 | def convert_to_unicode(text): 63 | """Converts `text` to Unicode (if it's not already), assuming utf-8 input.""" 64 | if six.PY3: 65 | if isinstance(text, str): 66 | return text 67 | elif isinstance(text, bytes): 68 | return text.decode("utf-8", "ignore") 69 | else: 70 | raise ValueError("Unsupported string type: %s" % (type(text))) 71 | elif six.PY2: 72 | if isinstance(text, str): 73 | return text.decode("utf-8", "ignore") 74 | elif isinstance(text, unicode): 75 | return text 76 | else: 77 | raise ValueError("Unsupported string type: %s" % (type(text))) 78 | else: 79 | raise ValueError("Not running on Python2 or Python 3?") 80 | 81 | 82 | def printable_text(text): 83 | """Returns text encoded in a way suitable for print or `tf.logging`.""" 84 | 85 | # These functions want `str` for both Python2 and Python3, but in one case 86 | # it's a Unicode string and in the other it's a byte string. 87 | if six.PY3: 88 | if isinstance(text, str): 89 | return text 90 | elif isinstance(text, bytes): 91 | return text.decode("utf-8", "ignore") 92 | else: 93 | raise ValueError("Unsupported string type: %s" % (type(text))) 94 | elif six.PY2: 95 | if isinstance(text, str): 96 | return text 97 | elif isinstance(text, unicode): 98 | return text.encode("utf-8") 99 | else: 100 | raise ValueError("Unsupported string type: %s" % (type(text))) 101 | else: 102 | raise ValueError("Not running on Python2 or Python 3?") 103 | 104 | 105 | def load_vocab(vocab_file): 106 | """Loads a vocabulary file into a dictionary.""" 107 | vocab = collections.OrderedDict() 108 | index = 0 109 | with tf.gfile.GFile(vocab_file, "r") as reader: 110 | while True: 111 | token = convert_to_unicode(reader.readline()) 112 | if not token: 113 | break 114 | token = token.strip() 115 | vocab[token] = index 116 | index += 1 117 | return vocab 118 | 119 | 120 | def convert_tokens_to_ids(vocab, tokens): 121 | """Converts a sequence of tokens into ids using the vocab.""" 122 | return [vocab[token] for token in tokens] 123 | 124 | 125 | def whitespace_tokenize(text): 126 | """Runs basic whitespace cleaning and splitting on a peice of text.""" 127 | text = text.strip() 128 | if not text: 129 | return [] 130 | tokens = text.split() 131 | return tokens 132 | 133 | 134 | class FullTokenizer(object): 135 | """Runs end-to-end tokenziation.""" 136 | 137 | def __init__(self, vocab_file, do_lower_case=True): 138 | self.vocab = load_vocab(vocab_file) 139 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case) 140 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) 141 | 142 | def tokenize(self, text): 143 | return [ 144 | sub_token 145 | for token in self.basic_tokenizer.tokenize(text) 146 | for sub_token in self.wordpiece_tokenizer.tokenize(token) 147 | ] 148 | 149 | def convert_tokens_to_ids(self, tokens): 150 | return convert_tokens_to_ids(self.vocab, tokens) 151 | 152 | 153 | class BasicTokenizer(object): 154 | """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" 155 | 156 | def __init__(self, do_lower_case=True): 157 | """Constructs a BasicTokenizer. 158 | 159 | Args: 160 | do_lower_case: Whether to lower case the input. 161 | """ 162 | self.do_lower_case = do_lower_case 163 | 164 | def tokenize(self, text): 165 | """Tokenizes a piece of text.""" 166 | text = convert_to_unicode(text) 167 | text = self._clean_text(text) 168 | orig_tokens = whitespace_tokenize(text) 169 | split_tokens = [] 170 | for token in orig_tokens: 171 | if self.do_lower_case: 172 | token = token.lower() 173 | token = self._run_strip_accents(token) 174 | split_tokens.extend(self._run_split_on_punc(token)) 175 | 176 | output_tokens = whitespace_tokenize(" ".join(split_tokens)) 177 | return output_tokens 178 | 179 | def _run_strip_accents(self, text): 180 | """Strips accents from a piece of text.""" 181 | text = unicodedata.normalize("NFD", text) 182 | output = [] 183 | for char in text: 184 | cat = unicodedata.category(char) 185 | if cat == "Mn": 186 | continue 187 | output.append(char) 188 | return "".join(output) 189 | 190 | def _run_split_on_punc(self, text): 191 | """Splits punctuation on a piece of text.""" 192 | chars = list(text) 193 | i = 0 194 | start_new_word = True 195 | output = [] 196 | while i < len(chars): 197 | char = chars[i] 198 | if _is_punctuation(char): 199 | output.append([char]) 200 | start_new_word = True 201 | else: 202 | if start_new_word: 203 | output.append([]) 204 | start_new_word = False 205 | output[-1].append(char) 206 | i += 1 207 | 208 | return ["".join(x) for x in output] 209 | 210 | def _clean_text(self, text): 211 | """Performs invalid character removal and whitespace cleanup on text.""" 212 | output = [] 213 | for char in text: 214 | cp = ord(char) 215 | if cp == 0 or cp == 0xfffd or _is_control(char): 216 | continue 217 | if _is_whitespace(char): 218 | output.append(" ") 219 | else: 220 | output.append(char) 221 | return "".join(output) 222 | 223 | 224 | class WordpieceTokenizer(object): 225 | """Runs WordPiece tokenziation.""" 226 | 227 | def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=100): 228 | self.vocab = vocab 229 | self.unk_token = unk_token 230 | self.max_input_chars_per_word = max_input_chars_per_word 231 | 232 | def tokenize(self, text): 233 | """Tokenizes a piece of text into its word pieces. 234 | 235 | This uses a greedy longest-match-first algorithm to perform tokenization 236 | using the given vocabulary. 237 | 238 | For example: 239 | input = "unaffable" 240 | output = ["un", "##aff", "##able"] 241 | 242 | Args: 243 | text: A single token or whitespace separated tokens. This should have 244 | already been passed through `BasicTokenizer. 245 | 246 | Returns: 247 | A list of wordpiece tokens. 248 | """ 249 | 250 | text = convert_to_unicode(text) 251 | 252 | output_tokens = [] 253 | for token in whitespace_tokenize(text): 254 | chars = list(token) 255 | if len(chars) > self.max_input_chars_per_word: 256 | output_tokens.append(self.unk_token) 257 | continue 258 | 259 | is_bad = False 260 | start = 0 261 | sub_tokens = [] 262 | while start < len(chars): 263 | end = len(chars) 264 | cur_substr = None 265 | while start < end: 266 | substr = "".join(chars[start:end]) 267 | if start > 0: 268 | substr = "##" + substr 269 | if substr in self.vocab: 270 | cur_substr = substr 271 | break 272 | end -= 1 273 | if cur_substr is None: 274 | is_bad = True 275 | break 276 | sub_tokens.append(cur_substr) 277 | start = end 278 | 279 | if is_bad: 280 | output_tokens.append(self.unk_token) 281 | else: 282 | output_tokens.extend(sub_tokens) 283 | return output_tokens 284 | 285 | 286 | def _is_whitespace(char): 287 | """Checks whether `chars` is a whitespace character.""" 288 | # \t, \n, and \r are technically contorl characters but we treat them 289 | # as whitespace since they are generally considered as such. 290 | if char == " " or char == "\t" or char == "\n" or char == "\r": 291 | return True 292 | cat = unicodedata.category(char) 293 | if cat == "Zs": 294 | return True 295 | return False 296 | 297 | 298 | def _is_control(char): 299 | """Checks whether `chars` is a control character.""" 300 | # These are technically control characters but we count them as whitespace 301 | # characters. 302 | if char == "\t" or char == "\n" or char == "\r": 303 | return False 304 | cat = unicodedata.category(char) 305 | if cat.startswith("C"): 306 | return True 307 | return False 308 | 309 | 310 | def _is_punctuation(char): 311 | """Checks whether `chars` is a punctuation character.""" 312 | cp = ord(char) 313 | # We treat all non-letter/number ASCII as punctuation. 314 | # Characters such as "^", "$", and "`" are not in the Unicode 315 | # Punctuation class but we treat them as punctuation anyways, for 316 | # consistency. 317 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or 318 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): 319 | return True 320 | cat = unicodedata.category(char) 321 | if cat.startswith("P"): 322 | return True 323 | return False 324 | -------------------------------------------------------------------------------- /tokenization_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import os 20 | import tempfile 21 | import six 22 | import tensorflow as tf 23 | import tokenization 24 | 25 | 26 | class TokenizationTest(tf.test.TestCase): 27 | 28 | def test_full_tokenizer(self): 29 | vocab_tokens = [ 30 | "[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn", 31 | "##ing", "," 32 | ] 33 | with tempfile.NamedTemporaryFile(delete=False) as vocab_writer: 34 | if six.PY2: 35 | vocab_writer.write("".join([x + "\n" for x in vocab_tokens])) 36 | else: 37 | vocab_writer.write("".join( 38 | [x + "\n" for x in vocab_tokens]).encode("utf-8")) 39 | 40 | vocab_file = vocab_writer.name 41 | 42 | tokenizer = tokenization.FullTokenizer(vocab_file) 43 | os.unlink(vocab_file) 44 | 45 | tokens = tokenizer.tokenize(u"UNwant\u00E9d,running") 46 | self.assertAllEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"]) 47 | 48 | self.assertAllEqual( 49 | tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9]) 50 | 51 | def test_chinese(self): 52 | tokenizer = tokenization.BasicTokenizer() 53 | 54 | self.assertAllEqual( 55 | tokenizer.tokenize(u"ah\u535A\u63A8zz"), 56 | [u"ah", u"\u535A", u"\u63A8", u"zz"]) 57 | 58 | def test_basic_tokenizer_lower(self): 59 | tokenizer = tokenization.BasicTokenizer(do_lower_case=True) 60 | 61 | self.assertAllEqual( 62 | tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "), 63 | ["hello", "!", "how", "are", "you", "?"]) 64 | self.assertAllEqual(tokenizer.tokenize(u"H\u00E9llo"), ["hello"]) 65 | 66 | def test_basic_tokenizer_no_lower(self): 67 | tokenizer = tokenization.BasicTokenizer(do_lower_case=False) 68 | 69 | self.assertAllEqual( 70 | tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "), 71 | ["HeLLo", "!", "how", "Are", "yoU", "?"]) 72 | 73 | def test_wordpiece_tokenizer(self): 74 | vocab_tokens = [ 75 | "[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn", 76 | "##ing" 77 | ] 78 | 79 | vocab = {} 80 | for (i, token) in enumerate(vocab_tokens): 81 | vocab[token] = i 82 | tokenizer = tokenization.WordpieceTokenizer(vocab=vocab) 83 | 84 | self.assertAllEqual(tokenizer.tokenize(""), []) 85 | 86 | self.assertAllEqual( 87 | tokenizer.tokenize("unwanted running"), 88 | ["un", "##want", "##ed", "runn", "##ing"]) 89 | 90 | self.assertAllEqual( 91 | tokenizer.tokenize("unwantedX running"), ["[UNK]", "runn", "##ing"]) 92 | 93 | def test_convert_tokens_to_ids(self): 94 | vocab_tokens = [ 95 | "[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn", 96 | "##ing" 97 | ] 98 | 99 | vocab = {} 100 | for (i, token) in enumerate(vocab_tokens): 101 | vocab[token] = i 102 | 103 | self.assertAllEqual( 104 | tokenization.convert_tokens_to_ids( 105 | vocab, ["un", "##want", "##ed", "runn", "##ing"]), [7, 4, 5, 8, 9]) 106 | 107 | def test_is_whitespace(self): 108 | self.assertTrue(tokenization._is_whitespace(u" ")) 109 | self.assertTrue(tokenization._is_whitespace(u"\t")) 110 | self.assertTrue(tokenization._is_whitespace(u"\r")) 111 | self.assertTrue(tokenization._is_whitespace(u"\n")) 112 | self.assertTrue(tokenization._is_whitespace(u"\u00A0")) 113 | 114 | self.assertFalse(tokenization._is_whitespace(u"A")) 115 | self.assertFalse(tokenization._is_whitespace(u"-")) 116 | 117 | def test_is_control(self): 118 | self.assertTrue(tokenization._is_control(u"\u0005")) 119 | 120 | self.assertFalse(tokenization._is_control(u"A")) 121 | self.assertFalse(tokenization._is_control(u" ")) 122 | self.assertFalse(tokenization._is_control(u"\t")) 123 | self.assertFalse(tokenization._is_control(u"\r")) 124 | 125 | def test_is_punctuation(self): 126 | self.assertTrue(tokenization._is_punctuation(u"-")) 127 | self.assertTrue(tokenization._is_punctuation(u"$")) 128 | self.assertTrue(tokenization._is_punctuation(u"`")) 129 | self.assertTrue(tokenization._is_punctuation(u".")) 130 | 131 | self.assertFalse(tokenization._is_punctuation(u"A")) 132 | self.assertFalse(tokenization._is_punctuation(u" ")) 133 | 134 | 135 | if __name__ == "__main__": 136 | tf.test.main() 137 | --------------------------------------------------------------------------------