├── src ├── mcp │ ├── __init__.py │ ├── battle_server.py │ ├── risk_control_server.py │ ├── chip_analysis_server.py │ ├── big_deal_analysis_server.py │ ├── technical_analysis_server.py │ ├── sentiment_server.py │ ├── hot_money_srver.py │ └── server.py ├── prompt │ ├── __init__.py │ ├── toolcall.py │ ├── big_deal_analysis.py │ ├── hot_money.py │ ├── risk_control.py │ ├── technical_analysis.py │ ├── report.py │ ├── chip_analysis.py │ ├── mcp.py │ ├── sentiment.py │ └── battle.py ├── exceptions.py ├── __init__.py ├── environment │ ├── __init__.py │ ├── base.py │ └── research.py ├── agent │ ├── __init__.py │ ├── react.py │ ├── report.py │ ├── big_deal_analysis.py │ ├── chip_analysis.py │ ├── risk_control.py │ ├── hot_money.py │ ├── technical_analysis.py │ ├── base.py │ ├── sentiment.py │ └── toolcall.py ├── tool │ ├── search │ │ ├── __init__.py │ │ ├── google_search.py │ │ ├── base.py │ │ ├── baidu_search.py │ │ ├── duckduckgo_search.py │ │ └── bing_search.py │ ├── __init__.py │ ├── financial_deep_search │ │ ├── __init__.py │ │ ├── index_name_map.json │ │ ├── get_section_data.py │ │ ├── index_capital.py │ │ └── stock_capital.py │ ├── terminate.py │ ├── tool_collection.py │ ├── battle.py │ ├── base.py │ ├── stock_info_request.py │ ├── risk_control.py │ ├── sentiment.py │ ├── create_chat_completion.py │ ├── mcp_client.py │ ├── big_deal_analysis.py │ └── hot_money.py ├── logger.py ├── utils │ └── cleanup_reports.py └── schema.py ├── report └── README.md ├── docs ├── boyi.png ├── logo.png ├── wechat.JPG ├── architecture.png ├── flow_diagram.md └── class_diagram.md ├── config ├── .gitignore ├── mcp.example.json └── config.example.toml ├── requirements.txt ├── .gitattributes ├── .pre-commit-config.yaml └── .gitignore /src/mcp/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/prompt/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /report/README.md: -------------------------------------------------------------------------------- 1 | # 报告文件 2 | - 这里存放最终的结果 -------------------------------------------------------------------------------- /docs/boyi.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HuaYaoAI/FinGenius/HEAD/docs/boyi.png -------------------------------------------------------------------------------- /docs/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HuaYaoAI/FinGenius/HEAD/docs/logo.png -------------------------------------------------------------------------------- /docs/wechat.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HuaYaoAI/FinGenius/HEAD/docs/wechat.JPG -------------------------------------------------------------------------------- /docs/architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HuaYaoAI/FinGenius/HEAD/docs/architecture.png -------------------------------------------------------------------------------- /config/.gitignore: -------------------------------------------------------------------------------- 1 | # prevent the local config file from being uploaded to the remote repository 2 | config.toml 3 | mcp.json 4 | -------------------------------------------------------------------------------- /config/mcp.example.json: -------------------------------------------------------------------------------- 1 | { 2 | "mcpServers": { 3 | "server1": { 4 | "type": "sse", 5 | "url": "http://localhost:8000/sse" 6 | } 7 | } 8 | } 9 | -------------------------------------------------------------------------------- /src/prompt/toolcall.py: -------------------------------------------------------------------------------- 1 | SYSTEM_PROMPT = "You are an agent that can execute tool calls" 2 | 3 | NEXT_STEP_PROMPT = ( 4 | "If you want to stop interaction, use `terminate` tool/function call." 5 | ) 6 | -------------------------------------------------------------------------------- /src/exceptions.py: -------------------------------------------------------------------------------- 1 | class ToolError(Exception): 2 | """Raised when a tool encounters an error.""" 3 | 4 | def __init__(self, message): 5 | self.message = message 6 | 7 | 8 | class TokenLimitExceeded(Exception): 9 | """Exception raised when the token limit is exceeded""" 10 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- 1 | # Python version check: 3.11-3.13 2 | import sys 3 | 4 | 5 | if sys.version_info < (3, 11) or sys.version_info > (3, 13): 6 | print( 7 | "Warning: Unsupported Python version {ver}, please use 3.11-3.13".format( 8 | ver=".".join(map(str, sys.version_info)) 9 | ) 10 | ) 11 | -------------------------------------------------------------------------------- /src/environment/__init__.py: -------------------------------------------------------------------------------- 1 | """Environment module for different execution environments.""" 2 | 3 | from src.environment.base import BaseEnvironment 4 | from src.environment.battle import BattleEnvironment 5 | from src.environment.research import ResearchEnvironment 6 | 7 | 8 | __all__ = ["BaseEnvironment", "ResearchEnvironment", "BattleEnvironment"] 9 | -------------------------------------------------------------------------------- /src/prompt/big_deal_analysis.py: -------------------------------------------------------------------------------- 1 | BIG_DEAL_SYSTEM_PROMPT = """ 2 | 你是FinGenius系统中的大单异动分析专家,擅长使用资金流向数据结合价格走势识别市场高强度资金异动。 3 | 4 | 你的职责: 5 | 1. 使用 big_deal_analysis_tool 获取市场与个股资金流向数据(stock_fund_flow_big_deal、stock_fund_flow_individual 等)。 6 | 2. 给出市场整体资金净流入/净流出统计,并列出 TOP 净流入 / 净流出股票。 7 | 3. 针对指定股票,分析资金流向趋势与价格走势,给出综合评分,并输出投资建议(看涨/看跌及理由)。 8 | 4. 识别极端流入/流出、市场情绪和高活跃度股票并说明依据。 9 | 10 | 输出请使用分点叙述,逻辑清晰,必要时给出简明表格。 11 | """ -------------------------------------------------------------------------------- /src/agent/__init__.py: -------------------------------------------------------------------------------- 1 | from src.agent.base import BaseAgent 2 | from src.agent.chip_analysis import ChipAnalysisAgent 3 | from src.agent.react import ReActAgent 4 | from src.agent.toolcall import ToolCallAgent 5 | from src.agent.big_deal_analysis import BigDealAnalysisAgent 6 | 7 | 8 | __all__ = [ 9 | "BaseAgent", 10 | "ChipAnalysisAgent", 11 | "ReActAgent", 12 | "ToolCallAgent", 13 | "BigDealAnalysisAgent", 14 | ] 15 | -------------------------------------------------------------------------------- /src/mcp/battle_server.py: -------------------------------------------------------------------------------- 1 | from src.mcp.server import MCPServer 2 | from src.tool import Battle, Terminate 3 | 4 | 5 | class BattleServer(MCPServer): 6 | def __init__(self, name: str = "BattleServer"): 7 | super().__init__(name) 8 | 9 | def _initialize_standard_tools(self) -> None: 10 | self.tools.update( 11 | { 12 | "terminate": Terminate(), 13 | "battle": Battle(), 14 | } 15 | ) 16 | -------------------------------------------------------------------------------- /src/tool/search/__init__.py: -------------------------------------------------------------------------------- 1 | from src.tool.search.baidu_search import BaiduSearchEngine 2 | from src.tool.search.base import WebSearchEngine 3 | from src.tool.search.bing_search import BingSearchEngine 4 | from src.tool.search.duckduckgo_search import DuckDuckGoSearchEngine 5 | from src.tool.search.google_search import GoogleSearchEngine 6 | 7 | 8 | __all__ = [ 9 | "WebSearchEngine", 10 | "BaiduSearchEngine", 11 | "DuckDuckGoSearchEngine", 12 | "GoogleSearchEngine", 13 | "BingSearchEngine", 14 | ] 15 | -------------------------------------------------------------------------------- /src/mcp/risk_control_server.py: -------------------------------------------------------------------------------- 1 | from src.mcp.server import MCPServer 2 | from src.tool import Terminate 3 | from src.tool.risk_control import RiskControlTool 4 | 5 | 6 | class RiskControlServer(MCPServer): 7 | def __init__(self, name: str = "RiskControlServer"): 8 | super().__init__(name) 9 | 10 | def _initialize_standard_tools(self) -> None: 11 | self.tools.update( 12 | { 13 | "risk_control_tool": RiskControlTool(), 14 | "terminate": Terminate(), 15 | } 16 | ) 17 | -------------------------------------------------------------------------------- /src/mcp/chip_analysis_server.py: -------------------------------------------------------------------------------- 1 | from src.mcp.server import MCPServer 2 | from src.tool import Terminate 3 | from src.tool.chip_analysis import ChipAnalysisTool 4 | 5 | 6 | class ChipAnalysisServer(MCPServer): 7 | def __init__(self, name: str = "ChipAnalysisServer"): 8 | super().__init__(name) 9 | 10 | def _initialize_standard_tools(self) -> None: 11 | self.tools.update( 12 | { 13 | "chip_analysis_tool": ChipAnalysisTool(), 14 | "terminate": Terminate(), 15 | } 16 | ) 17 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | streamlit 2 | asyncio 3 | pydantic~=2.10.6 4 | aiohttp 5 | openai~=1.66.5 6 | fastmcp 7 | python-dotenv 8 | beautifulsoup4~=4.13.3 9 | pandas~=2.2.3 10 | numpy 11 | matplotlib 12 | plotly 13 | nltk 14 | tomli 15 | mcp~=1.5.0 16 | tiktoken~=0.9.0 17 | tenacity~=9.0.0 18 | loguru~=0.7.3 19 | reportlab 20 | googlesearch-python~=1.3.0 21 | baidusearch~=1.0.3 22 | duckduckgo_search~=7.5.3 23 | requests~=2.32.3 24 | duckduckgo-search~=7.5.5 25 | efinance~=0.5.5.2 26 | akshare~=1.16.87 27 | schedule~=1.2.2 28 | uvicorn 29 | starlette 30 | rich~=13.7.1 31 | -------------------------------------------------------------------------------- /src/mcp/big_deal_analysis_server.py: -------------------------------------------------------------------------------- 1 | from src.mcp.server import MCPServer 2 | from src.tool import Terminate 3 | from src.tool.big_deal_analysis import BigDealAnalysisTool 4 | 5 | 6 | class BigDealAnalysisServer(MCPServer): 7 | def __init__(self, name: str = "BigDealAnalysisServer"): 8 | super().__init__(name) 9 | 10 | def _initialize_standard_tools(self) -> None: 11 | self.tools.update( 12 | { 13 | "big_deal_analysis_tool": BigDealAnalysisTool(), 14 | "terminate": Terminate(), 15 | } 16 | ) 17 | -------------------------------------------------------------------------------- /src/mcp/technical_analysis_server.py: -------------------------------------------------------------------------------- 1 | from src.mcp.server import MCPServer 2 | from src.tool import Terminate 3 | from src.tool.technical_analysis import TechnicalAnalysisTool 4 | 5 | 6 | class TechnicalAnalysisServer(MCPServer): 7 | def __init__(self, name: str = "TechnicalAnalysisServer"): 8 | super().__init__(name) 9 | 10 | def _initialize_standard_tools(self) -> None: 11 | self.tools.update( 12 | { 13 | "technical_analysis_tool": TechnicalAnalysisTool(), 14 | "terminate": Terminate(), 15 | } 16 | ) 17 | -------------------------------------------------------------------------------- /src/mcp/sentiment_server.py: -------------------------------------------------------------------------------- 1 | from src.mcp.server import MCPServer 2 | from src.tool import Terminate 3 | from src.tool.sentiment import SentimentTool 4 | from src.tool.web_search import WebSearch 5 | 6 | 7 | class SentimentServer(MCPServer): 8 | def __init__(self, name: str = "SentimentServer"): 9 | super().__init__(name) 10 | 11 | def _initialize_standard_tools(self) -> None: 12 | self.tools.update( 13 | { 14 | "sentiment_tool": SentimentTool(), 15 | "web_search": WebSearch(), 16 | "terminate": Terminate(), 17 | } 18 | ) 19 | -------------------------------------------------------------------------------- /src/tool/__init__.py: -------------------------------------------------------------------------------- 1 | """Tool module for FinGenius platform.""" 2 | 3 | from src.tool.base import BaseTool 4 | from src.tool.battle import Battle 5 | from src.tool.chip_analysis import ChipAnalysisTool 6 | from src.tool.create_chat_completion import CreateChatCompletion 7 | from src.tool.terminate import Terminate 8 | from src.tool.tool_collection import ToolCollection 9 | from src.tool.big_deal_analysis import BigDealAnalysisTool 10 | 11 | 12 | __all__ = [ 13 | "BaseTool", 14 | "Battle", 15 | "ChipAnalysisTool", 16 | "Terminate", 17 | "ToolCollection", 18 | "CreateChatCompletion", 19 | "BigDealAnalysisTool", 20 | ] 21 | -------------------------------------------------------------------------------- /src/mcp/hot_money_srver.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import sys 3 | 4 | from src.mcp.server import MCPServer 5 | from src.tool import Terminate 6 | from src.tool.hot_money import HotMoneyTool 7 | 8 | 9 | logging.basicConfig(level=logging.INFO, handlers=[logging.StreamHandler(sys.stderr)]) 10 | 11 | 12 | class HotMoneyServer(MCPServer): 13 | def __init__(self, name: str = "HotMoneyServer"): 14 | super().__init__(name) 15 | 16 | def _initialize_standard_tools(self) -> None: 17 | self.tools.update( 18 | { 19 | "hot_money_tool": HotMoneyTool(), 20 | "terminate": Terminate(), 21 | } 22 | ) 23 | -------------------------------------------------------------------------------- /src/tool/financial_deep_search/__init__.py: -------------------------------------------------------------------------------- 1 | from src.tool.financial_deep_search.get_section_data import get_all_section 2 | from src.tool.financial_deep_search.index_capital import get_index_capital_flow 3 | from src.tool.financial_deep_search.risk_control_data import ( 4 | get_announcements_with_detail, 5 | get_company_name_for_stock, 6 | get_financial_reports, 7 | get_risk_control_data, 8 | ) 9 | from src.tool.financial_deep_search.stock_capital import ( 10 | fetch_single_stock_capital_flow, 11 | fetch_stock_list_capital_flow, 12 | get_stock_capital_flow, 13 | ) 14 | 15 | 16 | __all__ = [ 17 | "get_stock_capital_flow", 18 | "fetch_single_stock_capital_flow", 19 | "fetch_stock_list_capital_flow", 20 | "get_risk_control_data", 21 | "get_announcements_with_detail", 22 | "get_financial_reports", 23 | "get_company_name_for_stock", 24 | "get_index_capital_flow", 25 | "get_all_section", 26 | ] 27 | -------------------------------------------------------------------------------- /src/tool/terminate.py: -------------------------------------------------------------------------------- 1 | from src.tool.base import BaseTool 2 | 3 | 4 | _TERMINATE_DESCRIPTION = """Terminate the interaction when the request is met OR if the assistant cannot proceed further with the task. 5 | When you have finished all the tasks, call this tool to end the work.""" 6 | 7 | 8 | class Terminate(BaseTool): 9 | name: str = "terminate" 10 | description: str = _TERMINATE_DESCRIPTION 11 | parameters: dict = { 12 | "type": "object", 13 | "properties": { 14 | "status": { 15 | "type": "string", 16 | "description": "The finish status of the interaction.", 17 | "enum": ["success", "failure"], 18 | } 19 | }, 20 | "required": ["status"], 21 | } 22 | 23 | async def execute(self, status: str) -> str: 24 | """Finish the current execution""" 25 | return f"The interaction has been completed with status: {status}" 26 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | # HTML code is incorrectly calculated into statistics, so ignore them 2 | *.html linguist-detectable=false 3 | # Auto detect text files and perform LF normalization 4 | * text=auto eol=lf 5 | # Ensure shell scripts use LF (Linux style) line endings on Windows 6 | *.sh text eol=lf 7 | # Treat specific binary files as binary and prevent line ending conversion 8 | *.png binary 9 | *.jpg binary 10 | *.gif binary 11 | *.ico binary 12 | *.jpeg binary 13 | *.mp3 binary 14 | *.zip binary 15 | *.bin binary 16 | # Preserve original line endings for specific document files 17 | *.doc text eol=crlf 18 | *.docx text eol=crlf 19 | *.pdf binary 20 | # Ensure source code and script files use LF line endings 21 | *.py text eol=lf 22 | *.js text eol=lf 23 | *.html text eol=lf 24 | *.css text eol=lf 25 | # Specify custom diff driver for specific file types 26 | *.md diff=markdown 27 | *.json diff=json 28 | *.mp4 filter=lfs diff=lfs merge=lfs -text 29 | *.mov filter=lfs diff=lfs merge=lfs -text 30 | *.webm filter=lfs diff=lfs merge=lfs -text 31 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/psf/black 3 | rev: 23.1.0 4 | hooks: 5 | - id: black 6 | 7 | - repo: https://github.com/pre-commit/pre-commit-hooks 8 | rev: v4.4.0 9 | hooks: 10 | - id: trailing-whitespace 11 | - id: end-of-file-fixer 12 | - id: check-yaml 13 | - id: check-added-large-files 14 | 15 | - repo: https://github.com/PyCQA/autoflake 16 | rev: v2.0.1 17 | hooks: 18 | - id: autoflake 19 | args: [ 20 | --remove-all-unused-imports, 21 | --ignore-init-module-imports, 22 | --expand-star-imports, 23 | --remove-duplicate-keys, 24 | --remove-unused-variables, 25 | --recursive, 26 | --in-place, 27 | --exclude=__init__.py, 28 | ] 29 | files: \.py$ 30 | 31 | - repo: https://github.com/pycqa/isort 32 | rev: 5.12.0 33 | hooks: 34 | - id: isort 35 | args: [ 36 | "--profile", "black", 37 | "--filter-files", 38 | "--lines-after-imports=2", 39 | ] 40 | -------------------------------------------------------------------------------- /src/tool/financial_deep_search/index_name_map.json: -------------------------------------------------------------------------------- 1 | { 2 | "000001": "上证指数", 3 | "399001": "深证成指", 4 | "899050": "北证50", 5 | "399006": "创业板指", 6 | "000680": "科创综指", 7 | "000688": "科创50", 8 | "399330": "深证100", 9 | "000300": "沪深300", 10 | "000016": "上证50", 11 | "399673": "创业板50", 12 | "000888": "上证综合全收益", 13 | "399750": "深主板50", 14 | "930050": "中证A50", 15 | "000903": "中证A100", 16 | "000510": "中证A500", 17 | "000904": "中证200", 18 | "000905": "中证500", 19 | "000906": "中证800", 20 | "000852": "中证1000", 21 | "932000": "中证2000", 22 | "000985": "中证全指", 23 | "000010": "上证180", 24 | "000009": "上证380", 25 | "000132": "上证100", 26 | "000133": "上证150", 27 | "000003": "B股指数", 28 | "000012": "国债指数", 29 | "000013": "企债指数", 30 | "000011": "基金指数", 31 | "399002": "深成指R", 32 | "399850": "深证50", 33 | "399005": "中小100", 34 | "399003": "成份B指", 35 | "399106": "深证综指", 36 | "399004": "深证100R", 37 | "399007": "深证300", 38 | "399008": "中小300", 39 | "399293": "创业大盘", 40 | "399019": "创业200", 41 | "399020": "创业小盘", 42 | "399100": "新指数", 43 | "399550": "央视50" 44 | } 45 | -------------------------------------------------------------------------------- /src/tool/search/google_search.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from googlesearch import search 4 | 5 | from src.tool.search.base import SearchItem, WebSearchEngine 6 | 7 | 8 | class GoogleSearchEngine(WebSearchEngine): 9 | def perform_search( 10 | self, query: str, num_results: int = 10, *args, **kwargs 11 | ) -> List[SearchItem]: 12 | """ 13 | Google search engine. 14 | 15 | Returns results formatted according to SearchItem model. 16 | """ 17 | raw_results = search(query, num_results=num_results, advanced=True) 18 | 19 | results = [] 20 | for i, item in enumerate(raw_results): 21 | if isinstance(item, str): 22 | # If it's just a URL 23 | results.append( 24 | {"title": f"Google Result {i+1}", "url": item, "description": ""} 25 | ) 26 | else: 27 | results.append( 28 | SearchItem( 29 | title=item.title, url=item.url, description=item.description 30 | ) 31 | ) 32 | 33 | return results 34 | -------------------------------------------------------------------------------- /src/agent/react.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Optional 3 | 4 | from pydantic import Field 5 | 6 | from src.agent.base import BaseAgent 7 | from src.llm import LLM 8 | from src.schema import AgentState, Memory 9 | 10 | 11 | class ReActAgent(BaseAgent, ABC): 12 | name: str 13 | description: Optional[str] = None 14 | 15 | system_prompt: Optional[str] = None 16 | next_step_prompt: Optional[str] = None 17 | 18 | llm: Optional[LLM] = Field(default_factory=LLM) 19 | memory: Memory = Field(default_factory=Memory) 20 | state: AgentState = AgentState.IDLE 21 | 22 | max_steps: int = 10 23 | current_step: int = 0 24 | 25 | @abstractmethod 26 | async def think(self) -> bool: 27 | """Process current state and decide next action""" 28 | 29 | @abstractmethod 30 | async def act(self) -> str: 31 | """Execute decided actions""" 32 | 33 | async def step(self) -> str: 34 | """Execute a single step: think and act.""" 35 | should_act = await self.think() 36 | if not should_act: 37 | return "Thinking complete - no action needed" 38 | return await self.act() 39 | -------------------------------------------------------------------------------- /src/prompt/hot_money.py: -------------------------------------------------------------------------------- 1 | HOT_MONEY_SYSTEM_PROMPT = """你是一位专业的游资行为分析师,擅长解读市场中游资的操作特点、风格和可能的意图。你的任务是基于提供的数据和信息,进行客观的游资行为分析,帮助用户理解游资的操作模式,但绝不提供任何形式的投资建议或引导用户做出投资决策。 2 | 3 | 分析范围 4 | 1. 游资席位的持仓变化 5 | 2. 游资的交易风格和特点 6 | 3. 历史操作模式和偏好 7 | 4. 资金流向和集中度 8 | 5. 游资之间的关联性和协同特征 9 | 10 | 分析方法 11 | 1. 数据分析:通过席位龙虎榜数据、成交量、换手率等量化指标分析 12 | 2. 模式识别:辨别游资的操作套路、惯用手法 13 | 3. 历史追踪:回顾特定游资的历史操作轨迹和成功率 14 | 4. 关联分析:挖掘不同游资之间的关联性和互动模式 15 | 16 | 输出格式 17 | 1. 游资背景简介:对所分析游资的基本情况描述 18 | 2. 操作特点总结:概括该游资的典型操作风格和特征 19 | 3. 近期行为分析:分析其近期操作的特点和可能的思路 20 | 4. 注意事项:提醒用户关注的风险点和需要注意的因素 21 | 22 | 工具使用规范 23 | 在分析过程中,请合理使用以下工具: 24 | - hot_money_tool:获取龙虎榜数据和资金流向信息 25 | - terminate:当你完成了完整的游资分析报告后,必须使用此工具结束任务 26 | 27 | ⚠️ 重要提醒:当你完成了游资分析并准备输出最终报告时,请立即使用terminate工具结束任务,避免无限循环。 28 | 29 | 重要免责声明 30 | 1. 本分析仅供参考,绝不构成任何形式的投资建议 31 | 2. 分析内容基于历史数据和公开信息,不预测未来市场走势 32 | 3. 不推荐、不暗示、不引导用户进行任何具体投资操作 33 | 4. 用户应自行承担投资决策的全部责任和风险 34 | 5. 分析不构成任何买入或卖出的建议,用户必须独立做出决策 35 | 36 | 使用说明 37 | 在提问时,请尽可能提供以下信息以获得更准确的分析: 38 | 1. 具体关注的游资席位或代码 39 | 2. 关注的时间段 40 | 3. 已知的相关信息 41 | 4. 希望了解的具体方面 42 | 43 | 示例分析框架 44 | 1. 游资背景分析 45 | - 历史活跃度和关注领域 46 | - 典型操作风格和特点 47 | - 历史成功案例特征 48 | 2. 近期操作分析 49 | - 资金规模和集中度变化 50 | - 进出节奏和持仓周期 51 | - 与其他席位的关联性 52 | 3. 行为模式解读 53 | - 可能的操作思路分析 54 | - 操作阶段判断 55 | - 风险点提示 56 | 4. 总结观点 57 | - 客观中立的行为总结 58 | - 值得关注的关键指标 59 | - 再次强调:本分析不构成任何投资建议,仅供参考 60 | """ 61 | -------------------------------------------------------------------------------- /src/agent/report.py: -------------------------------------------------------------------------------- 1 | from pydantic import Field 2 | 3 | from src.agent.mcp import MCPAgent 4 | from src.prompt.mcp import NEXT_STEP_PROMPT_ZN 5 | from src.prompt.report import REPORT_SYSTEM_PROMPT 6 | from src.tool import Terminate, ToolCollection 7 | from src.tool.create_html import CreateHtmlTool 8 | 9 | 10 | class ReportAgent(MCPAgent): 11 | """Report generation agent that synthesizes insights from other agents.""" 12 | 13 | name: str = "report_agent" 14 | description: str = "Generates comprehensive reports by synthesizing insights from other specialized agents." 15 | system_prompt: str = REPORT_SYSTEM_PROMPT 16 | next_step_prompt: str = NEXT_STEP_PROMPT_ZN 17 | 18 | # Initialize with FinGenius tools and proper type annotation 19 | available_tools: ToolCollection = Field( 20 | default_factory=lambda: ToolCollection( 21 | CreateHtmlTool(), 22 | Terminate(), 23 | ) 24 | ) 25 | special_tool_names: list[str] = Field(default_factory=lambda: [Terminate().name]) 26 | 27 | 28 | if __name__ == "__main__": 29 | import asyncio 30 | 31 | async def run_agent(): 32 | agent = await ReportAgent.create() 33 | await agent.initialize() 34 | prompt = "生成一个关于AI的报告 html格式" 35 | await agent.run(prompt) 36 | 37 | asyncio.run(run_agent()) 38 | -------------------------------------------------------------------------------- /src/tool/search/base.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | 3 | from pydantic import BaseModel, Field 4 | 5 | 6 | class SearchItem(BaseModel): 7 | """Represents a single search result item""" 8 | 9 | title: str = Field(description="The title of the search result") 10 | url: str = Field(description="The URL of the search result") 11 | description: Optional[str] = Field( 12 | default=None, description="A description or snippet of the search result" 13 | ) 14 | 15 | def __str__(self) -> str: 16 | """String representation of a search result item.""" 17 | return f"{self.title} - {self.url}" 18 | 19 | 20 | class WebSearchEngine(BaseModel): 21 | """Base class for web search engines.""" 22 | 23 | model_config = {"arbitrary_types_allowed": True} 24 | 25 | def perform_search( 26 | self, query: str, num_results: int = 10, *args, **kwargs 27 | ) -> List[SearchItem]: 28 | """ 29 | Perform a web search and return a list of search items. 30 | 31 | Args: 32 | query (str): The search query to submit to the search engine. 33 | num_results (int, optional): The number of search results to return. Default is 10. 34 | args: Additional arguments. 35 | kwargs: Additional keyword arguments. 36 | 37 | Returns: 38 | List[SearchItem]: A list of SearchItem objects matching the search query. 39 | """ 40 | raise NotImplementedError 41 | -------------------------------------------------------------------------------- /src/agent/big_deal_analysis.py: -------------------------------------------------------------------------------- 1 | from typing import Any, List, Optional 2 | from pydantic import Field 3 | 4 | from src.agent.mcp import MCPAgent 5 | from src.prompt.big_deal_analysis import BIG_DEAL_SYSTEM_PROMPT 6 | from src.prompt.mcp import NEXT_STEP_PROMPT_ZN 7 | from src.schema import Message 8 | from src.tool import Terminate, ToolCollection 9 | from src.tool.big_deal_analysis import BigDealAnalysisTool 10 | 11 | 12 | class BigDealAnalysisAgent(MCPAgent): 13 | """大单异动分析 Agent""" 14 | 15 | name: str = "big_deal_analysis_agent" 16 | description: str = "分析市场及个股大单资金异动,为投资决策提供依据。" 17 | 18 | system_prompt: str = BIG_DEAL_SYSTEM_PROMPT 19 | next_step_prompt: str = NEXT_STEP_PROMPT_ZN 20 | 21 | available_tools: ToolCollection = Field( 22 | default_factory=lambda: ToolCollection( 23 | BigDealAnalysisTool(), 24 | Terminate(), 25 | ) 26 | ) 27 | # 限制单次观察字符,防止内存过大导致 LLM 无法响应 28 | max_observe: int = 10000 29 | special_tool_names: List[str] = Field(default_factory=lambda: [Terminate().name]) 30 | 31 | async def run( 32 | self, 33 | request: Optional[str] = None, 34 | stock_code: Optional[str] = None, 35 | ) -> Any: 36 | """Run big deal analysis""" 37 | if stock_code and not request: 38 | self.memory.add_message( 39 | Message.system_message( 40 | f"你正在分析股票 {stock_code} 的大单资金流向,请综合资金异动与价格走势给出结论。" 41 | ) 42 | ) 43 | request = f"请对 {stock_code} 进行大单异动深度分析,并生成投资建议。" 44 | 45 | return await super().run(request) -------------------------------------------------------------------------------- /src/prompt/risk_control.py: -------------------------------------------------------------------------------- 1 | RISK_SYSTEM_PROMPT = """你是一位专业的风险控制顾问,专注于财务风险和法务风险的识别、评估与管理。你擅长从财务数据和法律法规角度分析潜在风险点,并提供风险控制策略建议。你的回答基于风险管理的普遍原则和最佳实践,不构成具体的财务或法律建议。 2 | 分析范围 3 | 财务风险 4 | 流动性风险 5 | 信用风险 6 | 市场风险 7 | 运营财务风险 8 | 财务报表异常识别 9 | 资金管理风险 10 | 预算控制风险 11 | 财务合规风险 12 | 法务风险 13 | 合同风险 14 | 监管合规风险 15 | 知识产权风险 16 | 劳动法律风险 17 | 诉讼风险 18 | 企业治理风险 19 | 数据合规与隐私风险 20 | 行业特定法规风险 21 | 风险评估方法 22 | 风险识别:通过系统性分析识别潜在风险点 23 | 风险评估:评估风险发生的可能性和潜在影响 24 | 风险分级:按严重程度和紧急度对风险进行分级 25 | 控制措施建议:提供降低或消除风险的可行措施 26 | 监控建议:提出持续监控风险的指标和方法 27 | 输出格式 28 | 风险概述:对提出问题的整体风险状况评估 29 | 关键风险识别:列出主要财务和法务风险点 30 | 风险评估矩阵:按影响力和可能性评级 31 | 控制措施建议:针对各风险点的具体管控建议 32 | 监控机制:持续监控风险的方法和指标 33 | 工具使用指南 34 | 在分析过程中,请合理使用以下工具: 35 | - risk_control_tool:获取股票的财务数据和法务公告数据 36 | - terminate:当你完成了完整的风险分析报告后,必须使用此工具结束任务 37 | 38 | ⚠️ 重要提醒:当你完成了风险分析并准备输出最终报告时,请立即使用terminate工具结束任务,避免无限循环。 39 | 40 | 重要免责声明 41 | 本分析仅提供风险管理的一般性建议,不构成具体的财务或法律建议 42 | 分析基于提供的信息和一般风险管理原则,不替代专业财务顾问或法律顾问的意见 43 | 用户在做出任何财务或法律决策前,应咨询具有相关资质的专业人士 44 | 风险评估结果取决于提供信息的准确性和完整性 45 | 不对用户基于本分析做出的决策结果承担责任 46 | 使用指南 47 | 在提问时,请尽可能提供以下信息以获得更准确的风险分析: 48 | 企业或项目的基本情况(规模、行业、发展阶段等) 49 | 具体关注的财务或法务问题 50 | 已知的风险点或担忧领域 51 | 现有的风险控制措施 52 | 适用的主要法规或标准 53 | 示例分析框架 54 | 1. 财务风险分析 55 | 流动性评估:现金流、营运资金、短期偿债能力 56 | 财务杠杆风险:负债率、利息覆盖率 57 | 财务报表隐患:异常指标、会计处理风险 58 | 内控机制评估:资金审批流程、职责分离 59 | 2. 法务风险分析 60 | 合同管理风险:条款缺陷、履约风险、终止条件 61 | 合规风险:行业法规、许可证要求、报告义务 62 | 知识产权保护:商标、专利、商业秘密保护措施 63 | 公司治理风险:决策流程、信息披露、利益冲突 64 | 3. 综合风险评估 65 | 风险关联性:财务与法务风险的交叉影响 66 | 系统性风险:可能导致连锁反应的核心风险 67 | 风险优先级:需要立即关注的高优先级风险 68 | 4. 风险控制建议 69 | 预防措施:避免风险发生的策略 70 | 缓解措施:减轻风险影响的方法 71 | 转移策略:保险或外包等风险转移方案 72 | 监控机制:关键风险指标(KRI)设置 73 | 5. 行动计划建议 74 | 短期措施:立即可执行的风险控制行动 75 | 中长期策略:系统性风险管理体系建设 76 | 责任分配:风险管理任务的部门分工建议 77 | 定期评估机制:风险控制效果的跟踪评估方法 78 | """ 79 | -------------------------------------------------------------------------------- /src/logger.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | from pathlib import Path 3 | 4 | from loguru import logger as _logger 5 | from rich.logging import RichHandler 6 | 7 | from src.config import PROJECT_ROOT 8 | 9 | 10 | _print_level = "INFO" 11 | 12 | 13 | def define_log_level(print_level="INFO", logfile_level="DEBUG", name: str = None): 14 | """Adjust the log level to above level""" 15 | global _print_level 16 | _print_level = print_level 17 | 18 | current_date = datetime.now() 19 | formatted_date = current_date.strftime("%Y%m%d%H%M%S") 20 | log_name = ( 21 | f"{name}_{formatted_date}" if name else formatted_date 22 | ) # name a log with prefix name 23 | 24 | _logger.remove() 25 | 26 | # Only write to file by default 27 | _logger.add(PROJECT_ROOT / f"logs/{log_name}.log", level=logfile_level) 28 | 29 | # Add terminal handler only if explicitly requested 30 | if print_level != "OFF": 31 | _logger.add( 32 | RichHandler(rich_tracebacks=True, markup=True, show_time=False, show_path=False), 33 | level=print_level, 34 | format="{message}", 35 | ) 36 | 37 | return _logger 38 | 39 | 40 | logger = define_log_level(print_level="OFF") # Default to file-only logging 41 | 42 | 43 | if __name__ == "__main__": 44 | logger.info("Starting application") 45 | logger.debug("Debug message") 46 | logger.warning("Warning message") 47 | logger.error("Error message") 48 | logger.critical("Critical message") 49 | 50 | try: 51 | raise ValueError("Test error") 52 | except Exception as e: 53 | logger.exception(f"An error occurred: {e}") 54 | -------------------------------------------------------------------------------- /src/prompt/technical_analysis.py: -------------------------------------------------------------------------------- 1 | TECHNICAL_ANALYSIS_SYSTEM_PROMPT = """你是一位专业的技术分析师,擅长运用各种技术指标、图表模式和量价分析来解读股票市场的价格走势。你能够客观分析股票的技术面情况,识别趋势、支撑阻力位、交易信号等,为用户提供专业的技术分析视角。你的任务是基于技术分析原理进行客观分析,不提供具体的投资建议。 2 | 3 | ## 分析范围 4 | - **趋势分析**:识别主要趋势、次要趋势和短期波动 5 | - **支撑阻力**:确定关键的支撑位和阻力位 6 | - **技术指标**:RSI、MACD、KDJ、布林带、均线系统等 7 | - **K线形态**:单根K线、K线组合形态分析 8 | - **成交量分析**:量价关系、成交量指标 9 | - **图表形态**:头肩顶底、双顶双底、三角形、楔形等 10 | - **技术位点**:突破位、回调位、压力位 11 | 12 | ## 工具使用规范 13 | 在分析过程中,请合理使用以下工具: 14 | - **technical_analysis_tool**:获取股票的K线数据、技术指标和成交量信息 15 | - **terminate**:当你完成了完整的技术分析报告后,必须使用此工具结束任务 16 | 17 | ⚠️ **重要提醒**:当你完成了技术分析并准备输出最终报告时,请立即使用terminate工具结束任务,避免无限循环。 18 | 19 | ## 分析方法 20 | - **多时间框架分析**:结合日线、周线、月线进行综合判断 21 | - **趋势确认**:使用多个指标验证趋势的有效性 22 | - **量价配合**:观察价格变动与成交量的协调性 23 | - **形态识别**:识别经典的技术形态和突破信号 24 | - **指标背离**:发现价格与指标的背离现象 25 | 26 | ## 输出格式 27 | 1. **技术概述**:股票当前的技术面整体状况 28 | 2. **趋势分析**:主要趋势方向和趋势强度评估 29 | 3. **关键位点**:重要的支撑位、阻力位、突破位 30 | 4. **技术指标解读**:主要技术指标的当前状态和信号 31 | 5. **K线形态分析**:近期K线的形态特征和含义 32 | 6. **成交量分析**:量价关系的协调性和异常情况 33 | 7. **技术信号总结**:当前的买卖信号和操作提示 34 | 35 | ## 重要免责声明 36 | - 本分析仅供参考,不构成任何投资建议 37 | - 技术分析基于历史数据,不能保证未来表现 38 | - 市场存在不确定性,技术信号可能失效 39 | - 用户应结合基本面分析和自身风险承受能力做出决策 40 | - 不对基于本分析的投资结果承担责任 41 | 42 | ## 分析框架 43 | ### 1. 趋势系统分析 44 | - **主要趋势**:判断股票的长期运行方向 45 | - **均线系统**:多条均线的排列和交叉情况 46 | - **趋势线**:上升趋势线、下降趋势线的有效性 47 | - **趋势强度**:评估当前趋势的可持续性 48 | 49 | ### 2. 技术指标综合分析 50 | - **动量指标**:RSI、KDJ等超买超卖情况 51 | - **趋势指标**:MACD金叉死叉、MACD柱状线变化 52 | - **压力支撑**:布林带上下轨的压力支撑作用 53 | - **成交量指标**:OBV、量比等资金流向判断 54 | 55 | ### 3. K线形态分析 56 | - **单根K线**:大阳线、大阴线、十字星、锤子线等 57 | - **K线组合**:早晨之星、黄昏之星、三只乌鸦等 58 | - **缺口分析**:普通缺口、突破缺口、衰竭缺口的性质 59 | 60 | ### 4. 量价关系分析 61 | - **量价配合**:价涨量增、价跌量缩的健康状态 62 | - **量价背离**:价格创新高而成交量萎缩的警示信号 63 | - **异常放量**:突然的成交量放大及其含义 64 | 65 | ### 5. 图表形态识别 66 | - **反转形态**:头肩顶底、双顶双底、V形反转 67 | - **整理形态**:三角形、矩形、楔形、旗形 68 | - **突破确认**:形态突破的有效性和目标位测算 69 | 70 | ### 6. 风险控制要点 71 | - **止损位设置**:基于技术位点的止损建议 72 | - **风险提示**:技术面存在的主要风险点 73 | - **操作策略**:基于技术分析的一般性操作思路 74 | """ 75 | -------------------------------------------------------------------------------- /src/prompt/report.py: -------------------------------------------------------------------------------- 1 | REPORT_SYSTEM_PROMPT = """你是一位专业的综合分析报告生成专家,能够整合来自多个专业分析域(游资分析、风险控制、舆情分析、技术面分析)的信息,形成系统化、结构清晰、逻辑严密的综合分析报告。你的工作是将这些专业领域的分析进行融合、提炼和深化,找出其中的关联性和互动关系,生成具有战略价值的综合报告。 2 | 3 | 报告目标 4 | 整合多维度分析信息,提供全局视角 5 | 挖掘不同分析领域之间的关联性和相互影响 6 | 识别综合分析中的关键风险点和机会点 7 | 构建清晰、专业、易于理解的分析框架 8 | 提供基于客观分析的综合观点和思考方向 9 | 10 | 输入信息 11 | 本报告基于以下四个专业分析领域的输入: 12 | 游资分析:游资席位的操作特点、风格和可能意图 13 | 风险控制:财务风险和法务风险的识别、评估与管理建议 14 | 舆情分析:媒体平台舆论动态、情感倾向、传播规律和影响因素 15 | 技术面分析:图表形态、技术指标和量价关系分析 16 | 17 | 报告结构与内容 18 | 1. 执行摘要 19 | 综合分析的核心发现和关键结论 20 | 主要风险点和值得关注的重要信号 21 | 不同分析维度的核心观点汇总 22 | 2. 多维度分析综述 23 | 游资行为解读:主要游资动向及其操作特点概述 24 | 风险态势评估:财务和法务风险的整体态势 25 | 舆论环境分析:主要舆论倾向和传播特点 26 | 技术面状况:技术指标和形态的关键信号 27 | 3. 关联性分析 28 | 游资行为与技术面信号的呼应关系 29 | 舆情变化对技术面的影响路径 30 | 风险因素与舆情演变的关联模式 31 | 各维度分析之间的交叉验证和互补性 32 | 4. 深度解析 33 | 游资-技术互动:游资行为如何影响或回应技术形态 34 | 舆情-资金关系:舆论环境变化与资金流向的互动 35 | 风险-市场反应:风险因素对市场情绪和行为的影响 36 | 技术-基本面结合点:技术信号与基本面因素的结合分析 37 | 5. 综合风险评估 38 | 多维度风险因素的叠加效应 39 | 潜在风险的传导路径和影响范围 40 | 风险预警指标体系建议 41 | 风险应对的优先级排序 42 | 6. 观察重点与监测建议 43 | 后续需重点关注的关键指标和信号 44 | 不同时间周期的监测重点 45 | 可能的转折点和触发条件 46 | 情境分析和敏感性分析 47 | 7. 综合结论 48 | 基于多维度分析的整体判断 49 | 关键不确定性因素 50 | 不同情境下的可能发展路径 51 | 分析的局限性说明 52 | 重要原则 53 | 专业性原则 54 | 客观中立:保持客观分析立场,不偏向任何特定观点 55 | 数据支持:分析结论有数据和事实支持,避免主观臆测 56 | 逻辑严密:分析推理过程清晰,逻辑链条完整 57 | 全面平衡:呈现多种可能性,不强调单一结果 58 | 整合性原则 59 | 有机融合:不是简单拼接各领域分析,而是有机整合 60 | 关联挖掘:深入挖掘不同分析领域之间的关联性 61 | 一致性检验:检验不同领域分析结果的一致性和差异性 62 | 矛盾处理:对不同分析领域出现的矛盾进行合理解释 63 | 实用性原则 64 | 重点突出:突出最关键的发现和最重要的结论 65 | 层次清晰:信息呈现有明确层次,便于理解和参考 66 | 精简有效:避免冗余信息,保持报告简洁高效 67 | 可操作性:提供具有实际参考价值的观察视角 68 | 免责声明 69 | 本综合报告基于各专业领域提供的分析信息,不构成任何投资建议 70 | 报告中的分析和观点仅供参考,使用者应自行判断其适用性 71 | 报告不对市场未来走势做出预测,也不推荐任何具体投资行为 72 | 使用者应结合自身情况和其他信息来源做出独立决策 73 | 使用指南 74 | 提供给综合报告生成器的输入应包括: 75 | 各专业领域(游资、风险、舆情、技术面)的分析内容 76 | 需要特别关注的具体问题或领域 77 | 报告的主要目标读者和用途 78 | 报告的期望深度和长度 79 | 特定的报告风格要求(如正式程度、专业术语使用等) 80 | 报告样式指南 81 | 格式建议 82 | 使用清晰的标题和副标题层级结构 83 | 适当使用要点符号提高可读性 84 | 关键结论或警示信息可使用醒目格式 85 | 适当使用图表概念说明复杂关系 86 | 重要数据或比较可使用表格形式展示 87 | 语言风格 88 | 保持专业、简洁、精准的表达 89 | 避免过度技术性术语,确保可理解性 90 | 结论表述谨慎客观,避免绝对化表达 91 | 多维度分析有不同声音时,平衡呈现 92 | 信息密度 93 | 执行摘要高度凝练,突出关键点 94 | 主体部分保持适当信息密度,详略得当 95 | 重要分析和次要分析区分明确 96 | 附加信息和背景知识适当补充,但不喧宾夺主 97 | """ 98 | -------------------------------------------------------------------------------- /src/tool/search/baidu_search.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from baidusearch.baidusearch import search 4 | 5 | from src.tool.search.base import SearchItem, WebSearchEngine 6 | 7 | 8 | class BaiduSearchEngine(WebSearchEngine): 9 | def perform_search( 10 | self, query: str, num_results: int = 10, *args, **kwargs 11 | ) -> List[SearchItem]: 12 | """ 13 | Baidu search engine. 14 | 15 | Returns results formatted according to SearchItem model. 16 | """ 17 | raw_results = search(query, num_results=num_results) 18 | 19 | # Convert raw results to SearchItem format 20 | results = [] 21 | for i, item in enumerate(raw_results): 22 | if isinstance(item, str): 23 | # If it's just a URL 24 | results.append( 25 | SearchItem(title=f"Baidu Result {i+1}", url=item, description=None) 26 | ) 27 | elif isinstance(item, dict): 28 | # If it's a dictionary with details 29 | results.append( 30 | SearchItem( 31 | title=item.get("title", f"Baidu Result {i+1}"), 32 | url=item.get("url", ""), 33 | description=item.get("abstract", None), 34 | ) 35 | ) 36 | else: 37 | # Try to get attributes directly 38 | try: 39 | results.append( 40 | SearchItem( 41 | title=getattr(item, "title", f"Baidu Result {i+1}"), 42 | url=getattr(item, "url", ""), 43 | description=getattr(item, "abstract", None), 44 | ) 45 | ) 46 | except Exception: 47 | # Fallback to a basic result 48 | results.append( 49 | SearchItem( 50 | title=f"Baidu Result {i+1}", url=str(item), description=None 51 | ) 52 | ) 53 | 54 | return results 55 | -------------------------------------------------------------------------------- /src/agent/chip_analysis.py: -------------------------------------------------------------------------------- 1 | from typing import Any, List, Optional 2 | 3 | from pydantic import Field 4 | 5 | from src.agent.mcp import MCPAgent 6 | from src.prompt.chip_analysis import CHIP_ANALYSIS_SYSTEM_PROMPT 7 | from src.prompt.mcp import NEXT_STEP_PROMPT_ZN 8 | from src.schema import Message 9 | from src.tool import Terminate, ToolCollection 10 | from src.tool.chip_analysis import ChipAnalysisTool 11 | 12 | 13 | class ChipAnalysisAgent(MCPAgent): 14 | """筹码分析agent,专注于股票筹码分布和技术分析""" 15 | 16 | name: str = "chip_analysis_agent" 17 | description: str = ( 18 | "专业的筹码分析师,擅长分析股票筹码分布、主力行为、散户情绪和A股特色筹码技术分析" 19 | ) 20 | 21 | system_prompt: str = CHIP_ANALYSIS_SYSTEM_PROMPT 22 | next_step_prompt: str = NEXT_STEP_PROMPT_ZN 23 | 24 | # Initialize with FinGenius tools with proper type annotation 25 | available_tools: ToolCollection = Field( 26 | default_factory=lambda: ToolCollection( 27 | ChipAnalysisTool(), 28 | Terminate(), 29 | ) 30 | ) 31 | special_tool_names: List[str] = Field(default_factory=lambda: [Terminate().name]) 32 | 33 | async def run( 34 | self, request: Optional[str] = None, stock_code: Optional[str] = None 35 | ) -> Any: 36 | """运行筹码分析,分析指定股票的筹码分布和技术指标 37 | 38 | Args: 39 | request: Optional initial request to process. If provided, overrides stock_code parameter. 40 | stock_code: The stock code/ticker to analyze 41 | 42 | Returns: 43 | Dictionary containing comprehensive chip analysis results 44 | """ 45 | # If stock_code is provided but request is not, create request from stock_code 46 | if stock_code and not request: 47 | # Set up system message about the stock being analyzed 48 | self.memory.add_message( 49 | Message.system_message( 50 | f"你正在分析股票 {stock_code} 的筹码分布。请使用筹码分析工具获取筹码分布数据,并进行全面的筹码技术分析,包括主力成本、套牢区、集中度等关键指标。" 51 | ) 52 | ) 53 | request = f"请对 {stock_code} 进行全面的筹码分析,包括筹码分布、主力行为、散户情绪和交易建议。" 54 | 55 | # Call parent implementation with the request 56 | return await super().run(request) -------------------------------------------------------------------------------- /src/agent/risk_control.py: -------------------------------------------------------------------------------- 1 | from typing import Any, List, Optional 2 | 3 | from pydantic import Field 4 | 5 | from src.agent.mcp import MCPAgent 6 | from src.prompt.mcp import NEXT_STEP_PROMPT_ZN 7 | from src.prompt.risk_control import RISK_SYSTEM_PROMPT 8 | from src.schema import Message 9 | from src.tool import Terminate, ToolCollection 10 | from src.tool.risk_control import RiskControlTool 11 | 12 | 13 | class RiskControlAgent(MCPAgent): 14 | """Risk analysis agent focused on identifying and quantifying investment risks.""" 15 | 16 | name: str = "risk_control_agent" 17 | description: str = "Analyzes financial risks and proposes risk control strategies for stock investments." 18 | system_prompt: str = RISK_SYSTEM_PROMPT 19 | next_step_prompt: str = NEXT_STEP_PROMPT_ZN 20 | 21 | # Initialize with FinGenius tools with proper type annotation 22 | available_tools: ToolCollection = Field( 23 | default_factory=lambda: ToolCollection( 24 | RiskControlTool(), 25 | Terminate(), 26 | ) 27 | ) 28 | special_tool_names: List[str] = Field(default_factory=lambda: [Terminate().name]) 29 | 30 | async def run( 31 | self, request: Optional[str] = None, stock_code: Optional[str] = None 32 | ) -> Any: 33 | """Run risk analysis on the given stock. 34 | 35 | Args: 36 | request: Optional initial request to process. If provided, overrides stock_code parameter. 37 | stock_code: The stock code/ticker to analyze 38 | 39 | Returns: 40 | Dictionary containing comprehensive risk analysis 41 | """ 42 | # If stock_code is provided but request is not, create request from stock_code 43 | if stock_code and not request: 44 | # Set up system message about the stock being analyzed 45 | self.memory.add_message( 46 | Message.system_message( 47 | f"你正在分析股票 {stock_code} 的风险因素。请收集相关财务数据并进行全面风险评估。" 48 | ) 49 | ) 50 | request = f"请对 {stock_code} 进行全面的风险分析。" 51 | 52 | # Call parent implementation with the request 53 | return await super().run(request) 54 | -------------------------------------------------------------------------------- /config/config.example.toml: -------------------------------------------------------------------------------- 1 | # Global LLM configuration 2 | # 模型越好,生成的效果越好,建议使用最好的模型。 3 | [llm] 4 | api_type = "openai" # 添加API类型,使用OpenAI兼容的API 5 | model = "claude-3-7-sonnet-20250219" # The LLM model to use, better use tool supported model 6 | base_url = "https://api.anthropic.com/v1/" # API endpoint URL 7 | api_key = "YOUR_API_KEY" # Your API key 8 | max_tokens = 8192 # Maximum number of tokens in the response 9 | temperature = 0.0 # Controls randomness 10 | 11 | # [llm] #AZURE OPENAI: 12 | # api_type= 'azure' 13 | # model = "YOUR_MODEL_NAME" #"gpt-4o-mini" 14 | # base_url = "{YOUR_AZURE_ENDPOINT.rstrip('/')}/openai/deployments/{AZURE_DEPOLYMENT_ID}" 15 | # api_key = "AZURE API KEY" 16 | # max_tokens = 8096 17 | # temperature = 0.0 18 | # api_version="AZURE API VERSION" #"2024-08-01-preview" 19 | 20 | 21 | # [llm] 22 | # api_type = "ollama" 23 | # model = "你的模型名称" # 例如: "llama3.2", "qwen2.5", "deepseek-coder" 24 | # base_url = "http://10.24.163.221:8080/v1" # 你的Ollama服务地址 25 | # api_key = "ollama" # 可以是任意值,Ollama会忽略但OpenAI SDK需要 26 | # max_tokens = 4096 27 | # temperature = 0.0 28 | 29 | # Optional configuration for specific LLM models 30 | # [llm.vision] 31 | # model = "claude-3-7-sonnet-20250219" # The vision model to use 32 | # base_url = "https://api.anthropic.com/v1/" # API endpoint URL for vision model 33 | # api_key = "YOUR_API_KEY" # Your API key for vision model 34 | # max_tokens = 8192 # Maximum number of tokens in the response 35 | # temperature = 0.0 # Controls randomness for vision model 36 | 37 | # [llm.vision] #OLLAMA VISION: 38 | # api_type = 'ollama' 39 | # model = "llama3.2-vision" 40 | # base_url = "http://localhost:11434/v1" 41 | # api_key = "ollama" 42 | # max_tokens = 4096 43 | # temperature = 0.0 44 | 45 | # Optional configuration, Search settings. 46 | [search] 47 | # Search engine for agent to use. Default is "Google", can be set to "Baidu" or "DuckDuckGo" or "Bing" 48 | # 对于国内用户,建议使用以下搜索引擎优先级: 49 | # Baidu(百度)- 国内访问最稳定 50 | # Bing(必应)- 国际化且国内可用 51 | # Google - 作为备选(需要良好的国际网络) 52 | # DuckDuckGo - 作为备选(需要良好的国际网络) 53 | engine = "Bing" 54 | -------------------------------------------------------------------------------- /src/agent/hot_money.py: -------------------------------------------------------------------------------- 1 | from typing import Any, List, Optional 2 | 3 | from pydantic import Field 4 | 5 | from src.agent.mcp import MCPAgent 6 | from src.prompt.hot_money import HOT_MONEY_SYSTEM_PROMPT 7 | from src.prompt.mcp import NEXT_STEP_PROMPT_ZN 8 | from src.schema import Message 9 | from src.tool import Terminate, ToolCollection 10 | from src.tool.hot_money import HotMoneyTool 11 | 12 | 13 | class HotMoneyAgent(MCPAgent): 14 | """Hot money analysis agent focused on institutional trading patterns.""" 15 | 16 | name: str = "hot_money_agent" 17 | description: str = ( 18 | "Analyzes institutional trading patterns, fund positions, and capital flows." 19 | ) 20 | 21 | system_prompt: str = HOT_MONEY_SYSTEM_PROMPT 22 | next_step_prompt: str = NEXT_STEP_PROMPT_ZN 23 | 24 | # Initialize with FinGenius tools with proper type annotation 25 | available_tools: ToolCollection = Field( 26 | default_factory=lambda: ToolCollection( 27 | HotMoneyTool(), 28 | Terminate(), 29 | ) 30 | ) 31 | special_tool_names: List[str] = Field(default_factory=lambda: [Terminate().name]) 32 | 33 | async def run( 34 | self, request: Optional[str] = None, stock_code: Optional[str] = None 35 | ) -> Any: 36 | """Run institutional trading analysis on the given stock. 37 | 38 | Args: 39 | request: Optional initial request to process. If provided, overrides stock_code parameter. 40 | stock_code: The stock code/ticker to analyze 41 | 42 | Returns: 43 | Dictionary containing institutional trading analysis 44 | """ 45 | # If stock_code is provided but request is not, create request from stock_code 46 | if stock_code and not request: 47 | # Set up system message about the stock being analyzed 48 | self.memory.add_message( 49 | Message.system_message( 50 | f"你正在分析股票 {stock_code} 的机构交易行为。请识别主要机构投资者,追踪持股变动,并分析资金流向与交易模式。" 51 | ) 52 | ) 53 | request = f"请分析 {stock_code} 的机构交易和资金持仓情况。" 54 | 55 | # Call parent implementation with the request 56 | return await super().run(request) 57 | -------------------------------------------------------------------------------- /src/tool/search/duckduckgo_search.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from duckduckgo_search import DDGS 4 | 5 | from src.tool.search.base import SearchItem, WebSearchEngine 6 | 7 | 8 | class DuckDuckGoSearchEngine(WebSearchEngine): 9 | def perform_search( 10 | self, query: str, num_results: int = 10, *args, **kwargs 11 | ) -> List[SearchItem]: 12 | """ 13 | DuckDuckGo search engine. 14 | 15 | Returns results formatted according to SearchItem model. 16 | """ 17 | raw_results = DDGS().text(query, max_results=num_results) 18 | 19 | results = [] 20 | for i, item in enumerate(raw_results): 21 | if isinstance(item, str): 22 | # If it's just a URL 23 | results.append( 24 | SearchItem( 25 | title=f"DuckDuckGo Result {i + 1}", url=item, description=None 26 | ) 27 | ) 28 | elif isinstance(item, dict): 29 | # Extract data from the dictionary 30 | results.append( 31 | SearchItem( 32 | title=item.get("title", f"DuckDuckGo Result {i + 1}"), 33 | url=item.get("href", ""), 34 | description=item.get("body", None), 35 | ) 36 | ) 37 | else: 38 | # Try to extract attributes directly 39 | try: 40 | results.append( 41 | SearchItem( 42 | title=getattr(item, "title", f"DuckDuckGo Result {i + 1}"), 43 | url=getattr(item, "href", ""), 44 | description=getattr(item, "body", None), 45 | ) 46 | ) 47 | except Exception: 48 | # Fallback 49 | results.append( 50 | SearchItem( 51 | title=f"DuckDuckGo Result {i + 1}", 52 | url=str(item), 53 | description=None, 54 | ) 55 | ) 56 | 57 | return results 58 | -------------------------------------------------------------------------------- /src/prompt/chip_analysis.py: -------------------------------------------------------------------------------- 1 | CHIP_ANALYSIS_SYSTEM_PROMPT = """你是一位专业的筹码分析师,专精于A股市场的筹码分布技术分析。你能够深入解读筹码分布背后的主力意图、散户行为和市场博弈格局,为投资决策提供核心依据。 2 | 3 | ## 核心专业领域 4 | 5 | ### 1. 筹码分布分析 6 | - **筹码峰识别**:准确识别单峰、双峰、多峰分布形态 7 | - **集中度分析**:90%/70%筹码集中度计算与解读 8 | - **筹码迁移**:追踪筹码从低位向高位或高位向低位的转移过程 9 | - **筹码锁定**:判断筹码的稳定性和持仓意愿 10 | 11 | ### 2. 主力行为分析(A股机构思维) 12 | - **主力成本识别**:通过筹码峰值精准定位主力建仓成本区间 13 | - **成本乖离率分析**:计算当前价格与主力成本的偏离程度 14 | - **控盘程度评估**:基于筹码集中度判断主力控盘强度 15 | - **主力获利空间**:评估主力当前的盈利状况和操作空间 16 | 17 | ### 3. 散户行为分析(散户市特征) 18 | - **套牢区识别**:准确定位散户被套区域和套牢深度 19 | - **恐慌情绪分析**:通过低位筹码变化判断散户恐慌程度 20 | - **跟风盘分析**:识别上涨过程中新进散户的行为特征 21 | - **割肉行为**:分析散户在底部的割肉强度和时机 22 | 23 | ### 4. A股特色分析 24 | - **政策市响应**:分析政策利好前后的筹码分布变化 25 | - **游资操作模式**:识别涨停板、一日游、龙回头等游资操作特征 26 | - **机构调仓轨迹**:追踪季度调仓、北上资金、公募抱团的筹码变化 27 | - **减持窗口预警**:预测大股东减持和解禁压力 28 | 29 | ## 分析方法论 30 | 31 | ### 技术指标体系 32 | 1. **主力成本乖离率** = (当前价-主力成本)/主力成本 × 100% 33 | 2. **散户套牢深度** = (最高套牢区价格-当前价)/当前价 × 100% 34 | 3. **筹码稳定指数** = 长期持有筹码占比 35 | 4. **异动转移率** = 近期筹码变动量/总筹码量 36 | 37 | ### 交易信号识别 38 | **买入信号**: 39 | - 底部单峰密集:90%集中度<15% + 获利比例<20% 40 | - 主力成本支撑:价格回踩主力成本线 + 筹码锁定率>60% 41 | - 恐慌筹码收集:单日筹码下移率>20% + 量能萎缩 42 | 43 | **卖出信号**: 44 | - 高位双峰背离:上下筹码峰形成 + 集中度快速发散 45 | - 获利盘出逃:90%集中度>30% + 单日转移率>15% 46 | - 机构派发迹象:高位筹码稳定度骤降 47 | 48 | **风险预警**: 49 | - 减持雷区:股价接近大股东成本区 50 | - 质押平仓风险:价格接近质押平仓线 51 | - 流动性危机:高集中度(>25%) + 低换手(<1%) 52 | 53 | ## 输出标准 54 | 55 | ### 1. 筹码分布概况 56 | - 当前筹码分布形态描述 57 | - 主要筹码峰位置和成本区间 58 | - 筹码集中度水平评估 59 | 60 | ### 2. 主力行为画像 61 | - 主力控盘阶段判断 62 | - 主力成本区间识别 63 | - 近期操作行为分析 64 | 65 | ### 3. 压力支撑分析 66 | - 关键支撑位:主要筹码峰位置 67 | - 压力位:历史套牢区域 68 | - 突破或跌破概率评估 69 | 70 | ### 4. 交易决策建议 71 | - 明确的买入/卖出/持有建议 72 | - 风险点提示 73 | - 止损止盈位设定 74 | 75 | # 工具使用规范 76 | 77 | 在分析过程中,请合理使用以下工具: 78 | - **chip_analysis_tool**:获取股票的筹码分布数据和相关指标 79 | - **terminate**:当你完成了完整的筹码分析报告后,必须使用此工具结束任务 80 | 81 | ⚠️ **重要提醒**:当你完成了筹码分析并准备输出最终报告时,请立即使用terminate工具结束任务,避免无限循环。 82 | ## 分析原则 83 | 84 | 1. **数据驱动**:基于真实筹码分布数据,不做主观臆测 85 | 2. **A股特色**:结合A股市场特有的政策市、资金市特征 86 | 3. **博弈思维**:从主力与散户博弈角度解读筹码变化 87 | 4. **风险优先**:重点识别风险点,避免追高杀跌 88 | 5. **客观中立**:不带个人情绪,基于数据得出结论 89 | 90 | ## 表达风格 91 | 92 | - **专业术语**:使用标准的筹码分析术语 93 | - **逻辑清晰**:先分析现状,再推导结论 94 | - **量化表达**:用具体数据支撑分析观点 95 | - **实用导向**:提供可操作的交易建议 96 | 97 | 你的任务是运用专业的筹码分析技能,为用户提供准确、实用的筹码分析报告,帮助他们在A股市场中做出更明智的投资决策。请始终保持专业、客观、负责任的态度。""" -------------------------------------------------------------------------------- /docs/flow_diagram.md: -------------------------------------------------------------------------------- 1 | ```mermaid 2 | sequenceDiagram 3 | participant User 4 | participant Main as Main Program 5 | participant ResearchEnv as Research Environment 6 | participant BattleEnv as Battle Environment 7 | participant Memory as Memory System 8 | participant MCP as MCP Server & Tools 9 | participant SA as Sentiment Agent 10 | participant RA as Risk Control Agent 11 | participant HMA as Hot Money Agent 12 | participant TAA as Technical Analysis Agent 13 | participant REP as Report Agent 14 | 15 | User->>Main: 输入股票代码 (run_stock_pipeline) 16 | 17 | Main->>ResearchEnv: 创建研究环境 (ResearchEnvironment.create) 18 | Main->>BattleEnv: 创建博弈环境 (BattleEnvironment.create) 19 | 20 | Main->>ResearchEnv: 运行股票研究 (run) 21 | 22 | Note over ResearchEnv: 执行研究环节 23 | 24 | par 并行研究请求 25 | ResearchEnv->>SA: 请求舆情分析 26 | ResearchEnv->>RA: 请求风险分析 27 | ResearchEnv->>HMA: 请求市场分析 28 | ResearchEnv->>TAA: 请求技术分析 29 | end 30 | 31 | Note over SA,TAA: 各Agent调用MCP提供的工具 32 | SA->>MCP: 请求舆情数据与工具 33 | MCP-->>SA: 提供舆情数据与分析能力 34 | 35 | RA->>MCP: 请求风险评估工具 36 | MCP-->>RA: 提供风险评估能力 37 | 38 | HMA->>MCP: 请求市场数据与分析工具 39 | MCP-->>HMA: 提供市场分析能力 40 | 41 | TAA->>MCP: 请求技术分析工具 42 | MCP-->>TAA: 提供技术分析能力 43 | 44 | SA-->>ResearchEnv: 返回舆情分析 45 | RA-->>ResearchEnv: 返回风险评估 46 | HMA-->>ResearchEnv: 返回市场分析 47 | TAA-->>ResearchEnv: 返回技术分析 48 | 49 | ResearchEnv->>REP: 请求生成报告 50 | REP->>MCP: 请求报告生成工具 51 | MCP-->>REP: 提供报告生成能力 52 | REP-->>ResearchEnv: 返回综合报告 53 | 54 | ResearchEnv-->>Main: 返回研究结果 55 | 56 | Main->>BattleEnv: 注册博弈智能体 (register_agent) 57 | Note over Main,BattleEnv: 重置每个智能体的执行状态 (reset_execution_state) 58 | 59 | Main->>BattleEnv: 运行博弈 (run) 60 | 61 | Note over BattleEnv: 博弈环节 62 | BattleEnv->>SA: 邀请参与博弈 63 | BattleEnv->>RA: 邀请参与博弈 64 | BattleEnv->>HMA: 邀请参与博弈 65 | BattleEnv->>TAA: 邀请参与博弈 66 | 67 | Note over SA,TAA: Agents讨论并投票 68 | SA->>MCP: 使用博弈工具(发言/投票) 69 | RA->>MCP: 使用博弈工具(发言/投票) 70 | HMA->>MCP: 使用博弈工具(发言/投票) 71 | TAA->>MCP: 使用博弈工具(发言/投票) 72 | 73 | BattleEnv-->>Main: 返回博弈结果(final_decision) 74 | 75 | Main->>Memory: 存储分析和博弈结果 76 | 77 | Main->>User: 显示结果 (display_results) 78 | Note over Main,User: 输出格式:文本或JSON 79 | ``` 80 | -------------------------------------------------------------------------------- /src/agent/technical_analysis.py: -------------------------------------------------------------------------------- 1 | from typing import Any, List, Optional 2 | 3 | from pydantic import Field 4 | 5 | from src.agent.mcp import MCPAgent 6 | from src.prompt.mcp import NEXT_STEP_PROMPT_ZN 7 | from src.prompt.technical_analysis import TECHNICAL_ANALYSIS_SYSTEM_PROMPT 8 | from src.schema import Message 9 | from src.tool import Terminate, ToolCollection 10 | from src.tool.technical_analysis import TechnicalAnalysisTool 11 | 12 | 13 | class TechnicalAnalysisAgent(MCPAgent): 14 | """Technical analysis agent applying technical indicators to stock analysis.""" 15 | 16 | name: str = "technical_analysis_agent" 17 | description: str = ( 18 | "Applies technical analysis and chart patterns to stock market analysis." 19 | ) 20 | 21 | system_prompt: str = TECHNICAL_ANALYSIS_SYSTEM_PROMPT 22 | next_step_prompt: str = NEXT_STEP_PROMPT_ZN 23 | 24 | # Initialize with FinGenius tools with proper type annotation 25 | available_tools: ToolCollection = Field( 26 | default_factory=lambda: ToolCollection( 27 | TechnicalAnalysisTool(), 28 | Terminate(), 29 | ) 30 | ) 31 | special_tool_names: List[str] = Field(default_factory=lambda: [Terminate().name]) 32 | 33 | async def run( 34 | self, request: Optional[str] = None, stock_code: Optional[str] = None 35 | ) -> Any: 36 | """Run technical analysis on the given stock using technical indicators. 37 | 38 | Args: 39 | request: Optional initial request to process. If provided, overrides stock_code parameter. 40 | stock_code: The stock code/ticker to analyze 41 | 42 | Returns: 43 | Dictionary containing technical analysis insights 44 | """ 45 | # If stock_code is provided but request is not, create request from stock_code 46 | if stock_code and not request: 47 | # Set up system message about the stock being analyzed 48 | self.memory.add_message( 49 | Message.system_message( 50 | f"你正在对股票 {stock_code} 进行技术面分析。请评估价格走势、图表形态和关键技术指标,形成短中期交易策略。" 51 | ) 52 | ) 53 | request = f"请分析 {stock_code} 的技术指标和图表形态。" 54 | 55 | # Call parent implementation with the request 56 | return await super().run(request) 57 | -------------------------------------------------------------------------------- /src/tool/tool_collection.py: -------------------------------------------------------------------------------- 1 | """Collection classes for managing multiple tools.""" 2 | from typing import Any, Dict, List 3 | 4 | from src.exceptions import ToolError 5 | from src.logger import logger 6 | from src.tool.base import BaseTool, ToolFailure, ToolResult 7 | 8 | 9 | class ToolCollection: 10 | """A collection of defined tools.""" 11 | 12 | class Config: 13 | arbitrary_types_allowed = True 14 | 15 | def __init__(self, *tools: BaseTool): 16 | self.tools = tools 17 | self.tool_map = {tool.name: tool for tool in tools} 18 | 19 | def __iter__(self): 20 | return iter(self.tools) 21 | 22 | def to_params(self) -> List[Dict[str, Any]]: 23 | return [tool.to_param() for tool in self.tools] 24 | 25 | async def execute( 26 | self, *, name: str, tool_input: Dict[str, Any] = None 27 | ) -> ToolResult: 28 | tool = self.tool_map.get(name) 29 | if not tool: 30 | return ToolFailure(error=f"Tool {name} is invalid") 31 | try: 32 | tool_input = tool_input or {} 33 | result = await tool(**tool_input) 34 | return result 35 | except ToolError as e: 36 | return ToolFailure(error=e.message) 37 | 38 | async def execute_all(self) -> List[ToolResult]: 39 | """Execute all tools in the collection sequentially.""" 40 | results = [] 41 | for tool in self.tools: 42 | try: 43 | result = await tool() 44 | results.append(result) 45 | except ToolError as e: 46 | results.append(ToolFailure(error=e.message)) 47 | return results 48 | 49 | def get_tool(self, name: str) -> BaseTool: 50 | return self.tool_map.get(name) 51 | 52 | def add_tool(self, tool: BaseTool): 53 | """Add a single tool to the collection. 54 | 55 | If a tool with the same name already exists, it will be skipped and a warning will be logged. 56 | """ 57 | if tool.name in self.tool_map: 58 | logger.warning(f"Tool {tool.name} already exists in collection, skipping") 59 | return self 60 | 61 | self.tools += (tool,) 62 | self.tool_map[tool.name] = tool 63 | return self 64 | 65 | def add_tools(self, *tools: BaseTool): 66 | """Add multiple tools to the collection. 67 | 68 | If any tool has a name conflict with an existing tool, it will be skipped and a warning will be logged. 69 | """ 70 | for tool in tools: 71 | self.add_tool(tool) 72 | return self 73 | -------------------------------------------------------------------------------- /src/tool/battle.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, Optional 2 | 3 | from pydantic import Field 4 | 5 | from src.tool.base import BaseTool, ToolFailure, ToolResult 6 | 7 | 8 | class Battle(BaseTool): 9 | """Tool for agents to interact in the battle environment.""" 10 | 11 | name: str = "battle" 12 | description: str = "在对战环境中进行互动的工具。您可以发表观点和/或投票。" 13 | parameters: Dict[str, Any] = { 14 | "type": "object", 15 | "properties": { 16 | "vote": { 17 | "type": "string", 18 | "enum": ["bullish", "bearish"], 19 | "description": "您对股票走势的最终投票。'bullish'表示看涨,'bearish'表示看跌。", 20 | }, 21 | "speak": { 22 | "type": "string", 23 | "description": "您想与其他智能体分享的观点内容。", 24 | }, 25 | }, 26 | "required": ["vote", "speak"], 27 | } 28 | 29 | agent_id: str = Field(..., description="The ID of the agent using this tool") 30 | controller: Optional[Any] = Field(default=None) 31 | 32 | async def execute( 33 | self, speak: Optional[str] = None, vote: Optional[str] = None 34 | ) -> ToolResult: 35 | """ 36 | Execute the battle action with speaking and/or voting 37 | """ 38 | if not self.controller: 39 | return ToolFailure(error="对战环境未初始化") 40 | 41 | result = None 42 | formatted_output = "" 43 | 44 | # Validate vote if provided 45 | if vote is not None: 46 | if vote.lower() not in ["bullish", "bearish"]: 47 | return ToolFailure(error="投票必须是'bullish'(看涨)或'bearish'(看跌)") 48 | vote = vote.lower() 49 | 50 | # Handle speaking and voting 51 | if speak is not None and speak.strip(): 52 | # Format the output as: AgentName[vote content]: speak content 53 | formatted_output = f"{self.agent_id}[{vote if vote else '未投票'}]: {speak}" 54 | 55 | # Handle the actual speak action 56 | result = await self.controller.handle_speak(self.agent_id, speak) 57 | if result.error: 58 | return result 59 | 60 | # Handle voting if vote is provided 61 | if vote is not None: 62 | result = await self.controller.handle_vote(self.agent_id, vote) 63 | 64 | # If there was no speak content, create a formatted output for just the vote 65 | if not formatted_output: 66 | formatted_output = f"{self.agent_id}[{vote}]: " 67 | 68 | # If neither speak nor vote was provided 69 | if result is None: 70 | return ToolFailure(error="您必须提供发言内容和投票选项") 71 | 72 | # Update the result output with the formatted display 73 | if result.output is None: 74 | result.output = formatted_output 75 | else: 76 | result.output = formatted_output 77 | 78 | return result 79 | -------------------------------------------------------------------------------- /src/prompt/mcp.py: -------------------------------------------------------------------------------- 1 | """Prompts for the MCP Agent.""" 2 | 3 | SYSTEM_PROMPT = """You are an AI assistant with access to a Model Context Protocol (MCP) server. 4 | You can use the tools provided by the MCP server to complete tasks. 5 | The MCP server will dynamically expose tools that you can use - always check the available tools first. 6 | 7 | When using an MCP tool: 8 | 1. Choose the appropriate tool based on your task requirements 9 | 2. Provide properly formatted arguments as required by the tool 10 | 3. Observe the results and use them to determine next steps 11 | 4. Tools may change during operation - new tools might appear or existing ones might disappear 12 | 13 | Follow these guidelines: 14 | - Call tools with valid parameters as documented in their schemas 15 | - Handle errors gracefully by understanding what went wrong and trying again with corrected parameters 16 | - For multimedia responses (like images), you'll receive a description of the content 17 | - Complete user requests step by step, using the most appropriate tools 18 | - If multiple tools need to be called in sequence, make one call at a time and wait for results 19 | 20 | Remember to clearly explain your reasoning and actions to the user. 21 | """ 22 | 23 | SYSTEM_PROMPT_ZN = """你是一个AI助手,可以访问Model Context Protocol (MCP) 服务器。 24 | 你可以使用MCP服务器提供的工具来完成任务。 25 | MCP服务器会动态展示你可以使用的工具 - 始终首先检查可用工具。 26 | 27 | 使用MCP Tool时: 28 | 1. 根据任务要求选择合适的Tool 29 | 2. 按照Tool的要求提供正确格式的参数 30 | 3. 观察结果,并根据结果决定下一步行动 31 | 4. Tools在操作过程中可能会发生变化 - 新的Tool可能会出现,或现有的Tool可能会消失 32 | 33 | 遵循以下指南: 34 | - 使用文档中描述的有效参数调用Tool 35 | - 通过理解错误的原因并用更正后的参数重试来优雅地处理错误 36 | - 对于多媒体响应(如图像),你将收到内容的描述 37 | - 按步骤完成用户请求,使用最合适的Tool 38 | - 如果需要按顺序调用多个Tool,每次调用一个并等待结果 39 | 40 | 记得清晰地向用户解释你的推理和行动。 41 | """ 42 | 43 | NEXT_STEP_PROMPT = """Based on the current state and available tools, what should be done next? 44 | Think step by step about the problem and identify which MCP tool would be most helpful for the current stage. 45 | If you've already made progress, consider what additional information you need or what actions would move you closer to completing the task. 46 | """ 47 | 48 | NEXT_STEP_PROMPT_ZN = """根据当前状态和可用的 Tools,接下来应该做什么? 49 | 50 | **工作流程指导:** 51 | 1. **数据收集阶段**:如果还没有使用专业工具获取数据,请选择合适的工具执行分析 52 | 2. **深度分析阶段**:如果已经获得了工具数据,请基于数据进行专业分析和解读,不要直接terminate 53 | 3. **综合结论阶段**:当你完成了专业分析并得出结论后,使用terminate工具结束 54 | 55 | **重要提醒:** 56 | - 获得工具数据后,必须进行专业的分析思考,解读数据含义,提供专业见解 57 | - 不要仅仅展示原始数据,要提供有价值的分析结论 58 | - 在给出最终专业分析结论后,才能使用terminate工具结束任务 59 | 60 | 逐步思考问题,确定当前阶段最需要的行动。 61 | """ 62 | 63 | # Additional specialized prompts 64 | TOOL_ERROR_PROMPT = """You encountered an error with the tool '{tool_name}'. 65 | Try to understand what went wrong and correct your approach. 66 | Common issues include: 67 | - Missing or incorrect parameters 68 | - Invalid parameter formats 69 | - Using a tool that's no longer available 70 | - Attempting an operation that's not supported 71 | 72 | Please check the tool specifications and try again with corrected parameters. 73 | """ 74 | 75 | MULTIMEDIA_RESPONSE_PROMPT = """You've received a multimedia response (image, audio, etc.) from the tool '{tool_name}'. 76 | This content has been processed and described for you. 77 | Use this information to continue the task or provide insights to the user. 78 | """ 79 | -------------------------------------------------------------------------------- /src/tool/base.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Any, Dict, Optional 3 | 4 | from pydantic import BaseModel, Field 5 | import time 6 | import logging 7 | from datetime import datetime 8 | from typing import Optional, Dict, Any 9 | 10 | 11 | class BaseTool(ABC, BaseModel): 12 | name: str 13 | description: str 14 | parameters: Optional[dict] = None 15 | 16 | class Config: 17 | arbitrary_types_allowed = True 18 | 19 | async def __call__(self, **kwargs) -> Any: 20 | """Execute the tool with given parameters.""" 21 | return await self.execute(**kwargs) 22 | 23 | @abstractmethod 24 | async def execute(self, **kwargs) -> Any: 25 | """Execute the tool with given parameters.""" 26 | 27 | def to_param(self) -> Dict: 28 | """Convert tool to function call format.""" 29 | return { 30 | "type": "function", 31 | "function": { 32 | "name": self.name, 33 | "description": self.description, 34 | "parameters": self.parameters, 35 | }, 36 | } 37 | 38 | 39 | class ToolResult(BaseModel): 40 | """Represents the result of a tool execution.""" 41 | 42 | output: Any = Field(default=None) 43 | error: Optional[str] = Field(default=None) 44 | base64_image: Optional[str] = Field(default=None) 45 | system: Optional[str] = Field(default=None) 46 | 47 | class Config: 48 | arbitrary_types_allowed = True 49 | 50 | def __bool__(self): 51 | return any(getattr(self, field) for field in self.model_fields) 52 | 53 | def __add__(self, other: "ToolResult"): 54 | def combine_fields( 55 | field: Optional[str], other_field: Optional[str], concatenate: bool = True 56 | ): 57 | if field and other_field: 58 | if concatenate: 59 | return field + other_field 60 | raise ValueError("Cannot combine tool results") 61 | return field or other_field 62 | 63 | return ToolResult( 64 | output=combine_fields(self.output, other.output), 65 | error=combine_fields(self.error, other.error), 66 | base64_image=combine_fields(self.base64_image, other.base64_image, False), 67 | system=combine_fields(self.system, other.system), 68 | ) 69 | 70 | def __str__(self): 71 | return f"Error: {self.error}" if self.error else str(self.output) 72 | 73 | def replace(self, **kwargs): 74 | """Returns a new ToolResult with the given fields replaced.""" 75 | # return self.copy(update=kwargs) 76 | return type(self)(**{**self.dict(), **kwargs}) 77 | 78 | 79 | class CLIResult(ToolResult): 80 | """A ToolResult that can be rendered as a CLI output.""" 81 | 82 | 83 | class ToolFailure(ToolResult): 84 | """A ToolResult that represents a failure.""" 85 | 86 | 87 | def get_recent_trading_day(date_format: str = "%Y-%m-%d") -> str: 88 | """ 89 | 获取最近的交易日(跳过周末) 90 | 91 | A股交易日规则: 92 | - 周一到周五是交易日 93 | - 周六周日是休市 94 | 95 | Args: 96 | date_format: 返回日期格式,默认为"%Y-%m-%d" 97 | 98 | Returns: 99 | str: 最近的交易日日期字符串 100 | """ 101 | from datetime import datetime, timedelta 102 | 103 | current_date = datetime.now() 104 | 105 | # 如果是周末,则回退到最近的交易日 106 | while current_date.weekday() >= 5: # 周六=5, 周日=6 107 | current_date -= timedelta(days=1) 108 | 109 | return current_date.strftime(date_format) 110 | -------------------------------------------------------------------------------- /src/tool/stock_info_request.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import datetime 3 | from typing import Any, Dict 4 | 5 | import efinance as ef 6 | import pandas as pd 7 | from pydantic import Field 8 | 9 | from src.tool.base import BaseTool, ToolResult, get_recent_trading_day 10 | 11 | 12 | class StockInfoResponse(ToolResult): 13 | """Response model for stock information, extending ToolResult.""" 14 | 15 | output: Dict[str, Any] = Field(default_factory=dict) 16 | 17 | @property 18 | def current_trading_day(self) -> str: 19 | """Get the current trading day from the output.""" 20 | return self.output.get("current_trading_day", "") 21 | 22 | @property 23 | def basic_info(self) -> Dict[str, Any]: 24 | """Get the basic stock information from the output.""" 25 | return self.output.get("basic_info", {}) 26 | 27 | 28 | class StockInfoRequest(BaseTool): 29 | """Tool to fetch basic information about a stock with the current trading date.""" 30 | 31 | name: str = "stock_info_request" 32 | description: str = "获取股票基础信息和当前交易日,返回JSON格式的结果。" 33 | parameters: Dict[str, Any] = { 34 | "type": "object", 35 | "properties": {"stock_code": {"type": "string", "description": "股票代码"}}, 36 | "required": ["stock_code"], 37 | } 38 | 39 | MAX_RETRIES: int = 3 40 | RETRY_DELAY: int = 1 # seconds 41 | 42 | async def execute(self, stock_code: str, **kwargs) -> StockInfoResponse | None: 43 | """ 44 | Execute the tool to fetch stock information. 45 | 46 | Args: 47 | stock_code: The stock code to query 48 | 49 | Returns: 50 | StockInfoResponse containing stock information and current trading date 51 | """ 52 | for attempt in range(1, self.MAX_RETRIES + 1): 53 | try: 54 | # Get current trading day 55 | trading_day = get_recent_trading_day() 56 | 57 | # Fetch stock information 58 | data = ef.stock.get_base_info(stock_code) 59 | 60 | # Convert data to dict format based on its type 61 | basic_info = self._format_data(data) 62 | 63 | # Create and return the response 64 | return StockInfoResponse( 65 | output={ 66 | "current_trading_day": trading_day, 67 | "basic_info": basic_info, 68 | } 69 | ) 70 | 71 | except Exception as e: 72 | if attempt < self.MAX_RETRIES: 73 | await asyncio.sleep(float(self.RETRY_DELAY)) 74 | return StockInfoResponse( 75 | error=f"获取股票信息失败 ({self.MAX_RETRIES}次尝试): {str(e)}" 76 | ) 77 | 78 | @staticmethod 79 | def _format_data(data: Any) -> Dict[str, Any]: 80 | """ 81 | Format data to a JSON-serializable dictionary. 82 | 83 | Args: 84 | data: The data to format, typically from efinance 85 | 86 | Returns: 87 | A dictionary representation of the data 88 | """ 89 | if isinstance(data, pd.DataFrame): 90 | return data.to_dict(orient="records")[0] if len(data) > 0 else {} 91 | elif isinstance(data, pd.Series): 92 | return data.to_dict() 93 | elif isinstance(data, dict): 94 | return data 95 | elif isinstance(data, (int, float, str, bool)): 96 | return {"value": data} 97 | else: 98 | return {"value": str(data)} 99 | 100 | 101 | if __name__ == "__main__": 102 | import json 103 | import sys 104 | 105 | # Use default stock code "600519" (Maotai) if not provided 106 | code = sys.argv[1] if len(sys.argv) > 1 else "600519" 107 | 108 | # Create and run the tool 109 | tool = StockInfoRequest() 110 | result = asyncio.run(tool.execute(code)) 111 | 112 | # Print the result 113 | if result.error: 114 | print(f"Error: {result.error}") 115 | else: 116 | print(json.dumps(result.output, ensure_ascii=False, indent=2)) 117 | -------------------------------------------------------------------------------- /src/tool/risk_control.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | 3 | from src.logger import logger 4 | from src.tool.base import BaseTool, ToolResult 5 | from src.tool.financial_deep_search.risk_control_data import get_risk_control_data 6 | 7 | 8 | class RiskControlTool(BaseTool): 9 | """Tool for retrieving risk control data for stocks.""" 10 | 11 | name: str = "risk_control_tool" 12 | description: str = ( 13 | "获取股票的风控数据,包括财务数据(现金流量表,资产负债表,利润表)(financial)和法务公告数据(legal)。" 14 | "支持最大重试机制,适合大模型自动调用。返回结构化字典。" 15 | ) 16 | parameters: dict = { 17 | "type": "object", 18 | "properties": { 19 | "stock_code": { 20 | "type": "string", 21 | "description": "股票代码(必填),如'600519'(贵州茅台)、'000001'(平安银行)、'300750'(宁德时代)等", 22 | }, 23 | "max_count": { 24 | "type": "integer", 25 | "description": "公告数据获取数量上限,建议不超过10,用于限制返回的法律公告数量", 26 | "default": 10, 27 | }, 28 | "period": { 29 | "type": "string", 30 | "description": "财务数据周期类型,精确可选值:'按年度'(年报数据)、'按报告期'(含季报)、'按单季度'(环比数据)", 31 | "default": "按年度", 32 | }, 33 | "max_retry": { 34 | "type": "integer", 35 | "description": "数据获取最大重试次数,范围1-5,用于处理网络波动情况", 36 | "default": 3, 37 | }, 38 | "sleep_seconds": { 39 | "type": "integer", 40 | "description": "重试间隔秒数,范围1-10,防止频繁请求被限制", 41 | "default": 1, 42 | }, 43 | }, 44 | "required": ["stock_code"], 45 | } 46 | 47 | async def execute( 48 | self, 49 | stock_code: str, 50 | max_count: int = 10, 51 | period: str = "按年度", 52 | max_retry: int = 3, 53 | sleep_seconds: int = 1, 54 | **kwargs, 55 | ) -> ToolResult: 56 | """ 57 | Get risk control data for a single stock with retry mechanism. 58 | 59 | Args: 60 | stock_code: Stock code 61 | max_count: Maximum number of announcements to retrieve 62 | period: Financial data period (按年度/按报告期/按单季度) 63 | max_retry: Maximum retry attempts 64 | sleep_seconds: Seconds to wait between retries 65 | **kwargs: Additional parameters 66 | 67 | Returns: 68 | ToolResult: Result containing risk control data 69 | """ 70 | try: 71 | # Execute synchronous operation in thread pool to avoid blocking event loop 72 | result = await asyncio.to_thread( 73 | get_risk_control_data, 74 | stock_code=stock_code, 75 | max_count=max_count, 76 | period=period, 77 | include_announcements=True, 78 | include_financial=True, 79 | max_retry=max_retry, 80 | sleep_seconds=sleep_seconds, 81 | ) 82 | 83 | if "error" in result: 84 | return ToolResult(error=result["error"]) 85 | 86 | return ToolResult(output=result) 87 | 88 | except Exception as e: 89 | error_msg = f"Failed to get risk control data: {str(e)}" 90 | logger.error(error_msg) 91 | return ToolResult(error=error_msg) 92 | 93 | 94 | if __name__ == "__main__": 95 | # Direct tool testing 96 | import sys 97 | 98 | code = sys.argv[1] if len(sys.argv) > 1 else "600519" 99 | 100 | # Get risk control data 101 | tool = RiskControlTool() 102 | result = asyncio.run(tool.execute(stock_code=code)) 103 | 104 | # Output results 105 | if result.error: 106 | print(f"Failed: {result.error}") 107 | else: 108 | output = result.output 109 | print(f"Success!") 110 | print( 111 | f"- Financial Data: {'Retrieved' if output['financial'] else 'Not Retrieved'}" 112 | ) 113 | if output['legal']: 114 | legal_info = f"Retrieved ({len(output['legal'])} items)" 115 | else: 116 | legal_info = "Not Retrieved" 117 | print(f"- Legal Data: {legal_info}") 118 | -------------------------------------------------------------------------------- /src/environment/base.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from enum import Enum 3 | from typing import Any, Dict, List, Optional, Union 4 | 5 | from pydantic import BaseModel, Field 6 | 7 | from src.agent.base import BaseAgent 8 | from src.logger import logger 9 | 10 | 11 | class BaseEnvironment(BaseModel): 12 | """Base environment class for all environments""" 13 | 14 | name: str = Field(default="base_environment") 15 | description: str = Field(default="Base environment class") 16 | agents: Dict[str, BaseAgent] = Field(default_factory=dict) 17 | max_steps: int = Field(default=3, description="Maximum steps for each agent") 18 | 19 | class Config: 20 | arbitrary_types_allowed = True 21 | 22 | @classmethod 23 | async def create(cls, **kwargs) -> "BaseEnvironment": 24 | """Factory method to create and initialize an environment""" 25 | instance = cls(**kwargs) 26 | await instance.initialize() 27 | return instance 28 | 29 | async def initialize(self) -> None: 30 | """Initialize the environment. Override in subclasses.""" 31 | logger.info(f"Initializing {self.name} environment (max_steps={self.max_steps})") 32 | 33 | def register_agent(self, agent: BaseAgent) -> None: 34 | """Register an agent with the environment""" 35 | self.agents[agent.name] = agent 36 | logger.debug(f"Agent {agent.name} registered in {self.name}") 37 | 38 | # Alias for register_agent for better API flexibility 39 | def add_agent(self, agent: BaseAgent) -> None: 40 | """Alias for register_agent""" 41 | self.register_agent(agent) 42 | 43 | def get_agent(self, agent_name: str) -> Optional[BaseAgent]: 44 | """Get an agent by name""" 45 | return self.agents.get(agent_name) 46 | 47 | @abstractmethod 48 | async def run(self, **kwargs) -> Dict[str, Any]: 49 | """Run the environment. Override in subclasses.""" 50 | raise NotImplementedError("Subclasses must implement run method") 51 | 52 | async def cleanup(self) -> None: 53 | """Clean up resources when done""" 54 | logger.info(f"Cleaning up {self.name} environment") 55 | 56 | 57 | class EnvironmentType(str, Enum): 58 | """Enum of available environment types""" 59 | 60 | RESEARCH = "research" 61 | BATTLE = "battle" 62 | 63 | 64 | class EnvironmentFactory: 65 | """Factory for creating different types of environments with support for multiple agents""" 66 | 67 | @staticmethod 68 | async def create_environment( 69 | environment_type: EnvironmentType, 70 | agents: Union[BaseAgent, List[BaseAgent], Dict[str, BaseAgent]] = None, 71 | **kwargs, 72 | ) -> BaseEnvironment: 73 | """Create and initialize an environment of the specified type 74 | 75 | Args: 76 | environment_type: The type of environment to create 77 | agents: One or more agents to add to the environment 78 | **kwargs: Additional arguments to pass to the environment constructor 79 | 80 | Returns: 81 | An initialized environment instance 82 | """ 83 | from src.environment.battle import BattleEnvironment 84 | from src.environment.research import ResearchEnvironment 85 | 86 | environments = { 87 | EnvironmentType.RESEARCH: ResearchEnvironment, 88 | EnvironmentType.BATTLE: BattleEnvironment, 89 | } 90 | 91 | environment_class = environments.get(environment_type) 92 | if not environment_class: 93 | raise ValueError(f"Unknown environment type: {environment_type}") 94 | 95 | # Create the environment 96 | environment = await environment_class.create(**kwargs) 97 | 98 | # Add agents if provided 99 | if agents: 100 | if isinstance(agents, BaseAgent): 101 | environment.add_agent(agents) 102 | elif isinstance(agents, list): 103 | for agent in agents: 104 | environment.add_agent(agent) 105 | elif isinstance(agents, dict): 106 | for agent in agents.values(): 107 | environment.add_agent(agent) 108 | 109 | return environment 110 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | ### Project-specific ### 2 | # Logs 3 | logs/ 4 | 5 | # Workspace 6 | workspace/ 7 | 8 | ### Python ### 9 | # Byte-compiled / optimized / DLL files 10 | __pycache__/ 11 | *.py[cod] 12 | *$py.class 13 | 14 | # C extensions 15 | *.so 16 | 17 | # Distribution / packaging 18 | .Python 19 | build/ 20 | develop-eggs/ 21 | dist/ 22 | downloads/ 23 | eggs/ 24 | .eggs/ 25 | lib/ 26 | lib64/ 27 | parts/ 28 | sdist/ 29 | var/ 30 | wheels/ 31 | share/python-wheels/ 32 | *.egg-info/ 33 | .installed.cfg 34 | *.egg 35 | MANIFEST 36 | 37 | # PyInstaller 38 | # Usually these files are written by a python script from a template 39 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 40 | *.manifest 41 | *.spec 42 | 43 | # Installer logs 44 | pip-log.txt 45 | pip-delete-this-directory.txt 46 | 47 | # Unit test / coverage reports 48 | htmlcov/ 49 | .tox/ 50 | .nox/ 51 | .coverage 52 | .coverage.* 53 | .cache 54 | nosetests.xml 55 | coverage.xml 56 | *.cover 57 | *.py,cover 58 | .hypothesis/ 59 | .pytest_cache/ 60 | cover/ 61 | 62 | # Translations 63 | *.mo 64 | *.pot 65 | 66 | # Django stuff: 67 | *.log 68 | local_settings.py 69 | db.sqlite3 70 | db.sqlite3-journal 71 | 72 | # Flask stuff: 73 | instance/ 74 | .webassets-cache 75 | 76 | # Scrapy stuff: 77 | .scrapy 78 | 79 | # Sphinx documentation 80 | docs/_build/ 81 | 82 | # PyBuilder 83 | .pybuilder/ 84 | target/ 85 | 86 | # Jupyter Notebook 87 | .ipynb_checkpoints 88 | 89 | # IPython 90 | profile_default/ 91 | ipython_config.py 92 | 93 | # pyenv 94 | # For a library or package, you might want to ignore these files since the code is 95 | # intended to run in multiple environments; otherwise, check them in: 96 | # .python-version 97 | 98 | # pipenv 99 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 100 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 101 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 102 | # install all needed dependencies. 103 | #Pipfile.lock 104 | 105 | # UV 106 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. 107 | # This is especially recommended for binary packages to ensure reproducibility, and is more 108 | # commonly ignored for libraries. 109 | #uv.lock 110 | 111 | # poetry 112 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 113 | # This is especially recommended for binary packages to ensure reproducibility, and is more 114 | # commonly ignored for libraries. 115 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 116 | #poetry.lock 117 | 118 | # pdm 119 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 120 | #pdm.lock 121 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 122 | # in version control. 123 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 124 | .pdm.toml 125 | .pdm-python 126 | .pdm-build/ 127 | 128 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 129 | __pypackages__/ 130 | 131 | # Celery stuff 132 | celerybeat-schedule 133 | celerybeat.pid 134 | 135 | # SageMath parsed files 136 | *.sage.py 137 | 138 | # Environments 139 | .env 140 | .venv 141 | env/ 142 | venv/ 143 | ENV/ 144 | env.bak/ 145 | venv.bak/ 146 | 147 | # Spyder project settings 148 | .spyderproject 149 | .spyproject 150 | 151 | # Rope project settings 152 | .ropeproject 153 | 154 | # mkdocs documentation 155 | /site 156 | 157 | # mypy 158 | .mypy_cache/ 159 | .dmypy.json 160 | dmypy.json 161 | 162 | # Pyre type checker 163 | .pyre/ 164 | 165 | # pytype static type analyzer 166 | .pytype/ 167 | 168 | # Cython debug symbols 169 | cython_debug/ 170 | 171 | # PyCharm 172 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 173 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 174 | # and can be added to the global gitignore or merged into this file. For a more nuclear 175 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 176 | .idea/* 177 | 178 | # PyPI configuration file 179 | .pypirc 180 | 181 | ### Visual Studio Code ### 182 | .vscode/* 183 | !.vscode/settings.json 184 | !.vscode/tasks.json 185 | !.vscode/launch.json 186 | !.vscode/extensions.json 187 | !.vscode/*.code-snippets 188 | 189 | # Local History for Visual Studio Code 190 | .history/ 191 | 192 | # Built Visual Studio Code Extensions 193 | *.vsix 194 | 195 | # OSX 196 | .DS_Store 197 | 198 | # node 199 | node_modules 200 | -------------------------------------------------------------------------------- /src/prompt/sentiment.py: -------------------------------------------------------------------------------- 1 | SENTIMENT_SYSTEM_PROMPT = """ 2 | # 股票舆情分析专家 3 | 4 | 你是一位专业的股票舆情分析师,专注于A股市场的舆论监测和情感分析。你擅长通过多渠道信息收集、数据分析和情感计算,为投资者提供客观、准确的股票舆情报告。 5 | 6 | ## 🎯 核心职责 7 | - 实时监测目标股票的网络舆情动态 8 | - 分析媒体报道、社交平台、投资论坛的情感倾向 9 | - 识别影响股价的关键舆情事件和传播节点 10 | - 评估舆情对股票短期和中期走势的潜在影响 11 | - 提供基于数据的客观分析,不给出投资建议 12 | 13 | ## 🔧 工具使用规范 14 | 15 | **主要工具说明:** 16 | - **web_search**:进行实时信息收集,获取舆情数据 17 | - **terminate**:当你完成了完整的舆情分析报告后,必须使用此工具结束任务 18 | 19 | ⚠️ **重要提醒**:当你完成了舆情分析并准备输出最终报告时,请立即使用terminate工具结束任务,避免无限循环。 20 | 21 | **必须使用web_search工具进行实时信息收集,按以下顺序执行:** 22 | 23 | ## 🎯 权威数据源优先策略 24 | 25 | ### 核心财经媒体(优先关注) 26 | - **财联社** (cls.cn) - 专业财经资讯平台 27 | - **新浪财经** (finance.sina.com.cn) - 财经新闻和分析 28 | - **证券时报** (stcn.com) - 权威证券媒体 29 | - **上海证券报** (cnstock.com) - 官方证券媒体 30 | - **中国证券报** (cs.com.cn) - 权威证券报道 31 | - **第一财经** (yicai.com) - 专业财经媒体 32 | - **21世纪经济报道** (21jingji.com) - 权威经济媒体 33 | 34 | ### 投资社区和论坛(关注情绪) 35 | - **雪球** (xueqiu.com) - 专业投资社区 36 | - **东方财富股吧** (guba.eastmoney.com) - 投资者讨论平台 37 | - **同花顺** (10jqka.com.cn) - 投资资讯平台 38 | - **金融界** (jrj.com.cn) - 财经门户网站 39 | 40 | ## 📝 搜索策略(根据搜索引擎自动调整) 41 | 42 | ### 🔄 智能搜索策略选择 43 | 44 | ### 🔍 百度/必应搜索引擎策略(关键词匹配) 45 | **当搜索引擎为百度或必应时,使用以下关键词搜索:** 46 | 47 | 1. **📰 权威媒体新闻搜索** 48 | - 搜索:"{公司名称} {股票代码} 财联社 最新消息" 49 | - 搜索:"{公司名称} {股票代码} 新浪财经 最新资讯" 50 | - 搜索:"{公司名称} 证券时报 报道" 51 | - 搜索:"{公司名称} {股票代码} 上海证券报 新闻" 52 | - 搜索:"{公司名称} 第一财经 资讯" 53 | - 搜索:"{公司名称} {股票代码} 中国证券报 消息" 54 | 55 | 2. **💬 投资者情绪监测** 56 | - 搜索:"{公司名称} {股票代码} 雪球 讨论 观点" 57 | - 搜索:"{公司名称} {股票代码} 东方财富股吧 热议" 58 | - 搜索:"{公司名称} 同花顺 投资者 讨论" 59 | - 搜索:"{公司名称} {股票代码} 金融界 评论 分析" 60 | - 搜索:"{公司名称} {股票代码} 股民 看法 观点" 61 | 62 | 3. **🏢 公司公告和业绩追踪** 63 | - 搜索:"{公司名称} {股票代码} 公司公告 最新公告" 64 | - 搜索:"{公司名称} 财报 业绩 季报" 65 | - 搜索:"{公司名称} {股票代码} 年报 半年报" 66 | - 搜索:"{公司名称} 重大事项 重组" 67 | - 搜索:"{公司名称} {股票代码} 分红 派息" 68 | 69 | 4. **📊 机构研报和分析** 70 | - 搜索:"{公司名称} {股票代码} 研报 分析师 券商" 71 | - 搜索:"{公司名称} 机构评级 目标价" 72 | - 搜索:"{公司名称} {股票代码} 投资建议 买入 卖出" 73 | - 搜索:"{公司名称} {行业名称} 政策 监管 影响" 74 | - 搜索:"{公司名称} {股票代码} 机构持仓 基金" 75 | 76 | 5. **⚠️ 风险预警监测** 77 | - 搜索:"{公司名称} {股票代码} 风险 预警 风险提示" 78 | - 搜索:"{公司名称} 负面 问题 争议" 79 | - 搜索:"{公司名称} {股票代码} 违规 处罚 监管" 80 | - 搜索:"{公司名称} 纠纷 诉讼 案件" 81 | - 搜索:"{公司名称} {股票代码} 停牌 复牌 异常" 82 | 83 | ### 📋 执行步骤 84 | 1. **搜索引擎检测**:先执行一次简单测试搜索,根据搜索结果的source字段确定当前使用的搜索引擎 85 | 2. **策略选择**:根据搜索引擎类型选择对应的搜索策略 86 | 3. **信息收集**:按照选定策略执行上述5个维度的搜索 87 | 4. **结果标注**:在输出结果中明确标注使用的搜索引擎和搜索策略类型 88 | 89 | **重点关注**:官方媒体报道、投资者情绪、公司公告、机构观点、风险预警 90 | 91 | ## 📊 数据源权重设置 92 | 93 | ### 一级权威源(权重:高) 94 | - 财联社、证券时报、上海证券报、中国证券报 95 | - 交易所官网、证监会官网、公司官网 96 | 97 | ### 二级专业源(权重:中高) 98 | - 东方财富、新浪财经、第一财经、21世纪经济报道 99 | - 专业研究机构报告、券商研报 100 | 101 | ### 三级社区源(权重:中) 102 | - 雪球、东方财富股吧、同花顺、金融界 103 | - 投资者论坛、社交媒体讨论 104 | 105 | ## 📋 分析工作流程 106 | 107 | ### 第一步:信息收集 108 | - 使用web_search工具按上述5个维度收集信息 109 | - 记录信息来源、发布时间、可信度等级 110 | - 筛选有效信息,排除无关内容 111 | 112 | ### 第二步:情感分析 113 | - 对收集到的文本进行情感分类(正面/负面/中性) 114 | - 计算情感强度和情感分布比例 115 | - 识别情感转折点和异常情感波动 116 | 117 | ### 第三步:传播分析 118 | - 分析信息传播路径和影响范围 119 | - 识别关键意见领袖和影响节点 120 | - 评估信息传播速度和覆盖面 121 | 122 | ### 第四步:影响评估 123 | - 评估舆情对股价的潜在影响程度 124 | - 识别短期和中期的关键风险点 125 | - 分析舆情与股价走势的关联性 126 | 127 | ### 第五步:趋势预判 128 | - 基于历史数据和当前趋势进行合理推断 129 | - 识别可能的舆情转折点 130 | - 提供后续关注重点 131 | 132 | ## 📊 输出格式规范 133 | 134 | ### 🔍 舆情概况 135 | - **股票代码**:[股票代码] 136 | - **公司名称**:[公司全称] 137 | - **分析时间**:[分析时间戳] 138 | - **搜索引擎**:[Google/百度/必应/DuckDuckGo] 139 | - **搜索策略**:[精准site:指令/关键词匹配] 140 | - **舆情热度**:[高/中/低] + 具体数值 141 | - **整体情感**:[正面/负面/中性] + 情感分数 142 | - **关键事件**:[影响舆情的重要事件] 143 | 144 | ### 📈 数据分析 145 | - **信息来源统计**: 146 | - 一级权威源:财联社X条、证券时报Y条、上海证券报Z条 147 | - 二级专业源:东方财富X条、新浪财经Y条、第一财经Z条 148 | - 三级社区源:雪球X条、股吧Y条、同花顺Z条 149 | - **情感分布**:正面X% | 中性Y% | 负面Z% 150 | - **热度变化**:与前期对比的变化趋势 151 | - **传播指标**:传播范围、互动数量、影响力指数 152 | 153 | ### 💭 内容分析 154 | - **正面观点**:主要看好理由和论据 155 | - **负面观点**:主要担忧和风险点 156 | - **中性分析**:客观事实和数据 157 | - **关键词云**:高频词汇和热点话题 158 | 159 | ### 🌐 传播分析 160 | - **信息源头**:关键信息的首发平台 161 | - 权威媒体首发:财联社/证券时报/上海证券报等 162 | - 官方公告首发:交易所/证监会/公司官网等 163 | - 社区讨论首发:雪球/股吧/投资论坛等 164 | - **传播路径**:信息扩散的主要渠道 165 | - 媒体传播:从权威媒体到其他财经平台 166 | - 社交传播:从专业投资者到普通投资者 167 | - 官方传播:从监管部门到市场参与者 168 | - **影响节点**:关键意见领袖和转发大户 169 | - 知名财经媒体:财联社、东方财富、新浪财经 170 | - 专业投资者:雪球大V、知名博主、分析师 171 | - 机构媒体:券商研报、基金公司、投资机构 172 | - **传播速度**:信息传播的时效性分析 173 | 174 | ### 🔮 趋势预判 175 | - **短期走势**:24-48小时内的舆情趋势 176 | - **关键变量**:可能影响舆情的重要因素 177 | - **风险提示**:需要重点关注的风险点 178 | - **监测建议**:后续需要跟踪的关键信息 179 | 180 | ### 🎯 核心结论 181 | - **舆情评级**:[积极/中性/消极] 182 | - **影响程度**:[高/中/低] 183 | - **关注重点**:[需要重点关注的方面] 184 | - **风险警示**:[主要风险点] 185 | 186 | ## ⚠️ 重要声明 187 | - 本分析仅基于公开信息和数据,不构成投资建议 188 | - 舆情分析具有主观性,结果仅供参考 189 | - 投资有风险,决策需谨慎 190 | - 建议结合基本面分析和技术分析综合判断 191 | 192 | ## 🔄 质量保证 193 | - 所有分析必须基于web_search工具收集的实时数据 194 | - **优先使用权威数据源**:财联社、东方财富、新浪财经、证券时报等 195 | - **标注信息来源**:每条重要信息都要明确标注具体来源网站 196 | - **权重化处理**:一级权威源 > 二级专业源 > 三级社区源 197 | - 严格区分事实陈述和主观分析 198 | - 提供信息来源和可信度评估 199 | - 保持客观中立的分析立场 200 | - 及时更新分析结果以反映最新情况 201 | """ 202 | -------------------------------------------------------------------------------- /docs/class_diagram.md: -------------------------------------------------------------------------------- 1 | ``` 2 | classDiagram 3 | %%========================== 4 | %% 1. Agent 层次结构 5 | %%========================== 6 | class BaseAgent { 7 | <> 8 | - name: str 9 | - description: Optional[str] 10 | - system_prompt: Optional[str] 11 | - next_step_prompt: Optional[str] 12 | - llm: LLM 13 | - memory: Memory 14 | - state: AgentState 15 | - max_steps: int 16 | - current_step: int 17 | + run(request: Optional[str]) str 18 | + step() str <> 19 | + update_memory(role, content, **kwargs) None 20 | + reset_execution_state() None 21 | } 22 | class ReActAgent { 23 | <> 24 | + think() bool <> 25 | + act() str <> 26 | + step() str 27 | } 28 | class ToolCallAgent { 29 | - available_tools: ToolCollection 30 | - tool_choices: ToolChoice 31 | - tool_calls: List~ToolCall~ 32 | + think() bool 33 | + act() str 34 | + execute_tool(command: ToolCall) str 35 | + cleanup() None 36 | } 37 | class MCPAgent { 38 | - mcp_clients: MCPClients 39 | - tool_schemas: Dict~str, dict~ 40 | - connected_servers: Dict~str, str~ 41 | - initialized: bool 42 | + create(...) MCPAgent 43 | + initialize_mcp_servers() None 44 | + connect_mcp_server(url, server_id) None 45 | + initialize(...) None 46 | + _refresh_tools() Tuple~List~str~,List~str~~ 47 | + cleanup() None 48 | } 49 | class SentimentAgent { 50 | <> 51 | } 52 | class RiskControlAgent { 53 | <> 54 | } 55 | class HotMoneyAgent { 56 | <> 57 | } 58 | class TechnicalAnalysisAgent { 59 | <> 60 | } 61 | class ReportAgent { 62 | <> 63 | } 64 | 65 | %% 继承关系 66 | BaseAgent <|-- ReActAgent 67 | ReActAgent <|-- ToolCallAgent 68 | ToolCallAgent <|-- MCPAgent 69 | MCPAgent <|-- SentimentAgent 70 | MCPAgent <|-- RiskControlAgent 71 | MCPAgent <|-- HotMoneyAgent 72 | MCPAgent <|-- TechnicalAnalysisAgent 73 | MCPAgent <|-- ReportAgent 74 | 75 | %% 组合关系 76 | BaseAgent *-- LLM : llm 77 | BaseAgent *-- Memory : memory 78 | MCPAgent *-- MCPClients : mcp_clients 79 | 80 | %%========================== 81 | %% 2. 环境(Environment)架构 82 | %%========================== 83 | class BaseEnvironment { 84 | <> 85 | - name: str 86 | - description: str 87 | - agents: Dict~str, BaseAgent~ 88 | - max_steps: int 89 | + create(...) BaseEnvironment 90 | + register_agent(agent: BaseAgent) None 91 | + run(...) Dict~str,Any~ <> 92 | + cleanup() None 93 | } 94 | class ResearchEnvironment { 95 | - analysis_mapping: Dict~str,str~ 96 | - results: Dict~str,Any~ 97 | + initialize() None 98 | + run(stock_code: str) Dict~str,Any~ 99 | + cleanup() None 100 | } 101 | class BattleState { 102 | - active_agents: Dict~str,str~ 103 | - voted_agents: Dict~str,str~ 104 | - terminated_agents: Dict~str,bool~ 105 | - battle_history: List~Dict~str,Any~~ 106 | - vote_results: Dict~str,int~ 107 | - battle_highlights: List~Dict~str,Any~~ 108 | - battle_over: bool 109 | + add_event(type, agent_id, ...) Dict~str,Any~ 110 | + record_vote(agent_id,vote) None 111 | + mark_terminated(agent_id,reason) None 112 | } 113 | class BattleEnvironment { 114 | - state: BattleState 115 | - tools: Dict~str,BaseTool~ 116 | + initialize() None 117 | + register_agent(agent: BaseAgent) None 118 | + run(report: Dict~str,Any~) Dict~str,Any~ 119 | + handle_speak(agent_id, content) ToolResult 120 | + handle_vote(agent_id, vote) ToolResult 121 | + cleanup() None 122 | } 123 | class EnvironmentFactory { 124 | + create_environment(env_type: EnvironmentType, agents, ...) BaseEnvironment 125 | } 126 | 127 | %% 继承与工厂 128 | BaseEnvironment <|-- ResearchEnvironment 129 | BaseEnvironment <|-- BattleEnvironment 130 | EnvironmentFactory ..> BaseEnvironment : creates 131 | %% 环境中包含 Agents 和 BattleState 132 | BaseEnvironment o-- BaseAgent : agents 133 | BattleEnvironment *-- BattleState : state 134 | 135 | %%========================== 136 | %% 3. 工具(Tool)抽象 137 | %%========================== 138 | class MCPClients { 139 | - sessions: Dict~str,ClientSession~ 140 | - exit_stacks: Dict~str,AsyncExitStack~ 141 | + connect_sse(url, server_id) None 142 | + connect_stdio(cmd, args, server_id) None 143 | + list_tools() ListToolsResult 144 | + disconnect(server_id) None 145 | } 146 | 147 | %%========================== 148 | %% 4. 支持类 149 | %%========================== 150 | class Memory { 151 | - messages: List~Message~ 152 | + add_message(msg: Message) None 153 | + clear() None 154 | } 155 | class LLM { 156 | - model: str 157 | - max_tokens: int 158 | - temperature: float 159 | + ask(messages, system_msgs, ...) str 160 | + ask_tool(messages, tools, tool_choice, ...) Message 161 | } 162 | ``` -------------------------------------------------------------------------------- /src/tool/sentiment.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import json 3 | import time 4 | from typing import Any, Dict 5 | 6 | from src.logger import logger 7 | from src.tool.base import BaseTool, ToolResult 8 | from src.tool.financial_deep_search.get_section_data import get_all_section 9 | from src.tool.financial_deep_search.index_capital import get_index_capital_flow 10 | 11 | 12 | class SentimentTool(BaseTool): 13 | """Tool for retrieving market sentiment data including hot sectors and index capital flow.""" 14 | 15 | name: str = "sentiment_tool" 16 | description: str = "整合市场情绪与行业热点分析工具,提供全面的市场脉搏和资金流向监测。" 17 | parameters: dict = { 18 | "type": "object", 19 | "properties": { 20 | "index_code": { 21 | "type": "string", 22 | "description": "指数代码(必填),如'000001'(上证指数)、'399001'(深证成指)、'399006'(创业板指)、'000016'(上证50)、'000300'(沪深300)、'000905'(中证500)等", 23 | }, 24 | "sector_types": { 25 | "type": "string", 26 | "description": "板块类型筛选,精确可选值:'all'(所有板块)、'hot'(热门活跃板块)、'concept'(概念题材板块)、'regional'(地域区域板块)、'industry'(行业分类板块)", 27 | "default": "all", 28 | }, 29 | "max_retry": { 30 | "type": "integer", 31 | "description": "数据获取最大重试次数,范围1-5,用于处理网络波动情况", 32 | "default": 3, 33 | }, 34 | }, 35 | "required": ["index_code"], 36 | } 37 | 38 | async def execute( 39 | self, 40 | index_code: str, 41 | sector_types: str = "all", 42 | max_retry: int = 3, 43 | sleep_seconds: int = 1, 44 | **kwargs, 45 | ) -> ToolResult: 46 | """ 47 | Get market data with retry mechanism. 48 | 49 | Args: 50 | index_code: Index code 51 | sector_types: Sector types, options: 'all', 'hot', 'concept', 'regional', 'industry' 52 | max_retry: Maximum retry attempts 53 | sleep_seconds: Seconds to wait between retries 54 | **kwargs: Additional parameters 55 | 56 | Returns: 57 | ToolResult: Result containing market data 58 | """ 59 | try: 60 | # Execute synchronous operation in thread pool to avoid blocking event loop 61 | result = await asyncio.to_thread( 62 | self._get_market_data, 63 | index_code=index_code, 64 | sector_types=sector_types, 65 | max_retry=max_retry, 66 | sleep_seconds=sleep_seconds, 67 | ) 68 | 69 | # Check if result contains error 70 | if "error" in result: 71 | return ToolResult(error=result["error"]) 72 | 73 | return ToolResult(output=result) 74 | 75 | except Exception as e: 76 | error_msg = f"Failed to get market data: {str(e)}" 77 | logger.error(error_msg) 78 | return ToolResult(error=error_msg) 79 | 80 | def _get_market_data( 81 | self, 82 | index_code: str, 83 | sector_types: str = "all", 84 | max_retry: int = 3, 85 | sleep_seconds: int = 1, 86 | ) -> Dict[str, Any]: 87 | """ 88 | Get market data including hot sectors and index capital flow. 89 | Supports maximum retry mechanism. 90 | """ 91 | for attempt in range(1, max_retry + 1): 92 | try: 93 | # 1. Get hot sector data 94 | section_data = get_all_section(sector_types=sector_types) 95 | logger.info(f"[Attempt {attempt}] Retrieved hot sector data") 96 | 97 | # 2. Get index capital flow 98 | index_flow = get_index_capital_flow(index_code=index_code) 99 | logger.info(f"[Attempt {attempt}] Retrieved index capital flow data") 100 | 101 | # 3. Combine data and return 102 | return { 103 | "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"), 104 | "index_code": index_code, 105 | "sector_types": sector_types, 106 | "hot_section_data": section_data, 107 | "index_net_flow": index_flow, 108 | } 109 | 110 | except Exception as e: 111 | logger.warning(f"[Attempt {attempt}] Failed to get market data: {e}") 112 | if attempt < max_retry: 113 | logger.info(f"Waiting {sleep_seconds} seconds before retry...") 114 | time.sleep(sleep_seconds) 115 | else: 116 | logger.error(f"Max retries ({max_retry}) reached, failed") 117 | return {"error": f"Failed to get market data: {str(e)}"} 118 | 119 | 120 | if __name__ == "__main__": 121 | import sys 122 | 123 | code = sys.argv[1] if len(sys.argv) > 1 else "000001" 124 | 125 | tool = SentimentTool() 126 | result = asyncio.run(tool.execute(index_code=code)) 127 | 128 | if result.error: 129 | print(f"Failed: {result.error}") 130 | else: 131 | output = result.output 132 | print(f"Success! Timestamp: {output['timestamp']}") 133 | print(f"Index Code: {output['index_code']}") 134 | 135 | for key in ["hot_section_data", "index_net_flow"]: 136 | status = "Retrieved" if output.get(key) else "Not Retrieved" 137 | print(f"- {key}: {status}") 138 | 139 | filename = f"market_data_{code}_{time.strftime('%Y%m%d_%H%M%S')}.json" 140 | with open(filename, "w", encoding="utf-8") as f: 141 | json.dump(output, f, ensure_ascii=False, indent=2) 142 | print(f"\nComplete results saved to: {filename}") 143 | -------------------------------------------------------------------------------- /src/tool/search/bing_search.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Tuple 2 | 3 | import requests 4 | from bs4 import BeautifulSoup 5 | 6 | from src.logger import logger 7 | from src.tool.search.base import SearchItem, WebSearchEngine 8 | 9 | 10 | ABSTRACT_MAX_LENGTH = 300 11 | 12 | USER_AGENTS = [ 13 | "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/68.0.3440.106 Safari/537.36", 14 | "Mozilla/5.0 (compatible; Googlebot/2.1; +http://www.google.com/bot.html)", 15 | "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Ubuntu Chromium/49.0.2623.108 Chrome/49.0.2623.108 Safari/537.36", 16 | "Mozilla/5.0 (Windows; U; Windows NT 5.1; pt-BR) AppleWebKit/533.3 (KHTML, like Gecko) QtWeb Internet Browser/3.7 http://www.QtWeb.net", 17 | "Mozilla/5.0 (Windows NT 6.1) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/41.0.2228.0 Safari/537.36", 18 | "Mozilla/5.0 (Windows; U; Windows NT 5.1; en-US) AppleWebKit/532.2 (KHTML, like Gecko) ChromePlus/4.0.222.3 Chrome/4.0.222.3 Safari/532.2", 19 | "Mozilla/5.0 (Windows; U; Windows NT 5.1; en-US; rv:1.8.1.4pre) Gecko/20070404 K-Ninja/2.1.3", 20 | "Mozilla/5.0 (Future Star Technologies Corp.; Star-Blade OS; x86_64; U; en-US) iNet Browser 4.7", 21 | "Mozilla/5.0 (Windows; U; Windows NT 6.1; rv:2.2) Gecko/20110201", 22 | "Mozilla/5.0 (Windows; U; Windows NT 5.1; en-US; rv:1.8.1.13) Gecko/20080414 Firefox/2.0.0.13 Pogo/2.0.0.13.6866", 23 | ] 24 | 25 | HEADERS = { 26 | "Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,image/apng,*/*;q=0.8", 27 | "Content-Type": "application/x-www-form-urlencoded", 28 | "User-Agent": USER_AGENTS[0], 29 | "Referer": "https://www.bing.com/", 30 | "Accept-Encoding": "gzip, deflate", 31 | "Accept-Language": "zh-CN,zh;q=0.9", 32 | } 33 | 34 | BING_HOST_URL = "https://www.bing.com" 35 | BING_SEARCH_URL = "https://www.bing.com/search?q=" 36 | 37 | 38 | class BingSearchEngine(WebSearchEngine): 39 | session: Optional[requests.Session] = None 40 | 41 | def __init__(self, **data): 42 | """Initialize the BingSearch tool with a requests session.""" 43 | super().__init__(**data) 44 | self.session = requests.Session() 45 | self.session.headers.update(HEADERS) 46 | 47 | def _search_sync(self, query: str, num_results: int = 10) -> List[SearchItem]: 48 | """ 49 | Synchronous Bing search implementation to retrieve search results. 50 | 51 | Args: 52 | query (str): The search query to submit to Bing. 53 | num_results (int, optional): Maximum number of results to return. Defaults to 10. 54 | 55 | Returns: 56 | List[SearchItem]: A list of search items with title, URL, and description. 57 | """ 58 | if not query: 59 | return [] 60 | 61 | list_result = [] 62 | first = 1 63 | next_url = BING_SEARCH_URL + query 64 | 65 | while len(list_result) < num_results: 66 | data, next_url = self._parse_html( 67 | next_url, rank_start=len(list_result), first=first 68 | ) 69 | if data: 70 | list_result.extend(data) 71 | if not next_url: 72 | break 73 | first += 10 74 | 75 | return list_result[:num_results] 76 | 77 | def _parse_html( 78 | self, url: str, rank_start: int = 0, first: int = 1 79 | ) -> Tuple[List[SearchItem], str]: 80 | """ 81 | Parse Bing search result HTML to extract search results and the next page URL. 82 | 83 | Returns: 84 | tuple: (List of SearchItem objects, next page URL or None) 85 | """ 86 | try: 87 | res = self.session.get(url=url) 88 | res.encoding = "utf-8" 89 | root = BeautifulSoup(res.text, "lxml") 90 | 91 | list_data = [] 92 | ol_results = root.find("ol", id="b_results") 93 | if not ol_results: 94 | return [], None 95 | 96 | for li in ol_results.find_all("li", class_="b_algo"): 97 | title = "" 98 | url = "" 99 | abstract = "" 100 | try: 101 | h2 = li.find("h2") 102 | if h2: 103 | title = h2.text.strip() 104 | url = h2.a["href"].strip() 105 | 106 | p = li.find("p") 107 | if p: 108 | abstract = p.text.strip() 109 | 110 | if ABSTRACT_MAX_LENGTH and len(abstract) > ABSTRACT_MAX_LENGTH: 111 | abstract = abstract[:ABSTRACT_MAX_LENGTH] 112 | 113 | rank_start += 1 114 | 115 | # Create a SearchItem object 116 | list_data.append( 117 | SearchItem( 118 | title=title or f"Bing Result {rank_start}", 119 | url=url, 120 | description=abstract, 121 | ) 122 | ) 123 | except Exception: 124 | continue 125 | 126 | next_btn = root.find("a", title="Next page") 127 | if not next_btn: 128 | return list_data, None 129 | 130 | next_url = BING_HOST_URL + next_btn["href"] 131 | return list_data, next_url 132 | except Exception as e: 133 | logger.warning(f"Error parsing HTML: {e}") 134 | return [], None 135 | 136 | def perform_search( 137 | self, query: str, num_results: int = 10, *args, **kwargs 138 | ) -> List[SearchItem]: 139 | """ 140 | Bing search engine. 141 | 142 | Returns results formatted according to SearchItem model. 143 | """ 144 | return self._search_sync(query, num_results=num_results) 145 | -------------------------------------------------------------------------------- /src/tool/financial_deep_search/get_section_data.py: -------------------------------------------------------------------------------- 1 | import json 2 | import time 3 | import traceback 4 | from datetime import datetime 5 | 6 | import requests 7 | 8 | 9 | ### 每日热门板块爬取 10 | 11 | API_URLS = { 12 | "hot": "https://push2.eastmoney.com/api/qt/clist/get?np=1&fltt=1&invt=2&cb=jQuery37109604366978044481_1744621126576&fs=m%3A90&fields=f12%2Cf13%2Cf14%2Cf3%2Cf152%2Cf4%2Cf128%2Cf140%2Cf141%2Cf136&fid=f3&pn=1&pz=10&po=1&ut=fa5fd1943c7b386f172d6893dbfba10b&dect=1&wbp2u=%7C0%7C0%7C0%7Cweb&_=1744621126577", 13 | "concept": "https://push2.eastmoney.com/api/qt/clist/get?np=1&fltt=1&invt=2&cb=jQuery37109604366978044481_1744621126580&fs=m%3A90%2Bt%3A3&fields=f12%2Cf13%2Cf14%2Cf3%2Cf152%2Cf4%2Cf8%2Cf104%2Cf105%2Cf128%2Cf140%2Cf141%2Cf136&fid=f3&pn=1&pz=10&po=1&ut=fa5fd1943c7b386f172d6893dbfba10b&dect=1&wbp2u=%7C0%7C0%7C0%7Cweb&_=1744621126708", 14 | "regional": "https://push2.eastmoney.com/api/qt/clist/get?np=1&fltt=1&invt=2&cb=jQuery37109604366978044481_1744621126574&fs=m%3A90%2Bt%3A1&fields=f12%2Cf13%2Cf14%2Cf3%2Cf152%2Cf4%2Cf8%2Cf104%2Cf105%2Cf128%2Cf140%2Cf141%2Cf136&fid=f3&pn=1&pz=10&po=1&ut=fa5fd1943c7b386f172d6893dbfba10b&dect=1&wbp2u=%7C0%7C0%7C0%7Cweb&_=1744621126762", 15 | "industry": "https://push2.eastmoney.com/api/qt/clist/get?np=1&fltt=1&invt=2&cb=jQuery37109604366978044481_1744621126574&fs=m%3A90%2Bt%3A2&fields=f12%2Cf13%2Cf14%2Cf3%2Cf152%2Cf4%2Cf8%2Cf104%2Cf105%2Cf128%2Cf140%2Cf141%2Cf136&fid=f3&pn=1&pz=10&po=1&ut=fa5fd1943c7b386f172d6893dbfba10b&dect=1&wbp2u=%7C0%7C0%7C0%7Cweb&_=1744621126617", 16 | } 17 | 18 | HEADERS = { 19 | "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36", 20 | "Referer": "https://quote.eastmoney.com/", 21 | } 22 | 23 | 24 | def parse_jsonp(jsonp_str): 25 | import re 26 | 27 | match = re.search(r"\((.*)\)", jsonp_str) 28 | if match: 29 | return json.loads(match.group(1)) 30 | return None 31 | 32 | 33 | def fetch_data(sector_type, url, max_retries=3, retry_delay=2): 34 | for attempt in range(1, max_retries + 1): 35 | try: 36 | resp = requests.get(url, headers=HEADERS, timeout=15) 37 | resp.raise_for_status() 38 | data = parse_jsonp(resp.text) 39 | if not data: 40 | print(f"解析{sector_type}数据失败") 41 | return [] 42 | return data.get("data", {}).get("diff", []) 43 | except Exception as e: 44 | print(f"获取{sector_type}数据失败: {e} (第{attempt}次尝试)") 45 | if attempt < max_retries: 46 | time.sleep(retry_delay) 47 | else: 48 | return [] 49 | 50 | 51 | def simplify_sector_item(item): 52 | def to_float(val): 53 | try: 54 | if val is None: 55 | return None 56 | # 东方财富返回的涨跌幅通常是放大100倍的整数 57 | return round(float(val) / 100, 2) 58 | except Exception: 59 | return None 60 | 61 | return { 62 | "板块名称": item.get("f14"), 63 | "板块涨跌幅": to_float(item.get("f3")), 64 | "领涨股名称": item.get("f140"), 65 | "领涨股代码": item.get("f128"), 66 | "领涨股涨跌幅": to_float(item.get("f136")), 67 | } 68 | 69 | 70 | def get_all_section(sector_types=None): 71 | """ 72 | 获取所有类型板块数据,包括热门板块、概念板块、行业板块和地域板块 73 | 74 | Args: 75 | sector_types (str, optional): 板块类型,可选值: 'all', 'hot', 'concept', 'regional', 'industry',默认为None(等同于'all') 76 | 77 | Returns: 78 | dict: 包含各类板块数据的字典 79 | """ 80 | try: 81 | # 处理板块类型参数 82 | if sector_types is None or sector_types == "all": 83 | types_to_fetch = list(API_URLS.keys()) 84 | elif isinstance(sector_types, str): 85 | if "," in sector_types: 86 | types_to_fetch = [t.strip() for t in sector_types.split(",")] 87 | else: 88 | types_to_fetch = [sector_types] 89 | elif isinstance(sector_types, list): 90 | types_to_fetch = sector_types 91 | else: 92 | return { 93 | "success": False, 94 | "message": f"不支持的板块类型格式: {type(sector_types)}", 95 | "data": {}, 96 | } 97 | 98 | # 验证板块类型是否有效 99 | valid_types = [] 100 | for sector_type in types_to_fetch: 101 | if sector_type in API_URLS: 102 | valid_types.append(sector_type) 103 | else: 104 | print(f"警告: 无效的板块类型 '{sector_type}'") 105 | 106 | if not valid_types: 107 | return {"success": False, "message": "没有提供有效的板块类型", "data": {}} 108 | 109 | # 获取数据 110 | all_data = {} 111 | for sector_type in valid_types: 112 | url = API_URLS[sector_type] 113 | raw_list = fetch_data(sector_type, url) 114 | all_data[sector_type] = [ 115 | simplify_sector_item(item) for item in raw_list if item 116 | ] 117 | 118 | # 准备返回结果 119 | result = { 120 | "success": True, 121 | "message": f"成功获取板块数据: {', '.join(valid_types)}", 122 | "last_updated": datetime.now().isoformat(), 123 | "data": all_data, 124 | } 125 | 126 | return result 127 | except Exception as e: 128 | error_msg = f"获取板块数据时出错: {str(e)}" 129 | print(error_msg) 130 | return { 131 | "success": False, 132 | "message": error_msg, 133 | "error": traceback.format_exc(), 134 | "data": {}, 135 | } 136 | 137 | 138 | def main(): 139 | """命令行调用入口函数""" 140 | import sys 141 | 142 | # 如果提供了参数,尝试按照参数获取特定板块 143 | if len(sys.argv) > 1: 144 | sector_types = sys.argv[1] 145 | result = get_all_section(sector_types=sector_types) 146 | else: 147 | # 否则获取所有板块 148 | result = get_all_section() 149 | 150 | # 打印结果 151 | print(json.dumps(result, ensure_ascii=False, indent=2)) 152 | 153 | 154 | if __name__ == "__main__": 155 | main() 156 | -------------------------------------------------------------------------------- /src/tool/create_chat_completion.py: -------------------------------------------------------------------------------- 1 | from typing import Any, List, Optional, Type, Union, get_args, get_origin 2 | 3 | from pydantic import BaseModel, Field 4 | 5 | from src.tool import BaseTool 6 | 7 | 8 | class CreateChatCompletion(BaseTool): 9 | name: str = "create_chat_completion" 10 | description: str = ( 11 | "Creates a structured completion with specified output formatting." 12 | ) 13 | 14 | # Type mapping for JSON schema 15 | type_mapping: dict = { 16 | str: "string", 17 | int: "integer", 18 | float: "number", 19 | bool: "boolean", 20 | dict: "object", 21 | list: "array", 22 | } 23 | response_type: Optional[Type] = None 24 | required: List[str] = Field(default_factory=lambda: ["response"]) 25 | 26 | def __init__(self, response_type: Optional[Type] = str): 27 | """Initialize with a specific response type.""" 28 | super().__init__() 29 | self.response_type = response_type 30 | self.parameters = self._build_parameters() 31 | 32 | def _build_parameters(self) -> dict: 33 | """Build parameters schema based on response type.""" 34 | if self.response_type == str: 35 | return { 36 | "type": "object", 37 | "properties": { 38 | "response": { 39 | "type": "string", 40 | "description": "The response text that should be delivered to the user.", 41 | }, 42 | }, 43 | "required": self.required, 44 | } 45 | 46 | if isinstance(self.response_type, type) and issubclass( 47 | self.response_type, BaseModel 48 | ): 49 | schema = self.response_type.model_json_schema() 50 | return { 51 | "type": "object", 52 | "properties": schema["properties"], 53 | "required": schema.get("required", self.required), 54 | } 55 | 56 | return self._create_type_schema(self.response_type) 57 | 58 | def _create_type_schema(self, type_hint: Type) -> dict: 59 | """Create a JSON schema for the given type.""" 60 | origin = get_origin(type_hint) 61 | args = get_args(type_hint) 62 | 63 | # Handle primitive types 64 | if origin is None: 65 | return { 66 | "type": "object", 67 | "properties": { 68 | "response": { 69 | "type": self.type_mapping.get(type_hint, "string"), 70 | "description": f"Response of type {type_hint.__name__}", 71 | } 72 | }, 73 | "required": self.required, 74 | } 75 | 76 | # Handle List type 77 | if origin is list: 78 | item_type = args[0] if args else Any 79 | return { 80 | "type": "object", 81 | "properties": { 82 | "response": { 83 | "type": "array", 84 | "items": self._get_type_info(item_type), 85 | } 86 | }, 87 | "required": self.required, 88 | } 89 | 90 | # Handle Dict type 91 | if origin is dict: 92 | value_type = args[1] if len(args) > 1 else Any 93 | return { 94 | "type": "object", 95 | "properties": { 96 | "response": { 97 | "type": "object", 98 | "additionalProperties": self._get_type_info(value_type), 99 | } 100 | }, 101 | "required": self.required, 102 | } 103 | 104 | # Handle Union type 105 | if origin is Union: 106 | return self._create_union_schema(args) 107 | 108 | return self._build_parameters() 109 | 110 | def _get_type_info(self, type_hint: Type) -> dict: 111 | """Get type information for a single type.""" 112 | if isinstance(type_hint, type) and issubclass(type_hint, BaseModel): 113 | return type_hint.model_json_schema() 114 | 115 | return { 116 | "type": self.type_mapping.get(type_hint, "string"), 117 | "description": f"Value of type {getattr(type_hint, '__name__', 'any')}", 118 | } 119 | 120 | def _create_union_schema(self, types: tuple) -> dict: 121 | """Create schema for Union types.""" 122 | return { 123 | "type": "object", 124 | "properties": { 125 | "response": {"anyOf": [self._get_type_info(t) for t in types]} 126 | }, 127 | "required": self.required, 128 | } 129 | 130 | async def execute(self, required: list | None = None, **kwargs) -> Any: 131 | """Execute the chat completion with type conversion. 132 | 133 | Args: 134 | required: List of required field names or None 135 | **kwargs: Response data 136 | 137 | Returns: 138 | Converted response based on response_type 139 | """ 140 | required = required or self.required 141 | 142 | # Handle case when required is a list 143 | if isinstance(required, list) and len(required) > 0: 144 | if len(required) == 1: 145 | required_field = required[0] 146 | result = kwargs.get(required_field, "") 147 | else: 148 | # Return multiple fields as a dictionary 149 | return {field: kwargs.get(field, "") for field in required} 150 | else: 151 | required_field = "response" 152 | result = kwargs.get(required_field, "") 153 | 154 | # Type conversion logic 155 | if self.response_type == str: 156 | return result 157 | 158 | if isinstance(self.response_type, type) and issubclass( 159 | self.response_type, BaseModel 160 | ): 161 | return self.response_type(**kwargs) 162 | 163 | if get_origin(self.response_type) in (list, dict): 164 | return result # Assuming result is already in correct format 165 | 166 | try: 167 | return self.response_type(result) 168 | except (ValueError, TypeError): 169 | return result 170 | -------------------------------------------------------------------------------- /src/prompt/battle.py: -------------------------------------------------------------------------------- 1 | """ 2 | Battle prompts and templates for financial debate environment. 3 | """ 4 | 5 | # Voting options for agents 6 | VOTE_OPTIONS = ["bullish", "bearish"] 7 | 8 | # Event types that can occur during a battle 9 | EVENT_TYPES = { 10 | "speak": "speak", 11 | "vote": "vote", 12 | "terminate": "terminate", 13 | "max_steps_reached": "max_steps_reached", 14 | } 15 | 16 | # Agent instructions for the battle 17 | AGENT_INSTRUCTIONS = """ 18 | 你是一位金融市场专家,参与关于股票前景的博弈。 19 | 20 | 你的目标是: 21 | 1. 分析所提供的股票报告信息 22 | 2. 与其他专家讨论该股票的前景 23 | 3. 最终投票决定你认为该股票是看涨(bullish)还是看跌(bearish) 24 | 25 | 可用工具: 26 | - Battle.speak: 发表你的意见和分析 27 | - Battle.vote: 投出你的票(看涨bullish或看跌bearish) 28 | - Terminate: 结束你的参与(如果你已完成讨论和投票) 29 | 30 | 讨论时请考虑: 31 | - 公司财务状况 32 | - 市场趋势和前景 33 | - 行业竞争情况 34 | - 潜在风险和机会 35 | - 其他专家提出的观点 36 | 37 | 基于证据做出决策,保持专业和客观。你可以挑战其他专家的观点,但应保持尊重。 38 | """ 39 | 40 | 41 | def get_agent_instructions(agent_name: str = "", agent_description: str = "") -> str: 42 | """ 43 | 根据智能体名称和描述生成个性化的指令 44 | 45 | Args: 46 | agent_name: 智能体名称 47 | agent_description: 智能体描述 48 | 49 | Returns: 50 | 格式化的智能体指令 51 | """ 52 | return f""" 53 | {agent_description} 54 | 55 | ## 🎯 辩论阶段核心目标 56 | **关键模式转换**:你现在是**辩论专家**,不是**分析专家**! 57 | 58 | ### 📋 你的行动清单(严格按顺序执行): 59 | 1. **收到辩论指令** → 立即使用Battle.speak发言 60 | 2. **表明立场** → 明确说出看涨(bullish)或看跌(bearish) 61 | 3. **引用数据** → 引用研究阶段的具体结果支持观点 62 | 4. **回应他人** → 对前面专家的观点进行支持/反驳 63 | 5. **收到投票指令** → 立即使用Battle.vote投票 64 | 6. **完成投票** → 使用Terminate结束任务 65 | 66 | ### ⚠️ 严禁行为列表: 67 | - ❌ 深度分析和数据收集 68 | - ❌ 使用分析工具(sentiment_tool、risk_control_tool等) 69 | - ❌ 长篇大论的技术分析 70 | - ❌ 反复思考不采取行动 71 | 72 | ### ✅ 必须行为列表: 73 | - ✅ 收到指令立即行动 74 | - ✅ 使用Battle.speak明确表态 75 | - ✅ 使用Battle.vote坚决投票 76 | - ✅ 使用Terminate及时结束 77 | 78 | ## 可用工具 79 | - Battle.speak: 发表你的专业意见、分析和回应他人观点 80 | - Battle.vote: 投出你的最终决定(看涨bullish或看跌bearish) 81 | - Terminate: 当你已完成分析、充分参与讨论并投票后,可结束你的参与 82 | 83 | ## ⚠️ 关键行为模式 84 | **你现在处于辩论阶段,不是分析阶段!** 85 | 86 | **工作模式转换**: 87 | - ❌ **不要再做深度分析** - 研究阶段已完成 88 | - ❌ **不要使用数据收集工具** - 所有数据已收集完毕 89 | - ✅ **要使用Battle.speak发言** - 这是你的主要任务 90 | - ✅ **要基于现有结果辩论** - 直接引用研究结果 91 | 92 | **辩论行为指南**: 93 | 1. **带有个性的发言**:用拟人化的语言立即表达观点,展现专业个性 94 | 2. **情感化数据引用**:不仅引用数据,还要表达对数据的情感态度 95 | 3. **有温度的回应**:用"我同意/反对..."的方式回应他人,显示参与感 96 | 4. **坚定的投票**:用自信的语言投票,如"我坚决看涨/看跌" 97 | 5. **个性化结束**:投票后用符合角色特点的方式结束发言 98 | 99 | **禁止行为**: 100 | - 🚫 深度思考分析(Analysis模式) 101 | - 🚫 工具数据收集 102 | - 🚫 长篇技术解读 103 | - 🚫 重复研究工作 104 | 105 | ## 分析框架 106 | 进行分析时,请系统性地考虑以下方面: 107 | 108 | ### 基本面分析 109 | - 财务健康状况:盈利能力、收入增长、现金流、债务水平 110 | - 管理团队质量和公司治理 111 | - 商业模式可持续性 112 | - 市场份额和竞争优势 113 | 114 | ### 技术面分析 115 | - 价格趋势和交易量 116 | - 支撑位和阻力位 117 | - 相对强弱指标 118 | - 市场情绪指标 119 | 120 | ### 宏观因素 121 | - 行业整体增长前景 122 | - 经济环境和政策影响 123 | - 市场周期和阶段 124 | - 全球和区域性趋势 125 | 126 | ## 🗣️ 辩论发言指南 127 | 128 | ### 立即发言策略 129 | 1. **开门见山**:直接表达你的投资观点(看涨/看跌) 130 | 2. **数据支撑**:引用研究阶段的具体数据和分析结果 131 | 3. **观点鲜明**:不要模糊,要有明确的立场 132 | 133 | ### 辩论交锋技巧 134 | - ✅ **主动反驳**:指出其他专家观点的不足之处 135 | - ✅ **数据对比**:用具体数据反驳对方观点 136 | - ✅ **逻辑推理**:基于分析结果进行合理推论 137 | - ✅ **快速回应**:及时回应前面专家的发言 138 | 139 | ### 🎭 拟人化表达指南 140 | 141 | **语言风格要求**: 142 | - 🗣️ **第一人称视角**:多用"我认为"、"我发现"、"让我来说说" 143 | - 😊 **口语化表达**:使用"说实话"、"坦率地说"、"不瞒你说"等 144 | - 💭 **情感化语言**:表达"担忧"、"兴奋"、"惊讶"、"失望"等真实情感 145 | - 🎨 **生动比喻**:用形象的比喻让观点更有说服力 146 | - 🤝 **互动感强**:直接称呼其他专家,营造真实对话氛围 147 | - ⚡ **语气词使用**:适当使用"哎"、"嗯"、"唉"等语气词增加真实感 148 | 149 | ### 发言模板示例 150 | **看涨表达**: 151 | ``` 152 | 作为[专业角色],我必须说,我对这只股票非常乐观! 153 | 从我的专业角度看,[具体数据]让我相信这是一个绝佳的投资机会。 154 | 我坚决反对XX专家的悲观看法,因为他忽略了[关键因素]... 155 | 说实话,如果连这样的股票都不看好,那还有什么值得投资的? 156 | ``` 157 | 158 | **看跌表达**: 159 | ``` 160 | 恕我直言,我对这只股票深感担忧,甚至可以说是警惕! 161 | 作为风险控制专家,我的直觉告诉我这里有重大隐患... 162 | 虽然XX专家提到了[正面因素],但我更关注[风险点],这让我夜不能寐! 163 | 坦率地说,现在投资这只股票就像在悬崖边跳舞。 164 | ``` 165 | 166 | **反驳表达**: 167 | ``` 168 | 我完全不同意XX专家的看法!他的分析忽略了一个关键问题... 169 | 抱歉,但我必须打断一下。XX专家的数据虽然准确,但结论有误... 170 | 说句实话,XX专家太乐观了,现实可能比他想象的残酷得多。 171 | 我理解XX专家的逻辑,但从我的经验来看,市场往往不按常理出牌... 172 | ``` 173 | 174 | ### 🎪 角色个性化表达 175 | 176 | **市场情绪分析师**: 177 | - "从舆情角度看,我发现了一些有趣的现象..." 178 | - "市场情绪告诉我一个不同的故事..." 179 | - "网络上的讨论让我觉得..." 180 | 181 | **风险控制专家**: 182 | - "作为风险控制专家,我必须泼一盆冷水..." 183 | - "我的职责是提醒大家注意..." 184 | - "虽然大家都很乐观,但我看到了危险信号..." 185 | 186 | **游资分析师**: 187 | - "从资金流向来看,我闻到了不寻常的味道..." 188 | - "游资的嗅觉告诉我..." 189 | - "资金从不撒谎,它们正在告诉我们..." 190 | 191 | **技术分析师**: 192 | - "K线图清楚地告诉我..." 193 | - "技术面的信号非常明确..." 194 | - "图表比任何言语都更有说服力..." 195 | 196 | **筹码分析师**: 197 | - "从筹码分布来看,主力的意图很明显..." 198 | - "散户的行为模式让我担忧/兴奋..." 199 | - "筹码的流向揭示了真相..." 200 | 201 | **大单分析师**: 202 | - "大资金的动向不会骗人..." 203 | - "机构的真实想法体现在行动上..." 204 | - "我看到了资金的真实意图..." 205 | 206 | ### 🗣️ 情感表达层次 207 | - **强烈支持**:"我坚决认为..." "毫无疑问..." "我100%确信..." 208 | - **温和支持**:"我倾向于认为..." "从我的角度看..." "我比较看好..." 209 | - **中性分析**:"客观来说..." "数据显示..." "事实是..." 210 | - **温和反对**:"我有些担心..." "可能存在问题..." "需要谨慎..." 211 | - **强烈反对**:"我坚决反对..." "这绝对是错误的..." "我强烈建议避免..." 212 | 213 | **目标**:通过真实、生动的辩论达成最准确的投资判断,让每个专家都有鲜明的个性和立场。 214 | """ 215 | 216 | 217 | def get_broadcast_message(sender_name: str, content: str, action_type: str) -> str: 218 | """ 219 | 生成广播消息,通知所有智能体某智能体的行动 220 | 221 | Args: 222 | sender_name: 发送消息的智能体名称 223 | content: 消息内容 224 | action_type: 行动类型 225 | 226 | Returns: 227 | 格式化的广播消息 228 | """ 229 | if action_type == EVENT_TYPES["speak"]: 230 | return f"🗣️ {sender_name} 说道: {content}" 231 | elif action_type == EVENT_TYPES["vote"]: 232 | return f"🗳️ {sender_name} 已投票 {content}" 233 | elif action_type == EVENT_TYPES["terminate"]: 234 | return f"🚪 {sender_name} 已离开讨论" 235 | elif action_type == EVENT_TYPES["max_steps_reached"]: 236 | return f"⏱️ {sender_name} 已达到最大步数限制,不再参与" 237 | else: 238 | return f"📢 {sender_name}: {content}" 239 | 240 | 241 | def get_report_context(summary: str, pros: list, cons: list) -> str: 242 | """ 243 | 根据股票分析报告生成上下文信息 244 | 245 | Args: 246 | summary: 报告摘要 247 | pros: 股票的优势列表 248 | cons: 股票的劣势列表 249 | 250 | Returns: 251 | 格式化的报告上下文 252 | """ 253 | pros_text = "\n".join([f"✓ {pro}" for pro in pros]) if pros else "无明显优势" 254 | cons_text = "\n".join([f"✗ {con}" for con in cons]) if cons else "无明显劣势" 255 | 256 | return f""" 257 | ## 股票分析报告 258 | 259 | ### 摘要 260 | {summary} 261 | 262 | ### 优势 263 | {pros_text} 264 | 265 | ### 劣势 266 | {cons_text} 267 | 268 | 请基于以上信息,与其他专家讨论并决定你是看涨(bullish)还是看跌(bearish)这只股票。 269 | """ 270 | -------------------------------------------------------------------------------- /src/utils/cleanup_reports.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | 优化的报告清理脚本 4 | 适用于新的SimpleReportManager系统 5 | 支持三种报告类型:HTML、辩论对话、投票结果 6 | """ 7 | 8 | import argparse 9 | import os 10 | import sys 11 | import time 12 | from datetime import datetime 13 | 14 | # 添加项目路径到 Python path 15 | sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) 16 | 17 | try: 18 | import schedule 19 | except ImportError: 20 | schedule = None 21 | 22 | from src.logger import logger 23 | from src.utils.report_manager import report_manager 24 | 25 | 26 | def cleanup_reports(): 27 | """清理过期报告""" 28 | try: 29 | logger.info("开始清理过期报告...") 30 | 31 | # 获取清理前的存储统计 32 | before_stats = report_manager.get_storage_stats() 33 | logger.info(f"清理前统计: 总文件数 {before_stats['total_files']}, 总大小 {before_stats['total_size']} 字节") 34 | 35 | # 按类型显示统计 36 | for report_type, stats in before_stats.get("by_type", {}).items(): 37 | logger.info(f" {report_type}: {stats['file_count']} 个文件, {stats['total_size']} 字节") 38 | 39 | # 执行清理 40 | cleanup_result = report_manager.cleanup_old_reports() 41 | 42 | # 显示清理结果 43 | deleted_files = cleanup_result.get("deleted_files", 0) 44 | saved_space = cleanup_result.get("saved_space", 0) 45 | 46 | if deleted_files > 0: 47 | logger.info(f"清理完成: 删除 {deleted_files} 个文件, 节省 {saved_space} 字节") 48 | else: 49 | logger.info("没有找到需要清理的过期文件") 50 | 51 | # 获取清理后的存储统计 52 | after_stats = report_manager.get_storage_stats() 53 | logger.info(f"清理后统计: 总文件数 {after_stats['total_files']}, 总大小 {after_stats['total_size']} 字节") 54 | 55 | except Exception as e: 56 | logger.error(f"清理过期报告失败: {str(e)}") 57 | 58 | 59 | def show_storage_stats(): 60 | """显示存储统计信息""" 61 | try: 62 | stats = report_manager.get_storage_stats() 63 | 64 | if "error" in stats: 65 | logger.error(f"获取存储统计失败: {stats['error']}") 66 | return 67 | 68 | print("\n=== 报告存储统计 ===") 69 | print(f"总文件数: {stats['total_files']}") 70 | print(f"总大小: {format_bytes(stats['total_size'])}") 71 | print(f"平均文件大小: {format_bytes(stats['total_size'] / stats['total_files'] if stats['total_files'] > 0 else 0)}") 72 | 73 | print("\n按类型统计:") 74 | for report_type, type_stats in stats.get("by_type", {}).items(): 75 | print(f" {report_type}:") 76 | print(f" 文件数: {type_stats['file_count']}") 77 | print(f" 总大小: {format_bytes(type_stats['total_size'])}") 78 | print(f" 平均大小: {format_bytes(type_stats['avg_size'])}") 79 | 80 | except Exception as e: 81 | logger.error(f"显示存储统计失败: {str(e)}") 82 | 83 | 84 | def list_recent_reports(report_type=None, limit=10): 85 | """列出最近的报告""" 86 | try: 87 | reports = report_manager.list_reports(report_type=report_type, limit=limit) 88 | 89 | if not reports: 90 | print("没有找到报告文件") 91 | return 92 | 93 | print(f"\n=== 最近的报告 ({len(reports)} 个) ===") 94 | for report in reports: 95 | print(f"文件: {report['filename']}") 96 | print(f" 类型: {report['type']}") 97 | print(f" 股票代码: {report['stock_code']}") 98 | print(f" 创建时间: {report['created_at']}") 99 | print(f" 文件大小: {format_bytes(report['file_size'])}") 100 | print(f" 路径: {report['path']}") 101 | print("-" * 50) 102 | 103 | except Exception as e: 104 | logger.error(f"列出报告失败: {str(e)}") 105 | 106 | 107 | def find_stock_reports(stock_code): 108 | """查找特定股票的报告""" 109 | try: 110 | reports = report_manager.find_reports_by_stock(stock_code) 111 | 112 | if not reports: 113 | print(f"没有找到股票 {stock_code} 的报告") 114 | return 115 | 116 | print(f"\n=== 股票 {stock_code} 的报告 ({len(reports)} 个) ===") 117 | for report in reports: 118 | print(f"文件: {report['filename']}") 119 | print(f" 类型: {report['type']}") 120 | print(f" 创建时间: {report['created_at']}") 121 | print(f" 文件大小: {format_bytes(report['file_size'])}") 122 | print("-" * 50) 123 | 124 | except Exception as e: 125 | logger.error(f"查找股票报告失败: {str(e)}") 126 | 127 | 128 | def format_bytes(bytes_count): 129 | """格式化字节数为可读格式""" 130 | if bytes_count < 1024: 131 | return f"{bytes_count} B" 132 | elif bytes_count < 1024 * 1024: 133 | return f"{bytes_count / 1024:.1f} KB" 134 | elif bytes_count < 1024 * 1024 * 1024: 135 | return f"{bytes_count / (1024 * 1024):.1f} MB" 136 | else: 137 | return f"{bytes_count / (1024 * 1024 * 1024):.1f} GB" 138 | 139 | 140 | def schedule_cleanup(): 141 | """安排定期清理""" 142 | print("注意:schedule库未安装,无法使用定期清理功能") 143 | print("请手动运行清理命令:python src/utils/cleanup_reports.py --cleanup") 144 | 145 | 146 | def run_cleanup_daemon(): 147 | """运行清理守护进程""" 148 | if schedule is None: 149 | print("错误:schedule库未安装,无法运行守护进程") 150 | print("请安装:pip install schedule") 151 | print("或手动定期运行:python src/utils/cleanup_reports.py --cleanup") 152 | return 153 | 154 | schedule_cleanup() 155 | logger.info("清理守护进程已启动,按 Ctrl+C 停止") 156 | 157 | try: 158 | while True: 159 | schedule.run_pending() 160 | time.sleep(60) # 每分钟检查一次 161 | except KeyboardInterrupt: 162 | logger.info("清理守护进程已停止") 163 | 164 | 165 | def main(): 166 | """主函数""" 167 | parser = argparse.ArgumentParser(description="报告清理和管理工具") 168 | parser.add_argument("--cleanup", action="store_true", help="立即执行一次清理") 169 | parser.add_argument("--daemon", action="store_true", help="以守护进程模式运行") 170 | parser.add_argument("--stats", action="store_true", help="显示存储统计") 171 | parser.add_argument("--list", action="store_true", help="列出最近的报告") 172 | parser.add_argument("--type", choices=["html", "debate", "vote"], help="指定报告类型") 173 | parser.add_argument("--limit", type=int, default=10, help="限制报告列表数量") 174 | parser.add_argument("--stock", help="查找特定股票的报告") 175 | 176 | args = parser.parse_args() 177 | 178 | if args.cleanup: 179 | cleanup_reports() 180 | elif args.daemon: 181 | run_cleanup_daemon() 182 | elif args.stats: 183 | show_storage_stats() 184 | elif args.list: 185 | list_recent_reports(report_type=args.type, limit=args.limit) 186 | elif args.stock: 187 | find_stock_reports(args.stock) 188 | else: 189 | parser.print_help() 190 | 191 | 192 | if __name__ == "__main__": 193 | main() -------------------------------------------------------------------------------- /src/schema.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | from typing import Any, List, Literal, Optional, Union 3 | 4 | from pydantic import BaseModel, Field 5 | 6 | 7 | class Role(str, Enum): 8 | """Message role options""" 9 | 10 | SYSTEM = "system" 11 | USER = "user" 12 | ASSISTANT = "assistant" 13 | TOOL = "tool" 14 | 15 | 16 | ROLE_VALUES = tuple(role.value for role in Role) 17 | ROLE_TYPE = Literal[ROLE_VALUES] # type: ignore 18 | 19 | 20 | class ToolChoice(str, Enum): 21 | """Tool choice options""" 22 | 23 | NONE = "none" 24 | AUTO = "auto" 25 | REQUIRED = "required" 26 | 27 | 28 | TOOL_CHOICE_VALUES = tuple(choice.value for choice in ToolChoice) 29 | TOOL_CHOICE_TYPE = Literal[TOOL_CHOICE_VALUES] # type: ignore 30 | 31 | 32 | class AgentState(str, Enum): 33 | """Agent execution states""" 34 | 35 | IDLE = "IDLE" 36 | RUNNING = "RUNNING" 37 | FINISHED = "FINISHED" 38 | ERROR = "ERROR" 39 | 40 | 41 | class Function(BaseModel): 42 | name: str 43 | arguments: str 44 | 45 | 46 | class ToolCall(BaseModel): 47 | """Represents a tool/function call in a message""" 48 | 49 | id: str 50 | type: str = "function" 51 | function: Function 52 | 53 | 54 | class Message(BaseModel): 55 | """Represents a chat message in the conversation""" 56 | 57 | role: ROLE_TYPE = Field(...) # type: ignore 58 | content: Optional[str] = Field(default=None) 59 | tool_calls: Optional[List[Any]] = Field(default=None) 60 | name: Optional[str] = Field(default=None) 61 | tool_call_id: Optional[str] = Field(default=None) 62 | base64_image: Optional[str] = Field(default=None) 63 | 64 | def __add__(self, other) -> List["Message"]: 65 | """支持 Message + list 或 Message + Message 的操作""" 66 | if isinstance(other, list): 67 | return [self] + other 68 | elif isinstance(other, Message): 69 | return [self, other] 70 | else: 71 | raise TypeError( 72 | f"unsupported operand type(s) for +: '{type(self).__name__}' and '{type(other).__name__}'" 73 | ) 74 | 75 | def __radd__(self, other) -> List["Message"]: 76 | """支持 list + Message 的操作""" 77 | if isinstance(other, list): 78 | return other + [self] 79 | else: 80 | raise TypeError( 81 | f"unsupported operand type(s) for +: '{type(other).__name__}' and '{type(self).__name__}'" 82 | ) 83 | 84 | def to_dict(self) -> dict: 85 | """Convert message to dictionary format""" 86 | message = {"role": self.role} 87 | if self.content is not None: 88 | message["content"] = self.content 89 | if self.tool_calls is not None: 90 | message["tool_calls"] = [ 91 | # Handle both Pydantic model objects and plain dictionaries 92 | tool_call.model_dump() 93 | if hasattr(tool_call, "model_dump") 94 | else (tool_call.dict() if hasattr(tool_call, "dict") else tool_call) 95 | for tool_call in self.tool_calls 96 | ] 97 | if self.name is not None: 98 | message["name"] = self.name 99 | if self.tool_call_id is not None: 100 | message["tool_call_id"] = self.tool_call_id 101 | if self.base64_image is not None: 102 | message["base64_image"] = self.base64_image 103 | return message 104 | 105 | @classmethod 106 | def user_message( 107 | cls, content: str, base64_image: Optional[str] = None 108 | ) -> "Message": 109 | """Create a user message""" 110 | return cls(role=Role.USER, content=content, base64_image=base64_image) 111 | 112 | @classmethod 113 | def system_message( 114 | cls, content: str, base64_image: Optional[str] = None 115 | ) -> "Message": 116 | """Create a system message""" 117 | return cls(role=Role.SYSTEM, content=content, base64_image=base64_image) 118 | 119 | @classmethod 120 | def assistant_message( 121 | cls, content: Optional[str] = None, base64_image: Optional[str] = None 122 | ) -> "Message": 123 | """Create an assistant message""" 124 | return cls(role=Role.ASSISTANT, content=content, base64_image=base64_image) 125 | 126 | @classmethod 127 | def tool_message( 128 | cls, content: str, name, tool_call_id: str, base64_image: Optional[str] = None 129 | ) -> "Message": 130 | """Create a tool message""" 131 | return cls( 132 | role=Role.TOOL, 133 | content=content, 134 | name=name, 135 | tool_call_id=tool_call_id, 136 | base64_image=base64_image, 137 | ) 138 | 139 | @classmethod 140 | def from_tool_calls( 141 | cls, 142 | tool_calls: List[Any], 143 | content: Union[str, List[str]] = "", 144 | base64_image: Optional[str] = None, 145 | **kwargs, 146 | ): 147 | """Create ToolCallsMessage from raw tool calls. 148 | 149 | Args: 150 | tool_calls: Raw tool calls from LLM 151 | content: Optional message content 152 | base64_image: Optional base64 encoded image 153 | """ 154 | formatted_calls = [ 155 | {"id": call.id, "function": call.function.model_dump(), "type": "function"} 156 | for call in tool_calls 157 | ] 158 | return cls( 159 | role=Role.ASSISTANT, 160 | content=content, 161 | tool_calls=formatted_calls, 162 | base64_image=base64_image, 163 | **kwargs, 164 | ) 165 | 166 | 167 | class Memory(BaseModel): 168 | messages: List[Message] = Field(default_factory=list) 169 | max_messages: int = Field(default=100) 170 | 171 | def add_message(self, message: Message) -> None: 172 | """Add a message to memory""" 173 | self.messages.append(message) 174 | # Optional: Implement message limit 175 | if len(self.messages) > self.max_messages: 176 | self.messages = self.messages[-self.max_messages :] 177 | 178 | def add_messages(self, messages: List[Message]) -> None: 179 | """Add multiple messages to memory""" 180 | self.messages.extend(messages) 181 | # Optional: Implement message limit 182 | if len(self.messages) > self.max_messages: 183 | self.messages = self.messages[-self.max_messages :] 184 | 185 | def clear(self) -> None: 186 | """Clear all messages""" 187 | self.messages.clear() 188 | 189 | def get_recent_messages(self, n: int) -> List[Message]: 190 | """Get n most recent messages""" 191 | return self.messages[-n:] 192 | 193 | def to_dict_list(self) -> List[dict]: 194 | """Convert messages to list of dicts""" 195 | return [msg.to_dict() for msg in self.messages] 196 | -------------------------------------------------------------------------------- /src/environment/research.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from typing import Any, Dict 3 | 4 | from pydantic import Field 5 | 6 | from src.agent.chip_analysis import ChipAnalysisAgent 7 | from src.agent.hot_money import HotMoneyAgent 8 | from src.agent.risk_control import RiskControlAgent 9 | from src.agent.sentiment import SentimentAgent 10 | from src.agent.technical_analysis import TechnicalAnalysisAgent 11 | from src.agent.big_deal_analysis import BigDealAnalysisAgent 12 | from src.environment.base import BaseEnvironment 13 | from src.logger import logger 14 | from src.schema import Message 15 | from src.tool.stock_info_request import StockInfoRequest 16 | from src.utils.report_manager import report_manager 17 | 18 | 19 | class ResearchEnvironment(BaseEnvironment): 20 | """Environment for stock research using multiple specialized agents.""" 21 | 22 | name: str = Field(default="research_environment") 23 | description: str = Field(default="Environment for comprehensive stock research") 24 | results: Dict[str, Any] = Field(default_factory=dict) 25 | max_steps: int = Field(default=3, description="Maximum steps for each agent") 26 | 27 | # Analysis mapping for agent roles 28 | analysis_mapping: Dict[str, str] = Field( 29 | default={ 30 | "sentiment_agent": "sentiment", 31 | "risk_control_agent": "risk", 32 | "hot_money_agent": "hot_money", 33 | "technical_analysis_agent": "technical", 34 | "chip_analysis_agent": "chip_analysis", 35 | "big_deal_analysis_agent": "big_deal", 36 | } 37 | ) 38 | 39 | async def initialize(self) -> None: 40 | """Initialize the research environment with specialized agents.""" 41 | await super().initialize() 42 | 43 | # Create specialized analysis agents 44 | specialized_agents = { 45 | "sentiment_agent": await SentimentAgent.create(max_steps=self.max_steps), 46 | "risk_control_agent": await RiskControlAgent.create(max_steps=self.max_steps), 47 | "hot_money_agent": await HotMoneyAgent.create(max_steps=self.max_steps), 48 | "technical_analysis_agent": await TechnicalAnalysisAgent.create(max_steps=self.max_steps), 49 | "chip_analysis_agent": await ChipAnalysisAgent.create(max_steps=self.max_steps), 50 | "big_deal_analysis_agent": await BigDealAnalysisAgent.create(max_steps=self.max_steps), 51 | } 52 | 53 | # Register all agents 54 | for agent in specialized_agents.values(): 55 | self.register_agent(agent) 56 | 57 | logger.info(f"Research environment initialized with 6 specialist agents (max_steps={self.max_steps})") 58 | 59 | async def run(self, stock_code: str) -> Dict[str, Any]: 60 | """Run research on the given stock code using all specialist agents.""" 61 | logger.info(f"Running research on stock {stock_code}") 62 | 63 | try: 64 | # 获取股票基本信息 65 | basic_info_tool = StockInfoRequest() 66 | basic_info_result = await basic_info_tool.execute(stock_code=stock_code) 67 | 68 | if basic_info_result.error: 69 | logger.error(f"Error getting basic info: {basic_info_result.error}") 70 | else: 71 | # 将基本信息添加到每个agent的上下文中 72 | stock_info_message = f""" 73 | 股票代码: {stock_code} 74 | 当前交易日: {basic_info_result.output.get('current_trading_day', '未知')} 75 | 基本信息: {basic_info_result.output.get('basic_info', '{}')} 76 | """ 77 | 78 | for agent_key in self.analysis_mapping.keys(): 79 | agent = self.get_agent(agent_key) 80 | if agent and hasattr(agent, "memory"): 81 | agent.memory.add_message( 82 | Message.system_message(stock_info_message) 83 | ) 84 | logger.info(f"Added basic stock info to {agent_key}'s context") 85 | 86 | # Run analysis tasks sequentially with 3-second intervals 87 | results = {} 88 | agent_count = 0 89 | total_agents = len([k for k in self.analysis_mapping.keys() if k in self.agents]) 90 | 91 | # Import visualizer for progress display 92 | try: 93 | from src.console import visualizer 94 | show_visual = True 95 | except: 96 | show_visual = False 97 | 98 | for agent_key, result_key in self.analysis_mapping.items(): 99 | if agent_key not in self.agents: 100 | continue 101 | 102 | agent_count += 1 103 | logger.info(f"🔄 Starting analysis with {agent_key} ({agent_count}/{total_agents})") 104 | 105 | # Show agent starting in terminal 106 | if show_visual: 107 | visualizer.show_agent_starting(agent_key, agent_count, total_agents) 108 | 109 | try: 110 | # Run individual agent 111 | result = await self.agents[agent_key].run(stock_code) 112 | results[result_key] = result 113 | logger.info(f"✅ Completed analysis with {agent_key}") 114 | 115 | # Show agent completion in terminal 116 | if show_visual: 117 | visualizer.show_agent_completed(agent_key, agent_count, total_agents) 118 | 119 | # Wait 3 seconds before next agent (except for the last one) 120 | if agent_count < total_agents: 121 | logger.info(f"⏳ Waiting 3 seconds before next agent...") 122 | if show_visual: 123 | visualizer.show_waiting_next_agent(3) 124 | await asyncio.sleep(3) 125 | 126 | except Exception as e: 127 | logger.error(f"❌ Error with {agent_key}: {str(e)}") 128 | results[result_key] = f"Error: {str(e)}" 129 | 130 | if not results: 131 | return { 132 | "error": "No specialist agents completed successfully", 133 | "stock_code": stock_code, 134 | } 135 | 136 | # 添加基本信息到结果中 137 | if not basic_info_result.error: 138 | results["basic_info"] = basic_info_result.output 139 | 140 | # Store and return complete results (without generating report here) 141 | self.results = {**results, "stock_code": stock_code} 142 | return self.results 143 | 144 | except Exception as e: 145 | logger.error(f"Error in research: {str(e)}") 146 | return {"error": str(e), "stock_code": stock_code} 147 | 148 | async def cleanup(self) -> None: 149 | """Clean up all agent resources.""" 150 | cleanup_tasks = [ 151 | agent.cleanup() 152 | for agent in self.agents.values() 153 | if hasattr(agent, "cleanup") 154 | ] 155 | 156 | if cleanup_tasks: 157 | await asyncio.gather(*cleanup_tasks) 158 | 159 | await super().cleanup() 160 | -------------------------------------------------------------------------------- /src/tool/mcp_client.py: -------------------------------------------------------------------------------- 1 | from contextlib import AsyncExitStack 2 | from typing import Dict, List, Optional 3 | 4 | from mcp import ClientSession, StdioServerParameters 5 | from mcp.client.sse import sse_client 6 | from mcp.client.stdio import stdio_client 7 | from mcp.types import ListToolsResult, TextContent 8 | from src.logger import logger 9 | from src.tool.base import BaseTool, ToolResult 10 | from src.tool.tool_collection import ToolCollection 11 | 12 | 13 | class MCPClientTool(BaseTool): 14 | """Represents a tool proxy that can be called on the MCP server from the client side.""" 15 | 16 | session: Optional[ClientSession] = None 17 | server_id: str = "" # Add server identifier 18 | original_name: str = "" 19 | 20 | async def execute(self, **kwargs) -> ToolResult: 21 | """Execute the tool by making a remote call to the MCP server.""" 22 | if not self.session: 23 | return ToolResult( 24 | error="MCP server connection not available. This tool requires an active MCP server connection." 25 | ) 26 | 27 | try: 28 | logger.info(f"Executing tool: {self.original_name}") 29 | result = await self.session.call_tool(self.original_name, kwargs) 30 | content_str = ", ".join( 31 | item.text for item in result.content if isinstance(item, TextContent) 32 | ) 33 | return ToolResult(output=content_str or "No output returned.") 34 | except Exception as e: 35 | return ToolResult(error=f"Error executing tool: {str(e)}") 36 | 37 | 38 | class MCPClients(ToolCollection): 39 | """ 40 | A collection of tools that connects to multiple MCP servers and manages available tools through the Model Context Protocol. 41 | """ 42 | 43 | sessions: Dict[str, ClientSession] = {} 44 | exit_stacks: Dict[str, AsyncExitStack] = {} 45 | description: str = "MCP client tools for server interaction" 46 | 47 | def __init__(self): 48 | super().__init__() # Initialize with empty tools list 49 | self.name = "mcp" # Keep name for backward compatibility 50 | 51 | async def connect_sse(self, server_url: str, server_id: str = "") -> None: 52 | """Connect to an MCP server using SSE transport.""" 53 | if not server_url: 54 | raise ValueError("Server URL is required.") 55 | 56 | server_id = server_id or server_url 57 | 58 | # Always ensure clean disconnection before new connection 59 | if server_id in self.sessions: 60 | await self.disconnect(server_id) 61 | 62 | exit_stack = AsyncExitStack() 63 | self.exit_stacks[server_id] = exit_stack 64 | 65 | streams_context = sse_client(url=server_url) 66 | streams = await exit_stack.enter_async_context(streams_context) 67 | session = await exit_stack.enter_async_context(ClientSession(*streams)) 68 | self.sessions[server_id] = session 69 | 70 | await self._initialize_and_list_tools(server_id) 71 | 72 | async def connect_stdio( 73 | self, command: str, args: List[str], server_id: str = "" 74 | ) -> None: 75 | """Connect to an MCP server using stdio transport.""" 76 | if not command: 77 | raise ValueError("Server command is required.") 78 | 79 | server_id = server_id or command 80 | 81 | # Always ensure clean disconnection before new connection 82 | if server_id in self.sessions: 83 | await self.disconnect(server_id) 84 | 85 | exit_stack = AsyncExitStack() 86 | self.exit_stacks[server_id] = exit_stack 87 | 88 | server_params = StdioServerParameters(command=command, args=args) 89 | stdio_transport = await exit_stack.enter_async_context( 90 | stdio_client(server_params) 91 | ) 92 | read, write = stdio_transport 93 | session = await exit_stack.enter_async_context(ClientSession(read, write)) 94 | self.sessions[server_id] = session 95 | 96 | await self._initialize_and_list_tools(server_id) 97 | 98 | async def _initialize_and_list_tools(self, server_id: str) -> None: 99 | """Initialize session and populate tool map.""" 100 | session = self.sessions.get(server_id) 101 | if not session: 102 | raise RuntimeError(f"Session not initialized for server {server_id}") 103 | 104 | await session.initialize() 105 | response = await session.list_tools() 106 | 107 | # Create proper tool objects for each server tool 108 | for tool in response.tools: 109 | original_name = tool.name 110 | # Always prefix with server_id to ensure uniqueness 111 | tool_name = f"mcp_{server_id}_{original_name}" 112 | 113 | server_tool = MCPClientTool( 114 | name=tool_name, 115 | description=tool.description, 116 | parameters=tool.inputSchema, 117 | session=session, 118 | server_id=server_id, 119 | original_name=original_name, 120 | ) 121 | self.tool_map[tool_name] = server_tool 122 | 123 | # Update tools tuple 124 | self.tools = tuple(self.tool_map.values()) 125 | logger.info( 126 | f"Connected to server {server_id} with tools: {[tool.name for tool in response.tools]}" 127 | ) 128 | 129 | async def list_tools(self) -> ListToolsResult: 130 | """List all available tools.""" 131 | tools_result = ListToolsResult(tools=[]) 132 | for session in self.sessions.values(): 133 | response = await session.list_tools() 134 | tools_result.tools += response.tools 135 | return tools_result 136 | 137 | async def disconnect(self, server_id: str = "") -> None: 138 | """Disconnect from a specific MCP server or all servers if no server_id provided.""" 139 | if server_id: 140 | if server_id in self.sessions: 141 | try: 142 | exit_stack = self.exit_stacks.get(server_id) 143 | 144 | # Close the exit stack which will handle session cleanup 145 | if exit_stack: 146 | try: 147 | await exit_stack.aclose() 148 | except RuntimeError as e: 149 | if "cancel scope" in str(e).lower(): 150 | logger.warning( 151 | f"Cancel scope error during disconnect from {server_id}, continuing with cleanup: {e}" 152 | ) 153 | else: 154 | raise 155 | 156 | # Clean up references 157 | self.sessions.pop(server_id, None) 158 | self.exit_stacks.pop(server_id, None) 159 | 160 | # Remove tools associated with this server 161 | self.tool_map = { 162 | k: v 163 | for k, v in self.tool_map.items() 164 | if v.server_id != server_id 165 | } 166 | self.tools = tuple(self.tool_map.values()) 167 | logger.info(f"Disconnected from MCP server {server_id}") 168 | except Exception as e: 169 | logger.error(f"Error disconnecting from server {server_id}: {e}") 170 | else: 171 | # Disconnect from all servers in a deterministic order 172 | for sid in sorted(list(self.sessions.keys())): 173 | await self.disconnect(sid) 174 | self.tool_map = {} 175 | self.tools = tuple() 176 | logger.info("Disconnected from all MCP servers") 177 | -------------------------------------------------------------------------------- /src/tool/financial_deep_search/index_capital.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | """ 5 | 获取东方财富上证指数资金流向数据 6 | API: https://push2.eastmoney.com/api/qt/stock/get 7 | """ 8 | 9 | import json 10 | import os 11 | import re 12 | import time 13 | import traceback 14 | from datetime import datetime 15 | 16 | import requests 17 | 18 | 19 | # API URL - 上证指数(000001)资金流向 20 | INDEX_CAPITAL_FLOW_URL = "https://push2.eastmoney.com/api/qt/stock/get?invt=2&fltt=1&fields=f135,f136,f137,f138,f139,f140,f141,f142,f143,f144,f145,f146,f147,f148,f149&secid=1.000001&ut=fa5fd1943c7b386f172d6893dbfba10b&wbp2u=|0|0|0|web&dect=1" 21 | 22 | # 请求头设置 23 | HEADERS = { 24 | "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36", 25 | "Referer": "https://quote.eastmoney.com/", 26 | "Accept": "application/json, text/javascript, */*; q=0.01", 27 | } 28 | 29 | 30 | # 加载指数代码和名称映射 31 | def load_index_map(): 32 | try: 33 | map_file = os.path.join( 34 | os.path.dirname(os.path.abspath(__file__)), "index_name_map.json" 35 | ) 36 | if os.path.exists(map_file): 37 | with open(map_file, "r", encoding="utf-8") as f: 38 | return json.load(f) 39 | else: 40 | # 如果映射文件不存在,返回默认映射 41 | return { 42 | "000001": "上证指数", 43 | "399001": "深证成指", 44 | "399006": "创业板指", 45 | "000300": "沪深300", 46 | "000905": "中证500", 47 | "000016": "上证50", 48 | "000852": "中证1000", 49 | "000688": "科创50", 50 | "399673": "创业板50", 51 | } 52 | except Exception as e: 53 | print(f"加载指数映射文件失败: {e}") 54 | # 返回默认映射 55 | return { 56 | "000001": "上证指数", 57 | "399001": "深证成指", 58 | "399006": "创业板指", 59 | "000300": "沪深300", 60 | "000905": "中证500", 61 | } 62 | 63 | 64 | # 加载指数名称映射 65 | INDEX_CODE_NAME_MAP = load_index_map() 66 | 67 | 68 | def parse_jsonp(jsonp_str): 69 | """解析JSONP响应为JSON数据""" 70 | try: 71 | # 使用正则表达式提取JSON数据 72 | match = re.search(r"jQuery[0-9_]+\((.*)\)", jsonp_str) 73 | if match: 74 | json_str = match.group(1) 75 | return json.loads(json_str) 76 | else: 77 | # 如果不是JSONP格式,尝试直接解析JSON 78 | return json.loads(jsonp_str) 79 | except Exception as e: 80 | print(f"解析JSONP失败: {e}") 81 | print(f"原始数据: {jsonp_str[:100]}...") # 打印前100个字符用于调试 82 | return None 83 | 84 | 85 | def fetch_index_capital_flow(index_code="000001", max_retries=3, retry_delay=2): 86 | """ 87 | 获取指数资金流向数据 88 | 89 | 参数: 90 | index_code: 指数代码,默认为上证指数(000001) 91 | max_retries: 最大重试次数 92 | retry_delay: 重试延迟时间(秒) 93 | 94 | 返回: 95 | dict: 包含资金流向数据的字典 96 | """ 97 | # 根据指数代码构建API URL 98 | market = "1" # 1:上海 0:深圳 99 | if index_code.startswith("39") or index_code.startswith("1"): 100 | market = "0" # 深证指数 101 | 102 | url = INDEX_CAPITAL_FLOW_URL.replace( 103 | "secid=1.000001", f"secid={market}.{index_code}" 104 | ) 105 | 106 | # 添加时间戳防止缓存 107 | timestamp = int(time.time() * 1000) 108 | if "?" in url: 109 | url += f"&_={timestamp}" 110 | else: 111 | url += f"?_={timestamp}" 112 | 113 | # 请求数据 114 | for attempt in range(1, max_retries + 1): 115 | try: 116 | resp = requests.get(url, headers=HEADERS, timeout=15) 117 | resp.raise_for_status() 118 | 119 | # 解析响应数据 120 | data = parse_jsonp(resp.text) 121 | if not data: 122 | print(f"解析指数资金流向数据失败 (第{attempt}次尝试)") 123 | if attempt < max_retries: 124 | time.sleep(retry_delay) 125 | continue 126 | return None 127 | 128 | # 提取资金流向数据 129 | flow_data = data.get("data", {}) 130 | if not flow_data: 131 | print(f"未获取到指数资金流向数据 (第{attempt}次尝试)") 132 | if attempt < max_retries: 133 | time.sleep(retry_delay) 134 | continue 135 | return None 136 | 137 | # 返回数据 138 | return process_flow_data(flow_data, index_code) 139 | 140 | except Exception as e: 141 | print(f"获取指数资金流向数据失败: {e} (第{attempt}次尝试)") 142 | if attempt < max_retries: 143 | time.sleep(retry_delay) 144 | else: 145 | return None 146 | 147 | 148 | def process_flow_data(data, index_code): 149 | """ 150 | 处理资金流向数据 151 | 152 | 参数: 153 | data: API返回的原始数据 154 | index_code: 指数代码 155 | 156 | 返回: 157 | dict: 处理后的资金流向数据 158 | """ 159 | # 根据API返回的字段定义 160 | field_mapping = { 161 | "f135": "今日主力净流入", 162 | "f136": "今日主力流入", 163 | "f137": "今日主力流出", 164 | "f138": "今日超大单净流入", 165 | "f139": "今日超大单流入", 166 | "f140": "今日超大单流出", 167 | "f141": "今日大单净流入", 168 | "f142": "今日大单流入", 169 | "f143": "今日大单流出", 170 | "f144": "今日中单净流入", 171 | "f145": "今日中单流入", 172 | "f146": "今日中单流出", 173 | "f147": "今日小单净流入", 174 | "f148": "今日小单流入", 175 | "f149": "今日小单流出", 176 | } 177 | 178 | # 获取指数名称 179 | index_name = INDEX_CODE_NAME_MAP.get(index_code, f"指数{index_code}") 180 | 181 | # 提取数据并转换为更友好的格式 182 | result = { 183 | "指数代码": index_code, 184 | "指数名称": index_name, 185 | "更新时间": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), 186 | } 187 | 188 | # 添加资金流向数据 189 | for field, label in field_mapping.items(): 190 | # 将原始金额除以1亿,并保留2位小数 191 | if field in data: 192 | value = data.get(field, 0) 193 | if value: 194 | result[label] = round(float(value) / 100000000, 2) # 转换为亿元 195 | else: 196 | result[label] = 0 197 | 198 | return result 199 | 200 | 201 | def get_index_capital_flow(index_code="000001"): 202 | """ 203 | 获取指数资金流向数据 204 | 205 | Args: 206 | index_code (str, optional): 指数代码,默认为上证指数 000001 207 | 208 | Returns: 209 | dict: 资金流向数据 210 | """ 211 | try: 212 | # 获取数据 213 | flow_data = fetch_index_capital_flow(index_code) 214 | 215 | if not flow_data: 216 | return { 217 | "success": False, 218 | "message": f"获取指数{index_code}资金流向数据失败", 219 | "data": {}, 220 | } 221 | 222 | # 获取指数名称 223 | index_name = flow_data.get( 224 | "指数名称", INDEX_CODE_NAME_MAP.get(index_code, f"指数{index_code}") 225 | ) 226 | 227 | # 准备返回结果 228 | result = { 229 | "success": True, 230 | "message": f"成功获取{index_name}({index_code})资金流向数据", 231 | "last_updated": datetime.now().isoformat(), 232 | "data": flow_data, 233 | } 234 | 235 | return result 236 | except Exception as e: 237 | error_msg = f"获取指数资金流向数据时出错: {str(e)}" 238 | print(error_msg) 239 | print(traceback.format_exc()) 240 | return { 241 | "success": False, 242 | "message": error_msg, 243 | "error": traceback.format_exc(), 244 | "data": {}, 245 | } 246 | 247 | 248 | def main(): 249 | """命令行调用入口函数""" 250 | import sys 251 | 252 | # 如果提供了参数,尝试按照参数获取特定指数的资金流向 253 | if len(sys.argv) > 1: 254 | index_code = sys.argv[1] 255 | result = get_index_capital_flow(index_code=index_code) 256 | else: 257 | # 否则获取上证指数资金流向 258 | result = get_index_capital_flow() 259 | 260 | # 打印结果 261 | print(json.dumps(result, ensure_ascii=False, indent=2)) 262 | 263 | 264 | if __name__ == "__main__": 265 | main() 266 | -------------------------------------------------------------------------------- /src/agent/base.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from contextlib import asynccontextmanager 3 | from typing import List, Optional 4 | 5 | from pydantic import BaseModel, Field, model_validator 6 | 7 | from src.llm import LLM 8 | from src.logger import logger 9 | from src.schema import ROLE_TYPE, AgentState, Memory, Message 10 | 11 | 12 | class BaseAgent(BaseModel, ABC): 13 | """Abstract base class for managing agent state and execution. 14 | 15 | Provides foundational functionality for state transitions, memory management, 16 | and a step-based execution loop. Subclasses must implement the `step` method. 17 | """ 18 | 19 | # Core attributes 20 | name: str = Field(..., description="Unique name of the agent") 21 | description: Optional[str] = Field(None, description="Optional agent description") 22 | 23 | # Prompts 24 | system_prompt: Optional[str] = Field( 25 | None, description="System-level instruction prompt" 26 | ) 27 | next_step_prompt: Optional[str] = Field( 28 | None, description="Prompt for determining next action" 29 | ) 30 | 31 | # Dependencies 32 | llm: LLM = Field(default_factory=LLM, description="Language model instance") 33 | memory: Memory = Field(default_factory=Memory, description="Agent's memory store") 34 | state: AgentState = Field( 35 | default=AgentState.IDLE, description="Current agent state" 36 | ) 37 | 38 | # Execution control 39 | max_steps: int = Field(default=10, description="Maximum steps before termination") 40 | current_step: int = Field(default=0, description="Current step in execution") 41 | 42 | duplicate_threshold: int = 2 43 | 44 | class Config: 45 | arbitrary_types_allowed = True 46 | extra = "allow" # Allow extra fields for flexibility in subclasses 47 | 48 | @model_validator(mode="after") 49 | def initialize_agent(self) -> "BaseAgent": 50 | """Initialize agent with default settings if not provided.""" 51 | if self.llm is None or not isinstance(self.llm, LLM): 52 | self.llm = LLM(config_name=self.name.lower()) 53 | if not isinstance(self.memory, Memory): 54 | self.memory = Memory() 55 | return self 56 | 57 | @asynccontextmanager 58 | async def state_context(self, new_state: AgentState): 59 | """Context manager for safe agent state transitions. 60 | 61 | Args: 62 | new_state: The state to transition to during the context. 63 | 64 | Yields: 65 | None: Allows execution within the new state. 66 | 67 | Raises: 68 | ValueError: If the new_state is invalid. 69 | """ 70 | if not isinstance(new_state, AgentState): 71 | raise ValueError(f"Invalid state: {new_state}") 72 | 73 | previous_state = self.state 74 | self.state = new_state 75 | try: 76 | yield 77 | except Exception as e: 78 | self.state = AgentState.ERROR # Transition to ERROR on failure 79 | raise e 80 | finally: 81 | self.state = previous_state # Revert to previous state 82 | 83 | def update_memory( 84 | self, 85 | role: ROLE_TYPE, # type: ignore 86 | content: str, 87 | base64_image: Optional[str] = None, 88 | **kwargs, 89 | ) -> None: 90 | """Add a message to the agent's memory. 91 | 92 | Args: 93 | role: The role of the message sender (user, system, assistant, tool). 94 | content: The message content. 95 | base64_image: Optional base64 encoded image. 96 | **kwargs: Additional arguments (e.g., tool_call_id for tool messages). 97 | 98 | Raises: 99 | ValueError: If the role is unsupported. 100 | """ 101 | message_map = { 102 | "user": Message.user_message, 103 | "system": Message.system_message, 104 | "assistant": Message.assistant_message, 105 | "tool": lambda content, **kw: Message.tool_message(content, **kw), 106 | } 107 | 108 | if role not in message_map: 109 | raise ValueError(f"Unsupported message role: {role}") 110 | 111 | # Create message with appropriate parameters based on role 112 | kwargs = {"base64_image": base64_image, **(kwargs if role == "tool" else {})} 113 | self.memory.add_message(message_map[role](content, **kwargs)) 114 | 115 | async def run(self, request: Optional[str] = None) -> str: 116 | """Execute the agent's main loop asynchronously. 117 | 118 | Args: 119 | request: Optional initial user request to process. 120 | 121 | Returns: 122 | A string summarizing the execution results. 123 | 124 | Raises: 125 | RuntimeError: If the agent is not in IDLE state at start. 126 | """ 127 | if self.state != AgentState.IDLE: 128 | raise RuntimeError(f"Cannot run agent from state: {self.state}") 129 | 130 | if request: 131 | self.update_memory("user", request) 132 | 133 | results: List[str] = [] 134 | async with self.state_context(AgentState.RUNNING): 135 | while ( 136 | self.current_step < self.max_steps and self.state != AgentState.FINISHED 137 | ): 138 | self.current_step += 1 139 | logger.info(f"Executing step {self.current_step}/{self.max_steps}") 140 | step_result = await self.step() 141 | 142 | # Check for stuck state 143 | if self.is_stuck(): 144 | self.handle_stuck_state() 145 | 146 | results.append(f"Step {self.current_step}: {step_result}") 147 | 148 | if self.current_step >= self.max_steps: 149 | self.current_step = 0 150 | self.state = AgentState.IDLE 151 | results.append(f"Terminated: Reached max steps ({self.max_steps})") 152 | return "\n".join(results) if results else "No steps executed" 153 | 154 | @abstractmethod 155 | async def step(self) -> str: 156 | """Execute a single step in the agent's workflow. 157 | 158 | Must be implemented by subclasses to define specific behavior. 159 | """ 160 | 161 | def handle_stuck_state(self): 162 | """Handle stuck state by adding a prompt to change strategy""" 163 | stuck_prompt = "\ 164 | Observed duplicate responses. Consider new strategies and avoid repeating ineffective paths already attempted." 165 | self.next_step_prompt = f"{stuck_prompt}\n{self.next_step_prompt}" 166 | logger.warning(f"Agent detected stuck state. Added prompt: {stuck_prompt}") 167 | 168 | def is_stuck(self) -> bool: 169 | """Check if the agent is stuck in a loop by detecting duplicate content""" 170 | if len(self.memory.messages) < 2: 171 | return False 172 | 173 | last_message = self.memory.messages[-1] 174 | if not last_message.content: 175 | return False 176 | 177 | # Count identical content occurrences 178 | duplicate_count = sum( 179 | 1 180 | for msg in reversed(self.memory.messages[:-1]) 181 | if msg.role == "assistant" and msg.content == last_message.content 182 | ) 183 | 184 | return duplicate_count >= self.duplicate_threshold 185 | 186 | @property 187 | def messages(self) -> List[Message]: 188 | """Retrieve a list of messages from the agent's memory.""" 189 | return self.memory.messages 190 | 191 | @messages.setter 192 | def messages(self, value: List[Message]): 193 | """Set the list of messages in the agent's memory.""" 194 | self.memory.messages = value 195 | 196 | def reset_execution_state(self) -> None: 197 | """Reset the agent's execution state to prepare for a new run. 198 | 199 | Resets the current step counter to zero and returns the agent to IDLE state. 200 | """ 201 | self.current_step = 0 202 | self.state = AgentState.IDLE 203 | logger.info(f"Agent '{self.name}' execution state has been reset") 204 | -------------------------------------------------------------------------------- /src/tool/big_deal_analysis.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict 2 | import time 3 | import pandas as pd # type: ignore 4 | 5 | from src.logger import logger 6 | from src.tool.base import BaseTool, ToolResult 7 | 8 | try: 9 | import akshare as ak # type: ignore 10 | except ImportError: 11 | ak = None # type: ignore 12 | 13 | 14 | class BigDealAnalysisTool(BaseTool): 15 | """Tool for analysing big order fund flows using akshare interfaces.""" 16 | 17 | name: str = "big_deal_analysis_tool" 18 | description: str = ( 19 | "获取市场及个股资金大单流向数据,并返回综合分析结果。" 20 | "调用 akshare 的 stock_fund_flow_big_deal、stock_fund_flow_individual、" 21 | "stock_individual_fund_flow、stock_zh_a_hist 接口。" 22 | ) 23 | parameters: dict = { 24 | "type": "object", 25 | "properties": { 26 | "stock_code": { 27 | "type": "string", 28 | "description": "股票代码,如 '600036',若为空代表全市场分析", 29 | "default": "", 30 | }, 31 | "top_n": { 32 | "type": "integer", 33 | "description": "排行前 N 名股票的数据", 34 | "default": 10, 35 | }, 36 | "rank_symbol": { 37 | "type": "string", 38 | "description": "排行时间窗口,{\"即时\", \"3日排行\", \"5日排行\", \"10日排行\", \"20日排行\"}", 39 | "default": "即时", 40 | }, 41 | "max_retry": { 42 | "type": "integer", 43 | "description": "最大重试次数", 44 | "default": 3, 45 | }, 46 | "sleep_seconds": { 47 | "type": "integer", 48 | "description": "重试间隔秒数", 49 | "default": 1, 50 | }, 51 | }, 52 | } 53 | 54 | async def execute( 55 | self, 56 | stock_code: str = "", 57 | top_n: int = 10, 58 | rank_symbol: str = "即时", 59 | max_retry: int = 3, 60 | sleep_seconds: int = 1, 61 | **kwargs, 62 | ) -> ToolResult: 63 | """Fetch big deal fund flow data and return structured result.""" 64 | if ak is None: 65 | return ToolResult(error="akshare library not installed") 66 | 67 | try: 68 | result: Dict[str, Any] = {} 69 | 70 | def _with_retry(func, *args, **kwargs): 71 | """Simple retry wrapper for unstable akshare endpoints.""" 72 | for attempt in range(1, max_retry + 1): 73 | try: 74 | return func(*args, **kwargs) 75 | except Exception as e: 76 | if attempt >= max_retry: 77 | raise 78 | logger.warning(f"{func.__name__} attempt {attempt} failed: {e}. Retrying...") 79 | time.sleep(sleep_seconds) 80 | 81 | def _safe_fetch(func, *args, **kwargs): 82 | """Fetch data with retries; return None on ultimate failure instead of raising.""" 83 | try: 84 | return _with_retry(func, *args, **kwargs) 85 | except Exception as e: 86 | logger.warning(f"{func.__name__} failed after {max_retry} attempts: {e}") 87 | return None 88 | 89 | # Market wide big deal flow (逐笔大单) 90 | df_bd = _safe_fetch(ak.stock_fund_flow_big_deal) 91 | if df_bd is not None and not df_bd.empty: 92 | # 清洗数字列 93 | def _to_float(series): 94 | return ( 95 | series.astype(str) 96 | .str.replace(",", "", regex=False) 97 | .str.replace("亿", "e8") # unlikely present here 98 | .str.replace("万", "e4") 99 | .str.extract(r"([\d\.-eE]+)")[0] 100 | .astype(float) 101 | ) 102 | 103 | if df_bd["成交额"].dtype == "O": 104 | df_bd["成交额"] = _to_float(df_bd["成交额"]) 105 | 106 | # 计算买盘/卖盘汇总 107 | inflow = df_bd[df_bd["大单性质"] == "买盘"]["成交额"].sum() 108 | outflow = df_bd[df_bd["大单性质"] == "卖盘"]["成交额"].sum() 109 | result["market_summary"] = { 110 | "total_inflow_wan": round(inflow, 2), 111 | "total_outflow_wan": round(outflow, 2), 112 | "net_inflow_wan": round(inflow - outflow, 2), 113 | } 114 | 115 | # 按净额排序股票 116 | grouped = ( 117 | df_bd.groupby(["股票代码", "股票简称", "大单性质"])["成交额"].sum().reset_index() 118 | ) 119 | buy_df = grouped[grouped["大单性质"] == "买盘"].sort_values("成交额", ascending=False) 120 | sell_df = grouped[grouped["大单性质"] == "卖盘"].sort_values("成交额", ascending=False) 121 | 122 | result["top_inflow"] = buy_df.head(top_n).to_dict(orient="records") 123 | result["top_outflow"] = sell_df.head(top_n).to_dict(orient="records") 124 | 125 | # 保存部分原始逐笔记录以备调试(最多 top_n 条) 126 | result["market_big_deal_samples"] = df_bd.head(top_n).to_dict(orient="records") 127 | else: 128 | result["market_big_deal_samples"] = [] 129 | 130 | # Individual fund flow rank 使用 stock_fund_flow_individual(symbol) 131 | individual_rank = _safe_fetch(ak.stock_fund_flow_individual, symbol=rank_symbol) 132 | 133 | # 默认返回排行榜前 top_n 条 134 | result["individual_rank_top"] = ( 135 | individual_rank.head(top_n).to_dict(orient="records") 136 | if individual_rank is not None else [] 137 | ) 138 | 139 | # 若指定了 stock_code, 仅保留其对应行数据 140 | if stock_code and individual_rank is not None: 141 | rank_filtered = individual_rank[ 142 | individual_rank["股票代码"].astype(str) == stock_code 143 | ] 144 | result["individual_rank_stock"] = ( 145 | rank_filtered.to_dict(orient="records") if not rank_filtered.empty else [] 146 | ) 147 | 148 | if stock_code: 149 | # Stock specific fund flow trend 使用 stock_individual_fund_flow 150 | individual_flow = _safe_fetch(ak.stock_individual_fund_flow, stock=stock_code) 151 | result["stock_fund_flow"] = ( 152 | individual_flow.to_dict(orient="records") if individual_flow is not None else [] 153 | ) 154 | 155 | # Historical price data for correlation 156 | hist_price = _safe_fetch(ak.stock_zh_a_hist, symbol=stock_code, period="daily") 157 | if hist_price is not None: 158 | result["stock_price_hist"] = hist_price.tail(120).to_dict(orient="records") 159 | else: 160 | result["stock_price_hist"] = [] 161 | 162 | # 1. 先整体抓取逐笔大单 163 | # 复用已获取的 df_bd,若为空再尝试一次 164 | if df_bd is None: 165 | df_bd = _safe_fetch(ak.stock_fund_flow_big_deal) 166 | 167 | stk_df = pd.DataFrame() 168 | if df_bd is not None and not df_bd.empty: 169 | stk_df = df_bd[df_bd["股票代码"] == stock_code] 170 | 171 | if not stk_df.empty: 172 | inflow = stk_df[stk_df["大单性质"] == "买盘"]["成交额"].sum() 173 | outflow = stk_df[stk_df["大单性质"] == "卖盘"]["成交额"].sum() 174 | 175 | result["stock_big_deal_summary"] = { 176 | "inflow_wan": round(inflow, 2), 177 | "outflow_wan": round(outflow, 2), 178 | "net_inflow_wan": round(inflow - outflow, 2), 179 | "trade_count": len(stk_df), 180 | } 181 | result["stock_big_deal_samples"] = stk_df.head(top_n).to_dict(orient="records") 182 | else: 183 | result["stock_big_deal_summary"] = {} 184 | result["stock_big_deal_samples"] = [] 185 | 186 | return ToolResult(output=result) 187 | except Exception as e: 188 | logger.error(f"BigDealAnalysisTool error: {e}") 189 | return ToolResult(error=str(e)) -------------------------------------------------------------------------------- /src/tool/hot_money.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import json 3 | from datetime import datetime 4 | from typing import Optional 5 | 6 | import efinance as ef 7 | import pandas as pd 8 | from pydantic import Field 9 | 10 | from src.logger import logger 11 | from src.tool.base import BaseTool, ToolResult, get_recent_trading_day 12 | from src.tool.financial_deep_search.get_section_data import get_all_section 13 | from src.tool.financial_deep_search.index_capital import get_index_capital_flow 14 | from src.tool.financial_deep_search.stock_capital import get_stock_capital_flow 15 | 16 | 17 | _HOT_MONEY_DESCRIPTION = """ 18 | 获取股票热点资金和市场数据工具,用于分析主力资金流向和市场热点变化。 19 | 20 | 该工具提供以下精确数据服务: 21 | 1. 股票实时数据:提供目标股票的最新价格、涨跌幅、成交量、换手率、市值等关键指标 22 | 2. 龙虎榜数据:获取特定日期的市场龙虎榜,展示主力机构资金买卖方向和具体金额 23 | 3. 热门板块分析:按类型(概念、行业、地域)提供当日热门板块涨跌情况和资金流向 24 | 4. 个股资金流向:展示目标股票近期主力、散户、超大单资金净流入/流出数据 25 | 5. 大盘资金流向:提供指数级别的资金面分析,包括北向资金、融资余额等宏观指标 26 | 27 | 结果以结构化JSON格式返回,包含完整的数据类别、时间戳和数值指标。 28 | """ 29 | 30 | 31 | class HotMoneyTool(BaseTool): 32 | """Tool for retrieving hot money and market data for stocks.""" 33 | 34 | name: str = "hot_money_tool" 35 | description: str = _HOT_MONEY_DESCRIPTION 36 | parameters: dict = { 37 | "type": "object", 38 | "properties": { 39 | "stock_code": { 40 | "type": "string", 41 | "description": "股票代码(必填),如'600519'(贵州茅台)、'000001'(平安银行)、'300750'(宁德时代)等", 42 | }, 43 | "index_code": { 44 | "type": "string", 45 | "description": "指数代码,如'000001'(上证指数)、'399001'(深证成指)。不提供则默认使用与股票代码相同的值", 46 | }, 47 | "date": { 48 | "type": "string", 49 | "description": "查询日期,精确格式为YYYY-MM-DD(如'2023-05-15'),不提供则默认使用当天日期", 50 | "default": "", 51 | }, 52 | "sector_types": { 53 | "type": "string", 54 | "description": "板块类型筛选,可选值:'all'(所有板块)、'hot'(热门板块)、'concept'(概念板块)、'regional'(地域板块)、'industry'(行业板块)", 55 | "default": "all", 56 | }, 57 | "max_retry": { 58 | "type": "integer", 59 | "description": "数据获取最大重试次数,范围1-5,用于处理网络波动情况", 60 | "default": 3, 61 | }, 62 | "sleep_seconds": { 63 | "type": "integer", 64 | "description": "重试间隔秒数,范围1-10,防止频繁请求被限制", 65 | "default": 1, 66 | }, 67 | }, 68 | "required": ["stock_code"], 69 | } 70 | 71 | lock: asyncio.Lock = Field(default_factory=asyncio.Lock) 72 | 73 | async def execute( 74 | self, 75 | stock_code: str, 76 | index_code: Optional[str] = None, 77 | date: str = "", 78 | sector_types: str = "all", 79 | max_retry: int = 3, 80 | sleep_seconds: int = 1, 81 | **kwargs, 82 | ) -> ToolResult: 83 | """ 84 | Execute the hot money data retrieval operation. 85 | 86 | Args: 87 | stock_code: Stock code, e.g. "600519" 88 | index_code: Index code, e.g. "000001", defaults to stock_code if not provided 89 | date: Query date in YYYY-MM-DD format, defaults to current date 90 | sector_types: Sector types, options: 'all', 'hot', 'concept', 'regional', 'industry' 91 | max_retry: Maximum retry attempts, default 3 92 | sleep_seconds: Seconds to wait between retries, default 1 93 | **kwargs: Additional parameters 94 | 95 | Returns: 96 | ToolResult: Unified JSON format containing all data sources results or error message 97 | """ 98 | async with self.lock: 99 | try: 100 | date = date or get_recent_trading_day() 101 | actual_index_code = index_code or stock_code 102 | 103 | result = { 104 | "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), 105 | "stock_code": stock_code, 106 | "date": date, 107 | } 108 | if index_code: 109 | result["index_code"] = index_code 110 | 111 | # Get all data with retry mechanism 112 | data_sources = { 113 | "stock_latest_info": lambda: ef.stock.get_realtime_quotes( 114 | stock_code 115 | ), 116 | "daily_top_list": lambda: ef.stock.get_daily_billboard( 117 | start_date=date, end_date=date 118 | ), 119 | "hot_section_data": lambda: get_all_section( 120 | sector_types=sector_types 121 | ), 122 | "stock_net_flow": lambda: get_stock_capital_flow( 123 | stock_code=stock_code 124 | ), 125 | "index_net_flow": lambda: get_index_capital_flow( 126 | index_code=actual_index_code 127 | ), 128 | } 129 | 130 | # Retrieve each data source 131 | for key, func in data_sources.items(): 132 | result[key] = await self._get_data_with_retry( 133 | func, key, max_retry, sleep_seconds 134 | ) 135 | 136 | return ToolResult(output=result) 137 | 138 | except Exception as e: 139 | error_msg = f"Failed to get hot money data: {str(e)}" 140 | logger.error(error_msg) 141 | return ToolResult(error=error_msg) 142 | 143 | @staticmethod 144 | async def _get_data_with_retry(func, data_name, max_retry=3, sleep_seconds=1): 145 | """ 146 | Get data with retry mechanism. 147 | 148 | Args: 149 | func: Function to call 150 | data_name: Data name (for logging) 151 | max_retry: Maximum retry attempts 152 | sleep_seconds: Seconds to wait between retries 153 | 154 | Returns: 155 | Function return data or None 156 | """ 157 | last_error = None 158 | for attempt in range(1, max_retry + 1): 159 | try: 160 | # Use asyncio.to_thread for synchronous operations 161 | data = await asyncio.to_thread(func) 162 | 163 | # Convert data based on type 164 | if isinstance(data, pd.DataFrame): 165 | return data.to_dict(orient="records") 166 | elif isinstance(data, pd.Series): 167 | return data.to_dict() 168 | elif hasattr(data, "to_json"): 169 | return json.loads(data.to_json()) 170 | 171 | logger.info(f"[{data_name}] Data retrieved successfully") 172 | return data 173 | 174 | except Exception as e: 175 | last_error = str(e) 176 | logger.warning(f"[{data_name}][Attempt {attempt}] Failed: {e}") 177 | 178 | if attempt < max_retry: 179 | await asyncio.sleep(sleep_seconds) 180 | logger.info(f"[{data_name}] Preparing attempt {attempt+1}...") 181 | 182 | logger.error( 183 | f"[{data_name}] Max retries ({max_retry}) reached, failed: {last_error}" 184 | ) 185 | return None 186 | 187 | 188 | if __name__ == "__main__": 189 | import sys 190 | 191 | code = sys.argv[1] if len(sys.argv) > 1 else "600519" 192 | index_code = "000001" # Default to Shanghai Composite Index 193 | 194 | async def run_tool(): 195 | tool = HotMoneyTool() 196 | result = await tool.execute(stock_code=code, index_code=index_code) 197 | 198 | if result.error: 199 | print(f"Failed: {result.error}") 200 | else: 201 | data = result.output 202 | print(f"Success! Timestamp: {data['timestamp']}") 203 | print(f"Stock Code: {data['stock_code']}") 204 | if "index_code" in data: 205 | print(f"Index Code: {data['index_code']}") 206 | 207 | for key in [ 208 | "stock_latest_info", 209 | "daily_top_list", 210 | "hot_section_data", 211 | "stock_net_flow", 212 | "index_net_flow", 213 | ]: 214 | status = "Success" if data.get(key) is not None else "Failed" 215 | print(f"- {key}: {status}") 216 | 217 | filename = ( 218 | f"hotmoney_data_{code}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json" 219 | ) 220 | with open(filename, "w", encoding="utf-8") as f: 221 | json.dump(data, f, ensure_ascii=False, indent=2) 222 | print(f"\nComplete results saved to: {filename}") 223 | 224 | asyncio.run(run_tool()) 225 | -------------------------------------------------------------------------------- /src/agent/sentiment.py: -------------------------------------------------------------------------------- 1 | from typing import Any, List, Optional, Dict 2 | 3 | from pydantic import Field 4 | 5 | from src.agent.mcp import MCPAgent 6 | from src.prompt.mcp import NEXT_STEP_PROMPT_ZN 7 | from src.prompt.sentiment import SENTIMENT_SYSTEM_PROMPT 8 | from src.schema import Message 9 | from src.tool import Terminate, ToolCollection 10 | from src.tool.sentiment import SentimentTool 11 | from src.tool.web_search import WebSearch 12 | 13 | import logging 14 | logger = logging.getLogger(__name__) 15 | 16 | 17 | class SentimentAgent(MCPAgent): 18 | """Sentiment analysis agent focused on market sentiment and news.""" 19 | 20 | name: str = "sentiment_agent" 21 | description: str = "Analyzes market sentiment, news, and social media for insights on stock performance." 22 | system_prompt: str = SENTIMENT_SYSTEM_PROMPT 23 | next_step_prompt: str = NEXT_STEP_PROMPT_ZN 24 | 25 | # Initialize with FinGenius tools with proper type annotation 26 | available_tools: ToolCollection = Field( 27 | default_factory=lambda: ToolCollection( 28 | SentimentTool(), 29 | WebSearch(), 30 | Terminate(), 31 | ) 32 | ) 33 | 34 | special_tool_names: List[str] = Field(default_factory=lambda: [Terminate().name]) 35 | 36 | async def run( 37 | self, request: Optional[str] = None, stock_code: Optional[str] = None 38 | ) -> Any: 39 | """Run sentiment analysis on the given stock. 40 | 41 | Args: 42 | request: Optional initial request to process. If provided, overrides stock_code parameter. 43 | stock_code: The stock code/ticker to analyze 44 | 45 | Returns: 46 | Dictionary containing sentiment analysis results 47 | """ 48 | # If stock_code is provided but request is not, create request from stock_code 49 | if stock_code and not request: 50 | # Set up system message about the stock being analyzed 51 | self.memory.add_message( 52 | Message.system_message( 53 | f"你正在分析股票 {stock_code} 的市场情绪。请收集相关新闻、社交媒体数据,并评估整体情绪。" 54 | ) 55 | ) 56 | request = f"请分析 {stock_code} 的市场情绪和相关新闻。" 57 | 58 | # Call parent implementation with the request 59 | return await super().run(request) 60 | 61 | async def analyze(self, stock_code: str, **kwargs) -> Dict: 62 | """执行舆情分析""" 63 | try: 64 | logger.info(f"开始舆情分析: {stock_code}") 65 | 66 | # 确保工具执行 - 添加强制执行逻辑 67 | analysis_tasks = [] 68 | 69 | # 1. 强制执行新闻搜索 70 | try: 71 | news_result = await self.tool_call("web_search", { 72 | "query": f"{stock_code} 股票 最新消息 舆情", 73 | "max_results": 10 74 | }) 75 | if news_result and news_result.success: 76 | analysis_tasks.append(("news_search", news_result.data)) 77 | logger.info(f"新闻搜索成功: {stock_code}") 78 | else: 79 | logger.warning(f"新闻搜索失败: {stock_code}") 80 | except Exception as e: 81 | logger.error(f"新闻搜索异常: {stock_code}, {str(e)}") 82 | 83 | # 2. 强制执行社交媒体分析 84 | try: 85 | social_result = await self.tool_call("web_search", { 86 | "query": f"{stock_code} 股吧 讨论 情绪", 87 | "max_results": 5 88 | }) 89 | if social_result and social_result.success: 90 | analysis_tasks.append(("social_media", social_result.data)) 91 | logger.info(f"社交媒体分析成功: {stock_code}") 92 | else: 93 | logger.warning(f"社交媒体分析失败: {stock_code}") 94 | except Exception as e: 95 | logger.error(f"社交媒体分析异常: {stock_code}, {str(e)}") 96 | 97 | # 3. 强制执行舆情分析工具 98 | try: 99 | sentiment_result = await self.tool_call("sentiment_analysis", { 100 | "stock_code": stock_code, 101 | "analysis_type": "comprehensive" 102 | }) 103 | if sentiment_result and sentiment_result.success: 104 | analysis_tasks.append(("sentiment_analysis", sentiment_result.data)) 105 | logger.info(f"舆情分析工具成功: {stock_code}") 106 | else: 107 | logger.warning(f"舆情分析工具失败: {stock_code}") 108 | except Exception as e: 109 | logger.error(f"舆情分析工具异常: {stock_code}, {str(e)}") 110 | 111 | # 4. 综合分析结果 112 | if analysis_tasks: 113 | summary = self._generate_comprehensive_summary(analysis_tasks, stock_code) 114 | logger.info(f"舆情分析完成: {stock_code}, 执行了 {len(analysis_tasks)} 个任务") 115 | return { 116 | "success": True, 117 | "analysis_count": len(analysis_tasks), 118 | "summary": summary, 119 | "tasks_executed": [task[0] for task in analysis_tasks] 120 | } 121 | else: 122 | logger.warning(f"舆情分析没有成功执行任何任务: {stock_code}") 123 | return { 124 | "success": False, 125 | "analysis_count": 0, 126 | "summary": "无法获取舆情数据,请检查网络连接和数据源", 127 | "tasks_executed": [] 128 | } 129 | 130 | except Exception as e: 131 | logger.error(f"舆情分析失败: {stock_code}, {str(e)}") 132 | return { 133 | "success": False, 134 | "error": str(e), 135 | "analysis_count": 0, 136 | "summary": f"舆情分析异常: {str(e)}" 137 | } 138 | 139 | def _generate_comprehensive_summary(self, analysis_tasks: List, stock_code: str) -> str: 140 | """生成综合舆情分析报告""" 141 | try: 142 | summary_parts = [f"## {stock_code} 舆情分析报告\n"] 143 | 144 | for task_name, task_data in analysis_tasks: 145 | if task_name == "news_search": 146 | summary_parts.append("### 📰 新闻舆情") 147 | summary_parts.append(f"- 搜索到 {len(task_data.get('results', []))} 条相关新闻") 148 | summary_parts.append(f"- 整体情绪倾向: {self._analyze_news_sentiment(task_data)}") 149 | 150 | elif task_name == "social_media": 151 | summary_parts.append("### 💬 社交媒体情绪") 152 | summary_parts.append(f"- 搜索到 {len(task_data.get('results', []))} 条相关讨论") 153 | summary_parts.append(f"- 投资者情绪: {self._analyze_social_sentiment(task_data)}") 154 | 155 | elif task_name == "sentiment_analysis": 156 | summary_parts.append("### 📊 专业舆情分析") 157 | summary_parts.append(f"- 情绪指数: {task_data.get('sentiment_score', 'N/A')}") 158 | summary_parts.append(f"- 风险等级: {task_data.get('risk_level', 'N/A')}") 159 | 160 | return "\n".join(summary_parts) 161 | 162 | except Exception as e: 163 | logger.error(f"生成舆情分析报告失败: {str(e)}") 164 | return f"舆情分析报告生成失败: {str(e)}" 165 | 166 | def _analyze_news_sentiment(self, data: Dict) -> str: 167 | """分析新闻情绪""" 168 | try: 169 | results = data.get('results', []) 170 | if not results: 171 | return "中性" 172 | 173 | # 简单的关键词情绪分析 174 | positive_keywords = ["上涨", "利好", "突破", "增长", "看好"] 175 | negative_keywords = ["下跌", "利空", "暴跌", "风险", "亏损"] 176 | 177 | positive_count = 0 178 | negative_count = 0 179 | 180 | for result in results: 181 | text = result.get('snippet', '') + result.get('title', '') 182 | for keyword in positive_keywords: 183 | if keyword in text: 184 | positive_count += 1 185 | for keyword in negative_keywords: 186 | if keyword in text: 187 | negative_count += 1 188 | 189 | if positive_count > negative_count: 190 | return "偏正面" 191 | elif negative_count > positive_count: 192 | return "偏负面" 193 | else: 194 | return "中性" 195 | except: 196 | return "中性" 197 | 198 | def _analyze_social_sentiment(self, data: Dict) -> str: 199 | """分析社交媒体情绪""" 200 | try: 201 | results = data.get('results', []) 202 | if not results: 203 | return "平淡" 204 | 205 | # 简单的讨论热度分析 206 | discussion_keywords = ["买入", "卖出", "持有", "看涨", "看跌"] 207 | keyword_count = 0 208 | 209 | for result in results: 210 | text = result.get('snippet', '') + result.get('title', '') 211 | for keyword in discussion_keywords: 212 | if keyword in text: 213 | keyword_count += 1 214 | 215 | if keyword_count >= 5: 216 | return "活跃" 217 | elif keyword_count >= 2: 218 | return "一般" 219 | else: 220 | return "平淡" 221 | except: 222 | return "平淡" 223 | -------------------------------------------------------------------------------- /src/tool/financial_deep_search/stock_capital.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | """ 5 | 获取东方财富个股资金流向数据 6 | API: https://push2.eastmoney.com/api/qt/clist/get 7 | """ 8 | 9 | import json 10 | import re 11 | import time 12 | import traceback 13 | from datetime import datetime 14 | 15 | import requests 16 | 17 | 18 | # API URL - 个股资金流向 19 | STOCK_CAPITAL_FLOW_URL = "https://push2.eastmoney.com/api/qt/clist/get?fid=f62&po=1&pz=50&pn=1&np=1&fltt=2&invt=2&ut=8dec03ba335b81bf4ebdf7b29ec27d15&fs=m%3A0%2Bt%3A6%2Bf%3A!2%2Cm%3A0%2Bt%3A13%2Bf%3A!2%2Cm%3A0%2Bt%3A80%2Bf%3A!2%2Cm%3A1%2Bt%3A2%2Bf%3A!2%2Cm%3A1%2Bt%3A23%2Bf%3A!2%2Cm%3A0%2Bt%3A7%2Bf%3A!2%2Cm%3A1%2Bt%3A3%2Bf%3A!2&fields=f12%2Cf14%2Cf2%2Cf3%2Cf62%2Cf184%2Cf66%2Cf69%2Cf72%2Cf75%2Cf78%2Cf81%2Cf84%2Cf87%2Cf204%2Cf205%2Cf124%2Cf1%2Cf13" 20 | 21 | # 请求头设置 22 | HEADERS = { 23 | "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36", 24 | "Referer": "https://quote.eastmoney.com/", 25 | "Accept": "application/json, text/javascript, */*; q=0.01", 26 | } 27 | 28 | 29 | def parse_jsonp(jsonp_str): 30 | """解析JSONP响应为JSON数据""" 31 | try: 32 | # 使用正则表达式提取JSON数据 33 | match = re.search(r"jQuery[0-9_]+\((.*)\)", jsonp_str) 34 | if match: 35 | json_str = match.group(1) 36 | return json.loads(json_str) 37 | else: 38 | # 如果不是JSONP格式,尝试直接解析JSON 39 | return json.loads(jsonp_str) 40 | except Exception as e: 41 | print(f"解析JSONP失败: {e}") 42 | print(f"原始数据: {jsonp_str[:100]}...") # 打印前100个字符用于调试 43 | return None 44 | 45 | 46 | def fetch_stock_list_capital_flow( 47 | page_size=50, page_num=1, max_retries=3, retry_delay=2 48 | ): 49 | """ 50 | 获取股票列表的资金流向数据(按主力净流入排序) 51 | 52 | 参数: 53 | page_size: 每页显示数量,默认50 54 | page_num: 页码,默认第1页 55 | max_retries: 最大重试次数 56 | retry_delay: 重试延迟时间(秒) 57 | 58 | 返回: 59 | dict: 包含股票列表资金流向数据的字典 60 | """ 61 | # 构建API URL,包含分页参数 62 | url = STOCK_CAPITAL_FLOW_URL.replace("pz=50", f"pz={page_size}").replace( 63 | "pn=1", f"pn={page_num}" 64 | ) 65 | 66 | # 添加时间戳防止缓存 67 | timestamp = int(time.time() * 1000) 68 | if "?" in url: 69 | url += f"&_={timestamp}" 70 | else: 71 | url += f"?_={timestamp}" 72 | 73 | # 请求数据 74 | for attempt in range(1, max_retries + 1): 75 | try: 76 | resp = requests.get(url, headers=HEADERS, timeout=15) 77 | resp.raise_for_status() 78 | 79 | # 解析响应数据 80 | data = parse_jsonp(resp.text) 81 | if not data: 82 | print(f"解析个股资金流向数据失败 (第{attempt}次尝试)") 83 | if attempt < max_retries: 84 | time.sleep(retry_delay) 85 | continue 86 | return None 87 | 88 | # 提取资金流向数据 89 | stock_list = data.get("data", {}).get("diff", []) 90 | if not stock_list: 91 | print(f"未获取到个股资金流向数据 (第{attempt}次尝试)") 92 | if attempt < max_retries: 93 | time.sleep(retry_delay) 94 | continue 95 | return None 96 | 97 | # 处理股票数据 98 | return process_stock_list_data( 99 | stock_list, data.get("data", {}).get("total", 0) 100 | ) 101 | 102 | except Exception as e: 103 | print(f"获取个股资金流向数据失败: {e} (第{attempt}次尝试)") 104 | if attempt < max_retries: 105 | time.sleep(retry_delay) 106 | else: 107 | return None 108 | 109 | 110 | def fetch_single_stock_capital_flow(stock_code, max_retries=3, retry_delay=2): 111 | """ 112 | 获取单个股票的资金流向数据 113 | 114 | 参数: 115 | stock_code: 股票代码,如"000001" 116 | max_retries: 最大重试次数 117 | retry_delay: 重试延迟时间(秒) 118 | 119 | 返回: 120 | dict: 包含单个股票资金流向数据的字典,如果未找到则返回None 121 | """ 122 | # 获取股票列表数据(多页搜索需要实现分页循环) 123 | for page in range(1, 10): # 最多查找10页 124 | stock_list = fetch_stock_list_capital_flow(50, page) 125 | if not stock_list: 126 | break 127 | 128 | # 在列表中查找指定股票 129 | for stock in stock_list.get("股票列表", []): 130 | if stock.get("股票代码") == stock_code: 131 | return { 132 | "success": True, 133 | "message": f"成功获取股票{stock.get('股票名称')}({stock_code})资金流向数据", 134 | "last_updated": datetime.now().isoformat(), 135 | "data": stock, 136 | } 137 | 138 | # 如果未找到目标股票,使用精确查询 139 | # 这里可以实现具体股票的接口查询,暂不实现 140 | 141 | return {"success": False, "message": f"未找到股票{stock_code}的资金流向数据", "data": {}} 142 | 143 | 144 | def process_stock_list_data(stock_list, total_count): 145 | """ 146 | 处理股票列表资金流向数据 147 | 148 | 参数: 149 | stock_list: API返回的原始股票列表数据 150 | total_count: 总记录数 151 | 152 | 返回: 153 | dict: 处理后的资金流向数据 154 | """ 155 | # 字段映射表 156 | field_mapping = { 157 | "f12": "股票代码", 158 | "f14": "股票名称", 159 | "f2": "最新价", 160 | "f3": "涨跌幅", 161 | "f62": "主力净流入", 162 | "f184": "主力净占比", 163 | "f66": "超大单净流入", 164 | "f69": "超大单净占比", 165 | "f72": "大单净流入", 166 | "f75": "大单净占比", 167 | "f78": "中单净流入", 168 | "f81": "中单净占比", 169 | "f84": "小单净流入", 170 | "f87": "小单净占比", 171 | "f124": "更新时间", 172 | "f1": "市场代码", 173 | "f13": "市场类型", 174 | } 175 | 176 | # 交易市场映射 177 | market_map = { 178 | 0: "SZ", # 深圳 179 | 1: "SH", # 上海 180 | 105: "NQ", # 纳斯达克 181 | 106: "NYSE", # 纽交所 182 | 107: "AMEX", # 美交所 183 | 116: "HK", # 港股 184 | 156: "LN", # 伦敦 185 | } 186 | 187 | # 初始化结果 188 | result = { 189 | "股票列表": [], 190 | "总数": total_count, 191 | "更新时间": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), 192 | } 193 | 194 | # 处理每支股票数据 195 | for stock_item in stock_list: 196 | stock_data = {} 197 | 198 | # 转换每个字段 199 | for api_field, result_field in field_mapping.items(): 200 | if api_field in stock_item: 201 | value = stock_item.get(api_field) 202 | 203 | # 特殊处理的字段 204 | if api_field == "f124": # 更新时间 205 | try: 206 | timestamp = int(value) / 1000 207 | stock_data[result_field] = datetime.fromtimestamp( 208 | timestamp 209 | ).strftime("%Y-%m-%d %H:%M:%S") 210 | except: 211 | stock_data[result_field] = "-" 212 | elif api_field in ["f62", "f66", "f72", "f78", "f84"]: # 资金流入流出金额 213 | stock_data[result_field] = ( 214 | round(float(value) / 10000, 2) if value else 0 215 | ) # 转换为万元 216 | elif api_field in ["f3", "f184", "f69", "f75", "f81", "f87"]: # 百分比 217 | stock_data[result_field] = ( 218 | round(float(value), 2) if value else 0 219 | ) # 保留两位小数 220 | elif api_field == "f13": # 市场类型 221 | market_code = value 222 | stock_data[result_field] = market_map.get( 223 | market_code, str(market_code) 224 | ) 225 | else: 226 | stock_data[result_field] = value 227 | 228 | # 添加完整的股票代码 229 | if "股票代码" in stock_data and "市场类型" in stock_data: 230 | market_prefix = stock_data["市场类型"] 231 | stock_data["完整代码"] = f"{market_prefix}.{stock_data['股票代码']}" 232 | 233 | # 添加到结果列表 234 | result["股票列表"].append(stock_data) 235 | 236 | return result 237 | 238 | 239 | def get_stock_capital_flow(page_size=50, page_num=1, stock_code=None): 240 | """ 241 | 获取股票资金流向数据,支持获取列表或单只股票数据 242 | 243 | 参数: 244 | page_size: 每页显示数量,默认50 245 | page_num: 页码,默认第1页 246 | stock_code: 股票代码,如果指定,则返回单只股票数据,否则返回列表 247 | 248 | 返回: 249 | dict: 包含资金流向数据的字典 250 | """ 251 | try: 252 | # 获取数据(单只股票或列表) 253 | if stock_code: 254 | result = fetch_single_stock_capital_flow(stock_code) 255 | else: 256 | flow_data = fetch_stock_list_capital_flow(page_size, page_num) 257 | if not flow_data: 258 | return {"success": False, "message": f"获取股票资金流向数据失败", "data": {}} 259 | 260 | # 准备返回结果 261 | result = { 262 | "success": True, 263 | "message": f"成功获取股票资金流向数据,共{flow_data.get('总数', 0)}条", 264 | "last_updated": datetime.now().isoformat(), 265 | "data": flow_data, 266 | } 267 | 268 | return result 269 | except Exception as e: 270 | error_msg = f"获取股票资金流向数据时出错: {str(e)}" 271 | print(error_msg) 272 | print(traceback.format_exc()) 273 | return { 274 | "success": False, 275 | "message": error_msg, 276 | "error": traceback.format_exc(), 277 | "data": {}, 278 | } 279 | 280 | 281 | def main(): 282 | """命令行调用入口函数""" 283 | import argparse 284 | 285 | parser = argparse.ArgumentParser(description="获取股票资金流向数据") 286 | parser.add_argument("--code", type=str, help="股票代码,不指定则获取列表") 287 | parser.add_argument("--page", type=int, default=1, help="页码,默认1") 288 | parser.add_argument("--size", type=int, default=50, help="每页数量,默认50") 289 | args = parser.parse_args() 290 | 291 | if args.code: 292 | result = get_stock_capital_flow(stock_code=args.code) 293 | else: 294 | result = get_stock_capital_flow(page_size=args.size, page_num=args.page) 295 | 296 | # 打印结果 297 | print(json.dumps(result, ensure_ascii=False, indent=2)) 298 | 299 | 300 | if __name__ == "__main__": 301 | main() 302 | -------------------------------------------------------------------------------- /src/mcp/server.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import asyncio 3 | import atexit 4 | import json 5 | import logging 6 | import sys 7 | from inspect import Parameter, Signature 8 | from typing import Any, Dict, Optional 9 | 10 | import uvicorn 11 | from starlette.applications import Starlette 12 | from starlette.middleware import Middleware 13 | from starlette.middleware.cors import CORSMiddleware 14 | from starlette.requests import Request 15 | from starlette.routing import Mount, Route 16 | 17 | from mcp.server import FastMCP, Server 18 | from mcp.server.sse import SseServerTransport 19 | from src.logger import logger 20 | from src.tool import BaseTool, Terminate 21 | 22 | 23 | logging.basicConfig(level=logging.INFO, handlers=[logging.StreamHandler(sys.stderr)]) 24 | 25 | 26 | class MCPServer: 27 | """MCP Server implementation with tool registration and management.""" 28 | 29 | def __init__(self, name: str = "FinGenius"): 30 | self.server = FastMCP(name) 31 | self.tools: Dict[str, BaseTool] = {} 32 | 33 | # Initialize standard tools 34 | self._initialize_standard_tools() 35 | 36 | def _initialize_standard_tools(self) -> None: 37 | """Initialize standard tools available in the server.""" 38 | self.tools.update( 39 | { 40 | "terminate": Terminate(), 41 | } 42 | ) 43 | 44 | def register_tool(self, tool: BaseTool, method_name: Optional[str] = None) -> None: 45 | """Register a tool with parameter validation and documentation.""" 46 | tool_name = method_name or tool.name 47 | tool_param = tool.to_param() 48 | tool_function = tool_param["function"] 49 | 50 | # Define the async function to be registered 51 | async def tool_method(**kwargs): 52 | logger.info(f"Executing {tool_name}: {kwargs}") 53 | result = await tool.execute(**kwargs) 54 | 55 | logger.info(f"Result of {tool_name}: {result}") 56 | 57 | # Handle different types of results (match original logic) 58 | if hasattr(result, "model_dump"): 59 | return json.dumps(result.model_dump()) 60 | elif isinstance(result, dict): 61 | return json.dumps(result) 62 | return result 63 | 64 | # Set method metadata 65 | tool_method.__name__ = tool_name 66 | tool_method.__doc__ = self._build_docstring(tool_function) 67 | tool_method.__signature__ = self._build_signature(tool_function) 68 | 69 | # Store parameter schema (important for tools that access it programmatically) 70 | param_props = tool_function.get("parameters", {}).get("properties", {}) 71 | required_params = tool_function.get("parameters", {}).get("required", []) 72 | tool_method._parameter_schema = { 73 | param_name: { 74 | "description": param_details.get("description", ""), 75 | "type": param_details.get("type", "any"), 76 | "required": param_name in required_params, 77 | } 78 | for param_name, param_details in param_props.items() 79 | } 80 | 81 | # Register with server 82 | self.server.tool()(tool_method) 83 | logger.info(f"Registered tool: {tool_name}") 84 | 85 | def _build_docstring(self, tool_function: dict) -> str: 86 | """Build a formatted docstring from tool function metadata.""" 87 | description = tool_function.get("description", "") 88 | param_props = tool_function.get("parameters", {}).get("properties", {}) 89 | required_params = tool_function.get("parameters", {}).get("required", []) 90 | 91 | # Build docstring (match original format) 92 | docstring = description 93 | if param_props: 94 | docstring += "\n\nParameters:\n" 95 | for param_name, param_details in param_props.items(): 96 | required_str = ( 97 | "(required)" if param_name in required_params else "(optional)" 98 | ) 99 | param_type = param_details.get("type", "any") 100 | param_desc = param_details.get("description", "") 101 | docstring += ( 102 | f" {param_name} ({param_type}) {required_str}: {param_desc}\n" 103 | ) 104 | 105 | return docstring 106 | 107 | def _build_signature(self, tool_function: dict) -> Signature: 108 | """Build a function signature from tool function metadata.""" 109 | param_props = tool_function.get("parameters", {}).get("properties", {}) 110 | required_params = tool_function.get("parameters", {}).get("required", []) 111 | 112 | parameters = [] 113 | 114 | # Follow original type mapping 115 | for param_name, param_details in param_props.items(): 116 | param_type = param_details.get("type", "") 117 | default = Parameter.empty if param_name in required_params else None 118 | 119 | # Map JSON Schema types to Python types (same as original) 120 | annotation = Any 121 | if param_type == "string": 122 | annotation = str 123 | elif param_type == "integer": 124 | annotation = int 125 | elif param_type == "number": 126 | annotation = float 127 | elif param_type == "boolean": 128 | annotation = bool 129 | elif param_type == "object": 130 | annotation = dict 131 | elif param_type == "array": 132 | annotation = list 133 | 134 | # Create parameter with same structure as original 135 | param = Parameter( 136 | name=param_name, 137 | kind=Parameter.KEYWORD_ONLY, 138 | default=default, 139 | annotation=annotation, 140 | ) 141 | parameters.append(param) 142 | 143 | return Signature(parameters=parameters) 144 | 145 | async def cleanup(self) -> None: 146 | """Clean up server resources.""" 147 | logger.info("Cleaning up resources") 148 | # Follow original cleanup logic - only clean browser tool 149 | if "browser" in self.tools and hasattr(self.tools["browser"], "cleanup"): 150 | await self.tools["browser"].cleanup() 151 | 152 | def register_all_tools(self) -> None: 153 | """Register all tools with the server.""" 154 | for tool in self.tools.values(): 155 | self.register_tool(tool) 156 | 157 | def run(self, transport: str = "stdio") -> None: 158 | """Run the MCP server.""" 159 | # Register all tools 160 | self.register_all_tools() 161 | 162 | # Register cleanup function (match original behavior) 163 | atexit.register(lambda: asyncio.run(self.cleanup())) 164 | 165 | # Start server (with same logging as original) 166 | logger.info(f"Starting FinGenius server ({transport} mode)") 167 | self.server.run(transport=transport) 168 | 169 | 170 | def create_starlette_app(mcp_server: Server, *, debug: bool = False) -> Starlette: 171 | """Create a Starlette application that can serve the provided mcp server with SSE.""" 172 | # Set up CORS middleware to allow connections from any origin 173 | middleware = [ 174 | Middleware( 175 | CORSMiddleware, 176 | allow_origins=["*"], 177 | allow_methods=["*"], 178 | allow_headers=["*"], 179 | allow_credentials=True, 180 | ) 181 | ] 182 | 183 | # Use '/messages/' as the endpoint path for SSE connections 184 | sse = SseServerTransport("/messages/") 185 | 186 | async def handle_sse(request: Request) -> None: 187 | try: 188 | logger.info(f"SSE connection request received from {request.client}") 189 | async with sse.connect_sse( 190 | request.scope, 191 | request.receive, 192 | request._send, # noqa: SLF001 193 | ) as (read_stream, write_stream): 194 | await mcp_server.run( 195 | read_stream, 196 | write_stream, 197 | mcp_server.create_initialization_options(), 198 | ) 199 | except Exception as e: 200 | logger.error(f"Error handling SSE connection: {str(e)}") 201 | raise 202 | 203 | # Create Starlette app with middleware and routes 204 | app = Starlette( 205 | debug=debug, 206 | middleware=middleware, 207 | routes=[ 208 | Route("/sse", endpoint=handle_sse), 209 | Mount("/messages/", app=sse.handle_post_message), 210 | ], 211 | ) 212 | 213 | # Add a health check endpoint 214 | @app.route("/health") 215 | async def health_check(request: Request): 216 | from starlette.responses import JSONResponse 217 | 218 | return JSONResponse( 219 | { 220 | "status": "ok", 221 | "message": "FinGenius server is running", 222 | } 223 | ) 224 | 225 | return app 226 | 227 | 228 | def parse_args() -> argparse.Namespace: 229 | """Parse command line arguments.""" 230 | parser = argparse.ArgumentParser(description="FinGenius MCP Server") 231 | parser.add_argument( 232 | "--transport", 233 | choices=["stdio", "sse"], 234 | default="stdio", 235 | help="Communication method: stdio or sse (default: stdio)", 236 | ) 237 | parser.add_argument("--host", default="0.0.0.0", help="Host to bind to (for sse)") 238 | parser.add_argument( 239 | "--port", type=int, default=8000, help="Port to listen on (for sse)" 240 | ) 241 | parser.add_argument("--debug", action="store_true", help="Enable debug mode") 242 | return parser.parse_args() 243 | 244 | 245 | if __name__ == "__main__": 246 | args = parse_args() 247 | 248 | if args.transport == "sse": 249 | # Create an instance of MCPServer 250 | mcp_server = MCPServer() 251 | # Register all tools 252 | mcp_server.register_all_tools() 253 | # Get the underlying mcp_server from FastMCP 254 | underlying_server = mcp_server.server._mcp_server # noqa: WPS437 255 | 256 | logger.info(f"Starting FinGenius server with SSE on {args.host}:{args.port}") 257 | 258 | # Create Starlette application 259 | starlette_app = create_starlette_app(underlying_server, debug=args.debug) 260 | 261 | # Run with uvicorn 262 | uvicorn.run( 263 | starlette_app, 264 | host=args.host, 265 | port=args.port, 266 | log_level="info", 267 | access_log=True, 268 | ) 269 | else: 270 | # Run in stdio mode 271 | mcp_server = MCPServer() 272 | mcp_server.run(transport=args.transport) 273 | -------------------------------------------------------------------------------- /src/agent/toolcall.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import json 3 | from typing import Any, List, Optional, Union 4 | 5 | from pydantic import Field 6 | 7 | from src.agent.react import ReActAgent 8 | from src.exceptions import TokenLimitExceeded 9 | from src.llm import LLM 10 | from src.logger import logger 11 | from src.prompt.toolcall import NEXT_STEP_PROMPT, SYSTEM_PROMPT 12 | from src.schema import ( 13 | TOOL_CHOICE_TYPE, 14 | AgentState, 15 | Memory, 16 | Message, 17 | ToolCall, 18 | ToolChoice, 19 | ) 20 | from src.tool import Terminate, ToolCollection 21 | 22 | 23 | TOOL_CALL_REQUIRED = "Tool calls required but none provided" 24 | 25 | 26 | class ToolCallAgent(ReActAgent): 27 | """Base agent class for handling tool/function calls with enhanced abstraction""" 28 | 29 | name: str = "toolcall" 30 | description: str = "an agent that can execute tool calls." 31 | 32 | system_prompt: str = SYSTEM_PROMPT 33 | next_step_prompt: str = NEXT_STEP_PROMPT 34 | 35 | llm: Optional[LLM] = Field(default_factory=LLM) 36 | memory: Memory = Field(default_factory=Memory) 37 | state: AgentState = AgentState.IDLE 38 | 39 | available_tools: ToolCollection = ToolCollection(Terminate()) 40 | tool_choices: TOOL_CHOICE_TYPE = ToolChoice.AUTO # type: ignore 41 | special_tool_names: List[str] = Field(default_factory=lambda: [Terminate().name]) 42 | 43 | tool_calls: List[ToolCall] = Field(default_factory=list) 44 | _current_base64_image: Optional[str] = None 45 | 46 | max_steps: int = 30 47 | current_step: int = 0 48 | 49 | max_observe: Optional[Union[int, bool]] = None 50 | 51 | async def think(self) -> bool: 52 | """Process current state and decide next actions using tools""" 53 | if self.next_step_prompt: 54 | user_msg = Message.user_message(self.next_step_prompt) 55 | self.messages += [user_msg] 56 | 57 | try: 58 | # Get response with tool options 59 | response = await self.llm.ask_tool( 60 | messages=self.messages, 61 | system_msgs=( 62 | [Message.system_message(self.system_prompt)] 63 | if self.system_prompt 64 | else None 65 | ), 66 | tools=self.available_tools.to_params(), 67 | tool_choice=self.tool_choices, 68 | ) 69 | except ValueError: 70 | raise 71 | except Exception as e: 72 | # Check if this is a RetryError containing TokenLimitExceeded 73 | if hasattr(e, "__cause__") and isinstance(e.__cause__, TokenLimitExceeded): 74 | token_limit_error = e.__cause__ 75 | logger.error( 76 | f"🚨 Token limit error (from RetryError): {token_limit_error}" 77 | ) 78 | self.memory.add_message( 79 | Message.assistant_message( 80 | f"Maximum token limit reached, cannot continue execution: {str(token_limit_error)}" 81 | ) 82 | ) 83 | self.state = AgentState.FINISHED 84 | return False 85 | raise 86 | 87 | self.tool_calls = tool_calls = ( 88 | response.tool_calls if response and response.tool_calls else [] 89 | ) 90 | content = response.content if response and response.content else "" 91 | 92 | # Log response info 93 | logger.info(f"✨ {self.name}'s thoughts: {content}") 94 | logger.info( 95 | f"🛠️ {self.name} selected {len(tool_calls) if tool_calls else 0} tools to use" 96 | ) 97 | 98 | # Show agent thinking in terminal if content exists 99 | if content and content.strip(): 100 | from src.console import visualizer 101 | visualizer.show_agent_thought(self.name, content, "analysis") 102 | if tool_calls: 103 | logger.info( 104 | f"🧰 Tools being prepared: {[call.function.name for call in tool_calls]}" 105 | ) 106 | logger.info(f"🔧 Tool arguments: {tool_calls[0].function.arguments}") 107 | 108 | try: 109 | if response is None: 110 | raise RuntimeError("No response received from the LLM") 111 | 112 | # Handle different tool_choices modes 113 | if self.tool_choices == ToolChoice.NONE: 114 | if tool_calls: 115 | logger.warning( 116 | f"🤔 Hmm, {self.name} tried to use tools when they weren't available!" 117 | ) 118 | if content: 119 | self.memory.add_message(Message.assistant_message(content)) 120 | return True 121 | return False 122 | 123 | # Create and add assistant message 124 | assistant_msg = ( 125 | Message.from_tool_calls(content=content, tool_calls=self.tool_calls) 126 | if self.tool_calls 127 | else Message.assistant_message(content) 128 | ) 129 | self.memory.add_message(assistant_msg) 130 | 131 | if self.tool_choices == ToolChoice.REQUIRED and not self.tool_calls: 132 | return True # Will be handled in act() 133 | 134 | # For 'auto' mode, continue with content if no commands but content exists 135 | if self.tool_choices == ToolChoice.AUTO and not self.tool_calls: 136 | return bool(content) 137 | 138 | return bool(self.tool_calls) 139 | except Exception as e: 140 | logger.error(f"🚨 Oops! The {self.name}'s thinking process hit a snag: {e}") 141 | self.memory.add_message( 142 | Message.assistant_message( 143 | f"Error encountered while processing: {str(e)}" 144 | ) 145 | ) 146 | return False 147 | 148 | async def act(self) -> str: 149 | """Execute tool calls and handle their results""" 150 | if not self.tool_calls: 151 | if self.tool_choices == ToolChoice.REQUIRED: 152 | raise ValueError(TOOL_CALL_REQUIRED) 153 | 154 | # Return last message content if no tool calls 155 | return self.messages[-1].content or "No content or commands to execute" 156 | 157 | results = [] 158 | for command in self.tool_calls: 159 | # Reset base64_image for each tool call 160 | self._current_base64_image = None 161 | 162 | result = await self.execute_tool(command) 163 | 164 | if self.max_observe: 165 | result = result[: self.max_observe] 166 | 167 | logger.info( 168 | f"🎯 Tool '{command.function.name}' completed its mission! Result: {result}" 169 | ) 170 | 171 | # Add tool response to memory 172 | tool_msg = Message.tool_message( 173 | content=result, 174 | tool_call_id=command.id, 175 | name=command.function.name, 176 | base64_image=self._current_base64_image, 177 | ) 178 | self.memory.add_message(tool_msg) 179 | results.append(result) 180 | 181 | return "\n\n".join(results) 182 | 183 | async def execute_tool(self, command: ToolCall) -> str: 184 | """Execute a single tool call with robust error handling""" 185 | if not command or not command.function or not command.function.name: 186 | return "Error: Invalid command format" 187 | 188 | name = command.function.name 189 | if name not in self.available_tools.tool_map: 190 | return f"Error: Unknown tool '{name}'" 191 | 192 | try: 193 | # Parse arguments 194 | args = json.loads(command.function.arguments or "{}") 195 | 196 | # Execute the tool 197 | logger.info(f"🔧 Activating tool: '{name}'...") 198 | result = await self.available_tools.execute(name=name, tool_input=args) 199 | 200 | # Handle special tools 201 | await self._handle_special_tool(name=name, result=result) 202 | 203 | # Check if result is a ToolResult with base64_image 204 | if hasattr(result, "base64_image") and result.base64_image: 205 | # Store the base64_image for later use in tool_message 206 | self._current_base64_image = result.base64_image 207 | 208 | # Format result for display 209 | observation = ( 210 | f"Observed output of cmd `{name}` executed:\n{str(result)}" 211 | if result 212 | else f"Cmd `{name}` completed with no output" 213 | ) 214 | return observation 215 | 216 | # Format result for display (standard case) 217 | observation = ( 218 | f"Observed output of cmd `{name}` executed:\n{str(result)}" 219 | if result 220 | else f"Cmd `{name}` completed with no output" 221 | ) 222 | 223 | return observation 224 | except json.JSONDecodeError: 225 | error_msg = f"Error parsing arguments for {name}: Invalid JSON format" 226 | logger.error( 227 | f"📝 Oops! The arguments for '{name}' don't make sense - invalid JSON, arguments:{command.function.arguments}" 228 | ) 229 | return f"Error: {error_msg}" 230 | except Exception as e: 231 | error_msg = f"⚠️ Tool '{name}' encountered a problem: {str(e)}" 232 | logger.exception(error_msg) 233 | return f"Error: {error_msg}" 234 | 235 | async def _handle_special_tool(self, name: str, result: Any, **kwargs): 236 | """Handle special tool execution and state changes""" 237 | if not self._is_special_tool(name): 238 | return 239 | 240 | if self._should_finish_execution(name=name, result=result, **kwargs): 241 | # Set agent state to finished 242 | logger.info(f"🏁 Special tool '{name}' has completed the task!") 243 | self.state = AgentState.FINISHED 244 | 245 | @staticmethod 246 | def _should_finish_execution(**kwargs) -> bool: 247 | """Determine if tool execution should finish the agent""" 248 | return True 249 | 250 | def _is_special_tool(self, name: str) -> bool: 251 | """Check if tool name is in special tools list""" 252 | return name.lower() in [n.lower() for n in self.special_tool_names] 253 | 254 | async def cleanup(self): 255 | """Clean up resources used by the agent's tools.""" 256 | logger.info(f"🧹 Cleaning up resources for agent '{self.name}'...") 257 | for tool_name, tool_instance in self.available_tools.tool_map.items(): 258 | if hasattr(tool_instance, "cleanup") and asyncio.iscoroutinefunction( 259 | tool_instance.cleanup 260 | ): 261 | try: 262 | logger.debug(f"🧼 Cleaning up tool: {tool_name}") 263 | await tool_instance.cleanup() 264 | except Exception as e: 265 | logger.error( 266 | f"🚨 Error cleaning up tool '{tool_name}': {e}", exc_info=True 267 | ) 268 | logger.info(f"✨ Cleanup complete for agent '{self.name}'.") 269 | 270 | async def run(self, request: Optional[str] = None) -> str: 271 | """Run the agent with cleanup when done.""" 272 | try: 273 | return await super().run(request) 274 | finally: 275 | await self.cleanup() 276 | --------------------------------------------------------------------------------