├── README.md └── Document_QA.py /README.md: -------------------------------------------------------------------------------- 1 | # Document_QA 2 | 3 | 根据传入的文本文件,回答你的问题。 4 | 5 | 核心逻辑来自于chatPDF,自动化客服AI,以及:[ChatWeb](https://github.com/SkywalkerDarren/chatWeb) 6 | 7 | 由于原来的ChatWeb项目使用的是pqsql作为向量存储和计算工具,较为复杂,本项目修改成faiss,更简单快速。 8 | 9 | 10 | # 基本原理 11 | 12 | 1. 读取文件,并进行分割 13 | 2. 对于每段文本,使用text-embedding-ada-002生成特征向量 14 | 3. 将向量和文本对应关系存入本地pkl文件 15 | 4. 对于用户输入,生成向量 16 | 5. 使用向量数据库进行最近邻搜索,返回最相似的文本列表 17 | 6. 使用gpt3.5的chatAPI,设计prompt,使其基于最相似的文本列表进行回答 18 | 19 | 就是先把大量文本中提取相关内容,再进行回答,最终可以达到类似突破token限制的效果 20 | 后续可以考虑将openai的文本向量改成自定义的向量生成工具 21 | 22 | # 准备开始 23 | 24 | - 项目依赖 25 | 26 | 主要依赖 27 | ``` 28 | faiss 29 | numpy 30 | openai 31 | ``` 32 | 33 | - 环境变量 34 | 35 | 设置`OPENAI_API_KEY`为你的openai的api key 36 | 37 | ```shell 38 | export OPENAI_API_KEY="sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx" 39 | ``` 40 | 41 | - 运行 42 | 43 | ``` 44 | python Document_QA.py --input_file test.md --file_embeding test.pkl 45 | ``` -------------------------------------------------------------------------------- /Document_QA.py: -------------------------------------------------------------------------------- 1 | 2 | import openai 3 | import faiss 4 | import numpy as np 5 | import pickle 6 | from tqdm import tqdm 7 | import argparse 8 | import os 9 | 10 | def create_embeddings(input): 11 | """Create embeddings for the provided input.""" 12 | result = [] 13 | # limit about 1000 tokens per request 14 | lens = [len(text) for text in input] 15 | query_len = 0 16 | start_index = 0 17 | tokens = 0 18 | 19 | def get_embedding(input_slice): 20 | embedding = openai.Embedding.create(model="text-embedding-ada-002", input=input_slice) 21 | return [(text, data.embedding) for text, data in zip(input_slice, embedding.data)], embedding.usage.total_tokens 22 | 23 | for index, l in tqdm(enumerate(lens)): 24 | query_len += l 25 | if query_len > 4096: 26 | ebd, tk = get_embedding(input[start_index:index + 1]) 27 | query_len = 0 28 | start_index = index + 1 29 | tokens += tk 30 | result.extend(ebd) 31 | 32 | if query_len > 0: 33 | ebd, tk = get_embedding(input[start_index:]) 34 | tokens += tk 35 | result.extend(ebd) 36 | return result, tokens 37 | 38 | def create_embedding(text): 39 | """Create an embedding for the provided text.""" 40 | embedding = openai.Embedding.create(model="text-embedding-ada-002", input=text) 41 | return text, embedding.data[0].embedding 42 | 43 | class QA(): 44 | def __init__(self,data_embe) -> None: 45 | d = 1536 46 | index = faiss.IndexFlatL2(d) 47 | embe = np.array([emm[1] for emm in data_embe]) 48 | data = [emm[0] for emm in data_embe] 49 | index.add(embe) 50 | self.index = index 51 | self.data = data 52 | def __call__(self, query): 53 | embedding = create_embedding(query) 54 | context = self.get_texts(embedding[1], limit) 55 | answer = self.completion(query,context) 56 | return answer,context 57 | def get_texts(self,embeding,limit): 58 | _,text_index = self.index.search(np.array([embeding]),limit) 59 | context = [] 60 | for i in list(text_index[0]): 61 | context.extend(self.data[i:i+5]) 62 | # context = [self.data[i] for i in list(text_index[0])] 63 | return context 64 | 65 | def completion(self,query, context): 66 | """Create a completion.""" 67 | lens = [len(text) for text in context] 68 | 69 | maximum = 3000 70 | for index, l in enumerate(lens): 71 | maximum -= l 72 | if maximum < 0: 73 | context = context[:index + 1] 74 | print("超过最大长度,截断到前", index + 1, "个片段") 75 | break 76 | 77 | text = "\n".join(f"{index}. {text}" for index, text in enumerate(context)) 78 | response = openai.ChatCompletion.create( 79 | model="gpt-3.5-turbo", 80 | messages=[ 81 | {'role': 'system', 82 | 'content': f'你是一个有帮助的AI文章助手,从下文中提取有用的内容进行回答,不能回答不在下文提到的内容,相关性从高到底排序:\n\n{text}'}, 83 | {'role': 'user', 'content': query}, 84 | ], 85 | ) 86 | print("使用的tokens:", response.usage.total_tokens) 87 | return response.choices[0].message.content 88 | 89 | if __name__ == '__main__': 90 | parser = argparse.ArgumentParser(description="Document QA") 91 | parser.add_argument("--input_file", default="input.txt", dest="input_file", type=str,help="输入文件路径") 92 | parser.add_argument("--file_embeding", default="input_embed.pkl", dest="file_embeding", type=str,help="文件embeding文件路径") 93 | parser.add_argument("--print_context", action='store_true',help="是否打印上下文") 94 | 95 | 96 | args = parser.parse_args() 97 | 98 | if os.path.isfile(args.file_embeding): 99 | data_embe = pickle.load(open(args.file_embeding,'rb')) 100 | else: 101 | with open(args.input_file,'r',encoding='utf-8') as f: 102 | texts = f.readlines() 103 | texts = [text.strip() for text in texts if text.strip()] 104 | data_embe,tokens = create_embeddings(texts) 105 | pickle.dump(data_embe,open(args.file_embeding,'wb')) 106 | print("文本消耗 {} tokens".format(tokens)) 107 | 108 | qa =QA(data_embe) 109 | 110 | limit = 10 111 | while True: 112 | query = input("请输入查询(help可查看指令):") 113 | if query == "quit": 114 | break 115 | elif query.startswith("limit"): 116 | try: 117 | limit = int(query.split(" ")[1]) 118 | print("已设置limit为", limit) 119 | except Exception as e: 120 | print("设置limit失败", e) 121 | continue 122 | elif query == "help": 123 | print("输入limit [数字]设置limit") 124 | print("输入quit退出") 125 | continue 126 | answer,context = qa(query) 127 | if args.print_context: 128 | print("已找到相关片段:") 129 | for text in context: 130 | print('\t', text) 131 | print("=====================================") 132 | print("回答如下\n\n") 133 | print(answer.strip()) 134 | print("=====================================") 135 | 136 | --------------------------------------------------------------------------------