├── Dockerfile ├── README.md ├── bm25_retriever.py ├── build.sh ├── config.py ├── data ├── result.json ├── test_question.json └── train_a.pdf ├── faiss_retriever.py ├── images ├── 01.png ├── 02.png └── 03.png ├── pdf_parse.py ├── pre_train_model └── Qwen-7B-Chat │ └── download.py ├── qwen_generation_utils.py ├── requirements.txt ├── rerank_model.py ├── run.py ├── run.sh ├── vllm_model.py └── vllm_wrapper.py /Dockerfile: -------------------------------------------------------------------------------- 1 | # FROM registry.cn-shanghai.aliyuncs.com/tcc-public/pytorch:2.0.0-py3.9.12-cuda11.8.0-u22.04-cudnn 2 | 3 | FROM registry.cn-shanghai.aliyuncs.com/aicar/vllm:base 4 | 5 | # 如有安装其他软件的需求 6 | # RUN apt-get update && apt-get install curl 7 | # 如果安装其他python包的情况 8 | #pip3 install numpy --index-url=http://mirrors.aliyun.com/pypi/simple/ --trusted-host=mirrors.aliyun.com 9 | # RUN pip install --progress-bar off numpy pandas PyPDF2 langchain jieba rank_bm25 sentence-transformers faiss-gpu modelscope tiktoken transformers_stream_generator accelerate pdfplumber --index-url=http://mirrors.aliyun.com/pypi/simple/ --trusted-host=mirrors.aliyun.com 10 | # 复制代码到镜像仓库 11 | COPY app /app 12 | 13 | # 指定工作目录 14 | WORKDIR /app 15 | 16 | # 容器启动运行命令 17 | CMD ["bash", "run.sh"] 18 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Tianchi-LLM-QA 2 | 阿里天池: 2023全球智能汽车AI挑战赛——赛道一:AI大模型检索问答 baseline 80+ 3 | 4 |

5 | 6 | 7 | 8 |

