├── tools
├── __init__.py
├── common.py
├── wikitext_dataset.py
├── openwebtext_dataset.py
└── mmap_dataset.py
├── dev-requirements.txt
├── .flake8
├── .dockerignore
├── requirements.txt
├── Dockerfile
├── .github
└── workflows
│ └── cffconvert.yml
├── CITATION.cff
├── .gitignore
├── README.md
├── cheatsheet.txt
├── eval_lambada.py
├── LICENSE
├── eval_wikitext.py
└── gpt_pretrain.py
/tools/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/dev-requirements.txt:
--------------------------------------------------------------------------------
1 | black==21.7b0
2 | mypy==0.910
3 | flake8==3.9.2
4 | pytest==6.2.4
5 |
--------------------------------------------------------------------------------
/.flake8:
--------------------------------------------------------------------------------
1 | [flake8]
2 | max-line-length = 115
3 |
4 | ignore =
5 | # these rules don't play well with black
6 | E203 # whitespace before :
7 | W503 # line break before binary operator
8 |
--------------------------------------------------------------------------------
/.dockerignore:
--------------------------------------------------------------------------------
1 | .dockerignore
2 | .git
3 | notebooks
4 | **.pyc
5 | **/__pycache__
6 | **/.mypy_cache
7 | .gitignore
8 | .git
9 | .github
10 | .flake8
11 | .venv
12 | *.md
13 | mypy.ini
14 | pytest.ini
15 | tests
16 | CITATION.cff
17 | dev-requirements.txt
18 | LICENSE
19 | cheatsheet.txt
20 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch==1.9.0
2 | numpy
3 | pytorch-lightning==1.3
4 | tokenizers==0.10.3
5 | sentencepiece==0.1.96
6 | transformers==4.8.2
7 | datasets==1.9.0
8 | accelerate==0.3.0
9 | click==7.1.2
10 | click-help-colors==0.9.1
11 | more-itertools==8.8.0
12 | tensorboardX
13 | tqdm==4.61.2
14 | test-tube==0.7.5
15 |
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
1 | # This Dockerfile creates an environment suitable for running `run.py`.
2 | FROM ghcr.io/allenai/pytorch:1.9.0-cuda11.1
3 |
4 | WORKDIR /stage/
5 |
6 | # Install remaining dependencies.
7 | COPY requirements.txt .
8 | RUN pip install --no-cache-dir -r requirements.txt
9 |
10 | WORKDIR /workspace
11 |
12 | COPY . .
13 |
14 | ENTRYPOINT ["python"]
15 |
--------------------------------------------------------------------------------
/.github/workflows/cffconvert.yml:
--------------------------------------------------------------------------------
1 | name: cffconvert
2 |
3 | on:
4 | pull_request:
5 | paths:
6 | - CITATION.cff
7 | push:
8 | paths:
9 | - CITATION.cff
10 |
11 | jobs:
12 | validate:
13 | name: "validate"
14 | runs-on: ubuntu-latest
15 | steps:
16 | - name: Check out a copy of the repository
17 | uses: actions/checkout@v2
18 |
19 | - name: Check whether the citation metadata from CITATION.cff is valid
20 | uses: citation-file-format/cffconvert-github-action@2.0.0
21 | with:
22 | args: "--validate"
23 |
--------------------------------------------------------------------------------
/tools/common.py:
--------------------------------------------------------------------------------
1 | def get_group_texts_function(block_size: int):
2 | def group_texts(examples):
3 | # Concatenate all texts.
4 | concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()} # type: ignore
5 | total_length = len(concatenated_examples[list(examples.keys())[0]])
6 | # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
7 | # customize this part to your needs.
8 | if total_length >= block_size:
9 | total_length = (total_length // block_size) * block_size
10 | # Split by chunks of max_len.
11 | result = {
12 | k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
13 | for k, t in concatenated_examples.items()
14 | }
15 | result["labels"] = result["input_ids"].copy()
16 | return result
17 |
18 | return group_texts
19 |
--------------------------------------------------------------------------------
/tools/wikitext_dataset.py:
--------------------------------------------------------------------------------
1 | from datasets import load_dataset
2 | from torch.utils.data import Dataset
3 | from transformers import GPT2Tokenizer
4 |
5 | from .common import get_group_texts_function
6 |
7 |
8 | def get_wikitext_dataset(
9 | tokenizer: GPT2Tokenizer,
10 | *,
11 | split: str = "test",
12 | block_size: int = 1024,
13 | num_workers: int = 1,
14 | ) -> Dataset:
15 | # dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="test") # type: ignore[assignment]
16 | dataset = load_dataset("wikitext", "wikitext-103-raw-v1", split="test")
17 | def tokenize_function(example):
18 | return tokenizer(example["text"])
19 |
20 | dataset = dataset.map( # type: ignore[union-attr,call-arg]
21 | tokenize_function,
22 | batched=True,
23 | num_proc=num_workers,
24 | remove_columns=["text"],
25 | desc="Tokenizing dataset",
26 | )
27 |
28 | group_texts = get_group_texts_function(block_size)
29 |
30 | dataset = dataset.map(
31 | group_texts,
32 | batched=True,
33 | num_proc=num_workers,
34 | desc=f"Grouping texts into chunks of {block_size}",
35 | )
36 |
37 | return dataset # type: ignore[return-value]
38 |
--------------------------------------------------------------------------------
/tools/openwebtext_dataset.py:
--------------------------------------------------------------------------------
1 | from datasets import load_dataset
2 | from torch.utils.data import Dataset
3 | from transformers import GPT2Tokenizer
4 |
5 | from .common import get_group_texts_function
6 |
7 |
8 | def get_openwebtext_dataset(
9 | tokenizer: GPT2Tokenizer,
10 | *,
11 | block_size: int = 1024,
12 | num_workers: int = 1,
13 | test_size: float = 0.01,
14 | seed: int = 17,
15 | ) -> Dataset:
16 | dataset_dict = load_dataset("openwebtext")
17 |
18 | # This dataset only comes with a single split, so we need to create our own train/test splits.
19 | dataset_dict = dataset_dict["train"].train_test_split( # type: ignore[index,union-attr]
20 | shuffle=True, test_size=test_size, seed=seed
21 | )
22 | dataset = dataset_dict["test"] # type: ignore[index]
23 |
24 | def tokenize_function(example):
25 | return tokenizer(example["text"])
26 |
27 | dataset = dataset.map( # type: ignore[union-attr,call-arg]
28 | tokenize_function,
29 | batched=True,
30 | num_proc=num_workers,
31 | remove_columns=["text"],
32 | desc="Tokenizing dataset",
33 | )
34 |
35 | group_texts = get_group_texts_function(block_size)
36 |
37 | dataset = dataset.map(
38 | group_texts,
39 | batched=True,
40 | num_proc=num_workers,
41 | desc=f"Grouping texts into chunks of {block_size}",
42 | )
43 |
44 | return dataset # type: ignore[return-value]
45 |
--------------------------------------------------------------------------------
/CITATION.cff:
--------------------------------------------------------------------------------
1 | # YAML 1.2
2 | ---
3 | cff-version: "1.2.0"
4 | title: "Staged Training for Transformer Language Models"
5 | license: "Apache-2.0"
6 | message: "If you use staged training in your research or wish to refer to the baseline results published here, please cite using this metadata."
7 | repository-code: "https://github.com/allenai/staged-training"
8 | authors:
9 | - affiliation: "University of California, Berkeley"
10 | family-names: Shen
11 | given-names: Sheng
12 | - affiliation: "Allen Institute for Artificial Intelligence"
13 | family-names: Walsh
14 | given-names: Pete
15 | - affiliation: "University of California, Berkeley"
16 | family-names: Keutzer
17 | given-names: Kurt
18 | - affiliation: "Allen Institute for Artificial Intelligence"
19 | family-names: Dodge
20 | given-names: Jesse
21 | - affiliation: "Allen Institute for Artificial Intelligence"
22 | family-names: Peters
23 | given-names: Matthew
24 | - affiliation: "Allen Institute for Artificial Intelligence"
25 | family-names: Beltagy
26 | given-names: Iz
27 | preferred-citation:
28 | type: "article"
29 | title: "Staged Training for Transformer Language Models"
30 | doi: "10.48550/arXiv.2203.06211"
31 | url: "https://arxiv.org/abs/2203.06211"
32 | year: 2022
33 | authors:
34 | - affiliation: "University of California, Berkeley"
35 | family-names: Shen
36 | given-names: Sheng
37 | - affiliation: "Allen Institute for Artificial Intelligence"
38 | family-names: Walsh
39 | given-names: Pete
40 | - affiliation: "University of California, Berkeley"
41 | family-names: Keutzer
42 | given-names: Kurt
43 | - affiliation: "Allen Institute for Artificial Intelligence"
44 | family-names: Dodge
45 | given-names: Jesse
46 | - affiliation: "Allen Institute for Artificial Intelligence"
47 | family-names: Peters
48 | given-names: Matthew
49 | - affiliation: "Allen Institute for Artificial Intelligence"
50 | family-names: Beltagy
51 | given-names: Iz
52 |
--------------------------------------------------------------------------------
/tools/mmap_dataset.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from torch.utils.data import Dataset
4 | from transformers import GPT2Tokenizer
5 |
6 |
7 | class MMapTextDataset(Dataset):
8 | def __init__(
9 | self,
10 | mmap_filename: str,
11 | *,
12 | chunk_size: int = 1024,
13 | bos_token_id: int = 50256,
14 | eos_token_id: int = 50256
15 | ):
16 | # `chunk_size - 2` to reserve space for and
17 | self.num_instances = np.memmap(mmap_filename, mode="r", dtype=np.uint16).shape[
18 | 0
19 | ] // (chunk_size - 2)
20 | # defer loading the token_ids memmap until after the first __getitem__ call.
21 | # when spawning new processes for ddp, there is a hard limit in python < 3.8 that
22 | # pickle files need to be < 4GB. By waiting until after the first __getitem__ we
23 | # don't have to pickle the memmap
24 | self.token_ids = None
25 | self._mmap_filename = mmap_filename
26 | self._chunk_size = chunk_size
27 | self._bos_token_id = bos_token_id
28 | self._eos_token_id = eos_token_id
29 |
30 | def __len__(self):
31 | return self.num_instances
32 |
33 | def __getitem__(self, idx: int):
34 | if self.token_ids is None:
35 | self.token_ids = np.memmap(self._mmap_filename, mode="r", dtype=np.uint16)
36 | from_index = idx * (self._chunk_size - 2)
37 | to_index = (idx + 1) * (self._chunk_size - 2)
38 | data = np.concatenate(
39 | (
40 | [self._bos_token_id],
41 | self.token_ids[from_index:to_index], # type: ignore[index]
42 | [self._eos_token_id],
43 | )
44 | )
45 | return torch.tensor(data, dtype=torch.long)
46 |
47 |
48 | def get_mmap_dataset(tokenizer: GPT2Tokenizer, filename: str, **kwargs) -> Dataset:
49 | return MMapTextDataset(
50 | filename,
51 | bos_token_id=tokenizer.bos_token_id or tokenizer.cls_token_id,
52 | eos_token_id=tokenizer.eos_token_id or tokenizer.sep_token_id,
53 | **kwargs,
54 | )
55 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
131 | # Mac stuff
132 | .DS_Store
133 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # staged-training
2 |
3 | In our paper [**Staged Training for Transformer Language Models**](https://arxiv.org/abs/2203.06211), we propose a staged training setup that begins with a small model and incrementally increases the amount of compute used for training by applying a "growth operator" to increase the model depth and width. By initializing each stage with the output of the previous one, the training process effectively re-uses the compute from prior stages and becomes more efficient.
4 |
5 | We release the reproducible code for the growth operator and evaluation scripts here.
6 |
7 | ## Setup
8 |
9 | The scripts in this repository require Python 3.7 or newer.
10 | Once you have a suitable Python environment, first install PyTorch v1.9.0 according the [official instructions](https://pytorch.org/get-started/previous-versions/#v190). Then run
11 | ```
12 | pip install -r requirements.txt
13 | ```
14 |
15 | ## Growth Operator
16 |
17 | Our growth operators (width/depth) each take as input the entire training state (including model parameters, optimizer state, learning rate schedule, etc.) and output a new training state from which training continues.
18 |
19 | Please see the `scripts/cheatsheet.txt` for more examples on how to use the corresponding scripts.
20 |
21 | For example, you can apply the width operator with:
22 | ```
23 | CUDA_VISIBLE_DEVICES=0,1,2,3 python scripts/gpt_pretrain.py \
24 | --save_prefix final_gpt2_large_div2_width_check_bs512_lr0.0020_warmup3k_seqlen1024_debug \
25 | --gpu_count -1 \
26 | --model gpt2 \
27 | --tokenizer gpt2 \
28 | --batch_size 4 \
29 | --grad_accum 32 \
30 | --lr 0.002006911598778545 \
31 | --warmup_steps 3000 \ \
32 | --train_steps 250000 \
33 | --val_every 50 \
34 | --val_batches 50 \
35 | --fp16 \
36 | --seqlen 1024 \
37 | --log_rate 10 \
38 | --num_workers 4 \
39 | --size GPT2_large_div2_width \
40 | --random \
41 | --resume final_runs/final_gpt2_large_div2_width_check_bs512_lr0.0021_warmup3k_seqlen1024_debug/checkpoint-xxx.ckpt \
42 | --doubling weights
43 | ```
44 |
45 | Or the depth operator with:
46 | ```
47 | CUDA_VISIBLE_DEVICES=0,1,2,3 python scripts/gpt_pretrain.py \
48 | --save_prefix final_gpt2_large_div2_depthx2_check_bs512_lr0.0020_warmup3k_seqlen1024_debug \
49 | --gpu_count -1 \
50 | --model gpt2 \
51 | --tokenizer gpt2 \
52 | --batch_size 4 \
53 | --grad_accum 32 \
54 | --lr 0.002006911598778545 \
55 | --warmup_steps 3000 \
56 | --train_steps 250000 \
57 | --val_every 50 \
58 | --val_batches 50 \
59 | --fp16 \
60 | --seqlen 1024 \
61 | --log_rate 10 \
62 | --num_workers 4 \
63 | --size GPT2_large_div2_depth \
64 | --random \
65 | --resume final_runs/final_gpt2_large_div2_depth_check_bs512_lr0.0020_warmup3k_seqlen1024_debug/checkpoint-epoch=0-step=6499.ckpt \
66 | --doubling layers
67 | ```
68 |
69 | ## Evaluation
70 |
71 | Use `evaluation/eval_wikitext.py` or `evaluation/eval_lambada.py` to evaluate [GPT-2](https://cdn.openai.com/better-language-models/language_models_are_unsupervised_multitask_learners.pdf) on one of the supported datasets. For example:
72 |
73 | ```bash
74 | python evaluation/eval_wikitext.py
75 | ```
76 |
77 | Or using Docker:
78 |
79 | ```bash
80 | docker build -t evaluation:latest .
81 | docker run --rm --gpus all evaluation:latest evaluation/eval_wikitext.py
82 | ```
83 |
84 | ## Reference
85 |
86 | If you use staged training in your research or wish to refer to the baseline results published here,
87 | please use the following BibTeX entry.
88 | ```
89 | @misc{shen2022staged,
90 | title={Staged Training for Transformer Language Models},
91 | author={Sheng Shen and Pete Walsh and Kurt Keutzer and Jesse Dodge and Matthew Peters and Iz Beltagy},
92 | year={2022},
93 | eprint={2203.06211},
94 | archivePrefix={arXiv},
95 | primaryClass={cs.CL}
96 | }
97 | ```
98 |
--------------------------------------------------------------------------------
/cheatsheet.txt:
--------------------------------------------------------------------------------
1 | # Thresholds are the same, constants are a bit different
2 | threshold_optimality = -0.052
3 | threshold_depth_growth = -0.0575
4 | threshold_width_growth = -0.0475
5 | threshold_depth_width_growth_ours = -0.03 # slightly different from the one you get from the scaling laws
6 |
7 | constant_op_width = 1.776
8 | constant_op_depth = 1.412
9 | constant_op_depth_width = 2.455
10 |
11 | # LR
12 | GPT2_base 0.002132892651963921
13 | GPT2_base_div2_depth 0.0021748863363590465
14 | GPT2_base_div2_width 0.002216880020754172
15 |
16 |
17 | GPT2_large 0.002006911598778545
18 | GPT2_large_div2_width 0.002090898967568796
19 | GPT2_large_div2_depth 0.0020489052831736704
20 | GPT2_large_div4_depth 0.002090898967568796
21 | GPT2_large_div4_width 0.0021748863363590465
22 |
23 |
24 | # Direct Pretrain 4 GPU
25 | CUDA_VISIBLE_DEVICES=0,1,2,3 python scripts/gpt_pretrain.py \
26 | --save_prefix final_gpt2_large_check_bs512_lr0.0020_warmup3k_seqlen1024_debug \
27 | --gpu_count -1 --model gpt2 --tokenizer gpt2 \
28 | --batch_size 4 --grad_accum 32 --lr 0.002006911598778545 --warmup_steps 3000 \
29 | --train_steps 250000 --val_every 50 --val_batches 50 --fp16 --seqlen 1024 \
30 | --log_rate 10 --num_workers 4 --size GPT2_large --random
31 |
32 | CUDA_VISIBLE_DEVICES=0,1,2,3 python scripts/gpt_pretrain.py \
33 | --save_prefix final_gpt2_large_div2_depth_check_bs512_lr0.0020_warmup3k_seqlen1024_debug \
34 | --gpu_count -1 --model gpt2 --tokenizer gpt2 \
35 | --batch_size 4 --grad_accum 32 --lr 0.0020489052831736704 --warmup_steps 3000 \
36 | --train_steps 250000 --val_every 50 --val_batches 50 --fp16 --seqlen 1024 \
37 | --log_rate 10 --num_workers 2 --size GPT2_large_div2_depth --random
38 |
39 |
40 | CUDA_VISIBLE_DEVICES=0,1,2,3 python scripts/gpt_pretrain.py \
41 | --save_prefix final_gpt2_large_div2_width_check_bs512_lr0.0020_warmup3k_seqlen1024_debug \
42 | --gpu_count -1 --model gpt2 --tokenizer gpt2 \
43 | --batch_size 4 --grad_accum 32 --lr 0.002090898967568796 --warmup_steps 3000 \
44 | --train_steps 250000 --val_every 50 --val_batches 50 --fp16 --seqlen 1024 \
45 | --log_rate 10 --num_workers 2 --size GPT2_large_div2_width --random
46 |
47 |
48 | # Use the operator to double the weights
49 | # First, apply the operator to the ckpts
50 | CUDA_VISIBLE_DEVICES=0,1,2,3 python scripts/gpt_pretrain.py \
51 | --save_prefix final_gpt2_large_div2_width_check_bs512_lr0.0020_warmup3k_seqlen1024_debug \
52 | --gpu_count -1 --model gpt2 --tokenizer gpt2 --batch_size 4 --grad_accum 32 --lr 0.002006911598778545 --warmup_steps 3000 \ --train_steps 250000 --val_every 50 --val_batches 50 --fp16 --seqlen 1024 --log_rate 10 --num_workers 4 \
53 | --size GPT2_large_div2_width --random \
54 | --resume final_runs/final_gpt2_large_div2_width_check_bs512_lr0.0021_warmup3k_seqlen1024_debug/checkpoint-epoch=0-step=6249.ckpt \
55 | --doubling weights --restart_warmup_steps 200 --restart_steps 3319 \
56 | --reset_lr_scheduler
57 |
58 | # Second, resume the grown ckpt and set the restart step and re-warmup steps
59 | CUDA_VISIBLE_DEVICES=0,1,2,3 python scripts/gpt_pretrain.py \
60 | --save_prefix final_gpt2_large_div2_depthx2_check_bs512_lr0.0020_warmup3k_seqlen1024_debug \
61 | --gpu_count -1 --model gpt2 --tokenizer gpt2 \
62 | --batch_size 4 --grad_accum 32 --lr 0.002006911598778545 --warmup_steps 3000 \
63 | --train_steps 250000 --val_every 50 --val_batches 50 --fp16 --seqlen 1024 \
64 | --log_rate 10 --num_workers 4 --size GPT2_large --random \
65 | --resume final_runs/final_gpt2_large_div2_width_check_bs512_lr0.0020_warmup3k_seqlen1024_debug/checkpoint-epoch=0-step=6499.ckpt.doubled_weights \
66 | --restart_warmup_steps 150 --restart_steps 3319 --reset_lr_scheduler
67 |
68 | # Use the operator to double the layers
69 | # First, apply the operator to the ckpts
70 | CUDA_VISIBLE_DEVICES=0,1,2,3 python scripts/gpt_pretrain.py \
71 | --save_prefix final_gpt2_large_div2_depthx2_check_bs512_lr0.0020_warmup3k_seqlen1024_debug \
72 | --gpu_count -1 --model gpt2 --tokenizer gpt2 \
73 | --batch_size 4 --grad_accum 32 --lr 0.002006911598778545 --warmup_steps 3000 \
74 | --train_steps 250000 --val_every 50 --val_batches 50 --fp16 --seqlen 1024 \
75 | --log_rate 10 --num_workers 4 --size GPT2_large_div2_depth --random \
76 | --resume final_runs/final_gpt2_large_div2_depth_check_bs512_lr0.0020_warmup3k_seqlen1024_debug/checkpoint-epoch=0-step=6499.ckpt \
77 | --doubling layers --restart_warmup_steps 150 --restart_steps 4449 --reset_lr_scheduler --doubling_layers alternate_id
78 |
79 | # Second, resume the grown ckpt and set the restart step and re-warmup steps
80 | CUDA_VISIBLE_DEVICES=0,1,2,3 python scripts/gpt_pretrain.py \
81 | --save_prefix final_gpt2_large_div2_depthx2_check_bs512_lr0.0020_warmup3k_seqlen1024_debug \
82 | --gpu_count -1 --model gpt2 --tokenizer gpt2 \
83 | --batch_size 4 --grad_accum 32 --lr 0.002006911598778545 --warmup_steps 3000 \
84 | --train_steps 250000 --val_every 50 --val_batches 50 --fp16 --seqlen 1024 \
85 | --log_rate 10 --num_workers 4 --size GPT2_large --random \
86 | --resume final_runs/final_gpt2_large_div2_depth_check_bs512_lr0.0020_warmup3k_seqlen1024_debug/checkpoint-epoch=0-step=6499.ckpt.doubled_layer \
87 | --restart_warmup_steps 150 --restart_steps 4449 --reset_lr_scheduler
88 |
89 |
--------------------------------------------------------------------------------
/eval_lambada.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import math
3 | import os
4 | import time
5 |
6 | import numpy as np
7 | import torch
8 | import tqdm
9 | from torch.utils.data import DataLoader, Dataset
10 | from tqdm import trange
11 |
12 |
13 | from transformers import (
14 | GPT2Tokenizer,
15 | GPT2LMHeadModel,
16 | GPT2Config,
17 | default_data_collator,
18 | DataCollatorForLanguageModeling,
19 | )
20 | from torch.utils.data import DataLoader, Dataset, Subset
21 |
22 | model_name = 'gpt2'
23 | enc = GPT2Tokenizer.from_pretrained(model_name)
24 | model = GPT2LMHeadModel.from_pretrained(model_name)
25 |
26 | parser = argparse.ArgumentParser()
27 | parser.add_argument('--path', type=str, default='lambada_test.jsonl', help='location of lambada dataset')
28 | parser.add_argument('--batch', type=int, default=4, help='batch size')
29 | parser.add_argument('--max-batches', type=int, default=0, help='batch size')
30 | parser.add_argument('--ignore-fragments', action='store_true', help="Whether to run training.")
31 | parser.add_argument('--preprocess', action='store_true', help="strip quotes")
32 | parser.add_argument('--jeff_suggestion', action='store_true', help="use jeff's suggestion of prepending \n to each example")
33 | parser.add_argument('--dryrun', action='store_true', help="test preprocessing pipeline")
34 | parser.add_argument('--checkpoint-file', default=None, help='location of lambada dataset')
35 |
36 | args = parser.parse_args()
37 |
38 |
39 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
40 | args.device = device
41 |
42 |
43 | model_name = 'gpt2'
44 | enc = GPT2Tokenizer.from_pretrained(model_name)
45 |
46 | checkpoint_file = args.checkpoint_file
47 |
48 | if checkpoint_file is not None:
49 | state_dict = torch.load(checkpoint_file)
50 | state_dict = state_dict['state_dict']
51 | state_dict = {k.replace("model.", ""): p for k, p in state_dict.items()}
52 |
53 | config = GPT2Config.from_pretrained(model_name)
54 |
55 | # Guess model size.
56 | print(state_dict.keys)
57 | config.n_embd = state_dict["transformer.wte.weight"].shape[1]
58 | config.n_layer = len(
59 | [key for key in state_dict if key.endswith("mlp.c_proj.bias")]
60 | )
61 | print(
62 | f"Adjusting hidden_size to {config.n_embd}, num_layers to {config.n_layer}"
63 | )
64 | if config.n_embd == 1536 or config.n_embd == 3072:
65 | config.n_head = 16
66 | else:
67 | config.n_head = 8
68 | #config.n_head = 12
69 | config.n_positions == 1024
70 | if "alibi" in checkpoint_file:
71 | config.alibi_embeddings = True
72 | if "rotary" in checkpoint_file:
73 | config.rotary_embeddings = True
74 | # Initialize model.
75 | model = GPT2LMHeadModel(config)
76 | model.load_state_dict(state_dict, strict=True)
77 | else:
78 | model = GPT2LMHeadModel.from_pretrained(model_name)
79 | model.to(device)
80 |
81 |
82 | def argmax(t):
83 | return int(torch.argmax(t).item())
84 |
85 | # from https://github.com/openai/gpt-2/issues/131#issuecomment-492786058
86 | def preprocess(text):
87 | text = text.replace("“", '"')
88 | text = text.replace("”", '"')
89 | text = text.replace("''", '"')
90 | text = text.replace("``", '"')
91 | return '\n'+text.strip()
92 |
93 | def score_batch(batch):
94 | """Return number of last-word mismatches in a batch."""
95 | batch_encoded = []
96 | lengths = []
97 | fragments = []
98 | for line in batch:
99 | line = line.strip()
100 | if args.jeff_suggestion:
101 | line = '\n'+line
102 | line_encoded = enc.encode(line)
103 | encoded_last_word = enc.decode(line_encoded[-1:]).strip()
104 | actual_last_word = line.split()[-1].strip()
105 | if encoded_last_word != actual_last_word:
106 | fragments.append(True)
107 | else:
108 | fragments.append(False)
109 | batch_encoded.append(line_encoded)
110 |
111 | # array is ragged, so pad to turn into rectangular tensor
112 | max_len = max(len(encoded) for encoded in batch_encoded)
113 | batch_padded = []
114 | for encoded in batch_encoded:
115 | batch_padded.append(encoded+[0]*(max_len - len(encoded)))
116 | lengths.append(len(encoded))
117 |
118 | batch_padded = torch.tensor(batch_padded)
119 | batch_padded = batch_padded.to(device)
120 | if args.dryrun:
121 | return 0, 1
122 |
123 | # logits, presents = model(batch_padded)
124 | outputs = model(batch_padded)
125 | logits = outputs.logits
126 | errors = 0
127 | total = 0
128 | for i in range(args.batch):
129 | # break on small last batch
130 | if i >= len(batch_padded):
131 | break
132 | last_idx = lengths[i]-1
133 | observed = batch_encoded[i][last_idx]
134 | predicted = argmax(logits[i][last_idx-1])
135 | if args.ignore_fragments and fragments[i]:
136 | continue
137 | total+=1
138 | errors += 0 if (observed == predicted) else 1
139 |
140 | return errors, total
141 |
142 |
143 | def main():
144 | ds_raw = open(f'{args.path}').read()
145 | if args.preprocess:
146 | ds_raw = preprocess(ds_raw)
147 |
148 | ds = ds_raw.strip().split('\n')
149 |
150 | # special handling for jsonl file
151 | lines = []
152 | if args.path.endswith('.jsonl'):
153 | # special handling for file from Jeff
154 | for line in ds:
155 | # candidate1 = eval(line)['text']
156 | # lines.append(candidate1)
157 | candidate2 = line[len('{"text": "'):-len('"}')]
158 | candidate2 = f'''"""{candidate2}"""'''
159 | lines.append(eval(candidate2))
160 |
161 | # lines.append(eval(line))
162 | #print(line)
163 | # break
164 | # print(line)
165 | # eprint(lines[-1])
166 | ds = lines
167 | data_loader = DataLoader(ds, batch_size=args.batch, shuffle=False)
168 |
169 | errors = 0
170 | total = 0
171 | for batch in tqdm.tqdm(data_loader):
172 | errors_batch, total_batch = score_batch(batch)
173 | errors += errors_batch
174 | total += total_batch
175 | # if args.max_batches and i>=args.max_batches-1:
176 | # break
177 |
178 | print("Accuracy: %.4f"%(1-errors/total,))
179 |
180 |
181 | if __name__=='__main__':
182 | main()
183 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/eval_wikitext.py:
--------------------------------------------------------------------------------
1 | """
2 | Evaluate pretrained GPT-2 models on standard datasets.
3 |
4 | Adapted from https://github.com/huggingface/transformers/blob/master/examples/pytorch/language-modeling/run_clm_no_trainer.py.
5 | """ # noqa: E501
6 |
7 | import math
8 | import os
9 | import random
10 | import shutil
11 | import tempfile
12 | from itertools import islice
13 | from typing import Any, Dict, Optional
14 |
15 | import click
16 | import numpy as np
17 | import torch
18 | from accelerate import Accelerator
19 | from click_help_colors import HelpColorsCommand, HelpColorsGroup
20 | from more_itertools import chunked
21 | from torch.utils.data import DataLoader, Dataset, DistributedSampler, Sampler
22 | from tqdm import tqdm
23 | from transformers import (DataCollatorForLanguageModeling, GPT2Config,
24 | GPT2LMHeadModel, GPT2Tokenizer,
25 | default_data_collator)
26 | from transformers.optimization import AdamW, get_linear_schedule_with_warmup
27 |
28 | from tools.mmap_dataset import get_mmap_dataset
29 | from tools.openwebtext_dataset import get_openwebtext_dataset
30 | from tools.wikitext_dataset import get_wikitext_dataset
31 |
32 |
33 | @click.group(
34 | cls=HelpColorsGroup,
35 | help_options_color="green",
36 | help_headers_color="yellow",
37 | context_settings={"max_content_width": 115},
38 | )
39 | def main():
40 | if torch.cuda.is_available():
41 | click.echo("CUDA is available :)\n")
42 | else:
43 | click.secho("No CUDA devices available!\n", fg="red")
44 |
45 |
46 | @main.command(
47 | cls=HelpColorsCommand,
48 | help_options_color="green",
49 | help_headers_color="yellow",
50 | context_settings={"max_content_width": 115},
51 | )
52 | @click.option("--model-name", default="gpt2")
53 | @click.option(
54 | "--dataset",
55 | default="wikitext2",
56 | type=click.Choice(["wikitext2", "openwebtext", "mmap"]),
57 | show_choices=True,
58 | show_default=True,
59 | help="The dataset to evaluate on.",
60 | )
61 | @click.option(
62 | "--dataset-path",
63 | default=None,
64 | type=click.Path(exists=True, dir_okay=False, resolve_path=True),
65 | help="Path to the memory-mapped dataset (only valid when --dataset=mmap)",
66 | )
67 | @click.option(
68 | "--block-size",
69 | default=1024,
70 | show_default=True,
71 | help="""Input texts are blocked together into blocks of this size.
72 | This should probably match the max input size of the model.""",
73 | )
74 | @click.option(
75 | "--batch-size",
76 | default=32,
77 | show_default=True,
78 | help="The batch size to use for evaluation.",
79 | )
80 | @click.option(
81 | "--checkpoint-file",
82 | default=None,
83 | type=click.Path(exists=True, dir_okay=False, resolve_path=True),
84 | help="A checkpoint file to load the weights from.",
85 | )
86 | @click.option(
87 | "--skip-loading-weights",
88 | is_flag=True,
89 | help="Leave the model's weights at their random initialization.",
90 | )
91 | @click.option(
92 | "--max-steps",
93 | default=None,
94 | type=click.INT,
95 | )
96 | def eval(
97 | model_name: str,
98 | dataset: str,
99 | dataset_path: Optional[str],
100 | block_size: int,
101 | batch_size: int,
102 | checkpoint_file: Optional[str],
103 | skip_loading_weights: bool,
104 | max_steps: Optional[int],
105 | ):
106 | """
107 | Evaluate a GPT-2 model on a dataset.
108 | """
109 | # Validate params.
110 | if dataset != "mmap" and dataset_path is not None:
111 | raise click.UsageError("'--dataset-path' only valid when '--dataset=mmap'")
112 | if dataset == "mmap" and dataset_path is None:
113 | raise click.UsageError("'--dataset-path' is required for this dataset type")
114 |
115 | click.secho("[1/3] Loading tokenizer and model...", fg="green")
116 |
117 | tokenizer = GPT2Tokenizer.from_pretrained(model_name)
118 | if checkpoint_file is not None:
119 | click.echo(f"Loading checkpoint from {checkpoint_file}")
120 |
121 | # Load state dict.
122 | state_dict = torch.load(checkpoint_file)
123 | state_dict = state_dict["state_dict"]
124 | state_dict = {k.replace("model.", ""): p for k, p in state_dict.items()}
125 |
126 | config = GPT2Config.from_pretrained(model_name)
127 |
128 | # Guess model size.
129 | config.n_embd = state_dict["transformer.wte.weight"].shape[1]
130 | config.n_layer = len(
131 | [key for key in state_dict if key.endswith("mlp.c_proj.bias")]
132 | )
133 | click.echo(
134 | f"Adjusting hidden_size to {config.n_embd}, num_layers to {config.n_layer}"
135 | )
136 | if config.n_embd == 1536 or config.n_embd == 3072:
137 | config.n_head = 16
138 | else:
139 | config.n_head = 8
140 |
141 | config.n_positions = 1024
142 | if "alibi" in checkpoint_file:
143 | config.alibi_embeddings = True
144 | if "rotary" in checkpoint_file:
145 | config.rotary_embeddings = True
146 | # Initialize model.
147 | model = GPT2LMHeadModel(config)
148 | if not skip_loading_weights:
149 | model.load_state_dict(state_dict, strict=True)
150 | elif skip_loading_weights:
151 | config = GPT2Config.from_pretrained(model_name)
152 | model = GPT2LMHeadModel(config)
153 | else:
154 | model = GPT2LMHeadModel.from_pretrained(model_name)
155 |
156 | model = model.cuda()
157 |
158 | click.secho("\n[2/3] Preprocessing data...", fg="green")
159 |
160 | dataloader = get_dataloader(
161 | dataset,
162 | tokenizer,
163 | block_size=block_size,
164 | batch_size=batch_size,
165 | dataset_path=dataset_path,
166 | )
167 |
168 | click.secho("\n[3/3] Evaluating model on data...", fg="green")
169 |
170 | model.eval()
171 |
172 | running_loss = 0.0
173 | total_batches = (
174 | len(dataloader) if max_steps is None else min([max_steps, len(dataloader)])
175 | )
176 | with tqdm(
177 | islice(dataloader, total_batches), desc="Evaluating", total=total_batches
178 | ) as batch_iterator:
179 | for i, batch in enumerate(batch_iterator):
180 | batch = {k: v.cuda() for k, v in batch.items()}
181 | # with torch.inference_mode():
182 | with torch.no_grad():
183 | outputs = model(**batch)
184 |
185 | loss = outputs.loss
186 | running_loss += loss.item()
187 |
188 | if i % 50 == 0 or i == total_batches - 1:
189 | mean_loss = running_loss / (i + 1)
190 | ppl = math.exp(mean_loss)
191 | batch_iterator.set_postfix(loss=mean_loss, ppl=ppl)
192 |
193 | mean_loss = running_loss / total_batches
194 | ppl = math.exp(mean_loss)
195 |
196 | click.secho(
197 | f"\nDone! Final loss: {mean_loss:.4f} (ppl = {ppl:.4f})", fg="green", bold=True
198 | )
199 |
200 |
201 | @main.command()
202 | @click.argument(
203 | "train-dataset-path",
204 | type=click.Path(exists=True, dir_okay=False, resolve_path=True),
205 | )
206 | @click.argument(
207 | "validation-dataset-path",
208 | type=click.Path(exists=True, dir_okay=False, resolve_path=True),
209 | )
210 | @click.argument(
211 | "log-dir",
212 | type=click.Path(exists=False, dir_okay=True, resolve_path=True),
213 | )
214 | @click.option(
215 | "--block-size",
216 | default=1024,
217 | show_default=True,
218 | help="""Input texts are blocked together into blocks of this size.
219 | This should probably match the max input size of the model.""",
220 | )
221 | @click.option(
222 | "--batch-size",
223 | default=32,
224 | show_default=True,
225 | help="The batch size to use for training and validation.",
226 | )
227 | @click.option(
228 | "--grad-accum",
229 | default=1,
230 | show_default=True,
231 | help="The number of gradient accumulation steps per update.",
232 | )
233 | @click.option(
234 | "--num-heads",
235 | default=4,
236 | show_default=True,
237 | help="The number of attention heads.",
238 | )
239 | @click.option(
240 | "--num-layers",
241 | default=4,
242 | show_default=True,
243 | help="The number of transformer layers.",
244 | )
245 | @click.option(
246 | "--hidden-size",
247 | default=256,
248 | show_default=True,
249 | help="The hidden size of the model.",
250 | )
251 | @click.option(
252 | "--lr",
253 | default=None,
254 | type=click.FLOAT,
255 | show_default=True,
256 | help="The learning rate. Defaults to '0.003239 - 0.0001395 log(N)'.",
257 | )
258 | @click.option(
259 | "--adam-epsilon",
260 | default=1e-6,
261 | show_default=True,
262 | )
263 | @click.option(
264 | "--adam-beta1",
265 | default=0.9,
266 | show_default=True,
267 | )
268 | @click.option(
269 | "--adam-beta2",
270 | default=0.95,
271 | show_default=True,
272 | )
273 | @click.option(
274 | "--warmup-steps",
275 | default=3000,
276 | show_default=True,
277 | )
278 | @click.option(
279 | "--train-steps",
280 | default=100000,
281 | show_default=True,
282 | )
283 | @click.option(
284 | "--validation-steps",
285 | default=50,
286 | show_default=True,
287 | )
288 | @click.option(
289 | "--validate-every",
290 | default=100,
291 | show_default=True,
292 | )
293 | @click.option(
294 | "--checkpoint-every",
295 | default=100,
296 | show_default=True,
297 | )
298 | @click.option(
299 | "--wandb-entity",
300 | default="allenai-team1",
301 | show_default=True,
302 | )
303 | @click.option(
304 | "--wandb-project",
305 | default="staged-training",
306 | show_default=True,
307 | )
308 | @click.option(
309 | "--amp", is_flag=True, help="""Train with automatic mixed-precision enabled."""
310 | )
311 | @click.option(
312 | "--recover", is_flag=True, help="""Restart training from a previous run."""
313 | )
314 | @click.option(
315 | "--recover-from",
316 | type=click.Path(exists=True, dir_okay=True, resolve_path=True),
317 | help="""Log directory to recover from if different.""",
318 | )
319 | @click.option(
320 | "--init-seed",
321 | default=42,
322 | show_default=True,
323 | )
324 | @click.option(
325 | "--data-seed",
326 | default=42,
327 | show_default=True,
328 | )
329 | @click.option(
330 | "--wandb-tags", help="""A comma-separated list of tags to assign to the W&B run."""
331 | )
332 | def train(
333 | train_dataset_path: str,
334 | validation_dataset_path: str,
335 | log_dir: str,
336 | block_size: int,
337 | batch_size: int,
338 | grad_accum: int,
339 | num_heads: int,
340 | num_layers: int,
341 | hidden_size: int,
342 | lr: float,
343 | adam_epsilon: float,
344 | adam_beta1: float,
345 | adam_beta2: float,
346 | warmup_steps: int,
347 | train_steps: int,
348 | validation_steps: int,
349 | validate_every: int,
350 | checkpoint_every: int,
351 | wandb_entity: str,
352 | wandb_project: str,
353 | amp: bool,
354 | recover: bool,
355 | recover_from: Optional[str],
356 | init_seed: int,
357 | data_seed: int,
358 | wandb_tags: Optional[str],
359 | ):
360 | """
361 | Train a GPT-2 model on C4.
362 | """
363 | accelerator = Accelerator(fp16=amp)
364 | device = accelerator.device
365 | is_distributed = accelerator.num_processes > 1
366 |
367 | state_path = os.path.join(
368 | log_dir, f"state_worker_{accelerator.local_process_index}.pt"
369 | )
370 |
371 | # Check log_dir.
372 | initial_state: Optional[Dict[str, Any]] = None
373 | if recover:
374 | if recover_from is not None:
375 | # Copy over contents to log_dir
376 | assert os.path.isdir(recover_from)
377 | assert os.path.isfile(
378 | os.path.join(
379 | recover_from, f"state_worker_{accelerator.local_process_index}.pt"
380 | )
381 | )
382 | if accelerator.is_local_main_process:
383 | assert not os.path.exists(log_dir) or not os.listdir(log_dir)
384 | shutil.copytree(
385 | recover_from,
386 | log_dir,
387 | # dirs_exist_ok=True, only available for python >= 3.8
388 | )
389 | accelerator.wait_for_everyone()
390 | assert os.path.isdir(log_dir)
391 | assert os.path.isfile(state_path)
392 | click.echo(
393 | f"[Worker {accelerator.local_process_index}] Loading training state from {state_path}"
394 | )
395 | initial_state = torch.load(state_path)
396 | else:
397 | assert not os.path.exists(log_dir) or not os.listdir(log_dir)
398 | if accelerator.is_local_main_process:
399 | os.makedirs(log_dir, exist_ok=True)
400 |
401 | if accelerator.is_local_main_process:
402 | click.echo(f"Training on {accelerator.num_processes} devices")
403 | click.secho(
404 | "\n[1/3] Initializing tokenizer, model, and optimizer...", fg="green"
405 | )
406 |
407 | tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
408 |
409 | config = GPT2Config.from_pretrained("gpt2")
410 | config.n_head = num_heads
411 | config.n_layer = num_layers
412 | config.n_embd = hidden_size
413 |
414 | # Set random seeds for model initialization.
415 | set_seeds(init_seed)
416 |
417 | model = GPT2LMHeadModel(config)
418 |
419 | total_params = 0
420 | for name, param in model.named_parameters():
421 | # Ignore embedding matrix when calculating size.
422 | if name == "transformer.wte.weight" or name == "transformer.wpe.weight":
423 | continue
424 | total_params += param.numel()
425 | if accelerator.is_local_main_process:
426 | click.echo(f"Total non-embedding parameters: {total_params:,}")
427 |
428 | if lr is None:
429 | lr = 0.003239 - 0.0001395 * np.log(total_params)
430 |
431 | optimizer = AdamW(
432 | model.parameters(),
433 | lr=lr,
434 | eps=adam_epsilon,
435 | betas=(adam_beta1, adam_beta2),
436 | correct_bias=False,
437 | )
438 | scheduler = get_linear_schedule_with_warmup(
439 | optimizer, num_warmup_steps=warmup_steps, num_training_steps=train_steps
440 | )
441 |
442 | if accelerator.is_local_main_process:
443 | click.secho("\n[2/3] Loading data...", fg="green")
444 |
445 | # Set random seeds for data shuffling.
446 | set_seeds(data_seed)
447 |
448 | train_dataloader = get_dataloader(
449 | "mmap",
450 | tokenizer,
451 | dataset_path=train_dataset_path,
452 | block_size=block_size,
453 | batch_size=batch_size,
454 | shuffle=True,
455 | is_distributed=is_distributed,
456 | seed=data_seed,
457 | )
458 |
459 | validation_dataloader = get_dataloader(
460 | "mmap",
461 | tokenizer,
462 | dataset_path=validation_dataset_path,
463 | block_size=block_size,
464 | batch_size=batch_size,
465 | is_distributed=is_distributed,
466 | seed=data_seed,
467 | )
468 |
469 | # NOTE: We don't call `prepare()` on the dataloaders because that causes a memory leak,
470 | # and it's not necessary anyway.
471 | model, optimizer = accelerator.prepare(model, optimizer)
472 |
473 | validation_steps = min([len(validation_dataloader), validation_steps])
474 |
475 | # Load state.
476 | if initial_state is not None:
477 | optimizer.load_state_dict(initial_state["optimizer"])
478 | scheduler.load_state_dict(initial_state["scheduler"])
479 | model.load_state_dict(initial_state["model"])
480 |
481 | wandb_run_id: Optional[str] = None
482 | if accelerator.is_main_process:
483 | import wandb
484 |
485 | if initial_state is not None:
486 | wandb_run_id = initial_state["wandb_run_id"]
487 | else:
488 | wandb_run_id = wandb.util.generate_id()
489 |
490 | wandb.init(
491 | id=wandb_run_id,
492 | dir=log_dir,
493 | entity=wandb_entity,
494 | resume="auto",
495 | project=wandb_project,
496 | tags=None if not wandb_tags else wandb_tags.split(","),
497 | config={
498 | "init_seed": init_seed,
499 | "data_seed": data_seed,
500 | "total_params": total_params,
501 | "learning_rate": lr,
502 | "adam_beta1": adam_beta1,
503 | "adam_beta2": adam_beta2,
504 | "adam_epsilon": adam_epsilon,
505 | "batch_size": batch_size,
506 | "grad_accum": grad_accum,
507 | "num_processes": accelerator.num_processes,
508 | "effective_batch_size": batch_size
509 | * grad_accum
510 | * accelerator.num_processes,
511 | },
512 | )
513 |
514 | accelerator.wait_for_everyone()
515 |
516 | if accelerator.is_local_main_process:
517 | click.secho("\n[3/3] Training...", fg="green")
518 |
519 | model.train()
520 | val_loss: Optional[float] = None
521 | training_batches = enumerate(
522 | islice(
523 | chunked(cycle_through_epochs(train_dataloader, is_distributed), grad_accum),
524 | train_steps,
525 | )
526 | )
527 |
528 | # Catch data loader up to where we left off before.
529 | if initial_state is not None:
530 | click.echo(
531 | f"[Worker {accelerator.local_process_index}] "
532 | f"Catching data loader up to step {initial_state['training_steps']}..."
533 | )
534 | training_steps = initial_state["training_steps"]
535 | for step, batch in training_batches:
536 | del batch
537 | if step >= training_steps - 1:
538 | break
539 | accelerator.wait_for_everyone()
540 |
541 | with tqdm(
542 | training_batches,
543 | desc="Training",
544 | initial=0 if initial_state is None else initial_state["training_steps"],
545 | total=train_steps,
546 | disable=not accelerator.is_local_main_process,
547 | ) as train_batch_iterator:
548 | for step, batch in train_batch_iterator:
549 |
550 | def save_state():
551 | temp_state_file = tempfile.NamedTemporaryFile(
552 | "w+b", dir=log_dir, delete=False, suffix="pt"
553 | )
554 | try:
555 | torch.save(
556 | {
557 | "optimizer": optimizer.state_dict(),
558 | "scheduler": scheduler.state_dict(),
559 | "model": model.state_dict(),
560 | "wandb_run_id": wandb_run_id,
561 | "training_steps": step + 1,
562 | },
563 | temp_state_file.name,
564 | )
565 | temp_state_file.close()
566 | os.replace(temp_state_file.name, state_path)
567 | finally:
568 | if os.path.exists(temp_state_file.name):
569 | os.remove(temp_state_file.name)
570 |
571 | optimizer.zero_grad()
572 | batch_loss = 0.0
573 | batch_ppl: Optional[float] = None
574 | for micro_batch in batch:
575 | # Move tensors to right device.
576 | micro_batch = {k: v.to(device) for k, v in micro_batch.items()}
577 |
578 | # Get loss.
579 | outputs = model(**micro_batch)
580 | micro_batch_loss = outputs.loss / len(batch)
581 | batch_loss += micro_batch_loss.detach().item()
582 |
583 | # Calculate gradients.
584 | accelerator.backward(micro_batch_loss)
585 |
586 | # Clean up.
587 | del micro_batch
588 | del outputs
589 | del micro_batch_loss
590 |
591 | del batch
592 |
593 | # Take step.
594 | optimizer.step()
595 | scheduler.step()
596 |
597 | should_log_this_step = step % 10 == 0 or step == train_steps - 1
598 | should_checkpoint_this_step = step > 0 and step % checkpoint_every == 0
599 | should_validate_this_step = (
600 | step > 0 and step % validate_every == 0
601 | ) or step == train_steps - 1
602 |
603 | # Gather average loss across all workers.
604 | if should_log_this_step or should_validate_this_step:
605 | batch_loss = (
606 | accelerator.gather(
607 | torch.tensor(batch_loss, device=device).unsqueeze(0)
608 | )
609 | .mean()
610 | .item()
611 | )
612 | batch_ppl = math.exp(batch_loss) # type: ignore[arg-type]
613 |
614 | # Update progress bar and log to W&B.
615 | if accelerator.is_local_main_process and should_log_this_step:
616 | if val_loss is not None:
617 | train_batch_iterator.set_postfix(
618 | batch_loss=batch_loss,
619 | batch_ppl=batch_ppl,
620 | val_loss=val_loss,
621 | val_ppl=math.exp(val_loss),
622 | )
623 | else:
624 | train_batch_iterator.set_postfix(
625 | batch_loss=batch_loss, batch_ppl=batch_ppl
626 | )
627 |
628 | if accelerator.is_main_process:
629 | wandb.log(
630 | {
631 | "batch_loss": batch_loss,
632 | "batch_ppl": batch_ppl,
633 | "lr": optimizer.param_groups[0]["lr"],
634 | },
635 | step=step,
636 | )
637 |
638 | # Checkpoint.
639 | if should_checkpoint_this_step:
640 | save_state()
641 |
642 | # Validate.
643 | if should_validate_this_step:
644 | # Prepare model for validation.
645 | model.eval()
646 | optimizer.zero_grad() # Not strictly necessary.
647 |
648 | running_loss = 0.0
649 | with tqdm(
650 | islice(validation_dataloader, validation_steps),
651 | desc="Validating",
652 | total=validation_steps,
653 | leave=False,
654 | disable=not accelerator.is_local_main_process,
655 | ) as val_batch_iterator:
656 | for val_step, val_batch in enumerate(val_batch_iterator):
657 | # Move tensors to right device.
658 | val_batch = {k: v.to(device) for k, v in val_batch.items()}
659 |
660 | # Get loss.
661 | with torch.inference_mode():
662 | outputs = model(**val_batch)
663 | loss = outputs.loss
664 |
665 | running_loss += loss.item()
666 | val_loss = running_loss / (val_step + 1)
667 |
668 | # Update progress bar.
669 | if accelerator.is_local_main_process and val_step % 10 == 0:
670 | val_batch_iterator.set_postfix(
671 | loss=val_loss, ppl=math.exp(val_loss)
672 | )
673 |
674 | # Clean up.
675 | del val_batch
676 | del outputs
677 | del loss
678 |
679 | # Average loss across all workers.
680 | val_loss = (
681 | accelerator.gather(
682 | torch.tensor(val_loss, device=device).unsqueeze(0)
683 | )
684 | .mean()
685 | .item()
686 | )
687 |
688 | # Reset model to train mode.
689 | model.train()
690 |
691 | # Update progress bar again with validation stats and log to W&B.
692 | val_ppl = math.exp(val_loss) # type: ignore[arg-type]
693 | if accelerator.is_local_main_process:
694 | train_batch_iterator.set_postfix(
695 | batch_loss=batch_loss,
696 | batch_ppl=batch_ppl,
697 | val_loss=val_loss,
698 | val_ppl=val_ppl,
699 | )
700 | if accelerator.is_main_process:
701 | wandb.log({"val_loss": val_loss, "val_ppl": val_ppl}, step=step)
702 |
703 | if accelerator.is_main_process:
704 | wandb.finish()
705 |
706 | click.secho("\nDone!", fg="green", bold=True)
707 |
708 |
709 | def get_dataloader(
710 | dataset: str,
711 | tokenizer: GPT2Tokenizer,
712 | *,
713 | block_size: int = 1024,
714 | batch_size: int = 32,
715 | dataset_path: Optional[str] = None,
716 | shuffle: bool = False,
717 | is_distributed: bool = False,
718 | seed: int = 0,
719 | ) -> DataLoader:
720 | dataset_object: Dataset
721 | collator = default_data_collator
722 | if dataset == "wikitext2":
723 | dataset_object = get_wikitext_dataset(
724 | tokenizer,
725 | split="test",
726 | block_size=block_size,
727 | num_workers=1,
728 | )
729 | elif dataset == "openwebtext":
730 | dataset_object = get_openwebtext_dataset(
731 | tokenizer,
732 | block_size=block_size,
733 | num_workers=8,
734 | )
735 | elif dataset == "mmap":
736 | assert dataset_path is not None
737 | dataset_object = get_mmap_dataset(
738 | tokenizer, dataset_path, chunk_size=block_size
739 | )
740 | collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
741 | else:
742 | raise ValueError(f"Unexpected dataset '{dataset}'")
743 |
744 | sampler: Optional[Sampler] = (
745 | DistributedSampler(dataset_object, shuffle=shuffle, seed=seed)
746 | if is_distributed
747 | else None
748 | )
749 |
750 | dataloader: DataLoader = DataLoader(
751 | dataset_object,
752 | collate_fn=collator,
753 | batch_size=batch_size,
754 | shuffle=shuffle if sampler is None else False,
755 | sampler=sampler,
756 | )
757 |
758 | return dataloader
759 |
760 |
761 | def cycle_through_epochs(dataloader: DataLoader, is_distributed: bool):
762 | epoch = 0
763 | while True:
764 | if is_distributed and isinstance(dataloader.sampler, DistributedSampler):
765 | dataloader.sampler.set_epoch(epoch)
766 | for batch in dataloader:
767 | yield batch
768 | epoch += 1
769 |
770 |
771 | def set_seeds(seed):
772 | random.seed(seed)
773 | np.random.seed(seed)
774 | torch.manual_seed(seed)
775 | if torch.cuda.is_available():
776 | torch.cuda.manual_seed_all(seed)
777 |
778 |
779 | if __name__ == "__main__":
780 | main()
781 |
--------------------------------------------------------------------------------
/gpt_pretrain.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import glob
3 | import os
4 | import re
5 | import random
6 | import logging
7 | import numpy as np
8 | import math
9 | from tqdm import tqdm
10 | import time
11 | import torch
12 | from transformers import AutoTokenizer, AutoModelForMaskedLM, AutoModelForCausalLM
13 | from transformers import BertForMaskedLM, RobertaForMaskedLM, GPT2LMHeadModel
14 | from transformers import AutoConfig
15 | from transformers import DataCollatorForLanguageModeling
16 | from transformers.optimization import AdamW, get_linear_schedule_with_warmup
17 |
18 | from torch.utils.data import Dataset, DataLoader
19 | import pytorch_lightning as ptl
20 | from pytorch_lightning.trainer.training_loop import TrainLoop
21 | from pytorch_lightning.loggers.test_tube import TestTubeLogger
22 | from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
23 | from pytorch_lightning.trainer.connectors.checkpoint_connector import CheckpointConnector
24 | from pytorch_lightning.utilities import AMPType, rank_zero_warn
25 | from pytorch_lightning.utilities.exceptions import MisconfigurationException
26 | from pytorch_lightning.utilities.upgrade_checkpoint import KEYS_MAPPING as DEPRECATED_CHECKPOINT_KEYS
27 | from torch.optim.lr_scheduler import LambdaLR
28 | import copy
29 |
30 | import multiprocessing
31 | try:
32 | from apex import amp
33 | except ImportError:
34 | amp = None
35 |
36 | import tensorflow as tf
37 | logging.basicConfig(level=logging.INFO)
38 | logger = logging.getLogger(__name__)
39 |
40 | try:
41 | import torch_xla.core.xla_model as xm
42 | except ImportError:
43 | XLA_AVAILABLE = False
44 | else:
45 | XLA_AVAILABLE = True
46 |
47 | # =======restart the linear warmup strategy with linear warmup==========
48 | def get_restart_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1, restart_warmup_steps=0, restart_steps=0):
49 | """
50 | Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after
51 | a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer.
52 |
53 | Args:
54 | optimizer (:class:`~torch.optim.Optimizer`):
55 | The optimizer for which to schedule the learning rate.
56 | num_warmup_steps (:obj:`int`):
57 | The number of steps for the warmup phase.
58 | num_training_steps (:obj:`int`):
59 | The total number of training steps.
60 | last_epoch (:obj:`int`, `optional`, defaults to -1):
61 | The index of the last epoch when resuming training, will be modified.
62 | restart_warmup_steps:
63 | the restart_warmup_steps should be set last_epoch + restart_warmup_steps;
64 |
65 | Return:
66 | :obj:`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
67 | """
68 |
69 | def lr_lambda(current_step: int):
70 |
71 | if restart_steps != 0 and restart_warmup_steps != 0 \
72 | and current_step < restart_steps + restart_warmup_steps \
73 | and current_step >= restart_steps:
74 | assert current_step >= restart_steps
75 |
76 | # pre-warmup + restart-warmup
77 | if current_step < num_warmup_steps:
78 | return float(current_step - restart_steps) / float(max(1, restart_warmup_steps)) * float(restart_steps+restart_warmup_steps) / float(max(1, num_warmup_steps))
79 | else:
80 | return float(current_step - restart_steps) / float(max(1, restart_warmup_steps)) * float(num_training_steps - restart_steps-restart_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
81 |
82 | if current_step < num_warmup_steps:
83 | return float(current_step) / float(max(1, num_warmup_steps))
84 | return max(
85 | 0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps))
86 | )
87 |
88 | return LambdaLR(optimizer, lr_lambda, last_epoch)
89 |
90 | # weights growth operator =================
91 | def double_split_matrix_weight(x, is_grad, is_avg_sq):
92 | embed_dim = x.shape[0]
93 | split_dim = x.shape[1] // x.shape[0]
94 | y_shape = [2 * i for i in x.shape]
95 | y = x.new_zeros(*y_shape)
96 |
97 | for split_idx in range(split_dim):
98 | start_idx, end_idx = split_idx * x.shape[0], (1+split_idx) * x.shape[0]
99 | y_split = double_matrix_weight( x[:,start_idx:end_idx], is_grad, is_avg_sq )
100 | y[:,start_idx*2:end_idx*2] = y_split.detach().clone()
101 |
102 | return y
103 |
104 | def double_split_bias(x, embed_dim, is_grad, is_avg_sq):
105 | split_dim = x.shape[0] // embed_dim
106 | y_shape = [2 * i for i in x.shape]
107 | y = x.new_zeros(*y_shape)
108 |
109 | for split_idx in range(split_dim):
110 | start_idx, end_idx = split_idx * embed_dim, (1+split_idx) * embed_dim
111 | y[start_idx*2:start_idx*2+embed_dim] = x[start_idx:end_idx].detach().clone()
112 | y[start_idx*2+embed_dim:end_idx*2] = x[start_idx:end_idx].detach().clone()
113 |
114 | if is_grad:
115 | y /= 2.0 if not is_avg_sq else 4.0
116 |
117 | return y
118 |
119 | def double_matrix_weight(x, is_grad, is_avg_sq):
120 | # x = (n, m), returns y = (2 * n, 2 * m), used for FF layers
121 | y_shape = [2 * i for i in x.shape]
122 | y = x.new_zeros(*y_shape)
123 |
124 | x_shape = x.shape
125 | y[:x_shape[0], :x_shape[1]] = x.detach().clone()
126 | y[-x_shape[0]:, -x_shape[1]:] = x.detach().clone()
127 | if is_grad:
128 | y /= 2.0 if not is_avg_sq else 4.0
129 |
130 |
131 | return y
132 |
133 |
134 | def double_bias(x, is_grad, is_avg_sq):
135 | # x = (n, ), returns y = (2 * n, ), used for bias weights
136 | y_shape = [2 * i for i in x.shape]
137 | y = x.new_zeros(*y_shape)
138 |
139 | x_shape = x.shape
140 |
141 | y[:x_shape[0]] = x.detach().clone()
142 | y[-x_shape[0]:] = x.detach().clone()
143 | if is_grad:
144 | y /= 2.0 if not is_avg_sq else 4.0
145 |
146 | return y
147 |
148 |
149 | def double_embedding(x, is_grad, is_avg_sq):
150 | # x = (vocab size, M), returns y = (vocab size, 2 * M), used for embedding layers
151 | y_shape = [i for i in x.shape]
152 | y_shape[1] *= 2
153 | y = x.new_zeros(*y_shape)
154 |
155 | x_shape = x.shape
156 | y[:, :x_shape[1]] = x.detach().clone()
157 | y[:, -x_shape[1]:] = x.detach().clone()
158 | if is_grad:
159 | y /= 2.0 if not is_avg_sq else 4.0
160 | # if args.noise_std is not None and args.noise_std != 0.0:
161 | # y += torch.normal(mean=0, std=args.noise_std, size=y.shape)
162 | return y
163 |
164 |
165 | def double_param(key, weight, is_double_embedding, is_grad, is_avg_sq):
166 | if 'lm_head' in key: # for roberta
167 | # the lm_head is a linear layer then the softmax layer
168 | if 'dense' in key:
169 | # this is the linear layer - need to expand as other linear layers
170 | if 'weight' in key:
171 | return double_matrix_weight(weight, is_grad=is_grad, is_avg_sq=is_avg_sq)
172 | elif 'bias' in key:
173 | return double_bias(weight, is_grad=is_grad, is_avg_sq=is_avg_sq)
174 |
175 | elif 'layer_norm' in key or 'ln' in key:
176 | # layer norm is weight * (x - mean)/std + bias, so need to divide both weight and bias by 2.0
177 | new_weight = double_bias(weight, is_grad=False, is_avg_sq=False)
178 | if not is_grad:
179 | new_weight /= 2.0
180 | return new_weight
181 | elif 'weight' in key:
182 | return double_embedding(weight, is_grad=is_grad, is_avg_sq=is_avg_sq)
183 | elif 'bias' in key:
184 | # this is the bias parameter added for the final softmax logit, shape (vocab_size, )
185 | return weight
186 |
187 | elif 'pooler' in key:
188 | # don't think this pooler is used without next-sentence-prediction in bert, but we'll double it anyway
189 | if 'weight' in key:
190 | return double_matrix_weight(weight, is_grad=is_grad, is_avg_sq=is_avg_sq)
191 | elif 'bias' in key:
192 | return double_bias(weight, is_grad=is_grad, is_avg_sq=is_avg_sq)
193 |
194 | elif 'cls' in key:
195 | # the masked LM head.
196 | # in BERT it is top layer activations -> dense with same hidden dim -> activation -> layer norm -> decoder
197 | # where the decoder is linear that has the same weights as the word embeddings and a new bias layer
198 | # to maintain the same loss, we want the logit outputs to remain the same - so can do it by duplicating
199 | # the word embeddings / bias, and dividing the input by 2. We accomplish the two division by modifying the
200 | # layer norm parameters right before prediction.
201 | #
202 | # cls.predictions.bias torch.Size([30522])
203 | # cls.predictions.transform.dense.weight torch.Size([768, 768])
204 | # cls.predictions.transform.dense.bias torch.Size([768])
205 | # cls.predictions.transform.LayerNorm.weight torch.Size([768])
206 | # cls.predictions.transform.LayerNorm.bias torch.Size([768])
207 | # cls.predictions.decoder.weight torch.Size([30522, 768])
208 | # cls.predictions.decoder.bias torch.Size([30522])
209 | if key.endswith('cls.predictions.bias') or key.endswith('cls.predictions.decoder.bias'):
210 | # these are size(vocab) and remain unchanged
211 | return weight
212 | elif key.endswith('cls.predictions.transform.dense.bias'):
213 | return double_bias(weight, is_grad=is_grad, is_avg_sq=is_avg_sq)
214 | elif key.endswith('cls.predictions.transform.dense.weight'):
215 | return double_matrix_weight(weight, is_grad=is_grad, is_avg_sq=is_avg_sq)
216 | elif key.endswith('cls.predictions.decoder.weight'):
217 | return double_embedding(weight, is_grad=is_grad, is_avg_sq=is_avg_sq)
218 | elif 'LayerNorm' in key:
219 | # layer norm is weight * (x - mean)/std + bias, so need to divide both weight and bias by 2.0
220 | new_weight = double_bias(weight, is_grad=False, is_avg_sq=False)
221 | if not is_grad:
222 | new_weight /= 2.0
223 | return new_weight
224 |
225 | elif 'word_embeddings' in key or 'position_embeddings' in key or 'token_type_embeddings' in key \
226 | or "wte.weight" in key or "wpe.weight" in key:
227 | if is_double_embedding:
228 | return double_embedding(weight, is_grad=is_grad, is_avg_sq=is_avg_sq)
229 | else:
230 | return weight.detach().clone()
231 | elif "masked_bias" in key or ("attn.bias" in key and len(weight.shape) != 1):
232 | return weight.detach().clone()
233 | elif "c_attn.weight" in key:
234 | return double_split_matrix_weight(weight, is_grad=is_grad, is_avg_sq=is_avg_sq)
235 | elif "c_attn.bias" in key:
236 | # TODO: this is hacked for GPT2
237 | return double_split_bias(weight, embed_dim=weight.shape[0] // 3, is_grad=is_grad, is_avg_sq=is_avg_sq)
238 | elif 'query.weight' in key or 'key.weight' in key or 'value.weight' in key or 'dense.weight' in key \
239 | or "c_proj.weight" in key or "c_fc.weight" in key:
240 | return double_matrix_weight(weight, is_grad=is_grad, is_avg_sq=is_avg_sq)
241 | elif "ln_f" in key:
242 | new_weight = double_bias(weight, is_grad=False, is_avg_sq=False)
243 | if not is_grad:
244 | new_weight /= 2.0
245 | return new_weight
246 | elif 'LayerNorm' in key or 'bias' in key or 'ln' in key:
247 | return double_bias(weight, is_grad=is_grad, is_avg_sq=is_avg_sq)
248 | elif 'position_ids' in key:
249 | return weight
250 |
251 | # Not found
252 | print(key)
253 | import ipdb; ipdb.set_trace()
254 | # raise ValueError(key, weight.shape)
255 |
256 |
257 | def double_state_dict(old_state_dict, is_double_embedding):
258 | new_state_dict = {}
259 | for key, weight in old_state_dict.items():
260 | new_state_dict[key] = double_param(key, weight, is_double_embedding=is_double_embedding, is_grad=False, is_avg_sq=False)
261 | return new_state_dict
262 |
263 | # depth growth operator
264 | def deep_split_matrix_weight(x, is_identical, is_grad, is_avg_sq):
265 | if not is_identical:
266 | return x.detach().clone()
267 |
268 | embed_dim = x.shape[0]
269 | split_dim = x.shape[1] // x.shape[0]
270 | y_shape = [i for i in x.shape]
271 | y = x.new_zeros(*y_shape)
272 |
273 | for split_idx in range(split_dim):
274 | start_idx, end_idx = split_idx * x.shape[0], (1+split_idx) * x.shape[0]
275 | y_split = deep_matrix_weight( x[:,start_idx:end_idx], is_identical, is_grad, is_avg_sq )
276 | y[:,start_idx:end_idx] = y_split.detach().clone()
277 |
278 | return y
279 |
280 | def deep_matrix_weight(x, is_identical, is_grad, is_avg_sq):
281 | # x = (n, m), returns y = (2 * n, 2 * m), used for FF layers
282 | if is_identical:
283 | y = torch.zeros_like(x)
284 | if len(y.shape) > 1:
285 | y.fill_diagonal_(1)
286 | return y
287 | else:
288 | return x.detach().clone()
289 |
290 |
291 | def deep_bias(x, is_identical, is_grad, is_avg_sq):
292 | # x = (n, ), returns y = (2 * n, ), used for bias weights
293 | if is_identical:
294 | return torch.zeros_like(x)
295 | else:
296 | return x.detach().clone()
297 |
298 | def deep_param(key, weight, is_identical, is_grad, is_avg_sq):
299 | if "c_attn.weight" in key:
300 | return deep_split_matrix_weight(weight, is_identical=is_identical, is_grad=is_grad, is_avg_sq=is_avg_sq)
301 | elif 'weight' in key:
302 | return deep_matrix_weight(weight, is_identical=is_identical, is_grad=is_grad, is_avg_sq=is_avg_sq)
303 | elif 'bias' in key:
304 | return deep_bias(weight, is_identical=is_identical, is_grad=is_grad, is_avg_sq=is_avg_sq)
305 |
306 | def deep_state_dict(old_state_dict, map_positions, is_identical):
307 | # how to insert layers: direct copy, identical copy
308 | # operator over the blocks: hacked for GPT-3
309 | new_state_dict = {}
310 | for key, weight in old_state_dict.items():
311 | if map_positions.get( key ):
312 | for (new_key, new_key_copy_flag) in map_positions.get( key ):
313 | # print( new_key_copy_flag, is_identical, new_key, key )
314 | new_state_dict[new_key] = deep_param(key, weight, is_identical=new_key_copy_flag and is_identical, is_grad=False, is_avg_sq=False)
315 | else:
316 | new_state_dict[key] = weight.detach().clone()
317 |
318 | return new_state_dict
319 |
320 |
321 | def test_double_matrix_weight():
322 | weight = torch.rand(52, 88)
323 | x = torch.rand(88, 1)
324 | reset_model_opt_copy=True
325 | weight2 = double_matrix_weight(weight, is_grad=False, is_avg_sq=False, reset_model_opt_copy=reset_model_opt_copy)
326 | y = torch.matmul(weight, x)
327 | y2 = torch.matmul(weight2, torch.cat([x, x], dim=0))
328 | print(torch.allclose(y, y2[:52], atol=1e-05, rtol=1e-03))
329 | assert torch.abs(y - y2[:52]).max() + torch.abs(y - y2[-52:]).max() < 1e-4
330 |
331 | x = torch.rand(1, 11)
332 | c_attn = torch.rand(11, 11*3)
333 | print(len(torch.matmul(x, c_attn).split(11, dim=1)))
334 | y0, y1, y2 = torch.matmul(x, c_attn).split(11, dim=1)
335 |
336 | c_attn2 = double_split_matrix_weight(c_attn, is_grad=False, is_avg_sq=False, reset_model_opt_copy=reset_model_opt_copy)
337 |
338 | y00, y11, y22 = torch.matmul(torch.cat([x, x], dim=1), c_attn2).split(11*2, dim=1)
339 |
340 | allcose = torch.allclose(y0, y00[:, :11], atol=1e-05, rtol=1e-03)
341 | print("reset_model_opt_copy", reset_model_opt_copy, allcose, y0.sum(), y00[:, :11].sum(), y00[:, 11:].sum())
342 | if not allcose:
343 | import ipdb; ipdb.set_trace()
344 | # import ipdb; ipdb.set_trace()
345 |
346 |
347 | def test_double_gradients():
348 | test_double_matrix_weight()
349 | # config = AutoConfig.from_pretrained('roberta-base')
350 | # config.hidden_size = 4
351 | # config.intermediate_size = 16
352 | # config.max_position_embeddings = 8
353 | # config.num_attention_heads = 1
354 | # config.num_hidden_layers = 1
355 | # config.vocab_size = 6
356 |
357 | # config.attention_probs_dropout_prob = 0
358 | # config.hidden_dropout_prob = 0
359 | # model = RobertaForMaskedLM(config=config)
360 |
361 | from gpt_pretrain import double_weights, double_param
362 | from transformers import AutoConfig, RobertaForMaskedLM, AutoModelForMaskedLM
363 | from transformers.optimization import AdamW, get_linear_schedule_with_warmup
364 | import torch
365 | # model = AutoModelForMaskedLM.from_pretrained('bert-base-uncased')
366 | model = AutoModelForCausalLM.from_pretrained('gpt2')
367 |
368 | optimizer = AdamW(model.parameters(), lr=0.00000, betas=(0.0, 0.0))
369 | model.eval()
370 | input_ids = torch.tensor([[1, 2, 3, 4]])
371 | labels = torch.tensor([[1, 2, 3, 4]])
372 | loss = model(input_ids=input_ids, labels=labels)[0]
373 | loss.backward()
374 | optimizer.step()
375 | # model.roberta.embeddings.word_embeddings.weight.grad
376 | reset_model_opt_copy=True
377 | reset_model_noise=False
378 | double_model = double_weights(model, is_double_embedding=True, reset_model_opt_copy=reset_model_opt_copy,reset_model_noise=reset_model_noise)
379 | double_optimizer = AdamW(double_model.parameters(), lr=0.00000, betas=(0.0, 0.0))
380 | double_model.eval()
381 | double_loss = double_model(input_ids=input_ids, labels=labels)[0]
382 | double_loss.backward()
383 | double_optimizer.step()
384 |
385 | print(double_loss.item(), loss.item(), torch.allclose(double_loss, loss, atol=1e-05, rtol=1e-03))
386 | assert torch.allclose(double_loss, loss, atol=1e-05, rtol=1e-03)
387 | # exit()
388 | for (name, parameter), (double_name, double_parameter), (opt_key, opt_val), (double_opt_key, double_opt_val) in \
389 | zip(model.named_parameters(), double_model.named_parameters(),
390 | optimizer.state.items(), double_optimizer.state.items()):
391 | assert name == double_name
392 | assert id(parameter) == id(opt_key)
393 | assert id(double_parameter) == id(double_opt_key)
394 | predicted = double_param(name, parameter.grad, is_double_embedding=True, is_grad=True, is_avg_sq=False, reset_optimizer_copy=True, reset_model_opt_copy=reset_model_opt_copy)
395 | all_close = torch.allclose(predicted, double_parameter.grad, atol=1e-05, rtol=1e-03)
396 |
397 | if not all_close:
398 | print('1', all_close, name, parameter.shape, )
399 | print(predicted)
400 | print(double_parameter.grad)
401 |
402 | predicted = double_param(name, opt_val['exp_avg'], is_double_embedding=True, is_grad=True, is_avg_sq=False, reset_optimizer_copy=True, reset_model_opt_copy=reset_model_opt_copy)
403 | all_close = torch.allclose(predicted, double_opt_val['exp_avg'], atol=1e-05, rtol=1e-03)
404 | if not all_close:
405 | print('2', all_close, name, parameter.shape, )
406 | print(predicted)
407 | print(double_opt_val['exp_avg'],)
408 |
409 | predicted = double_param(name, opt_val['exp_avg_sq'], is_double_embedding=True, is_grad=True, is_avg_sq=True, reset_optimizer_copy=True, reset_model_opt_copy=reset_model_opt_copy)
410 | all_close = torch.allclose(predicted, double_opt_val['exp_avg_sq'], atol=1e-05, rtol=1e-03)
411 | if not all_close:
412 | print('3', all_close, name, parameter.shape, )
413 | print(predicted)
414 | print(double_opt_val['exp_avg_sq'],)
415 | import ipdb; ipdb.set_trace()
416 | else:
417 | print('3', all_close, name, parameter.shape, )
418 |
419 |
420 | import ipdb; ipdb.set_trace()
421 |
422 | def double_weights(model, is_double_embedding):
423 | print(model)
424 | config = model.config
425 |
426 | # create an instance of the model twice the size
427 | new_config_dict = config.to_dict()
428 | new_config_dict['n_embd'] *= 2
429 | new_config_dict['n_inner'] = new_config_dict['n_inner']*2 if new_config_dict['n_inner'] is not None else None
430 | new_config_dict['n_head'] *= 2
431 |
432 | new_config = type(config).from_dict(new_config_dict)
433 | new_model = type(model)(new_config)
434 |
435 | # load the weights from the old model into new model after duplicating them
436 | model.tie_weights()
437 | new_model.tie_weights()
438 |
439 | new_state_dict = double_state_dict(model.state_dict(), is_double_embedding=is_double_embedding)
440 | new_model.load_state_dict(new_state_dict)
441 | new_model.tie_weights()
442 |
443 | return new_model
444 |
445 |
446 | def double_depth(model):
447 | print(model)
448 | config = model.config
449 |
450 | # create an instance of the model twice the size
451 | new_config_dict = config.to_dict()
452 | print(new_config_dict)
453 | new_config_dict['num_hidden_layers'] *= 2
454 |
455 | new_config = type(config).from_dict(new_config_dict)
456 | new_model = type(model)(new_config)
457 |
458 | # load the weights from the old model into new model after duplicating them
459 | model.tie_weights()
460 | new_model.tie_weights()
461 |
462 | new_state_dict = deep_state_dict(model.state_dict())
463 | new_model.load_state_dict(new_state_dict)
464 | new_model.tie_weights()
465 |
466 | return new_model
467 |
468 |
469 | # the dataset object we are using
470 | class MMapTextDataset(Dataset):
471 | def __init__(self, mmap_filename, chunk_size, bos_token_id, eos_token_id):
472 | # `chunk_size - 2` to reserve space for and
473 | self.num_instances = np.memmap(mmap_filename, mode='r', dtype=np.uint16).shape[0] // (chunk_size - 2)
474 | # defer loading the token_ids memmap until after the first __getitem__ call.
475 | # when spawning new processes for ddp, there is a hard limit in python < 3.8 that
476 | # pickle files need to be < 4GB. By waiting until after the first __getitem__ we
477 | # don't have to pickle the memmap
478 | self.token_ids = None
479 | self._mmap_filename = mmap_filename
480 | self._chunk_size = chunk_size
481 | self._bos_token_id = bos_token_id
482 | self._eos_token_id = eos_token_id
483 |
484 | def __len__(self):
485 | return self.num_instances
486 |
487 | def __getitem__(self, i):
488 | if self.token_ids is None:
489 | self.token_ids = np.memmap(self._mmap_filename, mode='r', dtype=np.uint16)
490 | from_index = i * (self._chunk_size - 2)
491 | to_index = (i + 1) * (self._chunk_size - 2)
492 | data = np.concatenate(([self._bos_token_id], self.token_ids[from_index:to_index], [self._eos_token_id]))
493 | return torch.tensor(data, dtype=torch.long)
494 |
495 | # ========================= preprocessing code ========================= #
496 | @staticmethod
497 | def _process_file(full_fname):
498 | "Step 1: tokenize an input text file then save token ids into `np.memmap` shards of size `args.shard_size`"
499 | fname = full_fname.split('/')[-1]
500 | if args.data_type == 'tfrecord':
501 | log_filename = f'{args.output_dir}/logs-{fname}.log'
502 | elif args.data_type == 'raw_text':
503 | log_filename = f'{args.output_dir}/logs-{args.shard_size}/{fname}.log'
504 | if os.path.isfile(log_filename):
505 | logging.info(f'Skipping {full_fname} ...')
506 | return # log file already exists. Skip current file.
507 |
508 | if args.num_workers > 1:
509 | current = multiprocessing.current_process()
510 | process_identity = int(current._identity[0])
511 | else:
512 | process_identity = 1
513 |
514 | if process_identity == 1:
515 | logging.info(f'Processing {full_fname} ...')
516 |
517 | def _write_shard():
518 | if len(token_list) == 0:
519 | return
520 | # if token_list[-1] != MMapTextDataset.tokenizer.sep_token_id: # handle a rare case
521 | # token_list.append(MMapTextDataset.tokenizer.sep_token_id)
522 | if args.data_type in ['tfrecord', 's2']:
523 | shared_filename = f'{args.output_dir}/{fname}.bin'
524 | elif args.data_type == 'raw_text':
525 | shared_filename = f'{args.output_dir}/shards-{args.shard_size}/{fname}-{shard_count}.bin'
526 | else:
527 | raise NotImplementedError
528 | logging.info(f'Writing {len(token_list)} tokens to shared {shared_filename}')
529 | fp = np.memmap(shared_filename, dtype=np.uint16, mode='w+', shape=len(token_list))
530 | fp[:] = token_list[:]
531 | del fp # flush and close file
532 |
533 | token_list = []
534 | shard_count = 0
535 | tokens_count = 0
536 |
537 | if args.data_type == 'raw_text': # the input file is one doc per line
538 | with open(full_fname, 'r') as fin:
539 | for line in tqdm(fin):
540 | line = line.strip()
541 | if line == '': # drop empty lines
542 | continue
543 | tokens = MMapTextDataset.tokenizer.encode(line, add_special_tokens=False) # `__getitem__` adds special tokens
544 | token_list.extend(tokens)
545 | if len(token_list) > args.shard_size:
546 | _write_shard()
547 | tokens_count += len(token_list)
548 | token_list = []
549 | shard_count += 1
550 | else:
551 | token_list.append(MMapTextDataset.tokenizer.sep_token_id)
552 | _write_shard()
553 | tokens_count += len(token_list)
554 | elif args.data_type == 'tfrecord': # the input file is tfrecord format of the c4 dataset
555 | fin = tf.data.TFRecordDataset(full_fname)
556 | for raw_example in tqdm(iter(fin), disable=process_identity != 1):
557 | parsed = tf.train.Example.FromString(raw_example.numpy())
558 | feature_keys = set(parsed.features.feature.keys())
559 | if 'text' in feature_keys:
560 | line = parsed.features.feature['text'].bytes_list.value[0].decode() # raw text
561 | tokens = MMapTextDataset.tokenizer.encode(line, add_special_tokens=False) # `__getitem__` adds special tokens
562 | if args.add_sep_after_doc:
563 | tokens.append(MMapTextDataset.tokenizer.sep_token_id)
564 | token_list.extend(tokens)
565 | tokens_count += len(token_list)
566 | shard_count += 1
567 | _write_shard()
568 |
569 | with open(log_filename, 'w') as f:
570 | f.write(f'Generated {tokens_count} tokens in {shard_count + 1} shards')
571 |
572 | @staticmethod
573 | def _combine_shards(output_fname, shards_list):
574 | "Step 2: combining memmap shards into one `train.bin` or `val.bin` file"
575 | total_size = 0
576 | for filename in shards_list:
577 | total_size += np.memmap(filename, mode='r', dtype=np.uint16).shape[0]
578 | logging.info(f'Writing {total_size} tokens to {output_fname}')
579 | all_token_ids = np.empty(total_size, dtype=np.uint16)
580 | last_token_index = 0
581 | for filename in tqdm(shards_list):
582 | shared = np.memmap(filename, mode='r', dtype=np.uint16)
583 | all_token_ids[last_token_index:last_token_index+len(shared)] = shared[:]
584 | last_token_index += len(shared)
585 | fp = np.memmap(output_fname, dtype=np.uint16, mode='w+', shape=total_size)
586 | fp[:] = all_token_ids[:]
587 | del fp
588 |
589 | @staticmethod
590 | def raw_text_to_mmap(args):
591 | """This is the main preprocessing function. It processes all the text files in `args.input_dir` and
592 | outputs two np.memmap files, one for training and one for validation with ratio `args.train_dev_split`.
593 | Processing each input file involves tokenizing it, sharding it into shards of size `args.shard_size`,
594 | then writing each shard as an np.memmap file, shuffle the shards, split them into train and dev shards,
595 | then combine the shards of each set into one big file (train.bin and val.bin).
596 | Notice that only the shards are shuffled not the instances inside each shard. Therefor, it is important
597 | to use `args.shard_size` that's small enough to have a good train/dev split, but also not small enough
598 | to end up with a huge number of shards that might be difficult to work with.
599 | The stream of tokens in the memmap files represents documents separated with `tokenizer.sep_token`.
600 | In `__getitem__`, the `tokenizer.bos_token` and `tokenizer.eos_token`
601 | are added. The reason for not adding them at preprocessing time is to allow different sequence lengths
602 | later on. Notice that this is the "FULL-SENTENCES" setting in the RoBERTa paper, Table2.
603 | Example running the preprocessing:
604 | >>> python scripts/pretrain.py --input_dir dirWithTextFiles --train_dev_split 0.05 \
605 | --shard_size 268435456 --num_preprocessing_workers 16
606 | """
607 | MMapTextDataset.tokenizer = AutoTokenizer.from_pretrained(args.tokenizer, use_fast=True)
608 | assert len(MMapTextDataset.tokenizer) < 65535 # will use uint16 to store token ids
609 | all_files = glob.glob(f'{args.input_dir}/c4-*')
610 | print(len(all_files), MMapTextDataset.tokenizer)
611 | if os.path.exists(f'{args.output_dir}/cache/train.bin') and os.path.exists(f'{args.input_dir}/cache/val.bin'):
612 | logger.info("Cache already exists. Remove the cache directory to regenerate")
613 | return
614 | try:
615 | os.mkdir(f'{args.output_dir}/cache/')
616 | except FileExistsError:
617 | pass
618 | try:
619 | os.mkdir(f'{args.output_dir}/shards-{args.shard_size}/')
620 | except FileExistsError:
621 | pass
622 | try:
623 | os.mkdir(f'{args.output_dir}/logs-{args.shard_size}/') # log progrss to be able to resume
624 | except FileExistsError:
625 | pass
626 |
627 | # STEP1: tokenizing and saving to shards
628 | if args.num_preprocessing_workers > 1:
629 | from multiprocessing.pool import Pool
630 | with Pool(args.num_preprocessing_workers) as p:
631 | list(tqdm(p.imap(MMapTextDataset._process_file, all_files), total=len(all_files)))
632 | else:
633 | [MMapTextDataset._process_file(f) for f in tqdm(all_files)]
634 |
635 | if args.data_type == 'raw_text': # c4 tfrecords are already sharded
636 | # STEP2: shuffling shards and combining them into train.bin and val.bin files
637 | all_shards = glob.glob(f'{args.output_dir}/shards-{args.shard_size}/*.bin')
638 | random.shuffle(all_shards) # shuffling based on shards not individual lines
639 | val_shards_count = int(args.train_dev_split * len(all_shards))
640 | val_shards = all_shards[:val_shards_count]
641 | train_shards = all_shards[val_shards_count:]
642 | # TODO: if MMapTextDataset._combining_shards is very slow for large files, it can be skipped but we nned to
643 | # update the dataset to read from multiple shards directly
644 | MMapTextDataset._combine_shards(f'{args.output_dir}/cache/val.bin', val_shards)
645 | MMapTextDataset._combine_shards(f'{args.output_dir}/cache/train.bin', train_shards)
646 | elif args.data_type == 'tfrecord':
647 | train_shards = glob.glob(f'{args.output_dir}/*train*.bin')
648 | val_shards = glob.glob(f'{args.output_dir}/*val*.bin')
649 | MMapTextDataset._combine_shards(f'{args.output_dir}/val.bin', val_shards)
650 | MMapTextDataset._combine_shards(f'{args.output_dir}/train.bin', train_shards)
651 | del MMapTextDataset.tokenizer
652 | # ========================= end preprocessing code ========================= #
653 |
654 |
655 | class MyCheckpointConnector(CheckpointConnector):
656 | def __init__(self, trainer, reset_optimizer=False, reset_lr_scheduler=False, set_global_step=None):
657 | super().__init__(trainer)
658 | self.reset_optimizer = reset_optimizer
659 | self.reset_lr_scheduler = reset_lr_scheduler
660 | self.set_global_step = set_global_step
661 |
662 | def restore_training_state(self, checkpoint, load_optimizer_states: bool = True):
663 | """
664 | COPIED from https://github.com/PyTorchLightning/pytorch-lightning/blob/1.0.8/pytorch_lightning/trainer/connectors/checkpoint_connector.py#L130-L199
665 | and updated to support reset_optimizer and reset_lr_scheduler
666 | """
667 | # validation
668 | if 'optimizer_states' not in checkpoint or 'lr_schedulers' not in checkpoint:
669 | raise KeyError(
670 | 'Trying to restore training state but checkpoint contains only the model.'
671 | ' This is probably due to `ModelCheckpoint.save_weights_only` being set to `True`.'
672 | )
673 |
674 | if any([key in checkpoint for key in DEPRECATED_CHECKPOINT_KEYS]):
675 | raise ValueError(
676 | "The checkpoint you're attempting to load follows an"
677 | " outdated schema. You can upgrade to the current schema by running"
678 | " `python -m pytorch_lightning.utilities.upgrade_checkpoint --file model.ckpt`"
679 | " where `model.ckpt` is your checkpoint file."
680 | )
681 |
682 | # restore amp scaling
683 | if self.trainer.amp_backend == AMPType.NATIVE and 'native_amp_scaling_state' in checkpoint:
684 | self.trainer.scaler.load_state_dict(checkpoint['native_amp_scaling_state'])
685 | elif self.trainer.amp_backend == AMPType.APEX and 'amp_scaling_state' in checkpoint:
686 | amp.load_state_dict(checkpoint['amp_scaling_state'])
687 |
688 | # restore callback states
689 | self.trainer.on_load_checkpoint(checkpoint)
690 |
691 | self.trainer.global_step = checkpoint['global_step']
692 | if self.set_global_step is not None:
693 | self.trainer.global_step = self.set_global_step
694 | self.trainer.current_epoch = checkpoint['epoch']
695 |
696 | # crash if max_epochs is lower then the current epoch from the checkpoint
697 | if self.trainer.current_epoch > self.trainer.max_epochs:
698 | m = f"""
699 | you restored a checkpoint with current_epoch={self.trainer.current_epoch}
700 | but the Trainer(max_epochs={self.trainer.max_epochs})
701 | """
702 | raise MisconfigurationException(m)
703 |
704 | # Division deals with global step stepping once per accumulated batch
705 | # Inequality deals with different global step for odd vs even num_training_batches
706 | n_accum = 1 if self.trainer.accumulate_grad_batches is None else self.trainer.accumulate_grad_batches
707 | expected_steps = self.trainer.num_training_batches / n_accum
708 | if self.trainer.num_training_batches != 0 and self.trainer.global_step % expected_steps > 1:
709 | rank_zero_warn(
710 | "You're resuming from a checkpoint that ended mid-epoch. "
711 | "This can cause unreliable results if further training is done, "
712 | "consider using an end of epoch checkpoint. "
713 | )
714 |
715 | if not load_optimizer_states:
716 | return
717 |
718 | # restore the optimizers
719 | if not self.reset_optimizer:
720 | optimizer_states = checkpoint['optimizer_states']
721 | for optimizer, opt_state in zip(self.trainer.optimizers, optimizer_states):
722 | print(opt_state.keys(), optimizer)
723 | # print(optimizer.param_groups.keys(), optimizer.param_groups)
724 | print([x.keys() for x in optimizer.param_groups])
725 | print([x.keys() for x in opt_state['param_groups']])
726 | optimizer.load_state_dict(opt_state)
727 |
728 | # move optimizer to GPU 1 weight at a time
729 | # avoids OOM
730 | if self.trainer.root_gpu is not None:
731 | for state in optimizer.state.values():
732 | for k, v in state.items():
733 | if isinstance(v, torch.Tensor):
734 | state[k] = v.cuda(self.trainer.root_gpu)
735 |
736 | if not self.reset_lr_scheduler:
737 | # restore the lr schedulers
738 | lr_schedulers = checkpoint['lr_schedulers']
739 | if self.set_global_step is not None:
740 | for lrs_state in lr_schedulers:
741 | lrs_state['last_epoch'] = self.set_global_step
742 | lrs_state['_step_count'] = self.set_global_step + 1
743 |
744 | for scheduler, lrs_state in zip(self.trainer.lr_schedulers, lr_schedulers):
745 | scheduler['scheduler'].load_state_dict(lrs_state)
746 | else:
747 | if self.set_global_step is not None:
748 | for scheduler in self.trainer.lr_schedulers:
749 | scheduler['scheduler'].last_epoch = self.set_global_step
750 | scheduler['scheduler']._step_count = self.set_global_step+ 1
751 |
752 |
753 | # rewrite the MyTrainLoop from pytorch-lightning to support batch size and sequence length warmup
754 | class MyTrainLoop(TrainLoop):
755 | def __init__(self, trainer, multiple_trainloader_mode, args):
756 | super().__init__(trainer, multiple_trainloader_mode)
757 | self.args = args
758 |
759 | def grad_norm(self, model, norm_type, should_accumulate=False):
760 | # Override PTL `grad_norm` function to only return `total_grad_norm` instead norms of individual params
761 | # TODO: grad_norm reporting needs to take fp16 loss scale into account
762 | # parameters = [p for p in self.parameters() if p.grad is not None]
763 | # device = parameters[0].device
764 | # total_norm = torch.zeros([], device=device if parameters else None)
765 | # norm_type = float(norm_type)
766 | # for p in parameters:
767 | # param_norm = p.grad.norm(norm_type)
768 | # total_norm.add_(param_norm)
769 | norm_type = float(norm_type)
770 |
771 | norms, all_norms = {}, []
772 | # local_norm = torch.zeros([], device=model.device)
773 | for name, p in model.named_parameters():
774 | if p.grad is None:
775 | continue
776 |
777 | if not should_accumulate:
778 | # param_norm = float(p.grad.data.norm(norm_type))
779 | p_grad = p.grad.data / args.batch_size / args.grad_accum
780 | param_norm = float( p_grad.norm(norm_type) )
781 | else:
782 | p_grad = p.grad.data / self.trainer.accelerator.precision_plugin.scaler.get_scale() / args.batch_size
783 | param_norm = float( p_grad.norm(norm_type) )
784 | all_norms.append(param_norm)
785 | # local_norm.add_(p.grad.norm(norm_type))
786 |
787 | total_norm = float(torch.tensor(all_norms).norm(norm_type))
788 | # norms[f'grad_{norm_type}_norm_total'] = round(total_norm, 4)
789 | # print("total_norm", total_norm, model.device, local_norm, self.trainer.accelerator.precision_plugin.scaler.get_scale())
790 | if not should_accumulate:
791 | return {'total_grad_norm': total_norm, "batch_size": args.batch_size * self.trainer.world_size, "grad_accum": args.grad_accum }
792 | else:
793 | return { "local_grad_norm %s" % model.device: total_norm, "local_scale": self.trainer.accelerator.precision_plugin.scaler.get_scale() }
794 |
795 | def _track_gradient_norm(self):
796 | grad_norm_dict = {}
797 | if (self.trainer.global_step + 1) % self.trainer.log_every_n_steps == 0:
798 | if float(self.trainer.track_grad_norm) > 0:
799 | model = self.trainer.lightning_module
800 | grad_norm_dict = self.grad_norm(model, self.trainer.track_grad_norm)
801 | return grad_norm_dict
802 |
803 | def backward(self, result, optimizer, opt_idx, *args, **kwargs):
804 | self.trainer.dev_debugger.track_event("backward_call")
805 |
806 | should_accumulate = self.should_accumulate()
807 | # print(should_accumulate)
808 | # backward can be called manually in the training loop
809 | if isinstance(result, torch.Tensor):
810 | self.trainer.accelerator.backward(result, optimizer, opt_idx, should_accumulate, *args, **kwargs)
811 | else:
812 | result.closure_loss = self.trainer.accelerator.backward(
813 | result.closure_loss, optimizer, opt_idx, should_accumulate, *args, **kwargs
814 | )
815 |
816 | if not self.should_accumulate():
817 | # track gradients
818 | # print("track gradient with should_accumulate False")
819 | cur_grad_norm_dict = self.track_and_norm_grad(optimizer=optimizer)
820 | if 'total_grad_norm' in self._cur_grad_norm_dict:
821 | B_small, B_big = self._cur_grad_norm_dict['batch_size'], self._cur_grad_norm_dict['batch_size'] * self._cur_grad_norm_dict['grad_accum']
822 | grad_norm_B_big = self._cur_grad_norm_dict['total_grad_norm']
823 | grad_norm_B_small = []
824 | if not hasattr(self, 'grad_norm_dict') or (hasattr(self, 'grad_norm_dict') and self.grad_norm_dict is None):
825 | B_critical = B_big
826 | else:
827 | for item in self.grad_norm_dict:
828 | if "local_grad_norm" in item:
829 | grad_norm_B_small.append( self.grad_norm_dict[item] )
830 |
831 | grad_norm_B_small = np.average(grad_norm_B_small)
832 | g2 = 1 / (B_big-B_small) * (B_big * grad_norm_B_big - B_small*grad_norm_B_small)
833 | s = 1 / (1/B_small - 1/B_big) * (grad_norm_B_small - grad_norm_B_big)
834 | B_critical = s / g2
835 | self._cur_grad_norm_dict.update( self.grad_norm_dict )
836 | self._cur_grad_norm_dict.update( {"critical_batch_size" : B_critical} )
837 | for e in ['batch_size', 'grad_accum']:
838 | self._cur_grad_norm_dict.pop(e)
839 | # print(self._cur_grad_norm_dict)
840 | self.grad_norm_dict = None
841 | else:
842 | # print("track gradient with should_accumulate True")
843 | # first gradient accumulation step !!!!!!!!!!!!
844 | if hasattr(self, 'grad_norm_dict') and self.grad_norm_dict is None:
845 | model = self.trainer.lightning_module
846 | self.grad_norm_dict = self.grad_norm(model, self.trainer.track_grad_norm, True)
847 |
848 | def run_training_epoch(self):
849 | # modify dataloader if needed (ddp, etc...)
850 | train_dataloader = self.trainer.accelerator.process_dataloader(self.trainer.train_dataloader)
851 |
852 | # track epoch output
853 | epoch_output = [[] for _ in range(self.num_optimizers)]
854 |
855 | train_dataloader = self.trainer.data_connector.get_profiled_train_dataloader(train_dataloader)
856 | dataloader_idx = 0
857 |
858 | batch_idx = None
859 | is_last_batch = None
860 |
861 | accum_bsz, accum_bsz_grad_step = 0, 0
862 | for batch_idx, (batch, is_last_batch) in train_dataloader:
863 | self.trainer.batch_idx = batch_idx
864 | self.trainer.is_last_batch = is_last_batch
865 |
866 | # warmup the batch size via truncation and gradient accumulation
867 | # hack the deepest into the PTL to make it happen
868 | if self.args.warmup_bsz != 0:
869 | # for key in batch.keys():
870 | # print(key, batch[key].shape, batch[key].device, batch[key].numel(), self.trainer.accumulate_grad_batches, self.trainer.model.device)
871 | input_ids = batch['input_ids']
872 |
873 | final_bsz = input_ids.shape[0] * self.args.grad_accum * self.trainer.world_size
874 | start_bsz = 64
875 |
876 | current_bsz = start_bsz + (final_bsz - start_bsz) * min( 1.0, accum_bsz / self.args.warmup_bsz )
877 | # print("before current_bsz", current_bsz, accum_bsz)
878 | if current_bsz >= final_bsz:
879 | self.trainer.accumulate_grad_batches = self.args.grad_accum
880 | else:
881 | current_bsz = current_bsz // self.trainer.world_size
882 | # try to reset gradient accum steps
883 | grad_accum = int(max(1, current_bsz // input_ids.shape[0]))
884 |
885 | if grad_accum == 1 or accum_bsz_grad_step <= 0:
886 | if grad_accum != 1 and accum_bsz_grad_step == 0:
887 | accum_bsz_grad_step = grad_accum
888 | self.trainer.accumulate_grad_batches = grad_accum
889 | bsz_after_chunk = int(current_bsz // self.trainer.accumulate_grad_batches)
890 | else:
891 | accum_bsz_grad_step -= 1
892 |
893 | # try to chunk the inputs
894 | # print("current_bsz", current_bsz, "grad_accum", grad_accum, self.trainer.accumulate_grad_batches, accum_bsz_grad_step, self.should_accumulate(), 'bsz_after_chunk', bsz_after_chunk, input_ids.shape[0])
895 | if bsz_after_chunk < input_ids.shape[0]:
896 | for key in batch.keys():
897 | batch[key] = torch.narrow(batch[key], 0, 0, bsz_after_chunk)#.to( self.trainer.model.device )
898 |
899 | accum_bsz += batch['input_ids'].numel()
900 |
901 | if self.args.warmup_seq != 0:
902 |
903 | input_ids = batch['input_ids']
904 |
905 | start_seq = 64
906 | final_seq = input_ids.shape[1]
907 |
908 | current_seq = int(start_seq + (final_seq - start_seq) * min( 1.0, accum_bsz / self.args.warmup_seq ))
909 | if accum_bsz_grad_step <= 0:
910 | accum_bsz_grad_step = self.trainer.accumulate_grad_batches
911 | else:
912 | accum_bsz_grad_step -= 1
913 |
914 | if current_seq < final_seq:
915 | for key in batch.keys():
916 | batch[key] = torch.narrow(batch[key], 1, 0, current_seq)
917 |
918 | accum_bsz += batch['input_ids'].numel()
919 |
920 | # ------------------------------------
921 | # TRAINING_STEP + TRAINING_STEP_END
922 | # ------------------------------------
923 | with self.trainer.profiler.profile("run_training_batch"):
924 | batch_output = self.run_training_batch(batch, batch_idx, dataloader_idx)
925 |
926 | # when returning -1 from train_step, we end epoch early
927 | if batch_output.signal == -1:
928 | break
929 |
930 | # hook
931 | # TODO: add outputs to batches
932 | self.on_train_batch_end(
933 | epoch_output,
934 | batch_output.training_step_output_for_epoch_end,
935 | batch,
936 | batch_idx,
937 | dataloader_idx,
938 | )
939 |
940 | # -----------------------------------------
941 | # SAVE METRICS TO LOGGERS
942 | # -----------------------------------------
943 | self.trainer.logger_connector.log_train_step_metrics(batch_output)
944 |
945 | # -----------------------------------------
946 | # VALIDATE IF NEEDED
947 | # -----------------------------------------
948 | should_check_val = self._should_check_val_fx(batch_idx, is_last_batch)
949 | if should_check_val:
950 | self.trainer.validating = True
951 | self.trainer.run_evaluation()
952 | self.trainer.training = True
953 |
954 | # -----------------------------------------
955 | # SAVE LOGGERS (ie: Tensorboard, etc...)
956 | # -----------------------------------------
957 | self.save_loggers_on_train_batch_end()
958 |
959 | # update LR schedulers
960 | monitor_metrics = copy.deepcopy(self.trainer.logger_connector.callback_metrics)
961 | self.update_train_loop_lr_schedulers(monitor_metrics=monitor_metrics)
962 | self.trainer.checkpoint_connector.has_trained = True
963 |
964 | self.trainer.total_batch_idx += 1
965 |
966 | # max steps reached, end training
967 | if (
968 | self.trainer.max_steps is not None and self.trainer.max_steps <= self.trainer.global_step + 1
969 | and self._accumulated_batches_reached()
970 | ):
971 | break
972 |
973 | # end epoch early
974 | # stop when the flag is changed or we've gone past the amount
975 | # requested in the batches
976 | if self.trainer.should_stop:
977 | break
978 |
979 | # stop epoch if we limited the number of training batches
980 | if self._num_training_batches_reached(is_last_batch):
981 | break
982 |
983 | # progress global step according to grads progress
984 | self.increment_accumulated_grad_global_step()
985 |
986 | if batch_idx is None:
987 | # dataloader/iterator did not produce a batch
988 | return
989 |
990 | # handle epoch_output on epoch end
991 | self.on_train_epoch_end(epoch_output)
992 |
993 | # log epoch metrics
994 | self.trainer.logger_connector.log_train_epoch_end_metrics(epoch_output)
995 |
996 | should_check_val = self._should_check_val_fx(batch_idx, is_last_batch, on_epoch=True)
997 | should_skip_eval = self.trainer.evaluation_loop.should_skip_evaluation(self.trainer.num_val_batches)
998 | should_train_only = self.trainer.disable_validation or should_skip_eval
999 |
1000 | # update epoch level lr_schedulers if no val loop outside train loop is triggered
1001 | if not should_check_val or should_train_only:
1002 | self.trainer.optimizer_connector.update_learning_rates(interval='epoch')
1003 |
1004 | if should_train_only:
1005 | self.check_checkpoint_callback(True)
1006 |
1007 | if should_check_val:
1008 | self.trainer.validating = True
1009 | self.trainer.run_evaluation(on_epoch=True)
1010 | self.trainer.training = True
1011 |
1012 | if batch_output.signal != -1:
1013 | self.increment_accumulated_grad_global_step()
1014 |
1015 | class Pretrainer(ptl.LightningModule):
1016 |
1017 | def __init__(self):
1018 | super().__init__()
1019 |
1020 | self.args = args # hparams
1021 | self._set_hparams(self.args) #v1.3.5 ptl issue
1022 | # self.hparams = self.args
1023 |
1024 | #self.model = AutoModelForMaskedLM.from_pretrained(args.model)
1025 | self.model = AutoModelForCausalLM.from_pretrained(args.model)
1026 | if args.random:
1027 | if args.layers is not None and args.size is not None:
1028 | raise False
1029 | if args.layers is not None:
1030 | self.model.config.n_layer = args.layers
1031 | if args.size is not None:
1032 | if args.size == 'GPT2_base':
1033 | self.model.config.n_layer = 12
1034 | self.model.config.n_embd = 768
1035 | self.model.config.n_head = 8
1036 | elif args.size == 'GPT2_large':
1037 | self.model.config.n_layer = 24
1038 | self.model.config.n_embd = 1536
1039 | self.model.config.n_head = 16
1040 | elif args.size == 'GPT2_base_div2_width':
1041 | self.model.config.n_layer = 12
1042 | self.model.config.n_embd = 384
1043 | self.model.config.n_head = 4
1044 | elif args.size == 'GPT2_base_div2_depth':
1045 | self.model.config.n_layer = 6
1046 | self.model.config.n_embd = 768
1047 | self.model.config.n_head = 8
1048 |
1049 | elif args.size == 'GPT2_large_div4_width':
1050 | self.model.config.n_layer = 24
1051 | self.model.config.n_embd = 384
1052 | self.model.config.n_head = 4
1053 |
1054 | elif args.size == 'GPT2_large_div2_width':
1055 | self.model.config.n_layer = 24
1056 | self.model.config.n_embd = 768
1057 | self.model.config.n_head = 8
1058 | elif args.size == 'GPT2_large_div4_depth':
1059 | self.model.config.n_layer = 6
1060 | self.model.config.n_embd = 1536
1061 | self.model.config.n_head = 16
1062 | elif args.size == 'GPT2_large_div2_depth':
1063 | self.model.config.n_layer = 12
1064 | self.model.config.n_embd = 1536
1065 | self.model.config.n_head = 16
1066 | else:
1067 | assert False
1068 |
1069 | assert self.model.config.n_positions == 1024
1070 | self.model.config.n_positions = args.seqlen
1071 | self.model = GPT2LMHeadModel(config=self.model.config)
1072 | else:
1073 | assert args.layers is None
1074 | assert args.size is None
1075 |
1076 | tokenizer = AutoTokenizer.from_pretrained(args.tokenizer)
1077 | self.pad_token_id = tokenizer.pad_token_id
1078 | self.eos_token_id = tokenizer.eos_token_id or tokenizer.sep_token_id
1079 | self.bos_token_id = tokenizer.bos_token_id or tokenizer.cls_token_id
1080 |
1081 | logger.info(f'Creating dataset cache from dir {self.args.input_dir}. This could be slow the first time.')
1082 | MMapTextDataset.raw_text_to_mmap(args)
1083 |
1084 | # TODO: add support for other objective functions (whole word masking, BART, Pegasus)
1085 | # self.data_collator = DataCollatorForLanguageModeling(
1086 | # tokenizer=tokenizer, mlm=True, mlm_probability=self.args.mlm_prob
1087 | # )
1088 | self.data_collator = DataCollatorForLanguageModeling(
1089 | tokenizer=tokenizer, mlm=False
1090 | )
1091 | self.start_time = 0
1092 |
1093 | def to(self, *args, **kwargs):
1094 | param_count_before_to = len(list(self.parameters()))
1095 | super().to(*args, **kwargs)
1096 | if self.trainer.on_tpu:
1097 | # if self.trainer.use_tpu:
1098 | # need to re-tie the weights after moving to XLA!
1099 | self.model.tie_weights()
1100 | if 'roberta' in self.args.model or 'longformer' in self.args.model:
1101 | self.model.lm_head.bias = self.model.lm_head.decoder.bias
1102 | param_count_after_to = len(list(self.parameters()))
1103 | assert param_count_before_to == param_count_after_to
1104 |
1105 | def forward(self, inputs):
1106 | # for MLM
1107 | # get the padding mask - 1 for NOT masked, 0 for MASKED/PAD
1108 | # attention_mask = (input_ids != self.pad_token_id).int()
1109 |
1110 | # output is loss, prediction_scores, hidden_states
1111 | # output = self.model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
1112 |
1113 | # for LM
1114 | output = self.model(**inputs)
1115 | return output[0] # loss
1116 |
1117 | def training_step(self, batch, batch_nb):
1118 | loss = self(batch)
1119 | input_ids = batch['input_ids']
1120 | tensorboard_logs = {
1121 | 'input_size': input_ids.numel(),
1122 | 'token_per_step': input_ids.numel() * self.trainer.accumulate_grad_batches * self.trainer.world_size,
1123 | }
1124 | # if not self.use_tpu:
1125 | if not self.trainer.on_tpu:
1126 | # logging additional losses is slow on tpu
1127 | tensorboard_logs['lm_loss'] = loss
1128 | tensorboard_logs['lm_bpc'] = loss/math.log(2)
1129 | tensorboard_logs['lm_perplexity'] = torch.exp(loss)
1130 |
1131 | if self.start_time != 0:
1132 | # torch.cuda.synchronize()
1133 | elapsed_time = time.monotonic() - self.start_time
1134 | tensorboard_logs['second_per_batch'] = elapsed_time
1135 | self.start_time = time.monotonic()
1136 |
1137 | if self.on_gpu:
1138 | tensorboard_logs['memory'] = torch.cuda.memory_allocated(loss.device) / 1024 ** 3
1139 |
1140 | for k, v in tensorboard_logs.items():
1141 | self.log(k, v)
1142 |
1143 | return {'loss': loss}
1144 |
1145 | def on_train_batch_start(self, *args, **kwargs):
1146 | self._start = time.monotonic()
1147 |
1148 | def on_train_batch_end(self, *args, **kwargs):
1149 | delta = time.monotonic() - self._start
1150 | self.log("time_per_batch", delta, on_step=True, on_epoch=False)
1151 |
1152 | def validation_step(self, batch, batch_nb):
1153 | # TODO: log how long evaluation takes
1154 | self.start_time = 0 # reset training_step timer
1155 |
1156 | loss = self(batch)
1157 | tensorboard_logs = {
1158 | 'val_lm_loss': loss.detach(),
1159 | }
1160 | return {'val_loss': tensorboard_logs["val_lm_loss"], 'log': tensorboard_logs}
1161 |
1162 | def validation_epoch_end(self, outputs):
1163 | avg_loss = torch.stack([x['log']['val_lm_loss'] for x in outputs if 'val_lm_loss' in x['log']]).mean()
1164 | if self.trainer.accelerator_connector.use_ddp:
1165 | # TODO: PTL is already doing this. Is it still needed here?
1166 | # https://github.com/PyTorchLightning/pytorch-lightning/blob/0.8.5/pytorch_lightning/metrics/converters.py#L251
1167 | torch.distributed.all_reduce(avg_loss, op=torch.distributed.ReduceOp.SUM)
1168 | avg_loss /= torch.distributed.get_world_size()
1169 | elif self.on_tpu:
1170 | avg_loss = xm.all_reduce(xm.REDUCE_SUM, avg_loss) / xm.xrt_world_size()
1171 |
1172 | self.log('val_loss', avg_loss, on_step=False, on_epoch=True, prog_bar=False, logger=True)
1173 |
1174 | def configure_optimizers(self):
1175 | # no_decay = ["bias", "LayerNorm.weight"]
1176 |
1177 | # optimizer_grouped_parameters = [
1178 | # {
1179 | # "params": [p for n, p in self.named_parameters() if not any(nd in n for nd in no_decay) and p.requires_grad],
1180 | # "weight_decay": self.args.weight_decay,
1181 | # },
1182 | # {
1183 | # "params": [p for n, p in self.named_parameters() if any(nd in n for nd in no_decay) and p.requires_grad],
1184 | # "weight_decay": 0.0,
1185 | # },
1186 | # ]
1187 | # optimizer_grouped_parameters
1188 |
1189 | optimizer = AdamW(self.parameters(), lr=self.args.lr, eps=self.args.adam_epsilon,
1190 | betas=(self.args.adam_beta1, self.args.adam_beta2),
1191 | correct_bias=False)
1192 | if self.args.restart_warmup_steps != 0 and self.args.restart_steps != 0:
1193 | scheduler = get_restart_linear_schedule_with_warmup(
1194 | optimizer, num_warmup_steps=self.args.warmup_steps, num_training_steps=self.args.train_steps,
1195 | restart_steps=self.args.restart_steps, restart_warmup_steps=self.args.restart_warmup_steps,
1196 | )
1197 | else:
1198 | scheduler = get_linear_schedule_with_warmup(
1199 | optimizer, num_warmup_steps=self.args.warmup_steps, num_training_steps=self.args.train_steps
1200 | )
1201 | return [optimizer], [{"scheduler": scheduler, "interval": "step"}]
1202 |
1203 | def _get_loader(self, fname, is_train):
1204 | dataset = MMapTextDataset(fname, chunk_size=self.args.seqlen,
1205 | bos_token_id=self.bos_token_id, eos_token_id=self.eos_token_id)
1206 |
1207 | # TODO: consider `replace_sampler_ddp=True` and removing the following if statement
1208 | # if self.trainer.use_ddp:
1209 | if self.trainer.accelerator_connector.use_ddp:
1210 | sampler = torch.utils.data.distributed.DistributedSampler(dataset, shuffle=is_train)
1211 | shuffle = False
1212 | elif self.trainer.on_tpu:
1213 | sampler = torch.utils.data.distributed.DistributedSampler(
1214 | dataset,
1215 | num_replicas=xm.xrt_world_size(),
1216 | rank=xm.get_ordinal(),
1217 | shuffle=is_train,
1218 | )
1219 | shuffle = False
1220 | else:
1221 | sampler = None
1222 | shuffle = is_train
1223 |
1224 | loader = DataLoader(
1225 | dataset,
1226 | batch_size=self.args.batch_size,
1227 | shuffle=shuffle,
1228 | sampler=sampler,
1229 | num_workers=self.args.num_workers,
1230 | collate_fn=self.data_collator,
1231 | drop_last=is_train,
1232 | )
1233 | return loader
1234 |
1235 | def train_dataloader(self):
1236 | return self._get_loader(f'{self.args.input_dir}/cache/train.bin', True)
1237 |
1238 | def val_dataloader(self):
1239 | return self._get_loader(f'{self.args.input_dir}/cache/val.bin', False)
1240 |
1241 | def grad_norm(self, norm_type):
1242 | # Override PTL `grad_norm` function to only return `total_grad_norm` instead norms of individual params
1243 | # TODO: grad_norm reporting needs to take fp16 loss scale into account
1244 | parameters = [p for p in self.parameters() if p.grad is not None]
1245 | device = parameters[0].device
1246 | total_norm = torch.zeros([], device=device if parameters else None)
1247 | norm_type = float(norm_type)
1248 | for p in parameters:
1249 | param_norm = p.grad.norm(norm_type)
1250 | total_norm.add_(param_norm)
1251 | return {'total_grad_norm': total_norm}
1252 |
1253 | @staticmethod
1254 | def add_args(parser):
1255 | parser.add_argument("--seed", type=int, default=3)
1256 |
1257 | # Dataset. Some of these params are only useful when generating the dataset cache
1258 | parser.add_argument("--input_dir", type=str, default='/net/nfs2.allennlp/shengs/c4/')
1259 | parser.add_argument("--output_dir", type=str, default='/net/nfs2.allennlp/shengs/c4/')
1260 |
1261 | parser.add_argument("--data_type", type=str, default='tfrecord')
1262 | parser.add_argument("--add_sep_after_doc", action='store_true', default=False, help='add sep token after document')
1263 |
1264 | # Used only at the preprocessing phase
1265 | parser.add_argument("--train_dev_split", type=float, default=0.05)
1266 | parser.add_argument("--shard_size", type=int, default=1024 ** 3 // 4) # 250MB
1267 | parser.add_argument("--num_preprocessing_workers", type=int, default=1)
1268 | # Used only at the training phase
1269 | parser.add_argument("--seqlen", type=int, default=512)
1270 |
1271 | # HF model loading
1272 | parser.add_argument("--tokenizer", type=str, default='gpt2')
1273 | parser.add_argument("--model", type=str, default='gpt2')
1274 | parser.add_argument("--doubling", type=str) # could be layers / weights
1275 | parser.add_argument("--doubling_layers", type=str) # could be alternate_id, append_id, alternate_copy, append_copy
1276 | # parser.add_argument("--noise_std", type=float, default=0.0)
1277 | parser.add_argument("--warmup_bsz", type=int, default=0, help='# warmup batch size')
1278 | parser.add_argument("--warmup_seq", type=int, default=0, help='# warmup sequence length')
1279 |
1280 | parser.add_argument("--random", default=False, action='store_true')
1281 | parser.add_argument("--layers", type=int)
1282 | parser.add_argument("--size", type=str)
1283 |
1284 | # Checkpointing and logging
1285 | parser.add_argument("--save_dir", type=str, default='runs/')
1286 | parser.add_argument("--save_prefix", type=str, default='test',
1287 | help="path of output directory is --save_dir/--save_prefix")
1288 | parser.add_argument("--resume", type=str, default=None, # It is better to use a different output dir.
1289 | help="Path to a checkpoint to load model weights and training state. It overwrites args")
1290 | parser.add_argument("--resume_model_only", type=str, default=None,
1291 | help="Path to a checkpoint to load model weights but not training state")
1292 | parser.add_argument("--reset_optimizer", default=False, action='store_true')
1293 | parser.add_argument("--reset_lr_scheduler", default=False, action='store_true')
1294 | parser.add_argument("--log_rate", type=int, default=10)
1295 | parser.add_argument("--disable_checkpointing", action='store_true', default=False)
1296 |
1297 | # Training hyperparams
1298 | parser.add_argument("--lr", type=float, default=1e-5)
1299 | parser.add_argument("--train_steps", type=int, default=3000, help='# training grad. updates')
1300 | parser.add_argument("--warmup_steps", type=int, default=1000, help='# warmup grad. updates')
1301 | parser.add_argument("--val_every", type=int, default=100, help='# training grad. updates between evaluations')
1302 | parser.add_argument("--val_batches", type=int, default=1000, help='# evaluation **batches**')
1303 | parser.add_argument("--weight_decay", type=float, default=0.01)
1304 | parser.add_argument("--adam_epsilon", type=float, default=1e-6)
1305 | parser.add_argument("--adam_beta1", type=float, default=0.9)
1306 | parser.add_argument("--adam_beta2", type=float, default=0.98)
1307 | parser.add_argument("--grad_clip", type=float, default=0) # TODO: test this with fp16. Likely not working
1308 |
1309 | # RoBERTa's tokens_per_step = 2^18 = 512(seqlen) x 1(gpu_count) x 32(batch_size) x 16(grad_accum)
1310 | parser.add_argument("--batch_size", type=int, default=32)
1311 | parser.add_argument("--grad_accum", type=int, default=1)
1312 |
1313 | # Compute resources
1314 | parser.add_argument("--fp16", default=False, action='store_true')
1315 | parser.add_argument("--num_workers", type=int, default=0)
1316 | parser.add_argument("--gpu_count", type=int, default=1, # `--gpus` is reserved for internal use by PTL
1317 | help="Number of gpus. This respects `CUDA_VISIBLE_DEVICES`")
1318 |
1319 | # For restarting with warmup
1320 | parser.add_argument("--restart_warmup_steps", type=int, default=0, help='# warmup grad. updates after restart')
1321 | parser.add_argument("--restart_steps", type=int, default=0, help='# restart steps, should be the same as set_global_steps')
1322 | # For multi-node training, use the PyTorch launch script. The script and instructions can be found here:
1323 | # https://github.com/pytorch/pytorch/blob/master/torch/distributed/launch.py.
1324 | # To run PTL in a mode compatible with the launch script, two things are needed:
1325 | # - pass the argument `--use_env` to `torch.distributed.launch`
1326 | # - make sure `--nproc_per_node` matches `--gpu_count` and `--nnodes` matches `--node_count`.
1327 | # For example, to run on 2 nodes, 3 gpus each, the command line on node rank 1 would be like:
1328 | # >>>> python -m torch.distributed.launch \
1329 | # --use_env --nnodes 2 --nproc_per_node 3 \
1330 | # --node_rank 1 --master_addr s2-server4 --master_port 12343 \
1331 | # scripts/pretrain.py \
1332 | # --gpu_count 2 --node_count 2 \
1333 | # --input_dir my_data_dir --save_prefix test_multinode
1334 | parser.add_argument("--node_count", type=int, default=1,
1335 | help="Number of nodes. It needs to match --nnodes of torch.distributed.launch")
1336 | parser.add_argument("--tpu_core_count", type=int, default=None)
1337 |
1338 | return parser
1339 |
1340 |
1341 | def main(args):
1342 | random.seed(args.seed * 10)
1343 | np.random.seed(args.seed * 100)
1344 | torch.manual_seed(args.seed * 1000)
1345 | if torch.cuda.is_available():
1346 | torch.cuda.manual_seed_all(args.seed * 10000)
1347 |
1348 | if args.resume_model_only is not None:
1349 | pretrainer = Pretrainer.load_from_checkpoint(args.resume_model_only)
1350 | else:
1351 | pretrainer = Pretrainer()
1352 |
1353 | if args.doubling is not None:
1354 |
1355 | doubled_resume = args.resume + '.doubled_weights' if args.doubling == 'weights' else args.resume + '.doubled_layer'
1356 | print(doubled_resume)
1357 | exsit_flag = os.path.isfile(doubled_resume)
1358 |
1359 | if exsit_flag:
1360 | args.resume = doubled_resume
1361 | print('================== warning: reusing old ckpt =======================')
1362 |
1363 |
1364 | # doubling the checkpoint before doubling the in-memory model
1365 | if args.resume is not None and not exsit_flag:
1366 | ckpt = torch.load(args.resume)
1367 |
1368 | # doubling state dict of the saved model
1369 | if args.doubling == 'weights':
1370 | model_state_dict = ckpt['state_dict']
1371 | ckpt['state_dict'] = double_state_dict(model_state_dict, is_double_embedding=True)
1372 |
1373 | # doubling state dict of the saved optimizer
1374 | # no_decay = ["bias", "LayerNorm.weight"]
1375 | # optimizer_params_by_name = [(n, p.shape) for n, p in pretrainer.named_parameters() if not any(nd in n for nd in no_decay) and p.requires_grad]
1376 | # optimizer_params_by_name.extend([(n, p.shape) for n, p in pretrainer.named_parameters() if any(nd in n for nd in no_decay) and p.requires_grad])
1377 | optimizer_params_by_name = [(n, p.shape) for n, p in pretrainer.named_parameters()]
1378 | assert len(optimizer_params_by_name) == len(ckpt['optimizer_states'][0]['state'])
1379 | for (param_name, param_shape), param in zip(optimizer_params_by_name, ckpt['optimizer_states'][0]['state'].values()):
1380 | assert param['exp_avg'].shape == param_shape
1381 | assert param['exp_avg_sq'].shape == param_shape
1382 | param['exp_avg'] = double_param(param_name, param['exp_avg'], is_double_embedding=True, is_grad=True, is_avg_sq=False)
1383 | param['exp_avg_sq'] = double_param(param_name, param['exp_avg_sq'], is_double_embedding=True, is_grad=True, is_avg_sq=True)
1384 |
1385 | # print(name_shape[0])
1386 | args.resume += '.doubled_weights'
1387 | elif args.doubling == 'layers':
1388 | model_state_dict = ckpt['state_dict']
1389 | # hack for doubling the layers
1390 | prefix = 'model.transformer.h'
1391 | map_positions, copy_positions = {}, {}
1392 | for key in model_state_dict:
1393 | if prefix in key:
1394 | layer_idx = re.findall("[-\d]+", key)[0]
1395 | origin_idx = prefix + "." + str(int(layer_idx))
1396 | if 'alternate' in args.doubling_layers:
1397 | insert_idx = prefix + "." + str(int(layer_idx) * 2 + 1)
1398 | origin_key = key.replace(origin_idx, prefix + "." + str(int(layer_idx) * 2))
1399 | elif 'append' in args.doubling_layers:
1400 | insert_idx = prefix + "." + str(pretrainer.model.config.n_layer + int( layer_idx ))
1401 | origin_key = key
1402 |
1403 | insert_key = key.replace( origin_idx, insert_idx )
1404 |
1405 | map_positions[ key ] = [ (origin_key, False), (insert_key, False) ]
1406 | copy_positions[ insert_key ] = (key, False)
1407 | copy_positions[ origin_key ] = (key, True)
1408 |
1409 | is_identical = 'id' in args.doubling_layers
1410 |
1411 | ckpt['state_dict'] = deep_state_dict(model_state_dict, is_identical=is_identical, map_positions=map_positions)
1412 |
1413 | # deal with the optimizer state
1414 | original_optimizer_params_by_name = [(n, p.shape) for n, p in pretrainer.named_parameters()]
1415 | # print( "original_optimizer_params_by_name", original_optimizer_params_by_name )
1416 | # print( "ckpt optimizer_states", ckpt['optimizer_states'][0]['state'].keys() )
1417 | layers = pretrainer.model.transformer.h
1418 | n = len(layers)
1419 | for i in range(n):
1420 | if 'alternate' in args.doubling_layers:
1421 | layers.insert(i * 2, copy.deepcopy(layers[i * 2]))
1422 | elif 'append' in args.doubling_layers:
1423 | layers.append( copy.deepcopy(layers[i]) )
1424 |
1425 | pretrainer.model.config.n_layer *= 2
1426 | pretrainer.model.tie_weights()
1427 | new_optimizer_params_by_name = [(n, p.shape) for n, p in pretrainer.named_parameters()]
1428 |
1429 | new_optimizer_state = { _:{} for _ in range(len(new_optimizer_params_by_name)) }
1430 | assert len(original_optimizer_params_by_name) == len(ckpt['optimizer_states'][0]['state'])
1431 | original_optimizer_param_name_dict = {}
1432 | for (param_name, param_shape), param in zip(original_optimizer_params_by_name, ckpt['optimizer_states'][0]['state'].values()):
1433 | assert param['exp_avg'].shape == param_shape
1434 | assert param['exp_avg_sq'].shape == param_shape
1435 | original_optimizer_param_name_dict[ param_name ] = copy.deepcopy(param)
1436 |
1437 | for param_idx, (param_name, param_shape) in enumerate(new_optimizer_params_by_name):
1438 | if copy_positions.get(param_name):
1439 | copy_param_name, copy_param_flag = copy_positions.get(param_name)
1440 | param_is_identical = copy_param_flag and is_identical
1441 | new_optimizer_state[ param_idx ] = copy.deepcopy(original_optimizer_param_name_dict[ copy_param_name ])
1442 | new_optimizer_state[ param_idx ]['exp_avg'] = deep_param(param_name, original_optimizer_param_name_dict[ copy_param_name ]['exp_avg'], is_identical=param_is_identical, is_grad=True, is_avg_sq=False)
1443 | new_optimizer_state[ param_idx ]['exp_avg_sq'] = deep_param(param_name, original_optimizer_param_name_dict[ copy_param_name ]['exp_avg_sq'], is_identical=param_is_identical, is_grad=True, is_avg_sq=True)
1444 | else:
1445 | new_optimizer_state[ param_idx ] = copy.deepcopy(original_optimizer_param_name_dict[param_name])
1446 |
1447 | ckpt['optimizer_states'][0]['state'] = new_optimizer_state
1448 | ckpt['optimizer_states'][0]['param_groups'][0]['params'] = list(new_optimizer_state.keys())
1449 | del original_optimizer_param_name_dict
1450 | args.resume += '.doubled_layer'
1451 |
1452 | torch.save(ckpt, args.resume)
1453 | exit()
1454 |
1455 | # we need to resume the model after the doubling
1456 | if args.doubling == 'layers':
1457 | assert True
1458 | elif args.doubling == 'weights':
1459 | assert True
1460 | else:
1461 | assert False
1462 |
1463 | # logger here is a SummaryWritter for tensorboard
1464 | # it is used by the trainer, and certain return variables
1465 | # from the model are automatically logged
1466 | logger = TestTubeLogger(
1467 | save_dir=args.save_dir,
1468 | name=args.save_prefix,
1469 | version=0 # always use version=0
1470 | )
1471 |
1472 | checkpoint_callback = ModelCheckpoint(
1473 | # model saved to filepath/prefix_....
1474 | # filepath=os.path.join(args.save_dir, args.save_prefix, 'checkpoint'),
1475 | # prefix='',
1476 | dirpath=os.path.join(args.save_dir, args.save_prefix),
1477 | filename='checkpoint-{epoch}-{step}',
1478 | save_top_k=-1,
1479 | # save_top_k=10,
1480 | every_n_train_steps=250,
1481 | save_last=True,
1482 | verbose=True,
1483 | # monitor='val_loss',
1484 | # mode='min',
1485 | )
1486 | args.val_every *= args.grad_accum # PTL is expecting number of batches_per_gpu
1487 | print(args.val_every, args.disable_checkpointing, checkpoint_callback.__dict__)
1488 | trainer = ptl.Trainer(
1489 | gpus=args.gpu_count,
1490 | num_nodes=args.node_count,
1491 | tpu_cores=args.tpu_core_count,
1492 | distributed_backend='ddp', # if (args.gpu_count > 1 or args.node_count > 1) else None,
1493 | replace_sampler_ddp=False,
1494 | track_grad_norm=2 if args.tpu_core_count is None else -1, # gradnorm logging is slow on tpus
1495 | max_epochs=10000, min_epochs=0,
1496 | max_steps=args.train_steps, # run for many epochs, but stop after max_steps
1497 | val_check_interval=args.val_every, limit_val_batches=args.val_batches,
1498 | log_every_n_steps=args.log_rate,
1499 | progress_bar_refresh_rate=args.log_rate,
1500 | logger=logger,
1501 | # checkpoint_callback=checkpoint_callback if not args.disable_checkpointing else None,
1502 | accumulate_grad_batches=args.grad_accum,
1503 | resume_from_checkpoint=args.resume,
1504 | gradient_clip_val=args.grad_clip,
1505 | precision=16 if args.fp16 else 32, amp_level='O2',
1506 | num_sanity_val_steps=2,
1507 | callbacks=[LearningRateMonitor(), checkpoint_callback],
1508 | profiler="simple",
1509 | )
1510 | trainer.profiler.dirpath = os.path.join(args.save_dir, args.save_prefix)
1511 | trainer.profiler.filename = 'profiler'
1512 | trainer.train_loop = MyTrainLoop(trainer,
1513 | multiple_trainloader_mode='max_size_cycle', args=args)
1514 | trainer.checkpoint_connector = MyCheckpointConnector(trainer,
1515 | reset_lr_scheduler=args.reset_lr_scheduler,
1516 | reset_optimizer=args.reset_optimizer,
1517 | set_global_step=args.restart_steps+1)
1518 | trainer.fit(pretrainer)
1519 |
1520 |
1521 | if __name__ == "__main__":
1522 | parser = Pretrainer.add_args(argparse.ArgumentParser(description="pretrain"))
1523 | args = parser.parse_args()
1524 | main(args)
1525 |
--------------------------------------------------------------------------------