├── CODE_OF_CONDUCT.md ├── LICENSE ├── README.md ├── SECURITY.md ├── SUPPORT.md ├── main.py └── src ├── __init__.py ├── data_handler.py ├── models ├── modeling_graphformers.py ├── modeling_graphsage.py └── tnlrv3 │ ├── config.py │ ├── configuration_tnlrv3.py │ ├── convert_state_dict.py │ ├── modeling.py │ ├── modeling_decoding.py │ ├── s2s_loader.py │ ├── tokenization_tnlrv3.py │ └── utils.py ├── parameters.py ├── run.py └── utils.py /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Microsoft Open Source Code of Conduct 2 | 3 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 4 | 5 | Resources: 6 | 7 | - [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/) 8 | - [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) 9 | - Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns 10 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Microsoft Corporation. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GraphFormers 2 | 3 | ## Introduction 4 | 5 | Implementation for [GraphFormers: GNN-nested Transformers for Representation Learning on Textual Graph](https://arxiv.org/abs/2105.02605) 6 | 7 | ## Requirements 8 | ``` 9 | Python==3.6 10 | torch==1.6.0 11 | transformers==3.4.0 12 | ``` 13 | 14 | ## Data & Pretrained Language Model 15 | 16 | Please refer to [OneDrive](https://1drv.ms/u/s!Ag0vYLiCLJL3hTP91qvk1L61SzNn?e=RC1veX) 17 | 18 | ## Usage 19 | ``` 20 | python main.py 21 | ``` 22 | More parameter information please refer to `src/parameter.py` 23 | 24 | ## Contributing 25 | 26 | This project welcomes contributions and suggestions. Most contributions require you to agree to a 27 | Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us 28 | the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com. 29 | 30 | When you submit a pull request, a CLA bot will automatically determine whether you need to provide 31 | a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions 32 | provided by the bot. You will only need to do this once across all repos using our CLA. 33 | 34 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 35 | For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or 36 | contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments. 37 | 38 | ## Trademarks 39 | 40 | This project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft 41 | trademarks or logos is subject to and must follow 42 | [Microsoft's Trademark & Brand Guidelines](https://www.microsoft.com/en-us/legal/intellectualproperty/trademarks/usage/general). 43 | Use of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship. 44 | Any use of third-party trademarks or logos are subject to those third-party's policies. 45 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Security 4 | 5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/). 6 | 7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://docs.microsoft.com/en-us/previous-versions/tn-archive/cc751383(v=technet.10)), please report it to us as described below. 8 | 9 | ## Reporting Security Issues 10 | 11 | **Please do not report security vulnerabilities through public GitHub issues.** 12 | 13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://msrc.microsoft.com/create-report). 14 | 15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://www.microsoft.com/en-us/msrc/pgp-key-msrc). 16 | 17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc). 18 | 19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue: 20 | 21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.) 22 | * Full paths of source file(s) related to the manifestation of the issue 23 | * The location of the affected source code (tag/branch/commit or direct URL) 24 | * Any special configuration required to reproduce the issue 25 | * Step-by-step instructions to reproduce the issue 26 | * Proof-of-concept or exploit code (if possible) 27 | * Impact of the issue, including how an attacker might exploit the issue 28 | 29 | This information will help us triage your report more quickly. 30 | 31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://microsoft.com/msrc/bounty) page for more details about our active programs. 32 | 33 | ## Preferred Languages 34 | 35 | We prefer all communications to be in English. 36 | 37 | ## Policy 38 | 39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://www.microsoft.com/en-us/msrc/cvd). 40 | 41 | -------------------------------------------------------------------------------- /SUPPORT.md: -------------------------------------------------------------------------------- 1 | # TODO: The maintainer of this repo has not yet edited this file 2 | 3 | **REPO OWNER**: Do you want Customer Service & Support (CSS) support for this product/project? 4 | 5 | - **No CSS support:** Fill out this template with information about how to file issues and get help. 6 | - **Yes CSS support:** Fill out an intake form at [aka.ms/spot](https://aka.ms/spot). CSS will work with/help you to determine next steps. More details also available at [aka.ms/onboardsupport](https://aka.ms/onboardsupport). 7 | - **Not sure?** Fill out a SPOT intake as though the answer were "Yes". CSS will help you decide. 8 | 9 | *Then remove this first heading from this SUPPORT.MD file before publishing your repo.* 10 | 11 | # Support 12 | 13 | ## How to file issues and get help 14 | 15 | This project uses GitHub Issues to track bugs and feature requests. Please search the existing 16 | issues before filing new issues to avoid duplicates. For new issues, file your bug or 17 | feature request as a new Issue. 18 | 19 | For help and questions about using this project, please **REPO MAINTAINER: INSERT INSTRUCTIONS HERE 20 | FOR HOW TO ENGAGE REPO OWNERS OR COMMUNITY FOR HELP. COULD BE A STACK OVERFLOW TAG OR OTHER 21 | CHANNEL. WHERE WILL YOU HELP PEOPLE?**. 22 | 23 | ## Microsoft Support Policy 24 | 25 | Support for this **PROJECT or PRODUCT** is limited to the resources listed above. 26 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | 4 | import torch.multiprocessing as mp 5 | 6 | from src.parameters import parse_args 7 | from src.run import train, test 8 | from src.utils import setuplogging 9 | 10 | if __name__ == "__main__": 11 | 12 | setuplogging() 13 | gpus = ','.join([str(_ + 1) for _ in range(2)]) 14 | os.environ["CUDA_VISIBLE_DEVICES"] = gpus 15 | os.environ['MASTER_ADDR'] = 'localhost' 16 | os.environ['MASTER_PORT'] = '12355' 17 | args = parse_args() 18 | print(os.getcwd()) 19 | args.log_steps = 5 20 | args.world_size = 2 # GPU number 21 | args.mode = 'train' 22 | Path(args.model_dir).mkdir(parents=True, exist_ok=True) 23 | 24 | cont = False 25 | if args.mode == 'train': 26 | print('-----------train------------') 27 | if args.world_size > 1: 28 | mp.freeze_support() 29 | mgr = mp.Manager() 30 | end = mgr.Value('b', False) 31 | mp.spawn(train, 32 | args=(args, end, cont), 33 | nprocs=args.world_size, 34 | join=True) 35 | else: 36 | end = None 37 | train(0, args, end, cont) 38 | 39 | if args.mode == 'test': 40 | args.load_ckpt_name = "/data/workspace/Share/junhan/TopoGram_ckpt/dblp/topogram-pretrain-finetune-dblp-best3.pt" 41 | print('-------------test--------------') 42 | test(args) 43 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/GraphFormers/06636cd7196cc2b3eecdb5a49fd2d405d2ea4e4b/src/__init__.py -------------------------------------------------------------------------------- /src/data_handler.py: -------------------------------------------------------------------------------- 1 | from concurrent.futures import ThreadPoolExecutor 2 | from dataclasses import dataclass 3 | from queue import Queue 4 | from typing import Any, Dict, List, Tuple, Callable, Union 5 | 6 | import numpy as np 7 | import torch 8 | import torch.distributed as dist 9 | from torch.utils.data.dataset import IterableDataset 10 | from transformers import BertTokenizerFast 11 | 12 | 13 | class DatasetForMatching(IterableDataset): 14 | def __init__( 15 | self, 16 | file_path: str, 17 | tokenizer: Union[BertTokenizerFast, str] = "bert-base-uncased", 18 | ): 19 | 20 | self.data_file = open(file_path, "r", encoding="utf-8") 21 | if isinstance(tokenizer, str): 22 | self.tokenizer = BertTokenizerFast.from_pretrained(tokenizer) 23 | else: 24 | self.tokenizer = tokenizer 25 | 26 | def process(self, input_line): 27 | # Input file format: 28 | # Example: 29 | # A simple algorithm for Boolean operations on polygons|'|Geometric modelling based on simplicial chains|'|Boolean operations on general planar polygons|'|Reentrant polygon clipping|'|Plane-sweep algorithms for intersecting geometric figures|'|A new algorithm for computing Boolean operations on polygons An analysis and algorithm for polygon clipping|'|Set Membership Classification: A Unified Approach to Geometric Intersection Problems|'|Reentrant polygon clipping|'|Hidden surface removal using polygon area sorting|'|Polygon comparison using a graph representation|'|A New Concept and Method for Line Clipping 30 | # Balanced Multifilter Banks for Multiple Description Coding|'|Balanced multiwavelets|'|On minimal lattice factorizations of symmetric-antisymmetric multifilterbanks|'|High-order balanced multiwavelets: theory, factorization, and design|'|Single-Trial Multiwavelet Coherence in Application to Neurophysiological Time Series|'|The application of multiwavelet filterbanks to image processing Armlets and balanced multiwavelets: flipping filter construction|'|Multiwavelet prefilters. II. Optimal orthogonal prefilters|'|Regularity of multiwavelets|'|Balanced GHM-like multiscaling functions|'|A new prefilter design for discrete multiwavelet transforms|'|Balanced multiwavelets with short filters 31 | 32 | query_and_neighbors, key_and_neighbors = input_line.strip('\n').split('\t')[:2] 33 | query_and_neighbors = query_and_neighbors.split('|\'|') 34 | key_and_neighbors = key_and_neighbors.split('|\'|') 35 | tokens_query_and_neighbors = self.tokenizer.batch_encode_plus(query_and_neighbors, add_special_tokens=False)[ 36 | 'input_ids'] 37 | tokens_key_and_neighbors = self.tokenizer.batch_encode_plus(key_and_neighbors, add_special_tokens=False)[ 38 | 'input_ids'] 39 | 40 | return tokens_query_and_neighbors, tokens_key_and_neighbors 41 | 42 | def __iter__(self): 43 | for line in self.data_file: 44 | yield self.process(line) 45 | 46 | 47 | @dataclass 48 | class DataCollatorForMatching: 49 | mlm: bool 50 | neighbor_num: int 51 | token_length: int 52 | tokenizer: Union[BertTokenizerFast, str] = "bert-base-uncased" 53 | mlm_probability: float = 0.15 54 | random_seed: int = 42 55 | 56 | def __post_init__(self): 57 | if isinstance(self.tokenizer, str): 58 | self.tokenizer = BertTokenizerFast.from_pretrained(self.tokenizer) 59 | self.random_state = np.random.RandomState(seed=self.random_seed) 60 | 61 | def __call__(self, samples: List[List[List[List[int]]]]) -> Dict[str, torch.Tensor]: 62 | input_ids_query_and_neighbors_batch = [] 63 | attention_mask_query_and_neighbors_batch = [] 64 | mask_query_and_neighbors_batch = [] 65 | input_ids_key_and_neighbors_batch = [] 66 | attention_mask_key_and_neighbors_batch = [] 67 | mask_key_and_neighbors_batch = [] 68 | for i, sample in (enumerate(samples)): 69 | input_ids_query_and_neighbors, attention_mask_query_and_neighbors, mask_query_and_neighbors, \ 70 | input_ids_key_and_neighbors, attention_mask_key_and_neighbors, mask_key_and_neighbors = self.create_training_sample( 71 | sample) 72 | 73 | input_ids_query_and_neighbors_batch.append(input_ids_query_and_neighbors) 74 | attention_mask_query_and_neighbors_batch.append(attention_mask_query_and_neighbors) 75 | mask_query_and_neighbors_batch.append(mask_query_and_neighbors) 76 | 77 | input_ids_key_and_neighbors_batch.append(input_ids_key_and_neighbors) 78 | attention_mask_key_and_neighbors_batch.append(attention_mask_key_and_neighbors) 79 | mask_key_and_neighbors_batch.append(mask_key_and_neighbors) 80 | 81 | if self.mlm: 82 | input_ids_query_and_neighbors_batch, mlm_labels_query_batch = self.mask_tokens( 83 | self._tensorize_batch(input_ids_query_and_neighbors_batch, self.tokenizer.pad_token_id), 84 | self.tokenizer.mask_token_id) 85 | input_ids_key_and_neighbors_batch, mlm_labels_key_batch = self.mask_tokens( 86 | self._tensorize_batch(input_ids_key_and_neighbors_batch, self.tokenizer.pad_token_id), 87 | self.tokenizer.mask_token_id) 88 | else: 89 | input_ids_query_and_neighbors_batch = self._tensorize_batch(input_ids_query_and_neighbors_batch, 90 | self.tokenizer.pad_token_id) 91 | input_ids_key_and_neighbors_batch = self._tensorize_batch(input_ids_key_and_neighbors_batch, 92 | self.tokenizer.pad_token_id) 93 | attention_mask_query_and_neighbors_batch = self._tensorize_batch(attention_mask_query_and_neighbors_batch, 0) 94 | attention_mask_key_and_neighbors_batch = self._tensorize_batch(attention_mask_key_and_neighbors_batch, 0) 95 | mask_query_and_neighbors_batch = self._tensorize_batch(mask_query_and_neighbors_batch, 0) 96 | mask_key_and_neighbors_batch = self._tensorize_batch(mask_key_and_neighbors_batch, 0) 97 | 98 | return { 99 | "input_ids_query_and_neighbors_batch": input_ids_query_and_neighbors_batch, 100 | "attention_mask_query_and_neighbors_batch": attention_mask_query_and_neighbors_batch, 101 | "mlm_labels_query_batch": mlm_labels_query_batch if self.mlm else None, 102 | "mask_query_and_neighbors_batch": mask_query_and_neighbors_batch, 103 | "input_ids_key_and_neighbors_batch": input_ids_key_and_neighbors_batch, 104 | "attention_mask_key_and_neighbors_batch": attention_mask_key_and_neighbors_batch, 105 | "mlm_labels_key_batch": mlm_labels_key_batch if self.mlm else None, 106 | "mask_key_and_neighbors_batch": mask_key_and_neighbors_batch, 107 | } 108 | 109 | def _tensorize_batch(self, sequences: Union[List[torch.Tensor], List[List[torch.Tensor]]], 110 | padding_value) -> torch.Tensor: 111 | if len(sequences[0].size()) == 1: 112 | max_len_1 = max([s.size(0) for s in sequences]) 113 | out_dims = (len(sequences), max_len_1) 114 | out_tensor = sequences[0].new_full(out_dims, padding_value) 115 | for i, tensor in enumerate(sequences): 116 | length_1 = tensor.size(0) 117 | out_tensor[i, :length_1] = tensor 118 | return out_tensor 119 | elif len(sequences[0].size()) == 2: 120 | max_len_1 = max([s.size(0) for s in sequences]) 121 | max_len_2 = max([s.size(1) for s in sequences]) 122 | out_dims = (len(sequences), max_len_1, max_len_2) 123 | out_tensor = sequences[0].new_full(out_dims, padding_value) 124 | for i, tensor in enumerate(sequences): 125 | length_1 = tensor.size(0) 126 | length_2 = tensor.size(1) 127 | out_tensor[i, :length_1, :length_2] = tensor 128 | return out_tensor 129 | else: 130 | raise 131 | 132 | def create_training_sample(self, sample: List[List[List[int]]]): 133 | 134 | def process_node_and_neighbors(tokens_node_and_neighbors): 135 | max_num_tokens = self.token_length - self.tokenizer.num_special_tokens_to_add(pair=False) 136 | input_ids_node_and_neighbors, attention_mask_node_and_neighbors, mask_node_and_neighbors = [], [], [] 137 | for i, tokens in enumerate(tokens_node_and_neighbors): 138 | if i > self.neighbor_num: break 139 | input_ids_node_and_neighbors.append( 140 | torch.tensor(self.tokenizer.build_inputs_with_special_tokens(tokens[:max_num_tokens]))) 141 | attention_mask_node_and_neighbors.append(torch.tensor([1] * len(input_ids_node_and_neighbors[-1]))) 142 | if len(tokens) == 0: 143 | mask_node_and_neighbors.append(torch.tensor(0)) 144 | else: 145 | mask_node_and_neighbors.append(torch.tensor(1)) 146 | input_ids_node_and_neighbors = self._tensorize_batch(input_ids_node_and_neighbors, 147 | self.tokenizer.pad_token_id) 148 | attention_mask_node_and_neighbors = self._tensorize_batch(attention_mask_node_and_neighbors, 0) 149 | mask_node_and_neighbors = torch.stack(mask_node_and_neighbors) 150 | return input_ids_node_and_neighbors, attention_mask_node_and_neighbors, mask_node_and_neighbors 151 | 152 | tokens_query_and_neighbors, tokens_key_and_neighbors = sample 153 | input_ids_query_and_neighbors, attention_mask_query_and_neighbors, mask_query_and_neighbors = process_node_and_neighbors( 154 | tokens_query_and_neighbors) 155 | input_ids_key_and_neighbors, attention_mask_key_and_neighbors, mask_key_and_neighbors = process_node_and_neighbors( 156 | tokens_key_and_neighbors) 157 | 158 | return input_ids_query_and_neighbors, attention_mask_query_and_neighbors, mask_query_and_neighbors, \ 159 | input_ids_key_and_neighbors, attention_mask_key_and_neighbors, mask_key_and_neighbors 160 | 161 | def mask_tokens(self, inputs_origin: torch.Tensor, mask_id: int) -> Tuple[torch.Tensor, torch.Tensor]: 162 | """ 163 | Prepare masked tokens inputs/labels for masked language modeling. 164 | """ 165 | inputs = inputs_origin.clone() 166 | labels = torch.zeros((inputs.shape[0], inputs.shape[2]), dtype=torch.long) - 100 167 | for i in range(len(inputs_origin)): 168 | input_origin = inputs_origin[i][0] 169 | input = inputs[i][0] 170 | mask_num, valid_length = 0, 0 171 | start_indexes = [] 172 | for index, x in enumerate(input_origin): 173 | if int(x) not in self.tokenizer.all_special_ids: 174 | valid_length += 1 175 | start_indexes.append(index) 176 | labels[i][index] = -99 177 | self.random_state.shuffle(start_indexes) 178 | if valid_length > 0: 179 | while mask_num / valid_length < self.mlm_probability: 180 | start_index = start_indexes.pop() 181 | span_length = 1e9 182 | while span_length > 10: span_length = np.random.geometric(0.2) 183 | for j in range(start_index, min(start_index + span_length, len(input_origin))): 184 | if labels[i][j] != -99: continue 185 | labels[i][j] = input_origin[j].clone() 186 | rand = self.random_state.random() 187 | if rand < 0.8: 188 | input[j] = mask_id 189 | elif rand < 0.9: 190 | input[j] = self.random_state.randint(0, self.tokenizer.vocab_size - 1) 191 | mask_num += 1 192 | if mask_num / valid_length >= self.mlm_probability: 193 | break 194 | labels[i] = torch.masked_fill(labels[i], labels[i] < 0, -100) 195 | return inputs, labels 196 | 197 | 198 | @dataclass 199 | class MultiProcessDataLoader: 200 | dataset: IterableDataset 201 | batch_size: int 202 | collate_fn: Callable 203 | local_rank: int 204 | world_size: int 205 | global_end: Any 206 | blocking: bool = False 207 | drop_last: bool = True 208 | 209 | def _start(self): 210 | self.local_end = False 211 | self.aval_count = 0 212 | self.outputs = Queue(10) 213 | self.pool = ThreadPoolExecutor(1) 214 | self.pool.submit(self._produce) 215 | 216 | def _produce(self): 217 | for batch in self._generate_batch(): 218 | self.outputs.put(batch) 219 | self.aval_count += 1 220 | self.pool.shutdown(wait=False) 221 | raise 222 | 223 | def _generate_batch(self): 224 | batch = [] 225 | for i, sample in enumerate(self.dataset): 226 | if i % self.world_size != self.local_rank: continue 227 | batch.append(sample) 228 | if len(batch) >= self.batch_size: 229 | yield self.collate_fn(batch[:self.batch_size]) 230 | batch = batch[self.batch_size:] 231 | else: 232 | if len(batch) > 0 and not self.drop_last: 233 | yield self.collate_fn(batch) 234 | batch = [] 235 | self.local_end = True 236 | 237 | def __iter__(self): 238 | if self.blocking: 239 | return self._generate_batch() 240 | self._start() 241 | return self 242 | 243 | def __next__(self): 244 | dist.barrier() 245 | while self.aval_count == 0: 246 | if self.local_end or self.global_end.value: 247 | self.global_end.value = True 248 | break 249 | dist.barrier() 250 | if self.global_end.value: 251 | raise StopIteration 252 | next_batch = self.outputs.get() 253 | self.aval_count -= 1 254 | return next_batch 255 | 256 | 257 | @dataclass 258 | class SingleProcessDataLoader: 259 | dataset: IterableDataset 260 | batch_size: int 261 | collate_fn: Callable 262 | blocking: bool = False 263 | drop_last: bool = True 264 | 265 | def _start(self): 266 | self.local_end = False 267 | self.aval_count = 0 268 | self.outputs = Queue(10) 269 | self.pool = ThreadPoolExecutor(1) 270 | self.pool.submit(self._produce) 271 | 272 | def _produce(self): 273 | for batch in self._generate_batch(): 274 | self.outputs.put(batch) 275 | self.aval_count += 1 276 | self.pool.shutdown(wait=False) 277 | raise 278 | 279 | def _generate_batch(self): 280 | batch = [] 281 | for i, sample in enumerate(self.dataset): 282 | batch.append(sample) 283 | if len(batch) >= self.batch_size: 284 | yield self.collate_fn(batch[:self.batch_size]) 285 | batch = batch[self.batch_size:] 286 | else: 287 | if len(batch) > 0 and not self.drop_last: 288 | yield self.collate_fn(batch) 289 | batch = [] 290 | self.local_end = True 291 | 292 | def __iter__(self): 293 | if self.blocking: 294 | return self._generate_batch() 295 | self._start() 296 | return self 297 | 298 | def __next__(self): 299 | while self.aval_count == 0: 300 | if self.local_end: raise StopIteration 301 | next_batch = self.outputs.get() 302 | self.aval_count -= 1 303 | return next_batch 304 | -------------------------------------------------------------------------------- /src/models/modeling_graphformers.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | from src.models.tnlrv3.convert_state_dict import get_checkpoint_from_transformer_cache, state_dict_convert 9 | from src.models.tnlrv3.modeling import TuringNLRv3PreTrainedModel, logger, BertSelfAttention, BertLayer, WEIGHTS_NAME, \ 10 | BertEmbeddings, relative_position_bucket 11 | from src.utils import roc_auc_score, mrr_score, ndcg_score 12 | 13 | 14 | class GraphTuringNLRPreTrainedModel(TuringNLRv3PreTrainedModel): 15 | @classmethod 16 | def from_pretrained( 17 | cls, pretrained_model_name_or_path, reuse_position_embedding=None, 18 | replace_prefix=None, *model_args, **kwargs, 19 | ): 20 | model_type = kwargs.pop('model_type', 'tnlrv3') 21 | if model_type is not None and "state_dict" not in kwargs: 22 | if model_type in cls.supported_convert_pretrained_model_archive_map: 23 | pretrained_model_archive_map = cls.supported_convert_pretrained_model_archive_map[model_type] 24 | if pretrained_model_name_or_path in pretrained_model_archive_map: 25 | state_dict = get_checkpoint_from_transformer_cache( 26 | archive_file=pretrained_model_archive_map[pretrained_model_name_or_path], 27 | pretrained_model_name_or_path=pretrained_model_name_or_path, 28 | pretrained_model_archive_map=pretrained_model_archive_map, 29 | cache_dir=kwargs.get("cache_dir", None), force_download=kwargs.get("force_download", None), 30 | proxies=kwargs.get("proxies", None), resume_download=kwargs.get("resume_download", None), 31 | ) 32 | state_dict = state_dict_convert[model_type](state_dict) 33 | kwargs["state_dict"] = state_dict 34 | logger.info("Load HF ckpts") 35 | elif os.path.isfile(pretrained_model_name_or_path): 36 | state_dict = torch.load(pretrained_model_name_or_path, map_location='cpu') 37 | kwargs["state_dict"] = state_dict_convert[model_type](state_dict) 38 | logger.info("Load local ckpts") 39 | elif os.path.isdir(pretrained_model_name_or_path): 40 | state_dict = torch.load(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME), 41 | map_location='cpu') 42 | kwargs["state_dict"] = state_dict_convert[model_type](state_dict) 43 | logger.info("Load local ckpts") 44 | else: 45 | raise RuntimeError("Not fined the pre-trained checkpoint !") 46 | 47 | if kwargs["state_dict"] is None: 48 | logger.info("TNLRv3 does't support the model !") 49 | raise NotImplementedError() 50 | 51 | config = kwargs["config"] 52 | state_dict = kwargs["state_dict"] 53 | # initialize new position embeddings (From Microsoft/UniLM) 54 | _k = 'bert.embeddings.position_embeddings.weight' 55 | if _k in state_dict: 56 | if config.max_position_embeddings > state_dict[_k].shape[0]: 57 | logger.info("Resize > position embeddings !") 58 | old_vocab_size = state_dict[_k].shape[0] 59 | new_postion_embedding = state_dict[_k].data.new_tensor(torch.ones( 60 | size=(config.max_position_embeddings, state_dict[_k].shape[1])), dtype=torch.float) 61 | new_postion_embedding = nn.Parameter(data=new_postion_embedding, requires_grad=True) 62 | new_postion_embedding.data.normal_(mean=0.0, std=config.initializer_range) 63 | max_range = config.max_position_embeddings if reuse_position_embedding else old_vocab_size 64 | shift = 0 65 | while shift < max_range: 66 | delta = min(old_vocab_size, max_range - shift) 67 | new_postion_embedding.data[shift: shift + delta, :] = state_dict[_k][:delta, :] 68 | logger.info(" CP [%d ~ %d] into [%d ~ %d] " % (0, delta, shift, shift + delta)) 69 | shift += delta 70 | state_dict[_k] = new_postion_embedding.data 71 | del new_postion_embedding 72 | elif config.max_position_embeddings < state_dict[_k].shape[0]: 73 | logger.info("Resize < position embeddings !") 74 | old_vocab_size = state_dict[_k].shape[0] 75 | new_postion_embedding = state_dict[_k].data.new_tensor(torch.ones( 76 | size=(config.max_position_embeddings, state_dict[_k].shape[1])), dtype=torch.float) 77 | new_postion_embedding = nn.Parameter(data=new_postion_embedding, requires_grad=True) 78 | new_postion_embedding.data.normal_(mean=0.0, std=config.initializer_range) 79 | new_postion_embedding.data.copy_(state_dict[_k][:config.max_position_embeddings, :]) 80 | state_dict[_k] = new_postion_embedding.data 81 | del new_postion_embedding 82 | 83 | # initialize new rel_pos weight 84 | _k = 'bert.rel_pos_bias.weight' 85 | if _k in state_dict and state_dict[_k].shape[1] != (config.rel_pos_bins + 2): 86 | logger.info( 87 | f"rel_pos_bias.weight.shape[1]:{state_dict[_k].shape[1]} != config.bus_num+config.rel_pos_bins:{config.rel_pos_bins + 2}") 88 | old_rel_pos_bias = state_dict[_k] 89 | new_rel_pos_bias = torch.cat( 90 | [old_rel_pos_bias, old_rel_pos_bias[:, -1:].expand(old_rel_pos_bias.size(0), 2)], -1) 91 | new_rel_pos_bias = nn.Parameter(data=new_rel_pos_bias, requires_grad=True) 92 | state_dict[_k] = new_rel_pos_bias.data 93 | del new_rel_pos_bias 94 | 95 | if replace_prefix is not None: 96 | new_state_dict = {} 97 | for key in state_dict: 98 | if key.startswith(replace_prefix): 99 | new_state_dict[key[len(replace_prefix):]] = state_dict[key] 100 | else: 101 | new_state_dict[key] = state_dict[key] 102 | kwargs["state_dict"] = new_state_dict 103 | del state_dict 104 | 105 | return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) 106 | 107 | 108 | class GraphAggregation(BertSelfAttention): 109 | def __init__(self, config): 110 | super(GraphAggregation, self).__init__(config) 111 | self.output_attentions = False 112 | 113 | def forward(self, hidden_states, attention_mask=None, rel_pos=None): 114 | query = self.query(hidden_states[:, :1]) # B 1 D 115 | key = self.key(hidden_states) 116 | value = self.value(hidden_states) 117 | station_embed = self.multi_head_attention(query=query, 118 | key=key, 119 | value=value, 120 | attention_mask=attention_mask, 121 | rel_pos=rel_pos)[0] # B 1 D 122 | station_embed = station_embed.squeeze(1) 123 | 124 | return station_embed 125 | 126 | 127 | class GraphBertEncoder(nn.Module): 128 | def __init__(self, config): 129 | super(GraphBertEncoder, self).__init__() 130 | 131 | self.output_attentions = config.output_attentions 132 | self.output_hidden_states = config.output_hidden_states 133 | self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)]) 134 | 135 | self.graph_attention = GraphAggregation(config=config) 136 | 137 | def forward(self, 138 | hidden_states, 139 | attention_mask, 140 | node_mask=None, 141 | node_rel_pos=None, 142 | rel_pos=None): 143 | 144 | all_hidden_states = () 145 | all_attentions = () 146 | 147 | all_nodes_num, seq_length, emb_dim = hidden_states.shape 148 | batch_size, _, _, subgraph_node_num = node_mask.shape 149 | 150 | for i, layer_module in enumerate(self.layer): 151 | if self.output_hidden_states: 152 | all_hidden_states = all_hidden_states + (hidden_states,) 153 | 154 | if i > 0: 155 | 156 | hidden_states = hidden_states.view(batch_size, subgraph_node_num, seq_length, emb_dim) # B SN L D 157 | cls_emb = hidden_states[:, :, 1].clone() # B SN D 158 | station_emb = self.graph_attention(hidden_states=cls_emb, attention_mask=node_mask, 159 | rel_pos=node_rel_pos) # B D 160 | 161 | # update the station in the query/key 162 | hidden_states[:, 0, 0] = station_emb 163 | hidden_states = hidden_states.view(all_nodes_num, seq_length, emb_dim) 164 | 165 | layer_outputs = layer_module(hidden_states, attention_mask=attention_mask, rel_pos=rel_pos) 166 | 167 | else: 168 | temp_attention_mask = attention_mask.clone() 169 | temp_attention_mask[::subgraph_node_num, :, :, 0] = -10000.0 170 | layer_outputs = layer_module(hidden_states, attention_mask=temp_attention_mask, rel_pos=rel_pos) 171 | 172 | hidden_states = layer_outputs[0] 173 | 174 | if self.output_attentions: 175 | all_attentions = all_attentions + (layer_outputs[1],) 176 | 177 | # Add last layer 178 | if self.output_hidden_states: 179 | all_hidden_states = all_hidden_states + (hidden_states,) 180 | 181 | outputs = (hidden_states,) 182 | if self.output_hidden_states: 183 | outputs = outputs + (all_hidden_states,) 184 | if self.output_attentions: 185 | outputs = outputs + (all_attentions,) 186 | 187 | return outputs # last-layer hidden state, (all hidden states), (all attentions) 188 | 189 | 190 | class GraphFormers(TuringNLRv3PreTrainedModel): 191 | def __init__(self, config): 192 | super(GraphFormers, self).__init__(config=config) 193 | self.config = config 194 | self.embeddings = BertEmbeddings(config=config) 195 | self.encoder = GraphBertEncoder(config=config) 196 | 197 | if self.config.rel_pos_bins > 0: 198 | self.rel_pos_bias = nn.Linear(self.config.rel_pos_bins + 2, 199 | config.num_attention_heads, 200 | bias=False) 201 | else: 202 | self.rel_pos_bias = None 203 | 204 | def forward(self, 205 | input_ids, 206 | attention_mask, 207 | neighbor_mask=None): 208 | all_nodes_num, seq_length = input_ids.shape 209 | batch_size, subgraph_node_num = neighbor_mask.shape 210 | 211 | embedding_output, position_ids = self.embeddings(input_ids=input_ids) 212 | 213 | # Add station attention mask 214 | station_mask = torch.zeros((all_nodes_num, 1), dtype=attention_mask.dtype, device=attention_mask.device) 215 | attention_mask = torch.cat([station_mask, attention_mask], dim=-1) # N 1+L 216 | attention_mask[::(subgraph_node_num), 0] = 1.0 # only use the station for main nodes 217 | 218 | node_mask = (1.0 - neighbor_mask[:, None, None, :]) * -10000.0 219 | extended_attention_mask = (1.0 - attention_mask[:, None, None, :]) * -10000.0 220 | 221 | if self.config.rel_pos_bins > 0: 222 | rel_pos_mat = position_ids.unsqueeze(-2) - position_ids.unsqueeze(-1) 223 | rel_pos = relative_position_bucket(rel_pos_mat, num_buckets=self.config.rel_pos_bins, 224 | max_distance=self.config.max_rel_pos) 225 | 226 | # rel_pos: (N,L,L) -> (N,1+L,L) 227 | temp_pos = torch.zeros(all_nodes_num, 1, seq_length, dtype=rel_pos.dtype, device=rel_pos.device) 228 | rel_pos = torch.cat([temp_pos, rel_pos], dim=1) 229 | # rel_pos: (N,1+L,L) -> (N,1+L,1+L) 230 | station_relpos = torch.full((all_nodes_num, seq_length + 1, 1), self.config.rel_pos_bins, 231 | dtype=rel_pos.dtype, device=rel_pos.device) 232 | rel_pos = torch.cat([station_relpos, rel_pos], dim=-1) 233 | 234 | # node_rel_pos:(B:batch_size, Head_num, neighbor_num+1) 235 | node_pos = self.config.rel_pos_bins + 1 236 | node_rel_pos = torch.full((batch_size, subgraph_node_num), node_pos, dtype=rel_pos.dtype, 237 | device=rel_pos.device) 238 | node_rel_pos[:, 0] = 0 239 | node_rel_pos = F.one_hot(node_rel_pos, 240 | num_classes=self.config.rel_pos_bins + 2).type_as( 241 | embedding_output) 242 | node_rel_pos = self.rel_pos_bias(node_rel_pos).permute(0, 2, 1) # B head_num, neighbor_num 243 | node_rel_pos = node_rel_pos.unsqueeze(2) # B head_num 1 neighbor_num 244 | 245 | # rel_pos: (N,Head_num,1+L,1+L) 246 | rel_pos = F.one_hot(rel_pos, num_classes=self.config.rel_pos_bins + 2).type_as( 247 | embedding_output) 248 | rel_pos = self.rel_pos_bias(rel_pos).permute(0, 3, 1, 2) 249 | 250 | else: 251 | node_rel_pos = None 252 | rel_pos = None 253 | 254 | # Add station_placeholder 255 | station_placeholder = torch.zeros(all_nodes_num, 1, embedding_output.size(-1)).type( 256 | embedding_output.dtype).to(embedding_output.device) 257 | embedding_output = torch.cat([station_placeholder, embedding_output], dim=1) # N 1+L D 258 | 259 | encoder_outputs = self.encoder( 260 | embedding_output, 261 | attention_mask=extended_attention_mask, 262 | node_mask=node_mask, 263 | node_rel_pos=node_rel_pos, 264 | rel_pos=rel_pos) 265 | 266 | return encoder_outputs 267 | 268 | 269 | class GraphFormersForNeighborPredict(GraphTuringNLRPreTrainedModel): 270 | def __init__(self, config): 271 | super().__init__(config) 272 | self.bert = GraphFormers(config) 273 | self.init_weights() 274 | 275 | def infer(self, input_ids_node_and_neighbors_batch, attention_mask_node_and_neighbors_batch, 276 | mask_node_and_neighbors_batch): 277 | B, N, L = input_ids_node_and_neighbors_batch.shape 278 | D = self.config.hidden_size 279 | input_ids = input_ids_node_and_neighbors_batch.view(B * N, L) 280 | attention_mask = attention_mask_node_and_neighbors_batch.view(B * N, L) 281 | hidden_states = self.bert(input_ids, attention_mask, mask_node_and_neighbors_batch) 282 | last_hidden_states = hidden_states[0] 283 | cls_embeddings = last_hidden_states[:, 1].view(B, N, D) # [B,N,D] 284 | node_embeddings = cls_embeddings[:, 0, :] # [B,D] 285 | return node_embeddings 286 | 287 | def test(self, input_ids_query_and_neighbors_batch, attention_mask_query_and_neighbors_batch, 288 | mask_query_and_neighbors_batch, \ 289 | input_ids_key_and_neighbors_batch, attention_mask_key_and_neighbors_batch, mask_key_and_neighbors_batch, 290 | **kwargs): 291 | query_embeddings = self.infer(input_ids_query_and_neighbors_batch, attention_mask_query_and_neighbors_batch, 292 | mask_query_and_neighbors_batch) 293 | key_embeddings = self.infer(input_ids_key_and_neighbors_batch, attention_mask_key_and_neighbors_batch, 294 | mask_key_and_neighbors_batch) 295 | scores = torch.matmul(query_embeddings, key_embeddings.transpose(0, 1)) 296 | labels = torch.arange(start=0, end=scores.shape[0], dtype=torch.long, device=scores.device) 297 | 298 | predictions = torch.argmax(scores, dim=-1) 299 | acc = (torch.sum((predictions == labels)) / labels.shape[0]).item() 300 | 301 | scores = scores.cpu().numpy() 302 | labels = F.one_hot(labels).cpu().numpy() 303 | auc_all = [roc_auc_score(labels[i], scores[i]) for i in range(labels.shape[0])] 304 | auc = np.mean(auc_all) 305 | mrr_all = [mrr_score(labels[i], scores[i]) for i in range(labels.shape[0])] 306 | mrr = np.mean(mrr_all) 307 | ndcg_all = [ndcg_score(labels[i], scores[i], labels.shape[1]) for i in range(labels.shape[0])] 308 | ndcg = np.mean(ndcg_all) 309 | 310 | return { 311 | "main": acc, 312 | "acc": acc, 313 | "auc": auc, 314 | "mrr": mrr, 315 | "ndcg": ndcg 316 | } 317 | 318 | def forward(self, input_ids_query_and_neighbors_batch, attention_mask_query_and_neighbors_batch, 319 | mask_query_and_neighbors_batch, \ 320 | input_ids_key_and_neighbors_batch, attention_mask_key_and_neighbors_batch, mask_key_and_neighbors_batch, 321 | **kwargs): 322 | query_embeddings = self.infer(input_ids_query_and_neighbors_batch, attention_mask_query_and_neighbors_batch, 323 | mask_query_and_neighbors_batch) 324 | key_embeddings = self.infer(input_ids_key_and_neighbors_batch, attention_mask_key_and_neighbors_batch, 325 | mask_key_and_neighbors_batch) 326 | score = torch.matmul(query_embeddings, key_embeddings.transpose(0, 1)) 327 | labels = torch.arange(start=0, end=score.shape[0], dtype=torch.long, device=score.device) 328 | loss = F.cross_entropy(score, labels) 329 | return loss 330 | -------------------------------------------------------------------------------- /src/models/modeling_graphsage.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from src.models.tnlrv3.modeling import TuringNLRv3PreTrainedModel, TuringNLRv3Model 7 | from src.utils import roc_auc_score, mrr_score, ndcg_score 8 | 9 | 10 | class GraphSageMaxForNeighborPredict(TuringNLRv3PreTrainedModel): 11 | def __init__(self, config): 12 | super().__init__(config) 13 | self.bert = TuringNLRv3Model(config) 14 | self.graph_transform = nn.Linear(config.hidden_size * 2, config.hidden_size, bias=False) 15 | self.pooling_transform = nn.Linear(config.hidden_size, config.hidden_size) 16 | self.init_weights() 17 | 18 | def aggregation(self, neighbor_embed, neighbor_mask): 19 | neighbor_embed = F.relu(self.pooling_transform(neighbor_embed)) 20 | neighbor_embed = neighbor_embed.masked_fill(neighbor_mask.unsqueeze(2) == 0, 0) 21 | return torch.max(neighbor_embed, dim=-2)[0] 22 | 23 | def graphsage(self, node_embed, node_mask): 24 | neighbor_embed = node_embed[:, 1:] # B N D 25 | neighbor_mask = node_mask[:, 1:] # B N 26 | center_embed = node_embed[:, 0] # B D 27 | neighbor_embed = self.aggregation(neighbor_embed, neighbor_mask) # B D 28 | main_embed = torch.cat([center_embed, neighbor_embed], dim=-1) # B 2D 29 | main_embed = self.graph_transform(main_embed) 30 | main_embed = F.relu(main_embed) 31 | return main_embed 32 | 33 | def infer(self, input_ids_node_and_neighbors_batch, attention_mask_node_and_neighbors_batch, 34 | mask_node_and_neighbors_batch): 35 | B, N, L = input_ids_node_and_neighbors_batch.shape 36 | D = self.config.hidden_size 37 | input_ids = input_ids_node_and_neighbors_batch.view(B * N, L) 38 | attention_mask = attention_mask_node_and_neighbors_batch.view(B * N, L) 39 | hidden_states = self.bert(input_ids, attention_mask) 40 | last_hidden_states = hidden_states[0] 41 | cls_embeddings = last_hidden_states[:, 0].view(B, N, D) # [B,N,D] 42 | node_embeddings = self.graphsage(cls_embeddings, mask_node_and_neighbors_batch) 43 | return node_embeddings 44 | 45 | def test(self, input_ids_query_and_neighbors_batch, attention_mask_query_and_neighbors_batch, 46 | mask_query_and_neighbors_batch, \ 47 | input_ids_key_and_neighbors_batch, attention_mask_key_and_neighbors_batch, mask_key_and_neighbors_batch, 48 | **kwargs): 49 | query_embeddings = self.infer(input_ids_query_and_neighbors_batch, attention_mask_query_and_neighbors_batch, 50 | mask_query_and_neighbors_batch) 51 | key_embeddings = self.infer(input_ids_key_and_neighbors_batch, attention_mask_key_and_neighbors_batch, 52 | mask_key_and_neighbors_batch) 53 | scores = torch.matmul(query_embeddings, key_embeddings.transpose(0, 1)) 54 | labels = torch.arange(start=0, end=scores.shape[0], dtype=torch.long, device=scores.device) 55 | 56 | predictions = torch.argmax(scores, dim=-1) 57 | acc = (torch.sum((predictions == labels)) / labels.shape[0]).item() 58 | 59 | scores = scores.cpu().numpy() 60 | labels = F.one_hot(labels).cpu().numpy() 61 | auc_all = [roc_auc_score(labels[i], scores[i]) for i in range(labels.shape[0])] 62 | auc = np.mean(auc_all) 63 | mrr_all = [mrr_score(labels[i], scores[i]) for i in range(labels.shape[0])] 64 | mrr = np.mean(mrr_all) 65 | ndcg_all = [ndcg_score(labels[i], scores[i], labels.shape[1]) for i in range(labels.shape[0])] 66 | ndcg = np.mean(ndcg_all) 67 | 68 | return { 69 | "main": acc, 70 | "acc": acc, 71 | "auc": auc, 72 | "mrr": mrr, 73 | "ndcg": ndcg 74 | } 75 | 76 | def forward(self, input_ids_query_and_neighbors_batch, attention_mask_query_and_neighbors_batch, 77 | mask_query_and_neighbors_batch, \ 78 | input_ids_key_and_neighbors_batch, attention_mask_key_and_neighbors_batch, mask_key_and_neighbors_batch, 79 | **kwargs): 80 | query_embeddings = self.infer(input_ids_query_and_neighbors_batch, attention_mask_query_and_neighbors_batch, 81 | mask_query_and_neighbors_batch) 82 | key_embeddings = self.infer(input_ids_key_and_neighbors_batch, attention_mask_key_and_neighbors_batch, 83 | mask_key_and_neighbors_batch) 84 | score = torch.matmul(query_embeddings, key_embeddings.transpose(0, 1)) 85 | labels = torch.arange(start=0, end=score.shape[0], dtype=torch.long, device=score.device) 86 | loss = F.cross_entropy(score, labels) 87 | return loss 88 | -------------------------------------------------------------------------------- /src/models/tnlrv3/config.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function, unicode_literals 2 | 3 | import logging 4 | from transformers import BertConfig 5 | from .configuration_tnlrv3 import TuringNLRv3Config 6 | 7 | logger = logging.getLogger(__name__) 8 | 9 | 10 | class TuringNLRv3ForSeq2SeqConfig(BertConfig): 11 | def __init__(self, label_smoothing=0.1, source_type_id=0, target_type_id=1, 12 | rel_pos_bins=0, max_rel_pos=0, fix_word_embedding=False, **kwargs): 13 | super(TuringNLRv3ForSeq2SeqConfig, self).__init__(**kwargs) 14 | self.label_smoothing = label_smoothing 15 | self.source_type_id = source_type_id 16 | self.target_type_id = target_type_id 17 | self.max_rel_pos = max_rel_pos 18 | self.rel_pos_bins = rel_pos_bins 19 | self.fix_word_embedding = fix_word_embedding 20 | 21 | @classmethod 22 | def from_exist_config(cls, config, label_smoothing=0.1, max_position_embeddings=None, fix_word_embedding=False): 23 | required_keys = [ 24 | "vocab_size", "hidden_size", "num_hidden_layers", "num_attention_heads", 25 | "hidden_act", "intermediate_size", "hidden_dropout_prob", "attention_probs_dropout_prob", 26 | "max_position_embeddings", "type_vocab_size", "initializer_range", "layer_norm_eps", 27 | ] 28 | 29 | kwargs = {} 30 | for key in required_keys: 31 | assert hasattr(config, key) 32 | kwargs[key] = getattr(config, key) 33 | 34 | kwargs["vocab_size_or_config_json_file"] = kwargs["vocab_size"] 35 | 36 | additional_keys = [ 37 | "source_type_id", "target_type_id", "rel_pos_bins", "max_rel_pos", 38 | ] 39 | for key in additional_keys: 40 | if hasattr(config, key): 41 | kwargs[key] = getattr(config, key) 42 | 43 | if max_position_embeddings is not None and max_position_embeddings > config.max_position_embeddings: 44 | kwargs["max_position_embeddings"] = max_position_embeddings 45 | logger.info(" ** Change max position embeddings to %d ** " % max_position_embeddings) 46 | 47 | return cls(label_smoothing=label_smoothing, fix_word_embedding=fix_word_embedding, **kwargs) 48 | -------------------------------------------------------------------------------- /src/models/tnlrv3/configuration_tnlrv3.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | """ TuringNLRv3 model configuration """ 3 | 4 | from __future__ import absolute_import, division, print_function, unicode_literals 5 | 6 | import json 7 | import logging 8 | import sys 9 | from io import open 10 | 11 | from transformers.configuration_utils import PretrainedConfig 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | TuringNLRv3_PRETRAINED_CONFIG_ARCHIVE_MAP = { 16 | } 17 | 18 | 19 | class TuringNLRv3Config(PretrainedConfig): 20 | r""" 21 | :class:`~transformers.TuringNLRv3Config` is the configuration class to store the configuration of a 22 | `TuringNLRv3Model`. 23 | Arguments: 24 | vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `TuringNLRv3Model`. 25 | hidden_size: Size of the encoder layers and the pooler layer. 26 | num_hidden_layers: Number of hidden layers in the Transformer encoder. 27 | num_attention_heads: Number of attention heads for each attention layer in 28 | the Transformer encoder. 29 | intermediate_size: The size of the "intermediate" (i.e., feed-forward) 30 | layer in the Transformer encoder. 31 | hidden_act: The non-linear activation function (function or string) in the 32 | encoder and pooler. If string, "gelu", "relu", "swish" and "gelu_new" are supported. 33 | hidden_dropout_prob: The dropout probabilitiy for all fully connected 34 | layers in the embeddings, encoder, and pooler. 35 | attention_probs_dropout_prob: The dropout ratio for the attention 36 | probabilities. 37 | max_position_embeddings: The maximum sequence length that this model might 38 | ever be used with. Typically set this to something large just in case 39 | (e.g., 512 or 1024 or 2048). 40 | type_vocab_size: The vocabulary size of the `token_type_ids` passed into 41 | `TuringNLRv3Model`. 42 | initializer_range: The sttdev of the truncated_normal_initializer for 43 | initializing all weight matrices. 44 | layer_norm_eps: The epsilon used by LayerNorm. 45 | """ 46 | pretrained_config_archive_map = TuringNLRv3_PRETRAINED_CONFIG_ARCHIVE_MAP 47 | 48 | def __init__(self, 49 | vocab_size=28996, 50 | hidden_size=768, 51 | num_hidden_layers=12, 52 | num_attention_heads=12, 53 | intermediate_size=3072, 54 | hidden_act="gelu", 55 | hidden_dropout_prob=0.1, 56 | attention_probs_dropout_prob=0.1, 57 | max_position_embeddings=512, 58 | type_vocab_size=6, 59 | initializer_range=0.02, 60 | layer_norm_eps=1e-12, 61 | source_type_id=0, 62 | target_type_id=1, 63 | **kwargs): 64 | super(TuringNLRv3Config, self).__init__(**kwargs) 65 | if isinstance(vocab_size, str) or (sys.version_info[0] == 2 66 | and isinstance(vocab_size, unicode)): 67 | with open(vocab_size, "r", encoding='utf-8') as reader: 68 | json_config = json.loads(reader.read()) 69 | for key, value in json_config.items(): 70 | self.__dict__[key] = value 71 | elif isinstance(vocab_size, int): 72 | self.vocab_size = vocab_size 73 | self.hidden_size = hidden_size 74 | self.num_hidden_layers = num_hidden_layers 75 | self.num_attention_heads = num_attention_heads 76 | self.hidden_act = hidden_act 77 | self.intermediate_size = intermediate_size 78 | self.hidden_dropout_prob = hidden_dropout_prob 79 | self.attention_probs_dropout_prob = attention_probs_dropout_prob 80 | self.max_position_embeddings = max_position_embeddings 81 | self.type_vocab_size = type_vocab_size 82 | self.initializer_range = initializer_range 83 | self.layer_norm_eps = layer_norm_eps 84 | self.source_type_id = source_type_id 85 | self.target_type_id = target_type_id 86 | else: 87 | raise ValueError("First argument must be either a vocabulary size (int)" 88 | " or the path to a pretrained model config file (str)") 89 | -------------------------------------------------------------------------------- /src/models/tnlrv3/convert_state_dict.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import logging 3 | 4 | from transformers.modeling_utils import cached_path, WEIGHTS_NAME, TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME 5 | 6 | logger = logging.getLogger(__name__) 7 | 8 | 9 | def get_checkpoint_from_transformer_cache( 10 | archive_file, pretrained_model_name_or_path, pretrained_model_archive_map, 11 | cache_dir, force_download, proxies, resume_download, 12 | ): 13 | try: 14 | resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir, force_download=force_download, 15 | proxies=proxies, resume_download=resume_download) 16 | except EnvironmentError: 17 | if pretrained_model_name_or_path in pretrained_model_archive_map: 18 | msg = "Couldn't reach server at '{}' to download pretrained weights.".format( 19 | archive_file) 20 | else: 21 | msg = "Model name '{}' was not found in model name list ({}). " \ 22 | "We assumed '{}' was a path or url to model weight files named one of {} but " \ 23 | "couldn't find any such file at this path or url.".format( 24 | pretrained_model_name_or_path, 25 | ', '.join(pretrained_model_archive_map.keys()), 26 | archive_file, 27 | [WEIGHTS_NAME, TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME]) 28 | raise EnvironmentError(msg) 29 | 30 | if resolved_archive_file == archive_file: 31 | logger.info("loading weights file {}".format(archive_file)) 32 | else: 33 | logger.info("loading weights file {} from cache at {}".format( 34 | archive_file, resolved_archive_file)) 35 | 36 | return torch.load(resolved_archive_file, map_location='cpu') 37 | 38 | 39 | def load_model(state_dict): 40 | new_state_dict = {} 41 | 42 | for key in state_dict: 43 | value = state_dict[key] 44 | if key.endswith("attention.self.q_bias"): 45 | new_state_dict[key.replace("attention.self.q_bias", "attention.self.query.bias")] = value.view(-1) 46 | elif key.endswith("attention.self.v_bias"): 47 | new_state_dict[key.replace("attention.self.v_bias", "attention.self.value.bias")] = value.view(-1) 48 | new_state_dict[key.replace("attention.self.v_bias", "attention.self.key.bias")] = torch.zeros_like(value.view(-1)) 49 | elif key.endswith("attention.self.qkv_linear.weight"): 50 | l, _ = value.size() 51 | assert l % 3 == 0 52 | l = l // 3 53 | q, k, v = torch.split(value, split_size_or_sections=(l, l, l), dim=0) 54 | new_state_dict[key.replace("attention.self.qkv_linear.weight", "attention.self.query.weight")] = q 55 | new_state_dict[key.replace("attention.self.qkv_linear.weight", "attention.self.key.weight")] = k 56 | new_state_dict[key.replace("attention.self.qkv_linear.weight", "attention.self.value.weight")] = v 57 | elif key == "bert.encoder.rel_pos_bias.weight": 58 | new_state_dict["bert.rel_pos_bias.weight"] = value 59 | else: 60 | new_state_dict[key] = value 61 | 62 | del state_dict 63 | 64 | return new_state_dict 65 | 66 | 67 | state_dict_convert = { 68 | 'tnlrv3': load_model, 69 | } 70 | -------------------------------------------------------------------------------- /src/models/tnlrv3/modeling.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function, unicode_literals 2 | 3 | import logging 4 | import math 5 | import os 6 | 7 | import torch 8 | from torch import nn 9 | from torch.nn.modules.loss import _Loss 10 | import torch.nn.functional as F 11 | 12 | from transformers.modeling_bert import \ 13 | BertPreTrainedModel, BertSelfOutput, BertIntermediate, \ 14 | BertOutput, BertPredictionHeadTransform, BertPooler 15 | from transformers.file_utils import WEIGHTS_NAME 16 | 17 | from .config import TuringNLRv3ForSeq2SeqConfig 18 | from .convert_state_dict import get_checkpoint_from_transformer_cache, state_dict_convert 19 | 20 | logger = logging.getLogger(__name__) 21 | 22 | BertLayerNorm = torch.nn.LayerNorm 23 | 24 | TuringNLRv3_PRETRAINED_MODEL_ARCHIVE_MAP = { 25 | } 26 | 27 | 28 | class TuringNLRv3PreTrainedModel(BertPreTrainedModel): 29 | """ An abstract class to handle weights initialization and 30 | a simple interface for dowloading and loading pretrained models. 31 | """ 32 | config_class = TuringNLRv3ForSeq2SeqConfig 33 | supported_convert_pretrained_model_archive_map = { 34 | "tnlrv3": TuringNLRv3_PRETRAINED_MODEL_ARCHIVE_MAP, 35 | } 36 | base_model_prefix = "TuringNLRv3_for_seq2seq" 37 | pretrained_model_archive_map = { 38 | **TuringNLRv3_PRETRAINED_MODEL_ARCHIVE_MAP, 39 | } 40 | 41 | def _init_weights(self, module): 42 | """ Initialize the weights """ 43 | if isinstance(module, (nn.Linear, nn.Embedding)): 44 | # Slightly different from the TF version which uses truncated_normal for initialization 45 | # cf https://github.com/pytorch/pytorch/pull/5617 46 | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) 47 | elif isinstance(module, BertLayerNorm): 48 | module.bias.data.zero_() 49 | module.weight.data.fill_(1.0) 50 | if isinstance(module, nn.Linear) and module.bias is not None: 51 | module.bias.data.zero_() 52 | 53 | @classmethod 54 | def from_pretrained( 55 | cls, pretrained_model_name_or_path, reuse_position_embedding=None, 56 | replace_prefix=None, *model_args, **kwargs, 57 | ): 58 | model_type = kwargs.pop('model_type', 'tnlrv3') 59 | if model_type is not None and "state_dict" not in kwargs: 60 | if model_type in cls.supported_convert_pretrained_model_archive_map: 61 | pretrained_model_archive_map = cls.supported_convert_pretrained_model_archive_map[model_type] 62 | if pretrained_model_name_or_path in pretrained_model_archive_map: 63 | state_dict = get_checkpoint_from_transformer_cache( 64 | archive_file=pretrained_model_archive_map[pretrained_model_name_or_path], 65 | pretrained_model_name_or_path=pretrained_model_name_or_path, 66 | pretrained_model_archive_map=pretrained_model_archive_map, 67 | cache_dir=kwargs.get("cache_dir", None), force_download=kwargs.get("force_download", None), 68 | proxies=kwargs.get("proxies", None), resume_download=kwargs.get("resume_download", None), 69 | ) 70 | state_dict = state_dict_convert[model_type](state_dict) 71 | kwargs["state_dict"] = state_dict 72 | logger.info("Load HF ckpts") 73 | elif os.path.isfile(pretrained_model_name_or_path): 74 | state_dict = torch.load(pretrained_model_name_or_path, map_location='cpu') 75 | kwargs["state_dict"] = state_dict_convert[model_type](state_dict) 76 | logger.info("Load local ckpts") 77 | elif os.path.isdir(pretrained_model_name_or_path): 78 | state_dict = torch.load(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME), map_location='cpu') 79 | kwargs["state_dict"] = state_dict_convert[model_type](state_dict) 80 | logger.info("Load local ckpts") 81 | else: 82 | raise RuntimeError("Not fined the pre-trained checkpoint !") 83 | 84 | if kwargs["state_dict"] is None: 85 | logger.info("TNLRv3 does't support the model !") 86 | raise NotImplementedError() 87 | 88 | config = kwargs["config"] 89 | state_dict = kwargs["state_dict"] 90 | # initialize new position embeddings (From Microsoft/UniLM) 91 | _k = 'bert.embeddings.position_embeddings.weight' 92 | if _k in state_dict: 93 | if config.max_position_embeddings > state_dict[_k].shape[0]: 94 | logger.info("Resize > position embeddings !") 95 | old_vocab_size = state_dict[_k].shape[0] 96 | new_postion_embedding = state_dict[_k].data.new_tensor(torch.ones( 97 | size=(config.max_position_embeddings, state_dict[_k].shape[1])), dtype=torch.float) 98 | new_postion_embedding = nn.Parameter(data=new_postion_embedding, requires_grad=True) 99 | new_postion_embedding.data.normal_(mean=0.0, std=config.initializer_range) 100 | max_range = config.max_position_embeddings if reuse_position_embedding else old_vocab_size 101 | shift = 0 102 | while shift < max_range: 103 | delta = min(old_vocab_size, max_range - shift) 104 | new_postion_embedding.data[shift: shift + delta, :] = state_dict[_k][:delta, :] 105 | logger.info(" CP [%d ~ %d] into [%d ~ %d] " % (0, delta, shift, shift + delta)) 106 | shift += delta 107 | state_dict[_k] = new_postion_embedding.data 108 | del new_postion_embedding 109 | elif config.max_position_embeddings < state_dict[_k].shape[0]: 110 | logger.info("Resize < position embeddings !") 111 | old_vocab_size = state_dict[_k].shape[0] 112 | new_postion_embedding = state_dict[_k].data.new_tensor(torch.ones( 113 | size=(config.max_position_embeddings, state_dict[_k].shape[1])), dtype=torch.float) 114 | new_postion_embedding = nn.Parameter(data=new_postion_embedding, requires_grad=True) 115 | new_postion_embedding.data.normal_(mean=0.0, std=config.initializer_range) 116 | new_postion_embedding.data.copy_(state_dict[_k][:config.max_position_embeddings, :]) 117 | state_dict[_k] = new_postion_embedding.data 118 | del new_postion_embedding 119 | 120 | if replace_prefix is not None: 121 | new_state_dict = {} 122 | for key in state_dict: 123 | if key.startswith(replace_prefix): 124 | new_state_dict[key[len(replace_prefix):]] = state_dict[key] 125 | else: 126 | new_state_dict[key] = state_dict[key] 127 | kwargs["state_dict"] = new_state_dict 128 | del state_dict 129 | 130 | return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) 131 | 132 | 133 | class BertEmbeddings(nn.Module): 134 | """Construct the embeddings from word, position and token_type embeddings. 135 | """ 136 | def __init__(self, config): 137 | super(BertEmbeddings, self).__init__() 138 | self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=0) 139 | fix_word_embedding = getattr(config, "fix_word_embedding", None) 140 | if fix_word_embedding: 141 | self.word_embeddings.weight.requires_grad = False 142 | self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) 143 | if config.type_vocab_size > 0: 144 | self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) 145 | else: 146 | self.token_type_embeddings = None 147 | 148 | # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load 149 | # any TensorFlow checkpoint file 150 | self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps) 151 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 152 | 153 | def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None): 154 | if input_ids is not None: 155 | input_shape = input_ids.size() 156 | else: 157 | input_shape = inputs_embeds.size()[:-1] 158 | 159 | seq_length = input_shape[1] 160 | device = input_ids.device if input_ids is not None else inputs_embeds.device 161 | if position_ids is None: 162 | position_ids = torch.arange(seq_length, dtype=torch.long, device=device) 163 | position_ids = position_ids.unsqueeze(0).expand(input_shape) 164 | if token_type_ids is None: 165 | token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) 166 | 167 | if inputs_embeds is None: 168 | inputs_embeds = self.word_embeddings(input_ids) 169 | position_embeddings = self.position_embeddings(position_ids) 170 | 171 | embeddings = inputs_embeds + position_embeddings 172 | 173 | if self.token_type_embeddings: 174 | embeddings = embeddings + self.token_type_embeddings(token_type_ids) 175 | 176 | embeddings = self.LayerNorm(embeddings) 177 | embeddings = self.dropout(embeddings) 178 | return embeddings, position_ids 179 | 180 | 181 | class BertSelfAttention(nn.Module): 182 | def __init__(self, config): 183 | super(BertSelfAttention, self).__init__() 184 | if config.hidden_size % config.num_attention_heads != 0: 185 | raise ValueError( 186 | "The hidden size (%d) is not a multiple of the number of attention " 187 | "heads (%d)" % (config.hidden_size, config.num_attention_heads)) 188 | self.output_attentions = config.output_attentions 189 | 190 | self.num_attention_heads = config.num_attention_heads 191 | self.attention_head_size = int(config.hidden_size / config.num_attention_heads) 192 | self.all_head_size = self.num_attention_heads * self.attention_head_size 193 | 194 | self.query = nn.Linear(config.hidden_size, self.all_head_size) 195 | self.key = nn.Linear(config.hidden_size, self.all_head_size) 196 | self.value = nn.Linear(config.hidden_size, self.all_head_size) 197 | 198 | self.dropout = nn.Dropout(config.attention_probs_dropout_prob) 199 | 200 | def transpose_for_scores(self, x): 201 | new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) 202 | x = x.view(*new_x_shape) 203 | return x.permute(0, 2, 1, 3) 204 | 205 | def multi_head_attention(self, query, key, value, attention_mask, rel_pos): 206 | query_layer = self.transpose_for_scores(query) 207 | key_layer = self.transpose_for_scores(key) 208 | value_layer = self.transpose_for_scores(value) 209 | 210 | # Take the dot product between "query" and "key" to get the raw attention scores. 211 | attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) 212 | attention_scores = attention_scores / math.sqrt(self.attention_head_size) 213 | if attention_mask is not None: 214 | # Apply the attention mask is (precomputed for all layers in BertModel forward() function) 215 | attention_scores = attention_scores + attention_mask 216 | if rel_pos is not None: 217 | attention_scores = attention_scores + rel_pos 218 | 219 | # Normalize the attention scores to probabilities. 220 | attention_probs = nn.Softmax(dim=-1)(attention_scores) 221 | 222 | # This is actually dropping out entire tokens to attend to, which might 223 | # seem a bit unusual, but is taken from the original Transformer paper. 224 | attention_probs = self.dropout(attention_probs) 225 | context_layer = torch.matmul(attention_probs, value_layer) 226 | 227 | context_layer = context_layer.permute(0, 2, 1, 3).contiguous() 228 | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) 229 | context_layer = context_layer.view(*new_context_layer_shape) 230 | 231 | return (context_layer, attention_probs) if self.output_attentions else (context_layer,) 232 | 233 | def forward(self, hidden_states, attention_mask=None, 234 | encoder_hidden_states=None, 235 | split_lengths=None, rel_pos=None): 236 | mixed_query_layer = self.query(hidden_states) 237 | if split_lengths: 238 | assert not self.output_attentions 239 | 240 | # If this is instantiated as a cross-attention module, the keys 241 | # and values come from an encoder; the attention mask needs to be 242 | # such that the encoder's padding tokens are not attended to. 243 | if encoder_hidden_states is not None: 244 | mixed_key_layer = self.key(encoder_hidden_states) 245 | mixed_value_layer = self.value(encoder_hidden_states) 246 | else: 247 | mixed_key_layer = self.key(hidden_states) 248 | mixed_value_layer = self.value(hidden_states) 249 | 250 | if split_lengths: 251 | query_parts = torch.split(mixed_query_layer, split_lengths, dim=1) 252 | key_parts = torch.split(mixed_key_layer, split_lengths, dim=1) 253 | value_parts = torch.split(mixed_value_layer, split_lengths, dim=1) 254 | 255 | key = None 256 | value = None 257 | outputs = [] 258 | sum_length = 0 259 | for (query, _key, _value, part_length) in zip(query_parts, key_parts, value_parts, split_lengths): 260 | key = _key if key is None else torch.cat((key, _key), dim=1) 261 | value = _value if value is None else torch.cat((value, _value), dim=1) 262 | sum_length += part_length 263 | outputs.append(self.multi_head_attention( 264 | query, key, value, attention_mask[:, :, sum_length - part_length: sum_length, :sum_length], 265 | rel_pos=None if rel_pos is None else rel_pos[:, :, sum_length - part_length: sum_length, :sum_length], 266 | )[0]) 267 | outputs = (torch.cat(outputs, dim=1), ) 268 | else: 269 | outputs = self.multi_head_attention( 270 | mixed_query_layer, mixed_key_layer, mixed_value_layer, 271 | attention_mask, rel_pos=rel_pos) 272 | return outputs 273 | 274 | 275 | class BertAttention(nn.Module): 276 | def __init__(self, config): 277 | super(BertAttention, self).__init__() 278 | self.self = BertSelfAttention(config) 279 | self.output = BertSelfOutput(config) 280 | 281 | def forward(self, hidden_states, attention_mask=None, encoder_hidden_states=None, 282 | split_lengths=None, rel_pos=None): 283 | self_outputs = self.self( 284 | hidden_states, attention_mask=attention_mask, 285 | encoder_hidden_states=encoder_hidden_states, 286 | split_lengths=split_lengths, rel_pos=rel_pos) 287 | attention_output = self.output(self_outputs[0], hidden_states) 288 | outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them 289 | return outputs 290 | 291 | 292 | class BertLayer(nn.Module): 293 | def __init__(self, config): 294 | super(BertLayer, self).__init__() 295 | self.attention = BertAttention(config) 296 | self.intermediate = BertIntermediate(config) 297 | self.output = BertOutput(config) 298 | 299 | def forward(self, hidden_states, attention_mask=None, split_lengths=None, rel_pos=None): 300 | self_attention_outputs = self.attention( 301 | hidden_states, attention_mask, 302 | split_lengths=split_lengths, rel_pos=rel_pos) 303 | attention_output = self_attention_outputs[0] 304 | 305 | intermediate_output = self.intermediate(attention_output) 306 | layer_output = self.output(intermediate_output, attention_output) 307 | outputs = (layer_output,) + self_attention_outputs[1:] 308 | return outputs 309 | 310 | 311 | class BertEncoder(nn.Module): 312 | def __init__(self, config): 313 | super(BertEncoder, self).__init__() 314 | self.output_attentions = config.output_attentions 315 | self.output_hidden_states = config.output_hidden_states 316 | self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)]) 317 | 318 | def forward(self, hidden_states, attention_mask=None, split_lengths=None, rel_pos=None): 319 | all_hidden_states = () 320 | all_attentions = () 321 | for i, layer_module in enumerate(self.layer): 322 | if self.output_hidden_states: 323 | all_hidden_states = all_hidden_states + (hidden_states,) 324 | 325 | layer_outputs = layer_module( 326 | hidden_states, attention_mask, 327 | split_lengths=split_lengths, rel_pos=rel_pos) 328 | hidden_states = layer_outputs[0] 329 | 330 | if self.output_attentions: 331 | all_attentions = all_attentions + (layer_outputs[1],) 332 | 333 | # Add last layer 334 | if self.output_hidden_states: 335 | all_hidden_states = all_hidden_states + (hidden_states,) 336 | 337 | outputs = (hidden_states,) 338 | if self.output_hidden_states: 339 | outputs = outputs + (all_hidden_states,) 340 | if self.output_attentions: 341 | outputs = outputs + (all_attentions,) 342 | return outputs # last-layer hidden state, (all hidden states), (all attentions) 343 | 344 | 345 | def relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128): 346 | """ 347 | Adapted from Mesh Tensorflow: 348 | https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593 349 | """ 350 | ret = 0 351 | if bidirectional: 352 | num_buckets //= 2 353 | # mtf.to_int32(mtf.less(n, 0)) * num_buckets 354 | ret += (relative_position > 0).long() * num_buckets 355 | n = torch.abs(relative_position) 356 | else: 357 | n = torch.max(-relative_position, torch.zeros_like(relative_position)) 358 | # now n is in the range [0, inf) 359 | 360 | # half of the buckets are for exact increments in positions 361 | max_exact = num_buckets // 2 362 | is_small = n < max_exact 363 | 364 | # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance 365 | val_if_large = max_exact + ( 366 | torch.log(n.float() / max_exact) / math.log(max_distance / 367 | max_exact) * (num_buckets - max_exact) 368 | ).to(torch.long) 369 | val_if_large = torch.min( 370 | val_if_large, torch.full_like(val_if_large, num_buckets - 1)) 371 | 372 | ret += torch.where(is_small, n, val_if_large) 373 | return ret 374 | 375 | 376 | class TuringNLRv3Model(TuringNLRv3PreTrainedModel): 377 | r""" 378 | Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: 379 | **last_hidden_state**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)`` 380 | Sequence of hidden-states at the output of the last layer of the model. 381 | **pooler_output**: ``torch.FloatTensor`` of shape ``(batch_size, hidden_size)`` 382 | Last layer hidden-state of the first token of the sequence (classification token) 383 | further processed by a Linear layer and a Tanh activation function. The Linear 384 | layer weights are trained from the next sentence prediction (classification) 385 | objective during Bert pretraining. This output is usually *not* a good summary 386 | of the semantic content of the input, you're often better with averaging or pooling 387 | the sequence of hidden-states for the whole input sequence. 388 | **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``) 389 | list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings) 390 | of shape ``(batch_size, sequence_length, hidden_size)``: 391 | Hidden-states of the model at the output of each layer plus the initial embedding outputs. 392 | **attentions**: (`optional`, returned when ``config.output_attentions=True``) 393 | list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``: 394 | Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. 395 | 396 | Examples:: 397 | 398 | tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') 399 | model = BertModel.from_pretrained('bert-base-uncased') 400 | input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1 401 | outputs = model(input_ids) 402 | last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple 403 | 404 | """ 405 | def __init__(self, config): 406 | super(TuringNLRv3Model, self).__init__(config) 407 | self.config = config 408 | 409 | self.embeddings = BertEmbeddings(config) 410 | self.encoder = BertEncoder(config) 411 | if not isinstance(config, TuringNLRv3ForSeq2SeqConfig): 412 | self.pooler = BertPooler(config) 413 | else: 414 | self.pooler = None 415 | 416 | if self.config.rel_pos_bins > 0: 417 | self.rel_pos_bias = nn.Linear(self.config.rel_pos_bins, config.num_attention_heads, bias=False) 418 | else: 419 | self.rel_pos_bias = None 420 | 421 | def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, 422 | position_ids=None, inputs_embeds=None, split_lengths=None): 423 | if input_ids is not None and inputs_embeds is not None: 424 | raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") 425 | elif input_ids is not None: 426 | input_shape = input_ids.size() 427 | elif inputs_embeds is not None: 428 | input_shape = inputs_embeds.size()[:-1] 429 | else: 430 | raise ValueError("You have to specify either input_ids or inputs_embeds") 431 | 432 | device = input_ids.device if input_ids is not None else inputs_embeds.device 433 | 434 | if attention_mask is None: 435 | attention_mask = torch.ones(input_shape, device=device) 436 | 437 | # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] 438 | # ourselves in which case we just need to make it broadcastable to all heads. 439 | if attention_mask.dim() == 3: 440 | extended_attention_mask = attention_mask[:, None, :, :] 441 | 442 | # Provided a padding mask of dimensions [batch_size, seq_length] 443 | # - if the model is a decoder, apply a causal mask in addition to the padding mask 444 | # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length] 445 | if attention_mask.dim() == 2: 446 | extended_attention_mask = attention_mask[:, None, None, :] 447 | 448 | # Since attention_mask is 1.0 for positions we want to attend and 0.0 for 449 | # masked positions, this operation will create a tensor which is 0.0 for 450 | # positions we want to attend and -10000.0 for masked positions. 451 | # Since we are adding it to the raw scores before the softmax, this is 452 | # effectively the same as removing these entirely. 453 | extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility 454 | extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 455 | 456 | embedding_output, position_ids = self.embeddings( 457 | input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds) 458 | if self.config.rel_pos_bins > 0: 459 | rel_pos_mat = position_ids.unsqueeze(-2) - position_ids.unsqueeze(-1) 460 | rel_pos = relative_position_bucket( 461 | rel_pos_mat, num_buckets=self.config.rel_pos_bins, max_distance=self.config.max_rel_pos) 462 | # print(rel_pos.shape,self.config.rel_pos_bins) 463 | rel_pos = F.one_hot(rel_pos, num_classes=self.config.rel_pos_bins).type_as(embedding_output) 464 | rel_pos = self.rel_pos_bias(rel_pos).permute(0, 3, 1, 2) 465 | else: 466 | rel_pos = None 467 | encoder_outputs = self.encoder( 468 | embedding_output, attention_mask=extended_attention_mask, 469 | split_lengths=split_lengths, rel_pos=rel_pos) 470 | sequence_output = encoder_outputs[0] 471 | 472 | outputs = (sequence_output, ) + encoder_outputs[1:] # add hidden_states and attentions if they are here 473 | if self.pooler is None: 474 | return outputs # sequence_output, pooled_output, (hidden_states), (attentions) 475 | else: 476 | pooled_output = self.pooler(sequence_output) 477 | return (sequence_output, pooled_output) + encoder_outputs[1:] 478 | 479 | 480 | class LabelSmoothingLoss(_Loss): 481 | """ 482 | With label smoothing, 483 | KL-divergence between q_{smoothed ground truth prob.}(w) 484 | and p_{prob. computed by model}(w) is minimized. 485 | """ 486 | 487 | def __init__(self, label_smoothing=0, tgt_vocab_size=0, ignore_index=0, size_average=None, reduce=None, reduction='mean'): 488 | assert 0.0 < label_smoothing <= 1.0 489 | self.ignore_index = ignore_index 490 | super(LabelSmoothingLoss, self).__init__( 491 | size_average=size_average, reduce=reduce, reduction=reduction) 492 | 493 | assert label_smoothing > 0 494 | assert tgt_vocab_size > 0 495 | 496 | smoothing_value = label_smoothing / (tgt_vocab_size - 2) 497 | one_hot = torch.full((tgt_vocab_size,), smoothing_value) 498 | one_hot[self.ignore_index] = 0 499 | self.register_buffer('one_hot', one_hot.unsqueeze(0)) 500 | self.confidence = 1.0 - label_smoothing 501 | self.tgt_vocab_size = tgt_vocab_size 502 | 503 | def forward(self, output, target): 504 | """ 505 | output (FloatTensor): batch_size * num_pos * n_classes 506 | target (LongTensor): batch_size * num_pos 507 | """ 508 | assert self.tgt_vocab_size == output.size(2) 509 | batch_size, num_pos = target.size(0), target.size(1) 510 | output = output.view(-1, self.tgt_vocab_size) 511 | target = target.view(-1) 512 | model_prob = self.one_hot.float().repeat(target.size(0), 1) 513 | model_prob.scatter_(1, target.unsqueeze(1), self.confidence) 514 | model_prob.masked_fill_((target == self.ignore_index).unsqueeze(1), 0) 515 | 516 | return F.kl_div(output, model_prob, reduction='none').view(batch_size, num_pos, -1).sum(2) 517 | 518 | 519 | class BertLMPredictionHead(nn.Module): 520 | def __init__(self, config, decoder_weight): 521 | super(BertLMPredictionHead, self).__init__() 522 | self.transform = BertPredictionHeadTransform(config) 523 | 524 | # The output weights are the same as the input embeddings, but there is 525 | # an output-only bias for each token. 526 | self.decoder_weight = decoder_weight 527 | 528 | self.bias = nn.Parameter(torch.zeros(config.vocab_size)) 529 | 530 | def forward(self, hidden_states): 531 | hidden_states = self.transform(hidden_states) 532 | hidden_states = F.linear(hidden_states, weight=self.decoder_weight, bias=self.bias) 533 | return hidden_states 534 | 535 | 536 | class BertOnlyMLMHead(nn.Module): 537 | def __init__(self, config, decoder_weight): 538 | super(BertOnlyMLMHead, self).__init__() 539 | self.predictions = BertLMPredictionHead(config, decoder_weight) 540 | 541 | def forward(self, sequence_output): 542 | prediction_scores = self.predictions(sequence_output) 543 | return prediction_scores 544 | 545 | 546 | def create_mask_and_position_ids(num_tokens, max_len, offset=None): 547 | base_position_matrix = torch.arange( 548 | 0, max_len, dtype=num_tokens.dtype, device=num_tokens.device).view(1, -1) 549 | mask = (base_position_matrix < num_tokens.view(-1, 1)).type_as(num_tokens) 550 | if offset is not None: 551 | base_position_matrix = base_position_matrix + offset.view(-1, 1) 552 | position_ids = base_position_matrix * mask 553 | return mask, position_ids 554 | 555 | 556 | class TuringNLRv3ForSequenceToSequence(TuringNLRv3PreTrainedModel): 557 | MODEL_NAME = 'basic class' 558 | 559 | def __init__(self, config): 560 | super(TuringNLRv3ForSequenceToSequence, self).__init__(config) 561 | self.bert = TuringNLRv3Model(config) 562 | self.cls = BertOnlyMLMHead(config, self.bert.embeddings.word_embeddings.weight) 563 | self.init_weights() 564 | 565 | self.log_softmax = nn.LogSoftmax() 566 | 567 | self.source_type_id = config.source_type_id 568 | self.target_type_id = config.target_type_id 569 | 570 | if config.label_smoothing > 0: 571 | self.crit_mask_lm_smoothed = LabelSmoothingLoss( 572 | config.label_smoothing, config.vocab_size, ignore_index=0, reduction='none') 573 | self.crit_mask_lm = None 574 | else: 575 | self.crit_mask_lm_smoothed = None 576 | self.crit_mask_lm = nn.CrossEntropyLoss(reduction='none') 577 | 578 | 579 | class TuringNLRv3ForSequenceToSequenceWithPseudoMask(TuringNLRv3ForSequenceToSequence): 580 | MODEL_NAME = "TuringNLRv3ForSequenceToSequenceWithPseudoMask" 581 | 582 | @staticmethod 583 | def create_attention_mask(source_mask, target_mask, source_position_ids, target_span_ids): 584 | weight = torch.cat((torch.zeros_like(source_position_ids), target_span_ids, -target_span_ids), dim=1) 585 | from_weight = weight.unsqueeze(-1) 586 | to_weight = weight.unsqueeze(1) 587 | 588 | true_tokens = (0 <= to_weight) & (torch.cat((source_mask, target_mask, target_mask), dim=1) == 1).unsqueeze(1) 589 | true_tokens_mask = (from_weight >= 0) & true_tokens & (to_weight <= from_weight) 590 | pseudo_tokens_mask = (from_weight < 0) & true_tokens & (-to_weight > from_weight) 591 | pseudo_tokens_mask = pseudo_tokens_mask | ((from_weight < 0) & (to_weight == from_weight)) 592 | 593 | return (true_tokens_mask | pseudo_tokens_mask).type_as(source_mask) 594 | 595 | def forward( 596 | self, source_ids, target_ids, label_ids, pseudo_ids, 597 | num_source_tokens, num_target_tokens, target_span_ids=None, target_no_offset=None): 598 | source_len = source_ids.size(1) 599 | target_len = target_ids.size(1) 600 | pseudo_len = pseudo_ids.size(1) 601 | assert target_len == pseudo_len 602 | assert source_len > 0 and target_len > 0 603 | split_lengths = (source_len, target_len, pseudo_len) 604 | 605 | input_ids = torch.cat((source_ids, target_ids, pseudo_ids), dim=1) 606 | 607 | token_type_ids = torch.cat( 608 | (torch.ones_like(source_ids) * self.source_type_id, 609 | torch.ones_like(target_ids) * self.target_type_id, 610 | torch.ones_like(pseudo_ids) * self.target_type_id), dim=1) 611 | 612 | source_mask, source_position_ids = \ 613 | create_mask_and_position_ids(num_source_tokens, source_len) 614 | target_mask, target_position_ids = \ 615 | create_mask_and_position_ids( 616 | num_target_tokens, target_len, offset=None if target_no_offset else num_source_tokens) 617 | 618 | position_ids = torch.cat((source_position_ids, target_position_ids, target_position_ids), dim=1) 619 | if target_span_ids is None: 620 | target_span_ids = target_position_ids 621 | attention_mask = self.create_attention_mask(source_mask, target_mask, source_position_ids, target_span_ids) 622 | 623 | outputs = self.bert( 624 | input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, 625 | position_ids=position_ids, split_lengths=split_lengths) 626 | 627 | sequence_output = outputs[0] 628 | pseudo_sequence_output = sequence_output[:, source_len + target_len:, ] 629 | 630 | def loss_mask_and_normalize(loss, mask): 631 | mask = mask.type_as(loss) 632 | loss = loss * mask 633 | denominator = torch.sum(mask) + 1e-5 634 | return (loss / denominator).sum() 635 | 636 | prediction_scores_masked = self.cls(pseudo_sequence_output) 637 | 638 | if self.crit_mask_lm_smoothed: 639 | masked_lm_loss = self.crit_mask_lm_smoothed( 640 | F.log_softmax(prediction_scores_masked.float(), dim=-1), label_ids) 641 | else: 642 | masked_lm_loss = self.crit_mask_lm( 643 | prediction_scores_masked.transpose(1, 2).float(), label_ids) 644 | pseudo_lm_loss = loss_mask_and_normalize( 645 | masked_lm_loss.float(), target_mask) 646 | 647 | return pseudo_lm_loss 648 | 649 | 650 | class TuringNLRv3ForSequenceToSequenceUniLMV1(TuringNLRv3ForSequenceToSequence): 651 | MODEL_NAME = "TuringNLRv3ForSequenceToSequenceUniLMV1" 652 | 653 | @staticmethod 654 | def create_attention_mask(source_mask, target_mask, source_position_ids, target_span_ids): 655 | weight = torch.cat((torch.zeros_like(source_position_ids), target_span_ids), dim=1) 656 | from_weight = weight.unsqueeze(-1) 657 | to_weight = weight.unsqueeze(1) 658 | 659 | true_tokens = torch.cat((source_mask, target_mask), dim=1).unsqueeze(1) 660 | return ((true_tokens == 1) & (to_weight <= from_weight)).type_as(source_mask) 661 | 662 | def forward(self, source_ids, target_ids, masked_ids, masked_pos, masked_weight, num_source_tokens, num_target_tokens): 663 | source_len = source_ids.size(1) 664 | target_len = target_ids.size(1) 665 | split_lengths = (source_len, target_len) 666 | 667 | input_ids = torch.cat((source_ids, target_ids), dim=1) 668 | 669 | token_type_ids = torch.cat( 670 | (torch.ones_like(source_ids) * self.source_type_id, 671 | torch.ones_like(target_ids) * self.target_type_id), dim=1) 672 | 673 | source_mask, source_position_ids = \ 674 | create_mask_and_position_ids(num_source_tokens, source_len) 675 | target_mask, target_position_ids = \ 676 | create_mask_and_position_ids( 677 | num_target_tokens, target_len, offset=num_source_tokens) 678 | 679 | position_ids = torch.cat((source_position_ids, target_position_ids), dim=1) 680 | attention_mask = self.create_attention_mask( 681 | source_mask, target_mask, source_position_ids, target_position_ids) 682 | 683 | outputs = self.bert( 684 | input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, 685 | position_ids=position_ids, split_lengths=split_lengths) 686 | 687 | def gather_seq_out_by_pos(seq, pos): 688 | return torch.gather(seq, 1, pos.unsqueeze(2).expand(-1, -1, seq.size(-1))) 689 | 690 | sequence_output = outputs[0] 691 | target_sequence_output = sequence_output[:, source_len:, ] 692 | masked_sequence_output = gather_seq_out_by_pos(target_sequence_output, masked_pos) 693 | 694 | def loss_mask_and_normalize(loss, mask): 695 | mask = mask.type_as(loss) 696 | loss = loss * mask 697 | denominator = torch.sum(mask) + 1e-5 698 | return (loss / denominator).sum() 699 | 700 | prediction_scores_masked = self.cls(masked_sequence_output) 701 | 702 | if self.crit_mask_lm_smoothed: 703 | masked_lm_loss = self.crit_mask_lm_smoothed( 704 | F.log_softmax(prediction_scores_masked.float(), dim=-1), masked_ids) 705 | else: 706 | masked_lm_loss = self.crit_mask_lm( 707 | prediction_scores_masked.transpose(1, 2).float(), masked_ids) 708 | pseudo_lm_loss = loss_mask_and_normalize( 709 | masked_lm_loss.float(), masked_weight) 710 | 711 | return pseudo_lm_loss 712 | 713 | 714 | class TuringNLRv3ForSequenceClassification(TuringNLRv3PreTrainedModel): 715 | def __init__(self, config): 716 | super().__init__(config) 717 | self.num_labels = config.num_labels 718 | 719 | self.bert = TuringNLRv3Model(config) 720 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 721 | self.classifier = nn.Linear(config.hidden_size, config.num_labels) 722 | 723 | self.init_weights() 724 | 725 | def forward( 726 | self, 727 | input_ids=None, 728 | attention_mask=None, 729 | token_type_ids=None, 730 | position_ids=None, 731 | head_mask=None, 732 | inputs_embeds=None, 733 | labels=None, 734 | ): 735 | r""" 736 | labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`): 737 | Labels for computing the sequence classification/regression loss. 738 | Indices should be in :obj:`[0, ..., config.num_labels - 1]`. 739 | If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss), 740 | If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy). 741 | 742 | Returns: 743 | :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs: 744 | loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`label` is provided): 745 | Classification (or regression if config.num_labels==1) loss. 746 | logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.num_labels)`): 747 | Classification (or regression if config.num_labels==1) scores (before SoftMax). 748 | hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``): 749 | Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) 750 | of shape :obj:`(batch_size, sequence_length, hidden_size)`. 751 | 752 | Hidden-states of the model at the output of each layer plus the initial embedding outputs. 753 | attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``): 754 | Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape 755 | :obj:`(batch_size, num_heads, sequence_length, sequence_length)`. 756 | 757 | Attentions weights after the attention softmax, used to compute the weighted average in the self-attention 758 | heads. 759 | 760 | Examples:: 761 | 762 | from transformers import BertTokenizer, BertForSequenceClassification 763 | import torch 764 | 765 | tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') 766 | model = BertForSequenceClassification.from_pretrained('bert-base-uncased') 767 | 768 | input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1 769 | labels = torch.tensor([1]).unsqueeze(0) # Batch size 1 770 | outputs = model(input_ids, labels=labels) 771 | 772 | loss, logits = outputs[:2] 773 | 774 | """ 775 | 776 | outputs = self.bert( 777 | input_ids, 778 | attention_mask=attention_mask, 779 | token_type_ids=token_type_ids, 780 | position_ids=position_ids, 781 | # head_mask=head_mask, 782 | inputs_embeds=inputs_embeds, 783 | ) 784 | 785 | pooled_output = outputs[1] 786 | 787 | pooled_output = self.dropout(pooled_output) 788 | logits = self.classifier(pooled_output) 789 | 790 | outputs = (logits,) + outputs[:] # add hidden states and attention if they are here 791 | 792 | if labels is not None: 793 | if self.num_labels == 1: 794 | # We are doing regression 795 | loss_fct = nn.MSELoss() 796 | loss = loss_fct(logits.view(-1), labels.view(-1)) 797 | else: 798 | loss_fct = nn.CrossEntropyLoss() 799 | loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) 800 | outputs = (loss,) + outputs 801 | 802 | return outputs # (loss), logits, last_hidden_state, pooled_output, (hidden_states), (attentions) 803 | -------------------------------------------------------------------------------- /src/models/tnlrv3/modeling_decoding.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | import os 8 | import copy 9 | import json 10 | import math 11 | import logging 12 | import numpy as np 13 | 14 | import torch 15 | from torch import nn 16 | from torch.nn import CrossEntropyLoss 17 | import torch.nn.functional as F 18 | 19 | from torch.nn.modules.loss import _Loss 20 | 21 | 22 | class LabelSmoothingLoss(_Loss): 23 | """ 24 | With label smoothing, 25 | KL-divergence between q_{smoothed ground truth prob.}(w) 26 | and p_{prob. computed by model}(w) is minimized. 27 | """ 28 | 29 | def __init__(self, label_smoothing=0, tgt_vocab_size=0, ignore_index=0, size_average=None, reduce=None, 30 | reduction='mean'): 31 | assert 0.0 < label_smoothing <= 1.0 32 | self.ignore_index = ignore_index 33 | super(LabelSmoothingLoss, self).__init__( 34 | size_average=size_average, reduce=reduce, reduction=reduction) 35 | 36 | assert label_smoothing > 0 37 | assert tgt_vocab_size > 0 38 | 39 | smoothing_value = label_smoothing / (tgt_vocab_size - 2) 40 | one_hot = torch.full((tgt_vocab_size,), smoothing_value) 41 | one_hot[self.ignore_index] = 0 42 | self.register_buffer('one_hot', one_hot.unsqueeze(0)) 43 | self.confidence = 1.0 - label_smoothing 44 | self.tgt_vocab_size = tgt_vocab_size 45 | 46 | def forward(self, output, target): 47 | """ 48 | output (FloatTensor): batch_size * num_pos * n_classes 49 | target (LongTensor): batch_size * num_pos 50 | """ 51 | assert self.tgt_vocab_size == output.size(2) 52 | batch_size, num_pos = target.size(0), target.size(1) 53 | output = output.view(-1, self.tgt_vocab_size) 54 | target = target.view(-1) 55 | model_prob = self.one_hot.repeat(target.size(0), 1) 56 | model_prob.scatter_(1, target.unsqueeze(1), self.confidence) 57 | model_prob.masked_fill_((target == self.ignore_index).unsqueeze(1), 0) 58 | 59 | return F.kl_div(output, model_prob, reduction='none').view(batch_size, num_pos, -1).sum(2) 60 | 61 | 62 | logger = logging.getLogger(__name__) 63 | 64 | from transformers import WEIGHTS_NAME 65 | 66 | 67 | def gelu(x): 68 | """Implementation of the gelu activation function. 69 | For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): 70 | 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) 71 | """ 72 | return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) 73 | 74 | 75 | def swish(x): 76 | return x * torch.sigmoid(x) 77 | 78 | 79 | ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish} 80 | 81 | 82 | class BertConfig(object): 83 | """Configuration class to store the configuration of a `BertModel`. 84 | """ 85 | 86 | def __init__(self, 87 | vocab_size_or_config_json_file, 88 | hidden_size=768, 89 | num_hidden_layers=12, 90 | num_attention_heads=12, 91 | intermediate_size=3072, 92 | hidden_act="gelu", 93 | hidden_dropout_prob=0.1, 94 | attention_probs_dropout_prob=0.1, 95 | max_position_embeddings=512, 96 | type_vocab_size=2, 97 | relax_projection=0, 98 | new_pos_ids=False, 99 | initializer_range=0.02, 100 | task_idx=None, 101 | fp32_embedding=False, 102 | ffn_type=0, 103 | label_smoothing=None, 104 | num_qkv=0, 105 | seg_emb=False, 106 | source_type_id=0, 107 | target_type_id=1, 108 | rel_pos_bins=0, 109 | max_rel_pos=0, **kwargs): 110 | """Constructs BertConfig. 111 | Args: 112 | vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `BertModel`. 113 | hidden_size: Size of the encoder layers and the pooler layer. 114 | num_hidden_layers: Number of hidden layers in the Transformer encoder. 115 | num_attention_heads: Number of attention heads for each attention layer in 116 | the Transformer encoder. 117 | intermediate_size: The size of the "intermediate" (i.e., feed-forward) 118 | layer in the Transformer encoder. 119 | hidden_act: The non-linear activation function (function or string) in the 120 | encoder and pooler. If string, "gelu", "relu" and "swish" are supported. 121 | hidden_dropout_prob: The dropout probabilitiy for all fully connected 122 | layers in the embeddings, encoder, and pooler. 123 | attention_probs_dropout_prob: The dropout ratio for the attention 124 | probabilities. 125 | max_position_embeddings: The maximum sequence length that this model might 126 | ever be used with. Typically set this to something large just in case 127 | (e.g., 512 or 1024 or 2048). 128 | type_vocab_size: The vocabulary size of the `token_type_ids` passed into 129 | `BertModel`. 130 | initializer_range: The sttdev of the truncated_normal_initializer for 131 | initializing all weight matrices. 132 | """ 133 | if isinstance(vocab_size_or_config_json_file, str): 134 | with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader: 135 | json_config = json.loads(reader.read()) 136 | for key, value in json_config.items(): 137 | self.__dict__[key] = value 138 | elif isinstance(vocab_size_or_config_json_file, int): 139 | self.vocab_size = vocab_size_or_config_json_file 140 | self.hidden_size = hidden_size 141 | self.num_hidden_layers = num_hidden_layers 142 | self.num_attention_heads = num_attention_heads 143 | self.hidden_act = hidden_act 144 | self.intermediate_size = intermediate_size 145 | self.hidden_dropout_prob = hidden_dropout_prob 146 | self.attention_probs_dropout_prob = attention_probs_dropout_prob 147 | self.max_position_embeddings = max_position_embeddings 148 | self.type_vocab_size = type_vocab_size 149 | self.relax_projection = relax_projection 150 | self.new_pos_ids = new_pos_ids 151 | self.initializer_range = initializer_range 152 | self.task_idx = task_idx 153 | self.fp32_embedding = fp32_embedding 154 | self.ffn_type = ffn_type 155 | self.label_smoothing = label_smoothing 156 | self.num_qkv = num_qkv 157 | self.seg_emb = seg_emb 158 | self.source_type_id = source_type_id 159 | self.target_type_id = target_type_id 160 | self.max_rel_pos = max_rel_pos 161 | self.rel_pos_bins = rel_pos_bins 162 | else: 163 | raise ValueError("First argument must be either a vocabulary size (int)" 164 | "or the path to a pretrained model config file (str)") 165 | 166 | @classmethod 167 | def from_dict(cls, json_object): 168 | """Constructs a `BertConfig` from a Python dictionary of parameters.""" 169 | config = BertConfig(vocab_size_or_config_json_file=-1) 170 | for key, value in json_object.items(): 171 | config.__dict__[key] = value 172 | return config 173 | 174 | @classmethod 175 | def from_json_file(cls, json_file): 176 | """Constructs a `BertConfig` from a json file of parameters.""" 177 | with open(json_file, "r", encoding='utf-8') as reader: 178 | text = reader.read() 179 | return cls.from_dict(json.loads(text)) 180 | 181 | def __repr__(self): 182 | return str(self.to_json_string()) 183 | 184 | def to_dict(self): 185 | """Serializes this instance to a Python dictionary.""" 186 | output = copy.deepcopy(self.__dict__) 187 | return output 188 | 189 | def to_json_string(self): 190 | """Serializes this instance to a JSON string.""" 191 | return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" 192 | 193 | 194 | try: 195 | from apex.normalization.fused_layer_norm import FusedLayerNorm as BertLayerNorm 196 | except ImportError: 197 | print("Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.") 198 | 199 | 200 | class BertLayerNorm(nn.Module): 201 | def __init__(self, hidden_size, eps=1e-5): 202 | """Construct a layernorm module in the TF style (epsilon inside the square root). 203 | """ 204 | super(BertLayerNorm, self).__init__() 205 | self.weight = nn.Parameter(torch.ones(hidden_size)) 206 | self.bias = nn.Parameter(torch.zeros(hidden_size)) 207 | self.variance_epsilon = eps 208 | 209 | def forward(self, x): 210 | u = x.mean(-1, keepdim=True) 211 | s = (x - u).pow(2).mean(-1, keepdim=True) 212 | x = (x - u) / torch.sqrt(s + self.variance_epsilon) 213 | return self.weight * x + self.bias 214 | 215 | 216 | class PositionalEmbedding(nn.Module): 217 | def __init__(self, demb): 218 | super(PositionalEmbedding, self).__init__() 219 | 220 | self.demb = demb 221 | 222 | inv_freq = 1 / (10000 ** (torch.arange(0.0, demb, 2.0) / demb)) 223 | self.register_buffer('inv_freq', inv_freq) 224 | 225 | def forward(self, pos_seq, bsz=None): 226 | sinusoid_inp = torch.ger(pos_seq, self.inv_freq) 227 | pos_emb = torch.cat([sinusoid_inp.sin(), sinusoid_inp.cos()], dim=-1) 228 | 229 | if bsz is not None: 230 | return pos_emb[:, None, :].expand(-1, bsz, -1) 231 | else: 232 | return pos_emb[:, None, :] 233 | 234 | 235 | class BertEmbeddings(nn.Module): 236 | """Construct the embeddings from word, position and token_type embeddings. 237 | """ 238 | 239 | def __init__(self, config): 240 | super(BertEmbeddings, self).__init__() 241 | self.word_embeddings = nn.Embedding( 242 | config.vocab_size, config.hidden_size) 243 | if config.type_vocab_size == 0: 244 | self.token_type_embeddings = None 245 | else: 246 | self.token_type_embeddings = nn.Embedding( 247 | config.type_vocab_size, config.hidden_size) 248 | if hasattr(config, 'fp32_embedding'): 249 | self.fp32_embedding = config.fp32_embedding 250 | else: 251 | self.fp32_embedding = False 252 | 253 | if hasattr(config, 'new_pos_ids') and config.new_pos_ids: 254 | self.num_pos_emb = 4 255 | else: 256 | self.num_pos_emb = 1 257 | self.position_embeddings = nn.Embedding( 258 | config.max_position_embeddings, config.hidden_size * self.num_pos_emb) 259 | 260 | # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load 261 | # any TensorFlow checkpoint file 262 | self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-5) 263 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 264 | 265 | def forward(self, input_ids, token_type_ids=None, position_ids=None, task_idx=None): 266 | seq_length = input_ids.size(1) 267 | if position_ids is None: 268 | position_ids = torch.arange( 269 | seq_length, dtype=torch.long, device=input_ids.device) 270 | position_ids = position_ids.unsqueeze(0).expand_as(input_ids) 271 | if token_type_ids is None: 272 | token_type_ids = torch.zeros_like(input_ids) 273 | 274 | words_embeddings = self.word_embeddings(input_ids) 275 | position_embeddings = self.position_embeddings(position_ids) 276 | 277 | if self.num_pos_emb > 1: 278 | num_batch = position_embeddings.size(0) 279 | num_pos = position_embeddings.size(1) 280 | position_embeddings = position_embeddings.view( 281 | num_batch, num_pos, self.num_pos_emb, -1)[torch.arange(0, num_batch).long(), :, task_idx, :] 282 | 283 | embeddings = words_embeddings + position_embeddings 284 | 285 | if self.token_type_embeddings is not None: 286 | embeddings = embeddings + self.token_type_embeddings(token_type_ids) 287 | 288 | if self.fp32_embedding: 289 | embeddings = embeddings.half() 290 | embeddings = self.LayerNorm(embeddings) 291 | embeddings = self.dropout(embeddings) 292 | return embeddings 293 | 294 | 295 | class BertSelfAttention(nn.Module): 296 | def __init__(self, config): 297 | super(BertSelfAttention, self).__init__() 298 | if config.hidden_size % config.num_attention_heads != 0: 299 | raise ValueError( 300 | "The hidden size (%d) is not a multiple of the number of attention " 301 | "heads (%d)" % (config.hidden_size, config.num_attention_heads)) 302 | self.num_attention_heads = config.num_attention_heads 303 | self.attention_head_size = int( 304 | config.hidden_size / config.num_attention_heads) 305 | self.all_head_size = self.num_attention_heads * self.attention_head_size 306 | 307 | if hasattr(config, 'num_qkv') and (config.num_qkv > 1): 308 | self.num_qkv = config.num_qkv 309 | else: 310 | self.num_qkv = 1 311 | 312 | self.query = nn.Linear( 313 | config.hidden_size, self.all_head_size * self.num_qkv) 314 | self.key = nn.Linear(config.hidden_size, 315 | self.all_head_size * self.num_qkv) 316 | self.value = nn.Linear( 317 | config.hidden_size, self.all_head_size * self.num_qkv) 318 | 319 | self.dropout = nn.Dropout(config.attention_probs_dropout_prob) 320 | 321 | self.uni_debug_flag = True if os.getenv( 322 | 'UNI_DEBUG_FLAG', '') else False 323 | if self.uni_debug_flag: 324 | self.register_buffer('debug_attention_probs', 325 | torch.zeros((512, 512))) 326 | if hasattr(config, 'seg_emb') and config.seg_emb: 327 | self.b_q_s = nn.Parameter(torch.zeros( 328 | 1, self.num_attention_heads, 1, self.attention_head_size)) 329 | self.seg_emb = nn.Embedding( 330 | config.type_vocab_size, self.all_head_size) 331 | else: 332 | self.b_q_s = None 333 | self.seg_emb = None 334 | 335 | def transpose_for_scores(self, x, mask_qkv=None): 336 | if self.num_qkv > 1: 337 | sz = x.size()[:-1] + (self.num_qkv, 338 | self.num_attention_heads, self.all_head_size) 339 | # (batch, pos, num_qkv, head, head_hid) 340 | x = x.view(*sz) 341 | if mask_qkv is None: 342 | x = x[:, :, 0, :, :] 343 | elif isinstance(mask_qkv, int): 344 | x = x[:, :, mask_qkv, :, :] 345 | else: 346 | # mask_qkv: (batch, pos) 347 | if mask_qkv.size(1) > sz[1]: 348 | mask_qkv = mask_qkv[:, :sz[1]] 349 | # -> x: (batch, pos, head, head_hid) 350 | x = x.gather(2, mask_qkv.view(sz[0], sz[1], 1, 1, 1).expand( 351 | sz[0], sz[1], 1, sz[3], sz[4])).squeeze(2) 352 | else: 353 | sz = x.size()[:-1] + (self.num_attention_heads, 354 | self.attention_head_size) 355 | # (batch, pos, head, head_hid) 356 | x = x.view(*sz) 357 | # (batch, head, pos, head_hid) 358 | return x.permute(0, 2, 1, 3) 359 | 360 | def forward(self, hidden_states, attention_mask, history_states=None, 361 | mask_qkv=None, seg_ids=None, key_history=None, value_history=None, 362 | key_cache=None, value_cache=None, rel_pos=None, 363 | ): 364 | if history_states is None: 365 | mixed_query_layer = self.query(hidden_states) 366 | # possible issue: https://github.com/NVIDIA/apex/issues/131 367 | mixed_key_layer = F.linear(hidden_states, self.key.weight) 368 | mixed_value_layer = self.value(hidden_states) 369 | else: 370 | x_states = torch.cat((history_states, hidden_states), dim=1) 371 | mixed_query_layer = self.query(hidden_states) 372 | # possible issue: https://github.com/NVIDIA/apex/issues/131 373 | mixed_key_layer = F.linear(x_states, self.key.weight) 374 | mixed_value_layer = self.value(x_states) 375 | 376 | if key_cache is not None and isinstance(key_cache, list): 377 | key_cache.append(mixed_key_layer) 378 | mixed_key_layer = torch.cat(key_cache, dim=1) 379 | 380 | if value_cache is not None and isinstance(value_cache, list): 381 | value_cache.append(mixed_value_layer) 382 | mixed_value_layer = torch.cat(value_cache, dim=1) 383 | 384 | query_layer = self.transpose_for_scores(mixed_query_layer, mask_qkv) 385 | key_layer = self.transpose_for_scores(mixed_key_layer, mask_qkv) 386 | value_layer = self.transpose_for_scores(mixed_value_layer, mask_qkv) 387 | 388 | if key_history is not None and not isinstance(key_history, list): 389 | key_layer = torch.cat((key_history, key_layer), dim=-2) 390 | value_layer = torch.cat((value_history, value_layer), dim=-2) 391 | 392 | # Take the dot product between "query" and "key" to get the raw attention scores. 393 | # (batch, head, pos, pos) 394 | attention_scores = torch.matmul( 395 | query_layer / math.sqrt(self.attention_head_size), key_layer.transpose(-1, -2)) 396 | if rel_pos is not None: 397 | attention_scores = attention_scores + rel_pos 398 | 399 | if self.seg_emb is not None: 400 | seg_rep = self.seg_emb(seg_ids) 401 | # (batch, pos, head, head_hid) 402 | seg_rep = seg_rep.view(seg_rep.size(0), seg_rep.size( 403 | 1), self.num_attention_heads, self.attention_head_size) 404 | qs = torch.einsum('bnih,bjnh->bnij', 405 | query_layer + self.b_q_s, seg_rep) 406 | attention_scores = attention_scores + qs 407 | 408 | # attention_scores = attention_scores / math.sqrt(self.attention_head_size) 409 | 410 | # Apply the attention mask is (precomputed for all layers in BertModel forward() function) 411 | attention_scores = attention_scores + attention_mask 412 | 413 | # Normalize the attention scores to probabilities. 414 | attention_probs = nn.Softmax(dim=-1)(attention_scores) 415 | 416 | if self.uni_debug_flag: 417 | _pos = attention_probs.size(-1) 418 | self.debug_attention_probs[:_pos, :_pos].copy_( 419 | attention_probs[0].mean(0).view(_pos, _pos)) 420 | 421 | # This is actually dropping out entire tokens to attend to, which might 422 | # seem a bit unusual, but is taken from the original Transformer paper. 423 | attention_probs = self.dropout(attention_probs) 424 | 425 | context_layer = torch.matmul(attention_probs, value_layer) 426 | context_layer = context_layer.permute(0, 2, 1, 3).contiguous() 427 | new_context_layer_shape = context_layer.size()[ 428 | :-2] + (self.all_head_size,) 429 | context_layer = context_layer.view(*new_context_layer_shape) 430 | 431 | if isinstance(key_history, list): 432 | key_history.append(key_layer) 433 | if isinstance(value_history, list): 434 | value_history.append(value_layer) 435 | 436 | return context_layer 437 | 438 | 439 | class BertSelfOutput(nn.Module): 440 | def __init__(self, config): 441 | super(BertSelfOutput, self).__init__() 442 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 443 | self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-5) 444 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 445 | 446 | def forward(self, hidden_states, input_tensor): 447 | hidden_states = self.dense(hidden_states) 448 | hidden_states = self.dropout(hidden_states) 449 | hidden_states = self.LayerNorm(hidden_states + input_tensor) 450 | return hidden_states 451 | 452 | 453 | class BertAttention(nn.Module): 454 | def __init__(self, config): 455 | super(BertAttention, self).__init__() 456 | self.self = BertSelfAttention(config) 457 | self.output = BertSelfOutput(config) 458 | 459 | def forward(self, input_tensor, attention_mask, history_states=None, 460 | mask_qkv=None, seg_ids=None, key_history=None, value_history=None, rel_pos=None): 461 | self_output = self.self( 462 | input_tensor, attention_mask, history_states=history_states, 463 | mask_qkv=mask_qkv, seg_ids=seg_ids, key_history=key_history, value_history=value_history, rel_pos=rel_pos) 464 | attention_output = self.output(self_output, input_tensor) 465 | return attention_output 466 | 467 | 468 | class BertIntermediate(nn.Module): 469 | def __init__(self, config): 470 | super(BertIntermediate, self).__init__() 471 | self.dense = nn.Linear(config.hidden_size, config.intermediate_size) 472 | self.intermediate_act_fn = ACT2FN[config.hidden_act] \ 473 | if isinstance(config.hidden_act, str) else config.hidden_act 474 | 475 | def forward(self, hidden_states): 476 | hidden_states = self.dense(hidden_states) 477 | hidden_states = self.intermediate_act_fn(hidden_states) 478 | return hidden_states 479 | 480 | 481 | class BertOutput(nn.Module): 482 | def __init__(self, config): 483 | super(BertOutput, self).__init__() 484 | self.dense = nn.Linear(config.intermediate_size, config.hidden_size) 485 | self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-5) 486 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 487 | 488 | def forward(self, hidden_states, input_tensor): 489 | hidden_states = self.dense(hidden_states) 490 | hidden_states = self.dropout(hidden_states) 491 | hidden_states = self.LayerNorm(hidden_states + input_tensor) 492 | return hidden_states 493 | 494 | 495 | class TransformerFFN(nn.Module): 496 | def __init__(self, config): 497 | super(TransformerFFN, self).__init__() 498 | self.ffn_type = config.ffn_type 499 | assert self.ffn_type in (1, 2) 500 | if self.ffn_type in (1, 2): 501 | self.wx0 = nn.Linear(config.hidden_size, config.hidden_size) 502 | if self.ffn_type in (2,): 503 | self.wx1 = nn.Linear(config.hidden_size, config.hidden_size) 504 | if self.ffn_type in (1, 2): 505 | self.output = nn.Linear(config.hidden_size, config.hidden_size) 506 | self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-5) 507 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 508 | 509 | def forward(self, x): 510 | if self.ffn_type in (1, 2): 511 | x0 = self.wx0(x) 512 | if self.ffn_type == 1: 513 | x1 = x 514 | elif self.ffn_type == 2: 515 | x1 = self.wx1(x) 516 | out = self.output(x0 * x1) 517 | out = self.dropout(out) 518 | out = self.LayerNorm(out + x) 519 | return out 520 | 521 | 522 | class BertLayer(nn.Module): 523 | def __init__(self, config): 524 | super(BertLayer, self).__init__() 525 | self.attention = BertAttention(config) 526 | self.ffn_type = config.ffn_type 527 | if self.ffn_type: 528 | self.ffn = TransformerFFN(config) 529 | else: 530 | self.intermediate = BertIntermediate(config) 531 | self.output = BertOutput(config) 532 | 533 | def forward(self, hidden_states, attention_mask, history_states=None, 534 | mask_qkv=None, seg_ids=None, key_history=None, value_history=None, rel_pos=None): 535 | attention_output = self.attention( 536 | hidden_states, attention_mask, history_states=history_states, 537 | mask_qkv=mask_qkv, seg_ids=seg_ids, key_history=key_history, value_history=value_history, rel_pos=rel_pos) 538 | if self.ffn_type: 539 | layer_output = self.ffn(attention_output) 540 | else: 541 | intermediate_output = self.intermediate(attention_output) 542 | layer_output = self.output(intermediate_output, attention_output) 543 | return layer_output 544 | 545 | 546 | class BertEncoder(nn.Module): 547 | def __init__(self, config): 548 | super(BertEncoder, self).__init__() 549 | layer = BertLayer(config) 550 | self.layer = nn.ModuleList([copy.deepcopy(layer) 551 | for _ in range(config.num_hidden_layers)]) 552 | 553 | def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True, 554 | prev_embedding=None, prev_encoded_layers=None, mask_qkv=None, 555 | seg_ids=None, key_history=None, value_history=None, rel_pos=None): 556 | # history embedding and encoded layer must be simultanously given 557 | assert (prev_embedding is None) == (prev_encoded_layers is None) 558 | 559 | all_encoder_layers = [] 560 | if (prev_embedding is not None) and (prev_encoded_layers is not None): 561 | history_states = prev_embedding 562 | for i, layer_module in enumerate(self.layer): 563 | hidden_states = layer_module( 564 | hidden_states, attention_mask, history_states=history_states, 565 | mask_qkv=mask_qkv, seg_ids=seg_ids, rel_pos=rel_pos) 566 | if output_all_encoded_layers: 567 | all_encoder_layers.append(hidden_states) 568 | if prev_encoded_layers is not None: 569 | history_states = prev_encoded_layers[i] 570 | else: 571 | for i, layer_module in enumerate(self.layer): 572 | set_key = None 573 | if isinstance(key_history, list): 574 | set_key = key_history if len(key_history) < len(self.layer) else key_history[i] 575 | set_value = None 576 | if isinstance(value_history, list): 577 | set_value = value_history if len(key_history) < len(self.layer) else value_history[i] 578 | hidden_states = layer_module( 579 | hidden_states, attention_mask, mask_qkv=mask_qkv, seg_ids=seg_ids, 580 | key_history=set_key, value_history=set_value, rel_pos=rel_pos) 581 | if output_all_encoded_layers: 582 | all_encoder_layers.append(hidden_states) 583 | if not output_all_encoded_layers: 584 | all_encoder_layers.append(hidden_states) 585 | return all_encoder_layers 586 | 587 | 588 | class BertPooler(nn.Module): 589 | def __init__(self, config): 590 | super(BertPooler, self).__init__() 591 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 592 | self.activation = nn.Tanh() 593 | 594 | def forward(self, hidden_states): 595 | # We "pool" the model by simply taking the hidden state corresponding 596 | # to the first token. 597 | first_token_tensor = hidden_states[:, 0] 598 | pooled_output = self.dense(first_token_tensor) 599 | pooled_output = self.activation(pooled_output) 600 | return pooled_output 601 | 602 | 603 | class BertPredictionHeadTransform(nn.Module): 604 | def __init__(self, config): 605 | super(BertPredictionHeadTransform, self).__init__() 606 | self.transform_act_fn = ACT2FN[config.hidden_act] \ 607 | if isinstance(config.hidden_act, str) else config.hidden_act 608 | hid_size = config.hidden_size 609 | if hasattr(config, 'relax_projection') and (config.relax_projection > 1): 610 | hid_size *= config.relax_projection 611 | self.dense = nn.Linear(config.hidden_size, hid_size) 612 | self.LayerNorm = BertLayerNorm(hid_size, eps=1e-5) 613 | 614 | def forward(self, hidden_states): 615 | hidden_states = self.dense(hidden_states) 616 | hidden_states = self.transform_act_fn(hidden_states) 617 | hidden_states = self.LayerNorm(hidden_states) 618 | return hidden_states 619 | 620 | 621 | class BertLMPredictionHead(nn.Module): 622 | def __init__(self, config, bert_model_embedding_weights): 623 | super(BertLMPredictionHead, self).__init__() 624 | self.transform = BertPredictionHeadTransform(config) 625 | 626 | # The output weights are the same as the input embeddings, but there is 627 | # an output-only bias for each token. 628 | self.decoder = nn.Linear(bert_model_embedding_weights.size(1), 629 | bert_model_embedding_weights.size(0), 630 | bias=False) 631 | self.decoder.weight = bert_model_embedding_weights 632 | self.bias = nn.Parameter(torch.zeros( 633 | bert_model_embedding_weights.size(0))) 634 | if hasattr(config, 'relax_projection') and (config.relax_projection > 1): 635 | self.relax_projection = config.relax_projection 636 | else: 637 | self.relax_projection = 0 638 | self.fp32_embedding = config.fp32_embedding 639 | 640 | def convert_to_type(tensor): 641 | if self.fp32_embedding: 642 | return tensor.half() 643 | else: 644 | return tensor 645 | 646 | self.type_converter = convert_to_type 647 | self.converted = False 648 | 649 | def forward(self, hidden_states, task_idx=None): 650 | if not self.converted: 651 | self.converted = True 652 | if self.fp32_embedding: 653 | self.transform.half() 654 | hidden_states = self.transform(self.type_converter(hidden_states)) 655 | if self.relax_projection > 1: 656 | num_batch = hidden_states.size(0) 657 | num_pos = hidden_states.size(1) 658 | # (batch, num_pos, relax_projection*hid) -> (batch, num_pos, relax_projection, hid) -> (batch, num_pos, hid) 659 | hidden_states = hidden_states.view( 660 | num_batch, num_pos, self.relax_projection, -1)[torch.arange(0, num_batch).long(), :, task_idx, :] 661 | if self.fp32_embedding: 662 | hidden_states = F.linear(self.type_converter(hidden_states), self.type_converter( 663 | self.decoder.weight), self.type_converter(self.bias)) 664 | else: 665 | hidden_states = self.decoder(hidden_states) + self.bias 666 | return hidden_states 667 | 668 | 669 | class BertOnlyMLMHead(nn.Module): 670 | def __init__(self, config, bert_model_embedding_weights): 671 | super(BertOnlyMLMHead, self).__init__() 672 | self.predictions = BertLMPredictionHead( 673 | config, bert_model_embedding_weights) 674 | 675 | def forward(self, sequence_output): 676 | prediction_scores = self.predictions(sequence_output) 677 | return prediction_scores 678 | 679 | 680 | class BertOnlyNSPHead(nn.Module): 681 | def __init__(self, config): 682 | super(BertOnlyNSPHead, self).__init__() 683 | self.seq_relationship = nn.Linear(config.hidden_size, 2) 684 | 685 | def forward(self, pooled_output): 686 | seq_relationship_score = self.seq_relationship(pooled_output) 687 | return seq_relationship_score 688 | 689 | 690 | class BertPreTrainingHeads(nn.Module): 691 | def __init__(self, config, bert_model_embedding_weights, num_labels=2): 692 | super(BertPreTrainingHeads, self).__init__() 693 | self.predictions = BertLMPredictionHead( 694 | config, bert_model_embedding_weights) 695 | self.seq_relationship = nn.Linear(config.hidden_size, num_labels) 696 | 697 | def forward(self, sequence_output, pooled_output, task_idx=None): 698 | prediction_scores = self.predictions(sequence_output, task_idx) 699 | if pooled_output is None: 700 | seq_relationship_score = None 701 | else: 702 | seq_relationship_score = self.seq_relationship(pooled_output) 703 | return prediction_scores, seq_relationship_score 704 | 705 | 706 | class PreTrainedBertModel(nn.Module): 707 | """ An abstract class to handle weights initialization and 708 | a simple interface for dowloading and loading pretrained models. 709 | """ 710 | 711 | def __init__(self, config, *inputs, **kwargs): 712 | super(PreTrainedBertModel, self).__init__() 713 | if not isinstance(config, BertConfig): 714 | raise ValueError( 715 | "Parameter config in `{}(config)` should be an instance of class `BertConfig`. " 716 | "To create a model from a Google pretrained model use " 717 | "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format( 718 | self.__class__.__name__, self.__class__.__name__ 719 | )) 720 | self.config = config 721 | 722 | def init_bert_weights(self, module): 723 | """ Initialize the weights. 724 | """ 725 | if isinstance(module, (nn.Linear, nn.Embedding)): 726 | # Slightly different from the TF version which uses truncated_normal for initialization 727 | # cf https://github.com/pytorch/pytorch/pull/5617 728 | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) 729 | # module.weight.data.copy_(torch.Tensor( 730 | # truncnorm.rvs(-1, 1, size=list(module.weight.data.shape)) * self.config.initializer_range)) 731 | elif isinstance(module, BertLayerNorm): 732 | module.bias.data.zero_() 733 | module.weight.data.fill_(1.0) 734 | if isinstance(module, nn.Linear) and module.bias is not None: 735 | module.bias.data.zero_() 736 | 737 | @classmethod 738 | def from_pretrained(cls, pretrained_model_name, config, state_dict=None, cache_dir=None, *inputs, **kwargs): 739 | """ 740 | Instantiate a PreTrainedBertModel from a pre-trained model file or a pytorch state dict. 741 | Download and cache the pre-trained model file if needed. 742 | Params: 743 | pretrained_model_name: either: 744 | - a str with the name of a pre-trained model to load selected in the list of: 745 | . `bert-base-uncased` 746 | . `bert-large-uncased` 747 | . `bert-base-cased` 748 | . `bert-base-multilingual` 749 | . `bert-base-chinese` 750 | - a path or url to a pretrained model archive containing: 751 | . `bert_config.json` a configuration file for the model 752 | . `pytorch_model.bin` a PyTorch dump of a BertForPreTraining instance 753 | cache_dir: an optional path to a folder in which the pre-trained models will be cached. 754 | state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of Google pre-trained models 755 | *inputs, **kwargs: additional input for the specific Bert class 756 | (ex: num_labels for BertForSequenceClassification) 757 | """ 758 | logger.info("Model config {}".format(config)) 759 | 760 | # clean the arguments in kwargs 761 | for arg_clean in ('config_path', 'type_vocab_size', 'relax_projection', 'new_pos_ids', 'task_idx', 762 | 'max_position_embeddings', 'fp32_embedding', 'ffn_type', 'label_smoothing', 763 | 'hidden_dropout_prob', 'attention_probs_dropout_prob', 'num_qkv', 'seg_emb', 764 | 'word_emb_map', 'num_labels', 'num_rel', 'num_sentlvl_labels'): 765 | if arg_clean in kwargs: 766 | del kwargs[arg_clean] 767 | 768 | # Instantiate model. 769 | model = cls(config, *inputs, **kwargs) 770 | if state_dict is None: 771 | weights_path = os.path.join(pretrained_model_name, WEIGHTS_NAME) 772 | state_dict = torch.load(weights_path) 773 | 774 | old_keys = [] 775 | new_keys = [] 776 | for key in state_dict.keys(): 777 | new_key = None 778 | if 'gamma' in key: 779 | new_key = key.replace('gamma', 'weight') 780 | if 'beta' in key: 781 | new_key = key.replace('beta', 'bias') 782 | if new_key: 783 | old_keys.append(key) 784 | new_keys.append(new_key) 785 | for old_key, new_key in zip(old_keys, new_keys): 786 | state_dict[new_key] = state_dict.pop(old_key) 787 | 788 | missing_keys = [] 789 | unexpected_keys = [] 790 | error_msgs = [] 791 | # copy state_dict so _load_from_state_dict can modify it 792 | metadata = getattr(state_dict, '_metadata', None) 793 | state_dict = state_dict.copy() 794 | if metadata is not None: 795 | state_dict._metadata = metadata 796 | 797 | def load(module, prefix=''): 798 | local_metadata = {} if metadata is None else metadata.get( 799 | prefix[:-1], {}) 800 | module._load_from_state_dict( 801 | state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs) 802 | for name, child in module._modules.items(): 803 | if child is not None: 804 | load(child, prefix + name + '.') 805 | 806 | load(model, prefix='' if hasattr(model, 'bert') else 'bert.') 807 | model.missing_keys = missing_keys 808 | if len(missing_keys) > 0: 809 | logger.info("Weights of {} not initialized from pretrained model: {}".format( 810 | model.__class__.__name__, missing_keys)) 811 | if len(unexpected_keys) > 0: 812 | logger.info("Weights from pretrained model not used in {}: {}".format( 813 | model.__class__.__name__, unexpected_keys)) 814 | if len(error_msgs) > 0: 815 | logger.info('\n'.join(error_msgs)) 816 | return model 817 | 818 | 819 | class BertModel(PreTrainedBertModel): 820 | """BERT model ("Bidirectional Embedding Representations from a Transformer"). 821 | Params: 822 | config: a BertConfig class instance with the configuration to build a new model 823 | Inputs: 824 | `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] 825 | with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts 826 | `extract_features.py`, `run_classifier.py`) 827 | `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token 828 | types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to 829 | a `sentence B` token (see BERT paper for more details). 830 | `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices 831 | selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max 832 | input sequence length in the current batch. It's the mask that we typically use for attention when 833 | a batch has varying length sentences. 834 | `output_all_encoded_layers`: boolean which controls the content of the `encoded_layers` output as described below. Default: `True`. 835 | Outputs: Tuple of (encoded_layers, pooled_output) 836 | `encoded_layers`: controled by `output_all_encoded_layers` argument: 837 | - `output_all_encoded_layers=True`: outputs a list of the full sequences of encoded-hidden-states at the end 838 | of each attention block (i.e. 12 full sequences for BERT-base, 24 for BERT-large), each 839 | encoded-hidden-state is a torch.FloatTensor of size [batch_size, sequence_length, hidden_size], 840 | - `output_all_encoded_layers=False`: outputs only the full sequence of hidden-states corresponding 841 | to the last attention block of shape [batch_size, sequence_length, hidden_size], 842 | `pooled_output`: a torch.FloatTensor of size [batch_size, hidden_size] which is the output of a 843 | classifier pretrained on top of the hidden state associated to the first character of the 844 | input (`CLF`) to train on the Next-Sentence task (see BERT's paper). 845 | ``` 846 | """ 847 | 848 | def __init__(self, config): 849 | super(BertModel, self).__init__(config) 850 | self.embeddings = BertEmbeddings(config) 851 | self.encoder = BertEncoder(config) 852 | self.pooler = BertPooler(config) 853 | self.config = config 854 | self.apply(self.init_bert_weights) 855 | 856 | def rescale_some_parameters(self): 857 | for layer_id, layer in enumerate(self.encoder.layer): 858 | layer.attention.output.dense.weight.data.div_( 859 | math.sqrt(2.0 * (layer_id + 1))) 860 | layer.output.dense.weight.data.div_(math.sqrt(2.0 * (layer_id + 1))) 861 | 862 | def get_extended_attention_mask(self, input_ids, token_type_ids, attention_mask): 863 | if attention_mask is None: 864 | attention_mask = torch.ones_like(input_ids) 865 | if token_type_ids is None: 866 | token_type_ids = torch.zeros_like(input_ids) 867 | 868 | # We create a 3D attention mask from a 2D tensor mask. 869 | # Sizes are [batch_size, 1, 1, to_seq_length] 870 | # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] 871 | # this attention mask is more simple than the triangular masking of causal attention 872 | # used in OpenAI GPT, we just need to prepare the broadcast dimension here. 873 | if attention_mask.dim() == 2: 874 | extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) 875 | elif attention_mask.dim() == 3: 876 | extended_attention_mask = attention_mask.unsqueeze(1) 877 | else: 878 | raise NotImplementedError 879 | 880 | # Since attention_mask is 1.0 for positions we want to attend and 0.0 for 881 | # masked positions, this operation will create a tensor which is 0.0 for 882 | # positions we want to attend and -10000.0 for masked positions. 883 | # Since we are adding it to the raw scores before the softmax, this is 884 | # effectively the same as removing these entirely. 885 | extended_attention_mask = extended_attention_mask.to( 886 | dtype=next(self.parameters()).dtype) # fp16 compatibility 887 | extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 888 | return extended_attention_mask 889 | 890 | def forward(self, input_ids, token_type_ids=None, attention_mask=None, output_all_encoded_layers=True, 891 | mask_qkv=None, task_idx=None, key_history=None, value_history=None, position_ids=None): 892 | extended_attention_mask = self.get_extended_attention_mask( 893 | input_ids, token_type_ids, attention_mask) 894 | 895 | embedding_output = self.embeddings( 896 | input_ids, token_type_ids, task_idx=task_idx, position_ids=position_ids) 897 | encoded_layers = self.encoder(embedding_output, extended_attention_mask, 898 | output_all_encoded_layers=output_all_encoded_layers, 899 | mask_qkv=mask_qkv, seg_ids=token_type_ids, 900 | key_history=key_history, value_history=value_history) 901 | sequence_output = encoded_layers[-1] 902 | pooled_output = self.pooler(sequence_output) 903 | if not output_all_encoded_layers: 904 | encoded_layers = encoded_layers[-1] 905 | return encoded_layers, pooled_output 906 | 907 | 908 | class BertModelIncr(BertModel): 909 | def __init__(self, config): 910 | super(BertModelIncr, self).__init__(config) 911 | 912 | if self.config.rel_pos_bins > 0: 913 | self.rel_pos_bias = nn.Linear(self.config.rel_pos_bins, config.num_attention_heads, bias=False) 914 | else: 915 | self.rel_pos_bias = None 916 | 917 | def forward(self, input_ids, token_type_ids, position_ids, attention_mask, output_all_encoded_layers=True, 918 | prev_embedding=None, prev_encoded_layers=None, mask_qkv=None, task_idx=None, rel_pos=None): 919 | extended_attention_mask = self.get_extended_attention_mask( 920 | input_ids, token_type_ids, attention_mask) 921 | 922 | embedding_output = self.embeddings( 923 | input_ids, token_type_ids, position_ids, task_idx=task_idx) 924 | 925 | if self.rel_pos_bias is not None: 926 | # print("Rel pos size = %s" % str(rel_pos.size())) 927 | rel_pos = F.one_hot(rel_pos, num_classes=self.config.rel_pos_bins).type_as(embedding_output) 928 | # print("Rel pos size = %s" % str(rel_pos.size())) 929 | rel_pos = self.rel_pos_bias(rel_pos).permute(0, 3, 1, 2) 930 | # print("Rel pos size = %s" % str(rel_pos.size())) 931 | else: 932 | rel_pos = None 933 | encoded_layers = self.encoder(embedding_output, 934 | extended_attention_mask, 935 | output_all_encoded_layers=output_all_encoded_layers, 936 | prev_embedding=prev_embedding, 937 | prev_encoded_layers=prev_encoded_layers, mask_qkv=mask_qkv, 938 | seg_ids=token_type_ids, rel_pos=rel_pos) 939 | sequence_output = encoded_layers[-1] 940 | pooled_output = self.pooler(sequence_output) 941 | if not output_all_encoded_layers: 942 | encoded_layers = encoded_layers[-1] 943 | return embedding_output, encoded_layers, pooled_output 944 | 945 | 946 | class BertForPreTraining(PreTrainedBertModel): 947 | """BERT model with pre-training heads. 948 | This module comprises the BERT model followed by the two pre-training heads: 949 | - the masked language modeling head, and 950 | - the next sentence classification head. 951 | Params: 952 | config: a BertConfig class instance with the configuration to build a new model. 953 | Inputs: 954 | `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] 955 | with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts 956 | `extract_features.py`, `run_classifier.py`) 957 | `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token 958 | types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to 959 | a `sentence B` token (see BERT paper for more details). 960 | `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices 961 | selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max 962 | input sequence length in the current batch. It's the mask that we typically use for attention when 963 | a batch has varying length sentences. 964 | `masked_lm_labels`: masked language modeling labels: torch.LongTensor of shape [batch_size, sequence_length] 965 | with indices selected in [-1, 0, ..., vocab_size]. All labels set to -1 are ignored (masked), the loss 966 | is only computed for the labels set in [0, ..., vocab_size] 967 | `next_sentence_label`: next sentence classification loss: torch.LongTensor of shape [batch_size] 968 | with indices selected in [0, 1]. 969 | 0 => next sentence is the continuation, 1 => next sentence is a random sentence. 970 | Outputs: 971 | if `masked_lm_labels` and `next_sentence_label` are not `None`: 972 | Outputs the total_loss which is the sum of the masked language modeling loss and the next 973 | sentence classification loss. 974 | if `masked_lm_labels` or `next_sentence_label` is `None`: 975 | Outputs a tuple comprising 976 | - the masked language modeling logits of shape [batch_size, sequence_length, vocab_size], and 977 | - the next sentence classification logits of shape [batch_size, 2]. 978 | Example usage: 979 | ```python 980 | # Already been converted into WordPiece token ids 981 | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) 982 | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) 983 | token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) 984 | config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, 985 | num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) 986 | model = BertForPreTraining(config) 987 | masked_lm_logits_scores, seq_relationship_logits = model(input_ids, token_type_ids, input_mask) 988 | ``` 989 | """ 990 | 991 | def __init__(self, config): 992 | super(BertForPreTraining, self).__init__(config) 993 | self.bert = BertModel(config) 994 | self.cls = BertPreTrainingHeads( 995 | config, self.bert.embeddings.word_embeddings.weight) 996 | self.apply(self.init_bert_weights) 997 | 998 | def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None, 999 | next_sentence_label=None, mask_qkv=None, task_idx=None): 1000 | sequence_output, pooled_output = self.bert(input_ids, token_type_ids, attention_mask, 1001 | output_all_encoded_layers=False, mask_qkv=mask_qkv, 1002 | task_idx=task_idx) 1003 | prediction_scores, seq_relationship_score = self.cls( 1004 | sequence_output, pooled_output) 1005 | 1006 | if masked_lm_labels is not None and next_sentence_label is not None: 1007 | loss_fct = CrossEntropyLoss(ignore_index=-1) 1008 | masked_lm_loss = loss_fct( 1009 | prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1)) 1010 | next_sentence_loss = loss_fct( 1011 | seq_relationship_score.view(-1, 2), next_sentence_label.view(-1)) 1012 | total_loss = masked_lm_loss + next_sentence_loss 1013 | return total_loss 1014 | else: 1015 | return prediction_scores, seq_relationship_score 1016 | 1017 | 1018 | class BertPreTrainingPairTransform(nn.Module): 1019 | def __init__(self, config): 1020 | super(BertPreTrainingPairTransform, self).__init__() 1021 | self.dense = nn.Linear(config.hidden_size * 2, config.hidden_size) 1022 | self.transform_act_fn = ACT2FN[config.hidden_act] \ 1023 | if isinstance(config.hidden_act, str) else config.hidden_act 1024 | # self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-5) 1025 | 1026 | def forward(self, pair_x, pair_y): 1027 | hidden_states = torch.cat([pair_x, pair_y], dim=-1) 1028 | hidden_states = self.dense(hidden_states) 1029 | hidden_states = self.transform_act_fn(hidden_states) 1030 | # hidden_states = self.LayerNorm(hidden_states) 1031 | return hidden_states 1032 | 1033 | 1034 | def relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128): 1035 | """ 1036 | Adapted from Mesh Tensorflow: 1037 | https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593 1038 | """ 1039 | ret = 0 1040 | if bidirectional: 1041 | num_buckets //= 2 1042 | # mtf.to_int32(mtf.less(n, 0)) * num_buckets 1043 | ret += (relative_position > 0).long() * num_buckets 1044 | n = torch.abs(relative_position) 1045 | else: 1046 | n = torch.max(-relative_position, torch.zeros_like(relative_position)) 1047 | # now n is in the range [0, inf) 1048 | 1049 | # half of the buckets are for exact increments in positions 1050 | max_exact = num_buckets // 2 1051 | is_small = n < max_exact 1052 | 1053 | # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance 1054 | val_if_large = max_exact + ( 1055 | torch.log(n.float() / max_exact) / math.log(max_distance / 1056 | max_exact) * (num_buckets - max_exact) 1057 | ).to(torch.long) 1058 | val_if_large = torch.min( 1059 | val_if_large, torch.full_like(val_if_large, num_buckets - 1)) 1060 | 1061 | ret += torch.where(is_small, n, val_if_large) 1062 | return ret 1063 | 1064 | 1065 | class BertForSeq2SeqDecoder(PreTrainedBertModel): 1066 | """refer to BertForPreTraining""" 1067 | 1068 | def __init__(self, config, mask_word_id=0, num_labels=2, num_rel=0, 1069 | search_beam_size=1, length_penalty=1.0, eos_id=0, sos_id=0, 1070 | forbid_duplicate_ngrams=False, forbid_ignore_set=None, ngram_size=3, min_len=0, mode="s2s", 1071 | pos_shift=False): 1072 | super(BertForSeq2SeqDecoder, self).__init__(config) 1073 | self.bert = BertModelIncr(config) 1074 | self.cls = BertPreTrainingHeads( 1075 | config, self.bert.embeddings.word_embeddings.weight, num_labels=num_labels) 1076 | self.apply(self.init_bert_weights) 1077 | self.crit_mask_lm = nn.CrossEntropyLoss(reduction='none') 1078 | self.crit_next_sent = nn.CrossEntropyLoss(ignore_index=-1) 1079 | self.mask_word_id = mask_word_id 1080 | self.num_labels = num_labels 1081 | self.search_beam_size = search_beam_size 1082 | self.length_penalty = length_penalty 1083 | self.eos_id = eos_id 1084 | self.sos_id = sos_id 1085 | self.forbid_duplicate_ngrams = forbid_duplicate_ngrams 1086 | self.forbid_ignore_set = forbid_ignore_set 1087 | self.ngram_size = ngram_size 1088 | self.min_len = min_len 1089 | assert mode in ("s2s", "l2r") 1090 | self.mode = mode 1091 | self.pos_shift = pos_shift 1092 | 1093 | def forward(self, input_ids, token_type_ids, position_ids, attention_mask, task_idx=None, mask_qkv=None): 1094 | if self.search_beam_size > 1: 1095 | return self.beam_search(input_ids, token_type_ids, position_ids, attention_mask, 1096 | task_idx=task_idx, mask_qkv=mask_qkv) 1097 | 1098 | input_shape = list(input_ids.size()) 1099 | batch_size = input_shape[0] 1100 | input_length = input_shape[1] 1101 | output_shape = list(token_type_ids.size()) 1102 | output_length = output_shape[1] 1103 | 1104 | output_ids = [] 1105 | prev_embedding = None 1106 | prev_encoded_layers = None 1107 | curr_ids = input_ids 1108 | mask_ids = input_ids.new(batch_size, 1).fill_(self.mask_word_id) 1109 | next_pos = input_length 1110 | if self.pos_shift: 1111 | sep_ids = input_ids.new(batch_size, 1).fill_(self.eos_id) 1112 | 1113 | if self.bert.rel_pos_bias is not None: 1114 | rel_pos_mat = position_ids.unsqueeze(-2) - position_ids.unsqueeze(-1) 1115 | rel_pos = relative_position_bucket( 1116 | rel_pos_mat, num_buckets=self.config.rel_pos_bins, max_distance=self.config.max_rel_pos) 1117 | else: 1118 | rel_pos = None 1119 | 1120 | while next_pos < output_length: 1121 | curr_length = list(curr_ids.size())[1] 1122 | 1123 | if self.pos_shift: 1124 | if next_pos == input_length: 1125 | x_input_ids = torch.cat((curr_ids, sep_ids), dim=1) 1126 | start_pos = 0 1127 | else: 1128 | x_input_ids = curr_ids 1129 | start_pos = next_pos 1130 | else: 1131 | start_pos = next_pos - curr_length 1132 | x_input_ids = torch.cat((curr_ids, mask_ids), dim=1) 1133 | 1134 | curr_token_type_ids = token_type_ids[:, start_pos:next_pos + 1] 1135 | curr_attention_mask = attention_mask[:, 1136 | start_pos:next_pos + 1, :next_pos + 1] 1137 | curr_position_ids = position_ids[:, start_pos:next_pos + 1] 1138 | 1139 | if rel_pos is not None: 1140 | cur_rel_pos = rel_pos[:, start_pos:next_pos + 1, :next_pos + 1] 1141 | else: 1142 | cur_rel_pos = None 1143 | 1144 | new_embedding, new_encoded_layers, _ = \ 1145 | self.bert(x_input_ids, curr_token_type_ids, curr_position_ids, curr_attention_mask, 1146 | output_all_encoded_layers=True, prev_embedding=prev_embedding, 1147 | prev_encoded_layers=prev_encoded_layers, mask_qkv=mask_qkv, rel_pos=cur_rel_pos) 1148 | 1149 | last_hidden = new_encoded_layers[-1][:, -1:, :] 1150 | prediction_scores, _ = self.cls( 1151 | last_hidden, None, task_idx=task_idx) 1152 | _, max_ids = torch.max(prediction_scores, dim=-1) 1153 | output_ids.append(max_ids) 1154 | 1155 | if self.pos_shift: 1156 | if prev_embedding is None: 1157 | prev_embedding = new_embedding 1158 | else: 1159 | prev_embedding = torch.cat( 1160 | (prev_embedding, new_embedding), dim=1) 1161 | if prev_encoded_layers is None: 1162 | prev_encoded_layers = [x for x in new_encoded_layers] 1163 | else: 1164 | prev_encoded_layers = [torch.cat((x[0], x[1]), dim=1) for x in zip( 1165 | prev_encoded_layers, new_encoded_layers)] 1166 | else: 1167 | if prev_embedding is None: 1168 | prev_embedding = new_embedding[:, :-1, :] 1169 | else: 1170 | prev_embedding = torch.cat( 1171 | (prev_embedding, new_embedding[:, :-1, :]), dim=1) 1172 | if prev_encoded_layers is None: 1173 | prev_encoded_layers = [x[:, :-1, :] 1174 | for x in new_encoded_layers] 1175 | else: 1176 | prev_encoded_layers = [torch.cat((x[0], x[1][:, :-1, :]), dim=1) 1177 | for x in zip(prev_encoded_layers, new_encoded_layers)] 1178 | curr_ids = max_ids 1179 | next_pos += 1 1180 | 1181 | return torch.cat(output_ids, dim=1) 1182 | 1183 | def beam_search(self, input_ids, token_type_ids, position_ids, attention_mask, task_idx=None, mask_qkv=None): 1184 | input_shape = list(input_ids.size()) 1185 | batch_size = input_shape[0] 1186 | input_length = input_shape[1] 1187 | output_shape = list(token_type_ids.size()) 1188 | output_length = output_shape[1] 1189 | 1190 | output_ids = [] 1191 | prev_embedding = None 1192 | prev_encoded_layers = None 1193 | curr_ids = input_ids 1194 | mask_ids = input_ids.new(batch_size, 1).fill_(self.mask_word_id) 1195 | next_pos = input_length 1196 | if self.pos_shift: 1197 | sep_ids = input_ids.new(batch_size, 1).fill_(self.eos_id) 1198 | 1199 | K = self.search_beam_size 1200 | 1201 | total_scores = [] 1202 | beam_masks = [] 1203 | step_ids = [] 1204 | step_back_ptrs = [] 1205 | partial_seqs = [] 1206 | forbid_word_mask = None 1207 | buf_matrix = None 1208 | 1209 | if self.bert.rel_pos_bias is not None: 1210 | rel_pos_mat = position_ids.unsqueeze(-2) - position_ids.unsqueeze(-1) 1211 | rel_pos = relative_position_bucket( 1212 | rel_pos_mat, num_buckets=self.config.rel_pos_bins, max_distance=self.config.max_rel_pos) 1213 | else: 1214 | rel_pos = None 1215 | # print("Rel pos size = %s" % str(rel_pos.size())) 1216 | 1217 | while next_pos < output_length: 1218 | curr_length = list(curr_ids.size())[1] 1219 | 1220 | if self.pos_shift: 1221 | if next_pos == input_length: 1222 | x_input_ids = torch.cat((curr_ids, sep_ids), dim=1) 1223 | start_pos = 0 1224 | else: 1225 | x_input_ids = curr_ids 1226 | start_pos = next_pos 1227 | else: 1228 | start_pos = next_pos - curr_length 1229 | x_input_ids = torch.cat((curr_ids, mask_ids), dim=1) 1230 | 1231 | curr_token_type_ids = token_type_ids[:, start_pos:next_pos + 1] 1232 | curr_attention_mask = attention_mask[:, start_pos:next_pos + 1, :next_pos + 1] 1233 | curr_position_ids = position_ids[:, start_pos:next_pos + 1] 1234 | if rel_pos is not None: 1235 | cur_rel_pos = rel_pos[:, start_pos:next_pos + 1, :next_pos + 1] 1236 | else: 1237 | cur_rel_pos = None 1238 | new_embedding, new_encoded_layers, _ = \ 1239 | self.bert(x_input_ids, curr_token_type_ids, curr_position_ids, curr_attention_mask, 1240 | output_all_encoded_layers=True, prev_embedding=prev_embedding, 1241 | prev_encoded_layers=prev_encoded_layers, mask_qkv=mask_qkv, rel_pos=cur_rel_pos) 1242 | 1243 | last_hidden = new_encoded_layers[-1][:, -1:, :] 1244 | prediction_scores, _ = self.cls( 1245 | last_hidden, None, task_idx=task_idx) 1246 | log_scores = torch.nn.functional.log_softmax( 1247 | prediction_scores, dim=-1) 1248 | 1249 | if forbid_word_mask is not None: 1250 | log_scores += (forbid_word_mask * -10000.0) 1251 | if self.min_len and (next_pos - input_length + 1 <= self.min_len): 1252 | log_scores[:, :, self.eos_id].fill_(-10000.0) 1253 | kk_scores, kk_ids = torch.topk(log_scores, k=K) 1254 | if len(total_scores) == 0: 1255 | k_ids = torch.reshape(kk_ids, [batch_size, K]) 1256 | back_ptrs = torch.zeros(batch_size, K, dtype=torch.long) 1257 | k_scores = torch.reshape(kk_scores, [batch_size, K]) 1258 | else: 1259 | last_eos = torch.reshape( 1260 | beam_masks[-1], [batch_size * K, 1, 1]) 1261 | last_seq_scores = torch.reshape( 1262 | total_scores[-1], [batch_size * K, 1, 1]) 1263 | kk_scores += last_eos * (-10000.0) + last_seq_scores 1264 | kk_scores = torch.reshape(kk_scores, [batch_size, K * K]) 1265 | k_scores, k_ids = torch.topk(kk_scores, k=K) 1266 | back_ptrs = torch.div(k_ids, K) 1267 | kk_ids = torch.reshape(kk_ids, [batch_size, K * K]) 1268 | k_ids = torch.gather(kk_ids, 1, k_ids) 1269 | step_back_ptrs.append(back_ptrs) 1270 | step_ids.append(k_ids) 1271 | beam_masks.append(torch.eq(k_ids, self.eos_id).type_as(kk_scores)) 1272 | total_scores.append(k_scores) 1273 | 1274 | def first_expand(x): 1275 | input_shape = list(x.size()) 1276 | expanded_shape = input_shape[:1] + [1] + input_shape[1:] 1277 | x = torch.reshape(x, expanded_shape) 1278 | repeat_count = [1, K] + [1] * (len(input_shape) - 1) 1279 | x = x.repeat(*repeat_count) 1280 | x = torch.reshape(x, [input_shape[0] * K] + input_shape[1:]) 1281 | return x 1282 | 1283 | def select_beam_items(x, ids): 1284 | id_shape = list(ids.size()) 1285 | id_rank = len(id_shape) 1286 | assert len(id_shape) == 2 1287 | x_shape = list(x.size()) 1288 | x = torch.reshape(x, [batch_size, K] + x_shape[1:]) 1289 | x_rank = len(x_shape) + 1 1290 | assert x_rank >= 2 1291 | if id_rank < x_rank: 1292 | ids = torch.reshape( 1293 | ids, id_shape + [1] * (x_rank - id_rank)) 1294 | ids = ids.expand(id_shape + x_shape[1:]) 1295 | y = torch.gather(x, 1, ids) 1296 | y = torch.reshape(y, x_shape) 1297 | return y 1298 | 1299 | is_first = (prev_embedding is None) 1300 | 1301 | if self.pos_shift: 1302 | if prev_embedding is None: 1303 | prev_embedding = first_expand(new_embedding) 1304 | else: 1305 | prev_embedding = torch.cat( 1306 | (prev_embedding, new_embedding), dim=1) 1307 | prev_embedding = select_beam_items( 1308 | prev_embedding, back_ptrs) 1309 | if prev_encoded_layers is None: 1310 | prev_encoded_layers = [first_expand( 1311 | x) for x in new_encoded_layers] 1312 | else: 1313 | prev_encoded_layers = [torch.cat((x[0], x[1]), dim=1) for x in zip( 1314 | prev_encoded_layers, new_encoded_layers)] 1315 | prev_encoded_layers = [select_beam_items( 1316 | x, back_ptrs) for x in prev_encoded_layers] 1317 | else: 1318 | if prev_embedding is None: 1319 | prev_embedding = first_expand(new_embedding[:, :-1, :]) 1320 | else: 1321 | prev_embedding = torch.cat( 1322 | (prev_embedding, new_embedding[:, :-1, :]), dim=1) 1323 | prev_embedding = select_beam_items( 1324 | prev_embedding, back_ptrs) 1325 | if prev_encoded_layers is None: 1326 | prev_encoded_layers = [first_expand( 1327 | x[:, :-1, :]) for x in new_encoded_layers] 1328 | else: 1329 | prev_encoded_layers = [torch.cat((x[0], x[1][:, :-1, :]), dim=1) 1330 | for x in zip(prev_encoded_layers, new_encoded_layers)] 1331 | prev_encoded_layers = [select_beam_items( 1332 | x, back_ptrs) for x in prev_encoded_layers] 1333 | 1334 | curr_ids = torch.reshape(k_ids, [batch_size * K, 1]) 1335 | 1336 | if is_first: 1337 | token_type_ids = first_expand(token_type_ids) 1338 | position_ids = first_expand(position_ids) 1339 | attention_mask = first_expand(attention_mask) 1340 | if rel_pos is not None: 1341 | rel_pos = first_expand(rel_pos) 1342 | mask_ids = first_expand(mask_ids) 1343 | if mask_qkv is not None: 1344 | mask_qkv = first_expand(mask_qkv) 1345 | 1346 | if self.forbid_duplicate_ngrams: 1347 | wids = step_ids[-1].tolist() 1348 | ptrs = step_back_ptrs[-1].tolist() 1349 | if is_first: 1350 | partial_seqs = [] 1351 | for b in range(batch_size): 1352 | for k in range(K): 1353 | partial_seqs.append([wids[b][k]]) 1354 | else: 1355 | new_partial_seqs = [] 1356 | for b in range(batch_size): 1357 | for k in range(K): 1358 | new_partial_seqs.append( 1359 | partial_seqs[ptrs[b][k] + b * K] + [wids[b][k]]) 1360 | partial_seqs = new_partial_seqs 1361 | 1362 | def get_dup_ngram_candidates(seq, n): 1363 | cands = set() 1364 | if len(seq) < n: 1365 | return [] 1366 | tail = seq[-(n - 1):] 1367 | if self.forbid_ignore_set and any(tk in self.forbid_ignore_set for tk in tail): 1368 | return [] 1369 | for i in range(len(seq) - (n - 1)): 1370 | mismatch = False 1371 | for j in range(n - 1): 1372 | if tail[j] != seq[i + j]: 1373 | mismatch = True 1374 | break 1375 | if (not mismatch) and not ( 1376 | self.forbid_ignore_set and (seq[i + n - 1] in self.forbid_ignore_set)): 1377 | cands.add(seq[i + n - 1]) 1378 | return list(sorted(cands)) 1379 | 1380 | if len(partial_seqs[0]) >= self.ngram_size: 1381 | dup_cands = [] 1382 | for seq in partial_seqs: 1383 | dup_cands.append( 1384 | get_dup_ngram_candidates(seq, self.ngram_size)) 1385 | if max(len(x) for x in dup_cands) > 0: 1386 | if buf_matrix is None: 1387 | vocab_size = list(log_scores.size())[-1] 1388 | buf_matrix = np.zeros( 1389 | (batch_size * K, vocab_size), dtype=float) 1390 | else: 1391 | buf_matrix.fill(0) 1392 | for bk, cands in enumerate(dup_cands): 1393 | for i, wid in enumerate(cands): 1394 | buf_matrix[bk, wid] = 1.0 1395 | forbid_word_mask = torch.tensor( 1396 | buf_matrix, dtype=log_scores.dtype) 1397 | forbid_word_mask = torch.reshape( 1398 | forbid_word_mask, [batch_size * K, 1, vocab_size]).cuda() 1399 | else: 1400 | forbid_word_mask = None 1401 | next_pos += 1 1402 | 1403 | # [(batch, beam)] 1404 | total_scores = [x.tolist() for x in total_scores] 1405 | step_ids = [x.tolist() for x in step_ids] 1406 | step_back_ptrs = [x.tolist() for x in step_back_ptrs] 1407 | # back tracking 1408 | traces = {'pred_seq': [], 'scores': [], 'wids': [], 'ptrs': []} 1409 | for b in range(batch_size): 1410 | # [(beam,)] 1411 | scores = [x[b] for x in total_scores] 1412 | wids_list = [x[b] for x in step_ids] 1413 | ptrs = [x[b] for x in step_back_ptrs] 1414 | traces['scores'].append(scores) 1415 | traces['wids'].append(wids_list) 1416 | traces['ptrs'].append(ptrs) 1417 | # first we need to find the eos frame where all symbols are eos 1418 | # any frames after the eos frame are invalid 1419 | last_frame_id = len(scores) - 1 1420 | for i, wids in enumerate(wids_list): 1421 | if all(wid == self.eos_id for wid in wids): 1422 | last_frame_id = i 1423 | break 1424 | max_score = -math.inf 1425 | frame_id = -1 1426 | pos_in_frame = -1 1427 | 1428 | for fid in range(last_frame_id + 1): 1429 | for i, wid in enumerate(wids_list[fid]): 1430 | if wid == self.eos_id or fid == last_frame_id: 1431 | s = scores[fid][i] 1432 | if self.length_penalty > 0: 1433 | s /= math.pow((5 + fid + 1) / 6.0, 1434 | self.length_penalty) 1435 | if s > max_score: 1436 | max_score = s 1437 | frame_id = fid 1438 | pos_in_frame = i 1439 | if frame_id == -1: 1440 | traces['pred_seq'].append([0]) 1441 | else: 1442 | seq = [wids_list[frame_id][pos_in_frame]] 1443 | for fid in range(frame_id, 0, -1): 1444 | pos_in_frame = ptrs[fid][pos_in_frame] 1445 | seq.append(wids_list[fid - 1][pos_in_frame]) 1446 | seq.reverse() 1447 | traces['pred_seq'].append(seq) 1448 | 1449 | def _pad_sequence(sequences, max_len, padding_value=0): 1450 | trailing_dims = sequences[0].size()[1:] 1451 | out_dims = (len(sequences), max_len) + trailing_dims 1452 | 1453 | out_tensor = sequences[0].data.new(*out_dims).fill_(padding_value) 1454 | for i, tensor in enumerate(sequences): 1455 | length = tensor.size(0) 1456 | # use index notation to prevent duplicate references to the tensor 1457 | out_tensor[i, :length, ...] = tensor 1458 | return out_tensor 1459 | 1460 | # convert to tensors for DataParallel 1461 | for k in ('pred_seq', 'scores', 'wids', 'ptrs'): 1462 | ts_list = traces[k] 1463 | if not isinstance(ts_list[0], torch.Tensor): 1464 | dt = torch.float if k == 'scores' else torch.long 1465 | ts_list = [torch.tensor(it, dtype=dt) for it in ts_list] 1466 | traces[k] = _pad_sequence( 1467 | ts_list, output_length, padding_value=0).to(input_ids.device) 1468 | 1469 | return traces 1470 | -------------------------------------------------------------------------------- /src/models/tnlrv3/s2s_loader.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from random import randint 4 | import logging 5 | import torch 6 | import torch.utils.data 7 | 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | 12 | def get_random_word(vocab_words): 13 | i = randint(0, len(vocab_words)-1) 14 | return vocab_words[i] 15 | 16 | 17 | def batch_list_to_batch_tensors(batch): 18 | batch_tensors = [] 19 | for x in zip(*batch): 20 | if x[0] is None: 21 | batch_tensors.append(None) 22 | elif isinstance(x[0], torch.Tensor): 23 | batch_tensors.append(torch.stack(x)) 24 | else: 25 | batch_tensors.append(torch.tensor(x, dtype=torch.long)) 26 | return batch_tensors 27 | 28 | 29 | def _get_word_split_index(tokens, st, end): 30 | split_idx = [] 31 | i = st 32 | while i < end: 33 | if (not tokens[i].startswith('##')) or (i == st): 34 | split_idx.append(i) 35 | i += 1 36 | split_idx.append(end) 37 | return split_idx 38 | 39 | 40 | def _expand_whole_word(tokens, st, end): 41 | new_st, new_end = st, end 42 | while (new_st >= 0) and tokens[new_st].startswith('##'): 43 | new_st -= 1 44 | while (new_end < len(tokens)) and tokens[new_end].startswith('##'): 45 | new_end += 1 46 | return new_st, new_end 47 | 48 | 49 | class Pipeline(): 50 | """ Pre-process Pipeline Class : callable """ 51 | 52 | def __init__(self): 53 | super().__init__() 54 | self.skipgram_prb = None 55 | self.skipgram_size = None 56 | self.pre_whole_word = None 57 | self.mask_whole_word = None 58 | self.word_subsample_prb = None 59 | self.sp_prob = None 60 | self.pieces_dir = None 61 | self.vocab_words = None 62 | self.pieces_threshold = 10 63 | self.call_count = 0 64 | self.offline_mode = False 65 | self.skipgram_size_geo_list = None 66 | self.span_same_mask = False 67 | 68 | def __call__(self, instance): 69 | raise NotImplementedError 70 | 71 | 72 | class Preprocess4Seq2seqDecoder(Pipeline): 73 | """ Pre-processing steps for pretraining transformer """ 74 | 75 | def __init__(self, vocab_words, indexer, max_len=512, max_tgt_length=128, 76 | mode="s2s", pos_shift=False, source_type_id=0, target_type_id=1, 77 | cls_token='[CLS]', sep_token='[SEP]', pad_token='[PAD]'): 78 | super().__init__() 79 | self.max_len = max_len 80 | self.vocab_words = vocab_words # vocabulary (sub)words 81 | self.indexer = indexer # function from token to token index 82 | self.max_len = max_len 83 | self._tril_matrix = torch.tril(torch.ones((max_len, max_len), dtype=torch.long)) 84 | self.task_idx = 3 # relax projection layer for different tasks 85 | assert mode in ("s2s", "l2r") 86 | self.mode = mode 87 | self.max_tgt_length = max_tgt_length 88 | self.pos_shift = pos_shift 89 | 90 | self.delta = 1 if pos_shift else 2 91 | 92 | self.cls_token = cls_token 93 | self.sep_token = sep_token 94 | self.pad_token = pad_token 95 | 96 | self.source_type_id = source_type_id 97 | self.target_type_id = target_type_id 98 | 99 | self.cc = 0 100 | 101 | def __call__(self, instance): 102 | tokens_a, max_a_len = instance 103 | 104 | padded_tokens_a = [self.cls_token] + tokens_a 105 | if not self.pos_shift: 106 | padded_tokens_a = padded_tokens_a + [self.sep_token] 107 | assert len(padded_tokens_a) <= max_a_len + self.delta 108 | if max_a_len + self.delta > len(padded_tokens_a): 109 | padded_tokens_a += [self.pad_token] * \ 110 | (max_a_len + self.delta - len(padded_tokens_a)) 111 | assert len(padded_tokens_a) == max_a_len + self.delta 112 | max_len_in_batch = min(self.max_tgt_length + 113 | max_a_len + self.delta, self.max_len) 114 | tokens = padded_tokens_a 115 | segment_ids = [self.source_type_id] * (len(padded_tokens_a)) \ 116 | + [self.target_type_id] * (max_len_in_batch - len(padded_tokens_a)) 117 | 118 | mask_qkv = None 119 | 120 | position_ids = [] 121 | for i in range(len(tokens_a) + self.delta): 122 | position_ids.append(i) 123 | for i in range(len(tokens_a) + self.delta, max_a_len + self.delta): 124 | position_ids.append(0) 125 | for i in range(max_a_len + self.delta, max_len_in_batch): 126 | position_ids.append(i - (max_a_len + self.delta) + len(tokens_a) + self.delta) 127 | 128 | # Token Indexing 129 | input_ids = self.indexer(tokens) 130 | 131 | self.cc += 1 132 | if self.cc < 20: 133 | # print("Vocab size = %d" % len(self.vocab_words)) 134 | # for tk_id in input_ids: 135 | # print(u"trans %d -> %s" % (tk_id, self.vocab_words[tk_id])) 136 | logger.info(u"Input src = %s" % " ".join((self.vocab_words[tk_id]) for tk_id in input_ids)) 137 | 138 | # Zero Padding 139 | input_mask = torch.zeros( 140 | max_len_in_batch, max_len_in_batch, dtype=torch.long) 141 | if self.mode == "s2s": 142 | input_mask[:, :len(tokens_a) + self.delta].fill_(1) 143 | else: 144 | st, end = 0, len(tokens_a) + self.delta 145 | input_mask[st:end, st:end].copy_( 146 | self._tril_matrix[:end, :end]) 147 | input_mask[end:, :len(tokens_a) + self.delta].fill_(1) 148 | second_st, second_end = len(padded_tokens_a), max_len_in_batch 149 | 150 | input_mask[second_st:second_end, second_st:second_end].copy_( 151 | self._tril_matrix[:second_end-second_st, :second_end-second_st]) 152 | 153 | return (input_ids, segment_ids, position_ids, input_mask, mask_qkv, self.task_idx) 154 | -------------------------------------------------------------------------------- /src/models/tnlrv3/tokenization_tnlrv3.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | """Tokenization classes for TuringNLRv3.""" 3 | 4 | from __future__ import absolute_import, division, print_function, unicode_literals 5 | 6 | import collections 7 | import logging 8 | import os 9 | import unicodedata 10 | from io import open 11 | 12 | from transformers.tokenization_bert import BertTokenizer, whitespace_tokenize 13 | 14 | logger = logging.getLogger(__name__) 15 | 16 | VOCAB_FILES_NAMES = {'vocab_file': 'vocab.txt'} 17 | 18 | PRETRAINED_VOCAB_FILES_MAP = { 19 | 'vocab_file': 20 | { 21 | } 22 | } 23 | 24 | PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { 25 | } 26 | 27 | 28 | class TuringNLRv3Tokenizer(BertTokenizer): 29 | r""" 30 | Constructs a TuringNLRv3Tokenizer. 31 | :class:`~transformers.TuringNLRv3Tokenizer` is identical to BertTokenizer and runs end-to-end tokenization: punctuation splitting + wordpiece 32 | Args: 33 | vocab_file: Path to a one-wordpiece-per-line vocabulary file 34 | do_lower_case: Whether to lower case the input. Only has an effect when do_wordpiece_only=False 35 | do_basic_tokenize: Whether to do basic tokenization before wordpiece. 36 | max_len: An artificial maximum length to truncate tokenized sequences to; Effective maximum length is always the 37 | minimum of this value (if specified) and the underlying model's sequence length. 38 | never_split: List of tokens which will never be split during tokenization. Only has an effect when 39 | do_wordpiece_only=False 40 | """ 41 | 42 | vocab_files_names = VOCAB_FILES_NAMES 43 | pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP 44 | max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES 45 | 46 | 47 | class WhitespaceTokenizer(object): 48 | def tokenize(self, text): 49 | return whitespace_tokenize(text) 50 | -------------------------------------------------------------------------------- /src/models/tnlrv3/utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function 2 | 3 | import logging 4 | import os 5 | import json 6 | import random 7 | import glob 8 | import torch 9 | import tqdm 10 | import array 11 | import collections 12 | import torch.utils.data 13 | from transformers.file_utils import WEIGHTS_NAME 14 | try: 15 | import lmdb 16 | except: 17 | pass 18 | 19 | OPTIM_NAME = "optimizer.bin" 20 | 21 | 22 | logger = logging.getLogger(__name__) 23 | 24 | 25 | class TrainingExample(object): 26 | def __init__(self, source_ids, target_ids, example_id): 27 | self.source_ids = source_ids 28 | self.target_ids = target_ids 29 | self.example_id = example_id 30 | 31 | 32 | class Seq2seqDatasetForTuringNLRv3(torch.utils.data.Dataset): 33 | def __init__( 34 | self, features, max_source_len, max_target_len, 35 | vocab_size, cls_id, sep_id, pad_id, mask_id, 36 | random_prob, keep_prob, offset, num_training_instances, 37 | finetuning_method='v1', target_mask_prob=-1.0, num_max_mask_token=0, 38 | source_mask_prob=-1.0, 39 | ): 40 | self.features = features 41 | self.max_source_len = max_source_len 42 | self.max_target_len = max_target_len 43 | self.offset = offset 44 | if offset > 0: 45 | logger.info(" **** Set offset %d in Seq2seqDatasetForBert **** ", offset) 46 | self.cls_id = cls_id 47 | self.sep_id = sep_id 48 | self.pad_id = pad_id 49 | self.random_prob = random_prob 50 | self.keep_prob = keep_prob 51 | self.mask_id = mask_id 52 | self.vocab_size = vocab_size 53 | self.num_training_instances = num_training_instances 54 | self.target_mask_prob = target_mask_prob 55 | if finetuning_method == 'v0': 56 | num_max_mask_token = self.max_target_len 57 | logger.info("Mask way v0: set num_max_mask_token = %d" % num_max_mask_token) 58 | self.num_max_mask_token = num_max_mask_token 59 | self.finetuning_method = finetuning_method 60 | assert finetuning_method in ('v0', 'v1', 'v2') 61 | self.source_mask_prob = source_mask_prob 62 | 63 | def __len__(self): 64 | return self.num_training_instances 65 | 66 | def __trunk(self, ids, max_len, append_sep=True): 67 | if append_sep: 68 | max_len -= 1 69 | if len(ids) > max_len: 70 | ids = ids[:max_len] 71 | if append_sep: 72 | ids = ids + [self.sep_id] 73 | return ids 74 | 75 | def __pad(self, ids, max_len): 76 | if len(ids) < max_len: 77 | return ids + [self.pad_id] * (max_len - len(ids)) 78 | else: 79 | assert len(ids) == max_len 80 | return ids 81 | 82 | def get_masked_token(self, tk_id): 83 | p = random.random() 84 | if p < self.keep_prob: 85 | return tk_id 86 | elif p < self.keep_prob + self.random_prob: 87 | return random.randint(0, self.vocab_size - 1) 88 | else: 89 | return self.mask_id 90 | 91 | def __getitem__(self, _idx): 92 | idx = (self.offset + _idx) % len(self.features) 93 | # print("%d get %d" % (_idx, idx)) 94 | feature = self.features[idx] 95 | source_ids = self.__trunk([self.cls_id] + feature.source_ids, self.max_source_len, append_sep=self.finetuning_method != 'v0') 96 | target_ids = feature.target_ids 97 | if self.finetuning_method == 'v0': 98 | target_ids = [self.sep_id] + target_ids 99 | target_ids = self.__trunk(target_ids, self.max_target_len, append_sep=self.finetuning_method != 'v0') 100 | 101 | num_source_tokens = len(source_ids) 102 | num_target_tokens = len(target_ids) 103 | 104 | if self.source_mask_prob > 0: 105 | for i in range(num_source_tokens): 106 | tk_id = source_ids[i] 107 | if tk_id != self.cls_id and tk_id != self.sep_id: 108 | r = random.random() 109 | if r < self.source_mask_prob: 110 | source_ids[i] = self.get_masked_token(tk_id) 111 | 112 | source_ids = self.__pad(source_ids, self.max_source_len) 113 | target_ids = self.__pad(target_ids, self.max_target_len) 114 | 115 | if self.finetuning_method == 'v0': 116 | masked_pos = [] 117 | masked_ids = [] 118 | masked_weights = [] 119 | for pos in range(num_target_tokens): 120 | if pos + 1 != num_target_tokens: 121 | masked_ids.append(target_ids[pos + 1]) 122 | else: 123 | masked_ids.append(self.sep_id) 124 | masked_pos.append(pos) 125 | masked_weights.append(1) 126 | 127 | r = random.random() 128 | if r < self.target_mask_prob and pos > 0: 129 | target_ids[pos] = self.get_masked_token(target_ids[pos]) 130 | 131 | masked_ids = self.__pad(masked_ids, self.num_max_mask_token) 132 | masked_pos = self.__pad(masked_pos, self.num_max_mask_token) 133 | masked_weights = self.__pad(masked_weights, self.num_max_mask_token) 134 | 135 | return source_ids, target_ids, masked_ids, masked_pos, masked_weights, num_source_tokens, num_target_tokens 136 | elif self.finetuning_method == 'v1': 137 | masked_pos = list(range(num_target_tokens)) 138 | random.shuffle(masked_pos) 139 | 140 | num_masked_token = \ 141 | min(self.num_max_mask_token, int(self.target_mask_prob * num_target_tokens)) 142 | if num_masked_token <= 0: 143 | num_masked_token = 1 144 | 145 | masked_pos = masked_pos[:num_masked_token] 146 | 147 | masked_ids = [] 148 | masked_weights = [] 149 | for pos in masked_pos: 150 | masked_ids.append(target_ids[pos]) 151 | target_ids[pos] = self.get_masked_token(target_ids[pos]) 152 | masked_weights.append(1) 153 | 154 | masked_ids = self.__pad(masked_ids, self.num_max_mask_token) 155 | masked_pos = self.__pad(masked_pos, self.num_max_mask_token) 156 | masked_weights = self.__pad(masked_weights, self.num_max_mask_token) 157 | 158 | return source_ids, target_ids, masked_ids, masked_pos, masked_weights, num_source_tokens, num_target_tokens 159 | elif self.finetuning_method == 'v2': 160 | pseudo_ids = [] 161 | label_ids = [] 162 | for pos in range(num_target_tokens): 163 | tk_id = target_ids[pos] 164 | masked_tk_id = self.get_masked_token(tk_id) 165 | pseudo_ids.append(masked_tk_id) 166 | label_ids.append(tk_id) 167 | r = random.random() 168 | if r < self.target_mask_prob: 169 | target_ids[pos] = masked_tk_id 170 | label_ids = self.__pad(label_ids, self.max_target_len) 171 | pseudo_ids = self.__pad(pseudo_ids, self.max_target_len) 172 | 173 | return source_ids, target_ids, label_ids, pseudo_ids, num_source_tokens, num_target_tokens 174 | 175 | 176 | def batch_list_to_batch_tensors(batch): 177 | batch_tensors = [] 178 | for x in zip(*batch): 179 | if isinstance(x[0], torch.Tensor): 180 | batch_tensors.append(torch.stack(x)) 181 | else: 182 | batch_tensors.append(torch.tensor(x, dtype=torch.long)) 183 | return batch_tensors 184 | 185 | 186 | def get_max_epoch_model(output_dir): 187 | fn_model_list = glob.glob(os.path.join(output_dir, "ckpt-*/%s" % WEIGHTS_NAME)) 188 | fn_optim_list = glob.glob(os.path.join(output_dir, "ckpt-*/%s" % OPTIM_NAME)) 189 | if (not fn_model_list) or (not fn_optim_list): 190 | return None 191 | both_set = set([int(os.path.dirname(fn).split('-')[-1]) for fn in fn_model_list] 192 | ) & set([int(os.path.dirname(fn).split('-')[-1]) for fn in fn_optim_list]) 193 | if both_set: 194 | return max(both_set) 195 | else: 196 | return None 197 | 198 | 199 | def get_checkpoint_state_dict(output_dir, ckpt): 200 | model_recover_checkpoint = os.path.join(output_dir, "ckpt-%d" % ckpt, WEIGHTS_NAME) 201 | logger.info(" ** Recover model checkpoint in %s ** ", model_recover_checkpoint) 202 | model_state_dict = torch.load(model_recover_checkpoint, map_location='cpu') 203 | optimizer_recover_checkpoint = os.path.join(output_dir, "ckpt-%d" % ckpt, OPTIM_NAME) 204 | checkpoint_state_dict = torch.load(optimizer_recover_checkpoint, map_location='cpu') 205 | checkpoint_state_dict['model'] = model_state_dict 206 | return checkpoint_state_dict 207 | 208 | 209 | def report_length(length_counter, total_count): 210 | max_len = max(length_counter.keys()) 211 | a = 0 212 | tc = 0 213 | while a < max_len: 214 | cc = 0 215 | for i in range(16): 216 | cc += length_counter[a + i] 217 | 218 | tc += cc 219 | if cc > 0: 220 | logger.info("%d ~ %d = %d, %.2f%%" % (a, a + 16, cc, (tc * 100.0) / total_count)) 221 | a += 16 222 | 223 | 224 | def serialize_str(x): 225 | return u"{}".format(x).encode('ascii') 226 | 227 | 228 | def serialize_array(x, dtype): 229 | data = array.array(dtype) 230 | data.fromlist(x) 231 | return data.tobytes() 232 | 233 | def write_to_lmdb(db, key, value): 234 | success = False 235 | while not success: 236 | txn = db.begin(write=True) 237 | try: 238 | txn.put(key, value) 239 | txn.commit() 240 | success = True 241 | except lmdb.MapFullError: 242 | txn.abort() 243 | # double the map_size 244 | curr_limit = db.info()['map_size'] 245 | new_limit = curr_limit*2 246 | print('>>> Doubling LMDB map size to %sMB ...' % 247 | (new_limit >> 20,)) 248 | db.set_mapsize(new_limit) # double it 249 | 250 | 251 | def deserialize_str(x): 252 | return x.decode('ascii') 253 | 254 | 255 | class DocDB(object): 256 | def __init__(self, db_path): 257 | self.db_path = db_path 258 | self.env = lmdb.open(db_path, readonly=True, lock=False, readahead=False, meminit=False) 259 | with self.env.begin(write=False) as txn: 260 | self.start_key_index = int(deserialize_str(txn.get(b'__start__'))) 261 | self.size = int(deserialize_str(txn.get(b'__size__'))) 262 | self.dtype = deserialize_str(txn.get(b'__dtype__')) 263 | 264 | def _deserialize_array(self, x): 265 | data = array.array(self.dtype) 266 | data.frombytes(x) 267 | return data.tolist() 268 | 269 | def __getitem__(self, doc_id): 270 | with self.env.begin(write=False) as txn: 271 | # example = { 272 | # "source_ids": self._deserialize_array(txn.get(b"src_ids_%d" % doc_id)), 273 | # "target_ids": self._deserialize_array(txn.get(b"tgt_ids_%d" % doc_id)), 274 | # } 275 | example = TrainingExample( 276 | source_ids=self._deserialize_array(txn.get(b"src_ids_%d" % doc_id)), 277 | target_ids=self._deserialize_array(txn.get(b"tgt_ids_%d" % doc_id)), 278 | example_id=None, 279 | ) 280 | return example 281 | 282 | def __len__(self): 283 | return self.size 284 | 285 | 286 | def load_and_cache_examples( 287 | example_file, tokenizer, local_rank, cached_features_file, shuffle=True, 288 | lmdb_cache=None, lmdb_dtype='h', eval_mode=False): 289 | # Make sure only the first process in distributed training process the dataset, and the others will use the cache 290 | if local_rank not in [-1, 0]: 291 | torch.distributed.barrier() 292 | 293 | if cached_features_file is not None and os.path.isfile(cached_features_file): 294 | logger.info("Loading features from cached file %s", cached_features_file) 295 | features = torch.load(cached_features_file) 296 | elif cached_features_file is not None and os.path.isdir(cached_features_file) \ 297 | and os.path.exists(os.path.join(cached_features_file, 'lock.mdb')): 298 | logger.info("Loading features from cached LMDB %s", cached_features_file) 299 | features = DocDB(cached_features_file) 300 | else: 301 | logger.info("Creating features from dataset file at %s", example_file) 302 | 303 | examples = [] 304 | with open(example_file, mode="r", encoding="utf-8") as reader: 305 | for line in reader: 306 | examples.append(json.loads(line)) 307 | features = [] 308 | 309 | slc = collections.defaultdict(int) 310 | tlc = collections.defaultdict(int) 311 | 312 | for example in tqdm.tqdm(examples): 313 | if isinstance(example["src"], list): 314 | source_tokens = example["src"] 315 | target_tokens = [] if eval_mode else example["tgt"] 316 | else: 317 | source_tokens = tokenizer.tokenize(example["src"]) 318 | target_tokens = [] if eval_mode else tokenizer.tokenize(example["tgt"]) 319 | source_ids = tokenizer.convert_tokens_to_ids(source_tokens) 320 | target_ids = tokenizer.convert_tokens_to_ids(target_tokens) 321 | 322 | slc[len(source_ids)] += 1 323 | tlc[len(target_ids)] += 1 324 | 325 | features.append( 326 | TrainingExample( 327 | source_ids=source_ids, 328 | target_ids=target_ids, 329 | example_id=len(features), 330 | ) 331 | ) 332 | 333 | if shuffle: 334 | random.shuffle(features) 335 | logger.info("Shuffle the features !") 336 | 337 | logger.info("Source length:") 338 | report_length(slc, total_count=len(examples)) 339 | logger.info("Target length:") 340 | report_length(tlc, total_count=len(examples)) 341 | 342 | if local_rank in [-1, 0] and cached_features_file is not None: 343 | if lmdb_cache: 344 | db = lmdb.open(cached_features_file, readonly=False, map_async=True) 345 | for idx, feature in enumerate(features): 346 | write_to_lmdb( 347 | db, b"src_ids_%d" % idx, 348 | serialize_array(feature.source_ids, dtype=lmdb_dtype)) 349 | write_to_lmdb( 350 | db, b"tgt_ids_%d" % idx, 351 | serialize_array(feature.target_ids, dtype=lmdb_dtype)) 352 | write_to_lmdb(db, b"__start__", serialize_str(0)) 353 | write_to_lmdb(db, b"__size__", serialize_str(len(features))) 354 | write_to_lmdb(db, b"__dtype__", serialize_str(lmdb_dtype)) 355 | db.sync() 356 | db.close() 357 | logger.info("db_key_idx = %d" % len(features)) 358 | del features 359 | features = cached_features_file 360 | logger.info("Saving features into cached lmdb dir %s", cached_features_file) 361 | else: 362 | logger.info("Saving features into cached file %s", cached_features_file) 363 | torch.save(features, cached_features_file) 364 | 365 | # Make sure only the first process in distributed training process the dataset, and the others will use the cache 366 | if local_rank == 0: 367 | torch.distributed.barrier() 368 | 369 | return features 370 | -------------------------------------------------------------------------------- /src/parameters.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | 4 | import src.utils as utils 5 | 6 | 7 | def parse_args(): 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument("--mode", type=str, default="train", choices=['train', 'test']) 10 | parser.add_argument("--train_data_path", type=str, default="./data/dblp_graph_data/train.tsv", ) 11 | parser.add_argument("--train_batch_size", type=int, default=30) 12 | parser.add_argument("--valid_data_path", type=str, default="./data/dblp_graph_data/valid.tsv") 13 | parser.add_argument("--valid_batch_size", type=int, default=300) 14 | parser.add_argument("--test_data_path", type=str, default="./data/dblp_graph_data/test.tsv") 15 | parser.add_argument("--test_batch_size", type=int, default=300) 16 | 17 | parser.add_argument("--model_dir", type=str, default='./ckpt') # path to save 18 | parser.add_argument("--enable_gpu", type=utils.str2bool, default=True) 19 | 20 | parser.add_argument("--savename", type=str, default='GraphFormers') 21 | parser.add_argument("--world_size", type=int, default=8) 22 | parser.add_argument("--token_length", type=int, default=32) 23 | parser.add_argument("--neighbor_num", type=int, default=5) 24 | 25 | # model training 26 | parser.add_argument("--epochs", type=int, default=100) 27 | parser.add_argument("--log_steps", type=int, default=1000) 28 | parser.add_argument("--mlm", type=utils.str2bool, default=False) 29 | parser.add_argument("--random_seed", type=int, default=42) 30 | 31 | # turing 32 | parser.add_argument("--model_type", default="GraphFormers", type=str) 33 | parser.add_argument("--model_name_or_path", default="./TuringModels/graphformers-dblp.pt", type=str, 34 | help="Path to pre-trained model or shortcut name. ") 35 | parser.add_argument("--config_name", default="./TuringModels/unilm2-base-uncased-config.json", type=str, 36 | help="Pretrained config name or path if not the same as model_name") 37 | 38 | parser.add_argument( 39 | "--load_ckpt_name", 40 | type=str, 41 | default='ckpt/GraphFormers/GraphFormers-epoch-1.pt', 42 | help="choose which ckpt to load and test" 43 | ) 44 | 45 | parser.add_argument("--lr", type=float, default=1e-5) 46 | 47 | # half float 48 | parser.add_argument("--fp16", type=utils.str2bool, default=True) 49 | 50 | args = parser.parse_args() 51 | logging.info(args) 52 | return args 53 | 54 | 55 | if __name__ == "__main__": 56 | args = parse_args() 57 | -------------------------------------------------------------------------------- /src/run.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import random 4 | import time 5 | from collections import defaultdict 6 | 7 | import numpy as np 8 | import torch 9 | import torch.distributed as dist 10 | import torch.optim as optim 11 | from torch.nn.parallel import DistributedDataParallel as DDP 12 | 13 | from src.data_handler import DatasetForMatching, DataCollatorForMatching, SingleProcessDataLoader, \ 14 | MultiProcessDataLoader 15 | from src.models.tnlrv3.configuration_tnlrv3 import TuringNLRv3Config 16 | 17 | 18 | def setup(rank, args): 19 | # initialize the process group 20 | dist.init_process_group("nccl", rank=rank, world_size=args.world_size) 21 | torch.cuda.set_device(rank) 22 | # Explicitly setting seed 23 | torch.manual_seed(args.random_seed) 24 | np.random.seed(args.random_seed) 25 | random.seed(args.random_seed) 26 | 27 | 28 | def cleanup(): 29 | dist.destroy_process_group() 30 | 31 | 32 | def load_bert(args): 33 | config = TuringNLRv3Config.from_pretrained( 34 | args.config_name if args.config_name else args.model_name_or_path, 35 | output_hidden_states=True) 36 | if args.model_type == "GraphFormers": 37 | from src.models.modeling_graphformers import GraphFormersForNeighborPredict 38 | model = GraphFormersForNeighborPredict(config) 39 | model.load_state_dict(torch.load(args.model_name_or_path, map_location="cpu")['model_state_dict'], strict=False) 40 | # model = GraphFormersForNeighborPredict.from_pretrained(args.model_name_or_path, config=config) 41 | elif args.model_type == "GraphSageMax": 42 | from src.models.modeling_graphsage import GraphSageMaxForNeighborPredict 43 | model = GraphSageMaxForNeighborPredict.from_pretrained(args.model_name_or_path, config=config) 44 | return model 45 | 46 | 47 | def train(local_rank, args, end, load): 48 | try: 49 | if local_rank == 0: 50 | from src.utils import setuplogging 51 | setuplogging() 52 | os.environ["RANK"] = str(local_rank) 53 | setup(local_rank, args) 54 | if args.fp16: 55 | from torch.cuda.amp import autocast 56 | scaler = torch.cuda.amp.GradScaler() 57 | 58 | model = load_bert(args) 59 | logging.info('loading model: {}'.format(args.model_type)) 60 | model = model.cuda() 61 | 62 | if load: 63 | model.load_state_dict(torch.load(args.load_ckpt_name, map_location="cpu")) 64 | logging.info('load ckpt:{}'.format(args.load_ckpt_name)) 65 | 66 | if args.world_size > 1: 67 | ddp_model = DDP(model, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True) 68 | else: 69 | ddp_model = model 70 | 71 | optimizer = optim.Adam([{'params': ddp_model.parameters(), 'lr': args.lr}]) 72 | 73 | data_collator = DataCollatorForMatching(mlm=args.mlm, neighbor_num=args.neighbor_num, 74 | token_length=args.token_length, random_seed=args.random_seed) 75 | loss = 0.0 76 | global_step = 0 77 | best_acc, best_count = 0.0, 0 78 | for ep in range(args.epochs): 79 | start_time = time.time() 80 | ddp_model.train() 81 | dataset = DatasetForMatching(file_path=args.train_data_path) 82 | if args.world_size > 1: 83 | end.value = False 84 | dataloader = MultiProcessDataLoader(dataset, 85 | batch_size=args.train_batch_size, 86 | collate_fn=data_collator, 87 | local_rank=local_rank, 88 | world_size=args.world_size, 89 | global_end=end) 90 | else: 91 | dataloader = SingleProcessDataLoader(dataset, batch_size=args.train_batch_size, 92 | collate_fn=data_collator, blocking=True) 93 | for step, batch in enumerate(dataloader): 94 | if args.enable_gpu: 95 | for k, v in batch.items(): 96 | if v is not None: 97 | batch[k] = v.cuda(non_blocking=True) 98 | 99 | if args.fp16: 100 | with autocast(): 101 | batch_loss = ddp_model(**batch) 102 | else: 103 | batch_loss = ddp_model(**batch) 104 | loss += batch_loss.item() 105 | optimizer.zero_grad() 106 | if args.fp16: 107 | scaler.scale(batch_loss).backward() 108 | scaler.step(optimizer) 109 | scaler.update() 110 | else: 111 | batch_loss.backward() 112 | optimizer.step() 113 | 114 | global_step += 1 115 | 116 | if local_rank == 0 and global_step % args.log_steps == 0: 117 | logging.info( 118 | '[{}] cost_time:{} step:{}, lr:{}, train_loss: {:.5f}'.format( 119 | local_rank, time.time() - start_time, global_step, optimizer.param_groups[0]['lr'], 120 | loss / args.log_steps)) 121 | loss = 0.0 122 | 123 | dist.barrier() 124 | logging.info("train time:{}".format(time.time() - start_time)) 125 | 126 | if local_rank == 0: 127 | ckpt_path = os.path.join(args.model_dir, '{}-epoch-{}.pt'.format(args.savename, ep + 1)) 128 | torch.save(model.state_dict(), ckpt_path) 129 | logging.info(f"Model saved to {ckpt_path}") 130 | 131 | logging.info("Star validation for epoch-{}".format(ep + 1)) 132 | acc = test_single_process(model, args, "valid") 133 | logging.info("validation time:{}".format(time.time() - start_time)) 134 | if acc > best_acc: 135 | ckpt_path = os.path.join(args.model_dir, '{}-best.pt'.format(args.savename)) 136 | torch.save(model.state_dict(), ckpt_path) 137 | logging.info(f"Model saved to {ckpt_path}") 138 | best_acc = acc 139 | best_count = 0 140 | else: 141 | best_count += 1 142 | if best_count >= 2: 143 | start_time = time.time() 144 | ckpt_path = os.path.join(args.model_dir, '{}-best.pt'.format(args.savename)) 145 | model.load_state_dict(torch.load(ckpt_path, map_location="cpu")) 146 | logging.info("Star testing for best") 147 | acc = test_single_process(model, args, "test") 148 | logging.info("test time:{}".format(time.time() - start_time)) 149 | exit() 150 | dist.barrier() 151 | 152 | if local_rank == 0: 153 | start_time = time.time() 154 | ckpt_path = os.path.join(args.model_dir, '{}-best.pt'.format(args.savename)) 155 | model.load_state_dict(torch.load(ckpt_path, map_location="cpu")) 156 | logging.info("Star testing for best") 157 | acc = test_single_process(model, args, "test") 158 | logging.info("test time:{}".format(time.time() - start_time)) 159 | dist.barrier() 160 | cleanup() 161 | except: 162 | import sys 163 | import traceback 164 | error_type, error_value, error_trace = sys.exc_info() 165 | traceback.print_tb(error_trace) 166 | logging.info(error_value) 167 | 168 | 169 | @torch.no_grad() 170 | def test_single_process(model, args, mode): 171 | assert mode in {"valid", "test"} 172 | model.eval() 173 | 174 | data_collator = DataCollatorForMatching(mlm=args.mlm, neighbor_num=args.neighbor_num, 175 | token_length=args.token_length, random_seed=args.random_seed) 176 | if mode == "valid": 177 | dataset = DatasetForMatching(file_path=args.valid_data_path) 178 | dataloader = SingleProcessDataLoader(dataset, batch_size=args.valid_batch_size, collate_fn=data_collator) 179 | elif mode == "test": 180 | dataset = DatasetForMatching(file_path=args.test_data_path) 181 | dataloader = SingleProcessDataLoader(dataset, batch_size=args.test_batch_size, collate_fn=data_collator) 182 | 183 | count = 0 184 | metrics_total = defaultdict(float) 185 | for step, batch in enumerate(dataloader): 186 | if args.enable_gpu: 187 | for k, v in batch.items(): 188 | if v is not None: 189 | batch[k] = v.cuda(non_blocking=True) 190 | 191 | metrics = model.test(**batch) 192 | for k, v in metrics.items(): 193 | metrics_total[k] += v 194 | count += 1 195 | for key in metrics_total: 196 | metrics_total[key] /= count 197 | logging.info("mode: {}, {}:{}".format(mode, key, metrics_total[key])) 198 | model.train() 199 | return metrics_total['main'] 200 | 201 | 202 | def test(args): 203 | model = load_bert(args) 204 | logging.info('loading model: {}'.format(args.model_type)) 205 | model = model.cuda() 206 | 207 | checkpoint = torch.load(args.load_ckpt_name, map_location="cpu") 208 | model.load_state_dict(checkpoint['model_state_dict'], strict=False) 209 | logging.info('load ckpt:{}'.format(args.load_ckpt_name)) 210 | 211 | test_single_process(model, args, "test") 212 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import sys 3 | import torch 4 | import numpy as np 5 | import argparse 6 | from sklearn.metrics import roc_auc_score 7 | 8 | def str2bool(v): 9 | if isinstance(v, bool): 10 | return v 11 | if v.lower() in ("yes", "true", "t", "y", "1"): 12 | return True 13 | elif v.lower() in ("no", "false", "f", "n", "0"): 14 | return False 15 | else: 16 | raise argparse.ArgumentTypeError("Boolean value expected.") 17 | 18 | 19 | 20 | def setuplogging(): 21 | root = logging.getLogger() 22 | root.setLevel(logging.INFO) 23 | handler = logging.StreamHandler(sys.stdout) 24 | handler.setLevel(logging.INFO) 25 | formatter = logging.Formatter("[%(levelname)s %(asctime)s] %(message)s") 26 | handler.setFormatter(formatter) 27 | root.addHandler(handler) 28 | 29 | def acc(y_true, y_hat): 30 | y_hat = torch.argmax(y_hat, dim=-1) 31 | tot = y_true.shape[0] 32 | hit = torch.sum(y_true == y_hat) 33 | return hit.data.float() * 1.0 / tot 34 | 35 | 36 | def dcg_score(y_true, y_score, k=10): 37 | order = np.argsort(y_score)[::-1] 38 | y_true = np.take(y_true, order[:k]) 39 | gains = 2**y_true - 1 40 | discounts = np.log2(np.arange(len(y_true)) + 2) 41 | return np.sum(gains / discounts) 42 | 43 | 44 | def ndcg_score(y_true, y_score, k=10): 45 | best = dcg_score(y_true, y_true, k) 46 | actual = dcg_score(y_true, y_score, k) 47 | return actual / best 48 | 49 | 50 | def mrr_score(y_true, y_score): 51 | order = np.argsort(y_score)[::-1] 52 | y_true = np.take(y_true, order) 53 | rr_score = y_true / (np.arange(len(y_true)) + 1) 54 | return np.sum(rr_score) / np.sum(y_true) 55 | 56 | --------------------------------------------------------------------------------