├── tests
├── __init__.py
├── test_train.py
├── test_zst.py
├── test_dataset.py
├── test_denoising.py
└── utils.py
├── text_denoising
├── __init__.py
├── utils.py
└── collate_fn.py
├── ul2.png
├── TODO.md
├── zero2_config.json
├── setup.py
├── .gitignore
├── examples
└── pretrain_example.py
├── README.md
└── reference_example_code.py
/tests/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/text_denoising/__init__.py:
--------------------------------------------------------------------------------
1 | from .collate_fn import DataCollatorForUL2
--------------------------------------------------------------------------------
/ul2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/theblackcat102/unify-learning-paradigms/HEAD/ul2.png
--------------------------------------------------------------------------------
/tests/test_train.py:
--------------------------------------------------------------------------------
1 | from transformers import AutoModelForSeq2SeqLM
2 | from text_denoising.collate_fn import DataCollatorForUL2
3 | from torch.utils.data import IterableDataset, DataLoader
4 | from tests.test_dataset import ZstDataset
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/TODO.md:
--------------------------------------------------------------------------------
1 |
2 | [] ~~add task token [S], [X], [D] token in front of each input sequence~~
3 |
4 | [Yi Tay: While we found mode switching to help, it was admittedly v inconvenient. This update also removes the need to use mode tokens :)](https://twitter.com/yitayml/status/1631359474421366784?s=21)
5 |
6 | [] Add eos and bos token to target sequence
--------------------------------------------------------------------------------
/tests/test_zst.py:
--------------------------------------------------------------------------------
1 | import json
2 | import io
3 | from tqdm import tqdm
4 | import zstandard as zstd
5 |
6 |
7 | if __name__ == "__main__":
8 |
9 | count = 0
10 | with open('/mnt/ssd/pythia/04.jsonl.zst', 'rb') as f:
11 | cctx = zstd.ZstdDecompressor()
12 | reader = io.BufferedReader(cctx.stream_reader(f))
13 | for line in tqdm(reader):
14 | count += 1
15 | print(count)
16 |
17 |
--------------------------------------------------------------------------------
/zero2_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "fp16": {
3 | "enabled": "auto",
4 | "loss_scale": 0,
5 | "loss_scale_window": 1000,
6 | "initial_scale_power": 16,
7 | "hysteresis": 2,
8 | "min_loss_scale": 0.5
9 | },
10 | "scheduler": {
11 | "type": "WarmupDecayLR",
12 | "params": {
13 | "warmup_min_lr": "auto",
14 | "warmup_max_lr": "auto",
15 | "warmup_num_steps": "auto",
16 | "total_num_steps": "auto"
17 | }
18 | },
19 | "zero_optimization": {
20 | "stage": 2,
21 | "offload_optimizer": {
22 | "device": "cpu",
23 | "pin_memory": true
24 | },
25 | "allgather_partitions": true,
26 | "allgather_bucket_size": 1e9,
27 | "overlap_comm": true,
28 | "reduce_scatter": true,
29 | "reduce_bucket_size": 1e9,
30 | "contiguous_gradients": true
31 | },
32 | "gradient_accumulation_steps": "auto",
33 | "gradient_clipping": "auto",
34 | "steps_per_print": 2000,
35 | "train_batch_size": "auto",
36 | "train_micro_batch_size_per_gpu": "auto",
37 | "wall_clock_breakdown": false
38 | }
39 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | with open("README.md", "r", encoding="utf-8") as fh:
4 | long_description = fh.read()
5 |
6 | setup(
7 | name="text-denoising",
8 | version="0.1.1", # Start with a small version number, you can increment it for subsequent releases
9 | author="Zhi-Rui, Tam",
10 | author_email="theblackcat102@github.io",
11 | description="A package for text denoising : UL2",
12 | long_description=long_description,
13 | long_description_content_type="text/markdown",
14 | url="https://github.com/your_username/text_denoising", # If you have a repo for this package
15 | packages=find_packages(include=["text_denoising"]),
16 | classifiers=[
17 | "Programming Language :: Python :: 3",
18 | "License :: OSI Approved :: MIT License", # You can change this if you have another preferred license
19 | "Operating System :: OS Independent",
20 | ],
21 | install_requires=[
22 | "numpy>=1.18.0",
23 | "torch>=1.7.0",
24 | "transformers",
25 | ],
26 | extras_require={
27 | "test": [
28 | # Add any additional testing dependencies here
29 | "pytest>=6.0.0",
30 | ]
31 | },
32 | python_requires='>=3.6',
33 | )
34 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 |
4 | *.py[cod]
5 | *$py.class
6 | currents_api.log
7 | # database
8 | *.db
9 | *.db*
10 | sqlite.db*
11 | sqlite_debug*
12 | # C extensions
13 | *.so
14 | schema.json
15 | # Distribution / packaging
16 | .Python
17 | env/
18 | env2/
19 | build/
20 | develop-eggs/
21 | dist/
22 | downloads/
23 | eggs/
24 | .eggs/
25 | lib64/
26 | parts/
27 | sdist/
28 | var/
29 | *.egg-info/
30 | .installed.cfg
31 | *.egg
32 |
33 | # PyInstaller
34 | # Usually these files are written by a python script from a template
35 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
36 | *.manifest
37 | *.spec
38 |
39 | # Installer logs
40 | pip-log.txt
41 | pip-delete-this-directory.txt
42 |
43 | # Unit test / coverage reports
44 | htmlcov/
45 | .tox/
46 | .coverage
47 | .coverage.*
48 | .cache
49 | nosetests.xml
50 | coverage.xml
51 | *,cover
52 | .hypothesis/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 |
62 | # Flask stuff:
63 | instance/
64 | .webassets-cache
65 |
66 | # Scrapy stuff:
67 | .scrapy
68 |
69 | # Sphinx documentation
70 | docs/_build/
71 |
72 | # PyBuilder
73 | target/
74 |
75 | # IPython Notebook
76 | .ipynb_checkpoints
77 |
78 | # pyenv
79 | .python-version
80 |
81 | # celery beat schedule file
82 | celerybeat-schedule
83 |
84 | # dotenv
85 | .env
86 |
87 | # virtualenv
88 | venv/
89 | ENV/
90 | env/
91 | # Spyder project settings
92 | .spyderproject
93 |
94 | # Rope project settings
95 | .ropeproject
96 |
97 | .idea
98 | .DS_Store
99 |
--------------------------------------------------------------------------------
/examples/pretrain_example.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import os
3 | import glob
4 | os.environ["WANDB_PROJECT"] = "ul2_pretrain"
5 | from transformers import (AutoTokenizer, Seq2SeqTrainer, MT5ForConditionalGeneration, MT5Config,
6 | Seq2SeqTrainingArguments, trainer, AutoModelForSeq2SeqLM)
7 | from text_denoising import DataCollatorForUL2
8 | from tests.test_dataset import ZstDataset
9 | from tests import utils
10 |
11 | if __name__ == "__main__":
12 | model_name = "theblackcat102/mt0-chat-large"
13 | model_saved_name = model_name.split("/")[-1]
14 | tokenizer = AutoTokenizer.from_pretrained(model_name)
15 | tokenizer.add_special_tokens({'bos_token': ''})
16 |
17 | # model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
18 | # model.resize_token_embeddings(len(tokenizer))
19 |
20 | config = MT5Config.from_pretrained(model_name)
21 | config.vocab_size = len(tokenizer)
22 | config.dropout_rate = 0.0
23 | print(config)
24 | model = MT5ForConditionalGeneration(config)
25 | # model = MT5ForConditionalGeneration.from_pretrained('theblackcat102/mt0-chat-large-ul2-2000')
26 | num_params = sum(param.numel() for param in model.parameters())
27 | print(num_params/1e6)
28 | print(model.config)
29 | print(tokenizer.convert_tokens_to_ids)
30 | collate_fn = DataCollatorForUL2(tokenizer)
31 | train_dataset = ZstDataset(list(glob.glob('/mnt/ssd/pythia/*.jsonl.zst')), tokenizer, max_length=600)
32 | val_dataset = ZstDataset('/mnt/ssd/pythia_val/val.jsonl.zst', tokenizer, max_length=600)
33 | args = Seq2SeqTrainingArguments(
34 | output_dir=f"{model_name}-ul2",
35 | fp16=True,
36 | deepspeed="zero2_config.json",
37 | max_steps=100000,
38 | warmup_steps=4000,
39 | learning_rate=5e-4,
40 | label_smoothing_factor=0,
41 | optim="adamw_hf",
42 | gradient_checkpointing=True,
43 | dataloader_num_workers=22,
44 | gradient_accumulation_steps=22,
45 | per_device_train_batch_size=25,
46 | per_device_eval_batch_size=8,
47 | weight_decay=0.01,
48 | max_grad_norm=2,
49 | logging_steps=10,
50 | save_total_limit=4,
51 | evaluation_strategy="steps",
52 | eval_steps=500,
53 | save_steps=500,
54 | report_to="wandb",
55 | )
56 |
57 | trainer = Seq2SeqTrainer(
58 | model,
59 | args,
60 | train_dataset=train_dataset,
61 | eval_dataset=val_dataset,
62 | data_collator=collate_fn,
63 | tokenizer=tokenizer,
64 | )
65 | import wandb
66 | wandb.init(
67 | project="ul2_pretrain",
68 | name=model_name
69 | )
70 | trainer.train()
71 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Masking implementation for Unifying Language Learning Paradigms (UL2)
2 |
3 | Want to get better model with limited budgets? You are in the right place
4 |
5 | ```
6 | pip install text-denoising
7 | ```
8 |
9 |
10 |
11 |
12 |
13 | - R-Denoiser (μ=3,r=0.15,n)∪ (μ=8,r=0.15,n)
14 |
15 | The regular denoising is the standard span corruption introduced in Raffel et al. (2019) that uses a range of 2 to 5 tokens as the span length, which masks about 15% of input tokens
16 |
17 | - S-Denoiser (μ=L/4,r=0.25,1)
18 |
19 | A specific case of denoising where we observe a strict sequential order when framing the inputs-to-targets task, i.e., prefix language modeling
20 |
21 | - X-Denoiser (μ = 3,r = 0.5,n)∪(μ = 8,r = 0.5,n)∪(μ = 64,r =0.15,n)∪ (μ=64,r=0.5,n)
22 |
23 | An extreme version of denoising where the model must recover a large part of the input, given a small to moderate part of it. This simulates a situation where a model needs to generate long target from a memory with relatively limited information. To do so, we opt to include examples with aggressive denoising where approximately 50% of the input sequence is masked
24 |
25 | 2022 papers : Transcending Scaling Laws with 0.1% Extra Compute
26 |
27 | > we show an approximately 2x computational savings rate
28 |
29 | - Regular denoising whereby the noise is sampled as spans, replaced with sentinel tokens. This is also the standard span corruption task used in Raffel et al. (2019). Spans are typically uniformly sampled with a mean of 3 and a corruption rate of 15%.
30 |
31 | - Extreme denoising whereby the noise is increased to relatively ‘extreme‘ amounts in either a huge percentage of the original text or being very long in nature. Spans are typically uniformly sampled with a mean length of 32 OR a corruption rate of up to 50%.
32 |
33 | - Sequential denoising whereby the noise is always sampled from the start of the text to a randomly sampled point in the text. This is also known as the PrefixLM objective (not to be confused with the architecture).
34 |
35 | This repo will just aim for accompolish this task instead, UL2 is way too complicated for my likings
36 |
37 | > 50% PrefixLM, 25% Long (extreme) span corruption, and 25% regular span corruption to be quite simple and efficient
38 |
39 |
40 | ## Experiments
41 |
42 | Run a mT5 encoder pretraining on 3090 on pythia json.zst files
43 |
44 | ```
45 | pip install text-denoising
46 | python examples/pretrain_example.py
47 | ```
48 |
49 |
50 |
51 |
52 |
53 | training loss was stable and no weird spikes
54 |
55 | ## References
56 |
57 | Core Papers
58 |
59 | [Transcending Scaling Laws with 0.1% Extra Compute](https://arxiv.org/pdf/2210.11399.pdf)
60 |
61 | [Unifying Language Learning Paradigms](https://arxiv.org/pdf/2205.05131.pdf)
62 |
63 | Implements of t5 noise masking in huggingface transformers or python code
64 |
65 | [OSLO](https://github.com/EleutherAI/oslo) : very underrated, some tidy and documentation, this will be a very useful tool
66 |
67 | - [t5_pretraining.py](https://github.com/EleutherAI/oslo/blob/main/oslo/transformers/tasks/data_t5_pretraining.py)
68 |
69 | Heavily inspired from this section
70 |
71 | [Amazon science : label aware pretrain in python](https://github.com/amazon-science/label-aware-pretrain/blob/main/models/preprocessor.py)
72 |
73 | [Fairseq : span_mask_tokens_dataset.py](https://github.com/facebookresearch/fairseq/blob/main/fairseq/data/span_mask_tokens_dataset.py)
74 |
--------------------------------------------------------------------------------
/tests/test_dataset.py:
--------------------------------------------------------------------------------
1 | import json
2 | import io
3 | import torch
4 | import zstandard as zstd
5 | from tqdm import tqdm
6 | from transformers import AutoTokenizer
7 | from text_denoising.collate_fn import DataCollatorForUL2
8 | from torch.utils.data import IterableDataset, DataLoader
9 |
10 | def chunks(lst, chunk_size):
11 | for i in range(0, len(lst), chunk_size):
12 | yield lst[i:i+chunk_size]
13 |
14 | class ZstDataset(IterableDataset):
15 |
16 | def __init__(self, file, tokenizer, max_length=512) -> None:
17 | super().__init__()
18 | self.file = [file] if isinstance(file, str) else file
19 | self.tokenizer = tokenizer
20 | self.max_length = max_length
21 |
22 |
23 | def __iter__(self):
24 | worker_info = torch.utils.data.get_worker_info()
25 | if worker_info is None:
26 | for filename in self.file:
27 | with open(filename, 'rb') as f:
28 | cctx = zstd.ZstdDecompressor()
29 | reader = io.BufferedReader(cctx.stream_reader(f))
30 | for line in reader:
31 | if line:
32 | raw_text = json.loads(line)['text']
33 | for chunk in chunks(raw_text, int(self.max_length*3.2)):
34 | tokens = self.tokenizer(chunk)['input_ids']
35 | if len(tokens) >= self.max_length:
36 | yield { 'input_ids': tokens[:self.max_length] }
37 | elif len(self.file) == 1:
38 | cnt = 0
39 | worker_id = worker_info.id
40 | total_workers = worker_info.num_workers
41 | with open(self.file[0], 'rb') as f:
42 | cctx = zstd.ZstdDecompressor()
43 | reader = io.BufferedReader(cctx.stream_reader(f))
44 | for line in reader:
45 | if line:
46 | raw_text = json.loads(line)['text']
47 | for chunk in chunks(raw_text, int(self.max_length*3.2)):
48 | tokens = self.tokenizer(chunk)['input_ids']
49 | cnt += 1
50 | if len(tokens) >= self.max_length and cnt % total_workers == worker_id:
51 | yield { 'input_ids': tokens[:self.max_length] }
52 | else:
53 | worker_id = worker_info.id
54 | with open(self.file[worker_id], 'rb') as f:
55 | cctx = zstd.ZstdDecompressor()
56 | reader = io.BufferedReader(cctx.stream_reader(f))
57 | for line in reader:
58 | if line:
59 | raw_text = json.loads(line)['text']
60 | for chunk in chunks(raw_text, int(self.max_length*3.2)):
61 | tokens = self.tokenizer(chunk)['input_ids']
62 | if len(tokens) >= self.max_length:
63 | yield {'input_ids': tokens[:self.max_length]}
64 |
65 | if __name__ == "__main__":
66 | import glob
67 | # download test.jsonl.zst from the-pile website
68 |
69 | tokenizer = AutoTokenizer.from_pretrained("bigscience/mt0-small")
70 | tokenizer.add_special_tokens({'bos_token': ''})
71 | dataset = ZstDataset(list(glob.glob('/mnt/ssd/pythia/*.jsonl.zst')), tokenizer)
72 | # dataset = ZstDataset('/mnt/ssd/pythia_val/val.jsonl.zst', tokenizer, max_length=600)
73 | # mimic result from multiple dataset runs
74 | collate_fn = DataCollatorForUL2(tokenizer,
75 | r_probability=0.5, r_denoising=True,
76 | s_probability=0.5, s_denoising=False,
77 | x_denoising=False, x_probability=0.0)
78 | dataloader = DataLoader(dataset, batch_size=256, collate_fn=collate_fn, num_workers=13)
79 |
80 |
81 | # benchmark iteration speed
82 | for batch in tqdm(dataloader):
83 | print(batch['input_ids'].shape)
84 | # print(batch["decoder_attention_mask"][0])
85 | # for (input, label) in zip(batch['decoder_input_ids'][:8], batch['labels'][:8]):
86 | # print(tokenizer.decode(input[ input != 0]))
87 | # print(tokenizer.decode(label[ label != -100]))
88 | # print('----')
89 |
90 | # batch = [ ]
91 | # np_batch = collate_fn(batch, return_tensors='pt')
92 | # print(np_batch)
93 |
94 |
--------------------------------------------------------------------------------
/tests/test_denoising.py:
--------------------------------------------------------------------------------
1 | from transformers import AutoTokenizer
2 | from text_denoising import DataCollatorForUL2
3 |
4 |
5 |
6 | if __name__ == "__main__":
7 | tokenizer = AutoTokenizer.from_pretrained("bigscience/mt0-small")
8 | tokenizer.add_special_tokens({'bos_token': ''})
9 |
10 | # collate_fn = DataCollatorForUL2(tokenizer,
11 | # r_probability=1.0, r_denoising=True,
12 | # s_probability=0.0, s_denoising=False,
13 | # x_denoising=False, x_probability=0.0)
14 | collate_fn = DataCollatorForUL2(tokenizer,
15 | r_probability=0.5, r_denoising=True,
16 | s_probability=0.0, s_denoising=False,
17 | x_denoising=True, x_probability=0.5)
18 |
19 | batch = [
20 | 'Lorem ipsum dolor sit amet, consectetur adipiscing elit. Nulla faucibus turpis et elit malesuada, ac venenatis sapien bibendum. Mauris at ullamcorper libero. Donec interdum auctor nisi a luctus. Suspendisse potenti. Proin vitae tortor vel leo consectetur fermentum. Sed blandit, nulla ac lobortis dapibus, diam massa accumsan velit, non pharetra lectus lacus ac odio. Vestibulum ante ipsum primis in faucibus orci luctus et ultrices posuere cubilia Curae; Sed viverra libero at est efficitur vestibulum. Sed ac tristique mauris, sit amet consequat leo. Sed imperdiet lectus a magna mollis auctor. Sed bibendum eget lacus vitae lobortis. Fusce ac eros eget libero scelerisque consequat. Cras id purus ornare, consectetur ipsum sed, semper nulla. Sed tincidunt eget neque eu pulvinar.. Sed tincidunt eget neque eu pulvinar.. Sed tincidunt eget neque eu pulvinar.. Sed tincidunt eget neque eu pulvinar.. Sed tincidunt eget neque eu pulvinar.. Sed tincidunt eget neque eu pulvinar.. Sed tincidunt eget neque eu pulvinar.',
21 | 'Donec tincidunt enim quis felis lacinia, vitae ultricies nulla consectetur. Praesent ullamcorper ligula ac tincidunt rutrum. Vestibulum euismod ex vel quam porttitor, sit amet consequat velit sollicitudin. Pellentesque non mauris auctor, varius odio non, feugiat mi. Sed vulputate tincidunt arcu eget interdum. Duis convallis purus eget mauris euismod, non efficitur mi convallis. Sed varius massa nec luctus iaculis. Donec ornare, nunc a consequat pellentesque, nisi orci tincidunt quam, ac laoreet mauris orci in nunc. Fusce ut orci sit amet turpis vestibulum imperdiet. Vivamus et enim vel lorem ultrices fringilla. Sed vehicula nibh id risus convallis, ac bibendum sapien vulputate. Vestibulum ante ipsum primis in faucibus orci luctus et ultrices posuere cubilia Curae; Integer at risus quis magna blandit aliquam. Suspendisse varius malesuada mauris, vitae dictum metus aliquam vitae. Ut et ante at odio malesuada lobortis. . Sed tincidunt eget neque eu pulvinar.. Sed tincidunt eget neque eu pulvinar.',
22 | 'Lorem ipsum dolor sit amet, consectetur adipiscing elit. Sed eleifend et tortor ac vehicula. Fusce eu urna aliquet, fringilla odio vel, malesuada metus. Pellentesque bibendum mauris vel est faucibus posuere. Duis ultrices vestibulum nulla, at tempor enim bibendum a. In sit amet quam vel nunc tristique varius ac eu dui. Quisque semper nisi at enim aliquam facilisis. Sed pharetra risus sit amet libero sollicitudin, vel faucibus velit sagittis. Sed viverra magna quis metus malesuada posuere. Donec in ante in enim tristique vestibulum. Sed molestie posuere urna id rhoncus. Fusce sit amet neque ac mi dapibus sollicitudin. Fusce pharetra est sed massa feugiat euismod. Vestibulum eu aliquam nulla, eget varius eros.. Vestibulum eu aliquam nulla, eget varius eros.. Vestibulum eu aliquam nulla, eget varius eros.. Vestibulum eu aliquam nulla, eget varius eros.. Sed tincidunt eget neque eu pulvinar.. Sed tincidunt eget neque eu pulvinar.. Sed tincidunt eget neque eu pulvinar.. Sed tincidunt eget neque eu pulvinar.. Sed tincidunt eget neque eu pulvinar.. Sed tincidunt eget neque eu pulvinar.',
23 | 'Suspendisse malesuada nibh a enim blandit, a laoreet augue imperdiet. Suspendisse at ligula non risus feugiat blandit. Proin tristique mi sit amet ex laoreet, vel viverra ipsum fringilla. Donec at gravida nisi. Curabitur vel magna vitae lectus bibendum lacinia ac vel elit. Nam sit amet sem purus. Suspendisse at metus vitae ipsum viverra bibendum. Ut quis gravida libero. Suspendisse varius vel purus nec scelerisque. Nulla tincidunt enim in mollis eleifend. Donec tincidunt justo vitae diam congue vestibulum. Nam id lectus auctor, pellentesque tellus in, scelerisque tellus. Vivamus vel semper justo. Sed eget ipsum nec libero pellentesque tincidunt. Ut in efficitur purus, sit amet hendrerit ex.Ut in efficitur purus, sit amet hendrerit ex.Ut in efficitur purus, sit amet hendrerit ex.Ut in efficitur purus, sit amet hendrerit ex.Ut in efficitur purus, sit amet hendrerit ex.Ut in efficitur purus, sit amet hendrerit ex.. Sed tincidunt eget neque eu pulvinar.. Sed tincidunt eget neque eu pulvinar.',
24 | ]
25 | encode = collate_fn([ { 'input_ids': tokenizer(r)['input_ids'][:200] } for r in batch] )
26 | print(tokenizer.decode(tokenizer(batch[0])['input_ids']))
27 | print('-----')
28 | for input_ids, token_ids, label_ids in zip(encode['input_ids'], encode['decoder_input_ids'], encode['labels']):
29 | print('---------')
30 | print(tokenizer.decode(input_ids))
31 | print(tokenizer.decode(token_ids))
32 | print(tokenizer.decode(label_ids[label_ids!= -100]))
33 | print('---------')
34 |
--------------------------------------------------------------------------------
/text_denoising/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | def random_spans_noise_mask(length, mean_noise_span_length, noise_density):
4 | """
5 | A copy from https://github.com/EleutherAI/oslo/blob/main/oslo/transformers/tasks/data_t5_pretraining.py#L230 (inception)
6 | This function is copy of `random_spans_helper `__ .
7 | Noise mask consisting of random spans of noise tokens.
8 | The number of noise tokens and the number of noise spans and non-noise spans
9 | are determined deterministically as follows:
10 | num_noise_tokens = round(length * noise_density)
11 | num_nonnoise_spans = num_noise_spans = round(num_noise_tokens / mean_noise_span_length)
12 | Spans alternate between non-noise and noise, beginning with non-noise.
13 | Subject to the above restrictions, all masks are equally likely.
14 | Args:
15 | length: an int32 scalar (length of the incoming token sequence)
16 | noise_density: a float - approximate density of output mask
17 | mean_noise_span_length: a number
18 | Returns:
19 | a boolean tensor with shape [length]
20 | """
21 |
22 | orig_length = length
23 |
24 | num_noise_tokens = int(np.round(length * noise_density))
25 | # avoid degeneracy by ensuring positive numbers of noise and nonnoise tokens.
26 | num_noise_tokens = min(max(num_noise_tokens, 1), length - 1)
27 | num_noise_spans = int(np.round(num_noise_tokens / mean_noise_span_length))
28 |
29 | # avoid degeneracy by ensuring positive number of noise spans
30 | num_noise_spans = max(num_noise_spans, 1)
31 | num_nonnoise_tokens = length - num_noise_tokens
32 |
33 | # pick the lengths of the noise spans and the non-noise spans
34 | def _random_segmentation(num_items, num_segments):
35 | """Partition a sequence of items randomly into non-empty segments.
36 | Args:
37 | num_items: an integer scalar > 0
38 | num_segments: an integer scalar in [1, num_items]
39 | Returns:
40 | a Tensor with shape [num_segments] containing positive integers that add
41 | up to num_items
42 | """
43 | mask_indices = np.arange(num_items - 1) < (num_segments - 1)
44 | np.random.shuffle(mask_indices)
45 | first_in_segment = np.pad(mask_indices, [[1, 0]])
46 | segment_id = np.cumsum(first_in_segment)
47 | # count length of sub segments assuming that list is sorted
48 | _, segment_length = np.unique(segment_id, return_counts=True)
49 | return segment_length
50 |
51 | noise_span_lengths = _random_segmentation(num_noise_tokens, num_noise_spans)
52 | nonnoise_span_lengths = _random_segmentation(
53 | num_nonnoise_tokens, num_noise_spans
54 | )
55 |
56 | interleaved_span_lengths = np.reshape(
57 | np.stack([nonnoise_span_lengths, noise_span_lengths], axis=1),
58 | [num_noise_spans * 2],
59 | )
60 | span_starts = np.cumsum(interleaved_span_lengths)[:-1]
61 | span_start_indicator = np.zeros((length,), dtype=np.int8)
62 | span_start_indicator[span_starts] = True
63 | span_num = np.cumsum(span_start_indicator)
64 | is_noise = np.equal(span_num % 2, 1)
65 |
66 | return is_noise[:orig_length]
67 |
68 | def compute_input_and_target_lengths(
69 | inputs_length, noise_density, mean_noise_span_length
70 | ):
71 | """
72 | A copy of copy from https://github.com/EleutherAI/oslo/blob/main/oslo/transformers/tasks/data_t5_pretraining.py#L76 (shits getting meta)
73 | This function is copy of `random_spans_helper `__ .
74 | Training parameters to avoid padding with random_spans_noise_mask.
75 | When training a model with random_spans_noise_mask, we would like to set the other
76 | training hyperparmeters in a way that avoids padding.
77 | This function helps us compute these hyperparameters.
78 | We assume that each noise span in the input is replaced by extra_tokens_per_span_inputs sentinel tokens,
79 | and each non-noise span in the targets is replaced by extra_tokens_per_span_targets sentinel tokens.
80 | This function tells us the required number of tokens in the raw example (for split_tokens())
81 | as well as the length of the encoded targets. Note that this function assumes
82 | the inputs and targets will have EOS appended and includes that in the reported length.
83 | Args:
84 | inputs_length: an integer - desired length of the tokenized inputs sequence
85 | noise_density: a float
86 | mean_noise_span_length: a float
87 | Returns:
88 | tokens_length: length of original text in tokens
89 | targets_length: an integer - length in tokens of encoded targets sequence
90 | """
91 | def _tokens_length_to_inputs_length_targets_length(tokens_length):
92 | num_noise_tokens = int(round(tokens_length * noise_density))
93 | num_nonnoise_tokens = tokens_length - num_noise_tokens
94 | num_noise_spans = int(round(num_noise_tokens / mean_noise_span_length))
95 | # inputs contain all nonnoise tokens, sentinels for all noise spans
96 | # and one EOS token.
97 | _input_length = num_nonnoise_tokens + num_noise_spans + 1
98 | _output_length = num_noise_tokens + num_noise_spans + 1
99 | return _input_length, _output_length
100 |
101 |
102 | tokens_length = inputs_length
103 |
104 | while (
105 | _tokens_length_to_inputs_length_targets_length(tokens_length + 1)[0]
106 | <= inputs_length
107 | ):
108 | tokens_length += 1
109 |
110 | inputs_length, targets_length = _tokens_length_to_inputs_length_targets_length(
111 | tokens_length
112 | )
113 |
114 | # minor hack to get the targets length to be equal to inputs length
115 | # which is more likely to have been set to a nice round number.
116 | if noise_density == 0.5 and targets_length > inputs_length:
117 | tokens_length -= 1
118 | targets_length -= 1
119 | return tokens_length, targets_length
--------------------------------------------------------------------------------
/tests/utils.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import string
3 | from random import choice
4 |
5 | import numpy as np
6 | import torch
7 |
8 |
9 | def count_param(model):
10 | model_parameters = filter(lambda p: p.requires_grad, model.parameters())
11 | params = sum([np.prod(p.size()) for p in model_parameters])
12 | return params
13 |
14 |
15 | def gen_random_sequence(length=128):
16 | chars = string.ascii_letters + " .!?'"
17 | return "".join([choice(chars) for i in range(length)])
18 |
19 |
20 | def test_no_nan(model, tokenizer, iterations=100, device="cpu"):
21 | model_half = model.half().to(device)
22 | model_half.train()
23 | text = " Hello world! The Earth is a nice place to be. "
24 | train_text = " translate English to German: The house is wonderful. "
25 | train_label = " Das Haus ist wunderbar. "
26 | for idx in range(iterations):
27 | inputs = tokenizer.encode(str(idx) + text + str(idx), return_tensors="pt").to(
28 | device
29 | )
30 | out = model_half(input_ids=inputs, decoder_input_ids=inputs)
31 | if torch.isnan(out[0]).any():
32 | return False
33 | train_input_ids = tokenizer(
34 | str(idx) + train_text + str(idx), return_tensors="pt"
35 | ).input_ids.to(device)
36 | train_labels = tokenizer(
37 | str(idx) + train_label + str(idx), return_tensors="pt"
38 | ).input_ids.to(device)
39 | loss = model_half(input_ids=train_input_ids, labels=train_labels).loss
40 | if torch.isnan(loss):
41 | return False
42 | inputs = tokenizer.encode(gen_random_sequence(), return_tensors="pt").to(device)
43 | out = model_half(input_ids=inputs, decoder_input_ids=inputs)
44 | if torch.isnan(out[0]).any():
45 | return False
46 | return True
47 |
48 |
49 | def scale_weights(layer, scale_down_factor):
50 | old_weights = layer.weight
51 | layer.weight = torch.nn.Parameter(old_weights / scale_down_factor)
52 | return old_weights
53 |
54 |
55 | def search_and_reset_layers(
56 | model, tokenizer, scale_down_factor=15, revert_old=False, device="cuda"
57 | ):
58 | model = model.float().to(device)
59 | total_params = count_param(model)
60 | param_reset_count = 0
61 |
62 | print("Testing encoder")
63 | for i, layer in enumerate(model.encoder.block[::-1]):
64 | fflayer0 = layer.layer[1].DenseReluDense.wi_0
65 | fflayer1 = layer.layer[1].DenseReluDense.wi_1
66 | fflayer2 = layer.layer[1].DenseReluDense.wo
67 |
68 | # fflayer2.reset_parameters()
69 | old_weights = scale_weights(fflayer2, scale_down_factor)
70 | param_reset_count += count_param(fflayer2)
71 | if test_no_nan(model, tokenizer, device=device):
72 | print("Success at encoder", len(model.encoder.block) - i, "FF2")
73 | return model.float(), int(param_reset_count / total_params * 100)
74 | else:
75 | if revert_old:
76 | param_reset_count -= count_param(fflayer2)
77 | fflayer2.weight = old_weights
78 |
79 | # fflayer1.reset_parameters()
80 | old_weights = scale_weights(fflayer1, scale_down_factor)
81 | param_reset_count += count_param(fflayer1)
82 |
83 | if test_no_nan(model, tokenizer, device=device):
84 | print("Success at encoder", len(model.encoder.block) - i, "FF1")
85 | return model.float(), int(param_reset_count / total_params * 100)
86 | else:
87 | if revert_old:
88 | param_reset_count -= count_param(fflayer1)
89 | fflayer1.weight = old_weights
90 |
91 | old_weights = scale_weights(fflayer0, scale_down_factor)
92 | param_reset_count += count_param(fflayer0)
93 | if test_no_nan(model, tokenizer, device=device):
94 | print("Success at encoder", len(model.encoder.block) - i, "FF0")
95 | return model.float(), int(param_reset_count / total_params * 100)
96 | else:
97 | if revert_old:
98 | param_reset_count -= count_param(fflayer0)
99 | fflayer0.weight = old_weights
100 |
101 | print("Testing decoder")
102 | for i, layer in enumerate(model.decoder.block[::-1]):
103 | fflayer0 = layer.layer[2].DenseReluDense.wi_0
104 | fflayer1 = layer.layer[2].DenseReluDense.wi_1
105 | fflayer2 = layer.layer[2].DenseReluDense.wo
106 |
107 | # fflayer2.reset_parameters()
108 | old_weights = scale_weights(fflayer2, scale_down_factor)
109 | param_reset_count += count_param(fflayer2)
110 | if test_no_nan(model, tokenizer, device=device):
111 | print("Success at decoder", len(model.decoder.block) - i, "FF2")
112 | return model.float(), int(param_reset_count / total_params * 100)
113 | else:
114 | if revert_old:
115 | param_reset_count -= count_param(fflayer2)
116 | fflayer2.weight = old_weights
117 |
118 | # fflayer1.reset_parameters()
119 | old_weights = scale_weights(fflayer1, scale_down_factor)
120 | param_reset_count += count_param(fflayer1)
121 | if test_no_nan(model, tokenizer, device=device):
122 | print("Success at decoder", len(model.decoder.block) - i, "FF1")
123 | return model.float(), int(param_reset_count / total_params * 100)
124 | else:
125 | if revert_old:
126 | param_reset_count -= count_param(fflayer1)
127 | fflayer1.weight = old_weights
128 |
129 | old_weights = scale_weights(fflayer0, scale_down_factor)
130 | param_reset_count += count_param(fflayer0)
131 | if test_no_nan(model, tokenizer, device=device):
132 | print("Success at decoder", len(model.decoder.block) - i, "FF0")
133 | return model.float(), int(param_reset_count / total_params * 100)
134 | else:
135 | if revert_old:
136 | param_reset_count -= count_param(fflayer0)
137 | fflayer0.weight = old_weights
138 |
139 | return model.float(), False
140 |
141 |
142 | def fix_rescale(model, scale_down_factor=10):
143 | model = model.float()
144 | total_params = count_param(model)
145 | param_reset_count = 0
146 |
147 | print("Testing encoder")
148 | for i, layer in enumerate(model.encoder.block[::-1]):
149 | fflayer0 = layer.layer[1].DenseReluDense.wi_0
150 | fflayer1 = layer.layer[1].DenseReluDense.wi_1
151 | fflayer2 = layer.layer[1].DenseReluDense.wo
152 |
153 | old_weights = scale_weights(fflayer2, scale_down_factor)
154 | param_reset_count += count_param(fflayer2)
155 |
156 | # fflayer1.reset_parameters()
157 | old_weights = scale_weights(fflayer1, scale_down_factor)
158 | param_reset_count += count_param(fflayer1)
159 |
160 | old_weights = scale_weights(fflayer0, scale_down_factor)
161 | param_reset_count += count_param(fflayer0)
162 |
163 | print("Testing decoder")
164 | for i, layer in enumerate(model.decoder.block[::-1]):
165 | fflayer0 = layer.layer[2].DenseReluDense.wi_0
166 | fflayer1 = layer.layer[2].DenseReluDense.wi_1
167 | fflayer2 = layer.layer[2].DenseReluDense.wo
168 |
169 | # fflayer2.reset_parameters()
170 | old_weights = scale_weights(fflayer2, scale_down_factor)
171 | param_reset_count += count_param(fflayer2)
172 | if len(model.decoder.block) - i == 2:
173 | print("decoder", len(model.decoder.block) - i, "FF2")
174 | return model.float()
175 |
176 | # fflayer1.reset_parameters()
177 | old_weights = scale_weights(fflayer1, scale_down_factor)
178 | param_reset_count += count_param(fflayer1)
179 |
180 | old_weights = scale_weights(fflayer0, scale_down_factor)
181 | param_reset_count += count_param(fflayer0)
182 |
183 | return model.float()
184 |
--------------------------------------------------------------------------------
/reference_example_code.py:
--------------------------------------------------------------------------------
1 | dataset = tfds.load('wikipedia/20220620.en', split='train', shuffle_files=True)
2 |
3 | def ul2_objective(dataset: tf.data.Dataset,
4 | sequence_length: seqio.preprocessors.SequenceLengthType,
5 | output_features: seqio.preprocessors.OutputFeaturesType,
6 | use_prefix_lm_task: bool = False,
7 | rates: Optional[Sequence[float]] = None,
8 | mean_noise_span_lengths: Sequence[float] = (3.0,),
9 | noise_densities: Sequence[float] = (0.15,),
10 | shard_ds: bool = True,
11 | optional_task_prefixes: Optional[Sequence[str]] = None,
12 | input_feature_key: str = "inputs",
13 | merge_examples_to_reduce_padding: bool = True,
14 | reserved_for_packing: bool = None,
15 | seed: int = 7) -> tf.data.Dataset:
16 |
17 | """
18 | UL2-like pre-training objectives. This preprocessor amounts to calling the ‘span_corruption‘ function several times with different values of ‘noise_density‘ and ‘mean_noise_span_length‘.
19 | We either shard or copy the dataset, then apply each function to each shard. Add S-denoising (prefixLM) using use_prefix_lm_task.
20 |
21 | Args:
22 | dataset: A tf.data.Dataset with dictionaries containing the key ‘input_feature_key‘.
23 | sequence_length: dict mapping of feature key to int length for that feature.
24 | output_features: mapping of keys to features.
25 | use_prefix_lm_task: If True, include PrefixLM in the task mix.
26 | rates: > List of rates per task. If None, tasks are sampled uniformly.
27 | mean_noise_span_lengths: List of mean number of tokens per masked span per example.
28 | noise_densities: List of what fraction of the tokens to mask.
29 | shard_ds: If True, shard dataset per objective.
30 | optional_task_prefixes: > Strings to prepend for each orruption scheme.
31 | NOTE: If including prefixLM task, it must be the last prefix.
32 | input_feature_key: which feature to use from the dataset as the input text tokens.
33 | merge_examples_to_reduce_padding: if True, combines multiple input examples to reduce padding. reserved_for_packing: if specified, reduces the desired inputs length by the specified amount to enable multiple examples to be packed together downstream.
34 | seed: tf.int64 for controlling the random choice of spans. Returns: a dataset
35 | """
36 |
37 | if optional_task_prefixes: # Ensure each task has a prefix.
38 | num_tasks = len(noise_densities) + int(use_prefix_lm_task)
39 | valid_number_of_prefixes = num_tasks == len(optional_task_prefixes)
40 | if not valid_number_of_prefixes:
41 | raise ValueError("Number of task prefixes must match number of tasks.")
42 | inputs_length = sequence_length[input_feature_key]
43 | input_lengths, targets_lengths = [], []
44 | sequence_lengths = {x: y for x, y in sequence_length.items()}
45 | if reserved_for_packing:
46 | inputs_length -= reserved_for_packing
47 | for x, y in sequence_length.items():
48 | sequence_lengths[x] = y - reserved_for_packing
49 | hyperparams = list(zip(mean_noise_span_lengths, noise_densities))
50 | for mean_noise_span_length, noise_density in hyperparams:
51 | input_length, targets_length = t5.data.preprocessors.random_spans_helper(
52 | extra_tokens_per_span_inputs=1,
53 | extra_tokens_per_span_targets=1,
54 | inputs_length=inputs_length,
55 | mean_noise_span_length=mean_noise_span_length,
56 | noise_density=noise_density)
57 | input_lengths.append(input_length)
58 | targets_lengths.append(targets_length)
59 |
60 | if sequence_length["targets"] < targets_length:
61 | upper_bound = max(targets_lengths)
62 | raise ValueError(f"Targets length {sequence_length['targets']} is too small for the given noise_density and mean_noise_span_length. Please increase the targets length to at least {upper_bound}.")
63 | #raise ValueError("f’Expected max targets length for span corruption ({upper_bound}) is ’ f’greater than configured targets length ’ f"({sequence_length[’targets’]})")
64 |
65 | ds = dataset
66 | ds = t5.data.preprocessors.select_random_chunk(
67 | ds,
68 | output_features=output_features,
69 | feature_key="targets",
70 | max_length=65536)
71 | if merge_examples_to_reduce_padding:
72 | ds = t5.data.preprocessors.reduce_concat_tokens(
73 | ds,
74 | feature_key="targets",
75 | batch_size=128)
76 | num_shards = len(input_lengths) + int(use_prefix_lm_task)
77 | if shard_ds:
78 | ds_shards = [ds.shard(num_shards, i) for i in range(num_shards)]
79 | else:
80 | ds_shards = [ds for _ in range(num_shards)]
81 | processed_ds = []
82 | hyperparams = zip(input_lengths, hyperparams, range(num_shards))
83 | for input_length, (noise_span_length, noise_density), i in hyperparams:
84 | ds = ds_shards[i]
85 | ds = t5.data.preprocessors.split_tokens(
86 | ds,
87 | feature_key="targets",
88 | min_tokens_per_segment=None,
89 | max_tokens_per_segment=input_length)
90 | ds = t5.data.preprocessors.denoise(
91 | ds,
92 | output_features,
93 | inputs_fn=t5.data.preprocessors.noise_span_to_unique_sentinel,
94 | targets_fn=t5.data.preprocessors.nonnoise_span_to_unique_sentinel,
95 | noise_density=noise_density,
96 | noise_mask_fn=functools.partial(
97 | t5.data.preprocessors.random_spans_noise_mask,
98 | mean_noise_span_length=noise_span_length),
99 | input_feature_key=input_feature_key)
100 | if optional_task_prefixes:
101 | ds = prepend_prompt(
102 | ds,
103 | output_features,
104 | prompt_mode=optional_task_prefixes[i],
105 | mode=optional_task_prefixes[i])
106 | processed_ds.append(ds)
107 | if use_prefix_lm_task:
108 | ds = ds_shards[-1]
109 | ds = t5.data.preprocessors.prefix_lm(ds, sequence_lengths, output_features)
110 | if optional_task_prefixes:
111 | ds = prepend_prompt(
112 | ds,
113 | output_features,
114 | prompt_mode=optional_task_prefixes[-1],
115 | mode=optional_task_prefixes[-1])
116 | processed_ds.append(ds)
117 | ds = tf.data.experimental.sample_from_datasets(processed_ds, rates, seed)
118 | return ds
119 |
120 | sequence_length = {
121 | "inputs": 512,
122 | "targets": 512,
123 | }
124 |
125 | output_features = {
126 | "inputs":
127 | seqio.Feature(
128 | vocabulary=t5.data.get_default_vocabulary(), add_eos=False),
129 | "targets":
130 | seqio.Feature(
131 | vocabulary=t5.data.get_default_vocabulary(), add_eos=False)
132 | }
133 |
134 | ul2_data = ul2_objective(
135 | dataset,
136 | sequence_length,
137 | output_features,
138 | use_prefix_lm_task=False,
139 | rates=None,
140 | mean_noise_span_lengths=(3.0,),
141 | noise_densities=(0.15,),
142 | shard_ds=True,
143 | optional_task_prefixes=None,
144 | input_feature_key="text",
145 | merge_examples_to_reduce_padding=True,
146 | reserved_for_packing=None,
147 | seed=7)
148 |
149 |
150 | def prepend_prompt(dataset: tf.data.Dataset,
151 | output_features: seqio.preprocessors.OutputFeaturesType,
152 | sequence_length: Optional[
153 | seqio.preprocessors.SequenceLengthType] = None,
154 | prompt_mode: str = "",
155 | key: str = "inputs",
156 | mode: str = "") -> tf.data.Dataset:
157 | """Prepends a prompt at the beginning of an input sequence."""
158 | del sequence_length
159 | if prompt_mode and mode:
160 | # output_features may not have inputs key
161 | out_keys = list(output_features.keys())
162 | prompt_tokens = output_features[out_keys[0]
163 | ].vocabulary.encode_tf(prompt_mode)
164 |
165 | def add_to_inputs(x):
166 | x[key] = tf.concat([prompt_tokens, x[key]], axis=0)
167 | return x
168 |
169 | dataset = dataset.map(add_to_inputs)
170 | return dataset
--------------------------------------------------------------------------------
/text_denoising/collate_fn.py:
--------------------------------------------------------------------------------
1 | import random
2 | from collections.abc import Mapping
3 | import numpy as np
4 | import torch
5 | from torch.nn import functional as F
6 | from dataclasses import dataclass
7 | from typing import Any, Dict, List, Optional, Tuple, Union
8 | from transformers.tokenization_utils_base import PreTrainedTokenizerBase
9 | from transformers.data.data_collator import (
10 | DataCollatorMixin,
11 | _torch_collate_batch,
12 | )
13 | from text_denoising.utils import random_spans_noise_mask
14 |
15 |
16 | @dataclass
17 | class DataCollatorForUL2(DataCollatorMixin):
18 | """
19 |
20 | Data collator used for UL2
21 |
22 | """
23 | tokenizer: PreTrainedTokenizerBase
24 | r_denoising: bool = True
25 | r_probability: float = 0.25
26 | r_denoising_config: Tuple[Tuple] = ((3, 0.15),)
27 | s_denoising: bool = True
28 | s_probability: float = 0.5
29 | x_denoising: bool = True
30 | x_probability: float = 0.25
31 | x_denoising_config: Tuple[Tuple] = ((32, 0.5), (64, 0.2))
32 | pad_to_multiple_of: Optional[int] = None
33 | tf_experimental_compile: bool = False
34 | return_tensors: str = "pt"
35 | label_pad_token_id: int = -100
36 |
37 | def __post_init__(self):
38 | self.total_task = [0, 1, 2]
39 | task_prob = []
40 | task_prob.append(self.r_probability if self.r_denoising else 0.0)
41 | task_prob.append(self.s_probability if self.s_denoising else 0.0)
42 | task_prob.append(self.x_probability if self.x_denoising else 0.0)
43 | self.task_prob = task_prob
44 | self.pad_token_id = self.tokenizer.pad_token_id
45 | self.decoder_start_token_id = self.tokenizer.bos_token_id
46 |
47 | assert sum(task_prob) == 1.0
48 |
49 | def assign_task_type(self, batch_size: int):
50 | '''
51 | Randomly assign S,R,X to each sentence based on weighted prob
52 | '''
53 | return random.choices(self.total_task,weights=self.task_prob, k=batch_size)
54 |
55 | def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
56 | # Handle dict or lists with proper padding and conversion to tensor.
57 | # print(examples)
58 | task_ids = self.assign_task_type(len(examples))
59 | task_type = torch.tensor(task_ids)
60 | lengths = torch.tensor([ len(e['input_ids']) for e in examples ], dtype=torch.long)
61 | if isinstance(examples[0], Mapping):
62 | batch = self.tokenizer.pad(examples, return_tensors="pt",
63 | pad_to_multiple_of=self.pad_to_multiple_of)
64 | else:
65 | batch = {
66 | "input_ids": _torch_collate_batch(examples, self.tokenizer,
67 | pad_to_multiple_of=self.pad_to_multiple_of)
68 | }
69 | max_length = batch['input_ids'].shape[-1]
70 |
71 | new_batch = {
72 | "input_ids": torch.zeros(batch['input_ids'].shape, dtype=torch.long),
73 | "labels": torch.zeros(batch['input_ids'].shape, dtype=torch.long)
74 | }
75 |
76 | _, expanded_length = batch['input_ids'].shape
77 | input_ids = batch["input_ids"]
78 | r_denoising_idx = task_type == 0
79 | if r_denoising_idx.any():
80 | mask_indices = None
81 | sub_input_ids = input_ids[r_denoising_idx]
82 | # union of different denoising settings
83 | for (mean_span, noise) in self.r_denoising_config:
84 | _mask_indices = np.array([
85 | random_spans_noise_mask(expanded_length, mean_span, noise) for _ in range(len(sub_input_ids))
86 | ])
87 | if mask_indices is None:
88 | mask_indices = _mask_indices
89 | else:
90 | mask_indices = mask_indices | _mask_indices
91 |
92 | input_ids_sentinel = self.create_sentinel_ids(mask_indices.astype(np.int8))
93 | labels_mask = ~mask_indices
94 | labels_sentinel = self.create_sentinel_ids(labels_mask.astype(np.int8))
95 | _sub_input_ids = self.filter_input_ids(sub_input_ids, input_ids_sentinel)
96 | _labels = self.filter_input_ids(sub_input_ids, labels_sentinel)
97 | diff = max_length-_labels.shape[-1]
98 | _labels = np.pad(_labels, [(0,0), (0, diff)], 'constant',
99 | constant_values=self.label_pad_token_id)
100 | diff = max_length - _sub_input_ids.shape[-1]
101 | _sub_input_ids = np.pad(_sub_input_ids, [(0,0), (0, diff)], 'constant')
102 | new_batch['input_ids'][r_denoising_idx] = torch.from_numpy(_sub_input_ids).long()
103 | new_batch['labels'][r_denoising_idx] = torch.from_numpy(_labels).long()
104 |
105 | s_denoising_idx = task_type == 1
106 | if s_denoising_idx.any():
107 | sub_input_ids = input_ids[s_denoising_idx]
108 | _labels = []
109 | _input_ids = []
110 | for input_id, len_ in zip(sub_input_ids, lengths[s_denoising_idx]):
111 | split = max(len_//2, 2)
112 | diff = expanded_length - split
113 | _input_ids.append(F.pad(input_id[:split], (0, diff), 'constant', self.pad_token_id))
114 | past_seq = input_id[split:]
115 | if past_seq[-1] != self.tokenizer.eos_token_id:
116 | past_seq[-1] = self.tokenizer.eos_token_id
117 | _labels.append(F.pad(past_seq, (0, split), 'constant', self.label_pad_token_id))
118 |
119 | new_batch['input_ids'][s_denoising_idx] = torch.stack(_input_ids)
120 | new_batch['labels'][s_denoising_idx] = torch.stack(_labels)
121 |
122 |
123 | x_denoising_idx = task_type == 2
124 | if x_denoising_idx.any():
125 | mask_indices = None
126 | sub_input_ids = input_ids[x_denoising_idx]
127 | for (mean_span, noise) in self.x_denoising_config:
128 | _mask_indices = np.array([
129 | random_spans_noise_mask(expanded_length, mean_span, noise) for _ in range(len(sub_input_ids))
130 | ])
131 | if mask_indices is None:
132 | mask_indices = _mask_indices
133 | else:
134 | mask_indices = mask_indices | _mask_indices
135 |
136 | labels_mask = ~mask_indices
137 | input_ids_sentinel = self.create_sentinel_ids(mask_indices.astype(np.int8))
138 | labels_sentinel = self.create_sentinel_ids(labels_mask.astype(np.int8))
139 | _sub_input_ids = self.filter_input_ids(sub_input_ids, input_ids_sentinel)
140 | _labels = self.filter_input_ids(sub_input_ids, labels_sentinel)
141 | diff = max_length-_labels.shape[-1]
142 | _labels = np.pad(_labels, [(0, 0), (0, diff)], 'constant',
143 | constant_values=self.label_pad_token_id)
144 | diff = max_length - _sub_input_ids.shape[-1]
145 | _sub_input_ids = np.pad(_sub_input_ids, [(0,0), (0, diff)], 'constant')
146 | new_batch['input_ids'][x_denoising_idx] = torch.from_numpy(_sub_input_ids).long()
147 | new_batch['labels'][x_denoising_idx] = torch.from_numpy(_labels).long()
148 |
149 | return self.prepare_decoder_inputs_from_labels(new_batch)
150 |
151 |
152 | def filter_input_ids(self, input_ids, sentinel_ids):
153 | """
154 | Puts sentinel mask on `input_ids` and fuse consecutive mask tokens into a single mask token by deleting.
155 | This will reduce the sequence length from `expanded_inputs_length` to `input_length`.
156 | """
157 |
158 | input_ids_full = np.where(sentinel_ids != 0, sentinel_ids, input_ids)
159 | # input_ids tokens and sentinel tokens are >= 0, tokens < 0 are
160 | # masked tokens coming after sentinel tokens and should be removed
161 | input_ids = []
162 | for row in input_ids_full:
163 | collapsed_id = row[row >= 0]
164 | diff = len(row) - len(collapsed_id)
165 | collapsed_id = np.pad(collapsed_id, (0, diff), 'constant')
166 | input_ids.append(collapsed_id)
167 | return np.array(input_ids)
168 |
169 | def create_sentinel_ids(self, mask_indices):
170 | """
171 | Sentinel ids creation given the indices that should be masked.
172 | The start indices of each mask are replaced by the sentinel ids in increasing
173 | order. Consecutive mask indices to be deleted are replaced with `-1`.
174 | """
175 | start_indices = mask_indices - np.roll(mask_indices, 1, axis=-1) * mask_indices
176 | start_indices[:, 0] = mask_indices[:, 0]
177 |
178 | sentinel_ids = np.where(
179 | start_indices != 0, np.cumsum(start_indices, axis=-1), start_indices
180 | )
181 | sentinel_ids = np.where(
182 | sentinel_ids != 0, (len(self.tokenizer) - sentinel_ids), 0
183 | )
184 | sentinel_ids -= mask_indices - start_indices
185 |
186 | return sentinel_ids
187 |
188 |
189 | def prepare_decoder_inputs_from_labels(self, batch):
190 | # decoder_start_token_id has to be defined. In T5 it is usually set to the pad_token_id.
191 | # See T5 docs for more information
192 | batch["labels"][ batch["labels"] == self.pad_token_id ] = self.label_pad_token_id
193 | shifted_labels = batch["labels"].new_zeros(batch["labels"].shape)
194 | shifted_labels[..., 1:] = batch["labels"][..., :-1].clone()
195 | shifted_labels[..., 0] = self.decoder_start_token_id # decoder_start_token_id
196 |
197 | batch["decoder_input_ids"] = torch.masked_fill(
198 | shifted_labels,
199 | shifted_labels == self.label_pad_token_id,
200 | self.pad_token_id
201 | )
202 | batch["decoder_attention_mask"] = torch.where(
203 | shifted_labels == self.label_pad_token_id,
204 | 0,
205 | torch.ones_like(shifted_labels),
206 | )
207 | return batch
208 |
209 | def np_prepare_decoder_inputs_from_labels(self, batch):
210 | batch["labels"][ batch["labels"] == self.pad_token_id ] = self.label_pad_token_id
211 | shifted_labels = np.zeros(batch["labels"].shape)
212 | shifted_labels[..., 1:] = batch["labels"][..., :-1].copy()
213 | shifted_labels[..., 0] = self.decoder_start_token_id
214 |
215 | batch["decoder_input_ids"] = np.where(
216 | shifted_labels == self.label_pad_token_id,
217 | self.pad_token_id,
218 | shifted_labels
219 | )
220 | batch["decoder_attention_mask"] = np.where(
221 | shifted_labels == self.label_pad_token_id,
222 | 0,
223 | np.ones_like(shifted_labels)
224 | )
225 | return batch
226 |
227 |
--------------------------------------------------------------------------------