├── .gitignore ├── README.md ├── classifier.py ├── config.py ├── models ├── Embeddings.py ├── LLM.py └── VectorBase.py ├── script ├── create_db.py └── create_test.py ├── test.py └── utils └── time_it.py /.gitignore: -------------------------------------------------------------------------------- 1 | data 2 | storage 3 | *.log 4 | __pycache__ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 通用大模型文本分类实践 2 | ## 基本原理 3 | 由于大模型自己具备较强的理解和推理能力,常规的指令大模型都是了解的,因此利用大模型做文本分类更关注下面几个内容: 4 | - 分类任务的具体目标需要在prompt中体现; 5 | - 尽可能每个类目的概念都有相对详细的解释,尤其尤其强调类目之间的差别。 6 | 7 | 而配合in-context learning的思想,比较简洁地使用大模型进行文本分类的prompt应该包含如下成分: 8 | 1. 分类任务的介绍及其需求细节; 9 | 2. 每个类目的概念解释; 10 | 3. 每个类目最好还有些例子(用学术的方法说,就是few-shot吧); 11 | 4. 需要分类的文本。 12 | 13 | 但在实际应用过程中,可能会出现类目较多、样本较多的问题,2/3是非常容易让prompt膨胀过长的,然而很长的prompt往往容易让大模型的推理效果下降,里面某些内容要求容易被忽略,因此如果有经过筛选再进入大模型就会变得更方便。因此,前面借助向量检索来缩小范围,然后交给大模型来进行最终的决策。 14 | 15 | 此时方案就形成了,思路如下。 16 | 17 | 离线,提前配置好每个类目的概念及其对应的样本。(某种程度上,其实算是一种训练,整个思路其实就跟KNN里的训练是类似的) 18 | 在线,先对给定query进行向量召回,然后把召回结果信息交给大模型做最终决策。 19 | 这么说比较抽象,这里我给出例子,方便大家理解处理吧。 20 | 21 | 强调,本方法不对任何模型进行参数更新,都是直接下载开源模型参数直接使用的,这也算是本方案的一大优势吧。 22 | 项目地址:[git@github.com:sunyongdi/llm_classification.git](https://github.com/sunyongdi/llm_classification.git) 23 | 24 | ![基于大模型的文本分类架构图](https://picgo-1305561115.cos.ap-beijing.myqcloud.com/img/20240813095623.png) 25 | 26 | 上图是参考GPT-RE 结合**叉烧大佬的通用大模型文本分类实践**改写的基于的大模型的文本分类模型,其实简单来说就是使用大模型进行文本分类,再添加上下文学习,提高模型的准确度。在GPT-RE中提到,提供的上下文的文本和类别贴近原文,效果会有所提升。这里使用当前先进的BGE作向量模型,使用K-BERT提取文本的关键词。最后拼接召回相似例子作为上下文喂给大模型。 27 | ## 具体实现 28 | ### 代码结构 29 | 1. 大模型使用Qwen2-7B-Instruct 30 | ```python 31 | class QwenMode(BaseModel): 32 | def __init__(self, model_path) -> None: 33 | self.model = AutoModelForCausalLM.from_pretrained( 34 | model_path, 35 | torch_dtype="auto", 36 | device_map="auto" 37 | ) 38 | self.tokenizer = AutoTokenizer.from_pretrained(model_path) 39 | self.model = self.model.eval() 40 | self.device = self.model.device 41 | 42 | logger.info("load LLM Model done") 43 | 44 | def chat(self, 45 | messages: str, 46 | max_new_tokens: int = 1024, 47 | do_sample:bool = False, 48 | top_k: float = 1, 49 | temperature: float = 0.8 50 | )->str: 51 | 52 | text = self.tokenizer.apply_chat_template( 53 | messages, 54 | tokenize=False, 55 | add_generation_prompt=True 56 | ) 57 | model_inputs = self.tokenizer([text], return_tensors="pt").to(self.device) 58 | 59 | generated_ids = self.model.generate( 60 | model_inputs.input_ids, 61 | attention_mask=model_inputs.attention_mask, 62 | pad_token_id=self.tokenizer.eos_token_id, 63 | max_new_tokens=max_new_tokens, 64 | do_sample=do_sample, 65 | top_k=top_k, 66 | temperature=temperature 67 | ) 68 | generated_ids = [ 69 | output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids) 70 | ] 71 | 72 | logger.info(f'input_tokens:{len(model_inputs.input_ids.tolist()[0])} \t generated_ids:{len(generated_ids[0].tolist())} \t all_tokens:{len(model_inputs.input_ids.tolist()[0]) + len(generated_ids[0].tolist())}') 73 | response = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] 74 | return response 75 | ``` 76 | 2. Embedding 采用 bge-base-zh-v1.5 77 | ```python 78 | class BgeEmbedding(BaseEmbeddings): 79 | """ 80 | class for BGE embeddings 81 | """ 82 | 83 | def __init__(self, path: str = 'BAAI/bge-base-zh-v1.5') -> None: 84 | self._model, self._tokenizer = self.load_model(path) 85 | 86 | def get_embedding(self, text: str) -> List[float]: 87 | import torch 88 | encoded_input = self._tokenizer([text], max_length=512, padding='max_length', truncation=True, return_tensors='pt') 89 | encoded_input = {k: v.to(self._model.device) for k, v in encoded_input.items()} 90 | with torch.no_grad(): 91 | model_output = self._model(**encoded_input) 92 | sentence_embeddings = model_output[0][:, 0] 93 | sentence_embeddings = torch.nn.functional.normalize(sentence_embeddings, p=2, dim=1) 94 | return sentence_embeddings[0].tolist() 95 | 96 | def load_model(self, path: str): 97 | import torch 98 | from transformers import AutoModel, AutoTokenizer 99 | if torch.cuda.is_available(): 100 | device = torch.device("cuda") 101 | else: 102 | device = torch.device("cpu") 103 | tokenizer = AutoTokenizer.from_pretrained(path) 104 | model = AutoModel.from_pretrained(path).to(device) 105 | model.eval() 106 | return model, tokenizer 107 | ``` 108 | 3. 向量库 109 | 110 | 向量库这里选择的是milvus,本着学习的态度还是比较好部署和使用的,并且官方文档也比较全面,有着RAG的详细案例教程。 111 | ```python 112 | class VectorStore: 113 | def __init__(self, EmbeddingModel: BaseEmbeddings, db_path: str='milvus_demo.db', collection_name: str='my_rag_collection') -> None: 114 | self.EmbeddingModel = EmbeddingModel 115 | self.milvus_client = MilvusClient(uri=db_path) 116 | self.collection_name = collection_name 117 | 118 | def create_collection(self)->None: 119 | if self.milvus_client.has_collection(self.collection_name): 120 | self.milvus_client.drop_collection(self.collection_name) 121 | 122 | # # 创建集合的schema 123 | # fields = [ 124 | # FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=True), 125 | # FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=self.EmbeddingModel._model.config.hidden_size), 126 | # FieldSchema(name="text", dtype=DataType.VARCHAR, max_length=512) 127 | # ] 128 | # collection_schema = CollectionSchema(fields, self.collection_name) 129 | 130 | self.milvus_client.create_collection( 131 | collection_name=self.collection_name, 132 | # schema=collection_schema, 133 | dimension=self.EmbeddingModel._model.config.hidden_size, 134 | metric_type="IP", # Inner product distance 135 | consistency_level="Strong", # Strong consistency level 136 | ) 137 | 138 | def insert(self, data: List[dict]): 139 | self.milvus_client.insert(collection_name=self.collection_name, data=data) 140 | 141 | def query(self, query: str, k: int = 3) -> List[str]: 142 | search_res = self.milvus_client.search( 143 | collection_name=self.collection_name, 144 | data=[ 145 | self.EmbeddingModel.get_embedding(query) 146 | ], 147 | limit=k, # Return top 3 results 148 | # search_params={"metric_type": "IP", "params": {}}, # Inner product distance 149 | output_fields=["text"], # Return the text field 150 | ) 151 | 152 | return [(res["entity"]["text"], res["distance"]) for res in search_res[0]] 153 | ``` 154 | 4. 分类主函数 155 | 156 | 这个没什么好解释的,就是在向量库中召回相似的案例,拼接prompt,给大模型进行输出。 157 | ```python 158 | class VecLlmClassifier: 159 | def __init__(self) -> None: 160 | self.emb_model = BgeEmbedding(EMBEDDING_PATH) 161 | self.retrieval = VectorStore(self.emb_model, DB_PATH) 162 | self.llm = QwenMode(LLM_PATH) 163 | 164 | 165 | def predict(self, query: str, icl=True) -> str: 166 | task_description = PROMPT_TEMPLATE['CLASSIFY_PROMPT_TEMPLATE'] 167 | demonstrations = '' 168 | # ICL 169 | if icl: 170 | demonstrations = self.retrieval.query(query, k=3) 171 | logger.info('大模型生成解释........') 172 | demonstrations = ['文本:' + demonstration[0] + '\n' + '原因:' + \ 173 | self.llm.chat([{'role': 'user', 'content': PROMPT_TEMPLATE['REASON_PROMPT_TEMPLATE'].format(content=demonstration[0])}]) \ 174 | for demonstration in demonstrations] 175 | 176 | 177 | # LLM 178 | logger.info('大模型进行推理........') 179 | output = self.llm.chat([{ 180 | 'role': 'user', 181 | 'content': task_description.format(examples=demonstrations, options='财经、彩票、房产、股票、家居、教育、科技、社会、时尚、时政、体育、星座、游戏、娱乐', options_detail = '', query=query)}]) 182 | return output 183 | ``` 184 | 5. prompt 185 | 186 | 这里偷个小懒,直接采用叉烧大佬的,没有进行修改。 187 | ```python 188 | 你是一个优秀的句子分类师,能把给定的用户query划分到正确的类目中。现在请你根据给定信息和要求,为给定用户query,从备选类目中选择最合适的类目。 189 | 190 | 下面是“参考案例”即被标注的正确结果,可供参考: 191 | 192 | 193 | 备选类目: 194 | 195 | 196 | 类目概念: 197 | 198 | 199 | 用户query: 200 | 201 | 202 | 请注意: 203 | 1. 用户query所选类目,仅能在【备选类目】中进行选择,用户query仅属于一个类目。 204 | 2. “参考案例”中的内容可供推理分析,可以仿照案例来分析用户query的所选类目。 205 | 3. 请仔细比对【备选类目】的概念和用户query的差异。 206 | 4. 如果用户quer也不属于【备选类目】中给定的类目,或者比较模糊,请选择“拒识”。 207 | 5. 请在“所选类目:”后回复结果,不需要说明理由。 208 | 209 | 所选类目: 210 | ``` 211 | ## 结果分析 212 | ### ICL 213 | 下面是使用ICL的结果,效果还算是不错的accuracy, 达到了0.94,较bert文本分类的0.98差了0.4。其中大模型生成错误类别6个,在代码中可以看到,我对类别预测错误的处理就是,append 一个"家居"的类别,看起来对结果的影响不大。 214 | ```python 215 | precision recall f1-score support 216 | 217 | 家居 0.98 0.99 0.99 200 218 | 体育 0.97 0.98 0.98 200 219 | 教育 0.95 0.87 0.91 200 220 | 房产 0.94 0.84 0.89 200 221 | 科技 0.98 0.79 0.87 200 222 | 娱乐 0.93 0.97 0.95 200 223 | 游戏 0.95 0.94 0.95 200 224 | 财经 1.00 0.99 0.99 200 225 | 时尚 0.98 0.99 0.99 200 226 | 时政 0.76 0.99 0.86 200 227 | 228 | accuracy 0.94 2000 229 | macro avg 0.94 0.94 0.94 2000 230 | weighted avg 0.94 0.94 0.94 2000 231 | ``` 232 | ### NO ICL 233 | 在不使用ICL的条件下,accuracy到达0.88, 其中ERROR 58项,对于这个结果,是有点超出我的预料的,主要原因也许是没有改prompt的缘故,对于examples,只是传了一个空的字符串。后续可以考虑构建一个新的prompt试试。 234 | ```python 235 | precision recall f1-score support 236 | 237 | 家居 0.97 0.99 0.98 200 238 | 体育 0.90 0.98 0.94 200 239 | 教育 0.70 0.71 0.71 200 240 | 房产 0.94 0.51 0.66 200 241 | 科技 0.94 0.66 0.77 200 242 | 娱乐 0.96 0.92 0.94 200 243 | 游戏 0.89 0.84 0.87 200 244 | 财经 1.00 0.95 0.98 200 245 | 时尚 0.97 0.85 0.91 200 246 | 时政 0.52 0.98 0.68 200 247 | 248 | accuracy 0.84 2000 249 | macro avg 0.88 0.84 0.84 2000 250 | weighted avg 0.88 0.84 0.84 2000 251 | ``` 252 | ## 结论 253 | 先说一下优缺点,优点就是无需训练也可以有一个较为不错的结果,如果能提供比较好的例子和类别的界限比较清晰,效果会更好。比较适合围绕一个通用大模型的api打造一系列工具的场景。缺点是也大模型的通病:就是上限不会太过,如果仅针对一个分类任务部署一个大模型得不偿失,同时推理速度较慢,icl 的token使用的比较多,在使用收费api的时候也是一笔花销。 254 | 再说一下后续的优化点,可以看到目前的方法中是没有使用到key-bert 的关键词,下图可以看到,一些核心的词语比语意更加重要。 255 | ![](https://picgo-1305561115.cos.ap-beijing.myqcloud.com/img/20240815112045.png) 256 | ## 参考资料 257 | https://mp.weixin.qq.com/s/H9oY4OaFWGJuwAoboLpcLw 258 | https://github.com/KMnO4-zx/TinyRAG/tree/master 259 | https://arxiv.org/abs/2305.02105 260 | -------------------------------------------------------------------------------- /classifier.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | ''' 3 | @File : classifier.py 4 | @Time : 2024/08/13 10:27:11 5 | @Author : sunyd 6 | @Email : sunyongdi@outlook.com 7 | @Version : 1.0 8 | @Desc : None 9 | ''' 10 | from models.LLM import QwenMode 11 | from models.VectorBase import VectorStore 12 | from models.Embeddings import BgeEmbedding 13 | from config import PROMPT_TEMPLATE, EMBEDDING_PATH, DB_PATH, LLM_PATH 14 | from loguru import logger 15 | 16 | 17 | 18 | class VecLlmClassifier: 19 | def __init__(self) -> None: 20 | self.emb_model = BgeEmbedding(EMBEDDING_PATH) 21 | self.retrieval = VectorStore(self.emb_model, DB_PATH) 22 | self.llm = QwenMode(LLM_PATH) 23 | 24 | 25 | def predict(self, query: str, icl=True) -> str: 26 | task_description = PROMPT_TEMPLATE['CLASSIFY_PROMPT_TEMPLATE'] 27 | demonstrations = '' 28 | # ICL 29 | if icl: 30 | demonstrations = self.retrieval.query(query, k=3) 31 | logger.info('大模型生成解释........') 32 | demonstrations = ['文本:' + demonstration[0] + '\n' + '原因:' + \ 33 | self.llm.chat([{'role': 'user', 'content': PROMPT_TEMPLATE['REASON_PROMPT_TEMPLATE'].format(content=demonstration[0])}]) \ 34 | for demonstration in demonstrations] 35 | 36 | 37 | # LLM 38 | logger.info('大模型进行推理........') 39 | output = self.llm.chat([{ 40 | 'role': 'user', 41 | 'content': task_description.format(examples=demonstrations, options='财经、彩票、房产、股票、家居、教育、科技、社会、时尚、时政、体育、星座、游戏、娱乐', options_detail = '', query=query)}]) 42 | return output 43 | 44 | def predict1(self, query: str) -> str: 45 | """已搜代分""" 46 | output = self.retrieval.query(query, k=1) 47 | return output[0][0][-2:] 48 | 49 | if __name__ == '__main__': 50 | vlc = VecLlmClassifier() 51 | res = vlc.predict1('《神兆OL》剑灵归心新人大奖即将揭晓眼看剑灵归心区开放已接近一个月,我是职业之王活动已接近尾声,究竟鹿死谁手还是未知数,玩家的排行形式尚不明朗,很多玩家还有机会,下面是活动的详细内容,看看你到底排第几?能拿到什么东西?活动名称:我是职业之王活动时间:12月3日—1月3日活动规则:在新区开放的同时将同时开放世界职业排行榜,分为7个职业:力士,拳师,侠客,盗贼,猎人,术士,道士。排行标准为单一职业的等级排行,该排行榜一天更新一次,活动时间持续一个月,一个月后即1月3日活动结束时,按照世界职业排行榜情况给予不同的奖励。第一名:获得御赐金箱20个,将之魂石5颗,龙王之鳞1个(一周版),踏焰白玉兽(一月版),金元宝2000第二名:获得御赐金箱15个,将之魂石4颗,60级踏焰白玉兽(一月版),金元宝1000第三名:获得御赐金箱10个,将之魂石3颗,60级踏焰白玉兽(一月版),金元宝500第四名:获得御赐金箱5个,将之魂石2颗,复活药10个,60级踏焰白玉兽(一周版)第五名:获得将之魂石2颗,复活药10个,60级踏焰白玉兽(一周版)第六名:获得将之魂石2颗,复活药10个,60级浴血铁角犀(一周版)第七名:获得将之魂石2颗,复活药10个,二级力量宝石,二级悟性宝石,二级灵巧宝石各一颗第八名:获得将之魂石2颗,复活药10个,二级力量宝石,二级悟性宝石,二级灵巧宝石各一颗第九名:获得将之魂石2颗,复活药10个,大地药水,优雅药水,智慧药水各一瓶第十名:获得将之魂石2颗,复活药10个,大地药水,优雅药水,智慧药水各一瓶活动时间仅剩不到一周时间了,1月3日晚12时即将锁定排行情况,请玩家们继续努力争夺排行名次,为了自己的一片天地奋力打拼吧!') 52 | print(res) 53 | 54 | 55 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | # LLM_PATH = '/root/sunyd/model_hub/qwen/Qwen2-7B-Instruct' 2 | LLM_PATH = '/root/sunyd/model_hub/qwen/Qwen2-0___5B-Instruct' 3 | EMBEDDING_PATH = '/root/sunyd/model_hub/ZhipuAI/bge-large-zh-v1___5' 4 | DB_PATH = '/root/sunyd/llms/llm_classification/storage/rag.db' 5 | 6 | 7 | PROMPT_TEMPLATE = dict( 8 | CLASSIFY_PROMPT_TEMPLATE = """你是一个优秀的句子分类师,能把给定的用户query划分到正确的类目中。现在请你根据给定信息和要求,为给定用户query,从备选类目中选择最合适的类目。 9 | 10 | 下面是“参考案例”即被标注的正确结果,可供参考: 11 | {examples} 12 | 13 | 备选类目: 14 | {options} 15 | 16 | 类目概念: 17 | {options_detail} 18 | 19 | 用户query: 20 | {query} 21 | 22 | 请注意: 23 | 1. 用户query所选类目,仅能在【备选类目】中进行选择,用户query仅属于一个类目。 24 | 2. “参考案例”中的内容可供推理分析,可以仿照案例来分析用户query的所选类目。 25 | 3. 请仔细比对【备选类目】的概念和用户query的差异。 26 | 4. 如果用户quer也不属于【备选类目】中给定的类目,或者比较模糊,请选择“拒识”。 27 | 5. 请在“所选类目:”后回复结果,不需要说明理由。 28 | 29 | 所选类目:""", 30 | 31 | REASON_PROMPT_TEMPLATE = """{content}, 请给出对该分类结果的合理解释。""", 32 | ICL_PROMPT_TEMPLATE = """""" 33 | ) 34 | 35 | -------------------------------------------------------------------------------- /models/Embeddings.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | ''' 3 | @File : Embeddings.py 4 | @Time : 2024/08/13 14:51:01 5 | @Author : sunyd 6 | @Email : sunyongdi@outlook.com 7 | @Version : 1.0 8 | @Desc : None 9 | ''' 10 | from typing import Dict, List, Optional, Tuple, Union 11 | import numpy as np 12 | from abc import ABC, abstractmethod 13 | 14 | 15 | class BaseEmbeddings(ABC): 16 | """ 17 | Base class for embeddings 18 | """ 19 | @abstractmethod 20 | def get_embedding(self, text: str, model: str) -> List[float]: 21 | """获取向量""" 22 | 23 | @classmethod 24 | def cosine_similarity(cls, vector1: List[float], vector2: List[float]) -> float: 25 | """ 26 | calculate cosine similarity between two vectors 27 | """ 28 | dot_product = np.dot(vector1, vector2) 29 | magnitude = np.linalg.norm(vector1) * np.linalg.norm(vector2) 30 | if not magnitude: 31 | return 0 32 | return dot_product / magnitude 33 | 34 | 35 | class BgeEmbedding(BaseEmbeddings): 36 | """ 37 | class for BGE embeddings 38 | """ 39 | 40 | def __init__(self, path: str = 'BAAI/bge-base-zh-v1.5') -> None: 41 | self._model, self._tokenizer = self.load_model(path) 42 | 43 | def get_embedding(self, text: str) -> List[float]: 44 | import torch 45 | encoded_input = self._tokenizer([text], max_length=512, padding='max_length', truncation=True, return_tensors='pt') 46 | encoded_input = {k: v.to(self._model.device) for k, v in encoded_input.items()} 47 | with torch.no_grad(): 48 | model_output = self._model(**encoded_input) 49 | sentence_embeddings = model_output[0][:, 0] 50 | sentence_embeddings = torch.nn.functional.normalize(sentence_embeddings, p=2, dim=1) 51 | return sentence_embeddings[0].tolist() 52 | 53 | def load_model(self, path: str): 54 | import torch 55 | from transformers import AutoModel, AutoTokenizer 56 | if torch.cuda.is_available(): 57 | device = torch.device("cuda") 58 | else: 59 | device = torch.device("cpu") 60 | tokenizer = AutoTokenizer.from_pretrained(path) 61 | model = AutoModel.from_pretrained(path).to(device) 62 | model.eval() 63 | return model, tokenizer 64 | 65 | if __name__ == '__main__': 66 | model_path = '/root/sunyd/model_hub/ZhipuAI/bge-large-zh-v1___5' 67 | model = BgeEmbedding(model_path) 68 | print(model._model.config.hidden_size) 69 | # emb1 = model.get_embedding('你好') 70 | # emb2 = model.get_embedding('hello') 71 | # print(model.cosine_similarity(emb1, emb2)) -------------------------------------------------------------------------------- /models/LLM.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | ''' 3 | @File : LLM.py 4 | @Time : 2024/08/13 11:28:57 5 | @Author : sunyd 6 | @Email : sunyongdi@outlook.com 7 | @Version : 1.0 8 | @Desc : None 9 | ''' 10 | 11 | from transformers import AutoModelForCausalLM, AutoTokenizer 12 | from typing import Tuple, List 13 | from loguru import logger 14 | import torch 15 | from abc import ABC, abstractmethod 16 | 17 | def _gc(): 18 | import gc 19 | gc.collect() 20 | if torch.cuda.is_available(): 21 | torch.cuda.empty_cache() 22 | 23 | class BaseModel(ABC): 24 | 25 | @abstractmethod 26 | def chat(self, messages: str, max_new_tokens: int = 512, do_sample:bool = False, top_k: float = 1, temperature: float = 0.8): 27 | """大模型对话""" 28 | 29 | 30 | class QwenMode(BaseModel): 31 | def __init__(self, model_path) -> None: 32 | self.model = AutoModelForCausalLM.from_pretrained( 33 | model_path, 34 | torch_dtype="auto", 35 | device_map="auto" 36 | ) 37 | self.tokenizer = AutoTokenizer.from_pretrained(model_path) 38 | self.model = self.model.eval() 39 | self.device = self.model.device 40 | 41 | logger.info("load LLM Model done") 42 | 43 | def chat(self, 44 | messages: str, 45 | max_new_tokens: int = 1024, 46 | do_sample:bool = False, 47 | top_k: float = 1, 48 | temperature: float = 0.8 49 | ) -> str: 50 | 51 | text = self.tokenizer.apply_chat_template( 52 | messages, 53 | tokenize=False, 54 | add_generation_prompt=True 55 | ) 56 | model_inputs = self.tokenizer([text], return_tensors="pt").to(self.device) 57 | 58 | generated_ids = self.model.generate( 59 | model_inputs.input_ids, 60 | attention_mask=model_inputs.attention_mask, 61 | pad_token_id=self.tokenizer.eos_token_id, 62 | max_new_tokens=max_new_tokens, 63 | do_sample=do_sample, 64 | top_k=top_k, 65 | temperature=temperature 66 | ) 67 | generated_ids = [ 68 | output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids) 69 | ] 70 | 71 | logger.info(f'input_tokens:{len(model_inputs.input_ids.tolist()[0])} \t generated_ids:{len(generated_ids[0].tolist())} \t all_tokens:{len(model_inputs.input_ids.tolist()[0]) + len(generated_ids[0].tolist())}') 72 | response = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] 73 | _gc() 74 | return response 75 | 76 | if __name__ == "__main__": 77 | LLM_PATH = '/root/sunyd/model_hub/qwen/Qwen2-7B-Instruct' 78 | llm_model = QwenMode(LLM_PATH) 79 | messages = [{'role': 'user', 'content': '文本:流感袭击广东主帅打点滴坚持 李春江无惧“失声”李春江无惧“失声”本报天津电\xa0“首发的队员上去后一定要打好开局,尤其是保护好篮板球,还有就是对于对方重点人的防守……”赛前准备会上,李春江用近乎沙哑的嗓音提醒队员。此次客场之旅,卫冕冠军将士遭遇了寒流的袭击。太原客场,先是王仕鹏在毫无征兆的情况下发烧,教练组在赛前更换了大名单。转战天津后,队医为了预防流感进一步蔓延,让每个队员都喝了板蓝根冲剂。但大鹏刚退烧,李春江和刘晓宇又发烧了。刚一到天津,刘晓宇就感觉不适,前天下午就在队医的带领下前往医院打了点滴,为了不影响到其他队员,球队还专门给刘晓宇安排了单间。即便如此,流感还是没有得到遏制。前天晚上,李春江开始发高烧,撑了一夜后于昨天上午前往医院打了点滴,下午刘晓宇第二次前往医院。为了让其得到休息,昨天的比赛晓宇根本没有上场。队员感冒起码有其他队员轮换,但李春江的位置是无人能顶替的,比赛中李春江依旧用沙哑的嗓音向场内队员高喊注意事项。“还行吧,病来如山倒,我也没办法,只能硬撑着了。”赛后走出体育馆时,李春江在西装外套上了厚厚的大衣和棉裤。其实,大鹏的感冒也没有完全好转,昨天虽然退烧了,但浑身上下没有劲。“腿都是软的,这还算好的,主要是嗓子还在发炎,上场打一会就感觉喘不过气来。”赛后王仕鹏告诉记者。虽然感冒尚未痊愈,但大鹏却展现出了其惊人的效率,他在19分钟内砍下了24分和4个篮板。“现在球队的情况大家都了解,我不能因为感冒而退缩,越是这种艰难时刻,我们老队员就更应该站出来。”王仕鹏说。本报宏远随队记者 刘爱琳\n上述文本所属的类别是体育, 请给出对该分类结果的合理解释。'}] 80 | print(llm_model.chat(messages)) -------------------------------------------------------------------------------- /models/VectorBase.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | ''' 3 | @File : VectorBase.py 4 | @Time : 2024/08/13 15:09:41 5 | @Author : sunyd 6 | @Email : sunyongdi@outlook.com 7 | @Version : 1.0 8 | @Desc : None 9 | ''' 10 | from pymilvus import MilvusClient, FieldSchema, CollectionSchema, DataType 11 | from typing import List, Dict 12 | from .Embeddings import BaseEmbeddings 13 | 14 | class VectorStore: 15 | def __init__(self, EmbeddingModel: BaseEmbeddings, db_path: str='milvus_demo.db', collection_name: str='my_rag_collection') -> None: 16 | self.EmbeddingModel = EmbeddingModel 17 | self.milvus_client = MilvusClient(uri=db_path) 18 | self.collection_name = collection_name 19 | 20 | def create_collection(self) -> None: 21 | if self.milvus_client.has_collection(self.collection_name): 22 | self.milvus_client.drop_collection(self.collection_name) 23 | 24 | # # 创建集合的schema 25 | # fields = [ 26 | # FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=True), 27 | # FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=self.EmbeddingModel._model.config.hidden_size), 28 | # FieldSchema(name="text", dtype=DataType.VARCHAR, max_length=512) 29 | # ] 30 | # collection_schema = CollectionSchema(fields, self.collection_name) 31 | 32 | self.milvus_client.create_collection( 33 | collection_name=self.collection_name, 34 | # schema=collection_schema, 35 | dimension=self.EmbeddingModel._model.config.hidden_size, 36 | metric_type="IP", # Inner product distance 37 | consistency_level="Strong", # Strong consistency level 38 | ) 39 | 40 | def insert(self, data: List[dict]) -> None: 41 | self.milvus_client.insert(collection_name=self.collection_name, data=data) 42 | 43 | def query(self, query: str, k: int = 3) -> List[str]: 44 | search_res = self.milvus_client.search( 45 | collection_name=self.collection_name, 46 | data=[ 47 | self.EmbeddingModel.get_embedding(query) 48 | ], 49 | limit=k, # Return top 3 results 50 | # search_params={"metric_type": "IP", "params": {}}, # Inner product distance 51 | output_fields=["text"], # Return the text field 52 | ) 53 | 54 | return [(res["entity"]["text"], res["distance"]) for res in search_res[0]] 55 | 56 | if __name__ == "__main__": 57 | from Embeddings import BgeEmbedding 58 | model_path = '/root/sunyd/model_hub/ZhipuAI/bge-large-zh-v1___5' 59 | model = BgeEmbedding(model_path) 60 | vec_model = VectorStore(model, db_path='/root/sunyd/llms/llm_classification/storage/rag.db') 61 | vec_model.create_collection() 62 | data = [] 63 | 64 | for i, line in enumerate(['你好', 'hello']): 65 | data.append({"id": i, "vector": model.get_embedding(line), "text": line}) 66 | vec_model.insert(data) 67 | res = vec_model.query('你好') 68 | print(res) 69 | 70 | 71 | -------------------------------------------------------------------------------- /script/create_db.py: -------------------------------------------------------------------------------- 1 | from models.VectorBase import VectorStore 2 | from models.Embeddings import BgeEmbedding 3 | from tqdm import tqdm 4 | 5 | def process_data(): 6 | pass 7 | 8 | def read_txt(file_path): 9 | with open(file_path, 'r', encoding='utf-8') as f: 10 | for line in f: 11 | label, content = line.split('\t') 12 | yield '文本:' + content + '上述文本所属的类别是' + label 13 | 14 | 15 | if __name__ == '__main__': 16 | file_path = '/root/sunyd/llms/llm_classification/data/cnews/cnews.train.txt' 17 | emb_path = '/root/sunyd/model_hub/ZhipuAI/bge-large-zh-v1___5' 18 | emb_model = BgeEmbedding(emb_path) 19 | vec = VectorStore(emb_model, db_path='/root/sunyd/llms/llm_classification/storage/rag.db') 20 | # 初始化 21 | vec.create_collection() 22 | # 创建 23 | dataset = read_txt(file_path) 24 | for i, text in enumerate(tqdm(dataset, desc='解析文本:')): 25 | vec.insert({"id": i, "vector": emb_model.get_embedding(text), "text": text}) 26 | res = vec.query('马晓旭意外受伤让国奥警惕 无奈大雨格外青睐殷家军记者傅亚雨沈阳报道 来到沈阳,国奥队依然没有摆脱雨水的困扰。', k=3) 27 | print(res) 28 | -------------------------------------------------------------------------------- /script/create_test.py: -------------------------------------------------------------------------------- 1 | from models.VectorBase import VectorStore 2 | from models.Embeddings import BgeEmbedding 3 | from tqdm import tqdm 4 | 5 | sample_test = open('/root/sunyd/llms/llm_classification/data/cnews/sample_test.txt', 'a', encoding='utf-8') 6 | 7 | def read_txt(file_path, count=200): 8 | with open(file_path, 'r', encoding='utf-8') as f: 9 | data_dict = dict() 10 | for line in f: 11 | label, content = line.split('\t') 12 | data_dict.setdefault(label, []) 13 | if len(data_dict[label]) < count: 14 | data_dict[label].append(content) 15 | 16 | for k, v in data_dict.items(): 17 | for i in v: 18 | sample_test.write(f'{k.strip()}\t{i.strip()}\n') 19 | 20 | sample_test.close() 21 | 22 | 23 | if __name__ == '__main__': 24 | file_path = '/root/sunyd/llms/llm_classification/data/cnews/cnews.test.txt' 25 | read_txt(file_path) 26 | # emb_path = '/root/sunyd/model_hub/ZhipuAI/bge-large-zh-v1___5' 27 | # emb_model = BgeEmbedding(emb_path) 28 | # vec = VectorStore(emb_model, db_path='/root/sunyd/llms/llm_classification/storage/rag.db') 29 | # # 初始化 30 | # vec.create_collection() 31 | # # 创建 32 | # dataset = read_txt(file_path) 33 | # for i, text in enumerate(tqdm(dataset, desc='解析文本:')): 34 | # vec.insert({"id": i, "vector": emb_model.get_embedding(text), "text": text}) 35 | # res = vec.query('马晓旭意外受伤让国奥警惕 无奈大雨格外青睐殷家军记者傅亚雨沈阳报道 来到沈阳,国奥队依然没有摆脱雨水的困扰。', k=3) 36 | # print(res) 37 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | from sklearn.metrics import classification_report 2 | from classifier import VecLlmClassifier 3 | from tqdm import tqdm 4 | from loguru import logger 5 | 6 | logger.add(sink='runtime_no_icl.log') 7 | 8 | labels = ['家居', '体育', '教育', '房产', '科技', '娱乐', '游戏', '财经', '时尚', '时政'] 9 | 10 | import time 11 | 12 | def timer_decorator(func): 13 | def wrapper(*args, **kwargs): 14 | start_time = time.time() 15 | result = func(*args, **kwargs) 16 | end_time = time.time() 17 | print(f"{func.__name__} took {end_time - start_time:.2f} seconds to execute.") 18 | return result 19 | return wrapper 20 | 21 | @timer_decorator 22 | def get_labels(data_path): 23 | labels = set() 24 | with open(data_path, 'r', encoding='utf-8') as f: 25 | for line in f: 26 | true_label, content = line.split('\t') 27 | labels.add(true_label) 28 | return list(labels) 29 | 30 | def handle_response(llm_response): 31 | pred_label = None 32 | for label in labels: 33 | if label in llm_response: 34 | pred_label = label 35 | break 36 | return pred_label 37 | 38 | @timer_decorator 39 | def main(data_path): 40 | vlc = VecLlmClassifier() 41 | test_outputs, test_targets = [], [] 42 | with open(data_path, 'r', encoding='utf-8') as f: 43 | for line in tqdm(f, total=10000, desc='预测文本:'): 44 | true_label, content = line.split('\t') 45 | test_targets.append(true_label.strip()) 46 | llm_response = vlc.predict1(content.strip()) 47 | # llm_response = vlc.predict(content.strip(), icl=True) 48 | 49 | logger.info(f'预测结果:{llm_response} \t 真实类别:{true_label}') 50 | pred_label = handle_response(llm_response) 51 | if pred_label: 52 | test_outputs.append(pred_label) 53 | else: 54 | logger.error('大模型推理错误:{}'.format(llm_response)) 55 | test_outputs.append('家居') 56 | 57 | report = classification_report(test_targets, test_outputs, target_names=labels) 58 | return report 59 | 60 | 61 | if __name__ == '__main__': 62 | # test_data_path = '/root/sunyd/llms/llm_classification/data/cnews/sample_test.txt' 63 | test_data_path = '/root/sunyd/llms/llm_classification/data/cnews/cnews.test.txt' 64 | print(get_labels(test_data_path)) 65 | print(main(test_data_path)) 66 | -------------------------------------------------------------------------------- /utils/time_it.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunyongdi/llm_classification/6cc3f9aad8d41c64eee089c30ef1ba2135048d7b/utils/time_it.py --------------------------------------------------------------------------------