9 | 10 | 11 | ### 1、代码结构 12 | 13 | ```text 14 | . 15 | ├── Dockerfile 16 | ├── README.md 17 | ├── bm25_retriever.py 18 | ├── build.sh 19 | ├── config.py 20 | ├── data 21 | │   ├── result.json 22 | │   ├── test_question.json 23 | │   └── train_a.pdf 24 | ├── faiss_retriever.py 25 | ├── vllm_model.py 26 | ├── pdf_parse.py 27 | ├── pre_train_model 28 | │   ├── Qwen-7B-Chat 29 | │   │   └── download.py 30 | │   ├── bge-reranker-large 31 | │   └── m3e-large 32 | ├── qwen_generation_utils.py 33 | ├── requirements.txt 34 | ├── rerank_model.py 35 | ├── run.py 36 | ├── run.sh 37 | └── vllm_wrapper.py 38 | ``` 39 | 40 | ### 2、[赛题概述](https://tianchi.aliyun.com/competition/entrance/532154) 41 | #### 2.1 赛题:基于大模型的文档检索问答 42 | 43 | 任务:本次比赛要求参赛选手以大模型为中心制作一个问答系统,回答用户的汽车相关问题。参赛选手需要根据问题,在文档中定位相关信息的位置,并根据文档内容通过大模型生成相应的答案。本次比赛涉及的问题主要围绕汽车使用、维修、保养等方面,具体可参考下面的例子: 44 | 45 | 问题1:怎么打开危险警告灯? 46 | 答案1:危险警告灯开关在方向盘下方,按下开关即可打开危险警告灯。 47 | 48 | 问题2:车辆如何保养? 49 | 答案2:为了保持车辆处于最佳状态,建议您定期关注车辆状态,包括定期保养、洗车、内部清洁、外部清洁、轮胎的保养、低压蓄电池的保养等。 50 | 51 | 问题3:靠背太热怎么办? 52 | 答案3:您好,如果您的座椅靠背太热,可以尝试关闭座椅加热功能。在多媒体显示屏上依次点击空调开启按键→座椅→加热,在该界面下可以关闭座椅加热。 53 | 54 | #### 2.2 数据(复赛数据官方只提供部分参考样式) 55 | 56 | [初赛训练数据集.pdf](https://tianchi-race-prod-sh.oss-cn-shanghai.aliyuncs.com/file/race/documents/532154/%E5%88%9D%E8%B5%9B%E8%AE%AD%E7%BB%83%E9%9B%86/%E5%88%9D%E8%B5%9B%E8%AE%AD%E7%BB%83%E6%95%B0%E6%8D%AE%E9%9B%86.pdf?Expires=1703022585&OSSAccessKeyId=LTAI5t7fj2oKqzKgLGz6kGQc&Signature=pg9tnYgHDLkAlfCU%2Bs3h3QBrvfA%3D&response-content-disposition=attachment%3B%20) 57 | 58 | [测试问题.json](https://tianchi-race-prod-sh.oss-cn-shanghai.aliyuncs.com/file/race/documents/532154/%E5%85%B6%E5%AE%83/%E6%B5%8B%E8%AF%95%E9%97%AE%E9%A2%98.json?Expires=1703022684&OSSAccessKeyId=LTAI5t7fj2oKqzKgLGz6kGQc&Signature=kTn%2BN4ZnY9tftVmz5kjNKOCoFAs%3D&response-content-disposition=attachment%3B%20) 59 | 60 | 61 | ### 3、解决方案 62 | 63 | #### 3.1 pdf解析 64 | 65 | ##### 3.1.1 pdf分块解析 66 | ![分块解析示例图](images/01.png) 67 | 如图所示,我们希望pdf解析能尽可能的按照快状进行解析,每一块当做一个样本,这样能尽可能的保证pdf中文本内容的完整性 68 | 改进==》希望借助OCR进行pdf的块状识别 69 | 70 | ##### 3.1.2 pdf 滑窗法解析 71 | ![滑窗法解析示例图1](images/02.png) 72 | ![滑窗法解析示例图2](images/03.png) 73 | 如图1,2 所示,我们可以看到图1和图2上下文是连续的,如何保证文本内容的跨页连续性问题,我们提出滑窗法。 74 | 具体的把pdf中所有内容当做一个字符串来处理,按照句号进行分割,根据分割后的数组进行滑窗。具体的如下所示: 75 | 76 | ["aa","bb","cc","dd"] 77 | 78 | 如果字符串长度为4, 经过滑窗后的结果如下: 79 | 80 | aabb 81 | 82 | bbcc 83 | 84 | ccdd 85 | 86 | 我们希望滑窗法像卷积一样可以不同的kernel,Stride,来寻找能覆盖到的最优的样本召回 87 | 88 | #### 3.2 召回 89 | 90 | 召回主要使用langchain中的retrievers进行文本的召回。我们知道向量召回和bm25召回具有互补性,因此选用了这两个进行召回 91 | 92 | ##### 3.2.1 向量召回 93 | 94 | 向量召回利用 FAISS 进行索引创建和查找,embedding 利用 [M3E-large](https://modelscope.cn/models/Jerry0/M3E-large/summary) 或者[bge-large-zh](https://modelscope.cn/models/AI-ModelScope/bge-large-zh/summary) 95 | 96 | ##### 3.2.2 bm25召回 97 | 98 | bm25召回利用 langchain自带的bm25 retrievers 99 | 100 | #### 3.3 重排序 101 | 102 | 1、重排序是对召回的文本进行进一步的重排,以获得更精准,数据量更少的可能答案。 103 | 2、向量召回中使用的是bi-encoder结构,而bge-reranker-large 使用的是 cross-encoder结构,cross-encoder结构一定程度上要优于bi-encoder 104 | 105 | ##### 3.3.1 cross-encoder 106 | 107 | 重排序此处使用了 [bge-reranker-large](https://modelscope.cn/models/Xorbits/bge-reranker-large/files) 108 | 109 | #### 3.4 推理优化 110 | 111 | ##### 3.4.1 vllm batch 112 | 113 | vllm 利用page attention 技术使推理速度得到提升,batch推理比普通推理有接近1倍的提升空间 114 | 115 | ##### 3.4.2 tensorRT-LLM 116 | 117 | tensorRT-LLM是英伟达推出的推理框架,并且提供了c++和python的调用方式。关于qwen的tensorRT-LLM使用请参考官方介绍[tensorRT-LLM Qwen](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/qwen) 118 | 119 | ### 4、排名 120 | 121 | [初赛2名](https://tianchi.aliyun.com/competition/entrance/532154/rankingList) 122 | [复赛13名](https://tianchi.aliyun.com/competition/entrance/532154/rankingList) 123 | -------------------------------------------------------------------------------- /bm25_retriever.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | 5 | from langchain.retrievers import BM25Retriever 6 | from langchain.schema import Document 7 | from pdf_parse import DataProcess 8 | import jieba 9 | 10 | class BM25(object): 11 | 12 | def __init__(self, documents): 13 | 14 | docs = [] 15 | full_docs = [] 16 | for idx, line in enumerate(documents): 17 | line = line.strip("\n").strip() 18 | if(len(line)<5): 19 | continue 20 | tokens = " ".join(jieba.cut_for_search(line)) 21 | # docs.append(Document(page_content=tokens, metadata={"id": idx, "cate":words[1],"pageid":words[2]})) 22 | docs.append(Document(page_content=tokens, metadata={"id": idx})) 23 | # full_docs.append(Document(page_content=words[0], metadata={"id": idx, "cate":words[1], "pageid":words[2]})) 24 | words = line.split("\t") 25 | full_docs.append(Document(page_content=words[0], metadata={"id": idx})) 26 | self.documents = docs 27 | self.full_documents = full_docs 28 | self.retriever = self._init_bm25() 29 | 30 | # 初始化BM25的知识库 31 | def _init_bm25(self): 32 | return BM25Retriever.from_documents(self.documents) 33 | 34 | # 获得得分在topk的文档和分数 35 | def GetBM25TopK(self, query, topk): 36 | self.retriever.k = topk 37 | query = " ".join(jieba.cut_for_search(query)) 38 | ans_docs = self.retriever.get_relevant_documents(query) 39 | ans = [] 40 | for line in ans_docs: 41 | ans.append(self.full_documents[line.metadata["id"]]) 42 | return ans 43 | 44 | if __name__ == "__main__": 45 | 46 | # bm2.5 47 | dp = DataProcess(pdf_path = "/root/autodl-tmp/codes/data/train_a.pdf") 48 | dp.ParseBlock(max_seq = 1024) 49 | dp.ParseBlock(max_seq = 512) 50 | print(len(dp.data)) 51 | dp.ParseAllPage(max_seq = 256) 52 | dp.ParseAllPage(max_seq = 512) 53 | print(len(dp.data)) 54 | dp.ParseOnePageWithRule(max_seq = 256) 55 | dp.ParseOnePageWithRule(max_seq = 512) 56 | print(len(dp.data)) 57 | data = dp.data 58 | bm25 = BM25(data) 59 | res = bm25.GetBM25TopK("座椅加热", 6) 60 | print(res) 61 | -------------------------------------------------------------------------------- /build.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # 以下命令作为打包示例,实际使用时请修改为自己的镜像地址, 建议每次提交前完成版本修改重新打包 4 | # docker build -t registry.cn-shanghai.aliyuncs.com/taylor:0.1 . 5 | 6 | imageid=`docker images|awk 'NR>1'|grep "aicar/taylor"|awk '{print($3)}'` 7 | echo $imageid 8 | docker rmi -f $imageid 9 | echo yes|docker builder prune 10 | docker build -t registry.cn-shanghai.aliyuncs.com/taylor:0.1 . 11 | 12 | ImageId=`docker images|awk 'NR>1'|grep "0.1"|awk '{print($3)}'` 13 | echo $ImageId 14 | docker tag $ImageId registry.cn-shanghai.aliyuncs.com/aicar/taylor:v1 15 | docker login --username=xxx -p xxx registry.cn-shanghai.aliyuncs.com 16 | docker push registry.cn-shanghai.aliyuncs.com/aicar/taylor:v1 17 | 18 | 19 | 20 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | 5 | # device config 6 | EMBEDDING_DEVICE = "cuda" if torch.cuda.is_available( 7 | ) else "mps" if torch.backends.mps.is_available() else "cpu" 8 | LLM_DEVICE = "cuda" if torch.cuda.is_available( 9 | ) else "mps" if torch.backends.mps.is_available() else "cpu" 10 | num_gpus = torch.cuda.device_count() 11 | 12 | # model cache config 13 | MODEL_CACHE_PATH = os.path.join(os.path.dirname(__file__), 'model_cache') 14 | 15 | 16 | # vector storage config 17 | VECTOR_STORE_PATH='./vector_store' 18 | COLLECTION_NAME='my_collection' 19 | 20 | 21 | # init model config 22 | init_llm = "ChatGLM2-6B" 23 | init_embedding_model = "text2vec-base" 24 | 25 | # model config 26 | embedding_model_dict = { 27 | "ernie-tiny": "nghuyong/ernie-3.0-nano-zh", 28 | "ernie-base": "nghuyong/ernie-3.0-base-zh", 29 | "ernie-medium": "nghuyong/ernie-3.0-medium-zh", 30 | "ernie-xbase": "nghuyong/ernie-3.0-xbase-zh", 31 | "text2vec-base": "GanymedeNil/text2vec-base-chinese", 32 | 'simbert-base-chinese': 'WangZeJun/simbert-base-chinese', 33 | 'paraphrase-multilingual-MiniLM-L12-v2': "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2" 34 | } 35 | 36 | 37 | llm_model_dict = { 38 | "chatglm2": { 39 | "ChatGLM2-6B": "THUDM/chatglm2-6b", 40 | "ChatGLM2-6B-int4": "THUDM/chatglm2-6b-int4", 41 | }, 42 | "chatglm": { 43 | "ChatGLM-6B": "THUDM/chatglm-6b", 44 | "ChatGLM-6B-int4": "THUDM/chatglm-6b-int4", 45 | "ChatGLM-6B-int8": "THUDM/chatglm-6b-int8", 46 | "ChatGLM-6b-int4-qe": "THUDM/chatglm-6b-int4-qe" 47 | }, 48 | "belle": { 49 | "BELLE-LLaMA-Local": "/pretrainmodel/belle", 50 | }, 51 | "vicuna": { 52 | "Vicuna-Local": "/pretrainmodel/vicuna", 53 | }, 54 | "internlm": { 55 | "internlm-chat-7b-8k": "internlm/internlm-chat-7b-8k", 56 | "internlm-chat-7b": "internlm/internlm-chat-7b", 57 | "internlm-chat-7b-v1_1": "internlm/internlm-chat-7b-v1_1", 58 | } 59 | } 60 | -------------------------------------------------------------------------------- /data/test_question.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "question": "中国足球的队长是谁", 4 | "answer_1": "", 5 | "answer_2": "", 6 | "answer_3": "" 7 | }, 8 | { 9 | "question": "新冠肺炎如何预防?", 10 | "answer_1": "", 11 | "answer_2": "", 12 | "answer_3": "" 13 | }, 14 | { 15 | "question": "交通事故如何处理?", 16 | "answer_1": "", 17 | "answer_2": "", 18 | "answer_3": "" 19 | }, 20 | { 21 | "question": "怎样加热座椅?", 22 | "answer_1": "", 23 | "answer_2": "", 24 | "answer_3": "" 25 | }, 26 | { 27 | "question": "自动模式下,中央显示屏是如何切换日间和夜间模式的?", 28 | "answer_1": "", 29 | "answer_2": "", 30 | "answer_3": "" 31 | }, 32 | { 33 | "question": "如何通过中央显示屏进行副驾驶员座椅设置?", 34 | "answer_1": "", 35 | "answer_2": "", 36 | "answer_3": "" 37 | }, 38 | { 39 | "question": "副仪表台按钮如何操作中央显示屏?", 40 | "answer_1": "", 41 | "answer_2": "", 42 | "answer_3": "" 43 | }, 44 | { 45 | "question": "如何从锁定状态唤醒中央显示器?", 46 | "answer_1": "", 47 | "answer_2": "", 48 | "answer_3": "" 49 | }, 50 | { 51 | "question": "如何正确使用颈椎保护系统?", 52 | "answer_1": "", 53 | "answer_2": "", 54 | "answer_3": "" 55 | }, 56 | { 57 | "question": "前方交叉路口预警系统(FCTA)的作用是什么?", 58 | "answer_1": "", 59 | "answer_2": "", 60 | "answer_3": "" 61 | }, 62 | { 63 | "question": "在使用FCTA时需要注意哪些事项?", 64 | "answer_1": "", 65 | "answer_2": "", 66 | "answer_3": "" 67 | }, 68 | { 69 | "question": "如何打开车辆尾门?", 70 | "answer_1": "", 71 | "answer_2": "", 72 | "answer_3": "" 73 | }, 74 | { 75 | "question": "在哪些情况下智能钥匙可能会受到干扰,导致功能异常?", 76 | "answer_1": "", 77 | "answer_2": "", 78 | "answer_3": "" 79 | }, 80 | { 81 | "question": "车辆尾门的防夹保护功能是如何工作的?", 82 | "answer_1": "", 83 | "answer_2": "", 84 | "answer_3": "" 85 | }, 86 | { 87 | "question": "在操作电动后备厢时需要注意哪些事项?", 88 | "answer_1": "", 89 | "answer_2": "", 90 | "answer_3": "" 91 | }, 92 | { 93 | "question": "如何进入车辆功能界面?", 94 | "answer_1": "", 95 | "answer_2": "", 96 | "answer_3": "" 97 | }, 98 | { 99 | "question": "在车辆功能界面有哪些操作选项?", 100 | "answer_1": "", 101 | "answer_2": "", 102 | "answer_3": "" 103 | }, 104 | { 105 | "question": "如何编辑快捷开关图标?", 106 | "answer_1": "", 107 | "answer_2": "", 108 | "answer_3": "" 109 | }, 110 | { 111 | "question": "如何减少车辆腐蚀风险?", 112 | "answer_1": "", 113 | "answer_2": "", 114 | "answer_3": "" 115 | }, 116 | { 117 | "question": "如何通过空调系统面板调节空调风量?", 118 | "answer_1": "", 119 | "answer_2": "", 120 | "answer_3": "" 121 | }, 122 | { 123 | "question": "如何创建新的Lynk&CoID?", 124 | "answer_1": "", 125 | "answer_2": "", 126 | "answer_3": "" 127 | }, 128 | { 129 | "question": "什么是车主账户?", 130 | "answer_1": "", 131 | "answer_2": "", 132 | "answer_3": "" 133 | }, 134 | { 135 | "question": "如何创建人脸识别?", 136 | "answer_1": "", 137 | "answer_2": "", 138 | "answer_3": "" 139 | }, 140 | { 141 | "question": "如何添加亲情账号?", 142 | "answer_1": "", 143 | "answer_2": "", 144 | "answer_3": "" 145 | }, 146 | { 147 | "question": "如何开启或关闭用车偏好自动同步?", 148 | "answer_1": "", 149 | "answer_2": "", 150 | "answer_3": "" 151 | }, 152 | { 153 | "question": "如何熄火我的车辆?", 154 | "answer_1": "", 155 | "answer_2": "", 156 | "answer_3": "" 157 | }, 158 | { 159 | "question": "如何通过遥控钥匙启动车辆?", 160 | "answer_1": "", 161 | "answer_2": "", 162 | "answer_3": "" 163 | }, 164 | { 165 | "question": "如果遥控钥匙电池电量低,我应该如何启动车辆?", 166 | "answer_1": "", 167 | "answer_2": "", 168 | "answer_3": "" 169 | }, 170 | { 171 | "question": "如何调节外后视镜?", 172 | "answer_1": "", 173 | "answer_2": "", 174 | "answer_3": "" 175 | }, 176 | { 177 | "question": "外部反光境显示物体距离是否准确?", 178 | "answer_1": "", 179 | "answer_2": "", 180 | "answer_3": "" 181 | }, 182 | { 183 | "question": "什么是自动驻车系统?", 184 | "answer_1": "", 185 | "answer_2": "", 186 | "answer_3": "" 187 | }, 188 | { 189 | "question": "在什么情况下会停用AutoHold并启用EPB功能?", 190 | "answer_1": "", 191 | "answer_2": "", 192 | "answer_3": "" 193 | }, 194 | { 195 | "question": "中央扶手箱的USB接口有几个?它们分别是什么类型?", 196 | "answer_1": "", 197 | "answer_2": "", 198 | "answer_3": "" 199 | }, 200 | { 201 | "question": "中央扶手箱所支持的U盘及数据传输格式有哪些?", 202 | "answer_1": "", 203 | "answer_2": "", 204 | "answer_3": "" 205 | }, 206 | { 207 | "question": "如何通过中央显示屏调节驾驶员侧座椅通风强度?", 208 | "answer_1": "", 209 | "answer_2": "", 210 | "answer_3": "" 211 | }, 212 | { 213 | "question": "如何关闭前排座行车通风功能?", 214 | "answer_1": "", 215 | "answer_2": "", 216 | "answer_3": "" 217 | }, 218 | { 219 | "question": "如何进入系统设置界面?", 220 | "answer_1": "", 221 | "answer_2": "", 222 | "answer_3": "" 223 | }, 224 | { 225 | "question": "在系统界面可以进行哪些操作?", 226 | "answer_1": "", 227 | "answer_2": "", 228 | "answer_3": "" 229 | }, 230 | { 231 | "question": "什么是无钥匙进入系统?", 232 | "answer_1": "", 233 | "answer_2": "", 234 | "answer_3": "" 235 | }, 236 | { 237 | "question": "如何设置无钥匙解锁模式?", 238 | "answer_1": "", 239 | "answer_2": "", 240 | "answer_3": "" 241 | }, 242 | { 243 | "question": "设置无钥匙解锁中单门和全车的区别在于什么?", 244 | "answer_1": "", 245 | "answer_2": "", 246 | "answer_3": "" 247 | }, 248 | { 249 | "question": "驾驶车辆时应遵守哪些注意事项?", 250 | "answer_1": "", 251 | "answer_2": "", 252 | "answer_3": "" 253 | }, 254 | { 255 | "question": "如何启用或停用手套箱密码保护功能?", 256 | "answer_1": "", 257 | "answer_2": "", 258 | "answer_3": "" 259 | }, 260 | { 261 | "question": "驾驶员状态监测系统是如何工作的?", 262 | "answer_1": "", 263 | "answer_2": "", 264 | "answer_3": "" 265 | }, 266 | { 267 | "question": "什么情况下会影响到驾驶员状态监测系统的工作?", 268 | "answer_1": "", 269 | "answer_2": "", 270 | "answer_3": "" 271 | }, 272 | { 273 | "question": "如何启用后排儿童锁功能?", 274 | "answer_1": "", 275 | "answer_2": "", 276 | "answer_3": "" 277 | }, 278 | { 279 | "question": "安全气囊是什么?它的作用是什么?", 280 | "answer_1": "", 281 | "answer_2": "", 282 | "answer_3": "" 283 | }, 284 | { 285 | "question": "如果未使用或未正确使用安全带,会对安全气囊有何影响? ", 286 | "answer_1": "", 287 | "answer_2": "", 288 | "answer_3": "" 289 | }, 290 | { 291 | "question": "在使用车辆时,有哪些安全气囊的注意事项?", 292 | "answer_1": "", 293 | "answer_2": "", 294 | "answer_3": "" 295 | }, 296 | { 297 | "question": "如何开启动力电池电量保持功能?", 298 | "answer_1": "", 299 | "answer_2": "", 300 | "answer_3": "" 301 | }, 302 | { 303 | "question": "在开启动力电池的情况下,选择经济性优先和充电速度优先有什么区别?", 304 | "answer_1": "", 305 | "answer_2": "", 306 | "answer_3": "" 307 | }, 308 | { 309 | "question": "后方碰撞预警系统在什么情况下会启动?", 310 | "answer_1": "", 311 | "answer_2": "", 312 | "answer_3": "" 313 | }, 314 | { 315 | "question": "如何调整方向盘的位置?", 316 | "answer_1": "", 317 | "answer_2": "", 318 | "answer_3": "" 319 | }, 320 | { 321 | "question": "什么情况下不能调节车辆的方向盘?", 322 | "answer_1": "", 323 | "answer_2": "", 324 | "answer_3": "" 325 | }, 326 | { 327 | "question": "如何通过手机APP启动车辆?", 328 | "answer_1": "", 329 | "answer_2": "", 330 | "answer_3": "" 331 | }, 332 | { 333 | "question": "什么是陡坡缓降系统(HDC什么是陡坡缓降系统(HDC)?", 334 | "answer_1": "", 335 | "answer_2": "", 336 | "answer_3": "" 337 | }, 338 | { 339 | "question": "当坡度过大时,如何操作才能使车辆保持匀速地行驶?", 340 | "answer_1": "", 341 | "answer_2": "", 342 | "answer_3": "" 343 | }, 344 | { 345 | "question": "在什么情况下HDC会激活?", 346 | "answer_1": "", 347 | "answer_2": "", 348 | "answer_3": "" 349 | }, 350 | { 351 | "question": "在什么情况下无法激活或自动退出HDC功能?", 352 | "answer_1": "", 353 | "answer_2": "", 354 | "answer_3": "" 355 | }, 356 | { 357 | "question": "开启陡坡缓降系统后,组合仪表显示什么颜色的指示灯?", 358 | "answer_1": "", 359 | "answer_2": "", 360 | "answer_3": "" 361 | }, 362 | { 363 | "question": "激活陡坡缓降系统时,组合仪表显示什么颜色的指示灯?", 364 | "answer_1": "", 365 | "answer_2": "", 366 | "answer_3": "" 367 | }, 368 | { 369 | "question": "当陡坡缓降系统出现故障时,组合仪表会显示怎样的提示?", 370 | "answer_1": "", 371 | "answer_2": "", 372 | "answer_3": "" 373 | }, 374 | { 375 | "question": "坡道辅助系统的主要功能是什么?", 376 | "answer_1": "", 377 | "answer_2": "", 378 | "answer_3": "" 379 | }, 380 | { 381 | "question": "使用坡道辅助系统时需要注意哪些警告?", 382 | "answer_1": "", 383 | "answer_2": "", 384 | "answer_3": "" 385 | }, 386 | { 387 | "question": "在有路缘石的上坡和下坡驻车时,应如何操作?", 388 | "answer_1": "", 389 | "answer_2": "", 390 | "answer_3": "" 391 | }, 392 | { 393 | "question": "新车在最初的2000km行驶期间应遵守哪些事项?", 394 | "answer_1": "", 395 | "answer_2": "", 396 | "answer_3": "" 397 | }, 398 | { 399 | "question": "当刹车片将达到最小安全厚度时会有什么表现?", 400 | "answer_1": "", 401 | "answer_2": "", 402 | "answer_3": "" 403 | }, 404 | { 405 | "question": "如果听到持续的高频尖锐噪声应该怎么做?", 406 | "answer_1": "", 407 | "answer_2": "", 408 | "answer_3": "" 409 | }, 410 | { 411 | "question": "风挡摄像头和前雷达在什么情况下会弹出警示消息?", 412 | "answer_1": "", 413 | "answer_2": "", 414 | "answer_3": "" 415 | }, 416 | { 417 | "question": "如何通过蓝牙钥匙启动车辆?", 418 | "answer_1": "", 419 | "answer_2": "", 420 | "answer_3": "" 421 | }, 422 | { 423 | "question": "在什么情况下无法通过蓝牙钥匙启动车辆?", 424 | "answer_1": "", 425 | "answer_2": "", 426 | "answer_3": "" 427 | }, 428 | { 429 | "question": "如何通过网络远程启动发动机?", 430 | "answer_1": "", 431 | "answer_2": "", 432 | "answer_3": "" 433 | }, 434 | { 435 | "question": "变道辅助系统由哪三部分组成?", 436 | "answer_1": "", 437 | "answer_2": "", 438 | "answer_3": "" 439 | }, 440 | { 441 | "question": "如何检查制动液液位?", 442 | "answer_1": "", 443 | "answer_2": "", 444 | "answer_3": "" 445 | }, 446 | { 447 | "question": "如何调节车辆的背光亮度?", 448 | "answer_1": "", 449 | "answer_2": "", 450 | "answer_3": "" 451 | }, 452 | { 453 | "question": "在白天和夜晚,开启和未开启整车背光联动时,转动调光旋钮有什么不同?", 454 | "answer_1": "", 455 | "answer_2": "", 456 | "answer_3": "" 457 | }, 458 | { 459 | "question": "当车辆发生故障时,应该如何处理?", 460 | "answer_1": "", 461 | "answer_2": "", 462 | "answer_3": "" 463 | }, 464 | { 465 | "question": "为什么不建议自行修理车辆故障?", 466 | "answer_1": "", 467 | "answer_2": "", 468 | "answer_3": "" 469 | }, 470 | { 471 | "question": "什么是能量回收系统?", 472 | "answer_1": "", 473 | "answer_2": "", 474 | "answer_3": "" 475 | }, 476 | { 477 | "question": "如何知道车辆正在进行能量回收?", 478 | "answer_1": "", 479 | "answer_2": "", 480 | "answer_3": "" 481 | }, 482 | { 483 | "question": "影响能量回收多少的因素有哪些?", 484 | "answer_1": "", 485 | "answer_2": "", 486 | "answer_3": "" 487 | }, 488 | { 489 | "question": "什么是智能远近光控制系统?", 490 | "answer_1": "", 491 | "answer_2": "", 492 | "answer_3": "" 493 | }, 494 | { 495 | "question": "在什么情况下可以开启智能远近光控制系统?", 496 | "answer_1": "", 497 | "answer_2": "", 498 | "answer_3": "" 499 | }, 500 | { 501 | "question": "哪些因素可能导致智能远近光控制系统无法正常工作?", 502 | "answer_1": "", 503 | "answer_2": "", 504 | "answer_3": "" 505 | }, 506 | { 507 | "question": "如何启用前除霜/除雾功能?", 508 | "answer_1": "", 509 | "answer_2": "", 510 | "answer_3": "" 511 | }, 512 | { 513 | "question": "当我开启前挡风玻璃的去冰去雾模式时,车辆会有什么反应?", 514 | "answer_1": "", 515 | "answer_2": "", 516 | "answer_3": "" 517 | }, 518 | { 519 | "question": "为什么驾驶之前需要确保挡风玻璃无冰渣、积雪或冷凝水?", 520 | "answer_1": "", 521 | "answer_2": "", 522 | "answer_3": "" 523 | }, 524 | { 525 | "question": "什么情况下车辆需要进行报废处理?", 526 | "answer_1": "", 527 | "answer_2": "", 528 | "answer_3": "" 529 | }, 530 | { 531 | "question": "什么是遥控泊车(RPA)?", 532 | "answer_1": "", 533 | "answer_2": "", 534 | "answer_3": "" 535 | }, 536 | { 537 | "question": "使用Lynk&CoApp的RPA功能需要注意什么?", 538 | "answer_1": "", 539 | "answer_2": "", 540 | "answer_3": "" 541 | }, 542 | { 543 | "question": "如何开启方向盘助力与驾驶模式联动?", 544 | "answer_1": "", 545 | "answer_2": "", 546 | "answer_3": "" 547 | }, 548 | { 549 | "question": "有哪些可选的方向盘转向助力模式?", 550 | "answer_1": "", 551 | "answer_2": "", 552 | "answer_3": "" 553 | }, 554 | { 555 | "question": "主动式座舱清洁系统的作用是什么?", 556 | "answer_1": "", 557 | "answer_2": "", 558 | "answer_3": "" 559 | }, 560 | { 561 | "question": "全景天窗是由几部分组成的?", 562 | "answer_1": "", 563 | "answer_2": "", 564 | "answer_3": "" 565 | }, 566 | { 567 | "question": "如何使用位置记忆功能?", 568 | "answer_1": "", 569 | "answer_2": "", 570 | "answer_3": "" 571 | }, 572 | { 573 | "question": "如何更换后挡风玻璃雨刮片?", 574 | "answer_1": "", 575 | "answer_2": "", 576 | "answer_3": "" 577 | }, 578 | { 579 | "question": "在更换雨刮片时需要注意什么?", 580 | "answer_1": "", 581 | "answer_2": "", 582 | "answer_3": "" 583 | }, 584 | { 585 | "question": "如何添加香氛精油?", 586 | "answer_1": "", 587 | "answer_2": "", 588 | "answer_3": "" 589 | }, 590 | { 591 | "question": "我应该在哪里添加香氛精油?", 592 | "answer_1": "", 593 | "answer_2": "", 594 | "answer_3": "" 595 | }, 596 | { 597 | "question": "什么是转向助力系统?", 598 | "answer_1": "", 599 | "answer_2": "", 600 | "answer_3": "" 601 | }, 602 | { 603 | "question": "涉水行驶前应注意什么?", 604 | "answer_1": "", 605 | "answer_2": "", 606 | "answer_3": "" 607 | }, 608 | { 609 | "question": "涉水驾驶后需要进行哪些检查?", 610 | "answer_1": "", 611 | "answer_2": "", 612 | "answer_3": "" 613 | }, 614 | { 615 | "question": "什么时候应该为车辆打蜡?", 616 | "answer_1": "", 617 | "answer_2": "", 618 | "answer_3": "" 619 | } 620 | ] 621 | -------------------------------------------------------------------------------- /data/train_a.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dawoshi/Tianchi-LLM-QA/fccb285a683f9ccb578eda2f22e9c94883b43672/data/train_a.pdf -------------------------------------------------------------------------------- /faiss_retriever.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | 5 | from langchain.schema import Document 6 | from langchain.vectorstores import Chroma,FAISS 7 | from langchain.embeddings.huggingface import HuggingFaceEmbeddings 8 | from pdf_parse import DataProcess 9 | import torch 10 | # from bm25_retriever import BM25 11 | 12 | class FaissRetriever(object): 13 | def __init__(self, model_path, data): 14 | self.embeddings = HuggingFaceEmbeddings( 15 | model_name = model_path, 16 | model_kwargs = {"device":"cuda"} 17 | ) 18 | docs = [] 19 | for idx, line in enumerate(data): 20 | line = line.strip("\n").strip() 21 | words = line.split("\t") 22 | docs.append(Document(page_content=words[0], metadata={"id": idx})) 23 | self.vector_store = FAISS.from_documents(docs, self.embeddings) 24 | del self.embeddings 25 | torch.cuda.empty_cache() 26 | 27 | def GetTopK(self, query, k): 28 | context = self.vector_store.similarity_search_with_score(query, k=k) 29 | return context 30 | def GetvectorStore(self): 31 | return self.vector_store 32 | 33 | if __name__ == "__main__": 34 | base = "/root/autodl-tmp/codes" 35 | model_name=base + "/pre_train_model/m3e-large" #text2vec-large-chinese 36 | dp = DataProcess(pdf_path = base + "/data/train_a.pdf") 37 | dp.ParseBlock(max_seq = 1024) 38 | dp.ParseBlock(max_seq = 512) 39 | print(len(dp.data)) 40 | dp.ParseAllPage(max_seq = 256) 41 | dp.ParseAllPage(max_seq = 512) 42 | print(len(dp.data)) 43 | dp.ParseOnePageWithRule(max_seq = 256) 44 | dp.ParseOnePageWithRule(max_seq = 512) 45 | print(len(dp.data)) 46 | data = dp.data 47 | 48 | faissretriever = FaissRetriever(model_name, data) 49 | # bm25 = BM25(data) 50 | faiss_ans = faissretriever.GetTopK("如何预防新冠肺炎", 6) 51 | print(faiss_ans) 52 | faiss_ans = faissretriever.GetTopK("交通事故如何处理", 6) 53 | print(faiss_ans) 54 | faiss_ans = faissretriever.GetTopK("吉利集团的董事长是谁", 6) 55 | print(faiss_ans) 56 | faiss_ans = faissretriever.GetTopK("吉利汽车语音组手叫什么", 6) 57 | print(faiss_ans) 58 | # bm25_ans = bm25.GetBM25TopK("座椅加热", 6) 59 | # ans = reRank(6, bm25_ans, faiss_ans) 60 | -------------------------------------------------------------------------------- /images/01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dawoshi/Tianchi-LLM-QA/fccb285a683f9ccb578eda2f22e9c94883b43672/images/01.png -------------------------------------------------------------------------------- /images/02.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dawoshi/Tianchi-LLM-QA/fccb285a683f9ccb578eda2f22e9c94883b43672/images/02.png -------------------------------------------------------------------------------- /images/03.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dawoshi/Tianchi-LLM-QA/fccb285a683f9ccb578eda2f22e9c94883b43672/images/03.png -------------------------------------------------------------------------------- /pdf_parse.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | import pdfplumber 5 | from PyPDF2 import PdfReader 6 | 7 | 8 | class DataProcess(object): 9 | 10 | def __init__(self, pdf_path): 11 | self.pdf_path = pdf_path 12 | self.data = [] 13 | def SlidingWindow(self, sentences, kernel = 512, stride = 1): 14 | sz = len(sentences) 15 | cur = "" 16 | fast = 0 17 | slow = 0 18 | while(fast < len(sentences)): 19 | sentence = sentences[fast] 20 | if(len(cur + sentence) > kernel and (cur + sentence) not in self.data): 21 | self.data.append(cur + sentence + "。") 22 | cur = cur[len(sentences[slow] + "。"):] 23 | slow = slow + 1 24 | cur = cur + sentence + "。" 25 | fast = fast + 1 26 | 27 | def Datafilter(self, line, header, pageid, max_seq = 1024): 28 | 29 | sz = len(line) 30 | if(sz < 6): 31 | return 32 | 33 | if(sz > max_seq): 34 | 35 | if("■" in line): 36 | sentences = line.split("■") 37 | elif("•" in line): 38 | sentences = line.split("•") 39 | elif("\t" in line): 40 | sentences = line.split("\t") 41 | else: 42 | sentences = line.split("。") 43 | 44 | for subsentence in sentences: 45 | subsentence = subsentence.replace("\n", "") 46 | 47 | if(len(subsentence) < max_seq and len(subsentence) > 5): 48 | # subsentence = subsentence.replace(",", "").replace("\n","").replace("\t","") + "\t" + header+ "\t" + str(pageid) 49 | subsentence = subsentence.replace(",", "").replace("\n","").replace("\t","") 50 | if(subsentence not in self.data): 51 | self.data.append(subsentence) 52 | else: 53 | # line = line.replace("\n","").replace(",", "").replace("\t","") + "\t" + header + "\t" + str(pageid) 54 | line = line.replace("\n","").replace(",", "").replace("\t","") 55 | if(line not in self.data): 56 | self.data.append(line) 57 | # 提取页头即一级标题 58 | 59 | def GetHeader(self, page): 60 | try: 61 | lines = page.extract_words()[::] 62 | except: 63 | return None 64 | if(len(lines) > 0): 65 | for line in lines: 66 | if("目录" in line["text"] or ".........." in line["text"]): 67 | return None 68 | if(line["top"] < 20 and line["top"] > 17): 69 | return line["text"] 70 | return lines[0]["text"] 71 | return None 72 | 73 | # 按照每页中块提取内容,并和一级标题进行组合,配合Document 可进行意图识别 74 | def ParseBlock(self, max_seq = 1024): 75 | 76 | with pdfplumber.open(self.pdf_path) as pdf: 77 | 78 | for i, p in enumerate(pdf.pages): 79 | header = self.GetHeader(p) 80 | 81 | if(header == None): 82 | continue 83 | 84 | texts = p.extract_words(use_text_flow=True, extra_attrs = ["size"])[::] 85 | 86 | squence = "" 87 | lastsize = 0 88 | 89 | for idx, line in enumerate(texts): 90 | if(idx <1): 91 | continue 92 | if(idx == 1): 93 | if(line["text"].isdigit()): 94 | continue 95 | cursize = line["size"] 96 | text = line["text"] 97 | if(text == "□" or text == "•"): 98 | continue 99 | elif(text== "警告!" or text == "注意!" or text == "说明!"): 100 | if(len(squence) > 0): 101 | self.Datafilter(squence, header, i, max_seq = max_seq) 102 | squence = "" 103 | elif(format(lastsize,".5f") == format(cursize,".5f")): 104 | if(len(squence)>0): 105 | squence = squence + text 106 | else: 107 | squence = text 108 | else: 109 | lastsize = cursize 110 | if(len(squence) < 15 and len(squence)>0): 111 | squence = squence + text 112 | else: 113 | if(len(squence) > 0): 114 | self.Datafilter(squence, header, i, max_seq = max_seq) 115 | squence = text 116 | if(len(squence) > 0): 117 | self.Datafilter(squence, header, i, max_seq = max_seq) 118 | def ParseOnePageWithRule(self, max_seq = 512, min_len = 6): 119 | for idx, page in enumerate(PdfReader(self.pdf_path).pages): 120 | page_content = "" 121 | text = page.extract_text() 122 | words = text.split("\n") 123 | for idx, word in enumerate(words): 124 | text = word.strip().strip("\n") 125 | if("...................." in text or "目录" in text): 126 | continue 127 | if(len(text) < 1): 128 | continue 129 | if(text.isdigit()): 130 | continue 131 | page_content = page_content + text 132 | if(len(page_content) < min_len): 133 | continue 134 | if(len(page_content) < max_seq): 135 | if(page_content not in self.data): 136 | self.data.append(page_content) 137 | else: 138 | sentences = page_content.split("。") 139 | cur = "" 140 | for idx, sentence in enumerate(sentences): 141 | if(len(cur + sentence) > max_seq and (cur + sentence) not in self.data): 142 | self.data.append(cur + sentence) 143 | cur = sentence 144 | else: 145 | cur = cur + sentence 146 | # 滑窗法提取段落 147 | # 1. 把pdf看做一个整体,作为一个字符串 148 | # 2. 利用句号当做分隔符,切分成一个数组 149 | # 3. 利用滑窗法对数组进行滑动, 此处的 150 | def ParseAllPage(self, max_seq = 512, min_len = 6): 151 | all_content = "" 152 | for idx, page in enumerate(PdfReader(self.pdf_path).pages): 153 | page_content = "" 154 | text = page.extract_text() 155 | words = text.split("\n") 156 | for idx, word in enumerate(words): 157 | text = word.strip().strip("\n") 158 | if("...................." in text or "目录" in text): 159 | continue 160 | if(len(text) < 1): 161 | continue 162 | if(text.isdigit()): 163 | continue 164 | page_content = page_content + text 165 | if(len(page_content) < min_len): 166 | continue 167 | all_content = all_content + page_content 168 | sentences = all_content.split("。") 169 | self.SlidingWindow(sentences, kernel = max_seq) 170 | 171 | # for idx, sentence in enumerate(sentences): 172 | # if(len(cur + sentence) > max_seq and (cur + sentence) not in self.data): 173 | # self.data.append(cur + sentence) 174 | # cur = sentence 175 | # else: 176 | # cur = cur + sentence 177 | 178 | if __name__ == "__main__": 179 | dp = DataProcess(pdf_path = "/root/autodl-tmp/codes/data/train_a.pdf") 180 | dp.ParseBlock(max_seq = 1024) 181 | dp.ParseBlock(max_seq = 512) 182 | print(len(dp.data)) 183 | dp.ParseAllPage(max_seq = 256) 184 | dp.ParseAllPage(max_seq = 512) 185 | print(len(dp.data)) 186 | dp.ParseOnePageWithRule(max_seq = 256) 187 | dp.ParseOnePageWithRule(max_seq = 512) 188 | print(len(dp.data)) 189 | data = dp.data 190 | out = open("all_text.txt", "w") 191 | for line in data: 192 | line = line.strip("\n") 193 | out.write(line) 194 | out.write("\n") 195 | out.close() 196 | -------------------------------------------------------------------------------- /pre_train_model/Qwen-7B-Chat/download.py: -------------------------------------------------------------------------------- 1 | #模型下载 2 | from modelscope import snapshot_download 3 | model_dir = snapshot_download('qwen/Qwen-7B-Chat') 4 | -------------------------------------------------------------------------------- /qwen_generation_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Alibaba Cloud. 2 | # 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | """Generation support.""" 7 | 8 | from typing import Tuple, List, Union, Iterable 9 | 10 | import numpy as np 11 | import torch 12 | import torch.nn.functional as F 13 | from transformers import PreTrainedTokenizer 14 | from transformers import logging 15 | from transformers.generation import LogitsProcessor 16 | 17 | logger = logging.get_logger(__name__) 18 | 19 | # Types. 20 | HistoryType = List[Tuple[str, str]] 21 | TokensType = List[int] 22 | BatchTokensType = List[List[int]] 23 | 24 | 25 | def pad_batch(batch: BatchTokensType, pad_id: int, seq_length: int) -> BatchTokensType: 26 | for tokens in batch: 27 | context_length = len(tokens) 28 | if context_length < seq_length: 29 | tokens.extend([pad_id] * (seq_length - context_length)) 30 | return batch 31 | 32 | 33 | def get_ltor_masks_and_position_ids( 34 | data, 35 | eod_token, 36 | reset_position_ids, 37 | reset_attention_mask, 38 | eod_mask_loss, 39 | ): 40 | """Build masks and position id for left to right model.""" 41 | 42 | # Extract batch size and sequence length. 43 | micro_batch_size, seq_length = data.size() 44 | 45 | # Attention mask (lower triangular). 46 | if reset_attention_mask: 47 | att_mask_batch = micro_batch_size 48 | else: 49 | att_mask_batch = 1 50 | attention_mask = torch.tril( 51 | torch.ones((att_mask_batch, seq_length, seq_length), device=data.device) 52 | ).view(att_mask_batch, 1, seq_length, seq_length) 53 | 54 | # Loss mask. 55 | loss_mask = torch.ones(data.size(), dtype=torch.float, device=data.device) 56 | if eod_mask_loss: 57 | loss_mask[data == eod_token] = 0.0 58 | 59 | # Position ids. 60 | position_ids = torch.arange(seq_length, dtype=torch.long, device=data.device) 61 | position_ids = position_ids.unsqueeze(0).expand_as(data) 62 | # We need to clone as the ids will be modifed based on batch index. 63 | if reset_position_ids: 64 | position_ids = position_ids.clone() 65 | 66 | if reset_position_ids or reset_attention_mask: 67 | # Loop through the batches: 68 | for b in range(micro_batch_size): 69 | 70 | # Find indecies where EOD token is. 71 | eod_index = position_ids[b, data[b] == eod_token] 72 | # Detach indecies from positions if going to modify positions. 73 | if reset_position_ids: 74 | eod_index = eod_index.clone() 75 | 76 | # Loop through EOD indecies: 77 | prev_index = 0 78 | for j in range(eod_index.size()[0]): 79 | i = eod_index[j] 80 | # Mask attention loss. 81 | if reset_attention_mask: 82 | attention_mask[b, 0, (i + 1) :, : (i + 1)] = 0 83 | # Reset positions. 84 | if reset_position_ids: 85 | position_ids[b, (i + 1) :] -= i + 1 - prev_index 86 | prev_index = i + 1 87 | 88 | # Convert attention mask to binary: 89 | attention_mask = attention_mask < 0.5 90 | 91 | return attention_mask, loss_mask, position_ids 92 | 93 | 94 | def get_batch(context_tokens: torch.LongTensor, eod_id: int): 95 | """Generate batch from context tokens.""" 96 | # Move to GPU. 97 | tokens = context_tokens.contiguous().to(context_tokens.device) 98 | # Get the attention mask and postition ids. 99 | attention_mask, _, position_ids = get_ltor_masks_and_position_ids( 100 | tokens, 101 | eod_id, 102 | reset_position_ids=False, 103 | reset_attention_mask=False, 104 | eod_mask_loss=False, 105 | ) 106 | return tokens, attention_mask, position_ids 107 | 108 | 109 | def get_stop_words_ids(chat_format, tokenizer): 110 | if chat_format == "raw": 111 | stop_words_ids = [tokenizer.encode("Human:"), [tokenizer.eod_id]] 112 | elif chat_format == "chatml": 113 | stop_words_ids = [[tokenizer.im_end_id], [tokenizer.im_start_id]] 114 | else: 115 | raise NotImplementedError(f"Unknown chat format {chat_format!r}") 116 | return stop_words_ids 117 | 118 | 119 | def make_context( 120 | tokenizer: PreTrainedTokenizer, 121 | query: str, 122 | history: List[Tuple[str, str]] = None, 123 | system: str = "", 124 | max_window_size: int = 6144, 125 | chat_format: str = "chatml", 126 | ): 127 | if history is None: 128 | history = [] 129 | 130 | if chat_format == "chatml": 131 | im_start, im_end = "<|im_start|>", "<|im_end|>" 132 | im_start_tokens = [tokenizer.im_start_id] 133 | im_end_tokens = [tokenizer.im_end_id] 134 | nl_tokens = tokenizer.encode("\n") 135 | 136 | def _tokenize_str(role, content): 137 | return f"{role}\n{content}", tokenizer.encode( 138 | role, allowed_special=set() 139 | ) + nl_tokens + tokenizer.encode(content, allowed_special=set()) 140 | 141 | system_text, system_tokens_part = _tokenize_str("system", system) 142 | system_tokens = im_start_tokens + system_tokens_part + im_end_tokens 143 | 144 | raw_text = "" 145 | context_tokens = [] 146 | 147 | for turn_query, turn_response in reversed(history): 148 | query_text, query_tokens_part = _tokenize_str("user", turn_query) 149 | query_tokens = im_start_tokens + query_tokens_part + im_end_tokens 150 | response_text, response_tokens_part = _tokenize_str( 151 | "assistant", turn_response 152 | ) 153 | response_tokens = im_start_tokens + response_tokens_part + im_end_tokens 154 | 155 | next_context_tokens = nl_tokens + query_tokens + nl_tokens + response_tokens 156 | prev_chat = ( 157 | f"\n{im_start}{query_text}{im_end}\n{im_start}{response_text}{im_end}" 158 | ) 159 | 160 | current_context_size = ( 161 | len(system_tokens) + len(next_context_tokens) + len(context_tokens) 162 | ) 163 | if current_context_size < max_window_size: 164 | context_tokens = next_context_tokens + context_tokens 165 | raw_text = prev_chat + raw_text 166 | else: 167 | break 168 | 169 | context_tokens = system_tokens + context_tokens 170 | raw_text = f"{im_start}{system_text}{im_end}" + raw_text 171 | context_tokens += ( 172 | nl_tokens 173 | + im_start_tokens 174 | + _tokenize_str("user", query)[1] 175 | + im_end_tokens 176 | + nl_tokens 177 | + im_start_tokens 178 | + tokenizer.encode("assistant") 179 | + nl_tokens 180 | ) 181 | raw_text += f"\n{im_start}user\n{query}{im_end}\n{im_start}assistant\n" 182 | 183 | elif chat_format == "raw": 184 | raw_text = query 185 | context_tokens = tokenizer.encode(raw_text) 186 | else: 187 | raise NotImplementedError(f"Unknown chat format {chat_format!r}") 188 | 189 | return raw_text, context_tokens 190 | 191 | 192 | def _decode_default( 193 | tokens: List[int], 194 | *, 195 | stop_words: List[str], 196 | eod_words: List[str], 197 | tokenizer: PreTrainedTokenizer, 198 | raw_text_len: int, 199 | verbose: bool = False, 200 | return_end_reason: bool = False, 201 | errors: str='replace', 202 | ): 203 | trim_decode_tokens = tokenizer.decode(tokens, errors=errors)[raw_text_len:] 204 | if verbose: 205 | print("\nRaw Generate: ", trim_decode_tokens) 206 | 207 | end_reason = f"Gen length {len(tokens)}" 208 | for stop_word in stop_words: 209 | trim_decode_tokens = trim_decode_tokens.replace(stop_word, "").strip() 210 | for eod_word in eod_words: 211 | if eod_word in trim_decode_tokens: 212 | end_reason = f"Gen {eod_word!r}" 213 | trim_decode_tokens = trim_decode_tokens.split(eod_word)[0] 214 | trim_decode_tokens = trim_decode_tokens.strip() 215 | if verbose: 216 | print("\nEnd Reason:", end_reason) 217 | print("\nGenerate: ", trim_decode_tokens) 218 | 219 | if return_end_reason: 220 | return trim_decode_tokens, end_reason 221 | else: 222 | return trim_decode_tokens 223 | 224 | 225 | def _decode_chatml( 226 | tokens: List[int], 227 | *, 228 | stop_words: List[str], 229 | eod_token_ids: List[int], 230 | tokenizer: PreTrainedTokenizer, 231 | raw_text_len: int, 232 | context_length: int, 233 | verbose: bool = False, 234 | return_end_reason: bool = False, 235 | errors: str='replace' 236 | ): 237 | end_reason = f"Gen length {len(tokens)}" 238 | eod_token_idx = context_length 239 | for eod_token_idx in range(context_length, len(tokens)): 240 | if tokens[eod_token_idx] in eod_token_ids: 241 | end_reason = f"Gen {tokenizer.decode([tokens[eod_token_idx]])!r}" 242 | break 243 | 244 | trim_decode_tokens = tokenizer.decode(tokens[:eod_token_idx], errors=errors)[raw_text_len:] 245 | if verbose: 246 | print("\nRaw Generate w/o EOD:", tokenizer.decode(tokens, errors=errors)[raw_text_len:]) 247 | print("\nRaw Generate:", trim_decode_tokens) 248 | print("\nEnd Reason:", end_reason) 249 | for stop_word in stop_words: 250 | trim_decode_tokens = trim_decode_tokens.replace(stop_word, "").strip() 251 | trim_decode_tokens = trim_decode_tokens.strip() 252 | if verbose: 253 | print("\nGenerate:", trim_decode_tokens) 254 | 255 | if return_end_reason: 256 | return trim_decode_tokens, end_reason 257 | else: 258 | return trim_decode_tokens 259 | 260 | 261 | def decode_tokens( 262 | tokens: Union[torch.LongTensor, TokensType], 263 | tokenizer: PreTrainedTokenizer, 264 | raw_text_len: int, 265 | context_length: int, 266 | chat_format: str, 267 | verbose: bool = False, 268 | return_end_reason: bool = False, 269 | errors: str="replace", 270 | ) -> str: 271 | if torch.is_tensor(tokens): 272 | tokens = tokens.cpu().numpy().tolist() 273 | 274 | if chat_format == "chatml": 275 | return _decode_chatml( 276 | tokens, 277 | stop_words=[], 278 | eod_token_ids=[tokenizer.im_start_id, tokenizer.im_end_id], 279 | tokenizer=tokenizer, 280 | raw_text_len=raw_text_len, 281 | context_length=context_length, 282 | verbose=verbose, 283 | return_end_reason=return_end_reason, 284 | errors=errors, 285 | ) 286 | elif chat_format == "raw": 287 | return _decode_default( 288 | tokens, 289 | stop_words=["<|endoftext|>"], 290 | eod_words=["<|endoftext|>"], 291 | tokenizer=tokenizer, 292 | raw_text_len=raw_text_len, 293 | verbose=verbose, 294 | return_end_reason=return_end_reason, 295 | errors=errors, 296 | ) 297 | else: 298 | raise NotImplementedError(f"Unknown chat format {chat_format!r}") 299 | 300 | 301 | class StopWordsLogitsProcessor(LogitsProcessor): 302 | """ 303 | :class:`transformers.LogitsProcessor` that enforces that when specified sequences appear, stop geration. 304 | 305 | Args: 306 | stop_words_ids (:obj:`List[List[int]]`): 307 | List of list of token ids of stop ids. In order to get the tokens of the words 308 | that should not appear in the generated text, use :obj:`tokenizer(bad_word, 309 | add_prefix_space=True).input_ids`. 310 | eos_token_id (:obj:`int`): 311 | The id of the `end-of-sequence` token. 312 | """ 313 | 314 | def __init__(self, stop_words_ids: Iterable[Iterable[int]], eos_token_id: int): 315 | 316 | if not isinstance(stop_words_ids, List) or len(stop_words_ids) == 0: 317 | raise ValueError( 318 | f"`stop_words_ids` has to be a non-emtpy list, but is {stop_words_ids}." 319 | ) 320 | if any(not isinstance(bad_word_ids, list) for bad_word_ids in stop_words_ids): 321 | raise ValueError( 322 | f"`stop_words_ids` has to be a list of lists, but is {stop_words_ids}." 323 | ) 324 | if any( 325 | any( 326 | (not isinstance(token_id, (int, np.integer)) or token_id < 0) 327 | for token_id in stop_word_ids 328 | ) 329 | for stop_word_ids in stop_words_ids 330 | ): 331 | raise ValueError( 332 | f"Each list in `stop_words_ids` has to be a list of positive integers, but is {stop_words_ids}." 333 | ) 334 | 335 | self.stop_words_ids = list( 336 | filter( 337 | lambda bad_token_seq: bad_token_seq != [eos_token_id], stop_words_ids 338 | ) 339 | ) 340 | self.eos_token_id = eos_token_id 341 | for stop_token_seq in self.stop_words_ids: 342 | assert ( 343 | len(stop_token_seq) > 0 344 | ), "Stop words token sequences {} cannot have an empty list".format( 345 | stop_words_ids 346 | ) 347 | 348 | def __call__( 349 | self, input_ids: torch.LongTensor, scores: torch.FloatTensor 350 | ) -> torch.FloatTensor: 351 | stopped_samples = self._calc_stopped_samples(input_ids) 352 | for i, should_stop in enumerate(stopped_samples): 353 | if should_stop: 354 | scores[i, self.eos_token_id] = float(2**15) 355 | return scores 356 | 357 | def _tokens_match(self, prev_tokens: torch.LongTensor, tokens: List[int]) -> bool: 358 | if len(tokens) == 0: 359 | # if bad word tokens is just one token always ban it 360 | return True 361 | elif len(tokens) > len(prev_tokens): 362 | # if bad word tokens are longer then prev input_ids they can't be equal 363 | return False 364 | elif prev_tokens[-len(tokens) :].tolist() == tokens: 365 | # if tokens match 366 | return True 367 | else: 368 | return False 369 | 370 | def _calc_stopped_samples(self, prev_input_ids: Iterable[int]) -> Iterable[int]: 371 | stopped_samples = [] 372 | for prev_input_ids_slice in prev_input_ids: 373 | match = False 374 | for stop_token_seq in self.stop_words_ids: 375 | if self._tokens_match(prev_input_ids_slice, stop_token_seq): 376 | # if tokens do not match continue 377 | match = True 378 | break 379 | stopped_samples.append(match) 380 | 381 | return stopped_samples 382 | 383 | 384 | def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float("Inf")): 385 | """This function has been mostly taken from huggingface conversational 386 | ai code at 387 | https://medium.com/huggingface/how-to-build-a-state-of-the-art- 388 | conversational-ai-with-transfer-learning-2d818ac26313""" 389 | 390 | if top_k > 0: 391 | # Remove all tokens with a probability less than the 392 | # last token of the top-k 393 | indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] 394 | logits[indices_to_remove] = filter_value 395 | 396 | if top_p > 0.0: 397 | # Cconvert to 1D 398 | sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) 399 | cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) 400 | 401 | # Remove tokens with cumulative probability above the threshold 402 | sorted_indices_to_remove = cumulative_probs > top_p 403 | # Shift the indices to the right to keep also the first token 404 | # above the threshold 405 | sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() 406 | sorted_indices_to_remove[..., 0] = 0 407 | for i in range(sorted_indices.size(0)): 408 | indices_to_remove = sorted_indices[i][sorted_indices_to_remove[i]] 409 | logits[i][indices_to_remove] = filter_value 410 | 411 | return logits 412 | 413 | 414 | def switch(val1, val2, boolean): 415 | boolean = boolean.type_as(val1) 416 | return (1 - boolean) * val1 + boolean * val2 417 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | python==3.9.12 2 | vllm 3 | modelscope 4 | tiktoken 5 | pdfplumber 6 | PdfReader 7 | PyPDF2 8 | langchain 9 | jieba 10 | rank_bm25 11 | sentence-transformers 12 | faiss-gpu 13 | -------------------------------------------------------------------------------- /rerank_model.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoModelForSequenceClassification, AutoTokenizer 2 | import os 3 | 4 | from bm25_retriever import BM25 5 | from pdf_parse import DataProcess 6 | from config import * 7 | 8 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 9 | 10 | DEVICE = LLM_DEVICE 11 | DEVICE_ID = "0" 12 | CUDA_DEVICE = f"{DEVICE}:{DEVICE_ID}" if DEVICE_ID else DEVICE 13 | 14 | 15 | def torch_gc(): 16 | if torch.cuda.is_available(): 17 | with torch.cuda.device(CUDA_DEVICE): 18 | torch.cuda.empty_cache() 19 | torch.cuda.ipc_collect() 20 | class reRankLLM(object): 21 | def __init__(self, model_path, max_length = 512): 22 | self.tokenizer = AutoTokenizer.from_pretrained(model_path) 23 | self.model = AutoModelForSequenceClassification.from_pretrained(model_path) 24 | self.model.eval() 25 | self.model.half() 26 | self.model.cuda() 27 | self.max_length = max_length 28 | 29 | def predict(self, query, docs): 30 | pairs = [(query, doc.page_content) for doc in docs] 31 | inputs = self.tokenizer(pairs, padding=True, truncation=True, return_tensors='pt', max_length=self.max_length).to("cuda") 32 | with torch.no_grad(): 33 | scores = self.model(**inputs).logits 34 | scores = scores.detach().cpu().clone().numpy() 35 | response = [doc for score, doc in sorted(zip(scores, docs), reverse=True, key=lambda x:x[0])] 36 | torch_gc() 37 | return response 38 | if __name__ == "__main__": 39 | bge_reranker_large = "/Users/william/codes/contest/aicar_docker/new_build/app/pre_train_model/bge-reranker-large" 40 | rerank = reRankLLM(bge_reranker_large) 41 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | import json 5 | import jieba 6 | import pandas as pd 7 | import numpy as np 8 | from tqdm import tqdm 9 | from langchain.schema import Document 10 | from langchain.vectorstores import Chroma,FAISS 11 | from langchain import PromptTemplate, LLMChain 12 | from langchain.chains import RetrievalQA 13 | import time 14 | import re 15 | 16 | from vllm_model import ChatLLM 17 | from vllm_wrapper import vLLMWrapper 18 | from rerank_model import reRankLLM 19 | from faiss_retriever import FaissRetriever 20 | from bm25_retriever import BM25 21 | from pdf_parse import DataProcess 22 | 23 | def get_qa_chain(llm, vector_store, prompt_template): 24 | 25 | prompt = PromptTemplate(template=prompt_template, 26 | input_variables=["context", "question"]) 27 | 28 | return RetrievalQA.from_llm(llm=llm, retriever=vector_store.as_retriever(search_kwargs={"k": 10}), prompt=prompt) 29 | def get_emb_bm25_merge(faiss_context, bm25_context, query): 30 | max_length = 2500 31 | emb_ans = "" 32 | cnt = 0 33 | for doc, score in faiss_context: 34 | cnt =cnt + 1 35 | if(cnt>6): 36 | break 37 | if(len(emb_ans + doc.page_content) > max_length): 38 | break 39 | emb_ans = emb_ans + doc.page_content 40 | bm25_ans = "" 41 | cnt = 0 42 | for doc in bm25_context: 43 | cnt = cnt + 1 44 | if(len(bm25_ans + doc.page_content) > max_length): 45 | break 46 | bm25_ans = bm25_ans + doc.page_content 47 | if(cnt > 6): 48 | break 49 | 50 | prompt_template = """基于以下已知信息,简洁和专业的来回答用户的问题。 51 | 如果无法从中得到答案,请说 "无答案"或"无答案",不允许在答案中添加编造成分,答案请使用中文。 52 | 已知内容为吉利控股集团汽车销售有限公司的吉利用户手册: 53 | 1: {emb_ans} 54 | 2: {bm25_ans} 55 | 问题: 56 | {question}""".format(emb_ans=emb_ans, bm25_ans = bm25_ans, question = query) 57 | return prompt_template 58 | def get_rerank(emb_ans, query): 59 | 60 | prompt_template = """基于以下已知信息,简洁和专业的来回答用户的问题。 61 | 如果无法从中得到答案,请说 "无答案"或"无答案" ,不允许在答案中添加编造成分,答案请使用中文。 62 | 已知内容为吉利控股集团汽车销售有限公司的吉利用户手册: 63 | 1: {emb_ans} 64 | 问题: 65 | {question}""".format(emb_ans=emb_ans, question = query) 66 | return prompt_template 67 | 68 | 69 | def question(text, llm, vector_store, prompt_template): 70 | 71 | chain = get_qa_chain(llm, vector_store, prompt_template) 72 | 73 | response = chain({"query": text}) 74 | return response 75 | 76 | def reRank(rerank, top_k, query, bm25_ans, faiss_ans): 77 | items = [] 78 | max_length = 4000 79 | for doc, score in faiss_ans: 80 | items.append(doc) 81 | items.extend(bm25_ans) 82 | rerank_ans = rerank.predict(query, items) 83 | rerank_ans = rerank_ans[:top_k] 84 | # docs_sort = sorted(rerank_ans, key = lambda x:x.metadata["id"]) 85 | emb_ans = "" 86 | for doc in rerank_ans: 87 | if(len(emb_ans + doc.page_content) > max_length): 88 | break 89 | emb_ans = emb_ans + doc.page_content 90 | return emb_ans 91 | 92 | if __name__ == "__main__": 93 | 94 | start = time.time() 95 | # base = "/app" 96 | # qwen7 = "/tcdata/qwen/Qwen-7B-Chat" 97 | 98 | base = "/root/autodl-tmp/codes" 99 | qwen7 = base + "/pre_train_model/Qwen-7B-Chat" 100 | m3e = base + "/pre_train_model/m3e-large" 101 | bge_reranker_large = base + "/pre_train_model/bge-reranker-large" 102 | 103 | # data 104 | # dp = DataProcess(pdf_path = "/tcdata/trainning_data.pdf") 105 | dp = DataProcess(pdf_path = base + "/data/train_a.pdf") 106 | dp.ParseBlock(max_seq = 1024) 107 | dp.ParseBlock(max_seq = 512) 108 | print(len(dp.data)) 109 | dp.ParseAllPage(max_seq = 256) 110 | dp.ParseAllPage(max_seq = 512) 111 | print(len(dp.data)) 112 | dp.ParseOnePageWithRule(max_seq = 256) 113 | dp.ParseOnePageWithRule(max_seq = 512) 114 | print(len(dp.data)) 115 | data = dp.data 116 | print("data load ok") 117 | 118 | # Faiss 119 | faissretriever = FaissRetriever(m3e, data) 120 | vector_store = faissretriever.vector_store 121 | print("faissretriever load ok") 122 | 123 | # BM2.5 124 | bm25 = BM25(data) 125 | print("bm25 load ok") 126 | 127 | # LLM 128 | # llm = vLLMWrapper(qwen7) 129 | llm = ChatLLM(qwen7) 130 | print("llm qwen load ok") 131 | 132 | # reRank 133 | rerank = reRankLLM(bge_reranker_large) 134 | print("rerank model load ok") 135 | 136 | # with open("/tcdata/test_question.json", "r") as f: 137 | 138 | with open(base + "/data/test_question.json", "r") as f: 139 | jdata = json.loads(f.read()) 140 | print(len(jdata)) 141 | max_length = 4000 142 | for idx, line in enumerate(jdata): 143 | query = line["question"] 144 | 145 | # faiss 146 | faiss_context = faissretriever.GetTopK(query, 15) 147 | faiss_min_score = 0.0 148 | if(len(faiss_context) > 0): 149 | faiss_min_score = faiss_context[0][1] 150 | cnt = 0 151 | emb_ans = "" 152 | for doc, score in faiss_context: 153 | cnt =cnt + 1 154 | if(len(emb_ans + doc.page_content) > max_length): 155 | break 156 | emb_ans = emb_ans + doc.page_content 157 | if(cnt>6): 158 | break 159 | 160 | # bm2.5 161 | bm25_context = bm25.GetBM25TopK(query, 15) 162 | bm25_ans = "" 163 | cnt = 0 164 | for doc in bm25_context: 165 | cnt = cnt + 1 166 | if(len(bm25_ans + doc.page_content) > max_length): 167 | break 168 | bm25_ans = bm25_ans + doc.page_content 169 | if(cnt > 6): 170 | break 171 | 172 | emb_bm25_merge_inputs = get_emb_bm25_merge(faiss_context, bm25_context, query) 173 | bm25_inputs = get_rerank(bm25_ans, query) 174 | emb_inputs = get_rerank(emb_ans, query) 175 | 176 | # rerank emb recall 177 | rerank_ans = reRank(rerank, 6, query, bm25_context, faiss_context) 178 | rerank_inputs = get_rerank(rerank_ans, query) 179 | 180 | batch_input = [] 181 | batch_input.append(emb_bm25_merge_inputs) 182 | batch_input.append(bm25_inputs) 183 | batch_input.append(emb_inputs) 184 | batch_input.append(rerank_inputs) 185 | batch_output = llm.infer(batch_input) 186 | line["answer_1"] = batch_output[0] 187 | line["answer_2"] = batch_output[1] 188 | line["answer_3"] = batch_output[2] 189 | line["answer_4"] = batch_output[3] 190 | line["answer_5"] = emb_ans 191 | line["answer_6"] = bm25_ans 192 | line["answer_7"] = rerank_ans 193 | if(faiss_min_score >500): 194 | line["answer_5"] = "无答案" 195 | else: 196 | line["answer_5"] = str(faiss_min_score) 197 | # json.dump(jdata, open("/app/result.json", "w", encoding='utf-8'), ensure_ascii=False, indent=2) 198 | json.dump(jdata, open(base + "/data/result.json", "w", encoding='utf-8'), ensure_ascii=False, indent=2) 199 | end = time.time() 200 | print("cost time: " + str(int(end-start)/60)) 201 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | python /app/run.py 2 | -------------------------------------------------------------------------------- /vllm_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import time 4 | 5 | from config import * 6 | from vllm import LLM, SamplingParams 7 | 8 | from transformers import AutoModelForCausalLM, AutoTokenizer 9 | from transformers import GenerationConfig 10 | from qwen_generation_utils import make_context, decode_tokens, get_stop_words_ids 11 | 12 | 13 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 14 | 15 | DEVICE = LLM_DEVICE 16 | DEVICE_ID = "0" 17 | CUDA_DEVICE = f"{DEVICE}:{DEVICE_ID}" if DEVICE_ID else DEVICE 18 | 19 | IMEND = "<|im_end|>" 20 | ENDOFTEXT = "<|endoftext|>" 21 | 22 | def get_stop_words_ids(chat_format, tokenizer): 23 | if chat_format == "raw": 24 | stop_words_ids = [tokenizer.encode("Human:"), [tokenizer.eod_id]] 25 | elif chat_format == "chatml": 26 | stop_words_ids = [[tokenizer.im_end_id], [tokenizer.im_start_id]] 27 | else: 28 | raise NotImplementedError(f"Unknown chat format {chat_format!r}") 29 | return stop_words_ids 30 | 31 | def torch_gc(): 32 | if torch.cuda.is_available(): 33 | with torch.cuda.device(CUDA_DEVICE): 34 | torch.cuda.empty_cache() 35 | torch.cuda.ipc_collect() 36 | 37 | class ChatLLM(object): 38 | 39 | def __init__(self, model_path): 40 | self.tokenizer = AutoTokenizer.from_pretrained( 41 | model_path, 42 | pad_token='<|extra_0|>', 43 | eos_token='<|endoftext|>', 44 | padding_side='left', 45 | trust_remote_code=True 46 | ) 47 | self.generation_config = GenerationConfig.from_pretrained(model_path, pad_token_id=self.tokenizer.pad_token_id) 48 | self.tokenizer.eos_token_id = self.generation_config.eos_token_id 49 | self.stop_words_ids = [] 50 | self.model = LLM(model=model_path, 51 | tokenizer=model_path, 52 | tensor_parallel_size=1, 53 | trust_remote_code=True, 54 | gpu_memory_utilization=0.90, 55 | dtype="bfloat16") 56 | for stop_id in get_stop_words_ids(self.generation_config.chat_format, self.tokenizer): 57 | self.stop_words_ids.extend(stop_id) 58 | self.stop_words_ids.extend([self.generation_config.eos_token_id]) 59 | sampling_kwargs = { 60 | "stop_token_ids": self.stop_words_ids, 61 | "early_stopping": False, 62 | "top_p": 1.0, 63 | "top_k": -1 if self.generation_config.top_k == 0 else self.generation_config.top_k, 64 | "temperature": 0.0, 65 | "max_tokens": 2000, 66 | "repetition_penalty": self.generation_config.repetition_penalty, 67 | "n":1, 68 | "best_of":2, 69 | "use_beam_search":True 70 | } 71 | self.sampling_params = SamplingParams(**sampling_kwargs) 72 | 73 | def infer(self, prompts): 74 | batch_text = [] 75 | for q in prompts: 76 | raw_text, _ = make_context( 77 | self.tokenizer, 78 | q, 79 | system="You are a helpful assistant.", 80 | max_window_size=self.generation_config.max_window_size, 81 | chat_format=self.generation_config.chat_format, 82 | ) 83 | batch_text.append(raw_text) 84 | outputs = self.model.generate(batch_text, 85 | sampling_params = self.sampling_params 86 | ) 87 | batch_response = [] 88 | for output in outputs: 89 | output_str = output.outputs[0].text 90 | if IMEND in output_str: 91 | output_str = output_str[:-len(IMEND)] 92 | if ENDOFTEXT in output_str: 93 | output_str = output_str[:-len(ENDOFTEXT)] 94 | batch_response.append(output_str) 95 | torch_gc() 96 | return batch_response 97 | 98 | if __name__ == "__main__": 99 | qwen7 = "/root/autodl-tmp/codes/pre_train_model/Qwen-7B-Chat" 100 | start = time.time() 101 | llm = ChatLLM(qwen7) 102 | test = ["吉利汽车座椅按摩","吉利汽车语音组手唤醒","自动驾驶功能介绍"] 103 | generated_text = llm.infer(test) 104 | print(generated_text) 105 | end = time.time() 106 | print("cost time: " + str((end-start)/60)) 107 | -------------------------------------------------------------------------------- /vllm_wrapper.py: -------------------------------------------------------------------------------- 1 | from transformers import PreTrainedTokenizer, GenerationConfig, StoppingCriteriaList 2 | from typing import Optional, Callable, List, Tuple, Union 3 | import copy 4 | import torch 5 | from transformers import AutoTokenizer 6 | from transformers.generation.logits_process import LogitsProcessorList 7 | from packaging import version 8 | 9 | _ERROR_BAD_CHAT_FORMAT = """\ 10 | We detect you are probably using the pretrained model (rather than chat model) for chatting, since the chat_format in generation_config is not "chatml". 11 | If you are directly using the model downloaded from Huggingface, please make sure you are using our "Qwen/Qwen-7B-Chat" Huggingface model (rather than "Qwen/Qwen-7B") when you call model.chat(). 12 | 我们检测到您可能在使用预训练模型(而非chat模型)进行多轮chat,因为您当前在generation_config指定的chat_format,并未设置为我们在对话中所支持的"chatml"格式。 13 | 如果您在直接使用我们从Huggingface提供的模型,请确保您在调用model.chat()时,使用的是"Qwen/Qwen-7B-Chat"模型(而非"Qwen/Qwen-7B"预训练模型)。 14 | """ 15 | 16 | IMEND = "<|im_end|>" 17 | ENDOFTEXT = "<|endoftext|>" 18 | 19 | HistoryType = List[Tuple[str, str]] 20 | TokensType = List[int] 21 | BatchTokensType = List[List[int]] 22 | 23 | def get_stop_words_ids(chat_format, tokenizer): 24 | if chat_format == "raw": 25 | stop_words_ids = [tokenizer.encode("Human:"), [tokenizer.eod_id]] 26 | elif chat_format == "chatml": 27 | stop_words_ids = [[tokenizer.im_end_id], [tokenizer.im_start_id]] 28 | else: 29 | raise NotImplementedError(f"Unknown chat format {chat_format!r}") 30 | return stop_words_ids 31 | 32 | def make_context( 33 | tokenizer: PreTrainedTokenizer, 34 | query: str, 35 | history: List[Tuple[str, str]] = None, 36 | system: str = "", 37 | max_window_size: int = 6144, 38 | chat_format: str = "chatml", 39 | ): 40 | if history is None: 41 | history = [] 42 | 43 | if chat_format == "chatml": 44 | im_start, im_end = "<|im_start|>", "<|im_end|>" 45 | im_start_tokens = [tokenizer.im_start_id] 46 | im_end_tokens = [tokenizer.im_end_id] 47 | nl_tokens = tokenizer.encode("\n") 48 | 49 | def _tokenize_str(role, content): 50 | return f"{role}\n{content}", tokenizer.encode( 51 | role, allowed_special=set() 52 | ) + nl_tokens + tokenizer.encode(content, allowed_special=set()) 53 | 54 | system_text, system_tokens_part = _tokenize_str("system", system) 55 | system_tokens = im_start_tokens + system_tokens_part + im_end_tokens 56 | 57 | raw_text = "" 58 | context_tokens = [] 59 | 60 | for turn_query, turn_response in reversed(history): 61 | query_text, query_tokens_part = _tokenize_str("user", turn_query) 62 | query_tokens = im_start_tokens + query_tokens_part + im_end_tokens 63 | response_text, response_tokens_part = _tokenize_str( 64 | "assistant", turn_response 65 | ) 66 | response_tokens = im_start_tokens + response_tokens_part + im_end_tokens 67 | 68 | next_context_tokens = nl_tokens + query_tokens + nl_tokens + response_tokens 69 | prev_chat = ( 70 | f"\n{im_start}{query_text}{im_end}\n{im_start}{response_text}{im_end}" 71 | ) 72 | 73 | current_context_size = ( 74 | len(system_tokens) + len(next_context_tokens) + len(context_tokens) 75 | ) 76 | if current_context_size < max_window_size: 77 | context_tokens = next_context_tokens + context_tokens 78 | raw_text = prev_chat + raw_text 79 | else: 80 | break 81 | 82 | context_tokens = system_tokens + context_tokens 83 | raw_text = f"{im_start}{system_text}{im_end}" + raw_text 84 | context_tokens += ( 85 | nl_tokens 86 | + im_start_tokens 87 | + _tokenize_str("user", query)[1] 88 | + im_end_tokens 89 | + nl_tokens 90 | + im_start_tokens 91 | + tokenizer.encode("assistant") 92 | + nl_tokens 93 | ) 94 | raw_text += f"\n{im_start}user\n{query}{im_end}\n{im_start}assistant\n" 95 | 96 | elif chat_format == "raw": 97 | raw_text = query 98 | context_tokens = tokenizer.encode(raw_text) 99 | else: 100 | raise NotImplementedError(f"Unknown chat format {chat_format!r}") 101 | 102 | return raw_text, context_tokens 103 | 104 | class vLLMWrapper: 105 | def __init__(self, 106 | model_dir: str, 107 | trust_remote_code: bool = True, 108 | tensor_parallel_size: int = 1, 109 | gpu_memory_utilization: float = 0.98, 110 | dtype: str = "bfloat16", 111 | **kwargs): 112 | 113 | if dtype not in ("bfloat16", "float16", "float32"): 114 | print("now not support {}!".format(dtype)) 115 | raise Exception 116 | 117 | # build generation_config 118 | self.generation_config = GenerationConfig.from_pretrained(model_dir, trust_remote_code=trust_remote_code) 119 | 120 | # build tokenizer 121 | self.tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True) 122 | self.tokenizer.eos_token_id = self.generation_config.eos_token_id 123 | 124 | self.stop_words_ids = [] 125 | 126 | from vllm import LLM 127 | import vllm 128 | if version.parse(vllm.__version__) >= version.parse("0.2.2"): 129 | self.__vllm_support_repetition_penalty = True 130 | else: 131 | self.__vllm_support_repetition_penalty = False 132 | 133 | quantization = getattr(kwargs, 'quantization', None) 134 | 135 | self.model = LLM(model=model_dir, 136 | tokenizer=model_dir, 137 | tensor_parallel_size=tensor_parallel_size, 138 | trust_remote_code=trust_remote_code, 139 | quantization=quantization, 140 | gpu_memory_utilization=gpu_memory_utilization, 141 | dtype=dtype) 142 | 143 | for stop_id in get_stop_words_ids(self.generation_config.chat_format, self.tokenizer): 144 | self.stop_words_ids.extend(stop_id) 145 | self.stop_words_ids.extend([self.generation_config.eos_token_id]) 146 | 147 | def chat(self, 148 | query: str, 149 | history: Optional[HistoryType], 150 | tokenizer: PreTrainedTokenizer = None, 151 | system: str = "You are a helpful assistant.", 152 | generation_config: Optional[GenerationConfig] = None, 153 | **kwargs): 154 | generation_config = generation_config if generation_config is not None else self.generation_config 155 | tokenizer = self.tokenizer if tokenizer is None else tokenizer 156 | 157 | assert generation_config.chat_format == 'chatml', _ERROR_BAD_CHAT_FORMAT 158 | if not self.__vllm_support_repetition_penalty and generation_config.repetition_penalty != 1: 159 | raise RuntimeError("The installed vLLM doesn't support repetition_penalty, please set ``model.generation_config.repetition_penalty = 1`` or install vllm>=0.2.2") 160 | 161 | if history is None: 162 | history = [] 163 | else: 164 | # make a copy of the user's input such that is is left untouched 165 | history = copy.deepcopy(history) 166 | 167 | extra_stop_words_ids = kwargs.get('stop_words_ids', None) 168 | if extra_stop_words_ids is None: 169 | extra_stop_words_ids = [] 170 | 171 | max_window_size = kwargs.get('max_window_size', None) 172 | if max_window_size is None: 173 | max_window_size = generation_config.max_window_size 174 | 175 | from vllm.sampling_params import SamplingParams 176 | sampling_kwargs = { 177 | "stop_token_ids": self.stop_words_ids, 178 | "early_stopping": False, 179 | "top_p": 1.0, 180 | "top_k": -1 if generation_config.top_k == 0 else generation_config.top_k, 181 | "temperature": 0.0, 182 | "max_tokens": 512, 183 | "repetition_penalty": generation_config.repetition_penalty, 184 | "n":1, 185 | "best_of":2, 186 | "use_beam_search":True 187 | } 188 | if not self.__vllm_support_repetition_penalty: 189 | sampling_kwargs.pop("repetition_penalty") 190 | sampling_params = SamplingParams(**sampling_kwargs) 191 | 192 | raw_text, context_tokens = make_context( 193 | self.tokenizer, 194 | query, 195 | history=history, 196 | system=system, 197 | max_window_size=max_window_size, 198 | chat_format=generation_config.chat_format, 199 | ) 200 | 201 | req_outputs = self.model.generate([query], 202 | sampling_params=sampling_params, 203 | prompt_token_ids=[context_tokens]) 204 | req_output = req_outputs[0] 205 | 206 | prompt_str = req_output.prompt 207 | prompt_ids = req_output.prompt_token_ids 208 | req_sample_output_ids = [] 209 | req_sample_output_strs = [] 210 | for sample in req_output.outputs: 211 | output_str = sample.text 212 | output_ids = sample.token_ids 213 | if IMEND in output_str: 214 | output_str = output_str[:-len(IMEND)] 215 | if ENDOFTEXT in output_str: 216 | output_str = output_str[:-len(ENDOFTEXT)] 217 | req_sample_output_ids.append(prompt_ids + output_ids) 218 | req_sample_output_strs.append(prompt_str + output_str) 219 | assert len(req_sample_output_strs) == 1 220 | response = req_sample_output_strs[0][len(prompt_str):] 221 | history.append((prompt_str, response)) 222 | 223 | return response, history 224 | 225 | if __name__ == '__main__': 226 | 227 | model_dir = 'Qwen/Qwen-72B-Chat' 228 | tensor_parallel_size = 1 229 | 230 | model = vLLMWrapper(model_dir, 231 | tensor_parallel_size=tensor_parallel_size, 232 | ) 233 | 234 | response, history = model.chat(query="你好", 235 | history=None) 236 | print(response) 237 | response, history = model.chat(query="给我讲一个年轻人奋斗创业最终取得成功的故事。", 238 | history=history) 239 | print(response) 240 | response, history = model.chat(query="给这个故事起一个标题", 241 | history=history) 242 | print(response) 243 | --------------------------------------------------------------------------------