├── .github
└── scaling_contours.png
├── pyproject.toml
├── LICENSE
├── README.md
├── training.py
├── .gitignore
├── data_utils.py
├── main.py
├── fsdp_training.py
├── data_gen.py
└── gzip_difficulty.py
/.github/scaling_contours.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KhoomeiK/complexity-scaling/HEAD/.github/scaling_contours.png
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.poetry]
2 | name = "complexity-scaling"
3 | version = "0.1.0"
4 | description = "data-sensitive scaling laws"
5 | authors = ["khoomeik <32777448+KhoomeiK@users.noreply.github.com>"]
6 | license = "MIT"
7 | readme = "README.md"
8 |
9 | [tool.poetry.dependencies]
10 | python = "^3.10"
11 | transformers = "^4.39.1"
12 | torch = "^2.2.1"
13 | statistics = "^1.0.3.5"
14 | pcfg = "^0.1.5"
15 | datasets = "^2.18.0"
16 |
17 |
18 | [build-system]
19 | requires = ["poetry-core"]
20 | build-backend = "poetry.core.masonry.api"
21 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Rohan Pandey
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
7 |
8 | 🐦 Twitter
9 | •
10 | 📄 Arxiv
11 | •
12 | 🤗 Datasets
13 |
14 |
15 | 🔗 Multimodal CodeGen for Web Data Extraction
16 |
17 |
18 | # `gzip` Predicts Data-dependent Scaling Laws
19 |
20 | This is the official code for *`gzip` Predicts Data-dependent Scaling Laws* (under review at NeurIPS 2024).
21 |
22 | We find that:
23 | 1. scaling laws are sensitive to differences in data complexity
24 | 2. `gzip`, a compression algorithm, is an effective predictor of how data complexity impacts scaling properties
25 |
26 | Our data-dependent scaling law's compute-optimal frontier increases in dataset size preference (over parameter count preference) as training data becomes more complex (harder to compress).
27 |
28 | ## Code Overview
29 | - `data_gen.py`: create PCFGs with specified syntactic properties and sample text datasets from them
30 | - `data_utils.py`: `gzip`-compressibility measurement, tokenization & HuggingFace tooling, dataloaders, etc.
31 | - `training.py`: run a single training run given model and dataset, returning loss at each train step
32 | - `main.py`: run a set of training runs across datasets & model sizes (hackily GPU-parallelized with threading)
33 | - `fsdp_training.py`: for running bigger jobs with cleaner data loading & FSDP training
34 |
35 | Upon request via email, we can also provide:
36 | - JSONL records of all training runs (this is large and can't fit on GitHub)
37 | - the Jupyter Notebook used to fit scaling laws from training runs and generate all visuals
38 |
--------------------------------------------------------------------------------
/training.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from torch.nn import CrossEntropyLoss
4 | from tqdm.auto import tqdm
5 | from torch.optim.lr_scheduler import CosineAnnealingLR
6 |
7 |
8 | def compute_perplexity(dataloader, model, device="cuda"):
9 | # adapted from: https://github.com/huggingface/evaluate/blob/main/metrics/perplexity/perplexity.py
10 | model = model.to(device)
11 |
12 | ppls = []
13 | loss_fct = CrossEntropyLoss(reduction="none")
14 |
15 | for batch in dataloader:
16 | batch.to(device)
17 | encoded_batch = batch["input_ids"]
18 | attn_mask = batch["attention_mask"]
19 |
20 | labels = encoded_batch
21 |
22 | with torch.no_grad():
23 | out_logits = model(encoded_batch, attention_mask=attn_mask).logits
24 |
25 | shift_logits = out_logits[
26 | ..., :-1, :
27 | ].contiguous() # TODO: double check that all this logic is correct
28 | shift_labels = labels[..., 1:].contiguous()
29 | shift_attention_mask_batch = attn_mask[..., 1:].contiguous()
30 |
31 | perplexity_batch = torch.exp(
32 | (
33 | loss_fct(shift_logits.transpose(1, 2), shift_labels)
34 | * shift_attention_mask_batch
35 | ).sum(1)
36 | / shift_attention_mask_batch.sum(1)
37 | )
38 |
39 | ppls += perplexity_batch.tolist()
40 |
41 | return np.mean(ppls)
42 |
43 |
44 | def run_training(
45 | model, train_dataloader, valid_dataloader, optimizer, num_epochs=10, device="cuda"
46 | ):
47 | train_loss = []
48 | valid_loss = []
49 |
50 | for epoch in range(num_epochs):
51 | lr_scheduler = CosineAnnealingLR(optimizer, T_max=len(train_dataloader))
52 | progress_bar = tqdm(
53 | range(len(train_dataloader)), desc=f"Epoch {epoch + 1}/{num_epochs}"
54 | )
55 |
56 | model.train()
57 | for batch in train_dataloader:
58 | batch = {k: v.to(device) for k, v in batch.items()}
59 | outputs = model(**batch)
60 | loss = outputs.loss
61 | loss.backward()
62 | train_loss.append(loss.item())
63 |
64 | optimizer.step()
65 | optimizer.zero_grad()
66 | progress_bar.update(1)
67 |
68 | lr_scheduler.step() # NOTE: all single-epoch scaling experiments before 4/30 mistakenly did not step the scheduler
69 |
70 | model.eval()
71 | with torch.no_grad():
72 | for batch in valid_dataloader:
73 | batch = {k: v.to(device) for k, v in batch.items()}
74 | outputs = model(**batch)
75 | loss = outputs.loss
76 | valid_loss.append(loss.item())
77 |
78 | print(
79 | f"Train Loss: {np.median(train_loss):.3f}, Valid Loss: {np.median(valid_loss):.3f}"
80 | )
81 |
82 | return train_loss, valid_loss
83 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | results/
2 | .DS_Store
3 |
4 | # Byte-compiled / optimized / DLL files
5 | __pycache__/
6 | *.py[cod]
7 | *$py.class
8 |
9 | # C extensions
10 | *.so
11 |
12 | # Distribution / packaging
13 | .Python
14 | build/
15 | develop-eggs/
16 | dist/
17 | downloads/
18 | eggs/
19 | .eggs/
20 | lib/
21 | lib64/
22 | parts/
23 | sdist/
24 | var/
25 | wheels/
26 | share/python-wheels/
27 | *.egg-info/
28 | .installed.cfg
29 | *.egg
30 | MANIFEST
31 |
32 | # PyInstaller
33 | # Usually these files are written by a python script from a template
34 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
35 | *.manifest
36 | *.spec
37 |
38 | # Installer logs
39 | pip-log.txt
40 | pip-delete-this-directory.txt
41 |
42 | # Unit test / coverage reports
43 | htmlcov/
44 | .tox/
45 | .nox/
46 | .coverage
47 | .coverage.*
48 | .cache
49 | nosetests.xml
50 | coverage.xml
51 | *.cover
52 | *.py,cover
53 | .hypothesis/
54 | .pytest_cache/
55 | cover/
56 |
57 | # Translations
58 | *.mo
59 | *.pot
60 |
61 | # Django stuff:
62 | *.log
63 | local_settings.py
64 | db.sqlite3
65 | db.sqlite3-journal
66 |
67 | # Flask stuff:
68 | instance/
69 | .webassets-cache
70 |
71 | # Scrapy stuff:
72 | .scrapy
73 |
74 | # Sphinx documentation
75 | docs/_build/
76 |
77 | # PyBuilder
78 | .pybuilder/
79 | target/
80 |
81 | # Jupyter Notebook
82 | .ipynb_checkpoints
83 |
84 | # IPython
85 | profile_default/
86 | ipython_config.py
87 |
88 | # pyenv
89 | # For a library or package, you might want to ignore these files since the code is
90 | # intended to run in multiple environments; otherwise, check them in:
91 | # .python-version
92 |
93 | # pipenv
94 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
95 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
96 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
97 | # install all needed dependencies.
98 | #Pipfile.lock
99 |
100 | # poetry
101 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
102 | # This is especially recommended for binary packages to ensure reproducibility, and is more
103 | # commonly ignored for libraries.
104 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
105 | #poetry.lock
106 |
107 | # pdm
108 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
109 | #pdm.lock
110 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
111 | # in version control.
112 | # https://pdm.fming.dev/#use-with-ide
113 | .pdm.toml
114 |
115 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
116 | __pypackages__/
117 |
118 | # Celery stuff
119 | celerybeat-schedule
120 | celerybeat.pid
121 |
122 | # SageMath parsed files
123 | *.sage.py
124 |
125 | # Environments
126 | .env
127 | .venv
128 | env/
129 | venv/
130 | ENV/
131 | env.bak/
132 | venv.bak/
133 |
134 | # Spyder project settings
135 | .spyderproject
136 | .spyproject
137 |
138 | # Rope project settings
139 | .ropeproject
140 |
141 | # mkdocs documentation
142 | /site
143 |
144 | # mypy
145 | .mypy_cache/
146 | .dmypy.json
147 | dmypy.json
148 |
149 | # Pyre type checker
150 | .pyre/
151 |
152 | # pytype static type analyzer
153 | .pytype/
154 |
155 | # Cython debug symbols
156 | cython_debug/
157 |
158 | # PyCharm
159 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
160 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
161 | # and can be added to the global gitignore or merged into this file. For a more nuclear
162 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
163 | #.idea/
164 |
--------------------------------------------------------------------------------
/data_utils.py:
--------------------------------------------------------------------------------
1 | import gzip
2 | import io
3 | import random
4 | from typing import List, Union
5 | from statistics import median, stdev
6 | from torch.utils.data import DataLoader
7 | from transformers import DataCollatorWithPadding
8 | from datasets import Dataset, DatasetDict, load_dataset
9 | from huggingface_hub import HfApi, HfFolder
10 |
11 |
12 | def count_total_tokens(dataloader):
13 | total_tokens = 0
14 | for batch in dataloader:
15 | total_tokens += sum(batch["attention_mask"].flatten().tolist())
16 | return total_tokens
17 |
18 |
19 | def pad_and_mask(sequence, sequence_length):
20 | if sequence_length - len(sequence) == 0:
21 | padded_sequence = sequence
22 | elif sequence_length - len(sequence) > 0:
23 | padded_sequence = sequence + [32000] * (sequence_length - len(sequence))
24 | elif sequence_length - len(sequence) < 0:
25 | padded_sequence = sequence[:sequence_length]
26 | mask = [1 if token != 32000 else 0 for token in padded_sequence]
27 | return padded_sequence, mask
28 |
29 |
30 | def pcfg_dataset_to_dataloader(
31 | pcfg_dataset, padder_tokenizer, batch_size=8, context_length=256, dataset_name=""
32 | ):
33 | if 'code' in dataset_name:
34 | tok_seqs = pcfg_dataset
35 | else:
36 | tok_seqs = [[int(tok) for tok in doc.split(" ")] for doc in pcfg_dataset]
37 |
38 | input_ids, attention_masks = [], []
39 | for seq in tok_seqs:
40 | padded_seq, mask = pad_and_mask(seq, context_length)
41 | input_ids.append(padded_seq)
42 | attention_masks.append(mask)
43 |
44 | tokenized_dataset = Dataset.from_dict(
45 | {"input_ids": input_ids, "attention_mask": attention_masks}
46 | )
47 | tokenized_dataset = tokenized_dataset.map(
48 | lambda x: {"labels": x["input_ids"].copy()}, batched=True
49 | )
50 | tokenized_dataset.set_format("torch")
51 |
52 | data_collator = DataCollatorWithPadding(tokenizer=padder_tokenizer)
53 |
54 | dataloader = DataLoader(
55 | tokenized_dataset, shuffle=True, batch_size=batch_size, collate_fn=data_collator
56 | )
57 |
58 | return dataloader
59 |
60 |
61 | def calculate_gzipability(
62 | input_data: Union[str, List[int]], gzip_toks: bool = True
63 | ) -> int:
64 | if type(input_data) == str and not gzip_toks:
65 | input_bytes = input_data.encode("utf-8")
66 | else: # token list
67 | if type(input_data) == str:
68 | input_data = [int(tok) for tok in input_data.split(" ")]
69 | input_bytes = b"".join(
70 | int.to_bytes(i, length=4, byteorder="big", signed=True) for i in input_data
71 | )
72 |
73 | buf = io.BytesIO()
74 | with gzip.GzipFile(fileobj=buf, mode="wb") as f:
75 | f.write(input_bytes)
76 |
77 | compressed_size = buf.tell()
78 | gzipability = compressed_size / len(input_bytes)
79 |
80 | return gzipability
81 |
82 |
83 | def calculate_median_stdev_gzipability(pcfg_dataset):
84 | gzipability_scores = [
85 | calculate_gzipability(row)
86 | for row in random.sample(pcfg_dataset, min(1000, len(pcfg_dataset)))
87 | ]
88 | med = median(gzipability_scores)
89 |
90 | if len(gzipability_scores) > 1:
91 | std_dev = stdev(gzipability_scores)
92 | else:
93 | std_dev = 0 # Default to 0 if there's only one element to avoid division by zero in stdev calculation
94 |
95 | return med, std_dev
96 |
97 |
98 | def upload_to_huggingface(pcfg_dataset, gzip, dataset_stats=None):
99 | api = HfApi()
100 | token = HfFolder.get_token()
101 | if token is None:
102 | raise ValueError(
103 | "Hugging Face Hub token not found. Please login using `huggingface-cli login`."
104 | )
105 | username = api.whoami(token)["name"]
106 |
107 | dataset = [{"text": seq} for seq in pcfg_dataset]
108 | dataset_dict = {
109 | "train": Dataset.from_list(dataset), # map to list of dicts?
110 | }
111 |
112 | dataset = DatasetDict(dataset_dict)
113 | dataset.push_to_hub(f"{username}/gzipscale-{gzip:0.2f}-{('_'.join([str(x) for x in dataset_stats[:-1]])) + '-' if dataset_stats else ''}100M")
114 |
115 |
116 | def download_from_huggingface(dataset_name):
117 | dataset_dict = load_dataset(dataset_name)
118 | dataset = dataset_dict["train"]
119 |
120 | if 'code' in dataset_name:
121 | return dataset['input_ids']
122 |
123 | pcfg_dataset = dataset["text"]
124 | return pcfg_dataset
125 |
126 | def sample_code_dataset(tokenizer, context_length=256):
127 | ds = load_dataset("codeparrot/github-code", streaming=True, split="train")
128 |
129 | seqs = []
130 | try:
131 | for row in ds:
132 | if row['language'] == 'C':
133 | outputs = tokenizer(row['code'], add_special_tokens=False)
134 | if len(outputs['input_ids']) < context_length:
135 | continue
136 |
137 | input_ids = outputs['input_ids']
138 |
139 | for i, subseq in enumerate(input_ids[::context_length]):
140 | if (i+1)*context_length > len(input_ids):
141 | break
142 | seq = input_ids[i*context_length : (i+1)*context_length]
143 | seqs.append({'input_ids': seq})
144 |
145 | if len(seqs) % 10_000 < 10:
146 | print(len(seqs))
147 |
148 | if len(seqs) > 31_250_000: # 8B tokens
149 | break
150 | # except http.client.RemoteDisconnected as e:
151 | except Exception as e:
152 | print(e)
153 | print("Connection to HuggingFace Hub was lost. Saving current progress...")
154 |
155 | print(len(seqs))
156 | tokenized_code_dataset = DatasetDict({"train": Dataset.from_list(seqs)})
157 | tokenized_code_dataset.set_format("torch")
158 | tokenized_code_dataset.push_to_hub(f"khoomeik/gzipscale-code-C-{(len(seqs) * context_length / 1_000_000):0.0f}M")
159 |
160 | if __name__ == '__main__':
161 | from transformers import AutoTokenizer
162 | tokenizer = AutoTokenizer.from_pretrained(
163 | "meta-llama/Llama-2-7b-chat-hf", token="[REDACTED]"
164 | )
165 | tokenizer.add_special_tokens({"pad_token": ""})
166 |
167 | sample_code_dataset(tokenizer)
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from training import run_training
3 | from transformers import AdamW, AutoTokenizer, LlamaForCausalLM, LlamaConfig
4 | import json
5 | from data_utils import (
6 | calculate_median_stdev_gzipability,
7 | count_total_tokens,
8 | pcfg_dataset_to_dataloader,
9 | download_from_huggingface,
10 | )
11 |
12 |
13 | def run_scaling_exps(cuda_idx=None):
14 | context_length = 256
15 | llm_configuration = {
16 | "vocab_size": 32001,
17 | "hidden_size": 256,
18 | "intermediate_size": 512,
19 | "num_hidden_layers": 4,
20 | "num_attention_heads": 4,
21 | "max_position_embeddings": context_length,
22 | }
23 | tokenizer = AutoTokenizer.from_pretrained(
24 | "meta-llama/Llama-2-7b-chat-hf", token="[REDACTED]"
25 | )
26 | tokenizer.add_special_tokens({"pad_token": ""})
27 |
28 | model_sizes = {
29 | "hidden_size": [64, 128, 256, 512, 1024, 2048],
30 | "intermediate_size": [128, 256, 512, 1024, 2048, 4096],
31 | "num_hidden_layers": [2, 4, 6, 10, 20, 30],
32 | "num_attention_heads": [1, 2, 4, 8, 16, 32],
33 | }
34 |
35 | dataset_names = [
36 | # "khoomeik/gzipscale-0.32-10_500_5_10-100M",
37 | # "khoomeik/gzipscale-0.36-20_300_10_5-100M",
38 | # "khoomeik/gzipscale-0.40-30_200_15_20-100M",
39 | # "khoomeik/gzipscale-0.38-(50,100,30,15)-100M",
40 |
41 | # "khoomeik/gzipscale-code-C-256M",
42 | # "khoomeik/gzipscale-code-python-256M",
43 | # "khoomeik/gzipscale-code-html-256M",
44 |
45 | # "khoomeik/gzipscale-0.11-100M",
46 | # "khoomeik/gzipscale-0.22-100M",
47 | # "khoomeik/gzipscale-0.35-100M",
48 | # "khoomeik/gzipscale-0.42-100M",
49 | # "khoomeik/gzipscale-0.51-100M",
50 | # "khoomeik/gzipscale-0.61-100M",
51 |
52 | # "khoomeik/gzipscale-0.12-10M",
53 | # "khoomeik/gzipscale-0.23-10M",
54 | # "khoomeik/gzipscale-0.33-10M",
55 | # "khoomeik/gzipscale-0.45-10M",
56 | # "khoomeik/gzipscale-0.61-10M",
57 |
58 | "khoomeik/gzipscale-0.11-3_300_2_2-100M",
59 | "khoomeik/gzipscale-0.25-10_300_5_3-100M",
60 | "khoomeik/gzipscale-0.36-20_300_10_5-100M",
61 | "khoomeik/gzipscale-0.47-50_300_20_10-100M"
62 | ]
63 | if cuda_idx is not None:
64 | if cuda_idx == torch.cuda.device_count(): # NOTE: this is only for handling dataset #5 and will likely break on systems with >4 GPUs
65 | dataset_names = [dataset_names[cuda_idx]]
66 | cuda_idx = torch.cuda.device_count() - 1
67 | else:
68 | dataset_names = [dataset_names[cuda_idx]]
69 | # cuda_idx = 1
70 | pcfg_datasets = [download_from_huggingface(name) for name in dataset_names]
71 | med_std_gzips = [
72 | calculate_median_stdev_gzipability(pcfg_dataset)
73 | for pcfg_dataset in pcfg_datasets
74 | ]
75 | for i, pcfg_dataset in enumerate(pcfg_datasets):
76 | med, std = med_std_gzips[i]
77 | total_toks = count_total_tokens(
78 | pcfg_dataset_to_dataloader(pcfg_dataset, padder_tokenizer=tokenizer, dataset_name=dataset_names[i])
79 | )
80 | print(f"{i}: {med:.3f} +- {std:.3f} ({total_toks}) | {dataset_names[i]}")
81 |
82 | device = f"cuda:{cuda_idx}" if cuda_idx is not None else "cpu"
83 | results = []
84 |
85 | torch.cuda.empty_cache()
86 |
87 | for i, pcfg_dataset in enumerate(pcfg_datasets):
88 | for data_portion in (0.001, 0.01, 0.1, 0.2, 0.5, 0.95):
89 | med_gzip, std_gzip = med_std_gzips[i]
90 |
91 | train_data_size = int(len(pcfg_dataset) * data_portion)
92 | valid_data_size = min(100, int(train_data_size / 10))
93 | train_dataloader = pcfg_dataset_to_dataloader(
94 | pcfg_dataset[:train_data_size],
95 | padder_tokenizer=tokenizer,
96 | batch_size=32,
97 | dataset_name=dataset_names[i]
98 | )
99 | valid_dataloader = pcfg_dataset_to_dataloader(
100 | pcfg_dataset[-valid_data_size:],
101 | padder_tokenizer=tokenizer,
102 | batch_size=32,
103 | dataset_name=dataset_names[i]
104 | )
105 | train_token_ct = count_total_tokens(train_dataloader)
106 |
107 | for j in range(len(list(model_sizes.values())[0])):
108 | print("-" * 20)
109 |
110 | model_stats = {key: val[j] for key, val in model_sizes.items()}
111 | model_config_dict = {
112 | **llm_configuration,
113 | **model_stats,
114 | } # NOTE: update vocab_size and new tokenizer?
115 | model_config = LlamaConfig(**model_config_dict)
116 | model = LlamaForCausalLM(model_config)
117 | model_size = sum(p.numel() for p in model.parameters())
118 |
119 | print(f"Dataset Stats: {med_gzip:.3f} +- {std_gzip:.3f}")
120 | print(f"Model Size: {model_size/1_000_000:.1f}M")
121 | print(f"Train Token Count: {train_token_ct}")
122 |
123 | model.to(device)
124 | optimizer = AdamW(model.parameters(), lr=5e-5)
125 | num_epochs = 1
126 |
127 | train_loss, valid_loss = run_training(
128 | model,
129 | train_dataloader,
130 | valid_dataloader,
131 | optimizer,
132 | num_epochs=num_epochs,
133 | device=device,
134 | )
135 |
136 | row = {
137 | "dataset_name": dataset_names[i],
138 | "dataset_gzip": (med_gzip, std_gzip),
139 | "token_ct": train_token_ct,
140 | "model_stats": model_config_dict,
141 | "model_size": model_size,
142 | "num_epochs": num_epochs,
143 | "train_loss": train_loss,
144 | "valid_loss": valid_loss,
145 | }
146 | results.append(row)
147 |
148 | with open(f"results_cuda:{cuda_idx}.jsonl", "a") as file:
149 | file.write(json.dumps(row) + "\n")
150 |
151 |
152 | if __name__ == "__main__":
153 | from concurrent.futures import ThreadPoolExecutor, wait
154 |
155 | with ThreadPoolExecutor(max_workers=torch.cuda.device_count()) as executor:
156 | futures = [executor.submit(run_scaling_exps, i) for i in range(torch.cuda.device_count())]
157 | wait(futures)
158 | # run_scaling_exps(4) # NOTE: for running dataset 5
159 |
160 | # run_scaling_exps(0)
--------------------------------------------------------------------------------
/fsdp_training.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | from tqdm.auto import tqdm
4 | from datasets import Dataset, load_dataset
5 | from transformers import DataCollatorWithPadding, AdamW, AutoTokenizer, LlamaForCausalLM, LlamaConfig
6 |
7 | import torch
8 | from torch.optim.lr_scheduler import CosineAnnealingLR
9 | from torch.utils.data import DataLoader
10 | from torch.utils.data.distributed import DistributedSampler
11 | import torch.distributed as dist
12 | from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
13 | import torch.multiprocessing as mp
14 |
15 | from data_utils import pad_and_mask, download_from_huggingface
16 |
17 |
18 | def create_dataloader(
19 | dataset_name,
20 | padder_tokenizer,
21 | batch_size=32,
22 | context_length=256,
23 | rank=0,
24 | world_size=1,
25 | ):
26 | dataset = load_dataset(dataset_name)["train"]
27 | data_collator = DataCollatorWithPadding(tokenizer=padder_tokenizer)
28 | data_sampler = DistributedSampler(
29 | dataset, rank=rank, num_replicas=world_size, shuffle=True
30 | )
31 |
32 | dataloader = DataLoader(
33 | dataset,
34 | batch_size=batch_size,
35 | collate_fn=data_collator,
36 | sampler=data_sampler,
37 | # CUDA args:
38 | num_workers=2,
39 | pin_memory=True,
40 | shuffle=False,
41 | )
42 |
43 | return dataloader
44 |
45 | def pcfg_dataset_to_dataloader(
46 | pcfg_dataset,
47 | padder_tokenizer,
48 | batch_size=8,
49 | context_length=256,
50 | dataset_name="",
51 | rank=0,
52 | world_size=1,
53 | ):
54 | if "code" in dataset_name:
55 | tok_seqs = pcfg_dataset
56 | else:
57 | tok_seqs = [[int(tok) for tok in doc.split(" ")] for doc in pcfg_dataset]
58 |
59 | input_ids, attention_masks = [], []
60 | for seq in tok_seqs:
61 | padded_seq, mask = pad_and_mask(seq, context_length)
62 | input_ids.append(padded_seq)
63 | attention_masks.append(mask)
64 |
65 | tokenized_dataset = Dataset.from_dict(
66 | {"input_ids": input_ids, "attention_mask": attention_masks}
67 | )
68 | tokenized_dataset = tokenized_dataset.map(
69 | lambda x: {"labels": x["input_ids"].copy()}, batched=True
70 | )
71 | tokenized_dataset.set_format("torch")
72 |
73 | data_collator = DataCollatorWithPadding(tokenizer=padder_tokenizer)
74 | data_sampler = DistributedSampler( # TODO: refactor via `distributed` flag back into original pcfg_dataset_to_dataloader
75 | tokenized_dataset, rank=rank, num_replicas=world_size, shuffle=True
76 | )
77 |
78 | dataloader = DataLoader(
79 | tokenized_dataset,
80 | batch_size=batch_size,
81 | collate_fn=data_collator,
82 | sampler=data_sampler,
83 | # CUDA args:
84 | num_workers=2,
85 | pin_memory=True,
86 | shuffle=False,
87 | )
88 |
89 | return dataloader
90 |
91 | def setup(rank, world_size):
92 | os.environ["MASTER_ADDR"] = "localhost"
93 | os.environ["MASTER_PORT"] = "12355"
94 | dist.init_process_group("nccl", rank=rank, world_size=world_size)
95 |
96 |
97 | def cleanup():
98 | dist.destroy_process_group()
99 |
100 |
101 | def run_fsdp_training(
102 | model, train_dataloader, valid_dataloader, optimizer, num_epochs=10, rank=0
103 | ):
104 | train_loss = []
105 | valid_loss = []
106 |
107 | for epoch in range(num_epochs):
108 | lr_scheduler = CosineAnnealingLR(optimizer, T_max=len(train_dataloader))
109 | progress_bar = tqdm(
110 | range(len(train_dataloader)), desc=f"Epoch {epoch + 1}/{num_epochs}"
111 | )
112 | ddp_loss = torch.zeros(2).to(rank)
113 |
114 | model.train()
115 | for batch in train_dataloader:
116 | batch = {k: v.to(rank) for k, v in batch.items()}
117 | if 'labels' not in batch: # NOTE: hack to get around DataLoader not calling CodeDataset.__iter__
118 | batch['labels'] = batch['input_ids'].clone()
119 | optimizer.zero_grad()
120 | outputs = model(**batch)
121 | loss = outputs.loss
122 | loss.backward()
123 |
124 | optimizer.step()
125 | progress_bar.update(1)
126 |
127 | train_loss.append(loss.item())
128 | ddp_loss[0] += loss.item()
129 | ddp_loss[1] += len(batch)
130 |
131 | lr_scheduler.step()
132 |
133 | dist.all_reduce(ddp_loss, op=dist.ReduceOp.SUM)
134 |
135 | return train_loss, valid_loss
136 |
137 |
138 | def fsdp_main(rank, world_size, args):
139 | setup(rank, world_size)
140 |
141 | tokenizer = AutoTokenizer.from_pretrained(
142 | "meta-llama/Llama-2-7b-chat-hf", token="[REDACTED]"
143 | )
144 | tokenizer.add_special_tokens({"pad_token": ""})
145 |
146 | dataset_name = "khoomeik/gzipscale-code-C-8000M"
147 | pcfg_dataset = download_from_huggingface(dataset_name)
148 | train_dataloader = create_dataloader(
149 | dataset_name,
150 | padder_tokenizer=tokenizer,
151 | batch_size=32,
152 | # dataset_name=dataset_name,
153 | rank=rank,
154 | world_size=world_size,
155 | )
156 |
157 | torch.cuda.set_device(rank)
158 |
159 | model_config_dict = {
160 | "vocab_size": 32001,
161 | "hidden_size": 1024,
162 | "intermediate_size": 2048,
163 | "num_hidden_layers": 32,
164 | "num_attention_heads": 16,
165 | "max_position_embeddings": 256,
166 | }
167 | model_config = LlamaConfig(**model_config_dict)
168 | model = LlamaForCausalLM(model_config)
169 | model_size = sum(p.numel() for p in model.parameters())
170 | print(f"Model Size: {model_size/1_000_000:.1f}M")
171 |
172 | model.to(rank)
173 | model = FSDP(model)
174 |
175 | optimizer = AdamW(model.parameters(), lr=5e-5)
176 | num_epochs = 1
177 |
178 | train_loss, valid_loss = run_fsdp_training(
179 | model,
180 | train_dataloader,
181 | None,
182 | optimizer,
183 | num_epochs=num_epochs,
184 | rank=rank,
185 | )
186 |
187 | row = {
188 | "dataset_name": dataset_name,
189 | # "token_ct": train_token_ct,
190 | "model_stats": model_config_dict,
191 | "model_size": model_size,
192 | "num_epochs": num_epochs,
193 | "train_loss": train_loss,
194 | # "valid_loss": valid_loss,
195 | "cuda_rank": rank,
196 | }
197 |
198 | with open("results_fsdp.jsonl", "a") as file:
199 | file.write(json.dumps(row) + "\n")
200 |
201 | dist.barrier()
202 | states = model.state_dict()
203 | if rank == 0:
204 | torch.save(states, f"./model_{model_size}_{dataset_name.split('/')[-1]}.pt")
205 |
206 | cleanup()
207 |
208 |
209 | if __name__ == "__main__":
210 | args = {}
211 |
212 | WORLD_SIZE = torch.cuda.device_count()
213 | mp.spawn(fsdp_main, args=(WORLD_SIZE, args), nprocs=WORLD_SIZE, join=True)
214 |
--------------------------------------------------------------------------------
/data_gen.py:
--------------------------------------------------------------------------------
1 | import random
2 | from pcfg import PCFG
3 | import threading
4 |
5 | def generate_probs(num_options):
6 | if num_options <= 0:
7 | raise ValueError("Number of options must be positive")
8 |
9 | # Generate random integers for each option
10 | random_ints = [random.randint(1, 100) for _ in range(num_options)]
11 |
12 | # Calculate the total sum
13 | total = sum(random_ints)
14 |
15 | # Normalize each integer by the total sum to get probabilities
16 | probs = [i / total for i in random_ints]
17 |
18 | return probs
19 |
20 |
21 | def create_random_pcfg(
22 | num_nonterminals,
23 | num_terminals,
24 | rhs_max_options=5,
25 | rhs_max_len=5,
26 | constrain_to_pfsa=False,
27 | ):
28 | # Create non-terminal symbols
29 | nonterminals = [f"N{i}" for i in range(num_nonterminals)]
30 |
31 | # Create terminal symbols as consecutive integers
32 | terminals = [f"'{i}'" for i in range(num_terminals)]
33 |
34 | # Initialize production rules
35 | productions = []
36 |
37 | for lhs in nonterminals:
38 | rhs_options_ct = random.randint(1, rhs_max_options)
39 | rhs_option_probs = generate_probs(rhs_options_ct)
40 |
41 | rhs_options = []
42 |
43 | for rhs_option_prob in rhs_option_probs:
44 | rhs = []
45 |
46 | if constrain_to_pfsa:
47 | rhs.append(
48 | random.choice(nonterminals + terminals)
49 | ) # TODO: is this the right constraint?
50 | else:
51 | # Randomly decide the length of the right-hand side (at least 1)
52 | rhs_len = random.randint(1, rhs_max_len)
53 | for _ in range(rhs_len):
54 | rhs.append(random.choice(nonterminals + terminals))
55 |
56 | rhs_option = f"{' '.join(rhs)} [{rhs_option_prob}]"
57 | rhs_options.append(rhs_option)
58 |
59 | production = f"{lhs} -> {' | '.join(rhs_options)}"
60 | productions.append(production)
61 |
62 | start_production = f"S -> {' | '.join([f'{nonterminal} [{1/len(nonterminals)}]' for nonterminal in nonterminals])}"
63 | productions.insert(0, start_production)
64 |
65 | # Create the PCFG
66 | grammar = PCFG.fromstring("\n".join(productions))
67 |
68 | return grammar
69 |
70 |
71 | def generate_dataset(
72 | num_nonterminals,
73 | num_terminals,
74 | rhs_max_options,
75 | rhs_max_len,
76 | constrain_to_pfsa,
77 | num_toks_total,
78 | num_toks_per_seq=256,
79 | ) -> list[str]:
80 | print(num_nonterminals, num_terminals, rhs_max_options, rhs_max_len)
81 |
82 | grammar = create_random_pcfg(
83 | num_nonterminals,
84 | num_terminals,
85 | rhs_max_options=rhs_max_options,
86 | rhs_max_len=rhs_max_len,
87 | constrain_to_pfsa=constrain_to_pfsa,
88 | )
89 |
90 | dataset = []
91 | total_tokens_generated = 0
92 |
93 | while total_tokens_generated < num_toks_total:
94 | document_tokens = 0
95 | document = []
96 |
97 | while document_tokens < num_toks_per_seq:
98 | try:
99 | sentence = next(grammar.generate(1))
100 | except RecursionError:
101 | continue
102 | except StopIteration:
103 | print('No more sentences to generate')
104 | break # No more sentences can be generated
105 |
106 | sentence_token_count = sentence.count(" ") + 2
107 |
108 | available_space = num_toks_per_seq - document_tokens
109 | if sentence_token_count <= available_space:
110 | document.append(sentence)
111 | document_tokens += sentence_token_count
112 | else:
113 | # Split the sentence into words and add words until the document is full
114 | words = sentence.split()
115 | words_to_add = words[:available_space]
116 | truncated_sentence = " ".join(words_to_add)
117 |
118 | document.append(truncated_sentence)
119 | document_tokens += len(words_to_add)
120 |
121 | if document_tokens == num_toks_per_seq:
122 | break
123 |
124 | if document:
125 | dataset.append(" 0 ".join(document))
126 | total_tokens_generated += document_tokens
127 |
128 | if total_tokens_generated >= num_toks_total or not document:
129 | break # Stop if we've met the total token count or can't generate more documents
130 |
131 | return dataset
132 |
133 | def generate_dataset_part(grammar, num_toks_per_seq, target_tokens, dataset, total_tokens_generated, lock):
134 | local_dataset = []
135 | local_tokens_generated = 0
136 | while local_tokens_generated < target_tokens:
137 | document_tokens = 0
138 | document = []
139 | while document_tokens < num_toks_per_seq:
140 | try:
141 | sentence = next(grammar.generate(1))
142 | except RecursionError:
143 | continue
144 | except StopIteration:
145 | print('No more sentences to generate')
146 | break
147 |
148 | print(sentence)
149 | sentence_token_count = sentence.count(" ") + 2
150 | available_space = num_toks_per_seq - document_tokens
151 | if sentence_token_count <= available_space:
152 | document.append(sentence)
153 | document_tokens += sentence_token_count
154 | else:
155 | words = sentence.split()
156 | words_to_add = words[:available_space]
157 | truncated_sentence = " ".join(words_to_add)
158 | document.append(truncated_sentence)
159 | document_tokens += len(words_to_add)
160 |
161 | if document_tokens == num_toks_per_seq:
162 | break
163 |
164 | if document:
165 | local_dataset.append(" 0 ".join(document))
166 | local_tokens_generated += document_tokens
167 |
168 | if local_tokens_generated >= target_tokens or not document:
169 | break
170 |
171 | with lock:
172 | dataset.extend(local_dataset)
173 | total_tokens_generated[0] += local_tokens_generated
174 |
175 | def generate_dataset_threaded(
176 | num_nonterminals,
177 | num_terminals,
178 | rhs_max_options,
179 | rhs_max_len,
180 | constrain_to_pfsa,
181 | num_toks_total,
182 | num_toks_per_seq=256,
183 | ) -> list[str]:
184 | # NOTE: threaded dataset generation isn't noticeably faster
185 | print(num_nonterminals, num_terminals, rhs_max_options, rhs_max_len)
186 |
187 | num_threads = 16
188 | threads = []
189 | lock = threading.Lock()
190 | dataset = []
191 | total_tokens_generated = 0
192 | target_tokens_per_thread = num_toks_total // num_threads
193 |
194 | grammar = create_random_pcfg(
195 | num_nonterminals,
196 | num_terminals,
197 | rhs_max_options=rhs_max_options,
198 | rhs_max_len=rhs_max_len,
199 | constrain_to_pfsa=constrain_to_pfsa,
200 | )
201 |
202 | for _ in range(num_threads):
203 | thread = threading.Thread(target=generate_dataset_part, args=(grammar, num_toks_per_seq, target_tokens_per_thread, dataset, total_tokens_generated, lock))
204 | threads.append(thread)
205 | thread.start()
206 |
207 | print(dataset)
208 |
209 | for thread in threads:
210 | thread.join()
211 |
212 | return dataset
213 |
214 | if __name__ == "__main__":
215 | from data_utils import (
216 | calculate_median_stdev_gzipability,
217 | count_total_tokens,
218 | pcfg_dataset_to_dataloader,
219 | upload_to_huggingface,
220 | )
221 | from transformers import AutoTokenizer
222 |
223 | context_length = 256
224 | tokenizer = AutoTokenizer.from_pretrained(
225 | "meta-llama/Llama-2-7b-chat-hf", token="[REDACTED]"
226 | )
227 | tokenizer.add_special_tokens({"pad_token": ""})
228 |
229 | dataset_stats = [
230 | # (3, 20, 2, 2, False),
231 | # (10, 150, 5, 3, False),
232 | # (20, 300, 10, 5, False),
233 | # (30, 400, 10, 8, False),
234 | # (100, 2000, 100, 30, False),
235 | # (50, 500, 20, 15, False),
236 |
237 | # (10, 600, 5, 10, False), # .32
238 | # (20, 300, 15, 5, False), # .36
239 | # (30, 200, 10, 15, False), # .38
240 | # (50, 100, 20, 20, False), # .34
241 |
242 | # (3, 300, 2, 2, False), # isovocab
243 | # (10, 300, 5, 3, False),
244 | # (20, 300, 10, 5, False),
245 | # (50, 300, 20, 10, False),
246 | (100, 300, 100, 30, False),
247 | (200, 300, 200, 50, False),
248 | ]
249 | for row in dataset_stats: # NOTE: runs one dataset generation + upload at a time
250 | dataset_stats = [row]
251 |
252 | pcfg_datasets = [
253 | generate_dataset(*row, 100_000_000, num_toks_per_seq=context_length)
254 | for row in dataset_stats
255 | ]
256 | med_std_gzips = [
257 | calculate_median_stdev_gzipability(pcfg_dataset)
258 | for pcfg_dataset in pcfg_datasets
259 | ]
260 | for i, pcfg_dataset in enumerate(pcfg_datasets):
261 | med, std = med_std_gzips[i]
262 | total_toks = count_total_tokens(
263 | pcfg_dataset_to_dataloader(pcfg_dataset, padder_tokenizer=tokenizer)
264 | )
265 |
266 | print(
267 | f"{i}: {med:.3f} +- {std:.3f} ({total_toks}) | [{' '.join([str(x) for x in dataset_stats[i]])}]"
268 | )
269 | upload_to_huggingface(pcfg_dataset, med, dataset_stats[i])
270 |
--------------------------------------------------------------------------------
/gzip_difficulty.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """gzip-difficulty.ipynb
3 |
4 | Automatically generated by Colaboratory.
5 |
6 | Original file is located at
7 | https://colab.research.google.com/drive/[REDACTED]
8 |
9 | # How well does compressibility predict the learnability of a dataset?
10 |
11 | - compressibility: gzipability ~= length of gzipped string / length of original string
12 | - learnability: learning difficulty ~= integral of perplexity across training steps
13 | - datasets will be synthetically generated by PCFGs and taken from standard natural language & code datasets
14 | - hopefully the real-world datasets are in the PCFG's gzipability distribution
15 |
16 |
17 |
18 | Training data preparation [reference](https://huggingface.co/learn/nlp-course/chapter3/4)
19 |
20 | ## Setup
21 | """
22 |
23 | # ! pip install nltk pcfg
24 | # ! pip install accelerate -U
25 | # ! pip install transformers[torch] datasets wandb
26 |
27 | # Commented out IPython magic to ensure Python compatibility.
28 | # ! wandb login --relogin # [REDACTED]
29 |
30 | # # %env WANDB_ENTITY=rspandey
31 | # # %env WANDB_PROJECT=LM-Training
32 |
33 | """## Load Model"""
34 |
35 | from transformers import LlamaForCausalLM, LlamaConfig
36 |
37 | configuration = {
38 | "vocab_size": 32001,
39 | "hidden_size": 256,
40 | "intermediate_size": 512,
41 | "num_hidden_layers": 4,
42 | "num_attention_heads": 4,
43 | "max_position_embeddings": 256,
44 | }
45 | context_length = configuration["max_position_embeddings"]
46 |
47 | config = LlamaConfig(**configuration)
48 | model = LlamaForCausalLM(config)
49 |
50 | print(f"Param Count: {sum(p.numel() for p in model.parameters()) / 1_000_000:.1f}M")
51 |
52 | from transformers import AutoTokenizer
53 |
54 | tokenizer = AutoTokenizer.from_pretrained(
55 | "meta-llama/Llama-2-7b-chat-hf", token="[REDACTED]"
56 | ) # TODO: replace with actual model name
57 |
58 | tokenizer.add_special_tokens({"pad_token": ""})
59 | model.resize_token_embeddings(len(tokenizer))
60 |
61 | """## Real Data"""
62 |
63 | from torch.utils.data import DataLoader
64 | from transformers import DataCollatorWithPadding
65 | from datasets import load_dataset, DatasetDict
66 |
67 |
68 | def count_total_tokens(dataloader):
69 | total_tokens = 0
70 | for batch in dataloader:
71 | total_tokens += sum(batch["attention_mask"].flatten().tolist())
72 | return total_tokens
73 |
74 |
75 | """## CFG Data
76 |
77 | https://www.nltk.org/api/nltk.grammar.PCFG.html
78 |
79 | https://www.nltk.org/_modules/nltk/parse/generate.html
80 | """
81 |
82 |
83 | def generate_probs(num_options):
84 | if num_options <= 0:
85 | raise ValueError("Number of options must be positive")
86 |
87 | # Generate random integers for each option
88 | random_ints = [random.randint(1, 100) for _ in range(num_options)]
89 |
90 | # Calculate the total sum
91 | total = sum(random_ints)
92 |
93 | # Normalize each integer by the total sum to get probabilities
94 | probs = [i / total for i in random_ints]
95 |
96 | return probs
97 |
98 |
99 | import random
100 | import math
101 | from nltk import Nonterminal
102 | from pcfg import PCFG
103 |
104 |
105 | def create_random_pcfg(
106 | num_nonterminals,
107 | num_terminals,
108 | rhs_max_options=5,
109 | rhs_max_len=5,
110 | constrain_to_pfsa=False,
111 | ):
112 | # Create non-terminal symbols
113 | nonterminals = [f"N{i}" for i in range(num_nonterminals)]
114 |
115 | # Create terminal symbols as consecutive integers
116 | terminals = [f"'{i}'" for i in range(num_terminals)]
117 |
118 | # Initialize production rules
119 | productions = []
120 |
121 | for lhs in nonterminals:
122 | rhs_options_ct = random.randint(1, rhs_max_options)
123 | rhs_option_probs = generate_probs(rhs_options_ct)
124 |
125 | rhs_options = []
126 |
127 | for rhs_option_prob in rhs_option_probs:
128 | rhs = []
129 |
130 | if constrain_to_pfsa:
131 | rhs.append(
132 | random.choice(nonterminals + terminals)
133 | ) # TODO: is this the right constraint?
134 | else:
135 | # Randomly decide the length of the right-hand side (at least 1)
136 | rhs_len = random.randint(1, rhs_max_len)
137 | for _ in range(rhs_len):
138 | rhs.append(random.choice(nonterminals + terminals))
139 |
140 | rhs_option = f"{' '.join(rhs)} [{rhs_option_prob}]"
141 | rhs_options.append(rhs_option)
142 |
143 | production = f"{lhs} -> {' | '.join(rhs_options)}"
144 | productions.append(production)
145 |
146 | start_production = f"S -> {' | '.join([f'{nonterminal} [{1/len(nonterminals)}]' for nonterminal in nonterminals])}"
147 | productions.insert(0, start_production)
148 |
149 | # Create the PCFG
150 | grammar = PCFG.fromstring("\n".join(productions))
151 |
152 | return grammar
153 |
154 |
155 | def generate_dataset(
156 | num_nonterminals,
157 | num_terminals,
158 | rhs_max_options,
159 | rhs_max_len,
160 | constrain_to_pfsa,
161 | num_toks_total,
162 | num_toks_per_seq=context_length,
163 | ):
164 | grammar = create_random_pcfg(
165 | num_nonterminals,
166 | num_terminals,
167 | rhs_max_options=rhs_max_options,
168 | rhs_max_len=rhs_max_len,
169 | constrain_to_pfsa=constrain_to_pfsa,
170 | )
171 |
172 | dataset = []
173 | total_tokens_generated = 0
174 |
175 | while total_tokens_generated < num_toks_total:
176 | document_tokens = 0
177 | document = []
178 |
179 | while document_tokens < num_toks_per_seq:
180 | try:
181 | sentence = next(grammar.generate(1))
182 | except RecursionError:
183 | continue
184 | except StopIteration:
185 | break # No more sentences can be generated
186 |
187 | sentence_token_count = sentence.count(" ") + 2
188 |
189 | available_space = num_toks_per_seq - document_tokens
190 | if sentence_token_count <= available_space:
191 | document.append(sentence)
192 | document_tokens += sentence_token_count
193 | else:
194 | # Split the sentence into words and add words until the document is full
195 | words = sentence.split()
196 | words_to_add = words[:available_space]
197 | truncated_sentence = " ".join(words_to_add)
198 |
199 | document.append(truncated_sentence)
200 | document_tokens += len(words_to_add)
201 |
202 | if document_tokens == num_toks_per_seq:
203 | break
204 |
205 | if document:
206 | dataset.append(" 0 ".join(document))
207 | total_tokens_generated += document_tokens
208 |
209 | if total_tokens_generated >= num_toks_total or not document:
210 | break # Stop if we've met the total token count or can't generate more documents
211 |
212 | return dataset
213 |
214 |
215 | dataset_stats = [
216 | (5, 50, 3, 2, False),
217 | (10, 150, 5, 3, False),
218 | (20, 300, 10, 5, False),
219 | (50, 600, 30, 15, False),
220 | (100, 2000, 100, 30, False),
221 | ]
222 | pcfg_datasets = [generate_dataset(*row, 1_000_000) for row in dataset_stats]
223 |
224 | from datasets import Dataset
225 |
226 |
227 | def pad_and_mask(sequence, sequence_length):
228 | if sequence_length - len(sequence) == 0:
229 | padded_sequence = sequence
230 | elif sequence_length - len(sequence) > 0:
231 | padded_sequence = sequence + [32000] * (sequence_length - len(sequence))
232 | elif sequence_length - len(sequence) < 0:
233 | padded_sequence = sequence[:sequence_length]
234 | mask = [1 if token != 32000 else 0 for token in padded_sequence]
235 | return padded_sequence, mask
236 |
237 |
238 | def pcfg_dataset_to_dataloader(pcfg_dataset, batch_size=8, padder_tokenizer=tokenizer):
239 | tok_seqs = [[int(tok) for tok in doc.split(" ")] for doc in pcfg_dataset]
240 |
241 | input_ids, attention_masks = [], []
242 | for seq in tok_seqs:
243 | padded_seq, mask = pad_and_mask(seq, context_length)
244 | input_ids.append(padded_seq)
245 | attention_masks.append(mask)
246 |
247 | tokenized_dataset = Dataset.from_dict(
248 | {"input_ids": input_ids, "attention_mask": attention_masks}
249 | )
250 | tokenized_dataset = tokenized_dataset.map(
251 | lambda x: {"labels": x["input_ids"].copy()}, batched=True
252 | )
253 | tokenized_dataset.set_format("torch")
254 |
255 | data_collator = DataCollatorWithPadding(tokenizer=padder_tokenizer)
256 |
257 | dataloader = DataLoader(
258 | tokenized_dataset, shuffle=True, batch_size=batch_size, collate_fn=data_collator
259 | )
260 |
261 | return dataloader
262 |
263 |
264 | """## gzip"""
265 |
266 | import gzip
267 | import io
268 | from typing import List, Union
269 |
270 |
271 | def calculate_gzipability(
272 | input_data: Union[str, List[int]], gzip_toks: bool = True
273 | ) -> int:
274 | if type(input_data) == str and not gzip_toks:
275 | input_bytes = input_data.encode("utf-8")
276 | else: # token list
277 | if type(input_data) == str:
278 | input_data = [int(tok) for tok in input_data.split(" ")]
279 | input_bytes = b"".join(
280 | int.to_bytes(i, length=4, byteorder="big", signed=True) for i in input_data
281 | )
282 |
283 | buf = io.BytesIO()
284 | with gzip.GzipFile(fileobj=buf, mode="wb") as f:
285 | f.write(input_bytes)
286 |
287 | compressed_size = buf.tell()
288 | gzipability = compressed_size / len(input_bytes)
289 |
290 | return gzipability
291 |
292 |
293 | from statistics import median, stdev
294 |
295 |
296 | def calculate_median_stdev_gzipability(pcfg_dataset):
297 | gzipability_scores = [
298 | calculate_gzipability([int(tok) for tok in row.split(" ")])
299 | for row in pcfg_dataset
300 | ]
301 | med = median(gzipability_scores)
302 |
303 | if len(gzipability_scores) > 1:
304 | std_dev = stdev(gzipability_scores)
305 | else:
306 | std_dev = 0 # Default to 0 if there's only one element to avoid division by zero in stdev calculation
307 |
308 | return med, std_dev
309 |
310 |
311 | for i, pcfg_dataset in enumerate(pcfg_datasets):
312 | med, std = calculate_median_stdev_gzipability(pcfg_dataset)
313 | total_toks = count_total_tokens(pcfg_dataset_to_dataloader(pcfg_dataset))
314 |
315 | print(
316 | f"{i}: {med:.3f} +- {std:.3f} ({total_toks}) | [{' '.join([str(x) for x in dataset_stats[i]])}]"
317 | )
318 |
319 | """## Training
320 |
321 | Train on 2 synthetic datasets of similar token count but diff gzipability medians; compare perplexity sum over N epochs.
322 |
323 | TODO:
324 | - ensure I don't have train data in the validation set (how many unique sentences is the grammar generating)
325 | - model is unnecessarily large since vocab size is 32001
326 | - set padder_tokenizer for pcfg dataloader during each training run based on terminal_ct of pcfg_dataset
327 | - pass name and run hyperparams to wandb
328 |
329 | """
330 |
331 | import numpy as np
332 | from torch.nn import CrossEntropyLoss
333 |
334 |
335 | def compute_perplexity(dataloader, model, device="cuda"):
336 | # adapted from: https://github.com/huggingface/evaluate/blob/main/metrics/perplexity/perplexity.py
337 | model = model.to(device)
338 |
339 | ppls = []
340 | loss_fct = CrossEntropyLoss(reduction="none")
341 |
342 | for batch in dataloader:
343 | batch.to(device)
344 | encoded_batch = batch["input_ids"]
345 | attn_mask = batch["attention_mask"]
346 |
347 | labels = encoded_batch
348 |
349 | with torch.no_grad():
350 | out_logits = model(encoded_batch, attention_mask=attn_mask).logits
351 |
352 | shift_logits = out_logits[
353 | ..., :-1, :
354 | ].contiguous() # TODO: double check that all this logic is correct
355 | shift_labels = labels[..., 1:].contiguous()
356 | shift_attention_mask_batch = attn_mask[..., 1:].contiguous()
357 |
358 | perplexity_batch = torch.exp(
359 | (
360 | loss_fct(shift_logits.transpose(1, 2), shift_labels)
361 | * shift_attention_mask_batch
362 | ).sum(1)
363 | / shift_attention_mask_batch.sum(1)
364 | )
365 |
366 | ppls += perplexity_batch.tolist()
367 |
368 | return np.mean(ppls)
369 |
370 |
371 | from tqdm.auto import tqdm
372 |
373 |
374 | def run_training(model, train_dataloader, valid_dataloader, num_epochs=10):
375 | train_perplexities = []
376 | valid_perplexities = []
377 |
378 | for epoch in range(num_epochs):
379 | progress_bar = tqdm(
380 | range(len(train_dataloader)), desc=f"Epoch {epoch + 1}/{num_epochs}"
381 | )
382 |
383 | model.train()
384 | for batch in train_dataloader:
385 | batch = {k: v.to(device) for k, v in batch.items()}
386 | outputs = model(**batch)
387 | loss = outputs.loss
388 | loss.backward()
389 |
390 | optimizer.step()
391 | optimizer.zero_grad()
392 | progress_bar.update(1)
393 |
394 | train_perplexity = compute_perplexity(train_dataloader, model)
395 | train_perplexities.append(train_perplexity)
396 |
397 | model.eval()
398 | with torch.no_grad():
399 | valid_perplexity = compute_perplexity(valid_dataloader, model)
400 | valid_perplexities.append(valid_perplexity)
401 |
402 | print(
403 | f"Epoch {epoch}: Training Perplexity: {train_perplexity}, Validation Perplexity: {valid_perplexity}"
404 | )
405 |
406 | return train_perplexities, valid_perplexities
407 |
408 |
409 | import torch
410 |
411 | med_std_gzips = [
412 | calculate_median_stdev_gzipability(pcfg_dataset) for pcfg_dataset in pcfg_datasets
413 | ]
414 |
415 | model_sizes = {
416 | "hidden_size": [64, 128, 256, 512, 1024],
417 | "intermediate_size": [128, 256, 512, 1024, 2048],
418 | "num_hidden_layers": [2, 4, 6, 10, 20],
419 | "num_attention_heads": [1, 2, 4, 8, 16],
420 | }
421 |
422 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
423 |
424 | from transformers import AdamW
425 | import json
426 |
427 | results = []
428 |
429 | for data_portion in (0.01, 0.1, 0.2, 0.5, 0.95):
430 | for i, pcfg_dataset in enumerate(pcfg_datasets):
431 | med_gzip, std_gzip = med_std_gzips[i]
432 |
433 | train_data_size = int(len(pcfg_dataset) * data_portion)
434 | valid_data_size = min(100, int(train_data_size / 10))
435 | train_dataloader = pcfg_dataset_to_dataloader(pcfg_dataset[:train_data_size])
436 | valid_dataloader = pcfg_dataset_to_dataloader(pcfg_dataset[-valid_data_size:])
437 | train_token_ct = count_total_tokens(train_dataloader)
438 |
439 | for j in range(len(list(model_sizes.values())[0])):
440 | print("-" * 20)
441 |
442 | model_stats = {key: val[j] for key, val in model_sizes.items()}
443 | model_config_dict = {
444 | **configuration,
445 | **model_stats,
446 | } # NOTE: update vocab_size and new tokenizer?
447 | model_config = LlamaConfig(**model_config_dict)
448 | model = LlamaForCausalLM(model_config)
449 | model_size = sum(p.numel() for p in model.parameters())
450 |
451 | print(
452 | f"Dataset Stats: {med_gzip:.3f} +- {std_gzip:.3f} | {dataset_stats[i]}"
453 | )
454 | print(f"Model Size: {model_size/1_000_000:.1f}M")
455 | print(f"Train Token Count: {train_token_ct}")
456 |
457 | model.to(device)
458 | optimizer = AdamW(model.parameters(), lr=5e-5)
459 | num_epochs = 10
460 |
461 | train_perplexities, valid_perplexities = run_training(
462 | model, train_dataloader, valid_dataloader, num_epochs=num_epochs
463 | )
464 |
465 | row = {
466 | "dataset_stats": dataset_stats[i],
467 | "dataset_gzip": (med_gzip, std_gzip),
468 | "token_ct": train_token_ct,
469 | "model_stats": model_config_dict,
470 | "model_size": model_size,
471 | "num_epochs": num_epochs,
472 | "train_pplx": train_perplexities,
473 | "valid_pplx": valid_perplexities,
474 | }
475 | results.append(row)
476 |
477 | with open("results.jsonl", "a") as file:
478 | file.write(json.dumps(row) + "\n")
479 |
--------------------------------------------------------------------------------