├── .gitignore ├── LICENSE ├── README.md ├── requirements.txt ├── train_512K.sh ├── train_64K.sh ├── train_sft.sh └── training ├── dataset.py ├── distributed_attention.py ├── modeling_flash_llama.py ├── train_language_model.py └── trainer.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 110 | .pdm.toml 111 | .pdm-python 112 | .pdm-build/ 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | 164 | debug* 165 | checkpoints 166 | datasets 167 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Princeton Natural Language Processing 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ProLong 2 | 3 | [[Paper](https://arxiv.org/pdf/2410.02660)] [[HF Page](https://huggingface.co/collections/princeton-nlp/prolong-66c72d55d2051a86ac7bd7e4)] 4 | 5 | This is the homepage for **ProLong** (Princeton long-context language models). 6 | 7 | ProLong is a family of long-context models that are continued trained and supervised fine-tuned from Llama-3-8B, with a maximum context window of 512K tokens. Our [main ProLong model](https://huggingface.co/princeton-nlp/Llama-3-8B-ProLong-512k-Instruct) is one of the best-performing long-context models at the 10B scale (evaluated by [HELMET](https://github.com/princeton-nlp/helmet)). 8 | 9 | To train this strong long-context model, we conduct thorough ablations on the long-context pre-training data, SFT data, and numerous other design choices. We demonstrate our findings in our paper, [How to Train Long-Context Language Models (Effectively)](https://arxiv.org/pdf/2410.02660). 10 | 11 | Authors: [Tianyu Gao](https://gaotianyu.xyz/about)\*, [Alexander Wettig](https://www.cs.princeton.edu/~awettig/)\*, [Howard Yen](https://howard-yen.github.io/), [Danqi Chen](https://www.cs.princeton.edu/~danqic/) (* equal contribution) 12 | 13 | ## Release Progress 14 | 15 | 16 | - [x] ProLong models 17 | - [x] ProLong data 18 | - [x] Pre-training and SFT code 19 | - [x] Sequence parallelism 20 | 21 | ## Model card 22 | 23 | Here are some quick facts about our main ProLong model: [princeton-nlp/Llama-3-8B-ProLong-512k-Instruct](https://huggingface.co/princeton-nlp/Llama-3-8B-ProLong-512k-Instruct). 24 | * Base model: [meta-llama/Meta-Llama-3-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct) 25 | * Long-context continued training: 20B tokens on 64K training data, and 20B tokens on 512K training data 26 | * Supervised fine-tuning (SFT): [UltraChat](https://huggingface.co/datasets/HuggingFaceH4/ultrachat_200k) 27 | * Maximum context window: 512K tokens 28 | 29 | 30 |

31 | image 32 |

33 | 34 |

35 | ProLong performance on HELMET averaged over 32K, 64K, and 128K lengths. All models are instruct models. 36 |

37 | 38 | 39 | ## Download the models and packed data 40 | 41 | All ProLong models are available on Hugging Face. All the models are based on Llama-3-8B, so any code that supports Llama-3-8B is also compatible with ProLong models. 42 | 43 | | Model | HF Link | 44 | |-------|---------| 45 | | ProLong-64k-Base | [princeton-nlp/Llama-3-8B-ProLong-64k-Base](https://huggingface.co/princeton-nlp/Llama-3-8B-ProLong-64k-Base) | 46 | | ProLong-64k-Instruct | [princeton-nlp/Llama-3-8B-ProLong-64k-Instruct](https://huggingface.co/princeton-nlp/Llama-3-8B-ProLong-64k-Instruct) | 47 | | ProLong-512k-Base | [princeton-nlp/Llama-3-8B-ProLong-512k-Base](https://huggingface.co/princeton-nlp/Llama-3-8B-ProLong-512k-Base) | 48 | | ⭐ ProLong-512k-Instruct | [princeton-nlp/Llama-3-8B-ProLong-512k-Instruct](https://huggingface.co/princeton-nlp/Llama-3-8B-ProLong-512k-Instruct) | 49 | 50 | Our training data (packed and sampled version) are also available on Hugging Face (in [mosaicml-streaming](https://docs.mosaicml.com/projects/streaming/en/stable/index.html) format). 51 | 52 | | Data | HF Link | 53 | |------|---------| 54 | | Stage 1: 64K training data (40B tokens) | [princeton-nlp/prolong-data-64K](https://huggingface.co/datasets/princeton-nlp/prolong-data-64K) | 55 | | Stage 2: 512K training data (40B tokens)| [princeton-nlp/prolong-data-512K](https://huggingface.co/datasets/princeton-nlp/prolong-data-512K) | 56 | | SFT: UltraChat (1B tokens) | [princeton-nlp/prolong-ultrachat-64K](https://huggingface.co/datasets/princeton-nlp/prolong-ultrachat-64K) | 57 | 58 | 59 | 60 | 61 | ## Download and prepare raw data 62 | 63 | If you want to experiment with different data lengths or data mixtures, 64 | We also provide the (unpacked, unfiltered, but tokenized) raw data from each domain below. 65 | Due to the large size of the raw data, we store it on AWS S3. To download the data, you need to have an AWS account (with an access key and a secret key). **Note that data downloading will incur a charge on your AWS account**. According to [this S3 document](https://aws.amazon.com/s3/pricing/), each GB of data downloaded incurs $0.09 and the first 100GB is free. You can download the data using the following commands: 66 | 67 | ```bash 68 | # Install AWS CLI if you haven't already 69 | pip install awscli 70 | 71 | # Configure AWS CLI with your credentials (you will need an access key and a secret key from your AWS account) 72 | aws configure 73 | 74 | # Download the raw code repo data (concatenated by repo names from the stack v1) 75 | aws s3 sync s3://princeton-prolong/data_before_packing/code_repos/ /target/path/ --request-payer requester 76 | ``` 77 | 78 | Below is the available unpacked raw data (tokenized with the Llama-3 tokenizer). All data is in the [mosaicml-streaming](https://docs.mosaicml.com/projects/streaming/en/stable/index.html) format, with three fields: `domain` (`str`), `input_ids` (`int32 numpy array`, the Llama-3 tokenized document with no BOS/EOS), and `length` (`int32`, number of tokens). 79 | 80 | | Data | Size | S3 path | 81 | |------|------|---------| 82 | | Code repos | 689 GB | s3://princeton-prolong/data_before_packing/code_repos/ | 83 | | Books (SlimPajama)| 180 GB| s3://princeton-prolong/data_before_packing/books/ | 84 | | FineWeb (sampled) | 864 GB | s3://princeton-prolong/data_before_packing/fineweb-2023-50/ | 85 | | FineWeb-edu (sampled) | 365 GB | s3://princeton-prolong/data_before_packing/fineweb-edu-100B/ | 86 | | OpenWebMath | 48 GB| s3://princeton-prolong/data_before_packing/openwebmath/ | 87 | | Wikipedia (Dolma) | 14 GB | s3://princeton-prolong/data_before_packing/wikipedia/ | 88 | | Textbooks | 1 GB | s3://princeton-prolong/data_before_packing/textbooks/ | 89 | | Tulu-v2 | 1 GB | s3://princeton-prolong/data_before_packing/tuluv2/ | 90 | | StackExchange (SlimPajama) | 135 GB | s3://princeton-prolong/data_before_packing/stackexchange/ | 91 | | ArXiv (SlimPajama) | 210 GB | s3://princeton-prolong/data_before_packing/arxiv/ | 92 | 93 | 94 |
95 | A quick guide of mosaicml-streaming 96 | 97 | Full documentation and installation guide can be found [here](https://docs.mosaicml.com/projects/streaming/en/stable/index.html). 98 | 99 |
100 | >>> from streaming import LocalDataset
101 | >>> dataset = LocalDataset("path/to/dataset")
102 | >>> len(dataset) # number of samples
103 | >>> dataset[0] # allow random access, use like a dictionary/JSON
104 | {'domain': 'book', 'input_ids': array([ 1038, 19017,  2041, ...,   271, 12488,   220], dtype=uint32), 'length': 111200}
105 | 
106 | 107 |
108 | 109 | 110 | 111 | ### How to filter and pack data 112 | 113 | We use our own [datatools](https://github.com/CodeCreator/datatools) (created by Alex and Tianyu) to filter (by lengths) and pack data. `datatools` is a versatile repo that supports tokenization/packing/filtering from various raw formats (json, jsonl, hugging face, mosaicml-streaming, etc) and outputs the data in the mosaicml-streaming format. 114 | 115 | Example usage: 116 | ```bash 117 | pack --pack_length --min_length -w 118 | 119 | # For example, pack our raw code data to 64K with 40 workers 120 | pack data/code_repo data/code_repo-packto64k-minlen64k --pack_length 65536 --min_length 65536 -w 40 121 | 122 | # Our script is also compatible with distributed workflows on SLURM. The example belows uses 20 SLURM array jobs, each using 40 workers 123 | pack data/code_repo data/code_repo-packto64k-minlen64k --pack_length 65536 --min_length 65536 -w 40 --num_jobs 20 --slurm_array 124 | 125 | # If you want to tokenize some raw data with text strings into tokenized data (which can then be packed). The example belows uses 20 SLURM array jobs, each using 40 workers. 126 | # The input directory should also be of mosaic-streaming format. Each item should have a "text" field as raw strings of documents. 127 | # You should first run at a smaller scale and check if the result looks correct 128 | tokenize data/code_repo_text data/code_repo -w 40 --num_jobs 20 --slurm_array --tokenizer {HF tokenizer name / llama2 / llama3} 129 | 130 | ``` 131 | 132 | ## How to train ProLong 133 | 134 |

135 | image 136 |

137 |

138 | ProLong training recipe. 139 |

140 | 141 | 142 | Our training code is built on top of Hugging Face's [Transformers](https://github.com/huggingface/transformers). Compared to the original codebase, we make the following changes: 143 | 144 | * Support `mosaicml-streaming` formats for datasets (much faster and IO friendly). 145 | * Support FlashAttention-2's variable-length attention (for efficient document masking). We implemented an in-batch length-sorting dataloader that balances data loads on different devices and improves training throughput. 146 | * Support sequence parallelism (inspired by DeepSpeed Ulysses). 147 | * Support SFT (masking out instructions) and token-averaged losses (instead of torch's standard sequence-and-device-averaged losses). 148 | * We implemented a memory-efficient cross entropy that allows 64K-token training of Llama-3-8B without using sequence parallelism. 149 | * Various improvements on checkpoint resuming and logging. 150 | 151 | #### File structures 152 | 153 | All our code is under `training`: 154 | * `dataset.py`: datasets and packing strategies for mosaicml-streaming data. 155 | * `distributed_attention.py`: sequence parallelism implementation. 156 | * `modeling_flash_llama.py`: our modified FlashAttention-2 Llama code, with support for variable-length attention, sequence parallelism, memory-efficient cross entropy, and token-averaged losses. 157 | * `trainer.py`: our trainer derived from Hugging Face's `Trainer` with various improvements. 158 | * `train_language_model.py`: the main training script. 159 | 160 | #### Preparation 161 | 162 | 1. Download all the data to `datasets/` 163 | ```bash 164 | git clone https://huggingface.co/datasets/princeton-nlp/prolong-data-64K datasets/long-context-65536 165 | git clone https://huggingface.co/datasets/princeton-nlp/prolong-data-512K datasets/long-context-524288 166 | git clone https://huggingface.co/datasets/princeton-nlp/prolong-ultrachat-64K datasets/prolong-ultrachat-64K 167 | ``` 168 | 169 | 2. Install dependencies 170 | ```bash 171 | pip install -r requirements.txt 172 | ``` 173 | 174 | #### Training 175 | 176 | We provide the scripts for 64K training (`train_64K.sh`), 512K training (`train_512K.sh`), and the final SFT training (`train_sft.sh`). The scripts require at least 8 GPUs (each with at least 80GB memory) to run. To run it on a local machine, simply do `bash {script_name}.sh`. If you are using SLURM in a cluster environment, you can submit the job by `sbatch {script_name}.sh`. To submit a resume-from-checkpoint job, the same script will work too. 177 | 178 | The 512K training will load the 64K checkpoint and the optimizer state. To allow this, **please do the following** 179 | ```bash 180 | cd {the HF checkpoint folder of the 64K model} 181 | mv trainer_state.json trainer_state.json.backup # Otherwise the model will reload the old LR scheduler 182 | ln -s checkpoint-5000/optimizer.pt . # Link the optimizer state so that it can be loaded; replace checkpoint-5000 to whichever that is the last checkpoint 183 | ``` 184 | 185 | 186 | #### Customization 187 | 188 | You can read the comments in the scripts to see what customized training arguments we used. 189 | Here is a brief explanation of them (we skip all that are already defined in Hugging Face): 190 | * `--cuda_empty_cache`: empty CUDA cache after each step to avoid OOM. 191 | * `--config_overrides`: override the default HF config with specified arguments, e.g., `--config_overrides "rope_theta=8000000"`. 192 | * `--seq_parallel_size`: sequence parallelism size. For example, `--seq_parallel_size 8` means we use 8 GPUs to handle one long sequence. 193 | * `--apply_instruct_masks`: read the `mask` field from the dataset and mask out those tokens during instruction tuning (e.g., the instructions). 194 | * `--token_scaled_loss`: average losses over valid training tokens instead of devices. This should be turned on during instruction tuning. 195 | 196 | There are more options regarding FSDP, gradient checkpointing, etc. Please refer to the scripts for more details. 197 | 198 | ## Contact 199 | 200 | Please email Tianyu (`tianyug@princeton.edu`) or Alex (`awettig@princeton.edu`) if you have any questions. If you encounter any issues with the code, models, or data, please open an issue on GitHub. 201 | 202 | 203 | ## Citation 204 | 205 | ```bibtex 206 | @inproceedings{gao2025prolong, 207 | title={How to Train Long-Context Language Models (Effectively)}, 208 | author={Gao, Tianyu and Wettig, Alexander and Yen, Howard and Chen, Danqi}, 209 | booktitle={ACL}, 210 | year={2025} 211 | } 212 | ``` 213 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==0.32.1 2 | datasets==2.20.0 3 | mosaicml-cli==0.5.34 4 | mosaicml-streaming==0.8.1 5 | ninja==1.11.1.1 6 | numpy==1.26.3 7 | packaging==24.1 8 | pandas==2.2.1 9 | protobuf==4.25.3 10 | python-dateutil==2.9.0 11 | regex==2023.12.25 12 | sentencepiece==0.1.99 13 | tiktoken==0.7.0 14 | torch==2.4.1 15 | tqdm==4.66.4 16 | transformers==4.44.2 17 | triton==3.0.0 18 | wandb==0.17.3 19 | zstandard==0.22.0 20 | zstd==1.5.5.1 21 | 22 | flash-attn==2.6.1, --config-settings=--global-option="--no-build-isolation" 23 | rotary-emb @ git+https://github.com/Dao-AILab/flash-attention.git@9356a1c0389660d7e231ff3163c1ac17d9e3824a#subdirectory=csrc/rotary 24 | -------------------------------------------------------------------------------- /train_512K.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash -l 2 | #SBATCH -J train_512K 3 | #SBATCH -N 1 4 | #SBATCH --output=slurm/%x-%j.out 5 | #SBATCH --gres=gpu:8 6 | #SBATCH --mem=400G 7 | #SBATCH -c 32 8 | 9 | # !!!! Load your own environment here !!!! # 10 | # !!!! Load your own environment here !!!! # 11 | 12 | # Fine-tune from this model 13 | model=${MODEL:-meta-llama/Meta-Llama-3-8B-Instruct} 14 | # Point to the base dir of the ProLong 512K data 15 | dataset=${DATASET:-"datasets/long-context-524288"} 16 | 17 | # Directories in the dataset root folder where @ is followed by the mixing proportion 18 | domains=( 19 | thestackv1_concat_by_repo-524288@0.15 20 | thestackv1_concat_by_repo-65536@0.15 21 | book-524288@0.05 22 | book-65536@0.25 23 | fineweb-edu@0.1 24 | fineweb-2023-50@0.1 25 | stackexchange@0.04 26 | dolmawiki@0.04 27 | tuluv2@0.03 28 | arxiv@0.03 29 | openwebmath@0.03 30 | textbooks@0.03 31 | ) 32 | domains_name=ProLong512KMix 33 | 34 | 35 | bsz=${BSZ:-128} # * 512K (seq len) / 8 (seq parallel size) = 8M 36 | seq=${SEQ:-1} # per-device batch size 37 | lr=${LR:-5e-6} 38 | steps=${STEPS:-2500} 39 | save_steps=${SAVE:-125} 40 | warmup=${WARMUP:-0.1} 41 | suffix=${SUFFIX:-""} # for model saving name 42 | 43 | 44 | run_name="lcft_$(basename $model)_$(basename $dataset)_${domains_name}_bsz${bsz}_steps${steps}_lr${lr}_warmup${warmup}${suffix}" 45 | out_dir="checkpoints/$run_name" 46 | 47 | if [ -z "$CUDA_VISIBLE_DEVICES" ]; then 48 | num_gpus=$(nvidia-smi -L | wc -l) 49 | else 50 | num_gpus=$(jq -n "[$CUDA_VISIBLE_DEVICES] | length") 51 | fi 52 | num_gpus=${NUM_GPUS:-$num_gpus} 53 | 54 | num_nodes=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | wc -l) 55 | if [ $num_nodes == 0 ]; then 56 | num_nodes=1 57 | fi 58 | num_nodes=${NUM_NODES:-$num_nodes} 59 | 60 | # Gradient accumulation 61 | accu=$(($bsz / $seq / $num_gpus / $num_nodes)) 62 | 63 | 64 | # [0] Disable 65 | # [1] FULL_SHARD (shards optimizer states, gradients and parameters), 66 | # [2] SHARD_GRAD_OP (shards optimizer states and gradients), 67 | # [3] NO_SHARD (DDP), 68 | # [4] HYBRID_SHARD (shards optimizer states, gradients and parameters within each node while each node has full copy), 69 | # [5] HYBRID_SHARD_ZERO2 (shards optimizer states and gradients within each node while each node has full copy). For more information, please refer the official PyTorch docs. 70 | fsdp=${FSDP:-"1"} 71 | gc=${GC:-"1"} 72 | 73 | export LOGIT_BLOCK_SIZE=2048 # Compute Llama logits in blocks of 2048 tokens 74 | 75 | mkdir -p $out_dir 76 | nvidia-smi 77 | 78 | if [ $num_nodes -gt 1 ]; then 79 | master_addr=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1) 80 | master_addr=${MASTER_ADDR:-$master_addr} 81 | 82 | # Launch via srun 83 | header="srun torchrun \ 84 | --rdzv-backend=c10d \ 85 | --rdzv-endpoint=$master_addr:56321 \ 86 | --nnodes=$num_nodes \ 87 | --nproc-per-node=$num_gpus \ 88 | -m training.train_language_model" 89 | else 90 | master_port=$(comm -23 <(seq 49152 65535 | sort) <(ss -Htan | awk '{print $4}' | cut -d':' -f2 | sort -u) | shuf | head -n 1) 91 | 92 | # Launch without srun 93 | header="torchrun \ 94 | --rdzv-backend=c10d \ 95 | --rdzv-endpoint=localhost:$master_port \ 96 | --nnodes=1 \ 97 | --nproc-per-node=$num_gpus \ 98 | -m training.train_language_model" 99 | fi 100 | echo "slurm_nodelist=${SLURM_NODELIST} num_nodes=${num_nodes} master_addr=${master_addr} master_port=${master_port} num_gpus=${num_gpus}" 101 | 102 | export OMP_NUM_THREADS=$num_gpus 103 | export WANDB_PROJECT="prolong" 104 | export WANDB_DIR=$out_dir 105 | export WANDB_MODE="offline" # We turn off wandb online sync by default 106 | export TOKENIZERS_PARALLELISM=true 107 | 108 | 109 | base_arguments=( 110 | --report_to wandb 111 | --do_train 112 | 113 | --model_name $model 114 | --tokenizer_name $model 115 | 116 | # Initialize model + optimizer state with ProLong64K (please follow the README for the correct setup) 117 | --resume_from_checkpoint path/to/the/root/64K/checkpoint/folder 118 | 119 | --run_name $run_name 120 | --output_dir $out_dir 121 | --config_overrides_json "$overrides" 122 | --gradient_accumulation_steps $accu 123 | --per_device_train_batch_size $seq 124 | --per_device_eval_batch_size $seq 125 | 126 | --bf16 127 | --learning_rate $lr 128 | --min_lr_ratio 0.1 129 | --lr_scheduler_type cosine 130 | --max_grad_norm 1.0 131 | --adam_beta1 0.9 132 | --adam_beta2 0.95 133 | --weight_decay 0.1 134 | --warmup_ratio $warmup 135 | --optim adamw_torch 136 | 137 | --logging_steps 1 138 | --log_level info 139 | 140 | --max_steps $steps 141 | --save_steps $save_steps 142 | --dataloader_num_workers 1 143 | 144 | --disable_tqdm true 145 | --use_fast_tokenizer false 146 | --remove_unused_columns false 147 | --ddp_find_unused_parameters false 148 | 149 | --per_device_max_tokens 524288 150 | 151 | # --torch_compile 152 | --cuda_empty_cache 153 | --config_overrides "rope_theta=128000000" 154 | 155 | --seq_parallel_size 8 156 | ) 157 | 158 | 159 | 160 | if [ $fsdp -ne 0 ]; then 161 | export FSDP_SHARDING_STRATEGY=$fsdp 162 | base_arguments+=( --fsdp "auto_wrap" ) 163 | # [1] FULL_STATE_DICT, [2] LOCAL_STATE_DICT, [3] SHARDED_STATE_DICT 164 | export FSDP_STATE_DICT_TYPE="FULL_STATE_DICT" 165 | fi 166 | 167 | if [ $gc -ne 0 ]; then 168 | base_arguments+=( --gradient_checkpointing ) 169 | fi 170 | 171 | base_arguments+=( --tokenized_mds_train ) 172 | for domain in "${domains[@]}"; do 173 | base_arguments+=( $dataset/$domain ) 174 | done 175 | 176 | base_arguments+=( $@ ) 177 | 178 | echo command: "${header} ${base_arguments[@]}" 179 | ${header} "${base_arguments[@]}" 2>&1 | tee -a $out_dir/log.out 180 | -------------------------------------------------------------------------------- /train_64K.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash -l 2 | #SBATCH -J train_64K 3 | #SBATCH -N 1 4 | #SBATCH --output=slurm/%x-%j.out 5 | #SBATCH --gres=gpu:8 6 | #SBATCH --mem=400G 7 | #SBATCH -c 32 8 | 9 | # !!!! Load your own environment here !!!! # 10 | # !!!! Load your own environment here !!!! # 11 | 12 | # Fine-tune from this model 13 | model=${MODEL:-meta-llama/Meta-Llama-3-8B-Instruct} 14 | # Point to the base dir of the ProLong 64K data 15 | dataset=${DATASET:-"datasets/long-context-65536"} 16 | 17 | # Directories in the dataset root folder where @ is followed by the mixing proportion 18 | domains=( 19 | thestackv1_concat_by_repo-65536@0.3 20 | book-65536@0.3 21 | fineweb-edu@0.1 22 | fineweb-2023-50@0.1 23 | stackexchange@0.04 24 | dolmawiki@0.04 25 | tuluv2@0.03 26 | arxiv@0.03 27 | openwebmath@0.03 28 | textbooks@0.03 29 | ) 30 | domains_name=ProLong64KMix 31 | 32 | 33 | bsz=${BSZ:-64} # * 64k (seq len) = 4M 34 | seq=${SEQ:-1} # per-device batch size 35 | lr=${LR:-1e-5} 36 | steps=${STEPS:-5000} 37 | save_steps=${SAVE:-125} 38 | warmup=${WARMUP:-0.1} 39 | suffix=${SUFFIX:-""} # for model saving name 40 | 41 | 42 | run_name="lcft_$(basename $model)_$(basename $dataset)_${domains_name}_bsz${bsz}_steps${steps}_lr${lr}_warmup${warmup}${suffix}" 43 | out_dir="checkpoints/$run_name" 44 | 45 | if [ -z "$CUDA_VISIBLE_DEVICES" ]; then 46 | num_gpus=$(nvidia-smi -L | wc -l) 47 | else 48 | num_gpus=$(jq -n "[$CUDA_VISIBLE_DEVICES] | length") 49 | fi 50 | num_gpus=${NUM_GPUS:-$num_gpus} 51 | 52 | num_nodes=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | wc -l) 53 | if [ $num_nodes == 0 ]; then 54 | num_nodes=1 55 | fi 56 | num_nodes=${NUM_NODES:-$num_nodes} 57 | 58 | # Gradient accumulation 59 | accu=$(($bsz / $seq / $num_gpus / $num_nodes)) 60 | 61 | 62 | # [0] Disable 63 | # [1] FULL_SHARD (shards optimizer states, gradients and parameters), 64 | # [2] SHARD_GRAD_OP (shards optimizer states and gradients), 65 | # [3] NO_SHARD (DDP), 66 | # [4] HYBRID_SHARD (shards optimizer states, gradients and parameters within each node while each node has full copy), 67 | # [5] HYBRID_SHARD_ZERO2 (shards optimizer states and gradients within each node while each node has full copy). For more information, please refer the official PyTorch docs. 68 | fsdp=${FSDP:-"1"} 69 | gc=${GC:-"1"} 70 | 71 | export LOGIT_BLOCK_SIZE=2048 # Compute Llama logits in blocks of 2048 tokens 72 | 73 | mkdir -p $out_dir 74 | nvidia-smi 75 | 76 | if [ $num_nodes -gt 1 ]; then 77 | master_addr=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1) 78 | master_addr=${MASTER_ADDR:-$master_addr} 79 | 80 | # Launch via srun 81 | header="srun torchrun \ 82 | --rdzv-backend=c10d \ 83 | --rdzv-endpoint=$master_addr:56321 \ 84 | --nnodes=$num_nodes \ 85 | --nproc-per-node=$num_gpus \ 86 | -m training.train_language_model" 87 | else 88 | master_port=$(comm -23 <(seq 49152 65535 | sort) <(ss -Htan | awk '{print $4}' | cut -d':' -f2 | sort -u) | shuf | head -n 1) 89 | 90 | # Launch without srun 91 | header="torchrun \ 92 | --rdzv-backend=c10d \ 93 | --rdzv-endpoint=localhost:$master_port \ 94 | --nnodes=1 \ 95 | --nproc-per-node=$num_gpus \ 96 | -m training.train_language_model" 97 | fi 98 | echo "slurm_nodelist=${SLURM_NODELIST} num_nodes=${num_nodes} master_addr=${master_addr} master_port=${master_port} num_gpus=${num_gpus}" 99 | 100 | export OMP_NUM_THREADS=$num_gpus 101 | export WANDB_PROJECT="prolong" 102 | export WANDB_DIR=$out_dir 103 | export WANDB_MODE="offline" # We turn off wandb online sync by default 104 | export TOKENIZERS_PARALLELISM=true 105 | 106 | 107 | base_arguments=( 108 | --report_to wandb 109 | --do_train 110 | 111 | --model_name $model 112 | --tokenizer_name $model 113 | 114 | --run_name $run_name 115 | --output_dir $out_dir 116 | --config_overrides_json "$overrides" 117 | --gradient_accumulation_steps $accu 118 | --per_device_train_batch_size $seq 119 | --per_device_eval_batch_size $seq 120 | 121 | --bf16 122 | --learning_rate $lr 123 | --min_lr_ratio 0.1 124 | --lr_scheduler_type cosine 125 | --max_grad_norm 1.0 126 | --adam_beta1 0.9 127 | --adam_beta2 0.95 128 | --weight_decay 0.1 129 | --warmup_ratio $warmup 130 | --optim adamw_torch 131 | 132 | --logging_steps 1 133 | --log_level info 134 | 135 | --max_steps $steps 136 | --save_steps $save_steps 137 | --dataloader_num_workers 1 138 | 139 | --disable_tqdm true 140 | --use_fast_tokenizer false 141 | --remove_unused_columns false 142 | --ddp_find_unused_parameters false 143 | 144 | --per_device_max_tokens 65536 145 | 146 | # --torch_compile 147 | --cuda_empty_cache 148 | --config_overrides "rope_theta=8000000" 149 | ) 150 | 151 | 152 | 153 | if [ $fsdp -ne 0 ]; then 154 | export FSDP_SHARDING_STRATEGY=$fsdp 155 | base_arguments+=( --fsdp "auto_wrap" ) 156 | # [1] FULL_STATE_DICT, [2] LOCAL_STATE_DICT, [3] SHARDED_STATE_DICT 157 | export FSDP_STATE_DICT_TYPE="FULL_STATE_DICT" 158 | fi 159 | 160 | if [ $gc -ne 0 ]; then 161 | base_arguments+=( --gradient_checkpointing ) 162 | fi 163 | 164 | base_arguments+=( --tokenized_mds_train ) 165 | for domain in "${domains[@]}"; do 166 | base_arguments+=( $dataset/$domain ) 167 | done 168 | 169 | base_arguments+=( $@ ) 170 | 171 | echo command: "${header} ${base_arguments[@]}" 172 | ${header} "${base_arguments[@]}" 2>&1 | tee -a $out_dir/log.out 173 | -------------------------------------------------------------------------------- /train_sft.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash -l 2 | #SBATCH -J train_64K 3 | #SBATCH -N 1 4 | #SBATCH --output=slurm/%x-%j.out 5 | #SBATCH --gres=gpu:8 6 | #SBATCH --mem=400G 7 | #SBATCH -c 32 8 | 9 | # !!!! Load your own environment here !!!! # 10 | # !!!! Load your own environment here !!!! # 11 | 12 | # Fine-tune from this model 13 | model=${MODEL:-princeton-nlp/Llama-3-8B-ProLong-512k-Base} 14 | # Point to the base dir of the ProLong 64K data 15 | dataset=${DATASET:-"datasets"} 16 | 17 | # Directories in the dataset root folder where @ is followed by the mixing proportion 18 | domains=( 19 | prolong-ultrachat-64K@1.0 20 | ) 21 | domains_name=ultrachat 22 | 23 | 24 | bsz=${BSZ:-64} # * 64k (seq len) = 4M 25 | seq=${SEQ:-1} # per-device batch size 26 | lr=${LR:-2e-5} 27 | steps=${STEPS:-250} 28 | save_steps=${SAVE:-250} 29 | warmup=${WARMUP:-0.05} 30 | suffix=${SUFFIX:-""} # for model saving name 31 | 32 | 33 | run_name="sft_$(basename $model)_$(basename $dataset)_${domains_name}_bsz${bsz}_steps${steps}_lr${lr}_warmup${warmup}${suffix}" 34 | out_dir="checkpoints/$run_name" 35 | 36 | if [ -z "$CUDA_VISIBLE_DEVICES" ]; then 37 | num_gpus=$(nvidia-smi -L | wc -l) 38 | else 39 | num_gpus=$(jq -n "[$CUDA_VISIBLE_DEVICES] | length") 40 | fi 41 | num_gpus=${NUM_GPUS:-$num_gpus} 42 | 43 | num_nodes=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | wc -l) 44 | if [ $num_nodes == 0 ]; then 45 | num_nodes=1 46 | fi 47 | num_nodes=${NUM_NODES:-$num_nodes} 48 | 49 | # Gradient accumulation 50 | accu=$(($bsz / $seq / $num_gpus / $num_nodes)) 51 | 52 | 53 | # [0] Disable 54 | # [1] FULL_SHARD (shards optimizer states, gradients and parameters), 55 | # [2] SHARD_GRAD_OP (shards optimizer states and gradients), 56 | # [3] NO_SHARD (DDP), 57 | # [4] HYBRID_SHARD (shards optimizer states, gradients and parameters within each node while each node has full copy), 58 | # [5] HYBRID_SHARD_ZERO2 (shards optimizer states and gradients within each node while each node has full copy). For more information, please refer the official PyTorch docs. 59 | fsdp=${FSDP:-"1"} 60 | gc=${GC:-"1"} 61 | 62 | export LOGIT_BLOCK_SIZE=2048 # Compute Llama logits in blocks of 2048 tokens 63 | 64 | mkdir -p $out_dir 65 | nvidia-smi 66 | 67 | if [ $num_nodes -gt 1 ]; then 68 | master_addr=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1) 69 | master_addr=${MASTER_ADDR:-$master_addr} 70 | 71 | # Launch via srun 72 | header="srun torchrun \ 73 | --rdzv-backend=c10d \ 74 | --rdzv-endpoint=$master_addr:56321 \ 75 | --nnodes=$num_nodes \ 76 | --nproc-per-node=$num_gpus \ 77 | -m training.train_language_model" 78 | else 79 | master_port=$(comm -23 <(seq 49152 65535 | sort) <(ss -Htan | awk '{print $4}' | cut -d':' -f2 | sort -u) | shuf | head -n 1) 80 | 81 | # Launch without srun 82 | header="torchrun \ 83 | --rdzv-backend=c10d \ 84 | --rdzv-endpoint=localhost:$master_port \ 85 | --nnodes=1 \ 86 | --nproc-per-node=$num_gpus \ 87 | -m training.train_language_model" 88 | fi 89 | echo "slurm_nodelist=${SLURM_NODELIST} num_nodes=${num_nodes} master_addr=${master_addr} master_port=${master_port} num_gpus=${num_gpus}" 90 | 91 | export OMP_NUM_THREADS=$num_gpus 92 | export WANDB_PROJECT="prolong" 93 | export WANDB_DIR=$out_dir 94 | export WANDB_MODE="offline" # We turn off wandb online sync by default 95 | export TOKENIZERS_PARALLELISM=true 96 | 97 | 98 | base_arguments=( 99 | --report_to wandb 100 | --do_train 101 | 102 | --model_name $model 103 | --tokenizer_name $model 104 | 105 | --run_name $run_name 106 | --output_dir $out_dir 107 | --config_overrides_json "$overrides" 108 | --gradient_accumulation_steps $accu 109 | --per_device_train_batch_size $seq 110 | --per_device_eval_batch_size $seq 111 | 112 | --bf16 113 | --learning_rate $lr 114 | --min_lr_ratio 0.1 115 | --lr_scheduler_type cosine 116 | --max_grad_norm 1.0 117 | --adam_beta1 0.9 118 | --adam_beta2 0.95 119 | --weight_decay 0.1 120 | --warmup_ratio $warmup 121 | --optim adamw_torch 122 | 123 | --logging_steps 1 124 | --log_level info 125 | 126 | --max_steps $steps 127 | --save_steps $save_steps 128 | --dataloader_num_workers 1 129 | 130 | --disable_tqdm true 131 | --use_fast_tokenizer false 132 | --remove_unused_columns false 133 | --ddp_find_unused_parameters false 134 | 135 | --per_device_max_tokens 65536 136 | 137 | --cuda_empty_cache 138 | 139 | --apply_instruct_masks # mask out the tokens from instructions (instead of responses) when calculating losses 140 | --token_scaled_loss # average losses over valid training tokens instead of devices 141 | ) 142 | 143 | 144 | 145 | if [ $fsdp -ne 0 ]; then 146 | export FSDP_SHARDING_STRATEGY=$fsdp 147 | base_arguments+=( --fsdp "auto_wrap" ) 148 | # [1] FULL_STATE_DICT, [2] LOCAL_STATE_DICT, [3] SHARDED_STATE_DICT 149 | export FSDP_STATE_DICT_TYPE="FULL_STATE_DICT" 150 | fi 151 | 152 | if [ $gc -ne 0 ]; then 153 | base_arguments+=( --gradient_checkpointing ) 154 | fi 155 | 156 | base_arguments+=( --tokenized_mds_train ) 157 | for domain in "${domains[@]}"; do 158 | base_arguments+=( $dataset/$domain ) 159 | done 160 | 161 | base_arguments+=( $@ ) 162 | 163 | echo command: "${header} ${base_arguments[@]}" 164 | ${header} "${base_arguments[@]}" 2>&1 | tee -a $out_dir/log.out 165 | -------------------------------------------------------------------------------- /training/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | from streaming import StreamingDataset, Stream 5 | import logging 6 | 7 | from itertools import islice 8 | 9 | from typing import Dict, Any, List, Tuple 10 | from collections.abc import Iterator 11 | 12 | from training.trainer import TrainingArguments 13 | 14 | from dataclasses import dataclass, field 15 | from typing import Optional, List 16 | 17 | logger = logging.getLogger(__name__) 18 | 19 | 20 | @dataclass 21 | class DataArguments: 22 | single_seq: bool = field(default=False, metadata={"help": "Ignore the document boundaries and treat the whole packed sequence as a single sequence"}) 23 | per_device_max_tokens: Optional[int] = field(default=4_294_967_296, metadata={"help": "Maximum number of tokens per device; this is to avoid some catastrophic cases where the indices or data sequences are not filtered/truncated properly in preprocessing"}) 24 | apply_instruct_masks: bool = field(default=False, metadata={"help": "Whether to apply loss masks over the instructions (for instruction tuning). If enabled, will read the `mask` field in the data and set the corresponding labels to -100."}) 25 | 26 | 27 | class SafeStream(Stream): 28 | """Safe if multiple processes try to decompress the same shard.""" 29 | 30 | def _decompress_shard_part(self, zip_info, zip_filename, raw_filename, compression): 31 | unique_extension = "." + str(os.getenv("SLURM_JOB_ID", "local")) + "-" + str(os.getpid()) 32 | super()._decompress_shard_part(zip_info, zip_filename, raw_filename + unique_extension, compression) 33 | os.rename(raw_filename + unique_extension, raw_filename) 34 | 35 | 36 | class DataCollator: 37 | def __init__(self, tokenizer, args: DataArguments): 38 | self.tokenizer = tokenizer 39 | self.args = args 40 | 41 | @torch.no_grad() 42 | def __call__(self, features): 43 | input_ids = [] 44 | labels = [] 45 | seq_lengths = [] 46 | 47 | available_tokens = self.args.per_device_max_tokens 48 | for item in features: 49 | apply_instruct_masks = self.args.apply_instruct_masks and ("mask" in item) 50 | indices = item["indices"] if "indices" in item else [(0, len(item["input_ids"]))] 51 | if self.args.single_seq: 52 | indices = [(0, len(item["input_ids"]))] 53 | 54 | label_seq = torch.tensor(item["input_ids"], dtype=torch.long) 55 | 56 | for a, b in indices: 57 | b = a + min(b - a, available_tokens) 58 | if b - a > 1: 59 | input_seq = torch.tensor(item["input_ids"][a:b], dtype=torch.long) 60 | input_ids.append(input_seq) 61 | 62 | _label = label_seq[a:b] 63 | _label[0] = -100 # Don't predict the first token 64 | if apply_instruct_masks: 65 | # Read the `mask` field and set the corresponding labels to -100 66 | mask = torch.tensor(item["mask"][a:b], dtype=torch.long) 67 | _label[mask == 0] = -100 68 | labels.append(_label) 69 | 70 | seq_lengths.append(b - a) 71 | available_tokens -= b - a 72 | elif available_tokens <= 0: 73 | assert available_tokens == 0, "Available tokens should be non-negative" 74 | break 75 | 76 | input_ids = torch.concat(input_ids, dim=0) 77 | labels = torch.concat(labels, dim=0) 78 | seq_lengths = torch.tensor(seq_lengths, dtype=torch.long) 79 | 80 | return dict(input_ids=input_ids, 81 | attention_mask=None, 82 | labels=labels, 83 | seq_lengths=seq_lengths) 84 | 85 | 86 | 87 | class SortByLengthDataset(StreamingDataset): 88 | def __init__(self, *args, sort_by_length_size=1, data_args=None, **kwargs): 89 | super().__init__(*args, **kwargs) 90 | self.sort_by_length_size = sort_by_length_size 91 | self.data_args = data_args 92 | 93 | def _negative_item_cost(self, item): 94 | if "indices" in item: 95 | return -sum( 96 | (end - start)**2 for start, end in item["indices"] 97 | ) 98 | elif "length" in item: 99 | return -item["length"]**2 100 | else: 101 | return -len(item["input_ids"])**2 102 | 103 | def __iter__(self) -> Iterator[Dict[str, Any]]: 104 | if self.sort_by_length_size <= 1: 105 | yield from super().__iter__() 106 | else: 107 | iterator = super().__iter__() 108 | while True: 109 | block = list(islice(iterator, self.sort_by_length_size)) 110 | if not block: 111 | return 112 | 113 | yield from sorted(block, key=self._negative_item_cost) 114 | 115 | 116 | def build_dataset(paths, training_args: TrainingArguments, data_args: DataArguments, is_training: bool) -> StreamingDataset: 117 | logger.info(f"Loading datasets for {'training' if is_training else 'evaluation'}") 118 | 119 | streams = [] 120 | for path in paths: 121 | if "@" in path: 122 | path, proportion = path.split("@", 1) 123 | logger.info(f"Loading dataset from {path} with proportion {proportion}") 124 | streams.append(SafeStream(remote=path, local=path, proportion=float(proportion))) 125 | elif "#" in path: 126 | path, proportion = path.split("#", 1) 127 | logger.info(f"Loading dataset from {path} with repeat {proportion}") 128 | streams.append(SafeStream(remote=path, local=path, repeat=float(proportion))) 129 | else: 130 | streams.append(SafeStream(remote=path, local=path)) 131 | 132 | epoch_size = ( 133 | training_args.max_steps * training_args.train_batch_size * training_args.gradient_accumulation_steps * 134 | training_args.world_size // training_args.seq_parallel_size 135 | ) 136 | 137 | num_dataloaders = max(training_args.dataloader_num_workers, 1) 138 | per_device_step_size = training_args.gradient_accumulation_steps * training_args.train_batch_size 139 | per_worker_step_size = per_device_step_size // num_dataloaders 140 | assert per_device_step_size % num_dataloaders == 0, "dataloader workers should divide local batch size" 141 | 142 | return SortByLengthDataset( 143 | streams=streams, 144 | shuffle=is_training, 145 | shuffle_seed=training_args.seed, 146 | batch_size=(training_args.train_batch_size if is_training else training_args.eval_batch_size), 147 | epoch_size=(epoch_size if is_training else None), 148 | sort_by_length_size=(per_worker_step_size if is_training else 1), 149 | data_args=data_args, 150 | replication=training_args.seq_parallel_size, 151 | ) 152 | -------------------------------------------------------------------------------- /training/distributed_attention.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | # DeepSpeed Team 5 | 6 | import torch 7 | 8 | from typing import Any, Tuple 9 | from torch import Tensor 10 | from torch.nn import Module 11 | 12 | import torch.distributed as dist 13 | 14 | class SeqAllToAll(torch.autograd.Function): 15 | @staticmethod 16 | def forward(ctx: Any, input: Tensor, scatter_idx: int, gather_idx: int, group: Any) -> Tensor: 17 | ctx.scatter_idx = scatter_idx 18 | ctx.gather_idx = gather_idx 19 | ctx.group = group 20 | 21 | world_size = dist.get_world_size(group) 22 | 23 | input_list = [t.contiguous() for t in torch.tensor_split(input, world_size, scatter_idx)] 24 | output_list = [torch.empty_like(input_list[0]) for _ in range(world_size)] 25 | 26 | dist.all_to_all(output_list, input_list, group=group) 27 | return torch.cat(output_list, dim=gather_idx).contiguous() 28 | 29 | @staticmethod 30 | def backward(ctx: Any, *grad_output: Tensor) -> Tuple[Tensor, None, None, None]: 31 | return (SeqAllToAll.apply(*grad_output, ctx.gather_idx, ctx.scatter_idx, ctx.group), None, None, None) 32 | 33 | 34 | class DistributedAttention(torch.nn.Module): 35 | """Initialization. 36 | 37 | Arguments: 38 | local_attention (Module): local attention with q,k,v 39 | scatter_idx (int): scatter_idx for all2all comm 40 | gather_idx (int): gather_idx for all2all comm 41 | """ 42 | 43 | def __init__( 44 | self, 45 | local_attention: Module, 46 | scatter_idx: int = -2, 47 | gather_idx: int = 1, 48 | ) -> None: 49 | 50 | super(DistributedAttention, self).__init__() 51 | self.local_attn = local_attention 52 | self.scatter_idx = scatter_idx # head axis 53 | self.gather_idx = gather_idx # seq axis 54 | 55 | def forward(self, query: Tensor, key_values: Tensor, *args, group: Any = None, **kwargs) -> Tensor: 56 | """ forward 57 | 58 | Arguments: 59 | query (Tensor): query input to the layer 60 | key (Tensor): key input to the layer 61 | value (Tensor): value input to the layer 62 | args: other args 63 | 64 | Returns: 65 | * output (Tensor): context output 66 | """ 67 | #in shape : e.g., [s/p:h:] 68 | query_heads = SeqAllToAll.apply(query, self.scatter_idx, self.gather_idx, group) 69 | key_values_heads = SeqAllToAll.apply(key_values, self.scatter_idx, self.gather_idx, group) 70 | 71 | #out shape : e.g., [s:h/p:] 72 | output_heads = self.local_attn(query_heads, key_values_heads, *args, **kwargs) 73 | 74 | #out e.g., [s/p::h] 75 | return SeqAllToAll.apply(output_heads, self.gather_idx, self.scatter_idx, group) 76 | -------------------------------------------------------------------------------- /training/modeling_flash_llama.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. 3 | # 4 | # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX 5 | # and OPT implementations in this library. It has been modified from its 6 | # original forms to accommodate minor architectural differences compared 7 | # to GPT-NeoX and OPT used by the Meta AI team that trained the model. 8 | # 9 | # Licensed under the Apache License, Version 2.0 (the "License"); 10 | # you may not use this file except in compliance with the License. 11 | # You may obtain a copy of the License at 12 | # 13 | # http://www.apache.org/licenses/LICENSE-2.0 14 | # 15 | # Unless required by applicable law or agreed to in writing, software 16 | # distributed under the License is distributed on an "AS IS" BASIS, 17 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 18 | # See the License for the specific language governing permissions and 19 | # limitations under the License. 20 | """ PyTorch LLaMA model.""" 21 | from typing import List, Optional, Tuple, Union, Any 22 | 23 | import torch 24 | import torch.nn.functional as F 25 | import torch.utils.checkpoint 26 | from torch import nn 27 | from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss 28 | 29 | import torch.distributed as dist 30 | 31 | import os 32 | 33 | from transformers.activations import ACT2FN 34 | from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast 35 | from transformers.modeling_utils import PreTrainedModel 36 | from transformers.utils import logging 37 | from transformers.models.llama.configuration_llama import LlamaConfig 38 | 39 | from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS 40 | 41 | from flash_attn import flash_attn_kvpacked_func, flash_attn_varlen_kvpacked_func, flash_attn_with_kvcache 42 | from flash_attn.bert_padding import unpad_input, pad_input 43 | 44 | try: 45 | from flash_attn.layers.rotary import apply_rotary_emb_func 46 | except ImportError: 47 | raise ImportError('Please install RoPE kernels: `pip install git+https://github.com/HazyResearch/flash-attention.git#subdirectory=csrc/rotary`') 48 | 49 | from training.distributed_attention import DistributedAttention 50 | 51 | logger = logging.get_logger(__name__) 52 | 53 | # @torch.jit.script 54 | def rmsnorm_func(hidden_states, weight, variance_epsilon): 55 | input_dtype = hidden_states.dtype 56 | hidden_states = hidden_states.to(torch.float32) 57 | variance = hidden_states.pow(2).mean(-1, keepdim=True) 58 | hidden_states = hidden_states * torch.rsqrt(variance + variance_epsilon) 59 | return (weight * hidden_states).to(input_dtype) 60 | 61 | 62 | class LlamaRMSNorm(nn.Module): 63 | def __init__(self, hidden_size, eps=1e-6): 64 | """ 65 | LlamaRMSNorm is equivalent to T5LayerNorm 66 | """ 67 | super().__init__() 68 | self.weight = nn.Parameter(torch.ones(hidden_size)) 69 | self.register_buffer( 70 | "variance_epsilon", 71 | torch.tensor(eps), 72 | persistent=False, 73 | ) 74 | 75 | def forward(self, hidden_states): 76 | return rmsnorm_func(hidden_states, self.weight, self.variance_epsilon) 77 | 78 | 79 | class LlamaRotaryEmbedding(nn.Module): 80 | def __init__( 81 | self, 82 | dim=None, 83 | max_position_embeddings=2048, 84 | base=10000, 85 | device=None, 86 | scaling_factor=1.0, 87 | rope_type="default", 88 | interleaved=False, 89 | config: Optional[LlamaConfig] = None, 90 | ): 91 | super().__init__() 92 | self.rope_kwargs = {} 93 | self.scaling_factor = scaling_factor 94 | self.interleaved = interleaved 95 | self.pos_idx_in_fp32 = True 96 | 97 | if config is None: 98 | logger.warning_once( 99 | "`L3lamaRotaryEmbedding` can now be fully parameterized by passing the model config through the " 100 | "`config` argument. All other arguments will be removed in v4.46" 101 | ) 102 | self.rope_kwargs = { 103 | "rope_type": rope_type, 104 | "factor": scaling_factor, 105 | "dim": dim, 106 | "base": base, 107 | "max_position_embeddings": max_position_embeddings, 108 | } 109 | self.rope_type = rope_type 110 | else: 111 | # BC: "rope_type" was originally "type" 112 | if config.rope_scaling is not None: 113 | self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) 114 | else: 115 | self.rope_type = "default" 116 | 117 | self._seq_len_cached = 0 118 | 119 | self.config = config 120 | self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] 121 | 122 | inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) 123 | self.register_buffer("inv_freq", inv_freq, persistent=False) 124 | 125 | 126 | @torch.no_grad() 127 | def _update_cos_sin_cache(self, seq_len, device=None, dtype=None): 128 | # Reset the tables if the sequence length has changed, 129 | # if we're on a new device (possibly due to tracing for instance), 130 | # or if we're switching from inference mode to training 131 | if (seq_len > self._seq_len_cached or self._cos_cached.device != device 132 | or self._cos_cached.dtype != dtype 133 | or (self.training and self._cos_cached.is_inference())): 134 | self._seq_len_cached = seq_len 135 | 136 | if "dynamic" in self.rope_type: 137 | inv_freq, self.attention_scaling = self.rope_init_fn( 138 | self.config, device, seq_len=seq_len, **self.rope_kwargs 139 | ) 140 | self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation 141 | 142 | # We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16 143 | # And the output of arange can be quite large, so bf16 would lose a lot of precision. 144 | # However, for compatibility reason, we add an option to use the dtype of self.inv_freq. 145 | if self.pos_idx_in_fp32: 146 | t = torch.arange(seq_len, device=device, dtype=torch.float32) 147 | t /= self.scaling_factor 148 | # We want fp32 here as well since inv_freq will be multiplied with t, and the output 149 | # will be large. Having it in bf16 will lose a lot of precision and cause the 150 | # cos & sin output to change significantly. 151 | # We want to recompute self.inv_freq if it was not loaded in fp32 152 | if self.inv_freq.dtype != torch.float32: 153 | inv_freq = self.inv_freq.to(torch.float32) 154 | else: 155 | inv_freq = self.inv_freq 156 | else: 157 | t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) 158 | t /= self.scaling_factor 159 | inv_freq = self.inv_freq 160 | 161 | # Don't do einsum, it converts fp32 to fp16 under AMP 162 | # freqs = torch.einsum("i,j->ij", t, self.inv_freq) 163 | freqs = torch.outer(t, inv_freq) 164 | self._cos_cached = (torch.cos(freqs) * self.attention_scaling).to(dtype) 165 | self._sin_cached = (torch.sin(freqs) * self.attention_scaling).to(dtype) 166 | 167 | 168 | def forward( 169 | self, 170 | q: torch.Tensor, k: torch.Tensor, 171 | seqlen_offset: int = 0, # Used in sequence parallelism where each device sees only a chunk of the full sequence 172 | unpadded_lengths: Optional[Tuple[torch.Tensor]] = None 173 | ): 174 | if unpadded_lengths is not None: 175 | cu_seqlens, max_seqlen = unpadded_lengths 176 | if seqlen_offset > 0: 177 | raise ValueError("seqlen_offset is not supported with unpadded_lengths") 178 | else: 179 | cu_seqlens, max_seqlen = None, q.shape[1] 180 | 181 | self._update_cos_sin_cache(max_seqlen + seqlen_offset, q.device, q.dtype) 182 | 183 | return apply_rotary_emb_func( 184 | q, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:], 185 | self.interleaved, True, # inplace=True, 186 | cu_seqlens=cu_seqlens, max_seqlen=max_seqlen 187 | ), apply_rotary_emb_func( 188 | k, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:], 189 | self.interleaved, True, # inplace=True 190 | cu_seqlens=cu_seqlens, max_seqlen=max_seqlen 191 | ) 192 | 193 | 194 | class LlamaMLP(nn.Module): 195 | def __init__(self, config): 196 | super().__init__() 197 | self.config = config 198 | self.hidden_size = config.hidden_size 199 | self.intermediate_size = config.intermediate_size 200 | self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) 201 | self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) 202 | self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) 203 | self.act_fn = ACT2FN[config.hidden_act] 204 | 205 | def forward(self, x): 206 | return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) 207 | 208 | 209 | @torch.jit.script 210 | def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: 211 | if n_rep == 1: 212 | return hidden_states 213 | final_shape = list(hidden_states.shape[:-2]) + [-1] + [hidden_states.shape[-1]] 214 | expand_shape = [-1] * (len(hidden_states.shape) - 1) + [n_rep] + [-1] 215 | hidden_states = hidden_states.unsqueeze(-2).expand(expand_shape) 216 | return hidden_states.reshape(final_shape) 217 | 218 | 219 | class LlamaAttention(nn.Module): 220 | """Multi-headed attention from 'Attention Is All You Need' paper""" 221 | 222 | def __init__( 223 | self, 224 | config: LlamaConfig, 225 | layer_idx: Optional[int] = None, 226 | context_window_toggle: Optional[int] = 4096, 227 | ): 228 | """ 229 | @context_window_toggle: if not None, the attention will be limited to a context window specified by this value 230 | """ 231 | super().__init__() 232 | self.config = config 233 | self.hidden_size = config.hidden_size 234 | self.num_heads = config.num_attention_heads 235 | self.head_dim = self.hidden_size // self.num_heads 236 | self.num_key_value_heads = getattr(config, "num_key_value_heads", self.num_heads) 237 | self.num_key_value_groups = self.num_heads // self.num_key_value_heads 238 | self.max_position_embeddings = config.max_position_embeddings 239 | 240 | if (self.head_dim * self.num_heads) != self.hidden_size: 241 | raise ValueError( 242 | f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" 243 | f" and `num_heads`: {self.num_heads})." 244 | ) 245 | self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) 246 | self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) 247 | self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) 248 | self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) 249 | 250 | self.register_buffer( 251 | "norm_factor", 252 | torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32)).to(torch.get_default_dtype()), 253 | persistent=False, 254 | ) 255 | self.rotary_emb = LlamaRotaryEmbedding(config=self.config) 256 | 257 | self.distributed_attn_func = DistributedAttention(flash_attn_kvpacked_func, gather_idx=1) 258 | self.distributed_varlen_attn_func = DistributedAttention(flash_attn_varlen_kvpacked_func, gather_idx=0) 259 | 260 | 261 | def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): 262 | return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() 263 | 264 | def forward( 265 | self, 266 | hidden_states: torch.Tensor, 267 | attention_mask: Optional[torch.Tensor] = None, 268 | position_ids: Optional[torch.LongTensor] = None, 269 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 270 | output_attentions: bool = False, 271 | use_cache: bool = False, 272 | unpadded_lengths: Optional[Tuple[torch.Tensor]] = None, 273 | seq_parallel_group: Optional[Any] = None, 274 | position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 275 | **kwargs, 276 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 277 | q_len, h_size = hidden_states.size(-2), hidden_states.size(-1) 278 | 279 | q = self.q_proj(hidden_states) 280 | k = self.k_proj(hidden_states) 281 | v = self.v_proj(hidden_states) 282 | 283 | q = q.view(*q.shape[:-1], self.num_heads, self.head_dim) 284 | k = k.view(*k.shape[:-1], self.num_key_value_heads, self.head_dim) 285 | v = v.view(*v.shape[:-1], self.num_key_value_heads, self.head_dim) 286 | 287 | has_layer_past = past_key_value is not None 288 | 289 | if has_layer_past: 290 | past_kv = past_key_value[0] 291 | past_len = past_key_value[1] 292 | else: 293 | past_len = 0 294 | 295 | # NOTE: Hacky way to include position_ids in sequence parallelism, assuming they are increasing uniformly per block 296 | if position_ids is not None: 297 | past_len += position_ids.min() 298 | 299 | if unpadded_lengths is not None: 300 | # We don't use the unpadded_length during rotary embeds and instead create a temporary `batch` dimension 301 | # This does not actually affect the otucome since the positional embeddings are relative and stay valid 302 | # This also ensures that in sequence parallelism the correct `past_len` offset is applied to mid-sequence chunks 303 | q, k = self.rotary_emb(q.unsqueeze(0), k.unsqueeze(0), past_len) 304 | q, k = q.squeeze(0), k.squeeze(0) 305 | else: 306 | q, k = self.rotary_emb(q, k, past_len) 307 | 308 | kv = torch.stack([k, v], -3) 309 | kv = repeat_kv(kv, self.num_key_value_groups) 310 | 311 | # Cache QKV values 312 | if has_layer_past: 313 | new_len = past_len+q.size(1) 314 | if new_len > past_kv.size(1): 315 | past_kv = torch.cat([past_kv, torch.empty(hidden_states.size(0), 256, 2, kv.size(3), kv.size(4), dtype=kv.dtype, device=kv.device)], 1) 316 | past_kv[:, past_len:new_len] = kv 317 | kv = past_kv[:, :new_len] 318 | else: 319 | past_kv = kv 320 | 321 | if seq_parallel_group is not None and dist.is_initialized() and dist.get_world_size(seq_parallel_group) > 1: 322 | attention_func = self.distributed_varlen_attn_func if unpadded_lengths is not None else self.distributed_attn_func 323 | kwargs = {"group": seq_parallel_group} 324 | else: 325 | attention_func = flash_attn_varlen_kvpacked_func if unpadded_lengths is not None else flash_attn_kvpacked_func 326 | kwargs = {} 327 | 328 | if unpadded_lengths is not None: 329 | # varlen, ignore padding tokens, efficient for large batch with many paddings 330 | cu_seqlens, max_seqlen = unpadded_lengths 331 | 332 | attn_outputs = attention_func( 333 | q, kv, 334 | cu_seqlens, cu_seqlens, 335 | max_seqlen, max_seqlen, 336 | dropout_p=0.0, softmax_scale=1.0/self.norm_factor, 337 | causal=True, return_attn_probs=output_attentions, 338 | **kwargs 339 | ) 340 | else: 341 | attn_outputs = attention_func( 342 | q, kv, 343 | dropout_p=0.0, 344 | softmax_scale=1.0/self.norm_factor, 345 | causal=True, 346 | return_attn_probs=output_attentions, 347 | **kwargs 348 | ) 349 | 350 | past_key_value = (past_kv, past_len+q.size(1)) if use_cache else None 351 | 352 | attn_output = attn_outputs[0] if output_attentions else attn_outputs 353 | attn_output = attn_output.reshape(*attn_output.shape[:-2], h_size) 354 | attn_weights = attn_outputs[2] if output_attentions else None 355 | 356 | attn_output = self.o_proj(attn_output) 357 | 358 | if not output_attentions: 359 | attn_weights = None 360 | 361 | return attn_output, attn_weights, past_key_value 362 | 363 | 364 | class LlamaDecoderLayer(nn.Module): 365 | def __init__(self, config: LlamaConfig): 366 | super().__init__() 367 | self.hidden_size = config.hidden_size 368 | self.self_attn = LlamaAttention(config=config) 369 | self.mlp = LlamaMLP(config) 370 | self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) 371 | self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) 372 | self._fsdp_wrap = True 373 | 374 | def forward( 375 | self, 376 | hidden_states: torch.Tensor, 377 | attention_mask: Optional[torch.Tensor] = None, 378 | position_ids: Optional[torch.LongTensor] = None, 379 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 380 | unpadded_lengths: Optional[Tuple[torch.Tensor]] = None, 381 | output_attentions: Optional[bool] = False, 382 | use_cache: Optional[bool] = False, 383 | seq_parallel_group: Optional[Any] = None, 384 | ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: 385 | """ 386 | Args: 387 | hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` 388 | attention_mask (`torch.FloatTensor`, *optional*): attention mask of size 389 | `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. 390 | output_attentions (`bool`, *optional*): 391 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under 392 | returned tensors for more detail. 393 | use_cache (`bool`, *optional*): 394 | If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding 395 | (see `past_key_values`). 396 | past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states 397 | """ 398 | 399 | residual = hidden_states 400 | 401 | hidden_states = self.input_layernorm(hidden_states) 402 | 403 | # Self Attention 404 | hidden_states, self_attn_weights, present_key_value = self.self_attn( 405 | hidden_states=hidden_states, 406 | attention_mask=attention_mask, 407 | position_ids=position_ids, 408 | past_key_value=past_key_value, 409 | output_attentions=output_attentions, 410 | use_cache=use_cache, 411 | unpadded_lengths=unpadded_lengths, 412 | seq_parallel_group=seq_parallel_group, 413 | ) 414 | hidden_states = residual + hidden_states 415 | 416 | # Fully Connected 417 | residual = hidden_states 418 | hidden_states = self.post_attention_layernorm(hidden_states) 419 | hidden_states = self.mlp(hidden_states) 420 | hidden_states = residual + hidden_states 421 | 422 | outputs = (hidden_states,) 423 | 424 | if output_attentions: 425 | outputs += (self_attn_weights,) 426 | 427 | if use_cache: 428 | outputs += (present_key_value,) 429 | 430 | return outputs 431 | 432 | 433 | class LlamaPreTrainedModel(PreTrainedModel): 434 | config_class = LlamaConfig 435 | base_model_prefix = "model" 436 | supports_gradient_checkpointing = True 437 | _no_split_modules = ["LlamaDecoderLayer"] 438 | _skip_keys_device_placement = "past_key_values" 439 | 440 | def _init_weights(self, module): 441 | std = self.config.initializer_range 442 | if isinstance(module, nn.Linear): 443 | module.weight.data.normal_(mean=0.0, std=std) 444 | if module.bias is not None: 445 | module.bias.data.zero_() 446 | elif isinstance(module, nn.Embedding): 447 | module.weight.data.normal_(mean=0.0, std=std) 448 | if module.padding_idx is not None: 449 | module.weight.data[module.padding_idx].zero_() 450 | 451 | 452 | class LlamaModel(LlamaPreTrainedModel): 453 | """ 454 | Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`] 455 | 456 | Args: 457 | config: LlamaConfig 458 | """ 459 | 460 | def __init__(self, config: LlamaConfig): 461 | super().__init__(config) 462 | self.padding_idx = config.pad_token_id 463 | self.vocab_size = config.vocab_size 464 | 465 | self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) 466 | self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)]) 467 | self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) 468 | 469 | self.gradient_checkpointing = False 470 | # Initialize weights and apply final processing 471 | self.post_init() 472 | 473 | def get_input_embeddings(self): 474 | return self.embed_tokens 475 | 476 | def set_input_embeddings(self, value): 477 | self.embed_tokens = value 478 | 479 | def forward( 480 | self, 481 | input_ids: torch.LongTensor = None, 482 | attention_mask: Optional[torch.Tensor] = None, 483 | position_ids: Optional[torch.LongTensor] = None, 484 | past_key_values: Optional[List[torch.FloatTensor]] = None, 485 | inputs_embeds: Optional[torch.FloatTensor] = None, 486 | use_cache: Optional[bool] = None, 487 | output_attentions: Optional[bool] = None, 488 | output_hidden_states: Optional[bool] = None, 489 | return_dict: Optional[bool] = None, 490 | unpadded_lengths: Optional[Tuple[torch.Tensor]] = None, 491 | seq_parallel_group: Optional[Any] = None, 492 | ) -> Union[Tuple, BaseModelOutputWithPast]: 493 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 494 | output_hidden_states = ( 495 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 496 | ) 497 | use_cache = use_cache if use_cache is not None else self.config.use_cache 498 | 499 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 500 | 501 | # retrieve input_ids and inputs_embeds 502 | if input_ids is not None and inputs_embeds is not None: 503 | raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") 504 | elif input_ids is None and inputs_embeds is None: 505 | raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") 506 | 507 | # position_ids = None 508 | 509 | if inputs_embeds is None: 510 | inputs_embeds = self.embed_tokens(input_ids) 511 | 512 | hidden_states = inputs_embeds 513 | 514 | if self.gradient_checkpointing and self.training: 515 | if use_cache: 516 | logger.warning_once( 517 | "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." 518 | ) 519 | use_cache = False 520 | 521 | # decoder layers 522 | all_hidden_states = () if output_hidden_states else None 523 | all_self_attns = () if output_attentions else None 524 | next_decoder_cache = () if use_cache else None 525 | 526 | for idx, decoder_layer in enumerate(self.layers): 527 | if output_hidden_states: 528 | all_hidden_states += (hidden_states,) 529 | 530 | past_key_value = past_key_values[idx] if past_key_values is not None else None 531 | 532 | if self.gradient_checkpointing and self.training: 533 | layer_outputs = torch.utils.checkpoint.checkpoint( 534 | decoder_layer, 535 | hidden_states, 536 | attention_mask, 537 | position_ids, 538 | None, 539 | unpadded_lengths, 540 | output_attentions, 541 | False, 542 | seq_parallel_group, 543 | use_reentrant=False, 544 | ) 545 | else: 546 | layer_outputs = decoder_layer( 547 | hidden_states, 548 | attention_mask=attention_mask, 549 | position_ids=position_ids, 550 | past_key_value=past_key_value, 551 | unpadded_lengths=unpadded_lengths, 552 | output_attentions=output_attentions, 553 | use_cache=use_cache, 554 | seq_parallel_group=seq_parallel_group, 555 | ) 556 | 557 | hidden_states = layer_outputs[0] 558 | 559 | if use_cache: 560 | next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) 561 | 562 | if output_attentions: 563 | all_self_attns += (layer_outputs[1],) 564 | 565 | hidden_states = self.norm(hidden_states) 566 | 567 | # add hidden states from the last decoder layer 568 | if output_hidden_states: 569 | all_hidden_states += (hidden_states,) 570 | 571 | next_cache = next_decoder_cache if use_cache else None 572 | if not return_dict: 573 | return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) 574 | return BaseModelOutputWithPast( 575 | last_hidden_state=hidden_states, 576 | past_key_values=next_cache, 577 | hidden_states=all_hidden_states, 578 | attentions=all_self_attns, 579 | ) 580 | 581 | 582 | class LlamaForCausalLM(LlamaPreTrainedModel): 583 | _tied_weights_keys = ["lm_head.weight"] 584 | 585 | def __init__(self, config): 586 | super().__init__(config) 587 | self.model = LlamaModel(config) 588 | self.vocab_size = config.vocab_size 589 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 590 | 591 | self.logit_block_size = int(os.environ.get("LOGIT_BLOCK_SIZE", 0)) 592 | 593 | # Initialize weights and apply final processing 594 | self.post_init() 595 | 596 | def get_input_embeddings(self): 597 | return self.model.embed_tokens 598 | 599 | def set_input_embeddings(self, value): 600 | self.model.embed_tokens = value 601 | 602 | def get_output_embeddings(self): 603 | return self.lm_head 604 | 605 | def set_output_embeddings(self, new_embeddings): 606 | self.lm_head = new_embeddings 607 | 608 | def set_decoder(self, decoder): 609 | self.model = decoder 610 | 611 | def get_decoder(self): 612 | return self.model 613 | 614 | def compute_loss(self, hidden_states, labels, token_losses=False): 615 | logits = self.lm_head(hidden_states) 616 | if len(logits.shape) > 2: 617 | logits = logits.transpose(-1, -2) 618 | # For num-valid-token-scaled loss, we sum up here and later reweight in `compute_loss` in the trainer 619 | return F.cross_entropy( 620 | logits, labels, 621 | ignore_index=-100, 622 | reduction=("sum" if getattr(self, "token_scaled_loss", False) else "mean") 623 | ) 624 | 625 | def forward( 626 | self, 627 | input_ids: torch.LongTensor = None, 628 | attention_mask: Optional[torch.Tensor] = None, 629 | position_ids: Optional[torch.LongTensor] = None, 630 | past_key_values: Optional[List[torch.FloatTensor]] = None, 631 | inputs_embeds: Optional[torch.FloatTensor] = None, 632 | labels: Optional[torch.LongTensor] = None, 633 | use_cache: Optional[bool] = None, 634 | output_attentions: Optional[bool] = None, 635 | output_hidden_states: Optional[bool] = None, 636 | return_dict: Optional[bool] = None, 637 | seq_lengths: Optional[torch.Tensor] = None, 638 | return_token_losses: bool = False, 639 | shifted_labels: Optional[torch.LongTensor] = None, 640 | seq_parallel_group: Optional[Any] = None, 641 | ) -> Union[Tuple, CausalLMOutputWithPast]: 642 | r""" 643 | Args: 644 | labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): 645 | Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., 646 | config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored 647 | (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. 648 | 649 | Returns: 650 | 651 | Example: 652 | 653 | ```python 654 | >>> from transformers import AutoTokenizer, LlamaForCausalLM 655 | 656 | >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) 657 | >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) 658 | 659 | >>> prompt = "Hey, are you conscious? Can you talk to me?" 660 | >>> inputs = tokenizer(prompt, return_tensors="pt") 661 | 662 | >>> # Generate 663 | >>> generate_ids = model.generate(inputs.input_ids, max_length=30) 664 | >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] 665 | "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." 666 | ```""" 667 | 668 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 669 | output_hidden_states = ( 670 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 671 | ) 672 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 673 | 674 | if seq_lengths is not None: 675 | if inputs_embeds is not None: 676 | assert len(inputs_embeds.shape) == 2, "inputs_embeds should be a 2D tensor with `seq_lengths`" 677 | # assert inputs_embeds.size(0) == seq_lengths.sum(), "inputs_embeds and seq_lengths should have the same batch size" 678 | else: 679 | assert len(input_ids.shape) == 1, "input_ids should be a 1D tensor with `seq_lengths`" 680 | # assert input_ids.size(0) == seq_lengths.sum(), "input_ids and seq_lengths should have the same batch size" 681 | 682 | assert attention_mask is None or attention_mask.all().item(), "attention_mask should be None or all ones for `seq_lengths`" 683 | assert not use_cache, "use_cache is not supported with `seq_lengths`" 684 | 685 | cu_seqlens = F.pad(torch.cumsum(seq_lengths, dim=0, dtype=torch.torch.int32), (1, 0)) 686 | max_seqlen = seq_lengths.max().item() 687 | 688 | unpadded_lengths = (cu_seqlens, max_seqlen) 689 | elif ( 690 | ((attention_mask is not None) and (not attention_mask.all().item())) 691 | and not use_cache 692 | ): 693 | if inputs_embeds is not None: 694 | bsz = inputs_embeds.size(0) 695 | inputs_embeds, unpad_indices, cu_seqlens, max_seqlen = unpad_input(inputs_embeds, attention_mask) 696 | else: 697 | bsz = input_ids.size(0) 698 | input_ids, unpad_indices, cu_seqlens, max_seqlen = unpad_input(input_ids.unsqueeze(-1), attention_mask) 699 | input_ids = input_ids.squeeze(-1) 700 | unpadded_lengths = (cu_seqlens, max_seqlen) 701 | else: 702 | unpadded_lengths = None 703 | 704 | # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) 705 | outputs = self.model( 706 | input_ids=input_ids, 707 | attention_mask=attention_mask, 708 | position_ids=position_ids, 709 | past_key_values=past_key_values, 710 | inputs_embeds=inputs_embeds, 711 | use_cache=use_cache, 712 | output_attentions=output_attentions, 713 | output_hidden_states=output_hidden_states, 714 | return_dict=return_dict, 715 | unpadded_lengths=unpadded_lengths, 716 | seq_parallel_group=seq_parallel_group, 717 | ) 718 | hidden_states = outputs[0] 719 | 720 | if seq_lengths is None and unpadded_lengths is not None: 721 | hidden_states = pad_input(hidden_states, unpad_indices, bsz, max_seqlen) 722 | 723 | if labels is not None or shifted_labels is not None: 724 | if shifted_labels is not None: 725 | labels = shifted_labels.reshape(-1) 726 | hidden_states = hidden_states.reshape(-1, hidden_states.size(-1)) 727 | else: 728 | labels = labels[..., 1:].reshape(-1) 729 | hidden_states = hidden_states[..., :-1 ,:].reshape(-1, hidden_states.size(-1)) 730 | 731 | if self.logit_block_size > 0: 732 | num_valid_labels = (labels != -100).sum() 733 | hidden_states = torch.split(hidden_states, self.logit_block_size, dim=0) 734 | labels = torch.split(labels, self.logit_block_size, dim=0) 735 | 736 | if getattr(self, "token_scaled_loss", False): 737 | # Just calculate the sum of loss here; we will divide by valid #tokens later in the trainer 738 | loss = sum( 739 | torch.utils.checkpoint.checkpoint(self.compute_loss, 740 | hidden_state_block, 741 | label_block, 742 | use_reentrant=False) 743 | for hidden_state_block, label_block in zip(hidden_states, labels) 744 | ) 745 | else: 746 | loss = sum( 747 | ((label_block != -100).sum() / num_valid_labels) * 748 | torch.utils.checkpoint.checkpoint(self.compute_loss, 749 | hidden_state_block, 750 | label_block, 751 | use_reentrant=False) 752 | for hidden_state_block, label_block in zip(hidden_states, labels) 753 | ) 754 | 755 | else: 756 | loss = self.compute_loss(hidden_states, labels) 757 | logits = None 758 | else: 759 | logits = self.lm_head(hidden_states) 760 | loss = None 761 | 762 | if not return_dict: 763 | output = (None,) + outputs[1:] 764 | return (loss,) + output if loss is not None else output 765 | 766 | return CausalLMOutputWithPast( 767 | loss=loss, 768 | logits=logits, 769 | past_key_values=outputs.past_key_values, 770 | hidden_states=outputs.hidden_states, 771 | attentions=outputs.attentions, 772 | ) 773 | 774 | def prepare_inputs_for_generation( 775 | self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs 776 | ): 777 | if past_key_values: 778 | input_ids = input_ids[:, -1:] 779 | 780 | # if `inputs_embeds` are passed, we only want to use them in the 1st generation step 781 | if inputs_embeds is not None and past_key_values is None: 782 | model_inputs = {"inputs_embeds": inputs_embeds} 783 | else: 784 | model_inputs = {"input_ids": input_ids} 785 | 786 | model_inputs.update( 787 | { 788 | "past_key_values": past_key_values, 789 | "use_cache": kwargs.get("use_cache"), 790 | "attention_mask": attention_mask 791 | } 792 | ) 793 | return model_inputs 794 | 795 | @staticmethod 796 | def _reorder_cache(past_key_values, beam_idx): 797 | reordered_past = () 798 | for layer_past in past_key_values: 799 | reordered_past += ( 800 | tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), 801 | ) 802 | return reordered_past 803 | 804 | 805 | class LlamaForSequenceClassification(LlamaPreTrainedModel): 806 | def __init__(self, config): 807 | super().__init__(config) 808 | self.num_labels = config.num_labels 809 | self.model = LlamaModel(config) 810 | self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) 811 | 812 | # Initialize weights and apply final processing 813 | self.post_init() 814 | 815 | def get_input_embeddings(self): 816 | return self.model.embed_tokens 817 | 818 | def set_input_embeddings(self, value): 819 | self.model.embed_tokens = value 820 | 821 | def forward( 822 | self, 823 | input_ids: torch.LongTensor = None, 824 | attention_mask: Optional[torch.Tensor] = None, 825 | position_ids: Optional[torch.LongTensor] = None, 826 | past_key_values: Optional[List[torch.FloatTensor]] = None, 827 | inputs_embeds: Optional[torch.FloatTensor] = None, 828 | labels: Optional[torch.LongTensor] = None, 829 | use_cache: Optional[bool] = None, 830 | output_attentions: Optional[bool] = None, 831 | output_hidden_states: Optional[bool] = None, 832 | return_dict: Optional[bool] = None, 833 | seq_parallel_group: Optional[Any] = None, 834 | ) -> Union[Tuple, SequenceClassifierOutputWithPast]: 835 | r""" 836 | labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): 837 | Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., 838 | config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If 839 | `config.num_labels > 1` a classification loss is computed (Cross-Entropy). 840 | """ 841 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 842 | 843 | transformer_outputs = self.model( 844 | input_ids, 845 | attention_mask=attention_mask, 846 | position_ids=position_ids, 847 | past_key_values=past_key_values, 848 | inputs_embeds=inputs_embeds, 849 | use_cache=use_cache, 850 | output_attentions=output_attentions, 851 | output_hidden_states=output_hidden_states, 852 | return_dict=return_dict, 853 | seq_parallel_group=seq_parallel_group 854 | ) 855 | hidden_states = transformer_outputs[0] 856 | logits = self.score(hidden_states) 857 | 858 | if input_ids is not None: 859 | batch_size = input_ids.shape[0] 860 | else: 861 | batch_size = inputs_embeds.shape[0] 862 | 863 | if self.config.pad_token_id is None and batch_size != 1: 864 | raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") 865 | if self.config.pad_token_id is None: 866 | sequence_lengths = -1 867 | else: 868 | if input_ids is not None: 869 | sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device) 870 | else: 871 | sequence_lengths = -1 872 | 873 | pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] 874 | 875 | loss = None 876 | if labels is not None: 877 | labels = labels.to(logits.device) 878 | if self.config.problem_type is None: 879 | if self.num_labels == 1: 880 | self.config.problem_type = "regression" 881 | elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): 882 | self.config.problem_type = "single_label_classification" 883 | else: 884 | self.config.problem_type = "multi_label_classification" 885 | 886 | if self.config.problem_type == "regression": 887 | loss_fct = MSELoss() 888 | if self.num_labels == 1: 889 | loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) 890 | else: 891 | loss = loss_fct(pooled_logits, labels) 892 | elif self.config.problem_type == "single_label_classification": 893 | loss_fct = CrossEntropyLoss() 894 | loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) 895 | elif self.config.problem_type == "multi_label_classification": 896 | loss_fct = BCEWithLogitsLoss() 897 | loss = loss_fct(pooled_logits, labels) 898 | if not return_dict: 899 | output = (pooled_logits,) + transformer_outputs[1:] 900 | return ((loss,) + output) if loss is not None else output 901 | 902 | return SequenceClassifierOutputWithPast( 903 | loss=loss, 904 | logits=pooled_logits, 905 | past_key_values=transformer_outputs.past_key_values, 906 | hidden_states=transformer_outputs.hidden_states, 907 | attentions=transformer_outputs.attentions, 908 | ) 909 | 910 | -------------------------------------------------------------------------------- /training/train_language_model.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import sys 4 | import torch 5 | import datasets 6 | import transformers 7 | import functools 8 | 9 | from transformers import ( 10 | AutoConfig, 11 | AutoTokenizer, 12 | HfArgumentParser, 13 | set_seed, 14 | ) 15 | 16 | from training.modeling_flash_llama import LlamaForCausalLM 17 | from training.trainer import Trainer, TrainingArguments 18 | from training.dataset import build_dataset, DataCollator, DataArguments 19 | from training.dataset import logger as dataset_logger 20 | 21 | 22 | from torch.distributed.fsdp.wrap import lambda_auto_wrap_policy 23 | 24 | from transformers.trainer_utils import get_last_checkpoint 25 | import json 26 | from dataclasses import dataclass, field 27 | from typing import Optional, List 28 | 29 | 30 | logger = logging.getLogger(__name__) 31 | 32 | @dataclass 33 | class ScriptArguments: 34 | model_name_or_path: Optional[str] = field( 35 | default=None, 36 | metadata={ 37 | "help": ( 38 | "The model checkpoint for weights initialization.Don't set if you want to train a model from scratch." 39 | ) 40 | }, 41 | ) 42 | config_overrides: Optional[str] = field( 43 | default=None, 44 | metadata={ 45 | "help": ( 46 | "Override some existing default config settings when a model is trained from scratch. Example: " 47 | "n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index" 48 | ) 49 | }, 50 | ) 51 | config_overrides_json: Optional[str] = field( 52 | default=None, 53 | metadata={ 54 | "help": ( 55 | "Override some existing default config settings when a model is trained from scratch. Example: " 56 | "'{\"resid_pdrop\": 0.2}'" 57 | ) 58 | }, 59 | ) 60 | config_name: Optional[str] = field( 61 | default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} 62 | ) 63 | tokenizer_name: Optional[str] = field( 64 | default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} 65 | ) 66 | cache_dir: Optional[str] = field( 67 | default=None, 68 | metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"}, 69 | ) 70 | use_fast_tokenizer: bool = field( 71 | default=True, 72 | metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, 73 | ) 74 | model_revision: str = field( 75 | default="main", 76 | metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, 77 | ) 78 | use_auth_token: bool = field( 79 | default=False, 80 | metadata={ 81 | "help": ( 82 | "Will use the token generated when running `huggingface-cli login` (necessary to use this script " 83 | "with private models)." 84 | ) 85 | }, 86 | ) 87 | 88 | tokenized_mds_train: List[str] = field(default_factory=list, metadata={"help": "Paths to tokenized training datasets in MDS format"}) 89 | tokenized_mds_validation: List[str] = field(default_factory=list, metadata={"help": "Paths to tokenized validation datasets in MDS format"}) 90 | tokenized_mds_test: List[str] = field(default_factory=list, metadata={"help": "Paths to tokenized test datasets in MDS format"}) 91 | 92 | token_scaled_loss: bool = field(default=False, metadata={"help": "Whether to re-scale the loss by the number of valid training tokens instead of averaging loss across sequences and devices. This should be turned on for instruction tuning, especially when using synthetic data, as the valid training tokens vary across devices."}) 93 | 94 | 95 | def main(): 96 | # See all possible arguments in src/transformers/training_args.py 97 | # or by passing the --help flag to this script. 98 | # We now keep distinct sets of script_args, for a cleaner separation of concerns. 99 | parser = HfArgumentParser((ScriptArguments, TrainingArguments, DataArguments)) 100 | script_args, training_args, data_args = parser.parse_args_into_dataclasses() 101 | # Setup logging 102 | logging.basicConfig( 103 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 104 | datefmt="%m/%d/%Y %H:%M:%S", 105 | handlers=[logging.StreamHandler(sys.stdout)], 106 | ) 107 | log_level = training_args.get_process_log_level() 108 | logger.setLevel(log_level) 109 | dataset_logger.setLevel(log_level) 110 | datasets.utils.logging.set_verbosity(log_level) 111 | transformers.utils.logging.set_verbosity(log_level) 112 | transformers.utils.logging.enable_default_handler() 113 | transformers.utils.logging.enable_explicit_format() 114 | 115 | # Log on each process the small summary: 116 | logger.warning( 117 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" 118 | + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" 119 | ) 120 | logger.info(f"Training/evaluation parameters {training_args}") 121 | logger.info(f"Data arguments {data_args}") 122 | logger.info(f"Additional arguments {script_args}") 123 | # Detecting last checkpoint. 124 | last_checkpoint = None 125 | if os.path.isdir(training_args.output_dir): 126 | last_checkpoint = get_last_checkpoint(training_args.output_dir) 127 | if last_checkpoint is not None and training_args.resume_from_checkpoint is None: 128 | logger.info( 129 | f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " 130 | "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." 131 | ) 132 | 133 | # Set seed before initializing model. 134 | set_seed(training_args.seed) 135 | tokenizer = AutoTokenizer.from_pretrained( 136 | script_args.tokenizer_name or script_args.model_name_or_path, 137 | cache_dir=script_args.cache_dir, 138 | use_fast=script_args.use_fast_tokenizer, 139 | revision=script_args.model_revision, 140 | use_auth_token=True if script_args.use_auth_token else None, 141 | ) 142 | config = AutoConfig.from_pretrained( 143 | script_args.config_name or script_args.model_name_or_path, 144 | cache_dir=script_args.cache_dir, 145 | revision=script_args.model_revision, 146 | use_auth_token=True if script_args.use_auth_token else None 147 | ) 148 | if script_args.config_overrides: 149 | logger.info(f"Overriding config: {script_args.config_overrides}") 150 | config.update_from_string(script_args.config_overrides) 151 | logger.info(f"New config: {config}") 152 | 153 | if script_args.config_overrides_json: 154 | logger.info(f"Overriding config: {script_args.config_overrides_json}") 155 | config.update(json.loads(script_args.config_overrides_json)) 156 | logger.info(f"New config: {config}") 157 | 158 | config.pad_token_id = 0 159 | 160 | if script_args.model_name_or_path: 161 | model = LlamaForCausalLM.from_pretrained( 162 | script_args.model_name_or_path, 163 | from_tf=bool(".ckpt" in script_args.model_name_or_path), 164 | config=config, 165 | cache_dir=script_args.cache_dir, 166 | revision=script_args.model_revision, 167 | use_auth_token=True if script_args.use_auth_token else None, 168 | ) 169 | else: 170 | logger.warning(f"Initializing new LlamaForCausalLM from scratch") 171 | model = LlamaForCausalLM(config) 172 | 173 | if script_args.tokenizer_name is not None and script_args.model_name_or_path != script_args.tokenizer_name: 174 | model.resize_token_embeddings(len(tokenizer)) 175 | 176 | logger.info(f"Model: {model}") 177 | 178 | # This avoids weird issues when doing multiple runs from different codebases 179 | import streaming 180 | streaming.base.util.clean_stale_shared_memory() 181 | 182 | if script_args.token_scaled_loss: 183 | model.token_scaled_loss = True 184 | training_args.token_scaled_loss = True 185 | 186 | # load_datasets 187 | if training_args.do_train: 188 | train_dataset = build_dataset(script_args.tokenized_mds_train, training_args, data_args, is_training=True) 189 | 190 | if training_args.do_eval: 191 | eval_dataset = { 192 | x.split("/")[-1]: build_dataset(x, tokenizer, training_args, data_args, is_training=False) 193 | for x in script_args.tokenized_mds_validation 194 | } 195 | 196 | if training_args.do_predict: 197 | test_dataset = { 198 | x.split("/")[-1]: build_dataset(x, tokenizer, training_args, data_args, is_training=False) 199 | for x in script_args.tokenized_mds_test 200 | } 201 | 202 | data_collator = DataCollator(tokenizer, data_args) 203 | 204 | # Initialize our Trainer 205 | trainer = Trainer( 206 | model=model, 207 | args=training_args, 208 | train_dataset=train_dataset if training_args.do_train else None, 209 | eval_dataset=eval_dataset if training_args.do_eval else None, 210 | tokenizer=tokenizer, 211 | data_collator=data_collator, 212 | ) 213 | 214 | if trainer.is_fsdp_enabled: 215 | # Identify which modules have "_fsdp_wrap" attribute set to True and wrap these 216 | def fsdp_policy_fn(module): 217 | return getattr(module, "_fsdp_wrap", False) 218 | 219 | auto_wrap_policy = functools.partial(lambda_auto_wrap_policy, 220 | lambda_fn=fsdp_policy_fn) 221 | trainer.accelerator.state.fsdp_plugin.auto_wrap_policy = auto_wrap_policy 222 | 223 | # Training 224 | if training_args.do_train: 225 | checkpoint = None 226 | if training_args.resume_from_checkpoint is not None: 227 | checkpoint = training_args.resume_from_checkpoint 228 | elif last_checkpoint is not None: 229 | checkpoint = last_checkpoint 230 | train_result = trainer.train(resume_from_checkpoint=checkpoint) 231 | trainer.save_model() 232 | 233 | metrics = train_result.metrics 234 | trainer.log_metrics("train", metrics) 235 | trainer.save_metrics("train", metrics) 236 | trainer.save_state() 237 | 238 | if torch.distributed.is_initialized(): 239 | torch.distributed.barrier() 240 | 241 | 242 | # Evaluation 243 | if training_args.do_eval: 244 | logger.info("*** Evaluate ***") 245 | metrics = trainer.evaluate(eval_dataset) 246 | trainer.log_metrics("eval", metrics) 247 | trainer.save_metrics("eval", metrics) 248 | 249 | # Predict 250 | if training_args.do_predict: 251 | logger.info("*** Predict ***") 252 | predictions = trainer.predict(test_dataset=test_dataset) 253 | print(predictions) 254 | predictions = predictions.predictions 255 | predictions = tokenizer.batch_decode(predictions, skip_special_tokens=True) 256 | with open('dump.json', 'w') as f: 257 | print(json.dumps(predictions), file=f, flush=True) 258 | 259 | 260 | if __name__ == "__main__": 261 | main() 262 | -------------------------------------------------------------------------------- /training/trainer.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020-present the HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ 16 | The Trainer class, to easily train a 🤗 Transformers from scratch or finetune it on a new task. 17 | """ 18 | 19 | import time 20 | from collections.abc import Mapping 21 | from distutils.util import strtobool 22 | from pathlib import Path 23 | from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union 24 | from functools import partial 25 | import math 26 | import gc 27 | 28 | from dataclasses import dataclass, field 29 | from datasets import Dataset 30 | import transformers 31 | from transformers import Trainer as HFTrainer 32 | from transformers.trainer import _get_fsdp_ckpt_kwargs 33 | # Integrations must be imported before ML frameworks: 34 | 35 | import numpy as np 36 | import torch 37 | import torch.distributed as dist 38 | from packaging import version 39 | from torch import nn 40 | from torch.utils.data import DataLoader, Dataset, SequentialSampler 41 | from torch.optim.lr_scheduler import LambdaLR 42 | 43 | 44 | from transformers import __version__ 45 | from transformers.trainer_callback import ( 46 | PrinterCallback, 47 | TrainerCallback, 48 | ) 49 | from transformers.trainer_pt_utils import ( 50 | IterableDatasetShard, 51 | find_batch_size, 52 | nested_concat, 53 | nested_detach, 54 | nested_numpify, 55 | nested_truncate, 56 | get_parameter_names, 57 | ) 58 | from transformers.trainer_utils import ( 59 | EvalLoopOutput, 60 | EvalPrediction, 61 | denumpify_detensorize, 62 | has_length, 63 | seed_worker, 64 | PREFIX_CHECKPOINT_DIR 65 | ) 66 | from transformers.utils import ( 67 | get_full_repo_name, 68 | is_apex_available, 69 | is_sagemaker_mp_enabled, 70 | is_torch_tpu_available, 71 | ) 72 | from transformers.optimization import get_scheduler 73 | from transformers import TrainingArguments as HfTrainingArguments 74 | 75 | if is_torch_tpu_available(check_device=False): 76 | import torch_xla.core.xla_model as xm # type: ignore 77 | import torch_xla.distributed.parallel_loader as pl # type: ignore 78 | 79 | 80 | if is_sagemaker_mp_enabled(): 81 | import smdistributed.modelparallel.torch as smp # type: ignore 82 | from smdistributed.modelparallel import __version__ as SMP_VERSION # type: ignore 83 | 84 | IS_SAGEMAKER_MP_POST_1_10 = version.parse(SMP_VERSION) >= version.parse("1.10") 85 | else: 86 | IS_SAGEMAKER_MP_POST_1_10 = False 87 | 88 | 89 | from transformers.trainer import logger 90 | from streaming import StreamingDataLoader, StreamingDataset 91 | import torch.distributed as dist 92 | import datasets 93 | import os 94 | import json 95 | 96 | from transformers.trainer_callback import TrainerState 97 | from transformers.trainer_utils import enable_full_determinism, get_last_checkpoint, set_seed, find_executable_batch_size 98 | from transformers.trainer_pt_utils import reissue_pt_warnings 99 | import warnings 100 | import huggingface_hub.utils as hf_hub_utils 101 | import glob 102 | 103 | from accelerate.utils import load_fsdp_optimizer 104 | 105 | # Name of the files used for checkpointing 106 | TRAINING_ARGS_NAME = "training_args.bin" 107 | TRAINER_STATE_NAME = "trainer_state.json" 108 | OPTIMIZER_NAME = "optimizer.pt" 109 | OPTIMIZER_NAME_BIN = "optimizer.bin" 110 | SCHEDULER_NAME = "scheduler.pt" 111 | SCALER_NAME = "scaler.pt" 112 | 113 | 114 | class LogCallback(TrainerCallback): 115 | def __init__(self, *args, **kwargs): 116 | super().__init__(*args, **kwargs) 117 | self.start_time = None 118 | self.last_log_time = None 119 | self.log_time_interval = 0 120 | self.is_training = False 121 | 122 | self.max_steps = -1 123 | self.first_step_of_run = 0 124 | 125 | def on_step_begin(self, args, state, control, **kwargs): 126 | if state.is_world_process_zero and self.last_log_time is None: 127 | self.log_time_interval = getattr(args, "log_time_interval", 0) 128 | if self.log_time_interval > 0: 129 | logger.info(f"Using log_time_interval {self.log_time_interval} s. This will override logging_steps and logging_strategy.") 130 | args.logging_steps = 1 131 | args.logging_strategy = "steps" 132 | 133 | self.last_step = 0 134 | 135 | self.start_time = time.time() 136 | self.last_log_time = self.start_time 137 | self.max_steps = state.max_steps 138 | self.first_step_of_run = state.global_step 139 | 140 | self.last_tokens_seen = state.num_input_tokens_seen 141 | 142 | def on_log(self, args, state, control, logs=None, **kwargs): 143 | _ = logs.pop("total_flos", None) 144 | 145 | if state.is_world_process_zero: 146 | if self.is_training: 147 | current_time = time.time() 148 | time_diff = current_time - self.last_log_time 149 | force = logs.get("force", False) 150 | 151 | if time_diff > self.log_time_interval or state.global_step >= self.max_steps - 1 or force: 152 | self.last_log_time = current_time 153 | steps_completed = max(state.global_step, 1) 154 | 155 | steps_since_first = max(1, state.global_step - self.first_step_of_run) 156 | self.last_step = state.global_step 157 | 158 | tokens_seen_since_last = (state.num_input_tokens_seen - self.last_tokens_seen) // args.seq_parallel_size 159 | self.last_tokens_seen = state.num_input_tokens_seen 160 | 161 | remaining_steps = self.max_steps - steps_completed 162 | pct_completed = (steps_completed / self.max_steps) * 100 163 | time_since_start = current_time - self.start_time 164 | remaining_time = (time_since_start / steps_since_first) * remaining_steps 165 | 166 | gpu_mem_free, _ = torch.cuda.mem_get_info(device=args.device) 167 | 168 | update = { 169 | "completed": f"{pct_completed:.2f}% ({steps_completed:_} / {self.max_steps:_})", 170 | "remaining time": self.format_duration(remaining_time), 171 | "throughput": f"{tokens_seen_since_last / time_diff:.2f}", 172 | "gpu_mem_free": f"{gpu_mem_free / 1024 / 1024:.0f}MB", 173 | } 174 | 175 | logger.info(str({**logs, **update})) 176 | else: 177 | logger.info(str(logs)) 178 | 179 | def on_train_begin(self, args, state, control, **kwargs): 180 | args.include_num_input_tokens_seen = True 181 | 182 | def on_step_end(self, args, state, control, **kwargs): 183 | if state.is_world_process_zero: 184 | self.is_training = True 185 | 186 | def on_prediction_step(self, args, state, control, **kwargs): 187 | if state.is_world_process_zero: 188 | self.is_training = False 189 | 190 | @staticmethod 191 | def format_duration(seconds): 192 | hours, remainder = divmod(seconds, 3600) 193 | minutes, seconds = divmod(remainder, 60) 194 | return f"{int(hours)}:{int(minutes):02}:{int(seconds):02}" 195 | 196 | 197 | import signal 198 | from subprocess import call 199 | class SIGUSR1Callback(transformers.TrainerCallback): 200 | def __init__(self, trainer) -> None: 201 | super().__init__() 202 | self.signal_received = False 203 | signal.signal(signal.SIGUSR1, self.handle_signal) 204 | # signal.signal(signal.SIGINT, self.handle_signal) 205 | logger.warn("Handler registered") 206 | self.trainer = trainer 207 | 208 | def handle_signal(self, signum, frame): 209 | self.signal_received = True 210 | logger.warn("Stop signal received...") 211 | 212 | def on_substep_end(self, args, state, control, **kwargs): 213 | if self.signal_received: 214 | self.trainer._save_checkpoint(self.trainer.model, None) # Note that here _save_checkpoint does not actually use this, so we can just pass on any model 215 | # The reason we don't set should_save but instead directly save here 216 | # is that streaming may collapse after receiving the signal and it 217 | # would be too late to wait till the save function is called. 218 | # Same reason for why we handle the single in both on_substep_end 219 | # and on_step_end, even though ideally we want to do on_step_end. 220 | # control.should_save = True 221 | control.should_training_stop = True 222 | 223 | def on_step_end(self, args, state, control, **kwargs): 224 | if self.signal_received: 225 | self.trainer._save_checkpoint(self.trainer.model, None) 226 | # control.should_save = True 227 | control.should_training_stop = True 228 | 229 | def on_train_end(self, args, state, control, **kwargs): 230 | if self.signal_received: 231 | exit(0) 232 | 233 | 234 | @dataclass 235 | class TrainingArguments(HfTrainingArguments): 236 | min_lr_ratio: float = field( 237 | default=0.0 238 | ) 239 | cuda_empty_cache: bool = field( 240 | default=False, metadata={"help": "Empty cuda cache before every step."} 241 | ) 242 | streaming_dataset: bool = field( 243 | default=True, metadata={"help": "Use streaming dataset, dataloader, and their ckpt and resume"} 244 | ) 245 | seq_parallel_size: int = field( 246 | default=1, metadata={"help": "Sequence parallelism group size (1 is no parallelism)"} 247 | ) 248 | 249 | 250 | def min_lr_bound(current_step: int, wrapped_func: Callable[[float], float], min_lr_ratio: float, warmup_steps: int): 251 | if current_step < warmup_steps: 252 | return wrapped_func(current_step) 253 | return min_lr_ratio + wrapped_func(current_step) * (1.0 - min_lr_ratio) 254 | 255 | 256 | # - Callbacks: transformers.trainer_callback.DefaultFlowCallback, transformers.integrations.WandbCallback, transformers.trainer_callback.ProgressCallback 257 | class Trainer(HFTrainer): 258 | def __init__(self, model, args, *more_args, **kwargs): 259 | super().__init__(model, args, *more_args, **kwargs) 260 | 261 | if not dist.is_initialized() or args.seq_parallel_size == dist.get_world_size(): 262 | logger.info(f"Using world as sequence parallel group") 263 | self.seq_parallel_group = dist.group.WORLD 264 | else: 265 | logger.info(f"Initializing sequence parallel groups with size {args.seq_parallel_size}") 266 | self.seq_parallel_group, _ = dist.new_subgroups(args.seq_parallel_size) 267 | 268 | try: 269 | self.remove_callback(PrinterCallback) 270 | self.add_callback(LogCallback) 271 | # self.add_callback(SIGUSR1Callback(self)) 272 | except ValueError: 273 | logger.warn("Couldn't remove PrinterCallback") 274 | 275 | def get_sequence_parallel_inputs(self, inputs): 276 | seq_parallel_world_size = (dist.get_world_size(self.seq_parallel_group) if dist.is_initialized() else 1) 277 | 278 | if seq_parallel_world_size > 1: 279 | seq_parallel_rank = dist.get_rank(self.seq_parallel_group) 280 | 281 | input_ids = inputs["input_ids"] 282 | labels = inputs["labels"] 283 | 284 | shifted_labels = labels.roll(-1, dims=-1) 285 | shifted_labels[..., -1] = -100 286 | 287 | seq_lengths = inputs["seq_lengths"] 288 | 289 | # add right padding here to make equal sized chunks 290 | if input_ids.size(-1) % seq_parallel_world_size != 0: 291 | padding = seq_parallel_world_size - (input_ids.size(-1) % seq_parallel_world_size) 292 | padding_zeros = torch.full(input_ids.size()[:-1] + (padding,), 0, dtype=input_ids.dtype, device=input_ids.device) 293 | input_ids = torch.cat([input_ids, padding_zeros], dim=-1) 294 | shifted_labels = torch.cat([shifted_labels, padding_zeros-100], dim=-1) 295 | seq_lengths[-1] += padding 296 | 297 | # select chunk of input_ids and labels 298 | input_ids_chunks = torch.tensor_split(input_ids, seq_parallel_world_size, dim=-1) 299 | shifted_labels_chunks = torch.tensor_split(shifted_labels, seq_parallel_world_size, dim=-1) 300 | 301 | inputs = { 302 | "input_ids": input_ids_chunks[seq_parallel_rank], 303 | "shifted_labels": shifted_labels_chunks[seq_parallel_rank], 304 | "seq_lengths": seq_lengths, 305 | "seq_parallel_group": self.seq_parallel_group, 306 | } 307 | 308 | max_seq_length = seq_lengths.max() 309 | max_tokens_per_device = seq_lengths.sum() // seq_parallel_world_size 310 | 311 | start_index = sum(chunk.size(-1) for chunk in input_ids_chunks[:seq_parallel_rank]) 312 | end_index = start_index + input_ids_chunks[seq_parallel_rank].size(-1) 313 | 314 | inputs["position_ids"] = torch.tensor([start_index]).to(input_ids.device) 315 | 316 | # max sequence length is smaller per device => no need for sequence parallelism 317 | if max_seq_length <= max_tokens_per_device: 318 | # take the seq length field and only retain seq lengths with indices that are valid for this rank 319 | seq_indices = seq_lengths.cumsum(-1) 320 | seq_indices = seq_indices[(seq_indices < end_index) & (seq_indices >= start_index)] 321 | 322 | start_index_tensor = torch.tensor([start_index], device=seq_indices.device) 323 | end_index_tensor = torch.tensor([end_index], device=seq_indices.device) 324 | 325 | seq_lengths = seq_indices.diff(prepend=start_index_tensor, append=end_index_tensor) 326 | seq_lengths = seq_lengths[seq_lengths > 0] 327 | inputs["seq_lengths"] = seq_lengths 328 | inputs["seq_parallel_group"] = None 329 | 330 | return inputs 331 | 332 | def compute_loss(self, model, inputs, return_outputs=False, return_output_and_metrics=False): 333 | """ 334 | How the loss is computed by Trainer. By default, all models return the loss in the first element. 335 | 336 | Subclass and override for custom behavior. 337 | """ 338 | 339 | inputs = self.get_sequence_parallel_inputs(inputs) 340 | 341 | try: 342 | outputs = model(**inputs, use_cache=False) 343 | except Exception as e: 344 | error_str = "-"*30 345 | for k, v in inputs.items(): 346 | if isinstance(v, torch.Tensor): 347 | error_str += f"\n{k}:\n{v.cpu().tolist()}\n ({v.dtype}, {v.shape})\n" 348 | error_str += "-"*30 349 | print(error_str[:256], flush=True) 350 | raise e 351 | 352 | loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0] 353 | 354 | if getattr(self.args, "token_scaled_loss", False): 355 | # We are going to scale the loss by the ratio of #valid tokens / avg #valid tokens per device 356 | # The standard loss = sum(avg loss of current seq) / #devices / #ga 357 | # What we are doing = sum(sum loss of current seq) / avg #valid tokens per device / #devices / #ga = original * (current device valid seq tokens / avg device valid seq tokens) = original_sumreduction (done in `compute_loss` in modeling_flash_llama.py) / avg device valid seq tokens 358 | # Technically we should use the avg #valid tokens per device for this batch (may include multiple steps because of gradient accumulation). But for simplicity we use the moving average (from the whole training process) 359 | 360 | seq_parallel_world_size = (dist.get_world_size(self.seq_parallel_group) if dist.is_initialized() else 1) 361 | if seq_parallel_world_size > 1: # Sequence parallelism 362 | device_num_valid_tokens = (inputs["shifted_labels"] != -100).sum().float() # Should be on the device already 363 | else: 364 | device_num_valid_tokens = (inputs["labels"] != -100).sum().float() # Should be on the device already 365 | 366 | avg_device_num_valid_tokens = torch.mean(self.accelerator.gather(device_num_valid_tokens)).item() 367 | 368 | if not hasattr(self.state, "count_step_for_num_valid_tokens"): 369 | self.state.count_step_for_num_valid_tokens = 1 370 | self.state.avg_num_valid_tokens_per_device = avg_device_num_valid_tokens 371 | else: 372 | self.state.count_step_for_num_valid_tokens += 1 373 | steps = self.state.count_step_for_num_valid_tokens 374 | self.state.avg_num_valid_tokens_per_device = self.state.avg_num_valid_tokens_per_device * ((steps - 1) / steps) + avg_device_num_valid_tokens / steps # moving avg 375 | 376 | loss = loss / self.state.avg_num_valid_tokens_per_device 377 | 378 | if return_output_and_metrics: 379 | # shifted_labels = inputs["labels"][:,1:].contiguous() 380 | # valid_mask = (shifted_labels != -100) 381 | # correct = (outputs.logits[:,:-1].argmax(-1) == shifted_labels).float() 382 | # correct[~valid_mask] = 0.0 383 | # acc = correct.sum(dim=-1) / valid_mask.float().sum(dim=-1) 384 | 385 | metrics = {} 386 | 387 | return (loss, outputs, metrics) 388 | if return_outputs: 389 | return (loss, outputs) 390 | else: 391 | return loss 392 | 393 | def create_scheduler(self, num_training_steps: int, optimizer: torch.optim.Optimizer = None): 394 | """ 395 | Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or 396 | passed as an argument. 397 | 398 | Args: 399 | num_training_steps (int): The number of training steps to do. 400 | """ 401 | 402 | self.lr_scheduler = super().create_scheduler(num_training_steps, optimizer) 403 | 404 | if self.args.min_lr_ratio != 0.0: 405 | if isinstance(self.lr_scheduler, LambdaLR): 406 | lr_lambdas = self.lr_scheduler.lr_lambdas 407 | new_lr_lambdas = [ 408 | lr_lambda 409 | if lr_lambda is None or isinstance(lr_lambda, partial) and lr_lambda.func == min_lr_bound 410 | else 411 | partial(min_lr_bound, 412 | wrapped_func=lr_lambda, 413 | min_lr_ratio=self.args.min_lr_ratio, 414 | warmup_steps=self.args.get_warmup_steps(num_training_steps)) 415 | for lr_lambda in lr_lambdas 416 | ] 417 | 418 | self.lr_scheduler.lr_lambdas = new_lr_lambdas 419 | else: 420 | raise NotImplementedError("Only LambdaLR is supported for min_lr_ratio") 421 | 422 | return self.lr_scheduler 423 | 424 | def prediction_step( 425 | self, 426 | model: nn.Module, 427 | inputs: Dict[str, Union[torch.Tensor, Any]], 428 | prediction_loss_only: bool, 429 | ignore_keys: Optional[List[str]] = None, 430 | ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: 431 | """ 432 | Perform an evaluation step on `model` using `inputs`. 433 | 434 | Subclass and override to inject custom behavior. 435 | 436 | Args: 437 | model (`nn.Module`): 438 | The model to evaluate. 439 | inputs (`Dict[str, Union[torch.Tensor, Any]]`): 440 | The inputs and targets of the model. 441 | 442 | The dictionary will be unpacked before being fed to the model. Most models expect the targets under the 443 | argument `labels`. Check your model's documentation for all accepted arguments. 444 | prediction_loss_only (`bool`): 445 | Whether or not to return the loss only. 446 | ignore_keys (`Lst[str]`, *optional*): 447 | A list of keys in the output of your model (if it is a dictionary) that should be ignored when 448 | gathering predictions. 449 | 450 | Return: 451 | Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss, 452 | logits and labels (each being optional). 453 | """ 454 | has_labels = False if len(self.label_names) == 0 else all(inputs.get(k) is not None for k in self.label_names) 455 | # For CLIP-like models capable of returning loss values. 456 | # If `return_loss` is not specified or being `None` in `inputs`, we check if the default value of `return_loss` 457 | # is `True` in `model.forward`. 458 | return_loss = inputs.get("return_loss", None) 459 | if return_loss is None: 460 | return_loss = self.can_return_loss 461 | loss_without_labels = True if len(self.label_names) == 0 and return_loss else False 462 | 463 | inputs = self._prepare_inputs(inputs) 464 | if ignore_keys is None: 465 | if hasattr(self.model, "config"): 466 | ignore_keys = getattr(self.model.config, "keys_to_ignore_at_inference", []) 467 | else: 468 | ignore_keys = [] 469 | 470 | # labels may be popped when computing the loss (label smoothing for instance) so we grab them first. 471 | if has_labels or loss_without_labels: 472 | labels = nested_detach(tuple(inputs.get(name) for name in self.label_names)) 473 | if len(labels) == 1: 474 | labels = labels[0] 475 | else: 476 | labels = None 477 | 478 | with torch.no_grad(): 479 | if is_sagemaker_mp_enabled(): 480 | raise ValueError("SageMaker Model Parallelism is not supported in BaseTrainer") 481 | else: 482 | with self.compute_loss_context_manager(): 483 | loss, outputs, metrics = self.compute_loss(model, inputs, return_output_and_metrics=True) 484 | if loss is not None: 485 | loss = loss.mean().detach() 486 | 487 | if isinstance(outputs, dict): 488 | logits = tuple(v for k, v in outputs.items() if k not in ignore_keys + ["loss"]) 489 | else: 490 | logits = outputs[1:] 491 | 492 | if prediction_loss_only: 493 | return (loss, None, None, metrics) 494 | 495 | logits = nested_detach(logits) 496 | if len(logits) == 1: 497 | logits = logits[0] 498 | 499 | return (loss, logits, labels, metrics) 500 | 501 | def compute_loss_context_manager(self): 502 | """ 503 | A helper wrapper to group together context managers. 504 | """ 505 | if self.args.cuda_empty_cache: 506 | gc.collect() 507 | torch.cuda.empty_cache() 508 | return self.autocast_smart_context_manager() 509 | 510 | def evaluation_loop( 511 | self, 512 | dataloader: DataLoader, 513 | description: str, 514 | prediction_loss_only: Optional[bool] = None, 515 | ignore_keys: Optional[List[str]] = None, 516 | metric_key_prefix: str = "eval", 517 | ) -> EvalLoopOutput: 518 | """ 519 | Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`. 520 | 521 | Works both with or without labels. 522 | """ 523 | args = self.args 524 | 525 | prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else args.prediction_loss_only 526 | 527 | model = self._wrap_model(self.model, training=False, dataloader=dataloader) 528 | 529 | if len(self.accelerator._models) == 0 and model is self.model: 530 | model = ( 531 | self.accelerator.prepare(model) 532 | if self.is_deepspeed_enabled 533 | else self.accelerator.prepare_model(model, evaluation_mode=False) 534 | ) 535 | 536 | if self.is_fsdp_enabled: 537 | self.model = model 538 | 539 | # for the rest of this function `model` is the outside model, whether it was wrapped or not 540 | if model is not self.model: 541 | self.model_wrapped = model 542 | 543 | # backward compatibility 544 | if self.is_deepspeed_enabled: 545 | self.deepspeed = self.model_wrapped 546 | 547 | # if full fp16 or bf16 eval is wanted and this ``evaluation`` or ``predict`` isn't called 548 | # while ``train`` is running, cast it to the right dtype first and then put on device 549 | if not self.is_in_train: 550 | if args.fp16_full_eval: 551 | model = model.to(dtype=torch.float16, device=args.device) 552 | elif args.bf16_full_eval: 553 | model = model.to(dtype=torch.bfloat16, device=args.device) 554 | 555 | batch_size = self.args.eval_batch_size 556 | 557 | logger.info(f"***** Running {description} *****") 558 | if has_length(dataloader): 559 | logger.info(f" Num examples = {self.num_examples(dataloader)}") 560 | else: 561 | logger.info(" Num examples: Unknown") 562 | logger.info(f" Batch size = {batch_size}") 563 | 564 | model.eval() 565 | 566 | self.callback_handler.eval_dataloader = dataloader 567 | # Do this before wrapping. 568 | eval_dataset = getattr(dataloader, "dataset", None) 569 | 570 | if is_torch_tpu_available(): 571 | dataloader = pl.ParallelLoader(dataloader, [args.device]).per_device_loader(args.device) 572 | 573 | if args.past_index >= 0: 574 | self._past = None 575 | 576 | # Initialize containers 577 | # losses/preds/labels on GPU/TPU (accumulated for eval_accumulation_steps) 578 | losses_host = None 579 | preds_host = None 580 | labels_host = None 581 | inputs_host = None 582 | metrics_host = None 583 | 584 | metrics_names = None 585 | 586 | # losses/preds/labels on CPU (final containers) 587 | all_losses = None 588 | all_preds = None 589 | all_labels = None 590 | all_inputs = None 591 | all_metrics = None 592 | # Will be useful when we have an iterable dataset so don't know its length. 593 | 594 | observed_num_examples = 0 595 | # Main evaluation loop 596 | for step, inputs in enumerate(dataloader): 597 | # Update the observed num examples 598 | observed_batch_size = find_batch_size(inputs) 599 | if observed_batch_size is not None: 600 | observed_num_examples += observed_batch_size 601 | # For batch samplers, batch_size is not known by the dataloader in advance. 602 | if batch_size is None: 603 | batch_size = observed_batch_size 604 | 605 | # Prediction step 606 | loss, logits, labels, metrics = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys) 607 | inputs_decode = self._prepare_input(inputs["input_ids"]) if args.include_inputs_for_metrics else None 608 | 609 | if is_torch_tpu_available(): 610 | xm.mark_step() 611 | 612 | 613 | # Update containers on host 614 | if loss is not None: 615 | losses = self.accelerator.gather_for_metrics((loss.repeat(batch_size))) 616 | losses_host = losses if losses_host is None else nested_concat(losses_host, losses, padding_index=-100) 617 | if labels is not None: 618 | labels = self.accelerator.pad_across_processes(labels, dim=1, pad_index=-100) 619 | if inputs_decode is not None: 620 | inputs_decode = self.accelerator.pad_across_processes(inputs_decode, dim=1, pad_index=-100) 621 | inputs_decode = self.accelerator.gather_for_metrics((inputs_decode)) 622 | inputs_host = ( 623 | inputs_decode 624 | if inputs_host is None 625 | else nested_concat(inputs_host, inputs_decode, padding_index=-100) 626 | ) 627 | if metrics is not None: 628 | if metrics_names is None: 629 | metrics_names = list(metrics.keys()) 630 | else: 631 | assert metrics_names == list(metrics.keys()), "Metrics should have the same keys across batches" 632 | 633 | 634 | metrics = [ 635 | metric if metric.shape else metric.repeat(batch_size) for metric in metrics.values() 636 | ] 637 | metrics = self.accelerator.pad_across_processes(metrics, dim=1, pad_index=float('nan')) 638 | metrics = self.accelerator.gather_for_metrics(metrics) 639 | metrics_host = metrics if metrics_host is None else nested_concat(metrics_host, metrics, padding_index=float('nan')) 640 | if logits is not None: 641 | logits = self.accelerator.pad_across_processes(logits, dim=1, pad_index=-100) 642 | if self.preprocess_logits_for_metrics is not None: 643 | logits = self.preprocess_logits_for_metrics(logits, labels) 644 | logits = self.accelerator.gather_for_metrics((logits)) 645 | preds_host = logits if preds_host is None else nested_concat(preds_host, logits, padding_index=-100) 646 | 647 | if labels is not None: 648 | labels = self.accelerator.gather_for_metrics((labels)) 649 | labels_host = labels if labels_host is None else nested_concat(labels_host, labels, padding_index=-100) 650 | 651 | 652 | self.control = self.callback_handler.on_prediction_step(args, self.state, self.control) 653 | 654 | # Gather all tensors and put them back on the CPU if we have done enough accumulation steps. 655 | if args.eval_accumulation_steps is not None and (step + 1) % args.eval_accumulation_steps == 0: 656 | if losses_host is not None: 657 | losses = nested_numpify(losses_host) 658 | all_losses = losses if all_losses is None else np.concatenate((all_losses, losses), axis=0) 659 | if preds_host is not None: 660 | logits = nested_numpify(preds_host) 661 | all_preds = logits if all_preds is None else nested_concat(all_preds, logits, padding_index=-100) 662 | if metrics_host is not None: 663 | metrics = nested_numpify(metrics_host) 664 | all_metrics = ( 665 | metrics if all_metrics is None else nested_concat(all_metrics, metrics, padding_index=float('nan')) 666 | ) 667 | if inputs_host is not None: 668 | inputs_decode = nested_numpify(inputs_host) 669 | all_inputs = ( 670 | inputs_decode 671 | if all_inputs is None 672 | else nested_concat(all_inputs, inputs_decode, padding_index=-100) 673 | ) 674 | if labels_host is not None: 675 | labels = nested_numpify(labels_host) 676 | all_labels = ( 677 | labels if all_labels is None else nested_concat(all_labels, labels, padding_index=-100) 678 | ) 679 | 680 | # Set back to None to begin a new accumulation 681 | losses_host, preds_host, inputs_host, labels_host = None, None, None, None 682 | 683 | if args.past_index and hasattr(self, "_past"): 684 | # Clean the state at the end of the evaluation loop 685 | delattr(self, "_past") 686 | 687 | # Gather all remaining tensors and put them back on the CPU 688 | if losses_host is not None: 689 | losses = nested_numpify(losses_host) 690 | all_losses = losses if all_losses is None else np.concatenate((all_losses, losses), axis=0) 691 | if preds_host is not None: 692 | logits = nested_numpify(preds_host) 693 | all_preds = logits if all_preds is None else nested_concat(all_preds, logits, padding_index=-100) 694 | if inputs_host is not None: 695 | inputs_decode = nested_numpify(inputs_host) 696 | all_inputs = ( 697 | inputs_decode if all_inputs is None else nested_concat(all_inputs, inputs_decode, padding_index=-100) 698 | ) 699 | if metrics_host is not None: 700 | metrics = nested_numpify(metrics_host) 701 | all_metrics = ( 702 | metrics if all_metrics is None else nested_concat(all_metrics, metrics, padding_index=float('nan')) 703 | ) 704 | if labels_host is not None: 705 | labels = nested_numpify(labels_host) 706 | all_labels = labels if all_labels is None else nested_concat(all_labels, labels, padding_index=-100) 707 | 708 | # Number of samples 709 | if has_length(eval_dataset): 710 | num_samples = len(eval_dataset) 711 | # The instance check is weird and does not actually check for the type, but whether the dataset has the right 712 | # methods. Therefore we need to make sure it also has the attribute. 713 | elif isinstance(eval_dataset, IterableDatasetShard) and getattr(eval_dataset, "num_examples", 0) > 0: 714 | num_samples = eval_dataset.num_examples 715 | else: 716 | if has_length(dataloader): 717 | num_samples = self.num_examples(dataloader) 718 | else: # both len(dataloader.dataset) and len(dataloader) fail 719 | num_samples = observed_num_examples 720 | if num_samples == 0 and observed_num_examples > 0: 721 | num_samples = observed_num_examples 722 | 723 | # Number of losses has been rounded to a multiple of batch_size and in a distributed training, the number of 724 | # samplers has been rounded to a multiple of batch_size, so we truncate. 725 | if all_losses is not None: 726 | all_losses = all_losses[:num_samples] 727 | if all_preds is not None: 728 | all_preds = nested_truncate(all_preds, num_samples) 729 | if all_labels is not None: 730 | all_labels = nested_truncate(all_labels, num_samples) 731 | if all_inputs is not None: 732 | all_inputs = nested_truncate(all_inputs, num_samples) 733 | # if all_metrics is not None: 734 | # all_metrics = nested_truncate(all_metrics, num_samples) 735 | 736 | # Metrics! 737 | if self.compute_metrics is not None and all_preds is not None and all_labels is not None: 738 | if args.include_inputs_for_metrics: 739 | metrics = self.compute_metrics( 740 | EvalPrediction(predictions=all_preds, label_ids=all_labels, inputs=all_inputs) 741 | ) 742 | else: 743 | metrics = self.compute_metrics(EvalPrediction(predictions=all_preds, label_ids=all_labels)) 744 | else: 745 | metrics = {} 746 | 747 | if all_metrics is not None: 748 | for key, value in zip(metrics_names, all_metrics): 749 | valid = ~np.isnan(value) 750 | metrics[key] = value[valid].mean().item() 751 | metrics[f"{key}___samples"] = np.sum(valid).item() 752 | 753 | metrics["samples"] = num_samples 754 | 755 | # To be JSON-serializable, we need to remove numpy types or zero-d tensors 756 | metrics = denumpify_detensorize(metrics) 757 | 758 | if all_losses is not None: 759 | metrics[f"{metric_key_prefix}_loss"] = all_losses.mean().item() 760 | if hasattr(self, "jit_compilation_time"): 761 | metrics[f"{metric_key_prefix}_jit_compilation_time"] = self.jit_compilation_time 762 | 763 | # Prefix all keys with metric_key_prefix + '_' 764 | for key in list(metrics.keys()): 765 | if not key.startswith(f"{metric_key_prefix}_"): 766 | metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key) 767 | 768 | return EvalLoopOutput(predictions=all_preds, label_ids=all_labels, metrics=metrics, num_samples=num_samples) 769 | 770 | 771 | def evaluate( 772 | self, 773 | eval_dataset: Optional[Union[Dict[str, Dataset], Dataset]] = None, 774 | ignore_keys: Optional[List[str]] = None, 775 | metric_key_prefix: str = "eval", 776 | ) -> Dict[str, float]: 777 | if eval_dataset is None: 778 | eval_dataset = self.eval_dataset 779 | 780 | if isinstance(eval_dataset, dict): 781 | metrics = {} 782 | for key, dataset in eval_dataset.items(): 783 | metrics.update(super().evaluate(dataset, ignore_keys=ignore_keys, metric_key_prefix=f"{metric_key_prefix}_{key}")) 784 | else: 785 | metrics = super().evaluate(eval_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix) 786 | 787 | return metrics 788 | 789 | def get_train_dataloader(self): 790 | """ 791 | Because streaming handles the distributed data parallel by itself, we don't need special data loader. 792 | The plainest data loader is enough. 793 | """ 794 | if not self.args.streaming_dataset: 795 | return super().get_train_dataloader() 796 | 797 | logger.warn("Use streaming dataloader for train") 798 | 799 | if self.train_dataset is None: 800 | raise ValueError("Trainer: training requires a train_dataset.") 801 | 802 | train_dataset = self.train_dataset 803 | data_collator = self.data_collator 804 | data_collator = self._get_collator_with_removed_columns(data_collator, description="training") 805 | 806 | dataloader_params = { 807 | "batch_size": self._train_batch_size, 808 | "collate_fn": data_collator, 809 | "num_workers": self.args.dataloader_num_workers, 810 | "pin_memory": self.args.dataloader_pin_memory, 811 | "persistent_workers": self.args.dataloader_persistent_workers, 812 | } 813 | 814 | # Streaming is iterable so no need to set sampler etc. 815 | 816 | # Instead of use accelerate to prepare the dataloader, we just return a plain dataloader 817 | self.train_dataloader = DataLoader(train_dataset, **dataloader_params) 818 | # This actually uses the dataset first dimension...... 819 | 820 | return self.train_dataloader 821 | 822 | 823 | def get_eval_dataloader(self, eval_dataset): 824 | """ 825 | Because streaming handles the distributed data parallel by itself, we don't need special data loader. 826 | The plainest data loader is enough. 827 | """ 828 | if not self.args.streaming_dataset: 829 | return super().get_eval_dataloader() 830 | 831 | logger.warn("Use streaming dataloader for val") 832 | 833 | if eval_dataset is None and self.eval_dataset is None: 834 | raise ValueError("Trainer: evaluation requires an eval_dataset.") 835 | eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset 836 | data_collator = self.data_collator 837 | data_collator = self._get_collator_with_removed_columns(data_collator, description="evaluation") 838 | 839 | dataloader_params = { 840 | "batch_size": self.args.eval_batch_size, 841 | "collate_fn": data_collator, 842 | "num_workers": self.args.dataloader_num_workers, 843 | "pin_memory": self.args.dataloader_pin_memory, 844 | "persistent_workers": self.args.dataloader_persistent_workers, 845 | } 846 | 847 | # Streaming is iterable so no need to set sampler etc. 848 | 849 | # Instead of use accelerate to prepare the dataloader, we just return a plain dataloader 850 | return StreamingDataLoader(eval_dataset, **dataloader_params) 851 | 852 | 853 | def _save_checkpoint(self, model, trial, metrics=None): 854 | # A wrapper around the original _save_checkpoint function to save streaming dataset state 855 | 856 | # Save model checkpoint 857 | super()._save_checkpoint(model, trial, metrics=metrics) 858 | 859 | # Get the path 860 | checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}" 861 | run_dir = self._get_output_dir(trial=trial) 862 | output_dir = os.path.join(run_dir, checkpoint_folder) 863 | 864 | # Save streaming dataset state 865 | if isinstance(self.train_dataset, StreamingDataset) and self.state.is_world_process_zero: 866 | num_samples = self.state.global_step * self.args.train_batch_size * self.args.world_size * self.args.gradient_accumulation_steps 867 | if self.train_dataset.replication is not None: 868 | num_samples = num_samples // self.train_dataset.replication 869 | dataset_state_dict = self.train_dataset.state_dict(num_samples, True) 870 | logger.warn(f"Save streaming dataset state: {dataset_state_dict}") 871 | json.dump(dataset_state_dict, open(os.path.join(output_dir, "streaming_dataset_state.json"), "w")) 872 | 873 | 874 | def _load_optimizer_and_scheduler(self, checkpoint): 875 | # A wrapper around the original _load_optimizer_and_scheduler to resume dataloader 876 | 877 | # Call the original function 878 | # super()._load_optimizer_and_scheduler(checkpoint) 879 | # Below is copied from the original _load_optimizer_and_scheduler 880 | # But allow only loading optimizer if the scheduler does not exist 881 | 882 | """If optimizer and scheduler states exist, load them.""" 883 | if checkpoint is None: 884 | return 885 | 886 | checkpoint_file_exists = ( 887 | glob.glob(os.path.join(checkpoint, OPTIMIZER_NAME) + "_*") 888 | if is_sagemaker_mp_enabled() 889 | else ( 890 | os.path.isfile(os.path.join(checkpoint, OPTIMIZER_NAME)) 891 | or os.path.isfile(os.path.join(checkpoint, OPTIMIZER_NAME_BIN)) 892 | or ( 893 | os.path.isdir(checkpoint) 894 | and any( 895 | OPTIMIZER_NAME_BIN.split(".")[0] in folder_name 896 | for folder_name in os.listdir(checkpoint) 897 | if os.path.isdir(os.path.join(checkpoint, folder_name)) 898 | ) 899 | ) 900 | ) 901 | ) 902 | if checkpoint_file_exists: 903 | logger.warn(f"Load optimizer state from {checkpoint}") 904 | # We use the CPU when training on one GPU to avoid OOM for GPU RAM when training big models. 905 | # In distributed training however, we load directly on each GPU and risk the GPU OOM as it's more 906 | # likely to get OOM on CPU (since we load num_gpu times the optimizer state 907 | map_location = self.args.device if self.args.world_size > 1 else "cpu" 908 | if self.is_fsdp_enabled: 909 | load_fsdp_optimizer( 910 | self.accelerator.state.fsdp_plugin, 911 | self.accelerator, 912 | self.optimizer, 913 | self.model, 914 | checkpoint, 915 | **_get_fsdp_ckpt_kwargs(), 916 | ) 917 | else: 918 | self.optimizer.load_state_dict( 919 | torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location=map_location) 920 | ) 921 | 922 | if os.path.isfile(os.path.join(checkpoint, SCHEDULER_NAME)): 923 | logger.warn(f"Load scheduler state from {checkpoint}") 924 | with warnings.catch_warnings(record=True) as caught_warnings: 925 | self.lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, SCHEDULER_NAME))) 926 | reissue_pt_warnings(caught_warnings) 927 | 928 | 929 | # Resume dataloader 930 | if checkpoint is not None and self.args.streaming_dataset: 931 | try: 932 | dataset_state_dict = json.load(open(os.path.join(checkpoint, "streaming_dataset_state.json"))) 933 | except: 934 | logger.warn(f"Failed to load streaming dataset state from {checkpoint}") 935 | logger.warn(f"Fall back to the HF data skip") 936 | self.args.ignore_data_skip = False 937 | 938 | return 939 | 940 | # First, disable HF's data skip 941 | self.args.ignore_data_skip = True 942 | 943 | # We save the sample_in_epoch assuming we only train for one epoch, so we need to adjust when resuming multi-epoch training 944 | epoch_size = len(self.train_dataset) * self.args.world_size 945 | assert dataset_state_dict["sample_in_epoch"] - dataset_state_dict["epoch"] * epoch_size == dataset_state_dict["sample_in_epoch"] % epoch_size 946 | 947 | dataset_state_dict["sample_in_epoch"] = dataset_state_dict["sample_in_epoch"] - dataset_state_dict["epoch"] * epoch_size 948 | 949 | # Load the dataset state and reinit the dataloader 950 | logger.warn(f"Resume streaming dataset state from {checkpoint}: {dataset_state_dict}") 951 | self.train_dataset.load_state_dict(dataset_state_dict) 952 | 953 | # Override the original train() to handle the case 954 | # when resuming from a checkpoint but no trainer_state is there 955 | # (e.g., continual training with optimizer states) 956 | def train( 957 | self, 958 | resume_from_checkpoint: Optional[Union[str, bool]] = None, 959 | trial: Union["optuna.Trial", Dict[str, Any]] = None, 960 | ignore_keys_for_eval: Optional[List[str]] = None, 961 | **kwargs, 962 | ): 963 | """ 964 | Main training entry point. 965 | 966 | Args: 967 | resume_from_checkpoint (`str` or `bool`, *optional*): 968 | If a `str`, local path to a saved checkpoint as saved by a previous instance of [`Trainer`]. If a 969 | `bool` and equals `True`, load the last checkpoint in *args.output_dir* as saved by a previous instance 970 | of [`Trainer`]. If present, training will resume from the model/optimizer/scheduler states loaded here. 971 | trial (`optuna.Trial` or `Dict[str, Any]`, *optional*): 972 | The trial run or the hyperparameter dictionary for hyperparameter search. 973 | ignore_keys_for_eval (`List[str]`, *optional*) 974 | A list of keys in the output of your model (if it is a dictionary) that should be ignored when 975 | gathering predictions for evaluation during the training. 976 | kwargs (`Dict[str, Any]`, *optional*): 977 | Additional keyword arguments used to hide deprecated arguments 978 | """ 979 | if resume_from_checkpoint is False: 980 | resume_from_checkpoint = None 981 | 982 | # memory metrics - must set up as early as possible 983 | self._memory_tracker.start() 984 | 985 | args = self.args 986 | 987 | self.is_in_train = True 988 | 989 | # Attach NEFTune hooks if necessary 990 | if self.neftune_noise_alpha is not None: 991 | self.model = self._activate_neftune(self.model) 992 | 993 | # do_train is not a reliable argument, as it might not be set and .train() still called, so 994 | # the following is a workaround: 995 | if (args.fp16_full_eval or args.bf16_full_eval) and not args.do_train: 996 | self._move_model_to_device(self.model, args.device) 997 | 998 | if "model_path" in kwargs: 999 | resume_from_checkpoint = kwargs.pop("model_path") 1000 | warnings.warn( 1001 | "`model_path` is deprecated and will be removed in a future version. Use `resume_from_checkpoint` " 1002 | "instead.", 1003 | FutureWarning, 1004 | ) 1005 | if len(kwargs) > 0: 1006 | raise TypeError(f"train() received got unexpected keyword arguments: {', '.join(list(kwargs.keys()))}.") 1007 | # This might change the seed so needs to run first. 1008 | self._hp_search_setup(trial) 1009 | self._train_batch_size = self.args.train_batch_size 1010 | 1011 | # Model re-init 1012 | model_reloaded = False 1013 | if self.model_init is not None: 1014 | # Seed must be set before instantiating the model when using model_init. 1015 | enable_full_determinism(self.args.seed) if self.args.full_determinism else set_seed(self.args.seed) 1016 | self.model = self.call_model_init(trial) 1017 | model_reloaded = True 1018 | # Reinitializes optimizer and scheduler 1019 | self.optimizer, self.lr_scheduler = None, None 1020 | 1021 | # Load potential model checkpoint 1022 | if isinstance(resume_from_checkpoint, bool) and resume_from_checkpoint: 1023 | resume_from_checkpoint = get_last_checkpoint(args.output_dir) 1024 | if resume_from_checkpoint is None: 1025 | raise ValueError(f"No valid checkpoint found in output directory ({args.output_dir})") 1026 | 1027 | if resume_from_checkpoint is not None: 1028 | if not is_sagemaker_mp_enabled() and not self.is_deepspeed_enabled and not self.is_fsdp_enabled: 1029 | self._load_from_checkpoint(resume_from_checkpoint) 1030 | # In case of repeating the find_executable_batch_size, set `self._train_batch_size` properly 1031 | if os.path.isfile(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME)): 1032 | state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME)) 1033 | if state.train_batch_size is not None: 1034 | self._train_batch_size = state.train_batch_size 1035 | 1036 | # If model was re-initialized, put it on the right device and update self.model_wrapped 1037 | if model_reloaded: 1038 | if self.place_model_on_device: 1039 | self._move_model_to_device(self.model, args.device) 1040 | self.model_wrapped = self.model 1041 | 1042 | inner_training_loop = find_executable_batch_size( 1043 | self._inner_training_loop, self._train_batch_size, args.auto_find_batch_size 1044 | ) 1045 | if args.push_to_hub: 1046 | try: 1047 | # Disable progress bars when uploading models during checkpoints to avoid polluting stdout 1048 | hf_hub_utils.disable_progress_bars() 1049 | return inner_training_loop( 1050 | args=args, 1051 | resume_from_checkpoint=resume_from_checkpoint, 1052 | trial=trial, 1053 | ignore_keys_for_eval=ignore_keys_for_eval, 1054 | ) 1055 | finally: 1056 | hf_hub_utils.enable_progress_bars() 1057 | else: 1058 | return inner_training_loop( 1059 | args=args, 1060 | resume_from_checkpoint=resume_from_checkpoint, 1061 | trial=trial, 1062 | ignore_keys_for_eval=ignore_keys_for_eval, 1063 | ) 1064 | 1065 | def _fsdp_qlora_plugin_updates(self): 1066 | pass # This messes with autowrap policy 1067 | --------------------------------------------------------------------------------