├── tests ├── __init__.py ├── test_train.py ├── test_zst.py ├── test_dataset.py ├── test_denoising.py └── utils.py ├── text_denoising ├── __init__.py ├── utils.py └── collate_fn.py ├── ul2.png ├── TODO.md ├── zero2_config.json ├── setup.py ├── .gitignore ├── examples └── pretrain_example.py ├── README.md └── reference_example_code.py /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /text_denoising/__init__.py: -------------------------------------------------------------------------------- 1 | from .collate_fn import DataCollatorForUL2 -------------------------------------------------------------------------------- /ul2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theblackcat102/unify-learning-paradigms/HEAD/ul2.png -------------------------------------------------------------------------------- /tests/test_train.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoModelForSeq2SeqLM 2 | from text_denoising.collate_fn import DataCollatorForUL2 3 | from torch.utils.data import IterableDataset, DataLoader 4 | from tests.test_dataset import ZstDataset 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /TODO.md: -------------------------------------------------------------------------------- 1 | 2 | [] ~~add task token [S], [X], [D] token in front of each input sequence~~ 3 | 4 | [Yi Tay: While we found mode switching to help, it was admittedly v inconvenient. This update also removes the need to use mode tokens :)](https://twitter.com/yitayml/status/1631359474421366784?s=21) 5 | 6 | [] Add eos and bos token to target sequence -------------------------------------------------------------------------------- /tests/test_zst.py: -------------------------------------------------------------------------------- 1 | import json 2 | import io 3 | from tqdm import tqdm 4 | import zstandard as zstd 5 | 6 | 7 | if __name__ == "__main__": 8 | 9 | count = 0 10 | with open('/mnt/ssd/pythia/04.jsonl.zst', 'rb') as f: 11 | cctx = zstd.ZstdDecompressor() 12 | reader = io.BufferedReader(cctx.stream_reader(f)) 13 | for line in tqdm(reader): 14 | count += 1 15 | print(count) 16 | 17 | -------------------------------------------------------------------------------- /zero2_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 0.5 9 | }, 10 | "scheduler": { 11 | "type": "WarmupDecayLR", 12 | "params": { 13 | "warmup_min_lr": "auto", 14 | "warmup_max_lr": "auto", 15 | "warmup_num_steps": "auto", 16 | "total_num_steps": "auto" 17 | } 18 | }, 19 | "zero_optimization": { 20 | "stage": 2, 21 | "offload_optimizer": { 22 | "device": "cpu", 23 | "pin_memory": true 24 | }, 25 | "allgather_partitions": true, 26 | "allgather_bucket_size": 1e9, 27 | "overlap_comm": true, 28 | "reduce_scatter": true, 29 | "reduce_bucket_size": 1e9, 30 | "contiguous_gradients": true 31 | }, 32 | "gradient_accumulation_steps": "auto", 33 | "gradient_clipping": "auto", 34 | "steps_per_print": 2000, 35 | "train_batch_size": "auto", 36 | "train_micro_batch_size_per_gpu": "auto", 37 | "wall_clock_breakdown": false 38 | } 39 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | with open("README.md", "r", encoding="utf-8") as fh: 4 | long_description = fh.read() 5 | 6 | setup( 7 | name="text-denoising", 8 | version="0.1.1", # Start with a small version number, you can increment it for subsequent releases 9 | author="Zhi-Rui, Tam", 10 | author_email="theblackcat102@github.io", 11 | description="A package for text denoising : UL2", 12 | long_description=long_description, 13 | long_description_content_type="text/markdown", 14 | url="https://github.com/your_username/text_denoising", # If you have a repo for this package 15 | packages=find_packages(include=["text_denoising"]), 16 | classifiers=[ 17 | "Programming Language :: Python :: 3", 18 | "License :: OSI Approved :: MIT License", # You can change this if you have another preferred license 19 | "Operating System :: OS Independent", 20 | ], 21 | install_requires=[ 22 | "numpy>=1.18.0", 23 | "torch>=1.7.0", 24 | "transformers", 25 | ], 26 | extras_require={ 27 | "test": [ 28 | # Add any additional testing dependencies here 29 | "pytest>=6.0.0", 30 | ] 31 | }, 32 | python_requires='>=3.6', 33 | ) 34 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | 4 | *.py[cod] 5 | *$py.class 6 | currents_api.log 7 | # database 8 | *.db 9 | *.db* 10 | sqlite.db* 11 | sqlite_debug* 12 | # C extensions 13 | *.so 14 | schema.json 15 | # Distribution / packaging 16 | .Python 17 | env/ 18 | env2/ 19 | build/ 20 | develop-eggs/ 21 | dist/ 22 | downloads/ 23 | eggs/ 24 | .eggs/ 25 | lib64/ 26 | parts/ 27 | sdist/ 28 | var/ 29 | *.egg-info/ 30 | .installed.cfg 31 | *.egg 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *,cover 52 | .hypothesis/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | 62 | # Flask stuff: 63 | instance/ 64 | .webassets-cache 65 | 66 | # Scrapy stuff: 67 | .scrapy 68 | 69 | # Sphinx documentation 70 | docs/_build/ 71 | 72 | # PyBuilder 73 | target/ 74 | 75 | # IPython Notebook 76 | .ipynb_checkpoints 77 | 78 | # pyenv 79 | .python-version 80 | 81 | # celery beat schedule file 82 | celerybeat-schedule 83 | 84 | # dotenv 85 | .env 86 | 87 | # virtualenv 88 | venv/ 89 | ENV/ 90 | env/ 91 | # Spyder project settings 92 | .spyderproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | .idea 98 | .DS_Store 99 | -------------------------------------------------------------------------------- /examples/pretrain_example.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | import glob 4 | os.environ["WANDB_PROJECT"] = "ul2_pretrain" 5 | from transformers import (AutoTokenizer, Seq2SeqTrainer, MT5ForConditionalGeneration, MT5Config, 6 | Seq2SeqTrainingArguments, trainer, AutoModelForSeq2SeqLM) 7 | from text_denoising import DataCollatorForUL2 8 | from tests.test_dataset import ZstDataset 9 | from tests import utils 10 | 11 | if __name__ == "__main__": 12 | model_name = "theblackcat102/mt0-chat-large" 13 | model_saved_name = model_name.split("/")[-1] 14 | tokenizer = AutoTokenizer.from_pretrained(model_name) 15 | tokenizer.add_special_tokens({'bos_token': ''}) 16 | 17 | # model = AutoModelForSeq2SeqLM.from_pretrained(model_name) 18 | # model.resize_token_embeddings(len(tokenizer)) 19 | 20 | config = MT5Config.from_pretrained(model_name) 21 | config.vocab_size = len(tokenizer) 22 | config.dropout_rate = 0.0 23 | print(config) 24 | model = MT5ForConditionalGeneration(config) 25 | # model = MT5ForConditionalGeneration.from_pretrained('theblackcat102/mt0-chat-large-ul2-2000') 26 | num_params = sum(param.numel() for param in model.parameters()) 27 | print(num_params/1e6) 28 | print(model.config) 29 | print(tokenizer.convert_tokens_to_ids) 30 | collate_fn = DataCollatorForUL2(tokenizer) 31 | train_dataset = ZstDataset(list(glob.glob('/mnt/ssd/pythia/*.jsonl.zst')), tokenizer, max_length=600) 32 | val_dataset = ZstDataset('/mnt/ssd/pythia_val/val.jsonl.zst', tokenizer, max_length=600) 33 | args = Seq2SeqTrainingArguments( 34 | output_dir=f"{model_name}-ul2", 35 | fp16=True, 36 | deepspeed="zero2_config.json", 37 | max_steps=100000, 38 | warmup_steps=4000, 39 | learning_rate=5e-4, 40 | label_smoothing_factor=0, 41 | optim="adamw_hf", 42 | gradient_checkpointing=True, 43 | dataloader_num_workers=22, 44 | gradient_accumulation_steps=22, 45 | per_device_train_batch_size=25, 46 | per_device_eval_batch_size=8, 47 | weight_decay=0.01, 48 | max_grad_norm=2, 49 | logging_steps=10, 50 | save_total_limit=4, 51 | evaluation_strategy="steps", 52 | eval_steps=500, 53 | save_steps=500, 54 | report_to="wandb", 55 | ) 56 | 57 | trainer = Seq2SeqTrainer( 58 | model, 59 | args, 60 | train_dataset=train_dataset, 61 | eval_dataset=val_dataset, 62 | data_collator=collate_fn, 63 | tokenizer=tokenizer, 64 | ) 65 | import wandb 66 | wandb.init( 67 | project="ul2_pretrain", 68 | name=model_name 69 | ) 70 | trainer.train() 71 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Masking implementation for Unifying Language Learning Paradigms (UL2) 2 | 3 | Want to get better model with limited budgets? You are in the right place 4 | 5 | ``` 6 | pip install text-denoising 7 | ``` 8 | 9 |

10 | 11 |

12 | 13 | - R-Denoiser (μ=3,r=0.15,n)∪ (μ=8,r=0.15,n) 14 | 15 | The regular denoising is the standard span corruption introduced in Raffel et al. (2019) that uses a range of 2 to 5 tokens as the span length, which masks about 15% of input tokens 16 | 17 | - S-Denoiser (μ=L/4,r=0.25,1) 18 | 19 | A specific case of denoising where we observe a strict sequential order when framing the inputs-to-targets task, i.e., prefix language modeling 20 | 21 | - X-Denoiser (μ = 3,r = 0.5,n)∪(μ = 8,r = 0.5,n)∪(μ = 64,r =0.15,n)∪ (μ=64,r=0.5,n) 22 | 23 | An extreme version of denoising where the model must recover a large part of the input, given a small to moderate part of it. This simulates a situation where a model needs to generate long target from a memory with relatively limited information. To do so, we opt to include examples with aggressive denoising where approximately 50% of the input sequence is masked 24 | 25 | 2022 papers : Transcending Scaling Laws with 0.1% Extra Compute 26 | 27 | > we show an approximately 2x computational savings rate 28 | 29 | - Regular denoising whereby the noise is sampled as spans, replaced with sentinel tokens. This is also the standard span corruption task used in Raffel et al. (2019). Spans are typically uniformly sampled with a mean of 3 and a corruption rate of 15%. 30 | 31 | - Extreme denoising whereby the noise is increased to relatively ‘extreme‘ amounts in either a huge percentage of the original text or being very long in nature. Spans are typically uniformly sampled with a mean length of 32 OR a corruption rate of up to 50%. 32 | 33 | - Sequential denoising whereby the noise is always sampled from the start of the text to a randomly sampled point in the text. This is also known as the PrefixLM objective (not to be confused with the architecture). 34 | 35 | This repo will just aim for accompolish this task instead, UL2 is way too complicated for my likings 36 | 37 | > 50% PrefixLM, 25% Long (extreme) span corruption, and 25% regular span corruption to be quite simple and efficient 38 | 39 | 40 | ## Experiments 41 | 42 | Run a mT5 encoder pretraining on 3090 on pythia json.zst files 43 | 44 | ``` 45 | pip install text-denoising 46 | python examples/pretrain_example.py 47 | ``` 48 | 49 |

50 | 51 |

52 | 53 | training loss was stable and no weird spikes 54 | 55 | ## References 56 | 57 | Core Papers 58 | 59 | [Transcending Scaling Laws with 0.1% Extra Compute](https://arxiv.org/pdf/2210.11399.pdf) 60 | 61 | [Unifying Language Learning Paradigms](https://arxiv.org/pdf/2205.05131.pdf) 62 | 63 | Implements of t5 noise masking in huggingface transformers or python code 64 | 65 | [OSLO](https://github.com/EleutherAI/oslo) : very underrated, some tidy and documentation, this will be a very useful tool 66 | 67 | - [t5_pretraining.py](https://github.com/EleutherAI/oslo/blob/main/oslo/transformers/tasks/data_t5_pretraining.py) 68 | 69 | Heavily inspired from this section 70 | 71 | [Amazon science : label aware pretrain in python](https://github.com/amazon-science/label-aware-pretrain/blob/main/models/preprocessor.py) 72 | 73 | [Fairseq : span_mask_tokens_dataset.py](https://github.com/facebookresearch/fairseq/blob/main/fairseq/data/span_mask_tokens_dataset.py) 74 | -------------------------------------------------------------------------------- /tests/test_dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | import io 3 | import torch 4 | import zstandard as zstd 5 | from tqdm import tqdm 6 | from transformers import AutoTokenizer 7 | from text_denoising.collate_fn import DataCollatorForUL2 8 | from torch.utils.data import IterableDataset, DataLoader 9 | 10 | def chunks(lst, chunk_size): 11 | for i in range(0, len(lst), chunk_size): 12 | yield lst[i:i+chunk_size] 13 | 14 | class ZstDataset(IterableDataset): 15 | 16 | def __init__(self, file, tokenizer, max_length=512) -> None: 17 | super().__init__() 18 | self.file = [file] if isinstance(file, str) else file 19 | self.tokenizer = tokenizer 20 | self.max_length = max_length 21 | 22 | 23 | def __iter__(self): 24 | worker_info = torch.utils.data.get_worker_info() 25 | if worker_info is None: 26 | for filename in self.file: 27 | with open(filename, 'rb') as f: 28 | cctx = zstd.ZstdDecompressor() 29 | reader = io.BufferedReader(cctx.stream_reader(f)) 30 | for line in reader: 31 | if line: 32 | raw_text = json.loads(line)['text'] 33 | for chunk in chunks(raw_text, int(self.max_length*3.2)): 34 | tokens = self.tokenizer(chunk)['input_ids'] 35 | if len(tokens) >= self.max_length: 36 | yield { 'input_ids': tokens[:self.max_length] } 37 | elif len(self.file) == 1: 38 | cnt = 0 39 | worker_id = worker_info.id 40 | total_workers = worker_info.num_workers 41 | with open(self.file[0], 'rb') as f: 42 | cctx = zstd.ZstdDecompressor() 43 | reader = io.BufferedReader(cctx.stream_reader(f)) 44 | for line in reader: 45 | if line: 46 | raw_text = json.loads(line)['text'] 47 | for chunk in chunks(raw_text, int(self.max_length*3.2)): 48 | tokens = self.tokenizer(chunk)['input_ids'] 49 | cnt += 1 50 | if len(tokens) >= self.max_length and cnt % total_workers == worker_id: 51 | yield { 'input_ids': tokens[:self.max_length] } 52 | else: 53 | worker_id = worker_info.id 54 | with open(self.file[worker_id], 'rb') as f: 55 | cctx = zstd.ZstdDecompressor() 56 | reader = io.BufferedReader(cctx.stream_reader(f)) 57 | for line in reader: 58 | if line: 59 | raw_text = json.loads(line)['text'] 60 | for chunk in chunks(raw_text, int(self.max_length*3.2)): 61 | tokens = self.tokenizer(chunk)['input_ids'] 62 | if len(tokens) >= self.max_length: 63 | yield {'input_ids': tokens[:self.max_length]} 64 | 65 | if __name__ == "__main__": 66 | import glob 67 | # download test.jsonl.zst from the-pile website 68 | 69 | tokenizer = AutoTokenizer.from_pretrained("bigscience/mt0-small") 70 | tokenizer.add_special_tokens({'bos_token': ''}) 71 | dataset = ZstDataset(list(glob.glob('/mnt/ssd/pythia/*.jsonl.zst')), tokenizer) 72 | # dataset = ZstDataset('/mnt/ssd/pythia_val/val.jsonl.zst', tokenizer, max_length=600) 73 | # mimic result from multiple dataset runs 74 | collate_fn = DataCollatorForUL2(tokenizer, 75 | r_probability=0.5, r_denoising=True, 76 | s_probability=0.5, s_denoising=False, 77 | x_denoising=False, x_probability=0.0) 78 | dataloader = DataLoader(dataset, batch_size=256, collate_fn=collate_fn, num_workers=13) 79 | 80 | 81 | # benchmark iteration speed 82 | for batch in tqdm(dataloader): 83 | print(batch['input_ids'].shape) 84 | # print(batch["decoder_attention_mask"][0]) 85 | # for (input, label) in zip(batch['decoder_input_ids'][:8], batch['labels'][:8]): 86 | # print(tokenizer.decode(input[ input != 0])) 87 | # print(tokenizer.decode(label[ label != -100])) 88 | # print('----') 89 | 90 | # batch = [ ] 91 | # np_batch = collate_fn(batch, return_tensors='pt') 92 | # print(np_batch) 93 | 94 | -------------------------------------------------------------------------------- /tests/test_denoising.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoTokenizer 2 | from text_denoising import DataCollatorForUL2 3 | 4 | 5 | 6 | if __name__ == "__main__": 7 | tokenizer = AutoTokenizer.from_pretrained("bigscience/mt0-small") 8 | tokenizer.add_special_tokens({'bos_token': ''}) 9 | 10 | # collate_fn = DataCollatorForUL2(tokenizer, 11 | # r_probability=1.0, r_denoising=True, 12 | # s_probability=0.0, s_denoising=False, 13 | # x_denoising=False, x_probability=0.0) 14 | collate_fn = DataCollatorForUL2(tokenizer, 15 | r_probability=0.5, r_denoising=True, 16 | s_probability=0.0, s_denoising=False, 17 | x_denoising=True, x_probability=0.5) 18 | 19 | batch = [ 20 | 'Lorem ipsum dolor sit amet, consectetur adipiscing elit. Nulla faucibus turpis et elit malesuada, ac venenatis sapien bibendum. Mauris at ullamcorper libero. Donec interdum auctor nisi a luctus. Suspendisse potenti. Proin vitae tortor vel leo consectetur fermentum. Sed blandit, nulla ac lobortis dapibus, diam massa accumsan velit, non pharetra lectus lacus ac odio. Vestibulum ante ipsum primis in faucibus orci luctus et ultrices posuere cubilia Curae; Sed viverra libero at est efficitur vestibulum. Sed ac tristique mauris, sit amet consequat leo. Sed imperdiet lectus a magna mollis auctor. Sed bibendum eget lacus vitae lobortis. Fusce ac eros eget libero scelerisque consequat. Cras id purus ornare, consectetur ipsum sed, semper nulla. Sed tincidunt eget neque eu pulvinar.. Sed tincidunt eget neque eu pulvinar.. Sed tincidunt eget neque eu pulvinar.. Sed tincidunt eget neque eu pulvinar.. Sed tincidunt eget neque eu pulvinar.. Sed tincidunt eget neque eu pulvinar.. Sed tincidunt eget neque eu pulvinar.', 21 | 'Donec tincidunt enim quis felis lacinia, vitae ultricies nulla consectetur. Praesent ullamcorper ligula ac tincidunt rutrum. Vestibulum euismod ex vel quam porttitor, sit amet consequat velit sollicitudin. Pellentesque non mauris auctor, varius odio non, feugiat mi. Sed vulputate tincidunt arcu eget interdum. Duis convallis purus eget mauris euismod, non efficitur mi convallis. Sed varius massa nec luctus iaculis. Donec ornare, nunc a consequat pellentesque, nisi orci tincidunt quam, ac laoreet mauris orci in nunc. Fusce ut orci sit amet turpis vestibulum imperdiet. Vivamus et enim vel lorem ultrices fringilla. Sed vehicula nibh id risus convallis, ac bibendum sapien vulputate. Vestibulum ante ipsum primis in faucibus orci luctus et ultrices posuere cubilia Curae; Integer at risus quis magna blandit aliquam. Suspendisse varius malesuada mauris, vitae dictum metus aliquam vitae. Ut et ante at odio malesuada lobortis. . Sed tincidunt eget neque eu pulvinar.. Sed tincidunt eget neque eu pulvinar.', 22 | 'Lorem ipsum dolor sit amet, consectetur adipiscing elit. Sed eleifend et tortor ac vehicula. Fusce eu urna aliquet, fringilla odio vel, malesuada metus. Pellentesque bibendum mauris vel est faucibus posuere. Duis ultrices vestibulum nulla, at tempor enim bibendum a. In sit amet quam vel nunc tristique varius ac eu dui. Quisque semper nisi at enim aliquam facilisis. Sed pharetra risus sit amet libero sollicitudin, vel faucibus velit sagittis. Sed viverra magna quis metus malesuada posuere. Donec in ante in enim tristique vestibulum. Sed molestie posuere urna id rhoncus. Fusce sit amet neque ac mi dapibus sollicitudin. Fusce pharetra est sed massa feugiat euismod. Vestibulum eu aliquam nulla, eget varius eros.. Vestibulum eu aliquam nulla, eget varius eros.. Vestibulum eu aliquam nulla, eget varius eros.. Vestibulum eu aliquam nulla, eget varius eros.. Sed tincidunt eget neque eu pulvinar.. Sed tincidunt eget neque eu pulvinar.. Sed tincidunt eget neque eu pulvinar.. Sed tincidunt eget neque eu pulvinar.. Sed tincidunt eget neque eu pulvinar.. Sed tincidunt eget neque eu pulvinar.', 23 | 'Suspendisse malesuada nibh a enim blandit, a laoreet augue imperdiet. Suspendisse at ligula non risus feugiat blandit. Proin tristique mi sit amet ex laoreet, vel viverra ipsum fringilla. Donec at gravida nisi. Curabitur vel magna vitae lectus bibendum lacinia ac vel elit. Nam sit amet sem purus. Suspendisse at metus vitae ipsum viverra bibendum. Ut quis gravida libero. Suspendisse varius vel purus nec scelerisque. Nulla tincidunt enim in mollis eleifend. Donec tincidunt justo vitae diam congue vestibulum. Nam id lectus auctor, pellentesque tellus in, scelerisque tellus. Vivamus vel semper justo. Sed eget ipsum nec libero pellentesque tincidunt. Ut in efficitur purus, sit amet hendrerit ex.Ut in efficitur purus, sit amet hendrerit ex.Ut in efficitur purus, sit amet hendrerit ex.Ut in efficitur purus, sit amet hendrerit ex.Ut in efficitur purus, sit amet hendrerit ex.Ut in efficitur purus, sit amet hendrerit ex.. Sed tincidunt eget neque eu pulvinar.. Sed tincidunt eget neque eu pulvinar.', 24 | ] 25 | encode = collate_fn([ { 'input_ids': tokenizer(r)['input_ids'][:200] } for r in batch] ) 26 | print(tokenizer.decode(tokenizer(batch[0])['input_ids'])) 27 | print('-----') 28 | for input_ids, token_ids, label_ids in zip(encode['input_ids'], encode['decoder_input_ids'], encode['labels']): 29 | print('---------') 30 | print(tokenizer.decode(input_ids)) 31 | print(tokenizer.decode(token_ids)) 32 | print(tokenizer.decode(label_ids[label_ids!= -100])) 33 | print('---------') 34 | -------------------------------------------------------------------------------- /text_denoising/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def random_spans_noise_mask(length, mean_noise_span_length, noise_density): 4 | """ 5 | A copy from https://github.com/EleutherAI/oslo/blob/main/oslo/transformers/tasks/data_t5_pretraining.py#L230 (inception) 6 | This function is copy of `random_spans_helper `__ . 7 | Noise mask consisting of random spans of noise tokens. 8 | The number of noise tokens and the number of noise spans and non-noise spans 9 | are determined deterministically as follows: 10 | num_noise_tokens = round(length * noise_density) 11 | num_nonnoise_spans = num_noise_spans = round(num_noise_tokens / mean_noise_span_length) 12 | Spans alternate between non-noise and noise, beginning with non-noise. 13 | Subject to the above restrictions, all masks are equally likely. 14 | Args: 15 | length: an int32 scalar (length of the incoming token sequence) 16 | noise_density: a float - approximate density of output mask 17 | mean_noise_span_length: a number 18 | Returns: 19 | a boolean tensor with shape [length] 20 | """ 21 | 22 | orig_length = length 23 | 24 | num_noise_tokens = int(np.round(length * noise_density)) 25 | # avoid degeneracy by ensuring positive numbers of noise and nonnoise tokens. 26 | num_noise_tokens = min(max(num_noise_tokens, 1), length - 1) 27 | num_noise_spans = int(np.round(num_noise_tokens / mean_noise_span_length)) 28 | 29 | # avoid degeneracy by ensuring positive number of noise spans 30 | num_noise_spans = max(num_noise_spans, 1) 31 | num_nonnoise_tokens = length - num_noise_tokens 32 | 33 | # pick the lengths of the noise spans and the non-noise spans 34 | def _random_segmentation(num_items, num_segments): 35 | """Partition a sequence of items randomly into non-empty segments. 36 | Args: 37 | num_items: an integer scalar > 0 38 | num_segments: an integer scalar in [1, num_items] 39 | Returns: 40 | a Tensor with shape [num_segments] containing positive integers that add 41 | up to num_items 42 | """ 43 | mask_indices = np.arange(num_items - 1) < (num_segments - 1) 44 | np.random.shuffle(mask_indices) 45 | first_in_segment = np.pad(mask_indices, [[1, 0]]) 46 | segment_id = np.cumsum(first_in_segment) 47 | # count length of sub segments assuming that list is sorted 48 | _, segment_length = np.unique(segment_id, return_counts=True) 49 | return segment_length 50 | 51 | noise_span_lengths = _random_segmentation(num_noise_tokens, num_noise_spans) 52 | nonnoise_span_lengths = _random_segmentation( 53 | num_nonnoise_tokens, num_noise_spans 54 | ) 55 | 56 | interleaved_span_lengths = np.reshape( 57 | np.stack([nonnoise_span_lengths, noise_span_lengths], axis=1), 58 | [num_noise_spans * 2], 59 | ) 60 | span_starts = np.cumsum(interleaved_span_lengths)[:-1] 61 | span_start_indicator = np.zeros((length,), dtype=np.int8) 62 | span_start_indicator[span_starts] = True 63 | span_num = np.cumsum(span_start_indicator) 64 | is_noise = np.equal(span_num % 2, 1) 65 | 66 | return is_noise[:orig_length] 67 | 68 | def compute_input_and_target_lengths( 69 | inputs_length, noise_density, mean_noise_span_length 70 | ): 71 | """ 72 | A copy of copy from https://github.com/EleutherAI/oslo/blob/main/oslo/transformers/tasks/data_t5_pretraining.py#L76 (shits getting meta) 73 | This function is copy of `random_spans_helper `__ . 74 | Training parameters to avoid padding with random_spans_noise_mask. 75 | When training a model with random_spans_noise_mask, we would like to set the other 76 | training hyperparmeters in a way that avoids padding. 77 | This function helps us compute these hyperparameters. 78 | We assume that each noise span in the input is replaced by extra_tokens_per_span_inputs sentinel tokens, 79 | and each non-noise span in the targets is replaced by extra_tokens_per_span_targets sentinel tokens. 80 | This function tells us the required number of tokens in the raw example (for split_tokens()) 81 | as well as the length of the encoded targets. Note that this function assumes 82 | the inputs and targets will have EOS appended and includes that in the reported length. 83 | Args: 84 | inputs_length: an integer - desired length of the tokenized inputs sequence 85 | noise_density: a float 86 | mean_noise_span_length: a float 87 | Returns: 88 | tokens_length: length of original text in tokens 89 | targets_length: an integer - length in tokens of encoded targets sequence 90 | """ 91 | def _tokens_length_to_inputs_length_targets_length(tokens_length): 92 | num_noise_tokens = int(round(tokens_length * noise_density)) 93 | num_nonnoise_tokens = tokens_length - num_noise_tokens 94 | num_noise_spans = int(round(num_noise_tokens / mean_noise_span_length)) 95 | # inputs contain all nonnoise tokens, sentinels for all noise spans 96 | # and one EOS token. 97 | _input_length = num_nonnoise_tokens + num_noise_spans + 1 98 | _output_length = num_noise_tokens + num_noise_spans + 1 99 | return _input_length, _output_length 100 | 101 | 102 | tokens_length = inputs_length 103 | 104 | while ( 105 | _tokens_length_to_inputs_length_targets_length(tokens_length + 1)[0] 106 | <= inputs_length 107 | ): 108 | tokens_length += 1 109 | 110 | inputs_length, targets_length = _tokens_length_to_inputs_length_targets_length( 111 | tokens_length 112 | ) 113 | 114 | # minor hack to get the targets length to be equal to inputs length 115 | # which is more likely to have been set to a nice round number. 116 | if noise_density == 0.5 and targets_length > inputs_length: 117 | tokens_length -= 1 118 | targets_length -= 1 119 | return tokens_length, targets_length -------------------------------------------------------------------------------- /tests/utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import string 3 | from random import choice 4 | 5 | import numpy as np 6 | import torch 7 | 8 | 9 | def count_param(model): 10 | model_parameters = filter(lambda p: p.requires_grad, model.parameters()) 11 | params = sum([np.prod(p.size()) for p in model_parameters]) 12 | return params 13 | 14 | 15 | def gen_random_sequence(length=128): 16 | chars = string.ascii_letters + " .!?'" 17 | return "".join([choice(chars) for i in range(length)]) 18 | 19 | 20 | def test_no_nan(model, tokenizer, iterations=100, device="cpu"): 21 | model_half = model.half().to(device) 22 | model_half.train() 23 | text = " Hello world! The Earth is a nice place to be. " 24 | train_text = " translate English to German: The house is wonderful. " 25 | train_label = " Das Haus ist wunderbar. " 26 | for idx in range(iterations): 27 | inputs = tokenizer.encode(str(idx) + text + str(idx), return_tensors="pt").to( 28 | device 29 | ) 30 | out = model_half(input_ids=inputs, decoder_input_ids=inputs) 31 | if torch.isnan(out[0]).any(): 32 | return False 33 | train_input_ids = tokenizer( 34 | str(idx) + train_text + str(idx), return_tensors="pt" 35 | ).input_ids.to(device) 36 | train_labels = tokenizer( 37 | str(idx) + train_label + str(idx), return_tensors="pt" 38 | ).input_ids.to(device) 39 | loss = model_half(input_ids=train_input_ids, labels=train_labels).loss 40 | if torch.isnan(loss): 41 | return False 42 | inputs = tokenizer.encode(gen_random_sequence(), return_tensors="pt").to(device) 43 | out = model_half(input_ids=inputs, decoder_input_ids=inputs) 44 | if torch.isnan(out[0]).any(): 45 | return False 46 | return True 47 | 48 | 49 | def scale_weights(layer, scale_down_factor): 50 | old_weights = layer.weight 51 | layer.weight = torch.nn.Parameter(old_weights / scale_down_factor) 52 | return old_weights 53 | 54 | 55 | def search_and_reset_layers( 56 | model, tokenizer, scale_down_factor=15, revert_old=False, device="cuda" 57 | ): 58 | model = model.float().to(device) 59 | total_params = count_param(model) 60 | param_reset_count = 0 61 | 62 | print("Testing encoder") 63 | for i, layer in enumerate(model.encoder.block[::-1]): 64 | fflayer0 = layer.layer[1].DenseReluDense.wi_0 65 | fflayer1 = layer.layer[1].DenseReluDense.wi_1 66 | fflayer2 = layer.layer[1].DenseReluDense.wo 67 | 68 | # fflayer2.reset_parameters() 69 | old_weights = scale_weights(fflayer2, scale_down_factor) 70 | param_reset_count += count_param(fflayer2) 71 | if test_no_nan(model, tokenizer, device=device): 72 | print("Success at encoder", len(model.encoder.block) - i, "FF2") 73 | return model.float(), int(param_reset_count / total_params * 100) 74 | else: 75 | if revert_old: 76 | param_reset_count -= count_param(fflayer2) 77 | fflayer2.weight = old_weights 78 | 79 | # fflayer1.reset_parameters() 80 | old_weights = scale_weights(fflayer1, scale_down_factor) 81 | param_reset_count += count_param(fflayer1) 82 | 83 | if test_no_nan(model, tokenizer, device=device): 84 | print("Success at encoder", len(model.encoder.block) - i, "FF1") 85 | return model.float(), int(param_reset_count / total_params * 100) 86 | else: 87 | if revert_old: 88 | param_reset_count -= count_param(fflayer1) 89 | fflayer1.weight = old_weights 90 | 91 | old_weights = scale_weights(fflayer0, scale_down_factor) 92 | param_reset_count += count_param(fflayer0) 93 | if test_no_nan(model, tokenizer, device=device): 94 | print("Success at encoder", len(model.encoder.block) - i, "FF0") 95 | return model.float(), int(param_reset_count / total_params * 100) 96 | else: 97 | if revert_old: 98 | param_reset_count -= count_param(fflayer0) 99 | fflayer0.weight = old_weights 100 | 101 | print("Testing decoder") 102 | for i, layer in enumerate(model.decoder.block[::-1]): 103 | fflayer0 = layer.layer[2].DenseReluDense.wi_0 104 | fflayer1 = layer.layer[2].DenseReluDense.wi_1 105 | fflayer2 = layer.layer[2].DenseReluDense.wo 106 | 107 | # fflayer2.reset_parameters() 108 | old_weights = scale_weights(fflayer2, scale_down_factor) 109 | param_reset_count += count_param(fflayer2) 110 | if test_no_nan(model, tokenizer, device=device): 111 | print("Success at decoder", len(model.decoder.block) - i, "FF2") 112 | return model.float(), int(param_reset_count / total_params * 100) 113 | else: 114 | if revert_old: 115 | param_reset_count -= count_param(fflayer2) 116 | fflayer2.weight = old_weights 117 | 118 | # fflayer1.reset_parameters() 119 | old_weights = scale_weights(fflayer1, scale_down_factor) 120 | param_reset_count += count_param(fflayer1) 121 | if test_no_nan(model, tokenizer, device=device): 122 | print("Success at decoder", len(model.decoder.block) - i, "FF1") 123 | return model.float(), int(param_reset_count / total_params * 100) 124 | else: 125 | if revert_old: 126 | param_reset_count -= count_param(fflayer1) 127 | fflayer1.weight = old_weights 128 | 129 | old_weights = scale_weights(fflayer0, scale_down_factor) 130 | param_reset_count += count_param(fflayer0) 131 | if test_no_nan(model, tokenizer, device=device): 132 | print("Success at decoder", len(model.decoder.block) - i, "FF0") 133 | return model.float(), int(param_reset_count / total_params * 100) 134 | else: 135 | if revert_old: 136 | param_reset_count -= count_param(fflayer0) 137 | fflayer0.weight = old_weights 138 | 139 | return model.float(), False 140 | 141 | 142 | def fix_rescale(model, scale_down_factor=10): 143 | model = model.float() 144 | total_params = count_param(model) 145 | param_reset_count = 0 146 | 147 | print("Testing encoder") 148 | for i, layer in enumerate(model.encoder.block[::-1]): 149 | fflayer0 = layer.layer[1].DenseReluDense.wi_0 150 | fflayer1 = layer.layer[1].DenseReluDense.wi_1 151 | fflayer2 = layer.layer[1].DenseReluDense.wo 152 | 153 | old_weights = scale_weights(fflayer2, scale_down_factor) 154 | param_reset_count += count_param(fflayer2) 155 | 156 | # fflayer1.reset_parameters() 157 | old_weights = scale_weights(fflayer1, scale_down_factor) 158 | param_reset_count += count_param(fflayer1) 159 | 160 | old_weights = scale_weights(fflayer0, scale_down_factor) 161 | param_reset_count += count_param(fflayer0) 162 | 163 | print("Testing decoder") 164 | for i, layer in enumerate(model.decoder.block[::-1]): 165 | fflayer0 = layer.layer[2].DenseReluDense.wi_0 166 | fflayer1 = layer.layer[2].DenseReluDense.wi_1 167 | fflayer2 = layer.layer[2].DenseReluDense.wo 168 | 169 | # fflayer2.reset_parameters() 170 | old_weights = scale_weights(fflayer2, scale_down_factor) 171 | param_reset_count += count_param(fflayer2) 172 | if len(model.decoder.block) - i == 2: 173 | print("decoder", len(model.decoder.block) - i, "FF2") 174 | return model.float() 175 | 176 | # fflayer1.reset_parameters() 177 | old_weights = scale_weights(fflayer1, scale_down_factor) 178 | param_reset_count += count_param(fflayer1) 179 | 180 | old_weights = scale_weights(fflayer0, scale_down_factor) 181 | param_reset_count += count_param(fflayer0) 182 | 183 | return model.float() 184 | -------------------------------------------------------------------------------- /reference_example_code.py: -------------------------------------------------------------------------------- 1 | dataset = tfds.load('wikipedia/20220620.en', split='train', shuffle_files=True) 2 | 3 | def ul2_objective(dataset: tf.data.Dataset, 4 | sequence_length: seqio.preprocessors.SequenceLengthType, 5 | output_features: seqio.preprocessors.OutputFeaturesType, 6 | use_prefix_lm_task: bool = False, 7 | rates: Optional[Sequence[float]] = None, 8 | mean_noise_span_lengths: Sequence[float] = (3.0,), 9 | noise_densities: Sequence[float] = (0.15,), 10 | shard_ds: bool = True, 11 | optional_task_prefixes: Optional[Sequence[str]] = None, 12 | input_feature_key: str = "inputs", 13 | merge_examples_to_reduce_padding: bool = True, 14 | reserved_for_packing: bool = None, 15 | seed: int = 7) -> tf.data.Dataset: 16 | 17 | """ 18 | UL2-like pre-training objectives. This preprocessor amounts to calling the ‘span_corruption‘ function several times with different values of ‘noise_density‘ and ‘mean_noise_span_length‘. 19 | We either shard or copy the dataset, then apply each function to each shard. Add S-denoising (prefixLM) using use_prefix_lm_task. 20 | 21 | Args: 22 | dataset: A tf.data.Dataset with dictionaries containing the key ‘input_feature_key‘. 23 | sequence_length: dict mapping of feature key to int length for that feature. 24 | output_features: mapping of keys to features. 25 | use_prefix_lm_task: If True, include PrefixLM in the task mix. 26 | rates: > List of rates per task. If None, tasks are sampled uniformly. 27 | mean_noise_span_lengths: List of mean number of tokens per masked span per example. 28 | noise_densities: List of what fraction of the tokens to mask. 29 | shard_ds: If True, shard dataset per objective. 30 | optional_task_prefixes: > Strings to prepend for each orruption scheme. 31 | NOTE: If including prefixLM task, it must be the last prefix. 32 | input_feature_key: which feature to use from the dataset as the input text tokens. 33 | merge_examples_to_reduce_padding: if True, combines multiple input examples to reduce padding. reserved_for_packing: if specified, reduces the desired inputs length by the specified amount to enable multiple examples to be packed together downstream. 34 | seed: tf.int64 for controlling the random choice of spans. Returns: a dataset 35 | """ 36 | 37 | if optional_task_prefixes: # Ensure each task has a prefix. 38 | num_tasks = len(noise_densities) + int(use_prefix_lm_task) 39 | valid_number_of_prefixes = num_tasks == len(optional_task_prefixes) 40 | if not valid_number_of_prefixes: 41 | raise ValueError("Number of task prefixes must match number of tasks.") 42 | inputs_length = sequence_length[input_feature_key] 43 | input_lengths, targets_lengths = [], [] 44 | sequence_lengths = {x: y for x, y in sequence_length.items()} 45 | if reserved_for_packing: 46 | inputs_length -= reserved_for_packing 47 | for x, y in sequence_length.items(): 48 | sequence_lengths[x] = y - reserved_for_packing 49 | hyperparams = list(zip(mean_noise_span_lengths, noise_densities)) 50 | for mean_noise_span_length, noise_density in hyperparams: 51 | input_length, targets_length = t5.data.preprocessors.random_spans_helper( 52 | extra_tokens_per_span_inputs=1, 53 | extra_tokens_per_span_targets=1, 54 | inputs_length=inputs_length, 55 | mean_noise_span_length=mean_noise_span_length, 56 | noise_density=noise_density) 57 | input_lengths.append(input_length) 58 | targets_lengths.append(targets_length) 59 | 60 | if sequence_length["targets"] < targets_length: 61 | upper_bound = max(targets_lengths) 62 | raise ValueError(f"Targets length {sequence_length['targets']} is too small for the given noise_density and mean_noise_span_length. Please increase the targets length to at least {upper_bound}.") 63 | #raise ValueError("f’Expected max targets length for span corruption ({upper_bound}) is ’ f’greater than configured targets length ’ f"({sequence_length[’targets’]})") 64 | 65 | ds = dataset 66 | ds = t5.data.preprocessors.select_random_chunk( 67 | ds, 68 | output_features=output_features, 69 | feature_key="targets", 70 | max_length=65536) 71 | if merge_examples_to_reduce_padding: 72 | ds = t5.data.preprocessors.reduce_concat_tokens( 73 | ds, 74 | feature_key="targets", 75 | batch_size=128) 76 | num_shards = len(input_lengths) + int(use_prefix_lm_task) 77 | if shard_ds: 78 | ds_shards = [ds.shard(num_shards, i) for i in range(num_shards)] 79 | else: 80 | ds_shards = [ds for _ in range(num_shards)] 81 | processed_ds = [] 82 | hyperparams = zip(input_lengths, hyperparams, range(num_shards)) 83 | for input_length, (noise_span_length, noise_density), i in hyperparams: 84 | ds = ds_shards[i] 85 | ds = t5.data.preprocessors.split_tokens( 86 | ds, 87 | feature_key="targets", 88 | min_tokens_per_segment=None, 89 | max_tokens_per_segment=input_length) 90 | ds = t5.data.preprocessors.denoise( 91 | ds, 92 | output_features, 93 | inputs_fn=t5.data.preprocessors.noise_span_to_unique_sentinel, 94 | targets_fn=t5.data.preprocessors.nonnoise_span_to_unique_sentinel, 95 | noise_density=noise_density, 96 | noise_mask_fn=functools.partial( 97 | t5.data.preprocessors.random_spans_noise_mask, 98 | mean_noise_span_length=noise_span_length), 99 | input_feature_key=input_feature_key) 100 | if optional_task_prefixes: 101 | ds = prepend_prompt( 102 | ds, 103 | output_features, 104 | prompt_mode=optional_task_prefixes[i], 105 | mode=optional_task_prefixes[i]) 106 | processed_ds.append(ds) 107 | if use_prefix_lm_task: 108 | ds = ds_shards[-1] 109 | ds = t5.data.preprocessors.prefix_lm(ds, sequence_lengths, output_features) 110 | if optional_task_prefixes: 111 | ds = prepend_prompt( 112 | ds, 113 | output_features, 114 | prompt_mode=optional_task_prefixes[-1], 115 | mode=optional_task_prefixes[-1]) 116 | processed_ds.append(ds) 117 | ds = tf.data.experimental.sample_from_datasets(processed_ds, rates, seed) 118 | return ds 119 | 120 | sequence_length = { 121 | "inputs": 512, 122 | "targets": 512, 123 | } 124 | 125 | output_features = { 126 | "inputs": 127 | seqio.Feature( 128 | vocabulary=t5.data.get_default_vocabulary(), add_eos=False), 129 | "targets": 130 | seqio.Feature( 131 | vocabulary=t5.data.get_default_vocabulary(), add_eos=False) 132 | } 133 | 134 | ul2_data = ul2_objective( 135 | dataset, 136 | sequence_length, 137 | output_features, 138 | use_prefix_lm_task=False, 139 | rates=None, 140 | mean_noise_span_lengths=(3.0,), 141 | noise_densities=(0.15,), 142 | shard_ds=True, 143 | optional_task_prefixes=None, 144 | input_feature_key="text", 145 | merge_examples_to_reduce_padding=True, 146 | reserved_for_packing=None, 147 | seed=7) 148 | 149 | 150 | def prepend_prompt(dataset: tf.data.Dataset, 151 | output_features: seqio.preprocessors.OutputFeaturesType, 152 | sequence_length: Optional[ 153 | seqio.preprocessors.SequenceLengthType] = None, 154 | prompt_mode: str = "", 155 | key: str = "inputs", 156 | mode: str = "") -> tf.data.Dataset: 157 | """Prepends a prompt at the beginning of an input sequence.""" 158 | del sequence_length 159 | if prompt_mode and mode: 160 | # output_features may not have inputs key 161 | out_keys = list(output_features.keys()) 162 | prompt_tokens = output_features[out_keys[0] 163 | ].vocabulary.encode_tf(prompt_mode) 164 | 165 | def add_to_inputs(x): 166 | x[key] = tf.concat([prompt_tokens, x[key]], axis=0) 167 | return x 168 | 169 | dataset = dataset.map(add_to_inputs) 170 | return dataset -------------------------------------------------------------------------------- /text_denoising/collate_fn.py: -------------------------------------------------------------------------------- 1 | import random 2 | from collections.abc import Mapping 3 | import numpy as np 4 | import torch 5 | from torch.nn import functional as F 6 | from dataclasses import dataclass 7 | from typing import Any, Dict, List, Optional, Tuple, Union 8 | from transformers.tokenization_utils_base import PreTrainedTokenizerBase 9 | from transformers.data.data_collator import ( 10 | DataCollatorMixin, 11 | _torch_collate_batch, 12 | ) 13 | from text_denoising.utils import random_spans_noise_mask 14 | 15 | 16 | @dataclass 17 | class DataCollatorForUL2(DataCollatorMixin): 18 | """ 19 | 20 | Data collator used for UL2 21 | 22 | """ 23 | tokenizer: PreTrainedTokenizerBase 24 | r_denoising: bool = True 25 | r_probability: float = 0.25 26 | r_denoising_config: Tuple[Tuple] = ((3, 0.15),) 27 | s_denoising: bool = True 28 | s_probability: float = 0.5 29 | x_denoising: bool = True 30 | x_probability: float = 0.25 31 | x_denoising_config: Tuple[Tuple] = ((32, 0.5), (64, 0.2)) 32 | pad_to_multiple_of: Optional[int] = None 33 | tf_experimental_compile: bool = False 34 | return_tensors: str = "pt" 35 | label_pad_token_id: int = -100 36 | 37 | def __post_init__(self): 38 | self.total_task = [0, 1, 2] 39 | task_prob = [] 40 | task_prob.append(self.r_probability if self.r_denoising else 0.0) 41 | task_prob.append(self.s_probability if self.s_denoising else 0.0) 42 | task_prob.append(self.x_probability if self.x_denoising else 0.0) 43 | self.task_prob = task_prob 44 | self.pad_token_id = self.tokenizer.pad_token_id 45 | self.decoder_start_token_id = self.tokenizer.bos_token_id 46 | 47 | assert sum(task_prob) == 1.0 48 | 49 | def assign_task_type(self, batch_size: int): 50 | ''' 51 | Randomly assign S,R,X to each sentence based on weighted prob 52 | ''' 53 | return random.choices(self.total_task,weights=self.task_prob, k=batch_size) 54 | 55 | def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]: 56 | # Handle dict or lists with proper padding and conversion to tensor. 57 | # print(examples) 58 | task_ids = self.assign_task_type(len(examples)) 59 | task_type = torch.tensor(task_ids) 60 | lengths = torch.tensor([ len(e['input_ids']) for e in examples ], dtype=torch.long) 61 | if isinstance(examples[0], Mapping): 62 | batch = self.tokenizer.pad(examples, return_tensors="pt", 63 | pad_to_multiple_of=self.pad_to_multiple_of) 64 | else: 65 | batch = { 66 | "input_ids": _torch_collate_batch(examples, self.tokenizer, 67 | pad_to_multiple_of=self.pad_to_multiple_of) 68 | } 69 | max_length = batch['input_ids'].shape[-1] 70 | 71 | new_batch = { 72 | "input_ids": torch.zeros(batch['input_ids'].shape, dtype=torch.long), 73 | "labels": torch.zeros(batch['input_ids'].shape, dtype=torch.long) 74 | } 75 | 76 | _, expanded_length = batch['input_ids'].shape 77 | input_ids = batch["input_ids"] 78 | r_denoising_idx = task_type == 0 79 | if r_denoising_idx.any(): 80 | mask_indices = None 81 | sub_input_ids = input_ids[r_denoising_idx] 82 | # union of different denoising settings 83 | for (mean_span, noise) in self.r_denoising_config: 84 | _mask_indices = np.array([ 85 | random_spans_noise_mask(expanded_length, mean_span, noise) for _ in range(len(sub_input_ids)) 86 | ]) 87 | if mask_indices is None: 88 | mask_indices = _mask_indices 89 | else: 90 | mask_indices = mask_indices | _mask_indices 91 | 92 | input_ids_sentinel = self.create_sentinel_ids(mask_indices.astype(np.int8)) 93 | labels_mask = ~mask_indices 94 | labels_sentinel = self.create_sentinel_ids(labels_mask.astype(np.int8)) 95 | _sub_input_ids = self.filter_input_ids(sub_input_ids, input_ids_sentinel) 96 | _labels = self.filter_input_ids(sub_input_ids, labels_sentinel) 97 | diff = max_length-_labels.shape[-1] 98 | _labels = np.pad(_labels, [(0,0), (0, diff)], 'constant', 99 | constant_values=self.label_pad_token_id) 100 | diff = max_length - _sub_input_ids.shape[-1] 101 | _sub_input_ids = np.pad(_sub_input_ids, [(0,0), (0, diff)], 'constant') 102 | new_batch['input_ids'][r_denoising_idx] = torch.from_numpy(_sub_input_ids).long() 103 | new_batch['labels'][r_denoising_idx] = torch.from_numpy(_labels).long() 104 | 105 | s_denoising_idx = task_type == 1 106 | if s_denoising_idx.any(): 107 | sub_input_ids = input_ids[s_denoising_idx] 108 | _labels = [] 109 | _input_ids = [] 110 | for input_id, len_ in zip(sub_input_ids, lengths[s_denoising_idx]): 111 | split = max(len_//2, 2) 112 | diff = expanded_length - split 113 | _input_ids.append(F.pad(input_id[:split], (0, diff), 'constant', self.pad_token_id)) 114 | past_seq = input_id[split:] 115 | if past_seq[-1] != self.tokenizer.eos_token_id: 116 | past_seq[-1] = self.tokenizer.eos_token_id 117 | _labels.append(F.pad(past_seq, (0, split), 'constant', self.label_pad_token_id)) 118 | 119 | new_batch['input_ids'][s_denoising_idx] = torch.stack(_input_ids) 120 | new_batch['labels'][s_denoising_idx] = torch.stack(_labels) 121 | 122 | 123 | x_denoising_idx = task_type == 2 124 | if x_denoising_idx.any(): 125 | mask_indices = None 126 | sub_input_ids = input_ids[x_denoising_idx] 127 | for (mean_span, noise) in self.x_denoising_config: 128 | _mask_indices = np.array([ 129 | random_spans_noise_mask(expanded_length, mean_span, noise) for _ in range(len(sub_input_ids)) 130 | ]) 131 | if mask_indices is None: 132 | mask_indices = _mask_indices 133 | else: 134 | mask_indices = mask_indices | _mask_indices 135 | 136 | labels_mask = ~mask_indices 137 | input_ids_sentinel = self.create_sentinel_ids(mask_indices.astype(np.int8)) 138 | labels_sentinel = self.create_sentinel_ids(labels_mask.astype(np.int8)) 139 | _sub_input_ids = self.filter_input_ids(sub_input_ids, input_ids_sentinel) 140 | _labels = self.filter_input_ids(sub_input_ids, labels_sentinel) 141 | diff = max_length-_labels.shape[-1] 142 | _labels = np.pad(_labels, [(0, 0), (0, diff)], 'constant', 143 | constant_values=self.label_pad_token_id) 144 | diff = max_length - _sub_input_ids.shape[-1] 145 | _sub_input_ids = np.pad(_sub_input_ids, [(0,0), (0, diff)], 'constant') 146 | new_batch['input_ids'][x_denoising_idx] = torch.from_numpy(_sub_input_ids).long() 147 | new_batch['labels'][x_denoising_idx] = torch.from_numpy(_labels).long() 148 | 149 | return self.prepare_decoder_inputs_from_labels(new_batch) 150 | 151 | 152 | def filter_input_ids(self, input_ids, sentinel_ids): 153 | """ 154 | Puts sentinel mask on `input_ids` and fuse consecutive mask tokens into a single mask token by deleting. 155 | This will reduce the sequence length from `expanded_inputs_length` to `input_length`. 156 | """ 157 | 158 | input_ids_full = np.where(sentinel_ids != 0, sentinel_ids, input_ids) 159 | # input_ids tokens and sentinel tokens are >= 0, tokens < 0 are 160 | # masked tokens coming after sentinel tokens and should be removed 161 | input_ids = [] 162 | for row in input_ids_full: 163 | collapsed_id = row[row >= 0] 164 | diff = len(row) - len(collapsed_id) 165 | collapsed_id = np.pad(collapsed_id, (0, diff), 'constant') 166 | input_ids.append(collapsed_id) 167 | return np.array(input_ids) 168 | 169 | def create_sentinel_ids(self, mask_indices): 170 | """ 171 | Sentinel ids creation given the indices that should be masked. 172 | The start indices of each mask are replaced by the sentinel ids in increasing 173 | order. Consecutive mask indices to be deleted are replaced with `-1`. 174 | """ 175 | start_indices = mask_indices - np.roll(mask_indices, 1, axis=-1) * mask_indices 176 | start_indices[:, 0] = mask_indices[:, 0] 177 | 178 | sentinel_ids = np.where( 179 | start_indices != 0, np.cumsum(start_indices, axis=-1), start_indices 180 | ) 181 | sentinel_ids = np.where( 182 | sentinel_ids != 0, (len(self.tokenizer) - sentinel_ids), 0 183 | ) 184 | sentinel_ids -= mask_indices - start_indices 185 | 186 | return sentinel_ids 187 | 188 | 189 | def prepare_decoder_inputs_from_labels(self, batch): 190 | # decoder_start_token_id has to be defined. In T5 it is usually set to the pad_token_id. 191 | # See T5 docs for more information 192 | batch["labels"][ batch["labels"] == self.pad_token_id ] = self.label_pad_token_id 193 | shifted_labels = batch["labels"].new_zeros(batch["labels"].shape) 194 | shifted_labels[..., 1:] = batch["labels"][..., :-1].clone() 195 | shifted_labels[..., 0] = self.decoder_start_token_id # decoder_start_token_id 196 | 197 | batch["decoder_input_ids"] = torch.masked_fill( 198 | shifted_labels, 199 | shifted_labels == self.label_pad_token_id, 200 | self.pad_token_id 201 | ) 202 | batch["decoder_attention_mask"] = torch.where( 203 | shifted_labels == self.label_pad_token_id, 204 | 0, 205 | torch.ones_like(shifted_labels), 206 | ) 207 | return batch 208 | 209 | def np_prepare_decoder_inputs_from_labels(self, batch): 210 | batch["labels"][ batch["labels"] == self.pad_token_id ] = self.label_pad_token_id 211 | shifted_labels = np.zeros(batch["labels"].shape) 212 | shifted_labels[..., 1:] = batch["labels"][..., :-1].copy() 213 | shifted_labels[..., 0] = self.decoder_start_token_id 214 | 215 | batch["decoder_input_ids"] = np.where( 216 | shifted_labels == self.label_pad_token_id, 217 | self.pad_token_id, 218 | shifted_labels 219 | ) 220 | batch["decoder_attention_mask"] = np.where( 221 | shifted_labels == self.label_pad_token_id, 222 | 0, 223 | np.ones_like(shifted_labels) 224 | ) 225 | return batch 226 | 227 | --------------------------------------------------------------------------------