├── Dockerfile
├── README.md
├── bm25_retriever.py
├── build.sh
├── config.py
├── data
├── result.json
├── test_question.json
└── train_a.pdf
├── faiss_retriever.py
├── images
├── 01.png
├── 02.png
└── 03.png
├── pdf_parse.py
├── pre_train_model
└── Qwen-7B-Chat
│ └── download.py
├── qwen_generation_utils.py
├── requirements.txt
├── rerank_model.py
├── run.py
├── run.sh
├── vllm_model.py
└── vllm_wrapper.py
/Dockerfile:
--------------------------------------------------------------------------------
1 | # FROM registry.cn-shanghai.aliyuncs.com/tcc-public/pytorch:2.0.0-py3.9.12-cuda11.8.0-u22.04-cudnn
2 |
3 | FROM registry.cn-shanghai.aliyuncs.com/aicar/vllm:base
4 |
5 | # 如有安装其他软件的需求
6 | # RUN apt-get update && apt-get install curl
7 | # 如果安装其他python包的情况
8 | #pip3 install numpy --index-url=http://mirrors.aliyun.com/pypi/simple/ --trusted-host=mirrors.aliyun.com
9 | # RUN pip install --progress-bar off numpy pandas PyPDF2 langchain jieba rank_bm25 sentence-transformers faiss-gpu modelscope tiktoken transformers_stream_generator accelerate pdfplumber --index-url=http://mirrors.aliyun.com/pypi/simple/ --trusted-host=mirrors.aliyun.com
10 | # 复制代码到镜像仓库
11 | COPY app /app
12 |
13 | # 指定工作目录
14 | WORKDIR /app
15 |
16 | # 容器启动运行命令
17 | CMD ["bash", "run.sh"]
18 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## Tianchi-LLM-QA
2 | 阿里天池: 2023全球智能汽车AI挑战赛——赛道一:AI大模型检索问答 baseline 80+
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 | ### 1、代码结构
12 |
13 | ```text
14 | .
15 | ├── Dockerfile
16 | ├── README.md
17 | ├── bm25_retriever.py
18 | ├── build.sh
19 | ├── config.py
20 | ├── data
21 | │ ├── result.json
22 | │ ├── test_question.json
23 | │ └── train_a.pdf
24 | ├── faiss_retriever.py
25 | ├── vllm_model.py
26 | ├── pdf_parse.py
27 | ├── pre_train_model
28 | │ ├── Qwen-7B-Chat
29 | │ │ └── download.py
30 | │ ├── bge-reranker-large
31 | │ └── m3e-large
32 | ├── qwen_generation_utils.py
33 | ├── requirements.txt
34 | ├── rerank_model.py
35 | ├── run.py
36 | ├── run.sh
37 | └── vllm_wrapper.py
38 | ```
39 |
40 | ### 2、[赛题概述](https://tianchi.aliyun.com/competition/entrance/532154)
41 | #### 2.1 赛题:基于大模型的文档检索问答
42 |
43 | 任务:本次比赛要求参赛选手以大模型为中心制作一个问答系统,回答用户的汽车相关问题。参赛选手需要根据问题,在文档中定位相关信息的位置,并根据文档内容通过大模型生成相应的答案。本次比赛涉及的问题主要围绕汽车使用、维修、保养等方面,具体可参考下面的例子:
44 |
45 | 问题1:怎么打开危险警告灯?
46 | 答案1:危险警告灯开关在方向盘下方,按下开关即可打开危险警告灯。
47 |
48 | 问题2:车辆如何保养?
49 | 答案2:为了保持车辆处于最佳状态,建议您定期关注车辆状态,包括定期保养、洗车、内部清洁、外部清洁、轮胎的保养、低压蓄电池的保养等。
50 |
51 | 问题3:靠背太热怎么办?
52 | 答案3:您好,如果您的座椅靠背太热,可以尝试关闭座椅加热功能。在多媒体显示屏上依次点击空调开启按键→座椅→加热,在该界面下可以关闭座椅加热。
53 |
54 | #### 2.2 数据(复赛数据官方只提供部分参考样式)
55 |
56 | [初赛训练数据集.pdf](https://tianchi-race-prod-sh.oss-cn-shanghai.aliyuncs.com/file/race/documents/532154/%E5%88%9D%E8%B5%9B%E8%AE%AD%E7%BB%83%E9%9B%86/%E5%88%9D%E8%B5%9B%E8%AE%AD%E7%BB%83%E6%95%B0%E6%8D%AE%E9%9B%86.pdf?Expires=1703022585&OSSAccessKeyId=LTAI5t7fj2oKqzKgLGz6kGQc&Signature=pg9tnYgHDLkAlfCU%2Bs3h3QBrvfA%3D&response-content-disposition=attachment%3B%20)
57 |
58 | [测试问题.json](https://tianchi-race-prod-sh.oss-cn-shanghai.aliyuncs.com/file/race/documents/532154/%E5%85%B6%E5%AE%83/%E6%B5%8B%E8%AF%95%E9%97%AE%E9%A2%98.json?Expires=1703022684&OSSAccessKeyId=LTAI5t7fj2oKqzKgLGz6kGQc&Signature=kTn%2BN4ZnY9tftVmz5kjNKOCoFAs%3D&response-content-disposition=attachment%3B%20)
59 |
60 |
61 | ### 3、解决方案
62 |
63 | #### 3.1 pdf解析
64 |
65 | ##### 3.1.1 pdf分块解析
66 | 
67 | 如图所示,我们希望pdf解析能尽可能的按照快状进行解析,每一块当做一个样本,这样能尽可能的保证pdf中文本内容的完整性
68 | 改进==》希望借助OCR进行pdf的块状识别
69 |
70 | ##### 3.1.2 pdf 滑窗法解析
71 | 
72 | 
73 | 如图1,2 所示,我们可以看到图1和图2上下文是连续的,如何保证文本内容的跨页连续性问题,我们提出滑窗法。
74 | 具体的把pdf中所有内容当做一个字符串来处理,按照句号进行分割,根据分割后的数组进行滑窗。具体的如下所示:
75 |
76 | ["aa","bb","cc","dd"]
77 |
78 | 如果字符串长度为4, 经过滑窗后的结果如下:
79 |
80 | aabb
81 |
82 | bbcc
83 |
84 | ccdd
85 |
86 | 我们希望滑窗法像卷积一样可以不同的kernel,Stride,来寻找能覆盖到的最优的样本召回
87 |
88 | #### 3.2 召回
89 |
90 | 召回主要使用langchain中的retrievers进行文本的召回。我们知道向量召回和bm25召回具有互补性,因此选用了这两个进行召回
91 |
92 | ##### 3.2.1 向量召回
93 |
94 | 向量召回利用 FAISS 进行索引创建和查找,embedding 利用 [M3E-large](https://modelscope.cn/models/Jerry0/M3E-large/summary) 或者[bge-large-zh](https://modelscope.cn/models/AI-ModelScope/bge-large-zh/summary)
95 |
96 | ##### 3.2.2 bm25召回
97 |
98 | bm25召回利用 langchain自带的bm25 retrievers
99 |
100 | #### 3.3 重排序
101 |
102 | 1、重排序是对召回的文本进行进一步的重排,以获得更精准,数据量更少的可能答案。
103 | 2、向量召回中使用的是bi-encoder结构,而bge-reranker-large 使用的是 cross-encoder结构,cross-encoder结构一定程度上要优于bi-encoder
104 |
105 | ##### 3.3.1 cross-encoder
106 |
107 | 重排序此处使用了 [bge-reranker-large](https://modelscope.cn/models/Xorbits/bge-reranker-large/files)
108 |
109 | #### 3.4 推理优化
110 |
111 | ##### 3.4.1 vllm batch
112 |
113 | vllm 利用page attention 技术使推理速度得到提升,batch推理比普通推理有接近1倍的提升空间
114 |
115 | ##### 3.4.2 tensorRT-LLM
116 |
117 | tensorRT-LLM是英伟达推出的推理框架,并且提供了c++和python的调用方式。关于qwen的tensorRT-LLM使用请参考官方介绍[tensorRT-LLM Qwen](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/qwen)
118 |
119 | ### 4、排名
120 |
121 | [初赛2名](https://tianchi.aliyun.com/competition/entrance/532154/rankingList)
122 | [复赛13名](https://tianchi.aliyun.com/competition/entrance/532154/rankingList)
123 |
--------------------------------------------------------------------------------
/bm25_retriever.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding: utf-8
3 |
4 |
5 | from langchain.retrievers import BM25Retriever
6 | from langchain.schema import Document
7 | from pdf_parse import DataProcess
8 | import jieba
9 |
10 | class BM25(object):
11 |
12 | def __init__(self, documents):
13 |
14 | docs = []
15 | full_docs = []
16 | for idx, line in enumerate(documents):
17 | line = line.strip("\n").strip()
18 | if(len(line)<5):
19 | continue
20 | tokens = " ".join(jieba.cut_for_search(line))
21 | # docs.append(Document(page_content=tokens, metadata={"id": idx, "cate":words[1],"pageid":words[2]}))
22 | docs.append(Document(page_content=tokens, metadata={"id": idx}))
23 | # full_docs.append(Document(page_content=words[0], metadata={"id": idx, "cate":words[1], "pageid":words[2]}))
24 | words = line.split("\t")
25 | full_docs.append(Document(page_content=words[0], metadata={"id": idx}))
26 | self.documents = docs
27 | self.full_documents = full_docs
28 | self.retriever = self._init_bm25()
29 |
30 | # 初始化BM25的知识库
31 | def _init_bm25(self):
32 | return BM25Retriever.from_documents(self.documents)
33 |
34 | # 获得得分在topk的文档和分数
35 | def GetBM25TopK(self, query, topk):
36 | self.retriever.k = topk
37 | query = " ".join(jieba.cut_for_search(query))
38 | ans_docs = self.retriever.get_relevant_documents(query)
39 | ans = []
40 | for line in ans_docs:
41 | ans.append(self.full_documents[line.metadata["id"]])
42 | return ans
43 |
44 | if __name__ == "__main__":
45 |
46 | # bm2.5
47 | dp = DataProcess(pdf_path = "/root/autodl-tmp/codes/data/train_a.pdf")
48 | dp.ParseBlock(max_seq = 1024)
49 | dp.ParseBlock(max_seq = 512)
50 | print(len(dp.data))
51 | dp.ParseAllPage(max_seq = 256)
52 | dp.ParseAllPage(max_seq = 512)
53 | print(len(dp.data))
54 | dp.ParseOnePageWithRule(max_seq = 256)
55 | dp.ParseOnePageWithRule(max_seq = 512)
56 | print(len(dp.data))
57 | data = dp.data
58 | bm25 = BM25(data)
59 | res = bm25.GetBM25TopK("座椅加热", 6)
60 | print(res)
61 |
--------------------------------------------------------------------------------
/build.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # 以下命令作为打包示例,实际使用时请修改为自己的镜像地址, 建议每次提交前完成版本修改重新打包
4 | # docker build -t registry.cn-shanghai.aliyuncs.com/taylor:0.1 .
5 |
6 | imageid=`docker images|awk 'NR>1'|grep "aicar/taylor"|awk '{print($3)}'`
7 | echo $imageid
8 | docker rmi -f $imageid
9 | echo yes|docker builder prune
10 | docker build -t registry.cn-shanghai.aliyuncs.com/taylor:0.1 .
11 |
12 | ImageId=`docker images|awk 'NR>1'|grep "0.1"|awk '{print($3)}'`
13 | echo $ImageId
14 | docker tag $ImageId registry.cn-shanghai.aliyuncs.com/aicar/taylor:v1
15 | docker login --username=xxx -p xxx registry.cn-shanghai.aliyuncs.com
16 | docker push registry.cn-shanghai.aliyuncs.com/aicar/taylor:v1
17 |
18 |
19 |
20 |
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import torch
4 |
5 | # device config
6 | EMBEDDING_DEVICE = "cuda" if torch.cuda.is_available(
7 | ) else "mps" if torch.backends.mps.is_available() else "cpu"
8 | LLM_DEVICE = "cuda" if torch.cuda.is_available(
9 | ) else "mps" if torch.backends.mps.is_available() else "cpu"
10 | num_gpus = torch.cuda.device_count()
11 |
12 | # model cache config
13 | MODEL_CACHE_PATH = os.path.join(os.path.dirname(__file__), 'model_cache')
14 |
15 |
16 | # vector storage config
17 | VECTOR_STORE_PATH='./vector_store'
18 | COLLECTION_NAME='my_collection'
19 |
20 |
21 | # init model config
22 | init_llm = "ChatGLM2-6B"
23 | init_embedding_model = "text2vec-base"
24 |
25 | # model config
26 | embedding_model_dict = {
27 | "ernie-tiny": "nghuyong/ernie-3.0-nano-zh",
28 | "ernie-base": "nghuyong/ernie-3.0-base-zh",
29 | "ernie-medium": "nghuyong/ernie-3.0-medium-zh",
30 | "ernie-xbase": "nghuyong/ernie-3.0-xbase-zh",
31 | "text2vec-base": "GanymedeNil/text2vec-base-chinese",
32 | 'simbert-base-chinese': 'WangZeJun/simbert-base-chinese',
33 | 'paraphrase-multilingual-MiniLM-L12-v2': "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
34 | }
35 |
36 |
37 | llm_model_dict = {
38 | "chatglm2": {
39 | "ChatGLM2-6B": "THUDM/chatglm2-6b",
40 | "ChatGLM2-6B-int4": "THUDM/chatglm2-6b-int4",
41 | },
42 | "chatglm": {
43 | "ChatGLM-6B": "THUDM/chatglm-6b",
44 | "ChatGLM-6B-int4": "THUDM/chatglm-6b-int4",
45 | "ChatGLM-6B-int8": "THUDM/chatglm-6b-int8",
46 | "ChatGLM-6b-int4-qe": "THUDM/chatglm-6b-int4-qe"
47 | },
48 | "belle": {
49 | "BELLE-LLaMA-Local": "/pretrainmodel/belle",
50 | },
51 | "vicuna": {
52 | "Vicuna-Local": "/pretrainmodel/vicuna",
53 | },
54 | "internlm": {
55 | "internlm-chat-7b-8k": "internlm/internlm-chat-7b-8k",
56 | "internlm-chat-7b": "internlm/internlm-chat-7b",
57 | "internlm-chat-7b-v1_1": "internlm/internlm-chat-7b-v1_1",
58 | }
59 | }
60 |
--------------------------------------------------------------------------------
/data/test_question.json:
--------------------------------------------------------------------------------
1 | [
2 | {
3 | "question": "中国足球的队长是谁",
4 | "answer_1": "",
5 | "answer_2": "",
6 | "answer_3": ""
7 | },
8 | {
9 | "question": "新冠肺炎如何预防?",
10 | "answer_1": "",
11 | "answer_2": "",
12 | "answer_3": ""
13 | },
14 | {
15 | "question": "交通事故如何处理?",
16 | "answer_1": "",
17 | "answer_2": "",
18 | "answer_3": ""
19 | },
20 | {
21 | "question": "怎样加热座椅?",
22 | "answer_1": "",
23 | "answer_2": "",
24 | "answer_3": ""
25 | },
26 | {
27 | "question": "自动模式下,中央显示屏是如何切换日间和夜间模式的?",
28 | "answer_1": "",
29 | "answer_2": "",
30 | "answer_3": ""
31 | },
32 | {
33 | "question": "如何通过中央显示屏进行副驾驶员座椅设置?",
34 | "answer_1": "",
35 | "answer_2": "",
36 | "answer_3": ""
37 | },
38 | {
39 | "question": "副仪表台按钮如何操作中央显示屏?",
40 | "answer_1": "",
41 | "answer_2": "",
42 | "answer_3": ""
43 | },
44 | {
45 | "question": "如何从锁定状态唤醒中央显示器?",
46 | "answer_1": "",
47 | "answer_2": "",
48 | "answer_3": ""
49 | },
50 | {
51 | "question": "如何正确使用颈椎保护系统?",
52 | "answer_1": "",
53 | "answer_2": "",
54 | "answer_3": ""
55 | },
56 | {
57 | "question": "前方交叉路口预警系统(FCTA)的作用是什么?",
58 | "answer_1": "",
59 | "answer_2": "",
60 | "answer_3": ""
61 | },
62 | {
63 | "question": "在使用FCTA时需要注意哪些事项?",
64 | "answer_1": "",
65 | "answer_2": "",
66 | "answer_3": ""
67 | },
68 | {
69 | "question": "如何打开车辆尾门?",
70 | "answer_1": "",
71 | "answer_2": "",
72 | "answer_3": ""
73 | },
74 | {
75 | "question": "在哪些情况下智能钥匙可能会受到干扰,导致功能异常?",
76 | "answer_1": "",
77 | "answer_2": "",
78 | "answer_3": ""
79 | },
80 | {
81 | "question": "车辆尾门的防夹保护功能是如何工作的?",
82 | "answer_1": "",
83 | "answer_2": "",
84 | "answer_3": ""
85 | },
86 | {
87 | "question": "在操作电动后备厢时需要注意哪些事项?",
88 | "answer_1": "",
89 | "answer_2": "",
90 | "answer_3": ""
91 | },
92 | {
93 | "question": "如何进入车辆功能界面?",
94 | "answer_1": "",
95 | "answer_2": "",
96 | "answer_3": ""
97 | },
98 | {
99 | "question": "在车辆功能界面有哪些操作选项?",
100 | "answer_1": "",
101 | "answer_2": "",
102 | "answer_3": ""
103 | },
104 | {
105 | "question": "如何编辑快捷开关图标?",
106 | "answer_1": "",
107 | "answer_2": "",
108 | "answer_3": ""
109 | },
110 | {
111 | "question": "如何减少车辆腐蚀风险?",
112 | "answer_1": "",
113 | "answer_2": "",
114 | "answer_3": ""
115 | },
116 | {
117 | "question": "如何通过空调系统面板调节空调风量?",
118 | "answer_1": "",
119 | "answer_2": "",
120 | "answer_3": ""
121 | },
122 | {
123 | "question": "如何创建新的Lynk&CoID?",
124 | "answer_1": "",
125 | "answer_2": "",
126 | "answer_3": ""
127 | },
128 | {
129 | "question": "什么是车主账户?",
130 | "answer_1": "",
131 | "answer_2": "",
132 | "answer_3": ""
133 | },
134 | {
135 | "question": "如何创建人脸识别?",
136 | "answer_1": "",
137 | "answer_2": "",
138 | "answer_3": ""
139 | },
140 | {
141 | "question": "如何添加亲情账号?",
142 | "answer_1": "",
143 | "answer_2": "",
144 | "answer_3": ""
145 | },
146 | {
147 | "question": "如何开启或关闭用车偏好自动同步?",
148 | "answer_1": "",
149 | "answer_2": "",
150 | "answer_3": ""
151 | },
152 | {
153 | "question": "如何熄火我的车辆?",
154 | "answer_1": "",
155 | "answer_2": "",
156 | "answer_3": ""
157 | },
158 | {
159 | "question": "如何通过遥控钥匙启动车辆?",
160 | "answer_1": "",
161 | "answer_2": "",
162 | "answer_3": ""
163 | },
164 | {
165 | "question": "如果遥控钥匙电池电量低,我应该如何启动车辆?",
166 | "answer_1": "",
167 | "answer_2": "",
168 | "answer_3": ""
169 | },
170 | {
171 | "question": "如何调节外后视镜?",
172 | "answer_1": "",
173 | "answer_2": "",
174 | "answer_3": ""
175 | },
176 | {
177 | "question": "外部反光境显示物体距离是否准确?",
178 | "answer_1": "",
179 | "answer_2": "",
180 | "answer_3": ""
181 | },
182 | {
183 | "question": "什么是自动驻车系统?",
184 | "answer_1": "",
185 | "answer_2": "",
186 | "answer_3": ""
187 | },
188 | {
189 | "question": "在什么情况下会停用AutoHold并启用EPB功能?",
190 | "answer_1": "",
191 | "answer_2": "",
192 | "answer_3": ""
193 | },
194 | {
195 | "question": "中央扶手箱的USB接口有几个?它们分别是什么类型?",
196 | "answer_1": "",
197 | "answer_2": "",
198 | "answer_3": ""
199 | },
200 | {
201 | "question": "中央扶手箱所支持的U盘及数据传输格式有哪些?",
202 | "answer_1": "",
203 | "answer_2": "",
204 | "answer_3": ""
205 | },
206 | {
207 | "question": "如何通过中央显示屏调节驾驶员侧座椅通风强度?",
208 | "answer_1": "",
209 | "answer_2": "",
210 | "answer_3": ""
211 | },
212 | {
213 | "question": "如何关闭前排座行车通风功能?",
214 | "answer_1": "",
215 | "answer_2": "",
216 | "answer_3": ""
217 | },
218 | {
219 | "question": "如何进入系统设置界面?",
220 | "answer_1": "",
221 | "answer_2": "",
222 | "answer_3": ""
223 | },
224 | {
225 | "question": "在系统界面可以进行哪些操作?",
226 | "answer_1": "",
227 | "answer_2": "",
228 | "answer_3": ""
229 | },
230 | {
231 | "question": "什么是无钥匙进入系统?",
232 | "answer_1": "",
233 | "answer_2": "",
234 | "answer_3": ""
235 | },
236 | {
237 | "question": "如何设置无钥匙解锁模式?",
238 | "answer_1": "",
239 | "answer_2": "",
240 | "answer_3": ""
241 | },
242 | {
243 | "question": "设置无钥匙解锁中单门和全车的区别在于什么?",
244 | "answer_1": "",
245 | "answer_2": "",
246 | "answer_3": ""
247 | },
248 | {
249 | "question": "驾驶车辆时应遵守哪些注意事项?",
250 | "answer_1": "",
251 | "answer_2": "",
252 | "answer_3": ""
253 | },
254 | {
255 | "question": "如何启用或停用手套箱密码保护功能?",
256 | "answer_1": "",
257 | "answer_2": "",
258 | "answer_3": ""
259 | },
260 | {
261 | "question": "驾驶员状态监测系统是如何工作的?",
262 | "answer_1": "",
263 | "answer_2": "",
264 | "answer_3": ""
265 | },
266 | {
267 | "question": "什么情况下会影响到驾驶员状态监测系统的工作?",
268 | "answer_1": "",
269 | "answer_2": "",
270 | "answer_3": ""
271 | },
272 | {
273 | "question": "如何启用后排儿童锁功能?",
274 | "answer_1": "",
275 | "answer_2": "",
276 | "answer_3": ""
277 | },
278 | {
279 | "question": "安全气囊是什么?它的作用是什么?",
280 | "answer_1": "",
281 | "answer_2": "",
282 | "answer_3": ""
283 | },
284 | {
285 | "question": "如果未使用或未正确使用安全带,会对安全气囊有何影响? ",
286 | "answer_1": "",
287 | "answer_2": "",
288 | "answer_3": ""
289 | },
290 | {
291 | "question": "在使用车辆时,有哪些安全气囊的注意事项?",
292 | "answer_1": "",
293 | "answer_2": "",
294 | "answer_3": ""
295 | },
296 | {
297 | "question": "如何开启动力电池电量保持功能?",
298 | "answer_1": "",
299 | "answer_2": "",
300 | "answer_3": ""
301 | },
302 | {
303 | "question": "在开启动力电池的情况下,选择经济性优先和充电速度优先有什么区别?",
304 | "answer_1": "",
305 | "answer_2": "",
306 | "answer_3": ""
307 | },
308 | {
309 | "question": "后方碰撞预警系统在什么情况下会启动?",
310 | "answer_1": "",
311 | "answer_2": "",
312 | "answer_3": ""
313 | },
314 | {
315 | "question": "如何调整方向盘的位置?",
316 | "answer_1": "",
317 | "answer_2": "",
318 | "answer_3": ""
319 | },
320 | {
321 | "question": "什么情况下不能调节车辆的方向盘?",
322 | "answer_1": "",
323 | "answer_2": "",
324 | "answer_3": ""
325 | },
326 | {
327 | "question": "如何通过手机APP启动车辆?",
328 | "answer_1": "",
329 | "answer_2": "",
330 | "answer_3": ""
331 | },
332 | {
333 | "question": "什么是陡坡缓降系统(HDC什么是陡坡缓降系统(HDC)?",
334 | "answer_1": "",
335 | "answer_2": "",
336 | "answer_3": ""
337 | },
338 | {
339 | "question": "当坡度过大时,如何操作才能使车辆保持匀速地行驶?",
340 | "answer_1": "",
341 | "answer_2": "",
342 | "answer_3": ""
343 | },
344 | {
345 | "question": "在什么情况下HDC会激活?",
346 | "answer_1": "",
347 | "answer_2": "",
348 | "answer_3": ""
349 | },
350 | {
351 | "question": "在什么情况下无法激活或自动退出HDC功能?",
352 | "answer_1": "",
353 | "answer_2": "",
354 | "answer_3": ""
355 | },
356 | {
357 | "question": "开启陡坡缓降系统后,组合仪表显示什么颜色的指示灯?",
358 | "answer_1": "",
359 | "answer_2": "",
360 | "answer_3": ""
361 | },
362 | {
363 | "question": "激活陡坡缓降系统时,组合仪表显示什么颜色的指示灯?",
364 | "answer_1": "",
365 | "answer_2": "",
366 | "answer_3": ""
367 | },
368 | {
369 | "question": "当陡坡缓降系统出现故障时,组合仪表会显示怎样的提示?",
370 | "answer_1": "",
371 | "answer_2": "",
372 | "answer_3": ""
373 | },
374 | {
375 | "question": "坡道辅助系统的主要功能是什么?",
376 | "answer_1": "",
377 | "answer_2": "",
378 | "answer_3": ""
379 | },
380 | {
381 | "question": "使用坡道辅助系统时需要注意哪些警告?",
382 | "answer_1": "",
383 | "answer_2": "",
384 | "answer_3": ""
385 | },
386 | {
387 | "question": "在有路缘石的上坡和下坡驻车时,应如何操作?",
388 | "answer_1": "",
389 | "answer_2": "",
390 | "answer_3": ""
391 | },
392 | {
393 | "question": "新车在最初的2000km行驶期间应遵守哪些事项?",
394 | "answer_1": "",
395 | "answer_2": "",
396 | "answer_3": ""
397 | },
398 | {
399 | "question": "当刹车片将达到最小安全厚度时会有什么表现?",
400 | "answer_1": "",
401 | "answer_2": "",
402 | "answer_3": ""
403 | },
404 | {
405 | "question": "如果听到持续的高频尖锐噪声应该怎么做?",
406 | "answer_1": "",
407 | "answer_2": "",
408 | "answer_3": ""
409 | },
410 | {
411 | "question": "风挡摄像头和前雷达在什么情况下会弹出警示消息?",
412 | "answer_1": "",
413 | "answer_2": "",
414 | "answer_3": ""
415 | },
416 | {
417 | "question": "如何通过蓝牙钥匙启动车辆?",
418 | "answer_1": "",
419 | "answer_2": "",
420 | "answer_3": ""
421 | },
422 | {
423 | "question": "在什么情况下无法通过蓝牙钥匙启动车辆?",
424 | "answer_1": "",
425 | "answer_2": "",
426 | "answer_3": ""
427 | },
428 | {
429 | "question": "如何通过网络远程启动发动机?",
430 | "answer_1": "",
431 | "answer_2": "",
432 | "answer_3": ""
433 | },
434 | {
435 | "question": "变道辅助系统由哪三部分组成?",
436 | "answer_1": "",
437 | "answer_2": "",
438 | "answer_3": ""
439 | },
440 | {
441 | "question": "如何检查制动液液位?",
442 | "answer_1": "",
443 | "answer_2": "",
444 | "answer_3": ""
445 | },
446 | {
447 | "question": "如何调节车辆的背光亮度?",
448 | "answer_1": "",
449 | "answer_2": "",
450 | "answer_3": ""
451 | },
452 | {
453 | "question": "在白天和夜晚,开启和未开启整车背光联动时,转动调光旋钮有什么不同?",
454 | "answer_1": "",
455 | "answer_2": "",
456 | "answer_3": ""
457 | },
458 | {
459 | "question": "当车辆发生故障时,应该如何处理?",
460 | "answer_1": "",
461 | "answer_2": "",
462 | "answer_3": ""
463 | },
464 | {
465 | "question": "为什么不建议自行修理车辆故障?",
466 | "answer_1": "",
467 | "answer_2": "",
468 | "answer_3": ""
469 | },
470 | {
471 | "question": "什么是能量回收系统?",
472 | "answer_1": "",
473 | "answer_2": "",
474 | "answer_3": ""
475 | },
476 | {
477 | "question": "如何知道车辆正在进行能量回收?",
478 | "answer_1": "",
479 | "answer_2": "",
480 | "answer_3": ""
481 | },
482 | {
483 | "question": "影响能量回收多少的因素有哪些?",
484 | "answer_1": "",
485 | "answer_2": "",
486 | "answer_3": ""
487 | },
488 | {
489 | "question": "什么是智能远近光控制系统?",
490 | "answer_1": "",
491 | "answer_2": "",
492 | "answer_3": ""
493 | },
494 | {
495 | "question": "在什么情况下可以开启智能远近光控制系统?",
496 | "answer_1": "",
497 | "answer_2": "",
498 | "answer_3": ""
499 | },
500 | {
501 | "question": "哪些因素可能导致智能远近光控制系统无法正常工作?",
502 | "answer_1": "",
503 | "answer_2": "",
504 | "answer_3": ""
505 | },
506 | {
507 | "question": "如何启用前除霜/除雾功能?",
508 | "answer_1": "",
509 | "answer_2": "",
510 | "answer_3": ""
511 | },
512 | {
513 | "question": "当我开启前挡风玻璃的去冰去雾模式时,车辆会有什么反应?",
514 | "answer_1": "",
515 | "answer_2": "",
516 | "answer_3": ""
517 | },
518 | {
519 | "question": "为什么驾驶之前需要确保挡风玻璃无冰渣、积雪或冷凝水?",
520 | "answer_1": "",
521 | "answer_2": "",
522 | "answer_3": ""
523 | },
524 | {
525 | "question": "什么情况下车辆需要进行报废处理?",
526 | "answer_1": "",
527 | "answer_2": "",
528 | "answer_3": ""
529 | },
530 | {
531 | "question": "什么是遥控泊车(RPA)?",
532 | "answer_1": "",
533 | "answer_2": "",
534 | "answer_3": ""
535 | },
536 | {
537 | "question": "使用Lynk&CoApp的RPA功能需要注意什么?",
538 | "answer_1": "",
539 | "answer_2": "",
540 | "answer_3": ""
541 | },
542 | {
543 | "question": "如何开启方向盘助力与驾驶模式联动?",
544 | "answer_1": "",
545 | "answer_2": "",
546 | "answer_3": ""
547 | },
548 | {
549 | "question": "有哪些可选的方向盘转向助力模式?",
550 | "answer_1": "",
551 | "answer_2": "",
552 | "answer_3": ""
553 | },
554 | {
555 | "question": "主动式座舱清洁系统的作用是什么?",
556 | "answer_1": "",
557 | "answer_2": "",
558 | "answer_3": ""
559 | },
560 | {
561 | "question": "全景天窗是由几部分组成的?",
562 | "answer_1": "",
563 | "answer_2": "",
564 | "answer_3": ""
565 | },
566 | {
567 | "question": "如何使用位置记忆功能?",
568 | "answer_1": "",
569 | "answer_2": "",
570 | "answer_3": ""
571 | },
572 | {
573 | "question": "如何更换后挡风玻璃雨刮片?",
574 | "answer_1": "",
575 | "answer_2": "",
576 | "answer_3": ""
577 | },
578 | {
579 | "question": "在更换雨刮片时需要注意什么?",
580 | "answer_1": "",
581 | "answer_2": "",
582 | "answer_3": ""
583 | },
584 | {
585 | "question": "如何添加香氛精油?",
586 | "answer_1": "",
587 | "answer_2": "",
588 | "answer_3": ""
589 | },
590 | {
591 | "question": "我应该在哪里添加香氛精油?",
592 | "answer_1": "",
593 | "answer_2": "",
594 | "answer_3": ""
595 | },
596 | {
597 | "question": "什么是转向助力系统?",
598 | "answer_1": "",
599 | "answer_2": "",
600 | "answer_3": ""
601 | },
602 | {
603 | "question": "涉水行驶前应注意什么?",
604 | "answer_1": "",
605 | "answer_2": "",
606 | "answer_3": ""
607 | },
608 | {
609 | "question": "涉水驾驶后需要进行哪些检查?",
610 | "answer_1": "",
611 | "answer_2": "",
612 | "answer_3": ""
613 | },
614 | {
615 | "question": "什么时候应该为车辆打蜡?",
616 | "answer_1": "",
617 | "answer_2": "",
618 | "answer_3": ""
619 | }
620 | ]
621 |
--------------------------------------------------------------------------------
/data/train_a.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dawoshi/Tianchi-LLM-QA/fccb285a683f9ccb578eda2f22e9c94883b43672/data/train_a.pdf
--------------------------------------------------------------------------------
/faiss_retriever.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding: utf-8
3 |
4 |
5 | from langchain.schema import Document
6 | from langchain.vectorstores import Chroma,FAISS
7 | from langchain.embeddings.huggingface import HuggingFaceEmbeddings
8 | from pdf_parse import DataProcess
9 | import torch
10 | # from bm25_retriever import BM25
11 |
12 | class FaissRetriever(object):
13 | def __init__(self, model_path, data):
14 | self.embeddings = HuggingFaceEmbeddings(
15 | model_name = model_path,
16 | model_kwargs = {"device":"cuda"}
17 | )
18 | docs = []
19 | for idx, line in enumerate(data):
20 | line = line.strip("\n").strip()
21 | words = line.split("\t")
22 | docs.append(Document(page_content=words[0], metadata={"id": idx}))
23 | self.vector_store = FAISS.from_documents(docs, self.embeddings)
24 | del self.embeddings
25 | torch.cuda.empty_cache()
26 |
27 | def GetTopK(self, query, k):
28 | context = self.vector_store.similarity_search_with_score(query, k=k)
29 | return context
30 | def GetvectorStore(self):
31 | return self.vector_store
32 |
33 | if __name__ == "__main__":
34 | base = "/root/autodl-tmp/codes"
35 | model_name=base + "/pre_train_model/m3e-large" #text2vec-large-chinese
36 | dp = DataProcess(pdf_path = base + "/data/train_a.pdf")
37 | dp.ParseBlock(max_seq = 1024)
38 | dp.ParseBlock(max_seq = 512)
39 | print(len(dp.data))
40 | dp.ParseAllPage(max_seq = 256)
41 | dp.ParseAllPage(max_seq = 512)
42 | print(len(dp.data))
43 | dp.ParseOnePageWithRule(max_seq = 256)
44 | dp.ParseOnePageWithRule(max_seq = 512)
45 | print(len(dp.data))
46 | data = dp.data
47 |
48 | faissretriever = FaissRetriever(model_name, data)
49 | # bm25 = BM25(data)
50 | faiss_ans = faissretriever.GetTopK("如何预防新冠肺炎", 6)
51 | print(faiss_ans)
52 | faiss_ans = faissretriever.GetTopK("交通事故如何处理", 6)
53 | print(faiss_ans)
54 | faiss_ans = faissretriever.GetTopK("吉利集团的董事长是谁", 6)
55 | print(faiss_ans)
56 | faiss_ans = faissretriever.GetTopK("吉利汽车语音组手叫什么", 6)
57 | print(faiss_ans)
58 | # bm25_ans = bm25.GetBM25TopK("座椅加热", 6)
59 | # ans = reRank(6, bm25_ans, faiss_ans)
60 |
--------------------------------------------------------------------------------
/images/01.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dawoshi/Tianchi-LLM-QA/fccb285a683f9ccb578eda2f22e9c94883b43672/images/01.png
--------------------------------------------------------------------------------
/images/02.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dawoshi/Tianchi-LLM-QA/fccb285a683f9ccb578eda2f22e9c94883b43672/images/02.png
--------------------------------------------------------------------------------
/images/03.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dawoshi/Tianchi-LLM-QA/fccb285a683f9ccb578eda2f22e9c94883b43672/images/03.png
--------------------------------------------------------------------------------
/pdf_parse.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding: utf-8
3 |
4 | import pdfplumber
5 | from PyPDF2 import PdfReader
6 |
7 |
8 | class DataProcess(object):
9 |
10 | def __init__(self, pdf_path):
11 | self.pdf_path = pdf_path
12 | self.data = []
13 | def SlidingWindow(self, sentences, kernel = 512, stride = 1):
14 | sz = len(sentences)
15 | cur = ""
16 | fast = 0
17 | slow = 0
18 | while(fast < len(sentences)):
19 | sentence = sentences[fast]
20 | if(len(cur + sentence) > kernel and (cur + sentence) not in self.data):
21 | self.data.append(cur + sentence + "。")
22 | cur = cur[len(sentences[slow] + "。"):]
23 | slow = slow + 1
24 | cur = cur + sentence + "。"
25 | fast = fast + 1
26 |
27 | def Datafilter(self, line, header, pageid, max_seq = 1024):
28 |
29 | sz = len(line)
30 | if(sz < 6):
31 | return
32 |
33 | if(sz > max_seq):
34 |
35 | if("■" in line):
36 | sentences = line.split("■")
37 | elif("•" in line):
38 | sentences = line.split("•")
39 | elif("\t" in line):
40 | sentences = line.split("\t")
41 | else:
42 | sentences = line.split("。")
43 |
44 | for subsentence in sentences:
45 | subsentence = subsentence.replace("\n", "")
46 |
47 | if(len(subsentence) < max_seq and len(subsentence) > 5):
48 | # subsentence = subsentence.replace(",", "").replace("\n","").replace("\t","") + "\t" + header+ "\t" + str(pageid)
49 | subsentence = subsentence.replace(",", "").replace("\n","").replace("\t","")
50 | if(subsentence not in self.data):
51 | self.data.append(subsentence)
52 | else:
53 | # line = line.replace("\n","").replace(",", "").replace("\t","") + "\t" + header + "\t" + str(pageid)
54 | line = line.replace("\n","").replace(",", "").replace("\t","")
55 | if(line not in self.data):
56 | self.data.append(line)
57 | # 提取页头即一级标题
58 |
59 | def GetHeader(self, page):
60 | try:
61 | lines = page.extract_words()[::]
62 | except:
63 | return None
64 | if(len(lines) > 0):
65 | for line in lines:
66 | if("目录" in line["text"] or ".........." in line["text"]):
67 | return None
68 | if(line["top"] < 20 and line["top"] > 17):
69 | return line["text"]
70 | return lines[0]["text"]
71 | return None
72 |
73 | # 按照每页中块提取内容,并和一级标题进行组合,配合Document 可进行意图识别
74 | def ParseBlock(self, max_seq = 1024):
75 |
76 | with pdfplumber.open(self.pdf_path) as pdf:
77 |
78 | for i, p in enumerate(pdf.pages):
79 | header = self.GetHeader(p)
80 |
81 | if(header == None):
82 | continue
83 |
84 | texts = p.extract_words(use_text_flow=True, extra_attrs = ["size"])[::]
85 |
86 | squence = ""
87 | lastsize = 0
88 |
89 | for idx, line in enumerate(texts):
90 | if(idx <1):
91 | continue
92 | if(idx == 1):
93 | if(line["text"].isdigit()):
94 | continue
95 | cursize = line["size"]
96 | text = line["text"]
97 | if(text == "□" or text == "•"):
98 | continue
99 | elif(text== "警告!" or text == "注意!" or text == "说明!"):
100 | if(len(squence) > 0):
101 | self.Datafilter(squence, header, i, max_seq = max_seq)
102 | squence = ""
103 | elif(format(lastsize,".5f") == format(cursize,".5f")):
104 | if(len(squence)>0):
105 | squence = squence + text
106 | else:
107 | squence = text
108 | else:
109 | lastsize = cursize
110 | if(len(squence) < 15 and len(squence)>0):
111 | squence = squence + text
112 | else:
113 | if(len(squence) > 0):
114 | self.Datafilter(squence, header, i, max_seq = max_seq)
115 | squence = text
116 | if(len(squence) > 0):
117 | self.Datafilter(squence, header, i, max_seq = max_seq)
118 | def ParseOnePageWithRule(self, max_seq = 512, min_len = 6):
119 | for idx, page in enumerate(PdfReader(self.pdf_path).pages):
120 | page_content = ""
121 | text = page.extract_text()
122 | words = text.split("\n")
123 | for idx, word in enumerate(words):
124 | text = word.strip().strip("\n")
125 | if("...................." in text or "目录" in text):
126 | continue
127 | if(len(text) < 1):
128 | continue
129 | if(text.isdigit()):
130 | continue
131 | page_content = page_content + text
132 | if(len(page_content) < min_len):
133 | continue
134 | if(len(page_content) < max_seq):
135 | if(page_content not in self.data):
136 | self.data.append(page_content)
137 | else:
138 | sentences = page_content.split("。")
139 | cur = ""
140 | for idx, sentence in enumerate(sentences):
141 | if(len(cur + sentence) > max_seq and (cur + sentence) not in self.data):
142 | self.data.append(cur + sentence)
143 | cur = sentence
144 | else:
145 | cur = cur + sentence
146 | # 滑窗法提取段落
147 | # 1. 把pdf看做一个整体,作为一个字符串
148 | # 2. 利用句号当做分隔符,切分成一个数组
149 | # 3. 利用滑窗法对数组进行滑动, 此处的
150 | def ParseAllPage(self, max_seq = 512, min_len = 6):
151 | all_content = ""
152 | for idx, page in enumerate(PdfReader(self.pdf_path).pages):
153 | page_content = ""
154 | text = page.extract_text()
155 | words = text.split("\n")
156 | for idx, word in enumerate(words):
157 | text = word.strip().strip("\n")
158 | if("...................." in text or "目录" in text):
159 | continue
160 | if(len(text) < 1):
161 | continue
162 | if(text.isdigit()):
163 | continue
164 | page_content = page_content + text
165 | if(len(page_content) < min_len):
166 | continue
167 | all_content = all_content + page_content
168 | sentences = all_content.split("。")
169 | self.SlidingWindow(sentences, kernel = max_seq)
170 |
171 | # for idx, sentence in enumerate(sentences):
172 | # if(len(cur + sentence) > max_seq and (cur + sentence) not in self.data):
173 | # self.data.append(cur + sentence)
174 | # cur = sentence
175 | # else:
176 | # cur = cur + sentence
177 |
178 | if __name__ == "__main__":
179 | dp = DataProcess(pdf_path = "/root/autodl-tmp/codes/data/train_a.pdf")
180 | dp.ParseBlock(max_seq = 1024)
181 | dp.ParseBlock(max_seq = 512)
182 | print(len(dp.data))
183 | dp.ParseAllPage(max_seq = 256)
184 | dp.ParseAllPage(max_seq = 512)
185 | print(len(dp.data))
186 | dp.ParseOnePageWithRule(max_seq = 256)
187 | dp.ParseOnePageWithRule(max_seq = 512)
188 | print(len(dp.data))
189 | data = dp.data
190 | out = open("all_text.txt", "w")
191 | for line in data:
192 | line = line.strip("\n")
193 | out.write(line)
194 | out.write("\n")
195 | out.close()
196 |
--------------------------------------------------------------------------------
/pre_train_model/Qwen-7B-Chat/download.py:
--------------------------------------------------------------------------------
1 | #模型下载
2 | from modelscope import snapshot_download
3 | model_dir = snapshot_download('qwen/Qwen-7B-Chat')
4 |
--------------------------------------------------------------------------------
/qwen_generation_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Alibaba Cloud.
2 | #
3 | # This source code is licensed under the license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | """Generation support."""
7 |
8 | from typing import Tuple, List, Union, Iterable
9 |
10 | import numpy as np
11 | import torch
12 | import torch.nn.functional as F
13 | from transformers import PreTrainedTokenizer
14 | from transformers import logging
15 | from transformers.generation import LogitsProcessor
16 |
17 | logger = logging.get_logger(__name__)
18 |
19 | # Types.
20 | HistoryType = List[Tuple[str, str]]
21 | TokensType = List[int]
22 | BatchTokensType = List[List[int]]
23 |
24 |
25 | def pad_batch(batch: BatchTokensType, pad_id: int, seq_length: int) -> BatchTokensType:
26 | for tokens in batch:
27 | context_length = len(tokens)
28 | if context_length < seq_length:
29 | tokens.extend([pad_id] * (seq_length - context_length))
30 | return batch
31 |
32 |
33 | def get_ltor_masks_and_position_ids(
34 | data,
35 | eod_token,
36 | reset_position_ids,
37 | reset_attention_mask,
38 | eod_mask_loss,
39 | ):
40 | """Build masks and position id for left to right model."""
41 |
42 | # Extract batch size and sequence length.
43 | micro_batch_size, seq_length = data.size()
44 |
45 | # Attention mask (lower triangular).
46 | if reset_attention_mask:
47 | att_mask_batch = micro_batch_size
48 | else:
49 | att_mask_batch = 1
50 | attention_mask = torch.tril(
51 | torch.ones((att_mask_batch, seq_length, seq_length), device=data.device)
52 | ).view(att_mask_batch, 1, seq_length, seq_length)
53 |
54 | # Loss mask.
55 | loss_mask = torch.ones(data.size(), dtype=torch.float, device=data.device)
56 | if eod_mask_loss:
57 | loss_mask[data == eod_token] = 0.0
58 |
59 | # Position ids.
60 | position_ids = torch.arange(seq_length, dtype=torch.long, device=data.device)
61 | position_ids = position_ids.unsqueeze(0).expand_as(data)
62 | # We need to clone as the ids will be modifed based on batch index.
63 | if reset_position_ids:
64 | position_ids = position_ids.clone()
65 |
66 | if reset_position_ids or reset_attention_mask:
67 | # Loop through the batches:
68 | for b in range(micro_batch_size):
69 |
70 | # Find indecies where EOD token is.
71 | eod_index = position_ids[b, data[b] == eod_token]
72 | # Detach indecies from positions if going to modify positions.
73 | if reset_position_ids:
74 | eod_index = eod_index.clone()
75 |
76 | # Loop through EOD indecies:
77 | prev_index = 0
78 | for j in range(eod_index.size()[0]):
79 | i = eod_index[j]
80 | # Mask attention loss.
81 | if reset_attention_mask:
82 | attention_mask[b, 0, (i + 1) :, : (i + 1)] = 0
83 | # Reset positions.
84 | if reset_position_ids:
85 | position_ids[b, (i + 1) :] -= i + 1 - prev_index
86 | prev_index = i + 1
87 |
88 | # Convert attention mask to binary:
89 | attention_mask = attention_mask < 0.5
90 |
91 | return attention_mask, loss_mask, position_ids
92 |
93 |
94 | def get_batch(context_tokens: torch.LongTensor, eod_id: int):
95 | """Generate batch from context tokens."""
96 | # Move to GPU.
97 | tokens = context_tokens.contiguous().to(context_tokens.device)
98 | # Get the attention mask and postition ids.
99 | attention_mask, _, position_ids = get_ltor_masks_and_position_ids(
100 | tokens,
101 | eod_id,
102 | reset_position_ids=False,
103 | reset_attention_mask=False,
104 | eod_mask_loss=False,
105 | )
106 | return tokens, attention_mask, position_ids
107 |
108 |
109 | def get_stop_words_ids(chat_format, tokenizer):
110 | if chat_format == "raw":
111 | stop_words_ids = [tokenizer.encode("Human:"), [tokenizer.eod_id]]
112 | elif chat_format == "chatml":
113 | stop_words_ids = [[tokenizer.im_end_id], [tokenizer.im_start_id]]
114 | else:
115 | raise NotImplementedError(f"Unknown chat format {chat_format!r}")
116 | return stop_words_ids
117 |
118 |
119 | def make_context(
120 | tokenizer: PreTrainedTokenizer,
121 | query: str,
122 | history: List[Tuple[str, str]] = None,
123 | system: str = "",
124 | max_window_size: int = 6144,
125 | chat_format: str = "chatml",
126 | ):
127 | if history is None:
128 | history = []
129 |
130 | if chat_format == "chatml":
131 | im_start, im_end = "<|im_start|>", "<|im_end|>"
132 | im_start_tokens = [tokenizer.im_start_id]
133 | im_end_tokens = [tokenizer.im_end_id]
134 | nl_tokens = tokenizer.encode("\n")
135 |
136 | def _tokenize_str(role, content):
137 | return f"{role}\n{content}", tokenizer.encode(
138 | role, allowed_special=set()
139 | ) + nl_tokens + tokenizer.encode(content, allowed_special=set())
140 |
141 | system_text, system_tokens_part = _tokenize_str("system", system)
142 | system_tokens = im_start_tokens + system_tokens_part + im_end_tokens
143 |
144 | raw_text = ""
145 | context_tokens = []
146 |
147 | for turn_query, turn_response in reversed(history):
148 | query_text, query_tokens_part = _tokenize_str("user", turn_query)
149 | query_tokens = im_start_tokens + query_tokens_part + im_end_tokens
150 | response_text, response_tokens_part = _tokenize_str(
151 | "assistant", turn_response
152 | )
153 | response_tokens = im_start_tokens + response_tokens_part + im_end_tokens
154 |
155 | next_context_tokens = nl_tokens + query_tokens + nl_tokens + response_tokens
156 | prev_chat = (
157 | f"\n{im_start}{query_text}{im_end}\n{im_start}{response_text}{im_end}"
158 | )
159 |
160 | current_context_size = (
161 | len(system_tokens) + len(next_context_tokens) + len(context_tokens)
162 | )
163 | if current_context_size < max_window_size:
164 | context_tokens = next_context_tokens + context_tokens
165 | raw_text = prev_chat + raw_text
166 | else:
167 | break
168 |
169 | context_tokens = system_tokens + context_tokens
170 | raw_text = f"{im_start}{system_text}{im_end}" + raw_text
171 | context_tokens += (
172 | nl_tokens
173 | + im_start_tokens
174 | + _tokenize_str("user", query)[1]
175 | + im_end_tokens
176 | + nl_tokens
177 | + im_start_tokens
178 | + tokenizer.encode("assistant")
179 | + nl_tokens
180 | )
181 | raw_text += f"\n{im_start}user\n{query}{im_end}\n{im_start}assistant\n"
182 |
183 | elif chat_format == "raw":
184 | raw_text = query
185 | context_tokens = tokenizer.encode(raw_text)
186 | else:
187 | raise NotImplementedError(f"Unknown chat format {chat_format!r}")
188 |
189 | return raw_text, context_tokens
190 |
191 |
192 | def _decode_default(
193 | tokens: List[int],
194 | *,
195 | stop_words: List[str],
196 | eod_words: List[str],
197 | tokenizer: PreTrainedTokenizer,
198 | raw_text_len: int,
199 | verbose: bool = False,
200 | return_end_reason: bool = False,
201 | errors: str='replace',
202 | ):
203 | trim_decode_tokens = tokenizer.decode(tokens, errors=errors)[raw_text_len:]
204 | if verbose:
205 | print("\nRaw Generate: ", trim_decode_tokens)
206 |
207 | end_reason = f"Gen length {len(tokens)}"
208 | for stop_word in stop_words:
209 | trim_decode_tokens = trim_decode_tokens.replace(stop_word, "").strip()
210 | for eod_word in eod_words:
211 | if eod_word in trim_decode_tokens:
212 | end_reason = f"Gen {eod_word!r}"
213 | trim_decode_tokens = trim_decode_tokens.split(eod_word)[0]
214 | trim_decode_tokens = trim_decode_tokens.strip()
215 | if verbose:
216 | print("\nEnd Reason:", end_reason)
217 | print("\nGenerate: ", trim_decode_tokens)
218 |
219 | if return_end_reason:
220 | return trim_decode_tokens, end_reason
221 | else:
222 | return trim_decode_tokens
223 |
224 |
225 | def _decode_chatml(
226 | tokens: List[int],
227 | *,
228 | stop_words: List[str],
229 | eod_token_ids: List[int],
230 | tokenizer: PreTrainedTokenizer,
231 | raw_text_len: int,
232 | context_length: int,
233 | verbose: bool = False,
234 | return_end_reason: bool = False,
235 | errors: str='replace'
236 | ):
237 | end_reason = f"Gen length {len(tokens)}"
238 | eod_token_idx = context_length
239 | for eod_token_idx in range(context_length, len(tokens)):
240 | if tokens[eod_token_idx] in eod_token_ids:
241 | end_reason = f"Gen {tokenizer.decode([tokens[eod_token_idx]])!r}"
242 | break
243 |
244 | trim_decode_tokens = tokenizer.decode(tokens[:eod_token_idx], errors=errors)[raw_text_len:]
245 | if verbose:
246 | print("\nRaw Generate w/o EOD:", tokenizer.decode(tokens, errors=errors)[raw_text_len:])
247 | print("\nRaw Generate:", trim_decode_tokens)
248 | print("\nEnd Reason:", end_reason)
249 | for stop_word in stop_words:
250 | trim_decode_tokens = trim_decode_tokens.replace(stop_word, "").strip()
251 | trim_decode_tokens = trim_decode_tokens.strip()
252 | if verbose:
253 | print("\nGenerate:", trim_decode_tokens)
254 |
255 | if return_end_reason:
256 | return trim_decode_tokens, end_reason
257 | else:
258 | return trim_decode_tokens
259 |
260 |
261 | def decode_tokens(
262 | tokens: Union[torch.LongTensor, TokensType],
263 | tokenizer: PreTrainedTokenizer,
264 | raw_text_len: int,
265 | context_length: int,
266 | chat_format: str,
267 | verbose: bool = False,
268 | return_end_reason: bool = False,
269 | errors: str="replace",
270 | ) -> str:
271 | if torch.is_tensor(tokens):
272 | tokens = tokens.cpu().numpy().tolist()
273 |
274 | if chat_format == "chatml":
275 | return _decode_chatml(
276 | tokens,
277 | stop_words=[],
278 | eod_token_ids=[tokenizer.im_start_id, tokenizer.im_end_id],
279 | tokenizer=tokenizer,
280 | raw_text_len=raw_text_len,
281 | context_length=context_length,
282 | verbose=verbose,
283 | return_end_reason=return_end_reason,
284 | errors=errors,
285 | )
286 | elif chat_format == "raw":
287 | return _decode_default(
288 | tokens,
289 | stop_words=["<|endoftext|>"],
290 | eod_words=["<|endoftext|>"],
291 | tokenizer=tokenizer,
292 | raw_text_len=raw_text_len,
293 | verbose=verbose,
294 | return_end_reason=return_end_reason,
295 | errors=errors,
296 | )
297 | else:
298 | raise NotImplementedError(f"Unknown chat format {chat_format!r}")
299 |
300 |
301 | class StopWordsLogitsProcessor(LogitsProcessor):
302 | """
303 | :class:`transformers.LogitsProcessor` that enforces that when specified sequences appear, stop geration.
304 |
305 | Args:
306 | stop_words_ids (:obj:`List[List[int]]`):
307 | List of list of token ids of stop ids. In order to get the tokens of the words
308 | that should not appear in the generated text, use :obj:`tokenizer(bad_word,
309 | add_prefix_space=True).input_ids`.
310 | eos_token_id (:obj:`int`):
311 | The id of the `end-of-sequence` token.
312 | """
313 |
314 | def __init__(self, stop_words_ids: Iterable[Iterable[int]], eos_token_id: int):
315 |
316 | if not isinstance(stop_words_ids, List) or len(stop_words_ids) == 0:
317 | raise ValueError(
318 | f"`stop_words_ids` has to be a non-emtpy list, but is {stop_words_ids}."
319 | )
320 | if any(not isinstance(bad_word_ids, list) for bad_word_ids in stop_words_ids):
321 | raise ValueError(
322 | f"`stop_words_ids` has to be a list of lists, but is {stop_words_ids}."
323 | )
324 | if any(
325 | any(
326 | (not isinstance(token_id, (int, np.integer)) or token_id < 0)
327 | for token_id in stop_word_ids
328 | )
329 | for stop_word_ids in stop_words_ids
330 | ):
331 | raise ValueError(
332 | f"Each list in `stop_words_ids` has to be a list of positive integers, but is {stop_words_ids}."
333 | )
334 |
335 | self.stop_words_ids = list(
336 | filter(
337 | lambda bad_token_seq: bad_token_seq != [eos_token_id], stop_words_ids
338 | )
339 | )
340 | self.eos_token_id = eos_token_id
341 | for stop_token_seq in self.stop_words_ids:
342 | assert (
343 | len(stop_token_seq) > 0
344 | ), "Stop words token sequences {} cannot have an empty list".format(
345 | stop_words_ids
346 | )
347 |
348 | def __call__(
349 | self, input_ids: torch.LongTensor, scores: torch.FloatTensor
350 | ) -> torch.FloatTensor:
351 | stopped_samples = self._calc_stopped_samples(input_ids)
352 | for i, should_stop in enumerate(stopped_samples):
353 | if should_stop:
354 | scores[i, self.eos_token_id] = float(2**15)
355 | return scores
356 |
357 | def _tokens_match(self, prev_tokens: torch.LongTensor, tokens: List[int]) -> bool:
358 | if len(tokens) == 0:
359 | # if bad word tokens is just one token always ban it
360 | return True
361 | elif len(tokens) > len(prev_tokens):
362 | # if bad word tokens are longer then prev input_ids they can't be equal
363 | return False
364 | elif prev_tokens[-len(tokens) :].tolist() == tokens:
365 | # if tokens match
366 | return True
367 | else:
368 | return False
369 |
370 | def _calc_stopped_samples(self, prev_input_ids: Iterable[int]) -> Iterable[int]:
371 | stopped_samples = []
372 | for prev_input_ids_slice in prev_input_ids:
373 | match = False
374 | for stop_token_seq in self.stop_words_ids:
375 | if self._tokens_match(prev_input_ids_slice, stop_token_seq):
376 | # if tokens do not match continue
377 | match = True
378 | break
379 | stopped_samples.append(match)
380 |
381 | return stopped_samples
382 |
383 |
384 | def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float("Inf")):
385 | """This function has been mostly taken from huggingface conversational
386 | ai code at
387 | https://medium.com/huggingface/how-to-build-a-state-of-the-art-
388 | conversational-ai-with-transfer-learning-2d818ac26313"""
389 |
390 | if top_k > 0:
391 | # Remove all tokens with a probability less than the
392 | # last token of the top-k
393 | indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
394 | logits[indices_to_remove] = filter_value
395 |
396 | if top_p > 0.0:
397 | # Cconvert to 1D
398 | sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
399 | cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
400 |
401 | # Remove tokens with cumulative probability above the threshold
402 | sorted_indices_to_remove = cumulative_probs > top_p
403 | # Shift the indices to the right to keep also the first token
404 | # above the threshold
405 | sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
406 | sorted_indices_to_remove[..., 0] = 0
407 | for i in range(sorted_indices.size(0)):
408 | indices_to_remove = sorted_indices[i][sorted_indices_to_remove[i]]
409 | logits[i][indices_to_remove] = filter_value
410 |
411 | return logits
412 |
413 |
414 | def switch(val1, val2, boolean):
415 | boolean = boolean.type_as(val1)
416 | return (1 - boolean) * val1 + boolean * val2
417 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | python==3.9.12
2 | vllm
3 | modelscope
4 | tiktoken
5 | pdfplumber
6 | PdfReader
7 | PyPDF2
8 | langchain
9 | jieba
10 | rank_bm25
11 | sentence-transformers
12 | faiss-gpu
13 |
--------------------------------------------------------------------------------
/rerank_model.py:
--------------------------------------------------------------------------------
1 | from transformers import AutoModelForSequenceClassification, AutoTokenizer
2 | import os
3 |
4 | from bm25_retriever import BM25
5 | from pdf_parse import DataProcess
6 | from config import *
7 |
8 | os.environ["TOKENIZERS_PARALLELISM"] = "false"
9 |
10 | DEVICE = LLM_DEVICE
11 | DEVICE_ID = "0"
12 | CUDA_DEVICE = f"{DEVICE}:{DEVICE_ID}" if DEVICE_ID else DEVICE
13 |
14 |
15 | def torch_gc():
16 | if torch.cuda.is_available():
17 | with torch.cuda.device(CUDA_DEVICE):
18 | torch.cuda.empty_cache()
19 | torch.cuda.ipc_collect()
20 | class reRankLLM(object):
21 | def __init__(self, model_path, max_length = 512):
22 | self.tokenizer = AutoTokenizer.from_pretrained(model_path)
23 | self.model = AutoModelForSequenceClassification.from_pretrained(model_path)
24 | self.model.eval()
25 | self.model.half()
26 | self.model.cuda()
27 | self.max_length = max_length
28 |
29 | def predict(self, query, docs):
30 | pairs = [(query, doc.page_content) for doc in docs]
31 | inputs = self.tokenizer(pairs, padding=True, truncation=True, return_tensors='pt', max_length=self.max_length).to("cuda")
32 | with torch.no_grad():
33 | scores = self.model(**inputs).logits
34 | scores = scores.detach().cpu().clone().numpy()
35 | response = [doc for score, doc in sorted(zip(scores, docs), reverse=True, key=lambda x:x[0])]
36 | torch_gc()
37 | return response
38 | if __name__ == "__main__":
39 | bge_reranker_large = "/Users/william/codes/contest/aicar_docker/new_build/app/pre_train_model/bge-reranker-large"
40 | rerank = reRankLLM(bge_reranker_large)
41 |
--------------------------------------------------------------------------------
/run.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding: utf-8
3 |
4 | import json
5 | import jieba
6 | import pandas as pd
7 | import numpy as np
8 | from tqdm import tqdm
9 | from langchain.schema import Document
10 | from langchain.vectorstores import Chroma,FAISS
11 | from langchain import PromptTemplate, LLMChain
12 | from langchain.chains import RetrievalQA
13 | import time
14 | import re
15 |
16 | from vllm_model import ChatLLM
17 | from vllm_wrapper import vLLMWrapper
18 | from rerank_model import reRankLLM
19 | from faiss_retriever import FaissRetriever
20 | from bm25_retriever import BM25
21 | from pdf_parse import DataProcess
22 |
23 | def get_qa_chain(llm, vector_store, prompt_template):
24 |
25 | prompt = PromptTemplate(template=prompt_template,
26 | input_variables=["context", "question"])
27 |
28 | return RetrievalQA.from_llm(llm=llm, retriever=vector_store.as_retriever(search_kwargs={"k": 10}), prompt=prompt)
29 | def get_emb_bm25_merge(faiss_context, bm25_context, query):
30 | max_length = 2500
31 | emb_ans = ""
32 | cnt = 0
33 | for doc, score in faiss_context:
34 | cnt =cnt + 1
35 | if(cnt>6):
36 | break
37 | if(len(emb_ans + doc.page_content) > max_length):
38 | break
39 | emb_ans = emb_ans + doc.page_content
40 | bm25_ans = ""
41 | cnt = 0
42 | for doc in bm25_context:
43 | cnt = cnt + 1
44 | if(len(bm25_ans + doc.page_content) > max_length):
45 | break
46 | bm25_ans = bm25_ans + doc.page_content
47 | if(cnt > 6):
48 | break
49 |
50 | prompt_template = """基于以下已知信息,简洁和专业的来回答用户的问题。
51 | 如果无法从中得到答案,请说 "无答案"或"无答案",不允许在答案中添加编造成分,答案请使用中文。
52 | 已知内容为吉利控股集团汽车销售有限公司的吉利用户手册:
53 | 1: {emb_ans}
54 | 2: {bm25_ans}
55 | 问题:
56 | {question}""".format(emb_ans=emb_ans, bm25_ans = bm25_ans, question = query)
57 | return prompt_template
58 | def get_rerank(emb_ans, query):
59 |
60 | prompt_template = """基于以下已知信息,简洁和专业的来回答用户的问题。
61 | 如果无法从中得到答案,请说 "无答案"或"无答案" ,不允许在答案中添加编造成分,答案请使用中文。
62 | 已知内容为吉利控股集团汽车销售有限公司的吉利用户手册:
63 | 1: {emb_ans}
64 | 问题:
65 | {question}""".format(emb_ans=emb_ans, question = query)
66 | return prompt_template
67 |
68 |
69 | def question(text, llm, vector_store, prompt_template):
70 |
71 | chain = get_qa_chain(llm, vector_store, prompt_template)
72 |
73 | response = chain({"query": text})
74 | return response
75 |
76 | def reRank(rerank, top_k, query, bm25_ans, faiss_ans):
77 | items = []
78 | max_length = 4000
79 | for doc, score in faiss_ans:
80 | items.append(doc)
81 | items.extend(bm25_ans)
82 | rerank_ans = rerank.predict(query, items)
83 | rerank_ans = rerank_ans[:top_k]
84 | # docs_sort = sorted(rerank_ans, key = lambda x:x.metadata["id"])
85 | emb_ans = ""
86 | for doc in rerank_ans:
87 | if(len(emb_ans + doc.page_content) > max_length):
88 | break
89 | emb_ans = emb_ans + doc.page_content
90 | return emb_ans
91 |
92 | if __name__ == "__main__":
93 |
94 | start = time.time()
95 | # base = "/app"
96 | # qwen7 = "/tcdata/qwen/Qwen-7B-Chat"
97 |
98 | base = "/root/autodl-tmp/codes"
99 | qwen7 = base + "/pre_train_model/Qwen-7B-Chat"
100 | m3e = base + "/pre_train_model/m3e-large"
101 | bge_reranker_large = base + "/pre_train_model/bge-reranker-large"
102 |
103 | # data
104 | # dp = DataProcess(pdf_path = "/tcdata/trainning_data.pdf")
105 | dp = DataProcess(pdf_path = base + "/data/train_a.pdf")
106 | dp.ParseBlock(max_seq = 1024)
107 | dp.ParseBlock(max_seq = 512)
108 | print(len(dp.data))
109 | dp.ParseAllPage(max_seq = 256)
110 | dp.ParseAllPage(max_seq = 512)
111 | print(len(dp.data))
112 | dp.ParseOnePageWithRule(max_seq = 256)
113 | dp.ParseOnePageWithRule(max_seq = 512)
114 | print(len(dp.data))
115 | data = dp.data
116 | print("data load ok")
117 |
118 | # Faiss
119 | faissretriever = FaissRetriever(m3e, data)
120 | vector_store = faissretriever.vector_store
121 | print("faissretriever load ok")
122 |
123 | # BM2.5
124 | bm25 = BM25(data)
125 | print("bm25 load ok")
126 |
127 | # LLM
128 | # llm = vLLMWrapper(qwen7)
129 | llm = ChatLLM(qwen7)
130 | print("llm qwen load ok")
131 |
132 | # reRank
133 | rerank = reRankLLM(bge_reranker_large)
134 | print("rerank model load ok")
135 |
136 | # with open("/tcdata/test_question.json", "r") as f:
137 |
138 | with open(base + "/data/test_question.json", "r") as f:
139 | jdata = json.loads(f.read())
140 | print(len(jdata))
141 | max_length = 4000
142 | for idx, line in enumerate(jdata):
143 | query = line["question"]
144 |
145 | # faiss
146 | faiss_context = faissretriever.GetTopK(query, 15)
147 | faiss_min_score = 0.0
148 | if(len(faiss_context) > 0):
149 | faiss_min_score = faiss_context[0][1]
150 | cnt = 0
151 | emb_ans = ""
152 | for doc, score in faiss_context:
153 | cnt =cnt + 1
154 | if(len(emb_ans + doc.page_content) > max_length):
155 | break
156 | emb_ans = emb_ans + doc.page_content
157 | if(cnt>6):
158 | break
159 |
160 | # bm2.5
161 | bm25_context = bm25.GetBM25TopK(query, 15)
162 | bm25_ans = ""
163 | cnt = 0
164 | for doc in bm25_context:
165 | cnt = cnt + 1
166 | if(len(bm25_ans + doc.page_content) > max_length):
167 | break
168 | bm25_ans = bm25_ans + doc.page_content
169 | if(cnt > 6):
170 | break
171 |
172 | emb_bm25_merge_inputs = get_emb_bm25_merge(faiss_context, bm25_context, query)
173 | bm25_inputs = get_rerank(bm25_ans, query)
174 | emb_inputs = get_rerank(emb_ans, query)
175 |
176 | # rerank emb recall
177 | rerank_ans = reRank(rerank, 6, query, bm25_context, faiss_context)
178 | rerank_inputs = get_rerank(rerank_ans, query)
179 |
180 | batch_input = []
181 | batch_input.append(emb_bm25_merge_inputs)
182 | batch_input.append(bm25_inputs)
183 | batch_input.append(emb_inputs)
184 | batch_input.append(rerank_inputs)
185 | batch_output = llm.infer(batch_input)
186 | line["answer_1"] = batch_output[0]
187 | line["answer_2"] = batch_output[1]
188 | line["answer_3"] = batch_output[2]
189 | line["answer_4"] = batch_output[3]
190 | line["answer_5"] = emb_ans
191 | line["answer_6"] = bm25_ans
192 | line["answer_7"] = rerank_ans
193 | if(faiss_min_score >500):
194 | line["answer_5"] = "无答案"
195 | else:
196 | line["answer_5"] = str(faiss_min_score)
197 | # json.dump(jdata, open("/app/result.json", "w", encoding='utf-8'), ensure_ascii=False, indent=2)
198 | json.dump(jdata, open(base + "/data/result.json", "w", encoding='utf-8'), ensure_ascii=False, indent=2)
199 | end = time.time()
200 | print("cost time: " + str(int(end-start)/60))
201 |
--------------------------------------------------------------------------------
/run.sh:
--------------------------------------------------------------------------------
1 | python /app/run.py
2 |
--------------------------------------------------------------------------------
/vllm_model.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import time
4 |
5 | from config import *
6 | from vllm import LLM, SamplingParams
7 |
8 | from transformers import AutoModelForCausalLM, AutoTokenizer
9 | from transformers import GenerationConfig
10 | from qwen_generation_utils import make_context, decode_tokens, get_stop_words_ids
11 |
12 |
13 | os.environ["TOKENIZERS_PARALLELISM"] = "false"
14 |
15 | DEVICE = LLM_DEVICE
16 | DEVICE_ID = "0"
17 | CUDA_DEVICE = f"{DEVICE}:{DEVICE_ID}" if DEVICE_ID else DEVICE
18 |
19 | IMEND = "<|im_end|>"
20 | ENDOFTEXT = "<|endoftext|>"
21 |
22 | def get_stop_words_ids(chat_format, tokenizer):
23 | if chat_format == "raw":
24 | stop_words_ids = [tokenizer.encode("Human:"), [tokenizer.eod_id]]
25 | elif chat_format == "chatml":
26 | stop_words_ids = [[tokenizer.im_end_id], [tokenizer.im_start_id]]
27 | else:
28 | raise NotImplementedError(f"Unknown chat format {chat_format!r}")
29 | return stop_words_ids
30 |
31 | def torch_gc():
32 | if torch.cuda.is_available():
33 | with torch.cuda.device(CUDA_DEVICE):
34 | torch.cuda.empty_cache()
35 | torch.cuda.ipc_collect()
36 |
37 | class ChatLLM(object):
38 |
39 | def __init__(self, model_path):
40 | self.tokenizer = AutoTokenizer.from_pretrained(
41 | model_path,
42 | pad_token='<|extra_0|>',
43 | eos_token='<|endoftext|>',
44 | padding_side='left',
45 | trust_remote_code=True
46 | )
47 | self.generation_config = GenerationConfig.from_pretrained(model_path, pad_token_id=self.tokenizer.pad_token_id)
48 | self.tokenizer.eos_token_id = self.generation_config.eos_token_id
49 | self.stop_words_ids = []
50 | self.model = LLM(model=model_path,
51 | tokenizer=model_path,
52 | tensor_parallel_size=1,
53 | trust_remote_code=True,
54 | gpu_memory_utilization=0.90,
55 | dtype="bfloat16")
56 | for stop_id in get_stop_words_ids(self.generation_config.chat_format, self.tokenizer):
57 | self.stop_words_ids.extend(stop_id)
58 | self.stop_words_ids.extend([self.generation_config.eos_token_id])
59 | sampling_kwargs = {
60 | "stop_token_ids": self.stop_words_ids,
61 | "early_stopping": False,
62 | "top_p": 1.0,
63 | "top_k": -1 if self.generation_config.top_k == 0 else self.generation_config.top_k,
64 | "temperature": 0.0,
65 | "max_tokens": 2000,
66 | "repetition_penalty": self.generation_config.repetition_penalty,
67 | "n":1,
68 | "best_of":2,
69 | "use_beam_search":True
70 | }
71 | self.sampling_params = SamplingParams(**sampling_kwargs)
72 |
73 | def infer(self, prompts):
74 | batch_text = []
75 | for q in prompts:
76 | raw_text, _ = make_context(
77 | self.tokenizer,
78 | q,
79 | system="You are a helpful assistant.",
80 | max_window_size=self.generation_config.max_window_size,
81 | chat_format=self.generation_config.chat_format,
82 | )
83 | batch_text.append(raw_text)
84 | outputs = self.model.generate(batch_text,
85 | sampling_params = self.sampling_params
86 | )
87 | batch_response = []
88 | for output in outputs:
89 | output_str = output.outputs[0].text
90 | if IMEND in output_str:
91 | output_str = output_str[:-len(IMEND)]
92 | if ENDOFTEXT in output_str:
93 | output_str = output_str[:-len(ENDOFTEXT)]
94 | batch_response.append(output_str)
95 | torch_gc()
96 | return batch_response
97 |
98 | if __name__ == "__main__":
99 | qwen7 = "/root/autodl-tmp/codes/pre_train_model/Qwen-7B-Chat"
100 | start = time.time()
101 | llm = ChatLLM(qwen7)
102 | test = ["吉利汽车座椅按摩","吉利汽车语音组手唤醒","自动驾驶功能介绍"]
103 | generated_text = llm.infer(test)
104 | print(generated_text)
105 | end = time.time()
106 | print("cost time: " + str((end-start)/60))
107 |
--------------------------------------------------------------------------------
/vllm_wrapper.py:
--------------------------------------------------------------------------------
1 | from transformers import PreTrainedTokenizer, GenerationConfig, StoppingCriteriaList
2 | from typing import Optional, Callable, List, Tuple, Union
3 | import copy
4 | import torch
5 | from transformers import AutoTokenizer
6 | from transformers.generation.logits_process import LogitsProcessorList
7 | from packaging import version
8 |
9 | _ERROR_BAD_CHAT_FORMAT = """\
10 | We detect you are probably using the pretrained model (rather than chat model) for chatting, since the chat_format in generation_config is not "chatml".
11 | If you are directly using the model downloaded from Huggingface, please make sure you are using our "Qwen/Qwen-7B-Chat" Huggingface model (rather than "Qwen/Qwen-7B") when you call model.chat().
12 | 我们检测到您可能在使用预训练模型(而非chat模型)进行多轮chat,因为您当前在generation_config指定的chat_format,并未设置为我们在对话中所支持的"chatml"格式。
13 | 如果您在直接使用我们从Huggingface提供的模型,请确保您在调用model.chat()时,使用的是"Qwen/Qwen-7B-Chat"模型(而非"Qwen/Qwen-7B"预训练模型)。
14 | """
15 |
16 | IMEND = "<|im_end|>"
17 | ENDOFTEXT = "<|endoftext|>"
18 |
19 | HistoryType = List[Tuple[str, str]]
20 | TokensType = List[int]
21 | BatchTokensType = List[List[int]]
22 |
23 | def get_stop_words_ids(chat_format, tokenizer):
24 | if chat_format == "raw":
25 | stop_words_ids = [tokenizer.encode("Human:"), [tokenizer.eod_id]]
26 | elif chat_format == "chatml":
27 | stop_words_ids = [[tokenizer.im_end_id], [tokenizer.im_start_id]]
28 | else:
29 | raise NotImplementedError(f"Unknown chat format {chat_format!r}")
30 | return stop_words_ids
31 |
32 | def make_context(
33 | tokenizer: PreTrainedTokenizer,
34 | query: str,
35 | history: List[Tuple[str, str]] = None,
36 | system: str = "",
37 | max_window_size: int = 6144,
38 | chat_format: str = "chatml",
39 | ):
40 | if history is None:
41 | history = []
42 |
43 | if chat_format == "chatml":
44 | im_start, im_end = "<|im_start|>", "<|im_end|>"
45 | im_start_tokens = [tokenizer.im_start_id]
46 | im_end_tokens = [tokenizer.im_end_id]
47 | nl_tokens = tokenizer.encode("\n")
48 |
49 | def _tokenize_str(role, content):
50 | return f"{role}\n{content}", tokenizer.encode(
51 | role, allowed_special=set()
52 | ) + nl_tokens + tokenizer.encode(content, allowed_special=set())
53 |
54 | system_text, system_tokens_part = _tokenize_str("system", system)
55 | system_tokens = im_start_tokens + system_tokens_part + im_end_tokens
56 |
57 | raw_text = ""
58 | context_tokens = []
59 |
60 | for turn_query, turn_response in reversed(history):
61 | query_text, query_tokens_part = _tokenize_str("user", turn_query)
62 | query_tokens = im_start_tokens + query_tokens_part + im_end_tokens
63 | response_text, response_tokens_part = _tokenize_str(
64 | "assistant", turn_response
65 | )
66 | response_tokens = im_start_tokens + response_tokens_part + im_end_tokens
67 |
68 | next_context_tokens = nl_tokens + query_tokens + nl_tokens + response_tokens
69 | prev_chat = (
70 | f"\n{im_start}{query_text}{im_end}\n{im_start}{response_text}{im_end}"
71 | )
72 |
73 | current_context_size = (
74 | len(system_tokens) + len(next_context_tokens) + len(context_tokens)
75 | )
76 | if current_context_size < max_window_size:
77 | context_tokens = next_context_tokens + context_tokens
78 | raw_text = prev_chat + raw_text
79 | else:
80 | break
81 |
82 | context_tokens = system_tokens + context_tokens
83 | raw_text = f"{im_start}{system_text}{im_end}" + raw_text
84 | context_tokens += (
85 | nl_tokens
86 | + im_start_tokens
87 | + _tokenize_str("user", query)[1]
88 | + im_end_tokens
89 | + nl_tokens
90 | + im_start_tokens
91 | + tokenizer.encode("assistant")
92 | + nl_tokens
93 | )
94 | raw_text += f"\n{im_start}user\n{query}{im_end}\n{im_start}assistant\n"
95 |
96 | elif chat_format == "raw":
97 | raw_text = query
98 | context_tokens = tokenizer.encode(raw_text)
99 | else:
100 | raise NotImplementedError(f"Unknown chat format {chat_format!r}")
101 |
102 | return raw_text, context_tokens
103 |
104 | class vLLMWrapper:
105 | def __init__(self,
106 | model_dir: str,
107 | trust_remote_code: bool = True,
108 | tensor_parallel_size: int = 1,
109 | gpu_memory_utilization: float = 0.98,
110 | dtype: str = "bfloat16",
111 | **kwargs):
112 |
113 | if dtype not in ("bfloat16", "float16", "float32"):
114 | print("now not support {}!".format(dtype))
115 | raise Exception
116 |
117 | # build generation_config
118 | self.generation_config = GenerationConfig.from_pretrained(model_dir, trust_remote_code=trust_remote_code)
119 |
120 | # build tokenizer
121 | self.tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
122 | self.tokenizer.eos_token_id = self.generation_config.eos_token_id
123 |
124 | self.stop_words_ids = []
125 |
126 | from vllm import LLM
127 | import vllm
128 | if version.parse(vllm.__version__) >= version.parse("0.2.2"):
129 | self.__vllm_support_repetition_penalty = True
130 | else:
131 | self.__vllm_support_repetition_penalty = False
132 |
133 | quantization = getattr(kwargs, 'quantization', None)
134 |
135 | self.model = LLM(model=model_dir,
136 | tokenizer=model_dir,
137 | tensor_parallel_size=tensor_parallel_size,
138 | trust_remote_code=trust_remote_code,
139 | quantization=quantization,
140 | gpu_memory_utilization=gpu_memory_utilization,
141 | dtype=dtype)
142 |
143 | for stop_id in get_stop_words_ids(self.generation_config.chat_format, self.tokenizer):
144 | self.stop_words_ids.extend(stop_id)
145 | self.stop_words_ids.extend([self.generation_config.eos_token_id])
146 |
147 | def chat(self,
148 | query: str,
149 | history: Optional[HistoryType],
150 | tokenizer: PreTrainedTokenizer = None,
151 | system: str = "You are a helpful assistant.",
152 | generation_config: Optional[GenerationConfig] = None,
153 | **kwargs):
154 | generation_config = generation_config if generation_config is not None else self.generation_config
155 | tokenizer = self.tokenizer if tokenizer is None else tokenizer
156 |
157 | assert generation_config.chat_format == 'chatml', _ERROR_BAD_CHAT_FORMAT
158 | if not self.__vllm_support_repetition_penalty and generation_config.repetition_penalty != 1:
159 | raise RuntimeError("The installed vLLM doesn't support repetition_penalty, please set ``model.generation_config.repetition_penalty = 1`` or install vllm>=0.2.2")
160 |
161 | if history is None:
162 | history = []
163 | else:
164 | # make a copy of the user's input such that is is left untouched
165 | history = copy.deepcopy(history)
166 |
167 | extra_stop_words_ids = kwargs.get('stop_words_ids', None)
168 | if extra_stop_words_ids is None:
169 | extra_stop_words_ids = []
170 |
171 | max_window_size = kwargs.get('max_window_size', None)
172 | if max_window_size is None:
173 | max_window_size = generation_config.max_window_size
174 |
175 | from vllm.sampling_params import SamplingParams
176 | sampling_kwargs = {
177 | "stop_token_ids": self.stop_words_ids,
178 | "early_stopping": False,
179 | "top_p": 1.0,
180 | "top_k": -1 if generation_config.top_k == 0 else generation_config.top_k,
181 | "temperature": 0.0,
182 | "max_tokens": 512,
183 | "repetition_penalty": generation_config.repetition_penalty,
184 | "n":1,
185 | "best_of":2,
186 | "use_beam_search":True
187 | }
188 | if not self.__vllm_support_repetition_penalty:
189 | sampling_kwargs.pop("repetition_penalty")
190 | sampling_params = SamplingParams(**sampling_kwargs)
191 |
192 | raw_text, context_tokens = make_context(
193 | self.tokenizer,
194 | query,
195 | history=history,
196 | system=system,
197 | max_window_size=max_window_size,
198 | chat_format=generation_config.chat_format,
199 | )
200 |
201 | req_outputs = self.model.generate([query],
202 | sampling_params=sampling_params,
203 | prompt_token_ids=[context_tokens])
204 | req_output = req_outputs[0]
205 |
206 | prompt_str = req_output.prompt
207 | prompt_ids = req_output.prompt_token_ids
208 | req_sample_output_ids = []
209 | req_sample_output_strs = []
210 | for sample in req_output.outputs:
211 | output_str = sample.text
212 | output_ids = sample.token_ids
213 | if IMEND in output_str:
214 | output_str = output_str[:-len(IMEND)]
215 | if ENDOFTEXT in output_str:
216 | output_str = output_str[:-len(ENDOFTEXT)]
217 | req_sample_output_ids.append(prompt_ids + output_ids)
218 | req_sample_output_strs.append(prompt_str + output_str)
219 | assert len(req_sample_output_strs) == 1
220 | response = req_sample_output_strs[0][len(prompt_str):]
221 | history.append((prompt_str, response))
222 |
223 | return response, history
224 |
225 | if __name__ == '__main__':
226 |
227 | model_dir = 'Qwen/Qwen-72B-Chat'
228 | tensor_parallel_size = 1
229 |
230 | model = vLLMWrapper(model_dir,
231 | tensor_parallel_size=tensor_parallel_size,
232 | )
233 |
234 | response, history = model.chat(query="你好",
235 | history=None)
236 | print(response)
237 | response, history = model.chat(query="给我讲一个年轻人奋斗创业最终取得成功的故事。",
238 | history=history)
239 | print(response)
240 | response, history = model.chat(query="给这个故事起一个标题",
241 | history=history)
242 | print(response)
243 |
--------------------------------------------------------------------------------