├── src ├── __init__.py ├── .env ├── __pycache__ │ ├── llm.cpython-310.pyc │ ├── llm.cpython-311.pyc │ ├── agent.cpython-310.pyc │ ├── chains.cpython-310.pyc │ ├── chains.cpython-311.pyc │ ├── utils.cpython-310.pyc │ ├── utils.cpython-311.pyc │ ├── __init__.cpython-310.pyc │ ├── __init__.cpython-311.pyc │ ├── AgentTools.cpython-310.pyc │ ├── constants.cpython-310.pyc │ ├── constants.cpython-311.pyc │ ├── vectorstore.cpython-310.pyc │ └── vectorstore.cpython-311.pyc ├── constants.py ├── agent.py ├── utils.py ├── AgentTools.py ├── chains.py ├── vectorstore.py └── llm.py ├── test ├── __init__.py └── AgentToolsTest.py ├── img.png ├── vectorStore └── XianzhiVectorStore_gemini │ ├── index.pkl │ └── index.faiss ├── .idea └── .gitignore ├── requirements.txt ├── README.md └── main.py /src/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /test/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/.env: -------------------------------------------------------------------------------- 1 | GOOGLE_API_KEY = xxxxxxx -------------------------------------------------------------------------------- /img.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kento996/xianzhi_assistant/HEAD/img.png -------------------------------------------------------------------------------- /src/__pycache__/llm.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kento996/xianzhi_assistant/HEAD/src/__pycache__/llm.cpython-310.pyc -------------------------------------------------------------------------------- /src/__pycache__/llm.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kento996/xianzhi_assistant/HEAD/src/__pycache__/llm.cpython-311.pyc -------------------------------------------------------------------------------- /src/__pycache__/agent.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kento996/xianzhi_assistant/HEAD/src/__pycache__/agent.cpython-310.pyc -------------------------------------------------------------------------------- /src/__pycache__/chains.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kento996/xianzhi_assistant/HEAD/src/__pycache__/chains.cpython-310.pyc -------------------------------------------------------------------------------- /src/__pycache__/chains.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kento996/xianzhi_assistant/HEAD/src/__pycache__/chains.cpython-311.pyc -------------------------------------------------------------------------------- /src/__pycache__/utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kento996/xianzhi_assistant/HEAD/src/__pycache__/utils.cpython-310.pyc -------------------------------------------------------------------------------- /src/__pycache__/utils.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kento996/xianzhi_assistant/HEAD/src/__pycache__/utils.cpython-311.pyc -------------------------------------------------------------------------------- /src/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kento996/xianzhi_assistant/HEAD/src/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /src/__pycache__/__init__.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kento996/xianzhi_assistant/HEAD/src/__pycache__/__init__.cpython-311.pyc -------------------------------------------------------------------------------- /src/__pycache__/AgentTools.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kento996/xianzhi_assistant/HEAD/src/__pycache__/AgentTools.cpython-310.pyc -------------------------------------------------------------------------------- /src/__pycache__/constants.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kento996/xianzhi_assistant/HEAD/src/__pycache__/constants.cpython-310.pyc -------------------------------------------------------------------------------- /src/__pycache__/constants.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kento996/xianzhi_assistant/HEAD/src/__pycache__/constants.cpython-311.pyc -------------------------------------------------------------------------------- /src/__pycache__/vectorstore.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kento996/xianzhi_assistant/HEAD/src/__pycache__/vectorstore.cpython-310.pyc -------------------------------------------------------------------------------- /src/__pycache__/vectorstore.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kento996/xianzhi_assistant/HEAD/src/__pycache__/vectorstore.cpython-311.pyc -------------------------------------------------------------------------------- /vectorStore/XianzhiVectorStore_gemini/index.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kento996/xianzhi_assistant/HEAD/vectorStore/XianzhiVectorStore_gemini/index.pkl -------------------------------------------------------------------------------- /vectorStore/XianzhiVectorStore_gemini/index.faiss: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kento996/xianzhi_assistant/HEAD/vectorStore/XianzhiVectorStore_gemini/index.faiss -------------------------------------------------------------------------------- /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /shelf/ 3 | /workspace.xml 4 | # Editor-based HTTP Client requests 5 | /httpRequests/ 6 | # Datasource local storage ignored files 7 | /dataSources/ 8 | /dataSources.local.xml 9 | -------------------------------------------------------------------------------- /src/constants.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | #项目root文件夹 4 | ROOT_DIR=Path(__file__).parent.parent 5 | #Prompts文件夹 6 | PROMPTS_DIR=ROOT_DIR.joinpath("Prompts") 7 | #向量知识库 8 | VECTOR_DB_KNOWLEDGE_DIR=ROOT_DIR.joinpath("vectorStore") 9 | #先知文章 10 | XIANZHI_DOCUMENT_DIR=ROOT_DIR.joinpath("DocumentStore/xianzhi") 11 | 12 | 13 | -------------------------------------------------------------------------------- /src/agent.py: -------------------------------------------------------------------------------- 1 | # src/agent.py 2 | 3 | from langchain.agents import initialize_agent, AgentType 4 | from src.llm import get_llm 5 | from src.AgentTools import TOOL_MAP 6 | 7 | def create_agent(tool_name: str, model_provider="gemini", model_name=None): 8 | """ 9 | 创建只包含指定工具的 Agent。 10 | """ 11 | if tool_name not in TOOL_MAP: 12 | raise ValueError(f"未知工具名: {tool_name}") 13 | 14 | llm = get_llm(model_provider, model_name) 15 | tools = [TOOL_MAP[tool_name]] 16 | 17 | agent = initialize_agent( 18 | tools=tools, 19 | llm=llm, 20 | agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, 21 | verbose=True, 22 | ) 23 | 24 | return agent 25 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | from PyPDF2 import PdfReader 2 | from langchain_community.document_loaders import AsyncChromiumLoader 3 | from langchain_community.document_transformers import BeautifulSoupTransformer 4 | from langchain_text_splitters import RecursiveCharacterTextSplitter 5 | 6 | 7 | def get_pdf_text(pdf): 8 | text = "" 9 | pdf_reader = PdfReader(pdf) 10 | for page in pdf_reader.pages: 11 | text += page.extract_text() 12 | 13 | return text 14 | 15 | def get_text_chunks( 16 | text, 17 | chunk_size: int = 1000, 18 | chunk_overlap: int = 150 19 | ): 20 | text_splitter = RecursiveCharacterTextSplitter( 21 | chunk_size=chunk_size, 22 | chunk_overlap=chunk_overlap, 23 | length_function=len 24 | ) 25 | return text_splitter.split_text(text) 26 | 27 | def documentScapy(url): 28 | urls=[] 29 | urls.append(url) 30 | loader = AsyncChromiumLoader(urls) 31 | html = loader.load() 32 | bs_transformer = BeautifulSoupTransformer() 33 | docs_transformed = bs_transformer.transform_documents(html, tags_to_extract=["p"]) 34 | 35 | return docs_transformed[0].page_content -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | aiohttp==3.9.5 2 | aiosignal==1.3.1 3 | annotated-types==0.6.0 4 | async-timeout==4.0.3 5 | attrs==23.2.0 6 | cachetools==5.3.3 7 | certifi==2024.2.2 8 | charset-normalizer==3.3.2 9 | dataclasses-json==0.6.6 10 | frozenlist==1.4.1 11 | google-ai-generativelanguage==0.6.2 12 | google-api-core==2.19.0 13 | google-api-python-client==2.129.0 14 | google-auth==2.29.0 15 | google-auth-httplib2==0.2.0 16 | google-generativeai==0.5.2 17 | googleapis-common-protos==1.63.0 18 | greenlet==3.0.3 19 | grpcio==1.63.0 20 | grpcio-status==1.62.2 21 | httplib2==0.22.0 22 | idna==3.7 23 | jsonpatch==1.33 24 | jsonpointer==2.4 25 | langchain==0.1.14 26 | langchain-community==0.0.38 27 | langchain-core==0.1.52 28 | langchain-google-genai==1.0.3 29 | langchain-openai==0.0.8 30 | langchain-text-splitters==0.0.1 31 | langsmith==0.1.19 32 | marshmallow==3.21.2 33 | multidict==6.0.5 34 | mypy-extensions==1.0.0 35 | numpy==1.26.4 36 | ollama==0.1.6 37 | openai==1.13.3 38 | orjson==3.10.3 39 | packaging==23.2 40 | proto-plus==1.23.0 41 | protobuf==4.25.3 42 | pyasn1==0.6.0 43 | pyasn1_modules==0.4.0 44 | pydantic==2.7.1 45 | pydantic_core==2.18.2 46 | pyparsing==3.1.2 47 | PyPDF2==3.0.1 48 | python-dotenv==1.0.1 49 | PyYAML==6.0.1 50 | requests==2.31.0 51 | rsa==4.9 52 | SQLAlchemy==2.0.30 53 | tenacity==8.3.0 54 | tqdm==4.66.4 55 | typing-inspect==0.9.0 56 | typing_extensions==4.11.0 57 | uritemplate==4.1.1 58 | urllib3==2.2.1 59 | yarl==1.9.4 -------------------------------------------------------------------------------- /test/AgentToolsTest.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from unittest.mock import patch, Mock 3 | from bs4 import BeautifulSoup 4 | import requests 5 | from src.AgentTools import read_web_page # replace with the actual module name 6 | 7 | class TestReadWebPage(unittest.TestCase): 8 | @patch('requests.get') 9 | def test_successful_retrieval(self, mock_get): 10 | mock_response = Mock() 11 | mock_response.status_code = 200 12 | mock_response.text = 'Hello World!' 13 | mock_get.return_value = mock_response 14 | 15 | result = read_web_page('https://example.com') 16 | self.assertEqual(result, 'Hello World!') 17 | 18 | @patch('requests.get') 19 | def test_invalid_url(self, mock_get): 20 | mock_response = Mock() 21 | mock_response.status_code = 404 22 | mock_get.return_value = mock_response 23 | 24 | result = read_web_page('https://invalid-url.com') 25 | self.assertEqual(result, '抓取失败: 404') 26 | 27 | @patch('requests.get') 28 | def test_timeout(self, mock_get): 29 | mock_get.side_effect = requests.Timeout() 30 | 31 | result = read_web_page('https://example.com') 32 | self.assertEqual(result, '抓取失败: Timeout') 33 | 34 | @patch('requests.get') 35 | def test_html_parsing_error(self, mock_get): 36 | mock_response = Mock() 37 | mock_response.status_code = 200 38 | mock_response.text = 'Hello World!' 39 | mock_get.return_value = mock_response 40 | 41 | with patch('bs4.BeautifulSoup') as mock_bs: 42 | mock_bs.side_effect = Exception('HTML parsing error') 43 | 44 | result = read_web_page('https://example.com') 45 | self.assertEqual(result, '抓取失败: HTML parsing error') 46 | 47 | @patch('requests.get') 48 | def test_text_extraction_error(self, mock_get): 49 | mock_response = Mock() 50 | mock_response.status_code = 200 51 | mock_response.text = 'Hello World!' 52 | mock_get.return_value = mock_response 53 | 54 | with patch('bs4.BeautifulSoup.get_text') as mock_get_text: 55 | mock_get_text.side_effect = Exception('Text extraction error') 56 | 57 | result = read_web_page('https://example.com') 58 | self.assertEqual(result, '抓取失败: Text extraction error') 59 | 60 | @patch('requests.get') 61 | def test_unknown_exception(self, mock_get): 62 | mock_get.side_effect = Exception('Unknown error') 63 | 64 | result = read_web_page('https://example.com') 65 | self.assertEqual(result, '抓取失败: Unknown error') 66 | 67 | if __name__ == '__main__': 68 | unittest.main() -------------------------------------------------------------------------------- /src/AgentTools.py: -------------------------------------------------------------------------------- 1 | # src/tools/my_tools.py 2 | 3 | from langchain.tools import Tool 4 | import requests 5 | from bs4 import BeautifulSoup 6 | 7 | 8 | # 网络搜索工具函数 9 | def search_web(query: str) -> str: 10 | """ 11 | 使用 DuckDuckGo 进行简单搜索。 12 | """ 13 | try: 14 | url = f"https://duckduckgo.com/html/?q={query}" 15 | headers = { 16 | "User-Agent": "Mozilla/5.0" 17 | } 18 | response = requests.get(url, headers=headers, timeout=10) 19 | 20 | if response.status_code == 200: 21 | return f"已搜索: “{query}”,请访问 DuckDuckGo 查看结果。\n{url}" 22 | else: 23 | return f"搜索失败,状态码: {response.status_code}" 24 | except Exception as e: 25 | return f"搜索出错: {str(e)}" 26 | 27 | # 网页抓取工具函数 28 | def read_web_page(url: str) -> str: 29 | """ 30 | 读取网页内容,返回纯文本(用于分析)。 31 | """ 32 | try: 33 | headers = {"User-Agent": "Mozilla/5.0"} 34 | response = requests.get(url, headers=headers, timeout=10) 35 | soup = BeautifulSoup(response.text, 'html.parser') 36 | 37 | # 提取正文文本 38 | text = soup.get_text(separator="\n") 39 | cleaned = '\n'.join(line.strip() for line in text.splitlines() if line.strip()) 40 | 41 | return cleaned[:3000] # 限制长度,避免太长(LangChain限制输入) 42 | except Exception as e: 43 | return f"抓取失败: {str(e)}" 44 | 45 | # CVE 查询工具函数 46 | def query_cve(cve_id: str) -> str: 47 | """ 48 | 查询 CVE 详情。 49 | 使用 CIRCL CVE API: https://cve.circl.lu/api/cve/{CVE-ID} 50 | """ 51 | try: 52 | api_url = f"https://cve.circl.lu/api/cve/{cve_id.upper()}" 53 | response = requests.get(api_url, timeout=10) 54 | 55 | if response.status_code != 200: 56 | return f"查询失败,状态码: {response.status_code}" 57 | 58 | data = response.json() 59 | if "summary" not in data: 60 | return "未找到该 CVE 的信息" 61 | 62 | summary = data.get("summary", "无描述") 63 | cvss = data.get("cvss", "无评分") 64 | references = data.get("references", []) 65 | ref_str = "\n".join(references[:5]) if references else "无参考链接" 66 | 67 | return ( 68 | f"🛡️ CVE编号: {cve_id.upper()}\n" 69 | f"📄 简要描述: {summary}\n" 70 | f"📊 CVSS评分: {cvss}\n" 71 | f"🔗 参考链接:\n{ref_str}" 72 | ) 73 | 74 | except Exception as e: 75 | return f"查询失败: {str(e)}" 76 | 77 | # 工具字典,可扩展更多 78 | TOOL_MAP = { 79 | "SearchWeb": Tool.from_function( 80 | func=search_web, 81 | name="SearchWeb", 82 | description="用于互联网搜索,比如查找最新漏洞资讯。输入应为自然语言问题。" 83 | ), 84 | "ReadWebPage": Tool.from_function( 85 | func=read_web_page, 86 | name="ReadWebPage", 87 | description="读取网页并提取正文,用于分析指定URL的内容。输入应为网页URL。" 88 | ), 89 | "CVEQuery": Tool.from_function( 90 | func=query_cve, 91 | name="CVEQuery", 92 | description="查询指定 CVE 编号的详细信息,包括描述、评分、参考链接。输入应为合法的 CVE ID。" 93 | ) 94 | } 95 | -------------------------------------------------------------------------------- /src/chains.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dotenv import load_dotenv 3 | import logging 4 | 5 | from langchain_core.prompts import PromptTemplate 6 | 7 | from src.constants import XIANZHI_DOCUMENT_DIR 8 | from src.llm import get_llm 9 | from src.utils import get_pdf_text, get_text_chunks 10 | 11 | # 配置日志 12 | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') 13 | logger = logging.getLogger('chains') 14 | 15 | # 加载环境变量 16 | load_dotenv() 17 | logger.info("加载环境变量配置") 18 | 19 | # 从环境变量获取默认模型配置 20 | DEFAULT_MODEL_PROVIDER = os.getenv("DEFAULT_MODEL_PROVIDER", "gemini") 21 | DEFAULT_MODEL_NAME = os.getenv("DEFAULT_MODEL_NAME") 22 | 23 | logger.info(f"Chains默认模型提供商: {DEFAULT_MODEL_PROVIDER}") 24 | logger.info(f"Chains默认模型名称: {DEFAULT_MODEL_NAME or '未指定,将使用提供商默认模型'}") 25 | 26 | class Chains: 27 | """ 28 | 用于创建和执行LLM链的基类 29 | """ 30 | 31 | def __init__(self, model_provider=None, model_name=None): 32 | """ 33 | 初始化Chains对象 34 | 35 | Args: 36 | model_provider: 模型提供商,支持"gemini", "openai", "ollama",默认使用环境变量中的配置 37 | model_name: 模型名称,默认使用环境变量中的配置 38 | """ 39 | # 使用环境变量中的默认配置(如果参数未指定) 40 | self.model_provider = model_provider or DEFAULT_MODEL_PROVIDER 41 | self.model_name = model_name or DEFAULT_MODEL_NAME 42 | 43 | logger.info(f"初始化Chains,模型提供商: {self.model_provider},模型名称: {self.model_name or '未指定'}") 44 | 45 | try: 46 | self.llm = get_llm(self.model_provider, self.model_name) 47 | logger.info("成功初始化LLM模型") 48 | except Exception as e: 49 | logger.error(f"初始化LLM模型失败: {str(e)}") 50 | raise 51 | 52 | self._init_prompts() 53 | 54 | def _init_prompts(self): 55 | """初始化提示模板""" 56 | logger.info("初始化提示模板") 57 | 58 | self.contentAbstractPrompt = """请对分析如下文档并完成以下任务: 59 | 1. 分析文档的主题和内容 60 | 2. 用一段话概括文档 61 | ## 文档 62 | ``` 63 | {content_by_question} 64 | ``` 65 | ## 注意 66 | 你的输出结果是包含文档主题和内容的一段话""" 67 | self.contentAbstract_PromptTemplate = PromptTemplate(template=self.contentAbstractPrompt, 68 | input_variables=["content_by_question"]) 69 | 70 | self.signalAnswerPrompt="""你是一位网络安全专家,请你完成如下任务: 71 | 1. 分析如下安全问题 72 | 2. 根据如下的相关文档知识回答问题 73 | ##问题 74 | {question} 75 | ##相关文档 76 | {contents} 77 | ##注意 78 | 如果你觉得相关文档没用时,请你根据你自己的知识回答问题""" 79 | self.signalAnswer_PromptTemplate = PromptTemplate(template=self.signalAnswerPrompt, 80 | input_variables=["question","contents"]) 81 | 82 | self.analyzeResultPrompt="""你是一位网络安全专家,请你完成如下任务: 83 | 1. 分析如下安全问题 84 | 2. 对如下的参考答案进行分析 85 | 3. 根据有用的参考答案回答问题 86 | ##问题 87 | {question} 88 | ##参考内容 89 | {contents}""" 90 | self.analyzeResult_PromptTemplate = PromptTemplate(template=self.analyzeResultPrompt, 91 | input_variables=["question", "contents"]) 92 | 93 | def ContentAbstract_chain(self, content_by_question): 94 | """执行内容摘要链""" 95 | logger.info("执行内容摘要链") 96 | try: 97 | chain=self.contentAbstract_PromptTemplate | self.llm 98 | return chain.invoke({"content_by_question":content_by_question}).content 99 | except Exception as e: 100 | logger.error(f"执行内容摘要链失败: {str(e)}") 101 | raise 102 | 103 | def get_document_description_chain(self, filename, question): 104 | """执行文档描述链""" 105 | logger.info(f"执行文档描述链,文件名: {filename}") 106 | 107 | chain = self.signalAnswer_PromptTemplate | self.llm 108 | ctf_folder = XIANZHI_DOCUMENT_DIR 109 | pdf_path = os.path.join(ctf_folder, filename) 110 | 111 | if not os.path.exists(pdf_path): 112 | logger.error(f"文件不存在: {pdf_path}") 113 | raise FileNotFoundError(f"文件不存在: {pdf_path}") 114 | 115 | try: 116 | logger.info(f"读取PDF文件: {pdf_path}") 117 | raw_text = get_pdf_text(pdf_path) 118 | text_chunks = get_text_chunks(raw_text) 119 | 120 | contents = "" 121 | for index, text in enumerate(text_chunks): 122 | contents += str(index + 1) + "." + text.replace("\n", " ") + "\n" 123 | 124 | logger.info("执行文档描述链") 125 | return chain.invoke({"question":question, "contents":contents}).content 126 | except Exception as e: 127 | logger.error(f"处理文档时出错: {str(e)}") 128 | raise 129 | 130 | def analyze_chain(self, question, contents): 131 | """执行分析链""" 132 | logger.info("执行分析链") 133 | try: 134 | chain=self.analyzeResult_PromptTemplate | self.llm 135 | return chain.invoke({"question":question, "contents":contents}).content 136 | except Exception as e: 137 | logger.error(f"执行分析链失败: {str(e)}") 138 | raise 139 | 140 | # 保持向后兼容性 141 | class Chains_Gemini(Chains): 142 | """保持向后兼容的Gemini模型链类""" 143 | 144 | def __init__(self): 145 | logger.info("初始化Chains_Gemini (兼容模式)") 146 | super().__init__(model_provider="gemini") 147 | 148 | 149 | 150 | 151 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 基于LLM的先知社区知识库 2 | 本项目的开发初衷是为了方便检索先知社区的文章,在ctf比赛中能够基于 3 | llm快速筛选到有用的文章并生成一个结果 4 | ## 实现原理 5 | 本项目基于先知社区的知识内容构建了一个向量知识库,通过llm能够实现基于先知内容的回答,具体内容参见如下流程图: 6 | ![img.png](img.png) 7 | ## 文章目录 8 | 知识库的文章构建范围为:7023~12923,共计2898篇 9 | ## 使用方式 10 | xianzhi_assistant有两种使用方式,用户可以按照知识库构建范围 11 | 在本地接入先知社区的文章,也可以使用url模式。 12 | 13 | 14 | 在env中填写gemini api: 15 | ``` 16 | GOOGLE_API_KEY = AIzaSyA9cKkm4U65BPksk-pVgHmclxxxxxxxxxxx 17 | ``` 18 | ### 本地模式 19 | 在DocumentStore中补充相应范围的先知社区文章,详细的先知文章参考范围参见`xianzhi_index.json`。 20 | 在DocumentStore文件夹下,新建xianzhi文件夹并将各位爬取的文章放入即可 21 | 22 | 然后按照如下方式调用即可: 23 | ``` 24 | python main.py --type "local" --question "k8s存在哪些漏洞" --num 3 25 | ``` 26 | ### URL模式 27 | 直接运行: 28 | ``` 29 | python main.py --type "url" --question "k8s存在哪些漏洞" --num 3 30 | ``` 31 | ### 更新知识库 32 | 用户可以通过`--update`参数指定自己本地的wp地址用于更新知识库中的文章 33 | ``` 34 | python main.py --update "xxxxxxx" 35 | 36 | ``` 37 | ### 示例 38 | - 问题:k8s存在哪些漏洞 39 | ``` 40 | ['https://xz.aliyun.com/t/12437', 'https://xz.aliyun.com/t/12055', 'https://xz.aliyun.com/t/11138', 'https://xz.aliyun.com/t/8000', 'https://xz.aliyun.com/t/11890'] 41 | ## 分析安全问题 42 | 43 | Kubernetes(k8s)是一个开源容器编排系统,它存在以下漏洞: 44 | 45 | * **容器逃逸:**攻击者可以从容器中逃逸到主机操作系统,从而获得对底层系统的访问权限。 46 | * **特权提升:**攻击者可以提升容器内的权限,从而获得对集群的控制权。 47 | * **网络攻击:**攻击者可以利用网络配置错误或漏洞来访问或破坏集群中的容器。 48 | * **数据泄露:**攻击者可以访问或窃取存储在容器中的敏感数据。 49 | * **拒绝服务(DoS):**攻击者可以发起DoS攻击,使集群中的容器或服务不可用。 50 | 51 | ## 对参考答案的分析 52 | 53 | 参考答案1、2、3、4、5都列出了k8s存在的漏洞,但内容有所不同。 54 | 55 | * **参考答案1**提供了最全面的漏洞列表,涵盖了容器逃逸、特权提升、网络安全、数据泄露、拒绝服务、供应链攻击、配置错误、API安全、镜像漏洞和编排漏洞。 56 | * **参考答案2**提供了具体CVE编号的漏洞,但数量较少。 57 | * **参考答案3**提供了与参考答案1类似的漏洞列表,但缺少了供应链攻击和编排漏洞。 58 | * **参考答案4**没有提供任何漏洞信息。 59 | * **参考答案5**提供了与参考答案1类似的漏洞列表,但缺少了供应链攻击和编排漏洞,并增加了Kubernetes API服务器漏洞和网络策略绕过。 60 | 61 | ## 根据有用的参考答案回答问题 62 | 63 | 根据参考答案1、3和5,k8s存在的漏洞包括: 64 | 65 | * 容器逃逸 66 | * 特权提升 67 | * 网络攻击 68 | * 数据泄露 69 | * 拒绝服务 70 | * 配置错误 71 | * Kubernetes API服务器漏洞(参考答案5) 72 | * 网络策略绕过(参考答案5) 73 | ``` 74 | ## 注意 75 | 本项目仅做研究使用,切勿用于任何违法行为 76 | 77 | # XianzhiAsistant 78 | 79 | XianzhiAsistant是一个用于网络安全领域问题处理的助手工具,能够根据问题搜索相关文档并生成回答。 80 | 81 | ## 特性 82 | 83 | - 支持多种模型提供商:Gemini、OpenAI、Ollama 84 | - 通过.env文件灵活配置模型提供商、模型名称和API密钥 85 | - 支持本地存储和URL查询 86 | - 自动提取和分析文档内容 87 | - 根据问题搜索最相关的文档 88 | - 生成综合分析结果 89 | 90 | ## 环境要求 91 | 92 | - Python 3.8+ 93 | - 安装requirements.txt中的依赖 94 | 95 | ## 安装 96 | 97 | 1. 克隆仓库: 98 | ```bash 99 | git clone https://github.com/yourusername/XianzhiAsistant.git 100 | cd XianzhiAsistant 101 | ``` 102 | 103 | 2. 安装依赖: 104 | ```bash 105 | pip install -r requirements.txt 106 | ``` 107 | 108 | 3. 配置环境变量: 109 | 110 | 拷贝示例环境变量文件: 111 | ```bash 112 | cp .env.example .env 113 | ``` 114 | 115 | 根据需要编辑.env文件: 116 | ``` 117 | # API密钥配置 118 | GOOGLE_API_KEY=your_google_api_key_here 119 | OPENAI_API_KEY=your_openai_api_key_here 120 | 121 | # 默认模型配置 122 | DEFAULT_MODEL_PROVIDER=gemini # 可选: gemini, openai, ollama 123 | DEFAULT_MODEL_NAME= # 为空时使用各提供商的默认模型 124 | 125 | # Gemini模型配置 126 | DEFAULT_GEMINI_MODEL=gemini-pro # 可选: gemini-pro, gemini-1.5-pro-latest等 127 | DEFAULT_GEMINI_EMBEDDING_MODEL=models/embedding-001 # Gemini嵌入模型 128 | 129 | # OpenAI模型配置 130 | DEFAULT_OPENAI_MODEL=gpt-3.5-turbo # 可选: gpt-3.5-turbo, gpt-4, gpt-4-turbo等 131 | DEFAULT_OPENAI_EMBEDDING_MODEL=text-embedding-ada-002 # OpenAI嵌入模型 132 | 133 | # Ollama模型配置 134 | OLLAMA_BASE_URL=http://localhost:11434 # Ollama服务URL 135 | DEFAULT_OLLAMA_MODEL=llama2 # 可选: llama2, llama3, mistral等 136 | 137 | # 向量库配置 138 | DEFAULT_VECTOR_DB_DIR=vectorStore # 向量库目录 139 | DEFAULT_NUM_RESULTS=5 # 默认查询结果数 140 | ``` 141 | 142 | ## 配置说明 143 | 144 | 通过.env文件,您可以灵活配置所有模型参数: 145 | 146 | | 环境变量 | 说明 | 默认值 | 147 | |---------|------|--------| 148 | | GOOGLE_API_KEY | Google API密钥 | 无 | 149 | | OPENAI_API_KEY | OpenAI API密钥 | 无 | 150 | | DEFAULT_MODEL_PROVIDER | 默认模型提供商 | gemini | 151 | | DEFAULT_MODEL_NAME | 通用默认模型名称 | 无 | 152 | | DEFAULT_GEMINI_MODEL | Gemini模型名称 | gemini-pro | 153 | | DEFAULT_GEMINI_EMBEDDING_MODEL | Gemini嵌入模型 | models/embedding-001 | 154 | | DEFAULT_OPENAI_MODEL | OpenAI模型名称 | gpt-3.5-turbo | 155 | | DEFAULT_OPENAI_EMBEDDING_MODEL | OpenAI嵌入模型 | text-embedding-ada-002 | 156 | | OLLAMA_BASE_URL | Ollama服务地址 | http://localhost:11434 | 157 | | DEFAULT_OLLAMA_MODEL | Ollama模型名称 | llama2 | 158 | | DEFAULT_VECTOR_DB_DIR | 向量库存储目录 | vectorStore | 159 | | DEFAULT_NUM_RESULTS | 默认查询结果数量 | 5 | 160 | 161 | ### 模型配置优先级 162 | 163 | 模型选择的优先级从高到低: 164 | 1. 命令行参数 (--model, --model_name) 165 | 2. 环境变量 DEFAULT_MODEL_NAME 166 | 3. 提供商特定的环境变量 (DEFAULT_GEMINI_MODEL, DEFAULT_OPENAI_MODEL, DEFAULT_OLLAMA_MODEL) 167 | 4. 代码内置默认值 168 | 169 | ### 特定模型配置示例 170 | 171 | #### Gemini配置示例 172 | ``` 173 | DEFAULT_MODEL_PROVIDER=gemini 174 | DEFAULT_GEMINI_MODEL=gemini-1.5-pro-latest # 使用1.5版本 175 | ``` 176 | 177 | #### OpenAI配置示例 178 | ``` 179 | DEFAULT_MODEL_PROVIDER=openai 180 | DEFAULT_OPENAI_MODEL=gpt-4-turbo # 使用GPT-4 Turbo 181 | ``` 182 | 183 | #### Ollama配置示例 184 | ``` 185 | DEFAULT_MODEL_PROVIDER=ollama 186 | DEFAULT_OLLAMA_MODEL=llama3 # 使用Llama 3 187 | OLLAMA_BASE_URL=http://192.168.1.100:11434 # 自定义Ollama服务地址 188 | ``` 189 | 190 | ## 使用方法 191 | 192 | ### 更新向量库 193 | 194 | 将PDF文档添加到向量库中: 195 | 196 | ```bash 197 | python main.py --update /path/to/pdf/folder 198 | ``` 199 | 200 | 这将使用.env文件中配置的默认模型提供商和模型名称。也可以在命令行指定: 201 | 202 | ```bash 203 | python main.py --update /path/to/pdf/folder --model openai --model_name gpt-4 204 | ``` 205 | 206 | ### 查询问题 207 | 208 | 使用本地存储查询,使用.env中配置的默认模型: 209 | 210 | ```bash 211 | python main.py --type local --question "你的问题" 212 | ``` 213 | 214 | 使用URL查询,覆盖.env中的默认配置: 215 | 216 | ```bash 217 | python main.py --type url --question "你的问题" --num 10 --model openai --model_name gpt-4 218 | ``` 219 | 220 | ## 参数说明 221 | 222 | - `--type`: 查询类型,可选`local`或`url` 223 | - `--question`: 要查询的问题 224 | - `--num`: 返回的相似文档数量,未指定时使用.env中的配置 225 | - `--update`: 要添加的PDF文件夹路径 226 | - `--model`: 模型提供商,可选`gemini`, `openai`或`ollama`,未指定时使用.env中的配置 227 | - `--model_name`: 指定模型名称,未指定时使用.env中的配置 228 | 229 | ## Ollama模型使用 230 | 231 | 要使用Ollama模型: 232 | 233 | 1. 从[Ollama官网](https://ollama.ai/)下载并安装Ollama 234 | 2. 拉取您想要使用的模型: 235 | ```bash 236 | ollama pull llama2 237 | ``` 238 | 3. 运行Ollama服务 239 | 4. 在.env文件中设置: 240 | ``` 241 | DEFAULT_MODEL_PROVIDER=ollama 242 | DEFAULT_OLLAMA_MODEL=llama2 # 或您拉取的其他模型 243 | OLLAMA_BASE_URL=http://localhost:11434 # 如果服务不在默认地址 244 | ``` 245 | 246 | ## OpenAI模型使用 247 | 248 | 要使用OpenAI模型: 249 | 1. 在.env文件中设置您的API密钥: 250 | ``` 251 | OPENAI_API_KEY=your_openai_api_key 252 | DEFAULT_MODEL_PROVIDER=openai 253 | DEFAULT_OPENAI_MODEL=gpt-4 # 或其他模型,如gpt-3.5-turbo 254 | ``` 255 | 256 | ## 注意事项 257 | 258 | - 对于Ollama模型,确保Ollama服务已在本地运行 259 | - 对于OpenAI和Google Gemini模型,确保已设置正确的API密钥 260 | - 向量库存储在不同的目录中,每个模型提供商有自己的向量库 261 | -------------------------------------------------------------------------------- /src/vectorstore.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import List 3 | from dotenv import load_dotenv 4 | import logging 5 | from langchain_community.vectorstores import FAISS 6 | from langchain_core.documents import Document 7 | from langchain_core.vectorstores import VectorStore 8 | 9 | from src.chains import Chains 10 | from src.constants import VECTOR_DB_KNOWLEDGE_DIR 11 | from src.llm import get_embedding 12 | from src.utils import get_pdf_text, get_text_chunks 13 | 14 | # 配置日志 15 | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') 16 | logger = logging.getLogger('vectorstore') 17 | 18 | # 加载环境变量 19 | load_dotenv() 20 | logger.info("加载环境变量配置") 21 | 22 | # 从环境变量获取默认模型配置 23 | DEFAULT_MODEL_PROVIDER = os.getenv("DEFAULT_MODEL_PROVIDER", "gemini") 24 | DEFAULT_MODEL_NAME = os.getenv("DEFAULT_MODEL_NAME") 25 | DEFAULT_VECTOR_DB_DIR = os.getenv("DEFAULT_VECTOR_DB_DIR", VECTOR_DB_KNOWLEDGE_DIR) 26 | 27 | logger.info(f"Vectorstore默认模型提供商: {DEFAULT_MODEL_PROVIDER}") 28 | logger.info(f"Vectorstore默认模型名称: {DEFAULT_MODEL_NAME or '未指定,将使用提供商默认模型'}") 29 | logger.info(f"向量库目录: {DEFAULT_VECTOR_DB_DIR}") 30 | 31 | # 默认向量库路径 32 | VECTOR_DB_KNOWLEDGE_PATH = os.path.join(DEFAULT_VECTOR_DB_DIR, "XianzhiVectorStore") 33 | 34 | # 不同模型提供商的向量库路径 35 | VECTOR_DB_PATHS = { 36 | "gemini": os.path.join(DEFAULT_VECTOR_DB_DIR, "XianzhiVectorStore_gemini"), 37 | "openai": os.path.join(DEFAULT_VECTOR_DB_DIR, "XianzhiVectorStore_openai"), 38 | "ollama": os.path.join(DEFAULT_VECTOR_DB_DIR, "XianzhiVectorStore_ollama") 39 | } 40 | 41 | # 向量库实例缓存 42 | vector_dbs = {} 43 | 44 | def get_vectordb(model_provider=None, model_name=None) -> VectorStore: 45 | """ 46 | 获取或创建向量库实例 47 | 48 | Args: 49 | model_provider: 模型提供商,支持"gemini", "openai", "ollama",默认使用环境变量中的配置 50 | model_name: 模型名称,默认使用环境变量中的配置 51 | 52 | Returns: 53 | VectorStore: 向量库实例 54 | """ 55 | # 使用环境变量默认值(如果参数未指定) 56 | model_provider = model_provider or DEFAULT_MODEL_PROVIDER 57 | model_name = model_name or DEFAULT_MODEL_NAME 58 | 59 | cache_key = f"{model_provider}_{model_name or 'default'}" 60 | logger.info(f"获取向量库,提供商: {model_provider},模型: {model_name or '默认'},缓存键: {cache_key}") 61 | 62 | if cache_key in vector_dbs: 63 | logger.info(f"使用缓存的向量库: {cache_key}") 64 | return vector_dbs[cache_key] 65 | 66 | vector_db_path = VECTOR_DB_PATHS.get(model_provider, VECTOR_DB_KNOWLEDGE_PATH) 67 | logger.info(f"向量库路径: {vector_db_path}") 68 | 69 | try: 70 | # 获取嵌入模型 71 | logger.info(f"获取嵌入模型,提供商: {model_provider}") 72 | embedding = get_embedding(model_provider, model_name) 73 | 74 | if os.path.exists(vector_db_path): 75 | logger.info(f"加载已存在的向量库: {vector_db_path}") 76 | db = FAISS.load_local(vector_db_path, embedding, allow_dangerous_deserialization=True) 77 | logger.info("向量库加载成功") 78 | else: 79 | logger.info(f"向量库不存在,创建新的向量库: {vector_db_path}") 80 | # 如果向量库不存在,创建一个空的向量库 81 | db = FAISS.from_texts(["初始化向量库"], embedding) 82 | # 确保目录存在 83 | os.makedirs(os.path.dirname(vector_db_path), exist_ok=True) 84 | db.save_local(vector_db_path) 85 | logger.info("新向量库创建并保存成功") 86 | 87 | # 存入缓存 88 | vector_dbs[cache_key] = db 89 | return db 90 | except Exception as e: 91 | logger.error(f"获取向量库失败: {str(e)}") 92 | raise 93 | 94 | def query_vectordb(query: str, k: int = 20, model_provider=None, model_name=None) -> List[Document]: 95 | """ 96 | 查询向量库 97 | 98 | Args: 99 | query: 查询文本 100 | k: 返回的最相似文档数量 101 | model_provider: 模型提供商,默认使用环境变量中的配置 102 | model_name: 模型名称,默认使用环境变量中的配置 103 | 104 | Returns: 105 | List[Document]: 相似文档列表 106 | """ 107 | # 使用环境变量默认值(如果参数未指定) 108 | model_provider = model_provider or DEFAULT_MODEL_PROVIDER 109 | model_name = model_name or DEFAULT_MODEL_NAME 110 | 111 | logger.info(f"查询向量库,提供商: {model_provider},k: {k}") 112 | 113 | try: 114 | db = get_vectordb(model_provider, model_name) 115 | logger.info(f"执行相似度搜索,查询: '{query[:50]}...'") 116 | docs = db.similarity_search(query, k=k) 117 | logger.info(f"查询成功,返回 {len(docs)} 个结果") 118 | return docs 119 | except Exception as e: 120 | logger.error(f"查询向量库失败: {str(e)}") 121 | raise 122 | 123 | def update_vectordb(pdf_documents_path: str, model_provider=None, model_name=None): 124 | """ 125 | 更新向量库 126 | 127 | Args: 128 | pdf_documents_path: PDF文档路径 129 | model_provider: 模型提供商,默认使用环境变量中的配置 130 | model_name: 模型名称,默认使用环境变量中的配置 131 | """ 132 | # 使用环境变量默认值(如果参数未指定) 133 | model_provider = model_provider or DEFAULT_MODEL_PROVIDER 134 | model_name = model_name or DEFAULT_MODEL_NAME 135 | 136 | logger.info(f"更新向量库,文档路径: {pdf_documents_path},提供商: {model_provider}") 137 | 138 | if not os.path.exists(pdf_documents_path): 139 | error_msg = f"指定的目录不存在: {pdf_documents_path}" 140 | logger.error(error_msg) 141 | print(error_msg) 142 | return 143 | 144 | try: 145 | db = get_vectordb(model_provider, model_name) 146 | vector_db_path = VECTOR_DB_PATHS.get(model_provider, VECTOR_DB_KNOWLEDGE_PATH) 147 | 148 | pdf_files = [f for f in os.listdir(pdf_documents_path) if f.lower().endswith('.pdf')] 149 | logger.info(f"找到 {len(pdf_files)} 个PDF文件") 150 | 151 | if not pdf_files: 152 | logger.warning(f"目录中没有PDF文件: {pdf_documents_path}") 153 | print(f"警告: 目录中没有PDF文件: {pdf_documents_path}") 154 | return 155 | 156 | for filename in pdf_files: 157 | try: 158 | filepath = os.path.join(pdf_documents_path, filename) 159 | logger.info(f"处理文件: {filepath}") 160 | 161 | # 读取并分割文本 162 | text_chunks = get_text_chunks(get_pdf_text(filepath)) 163 | logger.info(f"文件 {filename} 分割为 {len(text_chunks)} 个文本块") 164 | 165 | contents = "" 166 | for index, text in enumerate(text_chunks): 167 | contents += str(index + 1) + "." + text.replace("\n", " ") + "\n" 168 | 169 | # 创建摘要 170 | logger.info(f"为文件 {filename} 创建摘要") 171 | chain = Chains(model_provider, model_name) 172 | text_abstract = chain.ContentAbstract_chain(contents) 173 | 174 | # 创建文档对象 175 | documents = [] 176 | documents.append( 177 | Document( 178 | page_content=text_abstract, 179 | metadata={ 180 | "FileName": filename 181 | } 182 | ) 183 | ) 184 | 185 | # 添加到向量库 186 | logger.info(f"将文件 {filename} 添加到向量库") 187 | db.add_documents(documents) 188 | db.save_local(vector_db_path) 189 | logger.info(f"文件 {filename} 成功添加到向量库") 190 | print(f"已添加文档 {filename} 到向量库") 191 | except Exception as e: 192 | logger.error(f"处理文件 {filename} 时出错: {str(e)}") 193 | print(f"处理文件 {filename} 时出错: {str(e)}") 194 | continue 195 | 196 | logger.info(f"向量库更新完成,保存到: {vector_db_path}") 197 | except Exception as e: 198 | logger.error(f"更新向量库时出错: {str(e)}") 199 | raise 200 | 201 | if __name__ == '__main__': 202 | import sys 203 | 204 | if len(sys.argv) > 1: 205 | pdf_dir = sys.argv[1] 206 | model_provider = sys.argv[2] if len(sys.argv) > 2 else None 207 | model_name = sys.argv[3] if len(sys.argv) > 3 else None 208 | update_vectordb(pdf_dir, model_provider, model_name) -------------------------------------------------------------------------------- /src/llm.py: -------------------------------------------------------------------------------- 1 | from google.generativeai.types import HarmCategory, HarmBlockThreshold 2 | from langchain_core.language_models import BaseChatModel 3 | from langchain_google_genai import ChatGoogleGenerativeAI, GoogleGenerativeAIEmbeddings 4 | from langchain_openai import ChatOpenAI, OpenAIEmbeddings 5 | from langchain_community.llms import Ollama 6 | from langchain_community.embeddings import OllamaEmbeddings 7 | from dotenv import load_dotenv 8 | import os 9 | import logging 10 | 11 | # 配置日志 12 | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') 13 | logger = logging.getLogger('llm') 14 | 15 | # 从.env文件加载环境变量 16 | load_dotenv() 17 | logger.info("加载环境变量配置") 18 | 19 | # 从.env文件获取默认模型配置 20 | DEFAULT_MODEL_PROVIDER = os.getenv("DEFAULT_MODEL_PROVIDER", "gemini") 21 | DEFAULT_MODEL_NAME = os.getenv("DEFAULT_MODEL_NAME", None) 22 | GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY") 23 | OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") 24 | OLLAMA_BASE_URL = os.getenv("OLLAMA_BASE_URL", "http://localhost:11434") 25 | 26 | # 特定模型提供商的默认模型配置 27 | DEFAULT_GEMINI_MODEL = os.getenv("DEFAULT_GEMINI_MODEL", "gemini-pro") 28 | DEFAULT_OPENAI_MODEL = os.getenv("DEFAULT_OPENAI_MODEL", "gpt-3.5-turbo") 29 | DEFAULT_OLLAMA_MODEL = os.getenv("DEFAULT_OLLAMA_MODEL", "llama2") 30 | 31 | # 嵌入模型配置 32 | DEFAULT_GEMINI_EMBEDDING_MODEL = os.getenv("DEFAULT_GEMINI_EMBEDDING_MODEL", "models/embedding-001") 33 | DEFAULT_OPENAI_EMBEDDING_MODEL = os.getenv("DEFAULT_OPENAI_EMBEDDING_MODEL", "text-embedding-ada-002") 34 | 35 | logger.info(f"默认模型提供商: {DEFAULT_MODEL_PROVIDER}") 36 | logger.info(f"使用的模型: {DEFAULT_MODEL_NAME or '未指定,将使用提供商默认模型'}") 37 | 38 | def get_llm(model_provider=None, model_name=None) -> BaseChatModel: 39 | """ 40 | 获取LLM模型实例 41 | 42 | Args: 43 | model_provider: 模型提供商,支持"gemini", "openai", "ollama" 44 | model_name: 模型名称,为None时使用默认模型 45 | 46 | Returns: 47 | BaseChatModel: LLM模型实例 48 | """ 49 | # 如果未指定,使用环境变量中的默认值 50 | model_provider = model_provider or DEFAULT_MODEL_PROVIDER 51 | logger.info(f"使用模型提供商: {model_provider}") 52 | 53 | try: 54 | if model_provider == "gemini": 55 | # 如果没有指定model_name,使用DEFAULT_MODEL_NAME,如果仍为None则使用DEFAULT_GEMINI_MODEL 56 | specific_model = model_name or DEFAULT_MODEL_NAME or DEFAULT_GEMINI_MODEL 57 | logger.info(f"使用Gemini模型: {specific_model}") 58 | if not GOOGLE_API_KEY: 59 | raise ValueError("未设置GOOGLE_API_KEY环境变量,无法使用Gemini模型") 60 | return get_gemini(specific_model) 61 | elif model_provider == "openai": 62 | specific_model = model_name or DEFAULT_MODEL_NAME or DEFAULT_OPENAI_MODEL 63 | logger.info(f"使用OpenAI模型: {specific_model}") 64 | if not OPENAI_API_KEY: 65 | raise ValueError("未设置OPENAI_API_KEY环境变量,无法使用OpenAI模型") 66 | return get_openai(specific_model) 67 | elif model_provider == "ollama": 68 | specific_model = model_name or DEFAULT_MODEL_NAME or DEFAULT_OLLAMA_MODEL 69 | logger.info(f"使用Ollama模型: {specific_model}") 70 | return get_ollama(specific_model) 71 | else: 72 | raise ValueError(f"不支持的模型提供商: {model_provider}") 73 | except Exception as e: 74 | logger.error(f"获取LLM模型失败: {str(e)}") 75 | raise 76 | 77 | def get_embedding(model_provider=None, model_name=None): 78 | """ 79 | 获取嵌入模型实例 80 | 81 | Args: 82 | model_provider: 模型提供商,支持"gemini", "openai", "ollama" 83 | model_name: 模型名称,为None时使用默认模型 84 | 85 | Returns: 86 | 嵌入模型实例 87 | """ 88 | # 如果未指定,使用环境变量中的默认值 89 | model_provider = model_provider or DEFAULT_MODEL_PROVIDER 90 | logger.info(f"使用嵌入模型提供商: {model_provider}") 91 | 92 | try: 93 | if model_provider == "gemini": 94 | specific_model = model_name or DEFAULT_MODEL_NAME or DEFAULT_GEMINI_EMBEDDING_MODEL 95 | logger.info(f"使用Gemini嵌入模型: {specific_model}") 96 | if not GOOGLE_API_KEY: 97 | raise ValueError("未设置GOOGLE_API_KEY环境变量,无法使用Gemini嵌入模型") 98 | return get_gemini_embedding(specific_model) 99 | elif model_provider == "openai": 100 | specific_model = model_name or DEFAULT_MODEL_NAME or DEFAULT_OPENAI_EMBEDDING_MODEL 101 | logger.info(f"使用OpenAI嵌入模型: {specific_model}") 102 | if not OPENAI_API_KEY: 103 | raise ValueError("未设置OPENAI_API_KEY环境变量,无法使用OpenAI嵌入模型") 104 | return get_openai_embedding(specific_model) 105 | elif model_provider == "ollama": 106 | specific_model = model_name or DEFAULT_MODEL_NAME or DEFAULT_OLLAMA_MODEL 107 | logger.info(f"使用Ollama嵌入模型: {specific_model}") 108 | return get_ollama_embedding(specific_model) 109 | else: 110 | raise ValueError(f"不支持的模型提供商: {model_provider}") 111 | except Exception as e: 112 | logger.error(f"获取嵌入模型失败: {str(e)}") 113 | raise 114 | 115 | def get_gemini(model_name=DEFAULT_GEMINI_MODEL) -> BaseChatModel: 116 | """获取Gemini模型实例""" 117 | try: 118 | return ChatGoogleGenerativeAI( 119 | model=model_name, 120 | temperature=0, 121 | convert_system_message_to_human=True, 122 | transport="rest", 123 | google_api_key=GOOGLE_API_KEY, 124 | safety_settings={ 125 | HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE, 126 | HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE, 127 | HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE, 128 | HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE, 129 | } 130 | ) 131 | except Exception as e: 132 | logger.error(f"初始化Gemini模型失败: {str(e)}") 133 | raise 134 | 135 | # 保留向后兼容性的函数 136 | def get_gemini_pro_15() -> BaseChatModel: 137 | return get_gemini("gemini-1.5-pro-latest") 138 | 139 | # 保留向后兼容性的函数 140 | def get_gemini_pro() -> BaseChatModel: 141 | return get_gemini("gemini-pro") 142 | 143 | def get_gemini_embedding(model_name=DEFAULT_GEMINI_EMBEDDING_MODEL): 144 | """获取Gemini嵌入模型实例""" 145 | try: 146 | return GoogleGenerativeAIEmbeddings( 147 | model=model_name, 148 | transport="rest", 149 | google_api_key=GOOGLE_API_KEY 150 | ) 151 | except Exception as e: 152 | logger.error(f"初始化Gemini嵌入模型失败: {str(e)}") 153 | raise 154 | 155 | def get_openai(model_name=DEFAULT_OPENAI_MODEL) -> BaseChatModel: 156 | """获取OpenAI模型实例""" 157 | try: 158 | return ChatOpenAI( 159 | model=model_name, 160 | temperature=0, 161 | api_key=OPENAI_API_KEY, 162 | ) 163 | except Exception as e: 164 | logger.error(f"初始化OpenAI模型失败: {str(e)}") 165 | raise 166 | 167 | def get_openai_embedding(model_name=DEFAULT_OPENAI_EMBEDDING_MODEL): 168 | """获取OpenAI嵌入模型实例""" 169 | try: 170 | return OpenAIEmbeddings( 171 | model=model_name, 172 | api_key=OPENAI_API_KEY, 173 | ) 174 | except Exception as e: 175 | logger.error(f"初始化OpenAI嵌入模型失败: {str(e)}") 176 | raise 177 | 178 | def get_ollama(model_name=DEFAULT_OLLAMA_MODEL) -> BaseChatModel: 179 | """获取Ollama模型实例""" 180 | try: 181 | return Ollama( 182 | model=model_name, 183 | temperature=0, 184 | base_url=OLLAMA_BASE_URL, 185 | ) 186 | except Exception as e: 187 | logger.error(f"初始化Ollama模型失败: {str(e)}, 请确保Ollama服务正在运行且地址正确") 188 | raise 189 | 190 | def get_ollama_embedding(model_name=DEFAULT_OLLAMA_MODEL): 191 | """获取Ollama嵌入模型实例""" 192 | try: 193 | return OllamaEmbeddings( 194 | model=model_name, 195 | base_url=OLLAMA_BASE_URL, 196 | ) 197 | except Exception as e: 198 | logger.error(f"初始化Ollama嵌入模型失败: {str(e)}, 请确保Ollama服务正在运行且地址正确") 199 | raise 200 | 201 | 202 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import re 3 | import os 4 | import logging 5 | import sys 6 | from dotenv import load_dotenv 7 | 8 | from src.chains import Chains 9 | from src.llm import get_llm 10 | from src.utils import documentScapy 11 | from src.vectorstore import query_vectordb, update_vectordb 12 | from src.agent import create_agent 13 | 14 | # 配置日志 15 | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') 16 | logger = logging.getLogger('main') 17 | 18 | # 加载环境变量 19 | load_dotenv() 20 | logger.info("加载环境变量配置") 21 | 22 | # 从环境变量获取默认配置 23 | DEFAULT_MODEL_PROVIDER = os.getenv("DEFAULT_MODEL_PROVIDER", "gemini") 24 | DEFAULT_MODEL_NAME = os.getenv("DEFAULT_MODEL_NAME") 25 | DEFAULT_NUM_RESULTS = int(os.getenv("DEFAULT_NUM_RESULTS", "5")) 26 | 27 | logger.info(f"默认模型提供商: {DEFAULT_MODEL_PROVIDER}") 28 | logger.info(f"默认模型名称: {DEFAULT_MODEL_NAME or '未指定,将使用提供商默认模型'}") 29 | logger.info(f"默认结果数量: {DEFAULT_NUM_RESULTS}") 30 | 31 | def search_document(llm, chains, question: str, k: int = 5): 32 | """ 33 | 根据问题搜索相关文档 34 | 35 | Args: 36 | llm: 语言模型实例 37 | chains: 链实例 38 | question: 问题 39 | k: 返回的文档数量 40 | 41 | Returns: 42 | 相关文档列表 43 | """ 44 | logger.info(f"根据问题搜索文档: '{question}'") 45 | try: 46 | logger.info("调用LLM生成内容") 47 | content = llm.invoke(question).content 48 | logger.info("生成内容摘要") 49 | abstract_content = chains.ContentAbstract_chain(content_by_question=content) 50 | logger.info(f"查询向量库,k={k}") 51 | ans = query_vectordb(abstract_content, k, chains.model_provider, chains.model_name) 52 | logger.info(f"查询到 {len(ans)} 个相关文档") 53 | return ans 54 | except Exception as e: 55 | logger.error(f"搜索文档失败: {str(e)}") 56 | raise 57 | 58 | def resCollection(chains, ans, question): 59 | """ 60 | 收集文档结果 61 | 62 | Args: 63 | chains: 链实例 64 | ans: 文档列表 65 | question: 问题 66 | 67 | Returns: 68 | 处理后的结果和文件名列表 69 | """ 70 | logger.info("收集文档结果") 71 | res_org = [] 72 | fileNames = [] 73 | pattern = r"^\d+" 74 | 75 | try: 76 | for i, doc in enumerate(ans): 77 | fileName = doc.metadata["FileName"] 78 | logger.info(f"处理文档 {i+1}/{len(ans)}: {fileName}") 79 | 80 | # 处理文件名,提取编号 81 | match = re.search(pattern, fileName) 82 | if match: 83 | filenum = match.group() 84 | url = f"https://xz.aliyun.com/t/{filenum}" 85 | fileNames.append(url) 86 | logger.info(f"生成URL: {url}") 87 | else: 88 | fileNames.append(fileName) 89 | logger.info(f"使用文件名: {fileName}") 90 | 91 | # 获取文档描述 92 | logger.info(f"获取文档描述: {fileName}") 93 | doc_description = chains.get_document_description_chain(fileName, question) 94 | res_org.append(doc_description) 95 | 96 | # 组合结果 97 | logger.info("组合结果") 98 | res_deal = "" 99 | for index, i in enumerate(res_org): 100 | res_deal += f"###参考答案{index + 1}" + "\n" + i.replace("\n", " ") + "\n" 101 | 102 | return res_deal, fileNames 103 | except Exception as e: 104 | logger.error(f"收集文档结果失败: {str(e)}") 105 | raise 106 | 107 | def anaylzeResultByUrl(chains, llm, question, fileNames: list): 108 | """ 109 | 通过URL分析结果 110 | 111 | Args: 112 | chains: 链实例 113 | llm: 语言模型实例 114 | question: 问题 115 | fileNames: URL列表 116 | 117 | Returns: 118 | 分析结果 119 | """ 120 | logger.info(f"通过URL分析结果,URLs数量: {len(fileNames)}") 121 | res_org = [] 122 | 123 | try: 124 | for i, fileName in enumerate(fileNames): 125 | logger.info(f"获取URL内容 {i+1}/{len(fileNames)}: {fileName}") 126 | content_page = documentScapy(fileName) 127 | 128 | logger.info(f"分析URL内容: {fileName}") 129 | chain = chains.analyzeResult_PromptTemplate | llm 130 | res = chain.invoke({"question": question, "contents": content_page}).content 131 | res_org.append(res) 132 | 133 | # 组合结果 134 | logger.info("组合URL分析结果") 135 | res_deal = "" 136 | for index, i in enumerate(res_org): 137 | res_deal += f"###参考答案{index + 1}" + "\n" + i.replace("\n", " ") + "\n" 138 | 139 | return res_deal 140 | except Exception as e: 141 | logger.error(f"通过URL分析结果失败: {str(e)}") 142 | raise 143 | 144 | def anaylzeResult(chains, res_collection, question): 145 | """ 146 | 分析结果集 147 | 148 | Args: 149 | chains: 链实例 150 | res_collection: 结果集 151 | question: 问题 152 | 153 | Returns: 154 | 分析结果 155 | """ 156 | logger.info("分析最终结果") 157 | try: 158 | result = chains.analyze_chain(question=question, contents=res_collection) 159 | logger.info("分析完成") 160 | return result 161 | except Exception as e: 162 | logger.error(f"分析结果失败: {str(e)}") 163 | raise 164 | 165 | def process_query(question: str, k: int = None, model_provider=None, model_name=None, store_type="local"): 166 | """ 167 | 处理查询 168 | 169 | Args: 170 | question: 查询问题 171 | k: 返回的相似文档数量,默认使用环境变量中的配置 172 | model_provider: 模型提供商,支持"gemini", "openai", "ollama",默认使用环境变量中的配置 173 | model_name: 模型名称,默认使用环境变量中的配置 174 | store_type: 存储类型,"local"或"url" 175 | """ 176 | # 使用环境变量默认值(如果参数未指定) 177 | k = k or DEFAULT_NUM_RESULTS 178 | model_provider = model_provider or DEFAULT_MODEL_PROVIDER 179 | 180 | logger.info(f"处理查询 - 问题: '{question}', 类型: {store_type}, 模型: {model_provider}/{model_name or '默认'}, k: {k}") 181 | 182 | try: 183 | # 获取模型和链 184 | logger.info("初始化LLM模型") 185 | llm = get_llm(model_provider, model_name) 186 | logger.info("初始化Chains") 187 | chains = Chains(model_provider, model_name) 188 | 189 | # 搜索相关文档 190 | logger.info("搜索相关文档") 191 | ans = search_document(llm, chains, question, k) 192 | 193 | # 获取结果 194 | logger.info("获取文档来源") 195 | _, fileNames = resCollection(chains, ans, question) 196 | 197 | # 根据存储类型选择处理方式 198 | if store_type == "url": 199 | logger.info("使用URL模式处理") 200 | res_collection = anaylzeResultByUrl(chains=chains, llm=llm, question=question, fileNames=fileNames) 201 | else: 202 | logger.info("使用本地模式处理") 203 | res_collection, _ = resCollection(chains, ans, question) 204 | 205 | # 分析结果 206 | logger.info("分析结果") 207 | res = anaylzeResult(chains, res_collection, question) 208 | 209 | # 输出结果 210 | logger.info("输出结果") 211 | print("参考来源:") 212 | for url in fileNames: 213 | print(f"- {url}") 214 | print("\n分析结果:") 215 | print(res) 216 | 217 | return res, fileNames 218 | except Exception as e: 219 | logger.error(f"处理查询失败: {str(e)}") 220 | print(f"处理查询失败: {str(e)}") 221 | raise 222 | 223 | def local_store(question: str, k: int = None, model_provider=None, model_name=None): 224 | """使用本地存储处理查询""" 225 | logger.info("使用本地存储处理查询") 226 | return process_query(question, k, model_provider, model_name, "local") 227 | 228 | def url_store(question: str, k: int = None, model_provider=None, model_name=None): 229 | """使用URL存储处理查询""" 230 | logger.info("使用URL存储处理查询") 231 | return process_query(question, k, model_provider, model_name, "url") 232 | 233 | if __name__ == '__main__': 234 | try: 235 | logger.info("程序启动") 236 | parser = argparse.ArgumentParser(description='使用AI处理安全领域问题') 237 | parser.add_argument('--type', choices=['local', 'url'], help=f'选择存储类型: local或url') 238 | parser.add_argument('--question', type=str, help='要处理的问题') 239 | parser.add_argument('--num', type=int, help=f'返回的相似文档数量 (默认: {DEFAULT_NUM_RESULTS})') 240 | parser.add_argument('--update', type=str, help='要添加的PDF文件夹路径') 241 | parser.add_argument('--model', choices=['gemini', 'openai', 'ollama'], help=f'选择模型提供商: gemini, openai或ollama (默认: {DEFAULT_MODEL_PROVIDER})') 242 | parser.add_argument('--model_name', type=str, help='指定模型名称,默认使用.env中配置值') 243 | # 创建agent开放工具权限列表 244 | parser.add_argument('--call_function', choices=['SearchWeb', 'ReadWebPage', 'CVEQuery'], help='调用Agent执行的工具名,例如: SearchWeb, ReadWebPage, CVEQuery') 245 | 246 | args = parser.parse_args() 247 | logger.info(f"命令行参数: {args}") 248 | 249 | if args.type is not None: 250 | if not args.question: 251 | error_msg = "错误: 必须提供--question参数" 252 | logger.error(error_msg) 253 | print(error_msg) 254 | parser.print_help() 255 | sys.exit(1) 256 | 257 | model_provider = args.model 258 | model_name = args.model_name 259 | 260 | logger.info(f"执行查询 - 类型: {args.type}, 模型: {model_provider or DEFAULT_MODEL_PROVIDER}/{model_name or DEFAULT_MODEL_NAME or '默认'}") 261 | 262 | match args.type: 263 | case 'url': 264 | url_store(args.question, args.num, model_provider, model_name) 265 | case 'local': 266 | local_store(args.question, args.num, model_provider, model_name) 267 | elif args.update is not None: 268 | model_provider = args.model 269 | model_name = args.model_name 270 | 271 | logger.info(f"更新向量库 - 路径: {args.update}, 模型: {model_provider or DEFAULT_MODEL_PROVIDER}/{model_name or DEFAULT_MODEL_NAME or '默认'}") 272 | update_vectordb(args.update, model_provider, model_name) 273 | # 实现agent工具调用功能 274 | elif args.call_function is not None: 275 | logger.info(f"调用Agent工具: {args.call_function}") 276 | agent = create_agent( 277 | tool_name=args.call_function, 278 | model_provider=args.model or DEFAULT_MODEL_PROVIDER, 279 | model_name=args.model_name or DEFAULT_MODEL_NAME 280 | ) 281 | if not args.question: 282 | print("错误:使用 --call_function 时必须提供 --question 参数作为输入。") 283 | sys.exit(1) 284 | 285 | result = agent.run(args.question) 286 | print("Agent响应:", result) 287 | 288 | else: 289 | logger.info("未提供操作参数,显示帮助信息") 290 | parser.print_help() 291 | except Exception as e: 292 | logger.error(f"程序执行失败: {str(e)}") 293 | print(f"错误: {str(e)}") 294 | sys.exit(1) 295 | 296 | 297 | 298 | 299 | 300 | 301 | 302 | 303 | 304 | 305 | 306 | --------------------------------------------------------------------------------