├── tools ├── __init__.py ├── common.py ├── wikitext_dataset.py ├── openwebtext_dataset.py └── mmap_dataset.py ├── dev-requirements.txt ├── .flake8 ├── .dockerignore ├── requirements.txt ├── Dockerfile ├── .github └── workflows │ └── cffconvert.yml ├── CITATION.cff ├── .gitignore ├── README.md ├── cheatsheet.txt ├── eval_lambada.py ├── LICENSE ├── eval_wikitext.py └── gpt_pretrain.py /tools/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /dev-requirements.txt: -------------------------------------------------------------------------------- 1 | black==21.7b0 2 | mypy==0.910 3 | flake8==3.9.2 4 | pytest==6.2.4 5 | -------------------------------------------------------------------------------- /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-line-length = 115 3 | 4 | ignore = 5 | # these rules don't play well with black 6 | E203 # whitespace before : 7 | W503 # line break before binary operator 8 | -------------------------------------------------------------------------------- /.dockerignore: -------------------------------------------------------------------------------- 1 | .dockerignore 2 | .git 3 | notebooks 4 | **.pyc 5 | **/__pycache__ 6 | **/.mypy_cache 7 | .gitignore 8 | .git 9 | .github 10 | .flake8 11 | .venv 12 | *.md 13 | mypy.ini 14 | pytest.ini 15 | tests 16 | CITATION.cff 17 | dev-requirements.txt 18 | LICENSE 19 | cheatsheet.txt 20 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.9.0 2 | numpy 3 | pytorch-lightning==1.3 4 | tokenizers==0.10.3 5 | sentencepiece==0.1.96 6 | transformers==4.8.2 7 | datasets==1.9.0 8 | accelerate==0.3.0 9 | click==7.1.2 10 | click-help-colors==0.9.1 11 | more-itertools==8.8.0 12 | tensorboardX 13 | tqdm==4.61.2 14 | test-tube==0.7.5 15 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | # This Dockerfile creates an environment suitable for running `run.py`. 2 | FROM ghcr.io/allenai/pytorch:1.9.0-cuda11.1 3 | 4 | WORKDIR /stage/ 5 | 6 | # Install remaining dependencies. 7 | COPY requirements.txt . 8 | RUN pip install --no-cache-dir -r requirements.txt 9 | 10 | WORKDIR /workspace 11 | 12 | COPY . . 13 | 14 | ENTRYPOINT ["python"] 15 | -------------------------------------------------------------------------------- /.github/workflows/cffconvert.yml: -------------------------------------------------------------------------------- 1 | name: cffconvert 2 | 3 | on: 4 | pull_request: 5 | paths: 6 | - CITATION.cff 7 | push: 8 | paths: 9 | - CITATION.cff 10 | 11 | jobs: 12 | validate: 13 | name: "validate" 14 | runs-on: ubuntu-latest 15 | steps: 16 | - name: Check out a copy of the repository 17 | uses: actions/checkout@v2 18 | 19 | - name: Check whether the citation metadata from CITATION.cff is valid 20 | uses: citation-file-format/cffconvert-github-action@2.0.0 21 | with: 22 | args: "--validate" 23 | -------------------------------------------------------------------------------- /tools/common.py: -------------------------------------------------------------------------------- 1 | def get_group_texts_function(block_size: int): 2 | def group_texts(examples): 3 | # Concatenate all texts. 4 | concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()} # type: ignore 5 | total_length = len(concatenated_examples[list(examples.keys())[0]]) 6 | # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can 7 | # customize this part to your needs. 8 | if total_length >= block_size: 9 | total_length = (total_length // block_size) * block_size 10 | # Split by chunks of max_len. 11 | result = { 12 | k: [t[i : i + block_size] for i in range(0, total_length, block_size)] 13 | for k, t in concatenated_examples.items() 14 | } 15 | result["labels"] = result["input_ids"].copy() 16 | return result 17 | 18 | return group_texts 19 | -------------------------------------------------------------------------------- /tools/wikitext_dataset.py: -------------------------------------------------------------------------------- 1 | from datasets import load_dataset 2 | from torch.utils.data import Dataset 3 | from transformers import GPT2Tokenizer 4 | 5 | from .common import get_group_texts_function 6 | 7 | 8 | def get_wikitext_dataset( 9 | tokenizer: GPT2Tokenizer, 10 | *, 11 | split: str = "test", 12 | block_size: int = 1024, 13 | num_workers: int = 1, 14 | ) -> Dataset: 15 | # dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="test") # type: ignore[assignment] 16 | dataset = load_dataset("wikitext", "wikitext-103-raw-v1", split="test") 17 | def tokenize_function(example): 18 | return tokenizer(example["text"]) 19 | 20 | dataset = dataset.map( # type: ignore[union-attr,call-arg] 21 | tokenize_function, 22 | batched=True, 23 | num_proc=num_workers, 24 | remove_columns=["text"], 25 | desc="Tokenizing dataset", 26 | ) 27 | 28 | group_texts = get_group_texts_function(block_size) 29 | 30 | dataset = dataset.map( 31 | group_texts, 32 | batched=True, 33 | num_proc=num_workers, 34 | desc=f"Grouping texts into chunks of {block_size}", 35 | ) 36 | 37 | return dataset # type: ignore[return-value] 38 | -------------------------------------------------------------------------------- /tools/openwebtext_dataset.py: -------------------------------------------------------------------------------- 1 | from datasets import load_dataset 2 | from torch.utils.data import Dataset 3 | from transformers import GPT2Tokenizer 4 | 5 | from .common import get_group_texts_function 6 | 7 | 8 | def get_openwebtext_dataset( 9 | tokenizer: GPT2Tokenizer, 10 | *, 11 | block_size: int = 1024, 12 | num_workers: int = 1, 13 | test_size: float = 0.01, 14 | seed: int = 17, 15 | ) -> Dataset: 16 | dataset_dict = load_dataset("openwebtext") 17 | 18 | # This dataset only comes with a single split, so we need to create our own train/test splits. 19 | dataset_dict = dataset_dict["train"].train_test_split( # type: ignore[index,union-attr] 20 | shuffle=True, test_size=test_size, seed=seed 21 | ) 22 | dataset = dataset_dict["test"] # type: ignore[index] 23 | 24 | def tokenize_function(example): 25 | return tokenizer(example["text"]) 26 | 27 | dataset = dataset.map( # type: ignore[union-attr,call-arg] 28 | tokenize_function, 29 | batched=True, 30 | num_proc=num_workers, 31 | remove_columns=["text"], 32 | desc="Tokenizing dataset", 33 | ) 34 | 35 | group_texts = get_group_texts_function(block_size) 36 | 37 | dataset = dataset.map( 38 | group_texts, 39 | batched=True, 40 | num_proc=num_workers, 41 | desc=f"Grouping texts into chunks of {block_size}", 42 | ) 43 | 44 | return dataset # type: ignore[return-value] 45 | -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | # YAML 1.2 2 | --- 3 | cff-version: "1.2.0" 4 | title: "Staged Training for Transformer Language Models" 5 | license: "Apache-2.0" 6 | message: "If you use staged training in your research or wish to refer to the baseline results published here, please cite using this metadata." 7 | repository-code: "https://github.com/allenai/staged-training" 8 | authors: 9 | - affiliation: "University of California, Berkeley" 10 | family-names: Shen 11 | given-names: Sheng 12 | - affiliation: "Allen Institute for Artificial Intelligence" 13 | family-names: Walsh 14 | given-names: Pete 15 | - affiliation: "University of California, Berkeley" 16 | family-names: Keutzer 17 | given-names: Kurt 18 | - affiliation: "Allen Institute for Artificial Intelligence" 19 | family-names: Dodge 20 | given-names: Jesse 21 | - affiliation: "Allen Institute for Artificial Intelligence" 22 | family-names: Peters 23 | given-names: Matthew 24 | - affiliation: "Allen Institute for Artificial Intelligence" 25 | family-names: Beltagy 26 | given-names: Iz 27 | preferred-citation: 28 | type: "article" 29 | title: "Staged Training for Transformer Language Models" 30 | doi: "10.48550/arXiv.2203.06211" 31 | url: "https://arxiv.org/abs/2203.06211" 32 | year: 2022 33 | authors: 34 | - affiliation: "University of California, Berkeley" 35 | family-names: Shen 36 | given-names: Sheng 37 | - affiliation: "Allen Institute for Artificial Intelligence" 38 | family-names: Walsh 39 | given-names: Pete 40 | - affiliation: "University of California, Berkeley" 41 | family-names: Keutzer 42 | given-names: Kurt 43 | - affiliation: "Allen Institute for Artificial Intelligence" 44 | family-names: Dodge 45 | given-names: Jesse 46 | - affiliation: "Allen Institute for Artificial Intelligence" 47 | family-names: Peters 48 | given-names: Matthew 49 | - affiliation: "Allen Institute for Artificial Intelligence" 50 | family-names: Beltagy 51 | given-names: Iz 52 | -------------------------------------------------------------------------------- /tools/mmap_dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.utils.data import Dataset 4 | from transformers import GPT2Tokenizer 5 | 6 | 7 | class MMapTextDataset(Dataset): 8 | def __init__( 9 | self, 10 | mmap_filename: str, 11 | *, 12 | chunk_size: int = 1024, 13 | bos_token_id: int = 50256, 14 | eos_token_id: int = 50256 15 | ): 16 | # `chunk_size - 2` to reserve space for and 17 | self.num_instances = np.memmap(mmap_filename, mode="r", dtype=np.uint16).shape[ 18 | 0 19 | ] // (chunk_size - 2) 20 | # defer loading the token_ids memmap until after the first __getitem__ call. 21 | # when spawning new processes for ddp, there is a hard limit in python < 3.8 that 22 | # pickle files need to be < 4GB. By waiting until after the first __getitem__ we 23 | # don't have to pickle the memmap 24 | self.token_ids = None 25 | self._mmap_filename = mmap_filename 26 | self._chunk_size = chunk_size 27 | self._bos_token_id = bos_token_id 28 | self._eos_token_id = eos_token_id 29 | 30 | def __len__(self): 31 | return self.num_instances 32 | 33 | def __getitem__(self, idx: int): 34 | if self.token_ids is None: 35 | self.token_ids = np.memmap(self._mmap_filename, mode="r", dtype=np.uint16) 36 | from_index = idx * (self._chunk_size - 2) 37 | to_index = (idx + 1) * (self._chunk_size - 2) 38 | data = np.concatenate( 39 | ( 40 | [self._bos_token_id], 41 | self.token_ids[from_index:to_index], # type: ignore[index] 42 | [self._eos_token_id], 43 | ) 44 | ) 45 | return torch.tensor(data, dtype=torch.long) 46 | 47 | 48 | def get_mmap_dataset(tokenizer: GPT2Tokenizer, filename: str, **kwargs) -> Dataset: 49 | return MMapTextDataset( 50 | filename, 51 | bos_token_id=tokenizer.bos_token_id or tokenizer.cls_token_id, 52 | eos_token_id=tokenizer.eos_token_id or tokenizer.sep_token_id, 53 | **kwargs, 54 | ) 55 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # Mac stuff 132 | .DS_Store 133 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # staged-training 2 | 3 | In our paper [**Staged Training for Transformer Language Models**](https://arxiv.org/abs/2203.06211), we propose a staged training setup that begins with a small model and incrementally increases the amount of compute used for training by applying a "growth operator" to increase the model depth and width. By initializing each stage with the output of the previous one, the training process effectively re-uses the compute from prior stages and becomes more efficient. 4 | 5 | We release the reproducible code for the growth operator and evaluation scripts here. 6 | 7 | ## Setup 8 | 9 | The scripts in this repository require Python 3.7 or newer. 10 | Once you have a suitable Python environment, first install PyTorch v1.9.0 according the [official instructions](https://pytorch.org/get-started/previous-versions/#v190). Then run 11 | ``` 12 | pip install -r requirements.txt 13 | ``` 14 | 15 | ## Growth Operator 16 | 17 | Our growth operators (width/depth) each take as input the entire training state (including model parameters, optimizer state, learning rate schedule, etc.) and output a new training state from which training continues. 18 | 19 | Please see the `scripts/cheatsheet.txt` for more examples on how to use the corresponding scripts. 20 | 21 | For example, you can apply the width operator with: 22 | ``` 23 | CUDA_VISIBLE_DEVICES=0,1,2,3 python scripts/gpt_pretrain.py \ 24 | --save_prefix final_gpt2_large_div2_width_check_bs512_lr0.0020_warmup3k_seqlen1024_debug \ 25 | --gpu_count -1 \ 26 | --model gpt2 \ 27 | --tokenizer gpt2 \ 28 | --batch_size 4 \ 29 | --grad_accum 32 \ 30 | --lr 0.002006911598778545 \ 31 | --warmup_steps 3000 \ \ 32 | --train_steps 250000 \ 33 | --val_every 50 \ 34 | --val_batches 50 \ 35 | --fp16 \ 36 | --seqlen 1024 \ 37 | --log_rate 10 \ 38 | --num_workers 4 \ 39 | --size GPT2_large_div2_width \ 40 | --random \ 41 | --resume final_runs/final_gpt2_large_div2_width_check_bs512_lr0.0021_warmup3k_seqlen1024_debug/checkpoint-xxx.ckpt \ 42 | --doubling weights 43 | ``` 44 | 45 | Or the depth operator with: 46 | ``` 47 | CUDA_VISIBLE_DEVICES=0,1,2,3 python scripts/gpt_pretrain.py \ 48 | --save_prefix final_gpt2_large_div2_depthx2_check_bs512_lr0.0020_warmup3k_seqlen1024_debug \ 49 | --gpu_count -1 \ 50 | --model gpt2 \ 51 | --tokenizer gpt2 \ 52 | --batch_size 4 \ 53 | --grad_accum 32 \ 54 | --lr 0.002006911598778545 \ 55 | --warmup_steps 3000 \ 56 | --train_steps 250000 \ 57 | --val_every 50 \ 58 | --val_batches 50 \ 59 | --fp16 \ 60 | --seqlen 1024 \ 61 | --log_rate 10 \ 62 | --num_workers 4 \ 63 | --size GPT2_large_div2_depth \ 64 | --random \ 65 | --resume final_runs/final_gpt2_large_div2_depth_check_bs512_lr0.0020_warmup3k_seqlen1024_debug/checkpoint-epoch=0-step=6499.ckpt \ 66 | --doubling layers 67 | ``` 68 | 69 | ## Evaluation 70 | 71 | Use `evaluation/eval_wikitext.py` or `evaluation/eval_lambada.py` to evaluate [GPT-2](https://cdn.openai.com/better-language-models/language_models_are_unsupervised_multitask_learners.pdf) on one of the supported datasets. For example: 72 | 73 | ```bash 74 | python evaluation/eval_wikitext.py 75 | ``` 76 | 77 | Or using Docker: 78 | 79 | ```bash 80 | docker build -t evaluation:latest . 81 | docker run --rm --gpus all evaluation:latest evaluation/eval_wikitext.py 82 | ``` 83 | 84 | ## Reference 85 | 86 | If you use staged training in your research or wish to refer to the baseline results published here, 87 | please use the following BibTeX entry. 88 | ``` 89 | @misc{shen2022staged, 90 | title={Staged Training for Transformer Language Models}, 91 | author={Sheng Shen and Pete Walsh and Kurt Keutzer and Jesse Dodge and Matthew Peters and Iz Beltagy}, 92 | year={2022}, 93 | eprint={2203.06211}, 94 | archivePrefix={arXiv}, 95 | primaryClass={cs.CL} 96 | } 97 | ``` 98 | -------------------------------------------------------------------------------- /cheatsheet.txt: -------------------------------------------------------------------------------- 1 | # Thresholds are the same, constants are a bit different 2 | threshold_optimality = -0.052 3 | threshold_depth_growth = -0.0575 4 | threshold_width_growth = -0.0475 5 | threshold_depth_width_growth_ours = -0.03 # slightly different from the one you get from the scaling laws 6 | 7 | constant_op_width = 1.776 8 | constant_op_depth = 1.412 9 | constant_op_depth_width = 2.455 10 | 11 | # LR 12 | GPT2_base 0.002132892651963921 13 | GPT2_base_div2_depth 0.0021748863363590465 14 | GPT2_base_div2_width 0.002216880020754172 15 | 16 | 17 | GPT2_large 0.002006911598778545 18 | GPT2_large_div2_width 0.002090898967568796 19 | GPT2_large_div2_depth 0.0020489052831736704 20 | GPT2_large_div4_depth 0.002090898967568796 21 | GPT2_large_div4_width 0.0021748863363590465 22 | 23 | 24 | # Direct Pretrain 4 GPU 25 | CUDA_VISIBLE_DEVICES=0,1,2,3 python scripts/gpt_pretrain.py \ 26 | --save_prefix final_gpt2_large_check_bs512_lr0.0020_warmup3k_seqlen1024_debug \ 27 | --gpu_count -1 --model gpt2 --tokenizer gpt2 \ 28 | --batch_size 4 --grad_accum 32 --lr 0.002006911598778545 --warmup_steps 3000 \ 29 | --train_steps 250000 --val_every 50 --val_batches 50 --fp16 --seqlen 1024 \ 30 | --log_rate 10 --num_workers 4 --size GPT2_large --random 31 | 32 | CUDA_VISIBLE_DEVICES=0,1,2,3 python scripts/gpt_pretrain.py \ 33 | --save_prefix final_gpt2_large_div2_depth_check_bs512_lr0.0020_warmup3k_seqlen1024_debug \ 34 | --gpu_count -1 --model gpt2 --tokenizer gpt2 \ 35 | --batch_size 4 --grad_accum 32 --lr 0.0020489052831736704 --warmup_steps 3000 \ 36 | --train_steps 250000 --val_every 50 --val_batches 50 --fp16 --seqlen 1024 \ 37 | --log_rate 10 --num_workers 2 --size GPT2_large_div2_depth --random 38 | 39 | 40 | CUDA_VISIBLE_DEVICES=0,1,2,3 python scripts/gpt_pretrain.py \ 41 | --save_prefix final_gpt2_large_div2_width_check_bs512_lr0.0020_warmup3k_seqlen1024_debug \ 42 | --gpu_count -1 --model gpt2 --tokenizer gpt2 \ 43 | --batch_size 4 --grad_accum 32 --lr 0.002090898967568796 --warmup_steps 3000 \ 44 | --train_steps 250000 --val_every 50 --val_batches 50 --fp16 --seqlen 1024 \ 45 | --log_rate 10 --num_workers 2 --size GPT2_large_div2_width --random 46 | 47 | 48 | # Use the operator to double the weights 49 | # First, apply the operator to the ckpts 50 | CUDA_VISIBLE_DEVICES=0,1,2,3 python scripts/gpt_pretrain.py \ 51 | --save_prefix final_gpt2_large_div2_width_check_bs512_lr0.0020_warmup3k_seqlen1024_debug \ 52 | --gpu_count -1 --model gpt2 --tokenizer gpt2 --batch_size 4 --grad_accum 32 --lr 0.002006911598778545 --warmup_steps 3000 \ --train_steps 250000 --val_every 50 --val_batches 50 --fp16 --seqlen 1024 --log_rate 10 --num_workers 4 \ 53 | --size GPT2_large_div2_width --random \ 54 | --resume final_runs/final_gpt2_large_div2_width_check_bs512_lr0.0021_warmup3k_seqlen1024_debug/checkpoint-epoch=0-step=6249.ckpt \ 55 | --doubling weights --restart_warmup_steps 200 --restart_steps 3319 \ 56 | --reset_lr_scheduler 57 | 58 | # Second, resume the grown ckpt and set the restart step and re-warmup steps 59 | CUDA_VISIBLE_DEVICES=0,1,2,3 python scripts/gpt_pretrain.py \ 60 | --save_prefix final_gpt2_large_div2_depthx2_check_bs512_lr0.0020_warmup3k_seqlen1024_debug \ 61 | --gpu_count -1 --model gpt2 --tokenizer gpt2 \ 62 | --batch_size 4 --grad_accum 32 --lr 0.002006911598778545 --warmup_steps 3000 \ 63 | --train_steps 250000 --val_every 50 --val_batches 50 --fp16 --seqlen 1024 \ 64 | --log_rate 10 --num_workers 4 --size GPT2_large --random \ 65 | --resume final_runs/final_gpt2_large_div2_width_check_bs512_lr0.0020_warmup3k_seqlen1024_debug/checkpoint-epoch=0-step=6499.ckpt.doubled_weights \ 66 | --restart_warmup_steps 150 --restart_steps 3319 --reset_lr_scheduler 67 | 68 | # Use the operator to double the layers 69 | # First, apply the operator to the ckpts 70 | CUDA_VISIBLE_DEVICES=0,1,2,3 python scripts/gpt_pretrain.py \ 71 | --save_prefix final_gpt2_large_div2_depthx2_check_bs512_lr0.0020_warmup3k_seqlen1024_debug \ 72 | --gpu_count -1 --model gpt2 --tokenizer gpt2 \ 73 | --batch_size 4 --grad_accum 32 --lr 0.002006911598778545 --warmup_steps 3000 \ 74 | --train_steps 250000 --val_every 50 --val_batches 50 --fp16 --seqlen 1024 \ 75 | --log_rate 10 --num_workers 4 --size GPT2_large_div2_depth --random \ 76 | --resume final_runs/final_gpt2_large_div2_depth_check_bs512_lr0.0020_warmup3k_seqlen1024_debug/checkpoint-epoch=0-step=6499.ckpt \ 77 | --doubling layers --restart_warmup_steps 150 --restart_steps 4449 --reset_lr_scheduler --doubling_layers alternate_id 78 | 79 | # Second, resume the grown ckpt and set the restart step and re-warmup steps 80 | CUDA_VISIBLE_DEVICES=0,1,2,3 python scripts/gpt_pretrain.py \ 81 | --save_prefix final_gpt2_large_div2_depthx2_check_bs512_lr0.0020_warmup3k_seqlen1024_debug \ 82 | --gpu_count -1 --model gpt2 --tokenizer gpt2 \ 83 | --batch_size 4 --grad_accum 32 --lr 0.002006911598778545 --warmup_steps 3000 \ 84 | --train_steps 250000 --val_every 50 --val_batches 50 --fp16 --seqlen 1024 \ 85 | --log_rate 10 --num_workers 4 --size GPT2_large --random \ 86 | --resume final_runs/final_gpt2_large_div2_depth_check_bs512_lr0.0020_warmup3k_seqlen1024_debug/checkpoint-epoch=0-step=6499.ckpt.doubled_layer \ 87 | --restart_warmup_steps 150 --restart_steps 4449 --reset_lr_scheduler 88 | 89 | -------------------------------------------------------------------------------- /eval_lambada.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import math 3 | import os 4 | import time 5 | 6 | import numpy as np 7 | import torch 8 | import tqdm 9 | from torch.utils.data import DataLoader, Dataset 10 | from tqdm import trange 11 | 12 | 13 | from transformers import ( 14 | GPT2Tokenizer, 15 | GPT2LMHeadModel, 16 | GPT2Config, 17 | default_data_collator, 18 | DataCollatorForLanguageModeling, 19 | ) 20 | from torch.utils.data import DataLoader, Dataset, Subset 21 | 22 | model_name = 'gpt2' 23 | enc = GPT2Tokenizer.from_pretrained(model_name) 24 | model = GPT2LMHeadModel.from_pretrained(model_name) 25 | 26 | parser = argparse.ArgumentParser() 27 | parser.add_argument('--path', type=str, default='lambada_test.jsonl', help='location of lambada dataset') 28 | parser.add_argument('--batch', type=int, default=4, help='batch size') 29 | parser.add_argument('--max-batches', type=int, default=0, help='batch size') 30 | parser.add_argument('--ignore-fragments', action='store_true', help="Whether to run training.") 31 | parser.add_argument('--preprocess', action='store_true', help="strip quotes") 32 | parser.add_argument('--jeff_suggestion', action='store_true', help="use jeff's suggestion of prepending \n to each example") 33 | parser.add_argument('--dryrun', action='store_true', help="test preprocessing pipeline") 34 | parser.add_argument('--checkpoint-file', default=None, help='location of lambada dataset') 35 | 36 | args = parser.parse_args() 37 | 38 | 39 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 40 | args.device = device 41 | 42 | 43 | model_name = 'gpt2' 44 | enc = GPT2Tokenizer.from_pretrained(model_name) 45 | 46 | checkpoint_file = args.checkpoint_file 47 | 48 | if checkpoint_file is not None: 49 | state_dict = torch.load(checkpoint_file) 50 | state_dict = state_dict['state_dict'] 51 | state_dict = {k.replace("model.", ""): p for k, p in state_dict.items()} 52 | 53 | config = GPT2Config.from_pretrained(model_name) 54 | 55 | # Guess model size. 56 | print(state_dict.keys) 57 | config.n_embd = state_dict["transformer.wte.weight"].shape[1] 58 | config.n_layer = len( 59 | [key for key in state_dict if key.endswith("mlp.c_proj.bias")] 60 | ) 61 | print( 62 | f"Adjusting hidden_size to {config.n_embd}, num_layers to {config.n_layer}" 63 | ) 64 | if config.n_embd == 1536 or config.n_embd == 3072: 65 | config.n_head = 16 66 | else: 67 | config.n_head = 8 68 | #config.n_head = 12 69 | config.n_positions == 1024 70 | if "alibi" in checkpoint_file: 71 | config.alibi_embeddings = True 72 | if "rotary" in checkpoint_file: 73 | config.rotary_embeddings = True 74 | # Initialize model. 75 | model = GPT2LMHeadModel(config) 76 | model.load_state_dict(state_dict, strict=True) 77 | else: 78 | model = GPT2LMHeadModel.from_pretrained(model_name) 79 | model.to(device) 80 | 81 | 82 | def argmax(t): 83 | return int(torch.argmax(t).item()) 84 | 85 | # from https://github.com/openai/gpt-2/issues/131#issuecomment-492786058 86 | def preprocess(text): 87 | text = text.replace("“", '"') 88 | text = text.replace("”", '"') 89 | text = text.replace("''", '"') 90 | text = text.replace("``", '"') 91 | return '\n'+text.strip() 92 | 93 | def score_batch(batch): 94 | """Return number of last-word mismatches in a batch.""" 95 | batch_encoded = [] 96 | lengths = [] 97 | fragments = [] 98 | for line in batch: 99 | line = line.strip() 100 | if args.jeff_suggestion: 101 | line = '\n'+line 102 | line_encoded = enc.encode(line) 103 | encoded_last_word = enc.decode(line_encoded[-1:]).strip() 104 | actual_last_word = line.split()[-1].strip() 105 | if encoded_last_word != actual_last_word: 106 | fragments.append(True) 107 | else: 108 | fragments.append(False) 109 | batch_encoded.append(line_encoded) 110 | 111 | # array is ragged, so pad to turn into rectangular tensor 112 | max_len = max(len(encoded) for encoded in batch_encoded) 113 | batch_padded = [] 114 | for encoded in batch_encoded: 115 | batch_padded.append(encoded+[0]*(max_len - len(encoded))) 116 | lengths.append(len(encoded)) 117 | 118 | batch_padded = torch.tensor(batch_padded) 119 | batch_padded = batch_padded.to(device) 120 | if args.dryrun: 121 | return 0, 1 122 | 123 | # logits, presents = model(batch_padded) 124 | outputs = model(batch_padded) 125 | logits = outputs.logits 126 | errors = 0 127 | total = 0 128 | for i in range(args.batch): 129 | # break on small last batch 130 | if i >= len(batch_padded): 131 | break 132 | last_idx = lengths[i]-1 133 | observed = batch_encoded[i][last_idx] 134 | predicted = argmax(logits[i][last_idx-1]) 135 | if args.ignore_fragments and fragments[i]: 136 | continue 137 | total+=1 138 | errors += 0 if (observed == predicted) else 1 139 | 140 | return errors, total 141 | 142 | 143 | def main(): 144 | ds_raw = open(f'{args.path}').read() 145 | if args.preprocess: 146 | ds_raw = preprocess(ds_raw) 147 | 148 | ds = ds_raw.strip().split('\n') 149 | 150 | # special handling for jsonl file 151 | lines = [] 152 | if args.path.endswith('.jsonl'): 153 | # special handling for file from Jeff 154 | for line in ds: 155 | # candidate1 = eval(line)['text'] 156 | # lines.append(candidate1) 157 | candidate2 = line[len('{"text": "'):-len('"}')] 158 | candidate2 = f'''"""{candidate2}"""''' 159 | lines.append(eval(candidate2)) 160 | 161 | # lines.append(eval(line)) 162 | #print(line) 163 | # break 164 | # print(line) 165 | # eprint(lines[-1]) 166 | ds = lines 167 | data_loader = DataLoader(ds, batch_size=args.batch, shuffle=False) 168 | 169 | errors = 0 170 | total = 0 171 | for batch in tqdm.tqdm(data_loader): 172 | errors_batch, total_batch = score_batch(batch) 173 | errors += errors_batch 174 | total += total_batch 175 | # if args.max_batches and i>=args.max_batches-1: 176 | # break 177 | 178 | print("Accuracy: %.4f"%(1-errors/total,)) 179 | 180 | 181 | if __name__=='__main__': 182 | main() 183 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /eval_wikitext.py: -------------------------------------------------------------------------------- 1 | """ 2 | Evaluate pretrained GPT-2 models on standard datasets. 3 | 4 | Adapted from https://github.com/huggingface/transformers/blob/master/examples/pytorch/language-modeling/run_clm_no_trainer.py. 5 | """ # noqa: E501 6 | 7 | import math 8 | import os 9 | import random 10 | import shutil 11 | import tempfile 12 | from itertools import islice 13 | from typing import Any, Dict, Optional 14 | 15 | import click 16 | import numpy as np 17 | import torch 18 | from accelerate import Accelerator 19 | from click_help_colors import HelpColorsCommand, HelpColorsGroup 20 | from more_itertools import chunked 21 | from torch.utils.data import DataLoader, Dataset, DistributedSampler, Sampler 22 | from tqdm import tqdm 23 | from transformers import (DataCollatorForLanguageModeling, GPT2Config, 24 | GPT2LMHeadModel, GPT2Tokenizer, 25 | default_data_collator) 26 | from transformers.optimization import AdamW, get_linear_schedule_with_warmup 27 | 28 | from tools.mmap_dataset import get_mmap_dataset 29 | from tools.openwebtext_dataset import get_openwebtext_dataset 30 | from tools.wikitext_dataset import get_wikitext_dataset 31 | 32 | 33 | @click.group( 34 | cls=HelpColorsGroup, 35 | help_options_color="green", 36 | help_headers_color="yellow", 37 | context_settings={"max_content_width": 115}, 38 | ) 39 | def main(): 40 | if torch.cuda.is_available(): 41 | click.echo("CUDA is available :)\n") 42 | else: 43 | click.secho("No CUDA devices available!\n", fg="red") 44 | 45 | 46 | @main.command( 47 | cls=HelpColorsCommand, 48 | help_options_color="green", 49 | help_headers_color="yellow", 50 | context_settings={"max_content_width": 115}, 51 | ) 52 | @click.option("--model-name", default="gpt2") 53 | @click.option( 54 | "--dataset", 55 | default="wikitext2", 56 | type=click.Choice(["wikitext2", "openwebtext", "mmap"]), 57 | show_choices=True, 58 | show_default=True, 59 | help="The dataset to evaluate on.", 60 | ) 61 | @click.option( 62 | "--dataset-path", 63 | default=None, 64 | type=click.Path(exists=True, dir_okay=False, resolve_path=True), 65 | help="Path to the memory-mapped dataset (only valid when --dataset=mmap)", 66 | ) 67 | @click.option( 68 | "--block-size", 69 | default=1024, 70 | show_default=True, 71 | help="""Input texts are blocked together into blocks of this size. 72 | This should probably match the max input size of the model.""", 73 | ) 74 | @click.option( 75 | "--batch-size", 76 | default=32, 77 | show_default=True, 78 | help="The batch size to use for evaluation.", 79 | ) 80 | @click.option( 81 | "--checkpoint-file", 82 | default=None, 83 | type=click.Path(exists=True, dir_okay=False, resolve_path=True), 84 | help="A checkpoint file to load the weights from.", 85 | ) 86 | @click.option( 87 | "--skip-loading-weights", 88 | is_flag=True, 89 | help="Leave the model's weights at their random initialization.", 90 | ) 91 | @click.option( 92 | "--max-steps", 93 | default=None, 94 | type=click.INT, 95 | ) 96 | def eval( 97 | model_name: str, 98 | dataset: str, 99 | dataset_path: Optional[str], 100 | block_size: int, 101 | batch_size: int, 102 | checkpoint_file: Optional[str], 103 | skip_loading_weights: bool, 104 | max_steps: Optional[int], 105 | ): 106 | """ 107 | Evaluate a GPT-2 model on a dataset. 108 | """ 109 | # Validate params. 110 | if dataset != "mmap" and dataset_path is not None: 111 | raise click.UsageError("'--dataset-path' only valid when '--dataset=mmap'") 112 | if dataset == "mmap" and dataset_path is None: 113 | raise click.UsageError("'--dataset-path' is required for this dataset type") 114 | 115 | click.secho("[1/3] Loading tokenizer and model...", fg="green") 116 | 117 | tokenizer = GPT2Tokenizer.from_pretrained(model_name) 118 | if checkpoint_file is not None: 119 | click.echo(f"Loading checkpoint from {checkpoint_file}") 120 | 121 | # Load state dict. 122 | state_dict = torch.load(checkpoint_file) 123 | state_dict = state_dict["state_dict"] 124 | state_dict = {k.replace("model.", ""): p for k, p in state_dict.items()} 125 | 126 | config = GPT2Config.from_pretrained(model_name) 127 | 128 | # Guess model size. 129 | config.n_embd = state_dict["transformer.wte.weight"].shape[1] 130 | config.n_layer = len( 131 | [key for key in state_dict if key.endswith("mlp.c_proj.bias")] 132 | ) 133 | click.echo( 134 | f"Adjusting hidden_size to {config.n_embd}, num_layers to {config.n_layer}" 135 | ) 136 | if config.n_embd == 1536 or config.n_embd == 3072: 137 | config.n_head = 16 138 | else: 139 | config.n_head = 8 140 | 141 | config.n_positions = 1024 142 | if "alibi" in checkpoint_file: 143 | config.alibi_embeddings = True 144 | if "rotary" in checkpoint_file: 145 | config.rotary_embeddings = True 146 | # Initialize model. 147 | model = GPT2LMHeadModel(config) 148 | if not skip_loading_weights: 149 | model.load_state_dict(state_dict, strict=True) 150 | elif skip_loading_weights: 151 | config = GPT2Config.from_pretrained(model_name) 152 | model = GPT2LMHeadModel(config) 153 | else: 154 | model = GPT2LMHeadModel.from_pretrained(model_name) 155 | 156 | model = model.cuda() 157 | 158 | click.secho("\n[2/3] Preprocessing data...", fg="green") 159 | 160 | dataloader = get_dataloader( 161 | dataset, 162 | tokenizer, 163 | block_size=block_size, 164 | batch_size=batch_size, 165 | dataset_path=dataset_path, 166 | ) 167 | 168 | click.secho("\n[3/3] Evaluating model on data...", fg="green") 169 | 170 | model.eval() 171 | 172 | running_loss = 0.0 173 | total_batches = ( 174 | len(dataloader) if max_steps is None else min([max_steps, len(dataloader)]) 175 | ) 176 | with tqdm( 177 | islice(dataloader, total_batches), desc="Evaluating", total=total_batches 178 | ) as batch_iterator: 179 | for i, batch in enumerate(batch_iterator): 180 | batch = {k: v.cuda() for k, v in batch.items()} 181 | # with torch.inference_mode(): 182 | with torch.no_grad(): 183 | outputs = model(**batch) 184 | 185 | loss = outputs.loss 186 | running_loss += loss.item() 187 | 188 | if i % 50 == 0 or i == total_batches - 1: 189 | mean_loss = running_loss / (i + 1) 190 | ppl = math.exp(mean_loss) 191 | batch_iterator.set_postfix(loss=mean_loss, ppl=ppl) 192 | 193 | mean_loss = running_loss / total_batches 194 | ppl = math.exp(mean_loss) 195 | 196 | click.secho( 197 | f"\nDone! Final loss: {mean_loss:.4f} (ppl = {ppl:.4f})", fg="green", bold=True 198 | ) 199 | 200 | 201 | @main.command() 202 | @click.argument( 203 | "train-dataset-path", 204 | type=click.Path(exists=True, dir_okay=False, resolve_path=True), 205 | ) 206 | @click.argument( 207 | "validation-dataset-path", 208 | type=click.Path(exists=True, dir_okay=False, resolve_path=True), 209 | ) 210 | @click.argument( 211 | "log-dir", 212 | type=click.Path(exists=False, dir_okay=True, resolve_path=True), 213 | ) 214 | @click.option( 215 | "--block-size", 216 | default=1024, 217 | show_default=True, 218 | help="""Input texts are blocked together into blocks of this size. 219 | This should probably match the max input size of the model.""", 220 | ) 221 | @click.option( 222 | "--batch-size", 223 | default=32, 224 | show_default=True, 225 | help="The batch size to use for training and validation.", 226 | ) 227 | @click.option( 228 | "--grad-accum", 229 | default=1, 230 | show_default=True, 231 | help="The number of gradient accumulation steps per update.", 232 | ) 233 | @click.option( 234 | "--num-heads", 235 | default=4, 236 | show_default=True, 237 | help="The number of attention heads.", 238 | ) 239 | @click.option( 240 | "--num-layers", 241 | default=4, 242 | show_default=True, 243 | help="The number of transformer layers.", 244 | ) 245 | @click.option( 246 | "--hidden-size", 247 | default=256, 248 | show_default=True, 249 | help="The hidden size of the model.", 250 | ) 251 | @click.option( 252 | "--lr", 253 | default=None, 254 | type=click.FLOAT, 255 | show_default=True, 256 | help="The learning rate. Defaults to '0.003239 - 0.0001395 log(N)'.", 257 | ) 258 | @click.option( 259 | "--adam-epsilon", 260 | default=1e-6, 261 | show_default=True, 262 | ) 263 | @click.option( 264 | "--adam-beta1", 265 | default=0.9, 266 | show_default=True, 267 | ) 268 | @click.option( 269 | "--adam-beta2", 270 | default=0.95, 271 | show_default=True, 272 | ) 273 | @click.option( 274 | "--warmup-steps", 275 | default=3000, 276 | show_default=True, 277 | ) 278 | @click.option( 279 | "--train-steps", 280 | default=100000, 281 | show_default=True, 282 | ) 283 | @click.option( 284 | "--validation-steps", 285 | default=50, 286 | show_default=True, 287 | ) 288 | @click.option( 289 | "--validate-every", 290 | default=100, 291 | show_default=True, 292 | ) 293 | @click.option( 294 | "--checkpoint-every", 295 | default=100, 296 | show_default=True, 297 | ) 298 | @click.option( 299 | "--wandb-entity", 300 | default="allenai-team1", 301 | show_default=True, 302 | ) 303 | @click.option( 304 | "--wandb-project", 305 | default="staged-training", 306 | show_default=True, 307 | ) 308 | @click.option( 309 | "--amp", is_flag=True, help="""Train with automatic mixed-precision enabled.""" 310 | ) 311 | @click.option( 312 | "--recover", is_flag=True, help="""Restart training from a previous run.""" 313 | ) 314 | @click.option( 315 | "--recover-from", 316 | type=click.Path(exists=True, dir_okay=True, resolve_path=True), 317 | help="""Log directory to recover from if different.""", 318 | ) 319 | @click.option( 320 | "--init-seed", 321 | default=42, 322 | show_default=True, 323 | ) 324 | @click.option( 325 | "--data-seed", 326 | default=42, 327 | show_default=True, 328 | ) 329 | @click.option( 330 | "--wandb-tags", help="""A comma-separated list of tags to assign to the W&B run.""" 331 | ) 332 | def train( 333 | train_dataset_path: str, 334 | validation_dataset_path: str, 335 | log_dir: str, 336 | block_size: int, 337 | batch_size: int, 338 | grad_accum: int, 339 | num_heads: int, 340 | num_layers: int, 341 | hidden_size: int, 342 | lr: float, 343 | adam_epsilon: float, 344 | adam_beta1: float, 345 | adam_beta2: float, 346 | warmup_steps: int, 347 | train_steps: int, 348 | validation_steps: int, 349 | validate_every: int, 350 | checkpoint_every: int, 351 | wandb_entity: str, 352 | wandb_project: str, 353 | amp: bool, 354 | recover: bool, 355 | recover_from: Optional[str], 356 | init_seed: int, 357 | data_seed: int, 358 | wandb_tags: Optional[str], 359 | ): 360 | """ 361 | Train a GPT-2 model on C4. 362 | """ 363 | accelerator = Accelerator(fp16=amp) 364 | device = accelerator.device 365 | is_distributed = accelerator.num_processes > 1 366 | 367 | state_path = os.path.join( 368 | log_dir, f"state_worker_{accelerator.local_process_index}.pt" 369 | ) 370 | 371 | # Check log_dir. 372 | initial_state: Optional[Dict[str, Any]] = None 373 | if recover: 374 | if recover_from is not None: 375 | # Copy over contents to log_dir 376 | assert os.path.isdir(recover_from) 377 | assert os.path.isfile( 378 | os.path.join( 379 | recover_from, f"state_worker_{accelerator.local_process_index}.pt" 380 | ) 381 | ) 382 | if accelerator.is_local_main_process: 383 | assert not os.path.exists(log_dir) or not os.listdir(log_dir) 384 | shutil.copytree( 385 | recover_from, 386 | log_dir, 387 | # dirs_exist_ok=True, only available for python >= 3.8 388 | ) 389 | accelerator.wait_for_everyone() 390 | assert os.path.isdir(log_dir) 391 | assert os.path.isfile(state_path) 392 | click.echo( 393 | f"[Worker {accelerator.local_process_index}] Loading training state from {state_path}" 394 | ) 395 | initial_state = torch.load(state_path) 396 | else: 397 | assert not os.path.exists(log_dir) or not os.listdir(log_dir) 398 | if accelerator.is_local_main_process: 399 | os.makedirs(log_dir, exist_ok=True) 400 | 401 | if accelerator.is_local_main_process: 402 | click.echo(f"Training on {accelerator.num_processes} devices") 403 | click.secho( 404 | "\n[1/3] Initializing tokenizer, model, and optimizer...", fg="green" 405 | ) 406 | 407 | tokenizer = GPT2Tokenizer.from_pretrained("gpt2") 408 | 409 | config = GPT2Config.from_pretrained("gpt2") 410 | config.n_head = num_heads 411 | config.n_layer = num_layers 412 | config.n_embd = hidden_size 413 | 414 | # Set random seeds for model initialization. 415 | set_seeds(init_seed) 416 | 417 | model = GPT2LMHeadModel(config) 418 | 419 | total_params = 0 420 | for name, param in model.named_parameters(): 421 | # Ignore embedding matrix when calculating size. 422 | if name == "transformer.wte.weight" or name == "transformer.wpe.weight": 423 | continue 424 | total_params += param.numel() 425 | if accelerator.is_local_main_process: 426 | click.echo(f"Total non-embedding parameters: {total_params:,}") 427 | 428 | if lr is None: 429 | lr = 0.003239 - 0.0001395 * np.log(total_params) 430 | 431 | optimizer = AdamW( 432 | model.parameters(), 433 | lr=lr, 434 | eps=adam_epsilon, 435 | betas=(adam_beta1, adam_beta2), 436 | correct_bias=False, 437 | ) 438 | scheduler = get_linear_schedule_with_warmup( 439 | optimizer, num_warmup_steps=warmup_steps, num_training_steps=train_steps 440 | ) 441 | 442 | if accelerator.is_local_main_process: 443 | click.secho("\n[2/3] Loading data...", fg="green") 444 | 445 | # Set random seeds for data shuffling. 446 | set_seeds(data_seed) 447 | 448 | train_dataloader = get_dataloader( 449 | "mmap", 450 | tokenizer, 451 | dataset_path=train_dataset_path, 452 | block_size=block_size, 453 | batch_size=batch_size, 454 | shuffle=True, 455 | is_distributed=is_distributed, 456 | seed=data_seed, 457 | ) 458 | 459 | validation_dataloader = get_dataloader( 460 | "mmap", 461 | tokenizer, 462 | dataset_path=validation_dataset_path, 463 | block_size=block_size, 464 | batch_size=batch_size, 465 | is_distributed=is_distributed, 466 | seed=data_seed, 467 | ) 468 | 469 | # NOTE: We don't call `prepare()` on the dataloaders because that causes a memory leak, 470 | # and it's not necessary anyway. 471 | model, optimizer = accelerator.prepare(model, optimizer) 472 | 473 | validation_steps = min([len(validation_dataloader), validation_steps]) 474 | 475 | # Load state. 476 | if initial_state is not None: 477 | optimizer.load_state_dict(initial_state["optimizer"]) 478 | scheduler.load_state_dict(initial_state["scheduler"]) 479 | model.load_state_dict(initial_state["model"]) 480 | 481 | wandb_run_id: Optional[str] = None 482 | if accelerator.is_main_process: 483 | import wandb 484 | 485 | if initial_state is not None: 486 | wandb_run_id = initial_state["wandb_run_id"] 487 | else: 488 | wandb_run_id = wandb.util.generate_id() 489 | 490 | wandb.init( 491 | id=wandb_run_id, 492 | dir=log_dir, 493 | entity=wandb_entity, 494 | resume="auto", 495 | project=wandb_project, 496 | tags=None if not wandb_tags else wandb_tags.split(","), 497 | config={ 498 | "init_seed": init_seed, 499 | "data_seed": data_seed, 500 | "total_params": total_params, 501 | "learning_rate": lr, 502 | "adam_beta1": adam_beta1, 503 | "adam_beta2": adam_beta2, 504 | "adam_epsilon": adam_epsilon, 505 | "batch_size": batch_size, 506 | "grad_accum": grad_accum, 507 | "num_processes": accelerator.num_processes, 508 | "effective_batch_size": batch_size 509 | * grad_accum 510 | * accelerator.num_processes, 511 | }, 512 | ) 513 | 514 | accelerator.wait_for_everyone() 515 | 516 | if accelerator.is_local_main_process: 517 | click.secho("\n[3/3] Training...", fg="green") 518 | 519 | model.train() 520 | val_loss: Optional[float] = None 521 | training_batches = enumerate( 522 | islice( 523 | chunked(cycle_through_epochs(train_dataloader, is_distributed), grad_accum), 524 | train_steps, 525 | ) 526 | ) 527 | 528 | # Catch data loader up to where we left off before. 529 | if initial_state is not None: 530 | click.echo( 531 | f"[Worker {accelerator.local_process_index}] " 532 | f"Catching data loader up to step {initial_state['training_steps']}..." 533 | ) 534 | training_steps = initial_state["training_steps"] 535 | for step, batch in training_batches: 536 | del batch 537 | if step >= training_steps - 1: 538 | break 539 | accelerator.wait_for_everyone() 540 | 541 | with tqdm( 542 | training_batches, 543 | desc="Training", 544 | initial=0 if initial_state is None else initial_state["training_steps"], 545 | total=train_steps, 546 | disable=not accelerator.is_local_main_process, 547 | ) as train_batch_iterator: 548 | for step, batch in train_batch_iterator: 549 | 550 | def save_state(): 551 | temp_state_file = tempfile.NamedTemporaryFile( 552 | "w+b", dir=log_dir, delete=False, suffix="pt" 553 | ) 554 | try: 555 | torch.save( 556 | { 557 | "optimizer": optimizer.state_dict(), 558 | "scheduler": scheduler.state_dict(), 559 | "model": model.state_dict(), 560 | "wandb_run_id": wandb_run_id, 561 | "training_steps": step + 1, 562 | }, 563 | temp_state_file.name, 564 | ) 565 | temp_state_file.close() 566 | os.replace(temp_state_file.name, state_path) 567 | finally: 568 | if os.path.exists(temp_state_file.name): 569 | os.remove(temp_state_file.name) 570 | 571 | optimizer.zero_grad() 572 | batch_loss = 0.0 573 | batch_ppl: Optional[float] = None 574 | for micro_batch in batch: 575 | # Move tensors to right device. 576 | micro_batch = {k: v.to(device) for k, v in micro_batch.items()} 577 | 578 | # Get loss. 579 | outputs = model(**micro_batch) 580 | micro_batch_loss = outputs.loss / len(batch) 581 | batch_loss += micro_batch_loss.detach().item() 582 | 583 | # Calculate gradients. 584 | accelerator.backward(micro_batch_loss) 585 | 586 | # Clean up. 587 | del micro_batch 588 | del outputs 589 | del micro_batch_loss 590 | 591 | del batch 592 | 593 | # Take step. 594 | optimizer.step() 595 | scheduler.step() 596 | 597 | should_log_this_step = step % 10 == 0 or step == train_steps - 1 598 | should_checkpoint_this_step = step > 0 and step % checkpoint_every == 0 599 | should_validate_this_step = ( 600 | step > 0 and step % validate_every == 0 601 | ) or step == train_steps - 1 602 | 603 | # Gather average loss across all workers. 604 | if should_log_this_step or should_validate_this_step: 605 | batch_loss = ( 606 | accelerator.gather( 607 | torch.tensor(batch_loss, device=device).unsqueeze(0) 608 | ) 609 | .mean() 610 | .item() 611 | ) 612 | batch_ppl = math.exp(batch_loss) # type: ignore[arg-type] 613 | 614 | # Update progress bar and log to W&B. 615 | if accelerator.is_local_main_process and should_log_this_step: 616 | if val_loss is not None: 617 | train_batch_iterator.set_postfix( 618 | batch_loss=batch_loss, 619 | batch_ppl=batch_ppl, 620 | val_loss=val_loss, 621 | val_ppl=math.exp(val_loss), 622 | ) 623 | else: 624 | train_batch_iterator.set_postfix( 625 | batch_loss=batch_loss, batch_ppl=batch_ppl 626 | ) 627 | 628 | if accelerator.is_main_process: 629 | wandb.log( 630 | { 631 | "batch_loss": batch_loss, 632 | "batch_ppl": batch_ppl, 633 | "lr": optimizer.param_groups[0]["lr"], 634 | }, 635 | step=step, 636 | ) 637 | 638 | # Checkpoint. 639 | if should_checkpoint_this_step: 640 | save_state() 641 | 642 | # Validate. 643 | if should_validate_this_step: 644 | # Prepare model for validation. 645 | model.eval() 646 | optimizer.zero_grad() # Not strictly necessary. 647 | 648 | running_loss = 0.0 649 | with tqdm( 650 | islice(validation_dataloader, validation_steps), 651 | desc="Validating", 652 | total=validation_steps, 653 | leave=False, 654 | disable=not accelerator.is_local_main_process, 655 | ) as val_batch_iterator: 656 | for val_step, val_batch in enumerate(val_batch_iterator): 657 | # Move tensors to right device. 658 | val_batch = {k: v.to(device) for k, v in val_batch.items()} 659 | 660 | # Get loss. 661 | with torch.inference_mode(): 662 | outputs = model(**val_batch) 663 | loss = outputs.loss 664 | 665 | running_loss += loss.item() 666 | val_loss = running_loss / (val_step + 1) 667 | 668 | # Update progress bar. 669 | if accelerator.is_local_main_process and val_step % 10 == 0: 670 | val_batch_iterator.set_postfix( 671 | loss=val_loss, ppl=math.exp(val_loss) 672 | ) 673 | 674 | # Clean up. 675 | del val_batch 676 | del outputs 677 | del loss 678 | 679 | # Average loss across all workers. 680 | val_loss = ( 681 | accelerator.gather( 682 | torch.tensor(val_loss, device=device).unsqueeze(0) 683 | ) 684 | .mean() 685 | .item() 686 | ) 687 | 688 | # Reset model to train mode. 689 | model.train() 690 | 691 | # Update progress bar again with validation stats and log to W&B. 692 | val_ppl = math.exp(val_loss) # type: ignore[arg-type] 693 | if accelerator.is_local_main_process: 694 | train_batch_iterator.set_postfix( 695 | batch_loss=batch_loss, 696 | batch_ppl=batch_ppl, 697 | val_loss=val_loss, 698 | val_ppl=val_ppl, 699 | ) 700 | if accelerator.is_main_process: 701 | wandb.log({"val_loss": val_loss, "val_ppl": val_ppl}, step=step) 702 | 703 | if accelerator.is_main_process: 704 | wandb.finish() 705 | 706 | click.secho("\nDone!", fg="green", bold=True) 707 | 708 | 709 | def get_dataloader( 710 | dataset: str, 711 | tokenizer: GPT2Tokenizer, 712 | *, 713 | block_size: int = 1024, 714 | batch_size: int = 32, 715 | dataset_path: Optional[str] = None, 716 | shuffle: bool = False, 717 | is_distributed: bool = False, 718 | seed: int = 0, 719 | ) -> DataLoader: 720 | dataset_object: Dataset 721 | collator = default_data_collator 722 | if dataset == "wikitext2": 723 | dataset_object = get_wikitext_dataset( 724 | tokenizer, 725 | split="test", 726 | block_size=block_size, 727 | num_workers=1, 728 | ) 729 | elif dataset == "openwebtext": 730 | dataset_object = get_openwebtext_dataset( 731 | tokenizer, 732 | block_size=block_size, 733 | num_workers=8, 734 | ) 735 | elif dataset == "mmap": 736 | assert dataset_path is not None 737 | dataset_object = get_mmap_dataset( 738 | tokenizer, dataset_path, chunk_size=block_size 739 | ) 740 | collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) 741 | else: 742 | raise ValueError(f"Unexpected dataset '{dataset}'") 743 | 744 | sampler: Optional[Sampler] = ( 745 | DistributedSampler(dataset_object, shuffle=shuffle, seed=seed) 746 | if is_distributed 747 | else None 748 | ) 749 | 750 | dataloader: DataLoader = DataLoader( 751 | dataset_object, 752 | collate_fn=collator, 753 | batch_size=batch_size, 754 | shuffle=shuffle if sampler is None else False, 755 | sampler=sampler, 756 | ) 757 | 758 | return dataloader 759 | 760 | 761 | def cycle_through_epochs(dataloader: DataLoader, is_distributed: bool): 762 | epoch = 0 763 | while True: 764 | if is_distributed and isinstance(dataloader.sampler, DistributedSampler): 765 | dataloader.sampler.set_epoch(epoch) 766 | for batch in dataloader: 767 | yield batch 768 | epoch += 1 769 | 770 | 771 | def set_seeds(seed): 772 | random.seed(seed) 773 | np.random.seed(seed) 774 | torch.manual_seed(seed) 775 | if torch.cuda.is_available(): 776 | torch.cuda.manual_seed_all(seed) 777 | 778 | 779 | if __name__ == "__main__": 780 | main() 781 | -------------------------------------------------------------------------------- /gpt_pretrain.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import glob 3 | import os 4 | import re 5 | import random 6 | import logging 7 | import numpy as np 8 | import math 9 | from tqdm import tqdm 10 | import time 11 | import torch 12 | from transformers import AutoTokenizer, AutoModelForMaskedLM, AutoModelForCausalLM 13 | from transformers import BertForMaskedLM, RobertaForMaskedLM, GPT2LMHeadModel 14 | from transformers import AutoConfig 15 | from transformers import DataCollatorForLanguageModeling 16 | from transformers.optimization import AdamW, get_linear_schedule_with_warmup 17 | 18 | from torch.utils.data import Dataset, DataLoader 19 | import pytorch_lightning as ptl 20 | from pytorch_lightning.trainer.training_loop import TrainLoop 21 | from pytorch_lightning.loggers.test_tube import TestTubeLogger 22 | from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor 23 | from pytorch_lightning.trainer.connectors.checkpoint_connector import CheckpointConnector 24 | from pytorch_lightning.utilities import AMPType, rank_zero_warn 25 | from pytorch_lightning.utilities.exceptions import MisconfigurationException 26 | from pytorch_lightning.utilities.upgrade_checkpoint import KEYS_MAPPING as DEPRECATED_CHECKPOINT_KEYS 27 | from torch.optim.lr_scheduler import LambdaLR 28 | import copy 29 | 30 | import multiprocessing 31 | try: 32 | from apex import amp 33 | except ImportError: 34 | amp = None 35 | 36 | import tensorflow as tf 37 | logging.basicConfig(level=logging.INFO) 38 | logger = logging.getLogger(__name__) 39 | 40 | try: 41 | import torch_xla.core.xla_model as xm 42 | except ImportError: 43 | XLA_AVAILABLE = False 44 | else: 45 | XLA_AVAILABLE = True 46 | 47 | # =======restart the linear warmup strategy with linear warmup========== 48 | def get_restart_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1, restart_warmup_steps=0, restart_steps=0): 49 | """ 50 | Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after 51 | a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer. 52 | 53 | Args: 54 | optimizer (:class:`~torch.optim.Optimizer`): 55 | The optimizer for which to schedule the learning rate. 56 | num_warmup_steps (:obj:`int`): 57 | The number of steps for the warmup phase. 58 | num_training_steps (:obj:`int`): 59 | The total number of training steps. 60 | last_epoch (:obj:`int`, `optional`, defaults to -1): 61 | The index of the last epoch when resuming training, will be modified. 62 | restart_warmup_steps: 63 | the restart_warmup_steps should be set last_epoch + restart_warmup_steps; 64 | 65 | Return: 66 | :obj:`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. 67 | """ 68 | 69 | def lr_lambda(current_step: int): 70 | 71 | if restart_steps != 0 and restart_warmup_steps != 0 \ 72 | and current_step < restart_steps + restart_warmup_steps \ 73 | and current_step >= restart_steps: 74 | assert current_step >= restart_steps 75 | 76 | # pre-warmup + restart-warmup 77 | if current_step < num_warmup_steps: 78 | return float(current_step - restart_steps) / float(max(1, restart_warmup_steps)) * float(restart_steps+restart_warmup_steps) / float(max(1, num_warmup_steps)) 79 | else: 80 | return float(current_step - restart_steps) / float(max(1, restart_warmup_steps)) * float(num_training_steps - restart_steps-restart_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) 81 | 82 | if current_step < num_warmup_steps: 83 | return float(current_step) / float(max(1, num_warmup_steps)) 84 | return max( 85 | 0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps)) 86 | ) 87 | 88 | return LambdaLR(optimizer, lr_lambda, last_epoch) 89 | 90 | # weights growth operator ================= 91 | def double_split_matrix_weight(x, is_grad, is_avg_sq): 92 | embed_dim = x.shape[0] 93 | split_dim = x.shape[1] // x.shape[0] 94 | y_shape = [2 * i for i in x.shape] 95 | y = x.new_zeros(*y_shape) 96 | 97 | for split_idx in range(split_dim): 98 | start_idx, end_idx = split_idx * x.shape[0], (1+split_idx) * x.shape[0] 99 | y_split = double_matrix_weight( x[:,start_idx:end_idx], is_grad, is_avg_sq ) 100 | y[:,start_idx*2:end_idx*2] = y_split.detach().clone() 101 | 102 | return y 103 | 104 | def double_split_bias(x, embed_dim, is_grad, is_avg_sq): 105 | split_dim = x.shape[0] // embed_dim 106 | y_shape = [2 * i for i in x.shape] 107 | y = x.new_zeros(*y_shape) 108 | 109 | for split_idx in range(split_dim): 110 | start_idx, end_idx = split_idx * embed_dim, (1+split_idx) * embed_dim 111 | y[start_idx*2:start_idx*2+embed_dim] = x[start_idx:end_idx].detach().clone() 112 | y[start_idx*2+embed_dim:end_idx*2] = x[start_idx:end_idx].detach().clone() 113 | 114 | if is_grad: 115 | y /= 2.0 if not is_avg_sq else 4.0 116 | 117 | return y 118 | 119 | def double_matrix_weight(x, is_grad, is_avg_sq): 120 | # x = (n, m), returns y = (2 * n, 2 * m), used for FF layers 121 | y_shape = [2 * i for i in x.shape] 122 | y = x.new_zeros(*y_shape) 123 | 124 | x_shape = x.shape 125 | y[:x_shape[0], :x_shape[1]] = x.detach().clone() 126 | y[-x_shape[0]:, -x_shape[1]:] = x.detach().clone() 127 | if is_grad: 128 | y /= 2.0 if not is_avg_sq else 4.0 129 | 130 | 131 | return y 132 | 133 | 134 | def double_bias(x, is_grad, is_avg_sq): 135 | # x = (n, ), returns y = (2 * n, ), used for bias weights 136 | y_shape = [2 * i for i in x.shape] 137 | y = x.new_zeros(*y_shape) 138 | 139 | x_shape = x.shape 140 | 141 | y[:x_shape[0]] = x.detach().clone() 142 | y[-x_shape[0]:] = x.detach().clone() 143 | if is_grad: 144 | y /= 2.0 if not is_avg_sq else 4.0 145 | 146 | return y 147 | 148 | 149 | def double_embedding(x, is_grad, is_avg_sq): 150 | # x = (vocab size, M), returns y = (vocab size, 2 * M), used for embedding layers 151 | y_shape = [i for i in x.shape] 152 | y_shape[1] *= 2 153 | y = x.new_zeros(*y_shape) 154 | 155 | x_shape = x.shape 156 | y[:, :x_shape[1]] = x.detach().clone() 157 | y[:, -x_shape[1]:] = x.detach().clone() 158 | if is_grad: 159 | y /= 2.0 if not is_avg_sq else 4.0 160 | # if args.noise_std is not None and args.noise_std != 0.0: 161 | # y += torch.normal(mean=0, std=args.noise_std, size=y.shape) 162 | return y 163 | 164 | 165 | def double_param(key, weight, is_double_embedding, is_grad, is_avg_sq): 166 | if 'lm_head' in key: # for roberta 167 | # the lm_head is a linear layer then the softmax layer 168 | if 'dense' in key: 169 | # this is the linear layer - need to expand as other linear layers 170 | if 'weight' in key: 171 | return double_matrix_weight(weight, is_grad=is_grad, is_avg_sq=is_avg_sq) 172 | elif 'bias' in key: 173 | return double_bias(weight, is_grad=is_grad, is_avg_sq=is_avg_sq) 174 | 175 | elif 'layer_norm' in key or 'ln' in key: 176 | # layer norm is weight * (x - mean)/std + bias, so need to divide both weight and bias by 2.0 177 | new_weight = double_bias(weight, is_grad=False, is_avg_sq=False) 178 | if not is_grad: 179 | new_weight /= 2.0 180 | return new_weight 181 | elif 'weight' in key: 182 | return double_embedding(weight, is_grad=is_grad, is_avg_sq=is_avg_sq) 183 | elif 'bias' in key: 184 | # this is the bias parameter added for the final softmax logit, shape (vocab_size, ) 185 | return weight 186 | 187 | elif 'pooler' in key: 188 | # don't think this pooler is used without next-sentence-prediction in bert, but we'll double it anyway 189 | if 'weight' in key: 190 | return double_matrix_weight(weight, is_grad=is_grad, is_avg_sq=is_avg_sq) 191 | elif 'bias' in key: 192 | return double_bias(weight, is_grad=is_grad, is_avg_sq=is_avg_sq) 193 | 194 | elif 'cls' in key: 195 | # the masked LM head. 196 | # in BERT it is top layer activations -> dense with same hidden dim -> activation -> layer norm -> decoder 197 | # where the decoder is linear that has the same weights as the word embeddings and a new bias layer 198 | # to maintain the same loss, we want the logit outputs to remain the same - so can do it by duplicating 199 | # the word embeddings / bias, and dividing the input by 2. We accomplish the two division by modifying the 200 | # layer norm parameters right before prediction. 201 | # 202 | # cls.predictions.bias torch.Size([30522]) 203 | # cls.predictions.transform.dense.weight torch.Size([768, 768]) 204 | # cls.predictions.transform.dense.bias torch.Size([768]) 205 | # cls.predictions.transform.LayerNorm.weight torch.Size([768]) 206 | # cls.predictions.transform.LayerNorm.bias torch.Size([768]) 207 | # cls.predictions.decoder.weight torch.Size([30522, 768]) 208 | # cls.predictions.decoder.bias torch.Size([30522]) 209 | if key.endswith('cls.predictions.bias') or key.endswith('cls.predictions.decoder.bias'): 210 | # these are size(vocab) and remain unchanged 211 | return weight 212 | elif key.endswith('cls.predictions.transform.dense.bias'): 213 | return double_bias(weight, is_grad=is_grad, is_avg_sq=is_avg_sq) 214 | elif key.endswith('cls.predictions.transform.dense.weight'): 215 | return double_matrix_weight(weight, is_grad=is_grad, is_avg_sq=is_avg_sq) 216 | elif key.endswith('cls.predictions.decoder.weight'): 217 | return double_embedding(weight, is_grad=is_grad, is_avg_sq=is_avg_sq) 218 | elif 'LayerNorm' in key: 219 | # layer norm is weight * (x - mean)/std + bias, so need to divide both weight and bias by 2.0 220 | new_weight = double_bias(weight, is_grad=False, is_avg_sq=False) 221 | if not is_grad: 222 | new_weight /= 2.0 223 | return new_weight 224 | 225 | elif 'word_embeddings' in key or 'position_embeddings' in key or 'token_type_embeddings' in key \ 226 | or "wte.weight" in key or "wpe.weight" in key: 227 | if is_double_embedding: 228 | return double_embedding(weight, is_grad=is_grad, is_avg_sq=is_avg_sq) 229 | else: 230 | return weight.detach().clone() 231 | elif "masked_bias" in key or ("attn.bias" in key and len(weight.shape) != 1): 232 | return weight.detach().clone() 233 | elif "c_attn.weight" in key: 234 | return double_split_matrix_weight(weight, is_grad=is_grad, is_avg_sq=is_avg_sq) 235 | elif "c_attn.bias" in key: 236 | # TODO: this is hacked for GPT2 237 | return double_split_bias(weight, embed_dim=weight.shape[0] // 3, is_grad=is_grad, is_avg_sq=is_avg_sq) 238 | elif 'query.weight' in key or 'key.weight' in key or 'value.weight' in key or 'dense.weight' in key \ 239 | or "c_proj.weight" in key or "c_fc.weight" in key: 240 | return double_matrix_weight(weight, is_grad=is_grad, is_avg_sq=is_avg_sq) 241 | elif "ln_f" in key: 242 | new_weight = double_bias(weight, is_grad=False, is_avg_sq=False) 243 | if not is_grad: 244 | new_weight /= 2.0 245 | return new_weight 246 | elif 'LayerNorm' in key or 'bias' in key or 'ln' in key: 247 | return double_bias(weight, is_grad=is_grad, is_avg_sq=is_avg_sq) 248 | elif 'position_ids' in key: 249 | return weight 250 | 251 | # Not found 252 | print(key) 253 | import ipdb; ipdb.set_trace() 254 | # raise ValueError(key, weight.shape) 255 | 256 | 257 | def double_state_dict(old_state_dict, is_double_embedding): 258 | new_state_dict = {} 259 | for key, weight in old_state_dict.items(): 260 | new_state_dict[key] = double_param(key, weight, is_double_embedding=is_double_embedding, is_grad=False, is_avg_sq=False) 261 | return new_state_dict 262 | 263 | # depth growth operator 264 | def deep_split_matrix_weight(x, is_identical, is_grad, is_avg_sq): 265 | if not is_identical: 266 | return x.detach().clone() 267 | 268 | embed_dim = x.shape[0] 269 | split_dim = x.shape[1] // x.shape[0] 270 | y_shape = [i for i in x.shape] 271 | y = x.new_zeros(*y_shape) 272 | 273 | for split_idx in range(split_dim): 274 | start_idx, end_idx = split_idx * x.shape[0], (1+split_idx) * x.shape[0] 275 | y_split = deep_matrix_weight( x[:,start_idx:end_idx], is_identical, is_grad, is_avg_sq ) 276 | y[:,start_idx:end_idx] = y_split.detach().clone() 277 | 278 | return y 279 | 280 | def deep_matrix_weight(x, is_identical, is_grad, is_avg_sq): 281 | # x = (n, m), returns y = (2 * n, 2 * m), used for FF layers 282 | if is_identical: 283 | y = torch.zeros_like(x) 284 | if len(y.shape) > 1: 285 | y.fill_diagonal_(1) 286 | return y 287 | else: 288 | return x.detach().clone() 289 | 290 | 291 | def deep_bias(x, is_identical, is_grad, is_avg_sq): 292 | # x = (n, ), returns y = (2 * n, ), used for bias weights 293 | if is_identical: 294 | return torch.zeros_like(x) 295 | else: 296 | return x.detach().clone() 297 | 298 | def deep_param(key, weight, is_identical, is_grad, is_avg_sq): 299 | if "c_attn.weight" in key: 300 | return deep_split_matrix_weight(weight, is_identical=is_identical, is_grad=is_grad, is_avg_sq=is_avg_sq) 301 | elif 'weight' in key: 302 | return deep_matrix_weight(weight, is_identical=is_identical, is_grad=is_grad, is_avg_sq=is_avg_sq) 303 | elif 'bias' in key: 304 | return deep_bias(weight, is_identical=is_identical, is_grad=is_grad, is_avg_sq=is_avg_sq) 305 | 306 | def deep_state_dict(old_state_dict, map_positions, is_identical): 307 | # how to insert layers: direct copy, identical copy 308 | # operator over the blocks: hacked for GPT-3 309 | new_state_dict = {} 310 | for key, weight in old_state_dict.items(): 311 | if map_positions.get( key ): 312 | for (new_key, new_key_copy_flag) in map_positions.get( key ): 313 | # print( new_key_copy_flag, is_identical, new_key, key ) 314 | new_state_dict[new_key] = deep_param(key, weight, is_identical=new_key_copy_flag and is_identical, is_grad=False, is_avg_sq=False) 315 | else: 316 | new_state_dict[key] = weight.detach().clone() 317 | 318 | return new_state_dict 319 | 320 | 321 | def test_double_matrix_weight(): 322 | weight = torch.rand(52, 88) 323 | x = torch.rand(88, 1) 324 | reset_model_opt_copy=True 325 | weight2 = double_matrix_weight(weight, is_grad=False, is_avg_sq=False, reset_model_opt_copy=reset_model_opt_copy) 326 | y = torch.matmul(weight, x) 327 | y2 = torch.matmul(weight2, torch.cat([x, x], dim=0)) 328 | print(torch.allclose(y, y2[:52], atol=1e-05, rtol=1e-03)) 329 | assert torch.abs(y - y2[:52]).max() + torch.abs(y - y2[-52:]).max() < 1e-4 330 | 331 | x = torch.rand(1, 11) 332 | c_attn = torch.rand(11, 11*3) 333 | print(len(torch.matmul(x, c_attn).split(11, dim=1))) 334 | y0, y1, y2 = torch.matmul(x, c_attn).split(11, dim=1) 335 | 336 | c_attn2 = double_split_matrix_weight(c_attn, is_grad=False, is_avg_sq=False, reset_model_opt_copy=reset_model_opt_copy) 337 | 338 | y00, y11, y22 = torch.matmul(torch.cat([x, x], dim=1), c_attn2).split(11*2, dim=1) 339 | 340 | allcose = torch.allclose(y0, y00[:, :11], atol=1e-05, rtol=1e-03) 341 | print("reset_model_opt_copy", reset_model_opt_copy, allcose, y0.sum(), y00[:, :11].sum(), y00[:, 11:].sum()) 342 | if not allcose: 343 | import ipdb; ipdb.set_trace() 344 | # import ipdb; ipdb.set_trace() 345 | 346 | 347 | def test_double_gradients(): 348 | test_double_matrix_weight() 349 | # config = AutoConfig.from_pretrained('roberta-base') 350 | # config.hidden_size = 4 351 | # config.intermediate_size = 16 352 | # config.max_position_embeddings = 8 353 | # config.num_attention_heads = 1 354 | # config.num_hidden_layers = 1 355 | # config.vocab_size = 6 356 | 357 | # config.attention_probs_dropout_prob = 0 358 | # config.hidden_dropout_prob = 0 359 | # model = RobertaForMaskedLM(config=config) 360 | 361 | from gpt_pretrain import double_weights, double_param 362 | from transformers import AutoConfig, RobertaForMaskedLM, AutoModelForMaskedLM 363 | from transformers.optimization import AdamW, get_linear_schedule_with_warmup 364 | import torch 365 | # model = AutoModelForMaskedLM.from_pretrained('bert-base-uncased') 366 | model = AutoModelForCausalLM.from_pretrained('gpt2') 367 | 368 | optimizer = AdamW(model.parameters(), lr=0.00000, betas=(0.0, 0.0)) 369 | model.eval() 370 | input_ids = torch.tensor([[1, 2, 3, 4]]) 371 | labels = torch.tensor([[1, 2, 3, 4]]) 372 | loss = model(input_ids=input_ids, labels=labels)[0] 373 | loss.backward() 374 | optimizer.step() 375 | # model.roberta.embeddings.word_embeddings.weight.grad 376 | reset_model_opt_copy=True 377 | reset_model_noise=False 378 | double_model = double_weights(model, is_double_embedding=True, reset_model_opt_copy=reset_model_opt_copy,reset_model_noise=reset_model_noise) 379 | double_optimizer = AdamW(double_model.parameters(), lr=0.00000, betas=(0.0, 0.0)) 380 | double_model.eval() 381 | double_loss = double_model(input_ids=input_ids, labels=labels)[0] 382 | double_loss.backward() 383 | double_optimizer.step() 384 | 385 | print(double_loss.item(), loss.item(), torch.allclose(double_loss, loss, atol=1e-05, rtol=1e-03)) 386 | assert torch.allclose(double_loss, loss, atol=1e-05, rtol=1e-03) 387 | # exit() 388 | for (name, parameter), (double_name, double_parameter), (opt_key, opt_val), (double_opt_key, double_opt_val) in \ 389 | zip(model.named_parameters(), double_model.named_parameters(), 390 | optimizer.state.items(), double_optimizer.state.items()): 391 | assert name == double_name 392 | assert id(parameter) == id(opt_key) 393 | assert id(double_parameter) == id(double_opt_key) 394 | predicted = double_param(name, parameter.grad, is_double_embedding=True, is_grad=True, is_avg_sq=False, reset_optimizer_copy=True, reset_model_opt_copy=reset_model_opt_copy) 395 | all_close = torch.allclose(predicted, double_parameter.grad, atol=1e-05, rtol=1e-03) 396 | 397 | if not all_close: 398 | print('1', all_close, name, parameter.shape, ) 399 | print(predicted) 400 | print(double_parameter.grad) 401 | 402 | predicted = double_param(name, opt_val['exp_avg'], is_double_embedding=True, is_grad=True, is_avg_sq=False, reset_optimizer_copy=True, reset_model_opt_copy=reset_model_opt_copy) 403 | all_close = torch.allclose(predicted, double_opt_val['exp_avg'], atol=1e-05, rtol=1e-03) 404 | if not all_close: 405 | print('2', all_close, name, parameter.shape, ) 406 | print(predicted) 407 | print(double_opt_val['exp_avg'],) 408 | 409 | predicted = double_param(name, opt_val['exp_avg_sq'], is_double_embedding=True, is_grad=True, is_avg_sq=True, reset_optimizer_copy=True, reset_model_opt_copy=reset_model_opt_copy) 410 | all_close = torch.allclose(predicted, double_opt_val['exp_avg_sq'], atol=1e-05, rtol=1e-03) 411 | if not all_close: 412 | print('3', all_close, name, parameter.shape, ) 413 | print(predicted) 414 | print(double_opt_val['exp_avg_sq'],) 415 | import ipdb; ipdb.set_trace() 416 | else: 417 | print('3', all_close, name, parameter.shape, ) 418 | 419 | 420 | import ipdb; ipdb.set_trace() 421 | 422 | def double_weights(model, is_double_embedding): 423 | print(model) 424 | config = model.config 425 | 426 | # create an instance of the model twice the size 427 | new_config_dict = config.to_dict() 428 | new_config_dict['n_embd'] *= 2 429 | new_config_dict['n_inner'] = new_config_dict['n_inner']*2 if new_config_dict['n_inner'] is not None else None 430 | new_config_dict['n_head'] *= 2 431 | 432 | new_config = type(config).from_dict(new_config_dict) 433 | new_model = type(model)(new_config) 434 | 435 | # load the weights from the old model into new model after duplicating them 436 | model.tie_weights() 437 | new_model.tie_weights() 438 | 439 | new_state_dict = double_state_dict(model.state_dict(), is_double_embedding=is_double_embedding) 440 | new_model.load_state_dict(new_state_dict) 441 | new_model.tie_weights() 442 | 443 | return new_model 444 | 445 | 446 | def double_depth(model): 447 | print(model) 448 | config = model.config 449 | 450 | # create an instance of the model twice the size 451 | new_config_dict = config.to_dict() 452 | print(new_config_dict) 453 | new_config_dict['num_hidden_layers'] *= 2 454 | 455 | new_config = type(config).from_dict(new_config_dict) 456 | new_model = type(model)(new_config) 457 | 458 | # load the weights from the old model into new model after duplicating them 459 | model.tie_weights() 460 | new_model.tie_weights() 461 | 462 | new_state_dict = deep_state_dict(model.state_dict()) 463 | new_model.load_state_dict(new_state_dict) 464 | new_model.tie_weights() 465 | 466 | return new_model 467 | 468 | 469 | # the dataset object we are using 470 | class MMapTextDataset(Dataset): 471 | def __init__(self, mmap_filename, chunk_size, bos_token_id, eos_token_id): 472 | # `chunk_size - 2` to reserve space for and 473 | self.num_instances = np.memmap(mmap_filename, mode='r', dtype=np.uint16).shape[0] // (chunk_size - 2) 474 | # defer loading the token_ids memmap until after the first __getitem__ call. 475 | # when spawning new processes for ddp, there is a hard limit in python < 3.8 that 476 | # pickle files need to be < 4GB. By waiting until after the first __getitem__ we 477 | # don't have to pickle the memmap 478 | self.token_ids = None 479 | self._mmap_filename = mmap_filename 480 | self._chunk_size = chunk_size 481 | self._bos_token_id = bos_token_id 482 | self._eos_token_id = eos_token_id 483 | 484 | def __len__(self): 485 | return self.num_instances 486 | 487 | def __getitem__(self, i): 488 | if self.token_ids is None: 489 | self.token_ids = np.memmap(self._mmap_filename, mode='r', dtype=np.uint16) 490 | from_index = i * (self._chunk_size - 2) 491 | to_index = (i + 1) * (self._chunk_size - 2) 492 | data = np.concatenate(([self._bos_token_id], self.token_ids[from_index:to_index], [self._eos_token_id])) 493 | return torch.tensor(data, dtype=torch.long) 494 | 495 | # ========================= preprocessing code ========================= # 496 | @staticmethod 497 | def _process_file(full_fname): 498 | "Step 1: tokenize an input text file then save token ids into `np.memmap` shards of size `args.shard_size`" 499 | fname = full_fname.split('/')[-1] 500 | if args.data_type == 'tfrecord': 501 | log_filename = f'{args.output_dir}/logs-{fname}.log' 502 | elif args.data_type == 'raw_text': 503 | log_filename = f'{args.output_dir}/logs-{args.shard_size}/{fname}.log' 504 | if os.path.isfile(log_filename): 505 | logging.info(f'Skipping {full_fname} ...') 506 | return # log file already exists. Skip current file. 507 | 508 | if args.num_workers > 1: 509 | current = multiprocessing.current_process() 510 | process_identity = int(current._identity[0]) 511 | else: 512 | process_identity = 1 513 | 514 | if process_identity == 1: 515 | logging.info(f'Processing {full_fname} ...') 516 | 517 | def _write_shard(): 518 | if len(token_list) == 0: 519 | return 520 | # if token_list[-1] != MMapTextDataset.tokenizer.sep_token_id: # handle a rare case 521 | # token_list.append(MMapTextDataset.tokenizer.sep_token_id) 522 | if args.data_type in ['tfrecord', 's2']: 523 | shared_filename = f'{args.output_dir}/{fname}.bin' 524 | elif args.data_type == 'raw_text': 525 | shared_filename = f'{args.output_dir}/shards-{args.shard_size}/{fname}-{shard_count}.bin' 526 | else: 527 | raise NotImplementedError 528 | logging.info(f'Writing {len(token_list)} tokens to shared {shared_filename}') 529 | fp = np.memmap(shared_filename, dtype=np.uint16, mode='w+', shape=len(token_list)) 530 | fp[:] = token_list[:] 531 | del fp # flush and close file 532 | 533 | token_list = [] 534 | shard_count = 0 535 | tokens_count = 0 536 | 537 | if args.data_type == 'raw_text': # the input file is one doc per line 538 | with open(full_fname, 'r') as fin: 539 | for line in tqdm(fin): 540 | line = line.strip() 541 | if line == '': # drop empty lines 542 | continue 543 | tokens = MMapTextDataset.tokenizer.encode(line, add_special_tokens=False) # `__getitem__` adds special tokens 544 | token_list.extend(tokens) 545 | if len(token_list) > args.shard_size: 546 | _write_shard() 547 | tokens_count += len(token_list) 548 | token_list = [] 549 | shard_count += 1 550 | else: 551 | token_list.append(MMapTextDataset.tokenizer.sep_token_id) 552 | _write_shard() 553 | tokens_count += len(token_list) 554 | elif args.data_type == 'tfrecord': # the input file is tfrecord format of the c4 dataset 555 | fin = tf.data.TFRecordDataset(full_fname) 556 | for raw_example in tqdm(iter(fin), disable=process_identity != 1): 557 | parsed = tf.train.Example.FromString(raw_example.numpy()) 558 | feature_keys = set(parsed.features.feature.keys()) 559 | if 'text' in feature_keys: 560 | line = parsed.features.feature['text'].bytes_list.value[0].decode() # raw text 561 | tokens = MMapTextDataset.tokenizer.encode(line, add_special_tokens=False) # `__getitem__` adds special tokens 562 | if args.add_sep_after_doc: 563 | tokens.append(MMapTextDataset.tokenizer.sep_token_id) 564 | token_list.extend(tokens) 565 | tokens_count += len(token_list) 566 | shard_count += 1 567 | _write_shard() 568 | 569 | with open(log_filename, 'w') as f: 570 | f.write(f'Generated {tokens_count} tokens in {shard_count + 1} shards') 571 | 572 | @staticmethod 573 | def _combine_shards(output_fname, shards_list): 574 | "Step 2: combining memmap shards into one `train.bin` or `val.bin` file" 575 | total_size = 0 576 | for filename in shards_list: 577 | total_size += np.memmap(filename, mode='r', dtype=np.uint16).shape[0] 578 | logging.info(f'Writing {total_size} tokens to {output_fname}') 579 | all_token_ids = np.empty(total_size, dtype=np.uint16) 580 | last_token_index = 0 581 | for filename in tqdm(shards_list): 582 | shared = np.memmap(filename, mode='r', dtype=np.uint16) 583 | all_token_ids[last_token_index:last_token_index+len(shared)] = shared[:] 584 | last_token_index += len(shared) 585 | fp = np.memmap(output_fname, dtype=np.uint16, mode='w+', shape=total_size) 586 | fp[:] = all_token_ids[:] 587 | del fp 588 | 589 | @staticmethod 590 | def raw_text_to_mmap(args): 591 | """This is the main preprocessing function. It processes all the text files in `args.input_dir` and 592 | outputs two np.memmap files, one for training and one for validation with ratio `args.train_dev_split`. 593 | Processing each input file involves tokenizing it, sharding it into shards of size `args.shard_size`, 594 | then writing each shard as an np.memmap file, shuffle the shards, split them into train and dev shards, 595 | then combine the shards of each set into one big file (train.bin and val.bin). 596 | Notice that only the shards are shuffled not the instances inside each shard. Therefor, it is important 597 | to use `args.shard_size` that's small enough to have a good train/dev split, but also not small enough 598 | to end up with a huge number of shards that might be difficult to work with. 599 | The stream of tokens in the memmap files represents documents separated with `tokenizer.sep_token`. 600 | In `__getitem__`, the `tokenizer.bos_token` and `tokenizer.eos_token` 601 | are added. The reason for not adding them at preprocessing time is to allow different sequence lengths 602 | later on. Notice that this is the "FULL-SENTENCES" setting in the RoBERTa paper, Table2. 603 | Example running the preprocessing: 604 | >>> python scripts/pretrain.py --input_dir dirWithTextFiles --train_dev_split 0.05 \ 605 | --shard_size 268435456 --num_preprocessing_workers 16 606 | """ 607 | MMapTextDataset.tokenizer = AutoTokenizer.from_pretrained(args.tokenizer, use_fast=True) 608 | assert len(MMapTextDataset.tokenizer) < 65535 # will use uint16 to store token ids 609 | all_files = glob.glob(f'{args.input_dir}/c4-*') 610 | print(len(all_files), MMapTextDataset.tokenizer) 611 | if os.path.exists(f'{args.output_dir}/cache/train.bin') and os.path.exists(f'{args.input_dir}/cache/val.bin'): 612 | logger.info("Cache already exists. Remove the cache directory to regenerate") 613 | return 614 | try: 615 | os.mkdir(f'{args.output_dir}/cache/') 616 | except FileExistsError: 617 | pass 618 | try: 619 | os.mkdir(f'{args.output_dir}/shards-{args.shard_size}/') 620 | except FileExistsError: 621 | pass 622 | try: 623 | os.mkdir(f'{args.output_dir}/logs-{args.shard_size}/') # log progrss to be able to resume 624 | except FileExistsError: 625 | pass 626 | 627 | # STEP1: tokenizing and saving to shards 628 | if args.num_preprocessing_workers > 1: 629 | from multiprocessing.pool import Pool 630 | with Pool(args.num_preprocessing_workers) as p: 631 | list(tqdm(p.imap(MMapTextDataset._process_file, all_files), total=len(all_files))) 632 | else: 633 | [MMapTextDataset._process_file(f) for f in tqdm(all_files)] 634 | 635 | if args.data_type == 'raw_text': # c4 tfrecords are already sharded 636 | # STEP2: shuffling shards and combining them into train.bin and val.bin files 637 | all_shards = glob.glob(f'{args.output_dir}/shards-{args.shard_size}/*.bin') 638 | random.shuffle(all_shards) # shuffling based on shards not individual lines 639 | val_shards_count = int(args.train_dev_split * len(all_shards)) 640 | val_shards = all_shards[:val_shards_count] 641 | train_shards = all_shards[val_shards_count:] 642 | # TODO: if MMapTextDataset._combining_shards is very slow for large files, it can be skipped but we nned to 643 | # update the dataset to read from multiple shards directly 644 | MMapTextDataset._combine_shards(f'{args.output_dir}/cache/val.bin', val_shards) 645 | MMapTextDataset._combine_shards(f'{args.output_dir}/cache/train.bin', train_shards) 646 | elif args.data_type == 'tfrecord': 647 | train_shards = glob.glob(f'{args.output_dir}/*train*.bin') 648 | val_shards = glob.glob(f'{args.output_dir}/*val*.bin') 649 | MMapTextDataset._combine_shards(f'{args.output_dir}/val.bin', val_shards) 650 | MMapTextDataset._combine_shards(f'{args.output_dir}/train.bin', train_shards) 651 | del MMapTextDataset.tokenizer 652 | # ========================= end preprocessing code ========================= # 653 | 654 | 655 | class MyCheckpointConnector(CheckpointConnector): 656 | def __init__(self, trainer, reset_optimizer=False, reset_lr_scheduler=False, set_global_step=None): 657 | super().__init__(trainer) 658 | self.reset_optimizer = reset_optimizer 659 | self.reset_lr_scheduler = reset_lr_scheduler 660 | self.set_global_step = set_global_step 661 | 662 | def restore_training_state(self, checkpoint, load_optimizer_states: bool = True): 663 | """ 664 | COPIED from https://github.com/PyTorchLightning/pytorch-lightning/blob/1.0.8/pytorch_lightning/trainer/connectors/checkpoint_connector.py#L130-L199 665 | and updated to support reset_optimizer and reset_lr_scheduler 666 | """ 667 | # validation 668 | if 'optimizer_states' not in checkpoint or 'lr_schedulers' not in checkpoint: 669 | raise KeyError( 670 | 'Trying to restore training state but checkpoint contains only the model.' 671 | ' This is probably due to `ModelCheckpoint.save_weights_only` being set to `True`.' 672 | ) 673 | 674 | if any([key in checkpoint for key in DEPRECATED_CHECKPOINT_KEYS]): 675 | raise ValueError( 676 | "The checkpoint you're attempting to load follows an" 677 | " outdated schema. You can upgrade to the current schema by running" 678 | " `python -m pytorch_lightning.utilities.upgrade_checkpoint --file model.ckpt`" 679 | " where `model.ckpt` is your checkpoint file." 680 | ) 681 | 682 | # restore amp scaling 683 | if self.trainer.amp_backend == AMPType.NATIVE and 'native_amp_scaling_state' in checkpoint: 684 | self.trainer.scaler.load_state_dict(checkpoint['native_amp_scaling_state']) 685 | elif self.trainer.amp_backend == AMPType.APEX and 'amp_scaling_state' in checkpoint: 686 | amp.load_state_dict(checkpoint['amp_scaling_state']) 687 | 688 | # restore callback states 689 | self.trainer.on_load_checkpoint(checkpoint) 690 | 691 | self.trainer.global_step = checkpoint['global_step'] 692 | if self.set_global_step is not None: 693 | self.trainer.global_step = self.set_global_step 694 | self.trainer.current_epoch = checkpoint['epoch'] 695 | 696 | # crash if max_epochs is lower then the current epoch from the checkpoint 697 | if self.trainer.current_epoch > self.trainer.max_epochs: 698 | m = f""" 699 | you restored a checkpoint with current_epoch={self.trainer.current_epoch} 700 | but the Trainer(max_epochs={self.trainer.max_epochs}) 701 | """ 702 | raise MisconfigurationException(m) 703 | 704 | # Division deals with global step stepping once per accumulated batch 705 | # Inequality deals with different global step for odd vs even num_training_batches 706 | n_accum = 1 if self.trainer.accumulate_grad_batches is None else self.trainer.accumulate_grad_batches 707 | expected_steps = self.trainer.num_training_batches / n_accum 708 | if self.trainer.num_training_batches != 0 and self.trainer.global_step % expected_steps > 1: 709 | rank_zero_warn( 710 | "You're resuming from a checkpoint that ended mid-epoch. " 711 | "This can cause unreliable results if further training is done, " 712 | "consider using an end of epoch checkpoint. " 713 | ) 714 | 715 | if not load_optimizer_states: 716 | return 717 | 718 | # restore the optimizers 719 | if not self.reset_optimizer: 720 | optimizer_states = checkpoint['optimizer_states'] 721 | for optimizer, opt_state in zip(self.trainer.optimizers, optimizer_states): 722 | print(opt_state.keys(), optimizer) 723 | # print(optimizer.param_groups.keys(), optimizer.param_groups) 724 | print([x.keys() for x in optimizer.param_groups]) 725 | print([x.keys() for x in opt_state['param_groups']]) 726 | optimizer.load_state_dict(opt_state) 727 | 728 | # move optimizer to GPU 1 weight at a time 729 | # avoids OOM 730 | if self.trainer.root_gpu is not None: 731 | for state in optimizer.state.values(): 732 | for k, v in state.items(): 733 | if isinstance(v, torch.Tensor): 734 | state[k] = v.cuda(self.trainer.root_gpu) 735 | 736 | if not self.reset_lr_scheduler: 737 | # restore the lr schedulers 738 | lr_schedulers = checkpoint['lr_schedulers'] 739 | if self.set_global_step is not None: 740 | for lrs_state in lr_schedulers: 741 | lrs_state['last_epoch'] = self.set_global_step 742 | lrs_state['_step_count'] = self.set_global_step + 1 743 | 744 | for scheduler, lrs_state in zip(self.trainer.lr_schedulers, lr_schedulers): 745 | scheduler['scheduler'].load_state_dict(lrs_state) 746 | else: 747 | if self.set_global_step is not None: 748 | for scheduler in self.trainer.lr_schedulers: 749 | scheduler['scheduler'].last_epoch = self.set_global_step 750 | scheduler['scheduler']._step_count = self.set_global_step+ 1 751 | 752 | 753 | # rewrite the MyTrainLoop from pytorch-lightning to support batch size and sequence length warmup 754 | class MyTrainLoop(TrainLoop): 755 | def __init__(self, trainer, multiple_trainloader_mode, args): 756 | super().__init__(trainer, multiple_trainloader_mode) 757 | self.args = args 758 | 759 | def grad_norm(self, model, norm_type, should_accumulate=False): 760 | # Override PTL `grad_norm` function to only return `total_grad_norm` instead norms of individual params 761 | # TODO: grad_norm reporting needs to take fp16 loss scale into account 762 | # parameters = [p for p in self.parameters() if p.grad is not None] 763 | # device = parameters[0].device 764 | # total_norm = torch.zeros([], device=device if parameters else None) 765 | # norm_type = float(norm_type) 766 | # for p in parameters: 767 | # param_norm = p.grad.norm(norm_type) 768 | # total_norm.add_(param_norm) 769 | norm_type = float(norm_type) 770 | 771 | norms, all_norms = {}, [] 772 | # local_norm = torch.zeros([], device=model.device) 773 | for name, p in model.named_parameters(): 774 | if p.grad is None: 775 | continue 776 | 777 | if not should_accumulate: 778 | # param_norm = float(p.grad.data.norm(norm_type)) 779 | p_grad = p.grad.data / args.batch_size / args.grad_accum 780 | param_norm = float( p_grad.norm(norm_type) ) 781 | else: 782 | p_grad = p.grad.data / self.trainer.accelerator.precision_plugin.scaler.get_scale() / args.batch_size 783 | param_norm = float( p_grad.norm(norm_type) ) 784 | all_norms.append(param_norm) 785 | # local_norm.add_(p.grad.norm(norm_type)) 786 | 787 | total_norm = float(torch.tensor(all_norms).norm(norm_type)) 788 | # norms[f'grad_{norm_type}_norm_total'] = round(total_norm, 4) 789 | # print("total_norm", total_norm, model.device, local_norm, self.trainer.accelerator.precision_plugin.scaler.get_scale()) 790 | if not should_accumulate: 791 | return {'total_grad_norm': total_norm, "batch_size": args.batch_size * self.trainer.world_size, "grad_accum": args.grad_accum } 792 | else: 793 | return { "local_grad_norm %s" % model.device: total_norm, "local_scale": self.trainer.accelerator.precision_plugin.scaler.get_scale() } 794 | 795 | def _track_gradient_norm(self): 796 | grad_norm_dict = {} 797 | if (self.trainer.global_step + 1) % self.trainer.log_every_n_steps == 0: 798 | if float(self.trainer.track_grad_norm) > 0: 799 | model = self.trainer.lightning_module 800 | grad_norm_dict = self.grad_norm(model, self.trainer.track_grad_norm) 801 | return grad_norm_dict 802 | 803 | def backward(self, result, optimizer, opt_idx, *args, **kwargs): 804 | self.trainer.dev_debugger.track_event("backward_call") 805 | 806 | should_accumulate = self.should_accumulate() 807 | # print(should_accumulate) 808 | # backward can be called manually in the training loop 809 | if isinstance(result, torch.Tensor): 810 | self.trainer.accelerator.backward(result, optimizer, opt_idx, should_accumulate, *args, **kwargs) 811 | else: 812 | result.closure_loss = self.trainer.accelerator.backward( 813 | result.closure_loss, optimizer, opt_idx, should_accumulate, *args, **kwargs 814 | ) 815 | 816 | if not self.should_accumulate(): 817 | # track gradients 818 | # print("track gradient with should_accumulate False") 819 | cur_grad_norm_dict = self.track_and_norm_grad(optimizer=optimizer) 820 | if 'total_grad_norm' in self._cur_grad_norm_dict: 821 | B_small, B_big = self._cur_grad_norm_dict['batch_size'], self._cur_grad_norm_dict['batch_size'] * self._cur_grad_norm_dict['grad_accum'] 822 | grad_norm_B_big = self._cur_grad_norm_dict['total_grad_norm'] 823 | grad_norm_B_small = [] 824 | if not hasattr(self, 'grad_norm_dict') or (hasattr(self, 'grad_norm_dict') and self.grad_norm_dict is None): 825 | B_critical = B_big 826 | else: 827 | for item in self.grad_norm_dict: 828 | if "local_grad_norm" in item: 829 | grad_norm_B_small.append( self.grad_norm_dict[item] ) 830 | 831 | grad_norm_B_small = np.average(grad_norm_B_small) 832 | g2 = 1 / (B_big-B_small) * (B_big * grad_norm_B_big - B_small*grad_norm_B_small) 833 | s = 1 / (1/B_small - 1/B_big) * (grad_norm_B_small - grad_norm_B_big) 834 | B_critical = s / g2 835 | self._cur_grad_norm_dict.update( self.grad_norm_dict ) 836 | self._cur_grad_norm_dict.update( {"critical_batch_size" : B_critical} ) 837 | for e in ['batch_size', 'grad_accum']: 838 | self._cur_grad_norm_dict.pop(e) 839 | # print(self._cur_grad_norm_dict) 840 | self.grad_norm_dict = None 841 | else: 842 | # print("track gradient with should_accumulate True") 843 | # first gradient accumulation step !!!!!!!!!!!! 844 | if hasattr(self, 'grad_norm_dict') and self.grad_norm_dict is None: 845 | model = self.trainer.lightning_module 846 | self.grad_norm_dict = self.grad_norm(model, self.trainer.track_grad_norm, True) 847 | 848 | def run_training_epoch(self): 849 | # modify dataloader if needed (ddp, etc...) 850 | train_dataloader = self.trainer.accelerator.process_dataloader(self.trainer.train_dataloader) 851 | 852 | # track epoch output 853 | epoch_output = [[] for _ in range(self.num_optimizers)] 854 | 855 | train_dataloader = self.trainer.data_connector.get_profiled_train_dataloader(train_dataloader) 856 | dataloader_idx = 0 857 | 858 | batch_idx = None 859 | is_last_batch = None 860 | 861 | accum_bsz, accum_bsz_grad_step = 0, 0 862 | for batch_idx, (batch, is_last_batch) in train_dataloader: 863 | self.trainer.batch_idx = batch_idx 864 | self.trainer.is_last_batch = is_last_batch 865 | 866 | # warmup the batch size via truncation and gradient accumulation 867 | # hack the deepest into the PTL to make it happen 868 | if self.args.warmup_bsz != 0: 869 | # for key in batch.keys(): 870 | # print(key, batch[key].shape, batch[key].device, batch[key].numel(), self.trainer.accumulate_grad_batches, self.trainer.model.device) 871 | input_ids = batch['input_ids'] 872 | 873 | final_bsz = input_ids.shape[0] * self.args.grad_accum * self.trainer.world_size 874 | start_bsz = 64 875 | 876 | current_bsz = start_bsz + (final_bsz - start_bsz) * min( 1.0, accum_bsz / self.args.warmup_bsz ) 877 | # print("before current_bsz", current_bsz, accum_bsz) 878 | if current_bsz >= final_bsz: 879 | self.trainer.accumulate_grad_batches = self.args.grad_accum 880 | else: 881 | current_bsz = current_bsz // self.trainer.world_size 882 | # try to reset gradient accum steps 883 | grad_accum = int(max(1, current_bsz // input_ids.shape[0])) 884 | 885 | if grad_accum == 1 or accum_bsz_grad_step <= 0: 886 | if grad_accum != 1 and accum_bsz_grad_step == 0: 887 | accum_bsz_grad_step = grad_accum 888 | self.trainer.accumulate_grad_batches = grad_accum 889 | bsz_after_chunk = int(current_bsz // self.trainer.accumulate_grad_batches) 890 | else: 891 | accum_bsz_grad_step -= 1 892 | 893 | # try to chunk the inputs 894 | # print("current_bsz", current_bsz, "grad_accum", grad_accum, self.trainer.accumulate_grad_batches, accum_bsz_grad_step, self.should_accumulate(), 'bsz_after_chunk', bsz_after_chunk, input_ids.shape[0]) 895 | if bsz_after_chunk < input_ids.shape[0]: 896 | for key in batch.keys(): 897 | batch[key] = torch.narrow(batch[key], 0, 0, bsz_after_chunk)#.to( self.trainer.model.device ) 898 | 899 | accum_bsz += batch['input_ids'].numel() 900 | 901 | if self.args.warmup_seq != 0: 902 | 903 | input_ids = batch['input_ids'] 904 | 905 | start_seq = 64 906 | final_seq = input_ids.shape[1] 907 | 908 | current_seq = int(start_seq + (final_seq - start_seq) * min( 1.0, accum_bsz / self.args.warmup_seq )) 909 | if accum_bsz_grad_step <= 0: 910 | accum_bsz_grad_step = self.trainer.accumulate_grad_batches 911 | else: 912 | accum_bsz_grad_step -= 1 913 | 914 | if current_seq < final_seq: 915 | for key in batch.keys(): 916 | batch[key] = torch.narrow(batch[key], 1, 0, current_seq) 917 | 918 | accum_bsz += batch['input_ids'].numel() 919 | 920 | # ------------------------------------ 921 | # TRAINING_STEP + TRAINING_STEP_END 922 | # ------------------------------------ 923 | with self.trainer.profiler.profile("run_training_batch"): 924 | batch_output = self.run_training_batch(batch, batch_idx, dataloader_idx) 925 | 926 | # when returning -1 from train_step, we end epoch early 927 | if batch_output.signal == -1: 928 | break 929 | 930 | # hook 931 | # TODO: add outputs to batches 932 | self.on_train_batch_end( 933 | epoch_output, 934 | batch_output.training_step_output_for_epoch_end, 935 | batch, 936 | batch_idx, 937 | dataloader_idx, 938 | ) 939 | 940 | # ----------------------------------------- 941 | # SAVE METRICS TO LOGGERS 942 | # ----------------------------------------- 943 | self.trainer.logger_connector.log_train_step_metrics(batch_output) 944 | 945 | # ----------------------------------------- 946 | # VALIDATE IF NEEDED 947 | # ----------------------------------------- 948 | should_check_val = self._should_check_val_fx(batch_idx, is_last_batch) 949 | if should_check_val: 950 | self.trainer.validating = True 951 | self.trainer.run_evaluation() 952 | self.trainer.training = True 953 | 954 | # ----------------------------------------- 955 | # SAVE LOGGERS (ie: Tensorboard, etc...) 956 | # ----------------------------------------- 957 | self.save_loggers_on_train_batch_end() 958 | 959 | # update LR schedulers 960 | monitor_metrics = copy.deepcopy(self.trainer.logger_connector.callback_metrics) 961 | self.update_train_loop_lr_schedulers(monitor_metrics=monitor_metrics) 962 | self.trainer.checkpoint_connector.has_trained = True 963 | 964 | self.trainer.total_batch_idx += 1 965 | 966 | # max steps reached, end training 967 | if ( 968 | self.trainer.max_steps is not None and self.trainer.max_steps <= self.trainer.global_step + 1 969 | and self._accumulated_batches_reached() 970 | ): 971 | break 972 | 973 | # end epoch early 974 | # stop when the flag is changed or we've gone past the amount 975 | # requested in the batches 976 | if self.trainer.should_stop: 977 | break 978 | 979 | # stop epoch if we limited the number of training batches 980 | if self._num_training_batches_reached(is_last_batch): 981 | break 982 | 983 | # progress global step according to grads progress 984 | self.increment_accumulated_grad_global_step() 985 | 986 | if batch_idx is None: 987 | # dataloader/iterator did not produce a batch 988 | return 989 | 990 | # handle epoch_output on epoch end 991 | self.on_train_epoch_end(epoch_output) 992 | 993 | # log epoch metrics 994 | self.trainer.logger_connector.log_train_epoch_end_metrics(epoch_output) 995 | 996 | should_check_val = self._should_check_val_fx(batch_idx, is_last_batch, on_epoch=True) 997 | should_skip_eval = self.trainer.evaluation_loop.should_skip_evaluation(self.trainer.num_val_batches) 998 | should_train_only = self.trainer.disable_validation or should_skip_eval 999 | 1000 | # update epoch level lr_schedulers if no val loop outside train loop is triggered 1001 | if not should_check_val or should_train_only: 1002 | self.trainer.optimizer_connector.update_learning_rates(interval='epoch') 1003 | 1004 | if should_train_only: 1005 | self.check_checkpoint_callback(True) 1006 | 1007 | if should_check_val: 1008 | self.trainer.validating = True 1009 | self.trainer.run_evaluation(on_epoch=True) 1010 | self.trainer.training = True 1011 | 1012 | if batch_output.signal != -1: 1013 | self.increment_accumulated_grad_global_step() 1014 | 1015 | class Pretrainer(ptl.LightningModule): 1016 | 1017 | def __init__(self): 1018 | super().__init__() 1019 | 1020 | self.args = args # hparams 1021 | self._set_hparams(self.args) #v1.3.5 ptl issue 1022 | # self.hparams = self.args 1023 | 1024 | #self.model = AutoModelForMaskedLM.from_pretrained(args.model) 1025 | self.model = AutoModelForCausalLM.from_pretrained(args.model) 1026 | if args.random: 1027 | if args.layers is not None and args.size is not None: 1028 | raise False 1029 | if args.layers is not None: 1030 | self.model.config.n_layer = args.layers 1031 | if args.size is not None: 1032 | if args.size == 'GPT2_base': 1033 | self.model.config.n_layer = 12 1034 | self.model.config.n_embd = 768 1035 | self.model.config.n_head = 8 1036 | elif args.size == 'GPT2_large': 1037 | self.model.config.n_layer = 24 1038 | self.model.config.n_embd = 1536 1039 | self.model.config.n_head = 16 1040 | elif args.size == 'GPT2_base_div2_width': 1041 | self.model.config.n_layer = 12 1042 | self.model.config.n_embd = 384 1043 | self.model.config.n_head = 4 1044 | elif args.size == 'GPT2_base_div2_depth': 1045 | self.model.config.n_layer = 6 1046 | self.model.config.n_embd = 768 1047 | self.model.config.n_head = 8 1048 | 1049 | elif args.size == 'GPT2_large_div4_width': 1050 | self.model.config.n_layer = 24 1051 | self.model.config.n_embd = 384 1052 | self.model.config.n_head = 4 1053 | 1054 | elif args.size == 'GPT2_large_div2_width': 1055 | self.model.config.n_layer = 24 1056 | self.model.config.n_embd = 768 1057 | self.model.config.n_head = 8 1058 | elif args.size == 'GPT2_large_div4_depth': 1059 | self.model.config.n_layer = 6 1060 | self.model.config.n_embd = 1536 1061 | self.model.config.n_head = 16 1062 | elif args.size == 'GPT2_large_div2_depth': 1063 | self.model.config.n_layer = 12 1064 | self.model.config.n_embd = 1536 1065 | self.model.config.n_head = 16 1066 | else: 1067 | assert False 1068 | 1069 | assert self.model.config.n_positions == 1024 1070 | self.model.config.n_positions = args.seqlen 1071 | self.model = GPT2LMHeadModel(config=self.model.config) 1072 | else: 1073 | assert args.layers is None 1074 | assert args.size is None 1075 | 1076 | tokenizer = AutoTokenizer.from_pretrained(args.tokenizer) 1077 | self.pad_token_id = tokenizer.pad_token_id 1078 | self.eos_token_id = tokenizer.eos_token_id or tokenizer.sep_token_id 1079 | self.bos_token_id = tokenizer.bos_token_id or tokenizer.cls_token_id 1080 | 1081 | logger.info(f'Creating dataset cache from dir {self.args.input_dir}. This could be slow the first time.') 1082 | MMapTextDataset.raw_text_to_mmap(args) 1083 | 1084 | # TODO: add support for other objective functions (whole word masking, BART, Pegasus) 1085 | # self.data_collator = DataCollatorForLanguageModeling( 1086 | # tokenizer=tokenizer, mlm=True, mlm_probability=self.args.mlm_prob 1087 | # ) 1088 | self.data_collator = DataCollatorForLanguageModeling( 1089 | tokenizer=tokenizer, mlm=False 1090 | ) 1091 | self.start_time = 0 1092 | 1093 | def to(self, *args, **kwargs): 1094 | param_count_before_to = len(list(self.parameters())) 1095 | super().to(*args, **kwargs) 1096 | if self.trainer.on_tpu: 1097 | # if self.trainer.use_tpu: 1098 | # need to re-tie the weights after moving to XLA! 1099 | self.model.tie_weights() 1100 | if 'roberta' in self.args.model or 'longformer' in self.args.model: 1101 | self.model.lm_head.bias = self.model.lm_head.decoder.bias 1102 | param_count_after_to = len(list(self.parameters())) 1103 | assert param_count_before_to == param_count_after_to 1104 | 1105 | def forward(self, inputs): 1106 | # for MLM 1107 | # get the padding mask - 1 for NOT masked, 0 for MASKED/PAD 1108 | # attention_mask = (input_ids != self.pad_token_id).int() 1109 | 1110 | # output is loss, prediction_scores, hidden_states 1111 | # output = self.model(input_ids=input_ids, attention_mask=attention_mask, labels=labels) 1112 | 1113 | # for LM 1114 | output = self.model(**inputs) 1115 | return output[0] # loss 1116 | 1117 | def training_step(self, batch, batch_nb): 1118 | loss = self(batch) 1119 | input_ids = batch['input_ids'] 1120 | tensorboard_logs = { 1121 | 'input_size': input_ids.numel(), 1122 | 'token_per_step': input_ids.numel() * self.trainer.accumulate_grad_batches * self.trainer.world_size, 1123 | } 1124 | # if not self.use_tpu: 1125 | if not self.trainer.on_tpu: 1126 | # logging additional losses is slow on tpu 1127 | tensorboard_logs['lm_loss'] = loss 1128 | tensorboard_logs['lm_bpc'] = loss/math.log(2) 1129 | tensorboard_logs['lm_perplexity'] = torch.exp(loss) 1130 | 1131 | if self.start_time != 0: 1132 | # torch.cuda.synchronize() 1133 | elapsed_time = time.monotonic() - self.start_time 1134 | tensorboard_logs['second_per_batch'] = elapsed_time 1135 | self.start_time = time.monotonic() 1136 | 1137 | if self.on_gpu: 1138 | tensorboard_logs['memory'] = torch.cuda.memory_allocated(loss.device) / 1024 ** 3 1139 | 1140 | for k, v in tensorboard_logs.items(): 1141 | self.log(k, v) 1142 | 1143 | return {'loss': loss} 1144 | 1145 | def on_train_batch_start(self, *args, **kwargs): 1146 | self._start = time.monotonic() 1147 | 1148 | def on_train_batch_end(self, *args, **kwargs): 1149 | delta = time.monotonic() - self._start 1150 | self.log("time_per_batch", delta, on_step=True, on_epoch=False) 1151 | 1152 | def validation_step(self, batch, batch_nb): 1153 | # TODO: log how long evaluation takes 1154 | self.start_time = 0 # reset training_step timer 1155 | 1156 | loss = self(batch) 1157 | tensorboard_logs = { 1158 | 'val_lm_loss': loss.detach(), 1159 | } 1160 | return {'val_loss': tensorboard_logs["val_lm_loss"], 'log': tensorboard_logs} 1161 | 1162 | def validation_epoch_end(self, outputs): 1163 | avg_loss = torch.stack([x['log']['val_lm_loss'] for x in outputs if 'val_lm_loss' in x['log']]).mean() 1164 | if self.trainer.accelerator_connector.use_ddp: 1165 | # TODO: PTL is already doing this. Is it still needed here? 1166 | # https://github.com/PyTorchLightning/pytorch-lightning/blob/0.8.5/pytorch_lightning/metrics/converters.py#L251 1167 | torch.distributed.all_reduce(avg_loss, op=torch.distributed.ReduceOp.SUM) 1168 | avg_loss /= torch.distributed.get_world_size() 1169 | elif self.on_tpu: 1170 | avg_loss = xm.all_reduce(xm.REDUCE_SUM, avg_loss) / xm.xrt_world_size() 1171 | 1172 | self.log('val_loss', avg_loss, on_step=False, on_epoch=True, prog_bar=False, logger=True) 1173 | 1174 | def configure_optimizers(self): 1175 | # no_decay = ["bias", "LayerNorm.weight"] 1176 | 1177 | # optimizer_grouped_parameters = [ 1178 | # { 1179 | # "params": [p for n, p in self.named_parameters() if not any(nd in n for nd in no_decay) and p.requires_grad], 1180 | # "weight_decay": self.args.weight_decay, 1181 | # }, 1182 | # { 1183 | # "params": [p for n, p in self.named_parameters() if any(nd in n for nd in no_decay) and p.requires_grad], 1184 | # "weight_decay": 0.0, 1185 | # }, 1186 | # ] 1187 | # optimizer_grouped_parameters 1188 | 1189 | optimizer = AdamW(self.parameters(), lr=self.args.lr, eps=self.args.adam_epsilon, 1190 | betas=(self.args.adam_beta1, self.args.adam_beta2), 1191 | correct_bias=False) 1192 | if self.args.restart_warmup_steps != 0 and self.args.restart_steps != 0: 1193 | scheduler = get_restart_linear_schedule_with_warmup( 1194 | optimizer, num_warmup_steps=self.args.warmup_steps, num_training_steps=self.args.train_steps, 1195 | restart_steps=self.args.restart_steps, restart_warmup_steps=self.args.restart_warmup_steps, 1196 | ) 1197 | else: 1198 | scheduler = get_linear_schedule_with_warmup( 1199 | optimizer, num_warmup_steps=self.args.warmup_steps, num_training_steps=self.args.train_steps 1200 | ) 1201 | return [optimizer], [{"scheduler": scheduler, "interval": "step"}] 1202 | 1203 | def _get_loader(self, fname, is_train): 1204 | dataset = MMapTextDataset(fname, chunk_size=self.args.seqlen, 1205 | bos_token_id=self.bos_token_id, eos_token_id=self.eos_token_id) 1206 | 1207 | # TODO: consider `replace_sampler_ddp=True` and removing the following if statement 1208 | # if self.trainer.use_ddp: 1209 | if self.trainer.accelerator_connector.use_ddp: 1210 | sampler = torch.utils.data.distributed.DistributedSampler(dataset, shuffle=is_train) 1211 | shuffle = False 1212 | elif self.trainer.on_tpu: 1213 | sampler = torch.utils.data.distributed.DistributedSampler( 1214 | dataset, 1215 | num_replicas=xm.xrt_world_size(), 1216 | rank=xm.get_ordinal(), 1217 | shuffle=is_train, 1218 | ) 1219 | shuffle = False 1220 | else: 1221 | sampler = None 1222 | shuffle = is_train 1223 | 1224 | loader = DataLoader( 1225 | dataset, 1226 | batch_size=self.args.batch_size, 1227 | shuffle=shuffle, 1228 | sampler=sampler, 1229 | num_workers=self.args.num_workers, 1230 | collate_fn=self.data_collator, 1231 | drop_last=is_train, 1232 | ) 1233 | return loader 1234 | 1235 | def train_dataloader(self): 1236 | return self._get_loader(f'{self.args.input_dir}/cache/train.bin', True) 1237 | 1238 | def val_dataloader(self): 1239 | return self._get_loader(f'{self.args.input_dir}/cache/val.bin', False) 1240 | 1241 | def grad_norm(self, norm_type): 1242 | # Override PTL `grad_norm` function to only return `total_grad_norm` instead norms of individual params 1243 | # TODO: grad_norm reporting needs to take fp16 loss scale into account 1244 | parameters = [p for p in self.parameters() if p.grad is not None] 1245 | device = parameters[0].device 1246 | total_norm = torch.zeros([], device=device if parameters else None) 1247 | norm_type = float(norm_type) 1248 | for p in parameters: 1249 | param_norm = p.grad.norm(norm_type) 1250 | total_norm.add_(param_norm) 1251 | return {'total_grad_norm': total_norm} 1252 | 1253 | @staticmethod 1254 | def add_args(parser): 1255 | parser.add_argument("--seed", type=int, default=3) 1256 | 1257 | # Dataset. Some of these params are only useful when generating the dataset cache 1258 | parser.add_argument("--input_dir", type=str, default='/net/nfs2.allennlp/shengs/c4/') 1259 | parser.add_argument("--output_dir", type=str, default='/net/nfs2.allennlp/shengs/c4/') 1260 | 1261 | parser.add_argument("--data_type", type=str, default='tfrecord') 1262 | parser.add_argument("--add_sep_after_doc", action='store_true', default=False, help='add sep token after document') 1263 | 1264 | # Used only at the preprocessing phase 1265 | parser.add_argument("--train_dev_split", type=float, default=0.05) 1266 | parser.add_argument("--shard_size", type=int, default=1024 ** 3 // 4) # 250MB 1267 | parser.add_argument("--num_preprocessing_workers", type=int, default=1) 1268 | # Used only at the training phase 1269 | parser.add_argument("--seqlen", type=int, default=512) 1270 | 1271 | # HF model loading 1272 | parser.add_argument("--tokenizer", type=str, default='gpt2') 1273 | parser.add_argument("--model", type=str, default='gpt2') 1274 | parser.add_argument("--doubling", type=str) # could be layers / weights 1275 | parser.add_argument("--doubling_layers", type=str) # could be alternate_id, append_id, alternate_copy, append_copy 1276 | # parser.add_argument("--noise_std", type=float, default=0.0) 1277 | parser.add_argument("--warmup_bsz", type=int, default=0, help='# warmup batch size') 1278 | parser.add_argument("--warmup_seq", type=int, default=0, help='# warmup sequence length') 1279 | 1280 | parser.add_argument("--random", default=False, action='store_true') 1281 | parser.add_argument("--layers", type=int) 1282 | parser.add_argument("--size", type=str) 1283 | 1284 | # Checkpointing and logging 1285 | parser.add_argument("--save_dir", type=str, default='runs/') 1286 | parser.add_argument("--save_prefix", type=str, default='test', 1287 | help="path of output directory is --save_dir/--save_prefix") 1288 | parser.add_argument("--resume", type=str, default=None, # It is better to use a different output dir. 1289 | help="Path to a checkpoint to load model weights and training state. It overwrites args") 1290 | parser.add_argument("--resume_model_only", type=str, default=None, 1291 | help="Path to a checkpoint to load model weights but not training state") 1292 | parser.add_argument("--reset_optimizer", default=False, action='store_true') 1293 | parser.add_argument("--reset_lr_scheduler", default=False, action='store_true') 1294 | parser.add_argument("--log_rate", type=int, default=10) 1295 | parser.add_argument("--disable_checkpointing", action='store_true', default=False) 1296 | 1297 | # Training hyperparams 1298 | parser.add_argument("--lr", type=float, default=1e-5) 1299 | parser.add_argument("--train_steps", type=int, default=3000, help='# training grad. updates') 1300 | parser.add_argument("--warmup_steps", type=int, default=1000, help='# warmup grad. updates') 1301 | parser.add_argument("--val_every", type=int, default=100, help='# training grad. updates between evaluations') 1302 | parser.add_argument("--val_batches", type=int, default=1000, help='# evaluation **batches**') 1303 | parser.add_argument("--weight_decay", type=float, default=0.01) 1304 | parser.add_argument("--adam_epsilon", type=float, default=1e-6) 1305 | parser.add_argument("--adam_beta1", type=float, default=0.9) 1306 | parser.add_argument("--adam_beta2", type=float, default=0.98) 1307 | parser.add_argument("--grad_clip", type=float, default=0) # TODO: test this with fp16. Likely not working 1308 | 1309 | # RoBERTa's tokens_per_step = 2^18 = 512(seqlen) x 1(gpu_count) x 32(batch_size) x 16(grad_accum) 1310 | parser.add_argument("--batch_size", type=int, default=32) 1311 | parser.add_argument("--grad_accum", type=int, default=1) 1312 | 1313 | # Compute resources 1314 | parser.add_argument("--fp16", default=False, action='store_true') 1315 | parser.add_argument("--num_workers", type=int, default=0) 1316 | parser.add_argument("--gpu_count", type=int, default=1, # `--gpus` is reserved for internal use by PTL 1317 | help="Number of gpus. This respects `CUDA_VISIBLE_DEVICES`") 1318 | 1319 | # For restarting with warmup 1320 | parser.add_argument("--restart_warmup_steps", type=int, default=0, help='# warmup grad. updates after restart') 1321 | parser.add_argument("--restart_steps", type=int, default=0, help='# restart steps, should be the same as set_global_steps') 1322 | # For multi-node training, use the PyTorch launch script. The script and instructions can be found here: 1323 | # https://github.com/pytorch/pytorch/blob/master/torch/distributed/launch.py. 1324 | # To run PTL in a mode compatible with the launch script, two things are needed: 1325 | # - pass the argument `--use_env` to `torch.distributed.launch` 1326 | # - make sure `--nproc_per_node` matches `--gpu_count` and `--nnodes` matches `--node_count`. 1327 | # For example, to run on 2 nodes, 3 gpus each, the command line on node rank 1 would be like: 1328 | # >>>> python -m torch.distributed.launch \ 1329 | # --use_env --nnodes 2 --nproc_per_node 3 \ 1330 | # --node_rank 1 --master_addr s2-server4 --master_port 12343 \ 1331 | # scripts/pretrain.py \ 1332 | # --gpu_count 2 --node_count 2 \ 1333 | # --input_dir my_data_dir --save_prefix test_multinode 1334 | parser.add_argument("--node_count", type=int, default=1, 1335 | help="Number of nodes. It needs to match --nnodes of torch.distributed.launch") 1336 | parser.add_argument("--tpu_core_count", type=int, default=None) 1337 | 1338 | return parser 1339 | 1340 | 1341 | def main(args): 1342 | random.seed(args.seed * 10) 1343 | np.random.seed(args.seed * 100) 1344 | torch.manual_seed(args.seed * 1000) 1345 | if torch.cuda.is_available(): 1346 | torch.cuda.manual_seed_all(args.seed * 10000) 1347 | 1348 | if args.resume_model_only is not None: 1349 | pretrainer = Pretrainer.load_from_checkpoint(args.resume_model_only) 1350 | else: 1351 | pretrainer = Pretrainer() 1352 | 1353 | if args.doubling is not None: 1354 | 1355 | doubled_resume = args.resume + '.doubled_weights' if args.doubling == 'weights' else args.resume + '.doubled_layer' 1356 | print(doubled_resume) 1357 | exsit_flag = os.path.isfile(doubled_resume) 1358 | 1359 | if exsit_flag: 1360 | args.resume = doubled_resume 1361 | print('================== warning: reusing old ckpt =======================') 1362 | 1363 | 1364 | # doubling the checkpoint before doubling the in-memory model 1365 | if args.resume is not None and not exsit_flag: 1366 | ckpt = torch.load(args.resume) 1367 | 1368 | # doubling state dict of the saved model 1369 | if args.doubling == 'weights': 1370 | model_state_dict = ckpt['state_dict'] 1371 | ckpt['state_dict'] = double_state_dict(model_state_dict, is_double_embedding=True) 1372 | 1373 | # doubling state dict of the saved optimizer 1374 | # no_decay = ["bias", "LayerNorm.weight"] 1375 | # optimizer_params_by_name = [(n, p.shape) for n, p in pretrainer.named_parameters() if not any(nd in n for nd in no_decay) and p.requires_grad] 1376 | # optimizer_params_by_name.extend([(n, p.shape) for n, p in pretrainer.named_parameters() if any(nd in n for nd in no_decay) and p.requires_grad]) 1377 | optimizer_params_by_name = [(n, p.shape) for n, p in pretrainer.named_parameters()] 1378 | assert len(optimizer_params_by_name) == len(ckpt['optimizer_states'][0]['state']) 1379 | for (param_name, param_shape), param in zip(optimizer_params_by_name, ckpt['optimizer_states'][0]['state'].values()): 1380 | assert param['exp_avg'].shape == param_shape 1381 | assert param['exp_avg_sq'].shape == param_shape 1382 | param['exp_avg'] = double_param(param_name, param['exp_avg'], is_double_embedding=True, is_grad=True, is_avg_sq=False) 1383 | param['exp_avg_sq'] = double_param(param_name, param['exp_avg_sq'], is_double_embedding=True, is_grad=True, is_avg_sq=True) 1384 | 1385 | # print(name_shape[0]) 1386 | args.resume += '.doubled_weights' 1387 | elif args.doubling == 'layers': 1388 | model_state_dict = ckpt['state_dict'] 1389 | # hack for doubling the layers 1390 | prefix = 'model.transformer.h' 1391 | map_positions, copy_positions = {}, {} 1392 | for key in model_state_dict: 1393 | if prefix in key: 1394 | layer_idx = re.findall("[-\d]+", key)[0] 1395 | origin_idx = prefix + "." + str(int(layer_idx)) 1396 | if 'alternate' in args.doubling_layers: 1397 | insert_idx = prefix + "." + str(int(layer_idx) * 2 + 1) 1398 | origin_key = key.replace(origin_idx, prefix + "." + str(int(layer_idx) * 2)) 1399 | elif 'append' in args.doubling_layers: 1400 | insert_idx = prefix + "." + str(pretrainer.model.config.n_layer + int( layer_idx )) 1401 | origin_key = key 1402 | 1403 | insert_key = key.replace( origin_idx, insert_idx ) 1404 | 1405 | map_positions[ key ] = [ (origin_key, False), (insert_key, False) ] 1406 | copy_positions[ insert_key ] = (key, False) 1407 | copy_positions[ origin_key ] = (key, True) 1408 | 1409 | is_identical = 'id' in args.doubling_layers 1410 | 1411 | ckpt['state_dict'] = deep_state_dict(model_state_dict, is_identical=is_identical, map_positions=map_positions) 1412 | 1413 | # deal with the optimizer state 1414 | original_optimizer_params_by_name = [(n, p.shape) for n, p in pretrainer.named_parameters()] 1415 | # print( "original_optimizer_params_by_name", original_optimizer_params_by_name ) 1416 | # print( "ckpt optimizer_states", ckpt['optimizer_states'][0]['state'].keys() ) 1417 | layers = pretrainer.model.transformer.h 1418 | n = len(layers) 1419 | for i in range(n): 1420 | if 'alternate' in args.doubling_layers: 1421 | layers.insert(i * 2, copy.deepcopy(layers[i * 2])) 1422 | elif 'append' in args.doubling_layers: 1423 | layers.append( copy.deepcopy(layers[i]) ) 1424 | 1425 | pretrainer.model.config.n_layer *= 2 1426 | pretrainer.model.tie_weights() 1427 | new_optimizer_params_by_name = [(n, p.shape) for n, p in pretrainer.named_parameters()] 1428 | 1429 | new_optimizer_state = { _:{} for _ in range(len(new_optimizer_params_by_name)) } 1430 | assert len(original_optimizer_params_by_name) == len(ckpt['optimizer_states'][0]['state']) 1431 | original_optimizer_param_name_dict = {} 1432 | for (param_name, param_shape), param in zip(original_optimizer_params_by_name, ckpt['optimizer_states'][0]['state'].values()): 1433 | assert param['exp_avg'].shape == param_shape 1434 | assert param['exp_avg_sq'].shape == param_shape 1435 | original_optimizer_param_name_dict[ param_name ] = copy.deepcopy(param) 1436 | 1437 | for param_idx, (param_name, param_shape) in enumerate(new_optimizer_params_by_name): 1438 | if copy_positions.get(param_name): 1439 | copy_param_name, copy_param_flag = copy_positions.get(param_name) 1440 | param_is_identical = copy_param_flag and is_identical 1441 | new_optimizer_state[ param_idx ] = copy.deepcopy(original_optimizer_param_name_dict[ copy_param_name ]) 1442 | new_optimizer_state[ param_idx ]['exp_avg'] = deep_param(param_name, original_optimizer_param_name_dict[ copy_param_name ]['exp_avg'], is_identical=param_is_identical, is_grad=True, is_avg_sq=False) 1443 | new_optimizer_state[ param_idx ]['exp_avg_sq'] = deep_param(param_name, original_optimizer_param_name_dict[ copy_param_name ]['exp_avg_sq'], is_identical=param_is_identical, is_grad=True, is_avg_sq=True) 1444 | else: 1445 | new_optimizer_state[ param_idx ] = copy.deepcopy(original_optimizer_param_name_dict[param_name]) 1446 | 1447 | ckpt['optimizer_states'][0]['state'] = new_optimizer_state 1448 | ckpt['optimizer_states'][0]['param_groups'][0]['params'] = list(new_optimizer_state.keys()) 1449 | del original_optimizer_param_name_dict 1450 | args.resume += '.doubled_layer' 1451 | 1452 | torch.save(ckpt, args.resume) 1453 | exit() 1454 | 1455 | # we need to resume the model after the doubling 1456 | if args.doubling == 'layers': 1457 | assert True 1458 | elif args.doubling == 'weights': 1459 | assert True 1460 | else: 1461 | assert False 1462 | 1463 | # logger here is a SummaryWritter for tensorboard 1464 | # it is used by the trainer, and certain return variables 1465 | # from the model are automatically logged 1466 | logger = TestTubeLogger( 1467 | save_dir=args.save_dir, 1468 | name=args.save_prefix, 1469 | version=0 # always use version=0 1470 | ) 1471 | 1472 | checkpoint_callback = ModelCheckpoint( 1473 | # model saved to filepath/prefix_.... 1474 | # filepath=os.path.join(args.save_dir, args.save_prefix, 'checkpoint'), 1475 | # prefix='', 1476 | dirpath=os.path.join(args.save_dir, args.save_prefix), 1477 | filename='checkpoint-{epoch}-{step}', 1478 | save_top_k=-1, 1479 | # save_top_k=10, 1480 | every_n_train_steps=250, 1481 | save_last=True, 1482 | verbose=True, 1483 | # monitor='val_loss', 1484 | # mode='min', 1485 | ) 1486 | args.val_every *= args.grad_accum # PTL is expecting number of batches_per_gpu 1487 | print(args.val_every, args.disable_checkpointing, checkpoint_callback.__dict__) 1488 | trainer = ptl.Trainer( 1489 | gpus=args.gpu_count, 1490 | num_nodes=args.node_count, 1491 | tpu_cores=args.tpu_core_count, 1492 | distributed_backend='ddp', # if (args.gpu_count > 1 or args.node_count > 1) else None, 1493 | replace_sampler_ddp=False, 1494 | track_grad_norm=2 if args.tpu_core_count is None else -1, # gradnorm logging is slow on tpus 1495 | max_epochs=10000, min_epochs=0, 1496 | max_steps=args.train_steps, # run for many epochs, but stop after max_steps 1497 | val_check_interval=args.val_every, limit_val_batches=args.val_batches, 1498 | log_every_n_steps=args.log_rate, 1499 | progress_bar_refresh_rate=args.log_rate, 1500 | logger=logger, 1501 | # checkpoint_callback=checkpoint_callback if not args.disable_checkpointing else None, 1502 | accumulate_grad_batches=args.grad_accum, 1503 | resume_from_checkpoint=args.resume, 1504 | gradient_clip_val=args.grad_clip, 1505 | precision=16 if args.fp16 else 32, amp_level='O2', 1506 | num_sanity_val_steps=2, 1507 | callbacks=[LearningRateMonitor(), checkpoint_callback], 1508 | profiler="simple", 1509 | ) 1510 | trainer.profiler.dirpath = os.path.join(args.save_dir, args.save_prefix) 1511 | trainer.profiler.filename = 'profiler' 1512 | trainer.train_loop = MyTrainLoop(trainer, 1513 | multiple_trainloader_mode='max_size_cycle', args=args) 1514 | trainer.checkpoint_connector = MyCheckpointConnector(trainer, 1515 | reset_lr_scheduler=args.reset_lr_scheduler, 1516 | reset_optimizer=args.reset_optimizer, 1517 | set_global_step=args.restart_steps+1) 1518 | trainer.fit(pretrainer) 1519 | 1520 | 1521 | if __name__ == "__main__": 1522 | parser = Pretrainer.add_args(argparse.ArgumentParser(description="pretrain")) 1523 | args = parser.parse_args() 1524 | main(args) 1525 | --------------------------------------------------------------------------------