├── .gitignore ├── README.md ├── examples ├── prepare │ └── save_batch_generator.py ├── run_bidaf │ └── main.py ├── run_rnet_hkust │ └── main.py └── run_rnet_sogou │ └── main.py ├── pytorch_mrc ├── __init__.py ├── data │ ├── __init__.py │ ├── batch_generator.py │ └── vocabulary.py ├── dataset │ ├── __init__.py │ ├── base_dataset.py │ └── squad.py ├── model │ ├── __init__.py │ ├── base_model.py │ ├── bidaf.py │ ├── rnet_hkust.py │ └── rnet_sogou.py ├── nn │ ├── __init__.py │ ├── attention.py │ ├── dropout.py │ ├── layers.py │ ├── recurrent.py │ ├── similarity_function.py │ └── util.py ├── train │ ├── __init__.py │ └── trainer.py └── utils │ ├── __init__.py │ └── tokenizer.py └── unit_tests ├── data ├── batch_generator_test.py └── vocabulary_test.py └── model ├── bidaf_test.py ├── rnet_hkust_test.py └── rnet_sogou_test.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *,cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # dotenv 80 | .env 81 | 82 | # virtualenv 83 | .venv/ 84 | venv/ 85 | ENV/ 86 | 87 | # Spyder project settings 88 | .spyderproject 89 | 90 | # Rope project settings 91 | .ropeproject 92 | 93 | # pycharm 94 | .idea 95 | 96 | # nohup 97 | *.out 98 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PyTorch Machine Reading Comprehension Toolkit 2 | ## Introduction 3 | **The PyTorch Machine Reading Comprehension (PyTorch-MRC)** toolkit, which was rewritten on the basis of Sogou Machine Reading Comprehension (SMRC), was designed for the fast and efficient development of modern machine comprehension models, including both published models and original prototypes. 4 | 5 | ## Need Teammates! 6 | The whole project is written and maintained by me alone, so I hope that some friends who like NLP and are interested in MRC will work with me to maintain it. Please contact me by email at yingzq0116@163.com. 7 | 8 | ## Toolkit Architecture 9 | 10 | ## Installation 11 | 12 | ## Quick Start 13 | 14 | ## Modules 15 | 1. `data` 16 | - vocabulary.py: Vocabulary building, word/char index mapping and pretrained word embedding building. 17 | - batch_generator.py: Mapping words and tags to indices and building them by [*PyTorch Dataset*](https://pytorch.org/docs/stable/data.html?highlight=dataset#torch.utils.data.Dataset), padding length-variable features dynamically, transforming all of the features into tensors, and batching them by [*PyTorch DataLoader*](https://pytorch.org/docs/stable/data.html?highlight=dataloader#torch.utils.data.DataLoader). 18 | 2. `dataset` 19 | - squad.py: Dataset reader and evaluator (from official code) for SQuAD 1.1 20 | 3. `examples` 21 | - Examples for running different models, where the specified data path should provided to run the examples 22 | 4. `model` 23 | - Base class and subclasses of models, where any model should inherit the base class 24 | - Built-in models such as BiDAF, R-Net and QANet 25 | 5. `nn` 26 | - attention.py: Attention functions such as BiAttention, Trilinear and MultiHeadAttention 27 | - layers: commonly used layers in PyTorch Machine Reading Comprehension, such as VariationalDropout, Highway and PointerNetwork 28 | - recurrent: Special wrappers for LSTM and GRU 29 | - similarity\_function.py: Similarity functions for attention, such as dot_product, trilinear, and symmetric_nolinear 30 | - util: some useful utility functions such as sequence_mask, weighted_sum and masked_softmax 31 | 6. `utils` 32 | - tokenizer.py: Tokenizers that can be used for both English and Chinese 33 | - feature_extractor: Extracting linguistic features used in some papers, e.g., POS, NER, and Lemma 34 | 35 | ## Custom Model and Dataset 36 | 37 | ## Performance 38 | 39 | ### F1/EM score on SQuAD 1.1 dev set 40 | | Model | toolkit implementation | original paper| 41 | | --- | --- | ---| 42 | |BiDAF | 77.8/68.1 | 77.3/67.7 | 43 | |R-Net(sogou) | 79.0/70.5 | 79.5/71.1 | 44 | |R-Net(hkust) | 78.3/69.8 | 79.5/71.1 | 45 | |IARNN-Word | - | - | 46 | |IARNN-hidden | - | - | 47 | |DrQA | - | 78.8/69.5 | 48 | |FusionNet | - | 82.5/74.1 | 49 | |QANet | - | 82.7/73.6 | 50 | |BERT-Base | - | 88.5/80.8 | 51 | 52 | ### F1/EM score on SQuAD 2.0 dev set 53 | 54 | ### F1 score on CoQA dev set 55 | 56 | ## Contact information 57 | For help or issues using this toolkit, please submit a GitHub issue or by email yingzq0116@163.com. 58 | 59 | ## Additional information 60 | When implementing the MRC model, **sometimes I didn't follow the paper reproduction model completely**, because some parts of the paper were not clear to me or I didn't think they play a decisive role. So here's a description. Next I'll list the changes I've made. 61 | 62 | ## Reference Code 63 | - [sogou MRCToolkit](https://github.com/sogou/SMRCToolkit) 64 | - [allenai bi-att-flow](https://github.com/allenai/bi-att-flow) 65 | - [BiDAF-pytorch](https://github.com/galsang/BiDAF-pytorch.git) 66 | 67 | ## Reference Paper 68 | - [Match-LSTM](https://arxiv.org/pdf/1608.07905.pdf) 69 | - [BIDAF](https://arxiv.org/pdf/1611.01603.pdf) 70 | - [R-NET](https://www.microsoft.com/en-us/research/wp-content/uploads/2017/05/r-net.pdf) 71 | - [Highway Networks](https://arxiv.org/pdf/1505.00387.pdf) 72 | - [CNN for Sentence Classification](https://arxiv.org/pdf/1408.5882.pdf) 73 | - [Pointer Networks](https://arxiv.org/pdf/1506.03134.pdf) 74 | - [Variational Dropout](https://arxiv.org/pdf/1512.05287.pdf) 75 | 76 | ## License 77 | -------------------------------------------------------------------------------- /examples/prepare/save_batch_generator.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('../..') 3 | import logging 4 | from pytorch_mrc.data.batch_generator import BatchGenerator 5 | from pytorch_mrc.dataset.squad import SquadReader 6 | from pytorch_mrc.data.vocabulary import Vocabulary 7 | 8 | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') 9 | 10 | # define some variable 11 | EMB_DIM = 300 12 | BATCH_SIZE = 50 13 | FINE_GRAINED = True 14 | DO_LOWERCASE = True 15 | 16 | # define data path 17 | train_file = '/home/len/yingzq/nlp/mrc_dataset/squad-v1.1/train-v1.1.json' 18 | dev_file = '/home/len/yingzq/nlp/mrc_dataset/squad-v1.1/dev-v1.1.json' 19 | embedding_file = '/home/len/yingzq/nlp/mrc_dataset/word_embeddings/glove.840B.300d.txt' 20 | 21 | # the path to save file 22 | vocab_file = '/home/len/yingzq/nlp/mrc_dataset/squad-v1.1/vocab_data/vocab_{}d_{}.pkl'.format(EMB_DIM, 'cased' if DO_LOWERCASE else 'uncased') 23 | bg_train_file = '/home/len/yingzq/nlp/mrc_dataset/squad-v1.1/bg_data/bg_train_{}b_{}d_{}.pkl'.format( 24 | BATCH_SIZE, EMB_DIM, 'cased' if DO_LOWERCASE else 'uncased') 25 | bg_eval_file = '/home/len/yingzq/nlp/mrc_dataset/squad-v1.1/bg_data/bg_eval_{}b_{}d_{}.pkl'.format( 26 | BATCH_SIZE, EMB_DIM, 'cased' if DO_LOWERCASE else 'uncased') 27 | 28 | # read data 29 | reader = SquadReader(fine_grained=FINE_GRAINED) 30 | train_data = reader.read(train_file) 31 | eval_data = reader.read(dev_file) 32 | 33 | # build vocab and embedding 34 | vocab = Vocabulary(do_lowercase=DO_LOWERCASE) 35 | vocab.build_vocab(train_data + eval_data, min_word_count=3, min_char_count=10) 36 | vocab.make_word_embedding(embedding_file) 37 | vocab.save(vocab_file) 38 | 39 | logging.info("building train batch generator...") 40 | train_batch_generator = BatchGenerator() 41 | train_batch_generator.build(vocab, train_data, batch_size=BATCH_SIZE, shuffle=True) 42 | 43 | logging.info("building eval batch generator...") 44 | eval_batch_generator = BatchGenerator() 45 | eval_batch_generator.build(vocab, eval_data, batch_size=BATCH_SIZE) 46 | 47 | train_batch_generator.save(bg_train_file) 48 | eval_batch_generator.save(bg_eval_file) 49 | 50 | print('done!') 51 | -------------------------------------------------------------------------------- /examples/run_bidaf/main.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('../..') 3 | import logging 4 | import torch 5 | from pytorch_mrc.dataset.squad import SquadReader, SquadEvaluator 6 | from pytorch_mrc.model.bidaf import BiDAF 7 | from pytorch_mrc.data.batch_generator import BatchGenerator 8 | 9 | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') 10 | 11 | bg_folder = '/home/len/yingzq/nlp/mrc_dataset/squad-v1.1/bg_data/' 12 | train_bg_file = bg_folder + "bg_train_32b_100d.pkl" 13 | eval_bg_file = bg_folder + "bg_eval_32b_100d.pkl" 14 | dev_file = "/home/len/yingzq/nlp/mrc_dataset/squad-v1.1/dev-v1.1.json" 15 | 16 | reader = SquadReader() 17 | eval_data = reader.read(dev_file) 18 | evaluator = SquadEvaluator(dev_file) 19 | 20 | train_batch_generator = BatchGenerator() 21 | eval_batch_generator = BatchGenerator() 22 | train_batch_generator.load(train_bg_file) 23 | eval_batch_generator.load(eval_bg_file) 24 | vocab = train_batch_generator.get_vocab() 25 | 26 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 27 | model = BiDAF(vocab, device, pretrained_word_embedding=vocab.get_word_embedding()) 28 | model.compile() 29 | model.train_and_evaluate(train_batch_generator, eval_batch_generator, evaluator, epochs=20, episodes=2) 30 | -------------------------------------------------------------------------------- /examples/run_rnet_hkust/main.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('../..') 3 | import logging 4 | import torch 5 | from pytorch_mrc.dataset.squad import SquadReader, SquadEvaluator 6 | from pytorch_mrc.model.rnet_hkust import RNET 7 | from pytorch_mrc.data.batch_generator import BatchGenerator 8 | 9 | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') 10 | bg_folder = '/home/len/yingzq/nlp/mrc_dataset/squad-v1.1/bg_data/' 11 | train_bg_file = bg_folder + "bg_train_32b_300d_cased.pkl" 12 | eval_bg_file = bg_folder + "bg_eval_32b_300d_cased.pkl" 13 | dev_file = "/home/len/yingzq/nlp/mrc_dataset/squad-v1.1/dev-v1.1.json" 14 | 15 | reader = SquadReader(fine_grained=True) 16 | eval_data = reader.read(dev_file) 17 | evaluator = SquadEvaluator(dev_file) 18 | 19 | train_batch_generator = BatchGenerator() 20 | eval_batch_generator = BatchGenerator() 21 | train_batch_generator.load(train_bg_file) 22 | eval_batch_generator.load(eval_bg_file) 23 | vocab = train_batch_generator.get_vocab() 24 | 25 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 26 | model = RNET(vocab, device, pretrained_word_embedding=vocab.get_word_embedding(), word_embedding_size=300) 27 | model.compile(initial_lr=0.001) 28 | model.train_and_evaluate(train_batch_generator, eval_batch_generator, evaluator, epochs=100, episodes=2) 29 | -------------------------------------------------------------------------------- /examples/run_rnet_sogou/main.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('../..') 3 | import logging 4 | import torch 5 | from pytorch_mrc.dataset.squad import SquadReader, SquadEvaluator 6 | from pytorch_mrc.model.rnet_sogou import RNET 7 | from pytorch_mrc.data.batch_generator import BatchGenerator 8 | 9 | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') 10 | bg_folder = '/home/len/yingzq/nlp/mrc_dataset/squad-v1.1/bg_data/' 11 | train_bg_file = bg_folder + "bg_train_50b_300d_uncased.pkl" 12 | eval_bg_file = bg_folder + "bg_eval_50b_300d_uncased.pkl" 13 | dev_file = "/home/len/yingzq/nlp/mrc_dataset/squad-v1.1/dev-v1.1.json" 14 | 15 | reader = SquadReader(fine_grained=True) 16 | eval_data = reader.read(dev_file) 17 | evaluator = SquadEvaluator(dev_file) 18 | 19 | train_batch_generator = BatchGenerator() 20 | eval_batch_generator = BatchGenerator() 21 | train_batch_generator.load(train_bg_file) 22 | eval_batch_generator.load(eval_bg_file) 23 | vocab = train_batch_generator.get_vocab() 24 | 25 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 26 | model = RNET(vocab, device, pretrained_word_embedding=vocab.get_word_embedding(), word_embedding_size=300) 27 | model.compile('adam', 0.001) 28 | model.train_and_evaluate(train_batch_generator, eval_batch_generator, evaluator, epochs=200, episodes=2) 29 | -------------------------------------------------------------------------------- /pytorch_mrc/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YingZiqiang/PyTorch-MRCToolkit/23ec374338509a2e61d0060a43f2d6a32fe337d3/pytorch_mrc/__init__.py -------------------------------------------------------------------------------- /pytorch_mrc/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YingZiqiang/PyTorch-MRCToolkit/23ec374338509a2e61d0060a43f2d6a32fe337d3/pytorch_mrc/data/__init__.py -------------------------------------------------------------------------------- /pytorch_mrc/data/batch_generator.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import logging 3 | import multiprocessing 4 | 5 | import torch 6 | from torch.utils.data import Dataset, DataLoader 7 | 8 | 9 | class BatchGenerator(object): 10 | def __init__(self): 11 | pass 12 | 13 | def build(self, vocab, instances, 14 | batch_size=32, 15 | shuffle=False, 16 | max_context_len=400, 17 | max_question_len=50, 18 | use_char=True, 19 | max_word_len=30, 20 | additional_fields=None, 21 | feature_vocab=None, 22 | num_parallel_calls=0): 23 | """ 24 | Build the batch generator, including build dataset and build dataloader 25 | """ 26 | self.vocab = vocab 27 | self.instances = instances 28 | self.batch_size = batch_size 29 | self.shuffle = shuffle 30 | self.max_context_len = max_context_len 31 | self.max_question_len = max_question_len 32 | self.use_char = use_char 33 | self.max_word_len = max_word_len 34 | self.additional_fields = additional_fields if additional_fields is not None else list() 35 | self.feature_vocab = feature_vocab if feature_vocab is not None else dict() 36 | self.num_parallel_calls = num_parallel_calls if num_parallel_calls > 0 else multiprocessing.cpu_count() // 2 37 | if self.instances is None or len(self.instances) == 0: 38 | raise ValueError('empty instances!!') 39 | 40 | self.dataset = self._build_dataset_pipeline() 41 | self.dataloader = self._build_dataloader_pipeline() 42 | 43 | def save(self, file_path): 44 | """ 45 | Save the attribute of BatchGenerator 46 | """ 47 | logging.info("Saving BatchGenerator at {}".format(file_path)) 48 | # pickle can't save generator and dataloader, so we skip those fields 49 | dataloader_tmp = self.dataloader 50 | self.generator, self.dataloader = None, None 51 | with open(file_path, "wb") as f: 52 | pickle.dump(self.__dict__, f) 53 | self.dataloader = dataloader_tmp 54 | 55 | def load(self, file_path): 56 | """ 57 | Load the saved file and rebuilt BatchGenerator 58 | """ 59 | logging.info("Loading BatchGenerator at {}".format(file_path)) 60 | with open(file_path, 'rb') as f: 61 | vocab_data = pickle.load(f) 62 | self.__dict__.update(vocab_data) 63 | # we don't save the value of generator and dataloader, so we build they here 64 | self.generator = None 65 | self.dataloader = self._build_dataloader_pipeline() 66 | 67 | def init(self): 68 | """ 69 | Initialize the dataloader generator 70 | """ 71 | self.generator = BatchGenerator._generator(self.dataloader) 72 | 73 | def next(self): 74 | """ 75 | Get next batch data of dataloader 76 | """ 77 | if self.generator is None: 78 | raise Exception('you must do init before do next.') 79 | return next(self.generator) 80 | 81 | def get_dataset_size(self): 82 | return len(self.dataset) 83 | 84 | def get_batch_size(self): 85 | return self.batch_size 86 | 87 | def get_raw_dataset(self): 88 | """ 89 | When evaluating and predicting, you may need the raw dataset to generate answers 90 | """ 91 | return self.instances 92 | 93 | def get_vocab(self): 94 | return self.vocab 95 | 96 | @staticmethod 97 | def _generator(dataloader): 98 | for batch_data in dataloader: 99 | yield batch_data 100 | 101 | @staticmethod 102 | def _dynamic_padding(example, pad_len, pad_thing): 103 | example = (example + [pad_thing] * (pad_len - len(example)))[:pad_len] 104 | return example 105 | 106 | @staticmethod 107 | def _detect_input_type(instance, additional_fields=None): 108 | instance_keys = instance.keys() 109 | fields = ['context_tokens', 'question_tokens', 'answer_start', 'answer_end'] 110 | try: 111 | for f in fields: 112 | assert f in instance_keys 113 | except Exception: 114 | raise ValueError('A instance should contain at least "context_tokens", "question_tokens", \ 115 | "answer_start", "answer_end" four fields!') 116 | 117 | if additional_fields is not None and isinstance(additional_fields, list): 118 | fields.extend(additional_fields) 119 | 120 | def get_type(value): 121 | if isinstance(value, float): 122 | return torch.float32 123 | elif isinstance(value, int): 124 | return torch.int64 125 | elif isinstance(value, str): 126 | return str 127 | elif isinstance(value, bool): 128 | return bool 129 | else: 130 | return None 131 | 132 | input_type = {'answer_start': None, 'answer_end': None} 133 | 134 | for field in fields: 135 | if instance[field] is None: 136 | if field not in ('answer_start', 'answer_end'): 137 | logging.warning('Data type of field "%s" not detected! Skip this field.', field) 138 | continue 139 | elif isinstance(instance[field], list): 140 | if len(instance[field]) == 0: 141 | logging.warning('Data shape of field "%s" not detected! Skip this field.', field) 142 | continue 143 | 144 | field_type = get_type(instance[field][0]) 145 | if field_type is not None: 146 | input_type[field] = field_type 147 | else: 148 | logging.warning('Data type of field "%s" not detected! Skip this field.', field) 149 | else: 150 | field_type = get_type(instance[field]) 151 | if field_type is not None: 152 | input_type[field] = field_type 153 | else: 154 | logging.warning('Data type of field "%s" not detected! Skip this field.', field) 155 | 156 | return input_type 157 | 158 | def _build_dataset_pipeline(self): 159 | # 1. Check the input-data type and filter invalid keys 160 | input_type_dict = BatchGenerator._detect_input_type(self.instances[0], self.additional_fields) 161 | filtered_instances = [{field: instance[field] for field in input_type_dict} for instance in self.instances] 162 | 163 | # 2. Some preprocessing, including char extraction, lowercasing, length 164 | def transform_new_instance(instance): 165 | context_tokens = instance['context_tokens'] 166 | question_tokens = instance['question_tokens'] 167 | 168 | if self.use_char: 169 | def get_seq_char_ids(word_tokens): 170 | result = [] 171 | for word in word_tokens: 172 | word_char_ids = [self.vocab.get_char_idx(char) for char in word] 173 | result.append(word_char_ids) 174 | return result 175 | instance['context_char_ids'] = get_seq_char_ids(context_tokens) 176 | instance['question_char_ids'] = get_seq_char_ids(question_tokens) 177 | instance['context_word_len'] = [len(word) for word in context_tokens] 178 | instance['question_word_len'] = [len(word) for word in question_tokens] 179 | 180 | # if do_lowercasing, we will do it in `get_word_idx` function 181 | instance['context_ids'] = [self.vocab.get_word_idx(token) for token in context_tokens] 182 | instance['question_ids'] = [self.vocab.get_word_idx(token) for token in question_tokens] 183 | instance['context_len'] = len(context_tokens) 184 | instance['question_len'] = len(question_tokens) 185 | 186 | # filter the str data, because we don't need them when running neural network 187 | for field, field_type in input_type_dict.items(): 188 | if field_type == str: 189 | del instance[field] 190 | 191 | return instance 192 | 193 | new_instances = [transform_new_instance(instance) for instance in filtered_instances] 194 | 195 | return MRCDataset(new_instances) 196 | 197 | def _build_dataloader_pipeline(self): 198 | word_pad_idx = self.vocab.get_word_pad_idx() 199 | if self.use_char: 200 | char_pad_idx = self.vocab.get_char_pad_idx() 201 | 202 | def mrc_collate(batch): 203 | result = {} 204 | for key in batch[0].keys(): 205 | result[key] = [] 206 | 207 | # 1. Handle the word level sequence data 208 | # 1.1 Get batch pad length 209 | pad_context_len = min(self.max_context_len, max([sample['context_len'] for sample in batch])) 210 | pad_question_len = min(self.max_question_len, max([sample['question_len'] for sample in batch])) 211 | 212 | # 1.2 Padding context and question 213 | for sample in batch: 214 | sample['context_ids'] = BatchGenerator._dynamic_padding(sample['context_ids'], pad_context_len, word_pad_idx) 215 | sample['question_ids'] = BatchGenerator._dynamic_padding(sample['question_ids'], pad_question_len, word_pad_idx) 216 | sample['context_len'] = min(sample['context_len'], pad_context_len) 217 | sample['question_len'] = min(sample['question_len'], pad_question_len) 218 | 219 | # 2. Handle the char level data 220 | if self.use_char: 221 | # 2.1 Padding sample `char ids` and `word len` to batch max length 222 | # TODO padding with 1 length is ok ? 223 | for sample in batch: 224 | sample['context_char_ids'] = BatchGenerator._dynamic_padding( 225 | sample['context_char_ids'], pad_context_len, [char_pad_idx]) 226 | sample['question_char_ids'] = BatchGenerator._dynamic_padding( 227 | sample['question_char_ids'], pad_question_len, [char_pad_idx]) 228 | sample['context_word_len'] = BatchGenerator._dynamic_padding( 229 | sample['context_word_len'], pad_context_len, 1) 230 | sample['question_word_len'] = BatchGenerator._dynamic_padding( 231 | sample['question_word_len'], pad_question_len, 1) 232 | 233 | # 2.2 Get batch pad word length 234 | pad_context_word_len = min(self.max_word_len, max([max(sample['context_word_len']) for sample in batch])) 235 | pad_question_word_len = min(self.max_word_len, max([max(sample['question_word_len']) for sample in batch])) 236 | 237 | # 2.3 Padding batch word len to pad word length 238 | for sample in batch: 239 | sample['context_char_ids'] = [BatchGenerator._dynamic_padding(char_ids, pad_context_word_len, char_pad_idx) 240 | for char_ids in sample['context_char_ids']] 241 | sample['question_char_ids'] = [BatchGenerator._dynamic_padding(char_ids, pad_question_word_len, char_pad_idx) 242 | for char_ids in sample['question_char_ids']] 243 | sample['context_word_len'] = [min(word_len, pad_context_word_len) 244 | for word_len in sample['context_word_len']] 245 | sample['question_word_len'] = [min(word_len, pad_question_word_len) 246 | for word_len in sample['question_word_len']] 247 | 248 | # 3. Convert batch data to `torch tensor` 249 | for sample in batch: 250 | for key, value in sample.items(): 251 | result[key].append(value) 252 | for key, value in result.items(): 253 | result[key] = torch.tensor(value) 254 | 255 | return result 256 | 257 | return DataLoader(dataset=self.dataset, shuffle=self.shuffle, 258 | batch_size=self.batch_size, 259 | collate_fn=mrc_collate, 260 | num_workers=self.num_parallel_calls) 261 | 262 | 263 | class MRCDataset(Dataset): 264 | def __init__(self, instances): 265 | self.instances = instances 266 | 267 | def __getitem__(self, idx): 268 | return self.instances[idx] 269 | 270 | def __len__(self): 271 | return len(self.instances) 272 | -------------------------------------------------------------------------------- /pytorch_mrc/data/vocabulary.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import logging 3 | 4 | import numpy as np 5 | from tqdm import tqdm 6 | from collections import Counter 7 | 8 | 9 | class Vocabulary(object): 10 | def __init__(self, do_lowercase=True, special_tokens=None): 11 | self.word_vocab = None 12 | self.char_vocab = None 13 | self.word2idx = None 14 | self.char2idx = None 15 | self.word_embedding_matrix = None 16 | self.special_tokens = special_tokens 17 | self.do_lowercase = do_lowercase # only for word 18 | 19 | # Initial Tokens 20 | self.pad_token = "" 21 | self.unk_token = "" 22 | self.initial_tokens = (self.pad_token, self.unk_token) 23 | 24 | def build_vocab(self, instances, min_word_count=-1, min_char_count=-1): 25 | self.word_vocab = [token for token in self.initial_tokens] 26 | self.char_vocab = [token for token in self.initial_tokens] 27 | 28 | self.word_counter = Counter() 29 | char_counter = Counter() 30 | if self.special_tokens is not None and isinstance(self.special_tokens, list): 31 | self.word_vocab.extend(self.special_tokens) 32 | 33 | logging.info("Building vocabulary.") 34 | for instance in tqdm(instances): 35 | for token in instance['context_tokens']: 36 | for char in token: 37 | char_counter[char] += 1 38 | token = token.lower() if self.do_lowercase else token 39 | self.word_counter[token] += 1 40 | for token in instance['question_tokens']: 41 | for char in token: 42 | char_counter[char] += 1 43 | token = token.lower() if self.do_lowercase else token 44 | self.word_counter[token] += 1 45 | for w, v in self.word_counter.most_common(): 46 | if v >= min_word_count: 47 | self.word_vocab.append(w) 48 | for c, v in char_counter.most_common(): 49 | if v >= min_char_count: 50 | self.char_vocab.append(c) 51 | 52 | self._build_index_mapper() 53 | 54 | def set_vocab(self, word_vocab, char_vocab): 55 | self.word_vocab = [token for token in self.initial_tokens] 56 | self.char_vocab = [token for token in self.initial_tokens] 57 | if self.special_tokens is not None and isinstance(self.special_tokens, list): 58 | self.word_vocab.extend(self.special_tokens) 59 | 60 | self.word_vocab += word_vocab 61 | self.char_vocab += char_vocab 62 | 63 | self._build_index_mapper() 64 | 65 | def _build_index_mapper(self): 66 | self.word2idx = dict(zip(self.word_vocab, range(len(self.word_vocab)))) 67 | self.char2idx = dict(zip(self.char_vocab, range(len(self.char_vocab)))) 68 | 69 | def make_word_embedding(self, embedding_file, init_scale=0.02): 70 | if self.word_vocab is None or self.word2idx is None: 71 | raise ValueError("make_word_embedding must be called after build_vocab/set_vocab") 72 | 73 | # 1. Parse pretrained embedding 74 | embedding_dict = dict() 75 | with open(embedding_file) as f: 76 | for line in f: 77 | if len(line.rstrip().split(" ")) <= 2: 78 | continue 79 | word, vector = line.rstrip().split(" ", 1) 80 | embedding_dict[word] = np.fromstring(vector, dtype=np.float, sep=" ") 81 | 82 | # 2. Update word vocab according to pretrained word embedding 83 | new_word_vocab = [] 84 | special_tokens_set = set(self.special_tokens if self.special_tokens is not None else []) 85 | for word in self.word_vocab: 86 | if word in self.initial_tokens or word in special_tokens_set or word in embedding_dict: 87 | new_word_vocab.append(word) 88 | self.word_vocab = new_word_vocab 89 | self._build_index_mapper() 90 | 91 | # 3. Make word embedding matrix 92 | embedding_size = embedding_dict[list(embedding_dict.keys())[0]].shape[0] 93 | embedding_list = [] 94 | for word in self.word_vocab: 95 | if word == self.pad_token: 96 | embedding_list.append(np.zeros([1, embedding_size], dtype=np.float)) 97 | elif word == self.unk_token or word in special_tokens_set: 98 | embedding_list.append(np.random.uniform(-init_scale, init_scale, [1, embedding_size])) 99 | else: 100 | embedding_list.append(np.reshape(embedding_dict[word], [1, embedding_size])) 101 | 102 | self.word_embedding_matrix = np.concatenate(embedding_list, axis=0) 103 | 104 | def get_word_pad_idx(self): 105 | return self.word2idx[self.pad_token] 106 | 107 | def get_char_pad_idx(self): 108 | return self.char2idx[self.pad_token] 109 | 110 | def get_word_unk_idx(self): 111 | return self.word2idx[self.unk_token] 112 | 113 | def get_char_unk_idx(self): 114 | return self.char2idx[self.unk_token] 115 | 116 | def get_word_idx(self, token): 117 | token = token.lower() if self.do_lowercase else token 118 | if token in self.word_vocab: 119 | return self.word2idx[token] 120 | else: 121 | return self.get_word_unk_idx() 122 | 123 | def get_char_idx(self, token): 124 | if token in self.char_vocab: 125 | return self.char2idx[token] 126 | else: 127 | return self.get_char_unk_idx() 128 | 129 | def get_word_vocab(self): 130 | return self.word_vocab 131 | 132 | def get_char_vocab(self): 133 | return self.char_vocab 134 | 135 | def get_word_counter(self): 136 | return self.word_counter 137 | 138 | def get_word_embedding(self): 139 | if self.word_embedding_matrix is None: 140 | raise ValueError("get_word_embedding must be called after make_word_embedding") 141 | return self.word_embedding_matrix 142 | 143 | def save(self, file_path, include_word_embedding=True): 144 | logging.info("Saving vocabulary at {}".format(file_path)) 145 | # if include_word_embedding is False, we will not save the word embedding matrix 146 | if not include_word_embedding: 147 | tmp_word_emb = self.word_embedding_matrix 148 | self.word_embedding_matrix = None 149 | with open(file_path, "wb") as f: 150 | pickle.dump(self.__dict__, f) 151 | if not include_word_embedding: 152 | self.word_embedding_matrix = tmp_word_emb 153 | 154 | def load(self, file_path): 155 | logging.info("Loading vocabulary at {}".format(file_path)) 156 | with open(file_path, 'rb') as f: 157 | vocab_data = pickle.load(f) 158 | self.__dict__.update(vocab_data) 159 | -------------------------------------------------------------------------------- /pytorch_mrc/dataset/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YingZiqiang/PyTorch-MRCToolkit/23ec374338509a2e61d0060a43f2d6a32fe337d3/pytorch_mrc/dataset/__init__.py -------------------------------------------------------------------------------- /pytorch_mrc/dataset/base_dataset.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | 4 | class BaseReader(object): 5 | def read(self, *input): 6 | raise NotImplementedError 7 | 8 | 9 | class BaseEvaluator(object): 10 | def get_score(self, *input): 11 | raise NotImplementedError 12 | -------------------------------------------------------------------------------- /pytorch_mrc/dataset/squad.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | import json 3 | import logging 4 | import re 5 | import string 6 | from collections import OrderedDict, Counter 7 | from tqdm import tqdm 8 | from pytorch_mrc.utils.tokenizer import SpacyTokenizer 9 | from pytorch_mrc.dataset.base_dataset import BaseReader, BaseEvaluator 10 | 11 | 12 | class SquadReader(BaseReader): 13 | def __init__(self, fine_grained=False): 14 | self.tokenizer = SpacyTokenizer(fine_grained) 15 | 16 | def read(self, file_path, context_limit=-1): 17 | logging.info("Reading file at %s", file_path) 18 | logging.info("Processing the dataset.") 19 | instances = self._read(file_path, context_limit) 20 | instances = [instance for instance in tqdm(instances)] 21 | return instances 22 | 23 | def _read(self, file_path, context_limit=-1): 24 | with open(file_path) as dataset_file: 25 | dataset_json = json.load(dataset_file) 26 | dataset = dataset_json['data'] 27 | for article in dataset: 28 | for paragraph in article['paragraphs']: 29 | context = paragraph["context"] 30 | context_tokens, context_token_spans = self.tokenizer.word_tokenizer(context) 31 | for question_answer in paragraph['qas']: 32 | question = question_answer["question"].strip() 33 | question_tokens, _ = self.tokenizer.word_tokenizer(question) 34 | 35 | answers, span_starts, span_ends = [], [], [] 36 | if "answers" in question_answer: 37 | answers = [answer['text'] for answer in question_answer['answers']] 38 | span_starts = [answer['answer_start'] 39 | for answer in question_answer['answers']] 40 | span_ends = [start + len(answer) 41 | for start, answer in zip(span_starts, answers)] 42 | 43 | answer_char_spans = zip(span_starts, span_ends) if len( 44 | span_starts) > 0 and len(span_ends) > 0 else None 45 | answers = answers if len(answers) > 0 else None 46 | qid = question_answer['id'] 47 | instance = self._make_instance(context, context_tokens, context_token_spans, 48 | question, question_tokens, answer_char_spans, answers, qid) 49 | if context_limit > 0 and len(instance['context_tokens']) > context_limit: 50 | if instance['answer_start'] > context_limit or instance['answer_end'] > context_limit: 51 | continue 52 | else: 53 | instance['context_tokens'] = instance['context_tokens'][:context_limit] 54 | 55 | yield instance 56 | 57 | def _make_instance(self, context, context_tokens, context_token_spans, question, question_tokens, 58 | answer_char_spans=None, answers=None, qid=None): 59 | answer_token_starts, answer_token_ends = [], [] 60 | if answers is not None: 61 | for answer_char_start, answer_char_end in answer_char_spans: 62 | answer_token_span = [] 63 | for idx, span in enumerate(context_token_spans): 64 | if not (answer_char_end <= span[0] or answer_char_start >= span[1]): 65 | answer_token_span.append(idx) 66 | 67 | assert len(answer_token_span) > 0 68 | answer_token_starts.append(answer_token_span[0]) 69 | answer_token_ends.append(answer_token_span[-1]) 70 | 71 | return OrderedDict({ 72 | "context": context, 73 | "context_tokens": context_tokens, 74 | "context_token_spans": context_token_spans, 75 | "context_word_len": [len(word) for word in context_tokens], 76 | "question_word_len": [len(word) for word in question_tokens], 77 | "question": question, 78 | 'qid': qid, 79 | "question_tokens": question_tokens, 80 | "answer": answers[0] if answers is not None else None, 81 | "answer_start": answer_token_starts[0] if answers is not None else None, 82 | "answer_end": answer_token_ends[0] if answers is not None else None, 83 | }) 84 | 85 | 86 | class SquadEvaluator(BaseEvaluator): 87 | def __init__(self, file_path, monitor='f1'): 88 | self.ground_dict = dict() 89 | self.id_list = [] 90 | self.monitor = monitor 91 | 92 | with open(file_path) as dataset_file: 93 | dataset_json = json.load(dataset_file) 94 | dataset = dataset_json['data'] 95 | for article in dataset: 96 | for paragraph in article['paragraphs']: 97 | for question_answer in paragraph['qas']: 98 | id = question_answer["id"] 99 | self.ground_dict[id] = [answer['text'] for answer in question_answer['answers']] 100 | self.id_list.append(id) 101 | 102 | def get_monitor(self): 103 | return self.monitor 104 | 105 | def get_score(self, pred_answer): 106 | if isinstance(pred_answer, list): 107 | assert len(self.id_list) == len(pred_answer) 108 | answer_dict = dict(zip(self.id_list, pred_answer)) 109 | else: 110 | answer_dict = pred_answer 111 | 112 | f1 = exact_match = total = 0 113 | for key, value in answer_dict.items(): 114 | total += 1 115 | ground_truths = self.ground_dict[key] 116 | prediction = value 117 | exact_match += SquadEvaluator.metric_max_over_ground_truths( 118 | SquadEvaluator.exact_match_score, prediction, ground_truths) 119 | f1 += SquadEvaluator.metric_max_over_ground_truths( 120 | SquadEvaluator.f1_score, prediction, ground_truths) 121 | exact_match = 100.0 * exact_match / total 122 | f1 = 100.0 * f1 / total 123 | return {'exact_match': exact_match, 'f1': f1} 124 | 125 | @staticmethod 126 | def normalize_answer(s): 127 | def remove_articles(text): 128 | return re.sub(r'\b(a|an|the)\b', ' ', text) 129 | 130 | def white_space_fix(text): 131 | return ' '.join(text.split()) 132 | 133 | def remove_punc(text): 134 | exclude = set(string.punctuation) 135 | return ''.join(ch for ch in text if ch not in exclude) 136 | 137 | def lower(text): 138 | return text.lower() 139 | 140 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 141 | 142 | @staticmethod 143 | def exact_match_score(prediction, ground_truth): 144 | return (SquadEvaluator.normalize_answer(prediction) == SquadEvaluator.normalize_answer(ground_truth)) 145 | 146 | @staticmethod 147 | def metric_max_over_ground_truths(metric_fn, prediction, ground_truths): 148 | scores_for_ground_truths = [] 149 | for ground_truth in ground_truths: 150 | score = metric_fn(prediction, ground_truth) 151 | scores_for_ground_truths.append(score) 152 | return max(scores_for_ground_truths) 153 | 154 | @staticmethod 155 | def f1_score(prediction, ground_truth): 156 | prediction_tokens = SquadEvaluator.normalize_answer(prediction).split() 157 | ground_truth_tokens = SquadEvaluator.normalize_answer(ground_truth).split() 158 | common = Counter(prediction_tokens) & Counter(ground_truth_tokens) 159 | num_same = sum(common.values()) 160 | if num_same == 0: 161 | return 0 162 | precision = 1.0 * num_same / len(prediction_tokens) 163 | recall = 1.0 * num_same / len(ground_truth_tokens) 164 | f1 = (2 * precision * recall) / (precision + recall) 165 | return f1 166 | -------------------------------------------------------------------------------- /pytorch_mrc/model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YingZiqiang/PyTorch-MRCToolkit/23ec374338509a2e61d0060a43f2d6a32fe337d3/pytorch_mrc/model/__init__.py -------------------------------------------------------------------------------- /pytorch_mrc/model/base_model.py: -------------------------------------------------------------------------------- 1 | # coding:utf-8 2 | import logging 3 | import torch 4 | import torch.nn as nn 5 | from pytorch_mrc.train.trainer import Trainer 6 | 7 | 8 | class BaseModel(nn.Module): 9 | def __init__(self, vocab=None, device=None): 10 | super(BaseModel, self).__init__() 11 | 12 | self.vocab = vocab 13 | 14 | if device is None: 15 | device = torch.device('cpu') 16 | logging.warning('No device is assigned, given the default `cpu`') 17 | if not isinstance(device, torch.device): 18 | raise TypeError('device must be the instance of `torch.device`, not the instance of `{}`'.format(type(device))) 19 | self.device = device 20 | 21 | # self.initialized = False 22 | self.ema_decay = 0 23 | 24 | def __del__(self): 25 | # TODO 26 | pass 27 | 28 | def load(self, path, var_list=None): 29 | # TODO 30 | # var_list = None returns the list of all saveable variables 31 | pass 32 | # self.initialized = True 33 | 34 | def save(self, path, global_step=None, var_list=None): 35 | # TODO 36 | pass 37 | 38 | def forward(self, *input): 39 | raise NotImplementedError 40 | 41 | def compile(self, *input): 42 | raise NotImplementedError 43 | 44 | def update(self): 45 | # TODO There are still some problems with logic. 46 | if not self.training: 47 | raise Exception("Only in the train mode, you can update the weights") 48 | if self.optimizer is None: 49 | raise Exception("The model need to compile!") 50 | 51 | self.optimizer.step() 52 | # self.optimizer.zero_grad() 53 | 54 | def get_best_answer(self, *input): 55 | raise NotImplementedError 56 | 57 | def train_and_evaluate(self, train_generator, eval_generator, evaluator, epochs=1, episodes=1, 58 | save_dir=None, summary_dir=None, save_summary_steps=10, log_every_n_batch=100): 59 | Trainer.train_and_evaluate(self, self.device, train_generator, eval_generator, evaluator, 60 | epochs=epochs, episodes=episodes, 61 | save_dir=save_dir, summary_dir=summary_dir, save_summary_steps=save_summary_steps, 62 | log_every_n_batch=log_every_n_batch) 63 | 64 | def evaluate(self, batch_generator, evaluator): 65 | Trainer.evaluate(self, self.device, batch_generator, evaluator) 66 | 67 | def inference(self, batch_generator): 68 | Trainer.inference(self, self.device, batch_generator) 69 | -------------------------------------------------------------------------------- /pytorch_mrc/model/bidaf.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import CrossEntropyLoss 4 | 5 | from pytorch_mrc.model.base_model import BaseModel 6 | from pytorch_mrc.nn.layers import Embedding, Conv1DAndMaxPooling, Highway, VariationalDropout 7 | from pytorch_mrc.nn.recurrent import BiLSTM 8 | from pytorch_mrc.nn.attention import BiAttention 9 | from pytorch_mrc.nn.similarity_function import TriLinearSimilarity 10 | from pytorch_mrc.nn.util import sequence_mask, masked_softmax, mask_logits, weighted_sum 11 | 12 | 13 | class BiDAF(BaseModel): 14 | def __init__(self, vocab, device, 15 | pretrained_word_embedding=None, 16 | word_embedding_trainable=False, 17 | word_embedding_size=100, 18 | char_embedding_size=8, 19 | char_conv_filters=100, 20 | char_conv_kernel_size=5, 21 | use_elmo=False, 22 | elmo_local_path=None, 23 | rnn_hidden_size=100, 24 | dropout_prob=0.2, 25 | max_answer_len=17, 26 | enable_na_answer=False): 27 | super(BiDAF, self).__init__(vocab, device) 28 | self.use_elmo = use_elmo 29 | self.elmo_local_path = elmo_local_path 30 | self.max_answer_len = max_answer_len 31 | self.enable_na_answer = enable_na_answer # for squad2.0 32 | 33 | # Embedding 34 | # TODO UNK token need to handle 35 | self.word_embedding = Embedding(pretrained_embedding=pretrained_word_embedding, 36 | embedding_shape=(len(vocab.get_word_vocab()), word_embedding_size), 37 | trainable=word_embedding_trainable) 38 | self.char_embedding = Embedding(embedding_shape=(len(vocab.get_char_vocab()), char_embedding_size), 39 | trainable=True, init_scale=0.2) 40 | embedding_dim = word_embedding_size + char_conv_filters 41 | self.conv1d = Conv1DAndMaxPooling(char_embedding_size, char_conv_filters, char_conv_kernel_size) 42 | if use_elmo: 43 | # TODO 44 | pass 45 | # embedding_dim += ?? 46 | self.highway = Highway(input_dim=embedding_dim, num_layers=2) 47 | 48 | self.encode_phrase_lstm = BiLSTM(embedding_dim, rnn_hidden_size) 49 | 50 | # Attention Flow Layer 51 | self.bi_attention = BiAttention(TriLinearSimilarity(input_dim=2 * rnn_hidden_size)) 52 | 53 | # Modeling Layer 54 | self.modeling_lstm1 = BiLSTM(8 * rnn_hidden_size, rnn_hidden_size) 55 | self.modeling_lstm2 = BiLSTM(2 * rnn_hidden_size, rnn_hidden_size) 56 | 57 | # Output Layer 58 | self.start_pred_layer = nn.Linear(10 * rnn_hidden_size, 1, bias=False) 59 | self.end_lstm = BiLSTM(14 * rnn_hidden_size, rnn_hidden_size) 60 | self.end_pred_layer = nn.Linear(10 * rnn_hidden_size, 1, bias=False) 61 | 62 | self.dropout = VariationalDropout(p=dropout_prob) 63 | 64 | def forward(self, data): 65 | # Parsing data 66 | context_ids, context_len = data['context_ids'], data['context_len'] 67 | question_ids, question_len = data['question_ids'], data['question_len'] 68 | answer_start, answer_end = data['answer_start'], data['answer_end'] 69 | context_char_ids, context_word_len = data['context_char_ids'], data['context_word_len'] 70 | question_char_ids, question_word_len = data['question_char_ids'], data['question_word_len'] 71 | 72 | # compute mask 73 | context_mask = sequence_mask(context_len, maxlen=context_ids.size(1)) 74 | question_mask = sequence_mask(question_len, maxlen=question_ids.size(1)) 75 | 76 | # 1.1 Embedding 77 | context_word_repr = self.word_embedding(context_ids) 78 | context_char_repr = self.char_embedding(context_char_ids) 79 | question_word_repr = self.word_embedding(question_ids) 80 | question_char_repr = self.char_embedding(question_char_ids) 81 | 82 | # 1.2 Char convolution 83 | context_char_repr = self.dropout(self.conv1d(context_char_repr)) 84 | question_char_repr = self.dropout(self.conv1d(question_char_repr)) 85 | 86 | # 1.3 Concat word and char 87 | context_repr = torch.cat([context_word_repr, context_char_repr], dim=-1) 88 | question_repr = torch.cat([question_word_repr, question_char_repr], dim=-1) 89 | 90 | # 1.4 ELMo embedding 91 | if self.use_elmo: 92 | # TODO 93 | pass 94 | 95 | # 1.5 Highway network 96 | context_repr = self.highway(context_repr) 97 | question_repr = self.highway(question_repr) 98 | 99 | # 2. Phrase encoding 100 | context_repr, _ = self.encode_phrase_lstm(self.dropout(context_repr), context_len) 101 | question_repr, _ = self.encode_phrase_lstm(self.dropout(question_repr), question_len) 102 | 103 | # 3. Bi-Attention 104 | c2q, q2c = self.bi_attention(context_repr, question_repr, context_mask, question_mask) 105 | 106 | # 4. Modeling layer 107 | final_merged_context = torch.cat([context_repr, c2q, context_repr * c2q, context_repr * q2c], dim=-1) 108 | modeled_context1, _ = self.modeling_lstm1(self.dropout(final_merged_context), context_len) 109 | modeled_context2, _ = self.modeling_lstm2(self.dropout(modeled_context1), context_len) 110 | modeled_context = modeled_context1 + modeled_context2 111 | 112 | # 5. Start prediction 113 | start_logits = self.start_pred_layer(self.dropout(torch.cat([final_merged_context, modeled_context], dim=-1))) 114 | start_logits = start_logits.squeeze(-1) 115 | start_prob = masked_softmax(start_logits, context_mask) 116 | 117 | # 6. End prediction 118 | start_repr = weighted_sum(modeled_context, start_prob) 119 | tiled_start_repr = start_repr.unsqueeze(1).repeat(1, modeled_context.size(1), 1) 120 | end_repr = torch.cat([final_merged_context, 121 | modeled_context, 122 | tiled_start_repr, 123 | modeled_context * tiled_start_repr], 124 | dim=-1) 125 | encoded_end_repr, _ = self.end_lstm(self.dropout(end_repr), context_len) 126 | end_logits = self.end_pred_layer(self.dropout(torch.cat([final_merged_context, encoded_end_repr], dim=-1))) 127 | end_logits = end_logits.squeeze(-1) 128 | end_prob = masked_softmax(end_logits, context_mask) 129 | 130 | # 7. Retured Things. 131 | # TODO for squad2.0 and for multi GPUs 132 | if answer_start is not None and answer_end is not None: 133 | # sometimes the start/end positions are outside our model inputs, we ignore these terms 134 | ignored_index = start_logits.size(1) 135 | answer_start.clamp_(0, ignored_index) 136 | answer_end.clamp_(0, ignored_index) 137 | 138 | # compute loss 139 | loss_fct = CrossEntropyLoss(ignore_index=ignored_index) 140 | start_loss = loss_fct(mask_logits(start_logits, context_mask), answer_start) 141 | end_loss = loss_fct(mask_logits(end_logits, context_mask), answer_end) 142 | total_loss = (start_loss + end_loss) / 2 143 | 144 | if self.training: 145 | # if train, we return loss only 146 | return total_loss 147 | else: 148 | # if eval, we return a tuple (loss, output_dict) 149 | output_dict = { 150 | "start_prob": start_prob.cpu().numpy(), 151 | "end_prob": end_prob.cpu().numpy() 152 | } 153 | return total_loss, output_dict 154 | else: 155 | # if inference, we return output_dict only 156 | output_dict = { 157 | "start_prob": start_prob.cpu().numpy(), 158 | "end_prob": end_prob.cpu().numpy() 159 | } 160 | return output_dict 161 | 162 | def compile(self, optimizer=torch.optim.Adam, initial_lr=0.001): 163 | self.optimizer = optimizer(self.parameters(), lr=initial_lr) 164 | 165 | def get_best_answer(self, output, instances): 166 | na_prob = {} 167 | preds_dict = {} 168 | for i in range(len(instances)): 169 | instance = instances[i] 170 | max_prob, max_start, max_end = 0, 0, 0 171 | for end in range(output['end_prob'][i].shape[0]): 172 | for start in range(max(0, end - self.max_answer_len + 1), end + 1): 173 | prob = output["start_prob"][i][start] * output["end_prob"][i][end] 174 | if prob > max_prob: 175 | max_start, max_end = start, end 176 | max_prob = prob 177 | 178 | char_start_position = instance["context_token_spans"][max_start][0] 179 | char_end_position = instance["context_token_spans"][max_end][1] 180 | pred_answer = instance["context"][char_start_position:char_end_position] 181 | if not self.enable_na_answer: 182 | preds_dict[instance['qid']] = pred_answer 183 | else: 184 | preds_dict[instance['qid']] = pred_answer if max_prob > output['na_prob'][i] else '' 185 | na_prob[instance['qid']] = output['na_prob'][i] 186 | 187 | return preds_dict if not self.enable_na_answer else (preds_dict, na_prob) 188 | -------------------------------------------------------------------------------- /pytorch_mrc/model/rnet_hkust.py: -------------------------------------------------------------------------------- 1 | """ 2 | implements "HKUST R-Net"(https://github.com/HKUST-KnowComp/R-Net) with PyTorch. 3 | There are some slight deference with hkust model: 4 | 1. We use dynamic batch and padding, rather than static 5 | 2. When encoding, we just use zero to init Bi-GRU, but hkust uses a trainable variable to init 6 | """ 7 | 8 | from collections import deque 9 | import torch 10 | from torch.nn import CrossEntropyLoss 11 | from torch.optim import Adam, Adadelta 12 | 13 | from pytorch_mrc.model.base_model import BaseModel 14 | from pytorch_mrc.nn.dropout import VariationalDropout 15 | from pytorch_mrc.nn.layers import Embedding, StaticPairEncoder, StaticSelfMatchEncoder, PointerNetwork 16 | from pytorch_mrc.nn.recurrent import BiGRU, MultiLayerBiGRU 17 | from pytorch_mrc.nn.util import sequence_mask, masked_softmax, mask_logits 18 | 19 | 20 | class RNET(BaseModel): 21 | def __init__(self, vocab, device, 22 | pretrained_word_embedding=None, 23 | word_embedding_trainable=False, 24 | word_embedding_size=300, 25 | char_embedding_size=8, 26 | char_hidden_size=100, 27 | encoder_layers_num=3, 28 | hidden_size=75, 29 | dropout_prob=0.3): 30 | super(RNET, self).__init__(vocab, device) 31 | self.pretrained_word_embedding = pretrained_word_embedding 32 | self.word_embedding_trainable = word_embedding_trainable 33 | self.word_embedding_size = word_embedding_size 34 | self.char_embedding_size = char_embedding_size 35 | self.char_hidden_size = char_hidden_size 36 | self.encoder_layers_num = encoder_layers_num 37 | self.hidden_size = hidden_size 38 | self.dropout_prob = dropout_prob 39 | 40 | # Embedding 41 | self.word_embedding = Embedding(pretrained_embedding=pretrained_word_embedding, 42 | embedding_shape=(len(vocab.get_word_vocab()), word_embedding_size), 43 | trainable=word_embedding_trainable) 44 | self.char_embedding = Embedding(embedding_shape=(len(vocab.get_char_vocab()), char_embedding_size), 45 | trainable=True, init_scale=0.2) 46 | self.char_bigru = BiGRU(char_embedding_size, char_hidden_size) 47 | 48 | # Encoder 49 | self.encoder_multi_bigru = MultiLayerBiGRU(word_embedding_size + 2 * char_hidden_size, hidden_size, 50 | num_layers=encoder_layers_num, input_drop_prob=dropout_prob) 51 | 52 | # Gated attention RNNs in the paper 53 | self.gated_att_bigru = StaticPairEncoder(2 * encoder_layers_num * hidden_size, 54 | 2 * encoder_layers_num * hidden_size, 55 | hidden_dim=hidden_size, drop_prob=dropout_prob) 56 | 57 | # Self matching attention 58 | self.self_match_att = StaticSelfMatchEncoder(2 * hidden_size, 2 * hidden_size, 59 | hidden_dim=hidden_size, drop_prob=dropout_prob) 60 | 61 | # Output Layer 62 | self.pointer_net = PointerNetwork(2 * hidden_size, 2 * encoder_layers_num * hidden_size, 63 | hidden_dim=hidden_size, drop_prob=dropout_prob) 64 | 65 | # RNN Dropout 66 | self.dropout = VariationalDropout(dropout_prob, batch_first=True) 67 | 68 | def forward(self, data): 69 | # Parsing data 70 | context_ids, context_len = data['context_ids'], data['context_len'] 71 | question_ids, question_len = data['question_ids'], data['question_len'] 72 | answer_start, answer_end = data['answer_start'], data['answer_end'] 73 | context_char_ids, context_word_len = data['context_char_ids'], data['context_word_len'] 74 | question_char_ids, question_word_len = data['question_char_ids'], data['question_word_len'] 75 | 76 | # Record maximum length info and generate mask matrix 77 | max_context_len = context_ids.size(1) 78 | max_context_word_len = context_char_ids.size(2) 79 | max_question_len = question_ids.size(1) 80 | max_question_word_len = question_char_ids.size(2) 81 | context_mask = sequence_mask(context_len, maxlen=max_context_len) 82 | question_mask = sequence_mask(question_len, maxlen=max_question_len) 83 | 84 | # 1. Context and Question Encoder 85 | # 1.1 Word and char embedding 86 | context_word_repr = self.word_embedding(context_ids) # B*CL*WD 87 | context_char_embedding = self.dropout(self.char_embedding(context_char_ids).reshape( 88 | [-1, max_context_word_len, self.char_embedding_size])) # (B*CL)*WL*CD 89 | question_word_repr = self.word_embedding(question_ids) # B*QL*WD 90 | question_char_embedding = self.dropout(self.char_embedding(question_char_ids).reshape( 91 | [-1, max_question_word_len, self.char_embedding_size])) # (B*QL)*WL*CD 92 | 93 | # 1.2 Char-level representation 94 | _, last_hidden_state = self.char_bigru(context_char_embedding, context_word_len.reshape([-1])) # 2*(B*CL)*CH 95 | context_char_repr = torch.cat([last_hidden_state[0], last_hidden_state[1]], dim=-1) # (B*CL)*2CH 96 | context_char_repr = context_char_repr.reshape([-1, max_context_len, 2 * self.char_hidden_size]) # B*CL*2CH 97 | 98 | _, last_hidden_state = self.char_bigru(question_char_embedding, question_word_len.reshape([-1])) # 2*(B*QL)*CH 99 | question_char_repr = torch.cat([last_hidden_state[0], last_hidden_state[1]], dim=-1) # (B*QL)*2CH 100 | question_char_repr = question_char_repr.reshape([-1, max_question_len, 2 * self.char_hidden_size]) # B*QL*2CH 101 | 102 | # 1.3 Concat word and char representation 103 | context_repr = torch.cat([context_word_repr, context_char_repr], dim=-1) # B*CL*(WD+2CH) 104 | question_repr = torch.cat([question_word_repr, question_char_repr], dim=-1) # B*QL*(WD+2CH) 105 | 106 | # 2. Encoder 107 | encoder_context, _ = self.encoder_multi_bigru(context_repr, context_len, concat_layers=True) # B*CL*(H*2*num_layers) 108 | encoder_question, _ = self.encoder_multi_bigru(question_repr, question_len, concat_layers=True) # B*QL*(H*2*num_layers) 109 | 110 | # 3. Gated attention RNNs in the paper 111 | gated_att_repr = self.gated_att_bigru(encoder_context, encoder_question, context_len, question_mask) # B*CL*2H 112 | 113 | # 4. Self matching attention 114 | self_att_repr = self.self_match_att(gated_att_repr, gated_att_repr, context_len, context_mask) # B*CL*2H 115 | 116 | # 5. Pointer Network 117 | start_logits, end_logits = self.pointer_net(self_att_repr, encoder_question, context_mask, question_mask) 118 | start_prob = masked_softmax(start_logits, context_mask) 119 | end_prob = masked_softmax(end_logits, context_mask) 120 | 121 | # 6. Retured Things. If train return loss, if eval/inference return a dict 122 | # TODO for squad2.0 and for multi GPUs 123 | if answer_start is not None and answer_end is not None: 124 | # sometimes the start/end positions are outside our model inputs, we ignore these terms 125 | ignored_index = start_logits.size(1) 126 | answer_start.clamp_(0, ignored_index) 127 | answer_end.clamp_(0, ignored_index) 128 | 129 | loss_fct = CrossEntropyLoss(ignore_index=ignored_index) 130 | start_loss = loss_fct(mask_logits(start_logits, context_mask), answer_start) 131 | end_loss = loss_fct(mask_logits(end_logits, context_mask), answer_end) 132 | total_loss = (start_loss + end_loss) / 2 133 | 134 | if self.training: 135 | return total_loss 136 | else: 137 | output_dict = { 138 | "start_prob": start_prob.cpu().numpy(), 139 | "end_prob": end_prob.cpu().numpy() 140 | } 141 | return total_loss, output_dict 142 | else: 143 | output_dict = { 144 | "start_prob": start_prob.cpu().numpy(), 145 | "end_prob": end_prob.cpu().numpy() 146 | } 147 | return output_dict 148 | 149 | def compile(self, optimizer='adam', initial_lr=0.002): 150 | if optimizer.lower() == 'adam': 151 | self.optimizer = Adam(self.parameters(), lr=initial_lr) 152 | elif optimizer.lower() == 'adadelta': 153 | self.optimizer = Adadelta(self.parameters(), lr=initial_lr, rho=0.95, eps=1e-08) 154 | else: 155 | raise NotImplementedError("the optimizer hasn't been implemented") 156 | 157 | def update(self, grad_clip=5.0): 158 | torch.nn.utils.clip_grad_norm_(self.parameters(), max_norm=grad_clip) 159 | super().update() 160 | 161 | def get_best_answer(self, output, instances, max_len=15): 162 | answer_list = [] 163 | for i in range(len(output['start_prob'])): 164 | instance = instances[i] 165 | max_prob = 0.0 166 | start_position = 0 167 | end_position = 0 168 | d = deque() 169 | start_prob, end_prob = output['start_prob'][i], output['end_prob'][i] 170 | for idx in range(len(start_prob)): 171 | while len(d) > 0 and idx - d[0] >= max_len: 172 | d.popleft() 173 | while len(d) > 0 and start_prob[d[-1]] <= start_prob[idx]: 174 | d.pop() 175 | d.append(idx) 176 | if start_prob[d[0]] * end_prob[idx] > max_prob: 177 | start_position = d[0] 178 | end_position = idx 179 | max_prob = start_prob[d[0]] * end_prob[idx] 180 | char_start_position = instance["context_token_spans"][start_position][0] 181 | char_end_position = instance["context_token_spans"][end_position][1] 182 | pred_answer = instance["context"][char_start_position:char_end_position] 183 | answer_list.append(pred_answer) 184 | return answer_list 185 | -------------------------------------------------------------------------------- /pytorch_mrc/model/rnet_sogou.py: -------------------------------------------------------------------------------- 1 | from collections import deque 2 | import torch 3 | import torch.nn as nn 4 | from torch.nn import CrossEntropyLoss 5 | from torch.optim import Adam, Adadelta 6 | 7 | from pytorch_mrc.model.base_model import BaseModel 8 | from pytorch_mrc.nn.layers import VariationalDropout, Embedding, PointerNetwork 9 | from pytorch_mrc.nn.recurrent import BiGRU, MultiLayerBiGRU 10 | from pytorch_mrc.nn.util import sequence_mask, masked_softmax, mask_logits 11 | from pytorch_mrc.nn.attention import RnetCoAttention, MultiHeadAttention 12 | 13 | 14 | class RNET(BaseModel): 15 | def __init__(self, vocab, device, 16 | pretrained_word_embedding=None, 17 | word_embedding_trainable=False, 18 | word_embedding_size=300, 19 | char_embedding_size=100, 20 | heads=3, 21 | encoder_layers_num=1, 22 | hidden_size=75, 23 | dropout_prob=0.2): 24 | super(RNET, self).__init__(vocab, device) 25 | self.pretrained_word_embedding = pretrained_word_embedding 26 | self.word_embedding_trainable = word_embedding_trainable 27 | self.word_embedding_size = word_embedding_size 28 | self.char_embedding_size = char_embedding_size 29 | self.encoder_layers_num = encoder_layers_num 30 | self.hidden_size = hidden_size 31 | self.dropout_prob = dropout_prob 32 | 33 | # Embedding 34 | self.word_embedding = Embedding(pretrained_embedding=pretrained_word_embedding, 35 | embedding_shape=(len(vocab.get_word_vocab()), word_embedding_size), 36 | trainable=word_embedding_trainable) 37 | self.char_embedding = Embedding(embedding_shape=(len(vocab.get_char_vocab()), char_embedding_size), 38 | trainable=True, init_scale=0.2) 39 | self.char_bigru = BiGRU(char_embedding_size, hidden_size) 40 | 41 | # Encoder 42 | self.encoder_multi_bigru = MultiLayerBiGRU(word_embedding_size + 2 * hidden_size, hidden_size, 43 | num_layers=encoder_layers_num, 44 | input_drop_prob=dropout_prob) 45 | 46 | # Gated attention RNNs in the paper, here we use co-attention 47 | self.co_attention_layer = RnetCoAttention(2 * hidden_size * encoder_layers_num, 48 | 2 * hidden_size * encoder_layers_num, 49 | hidden_dim=hidden_size) 50 | 51 | # Self matching attention, here we use multi-head Attetion 52 | self.multi_head_att = MultiHeadAttention(heads, hidden_size, hidden_size, attention_on_itself=False) 53 | self.gate_dense = nn.Linear(2 * hidden_size, 2 * hidden_size) 54 | self.self_att_bigru = BiGRU(2 * hidden_size, hidden_size) 55 | 56 | # Output Layer 57 | self.pointer_net = PointerNetwork(2 * hidden_size, 2 * hidden_size, hidden_size) 58 | 59 | # RNN Dropout 60 | self.dropout = VariationalDropout(dropout_prob, batch_first=True) 61 | 62 | def forward(self, data): 63 | # Parsing data 64 | context_ids, context_len = data['context_ids'], data['context_len'] 65 | question_ids, question_len = data['question_ids'], data['question_len'] 66 | answer_start, answer_end = data['answer_start'], data['answer_end'] 67 | context_char_ids, context_word_len = data['context_char_ids'], data['context_word_len'] 68 | question_char_ids, question_word_len = data['question_char_ids'], data['question_word_len'] 69 | 70 | # Record maximum length info and generate mask matrix 71 | max_context_len = context_ids.size(1) 72 | max_context_word_len = context_char_ids.size(2) 73 | max_question_len = question_ids.size(1) 74 | max_question_word_len = question_char_ids.size(2) 75 | context_mask = sequence_mask(context_len, maxlen=max_context_len) 76 | question_mask = sequence_mask(question_len, maxlen=max_question_len) 77 | 78 | # 1. Context and Question Encoder 79 | # 1.1 Word and char embedding 80 | context_word_repr = self.word_embedding(context_ids) # B*CL*WD 81 | context_char_embedding = self.dropout(self.char_embedding(context_char_ids).reshape( 82 | [-1, max_context_word_len, self.char_embedding_size])) # (B*CL)*WL*CD 83 | question_word_repr = self.word_embedding(question_ids) # B*QL*WD 84 | question_char_embedding = self.dropout(self.char_embedding(question_char_ids).reshape( 85 | [-1, max_question_word_len, self.char_embedding_size])) # (B*QL)*WL*CD 86 | 87 | # 1.2 Char-level representation 88 | _, last_hidden_state = self.char_bigru(context_char_embedding, context_word_len.reshape([-1])) # 2*(B*CL)*H 89 | context_char_repr = torch.cat([last_hidden_state[0], last_hidden_state[1]], dim=-1) # (B*CL)*2H 90 | context_char_repr = context_char_repr.reshape([-1, max_context_len, 2 * self.hidden_size]) # B*CL*2H 91 | 92 | _, last_hidden_state = self.char_bigru(question_char_embedding, question_word_len.reshape([-1])) # 2*(B*QL)*H 93 | question_char_repr = torch.cat([last_hidden_state[0], last_hidden_state[1]], dim=-1) # (B*QL)*2H 94 | question_char_repr = question_char_repr.reshape([-1, max_question_len, 2 * self.hidden_size]) # B*QL*2H 95 | 96 | # 1.3 Concat word and char representation 97 | context_repr = torch.cat([context_word_repr, context_char_repr], dim=-1) # B*CL*(WD+2H) 98 | question_repr = torch.cat([question_word_repr, question_char_repr], dim=-1) # B*QL*(WD+2H) 99 | 100 | # 2. Encoder 101 | encoder_context, _ = self.encoder_multi_bigru(context_repr, context_len) # B*CL*(H*2*num_layers) 102 | encoder_question, _ = self.encoder_multi_bigru(question_repr, question_len) # B*QL*(H*2*num_layers) 103 | encoder_context = self.dropout(encoder_context) 104 | encoder_question = self.dropout(encoder_question) 105 | 106 | # 3. Gated attention RNNs in the paper 107 | co_att_output = self.dropout(self.co_attention_layer( 108 | encoder_context, encoder_question, context_len, question_mask)) # B*CL*H 109 | 110 | # 4. Self matching attention 111 | self_att_repr = self.dropout(self.multi_head_att(co_att_output, co_att_output, co_att_output, context_mask)) # B*CL*H 112 | self_att_rnn_input = torch.cat([co_att_output, self_att_repr], dim=-1) # B*CL*(H*2) 113 | self_att_rnn_input = self_att_rnn_input * torch.sigmoid(self.gate_dense(self_att_rnn_input)) 114 | self_att_output, _ = self.self_att_bigru(self_att_rnn_input, context_len) # B*CL*(H*2) 115 | self_att_output = self.dropout(self_att_output) # B*CL*(H*2) 116 | 117 | # 5. Pointer Network 118 | start_logits, end_logits = self.pointer_net(self_att_output, encoder_question, context_mask, question_mask) 119 | self.start_prob = masked_softmax(start_logits, context_mask) 120 | self.end_prob = masked_softmax(end_logits, context_mask) 121 | 122 | # 6. Retured Things. If train return loss, if eval/inference return a dict 123 | # TODO for squad2.0 and for multi GPUs 124 | if answer_start is not None and answer_end is not None: 125 | # sometimes the start/end positions are outside our model inputs, we ignore these terms 126 | ignored_index = start_logits.size(1) 127 | answer_start.clamp_(0, ignored_index) 128 | answer_end.clamp_(0, ignored_index) 129 | 130 | loss_fct = CrossEntropyLoss(ignore_index=ignored_index) 131 | start_loss = loss_fct(mask_logits(start_logits, context_mask), answer_start) 132 | end_loss = loss_fct(mask_logits(end_logits, context_mask), answer_end) 133 | total_loss = (start_loss + end_loss) / 2 134 | 135 | if self.training: 136 | return total_loss 137 | else: 138 | output_dict = { 139 | "start_prob": self.start_prob.cpu().numpy(), 140 | "end_prob": self.end_prob.cpu().numpy() 141 | } 142 | return total_loss, output_dict 143 | else: 144 | output_dict = { 145 | "start_prob": self.start_prob.cpu().numpy(), 146 | "end_prob": self.end_prob.cpu().numpy() 147 | } 148 | return output_dict 149 | 150 | def update(self, grad_clip=5.0): 151 | torch.nn.utils.clip_grad_norm_(self.parameters(), max_norm=grad_clip) 152 | super().update() 153 | 154 | def compile(self, optimizer='adam', initial_lr=0.002): 155 | if optimizer.lower() == 'adam': 156 | self.optimizer = Adam(self.parameters(), lr=initial_lr) 157 | elif optimizer.lower() == 'adadelta': 158 | self.optimizer = Adadelta(self.parameters(), lr=initial_lr, rho=0.95, eps=1e-08) 159 | else: 160 | raise NotImplementedError("the optimizer hasn't been implemented") 161 | 162 | def get_best_answer(self, output, instances, max_len=15): 163 | answer_list = [] 164 | for i in range(len(output['start_prob'])): 165 | instance = instances[i] 166 | max_prob = 0.0 167 | start_position = 0 168 | end_position = 0 169 | d = deque() 170 | start_prob, end_prob = output['start_prob'][i], output['end_prob'][i] 171 | for idx in range(len(start_prob)): 172 | while len(d) > 0 and idx - d[0] >= max_len: 173 | d.popleft() 174 | while len(d) > 0 and start_prob[d[-1]] <= start_prob[idx]: 175 | d.pop() 176 | d.append(idx) 177 | if start_prob[d[0]] * end_prob[idx] > max_prob: 178 | start_position = d[0] 179 | end_position = idx 180 | max_prob = start_prob[d[0]] * end_prob[idx] 181 | char_start_position = instance["context_token_spans"][start_position][0] 182 | char_end_position = instance["context_token_spans"][end_position][1] 183 | pred_answer = instance["context"][char_start_position:char_end_position] 184 | answer_list.append(pred_answer) 185 | return answer_list 186 | -------------------------------------------------------------------------------- /pytorch_mrc/nn/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YingZiqiang/PyTorch-MRCToolkit/23ec374338509a2e61d0060a43f2d6a32fe337d3/pytorch_mrc/nn/__init__.py -------------------------------------------------------------------------------- /pytorch_mrc/nn/attention.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from .recurrent import GRU 7 | from .dropout import VariationalDropout 8 | from .util import sequence_mask, masked_softmax 9 | 10 | VERY_NEGATIVE_NUMBER = -1e29 11 | 12 | 13 | class BiAttention(nn.Module): 14 | """ Bi-Directonal Attention from BIDAF Paper(https://arxiv.org/abs/1611.01603)""" 15 | 16 | def __init__(self, similarity_function): 17 | super(BiAttention, self).__init__() 18 | self.similarity_function = similarity_function 19 | 20 | def forward(self, context_repr, question_repr, context_mask, question_mask): 21 | """ 22 | Args: 23 | context_repr: the 3D torch.Tensor, shape is `(batch_size, max_context_len, context_dim)` 24 | question_repr: the 3D torch.Tensor, shape is `(batch_size, max_question_len, question_dim)` 25 | context_mask: the 1D or 2D torch.Tensor, if 1D means context_len, we will use `sequence_mask` 26 | to generate mask, if 2D we just use the mask directly. 27 | question_mask: the 1D or 2D torch.Tensor. Similar to the `context_mask` usage 28 | Returns: 29 | context2query and query2context attention 30 | """ 31 | sim_mat = self.similarity_function(context_repr, question_repr) 32 | 33 | if context_mask.dim() == 1: 34 | context_mask = sequence_mask(context_mask, maxlen=context_repr.size(1)) 35 | if question_mask.dim() == 1: 36 | question_mask = sequence_mask(question_mask, maxlen=question_repr.size(1)) 37 | mask = context_mask.unsqueeze(2) * question_mask.unsqueeze(1) # B*CL*QL 38 | sim_mat = sim_mat + (1. - mask) * VERY_NEGATIVE_NUMBER 39 | 40 | # Context-to-query Attention in the paper 41 | context2query_prob = F.softmax(sim_mat, dim=-1) 42 | context2query_attention = torch.bmm(context2query_prob, question_repr) 43 | 44 | # Query-to-context Attention in the paper 45 | query2context_prob = F.softmax(sim_mat.max(-1).values, dim=-1) 46 | query2context_attention = torch.bmm(query2context_prob.unsqueeze(1), context_repr) 47 | query2context_attention = query2context_attention.repeat(1, context_repr.size(1), 1) 48 | 49 | return context2query_attention, query2context_attention 50 | 51 | 52 | class DotAttention(nn.Module): 53 | """ 54 | We do those in DotAttention: 55 | 1. Use similarity to compute similarity score 56 | 2. Use masked softmax to gain similarity between inputs and valid memory 57 | """ 58 | 59 | def __init__(self, input_dim, memory_dim, hidden_dim, drop_prob=0.0, batch_first=True): 60 | super(DotAttention, self).__init__() 61 | self.hidden_dim = hidden_dim 62 | self.batch_first = batch_first 63 | self.input_linear = nn.Sequential( 64 | VariationalDropout(drop_prob, batch_first=True), 65 | nn.Linear(input_dim, hidden_dim, bias=False), 66 | nn.ReLU() 67 | ) 68 | 69 | self.memory_linear = nn.Sequential( 70 | VariationalDropout(drop_prob, batch_first=True), 71 | nn.Linear(memory_dim, hidden_dim, bias=False), 72 | nn.ReLU() 73 | ) 74 | 75 | def forward(self, inputs, memory, memory_mask): 76 | if not self.batch_first: 77 | inputs = inputs.transpose(0, 1) 78 | memory = memory.transpose(0, 1) 79 | memory_mask = memory_mask.transpose(0, 1) 80 | 81 | input_ = self.input_linear(inputs) # B*L1*H 82 | memory_ = self.memory_linear(memory) # B*L2*H 83 | 84 | logits = torch.bmm(input_, memory_.transpose(1, 2)) / (self.hidden_dim ** 0.5) # B*L1*L2 85 | memory_mask = memory_mask.unsqueeze(1).expand(-1, inputs.size(1), -1) # B*L1*L2 86 | score = masked_softmax(logits, memory_mask, dim=-1) # B*L1*L2 87 | 88 | context = torch.bmm(score, memory) # B*L1*D_M 89 | new_input = torch.cat([context, inputs], dim=-1) # B*L1*(D_IN+D_M) 90 | 91 | if not self.batch_first: 92 | return new_input.transpose(0, 1) 93 | return new_input 94 | 95 | 96 | class RnetCoAttention(nn.Module): 97 | """ 98 | come from sogou R-Net module, which is like to MLPSimilarity 99 | """ 100 | 101 | def __init__(self, context_dim, question_dim, hidden_dim): 102 | super(RnetCoAttention, self).__init__() 103 | self.context_linear = nn.Linear(context_dim, hidden_dim) 104 | self.question_linear = nn.Linear(question_dim, hidden_dim) 105 | self.reduce_linear = nn.Linear(hidden_dim, 1) 106 | self.gate_linear = nn.Linear(context_dim + question_dim, context_dim + question_dim) 107 | self.gru = GRU(context_dim + question_dim, hidden_dim) 108 | 109 | def forward(self, context_repr, question_repr, context_len, question_mask): 110 | co_att_context = self.context_linear(context_repr).unsqueeze(2) # B*CL*1*H 111 | co_att_question = self.question_linear(question_repr).unsqueeze(1) # B*1*QL*H 112 | co_attention_score = self.reduce_linear(torch.tanh(co_att_context + co_att_question)).squeeze(-1) # B*CL*QL 113 | co_attention_score += (1. - question_mask.unsqueeze(1)) * VERY_NEGATIVE_NUMBER 114 | co_attention_similarity = F.softmax(co_attention_score, -1) # B*CL*QL 115 | 116 | new_input = torch.cat([context_repr, torch.bmm(co_attention_similarity, question_repr)], -1) # B*CL*(CD+QD) 117 | new_input = new_input * torch.sigmoid(self.gate_linear(new_input)) 118 | 119 | outputs, _ = self.gru(new_input, context_len) # B*CL*H 120 | return outputs 121 | 122 | 123 | class MultiHeadAttention(nn.Module): 124 | def __init__(self, heads, input_dim, units, attention_on_itself=True): 125 | super(MultiHeadAttention, self).__init__() 126 | self.heads = heads 127 | self.input_dim = input_dim 128 | self.units = units 129 | self.attention_on_itself = attention_on_itself # only workable when query==key 130 | self.dense_layers = nn.ModuleList([nn.Linear(input_dim, units) for _ in range(3)]) 131 | 132 | def forward(self, query, key, value, key_mask=None): 133 | batch_size, max_query_len, max_key_len = query.size(0), query.size(1), key.size(1) 134 | wq = self.dense_layers[0](query).reshape( 135 | [batch_size, max_query_len, self.heads, self.units // self.heads]).permute(2, 0, 1, 3) # Head*B*QL*(U/Head) 136 | wk = self.dense_layers[1](key).reshape( 137 | [batch_size, max_key_len, self.heads, self.units // self.heads]).permute(2, 0, 1, 3) # Head*B*KL*(U/Head) 138 | wv = self.dense_layers[2](value).reshape( 139 | [batch_size, max_key_len, self.heads, self.units // self.heads]).permute(2, 0, 1, 3) # Head*B*KL*(U/Head) 140 | 141 | attention_score = torch.matmul(wq, wk.transpose(2, 3)) / math.sqrt(float(self.units) / self.heads) # Head*B*QL*KL 142 | if torch.equal(query, key) and not self.attention_on_itself: 143 | attention_score += torch.diag(wq.new_zeros(max_key_len) + VERY_NEGATIVE_NUMBER) 144 | if key_mask is not None: 145 | if key_mask.dim() == 1: 146 | key_mask = sequence_mask(key_mask, maxlen=max_key_len) 147 | attention_score += (1.0 - key_mask.unsqueeze(1).unsqueeze(0)) * VERY_NEGATIVE_NUMBER 148 | similarity = F.softmax(attention_score, dim=-1) # Head*B*QL*KL 149 | return torch.matmul(similarity, wv).permute(1, 2, 0, 3).reshape([batch_size, max_query_len, self.units]) # B*QL*U 150 | -------------------------------------------------------------------------------- /pytorch_mrc/nn/dropout.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class VariationalDropout(nn.Module): 5 | """Variational Dropout presented in https://arxiv.org/pdf/1512.05287.pdf""" 6 | 7 | def __init__(self, p, batch_first=True): 8 | super().__init__() 9 | self.batch_first = batch_first 10 | self.dropout = nn.Dropout(p) 11 | 12 | def forward(self, x): 13 | if not self.training: 14 | return x 15 | if self.batch_first: 16 | mask = x.new_ones(x.size(0), 1, x.size(2), requires_grad=False) 17 | else: 18 | mask = x.new_ones(1, x.size(1), x.size(2), requires_grad=False) 19 | return self.dropout(mask) * x 20 | -------------------------------------------------------------------------------- /pytorch_mrc/nn/layers.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module implements most useful network layers in Machine Reading Comprehension(MRC) Field. 3 | Such as Highway Network, Pointer Network and so on. 4 | """ 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import numpy as np 10 | 11 | from .attention import DotAttention 12 | from .recurrent import BiGRU 13 | from .dropout import VariationalDropout 14 | from .util import sequence_mask, masked_softmax 15 | 16 | VERY_NEGATIVE_NUMBER = -1e29 17 | 18 | 19 | class Embedding(nn.Module): 20 | # TODO unk token need to train always 21 | def __init__(self, pretrained_embedding=None, embedding_shape=None, trainable=True, init_scale=0.02, dtype='float'): 22 | super(Embedding, self).__init__() 23 | if pretrained_embedding is None and embedding_shape is None: 24 | raise ValueError("At least one of pretrained_embedding and embedding_shape must be specified!") 25 | 26 | if pretrained_embedding is not None: 27 | if isinstance(pretrained_embedding, np.ndarray): 28 | pretrained_embedding = torch.from_numpy(pretrained_embedding) 29 | self.embedding = nn.Embedding.from_pretrained(pretrained_embedding) 30 | else: 31 | self.embedding = nn.Embedding(embedding_shape[0], embedding_shape[1]) 32 | nn.init.uniform_(self.embedding.weight, -init_scale, init_scale) 33 | 34 | if dtype == 'float': 35 | self.embedding = self.embedding.float() 36 | elif dtype == 'double': 37 | self.embedding = self.embedding.double() 38 | else: 39 | raise NotImplementedError('the dtype must be one of `float` and `double`.') 40 | 41 | self.embedding.weight.requires_grad = trainable 42 | 43 | def forward(self, indices): 44 | return self.embedding(indices) 45 | 46 | 47 | class Highway(nn.Module): 48 | """ 49 | Implements Highway Networks(https://arxiv.org/pdf/1505.00387.pdf) 50 | y = g * x + (1 - g) * f(A(x)) where `A` is a linear transformation, `f` is an element-wise 51 | non-linearity, `g` is an element-wise gate computed as `sigmoid(B(x))`. 52 | """ 53 | 54 | def __init__(self, 55 | input_dim, 56 | num_layers=1, 57 | activation=F.relu): 58 | super(Highway, self).__init__() 59 | self.activation = activation 60 | self.layers = nn.ModuleList([nn.Linear(input_dim, input_dim * 2) for _ in range(num_layers)]) 61 | for layer in self.layers: 62 | # We should bias the highway layer to just carry its input forward. We do that by 63 | # setting the bias on `B(x)` to be positive, because that means `g` will be biased to 64 | # be high, so we will carry the input forward. The bias on `B(x)` is the second half 65 | # of the bias vector in each Linear layer. 66 | layer.bias[input_dim:].data.fill_(1) 67 | 68 | def forward(self, inputs): 69 | current_input = inputs 70 | for layer in self.layers: 71 | projected_input = layer(current_input) 72 | linear_part = current_input 73 | nonlinear_part, gate = projected_input.chunk(2, dim=-1) 74 | nonlinear_part = self.activation(nonlinear_part) 75 | gate = torch.sigmoid(gate) 76 | current_input = gate * linear_part + (1 - gate) * nonlinear_part 77 | return current_input 78 | 79 | 80 | class Conv1DAndMaxPooling(nn.Module): 81 | """ Conv1D for 3D or 4D input tensor, the second-to-last dimension is regarded as timestep """ 82 | 83 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, activation=F.relu): 84 | super(Conv1DAndMaxPooling, self).__init__() 85 | self.conv_layer = nn.Conv1d(in_channels, out_channels, kernel_size, stride=stride, padding=padding) 86 | self.activation = activation 87 | 88 | def forward(self, x, seq_word_len=None): 89 | input_shape = x.size() 90 | if len(input_shape) == 4: 91 | batch_size, max_seq_len = input_shape[0], input_shape[1] 92 | x = x.reshape([-1, input_shape[-2], input_shape[-1]]) 93 | x = self.activation(self.conv_layer(x.transpose(1, 2))).transpose(1, 2) 94 | 95 | if seq_word_len is not None: 96 | x = x.reshape([batch_size, max_seq_len, x.size(1), x.size(-1)]) 97 | x = self._masked_max_pooling(x, seq_word_len) 98 | else: 99 | x = x.max(1).values 100 | x = x.reshape([batch_size, -1, x.size(-1)]) 101 | elif len(input_shape) == 3: 102 | x = self.activation(self.conv_layer(x.transpose(1, 2))).transpose(1, 2) 103 | x = x.max(1).values 104 | else: 105 | raise ValueError('input tensor shape/size must be 3D or 4D') 106 | 107 | return x 108 | 109 | def _masked_max_pooling(self, input, seq_word_len=None): 110 | # TODO can be improved 111 | rank = len(input.size()) - 2 112 | if seq_word_len is not None: 113 | shape = input.size() 114 | mask = sequence_mask(seq_word_len.reshape([-1]), maxlen=shape[-2]) 115 | mask = mask.reshape([shape[0], shape[1], shape[2], 1]) 116 | input = input * mask + (1 - mask) * VERY_NEGATIVE_NUMBER 117 | return input.max(dim=rank).values 118 | 119 | 120 | class Gate(nn.Module): 121 | def __init__(self, input_dim, drop_prob=0.0): 122 | super().__init__() 123 | self.gate = nn.Sequential( 124 | VariationalDropout(drop_prob, batch_first=True), 125 | nn.Linear(input_dim, input_dim, bias=False), 126 | nn.Sigmoid() 127 | ) 128 | 129 | def forward(self, inputs): 130 | return inputs * self.gate(inputs) 131 | 132 | 133 | class StaticPairEncoder(nn.Module): 134 | def __init__(self, input_dim, memory_dim, hidden_dim, drop_prob=0.0, batch_first=True): 135 | super(StaticPairEncoder, self).__init__() 136 | self.attention = DotAttention(input_dim, memory_dim, hidden_dim, 137 | drop_prob=drop_prob, batch_first=batch_first) 138 | self.gate = nn.Sequential( 139 | Gate(input_dim + memory_dim, drop_prob=drop_prob), 140 | VariationalDropout(drop_prob, batch_first=batch_first) 141 | ) 142 | self.encoder = BiGRU(input_dim + memory_dim, hidden_dim, batch_first=batch_first) 143 | 144 | def forward(self, inputs, memory, inputs_len, memory_mask): 145 | new_inputs = self.gate(self.attention(inputs, memory, memory_mask)) 146 | outputs, _ = self.encoder(new_inputs, inputs_len) 147 | return outputs 148 | 149 | 150 | class StaticSelfMatchEncoder(StaticPairEncoder): 151 | """ 152 | just same with `StaticPairEncoder` 153 | """ 154 | pass 155 | 156 | 157 | class PointerNetwork(nn.Module): 158 | """ 159 | Implements the Pointer Network. 160 | """ 161 | 162 | def __init__(self, context_dim, question_dim, hidden_dim, 163 | cell_type='gru', drop_prob=0.0, batch_first=True): 164 | super(PointerNetwork, self).__init__() 165 | self.batch_first = batch_first 166 | self.cell_type = cell_type.lower() 167 | 168 | if self.cell_type == 'gru': 169 | cell_cls = nn.GRUCell 170 | elif self.cell_type == 'lstm': 171 | cell_cls = nn.LSTMCell 172 | elif self.cell_type == 'rnn': 173 | cell_cls = nn.RNNCell 174 | else: 175 | raise NotImplementedError('cell_type must be one of rnn/gru/lstm') 176 | self.cell = cell_cls(context_dim, question_dim) 177 | 178 | self.question_linear = nn.Sequential( 179 | VariationalDropout(drop_prob), 180 | nn.Linear(2 * question_dim, hidden_dim, bias=False), 181 | nn.Tanh(), 182 | nn.Linear(hidden_dim, 1, bias=False), 183 | ) 184 | self.context_linear = nn.Sequential( 185 | VariationalDropout(drop_prob), 186 | nn.Linear(question_dim + context_dim, hidden_dim, bias=False), 187 | nn.Tanh(), 188 | nn.Linear(hidden_dim, 1, bias=False), 189 | ) 190 | 191 | self.random_attn_vector = nn.Parameter(torch.randn(1, 1, question_dim)) 192 | 193 | def forward(self, context_repr, question_repr, context_mask, question_mask): 194 | """ 195 | Use Pointer Network to compute the probabilities of each position 196 | to be start and end of the answer 197 | Returns: 198 | the logits of evary position to be start and end of the answer 199 | """ 200 | if not self.batch_first: 201 | context_repr = context_repr.transpose(0, 1) 202 | question_repr = question_repr.transpose(0, 1) 203 | context_mask = context_mask.transpose(0, 1) 204 | question_mask = question_mask.transpose(0, 1) 205 | 206 | state = self._question_pooling(question_repr, question_mask) # B*QD 207 | cell_input, ans_start_logits = self._context_attention(context_repr, context_mask, state) # B*CD, B*CL 208 | if self.cell_type == 'lstm': 209 | state, _ = self.cell(cell_input, hx=(state, state)) # B*QD 210 | else: 211 | state = self.cell(cell_input, hx=state) # B*QD 212 | _, ans_end_logits = self._context_attention(context_repr, context_mask, state) # _, B*CL 213 | 214 | return ans_start_logits, ans_end_logits 215 | 216 | def _question_pooling(self, question_repr, question_mask): 217 | """use attention-pooling to question and a random trainable vector""" 218 | expanded_att_vector = self.random_attn_vector.expand(question_repr.size(0), question_repr.size(1), -1) # B*QL*QD 219 | logits = self.question_linear(torch.cat([question_repr, expanded_att_vector], dim=-1)).squeeze(-1) # B*QL 220 | score = masked_softmax(logits, question_mask, dim=-1) # B*QL 221 | state = torch.sum(score.unsqueeze(-1) * question_repr, dim=1) # B*QD 222 | return state 223 | 224 | def _context_attention(self, context_repr, context_mask, state): 225 | expanded_state = state.unsqueeze(1).expand(-1, context_repr.size(1), -1) # B*CL*QD 226 | logits = self.context_linear(torch.cat([context_repr, expanded_state], dim=-1)).squeeze(-1) # B*CL 227 | score = masked_softmax(logits, context_mask, dim=-1) # B*CL 228 | cell_input = torch.sum(score.unsqueeze(-1) * context_repr, dim=1) # B*CD 229 | return cell_input, logits 230 | -------------------------------------------------------------------------------- /pytorch_mrc/nn/recurrent.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module is the wrapper of recurrent neural networks(RNNs) in Pytorch. 3 | 1. we implement BaseRNN which is the wrapper of ``pack_padded_sequence``, 4 | ``pad_packed_sequence`` and standard RNNs 5 | 2. we implement BaseMultiLayerRNN which uses the VariationalDropout at each RNN layers input, 6 | rather than use standard Dropout between RNN layers. 7 | 3. we inherit the classes described above and implement some common RNNs class, e.g. BiGRU, MultiLayerBiGRU. 8 | """ 9 | 10 | import torch 11 | import torch.nn as nn 12 | from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence 13 | 14 | from .dropout import VariationalDropout 15 | 16 | 17 | class BaseRNN(nn.Module): 18 | """Base RNN Module, which has been packed using `pack_padded_sequence` and unpacked using `pad_packed_sequence`""" 19 | 20 | def __init__(self, rnn_type, input_size, hidden_size, batch_first=True, 21 | num_layers=1, bidirectional=False, drop_prob=0.0): 22 | super(BaseRNN, self).__init__() 23 | self.batch_first = batch_first 24 | self.rnn_type = rnn_type.lower() 25 | if self.rnn_type == 'rnn': 26 | rnn_cls = nn.RNN 27 | elif self.rnn_type == 'gru': 28 | rnn_cls = nn.GRU 29 | elif self.rnn_type == 'lstm': 30 | rnn_cls = nn.LSTM 31 | else: 32 | raise NotImplementedError('rnn_type must be one of `RNN/rnn`, `GRU/gru`, `LSTM/lstm`') 33 | 34 | self.rnn = rnn_cls(input_size, hidden_size, batch_first=True, num_layers=num_layers, 35 | bidirectional=bidirectional, dropout=(0 if num_layers == 1 else drop_prob)) 36 | 37 | def forward(self, inputs, lengths=None, initial_state=None): 38 | """ 39 | Args: 40 | inputs(Tensor): tensor containing the features of the input sequence. 41 | If `batch_first=True(default)`, tensor shape is `(batch_size, seq_len, input_size)`, 42 | else `(seq_len, batch_size, input_size)`. 43 | lengths(Tensor): tensor containing the real length of each sequence. 44 | initial_state(Tensor or tuple): tensor containing the initial hidden state for each element in the batch. 45 | if rnn_type is not `lstm`, it means h_0, shape is `(num_layers * num_directions, batch, hidden_size)`, 46 | else it means a tuple of (h_0, c_0) for lstm, the shape of both h_0 and c_0 are `(num_layers * num_directions, batch, hidden_size)` 47 | Retures: 48 | a tuple of (outputs, last_state). 49 | outputs: the shape of outputs is `(batch_size, seq_len, num_directions * hidden_size)` if `batch_first=True(default)`, 50 | else `(seq_len, batch_size, num_directions * hidden_size)`. 51 | last_state: containing the hidden state for `t = seq_len`. The shape of last_state 52 | is `(num_layers * num_directions, batch_size, hidden_size)`. 53 | """ 54 | # Ensure inputs is batch_first 55 | if not self.batch_first: 56 | inputs.transpose_(0, 1) 57 | 58 | if lengths is None: 59 | outputs, last_state = self.rnn(inputs, initial_state) 60 | else: 61 | orig_len = inputs.size(1) 62 | # Sort and Pack 63 | lengths, sort_idx = lengths.sort(dim=0, descending=True) 64 | inputs = inputs[sort_idx] 65 | inputs = pack_padded_sequence(inputs, lengths, batch_first=True) 66 | # Apply RNNs 67 | outputs, last_state = self.rnn(inputs, initial_state) 68 | # Unpack and Unsort 69 | outputs, _ = pad_packed_sequence(outputs, batch_first=True, total_length=orig_len) 70 | _, unsort_idx = sort_idx.sort(dim=0) 71 | outputs = outputs[unsort_idx] 72 | if self.rnn_type == 'lstm': 73 | last_state = (last_state[0][:, unsort_idx, :], last_state[1][:, unsort_idx, :]) 74 | else: 75 | last_state = last_state[:, unsort_idx, :] 76 | 77 | # Restored outputs shape 78 | if not self.batch_first: 79 | outputs.transpose_(0, 1) 80 | 81 | return outputs, last_state 82 | 83 | 84 | class LSTM(BaseRNN): 85 | """Unidirectional LSTM""" 86 | 87 | def __init__(self, input_size, hidden_size, num_layers=1, drop_prob=0.0, batch_first=True): 88 | super(LSTM, self).__init__('LSTM', input_size, hidden_size, 89 | num_layers=num_layers, drop_prob=drop_prob, 90 | batch_first=batch_first, bidirectional=False) 91 | 92 | 93 | class GRU(BaseRNN): 94 | """Unidirectional GRU""" 95 | 96 | def __init__(self, input_size, hidden_size, num_layers=1, drop_prob=0.0, batch_first=True): 97 | super(GRU, self).__init__('GRU', input_size, hidden_size, 98 | num_layers=num_layers, drop_prob=drop_prob, 99 | batch_first=batch_first, bidirectional=False) 100 | 101 | 102 | class BiLSTM(BaseRNN): 103 | """Bidirectional LSTM""" 104 | 105 | def __init__(self, input_size, hidden_size, num_layers=1, drop_prob=0.0, batch_first=True): 106 | super(BiLSTM, self).__init__('LSTM', input_size, hidden_size, 107 | num_layers=num_layers, drop_prob=drop_prob, 108 | batch_first=batch_first, bidirectional=True) 109 | 110 | 111 | class BiGRU(BaseRNN): 112 | """Bidirectional GRU""" 113 | 114 | def __init__(self, input_size, hidden_size, num_layers=1, drop_prob=0.0, batch_first=True): 115 | super(BiGRU, self).__init__('GRU', input_size, hidden_size, 116 | num_layers=num_layers, drop_prob=drop_prob, 117 | batch_first=batch_first, bidirectional=True) 118 | 119 | 120 | class BaseMultiLayerRNN(nn.Module): 121 | """Multi-Layer RNNs Base Model. In particular, the input of each RNN layer uses `Variational Dropout`""" 122 | 123 | def __init__(self, rnn_type, input_size, hidden_size, num_layers, 124 | batch_first=True, bidirectional=False, input_drop_prob=0.0): 125 | super(BaseMultiLayerRNN, self).__init__() 126 | self.rnn_type = rnn_type.lower() 127 | self.batch_first = batch_first 128 | 129 | self.rnn_list = nn.ModuleList( 130 | [BaseRNN(self.rnn_type, input_size, hidden_size, batch_first=True, bidirectional=bidirectional)]) 131 | 132 | input_size_ = 2 * hidden_size if bidirectional else hidden_size 133 | for _ in range(num_layers - 1): 134 | self.rnn_list.append(BaseRNN(self.rnn_type, input_size_, hidden_size, 135 | batch_first=True, bidirectional=bidirectional)) 136 | 137 | self.dropout = VariationalDropout(p=input_drop_prob, batch_first=True) 138 | 139 | def forward(self, inputs, lengths=None, initial_state=None, concat_layers=True): 140 | """ 141 | Args: 142 | concat_layers(bool): whether concat all layers outputs when `num_layers > 1` 143 | Returns: 144 | a tuple of (outputs, last_state). 145 | outputs(Tensor): If `concat_layers=True`, will return all layers outputs, 146 | the last dim shape is num_directions * hidden_size * num_layers. 147 | last_state(Tensor or tuple): if rnn_type is not lstm return a Tensor which means h_n, else return a tuple (h_n, c_n) 148 | the tensor shape is `(num_layers * num_directions, batch_size, hidden_size)`. 149 | """ 150 | # Ensure inputs is batch_first 151 | if not self.batch_first: 152 | inputs.transpose_(0, 1) 153 | 154 | # Apply RNNs 155 | outputs_list, last_state_list = [], [] 156 | for rnn in self.rnn_list: 157 | outputs, last_state = rnn(self.dropout(inputs), lengths, initial_state) 158 | outputs_list.append(outputs) 159 | last_state_list.append(last_state) 160 | inputs = outputs 161 | 162 | # Prepare the return values 163 | outputs, last_state = None, None 164 | if concat_layers: 165 | outputs = torch.cat(outputs_list, dim=-1) 166 | else: 167 | outputs = outputs_list[-1] 168 | if self.rnn_type == 'lstm': 169 | hn_state = torch.cat([layer_state[0] for layer_state in last_state_list], dim=0) 170 | cn_state = torch.cat([layer_state[1] for layer_state in last_state_list], dim=0) 171 | last_state = (hn_state, cn_state) 172 | else: 173 | last_state = torch.cat(last_state_list, dim=0) 174 | 175 | # Restored outputs shape 176 | if not self.batch_first: 177 | outputs.transpose_(0, 1) 178 | 179 | return outputs, last_state 180 | 181 | 182 | class MultiLayerBiGRU(BaseMultiLayerRNN): 183 | def __init__(self, input_size, hidden_size, num_layers, input_drop_prob=0.0, batch_first=True): 184 | super(MultiLayerBiGRU, self).__init__('GRU', input_size, hidden_size, 185 | num_layers=num_layers, 186 | input_drop_prob=input_drop_prob, 187 | batch_first=batch_first, 188 | bidirectional=True) 189 | 190 | 191 | class MultiLayerBiLSTM(BaseMultiLayerRNN): 192 | def __init__(self, input_size, hidden_size, num_layers, input_drop_prob=0.0, batch_first=True): 193 | super(MultiLayerBiLSTM, self).__init__('LSTM', input_size, hidden_size, 194 | num_layers=num_layers, 195 | input_drop_prob=input_drop_prob, 196 | batch_first=batch_first, 197 | bidirectional=True) 198 | -------------------------------------------------------------------------------- /pytorch_mrc/nn/similarity_function.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module implements the two matrices similarity calculation. 3 | Here we assume input shape is ``(batch_size, seq_len1, dim1)` and ``(batch_size, seq_len2, dim2)`, and 4 | we will return output whose shape is ``(batch_size, seq_len1, seq_len2)``. 5 | """ 6 | 7 | import torch 8 | import torch.nn as nn 9 | import math 10 | 11 | 12 | class CosineSimilarity(nn.Module): 13 | """ 14 | This similarity function simply computes the cosine similarity between two matrixes. 15 | It has no parameters. 16 | """ 17 | 18 | def forward(self, tensor_1, tensor_2): 19 | normalized_tensor_1 = tensor_1 / tensor_1.norm(dim=-1, keepdim=True) 20 | normalized_tensor_2 = tensor_2 / tensor_2.norm(dim=-1, keepdim=True) 21 | return torch.bmm(normalized_tensor_1, normalized_tensor_2.transpose(1, 2)) 22 | 23 | 24 | class DotProductSimilarity(nn.Module): 25 | """ 26 | This similarity function simply computes the dot product between two matrixes, with an 27 | optional scaling to reduce the variance of the output elements. 28 | """ 29 | 30 | def __init__(self, scale_output=False): 31 | super(DotProductSimilarity, self).__init__() 32 | self.scale_output = scale_output 33 | 34 | def forward(self, tensor_1, tensor_2): 35 | result = torch.bmm(tensor_1, tensor_2.transpose(1, 2)) 36 | if self.scale_output: 37 | # TODO why allennlp do multiplication here ? 38 | result /= math.sqrt(tensor_1.size(-1)) 39 | return result 40 | 41 | 42 | class ProjectedDotProductSimilarity(nn.Module): 43 | """ 44 | This similarity function does a projection and then computes the dot product between two matrices. 45 | It's computed as ``x^T W_1 (y^T W_2)^T + b(Optional)``. An activation function applied after the calculation. 46 | Default is no activation. 47 | """ 48 | 49 | def __init__(self, tensor_1_dim, tensor_2_dim, projected_dim, 50 | reuse_weight=False, bias=False, activation=None): 51 | super(ProjectedDotProductSimilarity, self).__init__() 52 | self.reuse_weight = reuse_weight 53 | self.projecting_weight_1 = nn.Parameter(torch.Tensor(tensor_1_dim, projected_dim)) 54 | if self.reuse_weight: 55 | if tensor_1_dim != tensor_2_dim: 56 | raise ValueError('if reuse_weight=True, tensor_1_dim must equal tensor_2_dim') 57 | else: 58 | self.projecting_weight_2 = nn.Parameter(torch.Tensor(tensor_2_dim, projected_dim)) 59 | self.bias = nn.Parameter(torch.Tensor(1)) if bias else None 60 | self.activation = activation 61 | 62 | def reset_parameters(self): 63 | nn.init.xavier_uniform_(self.projecting_weight_1) 64 | if not self.reuse_weight: 65 | nn.init.xavier_uniform_(self.projecting_weight_2) 66 | if self.bias is not None: 67 | self.bias.data.fill_(0) 68 | 69 | def forward(self, tensor_1, tensor_2): 70 | projected_tensor_1 = torch.matmul(tensor_1, self.projecting_weight_1) 71 | if self.reuse_weight: 72 | projected_tensor_2 = torch.matmul(tensor_2, self.projecting_weight_1) 73 | else: 74 | projected_tensor_2 = torch.matmul(tensor_2, self.projecting_weight_2) 75 | result = torch.bmm(projected_tensor_1, projected_tensor_2.transpose(1, 2)) 76 | if self.bias is not None: 77 | result += self.bias 78 | if self.activation is not None: 79 | result = self.activation(result) 80 | return result 81 | 82 | 83 | class BiLinearSimilarity(nn.Module): 84 | """ 85 | This similarity function performs a bilinear transformation of the two input matrices. It's 86 | computed as ``x^T W y + b(Optional)``. An activation function applied after the calculation. 87 | Default is no activation. 88 | """ 89 | 90 | def __init__(self, tensor_1_dim, tensor_2_dim, bias=False, activation=None): 91 | super(BiLinearSimilarity, self).__init__() 92 | self.weight_matrix = nn.Parameter(torch.Tensor(tensor_1_dim, tensor_2_dim)) 93 | self.bias = nn.Parameter(torch.Tensor(1)) if bias else None 94 | self.activation = activation 95 | self.reset_parameters() 96 | 97 | def reset_parameters(self): 98 | nn.init.xavier_uniform_(self.weight_matrix) 99 | if self.bias is not None: 100 | self.bias.data.fill_(0) 101 | 102 | def forward(self, tensor_1, tensor_2): 103 | intermediate = torch.matmul(tensor_1, self.weight_matrix) 104 | result = torch.bmm(intermediate, tensor_2.transpose(1, 2)) 105 | if self.bias is not None: 106 | result += self.bias 107 | if self.activation is not None: 108 | result = self.activation(result) 109 | return result 110 | 111 | 112 | class TriLinearSimilarity(nn.Module): 113 | """ 114 | This similarity function performs a trilinear transformation of the two input matrices. It's 115 | computed as ``w^T [x; y; x*y] + b(Optional)``. An activation function applied after the calculation. 116 | Default is no activation. 117 | """ 118 | 119 | def __init__(self, input_dim, bias=False, activation=None): 120 | super(TriLinearSimilarity, self).__init__() 121 | self.input_dim = input_dim 122 | self.weight_vector = nn.Parameter(torch.Tensor(3 * input_dim)) 123 | self.bias = nn.Parameter(torch.Tensor(1)) if bias else None 124 | self.activation = activation 125 | self.reset_parameters() 126 | 127 | def reset_parameters(self): 128 | std = math.sqrt(6 / (self.weight_vector.size(0) + 1)) 129 | self.weight_vector.data.uniform_(-std, std) 130 | if self.bias is not None: 131 | self.bias.data.fill_(0) 132 | 133 | def forward(self, tensor_1, tensor_2): 134 | w1, w2, w12 = self.weight_vector.chunk(3, dim=-1) 135 | tensor_1_score = torch.matmul(tensor_1, w1) # B*L1 136 | tensor_2_score = torch.matmul(tensor_2, w2) # B*L2 137 | combined_score = torch.bmm(tensor_1 * w12, tensor_2.transpose(1, 2)) # B*L1*L2 138 | result = combined_score + tensor_1_score.unsqueeze(2) + tensor_2_score.unsqueeze(1) 139 | if self.bias is not None: 140 | result += self.bias 141 | if self.activation is not None: 142 | result = self.activation(result) 143 | return result 144 | 145 | 146 | class MLPSimilarity(nn.Module): 147 | """ 148 | This similarity function performs Multi-Layer Perception to compute similarity. It's 149 | computed as ``w^T f(linear(x) + linear(y)) + b(Optional)``. Notify we will use the 150 | activation(Default tanh) between two perception layers rather than output layer. 151 | """ 152 | def __init__(self, tensor_1_dim, tensor_2_dim, hidden_dim, bias=False, activation=torch.tanh): 153 | super(MLPSimilarity, self).__init__() 154 | self.projecting_layers = nn.ModuleList([nn.Linear(tensor_1_dim, hidden_dim), 155 | nn.Linear(tensor_2_dim, hidden_dim)]) 156 | self.score_weight = nn.Parameter(torch.Tensor(hidden_dim)) 157 | self.score_bias = nn.Parameter(torch.Tensor(1)) if bias else None 158 | self.activation = activation 159 | self.reset_parameters() 160 | 161 | def reset_parameters(self): 162 | std = math.sqrt(6 / (self.score_weight.size(0) + 1)) 163 | self.score_weight.data.uniform_(-std, std) 164 | if self.score_bias is not None: 165 | self.score_bias.data.fill_(0) 166 | 167 | def forward(self, tensor_1, tensor_2): 168 | projected_tensor_1 = self.projecting_layers[0](tensor_1) # B*L1*H 169 | projected_tensor_2 = self.projecting_layers[1](tensor_2) # B*L2*H 170 | combined_tensor = projected_tensor_1.unsqueeze(2) + projected_tensor_2.unsqueeze(1) # B*L1*L2*H 171 | result = torch.matmul(self.activation(combined_tensor), self.score_weight) # B*L1*L2 172 | if self.score_bias is not None: 173 | result += self.score_bias 174 | return result 175 | 176 | 177 | # class SymmetricProject(nn.Module): 178 | # def __init__(self, tensor_1_dim, tensor_2_dim, hidden_dim, reuse_weight=True, activation=F.relu): 179 | # super(SymmetricProject, self).__init__() 180 | # self.reuse_weight = reuse_weight 181 | # with tf.variable_scope(self.name): 182 | # diagonal = tf.get_variable('diagonal_matrix', shape=[self.hidden_dim],initializer=tf.ones_initializer, dtype=tf.float32) 183 | # self.diagonal_matrix = tf.diag(diagonal) 184 | # self.projecting_layer = tf.keras.layers.Dense(hidden_dim, activation=activation, 185 | # use_bias=False) 186 | # if not reuse_weight: 187 | # self.projecting_layer2 = tf.keras.layers.Dense(hidden_dim, activation=activation, use_bias=False) 188 | # 189 | # def __call__(self, t0, t1): 190 | # trans_t0 = self.projecting_layer(t0) 191 | # trans_t1 = self.projecting_layer(t1) 192 | # return tf.matmul(tf.tensordot(trans_t0,self.diagonal_matrix,[[2],[0]]),trans_t1,transpose_b=True) 193 | -------------------------------------------------------------------------------- /pytorch_mrc/nn/util.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module implements some useful utility functions for handling the sequence tensor. 3 | """ 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | 8 | VERY_NEGATIVE_NUMBER = -1e29 9 | 10 | 11 | def sequence_mask(lengths, maxlen=None, dtype=torch.float32): 12 | """ 13 | Args: 14 | lengths: 1D torch.Tensor, shape is `(batch_size)` 15 | Returns: 16 | mask: 2D torhc.Tensor, shape is `(batch_size, maxlen)` 17 | """ 18 | # TODO come from tf.sequence_mask. There should be better implementation 19 | if maxlen is None: 20 | maxlen = lengths.max().item() 21 | mask = torch.zeros(len(lengths), maxlen, device=lengths.device, dtype=dtype) 22 | for idx, real_len in enumerate(lengths): 23 | mask[idx, :real_len] = 1 24 | return mask 25 | 26 | 27 | def weighted_sum(matrix, attention, dim=1): 28 | """ 29 | matrix: 3D torch.Tensor 30 | attention: 2D or 3D torch.Tensor 31 | dim: which dimension to reduce sum 32 | 33 | Note: In fact, it also supports higher-order tensors, but it needs to keep the dimensions exactly the same. 34 | """ 35 | if attention.dim() == 2: 36 | attention = attention.unsqueeze(2) 37 | return (matrix * attention).sum(dim) 38 | 39 | 40 | def mask_logits(logits, mask): 41 | """ 42 | logits is a 2D torch.Tensor, its shape usually means (batch_size, max_seq_len). 43 | mask can be 1D or 2D torch.Tensor. If 1D it means `seq_len`, we will use `sequence_mask` to generate mask, 44 | if 2D we just use the mask directly. 45 | 46 | Note: In fact, it also supports higher-order tensors, but it needs to keep the dimensions exactly the same. 47 | """ 48 | if mask.dim() == 1: 49 | mask = sequence_mask(mask, maxlen=logits.size(1), dtype=torch.float32) 50 | return logits + (1.0 - mask) * VERY_NEGATIVE_NUMBER 51 | 52 | 53 | def masked_softmax(logits, mask, dim=-1): 54 | """ 55 | Firstly, we will do same thing with `mask_logits`, it means logits is a 2D torch.Tensor and mask can be 1D or 2D torch.Tensor. 56 | Then, we will do `softmax` at selected dimension. 57 | 58 | Note: In fact, it also supports higher-order tensors, but it needs to keep the dimensions exactly the same. 59 | """ 60 | return F.softmax(mask_logits(logits, mask), dim=dim) 61 | 62 | 63 | def add_seq_mask(inputs, seq_len, mode='mul', max_len=None): 64 | mask = sequence_mask(seq_len, maxlen=max_len, dtype=torch.float32).unsqueeze(2) 65 | if mode == 'mul': 66 | return inputs * mask 67 | if mode == 'add': 68 | mask = (1 - mask) * VERY_NEGATIVE_NUMBER 69 | return inputs + mask 70 | -------------------------------------------------------------------------------- /pytorch_mrc/train/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YingZiqiang/PyTorch-MRCToolkit/23ec374338509a2e61d0060a43f2d6a32fe337d3/pytorch_mrc/train/__init__.py -------------------------------------------------------------------------------- /pytorch_mrc/train/trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | from collections import defaultdict 4 | 5 | import torch 6 | import numpy as np 7 | 8 | 9 | class Trainer(object): 10 | def __init__(self): 11 | pass 12 | 13 | @staticmethod 14 | def _train(model, device, batch_generator, steps, summary_writer, save_summary_steps, log_every_n_batch): 15 | model.train() 16 | # TODO handle the summary_writer and save_summary_steps 17 | total_loss, n_batch_loss = 0.0, 0.0 18 | for i in range(steps): 19 | train_batch = batch_generator.next() 20 | for key, value in train_batch.items(): 21 | train_batch[key] = value.to(device) 22 | 23 | # forward + backward + optimize 24 | model.zero_grad() 25 | loss = model(train_batch) 26 | loss.backward() 27 | model.update() 28 | 29 | if np.isnan(loss.item()): 30 | raise ValueError("NaN loss!") 31 | total_loss += loss.item() 32 | n_batch_loss += loss.item() 33 | if log_every_n_batch > 0 and i > 0 and i % log_every_n_batch == 0: 34 | logging.info("- Average loss from batch {} to {} is {:05.3f}".format( 35 | i - log_every_n_batch, i, n_batch_loss / log_every_n_batch)) 36 | n_batch_loss = 0.0 37 | 38 | logging.info("- Train mean loss: {:05.3f}".format(total_loss / steps)) 39 | 40 | @staticmethod 41 | def _eval(model, device, batch_generator, steps, summary_writer=None): 42 | model.eval() 43 | total_loss = 0.0 44 | final_output = defaultdict(list) 45 | 46 | with torch.no_grad(): 47 | for _ in range(steps): 48 | eval_batch = batch_generator.next() 49 | for key, value in eval_batch.items(): 50 | eval_batch[key] = value.to(device) 51 | 52 | loss, output = model(eval_batch) 53 | total_loss += loss.item() 54 | for key in output.keys(): 55 | final_output[key] += [v for v in output[key]] 56 | 57 | # Get Eval Mean Loss 58 | logging.info("- Eval mean loss: {:05.3f}".format(total_loss / steps)) 59 | 60 | # Add summaries manually to writer at global_step_val 61 | if summary_writer is not None: 62 | # TODO 63 | pass 64 | # global_step_val = model.session.run(global_step) 65 | # for tag, val in metrics_val.items(): 66 | # summ = tf.Summary(value=[tf.Summary.Value(tag=tag, simple_value=val)]) 67 | # summary_writer.add_summary(summ, global_step_val) 68 | 69 | return final_output 70 | 71 | @staticmethod 72 | def _inference(model, batch_generator, steps): 73 | model.eval() 74 | final_output = defaultdict(list) 75 | 76 | with torch.no_grad(): 77 | for _ in range(steps): 78 | eval_batch = batch_generator.next() 79 | output = model(eval_batch) 80 | for key in output.keys(): 81 | final_output[key] += [v for v in output[key]] 82 | 83 | return final_output 84 | 85 | @staticmethod 86 | def train_and_evaluate(model, device, train_batch_generator, eval_batch_generator, evaluator, epochs=1, episodes=1, 87 | save_dir=None, summary_dir=None, save_summary_steps=10, log_every_n_batch=100): 88 | model.to(device) 89 | 90 | # TODO use tensorboardX 91 | train_summary = None 92 | eval_summary = None 93 | # train_summary = tf.summary.FileWriter(os.path.join(summary_dir, 'train_summaries')) if summary_dir else None 94 | # eval_summary = tf.summary.FileWriter(os.path.join(summary_dir, 'eval_summaries')) if summary_dir else None 95 | 96 | best_eval_score = 0.0 97 | for epoch in range(epochs): 98 | logging.info("Epoch {}/{}".format(epoch + 1, epochs)) 99 | train_batch_generator.init() 100 | train_num_steps = (train_batch_generator.get_dataset_size() + 101 | train_batch_generator.get_batch_size() - 1) // train_batch_generator.get_batch_size() 102 | 103 | # one epoch consists of several episodes 104 | assert isinstance(episodes, int) 105 | num_steps_per_episode = (train_num_steps + episodes - 1) // episodes 106 | for episode in range(episodes): 107 | logging.info("episode {}/{}".format(episode + 1, episodes)) 108 | current_step_num = min(num_steps_per_episode, train_num_steps - episode * num_steps_per_episode) 109 | episode_id = epoch * episodes + episode + 1 110 | Trainer._train(model, device, train_batch_generator, current_step_num, 111 | train_summary, save_summary_steps, log_every_n_batch) 112 | 113 | if model.ema_decay > 0: 114 | # TODO how to do it 115 | pass 116 | 117 | # Save weights 118 | if save_dir is not None: 119 | last_save_path = os.path.join(save_dir, 'last_weights', 'after-episode') 120 | model.save(last_save_path, global_step=episode_id) 121 | 122 | # Evaluate for one episode on dev set, TODO 123 | eval_batch_generator.init() 124 | eval_raw_dataset = eval_batch_generator.get_raw_dataset() 125 | eval_num_steps = (eval_batch_generator.get_dataset_size() + 126 | eval_batch_generator.get_batch_size() - 1) // eval_batch_generator.get_batch_size() 127 | output = Trainer._eval(model, device, eval_batch_generator, eval_num_steps, eval_summary) 128 | score = evaluator.get_score(model.get_best_answer(output, eval_raw_dataset)) 129 | metrics_string = " ; ".join("{}: {:05.3f}".format(k, v) for k, v in score.items()) 130 | logging.info("- Eval metrics: " + metrics_string) 131 | 132 | if model.ema_decay > 0: 133 | # TODO how to do it 134 | pass 135 | 136 | # Save best weights 137 | eval_score = score[evaluator.get_monitor()] 138 | if eval_score > best_eval_score: 139 | logging.info("- epoch %d episode %d: Found new best score: %f" % (epoch + 1, episode + 1, eval_score)) 140 | best_eval_score = eval_score 141 | # Save best weights 142 | if save_dir is not None: 143 | best_save_path = os.path.join(save_dir, 'best_weights', 'after-episode') 144 | # TODO the best save path need always only one model, need to be improved 145 | for file in os.listdir(best_save_path): 146 | os.remove(os.path.join(best_save_path, file)) 147 | model.save(best_save_path, global_step=episode_id) 148 | logging.info("- Found new best model, saving in {}".format(best_save_path)) 149 | 150 | @staticmethod 151 | def evaluate(model, device, batch_generator, evaluator): 152 | model.to(device) 153 | batch_generator.init() 154 | eval_raw_dataset = batch_generator.get_raw_dataset() 155 | 156 | eval_num_steps = (batch_generator.get_dataset_size() + 157 | batch_generator.get_batch_size() - 1) // batch_generator.get_batch_size() 158 | output = Trainer._eval(model, batch_generator, eval_num_steps, None) 159 | score = evaluator.get_score(model.get_best_answer(output, eval_raw_dataset)) 160 | metrics_string = " ; ".join("{}: {:05.3f}".format(k, v) for k, v in score.items()) 161 | logging.info("- Eval metrics: " + metrics_string) 162 | 163 | @staticmethod 164 | def inference(model, device, batch_generator): 165 | model.to(device) 166 | batch_generator.init() 167 | test_raw_dataset = batch_generator.get_raw_dataset() 168 | eval_num_steps = (batch_generator.get_dataset_size() + 169 | batch_generator.get_batch_size() - 1) // batch_generator.get_batch_size() 170 | output = Trainer._inference(model, batch_generator, eval_num_steps) 171 | pred_answers = model.get_best_answer(output, test_raw_dataset) 172 | return pred_answers 173 | -------------------------------------------------------------------------------- /pytorch_mrc/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YingZiqiang/PyTorch-MRCToolkit/23ec374338509a2e61d0060a43f2d6a32fe337d3/pytorch_mrc/utils/__init__.py -------------------------------------------------------------------------------- /pytorch_mrc/utils/tokenizer.py: -------------------------------------------------------------------------------- 1 | # coding:utf-8 2 | import spacy 3 | from stanfordnlp.server.client import CoreNLPClient 4 | import jieba 5 | import multiprocessing 6 | import re 7 | 8 | 9 | class SpacyTokenizer(object): 10 | def __init__(self, fine_grained=False): 11 | self.nlp = spacy.load('en', disable=['parser', 'tagger', 'entity']) 12 | self.fine_grained = fine_grained 13 | 14 | def word_tokenizer(self, doc): 15 | if not self.fine_grained: 16 | doc = self.nlp(doc) 17 | tokens = [token.text for token in doc] 18 | token_spans = [(token.idx, token.idx + len(token.text)) for token in doc] 19 | return tokens, token_spans 20 | sentence = doc 21 | tokens = [] 22 | token_spans = [] 23 | cur = 0 24 | pattern = u'-|–|—|:|’|\.|,|\[|\?|\(|\)|~|\$|/' 25 | for next in re.finditer(pattern, sentence): 26 | for token in self.nlp(sentence[cur:next.regs[0][0]]): 27 | if token.text.strip() != '': 28 | tokens.append(token.text) 29 | token_spans.append((cur + token.idx, cur + token.idx + len(token.text))) 30 | tokens.append(sentence[next.regs[0][0]:next.regs[0][1]]) 31 | token_spans.append((next.regs[0][0], next.regs[0][1])) 32 | cur = next.regs[0][1] 33 | for token in self.nlp(sentence[cur:]): 34 | if token.text.strip() != '': 35 | tokens.append(token.text) 36 | token_spans.append((cur + token.idx, cur + token.idx + len(token.text))) 37 | return tokens, token_spans 38 | 39 | def word_tokenizer_parallel(self, docs): 40 | docs = [doc for doc in self.nlp.pipe(docs, batch_size=64, n_threads=multiprocessing.cpu_count())] 41 | tokens = [[token.text for token in doc] for doc in docs] 42 | token_spans = [[(token.idx, token.idx + len(token.text)) for token in doc] for doc in docs] 43 | return tokens, token_spans 44 | 45 | 46 | class JieBaTokenizer(object): 47 | """ 48 | only for chinese tokenize,no pos/ner feature function 49 | """ 50 | def __init__(self): 51 | self.tokenizer = jieba 52 | 53 | def word_tokenizer(self, doc): 54 | tokens = self.tokenizer.cut(doc) 55 | tokens = ''.join(tokens).split('') 56 | start = 0 57 | token_spans = [] 58 | for token in tokens: 59 | token_spans.append((start, start + len(token))) 60 | start += len(token) 61 | return tokens, token_spans 62 | 63 | 64 | class StanfordTokenizer(object): 65 | def __init__(self, language='zh', annotators='ssplit tokenize', timeout=30000, memory="4G"): 66 | if language == 'zh': 67 | CHINESE_PROPERTIES = { 68 | "tokenize.language": "zh", 69 | "segment.model": "edu/stanford/nlp/models/segmenter/chinese/ctb.gz", 70 | "segment.sighanCorporaDict": "edu/stanford/nlp/models/segmenter/chinese", 71 | "segment.serDictionary": "edu/stanford/nlp/models/segmenter/chinese/dict-chris6.ser.gz", 72 | "segment.sighanPostProcessing": "true", 73 | "ssplit.boundaryTokenRegex": "[.。]|[!?!?]+", 74 | } 75 | else: 76 | CHINESE_PROPERTIES = {} 77 | self.client = CoreNLPClient(annotators=annotators, timeout=timeout, memory=memory, properties=CHINESE_PROPERTIES) 78 | 79 | def word_tokenizer(self, doc): 80 | try: 81 | annotated = self.client.annotate(doc) 82 | tokens, token_spans = [], [] 83 | for sentence in annotated.sentence: 84 | for token in sentence.token: 85 | tokens.append(token.word) 86 | token_spans.append((token.beginChar, token.endChar)) 87 | return tokens, token_spans 88 | except Exception as e: 89 | return None, None 90 | -------------------------------------------------------------------------------- /unit_tests/data/batch_generator_test.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('../..') 3 | from pprint import pprint 4 | from pytorch_mrc.dataset.squad import SquadReader 5 | from pytorch_mrc.data.vocabulary import Vocabulary 6 | from pytorch_mrc.data.batch_generator import BatchGenerator 7 | 8 | 9 | # define print function to make sure something is right 10 | def print_info(batch_generator): 11 | print('***dataset keys***: {}'.format(list(batch_generator.dataset[0].keys()))) 12 | print('***dataset sample***: \n{}'.format(batch_generator.dataset[0])) 13 | 14 | batch_generator.init() 15 | print('*****one batch data sample*****') 16 | batch_sample = batch_generator.next() 17 | # pprint(batch_sample) # when batch_size is small you can do this. 18 | for k, v in batch_sample.items(): 19 | print('{} -> shape: {}'.format(k, list(v.size()))) 20 | print('*' * 10) 21 | 22 | 23 | # define data path 24 | tiny_file = "/home/len/yingzq/nlp/mrc_dataset/squad-v1.1/tiny-v1.1.json" 25 | vocab_file = '/home/len/yingzq/nlp/mrc_dataset/squad-v1.1/vocab_data/vocab_tiny_100d.pkl' 26 | bg_save_file = '/home/len/yingzq/nlp/mrc_dataset/squad-v1.1/bg_data/bg_tiny_32b_100d.pkl' 27 | 28 | # read data 29 | reader = SquadReader() 30 | print('reading data from {} ...'.format(tiny_file)) 31 | tiny_data = reader.read(tiny_file) 32 | 33 | # load vocabulary 34 | vocab = Vocabulary() 35 | print('***loading vocabulary...***') 36 | vocab.load(vocab_file) 37 | word_embedding = vocab.get_word_embedding() 38 | print('word vocab size: {}, word embedding shape: {}'.format(len(vocab.get_word_vocab()), word_embedding.shape)) 39 | 40 | # build batch generator 41 | print('***building batch generator***') 42 | batch_generator = BatchGenerator() 43 | batch_generator.build(vocab, tiny_data, batch_size=32, shuffle=False) 44 | print_info(batch_generator) 45 | 46 | # save batch generator 47 | print('***saving BatchGenerator...***') 48 | batch_generator.save(bg_save_file) 49 | print('successful!') 50 | 51 | # load batch generator 52 | print('***loading BatchGenerator***') 53 | batch_generator = BatchGenerator() 54 | batch_generator.load(bg_save_file) 55 | print_info(batch_generator) 56 | 57 | print('done!') 58 | -------------------------------------------------------------------------------- /unit_tests/data/vocabulary_test.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('../..') 3 | from pytorch_mrc.data.vocabulary import Vocabulary 4 | from pytorch_mrc.dataset.squad import SquadReader 5 | 6 | 7 | # define print function to make sure something is right 8 | def print_info(vocab, word_embedding): 9 | print('word vocab size: {}, word embedding shape: {}'.format(len(vocab.get_word_vocab()), word_embedding.shape)) 10 | print('word pad token idx: {}, embedding is: \n{}'.format( 11 | vocab.get_word_pad_idx(), word_embedding[vocab.get_word_pad_idx()])) 12 | print('word unk token idx: {}, embedding is: \n{}'.format( 13 | vocab.get_word_unk_idx(), word_embedding[vocab.get_word_unk_idx()])) 14 | print('word `code` idx: {}, embedding is: \n{}'.format( 15 | vocab.get_word_idx('code'), word_embedding[vocab.get_word_idx('code')])) 16 | print('word `randomrandom` idx: {}, embedding is: \n{}'.format( 17 | vocab.get_word_idx('randomrandom'), word_embedding[vocab.get_word_idx('randomrandom')])) 18 | 19 | 20 | # define data path 21 | tiny_file = "/home/len/yingzq/nlp/mrc_dataset/squad-v1.1/tiny-v1.1.json" 22 | embedding_file = '/home/len/yingzq/nlp/mrc_dataset/word_embeddings/glove.6B.100d.txt' 23 | vocab_save_file = '/home/len/yingzq/nlp/mrc_dataset/squad-v1.1/vocab_data/vocab_tiny_100d.pkl' # where to save vocab data 24 | 25 | # read data 26 | reader = SquadReader() 27 | print('reading data from {} ...'.format(tiny_file)) 28 | tiny_data = reader.read(tiny_file) 29 | 30 | # build the vocabulary 31 | vocab = Vocabulary() 32 | print('building vocabulary...') 33 | vocab.build_vocab(tiny_data, min_word_count=3, min_char_count=10) 34 | print('making word embedding...') 35 | vocab.make_word_embedding(embedding_file) 36 | word_embedding = vocab.get_word_embedding() 37 | print_info(vocab, word_embedding) 38 | 39 | # save vocabulary 40 | print('***saveing vocabulary...***') 41 | vocab.save(vocab_save_file) 42 | print('successful!') 43 | 44 | # load vocabulary 45 | print('***loading vocabulary...***') 46 | vocab = Vocabulary() 47 | vocab.load(vocab_save_file) 48 | word_embedding = vocab.get_word_embedding() 49 | print_info(vocab, word_embedding) 50 | 51 | print('done!') 52 | -------------------------------------------------------------------------------- /unit_tests/model/bidaf_test.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('../..') 3 | import logging 4 | import torch 5 | from pytorch_mrc.data.vocabulary import Vocabulary 6 | from pytorch_mrc.dataset.squad import SquadReader, SquadEvaluator 7 | from pytorch_mrc.model.bidaf import BiDAF 8 | from pytorch_mrc.data.batch_generator import BatchGenerator 9 | 10 | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') 11 | data_folder = '/home/len/yingzq/nlp/mrc_dataset/squad-v1.1/' 12 | embedding_folder = '/home/len/yingzq/nlp/mrc_dataset/word_embeddings/' 13 | tiny_file = data_folder + "tiny-v1.1.json" 14 | embedding_file = embedding_folder + 'glove.6B.100d.txt' 15 | 16 | reader = SquadReader() 17 | tiny_data = reader.read(tiny_file) 18 | evaluator = SquadEvaluator(tiny_file) 19 | 20 | logging.info('building vocab and making embedding...') 21 | vocab = Vocabulary() 22 | vocab.build_vocab(tiny_data, min_word_count=3, min_char_count=10) 23 | vocab.make_word_embedding(embedding_file) 24 | word_embedding = vocab.get_word_embedding() 25 | logging.info('word vocab size: {}, word embedding shape: {}'.format(len(vocab.get_word_vocab()), word_embedding.shape)) 26 | 27 | train_batch_generator = BatchGenerator() 28 | train_batch_generator.build(vocab, tiny_data, batch_size=32, shuffle=True) 29 | eval_batch_generator = BatchGenerator() 30 | eval_batch_generator.build(vocab, tiny_data, batch_size=32) 31 | 32 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 33 | model = BiDAF(vocab, device, pretrained_word_embedding=word_embedding) 34 | model.compile() 35 | model.train_and_evaluate(train_batch_generator, eval_batch_generator, evaluator, epochs=2, episodes=2, log_every_n_batch=10) 36 | -------------------------------------------------------------------------------- /unit_tests/model/rnet_hkust_test.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('../..') 3 | import logging 4 | import torch 5 | from pytorch_mrc.data.vocabulary import Vocabulary 6 | from pytorch_mrc.dataset.squad import SquadReader, SquadEvaluator 7 | from pytorch_mrc.model.rnet_hkust import RNET 8 | from pytorch_mrc.data.batch_generator import BatchGenerator 9 | 10 | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') 11 | data_folder = '/home/len/yingzq/nlp/mrc_dataset/squad-v1.1/' 12 | embedding_folder = '/home/len/yingzq/nlp/mrc_dataset/word_embeddings/' 13 | tiny_file = data_folder + "tiny-v1.1.json" 14 | embedding_file = embedding_folder + 'glove.6B.100d.txt' 15 | 16 | reader = SquadReader(fine_grained=True) 17 | tiny_data = reader.read(tiny_file) 18 | evaluator = SquadEvaluator(tiny_file) 19 | 20 | logging.info('building vocab and making embedding...') 21 | vocab = Vocabulary() 22 | vocab.build_vocab(tiny_data, min_word_count=3, min_char_count=10) 23 | vocab.make_word_embedding(embedding_file) 24 | word_embedding = vocab.get_word_embedding() 25 | logging.info('word vocab size: {}, word embedding shape: {}'.format(len(vocab.get_word_vocab()), word_embedding.shape)) 26 | 27 | train_batch_generator = BatchGenerator() 28 | train_batch_generator.build(vocab, tiny_data, batch_size=32, shuffle=True) 29 | eval_batch_generator = BatchGenerator() 30 | eval_batch_generator.build(vocab, tiny_data, batch_size=32) 31 | 32 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 33 | model = RNET(vocab, device, pretrained_word_embedding=word_embedding, word_embedding_size=100) 34 | model.compile() 35 | model.train_and_evaluate(train_batch_generator, eval_batch_generator, evaluator, epochs=2, episodes=2, log_every_n_batch=10) 36 | -------------------------------------------------------------------------------- /unit_tests/model/rnet_sogou_test.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('../..') 3 | import logging 4 | import torch 5 | from pytorch_mrc.data.vocabulary import Vocabulary 6 | from pytorch_mrc.dataset.squad import SquadReader, SquadEvaluator 7 | from pytorch_mrc.model.rnet_sogou import RNET 8 | from pytorch_mrc.data.batch_generator import BatchGenerator 9 | 10 | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') 11 | data_folder = '/home/len/yingzq/nlp/mrc_dataset/squad-v1.1/' 12 | embedding_folder = '/home/len/yingzq/nlp/mrc_dataset/word_embeddings/' 13 | tiny_file = data_folder + "tiny-v1.1.json" 14 | embedding_file = embedding_folder + 'glove.6B.100d.txt' 15 | 16 | reader = SquadReader(fine_grained=True) 17 | tiny_data = reader.read(tiny_file) 18 | evaluator = SquadEvaluator(tiny_file) 19 | 20 | logging.info('building vocab and making embedding...') 21 | vocab = Vocabulary(do_lowercase=False) 22 | vocab.build_vocab(tiny_data, min_word_count=3, min_char_count=10) 23 | vocab.make_word_embedding(embedding_file) 24 | word_embedding = vocab.get_word_embedding() 25 | logging.info('word vocab size: {}, word embedding shape: {}'.format(len(vocab.get_word_vocab()), word_embedding.shape)) 26 | 27 | train_batch_generator = BatchGenerator() 28 | train_batch_generator.build(vocab, tiny_data, batch_size=32, shuffle=True) 29 | eval_batch_generator = BatchGenerator() 30 | eval_batch_generator.build(vocab, tiny_data, batch_size=32) 31 | 32 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 33 | model = RNET(vocab, device, pretrained_word_embedding=word_embedding, word_embedding_size=100) 34 | model.compile() 35 | model.train_and_evaluate(train_batch_generator, eval_batch_generator, evaluator, epochs=2, episodes=2, log_every_n_batch=10) 36 | --------------------------------------------------------------------------------