├── .gitignore ├── README.md ├── args.py ├── models.py ├── run.sh ├── tokenizer.py ├── train.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # test module 个人习惯,测试目录去掉 10 | /test/ 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | pip-wheel-metadata/ 27 | share/python-wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | .idea/ 39 | 40 | # Installer logs 41 | pip-log.txt 42 | pip-delete-this-directory.txt 43 | 44 | # Unit test / coverage reports 45 | htmlcov/ 46 | .tox/ 47 | .nox/ 48 | .coverage 49 | .coverage.* 50 | .cache 51 | nosetests.xml 52 | coverage.xml 53 | *.cover 54 | .hypothesis/ 55 | .pytest_cache/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # IPython 83 | profile_default/ 84 | ipython_config.py 85 | 86 | # pyenv 87 | .python-version 88 | 89 | # celery beat schedule file 90 | celerybeat-schedule 91 | 92 | # SageMath parsed files 93 | *.sage.py 94 | 95 | # Environments 96 | .env 97 | .venv 98 | env/ 99 | venv/ 100 | ENV/ 101 | env.bak/ 102 | venv.bak/ 103 | 104 | # Spyder project settings 105 | .spyderproject 106 | .spyproject 107 | 108 | # Rope project settings 109 | .ropeproject 110 | 111 | # mkdocs documentation 112 | /site 113 | 114 | # mypy 115 | .mypy_cache/ 116 | .dmypy.json 117 | dmypy.json 118 | 119 | # Pyre type checker 120 | .pyre/ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # t5-pegasus pytorch 2 | ## 最新更新 3 | - 重构代码,支持更多模型 4 | - 支持transformers最新版本 5 | [老版代码点这里](https://github.com/renmada/t5-pegasus-pytorch/tree/legacy) 6 | ## 模型效果对比 7 | 数据集:[LCSTS_new](https://www.luge.ai/#/luge/dataDetail?id=10) 8 | 训练集取前一万条,验证集取前一千条 9 | 10 | | model | bleu | rouge-1 | rouge-2 | rouge-l | 11 | |----------------------|-------------|---------------|--------------|--------------| 12 | | t5-pegasus-base | 0.1276 | 0.3490 | 0.2123 | 0.3155 | 13 | | t5-copy | 0.0938 | 0.3369 | 0.1955 | 0.3086 | 14 | | Pegasus-238M-Chinese | 0.1200 | 0.3252 | 0.1957 | 0.2924 | 15 | | Pegasus-523M-Chinese | 0.1233 | 0.3313 | 0.2032 | 0.2996 | 16 | | cpt-large | **0.1366** | **0.3550** | **0.2242** | **0.3220** | 17 | | prophet-zh | 0.1240 | 0.3419 | 0.2109 | 0.3107 | 18 | 19 | ## 数据格式 20 | [样例数据](https://github.com/renmada/t5-pegasus-pytorch/blob/legacy/examples/sample_data.json) 21 | ## huggingface模型 22 | 23 | | model_type | model_type | 24 | |-------------|----------------------------------------| 25 | | t5-pegasus | imxly/t5-pegasus | 26 | | t5copy | imxly/t5-copy | 27 | | Pegasus | IDEA-CCNL/Randeng-Pegasus-238M-Chinese | 28 | | Pegasus | IDEA-CCNL/Randeng-Pegasus-523M-Chinese | 29 | | cpt | fnlp/cpt-large | 30 | | prophet | imxly/prophetnet-zh | 31 | 32 | 33 | ## 训练命令 34 | ### requirements 35 | 环境可以参考这个[issue](https://github.com/renmada/t5-pegasus-pytorch/issues/58) 36 | ``` 37 | torch >=1.10.0 38 | transformers 39 | pytorch_lightning==1.4.9 40 | torchmetrics==0.5.0 41 | ``` 42 | model_type见上方表格 43 | ```shell 44 | python train.py \ 45 | --train_file train.json \ 46 | --dev_file dev.json \ 47 | --batch_size 6 \ 48 | --max_epochs 10 \ 49 | --max_source_length 512 \ 50 | --max_target_length 300 \ 51 | --model_path imxly/t5-pegasus \ 52 | --gpus 4 \ 53 | --lr 5e-5 \ 54 | --model_type t5-pegasus 55 | ``` 56 | ## 参考 57 | https://github.com/ZhuiyiTechnology/t5-pegasus 58 | https://github.com/fastnlp/CPT 59 | https://github.com/IDEA-CCNL/Fengshenbang-LM 60 | https://github.com/microsoft/ProphetNet 61 | 62 | 63 | -------------------------------------------------------------------------------- /args.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | parser = argparse.ArgumentParser() 3 | # ========================= training========================== 4 | parser.add_argument('--warmup_steps', type=int, default=1000) 5 | parser.add_argument('--warmup_proportion', default=0.1, type=float) 6 | parser.add_argument('--weight_decay', default=0.01, type=float) 7 | parser.add_argument('--lr', default=1e-4, type=float, help='initial learning rate') 8 | parser.add_argument('--batch_size', default=16, type=int) 9 | parser.add_argument('--max_epochs', default=10, type=int) 10 | parser.add_argument('--accumulate_grad_batches', default=1, type=int) 11 | parser.add_argument('--seed', default=12, type=int) 12 | parser.add_argument('--eval_delay', default=0, type=int) 13 | parser.add_argument('--precision', default=32, type=int) 14 | parser.add_argument('--plugins', type=str, default='ddp_sharded') 15 | parser.add_argument('--gpus', type=int, default=1) 16 | parser.add_argument('--kfold', type=int, default=1) 17 | parser.add_argument('--recompute', action='store_true') 18 | parser.add_argument('--ls_eps', default=0., type=float) 19 | 20 | # ========================= Data ========================== 21 | parser.add_argument('--train_file', type=str, required=False) 22 | parser.add_argument('--dev_file', type=str, required=False) 23 | parser.add_argument('--predict_file', type=str, required=False) 24 | parser.add_argument('--noise_prob', default=0., type=float) 25 | parser.add_argument('--max_source_length', default=512, type=int) 26 | parser.add_argument('--max_target_length', default=300, type=int) 27 | parser.add_argument('--beams', default=3, type=int) 28 | parser.add_argument('--num_workers', type=int, default=4) 29 | 30 | # ========================= Model ========================== 31 | parser.add_argument('--model_path', type=str) 32 | parser.add_argument('--model_type', type=str) 33 | parser.add_argument('--rdrop', action='store_true') 34 | parser.add_argument('--rdrop_coef', default=5, type=float) 35 | parser.add_argument('--output_dir', type=str, default='./output') 36 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import math 2 | import random 3 | import copy 4 | from typing import Optional, Tuple 5 | 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | from torch.nn import LayerNorm 11 | import torch.utils.checkpoint 12 | from torch.nn import CrossEntropyLoss 13 | 14 | from transformers.activations import ACT2FN 15 | from transformers.file_utils import ( 16 | add_code_sample_docstrings, 17 | add_start_docstrings, 18 | add_start_docstrings_to_model_forward, 19 | replace_return_docstrings, 20 | ) 21 | from transformers.modeling_outputs import ( 22 | BaseModelOutput, 23 | BaseModelOutputWithPastAndCrossAttentions, 24 | Seq2SeqLMOutput, 25 | Seq2SeqModelOutput, 26 | 27 | ) 28 | from transformers.modeling_utils import PreTrainedModel 29 | from transformers.utils import logging 30 | from transformers import BartConfig as CPTConfig 31 | from transformers import BertModel, BertConfig 32 | from transformers import T5ForConditionalGeneration, PegasusForConditionalGeneration 33 | from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, get_linear_schedule_with_warmup 34 | from transformers import ( 35 | ProphetNetEncoder, 36 | ProphetNetDecoder, 37 | ProphetNetPreTrainedModel, 38 | ProphetNetModel as OldProphetNetModel, 39 | ProphetNetForConditionalGeneration as OldProphetNetForConditionalGeneration) 40 | import pytorch_lightning as pl 41 | 42 | from tokenizer import * 43 | from utils import * 44 | 45 | logger = logging.get_logger(__name__) 46 | _CHECKPOINT_FOR_DOC = "fnlp/cpt-large" 47 | _CONFIG_FOR_DOC = "CPTConfig" 48 | _TOKENIZER_FOR_DOC = "CPTTokenizer" 49 | CPT_PRETRAINED_MODEL_ARCHIVE_LIST = [ 50 | "fnlp/cpt-large"] 51 | 52 | 53 | class LightModel(pl.LightningModule): 54 | def __init__(self, args): 55 | super().__init__() 56 | self.args = args 57 | if self.args.model_type is not None: 58 | model_cls, tokenizer_cls = MODEL_CLASSES[self.args.model_type] 59 | else: 60 | model_cls, tokenizer_cls = AutoModelForSeq2SeqLM, AutoTokenizer 61 | self.model = model_cls.from_pretrained(self.args.model_path) 62 | self.tokenizer = tokenizer_cls.from_pretrained(self.args.model_path) 63 | 64 | if args.recompute: 65 | self.model.gradient_checkpointing_enable() 66 | 67 | def forward(self, batch): 68 | labels = batch.pop('labels') 69 | model_output = self.model(**batch) 70 | batch['labels'] = labels 71 | return model_output 72 | 73 | def training_step(self, batch, batch_idx): 74 | model_output = self(batch) 75 | logits = model_output.logits 76 | is_prob = True if hasattr(self.model, 'generator') else False 77 | loss = ce_loss(logits, batch['labels'], is_prob=is_prob, eps=self.args.ls_eps) 78 | if self.args.rdrop: 79 | logits2 = self(batch).logits 80 | loss = (loss + ce_loss(logits2, batch['labels'], is_prob=is_prob, eps=self.args.ls_eps)) / 2 81 | loss = loss + self.args.rdrop_coef * kl_loss(logits, logits2, batch['decoder_attention_mask']) 82 | 83 | return loss 84 | 85 | def predict_batch(self, batch): 86 | model_kwargs = dict(eos_token_id=self.tokenizer.eos_token_id, 87 | decoder_start_token_id=self.tokenizer.bos_token_id, 88 | num_beams=self.args.beams, 89 | input_ids=batch['input_ids'], 90 | attention_mask=batch['attention_mask'], 91 | use_cache=True, 92 | max_length=self.args.max_target_length, 93 | ) 94 | if hasattr(self.model, 'generator'): 95 | model_kwargs['src'] = batch['input_ids'] 96 | pred = self.model.generate(**model_kwargs) 97 | pred = self.ids2text(pred) 98 | return pred 99 | 100 | def ids2text(self, ids): 101 | ids = ids.cpu().numpy() 102 | text = self.tokenizer.batch_decode(ids, skip_special_tokens=True) 103 | text = [s.replace(' ', '').lower() for s in text] 104 | return text 105 | 106 | def predict_step(self, batch, batch_idx, dataloader_idx=None): 107 | return self.predict_batch(batch) 108 | 109 | def validation_step(self, batch, batch_idx): 110 | ret = {'bleu': 0, 'rouge': 0, 'rouge-1': 0, 'rouge-2': 0, 'rouge-l': 0} 111 | if self.current_epoch < self.args.eval_delay: 112 | return ret 113 | pred = self.predict_batch(batch) 114 | labels = self.ids2text( 115 | batch['labels'].masked_fill_(batch['labels'] == -100, self.tokenizer.pad_token_id) 116 | ) 117 | rouge_score = compute_rouge(labels, pred) 118 | ret.update(rouge_score) 119 | bleu_score = compute_bleu(labels, pred) 120 | ret['bleu'] = bleu_score 121 | return ret 122 | 123 | def validation_epoch_end(self, outputs): 124 | ret = {'bleu': 0, 'rouge': 0, 'rouge-1': 0, 'rouge-2': 0, 'rouge-l': 0} 125 | if self.current_epoch < self.args.eval_delay: 126 | return ret 127 | keys = outputs[0].keys() 128 | ret = {k: np.mean([x[k] for x in outputs]) for k in keys} 129 | for k, v in ret.items(): 130 | self.log(k, v, prog_bar=True) 131 | return ret 132 | 133 | def configure_optimizers(self): 134 | optimizer = create_optimizer(self.model, self.args.lr, self.args.weight_decay) 135 | if self.args.max_epochs == -1: 136 | t_total = self.args.max_steps // self.args.accumulate_grad_batches 137 | else: 138 | t_total = len(self.train_dataloader()) // self.args.accumulate_grad_batches * self.args.max_epochs 139 | if self.args.warmup_steps != -1: 140 | warmup_steps = self.args.warmup_steps 141 | else: 142 | warmup_steps = int(self.args.warmup_proportion * t_total) 143 | scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, 144 | num_training_steps=t_total) 145 | return [optimizer], [{"scheduler": scheduler, "interval": "step"}] 146 | 147 | 148 | def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int): 149 | """ 150 | Shift input ids one token to the right. 151 | """ 152 | shifted_input_ids = input_ids.new_zeros(input_ids.shape) 153 | shifted_input_ids[:, 1:] = input_ids[:, :-1].clone() 154 | shifted_input_ids[:, 0] = decoder_start_token_id 155 | 156 | assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined." 157 | # replace possible -100 values in labels by `pad_token_id` 158 | shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) 159 | 160 | return shifted_input_ids 161 | 162 | 163 | def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_values_length: int = 0): 164 | """ 165 | Make causal mask used for bi-directional self-attention. 166 | """ 167 | bsz, tgt_len = input_ids_shape 168 | mask = torch.full((tgt_len, tgt_len), float("-inf")) 169 | mask_cond = torch.arange(mask.size(-1)) 170 | mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) 171 | mask = mask.to(dtype) 172 | 173 | if past_key_values_length > 0: 174 | mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype), mask], dim=-1) 175 | return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) 176 | 177 | 178 | def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): 179 | """ 180 | Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. 181 | """ 182 | bsz, src_len = mask.size() 183 | tgt_len = tgt_len if tgt_len is not None else src_len 184 | 185 | expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) 186 | 187 | inverted_mask = 1.0 - expanded_mask 188 | 189 | return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min) 190 | 191 | 192 | def attention_mask_func(attention_scores, attention_mask): 193 | return attention_scores + attention_mask 194 | 195 | 196 | def init_method(std): 197 | def init_(tensor): 198 | return torch.nn.init.normal_(tensor, mean=0.0, std=std) 199 | 200 | return init_ 201 | 202 | 203 | class CPTLearnedPositionalEmbedding(nn.Embedding): 204 | """ 205 | This module learns positional embeddings up to a fixed maximum size. 206 | """ 207 | 208 | def __init__(self, num_embeddings: int, embedding_dim: int): 209 | # CPT is set up so that if padding_idx is specified then offset the embedding ids by 2 210 | # and adjust num_embeddings appropriately. Other models dont have this hack 211 | self.offset = 2 212 | super().__init__(num_embeddings + self.offset, embedding_dim) 213 | 214 | def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0): 215 | """`input_ids_shape` is expected to be [bsz x seqlen].""" 216 | bsz, seq_len = input_ids_shape[:2] 217 | positions = torch.arange( 218 | past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device 219 | ) 220 | return super().forward(positions + self.offset) 221 | 222 | 223 | class CPTAttention(nn.Module): 224 | """Multi-headed attention from 'Attention Is All You Need' paper""" 225 | 226 | def __init__( 227 | self, 228 | embed_dim: int, 229 | num_heads: int, 230 | dropout: float = 0.0, 231 | is_decoder: bool = False, 232 | bias: bool = True, 233 | ): 234 | super().__init__() 235 | self.embed_dim = embed_dim 236 | self.num_heads = num_heads 237 | self.dropout = dropout 238 | self.head_dim = embed_dim // num_heads 239 | assert ( 240 | self.head_dim * num_heads == self.embed_dim 241 | ), f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {num_heads})." 242 | self.scaling = self.head_dim ** -0.5 243 | self.is_decoder = is_decoder 244 | 245 | self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) 246 | self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) 247 | self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) 248 | self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) 249 | 250 | def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): 251 | return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() 252 | 253 | def forward( 254 | self, 255 | hidden_states: torch.Tensor, 256 | key_value_states: Optional[torch.Tensor] = None, 257 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 258 | attention_mask: Optional[torch.Tensor] = None, 259 | layer_head_mask: Optional[torch.Tensor] = None, 260 | output_attentions: bool = False, 261 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 262 | """Input shape: Batch x Time x Channel""" 263 | 264 | # if key_value_states are provided this layer is used as a cross-attention layer 265 | # for the decoder 266 | is_cross_attention = key_value_states is not None 267 | bsz, tgt_len, embed_dim = hidden_states.size() 268 | 269 | # get query proj 270 | query_states = self.q_proj(hidden_states) * self.scaling 271 | # get key, value proj 272 | if is_cross_attention and past_key_value is not None: 273 | # reuse k,v, cross_attentions 274 | key_states = past_key_value[0] 275 | value_states = past_key_value[1] 276 | elif is_cross_attention: 277 | # cross_attentions 278 | key_states = self._shape(self.k_proj(key_value_states), -1, bsz) 279 | value_states = self._shape(self.v_proj(key_value_states), -1, bsz) 280 | elif past_key_value is not None: 281 | # reuse k, v, self_attention 282 | key_states = self._shape(self.k_proj(hidden_states), -1, bsz) 283 | value_states = self._shape(self.v_proj(hidden_states), -1, bsz) 284 | key_states = torch.cat([past_key_value[0], key_states], dim=2) 285 | value_states = torch.cat([past_key_value[1], value_states], dim=2) 286 | else: 287 | # self_attention 288 | key_states = self._shape(self.k_proj(hidden_states), -1, bsz) 289 | value_states = self._shape(self.v_proj(hidden_states), -1, bsz) 290 | 291 | if self.is_decoder: 292 | # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. 293 | # Further calls to cross_attention layer can then reuse all cross-attention 294 | # key/value_states (first "if" case) 295 | # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of 296 | # all previous decoder key/value_states. Further calls to uni-directional self-attention 297 | # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) 298 | # if encoder bi-directional self-attention `past_key_value` is always `None` 299 | past_key_value = (key_states, value_states) 300 | 301 | proj_shape = (bsz * self.num_heads, -1, self.head_dim) 302 | query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) 303 | key_states = key_states.view(*proj_shape) 304 | value_states = value_states.view(*proj_shape) 305 | 306 | src_len = key_states.size(1) 307 | attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) 308 | 309 | assert attn_weights.size() == ( 310 | bsz * self.num_heads, 311 | tgt_len, 312 | src_len, 313 | ), f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}" 314 | 315 | if attention_mask is not None: 316 | assert attention_mask.size() == ( 317 | bsz, 318 | 1, 319 | tgt_len, 320 | src_len, 321 | ), f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" 322 | attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask 323 | attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) 324 | 325 | attn_weights = F.softmax(attn_weights, dim=-1) 326 | 327 | if layer_head_mask is not None: 328 | assert layer_head_mask.size() == ( 329 | self.num_heads, 330 | ), f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}" 331 | attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) 332 | attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) 333 | 334 | if output_attentions: 335 | # this operation is a bit akward, but it's required to 336 | # make sure that attn_weights keeps its gradient. 337 | # In order to do so, attn_weights have to reshaped 338 | # twice and have to be reused in the following 339 | attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) 340 | attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) 341 | else: 342 | attn_weights_reshaped = None 343 | 344 | # with mpu.get_cuda_rng_tracker().fork(): 345 | attn_probs = F.dropout(attn_weights, p=self.dropout, training=self.training) 346 | 347 | attn_output = torch.bmm(attn_probs, value_states) 348 | 349 | assert attn_output.size() == ( 350 | bsz * self.num_heads, 351 | tgt_len, 352 | self.head_dim, 353 | ), f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}" 354 | 355 | attn_output = ( 356 | attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) 357 | .transpose(1, 2) 358 | .reshape(bsz, tgt_len, embed_dim) 359 | ) 360 | 361 | attn_output = self.out_proj(attn_output) 362 | 363 | return attn_output, attn_weights_reshaped, past_key_value 364 | 365 | 366 | class CPTDecoderLayer(nn.Module): 367 | def __init__(self, config: CPTConfig): 368 | super().__init__() 369 | self.embed_dim = config.d_model 370 | 371 | self.self_attn = CPTAttention( 372 | embed_dim=self.embed_dim, 373 | num_heads=config.decoder_attention_heads, 374 | dropout=config.attention_dropout, 375 | is_decoder=True, 376 | ) 377 | self.dropout = config.dropout 378 | self.activation_fn = ACT2FN[config.activation_function] 379 | self.activation_dropout = config.activation_dropout 380 | 381 | self.self_attn_layer_norm = LayerNorm(self.embed_dim) 382 | self.encoder_attn = CPTAttention( 383 | self.embed_dim, 384 | config.decoder_attention_heads, 385 | dropout=config.attention_dropout, 386 | is_decoder=True, 387 | ) 388 | self.encoder_attn_layer_norm = LayerNorm(self.embed_dim) 389 | self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim) 390 | self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim) 391 | self.final_layer_norm = LayerNorm(self.embed_dim) 392 | 393 | def forward( 394 | self, 395 | hidden_states: torch.Tensor, 396 | attention_mask: Optional[torch.Tensor] = None, 397 | encoder_hidden_states: Optional[torch.Tensor] = None, 398 | encoder_attention_mask: Optional[torch.Tensor] = None, 399 | layer_head_mask: Optional[torch.Tensor] = None, 400 | encoder_layer_head_mask: Optional[torch.Tensor] = None, 401 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 402 | output_attentions: Optional[bool] = False, 403 | use_cache: Optional[bool] = True, 404 | ): 405 | """ 406 | Args: 407 | hidden_states (:obj:`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)` 408 | attention_mask (:obj:`torch.FloatTensor`): attention mask of size 409 | `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. 410 | encoder_hidden_states (:obj:`torch.FloatTensor`): cross attention input to the layer of shape `(seq_len, batch, embed_dim)` 411 | encoder_attention_mask (:obj:`torch.FloatTensor`): encoder attention mask of size 412 | `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. 413 | layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size 414 | `(config.encoder_attention_heads,)`. 415 | encoder_layer_head_mask (:obj:`torch.FloatTensor`): mask for encoder attention heads in a given layer of 416 | size `(config.encoder_attention_heads,)`. 417 | past_key_value (:obj:`Tuple(torch.FloatTensor)`): cached past key and value projection states 418 | output_attentions (:obj:`bool`, `optional`): 419 | Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under 420 | returned tensors for more detail. 421 | """ 422 | residual = hidden_states 423 | 424 | # Self Attention 425 | # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 426 | self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None 427 | # add present self-attn cache to positions 1,2 of present_key_value tuple 428 | hidden_states, self_attn_weights, present_key_value = self.self_attn( 429 | hidden_states=hidden_states, 430 | past_key_value=self_attn_past_key_value, 431 | attention_mask=attention_mask, 432 | layer_head_mask=layer_head_mask, 433 | output_attentions=output_attentions, 434 | ) 435 | hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) 436 | hidden_states = residual + hidden_states 437 | hidden_states = self.self_attn_layer_norm(hidden_states) 438 | 439 | # Cross-Attention Block 440 | cross_attn_present_key_value = None 441 | cross_attn_weights = None 442 | if encoder_hidden_states is not None: 443 | residual = hidden_states 444 | 445 | # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple 446 | cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None 447 | hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( 448 | hidden_states=hidden_states, 449 | key_value_states=encoder_hidden_states, 450 | attention_mask=encoder_attention_mask, 451 | layer_head_mask=encoder_layer_head_mask, 452 | past_key_value=cross_attn_past_key_value, 453 | output_attentions=output_attentions, 454 | ) 455 | hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) 456 | hidden_states = residual + hidden_states 457 | hidden_states = self.encoder_attn_layer_norm(hidden_states) 458 | 459 | # add cross-attn to positions 3,4 of present_key_value tuple 460 | present_key_value = present_key_value + cross_attn_present_key_value 461 | 462 | # Fully Connected 463 | residual = hidden_states 464 | hidden_states = self.activation_fn(self.fc1(hidden_states)) 465 | hidden_states = F.dropout(hidden_states, p=self.activation_dropout, training=self.training) 466 | hidden_states = self.fc2(hidden_states) 467 | hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) 468 | hidden_states = residual + hidden_states 469 | hidden_states = self.final_layer_norm(hidden_states) 470 | 471 | outputs = (hidden_states,) 472 | 473 | if output_attentions: 474 | outputs += (self_attn_weights, cross_attn_weights) 475 | 476 | if use_cache: 477 | outputs += (present_key_value,) 478 | 479 | return outputs 480 | 481 | 482 | class CPTClassificationHead(nn.Module): 483 | """Head for sentence-level classification tasks.""" 484 | 485 | def __init__( 486 | self, 487 | input_dim: int, 488 | inner_dim: int, 489 | num_classes: int, 490 | pooler_dropout: float, 491 | ): 492 | super().__init__() 493 | self.dense = nn.Linear(input_dim, inner_dim) 494 | self.dropout = nn.Dropout(p=pooler_dropout) 495 | self.out_proj = nn.Linear(inner_dim, num_classes) 496 | 497 | def forward(self, hidden_states: torch.Tensor): 498 | hidden_states = self.dropout(hidden_states) 499 | hidden_states = self.dense(hidden_states) 500 | hidden_states = torch.tanh(hidden_states) 501 | hidden_states = self.dropout(hidden_states) 502 | hidden_states = self.out_proj(hidden_states) 503 | return hidden_states 504 | 505 | 506 | class CPTPretrainedModel(PreTrainedModel): 507 | config_class = CPTConfig 508 | base_model_prefix = "model" 509 | 510 | def _init_weights(self, module): 511 | std = self.config.init_std 512 | if isinstance(module, nn.Linear): 513 | module.weight.data.normal_(mean=0.0, std=std) 514 | if module.bias is not None: 515 | module.bias.data.zero_() 516 | elif isinstance(module, nn.Embedding): 517 | module.weight.data.normal_(mean=0.0, std=std) 518 | if module.padding_idx is not None: 519 | module.weight.data[module.padding_idx].zero_() 520 | 521 | @property 522 | def dummy_inputs(self): 523 | pad_token = self.config.pad_token_id 524 | input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device) 525 | dummy_inputs = { 526 | "attention_mask": input_ids.ne(pad_token), 527 | "input_ids": input_ids, 528 | } 529 | return dummy_inputs 530 | 531 | 532 | CPT_START_DOCSTRING = r""" 533 | This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic 534 | methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, 535 | pruning heads etc.) 536 | 537 | This model is also a PyTorch `torch.nn.Module `__ 538 | subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to 539 | general usage and behavior. 540 | 541 | Parameters: 542 | config (:class:`~transformers.CPTConfig`): 543 | Model configuration class with all the parameters of the model. Initializing with a config file does not 544 | load the weights associated with the model, only the configuration. Check out the 545 | :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights. 546 | """ 547 | 548 | CPT_INPUTS_DOCSTRING = r""" 549 | Args: 550 | input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): 551 | Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide 552 | it. 553 | 554 | Indices can be obtained using :class:`~transformers.CPTTokenizer`. See 555 | :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for 556 | details. 557 | 558 | `What are input IDs? <../glossary.html#input-ids>`__ 559 | attention_mask (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): 560 | Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: 561 | 562 | - 1 for tokens that are **not masked**, 563 | - 0 for tokens that are **masked**. 564 | 565 | `What are attention masks? <../glossary.html#attention-mask>`__ 566 | decoder_input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`): 567 | Indices of decoder input sequence tokens in the vocabulary. 568 | 569 | Indices can be obtained using :class:`~transformers.CPTTokenizer`. See 570 | :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for 571 | details. 572 | 573 | `What are input IDs? <../glossary.html#input-ids>`__ 574 | 575 | CPT uses the :obj:`eos_token_id` as the starting token for :obj:`decoder_input_ids` generation. If 576 | :obj:`past_key_values` is used, optionally only the last :obj:`decoder_input_ids` have to be input (see 577 | :obj:`past_key_values`). 578 | 579 | For translation and summarization training, :obj:`decoder_input_ids` should be provided. If no 580 | :obj:`decoder_input_ids` is provided, the model will create this tensor by shifting the :obj:`input_ids` to 581 | the right for denoising pre-training following the paper. 582 | decoder_attention_mask (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`): 583 | Default behavior: generate a tensor that ignores pad tokens in :obj:`decoder_input_ids`. Causal mask will 584 | also be used by default. 585 | 586 | If you want to change padding behavior, you should read :func:`modeling_cpt._prepare_decoder_inputs` and 587 | modify to your needs. See diagram 1 in `the paper `__ for more 588 | information on the default strategy. 589 | head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): 590 | Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``: 591 | 592 | - 1 indicates the head is **not masked**, 593 | - 0 indicates the heas is **masked**. 594 | 595 | decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): 596 | Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``: 597 | 598 | - 1 indicates the head is **not masked**, 599 | - 0 indicates the head is **masked**. 600 | 601 | encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`): 602 | Tuple consists of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`: 603 | :obj:`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`, 604 | `optional`) is a sequence of hidden-states at the output of the last layer of the encoder. Used in the 605 | cross-attention of the decoder. 606 | past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): 607 | Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up decoding. 608 | 609 | If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` 610 | (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` 611 | instead of all :obj:`decoder_input_ids`` of shape :obj:`(batch_size, sequence_length)`. 612 | inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): 613 | Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. 614 | This is useful if you want more control over how to convert :obj:`input_ids` indices into associated 615 | vectors than the model's internal embedding lookup matrix. 616 | decoder_inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, target_sequence_length, hidden_size)`, `optional`): 617 | Optionally, instead of passing :obj:`decoder_input_ids` you can choose to directly pass an embedded 618 | representation. If :obj:`past_key_values` is used, optionally only the last :obj:`decoder_inputs_embeds` 619 | have to be input (see :obj:`past_key_values`). This is useful if you want more control over how to convert 620 | :obj:`decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix. 621 | 622 | If :obj:`decoder_input_ids` and :obj:`decoder_inputs_embeds` are both unset, :obj:`decoder_inputs_embeds` 623 | takes the value of :obj:`inputs_embeds`. 624 | use_cache (:obj:`bool`, `optional`): 625 | If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up 626 | decoding (see :obj:`past_key_values`). 627 | output_attentions (:obj:`bool`, `optional`): 628 | Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned 629 | tensors for more detail. 630 | output_hidden_states (:obj:`bool`, `optional`): 631 | Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for 632 | more detail. 633 | return_dict (:obj:`bool`, `optional`): 634 | Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. 635 | """ 636 | 637 | 638 | class CPTDecoder(CPTPretrainedModel): 639 | """ 640 | Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a :class:`CPTDecoderLayer` 641 | 642 | Args: 643 | config: CPTConfig 644 | embed_tokens (torch.nn.Embedding): output embedding 645 | """ 646 | 647 | def __init__(self, config: CPTConfig, embed_tokens: Optional[nn.Embedding] = None): 648 | super().__init__(config) 649 | self.dropout = config.dropout 650 | self.layerdrop = config.decoder_layerdrop 651 | self.padding_idx = config.pad_token_id 652 | self.max_target_positions = config.max_position_embeddings 653 | self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 654 | 655 | if embed_tokens is not None: 656 | self.embed_tokens = embed_tokens 657 | else: 658 | self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx) 659 | 660 | self.embed_positions = CPTLearnedPositionalEmbedding( 661 | config.max_position_embeddings, 662 | config.d_model, 663 | ) 664 | self.layers = nn.ModuleList([CPTDecoderLayer(config) for _ in range(config.decoder_layers)]) 665 | self.layernorm_embedding = LayerNorm(config.d_model) 666 | 667 | self.init_weights() 668 | 669 | def get_input_embeddings(self): 670 | return self.embed_tokens 671 | 672 | def set_input_embeddings(self, value): 673 | self.embed_tokens = value 674 | 675 | def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): 676 | # create causal mask 677 | # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] 678 | combined_attention_mask = None 679 | if input_shape[-1] > 1: 680 | combined_attention_mask = _make_causal_mask( 681 | input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length 682 | ).to(self.device) 683 | 684 | if attention_mask is not None: 685 | # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] 686 | expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) 687 | combined_attention_mask = ( 688 | expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask 689 | ) 690 | 691 | return combined_attention_mask 692 | 693 | def forward( 694 | self, 695 | input_ids=None, 696 | attention_mask=None, 697 | encoder_hidden_states=None, 698 | encoder_attention_mask=None, 699 | head_mask=None, 700 | encoder_head_mask=None, 701 | past_key_values=None, 702 | inputs_embeds=None, 703 | use_cache=None, 704 | output_attentions=None, 705 | output_hidden_states=None, 706 | return_dict=None, 707 | ): 708 | r""" 709 | Args: 710 | input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): 711 | Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you 712 | provide it. 713 | 714 | Indices can be obtained using :class:`~transformers.CPTTokenizer`. See 715 | :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` 716 | for details. 717 | 718 | `What are input IDs? <../glossary.html#input-ids>`__ 719 | attention_mask (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): 720 | Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: 721 | 722 | - 1 for tokens that are **not masked**, 723 | - 0 for tokens that are **masked**. 724 | 725 | `What are attention masks? <../glossary.html#attention-mask>`__ 726 | encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, encoder_sequence_length, hidden_size)`, `optional`): 727 | Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention 728 | of the decoder. 729 | encoder_attention_mask (:obj:`torch.LongTensor` of shape :obj:`(batch_size, encoder_sequence_length)`, `optional`): 730 | Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values 731 | selected in ``[0, 1]``: 732 | 733 | - 1 for tokens that are **not masked**, 734 | - 0 for tokens that are **masked**. 735 | 736 | `What are attention masks? <../glossary.html#attention-mask>`__ 737 | head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): 738 | Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: 739 | 740 | - 1 indicates the head is **not masked**, 741 | - 0 indicates the heas is **masked**. 742 | 743 | encoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): 744 | Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention 745 | on hidden heads. Mask values selected in ``[0, 1]``: 746 | 747 | - 1 indicates the head is **not masked**, 748 | - 0 indicates the heas is **masked**. 749 | 750 | past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): 751 | Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up 752 | decoding. 753 | 754 | If :obj:`past_key_values` are used, the user can optionally input only the last 755 | :obj:`decoder_input_ids` (those that don't have their past key value states given to this model) of 756 | shape :obj:`(batch_size, 1)` instead of all :obj:`decoder_input_ids`` of shape :obj:`(batch_size, 757 | sequence_length)`. 758 | inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): 759 | Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded 760 | representation. This is useful if you want more control over how to convert :obj:`input_ids` indices 761 | into associated vectors than the model's internal embedding lookup matrix. 762 | output_attentions (:obj:`bool`, `optional`): 763 | Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under 764 | returned tensors for more detail. 765 | output_hidden_states (:obj:`bool`, `optional`): 766 | Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors 767 | for more detail. 768 | return_dict (:obj:`bool`, `optional`): 769 | Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. 770 | """ 771 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 772 | output_hidden_states = ( 773 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 774 | ) 775 | use_cache = use_cache if use_cache is not None else self.config.use_cache 776 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 777 | 778 | # retrieve input_ids and inputs_embeds 779 | if input_ids is not None and inputs_embeds is not None: 780 | raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") 781 | elif input_ids is not None: 782 | input_shape = input_ids.size() 783 | input_ids = input_ids.view(-1, input_shape[-1]) 784 | elif inputs_embeds is not None: 785 | input_shape = inputs_embeds.size()[:-1] 786 | else: 787 | raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") 788 | 789 | # past_key_values_length 790 | past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 791 | 792 | if inputs_embeds is None: 793 | inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale 794 | 795 | attention_mask = self._prepare_decoder_attention_mask( 796 | attention_mask, input_shape, inputs_embeds, past_key_values_length 797 | ) 798 | 799 | # expand encoder attention mask 800 | if encoder_hidden_states is not None and encoder_attention_mask is not None: 801 | # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] 802 | encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) 803 | 804 | # embed positions 805 | positions = self.embed_positions(input_shape, past_key_values_length) 806 | 807 | hidden_states = inputs_embeds + positions 808 | hidden_states = self.layernorm_embedding(hidden_states) 809 | 810 | hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) 811 | 812 | # decoder layers 813 | all_hidden_states = () if output_hidden_states else None 814 | all_self_attns = () if output_attentions else None 815 | all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None 816 | next_decoder_cache = () if use_cache else None 817 | 818 | # check if head_mask has a correct number of layers specified if desired 819 | if head_mask is not None: 820 | assert head_mask.size()[0] == ( 821 | len(self.layers) 822 | ), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}." 823 | for idx, decoder_layer in enumerate(self.layers): 824 | # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) 825 | if output_hidden_states: 826 | all_hidden_states += (hidden_states,) 827 | dropout_probability = random.uniform(0, 1) 828 | if self.training and (dropout_probability < self.layerdrop): 829 | continue 830 | 831 | past_key_value = past_key_values[idx] if past_key_values is not None else None 832 | 833 | if getattr(self.config, "gradient_checkpointing", False) and self.training: 834 | 835 | if use_cache: 836 | logger.warn( 837 | "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " 838 | "`use_cache=False`..." 839 | ) 840 | use_cache = False 841 | 842 | def create_custom_forward(module): 843 | def custom_forward(*inputs): 844 | # None for past_key_value 845 | return module(*inputs, output_attentions, use_cache) 846 | 847 | return custom_forward 848 | 849 | # layer_outputs = mpu.checkpoint( 850 | layer_outputs = torch.utils.checkpoint( 851 | create_custom_forward(decoder_layer), 852 | hidden_states, 853 | attention_mask, 854 | encoder_hidden_states, 855 | encoder_attention_mask, 856 | head_mask[idx] if head_mask is not None else None, 857 | encoder_head_mask[idx] if encoder_head_mask is not None else None, 858 | None, 859 | ) 860 | else: 861 | 862 | layer_outputs = decoder_layer( 863 | hidden_states, 864 | attention_mask=attention_mask, 865 | encoder_hidden_states=encoder_hidden_states, 866 | encoder_attention_mask=encoder_attention_mask, 867 | layer_head_mask=(head_mask[idx] if head_mask is not None else None), 868 | encoder_layer_head_mask=(encoder_head_mask[idx] if encoder_head_mask is not None else None), 869 | past_key_value=past_key_value, 870 | output_attentions=output_attentions, 871 | use_cache=use_cache, 872 | ) 873 | hidden_states = layer_outputs[0] 874 | 875 | if use_cache: 876 | next_decoder_cache += (layer_outputs[3 if output_attentions else 1],) 877 | 878 | if output_attentions: 879 | all_self_attns += (layer_outputs[1],) 880 | 881 | if encoder_hidden_states is not None: 882 | all_cross_attentions += (layer_outputs[2],) 883 | 884 | # add hidden states from the last decoder layer 885 | if output_hidden_states: 886 | all_hidden_states += (hidden_states,) 887 | 888 | next_cache = next_decoder_cache if use_cache else None 889 | if not return_dict: 890 | return tuple( 891 | v 892 | for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions] 893 | if v is not None 894 | ) 895 | return BaseModelOutputWithPastAndCrossAttentions( 896 | last_hidden_state=hidden_states, 897 | past_key_values=next_cache, 898 | hidden_states=all_hidden_states, 899 | attentions=all_self_attns, 900 | cross_attentions=all_cross_attentions, 901 | ) 902 | 903 | 904 | @add_start_docstrings( 905 | "The bare CPT Model outputting raw hidden-states without any specific head on top.", 906 | CPT_START_DOCSTRING, 907 | ) 908 | class CPTModel(CPTPretrainedModel): 909 | def __init__(self, config: CPTConfig): 910 | super().__init__(config) 911 | encoder_config = BertConfig( 912 | vocab_size=config.vocab_size, 913 | hidden_size=config.d_model, 914 | num_hidden_layers=config.encoder_layers, 915 | num_attention_heads=config.encoder_attention_heads, 916 | intermediate_size=config.encoder_ffn_dim, 917 | hidden_dropout_prob=config.activation_dropout, 918 | attention_probs_dropout_prob=config.attention_dropout, 919 | max_position_embeddings=config.max_position_embeddings, 920 | ) 921 | config.vocab_size = encoder_config.vocab_size 922 | self.encoder = BertModel(encoder_config, add_pooling_layer=False) 923 | self.shared = self.encoder.get_input_embeddings() 924 | self.decoder = CPTDecoder(config, self.shared) 925 | self.num_decoder_layers = config.decoder_layers 926 | self.init_weights() 927 | 928 | def get_input_embeddings(self): 929 | return self.shared 930 | 931 | def set_input_embeddings(self, value): 932 | self.shared = value 933 | self.encoder.set_input_embeddings(self.shared) 934 | self.decoder.embed_tokens = self.shared 935 | 936 | def get_encoder(self): 937 | class _Encoder(torch.nn.Module): 938 | def __init__(self, encoder): 939 | super().__init__() 940 | self.encoder = encoder 941 | 942 | def forward(self, *args, **kwargs): 943 | kwargs['output_hidden_states'] = True 944 | return self.encoder(*args, **kwargs) 945 | 946 | return _Encoder(self.encoder) 947 | 948 | def get_decoder(self): 949 | return self.decoder 950 | 951 | @add_start_docstrings_to_model_forward(CPT_INPUTS_DOCSTRING) 952 | @add_code_sample_docstrings( 953 | checkpoint=_CHECKPOINT_FOR_DOC, 954 | output_type=Seq2SeqModelOutput, 955 | config_class=_CONFIG_FOR_DOC, 956 | ) 957 | def forward( 958 | self, 959 | input_ids=None, 960 | attention_mask=None, 961 | decoder_input_ids=None, 962 | decoder_attention_mask=None, 963 | head_mask=None, 964 | decoder_head_mask=None, 965 | encoder_outputs=None, 966 | past_key_values=None, 967 | inputs_embeds=None, 968 | decoder_inputs_embeds=None, 969 | use_cache=None, 970 | output_attentions=None, 971 | output_hidden_states=None, 972 | return_dict=None, 973 | ): 974 | 975 | # different to other models, CPT automatically creates decoder_input_ids from 976 | # input_ids if no decoder_input_ids are provided 977 | if decoder_input_ids is None and decoder_inputs_embeds is None: 978 | decoder_input_ids = shift_tokens_right( 979 | input_ids, self.config.pad_token_id, self.config.decoder_start_token_id 980 | ) 981 | 982 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 983 | output_hidden_states = ( 984 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 985 | ) 986 | use_cache = use_cache if use_cache is not None else self.config.use_cache 987 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 988 | 989 | if getattr(self.config, "gradient_checkpointing", False) and self.training: 990 | # mpu.reset_checkpointed_activations_memory_buffer() 991 | use_cache = False 992 | 993 | if encoder_outputs is None: 994 | encoder_outputs = self.encoder( 995 | input_ids=input_ids, 996 | attention_mask=attention_mask, 997 | token_type_ids=torch.ones_like(input_ids), 998 | head_mask=head_mask, 999 | inputs_embeds=inputs_embeds, 1000 | output_attentions=output_attentions, 1001 | output_hidden_states=True, 1002 | return_dict=return_dict, 1003 | ) 1004 | # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True 1005 | elif return_dict and isinstance(encoder_outputs, (tuple, list)): 1006 | encoder_outputs = BaseModelOutput( 1007 | last_hidden_state=encoder_outputs[0], 1008 | hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, 1009 | attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, 1010 | ) 1011 | 1012 | if isinstance(encoder_outputs, (torch.Tensor)): 1013 | encoder_hidden_states = encoder_outputs 1014 | else: 1015 | encoder_hidden_states = encoder_outputs[1][-self.num_decoder_layers - 1] 1016 | 1017 | # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) 1018 | decoder_outputs = self.decoder( 1019 | input_ids=decoder_input_ids, 1020 | attention_mask=decoder_attention_mask, 1021 | encoder_hidden_states=encoder_hidden_states, 1022 | encoder_attention_mask=attention_mask, 1023 | head_mask=decoder_head_mask, 1024 | encoder_head_mask=head_mask, 1025 | past_key_values=past_key_values, 1026 | inputs_embeds=decoder_inputs_embeds, 1027 | use_cache=use_cache, 1028 | output_attentions=output_attentions, 1029 | output_hidden_states=output_hidden_states, 1030 | return_dict=return_dict, 1031 | ) 1032 | 1033 | if not return_dict: 1034 | return decoder_outputs + encoder_outputs 1035 | 1036 | return Seq2SeqModelOutput( 1037 | last_hidden_state=decoder_outputs.last_hidden_state, 1038 | past_key_values=decoder_outputs.past_key_values, 1039 | decoder_hidden_states=decoder_outputs.hidden_states, 1040 | decoder_attentions=decoder_outputs.attentions, 1041 | cross_attentions=decoder_outputs.cross_attentions, 1042 | encoder_last_hidden_state=encoder_outputs.last_hidden_state if isinstance(encoder_outputs, dict) else None, 1043 | encoder_hidden_states=encoder_outputs.hidden_states if isinstance(encoder_outputs, dict) else None, 1044 | encoder_attentions=encoder_outputs.attentions if isinstance(encoder_outputs, dict) else None, 1045 | ) 1046 | 1047 | 1048 | @add_start_docstrings( 1049 | "The CPT Model with a language modeling head. Can be used for summarization.", CPT_START_DOCSTRING 1050 | ) 1051 | class CPTForConditionalGeneration(CPTPretrainedModel): 1052 | base_model_prefix = "model" 1053 | _keys_to_ignore_on_load_missing = [ 1054 | r"final_logits_bias", 1055 | r"encoder\.version", 1056 | r"decoder\.version", 1057 | r"lm_head\.weight", 1058 | ] 1059 | 1060 | def __init__(self, config): 1061 | super().__init__(config) 1062 | self.model = CPTModel(config) 1063 | self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings))) 1064 | self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False) 1065 | 1066 | self.init_weights() 1067 | 1068 | def get_encoder(self): 1069 | return self.model.get_encoder() 1070 | 1071 | def get_decoder(self): 1072 | return self.model.get_decoder() 1073 | 1074 | def resize_token_embeddings(self, new_num_tokens: int) -> nn.Embedding: 1075 | new_embeddings = super().resize_token_embeddings(new_num_tokens) 1076 | self._resize_final_logits_bias(new_num_tokens) 1077 | return new_embeddings 1078 | 1079 | def _resize_final_logits_bias(self, new_num_tokens: int) -> None: 1080 | old_num_tokens = self.final_logits_bias.shape[-1] 1081 | if new_num_tokens <= old_num_tokens: 1082 | new_bias = self.final_logits_bias[:, :new_num_tokens] 1083 | else: 1084 | extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), device=self.final_logits_bias.device) 1085 | new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1) 1086 | self.register_buffer("final_logits_bias", new_bias) 1087 | 1088 | def get_output_embeddings(self): 1089 | return self.lm_head 1090 | 1091 | def set_output_embeddings(self, new_embeddings): 1092 | self.lm_head = new_embeddings 1093 | 1094 | @add_start_docstrings_to_model_forward(CPT_INPUTS_DOCSTRING) 1095 | @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) 1096 | def forward( 1097 | self, 1098 | input_ids=None, 1099 | attention_mask=None, 1100 | decoder_input_ids=None, 1101 | decoder_attention_mask=None, 1102 | head_mask=None, 1103 | decoder_head_mask=None, 1104 | encoder_outputs=None, 1105 | past_key_values=None, 1106 | inputs_embeds=None, 1107 | decoder_inputs_embeds=None, 1108 | labels=None, 1109 | use_cache=None, 1110 | output_attentions=None, 1111 | output_hidden_states=None, 1112 | return_dict=None, 1113 | ): 1114 | r""" 1115 | labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): 1116 | Labels for computing the masked language modeling loss. Indices should either be in ``[0, ..., 1117 | config.vocab_size]`` or -100 (see ``input_ids`` docstring). Tokens with indices set to ``-100`` are ignored 1118 | (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``. 1119 | 1120 | Returns: 1121 | """ 1122 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1123 | 1124 | if labels is not None: 1125 | if decoder_input_ids is None: 1126 | decoder_input_ids = shift_tokens_right( 1127 | labels, self.config.pad_token_id, self.config.decoder_start_token_id 1128 | ) 1129 | 1130 | outputs = self.model( 1131 | input_ids, 1132 | attention_mask=attention_mask, 1133 | decoder_input_ids=decoder_input_ids, 1134 | encoder_outputs=encoder_outputs, 1135 | decoder_attention_mask=decoder_attention_mask, 1136 | head_mask=head_mask, 1137 | decoder_head_mask=decoder_head_mask, 1138 | past_key_values=past_key_values, 1139 | inputs_embeds=inputs_embeds, 1140 | decoder_inputs_embeds=decoder_inputs_embeds, 1141 | use_cache=use_cache, 1142 | output_attentions=output_attentions, 1143 | output_hidden_states=output_hidden_states, 1144 | return_dict=return_dict, 1145 | ) 1146 | lm_logits = self.lm_head(outputs[0]) + self.final_logits_bias 1147 | 1148 | masked_lm_loss = None 1149 | if labels is not None: 1150 | loss_fct = CrossEntropyLoss() 1151 | masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1)) 1152 | 1153 | if not return_dict: 1154 | output = (lm_logits,) + outputs[1:] 1155 | return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output 1156 | 1157 | return Seq2SeqLMOutput( 1158 | loss=masked_lm_loss, 1159 | logits=lm_logits, 1160 | past_key_values=outputs.past_key_values, 1161 | decoder_hidden_states=outputs.decoder_hidden_states, 1162 | decoder_attentions=outputs.decoder_attentions, 1163 | cross_attentions=outputs.cross_attentions, 1164 | encoder_last_hidden_state=outputs.encoder_last_hidden_state, 1165 | encoder_hidden_states=outputs.encoder_hidden_states, 1166 | encoder_attentions=outputs.encoder_attentions, 1167 | ) 1168 | 1169 | def prepare_inputs_for_generation( 1170 | self, 1171 | decoder_input_ids, 1172 | past=None, 1173 | attention_mask=None, 1174 | head_mask=None, 1175 | use_cache=None, 1176 | encoder_outputs=None, 1177 | **kwargs 1178 | ): 1179 | # cut decoder_input_ids if past is used 1180 | if past is not None: 1181 | decoder_input_ids = decoder_input_ids[:, -1:] 1182 | 1183 | return { 1184 | "input_ids": None, # encoder_outputs is defined. input_ids not needed 1185 | "encoder_outputs": encoder_outputs, 1186 | "past_key_values": past, 1187 | "decoder_input_ids": decoder_input_ids, 1188 | "attention_mask": attention_mask, 1189 | "head_mask": head_mask, 1190 | "use_cache": use_cache, # change this to avoid caching (presumably for debugging) 1191 | } 1192 | 1193 | @staticmethod 1194 | def _expand_inputs_for_generation( 1195 | input_ids: torch.LongTensor, 1196 | expand_size: int = 1, 1197 | is_encoder_decoder: bool = False, 1198 | attention_mask: torch.LongTensor = None, 1199 | encoder_outputs=None, 1200 | **model_kwargs, 1201 | ): 1202 | expanded_return_idx = ( 1203 | torch.arange(input_ids.shape[0]).view(-1, 1).repeat(1, expand_size).view(-1).to(input_ids.device) 1204 | ) 1205 | input_ids = input_ids.index_select(0, expanded_return_idx) 1206 | 1207 | if "token_type_ids" in model_kwargs: 1208 | token_type_ids = model_kwargs["token_type_ids"] 1209 | model_kwargs["token_type_ids"] = token_type_ids.index_select(0, expanded_return_idx) 1210 | 1211 | if attention_mask is not None: 1212 | model_kwargs["attention_mask"] = attention_mask.index_select(0, expanded_return_idx) 1213 | 1214 | if is_encoder_decoder: 1215 | assert encoder_outputs is not None 1216 | device = encoder_outputs.last_hidden_state.device 1217 | encoder_outputs["hidden_states"] = tuple(h.index_select(0, expanded_return_idx.to(device)) \ 1218 | for h in encoder_outputs["hidden_states"]) 1219 | model_kwargs["encoder_outputs"] = encoder_outputs 1220 | return input_ids, model_kwargs 1221 | 1222 | def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): 1223 | return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id) 1224 | 1225 | @staticmethod 1226 | def _reorder_cache(past, beam_idx): 1227 | reordered_past = () 1228 | for layer_past in past: 1229 | # cached cross_attention states don't have to be reordered -> they are always the same 1230 | reordered_past += ( 1231 | tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], 1232 | ) 1233 | return reordered_past 1234 | 1235 | 1236 | class CopyGenerator(nn.Module): 1237 | def __init__(self, config): 1238 | super().__init__() 1239 | self.vocab_size = config.vocab_size 1240 | self.prob_proj = nn.Linear(config.d_model * 2, 1) 1241 | 1242 | def forward(self, src, decode_output, decode_attn, memory, gen_logits): 1243 | decode_attn = torch.mean(decode_attn, dim=1) 1244 | batch_size, steps, seq = decode_attn.size() 1245 | src = src.unsqueeze(1).repeat([1, steps, 1]) 1246 | # vocab 1247 | copy_logits = torch.zeros_like(gen_logits) 1248 | context = torch.matmul(decode_attn, memory) 1249 | copy_logits = copy_logits.scatter_add(2, src, decode_attn) 1250 | prob = self.prob_proj(torch.cat([context, decode_output], -1)).sigmoid() 1251 | 1252 | gen_logits = prob * gen_logits.softmax(-1) 1253 | copy_logits = (1 - prob) * copy_logits.softmax(-1) 1254 | final_logits = gen_logits + copy_logits 1255 | return final_logits 1256 | 1257 | 1258 | class T5Copy(T5ForConditionalGeneration): 1259 | def __init__(self, config): 1260 | super().__init__(config) 1261 | self.generator = CopyGenerator(config) 1262 | 1263 | def _prepare_encoder_decoder_kwargs_for_generation( 1264 | self, inputs_tensor, model_kwargs, model_input_name=None): 1265 | # 1. get encoder 1266 | encoder = self.get_encoder() 1267 | # 2. prepare encoder args and encoder kwargs from model kwargs 1268 | irrelevant_prefix = ["decoder_", "cross_attn", "use_cache"] 1269 | encoder_kwargs = { 1270 | argument: value 1271 | for argument, value in model_kwargs.items() 1272 | if not any(argument.startswith(p) for p in irrelevant_prefix) 1273 | } 1274 | # 3. make sure that encoder returns `ModelOutput` 1275 | model_input_name = model_input_name if model_input_name is not None else self.main_input_name 1276 | encoder_kwargs["return_dict"] = True 1277 | encoder_kwargs[model_input_name] = inputs_tensor 1278 | src = encoder_kwargs.pop('src') 1279 | model_kwargs["encoder_outputs"] = encoder(**encoder_kwargs) 1280 | model_kwargs['src'] = src 1281 | return model_kwargs 1282 | 1283 | def prepare_inputs_for_generation( 1284 | self, 1285 | input_ids, 1286 | past_key_values=None, 1287 | attention_mask=None, 1288 | head_mask=None, 1289 | decoder_head_mask=None, 1290 | cross_attn_head_mask=None, 1291 | use_cache=None, 1292 | encoder_outputs=None, 1293 | **kwargs 1294 | ): 1295 | res = super().prepare_inputs_for_generation(input_ids, 1296 | past_key_values, 1297 | attention_mask, 1298 | head_mask, 1299 | decoder_head_mask, 1300 | cross_attn_head_mask, 1301 | use_cache, 1302 | encoder_outputs, 1303 | **kwargs) 1304 | res['src'] = kwargs['src'] 1305 | return res 1306 | 1307 | def forward( 1308 | self, 1309 | input_ids=None, 1310 | attention_mask=None, 1311 | decoder_input_ids=None, 1312 | decoder_attention_mask=None, 1313 | head_mask=None, 1314 | decoder_head_mask=None, 1315 | cross_attn_head_mask=None, 1316 | encoder_outputs=None, 1317 | past_key_values=None, 1318 | inputs_embeds=None, 1319 | decoder_inputs_embeds=None, 1320 | labels=None, 1321 | use_cache=None, 1322 | output_attentions=None, 1323 | output_hidden_states=None, 1324 | return_dict=None, 1325 | src=None 1326 | ): 1327 | outputs = super().forward(input_ids, 1328 | attention_mask, 1329 | decoder_input_ids, 1330 | decoder_attention_mask, 1331 | head_mask, 1332 | decoder_head_mask, 1333 | cross_attn_head_mask, 1334 | encoder_outputs, 1335 | past_key_values, 1336 | inputs_embeds, 1337 | decoder_inputs_embeds, 1338 | labels, 1339 | use_cache, 1340 | output_attentions=True, 1341 | output_hidden_states=True, 1342 | return_dict=True) 1343 | 1344 | memory = outputs.encoder_last_hidden_state 1345 | decode_attn = outputs.cross_attentions[-1] 1346 | decode_output = outputs.decoder_hidden_states[-1] 1347 | gen_logits = outputs.logits 1348 | if self.training: 1349 | prob = self.generator(input_ids, decode_output, decode_attn, memory, gen_logits) 1350 | else: 1351 | prob = self.generator(src, decode_output, decode_attn, memory, gen_logits) 1352 | outputs.logits = prob 1353 | return outputs 1354 | 1355 | 1356 | class ProphetNetModel(OldProphetNetModel): 1357 | def __init__(self, config): 1358 | super(ProphetNetPreTrainedModel, self).__init__(config) 1359 | self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) 1360 | 1361 | encoder_config = copy.deepcopy(config) 1362 | encoder_config.is_encoder_decoder = False 1363 | encoder_config.use_cache = False 1364 | encoder_config.max_position_embeddings = config.encoder_max_position_embeddings 1365 | self.encoder = ProphetNetEncoder(encoder_config, self.word_embeddings) 1366 | 1367 | decoder_config = copy.deepcopy(config) 1368 | decoder_config.is_decoder = True 1369 | decoder_config.is_encoder_decoder = False 1370 | decoder_config.max_position_embeddings = config.decoder_max_position_embeddings 1371 | self.decoder = ProphetNetDecoder(decoder_config, self.word_embeddings) 1372 | 1373 | # Initialize weights and apply final processing 1374 | self.post_init() 1375 | 1376 | 1377 | class ProphetNetForConditionalGeneration(OldProphetNetForConditionalGeneration): 1378 | def __init__(self, config): 1379 | super(ProphetNetPreTrainedModel, self).__init__(config) 1380 | self.prophetnet = ProphetNetModel(config) 1381 | self.padding_idx = config.pad_token_id 1382 | self.disable_ngram_loss = config.disable_ngram_loss 1383 | 1384 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 1385 | 1386 | self.post_init() 1387 | 1388 | 1389 | MODEL_CLASSES = { 1390 | 'cpt': [CPTForConditionalGeneration, CPTTokenizer], 1391 | 't5copy': [T5Copy, JieBaTokenizer], 1392 | 't5-pegasus': [T5ForConditionalGeneration, JieBaTokenizer], 1393 | 'pegasus': [PegasusForConditionalGeneration, PegasusTokenizer], 1394 | 'lm': [AutoModelForCausalLM, AutoTokenizer], 1395 | 'seq2seq': [AutoModelForSeq2SeqLM, AutoTokenizer], 1396 | 't5': [T5ForConditionalGeneration, T5Tokenizer], 1397 | 'prophet': [ProphetNetForConditionalGeneration, CPTTokenizer] 1398 | } 1399 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | python train.py \ 2 | --train_file /home/xianglingyang/data/faith_gen/LCSTS_new/small_train.json \ 3 | --dev_file /home/xianglingyang/data/faith_gen/LCSTS_new/small_dev.json \ 4 | --batch_size 6 \ 5 | --max_epochs 10 \ 6 | --max_source_length 512 \ 7 | --max_target_length 300 \ 8 | --model_path /home/xianglingyang/pretrained_models/torch/t5-copy \ 9 | --gpus 4 \ 10 | --lr 5e-5 --model_type t5copy 11 | 12 | 13 | python train.py \ 14 | --train_file /home/xianglingyang/data/faith_gen/LCSTS_new/small_train.json \ 15 | --dev_file /home/xianglingyang/data/faith_gen/LCSTS_new/small_dev.json \ 16 | --batch_size 6 \ 17 | --max_epochs 10 \ 18 | --max_source_length 512 \ 19 | --max_target_length 150 \ 20 | --model_path /home/xianglingyang/pretrained_models/torch/t5-copy \ 21 | --gpus 4 \ 22 | --lr 5e-5 --model_type t5-pegasus 23 | 24 | 25 | python train.py \ 26 | --train_file /home/xianglingyang/data/faith_gen/LCSTS_new/small_train.json \ 27 | --dev_file /home/xianglingyang/data/faith_gen/LCSTS_new/small_dev.json \ 28 | --batch_size 6 \ 29 | --max_epochs 10 \ 30 | --max_source_length 512 \ 31 | --max_target_length 300 \ 32 | --model_path /home/xianglingyang/pretrained_models/torch/cpt-large \ 33 | --gpus 4 \ 34 | --lr 5e-5 --model_type cpt --rdrop 35 | 36 | python train.py \ 37 | --train_file /home/xianglingyang/data/faith_gen/LCSTS_new/small_train.json \ 38 | --dev_file /home/xianglingyang/data/faith_gen/LCSTS_new/small_dev.json \ 39 | --batch_size 6 \ 40 | --max_epochs 10 \ 41 | --max_source_length 512 \ 42 | --max_target_length 300 \ 43 | --model_path /home/xianglingyang/pretrained_models/torch/prophet \ 44 | --gpus 4 \ 45 | --lr 5e-5 --model_type prophet -------------------------------------------------------------------------------- /tokenizer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import unicodedata 4 | import collections 5 | from functools import partial 6 | from typing import List, Optional, Tuple, Union 7 | 8 | import jieba 9 | from transformers import PreTrainedTokenizer, BertTokenizer, T5Tokenizer as OldT5Tokenizer, AutoTokenizer 10 | from transformers import logging 11 | 12 | jieba.initialize() 13 | 14 | logger = logging.get_logger(__name__) 15 | 16 | VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"} 17 | 18 | 19 | class PegasusTokenizer(PreTrainedTokenizer): 20 | r""" 21 | Construct a Pegasus tokenizer. Based on WordPiece. 22 | This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to 23 | this superclass for more information regarding those methods. 24 | Args: 25 | vocab_file (`str`): 26 | File containing the vocabulary. 27 | do_lower_case (`bool`, *optional*, defaults to `True`): 28 | Whether or not to lowercase the input when tokenizing. 29 | do_basic_tokenize (`bool`, *optional*, defaults to `True`): 30 | Whether or not to do basic tokenization before WordPiece. 31 | never_split (`Iterable`, *optional*): 32 | Collection of tokens which will never be split during tokenization. Only has an effect when 33 | `do_basic_tokenize=True` 34 | unk_token (`str`, *optional*, defaults to `"[UNK]"`): 35 | The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this 36 | token instead. 37 | sep_token (`str`, *optional*, defaults to `"[SEP]"`): 38 | The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for 39 | sequence classification or for a text and a question for question answering. It is also used as the last 40 | token of a sequence built with special tokens. 41 | pad_token (`str`, *optional*, defaults to `"[PAD]"`): 42 | The token used for padding, for example when batching sequences of different lengths. 43 | cls_token (`str`, *optional*, defaults to `"[CLS]"`): 44 | The classifier token which is used when doing sequence classification (classification of the whole sequence 45 | instead of per-token classification). It is the first token of the sequence when built with special tokens. 46 | mask_token (`str`, *optional*, defaults to `"[MASK]"`): 47 | The token used for masking values. This is the token used when training this model with masked language 48 | modeling. This is the token which the model will try to predict. 49 | tokenize_chinese_chars (`bool`, *optional*, defaults to `True`): 50 | Whether or not to tokenize Chinese characters. 51 | This should likely be deactivated for Japanese (see this 52 | [issue](https://github.com/huggingface/transformers/issues/328)). 53 | strip_accents (`bool`, *optional*): 54 | Whether or not to strip all accents. If this option is not specified, then it will be determined by the 55 | value for `lowercase` (as in the original BERT). 56 | """ 57 | 58 | vocab_files_names = VOCAB_FILES_NAMES 59 | model_input_names = ["input_ids", "attention_mask"] 60 | 61 | def __init__(self, 62 | vocab_file, 63 | do_lower_case=True, 64 | do_basic_tokenize=True, 65 | never_split=None, 66 | pad_token="", 67 | eos_token="", 68 | unk_token="", 69 | mask_token="", 70 | mask_token_sent="", 71 | additional_special_tokens=None, 72 | sep_token="[SEP]", 73 | cls_token="[CLS]", 74 | tokenize_chinese_chars=True, 75 | strip_accents=None, 76 | offset=100, 77 | pre_tokenizer=lambda x: jieba.cut(x, HMM=False), 78 | **kwargs): 79 | self.offset = offset 80 | 81 | if additional_special_tokens is not None: 82 | if not isinstance(additional_special_tokens, list): 83 | raise TypeError( 84 | f"additional_special_tokens should be of type {type(list)}, \ 85 | but is {type(additional_special_tokens)}" 86 | ) 87 | 88 | additional_special_tokens_extended = ( 89 | ([mask_token_sent] + additional_special_tokens) 90 | if mask_token_sent not in additional_special_tokens 91 | and mask_token_sent is not None else additional_special_tokens) 92 | 93 | # fill additional tokens with ..., in case not all additional tokens are already taken 94 | additional_special_tokens_extended += [ 95 | f"" for i in range( 96 | len(additional_special_tokens_extended), self.offset - 1) 97 | ] 98 | 99 | if len(set(additional_special_tokens_extended)) != len( 100 | additional_special_tokens_extended): 101 | raise ValueError( 102 | f"Please make sure that the provided additional_special_tokens \ 103 | do not contain an incorrectly shifted list of tokens. \ 104 | Found {additional_special_tokens_extended}." 105 | ) 106 | additional_special_tokens = additional_special_tokens_extended 107 | else: 108 | additional_special_tokens = [ 109 | mask_token_sent 110 | ] if mask_token_sent is not None else [] 111 | # additional_special_tokens += [f"" for i in range(3, self.offset)] 112 | 113 | # print("additional_special_tokens: ", additional_special_tokens) 114 | 115 | if not os.path.isfile(vocab_file): 116 | raise ValueError( 117 | f"Can't find a vocabulary file at path '{vocab_file}'. \ 118 | To load the vocabulary from a Google pretrained " 119 | "model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`" 120 | ) 121 | 122 | super().__init__( 123 | do_lower_case=do_lower_case, 124 | do_basic_tokenize=do_basic_tokenize, 125 | never_split=never_split, 126 | unk_token=unk_token, 127 | sep_token=sep_token, 128 | pad_token=pad_token, 129 | cls_token=cls_token, 130 | mask_token=mask_token, 131 | eos_token=eos_token, 132 | tokenize_chinese_chars=tokenize_chinese_chars, 133 | additional_special_tokens=additional_special_tokens, 134 | strip_accents=strip_accents, 135 | **kwargs, 136 | ) 137 | 138 | self.pre_tokenizer = pre_tokenizer 139 | self.mask_token_sent = mask_token_sent 140 | self.vocab = load_vocab(vocab_file) 141 | 142 | self.vocab[self.eos_token] = self.vocab.pop("[unused1]") 143 | # self.vocab[self.eos_token] = self.vocab.pop("[unused2]") 144 | self.vocab[self.pad_token] = self.vocab.pop("[PAD]") 145 | self.vocab[self.unk_token] = self.vocab.pop("[UNK]") 146 | 147 | if self.mask_token_sent is not None: 148 | self.vocab[self.mask_token] = self.vocab.pop("[unused3]") 149 | self.vocab[self.mask_token_sent] = self.vocab.pop("[unused2]") 150 | 151 | self.ids_to_tokens = collections.OrderedDict([ 152 | (ids, tok) for tok, ids in self.vocab.items() 153 | ]) 154 | self.do_basic_tokenize = do_basic_tokenize 155 | if do_basic_tokenize: 156 | self.basic_tokenizer = BasicTokenizer( 157 | do_lower_case=do_lower_case, 158 | never_split=never_split, 159 | tokenize_chinese_chars=tokenize_chinese_chars, 160 | strip_accents=strip_accents, 161 | ) 162 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab, 163 | unk_token=self.unk_token) 164 | self.target_mode = False 165 | 166 | @property 167 | def do_lower_case(self): 168 | return self.basic_tokenizer.do_lower_case 169 | 170 | @property 171 | def vocab_size(self): 172 | return len(self.vocab) 173 | 174 | def get_vocab(self): 175 | return dict(self.vocab, **self.added_tokens_encoder) 176 | 177 | def _tokenize(self, text): 178 | split_tokens = [] 179 | # print("pegasus_tokenizer: ", text) 180 | for text in self.pre_tokenizer(text): 181 | if text in self.vocab: 182 | split_tokens.append(text) 183 | else: 184 | if self.do_basic_tokenize: 185 | for token in self.basic_tokenizer.tokenize( 186 | text, never_split=self.all_special_tokens): 187 | 188 | # If the token is part of the never_split set 189 | if token in self.basic_tokenizer.never_split: 190 | split_tokens.append(token) 191 | else: 192 | split_tokens += self.wordpiece_tokenizer.tokenize( 193 | token) 194 | else: 195 | split_tokens = self.wordpiece_tokenizer.tokenize(text) 196 | return split_tokens 197 | 198 | def _convert_token_to_id(self, token): 199 | """Converts a token (str) in an id using the vocab.""" 200 | return self.vocab.get(token, self.vocab.get(self.unk_token)) 201 | 202 | def _convert_id_to_token(self, index): 203 | """Converts an index (integer) in a token (str) using the vocab.""" 204 | return self.ids_to_tokens.get(index, self.unk_token) 205 | 206 | @staticmethod 207 | def _cjk_punctuation(): 208 | return u'\uff02\uff03\uff04\uff05\uff06\uff07\uff08\uff09\uff0a\uff0b\uff0c\uff0d\uff0f\uff1a\uff1b\uff1c\uff1d\ 209 | \uff1e\uff20\uff3b\uff3c\uff3d\uff3e\uff3f\uff40\uff5b\uff5c\uff5d\uff5e\uff5f\uff60\uff62\ 210 | \uff63\uff64\u3000\u3001\u3003\u3008\u3009\u300a\u300b\u300c\u300d\u300e\u300f\u3010\u3011\u3014\ 211 | \u3015\u3016\u3017\u3018\u3019\u301a\u301b\u301c\u301d\u301e\u301f\u3030\u303e\u303f\u2013\u2014\ 212 | \u2018\u2019\u201b\u201c\u201d\u201e\u201f\u2026\u2027\ufe4f\ufe51\ufe54\u00b7\uff01\uff1f\uff61\u3002' 213 | 214 | def convert_ids_to_tokens( 215 | self, 216 | ids: Union[int, List[int]], 217 | skip_special_tokens: bool = False) -> Union[str, List[str]]: 218 | """ 219 | Converts a single index or a sequence of indices in a token or a sequence of tokens, using the vocabulary and 220 | added tokens. 221 | Args: 222 | ids (`int` or `List[int]`): 223 | The token id (or token ids) to convert to tokens. 224 | skip_special_tokens (`bool`, *optional*, defaults to `False`): 225 | Whether or not to remove special tokens in the decoding. 226 | Returns: 227 | `str` or `List[str]`: The decoded token(s). 228 | """ 229 | if isinstance(ids, int): 230 | if ids in self.added_tokens_decoder: 231 | return self.added_tokens_decoder[ids] 232 | else: 233 | return self._convert_id_to_token(ids) 234 | tokens = [] 235 | for index in ids: 236 | index = int(index) 237 | if skip_special_tokens and index in self.all_special_ids and index != 2: 238 | continue 239 | if index in self.added_tokens_decoder: 240 | tokens.append(self.added_tokens_decoder[index]) 241 | else: 242 | tokens.append(self._convert_id_to_token(index)) 243 | return tokens 244 | 245 | def convert_tokens_to_string(self, tokens): 246 | """Converts a sequence of tokens (string) in a single string.""" 247 | # for token in 248 | # tokens = tokens or self.ids_to_tokens(ids) 249 | # tokens = [token for token in tokens if not self._is_special(token)] 250 | 251 | text = '' 252 | for i, token in enumerate(tokens): 253 | if token[:2] == '##': 254 | text += token[2:] 255 | elif len(token) == 1 and _is_chinese_char(ord(token)): 256 | text += token 257 | elif len(token) == 1 and _is_punctuation(token): 258 | text += token 259 | text += ' ' 260 | elif i > 0 and _is_chinese_char(ord(text[-1])): 261 | text += token 262 | elif tokens == "": 263 | continue 264 | else: 265 | text += ' ' 266 | text += token 267 | 268 | text = re.sub(' +', ' ', text) 269 | text = re.sub('\' (re|m|s|t|ve|d|ll) ', '\'\\1 ', text) 270 | punctuation = re.sub(' +', '', self._cjk_punctuation()).strip() + '+-/={(<[' 271 | punctuation_regex = '|'.join([re.escape(p) for p in punctuation]) 272 | punctuation_regex = '(%s) ' % punctuation_regex 273 | text = re.sub(punctuation_regex, '\\1', text) 274 | text = re.sub(r'(\d\.) (\d)', '\\1\\2', text) 275 | 276 | return text.strip() 277 | # out_string = " ".join(tokens).replace(" ##", "").strip() 278 | 279 | def build_inputs_with_special_tokens( 280 | self, 281 | token_ids_0: List[int], 282 | token_ids_1: Optional[List[int]] = None) -> List[int]: 283 | if not self.target_mode: 284 | return token_ids_0 + [self.eos_token_id] 285 | return [self.pad_token_id] + token_ids_0 + [self.eos_token_id] 286 | 287 | def _special_token_mask(self, seq): 288 | all_special_ids = set( 289 | self.all_special_ids) # call it once instead of inside list comp 290 | # all_special_ids.remove(self.unk_token_id) # is only sometimes special 291 | 292 | return [1 if x in all_special_ids else 0 for x in seq] 293 | 294 | def get_special_tokens_mask( 295 | self, 296 | token_ids_0: List[int], 297 | token_ids_1: Optional[List[int]] = None, 298 | already_has_special_tokens: bool = False) -> List[int]: 299 | """ 300 | Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding 301 | special tokens using the tokenizer `prepare_for_model` method. 302 | Args: 303 | token_ids_0 (`List[int]`): 304 | List of IDs. 305 | token_ids_1 (`List[int]`, *optional*): 306 | Optional second list of IDs for sequence pairs. 307 | already_has_special_tokens (`bool`, *optional*, defaults to `False`): 308 | Whether or not the token list is already formatted with special tokens for the model. 309 | Returns: 310 | `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. 311 | """ 312 | 313 | if already_has_special_tokens: 314 | return self._special_token_mask(token_ids_0) 315 | elif token_ids_1 is None: 316 | return self._special_token_mask(token_ids_0) + [self.eos_token_id] 317 | else: 318 | return self._special_token_mask(token_ids_0 + 319 | token_ids_1) + [self.eos_token_id] 320 | 321 | def num_special_tokens_to_add(self, pair=False): 322 | """Just EOS""" 323 | return 1 324 | 325 | def save_vocabulary(self, 326 | save_directory: str, 327 | filename_prefix: Optional[str] = None) -> Tuple[str]: 328 | index = 0 329 | if os.path.isdir(save_directory): 330 | vocab_file = os.path.join( 331 | save_directory, 332 | (filename_prefix + "-" if filename_prefix else "") + 333 | VOCAB_FILES_NAMES["vocab_file"]) 334 | else: 335 | vocab_file = (filename_prefix + 336 | "-" if filename_prefix else "") + save_directory 337 | with open(vocab_file, "w", encoding="utf-8") as writer: 338 | for token, token_index in sorted(self.vocab.items(), 339 | key=lambda kv: kv[1]): 340 | if index != token_index: 341 | logger.warning( 342 | f"Saving vocabulary to {vocab_file}: vocabulary indices are not consecutive." 343 | " Please check that the vocabulary is not corrupted!") 344 | index = token_index 345 | writer.write(token + "\n") 346 | index += 1 347 | return (vocab_file,) 348 | 349 | def _switch_to_input_mode(self): 350 | self.target_mode = False 351 | 352 | def _switch_to_target_mode(self): 353 | self.target_mode = True 354 | 355 | @property 356 | def bos_token_id(self) -> Optional[int]: 357 | return 0 358 | 359 | 360 | class CPTTokenizer(BertTokenizer): 361 | 362 | def __init__(self, *args, **kwargs): 363 | super().__init__(*args, **kwargs) 364 | self.target_mode = False 365 | 366 | def build_inputs_with_special_tokens( 367 | self, 368 | token_ids_0: List[int], 369 | token_ids_1: Optional[List[int]] = None) -> List[int]: 370 | if not self.target_mode: 371 | return token_ids_0 + [self.eos_token_id] 372 | return [self.bos_token_id] + token_ids_0 + [self.eos_token_id] 373 | 374 | def _switch_to_input_mode(self): 375 | self.target_mode = False 376 | 377 | def _switch_to_target_mode(self): 378 | self.target_mode = True 379 | 380 | @property 381 | def bos_token_id(self) -> Optional[int]: 382 | return self.cls_token_id 383 | 384 | @property 385 | def eos_token_id(self) -> Optional[int]: 386 | return self.sep_token_id 387 | 388 | 389 | class JieBaTokenizer(CPTTokenizer): 390 | def __init__(self, *args, **kwargs): 391 | super().__init__(*args, **kwargs) 392 | self.pre_tokenizer = partial(jieba.cut, HMM=False) 393 | 394 | def _tokenize(self, text, *arg, **kwargs): 395 | split_tokens = [] 396 | for text in self.pre_tokenizer(text): 397 | if text in self.vocab: 398 | split_tokens.append(text) 399 | else: 400 | split_tokens.extend(super()._tokenize(text)) 401 | return split_tokens 402 | 403 | 404 | class T5Tokenizer(OldT5Tokenizer): 405 | 406 | def __init__(self, *args, **kwargs): 407 | super().__init__(*args, **kwargs) 408 | self.target_mode = False 409 | 410 | def build_inputs_with_special_tokens( 411 | self, 412 | token_ids_0: List[int], 413 | token_ids_1: Optional[List[int]] = None) -> List[int]: 414 | if not self.target_mode: 415 | return token_ids_0 + [self.eos_token_id] 416 | return [self.bos_token_id] + token_ids_0 + [self.eos_token_id] 417 | 418 | def _switch_to_input_mode(self): 419 | self.target_mode = False 420 | 421 | def _switch_to_target_mode(self): 422 | self.target_mode = True 423 | 424 | @property 425 | def bos_token_id(self) -> Optional[int]: 426 | return 0 427 | 428 | 429 | class BasicTokenizer(object): 430 | """ 431 | Constructs a BasicTokenizer that will run basic tokenization (punctuation splitting, lower casing, etc.). 432 | Args: 433 | do_lower_case (`bool`, *optional*, defaults to `True`): 434 | Whether or not to lowercase the input when tokenizing. 435 | never_split (`Iterable`, *optional*): 436 | Collection of tokens which will never be split during tokenization. Only has an effect when 437 | `do_basic_tokenize=True` 438 | tokenize_chinese_chars (`bool`, *optional*, defaults to `True`): 439 | Whether or not to tokenize Chinese characters. 440 | This should likely be deactivated for Japanese (see this 441 | [issue](https://github.com/huggingface/transformers/issues/328)). 442 | strip_accents: (`bool`, *optional*): 443 | Whether or not to strip all accents. If this option is not specified, then it will be determined by the 444 | value for `lowercase` (as in the original BERT). 445 | """ 446 | 447 | def __init__(self, 448 | do_lower_case=True, 449 | never_split=None, 450 | tokenize_chinese_chars=True, 451 | strip_accents=None): 452 | if never_split is None: 453 | never_split = [] 454 | self.do_lower_case = do_lower_case 455 | self.never_split = set(never_split) 456 | self.tokenize_chinese_chars = tokenize_chinese_chars 457 | self.strip_accents = strip_accents 458 | 459 | def tokenize(self, text, never_split=None): 460 | """ 461 | Basic Tokenization of a piece of text. Split on "white spaces" only, for sub-word tokenization, see 462 | WordPieceTokenizer. 463 | Args: 464 | never_split (`List[str]`, *optional*) 465 | Kept for backward compatibility purposes. Now implemented directly at the base class level (see 466 | [`PreTrainedTokenizer.tokenize`]) List of token not to split. 467 | """ 468 | # union() returns a new set by concatenating the two sets. 469 | never_split = self.never_split.union( 470 | set(never_split)) if never_split else self.never_split 471 | text = self._clean_text(text) 472 | 473 | # This was added on November 1st, 2018 for the multilingual and Chinese 474 | # models. This is also applied to the English models now, but it doesn't 475 | # matter since the English models were not trained on any Chinese data 476 | # and generally don't have any Chinese data in them (there are Chinese 477 | # characters in the vocabulary because Wikipedia does have some Chinese 478 | # words in the English Wikipedia.). 479 | if self.tokenize_chinese_chars: 480 | text = self._tokenize_chinese_chars(text) 481 | orig_tokens = whitespace_tokenize(text) 482 | split_tokens = [] 483 | for token in orig_tokens: 484 | if token not in never_split: 485 | if self.do_lower_case: 486 | token = token.lower() 487 | if self.strip_accents is not False: 488 | token = self._run_strip_accents(token) 489 | elif self.strip_accents: 490 | token = self._run_strip_accents(token) 491 | split_tokens.extend(self._run_split_on_punc(token, never_split)) 492 | 493 | output_tokens = whitespace_tokenize(" ".join(split_tokens)) 494 | return output_tokens 495 | 496 | def _run_strip_accents(self, text): 497 | """Strips accents from a piece of text.""" 498 | text = unicodedata.normalize("NFD", text) 499 | output = [] 500 | for char in text: 501 | cat = unicodedata.category(char) 502 | if cat == "Mn": 503 | continue 504 | output.append(char) 505 | return "".join(output) 506 | 507 | def _run_split_on_punc(self, text, never_split=None): 508 | """Splits punctuation on a piece of text.""" 509 | if never_split is not None and text in never_split: 510 | return [text] 511 | chars = list(text) 512 | i = 0 513 | start_new_word = True 514 | output = [] 515 | while i < len(chars): 516 | char = chars[i] 517 | if _is_punctuation(char): 518 | output.append([char]) 519 | start_new_word = True 520 | else: 521 | if start_new_word: 522 | output.append([]) 523 | start_new_word = False 524 | output[-1].append(char) 525 | i += 1 526 | 527 | return ["".join(x) for x in output] 528 | 529 | def _tokenize_chinese_chars(self, text): 530 | """Adds whitespace around any CJK character.""" 531 | output = [] 532 | for char in text: 533 | cp = ord(char) 534 | if self._is_chinese_char(cp): 535 | output.append(" ") 536 | output.append(char) 537 | output.append(" ") 538 | else: 539 | output.append(char) 540 | return "".join(output) 541 | 542 | def _is_chinese_char(self, cp): 543 | """Checks whether CP is the codepoint of a CJK character.""" 544 | # This defines a "chinese character" as anything in the CJK Unicode block: 545 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) 546 | # 547 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters, 548 | # despite its name. The modern Korean Hangul alphabet is a different block, 549 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write 550 | # space-separated words, so they are not treated specially and handled 551 | # like the all of the other languages. 552 | if ((cp >= 0x4E00 and cp <= 0x9FFF) 553 | or (cp >= 0x3400 and cp <= 0x4DBF) # 554 | or (cp >= 0x20000 and cp <= 0x2A6DF) # 555 | or (cp >= 0x2A700 and cp <= 0x2B73F) # 556 | or (cp >= 0x2B740 and cp <= 0x2B81F) # 557 | or (cp >= 0x2B820 and cp <= 0x2CEAF) # 558 | or (cp >= 0xF900 and cp <= 0xFAFF) 559 | or (cp >= 0x2F800 and cp <= 0x2FA1F)): # 560 | return True 561 | 562 | return False 563 | 564 | def _clean_text(self, text): 565 | """Performs invalid character removal and whitespace cleanup on text.""" 566 | output = [] 567 | for char in text: 568 | cp = ord(char) 569 | if cp == 0 or cp == 0xFFFD or _is_control(char): 570 | continue 571 | if _is_whitespace(char): 572 | output.append(" ") 573 | else: 574 | output.append(char) 575 | return "".join(output) 576 | 577 | 578 | class WordpieceTokenizer(object): 579 | """Runs WordPiece tokenization.""" 580 | 581 | def __init__(self, vocab, unk_token, max_input_chars_per_word=100): 582 | self.vocab = vocab 583 | self.unk_token = unk_token 584 | self.max_input_chars_per_word = max_input_chars_per_word 585 | 586 | def tokenize(self, text): 587 | """ 588 | Tokenizes a piece of text into its word pieces. This uses a greedy longest-match-first algorithm to perform 589 | tokenization using the given vocabulary. 590 | For example, `input = "unaffable"` wil return as output `["un", "##aff", "##able"]`. 591 | Args: 592 | text: A single token or whitespace separated tokens. This should have 593 | already been passed through *BasicTokenizer*. 594 | Returns: 595 | A list of wordpiece tokens. 596 | """ 597 | 598 | output_tokens = [] 599 | for token in whitespace_tokenize(text): 600 | chars = list(token) 601 | if len(chars) > self.max_input_chars_per_word: 602 | output_tokens.append(self.unk_token) 603 | continue 604 | 605 | is_bad = False 606 | start = 0 607 | sub_tokens = [] 608 | while start < len(chars): 609 | end = len(chars) 610 | cur_substr = None 611 | while start < end: 612 | substr = "".join(chars[start:end]) 613 | if start > 0: 614 | substr = "##" + substr 615 | if substr in self.vocab: 616 | cur_substr = substr 617 | break 618 | end -= 1 619 | if cur_substr is None: 620 | is_bad = True 621 | break 622 | sub_tokens.append(cur_substr) 623 | start = end 624 | 625 | if is_bad: 626 | output_tokens.append(self.unk_token) 627 | else: 628 | output_tokens.extend(sub_tokens) 629 | return output_tokens 630 | 631 | 632 | def _is_chinese_char(cp): 633 | """Checks whether CP is the codepoint of a CJK character.""" 634 | # This defines a "chinese character" as anything in the CJK Unicode block: 635 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) 636 | # 637 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters, 638 | # despite its name. The modern Korean Hangul alphabet is a different block, 639 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write 640 | # space-separated words, so they are not treated specially and handled 641 | # like the all of the other languages. 642 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or (cp >= 0x3400 and cp <= 0x4DBF) 643 | or (cp >= 0x20000 and cp <= 0x2A6DF) 644 | or (cp >= 0x2A700 and cp <= 0x2B73F) 645 | or (cp >= 0x2B740 and cp <= 0x2B81F) 646 | or (cp >= 0x2B820 and cp <= 0x2CEAF) 647 | or (cp >= 0xF900 and cp <= 0xFAFF) 648 | or (cp >= 0x2F800 and cp <= 0x2FA1F)): 649 | return True 650 | 651 | return False 652 | 653 | 654 | def _is_whitespace(char): 655 | """Checks whether `char` is a whitespace character.""" 656 | # \t, \n, and \r are technically control characters but we treat them 657 | # as whitespace since they are generally considered as such. 658 | if char == " " or char == "\t" or char == "\n" or char == "\r": 659 | return True 660 | cat = unicodedata.category(char) 661 | if cat == "Zs": 662 | return True 663 | return False 664 | 665 | 666 | def _is_control(char): 667 | """Checks whether `char` is a control character.""" 668 | # These are technically control characters but we count them as whitespace 669 | # characters. 670 | if char == "\t" or char == "\n" or char == "\r": 671 | return False 672 | cat = unicodedata.category(char) 673 | if cat.startswith("C"): 674 | return True 675 | return False 676 | 677 | 678 | def _is_punctuation(char): 679 | """Checks whether `char` is a punctuation character.""" 680 | cp = ord(char) 681 | # We treat all non-letter/number ASCII as punctuation. 682 | # Characters such as "^", "$", and "`" are not in the Unicode 683 | # Punctuation class but we treat them as punctuation anyways, for 684 | # consistency. 685 | if (cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or ( 686 | cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126): 687 | return True 688 | cat = unicodedata.category(char) 689 | if cat.startswith("P"): 690 | return True 691 | return False 692 | 693 | 694 | def load_vocab(vocab_file): 695 | """Loads a vocabulary file into a dictionary.""" 696 | vocab = collections.OrderedDict() 697 | with open(vocab_file, "r", encoding="utf-8") as reader: 698 | tokens = reader.readlines() 699 | for index, token in enumerate(tokens): 700 | token = token.rstrip("\n") 701 | vocab[token] = index 702 | return vocab 703 | 704 | 705 | def whitespace_tokenize(text): 706 | """Runs basic whitespace cleaning and splitting on a piece of text.""" 707 | text = text.strip() 708 | if not text: 709 | return [] 710 | tokens = text.split() 711 | return tokens 712 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pytorch_lightning as pl 3 | from utils import * 4 | from models import LightModel 5 | from args import parser 6 | 7 | if __name__ == '__main__': 8 | 9 | args = parser.parse_args() 10 | 11 | model = LightModel(args) 12 | data = EncoderDecoderData(args, model.tokenizer) 13 | dataloaders = data.get_dataloader() 14 | 15 | for fold in range(args.kfold): 16 | pl.seed_everything(args.seed + fold) 17 | train_data, dev_data = dataloaders['train'][fold], dataloaders['dev'][fold] 18 | if fold > 0: 19 | model = LightModel(args) 20 | checkpoint = pl.callbacks.ModelCheckpoint( 21 | dirpath=args.output_dir, 22 | filename='{fold:02d}-{epoch:02d}-{bleu:.4f}-{rouge:.4f}-{rouge-1:.4f}-{rouge-2:.4f}-{rouge-l:.4f}', 23 | save_weights_only=True, 24 | save_on_train_epoch_end=True, 25 | monitor='rouge', 26 | mode='max', 27 | ) 28 | trainer = pl.Trainer.from_argparse_args(args, callbacks=[checkpoint], logger=False) 29 | trainer.fit(model, train_data, dev_data) 30 | del model 31 | del trainer 32 | torch.cuda.empty_cache() 33 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from transformers import BertTokenizer 2 | from functools import partial 3 | import json 4 | import jieba 5 | import torch 6 | import torch.nn.functional as F 7 | import numpy as np 8 | from torch.utils.data import DataLoader, Dataset, Subset 9 | from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction 10 | import rouge 11 | import re 12 | from transformers import AdamW 13 | import collections 14 | 15 | rouge = rouge.Rouge() 16 | smooth = SmoothingFunction().method1 17 | 18 | 19 | class EncoderDecoderData: 20 | def __init__(self, args, tokenizer, ): 21 | self.train_data = self.read_file(args.train_file) if args.train_file else None 22 | self.dev_data = self.read_file(args.dev_file) if args.dev_file else None 23 | self.predict_data = self.read_file(args.predict_file) if args.predict_file else None 24 | self.args = args 25 | self.tokenizer = tokenizer 26 | 27 | if self.args.noise_prob > 0: 28 | self.vocab_pool = list(set(range(len(tokenizer))) - set(tokenizer.all_special_ids)) 29 | 30 | def get_predict_dataloader(self): 31 | predict_dataset = KeyDataset(self.predict_data) 32 | predict_dataloader = DataLoader(predict_dataset, 33 | batch_size=self.args.batch_size * 2, 34 | collate_fn=self.predict_collate) 35 | return predict_dataloader 36 | 37 | def read_file(self, file): 38 | return [json.loads(x) for x in open(file, encoding='utf-8')] 39 | 40 | def encode_src(self, src): 41 | res = self.tokenizer(src, 42 | padding=True, 43 | return_tensors='pt', 44 | max_length=self.args.max_source_length, 45 | truncation='longest_first', 46 | return_attention_mask=True, 47 | return_token_type_ids=False) 48 | return res 49 | 50 | def train_collate(self, batch): 51 | if isinstance(batch[0], list): 52 | batch = batch[0] # max_token_dataset 53 | src = [b['src'] for b in batch] 54 | tgt = [b['tgt'] for b in batch] 55 | 56 | src_tokenized = self.encode_src(src) 57 | with self.tokenizer.as_target_tokenizer(): 58 | tgt_tokenized = self.tokenizer( 59 | tgt, 60 | max_length=self.args.max_target_length, 61 | padding=True, 62 | return_tensors='pt', 63 | truncation='longest_first') 64 | 65 | decoder_attention_mask = tgt_tokenized['attention_mask'][:, :-1] 66 | decoder_input_ids = tgt_tokenized['input_ids'][:, :-1] 67 | 68 | labels = tgt_tokenized['input_ids'][:, 1:].clone() 69 | labels.masked_fill_(labels == self.tokenizer.pad_token_id, -100) 70 | 71 | if self.args.noise_prob > 0: 72 | noise_indices = torch.rand_like(labels) < self.args.noise_prob 73 | noise_indices = noise_indices & (decoder_input_ids != self.tokenizer.bos_token_id) \ 74 | & (labels != self.tokenizer.eos_token_id) & decoder_attention_mask.bool() 75 | noise_inp = np.random.choice(self.vocab_pool, decoder_input_ids.shape) 76 | decoder_input_ids = torch.where(noise_indices, noise_inp, decoder_input_ids) 77 | 78 | res = {'input_ids': src_tokenized['input_ids'], 79 | 'attention_mask': src_tokenized['attention_mask'], 80 | 'decoder_input_ids': decoder_input_ids, 81 | 'decoder_attention_mask': decoder_attention_mask, 82 | 'labels': labels} 83 | return res 84 | 85 | def dev_collate(self, batch): 86 | return self.train_collate(batch) 87 | 88 | def predict_collate(self, batch): 89 | src = [x['src'] for x in batch] 90 | return self.encode_src(src) 91 | 92 | def get_dataloader(self): 93 | ret = {'train': [], 'dev': []} 94 | base_dataset = KeyDataset(self.train_data) 95 | if self.args.kfold > 1: 96 | from sklearn.model_selection import KFold 97 | for train_idx, dev_idx in KFold(n_splits=self.args.kfold, shuffle=True, 98 | random_state=self.args.seed).split(range(len(self.train_data))): 99 | train_dataset = Subset(base_dataset, train_idx) 100 | dev_dataset = Subset(base_dataset, dev_idx) 101 | train_dataloader = DataLoader(train_dataset, 102 | batch_size=self.args.batch_size, 103 | collate_fn=self.train_collate, 104 | num_workers=self.args.num_workers, 105 | shuffle=True) 106 | dev_dataloader = DataLoader(dev_dataset, 107 | batch_size=self.args.batch_size * 2, 108 | collate_fn=self.dev_collate) 109 | ret['train'].append(train_dataloader) 110 | ret['dev'].append(dev_dataloader) 111 | else: 112 | if self.args.kfold == 1 and self.dev_data is None: 113 | from sklearn.model_selection import train_test_split 114 | train_idx, dev_idx = train_test_split(range(len(self.train_data)), 115 | test_size=0.2, 116 | random_state=self.args.seed) 117 | train_dataset = Subset(base_dataset, train_idx) 118 | dev_dataset = Subset(base_dataset, dev_idx) 119 | else: 120 | assert self.dev_data is not None, 'When no kfold, dev data must be targeted' 121 | train_dataset = base_dataset 122 | dev_dataset = KeyDataset(self.dev_data) 123 | 124 | train_dataloader = DataLoader(train_dataset, 125 | batch_size=self.args.batch_size, 126 | collate_fn=self.train_collate, 127 | num_workers=self.args.num_workers, shuffle=True) 128 | dev_dataloader = DataLoader(dev_dataset, 129 | batch_size=self.args.batch_size * 2, 130 | collate_fn=self.dev_collate) 131 | ret['train'].append(train_dataloader) 132 | ret['dev'].append(dev_dataloader) 133 | 134 | return ret 135 | 136 | 137 | class KeyDataset(Dataset): 138 | def __init__(self, dict_data): 139 | self.data = dict_data 140 | 141 | def __len__(self): 142 | return len(self.data) 143 | 144 | def __getitem__(self, index): 145 | return self.data[index] 146 | 147 | 148 | def compute_bleu(label, pred, weights=None): 149 | weights = weights or (0.25, 0.25, 0.25, 0.25) 150 | 151 | return np.mean([sentence_bleu(references=[list(a)], hypothesis=list(b), smoothing_function=smooth, weights=weights) 152 | for a, b in zip(label, pred)]) 153 | 154 | 155 | def compute_rouge(label, pred, weights=None): 156 | weights = weights or (0.2, 0.4, 0.4) 157 | if isinstance(label, str): 158 | label = [label] 159 | if isinstance(pred, str): 160 | pred = [pred] 161 | label = [' '.join(x) for x in label] 162 | pred = [' '.join(x) for x in pred] 163 | 164 | def _compute_rouge(label, pred): 165 | try: 166 | scores = rouge.get_scores(hyps=label, refs=pred)[0] 167 | scores = [scores['rouge-1']['f'], scores['rouge-2']['f'], scores['rouge-l']['f']] 168 | except ValueError: 169 | scores = [0, 0, 0] 170 | return scores 171 | 172 | scores = np.mean([_compute_rouge(*x) for x in zip(label, pred)], axis=0) 173 | return { 174 | 'rouge': sum(s * w for s, w in zip(scores, weights)), 175 | 'rouge-1': scores[0], 'rouge-2': scores[1], 'rouge-l': scores[2] 176 | } 177 | 178 | 179 | def ce_loss(logits, labels, is_prob=False, eps=0): 180 | logits = logits.view(-1, logits.size(-1)) 181 | labels = labels.view(-1) 182 | if not is_prob: 183 | loss = F.cross_entropy(logits, labels, label_smoothing=eps) 184 | else: 185 | lprob = (logits + 1e-9).log() 186 | loss = F.nll_loss(lprob, labels) 187 | return loss 188 | 189 | 190 | def kl_loss(logtis, logits2, mask): 191 | prob1 = F.softmax(logtis, -1) 192 | prob2 = F.softmax(logits2, -1) 193 | lprob1 = prob1.log() 194 | lprob2 = prob2.log() 195 | loss1 = F.kl_div(lprob1, prob2, reduction='none') 196 | loss2 = F.kl_div(lprob2, prob1, reduction='none') 197 | mask = (mask == 0).bool().unsqueeze(-1) 198 | loss1 = loss1.masked_fill_(mask, 0.0).sum() 199 | loss2 = loss2.masked_fill_(mask, 0.0).sum() 200 | loss = (loss1 + loss2) / 2 201 | 202 | return loss 203 | 204 | 205 | # def mask_select(inputs, mask): 206 | # input_dim = inputs.ndim 207 | # mask_dim = mask.ndim 208 | # mask = mask.reshape(-1).bool() 209 | # if input_dim > mask_dim: 210 | # inputs = inputs.reshape((int(mask.size(-1)), -1))[mask] 211 | # else: 212 | # inputs = inputs.reshape(-1)[mask] 213 | # return inputs 214 | 215 | 216 | # def copy_loss(inputs, targets, mask, eps=1e-6): 217 | # mask = mask[:, 1:] 218 | # inputs = inputs[:, :-1] 219 | # targets = targets[:, 1:] 220 | # inputs = mask_select(inputs, mask) 221 | # targets = mask_select(targets, mask) 222 | # log_preds = (inputs + eps).log() 223 | # loss = F.nll_loss(log_preds, targets) 224 | # return loss 225 | # 226 | # 227 | # def ce_loss(inputs, targets, mask): 228 | # mask = mask[:, 1:] 229 | # inputs = inputs[:, :-1] 230 | # targets = targets[:, 1:] 231 | # inputs = mask_select(inputs, mask) 232 | # targets = mask_select(targets, mask) 233 | # loss = F.cross_entropy(inputs, targets) 234 | # return loss 235 | 236 | 237 | def create_optimizer(model, lr, weight_decay, custom_lr=None): 238 | no_decay = 'bias|norm' 239 | params = collections.defaultdict(list) 240 | custom_lr = custom_lr or dict() 241 | for name, param in model.named_parameters(): 242 | if not param.requires_grad: 243 | continue 244 | in_custom = False 245 | for custom_name, _ in custom_lr.items(): 246 | if custom_name in name: 247 | if re.search(no_decay, name.lower()): 248 | params[custom_name].append(param) 249 | else: 250 | params[custom_name + '_decay'].append(param) 251 | in_custom = True 252 | break 253 | if not in_custom: 254 | if re.search(no_decay, name.lower()): 255 | params['normal'].append(param) 256 | else: 257 | params['normal_decay'].append(param) 258 | 259 | optimizer_grouped_parameters = [] 260 | for k, v in params.items(): 261 | param_lr = custom_lr.get(k.split('_')[0], lr) 262 | decay = weight_decay if 'decay' in k else 0.0 263 | optimizer_grouped_parameters.append({'params': v, 'weight_decay': decay, 'lr': param_lr}, ) 264 | 265 | optimizer = AdamW(optimizer_grouped_parameters, lr=lr) 266 | return optimizer 267 | --------------------------------------------------------------------------------