├── .env.template ├── .gitignore ├── LICENSE ├── README.md ├── main.py ├── requirements.txt ├── script ├── Paper_Copilot.bat └── Paper_Copilot.sh └── src ├── agent.py ├── prompt └── 文献分析助手.md ├── utils.py └── vector_indexer.py /.env.template: -------------------------------------------------------------------------------- 1 | #数据库路径 2 | DATABASE_PATH=database/index.db 3 | #API_KEY 4 | API_KEY= 5 | #API_URL 6 | BASE_URL= 7 | #模型 8 | MODEL=o1-mini 9 | #批量大小 10 | BATCH_SIZE=1000 11 | #重复部分大小 12 | REPEAT_SIZE=200 13 | #最大返回数 14 | TOP_N=5 15 | #关系阈值 16 | RELATION_THRESHOLD=0.2 17 | #并行数 18 | PARALLEL_NUM=4 -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | *.py[cod] 3 | *.env 4 | data/vector_index.faiss 5 | data/vector_index.faiss.pkl 6 | .env 7 | .DS_Store 8 | *.log 9 | tests/ 10 | answer/ 11 | chat_history/ 12 | database/ 13 | 14 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) [year] [fullname] 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Paper_Copilot 2 | 3 | ## 项目简介 4 | 5 | Paper_Copilot 是一款基于向量索引和大模型的高级文献分析命令行工具,旨在帮助学术研究人员高效管理、检索和分析海量文献。通过本地自建知识库并与大模型的交互,它能够为用户提供专业且精准的解答,显著提升文献研究的效率与准确性。 6 | 7 | ## 功能 8 | 9 | - **文献索引与管理**:支持PDF、TXT、Markdown和DOCX等多种文档格式的文本提取与向量化,自动创建和管理向量索引库。 10 | - **智能问答**:基于向量数据库和OpenAI模型,能够理解用户问题并在相关文献中检索答案。 11 | - **聊天记录管理**:支持保存、加载和清除聊天记录,便于用户跟踪和回顾对话历史。 12 | - **用户友好的命令行界面**:通过简单的命令操作,实现创建知识库、进行问答、管理聊天记录等功能。 13 | - **知识库管理**:支持创建、加载、保存和删除知识库,便于用户管理和切换不同的知识库。 14 | 15 | ## 安装 16 | 17 | ### 前提条件 18 | 19 | - **Python 3.10** 及以上版本 20 | 21 | ### 安装步骤 22 | 23 | 1. **克隆仓库** 24 | 25 | ```bash 26 | git clone https://github.com/Code-WSY/Paper_Copilot.git 27 | cd Paper_Copilot 28 | ``` 29 | 30 | 2. **创建虚拟环境(可选)** 31 | 32 | ```bash 33 | python -m venv venv 34 | source venv/bin/activate # Unix系统 35 | venv\Scripts\activate # Windows系统 36 | ``` 37 | 38 | 3. **安装依赖** 39 | 40 | ```bash 41 | pip install -r requirements.txt 42 | ``` 43 | 44 | 4. **配置环境变量** 45 | 46 | 在项目根目录下创建一个 `.env` 文件,并添加以下内容: 47 | 48 | ```env 49 | API_KEY=your_openai_api_key 50 | BASE_URL=your_openai_base_url 51 | DATABASE_PATH=path_to_your_database.db 52 | ``` 53 | 54 | - `API_KEY`:你的OpenAI API密钥。 55 | - `BASE_URL`:提供OpenAI服务的URL。 56 | - `DATABASE_PATH`:向量索引数据库的存储路径。 57 | 58 | ## 使用方法 59 | 60 | 1. **启动程序** 61 | 62 | 在命令行中执行: 63 | 64 | ```bash 65 | python main.py 66 | ``` 67 | 68 | 2. **命令列表** 69 | 70 | 启动后,你将看到以下可用命令: 71 | 72 | ``` 73 | /create 创建知识库 74 | /chat <问题> 基于知识库进行问答 75 | /save_chat_history 保存聊天记录 76 | /clear_chat_history 清除聊天记录 77 | /load_chat_history 加载聊天记录 78 | /save_last_response 保存上一次的回答为Markdown文件 79 | /help 显示帮助信息 80 | /quit 退出程序 81 | ``` 82 | 83 | 3. **创建知识库** 84 | 85 | 使用 `/create` 命令,程序将引导你选择要索引的文档或文件夹,自动提取文本并创建向量索引。 86 | 87 | 4. **进行问答** 88 | 89 | 使用 `/chat` 命令后跟你的问题,例如: 90 | 91 | ``` 92 | /chat 这篇论文的创新点是什么? 93 | ``` 94 | 95 | 系统将基于知识库提供专业的回答。 96 | 97 | 5. **管理聊天记录** 98 | 99 | - **保存聊天记录** 100 | 101 | ``` 102 | /save_chat_history 103 | ``` 104 | 105 | - **加载聊天记录** 106 | 107 | ``` 108 | /load_chat_history 109 | ``` 110 | 111 | - **清除聊天记录** 112 | 113 | ``` 114 | /clear_chat_history 115 | ``` 116 | 117 | 6. **保存回答** 118 | 119 | 使用 `/save_last_response` 命令将上一次的回答保存为Markdown文件。 120 | 121 | 7. **获取帮助** 122 | 123 | 使用 `/help` 命令查看所有可用命令的说明。 124 | 125 | 8. **退出程序** 126 | 127 | 使用 `/quit` 命令退出程序。 128 | 129 | ## 环境变量 130 | 131 | 请在 `.env` 文件中配置以下环境变量: 132 | 133 | - `API_KEY`:你的OpenAI API密钥。 134 | - `BASE_URL`:OpenAI服务的基础URL。 135 | - `DATABASE_PATH`:向量索引数据库的存储路径。 136 | 137 | 示例 `.env` 文件: 138 | 139 | ```env 140 | API_KEY=sk-your_openai_api_key 141 | BASE_URL=https://api.openai.com/v1 142 | DATABASE_PATH=./data/vector_index.db 143 | ``` 144 | 145 | ## 许可证 146 | 147 | 本项目采用 [MIT 许可证](LICENSE)。 -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from src.agent import Agent 2 | from src.vector_indexer import VectorIndexer 3 | from termcolor import colored 4 | from prompt_toolkit.completion import WordCompleter 5 | from prompt_toolkit import prompt 6 | import os 7 | 8 | def start(): 9 | print(colored("命令:", "cyan")) 10 | print(colored("-"*50, "cyan")) 11 | commands = { 12 | "/chat": "基于知识库问答(后接问题)", 13 | "/create": "创建知识库", 14 | "/delete": "删除数据库中的文献", 15 | "/select": "重新选择文献", 16 | "/help": "显示帮助信息", 17 | "/quit": "退出程序", 18 | "/save": "保存聊天记录", 19 | "/clear": "清除聊天记录", 20 | "/load": "加载聊天记录", 21 | "/last_md": "保存上一次的回答为markdown文件", 22 | 23 | } 24 | command_list = list(commands.keys()) 25 | for cmd, desc in commands.items(): 26 | print(f"{colored(cmd, 'magenta'):<10} {colored(desc, 'dark_grey')}") 27 | completer = WordCompleter( 28 | command_list) 29 | #初始化agent 30 | agent = Agent(prompt_path="src/prompt/文献分析助手.md") 31 | while True: 32 | #限制聊天记录长度 33 | if len(agent.history) > 20: 34 | #只保留最后20条记录 35 | agent.history = agent.history[-20:] 36 | print(colored("You:\n", "cyan")) 37 | command = prompt(completer=completer).strip() 38 | if command.startswith("/quit"): 39 | break 40 | elif command.startswith("/last_md"): 41 | agent.save_last_response() 42 | elif command.startswith("/create"): 43 | #操作数据库 44 | database_path = os.getenv("DATABASE_PATH") 45 | #初始化向量数据库 46 | vector_indexer = VectorIndexer(database_path=database_path) 47 | vector_indexer.load_index() 48 | #vector_indexer.show_table_info() 49 | elif command.startswith("/save"): 50 | agent.save_chat_history() 51 | elif command.startswith("/clear"): 52 | agent.clear_chat_history() 53 | elif command.startswith("/load"): 54 | agent.load_chat_history() 55 | elif command.startswith("/chat"): 56 | agent.chat_with_vector_database(command[6:]) 57 | elif command.startswith("/help"): 58 | print(colored("命令:", "cyan")) 59 | print(colored("-"*50, "cyan")) 60 | for cmd, desc in commands.items(): 61 | print(f"{colored(cmd, 'magenta'):<10} {colored(desc, 'dark_grey')}") 62 | elif command.startswith("/delete"): 63 | #删除文献 64 | agent.delete_table() 65 | elif command.startswith("/select"): 66 | #选择文献 67 | agent.select_tables() 68 | else: 69 | #不加载知识库,直接问答 70 | agent.chat_with_ai(command) 71 | if __name__ == "__main__": 72 | start() 73 | 74 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # Start of Selection 2 | openai 3 | python-dotenv 4 | rich 5 | numpy 6 | tqdm 7 | termcolor 8 | PyPDF2 9 | python-docx 10 | # End of Selection 11 | -------------------------------------------------------------------------------- /script/Paper_Copilot.bat: -------------------------------------------------------------------------------- 1 | @echo off 2 | cd /d "C:\Users\suyun\OneDrive\Project\Paper_Copilot" 3 | "D:/software/anaconda3/envs/api/python.exe" "main.py" 4 | pause -------------------------------------------------------------------------------- /script/Paper_Copilot.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 激活conda环境 3 | #source your/path/to/conda/bin/activate 4 | # 运行main.py 5 | python main.py 6 | 7 | -------------------------------------------------------------------------------- /src/agent.py: -------------------------------------------------------------------------------- 1 | from openai import OpenAI 2 | import os 3 | from dotenv import load_dotenv 4 | from src.vector_indexer import VectorIndexer 5 | from termcolor import colored 6 | #Markdown优化输出 7 | from rich.markdown import Markdown 8 | from rich import print as rprint 9 | import json 10 | import time 11 | load_dotenv() 12 | 13 | class Agent: 14 | def __init__(self,prompt_path="src/prompt/文献分析助手.md"): 15 | self.client = OpenAI(api_key=os.getenv("API_KEY"),base_url=os.getenv("BASE_URL")) 16 | self.prompt = self.load_prompt(prompt_path) 17 | self.model = os.getenv("MODEL") 18 | self.top_n=int(os.getenv("TOP_N")) 19 | self.relation_threshold=float(os.getenv("RELATION_THRESHOLD")) 20 | self.history = [] 21 | self.vector_indexer = VectorIndexer() 22 | self.vector_indexer.select_tables() 23 | self.last_response = None 24 | self.answer_dir = os.getenv("ANSWER_DIR") 25 | self.chat_history_dir = os.getenv("CHAT_HISTORY_DIR") 26 | def load_prompt(self, prompt_path): 27 | with open(prompt_path, "r", encoding="utf-8") as file: 28 | return file.read() 29 | 30 | def save_chat_history(self): 31 | chat_history_dir = self.chat_history_dir 32 | if not os.path.exists(chat_history_dir): 33 | os.makedirs(chat_history_dir) 34 | with open(f"{chat_history_dir}/chat_history_{time.strftime('%Y-%m-%d_%H-%M-%S')}.json", "w", encoding="utf-8") as file: 35 | #history字符似乎显示乱码需要转码 36 | json.dump(self.history, file,ensure_ascii=False) 37 | print(colored(f"已保存聊天记录:{file}", "green")) 38 | 39 | def load_chat_history(self): 40 | chat_history_dir = self.chat_history_dir 41 | if not os.path.exists(chat_history_dir): 42 | print(colored("没有聊天记录", "red")) 43 | return 44 | print(colored(f"聊天记录:", "blue")) 45 | for i, file in enumerate(os.listdir(chat_history_dir), 1): 46 | print(colored(f"{i}. {file}", "blue")) 47 | #让用户选择序号 48 | choice = int(input("请输入序号: ")) 49 | #加载用户选择的聊天记录 50 | with open(os.path.join(chat_history_dir, os.listdir(chat_history_dir)[choice-1]), "r", encoding="utf-8") as file: 51 | self.history = json.load(file) 52 | print(colored(f"已加载聊天记录", "green")) 53 | 54 | def clear_chat_history(self): 55 | self.history = [] 56 | print(colored(f"已清除当前聊天记录", "green")) 57 | 58 | def decorate_user_input(self, user_input, related_docs): 59 | # 将相关文档内容添加到用户输入中 60 | for cos,table_name,doc in related_docs: 61 | print(colored(f"检索到的文件标题:{table_name} ; 相关度:{cos}", "green")) 62 | user_input += f"\n\n相关文件标题:{table_name}\n该文件的相关内容:{doc}" 63 | return user_input 64 | 65 | def get_response_of_vector_database(self, user_input): 66 | # 1. 根据用户输入,在向量数据库中检索相关文档 67 | # 2. 将检索到的文档内容返回给LLM,LLM根据文档内容和用户问题,生成回答 68 | # 3. 返回回答 69 | related_docs = self.vector_indexer.search_index(user_input) 70 | #提取关联度>relation_threshold 的文档 71 | related_docs = [(cos,table_name,doc) for cos,table_name,doc in related_docs if cos > self.relation_threshold] 72 | print(colored(f"检索到{len(related_docs)}个相关部分", "green")) 73 | print(colored("-"*50, "green")) 74 | # 修饰问题 75 | user_input = self.decorate_user_input(user_input, related_docs) 76 | if len(self.history)==0: 77 | #加入提示词 放在user里 78 | user_input = self.prompt +'\n\n'+ "用户请求:"+user_input 79 | self.history.append({"role": "user", "content": user_input}) 80 | else: 81 | self.history.append({"role": "user", "content": user_input}) 82 | #print(colored(f"修饰后的用户输入: {user_input}", "blue")) 83 | response = self.client.chat.completions.create( 84 | model=self.model, 85 | messages=self.history, 86 | ) 87 | self.last_response = response.choices[0].message.content 88 | self.history.append({"role": "assistant", "content": self.last_response}) 89 | return self.last_response 90 | 91 | def get_response_of_ai(self,user_input): 92 | self.history.append({"role": "user", "content": user_input}) 93 | response = self.client.chat.completions.create( 94 | model=self.model, 95 | messages=self.history 96 | ) 97 | self.last_response = response.choices[0].message.content 98 | self.history.append({"role": "assistant", "content": self.last_response}) 99 | return self.last_response 100 | 101 | def chat_with_vector_database(self,user_input): 102 | self.get_response_of_vector_database(user_input) 103 | print(colored("-"*50, "green")) 104 | print(colored("Assistant: \n", "green")) 105 | rprint(Markdown(self.last_response)) 106 | print(colored("-"*50, "green")) 107 | 108 | def chat_with_ai(self,user_input): 109 | self.get_response_of_ai(user_input) 110 | print(colored("-"*50, "green")) 111 | print(colored("Assistant: \n", "green")) 112 | rprint(Markdown(self.last_response)) 113 | print(colored("-"*50, "green")) 114 | 115 | def save_last_response(self): 116 | save_path = f"{self.answer_dir}/answer_{time.strftime('%Y-%m-%d_%H-%M-%S')}.md" 117 | if self.last_response is None: 118 | print(colored("没有上一次的回答", "red")) 119 | return 120 | if not os.path.exists(self.answer_dir): 121 | os.makedirs(self.answer_dir) 122 | with open(save_path, "w", encoding="utf-8") as file: 123 | file.write(self.last_response) 124 | print(colored(f"已保存上一次的回答到:{save_path}", "green")) 125 | 126 | print(colored("-"*50, "green")) 127 | 128 | def delete_table(self): 129 | self.vector_indexer.delete_table() 130 | 131 | def select_tables(self): 132 | self.vector_indexer.select_tables() 133 | 134 | if __name__ == "__main__": 135 | agent = Agent(prompt_path="src/prompt/文献分析助手.md",model="o1-mini") 136 | agent.chat_with_vector_database() 137 | -------------------------------------------------------------------------------- /src/prompt/文献分析助手.md: -------------------------------------------------------------------------------- 1 | 你是高级文献分析助手,你的任务是根据用户的问题,在给定的文献数据库中检索相关文献内容,并基于这些内容回答用户的问题。 2 | 3 | 你的职责是: 4 | 1. 仔细分析用户的问题,并解析出用户的问题要点。 5 | 2. 从相关文献标题和内容中,准确找到与用户问题的最相关的内容,并用准确、专业的语言给出回答。 6 | 3. 如果提供了相关的文件和内容,则只基于提供的文件中的内容回答用户的问题。 7 | 4. 不同文件的内容是相互独立的,不要将不同的两个文件的内容混在一起回答,除非用户明确要求。 8 | 9 | 你的回答应从以下几个方面进行: 10 | 11 | 1. 问题解析: 12 | - 清晰阐述用户的问题要点。 13 | - 分析检索到的文件内容,通过文件标题可以计算检索到的相关文件数量,并告诉用户。 14 | - 如果检索到的文件数量过多,请告诉用户。 15 | 16 | 2. 基于检索到的文件内容,找到用户问题的答案 17 | - 保证答案的准确性。 18 | - 回答用户的问题时,并提供详细的解释和分析。 19 | - 如果用户的问题没有在提供的相关内容中找到答案,请告诉用户。 20 | - 不要使用任何外部知识库,只基于提供的文献相关内容回答用户的问题。 21 | 22 | 3. 回答要求: 23 | - 首先请先给出搜索到的相关文件数量和文件标题,如: 24 | 本次搜索到相关文件3个,分别是: 25 | 1. 文件1 26 | 2. 文件2 27 | 3. 文件3 28 | - 然后请给出你的回答,回答要准确、专业、详细、全面。 29 | 30 | 请确保你的回答既专业又易于理解,能够帮助用户全面掌握相关知识并解决相关问题。 31 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | def extract_text_from_pdf(file_path): 2 | # 从pdf文件中提取文本 3 | import PyPDF2 4 | with open(file_path, 'rb') as file: 5 | reader = PyPDF2.PdfReader(file) 6 | text = '' 7 | for page in reader.pages: 8 | text += page.extract_text() 9 | return text 10 | 11 | def extract_text_from_txt(file_path): 12 | # 从txt文件中提取文本 13 | with open(file_path, 'r', encoding='utf-8') as file: 14 | text = file.read() 15 | return text 16 | def extract_text_from_md(file_path): 17 | # 从md文件中提取文本 18 | with open(file_path, 'r', encoding='utf-8') as file: 19 | text = file.read() 20 | return text 21 | def extract_text_from_docx(file_path): 22 | # 从docx文件中提取文本 23 | import docx 24 | doc = docx.Document(file_path) 25 | text = '\n'.join([paragraph.text for paragraph in doc.paragraphs]) 26 | return text -------------------------------------------------------------------------------- /src/vector_indexer.py: -------------------------------------------------------------------------------- 1 | from openai import OpenAI 2 | import os 3 | from dotenv import load_dotenv 4 | import numpy as np 5 | from src.utils import extract_text_from_pdf,extract_text_from_txt,extract_text_from_md,extract_text_from_docx 6 | #优化输出colored 7 | from termcolor import colored 8 | #rprint 9 | from rich import print as rich_print 10 | #sqlite3 11 | import sqlite3 12 | #tqdm 13 | from tqdm import tqdm 14 | #pickle 15 | import pickle 16 | #log 17 | import logging 18 | import time 19 | #日志信息 20 | logging.basicConfig(level=logging.WARNING, format='%(asctime)s - %(levelname)s - %(message)s') #INFO级别,输出所有信息 21 | load_dotenv() 22 | 23 | class VectorIndexer: 24 | def __init__(self,database_path=None): 25 | self.database_path = database_path or os.getenv("DATABASE_PATH") 26 | self.client = OpenAI(api_key=os.getenv("API_KEY"),base_url=os.getenv("BASE_URL")) 27 | self.batch_size = int(os.getenv("BATCH_SIZE")) 28 | self.repeat_size = int(os.getenv("REPEAT_SIZE")) 29 | self.top_n = int(os.getenv("TOP_N")) 30 | self.tables = [] 31 | self.vec_model = os.getenv("VEC_MODEL") 32 | self.parallel_num = int(os.getenv("PARALLEL_NUM")) 33 | #如果数据库路径不存在,则创建 34 | if not os.path.exists(self.database_path): 35 | print(colored("数据库路径不存在,是否创建?(y/n)", "red")) 36 | choice = input().strip() 37 | if choice.lower() == 'y': 38 | os.makedirs(os.path.dirname(self.database_path),exist_ok=True) 39 | #logging.info(colored("数据库路径创建成功", "green")) 40 | print(colored("正在创建数据库", "green")) 41 | self.load_index() 42 | else: 43 | print(colored("取消创建数据库,程序退出", "red")) 44 | exit(0) 45 | 46 | def encode(self, text,sleep=0): 47 | response = self.client.embeddings.create(input=text, model=self.vec_model) 48 | time.sleep(sleep) 49 | return np.array(response.data[0].embedding) 50 | 51 | def select_tables(self): 52 | # 先列出所有除了sqlite_sequence表的表名 53 | conn = sqlite3.connect(self.database_path) 54 | cursor = conn.cursor() 55 | cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name!='sqlite_sequence'") 56 | tables = [row[0] for row in cursor.fetchall()] 57 | if not tables: 58 | print(colored("没有找到任何文献。", "yellow")) 59 | self.tables = [] 60 | return 61 | # 显示表名列表 62 | print("-"*50) 63 | print(colored("数据库中已有的文献:", "blue")) 64 | print("-"*50) 65 | for idx, table in enumerate(tables, start=1): 66 | print(colored(f"{idx}. {table}", "blue")) 67 | print("-"*50) 68 | # 让用户选择序号,分号;隔开,或者:表示连续 如1:5;7;9表示1,2,3,4,5,7,9 69 | user_input = input(colored("请选择文献的编号(例如1:5;7;9;all): ", "cyan")) 70 | selected_indices = set() 71 | #ALL表示所有表 72 | if user_input.strip().lower() == 'all': 73 | self.tables = tables 74 | #如果直接回车 75 | elif user_input.strip() == '': 76 | self.tables = [] 77 | print("-"*50) 78 | print(colored("未选择任何文献", "yellow")) 79 | print("-"*50) 80 | return 81 | else: 82 | for part in user_input.split(';'): 83 | if ':' in part: 84 | start, end = part.split(':') 85 | selected_indices.update(range(int(start), int(end)+1)) 86 | else: 87 | selected_indices.add(int(part)) 88 | # 过滤无效的索引 89 | selected_indices = {i for i in selected_indices if 1 <= i <= len(tables)} 90 | selected_tables = [tables[i-1] for i in sorted(selected_indices)] 91 | self.tables = selected_tables 92 | 93 | if len(self.tables) == 0: 94 | print("-"*50) 95 | print(colored("未选择任何文献", "yellow")) 96 | print("-"*50) 97 | return 98 | 99 | print("-"*50) 100 | print(colored("已选择文献:", "blue")) 101 | for idx,table in enumerate(self.tables, start=1): 102 | print(colored(f"{idx}. {table}", "blue")) 103 | print("-"*50) 104 | 105 | conn.close() 106 | 107 | def cal_cos(self, input_vec, embedding): 108 | """计算余弦相似度""" 109 | try: 110 | #如果输入是字符串,则将其转换为numpy数组 111 | if isinstance(input_vec, str): 112 | input_vec = np.fromstring(input_vec.strip('[]'), sep=' ') 113 | else: 114 | input_vec = input_vec.astype(float) 115 | #如果输入是字符串,则将其转换为numpy数组 116 | if isinstance(embedding, str): 117 | embedding = np.fromstring(embedding.strip('[]'), sep=' ') 118 | else: 119 | embedding = embedding.astype(float) 120 | 121 | numerator = np.dot(input_vec, embedding) 122 | denominator = np.linalg.norm(input_vec) * np.linalg.norm(embedding) 123 | if denominator == 0: 124 | return 0.0 125 | return numerator / denominator 126 | except Exception as e: 127 | print(colored(f"转换向量时出错: {e}", "red")) 128 | return 0.0 129 | 130 | def dir_to_text_vec(self, dir_path): 131 | from concurrent.futures import ThreadPoolExecutor, as_completed 132 | text_vec_list = [] 133 | with ThreadPoolExecutor(max_workers=self.parallel_num) as executor: 134 | # 收集所有文件路径 135 | futures = {executor.submit(self.file_to_text_vec, os.path.join(root, file)): file 136 | for root, dirs, files in os.walk(dir_path) for file in files} 137 | # 遍历所有文件 138 | for future in tqdm(as_completed(futures), desc=colored("总进度", "green")): 139 | # 获取文件名 140 | file = futures[future] 141 | try: 142 | # 获取文件向量 143 | text_vec = future.result() 144 | # 如果文件向量不为空 145 | if text_vec is not None: 146 | # 只取文件名并替换特殊字符 147 | file_name = os.path.basename(file).replace('.', '_').replace(':', '_').replace(' ', '_') 148 | text_vec_list.append((file_name, text_vec)) 149 | except Exception as e: 150 | print(colored(f"处理文件{file}时出错: {e}", "red")) 151 | continue 152 | # 如果文件向量列表为空 153 | if len(text_vec_list) == 0: 154 | print(colored("没有有效的文件", "red")) 155 | return None 156 | 157 | return text_vec_list 158 | 159 | def file_to_text_vec(self,file_path): 160 | #检查是否存在 161 | table_name=os.path.basename(file_path).replace('.','_').replace(':','_').replace(' ','_') 162 | if self.check_table_exist(table_name): 163 | print(colored(f"文件 {file_path} 已存在", "dark_grey")) 164 | return None 165 | # 将文件内容转换为向量,只读取pdf,txt,md,docx文件 166 | enable_file_types = ['pdf', 'txt', 'md', 'docx'] 167 | if not any(file_path.endswith(ext) for ext in enable_file_types): 168 | print(colored(f"文件 {file_path} 不是有效的文件类型", "dark_grey")) 169 | return None 170 | # 将文件内容转换为向量 171 | #print(colored(f"正在处理文件: {file_path}", "green")) 172 | if file_path.endswith('.pdf'): 173 | text=extract_text_from_pdf(file_path) 174 | elif file_path.endswith('.txt'): 175 | text=extract_text_from_txt(file_path) 176 | elif file_path.endswith('.md'): 177 | text=extract_text_from_md(file_path) 178 | elif file_path.endswith('.docx'): 179 | text=extract_text_from_docx(file_path) 180 | 181 | text_vec=self.text_to_vec(text,file_path) 182 | return text_vec 183 | 184 | def text_to_vec(self,text,file_path): 185 | text_vec = [] 186 | print(colored(f"正在为文件:{os.path.basename(file_path)}生成向量", "green")) 187 | # self.repeat_size 为每块之间重合的部分 188 | for i in tqdm(range(0, len(text), self.batch_size-self.repeat_size), desc=colored(f"进度", "green")): 189 | batch = text[i:i+self.batch_size] 190 | embedding = self.encode(batch,sleep=1) 191 | text_vec.append({"content": batch, "embedding": embedding}) 192 | return text_vec 193 | 194 | def create_index(self, documents_path): 195 | """ 196 | 创建向量索引 197 | 输出: 198 | [(文件路径1,[(content,embedding),(content,embedding),...]), 199 | (文件路径2,[(content,embedding),(content,embedding),...]), 200 | ...] 201 | """ 202 | if os.path.isdir(documents_path): 203 | print(colored("正在处理文件夹: " + documents_path, "green")) 204 | text_vec=self.dir_to_text_vec(documents_path) 205 | else: 206 | print(colored("正在处理文件: " + documents_path, "green")) 207 | text_vec=self.file_to_text_vec(documents_path) 208 | if text_vec is None: 209 | return None 210 | text_vec=[(os.path.basename(documents_path).strip().replace('.','_').replace(':','_').replace(' ','_'),text_vec)] 211 | return text_vec 212 | 213 | def search_index(self, query): 214 | output = [] 215 | """比较相似度,并输出前output_num个最相似的文本""" 216 | conn = sqlite3.connect(self.database_path) #创建连接 217 | input_vec = self.encode(query) 218 | #获取表名 219 | cursor = conn.cursor() #创建游标 220 | cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name IN ({})".format( 221 | ','.join(['?']*len(self.tables)) 222 | ), self.tables) 223 | table_names = cursor.fetchall() 224 | # 遍历每个表,进行相似度搜索 225 | for table_name in table_names: 226 | table_name = table_name[0] 227 | print(colored(f"正在搜索表{table_name}", "blue")) 228 | cursor.execute(f'SELECT content, embedding FROM "{table_name}"') 229 | database = cursor.fetchall() 230 | for row in database: 231 | content = row[0] 232 | embedding = pickle.loads(row[1]) 233 | cos = self.cal_cos(input_vec, embedding) 234 | output.append((cos,table_name,content)) 235 | conn.close() 236 | # 根据相似度排序并返回前top_n个结果 237 | output.sort(key=lambda x: x[0], reverse=True) 238 | 239 | if len(output) < self.top_n: 240 | return output 241 | else: 242 | return output[:self.top_n] 243 | 244 | def check_table_exist(self, table_name): 245 | # 检查表是否存在 246 | conn = sqlite3.connect(self.database_path) #创建连接 247 | cursor = conn.cursor() #创建游标 248 | cursor.execute(f"SELECT name FROM sqlite_master WHERE type='table' AND name='{table_name}'") 249 | if cursor.fetchone(): #fetchone返回的是一个元组,如果为空,则返回None 250 | return True 251 | return False 252 | 253 | def save_index(self, text_vec): 254 | if text_vec is None: 255 | print(colored("没有有效的文件", "red")) 256 | return 257 | 258 | print(colored(f"正在保存向量索引到文件{self.database_path}", "green")) 259 | conn = sqlite3.connect(self.database_path) # 创建连接 260 | cursor = conn.cursor() # 创建游标 261 | 262 | for table_name, vectors in text_vec: 263 | # 将文件路径作为表名,并替换掉不适合的符号 264 | table_name = f'"{table_name}"' 265 | # 创建表:id, content, embedding(如果不存在则创建) 266 | cursor.execute(f''' 267 | CREATE TABLE IF NOT EXISTS {table_name} ( 268 | id INTEGER PRIMARY KEY AUTOINCREMENT, 269 | content TEXT, 270 | embedding BLOB 271 | ) 272 | ''') 273 | # 对于每个表,插入数据 274 | for vec in vectors: 275 | # 将 NumPy 数组转换为二进制格式(BLOB) 276 | embedding_blob = sqlite3.Binary(pickle.dumps(vec["embedding"])) # 使用 pickle 序列化 277 | cursor.execute(f''' 278 | INSERT INTO {table_name} (content, embedding) 279 | VALUES (?, ?) 280 | ''', (vec["content"], embedding_blob)) 281 | print(colored(f"文件{table_name}保存成功", "green")) 282 | conn.commit() 283 | conn.close() 284 | 285 | def load_index(self): 286 | # 用户选择路径 287 | from tkinter.filedialog import askopenfilename, askdirectory 288 | from tkinter import messagebox 289 | selection_type = messagebox.askquestion("选择类型", "您要选择文件/文件夹?(是/否)", icon='question') 290 | if selection_type == 'yes' or selection_type == 'y': 291 | documents_path = askopenfilename(title="选择文件", 292 | filetypes=[("PDF文件", "*.pdf"), 293 | ("文本文件", "*.txt"), 294 | ("Markdown文件", "*.md"), 295 | ("Docx文件", "*.docx")]) 296 | else: 297 | documents_path = askdirectory(title="选择文件夹") 298 | text_vec = self.create_index(documents_path) 299 | #保存到文件 300 | self.save_index(text_vec) 301 | 302 | 303 | def delete_table(self): 304 | conn = sqlite3.connect(self.database_path) 305 | cursor = conn.cursor() 306 | cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name!='sqlite_sequence'") 307 | tables = [row[0] for row in cursor.fetchall()] 308 | if not tables: 309 | print(colored("没有找到任何文献。", "yellow")) 310 | return 311 | # 显示表名列表 312 | print("-"*50) 313 | print(colored("数据库中已有的文献:", "blue")) 314 | print("-"*50) 315 | for idx, table in enumerate(tables, start=1): 316 | print(colored(f"{idx}. {table}", "blue")) 317 | print("-"*50) 318 | # 让用户选择序号,分号;隔开,或者:表示连续 如1:5;7;9表示1,2,3,4,5,7,9 319 | user_input = input(colored("请选择你要删除的文献的编号(例如1:5;7;9;all): ", "cyan")) 320 | selected_indices = set() 321 | #ALL表示所有表 322 | if user_input.strip().lower() == 'all': 323 | selected_indices = set(range(1, len(tables)+1)) 324 | else: 325 | for part in user_input.split(';'): 326 | if ':' in part: 327 | start, end = part.split(':') 328 | selected_indices.update(range(int(start), int(end)+1)) 329 | else: 330 | selected_indices.add(int(part)) 331 | # 过滤无效的索引 332 | selected_indices = {i for i in selected_indices if 1 <= i <= len(tables)} 333 | selected_tables = [tables[i-1] for i in sorted(selected_indices)] 334 | print("-"*50) 335 | print(colored("已选择要删除的文献:", "blue")) 336 | for idx,table in enumerate(selected_tables, start=1): 337 | print(colored(f"{idx}. {table}", "green")) 338 | print("-"*50) 339 | #询问是否删除 340 | print(colored("是否删除选中的文献?(y/n)", "red")) 341 | choice = input().strip() 342 | if choice == 'y': 343 | # 删除选中的表 344 | for table_name in selected_tables: 345 | cursor.execute(f'DROP TABLE IF EXISTS "{table_name}"') 346 | print(colored(f"已删除文献{table_name}", "green")) 347 | conn.commit() 348 | conn.close() 349 | else: 350 | print(colored("已取消删除", "red")) 351 | 352 | def show_table_info(self): 353 | conn = sqlite3.connect(self.database_path) #创建连接 354 | cursor = conn.cursor() #创建游标 355 | cursor.execute(f"SELECT name FROM sqlite_master WHERE type='table'") 356 | table_names = cursor.fetchall() #fetchall返回的是一个元组列表,每个元组包含一个表名 357 | conn.close() 358 | print(colored(table_names, "blue")) 359 | for table_name in table_names: 360 | if table_name[0] == "sqlite_sequence": 361 | continue 362 | table_name = table_name[0] 363 | print(colored(f"表名:{table_name}", "blue")) 364 | #检查表的行数 365 | conn = sqlite3.connect(self.database_path) #创建连接 366 | cursor = conn.cursor() #创建游标 367 | cursor.execute(f'SELECT COUNT(*) FROM "{table_name}"') 368 | row_count = cursor.fetchone()[0] 369 | print("表的行数:",row_count) 370 | #检查表的列数 371 | cursor.execute(f'PRAGMA table_info("{table_name}")') 372 | column_names = [column[1] for column in cursor.fetchall()] 373 | print("表的列数:",len(column_names)) 374 | #提取第一行的第一列、第二列、第三列 375 | cursor.execute(f'SELECT * FROM "{table_name}" LIMIT 1') 376 | row = cursor.fetchone() 377 | print("第一行的第一列:",row[0]) 378 | print("第一行的第二列:",row[1]) 379 | conn.close() 380 | 381 | if __name__ == "__main__": 382 | database_path = os.getenv("DATABASE_PATH") 383 | vector_indexer = VectorIndexer(database_path=database_path) 384 | vector_indexer.load_index() 385 | vector_indexer.show_table_info() 386 | #output=vector_indexer.search_index("你好") 387 | #print(output) 388 | 389 | 390 | 391 | 392 | --------------------------------------------------------------------------------