├── .gitignore ├── LICENSE ├── README.md ├── README_cn.md ├── convert_llama_weights_to_hf.py ├── data ├── dummy_cn.json ├── dummy_en.json └── test.json ├── deepspeed-config.json ├── fastchat ├── data │ ├── changeto_alpaca.py │ ├── clean_sharegpt.py │ ├── merge.py │ ├── optional_clean.py │ ├── split_long_conversation.py │ └── trans_dummy.py └── train │ └── train_lora.py ├── generate.py ├── requirements.txt ├── templates └── alpaca.json └── utils └── prompter.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 git-cloner 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [中文](https://github.com/git-cloner/llama-lora-fine-tuning/blob/main/README_cn.md) 2 | 3 | # Fine-tuning vicuna-7b on a single 16G GPU 4 | 5 | ## 1. Overview 6 | 7 | There are generally two schemes for fine-tuning FaceBook/LLaMA. One is Stanford's alpaca series, and the other is Vicuna based on shareGPT corpus. Vicuna uses multi-round dialogue corpus, and the training effect is better than alpaca which is defaulted to single-round dialogue. Therefore, it is recommended to fine-tune Llama based on Vicuna. 8 | The two fine-tuning ways are described in detail in the following projects (the description of lora mode in FastChat is relatively simple).
9 | https://github.com/tloen/alpaca-lora
10 | https://github.com/lm-sys/FastChat
11 | Alpaca-lora has low memory requirements, about 12G 2080Ti can support, but training multi-round session models like Vicuna requires high GPU memory. Vicuna model training requires at least 24G GPU memory [official recommendation is 4 * V100 (32G)]. 12 | If you have a high-end graphics card, just follow the file to train. If you only have a 16G graphics card but want to customize the corpus to reproduce the Vicuna model, you have to think of many ways to continuously reduce the precision from 32 bits to half precision 16 bits, then from 16 bits to 8 bits, and accelerate the training method to achieve the goal. 13 | 14 | ## 2. Fine-tuning method 15 | 16 | • Use lora method to train only part of the parameters
17 | • The basic model adopts half-precision llama-7b-hf
18 | • Use load_in_8bit to load the basic model
19 | • Use peft technology for fine-tuning
20 | • Use bitsandbytes to accelerate
21 | Then we based on FastChat, this article modifies the lora training code, uses the shareGPT corpus, and fine-tunes on a 16G card, occupying about 13G of GPU memory.
22 | • Operating system: centos or ubuntu
23 | • NVIDA P100 or T4: 16G GPU memory or above
24 | • CUDA, conda 25 | 26 | ## 3.Fine-tuning process 27 | 28 | ### 3.1 Install dependency environment 29 | 30 | #### 3.1.1 Download source code 31 | 32 | ```bash 33 | git clone https://github.com/git-cloner/llama-lora-fine-tuning 34 | cd llama-lora-fine-tuning 35 | ``` 36 | 37 | #### 3.1.2 Install fine-tuning dependency environment 38 | 39 | ##### 3.1.2.1 Install pkg-config 40 | 41 | ```bash 42 | wget https://pkg-config.freedesktop.org/releases/pkg-config-0.29.2.tar.gz 43 | tar -zxvf pkg-config-0.29.2.tar.gz 44 | cd pkg-config-0.29.2 45 | ./configure --with-internal-glib 46 | make -j4 47 | make check 48 | sudo make install 49 | ``` 50 | 51 | ##### 3.1.2.2 Install libicu 52 | 53 | ```bash 54 | wget https://mirrors.aliyun.com/blfs/conglomeration/icu/icu4c-73_1-src.tgz 55 | tar xf icu4c-73_1-src.tgz 56 | cd icu/source 57 | ./configure 58 | make 59 | make check 60 | sudo make install 61 | sudo ldconfig 62 | ``` 63 | 64 | ##### 3.1.2.3 Install packages 65 | 66 | ```bash 67 | conda create -n llama-lora python=3.10 68 | conda activate llama-lora 69 | pip3 install -r requirements.txt 70 | ``` 71 | 72 | ### 3.2 Prepare the Llama model 73 | 74 | You can download the original model and convert it to half precision, or download the converted half precision model directly from https://huggingface.co/decapoda-research/llama-7b-hf. 75 | 76 | #### 3.2.1 Download the Llama model 77 | 78 | ```bash 79 | export GIT_TRACE=1 80 | export GIT_CURL_VERBOSE=1 81 | pip3 install git+https://github.com/juncongmoo/pyllama -i https://pypi.mirrors.ustc.edu.cn/simple --trusted-host=pypi.mirrors.ustc.edu.cn 82 | python -m llama.download --model_size 7B 83 | ``` 84 | 85 | #### 3.2.2 Convert the model to huggingface format 86 | 87 | ```bash 88 | CUDA_VISIBLE_DEVICES=1 python3 ./convert_llama_weights_to_hf.py --input_dir ./pyllama_data --model_size 7B --output_dir ./pyllama_data/output/7B 89 | ``` 90 | 91 | ### 3.3 Organize the corpus 92 | 93 | #### 3.3.1 Corpus download 94 | 95 | ```ash 96 | Download 52k ShareGPT: https://huggingface.co/datasets/RyokoAI/ShareGPT52K 97 | Other corpora refer to: https://github.com/Zjh-819/LLMDataHub 98 | Download sg_90k_part1.json and sg_90k_part2.json into the data directory 99 | ``` 100 | 101 | #### 3.3.2 Merge corpus files 102 | 103 | ```bash 104 | python3 fastchat/data/merge.py --in ./data/sg_90k_part1.json ./data/sg_90k_part2.json ./data/dummy_cn.json ./data/dummy_en.json --out ./data/sg_90k.json 105 | ``` 106 | 107 | #### 3.3.3 Html to Markdown 108 | 109 | ```bash 110 | python3 fastchat/data/clean_sharegpt.py --in ./data/sg_90k.json --out ./data/sharegpt_clean.json 111 | ``` 112 | 113 | #### 3.3.4 Remove some unused languages (optional) 114 | 115 | ```bash 116 | python3 fastchat/data/optional_clean.py --in ./data/sharegpt_clean.json --out ./data/sharegpt_clean_1.json --skip-lang SOME_LANGUAGE_CODE 117 | The values of SOME_LANGUAGE_CODE are as follows: 118 | en - English 119 | es - Spanish 120 | fr - French 121 | de - German 122 | it - Italian 123 | ja - Japanese 124 | ko - Korean 125 | zh - Chinese 126 | ar - Arabic 127 | ru - Russian 128 | pt - Portuguese 129 | nl - Dutch 130 | ``` 131 | 132 | #### 3.3.5 Split long conversations into short dialogues 133 | 134 | ```bash 135 | CUDA_VISIBLE_DEVICES=1 python3 fastchat/data/split_long_conversation.py --in ./data/sharegpt_clean.json --out ./data/sharegpt_clean_split.json --model-name ./pyllama_data/output/7B 136 | ``` 137 | 138 | ### 3.4 Fine-tuning 139 | 140 | #### 3.4.1 Fine-tuning command 141 | 142 | 143 | 144 | ```bash 145 | # Disable wandb 146 | wandb disabled 147 | # In order to prevent the SSH terminal from disconnecting and stopping the training, the training can run in the background (remove the # in three places to run in the background) 148 | # If you have multiple GPUs,using --num_gpus parameter 149 | CUDA_VISIBLE_DEVICES=0,1 \ #nohup \ 150 | deepspeed --num_gpus=2 fastchat/train/train_lora.py \ 151 | --deepspeed ./deepspeed-config.json \ 152 | --lora_r 8 \ 153 | --lora_alpha 16 \ 154 | --lora_dropout 0.05 \ 155 | --model_name_or_path ./pyllama_data/output/7B \ 156 | --data_path ./data/sharegpt_clean_split.json \ 157 | --fp16 True \ 158 | --output_dir ./output \ 159 | --num_train_epochs 1 \ 160 | --per_device_train_batch_size 14 \ 161 | --per_device_eval_batch_size 14 \ 162 | --gradient_accumulation_steps 1 \ 163 | --evaluation_strategy "no" \ 164 | --save_strategy "steps" \ 165 | --save_steps 2400 \ 166 | --save_total_limit 5 \ 167 | --learning_rate 2e-5 \ 168 | --weight_decay 0. \ 169 | --warmup_ratio 0.03 \ 170 | --lr_scheduler_type "cosine" \ 171 | --logging_steps 1 \ 172 | --model_max_length 512 \ 173 | --gradient_checkpointing True #>> lora.log 2>&1 & 174 | # If running in the background, tail lora.log to check the training progress 175 | tail -f lora.log 176 | ``` 177 | 178 | #### 3.4.2 Fine-tuning performance 179 | 180 | Fine-tuning on P100 (16G) occupies 13.5G of memory. In the case of one round of training, it takes 120 hours, about 5 days, which is still very time-consuming. The effect of the resulting model needs to be verified. 181 | model_max_length will affect the training time. If set to 1024, the time will be halved compared to 2048, but it will affect the inference effect. 182 | 183 | #### 3.4.3 Fine-tuning on A100 184 | 185 | fine-tuning on single A100 and take about 16 hours. 186 | 187 | ```bash 188 | deepspeed fastchat/train/train_lora.py \ 189 | --deepspeed ./deepspeed-config.json \ 190 | --lora_r 8 \ 191 | --lora_alpha 16 \ 192 | --lora_dropout 0.05 \ 193 | --model_name_or_path ./pyllama_data/output/7B \ 194 | --data_path ./data/sharegpt_clean_split.json \ 195 | --fp16 True \ 196 | --output_dir ./output \ 197 | --num_train_epochs 1 \ 198 | --per_device_train_batch_size 56 \ 199 | --per_device_eval_batch_size 56 \ 200 | --gradient_accumulation_steps 1\ 201 | --evaluation_strategy "no" \ 202 | --save_strategy "steps" \ 203 | --save_steps 1200 \ 204 | --save_total_limit 5 \ 205 | --learning_rate 2e-5 \ 206 | --weight_decay 0. \ 207 | --warmup_ratio 0.03 \ 208 | --lr_scheduler_type "cosine" \ 209 | --logging_steps 1 \ 210 | --model_max_length 1024 \ 211 | --gradient_checkpointing True 212 | ``` 213 | 214 | ## 4、Test trained model 215 | 216 | ### 4.1 model file structure 217 | 218 | The trained LoRa peft model consists of adapter_config.json, adapter_model.bin, and trainer_state.json. Below is the file structure of peft and the original llama model. 219 | 220 | ```bash 221 | model 222 | ───llama-peft 223 | │ adapter_config.json 224 | │ adapter_model.bin 225 | │ trainer_state.json 226 | │ 227 | └──llama_7b 228 | config.json 229 | generation_config.json 230 | pytorch_model-00001-of-00002.bin 231 | pytorch_model-00002-of-00002.bin 232 | pytorch_model.bin.index.json 233 | special_tokens_map.json 234 | tokenizer.json 235 | tokenizer.model 236 | tokenizer_config.json 237 | ``` 238 | 239 | ### 4.2 test generate 240 | 241 | ```bash 242 | CUDA_VISIBLE_DEVICES=0 python generate.py --base_model ./model/llama-7b --lora_weights ./model/llama-peft 243 | ``` 244 | 245 | -------------------------------------------------------------------------------- /README_cn.md: -------------------------------------------------------------------------------- 1 | # 在单块16G的推理卡上微调复现vicuna-7b 2 | 3 | ## 1、概述 4 | 5 | 对FaceBook/LLaMA的微调一般有两种方案,一种是斯坦福的alpaca系列,一种是基于shareGPT语料的Vicuna方向。Vicuna采用的是多轮对话的语料,训练效果要好于默认为单轮对话的alpaca,所以要在Llama基础上微调,建议首选Vicuna的方式。 6 | 7 | 关于两种微调方式,在以下项目中都有详细描述(FastChat中对lora模式的说明比较简单)。 8 | 9 | https://github.com/tloen/alpaca-lora 10 | 11 | https://github.com/lm-sys/FastChat 12 | 13 | Alpaca-lora的训练对内存要求不高,大概12G的2080Ti,就可以支持,但训练多轮会话的类Vicuna的模型,则对显存的要求比较高,Vicuna模式的训练至少需要24G显存【官方建议是4 * V100(32G)】。 14 | 15 | 如果有足够高端的显卡,只要跟随文件训练即可,如果只有16G的显卡,又想自己定制语料复现Vicuna模型,那就得想很多办法,不断地降精度,从32位降到半精度16位,再从16位降成8位,再辅以一些加速训练方法才能达到目的。 16 | 17 | ## 2、微调方案 18 | 19 | - 采用lora方式只训练一部分参数 20 | - 基础模型采用半精度llama-7b-hf 21 | - 使用load_in_8bit装载基础模型 22 | - 采用peft技术微调 23 | - 采用bitsandbytes加速 24 | 25 | 所以本文在FastChat的基础上,修改lora训练代码,使用shareGPT语料,在16G显存的推理卡进行微调,大概占用显存13G左右。 26 | 27 | - 操作系统 centos或ubuntu 28 | 29 | - NVIDA P100或T4:16G显存或以上 30 | 31 | - CUDA、conda:https://zhuanlan.zhihu.com/p/597063490 32 | 33 | ## 3、微调过程 34 | 35 | ### 3.1、安装依赖环境 36 | 37 | #### 3.1.1 下载源码 38 | 39 | ```bash 40 | git clone https://github.com/git-cloner/llama-lora-fine-tuning 41 | cd llama-lora-fine-tuning 42 | ``` 43 | 44 | #### 3.1.2 安装微调依赖环境 45 | 46 | ##### 3.1.2.1 install pkg-config 47 | 48 | ```bash 49 | wget https://pkg-config.freedesktop.org/releases/pkg-config-0.29.2.tar.gz 50 | tar -zxvf pkg-config-0.29.2.tar.gz 51 | cd pkg-config-0.29.2 52 | ./configure --with-internal-glib 53 | make -j4 54 | make check 55 | sudo make install 56 | ``` 57 | 58 | ##### 3.1.2.2 install libicu 59 | 60 | ```bash 61 | wget https://mirrors.aliyun.com/blfs/conglomeration/icu/icu4c-73_1-src.tgz 62 | tar xf icu4c-73_1-src.tgz 63 | cd icu/source 64 | ./configure 65 | make 66 | make check 67 | sudo make install 68 | sudo ldconfig 69 | ``` 70 | 71 | ##### 3.1.2.3 安装驱动及conda 72 | 73 | https://zhuanlan.zhihu.com/p/597063490 74 | 75 | ##### 3.1.2.4 install packages 76 | 77 | ```bash 78 | conda create -n llama-lora python=3.10 79 | conda activate llama-lora 80 | pip3 install -r requirements.txt -i https://pypi.mirrors.ustc.edu.cn/simple --trusted-host=pypi.mirrors.ustc.edu.cn 81 | ``` 82 | 83 | ### 3.2、准备Llama模型 84 | 85 | 可以用以下办法下载原始模型后转换为半精度,也可以从https://huggingface.co/decapoda-research/llama-7b-hf直接下载转换好的半精度模型。 86 | 87 | #### 3.2.1 下载Llama模型 88 | 89 | ```bash 90 | export GIT_TRACE=1 91 | export GIT_CURL_VERBOSE=1 92 | pip3 install git+https://github.com/juncongmoo/pyllama -i https://pypi.mirrors.ustc.edu.cn/simple --trusted-host=pypi.mirrors.ustc.edu.cn 93 | python -m llama.download --model_size 7B 94 | ``` 95 | 96 | #### 3.2.2 转换模型为huggingface格式 97 | 98 | ```bash 99 | CUDA_VISIBLE_DEVICES=1 python3 ./convert_llama_weights_to_hf.py --input_dir ./pyllama_data --model_size 7B --output_dir ./pyllama_data/output/7B 100 | ``` 101 | 102 | ### 3.3、整理语料 103 | 104 | #### 3.3.1 语料下载 105 | 106 | 下载52k的ShareGPT:https://huggingface.co/datasets/RyokoAI/ShareGPT52K 107 | 108 | 其他语料参见:https://github.com/Zjh-819/LLMDataHub 109 | 110 | 下载的sg_90k_part1.json和sg_90k_part2.json放到data下 111 | 112 | #### 3.3.2 合并语料文件 113 | 114 | ```bash 115 | python3 fastchat/data/merge.py --in ./data/sg_90k_part1.json ./data/sg_90k_part2.json ./data/dummy_cn.json ./data/dummy_en.json --out ./data/sg_90k.json 116 | ``` 117 | 118 | #### 3.3.3 html转markdown 119 | 120 | ```bash 121 | python3 fastchat/data/clean_sharegpt.py --in ./data/sg_90k.json --out ./data/sharegpt_clean.json 122 | ``` 123 | 124 | #### 3.3.4 去掉一些用不到的语言(可选) 125 | 126 | ```bash 127 | python3 fastchat/data/optional_clean.py --in ./data/sharegpt_clean.json --out ./data/sharegpt_clean_1.json --skip-lang SOME_LANGUAGE_CODE 128 | 其中SOME_LANGUAGE_CODE的取值举例如下: 129 | en - 英语 130 | es - 西班牙语 131 | fr - 法语 132 | de - 德语 133 | it - 意大利语 134 | ja - 日语 135 | ko - 朝鲜语 136 | zh - 中文 137 | ar - 阿拉伯语 138 | ru - 俄语 139 | pt - 葡萄牙语 140 | nl - 荷兰语 141 | ``` 142 | 143 | #### 3.3.5 将长会话切分成短对话 144 | 145 | ```shell 146 | CUDA_VISIBLE_DEVICES=1 python3 fastchat/data/split_long_conversation.py --in ./data/sharegpt_clean.json --out ./data/sharegpt_clean_split.json --model-name ./pyllama_data/output/7B 147 | ``` 148 | 149 | ### 3.4、微调 150 | 151 | #### 3.4.1 微调命令 152 | 153 | ```bash 154 | # 禁用wandb 155 | wandb disabled 156 | # 为了防止ssh终端断开导致训练中止,训练可在后台运行(去掉#三处注释即可在后台运行) 157 | # 如果有多颗GPU,可以用--num_gpus参数 158 | CUDA_VISIBLE_DEVICES=0,1 \ #nohup \ 159 | deepspeed --num_gpus=2 fastchat/train/train_lora.py \ 160 | --deepspeed ./deepspeed-config.json \ 161 | --lora_r 8 \ 162 | --lora_alpha 16 \ 163 | --lora_dropout 0.05 \ 164 | --model_name_or_path ./pyllama_data/output/7B \ 165 | --data_path ./data/sharegpt_clean_split.json \ 166 | --fp16 True \ 167 | --output_dir ./output \ 168 | --num_train_epochs 1 \ 169 | --per_device_train_batch_size 14 \ 170 | --per_device_eval_batch_size 14 \ 171 | --gradient_accumulation_steps 1 \ 172 | --evaluation_strategy "no" \ 173 | --save_strategy "steps" \ 174 | --save_steps 2400 \ 175 | --save_total_limit 5 \ 176 | --learning_rate 2e-5 \ 177 | --weight_decay 0. \ 178 | --warmup_ratio 0.03 \ 179 | --lr_scheduler_type "cosine" \ 180 | --logging_steps 1 \ 181 | --model_max_length 512 \ 182 | --gradient_checkpointing True #>> lora.log 2>&1 & 183 | # 如果在后台运行,则tail lora.log查看训练进度 184 | tail -f lora.log 185 | ``` 186 | 187 | #### 3.4.2 微调性能 188 | 189 | 在P100(16G)上进行微调,占用内存13.5G,在训练一轮的情况下,需要120个小时,大约5天时间,还是非常耗时时,形成的模型效果也有待验证。 190 | 191 | model_max_length会影响到训练的时长,如果设成1024,比2048的时长减少一半,但会影响到推理效果。 192 | 193 | #### 3.4.3 A100微调命令 194 | 195 | 单块A100微调的参数如下,大约需要16小时。 196 | 197 | ```bash 198 | deepspeed fastchat/train/train_lora.py \ 199 | --deepspeed ./deepspeed-config.json \ 200 | --lora_r 8 \ 201 | --lora_alpha 16 \ 202 | --lora_dropout 0.05 \ 203 | --model_name_or_path ./pyllama_data/output/7B \ 204 | --data_path ./data/sharegpt_clean_split.json \ 205 | --fp16 True \ 206 | --output_dir ./output \ 207 | --num_train_epochs 1 \ 208 | --per_device_train_batch_size 56 \ 209 | --per_device_eval_batch_size 56 \ 210 | --gradient_accumulation_steps 1\ 211 | --evaluation_strategy "no" \ 212 | --save_strategy "steps" \ 213 | --save_steps 1200 \ 214 | --save_total_limit 5 \ 215 | --learning_rate 2e-5 \ 216 | --weight_decay 0. \ 217 | --warmup_ratio 0.03 \ 218 | --lr_scheduler_type "cosine" \ 219 | --logging_steps 1 \ 220 | --model_max_length 1024 \ 221 | --gradient_checkpointing True 222 | ``` 223 | 224 | ## 4、测试模型 225 | 226 | ### 4.1 模型位置 227 | 228 | 训练好的lora peft模型由adapter_config.json、adapter_model.bin和trainer_state.json组成。下面是peft和原模型的目录结构。 229 | 230 | ```bash 231 | model 232 | ───llama-peft 233 | │ adapter_config.json 234 | │ adapter_model.bin 235 | │ trainer_state.json 236 | │ 237 | └──llama_7b 238 | config.json 239 | generation_config.json 240 | pytorch_model-00001-of-00002.bin 241 | pytorch_model-00002-of-00002.bin 242 | pytorch_model.bin.index.json 243 | special_tokens_map.json 244 | tokenizer.json 245 | tokenizer.model 246 | tokenizer_config.json 247 | ``` 248 | 249 | ### 4.2 测试生成 250 | 251 | ```bash 252 | CUDA_VISIBLE_DEVICES=0 python generate.py --base_model ./model/llama-7b --lora_weights ./model/llama-peft 253 | ``` 254 | 255 | -------------------------------------------------------------------------------- /convert_llama_weights_to_hf.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 EleutherAI and The HuggingFace Inc. team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import argparse 15 | import gc 16 | import json 17 | import math 18 | import os 19 | import shutil 20 | import warnings 21 | 22 | import torch 23 | 24 | from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer 25 | 26 | 27 | try: 28 | from transformers import LlamaTokenizerFast 29 | except ImportError as e: 30 | warnings.warn(e) 31 | warnings.warn( 32 | "The converted tokenizer will be the `slow` tokenizer. To use the fast, update your `tokenizers` library and re-run the tokenizer conversion" 33 | ) 34 | LlamaTokenizerFast = None 35 | 36 | """ 37 | Sample usage: 38 | 39 | ``` 40 | python src/transformers/models/llama/convert_llama_weights_to_hf.py \ 41 | --input_dir /path/to/downloaded/llama/weights --model_size 7B --output_dir /output/path 42 | ``` 43 | 44 | Thereafter, models can be loaded via: 45 | 46 | ```py 47 | from transformers import LlamaForCausalLM, LlamaTokenizer 48 | 49 | model = LlamaForCausalLM.from_pretrained("/output/path") 50 | tokenizer = LlamaTokenizer.from_pretrained("/output/path") 51 | ``` 52 | 53 | Important note: you need to be able to host the whole model in RAM to execute this script (even if the biggest versions 54 | come in several checkpoints they each contain a part of each weight of the model, so we need to load them all in RAM). 55 | """ 56 | 57 | INTERMEDIATE_SIZE_MAP = { 58 | "7B": 11008, 59 | "13B": 13824, 60 | "30B": 17920, 61 | "65B": 22016, 62 | } 63 | NUM_SHARDS = { 64 | "7B": 1, 65 | "13B": 2, 66 | "30B": 4, 67 | "65B": 8, 68 | } 69 | 70 | 71 | def compute_intermediate_size(n): 72 | return int(math.ceil(n * 8 / 3) + 255) // 256 * 256 73 | 74 | 75 | def read_json(path): 76 | with open(path, "r") as f: 77 | return json.load(f) 78 | 79 | 80 | def write_json(text, path): 81 | with open(path, "w") as f: 82 | json.dump(text, f) 83 | 84 | 85 | def write_model(model_path, input_base_path, model_size): 86 | os.makedirs(model_path, exist_ok=True) 87 | tmp_model_path = os.path.join(model_path, "tmp") 88 | os.makedirs(tmp_model_path, exist_ok=True) 89 | 90 | params = read_json(os.path.join(input_base_path, "params.json")) 91 | num_shards = NUM_SHARDS[model_size] 92 | n_layers = params["n_layers"] 93 | n_heads = params["n_heads"] 94 | n_heads_per_shard = n_heads // num_shards 95 | dim = params["dim"] 96 | dims_per_head = dim // n_heads 97 | base = 10000.0 98 | inv_freq = 1.0 / (base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head)) 99 | 100 | # permute for sliced rotary 101 | def permute(w): 102 | return w.view(n_heads, dim // n_heads // 2, 2, dim).transpose(1, 2).reshape(dim, dim) 103 | 104 | print(f"Fetching all parameters from the checkpoint at {input_base_path}.") 105 | # Load weights 106 | if model_size == "7B": 107 | # Not sharded 108 | # (The sharded implementation would also work, but this is simpler.) 109 | loaded = torch.load(os.path.join(input_base_path, "consolidated.00.pth"), map_location="cpu") 110 | else: 111 | # Sharded 112 | loaded = [ 113 | torch.load(os.path.join(input_base_path, f"consolidated.{i:02d}.pth"), map_location="cpu") 114 | for i in range(num_shards) 115 | ] 116 | param_count = 0 117 | index_dict = {"weight_map": {}} 118 | for layer_i in range(n_layers): 119 | filename = f"pytorch_model-{layer_i + 1}-of-{n_layers + 1}.bin" 120 | if model_size == "7B": 121 | # Unsharded 122 | state_dict = { 123 | f"model.layers.{layer_i}.self_attn.q_proj.weight": permute( 124 | loaded[f"layers.{layer_i}.attention.wq.weight"] 125 | ), 126 | f"model.layers.{layer_i}.self_attn.k_proj.weight": permute( 127 | loaded[f"layers.{layer_i}.attention.wk.weight"] 128 | ), 129 | f"model.layers.{layer_i}.self_attn.v_proj.weight": loaded[f"layers.{layer_i}.attention.wv.weight"], 130 | f"model.layers.{layer_i}.self_attn.o_proj.weight": loaded[f"layers.{layer_i}.attention.wo.weight"], 131 | f"model.layers.{layer_i}.mlp.gate_proj.weight": loaded[f"layers.{layer_i}.feed_forward.w1.weight"], 132 | f"model.layers.{layer_i}.mlp.down_proj.weight": loaded[f"layers.{layer_i}.feed_forward.w2.weight"], 133 | f"model.layers.{layer_i}.mlp.up_proj.weight": loaded[f"layers.{layer_i}.feed_forward.w3.weight"], 134 | f"model.layers.{layer_i}.input_layernorm.weight": loaded[f"layers.{layer_i}.attention_norm.weight"], 135 | f"model.layers.{layer_i}.post_attention_layernorm.weight": loaded[f"layers.{layer_i}.ffn_norm.weight"], 136 | } 137 | else: 138 | # Sharded 139 | # Note that in the 13B checkpoint, not cloning the two following weights will result in the checkpoint 140 | # becoming 37GB instead of 26GB for some reason. 141 | state_dict = { 142 | f"model.layers.{layer_i}.input_layernorm.weight": loaded[0][ 143 | f"layers.{layer_i}.attention_norm.weight" 144 | ].clone(), 145 | f"model.layers.{layer_i}.post_attention_layernorm.weight": loaded[0][ 146 | f"layers.{layer_i}.ffn_norm.weight" 147 | ].clone(), 148 | } 149 | state_dict[f"model.layers.{layer_i}.self_attn.q_proj.weight"] = permute( 150 | torch.cat( 151 | [ 152 | loaded[i][f"layers.{layer_i}.attention.wq.weight"].view(n_heads_per_shard, dims_per_head, dim) 153 | for i in range(num_shards) 154 | ], 155 | dim=0, 156 | ).reshape(dim, dim) 157 | ) 158 | state_dict[f"model.layers.{layer_i}.self_attn.k_proj.weight"] = permute( 159 | torch.cat( 160 | [ 161 | loaded[i][f"layers.{layer_i}.attention.wk.weight"].view(n_heads_per_shard, dims_per_head, dim) 162 | for i in range(num_shards) 163 | ], 164 | dim=0, 165 | ).reshape(dim, dim) 166 | ) 167 | state_dict[f"model.layers.{layer_i}.self_attn.v_proj.weight"] = torch.cat( 168 | [ 169 | loaded[i][f"layers.{layer_i}.attention.wv.weight"].view(n_heads_per_shard, dims_per_head, dim) 170 | for i in range(num_shards) 171 | ], 172 | dim=0, 173 | ).reshape(dim, dim) 174 | 175 | state_dict[f"model.layers.{layer_i}.self_attn.o_proj.weight"] = torch.cat( 176 | [loaded[i][f"layers.{layer_i}.attention.wo.weight"] for i in range(num_shards)], dim=1 177 | ) 178 | state_dict[f"model.layers.{layer_i}.mlp.gate_proj.weight"] = torch.cat( 179 | [loaded[i][f"layers.{layer_i}.feed_forward.w1.weight"] for i in range(num_shards)], dim=0 180 | ) 181 | state_dict[f"model.layers.{layer_i}.mlp.down_proj.weight"] = torch.cat( 182 | [loaded[i][f"layers.{layer_i}.feed_forward.w2.weight"] for i in range(num_shards)], dim=1 183 | ) 184 | state_dict[f"model.layers.{layer_i}.mlp.up_proj.weight"] = torch.cat( 185 | [loaded[i][f"layers.{layer_i}.feed_forward.w3.weight"] for i in range(num_shards)], dim=0 186 | ) 187 | 188 | state_dict[f"model.layers.{layer_i}.self_attn.rotary_emb.inv_freq"] = inv_freq 189 | for k, v in state_dict.items(): 190 | index_dict["weight_map"][k] = filename 191 | param_count += v.numel() 192 | torch.save(state_dict, os.path.join(tmp_model_path, filename)) 193 | 194 | filename = f"pytorch_model-{n_layers + 1}-of-{n_layers + 1}.bin" 195 | if model_size == "7B": 196 | # Unsharded 197 | state_dict = { 198 | "model.embed_tokens.weight": loaded["tok_embeddings.weight"], 199 | "model.norm.weight": loaded["norm.weight"], 200 | "lm_head.weight": loaded["output.weight"], 201 | } 202 | else: 203 | state_dict = { 204 | "model.norm.weight": loaded[0]["norm.weight"], 205 | "model.embed_tokens.weight": torch.cat( 206 | [loaded[i]["tok_embeddings.weight"] for i in range(num_shards)], dim=1 207 | ), 208 | "lm_head.weight": torch.cat([loaded[i]["output.weight"] for i in range(num_shards)], dim=0), 209 | } 210 | 211 | for k, v in state_dict.items(): 212 | index_dict["weight_map"][k] = filename 213 | param_count += v.numel() 214 | torch.save(state_dict, os.path.join(tmp_model_path, filename)) 215 | 216 | # Write configs 217 | index_dict["metadata"] = {"total_size": param_count * 2} 218 | write_json(index_dict, os.path.join(tmp_model_path, "pytorch_model.bin.index.json")) 219 | 220 | config = LlamaConfig( 221 | hidden_size=dim, 222 | intermediate_size=compute_intermediate_size(dim), 223 | num_attention_heads=params["n_heads"], 224 | num_hidden_layers=params["n_layers"], 225 | rms_norm_eps=params["norm_eps"], 226 | ) 227 | config.save_pretrained(tmp_model_path) 228 | 229 | # Make space so we can load the model properly now. 230 | del state_dict 231 | del loaded 232 | gc.collect() 233 | 234 | print("Loading the checkpoint in a Llama model.") 235 | model = LlamaForCausalLM.from_pretrained(tmp_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 236 | # Avoid saving this as part of the config. 237 | del model.config._name_or_path 238 | 239 | print("Saving in the Transformers format.") 240 | model.save_pretrained(model_path) 241 | shutil.rmtree(tmp_model_path) 242 | 243 | 244 | def write_tokenizer(tokenizer_path, input_tokenizer_path): 245 | # Initialize the tokenizer based on the `spm` model 246 | tokenizer_class = LlamaTokenizer if LlamaTokenizerFast is None else LlamaTokenizerFast 247 | print(f"Saving a {tokenizer_class.__name__} to {tokenizer_path}.") 248 | tokenizer = tokenizer_class(input_tokenizer_path) 249 | tokenizer.save_pretrained(tokenizer_path) 250 | 251 | 252 | def main(): 253 | parser = argparse.ArgumentParser() 254 | parser.add_argument( 255 | "--input_dir", 256 | help="Location of LLaMA weights, which contains tokenizer.model and model folders", 257 | ) 258 | parser.add_argument( 259 | "--model_size", 260 | choices=["7B", "13B", "30B", "65B", "tokenizer_only"], 261 | ) 262 | parser.add_argument( 263 | "--output_dir", 264 | help="Location to write HF model and tokenizer", 265 | ) 266 | args = parser.parse_args() 267 | if args.model_size != "tokenizer_only": 268 | write_model( 269 | model_path=args.output_dir, 270 | input_base_path=os.path.join(args.input_dir, args.model_size), 271 | model_size=args.model_size, 272 | ) 273 | spm_path = os.path.join(args.input_dir, "tokenizer.model") 274 | write_tokenizer(args.output_dir, spm_path) 275 | 276 | 277 | if __name__ == "__main__": 278 | main() 279 | -------------------------------------------------------------------------------- /data/test.json: -------------------------------------------------------------------------------- 1 | [{ 2 | "id": "jsxN4Dv", 3 | "conversations": [ 4 | { 5 | "from": "human", 6 | "value": "\u6211\u9700\u8981\u4f60\u5e6b\u6211\u586b\u5beb\u82f1\u6587\u5c65\u6b77\uff0c\u6211\u7d66\u4f60\u4e2d\u6587\u4e4b\u5f8c\u4f60\u5e6b\u6211\u5f59\u6574\u4e26\u7ffb\u8b6f\u6210\u82f1\u6587\u3002" 7 | }, 8 | { 9 | "from": "gpt", 10 | "value": "\u597d\u7684\uff0c\u6211\u5f88\u6a02\u610f\u5e6b\u60a8\u5f59\u6574\u4e26\u7ffb\u8b6f\u6210\u82f1\u6587\u7684\u5c65\u6b77\u3002\u8acb\u60a8\u63d0\u4f9b\u4e2d\u6587\u7248\u7684\u5c65\u6b77\u7d66\u6211\u53c3\u8003\uff0c\u6211\u6703\u76e1\u5feb\u70ba\u60a8\u5b8c\u6210\u7ffb\u8b6f\u3002" 11 | }, 12 | { 13 | "from": "human", 14 | "value": "\u6211\u662fZHENG,JING-CHENG\uff0c\u76ee\u524d\u5c31\u8b80\u4f4d\u65bc\u53f0\u5317\u5e02\u6587\u5c71\u5340\u7684\u4e2d\u570b\u79d1\u6280\u5927\u5b78\uff0c\u6211\u662f\u8cc7\u8a0a\u7ba1\u7406\u7cfb\uff0c\u6211\u51fa\u751f\u5230\u73fe\u5728\u90fd\u662f\u4f4f\u53f0\u5317\u5e02\uff0c\u6211\u7236\u6bcd\u5f88\u6ce8\u91cd\u6211\u7684\u5916\u8a9e\u80fd\u529b\uff0c\u6211\u4e5f\u89ba\u5f97\u5916\u8a9e\u80fd\u529b\u975e\u5e38\u91cd\u8981\uff0c\u6211\u76ee\u524d\u6703html\u7684\u67b6\u8a2d\uff0c\u4e5f\u6703\u4e00\u9ede\u786c\u9ad4\u88dd\u4fee\u7684\u57fa\u790e\u3002" 15 | }, 16 | { 17 | "from": "gpt", 18 | "value": "\u4ee5\u4e0b\u662f\u60a8\u7684\u82f1\u6587\u5c65\u6b77\uff0c\u8acb\u60a8\u78ba\u8a8d\u4e26\u63d0\u4f9b\u610f\u898b\u3002\n\nResume\n\nPersonal Information:\nName: ZHENG, JING-CHENG\nCurrent Address: Taipei City, Taiwan\nEmail: [\u60a8\u7684\u96fb\u5b50\u90f5\u4ef6]\nPhone: [\u60a8\u7684\u96fb\u8a71\u865f\u78bc]\n\nEducation:\nDegree: Bachelor's Degree\nMajor: Information Management\nUniversity: National Taiwan University of Science and Technology\nLocation: Wenshan District, Taipei City, Taiwan\nExpected Graduation Date: [\u60a8\u7684\u9810\u671f\u7562\u696d\u65e5\u671f]\n\nSkills:\n\n* Proficient in HTML coding\n* Basic knowledge of hardware installation\n\nExperience:\n[\u60a8\u7684\u5de5\u4f5c\u7d93\u9a57\uff0c\u53ef\u4f9d\u6642\u9593\u5012\u5e8f\u6392\u5217\uff0c\u5305\u542b\u516c\u53f8\u540d\u7a31\u3001\u8077\u7a31\u3001\u4efb\u8077\u6642\u9593\u3001\u5de5\u4f5c\u5167\u5bb9\u7b49]\n\nLanguage:\n\n* Mandarin Chinese: Native language\n* English: Proficient\n\nPersonal Statement:\nI was born and raised in Taipei City, Taiwan, and currently studying Information Management at National Taiwan University of Science and Technology. I believe that language proficiency is essential in today's global society, and my parents have always placed great emphasis on developing my language skills. As a result, I am proficient in both Mandarin Chinese and English. Additionally, I have gained hands-on experience in HTML coding and possess basic knowledge of hardware installation. I am eager to apply my skills and knowledge in a challenging work environment to enhance my professional development." 19 | }, 20 | { 21 | "from": "human", 22 | "value": "\u5e6b\u6211\u589e\u52a0\u9019\u4e9b\uff0c\u6211\u7684\u96fb\u5b50\u90f5\u4ef6\u662f: 951753wdw@gmail.com \u96fb\u8a71\u865f\u78bc\u662f: 0965503337\uff0c\u76ee\u524d\u4f4f\u53f0\u5317\u5e02\u6587\u5c71\u5340\u79c0\u660e\u8def\u4e8c\u6bb5115\u5df76\u865f16\u6a13\uff0c\u9810\u8a082025\u5e74\u7562\u696d\u3002" 23 | }, 24 | { 25 | "from": "gpt", 26 | "value": "\u597d\u7684\uff0c\u6211\u5df2\u7d93\u6839\u64da\u60a8\u7684\u63d0\u4f9b\u7684\u8cc7\u8a0a\u66f4\u65b0\u4e86\u60a8\u7684\u82f1\u6587\u5c65\u6b77\uff0c\u8acb\u60a8\u6aa2\u67e5\u4ee5\u4e0b\u5167\u5bb9\uff1a\n\nResume\n\nPersonal Information:\nName: ZHENG, JING-CHENG\nCurrent Address: 16F, No.6, Lane 115, Section 2, Xiuming Road, Wenshan District, Taipei City, Taiwan\nEmail: [951753wdw@gmail.com](mailto:951753wdw@gmail.com)\nPhone: 0965503337\n\nEducation:\nDegree: Bachelor's Degree\nMajor: Information Management\nUniversity: National Taiwan University of Science and Technology\nLocation: Wenshan District, Taipei City, Taiwan\nExpected Graduation Date: 2025\n\nSkills:\n\n* Proficient in HTML coding\n* Basic knowledge of hardware installation\n\nExperience:\n[\u60a8\u7684\u5de5\u4f5c\u7d93\u9a57\uff0c\u53ef\u4f9d\u6642\u9593\u5012\u5e8f\u6392\u5217\uff0c\u5305\u542b\u516c\u53f8\u540d\u7a31\u3001\u8077\u7a31\u3001\u4efb\u8077\u6642\u9593\u3001\u5de5\u4f5c\u5167\u5bb9\u7b49]\n\nLanguage:\n\n* Mandarin Chinese: Native language\n* English: Proficient\n\nPersonal Statement:\nI was born and raised in Taipei City, Taiwan, and currently studying Information Management at National Taiwan University of Science and Technology. I believe that language proficiency is essential in today's global society, and my parents have always placed great emphasis on developing my language skills. As a result, I am proficient in both Mandarin Chinese and English. Additionally, I have gained hands-on experience in HTML coding and possess basic knowledge of hardware installation. I am eager to apply my skills and knowledge in a challenging work environment to enhance my professional development." 27 | }, 28 | { 29 | "from": "human", 30 | "value": "\u5e6b\u6211\u628a\u5b78\u6821\u53bb\u6389" 31 | }, 32 | { 33 | "from": "gpt", 34 | "value": "\u597d\u7684\uff0c\u5df2\u6839\u636e\u60a8\u7684\u8981\u6c42\u66f4\u65b0\u4e86\u82f1\u6587\u5c65\u5386\uff0c\u8bf7\u60a8\u786e\u8ba4\u4ee5\u4e0b\u5185\u5bb9\uff1a\n\nResume\n\nPersonal Information:\nName: ZHENG, JING-CHENG\nCurrent Address: 16F, No.6, Lane 115, Section 2, Xiuming Road, Wenshan District, Taipei City, Taiwan\nEmail: [951753wdw@gmail.com](mailto:951753wdw@gmail.com)\nPhone: 0965503337\n\nEducation:\nDegree: Bachelor's Degree\nMajor: Information Management\nExpected Graduation Date: 2025\n\nSkills:\n\n* Proficient in HTML coding\n* Basic knowledge of hardware installation\n\nExperience:\n[\u60a8\u7684\u5de5\u4f5c\u7ecf\u9a8c\uff0c\u53ef\u4f9d\u65f6\u95f4\u5012\u5e8f\u6392\u5217\uff0c\u5305\u542b\u516c\u53f8\u540d\u79f0\u3001\u804c\u79f0\u3001\u4efb\u804c\u65f6\u95f4\u3001\u5de5\u4f5c\u5185\u5bb9\u7b49]\n\nLanguage:\n\n* Mandarin Chinese: Native language\n* English: Proficient\n\nPersonal Statement:\nI was born and raised in Taipei City, Taiwan, and currently studying Information Management. I believe that language proficiency is essential in today's global society, and my parents have always placed great emphasis on developing my language skills. As a result, I am proficient in both Mandarin Chinese and English. Additionally, I have gained hands-on experience in HTML coding and possess basic knowledge of hardware installation. I am eager to apply my skills and knowledge in a challenging work environment to enhance my professional development." 35 | }, 36 | { 37 | "from": "human", 38 | "value": "\u5e6b\u6211\u518d\u60f3\u4e00\u4e9b\u500b\u4eba\u6558\u8ff0\uff0c\u6211\u5e73\u5e38\u4e0b\u8ab2\u7684\u6642\u5019\u559c\u6b61\u95b1\u8b80\u66f8\u7c4d\uff0c\u9592\u6687\u6642\u9593\u4e5f\u6703\u95b1\u8b80\uff0c\u6211\u4e5f\u559c\u6b61\u5f9e\u4e8b\u6236\u5916\u6d3b\u8207\u793e\u4ea4\u3002" 39 | }, 40 | { 41 | "from": "gpt", 42 | "value": "\u597d\u7684\uff0c\u5df2\u6839\u636e\u60a8\u7684\u63d0\u4f9b\u7684\u4fe1\u606f\u66f4\u65b0\u4e86\u82f1\u6587\u5c65\u5386\uff0c\u8bf7\u60a8\u786e\u8ba4\u4ee5\u4e0b\u5185\u5bb9\uff1a\n\nResume\n\nPersonal Information:\nName: ZHENG, JING-CHENG\nCurrent Address: 16F, No.6, Lane 115, Section 2, Xiuming Road, Wenshan District, Taipei City, Taiwan\nEmail: [951753wdw@gmail.com](mailto:951753wdw@gmail.com)\nPhone: 0965503337\n\nEducation:\nDegree: Bachelor's Degree\nMajor: Information Management\nExpected Graduation Date: 2025\n\nSkills:\n\n* Proficient in HTML coding\n* Basic knowledge of hardware installation\n\nExperience:\n[\u60a8\u7684\u5de5\u4f5c\u7ecf\u9a8c\uff0c\u53ef\u4f9d\u65f6\u95f4\u5012\u5e8f\u6392\u5217\uff0c\u5305\u542b\u516c\u53f8\u540d\u79f0\u3001\u804c\u79f0\u3001\u4efb\u804c\u65f6\u95f4\u3001\u5de5\u4f5c\u5185\u5bb9\u7b49]\n\nLanguage:\n\n* Mandarin Chinese: Native language\n* English: Proficient\n\nPersonal Statement:\nI was born and raised in Taipei City, Taiwan, and currently studying Information Management. I believe that language proficiency is essential in today's global society, and my parents have always placed great emphasis on developing my language skills. As a result, I am proficient in both Mandarin Chinese and English.\n\nApart from my academic and professional pursuits, I also enjoy reading books in my spare time. I find that reading provides me with the opportunity to broaden my knowledge and understanding of the world. Additionally, I am passionate about engaging in outdoor activities and socializing with others. I believe that these experiences allow me to develop my interpersonal skills and enhance my overall well-being.\n\nI am eager to apply my skills and knowledge in a challenging work environment to enhance my professional development and contribute to the success of the organization." 43 | } 44 | ] 45 | } 46 | ] -------------------------------------------------------------------------------- /deepspeed-config.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_micro_batch_size_per_gpu": "auto", 3 | "gradient_accumulation_steps": "auto", 4 | "steps_per_print": 50, 5 | "gradient_clipping": 1.0, 6 | "zero_optimization": { 7 | "stage": 2, 8 | "offload_optimizer": { 9 | "device": "cpu" 10 | }, 11 | "contiguous_gradients": true, 12 | "overlap_comm": true 13 | }, 14 | "zero_allow_untested_optimizer": true, 15 | "fp16": { 16 | "enabled": true, 17 | "loss_scale": 0, 18 | "loss_scale_window": 1000, 19 | "hysteresis": 2, 20 | "min_loss_scale": 1 21 | }, 22 | "optimizer": { 23 | "type": "Adam", 24 | "params": { 25 | "lr": "auto", 26 | "betas": "auto", 27 | "eps": "auto", 28 | "weight_decay": "auto" 29 | } 30 | }, 31 | "activation_checkpointing": { 32 | "partition_activations": true, 33 | "contiguous_memory_optimization": true 34 | }, 35 | "wall_clock_breakdown": false 36 | } 37 | -------------------------------------------------------------------------------- /fastchat/data/changeto_alpaca.py: -------------------------------------------------------------------------------- 1 | import json 2 | import argparse 3 | 4 | 5 | def change_format(args): 6 | with open(args["in_file"], 'r+', encoding='utf-8') as f: 7 | data = json.load(f) 8 | j = len(data) 9 | i = 0 10 | new_data = [] 11 | temp_instruction = "" 12 | for items in data: 13 | i = i + 1 14 | print(str(i) + ' of ' + str(j)) 15 | for item in items["conversations"]: 16 | if item["from"] == "human": 17 | temp_instruction = item["value"] 18 | else: 19 | new_item = {} 20 | new_item["instruction"] = temp_instruction 21 | new_item["input"] = "" 22 | new_item["output"] = item["value"] 23 | new_data.append(new_item) 24 | with open(args["out_file"], 'w', encoding='utf-8') as f: 25 | json.dump(new_data, f, ensure_ascii=False, indent=2) 26 | 27 | 28 | if __name__ == "__main__": 29 | parser = argparse.ArgumentParser() 30 | parser.add_argument("--in-file", type=str, required=True) 31 | parser.add_argument("--out-file", type=str, 32 | default="./data/alpaca_sharegpt.json") 33 | args = parser.parse_args() 34 | change_format(vars(args)) 35 | -------------------------------------------------------------------------------- /fastchat/data/clean_sharegpt.py: -------------------------------------------------------------------------------- 1 | """ 2 | - Convert html to markdown with basic data cleaning. 3 | - Deduplication. 4 | 5 | Usage: 6 | python3 -m fastchat.data.clean_sharegpt --in sharegpt_html.json --out sharegpt_clean.json 7 | """ 8 | import argparse 9 | from concurrent.futures import ProcessPoolExecutor 10 | import json 11 | import logging 12 | import re 13 | from typing import Dict, Union 14 | import opencc 15 | 16 | import bs4 17 | import markdownify # == 0.11.6 18 | from tqdm import tqdm 19 | 20 | 21 | div_pattern = re.compile("") 22 | span_pattern = re.compile("") 23 | code_lang_pattern = re.compile( 24 | "```\s*" + "(.*?)" + "(?:Copy code)+" + "(.+?)" + "\s*?```", re.DOTALL 25 | ) 26 | code_lang_format = "```\g<1>\n\g<2>\n```" 27 | regenerate_pattern = re.compile("\d+ / \d+") 28 | copy_chars_pattern = re.compile("Copy\d+ chars / \d+ words") 29 | copy_code_pattern = re.compile("```(.*?)Copy code\s*```") 30 | converter = opencc.OpenCC('t2s') 31 | 32 | def reformat_code(val: str) -> str: 33 | # Input code format is: 34 | # ``` 35 | # $Copy code$ 36 | # 37 | # ``` 38 | # This function convert it into the correct markdown format 39 | return re.sub(code_lang_pattern, code_lang_format, val) 40 | 41 | 42 | def html_to_markdown(val: str) -> str: 43 | # Remove all
. This is required to make intent work in code blocks. 44 | val = re.sub(div_pattern, "", val) 45 | # Remove all . This is required to make underscores work in code blocks. 46 | val = re.sub(span_pattern, "", val) 47 | # Markdown to html 48 | val = markdownify.markdownify(val).strip() 49 | # Reformat code 50 | val = reformat_code(val) 51 | 52 | # Remove noisy "[number] / [number]" at the beginning 53 | noise = re.search(regenerate_pattern, val) 54 | if noise and noise.start() == 0: 55 | val = val[noise.end() :] 56 | # Remove noisy "Copy[number] chars / [number] words" 57 | val = re.sub(copy_chars_pattern, "", val) 58 | # Remove empty code block ```\nCopy code\n``` 59 | val = re.sub(copy_code_pattern, "", val) 60 | 61 | # Strip 62 | val = val.replace("\n\n\n", "\n").strip() 63 | 64 | return val 65 | 66 | 67 | def contain_blocked_words(val: str) -> bool: 68 | blocked_words = ["openai", "chatgpt"] 69 | for w in blocked_words: 70 | if w in val.lower(): 71 | return True 72 | return False 73 | 74 | 75 | def clean_html_one_sample(sample): 76 | roles = ["human", "gpt"] 77 | 78 | if len(sample["conversations"]) <= 1: 79 | return (sample, 1) 80 | 81 | # Adjust the offset for cases like https://sharegpt.com/c/VyaZlh4 82 | if sample["conversations"][0]["from"] != "human": 83 | sample["conversations"] = sample["conversations"][1:] 84 | if len(sample["conversations"]) <= 1: 85 | return (sample, 1) 86 | 87 | if sample["conversations"][-1]["from"] == "human": 88 | sample["conversations"] = sample["conversations"][:-1] 89 | if len(sample["conversations"]) <= 1: 90 | return (sample, 1) 91 | 92 | for i, c in enumerate(sample["conversations"]): 93 | if c["from"] != roles[i % 2]: 94 | return (sample, 2) 95 | 96 | if contain_blocked_words(c["value"]): 97 | return (sample, 3) 98 | 99 | try: 100 | new_val = html_to_markdown(c["value"]) 101 | except (bs4.builder.ParserRejectedMarkup, AssertionError): 102 | return (sample, 4) 103 | new_val = converter.convert(new_val) 104 | c["value"] = new_val 105 | 106 | return (sample, 0) 107 | 108 | 109 | def clean_html_all(content, begin, end): 110 | """ 111 | Clean the source html files. 112 | """ 113 | cnt_skip = 0 114 | cnt_blocked_words = 0 115 | cnt_wrong_format = 0 116 | cnt_parser_error = 0 117 | cnt_too_short = 0 118 | cnt_id_duplication = 0 119 | cnt_value_duplication = 0 120 | cnt_tag = 0 121 | 122 | content = content[begin:end] 123 | processed = [] 124 | with ProcessPoolExecutor() as executor: 125 | for result in tqdm( 126 | executor.map(clean_html_one_sample, content), total=len(content) 127 | ): 128 | processed.append(result) 129 | 130 | visited = {} 131 | new_content = [] 132 | for sample, error_code in tqdm(processed): 133 | cid = sample["id"] 134 | skipped = True 135 | 136 | if error_code != 0: 137 | if error_code == 1: 138 | print(f"id {cid} is too short") 139 | cnt_too_short += 1 140 | elif error_code == 2: 141 | print(f"id {cid} has a wrong format") 142 | cnt_wrong_format += 1 143 | elif error_code == 3: 144 | print(f"id {cid} contains blocked words") 145 | cnt_blocked_words += 1 146 | elif error_code == 4: 147 | print(f"id {cid} contains parser errors") 148 | cnt_parser_error += 1 149 | else: 150 | raise ValueError(f"Invalid error_code: {error_code}") 151 | elif cid in visited: 152 | print(f"id {cid} is an id duplication of {visited[cid]}") 153 | cnt_id_duplication += 1 154 | elif ( 155 | sample["conversations"][1]["value"], 156 | len(sample["conversations"]), 157 | ) in visited: 158 | key = (sample["conversations"][1]["value"], len(sample["conversations"])) 159 | print(f"id {cid} is a value duplication of {visited[key]}") 160 | cnt_value_duplication += 1 161 | else: 162 | key = (sample["conversations"][1]["value"], len(sample["conversations"])) 163 | visited[cid] = visited[key] = cid 164 | skipped = False 165 | 166 | if not skipped: 167 | new_content.append(sample) 168 | else: 169 | cnt_skip += 1 170 | 171 | print( 172 | f"total: {len(content)}, skip: {cnt_skip}, new: {len(new_content)}, " 173 | f"cnt_blocked_words: {cnt_blocked_words}, cnt_parser_error: {cnt_parser_error}, " 174 | f"cnt_wrong_format: {cnt_wrong_format}, " 175 | f"cnt_too_short: {cnt_too_short}, cnt_id_duplication: {cnt_id_duplication}, " 176 | f"cnt_value_duplication: {cnt_value_duplication}, " 177 | ) 178 | 179 | return new_content 180 | 181 | 182 | def main(args): 183 | content = json.load(open(args["in_file"], "r",encoding='utf-8')) 184 | content = clean_html_all(content, args["begin"], args["end"]) 185 | json_content = json.dumps(content, ensure_ascii=False,indent=2) 186 | with open(args["out_file"], 'w', encoding='utf-8') as f: 187 | f.write(json_content) 188 | 189 | 190 | if __name__ == "__main__": 191 | parser = argparse.ArgumentParser() 192 | parser.add_argument("--in-file", type=str, required=True) 193 | parser.add_argument("--out-file", type=str, default="sharegpt_clean.json") 194 | parser.add_argument("--begin", type=int) 195 | parser.add_argument("--end", type=int) 196 | parser.add_argument("--debug", action="store_true") 197 | args = parser.parse_args() 198 | main(vars(args)) 199 | -------------------------------------------------------------------------------- /fastchat/data/merge.py: -------------------------------------------------------------------------------- 1 | """ 2 | Merge two conversation files into one 3 | 4 | Usage: python3 -m fastchat.data.merge --in file1.json file2.json --out merged.json 5 | """ 6 | 7 | import argparse 8 | import json 9 | from typing import Dict, Sequence, Optional 10 | 11 | 12 | if __name__ == "__main__": 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument("--in-file", type=str, required=True, nargs="+") 15 | parser.add_argument("--out-file", type=str, default="merged.json") 16 | args = parser.parse_args() 17 | 18 | new_content = [] 19 | for in_file in args.in_file: 20 | content = json.load(open(in_file, "r",encoding='utf-8')) 21 | new_content.extend(content) 22 | 23 | json.dump(new_content, open(args.out_file, "w",encoding='utf-8'), ensure_ascii=False,indent=2) 24 | -------------------------------------------------------------------------------- /fastchat/data/optional_clean.py: -------------------------------------------------------------------------------- 1 | """ 2 | Do optional cleaning (e.g., remove some languages). 3 | 4 | Usage: 5 | python3 -m fastchat.data.optional_clean --in input.json --out output.json --keep-lang en 6 | python3 -m fastchat.data.optional_clean --in input.json --out output.json --skip-lang en 7 | 8 | Requirement: 9 | pip3 install polyglot icu pyicu pycld2 morfessor 10 | """ 11 | import argparse 12 | import json 13 | import re 14 | 15 | import polyglot 16 | from polyglot.detect import Detector 17 | import pycld2 18 | from tqdm import tqdm 19 | 20 | 21 | def skip(conv, args): 22 | # Remove certain languages 23 | if args.keep_lang != "all" or args.skip_lang is not None: 24 | text = "\n".join([x["value"] for x in conv["conversations"]]) 25 | try: 26 | lang_code = Detector(text).language.code 27 | except (pycld2.error, polyglot.detect.base.UnknownLanguage): 28 | lang_code = "unknown" 29 | 30 | if args.keep_lang != "all" and lang_code != args.keep_lang: 31 | return True 32 | 33 | if lang_code == args.skip_lang: 34 | return True 35 | 36 | # Remove repetitive numbers 37 | if args.reduce_rep: 38 | for sentence in conv["conversations"]: 39 | val = sentence["value"] 40 | sub = re.search(r"(\d)\1{8}", val) 41 | if sub is not None: 42 | return True 43 | 44 | return False 45 | 46 | 47 | if __name__ == "__main__": 48 | parser = argparse.ArgumentParser() 49 | parser.add_argument("--in-file", type=str, required=True) 50 | parser.add_argument("--out-file", type=str) 51 | parser.add_argument( 52 | "--keep-lang", 53 | type=str, 54 | default="all", 55 | choices=["all", "en"], 56 | help="Only keep certain langauges.", 57 | ) 58 | parser.add_argument("--skip-lang", type=str, help="Skip a specific language.") 59 | # NOTE: Be careful about reduce_rep which may remove some good data. 60 | # For example, addresses could have long consecutive 0's 61 | parser.add_argument("--reduce-rep", action="store_true") 62 | args = parser.parse_args() 63 | 64 | in_file = args.in_file 65 | out_file = args.out_file 66 | keep_lang = args.keep_lang 67 | skip_lang = args.skip_lang 68 | reduce_rep = args.reduce_rep 69 | assert keep_lang == "all" or skip_lang is None 70 | 71 | if out_file is None: 72 | out_file = "sharegpt_clean" 73 | if keep_lang != "all": 74 | out_file += "_" + keep_lang 75 | if skip_lang is not None: 76 | out_file += "_skip_" + skip_lang 77 | if reduce_rep: 78 | out_file += "_reduce_rep" 79 | out_file += ".json" 80 | 81 | content = json.load(open(in_file, "r",encoding='utf-8')) 82 | num_conv = len(content) 83 | 84 | new_content = [] 85 | for conv in tqdm(content): 86 | if not skip(conv, args): 87 | new_content.append(conv) 88 | 89 | print(f"return {len(new_content)} out of {len(content)}, start dump ...") 90 | json_content = json.dumps(new_content, ensure_ascii=False,indent=2) 91 | with open(out_file, 'w', encoding='utf-8') as f: 92 | f.write(json_content) 93 | -------------------------------------------------------------------------------- /fastchat/data/split_long_conversation.py: -------------------------------------------------------------------------------- 1 | """ 2 | Split long conversations based on certain max length. 3 | 4 | Usage: python3 -m fastchat.data.split_long_conversation \ 5 | --in sharegpt_clean.json \ 6 | --out sharegpt_split.json \ 7 | --model-name-or-path $ 8 | """ 9 | import argparse 10 | from concurrent.futures import ProcessPoolExecutor 11 | import json 12 | from typing import Dict, Sequence, Optional 13 | 14 | import transformers 15 | from tqdm import tqdm 16 | 17 | 18 | def make_sample(sample, start_idx, end_idx): 19 | assert (end_idx - start_idx) % 2 == 0 20 | return { 21 | "id": sample["id"] + "_" + str(start_idx), 22 | "conversations": sample["conversations"][start_idx:end_idx], 23 | } 24 | 25 | 26 | tokenizer = max_length = None 27 | 28 | 29 | def split_one_sample(sample): 30 | tokenized_lens = [] 31 | conversations = sample["conversations"] 32 | conversations = conversations[: len(conversations) // 2 * 2] 33 | for c in conversations: 34 | length = len(tokenizer(c["value"]).input_ids) + 6 35 | tokenized_lens.append(length) 36 | 37 | start_idx = 0 38 | cur_len = 0 39 | 40 | if len(conversations) % 2 != 0 or len(conversations) < 2: 41 | return [] 42 | 43 | new_samples = [] 44 | for i in range(0, len(conversations), 2): 45 | tmp_len = tokenized_lens[i] + tokenized_lens[i + 1] 46 | if cur_len + tmp_len > max_length: 47 | new_samples.append(make_sample(sample, start_idx, i)) 48 | start_idx = i 49 | cur_len = 0 50 | elif i == len(conversations) - 2: 51 | new_samples.append(make_sample(sample, start_idx, i + 2)) 52 | 53 | cur_len += tmp_len 54 | 55 | return new_samples 56 | 57 | 58 | def split_all(content, begin, end, tokenizer_, max_length_): 59 | """ 60 | Keep the maximum round of conversations within the max token length constraint 61 | """ 62 | global tokenizer, max_length 63 | tokenizer = tokenizer_ 64 | max_length = max_length_ 65 | 66 | content = content[begin:end] 67 | new_content = [] 68 | 69 | with ProcessPoolExecutor() as executor: 70 | for result in tqdm(executor.map(split_one_sample, content), total=len(content)): 71 | new_content.extend(result) 72 | 73 | return new_content 74 | 75 | 76 | def filter_invalid_roles(content): 77 | new_content = [] 78 | for i, c in enumerate(content): 79 | roles = ["human", "gpt"] 80 | if len(c["conversations"]) <= 0: 81 | continue 82 | 83 | valid = True 84 | for j, s in enumerate(c["conversations"]): 85 | if s["from"] != roles[j % 2]: 86 | valid = False 87 | break 88 | 89 | if valid: 90 | new_content.append(c) 91 | 92 | return new_content 93 | 94 | 95 | def main(args): 96 | content = json.load(open(args.in_file, "r")) 97 | tokenizer = transformers.AutoTokenizer.from_pretrained( 98 | args.model_name_or_path, 99 | model_max_length=args.max_length, 100 | padding_side="right", 101 | use_fast=False, 102 | ) 103 | new_content = split_all(content, args.begin, args.end, tokenizer, args.max_length) 104 | new_content = filter_invalid_roles(new_content) 105 | 106 | print(f"total: {len(content)}, new: {len(new_content)}") 107 | json_content = json.dumps(new_content, ensure_ascii=False,indent=2) 108 | with open(args.out_file, 'w', encoding='utf-8') as f: 109 | f.write(json_content) 110 | 111 | 112 | if __name__ == "__main__": 113 | parser = argparse.ArgumentParser() 114 | parser.add_argument("--in-file", type=str, required=True) 115 | parser.add_argument("--out-file", type=str, default="sharegpt_split.json") 116 | parser.add_argument("--begin", type=int) 117 | parser.add_argument("--end", type=int) 118 | parser.add_argument("--model-name-or-path", type=str, required=True) 119 | parser.add_argument("--max-length", type=int, default=2048) 120 | args = parser.parse_args() 121 | main(args) 122 | -------------------------------------------------------------------------------- /fastchat/data/trans_dummy.py: -------------------------------------------------------------------------------- 1 | import json 2 | from transformers import pipeline 3 | import torch 4 | 5 | translator_en2zh = None 6 | 7 | 8 | def load_model(): 9 | global translator_en2zh 10 | if translator_en2zh is None: 11 | translator_en2zh = pipeline( 12 | "translation", model="Helsinki-NLP/opus-mt-en-zh", device=0) 13 | print("load model ok") 14 | 15 | 16 | def trans_en2zh(input): 17 | global translator_en2zh 18 | return translator_en2zh(input) 19 | 20 | 21 | def trans_dummy(): 22 | with open('./data/dummy_en.json', 'r+', encoding='utf-8') as f: 23 | data = json.load(f) 24 | j = len(data) 25 | i = 0 26 | for items in data: 27 | i = i + 1 28 | print(str(i) + ' of ' + str(j)) 29 | items['id'] = 'cn_' + items['id'] 30 | for item in items["conversations"] : 31 | item["value"] = trans_en2zh(item['value'])[0]['translation_text'] 32 | with open('./data/dummy_cn.json', 'w', encoding='utf-8') as f: 33 | json.dump(data, f,ensure_ascii=False,indent=2) 34 | 35 | 36 | if __name__ == "__main__": 37 | print("torch.cuda.is_available:", torch.cuda.is_available()) 38 | load_model() 39 | trans_dummy() 40 | -------------------------------------------------------------------------------- /fastchat/train/train_lora.py: -------------------------------------------------------------------------------- 1 | # Usage: deepspeed train_lora.py --deepspeed <$PATH_TO_DEEPSPEED_CONFIG> 2 | 3 | # Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright: 4 | # Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | import os 18 | from dataclasses import dataclass, field 19 | import logging 20 | import pathlib 21 | import typing 22 | import torch 23 | 24 | from deepspeed import zero 25 | from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus 26 | from peft import ( 27 | LoraConfig, 28 | get_peft_model, 29 | get_peft_model_state_dict, 30 | prepare_model_for_kbit_training, 31 | set_peft_model_state_dict, 32 | ) 33 | from transformers import LlamaForCausalLM, LlamaTokenizer 34 | import transformers 35 | from transformers import Trainer 36 | 37 | from fastchat.train.train import ( 38 | DataArguments, 39 | ModelArguments, 40 | TrainingArguments, 41 | make_supervised_data_module, 42 | ) 43 | 44 | 45 | @dataclass 46 | class LoraArguments: 47 | lora_r: int = 8 48 | lora_alpha: int = 16 49 | lora_dropout: float = 0.05 50 | lora_target_modules: typing.List[str] = field( 51 | default_factory=lambda: ["q_proj", "v_proj"] 52 | ) 53 | lora_weight_path: str = "" 54 | bias: str = "none" 55 | 56 | 57 | def maybe_zero_3(param): 58 | if hasattr(param, "ds_id"): 59 | assert param.ds_status == ZeroParamStatus.NOT_AVAILABLE 60 | with zero.GatheredParameters([param]): 61 | param = param.data.cpu().clone().detach() 62 | return param 63 | 64 | 65 | # Borrowed from peft.utils.get_peft_model_state_dict 66 | def get_peft_state_maybe_zero_3(state_dict, bias): 67 | if bias == "none": 68 | to_return = { 69 | k: state_dict[k].cpu().clone().detach() for k in state_dict if "lora_" in k 70 | } 71 | elif bias == "all": 72 | to_return = { 73 | k: state_dict[k] for k in state_dict if "lora_" in k or "bias" in k 74 | } 75 | elif bias == "lora_only": 76 | to_return = {} 77 | for k in state_dict: 78 | if "lora_" in k: 79 | to_return[k] = state_dict[k] 80 | bias_name = k.split("lora_")[0] + "bias" 81 | if bias_name in state_dict: 82 | to_return[bias_name] = state_dict[bias_name] 83 | else: 84 | raise NotImplementedError 85 | to_return = {k: maybe_zero_3(v) for k, v in to_return.items()} 86 | return to_return 87 | 88 | 89 | def train(): 90 | parser = transformers.HfArgumentParser( 91 | (ModelArguments, DataArguments, TrainingArguments, LoraArguments) 92 | ) 93 | ( 94 | model_args, 95 | data_args, 96 | training_args, 97 | lora_args, 98 | ) = parser.parse_args_into_dataclasses() 99 | print("开始装载模型") 100 | device_map = "auto" 101 | world_size = int(os.environ.get("WORLD_SIZE", 1)) 102 | ddp = world_size != 1 103 | if ddp: 104 | device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)} 105 | model = LlamaForCausalLM.from_pretrained( 106 | model_args.model_name_or_path, 107 | cache_dir=training_args.cache_dir, 108 | load_in_8bit=True, 109 | torch_dtype=torch.float16, 110 | device_map=device_map 111 | ) 112 | print("装载模型完成") 113 | model = prepare_model_for_kbit_training(model) 114 | print("模型处理为int8") 115 | lora_config = LoraConfig( 116 | r=lora_args.lora_r, 117 | lora_alpha=lora_args.lora_alpha, 118 | target_modules=lora_args.lora_target_modules, 119 | lora_dropout=lora_args.lora_dropout, 120 | bias=lora_args.bias, 121 | task_type="CAUSAL_LM", 122 | ) 123 | model = get_peft_model(model, lora_config) 124 | print("模型用peft处理") 125 | if training_args.deepspeed is not None and training_args.local_rank == 0: 126 | model.print_trainable_parameters() 127 | tokenizer = transformers.AutoTokenizer.from_pretrained( 128 | model_args.model_name_or_path, 129 | cache_dir=training_args.cache_dir, 130 | model_max_length=training_args.model_max_length, 131 | padding_side="right", 132 | use_fast=False, 133 | ) 134 | tokenizer.pad_token = tokenizer.unk_token 135 | print("装载tokenizer") 136 | data_module = make_supervised_data_module( 137 | tokenizer=tokenizer, data_args=data_args) 138 | if torch.cuda.device_count() > 1: 139 | model.is_parallelizable = True 140 | model.model_parallel = True 141 | model.config.use_cache = False 142 | print("装载训练数据") 143 | trainer = Trainer( 144 | model=model, tokenizer=tokenizer, args=training_args, **data_module 145 | ) 146 | print("准备训练参数") 147 | # model.config.use_cache = False 148 | print("开始训练模型") 149 | if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")): 150 | trainer.train(resume_from_checkpoint=True) 151 | else: 152 | trainer.train() 153 | print("准备训练状态") 154 | trainer.save_state() 155 | print("保存训练模型") 156 | # Save states. Weights might be a placeholder in zero3 and need a gather 157 | state_dict = get_peft_state_maybe_zero_3( 158 | model.state_dict(), lora_args.bias) 159 | if training_args.local_rank == 0: 160 | model.save_pretrained(training_args.output_dir, state_dict=state_dict) 161 | 162 | 163 | if __name__ == "__main__": 164 | with torch.autocast("cuda"): 165 | train() 166 | -------------------------------------------------------------------------------- /generate.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | import argparse 5 | import torch 6 | import transformers 7 | from peft import PeftModel 8 | from transformers import GenerationConfig, LlamaForCausalLM, LlamaTokenizer 9 | 10 | from utils.prompter import Prompter 11 | 12 | if torch.cuda.is_available(): 13 | device = "cuda" 14 | else: 15 | device = "cpu" 16 | 17 | try: 18 | if torch.backends.mps.is_available(): 19 | device = "mps" 20 | except: # noqa: E722 21 | pass 22 | 23 | 24 | def main( 25 | load_8bit: bool = False, 26 | base_model: str = "", 27 | lora_weights: str = "tloen/alpaca-lora-7b", 28 | prompt_template: str = "" 29 | ): 30 | base_model = base_model or os.environ.get("BASE_MODEL", "") 31 | assert ( 32 | base_model 33 | ), "Please specify a --base_model, e.g. --base_model='huggyllama/llama-7b'" 34 | 35 | prompter = Prompter(prompt_template) 36 | tokenizer = LlamaTokenizer.from_pretrained(base_model) 37 | if device == "cuda": 38 | model = LlamaForCausalLM.from_pretrained( 39 | base_model, 40 | load_in_8bit=load_8bit, 41 | torch_dtype=torch.float16, 42 | device_map="auto", 43 | ) 44 | model = PeftModel.from_pretrained( 45 | model, 46 | lora_weights, 47 | torch_dtype=torch.float16, 48 | ) 49 | elif device == "mps": 50 | model = LlamaForCausalLM.from_pretrained( 51 | base_model, 52 | device_map={"": device}, 53 | torch_dtype=torch.float16, 54 | ) 55 | model = PeftModel.from_pretrained( 56 | model, 57 | lora_weights, 58 | device_map={"": device}, 59 | torch_dtype=torch.float16, 60 | ) 61 | else: 62 | model = LlamaForCausalLM.from_pretrained( 63 | base_model, device_map={"": device}, low_cpu_mem_usage=True 64 | ) 65 | model = PeftModel.from_pretrained( 66 | model, 67 | lora_weights, 68 | device_map={"": device}, 69 | ) 70 | 71 | # unwind broken decapoda-research config 72 | model.config.pad_token_id = tokenizer.pad_token_id = 0 # unk 73 | model.config.bos_token_id = 1 74 | model.config.eos_token_id = 2 75 | 76 | if not load_8bit: 77 | model.half() # seems to fix bugs for some users. 78 | 79 | model.eval() 80 | if torch.__version__ >= "2" and sys.platform != "win32": 81 | model = torch.compile(model) 82 | 83 | def evaluate( 84 | instruction, 85 | input=None, 86 | temperature=0.1, 87 | top_p=0.75, 88 | top_k=40, 89 | num_beams=4, 90 | max_new_tokens=128, 91 | stream_output=False, 92 | **kwargs, 93 | ): 94 | prompt = prompter.generate_prompt(instruction, input) 95 | inputs = tokenizer(prompt, return_tensors="pt") 96 | input_ids = inputs["input_ids"].to(device) 97 | generation_config = GenerationConfig( 98 | temperature=temperature, 99 | top_p=top_p, 100 | top_k=top_k, 101 | num_beams=num_beams, 102 | **kwargs, 103 | ) 104 | 105 | generate_params = { 106 | "input_ids": input_ids, 107 | "generation_config": generation_config, 108 | "return_dict_in_generate": True, 109 | "output_scores": True, 110 | "max_new_tokens": max_new_tokens, 111 | } 112 | 113 | with torch.no_grad(): 114 | generation_output = model.generate( 115 | input_ids=input_ids, 116 | generation_config=generation_config, 117 | return_dict_in_generate=True, 118 | output_scores=True, 119 | max_new_tokens=max_new_tokens, 120 | ) 121 | s = generation_output.sequences[0] 122 | output = tokenizer.decode(s) 123 | return prompter.get_response(output) 124 | 125 | while True: 126 | instruction = input("Input:") 127 | if len(instruction.strip()) == 0: 128 | break 129 | print("Response:", evaluate(instruction)) 130 | 131 | 132 | if __name__ == "__main__": 133 | # parse args 134 | parser = argparse.ArgumentParser() 135 | parser.add_argument('--base_model', default=None, type=str, required=True) 136 | parser.add_argument('--lora_weights', default=None, type=str, 137 | help="If None, perform inference on the base model") 138 | parser.add_argument('--load_8bit', action='store_true', 139 | help='only use CPU for inference') 140 | args = parser.parse_args() 141 | main(args.load_8bit, args.base_model, args.lora_weights, "") 142 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | transformers 2 | torch 3 | bs4 4 | markdownify 5 | polyglot 6 | pyicu 7 | pycld2 8 | tqdm 9 | accelerate 10 | einops 11 | flash_attn==1.0.5 12 | peft 13 | deepspeed 14 | bitsandbytes==0.37.2 15 | SentencePiece 16 | fschat==0.2.10 17 | opencc 18 | -------------------------------------------------------------------------------- /templates/alpaca.json: -------------------------------------------------------------------------------- 1 | { 2 | "description": "Template used by Alpaca-LoRA.", 3 | "prompt_input": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n", 4 | "prompt_no_input": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response:\n", 5 | "response_split": "### Response:" 6 | } 7 | -------------------------------------------------------------------------------- /utils/prompter.py: -------------------------------------------------------------------------------- 1 | """ 2 | A dedicated helper to manage templates and prompt building. 3 | """ 4 | 5 | import json 6 | import os.path as osp 7 | from typing import Union 8 | 9 | 10 | class Prompter(object): 11 | __slots__ = ("template", "_verbose") 12 | 13 | def __init__(self, template_name: str = "", verbose: bool = False): 14 | self._verbose = verbose 15 | if not template_name: 16 | # Enforce the default here, so the constructor can be called with '' and will not break. 17 | template_name = "alpaca" 18 | file_name = osp.join("templates", f"{template_name}.json") 19 | if not osp.exists(file_name): 20 | raise ValueError(f"Can't read {file_name}") 21 | with open(file_name) as fp: 22 | self.template = json.load(fp) 23 | if self._verbose: 24 | print( 25 | f"Using prompt template {template_name}: {self.template['description']}" 26 | ) 27 | 28 | def generate_prompt( 29 | self, 30 | instruction: str, 31 | input: Union[None, str] = None, 32 | label: Union[None, str] = None, 33 | ) -> str: 34 | # returns the full prompt from instruction and optional input 35 | # if a label (=response, =output) is provided, it's also appended. 36 | if input: 37 | res = self.template["prompt_input"].format( 38 | instruction=instruction, input=input 39 | ) 40 | else: 41 | res = self.template["prompt_no_input"].format( 42 | instruction=instruction 43 | ) 44 | if label: 45 | res = f"{res}{label}" 46 | if self._verbose: 47 | print(res) 48 | return res 49 | 50 | def get_response(self, output: str) -> str: 51 | return output.split(self.template["response_split"])[1].strip() 52 | --------------------------------------------------------------------------------