├── .gitignore ├── speculative_decoding ├── utils.patch └── codellama_spec.sh ├── requirement.txt ├── lit_gpt ├── __init__.py ├── rmsnorm.py ├── rotary_ebm.py ├── tokenizer.py ├── fused_rotary_embedding.py ├── fused_cross_entropy.py ├── packed_dataset.py ├── adapter.py ├── adapter_v2.py ├── model.py └── utils.py ├── scripts ├── prepare_mnbvc.sh ├── prepare_skypile.sh ├── prepare_slimpajama_train.sh ├── prepare_slimpajama_valid.sh ├── prepare_starcoder.sh ├── prepare_project_gutenberg.sh ├── prepare_starcoder_python.sh ├── run_lm_eval.sh ├── convert_lit_model_to_hf.sh ├── datasets_statistics.py ├── prepare_starcoder.py ├── prepare_skypile.py ├── prepare_starcoder_python.py ├── prepare_project_gutenberg.py ├── prepare_mnbvc.py ├── prepare_slimpajama.py ├── convert_hf_checkpoint.py └── convert_lit_checkpoint.py ├── cluster ├── pretrain.sh ├── pretrain_node_0.sh ├── pretrain_node_1.sh ├── pretrain_node_2.sh ├── pretrain_node_3.sh └── finetune.sh ├── README.md └── pretrain ├── tinyllama.py └── tinyllama_code.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | add_license.py -------------------------------------------------------------------------------- /speculative_decoding/utils.patch: -------------------------------------------------------------------------------- 1 | @@ -92,6 +92,7 @@ 2 | if is_accelerate_available(): 3 | from accelerate.hooks import AlignDevicesHook, add_hook_to_module 4 | 5 | +import logging 6 | 7 | @dataclass 8 | class GenerateDecoderOnlyOutput(ModelOutput): 9 | @@ -4415,6 +4416,7 @@ 10 | # is no match. 11 | 12 | # 4.1. Get the valid continuation, after the matching tokens 13 | + logging.critical(f"[PROFILE] valid_tokens {valid_tokens.shape[1]} , n_matches {n_matches}") 14 | input_ids = torch.cat((input_ids, valid_tokens), dim=-1) 15 | if streamer is not None: 16 | streamer.put(valid_tokens.cpu()) 17 | -------------------------------------------------------------------------------- /requirement.txt: -------------------------------------------------------------------------------- 1 | arrow==1.3.0 2 | boto3==1.19.12 3 | filelock==3.12.4 4 | lightning==2.1.2 5 | lightning-cloud==0.5.52 6 | lightning-utilities==0.10.0 7 | markdown-it-py==3.0.0 8 | pydantic==2.5.2 9 | pydantic_core==2.14.5 10 | pytorch-lightning==2.1.2 11 | sentencepiece==0.1.99 12 | wandb==0.15.3 13 | zstandard==0.22.0 14 | transformers==4.37.2 15 | numpy==1.22.4 16 | jsonargparse==4.32.0 17 | backoff==2.2.1 18 | beautifulsoup4==4.12.3 19 | blessed==1.20.0 20 | croniter==1.4.1 21 | dateutils==0.6.12 22 | deepdiff==6.7.1 23 | editor==1.6.6 24 | inquirer==3.4.0 25 | itsdangerous==2.2.0 26 | ordered-set==4.1.0 27 | readchar==4.2.0 28 | runs==1.2.2 29 | soupsieve==2.6 30 | starsessions==1.3.0 31 | traitlets==5.14.3 32 | wcwidth==0.2.13 33 | websockets==11.0.3 34 | xmod==1.8.1 35 | datasets -------------------------------------------------------------------------------- /lit_gpt/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018-2024 Advanced Micro Devices, Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from lit_gpt.model import GPT 17 | from lit_gpt.config import Config 18 | from lit_gpt.tokenizer import Tokenizer 19 | from lightning_utilities.core.imports import RequirementCache 20 | 21 | __all__ = ["GPT", "Config", "Tokenizer"] 22 | -------------------------------------------------------------------------------- /scripts/prepare_mnbvc.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018-2024 Advanced Micro Devices, Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | source_path=../MNBVC 17 | target_path=../mnbvc_processed/ 18 | tokenizer_path=./scripts/tokenizer 19 | 20 | python ./scripts/prepare_mnbvc.py \ 21 | --source_path $source_path \ 22 | --tokenizer_path $tokenizer_path \ 23 | --destination_path $target_path \ 24 | --split train \ 25 | --percentage 1.0 26 | -------------------------------------------------------------------------------- /scripts/prepare_skypile.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018-2024 Advanced Micro Devices, Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | source_path=../SkyPile-150B/ 17 | target_path=../skypile_processed/ 18 | tokenizer_path=./scripts/tokenizer 19 | 20 | python ./scripts/prepare_skypile.py \ 21 | --source_path $source_path \ 22 | --tokenizer_path $tokenizer_path \ 23 | --destination_path $target_path \ 24 | --split train \ 25 | --percentage 1.0 26 | -------------------------------------------------------------------------------- /scripts/prepare_slimpajama_train.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018-2024 Advanced Micro Devices, Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | source_path=../SlimPajama-627B 17 | target_path=../slim_processed 18 | tokenizer_path=./scripts/tokenizer 19 | 20 | python3 ./scripts/prepare_slimpajama.py \ 21 | --source_path $source_path \ 22 | --tokenizer_path $tokenizer_path \ 23 | --destination_path $target_path \ 24 | --split train \ 25 | --percentage 1.0 26 | -------------------------------------------------------------------------------- /scripts/prepare_slimpajama_valid.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018-2024 Advanced Micro Devices, Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | source_path=../SlimPajama-627B 17 | target_path=../slim_validation_processed 18 | tokenizer_path=./scripts/tokenizer 19 | python3 ./scripts/prepare_slimpajama.py \ 20 | --source_path $source_path \ 21 | --tokenizer_path $tokenizer_path \ 22 | --destination_path $target_path \ 23 | --split validation \ 24 | --percentage 1.0 25 | -------------------------------------------------------------------------------- /speculative_decoding/codellama_spec.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018-2024 Advanced Micro Devices, Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | target_model=/path/to/target/model 17 | draft_model=/path/to/draft/model 18 | max_new_tokens=512 19 | 20 | python ./speculative_decoding/codellama_spec.py \ 21 | --target_model $target_model\ 22 | --draft_model $draft_model\ 23 | --max_new_tokens $max_new_tokens \ 24 | --temperature 0.1 \ 25 | --do_sample \ 26 | --bf16 \ 27 | -------------------------------------------------------------------------------- /scripts/prepare_starcoder.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018-2024 Advanced Micro Devices, Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | source_path=./starcoderdata/ 17 | target_path=./starcoderdata_processed/ 18 | #train: 12549 secs, ~500G 19 | tokenizer_path=./scripts/tokenizer 20 | 21 | python prepare_starcoder.py \ 22 | --source_path $source_path \ 23 | --tokenizer_path $tokenizer_path \ 24 | --destination_path $target_path \ 25 | --split train \ 26 | --percentage 1.0 27 | -------------------------------------------------------------------------------- /scripts/prepare_project_gutenberg.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018-2024 Advanced Micro Devices, Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | source_path=../project_gutenberg 17 | target_path=../slim_processed 18 | 19 | # train: 873 secs, ~20G 20 | tokenizer_path=./scripts/tokenizer 21 | 22 | python ./scripts/prepare_project_gutenberg.py \ 23 | --source_path $source_path \ 24 | --tokenizer_path $tokenizer_path \ 25 | --destination_path $target_path \ 26 | --split train \ 27 | --percentage 1.0 28 | -------------------------------------------------------------------------------- /scripts/prepare_starcoder_python.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018-2024 Advanced Micro Devices, Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | source_path=../starcoderdata/ 17 | target_path=../starcoderdata_python_processed/ 18 | #train: 12549 secs, ~500G 19 | tokenizer_path=./scripts/tokenizer 20 | 21 | python ./scripts/prepare_starcoder_python.py \ 22 | --source_path $source_path \ 23 | --tokenizer_path $tokenizer_path \ 24 | --destination_path $target_path \ 25 | --split train \ 26 | --percentage 1.0 27 | -------------------------------------------------------------------------------- /scripts/run_lm_eval.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018-2024 Advanced Micro Devices, Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | export HF_DATASETS_CACHE="./huggingface_data" 17 | 18 | for task in wikitext lambada_openai winogrande piqa sciq wsc arc_easy arc_challenge logiqa hellaswag mmlu truthfulqa gsm8k ceval-valid 19 | do 20 | export CUDA_VISIBLE_DEVICES="0" 21 | lm_eval --model hf \ 22 | --tasks $task \ 23 | --model_args pretrained=/path/to/your/huggingface/model \ 24 | --device cuda:0 \ 25 | --batch_size 2 26 | done 27 | -------------------------------------------------------------------------------- /scripts/convert_lit_model_to_hf.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018-2024 Advanced Micro Devices, Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | iter=104000 17 | model_path=./out_bak_100m_2k_code_iter_$iter/ 18 | mkdir $model_path 19 | checkpoint_name=iter-${iter}-ckpt.pth 20 | cp ./out/tinyllama_135M_2k/$checkpoint_name $model_path 21 | python scripts/convert_lit_checkpoint.py \ 22 | --checkpoint_name=$checkpoint_name\ 23 | --out_dir=$model_path \ 24 | --model_name='tiny_LLaMA_135M_2k' \ 25 | --model_only=False 26 | 27 | cp ./scripts/tokenizer/* ${model_path} 28 | mv ${model_path}/iter-${iter}-ckpt.bin ${model_path}/pytorch_model.bin 29 | rm ${model_path}/${checkpoint_name} 30 | -------------------------------------------------------------------------------- /cluster/pretrain.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018-2024 Advanced Micro Devices, Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | export TRAIN_DATA_PATH=/path/to/preprocessed/training/data 17 | export VALID_DATA_PATH=/path/to/preprocessed/validation/data 18 | export CUDA_VISIBLE_DEVICES='0,1,2,3,4,5,6,7' 19 | export NCCL_SOCKET_IFNAME={network interface name} 20 | export MASTER_ADDRESS={master node ip} 21 | export MAIN_OPRT={port} 22 | 23 | MODEL_NAME='tiny_LLaMA_135M_2k' 24 | lightning run model \ 25 | --node-rank=0 \ 26 | --main-address=$MASTER_ADDRESS \ 27 | --accelerator=cuda \ 28 | --devices=8 \ 29 | --num-nodes=1 \ 30 | --main-port=$MAIN_OPRT \ 31 | pretrain/tinyllama.py --devices 8 --train_data_dir $TRAIN_DATA_PATH --val_data_dir $VALID_DATA_PATH --model_name $MODEL_NAME\ 32 | -------------------------------------------------------------------------------- /cluster/pretrain_node_0.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018-2024 Advanced Micro Devices, Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | export TRAIN_DATA_PATH=/path/to/preprocessed/training/data 17 | export VALID_DATA_PATH=/path/to/preprocessed/validation/data 18 | export CUDA_VISIBLE_DEVICES='0,1,2,3,4,5,6,7' 19 | export NCCL_SOCKET_IFNAME={network interface name} 20 | export MASTER_ADDRESS={master node ip} 21 | export MAIN_OPRT={port} 22 | 23 | MODEL_NAME='tiny_LLaMA_135M_2k' 24 | lightning run model \ 25 | --node-rank=0 \ 26 | --main-address=$MASTER_ADDRESS \ 27 | --accelerator=cuda \ 28 | --devices=8 \ 29 | --num-nodes=4 \ 30 | --main-port=$MAIN_OPRT \ 31 | pretrain/tinyllama.py --precision 'bf16-mixed' --devices 8 --train_data_dir $TRAIN_DATA_PATH --val_data_dir $VALID_DATA_PATH --model_name $MODEL_NAME\ 32 | -------------------------------------------------------------------------------- /cluster/pretrain_node_1.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018-2024 Advanced Micro Devices, Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | export TRAIN_DATA_PATH=/path/to/preprocessed/training/data 17 | export VALID_DATA_PATH=/path/to/preprocessed/validation/data 18 | export CUDA_VISIBLE_DEVICES='0,1,2,3,4,5,6,7' 19 | export NCCL_SOCKET_IFNAME={network interface name} 20 | export MASTER_ADDRESS={master node ip} 21 | export MAIN_OPRT={port} 22 | 23 | MODEL_NAME='tiny_LLaMA_135M_2k' 24 | lightning run model \ 25 | --node-rank=1 \ 26 | --main-address=$MASTER_ADDRESS \ 27 | --accelerator=cuda \ 28 | --devices=8 \ 29 | --num-nodes=4 \ 30 | --main-port=$MAIN_OPRT \ 31 | pretrain/tinyllama.py --precision 'bf16-mixed' --devices 8 --train_data_dir $TRAIN_DATA_PATH --val_data_dir $VALID_DATA_PATH --model_name $MODEL_NAME\ 32 | -------------------------------------------------------------------------------- /cluster/pretrain_node_2.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018-2024 Advanced Micro Devices, Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | export TRAIN_DATA_PATH=/path/to/preprocessed/training/data 17 | export VALID_DATA_PATH=/path/to/preprocessed/validation/data 18 | export CUDA_VISIBLE_DEVICES='0,1,2,3,4,5,6,7' 19 | export NCCL_SOCKET_IFNAME={network interface name} 20 | export MASTER_ADDRESS={master node ip} 21 | export MAIN_OPRT={port} 22 | 23 | MODEL_NAME='tiny_LLaMA_135M_2k' 24 | lightning run model \ 25 | --node-rank=2 \ 26 | --main-address=$MASTER_ADDRESS \ 27 | --accelerator=cuda \ 28 | --devices=8 \ 29 | --num-nodes=4 \ 30 | --main-port=$MAIN_OPRT \ 31 | pretrain/tinyllama.py --precision 'bf16-mixed' --devices 8 --train_data_dir $TRAIN_DATA_PATH --val_data_dir $VALID_DATA_PATH --model_name $MODEL_NAME\ 32 | -------------------------------------------------------------------------------- /cluster/pretrain_node_3.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018-2024 Advanced Micro Devices, Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | export TRAIN_DATA_PATH=/path/to/preprocessed/training/data 17 | export VALID_DATA_PATH=/path/to/preprocessed/validation/data 18 | export CUDA_VISIBLE_DEVICES='0,1,2,3,4,5,6,7' 19 | export NCCL_SOCKET_IFNAME={network interface name} 20 | export MASTER_ADDRESS={master node ip} 21 | export MAIN_OPRT={port} 22 | 23 | MODEL_NAME='tiny_LLaMA_135M_2k' 24 | lightning run model \ 25 | --node-rank=3 \ 26 | --main-address=$MASTER_ADDRESS \ 27 | --accelerator=cuda \ 28 | --devices=8 \ 29 | --num-nodes=4 \ 30 | --main-port=$MAIN_OPRT \ 31 | pretrain/tinyllama.py --precision 'bf16-mixed' --devices 8 --train_data_dir $TRAIN_DATA_PATH --val_data_dir $VALID_DATA_PATH --model_name $MODEL_NAME\ 32 | -------------------------------------------------------------------------------- /cluster/finetune.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018-2024 Advanced Micro Devices, Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | export TRAIN_DATA_PATH=/path/to/preprocessed/training/data 17 | export VALID_DATA_PATH=/path/to/preprocessed/validation/data 18 | export CUDA_VISIBLE_DEVICES='0,1,2,3,4,5,6,7' 19 | export BASE_MODEL_PATH=/path/to/base/model/ckpt 20 | export NCCL_SOCKET_IFNAME={network interface name} 21 | export MASTER_ADDRESS={master node ip} 22 | export MAIN_OPRT={port} 23 | 24 | MODEL_NAME='tiny_LLaMA_135M_2k' 25 | lightning run model \ 26 | --node-rank=0 \ 27 | --main-address=$MASTER_ADDRESS \ 28 | --accelerator=cuda \ 29 | --devices=8 \ 30 | --num-nodes=1 \ 31 | --main-port=$MAIN_OPRT \ 32 | pretrain/tinyllama_code.py --devices 8 --train_data_dir $TRAIN_DATA_PATH --val_data_dir $VALID_DATA_PATH --model_name $MODEL_NAME \ 33 | --checkpoint_path $BASE_MODEL_PATH 34 | -------------------------------------------------------------------------------- /lit_gpt/rmsnorm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018-2024 Advanced Micro Devices, Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import torch 17 | # Copyright (c) 2022, Tri Dao. 18 | # Adapted from https://github.com/Dao-AILab/flash-attention/blob/7a983df74215e035e566e37125b0a71e3618f39d/flash_attn/ops/layer_norm.py#L16 19 | 20 | import torch 21 | from torch.nn import init 22 | 23 | 24 | class RMSNorm(torch.nn.Module): 25 | """Root Mean Square Layer Normalization. 26 | 27 | Derived from https://github.com/bzhangGo/rmsnorm/blob/master/rmsnorm_torch.py. BSD 3-Clause License: 28 | https://github.com/bzhangGo/rmsnorm/blob/master/LICENSE. 29 | """ 30 | 31 | def __init__(self, size: int, dim: int = -1, eps: float = 1e-5) -> None: 32 | super().__init__() 33 | self.weight = torch.nn.Parameter(torch.ones(size)) 34 | self.eps = eps 35 | self.dim = dim 36 | 37 | def forward(self, x: torch.Tensor) -> torch.Tensor: 38 | # NOTE: the original RMSNorm paper implementation is not equivalent 39 | norm_x = torch.mean(x * x, dim=self.dim, keepdim=True) 40 | x_normed = x * torch.rsqrt(norm_x + self.eps) 41 | return self.weight * x_normed 42 | 43 | def reset_parameters(self): 44 | torch.nn.init.ones_(self.weight) 45 | 46 | try: 47 | import apex 48 | class FusedRMSNorm(apex.normalization.FusedRMSNorm): 49 | def __init__(self, size: int, dim: int = -1, eps: float = 1e-5): 50 | super().__init__(size, eps=eps, elementwise_affine=True) 51 | self.eps = eps 52 | self.weight = torch.nn.Parameter(torch.ones(size)) 53 | self.dim = dim 54 | self.reset_parameters() 55 | 56 | def reset_parameters(self): 57 | init.ones_(self.weight) 58 | 59 | # def forward(self, x): 60 | # return rms_norm(x, self.weight, self.eps) 61 | except: 62 | print("Fail to import FusedRMSNorm, use RMSNorm instead.") 63 | FusedRMSNorm = RMSNorm 64 | -------------------------------------------------------------------------------- /lit_gpt/rotary_ebm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018-2024 Advanced Micro Devices, Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import torch 17 | 18 | 19 | def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): 20 | """Applies Rotary Position Embedding to the query and key tensors. 21 | 22 | Args: 23 | q (`torch.Tensor`): The query tensor. 24 | k (`torch.Tensor`): The key tensor. 25 | cos (`torch.Tensor`): The cosine part of the rotary embedding. 26 | sin (`torch.Tensor`): The sine part of the rotary embedding. 27 | position_ids (`torch.Tensor`): 28 | The position indices of the tokens corresponding to the query and key tensors. For example, this can be 29 | used to pass offsetted position ids when working with a KV-cache. 30 | unsqueeze_dim (`int`, *optional*, defaults to 1): 31 | The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and 32 | sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note 33 | that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and 34 | k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes 35 | cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have 36 | the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. 37 | Returns: 38 | `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. 39 | """ 40 | cos = cos[position_ids].unsqueeze(unsqueeze_dim) 41 | sin = sin[position_ids].unsqueeze(unsqueeze_dim) 42 | q_embed = (q * cos) + (rotate_half(q) * sin) 43 | k_embed = (k * cos) + (rotate_half(k) * sin) 44 | return q_embed.type_as(q), k_embed.type_as(k) 45 | 46 | 47 | def rotate_half(x): 48 | """Rotates half the hidden dims of the input.""" 49 | x1 = x[..., : x.shape[-1] // 2] 50 | x2 = x[..., x.shape[-1] // 2 :] 51 | return torch.cat((-x2, x1), dim=-1) 52 | 53 | -------------------------------------------------------------------------------- /scripts/datasets_statistics.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018-2024 Advanced Micro Devices, Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import glob 17 | import os 18 | 19 | data_path = '../slim_processed' 20 | chunk_size=4097*1024 21 | 22 | prefix = 'train_*' 23 | data_split = { 24 | 'starcoder': 'train_starcoder', 25 | 'slimpajama_wiki': 'train_wikipedia_slimpajama', 26 | 'slimpajama_git': 'train_github_slimpajama', 27 | 'slimpajama_book': 'train_book_slimpajama', 28 | 'slimpajama': 'train_slimpajama', 29 | 'mnbvc': 'train_mnbvc', 30 | 'skypile': 'train_skypile', 31 | 'openwebmath': 'train_openwebmath', 32 | 'project_gutenberg': 'train_project_gutenberg', 33 | } 34 | data_epoches = { 35 | 'starcoder': 1.0, 36 | 'slimpajama_wiki': 1.0, 37 | 'slimpajama_git': 1.0, 38 | 'slimpajama_book': 1.0, 39 | 'slimpajama': 1.0, 40 | 'mnbvc': 1.0, 41 | 'skypile': 1.0, 42 | 'openwebmath': 1.0, 43 | 'project_gutenberg': 1.0, 44 | } 45 | data_statis = {} 46 | for data_name in data_split: 47 | data_statis[data_name] = 0 48 | total_chunks = 0 49 | total_tokens = 0 50 | 51 | filenames = glob.glob(os.path.join(data_path, prefix), recursive=True) 52 | for filename in filenames: 53 | for data_name, pref in data_split.items(): 54 | if filename[len(os.path.dirname(filename))+1:].startswith(pref): 55 | data_statis[data_name] += 1 56 | total_chunks += 1 57 | print('statistics:') 58 | for data_name, num_chunk in data_statis.items(): 59 | print(f'{num_chunk*chunk_size/1000/1000/1000} B tokens, ', f'{num_chunk} chunks, ', data_name) 60 | total_tokens += num_chunk*chunk_size 61 | print(f"{total_tokens/1000/1000/1000} B tokens", f"{total_chunks} chunks in total.") 62 | 63 | print("percentage:") 64 | for data_name, num_chunk in data_statis.items(): 65 | print(f'1.0 epoches, {num_chunk*chunk_size / total_tokens} %, ', data_name) 66 | 67 | print("weighted:") 68 | total_tokens = 0 69 | for data_name, num_chunk in data_statis.items(): 70 | print(f'{num_chunk*chunk_size*data_epoches[data_name]/1000/1000/1000} B tokens, ', f'{num_chunk} chunks, ', data_name) 71 | total_tokens += num_chunk*chunk_size*data_epoches[data_name] 72 | 73 | for data_name, num_chunk in data_statis.items(): 74 | print(f'{data_epoches[data_name]} epoches, {num_chunk*chunk_size*data_epoches[data_name] / total_tokens*100} %, ', data_name) 75 | print(f"{total_tokens/1000/1000/1000} B tokens", f"{total_chunks} chunks in total.") 76 | -------------------------------------------------------------------------------- /lit_gpt/tokenizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018-2024 Advanced Micro Devices, Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import json 17 | from pathlib import Path 18 | from typing import Optional 19 | 20 | import torch 21 | 22 | 23 | class Tokenizer: 24 | def __init__(self, checkpoint_dir: Path) -> None: 25 | # some checkpoints have both files, `.model` takes precedence 26 | if (vocabulary_path := checkpoint_dir / "tokenizer.model").is_file(): 27 | from sentencepiece import SentencePieceProcessor 28 | 29 | self.processor = SentencePieceProcessor(model_file=str(vocabulary_path)) 30 | self.backend = "sentencepiece" 31 | self.bos_id = self.processor.bos_id() 32 | self.eos_id = self.processor.eos_id() 33 | elif (vocabulary_path := checkpoint_dir / "tokenizer.json").is_file(): 34 | from tokenizers import Tokenizer as HFTokenizer 35 | 36 | self.processor = HFTokenizer.from_file(str(vocabulary_path)) 37 | self.backend = "huggingface" 38 | with open(checkpoint_dir / "tokenizer_config.json") as fp: 39 | config = json.load(fp) 40 | self.eos_id = self.token_to_id(config["eos_token"]) 41 | bos_token = config.get("bos_token") 42 | self.bos_id = self.token_to_id(bos_token) if bos_token is not None else self.eos_id 43 | else: 44 | raise NotImplementedError 45 | 46 | @property 47 | def vocab_size(self) -> int: 48 | if self.backend == "huggingface": 49 | return self.processor.get_vocab_size(with_added_tokens=False) 50 | if self.backend == "sentencepiece": 51 | return self.processor.vocab_size() 52 | raise RuntimeError 53 | 54 | def token_to_id(self, token: str) -> int: 55 | if self.backend == "huggingface": 56 | id_ = self.processor.token_to_id(token) 57 | elif self.backend == "sentencepiece": 58 | id_ = self.processor.piece_to_id(token) 59 | else: 60 | raise RuntimeError 61 | if id_ is None: 62 | raise ValueError(f"token {token!r} not found in the collection.") 63 | return id_ 64 | 65 | def encode( 66 | self, 67 | string: str, 68 | device: Optional[torch.device] = None, 69 | bos: bool = False, 70 | eos: bool = True, 71 | max_length: int = -1, 72 | ) -> torch.Tensor: 73 | if self.backend == "huggingface": 74 | tokens = self.processor.encode(string).ids 75 | elif self.backend == "sentencepiece": 76 | tokens = self.processor.encode(string) 77 | else: 78 | raise RuntimeError 79 | if bos: 80 | bos_id = self.bos_id 81 | if bos_id is None: 82 | raise NotImplementedError("This tokenizer does not defined a bos token") 83 | tokens = [bos_id] + tokens 84 | if eos: 85 | tokens = tokens + [self.eos_id] 86 | if max_length > 0: 87 | tokens = tokens[:max_length] 88 | return torch.tensor(tokens, dtype=torch.int, device=device) 89 | 90 | def decode(self, tensor: torch.Tensor) -> str: 91 | tokens = [tensor.item()] if tensor.ndim == 0 else tensor.tolist() 92 | return self.processor.decode(tokens) 93 | -------------------------------------------------------------------------------- /lit_gpt/fused_rotary_embedding.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018-2024 Advanced Micro Devices, Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | # Copyright (c) 2023, Tri Dao. 17 | 18 | import math 19 | from typing import Optional, Tuple 20 | 21 | import rotary_emb 22 | import torch 23 | from einops import rearrange, repeat 24 | 25 | class ApplyRotaryEmb(torch.autograd.Function): 26 | @staticmethod 27 | def forward(ctx, x, cos, sin, interleaved=False, inplace=False): 28 | """ 29 | x: (batch_size, seqlen, nheads, headdim) 30 | cos, sin: (seqlen, rotary_dim / 2) 31 | interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead 32 | of 1st half and 2nd half (GPT-NeoX style). 33 | rotary_dim must be <= headdim 34 | Apply rotary embedding to the first rotary_dim of x. 35 | """ 36 | batch, seqlen, nheads, headdim = x.shape 37 | rotary_seqlen, rotary_dim = cos.shape 38 | rotary_dim *= 2 39 | assert rotary_dim <= headdim 40 | assert seqlen <= rotary_seqlen 41 | assert sin.shape == (rotary_seqlen, rotary_dim // 2) 42 | x_ro = x[..., :rotary_dim] 43 | x1, x2 = x_ro.chunk(2, dim=-1) if not interleaved else (x_ro[..., ::2], x_ro[..., 1::2]) 44 | out = torch.empty_like(x) if not inplace else x 45 | out_ro = out[..., :rotary_dim] 46 | if inplace: 47 | o1, o2 = x1, x2 48 | else: 49 | o1, o2 = ( 50 | out_ro.chunk(2, dim=-1) 51 | if not interleaved 52 | else (out_ro[..., ::2], out_ro[..., 1::2]) 53 | ) 54 | rotary_emb.apply_rotary( 55 | x1, 56 | x2, 57 | rearrange(cos[:seqlen], "s d -> s 1 d"), 58 | rearrange(sin[:seqlen], "s d -> s 1 d"), 59 | o1, 60 | o2, 61 | False, 62 | ) 63 | if not inplace and rotary_dim < headdim: 64 | out[..., rotary_dim:].copy_(x[..., rotary_dim:]) 65 | ctx.save_for_backward(cos, sin) 66 | ctx.interleaved = interleaved 67 | ctx.inplace = inplace 68 | return out if not inplace else x 69 | 70 | @staticmethod 71 | def backward(ctx, do): 72 | cos, sin = ctx.saved_tensors 73 | _, seqlen, _, headdim = do.shape 74 | rotary_dim = cos.shape[-1] 75 | rotary_dim *= 2 76 | inplace = ctx.inplace 77 | do_ro = do[..., :rotary_dim] 78 | do1, do2 = ( 79 | do_ro.chunk(2, dim=-1) if not ctx.interleaved else (do_ro[..., ::2], do_ro[..., 1::2]) 80 | ) 81 | dx = torch.empty_like(do) if not inplace else do 82 | if inplace: 83 | dx1, dx2 = do1, do2 84 | else: 85 | dx_ro = dx[..., :rotary_dim] 86 | dx1, dx2 = ( 87 | dx_ro.chunk(2, dim=-1) 88 | if not ctx.interleaved 89 | else (dx_ro[..., ::2], dx_ro[..., 1::2]) 90 | ) 91 | rotary_emb.apply_rotary( 92 | do1, 93 | do2, 94 | rearrange(cos[:seqlen], "s d -> s 1 d"), 95 | rearrange(sin[:seqlen], "s d -> s 1 d"), 96 | dx1, 97 | dx2, 98 | True, 99 | ) 100 | if not inplace and rotary_dim < headdim: 101 | dx[..., rotary_dim:].copy_(do[..., rotary_dim:]) 102 | return dx, None, None, None, None 103 | 104 | 105 | apply_rotary_emb_func = ApplyRotaryEmb.apply 106 | 107 | -------------------------------------------------------------------------------- /scripts/prepare_starcoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018-2024 Advanced Micro Devices, Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import json 17 | import glob 18 | import os 19 | from pathlib import Path 20 | import sys 21 | from typing import List 22 | import numpy as np 23 | from tqdm import tqdm 24 | from multiprocessing import Process, cpu_count 25 | 26 | # support running without installing as a package 27 | wd = Path(__file__).parent.parent.resolve() 28 | sys.path.append(str(wd)) 29 | 30 | import lit_gpt.packed_dataset as packed_dataset 31 | from lit_gpt import Tokenizer 32 | 33 | import pandas as pd 34 | 35 | 36 | def prepare_full( 37 | source_path: Path, 38 | tokenizer_path: Path, 39 | destination_path: Path, 40 | chunk_size: int, 41 | split: str="train", 42 | filenames_subset: List[str] = None, 43 | process_id: int = 0 44 | ) -> None: 45 | import zstandard as zstd 46 | 47 | destination_path.mkdir(parents=True, exist_ok=True) 48 | 49 | tokenizer = Tokenizer(tokenizer_path) 50 | 51 | # Use the provided filenames_subset or default to all filenames 52 | filenames = filenames_subset 53 | 54 | if not filenames: 55 | raise RuntimeError( 56 | f"No files matching found at {source_path}. \n" 57 | "Make sure you download the data..." 58 | ) 59 | 60 | builder = packed_dataset.PackedDatasetBuilder( 61 | outdir=destination_path, 62 | prefix=f"{split}_starcoder_{process_id}", # Use process_id to differentiate builders 63 | chunk_size=chunk_size, 64 | sep_token=tokenizer.bos_id, 65 | dtype="auto", 66 | vocab_size=tokenizer.vocab_size, 67 | ) 68 | 69 | for filepath in filenames: 70 | print(f"Processing {filepath}") 71 | try: 72 | contents = pd.read_parquet(filepath, engine='pyarrow')['content'] 73 | except: 74 | print(f"Error reading {filepath}!!") 75 | continue 76 | for text in contents: 77 | text_ids = tokenizer.encode(text) 78 | builder.add_array(np.array(text_ids, dtype=builder.dtype)) 79 | 80 | # we throw away the final corpus to avoid meaningless corpus filled with bos_ids, see https://github.com/jzhang38/TinyLlama/issues/83 for more details 81 | # builder.write_reminder() 82 | 83 | 84 | def prepare( 85 | source_path: Path = Path("data/RedPajama-Data-1T-Sample"), 86 | tokenizer_path: Path = Path("checkpoints/lit-llama/tokenizer.model"), 87 | destination_path: Path = Path("data/red_pajama_sample"), 88 | chunk_size: int = 2048 * 2049, 89 | split: str="train", 90 | percentage: float = 1.0, 91 | filenames_subset: List[str] = None, 92 | ) -> None: 93 | import time 94 | assert split == "train" # starcoder only has train data 95 | filenames = glob.glob(os.path.join(source_path, "*/*.parquet"), recursive=True) 96 | # only retrain subsets that follow the prefix in filenames_subset 97 | if filenames_subset: 98 | filenames = [f for f in filenames if any([prefix in f for prefix in filenames_subset])] 99 | filenames = filenames[:int(len(filenames) * percentage)] 100 | num_processes = cpu_count()#64 101 | chunked_filenames = np.array_split(filenames, num_processes) 102 | 103 | processes = [] 104 | start_time = time.time() 105 | 106 | for i, subset in enumerate(chunked_filenames): 107 | p = Process(target=prepare_full, args=(source_path, tokenizer_path, destination_path, chunk_size, split, list(subset), i)) 108 | processes.append(p) 109 | p.start() 110 | 111 | for p in processes: 112 | p.join() 113 | end_time = time.time() 114 | elapsed_time = end_time - start_time 115 | print(f"Time taken: {elapsed_time:.2f} seconds") 116 | 117 | 118 | if __name__ == "__main__": 119 | from jsonargparse import CLI 120 | CLI(prepare) 121 | -------------------------------------------------------------------------------- /scripts/prepare_skypile.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018-2024 Advanced Micro Devices, Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import json 17 | import glob 18 | import os 19 | from pathlib import Path 20 | import sys 21 | from typing import List 22 | import numpy as np 23 | from tqdm import tqdm 24 | from multiprocessing import Process, cpu_count 25 | 26 | # support running without installing as a package 27 | wd = Path(__file__).parent.parent.resolve() 28 | sys.path.append(str(wd)) 29 | 30 | import lit_gpt.packed_dataset as packed_dataset 31 | from lit_gpt import Tokenizer 32 | 33 | import pandas as pd 34 | 35 | 36 | def prepare_full( 37 | source_path: Path, 38 | tokenizer_path: Path, 39 | destination_path: Path, 40 | chunk_size: int, 41 | split: str="train", 42 | filenames_subset: List[str] = None, 43 | process_id: int = 0 44 | ) -> None: 45 | import zstandard as zstd 46 | 47 | destination_path.mkdir(parents=True, exist_ok=True) 48 | 49 | tokenizer = Tokenizer(tokenizer_path) 50 | 51 | # Use the provided filenames_subset or default to all filenames 52 | filenames = filenames_subset 53 | 54 | if not filenames: 55 | raise RuntimeError( 56 | f"No files matching found at {source_path}. \n" 57 | "Make sure you download the data..." 58 | ) 59 | 60 | builder = packed_dataset.PackedDatasetBuilder( 61 | outdir=destination_path, 62 | prefix=f"{split}_skypile_{process_id}", # Use process_id to differentiate builders 63 | chunk_size=chunk_size, 64 | sep_token=tokenizer.bos_id, 65 | dtype="auto", 66 | vocab_size=tokenizer.vocab_size, 67 | ) 68 | 69 | for filepath in filenames: 70 | print(f"Processing {filepath}") 71 | try: 72 | # contents = pd.read_parquet(filepath, engine='pyarrow')['content'] 73 | contents = pd.read_json(path_or_buf=filepath, lines=True)['text'] 74 | except: 75 | print(f"Error reading {filepath}!!") 76 | continue 77 | for text in contents: 78 | text_ids = tokenizer.encode(text) 79 | builder.add_array(np.array(text_ids, dtype=builder.dtype)) 80 | 81 | # we throw away the final corpus to avoid meaningless corpus filled with bos_ids, see https://github.com/jzhang38/TinyLlama/issues/83 for more details 82 | # builder.write_reminder() 83 | 84 | 85 | def prepare( 86 | source_path: Path = Path("data/RedPajama-Data-1T-Sample"), 87 | tokenizer_path: Path = Path("checkpoints/lit-llama/tokenizer.model"), 88 | destination_path: Path = Path("data/red_pajama_sample"), 89 | chunk_size: int = 2048 * 2049, 90 | split: str="train", 91 | percentage: float = 1.0, 92 | filenames_subset: List[str] = None, 93 | ) -> None: 94 | import time 95 | assert split == "train" # starcoder only has train data 96 | filenames = glob.glob(os.path.join(source_path, "*/*.jsonl"), recursive=True) 97 | # only retrain subsets that follow the prefix in filenames_subset 98 | if filenames_subset: 99 | filenames = [f for f in filenames if any([prefix in f for prefix in filenames_subset])] 100 | filenames = filenames[:int(len(filenames) * percentage)] 101 | num_processes = cpu_count() 102 | chunked_filenames = np.array_split(filenames, num_processes) 103 | 104 | processes = [] 105 | start_time = time.time() 106 | 107 | for i, subset in enumerate(chunked_filenames): 108 | p = Process(target=prepare_full, args=(source_path, tokenizer_path, destination_path, chunk_size, split, list(subset), i)) 109 | processes.append(p) 110 | p.start() 111 | 112 | for p in processes: 113 | p.join() 114 | end_time = time.time() 115 | elapsed_time = end_time - start_time 116 | print(f"Time taken: {elapsed_time:.2f} seconds") 117 | 118 | 119 | if __name__ == "__main__": 120 | from jsonargparse import CLI 121 | CLI(prepare) 122 | -------------------------------------------------------------------------------- /scripts/prepare_starcoder_python.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018-2024 Advanced Micro Devices, Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import json 17 | import glob 18 | import os 19 | from pathlib import Path 20 | import sys 21 | from typing import List 22 | import numpy as np 23 | from tqdm import tqdm 24 | from multiprocessing import Process, cpu_count 25 | 26 | # support running without installing as a package 27 | wd = Path(__file__).parent.parent.resolve() 28 | sys.path.append(str(wd)) 29 | 30 | import lit_gpt.packed_dataset as packed_dataset 31 | from lit_gpt import Tokenizer 32 | 33 | import pandas as pd 34 | 35 | 36 | def prepare_full( 37 | source_path: Path, 38 | tokenizer_path: Path, 39 | destination_path: Path, 40 | chunk_size: int, 41 | split: str="train", 42 | filenames_subset: List[str] = None, 43 | process_id: int = 0 44 | ) -> None: 45 | import zstandard as zstd 46 | 47 | destination_path.mkdir(parents=True, exist_ok=True) 48 | 49 | tokenizer = Tokenizer(tokenizer_path) 50 | 51 | # Use the provided filenames_subset or default to all filenames 52 | filenames = filenames_subset 53 | 54 | if not filenames: 55 | raise RuntimeError( 56 | f"No files matching found at {source_path}. \n" 57 | "Make sure you download the data..." 58 | ) 59 | 60 | builder = packed_dataset.PackedDatasetBuilder( 61 | outdir=destination_path, 62 | prefix=f"{split}_starcoder_{process_id}", # Use process_id to differentiate builders 63 | chunk_size=chunk_size, 64 | sep_token=tokenizer.bos_id, 65 | dtype="auto", 66 | vocab_size=tokenizer.vocab_size, 67 | ) 68 | 69 | for filepath in filenames: 70 | print(f"Processing {filepath}") 71 | contents = pd.read_parquet(filepath, engine='pyarrow')['content'] 72 | try: 73 | contents = pd.read_parquet(filepath, engine='pyarrow')['content'] 74 | except: 75 | print(f"Error reading {filepath}!!") 76 | continue 77 | for text in contents: 78 | text_ids = tokenizer.encode(text) 79 | builder.add_array(np.array(text_ids, dtype=builder.dtype)) 80 | 81 | # we throw away the final corpus to avoid meaningless corpus filled with bos_ids, see https://github.com/jzhang38/TinyLlama/issues/83 for more details 82 | # builder.write_reminder() 83 | 84 | 85 | def prepare( 86 | source_path: Path = Path("data/RedPajama-Data-1T-Sample"), 87 | tokenizer_path: Path = Path("checkpoints/lit-llama/tokenizer.model"), 88 | destination_path: Path = Path("data/red_pajama_sample"), 89 | chunk_size: int = 2048 * 2049, 90 | split: str="train", 91 | percentage: float = 1.0, 92 | filenames_subset: List[str] = None, 93 | ) -> None: 94 | import time 95 | assert split == "train" # starcoder only has train data 96 | filenames = glob.glob(os.path.join(source_path, "python/*.parquet"), recursive=True) 97 | # only retrain subsets that follow the prefix in filenames_subset 98 | if filenames_subset: 99 | filenames = [f for f in filenames if any([prefix in f for prefix in filenames_subset])] 100 | filenames = filenames[:int(len(filenames) * percentage)] 101 | num_processes = min(len(filenames), cpu_count()) 102 | chunked_filenames = np.array_split(filenames, num_processes) 103 | 104 | processes = [] 105 | start_time = time.time() 106 | 107 | for i, subset in enumerate(chunked_filenames): 108 | p = Process(target=prepare_full, args=(source_path, tokenizer_path, destination_path, chunk_size, split, list(subset), i)) 109 | processes.append(p) 110 | p.start() 111 | 112 | for p in processes: 113 | p.join() 114 | end_time = time.time() 115 | elapsed_time = end_time - start_time 116 | print(f"Time taken: {elapsed_time:.2f} seconds") 117 | 118 | 119 | if __name__ == "__main__": 120 | from jsonargparse import CLI 121 | CLI(prepare) 122 | -------------------------------------------------------------------------------- /scripts/prepare_project_gutenberg.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018-2024 Advanced Micro Devices, Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import json 17 | import glob 18 | import os 19 | from pathlib import Path 20 | import sys 21 | from typing import List 22 | import numpy as np 23 | from tqdm import tqdm 24 | from multiprocessing import Process, cpu_count 25 | 26 | # support running without installing as a package 27 | wd = Path(__file__).parent.parent.resolve() 28 | sys.path.append(str(wd)) 29 | 30 | import lit_gpt.packed_dataset as packed_dataset 31 | from lit_gpt import Tokenizer 32 | 33 | import pandas as pd 34 | 35 | 36 | def prepare_full( 37 | source_path: Path, 38 | tokenizer_path: Path, 39 | destination_path: Path, 40 | chunk_size: int, 41 | split: str="train", 42 | filenames_subset: List[str] = None, 43 | process_id: int = 0 44 | ) -> None: 45 | import zstandard as zstd 46 | 47 | destination_path.mkdir(parents=True, exist_ok=True) 48 | 49 | tokenizer = Tokenizer(tokenizer_path) 50 | 51 | # Use the provided filenames_subset or default to all filenames 52 | filenames = filenames_subset 53 | 54 | if not filenames: 55 | raise RuntimeError( 56 | f"No files matching found at {source_path}. \n" 57 | "Make sure you download the data..." 58 | ) 59 | 60 | builder = packed_dataset.PackedDatasetBuilder( 61 | outdir=destination_path, 62 | prefix=f"{split}_project_gutenberg_{process_id}", # Use process_id to differentiate builders 63 | chunk_size=chunk_size, 64 | sep_token=tokenizer.bos_id, 65 | dtype="auto", 66 | vocab_size=tokenizer.vocab_size, 67 | ) 68 | 69 | for filepath in filenames: 70 | print(f"Processing {filepath}") 71 | print(filepath) 72 | contents = pd.read_parquet(filepath, engine='pyarrow')['text'] 73 | #try: 74 | # contents = pd.read_parquet(filepath, engine='pyarrow')['text'] 75 | #except: 76 | # print(f"Error reading {filepath}!!") 77 | # continue 78 | for text in contents: 79 | text_ids = tokenizer.encode(text) 80 | builder.add_array(np.array(text_ids, dtype=builder.dtype)) 81 | 82 | # we throw away the final corpus to avoid meaningless corpus filled with bos_ids, see https://github.com/jzhang38/TinyLlama/issues/83 for more details 83 | # builder.write_reminder() 84 | 85 | 86 | def prepare( 87 | source_path: Path = Path("data/RedPajama-Data-1T-Sample"), 88 | tokenizer_path: Path = Path("checkpoints/lit-llama/tokenizer.model"), 89 | destination_path: Path = Path("data/red_pajama_sample"), 90 | chunk_size: int = 2049 * 2048, 91 | split: str="train", 92 | percentage: float = 1.0, 93 | filenames_subset: List[str] = None, 94 | ) -> None: 95 | import time 96 | assert split == "train" # starcoder only has train data 97 | filenames = glob.glob(os.path.join(source_path, "*/*.parquet"), recursive=True) 98 | # only retrain subsets that follow the prefix in filenames_subset 99 | if filenames_subset: 100 | filenames = [f for f in filenames if any([prefix in f for prefix in filenames_subset])] 101 | filenames = filenames[:int(len(filenames) * percentage)] 102 | num_processes = min(cpu_count(), len(filenames)) 103 | chunked_filenames = np.array_split(filenames, num_processes) 104 | 105 | processes = [] 106 | start_time = time.time() 107 | 108 | for i, subset in enumerate(chunked_filenames): 109 | p = Process(target=prepare_full, args=(source_path, tokenizer_path, destination_path, chunk_size, split, list(subset), i)) 110 | processes.append(p) 111 | p.start() 112 | 113 | for p in processes: 114 | p.join() 115 | end_time = time.time() 116 | elapsed_time = end_time - start_time 117 | print(f"Time taken: {elapsed_time:.2f} seconds") 118 | 119 | 120 | if __name__ == "__main__": 121 | from jsonargparse import CLI 122 | CLI(prepare) 123 | -------------------------------------------------------------------------------- /scripts/prepare_mnbvc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018-2024 Advanced Micro Devices, Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import json 17 | import glob 18 | import os 19 | from pathlib import Path 20 | import sys 21 | from typing import List 22 | import numpy as np 23 | from tqdm import tqdm 24 | from multiprocessing import Process, cpu_count 25 | 26 | # support running without installing as a package 27 | wd = Path(__file__).parent.parent.resolve() 28 | sys.path.append(str(wd)) 29 | 30 | import lit_gpt.packed_dataset as packed_dataset 31 | from lit_gpt import Tokenizer 32 | 33 | import pandas as pd 34 | import gzip 35 | 36 | def prepare_full( 37 | source_path: Path, 38 | tokenizer_path: Path, 39 | destination_path: Path, 40 | chunk_size: int, 41 | split: str="train", 42 | filenames_subset: List[str] = None, 43 | process_id: int = 0 44 | ) -> None: 45 | import zstandard as zstd 46 | 47 | destination_path.mkdir(parents=True, exist_ok=True) 48 | 49 | tokenizer = Tokenizer(tokenizer_path) 50 | 51 | # Use the provided filenames_subset or default to all filenames 52 | filenames = filenames_subset 53 | 54 | if not filenames: 55 | raise RuntimeError( 56 | f"No files matching found at {source_path}. \n" 57 | "Make sure you download the data..." 58 | ) 59 | 60 | builder = packed_dataset.PackedDatasetBuilder( 61 | outdir=destination_path, 62 | prefix=f"{split}_mnbvc_{process_id}", # Use process_id to differentiate builders 63 | chunk_size=chunk_size, 64 | sep_token=tokenizer.bos_id, 65 | dtype="auto", 66 | vocab_size=tokenizer.vocab_size, 67 | ) 68 | 69 | for filepath in filenames: 70 | print(f"Processing {filepath}") 71 | try: 72 | # contents = pd.read_parquet(filepath, engine='pyarrow')['content'] 73 | if 'code/metadata/' in filepath: 74 | print("Not use metadata!") 75 | continue 76 | with gzip.open(open(filepath, "rb"), mode="rt") as f: 77 | for row in tqdm(f): 78 | text = json.loads(row)["text"] 79 | text_ids = tokenizer.encode(text) 80 | builder.add_array(np.array(text_ids, dtype=builder.dtype)) 81 | except: 82 | print(f"Error reading {filepath}!!") 83 | continue 84 | 85 | # we throw away the final corpus to avoid meaningless corpus filled with bos_ids, see https://github.com/jzhang38/TinyLlama/issues/83 for more details 86 | # builder.write_reminder() 87 | 88 | 89 | def prepare( 90 | source_path: Path = Path("data/RedPajama-Data-1T-Sample"), 91 | tokenizer_path: Path = Path("checkpoints/lit-llama/tokenizer.model"), 92 | destination_path: Path = Path("data/red_pajama_sample"), 93 | chunk_size: int = 4097 * 1024, 94 | split: str="train", 95 | percentage: float = 1.0, 96 | filenames_subset: List[str] = None, 97 | ) -> None: 98 | import time 99 | assert split == "train" # starcoder only has train data 100 | filenames = glob.glob(os.path.join(source_path, "*/*/*.jsonl.gz"), recursive=True) 101 | filenames += glob.glob(os.path.join(source_path, "*/*/*/*.jsonl.gz"), recursive=True) 102 | print(len(filenames)) 103 | # only retrain subsets that follow the prefix in filenames_subset 104 | if filenames_subset: 105 | filenames = [f for f in filenames if any([prefix in f for prefix in filenames_subset])] 106 | filenames = filenames[:int(len(filenames) * percentage)] 107 | num_processes = 64 108 | chunked_filenames = np.array_split(filenames, num_processes) 109 | 110 | processes = [] 111 | start_time = time.time() 112 | 113 | for i, subset in enumerate(chunked_filenames): 114 | p = Process(target=prepare_full, args=(source_path, tokenizer_path, destination_path, chunk_size, split, list(subset), i)) 115 | processes.append(p) 116 | p.start() 117 | 118 | for p in processes: 119 | p.join() 120 | end_time = time.time() 121 | elapsed_time = end_time - start_time 122 | print(f"Time taken: {elapsed_time:.2f} seconds") 123 | 124 | 125 | if __name__ == "__main__": 126 | from jsonargparse import CLI 127 | CLI(prepare) 128 | -------------------------------------------------------------------------------- /scripts/prepare_slimpajama.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018-2024 Advanced Micro Devices, Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import json 17 | import glob 18 | import os 19 | from pathlib import Path 20 | import sys 21 | from typing import List 22 | import numpy as np 23 | from tqdm import tqdm 24 | from multiprocessing import Process, cpu_count 25 | 26 | # support running without installing as a package 27 | wd = Path(__file__).parent.parent.resolve() 28 | sys.path.append(str(wd)) 29 | 30 | import lit_gpt.packed_dataset as packed_dataset 31 | from lit_gpt.tokenizer import Tokenizer 32 | 33 | # Filename for SlimPajama 34 | slimpajama_sets = { 35 | "train": "train/chunk*/*", 36 | "validation": "validation/chunk*/*", 37 | "test": "test/chunk*/*", 38 | } 39 | 40 | 41 | def prepare_full( 42 | source_path: Path, 43 | tokenizer_path: Path, 44 | destination_path: Path, 45 | chunk_size: int, 46 | split: str="train", 47 | filenames_subset: List[str] = None, 48 | process_id: int = 0 49 | ) -> None: 50 | import zstandard as zstd 51 | 52 | destination_path.mkdir(parents=True, exist_ok=True) 53 | 54 | tokenizer = Tokenizer(tokenizer_path) 55 | print(tokenizer_path, ' DDDDDDDDDDDDDDDDD') 56 | print(tokenizer.bos_id) 57 | 58 | # Use the provided filenames_subset or default to all filenames 59 | filenames = filenames_subset 60 | 61 | if not filenames: 62 | raise RuntimeError( 63 | f"No files matching {slimpajama_sets[split]} found at {source_path}. \n" 64 | "Make sure you download the data..." 65 | ) 66 | 67 | builder = packed_dataset.PackedDatasetBuilder( 68 | outdir=destination_path, 69 | prefix=f"{split}_slimpajama_{process_id}", # Use process_id to differentiate builders 70 | chunk_size=chunk_size, 71 | sep_token=tokenizer.bos_id, 72 | dtype="auto", 73 | vocab_size=tokenizer.vocab_size, 74 | ) 75 | builder_wiki = packed_dataset.PackedDatasetBuilder( 76 | outdir=destination_path, 77 | prefix=f"{split}_wikipedia_slimpajama_{process_id}", # Use process_id to differentiate builders 78 | chunk_size=chunk_size, 79 | sep_token=tokenizer.bos_id, 80 | dtype="auto", 81 | vocab_size=tokenizer.vocab_size, 82 | ) 83 | for filepath in filenames: 84 | print(f"Processing {filepath}") 85 | with zstd.open(open(filepath, "rb"), "rt", encoding="utf-8") as f: 86 | for row in tqdm(f): 87 | text = json.loads(row)["text"] 88 | text_ids = tokenizer.encode(text) 89 | if json.loads(row)["meta"]["redpajama_set_name"]=='RedPajamaBook': 90 | print("skip red pajama book!!!") 91 | continue 92 | if split == 'train' and json.loads(row)["meta"]["redpajama_set_name"] == "RedPajamaWikipedia": 93 | builder_wiki.add_array(np.array(text_ids, dtype=builder_wiki.dtype)) 94 | else: 95 | builder.add_array(np.array(text_ids, dtype=builder.dtype)) 96 | 97 | # we throw away the final corpus to avoid meaningless corpus filled with bos_ids, see https://github.com/jzhang38/TinyLlama/issues/83 for more details 98 | # builder.write_reminder() 99 | 100 | 101 | def prepare( 102 | source_path: Path = Path("data/RedPajama-Data-1T-Sample"), 103 | tokenizer_path: Path = Path("checkpoints/lit-llama/tokenizer.model"), 104 | destination_path: Path = Path("data/red_pajama_sample"), 105 | chunk_size: int = 2048 * 2049, 106 | split: str="train", 107 | percentage: float = 1.0, 108 | ) -> None: 109 | import time 110 | 111 | filenames = glob.glob(os.path.join(source_path, slimpajama_sets[split]), recursive=True) 112 | filenames = filenames[:int(len(filenames) * percentage)] 113 | 114 | num_processes = 16#cpu_count() 115 | chunked_filenames = np.array_split(filenames, num_processes) 116 | 117 | processes = [] 118 | start_time = time.time() 119 | 120 | for i, subset in enumerate(chunked_filenames): 121 | p = Process(target=prepare_full, args=(source_path, tokenizer_path, destination_path, chunk_size, split, list(subset), i)) 122 | processes.append(p) 123 | p.start() 124 | 125 | for p in processes: 126 | p.join() 127 | end_time = time.time() 128 | elapsed_time = end_time - start_time 129 | print(f"Time taken: {elapsed_time:.2f} seconds") 130 | 131 | 132 | if __name__ == "__main__": 133 | from jsonargparse import CLI 134 | CLI(prepare) 135 | -------------------------------------------------------------------------------- /lit_gpt/fused_cross_entropy.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018-2024 Advanced Micro Devices, Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | # Copyright (c) 2023, Tri Dao. 17 | 18 | import torch 19 | import torch.nn as nn 20 | import xentropy_cuda_lib 21 | 22 | # `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for 23 | # `_all_gather_base` and `_reduce_scatter_base`. They require the most recent 24 | # version of PyTorch. The following 2 lines are for backward compatibility with 25 | # older PyTorch. 26 | if "all_gather_into_tensor" not in dir(torch.distributed): 27 | torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base 28 | 29 | 30 | class SoftmaxCrossEntropyLossFn(torch.autograd.Function): 31 | @staticmethod 32 | def forward( 33 | ctx, 34 | logits, 35 | labels, 36 | smoothing=0.0, 37 | ignored_index=-100, 38 | inplace_backward=False, 39 | process_group=None, 40 | ): 41 | """ 42 | logits: (batch, vocab_size) 43 | labels: (batch,) 44 | If process_group is not None, we're doing Tensor Parallel: each process is responsible for 45 | one part of the vocab. The loss needs to be aggregated across processes. 46 | """ 47 | batch, vocab_size = logits.shape 48 | assert labels.shape == (batch,) 49 | world_size = 1 if process_group is None else torch.distributed.get_world_size(process_group) 50 | ctx.total_classes = world_size * vocab_size 51 | 52 | if world_size == 1: 53 | losses, lse = xentropy_cuda_lib.forward(logits, labels, smoothing) 54 | losses.masked_fill_(labels == ignored_index, 0) 55 | labels_local = labels 56 | else: 57 | rank = torch.distributed.get_rank(process_group) 58 | vocab_start_index, vocab_end_index = rank * vocab_size, (rank + 1) * vocab_size 59 | 60 | # Create a mask of valid vocab ids (1 means it needs to be masked). 61 | labels_mask = (labels < vocab_start_index) | (labels >= vocab_end_index) 62 | ignored_mask = labels == ignored_index 63 | labels_local = torch.where(ignored_mask, labels, labels - vocab_start_index) 64 | 65 | # For tensor parallel cross entropy with smoothing, we want to pass in the total number 66 | # of classes so that smoothing can be applied correctly. If total_classes=-1, use the 67 | # last dimension of the input tensor. 68 | losses, lse_local = xentropy_cuda_lib.forward( 69 | logits, labels_local, smoothing, world_size * vocab_size 70 | ) 71 | assert lse_local.shape == (batch,) 72 | assert losses.shape == (batch,) 73 | losses.masked_fill_(ignored_mask, 0) 74 | # For labels == ignored_index, the loss is always 0. 75 | # If there's no smoothing, if labels are in the vocab of this partition, losses contains 76 | # lse_local - predicted logit, and 0 otherwise. 77 | # If there's smoothing=0.1, for labels in the vocab of this partition, losses contains 78 | # 0.9 * (lse_local - predicted logit) + 0.1 * (lse_local - sum logit / total_classes) 79 | # For labels not in the vocab of this partition, losses contains 80 | # 0.1 * (lse_local - sum logit / total_classes). 81 | 82 | lse_allgather = torch.empty( 83 | world_size, batch, dtype=lse_local.dtype, device=lse_local.device 84 | ) 85 | torch.distributed.all_gather_into_tensor( 86 | lse_allgather, lse_local.contiguous(), group=process_group 87 | ) 88 | handle_losses = torch.distributed.all_reduce( 89 | losses, op=torch.distributed.ReduceOp.SUM, group=process_group, async_op=True 90 | ) 91 | lse = torch.logsumexp(lse_allgather, dim=0) 92 | # If there's no smoothing, the total losses are lse_local - predicted_logit, 93 | # we just have to subtract the lse_local and add the lse (global). 94 | # If there's smoothing=0.1, the total losses are 95 | # 0.9 * (lse_local - predicted_logit) + 0.1 * (sum of all lse_local - sum logit / total_classes) 96 | # We want 0.9 * (lse - predicted_logit) + 0.1 * (lse - sum logit / total_classes). 97 | rank_per_sample = torch.div(labels, vocab_size, rounding_mode="floor") 98 | lse_local = lse_allgather[ 99 | rank_per_sample, torch.arange(batch, device=lse_allgather.device) 100 | ] 101 | 102 | handle_losses.wait() 103 | if smoothing == 0.0: 104 | losses += lse - lse_local 105 | else: 106 | losses += (1 - smoothing) * (lse - lse_local) + smoothing * ( 107 | lse - lse_allgather.sum(dim=0) 108 | ) 109 | losses.masked_fill_(ignored_mask, 0) 110 | 111 | ctx.save_for_backward(logits, lse, labels_local) 112 | ctx.smoothing = smoothing 113 | ctx.ignored_index = ignored_index 114 | ctx.inplace_backward = inplace_backward 115 | return losses 116 | 117 | @staticmethod 118 | def backward(ctx, grad_loss): 119 | logits, lse, labels = ctx.saved_tensors 120 | grad_loss = grad_loss.contiguous() 121 | grad_loss.masked_fill_(labels == ctx.ignored_index, 0) 122 | grad_logits = xentropy_cuda_lib.backward( 123 | grad_loss, logits, lse, labels, ctx.smoothing, ctx.inplace_backward, ctx.total_classes 124 | ) 125 | return grad_logits, None, None, None, None, None, None 126 | 127 | 128 | class FusedCrossEntropyLoss(nn.Module): 129 | def __init__( 130 | self, 131 | ignore_index=-100, 132 | reduction="mean", 133 | label_smoothing=0.0, 134 | inplace_backward=True, 135 | process_group=None, 136 | ): 137 | super().__init__() 138 | if reduction not in ["mean", "none"]: 139 | raise NotImplementedError("Only support reduction = 'mean' or 'none'") 140 | self.ignore_index = ignore_index 141 | self.reduction = reduction 142 | self.label_smoothing = label_smoothing 143 | self.inplace_backward = inplace_backward 144 | self.process_group = process_group 145 | 146 | def forward(self, input, target): 147 | assert input.is_cuda and target.is_cuda 148 | # SoftmaxCrossEntropyLoss implicitly casts to float 149 | if len(input.shape) == 3: 150 | input = input.view(-1, input.size(-1)) 151 | target = target.view(-1) 152 | loss = SoftmaxCrossEntropyLossFn.apply( 153 | input, 154 | target, 155 | self.label_smoothing, 156 | self.ignore_index, 157 | self.inplace_backward, 158 | self.process_group, 159 | ) 160 | if self.reduction == "mean": 161 | return loss.sum() / (target != self.ignore_index).sum() 162 | else: 163 | return loss -------------------------------------------------------------------------------- /lit_gpt/packed_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018-2024 Advanced Micro Devices, Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import os 17 | import random 18 | import struct 19 | 20 | import numpy as np 21 | import torch 22 | from torch.utils.data import IterableDataset, get_worker_info 23 | 24 | dtypes = {1: np.uint8, 2: np.int8, 3: np.int16, 4: np.int32, 5: np.int64, 6: np.float32, 7: np.float64, 8: np.uint16} 25 | 26 | 27 | def code(dtype): 28 | for k in dtypes: 29 | if dtypes[k] == dtype: 30 | return k 31 | raise ValueError(dtype) 32 | 33 | 34 | HDR_MAGIC = b"LITPKDS" 35 | HDR_SIZE = 24 # bytes 36 | 37 | 38 | class PackedDataset(IterableDataset): 39 | def __init__( 40 | self, filenames, n_chunks, block_size, seed=12345, shuffle=True, wrap=False, num_processes=1, process_rank=0 41 | ): 42 | self._filenames = filenames 43 | self._n_chunks = n_chunks 44 | self._block_size = block_size 45 | self._seed = seed 46 | self._shuffle = shuffle 47 | self._wrap = wrap 48 | self._num_processes = num_processes 49 | self._process_rank = process_rank 50 | 51 | def __iter__(self): 52 | worker_info = get_worker_info() 53 | num_workers = worker_info.num_workers if worker_info is not None else 1 54 | worker_id = worker_info.id if worker_info is not None else 0 55 | num_shards = num_workers * self._num_processes 56 | shard_id = self._process_rank * num_workers + worker_id 57 | 58 | max_num_files = len(self._filenames) // num_shards * num_shards 59 | filenames = self._filenames[shard_id:max_num_files:num_shards] 60 | 61 | return PackedDatasetIterator( 62 | filenames=filenames, 63 | n_chunks=self._n_chunks, 64 | block_size=self._block_size, 65 | seed=self._seed, 66 | shuffle=self._shuffle, 67 | wrap=self._wrap, 68 | ) 69 | 70 | 71 | class PackedDatasetBuilder(object): 72 | def __init__(self, outdir, prefix, chunk_size, sep_token, dtype="auto", vocab_size=None): 73 | print("++++++++++{}".format(sep_token)) 74 | if dtype == "auto": 75 | if vocab_size is None: 76 | raise ValueError("vocab_size cannot be None when dtype='auto'") 77 | if vocab_size is not None and vocab_size < 65500: 78 | self._dtype = np.uint16 79 | else: 80 | self._dtype = np.int32 81 | else: 82 | self._dtype = dtype 83 | self._counter = 0 84 | self._chunk_size = chunk_size 85 | self._outdir = outdir 86 | self._prefix = prefix 87 | self._sep_token = sep_token 88 | self._arr = np.zeros(self._chunk_size, dtype=self._dtype) 89 | self._arr.fill(self._sep_token) 90 | self._idx = 0 91 | self._version = 1 92 | self._filenames = [] 93 | 94 | def _write_chunk(self): 95 | filename = f"{self._prefix}_{self._counter:010d}.bin" 96 | filename = os.path.join(self._outdir, filename) 97 | 98 | with open(filename, "wb") as f: 99 | f.write(HDR_MAGIC) 100 | f.write(struct.pack(" self._chunk_size: 120 | part_len = self._chunk_size - self._idx 121 | self._arr[self._idx : self._idx + part_len] = arr[:part_len] 122 | self._write_chunk() 123 | arr = arr[part_len:] 124 | 125 | arr_len = arr.shape[0] 126 | self._arr[self._idx : self._idx + arr_len] = arr 127 | self._idx += arr_len 128 | 129 | def write_reminder(self): 130 | self._write_chunk() 131 | 132 | 133 | class PackedDatasetIterator: 134 | def __init__(self, filenames, n_chunks, block_size, seed, shuffle, wrap): 135 | self._seed = seed 136 | self._shuffle = shuffle 137 | self._rng = np.random.default_rng(seed) if shuffle else None 138 | self._block_idxs = None 139 | 140 | self._wrap = wrap 141 | 142 | # TODO: instead of filenames, we could have a single text stream 143 | # (or text file) with the sequence of all files to be 144 | # fetched/loaded. 145 | self._filenames = filenames 146 | self._file_idx = 0 147 | 148 | self._n_chunks = n_chunks 149 | 150 | self._dtype = None 151 | self._block_size = block_size 152 | self._n_blocks = None 153 | 154 | self._mmaps = [] 155 | self._buffers = [] 156 | 157 | self._block_idxs = [] 158 | self._curr_idx = 0 159 | 160 | self._load_n_chunks() 161 | 162 | def _read_header(self, path): 163 | with open(path, "rb") as f: 164 | magic = f.read(len(HDR_MAGIC)) 165 | assert magic == HDR_MAGIC, "File doesn't match expected format." 166 | version = struct.unpack(" len(self._filenames[self._file_idx :]): 183 | # if not self._wrap: 184 | # raise StopIteration 185 | self._file_idx = 0 186 | actual_n_chunks = min(self._n_chunks, len(self._filenames[self._file_idx :])) 187 | for i in range(actual_n_chunks): 188 | filename = self._filenames[self._file_idx + i] 189 | if self._dtype is None: 190 | self._dtype, self._chunk_size = self._read_header(filename) 191 | self._n_blocks = self._chunk_size // self._block_size 192 | # TODO: check header matches with previous files 193 | mmap = np.memmap(filename, mode="r", order="C", offset=HDR_SIZE) 194 | self._mmaps.append(mmap) 195 | self._buffers.append(memoryview(mmap)) 196 | 197 | self._file_idx += actual_n_chunks 198 | n_all_blocks = actual_n_chunks * self._n_blocks 199 | 200 | self._block_idxs = self._rng.permutation(n_all_blocks) if self._shuffle else range(n_all_blocks) 201 | 202 | self._curr_idx = 0 203 | 204 | def __del__(self): 205 | self._close_mmaps() 206 | del self._mmaps 207 | del self._buffers 208 | 209 | def __iter__(self): 210 | return self 211 | 212 | def __next__(self): 213 | if self._curr_idx >= len(self._block_idxs): 214 | self._load_n_chunks() 215 | # TODO: trigger fetching next next n_chunks if remote 216 | block_idx = self._block_idxs[self._curr_idx] 217 | chunk_id = block_idx // self._n_blocks 218 | buffer = self._buffers[chunk_id] 219 | elem_id = (block_idx % self._n_blocks) * self._block_size 220 | offset = np.dtype(self._dtype).itemsize * elem_id 221 | arr = np.frombuffer(buffer, dtype=self._dtype, count=self._block_size, offset=offset) 222 | self._curr_idx += 1 223 | return torch.from_numpy(arr.astype(np.int64)) 224 | 225 | 226 | class CombinedDataset(IterableDataset): 227 | def __init__(self, datasets, seed, weights=None): 228 | self._seed = seed 229 | self._datasets = datasets 230 | self._weights = weights 231 | n_datasets = len(datasets) 232 | if weights is None: 233 | self._weights = [1 / n_datasets] * n_datasets 234 | 235 | def __iter__(self): 236 | return CombinedDatasetIterator(self._datasets, self._seed, self._weights) 237 | 238 | 239 | class CombinedDatasetIterator: 240 | def __init__(self, datasets, seed, weights): 241 | self._datasets = [iter(el) for el in datasets] 242 | self._weights = weights 243 | self._rng = random.Random(seed) 244 | 245 | def __next__(self): 246 | (dataset,) = self._rng.choices(self._datasets, weights=self._weights, k=1) 247 | return next(dataset) 248 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # AMD-135M 2 | This repository provides the implementation for training AMD-135M models and is based on [TinyLlama](https://github.com/jzhang38/TinyLlama). 3 | 4 | AMD-135M is a language model trained on AMD MI250 GPUs. Based on LLaMA2 model architecture, this model can be smoothly loaded as LlamaForCausalLM with huggingface transformers. Furthermore, we use the same tokenizer as LLaMA2, enableing it to be a draft model of speculative decoding for LLaMA2 and CodeLlama. 5 | 6 | ### Docker image 7 | Please use the following rocm docker in [docker hub](https://hub.docker.com/layers/rocm/pytorch/rocm6.1_ubuntu20.04_py3.9_pytorch_2.3.0_preview/images/sha256-0136f3e678290e0ae78cdd78c90d9f849ee3ac3602864c486e0252f8f8b9662b?context=explore) 8 | 9 | `docker pull rocm/pytorch:rocm6.1_ubuntu20.04_py3.9_pytorch_2.3.0_preview` 10 | 11 | ### Python packages dependency 12 | Please run `pip install -r requirement.txt` to install extra python packages based on the docker above. 13 | 14 | ### Dataset 15 | Step 1, download [SlimPajama-627](https://huggingface.co/datasets/cerebras/SlimPajama-627B), [project gutenberg](https://huggingface.co/datasets/manu/project_gutenberg) and [StarCoder](https://huggingface.co/datasets/bigcode/starcoderdata). 16 | 17 | ```bash 18 | git clone https://huggingface.co/datasets/cerebras/SlimPajama-627B 19 | git clone https://huggingface.co/datasets/manu/project_gutenberg 20 | git clone https://huggingface.co/datasets/bigcode/starcoderdata 21 | ``` 22 | 23 | Step 2, process the text data into token ids. And you will find the processed dataset at `./slim_processed`, `./slim_validation_processed` and `./starcoderdata_python_processed`. 24 | 25 | ```bash 26 | # For pretraining 27 | bash ./scripts/prepare_slimpajama_train.sh 28 | bash ./scripts/prepare_project_gutenberg.sh 29 | # For validation 30 | bash ./scripts/prepare_slimpajama_valid.sh 31 | # For code finetuning 32 | bash ./scripts/prepare_starcoder_python.sh 33 | ``` 34 | 35 | ### Pretraining 36 | To train a tinyllama model, please run the following scripts on 4 nodes, 4 MI250 GPUs (8 vitural devices) for each node. 37 | 38 | ```bash 39 | # run on node 0. 40 | bash ./cluster/pretrain_node_0.sh 41 | # run on node 1. 42 | bash ./cluster/pretrain_node_1.sh 43 | # run on node 2. 44 | bash ./cluster/pretrain_node_2.sh 45 | # run on node 3. 46 | bash ./cluster/pretrain_node_3.sh 47 | ``` 48 | 49 | ### Code Finetuning 50 | To finetune a tinyllama model, please run the following script. 51 | 52 | ```bash 53 | bash ./cluster/finetune.sh 54 | ``` 55 | 56 | ### Evaluation 57 | We evaluate AMD-Llama-135m using [lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness) on popular NLP benchmarks and results are listed as follows. 58 | 59 | | **Model** | **SciQ** | **WinoGrande** | **PIQA** | **WSC** | **MMLU** | **Lambada (OpenAI)** | **ARC - Easy** | **ARC - Challenge** | **LogiQA** | **Hellaswag** | 60 | |----------------------|---------------|----------------|---------------|---------------|---------------|----------------------|----------------|---------------------|---------------|---------------| 61 | | GPT2-124M (small) | 0.753±0.0136 | 0.5162±0.0140 | 0.6289±0.0113 | 0.4327±0.0488 | 0.2292±0.0383 | 0.3256±0.0065 | 0.4381±0.0102 | 0.1903±0.0115 | 0.2181±0.0162 | 0.2892±0.0045 | 62 | | OPT-125M | 0.751±0.014 | 0.503±0.014 | 0.630±0.011 | 0.365±0.047 | 0.229±0.038 | 0.379±0.007 | 0.436±0.010 | 0.191±0.012 | 0.229±0.016 | 0.292±0.004 | 63 | | JackFram/llama-68m | 0.652±0.0151 | 0.513±0.014 | 0.6197±0.0113 | 0.4038±0.0483 | 0.2302±0.0035 | 0.1351±0.0048 | 0.3864±0.0100 | 0.1792±0.0112 | 0.2273±0.0164 | 0.2790±0.0045 | 64 | | JackFram/llama-160m | 0.724±0.0141 | 0.5012±0.0141 | 0.6605±0.011 | 0.3654±0.0474 | 0.2299±0.0035 | 0.3134±0.0065 | 0.4335±0.0102 | 0.1980±0.0116 | 0.2197±0.0162 | 0.3094±0.0046 | 65 | | [AMD-Llama-135m](https://huggingface.co/amd/AMD-Llama-135m) | 0.761±0.0135 | 0.5012±0.0141 | 0.6420±0.0112 | 0.3654±0.0474 | 0.2302±0.0035 | 0.3330±0.0066 | 0.4364±0.0102 | 0.1911±0.0115 | 0.2120±0.0160 | 0.3048±0.0046 | 66 | 67 | 68 | ### Speculative Decoding 69 | To run speculative decoding using AMD-Llama-135m-code as draft model for CodeLlama-7b on [Humaneval](https://huggingface.co/datasets/openai_humaneval) dataset, please run the following script. 70 | 71 | ```bash 72 | # Need add some logs for huggingface transformers==4.37.2 to calculate the acceptance rate of speculative decoding. 73 | patch -u /path/to/transformers/generation/utils.py -i ./speculative_decoding/utils.patch 74 | bash ./speculative_decoding/codellama_spec.sh 75 | ``` 76 | 77 | We evaluate performance of decoding with target model only and speculative decoding on MI250 GPU and Ryzen AI CPU (with NPU kernel). All experiments are run on Humaneval dataset. 78 | 79 | | Target Model Device | Draft Model Device | Do Randomly Sampling | Target model Humaneval Pass@1 | Speculative Decoding Humaneval Pass@1 | Acceptance Rate | Throughput Speedup | 80 | |:----------------------|:---------------------|:-----------------------|-------------------------------:|---------------------------------------:|----------------:|-------------------:| 81 | | FP32 MI250 | FP32 MI250 | TRUE | 32.31% | 29.27% | 0.650355 | 2.58x | 82 | | FP32 MI250 | FP32 MI250 | FALSE | 31.10% | 31.10% | 0.657839 | **2.80x** | 83 | | BF16 MI250 | BF16 MI250 | TRUE | 31.10% | 31.10% | 0.668822 | 1.67x | 84 | | BF16 MI250 | BF16 MI250 | FALSE | 34.15% | 33.54% | 0.665497 | 1.75x | 85 | | INT4 NPU | BF16 CPU | TRUE | 28.05% | 30.49% | 0.722913 | 2.83x | 86 | | INT4 NPU | BF16 CPU | FALSE | 28.66% | 28.66% | 0.738072 | **2.98x** | 87 | | BF16 CPU | BF16 CPU | TRUE | 31.10% | 31.71% | 0.723971 | 3.68x | 88 | | BF16 CPU | BF16 CPU | FALSE | 33.54% | 33.54% | 0.727548 | **3.88x** | 89 | | FP32 CPU | FP32 CPU | TRUE | 29.87% | 28.05% | 0.727214 | 3.57x | 90 | | FP32 CPU | FP32 CPU | FALSE | 31.10% | 31.10% | 0.738641 | 3.66x | 91 | 92 | 93 | ## Training and finetuning cost 94 | It takes 6 days to pretrain AMD-Llama-135m on 4 MI250 nodes each of which has 4 MI250 GPUs (8 virtual GPU cards, 64G memory for each). 95 | It takes 4 days to finetune AMD-Llama-135m-code on 4 MI250 GPUs. 96 | It takes 11T disk space to store raw and processed SlimPajama, project gutenberg and Starcoder datasets. 97 | 98 | 99 | #### ROCM 100 | ``` 101 | Version: 6.1.2.60102-119~20.04 102 | Priority: optional 103 | Section: devel 104 | Maintainer: ROCm Dev Support 105 | Installed-Size: 13.3 kB 106 | Depends: hipblas (= 2.1.0.60102-119~20.04), hipblaslt (= 0.7.0.60102-119~20.04), hipfft (= 1.0.14.60102-119~20.04), hipsolver (= 2.1.1.60102-119~20.04), hipsparse (= 3.0.1.60102-119~20.04), hiptensor (= 1.2.0.60102-119~20.04), miopen-hip (= 3.1.0.60102-119~20.04), half (= 1.12.0.60102-119~20.04), rccl (= 2.18.6.60102-119~20.04), rocalution (= 3.1.1.60102-119~20.04), rocblas (= 4.1.2.60102-119~20.04), rocfft (= 1.0.27.60102-119~20.04), rocrand (= 3.0.1.60102-119~20.04), hiprand (= 2.10.16.60102-119~20.04), rocsolver (= 3.25.0.60102-119~20.04), rocsparse (= 3.1.2.60102-119~20.04), rocm-core (= 6.1.2.60102-119~20.04), hipsparselt (= 0.2.0.60102-119~20.04), composablekernel-dev (= 1.1.0.60102-119~20.04), hipblas-dev (= 2.1.0.60102-119~20.04), hipblaslt-dev (= 0.7.0.60102-119~20.04), hipcub-dev (= 3.1.0.60102-119~20.04), hipfft-dev (= 1.0.14.60102-119~20.04), hipsolver-dev (= 2.1.1.60102-119~20.04), hipsparse-dev (= 3.0.1.60102-119~20.04), hiptensor-dev (= 1.2.0.60102-119~20.04), miopen-hip-dev (= 3.1.0.60102-119~20.04), rccl-dev (= 2.18.6.60102-119~20.04), rocalution-dev (= 3.1.1.60102-119~20.04), rocblas-dev (= 4.1.2.60102-119~20.04), rocfft-dev (= 1.0.27.60102-119~20.04), rocprim-dev (= 3.1.0.60102-119~20.04), rocrand-dev (= 3.0.1.60102-119~20.04), hiprand-dev (= 2.10.16.60102-119~20.04), rocsolver-dev (= 3.25.0.60102-119~20.04), rocsparse-dev (= 3.1.2.60102-119~20.04), rocthrust-dev (= 3.0.1.60102-119~20.04), rocwmma-dev (= 1.4.0.60102-119~20.04), hipsparselt-dev (= 0.2.0.60102-119~20.04) 107 | Homepage: 108 | https://github.com/RadeonOpenCompute/ROCm 109 | Download-Size: 1064 B 110 | APT-Manual-Installed: yes 111 | APT-Sources: 112 | http://repo.radeon.com/rocm/apt/6.1.2 113 | focal/main amd64 Packages 114 | Description: Radeon Open Compute (ROCm) Runtime software stack 115 | ``` 116 | ### System info 117 | ``` 118 | Ubuntu 22.04.3 LTS 119 | Release: 22.04 120 | Codename: jammy 121 | 122 | Linux version 5.15.0-88-generic (buildd@lcy02-amd64-058) (gcc (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0, GNU ld (GNU Binutils for Ubuntu) 2.38) #98-Ubuntu SMP Mon Oct 2 15:18:56 UTC 2023 123 | 124 | Linux sjc144-canary-node035.dcgpu.amd.com 5.15.0-88-generic #98-Ubuntu SMP Mon Oct 2 15:18:56 UTC 2023 x86_64 x86_64 x86_64 GNU/Linux 125 | ``` 126 | #### License 127 | Copyright (c) 2018-2024 Advanced Micro Devices, Inc. All Rights Reserved. 128 | 129 | Licensed under the Apache License, Version 2.0 (the "License"); 130 | you may not use this file except in compliance with the License. 131 | You may obtain a copy of the License at 132 | 133 | http://www.apache.org/licenses/LICENSE-2.0 134 | 135 | Unless required by applicable law or agreed to in writing, software 136 | distributed under the License is distributed on an "AS IS" BASIS, 137 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 138 | See the License for the specific language governing permissions and 139 | limitations under the License. 140 | -------------------------------------------------------------------------------- /scripts/convert_hf_checkpoint.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018-2024 Advanced Micro Devices, Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import contextlib 17 | import gc 18 | import json 19 | import sys 20 | from functools import partial 21 | from pathlib import Path 22 | from typing import Dict, List, Literal, Optional, Tuple, Union 23 | 24 | import torch 25 | 26 | # support running without installing as a package 27 | wd = Path(__file__).parent.parent.resolve() 28 | sys.path.append(str(wd)) 29 | 30 | from lit_gpt import Config 31 | from lit_gpt.utils import NotYetLoadedTensor, incremental_save, lazy_load 32 | 33 | 34 | def copy_weights_gpt_neox( 35 | state_dict: Dict[str, torch.Tensor], 36 | hf_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]], 37 | saver: Optional[incremental_save] = None, 38 | dtype: Optional[torch.dtype] = None, 39 | ) -> None: 40 | weight_map = { 41 | "gpt_neox.embed_in.weight": "transformer.wte.weight", 42 | "gpt_neox.layers.{}.input_layernorm.bias": "transformer.h.{}.norm_1.bias", 43 | "gpt_neox.layers.{}.input_layernorm.weight": "transformer.h.{}.norm_1.weight", 44 | "gpt_neox.layers.{}.attention.query_key_value.bias": "transformer.h.{}.attn.attn.bias", 45 | "gpt_neox.layers.{}.attention.query_key_value.weight": "transformer.h.{}.attn.attn.weight", 46 | "gpt_neox.layers.{}.attention.dense.bias": "transformer.h.{}.attn.proj.bias", 47 | "gpt_neox.layers.{}.attention.dense.weight": "transformer.h.{}.attn.proj.weight", 48 | "gpt_neox.layers.{}.attention.rotary_emb.inv_freq": None, 49 | "gpt_neox.layers.{}.attention.bias": None, 50 | "gpt_neox.layers.{}.attention.masked_bias": None, 51 | "gpt_neox.layers.{}.post_attention_layernorm.bias": "transformer.h.{}.norm_2.bias", 52 | "gpt_neox.layers.{}.post_attention_layernorm.weight": "transformer.h.{}.norm_2.weight", 53 | "gpt_neox.layers.{}.mlp.dense_h_to_4h.bias": "transformer.h.{}.mlp.fc.bias", 54 | "gpt_neox.layers.{}.mlp.dense_h_to_4h.weight": "transformer.h.{}.mlp.fc.weight", 55 | "gpt_neox.layers.{}.mlp.dense_4h_to_h.bias": "transformer.h.{}.mlp.proj.bias", 56 | "gpt_neox.layers.{}.mlp.dense_4h_to_h.weight": "transformer.h.{}.mlp.proj.weight", 57 | "gpt_neox.final_layer_norm.bias": "transformer.ln_f.bias", 58 | "gpt_neox.final_layer_norm.weight": "transformer.ln_f.weight", 59 | "embed_out.weight": "lm_head.weight", 60 | } 61 | 62 | for name, param in hf_weights.items(): 63 | if "gpt_neox.layers" in name: 64 | from_name, number = layer_template(name, 2) 65 | to_name = weight_map[from_name] 66 | if to_name is None: 67 | continue 68 | to_name = to_name.format(number) 69 | else: 70 | to_name = weight_map[name] 71 | param = load_param(param, name, dtype) 72 | if saver is not None: 73 | param = saver.store_early(param) 74 | state_dict[to_name] = param 75 | 76 | 77 | def copy_weights_falcon( 78 | size: Literal["7b", "40b"], 79 | state_dict: Dict[str, torch.Tensor], 80 | hf_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]], 81 | saver: Optional[incremental_save] = None, 82 | dtype: Optional[torch.dtype] = None, 83 | ) -> None: 84 | weight_map = { 85 | "transformer.word_embeddings.weight": "transformer.wte.weight", 86 | "transformer.h.{}.self_attention.query_key_value.weight": "transformer.h.{}.attn.attn.weight", 87 | "transformer.h.{}.self_attention.dense.weight": "transformer.h.{}.attn.proj.weight", 88 | "transformer.h.{}.mlp.dense_h_to_4h.weight": "transformer.h.{}.mlp.fc.weight", 89 | "transformer.h.{}.mlp.dense_4h_to_h.weight": "transformer.h.{}.mlp.proj.weight", 90 | "transformer.ln_f.bias": "transformer.ln_f.bias", 91 | "transformer.ln_f.weight": "transformer.ln_f.weight", 92 | "lm_head.weight": "lm_head.weight", 93 | } 94 | # the original model definition is different for each size 95 | if size == "7b": 96 | weight_map.update( 97 | { 98 | "transformer.h.{}.input_layernorm.bias": "transformer.h.{}.norm_1.bias", 99 | "transformer.h.{}.input_layernorm.weight": "transformer.h.{}.norm_1.weight", 100 | } 101 | ) 102 | elif size == "40b": 103 | weight_map.update( 104 | { 105 | "transformer.h.{}.ln_attn.bias": "transformer.h.{}.norm_1.bias", 106 | "transformer.h.{}.ln_attn.weight": "transformer.h.{}.norm_1.weight", 107 | "transformer.h.{}.ln_mlp.bias": "transformer.h.{}.norm_2.bias", 108 | "transformer.h.{}.ln_mlp.weight": "transformer.h.{}.norm_2.weight", 109 | } 110 | ) 111 | else: 112 | raise NotImplementedError 113 | 114 | for name, param in hf_weights.items(): 115 | if "transformer.h" in name: 116 | from_name, number = layer_template(name, 2) 117 | to_name = weight_map[from_name].format(number) 118 | else: 119 | to_name = weight_map[name] 120 | param = load_param(param, name, dtype) 121 | if saver is not None: 122 | param = saver.store_early(param) 123 | state_dict[to_name] = param 124 | 125 | 126 | def copy_weights_hf_llama( 127 | config: Config, 128 | qkv_weights: Dict[int, List[Optional[NotYetLoadedTensor]]], 129 | state_dict: Dict[str, torch.Tensor], 130 | hf_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]], 131 | saver: Optional[incremental_save] = None, 132 | dtype: Optional[torch.dtype] = None, 133 | ) -> None: 134 | weight_map = { 135 | "model.embed_tokens.weight": "transformer.wte.weight", 136 | "model.layers.{}.input_layernorm.weight": "transformer.h.{}.norm_1.weight", 137 | "model.layers.{}.self_attn.q_proj.weight": None, 138 | "model.layers.{}.self_attn.k_proj.weight": None, 139 | "model.layers.{}.self_attn.v_proj.weight": None, 140 | "model.layers.{}.self_attn.o_proj.weight": "transformer.h.{}.attn.proj.weight", 141 | "model.layers.{}.self_attn.rotary_emb.inv_freq": None, 142 | "model.layers.{}.post_attention_layernorm.weight": "transformer.h.{}.norm_2.weight", 143 | "model.layers.{}.mlp.gate_proj.weight": "transformer.h.{}.mlp.swiglu.w1.weight", 144 | "model.layers.{}.mlp.up_proj.weight": "transformer.h.{}.mlp.swiglu.w2.weight", 145 | "model.layers.{}.mlp.down_proj.weight": "transformer.h.{}.mlp.swiglu.w3.weight", 146 | "model.norm.weight": "transformer.ln_f.weight", 147 | "lm_head.weight": "lm_head.weight", 148 | } 149 | 150 | for name, param in hf_weights.items(): 151 | if "model.layers" in name: 152 | from_name, number = layer_template(name, 2) 153 | qkv = qkv_weights.setdefault(number, [None, None, None]) 154 | if "q_proj" in name: 155 | qkv[0] = param 156 | elif "k_proj" in name: 157 | qkv[1] = param 158 | elif "v_proj" in name: 159 | qkv[2] = param 160 | to_name = weight_map[from_name] 161 | if to_name is None: 162 | continue 163 | to_name = to_name.format(number) 164 | else: 165 | to_name = weight_map[name] 166 | param = load_param(param, name, dtype) 167 | if saver is not None: 168 | param = saver.store_early(param) 169 | state_dict[to_name] = param 170 | 171 | for i, (q, k, v) in list(qkv_weights.items()): 172 | if q is None or k is None or v is None: 173 | # split across different .bin files 174 | continue 175 | q = load_param(q, f"layer {i} q", dtype) 176 | k = load_param(k, f"layer {i} k", dtype) 177 | v = load_param(v, f"layer {i} v", dtype) 178 | q_per_kv = config.n_head // config.n_query_groups 179 | qs = torch.split(q, config.head_size * q_per_kv) 180 | ks = torch.split(k, config.head_size) 181 | vs = torch.split(v, config.head_size) 182 | cycled = [t for group in zip(qs, ks, vs) for t in group] 183 | qkv = torch.cat(cycled) 184 | state_dict[f"transformer.h.{i}.attn.attn.weight"] = qkv 185 | del qkv_weights[i] 186 | 187 | 188 | def layer_template(layer_name: str, idx: int) -> Tuple[str, int]: 189 | split = layer_name.split(".") 190 | number = int(split[idx]) 191 | split[idx] = "{}" 192 | from_name = ".".join(split) 193 | return from_name, number 194 | 195 | 196 | def load_param(param: Union[torch.Tensor, NotYetLoadedTensor], name: str, dtype: Optional[torch.dtype]) -> torch.Tensor: 197 | if hasattr(param, "_load_tensor"): 198 | # support tensors loaded via `lazy_load()` 199 | print(f"Loading {name!r} into RAM") 200 | param = param._load_tensor() 201 | if dtype is not None and type(dtype) is not NotYetLoadedTensor and dtype != param.dtype: 202 | print(f"Converting {name!r} from {param.dtype} to {dtype}") 203 | param = param.to(dtype) 204 | return param 205 | 206 | 207 | @torch.inference_mode() 208 | def convert_hf_checkpoint( 209 | *, 210 | checkpoint_dir: Path = Path("checkpoints/stabilityai/stablelm-base-alpha-3b"), 211 | model_name: Optional[str] = None, 212 | dtype: Optional[str] = None, 213 | ) -> None: 214 | if model_name is None: 215 | model_name = checkpoint_dir.name 216 | if dtype is not None: 217 | dtype = getattr(torch, dtype) 218 | 219 | config = Config.from_name(model_name) 220 | print(f"Model config {config.__dict__}") 221 | with open(checkpoint_dir / "lit_config.json", "w") as json_config: 222 | json.dump(config.__dict__, json_config) 223 | 224 | if "falcon" in model_name: 225 | copy_fn = partial(copy_weights_falcon, "40b" if config.n_embd == 8192 else "7b") 226 | elif config._mlp_class == "LLaMAMLP": 227 | # holder to reconstitute the split q, k, v 228 | qkv_weights = {} 229 | copy_fn = partial(copy_weights_hf_llama, config, qkv_weights) 230 | else: 231 | copy_fn = copy_weights_gpt_neox 232 | 233 | # initialize a new empty state dict to hold our new weights 234 | sd = {} 235 | 236 | # Load the json file containing weight mapping 237 | pytorch_bin_map_json_path = checkpoint_dir / "pytorch_model.bin.index.json" 238 | if pytorch_bin_map_json_path.is_file(): # not all checkpoints have this file 239 | with open(pytorch_bin_map_json_path) as json_map: 240 | bin_index = json.load(json_map) 241 | bin_files = {checkpoint_dir / bin for bin in bin_index["weight_map"].values()} 242 | else: 243 | bin_files = set(checkpoint_dir.glob("*.bin")) 244 | if not bin_files: 245 | raise ValueError(f"Expected {str(checkpoint_dir)!r} to contain .bin files") 246 | 247 | with incremental_save(checkpoint_dir / "lit_model.pth") as saver: 248 | # for checkpoints that split the QKV across several files, we need to keep all the bin files 249 | # open, so we use `ExitStack` to close them all together at the end 250 | with contextlib.ExitStack() as stack: 251 | for bin_file in sorted(bin_files): 252 | print("Processing", bin_file) 253 | hf_weights = stack.enter_context(lazy_load(bin_file)) 254 | copy_fn(sd, hf_weights, saver=None, dtype=dtype) 255 | gc.collect() 256 | print("Saving converted checkpoint") 257 | saver.save(sd) 258 | 259 | 260 | if __name__ == "__main__": 261 | from jsonargparse import CLI 262 | 263 | CLI(convert_hf_checkpoint) 264 | -------------------------------------------------------------------------------- /lit_gpt/adapter.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018-2024 Advanced Micro Devices, Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | """Implementation of the paper: 17 | 18 | LLaMA-Adapter: Efficient Fine-tuning of Language Models with Zero-init Attention 19 | https://arxiv.org/abs/2303.16199 20 | 21 | Port for Lit-GPT 22 | """ 23 | from dataclasses import dataclass 24 | from typing import Any, Dict, List, Optional, Tuple, Union 25 | 26 | import torch 27 | import torch.nn as nn 28 | from typing_extensions import Self 29 | 30 | from lit_gpt.config import Config as BaseConfig 31 | from lit_gpt.model import GPT as BaseModel 32 | from lit_gpt.model import CausalSelfAttention as BaseCausalSelfAttention 33 | from lit_gpt.model import KVCache, RoPECache, apply_rope 34 | 35 | 36 | @dataclass 37 | class Config(BaseConfig): 38 | adapter_prompt_length: int = 10 39 | adapter_start_layer: int = 2 40 | 41 | 42 | class GPT(BaseModel): 43 | """The implementation is identical to `lit_gpt.model.GPT` with the exception that 44 | the `Block` saves the layer index and passes it down to the attention layer.""" 45 | 46 | def __init__(self, config: Config) -> None: 47 | nn.Module.__init__(self) 48 | assert config.padded_vocab_size is not None 49 | self.config = config 50 | 51 | self.lm_head = nn.Linear(config.n_embd, config.padded_vocab_size, bias=False) 52 | self.transformer = nn.ModuleDict( 53 | dict( 54 | wte=nn.Embedding(config.padded_vocab_size, config.n_embd), 55 | h=nn.ModuleList(Block(config, i) for i in range(config.n_layer)), 56 | ln_f=config.norm_class(config.n_embd, eps=config.norm_eps), 57 | ) 58 | ) 59 | 60 | self.rope_cache: Optional[RoPECache] = None 61 | self.mask_cache: Optional[torch.Tensor] = None 62 | self.kv_caches: List[KVCache] = [] 63 | self.adapter_kv_caches: List[KVCache] = [] 64 | 65 | def reset_cache(self) -> None: 66 | super().reset_cache() 67 | self.adapter_kv_caches.clear() 68 | 69 | def forward( 70 | self, 71 | idx: torch.Tensor, 72 | max_seq_length: Optional[int] = None, 73 | input_pos: Optional[torch.Tensor] = None, 74 | lm_head_chunk_size: int = 0, 75 | ) -> Union[torch.Tensor, List[torch.Tensor]]: 76 | B, T = idx.size() 77 | use_kv_cache = input_pos is not None 78 | 79 | block_size = self.config.block_size 80 | if max_seq_length is None: 81 | max_seq_length = block_size 82 | if use_kv_cache: # not relevant otherwise 83 | assert ( 84 | max_seq_length >= T 85 | ), f"Cannot forward sequence of length {T}, max seq length is only {max_seq_length}" 86 | assert max_seq_length <= block_size, f"Cannot attend to {max_seq_length}, block size is only {block_size}" 87 | assert block_size >= T, f"Cannot forward sequence of length {T}, block size is only {block_size}" 88 | 89 | if self.rope_cache is None: 90 | self.rope_cache = self.build_rope_cache(idx) 91 | # passing `attn_mask` to SDPA downgrades it to use the inefficient implementation. since we only need the mask 92 | # for the kv-cache support (only during inference), we only create it in that situation 93 | # this will be resolved by https://github.com/pytorch/pytorch/issues/96099 94 | if use_kv_cache and self.mask_cache is None: 95 | self.mask_cache = self.build_mask_cache(idx) 96 | 97 | cos, sin = self.rope_cache 98 | if use_kv_cache: 99 | cos = cos.index_select(0, input_pos) 100 | sin = sin.index_select(0, input_pos) 101 | mask = self.mask_cache.index_select(2, input_pos) 102 | mask = mask[:, :, :, :max_seq_length] 103 | else: 104 | cos = cos[:T] 105 | sin = sin[:T] 106 | mask = None 107 | 108 | # forward the model itself 109 | x = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd) 110 | 111 | if not use_kv_cache: 112 | for block in self.transformer.h: 113 | x, *_ = block(x, (cos, sin), max_seq_length) 114 | else: 115 | self.kv_caches = self.kv_caches or self.build_kv_caches(x, max_seq_length, cos.size(-1)) 116 | self.adapter_kv_caches = self.adapter_kv_caches or [None for _ in range(self.config.n_layer)] 117 | for i, block in enumerate(self.transformer.h): 118 | x, self.kv_caches[i], self.adapter_kv_caches[i] = block( 119 | x, (cos, sin), max_seq_length, mask, input_pos, self.kv_caches[i], self.adapter_kv_caches[i] 120 | ) 121 | 122 | x = self.transformer.ln_f(x) 123 | 124 | if lm_head_chunk_size > 0: 125 | # chunk the lm head logits to reduce the peak memory used by autograd 126 | return [self.lm_head(x_i) for x_i in x.split(lm_head_chunk_size, dim=1)] 127 | return self.lm_head(x) # (b, t, vocab_size) 128 | 129 | @classmethod 130 | def from_name(cls, name: str, **kwargs: Any) -> Self: 131 | return cls(Config.from_name(name, **kwargs)) 132 | 133 | def _init_weights(self, module: nn.Module) -> None: 134 | """Meant to be used with `gpt.apply(gpt._init_weights)`. Unused method left for completeness.""" 135 | super()._init_weights(module) 136 | if isinstance(module, CausalSelfAttention): 137 | module.reset_parameters() 138 | 139 | 140 | class Block(nn.Module): 141 | """The implementation is identical to `lit_gpt.model.Block` with the exception that 142 | we replace the attention layer where adaption is implemented.""" 143 | 144 | def __init__(self, config: Config, block_idx: int) -> None: 145 | super().__init__() 146 | self.norm_1 = config.norm_class(config.n_embd, eps=config.norm_eps) 147 | self.attn = CausalSelfAttention(config, block_idx) 148 | if not config.shared_attention_norm: 149 | self.norm_2 = config.norm_class(config.n_embd, eps=config.norm_eps) 150 | self.mlp = config.mlp_class(config) 151 | 152 | self.config = config 153 | 154 | def forward( 155 | self, 156 | x: torch.Tensor, 157 | rope: RoPECache, 158 | max_seq_length: int, 159 | mask: Optional[torch.Tensor] = None, 160 | input_pos: Optional[torch.Tensor] = None, 161 | kv_cache: Optional[KVCache] = None, 162 | adapter_kv_cache: Optional[KVCache] = None, 163 | ) -> Tuple[torch.Tensor, Optional[KVCache], Optional[KVCache]]: 164 | n_1 = self.norm_1(x) 165 | h, new_kv_cache, new_adapter_kv_cache = self.attn( 166 | n_1, rope, max_seq_length, mask, input_pos, kv_cache, adapter_kv_cache 167 | ) 168 | if self.config.parallel_residual: 169 | n_2 = n_1 if self.config.shared_attention_norm else self.norm_2(x) 170 | x = x + h + self.mlp(n_2) 171 | else: 172 | if self.config.shared_attention_norm: 173 | raise NotImplementedError( 174 | "No checkpoint amongst the ones we support uses this configuration" 175 | " (non-parallel residual and shared attention norm)." 176 | ) 177 | x = x + h 178 | x = x + self.mlp(self.norm_2(x)) 179 | return x, new_kv_cache, new_adapter_kv_cache 180 | 181 | 182 | class CausalSelfAttention(BaseCausalSelfAttention): 183 | """A modification of `lit_gpt.model.CausalSelfAttention` that adds the attention 184 | over the adaption prompt.""" 185 | 186 | def __init__(self, config: Config, block_idx: int) -> None: 187 | super().__init__(config) 188 | if block_idx >= config.adapter_start_layer: 189 | # adapter embedding layer 190 | self.adapter_wte = nn.Embedding(config.adapter_prompt_length, config.n_embd) 191 | # gate for adaption 192 | self.gating_factor = torch.nn.Parameter(torch.zeros(1, 1, config.n_head, 1)) 193 | self.reset_parameters() 194 | self.block_idx = block_idx 195 | 196 | def forward( 197 | self, 198 | x: torch.Tensor, 199 | rope: RoPECache, 200 | max_seq_length: int, 201 | mask: Optional[torch.Tensor] = None, 202 | input_pos: Optional[torch.Tensor] = None, 203 | kv_cache: Optional[KVCache] = None, 204 | adapter_kv_cache: Optional[KVCache] = None, 205 | ) -> Tuple[torch.Tensor, Optional[KVCache], Optional[KVCache]]: 206 | B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) 207 | 208 | qkv = self.attn(x) 209 | 210 | # assemble into a number of query groups to support MHA, MQA and GQA together (see `config.n_query_groups`) 211 | q_per_kv = self.config.n_head // self.config.n_query_groups 212 | total_qkv = q_per_kv + 2 # each group has 1+ queries, 1 key, and 1 value 213 | qkv = qkv.view(B, T, self.config.n_query_groups, total_qkv, self.config.head_size) 214 | qkv = qkv.permute(0, 2, 3, 1, 4) # (B, n_query_groups, total_qkv, T, hs) 215 | 216 | # split batched computation into three 217 | q, k, v = qkv.split((q_per_kv, 1, 1), dim=2) 218 | 219 | # repeat k and v if necessary 220 | if self.config.n_query_groups != 1: # doing this would require a full kv cache with MQA (inefficient!) 221 | # for MHA this is a no-op 222 | k = k.expand(B, self.config.n_query_groups, q_per_kv, T, self.config.head_size) 223 | v = v.expand(B, self.config.n_query_groups, q_per_kv, T, self.config.head_size) 224 | 225 | q = q.reshape(B, -1, T, self.config.head_size) # (B, nh_q, T, hs) 226 | k = k.reshape(B, -1, T, self.config.head_size) # (B, nh_k, T, hs) 227 | v = v.reshape(B, -1, T, self.config.head_size) # (B, nh_v, T, hs) 228 | 229 | n_elem = int(self.config.rotary_percentage * self.config.head_size) 230 | 231 | cos, sin = rope 232 | q_roped = apply_rope(q[..., :n_elem], cos, sin) 233 | k_roped = apply_rope(k[..., :n_elem], cos, sin) 234 | q = torch.cat((q_roped, q[..., n_elem:]), dim=-1) 235 | k = torch.cat((k_roped, k[..., n_elem:]), dim=-1) 236 | 237 | if kv_cache is not None: 238 | cache_k, cache_v = kv_cache 239 | cache_k, cache_v = cache_k.to(dtype=k.dtype), cache_v.to(dtype=v.dtype) 240 | # check if reached token limit 241 | if input_pos[-1] >= max_seq_length: 242 | input_pos = torch.tensor(max_seq_length - 1, device=input_pos.device) 243 | # shift 1 position to the left 244 | cache_k = torch.roll(cache_k, -1, dims=2) 245 | cache_v = torch.roll(cache_v, -1, dims=2) 246 | k = cache_k.index_copy_(2, input_pos, k) 247 | v = cache_v.index_copy_(2, input_pos, v) 248 | kv_cache = k, v 249 | 250 | y = self.scaled_dot_product_attention(q, k, v, mask=mask) 251 | 252 | if self.block_idx >= self.config.adapter_start_layer: 253 | aT = self.config.adapter_prompt_length 254 | if adapter_kv_cache is not None: 255 | ak, av = adapter_kv_cache 256 | else: 257 | prefix = self.adapter_wte.weight.reshape(1, aT, C) 258 | aqkv = self.attn(prefix) 259 | aqkv = aqkv.view(1, aT, self.config.n_query_groups, q_per_kv + 2, self.config.head_size) 260 | aqkv = aqkv.permute(0, 2, 3, 1, 4) 261 | _, ak, av = aqkv.split((q_per_kv, 1, 1), dim=2) 262 | if self.config.n_query_groups != 1: 263 | # for MHA this is a no-op 264 | ak = ak.repeat_interleave(q_per_kv, dim=2) 265 | av = av.repeat_interleave(q_per_kv, dim=2) 266 | ak = ak.view(1, -1, aT, self.config.head_size) # (1, nh_ak, aT, hs) 267 | av = av.view(1, -1, aT, self.config.head_size) # (1, nh_av, aT, hs) 268 | adapter_kv_cache = (ak, av) 269 | 270 | amask = torch.ones(T, aT, dtype=torch.bool, device=x.device) 271 | ay = self.scaled_dot_product_attention(q, ak, av, amask) 272 | y = y + self.gating_factor * ay 273 | 274 | y = y.reshape(B, T, C) # re-assemble all head outputs side by side 275 | 276 | # output projection 277 | y = self.proj(y) 278 | 279 | return y, kv_cache, adapter_kv_cache 280 | 281 | def reset_parameters(self) -> None: 282 | torch.nn.init.zeros_(self.gating_factor) 283 | 284 | def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None: 285 | """For compatibility with older checkpoints.""" 286 | if (key := prefix + "gating_factor") in state_dict and state_dict[key].size(1) == self.config.n_head: 287 | state_dict[key] = state_dict[key].permute(0, 2, 1, 3) 288 | super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) 289 | 290 | 291 | def mark_only_adapter_as_trainable(model: GPT) -> None: 292 | """Sets `requires_grad=False` for all non-adapter weights.""" 293 | for name, param in model.named_parameters(): 294 | param.requires_grad = adapter_filter(name, param) 295 | 296 | 297 | def adapter_filter(key: str, value: Any) -> bool: 298 | return "adapter_wte" in key or "gating_factor" in key 299 | -------------------------------------------------------------------------------- /lit_gpt/adapter_v2.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018-2024 Advanced Micro Devices, Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | """Implementation of the paper: 17 | 18 | LLaMA-Adapter V2: Parameter-Efficient Visual Instruction Model 19 | https://arxiv.org/abs/2304.15010 20 | 21 | Port for Lit-GPT 22 | """ 23 | from dataclasses import dataclass 24 | from typing import Any, Dict, List, Optional, Tuple, Type 25 | 26 | import torch 27 | import torch.nn as nn 28 | from typing_extensions import Self 29 | 30 | import lit_gpt 31 | from lit_gpt.adapter import GPT as BaseModel 32 | from lit_gpt.adapter import Block as BaseBlock 33 | from lit_gpt.adapter import Config as BaseConfig 34 | from lit_gpt.adapter import KVCache, RoPECache 35 | from lit_gpt.model import CausalSelfAttention as BaseCausalSelfAttention 36 | from lit_gpt.model import apply_rope 37 | from lit_gpt.utils import map_old_state_dict_weights 38 | 39 | 40 | @dataclass 41 | class Config(BaseConfig): 42 | @property 43 | def mlp_class(self) -> Type: 44 | return getattr(lit_gpt.adapter_v2, self._mlp_class) 45 | 46 | 47 | def adapter_filter(key: str, value: Any) -> bool: 48 | adapter_substrings = ( 49 | # regular adapter v1 parameters 50 | "adapter_wte", 51 | "gating_factor", 52 | # adapter v2: new bias and scale used in Linear 53 | "adapter_scale", 54 | "adapter_bias", 55 | # adapter v2: Norm parameters are now trainable 56 | "norm_1", 57 | "norm_2", 58 | "ln_f", 59 | ) 60 | return any(s in key for s in adapter_substrings) 61 | 62 | 63 | class AdapterV2Linear(torch.nn.Module): 64 | def __init__(self, in_features: int, out_features: int, **kwargs) -> None: 65 | super().__init__() 66 | self.linear = torch.nn.Linear(in_features, out_features, **kwargs) 67 | self.adapter_bias = torch.nn.Parameter(torch.zeros(out_features), requires_grad=False) 68 | self.adapter_scale = torch.nn.Parameter(torch.ones(out_features), requires_grad=False) 69 | self.reset_parameters() 70 | 71 | def forward(self, x: torch.Tensor) -> torch.Tensor: 72 | return self.adapter_scale * (self.linear(x) + self.adapter_bias) 73 | 74 | def reset_parameters(self) -> None: 75 | nn.init.zeros_(self.adapter_bias) 76 | nn.init.ones_(self.adapter_scale) 77 | 78 | 79 | class GPT(BaseModel): 80 | def __init__(self, config: Config) -> None: 81 | # Skip the parent class __init__ altogether and replace it to avoid useless allocations 82 | nn.Module.__init__(self) 83 | assert config.padded_vocab_size is not None 84 | self.config = config 85 | 86 | self.lm_head = AdapterV2Linear(config.n_embd, config.padded_vocab_size, bias=False) 87 | self.transformer = nn.ModuleDict( 88 | dict( 89 | wte=nn.Embedding(config.padded_vocab_size, config.n_embd), 90 | h=nn.ModuleList(Block(config, i) for i in range(config.n_layer)), 91 | ln_f=config.norm_class(config.n_embd, eps=config.norm_eps), 92 | ) 93 | ) 94 | 95 | self.rope_cache: Optional[RoPECache] = None 96 | self.mask_cache: Optional[torch.Tensor] = None 97 | self.kv_caches: List[KVCache] = [] 98 | self.adapter_kv_caches: List[KVCache] = [] 99 | 100 | @classmethod 101 | def from_name(cls, name: str, **kwargs: Any) -> Self: 102 | return cls(Config.from_name(name, **kwargs)) 103 | 104 | def _init_weights(self, module: nn.Module) -> None: 105 | """Meant to be used with `gpt.apply(gpt._init_weights)`. Unused method left for completeness.""" 106 | super()._init_weights(module) 107 | if isinstance(module, CausalSelfAttention): 108 | module.reset_parameters() 109 | if isinstance(module, AdapterV2Linear): 110 | module.reset_parameters() 111 | 112 | def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None: 113 | """For compatibility with base checkpoints.""" 114 | mapping = {"lm_head.weight": "lm_head.linear.weight"} 115 | state_dict = map_old_state_dict_weights(state_dict, mapping, prefix) 116 | super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) 117 | 118 | 119 | class Block(BaseBlock): 120 | """The implementation is identical to `lit_gpt.model.Block` with the exception that 121 | we replace the attention layer where adaption is implemented.""" 122 | 123 | def __init__(self, config: Config, block_idx: int) -> None: 124 | # Skip the parent class __init__ altogether and replace it to avoid useless allocations 125 | nn.Module.__init__(self) 126 | self.norm_1 = config.norm_class(config.n_embd, eps=config.norm_eps) 127 | self.attn = CausalSelfAttention(config, block_idx) 128 | if not config.shared_attention_norm: 129 | self.norm_2 = config.norm_class(config.n_embd, eps=config.norm_eps) 130 | self.mlp = config.mlp_class(config) 131 | 132 | self.config = config 133 | 134 | 135 | class CausalSelfAttention(BaseCausalSelfAttention): 136 | def __init__(self, config: Config, block_idx: int) -> None: 137 | """Causal self-attention with calculating qkv matrices with a single matrix* and Low Ranking Adaptation for 138 | parameter-efficient fine-tuning. 139 | 140 | *Instead of creating multiple heads and concatenating the result (in addition to creating separate matrices for 141 | query, key and value for each head) we can do this in a single pass with a single weight matrix. 142 | """ 143 | # Skip the parent class __init__ altogether and replace it to avoid useless allocations 144 | nn.Module.__init__(self) 145 | shape = (config.n_head + 2 * config.n_query_groups) * config.head_size 146 | # key, query, value projections for all heads, but in a batch 147 | self.attn = AdapterV2Linear(in_features=config.n_embd, out_features=shape, bias=config.bias) 148 | # output projection 149 | self.proj = AdapterV2Linear(config.n_embd, config.n_embd, bias=config.bias) 150 | if block_idx >= config.adapter_start_layer: 151 | # adapter embedding layer 152 | self.adapter_wte = nn.Embedding(config.adapter_prompt_length, config.n_embd) 153 | # gate for adaption 154 | self.gating_factor = torch.nn.Parameter(torch.zeros(1, 1, config.n_head, 1)) 155 | self.reset_parameters() 156 | self.block_idx = block_idx 157 | 158 | self.config = config 159 | 160 | def forward( 161 | self, 162 | x: torch.Tensor, 163 | rope: RoPECache, 164 | max_seq_length: int, 165 | mask: Optional[torch.Tensor] = None, 166 | input_pos: Optional[torch.Tensor] = None, 167 | kv_cache: Optional[KVCache] = None, 168 | adapter_kv_cache: Optional[KVCache] = None, 169 | ) -> Tuple[torch.Tensor, Optional[KVCache], Optional[KVCache]]: 170 | B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) 171 | 172 | qkv = self.attn(x) 173 | 174 | # assemble into a number of query groups to support MHA, MQA and GQA together (see `config.n_query_groups`) 175 | q_per_kv = self.config.n_head // self.config.n_query_groups 176 | total_qkv = q_per_kv + 2 # each group has 1+ queries, 1 key, and 1 value 177 | qkv = qkv.view(B, T, self.config.n_query_groups, total_qkv, self.config.head_size) 178 | qkv = qkv.permute(0, 2, 3, 1, 4) # (B, n_query_groups, total_qkv, T, hs) 179 | 180 | # split batched computation into three 181 | q, k, v = qkv.split((q_per_kv, 1, 1), dim=2) 182 | 183 | # repeat k and v if necessary 184 | if self.config.n_query_groups != 1: # doing this would require a full kv cache with MQA (inefficient!) 185 | # for MHA this is a no-op 186 | k = k.expand(B, self.config.n_query_groups, q_per_kv, T, self.config.head_size) 187 | v = v.expand(B, self.config.n_query_groups, q_per_kv, T, self.config.head_size) 188 | 189 | q = q.reshape(B, -1, T, self.config.head_size) # (B, nh_q, T, hs) 190 | k = k.reshape(B, -1, T, self.config.head_size) # (B, nh_k, T, hs) 191 | v = v.reshape(B, -1, T, self.config.head_size) # (B, nh_v, T, hs) 192 | 193 | n_elem = int(self.config.rotary_percentage * self.config.head_size) 194 | 195 | cos, sin = rope 196 | q_roped = apply_rope(q[..., :n_elem], cos, sin) 197 | k_roped = apply_rope(k[..., :n_elem], cos, sin) 198 | q = torch.cat((q_roped, q[..., n_elem:]), dim=-1) 199 | k = torch.cat((k_roped, k[..., n_elem:]), dim=-1) 200 | 201 | if kv_cache is not None: 202 | cache_k, cache_v = kv_cache 203 | cache_k, cache_v = cache_k.to(dtype=k.dtype), cache_v.to(dtype=v.dtype) 204 | # check if reached token limit 205 | if input_pos[-1] >= max_seq_length: 206 | input_pos = torch.tensor(max_seq_length - 1, device=input_pos.device) 207 | # shift 1 position to the left 208 | cache_k = torch.roll(cache_k, -1, dims=2) 209 | cache_v = torch.roll(cache_v, -1, dims=2) 210 | k = cache_k.index_copy_(2, input_pos, k) 211 | v = cache_v.index_copy_(2, input_pos, v) 212 | kv_cache = k, v 213 | 214 | y = self.scaled_dot_product_attention(q, k, v, mask=mask) 215 | 216 | if self.block_idx >= self.config.adapter_start_layer: 217 | aT = self.config.adapter_prompt_length 218 | if adapter_kv_cache is not None: 219 | ak, av = adapter_kv_cache 220 | else: 221 | prefix = self.adapter_wte.weight.reshape(1, aT, C) 222 | aqkv = self.attn(prefix) 223 | aqkv = aqkv.view(1, aT, self.config.n_query_groups, q_per_kv + 2, self.config.head_size) 224 | aqkv = aqkv.permute(0, 2, 3, 1, 4) 225 | _, ak, av = aqkv.split((q_per_kv, 1, 1), dim=2) 226 | if self.config.n_query_groups != 1: 227 | # for MHA this is a no-op 228 | ak = ak.repeat_interleave(q_per_kv, dim=2) 229 | av = av.repeat_interleave(q_per_kv, dim=2) 230 | ak = ak.view(1, -1, aT, self.config.head_size) # (1, nh_ak, aT, hs) 231 | av = av.view(1, -1, aT, self.config.head_size) # (1, nh_av, aT, hs) 232 | adapter_kv_cache = (ak, av) 233 | 234 | amask = torch.ones(T, aT, dtype=torch.bool, device=x.device) 235 | ay = self.scaled_dot_product_attention(q, ak, av, amask) 236 | y = y + self.gating_factor * ay 237 | 238 | y = y.reshape(B, T, C) # re-assemble all head outputs side by side 239 | 240 | # output projection 241 | y = self.proj(y) 242 | 243 | return y, kv_cache, adapter_kv_cache 244 | 245 | def reset_parameters(self) -> None: 246 | torch.nn.init.zeros_(self.gating_factor) 247 | 248 | def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None: 249 | """For compatibility with base checkpoints.""" 250 | mapping = { 251 | "attn.weight": "attn.linear.weight", 252 | "attn.bias": "attn.linear.bias", 253 | "proj.weight": "proj.linear.weight", 254 | "proj.bias": "proj.linear.bias", 255 | } 256 | state_dict = map_old_state_dict_weights(state_dict, mapping, prefix) 257 | # For compatibility with older checkpoints 258 | if (key := prefix + "gating_factor") in state_dict and state_dict[key].size(1) == self.config.n_head: 259 | state_dict[key] = state_dict[key].permute(0, 2, 1, 3) 260 | super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) 261 | 262 | 263 | class GptNeoxMLP(lit_gpt.model.GptNeoxMLP): 264 | def __init__(self, config: Config) -> None: 265 | nn.Module.__init__(self) 266 | self.fc = AdapterV2Linear(config.n_embd, config.intermediate_size, bias=config.bias) 267 | self.proj = AdapterV2Linear(config.intermediate_size, config.n_embd, bias=config.bias) 268 | 269 | def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None: 270 | """For compatibility with base checkpoints.""" 271 | mapping = { 272 | "fc.weight": "fc.linear.weight", 273 | "fc.bias": "fc.linear.bias", 274 | "proj.weight": "proj.linear.weight", 275 | "proj.bias": "proj.linear.bias", 276 | } 277 | state_dict = map_old_state_dict_weights(state_dict, mapping, prefix) 278 | super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) 279 | 280 | 281 | class LLaMAMLP(lit_gpt.model.LLaMAMLP): 282 | def __init__(self, config: Config) -> None: 283 | nn.Module.__init__(self) 284 | self.fc_1 = AdapterV2Linear(config.n_embd, config.intermediate_size, bias=config.bias) 285 | self.fc_2 = AdapterV2Linear(config.n_embd, config.intermediate_size, bias=config.bias) 286 | self.proj = AdapterV2Linear(config.intermediate_size, config.n_embd, bias=config.bias) 287 | 288 | def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None: 289 | """For compatibility with base checkpoints.""" 290 | mapping = { 291 | "fc_1.weight": "fc_1.linear.weight", 292 | "fc_1.bias": "fc_1.linear.bias", 293 | "fc_2.weight": "fc_2.linear.weight", 294 | "fc_2.bias": "fc_2.linear.bias", 295 | "proj.weight": "proj.linear.weight", 296 | "proj.bias": "proj.linear.bias", 297 | } 298 | state_dict = map_old_state_dict_weights(state_dict, mapping, prefix) 299 | super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) 300 | 301 | 302 | def mark_only_adapter_v2_as_trainable(model: GPT) -> None: 303 | """Sets requires_grad=False for all non-adapter weights""" 304 | for name, param in model.named_parameters(): 305 | param.requires_grad = adapter_filter(name, param) 306 | -------------------------------------------------------------------------------- /scripts/convert_lit_checkpoint.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018-2024 Advanced Micro Devices, Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import contextlib 17 | import gc 18 | import sys 19 | from functools import partial 20 | from pathlib import Path 21 | from typing import Dict, Literal, Optional, Tuple, Union 22 | from dataclasses import asdict 23 | import json 24 | import torch 25 | 26 | # support running without installing as a package 27 | wd = Path(__file__).parent.parent.resolve() 28 | sys.path.append(str(wd)) 29 | 30 | from lit_gpt import Config 31 | from lit_gpt.utils import NotYetLoadedTensor, incremental_save, lazy_load 32 | # from scripts.convert_hf_checkpoint import layer_template, load_param 33 | 34 | 35 | def layer_template(layer_name: str, idx: int) -> Tuple[str, int]: 36 | split = layer_name.split(".") 37 | number = int(split[idx]) 38 | split[idx] = "{}" 39 | from_name = ".".join(split) 40 | return from_name, number 41 | 42 | 43 | def load_param(param: Union[torch.Tensor, NotYetLoadedTensor], name: str, dtype: Optional[torch.dtype]) -> torch.Tensor: 44 | if hasattr(param, "_load_tensor"): 45 | # support tensors loaded via `lazy_load()` 46 | print(f"Loading {name!r} into RAM") 47 | param = param._load_tensor() 48 | if dtype is not None and type(dtype) is not NotYetLoadedTensor and dtype != param.dtype: 49 | print(f"Converting {name!r} from {param.dtype} to {dtype}") 50 | param = param.to(dtype) 51 | return param 52 | def copy_weights_falcon( 53 | size: Literal["7b", "40b"], 54 | state_dict: Dict[str, torch.Tensor], 55 | lit_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]], 56 | saver: Optional[incremental_save] = None, 57 | ): 58 | weight_map = { 59 | "transformer.wte.weight": "transformer.word_embeddings.weight", 60 | "transformer.h.{}.attn.attn.weight": "transformer.h.{}.self_attention.query_key_value.weight", 61 | "transformer.h.{}.attn.proj.weight": "transformer.h.{}.self_attention.dense.weight", 62 | "transformer.h.{}.mlp.fc.weight": "transformer.h.{}.mlp.dense_h_to_4h.weight", 63 | "transformer.h.{}.mlp.proj.weight": "transformer.h.{}.mlp.dense_4h_to_h.weight", 64 | "transformer.ln_f.bias": "transformer.ln_f.bias", 65 | "transformer.ln_f.weight": "transformer.ln_f.weight", 66 | "lm_head.weight": "lm_head.weight", 67 | } 68 | # the original model definition is different for each size 69 | if size == "7b": 70 | weight_map.update( 71 | { 72 | "transformer.h.{}.norm_1.bias": "transformer.h.{}.input_layernorm.bias", 73 | "transformer.h.{}.norm_1.weight": "transformer.h.{}.input_layernorm.weight", 74 | } 75 | ) 76 | elif size == "40b": 77 | weight_map.update( 78 | { 79 | "transformer.h.{}.norm_1.bias": "transformer.h.{}.ln_attn.bias", 80 | "transformer.h.{}.norm_1.weight": "transformer.h.{}.ln_attn.weight", 81 | "transformer.h.{}.norm_2.bias": "transformer.h.{}.ln_mlp.bias", 82 | "transformer.h.{}.norm_2.weight": "transformer.h.{}.ln_mlp.weight", 83 | } 84 | ) 85 | else: 86 | raise NotImplementedError 87 | 88 | for name, param in lit_weights.items(): 89 | if "transformer.h" in name: 90 | from_name, number = layer_template(name, 2) 91 | to_name = weight_map[from_name].format(number) 92 | else: 93 | to_name = weight_map[name] 94 | param = load_param(param, name, None) 95 | if saver is not None: 96 | param = saver.store_early(param) 97 | state_dict[to_name] = param 98 | 99 | 100 | def copy_weights_gpt_neox( 101 | state_dict: Dict[str, torch.Tensor], 102 | lit_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]], 103 | saver: Optional[incremental_save] = None, 104 | ) -> None: 105 | weight_map = { 106 | "transformer.wte.weight": "gpt_neox.embed_in.weight", 107 | "transformer.h.{}.norm_1.bias": "gpt_neox.layers.{}.input_layernorm.bias", 108 | "transformer.h.{}.norm_1.weight": "gpt_neox.layers.{}.input_layernorm.weight", 109 | "transformer.h.{}.attn.attn.bias": "gpt_neox.layers.{}.attention.query_key_value.bias", 110 | "transformer.h.{}.attn.attn.weight": "gpt_neox.layers.{}.attention.query_key_value.weight", 111 | "transformer.h.{}.attn.proj.bias": "gpt_neox.layers.{}.attention.dense.bias", 112 | "transformer.h.{}.attn.proj.weight": "gpt_neox.layers.{}.attention.dense.weight", 113 | "transformer.h.{}.norm_2.bias": "gpt_neox.layers.{}.post_attention_layernorm.bias", 114 | "transformer.h.{}.norm_2.weight": "gpt_neox.layers.{}.post_attention_layernorm.weight", 115 | "transformer.h.{}.mlp.fc.bias": "gpt_neox.layers.{}.mlp.dense_h_to_4h.bias", 116 | "transformer.h.{}.mlp.fc.weight": "gpt_neox.layers.{}.mlp.dense_h_to_4h.weight", 117 | "transformer.h.{}.mlp.proj.bias": "gpt_neox.layers.{}.mlp.dense_4h_to_h.bias", 118 | "transformer.h.{}.mlp.proj.weight": "gpt_neox.layers.{}.mlp.dense_4h_to_h.weight", 119 | "transformer.ln_f.bias": "gpt_neox.final_layer_norm.bias", 120 | "transformer.ln_f.weight": "gpt_neox.final_layer_norm.weight", 121 | "lm_head.weight": "embed_out.weight", 122 | } 123 | 124 | for name, param in lit_weights.items(): 125 | if "transformer.h" in name: 126 | from_name, number = layer_template(name, 2) 127 | to_name = weight_map[from_name].format(number) 128 | else: 129 | to_name = weight_map[name] 130 | param = load_param(param, name, None) 131 | if saver is not None: 132 | param = saver.store_early(param) 133 | state_dict[to_name] = param 134 | 135 | 136 | def copy_weights_llama( 137 | config: Config, 138 | state_dict: Dict[str, torch.Tensor], 139 | lit_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]], 140 | saver: Optional[incremental_save] = None, 141 | ): 142 | weight_map = { 143 | "transformer.wte.weight": "model.embed_tokens.weight", 144 | "transformer.h.{}.norm_1.weight": "model.layers.{}.input_layernorm.weight", 145 | "transformer.h.{}.attn.proj.weight": "model.layers.{}.self_attn.o_proj.weight", 146 | "transformer.h.{}.norm_2.weight": "model.layers.{}.post_attention_layernorm.weight", 147 | "transformer.h.{}.mlp.fc_1.weight": "model.layers.{}.mlp.gate_proj.weight", 148 | "transformer.h.{}.mlp.fc_2.weight": "model.layers.{}.mlp.up_proj.weight", 149 | "transformer.h.{}.mlp.proj.weight": "model.layers.{}.mlp.down_proj.weight", 150 | "transformer.ln_f.weight": "model.norm.weight", 151 | "lm_head.weight": "lm_head.weight", 152 | } 153 | for name, param in lit_weights.items(): 154 | if name.endswith(".attn.attn.weight"): 155 | from_name, number = layer_template(name, 2) 156 | q = "model.layers.{}.self_attn.q_proj.weight".format(number) 157 | k = "model.layers.{}.self_attn.k_proj.weight".format(number) 158 | v = "model.layers.{}.self_attn.v_proj.weight".format(number) 159 | qkv = load_param(param, name,None) 160 | qp, kp, vp = tensor_split(qkv, config) 161 | for to_name, param in zip((q, k, v), (qp, kp, vp)): 162 | if saver is not None: 163 | param = saver.store_early(param) 164 | state_dict[to_name] = param 165 | elif "transformer.h" in name: 166 | from_name, number = layer_template(name, 2) 167 | to_name = weight_map[from_name] 168 | 169 | if to_name is None: 170 | continue 171 | to_name = to_name.format(number) 172 | param = load_param(param, name,None) 173 | if saver is not None: 174 | param = saver.store_early(param) 175 | state_dict[to_name] = param 176 | 177 | else: 178 | to_name = weight_map[name] 179 | param = load_param(param, name, None) 180 | if saver is not None: 181 | param = saver.store_early(param) 182 | state_dict[to_name] = param 183 | 184 | 185 | def tensor_split( 186 | param: Union[torch.Tensor, NotYetLoadedTensor], config: Config 187 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 188 | def kstart(start, blen, klen) -> int: 189 | """returns start index of keys in batch""" 190 | return start + (blen - (klen * 2)) 191 | 192 | def vstart(start, blen, klen) -> int: 193 | """returns start index of values in batch""" 194 | return start + blen - klen 195 | 196 | def vend(start, blen) -> int: 197 | """returns last index of values in batch""" 198 | return start + blen 199 | 200 | # num observations 201 | nobs = param.shape[0] 202 | # batch length 203 | blen = nobs // config.n_query_groups 204 | # key length in batch 205 | klen = config.head_size 206 | # value length in batch 207 | vlen = config.head_size 208 | # the starting index of each new batch 209 | starts = range(0, nobs, blen) 210 | # the indices to splice on 211 | splices = [(s, kstart(s, blen, klen), vstart(s, blen, vlen), vend(s, blen)) for s in starts] 212 | 213 | qc = () 214 | kc = () 215 | vc = () 216 | 217 | for splice in splices: 218 | qs, ks, vs, ve = splice 219 | qc += (param[qs:ks, :],) 220 | kc += (param[ks:vs, :],) 221 | vc += (param[vs:ve, :],) 222 | 223 | q = torch.cat(qc) 224 | k = torch.cat(kc) 225 | v = torch.cat(vc) 226 | 227 | return q, k, v 228 | 229 | 230 | def maybe_unwrap_state_dict(lit_weights: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: 231 | return lit_weights.get("model", lit_weights) 232 | 233 | 234 | def check_conversion_supported(lit_weights: Dict[str, torch.Tensor]) -> None: 235 | weight_names = {wk.split(".")[-1] for wk in lit_weights} 236 | # LoRA or QLoRA 237 | if any("lora" in wn for wn in weight_names): 238 | raise ValueError("Model weights must be merged using `lora.merge_lora_weights()` before conversion.") 239 | # adapter v2. adapter_bias will only be in adapter_v2 240 | elif "adapter_bias" in weight_names: 241 | raise NotImplementedError("Converting models finetuned with adapter_v2 not yet supported.") 242 | # adapter. gating_factor is in adapter and adapter_v2 243 | elif "gating_factor" in weight_names: 244 | raise NotImplementedError("Converting models finetuned with adapter not yet supported.") 245 | 246 | 247 | def get_tinyllama_init_hf_config() -> dict: 248 | return { 249 | "architectures": ["LlamaForCausalLM"], 250 | "bos_token_id": 1, 251 | "eos_token_id": 2, 252 | "hidden_act": "silu", 253 | "hidden_size": None, 254 | "initializer_range": 0.02, 255 | "intermediate_size": None, 256 | "max_position_embeddings": None, 257 | "model_type": "llama", 258 | "num_attention_heads": None, 259 | "num_hidden_layers": None, 260 | "num_key_value_heads": None, 261 | "pretraining_tp": 1, 262 | "rms_norm_eps": None, 263 | "rope_scaling": None, 264 | "tie_word_embeddings": False, 265 | "torch_dtype": "float32", 266 | "transformers_version": "4.31.0.dev0", 267 | "use_cache": True, 268 | "vocab_size": None, 269 | } 270 | 271 | 272 | def convert_config_lit_to_hf(lit_config_dict: dict) -> dict: 273 | lit_hf_mapping = { 274 | "block_size": "max_position_embeddings", 275 | "vocab_size": "vocab_size", 276 | "n_layer": "num_hidden_layers", 277 | "n_embd": "hidden_size", 278 | "n_head": "num_attention_heads", 279 | "n_query_groups": "num_key_value_heads", 280 | "intermediate_size": "intermediate_size", 281 | "norm_eps": "rms_norm_eps", 282 | 283 | } 284 | hf_config_dict = get_tinyllama_init_hf_config() 285 | 286 | for lit_key, hf_key in lit_hf_mapping.items(): 287 | hf_config_dict[hf_key] = lit_config_dict[lit_key] 288 | return hf_config_dict 289 | 290 | 291 | @torch.inference_mode() 292 | def convert_lit_checkpoint(*, 293 | checkpoint_name: str, 294 | out_dir: Path, 295 | model_name: str, 296 | model_only: bool = True) -> None: 297 | config = Config.from_name(model_name) 298 | 299 | if "falcon" in model_name: 300 | copy_fn = partial(copy_weights_falcon, "40b" if config.n_embd == 8192 else "7b") 301 | elif config._mlp_class == "LLaMAMLP": 302 | copy_fn = partial(copy_weights_llama, config) 303 | else: 304 | copy_fn = copy_weights_gpt_neox 305 | 306 | # initialize a new empty state dict to hold our new weights 307 | sd = {} 308 | 309 | # checkpoint_name cannot be hardcoded because there exists different outputs such as 310 | # ("lit_model_finetuned.pth", "lit_model_lora_finetuned.pth", "lit_model_adapter_finetuned.pth"") 311 | pth_file = out_dir / checkpoint_name 312 | bin_file = pth_file.with_suffix(".bin") 313 | 314 | with incremental_save(bin_file) as saver: 315 | with contextlib.ExitStack() as stack: 316 | lit_weights = stack.enter_context(lazy_load(pth_file)) 317 | lit_weights = maybe_unwrap_state_dict(lit_weights) 318 | check_conversion_supported(lit_weights) 319 | # Incremental save will trigger error 320 | copy_fn(sd, lit_weights, saver=None) 321 | gc.collect() 322 | saver.save(sd) 323 | 324 | # convert lit config file to hf-style 325 | if not model_only: 326 | print('Converting config file...') 327 | lit_config = asdict(config) 328 | hf_config = convert_config_lit_to_hf(lit_config) 329 | config_path = out_dir / "config.json" 330 | with open(config_path, "w") as f: 331 | json.dump(hf_config, f, indent=4) 332 | 333 | 334 | 335 | 336 | if __name__ == "__main__": 337 | from jsonargparse import CLI 338 | 339 | CLI(convert_lit_checkpoint, as_positional=False) 340 | -------------------------------------------------------------------------------- /pretrain/tinyllama.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018-2024 Advanced Micro Devices, Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import glob 17 | import math 18 | import sys 19 | import time 20 | from pathlib import Path 21 | from typing import Optional, Tuple, Union 22 | import math 23 | import lightning as L 24 | import torch 25 | from lightning.fabric.strategies import FSDPStrategy, XLAStrategy 26 | from torch.utils.data import DataLoader 27 | from functools import partial 28 | # support running without installing as a package 29 | wd = Path(__file__).parent.parent.resolve() 30 | sys.path.append(str(wd)) 31 | # from apex.optimizers import FusedAdam #torch optimizer has a cuda backend, which is faster actually 32 | from lit_gpt.model import GPT, Block, Config, CausalSelfAttention 33 | from lit_gpt.packed_dataset import CombinedDataset, PackedDataset 34 | from lit_gpt.speed_monitor import SpeedMonitorFabric as Monitor 35 | from lit_gpt.speed_monitor import estimate_flops, measure_flops 36 | from lit_gpt.utils import chunked_cross_entropy, get_default_supported_precision, num_parameters, step_csv_logger, lazy_load 37 | from pytorch_lightning.loggers import WandbLogger 38 | #from lit_gpt import FusedCrossEntropyLoss 39 | import random 40 | 41 | model_name = 'tiny_LLaMA_135M_2k' # model to train 42 | 43 | name = "tinyllama" 44 | out_dir = Path("./out") / (name+"_135M_2k") 45 | 46 | default_seed=3407 47 | 48 | # Hyperparameters 49 | num_node=8 50 | num_of_devices = 8 51 | global_batch_size = 1024/num_node 52 | learning_rate = 6e-4 53 | micro_batch_size = 16 54 | num_epochs=1 55 | num_total_token_in_b = 670 * num_epochs 56 | 57 | warmup_steps = 2000 58 | log_step_interval = 50 59 | eval_iters = 1000 60 | save_step_interval = 2000 61 | eval_step_interval = 2000 62 | 63 | 64 | weight_decay = 1e-1 65 | beta1 = 0.9 66 | beta2 = 0.95 67 | grad_clip = 1.0 68 | decay_lr = True 69 | min_lr = 6e-5 70 | 71 | batch_size = global_batch_size // num_of_devices 72 | gradient_accumulation_steps = math.ceil(batch_size / micro_batch_size) 73 | actual_global_batch = gradient_accumulation_steps*micro_batch_size*num_of_devices*num_node 74 | print(actual_global_batch) 75 | max_step = int(num_total_token_in_b * 10**9/(actual_global_batch*2048)//save_step_interval + 1)*save_step_interval 76 | assert gradient_accumulation_steps > 0 77 | warmup_iters = warmup_steps * gradient_accumulation_steps 78 | 79 | 80 | 81 | import math 82 | max_iters = max_step * gradient_accumulation_steps 83 | lr_decay_iters = max_iters 84 | log_iter_interval = math.ceil(log_step_interval * gradient_accumulation_steps) 85 | 86 | 87 | 88 | # Treat all dataset equally by their size. If you want to use a different weight for a dataset, add it to the list with the weight. 89 | train_data_config = [ 90 | ("train_", 1.0) 91 | ] 92 | 93 | val_data_config = [ 94 | ("validation_slim", 1.0), 95 | ] 96 | 97 | hparams = {k: v for k, v in locals().items() if isinstance(v, (int, float, str)) and not k.startswith("_")} 98 | logger = step_csv_logger("out", name, flush_logs_every_n_steps=log_iter_interval) 99 | wandb_logger = WandbLogger() 100 | 101 | 102 | def setup( 103 | devices: int = 8, 104 | train_data_dir: Path = Path("./slim_star_combined"), 105 | val_data_dir: Optional[Path] = None, 106 | precision: Optional[str] = 'bf16-mixed', 107 | tpu: bool = False, 108 | resume: Union[bool, Path] = False, 109 | model_name: str=None 110 | ) -> None: 111 | precision = precision or get_default_supported_precision(training=True, tpu=tpu) 112 | 113 | if devices > 1: 114 | if tpu: 115 | # For multi-host TPU training, the device count for Fabric is limited to the count on a single host. 116 | devices = "auto" 117 | strategy = XLAStrategy(sync_module_states=False) 118 | else: 119 | strategy = FSDPStrategy( 120 | auto_wrap_policy={Block}, 121 | activation_checkpointing_policy=None, 122 | state_dict_type="full", 123 | limit_all_gathers=True, 124 | cpu_offload=False, 125 | ) 126 | else: 127 | strategy = "auto" 128 | 129 | fabric = L.Fabric(devices=devices, strategy=strategy, precision=precision, loggers=[logger, wandb_logger]) 130 | fabric.print(hparams) 131 | #fabric.launch(main, train_data_dir, val_data_dir, resume) 132 | main(fabric, train_data_dir, val_data_dir, resume, model_name) 133 | 134 | 135 | def main(fabric, train_data_dir, val_data_dir, resume, model_name=None): 136 | monitor = Monitor(fabric, window_size=2, time_unit="seconds", log_iter_interval=log_iter_interval) 137 | 138 | if fabric.global_rank == 0: 139 | out_dir.mkdir(parents=True, exist_ok=True) 140 | 141 | config = Config.from_name(model_name) 142 | 143 | train_dataloader, val_dataloader = create_dataloaders( 144 | batch_size=micro_batch_size, 145 | block_size=config.block_size, 146 | fabric=fabric, 147 | train_data_dir=train_data_dir, 148 | val_data_dir=val_data_dir, 149 | seed=default_seed, 150 | ) 151 | if val_dataloader is None: 152 | train_dataloader = fabric.setup_dataloaders(train_dataloader) 153 | else: 154 | train_dataloader, val_dataloader = fabric.setup_dataloaders(train_dataloader, val_dataloader) 155 | 156 | fabric.seed_everything(default_seed) # same seed for every process to init model (FSDP) 157 | 158 | fabric.print(f"Loading model with {config.__dict__}") 159 | t0 = time.perf_counter() 160 | with fabric.init_module(empty_init=False): 161 | model = GPT(config) 162 | model.apply(partial(model._init_weights ,n_layer=config.n_layer)) 163 | 164 | 165 | fabric.print(f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.") 166 | fabric.print(f"Total parameters {num_parameters(model):,}") 167 | 168 | model = fabric.setup(model) 169 | optimizer = torch.optim.AdamW( 170 | model.parameters(), lr=learning_rate, weight_decay=weight_decay, betas=(beta1, beta2), foreach=False 171 | ) 172 | # optimizer = FusedAdam(model.parameters(), lr=learning_rate, weight_decay=weight_decay, betas=(beta1, beta2),adam_w_mode=True) 173 | optimizer = fabric.setup_optimizers(optimizer) 174 | 175 | state = {"model": model, "optimizer": optimizer, "hparams": hparams, "iter_num": 0, "step_count": 0} 176 | 177 | if resume is True: 178 | resume = sorted(out_dir.glob("*.pth"))[-1] 179 | if resume : 180 | fabric.print(f"Resuming training from {resume}") 181 | fabric.load(resume, state) 182 | 183 | train_time = time.perf_counter() 184 | train(fabric, state, train_dataloader, val_dataloader, monitor, resume) 185 | fabric.print(f"Training time: {(time.perf_counter()-train_time):.2f}s") 186 | if fabric.device.type == "cuda": 187 | fabric.print(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB") 188 | 189 | 190 | def train(fabric, state, train_dataloader, val_dataloader, monitor, resume): 191 | model = state["model"] 192 | optimizer = state["optimizer"] 193 | 194 | # if val_dataloader is not None: 195 | # validate(fabric, model, val_dataloader) # sanity check 196 | model.train() 197 | 198 | meta_model = GPT(model.config).cuda() 199 | # "estimated" is not as precise as "measured". Estimated is optimistic but widely used in the wild. 200 | # When comparing MFU or FLOP numbers with other projects that use estimated FLOPs, 201 | # consider passing `SpeedMonitor(flops_per_batch=estimated_flops)` instead 202 | estimated_flops = estimate_flops(meta_model) * micro_batch_size 203 | fabric.print(f"Estimated TFLOPs: {estimated_flops * fabric.world_size / 1e12:.2f}") 204 | x = torch.randint(0, 1, (micro_batch_size, model.config.block_size)) 205 | # measured_flos run in meta. Will trigger fusedRMSNorm error 206 | #measured_flops = measure_flops(meta_model, x) 207 | #fabric.print(f"Measured TFLOPs: {measured_flops * fabric.world_size / 1e12:.2f}") 208 | del meta_model, x 209 | 210 | total_lengths = 0 211 | total_t0 = time.perf_counter() 212 | 213 | if fabric.device.type == "xla": 214 | import torch_xla.core.xla_model as xm 215 | 216 | xm.mark_step() 217 | 218 | 219 | initial_iter = state["iter_num"] 220 | curr_iter = 0 221 | 222 | loss_func = torch.nn.CrossEntropyLoss() #FusedCrossEntropyLoss() 223 | for i, train_data in enumerate(train_dataloader): 224 | # resume loader state. This is not elegant but it works. Should rewrite it in the future. 225 | if resume: 226 | if curr_iter < initial_iter: 227 | curr_iter += 1 228 | continue 229 | else: 230 | resume = False 231 | curr_iter = -1 232 | fabric.barrier() 233 | fabric.print("resume finished, taken {} seconds".format(time.perf_counter() - total_t0)) 234 | if state["iter_num"] >= max_iters: 235 | break 236 | 237 | # determine and set the learning rate for this iteration 238 | lr = get_lr(state["iter_num"]) if decay_lr else learning_rate 239 | for param_group in optimizer.param_groups: 240 | param_group["lr"] = lr 241 | 242 | iter_t0 = time.perf_counter() 243 | 244 | input_ids = train_data[:, 0 : model.config.block_size].contiguous() 245 | targets = train_data[:, 1 : model.config.block_size + 1].contiguous() 246 | is_accumulating = (state["iter_num"] + 1) % gradient_accumulation_steps != 0 247 | with fabric.no_backward_sync(model, enabled=is_accumulating): 248 | logits = model(input_ids) 249 | loss = loss_func(logits.transpose(1,2), targets) 250 | # loss = chunked_cross_entropy(logits, targets, chunk_size=0) 251 | fabric.backward(loss / gradient_accumulation_steps) 252 | 253 | if not is_accumulating: 254 | fabric.clip_gradients(model, optimizer, max_norm=grad_clip) 255 | optimizer.step() 256 | optimizer.zero_grad() 257 | state["step_count"] += 1 258 | elif fabric.device.type == "xla": 259 | xm.mark_step() 260 | state["iter_num"] += 1 261 | # input_id: B L 262 | total_lengths += input_ids.size(1) 263 | t1 = time.perf_counter() 264 | if i % log_step_interval == 0: 265 | fabric.print( 266 | f"iter {state['iter_num']} step {state['step_count']}: loss {loss.item():.4f}, iter time:" f" {(t1 - iter_t0) * 1000:.2f}ms{' (optimizer.step)' if not is_accumulating else ''}" 267 | f" remaining time: {(t1 - total_t0) / (state['iter_num'] - initial_iter) * (max_iters - state['iter_num']) / 3600:.2f} hours. " 268 | # print days as well 269 | f" or {(t1 - total_t0) / (state['iter_num'] - initial_iter) * (max_iters - state['iter_num']) / 3600 / 24:.2f} days. " 270 | ) 271 | 272 | monitor.on_train_batch_end( 273 | state["iter_num"] * micro_batch_size, 274 | t1 - total_t0, 275 | # this assumes that device FLOPs are the same and that all devices have the same batch size 276 | fabric.world_size, 277 | state["step_count"], 278 | flops_per_batch=estimated_flops, 279 | lengths=total_lengths, 280 | train_loss = loss.item() 281 | ) 282 | 283 | 284 | 285 | 286 | if val_dataloader is not None and not is_accumulating and state["step_count"] % eval_step_interval == 0: 287 | 288 | t0 = time.perf_counter() 289 | val_loss = validate(fabric, model, val_dataloader) 290 | t1 = time.perf_counter() - t0 291 | monitor.eval_end(t1) 292 | fabric.print(f"step {state['iter_num']}: val loss {val_loss:.4f}, val time: {t1 * 1000:.2f}ms") 293 | fabric.log_dict({"metric/val_loss": val_loss.item(), "total_tokens": model.config.block_size * (state["iter_num"] + 1) * micro_batch_size * fabric.world_size}, state["step_count"]) 294 | fabric.log_dict({"metric/val_ppl": math.exp(val_loss.item()), "total_tokens": model.config.block_size * (state["iter_num"] + 1) * micro_batch_size * fabric.world_size}, state["step_count"]) 295 | fabric.barrier() 296 | if not is_accumulating and state["step_count"] % save_step_interval == 0: 297 | checkpoint_path = out_dir / f"iter-{state['iter_num']:06d}-ckpt.pth" 298 | fabric.print(f"Saving checkpoint to {str(checkpoint_path)!r}") 299 | # only works for pytorch>=2.0 300 | fabric.save(checkpoint_path, state) 301 | 302 | 303 | @torch.no_grad() 304 | def validate(fabric: L.Fabric, model: torch.nn.Module, val_dataloader: DataLoader) -> torch.Tensor: 305 | fabric.print("Validating ...") 306 | model.eval() 307 | 308 | losses = torch.zeros(eval_iters, device=fabric.device) 309 | for k, val_data in enumerate(val_dataloader): 310 | if k >= eval_iters: 311 | break 312 | input_ids = val_data[:, 0 : model.config.block_size].contiguous() 313 | targets = val_data[:, 1 : model.config.block_size + 1].contiguous() 314 | logits = model(input_ids) 315 | loss = chunked_cross_entropy(logits, targets, chunk_size=0) 316 | 317 | # loss_func = FusedCrossEntropyLoss() 318 | # loss = loss_func(logits, targets) 319 | losses[k] = loss.item() 320 | 321 | out = losses.mean() 322 | 323 | model.train() 324 | return out 325 | 326 | 327 | def create_dataloader( 328 | batch_size: int, block_size: int, data_dir: Path, fabric, shuffle: bool = True, seed: int = 12345, split="train" 329 | ) -> DataLoader: 330 | datasets = [] 331 | data_config = train_data_config if split == "train" else val_data_config 332 | for prefix, _ in data_config: 333 | filenames = sorted(glob.glob(str(data_dir / f"{prefix}*"))) 334 | random.seed(seed) 335 | random.shuffle(filenames) 336 | 337 | dataset = PackedDataset( 338 | filenames, 339 | # n_chunks control the buffer size. 340 | # Note that the buffer size also impacts the random shuffle 341 | # (PackedDataset is an IterableDataset. So the shuffle is done by prefetch a buffer and shuffle the buffer) 342 | n_chunks=128, 343 | block_size=block_size, 344 | shuffle=shuffle, 345 | seed=seed+fabric.global_rank, 346 | num_processes=fabric.world_size, 347 | process_rank=fabric.global_rank, 348 | ) 349 | datasets.append(dataset) 350 | 351 | if not datasets: 352 | raise RuntimeError( 353 | f"No data found at {data_dir}. Make sure you ran prepare_redpajama.py to create the dataset." 354 | ) 355 | 356 | weights = [weight for _, weight in data_config] 357 | sum_weights = sum(weights) 358 | weights = [el / sum_weights for el in weights] 359 | 360 | combined_dataset = CombinedDataset(datasets=datasets, seed=seed, weights=weights) 361 | 362 | return DataLoader(combined_dataset, batch_size=batch_size, shuffle=False, pin_memory=True) 363 | 364 | 365 | def create_dataloaders( 366 | batch_size: int, 367 | block_size: int, 368 | fabric, 369 | train_data_dir: Path = Path("data/redpajama_sample"), 370 | val_data_dir: Optional[Path] = None, 371 | seed: int = 12345, 372 | ) -> Tuple[DataLoader, DataLoader]: 373 | # Increase by one because we need the next word as well 374 | effective_block_size = block_size + 1 375 | train_dataloader = create_dataloader( 376 | batch_size=batch_size, 377 | block_size=effective_block_size, 378 | fabric=fabric, 379 | data_dir=train_data_dir, 380 | shuffle=True, 381 | seed=seed, 382 | split="train" 383 | ) 384 | val_dataloader = ( 385 | create_dataloader( 386 | batch_size=batch_size, 387 | block_size=effective_block_size, 388 | fabric=fabric, 389 | data_dir=val_data_dir, 390 | shuffle=False, 391 | seed=seed, 392 | split="validation" 393 | ) 394 | if val_data_dir 395 | else None 396 | ) 397 | return train_dataloader, val_dataloader 398 | 399 | 400 | # learning rate decay scheduler (cosine with warmup) 401 | def get_lr(it): 402 | # 1) linear warmup for warmup_iters steps 403 | if it < warmup_iters: 404 | return learning_rate * it / warmup_iters 405 | # 2) if it > lr_decay_iters, return min learning rate 406 | if it > lr_decay_iters: 407 | return min_lr 408 | # 3) in between, use cosine decay down to min learning rate 409 | decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters) 410 | assert 0 <= decay_ratio <= 1 411 | coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1 412 | return min_lr + coeff * (learning_rate - min_lr) 413 | 414 | 415 | if __name__ == "__main__": 416 | # Uncomment this line if you see an error: "Expected is_sm80 to be true, but got false" 417 | # torch.backends.cuda.enable_flash_sdp(False) 418 | torch.set_float32_matmul_precision("high") 419 | 420 | from jsonargparse import CLI 421 | 422 | CLI(setup) 423 | -------------------------------------------------------------------------------- /pretrain/tinyllama_code.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018-2024 Advanced Micro Devices, Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import glob 17 | import math 18 | import sys 19 | import time 20 | from pathlib import Path 21 | from typing import Optional, Tuple, Union 22 | import math 23 | import lightning as L 24 | import torch 25 | from lightning.fabric.strategies import FSDPStrategy, XLAStrategy 26 | from torch.utils.data import DataLoader 27 | from functools import partial 28 | # support running without installing as a package 29 | wd = Path(__file__).parent.parent.resolve() 30 | sys.path.append(str(wd)) 31 | # from apex.optimizers import FusedAdam #torch optimizer has a cuda backend, which is faster actually 32 | from lit_gpt.model import GPT, Block, Config, CausalSelfAttention 33 | from lit_gpt.packed_dataset import CombinedDataset, PackedDataset 34 | from lit_gpt.speed_monitor import SpeedMonitorFabric as Monitor 35 | from lit_gpt.speed_monitor import estimate_flops, measure_flops 36 | from lit_gpt.utils import chunked_cross_entropy, get_default_supported_precision, num_parameters, step_csv_logger, lazy_load 37 | from pytorch_lightning.loggers import WandbLogger 38 | #from lit_gpt import FusedCrossEntropyLoss 39 | import random 40 | 41 | model_name = 'tiny_LLaMA_135M_2k' # model to train 42 | 43 | name = "tinyllama" 44 | out_dir = Path("./out") / (name+"_135M_2k_code") 45 | 46 | default_seed=3407 47 | 48 | # Hyperparameters 49 | num_node=1 50 | num_of_devices = 8 51 | global_batch_size = 320/num_node 52 | learning_rate = 3e-4 53 | micro_batch_size = 16 54 | num_epochs=1 55 | num_total_token_in_b = 21 * num_epochs 56 | 57 | warmup_steps = 2000 58 | log_step_interval = 10 59 | eval_iters = 1000 60 | save_step_interval = 2000 61 | eval_step_interval = 2000 62 | 63 | 64 | weight_decay = 1e-1 65 | beta1 = 0.9 66 | beta2 = 0.95 67 | grad_clip = 1.0 68 | decay_lr = True 69 | min_lr = 3e-5 70 | 71 | batch_size = global_batch_size // num_of_devices 72 | gradient_accumulation_steps = batch_size // micro_batch_size 73 | actual_global_batch = gradient_accumulation_steps*micro_batch_size*num_of_devices*num_node 74 | print(actual_global_batch) 75 | max_step = int(num_total_token_in_b * 10**9/(actual_global_batch*2048)//save_step_interval + 1)*save_step_interval 76 | assert gradient_accumulation_steps > 0 77 | warmup_iters = warmup_steps * gradient_accumulation_steps 78 | 79 | 80 | 81 | import math 82 | max_iters = max_step * gradient_accumulation_steps 83 | lr_decay_iters = max_iters 84 | log_iter_interval = math.ceil(log_step_interval * gradient_accumulation_steps) 85 | 86 | 87 | 88 | # Treat all dataset equally by their size. If you want to use a different weight for a dataset, add it to the list with the weight. 89 | train_data_config = [ 90 | ("train_", 1.0), 91 | ] 92 | 93 | val_data_config = [ 94 | ("validation_slim", 1.0), 95 | ] 96 | 97 | hparams = {k: v for k, v in locals().items() if isinstance(v, (int, float, str)) and not k.startswith("_")} 98 | logger = step_csv_logger("out", name, flush_logs_every_n_steps=log_iter_interval) 99 | wandb_logger = WandbLogger() 100 | 101 | 102 | def setup( 103 | devices: int = 8, 104 | train_data_dir: Path = Path("./slim_star_combined"), 105 | val_data_dir: Optional[Path] = None, 106 | precision: Optional[str] = 'bf16-mixed', 107 | tpu: bool = False, 108 | resume: Union[bool, Path] = False, 109 | model_name: str=None, 110 | checkpoint_path: str=None 111 | ) -> None: 112 | precision = precision# or get_default_supported_precision(training=True, tpu=tpu) 113 | 114 | if devices > 1: 115 | if tpu: 116 | # For multi-host TPU training, the device count for Fabric is limited to the count on a single host. 117 | devices = "auto" 118 | strategy = XLAStrategy(sync_module_states=False) 119 | else: 120 | strategy = FSDPStrategy( 121 | auto_wrap_policy={Block}, 122 | activation_checkpointing_policy=None, 123 | state_dict_type="full", 124 | limit_all_gathers=True, 125 | cpu_offload=False, 126 | ) 127 | else: 128 | strategy = "auto" 129 | 130 | fabric = L.Fabric(devices=devices, strategy=strategy, precision=precision, loggers=[logger, wandb_logger]) 131 | fabric.print(hparams) 132 | #fabric.launch(main, train_data_dir, val_data_dir, resume) 133 | main(fabric, train_data_dir, val_data_dir, resume, model_name, checkpoint_path) 134 | 135 | 136 | def main(fabric, train_data_dir, val_data_dir, resume, model_name=None, checkpoint_path=None): 137 | print('continue {}'.format(checkpoint_path)) 138 | monitor = Monitor(fabric, window_size=2, time_unit="seconds", log_iter_interval=log_iter_interval) 139 | 140 | if fabric.global_rank == 0: 141 | out_dir.mkdir(parents=True, exist_ok=True) 142 | 143 | config = Config.from_name(model_name) 144 | 145 | train_dataloader, val_dataloader = create_dataloaders( 146 | batch_size=micro_batch_size, 147 | block_size=config.block_size, 148 | fabric=fabric, 149 | train_data_dir=train_data_dir, 150 | val_data_dir=val_data_dir, 151 | seed=default_seed, 152 | ) 153 | if val_dataloader is None: 154 | train_dataloader = fabric.setup_dataloaders(train_dataloader) 155 | else: 156 | train_dataloader, val_dataloader = fabric.setup_dataloaders(train_dataloader, val_dataloader) 157 | 158 | fabric.seed_everything(default_seed) # same seed for every process to init model (FSDP) 159 | 160 | fabric.print(f"Loading model with {config.__dict__}") 161 | t0 = time.perf_counter() 162 | with fabric.init_module(empty_init=True): 163 | model = GPT(config) 164 | 165 | 166 | fabric.print(f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.") 167 | fabric.print(f"Total parameters {num_parameters(model):,}") 168 | 169 | model = fabric.setup(model) 170 | fabric.load_raw(checkpoint_path, model, strict=True) 171 | optimizer = torch.optim.AdamW( 172 | model.parameters(), lr=learning_rate, weight_decay=weight_decay, betas=(beta1, beta2), foreach=False 173 | ) 174 | # optimizer = FusedAdam(model.parameters(), lr=learning_rate, weight_decay=weight_decay, betas=(beta1, beta2),adam_w_mode=True) 175 | optimizer = fabric.setup_optimizers(optimizer) 176 | 177 | state = {"model": model, "optimizer": optimizer, "hparams": hparams, "iter_num": 0, "step_count": 0} 178 | 179 | if resume is True: 180 | resume = sorted(out_dir.glob("*.pth"))[-1] 181 | if resume : 182 | fabric.print(f"Resuming training from {resume}") 183 | fabric.load(resume, state) 184 | 185 | train_time = time.perf_counter() 186 | train(fabric, state, train_dataloader, val_dataloader, monitor, resume) 187 | fabric.print(f"Training time: {(time.perf_counter()-train_time):.2f}s") 188 | if fabric.device.type == "cuda": 189 | fabric.print(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB") 190 | 191 | 192 | def train(fabric, state, train_dataloader, val_dataloader, monitor, resume): 193 | model = state["model"] 194 | optimizer = state["optimizer"] 195 | 196 | # if val_dataloader is not None: 197 | # validate(fabric, model, val_dataloader) # sanity check 198 | model.train() 199 | 200 | meta_model = GPT(model.config).cuda() 201 | # "estimated" is not as precise as "measured". Estimated is optimistic but widely used in the wild. 202 | # When comparing MFU or FLOP numbers with other projects that use estimated FLOPs, 203 | # consider passing `SpeedMonitor(flops_per_batch=estimated_flops)` instead 204 | estimated_flops = estimate_flops(meta_model) * micro_batch_size 205 | fabric.print(f"Estimated TFLOPs: {estimated_flops * fabric.world_size / 1e12:.2f}") 206 | x = torch.randint(0, 1, (micro_batch_size, model.config.block_size)) 207 | # measured_flos run in meta. Will trigger fusedRMSNorm error 208 | #measured_flops = measure_flops(meta_model, x) 209 | #fabric.print(f"Measured TFLOPs: {measured_flops * fabric.world_size / 1e12:.2f}") 210 | del meta_model, x 211 | 212 | total_lengths = 0 213 | total_t0 = time.perf_counter() 214 | 215 | if fabric.device.type == "xla": 216 | import torch_xla.core.xla_model as xm 217 | 218 | xm.mark_step() 219 | 220 | 221 | initial_iter = state["iter_num"] 222 | curr_iter = 0 223 | 224 | loss_func = torch.nn.CrossEntropyLoss() #FusedCrossEntropyLoss() 225 | for i, train_data in enumerate(train_dataloader): 226 | # resume loader state. This is not elegant but it works. Should rewrite it in the future. 227 | if resume: 228 | if curr_iter < initial_iter: 229 | curr_iter += 1 230 | continue 231 | else: 232 | resume = False 233 | curr_iter = -1 234 | fabric.barrier() 235 | fabric.print("resume finished, taken {} seconds".format(time.perf_counter() - total_t0)) 236 | if state["iter_num"] >= max_iters: 237 | break 238 | 239 | # determine and set the learning rate for this iteration 240 | lr = get_lr(state["iter_num"]) if decay_lr else learning_rate 241 | for param_group in optimizer.param_groups: 242 | param_group["lr"] = lr 243 | 244 | iter_t0 = time.perf_counter() 245 | 246 | input_ids = train_data[:, 0 : model.config.block_size].contiguous() 247 | targets = train_data[:, 1 : model.config.block_size + 1].contiguous() 248 | is_accumulating = (state["iter_num"] + 1) % gradient_accumulation_steps != 0 249 | with fabric.no_backward_sync(model, enabled=is_accumulating): 250 | logits = model(input_ids) 251 | loss = loss_func(logits.transpose(1,2), targets) 252 | # loss = chunked_cross_entropy(logits, targets, chunk_size=0) 253 | fabric.backward(loss / gradient_accumulation_steps) 254 | 255 | if not is_accumulating: 256 | fabric.clip_gradients(model, optimizer, max_norm=grad_clip) 257 | optimizer.step() 258 | optimizer.zero_grad() 259 | state["step_count"] += 1 260 | elif fabric.device.type == "xla": 261 | xm.mark_step() 262 | state["iter_num"] += 1 263 | # input_id: B L 264 | total_lengths += input_ids.size(1) 265 | t1 = time.perf_counter() 266 | if i % log_step_interval == 0: 267 | fabric.print( 268 | f"iter {state['iter_num']} step {state['step_count']}: loss {loss.item():.4f}, iter time:" f" {(t1 - iter_t0) * 1000:.2f}ms{' (optimizer.step)' if not is_accumulating else ''}" 269 | f" remaining time: {(t1 - total_t0) / (state['iter_num'] - initial_iter) * (max_iters - state['iter_num']) / 3600:.2f} hours. " 270 | # print days as well 271 | f" or {(t1 - total_t0) / (state['iter_num'] - initial_iter) * (max_iters - state['iter_num']) / 3600 / 24:.2f} days. " 272 | ) 273 | 274 | monitor.on_train_batch_end( 275 | state["iter_num"] * micro_batch_size, 276 | t1 - total_t0, 277 | # this assumes that device FLOPs are the same and that all devices have the same batch size 278 | fabric.world_size, 279 | state["step_count"], 280 | flops_per_batch=estimated_flops, 281 | lengths=total_lengths, 282 | train_loss = loss.item() 283 | ) 284 | 285 | 286 | 287 | 288 | if val_dataloader is not None and not is_accumulating and state["step_count"] % eval_step_interval == 0: 289 | 290 | t0 = time.perf_counter() 291 | val_loss = validate(fabric, model, val_dataloader) 292 | t1 = time.perf_counter() - t0 293 | monitor.eval_end(t1) 294 | fabric.print(f"step {state['iter_num']}: val loss {val_loss:.4f}, val time: {t1 * 1000:.2f}ms") 295 | fabric.log_dict({"metric/val_loss": val_loss.item(), "total_tokens": model.config.block_size * (state["iter_num"] + 1) * micro_batch_size * fabric.world_size}, state["step_count"]) 296 | fabric.log_dict({"metric/val_ppl": math.exp(val_loss.item()), "total_tokens": model.config.block_size * (state["iter_num"] + 1) * micro_batch_size * fabric.world_size}, state["step_count"]) 297 | fabric.barrier() 298 | if not is_accumulating and state["step_count"] % save_step_interval == 0: 299 | checkpoint_path = out_dir / f"iter-{state['iter_num']:06d}-ckpt.pth" 300 | fabric.print(f"Saving checkpoint to {str(checkpoint_path)!r}") 301 | # only works for pytorch>=2.0 302 | fabric.save(checkpoint_path, state) 303 | 304 | 305 | @torch.no_grad() 306 | def validate(fabric: L.Fabric, model: torch.nn.Module, val_dataloader: DataLoader) -> torch.Tensor: 307 | fabric.print("Validating ...") 308 | model.eval() 309 | 310 | losses = torch.zeros(eval_iters, device=fabric.device) 311 | for k, val_data in enumerate(val_dataloader): 312 | if k >= eval_iters: 313 | break 314 | input_ids = val_data[:, 0 : model.config.block_size].contiguous() 315 | targets = val_data[:, 1 : model.config.block_size + 1].contiguous() 316 | logits = model(input_ids) 317 | loss = chunked_cross_entropy(logits, targets, chunk_size=0) 318 | 319 | # loss_func = FusedCrossEntropyLoss() 320 | # loss = loss_func(logits, targets) 321 | losses[k] = loss.item() 322 | 323 | out = losses.mean() 324 | 325 | model.train() 326 | return out 327 | 328 | 329 | def create_dataloader( 330 | batch_size: int, block_size: int, data_dir: Path, fabric, shuffle: bool = True, seed: int = 12345, split="train" 331 | ) -> DataLoader: 332 | datasets = [] 333 | data_config = train_data_config if split == "train" else val_data_config 334 | for prefix, _ in data_config: 335 | filenames = sorted(glob.glob(str(data_dir / f"{prefix}*"))) 336 | random.seed(seed) 337 | random.shuffle(filenames) 338 | 339 | dataset = PackedDataset( 340 | filenames, 341 | # n_chunks control the buffer size. 342 | # Note that the buffer size also impacts the random shuffle 343 | # (PackedDataset is an IterableDataset. So the shuffle is done by prefetch a buffer and shuffle the buffer) 344 | n_chunks=64, 345 | block_size=block_size, 346 | shuffle=shuffle, 347 | seed=seed+fabric.global_rank, 348 | num_processes=fabric.world_size, 349 | process_rank=fabric.global_rank, 350 | ) 351 | datasets.append(dataset) 352 | 353 | if not datasets: 354 | raise RuntimeError( 355 | f"No data found at {data_dir}. Make sure you ran prepare_redpajama.py to create the dataset." 356 | ) 357 | 358 | weights = [weight for _, weight in data_config] 359 | sum_weights = sum(weights) 360 | weights = [el / sum_weights for el in weights] 361 | 362 | combined_dataset = CombinedDataset(datasets=datasets, seed=seed, weights=weights) 363 | 364 | return DataLoader(combined_dataset, batch_size=batch_size, shuffle=False, pin_memory=True) 365 | 366 | 367 | def create_dataloaders( 368 | batch_size: int, 369 | block_size: int, 370 | fabric, 371 | train_data_dir: Path = Path("data/redpajama_sample"), 372 | val_data_dir: Optional[Path] = None, 373 | seed: int = 12345, 374 | ) -> Tuple[DataLoader, DataLoader]: 375 | # Increase by one because we need the next word as well 376 | effective_block_size = block_size + 1 377 | train_dataloader = create_dataloader( 378 | batch_size=batch_size, 379 | block_size=effective_block_size, 380 | fabric=fabric, 381 | data_dir=train_data_dir, 382 | shuffle=True, 383 | seed=seed, 384 | split="train" 385 | ) 386 | val_dataloader = ( 387 | create_dataloader( 388 | batch_size=batch_size, 389 | block_size=effective_block_size, 390 | fabric=fabric, 391 | data_dir=val_data_dir, 392 | shuffle=False, 393 | seed=seed, 394 | split="validation" 395 | ) 396 | if val_data_dir 397 | else None 398 | ) 399 | return train_dataloader, val_dataloader 400 | 401 | 402 | # learning rate decay scheduler (cosine with warmup) 403 | def get_lr(it): 404 | # 1) linear warmup for warmup_iters steps 405 | if it < warmup_iters: 406 | return learning_rate * it / warmup_iters 407 | # 2) if it > lr_decay_iters, return min learning rate 408 | if it > lr_decay_iters: 409 | return min_lr 410 | # 3) in between, use cosine decay down to min learning rate 411 | decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters) 412 | assert 0 <= decay_ratio <= 1 413 | coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1 414 | return min_lr + coeff * (learning_rate - min_lr) 415 | 416 | 417 | if __name__ == "__main__": 418 | # Uncomment this line if you see an error: "Expected is_sm80 to be true, but got false" 419 | # torch.backends.cuda.enable_flash_sdp(False) 420 | torch.set_float32_matmul_precision("high") 421 | 422 | from jsonargparse import CLI 423 | 424 | CLI(setup) 425 | -------------------------------------------------------------------------------- /lit_gpt/model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018-2024 Advanced Micro Devices, Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | """Full definition of a GPT NeoX Language Model, all of it in this single file. 17 | 18 | Based on the nanoGPT implementation: https://github.com/karpathy/nanoGPT and 19 | https://github.com/EleutherAI/gpt-neox/tree/main/megatron/model. 20 | """ 21 | import math 22 | from typing import Any, List, Optional, Tuple 23 | 24 | import torch 25 | import torch.nn as nn 26 | # from lightning_utilities.core.imports import RequirementCache 27 | from typing_extensions import Self 28 | # from flash_attn import flash_attn_func 29 | from lit_gpt.config import Config 30 | #from xformers.ops import SwiGLU 31 | #from .fused_rotary_embedding import apply_rotary_emb_func 32 | from .rotary_ebm import apply_rotary_pos_emb 33 | 34 | RoPECache = Tuple[torch.Tensor, torch.Tensor] 35 | KVCache = Tuple[torch.Tensor, torch.Tensor] 36 | # FlashAttention2Available = RequirementCache("flash-attn>=2.0.0.post1") 37 | 38 | # input_pos_global = torch.arange(0, 4096, device=torch.device('cuda')) 39 | #import triton 40 | #from triton import ops 41 | 42 | class GPT(nn.Module): 43 | def __init__(self, config: Config) -> None: 44 | super().__init__() 45 | assert config.padded_vocab_size is not None 46 | self.config = config 47 | 48 | self.lm_head = nn.Linear(config.n_embd, config.padded_vocab_size, bias=False) 49 | self.transformer = nn.ModuleDict( 50 | dict( 51 | wte=nn.Embedding(config.padded_vocab_size, config.n_embd), 52 | h=nn.ModuleList(Block(config) for _ in range(config.n_layer)), 53 | ln_f=config.norm_class(config.n_embd, eps=config.norm_eps), 54 | ) 55 | ) 56 | self.rope_cache: Optional[RoPECache] = None 57 | self.mask_cache: Optional[torch.Tensor] = None 58 | self.kv_caches: List[KVCache] = [] 59 | 60 | def _init_weights(self, module: nn.Module, n_layer) -> None: 61 | """Meant to be used with `gpt.apply(gpt._init_weights)`.""" 62 | # GPT-NeoX https://arxiv.org/pdf/2204.06745.pdf 63 | if isinstance(module, nn.Embedding): 64 | torch.nn.init.normal_(module.weight, mean=0.0, std=math.sqrt(2.0 / 5 / self.config.n_embd)) 65 | # RWKV: set it to 1e-4 66 | # torch.nn.init.uniform_(module.weight, -1e-4, 1e-4) 67 | elif isinstance(module, nn.Linear): 68 | torch.nn.init.normal_(module.weight, mean=0.0, std=math.sqrt(2.0 / 5 / self.config.n_embd)) 69 | if module.bias is not None: 70 | torch.nn.init.zeros_(module.bias) 71 | # GPT-NeoX 72 | for name, p in module.named_parameters(): 73 | if (name == "proj.weight" and isinstance(module, LLaMAMLP)) or (name == "w3.weight" and isinstance(module, SwiGLU) or (name=="proj.weight" and isinstance(module, CausalSelfAttention))): #if use xformer swiglu, fc2 layer will be renamed to w3 74 | nn.init.normal_(p, mean=0.0, std=1 / math.sqrt(self.config.n_embd) / n_layer) 75 | 76 | 77 | def reset_cache(self) -> None: 78 | self.kv_caches.clear() 79 | if self.mask_cache is not None and self.mask_cache.device.type == "xla": 80 | # https://github.com/Lightning-AI/lit-gpt/pull/83#issuecomment-1558150179 81 | self.rope_cache = None 82 | self.mask_cache = None 83 | 84 | def forward( 85 | self, idx: torch.Tensor, max_seq_length: Optional[int] = None, input_pos: Optional[torch.Tensor] = None 86 | ) -> torch.Tensor: 87 | B, T = idx.size() 88 | use_kv_cache = input_pos is not None 89 | 90 | block_size = self.config.block_size 91 | if max_seq_length is None: 92 | max_seq_length = block_size 93 | if use_kv_cache: # not relevant otherwise 94 | assert ( 95 | max_seq_length >= T 96 | ), f"Cannot forward sequence of length {T}, max seq length is only {max_seq_length}" 97 | assert max_seq_length <= block_size, f"Cannot attend to {max_seq_length}, block size is only {block_size}" 98 | assert block_size >= T, f"Cannot forward sequence of length {T}, block size is only {block_size}" 99 | 100 | if self.rope_cache is None: 101 | self.rope_cache = self.build_rope_cache(idx) 102 | # passing `attn_mask` to SDPA downgrades it to use the inefficient implementation. since we only need the mask 103 | # for the kv-cache support (only during inference), we only create it in that situation 104 | # this will be resolved by https://github.com/pytorch/pytorch/issues/96099 105 | if use_kv_cache and self.mask_cache is None: 106 | self.mask_cache = self.build_mask_cache(idx) 107 | 108 | cos, sin = self.rope_cache 109 | if use_kv_cache: 110 | 111 | cos = cos.index_select(0, input_pos) 112 | sin = sin.index_select(0, input_pos) 113 | mask = self.mask_cache.index_select(2, input_pos) 114 | mask = mask[:, :, :, :max_seq_length] 115 | else: 116 | cos = cos[:T] 117 | sin = sin[:T] 118 | mask = None 119 | input_pos = torch.arange(0, T, device=idx.device) 120 | 121 | # forward the model itself 122 | x = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd) 123 | 124 | if not use_kv_cache: 125 | for block in self.transformer.h: 126 | x, *_ = block(x, (cos, sin), max_seq_length, input_pos=input_pos) 127 | else: 128 | self.kv_caches = self.kv_caches or self.build_kv_caches(x, max_seq_length, cos.size(-1) * 2) 129 | for i, block in enumerate(self.transformer.h): 130 | x, self.kv_caches[i] = block(x, (cos, sin), max_seq_length, mask, input_pos, self.kv_caches[i]) 131 | 132 | x = self.transformer.ln_f(x) 133 | return self.lm_head(x) # (b, t, vocab_size) 134 | 135 | @classmethod 136 | def from_name(cls, name: str, **kwargs: Any) -> Self: 137 | return cls(Config.from_name(name, **kwargs)) 138 | 139 | def build_rope_cache(self, idx: torch.Tensor) -> RoPECache: 140 | return build_rope_cache( 141 | seq_len=self.config.block_size, 142 | n_elem=int(self.config.rotary_percentage * self.config.head_size), 143 | dtype=torch.bfloat16, 144 | device=idx.device, 145 | condense_ratio=self.config.condense_ratio, 146 | ) 147 | 148 | def build_mask_cache(self, idx: torch.Tensor) -> torch.Tensor: 149 | ones = torch.ones((self.config.block_size, self.config.block_size), device=idx.device, dtype=torch.bool) 150 | return torch.tril(ones).unsqueeze(0).unsqueeze(0) 151 | 152 | def build_kv_caches(self, idx: torch.Tensor, max_seq_length: int, rope_cache_length: int) -> List[KVCache]: 153 | B = idx.size(0) 154 | heads = 1 if self.config.n_query_groups == 1 else self.config.n_query_groups 155 | 156 | k_cache_shape = ( 157 | B, 158 | max_seq_length, 159 | heads, 160 | rope_cache_length + self.config.head_size - int(self.config.rotary_percentage * self.config.head_size), 161 | ) 162 | v_cache_shape = (B, max_seq_length, heads, self.config.head_size) 163 | device = idx.device 164 | return [ 165 | (torch.zeros(k_cache_shape, device=device), torch.zeros(v_cache_shape, device=device)) 166 | for _ in range(self.config.n_layer) 167 | ] 168 | 169 | 170 | class Block(nn.Module): 171 | def __init__(self, config: Config) -> None: 172 | super().__init__() 173 | self.norm_1 = config.norm_class(config.n_embd, eps=config.norm_eps) 174 | self.attn = CausalSelfAttention(config) 175 | if not config.shared_attention_norm: 176 | self.norm_2 = config.norm_class(config.n_embd, eps=config.norm_eps) 177 | self.mlp = config.mlp_class(config) 178 | self.config = config 179 | def forward( 180 | self, 181 | x: torch.Tensor, 182 | rope: RoPECache, 183 | max_seq_length: int, 184 | mask: Optional[torch.Tensor] = None, 185 | input_pos: Optional[torch.Tensor] = None, 186 | kv_cache: Optional[KVCache] = None, 187 | ) -> Tuple[torch.Tensor, Optional[KVCache]]: 188 | 189 | n_1 = self.norm_1(x) 190 | h, new_kv_cache = self.attn(n_1, rope, max_seq_length, mask, input_pos, kv_cache) 191 | if self.config.parallel_residual: 192 | n_2 = n_1 if self.config.shared_attention_norm else self.norm_2(x) 193 | x = x + h + self.mlp(n_2) 194 | else: 195 | if self.config.shared_attention_norm: 196 | raise NotImplementedError( 197 | "No checkpoint amongst the ones we support uses this configuration" 198 | " (non-parallel residual and shared attention norm)." 199 | ) 200 | 201 | x = x + h 202 | x = x + self.mlp(self.norm_2(x)) 203 | return x, new_kv_cache 204 | 205 | 206 | class CausalSelfAttention(nn.Module): 207 | def __init__(self, config: Config) -> None: 208 | super().__init__() 209 | shape = (config.n_head + 2 * config.n_query_groups) * config.head_size 210 | # key, query, value projections for all heads, but in a batch 211 | self.attn = nn.Linear(config.n_embd, shape, bias=config.bias) 212 | # output projection 213 | self.proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias) 214 | 215 | self.config = config 216 | 217 | def forward( 218 | self, 219 | x: torch.Tensor, 220 | rope: RoPECache, 221 | max_seq_length: int, 222 | mask: Optional[torch.Tensor] = None, 223 | input_pos: Optional[torch.Tensor] = None, 224 | kv_cache: Optional[KVCache] = None, 225 | ) -> Tuple[torch.Tensor, Optional[KVCache]]: 226 | B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) 227 | 228 | qkv = self.attn(x) 229 | 230 | # assemble into a number of query groups to support MHA, MQA and GQA together (see `config.n_query_groups`) 231 | q_per_kv = self.config.n_head // self.config.n_query_groups 232 | total_qkv = q_per_kv + 2 # each group has 1+ queries, 1 key, and 1 value 233 | qkv = qkv.view(B, T, self.config.n_query_groups, total_qkv, self.config.head_size) # (B, T, n_query_groups, total_qkv, hs) 234 | # qkv = qkv.permute(0, 2, 3, 1, 4) # (B, n_query_groups, total_qkv, T, hs) 235 | 236 | # split batched computation into three 237 | q, k, v = qkv.split((q_per_kv, 1, 1), dim=-2) 238 | 239 | # repeat k and v if necessary 240 | # Peiyuan: we do not need to do this as flash attention 2 already support GQA 241 | # if self.config.n_query_groups != 1: # doing this would require a full kv cache with MQA (inefficient!) 242 | # # for MHA this is a no-op 243 | # k = k.expand(B, self.config.n_query_groups, q_per_kv, T, self.config.head_size) 244 | # v = v.expand(B, self.config.n_query_groups, q_per_kv, T, self.config.head_size) 245 | 246 | q = q.reshape(B, T, -1, self.config.head_size) # (B, T, nh_q, hs) 247 | k = k.reshape(B, T, -1, self.config.head_size) 248 | v = v.reshape(B, T, -1, self.config.head_size) 249 | 250 | cos, sin = rope 251 | 252 | # apply rope in fp32 significanly stabalize training 253 | # fused rope expect (batch_size, seqlen, nheads, headdim) 254 | #q = apply_rotary_emb_func(q, cos, sin, False, True) 255 | #k = apply_rotary_emb_func(k, cos, sin, False, True) 256 | q, k = apply_rotary_pos_emb(q, k, cos, sin, input_pos) 257 | 258 | # n_elem = int(self.config.rotary_percentage * self.config.head_size) 259 | 260 | # q_roped = apply_rope(q[..., :n_elem], cos.repeat(1,2), sin.repeat(1,2)) 261 | # k_roped = apply_rope(k[..., :n_elem], cos.repeat(1,2), sin.repeat(1,2)) 262 | # print( (q_roped - q).sum()) 263 | # q = torch.cat((q_roped, q[..., n_elem:]), dim=-1) 264 | # k = torch.cat((k_roped, k[..., n_elem:]), dim=-1) 265 | 266 | if kv_cache is not None: 267 | cache_k, cache_v = kv_cache 268 | cache_k, cache_v = cache_k.to(dtype=k.dtype), cache_v.to(dtype=v.dtype) 269 | # check if reached token limit 270 | if input_pos[-1] >= max_seq_length: 271 | input_pos = torch.tensor(max_seq_length - 1, device=input_pos.device) 272 | # shift 1 position to the left 273 | cache_k = torch.roll(cache_k, -1, dims=1) 274 | cache_v = torch.roll(cache_v, -1, dims=1) 275 | 276 | k = cache_k.index_copy_(1, input_pos, k) 277 | v = cache_v.index_copy_(1, input_pos, v) 278 | kv_cache = k, v 279 | 280 | y = self.scaled_dot_product_attention(q, k, v, mask=mask) 281 | 282 | y = y.reshape(B, T, C) # re-assemble all head outputs side by side 283 | 284 | # output projection 285 | y = self.proj(y) 286 | 287 | return y, kv_cache 288 | 289 | def scaled_dot_product_attention( 290 | self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, mask: Optional[torch.Tensor] = None 291 | ): 292 | scale = 1.0 / math.sqrt(self.config.head_size) 293 | ''' 294 | if ( 295 | FlashAttention2Available 296 | and mask is None 297 | and q.device.type == "cuda" 298 | and self.config.enable_flash_attn 299 | #and q.dtype in (torch.float16, torch.bfloat16) 300 | ): 301 | from flash_attn import flash_attn_func 302 | return flash_attn_func(q.to(self.config.flash_attn_dtype), k.to(self.config.flash_attn_dtype), v.to(self.config.flash_attn_dtype), dropout_p=0.0, softmax_scale=scale, causal=True).to(v.dtype) 303 | ''' 304 | q = q.transpose(1, 2) 305 | k = k.transpose(1, 2) 306 | v = v.transpose(1, 2) 307 | if q.size() != k.size(): 308 | k = k.repeat_interleave(q.shape[1]//k.shape[1], dim=1) 309 | v = v.repeat_interleave(q.shape[1]//v.shape[1], dim=1) 310 | y = torch.nn.functional.scaled_dot_product_attention( 311 | q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=mask is None, scale=scale 312 | ) 313 | 314 | return y.transpose(1, 2) 315 | 316 | # Efficient implementation equivalent to the following: 317 | def raw_product_attention(self, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None) -> torch.Tensor: 318 | # Efficient implementation equivalent to the following: 319 | L, S = query.size(-2), key.size(-2) 320 | scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale 321 | attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device) 322 | if is_causal: 323 | assert attn_mask is None 324 | temp_mask = torch.ones(L, S, dtype=torch.bool, device=query.device).tril(diagonal=0) 325 | attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) 326 | attn_bias.to(query.dtype) 327 | 328 | if attn_mask is not None: 329 | if attn_mask.dtype == torch.bool: 330 | attn_mask = (~attn_mask).to(query.dtype).masked_fill_(attn_mask.logical_not(), float("-inf")) 331 | else: 332 | attn_bias += attn_mask 333 | attn_weight = query @ key.transpose(-2, -1) * scale_factor 334 | attn_weight += attn_bias 335 | attn_weight = torch.softmax(attn_weight, dim=-1) 336 | attn_weight = torch.dropout(attn_weight, dropout_p, train=True) 337 | return attn_weight @ value 338 | 339 | 340 | def test_attn(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None) -> torch.Tensor: 341 | # Efficient implementation equivalent to the following: 342 | L, S = query.size(-2), key.size(-2) 343 | scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale 344 | attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device) 345 | if is_causal: 346 | assert attn_mask is None 347 | temp_mask = torch.ones(L, S, dtype=torch.bool, device=query.device).tril(diagonal=0) 348 | attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) 349 | attn_bias.to(query.dtype) 350 | 351 | if attn_mask is not None: 352 | if attn_mask.dtype == torch.bool: 353 | attn_mask = (~attn_mask).to(query.dtype).masked_fill_(attn_mask.logical_not(), float("-inf")) 354 | else: 355 | attn_bias += attn_mask 356 | attn_weight = query @ key.transpose(-2, -1) * scale_factor 357 | attn_weight += attn_bias 358 | attn_weight = torch.softmax(attn_weight, dim=-1) 359 | attn_weight = torch.dropout(attn_weight, dropout_p, train=True) 360 | return attn_weight @ value 361 | 362 | class GptNeoxMLP(nn.Module): 363 | def __init__(self, config: Config) -> None: 364 | super().__init__() 365 | self.fc = nn.Linear(config.n_embd, config.intermediate_size, bias=config.bias) 366 | self.proj = nn.Linear(config.intermediate_size, config.n_embd, bias=config.bias) 367 | 368 | def forward(self, x: torch.Tensor) -> torch.Tensor: 369 | x = self.fc(x) 370 | x = torch.nn.functional.gelu(x) 371 | return self.proj(x) 372 | 373 | 374 | class LLaMAMLP(nn.Module): ##NOTE: changed to use torch ativation version Dec 8. 375 | def __init__(self, config: Config) -> None: 376 | super().__init__() 377 | self.fc_1 = nn.Linear(config.n_embd, config.intermediate_size, bias=config.bias) 378 | self.fc_2 = nn.Linear(config.n_embd, config.intermediate_size, bias=config.bias) 379 | self.proj = nn.Linear(config.intermediate_size, config.n_embd, bias=config.bias) 380 | # self.swiglu = SwiGLU(config.n_embd,config.intermediate_size, bias=False, _pack_weights=False) 381 | def forward(self, x: torch.Tensor) -> torch.Tensor: 382 | x_fc_1 = self.fc_1(x) 383 | x_fc_2 = self.fc_2(x) 384 | x = torch.nn.functional.silu(x_fc_1) * x_fc_2 385 | return self.proj(x) 386 | # return self.swiglu(x) 387 | 388 | 389 | def build_rope_cache( 390 | seq_len: int, n_elem: int, dtype: torch.dtype, device: torch.device, base: int = 10000, condense_ratio: int = 1 391 | ) -> RoPECache: 392 | """Enhanced Transformer with Rotary Position Embedding. 393 | 394 | Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/ 395 | transformers/rope/__init__.py. MIT License: 396 | https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license. 397 | """ 398 | # $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$ 399 | theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, device=device) / n_elem)) 400 | 401 | # Create position indexes `[0, 1, ..., seq_len - 1]` 402 | seq_idx = torch.arange(seq_len, device=device) / condense_ratio 403 | 404 | # Calculate the product of position index and $\theta_i$ 405 | idx_theta = torch.outer(seq_idx, theta) 406 | idx_theta = torch.cat((idx_theta, idx_theta), dim=-1) 407 | 408 | cos, sin = torch.cos(idx_theta), torch.sin(idx_theta) 409 | 410 | # added by peiyuan to ensure same data type with q, k, to use fused rotary embedding 411 | # if dtype == torch.bfloat16: 412 | # return cos.bfloat16(), sin.bfloat16() 413 | # # this is to mimic the behaviour of complex32, else we will get different results 414 | # if dtype in (torch.float16, torch.bfloat16, torch.int8): 415 | # return cos.half(), sin.half() 416 | return cos, sin 417 | 418 | 419 | def apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: 420 | head_size = x.size(-1) 421 | x1 = x[..., : head_size // 2] # (B, nh, T, hs/2) 422 | x2 = x[..., head_size // 2 :] # (B, nh, T, hs/2) 423 | rotated = torch.cat((-x2, x1), dim=-1) # (B, nh, T, hs) 424 | roped = (x * cos) + (rotated * sin) 425 | return roped.type_as(x) 426 | -------------------------------------------------------------------------------- /lit_gpt/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018-2024 Advanced Micro Devices, Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | """Utility functions for training and inference.""" 17 | 18 | import pickle 19 | import sys 20 | import warnings 21 | from contextlib import contextmanager 22 | from functools import partial 23 | from io import BytesIO 24 | from pathlib import Path 25 | from types import MethodType 26 | from typing import Any, Dict, List, Mapping, Optional, Type, TypeVar, Union 27 | 28 | import torch 29 | import torch.nn as nn 30 | from lightning.fabric.loggers import CSVLogger 31 | from torch.serialization import normalize_storage_type 32 | 33 | 34 | def find_multiple(n: int, k: int) -> int: 35 | assert k > 0 36 | if n % k == 0: 37 | return n 38 | return n + k - (n % k) 39 | 40 | 41 | def num_parameters(module: nn.Module, requires_grad: Optional[bool] = None) -> int: 42 | return sum(p.numel() for p in module.parameters() if requires_grad is None or p.requires_grad == requires_grad) 43 | 44 | 45 | @contextmanager 46 | def quantization(mode: Optional[str] = None): 47 | if mode is None: 48 | yield 49 | return 50 | 51 | if mode == "bnb.int8": 52 | from quantize.bnb import InferenceLinear8bitLt 53 | 54 | quantized_linear_cls = InferenceLinear8bitLt 55 | elif mode == "bnb.fp4": 56 | from quantize.bnb import Linear4bit 57 | 58 | # Use a class instead `functools.partial` to respect `isinstance` checks and attribute accesses 59 | class QuantizedLinear(Linear4bit): 60 | def __init__(self, *args, **kwargs): 61 | super().__init__(*args, quant_type="fp4", compress_statistics=False, **kwargs) 62 | 63 | quantized_linear_cls = QuantizedLinear 64 | elif mode == "bnb.fp4-dq": 65 | from quantize.bnb import Linear4bit 66 | 67 | class QuantizedLinear(Linear4bit): 68 | def __init__(self, *args, **kwargs): 69 | super().__init__(*args, quant_type="fp4", compress_statistics=True, **kwargs) 70 | 71 | quantized_linear_cls = QuantizedLinear 72 | elif mode == "bnb.nf4": 73 | from quantize.bnb import Linear4bit 74 | 75 | class QuantizedLinear(Linear4bit): 76 | def __init__(self, *args, **kwargs): 77 | super().__init__(*args, quant_type="nf4", compress_statistics=False, **kwargs) 78 | 79 | quantized_linear_cls = QuantizedLinear 80 | elif mode == "bnb.nf4-dq": 81 | from quantize.bnb import Linear4bit 82 | 83 | class QuantizedLinear(Linear4bit): 84 | def __init__(self, *args, **kwargs): 85 | super().__init__(*args, quant_type="nf4", compress_statistics=True, **kwargs) 86 | 87 | quantized_linear_cls = QuantizedLinear 88 | elif mode == "gptq.int4": 89 | from quantize.gptq import ColBlockQuantizedLinear 90 | 91 | class QuantizedLinear(ColBlockQuantizedLinear): 92 | def __init__(self, *args, **kwargs): 93 | super().__init__(*args, bits=4, tile_cols=-1, **kwargs) 94 | 95 | quantized_linear_cls = QuantizedLinear 96 | else: 97 | raise ValueError(f"Unknown quantization mode: {mode}") 98 | 99 | torch_linear_cls = torch.nn.Linear 100 | torch.nn.Linear = quantized_linear_cls 101 | yield 102 | torch.nn.Linear = torch_linear_cls 103 | 104 | 105 | # this is taken from torchhacks https://github.com/lernapparat/torchhacks 106 | 107 | 108 | class NotYetLoadedTensor: 109 | def __init__(self, metatensor, archiveinfo, storageinfo, rebuild_args): 110 | self.metatensor = metatensor 111 | self.archiveinfo = archiveinfo 112 | self.storageinfo = storageinfo 113 | self.rebuild_args = rebuild_args 114 | 115 | @classmethod 116 | def rebuild_from_type_v2(cls, func, new_type, args, state, *, archiveinfo=None): 117 | ret = func(*args) 118 | if isinstance(ret, NotYetLoadedTensor): 119 | old_lt = ret._load_tensor 120 | 121 | def _load_tensor(): 122 | t = old_lt() 123 | return torch._tensor._rebuild_from_type_v2(lambda: t, new_type, (), state) 124 | 125 | ret._load_tensor = _load_tensor 126 | return ret 127 | return torch._tensor._rebuild_from_type_v2(func, new_type, args, state) 128 | 129 | @classmethod 130 | def rebuild_parameter(cls, data, requires_grad, backward_hooks, *, archiveinfo=None): 131 | if isinstance(data, NotYetLoadedTensor): 132 | old_lt = data._load_tensor 133 | 134 | def _load_tensor(): 135 | t = old_lt() 136 | return torch._utils._rebuild_parameter(t, requires_grad, backward_hooks) 137 | 138 | data._load_tensor = _load_tensor 139 | return data 140 | return torch._utils._rebuild_parameter(data, requires_grad, backward_hooks) 141 | 142 | @classmethod 143 | def rebuild_tensor_v2( 144 | cls, storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata=None, *, archiveinfo=None 145 | ): 146 | rebuild_args = (storage_offset, size, stride, requires_grad, backward_hooks, metadata) 147 | metatensor = torch._utils._rebuild_tensor_v2( 148 | storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata 149 | ) 150 | storageinfo = storage.archiveinfo 151 | return NotYetLoadedTensor(metatensor, archiveinfo, storageinfo, rebuild_args) 152 | 153 | def _load_tensor(self): 154 | name, storage_cls, fn, device, size = self.storageinfo 155 | dtype = self.metatensor.dtype 156 | 157 | uts = ( 158 | self.archiveinfo.zipfile_context.zf.get_storage_from_record( 159 | f"data/{fn}", size * torch._utils._element_size(dtype), torch.UntypedStorage 160 | ) 161 | ._typed_storage() 162 | ._untyped_storage 163 | ) 164 | with warnings.catch_warnings(): 165 | warnings.simplefilter("ignore") 166 | storage = torch.storage.TypedStorage(wrap_storage=uts, dtype=self.metatensor.dtype, _internal=True) 167 | return torch._utils._rebuild_tensor_v2(storage, *self.rebuild_args) 168 | 169 | @classmethod 170 | def __torch_function__(cls, func, types, args=(), kwargs=None): 171 | if kwargs is None: 172 | kwargs = {} 173 | loaded_args = [(a._load_tensor() if isinstance(a, NotYetLoadedTensor) else a) for a in args] 174 | return func(*loaded_args, **kwargs) 175 | # gc.collect would be costly here, maybe do it optionally 176 | 177 | def __getattr__(self, name): 178 | # properties 179 | ## TODO: device, is_...?? 180 | ## TODO: mH, mT, H, T, data, imag, real 181 | ## name ??? 182 | if name in { 183 | "dtype", 184 | "grad", 185 | "grad_fn", 186 | "layout", 187 | "names", 188 | "ndim", 189 | "output_nr", 190 | "requires_grad", 191 | "retains_grad", 192 | "shape", 193 | "volatile", 194 | }: 195 | return getattr(self.metatensor, name) 196 | if name in {"size"}: 197 | return getattr(self.metatensor, name) 198 | # materializing with contiguous is needed for quantization 199 | if name in {"contiguous"}: 200 | return getattr(self._load_tensor(), name) 201 | 202 | raise AttributeError(f"{type(self)} does not have {name}") 203 | 204 | def __repr__(self): 205 | return f"NotYetLoadedTensor({repr(self.metatensor)})" 206 | 207 | 208 | class LazyLoadingUnpickler(pickle.Unpickler): 209 | def __init__(self, file, zipfile_context): 210 | super().__init__(file) 211 | self.zipfile_context = zipfile_context 212 | 213 | def find_class(self, module, name): 214 | res = super().find_class(module, name) 215 | if module == "torch._utils" and name == "_rebuild_tensor_v2": 216 | return partial(NotYetLoadedTensor.rebuild_tensor_v2, archiveinfo=self) 217 | if module == "torch._tensor" and name == "_rebuild_from_type_v2": 218 | return partial(NotYetLoadedTensor.rebuild_from_type_v2, archiveinfo=self) 219 | if module == "torch._utils" and name == "_rebuild_parameter": 220 | return partial(NotYetLoadedTensor.rebuild_parameter, archiveinfo=self) 221 | return res 222 | 223 | def persistent_load(self, pid): 224 | name, cls, fn, device, size = pid 225 | with warnings.catch_warnings(): 226 | warnings.simplefilter("ignore") 227 | s = torch.storage.TypedStorage(dtype=cls().dtype, device="meta") 228 | s.archiveinfo = pid 229 | return s 230 | 231 | 232 | class lazy_load: 233 | def __init__(self, fn): 234 | self.zf = torch._C.PyTorchFileReader(str(fn)) 235 | with BytesIO(self.zf.get_record("data.pkl")) as pkl: 236 | mup = LazyLoadingUnpickler(pkl, self) 237 | self.sd = mup.load() 238 | 239 | def __enter__(self): 240 | return self.sd 241 | 242 | def __exit__(self, exc_type, exc_val, exc_tb): 243 | del self.zf # I don't think there is a way to force closing... 244 | self.zf = None 245 | 246 | 247 | def check_valid_checkpoint_dir(checkpoint_dir: Path) -> None: 248 | files = { 249 | "lit_model.pth": (checkpoint_dir / "lit_model.pth").is_file(), 250 | "lit_config.json": (checkpoint_dir / "lit_config.json").is_file(), 251 | "tokenizer.json OR tokenizer.model": (checkpoint_dir / "tokenizer.json").is_file() or ( 252 | checkpoint_dir / "tokenizer.model" 253 | ).is_file(), 254 | "tokenizer_config.json": (checkpoint_dir / "tokenizer_config.json").is_file(), 255 | } 256 | if checkpoint_dir.is_dir(): 257 | if all(files.values()): 258 | # we're good 259 | return 260 | problem = f" is missing the files: {[f for f, exists in files.items() if not exists]!r}" 261 | else: 262 | problem = " is not a checkpoint directory" 263 | 264 | # list locally available checkpoints 265 | available = list(Path("checkpoints").glob("*/*")) 266 | if available: 267 | options = "\n --checkpoint_dir ".join([""] + [repr(str(p.resolve())) for p in available]) 268 | extra = f"\nYou have downloaded locally:{options}\n" 269 | else: 270 | extra = "" 271 | 272 | error_message = ( 273 | f"--checkpoint_dir {str(checkpoint_dir.absolute())!r}{problem}." 274 | "\nFind download instructions at https://github.com/Lightning-AI/lit-gpt/blob/main/tutorials\n" 275 | f"{extra}\nSee all download options by running:\n python scripts/download.py" 276 | ) 277 | print(error_message, file=sys.stderr) 278 | raise SystemExit(1) 279 | 280 | 281 | class SavingProxyForStorage: 282 | def __init__(self, obj, saver, protocol_version=5): 283 | self.protocol_version = protocol_version 284 | self.saver = saver 285 | if not (isinstance(obj, torch.storage.TypedStorage) or torch.is_storage(obj)): 286 | raise TypeError(f"expected storage, not {type(obj)}") 287 | 288 | # this logic is taken from PyTorch 2.0+ torch/serialization.py 289 | if isinstance(obj, torch.storage.TypedStorage): 290 | # PT upstream wants to deprecate this eventually... 291 | storage = obj._untyped_storage 292 | storage_type_str = obj._pickle_storage_type() 293 | storage_type = getattr(torch, storage_type_str) 294 | storage_numel = obj._size() 295 | else: 296 | storage = obj 297 | storage_type = normalize_storage_type(type(obj)) 298 | storage_numel = storage.nbytes() 299 | 300 | storage_key = saver._write_storage_and_return_key(storage) 301 | location = torch.serialization.location_tag(storage) 302 | 303 | self.storage_info = ("storage", storage_type, storage_key, location, storage_numel) 304 | 305 | def __reduce_ex__(self, protocol_version): 306 | assert False, "this should be handled with out of band" 307 | 308 | 309 | class SavingProxyForTensor: 310 | def __init__(self, tensor, saver, protocol_version=5): 311 | self.protocol_version = protocol_version 312 | self.reduce_ret_fn, (storage, *other_reduce_args) = tensor.__reduce_ex__(protocol_version) 313 | assert isinstance(storage, torch.storage.TypedStorage), "Please check for updates" 314 | storage_proxy = SavingProxyForStorage(storage, saver, protocol_version=protocol_version) 315 | self.reduce_args = (storage_proxy, *other_reduce_args) 316 | 317 | def __reduce_ex__(self, protocol_version): 318 | if protocol_version != self.protocol_version: 319 | raise RuntimeError(f"Unexpected protocol version: expected {self.protocol_version}, got {protocol_version}") 320 | return self.reduce_ret_fn, self.reduce_args 321 | 322 | 323 | class IncrementalPyTorchPickler(pickle.Pickler): 324 | def __init__(self, saver, *args, **kwargs): 325 | super().__init__(*args, **kwargs) 326 | self.storage_dtypes = {} 327 | self.saver = saver 328 | self.id_map = {} 329 | 330 | # this logic is taken from PyTorch 2.0+ torch/serialization.py 331 | def persistent_id(self, obj): 332 | # FIXME: the docs say that persistent_id should only return a string 333 | # but torch store returns tuples. This works only in the binary protocol 334 | # see 335 | # https://docs.python.org/2/library/pickle.html#pickling-and-unpickling-external-objects 336 | # https://github.com/python/cpython/blob/master/Lib/pickle.py#L527-L537 337 | if isinstance(obj, SavingProxyForStorage): 338 | return obj.storage_info 339 | 340 | if isinstance(obj, torch.storage.TypedStorage) or torch.is_storage(obj): 341 | if isinstance(obj, torch.storage.TypedStorage): 342 | # TODO: Once we decide to break serialization FC, this case 343 | # can be deleted 344 | storage = obj._untyped_storage 345 | storage_dtype = obj.dtype 346 | storage_type_str = obj._pickle_storage_type() 347 | storage_type = getattr(torch, storage_type_str) 348 | storage_numel = obj._size() 349 | 350 | else: 351 | storage = obj 352 | storage_dtype = torch.uint8 353 | storage_type = normalize_storage_type(type(obj)) 354 | storage_numel = storage.nbytes() 355 | 356 | # If storage is allocated, ensure that any other saved storages 357 | # pointing to the same data all have the same dtype. If storage is 358 | # not allocated, don't perform this check 359 | if storage.data_ptr() != 0: 360 | if storage.data_ptr() in self.storage_dtypes: 361 | if storage_dtype != self.storage_dtypes[storage.data_ptr()]: 362 | raise RuntimeError( 363 | "Cannot save multiple tensors or storages that view the same data as different types" 364 | ) 365 | else: 366 | self.storage_dtypes[storage.data_ptr()] = storage_dtype 367 | 368 | storage_key = self.id_map.get(storage._cdata) 369 | if storage_key is None: 370 | storage_key = self.saver._write_storage_and_return_key(storage) 371 | self.id_map[storage._cdata] = storage_key 372 | location = torch.serialization.location_tag(storage) 373 | 374 | return ("storage", storage_type, storage_key, location, storage_numel) 375 | 376 | return None 377 | 378 | 379 | class incremental_save: 380 | def __init__(self, name): 381 | self.name = name 382 | self.zipfile = torch._C.PyTorchFileWriter(str(name)) 383 | self.has_saved = False 384 | self.next_key = 0 385 | 386 | def __enter__(self): 387 | return self 388 | 389 | def store_early(self, tensor): 390 | if isinstance(tensor, torch.Tensor): 391 | return SavingProxyForTensor(tensor, self) 392 | raise TypeError(f"can only store tensors early, not {type(tensor)}") 393 | 394 | def save(self, obj): 395 | if self.has_saved: 396 | raise RuntimeError("have already saved") 397 | # Write the pickle data for `obj` 398 | data_buf = BytesIO() 399 | pickler = IncrementalPyTorchPickler(self, data_buf, protocol=5) 400 | pickler.dump(obj) 401 | data_value = data_buf.getvalue() 402 | self.zipfile.write_record("data.pkl", data_value, len(data_value)) 403 | self.has_saved = True 404 | 405 | def _write_storage_and_return_key(self, storage): 406 | if self.has_saved: 407 | raise RuntimeError("have already saved") 408 | key = self.next_key 409 | self.next_key += 1 410 | name = f"data/{key}" 411 | if storage.device.type != "cpu": 412 | storage = storage.cpu() 413 | num_bytes = storage.nbytes() 414 | self.zipfile.write_record(name, storage.data_ptr(), num_bytes) 415 | return key 416 | 417 | def __exit__(self, type, value, traceback): 418 | self.zipfile.write_end_of_file() 419 | 420 | 421 | T = TypeVar("T") 422 | 423 | 424 | def step_csv_logger(*args: Any, cls: Type[T] = CSVLogger, **kwargs: Any) -> T: 425 | logger = cls(*args, **kwargs) 426 | 427 | def merge_by(dicts, key): 428 | from collections import defaultdict 429 | 430 | out = defaultdict(dict) 431 | for d in dicts: 432 | if key in d: 433 | out[d[key]].update(d) 434 | return [v for _, v in sorted(out.items())] 435 | 436 | def save(self) -> None: 437 | """Overridden to merge CSV by the step number.""" 438 | import csv 439 | 440 | if not self.metrics: 441 | return 442 | metrics = merge_by(self.metrics, "step") 443 | keys = sorted({k for m in metrics for k in m}) 444 | with self._fs.open(self.metrics_file_path, "w", newline="") as f: 445 | writer = csv.DictWriter(f, fieldnames=keys) 446 | writer.writeheader() 447 | writer.writerows(metrics) 448 | 449 | logger.experiment.save = MethodType(save, logger.experiment) 450 | 451 | return logger 452 | 453 | 454 | def chunked_cross_entropy( 455 | logits: Union[torch.Tensor, List[torch.Tensor]], targets: torch.Tensor, chunk_size: int = 128 456 | ) -> torch.Tensor: 457 | # with large max_sequence_lengths, the beginning of `backward` allocates a large memory chunk which can dominate 458 | # the memory usage in fine-tuning settings with low number of parameters. 459 | # as a workaround hack, the cross entropy computation is chunked to force it to deallocate on the go, reducing 460 | # the memory spike's magnitude 461 | 462 | # lm_head was chunked (we are fine-tuning) 463 | if isinstance(logits, list): 464 | # don't want to chunk cross entropy 465 | if chunk_size == 0: 466 | logits = torch.cat(logits, dim=1) 467 | logits = logits.reshape(-1, logits.size(-1)) 468 | targets = targets.reshape(-1) 469 | return torch.nn.functional.cross_entropy(logits, targets, ignore_index=-1) 470 | 471 | # chunk cross entropy 472 | logit_chunks = [logit_chunk.reshape(-1, logit_chunk.size(-1)) for logit_chunk in logits] 473 | target_chunks = [target_chunk.reshape(-1) for target_chunk in targets.split(logits[0].size(1), dim=1)] 474 | loss_chunks = [ 475 | torch.nn.functional.cross_entropy(logit_chunk, target_chunk, ignore_index=-1, reduction="none") 476 | for logit_chunk, target_chunk in zip(logit_chunks, target_chunks) 477 | ] 478 | return torch.cat(loss_chunks).mean() 479 | 480 | # no chunking at all 481 | logits = logits.reshape(-1, logits.size(-1)) 482 | targets = targets.reshape(-1) 483 | if chunk_size == 0: 484 | return torch.nn.functional.cross_entropy(logits, targets, ignore_index=-1) 485 | 486 | # lm_head wasn't chunked, chunk cross entropy 487 | logit_chunks = logits.split(chunk_size) 488 | target_chunks = targets.split(chunk_size) 489 | loss_chunks = [ 490 | torch.nn.functional.cross_entropy(logit_chunk, target_chunk, ignore_index=-1, reduction="none") 491 | for logit_chunk, target_chunk in zip(logit_chunks, target_chunks) 492 | ] 493 | return torch.cat(loss_chunks).mean() 494 | 495 | 496 | def map_old_state_dict_weights(state_dict: Dict, mapping: Mapping, prefix: str) -> Dict: 497 | for checkpoint_name, attribute_name in mapping.items(): 498 | full_checkpoint_name = prefix + checkpoint_name 499 | if full_checkpoint_name in state_dict: 500 | full_attribute_name = prefix + attribute_name 501 | state_dict[full_attribute_name] = state_dict.pop(full_checkpoint_name) 502 | return state_dict 503 | 504 | 505 | def get_default_supported_precision(training: bool, tpu: bool = False) -> str: 506 | """Return default precision that is supported by the hardware. 507 | 508 | Args: 509 | training: `-mixed` or `-true` version of the precision to use 510 | tpu: whether TPU device is used 511 | 512 | Returns: 513 | default precision that is suitable for the task and is supported by the hardware 514 | """ 515 | if tpu: 516 | return "32-true" 517 | if not torch.cuda.is_available() or torch.cuda.is_bf16_supported(): 518 | return "bf16-mixed" if training else "bf16-true" 519 | return "16-mixed" if training else "16-true" 520 | --------------------------------------------------------------------------------