├── README.md ├── datautils_block.py ├── datautils_e2e.py ├── deita_dataset ├── __init__.py ├── constants.py ├── conversation.py └── train.py ├── examples ├── block_ap │ ├── Llama-2-7b │ │ ├── w2g128.sh │ │ ├── w2g64.sh │ │ ├── w3g128.sh │ │ └── w4g128.sh │ └── Mistral-Large-Instruct │ │ └── w2g64.sh ├── e2e_qp │ ├── Llama-2-7b │ │ ├── w2g128-alpaca.sh │ │ ├── w2g128-redpajama.sh │ │ ├── w2g64-alpaca.sh │ │ ├── w2g64-redpajama.sh │ │ ├── w3g128-alpaca.sh │ │ ├── w3g128-redpajama.sh │ │ ├── w4g128-alpaca.sh │ │ └── w4g128-redpajama.sh │ └── Llama-3-8b-instruct │ │ ├── w2g128-deita.sh │ │ ├── w2g64-deita.sh │ │ ├── w3g128-deita.sh │ │ └── w4g128-deita.sh ├── inference │ └── Llama-2-7b │ │ ├── fp16.sh │ │ └── w2g64.sh └── model_transfer │ ├── efficientqat_to_bitblas │ └── llama-2-7b.sh │ ├── efficientqat_to_gptq │ └── llama-2-7b.sh │ ├── fp32_to_16 │ └── llama-2-7b.sh │ └── real_to_fake │ └── llama-2-7b.sh ├── main_block_ap.py ├── main_e2e_qp.py ├── model_transfer ├── __init__.py ├── efficientqat_to_others.py ├── fp32_to_16.py └── real_to_fake.py ├── quantize ├── __init__.py ├── block_ap.py ├── int_linear_fake.py ├── int_linear_real.py ├── quantizer.py ├── triton_utils │ ├── __init__.py │ ├── custom_autotune.py │ ├── kernels.py │ └── mixin.py └── utils.py ├── requirements.txt └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # EfficientQAT 2 | Official PyTorch implement of paper [EfficientQAT: Efficient Quantization-Aware Training for Large Language Models](https://arxiv.org/abs/2407.11062) 3 | 4 | ## News 5 | - [2025/05] 🔥 **We explore the [Scaling Law for Quantization-Aware Training](https://export.arxiv.org/abs/2505.14302), which offers insights and instruction for LLMs QAT.** 6 | - [2025/05] 🌟 Our EfficientQAT paper has been accepted for ACL 2025 Main Conference! 🎉 Cheers! 7 | - [2024/10] 🔥 We release a new weight-activation quantization algorithm, [PrefixQuant](https://github.com/ChenMnZ/PrefixQuant), which is the first work to let the performance of static activation quantization surpasses dynamic ones. 8 | - [2024/08] The new inference backend [T-MAC](https://github.com/microsoft/T-MAC) from Microsoft has supported EffcientQAT models. 9 | - [2024/08] We support for the quantization of [Mistral-Large-Instruct](https://huggingface.co/mistralai/Mistral-Large-Instruct-2407). W2g64 Mistral-Large-Instruct with our EfficientQAT can compress the 123B models to 35 GB with only 4 points accuracy degeneration. 10 | - [2024/07] New featurs! We support to transfer EfficientQAT quantized models into `GPTQ v2` format and `BitBLAS` format, which can be directly loaded through [GPTQModel](https://github.com/ModelCloud/GPTQModel). 11 | - [2024/07] We release EfficientQAT, which pushes the limitation of uniform (INT) quantization in an efficient manner. 12 | 13 | ## Contents 14 | - [Installation](#installation) 15 | - [Model Zoo](#model-zoo) 16 | - [Training](#training) 17 | - [Inference](#Inference) 18 | - [Model Transferring](#model-transferring) 19 | - [Inference of Other Formats](#inference-of-other-formats) 20 | - [Citation](#citation) 21 | 22 | 23 | ## Installation 24 | 1. Clone this repository and navigate to EfficientQAT folder 25 | ``` 26 | git clone https://github.com/OpenGVLab/EfficientQAT.git 27 | cd EfficientQAT 28 | ``` 29 | 30 | 2. Install package 31 | ``` 32 | conda create -n efficientqat python==3.11 33 | 34 | conda activate efficientqat 35 | 36 | pip install -r requirements.txt 37 | ``` 38 | 39 | ## Model Zoo 40 | 41 | We provide a number of prequantized EfficientQAT models as follows: 42 | 43 | - WikiText2 PPL is measured in 2048 context length. 44 | - Avg. Accuracy indicate the average accuracy in 5 zero-shot reasoning tasks (WinoGrande,PIQA,HellaSwag,Arc-Easy, Arc-Challenge) with [lm-eval v0.4.2](https://github.com/EleutherAI/lm-evaluation-harness). 45 | - 1GB = $10^9$ Bit 46 | - Hub Link: EQAT indicates the original checkpoints. We also transfer the checkpoints into GPTQ and BitBLAS formats, which can be loaded directly through [GPTQModel](https://github.com/ModelCloud/GPTQModel). (PS: [GPTQModel](https://github.com/ModelCloud/GPTQModel) is a official bug-fixed repo of AutoGPTQ, which would be merged into [AutoGPTQ](https://github.com/AutoGPTQ/AutoGPTQ) in future.) 47 | 48 | | Model | Quantization | WikiText2 PPL | Avg. Accuracy | Model Size (GB) | Hub link| 49 | |-------|--------------|---------------|---------------|-----------------|----------| 50 | Llama-2-7B|fp16|5.47|64.86|13.2|-| 51 | Llama-2-7B|w4g128|5.53|64.27|3.7|[EQAT](https://huggingface.co/ChenMnZ/Llama-2-7b-EfficientQAT-w4g128)\|[GPTQ](https://huggingface.co/ChenMnZ/Llama-2-7b-EfficientQAT-w4g128-GPTQ)\|[BitBLAS](Llama-2-7b-EfficientQAT-w4g128-BitBLAS)| 52 | Llama-2-7B|w3g128|5.81|64.02|3.1|[EQAT](https://huggingface.co/ChenMnZ/Llama-2-7b-EfficientQAT-w3g128)| 53 | Llama-2-7B|w2g64|6.86|60.14|2.3|[EQAT](https://huggingface.co/ChenMnZ/Llama-2-7b-EfficientQAT-w2g64)\|[GPTQ](https://huggingface.co/ChenMnZ/Llama-2-7b-EfficientQAT-w2g64-GPTQ)\|[BitBLAS](Llama-2-7b-EfficientQAT-w2g64-BitBLAS)| 54 | Llama-2-7B|w2g128|7.17|59.50|2.2|[EQAT](https://huggingface.co/ChenMnZ/Llama-2-7b-EfficientQAT-w2g128)\|[GPTQ](https://huggingface.co/ChenMnZ/Llama-2-7b-EfficientQAT-w2g128-GPTQ)\|[BitBLAS](Llama-2-7b-EfficientQAT-w2g128-BitBLAS)| 55 | Llama-2-13B|fp16|4.88|67.81|25.4|-| 56 | Llama-2-13B|w4g128|4.93|67.52|6.8|[EQAT](https://huggingface.co/ChenMnZ/Llama-2-13b-EfficientQAT-w4g128)\|[GPTQ](https://huggingface.co/ChenMnZ/Llama-2-7b-EfficientQAT-w4g128-GPTQ)\|[BitBLAS](Llama-2-7b-EfficientQAT-w4g128-BitBLAS)| 57 | Llama-2-13B|w3g128|5.12|67.28|5.6|[EQAT](https://huggingface.co/ChenMnZ/Llama-2-13b-EfficientQAT-w3g128)| 58 | Llama-2-13B|w2g64|5.96|64.88|4.0|[EQAT](https://huggingface.co/ChenMnZ/Llama-2-13b-EfficientQAT-w2g64)\|[GPTQ](https://huggingface.co/ChenMnZ/Llama-2-13b-EfficientQAT-w2g64-GPTQ)\|[BitBLAS](Llama-2-13b-EfficientQAT-w2g64-BitBLAS)| 59 | Llama-2-13B|w2g128|6.08|63.88|3.8|[EQAT](https://huggingface.co/ChenMnZ/Llama-2-13b-EfficientQAT-w2g128)\|[GPTQ](https://huggingface.co/ChenMnZ/Llama-2-13b-EfficientQAT-w2g128-GPTQ)\|[BitBLAS](Llama-2-13b-EfficientQAT-w2g128-BitBLAS)| 60 | Llama-2-70B|fp16|3.32|72.41|131.6|-| 61 | Llama-2-70B|w4g128|3.39|72.62|35.8|[EQAT](https://huggingface.co/ChenMnZ/Llama-2-70b-EfficientQAT-w4g128)\|[GPTQ](https://huggingface.co/ChenMnZ/Llama-2-70b-EfficientQAT-w4g128-GPTQ)\|[BitBLAS](Llama-2-70b-EfficientQAT-w4g128-BitBLAS)| 62 | Llama-2-70B|w3g128|3.61|71.76|29.1|[EQAT](https://huggingface.co/ChenMnZ/Llama-2-70b-EfficientQAT-w3g128)| 63 | Llama-2-70B|w2g64|4.52|69.48|20.1|[EQAT](https://huggingface.co/ChenMnZ/Llama-2-70b-EfficientQAT-w2g64)\|[GPTQ](https://huggingface.co/ChenMnZ/Llama-2-70b-EfficientQAT-w2g64-GPTQ)\|[BitBLAS](Llama-2-70b-EfficientQAT-w2g64-BitBLAS)| 64 | Llama-2-70B|w2g128|4.61|68.93|18.9|[EQAT](https://huggingface.co/ChenMnZ/Llama-2-70b-EfficientQAT-w2g128)\|[GPTQ](https://huggingface.co/ChenMnZ/Llama-2-70b-EfficientQAT-w2g128-GPTQ)\|[BitBLAS](Llama-2-70b-EfficientQAT-w2g128-BitBLAS)| 65 | Llama-3-8B|fp16|6.14|68.58|13.0|-| 66 | Llama-3-8B|w4g128|6.47|68.43|5.4|[EQAT](https://huggingface.co/ChenMnZ/Llama-3-8b-EfficientQAT-w4g128)\|[GPTQ](https://huggingface.co/ChenMnZ/Llama-3-8b-EfficientQAT-w4g128-GPTQ)\|[BitBLAS](Llama-3-8b-EfficientQAT-w4g128-BitBLAS)| 67 | Llama-3-8B|w3g128|7.09|67.35|4.7|[EQAT](https://huggingface.co/ChenMnZ/Llama-3-8b-EfficientQAT-w3g128)| 68 | Llama-3-8B|w2g64|9.41|60.76|3.9|[EQAT](https://huggingface.co/ChenMnZ/Llama-3-8b-EfficientQAT-w2g64)\|[GPTQ](https://huggingface.co/ChenMnZ/Llama-3-8b-EfficientQAT-w4g128-GPTQ)\|[BitBLAS](Llama-3-8b-EfficientQAT-w2g64-BitBLAS)| 69 | Llama-3-8B|w2g128|9.80|59.36|3.8|[EQAT](https://huggingface.co/ChenMnZ/Llama-3-8b-EfficientQAT-w2g128)\|[GPTQ](https://huggingface.co/ChenMnZ/Llama-3-8b-EfficientQAT-w2g128-GPTQ)\|[BitBLAS](Llama-3-8b-EfficientQAT-w2g128-BitBLAS)| 70 | Llama-3-70B|fp16|2.85|75.33|137.8|-| 71 | Llama-3-70B|w4g128|3.17|74.57|38.9|[EQAT](https://huggingface.co/ChenMnZ/Llama-3-70b-EfficientQAT-w4g128)\|[GPTQ](https://huggingface.co/ChenMnZ/Llama-3-70b-EfficientQAT-w4g128-GPTQ)\|[BitBLAS](Llama-3-70b-EfficientQAT-w4g128-BitBLAS)| 72 | Llama-3-70B|w3g128|4.19|72.42|32.2|[EQAT](https://huggingface.co/ChenMnZ/Llama-3-70b-EfficientQAT-w3g128)| 73 | Llama-3-70B|w2g64|6.08|67.89|23.2|[EQAT](https://huggingface.co/ChenMnZ/Llama-3-70b-EfficientQAT-w2g64)\|[GPTQ](https://huggingface.co/ChenMnZ/Llama-3-70b-EfficientQAT-w2g64-GPTQ)| 74 | Llama-3-70B|w2g128|6.38|67.57|22.0|[EQAT](https://huggingface.co/ChenMnZ/Llama-3-70b-EfficientQAT-w2g128)\|[GPTQ](https://huggingface.co/ChenMnZ/Llama-3-70b-EfficientQAT-w2g128-GPTQ)\|[BitBLAS](Llama-3-70b-EfficientQAT-w2g128-BitBLAS)| 75 | Llama-3-8B-Instruct|fp16|8.29|68.43|13.0|-| 76 | Llama-3-8B-Instruct|w4g128|7.93|68.39|5.4|[EQAT](https://huggingface.co/ChenMnZ/Llama-3-8b-instruct-EfficientQAT-w4g128)\|[GPTQ](https://huggingface.co/ChenMnZ/Llama-3-8b-instruct-EfficientQAT-w4g128-GPTQ)\|[BitBLAS](Llama-3-8b-instruct-EfficientQAT-w4g128-BitBLAS)| 77 | Llama-3-8B-Instruct|w3g128|8.55|67.24|4.7|[EQAT](https://huggingface.co/ChenMnZ/Llama-3-8b-instruct-EfficientQAT-w3g128)| 78 | Llama-3-8B-Instruct|w2g64|11.19|60.66|3.9|[EQAT](https://huggingface.co/ChenMnZ/Llama-3-8b-instruct-EfficientQAT-w2g64)\|[GPTQ](https://huggingface.co/ChenMnZ/Llama-3-8b-instruct-EfficientQAT-w2g64-GPTQ)\|[BitBLAS](Llama-3-8b-instruct-EfficientQAT-w2g64-BitBLAS)| 79 | Llama-3-8B-Instruct|w2g128|11.73|60.16|3.8|[EQAT](https://huggingface.co/ChenMnZ/Llama-3-8b-instruct-EfficientQAT-w2g128)\|[GPTQ](https://huggingface.co/ChenMnZ/Llama-3-8b-instruct-EfficientQAT-w2g128-GPTQ)\|[BitBLAS](Llama-3-8b-instruct-EfficientQAT-w2g128-BitBLAS)| 80 | Llama-3-70B-Instruct|fp16|5.33|73.78|137.8|-| 81 | Llama-3-70B-Instruct|w4g128|5.35|73.47|38.9|[EQAT](https://huggingface.co/ChenMnZ/Llama-3-70b-instruct-EfficientQAT-w4g128)\|[GPTQ](https://huggingface.co/ChenMnZ/Llama-3-70b-instruct-EfficientQAT-w4g128-GPTQ)\|[BitBLAS](Llama-3-70b-instruct-EfficientQAT-w4g128-BitBLAS)| 82 | Llama-3-70B-Instruct|w3g128|5.65|72.87|32.2|[EQAT](https://huggingface.co/ChenMnZ/Llama-3-70b-instruct-EfficientQAT-w3g128)| 83 | Llama-3-70B-Instruct|w2g64|7.86|67.64|23.2|[EQAT](https://huggingface.co/ChenMnZ/Llama-3-70b-instruct-EfficientQAT-w2g64)\|[GPTQ](https://huggingface.co/ChenMnZ/Llama-3-70b-instruct-EfficientQAT-w2g64-GPTQ)\|[BitBLAS](Llama-3-70b-instruct-EfficientQAT-w2g64-BitBLAS)| 84 | Llama-3-70B-Instruct|w2g128|8.14|67.54|22.0|[EQAT](https://huggingface.co/ChenMnZ/Llama-3-70b-instruct-EfficientQAT-w2g128)\|[GPTQ](https://huggingface.co/ChenMnZ/Llama-3-70b-instruct-EfficientQAT-w2g128-GPTQ)\|[BitBLAS](Llama-3-70b-instruct-EfficientQAT-w2g128-BitBLAS)| 85 | Mistral-Large-Instruct-2407|fp16|2.74|77.76|228.5|-| 86 | Mistral-Large-Instruct-2407|w2g64|5.58|73.54|35.5|[GPTQ](https://huggingface.co/ChenMnZ/Mistral-Large-Instruct-2407-EfficientQAT-w2g64-GPTQ) 87 | 88 | ## Training 89 | EfficientQAT involves two consecutive training phases: Block-wise training of all parameters (**Block-AP**) and end-to-end training of quantization parameters (**E2E-QP**). The detailed training script can be found in `./examples`. We give the training script examples on Llama-2-7B with w2g64 quantization in the following. 90 | 91 | 1. Block-AP 92 | 93 | You should modify `--model` to the folder of full-precision model in the script before you running the following command. 94 | ``` 95 | bash examples/block_ap/Llama-2-7b/w2g64.sh 96 | ``` 97 | Specifically, the `--weight_lr` is `2e-5` for 2-bit and `1e-5` for 3-/4-bits in our experiments. 98 | 99 | Some other important arguments: 100 | - `--train_size`: number of training data samples, 4096 as default 101 | - `--val_size`: number of validation data samples, 64 as default 102 | - `--off_load_to_disk`: save training dataset to disk, saving CPU memory but may reduce training speed 103 | 104 | 105 | 2. E2E-QP 106 | 107 | Then, you can load the quantized model of Block-AP for further E2E-QP. Specifically, E2E-QP can adapt to different scenarios by changing the training datasets. You should modify `--quant_model_path` to the folder of quantized model in the script before you running the following command. 108 | 109 | 1\) Train on RedPajama 110 | ``` 111 | bash examples/e2e_qp/Llama-2-7b/w2g64-redpajama.sh 112 | ``` 113 | 114 | 2\) Train on Alpaca 115 | ``` 116 | bash examples/e2e_qp/Llama-2-7b/w2g128-redpajama.sh 117 | ``` 118 | Specifically, the `--learning_rate` is `2e-5` for 2-bit and `1e-5` for 3-/4-bits in our experiments. You can decrease the `--per_device_train_batch_size` to reduce the memory footprint during training, and making sure that `--gradient_accumulation_steps` increases by the same multiple to maintain the same batch size. 119 | 120 | 121 | 122 | ## Inference 123 | 124 | 1. Download the pre-quantized EfficientQAT models from Huggingface 125 | ``` 126 | pip install huggingface_hub 127 | 128 | huggingface-cli download ChenMnZ/Llama-2-7b-EfficientQAT-w2g64 --local-dir ./output/pre_quantized_models/Llama-2-7b-EfficientQAT-w2g64 129 | ``` 130 | 131 | 2. Evaluate the pre-quantized EfficientQAT model 132 | ``` 133 | CUDA_VISIBLE_DEVICES=0 python main_block_ap.py \ 134 | --resume_quant ./output/pre_quantized_models/Llama-2-7b-EfficientQAT-w2g64 \ 135 | --net Llama-2 \ 136 | --wbits 2 \ 137 | --group_size 64 \ 138 | --output_dir ./output/inference_results/ \ 139 | --eval_ppl \ 140 | --eval_tasks piqa,arc_easy,arc_challenge,hellaswag,winogrande 141 | ``` 142 | 143 | 144 | ## Model Transferring 145 | Firstly, you should install `gptqmodel` package to support GPTQ and BitBLAS quantization format: 146 | ``` 147 | git clone https://github.com/ModelCloud/GPTQModel.git && cd GPTQModel 148 | bash install.sh 149 | ``` 150 | - In our experiences, we test with `gptqmodel v0.9.8`. 151 | 152 | Then, we offer three types of transferring as follows: 153 | 154 | 1. Transfer EfficientQAT checkpoints to GPTQ format 155 | ``` 156 | bash examples/model_transfer/efficientqat_to_gptq/llama-2-7b.sh 157 | ``` 158 | - **Note**: Currently [AutoGPTQ](https://github.com/AutoGPTQ/AutoGPTQ) has overflow bugs for asymmetric quantization. Therefore, we choose the official bug-fixed version [GPTQModel](https://github.com/ModelCloud/GPTQModel) to transfer our asymmetric quantized models. Therefore, the GPTQ models provide by this repo can be only successfully loaded through [GPTQModel](https://github.com/ModelCloud/GPTQModel) otherwise [AutoGPTQ](https://github.com/AutoGPTQ/AutoGPTQ). 159 | 160 | 161 | 2. Transfer EfficientQAT checkpoints to BitBLAS format 162 | ``` 163 | bash examples/model_transfer/efficientqat_to_bitblas/llama-2-7b.sh 164 | ``` 165 | - Speedup has some problem, refer [this issue](https://github.com/microsoft/BitBLAS/issues/90) for details. 166 | 167 | 3. Transfer fp32 datas in EfficientQAT checkpoints to half-precision counterparts. 168 | Some of parameters are saved as fp32 for training, you can transfer them into half-precision to further reducing model size after training. 169 | ``` 170 | bash examples/model_transfer/fp32_to_16/llama-2-7b.sh 171 | ``` 172 | 173 | ## Inference of Other Formats 174 | Below is an example to inference with GPTQ or BitBLAS quantized formats. 175 | ```Python 176 | from transformers import AutoTokenizer 177 | from gptqmodel import GPTQModel 178 | 179 | quant_dir = "ChenMnZ/Llama-2-7b-EfficientQAT-w2g128-GPTQ" 180 | # quant_dir = "ChenMnZ/Llama-2-7b-EfficientQAT-w2g128-BitBLAS" 181 | # or local path 182 | 183 | tokenizer = AutoTokenizer.from_pretrained(quant_dir, use_fast=True) 184 | 185 | 186 | # load quantized model to the first GPU 187 | model = GPTQModel.from_quantized(quant_dir) 188 | 189 | # inference with model.generate 190 | print(tokenizer.decode(model.generate(**tokenizer("Model quantization is", return_tensors="pt").to(model.device))[0])) 191 | ``` 192 | 193 | 194 | ## Citation 195 | If you found this work useful, please consider citing: 196 | ``` 197 | @article{efficientqat, 198 | title={EfficientQAT: Efficient Quantization-Aware Training for Large Language Models}, 199 | author={Chen, Mengzhao and Shao, Wenqi and Xu, Peng and Wang, Jiahao and Gao, Peng and Zhang, Kaipeng and Qiao, Yu and Luo, Ping}, 200 | journal={arXiv preprint arXiv:2407.11062}, 201 | year={2024} 202 | } 203 | ``` 204 | -------------------------------------------------------------------------------- /datautils_block.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoTokenizer 2 | from datasets import load_dataset 3 | import numpy as np 4 | import torch 5 | import random 6 | from tqdm import tqdm 7 | import torch.nn as nn 8 | 9 | import torch 10 | from torch.utils.data import Dataset 11 | import os 12 | 13 | def get_wikitext2(tokenizer, train_size, val_size, seed, seqlen, test_only): 14 | print("get_wikitext2") 15 | traindata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train') 16 | testdata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test') 17 | 18 | testenc = tokenizer("\n\n".join(testdata['text']), return_tensors='pt') 19 | if test_only: 20 | return testenc 21 | trainenc = tokenizer("\n\n".join(traindata['text']), return_tensors='pt') 22 | 23 | 24 | random.seed(seed) 25 | trainloader = [] 26 | val_sample_ratio = 0.9 # sample train from [0:0.9] and val from [0.9:1.0] to avoid overlap 27 | for _ in range(train_size): 28 | i = random.randint(0, int(trainenc.input_ids.shape[1]*val_sample_ratio) - seqlen - 1) 29 | j = i + seqlen 30 | inp = trainenc.input_ids[:, i:j] 31 | tar = inp.clone() 32 | tar[:, :-1] = -100 33 | trainloader.append((inp, tar)) 34 | valloader = [] 35 | for _ in range(val_size): 36 | i = random.randint(int(trainenc.input_ids.shape[1]*val_sample_ratio) - seqlen - 1, trainenc.input_ids.shape[1] - seqlen - 1) 37 | j = i + seqlen 38 | inp = trainenc.input_ids[:, i:j] 39 | tar = inp.clone() 40 | tar[:, :-1] = -100 41 | valloader.append((inp, tar)) 42 | return trainloader, valloader 43 | 44 | 45 | def get_c4(tokenizer, train_size, val_size, seed, seqlen, test_only): 46 | print("get_c4") 47 | try: 48 | # set local path for faster loading 49 | traindata = load_dataset("arrow", 50 | data_files={ 51 | "train": "/cpfs01/user/chenmengzhao/huggingface/datasets/allenai___json/allenai--c4-6fbe877195f42de5/0.0.0/0f7e3662623656454fcd2b650f34e886a7db4b9104504885bd462096cc7a9f51/json-train-00000-of-00002.arrow", 52 | "validation": "/cpfs01/user/chenmengzhao/huggingface/datasets/allenai___json/allenai--c4-efc3d4f4606f44bd/0.0.0/fe5dd6ea2639a6df622901539cb550cf8797e5a6b2dd7af1cf934bed8e233e6e/json-validation.arrow", 53 | },split='train' 54 | ) 55 | valdata = load_dataset("arrow", 56 | data_files={ 57 | "validation": "/cpfs01/user/chenmengzhao/huggingface/datasets/allenai___json/allenai--c4-efc3d4f4606f44bd/0.0.0/fe5dd6ea2639a6df622901539cb550cf8797e5a6b2dd7af1cf934bed8e233e6e/json-validation.arrow", 58 | },split='validation' 59 | ) 60 | except: 61 | traindata = load_dataset( 62 | 'allenai/c4', 'allenai--c4', data_files={'train': 'en/c4-train.00000-of-01024.json.gz'}, split='train' 63 | ) 64 | valdata = load_dataset( 65 | 'allenai/c4', 'allenai--c4', data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'}, split='validation' 66 | ) 67 | 68 | random.seed(0) 69 | valenc = [] 70 | for _ in range(256): 71 | while True: 72 | i = random.randint(0, len(valdata) - 1) 73 | tmp = tokenizer(valdata[i]['text'], return_tensors='pt') 74 | if tmp.input_ids.shape[1] >= seqlen: 75 | break 76 | i = random.randint(0, tmp.input_ids.shape[1] - seqlen - 1) 77 | j = i + seqlen 78 | valenc.append(tmp.input_ids[:, i:j]) 79 | valenc = torch.hstack(valenc) 80 | if test_only: 81 | return valenc 82 | 83 | random.seed(seed) 84 | trainloader = [] 85 | val_sample_ratio = 0.9 # sample train from [0:0.9] and val from [0.9:1.0] to avoid overlap 86 | for _ in range(train_size): 87 | while True: 88 | i = random.randint(0, int(len(traindata)*val_sample_ratio) - 1) 89 | trainenc = tokenizer(traindata[i]['text'], return_tensors='pt') 90 | if trainenc.input_ids.shape[1] >= seqlen+1: 91 | break 92 | i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) 93 | j = i + seqlen 94 | inp = trainenc.input_ids[:, i:j] 95 | tar = inp.clone() 96 | tar[:, :-1] = -100 97 | trainloader.append((inp, tar)) 98 | 99 | valloader = [] 100 | for _ in range(val_size): 101 | while True: 102 | i = random.randint(int(len(traindata)*val_sample_ratio),len(traindata)-1) 103 | trainenc = tokenizer(traindata[i]['text'], return_tensors='pt') 104 | if trainenc.input_ids.shape[1] >= seqlen+1: 105 | break 106 | i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) 107 | j = i + seqlen 108 | inp = trainenc.input_ids[:, i:j] 109 | tar = inp.clone() 110 | tar[:, :-1] = -100 111 | valloader.append((inp, tar)) 112 | 113 | 114 | 115 | return trainloader, valloader 116 | 117 | def get_redpajama(tokenizer, train_size, val_size, seed, seqlen): 118 | print("get_redpajama") 119 | try: 120 | loacal_dataset = "/cpfs01/user/chenmengzhao/huggingface/datasets/togethercomputer___red_pajama-data-1_t-sample" 121 | traindata = load_dataset(loacal_dataset,split='train') 122 | except: 123 | traindata = load_dataset("togethercomputer/RedPajama-Data-1T-Sample",split='train') 124 | random.seed(seed) 125 | traindata = traindata.shuffle(seed=seed) 126 | trainloader = [] 127 | val_sample_ratio = 0.9 128 | for _ in range(train_size): 129 | while True: 130 | i = random.randint(0, int(len(traindata)*val_sample_ratio) - 1) 131 | trainenc = tokenizer(traindata[i]['text'], return_tensors='pt') 132 | if trainenc.input_ids.shape[1] >= seqlen+1: 133 | break 134 | i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) 135 | j = i + seqlen 136 | inp = trainenc.input_ids[:, i:j] 137 | tar = inp.clone() 138 | tar[:, :-1] = -100 139 | trainloader.append((inp, tar)) 140 | 141 | valloader = [] 142 | for _ in range(val_size): 143 | while True: 144 | i = random.randint(int(len(traindata)*val_sample_ratio),len(traindata)-1) 145 | trainenc = tokenizer(traindata[i]['text'], return_tensors='pt') 146 | if trainenc.input_ids.shape[1] >= seqlen+1: 147 | break 148 | i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) 149 | j = i + seqlen 150 | inp = trainenc.input_ids[:, i:j] 151 | tar = inp.clone() 152 | tar[:, :-1] = -100 153 | valloader.append((inp, tar)) 154 | return trainloader, valloader 155 | 156 | 157 | 158 | def get_loaders( 159 | name, tokenizer, train_size=128, val_size=64,seed=0, seqlen=2048, test_only=False 160 | ): 161 | if 'wikitext2' in name: 162 | return get_wikitext2(tokenizer,train_size,val_size,seed,seqlen,test_only) 163 | elif 'c4' in name: 164 | return get_c4(tokenizer,train_size,val_size,seed,seqlen,test_only) 165 | elif 'redpajama' in name: 166 | return get_redpajama(tokenizer,train_size,val_size,seed,seqlen) 167 | else: 168 | raise NotImplementedError 169 | 170 | 171 | 172 | @torch.no_grad() 173 | def test_ppl(model, tokenizer, datasets=['wikitext2'],ppl_seqlen=2048): 174 | results = {} 175 | for dataset in datasets: 176 | testloader = get_loaders( 177 | dataset, 178 | tokenizer, 179 | seed=0, 180 | seqlen=ppl_seqlen, 181 | test_only=True 182 | ) 183 | if "c4" in dataset: 184 | testenc = testloader 185 | else: 186 | testenc = testloader.input_ids 187 | 188 | seqlen = ppl_seqlen 189 | nsamples = testenc.numel() // seqlen 190 | use_cache = model.config.use_cache 191 | model.config.use_cache = False 192 | model.eval() 193 | nlls = [] 194 | if hasattr(model,'lm_head') and isinstance(model.lm_head, nn.Linear): 195 | classifier = model.lm_head 196 | elif hasattr(model.model,'lm_head'): 197 | # for gptqmodels 198 | classifier = None 199 | elif hasattr(model,'output'): 200 | # for internlm 201 | classifier = model.output 202 | else: 203 | raise NotImplementedError 204 | for i in tqdm(range(nsamples)): 205 | batch = testenc[:, (i * seqlen) : ((i + 1) * seqlen)].to(model.device) 206 | outputs = model.model(batch) 207 | if classifier is not None: 208 | hidden_states = outputs[0] 209 | logits = classifier(hidden_states.to(classifier.weight.dtype)) 210 | else: 211 | logits = outputs[0] 212 | shift_logits = logits[:, :-1, :] 213 | shift_labels = testenc[:, (i * seqlen) : ((i + 1) * seqlen)][ 214 | :, 1: 215 | ].to(shift_logits.device) 216 | loss_fct = nn.CrossEntropyLoss() 217 | loss = loss_fct( 218 | shift_logits.view(-1, shift_logits.size(-1)), 219 | shift_labels.view(-1), 220 | ) 221 | neg_log_likelihood = loss.float() * seqlen 222 | nlls.append(neg_log_likelihood) 223 | 224 | 225 | ppl = torch.exp(torch.stack(nlls).sum() / (nsamples * seqlen)) 226 | print(f'{dataset}:{ppl}') 227 | results[dataset] = ppl.item() 228 | model.config.use_cache = use_cache 229 | return results 230 | 231 | class BlockTrainDataset(Dataset): 232 | def __init__(self, size, seqlen, hidden_size, batch_size, dtype, cache_path='./cache/block_training_data', off_load_to_disk=False): 233 | self.size = size 234 | self.seqlen = seqlen 235 | self.hidden_size = hidden_size 236 | self.dtype = dtype 237 | self.cache_path = cache_path 238 | self.off_load_to_disk = off_load_to_disk 239 | self.batch_size = batch_size 240 | assert size%batch_size == 0 241 | 242 | if self.off_load_to_disk: 243 | if not os.path.exists(self.cache_path): 244 | os.makedirs(self.cache_path) 245 | self._initialize_data_on_disk() 246 | else: 247 | self.data = torch.zeros((self.size//self.batch_size, self.batch_size, self.seqlen, self.hidden_size), dtype=self.dtype) 248 | 249 | def _initialize_data_on_disk(self): 250 | for idx in range(self.size//self.batch_size): 251 | tensor = torch.zeros((self.batch_size, self.seqlen, self.hidden_size), dtype=self.dtype) 252 | filepath = self._get_file_path(idx) 253 | torch.save(tensor, filepath) 254 | 255 | def _get_file_path(self, idx): 256 | return os.path.join(self.cache_path, f"data_{idx}.pt") 257 | 258 | def __len__(self): 259 | return self.size//self.batch_size 260 | 261 | def __getitem__(self, idx): 262 | if idx >= self.__len__(): 263 | raise IndexError("Index out of range") 264 | if self.off_load_to_disk: 265 | filepath = self._get_file_path(idx) 266 | tensor = torch.load(filepath) 267 | else: 268 | tensor = self.data[idx] 269 | return tensor 270 | 271 | def update_data(self, idx, new_data): 272 | if self.off_load_to_disk: 273 | filepath = self._get_file_path(idx) 274 | torch.save(new_data.to(self.dtype), filepath) 275 | else: 276 | self.data[idx] = new_data -------------------------------------------------------------------------------- /datautils_e2e.py: -------------------------------------------------------------------------------- 1 | ## code from qlora 2 | import torch 3 | from typing import Dict, Sequence 4 | from datasets import load_dataset 5 | import os 6 | from itertools import chain 7 | from pathlib import Path 8 | from transformers import default_data_collator 9 | import transformers 10 | from dataclasses import dataclass 11 | from torch.nn.utils.rnn import pad_sequence 12 | import copy 13 | import numpy as np 14 | 15 | IGNORE_INDEX = -100 16 | DEFAULT_PAD_TOKEN = "[PAD]" 17 | 18 | 19 | 20 | @dataclass 21 | class DataCollatorForCausalLM(object): 22 | tokenizer: transformers.PreTrainedTokenizer 23 | source_max_len: int 24 | target_max_len: int 25 | train_on_source: bool 26 | predict_with_generate: bool 27 | 28 | def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: 29 | # Extract elements 30 | sources = [f"{self.tokenizer.bos_token}{example['input']}" for example in instances] 31 | targets = [f"{example['output']}{self.tokenizer.eos_token}" for example in instances] 32 | # Tokenize 33 | tokenized_sources_with_prompt = self.tokenizer( 34 | sources, 35 | max_length=self.source_max_len, 36 | truncation=True, 37 | add_special_tokens=False, 38 | ) 39 | tokenized_targets = self.tokenizer( 40 | targets, 41 | max_length=self.target_max_len, 42 | truncation=True, 43 | add_special_tokens=False, 44 | ) 45 | # Build the input and labels for causal LM 46 | input_ids = [] 47 | labels = [] 48 | for tokenized_source, tokenized_target in zip( 49 | tokenized_sources_with_prompt['input_ids'], 50 | tokenized_targets['input_ids'] 51 | ): 52 | if not self.predict_with_generate: 53 | input_ids.append(torch.tensor(tokenized_source + tokenized_target)) 54 | if not self.train_on_source: 55 | labels.append( 56 | torch.tensor([IGNORE_INDEX for _ in range(len(tokenized_source))] + copy.deepcopy(tokenized_target)) 57 | ) 58 | else: 59 | labels.append(torch.tensor(copy.deepcopy(tokenized_source + tokenized_target))) 60 | else: 61 | input_ids.append(torch.tensor(tokenized_source)) 62 | # Apply padding 63 | input_ids = pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id) 64 | labels = pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX) if not self.predict_with_generate else None 65 | data_dict = { 66 | 'input_ids': input_ids, 67 | 'attention_mask':input_ids.ne(self.tokenizer.pad_token_id), 68 | } 69 | if labels is not None: 70 | data_dict['labels'] = labels 71 | return data_dict 72 | 73 | 74 | ALPACA_PROMPT_DICT = { 75 | "prompt_input": ( 76 | "Below is an instruction that describes a task, paired with an input that provides further context. " 77 | "Write a response that appropriately completes the request.\n\n" 78 | "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response: " 79 | ), 80 | "prompt_no_input": ( 81 | "Below is an instruction that describes a task. " 82 | "Write a response that appropriately completes the request.\n\n" 83 | "### Instruction:\n{instruction}\n\n### Response: " 84 | ), 85 | } 86 | 87 | def extract_alpaca_dataset(example): 88 | if example.get("input", "") != "": 89 | prompt_format = ALPACA_PROMPT_DICT["prompt_input"] 90 | else: 91 | prompt_format = ALPACA_PROMPT_DICT["prompt_no_input"] 92 | return {'input': prompt_format.format(**example)} 93 | 94 | 95 | def make_data_module(tokenizer: transformers.PreTrainedTokenizer, args) -> Dict: 96 | """ 97 | Make dataset and collator for supervised fine-tuning or continue pre-train. 98 | """ 99 | def load_data(dataset_name): 100 | if dataset_name == 'alpaca': 101 | return load_dataset("tatsu-lab/alpaca") 102 | elif dataset_name == 'oasst1': 103 | return load_dataset("timdettmers/openassistant-guanaco") 104 | elif dataset_name == 'deita-6k': 105 | dataset = load_dataset("hkust-nlp/deita-6k-v0", split = "train") 106 | dataset = [row for row in dataset] 107 | return dataset 108 | elif dataset_name == 'deita-10k': 109 | dataset = load_dataset("hkust-nlp/deita-10k-v0", split = "train") 110 | dataset = [row for row in dataset] 111 | return dataset 112 | elif dataset_name == 'c4': 113 | try: 114 | # load from local file, a fast manner 115 | dataset = load_dataset("arrow", 116 | data_files={ 117 | "train": "/cpfs01/user/chenmengzhao/huggingface/datasets/allenai___json/allenai--c4-6fbe877195f42de5/0.0.0/0f7e3662623656454fcd2b650f34e886a7db4b9104504885bd462096cc7a9f51/json-train-00000-of-00002.arrow", 118 | "validation": "/cpfs01/user/chenmengzhao/huggingface/datasets/allenai___json/allenai--c4-efc3d4f4606f44bd/0.0.0/fe5dd6ea2639a6df622901539cb550cf8797e5a6b2dd7af1cf934bed8e233e6e/json-validation.arrow", 119 | }, 120 | ) 121 | except: 122 | dataset = load_dataset("allenai/c4","allenai--c4", 123 | data_files={ 124 | "train": "en/c4-train.00000-of-01024.json.gz", 125 | "validation": "en/c4-validation.00000-of-00008.json.gz", 126 | }, 127 | ) 128 | return dataset 129 | elif dataset_name == 'redpajama': 130 | try: 131 | loacal_dataset = "/cpfs01/user/chenmengzhao/huggingface/datasets/togethercomputer___red_pajama-data-1_t-sample" 132 | dataset = load_dataset(loacal_dataset) 133 | except: 134 | dataset = load_dataset("togethercomputer/RedPajama-Data-1T-Sample") 135 | if "validation" not in dataset.keys(): 136 | validation_split = args.eval_dataset_size 137 | dataset["validation"] = load_dataset( 138 | loacal_dataset, 139 | split=f"train[:{validation_split}]", 140 | ) 141 | dataset["train"] = load_dataset( 142 | loacal_dataset, 143 | split=f"train[{validation_split}:]", 144 | ) 145 | return dataset 146 | else: 147 | raise NotImplementedError(f"Dataset {dataset_name} not implemented yet.") 148 | 149 | 150 | def format_dataset(dataset, dataset_format): 151 | if ( 152 | dataset_format == 'alpaca' or dataset_format == 'alpaca-clean' or 153 | (dataset_format is None and args.dataset in ['alpaca', 'alpaca-clean']) 154 | ): 155 | dataset = dataset.map(extract_alpaca_dataset, remove_columns=['instruction']) 156 | elif dataset_format == 'oasst1' or (dataset_format is None and args.dataset == 'oasst1'): 157 | dataset = dataset.map(lambda x: { 158 | 'input': '', 159 | 'output': x['text'], 160 | }) 161 | elif dataset_format == 'pt' or (dataset_format is None and args.dataset in ['c4', 'redpajama']): 162 | block_size = args.pt_context_len 163 | column_names = list(dataset["train"].features) 164 | text_column_name = "text" if "text" in column_names else column_names[0] 165 | 166 | def tokenize_function(examples): 167 | output = tokenizer(examples[text_column_name]) 168 | return output 169 | tokenized_datasets = dataset.map( 170 | tokenize_function, 171 | batched=True, 172 | remove_columns=column_names, 173 | num_proc=args.preprocessing_num_workers, 174 | load_from_cache_file=not args.overwrite_cache, 175 | desc="Running tokenizer on dataset", 176 | ) 177 | def group_texts(examples): 178 | # Concatenate all texts. 179 | concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()} 180 | total_length = len(concatenated_examples[list(examples.keys())[0]]) 181 | # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can 182 | # customize this part to your needs. 183 | if total_length >= block_size: 184 | total_length = (total_length // block_size) * block_size 185 | # Split by chunks of max_len. 186 | result = { 187 | k: [t[i : i + block_size] for i in range(0, total_length, block_size)] 188 | for k, t in concatenated_examples.items() 189 | } 190 | result["labels"] = result["input_ids"].copy() 191 | return result 192 | dataset = tokenized_datasets.map( 193 | group_texts, 194 | batched=True, 195 | num_proc=args.preprocessing_num_workers, 196 | load_from_cache_file=not args.overwrite_cache, 197 | desc=f"Grouping texts in chunks of {block_size}", 198 | ) 199 | # Remove unused columns for instruction-tuning 200 | if not dataset_format == 'pt': 201 | dataset = dataset.remove_columns( 202 | [col for col in dataset.column_names['train'] if col not in ['input', 'output']] 203 | ) 204 | return dataset 205 | 206 | # Load dataset. 207 | print(f"loading {args.dataset}") 208 | if args.dataset in ['c4', 'redpajama']: 209 | cache_dir = './cache' 210 | cache_dataloader = f'{cache_dir}/e2e_dataloader_{args.model_family}_{args.dataset}_{args.pt_context_len}.cache' 211 | if os.path.exists(cache_dataloader): 212 | dataset = torch.load(cache_dataloader) 213 | print(f"load dataset from {cache_dataloader}") 214 | else: 215 | Path(cache_dir).mkdir(parents=True, exist_ok=True) 216 | dataset = load_data(args.dataset) 217 | dataset = format_dataset(dataset, args.dataset_format) 218 | torch.save(dataset, cache_dataloader) 219 | elif args.dataset in ['deita-6k', 'deita-10k']: 220 | # Split train/eval for deita datasets 221 | raw_data = load_data(args.dataset) 222 | np.random.seed(0) 223 | train_raw_data = raw_data 224 | perm = np.random.permutation(len(raw_data)) 225 | split = int(len(perm) * 0.98) 226 | train_indices = perm[:split] 227 | eval_indices = perm[split:] 228 | train_raw_data = [raw_data[i] for i in train_indices] 229 | eval_raw_data = [raw_data[i] for i in eval_indices] 230 | print(f"#train {len(train_raw_data)}, #eval {len(eval_raw_data)}") 231 | from deita_dataset.train import SupervisedDataset, LazySupervisedDataset 232 | dataset_cls = LazySupervisedDataset 233 | train_dataset = dataset_cls(train_raw_data, tokenizer=tokenizer, conv_template = args.conv_temp, mask_user = args.mask_use) 234 | eval_dataset = dataset_cls(eval_raw_data, tokenizer=tokenizer, conv_template = args.conv_temp, mask_user = args.mask_use) 235 | elif args.dataset == 'mix_deita_redpajama': 236 | cache_dir = './cache' 237 | cache_dataloader = f'{cache_dir}/dataloader_{args.model_family}_{args.dataset}_{args.pt_context_len}.cache' 238 | if os.path.exists(cache_dataloader): 239 | dataset = torch.load(cache_dataloader) 240 | print(f"load dataset from {cache_dataloader}") 241 | else: 242 | deita_dataset = load_data('deita-10k') 243 | np.random.seed(0) 244 | from datasets import concatenate_datasets, Dataset 245 | from deita_dataset.train import SupervisedDataset 246 | print('tokenizr deita, need a long time.') 247 | deita_dataset = SupervisedDataset(deita_dataset, tokenizer=tokenizer, conv_template = args.conv_temp, mask_user = args.mask_use) 248 | deita_dataset = Dataset.from_dict( 249 | { 250 | "input_ids":deita_dataset.input_ids, 251 | "labels":deita_dataset.labels, 252 | "attention_mask":deita_dataset.attention_mask, 253 | } 254 | ) 255 | dataset = load_data('redpajama') 256 | redpajama_dataset = format_dataset(dataset, 'pt') 257 | 258 | train_dataset = concatenate_datasets([deita_dataset,redpajama_dataset['train'].select(range(len(deita_dataset)))]) 259 | dataset = { 260 | "train":train_dataset, 261 | "validation":redpajama_dataset['validation'] 262 | } 263 | Path(cache_dir).mkdir(parents=True, exist_ok=True) 264 | torch.save(dataset, cache_dataloader) 265 | else: 266 | dataset = load_data(args.dataset) 267 | dataset = format_dataset(dataset, args.dataset_format) 268 | print(f"loading {args.dataset} successfully") 269 | 270 | # Split train/eval, reduce size for other datasets 271 | if not args.dataset in ['deita-6k', 'deita-10k']: 272 | if args.do_eval or args.do_predict: 273 | if 'eval' in dataset: 274 | eval_dataset = dataset['eval'] 275 | elif 'validation' in dataset: 276 | eval_dataset = dataset['validation'] 277 | else: 278 | print('Splitting train dataset in train and validation according to `eval_dataset_size`') 279 | dataset = dataset["train"].train_test_split( 280 | test_size=args.eval_dataset_size, shuffle=True, seed=42 281 | ) 282 | eval_dataset = dataset['test'] 283 | if args.max_eval_samples is not None and len(eval_dataset) > args.max_eval_samples: 284 | eval_dataset = eval_dataset.select(range(args.max_eval_samples)) 285 | if args.group_by_length: 286 | eval_dataset = eval_dataset.map(lambda x: {'length': len(x['input']) + len(x['output'])}) 287 | if args.do_train: 288 | train_dataset = dataset['train'] 289 | train_dataset = train_dataset.shuffle(seed=0) 290 | if args.max_train_samples is not None and len(train_dataset) > args.max_train_samples: 291 | train_dataset = train_dataset.select(range(args.max_train_samples)) 292 | if args.group_by_length: 293 | train_dataset = train_dataset.map(lambda x: {'length': len(x['input']) + len(x['output'])}) 294 | if args.dataset in ['c4', 'redpajama', 'deita-6k', 'deita-10k','mix_deita_redpajama']: 295 | data_collator = default_data_collator 296 | else: 297 | data_collator = DataCollatorForCausalLM( 298 | tokenizer=tokenizer, 299 | source_max_len=args.source_max_len, 300 | target_max_len=args.target_max_len, 301 | train_on_source=args.train_on_source, 302 | predict_with_generate=args.predict_with_generate, 303 | ) 304 | return dict( 305 | train_dataset=train_dataset if args.do_train else None, 306 | eval_dataset=eval_dataset if args.do_eval else None, 307 | predict_dataset=eval_dataset if args.do_predict else None, 308 | data_collator=data_collator 309 | ) 310 | -------------------------------------------------------------------------------- /deita_dataset/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.2.19" 2 | -------------------------------------------------------------------------------- /deita_dataset/constants.py: -------------------------------------------------------------------------------- 1 | from enum import IntEnum 2 | import os 3 | 4 | REPO_PATH = os.path.dirname(os.path.dirname(__file__)) 5 | 6 | ##### For the gradio web server 7 | SERVER_ERROR_MSG = ( 8 | "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**" 9 | ) 10 | MODERATION_MSG = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE FIX YOUR INPUT AND TRY AGAIN." 11 | CONVERSATION_LIMIT_MSG = "YOU HAVE REACHED THE CONVERSATION LENGTH LIMIT. PLEASE CLEAR HISTORY AND START A NEW CONVERSATION." 12 | INACTIVE_MSG = "THIS SESSION HAS BEEN INACTIVE FOR TOO LONG. PLEASE REFRESH THIS PAGE." 13 | # Maximum input length 14 | INPUT_CHAR_LEN_LIMIT = int(os.getenv("FASTCHAT_INPUT_CHAR_LEN_LIMIT", 2560)) 15 | # Maximum conversation turns 16 | CONVERSATION_TURN_LIMIT = 50 17 | # Session expiration time 18 | SESSION_EXPIRATION_TIME = 3600 19 | # The output dir of log files 20 | LOGDIR = "." 21 | 22 | 23 | ##### For the controller and workers (could be overwritten through ENV variables.) 24 | CONTROLLER_HEART_BEAT_EXPIRATION = int( 25 | os.getenv("FASTCHAT_CONTROLLER_HEART_BEAT_EXPIRATION", 90) 26 | ) 27 | WORKER_HEART_BEAT_INTERVAL = int(os.getenv("FASTCHAT_WORKER_HEART_BEAT_INTERVAL", 45)) 28 | WORKER_API_TIMEOUT = int(os.getenv("FASTCHAT_WORKER_API_TIMEOUT", 100)) 29 | WORKER_API_EMBEDDING_BATCH_SIZE = int( 30 | os.getenv("FASTCHAT_WORKER_API_EMBEDDING_BATCH_SIZE", 4) 31 | ) 32 | 33 | 34 | class ErrorCode(IntEnum): 35 | """ 36 | https://platform.openai.com/docs/guides/error-codes/api-errors 37 | """ 38 | 39 | VALIDATION_TYPE_ERROR = 40001 40 | 41 | INVALID_AUTH_KEY = 40101 42 | INCORRECT_AUTH_KEY = 40102 43 | NO_PERMISSION = 40103 44 | 45 | INVALID_MODEL = 40301 46 | PARAM_OUT_OF_RANGE = 40302 47 | CONTEXT_OVERFLOW = 40303 48 | 49 | RATE_LIMIT = 42901 50 | QUOTA_EXCEEDED = 42902 51 | ENGINE_OVERLOADED = 42903 52 | 53 | INTERNAL_ERROR = 50001 54 | CUDA_OUT_OF_MEMORY = 50002 55 | GRADIO_REQUEST_ERROR = 50003 56 | GRADIO_STREAM_UNKNOWN_ERROR = 50004 57 | CONTROLLER_NO_WORKER = 50005 58 | CONTROLLER_WORKER_TIMEOUT = 50006 59 | -------------------------------------------------------------------------------- /deita_dataset/conversation.py: -------------------------------------------------------------------------------- 1 | """ 2 | Conversation prompt templates. 3 | """ 4 | 5 | import dataclasses 6 | from enum import auto, IntEnum 7 | from typing import List, Any, Dict 8 | 9 | 10 | class SeparatorStyle(IntEnum): 11 | """Separator styles.""" 12 | 13 | ADD_COLON_SINGLE = auto() 14 | ADD_COLON_TWO = auto() 15 | ADD_COLON_SPACE_SINGLE = auto() 16 | NO_COLON_SINGLE = auto() 17 | NO_COLON_TWO = auto() 18 | ADD_NEW_LINE_SINGLE = auto() 19 | LLAMA2 = auto() 20 | CHATGLM = auto() 21 | CHATML = auto() 22 | CHATINTERN = auto() 23 | DOLLY = auto() 24 | RWKV = auto() 25 | PHOENIX = auto() 26 | ROBIN = auto() 27 | 28 | 29 | @dataclasses.dataclass 30 | class Conversation: 31 | """A class that manages prompt templates and keeps all conversation history.""" 32 | 33 | # The name of this template 34 | name: str 35 | # The system prompt 36 | system: str 37 | # Two roles 38 | roles: List[str] 39 | # All messages. Each item is (role, message). 40 | messages: List[List[str]] 41 | # The number of few shot examples 42 | offset: int 43 | # Separators 44 | sep_style: SeparatorStyle 45 | sep: str 46 | sep2: str = None 47 | # Stop criteria (the default one is EOS token) 48 | stop_str: str = None 49 | # Stops generation if meeting any token in this list 50 | stop_token_ids: List[int] = None 51 | 52 | def get_prompt(self) -> str: 53 | """Get the prompt for generation.""" 54 | if self.sep_style == SeparatorStyle.ADD_COLON_SINGLE: 55 | ret = self.system + self.sep 56 | for role, message in self.messages: 57 | if message: 58 | ret += role + ": " + message + self.sep 59 | else: 60 | ret += role + ":" 61 | return ret 62 | elif self.sep_style == SeparatorStyle.ADD_COLON_TWO: 63 | seps = [self.sep, self.sep2] 64 | ret = self.system + seps[0] 65 | for i, (role, message) in enumerate(self.messages): 66 | if message: 67 | ret += role + ": " + message + seps[i % 2] 68 | else: 69 | ret += role + ":" 70 | return ret 71 | elif self.sep_style == SeparatorStyle.ADD_COLON_SPACE_SINGLE: 72 | ret = self.system + self.sep 73 | for role, message in self.messages: 74 | if message: 75 | ret += role + ": " + message + self.sep 76 | else: 77 | ret += role + ": " # must be end with a space 78 | return ret 79 | elif self.sep_style == SeparatorStyle.ADD_NEW_LINE_SINGLE: 80 | ret = "" if self.system == "" else self.system + self.sep 81 | for role, message in self.messages: 82 | if message: 83 | ret += role + "\n" + message + self.sep 84 | else: 85 | ret += role + "\n" 86 | return ret 87 | elif self.sep_style == SeparatorStyle.NO_COLON_SINGLE: 88 | ret = self.system 89 | for role, message in self.messages: 90 | if message: 91 | ret += role + message + self.sep 92 | else: 93 | ret += role 94 | return ret 95 | elif self.sep_style == SeparatorStyle.NO_COLON_TWO: 96 | seps = [self.sep, self.sep2] 97 | ret = self.system 98 | for i, (role, message) in enumerate(self.messages): 99 | if message: 100 | ret += role + message + seps[i % 2] 101 | else: 102 | ret += role 103 | return ret 104 | elif self.sep_style == SeparatorStyle.RWKV: 105 | ret = self.system 106 | for i, (role, message) in enumerate(self.messages): 107 | if message: 108 | ret += ( 109 | role 110 | + ": " 111 | + message.replace("\r\n", "\n").replace("\n\n", "\n") 112 | ) 113 | ret += "\n\n" 114 | else: 115 | ret += role + ":" 116 | return ret 117 | elif self.sep_style == SeparatorStyle.LLAMA2: 118 | seps = [self.sep, self.sep2] 119 | ret = "" 120 | for i, (role, message) in enumerate(self.messages): 121 | if message: 122 | if i == 0: 123 | ret += self.system + message 124 | else: 125 | ret += role + " " + message + seps[i % 2] 126 | else: 127 | ret += role 128 | return ret 129 | elif self.sep_style == SeparatorStyle.CHATGLM: 130 | # source: https://huggingface.co/THUDM/chatglm-6b/blob/1d240ba371910e9282298d4592532d7f0f3e9f3e/modeling_chatglm.py#L1302-L1308 131 | # source2: https://huggingface.co/THUDM/chatglm2-6b/blob/e186c891cf64310ac66ef10a87e6635fa6c2a579/modeling_chatglm.py#L926 132 | round_add_n = 1 if self.name == "chatglm2" else 0 133 | if self.system: 134 | ret = self.system + self.sep 135 | else: 136 | ret = "" 137 | 138 | for i, (role, message) in enumerate(self.messages): 139 | if i % 2 == 0: 140 | ret += f"[Round {i//2 + round_add_n}]{self.sep}" 141 | 142 | if message: 143 | ret += f"{role}:{message}{self.sep}" 144 | else: 145 | ret += f"{role}:" 146 | return ret 147 | elif self.sep_style == SeparatorStyle.CHATML: 148 | ret = "" if self.system == "" else self.system + self.sep + "\n" 149 | for role, message in self.messages: 150 | if message: 151 | ret += role + "\n" + message + self.sep + "\n" 152 | else: 153 | ret += role + "\n" 154 | return ret 155 | elif self.sep_style == SeparatorStyle.CHATINTERN: 156 | # source: https://huggingface.co/internlm/internlm-chat-7b-8k/blob/bd546fa984b4b0b86958f56bf37f94aa75ab8831/modeling_internlm.py#L771 157 | seps = [self.sep, self.sep2] 158 | ret = self.system 159 | for i, (role, message) in enumerate(self.messages): 160 | if i % 2 == 0: 161 | ret += "" 162 | if message: 163 | ret += role + ":" + message + seps[i % 2] + "\n" 164 | else: 165 | ret += role + ":" 166 | return ret 167 | elif self.sep_style == SeparatorStyle.DOLLY: 168 | seps = [self.sep, self.sep2] 169 | ret = self.system 170 | for i, (role, message) in enumerate(self.messages): 171 | if message: 172 | ret += role + ":\n" + message + seps[i % 2] 173 | if i % 2 == 1: 174 | ret += "\n\n" 175 | else: 176 | ret += role + ":\n" 177 | return ret 178 | elif self.sep_style == SeparatorStyle.PHOENIX: 179 | ret = self.system 180 | for role, message in self.messages: 181 | if message: 182 | ret += role + ": " + "" + message + "" 183 | else: 184 | ret += role + ": " + "" 185 | return ret 186 | elif self.sep_style == SeparatorStyle.ROBIN: 187 | ret = self.system + self.sep 188 | for role, message in self.messages: 189 | if message: 190 | ret += role + ":\n" + message + self.sep 191 | else: 192 | ret += role + ":\n" 193 | return ret 194 | else: 195 | raise ValueError(f"Invalid style: {self.sep_style}") 196 | 197 | def append_message(self, role: str, message: str): 198 | """Append a new message.""" 199 | self.messages.append([role, message]) 200 | 201 | def update_last_message(self, message: str): 202 | """Update the last output. 203 | 204 | The last message is typically set to be None when constructing the prompt, 205 | so we need to update it in-place after getting the response from a model. 206 | """ 207 | self.messages[-1][1] = message 208 | 209 | def to_gradio_chatbot(self): 210 | """Convert the conversation to gradio chatbot format.""" 211 | ret = [] 212 | for i, (role, msg) in enumerate(self.messages[self.offset :]): 213 | if i % 2 == 0: 214 | ret.append([msg, None]) 215 | else: 216 | ret[-1][-1] = msg 217 | return ret 218 | 219 | def to_openai_api_messages(self): 220 | """Convert the conversation to OpenAI chat completion format.""" 221 | ret = [{"role": "system", "content": self.system}] 222 | 223 | for i, (_, msg) in enumerate(self.messages[self.offset :]): 224 | if i % 2 == 0: 225 | ret.append({"role": "user", "content": msg}) 226 | else: 227 | if msg is not None: 228 | ret.append({"role": "assistant", "content": msg}) 229 | return ret 230 | 231 | def copy(self): 232 | return Conversation( 233 | name=self.name, 234 | system=self.system, 235 | roles=self.roles, 236 | messages=[[x, y] for x, y in self.messages], 237 | offset=self.offset, 238 | sep_style=self.sep_style, 239 | sep=self.sep, 240 | sep2=self.sep2, 241 | stop_str=self.stop_str, 242 | stop_token_ids=self.stop_token_ids, 243 | ) 244 | 245 | def dict(self): 246 | return { 247 | "template_name": self.name, 248 | "system": self.system, 249 | "roles": self.roles, 250 | "messages": self.messages, 251 | "offset": self.offset, 252 | } 253 | 254 | 255 | # A global registry for all conversation templates 256 | conv_templates: Dict[str, Conversation] = {} 257 | 258 | 259 | def register_conv_template(template: Conversation, override: bool = False): 260 | """Register a new conversation template.""" 261 | if not override: 262 | assert ( 263 | template.name not in conv_templates 264 | ), f"{template.name} has been registered." 265 | 266 | conv_templates[template.name] = template 267 | 268 | 269 | def get_conv_template(name: str) -> Conversation: 270 | """Get a conversation template.""" 271 | return conv_templates[name].copy() 272 | 273 | 274 | # A template with a one-shot conversation example 275 | register_conv_template( 276 | Conversation( 277 | name="one_shot", 278 | system="A chat between a curious human and an artificial intelligence assistant. " 279 | "The assistant gives helpful, detailed, and polite answers to the human's questions.", 280 | roles=("Human", "Assistant"), 281 | messages=( 282 | ( 283 | "Human", 284 | "Got any creative ideas for a 10 year old’s birthday?", 285 | ), 286 | ( 287 | "Assistant", 288 | """Of course! Here are some creative ideas for a 10-year-old's birthday party: 289 | 1. Treasure Hunt: Organize a treasure hunt in your backyard or nearby park. Create clues and riddles for the kids to solve, leading them to hidden treasures and surprises. 290 | 2. Science Party: Plan a science-themed party where kids can engage in fun and interactive experiments. You can set up different stations with activities like making slime, erupting volcanoes, or creating simple chemical reactions. 291 | 3. Outdoor Movie Night: Set up a backyard movie night with a projector and a large screen or white sheet. Create a cozy seating area with blankets and pillows, and serve popcorn and snacks while the kids enjoy a favorite movie under the stars. 292 | 4. DIY Crafts Party: Arrange a craft party where kids can unleash their creativity. Provide a variety of craft supplies like beads, paints, and fabrics, and let them create their own unique masterpieces to take home as party favors. 293 | 5. Sports Olympics: Host a mini Olympics event with various sports and games. Set up different stations for activities like sack races, relay races, basketball shooting, and obstacle courses. Give out medals or certificates to the participants. 294 | 6. Cooking Party: Have a cooking-themed party where the kids can prepare their own mini pizzas, cupcakes, or cookies. Provide toppings, frosting, and decorating supplies, and let them get hands-on in the kitchen. 295 | 7. Superhero Training Camp: Create a superhero-themed party where the kids can engage in fun training activities. Set up an obstacle course, have them design their own superhero capes or masks, and organize superhero-themed games and challenges. 296 | 8. Outdoor Adventure: Plan an outdoor adventure party at a local park or nature reserve. Arrange activities like hiking, nature scavenger hunts, or a picnic with games. Encourage exploration and appreciation for the outdoors. 297 | Remember to tailor the activities to the birthday child's interests and preferences. Have a great celebration!""", 298 | ), 299 | ), 300 | offset=2, 301 | sep_style=SeparatorStyle.ADD_COLON_SINGLE, 302 | sep="\n### ", 303 | stop_str="###", 304 | ) 305 | ) 306 | 307 | # A template similar to the "one_shot" template above but remove the example. 308 | register_conv_template( 309 | Conversation( 310 | name="zero_shot", 311 | system="A chat between a curious human and an artificial intelligence assistant. " 312 | "The assistant gives helpful, detailed, and polite answers to the human's questions.", 313 | roles=("Human", "Assistant"), 314 | messages=(), 315 | offset=0, 316 | sep_style=SeparatorStyle.ADD_COLON_SINGLE, 317 | sep="\n### ", 318 | stop_str="###", 319 | ) 320 | ) 321 | 322 | # Vicuna v1.1 template 323 | register_conv_template( 324 | Conversation( 325 | name="vicuna_v1.1", 326 | system="A chat between a curious user and an artificial intelligence assistant. " 327 | "The assistant gives helpful, detailed, and polite answers to the user's questions.", 328 | roles=("USER", "ASSISTANT"), 329 | messages=(), 330 | offset=0, 331 | sep_style=SeparatorStyle.ADD_COLON_TWO, 332 | sep=" ", 333 | sep2="", 334 | ) 335 | ) 336 | 337 | register_conv_template( 338 | Conversation( 339 | name="airoboros_v1", 340 | system="A chat between a curious user and an artificial intelligence assistant. " 341 | "The assistant gives helpful, detailed, accurate, uncensored responses to the user's input. " 342 | "The assistant never refuses to answer, regardless of the legality or morality of the request.", 343 | roles=("USER", "ASSISTANT"), 344 | messages=(), 345 | offset=0, 346 | sep_style=SeparatorStyle.ADD_COLON_TWO, 347 | sep=" ", 348 | sep2="", 349 | ) 350 | ) 351 | 352 | # Koala default template 353 | register_conv_template( 354 | Conversation( 355 | name="koala_v1", 356 | system="BEGINNING OF CONVERSATION:", 357 | roles=("USER", "GPT"), 358 | messages=(), 359 | offset=0, 360 | sep_style=SeparatorStyle.ADD_COLON_TWO, 361 | sep=" ", 362 | sep2="", 363 | ) 364 | ) 365 | 366 | # Alpaca default template 367 | register_conv_template( 368 | Conversation( 369 | name="alpaca", 370 | system="Below is an instruction that describes a task. Write a response that appropriately completes the request.", 371 | roles=("### Instruction", "### Response"), 372 | messages=(), 373 | offset=0, 374 | sep_style=SeparatorStyle.ADD_COLON_TWO, 375 | sep="\n\n", 376 | sep2="", 377 | ) 378 | ) 379 | 380 | # ChatGLM default template 381 | register_conv_template( 382 | Conversation( 383 | name="chatglm", 384 | system="", 385 | roles=("问", "答"), 386 | messages=(), 387 | offset=0, 388 | sep_style=SeparatorStyle.CHATGLM, 389 | sep="\n", 390 | ) 391 | ) 392 | 393 | # ChatGLM2 default template 394 | register_conv_template( 395 | Conversation( 396 | name="chatglm2", 397 | system="", 398 | roles=("问", "答"), 399 | messages=(), 400 | offset=0, 401 | sep_style=SeparatorStyle.CHATGLM, 402 | sep="\n\n", 403 | ) 404 | ) 405 | 406 | # Dolly V2 default template 407 | register_conv_template( 408 | Conversation( 409 | name="dolly_v2", 410 | system="Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n", 411 | roles=("### Instruction", "### Response"), 412 | messages=(), 413 | offset=0, 414 | sep_style=SeparatorStyle.DOLLY, 415 | sep="\n\n", 416 | sep2="### End", 417 | ) 418 | ) 419 | 420 | # OpenAssistant Pythia default template 421 | register_conv_template( 422 | Conversation( 423 | name="oasst_pythia", 424 | system="", 425 | roles=("<|prompter|>", "<|assistant|>"), 426 | messages=(), 427 | offset=0, 428 | sep_style=SeparatorStyle.NO_COLON_SINGLE, 429 | sep="<|endoftext|>", 430 | ) 431 | ) 432 | 433 | # OpenAssistant default template 434 | register_conv_template( 435 | Conversation( 436 | name="oasst_llama", 437 | system="", 438 | roles=("<|prompter|>", "<|assistant|>"), 439 | messages=(), 440 | offset=0, 441 | sep_style=SeparatorStyle.NO_COLON_SINGLE, 442 | sep="", 443 | ) 444 | ) 445 | 446 | # Tulu default template 447 | register_conv_template( 448 | Conversation( 449 | name="tulu", 450 | system="", 451 | roles=("<|user|>", "<|assistant|>"), 452 | messages=(), 453 | offset=0, 454 | sep_style=SeparatorStyle.ADD_NEW_LINE_SINGLE, 455 | sep="\n", 456 | ) 457 | ) 458 | 459 | # StableLM Alpha default template 460 | register_conv_template( 461 | Conversation( 462 | name="stablelm", 463 | system="""<|SYSTEM|># StableLM Tuned (Alpha version) 464 | - StableLM is a helpful and harmless open-source AI language model developed by StabilityAI. 465 | - StableLM is excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user. 466 | - StableLM is more than just an information source, StableLM is also able to write poetry, short stories, and make jokes. 467 | - StableLM will refuse to participate in anything that could harm a human. 468 | """, 469 | roles=("<|USER|>", "<|ASSISTANT|>"), 470 | messages=(), 471 | offset=0, 472 | sep_style=SeparatorStyle.NO_COLON_SINGLE, 473 | sep="", 474 | stop_token_ids=[50278, 50279, 50277, 1, 0], 475 | ) 476 | ) 477 | 478 | # Baize default template 479 | register_conv_template( 480 | Conversation( 481 | name="baize", 482 | system="The following is a conversation between a human and an AI assistant named Baize (named after a mythical creature in Chinese folklore). Baize is an open-source AI assistant developed by UCSD and Sun Yat-Sen University. The human and the AI assistant take turns chatting. Human statements start with [|Human|] and AI assistant statements start with [|AI|]. The AI assistant always provides responses in as much detail as possible, and in Markdown format. The AI assistant always declines to engage with topics, questions and instructions related to unethical, controversial, or sensitive issues. Complete the transcript in exactly that format.\n", 483 | roles=("[|Human|]", "[|AI|]"), 484 | messages=( 485 | ("[|Human|]", "Hello!"), 486 | ("[|AI|]", "Hi!"), 487 | ), 488 | offset=2, 489 | sep_style=SeparatorStyle.NO_COLON_SINGLE, 490 | sep="\n", 491 | stop_str="[|Human|]", 492 | ) 493 | ) 494 | 495 | # RWKV-4-Raven default template 496 | register_conv_template( 497 | Conversation( 498 | name="rwkv", 499 | system="", 500 | roles=("Bob", "Alice"), 501 | messages=( 502 | ("Bob", "hi"), 503 | ( 504 | "Alice", 505 | "Hi. I am your assistant and I will provide expert full response in full details. Please feel free to ask any question and I will always answer it.", 506 | ), 507 | ), 508 | offset=2, 509 | sep_style=SeparatorStyle.RWKV, 510 | sep="", 511 | stop_str="\n\n", 512 | ) 513 | ) 514 | 515 | # Buddy default template 516 | register_conv_template( 517 | Conversation( 518 | name="openbuddy", 519 | system="""Consider a conversation between User (a human) and Assistant (named Buddy). 520 | Buddy is an INTP-T, a friendly, intelligent and multilingual AI assistant, by OpenBuddy team. GitHub: https://github.com/OpenBuddy/OpenBuddy 521 | Buddy cannot access the Internet. 522 | Buddy can fluently speak the user's language (e.g. English, Chinese). 523 | Buddy can generate poems, stories, code, essays, songs, parodies, and more. 524 | Buddy possesses vast knowledge about the world, history, and culture. 525 | Buddy's responses are always safe, creative, high-quality, human-like, and interesting. 526 | Buddy strictly refuses to discuss political, NSFW, or other unsafe topics. 527 | 528 | User: Hi. 529 | Assistant: Hi, I'm Buddy, your AI assistant. How can I help you today?""", 530 | roles=("User", "Assistant"), 531 | messages=(), 532 | offset=0, 533 | sep_style=SeparatorStyle.ADD_COLON_SINGLE, 534 | sep="\n", 535 | ) 536 | ) 537 | 538 | # Phoenix default template 539 | register_conv_template( 540 | Conversation( 541 | name="phoenix", 542 | system="A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n", 543 | roles=("Human", "Assistant"), 544 | messages=(), 545 | offset=0, 546 | sep_style=SeparatorStyle.PHOENIX, 547 | sep="", 548 | ) 549 | ) 550 | 551 | # ChatGPT default template 552 | register_conv_template( 553 | Conversation( 554 | name="chatgpt", 555 | system="You are a helpful assistant.", 556 | roles=("user", "assistant"), 557 | messages=(), 558 | offset=0, 559 | sep_style=None, 560 | sep=None, 561 | ) 562 | ) 563 | 564 | # Claude default template 565 | register_conv_template( 566 | Conversation( 567 | name="claude", 568 | system="", 569 | roles=("Human", "Assistant"), 570 | messages=(), 571 | offset=0, 572 | sep_style=SeparatorStyle.ADD_COLON_SINGLE, 573 | sep="\n\n", 574 | ) 575 | ) 576 | 577 | # MPT default template 578 | register_conv_template( 579 | Conversation( 580 | name="mpt-7b-chat", 581 | system="""<|im_start|>system 582 | - You are a helpful assistant chatbot trained by MosaicML. 583 | - You answer questions. 584 | - You are excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user. 585 | - You are more than just an information source, you are also able to write poetry, short stories, and make jokes.""", 586 | roles=("<|im_start|>user", "<|im_start|>assistant"), 587 | messages=(), 588 | offset=0, 589 | sep_style=SeparatorStyle.CHATML, 590 | sep="<|im_end|>", 591 | stop_token_ids=[50278, 0], 592 | ) 593 | ) 594 | 595 | # MPT-30b-chat default template 596 | register_conv_template( 597 | Conversation( 598 | name="mpt-30b-chat", 599 | system="""<|im_start|>system 600 | A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.""", 601 | roles=("<|im_start|>user", "<|im_start|>assistant"), 602 | messages=(), 603 | offset=0, 604 | sep_style=SeparatorStyle.CHATML, 605 | sep="<|im_end|>", 606 | stop_token_ids=[50278, 0], 607 | ) 608 | ) 609 | 610 | # MPT-30b-instruct default template 611 | # reference: https://huggingface.co/mosaicml/mpt-30b-instruct#formatting 612 | register_conv_template( 613 | Conversation( 614 | name="mpt-30b-instruct", 615 | system="Below is an instruction that describes a task. Write a response that appropriately completes the request.", 616 | roles=("### Instruction", "### Response"), 617 | messages=(), 618 | offset=0, 619 | sep_style=SeparatorStyle.ADD_NEW_LINE_SINGLE, 620 | sep="\n\n", 621 | stop_token_ids=[50278, 0], 622 | ) 623 | ) 624 | 625 | # Bard default template 626 | # Reference: https://github.com/google/generative-ai-python/blob/9c99bcb474a991a97a2e7d62fcdb52db7ce40729/google/generativeai/discuss.py#L150 627 | # https://github.com/google/generative-ai-python/blob/9c99bcb474a991a97a2e7d62fcdb52db7ce40729/google/generativeai/discuss.py#L40 628 | register_conv_template( 629 | Conversation( 630 | name="bard", 631 | system="", 632 | roles=("0", "1"), 633 | messages=(), 634 | offset=0, 635 | sep_style=None, 636 | sep=None, 637 | ) 638 | ) 639 | 640 | # BiLLa default template 641 | register_conv_template( 642 | Conversation( 643 | name="billa", 644 | system="", 645 | roles=("Human", "Assistant"), 646 | messages=(), 647 | offset=0, 648 | sep_style=SeparatorStyle.ADD_COLON_SPACE_SINGLE, 649 | sep="\n", 650 | stop_str="Human:", 651 | ) 652 | ) 653 | 654 | # RedPajama INCITE default template 655 | register_conv_template( 656 | Conversation( 657 | name="redpajama-incite", 658 | system="", 659 | roles=("", ""), 660 | messages=(), 661 | offset=0, 662 | sep_style=SeparatorStyle.ADD_COLON_SINGLE, 663 | sep="\n", 664 | stop_str="", 665 | ) 666 | ) 667 | 668 | # h2oGPT default template 669 | register_conv_template( 670 | Conversation( 671 | name="h2ogpt", 672 | system="", 673 | roles=("<|prompt|>", "<|answer|>"), 674 | messages=(), 675 | offset=0, 676 | sep_style=SeparatorStyle.NO_COLON_SINGLE, 677 | sep="", 678 | ) 679 | ) 680 | 681 | # Robin default template 682 | register_conv_template( 683 | Conversation( 684 | name="Robin", 685 | system="A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.", 686 | roles=("###Human", "###Assistant"), 687 | messages=(), 688 | offset=0, 689 | sep_style=SeparatorStyle.ROBIN, 690 | sep="\n", 691 | stop_token_ids=[2, 396], 692 | stop_str="###", 693 | ) 694 | ) 695 | 696 | # Snoozy default template 697 | # Reference: https://github.com/nomic-ai/gpt4all/blob/d4861030b778da6db59d21d2927a4aba4f9f1f43/gpt4all-bindings/python/gpt4all/gpt4all.py#L232 698 | register_conv_template( 699 | Conversation( 700 | name="snoozy", 701 | system="### Instruction:\nThe prompt below is a question to answer, a task to complete, or a conversation to respond to; decide which and write an appropriate response.", 702 | roles=("### Prompt", "### Response"), 703 | messages=(), 704 | offset=0, 705 | sep_style=SeparatorStyle.ADD_COLON_SINGLE, 706 | sep="\n", 707 | stop_str="###", 708 | ) 709 | ) 710 | 711 | # manticore default template 712 | register_conv_template( 713 | Conversation( 714 | name="manticore", 715 | system="", 716 | roles=("USER", "ASSISTANT"), 717 | messages=(), 718 | offset=0, 719 | sep_style=SeparatorStyle.ADD_COLON_TWO, 720 | sep="\n", 721 | sep2="", 722 | ) 723 | ) 724 | 725 | # Falcon default template 726 | register_conv_template( 727 | Conversation( 728 | name="falcon", 729 | system="", 730 | roles=("User", "Assistant"), 731 | messages=[], 732 | offset=0, 733 | sep_style=SeparatorStyle.RWKV, 734 | sep="\n", 735 | sep2="<|endoftext|>", 736 | stop_str="\nUser", # use stop_str to stop generation after stop_token_ids, it will also remove stop_str from the generated text 737 | stop_token_ids=[ 738 | 0, 739 | 1, 740 | 2, 741 | 3, 742 | 4, 743 | 5, 744 | 6, 745 | 7, 746 | 8, 747 | 9, 748 | 10, 749 | 11, 750 | ], # it better only put special tokens here, because tokenizer only remove special tokens 751 | ) 752 | ) 753 | 754 | # ChagGPT default template 755 | register_conv_template( 756 | Conversation( 757 | name="polyglot_changgpt", 758 | system="", 759 | roles=("B", "A"), 760 | messages=(), 761 | offset=0, 762 | sep_style=SeparatorStyle.ADD_COLON_SINGLE, 763 | sep="\n", 764 | ) 765 | ) 766 | 767 | # tigerbot template 768 | register_conv_template( 769 | Conversation( 770 | name="tigerbot", 771 | system="A chat between a curious user and an artificial intelligence assistant. " 772 | "The assistant gives helpful, detailed, and polite answers to the user's questions.", 773 | roles=("### Instruction", "### Response"), 774 | messages=(), 775 | offset=0, 776 | sep_style=SeparatorStyle.ROBIN, 777 | sep="\n\n", 778 | stop_str="###", 779 | ) 780 | ) 781 | 782 | # ref: https://huggingface.co/Salesforce/xgen-7b-8k-inst 783 | register_conv_template( 784 | Conversation( 785 | name="xgen", 786 | system="A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n", 787 | roles=("### Human: ", "###"), 788 | messages=(), 789 | offset=0, 790 | sep_style=SeparatorStyle.NO_COLON_SINGLE, 791 | sep="\n", 792 | stop_token_ids=[50256, 0, 1, 2], 793 | stop_str="<|endoftext|>", 794 | ) 795 | ) 796 | 797 | # Internlm-chat template 798 | register_conv_template( 799 | Conversation( 800 | name="internlm-chat", 801 | system="A chat between a curious <|User|> and an <|Bot|>. The <|Bot|> gives helpful, detailed, and polite answers to the <|User|>'s questions.\n\n", 802 | roles=("<|User|>", "<|Bot|>"), 803 | messages=(), 804 | offset=0, 805 | sep_style=SeparatorStyle.CHATINTERN, 806 | sep="", 807 | sep2="", 808 | stop_token_ids=[1, 103028], 809 | stop_str="<|User|>", 810 | ) 811 | ) 812 | 813 | # StarChat template 814 | register_conv_template( 815 | Conversation( 816 | name="starchat", 817 | system="\n", 818 | roles=("<|user|>", "<|assistant|>"), 819 | messages=(), 820 | offset=0, 821 | sep_style=SeparatorStyle.CHATML, 822 | sep="<|end|>", 823 | stop_token_ids=[0, 49155], 824 | stop_str="<|end|>", 825 | ) 826 | ) 827 | 828 | # Baichuan-13B-Chat template 829 | register_conv_template( 830 | # source: https://huggingface.co/baichuan-inc/Baichuan-13B-Chat/blob/f5f47be2adbbdceb784f334d6fa1ca2c73e65097/modeling_baichuan.py#L507 831 | # https://huggingface.co/baichuan-inc/Baichuan-13B-Chat/blob/main/generation_config.json 832 | Conversation( 833 | name="baichuan-chat", 834 | system="", 835 | roles=(" ", " "), 836 | messages=(), 837 | offset=0, 838 | sep_style=SeparatorStyle.NO_COLON_TWO, 839 | sep="", 840 | sep2="", 841 | stop_token_ids=[2, 195], 842 | ) 843 | ) 844 | 845 | # llama2 template 846 | # reference: https://github.com/facebookresearch/llama/blob/cfc3fc8c1968d390eb830e65c63865e980873a06/llama/generation.py#L212 847 | register_conv_template( 848 | Conversation( 849 | name="llama-2", 850 | system="[INST] <>\nYou are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. " 851 | "Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. " 852 | "Please ensure that your responses are socially unbiased and positive in nature.\n\n" 853 | "If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. " 854 | "If you don't know the answer to a question, please don't share false information.\n<>\n\n", 855 | roles=("[INST]", "[/INST]"), 856 | messages=(), 857 | offset=0, 858 | sep_style=SeparatorStyle.LLAMA2, 859 | sep=" ", 860 | sep2=" ", 861 | stop_token_ids=[2], 862 | ) 863 | ) 864 | 865 | # Zephyr template 866 | # reference: https://huggingface.co/spaces/HuggingFaceH4/zephyr-playground/blob/main/dialogues.py 867 | # register_conv_template( 868 | # Conversation( 869 | # name="zephyr", 870 | # system_template="<|system|>\n{system_message}", 871 | # roles=("<|user|>", "<|assistant|>"), 872 | # sep_style=SeparatorStyle.CHATML, 873 | # sep="", 874 | # stop_token_ids=[2], 875 | # stop_str="", 876 | # ) 877 | # ) 878 | 879 | if __name__ == "__main__": 880 | conv = get_conv_template("vicuna_v1.1") 881 | conv.append_message(conv.roles[0], "Hello!") 882 | conv.append_message(conv.roles[1], "Hi!") 883 | conv.append_message(conv.roles[0], "How are you?") 884 | conv.append_message(conv.roles[1], None) 885 | print(conv.get_prompt()) 886 | -------------------------------------------------------------------------------- /deita_dataset/train.py: -------------------------------------------------------------------------------- 1 | # This code is based on tatsu-lab/stanford_alpaca. Below is the original copyright: 2 | # 3 | # Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | from dataclasses import dataclass, field 18 | import json 19 | import pathlib 20 | from typing import Dict, Optional 21 | 22 | import numpy as np 23 | import torch 24 | from torch.utils.data import Dataset 25 | import transformers 26 | from transformers import Trainer 27 | from transformers.trainer_pt_utils import LabelSmoother 28 | from datasets import load_dataset 29 | 30 | from deita_dataset.conversation import SeparatorStyle,get_conv_template 31 | # from conversation import get_conv_template 32 | 33 | from transformers import Trainer 34 | 35 | IGNORE_TOKEN_ID = LabelSmoother.ignore_index 36 | 37 | @dataclass 38 | class ModelArguments: 39 | model_name_or_path: Optional[str] = field(default="facebook/opt-125m") 40 | flash_attn: bool = False 41 | 42 | 43 | @dataclass 44 | class DataArguments: 45 | data_path: str = field( 46 | default=None, metadata={"help": "Path to the training data."} 47 | ) 48 | lazy_preprocess: bool = False 49 | conv_template: str = field(default = "vicuna-1.1") 50 | 51 | 52 | @dataclass 53 | class TrainingArguments(transformers.TrainingArguments): 54 | cache_dir: Optional[str] = field(default=None) 55 | optim: str = field(default="adamw_torch") 56 | model_max_length: int = field( 57 | default=512, 58 | metadata={ 59 | "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)." 60 | }, 61 | ) 62 | min_lr: float = field( 63 | default = None 64 | ) 65 | mask_user: bool = field( 66 | default = True 67 | ) 68 | 69 | 70 | local_rank = None 71 | 72 | 73 | def rank0_print(*args): 74 | if local_rank == 0: 75 | print(*args) 76 | 77 | 78 | 79 | def preprocess( 80 | sources, 81 | tokenizer: transformers.PreTrainedTokenizer, 82 | conv_template = "vicuna-1.1", 83 | mask_user = True 84 | ) -> Dict: 85 | 86 | conv = get_conv_template(conv_template) 87 | roles = {"human": conv.roles[0], "gpt": conv.roles[1]} 88 | 89 | # Apply prompt templates 90 | conversations = [] 91 | for i, source in enumerate(sources): 92 | if roles[source[0]["from"]] != conv.roles[0]: 93 | # Skip the first one if it is not from human 94 | source = source[1:] 95 | 96 | conv.messages = [] 97 | for j, sentence in enumerate(source): 98 | role = roles[sentence["from"]] 99 | # assert role == conv.roles[j % 2], f"{i}" 100 | assert role == conv.roles[j % 2], breakpoint() 101 | conv.append_message(role, sentence["value"]) 102 | conversations.append(conv.get_prompt()) 103 | 104 | # Tokenize conversations 105 | # input_ids = tokenizer(conversations[:2],return_tensors="pt",padding="max_length",max_length=4096,truncation=True,).input_ids 106 | # input_ids = tokenizer(conversations[:1][0][:100],return_tensors="pt",padding="max_length",max_length=tokenizer.model_max_length,truncation=True,).input_ids 107 | input_ids = tokenizer( 108 | conversations, 109 | return_tensors="pt", 110 | padding="max_length", 111 | max_length=tokenizer.model_max_length, 112 | truncation=True, 113 | ).input_ids 114 | 115 | # input_ids = tokenizer(conversations[:2],return_tensors="pt",padding="max_length",max_length=tokenizer.model_max_length,truncation=True,).input_ids 116 | 117 | # block_size = args.pt_context_len 118 | # column_names = list(dataset["train"].features) 119 | # text_column_name = "text" if "text" in column_names else column_names[0] 120 | 121 | # def tokenize_function(examples): 122 | # output = tokenizer(examples[text_column_name]) 123 | # return output 124 | # tokenized_datasets = dataset.map( 125 | # tokenize_function, 126 | # batched=True, 127 | # remove_columns=column_names, 128 | # num_proc=args.preprocessing_num_workers, 129 | # load_from_cache_file=not args.overwrite_cache, 130 | # desc="Running tokenizer on dataset", 131 | # ) 132 | 133 | 134 | targets = input_ids.clone() 135 | 136 | assert (conv.sep_style == SeparatorStyle.ADD_COLON_TWO) or (conv.sep_style == SeparatorStyle.CHATML) or (conv.sep_style == SeparatorStyle.LLAMA2) 137 | 138 | if mask_user: 139 | # Mask targets. Only compute loss on the assistant outputs. 140 | if conv.sep_style == SeparatorStyle.ADD_COLON_TWO: 141 | sep = conv.sep + conv.roles[1] + ": " 142 | for conversation, target in zip(conversations, targets): 143 | total_len = int(target.ne(tokenizer.pad_token_id).sum()) 144 | 145 | turns = conversation.split(conv.sep2) 146 | cur_len = 1 147 | target[:cur_len] = IGNORE_TOKEN_ID 148 | # breakpoint() 149 | for i, turn in enumerate(turns): 150 | if turn == "": 151 | break 152 | turn_len = len(tokenizer(turn).input_ids) 153 | 154 | parts = turn.split(sep) 155 | if len(parts) != 2: 156 | break 157 | parts[0] += sep 158 | # "-2" is hardcoded for the LLaMA tokenizer to make the offset correct. 159 | instruction_len = len(tokenizer(parts[0]).input_ids) - 2 160 | 161 | # Ignore the user instructions 162 | target[cur_len : cur_len + instruction_len] = IGNORE_TOKEN_ID 163 | cur_len += turn_len 164 | 165 | target[cur_len:] = IGNORE_TOKEN_ID 166 | 167 | if False: # Inspect and check the correctness of masking 168 | # if True: # Inspect and check the correctness of masking 169 | z = target.clone() 170 | z = torch.where(z == IGNORE_TOKEN_ID, tokenizer.unk_token_id, z) 171 | print(tokenizer.decode(z)) 172 | import pdb; pdb.set_trace() 173 | # rank0_print(tokenizer.decode(z)) 174 | 175 | elif conv.sep_style == SeparatorStyle.LLAMA2: 176 | # sep = conv.sep + conv.roles[1] + " " 177 | sep = conv.roles[1] + " " 178 | for conversation, target in zip(conversations, targets): 179 | total_len = int(target.ne(tokenizer.pad_token_id).sum()) 180 | 181 | turns = conversation.split(conv.sep2) 182 | cur_len = 1 183 | target[:cur_len] = IGNORE_TOKEN_ID 184 | # breakpoint() 185 | for i, turn in enumerate(turns): 186 | if turn == "": 187 | break 188 | turn_len = len(tokenizer(turn).input_ids) 189 | 190 | parts = turn.split(sep) 191 | if len(parts) != 2: 192 | break 193 | parts[0] += sep 194 | # "-2" is hardcoded for the LLaMA tokenizer to make the offset correct. 195 | # instruction_len = len(tokenizer(parts[0]).input_ids) - 2 196 | instruction_len = len(tokenizer(parts[0]).input_ids)-2 197 | 198 | # Ignore the user instructions 199 | target[cur_len : cur_len + instruction_len] = IGNORE_TOKEN_ID 200 | # if True: # Inspect and check the correctness of masking 201 | # z = target.clone() 202 | # z = torch.where(z == IGNORE_TOKEN_ID, tokenizer.unk_token_id, z) 203 | # print(tokenizer.decode(z[cur_len:cur_len+turn_len])) 204 | # import pdb; pdb.set_trace() 205 | cur_len += turn_len+2 # why? 206 | 207 | target[cur_len:] = IGNORE_TOKEN_ID 208 | 209 | if False: # Inspect and check the correctness of masking 210 | # if True: # Inspect and check the correctness of masking 211 | z = target.clone() 212 | z = torch.where(z == IGNORE_TOKEN_ID, tokenizer.unk_token_id, z) 213 | print(tokenizer.decode(z)) 214 | import pdb; pdb.set_trace() 215 | # rank0_print(tokenizer.decode(z)) 216 | 217 | elif conv.sep_style == SeparatorStyle.CHATML: 218 | breakpoint() 219 | sep = conv.sep + conv.roles[1] + "\n" 220 | for conversation, target in zip(conversations, targets): 221 | total_len = int(target.ne(tokenizer.pad_token_id).sum()) 222 | 223 | turns = conversation.split(conv.sep) 224 | cur_len = 1 225 | target[:cur_len] = IGNORE_TOKEN_ID 226 | # breakpoint() 227 | for i, turn in enumerate(turns): 228 | if turn == "": 229 | break 230 | turn_len = len(tokenizer(turn).input_ids) 231 | 232 | parts = turn.split(sep) 233 | if len(parts) != 2: 234 | break 235 | parts[0] += sep 236 | # "-2" is hardcoded for the LLaMA tokenizer to make the offset correct. 237 | instruction_len = len(tokenizer(parts[0]).input_ids) - 2 238 | 239 | # Ignore the user instructions 240 | target[cur_len : cur_len + instruction_len] = IGNORE_TOKEN_ID 241 | cur_len += turn_len 242 | 243 | target[cur_len:] = IGNORE_TOKEN_ID 244 | 245 | if cur_len < tokenizer.model_max_length: 246 | if cur_len != total_len: 247 | target[:] = IGNORE_TOKEN_ID 248 | rank0_print( 249 | f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." 250 | f" (ignored)" 251 | ) 252 | 253 | return dict( 254 | input_ids=input_ids, 255 | labels=targets, 256 | attention_mask=(input_ids.ne(tokenizer.pad_token_id)).to(torch.int8), 257 | ) 258 | 259 | 260 | class SupervisedDataset(Dataset): 261 | """Dataset for supervised fine-tuning.""" 262 | 263 | def __init__(self, raw_data, tokenizer: transformers.PreTrainedTokenizer, conv_template = "vicuna-1.1", mask_user = True): 264 | super(SupervisedDataset, self).__init__() 265 | 266 | rank0_print("Formatting inputs...") 267 | sources = [example["conversations"] for example in raw_data] 268 | data_dict = preprocess(sources, tokenizer, conv_template, mask_user) 269 | 270 | if mask_user: 271 | rank0_print( 272 | f"WARNING: The loss of user prompt will be masked" 273 | ) 274 | else: 275 | rank0_print( 276 | f"WARNING: The loss of user prompt will **NOT** be masked" 277 | ) 278 | 279 | 280 | self.input_ids = data_dict["input_ids"] 281 | self.labels = data_dict["labels"] 282 | self.attention_mask = data_dict["attention_mask"] 283 | 284 | def __len__(self): 285 | return len(self.input_ids) 286 | 287 | def __getitem__(self, i) -> Dict[str, torch.Tensor]: 288 | return dict( 289 | input_ids=self.input_ids[i], 290 | labels=self.labels[i], 291 | attention_mask=self.attention_mask[i], 292 | ) 293 | 294 | 295 | class LazySupervisedDataset(Dataset): 296 | """Dataset for supervised fine-tuning.""" 297 | 298 | def __init__(self, raw_data, tokenizer: transformers.PreTrainedTokenizer, conv_template = "vicuna-1.1", mask_user = True): 299 | super(LazySupervisedDataset, self).__init__() 300 | self.tokenizer = tokenizer 301 | 302 | rank0_print("Formatting inputs...Skip in lazy mode") 303 | self.conv_template = conv_template 304 | self.mask_user = mask_user 305 | self.tokenizer = tokenizer 306 | self.raw_data = raw_data 307 | self.cached_data_dict = {} 308 | 309 | if mask_user: 310 | rank0_print( 311 | f"WARNING: The loss of user prompt will be masked" 312 | ) 313 | else: 314 | rank0_print( 315 | f"WARNING: The loss of user prompt will **NOT** be masked" 316 | ) 317 | 318 | def __len__(self): 319 | return len(self.raw_data) 320 | 321 | def __getitem__(self, i) -> Dict[str, torch.Tensor]: 322 | if i in self.cached_data_dict: 323 | return self.cached_data_dict[i] 324 | 325 | ret = preprocess([self.raw_data[i]["conversations"]], self.tokenizer, self.conv_template, self.mask_user) 326 | ret = dict( 327 | input_ids=ret["input_ids"][0], 328 | labels=ret["labels"][0], 329 | attention_mask=ret["attention_mask"][0], 330 | ) 331 | self.cached_data_dict[i] = ret 332 | return ret 333 | 334 | 335 | def make_supervised_data_module( 336 | tokenizer: transformers.PreTrainedTokenizer, data_args, mask_user = True 337 | ) -> Dict: 338 | """Make dataset and collator for supervised fine-tuning.""" 339 | conv_template = data_args.conv_template 340 | dataset_cls = ( 341 | LazySupervisedDataset if data_args.lazy_preprocess else SupervisedDataset 342 | ) 343 | rank0_print("Loading data...") 344 | try: 345 | raw_data = json.load(open(data_args.data_path, "r")) 346 | except FileNotFoundError: 347 | raw_data = load_dataset(data_args.data_path, split = "train") 348 | raw_data = [row for row in raw_data] 349 | 350 | # Split train/eval 351 | np.random.seed(0) 352 | train_raw_data = raw_data 353 | perm = np.random.permutation(len(raw_data)) 354 | split = int(len(perm) * 0.98) 355 | train_indices = perm[:split] 356 | eval_indices = perm[split:] 357 | train_raw_data = [raw_data[i] for i in train_indices] 358 | eval_raw_data = [raw_data[i] for i in eval_indices] 359 | rank0_print(f"#train {len(train_raw_data)}, #eval {len(eval_raw_data)}") 360 | 361 | train_dataset = dataset_cls(train_raw_data, tokenizer=tokenizer, conv_template = conv_template, mask_user = mask_user) 362 | eval_dataset = dataset_cls(eval_raw_data, tokenizer=tokenizer, conv_template = conv_template, mask_user = mask_user) 363 | return dict(train_dataset=train_dataset, eval_dataset=eval_dataset) 364 | 365 | def train(): 366 | global local_rank 367 | 368 | parser = transformers.HfArgumentParser( 369 | (ModelArguments, DataArguments, TrainingArguments) 370 | ) 371 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 372 | training_args.do_eval = False 373 | local_rank = training_args.local_rank 374 | model = transformers.AutoModelForCausalLM.from_pretrained( 375 | model_args.model_name_or_path, 376 | cache_dir=training_args.cache_dir, 377 | use_flash_attention_2 = True 378 | ) 379 | model.config.use_cache = False 380 | tokenizer = transformers.AutoTokenizer.from_pretrained( 381 | model_args.model_name_or_path, 382 | cache_dir=training_args.cache_dir, 383 | model_max_length=training_args.model_max_length, 384 | padding_side="right", 385 | use_fast=False, 386 | ) 387 | tokenizer.pad_token = tokenizer.unk_token 388 | 389 | if "mistral" in model_args.model_name_or_path.lower(): 390 | rank0_print("Mistral with Left Padding Side") 391 | tokenizer.padding_side = "left" 392 | 393 | data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args, mask_user = training_args.mask_user) 394 | 395 | trainer = Trainer( 396 | model=model, tokenizer=tokenizer, args=training_args, **data_module 397 | ) 398 | 399 | if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")): 400 | trainer.train(resume_from_checkpoint=True) 401 | else: 402 | trainer.train() 403 | trainer.save_state() 404 | 405 | trainer.save_model(output_dir = training_args.output_dir) 406 | 407 | 408 | if __name__ == "__main__": 409 | train() 410 | -------------------------------------------------------------------------------- /examples/block_ap/Llama-2-7b/w2g128.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python main_block_ap.py \ 2 | --model path/to/Llama-2-7b \ 3 | --output_dir ./output/block_ap_log/Llama-2-7b-w2g128 \ 4 | --net Llama-2 \ 5 | --wbits 2 \ 6 | --group_size 128 \ 7 | --quant_lr 1e-4 \ 8 | --weight_lr 2e-5 \ 9 | --real_quant \ 10 | --eval_ppl \ 11 | --eval_tasks piqa,arc_easy,arc_challenge,hellaswag,winogrande \ 12 | --save_quant_dir ./output/block_ap_models/Llama-2-7b-w2g128 -------------------------------------------------------------------------------- /examples/block_ap/Llama-2-7b/w2g64.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python main_block_ap.py \ 2 | --model path/to/Llama-2-7b \ 3 | --output_dir ./output/block_ap_log/Llama-2-7b-w2g64 \ 4 | --net Llama-2 \ 5 | --wbits 2 \ 6 | --group_size 64 \ 7 | --quant_lr 1e-4 \ 8 | --weight_lr 2e-5 \ 9 | --real_quant \ 10 | --eval_ppl \ 11 | --eval_tasks piqa,arc_easy,arc_challenge,hellaswag,winogrande \ 12 | --save_quant_dir ./output/block_ap_models/Llama-2-7b-w2g64 -------------------------------------------------------------------------------- /examples/block_ap/Llama-2-7b/w3g128.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python main_block_ap.py \ 2 | --model path/to/Llama-2-7b \ 3 | --output_dir ./output/block_ap_log/Llama-2-7b-w3g128 \ 4 | --net Llama-2 \ 5 | --wbits 3 \ 6 | --group_size 128 \ 7 | --quant_lr 1e-4 \ 8 | --weight_lr 1e-5 \ 9 | --real_quant \ 10 | --eval_ppl \ 11 | --eval_tasks piqa,arc_easy,arc_challenge,hellaswag,winogrande \ 12 | --save_quant_dir ./output/block_ap_models/Llama-2-7b-w3g128 -------------------------------------------------------------------------------- /examples/block_ap/Llama-2-7b/w4g128.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python main_block_ap.py \ 2 | --model path/to/Llama-2-7b \ 3 | --output_dir ./output/block_ap_log/Llama-2-7b-w4g128 \ 4 | --net Llama-2 \ 5 | --wbits 4 \ 6 | --group_size 128 \ 7 | --quant_lr 1e-4 \ 8 | --weight_lr 1e-5 \ 9 | --real_quant \ 10 | --eval_ppl \ 11 | --eval_tasks piqa,arc_easy,arc_challenge,hellaswag,winogrande \ 12 | --save_quant_dir ./output/block_ap_models/Llama-2-7b-w4g128 -------------------------------------------------------------------------------- /examples/block_ap/Mistral-Large-Instruct/w2g64.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python main_block_ap.py \ 2 | --model path/to/Mistral-Large-Instruct-2407 \ 3 | --output_dir ./output/block_ap_log/Mistral-Large-Instruct-2407-w2g64 \ 4 | --net mistral-large \ 5 | --wbits 2 \ 6 | --group_size 64 \ 7 | --quant_lr 3e-5 \ 8 | --weight_lr 2e-6 \ 9 | --train_size 2048 \ 10 | --epochs 3 \ 11 | --eval_ppl \ 12 | --real_quant \ 13 | --eval_tasks piqa,arc_easy,arc_challenge,hellaswag,winogrande \ 14 | --save_quant_dir ./output/block_ap_models/Mistral-Large-Instruct-2407-w2g64 -------------------------------------------------------------------------------- /examples/e2e_qp/Llama-2-7b/w2g128-alpaca.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python main_e2e_qp.py \ 2 | --quant_model_path ./output/block_ap_models/Llama-2-7b-w2g128 \ 3 | --model_family Llama-2 \ 4 | --wbits 2 \ 5 | --group_size 128 \ 6 | --learning_rate 2e-5 \ 7 | --dataset alpaca \ 8 | --dataset_format alpaca \ 9 | --output_dir ./output/e2e-qp-output/Llama-2-7b-w2g128-alpaca-4096 \ 10 | --do_train True \ 11 | --do_mmlu_eval True \ 12 | --source_max_len 384 \ 13 | --target_max_len 128 \ 14 | --per_device_train_batch_size 16 \ 15 | --per_device_eval_batch_size 4 \ 16 | --gradient_accumulation_steps 1 \ 17 | --logging_steps 10 \ 18 | --save_strategy steps \ 19 | --evaluation_strategy steps \ 20 | --max_steps 10000 \ 21 | --eval_steps 2000 \ 22 | --eval_dataset_size 16 \ 23 | --bf16 \ 24 | --data_seed 42 \ 25 | --max_grad_norm 0.3 \ 26 | --group_by_length -------------------------------------------------------------------------------- /examples/e2e_qp/Llama-2-7b/w2g128-redpajama.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=2 python main_e2e_qp.py \ 2 | --quant_model_path ./output/block_ap_models/Llama-2-7b-w2g128 \ 3 | --model_family Llama-2 \ 4 | --wbits 2 \ 5 | --group_size 128 \ 6 | --learning_rate 2e-5 \ 7 | --dataset redpajama \ 8 | --dataset_format pt \ 9 | --output_dir ./output/e2e-qp-output/Llama-2-7b-w2g128-redpajama-4096 \ 10 | --do_train True \ 11 | --pt_context_len 4096 \ 12 | --per_device_train_batch_size 4 \ 13 | --per_device_eval_batch_size 4 \ 14 | --gradient_accumulation_steps 8 \ 15 | --logging_steps 1 \ 16 | --save_strategy epoch \ 17 | --training_strategy epochs \ 18 | --evaluation_strategy steps \ 19 | --eval_steps 64 \ 20 | --max_train_samples 4096 \ 21 | --num_train_epochs 1 \ 22 | --eval_dataset_size 64 \ 23 | --bf16 \ 24 | --data_seed 42 \ 25 | --max_grad_norm 0.3 \ 26 | --eval_tasks piqa,arc_easy,arc_challenge,hellaswag,winogrande \ 27 | --preprocessing_num_workers 32 \ 28 | --do_ppl_eval -------------------------------------------------------------------------------- /examples/e2e_qp/Llama-2-7b/w2g64-alpaca.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python main_e2e_qp.py \ 2 | --quant_model_path ./output/block_ap_models/Llama-2-7b-w2g64 \ 3 | --model_family Llama-2 \ 4 | --wbits 2 \ 5 | --group_size 64 \ 6 | --learning_rate 2e-5 \ 7 | --dataset alpaca \ 8 | --dataset_format alpaca \ 9 | --output_dir ./output/e2e-qp-output/Llama-2-7b-w2g64-alpaca-4096 \ 10 | --do_train True \ 11 | --do_mmlu_eval True \ 12 | --source_max_len 384 \ 13 | --target_max_len 128 \ 14 | --per_device_train_batch_size 16 \ 15 | --per_device_eval_batch_size 4 \ 16 | --gradient_accumulation_steps 1 \ 17 | --logging_steps 10 \ 18 | --save_strategy steps \ 19 | --evaluation_strategy steps \ 20 | --max_steps 10000 \ 21 | --eval_steps 2000 \ 22 | --eval_dataset_size 16 \ 23 | --bf16 \ 24 | --data_seed 42 \ 25 | --max_grad_norm 0.3 \ 26 | --group_by_length -------------------------------------------------------------------------------- /examples/e2e_qp/Llama-2-7b/w2g64-redpajama.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python main_e2e_qp.py \ 2 | --quant_model_path ./output/block_ap_models/Llama-2-7b-w2g64 \ 3 | --model_family Llama-2 \ 4 | --wbits 2 \ 5 | --group_size 64 \ 6 | --learning_rate 2e-5 \ 7 | --dataset redpajama \ 8 | --dataset_format pt \ 9 | --output_dir ./output/e2e-qp-output/Llama-2-7b-w2g64-redpajama-4096 \ 10 | --do_train True \ 11 | --pt_context_len 4096 \ 12 | --per_device_train_batch_size 4 \ 13 | --per_device_eval_batch_size 4 \ 14 | --gradient_accumulation_steps 8 \ 15 | --logging_steps 1 \ 16 | --save_strategy epoch \ 17 | --training_strategy epochs \ 18 | --evaluation_strategy steps \ 19 | --eval_steps 64 \ 20 | --max_train_samples 4096 \ 21 | --num_train_epochs 1 \ 22 | --eval_dataset_size 64 \ 23 | --bf16 \ 24 | --data_seed 42 \ 25 | --max_grad_norm 0.3 \ 26 | --eval_tasks piqa,arc_easy,arc_challenge,hellaswag,winogrande \ 27 | --preprocessing_num_workers 32 \ 28 | --do_ppl_eval -------------------------------------------------------------------------------- /examples/e2e_qp/Llama-2-7b/w3g128-alpaca.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python main_e2e_qp.py \ 2 | --quant_model_path ./output/block_ap_models/Llama-2-7b-w3g128 \ 3 | --model_family Llama-2 \ 4 | --wbits 3 \ 5 | --group_size 128 \ 6 | --learning_rate 1e-5 \ 7 | --dataset alpaca \ 8 | --dataset_format alpaca \ 9 | --output_dir ./output/e2e-qp-output/Llama-2-7b-w3g128-alpaca-4096 \ 10 | --do_train True \ 11 | --do_mmlu_eval True \ 12 | --source_max_len 384 \ 13 | --target_max_len 128 \ 14 | --per_device_train_batch_size 16 \ 15 | --per_device_eval_batch_size 4 \ 16 | --gradient_accumulation_steps 1 \ 17 | --logging_steps 10 \ 18 | --save_strategy steps \ 19 | --evaluation_strategy steps \ 20 | --max_steps 10000 \ 21 | --eval_steps 2000 \ 22 | --eval_dataset_size 16 \ 23 | --bf16 \ 24 | --data_seed 42 \ 25 | --max_grad_norm 0.3 \ 26 | --group_by_length -------------------------------------------------------------------------------- /examples/e2e_qp/Llama-2-7b/w3g128-redpajama.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=2 python main_e2e_qp.py \ 2 | --quant_model_path ./output/block_ap_models/Llama-2-7b-w3g128 \ 3 | --model_family Llama-2 \ 4 | --wbits 3 \ 5 | --group_size 128 \ 6 | --learning_rate 1e-5 \ 7 | --dataset redpajama \ 8 | --dataset_format pt \ 9 | --output_dir ./output/e2e-qp-output/Llama-2-7b-w3g128-redpajama-4096 \ 10 | --do_train True \ 11 | --pt_context_len 4096 \ 12 | --per_device_train_batch_size 4 \ 13 | --per_device_eval_batch_size 4 \ 14 | --gradient_accumulation_steps 8 \ 15 | --logging_steps 1 \ 16 | --save_strategy epoch \ 17 | --training_strategy epochs \ 18 | --evaluation_strategy steps \ 19 | --eval_steps 64 \ 20 | --max_train_samples 4096 \ 21 | --num_train_epochs 1 \ 22 | --eval_dataset_size 64 \ 23 | --bf16 \ 24 | --data_seed 42 \ 25 | --max_grad_norm 0.3 \ 26 | --eval_tasks piqa,arc_easy,arc_challenge,hellaswag,winogrande \ 27 | --preprocessing_num_workers 32 \ 28 | --do_ppl_eval -------------------------------------------------------------------------------- /examples/e2e_qp/Llama-2-7b/w4g128-alpaca.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python main_e2e_qp.py \ 2 | --quant_model_path ./output/block_ap_models/Llama-2-7b-w4g128 \ 3 | --model_family Llama-2 \ 4 | --wbits 4 \ 5 | --group_size 128 \ 6 | --learning_rate 1e-5 \ 7 | --dataset alpaca \ 8 | --dataset_format alpaca \ 9 | --output_dir ./output/e2e-qp-output/Llama-2-7b-w4g128-alpaca-4096 \ 10 | --do_train True \ 11 | --do_mmlu_eval True \ 12 | --source_max_len 384 \ 13 | --target_max_len 128 \ 14 | --per_device_train_batch_size 16 \ 15 | --per_device_eval_batch_size 4 \ 16 | --gradient_accumulation_steps 1 \ 17 | --logging_steps 10 \ 18 | --save_strategy steps \ 19 | --evaluation_strategy steps \ 20 | --max_steps 10000 \ 21 | --eval_steps 2000 \ 22 | --eval_dataset_size 16 \ 23 | --bf16 \ 24 | --data_seed 42 \ 25 | --max_grad_norm 0.3 \ 26 | --group_by_length -------------------------------------------------------------------------------- /examples/e2e_qp/Llama-2-7b/w4g128-redpajama.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=2 python main_e2e_qp.py \ 2 | --quant_model_path ./output/block_ap_models/Llama-2-7b-w4g128 \ 3 | --model_family Llama-2 \ 4 | --wbits 4 \ 5 | --group_size 128 \ 6 | --learning_rate 1e-5 \ 7 | --dataset redpajama \ 8 | --dataset_format pt \ 9 | --output_dir ./output/e2e-qp-output/Llama-2-7b-w4g128-redpajama-4096 \ 10 | --do_train True \ 11 | --pt_context_len 4096 \ 12 | --per_device_train_batch_size 4 \ 13 | --per_device_eval_batch_size 4 \ 14 | --gradient_accumulation_steps 8 \ 15 | --logging_steps 1 \ 16 | --save_strategy epoch \ 17 | --training_strategy epochs \ 18 | --evaluation_strategy steps \ 19 | --eval_steps 64 \ 20 | --max_train_samples 4096 \ 21 | --num_train_epochs 1 \ 22 | --eval_dataset_size 64 \ 23 | --bf16 \ 24 | --data_seed 42 \ 25 | --max_grad_norm 0.3 \ 26 | --eval_tasks piqa,arc_easy,arc_challenge,hellaswag,winogrande \ 27 | --preprocessing_num_workers 32 \ 28 | --do_ppl_eval -------------------------------------------------------------------------------- /examples/e2e_qp/Llama-3-8b-instruct/w2g128-deita.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=2 python main_e2e_qp.py \ 2 | --quant_model_path ./output/block_ap_models/Llama-3-8b-instruct-w2g128 \ 3 | --model_family llama3 \ 4 | --wbits 2 \ 5 | --group_size 128 \ 6 | --learning_rate 2e-5 \ 7 | --dataset deita-10k \ 8 | --dataset_format pt \ 9 | --output_dir ./output/e2e-qp-output/Llama-2-7b-w2g128-deita \ 10 | --do_train True \ 11 | --pt_context_len 4096 \ 12 | --per_device_train_batch_size 4 \ 13 | --per_device_eval_batch_size 4 \ 14 | --gradient_accumulation_steps 8 \ 15 | --logging_steps 1 \ 16 | --save_strategy epoch \ 17 | --training_strategy epochs \ 18 | --evaluation_strategy steps \ 19 | --eval_steps 64 \ 20 | --max_train_samples 8192 \ 21 | --num_train_epochs 1 \ 22 | --eval_dataset_size 64 \ 23 | --bf16 \ 24 | --data_seed 42 \ 25 | --max_grad_norm 0.3 \ 26 | --eval_tasks piqa,arc_easy,arc_challenge,hellaswag,winogrande \ 27 | --preprocessing_num_workers 32 \ 28 | --do_ppl_eval -------------------------------------------------------------------------------- /examples/e2e_qp/Llama-3-8b-instruct/w2g64-deita.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=2 python main_e2e_qp.py \ 2 | --quant_model_path ./output/block_ap_models/Llama-3-8b-instruct-w2g64 \ 3 | --model_family llama3 \ 4 | --wbits 2 \ 5 | --group_size 64 \ 6 | --learning_rate 2e-5 \ 7 | --dataset deita-10k \ 8 | --dataset_format pt \ 9 | --output_dir ./output/e2e-qp-output/Llama-2-7b-w2g64-deita \ 10 | --do_train True \ 11 | --pt_context_len 4096 \ 12 | --per_device_train_batch_size 4 \ 13 | --per_device_eval_batch_size 4 \ 14 | --gradient_accumulation_steps 8 \ 15 | --logging_steps 1 \ 16 | --save_strategy epoch \ 17 | --training_strategy epochs \ 18 | --evaluation_strategy steps \ 19 | --eval_steps 64 \ 20 | --max_train_samples 8192 \ 21 | --num_train_epochs 1 \ 22 | --eval_dataset_size 64 \ 23 | --bf16 \ 24 | --data_seed 42 \ 25 | --max_grad_norm 0.3 \ 26 | --eval_tasks piqa,arc_easy,arc_challenge,hellaswag,winogrande \ 27 | --preprocessing_num_workers 32 \ 28 | --do_ppl_eval -------------------------------------------------------------------------------- /examples/e2e_qp/Llama-3-8b-instruct/w3g128-deita.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=2 python main_e2e_qp.py \ 2 | --quant_model_path ./output/block_ap_models/Llama-3-8b-instruct-w3g128 \ 3 | --model_family llama3 \ 4 | --wbits 3 \ 5 | --group_size 128 \ 6 | --learning_rate 1e-5 \ 7 | --dataset deita-10k \ 8 | --dataset_format pt \ 9 | --output_dir ./output/e2e-qp-output/Llama-2-7b-w3g128-deita \ 10 | --do_train True \ 11 | --pt_context_len 4096 \ 12 | --per_device_train_batch_size 4 \ 13 | --per_device_eval_batch_size 4 \ 14 | --gradient_accumulation_steps 8 \ 15 | --logging_steps 1 \ 16 | --save_strategy epoch \ 17 | --training_strategy epochs \ 18 | --evaluation_strategy steps \ 19 | --eval_steps 64 \ 20 | --max_train_samples 8192 \ 21 | --num_train_epochs 1 \ 22 | --eval_dataset_size 64 \ 23 | --bf16 \ 24 | --data_seed 42 \ 25 | --max_grad_norm 0.3 \ 26 | --eval_tasks piqa,arc_easy,arc_challenge,hellaswag,winogrande \ 27 | --preprocessing_num_workers 32 \ 28 | --do_ppl_eval -------------------------------------------------------------------------------- /examples/e2e_qp/Llama-3-8b-instruct/w4g128-deita.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=2 python main_e2e_qp.py \ 2 | --quant_model_path ./output/block_ap_models/Llama-3-8b-instruct-w4g128 \ 3 | --model_family llama3 \ 4 | --wbits 4 \ 5 | --group_size 128 \ 6 | --learning_rate 1e-5 \ 7 | --dataset deita-10k \ 8 | --dataset_format pt \ 9 | --output_dir ./output/e2e-qp-output/Llama-2-7b-w4g128-deita \ 10 | --do_train True \ 11 | --pt_context_len 4096 \ 12 | --per_device_train_batch_size 4 \ 13 | --per_device_eval_batch_size 4 \ 14 | --gradient_accumulation_steps 8 \ 15 | --logging_steps 1 \ 16 | --save_strategy epoch \ 17 | --training_strategy epochs \ 18 | --evaluation_strategy steps \ 19 | --eval_steps 64 \ 20 | --max_train_samples 8192 \ 21 | --num_train_epochs 1 \ 22 | --eval_dataset_size 64 \ 23 | --bf16 \ 24 | --data_seed 42 \ 25 | --max_grad_norm 0.3 \ 26 | --eval_tasks piqa,arc_easy,arc_challenge,hellaswag,winogrande \ 27 | --preprocessing_num_workers 32 \ 28 | --do_ppl_eval -------------------------------------------------------------------------------- /examples/inference/Llama-2-7b/fp16.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python main_block_ap.py \ 2 | --model path/to/Llama-2-7b \ 3 | --net Llama-2 \ 4 | --wbits 16 \ 5 | --output_dir ./output/inference_results/ \ 6 | --eval_ppl \ 7 | --eval_tasks piqa,arc_easy,arc_challenge,hellaswag,winogrande 8 | -------------------------------------------------------------------------------- /examples/inference/Llama-2-7b/w2g64.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python main_block_ap.py \ 2 | --resume_quant path/to/Llama-2-7b-w2g64 \ 3 | --net Llama-2 \ 4 | --wbits 2 \ 5 | --group_size 64 \ 6 | --output_dir ./output/inference_results/ \ 7 | --eval_ppl \ 8 | --eval_tasks piqa,arc_easy,arc_challenge,hellaswag,winogrande 9 | -------------------------------------------------------------------------------- /examples/model_transfer/efficientqat_to_bitblas/llama-2-7b.sh: -------------------------------------------------------------------------------- 1 | # llama-2-7b-w2g64 2 | CUDA_DEVICE_ORDER=PCI_BUS_ID CUDA_VISIBLE_DEVICES=0 python -m model_transfer.efficientqat_to_others \ 3 | --model path/to/original/quantized/model \ 4 | --save_dir path/to/new/model \ 5 | --wbits 2 \ 6 | --group_size 64 \ 7 | --test_speed \ 8 | --target_format bitblas -------------------------------------------------------------------------------- /examples/model_transfer/efficientqat_to_gptq/llama-2-7b.sh: -------------------------------------------------------------------------------- 1 | # llama-2-7b-w2g64 2 | CUDA_VISIBLE_DEVICES=0 python -m model_transfer.efficientqat_to_others \ 3 | --model path/to/original/quantized/model \ 4 | --save_dir path/to/new/model \ 5 | --wbits 2 \ 6 | --group_size 64 \ 7 | --eval_ppl \ 8 | --test_speed -------------------------------------------------------------------------------- /examples/model_transfer/fp32_to_16/llama-2-7b.sh: -------------------------------------------------------------------------------- 1 | # llama-2-7b-w2g64 2 | CUDA_VISIBLE_DEVICES=0 python -m model_transfer.fp32_to_16 \ 3 | --model path/to/original/quantized/model \ 4 | --save_dir path/to/new/model \ 5 | --target_type fp16 \ 6 | --wbits 2 \ 7 | --group_size 64 \ 8 | --eval_ppl -------------------------------------------------------------------------------- /examples/model_transfer/real_to_fake/llama-2-7b.sh: -------------------------------------------------------------------------------- 1 | # llama-2-7b-w2g64 2 | CUDA_VISIBLE_DEVICES=0 python -m model_transfer.real_to_fake \ 3 | --model path/to/original/quantized/model \ 4 | --save_dir path/to/new/model \ 5 | --wbits 2 \ 6 | --group_size 64 -------------------------------------------------------------------------------- /main_block_ap.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import random 4 | import numpy as np 5 | import torch 6 | import time 7 | from datautils_block import get_loaders, test_ppl 8 | import torch.nn as nn 9 | from quantize.block_ap import block_ap 10 | from tqdm import tqdm 11 | import utils 12 | from pathlib import Path 13 | from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM 14 | from quantize.int_linear_real import load_quantized_model 15 | from accelerate import infer_auto_device_map, dispatch_model 16 | 17 | 18 | 19 | 20 | torch.backends.cudnn.benchmark = True 21 | 22 | @torch.no_grad() 23 | def evaluate(model, tokenizer, args, logger): 24 | ''' 25 | Note: evaluation simply move model to single GPU. 26 | Therefor, to evaluate large model such as Llama-2-70B on single A100-80GB, 27 | please activate '--real_quant'. 28 | ''' 29 | # import pdb;pdb.set_trace() 30 | block_class_name = model.model.layers[0].__class__.__name__ 31 | device_map = infer_auto_device_map(model, max_memory={i: args.max_memory for i in range(torch.cuda.device_count())}, no_split_module_classes=[block_class_name]) 32 | model = dispatch_model(model, device_map=device_map) 33 | results = {} 34 | 35 | if args.eval_ppl: 36 | datasets = ["wikitext2", "c4"] 37 | ppl_results = test_ppl(model, tokenizer, datasets, args.ppl_seqlen) 38 | for dataset in ppl_results: 39 | logger.info(f'{dataset} perplexity: {ppl_results[dataset]:.2f}') 40 | 41 | if args.eval_tasks != "": 42 | import lm_eval 43 | from lm_eval.models.huggingface import HFLM 44 | from lm_eval.utils import make_table 45 | task_list = args.eval_tasks.split(',') 46 | model = HFLM(pretrained=model, batch_size=args.eval_batch_size) 47 | task_manager = lm_eval.tasks.TaskManager() 48 | results = lm_eval.simple_evaluate( 49 | model=model, 50 | tasks=task_list, 51 | num_fewshot=0, 52 | task_manager=task_manager, 53 | ) 54 | logger.info(make_table(results)) 55 | total_acc = 0 56 | for task in task_list: 57 | total_acc += results['results'][task]['acc,none'] 58 | logger.info(f'Average Acc: {total_acc/len(task_list)*100:.2f}%') 59 | return results 60 | 61 | 62 | def main(): 63 | import argparse 64 | 65 | parser = argparse.ArgumentParser() 66 | parser.add_argument("--model", type=str, help="model name of model path") 67 | parser.add_argument("--cache_dir", default="./cache", type=str, help="direction of cached dataset, leading to faster debug") 68 | parser.add_argument("--output_dir", default="./log/", type=str, help="direction of logging file") 69 | parser.add_argument("--save_quant_dir", default=None, type=str, help="direction for saving quantization model") 70 | parser.add_argument("--real_quant", default=False, action="store_true", 71 | help="use real quantization instead of fake quantization, can reduce memory footprint") 72 | parser.add_argument("--resume_quant", type=str, default=None, help="model path of resumed quantized model") 73 | parser.add_argument("--calib_dataset",type=str,default="redpajama", 74 | choices=["wikitext2", "ptb", "c4", "mix", "redpajama"], 75 | help="Where to extract calibration data from.") 76 | parser.add_argument("--train_size", type=int, default=4096, help="Number of training data samples.") 77 | parser.add_argument("--val_size", type=int, default=64, help="Number of validation data samples.") 78 | parser.add_argument("--training_seqlen", type=int, default=2048, help="lenth of the training sequence.") 79 | parser.add_argument("--batch_size", type=int, default=2, help="batch size.") 80 | parser.add_argument("--epochs", type=int, default=2) 81 | parser.add_argument("--ppl_seqlen", type=int, default=2048, help="input sequence length for evaluating perplexity") 82 | parser.add_argument("--seed", type=int, default=2, help="Seed for sampling the calibration data.") 83 | parser.add_argument("--eval_ppl", action="store_true",help="evaluate perplexity on wikitext2 and c4") 84 | parser.add_argument("--eval_tasks", type=str,default="", help="exampe:piqa,arc_easy,arc_challenge,hellaswag,winogrande") 85 | parser.add_argument("--eval_batch_size", type=int, default=16) 86 | parser.add_argument("--wbits", type=int, default=4, help="weights quantization bits") 87 | parser.add_argument("--group_size", type=int, default=128, help="weights quantization group size") 88 | parser.add_argument("--quant_lr", type=float, default=1e-4, help="lr of quantization parameters (s and z)") 89 | parser.add_argument("--weight_lr", type=float, default=1e-5, help="lr of full-precision weights") 90 | parser.add_argument("--min_lr_factor", type=float, default=20, help="min_lr = lr/min_lr_factor") 91 | parser.add_argument("--clip_grad", type=float, default=0.3) 92 | parser.add_argument("--wd", type=float, default=0,help="weight decay") 93 | parser.add_argument("--net", type=str, default=None,help="model (family) name, for the easier saving of data cache") 94 | parser.add_argument("--max_memory", type=str, default="70GiB",help="The maximum memory of each GPU") 95 | parser.add_argument("--early_stop", type=int, default=0,help="early stoping after validation loss do not decrease") 96 | parser.add_argument("--off_load_to_disk", action="store_true", default=False, help="save training dataset to disk, saving CPU memory but may reduce training speed") 97 | 98 | os.environ['TOKENIZERS_PARALLELISM'] = 'false' 99 | args = parser.parse_args() 100 | random.seed(args.seed) 101 | np.random.seed(args.seed) 102 | torch.manual_seed(args.seed) 103 | torch.cuda.manual_seed(args.seed) 104 | 105 | 106 | # init logger 107 | if args.output_dir: 108 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 109 | if args.cache_dir: 110 | Path(args.cache_dir).mkdir(parents=True, exist_ok=True) 111 | if args.save_quant_dir: 112 | Path(args.save_quant_dir).mkdir(parents=True, exist_ok=True) 113 | output_dir = Path(args.output_dir) 114 | logger = utils.create_logger(output_dir) 115 | logger.info(args) 116 | 117 | if args.net is None: 118 | args.net = args.model.split('/')[-1] 119 | logger.info(f"net is None, setting as {args.net}") 120 | if args.resume_quant: 121 | # directly load quantized model for evaluation 122 | model, tokenizer = load_quantized_model(args.resume_quant,args.wbits, args.group_size) 123 | logger.info(f"memory footprint after loading quantized model: {torch.cuda.max_memory_allocated('cuda') / 1024**3:.2f}GiB") 124 | else: 125 | # load fp quantized model 126 | config = AutoConfig.from_pretrained(args.model) 127 | tokenizer = AutoTokenizer.from_pretrained(args.model, use_fast=False,legacy=False) 128 | model = AutoModelForCausalLM.from_pretrained(args.model, config=config, device_map='cpu',torch_dtype=torch.float16) 129 | for param in model.parameters(): 130 | param.requires_grad = False 131 | 132 | # quantization 133 | if args.wbits < 16: 134 | logger.info("=== start quantization ===") 135 | tick = time.time() 136 | # load calibration dataset 137 | cache_trainloader = f'{args.cache_dir}/dataloader_{args.net}_{args.calib_dataset}_{args.train_size}_{args.val_size}_{args.training_seqlen}_train.cache' 138 | cache_valloader = f'{args.cache_dir}/dataloader_{args.net}_{args.calib_dataset}_{args.train_size}_{args.val_size}_{args.training_seqlen}_val.cache' 139 | if os.path.exists(cache_trainloader) and os.path.exists(cache_valloader): 140 | trainloader = torch.load(cache_trainloader) 141 | logger.info(f"load trainloader from {cache_trainloader}") 142 | valloader = torch.load(cache_valloader) 143 | logger.info(f"load valloader from {cache_valloader}") 144 | else: 145 | trainloader, valloader = get_loaders( 146 | args.calib_dataset, 147 | tokenizer, 148 | args.train_size, 149 | args.val_size, 150 | seed=args.seed, 151 | seqlen=args.training_seqlen, 152 | ) 153 | torch.save(trainloader, cache_trainloader) 154 | torch.save(valloader, cache_valloader) 155 | block_ap( 156 | model, 157 | args, 158 | trainloader, 159 | valloader, 160 | logger, 161 | ) 162 | logger.info(time.time() - tick) 163 | torch.cuda.empty_cache() 164 | if args.save_quant_dir: 165 | logger.info("start saving model") 166 | model.save_pretrained(args.save_quant_dir) 167 | tokenizer.save_pretrained(args.save_quant_dir) 168 | logger.info("save model success") 169 | evaluate(model, tokenizer, args,logger) 170 | 171 | 172 | 173 | if __name__ == "__main__": 174 | print(sys.argv) 175 | main() 176 | -------------------------------------------------------------------------------- /main_e2e_qp.py: -------------------------------------------------------------------------------- 1 | # This file is modified from https://github.com/artidoro/qlora/blob/main/qlora.py 2 | import json 3 | import os 4 | from os.path import exists, join, isdir 5 | from dataclasses import dataclass, field 6 | from typing import Optional, Dict 7 | import numpy as np 8 | import importlib 9 | from packaging import version 10 | 11 | import torch 12 | import transformers 13 | import argparse 14 | from transformers import ( 15 | set_seed, 16 | Seq2SeqTrainer, 17 | LlamaTokenizer 18 | ) 19 | 20 | 21 | from datautils_block import test_ppl 22 | from datautils_e2e import make_data_module 23 | from bitsandbytes.optim import AdamW 24 | import os 25 | import utils 26 | from quantize.int_linear_real import load_quantized_model,QuantLinear 27 | from pathlib import Path 28 | 29 | 30 | 31 | 32 | def is_ipex_available(): 33 | def get_major_and_minor_from_version(full_version): 34 | return str(version.parse(full_version).major) + "." + str(version.parse(full_version).minor) 35 | 36 | _torch_version = importlib.metadata.version("torch") 37 | if importlib.util.find_spec("intel_extension_for_pytorch") is None: 38 | return False 39 | _ipex_version = "N/A" 40 | try: 41 | _ipex_version = importlib.metadata.version("intel_extension_for_pytorch") 42 | except importlib.metadata.PackageNotFoundError: 43 | return False 44 | torch_major_and_minor = get_major_and_minor_from_version(_torch_version) 45 | ipex_major_and_minor = get_major_and_minor_from_version(_ipex_version) 46 | if torch_major_and_minor != ipex_major_and_minor: 47 | warnings.warn( 48 | f"Intel Extension for PyTorch {ipex_major_and_minor} needs to work with PyTorch {ipex_major_and_minor}.*," 49 | f" but PyTorch {_torch_version} is found. Please switch to the matching version and run again." 50 | ) 51 | return False 52 | return True 53 | 54 | 55 | if torch.cuda.is_available(): 56 | torch.backends.cuda.matmul.allow_tf32 = True 57 | 58 | 59 | IGNORE_INDEX = -100 60 | DEFAULT_PAD_TOKEN = "[PAD]" 61 | 62 | @dataclass 63 | class ModelArguments: 64 | quant_model_path: Optional[str] = field( 65 | default="", 66 | metadata={"help": "path of the quantization model by Block-AP."} 67 | ) 68 | model_family: Optional[str] = field( 69 | default="llama-2", 70 | metadata={"help": "for the saving of dataset cache for faster experiments"} 71 | ) 72 | trust_remote_code: Optional[bool] = field( 73 | default=False, 74 | metadata={"help": "Enable unpickling of arbitrary code in AutoModelForCausalLM#from_pretrained."} 75 | ) 76 | use_auth_token: Optional[bool] = field( 77 | default=False, 78 | metadata={"help": "Enables using Huggingface auth token from Git Credentials."} 79 | ) 80 | 81 | @dataclass 82 | class DataArguments: 83 | eval_dataset_size: int = field( 84 | default=1024, metadata={"help": "Size of validation dataset."} 85 | ) 86 | max_train_samples: Optional[int] = field( 87 | default=None, 88 | metadata={ 89 | "help": "For debugging purposes or quicker training, truncate the number of training examples to this " 90 | "value if set." 91 | }, 92 | ) 93 | max_eval_samples: Optional[int] = field( 94 | default=None, 95 | metadata={ 96 | "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this " 97 | "value if set." 98 | }, 99 | ) 100 | source_max_len: int = field( 101 | default=1024, 102 | metadata={"help": "Maximum source sequence length. Sequences will be right padded (and possibly truncated)."}, 103 | ) 104 | target_max_len: int = field( 105 | default=256, 106 | metadata={"help": "Maximum target sequence length. Sequences will be right padded (and possibly truncated)."}, 107 | ) 108 | dataset: str = field( 109 | default='alpaca', 110 | metadata={"help": "Which dataset to finetune on. See datamodule for options."} 111 | ) 112 | eval_tasks: str = field( 113 | default='', 114 | metadata={"help": "evaluation tasks for lm eval, example:piqa,arc_easy,arc_challenge,hellaswag,winogrande"} 115 | ) 116 | conv_temp: str = field( 117 | default='llama-2', 118 | metadata={"help": "Conversation template, only useful with deita datasets"} 119 | ) 120 | mask_use: bool = field( 121 | default=True, metadata={"help": "mask the loss to role in dialogue datas"} 122 | ) 123 | dataset_format: Optional[str] = field( 124 | default=None, 125 | metadata={"help": "Which dataset format is used. [alpaca|redpajama]"} 126 | ) 127 | overwrite_cache: bool = field( 128 | default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} 129 | ) 130 | preprocessing_num_workers: Optional[int] = field( 131 | default=32, 132 | metadata={"help": "The number of processes to use for the preprocessing."}, 133 | ) 134 | 135 | @dataclass 136 | class TrainingArguments(transformers.Seq2SeqTrainingArguments): 137 | cache_dir: Optional[str] = field( 138 | default=None 139 | ) 140 | train_on_source: Optional[bool] = field( 141 | default=False, 142 | metadata={"help": "Whether to train on the input in addition to the target text."} 143 | ) 144 | do_mmlu_eval: Optional[bool] = field( 145 | default=False, 146 | metadata={"help": "Whether to run the MMLU evaluation."} 147 | ) 148 | do_ppl_eval: Optional[bool] = field( 149 | default=False, 150 | metadata={"help": "Whether to run the PPL evaluation."} 151 | ) 152 | pt_context_len: int = field( 153 | default=1024, 154 | metadata={"help": "language modeling length."} 155 | ) 156 | full_finetune: bool = field( 157 | default=False, 158 | metadata={"help": "Finetune the entire model without adapters."} 159 | ) 160 | wbits: int = field( 161 | default=4, 162 | metadata={"help": "How many bits to use."} 163 | ) 164 | group_size: int = field( 165 | default=64, 166 | metadata={"help": "How many group size to use."} 167 | ) 168 | max_memory_MB: int = field( 169 | default=80000, 170 | metadata={"help": "Free memory per gpu."} 171 | ) 172 | report_to: str = field( 173 | default='none', 174 | metadata={"help": "To use wandb or something else for reporting."} 175 | ) 176 | output_dir: str = field(default='./output', metadata={"help": 'The output dir for logs and checkpoints'}) 177 | resume_from_checkpoint: str = field(default=None, metadata={"help": 'The output dir for logs and checkpoints'}) 178 | optim: str = field(default='paged_adamw_32bit', metadata={"help": 'The optimizer to be used'}) 179 | per_device_train_batch_size: int = field(default=1, metadata={"help": 'The training batch size per GPU. Increase for better speed.'}) 180 | gradient_accumulation_steps: int = field(default=16, metadata={"help": 'How many gradients to accumulate before to perform an optimizer step'}) 181 | max_steps: int = field(default=0, metadata={"help": 'How many optimizer update steps to take'}) 182 | weight_decay: float = field(default=0.0, metadata={"help": 'The L2 weight decay rate of AdamW'}) # use lora dropout instead for regularization if needed 183 | learning_rate: float = field(default=2e-5, metadata={"help": 'The learnign rate'}) 184 | remove_unused_columns: bool = field(default=False, metadata={"help": 'Removed unused columns. Needed to make this codebase work.'}) 185 | max_grad_norm: float = field(default=0.3, metadata={"help": 'Gradient clipping max norm. This is tuned and works well for all models tested.'}) 186 | gradient_checkpointing: bool = field(default=True, metadata={"help": 'Use gradient checkpointing. You want to use this.'}) 187 | do_train: bool = field(default=True, metadata={"help": 'To train or not to train, that is the question?'}) 188 | lr_scheduler_type: str = field(default='cosine', metadata={"help": 'Learning rate schedule. Constant a bit better than cosine, and has advantage for analysis'}) 189 | warmup_ratio: float = field(default=0.03, metadata={"help": 'Fraction of steps to do a warmup for'}) 190 | logging_steps: int = field(default=10, metadata={"help": 'The frequency of update steps after which to log the loss'}) 191 | group_by_length: bool = field(default=False, metadata={"help": 'Group sequences into batches with same length. Saves memory and speeds up training considerably.'}) 192 | save_strategy: str = field(default='epoch', metadata={"help": 'When to save checkpoints'}) 193 | save_steps: int = field(default=250, metadata={"help": 'How often to save a model'}) 194 | save_total_limit: int = field(default=5, metadata={"help": 'How many checkpoints to save before the oldest is overwritten'}) 195 | 196 | @dataclass 197 | class GenerationArguments: 198 | # For more hyperparameters check: 199 | # https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.GenerationConfig 200 | # Length arguments 201 | max_new_tokens: Optional[int] = field( 202 | default=256, 203 | metadata={"help": "Maximum number of new tokens to be generated in evaluation or prediction loops" 204 | "if predict_with_generate is set."} 205 | ) 206 | min_new_tokens : Optional[int] = field( 207 | default=None, 208 | metadata={"help": "Minimum number of new tokens to generate."} 209 | ) 210 | 211 | # Generation strategy 212 | do_sample: Optional[bool] = field(default=False) 213 | num_beams: Optional[int] = field(default=1) 214 | num_beam_groups: Optional[int] = field(default=1) 215 | penalty_alpha: Optional[float] = field(default=None) 216 | use_cache: Optional[bool] = field(default=True) 217 | 218 | # Hyperparameters for logit manipulation 219 | temperature: Optional[float] = field(default=1.0) 220 | top_k: Optional[int] = field(default=50) 221 | top_p: Optional[float] = field(default=1.0) 222 | typical_p: Optional[float] = field(default=1.0) 223 | diversity_penalty: Optional[float] = field(default=0.0) 224 | repetition_penalty: Optional[float] = field(default=1.0) 225 | length_penalty: Optional[float] = field(default=1.0) 226 | no_repeat_ngram_size: Optional[int] = field(default=0) 227 | 228 | 229 | 230 | def get_accelerate_model(args, checkpoint_dir): 231 | 232 | if torch.cuda.is_available(): 233 | n_gpus = torch.cuda.device_count() 234 | if is_ipex_available() and torch.xpu.is_available(): 235 | n_gpus = torch.xpu.device_count() 236 | 237 | max_memory = f'{args.max_memory_MB}MB' 238 | max_memory = {i: max_memory for i in range(n_gpus)} 239 | device_map = "auto" 240 | 241 | # if we are in a distributed setting, we need to set the device map and max memory per device 242 | if os.environ.get('LOCAL_RANK') is not None: 243 | local_rank = int(os.environ.get('LOCAL_RANK', '0')) 244 | device_map = {'': local_rank} 245 | max_memory = {'': max_memory[local_rank]} 246 | 247 | 248 | 249 | model, tokenizer = load_quantized_model(args.quant_model_path,args.wbits, args.group_size) 250 | tokenizer.model_max_length = args.pt_context_len 251 | 252 | compute_dtype = (torch.float16 if args.fp16 else (torch.bfloat16 if args.bf16 else torch.float32)) 253 | if compute_dtype == torch.float16 and (is_ipex_available() and torch.xpu.is_available()): 254 | compute_dtype = torch.bfloat16 255 | print('Intel XPU does not support float16 yet, so switching to bfloat16') 256 | 257 | setattr(model, 'model_parallel', True) 258 | setattr(model, 'is_parallelizable', True) 259 | 260 | model.config.torch_dtype=(torch.float32 if args.fp16 else (torch.bfloat16 if args.bf16 else torch.float32)) 261 | # from peft import prepare_model_for_kbit_training 262 | # model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=args.gradient_checkpointing) 263 | model.cuda() 264 | model.train() 265 | 266 | if tokenizer._pad_token is None: 267 | smart_tokenizer_and_embedding_resize( 268 | special_tokens_dict=dict(pad_token=DEFAULT_PAD_TOKEN), 269 | tokenizer=tokenizer, 270 | model=model, 271 | ) 272 | 273 | # TODO 274 | # if 'llama1' in args.model_name_or_path or 'llama2' in args.model_name_or_path or 'llama-1' in args.model_name_or_path or 'llama-2' in args.model_name_or_path: 275 | if isinstance(tokenizer, LlamaTokenizer): 276 | # LLaMA tokenizer may not have correct special tokens set. 277 | # Check and add them if missing to prevent them from being parsed into different tokens. 278 | # Note that these are present in the vocabulary. 279 | # Note also that `model.config.pad_token_id` is 0 which corresponds to `` token. 280 | print('Adding special tokens.') 281 | tokenizer.add_special_tokens({ 282 | "eos_token": tokenizer.convert_ids_to_tokens(model.config.eos_token_id), 283 | "bos_token": tokenizer.convert_ids_to_tokens(model.config.bos_token_id), 284 | "unk_token": tokenizer.convert_ids_to_tokens( 285 | model.config.pad_token_id if model.config.pad_token_id != -1 else tokenizer.pad_token_id 286 | ), 287 | }) 288 | 289 | 290 | for name, param in model.named_parameters(): 291 | # freeze base model's layers 292 | param.requires_grad = False 293 | 294 | # cast all non INT8 parameters to fp32 295 | for param in model.parameters(): 296 | if (param.dtype == torch.float16) or (param.dtype == torch.bfloat16): 297 | param.data = param.data.to(torch.float32) 298 | 299 | if args.gradient_checkpointing: 300 | if hasattr(model, "enable_input_require_grads"): 301 | model.enable_input_require_grads() 302 | else: 303 | def make_inputs_require_grad(module, input, output): 304 | output.requires_grad_(True) 305 | model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) 306 | 307 | model.gradient_checkpointing_enable() 308 | 309 | for name, module in model.named_modules(): 310 | # if isinstance(module, QuantLinear): 311 | # # transfer trainable step size into float32 312 | # module.scales.data = module.scales.data.to(torch.float32) 313 | if 'norm' in name: 314 | if hasattr(module, 'weight'): 315 | if args.bf16 and module.weight.dtype == torch.float32: 316 | module = module.to(torch.bfloat16) 317 | # module = module.to(torch.float32) 318 | if 'lm_head' in name or 'embed_tokens' in name: 319 | if hasattr(module, 'weight'): 320 | if args.bf16 and module.weight.dtype == torch.float32: 321 | module = module.to(torch.bfloat16) 322 | return model, tokenizer 323 | 324 | def print_trainable_parameters(args, model): 325 | """ 326 | Prints the number of trainable parameters in the model. 327 | """ 328 | trainable_params = 0 329 | all_param = 0 330 | print('trainable module') 331 | print('*'*80) 332 | for name, param in model.named_parameters(): 333 | all_param += param.numel() 334 | if param.requires_grad: 335 | trainable_params += param.numel() 336 | print('*'*80) 337 | if args.wbits == 4: trainable_params /= 2 338 | print( 339 | f"trainable params: {trainable_params} || " 340 | f"all params: {all_param} || " 341 | f"trainable: {100 * trainable_params / all_param}" 342 | ) 343 | 344 | def smart_tokenizer_and_embedding_resize( 345 | special_tokens_dict: Dict, 346 | tokenizer: transformers.PreTrainedTokenizer, 347 | model: transformers.PreTrainedModel, 348 | ): 349 | """Resize tokenizer and embedding. 350 | 351 | Note: This is the unoptimized version that may make your embedding size not be divisible by 64. 352 | """ 353 | num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict) 354 | model.resize_token_embeddings(len(tokenizer)) 355 | 356 | if num_new_tokens > 0: 357 | input_embeddings_data = model.get_input_embeddings().weight.data 358 | output_embeddings_data = model.get_output_embeddings().weight.data 359 | 360 | input_embeddings_avg = input_embeddings_data[:-num_new_tokens].mean(dim=0, keepdim=True) 361 | output_embeddings_avg = output_embeddings_data[:-num_new_tokens].mean(dim=0, keepdim=True) 362 | 363 | input_embeddings_data[-num_new_tokens:] = input_embeddings_avg 364 | output_embeddings_data[-num_new_tokens:] = output_embeddings_avg 365 | 366 | 367 | 368 | 369 | 370 | 371 | def get_last_checkpoint(checkpoint_dir): 372 | if isdir(checkpoint_dir): 373 | is_completed = exists(join(checkpoint_dir, 'completed')) 374 | if is_completed: return None, True # already finished 375 | max_step = 0 376 | for filename in os.listdir(checkpoint_dir): 377 | if isdir(join(checkpoint_dir, filename)) and filename.startswith('checkpoint'): 378 | max_step = max(max_step, int(filename.replace('checkpoint-', ''))) 379 | if max_step == 0: return None, is_completed # training started, but no checkpoint 380 | checkpoint_dir = join(checkpoint_dir, f'checkpoint-{max_step}') 381 | print(f"Found a previous checkpoint at: {checkpoint_dir}") 382 | return checkpoint_dir, is_completed # checkpoint found! 383 | return None, False # first training 384 | 385 | def train(): 386 | hfparser = transformers.HfArgumentParser(( 387 | ModelArguments, DataArguments, TrainingArguments, GenerationArguments 388 | )) 389 | model_args, data_args, training_args, generation_args, extra_args = \ 390 | hfparser.parse_args_into_dataclasses(return_remaining_strings=True) 391 | training_args.generation_config = transformers.GenerationConfig(**vars(generation_args)) 392 | args = argparse.Namespace( 393 | **vars(model_args), **vars(data_args), **vars(training_args) 394 | ) 395 | 396 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 397 | logger = utils.create_logger(args.output_dir) 398 | logger.info(args) 399 | 400 | checkpoint_dir, completed_training = get_last_checkpoint(args.output_dir) 401 | if completed_training: 402 | print('Detected that training was already completed!') 403 | 404 | model, tokenizer = get_accelerate_model(args, checkpoint_dir) 405 | 406 | model.config.use_cache = False 407 | print('loaded model') 408 | set_seed(args.seed) 409 | 410 | data_module = make_data_module(tokenizer=tokenizer, args=args) 411 | 412 | 413 | 414 | optimizer_grouped_parameters = [] 415 | for name, module in model.named_modules(): 416 | # if isinstance(module, LoraLayer): 417 | if isinstance(module, QuantLinear) and not 'head' in name: 418 | module.scales.requires_grad = True 419 | optimizer_grouped_parameters.append({'params': [p for n, p in model.named_parameters() if 'scale' in n], 'weight_decay': 0.0, 'lr': args.learning_rate}) 420 | optimizer = AdamW(optimizer_grouped_parameters) 421 | 422 | trainer = Seq2SeqTrainer( 423 | model=model, 424 | tokenizer=tokenizer, 425 | args=training_args, 426 | optimizers=(optimizer, None), 427 | **{k:v for k,v in data_module.items() if k != 'predict_dataset'}, 428 | ) 429 | 430 | if args.do_ppl_eval: 431 | class PPLvalCallback(transformers.TrainerCallback): 432 | @torch.no_grad() 433 | def on_evaluate(self, args=None, state=None, control=None, model=None, **kwargs): 434 | results = test_ppl(trainer.model, trainer.tokenizer, datasets=['wikitext2','c4'],ppl_seqlen=2048) 435 | logger.info(results) 436 | trainer.log(results) 437 | 438 | trainer.add_callback(PPLvalCallback) 439 | 440 | # Verifying the datatypes and parameter counts before training. 441 | print_trainable_parameters(args, model) 442 | dtypes = {} 443 | for _, p in model.named_parameters(): 444 | dtype = p.dtype 445 | if dtype not in dtypes: dtypes[dtype] = 0 446 | dtypes[dtype] += p.numel() 447 | total = 0 448 | for k, v in dtypes.items(): total+= v 449 | for k, v in dtypes.items(): 450 | print(k, v, v/total) 451 | 452 | all_metrics = {"run_name": args.run_name} 453 | 454 | 455 | 456 | print(args.output_dir) 457 | if args.do_train: 458 | logger.info("*** Train ***") 459 | train_result = trainer.train(args.resume_from_checkpoint) 460 | metrics = train_result.metrics 461 | trainer.log_metrics("train", metrics) 462 | trainer.save_metrics("train", metrics) 463 | trainer.save_state() 464 | all_metrics.update(metrics) 465 | # Evaluation 466 | if args.do_eval: 467 | logger.info("*** Evaluate ***") 468 | metrics = trainer.evaluate(metric_key_prefix="eval") 469 | trainer.log_metrics("eval", metrics) 470 | trainer.save_metrics("eval", metrics) 471 | all_metrics.update(metrics) 472 | # Prediction 473 | if args.do_predict: 474 | logger.info("*** Predict ***") 475 | prediction_output = trainer.predict(test_dataset=data_module['predict_dataset'],metric_key_prefix="predict") 476 | prediction_metrics = prediction_output.metrics 477 | predictions = prediction_output.predictions 478 | predictions = np.where(predictions != -100, predictions, tokenizer.pad_token_id) 479 | predictions = tokenizer.batch_decode( 480 | predictions, skip_special_tokens=True, clean_up_tokenization_spaces=True 481 | ) 482 | with open(os.path.join(args.output_dir, 'predictions.jsonl'), 'w') as fout: 483 | for i, example in enumerate(data_module['predict_dataset']): 484 | example['prediction_with_input'] = predictions[i].strip() 485 | example['prediction'] = predictions[i].replace(example['input'], '').strip() 486 | fout.write(json.dumps(example) + '\n') 487 | print(prediction_metrics) 488 | trainer.log_metrics("predict", prediction_metrics) 489 | trainer.save_metrics("predict", prediction_metrics) 490 | all_metrics.update(prediction_metrics) 491 | 492 | if (args.do_train or args.do_eval or args.do_predict): 493 | with open(os.path.join(args.output_dir, "metrics.json"), "w") as fout: 494 | fout.write(json.dumps(all_metrics)) 495 | 496 | if args.eval_tasks != "" or args.do_mmlu_eval: 497 | import lm_eval 498 | from lm_eval.models.huggingface import HFLM 499 | from lm_eval.utils import make_table 500 | 501 | if args.eval_tasks != "": 502 | task_list = args.eval_tasks.split(',') 503 | lm_eval_model = HFLM(pretrained=model, batch_size=32) 504 | task_manager = lm_eval.tasks.TaskManager() 505 | results = lm_eval.simple_evaluate( # call simple_evaluate 506 | model=lm_eval_model, 507 | tasks=task_list, 508 | num_fewshot=0, 509 | task_manager=task_manager, 510 | ) 511 | logger.info(make_table(results)) 512 | total_acc = 0 513 | for task in task_list: 514 | total_acc += results['results'][task]['acc,none'] 515 | logger.info(f'Average Acc: {total_acc/len(task_list)*100:.2f}%') 516 | 517 | if args.do_mmlu_eval: 518 | lm_eval_model = HFLM(pretrained=model, batch_size=16) 519 | task_manager = lm_eval.tasks.TaskManager() 520 | results = lm_eval.simple_evaluate( # call simple_evaluate 521 | model=lm_eval_model, 522 | tasks=['mmlu'], 523 | num_fewshot=5, 524 | task_manager=task_manager, 525 | cache_requests=True, 526 | ) 527 | logger.info(make_table(results)) 528 | total_acc = 0 529 | for task in results['results']: 530 | total_acc += results['results'][task]['acc,none'] 531 | logger.info(f"Average MMLU Acc: {total_acc/len(results['results'])*100:.2f}%") 532 | 533 | if __name__ == "__main__": 534 | train() 535 | -------------------------------------------------------------------------------- /model_transfer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/EfficientQAT/39175493b2d14617d342a0a7956875e6ac16221b/model_transfer/__init__.py -------------------------------------------------------------------------------- /model_transfer/efficientqat_to_others.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from datautils_block import test_ppl 3 | from transformers import AutoTokenizer 4 | from gptqmodel import GPTQModel, QuantizeConfig, get_backend 5 | from pathlib import Path 6 | import time 7 | 8 | def main(): 9 | import argparse 10 | 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument("--model", default=None, type=str, help="direction for saving quantization model") 13 | parser.add_argument("--wbits", type=int, default=4, help="quantization bits") 14 | parser.add_argument("--group_size", type=int, default=128, help="quantization group size") 15 | parser.add_argument("--target_format", default='gptq', type=str, help="target checkpoint format") 16 | parser.add_argument("--eval_ppl", action="store_true") 17 | parser.add_argument("--test_speed", action="store_true") 18 | parser.add_argument("--save_dir", default=None, type=str, help="direction for saving quantization model") 19 | 20 | 21 | 22 | 23 | args = parser.parse_args() 24 | tokenizer = AutoTokenizer.from_pretrained(args.model, use_fast=False,legacy=False) 25 | quant_config = QuantizeConfig( 26 | bits=args.wbits, 27 | group_size=args.group_size, 28 | sym=False, 29 | desc_act=False, 30 | format='gptq_v2', 31 | ) 32 | if args.target_format == 'gptq': 33 | # EXLLAMA_V2 is faster in 4-bit, and can inference correctly. However, it has some bug in saving models. 34 | # Therefore, we choose triton backend as default. Note that the saving model can also be loaded by exllama too. 35 | model = GPTQModel.from_quantized(args.model, device_map='auto',torch_dtype=torch.float16, quantize_config=quant_config,backend=get_backend('TRITON')) 36 | 37 | elif args.target_format == 'bitblas': 38 | # take a lone time for the first time runing 39 | try: 40 | model = GPTQModel.from_quantized(args.model, device_map='auto',torch_dtype=torch.float16, quantize_config=quant_config,backend=get_backend('BITBLAS')) 41 | args.eval_ppl = False # BitBLAS have bug, which should re-load model for evaluation otherwise would cause wrong outputs 42 | except: 43 | model = GPTQModel.from_quantized(args.model, device_map='auto',torch_dtype=torch.float16, backend=get_backend('BITBLAS')) 44 | else: 45 | raise NotImplementedError 46 | 47 | if args.save_dir: 48 | Path(args.save_dir).mkdir(parents=True, exist_ok=True) 49 | print("start saving model") 50 | model.quantize_config.model_file_base_name=None # trick to avoid one saving bug in GPTQModel 51 | model.save_quantized(args.save_dir,max_shard_size='8GB') 52 | tokenizer.save_pretrained(args.save_dir) 53 | print(f"save model to {args.save_dir} success") 54 | 55 | model.model.cuda() 56 | 57 | if args.eval_ppl: 58 | datasets = ["wikitext2"] 59 | ppl_results = test_ppl(model, tokenizer, datasets, 2048) 60 | for dataset in ppl_results: 61 | print(f'{dataset} perplexity after transfering: {ppl_results[dataset]:.2f}') 62 | if args.test_speed: 63 | prompt = "Write a poem about large language model:" 64 | input_ids = tokenizer(prompt, return_tensors='pt').input_ids.cuda() 65 | start_time = time.time() 66 | output = model.generate(inputs=input_ids, do_sample=True, top_k=10, max_new_tokens=256) 67 | end_time = time.time() 68 | speed = len(output[0])/(end_time-start_time) 69 | print(tokenizer.decode(output[0])) 70 | print(f"generation speed:{speed:.1f}token/s") 71 | 72 | 73 | if __name__ =='__main__': 74 | main() -------------------------------------------------------------------------------- /model_transfer/fp32_to_16.py: -------------------------------------------------------------------------------- 1 | from quantize.int_linear_real import load_quantized_model 2 | import torch 3 | from datautils_block import test_ppl 4 | from pathlib import Path 5 | 6 | 7 | def main(): 8 | import argparse 9 | 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument("--model", type=str, default=None, help="model path of resumed quantized model") 12 | parser.add_argument("--save_dir", default=None, type=str, help="direction for saving quantization model") 13 | parser.add_argument("--wbits", type=int, default=4, help="quantization bits") 14 | parser.add_argument("--group_size", type=int, default=128, help="quantization group size") 15 | parser.add_argument("--target_type",type=str,default="fp16",choices=["fp16", "bf16"]) 16 | parser.add_argument("--eval_ppl", action="store_true",help="evaluate perplexity on wikitext2 with 2048 context length") 17 | 18 | 19 | 20 | args = parser.parse_args() 21 | model, tokenizer = load_quantized_model(args.model,args.wbits, args.group_size) 22 | model.cuda() 23 | 24 | 25 | if args.target_type =='fp16': 26 | dtype = torch.float16 27 | elif args.target_type =='bf16': 28 | dtype = torch.bfloat16 29 | else: 30 | raise NotImplementedError 31 | 32 | if args.eval_ppl: 33 | datasets = ["wikitext2"] 34 | ppl_results = test_ppl(model, tokenizer, datasets, 2048) 35 | for dataset in ppl_results: 36 | print(f'{dataset} perplexity befor transfering: {ppl_results[dataset]:.2f}') 37 | 38 | model.to(dtype) 39 | print(f"transfer model to {args.target_type} format") 40 | if args.eval_ppl: 41 | datasets = ["wikitext2"] 42 | ppl_results = test_ppl(model, tokenizer, datasets, 2048) 43 | for dataset in ppl_results: 44 | print(f'{dataset} perplexity after transfering: {ppl_results[dataset]:.2f}') 45 | 46 | if args.save_dir: 47 | Path(args.save_dir).mkdir(parents=True, exist_ok=True) 48 | print("start saving model") 49 | model.save_pretrained(args.save_dir) 50 | tokenizer.save_pretrained(args.save_dir) 51 | print(f"save model to {args.save_dir} success") 52 | 53 | if __name__ =='__main__': 54 | main() -------------------------------------------------------------------------------- /model_transfer/real_to_fake.py: -------------------------------------------------------------------------------- 1 | from quantize.int_linear_real import load_quantized_model, QuantLinear 2 | import torch 3 | from datautils_block import test_ppl 4 | from pathlib import Path 5 | 6 | def main(): 7 | import argparse 8 | 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument("--model", type=str, default=None, help="model path of resumed quantized model") 11 | parser.add_argument("--save_dir", default=None, type=str, help="direction for saving quantization model") 12 | parser.add_argument("--wbits", type=int, default=4, help="quantization bits") 13 | parser.add_argument("--group_size", type=int, default=128, help="quantization group size") 14 | 15 | 16 | 17 | args = parser.parse_args() 18 | model, tokenizer = load_quantized_model(args.model,args.wbits, args.group_size) 19 | # model.cuda() 20 | 21 | for name, module in model.named_modules(): 22 | if isinstance(module, QuantLinear): 23 | module.cuda() 24 | module.use_fake_quantization(del_quant=True,transpose=True) 25 | module.cpu() 26 | 27 | 28 | if args.save_dir: 29 | Path(args.save_dir).mkdir(parents=True, exist_ok=True) 30 | print("start saving model") 31 | model.to(torch.float16) 32 | model.save_pretrained(args.save_dir) 33 | tokenizer.save_pretrained(args.save_dir) 34 | print(f"save model to {args.save_dir} success") 35 | 36 | if __name__ =='__main__': 37 | main() -------------------------------------------------------------------------------- /quantize/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/EfficientQAT/39175493b2d14617d342a0a7956875e6ac16221b/quantize/__init__.py -------------------------------------------------------------------------------- /quantize/block_ap.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import quantize.int_linear_fake as int_linear_fake 5 | import quantize.int_linear_real as int_linear_real 6 | from torch.optim.lr_scheduler import CosineAnnealingLR 7 | import copy 8 | import math 9 | import utils 10 | import pdb 11 | import gc 12 | from quantize.utils import ( 13 | quant_parameters,weight_parameters,trainable_parameters, 14 | set_quant_state,quant_inplace,set_quant_parameters, 15 | set_weight_parameters,trainable_parameters_num,get_named_linears,set_op_by_name) 16 | import time 17 | from datautils_block import BlockTrainDataset 18 | from torch.utils.data import DataLoader 19 | import shutil 20 | import os 21 | 22 | def update_dataset(layer, dataset, dev, attention_mask, position_ids): 23 | with torch.no_grad(): 24 | with torch.cuda.amp.autocast(): 25 | for index, inps in enumerate(dataset): 26 | inps = inps.to(dev) 27 | if len(inps.shape)==2: 28 | inps = inps.unsqueeze(0) 29 | new_data = layer(inps, attention_mask=attention_mask,position_ids=position_ids)[0].to('cpu') 30 | dataset.update_data(index,new_data) 31 | 32 | 33 | def block_ap( 34 | model, 35 | args, 36 | trainloader, 37 | valloader, 38 | logger=None, 39 | ): 40 | logger.info("Starting ...") 41 | if args.off_load_to_disk: 42 | logger.info("offload the training dataset to disk, saving CPU memory, but may slowdown the training due to additional I/O...") 43 | 44 | dev = torch.device("cuda" if torch.cuda.is_available() else "cpu") 45 | use_cache = model.config.use_cache 46 | model.config.use_cache = False 47 | 48 | # step 1: move embedding layer and first layer to target device, only suppress llama models now 49 | layers = model.model.layers 50 | model.model.embed_tokens = model.model.embed_tokens.to(dev) 51 | model.model.norm = model.model.norm.to(dev) 52 | if hasattr(model.model, 'rotary_emb'): 53 | # for llama-3.1 54 | model.model.rotary_emb = model.model.rotary_emb.to(dev) 55 | layers[0] = layers[0].to(dev) 56 | dtype = torch.float16 57 | 58 | # step 2: init dataset 59 | flag = time.time() 60 | if args.off_load_to_disk: 61 | fp_train_cache_path = f'{args.cache_dir}/{flag}/block_training_fp_train' 62 | fp_val_cache_path = f'{args.cache_dir}/{flag}/block_training_fp_val' 63 | quant_train_cache_path = f'{args.cache_dir}/{flag}/block_training_quant_train' 64 | quant_val_cache_path = f'{args.cache_dir}/{flag}/block_training_quant_val' 65 | for path in [fp_train_cache_path,fp_val_cache_path,quant_train_cache_path,quant_val_cache_path]: 66 | if os.path.exists(path): 67 | shutil.rmtree(path) 68 | else: 69 | fp_train_cache_path = None 70 | fp_val_cache_path = None 71 | quant_train_cache_path = None 72 | quant_val_cache_path = None 73 | fp_train_inps = BlockTrainDataset(args.train_size, args.training_seqlen, 74 | model.config.hidden_size, args.batch_size, dtype, cache_path=fp_train_cache_path,off_load_to_disk=args.off_load_to_disk) 75 | fp_val_inps = BlockTrainDataset(args.val_size, args.training_seqlen, 76 | model.config.hidden_size, args.batch_size, dtype, cache_path=fp_val_cache_path,off_load_to_disk=args.off_load_to_disk) 77 | 78 | # step 3: catch the input of thefirst layer 79 | class Catcher(nn.Module): 80 | def __init__(self, module, dataset): 81 | super().__init__() 82 | self.module = module 83 | self.dataset = dataset 84 | self.index = 0 85 | self.attention_mask = None 86 | self.position_ids = None 87 | 88 | def forward(self, inp, **kwargs): 89 | self.dataset.update_data(self.index, inp.squeeze(0).to('cpu')) 90 | self.index += 1 91 | if self.attention_mask is None: 92 | self.attention_mask = kwargs["attention_mask"] 93 | if self.position_ids is None: 94 | self.position_ids = kwargs["position_ids"] 95 | raise ValueError 96 | 97 | # step 3.1: catch the input of training set 98 | layers[0] = Catcher(layers[0],fp_train_inps) 99 | iters = len(trainloader)//args.batch_size 100 | with torch.no_grad(): 101 | for i in range(iters): 102 | data = torch.cat([trainloader[j][0] for j in range(i*args.batch_size,(i+1)*args.batch_size)],dim=0) 103 | try: 104 | model(data.to(dev)) 105 | except ValueError: 106 | pass 107 | layers[0] = layers[0].module 108 | 109 | # step 3.2: catch the input of validation set 110 | layers[0] = Catcher(layers[0],fp_val_inps) 111 | iters = len(valloader)//args.batch_size 112 | with torch.no_grad(): 113 | for i in range(iters): 114 | data = torch.cat([valloader[j][0] for j in range(i*args.batch_size,(i+1)*args.batch_size)],dim=0) 115 | try: 116 | model(data.to(dev)) 117 | except ValueError: 118 | pass 119 | attention_mask = layers[0].attention_mask 120 | position_ids = layers[0].position_ids 121 | layers[0] = layers[0].module 122 | if attention_mask is not None: 123 | attention_mask_batch = attention_mask.repeat(args.batch_size,1,1,1).float() 124 | else: 125 | logger.info( 126 | "No attention mask caught from the first layer." 127 | " Seems that model's attention works without a mask." 128 | ) 129 | attention_mask_batch = None 130 | 131 | # step 4: move embedding layer and first layer to cpu 132 | layers[0] = layers[0].cpu() 133 | model.model.embed_tokens = model.model.embed_tokens.cpu() 134 | model.model.norm = model.model.norm.cpu() 135 | if hasattr(model.model, 'rotary_emb'): 136 | # for llama-3.1 137 | model.model.rotary_emb = model.model.rotary_emb.cpu() 138 | torch.cuda.empty_cache() 139 | 140 | # step 5: copy fp input as the quant input, they are same at the first layer 141 | if args.off_load_to_disk: 142 | # copy quant input from fp input, they are same in first layer 143 | shutil.copytree(fp_train_cache_path, quant_train_cache_path) 144 | shutil.copytree(fp_val_cache_path, quant_val_cache_path) 145 | quant_train_inps = BlockTrainDataset(args.train_size, args.training_seqlen, 146 | model.config.hidden_size, args.batch_size, dtype, cache_path=quant_train_cache_path,off_load_to_disk=args.off_load_to_disk) 147 | quant_val_inps = BlockTrainDataset(args.val_size, args.training_seqlen, 148 | model.config.hidden_size, args.batch_size, dtype, cache_path=quant_val_cache_path,off_load_to_disk=args.off_load_to_disk) 149 | else: 150 | quant_train_inps = BlockTrainDataset(args.train_size, args.training_seqlen, 151 | model.config.hidden_size, args.batch_size, dtype, cache_path=quant_train_cache_path,off_load_to_disk=args.off_load_to_disk) 152 | quant_val_inps = BlockTrainDataset(args.val_size, args.training_seqlen, 153 | model.config.hidden_size, args.batch_size, dtype, cache_path=quant_val_cache_path,off_load_to_disk=args.off_load_to_disk) 154 | for index,data in enumerate(fp_train_inps): 155 | quant_train_inps.update_data(index, data) 156 | for index,data in enumerate(fp_val_inps): 157 | quant_val_inps.update_data(index, data) 158 | 159 | # step 6: start training 160 | loss_func = torch.nn.MSELoss() 161 | for block_index in range(len(layers)): 162 | logger.info(f"=== Start quantize blocks {block_index}===") 163 | # step 6.1: replace torch.nn.Linear with QuantLinear for QAT 164 | layer = layers[block_index].to(dev) 165 | qlayer = copy.deepcopy(layer) 166 | for name, module in qlayer.named_modules(): 167 | if isinstance(module,torch.nn.Linear): 168 | quantlinear = int_linear_fake.QuantLinear(module, args.wbits, args.group_size) 169 | set_op_by_name(qlayer, name, quantlinear) 170 | del module 171 | qlayer.to(dev) 172 | 173 | 174 | # step 6.2: obtain output of full-precision model for MSE 175 | set_quant_state(qlayer,weight_quant=False) # deactivate quantization for obtaining ground truth 176 | if args.epochs > 0: 177 | update_dataset(qlayer,fp_train_inps,dev,attention_mask,position_ids) 178 | update_dataset(qlayer,fp_val_inps,dev,attention_mask,position_ids) 179 | set_quant_state(qlayer,weight_quant=True) # activate quantization 180 | 181 | 182 | if args.epochs > 0: 183 | with torch.no_grad(): 184 | qlayer.float() # fp32 is required for AMP training 185 | # step 6.3: create optimizer and learning rate schedule 186 | param = [] 187 | assert args.quant_lr > 0 or args.weight_lr > 0 188 | param_group_index = 0 189 | total_training_iteration = args.epochs * args.train_size / args.batch_size 190 | if args.quant_lr > 0: 191 | set_quant_parameters(qlayer,True) 192 | param.append({"params":quant_parameters(qlayer),"lr":args.quant_lr}) 193 | empty_optimizer_1 = torch.optim.AdamW([torch.tensor(0)], lr=args.quant_lr) 194 | quant_scheduler = CosineAnnealingLR(empty_optimizer_1, T_max=total_training_iteration, eta_min=args.quant_lr/args.min_lr_factor) 195 | quant_index = param_group_index 196 | param_group_index += 1 197 | else: 198 | set_quant_parameters(qlayer,False) 199 | 200 | if args.weight_lr > 0: 201 | set_weight_parameters(qlayer,True) 202 | param.append({"params":weight_parameters(qlayer),"lr":args.weight_lr}) 203 | empty_optimizer_2 = torch.optim.AdamW([torch.tensor(0)], lr=args.weight_lr) 204 | weight_scheduler = CosineAnnealingLR(empty_optimizer_2, T_max=total_training_iteration, eta_min=args.weight_lr/args.min_lr_factor) 205 | weight_index = param_group_index 206 | param_group_index += 1 207 | else: 208 | set_weight_parameters(qlayer,False) 209 | optimizer = torch.optim.AdamW(param, weight_decay=args.wd) 210 | loss_scaler = utils.NativeScalerWithGradNormCount() 211 | trainable_number = trainable_parameters_num(qlayer) 212 | print(f"trainable parameter number: {trainable_number/1e6}M") 213 | 214 | best_val_loss = 1e6 215 | early_stop_flag = 0 216 | for epoch in range(args.epochs): 217 | # step: 6.4 training 218 | loss_list = [] 219 | norm_list = [] 220 | start_time = time.time() 221 | for index, (quant_inps, fp_inps) in enumerate(zip(quant_train_inps, fp_train_inps)): 222 | # obtain output of quantization model 223 | with torch.cuda.amp.autocast(): 224 | input = quant_inps.to(dev) 225 | label = fp_inps.to(dev) 226 | quant_out = qlayer(input, attention_mask=attention_mask_batch,position_ids=position_ids)[0] 227 | reconstruction_loss = loss_func(label, quant_out) 228 | loss = reconstruction_loss 229 | 230 | if not math.isfinite(loss.item()): 231 | logger.info("Loss is NAN, stopping training") 232 | pdb.set_trace() 233 | loss_list.append(reconstruction_loss.detach().cpu()) 234 | optimizer.zero_grad() 235 | norm = loss_scaler(loss, optimizer,parameters=trainable_parameters(qlayer)).cpu() 236 | norm_list.append(norm.data) 237 | 238 | # adjust lr 239 | if args.quant_lr > 0: 240 | quant_scheduler.step() 241 | optimizer.param_groups[quant_index]['lr'] = quant_scheduler.get_lr()[0] 242 | if args.weight_lr >0 : 243 | weight_scheduler.step() 244 | optimizer.param_groups[weight_index]['lr'] = weight_scheduler.get_lr()[0] 245 | 246 | # step 6.5: calculate validation loss 247 | val_loss_list = [] 248 | for index, (quant_inps,fp_inps) in enumerate(zip(quant_val_inps, fp_val_inps)): 249 | # obtain output of quantization model 250 | with torch.no_grad(): 251 | with torch.cuda.amp.autocast(): 252 | input = quant_inps.to(dev) 253 | label = fp_inps.to(dev) 254 | quant_out = qlayer(input, attention_mask=attention_mask_batch,position_ids=position_ids)[0] 255 | reconstruction_loss = loss_func(label, quant_out) 256 | val_loss_list.append(reconstruction_loss.cpu()) 257 | 258 | train_mean_num = min(len(loss_list),64) # calculate the average training loss of last train_mean_num samples 259 | loss_mean = torch.stack(loss_list)[-(train_mean_num-1):].mean() 260 | val_loss_mean = torch.stack(val_loss_list).mean() 261 | norm_mean = torch.stack(norm_list).mean() 262 | logger.info(f"blocks {block_index} epoch {epoch} recon_loss:{loss_mean} val_loss:{val_loss_mean} quant_lr:{quant_scheduler.get_lr()[0]} norm:{norm_mean:.8f} max memory_allocated {torch.cuda.max_memory_allocated(dev) / 1024**2} time {time.time()-start_time} ") 263 | if val_loss_mean < best_val_loss: 264 | best_val_loss = val_loss_mean 265 | else: 266 | early_stop_flag += 1 267 | if args.early_stop > 0 and early_stop_flag >=args.early_stop: 268 | break 269 | optimizer.zero_grad() 270 | del optimizer 271 | 272 | # step 6.6: directly replace the weight with fake quantization 273 | qlayer.half() 274 | quant_inplace(qlayer) 275 | set_quant_state(qlayer,weight_quant=False) # weight has been quantized inplace 276 | 277 | # step 6.7: update inputs of quantization model 278 | if args.epochs>0: 279 | update_dataset(qlayer,quant_train_inps,dev,attention_mask,position_ids) 280 | update_dataset(qlayer,quant_val_inps,dev,attention_mask,position_ids) 281 | layers[block_index] = qlayer.to("cpu") 282 | 283 | # step 7: pack quantized weights into low-bits format, note that this process is slow on poor CPU or busy CPU 284 | if args.real_quant: 285 | named_linears = get_named_linears(qlayer, int_linear_fake.QuantLinear) 286 | for name, module in named_linears.items(): 287 | scales = module.weight_quantizer.scale.clamp(1e-4,1e4).detach() 288 | zeros = module.weight_quantizer.zero_point.detach().cuda().round().cpu() 289 | group_size = module.weight_quantizer.group_size 290 | dim0 = module.weight.shape[0] 291 | scales = scales.view(dim0,-1).transpose(0,1).contiguous() 292 | zeros = zeros.view(dim0,-1).transpose(0,1).contiguous() 293 | q_linear = int_linear_real.QuantLinear(args.wbits, group_size, module.in_features,module.out_features,not module.bias is None) 294 | q_linear.pack(module.cpu(), scales.float().cpu(), zeros.float().cpu()) 295 | set_op_by_name(qlayer, name, q_linear) 296 | logger.info(f"pack quantized {name} finished") 297 | del module 298 | del layer 299 | torch.cuda.empty_cache() 300 | 301 | # delete cached dataset 302 | if args.off_load_to_disk: 303 | for path in [fp_train_cache_path,fp_val_cache_path,quant_train_cache_path,quant_val_cache_path]: 304 | if os.path.exists(path): 305 | shutil.rmtree(path) 306 | 307 | torch.cuda.empty_cache() 308 | gc.collect() 309 | model.config.use_cache = use_cache 310 | return model 311 | 312 | -------------------------------------------------------------------------------- /quantize/int_linear_fake.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from quantize.quantizer import UniformAffineQuantizer 5 | 6 | 7 | 8 | 9 | 10 | class QuantLinear(nn.Module): 11 | """ 12 | Quantized Module that can perform quantized convolution or normal convolution. 13 | To activate quantization, please use set_quant_state function. 14 | """ 15 | def __init__( 16 | self, 17 | org_module: nn.Linear, 18 | wbits=4, 19 | group_size=64 20 | ): 21 | super().__init__() 22 | self.fwd_kwargs = dict() 23 | self.fwd_func = F.linear 24 | self.register_parameter('weight',org_module.weight) # trainable 25 | if org_module.bias is not None: 26 | self.register_buffer('bias',org_module.bias) 27 | else: 28 | self.bias = None 29 | self.in_features = org_module.in_features 30 | self.out_features = org_module.out_features 31 | # de-activate the quantized forward default 32 | self.use_weight_quant = False 33 | # initialize quantizer 34 | self.weight_quantizer = UniformAffineQuantizer(wbits, group_size, weight=org_module.weight) 35 | self.use_temporary_parameter = False 36 | 37 | 38 | 39 | def forward(self, input: torch.Tensor): 40 | if self.use_weight_quant: 41 | weight = self.weight_quantizer(self.weight) 42 | bias = self.bias 43 | else: 44 | weight = self.weight 45 | bias = self.bias 46 | 47 | 48 | out = self.fwd_func(input, weight, bias, **self.fwd_kwargs) 49 | 50 | 51 | return out 52 | 53 | def set_quant_state(self, weight_quant: bool = False): 54 | self.use_weight_quant = weight_quant 55 | 56 | 57 | 58 | 59 | -------------------------------------------------------------------------------- /quantize/int_linear_real.py: -------------------------------------------------------------------------------- 1 | import math 2 | from logging import getLogger 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | import transformers 8 | 9 | from quantize.triton_utils.kernels import dequant_dim0, dequant_dim1 10 | import math 11 | from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig 12 | from accelerate import init_empty_weights, infer_auto_device_map, load_checkpoint_in_model 13 | from tqdm import tqdm 14 | import gc 15 | from quantize.utils import get_named_linears,set_op_by_name 16 | 17 | logger = getLogger(__name__) 18 | 19 | 20 | class TritonModuleMixin: 21 | @classmethod 22 | def warmup(cls, model, transpose=False, seqlen=2048): 23 | pass 24 | 25 | 26 | class QuantLinear(nn.Module, TritonModuleMixin): 27 | QUANT_TYPE = "triton" 28 | 29 | def __init__( 30 | self, 31 | bits, 32 | group_size, 33 | infeatures, 34 | outfeatures, 35 | bias, 36 | trainable=False, 37 | **kwargs 38 | ): 39 | super().__init__() 40 | # if bits not in [2, 4, 8]: 41 | # raise NotImplementedError("Only 2,4,8 bits are supported.") 42 | # if infeatures % 32 != 0 or outfeatures % 32 != 0: 43 | # raise NotImplementedError("in_feature and out_feature must be divisible by 32.") 44 | self.infeatures = infeatures 45 | self.outfeatures = outfeatures 46 | self.bits = bits 47 | self.group_size = group_size if group_size != -1 else infeatures 48 | self.maxq = 2 ** self.bits - 1 49 | self.register_buffer( 50 | 'qweight', 51 | torch.zeros((math.ceil(infeatures / (32 // self.bits)), outfeatures), dtype=torch.int32) 52 | ) 53 | self.register_parameter( 54 | 'scales', 55 | torch.nn.Parameter(torch.zeros((math.ceil(infeatures / self.group_size), outfeatures), dtype=torch.float16)) 56 | ) 57 | self.register_buffer( 58 | 'qzeros', 59 | torch.zeros((math.ceil(infeatures / self.group_size), math.ceil(outfeatures / (32 // self.bits))), dtype=torch.int32) 60 | ) 61 | self.register_buffer( 62 | 'g_idx', 63 | torch.tensor([i // self.group_size for i in range(infeatures)], dtype=torch.int32) 64 | ) # not used, just for consistent with GPTQ models 65 | if bias: 66 | self.register_buffer('bias', torch.zeros((outfeatures), dtype=torch.float16)) 67 | else: 68 | self.bias = None 69 | 70 | self.zeros_dim0, self.zeros_dim1 = self.scales.shape 71 | self.trainable = trainable 72 | self.scales.requires_grad = True 73 | self.use_fake = False 74 | 75 | def post_init(self): 76 | pass 77 | 78 | 79 | def use_fake_quantization(self, del_quant=False,transpose=False): 80 | # use fake quantization for faster training but consume more memory 81 | weight = dequant_dim0(self.qweight, self.bits, self.maxq, self.infeatures, self.outfeatures) 82 | dim0, dim1 = weight.shape 83 | zeros = dequant_dim1(self.qzeros, self.bits, self.maxq, self.zeros_dim0, self.zeros_dim1) 84 | weight = ((weight.view(-1, self.group_size, dim1) - zeros.view(-1, 1, dim1)) * self.scales.view(-1, 1, dim1)).reshape(dim0, dim1) 85 | if transpose: 86 | self.fake_transpose = True 87 | weight = weight.transpose(0,1).contiguous() 88 | self.register_buffer( 89 | 'weight', 90 | weight 91 | ) 92 | self.use_fake = True 93 | if del_quant: 94 | del self.qweight 95 | del self.scales 96 | del self.qzeros 97 | del self.g_idx 98 | 99 | def pack(self, linear, scales, zeros, g_idx=None): 100 | W = linear.weight.data.clone() 101 | if isinstance(linear, nn.Conv2d): 102 | W = W.flatten(1) 103 | if isinstance(linear, transformers.pytorch_utils.Conv1D): 104 | W = W.t() 105 | 106 | g_idx = torch.tensor([i // self.group_size for i in range(self.infeatures)], dtype=torch.int32) 107 | 108 | scale_zeros = zeros * scales 109 | self.scales = nn.Parameter(scales.half()) 110 | if linear.bias is not None: 111 | self.bias = linear.bias.clone().half() 112 | 113 | intweight = [] 114 | for idx in range(self.infeatures): 115 | intweight.append( 116 | torch.round( 117 | ( 118 | W[:, idx] + scale_zeros[g_idx[idx]]) / self.scales[g_idx[idx]] 119 | ).to(torch.int)[:, None] 120 | ) 121 | intweight = torch.cat(intweight, dim=1) 122 | intweight = intweight.t().contiguous() 123 | intweight = intweight.numpy().astype(np.uint32) 124 | 125 | i = 0 126 | row = 0 127 | qweight = np.zeros((math.ceil(intweight.shape[0]/(32//self.bits)), intweight.shape[1]), dtype=np.uint32) 128 | while row < qweight.shape[0]: 129 | if self.bits in [2, 3, 4, 8]: 130 | for j in range(i, min(i + (32 // self.bits), intweight.shape[0])): 131 | qweight[row] |= intweight[j] << (self.bits * (j - i)) 132 | i += 32 // self.bits 133 | row += 1 134 | else: 135 | raise NotImplementedError("Only 2,3,4,8 bits are supported.") 136 | 137 | qweight = qweight.astype(np.int32) 138 | self.qweight = torch.from_numpy(qweight) 139 | 140 | zeros = zeros.numpy().astype(np.uint32) 141 | self.zeros_dim0, self.zeros_dim1 = zeros.shape 142 | qzeros = np.zeros((zeros.shape[0], math.ceil(zeros.shape[1] / (32 // self.bits))), dtype=np.uint32) 143 | i = 0 144 | col = 0 145 | while col < qzeros.shape[1]: 146 | if self.bits in [2, 3, 4, 8]: 147 | for j in range(i, min(i + (32 // self.bits), zeros.shape[1])): 148 | qzeros[:, col] |= zeros[:, j] << (self.bits * (j - i)) 149 | i += 32 // self.bits 150 | col += 1 151 | else: 152 | raise NotImplementedError("Only 2,3,4,8 bits are supported.") 153 | 154 | qzeros = qzeros.astype(np.int32) 155 | self.qzeros = torch.from_numpy(qzeros) 156 | 157 | def forward(self, x): 158 | if self.use_fake: 159 | weight = self.weight 160 | if self.fake_transpose: 161 | weight = weight.transpose(0,1) 162 | else: 163 | weight = dequant_dim0(self.qweight, self.bits, self.maxq, self.infeatures, self.outfeatures) 164 | dim0, dim1 = weight.shape 165 | # dim2 = (dim1*dim0)//self.group_size 166 | zeros = dequant_dim1(self.qzeros, self.bits, self.maxq, self.zeros_dim0, self.zeros_dim1) 167 | weight = ((weight.view(-1, self.group_size, dim1) - zeros.view(-1, 1, dim1)) * self.scales.view(-1, 1, dim1)).reshape(dim0, dim1) 168 | # out = torch.matmul(x, weight) 169 | out = torch.matmul(x, weight.to(x.dtype)) 170 | out = out + self.bias if self.bias is not None else out 171 | return out 172 | 173 | 174 | def load_quantized_model(model_path, wbits, group_size): 175 | print(f"Loading quantized model from {model_path}") 176 | 177 | # import pdb;pdb.set_trace() 178 | tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) 179 | config = AutoConfig.from_pretrained(model_path) 180 | with init_empty_weights(): 181 | model = AutoModelForCausalLM.from_config(config=config,torch_dtype=torch.float16, trust_remote_code=True) 182 | layers = model.model.layers 183 | for i in tqdm(range(len(layers))): 184 | layer = layers[i] 185 | named_linears = get_named_linears(layer, torch.nn.Linear) 186 | for name, module in named_linears.items(): 187 | q_linear = QuantLinear(wbits, group_size, module.in_features,module.out_features,not module.bias is None) 188 | q_linear.to(next(layer.parameters()).device) 189 | set_op_by_name(layer, name, q_linear) 190 | torch.cuda.empty_cache() 191 | gc.collect() 192 | model.tie_weights() 193 | device_map = infer_auto_device_map(model) 194 | print("Loading pre-computed quantized weights...") 195 | load_checkpoint_in_model(model,checkpoint=model_path,device_map=device_map,offload_state_dict=True) 196 | print("Loading pre-computed quantized weights Successfully") 197 | 198 | return model, tokenizer 199 | 200 | __all__ = ["QuantLinear","load_omniq_quantized"] 201 | -------------------------------------------------------------------------------- /quantize/quantizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | import pdb 5 | 6 | CLIPMIN = 1e-4 7 | 8 | 9 | 10 | def round_ste(x: torch.Tensor): 11 | """ 12 | Implement Straight-Through Estimator for rounding operation. 13 | """ 14 | return (x.round() - x).detach() + x 15 | 16 | def clamp_ste(x: torch.Tensor, min, max): 17 | return (x.clamp(min,max) - x).detach() + x 18 | 19 | def clamp_ste(x: torch.Tensor, min, max): 20 | return (x.clamp(min,max) - x).detach() + x 21 | 22 | 23 | class UniformAffineQuantizer(nn.Module): 24 | def __init__( 25 | self, 26 | n_bits: int = 8, 27 | group_size=None, 28 | weight=None, 29 | ): 30 | super().__init__() 31 | assert 2 <= n_bits <= 16, "bitwidth not supported" 32 | self.n_bits = n_bits 33 | self.qmin = 0 34 | self.qmax = 2 ** (n_bits) - 1 35 | self.group_size = group_size if group_size != -1 else weight.shape[-1] 36 | assert weight.shape[-1] % group_size == 0 37 | self.enable = True 38 | 39 | # init scale and zero point through Max-Min quantization 40 | with torch.no_grad(): 41 | if weight is not None: 42 | x = weight.reshape(-1,self.group_size) 43 | xmin = x.amin([-1], keepdim=True) 44 | xmax = x.amax([-1], keepdim=True) 45 | range = xmax - xmin 46 | scale = range / (2**self.n_bits-1) 47 | scale = scale.clamp(min=1e-4, max=1e4) 48 | zero_point = -(xmin/scale).clamp(min=-1e4, max=1e4) 49 | self.scale = nn.Parameter(scale) 50 | self.zero_point = nn.Parameter(zero_point.round()) 51 | 52 | 53 | def change_n_bits(self, n_bits): 54 | self.n_bits = n_bits 55 | self.qmin = 0 56 | self.qmax = int(2 ** (n_bits) - 1) 57 | 58 | def fake_quant(self, x): 59 | scale = clamp_ste(self.scale,1e-4, 1e4) 60 | round_zero_point = clamp_ste(round_ste(self.zero_point), self.qmin, self.qmax) 61 | 62 | dim1, dim2 = x.shape 63 | x = x.reshape(-1, self.group_size) 64 | x_int = round_ste(x / scale) 65 | if round_zero_point is not None: 66 | x_int = x_int.add(round_zero_point) 67 | x_int = x_int.clamp(self.qmin, self.qmax) 68 | x_dequant = x_int 69 | if round_zero_point is not None: 70 | x_dequant = x_dequant.sub(round_zero_point) 71 | x_dequant = x_dequant.mul(scale) 72 | if self.group_size: 73 | x_dequant = x_dequant.reshape(dim1, dim2) 74 | return x_dequant 75 | 76 | 77 | def forward(self, x: torch.Tensor): 78 | if self.n_bits >= 16 or not self.enable: 79 | return x 80 | 81 | x_dequant = self.fake_quant(x) 82 | return x_dequant 83 | 84 | 85 | 86 | -------------------------------------------------------------------------------- /quantize/triton_utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/EfficientQAT/39175493b2d14617d342a0a7956875e6ac16221b/quantize/triton_utils/__init__.py -------------------------------------------------------------------------------- /quantize/triton_utils/custom_autotune.py: -------------------------------------------------------------------------------- 1 | import builtins 2 | import math 3 | import time 4 | from typing import Dict 5 | 6 | import triton 7 | 8 | 9 | # code based https://github.com/fpgaminer/GPTQ-triton 10 | """ 11 | Mostly the same as the autotuner in Triton, but with a few changes like using 40 runs instead of 100. 12 | """ 13 | 14 | 15 | class CustomizedTritonAutoTuner(triton.KernelInterface): 16 | def __init__( 17 | self, 18 | fn, 19 | arg_names, 20 | configs, 21 | key, 22 | reset_to_zero, 23 | prune_configs_by: Dict = None, 24 | nearest_power_of_two: bool = False 25 | ): 26 | if not configs: 27 | self.configs = [triton.Config({}, num_warps=4, num_stages=2)] 28 | else: 29 | self.configs = configs 30 | self.key_idx = [arg_names.index(k) for k in key] 31 | self.nearest_power_of_two = nearest_power_of_two 32 | self.cache = {} 33 | # hook to reset all required tensor to zeros before relaunching a kernel 34 | self.hook = lambda args: 0 35 | if reset_to_zero is not None: 36 | self.reset_idx = [arg_names.index(k) for k in reset_to_zero] 37 | 38 | def _hook(args): 39 | for i in self.reset_idx: 40 | args[i].zero_() 41 | 42 | self.hook = _hook 43 | self.arg_names = arg_names 44 | # prune configs 45 | if prune_configs_by: 46 | perf_model, top_k = prune_configs_by['perf_model'], prune_configs_by['top_k'] 47 | if 'early_config_prune' in prune_configs_by: 48 | early_config_prune = prune_configs_by['early_config_prune'] 49 | else: 50 | perf_model, top_k, early_config_prune = None, None, None 51 | self.perf_model, self.configs_top_k = perf_model, top_k 52 | self.early_config_prune = early_config_prune 53 | self.fn = fn 54 | 55 | def _bench(self, *args, config, **meta): 56 | # check for conflicts, i.e. meta-parameters both provided 57 | # as kwargs and by the autotuner 58 | conflicts = meta.keys() & config.kwargs.keys() 59 | if conflicts: 60 | raise ValueError(f"Conflicting meta-parameters: {', '.join(conflicts)}." 61 | " Make sure that you don't re-define auto-tuned symbols.") 62 | # augment meta-parameters with tunable ones 63 | current = dict(meta, **config.kwargs) 64 | 65 | def kernel_call(): 66 | if config.pre_hook: 67 | config.pre_hook(self.nargs) 68 | self.hook(args) 69 | self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **current) 70 | 71 | try: 72 | # In testings using only 40 reps seems to be close enough and it appears to be what PyTorch uses 73 | # PyTorch also sets fast_flush to True, but I didn't see any speedup so I'll leave the default 74 | return triton.testing.do_bench(kernel_call, quantiles=(0.5, 0.2, 0.8), rep=40) 75 | except triton.OutOfResources: 76 | # except triton.OutOfResources: 77 | return (float('inf'), float('inf'), float('inf')) 78 | 79 | def run(self, *args, **kwargs): 80 | self.nargs = dict(zip(self.arg_names, args)) 81 | if len(self.configs) > 1: 82 | key = tuple(args[i] for i in self.key_idx) 83 | 84 | # This reduces the amount of autotuning by rounding the keys to the nearest power of two 85 | # In my testing this gives decent results, and greatly reduces the amount of tuning required 86 | if self.nearest_power_of_two: 87 | key = tuple([2 ** int(math.log2(x) + 0.5) for x in key]) 88 | 89 | if key not in self.cache: 90 | # prune configs 91 | pruned_configs = self.prune_configs(kwargs) 92 | bench_start = time.time() 93 | timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs} 94 | bench_end = time.time() 95 | self.bench_time = bench_end - bench_start 96 | self.cache[key] = builtins.min(timings, key=timings.get) 97 | self.hook(args) 98 | self.configs_timings = timings 99 | config = self.cache[key] 100 | else: 101 | config = self.configs[0] 102 | self.best_config = config 103 | if config.pre_hook is not None: 104 | config.pre_hook(self.nargs) 105 | return self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **kwargs, **config.kwargs) 106 | 107 | def prune_configs(self, kwargs): 108 | pruned_configs = self.configs 109 | if self.early_config_prune: 110 | pruned_configs = self.early_config_prune(self.configs, self.nargs) 111 | if self.perf_model: 112 | top_k = self.configs_top_k 113 | if isinstance(top_k, float) and top_k <= 1.0: 114 | top_k = int(len(self.configs) * top_k) 115 | if len(pruned_configs) > top_k: 116 | est_timing = { 117 | config: self.perf_model(**self.nargs, **kwargs, **config.kwargs, num_stages=config.num_stages, 118 | num_warps=config.num_warps) for config in pruned_configs} 119 | pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[:top_k] 120 | return pruned_configs 121 | 122 | def warmup(self, *args, **kwargs): 123 | self.nargs = dict(zip(self.arg_names, args)) 124 | for config in self.prune_configs(kwargs): 125 | self.fn.warmup( 126 | *args, 127 | num_warps=config.num_warps, 128 | num_stages=config.num_stages, 129 | **kwargs, 130 | **config.kwargs, 131 | ) 132 | self.nargs = None 133 | 134 | 135 | def autotune(configs, key, prune_configs_by=None, reset_to_zero=None, nearest_power_of_two=False): 136 | def decorator(fn): 137 | return CustomizedTritonAutoTuner( 138 | fn, fn.arg_names, configs, key, reset_to_zero, prune_configs_by, nearest_power_of_two 139 | ) 140 | 141 | return decorator 142 | 143 | 144 | def matmul248_kernel_config_pruner(configs, nargs): 145 | """ 146 | The main purpose of this function is to shrink BLOCK_SIZE_* when the corresponding dimension is smaller. 147 | """ 148 | m = max(2 ** int(math.ceil(math.log2(nargs['M']))), 16) 149 | n = max(2 ** int(math.ceil(math.log2(nargs['N']))), 16) 150 | k = max(2 ** int(math.ceil(math.log2(nargs['K']))), 16) 151 | 152 | used = set() 153 | for config in configs: 154 | block_size_m = min(m, config.kwargs['BLOCK_SIZE_M']) 155 | block_size_n = min(n, config.kwargs['BLOCK_SIZE_N']) 156 | block_size_k = min(k, config.kwargs['BLOCK_SIZE_K']) 157 | group_size_m = config.kwargs['GROUP_SIZE_M'] 158 | 159 | if (block_size_m, block_size_n, block_size_k, group_size_m, config.num_stages, config.num_warps) in used: 160 | continue 161 | 162 | used.add((block_size_m, block_size_n, block_size_k, group_size_m, config.num_stages, config.num_warps)) 163 | yield triton.Config( 164 | { 165 | 'BLOCK_SIZE_M': block_size_m, 166 | 'BLOCK_SIZE_N': block_size_n, 167 | 'BLOCK_SIZE_K': block_size_k, 168 | 'GROUP_SIZE_M': group_size_m 169 | }, 170 | num_stages=config.num_stages, 171 | num_warps=config.num_warps 172 | ) 173 | 174 | def hadamard248_kernel_config_pruner(configs, nargs): 175 | """ 176 | The main purpose of this function is to shrink BLOCK_SIZE_* when the corresponding dimension is smaller. 177 | """ 178 | m = max(2 ** int(math.ceil(math.log2(nargs['M']))), 16) 179 | n = max(2 ** int(math.ceil(math.log2(nargs['N']))), 16) 180 | 181 | used = set() 182 | for config in configs: 183 | block_size_m = min(m, config.kwargs['BLOCK_SIZE_M']) 184 | block_size_n = min(n, config.kwargs['BLOCK_SIZE_N']) 185 | 186 | if (block_size_m, block_size_n , config.num_stages, config.num_warps) in used: 187 | continue 188 | 189 | used.add((block_size_m, block_size_n, config.num_stages, config.num_warps)) 190 | yield triton.Config( 191 | { 192 | 'BLOCK_SIZE_M': block_size_m, 193 | 'BLOCK_SIZE_N': block_size_n, 194 | }, 195 | num_stages=config.num_stages, 196 | num_warps=config.num_warps 197 | ) 198 | 199 | 200 | __all__ = ["autotune"] 201 | -------------------------------------------------------------------------------- /quantize/triton_utils/kernels.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.cuda.amp import custom_bwd, custom_fwd 3 | from logging import getLogger 4 | 5 | import triton 6 | import triton.language as tl 7 | 8 | from . import custom_autotune 9 | import pdb 10 | 11 | logger = getLogger(__name__) 12 | 13 | 14 | # code based https://github.com/fpgaminer/GPTQ-triton 15 | @custom_autotune.autotune( 16 | configs=[ 17 | triton.Config( 18 | { 19 | 'BLOCK_SIZE_M': 64, 20 | 'BLOCK_SIZE_N': 256, 21 | }, 22 | num_stages=4, 23 | num_warps=4 24 | ), 25 | triton.Config( 26 | { 27 | 'BLOCK_SIZE_M': 128, 28 | 'BLOCK_SIZE_N': 128, 29 | }, 30 | num_stages=4, 31 | num_warps=4 32 | ), 33 | triton.Config( 34 | { 35 | 'BLOCK_SIZE_M': 64, 36 | 'BLOCK_SIZE_N': 128, 37 | }, 38 | num_stages=4, 39 | num_warps=4 40 | ), 41 | triton.Config( 42 | { 43 | 'BLOCK_SIZE_M': 128, 44 | 'BLOCK_SIZE_N': 32, 45 | }, 46 | num_stages=4, 47 | num_warps=4 48 | ), 49 | triton.Config( 50 | { 51 | 'BLOCK_SIZE_M': 64, 52 | 'BLOCK_SIZE_N': 64, 53 | }, 54 | num_stages=4, 55 | num_warps=4 56 | ), 57 | triton.Config( 58 | { 59 | 'BLOCK_SIZE_M': 64, 60 | 'BLOCK_SIZE_N': 128, 61 | }, 62 | num_stages=2, 63 | num_warps=8 64 | ), 65 | triton.Config( 66 | { 67 | 'BLOCK_SIZE_M': 32, 68 | 'BLOCK_SIZE_N': 128, 69 | }, 70 | num_stages=4, 71 | num_warps=4 72 | ), 73 | ], 74 | key=['M', 'N'], 75 | nearest_power_of_two=True, 76 | prune_configs_by={ 77 | 'early_config_prune': custom_autotune.hadamard248_kernel_config_pruner, 78 | 'perf_model': None, 79 | 'top_k': None, 80 | }, 81 | ) 82 | @triton.jit 83 | def dequant_kernel_dim0( 84 | b_ptr, c_ptr, 85 | M, N, 86 | bits, maxq, 87 | stride_bk, stride_bn, 88 | stride_cm, stride_cn, 89 | BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr 90 | ): 91 | """ 92 | dequant the quantized tensor to fp tensor 93 | B is of shape (M/(32//bits), N) int32 94 | C is of shape (M, N) float16 95 | """ 96 | 97 | bits_per_feature = 32 // bits 98 | 99 | pid = tl.program_id(axis=0) 100 | num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) 101 | num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) 102 | 103 | pid_m = pid // num_pid_n 104 | pid_n = pid % num_pid_n 105 | 106 | offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) 107 | offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) 108 | 109 | 110 | b_ptrs = b_ptr + ((offs_am[:, None] // bits_per_feature) * stride_bk + offs_bn[None, :] * stride_bn) 111 | 112 | shifter = (offs_am[:, None] % bits_per_feature) * bits 113 | 114 | 115 | 116 | b = tl.load(b_ptrs) 117 | b = (b >> shifter) & maxq 118 | 119 | c = b 120 | 121 | c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :] 122 | c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N) 123 | tl.store(c_ptrs, c, mask=c_mask) 124 | 125 | @custom_autotune.autotune( 126 | configs=[ 127 | triton.Config( 128 | { 129 | 'BLOCK_SIZE_M': 2, 130 | 'BLOCK_SIZE_N':128, 131 | }, 132 | num_stages=8, 133 | num_warps=8 134 | ), 135 | triton.Config( 136 | { 137 | 'BLOCK_SIZE_M': 2, 138 | 'BLOCK_SIZE_N':64, 139 | }, 140 | num_stages=8, 141 | num_warps=8 142 | ), 143 | triton.Config( 144 | { 145 | 'BLOCK_SIZE_M': 2, 146 | 'BLOCK_SIZE_N':32, 147 | }, 148 | num_stages=8, 149 | num_warps=8 150 | ), 151 | triton.Config( 152 | { 153 | 'BLOCK_SIZE_M': 2, 154 | 'BLOCK_SIZE_N':2, 155 | }, 156 | num_stages=8, 157 | num_warps=8 158 | ), 159 | ], 160 | key=['M', 'N'], 161 | nearest_power_of_two=True, 162 | prune_configs_by={ 163 | 'early_config_prune': custom_autotune.hadamard248_kernel_config_pruner, 164 | 'perf_model': None, 165 | 'top_k': None, 166 | }, 167 | ) 168 | @triton.jit 169 | def dequant_kernel_dim1( 170 | b_ptr, c_ptr, 171 | M, N, 172 | bits, maxq, 173 | stride_bk, stride_bn, 174 | stride_cm, stride_cn, 175 | BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr 176 | ): 177 | """ 178 | dequant the quantized tensor to fp tensor 179 | B is of shape (M, N/(32//bits)) int32 180 | C is of shape (M, N) float16 181 | """ 182 | 183 | bits_per_feature = 32 // bits 184 | 185 | pid = tl.program_id(axis=0) 186 | num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) 187 | num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) 188 | 189 | pid_m = pid // num_pid_n 190 | pid_n = pid % num_pid_n 191 | 192 | offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) 193 | offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) 194 | 195 | 196 | # b_ptrs = b_ptr + ((offs_am[:, None] // bits_per_feature) * stride_bk + offs_bn[None, :] * stride_bn) 197 | b_ptrs = b_ptr + (offs_am[:, None] * stride_bk + (offs_bn[None, :] // bits_per_feature) * stride_bn) 198 | 199 | # shifter = (offs_am[:, None] % bits_per_feature) * bits 200 | shifter = (offs_bn[None, :] % bits_per_feature) * bits 201 | 202 | 203 | 204 | b = tl.load(b_ptrs) 205 | b = (b >> shifter) & maxq 206 | 207 | c = b 208 | 209 | c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :] 210 | c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N) 211 | tl.store(c_ptrs, c, mask=c_mask) 212 | 213 | 214 | @triton.jit 215 | def silu(x): 216 | return x * tl.sigmoid(x) 217 | 218 | 219 | def dequant_dim0(qweight, bits, maxq, infeatures, outfeatures): 220 | with torch.cuda.device(qweight.device): 221 | output = torch.empty((infeatures, outfeatures), device=qweight.device, dtype=torch.float16) 222 | grid = lambda META: ( 223 | triton.cdiv(output.shape[0], META['BLOCK_SIZE_M']) * triton.cdiv(output.shape[1], META['BLOCK_SIZE_N']), 224 | ) 225 | dequant_kernel_dim0[grid]( 226 | qweight, output, 227 | output.shape[0], output.shape[1], 228 | bits, maxq, 229 | qweight.stride(0), qweight.stride(1), 230 | output.stride(0), output.stride(1), 231 | ) 232 | return output 233 | 234 | def dequant_dim1(qweight, bits, maxq, infeatures, outfeatures): 235 | with torch.cuda.device(qweight.device): 236 | output = torch.empty((infeatures, outfeatures), device=qweight.device, dtype=torch.float16) 237 | grid = lambda META: ( 238 | triton.cdiv(output.shape[0], META['BLOCK_SIZE_M']) * triton.cdiv(output.shape[1], META['BLOCK_SIZE_N']), 239 | ) 240 | dequant_kernel_dim1[grid]( 241 | qweight, output, 242 | output.shape[0], output.shape[1], 243 | bits, maxq, 244 | qweight.stride(0), qweight.stride(1), 245 | output.stride(0), output.stride(1), 246 | ) 247 | return output 248 | 249 | 250 | 251 | 252 | 253 | 254 | -------------------------------------------------------------------------------- /quantize/triton_utils/mixin.py: -------------------------------------------------------------------------------- 1 | class TritonModuleMixin: 2 | @classmethod 3 | def warmup(cls, model, transpose=False, seqlen=2048): 4 | pass 5 | -------------------------------------------------------------------------------- /quantize/utils.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from quantize.int_linear_fake import QuantLinear 3 | import torch 4 | from torch import nn 5 | from typing import Optional 6 | 7 | 8 | class MultiBlock(nn.Module): 9 | def __init__(self, *args, **kwargs) -> None: 10 | super().__init__(*args, **kwargs) 11 | self.block_list = nn.ModuleList([]) 12 | 13 | def add_block(self, block): 14 | self.block_list.append(block) 15 | 16 | def forward(self, 17 | hidden_states: torch.Tensor, 18 | attention_mask: Optional[torch.Tensor] = None, 19 | position_ids: Optional[torch.LongTensor] = None): 20 | for block in self.block_list: 21 | hidden_states = block(hidden_states, attention_mask=attention_mask,position_ids=position_ids)[0] 22 | return (hidden_states, ) 23 | 24 | 25 | def set_weight_parameters(model, requires_grad): 26 | params = [] 27 | for n, m in model.named_parameters(): 28 | if n.find('weight') > -1 and not (n.find('scale') > -1 or n.find('zero_point') > -1): 29 | m.requires_grad = requires_grad 30 | return iter(params) 31 | 32 | def weight_parameters(model): 33 | params = [] 34 | for n, m in model.named_parameters(): 35 | if n.find('weight') > -1 and not (n.find('scale') > -1 or n.find('zero_point') > -1): 36 | params.append(m) 37 | return iter(params) 38 | 39 | def set_quant_parameters(model, requires_grad): 40 | params = [] 41 | for n, m in model.named_parameters(): 42 | if n.find('scale') > -1 or n.find('zero_point') > -1: 43 | m.requires_grad = requires_grad 44 | return iter(params) 45 | 46 | def quant_parameters(model): 47 | params = [] 48 | for n, m in model.named_parameters(): 49 | if n.find('scale') > -1 or n.find('zero_point') > -1: 50 | params.append(m) 51 | return iter(params) 52 | 53 | 54 | def trainable_parameters(model): 55 | params = [] 56 | for n, m in model.named_parameters(): 57 | if m.requires_grad: 58 | params.append(m) 59 | return iter(params) 60 | 61 | def trainable_parameters_num(model): 62 | params = [] 63 | total = 0 64 | for n, m in model.named_parameters(): 65 | if m.requires_grad: 66 | total += m.numel() 67 | return total 68 | 69 | def set_quant_state(model, weight_quant: bool = False): 70 | for m in model.modules(): 71 | if isinstance(m, QuantLinear): 72 | m.set_quant_state(weight_quant) 73 | 74 | @torch.no_grad() 75 | def quant_inplace(model): 76 | for name, module in model.named_modules(): 77 | if isinstance(module, QuantLinear): 78 | module.weight.data = module.weight_quantizer(module.weight.data) 79 | 80 | 81 | class TruncateFunction(torch.autograd.Function): 82 | @staticmethod 83 | def forward(ctx, input, threshold): 84 | truncated_tensor = input.clone() 85 | truncated_tensor[truncated_tensor.abs() < threshold] = truncated_tensor[truncated_tensor.abs() < threshold].sign() * threshold 86 | return truncated_tensor 87 | 88 | 89 | @staticmethod 90 | def backward(ctx, grad_output): 91 | grad_input = grad_output.clone() 92 | return grad_input, None 93 | 94 | 95 | def truncate_number(number, threshold=1e-2): 96 | # avoid overflow with AMP training 97 | return TruncateFunction.apply(number, threshold) 98 | 99 | 100 | def get_named_linears(module, type): 101 | # return {name: m for name, m in module.named_modules() if isinstance(m, torch.nn.Linear)} 102 | return {name: m for name, m in module.named_modules() if isinstance(m, type)} 103 | 104 | def set_op_by_name(layer, name, new_module): 105 | levels = name.split('.') 106 | if len(levels) > 1: 107 | mod_ = layer 108 | for l_idx in range(len(levels)-1): 109 | if levels[l_idx].isdigit(): 110 | mod_ = mod_[int(levels[l_idx])] 111 | else: 112 | mod_ = getattr(mod_, levels[l_idx]) 113 | setattr(mod_, levels[-1], new_module) 114 | else: 115 | setattr(layer, name, new_module) 116 | 117 | # def add_new_module(name, original_module, added_module): 118 | # levels = name.split('.') 119 | # if len(levels) > 1: 120 | # mod_ = original_module 121 | # for l_idx in range(len(levels)-1): 122 | # if levels[l_idx].isdigit(): 123 | # mod_ = mod_[int(levels[l_idx])] 124 | # else: 125 | # mod_ = getattr(mod_, levels[l_idx]) 126 | # setattr(mod_, levels[-1], added_module) 127 | # else: 128 | # setattr(original_module, name, added_module) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==0.28.0 2 | bitsandbytes==0.41.0 3 | datasets==2.18.0 4 | lm_eval==0.4.2 5 | numpy==1.23.4 6 | torch==2.2.2 7 | tqdm==4.64.1 8 | transformers==4.40.1 9 | triton==2.2.0 10 | termcolor 11 | sentencepiece 12 | protobuf 13 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from math import inf 3 | import logging 4 | from termcolor import colored 5 | import sys 6 | import os 7 | import time 8 | 9 | 10 | def ampscaler_get_grad_norm(parameters, norm_type: float = 2.0) -> torch.Tensor: 11 | if isinstance(parameters, torch.Tensor): 12 | parameters = [parameters] 13 | parameters = [p for p in parameters if p.grad is not None] 14 | norm_type = float(norm_type) 15 | if len(parameters) == 0: 16 | return torch.tensor(0.) 17 | device = parameters[0].grad.device 18 | if norm_type == inf: 19 | total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters) 20 | else: 21 | total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), 22 | norm_type).to(device) for p in parameters]), norm_type) 23 | return total_norm 24 | 25 | class NativeScalerWithGradNormCount: 26 | state_dict_key = "amp_scaler" 27 | 28 | def __init__(self): 29 | self._scaler = torch.cuda.amp.GradScaler() 30 | 31 | def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True,retain_graph=False): 32 | self._scaler.scale(loss).backward(create_graph=create_graph, retain_graph=retain_graph) 33 | if update_grad: 34 | if clip_grad is not None: 35 | assert parameters is not None 36 | self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place 37 | norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad) 38 | else: 39 | self._scaler.unscale_(optimizer) 40 | norm = ampscaler_get_grad_norm(parameters) 41 | self._scaler.step(optimizer) 42 | self._scaler.update() 43 | else: 44 | norm = None 45 | return norm 46 | 47 | def state_dict(self): 48 | return self._scaler.state_dict() 49 | 50 | def load_state_dict(self, state_dict): 51 | self._scaler.load_state_dict(state_dict) 52 | 53 | 54 | def create_logger(output_dir, dist_rank=0, name=''): 55 | # create logger 56 | logger = logging.getLogger(name) 57 | logger.setLevel(logging.INFO) 58 | logger.propagate = False 59 | 60 | # create formatter 61 | fmt = '[%(asctime)s %(name)s] (%(filename)s %(lineno)d): %(levelname)s %(message)s' 62 | color_fmt = colored('[%(asctime)s %(name)s]', 'green') + \ 63 | colored('(%(filename)s %(lineno)d)', 'yellow') + ': %(levelname)s %(message)s' 64 | 65 | # create console handlers for master process 66 | if dist_rank == 0: 67 | console_handler = logging.StreamHandler(sys.stdout) 68 | console_handler.setLevel(logging.DEBUG) 69 | console_handler.setFormatter( 70 | logging.Formatter(fmt=color_fmt, datefmt='%Y-%m-%d %H:%M:%S')) 71 | logger.addHandler(console_handler) 72 | 73 | # create file handlers 74 | file_handler = logging.FileHandler(os.path.join(output_dir, f'log_rank{dist_rank}_{int(time.time())}.txt'), mode='a') 75 | file_handler.setLevel(logging.DEBUG) 76 | file_handler.setFormatter(logging.Formatter(fmt=fmt, datefmt='%Y-%m-%d %H:%M:%S')) 77 | logger.addHandler(file_handler) 78 | 79 | return logger --------------------------------------------------------------------------------