├── data ├── requirements.txt └── download.sh ├── .gitignore ├── additional_documents.npy ├── .gitmodules ├── predictor.py ├── checkpoint_converter.py ├── benchmark.py ├── tests.py ├── data.py ├── model.py ├── README.md ├── LICENSE └── run_finetune.py /data/requirements.txt: -------------------------------------------------------------------------------- 1 | gsutil -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | checkpoint.pt 3 | 4 | /data/* 5 | !/data/download.sh 6 | !/data/requirements.txt -------------------------------------------------------------------------------- /additional_documents.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qqaatw/pytorch-realm-orqa/HEAD/additional_documents.npy -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "transformers"] 2 | path = transformers 3 | url = https://github.com/qqaatw/transformers 4 | branch = add_realmqa 5 | -------------------------------------------------------------------------------- /data/download.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | gsutil -m cp -r \ 3 | "gs://realm-data/cc_news_pretrained/" \ 4 | "gs://realm-data/orqa_nq_model_from_realm" \ 5 | "gs://realm-data/orqa_wq_model_from_realm" \ 6 | "gs://orqa-data/enwiki-20181220/" \ 7 | . -------------------------------------------------------------------------------- /predictor.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | 3 | import torch 4 | 5 | from model import get_openqa, add_additional_documents 6 | 7 | from transformers.models.realm.modeling_realm import logger 8 | from transformers.utils import logging 9 | 10 | logger.setLevel(logging.INFO) 11 | torch.set_printoptions(precision=8) 12 | 13 | 14 | def get_arg_parser(): 15 | parser = ArgumentParser() 16 | 17 | parser.add_argument("--question", type=str, required=True, 18 | help="Input question.") 19 | parser.add_argument("--checkpoint_pretrained_name", type=str, default=r"google/realm-orqa-nq-openqa", 20 | help="Checkpoint name or path.") 21 | parser.add_argument("--additional_documents_path", type=str, default=None, 22 | help="Additional document entries for retrieval. Must be .npy format.") 23 | 24 | return parser 25 | 26 | def main(args): 27 | openqa = get_openqa(args) 28 | tokenizer = openqa.retriever.tokenizer 29 | 30 | if args.additional_documents_path is not None: 31 | add_additional_documents(openqa, args.additional_documents_path) 32 | 33 | question_ids = tokenizer(args.question, return_tensors="pt").input_ids 34 | 35 | with torch.no_grad(): 36 | outputs = openqa( 37 | input_ids=question_ids, 38 | return_dict=True, 39 | ) 40 | 41 | predicted_answer = tokenizer.decode(outputs.predicted_answer_ids) 42 | 43 | print(f"Question: {args.question}\nAnswer: {predicted_answer}") 44 | 45 | return predicted_answer 46 | 47 | if __name__ == "__main__": 48 | parser = get_arg_parser() 49 | args = parser.parse_args() 50 | main(args) -------------------------------------------------------------------------------- /checkpoint_converter.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from argparse import ArgumentParser 3 | from transformers import RealmConfig 4 | from transformers.models.realm.modeling_realm import logger 5 | 6 | from model import get_openqa_tf_finetuned, get_openqa_tf_pretrained 7 | 8 | logger.setLevel(logging.INFO) 9 | 10 | def get_arg_parser(): 11 | parser = ArgumentParser() 12 | 13 | # ./data/enwiki-20181220/blocks.tfr 14 | parser.add_argument("--block_records_path", type=str, required=True, 15 | help="Block records path.") 16 | # ./data/cc_news_pretrained/embedder/encoded/encoded.ckpt 17 | parser.add_argument("--block_emb_path", type=str, required=True, 18 | help="Block embeddings path.") 19 | 20 | pretrained_group = parser.add_argument_group("pretrained conversion") 21 | 22 | pretrained_group.add_argument("--embedder_path", type=str, default=r"./data/cc_news_pretrained/embedder/variables/variables", 23 | help="Pretrained embedder path.") 24 | pretrained_group.add_argument("--bert_path", type=str, default=r"./data/cc_news_pretrained/bert/variables/variables", 25 | help="Pretrained bert path.") 26 | 27 | finetuned_group = parser.add_argument_group("finetuned conversion") 28 | 29 | finetuned_group.add_argument("--checkpoint_path", type=str, default=r"./data/orqa_nq_model_from_realm/export/best_default/checkpoint/model.ckpt-300000", 30 | help="Finetuned checkpoint path.") 31 | 32 | parser.add_argument("--output_path", type=str, default=r"./converted_model/", 33 | help="Converted checkpoint path.") 34 | parser.add_argument("--from_pretrained", action="store_true", 35 | help="Whether to convert from a pretrained checkpoint or a finetuned checkpoint.") 36 | 37 | return parser 38 | 39 | def main(args): 40 | config = RealmConfig() 41 | 42 | if args.from_pretrained: 43 | model = get_openqa_tf_pretrained(args, config) 44 | else: 45 | model = get_openqa_tf_finetuned(args, config) 46 | 47 | model.save_pretrained(args.output_path) 48 | model.retriever.save_pretrained(args.output_path) 49 | 50 | if __name__ == "__main__": 51 | parser = get_arg_parser() 52 | args = parser.parse_args() 53 | main(args) -------------------------------------------------------------------------------- /benchmark.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from argparse import ArgumentParser 3 | from transformers import RealmConfig 4 | from tqdm import tqdm 5 | 6 | from data import load as load_dataset 7 | from data import DataCollator 8 | from run_finetune import compute_eval_metrics 9 | from model import get_openqa 10 | 11 | 12 | def get_arg_parser(): 13 | parser = ArgumentParser() 14 | 15 | # Data 16 | parser.add_argument("--dataset_name_path", type=str, default=r"natural_questions", 17 | choices=["natural_questions", "web_questions"]) 18 | parser.add_argument("--dataset_cache_dir", type=str, default=r"./data/dataset_cache_dir/") 19 | parser.add_argument("--dev_ratio", type=float, default=0.1) 20 | parser.add_argument("--max_answer_tokens", type=int, default=5) 21 | 22 | # Model 23 | parser.add_argument("--checkpoint_pretrained_name", type=str, default=r"google/realm-orqa-nq-openqa") 24 | 25 | # Config 26 | parser.add_argument("--device", type=str, default="cpu") 27 | 28 | return parser 29 | 30 | def main(args): 31 | config = RealmConfig(searcher_beam_size=10) 32 | 33 | openqa = get_openqa(args, config=config) 34 | openqa.to(args.device) 35 | tokenizer = openqa.retriever.tokenizer 36 | 37 | # Setup data 38 | _, _, eval_dataset = load_dataset(args) 39 | data_collector = DataCollator(args, tokenizer) 40 | eval_dataloader = torch.utils.data.DataLoader( 41 | dataset=eval_dataset, 42 | batch_size=1, 43 | shuffle=False, 44 | collate_fn=data_collector 45 | ) 46 | print(eval_dataset) 47 | 48 | all_metrics = [] 49 | for batch in tqdm(eval_dataloader): 50 | question, answer_texts, answer_ids = batch 51 | question_ids = tokenizer(question, return_tensors="pt").input_ids 52 | 53 | with torch.no_grad(): 54 | outputs = openqa( 55 | input_ids=question_ids.to(args.device), 56 | answer_ids=answer_ids, 57 | return_dict=True, 58 | ) 59 | 60 | predicted_answer = tokenizer.decode(outputs.predicted_answer_ids) 61 | all_metrics.append(compute_eval_metrics(answer_texts, predicted_answer, outputs.reader_output)) 62 | 63 | stacked_metrics = { 64 | metric_key : torch.stack((*map(lambda metrics: metrics[metric_key], all_metrics),)) for metric_key in all_metrics[0].keys() 65 | } 66 | 67 | print('\n'.join(map(lambda metric: f"{metric[0]}:{metric[1].type(torch.float32).mean()}", stacked_metrics.items()))) 68 | 69 | if __name__ == "__main__": 70 | parser = get_arg_parser() 71 | args = parser.parse_args() 72 | main(args) -------------------------------------------------------------------------------- /tests.py: -------------------------------------------------------------------------------- 1 | import os 2 | import unittest 3 | import shutil 4 | import tempfile 5 | import torch 6 | 7 | import predictor 8 | import run_finetune 9 | 10 | 11 | class Tester(unittest.TestCase): 12 | def setUp(self): 13 | self.device = "cuda" if torch.cuda.is_available() else "cpu" 14 | self.pt_checkpoint_pretrained_name = "./export_nq_newqa/realm-orqa-nq-openqa" 15 | self.model_dir = "./" 16 | self.temp_dir = tempfile.mkdtemp(dir="./") 17 | self.pt_checkpoint_name = "checkpoint" 18 | self.pt_checkpoint_step = 158330 19 | 20 | def tearDown(self) -> None: 21 | shutil.rmtree(self.temp_dir) 22 | 23 | def test_predictor(self): 24 | parser = predictor.get_arg_parser() 25 | args = parser.parse_args([ 26 | "--question", "Who is the pioneer in modern computer science?", 27 | "--checkpoint_pretrained_name", self.pt_checkpoint_pretrained_name, 28 | ]) 29 | answer = predictor.main(args) 30 | 31 | self.assertEqual(answer, "alan mathison turing") 32 | 33 | def test_predictor_with_additional_documents(self): 34 | parser = predictor.get_arg_parser() 35 | args = parser.parse_args([ 36 | "--question", "What is the previous name of Meta Platform, Inc.?", 37 | "--checkpoint_pretrained_name", self.pt_checkpoint_pretrained_name, 38 | "--additional_documents_path", "additional_documents.npy", 39 | ]) 40 | answer = predictor.main(args) 41 | 42 | self.assertEqual(answer, "facebook, inc.") 43 | 44 | def test_finetune(self): 45 | parser = run_finetune.get_arg_parser() 46 | args = parser.parse_args([ 47 | "--is_train", 48 | "--num_training_steps", "5", 49 | "--dataset_name_path", "dummy", 50 | "--model_dir", self.temp_dir, 51 | "--checkpoint_pretrained_name", self.pt_checkpoint_pretrained_name, 52 | "--checkpoint_name", self.pt_checkpoint_name, 53 | "--device", self.device, 54 | ]) 55 | run_finetune.main(args) 56 | 57 | self.assertTrue(os.path.isdir(os.path.join(self.temp_dir, f"{self.pt_checkpoint_name}-5"))) 58 | self.assertEqual(len(os.listdir(os.path.join(self.temp_dir, f"{self.pt_checkpoint_name}-5"))), 7) 59 | self.assertTrue(os.path.isfile("fine-tuning.log")) 60 | 61 | def test_finetune_eval(self): 62 | parser = run_finetune.get_arg_parser() 63 | args = parser.parse_args([ 64 | "--dataset_name_path", "dummy", 65 | "--model_dir", self.model_dir, 66 | "--checkpoint_name", self.pt_checkpoint_name, 67 | "--checkpoint_step", str(self.pt_checkpoint_step), 68 | "--device", self.device, 69 | ]) 70 | run_finetune.main(args) 71 | 72 | self.assertTrue(os.path.isfile("fine-tuning.log")) -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import string 4 | import unicodedata 5 | 6 | from datasets import load_dataset 7 | 8 | 9 | def normalize_answer(s): 10 | """Normalize answer. (Directly copied from ORQA codebase)""" 11 | s = unicodedata.normalize("NFD", s) 12 | 13 | def remove_articles(text): 14 | return re.sub(r"\b(a|an|the)\b", " ", text) 15 | 16 | def white_space_fix(text): 17 | return " ".join(text.split()) 18 | 19 | def remove_punc(text): 20 | exclude = set(string.punctuation) 21 | return "".join(ch for ch in text if ch not in exclude) 22 | 23 | def lower(text): 24 | return text.lower() 25 | 26 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 27 | 28 | def load(args): 29 | """Load dataset""" 30 | if os.path.isdir(args.dataset_name_path): 31 | raise ValueError("Dataset path currently not supported.") 32 | 33 | if args.dataset_name_path == "natural_questions": 34 | return load_nq(args) 35 | elif args.dataset_name_path == "web_questions": 36 | return load_wq(args) 37 | elif args.dataset_name_path == "dummy": 38 | return load_dummy(args) 39 | else: 40 | raise ValueError("Invalid dataset name or path") 41 | 42 | def load_nq(args): 43 | """Load NaturalQuestions.""" 44 | def filter_fn(example): 45 | """Remove answers having length more than 5.""" 46 | for short_answer in example['annotations.short_answers']: 47 | if len(short_answer) != 0: 48 | for i in range(len(short_answer['text'])): 49 | if short_answer['end_token'][i] - short_answer['start_token'][i] <= args.max_answer_tokens: 50 | return True 51 | return False 52 | 53 | def map_fn(example): 54 | """Unify dataset structures.""" 55 | return { 56 | "question": example["question.text"], 57 | "answers": [answer["text"] for answer in example["annotations.short_answers"]] 58 | } 59 | 60 | dataset = load_dataset(args.dataset_name_path, cache_dir=os.path.abspath(args.dataset_cache_dir)) 61 | 62 | # Remove unused columns and flatten structure. 63 | training_dev_dataset = dataset['train'].train_test_split(test_size=args.dev_ratio, shuffle=False) 64 | training_dataset = training_dev_dataset['train'].remove_columns(['id', 'document']).flatten() 65 | dev_dataset = training_dev_dataset['test'].remove_columns(['id', 'document']).flatten() 66 | eval_dataset = dataset['validation'].remove_columns(['id', 'document']).flatten() 67 | 68 | # Perform filtering and mapping 69 | filtered_training_dataset = training_dataset.filter(filter_fn).map(map_fn) 70 | filtered_dev_dataset = dev_dataset.filter(filter_fn).map(map_fn) 71 | filtered_eval_dataset = eval_dataset.filter(filter_fn).map(map_fn) 72 | 73 | # An exmaple of each dataset should contain the following columns: 74 | # example["question"] 75 | # example["answers"][num_answers] 76 | return filtered_training_dataset, filtered_dev_dataset, filtered_eval_dataset 77 | 78 | def load_wq(args): 79 | """Load WebQuestions(WQ).""" 80 | dataset = load_dataset(args.dataset_name_path, cache_dir=os.path.abspath(args.dataset_cache_dir)) 81 | 82 | # Remove unused columns and flatten structure. 83 | training_dev_dataset = dataset['train'].train_test_split(test_size=args.dev_ratio, shuffle=False) 84 | training_dataset = training_dev_dataset['train'].remove_columns(['url']) 85 | dev_dataset = training_dev_dataset['test'].remove_columns(['url']) 86 | eval_dataset = dataset['test'].remove_columns(['url']) 87 | 88 | # No need to filter 89 | filtered_training_dataset = training_dataset 90 | filtered_dev_dataset = dev_dataset 91 | filtered_eval_dataset = eval_dataset 92 | 93 | # An exmaple of each dataset should contain the following columns: 94 | # example["question"] 95 | # example["answers"][num_answers] 96 | return filtered_training_dataset, filtered_dev_dataset, filtered_eval_dataset 97 | 98 | def load_dummy(args): 99 | dataset = [ 100 | { 101 | "question": "What is the previous name of Meta Platform, Inc.?", 102 | "answers": [ 103 | "facebook, inc.", 104 | ], 105 | }, 106 | { 107 | "question": "Who is the pioneer in modern computer science?", 108 | "answers": [ 109 | "alan mathison turing", 110 | ], 111 | }, 112 | ] 113 | return dataset, dataset, dataset 114 | 115 | class DataCollator(object): 116 | def __init__(self, args, tokenizer): 117 | self.args = args 118 | self.tokenizer = tokenizer 119 | def __call__(self, batch): 120 | example = batch[0] 121 | question = example["question"] 122 | answer_texts = [] 123 | for answer in example["answers"]: 124 | answer_texts += [answer] if isinstance(answer, str) else answer 125 | answer_texts = list(set(answer_texts)) 126 | if len(answer_texts) != 0: 127 | answer_ids = self.tokenizer( 128 | answer_texts, 129 | add_special_tokens=False, 130 | return_token_type_ids=False, 131 | return_attention_mask=False, 132 | ).input_ids 133 | else: 134 | answer_ids = [[-1]] 135 | return question, answer_texts, answer_ids -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from transformers import ( 4 | RealmConfig, 5 | RealmReader, 6 | RealmRetriever, 7 | RealmScorer, 8 | RealmForOpenQA, 9 | RealmTokenizerFast, 10 | load_tf_weights_in_realm, 11 | ) 12 | from transformers.models.realm.retrieval_realm import convert_tfrecord_to_np 13 | 14 | 15 | def add_additional_documents(openqa, additional_documents_path): 16 | documents = np.load(additional_documents_path, allow_pickle=True) 17 | total_documents = documents.shape[0] 18 | 19 | retriever = openqa.retriever 20 | tokenizer = openqa.retriever.tokenizer 21 | 22 | # docs 23 | retriever.block_records = np.concatenate((retriever.block_records, documents), axis=0) 24 | 25 | # embeds 26 | documents = [doc.decode() for doc in documents] 27 | inputs = tokenizer(documents, padding=True, truncation=True, return_tensors="pt") 28 | 29 | with torch.no_grad(): 30 | projected_score = openqa.embedder(**inputs, return_dict=True).projected_score 31 | openqa.block_emb = torch.cat((openqa.block_emb, projected_score), dim=0) 32 | 33 | openqa.config.num_block_records += total_documents 34 | 35 | def get_openqa_tf_finetuned(args, config=None): 36 | if config is None: 37 | config = RealmConfig(hidden_act="gelu_new") 38 | 39 | tokenizer = RealmTokenizerFast.from_pretrained("google/realm-cc-news-pretrained-embedder", do_lower_case=True) 40 | 41 | block_records = convert_tfrecord_to_np(args.block_records_path, config.num_block_records) 42 | retriever = RealmRetriever(block_records, tokenizer) 43 | 44 | openqa = RealmForOpenQA(config, retriever) 45 | 46 | openqa = load_tf_weights_in_realm( 47 | openqa, 48 | config, 49 | args.checkpoint_path, 50 | ) 51 | 52 | openqa = load_tf_weights_in_realm( 53 | openqa, 54 | config, 55 | args.block_emb_path, 56 | ) 57 | 58 | openqa.eval() 59 | 60 | return openqa 61 | 62 | def get_openqa_tf_pretrained(args, config=None): 63 | if config is None: 64 | config = RealmConfig(hidden_act="gelu_new") 65 | 66 | tokenizer = RealmTokenizerFast.from_pretrained("google/realm-cc-news-pretrained-embedder", do_lower_case=True) 67 | 68 | block_records = convert_tfrecord_to_np(args.block_records_path, config.num_block_records) 69 | retriever = RealmRetriever(block_records, tokenizer) 70 | 71 | openqa = RealmForOpenQA(config, retriever) 72 | 73 | openqa = load_tf_weights_in_realm( 74 | openqa, 75 | config, 76 | args.bert_path, 77 | ) 78 | 79 | openqa = load_tf_weights_in_realm( 80 | openqa, 81 | config, 82 | args.embedder_path, 83 | ) 84 | 85 | openqa = load_tf_weights_in_realm( 86 | openqa, 87 | config, 88 | args.block_emb_path, 89 | ) 90 | 91 | openqa.eval() 92 | 93 | return openqa 94 | 95 | def get_openqa(args, config=None): 96 | if config is None: 97 | config = RealmConfig(hidden_act="gelu_new") 98 | 99 | retriever = RealmRetriever.from_pretrained(args.checkpoint_pretrained_name) 100 | 101 | openqa = RealmForOpenQA.from_pretrained( 102 | args.checkpoint_pretrained_name, 103 | retriever=retriever, 104 | config=config, 105 | ) 106 | openqa.eval() 107 | 108 | return openqa 109 | 110 | def get_scorer_reader_tokenizer_tf(args, config=None): 111 | if config is None: 112 | config = RealmConfig(hidden_act="gelu_new") 113 | scorer = RealmScorer(config, args.block_records_path) 114 | 115 | # Load retriever weights 116 | scorer = load_tf_weights_in_realm( 117 | scorer, 118 | config, 119 | args.retriever_path, 120 | ) 121 | 122 | # Load block_emb weights 123 | scorer = load_tf_weights_in_realm( 124 | scorer, 125 | config, 126 | args.block_emb_path, 127 | ) 128 | scorer.eval() 129 | 130 | reader = RealmReader.from_pretrained( 131 | args.checkpoint_path, 132 | config=config, 133 | from_tf=True, 134 | ) 135 | reader.eval() 136 | 137 | tokenizer = RealmTokenizerFast.from_pretrained("google/realm-cc-news-pretrained-embedder", do_lower_case=True) 138 | 139 | return scorer, reader, tokenizer 140 | 141 | def get_scorer_reader_tokenizer_pt_pretrained(args, config=None): 142 | if config is None: 143 | config = RealmConfig(hidden_act="gelu_new") 144 | scorer = RealmScorer.from_pretrained(args.retriever_pretrained_name, args.block_records_path, config=config) 145 | 146 | # Load block_emb weights 147 | scorer = load_tf_weights_in_realm( 148 | scorer, 149 | config, 150 | args.block_emb_path, 151 | ) 152 | scorer.eval() 153 | 154 | reader = RealmReader.from_pretrained(args.checkpoint_pretrained_name, config=config) 155 | reader.eval() 156 | 157 | tokenizer = RealmTokenizerFast.from_pretrained("google/realm-cc-news-pretrained-embedder", do_lower_case=True) 158 | 159 | return scorer, reader, tokenizer 160 | 161 | def get_scorer_reader_tokenizer_pt_finetuned(args, config=None): 162 | if config is None: 163 | config = RealmConfig(hidden_act="gelu_new") 164 | scorer = RealmScorer.from_pretrained(args.retriever_pretrained_name, args.block_records_path, config=config) 165 | scorer.eval() 166 | 167 | reader = RealmReader.from_pretrained(args.checkpoint_pretrained_name, config=config) 168 | reader.eval() 169 | 170 | tokenizer = RealmTokenizerFast.from_pretrained("google/realm-cc-news-pretrained-embedder", do_lower_case=True) 171 | 172 | return scorer, reader, tokenizer 173 | 174 | def get_scorer_reader_tokenizer(args, config=None): 175 | if config is None: 176 | config = RealmConfig(hidden_act="gelu_new") 177 | 178 | scorer = RealmScorer(config, args.block_records_path) 179 | reader = RealmReader(config) 180 | tokenizer = RealmTokenizerFast.from_pretrained("google/realm-cc-news-pretrained-embedder", do_lower_case=True) 181 | 182 | return scorer, reader, tokenizer -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PyTorch Reimplementation of REALM and ORQA 2 | 3 | This is PyTorch reimplementation of REALM ([paper](https://arxiv.org/abs/2002.08909), [codebase](https://github.com/google-research/language/tree/master/language/realm)) and ORQA ([paper](https://arxiv.org/abs/1906.00300), [codebase](https://github.com/google-research/language/tree/master/language/orqa)). 4 | 5 | 6 | *The term `Scorer` is actually the pretraining `Retriever` in REALM paper, we change it to `Scorer` to prevent a conflict with finetuning `Retriever`.* 7 | 8 | 9 | ## Prerequisite 10 | 11 | ```bash 12 | pip install -U transformers apache_beam 13 | ``` 14 | 15 | ## Data 16 | 17 | To download TensorFlow checkpoints and preprocessed data, please follow the instructions below: 18 | 19 | ```bash 20 | cd data 21 | pip install -U -r requirements.txt 22 | sh download.sh 23 | ``` 24 | 25 | To convert pretrained TensorFlow checkpoints like **CC-News** to PyTorch checkpoints: 26 | 27 | ```bash 28 | python checkpoint_converter.py \ 29 | --block_records_path "data/enwiki-20181220/blocks.tfr" \ 30 | --block_emb_path "./data/cc_news_pretrained/embedder/encoded/encoded.ckpt" \ 31 | --embedder_path "./data/cc_news_pretrained/embedder/variables/variables" \ 32 | --bert_path "./data/cc_news_pretrained/bert/variables/variables" \ 33 | --output_path path_to_save_converted_model \ 34 | --from_pretrained 35 | ``` 36 | 37 | To convert finetuned TensorFlow checkpoints like **Natural Questions (NQ)** to PyTorch checkpoints: 38 | 39 | ```bash 40 | python checkpoint_converter.py \ 41 | --block_records_path "data/enwiki-20181220/blocks.tfr" \ 42 | --block_emb_path "./data/cc_news_pretrained/embedder/encoded/encoded.ckpt" \ 43 | --checkpoint_path "./data/orqa_nq_model_from_realm/export/best_default/checkpoint/model.ckpt-300000" \ 44 | --output_path path_to_save_converted_model 45 | ``` 46 | 47 | The format of additional documents are built like this in NumPy: 48 | 49 | ```python 50 | array( 51 | [b"Meta Platforms, Inc., doing business as Meta and formerly known as Facebook, Inc., is an American multinational technology conglomerate based in Menlo Park, California. The company is the parent organization of Facebook, Instagram, and WhatsApp, among other subsidiaries. Meta is one of the world's most valuable companies. It is one of the Big Five American information technology companies, alongside Google (Alphabet Inc.), Amazon, Apple, and Microsoft", 52 | b"Coronavirus disease 2019 (COVID-19) is a contagious disease caused by severe acute respiratory syndrome coronavirus 2 (SARS-CoV-2). The first known case was identified in Wuhan, China, in December 2019. The disease has since spread worldwide, leading to an ongoing pandemic."], 53 | dtype=object 54 | ) 55 | ``` 56 | 57 | ## Predict 58 | 59 | The default checkpoint is `google/realm-orqa-nq-openqa`. To change it, kindly specify `--checkpoint_pretrained_name`, which can be a local path or a model name on the huggingface model hub. 60 | 61 | ```bash 62 | python predictor.py --question "Who is the pioneer in modern computer science?" 63 | 64 | Output: alan mathison turing 65 | ``` 66 | 67 | Loading additional documents for retrieval: 68 | 69 | ```bash 70 | python predictor.py \ 71 | --question "What is the previous name of Meta Platform, Inc.?" \ 72 | --additional_documents_path "additional_documents.npy" 73 | 74 | Output: facebook, inc. 75 | ``` 76 | 77 | ## Finetune (Experimental) 78 | 79 | The default finetuning dataset is **Natural Questions (NQ)**. To load your custom dataset, please change the loading function in `data.py`. 80 | 81 | Training: 82 | 83 | ```bash 84 | python run_finetune.py --is_train \ 85 | --checkpoint_pretrained_name "google/realm-cc-news-pretrained-openqa" \ 86 | --checkpoint_name "checkpoint" \ 87 | --dataset_name_path "natural_questions" \ 88 | --model_dir "./out/" \ 89 | --num_epochs 2 \ 90 | --device cuda 91 | ``` 92 | 93 | Loading additional documents for retrieval: 94 | 95 | ```bash 96 | --additional_documents_path "additional_documents.npy" 97 | ``` 98 | 99 | The output model and the additional documents will be stored in `./out/checkpoint-x` directory, where `x` is the training step when saving. So if you've added additional documents when training, there is no need to specify it during evaluation. 100 | 101 | Evaluation: 102 | 103 | ```bash 104 | python run_finetune.py \ 105 | --checkpoint_name "checkpoint" \ 106 | --checkpoint_step 50000 \ 107 | --dataset_name_path "natural_questions" \ 108 | --model_dir "./out/" \ 109 | --device cuda 110 | ``` 111 | 112 | ## Benchmark 113 | 114 | ### Natural Questions (NQ) 115 | 116 | Using brute-force matrix multiplication searcher: 117 | 118 | ```bash 119 | python benchmark.py \ 120 | --dataset_name_path natural_questions \ 121 | --checkpoint_pretrained_name google/realm-orqa-nq-openqa \ 122 | --device cuda 123 | ``` 124 | 125 | Outputs with brute-force matrix multiplication searcher: 126 | 127 | ``` 128 | exact_match:0.410526305437088 129 | official_exact_match:0.4041551351547241 # value in the paper: ~0.404 130 | reader_oracle:0.7193905711174011 131 | top_5_match:0.7218836545944214 132 | top_10_match:0.7218836545944214 133 | top_50_match:0.7218836545944214 134 | top_100_match:0.7218836545944214 135 | top_500_match:0.7218836545944214 136 | top_1000_match:0.7218836545944214 137 | top_5000_match:0.7218836545944214 138 | ``` 139 | 140 | ~Using ScaNN searcher~(currently not available): 141 | 142 | ```bash 143 | python run_finetune.py --benchmark --use_scann 144 | ``` 145 | 146 | Outputs with ScaNN searcher: 147 | 148 | ``` 149 | exact_match:0.4019390642642975 150 | official_exact_match:0.3972299098968506 # value in the paper: ~0.404 151 | reader_oracle:0.7041551470756531 152 | top_5_match:0.7058171629905701 153 | top_10_match:0.7058171629905701 154 | top_50_match:0.7058171629905701 155 | top_100_match:0.7058171629905701 156 | top_500_match:0.7058171629905701 157 | top_1000_match:0.7058171629905701 158 | top_5000_match:0.7058171629905701 159 | ``` 160 | 161 | ### Web Questions (WQ) 162 | 163 | Using brute-force matrix multiplication searcher: 164 | 165 | ```bash 166 | python benchmark.py \ 167 | --dataset_name_path web_questions \ 168 | --checkpoint_pretrained_name google/realm-orqa-wq-openqa \ 169 | --device cuda 170 | ``` 171 | 172 | Outputs with brute-force matrix multiplication searcher: 173 | 174 | ``` 175 | exact_match:0.4345472455024719 176 | official_exact_match:0.41683071851730347 # value in the paper: ~0.407 177 | reader_oracle:0.6929134130477905 178 | top_5_match:0.6934055089950562 179 | top_10_match:0.6934055089950562 180 | top_50_match:0.6934055089950562 181 | top_100_match:0.6934055089950562 182 | top_500_match:0.6934055089950562 183 | top_1000_match:0.6934055089950562 184 | top_5000_match:0.6934055089950562 185 | ``` 186 | 187 | ~Using ScaNN searcher~(currently not available): 188 | 189 | ```bash 190 | python run_finetune.py \ 191 | --benchmark \ 192 | --use_scann \ 193 | --dataset_name_path web_questions \ 194 | --retriever_path ./data/orqa_wq_model_from_realm/export/best_default/checkpoint/model.ckpt-205020 \ 195 | --checkpoint_path ./data/orqa_wq_model_from_realm/export/best_default/checkpoint/model.ckpt-205020 196 | ``` 197 | 198 | Outputs with ScaNN searcher: 199 | 200 | ``` 201 | exact_match:0.42814961075782776 202 | official_exact_match:0.4114173352718353 # value in the paper: ~0.407 203 | reader_oracle:0.6840550899505615 204 | top_5_match:0.6840550899505615 205 | top_10_match:0.6840550899505615 206 | top_50_match:0.6840550899505615 207 | top_100_match:0.6840550899505615 208 | top_500_match:0.6840550899505615 209 | top_1000_match:0.6840550899505615 210 | top_5000_match:0.6840550899505615 211 | ``` 212 | 213 | ## License 214 | 215 | Apache License 2.0 -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2021 qqaatw 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /run_finetune.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | from argparse import ArgumentParser 4 | 5 | import torch 6 | from torch.nn.utils import clip_grad_norm_ 7 | from tqdm import tqdm 8 | 9 | from data import DataCollator, normalize_answer 10 | from data import load as load_dataset 11 | from model import get_openqa, add_additional_documents 12 | from transformers import RealmConfig, RealmForOpenQA, RealmRetriever, get_linear_schedule_with_warmup 13 | from transformers.models.realm.modeling_realm import logger as model_logger 14 | 15 | model_logger.setLevel(logging.INFO) 16 | 17 | formatter = logging.Formatter('%(asctime)s - %(message)s') 18 | file_handler = logging.FileHandler('fine-tuning.log') 19 | file_handler.setLevel(logging.INFO) 20 | file_handler.setFormatter(formatter) 21 | stream_handler = logging.StreamHandler() 22 | stream_handler.setLevel(logging.INFO) 23 | stream_handler.setFormatter(formatter) 24 | logger = logging.getLogger() 25 | logger.setLevel(logging.INFO) 26 | logger.addHandler(file_handler) 27 | logger.addHandler(stream_handler) 28 | 29 | torch.set_printoptions(precision=8) 30 | 31 | MAX_EPOCHS = 2 32 | 33 | 34 | def get_arg_parser(): 35 | parser = ArgumentParser() 36 | 37 | # Data processing 38 | parser.add_argument("--dev_ratio", type=float, default=0.1, 39 | help="The ratio of development set which will be splitted from training set.") 40 | parser.add_argument("--max_answer_tokens", type=int, default=5, 41 | help="Answers below max_answer_tokens will be used for training and evaluation.") 42 | 43 | # Training dir 44 | parser.add_argument("--dataset_name_path", type=str, default=r"natural_questions", 45 | help="Dataset name or path. Currently available datasets: natural_questions and web_questions. See data.py for more details.") 46 | parser.add_argument("--dataset_cache_dir", type=str, default=r"./data/dataset_cache_dir/", 47 | help="Directory storing dataset caches.") 48 | parser.add_argument("--model_dir", type=str, default=r"./out/", 49 | help="Directory storing resulting models. ") 50 | 51 | # Training hparams 52 | parser.add_argument("--ckpt_interval", type=int, default=5000, 53 | help="Number of steps the checkpoint will be saved.") 54 | parser.add_argument("--device", type=str, default='cpu', 55 | help="Device used for training and evaluation.") 56 | parser.add_argument("--is_train", action="store_true", 57 | help="If specified, training mode is set; otherwise, evaluation mode is set.") 58 | parser.add_argument("--learning_rate", type=float, default=1e-5, 59 | help="Learning rate.") 60 | parser.add_argument("--searcher_beam_size", type=int, default=5000, 61 | help="Searcher (Retriever) beam size.") 62 | parser.add_argument("--reader_beam_size", type=int, default=5, 63 | help="Reader beam size.") 64 | group = parser.add_mutually_exclusive_group() 65 | group.add_argument("--num_training_steps", type=int, default=100, 66 | help="Number of training steps.") 67 | group.add_argument("--num_epochs", type=int, default=0, 68 | help="Number of training epochs.") 69 | 70 | # Evaluation hparams 71 | parser.add_argument("--checkpoint_name", type=str, default="checkpoint", 72 | help="Checkpoint name for evalutaion.") 73 | parser.add_argument("--checkpoint_step", type=int, default=5000, 74 | help="Checkpoint step for evalutaion.") 75 | 76 | # Model path 77 | parser.add_argument("--checkpoint_pretrained_name", type=str, default=r"google/realm-cc-news-pretrained-openqa", 78 | help="Pretrained checkpoint for fine-tuning.") 79 | parser.add_argument("--additional_documents_path", type=str, default=None, 80 | help="Additional document entries for retrieval. Must be .npy format.") 81 | 82 | return parser 83 | 84 | def compute_eval_metrics(labels, predicted_answer, reader_output): 85 | """Compute eval metrics.""" 86 | # [] 87 | exact_match = torch.index_select( 88 | torch.index_select( 89 | reader_output.reader_correct, 90 | dim=0, 91 | index=reader_output.block_idx 92 | ), 93 | dim=1, 94 | index=reader_output.candidate 95 | ) 96 | 97 | def _official_exact_match(predicted_answer, references): 98 | return torch.tensor(max( 99 | [normalize_answer(predicted_answer) == normalize_answer(reference) for reference in references] 100 | )) 101 | 102 | official_exact_match = _official_exact_match(predicted_answer, labels) 103 | 104 | eval_metric = dict( 105 | exact_match=exact_match[0][0], 106 | official_exact_match=official_exact_match, 107 | reader_oracle=torch.any(reader_output.reader_correct) 108 | ) 109 | 110 | for k in (5, 10, 50, 100, 500, 1000, 5000): 111 | eval_metric["top_{}_match".format(k)] = torch.any(reader_output.retriever_correct[:k]) 112 | return eval_metric 113 | 114 | def main(args): 115 | 116 | training_dataset, dev_dataset, eval_dataset = load_dataset(args) 117 | 118 | if args.is_train: 119 | global_step = 1 120 | starting_epoch = 1 121 | 122 | config = RealmConfig( 123 | searcher_beam_size=args.searcher_beam_size, 124 | reader_beam_size=args.reader_beam_size, 125 | ) 126 | 127 | openqa = get_openqa(args, config) 128 | retriever = openqa.retriever 129 | tokenizer = openqa.retriever.tokenizer 130 | 131 | if args.additional_documents_path is not None: 132 | add_additional_documents(openqa, args.additional_documents_path) 133 | 134 | openqa.to(args.device) 135 | 136 | # Setup data 137 | logging.info(training_dataset) 138 | logging.info(dev_dataset) 139 | 140 | data_collector = DataCollator(args, tokenizer) 141 | train_dataloader = torch.utils.data.DataLoader( 142 | dataset=training_dataset, 143 | batch_size=1, 144 | shuffle=True, 145 | collate_fn=data_collector 146 | ) 147 | eval_dataloader = torch.utils.data.DataLoader( 148 | dataset=dev_dataset, 149 | batch_size=1, 150 | shuffle=False, 151 | collate_fn=data_collector 152 | ) 153 | 154 | if args.num_epochs == 0: 155 | args.num_epochs = MAX_EPOCHS 156 | else: 157 | args.num_training_steps = args.num_epochs * len(train_dataloader) 158 | 159 | # Optimizer 160 | # See: https://github.com/huggingface/transformers/blob/e239fc3b0baf1171079a5e0177a69254350a063b/examples/pytorch/language-modeling/run_mlm_no_trainer.py#L456-L468 161 | no_decay = ["bias", "LayerNorm.weight"] 162 | optimizer_grouped_parameters = [ 163 | { 164 | "params": [p for n, p in openqa.named_parameters() if not any(nd in n for nd in no_decay)], 165 | "weight_decay": 0.01, 166 | }, 167 | { 168 | "params": [p for n, p in openqa.named_parameters() if any(nd in n for nd in no_decay)], 169 | "weight_decay": 0.0, 170 | }, 171 | ] 172 | 173 | optimizer = torch.optim.AdamW( 174 | optimizer_grouped_parameters, 175 | lr=args.learning_rate, 176 | weight_decay=0.01, 177 | eps=1e-6, 178 | ) 179 | lr_scheduler = get_linear_schedule_with_warmup( 180 | optimizer=optimizer, 181 | num_warmup_steps=min(10000, max(100, 182 | int(args.num_training_steps / 10))), 183 | num_training_steps=args.num_training_steps, 184 | ) 185 | 186 | for epoch in range(starting_epoch, args.num_epochs + 1): 187 | 188 | # Setup training mode 189 | openqa.train() 190 | 191 | for batch in train_dataloader: 192 | optimizer.zero_grad() 193 | question, answer_texts, answer_ids = batch 194 | 195 | question_ids = tokenizer(question, return_tensors="pt").input_ids 196 | reader_output, predicted_answer_ids = openqa( 197 | input_ids=question_ids.to(args.device), 198 | answer_ids=answer_ids, 199 | return_dict=False, 200 | ) 201 | 202 | predicted_answer = tokenizer.decode(predicted_answer_ids) 203 | 204 | reader_output.loss.backward() 205 | clip_grad_norm_(openqa.parameters(), 1.0, norm_type=2.0, error_if_nonfinite=False) 206 | 207 | optimizer.step() 208 | lr_scheduler.step() 209 | 210 | logging.info( 211 | f"Epoch: {epoch}, Step: {global_step}, Retriever Loss: {reader_output.retriever_loss.mean()}, Reader Loss: {reader_output.reader_loss.mean()}\nQuestion: {question}, Gold Answer: {tokenizer.batch_decode(answer_ids) if answer_ids != [[-1]] else None}, Predicted Answer: {predicted_answer}" 212 | ) 213 | 214 | if global_step % args.ckpt_interval == 0: 215 | logging.info(f"Saving checkpint at step {global_step}") 216 | openqa.save_pretrained(os.path.join(args.model_dir, f"{args.checkpoint_name}-{global_step}")) 217 | 218 | global_step += 1 219 | if global_step >= args.num_training_steps: 220 | break 221 | 222 | # Setup eval mode 223 | openqa.eval() 224 | all_metrics = [] 225 | 226 | for batch in tqdm(eval_dataloader): 227 | question, answer_texts, answer_ids = batch 228 | 229 | question_ids = tokenizer(question, return_tensors="pt").input_ids 230 | with torch.no_grad(): 231 | outputs = openqa( 232 | input_ids=question_ids.to(args.device), 233 | answer_ids=answer_ids, 234 | return_dict=True, 235 | ) 236 | 237 | predicted_answer = tokenizer.decode(outputs.predicted_answer_ids) 238 | all_metrics.append(compute_eval_metrics(answer_texts, predicted_answer, outputs.reader_output)) 239 | 240 | stacked_metrics = { 241 | metric_key : torch.stack((*map(lambda metrics: metrics[metric_key], all_metrics),)) for metric_key in all_metrics[0].keys() 242 | } 243 | logging.info(f"Step: {global_step}, Epoch: {epoch}") 244 | logging.info('\n'.join(map(lambda metric: f"{metric[0]}:{metric[1].type(torch.float32).mean()}", stacked_metrics.items()))) 245 | 246 | if global_step >= args.num_training_steps: 247 | break 248 | 249 | logging.info(f"Saving final checkpoint at step {global_step}") 250 | openqa.save_pretrained(os.path.join(args.model_dir, f"{args.checkpoint_name}-{global_step}")) 251 | retriever.save_pretrained(os.path.join(args.model_dir, f"{args.checkpoint_name}-{global_step}")) 252 | else: 253 | retriever = RealmRetriever.from_pretrained(os.path.join(args.model_dir, f"{args.checkpoint_name}-{args.checkpoint_step}")) 254 | tokenizer = retriever.tokenizer 255 | openqa = RealmForOpenQA.from_pretrained(os.path.join(args.model_dir, f"{args.checkpoint_name}-{args.checkpoint_step}"), retriever) 256 | 257 | openqa.config.searcher_beam_size = args.searcher_beam_size 258 | openqa.config.reader_beam_size = args.reader_beam_size 259 | 260 | if args.additional_documents_path is not None: 261 | add_additional_documents(openqa, args.additional_documents_path) 262 | 263 | # Setup eval mode 264 | openqa.eval() 265 | openqa.to(args.device) 266 | 267 | # Setup data 268 | logging.info(eval_dataset) 269 | data_collector = DataCollator(args, tokenizer) 270 | eval_dataloader = torch.utils.data.DataLoader( 271 | dataset=eval_dataset, 272 | batch_size=1, 273 | shuffle=False, 274 | collate_fn=data_collector 275 | ) 276 | 277 | all_metrics = [] 278 | for batch in tqdm(eval_dataloader): 279 | question, answer_texts, answer_ids = batch 280 | question_ids = tokenizer(question, return_tensors="pt").input_ids 281 | 282 | with torch.no_grad(): 283 | outputs = openqa( 284 | input_ids=question_ids.to(args.device), 285 | answer_ids=answer_ids, 286 | return_dict=True, 287 | ) 288 | 289 | predicted_answer = tokenizer.decode(outputs.predicted_answer_ids) 290 | all_metrics.append(compute_eval_metrics(answer_texts, predicted_answer, outputs.reader_output)) 291 | 292 | 293 | stacked_metrics = { 294 | metric_key : torch.stack((*map(lambda metrics: metrics[metric_key], all_metrics),)) for metric_key in all_metrics[0].keys() 295 | } 296 | 297 | logging.info('\n'.join(map(lambda metric: f"{metric[0]}:{metric[1].type(torch.float32).mean()}", stacked_metrics.items()))) 298 | 299 | 300 | if __name__ == "__main__": 301 | logging.info("Test logging") 302 | 303 | parser = get_arg_parser() 304 | args = parser.parse_args() 305 | main(args) --------------------------------------------------------------------------------