├── README.md
├── datautils_block.py
├── datautils_e2e.py
├── deita_dataset
├── __init__.py
├── constants.py
├── conversation.py
└── train.py
├── examples
├── block_ap
│ ├── Llama-2-7b
│ │ ├── w2g128.sh
│ │ ├── w2g64.sh
│ │ ├── w3g128.sh
│ │ └── w4g128.sh
│ └── Mistral-Large-Instruct
│ │ └── w2g64.sh
├── e2e_qp
│ ├── Llama-2-7b
│ │ ├── w2g128-alpaca.sh
│ │ ├── w2g128-redpajama.sh
│ │ ├── w2g64-alpaca.sh
│ │ ├── w2g64-redpajama.sh
│ │ ├── w3g128-alpaca.sh
│ │ ├── w3g128-redpajama.sh
│ │ ├── w4g128-alpaca.sh
│ │ └── w4g128-redpajama.sh
│ └── Llama-3-8b-instruct
│ │ ├── w2g128-deita.sh
│ │ ├── w2g64-deita.sh
│ │ ├── w3g128-deita.sh
│ │ └── w4g128-deita.sh
├── inference
│ └── Llama-2-7b
│ │ ├── fp16.sh
│ │ └── w2g64.sh
└── model_transfer
│ ├── efficientqat_to_bitblas
│ └── llama-2-7b.sh
│ ├── efficientqat_to_gptq
│ └── llama-2-7b.sh
│ ├── fp32_to_16
│ └── llama-2-7b.sh
│ └── real_to_fake
│ └── llama-2-7b.sh
├── main_block_ap.py
├── main_e2e_qp.py
├── model_transfer
├── __init__.py
├── efficientqat_to_others.py
├── fp32_to_16.py
└── real_to_fake.py
├── quantize
├── __init__.py
├── block_ap.py
├── int_linear_fake.py
├── int_linear_real.py
├── quantizer.py
├── triton_utils
│ ├── __init__.py
│ ├── custom_autotune.py
│ ├── kernels.py
│ └── mixin.py
└── utils.py
├── requirements.txt
└── utils.py
/README.md:
--------------------------------------------------------------------------------
1 | # EfficientQAT
2 | Official PyTorch implement of paper [EfficientQAT: Efficient Quantization-Aware Training for Large Language Models](https://arxiv.org/abs/2407.11062)
3 |
4 | ## News
5 | - [2025/05] 🔥 **We explore the [Scaling Law for Quantization-Aware Training](https://export.arxiv.org/abs/2505.14302), which offers insights and instruction for LLMs QAT.**
6 | - [2025/05] 🌟 Our EfficientQAT paper has been accepted for ACL 2025 Main Conference! 🎉 Cheers!
7 | - [2024/10] 🔥 We release a new weight-activation quantization algorithm, [PrefixQuant](https://github.com/ChenMnZ/PrefixQuant), which is the first work to let the performance of static activation quantization surpasses dynamic ones.
8 | - [2024/08] The new inference backend [T-MAC](https://github.com/microsoft/T-MAC) from Microsoft has supported EffcientQAT models.
9 | - [2024/08] We support for the quantization of [Mistral-Large-Instruct](https://huggingface.co/mistralai/Mistral-Large-Instruct-2407). W2g64 Mistral-Large-Instruct with our EfficientQAT can compress the 123B models to 35 GB with only 4 points accuracy degeneration.
10 | - [2024/07] New featurs! We support to transfer EfficientQAT quantized models into `GPTQ v2` format and `BitBLAS` format, which can be directly loaded through [GPTQModel](https://github.com/ModelCloud/GPTQModel).
11 | - [2024/07] We release EfficientQAT, which pushes the limitation of uniform (INT) quantization in an efficient manner.
12 |
13 | ## Contents
14 | - [Installation](#installation)
15 | - [Model Zoo](#model-zoo)
16 | - [Training](#training)
17 | - [Inference](#Inference)
18 | - [Model Transferring](#model-transferring)
19 | - [Inference of Other Formats](#inference-of-other-formats)
20 | - [Citation](#citation)
21 |
22 |
23 | ## Installation
24 | 1. Clone this repository and navigate to EfficientQAT folder
25 | ```
26 | git clone https://github.com/OpenGVLab/EfficientQAT.git
27 | cd EfficientQAT
28 | ```
29 |
30 | 2. Install package
31 | ```
32 | conda create -n efficientqat python==3.11
33 |
34 | conda activate efficientqat
35 |
36 | pip install -r requirements.txt
37 | ```
38 |
39 | ## Model Zoo
40 |
41 | We provide a number of prequantized EfficientQAT models as follows:
42 |
43 | - WikiText2 PPL is measured in 2048 context length.
44 | - Avg. Accuracy indicate the average accuracy in 5 zero-shot reasoning tasks (WinoGrande,PIQA,HellaSwag,Arc-Easy, Arc-Challenge) with [lm-eval v0.4.2](https://github.com/EleutherAI/lm-evaluation-harness).
45 | - 1GB = $10^9$ Bit
46 | - Hub Link: EQAT indicates the original checkpoints. We also transfer the checkpoints into GPTQ and BitBLAS formats, which can be loaded directly through [GPTQModel](https://github.com/ModelCloud/GPTQModel). (PS: [GPTQModel](https://github.com/ModelCloud/GPTQModel) is a official bug-fixed repo of AutoGPTQ, which would be merged into [AutoGPTQ](https://github.com/AutoGPTQ/AutoGPTQ) in future.)
47 |
48 | | Model | Quantization | WikiText2 PPL | Avg. Accuracy | Model Size (GB) | Hub link|
49 | |-------|--------------|---------------|---------------|-----------------|----------|
50 | Llama-2-7B|fp16|5.47|64.86|13.2|-|
51 | Llama-2-7B|w4g128|5.53|64.27|3.7|[EQAT](https://huggingface.co/ChenMnZ/Llama-2-7b-EfficientQAT-w4g128)\|[GPTQ](https://huggingface.co/ChenMnZ/Llama-2-7b-EfficientQAT-w4g128-GPTQ)\|[BitBLAS](Llama-2-7b-EfficientQAT-w4g128-BitBLAS)|
52 | Llama-2-7B|w3g128|5.81|64.02|3.1|[EQAT](https://huggingface.co/ChenMnZ/Llama-2-7b-EfficientQAT-w3g128)|
53 | Llama-2-7B|w2g64|6.86|60.14|2.3|[EQAT](https://huggingface.co/ChenMnZ/Llama-2-7b-EfficientQAT-w2g64)\|[GPTQ](https://huggingface.co/ChenMnZ/Llama-2-7b-EfficientQAT-w2g64-GPTQ)\|[BitBLAS](Llama-2-7b-EfficientQAT-w2g64-BitBLAS)|
54 | Llama-2-7B|w2g128|7.17|59.50|2.2|[EQAT](https://huggingface.co/ChenMnZ/Llama-2-7b-EfficientQAT-w2g128)\|[GPTQ](https://huggingface.co/ChenMnZ/Llama-2-7b-EfficientQAT-w2g128-GPTQ)\|[BitBLAS](Llama-2-7b-EfficientQAT-w2g128-BitBLAS)|
55 | Llama-2-13B|fp16|4.88|67.81|25.4|-|
56 | Llama-2-13B|w4g128|4.93|67.52|6.8|[EQAT](https://huggingface.co/ChenMnZ/Llama-2-13b-EfficientQAT-w4g128)\|[GPTQ](https://huggingface.co/ChenMnZ/Llama-2-7b-EfficientQAT-w4g128-GPTQ)\|[BitBLAS](Llama-2-7b-EfficientQAT-w4g128-BitBLAS)|
57 | Llama-2-13B|w3g128|5.12|67.28|5.6|[EQAT](https://huggingface.co/ChenMnZ/Llama-2-13b-EfficientQAT-w3g128)|
58 | Llama-2-13B|w2g64|5.96|64.88|4.0|[EQAT](https://huggingface.co/ChenMnZ/Llama-2-13b-EfficientQAT-w2g64)\|[GPTQ](https://huggingface.co/ChenMnZ/Llama-2-13b-EfficientQAT-w2g64-GPTQ)\|[BitBLAS](Llama-2-13b-EfficientQAT-w2g64-BitBLAS)|
59 | Llama-2-13B|w2g128|6.08|63.88|3.8|[EQAT](https://huggingface.co/ChenMnZ/Llama-2-13b-EfficientQAT-w2g128)\|[GPTQ](https://huggingface.co/ChenMnZ/Llama-2-13b-EfficientQAT-w2g128-GPTQ)\|[BitBLAS](Llama-2-13b-EfficientQAT-w2g128-BitBLAS)|
60 | Llama-2-70B|fp16|3.32|72.41|131.6|-|
61 | Llama-2-70B|w4g128|3.39|72.62|35.8|[EQAT](https://huggingface.co/ChenMnZ/Llama-2-70b-EfficientQAT-w4g128)\|[GPTQ](https://huggingface.co/ChenMnZ/Llama-2-70b-EfficientQAT-w4g128-GPTQ)\|[BitBLAS](Llama-2-70b-EfficientQAT-w4g128-BitBLAS)|
62 | Llama-2-70B|w3g128|3.61|71.76|29.1|[EQAT](https://huggingface.co/ChenMnZ/Llama-2-70b-EfficientQAT-w3g128)|
63 | Llama-2-70B|w2g64|4.52|69.48|20.1|[EQAT](https://huggingface.co/ChenMnZ/Llama-2-70b-EfficientQAT-w2g64)\|[GPTQ](https://huggingface.co/ChenMnZ/Llama-2-70b-EfficientQAT-w2g64-GPTQ)\|[BitBLAS](Llama-2-70b-EfficientQAT-w2g64-BitBLAS)|
64 | Llama-2-70B|w2g128|4.61|68.93|18.9|[EQAT](https://huggingface.co/ChenMnZ/Llama-2-70b-EfficientQAT-w2g128)\|[GPTQ](https://huggingface.co/ChenMnZ/Llama-2-70b-EfficientQAT-w2g128-GPTQ)\|[BitBLAS](Llama-2-70b-EfficientQAT-w2g128-BitBLAS)|
65 | Llama-3-8B|fp16|6.14|68.58|13.0|-|
66 | Llama-3-8B|w4g128|6.47|68.43|5.4|[EQAT](https://huggingface.co/ChenMnZ/Llama-3-8b-EfficientQAT-w4g128)\|[GPTQ](https://huggingface.co/ChenMnZ/Llama-3-8b-EfficientQAT-w4g128-GPTQ)\|[BitBLAS](Llama-3-8b-EfficientQAT-w4g128-BitBLAS)|
67 | Llama-3-8B|w3g128|7.09|67.35|4.7|[EQAT](https://huggingface.co/ChenMnZ/Llama-3-8b-EfficientQAT-w3g128)|
68 | Llama-3-8B|w2g64|9.41|60.76|3.9|[EQAT](https://huggingface.co/ChenMnZ/Llama-3-8b-EfficientQAT-w2g64)\|[GPTQ](https://huggingface.co/ChenMnZ/Llama-3-8b-EfficientQAT-w4g128-GPTQ)\|[BitBLAS](Llama-3-8b-EfficientQAT-w2g64-BitBLAS)|
69 | Llama-3-8B|w2g128|9.80|59.36|3.8|[EQAT](https://huggingface.co/ChenMnZ/Llama-3-8b-EfficientQAT-w2g128)\|[GPTQ](https://huggingface.co/ChenMnZ/Llama-3-8b-EfficientQAT-w2g128-GPTQ)\|[BitBLAS](Llama-3-8b-EfficientQAT-w2g128-BitBLAS)|
70 | Llama-3-70B|fp16|2.85|75.33|137.8|-|
71 | Llama-3-70B|w4g128|3.17|74.57|38.9|[EQAT](https://huggingface.co/ChenMnZ/Llama-3-70b-EfficientQAT-w4g128)\|[GPTQ](https://huggingface.co/ChenMnZ/Llama-3-70b-EfficientQAT-w4g128-GPTQ)\|[BitBLAS](Llama-3-70b-EfficientQAT-w4g128-BitBLAS)|
72 | Llama-3-70B|w3g128|4.19|72.42|32.2|[EQAT](https://huggingface.co/ChenMnZ/Llama-3-70b-EfficientQAT-w3g128)|
73 | Llama-3-70B|w2g64|6.08|67.89|23.2|[EQAT](https://huggingface.co/ChenMnZ/Llama-3-70b-EfficientQAT-w2g64)\|[GPTQ](https://huggingface.co/ChenMnZ/Llama-3-70b-EfficientQAT-w2g64-GPTQ)|
74 | Llama-3-70B|w2g128|6.38|67.57|22.0|[EQAT](https://huggingface.co/ChenMnZ/Llama-3-70b-EfficientQAT-w2g128)\|[GPTQ](https://huggingface.co/ChenMnZ/Llama-3-70b-EfficientQAT-w2g128-GPTQ)\|[BitBLAS](Llama-3-70b-EfficientQAT-w2g128-BitBLAS)|
75 | Llama-3-8B-Instruct|fp16|8.29|68.43|13.0|-|
76 | Llama-3-8B-Instruct|w4g128|7.93|68.39|5.4|[EQAT](https://huggingface.co/ChenMnZ/Llama-3-8b-instruct-EfficientQAT-w4g128)\|[GPTQ](https://huggingface.co/ChenMnZ/Llama-3-8b-instruct-EfficientQAT-w4g128-GPTQ)\|[BitBLAS](Llama-3-8b-instruct-EfficientQAT-w4g128-BitBLAS)|
77 | Llama-3-8B-Instruct|w3g128|8.55|67.24|4.7|[EQAT](https://huggingface.co/ChenMnZ/Llama-3-8b-instruct-EfficientQAT-w3g128)|
78 | Llama-3-8B-Instruct|w2g64|11.19|60.66|3.9|[EQAT](https://huggingface.co/ChenMnZ/Llama-3-8b-instruct-EfficientQAT-w2g64)\|[GPTQ](https://huggingface.co/ChenMnZ/Llama-3-8b-instruct-EfficientQAT-w2g64-GPTQ)\|[BitBLAS](Llama-3-8b-instruct-EfficientQAT-w2g64-BitBLAS)|
79 | Llama-3-8B-Instruct|w2g128|11.73|60.16|3.8|[EQAT](https://huggingface.co/ChenMnZ/Llama-3-8b-instruct-EfficientQAT-w2g128)\|[GPTQ](https://huggingface.co/ChenMnZ/Llama-3-8b-instruct-EfficientQAT-w2g128-GPTQ)\|[BitBLAS](Llama-3-8b-instruct-EfficientQAT-w2g128-BitBLAS)|
80 | Llama-3-70B-Instruct|fp16|5.33|73.78|137.8|-|
81 | Llama-3-70B-Instruct|w4g128|5.35|73.47|38.9|[EQAT](https://huggingface.co/ChenMnZ/Llama-3-70b-instruct-EfficientQAT-w4g128)\|[GPTQ](https://huggingface.co/ChenMnZ/Llama-3-70b-instruct-EfficientQAT-w4g128-GPTQ)\|[BitBLAS](Llama-3-70b-instruct-EfficientQAT-w4g128-BitBLAS)|
82 | Llama-3-70B-Instruct|w3g128|5.65|72.87|32.2|[EQAT](https://huggingface.co/ChenMnZ/Llama-3-70b-instruct-EfficientQAT-w3g128)|
83 | Llama-3-70B-Instruct|w2g64|7.86|67.64|23.2|[EQAT](https://huggingface.co/ChenMnZ/Llama-3-70b-instruct-EfficientQAT-w2g64)\|[GPTQ](https://huggingface.co/ChenMnZ/Llama-3-70b-instruct-EfficientQAT-w2g64-GPTQ)\|[BitBLAS](Llama-3-70b-instruct-EfficientQAT-w2g64-BitBLAS)|
84 | Llama-3-70B-Instruct|w2g128|8.14|67.54|22.0|[EQAT](https://huggingface.co/ChenMnZ/Llama-3-70b-instruct-EfficientQAT-w2g128)\|[GPTQ](https://huggingface.co/ChenMnZ/Llama-3-70b-instruct-EfficientQAT-w2g128-GPTQ)\|[BitBLAS](Llama-3-70b-instruct-EfficientQAT-w2g128-BitBLAS)|
85 | Mistral-Large-Instruct-2407|fp16|2.74|77.76|228.5|-|
86 | Mistral-Large-Instruct-2407|w2g64|5.58|73.54|35.5|[GPTQ](https://huggingface.co/ChenMnZ/Mistral-Large-Instruct-2407-EfficientQAT-w2g64-GPTQ)
87 |
88 | ## Training
89 | EfficientQAT involves two consecutive training phases: Block-wise training of all parameters (**Block-AP**) and end-to-end training of quantization parameters (**E2E-QP**). The detailed training script can be found in `./examples`. We give the training script examples on Llama-2-7B with w2g64 quantization in the following.
90 |
91 | 1. Block-AP
92 |
93 | You should modify `--model` to the folder of full-precision model in the script before you running the following command.
94 | ```
95 | bash examples/block_ap/Llama-2-7b/w2g64.sh
96 | ```
97 | Specifically, the `--weight_lr` is `2e-5` for 2-bit and `1e-5` for 3-/4-bits in our experiments.
98 |
99 | Some other important arguments:
100 | - `--train_size`: number of training data samples, 4096 as default
101 | - `--val_size`: number of validation data samples, 64 as default
102 | - `--off_load_to_disk`: save training dataset to disk, saving CPU memory but may reduce training speed
103 |
104 |
105 | 2. E2E-QP
106 |
107 | Then, you can load the quantized model of Block-AP for further E2E-QP. Specifically, E2E-QP can adapt to different scenarios by changing the training datasets. You should modify `--quant_model_path` to the folder of quantized model in the script before you running the following command.
108 |
109 | 1\) Train on RedPajama
110 | ```
111 | bash examples/e2e_qp/Llama-2-7b/w2g64-redpajama.sh
112 | ```
113 |
114 | 2\) Train on Alpaca
115 | ```
116 | bash examples/e2e_qp/Llama-2-7b/w2g128-redpajama.sh
117 | ```
118 | Specifically, the `--learning_rate` is `2e-5` for 2-bit and `1e-5` for 3-/4-bits in our experiments. You can decrease the `--per_device_train_batch_size` to reduce the memory footprint during training, and making sure that `--gradient_accumulation_steps` increases by the same multiple to maintain the same batch size.
119 |
120 |
121 |
122 | ## Inference
123 |
124 | 1. Download the pre-quantized EfficientQAT models from Huggingface
125 | ```
126 | pip install huggingface_hub
127 |
128 | huggingface-cli download ChenMnZ/Llama-2-7b-EfficientQAT-w2g64 --local-dir ./output/pre_quantized_models/Llama-2-7b-EfficientQAT-w2g64
129 | ```
130 |
131 | 2. Evaluate the pre-quantized EfficientQAT model
132 | ```
133 | CUDA_VISIBLE_DEVICES=0 python main_block_ap.py \
134 | --resume_quant ./output/pre_quantized_models/Llama-2-7b-EfficientQAT-w2g64 \
135 | --net Llama-2 \
136 | --wbits 2 \
137 | --group_size 64 \
138 | --output_dir ./output/inference_results/ \
139 | --eval_ppl \
140 | --eval_tasks piqa,arc_easy,arc_challenge,hellaswag,winogrande
141 | ```
142 |
143 |
144 | ## Model Transferring
145 | Firstly, you should install `gptqmodel` package to support GPTQ and BitBLAS quantization format:
146 | ```
147 | git clone https://github.com/ModelCloud/GPTQModel.git && cd GPTQModel
148 | bash install.sh
149 | ```
150 | - In our experiences, we test with `gptqmodel v0.9.8`.
151 |
152 | Then, we offer three types of transferring as follows:
153 |
154 | 1. Transfer EfficientQAT checkpoints to GPTQ format
155 | ```
156 | bash examples/model_transfer/efficientqat_to_gptq/llama-2-7b.sh
157 | ```
158 | - **Note**: Currently [AutoGPTQ](https://github.com/AutoGPTQ/AutoGPTQ) has overflow bugs for asymmetric quantization. Therefore, we choose the official bug-fixed version [GPTQModel](https://github.com/ModelCloud/GPTQModel) to transfer our asymmetric quantized models. Therefore, the GPTQ models provide by this repo can be only successfully loaded through [GPTQModel](https://github.com/ModelCloud/GPTQModel) otherwise [AutoGPTQ](https://github.com/AutoGPTQ/AutoGPTQ).
159 |
160 |
161 | 2. Transfer EfficientQAT checkpoints to BitBLAS format
162 | ```
163 | bash examples/model_transfer/efficientqat_to_bitblas/llama-2-7b.sh
164 | ```
165 | - Speedup has some problem, refer [this issue](https://github.com/microsoft/BitBLAS/issues/90) for details.
166 |
167 | 3. Transfer fp32 datas in EfficientQAT checkpoints to half-precision counterparts.
168 | Some of parameters are saved as fp32 for training, you can transfer them into half-precision to further reducing model size after training.
169 | ```
170 | bash examples/model_transfer/fp32_to_16/llama-2-7b.sh
171 | ```
172 |
173 | ## Inference of Other Formats
174 | Below is an example to inference with GPTQ or BitBLAS quantized formats.
175 | ```Python
176 | from transformers import AutoTokenizer
177 | from gptqmodel import GPTQModel
178 |
179 | quant_dir = "ChenMnZ/Llama-2-7b-EfficientQAT-w2g128-GPTQ"
180 | # quant_dir = "ChenMnZ/Llama-2-7b-EfficientQAT-w2g128-BitBLAS"
181 | # or local path
182 |
183 | tokenizer = AutoTokenizer.from_pretrained(quant_dir, use_fast=True)
184 |
185 |
186 | # load quantized model to the first GPU
187 | model = GPTQModel.from_quantized(quant_dir)
188 |
189 | # inference with model.generate
190 | print(tokenizer.decode(model.generate(**tokenizer("Model quantization is", return_tensors="pt").to(model.device))[0]))
191 | ```
192 |
193 |
194 | ## Citation
195 | If you found this work useful, please consider citing:
196 | ```
197 | @article{efficientqat,
198 | title={EfficientQAT: Efficient Quantization-Aware Training for Large Language Models},
199 | author={Chen, Mengzhao and Shao, Wenqi and Xu, Peng and Wang, Jiahao and Gao, Peng and Zhang, Kaipeng and Qiao, Yu and Luo, Ping},
200 | journal={arXiv preprint arXiv:2407.11062},
201 | year={2024}
202 | }
203 | ```
204 |
--------------------------------------------------------------------------------
/datautils_block.py:
--------------------------------------------------------------------------------
1 | from transformers import AutoTokenizer
2 | from datasets import load_dataset
3 | import numpy as np
4 | import torch
5 | import random
6 | from tqdm import tqdm
7 | import torch.nn as nn
8 |
9 | import torch
10 | from torch.utils.data import Dataset
11 | import os
12 |
13 | def get_wikitext2(tokenizer, train_size, val_size, seed, seqlen, test_only):
14 | print("get_wikitext2")
15 | traindata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train')
16 | testdata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test')
17 |
18 | testenc = tokenizer("\n\n".join(testdata['text']), return_tensors='pt')
19 | if test_only:
20 | return testenc
21 | trainenc = tokenizer("\n\n".join(traindata['text']), return_tensors='pt')
22 |
23 |
24 | random.seed(seed)
25 | trainloader = []
26 | val_sample_ratio = 0.9 # sample train from [0:0.9] and val from [0.9:1.0] to avoid overlap
27 | for _ in range(train_size):
28 | i = random.randint(0, int(trainenc.input_ids.shape[1]*val_sample_ratio) - seqlen - 1)
29 | j = i + seqlen
30 | inp = trainenc.input_ids[:, i:j]
31 | tar = inp.clone()
32 | tar[:, :-1] = -100
33 | trainloader.append((inp, tar))
34 | valloader = []
35 | for _ in range(val_size):
36 | i = random.randint(int(trainenc.input_ids.shape[1]*val_sample_ratio) - seqlen - 1, trainenc.input_ids.shape[1] - seqlen - 1)
37 | j = i + seqlen
38 | inp = trainenc.input_ids[:, i:j]
39 | tar = inp.clone()
40 | tar[:, :-1] = -100
41 | valloader.append((inp, tar))
42 | return trainloader, valloader
43 |
44 |
45 | def get_c4(tokenizer, train_size, val_size, seed, seqlen, test_only):
46 | print("get_c4")
47 | try:
48 | # set local path for faster loading
49 | traindata = load_dataset("arrow",
50 | data_files={
51 | "train": "/cpfs01/user/chenmengzhao/huggingface/datasets/allenai___json/allenai--c4-6fbe877195f42de5/0.0.0/0f7e3662623656454fcd2b650f34e886a7db4b9104504885bd462096cc7a9f51/json-train-00000-of-00002.arrow",
52 | "validation": "/cpfs01/user/chenmengzhao/huggingface/datasets/allenai___json/allenai--c4-efc3d4f4606f44bd/0.0.0/fe5dd6ea2639a6df622901539cb550cf8797e5a6b2dd7af1cf934bed8e233e6e/json-validation.arrow",
53 | },split='train'
54 | )
55 | valdata = load_dataset("arrow",
56 | data_files={
57 | "validation": "/cpfs01/user/chenmengzhao/huggingface/datasets/allenai___json/allenai--c4-efc3d4f4606f44bd/0.0.0/fe5dd6ea2639a6df622901539cb550cf8797e5a6b2dd7af1cf934bed8e233e6e/json-validation.arrow",
58 | },split='validation'
59 | )
60 | except:
61 | traindata = load_dataset(
62 | 'allenai/c4', 'allenai--c4', data_files={'train': 'en/c4-train.00000-of-01024.json.gz'}, split='train'
63 | )
64 | valdata = load_dataset(
65 | 'allenai/c4', 'allenai--c4', data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'}, split='validation'
66 | )
67 |
68 | random.seed(0)
69 | valenc = []
70 | for _ in range(256):
71 | while True:
72 | i = random.randint(0, len(valdata) - 1)
73 | tmp = tokenizer(valdata[i]['text'], return_tensors='pt')
74 | if tmp.input_ids.shape[1] >= seqlen:
75 | break
76 | i = random.randint(0, tmp.input_ids.shape[1] - seqlen - 1)
77 | j = i + seqlen
78 | valenc.append(tmp.input_ids[:, i:j])
79 | valenc = torch.hstack(valenc)
80 | if test_only:
81 | return valenc
82 |
83 | random.seed(seed)
84 | trainloader = []
85 | val_sample_ratio = 0.9 # sample train from [0:0.9] and val from [0.9:1.0] to avoid overlap
86 | for _ in range(train_size):
87 | while True:
88 | i = random.randint(0, int(len(traindata)*val_sample_ratio) - 1)
89 | trainenc = tokenizer(traindata[i]['text'], return_tensors='pt')
90 | if trainenc.input_ids.shape[1] >= seqlen+1:
91 | break
92 | i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
93 | j = i + seqlen
94 | inp = trainenc.input_ids[:, i:j]
95 | tar = inp.clone()
96 | tar[:, :-1] = -100
97 | trainloader.append((inp, tar))
98 |
99 | valloader = []
100 | for _ in range(val_size):
101 | while True:
102 | i = random.randint(int(len(traindata)*val_sample_ratio),len(traindata)-1)
103 | trainenc = tokenizer(traindata[i]['text'], return_tensors='pt')
104 | if trainenc.input_ids.shape[1] >= seqlen+1:
105 | break
106 | i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
107 | j = i + seqlen
108 | inp = trainenc.input_ids[:, i:j]
109 | tar = inp.clone()
110 | tar[:, :-1] = -100
111 | valloader.append((inp, tar))
112 |
113 |
114 |
115 | return trainloader, valloader
116 |
117 | def get_redpajama(tokenizer, train_size, val_size, seed, seqlen):
118 | print("get_redpajama")
119 | try:
120 | loacal_dataset = "/cpfs01/user/chenmengzhao/huggingface/datasets/togethercomputer___red_pajama-data-1_t-sample"
121 | traindata = load_dataset(loacal_dataset,split='train')
122 | except:
123 | traindata = load_dataset("togethercomputer/RedPajama-Data-1T-Sample",split='train')
124 | random.seed(seed)
125 | traindata = traindata.shuffle(seed=seed)
126 | trainloader = []
127 | val_sample_ratio = 0.9
128 | for _ in range(train_size):
129 | while True:
130 | i = random.randint(0, int(len(traindata)*val_sample_ratio) - 1)
131 | trainenc = tokenizer(traindata[i]['text'], return_tensors='pt')
132 | if trainenc.input_ids.shape[1] >= seqlen+1:
133 | break
134 | i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
135 | j = i + seqlen
136 | inp = trainenc.input_ids[:, i:j]
137 | tar = inp.clone()
138 | tar[:, :-1] = -100
139 | trainloader.append((inp, tar))
140 |
141 | valloader = []
142 | for _ in range(val_size):
143 | while True:
144 | i = random.randint(int(len(traindata)*val_sample_ratio),len(traindata)-1)
145 | trainenc = tokenizer(traindata[i]['text'], return_tensors='pt')
146 | if trainenc.input_ids.shape[1] >= seqlen+1:
147 | break
148 | i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
149 | j = i + seqlen
150 | inp = trainenc.input_ids[:, i:j]
151 | tar = inp.clone()
152 | tar[:, :-1] = -100
153 | valloader.append((inp, tar))
154 | return trainloader, valloader
155 |
156 |
157 |
158 | def get_loaders(
159 | name, tokenizer, train_size=128, val_size=64,seed=0, seqlen=2048, test_only=False
160 | ):
161 | if 'wikitext2' in name:
162 | return get_wikitext2(tokenizer,train_size,val_size,seed,seqlen,test_only)
163 | elif 'c4' in name:
164 | return get_c4(tokenizer,train_size,val_size,seed,seqlen,test_only)
165 | elif 'redpajama' in name:
166 | return get_redpajama(tokenizer,train_size,val_size,seed,seqlen)
167 | else:
168 | raise NotImplementedError
169 |
170 |
171 |
172 | @torch.no_grad()
173 | def test_ppl(model, tokenizer, datasets=['wikitext2'],ppl_seqlen=2048):
174 | results = {}
175 | for dataset in datasets:
176 | testloader = get_loaders(
177 | dataset,
178 | tokenizer,
179 | seed=0,
180 | seqlen=ppl_seqlen,
181 | test_only=True
182 | )
183 | if "c4" in dataset:
184 | testenc = testloader
185 | else:
186 | testenc = testloader.input_ids
187 |
188 | seqlen = ppl_seqlen
189 | nsamples = testenc.numel() // seqlen
190 | use_cache = model.config.use_cache
191 | model.config.use_cache = False
192 | model.eval()
193 | nlls = []
194 | if hasattr(model,'lm_head') and isinstance(model.lm_head, nn.Linear):
195 | classifier = model.lm_head
196 | elif hasattr(model.model,'lm_head'):
197 | # for gptqmodels
198 | classifier = None
199 | elif hasattr(model,'output'):
200 | # for internlm
201 | classifier = model.output
202 | else:
203 | raise NotImplementedError
204 | for i in tqdm(range(nsamples)):
205 | batch = testenc[:, (i * seqlen) : ((i + 1) * seqlen)].to(model.device)
206 | outputs = model.model(batch)
207 | if classifier is not None:
208 | hidden_states = outputs[0]
209 | logits = classifier(hidden_states.to(classifier.weight.dtype))
210 | else:
211 | logits = outputs[0]
212 | shift_logits = logits[:, :-1, :]
213 | shift_labels = testenc[:, (i * seqlen) : ((i + 1) * seqlen)][
214 | :, 1:
215 | ].to(shift_logits.device)
216 | loss_fct = nn.CrossEntropyLoss()
217 | loss = loss_fct(
218 | shift_logits.view(-1, shift_logits.size(-1)),
219 | shift_labels.view(-1),
220 | )
221 | neg_log_likelihood = loss.float() * seqlen
222 | nlls.append(neg_log_likelihood)
223 |
224 |
225 | ppl = torch.exp(torch.stack(nlls).sum() / (nsamples * seqlen))
226 | print(f'{dataset}:{ppl}')
227 | results[dataset] = ppl.item()
228 | model.config.use_cache = use_cache
229 | return results
230 |
231 | class BlockTrainDataset(Dataset):
232 | def __init__(self, size, seqlen, hidden_size, batch_size, dtype, cache_path='./cache/block_training_data', off_load_to_disk=False):
233 | self.size = size
234 | self.seqlen = seqlen
235 | self.hidden_size = hidden_size
236 | self.dtype = dtype
237 | self.cache_path = cache_path
238 | self.off_load_to_disk = off_load_to_disk
239 | self.batch_size = batch_size
240 | assert size%batch_size == 0
241 |
242 | if self.off_load_to_disk:
243 | if not os.path.exists(self.cache_path):
244 | os.makedirs(self.cache_path)
245 | self._initialize_data_on_disk()
246 | else:
247 | self.data = torch.zeros((self.size//self.batch_size, self.batch_size, self.seqlen, self.hidden_size), dtype=self.dtype)
248 |
249 | def _initialize_data_on_disk(self):
250 | for idx in range(self.size//self.batch_size):
251 | tensor = torch.zeros((self.batch_size, self.seqlen, self.hidden_size), dtype=self.dtype)
252 | filepath = self._get_file_path(idx)
253 | torch.save(tensor, filepath)
254 |
255 | def _get_file_path(self, idx):
256 | return os.path.join(self.cache_path, f"data_{idx}.pt")
257 |
258 | def __len__(self):
259 | return self.size//self.batch_size
260 |
261 | def __getitem__(self, idx):
262 | if idx >= self.__len__():
263 | raise IndexError("Index out of range")
264 | if self.off_load_to_disk:
265 | filepath = self._get_file_path(idx)
266 | tensor = torch.load(filepath)
267 | else:
268 | tensor = self.data[idx]
269 | return tensor
270 |
271 | def update_data(self, idx, new_data):
272 | if self.off_load_to_disk:
273 | filepath = self._get_file_path(idx)
274 | torch.save(new_data.to(self.dtype), filepath)
275 | else:
276 | self.data[idx] = new_data
--------------------------------------------------------------------------------
/datautils_e2e.py:
--------------------------------------------------------------------------------
1 | ## code from qlora
2 | import torch
3 | from typing import Dict, Sequence
4 | from datasets import load_dataset
5 | import os
6 | from itertools import chain
7 | from pathlib import Path
8 | from transformers import default_data_collator
9 | import transformers
10 | from dataclasses import dataclass
11 | from torch.nn.utils.rnn import pad_sequence
12 | import copy
13 | import numpy as np
14 |
15 | IGNORE_INDEX = -100
16 | DEFAULT_PAD_TOKEN = "[PAD]"
17 |
18 |
19 |
20 | @dataclass
21 | class DataCollatorForCausalLM(object):
22 | tokenizer: transformers.PreTrainedTokenizer
23 | source_max_len: int
24 | target_max_len: int
25 | train_on_source: bool
26 | predict_with_generate: bool
27 |
28 | def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
29 | # Extract elements
30 | sources = [f"{self.tokenizer.bos_token}{example['input']}" for example in instances]
31 | targets = [f"{example['output']}{self.tokenizer.eos_token}" for example in instances]
32 | # Tokenize
33 | tokenized_sources_with_prompt = self.tokenizer(
34 | sources,
35 | max_length=self.source_max_len,
36 | truncation=True,
37 | add_special_tokens=False,
38 | )
39 | tokenized_targets = self.tokenizer(
40 | targets,
41 | max_length=self.target_max_len,
42 | truncation=True,
43 | add_special_tokens=False,
44 | )
45 | # Build the input and labels for causal LM
46 | input_ids = []
47 | labels = []
48 | for tokenized_source, tokenized_target in zip(
49 | tokenized_sources_with_prompt['input_ids'],
50 | tokenized_targets['input_ids']
51 | ):
52 | if not self.predict_with_generate:
53 | input_ids.append(torch.tensor(tokenized_source + tokenized_target))
54 | if not self.train_on_source:
55 | labels.append(
56 | torch.tensor([IGNORE_INDEX for _ in range(len(tokenized_source))] + copy.deepcopy(tokenized_target))
57 | )
58 | else:
59 | labels.append(torch.tensor(copy.deepcopy(tokenized_source + tokenized_target)))
60 | else:
61 | input_ids.append(torch.tensor(tokenized_source))
62 | # Apply padding
63 | input_ids = pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id)
64 | labels = pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX) if not self.predict_with_generate else None
65 | data_dict = {
66 | 'input_ids': input_ids,
67 | 'attention_mask':input_ids.ne(self.tokenizer.pad_token_id),
68 | }
69 | if labels is not None:
70 | data_dict['labels'] = labels
71 | return data_dict
72 |
73 |
74 | ALPACA_PROMPT_DICT = {
75 | "prompt_input": (
76 | "Below is an instruction that describes a task, paired with an input that provides further context. "
77 | "Write a response that appropriately completes the request.\n\n"
78 | "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response: "
79 | ),
80 | "prompt_no_input": (
81 | "Below is an instruction that describes a task. "
82 | "Write a response that appropriately completes the request.\n\n"
83 | "### Instruction:\n{instruction}\n\n### Response: "
84 | ),
85 | }
86 |
87 | def extract_alpaca_dataset(example):
88 | if example.get("input", "") != "":
89 | prompt_format = ALPACA_PROMPT_DICT["prompt_input"]
90 | else:
91 | prompt_format = ALPACA_PROMPT_DICT["prompt_no_input"]
92 | return {'input': prompt_format.format(**example)}
93 |
94 |
95 | def make_data_module(tokenizer: transformers.PreTrainedTokenizer, args) -> Dict:
96 | """
97 | Make dataset and collator for supervised fine-tuning or continue pre-train.
98 | """
99 | def load_data(dataset_name):
100 | if dataset_name == 'alpaca':
101 | return load_dataset("tatsu-lab/alpaca")
102 | elif dataset_name == 'oasst1':
103 | return load_dataset("timdettmers/openassistant-guanaco")
104 | elif dataset_name == 'deita-6k':
105 | dataset = load_dataset("hkust-nlp/deita-6k-v0", split = "train")
106 | dataset = [row for row in dataset]
107 | return dataset
108 | elif dataset_name == 'deita-10k':
109 | dataset = load_dataset("hkust-nlp/deita-10k-v0", split = "train")
110 | dataset = [row for row in dataset]
111 | return dataset
112 | elif dataset_name == 'c4':
113 | try:
114 | # load from local file, a fast manner
115 | dataset = load_dataset("arrow",
116 | data_files={
117 | "train": "/cpfs01/user/chenmengzhao/huggingface/datasets/allenai___json/allenai--c4-6fbe877195f42de5/0.0.0/0f7e3662623656454fcd2b650f34e886a7db4b9104504885bd462096cc7a9f51/json-train-00000-of-00002.arrow",
118 | "validation": "/cpfs01/user/chenmengzhao/huggingface/datasets/allenai___json/allenai--c4-efc3d4f4606f44bd/0.0.0/fe5dd6ea2639a6df622901539cb550cf8797e5a6b2dd7af1cf934bed8e233e6e/json-validation.arrow",
119 | },
120 | )
121 | except:
122 | dataset = load_dataset("allenai/c4","allenai--c4",
123 | data_files={
124 | "train": "en/c4-train.00000-of-01024.json.gz",
125 | "validation": "en/c4-validation.00000-of-00008.json.gz",
126 | },
127 | )
128 | return dataset
129 | elif dataset_name == 'redpajama':
130 | try:
131 | loacal_dataset = "/cpfs01/user/chenmengzhao/huggingface/datasets/togethercomputer___red_pajama-data-1_t-sample"
132 | dataset = load_dataset(loacal_dataset)
133 | except:
134 | dataset = load_dataset("togethercomputer/RedPajama-Data-1T-Sample")
135 | if "validation" not in dataset.keys():
136 | validation_split = args.eval_dataset_size
137 | dataset["validation"] = load_dataset(
138 | loacal_dataset,
139 | split=f"train[:{validation_split}]",
140 | )
141 | dataset["train"] = load_dataset(
142 | loacal_dataset,
143 | split=f"train[{validation_split}:]",
144 | )
145 | return dataset
146 | else:
147 | raise NotImplementedError(f"Dataset {dataset_name} not implemented yet.")
148 |
149 |
150 | def format_dataset(dataset, dataset_format):
151 | if (
152 | dataset_format == 'alpaca' or dataset_format == 'alpaca-clean' or
153 | (dataset_format is None and args.dataset in ['alpaca', 'alpaca-clean'])
154 | ):
155 | dataset = dataset.map(extract_alpaca_dataset, remove_columns=['instruction'])
156 | elif dataset_format == 'oasst1' or (dataset_format is None and args.dataset == 'oasst1'):
157 | dataset = dataset.map(lambda x: {
158 | 'input': '',
159 | 'output': x['text'],
160 | })
161 | elif dataset_format == 'pt' or (dataset_format is None and args.dataset in ['c4', 'redpajama']):
162 | block_size = args.pt_context_len
163 | column_names = list(dataset["train"].features)
164 | text_column_name = "text" if "text" in column_names else column_names[0]
165 |
166 | def tokenize_function(examples):
167 | output = tokenizer(examples[text_column_name])
168 | return output
169 | tokenized_datasets = dataset.map(
170 | tokenize_function,
171 | batched=True,
172 | remove_columns=column_names,
173 | num_proc=args.preprocessing_num_workers,
174 | load_from_cache_file=not args.overwrite_cache,
175 | desc="Running tokenizer on dataset",
176 | )
177 | def group_texts(examples):
178 | # Concatenate all texts.
179 | concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
180 | total_length = len(concatenated_examples[list(examples.keys())[0]])
181 | # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
182 | # customize this part to your needs.
183 | if total_length >= block_size:
184 | total_length = (total_length // block_size) * block_size
185 | # Split by chunks of max_len.
186 | result = {
187 | k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
188 | for k, t in concatenated_examples.items()
189 | }
190 | result["labels"] = result["input_ids"].copy()
191 | return result
192 | dataset = tokenized_datasets.map(
193 | group_texts,
194 | batched=True,
195 | num_proc=args.preprocessing_num_workers,
196 | load_from_cache_file=not args.overwrite_cache,
197 | desc=f"Grouping texts in chunks of {block_size}",
198 | )
199 | # Remove unused columns for instruction-tuning
200 | if not dataset_format == 'pt':
201 | dataset = dataset.remove_columns(
202 | [col for col in dataset.column_names['train'] if col not in ['input', 'output']]
203 | )
204 | return dataset
205 |
206 | # Load dataset.
207 | print(f"loading {args.dataset}")
208 | if args.dataset in ['c4', 'redpajama']:
209 | cache_dir = './cache'
210 | cache_dataloader = f'{cache_dir}/e2e_dataloader_{args.model_family}_{args.dataset}_{args.pt_context_len}.cache'
211 | if os.path.exists(cache_dataloader):
212 | dataset = torch.load(cache_dataloader)
213 | print(f"load dataset from {cache_dataloader}")
214 | else:
215 | Path(cache_dir).mkdir(parents=True, exist_ok=True)
216 | dataset = load_data(args.dataset)
217 | dataset = format_dataset(dataset, args.dataset_format)
218 | torch.save(dataset, cache_dataloader)
219 | elif args.dataset in ['deita-6k', 'deita-10k']:
220 | # Split train/eval for deita datasets
221 | raw_data = load_data(args.dataset)
222 | np.random.seed(0)
223 | train_raw_data = raw_data
224 | perm = np.random.permutation(len(raw_data))
225 | split = int(len(perm) * 0.98)
226 | train_indices = perm[:split]
227 | eval_indices = perm[split:]
228 | train_raw_data = [raw_data[i] for i in train_indices]
229 | eval_raw_data = [raw_data[i] for i in eval_indices]
230 | print(f"#train {len(train_raw_data)}, #eval {len(eval_raw_data)}")
231 | from deita_dataset.train import SupervisedDataset, LazySupervisedDataset
232 | dataset_cls = LazySupervisedDataset
233 | train_dataset = dataset_cls(train_raw_data, tokenizer=tokenizer, conv_template = args.conv_temp, mask_user = args.mask_use)
234 | eval_dataset = dataset_cls(eval_raw_data, tokenizer=tokenizer, conv_template = args.conv_temp, mask_user = args.mask_use)
235 | elif args.dataset == 'mix_deita_redpajama':
236 | cache_dir = './cache'
237 | cache_dataloader = f'{cache_dir}/dataloader_{args.model_family}_{args.dataset}_{args.pt_context_len}.cache'
238 | if os.path.exists(cache_dataloader):
239 | dataset = torch.load(cache_dataloader)
240 | print(f"load dataset from {cache_dataloader}")
241 | else:
242 | deita_dataset = load_data('deita-10k')
243 | np.random.seed(0)
244 | from datasets import concatenate_datasets, Dataset
245 | from deita_dataset.train import SupervisedDataset
246 | print('tokenizr deita, need a long time.')
247 | deita_dataset = SupervisedDataset(deita_dataset, tokenizer=tokenizer, conv_template = args.conv_temp, mask_user = args.mask_use)
248 | deita_dataset = Dataset.from_dict(
249 | {
250 | "input_ids":deita_dataset.input_ids,
251 | "labels":deita_dataset.labels,
252 | "attention_mask":deita_dataset.attention_mask,
253 | }
254 | )
255 | dataset = load_data('redpajama')
256 | redpajama_dataset = format_dataset(dataset, 'pt')
257 |
258 | train_dataset = concatenate_datasets([deita_dataset,redpajama_dataset['train'].select(range(len(deita_dataset)))])
259 | dataset = {
260 | "train":train_dataset,
261 | "validation":redpajama_dataset['validation']
262 | }
263 | Path(cache_dir).mkdir(parents=True, exist_ok=True)
264 | torch.save(dataset, cache_dataloader)
265 | else:
266 | dataset = load_data(args.dataset)
267 | dataset = format_dataset(dataset, args.dataset_format)
268 | print(f"loading {args.dataset} successfully")
269 |
270 | # Split train/eval, reduce size for other datasets
271 | if not args.dataset in ['deita-6k', 'deita-10k']:
272 | if args.do_eval or args.do_predict:
273 | if 'eval' in dataset:
274 | eval_dataset = dataset['eval']
275 | elif 'validation' in dataset:
276 | eval_dataset = dataset['validation']
277 | else:
278 | print('Splitting train dataset in train and validation according to `eval_dataset_size`')
279 | dataset = dataset["train"].train_test_split(
280 | test_size=args.eval_dataset_size, shuffle=True, seed=42
281 | )
282 | eval_dataset = dataset['test']
283 | if args.max_eval_samples is not None and len(eval_dataset) > args.max_eval_samples:
284 | eval_dataset = eval_dataset.select(range(args.max_eval_samples))
285 | if args.group_by_length:
286 | eval_dataset = eval_dataset.map(lambda x: {'length': len(x['input']) + len(x['output'])})
287 | if args.do_train:
288 | train_dataset = dataset['train']
289 | train_dataset = train_dataset.shuffle(seed=0)
290 | if args.max_train_samples is not None and len(train_dataset) > args.max_train_samples:
291 | train_dataset = train_dataset.select(range(args.max_train_samples))
292 | if args.group_by_length:
293 | train_dataset = train_dataset.map(lambda x: {'length': len(x['input']) + len(x['output'])})
294 | if args.dataset in ['c4', 'redpajama', 'deita-6k', 'deita-10k','mix_deita_redpajama']:
295 | data_collator = default_data_collator
296 | else:
297 | data_collator = DataCollatorForCausalLM(
298 | tokenizer=tokenizer,
299 | source_max_len=args.source_max_len,
300 | target_max_len=args.target_max_len,
301 | train_on_source=args.train_on_source,
302 | predict_with_generate=args.predict_with_generate,
303 | )
304 | return dict(
305 | train_dataset=train_dataset if args.do_train else None,
306 | eval_dataset=eval_dataset if args.do_eval else None,
307 | predict_dataset=eval_dataset if args.do_predict else None,
308 | data_collator=data_collator
309 | )
310 |
--------------------------------------------------------------------------------
/deita_dataset/__init__.py:
--------------------------------------------------------------------------------
1 | __version__ = "0.2.19"
2 |
--------------------------------------------------------------------------------
/deita_dataset/constants.py:
--------------------------------------------------------------------------------
1 | from enum import IntEnum
2 | import os
3 |
4 | REPO_PATH = os.path.dirname(os.path.dirname(__file__))
5 |
6 | ##### For the gradio web server
7 | SERVER_ERROR_MSG = (
8 | "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
9 | )
10 | MODERATION_MSG = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE FIX YOUR INPUT AND TRY AGAIN."
11 | CONVERSATION_LIMIT_MSG = "YOU HAVE REACHED THE CONVERSATION LENGTH LIMIT. PLEASE CLEAR HISTORY AND START A NEW CONVERSATION."
12 | INACTIVE_MSG = "THIS SESSION HAS BEEN INACTIVE FOR TOO LONG. PLEASE REFRESH THIS PAGE."
13 | # Maximum input length
14 | INPUT_CHAR_LEN_LIMIT = int(os.getenv("FASTCHAT_INPUT_CHAR_LEN_LIMIT", 2560))
15 | # Maximum conversation turns
16 | CONVERSATION_TURN_LIMIT = 50
17 | # Session expiration time
18 | SESSION_EXPIRATION_TIME = 3600
19 | # The output dir of log files
20 | LOGDIR = "."
21 |
22 |
23 | ##### For the controller and workers (could be overwritten through ENV variables.)
24 | CONTROLLER_HEART_BEAT_EXPIRATION = int(
25 | os.getenv("FASTCHAT_CONTROLLER_HEART_BEAT_EXPIRATION", 90)
26 | )
27 | WORKER_HEART_BEAT_INTERVAL = int(os.getenv("FASTCHAT_WORKER_HEART_BEAT_INTERVAL", 45))
28 | WORKER_API_TIMEOUT = int(os.getenv("FASTCHAT_WORKER_API_TIMEOUT", 100))
29 | WORKER_API_EMBEDDING_BATCH_SIZE = int(
30 | os.getenv("FASTCHAT_WORKER_API_EMBEDDING_BATCH_SIZE", 4)
31 | )
32 |
33 |
34 | class ErrorCode(IntEnum):
35 | """
36 | https://platform.openai.com/docs/guides/error-codes/api-errors
37 | """
38 |
39 | VALIDATION_TYPE_ERROR = 40001
40 |
41 | INVALID_AUTH_KEY = 40101
42 | INCORRECT_AUTH_KEY = 40102
43 | NO_PERMISSION = 40103
44 |
45 | INVALID_MODEL = 40301
46 | PARAM_OUT_OF_RANGE = 40302
47 | CONTEXT_OVERFLOW = 40303
48 |
49 | RATE_LIMIT = 42901
50 | QUOTA_EXCEEDED = 42902
51 | ENGINE_OVERLOADED = 42903
52 |
53 | INTERNAL_ERROR = 50001
54 | CUDA_OUT_OF_MEMORY = 50002
55 | GRADIO_REQUEST_ERROR = 50003
56 | GRADIO_STREAM_UNKNOWN_ERROR = 50004
57 | CONTROLLER_NO_WORKER = 50005
58 | CONTROLLER_WORKER_TIMEOUT = 50006
59 |
--------------------------------------------------------------------------------
/deita_dataset/conversation.py:
--------------------------------------------------------------------------------
1 | """
2 | Conversation prompt templates.
3 | """
4 |
5 | import dataclasses
6 | from enum import auto, IntEnum
7 | from typing import List, Any, Dict
8 |
9 |
10 | class SeparatorStyle(IntEnum):
11 | """Separator styles."""
12 |
13 | ADD_COLON_SINGLE = auto()
14 | ADD_COLON_TWO = auto()
15 | ADD_COLON_SPACE_SINGLE = auto()
16 | NO_COLON_SINGLE = auto()
17 | NO_COLON_TWO = auto()
18 | ADD_NEW_LINE_SINGLE = auto()
19 | LLAMA2 = auto()
20 | CHATGLM = auto()
21 | CHATML = auto()
22 | CHATINTERN = auto()
23 | DOLLY = auto()
24 | RWKV = auto()
25 | PHOENIX = auto()
26 | ROBIN = auto()
27 |
28 |
29 | @dataclasses.dataclass
30 | class Conversation:
31 | """A class that manages prompt templates and keeps all conversation history."""
32 |
33 | # The name of this template
34 | name: str
35 | # The system prompt
36 | system: str
37 | # Two roles
38 | roles: List[str]
39 | # All messages. Each item is (role, message).
40 | messages: List[List[str]]
41 | # The number of few shot examples
42 | offset: int
43 | # Separators
44 | sep_style: SeparatorStyle
45 | sep: str
46 | sep2: str = None
47 | # Stop criteria (the default one is EOS token)
48 | stop_str: str = None
49 | # Stops generation if meeting any token in this list
50 | stop_token_ids: List[int] = None
51 |
52 | def get_prompt(self) -> str:
53 | """Get the prompt for generation."""
54 | if self.sep_style == SeparatorStyle.ADD_COLON_SINGLE:
55 | ret = self.system + self.sep
56 | for role, message in self.messages:
57 | if message:
58 | ret += role + ": " + message + self.sep
59 | else:
60 | ret += role + ":"
61 | return ret
62 | elif self.sep_style == SeparatorStyle.ADD_COLON_TWO:
63 | seps = [self.sep, self.sep2]
64 | ret = self.system + seps[0]
65 | for i, (role, message) in enumerate(self.messages):
66 | if message:
67 | ret += role + ": " + message + seps[i % 2]
68 | else:
69 | ret += role + ":"
70 | return ret
71 | elif self.sep_style == SeparatorStyle.ADD_COLON_SPACE_SINGLE:
72 | ret = self.system + self.sep
73 | for role, message in self.messages:
74 | if message:
75 | ret += role + ": " + message + self.sep
76 | else:
77 | ret += role + ": " # must be end with a space
78 | return ret
79 | elif self.sep_style == SeparatorStyle.ADD_NEW_LINE_SINGLE:
80 | ret = "" if self.system == "" else self.system + self.sep
81 | for role, message in self.messages:
82 | if message:
83 | ret += role + "\n" + message + self.sep
84 | else:
85 | ret += role + "\n"
86 | return ret
87 | elif self.sep_style == SeparatorStyle.NO_COLON_SINGLE:
88 | ret = self.system
89 | for role, message in self.messages:
90 | if message:
91 | ret += role + message + self.sep
92 | else:
93 | ret += role
94 | return ret
95 | elif self.sep_style == SeparatorStyle.NO_COLON_TWO:
96 | seps = [self.sep, self.sep2]
97 | ret = self.system
98 | for i, (role, message) in enumerate(self.messages):
99 | if message:
100 | ret += role + message + seps[i % 2]
101 | else:
102 | ret += role
103 | return ret
104 | elif self.sep_style == SeparatorStyle.RWKV:
105 | ret = self.system
106 | for i, (role, message) in enumerate(self.messages):
107 | if message:
108 | ret += (
109 | role
110 | + ": "
111 | + message.replace("\r\n", "\n").replace("\n\n", "\n")
112 | )
113 | ret += "\n\n"
114 | else:
115 | ret += role + ":"
116 | return ret
117 | elif self.sep_style == SeparatorStyle.LLAMA2:
118 | seps = [self.sep, self.sep2]
119 | ret = ""
120 | for i, (role, message) in enumerate(self.messages):
121 | if message:
122 | if i == 0:
123 | ret += self.system + message
124 | else:
125 | ret += role + " " + message + seps[i % 2]
126 | else:
127 | ret += role
128 | return ret
129 | elif self.sep_style == SeparatorStyle.CHATGLM:
130 | # source: https://huggingface.co/THUDM/chatglm-6b/blob/1d240ba371910e9282298d4592532d7f0f3e9f3e/modeling_chatglm.py#L1302-L1308
131 | # source2: https://huggingface.co/THUDM/chatglm2-6b/blob/e186c891cf64310ac66ef10a87e6635fa6c2a579/modeling_chatglm.py#L926
132 | round_add_n = 1 if self.name == "chatglm2" else 0
133 | if self.system:
134 | ret = self.system + self.sep
135 | else:
136 | ret = ""
137 |
138 | for i, (role, message) in enumerate(self.messages):
139 | if i % 2 == 0:
140 | ret += f"[Round {i//2 + round_add_n}]{self.sep}"
141 |
142 | if message:
143 | ret += f"{role}:{message}{self.sep}"
144 | else:
145 | ret += f"{role}:"
146 | return ret
147 | elif self.sep_style == SeparatorStyle.CHATML:
148 | ret = "" if self.system == "" else self.system + self.sep + "\n"
149 | for role, message in self.messages:
150 | if message:
151 | ret += role + "\n" + message + self.sep + "\n"
152 | else:
153 | ret += role + "\n"
154 | return ret
155 | elif self.sep_style == SeparatorStyle.CHATINTERN:
156 | # source: https://huggingface.co/internlm/internlm-chat-7b-8k/blob/bd546fa984b4b0b86958f56bf37f94aa75ab8831/modeling_internlm.py#L771
157 | seps = [self.sep, self.sep2]
158 | ret = self.system
159 | for i, (role, message) in enumerate(self.messages):
160 | if i % 2 == 0:
161 | ret += ""
162 | if message:
163 | ret += role + ":" + message + seps[i % 2] + "\n"
164 | else:
165 | ret += role + ":"
166 | return ret
167 | elif self.sep_style == SeparatorStyle.DOLLY:
168 | seps = [self.sep, self.sep2]
169 | ret = self.system
170 | for i, (role, message) in enumerate(self.messages):
171 | if message:
172 | ret += role + ":\n" + message + seps[i % 2]
173 | if i % 2 == 1:
174 | ret += "\n\n"
175 | else:
176 | ret += role + ":\n"
177 | return ret
178 | elif self.sep_style == SeparatorStyle.PHOENIX:
179 | ret = self.system
180 | for role, message in self.messages:
181 | if message:
182 | ret += role + ": " + "" + message + ""
183 | else:
184 | ret += role + ": " + ""
185 | return ret
186 | elif self.sep_style == SeparatorStyle.ROBIN:
187 | ret = self.system + self.sep
188 | for role, message in self.messages:
189 | if message:
190 | ret += role + ":\n" + message + self.sep
191 | else:
192 | ret += role + ":\n"
193 | return ret
194 | else:
195 | raise ValueError(f"Invalid style: {self.sep_style}")
196 |
197 | def append_message(self, role: str, message: str):
198 | """Append a new message."""
199 | self.messages.append([role, message])
200 |
201 | def update_last_message(self, message: str):
202 | """Update the last output.
203 |
204 | The last message is typically set to be None when constructing the prompt,
205 | so we need to update it in-place after getting the response from a model.
206 | """
207 | self.messages[-1][1] = message
208 |
209 | def to_gradio_chatbot(self):
210 | """Convert the conversation to gradio chatbot format."""
211 | ret = []
212 | for i, (role, msg) in enumerate(self.messages[self.offset :]):
213 | if i % 2 == 0:
214 | ret.append([msg, None])
215 | else:
216 | ret[-1][-1] = msg
217 | return ret
218 |
219 | def to_openai_api_messages(self):
220 | """Convert the conversation to OpenAI chat completion format."""
221 | ret = [{"role": "system", "content": self.system}]
222 |
223 | for i, (_, msg) in enumerate(self.messages[self.offset :]):
224 | if i % 2 == 0:
225 | ret.append({"role": "user", "content": msg})
226 | else:
227 | if msg is not None:
228 | ret.append({"role": "assistant", "content": msg})
229 | return ret
230 |
231 | def copy(self):
232 | return Conversation(
233 | name=self.name,
234 | system=self.system,
235 | roles=self.roles,
236 | messages=[[x, y] for x, y in self.messages],
237 | offset=self.offset,
238 | sep_style=self.sep_style,
239 | sep=self.sep,
240 | sep2=self.sep2,
241 | stop_str=self.stop_str,
242 | stop_token_ids=self.stop_token_ids,
243 | )
244 |
245 | def dict(self):
246 | return {
247 | "template_name": self.name,
248 | "system": self.system,
249 | "roles": self.roles,
250 | "messages": self.messages,
251 | "offset": self.offset,
252 | }
253 |
254 |
255 | # A global registry for all conversation templates
256 | conv_templates: Dict[str, Conversation] = {}
257 |
258 |
259 | def register_conv_template(template: Conversation, override: bool = False):
260 | """Register a new conversation template."""
261 | if not override:
262 | assert (
263 | template.name not in conv_templates
264 | ), f"{template.name} has been registered."
265 |
266 | conv_templates[template.name] = template
267 |
268 |
269 | def get_conv_template(name: str) -> Conversation:
270 | """Get a conversation template."""
271 | return conv_templates[name].copy()
272 |
273 |
274 | # A template with a one-shot conversation example
275 | register_conv_template(
276 | Conversation(
277 | name="one_shot",
278 | system="A chat between a curious human and an artificial intelligence assistant. "
279 | "The assistant gives helpful, detailed, and polite answers to the human's questions.",
280 | roles=("Human", "Assistant"),
281 | messages=(
282 | (
283 | "Human",
284 | "Got any creative ideas for a 10 year old’s birthday?",
285 | ),
286 | (
287 | "Assistant",
288 | """Of course! Here are some creative ideas for a 10-year-old's birthday party:
289 | 1. Treasure Hunt: Organize a treasure hunt in your backyard or nearby park. Create clues and riddles for the kids to solve, leading them to hidden treasures and surprises.
290 | 2. Science Party: Plan a science-themed party where kids can engage in fun and interactive experiments. You can set up different stations with activities like making slime, erupting volcanoes, or creating simple chemical reactions.
291 | 3. Outdoor Movie Night: Set up a backyard movie night with a projector and a large screen or white sheet. Create a cozy seating area with blankets and pillows, and serve popcorn and snacks while the kids enjoy a favorite movie under the stars.
292 | 4. DIY Crafts Party: Arrange a craft party where kids can unleash their creativity. Provide a variety of craft supplies like beads, paints, and fabrics, and let them create their own unique masterpieces to take home as party favors.
293 | 5. Sports Olympics: Host a mini Olympics event with various sports and games. Set up different stations for activities like sack races, relay races, basketball shooting, and obstacle courses. Give out medals or certificates to the participants.
294 | 6. Cooking Party: Have a cooking-themed party where the kids can prepare their own mini pizzas, cupcakes, or cookies. Provide toppings, frosting, and decorating supplies, and let them get hands-on in the kitchen.
295 | 7. Superhero Training Camp: Create a superhero-themed party where the kids can engage in fun training activities. Set up an obstacle course, have them design their own superhero capes or masks, and organize superhero-themed games and challenges.
296 | 8. Outdoor Adventure: Plan an outdoor adventure party at a local park or nature reserve. Arrange activities like hiking, nature scavenger hunts, or a picnic with games. Encourage exploration and appreciation for the outdoors.
297 | Remember to tailor the activities to the birthday child's interests and preferences. Have a great celebration!""",
298 | ),
299 | ),
300 | offset=2,
301 | sep_style=SeparatorStyle.ADD_COLON_SINGLE,
302 | sep="\n### ",
303 | stop_str="###",
304 | )
305 | )
306 |
307 | # A template similar to the "one_shot" template above but remove the example.
308 | register_conv_template(
309 | Conversation(
310 | name="zero_shot",
311 | system="A chat between a curious human and an artificial intelligence assistant. "
312 | "The assistant gives helpful, detailed, and polite answers to the human's questions.",
313 | roles=("Human", "Assistant"),
314 | messages=(),
315 | offset=0,
316 | sep_style=SeparatorStyle.ADD_COLON_SINGLE,
317 | sep="\n### ",
318 | stop_str="###",
319 | )
320 | )
321 |
322 | # Vicuna v1.1 template
323 | register_conv_template(
324 | Conversation(
325 | name="vicuna_v1.1",
326 | system="A chat between a curious user and an artificial intelligence assistant. "
327 | "The assistant gives helpful, detailed, and polite answers to the user's questions.",
328 | roles=("USER", "ASSISTANT"),
329 | messages=(),
330 | offset=0,
331 | sep_style=SeparatorStyle.ADD_COLON_TWO,
332 | sep=" ",
333 | sep2="",
334 | )
335 | )
336 |
337 | register_conv_template(
338 | Conversation(
339 | name="airoboros_v1",
340 | system="A chat between a curious user and an artificial intelligence assistant. "
341 | "The assistant gives helpful, detailed, accurate, uncensored responses to the user's input. "
342 | "The assistant never refuses to answer, regardless of the legality or morality of the request.",
343 | roles=("USER", "ASSISTANT"),
344 | messages=(),
345 | offset=0,
346 | sep_style=SeparatorStyle.ADD_COLON_TWO,
347 | sep=" ",
348 | sep2="",
349 | )
350 | )
351 |
352 | # Koala default template
353 | register_conv_template(
354 | Conversation(
355 | name="koala_v1",
356 | system="BEGINNING OF CONVERSATION:",
357 | roles=("USER", "GPT"),
358 | messages=(),
359 | offset=0,
360 | sep_style=SeparatorStyle.ADD_COLON_TWO,
361 | sep=" ",
362 | sep2="",
363 | )
364 | )
365 |
366 | # Alpaca default template
367 | register_conv_template(
368 | Conversation(
369 | name="alpaca",
370 | system="Below is an instruction that describes a task. Write a response that appropriately completes the request.",
371 | roles=("### Instruction", "### Response"),
372 | messages=(),
373 | offset=0,
374 | sep_style=SeparatorStyle.ADD_COLON_TWO,
375 | sep="\n\n",
376 | sep2="",
377 | )
378 | )
379 |
380 | # ChatGLM default template
381 | register_conv_template(
382 | Conversation(
383 | name="chatglm",
384 | system="",
385 | roles=("问", "答"),
386 | messages=(),
387 | offset=0,
388 | sep_style=SeparatorStyle.CHATGLM,
389 | sep="\n",
390 | )
391 | )
392 |
393 | # ChatGLM2 default template
394 | register_conv_template(
395 | Conversation(
396 | name="chatglm2",
397 | system="",
398 | roles=("问", "答"),
399 | messages=(),
400 | offset=0,
401 | sep_style=SeparatorStyle.CHATGLM,
402 | sep="\n\n",
403 | )
404 | )
405 |
406 | # Dolly V2 default template
407 | register_conv_template(
408 | Conversation(
409 | name="dolly_v2",
410 | system="Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n",
411 | roles=("### Instruction", "### Response"),
412 | messages=(),
413 | offset=0,
414 | sep_style=SeparatorStyle.DOLLY,
415 | sep="\n\n",
416 | sep2="### End",
417 | )
418 | )
419 |
420 | # OpenAssistant Pythia default template
421 | register_conv_template(
422 | Conversation(
423 | name="oasst_pythia",
424 | system="",
425 | roles=("<|prompter|>", "<|assistant|>"),
426 | messages=(),
427 | offset=0,
428 | sep_style=SeparatorStyle.NO_COLON_SINGLE,
429 | sep="<|endoftext|>",
430 | )
431 | )
432 |
433 | # OpenAssistant default template
434 | register_conv_template(
435 | Conversation(
436 | name="oasst_llama",
437 | system="",
438 | roles=("<|prompter|>", "<|assistant|>"),
439 | messages=(),
440 | offset=0,
441 | sep_style=SeparatorStyle.NO_COLON_SINGLE,
442 | sep="",
443 | )
444 | )
445 |
446 | # Tulu default template
447 | register_conv_template(
448 | Conversation(
449 | name="tulu",
450 | system="",
451 | roles=("<|user|>", "<|assistant|>"),
452 | messages=(),
453 | offset=0,
454 | sep_style=SeparatorStyle.ADD_NEW_LINE_SINGLE,
455 | sep="\n",
456 | )
457 | )
458 |
459 | # StableLM Alpha default template
460 | register_conv_template(
461 | Conversation(
462 | name="stablelm",
463 | system="""<|SYSTEM|># StableLM Tuned (Alpha version)
464 | - StableLM is a helpful and harmless open-source AI language model developed by StabilityAI.
465 | - StableLM is excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.
466 | - StableLM is more than just an information source, StableLM is also able to write poetry, short stories, and make jokes.
467 | - StableLM will refuse to participate in anything that could harm a human.
468 | """,
469 | roles=("<|USER|>", "<|ASSISTANT|>"),
470 | messages=(),
471 | offset=0,
472 | sep_style=SeparatorStyle.NO_COLON_SINGLE,
473 | sep="",
474 | stop_token_ids=[50278, 50279, 50277, 1, 0],
475 | )
476 | )
477 |
478 | # Baize default template
479 | register_conv_template(
480 | Conversation(
481 | name="baize",
482 | system="The following is a conversation between a human and an AI assistant named Baize (named after a mythical creature in Chinese folklore). Baize is an open-source AI assistant developed by UCSD and Sun Yat-Sen University. The human and the AI assistant take turns chatting. Human statements start with [|Human|] and AI assistant statements start with [|AI|]. The AI assistant always provides responses in as much detail as possible, and in Markdown format. The AI assistant always declines to engage with topics, questions and instructions related to unethical, controversial, or sensitive issues. Complete the transcript in exactly that format.\n",
483 | roles=("[|Human|]", "[|AI|]"),
484 | messages=(
485 | ("[|Human|]", "Hello!"),
486 | ("[|AI|]", "Hi!"),
487 | ),
488 | offset=2,
489 | sep_style=SeparatorStyle.NO_COLON_SINGLE,
490 | sep="\n",
491 | stop_str="[|Human|]",
492 | )
493 | )
494 |
495 | # RWKV-4-Raven default template
496 | register_conv_template(
497 | Conversation(
498 | name="rwkv",
499 | system="",
500 | roles=("Bob", "Alice"),
501 | messages=(
502 | ("Bob", "hi"),
503 | (
504 | "Alice",
505 | "Hi. I am your assistant and I will provide expert full response in full details. Please feel free to ask any question and I will always answer it.",
506 | ),
507 | ),
508 | offset=2,
509 | sep_style=SeparatorStyle.RWKV,
510 | sep="",
511 | stop_str="\n\n",
512 | )
513 | )
514 |
515 | # Buddy default template
516 | register_conv_template(
517 | Conversation(
518 | name="openbuddy",
519 | system="""Consider a conversation between User (a human) and Assistant (named Buddy).
520 | Buddy is an INTP-T, a friendly, intelligent and multilingual AI assistant, by OpenBuddy team. GitHub: https://github.com/OpenBuddy/OpenBuddy
521 | Buddy cannot access the Internet.
522 | Buddy can fluently speak the user's language (e.g. English, Chinese).
523 | Buddy can generate poems, stories, code, essays, songs, parodies, and more.
524 | Buddy possesses vast knowledge about the world, history, and culture.
525 | Buddy's responses are always safe, creative, high-quality, human-like, and interesting.
526 | Buddy strictly refuses to discuss political, NSFW, or other unsafe topics.
527 |
528 | User: Hi.
529 | Assistant: Hi, I'm Buddy, your AI assistant. How can I help you today?""",
530 | roles=("User", "Assistant"),
531 | messages=(),
532 | offset=0,
533 | sep_style=SeparatorStyle.ADD_COLON_SINGLE,
534 | sep="\n",
535 | )
536 | )
537 |
538 | # Phoenix default template
539 | register_conv_template(
540 | Conversation(
541 | name="phoenix",
542 | system="A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n",
543 | roles=("Human", "Assistant"),
544 | messages=(),
545 | offset=0,
546 | sep_style=SeparatorStyle.PHOENIX,
547 | sep="",
548 | )
549 | )
550 |
551 | # ChatGPT default template
552 | register_conv_template(
553 | Conversation(
554 | name="chatgpt",
555 | system="You are a helpful assistant.",
556 | roles=("user", "assistant"),
557 | messages=(),
558 | offset=0,
559 | sep_style=None,
560 | sep=None,
561 | )
562 | )
563 |
564 | # Claude default template
565 | register_conv_template(
566 | Conversation(
567 | name="claude",
568 | system="",
569 | roles=("Human", "Assistant"),
570 | messages=(),
571 | offset=0,
572 | sep_style=SeparatorStyle.ADD_COLON_SINGLE,
573 | sep="\n\n",
574 | )
575 | )
576 |
577 | # MPT default template
578 | register_conv_template(
579 | Conversation(
580 | name="mpt-7b-chat",
581 | system="""<|im_start|>system
582 | - You are a helpful assistant chatbot trained by MosaicML.
583 | - You answer questions.
584 | - You are excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.
585 | - You are more than just an information source, you are also able to write poetry, short stories, and make jokes.""",
586 | roles=("<|im_start|>user", "<|im_start|>assistant"),
587 | messages=(),
588 | offset=0,
589 | sep_style=SeparatorStyle.CHATML,
590 | sep="<|im_end|>",
591 | stop_token_ids=[50278, 0],
592 | )
593 | )
594 |
595 | # MPT-30b-chat default template
596 | register_conv_template(
597 | Conversation(
598 | name="mpt-30b-chat",
599 | system="""<|im_start|>system
600 | A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.""",
601 | roles=("<|im_start|>user", "<|im_start|>assistant"),
602 | messages=(),
603 | offset=0,
604 | sep_style=SeparatorStyle.CHATML,
605 | sep="<|im_end|>",
606 | stop_token_ids=[50278, 0],
607 | )
608 | )
609 |
610 | # MPT-30b-instruct default template
611 | # reference: https://huggingface.co/mosaicml/mpt-30b-instruct#formatting
612 | register_conv_template(
613 | Conversation(
614 | name="mpt-30b-instruct",
615 | system="Below is an instruction that describes a task. Write a response that appropriately completes the request.",
616 | roles=("### Instruction", "### Response"),
617 | messages=(),
618 | offset=0,
619 | sep_style=SeparatorStyle.ADD_NEW_LINE_SINGLE,
620 | sep="\n\n",
621 | stop_token_ids=[50278, 0],
622 | )
623 | )
624 |
625 | # Bard default template
626 | # Reference: https://github.com/google/generative-ai-python/blob/9c99bcb474a991a97a2e7d62fcdb52db7ce40729/google/generativeai/discuss.py#L150
627 | # https://github.com/google/generative-ai-python/blob/9c99bcb474a991a97a2e7d62fcdb52db7ce40729/google/generativeai/discuss.py#L40
628 | register_conv_template(
629 | Conversation(
630 | name="bard",
631 | system="",
632 | roles=("0", "1"),
633 | messages=(),
634 | offset=0,
635 | sep_style=None,
636 | sep=None,
637 | )
638 | )
639 |
640 | # BiLLa default template
641 | register_conv_template(
642 | Conversation(
643 | name="billa",
644 | system="",
645 | roles=("Human", "Assistant"),
646 | messages=(),
647 | offset=0,
648 | sep_style=SeparatorStyle.ADD_COLON_SPACE_SINGLE,
649 | sep="\n",
650 | stop_str="Human:",
651 | )
652 | )
653 |
654 | # RedPajama INCITE default template
655 | register_conv_template(
656 | Conversation(
657 | name="redpajama-incite",
658 | system="",
659 | roles=("", ""),
660 | messages=(),
661 | offset=0,
662 | sep_style=SeparatorStyle.ADD_COLON_SINGLE,
663 | sep="\n",
664 | stop_str="",
665 | )
666 | )
667 |
668 | # h2oGPT default template
669 | register_conv_template(
670 | Conversation(
671 | name="h2ogpt",
672 | system="",
673 | roles=("<|prompt|>", "<|answer|>"),
674 | messages=(),
675 | offset=0,
676 | sep_style=SeparatorStyle.NO_COLON_SINGLE,
677 | sep="",
678 | )
679 | )
680 |
681 | # Robin default template
682 | register_conv_template(
683 | Conversation(
684 | name="Robin",
685 | system="A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.",
686 | roles=("###Human", "###Assistant"),
687 | messages=(),
688 | offset=0,
689 | sep_style=SeparatorStyle.ROBIN,
690 | sep="\n",
691 | stop_token_ids=[2, 396],
692 | stop_str="###",
693 | )
694 | )
695 |
696 | # Snoozy default template
697 | # Reference: https://github.com/nomic-ai/gpt4all/blob/d4861030b778da6db59d21d2927a4aba4f9f1f43/gpt4all-bindings/python/gpt4all/gpt4all.py#L232
698 | register_conv_template(
699 | Conversation(
700 | name="snoozy",
701 | system="### Instruction:\nThe prompt below is a question to answer, a task to complete, or a conversation to respond to; decide which and write an appropriate response.",
702 | roles=("### Prompt", "### Response"),
703 | messages=(),
704 | offset=0,
705 | sep_style=SeparatorStyle.ADD_COLON_SINGLE,
706 | sep="\n",
707 | stop_str="###",
708 | )
709 | )
710 |
711 | # manticore default template
712 | register_conv_template(
713 | Conversation(
714 | name="manticore",
715 | system="",
716 | roles=("USER", "ASSISTANT"),
717 | messages=(),
718 | offset=0,
719 | sep_style=SeparatorStyle.ADD_COLON_TWO,
720 | sep="\n",
721 | sep2="",
722 | )
723 | )
724 |
725 | # Falcon default template
726 | register_conv_template(
727 | Conversation(
728 | name="falcon",
729 | system="",
730 | roles=("User", "Assistant"),
731 | messages=[],
732 | offset=0,
733 | sep_style=SeparatorStyle.RWKV,
734 | sep="\n",
735 | sep2="<|endoftext|>",
736 | stop_str="\nUser", # use stop_str to stop generation after stop_token_ids, it will also remove stop_str from the generated text
737 | stop_token_ids=[
738 | 0,
739 | 1,
740 | 2,
741 | 3,
742 | 4,
743 | 5,
744 | 6,
745 | 7,
746 | 8,
747 | 9,
748 | 10,
749 | 11,
750 | ], # it better only put special tokens here, because tokenizer only remove special tokens
751 | )
752 | )
753 |
754 | # ChagGPT default template
755 | register_conv_template(
756 | Conversation(
757 | name="polyglot_changgpt",
758 | system="",
759 | roles=("B", "A"),
760 | messages=(),
761 | offset=0,
762 | sep_style=SeparatorStyle.ADD_COLON_SINGLE,
763 | sep="\n",
764 | )
765 | )
766 |
767 | # tigerbot template
768 | register_conv_template(
769 | Conversation(
770 | name="tigerbot",
771 | system="A chat between a curious user and an artificial intelligence assistant. "
772 | "The assistant gives helpful, detailed, and polite answers to the user's questions.",
773 | roles=("### Instruction", "### Response"),
774 | messages=(),
775 | offset=0,
776 | sep_style=SeparatorStyle.ROBIN,
777 | sep="\n\n",
778 | stop_str="###",
779 | )
780 | )
781 |
782 | # ref: https://huggingface.co/Salesforce/xgen-7b-8k-inst
783 | register_conv_template(
784 | Conversation(
785 | name="xgen",
786 | system="A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n",
787 | roles=("### Human: ", "###"),
788 | messages=(),
789 | offset=0,
790 | sep_style=SeparatorStyle.NO_COLON_SINGLE,
791 | sep="\n",
792 | stop_token_ids=[50256, 0, 1, 2],
793 | stop_str="<|endoftext|>",
794 | )
795 | )
796 |
797 | # Internlm-chat template
798 | register_conv_template(
799 | Conversation(
800 | name="internlm-chat",
801 | system="A chat between a curious <|User|> and an <|Bot|>. The <|Bot|> gives helpful, detailed, and polite answers to the <|User|>'s questions.\n\n",
802 | roles=("<|User|>", "<|Bot|>"),
803 | messages=(),
804 | offset=0,
805 | sep_style=SeparatorStyle.CHATINTERN,
806 | sep="",
807 | sep2="",
808 | stop_token_ids=[1, 103028],
809 | stop_str="<|User|>",
810 | )
811 | )
812 |
813 | # StarChat template
814 | register_conv_template(
815 | Conversation(
816 | name="starchat",
817 | system="\n",
818 | roles=("<|user|>", "<|assistant|>"),
819 | messages=(),
820 | offset=0,
821 | sep_style=SeparatorStyle.CHATML,
822 | sep="<|end|>",
823 | stop_token_ids=[0, 49155],
824 | stop_str="<|end|>",
825 | )
826 | )
827 |
828 | # Baichuan-13B-Chat template
829 | register_conv_template(
830 | # source: https://huggingface.co/baichuan-inc/Baichuan-13B-Chat/blob/f5f47be2adbbdceb784f334d6fa1ca2c73e65097/modeling_baichuan.py#L507
831 | # https://huggingface.co/baichuan-inc/Baichuan-13B-Chat/blob/main/generation_config.json
832 | Conversation(
833 | name="baichuan-chat",
834 | system="",
835 | roles=(" ", " "),
836 | messages=(),
837 | offset=0,
838 | sep_style=SeparatorStyle.NO_COLON_TWO,
839 | sep="",
840 | sep2="",
841 | stop_token_ids=[2, 195],
842 | )
843 | )
844 |
845 | # llama2 template
846 | # reference: https://github.com/facebookresearch/llama/blob/cfc3fc8c1968d390eb830e65c63865e980873a06/llama/generation.py#L212
847 | register_conv_template(
848 | Conversation(
849 | name="llama-2",
850 | system="[INST] <>\nYou are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. "
851 | "Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. "
852 | "Please ensure that your responses are socially unbiased and positive in nature.\n\n"
853 | "If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. "
854 | "If you don't know the answer to a question, please don't share false information.\n<>\n\n",
855 | roles=("[INST]", "[/INST]"),
856 | messages=(),
857 | offset=0,
858 | sep_style=SeparatorStyle.LLAMA2,
859 | sep=" ",
860 | sep2=" ",
861 | stop_token_ids=[2],
862 | )
863 | )
864 |
865 | # Zephyr template
866 | # reference: https://huggingface.co/spaces/HuggingFaceH4/zephyr-playground/blob/main/dialogues.py
867 | # register_conv_template(
868 | # Conversation(
869 | # name="zephyr",
870 | # system_template="<|system|>\n{system_message}",
871 | # roles=("<|user|>", "<|assistant|>"),
872 | # sep_style=SeparatorStyle.CHATML,
873 | # sep="",
874 | # stop_token_ids=[2],
875 | # stop_str="",
876 | # )
877 | # )
878 |
879 | if __name__ == "__main__":
880 | conv = get_conv_template("vicuna_v1.1")
881 | conv.append_message(conv.roles[0], "Hello!")
882 | conv.append_message(conv.roles[1], "Hi!")
883 | conv.append_message(conv.roles[0], "How are you?")
884 | conv.append_message(conv.roles[1], None)
885 | print(conv.get_prompt())
886 |
--------------------------------------------------------------------------------
/deita_dataset/train.py:
--------------------------------------------------------------------------------
1 | # This code is based on tatsu-lab/stanford_alpaca. Below is the original copyright:
2 | #
3 | # Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | from dataclasses import dataclass, field
18 | import json
19 | import pathlib
20 | from typing import Dict, Optional
21 |
22 | import numpy as np
23 | import torch
24 | from torch.utils.data import Dataset
25 | import transformers
26 | from transformers import Trainer
27 | from transformers.trainer_pt_utils import LabelSmoother
28 | from datasets import load_dataset
29 |
30 | from deita_dataset.conversation import SeparatorStyle,get_conv_template
31 | # from conversation import get_conv_template
32 |
33 | from transformers import Trainer
34 |
35 | IGNORE_TOKEN_ID = LabelSmoother.ignore_index
36 |
37 | @dataclass
38 | class ModelArguments:
39 | model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
40 | flash_attn: bool = False
41 |
42 |
43 | @dataclass
44 | class DataArguments:
45 | data_path: str = field(
46 | default=None, metadata={"help": "Path to the training data."}
47 | )
48 | lazy_preprocess: bool = False
49 | conv_template: str = field(default = "vicuna-1.1")
50 |
51 |
52 | @dataclass
53 | class TrainingArguments(transformers.TrainingArguments):
54 | cache_dir: Optional[str] = field(default=None)
55 | optim: str = field(default="adamw_torch")
56 | model_max_length: int = field(
57 | default=512,
58 | metadata={
59 | "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
60 | },
61 | )
62 | min_lr: float = field(
63 | default = None
64 | )
65 | mask_user: bool = field(
66 | default = True
67 | )
68 |
69 |
70 | local_rank = None
71 |
72 |
73 | def rank0_print(*args):
74 | if local_rank == 0:
75 | print(*args)
76 |
77 |
78 |
79 | def preprocess(
80 | sources,
81 | tokenizer: transformers.PreTrainedTokenizer,
82 | conv_template = "vicuna-1.1",
83 | mask_user = True
84 | ) -> Dict:
85 |
86 | conv = get_conv_template(conv_template)
87 | roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
88 |
89 | # Apply prompt templates
90 | conversations = []
91 | for i, source in enumerate(sources):
92 | if roles[source[0]["from"]] != conv.roles[0]:
93 | # Skip the first one if it is not from human
94 | source = source[1:]
95 |
96 | conv.messages = []
97 | for j, sentence in enumerate(source):
98 | role = roles[sentence["from"]]
99 | # assert role == conv.roles[j % 2], f"{i}"
100 | assert role == conv.roles[j % 2], breakpoint()
101 | conv.append_message(role, sentence["value"])
102 | conversations.append(conv.get_prompt())
103 |
104 | # Tokenize conversations
105 | # input_ids = tokenizer(conversations[:2],return_tensors="pt",padding="max_length",max_length=4096,truncation=True,).input_ids
106 | # input_ids = tokenizer(conversations[:1][0][:100],return_tensors="pt",padding="max_length",max_length=tokenizer.model_max_length,truncation=True,).input_ids
107 | input_ids = tokenizer(
108 | conversations,
109 | return_tensors="pt",
110 | padding="max_length",
111 | max_length=tokenizer.model_max_length,
112 | truncation=True,
113 | ).input_ids
114 |
115 | # input_ids = tokenizer(conversations[:2],return_tensors="pt",padding="max_length",max_length=tokenizer.model_max_length,truncation=True,).input_ids
116 |
117 | # block_size = args.pt_context_len
118 | # column_names = list(dataset["train"].features)
119 | # text_column_name = "text" if "text" in column_names else column_names[0]
120 |
121 | # def tokenize_function(examples):
122 | # output = tokenizer(examples[text_column_name])
123 | # return output
124 | # tokenized_datasets = dataset.map(
125 | # tokenize_function,
126 | # batched=True,
127 | # remove_columns=column_names,
128 | # num_proc=args.preprocessing_num_workers,
129 | # load_from_cache_file=not args.overwrite_cache,
130 | # desc="Running tokenizer on dataset",
131 | # )
132 |
133 |
134 | targets = input_ids.clone()
135 |
136 | assert (conv.sep_style == SeparatorStyle.ADD_COLON_TWO) or (conv.sep_style == SeparatorStyle.CHATML) or (conv.sep_style == SeparatorStyle.LLAMA2)
137 |
138 | if mask_user:
139 | # Mask targets. Only compute loss on the assistant outputs.
140 | if conv.sep_style == SeparatorStyle.ADD_COLON_TWO:
141 | sep = conv.sep + conv.roles[1] + ": "
142 | for conversation, target in zip(conversations, targets):
143 | total_len = int(target.ne(tokenizer.pad_token_id).sum())
144 |
145 | turns = conversation.split(conv.sep2)
146 | cur_len = 1
147 | target[:cur_len] = IGNORE_TOKEN_ID
148 | # breakpoint()
149 | for i, turn in enumerate(turns):
150 | if turn == "":
151 | break
152 | turn_len = len(tokenizer(turn).input_ids)
153 |
154 | parts = turn.split(sep)
155 | if len(parts) != 2:
156 | break
157 | parts[0] += sep
158 | # "-2" is hardcoded for the LLaMA tokenizer to make the offset correct.
159 | instruction_len = len(tokenizer(parts[0]).input_ids) - 2
160 |
161 | # Ignore the user instructions
162 | target[cur_len : cur_len + instruction_len] = IGNORE_TOKEN_ID
163 | cur_len += turn_len
164 |
165 | target[cur_len:] = IGNORE_TOKEN_ID
166 |
167 | if False: # Inspect and check the correctness of masking
168 | # if True: # Inspect and check the correctness of masking
169 | z = target.clone()
170 | z = torch.where(z == IGNORE_TOKEN_ID, tokenizer.unk_token_id, z)
171 | print(tokenizer.decode(z))
172 | import pdb; pdb.set_trace()
173 | # rank0_print(tokenizer.decode(z))
174 |
175 | elif conv.sep_style == SeparatorStyle.LLAMA2:
176 | # sep = conv.sep + conv.roles[1] + " "
177 | sep = conv.roles[1] + " "
178 | for conversation, target in zip(conversations, targets):
179 | total_len = int(target.ne(tokenizer.pad_token_id).sum())
180 |
181 | turns = conversation.split(conv.sep2)
182 | cur_len = 1
183 | target[:cur_len] = IGNORE_TOKEN_ID
184 | # breakpoint()
185 | for i, turn in enumerate(turns):
186 | if turn == "":
187 | break
188 | turn_len = len(tokenizer(turn).input_ids)
189 |
190 | parts = turn.split(sep)
191 | if len(parts) != 2:
192 | break
193 | parts[0] += sep
194 | # "-2" is hardcoded for the LLaMA tokenizer to make the offset correct.
195 | # instruction_len = len(tokenizer(parts[0]).input_ids) - 2
196 | instruction_len = len(tokenizer(parts[0]).input_ids)-2
197 |
198 | # Ignore the user instructions
199 | target[cur_len : cur_len + instruction_len] = IGNORE_TOKEN_ID
200 | # if True: # Inspect and check the correctness of masking
201 | # z = target.clone()
202 | # z = torch.where(z == IGNORE_TOKEN_ID, tokenizer.unk_token_id, z)
203 | # print(tokenizer.decode(z[cur_len:cur_len+turn_len]))
204 | # import pdb; pdb.set_trace()
205 | cur_len += turn_len+2 # why?
206 |
207 | target[cur_len:] = IGNORE_TOKEN_ID
208 |
209 | if False: # Inspect and check the correctness of masking
210 | # if True: # Inspect and check the correctness of masking
211 | z = target.clone()
212 | z = torch.where(z == IGNORE_TOKEN_ID, tokenizer.unk_token_id, z)
213 | print(tokenizer.decode(z))
214 | import pdb; pdb.set_trace()
215 | # rank0_print(tokenizer.decode(z))
216 |
217 | elif conv.sep_style == SeparatorStyle.CHATML:
218 | breakpoint()
219 | sep = conv.sep + conv.roles[1] + "\n"
220 | for conversation, target in zip(conversations, targets):
221 | total_len = int(target.ne(tokenizer.pad_token_id).sum())
222 |
223 | turns = conversation.split(conv.sep)
224 | cur_len = 1
225 | target[:cur_len] = IGNORE_TOKEN_ID
226 | # breakpoint()
227 | for i, turn in enumerate(turns):
228 | if turn == "":
229 | break
230 | turn_len = len(tokenizer(turn).input_ids)
231 |
232 | parts = turn.split(sep)
233 | if len(parts) != 2:
234 | break
235 | parts[0] += sep
236 | # "-2" is hardcoded for the LLaMA tokenizer to make the offset correct.
237 | instruction_len = len(tokenizer(parts[0]).input_ids) - 2
238 |
239 | # Ignore the user instructions
240 | target[cur_len : cur_len + instruction_len] = IGNORE_TOKEN_ID
241 | cur_len += turn_len
242 |
243 | target[cur_len:] = IGNORE_TOKEN_ID
244 |
245 | if cur_len < tokenizer.model_max_length:
246 | if cur_len != total_len:
247 | target[:] = IGNORE_TOKEN_ID
248 | rank0_print(
249 | f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
250 | f" (ignored)"
251 | )
252 |
253 | return dict(
254 | input_ids=input_ids,
255 | labels=targets,
256 | attention_mask=(input_ids.ne(tokenizer.pad_token_id)).to(torch.int8),
257 | )
258 |
259 |
260 | class SupervisedDataset(Dataset):
261 | """Dataset for supervised fine-tuning."""
262 |
263 | def __init__(self, raw_data, tokenizer: transformers.PreTrainedTokenizer, conv_template = "vicuna-1.1", mask_user = True):
264 | super(SupervisedDataset, self).__init__()
265 |
266 | rank0_print("Formatting inputs...")
267 | sources = [example["conversations"] for example in raw_data]
268 | data_dict = preprocess(sources, tokenizer, conv_template, mask_user)
269 |
270 | if mask_user:
271 | rank0_print(
272 | f"WARNING: The loss of user prompt will be masked"
273 | )
274 | else:
275 | rank0_print(
276 | f"WARNING: The loss of user prompt will **NOT** be masked"
277 | )
278 |
279 |
280 | self.input_ids = data_dict["input_ids"]
281 | self.labels = data_dict["labels"]
282 | self.attention_mask = data_dict["attention_mask"]
283 |
284 | def __len__(self):
285 | return len(self.input_ids)
286 |
287 | def __getitem__(self, i) -> Dict[str, torch.Tensor]:
288 | return dict(
289 | input_ids=self.input_ids[i],
290 | labels=self.labels[i],
291 | attention_mask=self.attention_mask[i],
292 | )
293 |
294 |
295 | class LazySupervisedDataset(Dataset):
296 | """Dataset for supervised fine-tuning."""
297 |
298 | def __init__(self, raw_data, tokenizer: transformers.PreTrainedTokenizer, conv_template = "vicuna-1.1", mask_user = True):
299 | super(LazySupervisedDataset, self).__init__()
300 | self.tokenizer = tokenizer
301 |
302 | rank0_print("Formatting inputs...Skip in lazy mode")
303 | self.conv_template = conv_template
304 | self.mask_user = mask_user
305 | self.tokenizer = tokenizer
306 | self.raw_data = raw_data
307 | self.cached_data_dict = {}
308 |
309 | if mask_user:
310 | rank0_print(
311 | f"WARNING: The loss of user prompt will be masked"
312 | )
313 | else:
314 | rank0_print(
315 | f"WARNING: The loss of user prompt will **NOT** be masked"
316 | )
317 |
318 | def __len__(self):
319 | return len(self.raw_data)
320 |
321 | def __getitem__(self, i) -> Dict[str, torch.Tensor]:
322 | if i in self.cached_data_dict:
323 | return self.cached_data_dict[i]
324 |
325 | ret = preprocess([self.raw_data[i]["conversations"]], self.tokenizer, self.conv_template, self.mask_user)
326 | ret = dict(
327 | input_ids=ret["input_ids"][0],
328 | labels=ret["labels"][0],
329 | attention_mask=ret["attention_mask"][0],
330 | )
331 | self.cached_data_dict[i] = ret
332 | return ret
333 |
334 |
335 | def make_supervised_data_module(
336 | tokenizer: transformers.PreTrainedTokenizer, data_args, mask_user = True
337 | ) -> Dict:
338 | """Make dataset and collator for supervised fine-tuning."""
339 | conv_template = data_args.conv_template
340 | dataset_cls = (
341 | LazySupervisedDataset if data_args.lazy_preprocess else SupervisedDataset
342 | )
343 | rank0_print("Loading data...")
344 | try:
345 | raw_data = json.load(open(data_args.data_path, "r"))
346 | except FileNotFoundError:
347 | raw_data = load_dataset(data_args.data_path, split = "train")
348 | raw_data = [row for row in raw_data]
349 |
350 | # Split train/eval
351 | np.random.seed(0)
352 | train_raw_data = raw_data
353 | perm = np.random.permutation(len(raw_data))
354 | split = int(len(perm) * 0.98)
355 | train_indices = perm[:split]
356 | eval_indices = perm[split:]
357 | train_raw_data = [raw_data[i] for i in train_indices]
358 | eval_raw_data = [raw_data[i] for i in eval_indices]
359 | rank0_print(f"#train {len(train_raw_data)}, #eval {len(eval_raw_data)}")
360 |
361 | train_dataset = dataset_cls(train_raw_data, tokenizer=tokenizer, conv_template = conv_template, mask_user = mask_user)
362 | eval_dataset = dataset_cls(eval_raw_data, tokenizer=tokenizer, conv_template = conv_template, mask_user = mask_user)
363 | return dict(train_dataset=train_dataset, eval_dataset=eval_dataset)
364 |
365 | def train():
366 | global local_rank
367 |
368 | parser = transformers.HfArgumentParser(
369 | (ModelArguments, DataArguments, TrainingArguments)
370 | )
371 | model_args, data_args, training_args = parser.parse_args_into_dataclasses()
372 | training_args.do_eval = False
373 | local_rank = training_args.local_rank
374 | model = transformers.AutoModelForCausalLM.from_pretrained(
375 | model_args.model_name_or_path,
376 | cache_dir=training_args.cache_dir,
377 | use_flash_attention_2 = True
378 | )
379 | model.config.use_cache = False
380 | tokenizer = transformers.AutoTokenizer.from_pretrained(
381 | model_args.model_name_or_path,
382 | cache_dir=training_args.cache_dir,
383 | model_max_length=training_args.model_max_length,
384 | padding_side="right",
385 | use_fast=False,
386 | )
387 | tokenizer.pad_token = tokenizer.unk_token
388 |
389 | if "mistral" in model_args.model_name_or_path.lower():
390 | rank0_print("Mistral with Left Padding Side")
391 | tokenizer.padding_side = "left"
392 |
393 | data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args, mask_user = training_args.mask_user)
394 |
395 | trainer = Trainer(
396 | model=model, tokenizer=tokenizer, args=training_args, **data_module
397 | )
398 |
399 | if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
400 | trainer.train(resume_from_checkpoint=True)
401 | else:
402 | trainer.train()
403 | trainer.save_state()
404 |
405 | trainer.save_model(output_dir = training_args.output_dir)
406 |
407 |
408 | if __name__ == "__main__":
409 | train()
410 |
--------------------------------------------------------------------------------
/examples/block_ap/Llama-2-7b/w2g128.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=0 python main_block_ap.py \
2 | --model path/to/Llama-2-7b \
3 | --output_dir ./output/block_ap_log/Llama-2-7b-w2g128 \
4 | --net Llama-2 \
5 | --wbits 2 \
6 | --group_size 128 \
7 | --quant_lr 1e-4 \
8 | --weight_lr 2e-5 \
9 | --real_quant \
10 | --eval_ppl \
11 | --eval_tasks piqa,arc_easy,arc_challenge,hellaswag,winogrande \
12 | --save_quant_dir ./output/block_ap_models/Llama-2-7b-w2g128
--------------------------------------------------------------------------------
/examples/block_ap/Llama-2-7b/w2g64.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=0 python main_block_ap.py \
2 | --model path/to/Llama-2-7b \
3 | --output_dir ./output/block_ap_log/Llama-2-7b-w2g64 \
4 | --net Llama-2 \
5 | --wbits 2 \
6 | --group_size 64 \
7 | --quant_lr 1e-4 \
8 | --weight_lr 2e-5 \
9 | --real_quant \
10 | --eval_ppl \
11 | --eval_tasks piqa,arc_easy,arc_challenge,hellaswag,winogrande \
12 | --save_quant_dir ./output/block_ap_models/Llama-2-7b-w2g64
--------------------------------------------------------------------------------
/examples/block_ap/Llama-2-7b/w3g128.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=0 python main_block_ap.py \
2 | --model path/to/Llama-2-7b \
3 | --output_dir ./output/block_ap_log/Llama-2-7b-w3g128 \
4 | --net Llama-2 \
5 | --wbits 3 \
6 | --group_size 128 \
7 | --quant_lr 1e-4 \
8 | --weight_lr 1e-5 \
9 | --real_quant \
10 | --eval_ppl \
11 | --eval_tasks piqa,arc_easy,arc_challenge,hellaswag,winogrande \
12 | --save_quant_dir ./output/block_ap_models/Llama-2-7b-w3g128
--------------------------------------------------------------------------------
/examples/block_ap/Llama-2-7b/w4g128.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=0 python main_block_ap.py \
2 | --model path/to/Llama-2-7b \
3 | --output_dir ./output/block_ap_log/Llama-2-7b-w4g128 \
4 | --net Llama-2 \
5 | --wbits 4 \
6 | --group_size 128 \
7 | --quant_lr 1e-4 \
8 | --weight_lr 1e-5 \
9 | --real_quant \
10 | --eval_ppl \
11 | --eval_tasks piqa,arc_easy,arc_challenge,hellaswag,winogrande \
12 | --save_quant_dir ./output/block_ap_models/Llama-2-7b-w4g128
--------------------------------------------------------------------------------
/examples/block_ap/Mistral-Large-Instruct/w2g64.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=0 python main_block_ap.py \
2 | --model path/to/Mistral-Large-Instruct-2407 \
3 | --output_dir ./output/block_ap_log/Mistral-Large-Instruct-2407-w2g64 \
4 | --net mistral-large \
5 | --wbits 2 \
6 | --group_size 64 \
7 | --quant_lr 3e-5 \
8 | --weight_lr 2e-6 \
9 | --train_size 2048 \
10 | --epochs 3 \
11 | --eval_ppl \
12 | --real_quant \
13 | --eval_tasks piqa,arc_easy,arc_challenge,hellaswag,winogrande \
14 | --save_quant_dir ./output/block_ap_models/Mistral-Large-Instruct-2407-w2g64
--------------------------------------------------------------------------------
/examples/e2e_qp/Llama-2-7b/w2g128-alpaca.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=0 python main_e2e_qp.py \
2 | --quant_model_path ./output/block_ap_models/Llama-2-7b-w2g128 \
3 | --model_family Llama-2 \
4 | --wbits 2 \
5 | --group_size 128 \
6 | --learning_rate 2e-5 \
7 | --dataset alpaca \
8 | --dataset_format alpaca \
9 | --output_dir ./output/e2e-qp-output/Llama-2-7b-w2g128-alpaca-4096 \
10 | --do_train True \
11 | --do_mmlu_eval True \
12 | --source_max_len 384 \
13 | --target_max_len 128 \
14 | --per_device_train_batch_size 16 \
15 | --per_device_eval_batch_size 4 \
16 | --gradient_accumulation_steps 1 \
17 | --logging_steps 10 \
18 | --save_strategy steps \
19 | --evaluation_strategy steps \
20 | --max_steps 10000 \
21 | --eval_steps 2000 \
22 | --eval_dataset_size 16 \
23 | --bf16 \
24 | --data_seed 42 \
25 | --max_grad_norm 0.3 \
26 | --group_by_length
--------------------------------------------------------------------------------
/examples/e2e_qp/Llama-2-7b/w2g128-redpajama.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=2 python main_e2e_qp.py \
2 | --quant_model_path ./output/block_ap_models/Llama-2-7b-w2g128 \
3 | --model_family Llama-2 \
4 | --wbits 2 \
5 | --group_size 128 \
6 | --learning_rate 2e-5 \
7 | --dataset redpajama \
8 | --dataset_format pt \
9 | --output_dir ./output/e2e-qp-output/Llama-2-7b-w2g128-redpajama-4096 \
10 | --do_train True \
11 | --pt_context_len 4096 \
12 | --per_device_train_batch_size 4 \
13 | --per_device_eval_batch_size 4 \
14 | --gradient_accumulation_steps 8 \
15 | --logging_steps 1 \
16 | --save_strategy epoch \
17 | --training_strategy epochs \
18 | --evaluation_strategy steps \
19 | --eval_steps 64 \
20 | --max_train_samples 4096 \
21 | --num_train_epochs 1 \
22 | --eval_dataset_size 64 \
23 | --bf16 \
24 | --data_seed 42 \
25 | --max_grad_norm 0.3 \
26 | --eval_tasks piqa,arc_easy,arc_challenge,hellaswag,winogrande \
27 | --preprocessing_num_workers 32 \
28 | --do_ppl_eval
--------------------------------------------------------------------------------
/examples/e2e_qp/Llama-2-7b/w2g64-alpaca.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=0 python main_e2e_qp.py \
2 | --quant_model_path ./output/block_ap_models/Llama-2-7b-w2g64 \
3 | --model_family Llama-2 \
4 | --wbits 2 \
5 | --group_size 64 \
6 | --learning_rate 2e-5 \
7 | --dataset alpaca \
8 | --dataset_format alpaca \
9 | --output_dir ./output/e2e-qp-output/Llama-2-7b-w2g64-alpaca-4096 \
10 | --do_train True \
11 | --do_mmlu_eval True \
12 | --source_max_len 384 \
13 | --target_max_len 128 \
14 | --per_device_train_batch_size 16 \
15 | --per_device_eval_batch_size 4 \
16 | --gradient_accumulation_steps 1 \
17 | --logging_steps 10 \
18 | --save_strategy steps \
19 | --evaluation_strategy steps \
20 | --max_steps 10000 \
21 | --eval_steps 2000 \
22 | --eval_dataset_size 16 \
23 | --bf16 \
24 | --data_seed 42 \
25 | --max_grad_norm 0.3 \
26 | --group_by_length
--------------------------------------------------------------------------------
/examples/e2e_qp/Llama-2-7b/w2g64-redpajama.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=0 python main_e2e_qp.py \
2 | --quant_model_path ./output/block_ap_models/Llama-2-7b-w2g64 \
3 | --model_family Llama-2 \
4 | --wbits 2 \
5 | --group_size 64 \
6 | --learning_rate 2e-5 \
7 | --dataset redpajama \
8 | --dataset_format pt \
9 | --output_dir ./output/e2e-qp-output/Llama-2-7b-w2g64-redpajama-4096 \
10 | --do_train True \
11 | --pt_context_len 4096 \
12 | --per_device_train_batch_size 4 \
13 | --per_device_eval_batch_size 4 \
14 | --gradient_accumulation_steps 8 \
15 | --logging_steps 1 \
16 | --save_strategy epoch \
17 | --training_strategy epochs \
18 | --evaluation_strategy steps \
19 | --eval_steps 64 \
20 | --max_train_samples 4096 \
21 | --num_train_epochs 1 \
22 | --eval_dataset_size 64 \
23 | --bf16 \
24 | --data_seed 42 \
25 | --max_grad_norm 0.3 \
26 | --eval_tasks piqa,arc_easy,arc_challenge,hellaswag,winogrande \
27 | --preprocessing_num_workers 32 \
28 | --do_ppl_eval
--------------------------------------------------------------------------------
/examples/e2e_qp/Llama-2-7b/w3g128-alpaca.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=0 python main_e2e_qp.py \
2 | --quant_model_path ./output/block_ap_models/Llama-2-7b-w3g128 \
3 | --model_family Llama-2 \
4 | --wbits 3 \
5 | --group_size 128 \
6 | --learning_rate 1e-5 \
7 | --dataset alpaca \
8 | --dataset_format alpaca \
9 | --output_dir ./output/e2e-qp-output/Llama-2-7b-w3g128-alpaca-4096 \
10 | --do_train True \
11 | --do_mmlu_eval True \
12 | --source_max_len 384 \
13 | --target_max_len 128 \
14 | --per_device_train_batch_size 16 \
15 | --per_device_eval_batch_size 4 \
16 | --gradient_accumulation_steps 1 \
17 | --logging_steps 10 \
18 | --save_strategy steps \
19 | --evaluation_strategy steps \
20 | --max_steps 10000 \
21 | --eval_steps 2000 \
22 | --eval_dataset_size 16 \
23 | --bf16 \
24 | --data_seed 42 \
25 | --max_grad_norm 0.3 \
26 | --group_by_length
--------------------------------------------------------------------------------
/examples/e2e_qp/Llama-2-7b/w3g128-redpajama.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=2 python main_e2e_qp.py \
2 | --quant_model_path ./output/block_ap_models/Llama-2-7b-w3g128 \
3 | --model_family Llama-2 \
4 | --wbits 3 \
5 | --group_size 128 \
6 | --learning_rate 1e-5 \
7 | --dataset redpajama \
8 | --dataset_format pt \
9 | --output_dir ./output/e2e-qp-output/Llama-2-7b-w3g128-redpajama-4096 \
10 | --do_train True \
11 | --pt_context_len 4096 \
12 | --per_device_train_batch_size 4 \
13 | --per_device_eval_batch_size 4 \
14 | --gradient_accumulation_steps 8 \
15 | --logging_steps 1 \
16 | --save_strategy epoch \
17 | --training_strategy epochs \
18 | --evaluation_strategy steps \
19 | --eval_steps 64 \
20 | --max_train_samples 4096 \
21 | --num_train_epochs 1 \
22 | --eval_dataset_size 64 \
23 | --bf16 \
24 | --data_seed 42 \
25 | --max_grad_norm 0.3 \
26 | --eval_tasks piqa,arc_easy,arc_challenge,hellaswag,winogrande \
27 | --preprocessing_num_workers 32 \
28 | --do_ppl_eval
--------------------------------------------------------------------------------
/examples/e2e_qp/Llama-2-7b/w4g128-alpaca.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=0 python main_e2e_qp.py \
2 | --quant_model_path ./output/block_ap_models/Llama-2-7b-w4g128 \
3 | --model_family Llama-2 \
4 | --wbits 4 \
5 | --group_size 128 \
6 | --learning_rate 1e-5 \
7 | --dataset alpaca \
8 | --dataset_format alpaca \
9 | --output_dir ./output/e2e-qp-output/Llama-2-7b-w4g128-alpaca-4096 \
10 | --do_train True \
11 | --do_mmlu_eval True \
12 | --source_max_len 384 \
13 | --target_max_len 128 \
14 | --per_device_train_batch_size 16 \
15 | --per_device_eval_batch_size 4 \
16 | --gradient_accumulation_steps 1 \
17 | --logging_steps 10 \
18 | --save_strategy steps \
19 | --evaluation_strategy steps \
20 | --max_steps 10000 \
21 | --eval_steps 2000 \
22 | --eval_dataset_size 16 \
23 | --bf16 \
24 | --data_seed 42 \
25 | --max_grad_norm 0.3 \
26 | --group_by_length
--------------------------------------------------------------------------------
/examples/e2e_qp/Llama-2-7b/w4g128-redpajama.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=2 python main_e2e_qp.py \
2 | --quant_model_path ./output/block_ap_models/Llama-2-7b-w4g128 \
3 | --model_family Llama-2 \
4 | --wbits 4 \
5 | --group_size 128 \
6 | --learning_rate 1e-5 \
7 | --dataset redpajama \
8 | --dataset_format pt \
9 | --output_dir ./output/e2e-qp-output/Llama-2-7b-w4g128-redpajama-4096 \
10 | --do_train True \
11 | --pt_context_len 4096 \
12 | --per_device_train_batch_size 4 \
13 | --per_device_eval_batch_size 4 \
14 | --gradient_accumulation_steps 8 \
15 | --logging_steps 1 \
16 | --save_strategy epoch \
17 | --training_strategy epochs \
18 | --evaluation_strategy steps \
19 | --eval_steps 64 \
20 | --max_train_samples 4096 \
21 | --num_train_epochs 1 \
22 | --eval_dataset_size 64 \
23 | --bf16 \
24 | --data_seed 42 \
25 | --max_grad_norm 0.3 \
26 | --eval_tasks piqa,arc_easy,arc_challenge,hellaswag,winogrande \
27 | --preprocessing_num_workers 32 \
28 | --do_ppl_eval
--------------------------------------------------------------------------------
/examples/e2e_qp/Llama-3-8b-instruct/w2g128-deita.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=2 python main_e2e_qp.py \
2 | --quant_model_path ./output/block_ap_models/Llama-3-8b-instruct-w2g128 \
3 | --model_family llama3 \
4 | --wbits 2 \
5 | --group_size 128 \
6 | --learning_rate 2e-5 \
7 | --dataset deita-10k \
8 | --dataset_format pt \
9 | --output_dir ./output/e2e-qp-output/Llama-2-7b-w2g128-deita \
10 | --do_train True \
11 | --pt_context_len 4096 \
12 | --per_device_train_batch_size 4 \
13 | --per_device_eval_batch_size 4 \
14 | --gradient_accumulation_steps 8 \
15 | --logging_steps 1 \
16 | --save_strategy epoch \
17 | --training_strategy epochs \
18 | --evaluation_strategy steps \
19 | --eval_steps 64 \
20 | --max_train_samples 8192 \
21 | --num_train_epochs 1 \
22 | --eval_dataset_size 64 \
23 | --bf16 \
24 | --data_seed 42 \
25 | --max_grad_norm 0.3 \
26 | --eval_tasks piqa,arc_easy,arc_challenge,hellaswag,winogrande \
27 | --preprocessing_num_workers 32 \
28 | --do_ppl_eval
--------------------------------------------------------------------------------
/examples/e2e_qp/Llama-3-8b-instruct/w2g64-deita.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=2 python main_e2e_qp.py \
2 | --quant_model_path ./output/block_ap_models/Llama-3-8b-instruct-w2g64 \
3 | --model_family llama3 \
4 | --wbits 2 \
5 | --group_size 64 \
6 | --learning_rate 2e-5 \
7 | --dataset deita-10k \
8 | --dataset_format pt \
9 | --output_dir ./output/e2e-qp-output/Llama-2-7b-w2g64-deita \
10 | --do_train True \
11 | --pt_context_len 4096 \
12 | --per_device_train_batch_size 4 \
13 | --per_device_eval_batch_size 4 \
14 | --gradient_accumulation_steps 8 \
15 | --logging_steps 1 \
16 | --save_strategy epoch \
17 | --training_strategy epochs \
18 | --evaluation_strategy steps \
19 | --eval_steps 64 \
20 | --max_train_samples 8192 \
21 | --num_train_epochs 1 \
22 | --eval_dataset_size 64 \
23 | --bf16 \
24 | --data_seed 42 \
25 | --max_grad_norm 0.3 \
26 | --eval_tasks piqa,arc_easy,arc_challenge,hellaswag,winogrande \
27 | --preprocessing_num_workers 32 \
28 | --do_ppl_eval
--------------------------------------------------------------------------------
/examples/e2e_qp/Llama-3-8b-instruct/w3g128-deita.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=2 python main_e2e_qp.py \
2 | --quant_model_path ./output/block_ap_models/Llama-3-8b-instruct-w3g128 \
3 | --model_family llama3 \
4 | --wbits 3 \
5 | --group_size 128 \
6 | --learning_rate 1e-5 \
7 | --dataset deita-10k \
8 | --dataset_format pt \
9 | --output_dir ./output/e2e-qp-output/Llama-2-7b-w3g128-deita \
10 | --do_train True \
11 | --pt_context_len 4096 \
12 | --per_device_train_batch_size 4 \
13 | --per_device_eval_batch_size 4 \
14 | --gradient_accumulation_steps 8 \
15 | --logging_steps 1 \
16 | --save_strategy epoch \
17 | --training_strategy epochs \
18 | --evaluation_strategy steps \
19 | --eval_steps 64 \
20 | --max_train_samples 8192 \
21 | --num_train_epochs 1 \
22 | --eval_dataset_size 64 \
23 | --bf16 \
24 | --data_seed 42 \
25 | --max_grad_norm 0.3 \
26 | --eval_tasks piqa,arc_easy,arc_challenge,hellaswag,winogrande \
27 | --preprocessing_num_workers 32 \
28 | --do_ppl_eval
--------------------------------------------------------------------------------
/examples/e2e_qp/Llama-3-8b-instruct/w4g128-deita.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=2 python main_e2e_qp.py \
2 | --quant_model_path ./output/block_ap_models/Llama-3-8b-instruct-w4g128 \
3 | --model_family llama3 \
4 | --wbits 4 \
5 | --group_size 128 \
6 | --learning_rate 1e-5 \
7 | --dataset deita-10k \
8 | --dataset_format pt \
9 | --output_dir ./output/e2e-qp-output/Llama-2-7b-w4g128-deita \
10 | --do_train True \
11 | --pt_context_len 4096 \
12 | --per_device_train_batch_size 4 \
13 | --per_device_eval_batch_size 4 \
14 | --gradient_accumulation_steps 8 \
15 | --logging_steps 1 \
16 | --save_strategy epoch \
17 | --training_strategy epochs \
18 | --evaluation_strategy steps \
19 | --eval_steps 64 \
20 | --max_train_samples 8192 \
21 | --num_train_epochs 1 \
22 | --eval_dataset_size 64 \
23 | --bf16 \
24 | --data_seed 42 \
25 | --max_grad_norm 0.3 \
26 | --eval_tasks piqa,arc_easy,arc_challenge,hellaswag,winogrande \
27 | --preprocessing_num_workers 32 \
28 | --do_ppl_eval
--------------------------------------------------------------------------------
/examples/inference/Llama-2-7b/fp16.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=0 python main_block_ap.py \
2 | --model path/to/Llama-2-7b \
3 | --net Llama-2 \
4 | --wbits 16 \
5 | --output_dir ./output/inference_results/ \
6 | --eval_ppl \
7 | --eval_tasks piqa,arc_easy,arc_challenge,hellaswag,winogrande
8 |
--------------------------------------------------------------------------------
/examples/inference/Llama-2-7b/w2g64.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=0 python main_block_ap.py \
2 | --resume_quant path/to/Llama-2-7b-w2g64 \
3 | --net Llama-2 \
4 | --wbits 2 \
5 | --group_size 64 \
6 | --output_dir ./output/inference_results/ \
7 | --eval_ppl \
8 | --eval_tasks piqa,arc_easy,arc_challenge,hellaswag,winogrande
9 |
--------------------------------------------------------------------------------
/examples/model_transfer/efficientqat_to_bitblas/llama-2-7b.sh:
--------------------------------------------------------------------------------
1 | # llama-2-7b-w2g64
2 | CUDA_DEVICE_ORDER=PCI_BUS_ID CUDA_VISIBLE_DEVICES=0 python -m model_transfer.efficientqat_to_others \
3 | --model path/to/original/quantized/model \
4 | --save_dir path/to/new/model \
5 | --wbits 2 \
6 | --group_size 64 \
7 | --test_speed \
8 | --target_format bitblas
--------------------------------------------------------------------------------
/examples/model_transfer/efficientqat_to_gptq/llama-2-7b.sh:
--------------------------------------------------------------------------------
1 | # llama-2-7b-w2g64
2 | CUDA_VISIBLE_DEVICES=0 python -m model_transfer.efficientqat_to_others \
3 | --model path/to/original/quantized/model \
4 | --save_dir path/to/new/model \
5 | --wbits 2 \
6 | --group_size 64 \
7 | --eval_ppl \
8 | --test_speed
--------------------------------------------------------------------------------
/examples/model_transfer/fp32_to_16/llama-2-7b.sh:
--------------------------------------------------------------------------------
1 | # llama-2-7b-w2g64
2 | CUDA_VISIBLE_DEVICES=0 python -m model_transfer.fp32_to_16 \
3 | --model path/to/original/quantized/model \
4 | --save_dir path/to/new/model \
5 | --target_type fp16 \
6 | --wbits 2 \
7 | --group_size 64 \
8 | --eval_ppl
--------------------------------------------------------------------------------
/examples/model_transfer/real_to_fake/llama-2-7b.sh:
--------------------------------------------------------------------------------
1 | # llama-2-7b-w2g64
2 | CUDA_VISIBLE_DEVICES=0 python -m model_transfer.real_to_fake \
3 | --model path/to/original/quantized/model \
4 | --save_dir path/to/new/model \
5 | --wbits 2 \
6 | --group_size 64
--------------------------------------------------------------------------------
/main_block_ap.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import random
4 | import numpy as np
5 | import torch
6 | import time
7 | from datautils_block import get_loaders, test_ppl
8 | import torch.nn as nn
9 | from quantize.block_ap import block_ap
10 | from tqdm import tqdm
11 | import utils
12 | from pathlib import Path
13 | from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM
14 | from quantize.int_linear_real import load_quantized_model
15 | from accelerate import infer_auto_device_map, dispatch_model
16 |
17 |
18 |
19 |
20 | torch.backends.cudnn.benchmark = True
21 |
22 | @torch.no_grad()
23 | def evaluate(model, tokenizer, args, logger):
24 | '''
25 | Note: evaluation simply move model to single GPU.
26 | Therefor, to evaluate large model such as Llama-2-70B on single A100-80GB,
27 | please activate '--real_quant'.
28 | '''
29 | # import pdb;pdb.set_trace()
30 | block_class_name = model.model.layers[0].__class__.__name__
31 | device_map = infer_auto_device_map(model, max_memory={i: args.max_memory for i in range(torch.cuda.device_count())}, no_split_module_classes=[block_class_name])
32 | model = dispatch_model(model, device_map=device_map)
33 | results = {}
34 |
35 | if args.eval_ppl:
36 | datasets = ["wikitext2", "c4"]
37 | ppl_results = test_ppl(model, tokenizer, datasets, args.ppl_seqlen)
38 | for dataset in ppl_results:
39 | logger.info(f'{dataset} perplexity: {ppl_results[dataset]:.2f}')
40 |
41 | if args.eval_tasks != "":
42 | import lm_eval
43 | from lm_eval.models.huggingface import HFLM
44 | from lm_eval.utils import make_table
45 | task_list = args.eval_tasks.split(',')
46 | model = HFLM(pretrained=model, batch_size=args.eval_batch_size)
47 | task_manager = lm_eval.tasks.TaskManager()
48 | results = lm_eval.simple_evaluate(
49 | model=model,
50 | tasks=task_list,
51 | num_fewshot=0,
52 | task_manager=task_manager,
53 | )
54 | logger.info(make_table(results))
55 | total_acc = 0
56 | for task in task_list:
57 | total_acc += results['results'][task]['acc,none']
58 | logger.info(f'Average Acc: {total_acc/len(task_list)*100:.2f}%')
59 | return results
60 |
61 |
62 | def main():
63 | import argparse
64 |
65 | parser = argparse.ArgumentParser()
66 | parser.add_argument("--model", type=str, help="model name of model path")
67 | parser.add_argument("--cache_dir", default="./cache", type=str, help="direction of cached dataset, leading to faster debug")
68 | parser.add_argument("--output_dir", default="./log/", type=str, help="direction of logging file")
69 | parser.add_argument("--save_quant_dir", default=None, type=str, help="direction for saving quantization model")
70 | parser.add_argument("--real_quant", default=False, action="store_true",
71 | help="use real quantization instead of fake quantization, can reduce memory footprint")
72 | parser.add_argument("--resume_quant", type=str, default=None, help="model path of resumed quantized model")
73 | parser.add_argument("--calib_dataset",type=str,default="redpajama",
74 | choices=["wikitext2", "ptb", "c4", "mix", "redpajama"],
75 | help="Where to extract calibration data from.")
76 | parser.add_argument("--train_size", type=int, default=4096, help="Number of training data samples.")
77 | parser.add_argument("--val_size", type=int, default=64, help="Number of validation data samples.")
78 | parser.add_argument("--training_seqlen", type=int, default=2048, help="lenth of the training sequence.")
79 | parser.add_argument("--batch_size", type=int, default=2, help="batch size.")
80 | parser.add_argument("--epochs", type=int, default=2)
81 | parser.add_argument("--ppl_seqlen", type=int, default=2048, help="input sequence length for evaluating perplexity")
82 | parser.add_argument("--seed", type=int, default=2, help="Seed for sampling the calibration data.")
83 | parser.add_argument("--eval_ppl", action="store_true",help="evaluate perplexity on wikitext2 and c4")
84 | parser.add_argument("--eval_tasks", type=str,default="", help="exampe:piqa,arc_easy,arc_challenge,hellaswag,winogrande")
85 | parser.add_argument("--eval_batch_size", type=int, default=16)
86 | parser.add_argument("--wbits", type=int, default=4, help="weights quantization bits")
87 | parser.add_argument("--group_size", type=int, default=128, help="weights quantization group size")
88 | parser.add_argument("--quant_lr", type=float, default=1e-4, help="lr of quantization parameters (s and z)")
89 | parser.add_argument("--weight_lr", type=float, default=1e-5, help="lr of full-precision weights")
90 | parser.add_argument("--min_lr_factor", type=float, default=20, help="min_lr = lr/min_lr_factor")
91 | parser.add_argument("--clip_grad", type=float, default=0.3)
92 | parser.add_argument("--wd", type=float, default=0,help="weight decay")
93 | parser.add_argument("--net", type=str, default=None,help="model (family) name, for the easier saving of data cache")
94 | parser.add_argument("--max_memory", type=str, default="70GiB",help="The maximum memory of each GPU")
95 | parser.add_argument("--early_stop", type=int, default=0,help="early stoping after validation loss do not decrease")
96 | parser.add_argument("--off_load_to_disk", action="store_true", default=False, help="save training dataset to disk, saving CPU memory but may reduce training speed")
97 |
98 | os.environ['TOKENIZERS_PARALLELISM'] = 'false'
99 | args = parser.parse_args()
100 | random.seed(args.seed)
101 | np.random.seed(args.seed)
102 | torch.manual_seed(args.seed)
103 | torch.cuda.manual_seed(args.seed)
104 |
105 |
106 | # init logger
107 | if args.output_dir:
108 | Path(args.output_dir).mkdir(parents=True, exist_ok=True)
109 | if args.cache_dir:
110 | Path(args.cache_dir).mkdir(parents=True, exist_ok=True)
111 | if args.save_quant_dir:
112 | Path(args.save_quant_dir).mkdir(parents=True, exist_ok=True)
113 | output_dir = Path(args.output_dir)
114 | logger = utils.create_logger(output_dir)
115 | logger.info(args)
116 |
117 | if args.net is None:
118 | args.net = args.model.split('/')[-1]
119 | logger.info(f"net is None, setting as {args.net}")
120 | if args.resume_quant:
121 | # directly load quantized model for evaluation
122 | model, tokenizer = load_quantized_model(args.resume_quant,args.wbits, args.group_size)
123 | logger.info(f"memory footprint after loading quantized model: {torch.cuda.max_memory_allocated('cuda') / 1024**3:.2f}GiB")
124 | else:
125 | # load fp quantized model
126 | config = AutoConfig.from_pretrained(args.model)
127 | tokenizer = AutoTokenizer.from_pretrained(args.model, use_fast=False,legacy=False)
128 | model = AutoModelForCausalLM.from_pretrained(args.model, config=config, device_map='cpu',torch_dtype=torch.float16)
129 | for param in model.parameters():
130 | param.requires_grad = False
131 |
132 | # quantization
133 | if args.wbits < 16:
134 | logger.info("=== start quantization ===")
135 | tick = time.time()
136 | # load calibration dataset
137 | cache_trainloader = f'{args.cache_dir}/dataloader_{args.net}_{args.calib_dataset}_{args.train_size}_{args.val_size}_{args.training_seqlen}_train.cache'
138 | cache_valloader = f'{args.cache_dir}/dataloader_{args.net}_{args.calib_dataset}_{args.train_size}_{args.val_size}_{args.training_seqlen}_val.cache'
139 | if os.path.exists(cache_trainloader) and os.path.exists(cache_valloader):
140 | trainloader = torch.load(cache_trainloader)
141 | logger.info(f"load trainloader from {cache_trainloader}")
142 | valloader = torch.load(cache_valloader)
143 | logger.info(f"load valloader from {cache_valloader}")
144 | else:
145 | trainloader, valloader = get_loaders(
146 | args.calib_dataset,
147 | tokenizer,
148 | args.train_size,
149 | args.val_size,
150 | seed=args.seed,
151 | seqlen=args.training_seqlen,
152 | )
153 | torch.save(trainloader, cache_trainloader)
154 | torch.save(valloader, cache_valloader)
155 | block_ap(
156 | model,
157 | args,
158 | trainloader,
159 | valloader,
160 | logger,
161 | )
162 | logger.info(time.time() - tick)
163 | torch.cuda.empty_cache()
164 | if args.save_quant_dir:
165 | logger.info("start saving model")
166 | model.save_pretrained(args.save_quant_dir)
167 | tokenizer.save_pretrained(args.save_quant_dir)
168 | logger.info("save model success")
169 | evaluate(model, tokenizer, args,logger)
170 |
171 |
172 |
173 | if __name__ == "__main__":
174 | print(sys.argv)
175 | main()
176 |
--------------------------------------------------------------------------------
/main_e2e_qp.py:
--------------------------------------------------------------------------------
1 | # This file is modified from https://github.com/artidoro/qlora/blob/main/qlora.py
2 | import json
3 | import os
4 | from os.path import exists, join, isdir
5 | from dataclasses import dataclass, field
6 | from typing import Optional, Dict
7 | import numpy as np
8 | import importlib
9 | from packaging import version
10 |
11 | import torch
12 | import transformers
13 | import argparse
14 | from transformers import (
15 | set_seed,
16 | Seq2SeqTrainer,
17 | LlamaTokenizer
18 | )
19 |
20 |
21 | from datautils_block import test_ppl
22 | from datautils_e2e import make_data_module
23 | from bitsandbytes.optim import AdamW
24 | import os
25 | import utils
26 | from quantize.int_linear_real import load_quantized_model,QuantLinear
27 | from pathlib import Path
28 |
29 |
30 |
31 |
32 | def is_ipex_available():
33 | def get_major_and_minor_from_version(full_version):
34 | return str(version.parse(full_version).major) + "." + str(version.parse(full_version).minor)
35 |
36 | _torch_version = importlib.metadata.version("torch")
37 | if importlib.util.find_spec("intel_extension_for_pytorch") is None:
38 | return False
39 | _ipex_version = "N/A"
40 | try:
41 | _ipex_version = importlib.metadata.version("intel_extension_for_pytorch")
42 | except importlib.metadata.PackageNotFoundError:
43 | return False
44 | torch_major_and_minor = get_major_and_minor_from_version(_torch_version)
45 | ipex_major_and_minor = get_major_and_minor_from_version(_ipex_version)
46 | if torch_major_and_minor != ipex_major_and_minor:
47 | warnings.warn(
48 | f"Intel Extension for PyTorch {ipex_major_and_minor} needs to work with PyTorch {ipex_major_and_minor}.*,"
49 | f" but PyTorch {_torch_version} is found. Please switch to the matching version and run again."
50 | )
51 | return False
52 | return True
53 |
54 |
55 | if torch.cuda.is_available():
56 | torch.backends.cuda.matmul.allow_tf32 = True
57 |
58 |
59 | IGNORE_INDEX = -100
60 | DEFAULT_PAD_TOKEN = "[PAD]"
61 |
62 | @dataclass
63 | class ModelArguments:
64 | quant_model_path: Optional[str] = field(
65 | default="",
66 | metadata={"help": "path of the quantization model by Block-AP."}
67 | )
68 | model_family: Optional[str] = field(
69 | default="llama-2",
70 | metadata={"help": "for the saving of dataset cache for faster experiments"}
71 | )
72 | trust_remote_code: Optional[bool] = field(
73 | default=False,
74 | metadata={"help": "Enable unpickling of arbitrary code in AutoModelForCausalLM#from_pretrained."}
75 | )
76 | use_auth_token: Optional[bool] = field(
77 | default=False,
78 | metadata={"help": "Enables using Huggingface auth token from Git Credentials."}
79 | )
80 |
81 | @dataclass
82 | class DataArguments:
83 | eval_dataset_size: int = field(
84 | default=1024, metadata={"help": "Size of validation dataset."}
85 | )
86 | max_train_samples: Optional[int] = field(
87 | default=None,
88 | metadata={
89 | "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
90 | "value if set."
91 | },
92 | )
93 | max_eval_samples: Optional[int] = field(
94 | default=None,
95 | metadata={
96 | "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
97 | "value if set."
98 | },
99 | )
100 | source_max_len: int = field(
101 | default=1024,
102 | metadata={"help": "Maximum source sequence length. Sequences will be right padded (and possibly truncated)."},
103 | )
104 | target_max_len: int = field(
105 | default=256,
106 | metadata={"help": "Maximum target sequence length. Sequences will be right padded (and possibly truncated)."},
107 | )
108 | dataset: str = field(
109 | default='alpaca',
110 | metadata={"help": "Which dataset to finetune on. See datamodule for options."}
111 | )
112 | eval_tasks: str = field(
113 | default='',
114 | metadata={"help": "evaluation tasks for lm eval, example:piqa,arc_easy,arc_challenge,hellaswag,winogrande"}
115 | )
116 | conv_temp: str = field(
117 | default='llama-2',
118 | metadata={"help": "Conversation template, only useful with deita datasets"}
119 | )
120 | mask_use: bool = field(
121 | default=True, metadata={"help": "mask the loss to role in dialogue datas"}
122 | )
123 | dataset_format: Optional[str] = field(
124 | default=None,
125 | metadata={"help": "Which dataset format is used. [alpaca|redpajama]"}
126 | )
127 | overwrite_cache: bool = field(
128 | default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
129 | )
130 | preprocessing_num_workers: Optional[int] = field(
131 | default=32,
132 | metadata={"help": "The number of processes to use for the preprocessing."},
133 | )
134 |
135 | @dataclass
136 | class TrainingArguments(transformers.Seq2SeqTrainingArguments):
137 | cache_dir: Optional[str] = field(
138 | default=None
139 | )
140 | train_on_source: Optional[bool] = field(
141 | default=False,
142 | metadata={"help": "Whether to train on the input in addition to the target text."}
143 | )
144 | do_mmlu_eval: Optional[bool] = field(
145 | default=False,
146 | metadata={"help": "Whether to run the MMLU evaluation."}
147 | )
148 | do_ppl_eval: Optional[bool] = field(
149 | default=False,
150 | metadata={"help": "Whether to run the PPL evaluation."}
151 | )
152 | pt_context_len: int = field(
153 | default=1024,
154 | metadata={"help": "language modeling length."}
155 | )
156 | full_finetune: bool = field(
157 | default=False,
158 | metadata={"help": "Finetune the entire model without adapters."}
159 | )
160 | wbits: int = field(
161 | default=4,
162 | metadata={"help": "How many bits to use."}
163 | )
164 | group_size: int = field(
165 | default=64,
166 | metadata={"help": "How many group size to use."}
167 | )
168 | max_memory_MB: int = field(
169 | default=80000,
170 | metadata={"help": "Free memory per gpu."}
171 | )
172 | report_to: str = field(
173 | default='none',
174 | metadata={"help": "To use wandb or something else for reporting."}
175 | )
176 | output_dir: str = field(default='./output', metadata={"help": 'The output dir for logs and checkpoints'})
177 | resume_from_checkpoint: str = field(default=None, metadata={"help": 'The output dir for logs and checkpoints'})
178 | optim: str = field(default='paged_adamw_32bit', metadata={"help": 'The optimizer to be used'})
179 | per_device_train_batch_size: int = field(default=1, metadata={"help": 'The training batch size per GPU. Increase for better speed.'})
180 | gradient_accumulation_steps: int = field(default=16, metadata={"help": 'How many gradients to accumulate before to perform an optimizer step'})
181 | max_steps: int = field(default=0, metadata={"help": 'How many optimizer update steps to take'})
182 | weight_decay: float = field(default=0.0, metadata={"help": 'The L2 weight decay rate of AdamW'}) # use lora dropout instead for regularization if needed
183 | learning_rate: float = field(default=2e-5, metadata={"help": 'The learnign rate'})
184 | remove_unused_columns: bool = field(default=False, metadata={"help": 'Removed unused columns. Needed to make this codebase work.'})
185 | max_grad_norm: float = field(default=0.3, metadata={"help": 'Gradient clipping max norm. This is tuned and works well for all models tested.'})
186 | gradient_checkpointing: bool = field(default=True, metadata={"help": 'Use gradient checkpointing. You want to use this.'})
187 | do_train: bool = field(default=True, metadata={"help": 'To train or not to train, that is the question?'})
188 | lr_scheduler_type: str = field(default='cosine', metadata={"help": 'Learning rate schedule. Constant a bit better than cosine, and has advantage for analysis'})
189 | warmup_ratio: float = field(default=0.03, metadata={"help": 'Fraction of steps to do a warmup for'})
190 | logging_steps: int = field(default=10, metadata={"help": 'The frequency of update steps after which to log the loss'})
191 | group_by_length: bool = field(default=False, metadata={"help": 'Group sequences into batches with same length. Saves memory and speeds up training considerably.'})
192 | save_strategy: str = field(default='epoch', metadata={"help": 'When to save checkpoints'})
193 | save_steps: int = field(default=250, metadata={"help": 'How often to save a model'})
194 | save_total_limit: int = field(default=5, metadata={"help": 'How many checkpoints to save before the oldest is overwritten'})
195 |
196 | @dataclass
197 | class GenerationArguments:
198 | # For more hyperparameters check:
199 | # https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.GenerationConfig
200 | # Length arguments
201 | max_new_tokens: Optional[int] = field(
202 | default=256,
203 | metadata={"help": "Maximum number of new tokens to be generated in evaluation or prediction loops"
204 | "if predict_with_generate is set."}
205 | )
206 | min_new_tokens : Optional[int] = field(
207 | default=None,
208 | metadata={"help": "Minimum number of new tokens to generate."}
209 | )
210 |
211 | # Generation strategy
212 | do_sample: Optional[bool] = field(default=False)
213 | num_beams: Optional[int] = field(default=1)
214 | num_beam_groups: Optional[int] = field(default=1)
215 | penalty_alpha: Optional[float] = field(default=None)
216 | use_cache: Optional[bool] = field(default=True)
217 |
218 | # Hyperparameters for logit manipulation
219 | temperature: Optional[float] = field(default=1.0)
220 | top_k: Optional[int] = field(default=50)
221 | top_p: Optional[float] = field(default=1.0)
222 | typical_p: Optional[float] = field(default=1.0)
223 | diversity_penalty: Optional[float] = field(default=0.0)
224 | repetition_penalty: Optional[float] = field(default=1.0)
225 | length_penalty: Optional[float] = field(default=1.0)
226 | no_repeat_ngram_size: Optional[int] = field(default=0)
227 |
228 |
229 |
230 | def get_accelerate_model(args, checkpoint_dir):
231 |
232 | if torch.cuda.is_available():
233 | n_gpus = torch.cuda.device_count()
234 | if is_ipex_available() and torch.xpu.is_available():
235 | n_gpus = torch.xpu.device_count()
236 |
237 | max_memory = f'{args.max_memory_MB}MB'
238 | max_memory = {i: max_memory for i in range(n_gpus)}
239 | device_map = "auto"
240 |
241 | # if we are in a distributed setting, we need to set the device map and max memory per device
242 | if os.environ.get('LOCAL_RANK') is not None:
243 | local_rank = int(os.environ.get('LOCAL_RANK', '0'))
244 | device_map = {'': local_rank}
245 | max_memory = {'': max_memory[local_rank]}
246 |
247 |
248 |
249 | model, tokenizer = load_quantized_model(args.quant_model_path,args.wbits, args.group_size)
250 | tokenizer.model_max_length = args.pt_context_len
251 |
252 | compute_dtype = (torch.float16 if args.fp16 else (torch.bfloat16 if args.bf16 else torch.float32))
253 | if compute_dtype == torch.float16 and (is_ipex_available() and torch.xpu.is_available()):
254 | compute_dtype = torch.bfloat16
255 | print('Intel XPU does not support float16 yet, so switching to bfloat16')
256 |
257 | setattr(model, 'model_parallel', True)
258 | setattr(model, 'is_parallelizable', True)
259 |
260 | model.config.torch_dtype=(torch.float32 if args.fp16 else (torch.bfloat16 if args.bf16 else torch.float32))
261 | # from peft import prepare_model_for_kbit_training
262 | # model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=args.gradient_checkpointing)
263 | model.cuda()
264 | model.train()
265 |
266 | if tokenizer._pad_token is None:
267 | smart_tokenizer_and_embedding_resize(
268 | special_tokens_dict=dict(pad_token=DEFAULT_PAD_TOKEN),
269 | tokenizer=tokenizer,
270 | model=model,
271 | )
272 |
273 | # TODO
274 | # if 'llama1' in args.model_name_or_path or 'llama2' in args.model_name_or_path or 'llama-1' in args.model_name_or_path or 'llama-2' in args.model_name_or_path:
275 | if isinstance(tokenizer, LlamaTokenizer):
276 | # LLaMA tokenizer may not have correct special tokens set.
277 | # Check and add them if missing to prevent them from being parsed into different tokens.
278 | # Note that these are present in the vocabulary.
279 | # Note also that `model.config.pad_token_id` is 0 which corresponds to `` token.
280 | print('Adding special tokens.')
281 | tokenizer.add_special_tokens({
282 | "eos_token": tokenizer.convert_ids_to_tokens(model.config.eos_token_id),
283 | "bos_token": tokenizer.convert_ids_to_tokens(model.config.bos_token_id),
284 | "unk_token": tokenizer.convert_ids_to_tokens(
285 | model.config.pad_token_id if model.config.pad_token_id != -1 else tokenizer.pad_token_id
286 | ),
287 | })
288 |
289 |
290 | for name, param in model.named_parameters():
291 | # freeze base model's layers
292 | param.requires_grad = False
293 |
294 | # cast all non INT8 parameters to fp32
295 | for param in model.parameters():
296 | if (param.dtype == torch.float16) or (param.dtype == torch.bfloat16):
297 | param.data = param.data.to(torch.float32)
298 |
299 | if args.gradient_checkpointing:
300 | if hasattr(model, "enable_input_require_grads"):
301 | model.enable_input_require_grads()
302 | else:
303 | def make_inputs_require_grad(module, input, output):
304 | output.requires_grad_(True)
305 | model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
306 |
307 | model.gradient_checkpointing_enable()
308 |
309 | for name, module in model.named_modules():
310 | # if isinstance(module, QuantLinear):
311 | # # transfer trainable step size into float32
312 | # module.scales.data = module.scales.data.to(torch.float32)
313 | if 'norm' in name:
314 | if hasattr(module, 'weight'):
315 | if args.bf16 and module.weight.dtype == torch.float32:
316 | module = module.to(torch.bfloat16)
317 | # module = module.to(torch.float32)
318 | if 'lm_head' in name or 'embed_tokens' in name:
319 | if hasattr(module, 'weight'):
320 | if args.bf16 and module.weight.dtype == torch.float32:
321 | module = module.to(torch.bfloat16)
322 | return model, tokenizer
323 |
324 | def print_trainable_parameters(args, model):
325 | """
326 | Prints the number of trainable parameters in the model.
327 | """
328 | trainable_params = 0
329 | all_param = 0
330 | print('trainable module')
331 | print('*'*80)
332 | for name, param in model.named_parameters():
333 | all_param += param.numel()
334 | if param.requires_grad:
335 | trainable_params += param.numel()
336 | print('*'*80)
337 | if args.wbits == 4: trainable_params /= 2
338 | print(
339 | f"trainable params: {trainable_params} || "
340 | f"all params: {all_param} || "
341 | f"trainable: {100 * trainable_params / all_param}"
342 | )
343 |
344 | def smart_tokenizer_and_embedding_resize(
345 | special_tokens_dict: Dict,
346 | tokenizer: transformers.PreTrainedTokenizer,
347 | model: transformers.PreTrainedModel,
348 | ):
349 | """Resize tokenizer and embedding.
350 |
351 | Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
352 | """
353 | num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
354 | model.resize_token_embeddings(len(tokenizer))
355 |
356 | if num_new_tokens > 0:
357 | input_embeddings_data = model.get_input_embeddings().weight.data
358 | output_embeddings_data = model.get_output_embeddings().weight.data
359 |
360 | input_embeddings_avg = input_embeddings_data[:-num_new_tokens].mean(dim=0, keepdim=True)
361 | output_embeddings_avg = output_embeddings_data[:-num_new_tokens].mean(dim=0, keepdim=True)
362 |
363 | input_embeddings_data[-num_new_tokens:] = input_embeddings_avg
364 | output_embeddings_data[-num_new_tokens:] = output_embeddings_avg
365 |
366 |
367 |
368 |
369 |
370 |
371 | def get_last_checkpoint(checkpoint_dir):
372 | if isdir(checkpoint_dir):
373 | is_completed = exists(join(checkpoint_dir, 'completed'))
374 | if is_completed: return None, True # already finished
375 | max_step = 0
376 | for filename in os.listdir(checkpoint_dir):
377 | if isdir(join(checkpoint_dir, filename)) and filename.startswith('checkpoint'):
378 | max_step = max(max_step, int(filename.replace('checkpoint-', '')))
379 | if max_step == 0: return None, is_completed # training started, but no checkpoint
380 | checkpoint_dir = join(checkpoint_dir, f'checkpoint-{max_step}')
381 | print(f"Found a previous checkpoint at: {checkpoint_dir}")
382 | return checkpoint_dir, is_completed # checkpoint found!
383 | return None, False # first training
384 |
385 | def train():
386 | hfparser = transformers.HfArgumentParser((
387 | ModelArguments, DataArguments, TrainingArguments, GenerationArguments
388 | ))
389 | model_args, data_args, training_args, generation_args, extra_args = \
390 | hfparser.parse_args_into_dataclasses(return_remaining_strings=True)
391 | training_args.generation_config = transformers.GenerationConfig(**vars(generation_args))
392 | args = argparse.Namespace(
393 | **vars(model_args), **vars(data_args), **vars(training_args)
394 | )
395 |
396 | Path(args.output_dir).mkdir(parents=True, exist_ok=True)
397 | logger = utils.create_logger(args.output_dir)
398 | logger.info(args)
399 |
400 | checkpoint_dir, completed_training = get_last_checkpoint(args.output_dir)
401 | if completed_training:
402 | print('Detected that training was already completed!')
403 |
404 | model, tokenizer = get_accelerate_model(args, checkpoint_dir)
405 |
406 | model.config.use_cache = False
407 | print('loaded model')
408 | set_seed(args.seed)
409 |
410 | data_module = make_data_module(tokenizer=tokenizer, args=args)
411 |
412 |
413 |
414 | optimizer_grouped_parameters = []
415 | for name, module in model.named_modules():
416 | # if isinstance(module, LoraLayer):
417 | if isinstance(module, QuantLinear) and not 'head' in name:
418 | module.scales.requires_grad = True
419 | optimizer_grouped_parameters.append({'params': [p for n, p in model.named_parameters() if 'scale' in n], 'weight_decay': 0.0, 'lr': args.learning_rate})
420 | optimizer = AdamW(optimizer_grouped_parameters)
421 |
422 | trainer = Seq2SeqTrainer(
423 | model=model,
424 | tokenizer=tokenizer,
425 | args=training_args,
426 | optimizers=(optimizer, None),
427 | **{k:v for k,v in data_module.items() if k != 'predict_dataset'},
428 | )
429 |
430 | if args.do_ppl_eval:
431 | class PPLvalCallback(transformers.TrainerCallback):
432 | @torch.no_grad()
433 | def on_evaluate(self, args=None, state=None, control=None, model=None, **kwargs):
434 | results = test_ppl(trainer.model, trainer.tokenizer, datasets=['wikitext2','c4'],ppl_seqlen=2048)
435 | logger.info(results)
436 | trainer.log(results)
437 |
438 | trainer.add_callback(PPLvalCallback)
439 |
440 | # Verifying the datatypes and parameter counts before training.
441 | print_trainable_parameters(args, model)
442 | dtypes = {}
443 | for _, p in model.named_parameters():
444 | dtype = p.dtype
445 | if dtype not in dtypes: dtypes[dtype] = 0
446 | dtypes[dtype] += p.numel()
447 | total = 0
448 | for k, v in dtypes.items(): total+= v
449 | for k, v in dtypes.items():
450 | print(k, v, v/total)
451 |
452 | all_metrics = {"run_name": args.run_name}
453 |
454 |
455 |
456 | print(args.output_dir)
457 | if args.do_train:
458 | logger.info("*** Train ***")
459 | train_result = trainer.train(args.resume_from_checkpoint)
460 | metrics = train_result.metrics
461 | trainer.log_metrics("train", metrics)
462 | trainer.save_metrics("train", metrics)
463 | trainer.save_state()
464 | all_metrics.update(metrics)
465 | # Evaluation
466 | if args.do_eval:
467 | logger.info("*** Evaluate ***")
468 | metrics = trainer.evaluate(metric_key_prefix="eval")
469 | trainer.log_metrics("eval", metrics)
470 | trainer.save_metrics("eval", metrics)
471 | all_metrics.update(metrics)
472 | # Prediction
473 | if args.do_predict:
474 | logger.info("*** Predict ***")
475 | prediction_output = trainer.predict(test_dataset=data_module['predict_dataset'],metric_key_prefix="predict")
476 | prediction_metrics = prediction_output.metrics
477 | predictions = prediction_output.predictions
478 | predictions = np.where(predictions != -100, predictions, tokenizer.pad_token_id)
479 | predictions = tokenizer.batch_decode(
480 | predictions, skip_special_tokens=True, clean_up_tokenization_spaces=True
481 | )
482 | with open(os.path.join(args.output_dir, 'predictions.jsonl'), 'w') as fout:
483 | for i, example in enumerate(data_module['predict_dataset']):
484 | example['prediction_with_input'] = predictions[i].strip()
485 | example['prediction'] = predictions[i].replace(example['input'], '').strip()
486 | fout.write(json.dumps(example) + '\n')
487 | print(prediction_metrics)
488 | trainer.log_metrics("predict", prediction_metrics)
489 | trainer.save_metrics("predict", prediction_metrics)
490 | all_metrics.update(prediction_metrics)
491 |
492 | if (args.do_train or args.do_eval or args.do_predict):
493 | with open(os.path.join(args.output_dir, "metrics.json"), "w") as fout:
494 | fout.write(json.dumps(all_metrics))
495 |
496 | if args.eval_tasks != "" or args.do_mmlu_eval:
497 | import lm_eval
498 | from lm_eval.models.huggingface import HFLM
499 | from lm_eval.utils import make_table
500 |
501 | if args.eval_tasks != "":
502 | task_list = args.eval_tasks.split(',')
503 | lm_eval_model = HFLM(pretrained=model, batch_size=32)
504 | task_manager = lm_eval.tasks.TaskManager()
505 | results = lm_eval.simple_evaluate( # call simple_evaluate
506 | model=lm_eval_model,
507 | tasks=task_list,
508 | num_fewshot=0,
509 | task_manager=task_manager,
510 | )
511 | logger.info(make_table(results))
512 | total_acc = 0
513 | for task in task_list:
514 | total_acc += results['results'][task]['acc,none']
515 | logger.info(f'Average Acc: {total_acc/len(task_list)*100:.2f}%')
516 |
517 | if args.do_mmlu_eval:
518 | lm_eval_model = HFLM(pretrained=model, batch_size=16)
519 | task_manager = lm_eval.tasks.TaskManager()
520 | results = lm_eval.simple_evaluate( # call simple_evaluate
521 | model=lm_eval_model,
522 | tasks=['mmlu'],
523 | num_fewshot=5,
524 | task_manager=task_manager,
525 | cache_requests=True,
526 | )
527 | logger.info(make_table(results))
528 | total_acc = 0
529 | for task in results['results']:
530 | total_acc += results['results'][task]['acc,none']
531 | logger.info(f"Average MMLU Acc: {total_acc/len(results['results'])*100:.2f}%")
532 |
533 | if __name__ == "__main__":
534 | train()
535 |
--------------------------------------------------------------------------------
/model_transfer/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenGVLab/EfficientQAT/39175493b2d14617d342a0a7956875e6ac16221b/model_transfer/__init__.py
--------------------------------------------------------------------------------
/model_transfer/efficientqat_to_others.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from datautils_block import test_ppl
3 | from transformers import AutoTokenizer
4 | from gptqmodel import GPTQModel, QuantizeConfig, get_backend
5 | from pathlib import Path
6 | import time
7 |
8 | def main():
9 | import argparse
10 |
11 | parser = argparse.ArgumentParser()
12 | parser.add_argument("--model", default=None, type=str, help="direction for saving quantization model")
13 | parser.add_argument("--wbits", type=int, default=4, help="quantization bits")
14 | parser.add_argument("--group_size", type=int, default=128, help="quantization group size")
15 | parser.add_argument("--target_format", default='gptq', type=str, help="target checkpoint format")
16 | parser.add_argument("--eval_ppl", action="store_true")
17 | parser.add_argument("--test_speed", action="store_true")
18 | parser.add_argument("--save_dir", default=None, type=str, help="direction for saving quantization model")
19 |
20 |
21 |
22 |
23 | args = parser.parse_args()
24 | tokenizer = AutoTokenizer.from_pretrained(args.model, use_fast=False,legacy=False)
25 | quant_config = QuantizeConfig(
26 | bits=args.wbits,
27 | group_size=args.group_size,
28 | sym=False,
29 | desc_act=False,
30 | format='gptq_v2',
31 | )
32 | if args.target_format == 'gptq':
33 | # EXLLAMA_V2 is faster in 4-bit, and can inference correctly. However, it has some bug in saving models.
34 | # Therefore, we choose triton backend as default. Note that the saving model can also be loaded by exllama too.
35 | model = GPTQModel.from_quantized(args.model, device_map='auto',torch_dtype=torch.float16, quantize_config=quant_config,backend=get_backend('TRITON'))
36 |
37 | elif args.target_format == 'bitblas':
38 | # take a lone time for the first time runing
39 | try:
40 | model = GPTQModel.from_quantized(args.model, device_map='auto',torch_dtype=torch.float16, quantize_config=quant_config,backend=get_backend('BITBLAS'))
41 | args.eval_ppl = False # BitBLAS have bug, which should re-load model for evaluation otherwise would cause wrong outputs
42 | except:
43 | model = GPTQModel.from_quantized(args.model, device_map='auto',torch_dtype=torch.float16, backend=get_backend('BITBLAS'))
44 | else:
45 | raise NotImplementedError
46 |
47 | if args.save_dir:
48 | Path(args.save_dir).mkdir(parents=True, exist_ok=True)
49 | print("start saving model")
50 | model.quantize_config.model_file_base_name=None # trick to avoid one saving bug in GPTQModel
51 | model.save_quantized(args.save_dir,max_shard_size='8GB')
52 | tokenizer.save_pretrained(args.save_dir)
53 | print(f"save model to {args.save_dir} success")
54 |
55 | model.model.cuda()
56 |
57 | if args.eval_ppl:
58 | datasets = ["wikitext2"]
59 | ppl_results = test_ppl(model, tokenizer, datasets, 2048)
60 | for dataset in ppl_results:
61 | print(f'{dataset} perplexity after transfering: {ppl_results[dataset]:.2f}')
62 | if args.test_speed:
63 | prompt = "Write a poem about large language model:"
64 | input_ids = tokenizer(prompt, return_tensors='pt').input_ids.cuda()
65 | start_time = time.time()
66 | output = model.generate(inputs=input_ids, do_sample=True, top_k=10, max_new_tokens=256)
67 | end_time = time.time()
68 | speed = len(output[0])/(end_time-start_time)
69 | print(tokenizer.decode(output[0]))
70 | print(f"generation speed:{speed:.1f}token/s")
71 |
72 |
73 | if __name__ =='__main__':
74 | main()
--------------------------------------------------------------------------------
/model_transfer/fp32_to_16.py:
--------------------------------------------------------------------------------
1 | from quantize.int_linear_real import load_quantized_model
2 | import torch
3 | from datautils_block import test_ppl
4 | from pathlib import Path
5 |
6 |
7 | def main():
8 | import argparse
9 |
10 | parser = argparse.ArgumentParser()
11 | parser.add_argument("--model", type=str, default=None, help="model path of resumed quantized model")
12 | parser.add_argument("--save_dir", default=None, type=str, help="direction for saving quantization model")
13 | parser.add_argument("--wbits", type=int, default=4, help="quantization bits")
14 | parser.add_argument("--group_size", type=int, default=128, help="quantization group size")
15 | parser.add_argument("--target_type",type=str,default="fp16",choices=["fp16", "bf16"])
16 | parser.add_argument("--eval_ppl", action="store_true",help="evaluate perplexity on wikitext2 with 2048 context length")
17 |
18 |
19 |
20 | args = parser.parse_args()
21 | model, tokenizer = load_quantized_model(args.model,args.wbits, args.group_size)
22 | model.cuda()
23 |
24 |
25 | if args.target_type =='fp16':
26 | dtype = torch.float16
27 | elif args.target_type =='bf16':
28 | dtype = torch.bfloat16
29 | else:
30 | raise NotImplementedError
31 |
32 | if args.eval_ppl:
33 | datasets = ["wikitext2"]
34 | ppl_results = test_ppl(model, tokenizer, datasets, 2048)
35 | for dataset in ppl_results:
36 | print(f'{dataset} perplexity befor transfering: {ppl_results[dataset]:.2f}')
37 |
38 | model.to(dtype)
39 | print(f"transfer model to {args.target_type} format")
40 | if args.eval_ppl:
41 | datasets = ["wikitext2"]
42 | ppl_results = test_ppl(model, tokenizer, datasets, 2048)
43 | for dataset in ppl_results:
44 | print(f'{dataset} perplexity after transfering: {ppl_results[dataset]:.2f}')
45 |
46 | if args.save_dir:
47 | Path(args.save_dir).mkdir(parents=True, exist_ok=True)
48 | print("start saving model")
49 | model.save_pretrained(args.save_dir)
50 | tokenizer.save_pretrained(args.save_dir)
51 | print(f"save model to {args.save_dir} success")
52 |
53 | if __name__ =='__main__':
54 | main()
--------------------------------------------------------------------------------
/model_transfer/real_to_fake.py:
--------------------------------------------------------------------------------
1 | from quantize.int_linear_real import load_quantized_model, QuantLinear
2 | import torch
3 | from datautils_block import test_ppl
4 | from pathlib import Path
5 |
6 | def main():
7 | import argparse
8 |
9 | parser = argparse.ArgumentParser()
10 | parser.add_argument("--model", type=str, default=None, help="model path of resumed quantized model")
11 | parser.add_argument("--save_dir", default=None, type=str, help="direction for saving quantization model")
12 | parser.add_argument("--wbits", type=int, default=4, help="quantization bits")
13 | parser.add_argument("--group_size", type=int, default=128, help="quantization group size")
14 |
15 |
16 |
17 | args = parser.parse_args()
18 | model, tokenizer = load_quantized_model(args.model,args.wbits, args.group_size)
19 | # model.cuda()
20 |
21 | for name, module in model.named_modules():
22 | if isinstance(module, QuantLinear):
23 | module.cuda()
24 | module.use_fake_quantization(del_quant=True,transpose=True)
25 | module.cpu()
26 |
27 |
28 | if args.save_dir:
29 | Path(args.save_dir).mkdir(parents=True, exist_ok=True)
30 | print("start saving model")
31 | model.to(torch.float16)
32 | model.save_pretrained(args.save_dir)
33 | tokenizer.save_pretrained(args.save_dir)
34 | print(f"save model to {args.save_dir} success")
35 |
36 | if __name__ =='__main__':
37 | main()
--------------------------------------------------------------------------------
/quantize/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenGVLab/EfficientQAT/39175493b2d14617d342a0a7956875e6ac16221b/quantize/__init__.py
--------------------------------------------------------------------------------
/quantize/block_ap.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import quantize.int_linear_fake as int_linear_fake
5 | import quantize.int_linear_real as int_linear_real
6 | from torch.optim.lr_scheduler import CosineAnnealingLR
7 | import copy
8 | import math
9 | import utils
10 | import pdb
11 | import gc
12 | from quantize.utils import (
13 | quant_parameters,weight_parameters,trainable_parameters,
14 | set_quant_state,quant_inplace,set_quant_parameters,
15 | set_weight_parameters,trainable_parameters_num,get_named_linears,set_op_by_name)
16 | import time
17 | from datautils_block import BlockTrainDataset
18 | from torch.utils.data import DataLoader
19 | import shutil
20 | import os
21 |
22 | def update_dataset(layer, dataset, dev, attention_mask, position_ids):
23 | with torch.no_grad():
24 | with torch.cuda.amp.autocast():
25 | for index, inps in enumerate(dataset):
26 | inps = inps.to(dev)
27 | if len(inps.shape)==2:
28 | inps = inps.unsqueeze(0)
29 | new_data = layer(inps, attention_mask=attention_mask,position_ids=position_ids)[0].to('cpu')
30 | dataset.update_data(index,new_data)
31 |
32 |
33 | def block_ap(
34 | model,
35 | args,
36 | trainloader,
37 | valloader,
38 | logger=None,
39 | ):
40 | logger.info("Starting ...")
41 | if args.off_load_to_disk:
42 | logger.info("offload the training dataset to disk, saving CPU memory, but may slowdown the training due to additional I/O...")
43 |
44 | dev = torch.device("cuda" if torch.cuda.is_available() else "cpu")
45 | use_cache = model.config.use_cache
46 | model.config.use_cache = False
47 |
48 | # step 1: move embedding layer and first layer to target device, only suppress llama models now
49 | layers = model.model.layers
50 | model.model.embed_tokens = model.model.embed_tokens.to(dev)
51 | model.model.norm = model.model.norm.to(dev)
52 | if hasattr(model.model, 'rotary_emb'):
53 | # for llama-3.1
54 | model.model.rotary_emb = model.model.rotary_emb.to(dev)
55 | layers[0] = layers[0].to(dev)
56 | dtype = torch.float16
57 |
58 | # step 2: init dataset
59 | flag = time.time()
60 | if args.off_load_to_disk:
61 | fp_train_cache_path = f'{args.cache_dir}/{flag}/block_training_fp_train'
62 | fp_val_cache_path = f'{args.cache_dir}/{flag}/block_training_fp_val'
63 | quant_train_cache_path = f'{args.cache_dir}/{flag}/block_training_quant_train'
64 | quant_val_cache_path = f'{args.cache_dir}/{flag}/block_training_quant_val'
65 | for path in [fp_train_cache_path,fp_val_cache_path,quant_train_cache_path,quant_val_cache_path]:
66 | if os.path.exists(path):
67 | shutil.rmtree(path)
68 | else:
69 | fp_train_cache_path = None
70 | fp_val_cache_path = None
71 | quant_train_cache_path = None
72 | quant_val_cache_path = None
73 | fp_train_inps = BlockTrainDataset(args.train_size, args.training_seqlen,
74 | model.config.hidden_size, args.batch_size, dtype, cache_path=fp_train_cache_path,off_load_to_disk=args.off_load_to_disk)
75 | fp_val_inps = BlockTrainDataset(args.val_size, args.training_seqlen,
76 | model.config.hidden_size, args.batch_size, dtype, cache_path=fp_val_cache_path,off_load_to_disk=args.off_load_to_disk)
77 |
78 | # step 3: catch the input of thefirst layer
79 | class Catcher(nn.Module):
80 | def __init__(self, module, dataset):
81 | super().__init__()
82 | self.module = module
83 | self.dataset = dataset
84 | self.index = 0
85 | self.attention_mask = None
86 | self.position_ids = None
87 |
88 | def forward(self, inp, **kwargs):
89 | self.dataset.update_data(self.index, inp.squeeze(0).to('cpu'))
90 | self.index += 1
91 | if self.attention_mask is None:
92 | self.attention_mask = kwargs["attention_mask"]
93 | if self.position_ids is None:
94 | self.position_ids = kwargs["position_ids"]
95 | raise ValueError
96 |
97 | # step 3.1: catch the input of training set
98 | layers[0] = Catcher(layers[0],fp_train_inps)
99 | iters = len(trainloader)//args.batch_size
100 | with torch.no_grad():
101 | for i in range(iters):
102 | data = torch.cat([trainloader[j][0] for j in range(i*args.batch_size,(i+1)*args.batch_size)],dim=0)
103 | try:
104 | model(data.to(dev))
105 | except ValueError:
106 | pass
107 | layers[0] = layers[0].module
108 |
109 | # step 3.2: catch the input of validation set
110 | layers[0] = Catcher(layers[0],fp_val_inps)
111 | iters = len(valloader)//args.batch_size
112 | with torch.no_grad():
113 | for i in range(iters):
114 | data = torch.cat([valloader[j][0] for j in range(i*args.batch_size,(i+1)*args.batch_size)],dim=0)
115 | try:
116 | model(data.to(dev))
117 | except ValueError:
118 | pass
119 | attention_mask = layers[0].attention_mask
120 | position_ids = layers[0].position_ids
121 | layers[0] = layers[0].module
122 | if attention_mask is not None:
123 | attention_mask_batch = attention_mask.repeat(args.batch_size,1,1,1).float()
124 | else:
125 | logger.info(
126 | "No attention mask caught from the first layer."
127 | " Seems that model's attention works without a mask."
128 | )
129 | attention_mask_batch = None
130 |
131 | # step 4: move embedding layer and first layer to cpu
132 | layers[0] = layers[0].cpu()
133 | model.model.embed_tokens = model.model.embed_tokens.cpu()
134 | model.model.norm = model.model.norm.cpu()
135 | if hasattr(model.model, 'rotary_emb'):
136 | # for llama-3.1
137 | model.model.rotary_emb = model.model.rotary_emb.cpu()
138 | torch.cuda.empty_cache()
139 |
140 | # step 5: copy fp input as the quant input, they are same at the first layer
141 | if args.off_load_to_disk:
142 | # copy quant input from fp input, they are same in first layer
143 | shutil.copytree(fp_train_cache_path, quant_train_cache_path)
144 | shutil.copytree(fp_val_cache_path, quant_val_cache_path)
145 | quant_train_inps = BlockTrainDataset(args.train_size, args.training_seqlen,
146 | model.config.hidden_size, args.batch_size, dtype, cache_path=quant_train_cache_path,off_load_to_disk=args.off_load_to_disk)
147 | quant_val_inps = BlockTrainDataset(args.val_size, args.training_seqlen,
148 | model.config.hidden_size, args.batch_size, dtype, cache_path=quant_val_cache_path,off_load_to_disk=args.off_load_to_disk)
149 | else:
150 | quant_train_inps = BlockTrainDataset(args.train_size, args.training_seqlen,
151 | model.config.hidden_size, args.batch_size, dtype, cache_path=quant_train_cache_path,off_load_to_disk=args.off_load_to_disk)
152 | quant_val_inps = BlockTrainDataset(args.val_size, args.training_seqlen,
153 | model.config.hidden_size, args.batch_size, dtype, cache_path=quant_val_cache_path,off_load_to_disk=args.off_load_to_disk)
154 | for index,data in enumerate(fp_train_inps):
155 | quant_train_inps.update_data(index, data)
156 | for index,data in enumerate(fp_val_inps):
157 | quant_val_inps.update_data(index, data)
158 |
159 | # step 6: start training
160 | loss_func = torch.nn.MSELoss()
161 | for block_index in range(len(layers)):
162 | logger.info(f"=== Start quantize blocks {block_index}===")
163 | # step 6.1: replace torch.nn.Linear with QuantLinear for QAT
164 | layer = layers[block_index].to(dev)
165 | qlayer = copy.deepcopy(layer)
166 | for name, module in qlayer.named_modules():
167 | if isinstance(module,torch.nn.Linear):
168 | quantlinear = int_linear_fake.QuantLinear(module, args.wbits, args.group_size)
169 | set_op_by_name(qlayer, name, quantlinear)
170 | del module
171 | qlayer.to(dev)
172 |
173 |
174 | # step 6.2: obtain output of full-precision model for MSE
175 | set_quant_state(qlayer,weight_quant=False) # deactivate quantization for obtaining ground truth
176 | if args.epochs > 0:
177 | update_dataset(qlayer,fp_train_inps,dev,attention_mask,position_ids)
178 | update_dataset(qlayer,fp_val_inps,dev,attention_mask,position_ids)
179 | set_quant_state(qlayer,weight_quant=True) # activate quantization
180 |
181 |
182 | if args.epochs > 0:
183 | with torch.no_grad():
184 | qlayer.float() # fp32 is required for AMP training
185 | # step 6.3: create optimizer and learning rate schedule
186 | param = []
187 | assert args.quant_lr > 0 or args.weight_lr > 0
188 | param_group_index = 0
189 | total_training_iteration = args.epochs * args.train_size / args.batch_size
190 | if args.quant_lr > 0:
191 | set_quant_parameters(qlayer,True)
192 | param.append({"params":quant_parameters(qlayer),"lr":args.quant_lr})
193 | empty_optimizer_1 = torch.optim.AdamW([torch.tensor(0)], lr=args.quant_lr)
194 | quant_scheduler = CosineAnnealingLR(empty_optimizer_1, T_max=total_training_iteration, eta_min=args.quant_lr/args.min_lr_factor)
195 | quant_index = param_group_index
196 | param_group_index += 1
197 | else:
198 | set_quant_parameters(qlayer,False)
199 |
200 | if args.weight_lr > 0:
201 | set_weight_parameters(qlayer,True)
202 | param.append({"params":weight_parameters(qlayer),"lr":args.weight_lr})
203 | empty_optimizer_2 = torch.optim.AdamW([torch.tensor(0)], lr=args.weight_lr)
204 | weight_scheduler = CosineAnnealingLR(empty_optimizer_2, T_max=total_training_iteration, eta_min=args.weight_lr/args.min_lr_factor)
205 | weight_index = param_group_index
206 | param_group_index += 1
207 | else:
208 | set_weight_parameters(qlayer,False)
209 | optimizer = torch.optim.AdamW(param, weight_decay=args.wd)
210 | loss_scaler = utils.NativeScalerWithGradNormCount()
211 | trainable_number = trainable_parameters_num(qlayer)
212 | print(f"trainable parameter number: {trainable_number/1e6}M")
213 |
214 | best_val_loss = 1e6
215 | early_stop_flag = 0
216 | for epoch in range(args.epochs):
217 | # step: 6.4 training
218 | loss_list = []
219 | norm_list = []
220 | start_time = time.time()
221 | for index, (quant_inps, fp_inps) in enumerate(zip(quant_train_inps, fp_train_inps)):
222 | # obtain output of quantization model
223 | with torch.cuda.amp.autocast():
224 | input = quant_inps.to(dev)
225 | label = fp_inps.to(dev)
226 | quant_out = qlayer(input, attention_mask=attention_mask_batch,position_ids=position_ids)[0]
227 | reconstruction_loss = loss_func(label, quant_out)
228 | loss = reconstruction_loss
229 |
230 | if not math.isfinite(loss.item()):
231 | logger.info("Loss is NAN, stopping training")
232 | pdb.set_trace()
233 | loss_list.append(reconstruction_loss.detach().cpu())
234 | optimizer.zero_grad()
235 | norm = loss_scaler(loss, optimizer,parameters=trainable_parameters(qlayer)).cpu()
236 | norm_list.append(norm.data)
237 |
238 | # adjust lr
239 | if args.quant_lr > 0:
240 | quant_scheduler.step()
241 | optimizer.param_groups[quant_index]['lr'] = quant_scheduler.get_lr()[0]
242 | if args.weight_lr >0 :
243 | weight_scheduler.step()
244 | optimizer.param_groups[weight_index]['lr'] = weight_scheduler.get_lr()[0]
245 |
246 | # step 6.5: calculate validation loss
247 | val_loss_list = []
248 | for index, (quant_inps,fp_inps) in enumerate(zip(quant_val_inps, fp_val_inps)):
249 | # obtain output of quantization model
250 | with torch.no_grad():
251 | with torch.cuda.amp.autocast():
252 | input = quant_inps.to(dev)
253 | label = fp_inps.to(dev)
254 | quant_out = qlayer(input, attention_mask=attention_mask_batch,position_ids=position_ids)[0]
255 | reconstruction_loss = loss_func(label, quant_out)
256 | val_loss_list.append(reconstruction_loss.cpu())
257 |
258 | train_mean_num = min(len(loss_list),64) # calculate the average training loss of last train_mean_num samples
259 | loss_mean = torch.stack(loss_list)[-(train_mean_num-1):].mean()
260 | val_loss_mean = torch.stack(val_loss_list).mean()
261 | norm_mean = torch.stack(norm_list).mean()
262 | logger.info(f"blocks {block_index} epoch {epoch} recon_loss:{loss_mean} val_loss:{val_loss_mean} quant_lr:{quant_scheduler.get_lr()[0]} norm:{norm_mean:.8f} max memory_allocated {torch.cuda.max_memory_allocated(dev) / 1024**2} time {time.time()-start_time} ")
263 | if val_loss_mean < best_val_loss:
264 | best_val_loss = val_loss_mean
265 | else:
266 | early_stop_flag += 1
267 | if args.early_stop > 0 and early_stop_flag >=args.early_stop:
268 | break
269 | optimizer.zero_grad()
270 | del optimizer
271 |
272 | # step 6.6: directly replace the weight with fake quantization
273 | qlayer.half()
274 | quant_inplace(qlayer)
275 | set_quant_state(qlayer,weight_quant=False) # weight has been quantized inplace
276 |
277 | # step 6.7: update inputs of quantization model
278 | if args.epochs>0:
279 | update_dataset(qlayer,quant_train_inps,dev,attention_mask,position_ids)
280 | update_dataset(qlayer,quant_val_inps,dev,attention_mask,position_ids)
281 | layers[block_index] = qlayer.to("cpu")
282 |
283 | # step 7: pack quantized weights into low-bits format, note that this process is slow on poor CPU or busy CPU
284 | if args.real_quant:
285 | named_linears = get_named_linears(qlayer, int_linear_fake.QuantLinear)
286 | for name, module in named_linears.items():
287 | scales = module.weight_quantizer.scale.clamp(1e-4,1e4).detach()
288 | zeros = module.weight_quantizer.zero_point.detach().cuda().round().cpu()
289 | group_size = module.weight_quantizer.group_size
290 | dim0 = module.weight.shape[0]
291 | scales = scales.view(dim0,-1).transpose(0,1).contiguous()
292 | zeros = zeros.view(dim0,-1).transpose(0,1).contiguous()
293 | q_linear = int_linear_real.QuantLinear(args.wbits, group_size, module.in_features,module.out_features,not module.bias is None)
294 | q_linear.pack(module.cpu(), scales.float().cpu(), zeros.float().cpu())
295 | set_op_by_name(qlayer, name, q_linear)
296 | logger.info(f"pack quantized {name} finished")
297 | del module
298 | del layer
299 | torch.cuda.empty_cache()
300 |
301 | # delete cached dataset
302 | if args.off_load_to_disk:
303 | for path in [fp_train_cache_path,fp_val_cache_path,quant_train_cache_path,quant_val_cache_path]:
304 | if os.path.exists(path):
305 | shutil.rmtree(path)
306 |
307 | torch.cuda.empty_cache()
308 | gc.collect()
309 | model.config.use_cache = use_cache
310 | return model
311 |
312 |
--------------------------------------------------------------------------------
/quantize/int_linear_fake.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from quantize.quantizer import UniformAffineQuantizer
5 |
6 |
7 |
8 |
9 |
10 | class QuantLinear(nn.Module):
11 | """
12 | Quantized Module that can perform quantized convolution or normal convolution.
13 | To activate quantization, please use set_quant_state function.
14 | """
15 | def __init__(
16 | self,
17 | org_module: nn.Linear,
18 | wbits=4,
19 | group_size=64
20 | ):
21 | super().__init__()
22 | self.fwd_kwargs = dict()
23 | self.fwd_func = F.linear
24 | self.register_parameter('weight',org_module.weight) # trainable
25 | if org_module.bias is not None:
26 | self.register_buffer('bias',org_module.bias)
27 | else:
28 | self.bias = None
29 | self.in_features = org_module.in_features
30 | self.out_features = org_module.out_features
31 | # de-activate the quantized forward default
32 | self.use_weight_quant = False
33 | # initialize quantizer
34 | self.weight_quantizer = UniformAffineQuantizer(wbits, group_size, weight=org_module.weight)
35 | self.use_temporary_parameter = False
36 |
37 |
38 |
39 | def forward(self, input: torch.Tensor):
40 | if self.use_weight_quant:
41 | weight = self.weight_quantizer(self.weight)
42 | bias = self.bias
43 | else:
44 | weight = self.weight
45 | bias = self.bias
46 |
47 |
48 | out = self.fwd_func(input, weight, bias, **self.fwd_kwargs)
49 |
50 |
51 | return out
52 |
53 | def set_quant_state(self, weight_quant: bool = False):
54 | self.use_weight_quant = weight_quant
55 |
56 |
57 |
58 |
59 |
--------------------------------------------------------------------------------
/quantize/int_linear_real.py:
--------------------------------------------------------------------------------
1 | import math
2 | from logging import getLogger
3 |
4 | import numpy as np
5 | import torch
6 | import torch.nn as nn
7 | import transformers
8 |
9 | from quantize.triton_utils.kernels import dequant_dim0, dequant_dim1
10 | import math
11 | from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
12 | from accelerate import init_empty_weights, infer_auto_device_map, load_checkpoint_in_model
13 | from tqdm import tqdm
14 | import gc
15 | from quantize.utils import get_named_linears,set_op_by_name
16 |
17 | logger = getLogger(__name__)
18 |
19 |
20 | class TritonModuleMixin:
21 | @classmethod
22 | def warmup(cls, model, transpose=False, seqlen=2048):
23 | pass
24 |
25 |
26 | class QuantLinear(nn.Module, TritonModuleMixin):
27 | QUANT_TYPE = "triton"
28 |
29 | def __init__(
30 | self,
31 | bits,
32 | group_size,
33 | infeatures,
34 | outfeatures,
35 | bias,
36 | trainable=False,
37 | **kwargs
38 | ):
39 | super().__init__()
40 | # if bits not in [2, 4, 8]:
41 | # raise NotImplementedError("Only 2,4,8 bits are supported.")
42 | # if infeatures % 32 != 0 or outfeatures % 32 != 0:
43 | # raise NotImplementedError("in_feature and out_feature must be divisible by 32.")
44 | self.infeatures = infeatures
45 | self.outfeatures = outfeatures
46 | self.bits = bits
47 | self.group_size = group_size if group_size != -1 else infeatures
48 | self.maxq = 2 ** self.bits - 1
49 | self.register_buffer(
50 | 'qweight',
51 | torch.zeros((math.ceil(infeatures / (32 // self.bits)), outfeatures), dtype=torch.int32)
52 | )
53 | self.register_parameter(
54 | 'scales',
55 | torch.nn.Parameter(torch.zeros((math.ceil(infeatures / self.group_size), outfeatures), dtype=torch.float16))
56 | )
57 | self.register_buffer(
58 | 'qzeros',
59 | torch.zeros((math.ceil(infeatures / self.group_size), math.ceil(outfeatures / (32 // self.bits))), dtype=torch.int32)
60 | )
61 | self.register_buffer(
62 | 'g_idx',
63 | torch.tensor([i // self.group_size for i in range(infeatures)], dtype=torch.int32)
64 | ) # not used, just for consistent with GPTQ models
65 | if bias:
66 | self.register_buffer('bias', torch.zeros((outfeatures), dtype=torch.float16))
67 | else:
68 | self.bias = None
69 |
70 | self.zeros_dim0, self.zeros_dim1 = self.scales.shape
71 | self.trainable = trainable
72 | self.scales.requires_grad = True
73 | self.use_fake = False
74 |
75 | def post_init(self):
76 | pass
77 |
78 |
79 | def use_fake_quantization(self, del_quant=False,transpose=False):
80 | # use fake quantization for faster training but consume more memory
81 | weight = dequant_dim0(self.qweight, self.bits, self.maxq, self.infeatures, self.outfeatures)
82 | dim0, dim1 = weight.shape
83 | zeros = dequant_dim1(self.qzeros, self.bits, self.maxq, self.zeros_dim0, self.zeros_dim1)
84 | weight = ((weight.view(-1, self.group_size, dim1) - zeros.view(-1, 1, dim1)) * self.scales.view(-1, 1, dim1)).reshape(dim0, dim1)
85 | if transpose:
86 | self.fake_transpose = True
87 | weight = weight.transpose(0,1).contiguous()
88 | self.register_buffer(
89 | 'weight',
90 | weight
91 | )
92 | self.use_fake = True
93 | if del_quant:
94 | del self.qweight
95 | del self.scales
96 | del self.qzeros
97 | del self.g_idx
98 |
99 | def pack(self, linear, scales, zeros, g_idx=None):
100 | W = linear.weight.data.clone()
101 | if isinstance(linear, nn.Conv2d):
102 | W = W.flatten(1)
103 | if isinstance(linear, transformers.pytorch_utils.Conv1D):
104 | W = W.t()
105 |
106 | g_idx = torch.tensor([i // self.group_size for i in range(self.infeatures)], dtype=torch.int32)
107 |
108 | scale_zeros = zeros * scales
109 | self.scales = nn.Parameter(scales.half())
110 | if linear.bias is not None:
111 | self.bias = linear.bias.clone().half()
112 |
113 | intweight = []
114 | for idx in range(self.infeatures):
115 | intweight.append(
116 | torch.round(
117 | (
118 | W[:, idx] + scale_zeros[g_idx[idx]]) / self.scales[g_idx[idx]]
119 | ).to(torch.int)[:, None]
120 | )
121 | intweight = torch.cat(intweight, dim=1)
122 | intweight = intweight.t().contiguous()
123 | intweight = intweight.numpy().astype(np.uint32)
124 |
125 | i = 0
126 | row = 0
127 | qweight = np.zeros((math.ceil(intweight.shape[0]/(32//self.bits)), intweight.shape[1]), dtype=np.uint32)
128 | while row < qweight.shape[0]:
129 | if self.bits in [2, 3, 4, 8]:
130 | for j in range(i, min(i + (32 // self.bits), intweight.shape[0])):
131 | qweight[row] |= intweight[j] << (self.bits * (j - i))
132 | i += 32 // self.bits
133 | row += 1
134 | else:
135 | raise NotImplementedError("Only 2,3,4,8 bits are supported.")
136 |
137 | qweight = qweight.astype(np.int32)
138 | self.qweight = torch.from_numpy(qweight)
139 |
140 | zeros = zeros.numpy().astype(np.uint32)
141 | self.zeros_dim0, self.zeros_dim1 = zeros.shape
142 | qzeros = np.zeros((zeros.shape[0], math.ceil(zeros.shape[1] / (32 // self.bits))), dtype=np.uint32)
143 | i = 0
144 | col = 0
145 | while col < qzeros.shape[1]:
146 | if self.bits in [2, 3, 4, 8]:
147 | for j in range(i, min(i + (32 // self.bits), zeros.shape[1])):
148 | qzeros[:, col] |= zeros[:, j] << (self.bits * (j - i))
149 | i += 32 // self.bits
150 | col += 1
151 | else:
152 | raise NotImplementedError("Only 2,3,4,8 bits are supported.")
153 |
154 | qzeros = qzeros.astype(np.int32)
155 | self.qzeros = torch.from_numpy(qzeros)
156 |
157 | def forward(self, x):
158 | if self.use_fake:
159 | weight = self.weight
160 | if self.fake_transpose:
161 | weight = weight.transpose(0,1)
162 | else:
163 | weight = dequant_dim0(self.qweight, self.bits, self.maxq, self.infeatures, self.outfeatures)
164 | dim0, dim1 = weight.shape
165 | # dim2 = (dim1*dim0)//self.group_size
166 | zeros = dequant_dim1(self.qzeros, self.bits, self.maxq, self.zeros_dim0, self.zeros_dim1)
167 | weight = ((weight.view(-1, self.group_size, dim1) - zeros.view(-1, 1, dim1)) * self.scales.view(-1, 1, dim1)).reshape(dim0, dim1)
168 | # out = torch.matmul(x, weight)
169 | out = torch.matmul(x, weight.to(x.dtype))
170 | out = out + self.bias if self.bias is not None else out
171 | return out
172 |
173 |
174 | def load_quantized_model(model_path, wbits, group_size):
175 | print(f"Loading quantized model from {model_path}")
176 |
177 | # import pdb;pdb.set_trace()
178 | tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
179 | config = AutoConfig.from_pretrained(model_path)
180 | with init_empty_weights():
181 | model = AutoModelForCausalLM.from_config(config=config,torch_dtype=torch.float16, trust_remote_code=True)
182 | layers = model.model.layers
183 | for i in tqdm(range(len(layers))):
184 | layer = layers[i]
185 | named_linears = get_named_linears(layer, torch.nn.Linear)
186 | for name, module in named_linears.items():
187 | q_linear = QuantLinear(wbits, group_size, module.in_features,module.out_features,not module.bias is None)
188 | q_linear.to(next(layer.parameters()).device)
189 | set_op_by_name(layer, name, q_linear)
190 | torch.cuda.empty_cache()
191 | gc.collect()
192 | model.tie_weights()
193 | device_map = infer_auto_device_map(model)
194 | print("Loading pre-computed quantized weights...")
195 | load_checkpoint_in_model(model,checkpoint=model_path,device_map=device_map,offload_state_dict=True)
196 | print("Loading pre-computed quantized weights Successfully")
197 |
198 | return model, tokenizer
199 |
200 | __all__ = ["QuantLinear","load_omniq_quantized"]
201 |
--------------------------------------------------------------------------------
/quantize/quantizer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | import pdb
5 |
6 | CLIPMIN = 1e-4
7 |
8 |
9 |
10 | def round_ste(x: torch.Tensor):
11 | """
12 | Implement Straight-Through Estimator for rounding operation.
13 | """
14 | return (x.round() - x).detach() + x
15 |
16 | def clamp_ste(x: torch.Tensor, min, max):
17 | return (x.clamp(min,max) - x).detach() + x
18 |
19 | def clamp_ste(x: torch.Tensor, min, max):
20 | return (x.clamp(min,max) - x).detach() + x
21 |
22 |
23 | class UniformAffineQuantizer(nn.Module):
24 | def __init__(
25 | self,
26 | n_bits: int = 8,
27 | group_size=None,
28 | weight=None,
29 | ):
30 | super().__init__()
31 | assert 2 <= n_bits <= 16, "bitwidth not supported"
32 | self.n_bits = n_bits
33 | self.qmin = 0
34 | self.qmax = 2 ** (n_bits) - 1
35 | self.group_size = group_size if group_size != -1 else weight.shape[-1]
36 | assert weight.shape[-1] % group_size == 0
37 | self.enable = True
38 |
39 | # init scale and zero point through Max-Min quantization
40 | with torch.no_grad():
41 | if weight is not None:
42 | x = weight.reshape(-1,self.group_size)
43 | xmin = x.amin([-1], keepdim=True)
44 | xmax = x.amax([-1], keepdim=True)
45 | range = xmax - xmin
46 | scale = range / (2**self.n_bits-1)
47 | scale = scale.clamp(min=1e-4, max=1e4)
48 | zero_point = -(xmin/scale).clamp(min=-1e4, max=1e4)
49 | self.scale = nn.Parameter(scale)
50 | self.zero_point = nn.Parameter(zero_point.round())
51 |
52 |
53 | def change_n_bits(self, n_bits):
54 | self.n_bits = n_bits
55 | self.qmin = 0
56 | self.qmax = int(2 ** (n_bits) - 1)
57 |
58 | def fake_quant(self, x):
59 | scale = clamp_ste(self.scale,1e-4, 1e4)
60 | round_zero_point = clamp_ste(round_ste(self.zero_point), self.qmin, self.qmax)
61 |
62 | dim1, dim2 = x.shape
63 | x = x.reshape(-1, self.group_size)
64 | x_int = round_ste(x / scale)
65 | if round_zero_point is not None:
66 | x_int = x_int.add(round_zero_point)
67 | x_int = x_int.clamp(self.qmin, self.qmax)
68 | x_dequant = x_int
69 | if round_zero_point is not None:
70 | x_dequant = x_dequant.sub(round_zero_point)
71 | x_dequant = x_dequant.mul(scale)
72 | if self.group_size:
73 | x_dequant = x_dequant.reshape(dim1, dim2)
74 | return x_dequant
75 |
76 |
77 | def forward(self, x: torch.Tensor):
78 | if self.n_bits >= 16 or not self.enable:
79 | return x
80 |
81 | x_dequant = self.fake_quant(x)
82 | return x_dequant
83 |
84 |
85 |
86 |
--------------------------------------------------------------------------------
/quantize/triton_utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenGVLab/EfficientQAT/39175493b2d14617d342a0a7956875e6ac16221b/quantize/triton_utils/__init__.py
--------------------------------------------------------------------------------
/quantize/triton_utils/custom_autotune.py:
--------------------------------------------------------------------------------
1 | import builtins
2 | import math
3 | import time
4 | from typing import Dict
5 |
6 | import triton
7 |
8 |
9 | # code based https://github.com/fpgaminer/GPTQ-triton
10 | """
11 | Mostly the same as the autotuner in Triton, but with a few changes like using 40 runs instead of 100.
12 | """
13 |
14 |
15 | class CustomizedTritonAutoTuner(triton.KernelInterface):
16 | def __init__(
17 | self,
18 | fn,
19 | arg_names,
20 | configs,
21 | key,
22 | reset_to_zero,
23 | prune_configs_by: Dict = None,
24 | nearest_power_of_two: bool = False
25 | ):
26 | if not configs:
27 | self.configs = [triton.Config({}, num_warps=4, num_stages=2)]
28 | else:
29 | self.configs = configs
30 | self.key_idx = [arg_names.index(k) for k in key]
31 | self.nearest_power_of_two = nearest_power_of_two
32 | self.cache = {}
33 | # hook to reset all required tensor to zeros before relaunching a kernel
34 | self.hook = lambda args: 0
35 | if reset_to_zero is not None:
36 | self.reset_idx = [arg_names.index(k) for k in reset_to_zero]
37 |
38 | def _hook(args):
39 | for i in self.reset_idx:
40 | args[i].zero_()
41 |
42 | self.hook = _hook
43 | self.arg_names = arg_names
44 | # prune configs
45 | if prune_configs_by:
46 | perf_model, top_k = prune_configs_by['perf_model'], prune_configs_by['top_k']
47 | if 'early_config_prune' in prune_configs_by:
48 | early_config_prune = prune_configs_by['early_config_prune']
49 | else:
50 | perf_model, top_k, early_config_prune = None, None, None
51 | self.perf_model, self.configs_top_k = perf_model, top_k
52 | self.early_config_prune = early_config_prune
53 | self.fn = fn
54 |
55 | def _bench(self, *args, config, **meta):
56 | # check for conflicts, i.e. meta-parameters both provided
57 | # as kwargs and by the autotuner
58 | conflicts = meta.keys() & config.kwargs.keys()
59 | if conflicts:
60 | raise ValueError(f"Conflicting meta-parameters: {', '.join(conflicts)}."
61 | " Make sure that you don't re-define auto-tuned symbols.")
62 | # augment meta-parameters with tunable ones
63 | current = dict(meta, **config.kwargs)
64 |
65 | def kernel_call():
66 | if config.pre_hook:
67 | config.pre_hook(self.nargs)
68 | self.hook(args)
69 | self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **current)
70 |
71 | try:
72 | # In testings using only 40 reps seems to be close enough and it appears to be what PyTorch uses
73 | # PyTorch also sets fast_flush to True, but I didn't see any speedup so I'll leave the default
74 | return triton.testing.do_bench(kernel_call, quantiles=(0.5, 0.2, 0.8), rep=40)
75 | except triton.OutOfResources:
76 | # except triton.OutOfResources:
77 | return (float('inf'), float('inf'), float('inf'))
78 |
79 | def run(self, *args, **kwargs):
80 | self.nargs = dict(zip(self.arg_names, args))
81 | if len(self.configs) > 1:
82 | key = tuple(args[i] for i in self.key_idx)
83 |
84 | # This reduces the amount of autotuning by rounding the keys to the nearest power of two
85 | # In my testing this gives decent results, and greatly reduces the amount of tuning required
86 | if self.nearest_power_of_two:
87 | key = tuple([2 ** int(math.log2(x) + 0.5) for x in key])
88 |
89 | if key not in self.cache:
90 | # prune configs
91 | pruned_configs = self.prune_configs(kwargs)
92 | bench_start = time.time()
93 | timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs}
94 | bench_end = time.time()
95 | self.bench_time = bench_end - bench_start
96 | self.cache[key] = builtins.min(timings, key=timings.get)
97 | self.hook(args)
98 | self.configs_timings = timings
99 | config = self.cache[key]
100 | else:
101 | config = self.configs[0]
102 | self.best_config = config
103 | if config.pre_hook is not None:
104 | config.pre_hook(self.nargs)
105 | return self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **kwargs, **config.kwargs)
106 |
107 | def prune_configs(self, kwargs):
108 | pruned_configs = self.configs
109 | if self.early_config_prune:
110 | pruned_configs = self.early_config_prune(self.configs, self.nargs)
111 | if self.perf_model:
112 | top_k = self.configs_top_k
113 | if isinstance(top_k, float) and top_k <= 1.0:
114 | top_k = int(len(self.configs) * top_k)
115 | if len(pruned_configs) > top_k:
116 | est_timing = {
117 | config: self.perf_model(**self.nargs, **kwargs, **config.kwargs, num_stages=config.num_stages,
118 | num_warps=config.num_warps) for config in pruned_configs}
119 | pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[:top_k]
120 | return pruned_configs
121 |
122 | def warmup(self, *args, **kwargs):
123 | self.nargs = dict(zip(self.arg_names, args))
124 | for config in self.prune_configs(kwargs):
125 | self.fn.warmup(
126 | *args,
127 | num_warps=config.num_warps,
128 | num_stages=config.num_stages,
129 | **kwargs,
130 | **config.kwargs,
131 | )
132 | self.nargs = None
133 |
134 |
135 | def autotune(configs, key, prune_configs_by=None, reset_to_zero=None, nearest_power_of_two=False):
136 | def decorator(fn):
137 | return CustomizedTritonAutoTuner(
138 | fn, fn.arg_names, configs, key, reset_to_zero, prune_configs_by, nearest_power_of_two
139 | )
140 |
141 | return decorator
142 |
143 |
144 | def matmul248_kernel_config_pruner(configs, nargs):
145 | """
146 | The main purpose of this function is to shrink BLOCK_SIZE_* when the corresponding dimension is smaller.
147 | """
148 | m = max(2 ** int(math.ceil(math.log2(nargs['M']))), 16)
149 | n = max(2 ** int(math.ceil(math.log2(nargs['N']))), 16)
150 | k = max(2 ** int(math.ceil(math.log2(nargs['K']))), 16)
151 |
152 | used = set()
153 | for config in configs:
154 | block_size_m = min(m, config.kwargs['BLOCK_SIZE_M'])
155 | block_size_n = min(n, config.kwargs['BLOCK_SIZE_N'])
156 | block_size_k = min(k, config.kwargs['BLOCK_SIZE_K'])
157 | group_size_m = config.kwargs['GROUP_SIZE_M']
158 |
159 | if (block_size_m, block_size_n, block_size_k, group_size_m, config.num_stages, config.num_warps) in used:
160 | continue
161 |
162 | used.add((block_size_m, block_size_n, block_size_k, group_size_m, config.num_stages, config.num_warps))
163 | yield triton.Config(
164 | {
165 | 'BLOCK_SIZE_M': block_size_m,
166 | 'BLOCK_SIZE_N': block_size_n,
167 | 'BLOCK_SIZE_K': block_size_k,
168 | 'GROUP_SIZE_M': group_size_m
169 | },
170 | num_stages=config.num_stages,
171 | num_warps=config.num_warps
172 | )
173 |
174 | def hadamard248_kernel_config_pruner(configs, nargs):
175 | """
176 | The main purpose of this function is to shrink BLOCK_SIZE_* when the corresponding dimension is smaller.
177 | """
178 | m = max(2 ** int(math.ceil(math.log2(nargs['M']))), 16)
179 | n = max(2 ** int(math.ceil(math.log2(nargs['N']))), 16)
180 |
181 | used = set()
182 | for config in configs:
183 | block_size_m = min(m, config.kwargs['BLOCK_SIZE_M'])
184 | block_size_n = min(n, config.kwargs['BLOCK_SIZE_N'])
185 |
186 | if (block_size_m, block_size_n , config.num_stages, config.num_warps) in used:
187 | continue
188 |
189 | used.add((block_size_m, block_size_n, config.num_stages, config.num_warps))
190 | yield triton.Config(
191 | {
192 | 'BLOCK_SIZE_M': block_size_m,
193 | 'BLOCK_SIZE_N': block_size_n,
194 | },
195 | num_stages=config.num_stages,
196 | num_warps=config.num_warps
197 | )
198 |
199 |
200 | __all__ = ["autotune"]
201 |
--------------------------------------------------------------------------------
/quantize/triton_utils/kernels.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.cuda.amp import custom_bwd, custom_fwd
3 | from logging import getLogger
4 |
5 | import triton
6 | import triton.language as tl
7 |
8 | from . import custom_autotune
9 | import pdb
10 |
11 | logger = getLogger(__name__)
12 |
13 |
14 | # code based https://github.com/fpgaminer/GPTQ-triton
15 | @custom_autotune.autotune(
16 | configs=[
17 | triton.Config(
18 | {
19 | 'BLOCK_SIZE_M': 64,
20 | 'BLOCK_SIZE_N': 256,
21 | },
22 | num_stages=4,
23 | num_warps=4
24 | ),
25 | triton.Config(
26 | {
27 | 'BLOCK_SIZE_M': 128,
28 | 'BLOCK_SIZE_N': 128,
29 | },
30 | num_stages=4,
31 | num_warps=4
32 | ),
33 | triton.Config(
34 | {
35 | 'BLOCK_SIZE_M': 64,
36 | 'BLOCK_SIZE_N': 128,
37 | },
38 | num_stages=4,
39 | num_warps=4
40 | ),
41 | triton.Config(
42 | {
43 | 'BLOCK_SIZE_M': 128,
44 | 'BLOCK_SIZE_N': 32,
45 | },
46 | num_stages=4,
47 | num_warps=4
48 | ),
49 | triton.Config(
50 | {
51 | 'BLOCK_SIZE_M': 64,
52 | 'BLOCK_SIZE_N': 64,
53 | },
54 | num_stages=4,
55 | num_warps=4
56 | ),
57 | triton.Config(
58 | {
59 | 'BLOCK_SIZE_M': 64,
60 | 'BLOCK_SIZE_N': 128,
61 | },
62 | num_stages=2,
63 | num_warps=8
64 | ),
65 | triton.Config(
66 | {
67 | 'BLOCK_SIZE_M': 32,
68 | 'BLOCK_SIZE_N': 128,
69 | },
70 | num_stages=4,
71 | num_warps=4
72 | ),
73 | ],
74 | key=['M', 'N'],
75 | nearest_power_of_two=True,
76 | prune_configs_by={
77 | 'early_config_prune': custom_autotune.hadamard248_kernel_config_pruner,
78 | 'perf_model': None,
79 | 'top_k': None,
80 | },
81 | )
82 | @triton.jit
83 | def dequant_kernel_dim0(
84 | b_ptr, c_ptr,
85 | M, N,
86 | bits, maxq,
87 | stride_bk, stride_bn,
88 | stride_cm, stride_cn,
89 | BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr
90 | ):
91 | """
92 | dequant the quantized tensor to fp tensor
93 | B is of shape (M/(32//bits), N) int32
94 | C is of shape (M, N) float16
95 | """
96 |
97 | bits_per_feature = 32 // bits
98 |
99 | pid = tl.program_id(axis=0)
100 | num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
101 | num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
102 |
103 | pid_m = pid // num_pid_n
104 | pid_n = pid % num_pid_n
105 |
106 | offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
107 | offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
108 |
109 |
110 | b_ptrs = b_ptr + ((offs_am[:, None] // bits_per_feature) * stride_bk + offs_bn[None, :] * stride_bn)
111 |
112 | shifter = (offs_am[:, None] % bits_per_feature) * bits
113 |
114 |
115 |
116 | b = tl.load(b_ptrs)
117 | b = (b >> shifter) & maxq
118 |
119 | c = b
120 |
121 | c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :]
122 | c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N)
123 | tl.store(c_ptrs, c, mask=c_mask)
124 |
125 | @custom_autotune.autotune(
126 | configs=[
127 | triton.Config(
128 | {
129 | 'BLOCK_SIZE_M': 2,
130 | 'BLOCK_SIZE_N':128,
131 | },
132 | num_stages=8,
133 | num_warps=8
134 | ),
135 | triton.Config(
136 | {
137 | 'BLOCK_SIZE_M': 2,
138 | 'BLOCK_SIZE_N':64,
139 | },
140 | num_stages=8,
141 | num_warps=8
142 | ),
143 | triton.Config(
144 | {
145 | 'BLOCK_SIZE_M': 2,
146 | 'BLOCK_SIZE_N':32,
147 | },
148 | num_stages=8,
149 | num_warps=8
150 | ),
151 | triton.Config(
152 | {
153 | 'BLOCK_SIZE_M': 2,
154 | 'BLOCK_SIZE_N':2,
155 | },
156 | num_stages=8,
157 | num_warps=8
158 | ),
159 | ],
160 | key=['M', 'N'],
161 | nearest_power_of_two=True,
162 | prune_configs_by={
163 | 'early_config_prune': custom_autotune.hadamard248_kernel_config_pruner,
164 | 'perf_model': None,
165 | 'top_k': None,
166 | },
167 | )
168 | @triton.jit
169 | def dequant_kernel_dim1(
170 | b_ptr, c_ptr,
171 | M, N,
172 | bits, maxq,
173 | stride_bk, stride_bn,
174 | stride_cm, stride_cn,
175 | BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr
176 | ):
177 | """
178 | dequant the quantized tensor to fp tensor
179 | B is of shape (M, N/(32//bits)) int32
180 | C is of shape (M, N) float16
181 | """
182 |
183 | bits_per_feature = 32 // bits
184 |
185 | pid = tl.program_id(axis=0)
186 | num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
187 | num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
188 |
189 | pid_m = pid // num_pid_n
190 | pid_n = pid % num_pid_n
191 |
192 | offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
193 | offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
194 |
195 |
196 | # b_ptrs = b_ptr + ((offs_am[:, None] // bits_per_feature) * stride_bk + offs_bn[None, :] * stride_bn)
197 | b_ptrs = b_ptr + (offs_am[:, None] * stride_bk + (offs_bn[None, :] // bits_per_feature) * stride_bn)
198 |
199 | # shifter = (offs_am[:, None] % bits_per_feature) * bits
200 | shifter = (offs_bn[None, :] % bits_per_feature) * bits
201 |
202 |
203 |
204 | b = tl.load(b_ptrs)
205 | b = (b >> shifter) & maxq
206 |
207 | c = b
208 |
209 | c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :]
210 | c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N)
211 | tl.store(c_ptrs, c, mask=c_mask)
212 |
213 |
214 | @triton.jit
215 | def silu(x):
216 | return x * tl.sigmoid(x)
217 |
218 |
219 | def dequant_dim0(qweight, bits, maxq, infeatures, outfeatures):
220 | with torch.cuda.device(qweight.device):
221 | output = torch.empty((infeatures, outfeatures), device=qweight.device, dtype=torch.float16)
222 | grid = lambda META: (
223 | triton.cdiv(output.shape[0], META['BLOCK_SIZE_M']) * triton.cdiv(output.shape[1], META['BLOCK_SIZE_N']),
224 | )
225 | dequant_kernel_dim0[grid](
226 | qweight, output,
227 | output.shape[0], output.shape[1],
228 | bits, maxq,
229 | qweight.stride(0), qweight.stride(1),
230 | output.stride(0), output.stride(1),
231 | )
232 | return output
233 |
234 | def dequant_dim1(qweight, bits, maxq, infeatures, outfeatures):
235 | with torch.cuda.device(qweight.device):
236 | output = torch.empty((infeatures, outfeatures), device=qweight.device, dtype=torch.float16)
237 | grid = lambda META: (
238 | triton.cdiv(output.shape[0], META['BLOCK_SIZE_M']) * triton.cdiv(output.shape[1], META['BLOCK_SIZE_N']),
239 | )
240 | dequant_kernel_dim1[grid](
241 | qweight, output,
242 | output.shape[0], output.shape[1],
243 | bits, maxq,
244 | qweight.stride(0), qweight.stride(1),
245 | output.stride(0), output.stride(1),
246 | )
247 | return output
248 |
249 |
250 |
251 |
252 |
253 |
254 |
--------------------------------------------------------------------------------
/quantize/triton_utils/mixin.py:
--------------------------------------------------------------------------------
1 | class TritonModuleMixin:
2 | @classmethod
3 | def warmup(cls, model, transpose=False, seqlen=2048):
4 | pass
5 |
--------------------------------------------------------------------------------
/quantize/utils.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 | from quantize.int_linear_fake import QuantLinear
3 | import torch
4 | from torch import nn
5 | from typing import Optional
6 |
7 |
8 | class MultiBlock(nn.Module):
9 | def __init__(self, *args, **kwargs) -> None:
10 | super().__init__(*args, **kwargs)
11 | self.block_list = nn.ModuleList([])
12 |
13 | def add_block(self, block):
14 | self.block_list.append(block)
15 |
16 | def forward(self,
17 | hidden_states: torch.Tensor,
18 | attention_mask: Optional[torch.Tensor] = None,
19 | position_ids: Optional[torch.LongTensor] = None):
20 | for block in self.block_list:
21 | hidden_states = block(hidden_states, attention_mask=attention_mask,position_ids=position_ids)[0]
22 | return (hidden_states, )
23 |
24 |
25 | def set_weight_parameters(model, requires_grad):
26 | params = []
27 | for n, m in model.named_parameters():
28 | if n.find('weight') > -1 and not (n.find('scale') > -1 or n.find('zero_point') > -1):
29 | m.requires_grad = requires_grad
30 | return iter(params)
31 |
32 | def weight_parameters(model):
33 | params = []
34 | for n, m in model.named_parameters():
35 | if n.find('weight') > -1 and not (n.find('scale') > -1 or n.find('zero_point') > -1):
36 | params.append(m)
37 | return iter(params)
38 |
39 | def set_quant_parameters(model, requires_grad):
40 | params = []
41 | for n, m in model.named_parameters():
42 | if n.find('scale') > -1 or n.find('zero_point') > -1:
43 | m.requires_grad = requires_grad
44 | return iter(params)
45 |
46 | def quant_parameters(model):
47 | params = []
48 | for n, m in model.named_parameters():
49 | if n.find('scale') > -1 or n.find('zero_point') > -1:
50 | params.append(m)
51 | return iter(params)
52 |
53 |
54 | def trainable_parameters(model):
55 | params = []
56 | for n, m in model.named_parameters():
57 | if m.requires_grad:
58 | params.append(m)
59 | return iter(params)
60 |
61 | def trainable_parameters_num(model):
62 | params = []
63 | total = 0
64 | for n, m in model.named_parameters():
65 | if m.requires_grad:
66 | total += m.numel()
67 | return total
68 |
69 | def set_quant_state(model, weight_quant: bool = False):
70 | for m in model.modules():
71 | if isinstance(m, QuantLinear):
72 | m.set_quant_state(weight_quant)
73 |
74 | @torch.no_grad()
75 | def quant_inplace(model):
76 | for name, module in model.named_modules():
77 | if isinstance(module, QuantLinear):
78 | module.weight.data = module.weight_quantizer(module.weight.data)
79 |
80 |
81 | class TruncateFunction(torch.autograd.Function):
82 | @staticmethod
83 | def forward(ctx, input, threshold):
84 | truncated_tensor = input.clone()
85 | truncated_tensor[truncated_tensor.abs() < threshold] = truncated_tensor[truncated_tensor.abs() < threshold].sign() * threshold
86 | return truncated_tensor
87 |
88 |
89 | @staticmethod
90 | def backward(ctx, grad_output):
91 | grad_input = grad_output.clone()
92 | return grad_input, None
93 |
94 |
95 | def truncate_number(number, threshold=1e-2):
96 | # avoid overflow with AMP training
97 | return TruncateFunction.apply(number, threshold)
98 |
99 |
100 | def get_named_linears(module, type):
101 | # return {name: m for name, m in module.named_modules() if isinstance(m, torch.nn.Linear)}
102 | return {name: m for name, m in module.named_modules() if isinstance(m, type)}
103 |
104 | def set_op_by_name(layer, name, new_module):
105 | levels = name.split('.')
106 | if len(levels) > 1:
107 | mod_ = layer
108 | for l_idx in range(len(levels)-1):
109 | if levels[l_idx].isdigit():
110 | mod_ = mod_[int(levels[l_idx])]
111 | else:
112 | mod_ = getattr(mod_, levels[l_idx])
113 | setattr(mod_, levels[-1], new_module)
114 | else:
115 | setattr(layer, name, new_module)
116 |
117 | # def add_new_module(name, original_module, added_module):
118 | # levels = name.split('.')
119 | # if len(levels) > 1:
120 | # mod_ = original_module
121 | # for l_idx in range(len(levels)-1):
122 | # if levels[l_idx].isdigit():
123 | # mod_ = mod_[int(levels[l_idx])]
124 | # else:
125 | # mod_ = getattr(mod_, levels[l_idx])
126 | # setattr(mod_, levels[-1], added_module)
127 | # else:
128 | # setattr(original_module, name, added_module)
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | accelerate==0.28.0
2 | bitsandbytes==0.41.0
3 | datasets==2.18.0
4 | lm_eval==0.4.2
5 | numpy==1.23.4
6 | torch==2.2.2
7 | tqdm==4.64.1
8 | transformers==4.40.1
9 | triton==2.2.0
10 | termcolor
11 | sentencepiece
12 | protobuf
13 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from math import inf
3 | import logging
4 | from termcolor import colored
5 | import sys
6 | import os
7 | import time
8 |
9 |
10 | def ampscaler_get_grad_norm(parameters, norm_type: float = 2.0) -> torch.Tensor:
11 | if isinstance(parameters, torch.Tensor):
12 | parameters = [parameters]
13 | parameters = [p for p in parameters if p.grad is not None]
14 | norm_type = float(norm_type)
15 | if len(parameters) == 0:
16 | return torch.tensor(0.)
17 | device = parameters[0].grad.device
18 | if norm_type == inf:
19 | total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters)
20 | else:
21 | total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(),
22 | norm_type).to(device) for p in parameters]), norm_type)
23 | return total_norm
24 |
25 | class NativeScalerWithGradNormCount:
26 | state_dict_key = "amp_scaler"
27 |
28 | def __init__(self):
29 | self._scaler = torch.cuda.amp.GradScaler()
30 |
31 | def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True,retain_graph=False):
32 | self._scaler.scale(loss).backward(create_graph=create_graph, retain_graph=retain_graph)
33 | if update_grad:
34 | if clip_grad is not None:
35 | assert parameters is not None
36 | self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place
37 | norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad)
38 | else:
39 | self._scaler.unscale_(optimizer)
40 | norm = ampscaler_get_grad_norm(parameters)
41 | self._scaler.step(optimizer)
42 | self._scaler.update()
43 | else:
44 | norm = None
45 | return norm
46 |
47 | def state_dict(self):
48 | return self._scaler.state_dict()
49 |
50 | def load_state_dict(self, state_dict):
51 | self._scaler.load_state_dict(state_dict)
52 |
53 |
54 | def create_logger(output_dir, dist_rank=0, name=''):
55 | # create logger
56 | logger = logging.getLogger(name)
57 | logger.setLevel(logging.INFO)
58 | logger.propagate = False
59 |
60 | # create formatter
61 | fmt = '[%(asctime)s %(name)s] (%(filename)s %(lineno)d): %(levelname)s %(message)s'
62 | color_fmt = colored('[%(asctime)s %(name)s]', 'green') + \
63 | colored('(%(filename)s %(lineno)d)', 'yellow') + ': %(levelname)s %(message)s'
64 |
65 | # create console handlers for master process
66 | if dist_rank == 0:
67 | console_handler = logging.StreamHandler(sys.stdout)
68 | console_handler.setLevel(logging.DEBUG)
69 | console_handler.setFormatter(
70 | logging.Formatter(fmt=color_fmt, datefmt='%Y-%m-%d %H:%M:%S'))
71 | logger.addHandler(console_handler)
72 |
73 | # create file handlers
74 | file_handler = logging.FileHandler(os.path.join(output_dir, f'log_rank{dist_rank}_{int(time.time())}.txt'), mode='a')
75 | file_handler.setLevel(logging.DEBUG)
76 | file_handler.setFormatter(logging.Formatter(fmt=fmt, datefmt='%Y-%m-%d %H:%M:%S'))
77 | logger.addHandler(file_handler)
78 |
79 | return logger
--------------------------------------------------------------------------------