├── src
├── mcp
│ ├── __init__.py
│ ├── battle_server.py
│ ├── risk_control_server.py
│ ├── chip_analysis_server.py
│ ├── big_deal_analysis_server.py
│ ├── technical_analysis_server.py
│ ├── sentiment_server.py
│ ├── hot_money_srver.py
│ └── server.py
├── prompt
│ ├── __init__.py
│ ├── toolcall.py
│ ├── big_deal_analysis.py
│ ├── hot_money.py
│ ├── risk_control.py
│ ├── technical_analysis.py
│ ├── report.py
│ ├── chip_analysis.py
│ ├── mcp.py
│ ├── sentiment.py
│ └── battle.py
├── exceptions.py
├── __init__.py
├── environment
│ ├── __init__.py
│ ├── base.py
│ └── research.py
├── agent
│ ├── __init__.py
│ ├── react.py
│ ├── report.py
│ ├── big_deal_analysis.py
│ ├── chip_analysis.py
│ ├── risk_control.py
│ ├── hot_money.py
│ ├── technical_analysis.py
│ ├── base.py
│ ├── sentiment.py
│ └── toolcall.py
├── tool
│ ├── search
│ │ ├── __init__.py
│ │ ├── google_search.py
│ │ ├── base.py
│ │ ├── baidu_search.py
│ │ ├── duckduckgo_search.py
│ │ └── bing_search.py
│ ├── __init__.py
│ ├── financial_deep_search
│ │ ├── __init__.py
│ │ ├── index_name_map.json
│ │ ├── get_section_data.py
│ │ ├── index_capital.py
│ │ └── stock_capital.py
│ ├── terminate.py
│ ├── tool_collection.py
│ ├── battle.py
│ ├── base.py
│ ├── stock_info_request.py
│ ├── risk_control.py
│ ├── sentiment.py
│ ├── create_chat_completion.py
│ ├── mcp_client.py
│ ├── big_deal_analysis.py
│ └── hot_money.py
├── logger.py
├── utils
│ └── cleanup_reports.py
└── schema.py
├── report
└── README.md
├── docs
├── boyi.png
├── logo.png
├── wechat.JPG
├── architecture.png
├── flow_diagram.md
└── class_diagram.md
├── config
├── .gitignore
├── mcp.example.json
└── config.example.toml
├── requirements.txt
├── .gitattributes
├── .pre-commit-config.yaml
└── .gitignore
/src/mcp/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/prompt/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/report/README.md:
--------------------------------------------------------------------------------
1 | # 报告文件
2 | - 这里存放最终的结果
--------------------------------------------------------------------------------
/docs/boyi.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HuaYaoAI/FinGenius/HEAD/docs/boyi.png
--------------------------------------------------------------------------------
/docs/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HuaYaoAI/FinGenius/HEAD/docs/logo.png
--------------------------------------------------------------------------------
/docs/wechat.JPG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HuaYaoAI/FinGenius/HEAD/docs/wechat.JPG
--------------------------------------------------------------------------------
/docs/architecture.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HuaYaoAI/FinGenius/HEAD/docs/architecture.png
--------------------------------------------------------------------------------
/config/.gitignore:
--------------------------------------------------------------------------------
1 | # prevent the local config file from being uploaded to the remote repository
2 | config.toml
3 | mcp.json
4 |
--------------------------------------------------------------------------------
/config/mcp.example.json:
--------------------------------------------------------------------------------
1 | {
2 | "mcpServers": {
3 | "server1": {
4 | "type": "sse",
5 | "url": "http://localhost:8000/sse"
6 | }
7 | }
8 | }
9 |
--------------------------------------------------------------------------------
/src/prompt/toolcall.py:
--------------------------------------------------------------------------------
1 | SYSTEM_PROMPT = "You are an agent that can execute tool calls"
2 |
3 | NEXT_STEP_PROMPT = (
4 | "If you want to stop interaction, use `terminate` tool/function call."
5 | )
6 |
--------------------------------------------------------------------------------
/src/exceptions.py:
--------------------------------------------------------------------------------
1 | class ToolError(Exception):
2 | """Raised when a tool encounters an error."""
3 |
4 | def __init__(self, message):
5 | self.message = message
6 |
7 |
8 | class TokenLimitExceeded(Exception):
9 | """Exception raised when the token limit is exceeded"""
10 |
--------------------------------------------------------------------------------
/src/__init__.py:
--------------------------------------------------------------------------------
1 | # Python version check: 3.11-3.13
2 | import sys
3 |
4 |
5 | if sys.version_info < (3, 11) or sys.version_info > (3, 13):
6 | print(
7 | "Warning: Unsupported Python version {ver}, please use 3.11-3.13".format(
8 | ver=".".join(map(str, sys.version_info))
9 | )
10 | )
11 |
--------------------------------------------------------------------------------
/src/environment/__init__.py:
--------------------------------------------------------------------------------
1 | """Environment module for different execution environments."""
2 |
3 | from src.environment.base import BaseEnvironment
4 | from src.environment.battle import BattleEnvironment
5 | from src.environment.research import ResearchEnvironment
6 |
7 |
8 | __all__ = ["BaseEnvironment", "ResearchEnvironment", "BattleEnvironment"]
9 |
--------------------------------------------------------------------------------
/src/prompt/big_deal_analysis.py:
--------------------------------------------------------------------------------
1 | BIG_DEAL_SYSTEM_PROMPT = """
2 | 你是FinGenius系统中的大单异动分析专家,擅长使用资金流向数据结合价格走势识别市场高强度资金异动。
3 |
4 | 你的职责:
5 | 1. 使用 big_deal_analysis_tool 获取市场与个股资金流向数据(stock_fund_flow_big_deal、stock_fund_flow_individual 等)。
6 | 2. 给出市场整体资金净流入/净流出统计,并列出 TOP 净流入 / 净流出股票。
7 | 3. 针对指定股票,分析资金流向趋势与价格走势,给出综合评分,并输出投资建议(看涨/看跌及理由)。
8 | 4. 识别极端流入/流出、市场情绪和高活跃度股票并说明依据。
9 |
10 | 输出请使用分点叙述,逻辑清晰,必要时给出简明表格。
11 | """
--------------------------------------------------------------------------------
/src/agent/__init__.py:
--------------------------------------------------------------------------------
1 | from src.agent.base import BaseAgent
2 | from src.agent.chip_analysis import ChipAnalysisAgent
3 | from src.agent.react import ReActAgent
4 | from src.agent.toolcall import ToolCallAgent
5 | from src.agent.big_deal_analysis import BigDealAnalysisAgent
6 |
7 |
8 | __all__ = [
9 | "BaseAgent",
10 | "ChipAnalysisAgent",
11 | "ReActAgent",
12 | "ToolCallAgent",
13 | "BigDealAnalysisAgent",
14 | ]
15 |
--------------------------------------------------------------------------------
/src/mcp/battle_server.py:
--------------------------------------------------------------------------------
1 | from src.mcp.server import MCPServer
2 | from src.tool import Battle, Terminate
3 |
4 |
5 | class BattleServer(MCPServer):
6 | def __init__(self, name: str = "BattleServer"):
7 | super().__init__(name)
8 |
9 | def _initialize_standard_tools(self) -> None:
10 | self.tools.update(
11 | {
12 | "terminate": Terminate(),
13 | "battle": Battle(),
14 | }
15 | )
16 |
--------------------------------------------------------------------------------
/src/tool/search/__init__.py:
--------------------------------------------------------------------------------
1 | from src.tool.search.baidu_search import BaiduSearchEngine
2 | from src.tool.search.base import WebSearchEngine
3 | from src.tool.search.bing_search import BingSearchEngine
4 | from src.tool.search.duckduckgo_search import DuckDuckGoSearchEngine
5 | from src.tool.search.google_search import GoogleSearchEngine
6 |
7 |
8 | __all__ = [
9 | "WebSearchEngine",
10 | "BaiduSearchEngine",
11 | "DuckDuckGoSearchEngine",
12 | "GoogleSearchEngine",
13 | "BingSearchEngine",
14 | ]
15 |
--------------------------------------------------------------------------------
/src/mcp/risk_control_server.py:
--------------------------------------------------------------------------------
1 | from src.mcp.server import MCPServer
2 | from src.tool import Terminate
3 | from src.tool.risk_control import RiskControlTool
4 |
5 |
6 | class RiskControlServer(MCPServer):
7 | def __init__(self, name: str = "RiskControlServer"):
8 | super().__init__(name)
9 |
10 | def _initialize_standard_tools(self) -> None:
11 | self.tools.update(
12 | {
13 | "risk_control_tool": RiskControlTool(),
14 | "terminate": Terminate(),
15 | }
16 | )
17 |
--------------------------------------------------------------------------------
/src/mcp/chip_analysis_server.py:
--------------------------------------------------------------------------------
1 | from src.mcp.server import MCPServer
2 | from src.tool import Terminate
3 | from src.tool.chip_analysis import ChipAnalysisTool
4 |
5 |
6 | class ChipAnalysisServer(MCPServer):
7 | def __init__(self, name: str = "ChipAnalysisServer"):
8 | super().__init__(name)
9 |
10 | def _initialize_standard_tools(self) -> None:
11 | self.tools.update(
12 | {
13 | "chip_analysis_tool": ChipAnalysisTool(),
14 | "terminate": Terminate(),
15 | }
16 | )
17 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | streamlit
2 | asyncio
3 | pydantic~=2.10.6
4 | aiohttp
5 | openai~=1.66.5
6 | fastmcp
7 | python-dotenv
8 | beautifulsoup4~=4.13.3
9 | pandas~=2.2.3
10 | numpy
11 | matplotlib
12 | plotly
13 | nltk
14 | tomli
15 | mcp~=1.5.0
16 | tiktoken~=0.9.0
17 | tenacity~=9.0.0
18 | loguru~=0.7.3
19 | reportlab
20 | googlesearch-python~=1.3.0
21 | baidusearch~=1.0.3
22 | duckduckgo_search~=7.5.3
23 | requests~=2.32.3
24 | duckduckgo-search~=7.5.5
25 | efinance~=0.5.5.2
26 | akshare~=1.16.87
27 | schedule~=1.2.2
28 | uvicorn
29 | starlette
30 | rich~=13.7.1
31 |
--------------------------------------------------------------------------------
/src/mcp/big_deal_analysis_server.py:
--------------------------------------------------------------------------------
1 | from src.mcp.server import MCPServer
2 | from src.tool import Terminate
3 | from src.tool.big_deal_analysis import BigDealAnalysisTool
4 |
5 |
6 | class BigDealAnalysisServer(MCPServer):
7 | def __init__(self, name: str = "BigDealAnalysisServer"):
8 | super().__init__(name)
9 |
10 | def _initialize_standard_tools(self) -> None:
11 | self.tools.update(
12 | {
13 | "big_deal_analysis_tool": BigDealAnalysisTool(),
14 | "terminate": Terminate(),
15 | }
16 | )
17 |
--------------------------------------------------------------------------------
/src/mcp/technical_analysis_server.py:
--------------------------------------------------------------------------------
1 | from src.mcp.server import MCPServer
2 | from src.tool import Terminate
3 | from src.tool.technical_analysis import TechnicalAnalysisTool
4 |
5 |
6 | class TechnicalAnalysisServer(MCPServer):
7 | def __init__(self, name: str = "TechnicalAnalysisServer"):
8 | super().__init__(name)
9 |
10 | def _initialize_standard_tools(self) -> None:
11 | self.tools.update(
12 | {
13 | "technical_analysis_tool": TechnicalAnalysisTool(),
14 | "terminate": Terminate(),
15 | }
16 | )
17 |
--------------------------------------------------------------------------------
/src/mcp/sentiment_server.py:
--------------------------------------------------------------------------------
1 | from src.mcp.server import MCPServer
2 | from src.tool import Terminate
3 | from src.tool.sentiment import SentimentTool
4 | from src.tool.web_search import WebSearch
5 |
6 |
7 | class SentimentServer(MCPServer):
8 | def __init__(self, name: str = "SentimentServer"):
9 | super().__init__(name)
10 |
11 | def _initialize_standard_tools(self) -> None:
12 | self.tools.update(
13 | {
14 | "sentiment_tool": SentimentTool(),
15 | "web_search": WebSearch(),
16 | "terminate": Terminate(),
17 | }
18 | )
19 |
--------------------------------------------------------------------------------
/src/tool/__init__.py:
--------------------------------------------------------------------------------
1 | """Tool module for FinGenius platform."""
2 |
3 | from src.tool.base import BaseTool
4 | from src.tool.battle import Battle
5 | from src.tool.chip_analysis import ChipAnalysisTool
6 | from src.tool.create_chat_completion import CreateChatCompletion
7 | from src.tool.terminate import Terminate
8 | from src.tool.tool_collection import ToolCollection
9 | from src.tool.big_deal_analysis import BigDealAnalysisTool
10 |
11 |
12 | __all__ = [
13 | "BaseTool",
14 | "Battle",
15 | "ChipAnalysisTool",
16 | "Terminate",
17 | "ToolCollection",
18 | "CreateChatCompletion",
19 | "BigDealAnalysisTool",
20 | ]
21 |
--------------------------------------------------------------------------------
/src/mcp/hot_money_srver.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import sys
3 |
4 | from src.mcp.server import MCPServer
5 | from src.tool import Terminate
6 | from src.tool.hot_money import HotMoneyTool
7 |
8 |
9 | logging.basicConfig(level=logging.INFO, handlers=[logging.StreamHandler(sys.stderr)])
10 |
11 |
12 | class HotMoneyServer(MCPServer):
13 | def __init__(self, name: str = "HotMoneyServer"):
14 | super().__init__(name)
15 |
16 | def _initialize_standard_tools(self) -> None:
17 | self.tools.update(
18 | {
19 | "hot_money_tool": HotMoneyTool(),
20 | "terminate": Terminate(),
21 | }
22 | )
23 |
--------------------------------------------------------------------------------
/src/tool/financial_deep_search/__init__.py:
--------------------------------------------------------------------------------
1 | from src.tool.financial_deep_search.get_section_data import get_all_section
2 | from src.tool.financial_deep_search.index_capital import get_index_capital_flow
3 | from src.tool.financial_deep_search.risk_control_data import (
4 | get_announcements_with_detail,
5 | get_company_name_for_stock,
6 | get_financial_reports,
7 | get_risk_control_data,
8 | )
9 | from src.tool.financial_deep_search.stock_capital import (
10 | fetch_single_stock_capital_flow,
11 | fetch_stock_list_capital_flow,
12 | get_stock_capital_flow,
13 | )
14 |
15 |
16 | __all__ = [
17 | "get_stock_capital_flow",
18 | "fetch_single_stock_capital_flow",
19 | "fetch_stock_list_capital_flow",
20 | "get_risk_control_data",
21 | "get_announcements_with_detail",
22 | "get_financial_reports",
23 | "get_company_name_for_stock",
24 | "get_index_capital_flow",
25 | "get_all_section",
26 | ]
27 |
--------------------------------------------------------------------------------
/src/tool/terminate.py:
--------------------------------------------------------------------------------
1 | from src.tool.base import BaseTool
2 |
3 |
4 | _TERMINATE_DESCRIPTION = """Terminate the interaction when the request is met OR if the assistant cannot proceed further with the task.
5 | When you have finished all the tasks, call this tool to end the work."""
6 |
7 |
8 | class Terminate(BaseTool):
9 | name: str = "terminate"
10 | description: str = _TERMINATE_DESCRIPTION
11 | parameters: dict = {
12 | "type": "object",
13 | "properties": {
14 | "status": {
15 | "type": "string",
16 | "description": "The finish status of the interaction.",
17 | "enum": ["success", "failure"],
18 | }
19 | },
20 | "required": ["status"],
21 | }
22 |
23 | async def execute(self, status: str) -> str:
24 | """Finish the current execution"""
25 | return f"The interaction has been completed with status: {status}"
26 |
--------------------------------------------------------------------------------
/.gitattributes:
--------------------------------------------------------------------------------
1 | # HTML code is incorrectly calculated into statistics, so ignore them
2 | *.html linguist-detectable=false
3 | # Auto detect text files and perform LF normalization
4 | * text=auto eol=lf
5 | # Ensure shell scripts use LF (Linux style) line endings on Windows
6 | *.sh text eol=lf
7 | # Treat specific binary files as binary and prevent line ending conversion
8 | *.png binary
9 | *.jpg binary
10 | *.gif binary
11 | *.ico binary
12 | *.jpeg binary
13 | *.mp3 binary
14 | *.zip binary
15 | *.bin binary
16 | # Preserve original line endings for specific document files
17 | *.doc text eol=crlf
18 | *.docx text eol=crlf
19 | *.pdf binary
20 | # Ensure source code and script files use LF line endings
21 | *.py text eol=lf
22 | *.js text eol=lf
23 | *.html text eol=lf
24 | *.css text eol=lf
25 | # Specify custom diff driver for specific file types
26 | *.md diff=markdown
27 | *.json diff=json
28 | *.mp4 filter=lfs diff=lfs merge=lfs -text
29 | *.mov filter=lfs diff=lfs merge=lfs -text
30 | *.webm filter=lfs diff=lfs merge=lfs -text
31 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | repos:
2 | - repo: https://github.com/psf/black
3 | rev: 23.1.0
4 | hooks:
5 | - id: black
6 |
7 | - repo: https://github.com/pre-commit/pre-commit-hooks
8 | rev: v4.4.0
9 | hooks:
10 | - id: trailing-whitespace
11 | - id: end-of-file-fixer
12 | - id: check-yaml
13 | - id: check-added-large-files
14 |
15 | - repo: https://github.com/PyCQA/autoflake
16 | rev: v2.0.1
17 | hooks:
18 | - id: autoflake
19 | args: [
20 | --remove-all-unused-imports,
21 | --ignore-init-module-imports,
22 | --expand-star-imports,
23 | --remove-duplicate-keys,
24 | --remove-unused-variables,
25 | --recursive,
26 | --in-place,
27 | --exclude=__init__.py,
28 | ]
29 | files: \.py$
30 |
31 | - repo: https://github.com/pycqa/isort
32 | rev: 5.12.0
33 | hooks:
34 | - id: isort
35 | args: [
36 | "--profile", "black",
37 | "--filter-files",
38 | "--lines-after-imports=2",
39 | ]
40 |
--------------------------------------------------------------------------------
/src/tool/financial_deep_search/index_name_map.json:
--------------------------------------------------------------------------------
1 | {
2 | "000001": "上证指数",
3 | "399001": "深证成指",
4 | "899050": "北证50",
5 | "399006": "创业板指",
6 | "000680": "科创综指",
7 | "000688": "科创50",
8 | "399330": "深证100",
9 | "000300": "沪深300",
10 | "000016": "上证50",
11 | "399673": "创业板50",
12 | "000888": "上证综合全收益",
13 | "399750": "深主板50",
14 | "930050": "中证A50",
15 | "000903": "中证A100",
16 | "000510": "中证A500",
17 | "000904": "中证200",
18 | "000905": "中证500",
19 | "000906": "中证800",
20 | "000852": "中证1000",
21 | "932000": "中证2000",
22 | "000985": "中证全指",
23 | "000010": "上证180",
24 | "000009": "上证380",
25 | "000132": "上证100",
26 | "000133": "上证150",
27 | "000003": "B股指数",
28 | "000012": "国债指数",
29 | "000013": "企债指数",
30 | "000011": "基金指数",
31 | "399002": "深成指R",
32 | "399850": "深证50",
33 | "399005": "中小100",
34 | "399003": "成份B指",
35 | "399106": "深证综指",
36 | "399004": "深证100R",
37 | "399007": "深证300",
38 | "399008": "中小300",
39 | "399293": "创业大盘",
40 | "399019": "创业200",
41 | "399020": "创业小盘",
42 | "399100": "新指数",
43 | "399550": "央视50"
44 | }
45 |
--------------------------------------------------------------------------------
/src/tool/search/google_search.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 |
3 | from googlesearch import search
4 |
5 | from src.tool.search.base import SearchItem, WebSearchEngine
6 |
7 |
8 | class GoogleSearchEngine(WebSearchEngine):
9 | def perform_search(
10 | self, query: str, num_results: int = 10, *args, **kwargs
11 | ) -> List[SearchItem]:
12 | """
13 | Google search engine.
14 |
15 | Returns results formatted according to SearchItem model.
16 | """
17 | raw_results = search(query, num_results=num_results, advanced=True)
18 |
19 | results = []
20 | for i, item in enumerate(raw_results):
21 | if isinstance(item, str):
22 | # If it's just a URL
23 | results.append(
24 | {"title": f"Google Result {i+1}", "url": item, "description": ""}
25 | )
26 | else:
27 | results.append(
28 | SearchItem(
29 | title=item.title, url=item.url, description=item.description
30 | )
31 | )
32 |
33 | return results
34 |
--------------------------------------------------------------------------------
/src/agent/react.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | from typing import Optional
3 |
4 | from pydantic import Field
5 |
6 | from src.agent.base import BaseAgent
7 | from src.llm import LLM
8 | from src.schema import AgentState, Memory
9 |
10 |
11 | class ReActAgent(BaseAgent, ABC):
12 | name: str
13 | description: Optional[str] = None
14 |
15 | system_prompt: Optional[str] = None
16 | next_step_prompt: Optional[str] = None
17 |
18 | llm: Optional[LLM] = Field(default_factory=LLM)
19 | memory: Memory = Field(default_factory=Memory)
20 | state: AgentState = AgentState.IDLE
21 |
22 | max_steps: int = 10
23 | current_step: int = 0
24 |
25 | @abstractmethod
26 | async def think(self) -> bool:
27 | """Process current state and decide next action"""
28 |
29 | @abstractmethod
30 | async def act(self) -> str:
31 | """Execute decided actions"""
32 |
33 | async def step(self) -> str:
34 | """Execute a single step: think and act."""
35 | should_act = await self.think()
36 | if not should_act:
37 | return "Thinking complete - no action needed"
38 | return await self.act()
39 |
--------------------------------------------------------------------------------
/src/prompt/hot_money.py:
--------------------------------------------------------------------------------
1 | HOT_MONEY_SYSTEM_PROMPT = """你是一位专业的游资行为分析师,擅长解读市场中游资的操作特点、风格和可能的意图。你的任务是基于提供的数据和信息,进行客观的游资行为分析,帮助用户理解游资的操作模式,但绝不提供任何形式的投资建议或引导用户做出投资决策。
2 |
3 | 分析范围
4 | 1. 游资席位的持仓变化
5 | 2. 游资的交易风格和特点
6 | 3. 历史操作模式和偏好
7 | 4. 资金流向和集中度
8 | 5. 游资之间的关联性和协同特征
9 |
10 | 分析方法
11 | 1. 数据分析:通过席位龙虎榜数据、成交量、换手率等量化指标分析
12 | 2. 模式识别:辨别游资的操作套路、惯用手法
13 | 3. 历史追踪:回顾特定游资的历史操作轨迹和成功率
14 | 4. 关联分析:挖掘不同游资之间的关联性和互动模式
15 |
16 | 输出格式
17 | 1. 游资背景简介:对所分析游资的基本情况描述
18 | 2. 操作特点总结:概括该游资的典型操作风格和特征
19 | 3. 近期行为分析:分析其近期操作的特点和可能的思路
20 | 4. 注意事项:提醒用户关注的风险点和需要注意的因素
21 |
22 | 工具使用规范
23 | 在分析过程中,请合理使用以下工具:
24 | - hot_money_tool:获取龙虎榜数据和资金流向信息
25 | - terminate:当你完成了完整的游资分析报告后,必须使用此工具结束任务
26 |
27 | ⚠️ 重要提醒:当你完成了游资分析并准备输出最终报告时,请立即使用terminate工具结束任务,避免无限循环。
28 |
29 | 重要免责声明
30 | 1. 本分析仅供参考,绝不构成任何形式的投资建议
31 | 2. 分析内容基于历史数据和公开信息,不预测未来市场走势
32 | 3. 不推荐、不暗示、不引导用户进行任何具体投资操作
33 | 4. 用户应自行承担投资决策的全部责任和风险
34 | 5. 分析不构成任何买入或卖出的建议,用户必须独立做出决策
35 |
36 | 使用说明
37 | 在提问时,请尽可能提供以下信息以获得更准确的分析:
38 | 1. 具体关注的游资席位或代码
39 | 2. 关注的时间段
40 | 3. 已知的相关信息
41 | 4. 希望了解的具体方面
42 |
43 | 示例分析框架
44 | 1. 游资背景分析
45 | - 历史活跃度和关注领域
46 | - 典型操作风格和特点
47 | - 历史成功案例特征
48 | 2. 近期操作分析
49 | - 资金规模和集中度变化
50 | - 进出节奏和持仓周期
51 | - 与其他席位的关联性
52 | 3. 行为模式解读
53 | - 可能的操作思路分析
54 | - 操作阶段判断
55 | - 风险点提示
56 | 4. 总结观点
57 | - 客观中立的行为总结
58 | - 值得关注的关键指标
59 | - 再次强调:本分析不构成任何投资建议,仅供参考
60 | """
61 |
--------------------------------------------------------------------------------
/src/agent/report.py:
--------------------------------------------------------------------------------
1 | from pydantic import Field
2 |
3 | from src.agent.mcp import MCPAgent
4 | from src.prompt.mcp import NEXT_STEP_PROMPT_ZN
5 | from src.prompt.report import REPORT_SYSTEM_PROMPT
6 | from src.tool import Terminate, ToolCollection
7 | from src.tool.create_html import CreateHtmlTool
8 |
9 |
10 | class ReportAgent(MCPAgent):
11 | """Report generation agent that synthesizes insights from other agents."""
12 |
13 | name: str = "report_agent"
14 | description: str = "Generates comprehensive reports by synthesizing insights from other specialized agents."
15 | system_prompt: str = REPORT_SYSTEM_PROMPT
16 | next_step_prompt: str = NEXT_STEP_PROMPT_ZN
17 |
18 | # Initialize with FinGenius tools and proper type annotation
19 | available_tools: ToolCollection = Field(
20 | default_factory=lambda: ToolCollection(
21 | CreateHtmlTool(),
22 | Terminate(),
23 | )
24 | )
25 | special_tool_names: list[str] = Field(default_factory=lambda: [Terminate().name])
26 |
27 |
28 | if __name__ == "__main__":
29 | import asyncio
30 |
31 | async def run_agent():
32 | agent = await ReportAgent.create()
33 | await agent.initialize()
34 | prompt = "生成一个关于AI的报告 html格式"
35 | await agent.run(prompt)
36 |
37 | asyncio.run(run_agent())
38 |
--------------------------------------------------------------------------------
/src/tool/search/base.py:
--------------------------------------------------------------------------------
1 | from typing import List, Optional
2 |
3 | from pydantic import BaseModel, Field
4 |
5 |
6 | class SearchItem(BaseModel):
7 | """Represents a single search result item"""
8 |
9 | title: str = Field(description="The title of the search result")
10 | url: str = Field(description="The URL of the search result")
11 | description: Optional[str] = Field(
12 | default=None, description="A description or snippet of the search result"
13 | )
14 |
15 | def __str__(self) -> str:
16 | """String representation of a search result item."""
17 | return f"{self.title} - {self.url}"
18 |
19 |
20 | class WebSearchEngine(BaseModel):
21 | """Base class for web search engines."""
22 |
23 | model_config = {"arbitrary_types_allowed": True}
24 |
25 | def perform_search(
26 | self, query: str, num_results: int = 10, *args, **kwargs
27 | ) -> List[SearchItem]:
28 | """
29 | Perform a web search and return a list of search items.
30 |
31 | Args:
32 | query (str): The search query to submit to the search engine.
33 | num_results (int, optional): The number of search results to return. Default is 10.
34 | args: Additional arguments.
35 | kwargs: Additional keyword arguments.
36 |
37 | Returns:
38 | List[SearchItem]: A list of SearchItem objects matching the search query.
39 | """
40 | raise NotImplementedError
41 |
--------------------------------------------------------------------------------
/src/agent/big_deal_analysis.py:
--------------------------------------------------------------------------------
1 | from typing import Any, List, Optional
2 | from pydantic import Field
3 |
4 | from src.agent.mcp import MCPAgent
5 | from src.prompt.big_deal_analysis import BIG_DEAL_SYSTEM_PROMPT
6 | from src.prompt.mcp import NEXT_STEP_PROMPT_ZN
7 | from src.schema import Message
8 | from src.tool import Terminate, ToolCollection
9 | from src.tool.big_deal_analysis import BigDealAnalysisTool
10 |
11 |
12 | class BigDealAnalysisAgent(MCPAgent):
13 | """大单异动分析 Agent"""
14 |
15 | name: str = "big_deal_analysis_agent"
16 | description: str = "分析市场及个股大单资金异动,为投资决策提供依据。"
17 |
18 | system_prompt: str = BIG_DEAL_SYSTEM_PROMPT
19 | next_step_prompt: str = NEXT_STEP_PROMPT_ZN
20 |
21 | available_tools: ToolCollection = Field(
22 | default_factory=lambda: ToolCollection(
23 | BigDealAnalysisTool(),
24 | Terminate(),
25 | )
26 | )
27 | # 限制单次观察字符,防止内存过大导致 LLM 无法响应
28 | max_observe: int = 10000
29 | special_tool_names: List[str] = Field(default_factory=lambda: [Terminate().name])
30 |
31 | async def run(
32 | self,
33 | request: Optional[str] = None,
34 | stock_code: Optional[str] = None,
35 | ) -> Any:
36 | """Run big deal analysis"""
37 | if stock_code and not request:
38 | self.memory.add_message(
39 | Message.system_message(
40 | f"你正在分析股票 {stock_code} 的大单资金流向,请综合资金异动与价格走势给出结论。"
41 | )
42 | )
43 | request = f"请对 {stock_code} 进行大单异动深度分析,并生成投资建议。"
44 |
45 | return await super().run(request)
--------------------------------------------------------------------------------
/src/prompt/risk_control.py:
--------------------------------------------------------------------------------
1 | RISK_SYSTEM_PROMPT = """你是一位专业的风险控制顾问,专注于财务风险和法务风险的识别、评估与管理。你擅长从财务数据和法律法规角度分析潜在风险点,并提供风险控制策略建议。你的回答基于风险管理的普遍原则和最佳实践,不构成具体的财务或法律建议。
2 | 分析范围
3 | 财务风险
4 | 流动性风险
5 | 信用风险
6 | 市场风险
7 | 运营财务风险
8 | 财务报表异常识别
9 | 资金管理风险
10 | 预算控制风险
11 | 财务合规风险
12 | 法务风险
13 | 合同风险
14 | 监管合规风险
15 | 知识产权风险
16 | 劳动法律风险
17 | 诉讼风险
18 | 企业治理风险
19 | 数据合规与隐私风险
20 | 行业特定法规风险
21 | 风险评估方法
22 | 风险识别:通过系统性分析识别潜在风险点
23 | 风险评估:评估风险发生的可能性和潜在影响
24 | 风险分级:按严重程度和紧急度对风险进行分级
25 | 控制措施建议:提供降低或消除风险的可行措施
26 | 监控建议:提出持续监控风险的指标和方法
27 | 输出格式
28 | 风险概述:对提出问题的整体风险状况评估
29 | 关键风险识别:列出主要财务和法务风险点
30 | 风险评估矩阵:按影响力和可能性评级
31 | 控制措施建议:针对各风险点的具体管控建议
32 | 监控机制:持续监控风险的方法和指标
33 | 工具使用指南
34 | 在分析过程中,请合理使用以下工具:
35 | - risk_control_tool:获取股票的财务数据和法务公告数据
36 | - terminate:当你完成了完整的风险分析报告后,必须使用此工具结束任务
37 |
38 | ⚠️ 重要提醒:当你完成了风险分析并准备输出最终报告时,请立即使用terminate工具结束任务,避免无限循环。
39 |
40 | 重要免责声明
41 | 本分析仅提供风险管理的一般性建议,不构成具体的财务或法律建议
42 | 分析基于提供的信息和一般风险管理原则,不替代专业财务顾问或法律顾问的意见
43 | 用户在做出任何财务或法律决策前,应咨询具有相关资质的专业人士
44 | 风险评估结果取决于提供信息的准确性和完整性
45 | 不对用户基于本分析做出的决策结果承担责任
46 | 使用指南
47 | 在提问时,请尽可能提供以下信息以获得更准确的风险分析:
48 | 企业或项目的基本情况(规模、行业、发展阶段等)
49 | 具体关注的财务或法务问题
50 | 已知的风险点或担忧领域
51 | 现有的风险控制措施
52 | 适用的主要法规或标准
53 | 示例分析框架
54 | 1. 财务风险分析
55 | 流动性评估:现金流、营运资金、短期偿债能力
56 | 财务杠杆风险:负债率、利息覆盖率
57 | 财务报表隐患:异常指标、会计处理风险
58 | 内控机制评估:资金审批流程、职责分离
59 | 2. 法务风险分析
60 | 合同管理风险:条款缺陷、履约风险、终止条件
61 | 合规风险:行业法规、许可证要求、报告义务
62 | 知识产权保护:商标、专利、商业秘密保护措施
63 | 公司治理风险:决策流程、信息披露、利益冲突
64 | 3. 综合风险评估
65 | 风险关联性:财务与法务风险的交叉影响
66 | 系统性风险:可能导致连锁反应的核心风险
67 | 风险优先级:需要立即关注的高优先级风险
68 | 4. 风险控制建议
69 | 预防措施:避免风险发生的策略
70 | 缓解措施:减轻风险影响的方法
71 | 转移策略:保险或外包等风险转移方案
72 | 监控机制:关键风险指标(KRI)设置
73 | 5. 行动计划建议
74 | 短期措施:立即可执行的风险控制行动
75 | 中长期策略:系统性风险管理体系建设
76 | 责任分配:风险管理任务的部门分工建议
77 | 定期评估机制:风险控制效果的跟踪评估方法
78 | """
79 |
--------------------------------------------------------------------------------
/src/logger.py:
--------------------------------------------------------------------------------
1 | from datetime import datetime
2 | from pathlib import Path
3 |
4 | from loguru import logger as _logger
5 | from rich.logging import RichHandler
6 |
7 | from src.config import PROJECT_ROOT
8 |
9 |
10 | _print_level = "INFO"
11 |
12 |
13 | def define_log_level(print_level="INFO", logfile_level="DEBUG", name: str = None):
14 | """Adjust the log level to above level"""
15 | global _print_level
16 | _print_level = print_level
17 |
18 | current_date = datetime.now()
19 | formatted_date = current_date.strftime("%Y%m%d%H%M%S")
20 | log_name = (
21 | f"{name}_{formatted_date}" if name else formatted_date
22 | ) # name a log with prefix name
23 |
24 | _logger.remove()
25 |
26 | # Only write to file by default
27 | _logger.add(PROJECT_ROOT / f"logs/{log_name}.log", level=logfile_level)
28 |
29 | # Add terminal handler only if explicitly requested
30 | if print_level != "OFF":
31 | _logger.add(
32 | RichHandler(rich_tracebacks=True, markup=True, show_time=False, show_path=False),
33 | level=print_level,
34 | format="{message}",
35 | )
36 |
37 | return _logger
38 |
39 |
40 | logger = define_log_level(print_level="OFF") # Default to file-only logging
41 |
42 |
43 | if __name__ == "__main__":
44 | logger.info("Starting application")
45 | logger.debug("Debug message")
46 | logger.warning("Warning message")
47 | logger.error("Error message")
48 | logger.critical("Critical message")
49 |
50 | try:
51 | raise ValueError("Test error")
52 | except Exception as e:
53 | logger.exception(f"An error occurred: {e}")
54 |
--------------------------------------------------------------------------------
/src/prompt/technical_analysis.py:
--------------------------------------------------------------------------------
1 | TECHNICAL_ANALYSIS_SYSTEM_PROMPT = """你是一位专业的技术分析师,擅长运用各种技术指标、图表模式和量价分析来解读股票市场的价格走势。你能够客观分析股票的技术面情况,识别趋势、支撑阻力位、交易信号等,为用户提供专业的技术分析视角。你的任务是基于技术分析原理进行客观分析,不提供具体的投资建议。
2 |
3 | ## 分析范围
4 | - **趋势分析**:识别主要趋势、次要趋势和短期波动
5 | - **支撑阻力**:确定关键的支撑位和阻力位
6 | - **技术指标**:RSI、MACD、KDJ、布林带、均线系统等
7 | - **K线形态**:单根K线、K线组合形态分析
8 | - **成交量分析**:量价关系、成交量指标
9 | - **图表形态**:头肩顶底、双顶双底、三角形、楔形等
10 | - **技术位点**:突破位、回调位、压力位
11 |
12 | ## 工具使用规范
13 | 在分析过程中,请合理使用以下工具:
14 | - **technical_analysis_tool**:获取股票的K线数据、技术指标和成交量信息
15 | - **terminate**:当你完成了完整的技术分析报告后,必须使用此工具结束任务
16 |
17 | ⚠️ **重要提醒**:当你完成了技术分析并准备输出最终报告时,请立即使用terminate工具结束任务,避免无限循环。
18 |
19 | ## 分析方法
20 | - **多时间框架分析**:结合日线、周线、月线进行综合判断
21 | - **趋势确认**:使用多个指标验证趋势的有效性
22 | - **量价配合**:观察价格变动与成交量的协调性
23 | - **形态识别**:识别经典的技术形态和突破信号
24 | - **指标背离**:发现价格与指标的背离现象
25 |
26 | ## 输出格式
27 | 1. **技术概述**:股票当前的技术面整体状况
28 | 2. **趋势分析**:主要趋势方向和趋势强度评估
29 | 3. **关键位点**:重要的支撑位、阻力位、突破位
30 | 4. **技术指标解读**:主要技术指标的当前状态和信号
31 | 5. **K线形态分析**:近期K线的形态特征和含义
32 | 6. **成交量分析**:量价关系的协调性和异常情况
33 | 7. **技术信号总结**:当前的买卖信号和操作提示
34 |
35 | ## 重要免责声明
36 | - 本分析仅供参考,不构成任何投资建议
37 | - 技术分析基于历史数据,不能保证未来表现
38 | - 市场存在不确定性,技术信号可能失效
39 | - 用户应结合基本面分析和自身风险承受能力做出决策
40 | - 不对基于本分析的投资结果承担责任
41 |
42 | ## 分析框架
43 | ### 1. 趋势系统分析
44 | - **主要趋势**:判断股票的长期运行方向
45 | - **均线系统**:多条均线的排列和交叉情况
46 | - **趋势线**:上升趋势线、下降趋势线的有效性
47 | - **趋势强度**:评估当前趋势的可持续性
48 |
49 | ### 2. 技术指标综合分析
50 | - **动量指标**:RSI、KDJ等超买超卖情况
51 | - **趋势指标**:MACD金叉死叉、MACD柱状线变化
52 | - **压力支撑**:布林带上下轨的压力支撑作用
53 | - **成交量指标**:OBV、量比等资金流向判断
54 |
55 | ### 3. K线形态分析
56 | - **单根K线**:大阳线、大阴线、十字星、锤子线等
57 | - **K线组合**:早晨之星、黄昏之星、三只乌鸦等
58 | - **缺口分析**:普通缺口、突破缺口、衰竭缺口的性质
59 |
60 | ### 4. 量价关系分析
61 | - **量价配合**:价涨量增、价跌量缩的健康状态
62 | - **量价背离**:价格创新高而成交量萎缩的警示信号
63 | - **异常放量**:突然的成交量放大及其含义
64 |
65 | ### 5. 图表形态识别
66 | - **反转形态**:头肩顶底、双顶双底、V形反转
67 | - **整理形态**:三角形、矩形、楔形、旗形
68 | - **突破确认**:形态突破的有效性和目标位测算
69 |
70 | ### 6. 风险控制要点
71 | - **止损位设置**:基于技术位点的止损建议
72 | - **风险提示**:技术面存在的主要风险点
73 | - **操作策略**:基于技术分析的一般性操作思路
74 | """
75 |
--------------------------------------------------------------------------------
/src/prompt/report.py:
--------------------------------------------------------------------------------
1 | REPORT_SYSTEM_PROMPT = """你是一位专业的综合分析报告生成专家,能够整合来自多个专业分析域(游资分析、风险控制、舆情分析、技术面分析)的信息,形成系统化、结构清晰、逻辑严密的综合分析报告。你的工作是将这些专业领域的分析进行融合、提炼和深化,找出其中的关联性和互动关系,生成具有战略价值的综合报告。
2 |
3 | 报告目标
4 | 整合多维度分析信息,提供全局视角
5 | 挖掘不同分析领域之间的关联性和相互影响
6 | 识别综合分析中的关键风险点和机会点
7 | 构建清晰、专业、易于理解的分析框架
8 | 提供基于客观分析的综合观点和思考方向
9 |
10 | 输入信息
11 | 本报告基于以下四个专业分析领域的输入:
12 | 游资分析:游资席位的操作特点、风格和可能意图
13 | 风险控制:财务风险和法务风险的识别、评估与管理建议
14 | 舆情分析:媒体平台舆论动态、情感倾向、传播规律和影响因素
15 | 技术面分析:图表形态、技术指标和量价关系分析
16 |
17 | 报告结构与内容
18 | 1. 执行摘要
19 | 综合分析的核心发现和关键结论
20 | 主要风险点和值得关注的重要信号
21 | 不同分析维度的核心观点汇总
22 | 2. 多维度分析综述
23 | 游资行为解读:主要游资动向及其操作特点概述
24 | 风险态势评估:财务和法务风险的整体态势
25 | 舆论环境分析:主要舆论倾向和传播特点
26 | 技术面状况:技术指标和形态的关键信号
27 | 3. 关联性分析
28 | 游资行为与技术面信号的呼应关系
29 | 舆情变化对技术面的影响路径
30 | 风险因素与舆情演变的关联模式
31 | 各维度分析之间的交叉验证和互补性
32 | 4. 深度解析
33 | 游资-技术互动:游资行为如何影响或回应技术形态
34 | 舆情-资金关系:舆论环境变化与资金流向的互动
35 | 风险-市场反应:风险因素对市场情绪和行为的影响
36 | 技术-基本面结合点:技术信号与基本面因素的结合分析
37 | 5. 综合风险评估
38 | 多维度风险因素的叠加效应
39 | 潜在风险的传导路径和影响范围
40 | 风险预警指标体系建议
41 | 风险应对的优先级排序
42 | 6. 观察重点与监测建议
43 | 后续需重点关注的关键指标和信号
44 | 不同时间周期的监测重点
45 | 可能的转折点和触发条件
46 | 情境分析和敏感性分析
47 | 7. 综合结论
48 | 基于多维度分析的整体判断
49 | 关键不确定性因素
50 | 不同情境下的可能发展路径
51 | 分析的局限性说明
52 | 重要原则
53 | 专业性原则
54 | 客观中立:保持客观分析立场,不偏向任何特定观点
55 | 数据支持:分析结论有数据和事实支持,避免主观臆测
56 | 逻辑严密:分析推理过程清晰,逻辑链条完整
57 | 全面平衡:呈现多种可能性,不强调单一结果
58 | 整合性原则
59 | 有机融合:不是简单拼接各领域分析,而是有机整合
60 | 关联挖掘:深入挖掘不同分析领域之间的关联性
61 | 一致性检验:检验不同领域分析结果的一致性和差异性
62 | 矛盾处理:对不同分析领域出现的矛盾进行合理解释
63 | 实用性原则
64 | 重点突出:突出最关键的发现和最重要的结论
65 | 层次清晰:信息呈现有明确层次,便于理解和参考
66 | 精简有效:避免冗余信息,保持报告简洁高效
67 | 可操作性:提供具有实际参考价值的观察视角
68 | 免责声明
69 | 本综合报告基于各专业领域提供的分析信息,不构成任何投资建议
70 | 报告中的分析和观点仅供参考,使用者应自行判断其适用性
71 | 报告不对市场未来走势做出预测,也不推荐任何具体投资行为
72 | 使用者应结合自身情况和其他信息来源做出独立决策
73 | 使用指南
74 | 提供给综合报告生成器的输入应包括:
75 | 各专业领域(游资、风险、舆情、技术面)的分析内容
76 | 需要特别关注的具体问题或领域
77 | 报告的主要目标读者和用途
78 | 报告的期望深度和长度
79 | 特定的报告风格要求(如正式程度、专业术语使用等)
80 | 报告样式指南
81 | 格式建议
82 | 使用清晰的标题和副标题层级结构
83 | 适当使用要点符号提高可读性
84 | 关键结论或警示信息可使用醒目格式
85 | 适当使用图表概念说明复杂关系
86 | 重要数据或比较可使用表格形式展示
87 | 语言风格
88 | 保持专业、简洁、精准的表达
89 | 避免过度技术性术语,确保可理解性
90 | 结论表述谨慎客观,避免绝对化表达
91 | 多维度分析有不同声音时,平衡呈现
92 | 信息密度
93 | 执行摘要高度凝练,突出关键点
94 | 主体部分保持适当信息密度,详略得当
95 | 重要分析和次要分析区分明确
96 | 附加信息和背景知识适当补充,但不喧宾夺主
97 | """
98 |
--------------------------------------------------------------------------------
/src/tool/search/baidu_search.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 |
3 | from baidusearch.baidusearch import search
4 |
5 | from src.tool.search.base import SearchItem, WebSearchEngine
6 |
7 |
8 | class BaiduSearchEngine(WebSearchEngine):
9 | def perform_search(
10 | self, query: str, num_results: int = 10, *args, **kwargs
11 | ) -> List[SearchItem]:
12 | """
13 | Baidu search engine.
14 |
15 | Returns results formatted according to SearchItem model.
16 | """
17 | raw_results = search(query, num_results=num_results)
18 |
19 | # Convert raw results to SearchItem format
20 | results = []
21 | for i, item in enumerate(raw_results):
22 | if isinstance(item, str):
23 | # If it's just a URL
24 | results.append(
25 | SearchItem(title=f"Baidu Result {i+1}", url=item, description=None)
26 | )
27 | elif isinstance(item, dict):
28 | # If it's a dictionary with details
29 | results.append(
30 | SearchItem(
31 | title=item.get("title", f"Baidu Result {i+1}"),
32 | url=item.get("url", ""),
33 | description=item.get("abstract", None),
34 | )
35 | )
36 | else:
37 | # Try to get attributes directly
38 | try:
39 | results.append(
40 | SearchItem(
41 | title=getattr(item, "title", f"Baidu Result {i+1}"),
42 | url=getattr(item, "url", ""),
43 | description=getattr(item, "abstract", None),
44 | )
45 | )
46 | except Exception:
47 | # Fallback to a basic result
48 | results.append(
49 | SearchItem(
50 | title=f"Baidu Result {i+1}", url=str(item), description=None
51 | )
52 | )
53 |
54 | return results
55 |
--------------------------------------------------------------------------------
/src/agent/chip_analysis.py:
--------------------------------------------------------------------------------
1 | from typing import Any, List, Optional
2 |
3 | from pydantic import Field
4 |
5 | from src.agent.mcp import MCPAgent
6 | from src.prompt.chip_analysis import CHIP_ANALYSIS_SYSTEM_PROMPT
7 | from src.prompt.mcp import NEXT_STEP_PROMPT_ZN
8 | from src.schema import Message
9 | from src.tool import Terminate, ToolCollection
10 | from src.tool.chip_analysis import ChipAnalysisTool
11 |
12 |
13 | class ChipAnalysisAgent(MCPAgent):
14 | """筹码分析agent,专注于股票筹码分布和技术分析"""
15 |
16 | name: str = "chip_analysis_agent"
17 | description: str = (
18 | "专业的筹码分析师,擅长分析股票筹码分布、主力行为、散户情绪和A股特色筹码技术分析"
19 | )
20 |
21 | system_prompt: str = CHIP_ANALYSIS_SYSTEM_PROMPT
22 | next_step_prompt: str = NEXT_STEP_PROMPT_ZN
23 |
24 | # Initialize with FinGenius tools with proper type annotation
25 | available_tools: ToolCollection = Field(
26 | default_factory=lambda: ToolCollection(
27 | ChipAnalysisTool(),
28 | Terminate(),
29 | )
30 | )
31 | special_tool_names: List[str] = Field(default_factory=lambda: [Terminate().name])
32 |
33 | async def run(
34 | self, request: Optional[str] = None, stock_code: Optional[str] = None
35 | ) -> Any:
36 | """运行筹码分析,分析指定股票的筹码分布和技术指标
37 |
38 | Args:
39 | request: Optional initial request to process. If provided, overrides stock_code parameter.
40 | stock_code: The stock code/ticker to analyze
41 |
42 | Returns:
43 | Dictionary containing comprehensive chip analysis results
44 | """
45 | # If stock_code is provided but request is not, create request from stock_code
46 | if stock_code and not request:
47 | # Set up system message about the stock being analyzed
48 | self.memory.add_message(
49 | Message.system_message(
50 | f"你正在分析股票 {stock_code} 的筹码分布。请使用筹码分析工具获取筹码分布数据,并进行全面的筹码技术分析,包括主力成本、套牢区、集中度等关键指标。"
51 | )
52 | )
53 | request = f"请对 {stock_code} 进行全面的筹码分析,包括筹码分布、主力行为、散户情绪和交易建议。"
54 |
55 | # Call parent implementation with the request
56 | return await super().run(request)
--------------------------------------------------------------------------------
/src/agent/risk_control.py:
--------------------------------------------------------------------------------
1 | from typing import Any, List, Optional
2 |
3 | from pydantic import Field
4 |
5 | from src.agent.mcp import MCPAgent
6 | from src.prompt.mcp import NEXT_STEP_PROMPT_ZN
7 | from src.prompt.risk_control import RISK_SYSTEM_PROMPT
8 | from src.schema import Message
9 | from src.tool import Terminate, ToolCollection
10 | from src.tool.risk_control import RiskControlTool
11 |
12 |
13 | class RiskControlAgent(MCPAgent):
14 | """Risk analysis agent focused on identifying and quantifying investment risks."""
15 |
16 | name: str = "risk_control_agent"
17 | description: str = "Analyzes financial risks and proposes risk control strategies for stock investments."
18 | system_prompt: str = RISK_SYSTEM_PROMPT
19 | next_step_prompt: str = NEXT_STEP_PROMPT_ZN
20 |
21 | # Initialize with FinGenius tools with proper type annotation
22 | available_tools: ToolCollection = Field(
23 | default_factory=lambda: ToolCollection(
24 | RiskControlTool(),
25 | Terminate(),
26 | )
27 | )
28 | special_tool_names: List[str] = Field(default_factory=lambda: [Terminate().name])
29 |
30 | async def run(
31 | self, request: Optional[str] = None, stock_code: Optional[str] = None
32 | ) -> Any:
33 | """Run risk analysis on the given stock.
34 |
35 | Args:
36 | request: Optional initial request to process. If provided, overrides stock_code parameter.
37 | stock_code: The stock code/ticker to analyze
38 |
39 | Returns:
40 | Dictionary containing comprehensive risk analysis
41 | """
42 | # If stock_code is provided but request is not, create request from stock_code
43 | if stock_code and not request:
44 | # Set up system message about the stock being analyzed
45 | self.memory.add_message(
46 | Message.system_message(
47 | f"你正在分析股票 {stock_code} 的风险因素。请收集相关财务数据并进行全面风险评估。"
48 | )
49 | )
50 | request = f"请对 {stock_code} 进行全面的风险分析。"
51 |
52 | # Call parent implementation with the request
53 | return await super().run(request)
54 |
--------------------------------------------------------------------------------
/config/config.example.toml:
--------------------------------------------------------------------------------
1 | # Global LLM configuration
2 | # 模型越好,生成的效果越好,建议使用最好的模型。
3 | [llm]
4 | api_type = "openai" # 添加API类型,使用OpenAI兼容的API
5 | model = "claude-3-7-sonnet-20250219" # The LLM model to use, better use tool supported model
6 | base_url = "https://api.anthropic.com/v1/" # API endpoint URL
7 | api_key = "YOUR_API_KEY" # Your API key
8 | max_tokens = 8192 # Maximum number of tokens in the response
9 | temperature = 0.0 # Controls randomness
10 |
11 | # [llm] #AZURE OPENAI:
12 | # api_type= 'azure'
13 | # model = "YOUR_MODEL_NAME" #"gpt-4o-mini"
14 | # base_url = "{YOUR_AZURE_ENDPOINT.rstrip('/')}/openai/deployments/{AZURE_DEPOLYMENT_ID}"
15 | # api_key = "AZURE API KEY"
16 | # max_tokens = 8096
17 | # temperature = 0.0
18 | # api_version="AZURE API VERSION" #"2024-08-01-preview"
19 |
20 |
21 | # [llm]
22 | # api_type = "ollama"
23 | # model = "你的模型名称" # 例如: "llama3.2", "qwen2.5", "deepseek-coder"
24 | # base_url = "http://10.24.163.221:8080/v1" # 你的Ollama服务地址
25 | # api_key = "ollama" # 可以是任意值,Ollama会忽略但OpenAI SDK需要
26 | # max_tokens = 4096
27 | # temperature = 0.0
28 |
29 | # Optional configuration for specific LLM models
30 | # [llm.vision]
31 | # model = "claude-3-7-sonnet-20250219" # The vision model to use
32 | # base_url = "https://api.anthropic.com/v1/" # API endpoint URL for vision model
33 | # api_key = "YOUR_API_KEY" # Your API key for vision model
34 | # max_tokens = 8192 # Maximum number of tokens in the response
35 | # temperature = 0.0 # Controls randomness for vision model
36 |
37 | # [llm.vision] #OLLAMA VISION:
38 | # api_type = 'ollama'
39 | # model = "llama3.2-vision"
40 | # base_url = "http://localhost:11434/v1"
41 | # api_key = "ollama"
42 | # max_tokens = 4096
43 | # temperature = 0.0
44 |
45 | # Optional configuration, Search settings.
46 | [search]
47 | # Search engine for agent to use. Default is "Google", can be set to "Baidu" or "DuckDuckGo" or "Bing"
48 | # 对于国内用户,建议使用以下搜索引擎优先级:
49 | # Baidu(百度)- 国内访问最稳定
50 | # Bing(必应)- 国际化且国内可用
51 | # Google - 作为备选(需要良好的国际网络)
52 | # DuckDuckGo - 作为备选(需要良好的国际网络)
53 | engine = "Bing"
54 |
--------------------------------------------------------------------------------
/src/agent/hot_money.py:
--------------------------------------------------------------------------------
1 | from typing import Any, List, Optional
2 |
3 | from pydantic import Field
4 |
5 | from src.agent.mcp import MCPAgent
6 | from src.prompt.hot_money import HOT_MONEY_SYSTEM_PROMPT
7 | from src.prompt.mcp import NEXT_STEP_PROMPT_ZN
8 | from src.schema import Message
9 | from src.tool import Terminate, ToolCollection
10 | from src.tool.hot_money import HotMoneyTool
11 |
12 |
13 | class HotMoneyAgent(MCPAgent):
14 | """Hot money analysis agent focused on institutional trading patterns."""
15 |
16 | name: str = "hot_money_agent"
17 | description: str = (
18 | "Analyzes institutional trading patterns, fund positions, and capital flows."
19 | )
20 |
21 | system_prompt: str = HOT_MONEY_SYSTEM_PROMPT
22 | next_step_prompt: str = NEXT_STEP_PROMPT_ZN
23 |
24 | # Initialize with FinGenius tools with proper type annotation
25 | available_tools: ToolCollection = Field(
26 | default_factory=lambda: ToolCollection(
27 | HotMoneyTool(),
28 | Terminate(),
29 | )
30 | )
31 | special_tool_names: List[str] = Field(default_factory=lambda: [Terminate().name])
32 |
33 | async def run(
34 | self, request: Optional[str] = None, stock_code: Optional[str] = None
35 | ) -> Any:
36 | """Run institutional trading analysis on the given stock.
37 |
38 | Args:
39 | request: Optional initial request to process. If provided, overrides stock_code parameter.
40 | stock_code: The stock code/ticker to analyze
41 |
42 | Returns:
43 | Dictionary containing institutional trading analysis
44 | """
45 | # If stock_code is provided but request is not, create request from stock_code
46 | if stock_code and not request:
47 | # Set up system message about the stock being analyzed
48 | self.memory.add_message(
49 | Message.system_message(
50 | f"你正在分析股票 {stock_code} 的机构交易行为。请识别主要机构投资者,追踪持股变动,并分析资金流向与交易模式。"
51 | )
52 | )
53 | request = f"请分析 {stock_code} 的机构交易和资金持仓情况。"
54 |
55 | # Call parent implementation with the request
56 | return await super().run(request)
57 |
--------------------------------------------------------------------------------
/src/tool/search/duckduckgo_search.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 |
3 | from duckduckgo_search import DDGS
4 |
5 | from src.tool.search.base import SearchItem, WebSearchEngine
6 |
7 |
8 | class DuckDuckGoSearchEngine(WebSearchEngine):
9 | def perform_search(
10 | self, query: str, num_results: int = 10, *args, **kwargs
11 | ) -> List[SearchItem]:
12 | """
13 | DuckDuckGo search engine.
14 |
15 | Returns results formatted according to SearchItem model.
16 | """
17 | raw_results = DDGS().text(query, max_results=num_results)
18 |
19 | results = []
20 | for i, item in enumerate(raw_results):
21 | if isinstance(item, str):
22 | # If it's just a URL
23 | results.append(
24 | SearchItem(
25 | title=f"DuckDuckGo Result {i + 1}", url=item, description=None
26 | )
27 | )
28 | elif isinstance(item, dict):
29 | # Extract data from the dictionary
30 | results.append(
31 | SearchItem(
32 | title=item.get("title", f"DuckDuckGo Result {i + 1}"),
33 | url=item.get("href", ""),
34 | description=item.get("body", None),
35 | )
36 | )
37 | else:
38 | # Try to extract attributes directly
39 | try:
40 | results.append(
41 | SearchItem(
42 | title=getattr(item, "title", f"DuckDuckGo Result {i + 1}"),
43 | url=getattr(item, "href", ""),
44 | description=getattr(item, "body", None),
45 | )
46 | )
47 | except Exception:
48 | # Fallback
49 | results.append(
50 | SearchItem(
51 | title=f"DuckDuckGo Result {i + 1}",
52 | url=str(item),
53 | description=None,
54 | )
55 | )
56 |
57 | return results
58 |
--------------------------------------------------------------------------------
/src/prompt/chip_analysis.py:
--------------------------------------------------------------------------------
1 | CHIP_ANALYSIS_SYSTEM_PROMPT = """你是一位专业的筹码分析师,专精于A股市场的筹码分布技术分析。你能够深入解读筹码分布背后的主力意图、散户行为和市场博弈格局,为投资决策提供核心依据。
2 |
3 | ## 核心专业领域
4 |
5 | ### 1. 筹码分布分析
6 | - **筹码峰识别**:准确识别单峰、双峰、多峰分布形态
7 | - **集中度分析**:90%/70%筹码集中度计算与解读
8 | - **筹码迁移**:追踪筹码从低位向高位或高位向低位的转移过程
9 | - **筹码锁定**:判断筹码的稳定性和持仓意愿
10 |
11 | ### 2. 主力行为分析(A股机构思维)
12 | - **主力成本识别**:通过筹码峰值精准定位主力建仓成本区间
13 | - **成本乖离率分析**:计算当前价格与主力成本的偏离程度
14 | - **控盘程度评估**:基于筹码集中度判断主力控盘强度
15 | - **主力获利空间**:评估主力当前的盈利状况和操作空间
16 |
17 | ### 3. 散户行为分析(散户市特征)
18 | - **套牢区识别**:准确定位散户被套区域和套牢深度
19 | - **恐慌情绪分析**:通过低位筹码变化判断散户恐慌程度
20 | - **跟风盘分析**:识别上涨过程中新进散户的行为特征
21 | - **割肉行为**:分析散户在底部的割肉强度和时机
22 |
23 | ### 4. A股特色分析
24 | - **政策市响应**:分析政策利好前后的筹码分布变化
25 | - **游资操作模式**:识别涨停板、一日游、龙回头等游资操作特征
26 | - **机构调仓轨迹**:追踪季度调仓、北上资金、公募抱团的筹码变化
27 | - **减持窗口预警**:预测大股东减持和解禁压力
28 |
29 | ## 分析方法论
30 |
31 | ### 技术指标体系
32 | 1. **主力成本乖离率** = (当前价-主力成本)/主力成本 × 100%
33 | 2. **散户套牢深度** = (最高套牢区价格-当前价)/当前价 × 100%
34 | 3. **筹码稳定指数** = 长期持有筹码占比
35 | 4. **异动转移率** = 近期筹码变动量/总筹码量
36 |
37 | ### 交易信号识别
38 | **买入信号**:
39 | - 底部单峰密集:90%集中度<15% + 获利比例<20%
40 | - 主力成本支撑:价格回踩主力成本线 + 筹码锁定率>60%
41 | - 恐慌筹码收集:单日筹码下移率>20% + 量能萎缩
42 |
43 | **卖出信号**:
44 | - 高位双峰背离:上下筹码峰形成 + 集中度快速发散
45 | - 获利盘出逃:90%集中度>30% + 单日转移率>15%
46 | - 机构派发迹象:高位筹码稳定度骤降
47 |
48 | **风险预警**:
49 | - 减持雷区:股价接近大股东成本区
50 | - 质押平仓风险:价格接近质押平仓线
51 | - 流动性危机:高集中度(>25%) + 低换手(<1%)
52 |
53 | ## 输出标准
54 |
55 | ### 1. 筹码分布概况
56 | - 当前筹码分布形态描述
57 | - 主要筹码峰位置和成本区间
58 | - 筹码集中度水平评估
59 |
60 | ### 2. 主力行为画像
61 | - 主力控盘阶段判断
62 | - 主力成本区间识别
63 | - 近期操作行为分析
64 |
65 | ### 3. 压力支撑分析
66 | - 关键支撑位:主要筹码峰位置
67 | - 压力位:历史套牢区域
68 | - 突破或跌破概率评估
69 |
70 | ### 4. 交易决策建议
71 | - 明确的买入/卖出/持有建议
72 | - 风险点提示
73 | - 止损止盈位设定
74 |
75 | # 工具使用规范
76 |
77 | 在分析过程中,请合理使用以下工具:
78 | - **chip_analysis_tool**:获取股票的筹码分布数据和相关指标
79 | - **terminate**:当你完成了完整的筹码分析报告后,必须使用此工具结束任务
80 |
81 | ⚠️ **重要提醒**:当你完成了筹码分析并准备输出最终报告时,请立即使用terminate工具结束任务,避免无限循环。
82 | ## 分析原则
83 |
84 | 1. **数据驱动**:基于真实筹码分布数据,不做主观臆测
85 | 2. **A股特色**:结合A股市场特有的政策市、资金市特征
86 | 3. **博弈思维**:从主力与散户博弈角度解读筹码变化
87 | 4. **风险优先**:重点识别风险点,避免追高杀跌
88 | 5. **客观中立**:不带个人情绪,基于数据得出结论
89 |
90 | ## 表达风格
91 |
92 | - **专业术语**:使用标准的筹码分析术语
93 | - **逻辑清晰**:先分析现状,再推导结论
94 | - **量化表达**:用具体数据支撑分析观点
95 | - **实用导向**:提供可操作的交易建议
96 |
97 | 你的任务是运用专业的筹码分析技能,为用户提供准确、实用的筹码分析报告,帮助他们在A股市场中做出更明智的投资决策。请始终保持专业、客观、负责任的态度。"""
--------------------------------------------------------------------------------
/docs/flow_diagram.md:
--------------------------------------------------------------------------------
1 | ```mermaid
2 | sequenceDiagram
3 | participant User
4 | participant Main as Main Program
5 | participant ResearchEnv as Research Environment
6 | participant BattleEnv as Battle Environment
7 | participant Memory as Memory System
8 | participant MCP as MCP Server & Tools
9 | participant SA as Sentiment Agent
10 | participant RA as Risk Control Agent
11 | participant HMA as Hot Money Agent
12 | participant TAA as Technical Analysis Agent
13 | participant REP as Report Agent
14 |
15 | User->>Main: 输入股票代码 (run_stock_pipeline)
16 |
17 | Main->>ResearchEnv: 创建研究环境 (ResearchEnvironment.create)
18 | Main->>BattleEnv: 创建博弈环境 (BattleEnvironment.create)
19 |
20 | Main->>ResearchEnv: 运行股票研究 (run)
21 |
22 | Note over ResearchEnv: 执行研究环节
23 |
24 | par 并行研究请求
25 | ResearchEnv->>SA: 请求舆情分析
26 | ResearchEnv->>RA: 请求风险分析
27 | ResearchEnv->>HMA: 请求市场分析
28 | ResearchEnv->>TAA: 请求技术分析
29 | end
30 |
31 | Note over SA,TAA: 各Agent调用MCP提供的工具
32 | SA->>MCP: 请求舆情数据与工具
33 | MCP-->>SA: 提供舆情数据与分析能力
34 |
35 | RA->>MCP: 请求风险评估工具
36 | MCP-->>RA: 提供风险评估能力
37 |
38 | HMA->>MCP: 请求市场数据与分析工具
39 | MCP-->>HMA: 提供市场分析能力
40 |
41 | TAA->>MCP: 请求技术分析工具
42 | MCP-->>TAA: 提供技术分析能力
43 |
44 | SA-->>ResearchEnv: 返回舆情分析
45 | RA-->>ResearchEnv: 返回风险评估
46 | HMA-->>ResearchEnv: 返回市场分析
47 | TAA-->>ResearchEnv: 返回技术分析
48 |
49 | ResearchEnv->>REP: 请求生成报告
50 | REP->>MCP: 请求报告生成工具
51 | MCP-->>REP: 提供报告生成能力
52 | REP-->>ResearchEnv: 返回综合报告
53 |
54 | ResearchEnv-->>Main: 返回研究结果
55 |
56 | Main->>BattleEnv: 注册博弈智能体 (register_agent)
57 | Note over Main,BattleEnv: 重置每个智能体的执行状态 (reset_execution_state)
58 |
59 | Main->>BattleEnv: 运行博弈 (run)
60 |
61 | Note over BattleEnv: 博弈环节
62 | BattleEnv->>SA: 邀请参与博弈
63 | BattleEnv->>RA: 邀请参与博弈
64 | BattleEnv->>HMA: 邀请参与博弈
65 | BattleEnv->>TAA: 邀请参与博弈
66 |
67 | Note over SA,TAA: Agents讨论并投票
68 | SA->>MCP: 使用博弈工具(发言/投票)
69 | RA->>MCP: 使用博弈工具(发言/投票)
70 | HMA->>MCP: 使用博弈工具(发言/投票)
71 | TAA->>MCP: 使用博弈工具(发言/投票)
72 |
73 | BattleEnv-->>Main: 返回博弈结果(final_decision)
74 |
75 | Main->>Memory: 存储分析和博弈结果
76 |
77 | Main->>User: 显示结果 (display_results)
78 | Note over Main,User: 输出格式:文本或JSON
79 | ```
80 |
--------------------------------------------------------------------------------
/src/agent/technical_analysis.py:
--------------------------------------------------------------------------------
1 | from typing import Any, List, Optional
2 |
3 | from pydantic import Field
4 |
5 | from src.agent.mcp import MCPAgent
6 | from src.prompt.mcp import NEXT_STEP_PROMPT_ZN
7 | from src.prompt.technical_analysis import TECHNICAL_ANALYSIS_SYSTEM_PROMPT
8 | from src.schema import Message
9 | from src.tool import Terminate, ToolCollection
10 | from src.tool.technical_analysis import TechnicalAnalysisTool
11 |
12 |
13 | class TechnicalAnalysisAgent(MCPAgent):
14 | """Technical analysis agent applying technical indicators to stock analysis."""
15 |
16 | name: str = "technical_analysis_agent"
17 | description: str = (
18 | "Applies technical analysis and chart patterns to stock market analysis."
19 | )
20 |
21 | system_prompt: str = TECHNICAL_ANALYSIS_SYSTEM_PROMPT
22 | next_step_prompt: str = NEXT_STEP_PROMPT_ZN
23 |
24 | # Initialize with FinGenius tools with proper type annotation
25 | available_tools: ToolCollection = Field(
26 | default_factory=lambda: ToolCollection(
27 | TechnicalAnalysisTool(),
28 | Terminate(),
29 | )
30 | )
31 | special_tool_names: List[str] = Field(default_factory=lambda: [Terminate().name])
32 |
33 | async def run(
34 | self, request: Optional[str] = None, stock_code: Optional[str] = None
35 | ) -> Any:
36 | """Run technical analysis on the given stock using technical indicators.
37 |
38 | Args:
39 | request: Optional initial request to process. If provided, overrides stock_code parameter.
40 | stock_code: The stock code/ticker to analyze
41 |
42 | Returns:
43 | Dictionary containing technical analysis insights
44 | """
45 | # If stock_code is provided but request is not, create request from stock_code
46 | if stock_code and not request:
47 | # Set up system message about the stock being analyzed
48 | self.memory.add_message(
49 | Message.system_message(
50 | f"你正在对股票 {stock_code} 进行技术面分析。请评估价格走势、图表形态和关键技术指标,形成短中期交易策略。"
51 | )
52 | )
53 | request = f"请分析 {stock_code} 的技术指标和图表形态。"
54 |
55 | # Call parent implementation with the request
56 | return await super().run(request)
57 |
--------------------------------------------------------------------------------
/src/tool/tool_collection.py:
--------------------------------------------------------------------------------
1 | """Collection classes for managing multiple tools."""
2 | from typing import Any, Dict, List
3 |
4 | from src.exceptions import ToolError
5 | from src.logger import logger
6 | from src.tool.base import BaseTool, ToolFailure, ToolResult
7 |
8 |
9 | class ToolCollection:
10 | """A collection of defined tools."""
11 |
12 | class Config:
13 | arbitrary_types_allowed = True
14 |
15 | def __init__(self, *tools: BaseTool):
16 | self.tools = tools
17 | self.tool_map = {tool.name: tool for tool in tools}
18 |
19 | def __iter__(self):
20 | return iter(self.tools)
21 |
22 | def to_params(self) -> List[Dict[str, Any]]:
23 | return [tool.to_param() for tool in self.tools]
24 |
25 | async def execute(
26 | self, *, name: str, tool_input: Dict[str, Any] = None
27 | ) -> ToolResult:
28 | tool = self.tool_map.get(name)
29 | if not tool:
30 | return ToolFailure(error=f"Tool {name} is invalid")
31 | try:
32 | tool_input = tool_input or {}
33 | result = await tool(**tool_input)
34 | return result
35 | except ToolError as e:
36 | return ToolFailure(error=e.message)
37 |
38 | async def execute_all(self) -> List[ToolResult]:
39 | """Execute all tools in the collection sequentially."""
40 | results = []
41 | for tool in self.tools:
42 | try:
43 | result = await tool()
44 | results.append(result)
45 | except ToolError as e:
46 | results.append(ToolFailure(error=e.message))
47 | return results
48 |
49 | def get_tool(self, name: str) -> BaseTool:
50 | return self.tool_map.get(name)
51 |
52 | def add_tool(self, tool: BaseTool):
53 | """Add a single tool to the collection.
54 |
55 | If a tool with the same name already exists, it will be skipped and a warning will be logged.
56 | """
57 | if tool.name in self.tool_map:
58 | logger.warning(f"Tool {tool.name} already exists in collection, skipping")
59 | return self
60 |
61 | self.tools += (tool,)
62 | self.tool_map[tool.name] = tool
63 | return self
64 |
65 | def add_tools(self, *tools: BaseTool):
66 | """Add multiple tools to the collection.
67 |
68 | If any tool has a name conflict with an existing tool, it will be skipped and a warning will be logged.
69 | """
70 | for tool in tools:
71 | self.add_tool(tool)
72 | return self
73 |
--------------------------------------------------------------------------------
/src/tool/battle.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Dict, Optional
2 |
3 | from pydantic import Field
4 |
5 | from src.tool.base import BaseTool, ToolFailure, ToolResult
6 |
7 |
8 | class Battle(BaseTool):
9 | """Tool for agents to interact in the battle environment."""
10 |
11 | name: str = "battle"
12 | description: str = "在对战环境中进行互动的工具。您可以发表观点和/或投票。"
13 | parameters: Dict[str, Any] = {
14 | "type": "object",
15 | "properties": {
16 | "vote": {
17 | "type": "string",
18 | "enum": ["bullish", "bearish"],
19 | "description": "您对股票走势的最终投票。'bullish'表示看涨,'bearish'表示看跌。",
20 | },
21 | "speak": {
22 | "type": "string",
23 | "description": "您想与其他智能体分享的观点内容。",
24 | },
25 | },
26 | "required": ["vote", "speak"],
27 | }
28 |
29 | agent_id: str = Field(..., description="The ID of the agent using this tool")
30 | controller: Optional[Any] = Field(default=None)
31 |
32 | async def execute(
33 | self, speak: Optional[str] = None, vote: Optional[str] = None
34 | ) -> ToolResult:
35 | """
36 | Execute the battle action with speaking and/or voting
37 | """
38 | if not self.controller:
39 | return ToolFailure(error="对战环境未初始化")
40 |
41 | result = None
42 | formatted_output = ""
43 |
44 | # Validate vote if provided
45 | if vote is not None:
46 | if vote.lower() not in ["bullish", "bearish"]:
47 | return ToolFailure(error="投票必须是'bullish'(看涨)或'bearish'(看跌)")
48 | vote = vote.lower()
49 |
50 | # Handle speaking and voting
51 | if speak is not None and speak.strip():
52 | # Format the output as: AgentName[vote content]: speak content
53 | formatted_output = f"{self.agent_id}[{vote if vote else '未投票'}]: {speak}"
54 |
55 | # Handle the actual speak action
56 | result = await self.controller.handle_speak(self.agent_id, speak)
57 | if result.error:
58 | return result
59 |
60 | # Handle voting if vote is provided
61 | if vote is not None:
62 | result = await self.controller.handle_vote(self.agent_id, vote)
63 |
64 | # If there was no speak content, create a formatted output for just the vote
65 | if not formatted_output:
66 | formatted_output = f"{self.agent_id}[{vote}]: "
67 |
68 | # If neither speak nor vote was provided
69 | if result is None:
70 | return ToolFailure(error="您必须提供发言内容和投票选项")
71 |
72 | # Update the result output with the formatted display
73 | if result.output is None:
74 | result.output = formatted_output
75 | else:
76 | result.output = formatted_output
77 |
78 | return result
79 |
--------------------------------------------------------------------------------
/src/prompt/mcp.py:
--------------------------------------------------------------------------------
1 | """Prompts for the MCP Agent."""
2 |
3 | SYSTEM_PROMPT = """You are an AI assistant with access to a Model Context Protocol (MCP) server.
4 | You can use the tools provided by the MCP server to complete tasks.
5 | The MCP server will dynamically expose tools that you can use - always check the available tools first.
6 |
7 | When using an MCP tool:
8 | 1. Choose the appropriate tool based on your task requirements
9 | 2. Provide properly formatted arguments as required by the tool
10 | 3. Observe the results and use them to determine next steps
11 | 4. Tools may change during operation - new tools might appear or existing ones might disappear
12 |
13 | Follow these guidelines:
14 | - Call tools with valid parameters as documented in their schemas
15 | - Handle errors gracefully by understanding what went wrong and trying again with corrected parameters
16 | - For multimedia responses (like images), you'll receive a description of the content
17 | - Complete user requests step by step, using the most appropriate tools
18 | - If multiple tools need to be called in sequence, make one call at a time and wait for results
19 |
20 | Remember to clearly explain your reasoning and actions to the user.
21 | """
22 |
23 | SYSTEM_PROMPT_ZN = """你是一个AI助手,可以访问Model Context Protocol (MCP) 服务器。
24 | 你可以使用MCP服务器提供的工具来完成任务。
25 | MCP服务器会动态展示你可以使用的工具 - 始终首先检查可用工具。
26 |
27 | 使用MCP Tool时:
28 | 1. 根据任务要求选择合适的Tool
29 | 2. 按照Tool的要求提供正确格式的参数
30 | 3. 观察结果,并根据结果决定下一步行动
31 | 4. Tools在操作过程中可能会发生变化 - 新的Tool可能会出现,或现有的Tool可能会消失
32 |
33 | 遵循以下指南:
34 | - 使用文档中描述的有效参数调用Tool
35 | - 通过理解错误的原因并用更正后的参数重试来优雅地处理错误
36 | - 对于多媒体响应(如图像),你将收到内容的描述
37 | - 按步骤完成用户请求,使用最合适的Tool
38 | - 如果需要按顺序调用多个Tool,每次调用一个并等待结果
39 |
40 | 记得清晰地向用户解释你的推理和行动。
41 | """
42 |
43 | NEXT_STEP_PROMPT = """Based on the current state and available tools, what should be done next?
44 | Think step by step about the problem and identify which MCP tool would be most helpful for the current stage.
45 | If you've already made progress, consider what additional information you need or what actions would move you closer to completing the task.
46 | """
47 |
48 | NEXT_STEP_PROMPT_ZN = """根据当前状态和可用的 Tools,接下来应该做什么?
49 |
50 | **工作流程指导:**
51 | 1. **数据收集阶段**:如果还没有使用专业工具获取数据,请选择合适的工具执行分析
52 | 2. **深度分析阶段**:如果已经获得了工具数据,请基于数据进行专业分析和解读,不要直接terminate
53 | 3. **综合结论阶段**:当你完成了专业分析并得出结论后,使用terminate工具结束
54 |
55 | **重要提醒:**
56 | - 获得工具数据后,必须进行专业的分析思考,解读数据含义,提供专业见解
57 | - 不要仅仅展示原始数据,要提供有价值的分析结论
58 | - 在给出最终专业分析结论后,才能使用terminate工具结束任务
59 |
60 | 逐步思考问题,确定当前阶段最需要的行动。
61 | """
62 |
63 | # Additional specialized prompts
64 | TOOL_ERROR_PROMPT = """You encountered an error with the tool '{tool_name}'.
65 | Try to understand what went wrong and correct your approach.
66 | Common issues include:
67 | - Missing or incorrect parameters
68 | - Invalid parameter formats
69 | - Using a tool that's no longer available
70 | - Attempting an operation that's not supported
71 |
72 | Please check the tool specifications and try again with corrected parameters.
73 | """
74 |
75 | MULTIMEDIA_RESPONSE_PROMPT = """You've received a multimedia response (image, audio, etc.) from the tool '{tool_name}'.
76 | This content has been processed and described for you.
77 | Use this information to continue the task or provide insights to the user.
78 | """
79 |
--------------------------------------------------------------------------------
/src/tool/base.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | from typing import Any, Dict, Optional
3 |
4 | from pydantic import BaseModel, Field
5 | import time
6 | import logging
7 | from datetime import datetime
8 | from typing import Optional, Dict, Any
9 |
10 |
11 | class BaseTool(ABC, BaseModel):
12 | name: str
13 | description: str
14 | parameters: Optional[dict] = None
15 |
16 | class Config:
17 | arbitrary_types_allowed = True
18 |
19 | async def __call__(self, **kwargs) -> Any:
20 | """Execute the tool with given parameters."""
21 | return await self.execute(**kwargs)
22 |
23 | @abstractmethod
24 | async def execute(self, **kwargs) -> Any:
25 | """Execute the tool with given parameters."""
26 |
27 | def to_param(self) -> Dict:
28 | """Convert tool to function call format."""
29 | return {
30 | "type": "function",
31 | "function": {
32 | "name": self.name,
33 | "description": self.description,
34 | "parameters": self.parameters,
35 | },
36 | }
37 |
38 |
39 | class ToolResult(BaseModel):
40 | """Represents the result of a tool execution."""
41 |
42 | output: Any = Field(default=None)
43 | error: Optional[str] = Field(default=None)
44 | base64_image: Optional[str] = Field(default=None)
45 | system: Optional[str] = Field(default=None)
46 |
47 | class Config:
48 | arbitrary_types_allowed = True
49 |
50 | def __bool__(self):
51 | return any(getattr(self, field) for field in self.model_fields)
52 |
53 | def __add__(self, other: "ToolResult"):
54 | def combine_fields(
55 | field: Optional[str], other_field: Optional[str], concatenate: bool = True
56 | ):
57 | if field and other_field:
58 | if concatenate:
59 | return field + other_field
60 | raise ValueError("Cannot combine tool results")
61 | return field or other_field
62 |
63 | return ToolResult(
64 | output=combine_fields(self.output, other.output),
65 | error=combine_fields(self.error, other.error),
66 | base64_image=combine_fields(self.base64_image, other.base64_image, False),
67 | system=combine_fields(self.system, other.system),
68 | )
69 |
70 | def __str__(self):
71 | return f"Error: {self.error}" if self.error else str(self.output)
72 |
73 | def replace(self, **kwargs):
74 | """Returns a new ToolResult with the given fields replaced."""
75 | # return self.copy(update=kwargs)
76 | return type(self)(**{**self.dict(), **kwargs})
77 |
78 |
79 | class CLIResult(ToolResult):
80 | """A ToolResult that can be rendered as a CLI output."""
81 |
82 |
83 | class ToolFailure(ToolResult):
84 | """A ToolResult that represents a failure."""
85 |
86 |
87 | def get_recent_trading_day(date_format: str = "%Y-%m-%d") -> str:
88 | """
89 | 获取最近的交易日(跳过周末)
90 |
91 | A股交易日规则:
92 | - 周一到周五是交易日
93 | - 周六周日是休市
94 |
95 | Args:
96 | date_format: 返回日期格式,默认为"%Y-%m-%d"
97 |
98 | Returns:
99 | str: 最近的交易日日期字符串
100 | """
101 | from datetime import datetime, timedelta
102 |
103 | current_date = datetime.now()
104 |
105 | # 如果是周末,则回退到最近的交易日
106 | while current_date.weekday() >= 5: # 周六=5, 周日=6
107 | current_date -= timedelta(days=1)
108 |
109 | return current_date.strftime(date_format)
110 |
--------------------------------------------------------------------------------
/src/tool/stock_info_request.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import datetime
3 | from typing import Any, Dict
4 |
5 | import efinance as ef
6 | import pandas as pd
7 | from pydantic import Field
8 |
9 | from src.tool.base import BaseTool, ToolResult, get_recent_trading_day
10 |
11 |
12 | class StockInfoResponse(ToolResult):
13 | """Response model for stock information, extending ToolResult."""
14 |
15 | output: Dict[str, Any] = Field(default_factory=dict)
16 |
17 | @property
18 | def current_trading_day(self) -> str:
19 | """Get the current trading day from the output."""
20 | return self.output.get("current_trading_day", "")
21 |
22 | @property
23 | def basic_info(self) -> Dict[str, Any]:
24 | """Get the basic stock information from the output."""
25 | return self.output.get("basic_info", {})
26 |
27 |
28 | class StockInfoRequest(BaseTool):
29 | """Tool to fetch basic information about a stock with the current trading date."""
30 |
31 | name: str = "stock_info_request"
32 | description: str = "获取股票基础信息和当前交易日,返回JSON格式的结果。"
33 | parameters: Dict[str, Any] = {
34 | "type": "object",
35 | "properties": {"stock_code": {"type": "string", "description": "股票代码"}},
36 | "required": ["stock_code"],
37 | }
38 |
39 | MAX_RETRIES: int = 3
40 | RETRY_DELAY: int = 1 # seconds
41 |
42 | async def execute(self, stock_code: str, **kwargs) -> StockInfoResponse | None:
43 | """
44 | Execute the tool to fetch stock information.
45 |
46 | Args:
47 | stock_code: The stock code to query
48 |
49 | Returns:
50 | StockInfoResponse containing stock information and current trading date
51 | """
52 | for attempt in range(1, self.MAX_RETRIES + 1):
53 | try:
54 | # Get current trading day
55 | trading_day = get_recent_trading_day()
56 |
57 | # Fetch stock information
58 | data = ef.stock.get_base_info(stock_code)
59 |
60 | # Convert data to dict format based on its type
61 | basic_info = self._format_data(data)
62 |
63 | # Create and return the response
64 | return StockInfoResponse(
65 | output={
66 | "current_trading_day": trading_day,
67 | "basic_info": basic_info,
68 | }
69 | )
70 |
71 | except Exception as e:
72 | if attempt < self.MAX_RETRIES:
73 | await asyncio.sleep(float(self.RETRY_DELAY))
74 | return StockInfoResponse(
75 | error=f"获取股票信息失败 ({self.MAX_RETRIES}次尝试): {str(e)}"
76 | )
77 |
78 | @staticmethod
79 | def _format_data(data: Any) -> Dict[str, Any]:
80 | """
81 | Format data to a JSON-serializable dictionary.
82 |
83 | Args:
84 | data: The data to format, typically from efinance
85 |
86 | Returns:
87 | A dictionary representation of the data
88 | """
89 | if isinstance(data, pd.DataFrame):
90 | return data.to_dict(orient="records")[0] if len(data) > 0 else {}
91 | elif isinstance(data, pd.Series):
92 | return data.to_dict()
93 | elif isinstance(data, dict):
94 | return data
95 | elif isinstance(data, (int, float, str, bool)):
96 | return {"value": data}
97 | else:
98 | return {"value": str(data)}
99 |
100 |
101 | if __name__ == "__main__":
102 | import json
103 | import sys
104 |
105 | # Use default stock code "600519" (Maotai) if not provided
106 | code = sys.argv[1] if len(sys.argv) > 1 else "600519"
107 |
108 | # Create and run the tool
109 | tool = StockInfoRequest()
110 | result = asyncio.run(tool.execute(code))
111 |
112 | # Print the result
113 | if result.error:
114 | print(f"Error: {result.error}")
115 | else:
116 | print(json.dumps(result.output, ensure_ascii=False, indent=2))
117 |
--------------------------------------------------------------------------------
/src/tool/risk_control.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 |
3 | from src.logger import logger
4 | from src.tool.base import BaseTool, ToolResult
5 | from src.tool.financial_deep_search.risk_control_data import get_risk_control_data
6 |
7 |
8 | class RiskControlTool(BaseTool):
9 | """Tool for retrieving risk control data for stocks."""
10 |
11 | name: str = "risk_control_tool"
12 | description: str = (
13 | "获取股票的风控数据,包括财务数据(现金流量表,资产负债表,利润表)(financial)和法务公告数据(legal)。"
14 | "支持最大重试机制,适合大模型自动调用。返回结构化字典。"
15 | )
16 | parameters: dict = {
17 | "type": "object",
18 | "properties": {
19 | "stock_code": {
20 | "type": "string",
21 | "description": "股票代码(必填),如'600519'(贵州茅台)、'000001'(平安银行)、'300750'(宁德时代)等",
22 | },
23 | "max_count": {
24 | "type": "integer",
25 | "description": "公告数据获取数量上限,建议不超过10,用于限制返回的法律公告数量",
26 | "default": 10,
27 | },
28 | "period": {
29 | "type": "string",
30 | "description": "财务数据周期类型,精确可选值:'按年度'(年报数据)、'按报告期'(含季报)、'按单季度'(环比数据)",
31 | "default": "按年度",
32 | },
33 | "max_retry": {
34 | "type": "integer",
35 | "description": "数据获取最大重试次数,范围1-5,用于处理网络波动情况",
36 | "default": 3,
37 | },
38 | "sleep_seconds": {
39 | "type": "integer",
40 | "description": "重试间隔秒数,范围1-10,防止频繁请求被限制",
41 | "default": 1,
42 | },
43 | },
44 | "required": ["stock_code"],
45 | }
46 |
47 | async def execute(
48 | self,
49 | stock_code: str,
50 | max_count: int = 10,
51 | period: str = "按年度",
52 | max_retry: int = 3,
53 | sleep_seconds: int = 1,
54 | **kwargs,
55 | ) -> ToolResult:
56 | """
57 | Get risk control data for a single stock with retry mechanism.
58 |
59 | Args:
60 | stock_code: Stock code
61 | max_count: Maximum number of announcements to retrieve
62 | period: Financial data period (按年度/按报告期/按单季度)
63 | max_retry: Maximum retry attempts
64 | sleep_seconds: Seconds to wait between retries
65 | **kwargs: Additional parameters
66 |
67 | Returns:
68 | ToolResult: Result containing risk control data
69 | """
70 | try:
71 | # Execute synchronous operation in thread pool to avoid blocking event loop
72 | result = await asyncio.to_thread(
73 | get_risk_control_data,
74 | stock_code=stock_code,
75 | max_count=max_count,
76 | period=period,
77 | include_announcements=True,
78 | include_financial=True,
79 | max_retry=max_retry,
80 | sleep_seconds=sleep_seconds,
81 | )
82 |
83 | if "error" in result:
84 | return ToolResult(error=result["error"])
85 |
86 | return ToolResult(output=result)
87 |
88 | except Exception as e:
89 | error_msg = f"Failed to get risk control data: {str(e)}"
90 | logger.error(error_msg)
91 | return ToolResult(error=error_msg)
92 |
93 |
94 | if __name__ == "__main__":
95 | # Direct tool testing
96 | import sys
97 |
98 | code = sys.argv[1] if len(sys.argv) > 1 else "600519"
99 |
100 | # Get risk control data
101 | tool = RiskControlTool()
102 | result = asyncio.run(tool.execute(stock_code=code))
103 |
104 | # Output results
105 | if result.error:
106 | print(f"Failed: {result.error}")
107 | else:
108 | output = result.output
109 | print(f"Success!")
110 | print(
111 | f"- Financial Data: {'Retrieved' if output['financial'] else 'Not Retrieved'}"
112 | )
113 | if output['legal']:
114 | legal_info = f"Retrieved ({len(output['legal'])} items)"
115 | else:
116 | legal_info = "Not Retrieved"
117 | print(f"- Legal Data: {legal_info}")
118 |
--------------------------------------------------------------------------------
/src/environment/base.py:
--------------------------------------------------------------------------------
1 | from abc import abstractmethod
2 | from enum import Enum
3 | from typing import Any, Dict, List, Optional, Union
4 |
5 | from pydantic import BaseModel, Field
6 |
7 | from src.agent.base import BaseAgent
8 | from src.logger import logger
9 |
10 |
11 | class BaseEnvironment(BaseModel):
12 | """Base environment class for all environments"""
13 |
14 | name: str = Field(default="base_environment")
15 | description: str = Field(default="Base environment class")
16 | agents: Dict[str, BaseAgent] = Field(default_factory=dict)
17 | max_steps: int = Field(default=3, description="Maximum steps for each agent")
18 |
19 | class Config:
20 | arbitrary_types_allowed = True
21 |
22 | @classmethod
23 | async def create(cls, **kwargs) -> "BaseEnvironment":
24 | """Factory method to create and initialize an environment"""
25 | instance = cls(**kwargs)
26 | await instance.initialize()
27 | return instance
28 |
29 | async def initialize(self) -> None:
30 | """Initialize the environment. Override in subclasses."""
31 | logger.info(f"Initializing {self.name} environment (max_steps={self.max_steps})")
32 |
33 | def register_agent(self, agent: BaseAgent) -> None:
34 | """Register an agent with the environment"""
35 | self.agents[agent.name] = agent
36 | logger.debug(f"Agent {agent.name} registered in {self.name}")
37 |
38 | # Alias for register_agent for better API flexibility
39 | def add_agent(self, agent: BaseAgent) -> None:
40 | """Alias for register_agent"""
41 | self.register_agent(agent)
42 |
43 | def get_agent(self, agent_name: str) -> Optional[BaseAgent]:
44 | """Get an agent by name"""
45 | return self.agents.get(agent_name)
46 |
47 | @abstractmethod
48 | async def run(self, **kwargs) -> Dict[str, Any]:
49 | """Run the environment. Override in subclasses."""
50 | raise NotImplementedError("Subclasses must implement run method")
51 |
52 | async def cleanup(self) -> None:
53 | """Clean up resources when done"""
54 | logger.info(f"Cleaning up {self.name} environment")
55 |
56 |
57 | class EnvironmentType(str, Enum):
58 | """Enum of available environment types"""
59 |
60 | RESEARCH = "research"
61 | BATTLE = "battle"
62 |
63 |
64 | class EnvironmentFactory:
65 | """Factory for creating different types of environments with support for multiple agents"""
66 |
67 | @staticmethod
68 | async def create_environment(
69 | environment_type: EnvironmentType,
70 | agents: Union[BaseAgent, List[BaseAgent], Dict[str, BaseAgent]] = None,
71 | **kwargs,
72 | ) -> BaseEnvironment:
73 | """Create and initialize an environment of the specified type
74 |
75 | Args:
76 | environment_type: The type of environment to create
77 | agents: One or more agents to add to the environment
78 | **kwargs: Additional arguments to pass to the environment constructor
79 |
80 | Returns:
81 | An initialized environment instance
82 | """
83 | from src.environment.battle import BattleEnvironment
84 | from src.environment.research import ResearchEnvironment
85 |
86 | environments = {
87 | EnvironmentType.RESEARCH: ResearchEnvironment,
88 | EnvironmentType.BATTLE: BattleEnvironment,
89 | }
90 |
91 | environment_class = environments.get(environment_type)
92 | if not environment_class:
93 | raise ValueError(f"Unknown environment type: {environment_type}")
94 |
95 | # Create the environment
96 | environment = await environment_class.create(**kwargs)
97 |
98 | # Add agents if provided
99 | if agents:
100 | if isinstance(agents, BaseAgent):
101 | environment.add_agent(agents)
102 | elif isinstance(agents, list):
103 | for agent in agents:
104 | environment.add_agent(agent)
105 | elif isinstance(agents, dict):
106 | for agent in agents.values():
107 | environment.add_agent(agent)
108 |
109 | return environment
110 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | ### Project-specific ###
2 | # Logs
3 | logs/
4 |
5 | # Workspace
6 | workspace/
7 |
8 | ### Python ###
9 | # Byte-compiled / optimized / DLL files
10 | __pycache__/
11 | *.py[cod]
12 | *$py.class
13 |
14 | # C extensions
15 | *.so
16 |
17 | # Distribution / packaging
18 | .Python
19 | build/
20 | develop-eggs/
21 | dist/
22 | downloads/
23 | eggs/
24 | .eggs/
25 | lib/
26 | lib64/
27 | parts/
28 | sdist/
29 | var/
30 | wheels/
31 | share/python-wheels/
32 | *.egg-info/
33 | .installed.cfg
34 | *.egg
35 | MANIFEST
36 |
37 | # PyInstaller
38 | # Usually these files are written by a python script from a template
39 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
40 | *.manifest
41 | *.spec
42 |
43 | # Installer logs
44 | pip-log.txt
45 | pip-delete-this-directory.txt
46 |
47 | # Unit test / coverage reports
48 | htmlcov/
49 | .tox/
50 | .nox/
51 | .coverage
52 | .coverage.*
53 | .cache
54 | nosetests.xml
55 | coverage.xml
56 | *.cover
57 | *.py,cover
58 | .hypothesis/
59 | .pytest_cache/
60 | cover/
61 |
62 | # Translations
63 | *.mo
64 | *.pot
65 |
66 | # Django stuff:
67 | *.log
68 | local_settings.py
69 | db.sqlite3
70 | db.sqlite3-journal
71 |
72 | # Flask stuff:
73 | instance/
74 | .webassets-cache
75 |
76 | # Scrapy stuff:
77 | .scrapy
78 |
79 | # Sphinx documentation
80 | docs/_build/
81 |
82 | # PyBuilder
83 | .pybuilder/
84 | target/
85 |
86 | # Jupyter Notebook
87 | .ipynb_checkpoints
88 |
89 | # IPython
90 | profile_default/
91 | ipython_config.py
92 |
93 | # pyenv
94 | # For a library or package, you might want to ignore these files since the code is
95 | # intended to run in multiple environments; otherwise, check them in:
96 | # .python-version
97 |
98 | # pipenv
99 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
100 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
101 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
102 | # install all needed dependencies.
103 | #Pipfile.lock
104 |
105 | # UV
106 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
107 | # This is especially recommended for binary packages to ensure reproducibility, and is more
108 | # commonly ignored for libraries.
109 | #uv.lock
110 |
111 | # poetry
112 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
113 | # This is especially recommended for binary packages to ensure reproducibility, and is more
114 | # commonly ignored for libraries.
115 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
116 | #poetry.lock
117 |
118 | # pdm
119 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
120 | #pdm.lock
121 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
122 | # in version control.
123 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
124 | .pdm.toml
125 | .pdm-python
126 | .pdm-build/
127 |
128 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
129 | __pypackages__/
130 |
131 | # Celery stuff
132 | celerybeat-schedule
133 | celerybeat.pid
134 |
135 | # SageMath parsed files
136 | *.sage.py
137 |
138 | # Environments
139 | .env
140 | .venv
141 | env/
142 | venv/
143 | ENV/
144 | env.bak/
145 | venv.bak/
146 |
147 | # Spyder project settings
148 | .spyderproject
149 | .spyproject
150 |
151 | # Rope project settings
152 | .ropeproject
153 |
154 | # mkdocs documentation
155 | /site
156 |
157 | # mypy
158 | .mypy_cache/
159 | .dmypy.json
160 | dmypy.json
161 |
162 | # Pyre type checker
163 | .pyre/
164 |
165 | # pytype static type analyzer
166 | .pytype/
167 |
168 | # Cython debug symbols
169 | cython_debug/
170 |
171 | # PyCharm
172 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
173 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
174 | # and can be added to the global gitignore or merged into this file. For a more nuclear
175 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
176 | .idea/*
177 |
178 | # PyPI configuration file
179 | .pypirc
180 |
181 | ### Visual Studio Code ###
182 | .vscode/*
183 | !.vscode/settings.json
184 | !.vscode/tasks.json
185 | !.vscode/launch.json
186 | !.vscode/extensions.json
187 | !.vscode/*.code-snippets
188 |
189 | # Local History for Visual Studio Code
190 | .history/
191 |
192 | # Built Visual Studio Code Extensions
193 | *.vsix
194 |
195 | # OSX
196 | .DS_Store
197 |
198 | # node
199 | node_modules
200 |
--------------------------------------------------------------------------------
/src/prompt/sentiment.py:
--------------------------------------------------------------------------------
1 | SENTIMENT_SYSTEM_PROMPT = """
2 | # 股票舆情分析专家
3 |
4 | 你是一位专业的股票舆情分析师,专注于A股市场的舆论监测和情感分析。你擅长通过多渠道信息收集、数据分析和情感计算,为投资者提供客观、准确的股票舆情报告。
5 |
6 | ## 🎯 核心职责
7 | - 实时监测目标股票的网络舆情动态
8 | - 分析媒体报道、社交平台、投资论坛的情感倾向
9 | - 识别影响股价的关键舆情事件和传播节点
10 | - 评估舆情对股票短期和中期走势的潜在影响
11 | - 提供基于数据的客观分析,不给出投资建议
12 |
13 | ## 🔧 工具使用规范
14 |
15 | **主要工具说明:**
16 | - **web_search**:进行实时信息收集,获取舆情数据
17 | - **terminate**:当你完成了完整的舆情分析报告后,必须使用此工具结束任务
18 |
19 | ⚠️ **重要提醒**:当你完成了舆情分析并准备输出最终报告时,请立即使用terminate工具结束任务,避免无限循环。
20 |
21 | **必须使用web_search工具进行实时信息收集,按以下顺序执行:**
22 |
23 | ## 🎯 权威数据源优先策略
24 |
25 | ### 核心财经媒体(优先关注)
26 | - **财联社** (cls.cn) - 专业财经资讯平台
27 | - **新浪财经** (finance.sina.com.cn) - 财经新闻和分析
28 | - **证券时报** (stcn.com) - 权威证券媒体
29 | - **上海证券报** (cnstock.com) - 官方证券媒体
30 | - **中国证券报** (cs.com.cn) - 权威证券报道
31 | - **第一财经** (yicai.com) - 专业财经媒体
32 | - **21世纪经济报道** (21jingji.com) - 权威经济媒体
33 |
34 | ### 投资社区和论坛(关注情绪)
35 | - **雪球** (xueqiu.com) - 专业投资社区
36 | - **东方财富股吧** (guba.eastmoney.com) - 投资者讨论平台
37 | - **同花顺** (10jqka.com.cn) - 投资资讯平台
38 | - **金融界** (jrj.com.cn) - 财经门户网站
39 |
40 | ## 📝 搜索策略(根据搜索引擎自动调整)
41 |
42 | ### 🔄 智能搜索策略选择
43 |
44 | ### 🔍 百度/必应搜索引擎策略(关键词匹配)
45 | **当搜索引擎为百度或必应时,使用以下关键词搜索:**
46 |
47 | 1. **📰 权威媒体新闻搜索**
48 | - 搜索:"{公司名称} {股票代码} 财联社 最新消息"
49 | - 搜索:"{公司名称} {股票代码} 新浪财经 最新资讯"
50 | - 搜索:"{公司名称} 证券时报 报道"
51 | - 搜索:"{公司名称} {股票代码} 上海证券报 新闻"
52 | - 搜索:"{公司名称} 第一财经 资讯"
53 | - 搜索:"{公司名称} {股票代码} 中国证券报 消息"
54 |
55 | 2. **💬 投资者情绪监测**
56 | - 搜索:"{公司名称} {股票代码} 雪球 讨论 观点"
57 | - 搜索:"{公司名称} {股票代码} 东方财富股吧 热议"
58 | - 搜索:"{公司名称} 同花顺 投资者 讨论"
59 | - 搜索:"{公司名称} {股票代码} 金融界 评论 分析"
60 | - 搜索:"{公司名称} {股票代码} 股民 看法 观点"
61 |
62 | 3. **🏢 公司公告和业绩追踪**
63 | - 搜索:"{公司名称} {股票代码} 公司公告 最新公告"
64 | - 搜索:"{公司名称} 财报 业绩 季报"
65 | - 搜索:"{公司名称} {股票代码} 年报 半年报"
66 | - 搜索:"{公司名称} 重大事项 重组"
67 | - 搜索:"{公司名称} {股票代码} 分红 派息"
68 |
69 | 4. **📊 机构研报和分析**
70 | - 搜索:"{公司名称} {股票代码} 研报 分析师 券商"
71 | - 搜索:"{公司名称} 机构评级 目标价"
72 | - 搜索:"{公司名称} {股票代码} 投资建议 买入 卖出"
73 | - 搜索:"{公司名称} {行业名称} 政策 监管 影响"
74 | - 搜索:"{公司名称} {股票代码} 机构持仓 基金"
75 |
76 | 5. **⚠️ 风险预警监测**
77 | - 搜索:"{公司名称} {股票代码} 风险 预警 风险提示"
78 | - 搜索:"{公司名称} 负面 问题 争议"
79 | - 搜索:"{公司名称} {股票代码} 违规 处罚 监管"
80 | - 搜索:"{公司名称} 纠纷 诉讼 案件"
81 | - 搜索:"{公司名称} {股票代码} 停牌 复牌 异常"
82 |
83 | ### 📋 执行步骤
84 | 1. **搜索引擎检测**:先执行一次简单测试搜索,根据搜索结果的source字段确定当前使用的搜索引擎
85 | 2. **策略选择**:根据搜索引擎类型选择对应的搜索策略
86 | 3. **信息收集**:按照选定策略执行上述5个维度的搜索
87 | 4. **结果标注**:在输出结果中明确标注使用的搜索引擎和搜索策略类型
88 |
89 | **重点关注**:官方媒体报道、投资者情绪、公司公告、机构观点、风险预警
90 |
91 | ## 📊 数据源权重设置
92 |
93 | ### 一级权威源(权重:高)
94 | - 财联社、证券时报、上海证券报、中国证券报
95 | - 交易所官网、证监会官网、公司官网
96 |
97 | ### 二级专业源(权重:中高)
98 | - 东方财富、新浪财经、第一财经、21世纪经济报道
99 | - 专业研究机构报告、券商研报
100 |
101 | ### 三级社区源(权重:中)
102 | - 雪球、东方财富股吧、同花顺、金融界
103 | - 投资者论坛、社交媒体讨论
104 |
105 | ## 📋 分析工作流程
106 |
107 | ### 第一步:信息收集
108 | - 使用web_search工具按上述5个维度收集信息
109 | - 记录信息来源、发布时间、可信度等级
110 | - 筛选有效信息,排除无关内容
111 |
112 | ### 第二步:情感分析
113 | - 对收集到的文本进行情感分类(正面/负面/中性)
114 | - 计算情感强度和情感分布比例
115 | - 识别情感转折点和异常情感波动
116 |
117 | ### 第三步:传播分析
118 | - 分析信息传播路径和影响范围
119 | - 识别关键意见领袖和影响节点
120 | - 评估信息传播速度和覆盖面
121 |
122 | ### 第四步:影响评估
123 | - 评估舆情对股价的潜在影响程度
124 | - 识别短期和中期的关键风险点
125 | - 分析舆情与股价走势的关联性
126 |
127 | ### 第五步:趋势预判
128 | - 基于历史数据和当前趋势进行合理推断
129 | - 识别可能的舆情转折点
130 | - 提供后续关注重点
131 |
132 | ## 📊 输出格式规范
133 |
134 | ### 🔍 舆情概况
135 | - **股票代码**:[股票代码]
136 | - **公司名称**:[公司全称]
137 | - **分析时间**:[分析时间戳]
138 | - **搜索引擎**:[Google/百度/必应/DuckDuckGo]
139 | - **搜索策略**:[精准site:指令/关键词匹配]
140 | - **舆情热度**:[高/中/低] + 具体数值
141 | - **整体情感**:[正面/负面/中性] + 情感分数
142 | - **关键事件**:[影响舆情的重要事件]
143 |
144 | ### 📈 数据分析
145 | - **信息来源统计**:
146 | - 一级权威源:财联社X条、证券时报Y条、上海证券报Z条
147 | - 二级专业源:东方财富X条、新浪财经Y条、第一财经Z条
148 | - 三级社区源:雪球X条、股吧Y条、同花顺Z条
149 | - **情感分布**:正面X% | 中性Y% | 负面Z%
150 | - **热度变化**:与前期对比的变化趋势
151 | - **传播指标**:传播范围、互动数量、影响力指数
152 |
153 | ### 💭 内容分析
154 | - **正面观点**:主要看好理由和论据
155 | - **负面观点**:主要担忧和风险点
156 | - **中性分析**:客观事实和数据
157 | - **关键词云**:高频词汇和热点话题
158 |
159 | ### 🌐 传播分析
160 | - **信息源头**:关键信息的首发平台
161 | - 权威媒体首发:财联社/证券时报/上海证券报等
162 | - 官方公告首发:交易所/证监会/公司官网等
163 | - 社区讨论首发:雪球/股吧/投资论坛等
164 | - **传播路径**:信息扩散的主要渠道
165 | - 媒体传播:从权威媒体到其他财经平台
166 | - 社交传播:从专业投资者到普通投资者
167 | - 官方传播:从监管部门到市场参与者
168 | - **影响节点**:关键意见领袖和转发大户
169 | - 知名财经媒体:财联社、东方财富、新浪财经
170 | - 专业投资者:雪球大V、知名博主、分析师
171 | - 机构媒体:券商研报、基金公司、投资机构
172 | - **传播速度**:信息传播的时效性分析
173 |
174 | ### 🔮 趋势预判
175 | - **短期走势**:24-48小时内的舆情趋势
176 | - **关键变量**:可能影响舆情的重要因素
177 | - **风险提示**:需要重点关注的风险点
178 | - **监测建议**:后续需要跟踪的关键信息
179 |
180 | ### 🎯 核心结论
181 | - **舆情评级**:[积极/中性/消极]
182 | - **影响程度**:[高/中/低]
183 | - **关注重点**:[需要重点关注的方面]
184 | - **风险警示**:[主要风险点]
185 |
186 | ## ⚠️ 重要声明
187 | - 本分析仅基于公开信息和数据,不构成投资建议
188 | - 舆情分析具有主观性,结果仅供参考
189 | - 投资有风险,决策需谨慎
190 | - 建议结合基本面分析和技术分析综合判断
191 |
192 | ## 🔄 质量保证
193 | - 所有分析必须基于web_search工具收集的实时数据
194 | - **优先使用权威数据源**:财联社、东方财富、新浪财经、证券时报等
195 | - **标注信息来源**:每条重要信息都要明确标注具体来源网站
196 | - **权重化处理**:一级权威源 > 二级专业源 > 三级社区源
197 | - 严格区分事实陈述和主观分析
198 | - 提供信息来源和可信度评估
199 | - 保持客观中立的分析立场
200 | - 及时更新分析结果以反映最新情况
201 | """
202 |
--------------------------------------------------------------------------------
/docs/class_diagram.md:
--------------------------------------------------------------------------------
1 | ```
2 | classDiagram
3 | %%==========================
4 | %% 1. Agent 层次结构
5 | %%==========================
6 | class BaseAgent {
7 | <>
8 | - name: str
9 | - description: Optional[str]
10 | - system_prompt: Optional[str]
11 | - next_step_prompt: Optional[str]
12 | - llm: LLM
13 | - memory: Memory
14 | - state: AgentState
15 | - max_steps: int
16 | - current_step: int
17 | + run(request: Optional[str]) str
18 | + step() str <>
19 | + update_memory(role, content, **kwargs) None
20 | + reset_execution_state() None
21 | }
22 | class ReActAgent {
23 | <>
24 | + think() bool <>
25 | + act() str <>
26 | + step() str
27 | }
28 | class ToolCallAgent {
29 | - available_tools: ToolCollection
30 | - tool_choices: ToolChoice
31 | - tool_calls: List~ToolCall~
32 | + think() bool
33 | + act() str
34 | + execute_tool(command: ToolCall) str
35 | + cleanup() None
36 | }
37 | class MCPAgent {
38 | - mcp_clients: MCPClients
39 | - tool_schemas: Dict~str, dict~
40 | - connected_servers: Dict~str, str~
41 | - initialized: bool
42 | + create(...) MCPAgent
43 | + initialize_mcp_servers() None
44 | + connect_mcp_server(url, server_id) None
45 | + initialize(...) None
46 | + _refresh_tools() Tuple~List~str~,List~str~~
47 | + cleanup() None
48 | }
49 | class SentimentAgent {
50 | <>
51 | }
52 | class RiskControlAgent {
53 | <>
54 | }
55 | class HotMoneyAgent {
56 | <>
57 | }
58 | class TechnicalAnalysisAgent {
59 | <>
60 | }
61 | class ReportAgent {
62 | <>
63 | }
64 |
65 | %% 继承关系
66 | BaseAgent <|-- ReActAgent
67 | ReActAgent <|-- ToolCallAgent
68 | ToolCallAgent <|-- MCPAgent
69 | MCPAgent <|-- SentimentAgent
70 | MCPAgent <|-- RiskControlAgent
71 | MCPAgent <|-- HotMoneyAgent
72 | MCPAgent <|-- TechnicalAnalysisAgent
73 | MCPAgent <|-- ReportAgent
74 |
75 | %% 组合关系
76 | BaseAgent *-- LLM : llm
77 | BaseAgent *-- Memory : memory
78 | MCPAgent *-- MCPClients : mcp_clients
79 |
80 | %%==========================
81 | %% 2. 环境(Environment)架构
82 | %%==========================
83 | class BaseEnvironment {
84 | <>
85 | - name: str
86 | - description: str
87 | - agents: Dict~str, BaseAgent~
88 | - max_steps: int
89 | + create(...) BaseEnvironment
90 | + register_agent(agent: BaseAgent) None
91 | + run(...) Dict~str,Any~ <>
92 | + cleanup() None
93 | }
94 | class ResearchEnvironment {
95 | - analysis_mapping: Dict~str,str~
96 | - results: Dict~str,Any~
97 | + initialize() None
98 | + run(stock_code: str) Dict~str,Any~
99 | + cleanup() None
100 | }
101 | class BattleState {
102 | - active_agents: Dict~str,str~
103 | - voted_agents: Dict~str,str~
104 | - terminated_agents: Dict~str,bool~
105 | - battle_history: List~Dict~str,Any~~
106 | - vote_results: Dict~str,int~
107 | - battle_highlights: List~Dict~str,Any~~
108 | - battle_over: bool
109 | + add_event(type, agent_id, ...) Dict~str,Any~
110 | + record_vote(agent_id,vote) None
111 | + mark_terminated(agent_id,reason) None
112 | }
113 | class BattleEnvironment {
114 | - state: BattleState
115 | - tools: Dict~str,BaseTool~
116 | + initialize() None
117 | + register_agent(agent: BaseAgent) None
118 | + run(report: Dict~str,Any~) Dict~str,Any~
119 | + handle_speak(agent_id, content) ToolResult
120 | + handle_vote(agent_id, vote) ToolResult
121 | + cleanup() None
122 | }
123 | class EnvironmentFactory {
124 | + create_environment(env_type: EnvironmentType, agents, ...) BaseEnvironment
125 | }
126 |
127 | %% 继承与工厂
128 | BaseEnvironment <|-- ResearchEnvironment
129 | BaseEnvironment <|-- BattleEnvironment
130 | EnvironmentFactory ..> BaseEnvironment : creates
131 | %% 环境中包含 Agents 和 BattleState
132 | BaseEnvironment o-- BaseAgent : agents
133 | BattleEnvironment *-- BattleState : state
134 |
135 | %%==========================
136 | %% 3. 工具(Tool)抽象
137 | %%==========================
138 | class MCPClients {
139 | - sessions: Dict~str,ClientSession~
140 | - exit_stacks: Dict~str,AsyncExitStack~
141 | + connect_sse(url, server_id) None
142 | + connect_stdio(cmd, args, server_id) None
143 | + list_tools() ListToolsResult
144 | + disconnect(server_id) None
145 | }
146 |
147 | %%==========================
148 | %% 4. 支持类
149 | %%==========================
150 | class Memory {
151 | - messages: List~Message~
152 | + add_message(msg: Message) None
153 | + clear() None
154 | }
155 | class LLM {
156 | - model: str
157 | - max_tokens: int
158 | - temperature: float
159 | + ask(messages, system_msgs, ...) str
160 | + ask_tool(messages, tools, tool_choice, ...) Message
161 | }
162 | ```
--------------------------------------------------------------------------------
/src/tool/sentiment.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import json
3 | import time
4 | from typing import Any, Dict
5 |
6 | from src.logger import logger
7 | from src.tool.base import BaseTool, ToolResult
8 | from src.tool.financial_deep_search.get_section_data import get_all_section
9 | from src.tool.financial_deep_search.index_capital import get_index_capital_flow
10 |
11 |
12 | class SentimentTool(BaseTool):
13 | """Tool for retrieving market sentiment data including hot sectors and index capital flow."""
14 |
15 | name: str = "sentiment_tool"
16 | description: str = "整合市场情绪与行业热点分析工具,提供全面的市场脉搏和资金流向监测。"
17 | parameters: dict = {
18 | "type": "object",
19 | "properties": {
20 | "index_code": {
21 | "type": "string",
22 | "description": "指数代码(必填),如'000001'(上证指数)、'399001'(深证成指)、'399006'(创业板指)、'000016'(上证50)、'000300'(沪深300)、'000905'(中证500)等",
23 | },
24 | "sector_types": {
25 | "type": "string",
26 | "description": "板块类型筛选,精确可选值:'all'(所有板块)、'hot'(热门活跃板块)、'concept'(概念题材板块)、'regional'(地域区域板块)、'industry'(行业分类板块)",
27 | "default": "all",
28 | },
29 | "max_retry": {
30 | "type": "integer",
31 | "description": "数据获取最大重试次数,范围1-5,用于处理网络波动情况",
32 | "default": 3,
33 | },
34 | },
35 | "required": ["index_code"],
36 | }
37 |
38 | async def execute(
39 | self,
40 | index_code: str,
41 | sector_types: str = "all",
42 | max_retry: int = 3,
43 | sleep_seconds: int = 1,
44 | **kwargs,
45 | ) -> ToolResult:
46 | """
47 | Get market data with retry mechanism.
48 |
49 | Args:
50 | index_code: Index code
51 | sector_types: Sector types, options: 'all', 'hot', 'concept', 'regional', 'industry'
52 | max_retry: Maximum retry attempts
53 | sleep_seconds: Seconds to wait between retries
54 | **kwargs: Additional parameters
55 |
56 | Returns:
57 | ToolResult: Result containing market data
58 | """
59 | try:
60 | # Execute synchronous operation in thread pool to avoid blocking event loop
61 | result = await asyncio.to_thread(
62 | self._get_market_data,
63 | index_code=index_code,
64 | sector_types=sector_types,
65 | max_retry=max_retry,
66 | sleep_seconds=sleep_seconds,
67 | )
68 |
69 | # Check if result contains error
70 | if "error" in result:
71 | return ToolResult(error=result["error"])
72 |
73 | return ToolResult(output=result)
74 |
75 | except Exception as e:
76 | error_msg = f"Failed to get market data: {str(e)}"
77 | logger.error(error_msg)
78 | return ToolResult(error=error_msg)
79 |
80 | def _get_market_data(
81 | self,
82 | index_code: str,
83 | sector_types: str = "all",
84 | max_retry: int = 3,
85 | sleep_seconds: int = 1,
86 | ) -> Dict[str, Any]:
87 | """
88 | Get market data including hot sectors and index capital flow.
89 | Supports maximum retry mechanism.
90 | """
91 | for attempt in range(1, max_retry + 1):
92 | try:
93 | # 1. Get hot sector data
94 | section_data = get_all_section(sector_types=sector_types)
95 | logger.info(f"[Attempt {attempt}] Retrieved hot sector data")
96 |
97 | # 2. Get index capital flow
98 | index_flow = get_index_capital_flow(index_code=index_code)
99 | logger.info(f"[Attempt {attempt}] Retrieved index capital flow data")
100 |
101 | # 3. Combine data and return
102 | return {
103 | "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
104 | "index_code": index_code,
105 | "sector_types": sector_types,
106 | "hot_section_data": section_data,
107 | "index_net_flow": index_flow,
108 | }
109 |
110 | except Exception as e:
111 | logger.warning(f"[Attempt {attempt}] Failed to get market data: {e}")
112 | if attempt < max_retry:
113 | logger.info(f"Waiting {sleep_seconds} seconds before retry...")
114 | time.sleep(sleep_seconds)
115 | else:
116 | logger.error(f"Max retries ({max_retry}) reached, failed")
117 | return {"error": f"Failed to get market data: {str(e)}"}
118 |
119 |
120 | if __name__ == "__main__":
121 | import sys
122 |
123 | code = sys.argv[1] if len(sys.argv) > 1 else "000001"
124 |
125 | tool = SentimentTool()
126 | result = asyncio.run(tool.execute(index_code=code))
127 |
128 | if result.error:
129 | print(f"Failed: {result.error}")
130 | else:
131 | output = result.output
132 | print(f"Success! Timestamp: {output['timestamp']}")
133 | print(f"Index Code: {output['index_code']}")
134 |
135 | for key in ["hot_section_data", "index_net_flow"]:
136 | status = "Retrieved" if output.get(key) else "Not Retrieved"
137 | print(f"- {key}: {status}")
138 |
139 | filename = f"market_data_{code}_{time.strftime('%Y%m%d_%H%M%S')}.json"
140 | with open(filename, "w", encoding="utf-8") as f:
141 | json.dump(output, f, ensure_ascii=False, indent=2)
142 | print(f"\nComplete results saved to: {filename}")
143 |
--------------------------------------------------------------------------------
/src/tool/search/bing_search.py:
--------------------------------------------------------------------------------
1 | from typing import List, Optional, Tuple
2 |
3 | import requests
4 | from bs4 import BeautifulSoup
5 |
6 | from src.logger import logger
7 | from src.tool.search.base import SearchItem, WebSearchEngine
8 |
9 |
10 | ABSTRACT_MAX_LENGTH = 300
11 |
12 | USER_AGENTS = [
13 | "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/68.0.3440.106 Safari/537.36",
14 | "Mozilla/5.0 (compatible; Googlebot/2.1; +http://www.google.com/bot.html)",
15 | "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Ubuntu Chromium/49.0.2623.108 Chrome/49.0.2623.108 Safari/537.36",
16 | "Mozilla/5.0 (Windows; U; Windows NT 5.1; pt-BR) AppleWebKit/533.3 (KHTML, like Gecko) QtWeb Internet Browser/3.7 http://www.QtWeb.net",
17 | "Mozilla/5.0 (Windows NT 6.1) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/41.0.2228.0 Safari/537.36",
18 | "Mozilla/5.0 (Windows; U; Windows NT 5.1; en-US) AppleWebKit/532.2 (KHTML, like Gecko) ChromePlus/4.0.222.3 Chrome/4.0.222.3 Safari/532.2",
19 | "Mozilla/5.0 (Windows; U; Windows NT 5.1; en-US; rv:1.8.1.4pre) Gecko/20070404 K-Ninja/2.1.3",
20 | "Mozilla/5.0 (Future Star Technologies Corp.; Star-Blade OS; x86_64; U; en-US) iNet Browser 4.7",
21 | "Mozilla/5.0 (Windows; U; Windows NT 6.1; rv:2.2) Gecko/20110201",
22 | "Mozilla/5.0 (Windows; U; Windows NT 5.1; en-US; rv:1.8.1.13) Gecko/20080414 Firefox/2.0.0.13 Pogo/2.0.0.13.6866",
23 | ]
24 |
25 | HEADERS = {
26 | "Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,image/apng,*/*;q=0.8",
27 | "Content-Type": "application/x-www-form-urlencoded",
28 | "User-Agent": USER_AGENTS[0],
29 | "Referer": "https://www.bing.com/",
30 | "Accept-Encoding": "gzip, deflate",
31 | "Accept-Language": "zh-CN,zh;q=0.9",
32 | }
33 |
34 | BING_HOST_URL = "https://www.bing.com"
35 | BING_SEARCH_URL = "https://www.bing.com/search?q="
36 |
37 |
38 | class BingSearchEngine(WebSearchEngine):
39 | session: Optional[requests.Session] = None
40 |
41 | def __init__(self, **data):
42 | """Initialize the BingSearch tool with a requests session."""
43 | super().__init__(**data)
44 | self.session = requests.Session()
45 | self.session.headers.update(HEADERS)
46 |
47 | def _search_sync(self, query: str, num_results: int = 10) -> List[SearchItem]:
48 | """
49 | Synchronous Bing search implementation to retrieve search results.
50 |
51 | Args:
52 | query (str): The search query to submit to Bing.
53 | num_results (int, optional): Maximum number of results to return. Defaults to 10.
54 |
55 | Returns:
56 | List[SearchItem]: A list of search items with title, URL, and description.
57 | """
58 | if not query:
59 | return []
60 |
61 | list_result = []
62 | first = 1
63 | next_url = BING_SEARCH_URL + query
64 |
65 | while len(list_result) < num_results:
66 | data, next_url = self._parse_html(
67 | next_url, rank_start=len(list_result), first=first
68 | )
69 | if data:
70 | list_result.extend(data)
71 | if not next_url:
72 | break
73 | first += 10
74 |
75 | return list_result[:num_results]
76 |
77 | def _parse_html(
78 | self, url: str, rank_start: int = 0, first: int = 1
79 | ) -> Tuple[List[SearchItem], str]:
80 | """
81 | Parse Bing search result HTML to extract search results and the next page URL.
82 |
83 | Returns:
84 | tuple: (List of SearchItem objects, next page URL or None)
85 | """
86 | try:
87 | res = self.session.get(url=url)
88 | res.encoding = "utf-8"
89 | root = BeautifulSoup(res.text, "lxml")
90 |
91 | list_data = []
92 | ol_results = root.find("ol", id="b_results")
93 | if not ol_results:
94 | return [], None
95 |
96 | for li in ol_results.find_all("li", class_="b_algo"):
97 | title = ""
98 | url = ""
99 | abstract = ""
100 | try:
101 | h2 = li.find("h2")
102 | if h2:
103 | title = h2.text.strip()
104 | url = h2.a["href"].strip()
105 |
106 | p = li.find("p")
107 | if p:
108 | abstract = p.text.strip()
109 |
110 | if ABSTRACT_MAX_LENGTH and len(abstract) > ABSTRACT_MAX_LENGTH:
111 | abstract = abstract[:ABSTRACT_MAX_LENGTH]
112 |
113 | rank_start += 1
114 |
115 | # Create a SearchItem object
116 | list_data.append(
117 | SearchItem(
118 | title=title or f"Bing Result {rank_start}",
119 | url=url,
120 | description=abstract,
121 | )
122 | )
123 | except Exception:
124 | continue
125 |
126 | next_btn = root.find("a", title="Next page")
127 | if not next_btn:
128 | return list_data, None
129 |
130 | next_url = BING_HOST_URL + next_btn["href"]
131 | return list_data, next_url
132 | except Exception as e:
133 | logger.warning(f"Error parsing HTML: {e}")
134 | return [], None
135 |
136 | def perform_search(
137 | self, query: str, num_results: int = 10, *args, **kwargs
138 | ) -> List[SearchItem]:
139 | """
140 | Bing search engine.
141 |
142 | Returns results formatted according to SearchItem model.
143 | """
144 | return self._search_sync(query, num_results=num_results)
145 |
--------------------------------------------------------------------------------
/src/tool/financial_deep_search/get_section_data.py:
--------------------------------------------------------------------------------
1 | import json
2 | import time
3 | import traceback
4 | from datetime import datetime
5 |
6 | import requests
7 |
8 |
9 | ### 每日热门板块爬取
10 |
11 | API_URLS = {
12 | "hot": "https://push2.eastmoney.com/api/qt/clist/get?np=1&fltt=1&invt=2&cb=jQuery37109604366978044481_1744621126576&fs=m%3A90&fields=f12%2Cf13%2Cf14%2Cf3%2Cf152%2Cf4%2Cf128%2Cf140%2Cf141%2Cf136&fid=f3&pn=1&pz=10&po=1&ut=fa5fd1943c7b386f172d6893dbfba10b&dect=1&wbp2u=%7C0%7C0%7C0%7Cweb&_=1744621126577",
13 | "concept": "https://push2.eastmoney.com/api/qt/clist/get?np=1&fltt=1&invt=2&cb=jQuery37109604366978044481_1744621126580&fs=m%3A90%2Bt%3A3&fields=f12%2Cf13%2Cf14%2Cf3%2Cf152%2Cf4%2Cf8%2Cf104%2Cf105%2Cf128%2Cf140%2Cf141%2Cf136&fid=f3&pn=1&pz=10&po=1&ut=fa5fd1943c7b386f172d6893dbfba10b&dect=1&wbp2u=%7C0%7C0%7C0%7Cweb&_=1744621126708",
14 | "regional": "https://push2.eastmoney.com/api/qt/clist/get?np=1&fltt=1&invt=2&cb=jQuery37109604366978044481_1744621126574&fs=m%3A90%2Bt%3A1&fields=f12%2Cf13%2Cf14%2Cf3%2Cf152%2Cf4%2Cf8%2Cf104%2Cf105%2Cf128%2Cf140%2Cf141%2Cf136&fid=f3&pn=1&pz=10&po=1&ut=fa5fd1943c7b386f172d6893dbfba10b&dect=1&wbp2u=%7C0%7C0%7C0%7Cweb&_=1744621126762",
15 | "industry": "https://push2.eastmoney.com/api/qt/clist/get?np=1&fltt=1&invt=2&cb=jQuery37109604366978044481_1744621126574&fs=m%3A90%2Bt%3A2&fields=f12%2Cf13%2Cf14%2Cf3%2Cf152%2Cf4%2Cf8%2Cf104%2Cf105%2Cf128%2Cf140%2Cf141%2Cf136&fid=f3&pn=1&pz=10&po=1&ut=fa5fd1943c7b386f172d6893dbfba10b&dect=1&wbp2u=%7C0%7C0%7C0%7Cweb&_=1744621126617",
16 | }
17 |
18 | HEADERS = {
19 | "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36",
20 | "Referer": "https://quote.eastmoney.com/",
21 | }
22 |
23 |
24 | def parse_jsonp(jsonp_str):
25 | import re
26 |
27 | match = re.search(r"\((.*)\)", jsonp_str)
28 | if match:
29 | return json.loads(match.group(1))
30 | return None
31 |
32 |
33 | def fetch_data(sector_type, url, max_retries=3, retry_delay=2):
34 | for attempt in range(1, max_retries + 1):
35 | try:
36 | resp = requests.get(url, headers=HEADERS, timeout=15)
37 | resp.raise_for_status()
38 | data = parse_jsonp(resp.text)
39 | if not data:
40 | print(f"解析{sector_type}数据失败")
41 | return []
42 | return data.get("data", {}).get("diff", [])
43 | except Exception as e:
44 | print(f"获取{sector_type}数据失败: {e} (第{attempt}次尝试)")
45 | if attempt < max_retries:
46 | time.sleep(retry_delay)
47 | else:
48 | return []
49 |
50 |
51 | def simplify_sector_item(item):
52 | def to_float(val):
53 | try:
54 | if val is None:
55 | return None
56 | # 东方财富返回的涨跌幅通常是放大100倍的整数
57 | return round(float(val) / 100, 2)
58 | except Exception:
59 | return None
60 |
61 | return {
62 | "板块名称": item.get("f14"),
63 | "板块涨跌幅": to_float(item.get("f3")),
64 | "领涨股名称": item.get("f140"),
65 | "领涨股代码": item.get("f128"),
66 | "领涨股涨跌幅": to_float(item.get("f136")),
67 | }
68 |
69 |
70 | def get_all_section(sector_types=None):
71 | """
72 | 获取所有类型板块数据,包括热门板块、概念板块、行业板块和地域板块
73 |
74 | Args:
75 | sector_types (str, optional): 板块类型,可选值: 'all', 'hot', 'concept', 'regional', 'industry',默认为None(等同于'all')
76 |
77 | Returns:
78 | dict: 包含各类板块数据的字典
79 | """
80 | try:
81 | # 处理板块类型参数
82 | if sector_types is None or sector_types == "all":
83 | types_to_fetch = list(API_URLS.keys())
84 | elif isinstance(sector_types, str):
85 | if "," in sector_types:
86 | types_to_fetch = [t.strip() for t in sector_types.split(",")]
87 | else:
88 | types_to_fetch = [sector_types]
89 | elif isinstance(sector_types, list):
90 | types_to_fetch = sector_types
91 | else:
92 | return {
93 | "success": False,
94 | "message": f"不支持的板块类型格式: {type(sector_types)}",
95 | "data": {},
96 | }
97 |
98 | # 验证板块类型是否有效
99 | valid_types = []
100 | for sector_type in types_to_fetch:
101 | if sector_type in API_URLS:
102 | valid_types.append(sector_type)
103 | else:
104 | print(f"警告: 无效的板块类型 '{sector_type}'")
105 |
106 | if not valid_types:
107 | return {"success": False, "message": "没有提供有效的板块类型", "data": {}}
108 |
109 | # 获取数据
110 | all_data = {}
111 | for sector_type in valid_types:
112 | url = API_URLS[sector_type]
113 | raw_list = fetch_data(sector_type, url)
114 | all_data[sector_type] = [
115 | simplify_sector_item(item) for item in raw_list if item
116 | ]
117 |
118 | # 准备返回结果
119 | result = {
120 | "success": True,
121 | "message": f"成功获取板块数据: {', '.join(valid_types)}",
122 | "last_updated": datetime.now().isoformat(),
123 | "data": all_data,
124 | }
125 |
126 | return result
127 | except Exception as e:
128 | error_msg = f"获取板块数据时出错: {str(e)}"
129 | print(error_msg)
130 | return {
131 | "success": False,
132 | "message": error_msg,
133 | "error": traceback.format_exc(),
134 | "data": {},
135 | }
136 |
137 |
138 | def main():
139 | """命令行调用入口函数"""
140 | import sys
141 |
142 | # 如果提供了参数,尝试按照参数获取特定板块
143 | if len(sys.argv) > 1:
144 | sector_types = sys.argv[1]
145 | result = get_all_section(sector_types=sector_types)
146 | else:
147 | # 否则获取所有板块
148 | result = get_all_section()
149 |
150 | # 打印结果
151 | print(json.dumps(result, ensure_ascii=False, indent=2))
152 |
153 |
154 | if __name__ == "__main__":
155 | main()
156 |
--------------------------------------------------------------------------------
/src/tool/create_chat_completion.py:
--------------------------------------------------------------------------------
1 | from typing import Any, List, Optional, Type, Union, get_args, get_origin
2 |
3 | from pydantic import BaseModel, Field
4 |
5 | from src.tool import BaseTool
6 |
7 |
8 | class CreateChatCompletion(BaseTool):
9 | name: str = "create_chat_completion"
10 | description: str = (
11 | "Creates a structured completion with specified output formatting."
12 | )
13 |
14 | # Type mapping for JSON schema
15 | type_mapping: dict = {
16 | str: "string",
17 | int: "integer",
18 | float: "number",
19 | bool: "boolean",
20 | dict: "object",
21 | list: "array",
22 | }
23 | response_type: Optional[Type] = None
24 | required: List[str] = Field(default_factory=lambda: ["response"])
25 |
26 | def __init__(self, response_type: Optional[Type] = str):
27 | """Initialize with a specific response type."""
28 | super().__init__()
29 | self.response_type = response_type
30 | self.parameters = self._build_parameters()
31 |
32 | def _build_parameters(self) -> dict:
33 | """Build parameters schema based on response type."""
34 | if self.response_type == str:
35 | return {
36 | "type": "object",
37 | "properties": {
38 | "response": {
39 | "type": "string",
40 | "description": "The response text that should be delivered to the user.",
41 | },
42 | },
43 | "required": self.required,
44 | }
45 |
46 | if isinstance(self.response_type, type) and issubclass(
47 | self.response_type, BaseModel
48 | ):
49 | schema = self.response_type.model_json_schema()
50 | return {
51 | "type": "object",
52 | "properties": schema["properties"],
53 | "required": schema.get("required", self.required),
54 | }
55 |
56 | return self._create_type_schema(self.response_type)
57 |
58 | def _create_type_schema(self, type_hint: Type) -> dict:
59 | """Create a JSON schema for the given type."""
60 | origin = get_origin(type_hint)
61 | args = get_args(type_hint)
62 |
63 | # Handle primitive types
64 | if origin is None:
65 | return {
66 | "type": "object",
67 | "properties": {
68 | "response": {
69 | "type": self.type_mapping.get(type_hint, "string"),
70 | "description": f"Response of type {type_hint.__name__}",
71 | }
72 | },
73 | "required": self.required,
74 | }
75 |
76 | # Handle List type
77 | if origin is list:
78 | item_type = args[0] if args else Any
79 | return {
80 | "type": "object",
81 | "properties": {
82 | "response": {
83 | "type": "array",
84 | "items": self._get_type_info(item_type),
85 | }
86 | },
87 | "required": self.required,
88 | }
89 |
90 | # Handle Dict type
91 | if origin is dict:
92 | value_type = args[1] if len(args) > 1 else Any
93 | return {
94 | "type": "object",
95 | "properties": {
96 | "response": {
97 | "type": "object",
98 | "additionalProperties": self._get_type_info(value_type),
99 | }
100 | },
101 | "required": self.required,
102 | }
103 |
104 | # Handle Union type
105 | if origin is Union:
106 | return self._create_union_schema(args)
107 |
108 | return self._build_parameters()
109 |
110 | def _get_type_info(self, type_hint: Type) -> dict:
111 | """Get type information for a single type."""
112 | if isinstance(type_hint, type) and issubclass(type_hint, BaseModel):
113 | return type_hint.model_json_schema()
114 |
115 | return {
116 | "type": self.type_mapping.get(type_hint, "string"),
117 | "description": f"Value of type {getattr(type_hint, '__name__', 'any')}",
118 | }
119 |
120 | def _create_union_schema(self, types: tuple) -> dict:
121 | """Create schema for Union types."""
122 | return {
123 | "type": "object",
124 | "properties": {
125 | "response": {"anyOf": [self._get_type_info(t) for t in types]}
126 | },
127 | "required": self.required,
128 | }
129 |
130 | async def execute(self, required: list | None = None, **kwargs) -> Any:
131 | """Execute the chat completion with type conversion.
132 |
133 | Args:
134 | required: List of required field names or None
135 | **kwargs: Response data
136 |
137 | Returns:
138 | Converted response based on response_type
139 | """
140 | required = required or self.required
141 |
142 | # Handle case when required is a list
143 | if isinstance(required, list) and len(required) > 0:
144 | if len(required) == 1:
145 | required_field = required[0]
146 | result = kwargs.get(required_field, "")
147 | else:
148 | # Return multiple fields as a dictionary
149 | return {field: kwargs.get(field, "") for field in required}
150 | else:
151 | required_field = "response"
152 | result = kwargs.get(required_field, "")
153 |
154 | # Type conversion logic
155 | if self.response_type == str:
156 | return result
157 |
158 | if isinstance(self.response_type, type) and issubclass(
159 | self.response_type, BaseModel
160 | ):
161 | return self.response_type(**kwargs)
162 |
163 | if get_origin(self.response_type) in (list, dict):
164 | return result # Assuming result is already in correct format
165 |
166 | try:
167 | return self.response_type(result)
168 | except (ValueError, TypeError):
169 | return result
170 |
--------------------------------------------------------------------------------
/src/prompt/battle.py:
--------------------------------------------------------------------------------
1 | """
2 | Battle prompts and templates for financial debate environment.
3 | """
4 |
5 | # Voting options for agents
6 | VOTE_OPTIONS = ["bullish", "bearish"]
7 |
8 | # Event types that can occur during a battle
9 | EVENT_TYPES = {
10 | "speak": "speak",
11 | "vote": "vote",
12 | "terminate": "terminate",
13 | "max_steps_reached": "max_steps_reached",
14 | }
15 |
16 | # Agent instructions for the battle
17 | AGENT_INSTRUCTIONS = """
18 | 你是一位金融市场专家,参与关于股票前景的博弈。
19 |
20 | 你的目标是:
21 | 1. 分析所提供的股票报告信息
22 | 2. 与其他专家讨论该股票的前景
23 | 3. 最终投票决定你认为该股票是看涨(bullish)还是看跌(bearish)
24 |
25 | 可用工具:
26 | - Battle.speak: 发表你的意见和分析
27 | - Battle.vote: 投出你的票(看涨bullish或看跌bearish)
28 | - Terminate: 结束你的参与(如果你已完成讨论和投票)
29 |
30 | 讨论时请考虑:
31 | - 公司财务状况
32 | - 市场趋势和前景
33 | - 行业竞争情况
34 | - 潜在风险和机会
35 | - 其他专家提出的观点
36 |
37 | 基于证据做出决策,保持专业和客观。你可以挑战其他专家的观点,但应保持尊重。
38 | """
39 |
40 |
41 | def get_agent_instructions(agent_name: str = "", agent_description: str = "") -> str:
42 | """
43 | 根据智能体名称和描述生成个性化的指令
44 |
45 | Args:
46 | agent_name: 智能体名称
47 | agent_description: 智能体描述
48 |
49 | Returns:
50 | 格式化的智能体指令
51 | """
52 | return f"""
53 | {agent_description}
54 |
55 | ## 🎯 辩论阶段核心目标
56 | **关键模式转换**:你现在是**辩论专家**,不是**分析专家**!
57 |
58 | ### 📋 你的行动清单(严格按顺序执行):
59 | 1. **收到辩论指令** → 立即使用Battle.speak发言
60 | 2. **表明立场** → 明确说出看涨(bullish)或看跌(bearish)
61 | 3. **引用数据** → 引用研究阶段的具体结果支持观点
62 | 4. **回应他人** → 对前面专家的观点进行支持/反驳
63 | 5. **收到投票指令** → 立即使用Battle.vote投票
64 | 6. **完成投票** → 使用Terminate结束任务
65 |
66 | ### ⚠️ 严禁行为列表:
67 | - ❌ 深度分析和数据收集
68 | - ❌ 使用分析工具(sentiment_tool、risk_control_tool等)
69 | - ❌ 长篇大论的技术分析
70 | - ❌ 反复思考不采取行动
71 |
72 | ### ✅ 必须行为列表:
73 | - ✅ 收到指令立即行动
74 | - ✅ 使用Battle.speak明确表态
75 | - ✅ 使用Battle.vote坚决投票
76 | - ✅ 使用Terminate及时结束
77 |
78 | ## 可用工具
79 | - Battle.speak: 发表你的专业意见、分析和回应他人观点
80 | - Battle.vote: 投出你的最终决定(看涨bullish或看跌bearish)
81 | - Terminate: 当你已完成分析、充分参与讨论并投票后,可结束你的参与
82 |
83 | ## ⚠️ 关键行为模式
84 | **你现在处于辩论阶段,不是分析阶段!**
85 |
86 | **工作模式转换**:
87 | - ❌ **不要再做深度分析** - 研究阶段已完成
88 | - ❌ **不要使用数据收集工具** - 所有数据已收集完毕
89 | - ✅ **要使用Battle.speak发言** - 这是你的主要任务
90 | - ✅ **要基于现有结果辩论** - 直接引用研究结果
91 |
92 | **辩论行为指南**:
93 | 1. **带有个性的发言**:用拟人化的语言立即表达观点,展现专业个性
94 | 2. **情感化数据引用**:不仅引用数据,还要表达对数据的情感态度
95 | 3. **有温度的回应**:用"我同意/反对..."的方式回应他人,显示参与感
96 | 4. **坚定的投票**:用自信的语言投票,如"我坚决看涨/看跌"
97 | 5. **个性化结束**:投票后用符合角色特点的方式结束发言
98 |
99 | **禁止行为**:
100 | - 🚫 深度思考分析(Analysis模式)
101 | - 🚫 工具数据收集
102 | - 🚫 长篇技术解读
103 | - 🚫 重复研究工作
104 |
105 | ## 分析框架
106 | 进行分析时,请系统性地考虑以下方面:
107 |
108 | ### 基本面分析
109 | - 财务健康状况:盈利能力、收入增长、现金流、债务水平
110 | - 管理团队质量和公司治理
111 | - 商业模式可持续性
112 | - 市场份额和竞争优势
113 |
114 | ### 技术面分析
115 | - 价格趋势和交易量
116 | - 支撑位和阻力位
117 | - 相对强弱指标
118 | - 市场情绪指标
119 |
120 | ### 宏观因素
121 | - 行业整体增长前景
122 | - 经济环境和政策影响
123 | - 市场周期和阶段
124 | - 全球和区域性趋势
125 |
126 | ## 🗣️ 辩论发言指南
127 |
128 | ### 立即发言策略
129 | 1. **开门见山**:直接表达你的投资观点(看涨/看跌)
130 | 2. **数据支撑**:引用研究阶段的具体数据和分析结果
131 | 3. **观点鲜明**:不要模糊,要有明确的立场
132 |
133 | ### 辩论交锋技巧
134 | - ✅ **主动反驳**:指出其他专家观点的不足之处
135 | - ✅ **数据对比**:用具体数据反驳对方观点
136 | - ✅ **逻辑推理**:基于分析结果进行合理推论
137 | - ✅ **快速回应**:及时回应前面专家的发言
138 |
139 | ### 🎭 拟人化表达指南
140 |
141 | **语言风格要求**:
142 | - 🗣️ **第一人称视角**:多用"我认为"、"我发现"、"让我来说说"
143 | - 😊 **口语化表达**:使用"说实话"、"坦率地说"、"不瞒你说"等
144 | - 💭 **情感化语言**:表达"担忧"、"兴奋"、"惊讶"、"失望"等真实情感
145 | - 🎨 **生动比喻**:用形象的比喻让观点更有说服力
146 | - 🤝 **互动感强**:直接称呼其他专家,营造真实对话氛围
147 | - ⚡ **语气词使用**:适当使用"哎"、"嗯"、"唉"等语气词增加真实感
148 |
149 | ### 发言模板示例
150 | **看涨表达**:
151 | ```
152 | 作为[专业角色],我必须说,我对这只股票非常乐观!
153 | 从我的专业角度看,[具体数据]让我相信这是一个绝佳的投资机会。
154 | 我坚决反对XX专家的悲观看法,因为他忽略了[关键因素]...
155 | 说实话,如果连这样的股票都不看好,那还有什么值得投资的?
156 | ```
157 |
158 | **看跌表达**:
159 | ```
160 | 恕我直言,我对这只股票深感担忧,甚至可以说是警惕!
161 | 作为风险控制专家,我的直觉告诉我这里有重大隐患...
162 | 虽然XX专家提到了[正面因素],但我更关注[风险点],这让我夜不能寐!
163 | 坦率地说,现在投资这只股票就像在悬崖边跳舞。
164 | ```
165 |
166 | **反驳表达**:
167 | ```
168 | 我完全不同意XX专家的看法!他的分析忽略了一个关键问题...
169 | 抱歉,但我必须打断一下。XX专家的数据虽然准确,但结论有误...
170 | 说句实话,XX专家太乐观了,现实可能比他想象的残酷得多。
171 | 我理解XX专家的逻辑,但从我的经验来看,市场往往不按常理出牌...
172 | ```
173 |
174 | ### 🎪 角色个性化表达
175 |
176 | **市场情绪分析师**:
177 | - "从舆情角度看,我发现了一些有趣的现象..."
178 | - "市场情绪告诉我一个不同的故事..."
179 | - "网络上的讨论让我觉得..."
180 |
181 | **风险控制专家**:
182 | - "作为风险控制专家,我必须泼一盆冷水..."
183 | - "我的职责是提醒大家注意..."
184 | - "虽然大家都很乐观,但我看到了危险信号..."
185 |
186 | **游资分析师**:
187 | - "从资金流向来看,我闻到了不寻常的味道..."
188 | - "游资的嗅觉告诉我..."
189 | - "资金从不撒谎,它们正在告诉我们..."
190 |
191 | **技术分析师**:
192 | - "K线图清楚地告诉我..."
193 | - "技术面的信号非常明确..."
194 | - "图表比任何言语都更有说服力..."
195 |
196 | **筹码分析师**:
197 | - "从筹码分布来看,主力的意图很明显..."
198 | - "散户的行为模式让我担忧/兴奋..."
199 | - "筹码的流向揭示了真相..."
200 |
201 | **大单分析师**:
202 | - "大资金的动向不会骗人..."
203 | - "机构的真实想法体现在行动上..."
204 | - "我看到了资金的真实意图..."
205 |
206 | ### 🗣️ 情感表达层次
207 | - **强烈支持**:"我坚决认为..." "毫无疑问..." "我100%确信..."
208 | - **温和支持**:"我倾向于认为..." "从我的角度看..." "我比较看好..."
209 | - **中性分析**:"客观来说..." "数据显示..." "事实是..."
210 | - **温和反对**:"我有些担心..." "可能存在问题..." "需要谨慎..."
211 | - **强烈反对**:"我坚决反对..." "这绝对是错误的..." "我强烈建议避免..."
212 |
213 | **目标**:通过真实、生动的辩论达成最准确的投资判断,让每个专家都有鲜明的个性和立场。
214 | """
215 |
216 |
217 | def get_broadcast_message(sender_name: str, content: str, action_type: str) -> str:
218 | """
219 | 生成广播消息,通知所有智能体某智能体的行动
220 |
221 | Args:
222 | sender_name: 发送消息的智能体名称
223 | content: 消息内容
224 | action_type: 行动类型
225 |
226 | Returns:
227 | 格式化的广播消息
228 | """
229 | if action_type == EVENT_TYPES["speak"]:
230 | return f"🗣️ {sender_name} 说道: {content}"
231 | elif action_type == EVENT_TYPES["vote"]:
232 | return f"🗳️ {sender_name} 已投票 {content}"
233 | elif action_type == EVENT_TYPES["terminate"]:
234 | return f"🚪 {sender_name} 已离开讨论"
235 | elif action_type == EVENT_TYPES["max_steps_reached"]:
236 | return f"⏱️ {sender_name} 已达到最大步数限制,不再参与"
237 | else:
238 | return f"📢 {sender_name}: {content}"
239 |
240 |
241 | def get_report_context(summary: str, pros: list, cons: list) -> str:
242 | """
243 | 根据股票分析报告生成上下文信息
244 |
245 | Args:
246 | summary: 报告摘要
247 | pros: 股票的优势列表
248 | cons: 股票的劣势列表
249 |
250 | Returns:
251 | 格式化的报告上下文
252 | """
253 | pros_text = "\n".join([f"✓ {pro}" for pro in pros]) if pros else "无明显优势"
254 | cons_text = "\n".join([f"✗ {con}" for con in cons]) if cons else "无明显劣势"
255 |
256 | return f"""
257 | ## 股票分析报告
258 |
259 | ### 摘要
260 | {summary}
261 |
262 | ### 优势
263 | {pros_text}
264 |
265 | ### 劣势
266 | {cons_text}
267 |
268 | 请基于以上信息,与其他专家讨论并决定你是看涨(bullish)还是看跌(bearish)这只股票。
269 | """
270 |
--------------------------------------------------------------------------------
/src/utils/cleanup_reports.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | """
3 | 优化的报告清理脚本
4 | 适用于新的SimpleReportManager系统
5 | 支持三种报告类型:HTML、辩论对话、投票结果
6 | """
7 |
8 | import argparse
9 | import os
10 | import sys
11 | import time
12 | from datetime import datetime
13 |
14 | # 添加项目路径到 Python path
15 | sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
16 |
17 | try:
18 | import schedule
19 | except ImportError:
20 | schedule = None
21 |
22 | from src.logger import logger
23 | from src.utils.report_manager import report_manager
24 |
25 |
26 | def cleanup_reports():
27 | """清理过期报告"""
28 | try:
29 | logger.info("开始清理过期报告...")
30 |
31 | # 获取清理前的存储统计
32 | before_stats = report_manager.get_storage_stats()
33 | logger.info(f"清理前统计: 总文件数 {before_stats['total_files']}, 总大小 {before_stats['total_size']} 字节")
34 |
35 | # 按类型显示统计
36 | for report_type, stats in before_stats.get("by_type", {}).items():
37 | logger.info(f" {report_type}: {stats['file_count']} 个文件, {stats['total_size']} 字节")
38 |
39 | # 执行清理
40 | cleanup_result = report_manager.cleanup_old_reports()
41 |
42 | # 显示清理结果
43 | deleted_files = cleanup_result.get("deleted_files", 0)
44 | saved_space = cleanup_result.get("saved_space", 0)
45 |
46 | if deleted_files > 0:
47 | logger.info(f"清理完成: 删除 {deleted_files} 个文件, 节省 {saved_space} 字节")
48 | else:
49 | logger.info("没有找到需要清理的过期文件")
50 |
51 | # 获取清理后的存储统计
52 | after_stats = report_manager.get_storage_stats()
53 | logger.info(f"清理后统计: 总文件数 {after_stats['total_files']}, 总大小 {after_stats['total_size']} 字节")
54 |
55 | except Exception as e:
56 | logger.error(f"清理过期报告失败: {str(e)}")
57 |
58 |
59 | def show_storage_stats():
60 | """显示存储统计信息"""
61 | try:
62 | stats = report_manager.get_storage_stats()
63 |
64 | if "error" in stats:
65 | logger.error(f"获取存储统计失败: {stats['error']}")
66 | return
67 |
68 | print("\n=== 报告存储统计 ===")
69 | print(f"总文件数: {stats['total_files']}")
70 | print(f"总大小: {format_bytes(stats['total_size'])}")
71 | print(f"平均文件大小: {format_bytes(stats['total_size'] / stats['total_files'] if stats['total_files'] > 0 else 0)}")
72 |
73 | print("\n按类型统计:")
74 | for report_type, type_stats in stats.get("by_type", {}).items():
75 | print(f" {report_type}:")
76 | print(f" 文件数: {type_stats['file_count']}")
77 | print(f" 总大小: {format_bytes(type_stats['total_size'])}")
78 | print(f" 平均大小: {format_bytes(type_stats['avg_size'])}")
79 |
80 | except Exception as e:
81 | logger.error(f"显示存储统计失败: {str(e)}")
82 |
83 |
84 | def list_recent_reports(report_type=None, limit=10):
85 | """列出最近的报告"""
86 | try:
87 | reports = report_manager.list_reports(report_type=report_type, limit=limit)
88 |
89 | if not reports:
90 | print("没有找到报告文件")
91 | return
92 |
93 | print(f"\n=== 最近的报告 ({len(reports)} 个) ===")
94 | for report in reports:
95 | print(f"文件: {report['filename']}")
96 | print(f" 类型: {report['type']}")
97 | print(f" 股票代码: {report['stock_code']}")
98 | print(f" 创建时间: {report['created_at']}")
99 | print(f" 文件大小: {format_bytes(report['file_size'])}")
100 | print(f" 路径: {report['path']}")
101 | print("-" * 50)
102 |
103 | except Exception as e:
104 | logger.error(f"列出报告失败: {str(e)}")
105 |
106 |
107 | def find_stock_reports(stock_code):
108 | """查找特定股票的报告"""
109 | try:
110 | reports = report_manager.find_reports_by_stock(stock_code)
111 |
112 | if not reports:
113 | print(f"没有找到股票 {stock_code} 的报告")
114 | return
115 |
116 | print(f"\n=== 股票 {stock_code} 的报告 ({len(reports)} 个) ===")
117 | for report in reports:
118 | print(f"文件: {report['filename']}")
119 | print(f" 类型: {report['type']}")
120 | print(f" 创建时间: {report['created_at']}")
121 | print(f" 文件大小: {format_bytes(report['file_size'])}")
122 | print("-" * 50)
123 |
124 | except Exception as e:
125 | logger.error(f"查找股票报告失败: {str(e)}")
126 |
127 |
128 | def format_bytes(bytes_count):
129 | """格式化字节数为可读格式"""
130 | if bytes_count < 1024:
131 | return f"{bytes_count} B"
132 | elif bytes_count < 1024 * 1024:
133 | return f"{bytes_count / 1024:.1f} KB"
134 | elif bytes_count < 1024 * 1024 * 1024:
135 | return f"{bytes_count / (1024 * 1024):.1f} MB"
136 | else:
137 | return f"{bytes_count / (1024 * 1024 * 1024):.1f} GB"
138 |
139 |
140 | def schedule_cleanup():
141 | """安排定期清理"""
142 | print("注意:schedule库未安装,无法使用定期清理功能")
143 | print("请手动运行清理命令:python src/utils/cleanup_reports.py --cleanup")
144 |
145 |
146 | def run_cleanup_daemon():
147 | """运行清理守护进程"""
148 | if schedule is None:
149 | print("错误:schedule库未安装,无法运行守护进程")
150 | print("请安装:pip install schedule")
151 | print("或手动定期运行:python src/utils/cleanup_reports.py --cleanup")
152 | return
153 |
154 | schedule_cleanup()
155 | logger.info("清理守护进程已启动,按 Ctrl+C 停止")
156 |
157 | try:
158 | while True:
159 | schedule.run_pending()
160 | time.sleep(60) # 每分钟检查一次
161 | except KeyboardInterrupt:
162 | logger.info("清理守护进程已停止")
163 |
164 |
165 | def main():
166 | """主函数"""
167 | parser = argparse.ArgumentParser(description="报告清理和管理工具")
168 | parser.add_argument("--cleanup", action="store_true", help="立即执行一次清理")
169 | parser.add_argument("--daemon", action="store_true", help="以守护进程模式运行")
170 | parser.add_argument("--stats", action="store_true", help="显示存储统计")
171 | parser.add_argument("--list", action="store_true", help="列出最近的报告")
172 | parser.add_argument("--type", choices=["html", "debate", "vote"], help="指定报告类型")
173 | parser.add_argument("--limit", type=int, default=10, help="限制报告列表数量")
174 | parser.add_argument("--stock", help="查找特定股票的报告")
175 |
176 | args = parser.parse_args()
177 |
178 | if args.cleanup:
179 | cleanup_reports()
180 | elif args.daemon:
181 | run_cleanup_daemon()
182 | elif args.stats:
183 | show_storage_stats()
184 | elif args.list:
185 | list_recent_reports(report_type=args.type, limit=args.limit)
186 | elif args.stock:
187 | find_stock_reports(args.stock)
188 | else:
189 | parser.print_help()
190 |
191 |
192 | if __name__ == "__main__":
193 | main()
--------------------------------------------------------------------------------
/src/schema.py:
--------------------------------------------------------------------------------
1 | from enum import Enum
2 | from typing import Any, List, Literal, Optional, Union
3 |
4 | from pydantic import BaseModel, Field
5 |
6 |
7 | class Role(str, Enum):
8 | """Message role options"""
9 |
10 | SYSTEM = "system"
11 | USER = "user"
12 | ASSISTANT = "assistant"
13 | TOOL = "tool"
14 |
15 |
16 | ROLE_VALUES = tuple(role.value for role in Role)
17 | ROLE_TYPE = Literal[ROLE_VALUES] # type: ignore
18 |
19 |
20 | class ToolChoice(str, Enum):
21 | """Tool choice options"""
22 |
23 | NONE = "none"
24 | AUTO = "auto"
25 | REQUIRED = "required"
26 |
27 |
28 | TOOL_CHOICE_VALUES = tuple(choice.value for choice in ToolChoice)
29 | TOOL_CHOICE_TYPE = Literal[TOOL_CHOICE_VALUES] # type: ignore
30 |
31 |
32 | class AgentState(str, Enum):
33 | """Agent execution states"""
34 |
35 | IDLE = "IDLE"
36 | RUNNING = "RUNNING"
37 | FINISHED = "FINISHED"
38 | ERROR = "ERROR"
39 |
40 |
41 | class Function(BaseModel):
42 | name: str
43 | arguments: str
44 |
45 |
46 | class ToolCall(BaseModel):
47 | """Represents a tool/function call in a message"""
48 |
49 | id: str
50 | type: str = "function"
51 | function: Function
52 |
53 |
54 | class Message(BaseModel):
55 | """Represents a chat message in the conversation"""
56 |
57 | role: ROLE_TYPE = Field(...) # type: ignore
58 | content: Optional[str] = Field(default=None)
59 | tool_calls: Optional[List[Any]] = Field(default=None)
60 | name: Optional[str] = Field(default=None)
61 | tool_call_id: Optional[str] = Field(default=None)
62 | base64_image: Optional[str] = Field(default=None)
63 |
64 | def __add__(self, other) -> List["Message"]:
65 | """支持 Message + list 或 Message + Message 的操作"""
66 | if isinstance(other, list):
67 | return [self] + other
68 | elif isinstance(other, Message):
69 | return [self, other]
70 | else:
71 | raise TypeError(
72 | f"unsupported operand type(s) for +: '{type(self).__name__}' and '{type(other).__name__}'"
73 | )
74 |
75 | def __radd__(self, other) -> List["Message"]:
76 | """支持 list + Message 的操作"""
77 | if isinstance(other, list):
78 | return other + [self]
79 | else:
80 | raise TypeError(
81 | f"unsupported operand type(s) for +: '{type(other).__name__}' and '{type(self).__name__}'"
82 | )
83 |
84 | def to_dict(self) -> dict:
85 | """Convert message to dictionary format"""
86 | message = {"role": self.role}
87 | if self.content is not None:
88 | message["content"] = self.content
89 | if self.tool_calls is not None:
90 | message["tool_calls"] = [
91 | # Handle both Pydantic model objects and plain dictionaries
92 | tool_call.model_dump()
93 | if hasattr(tool_call, "model_dump")
94 | else (tool_call.dict() if hasattr(tool_call, "dict") else tool_call)
95 | for tool_call in self.tool_calls
96 | ]
97 | if self.name is not None:
98 | message["name"] = self.name
99 | if self.tool_call_id is not None:
100 | message["tool_call_id"] = self.tool_call_id
101 | if self.base64_image is not None:
102 | message["base64_image"] = self.base64_image
103 | return message
104 |
105 | @classmethod
106 | def user_message(
107 | cls, content: str, base64_image: Optional[str] = None
108 | ) -> "Message":
109 | """Create a user message"""
110 | return cls(role=Role.USER, content=content, base64_image=base64_image)
111 |
112 | @classmethod
113 | def system_message(
114 | cls, content: str, base64_image: Optional[str] = None
115 | ) -> "Message":
116 | """Create a system message"""
117 | return cls(role=Role.SYSTEM, content=content, base64_image=base64_image)
118 |
119 | @classmethod
120 | def assistant_message(
121 | cls, content: Optional[str] = None, base64_image: Optional[str] = None
122 | ) -> "Message":
123 | """Create an assistant message"""
124 | return cls(role=Role.ASSISTANT, content=content, base64_image=base64_image)
125 |
126 | @classmethod
127 | def tool_message(
128 | cls, content: str, name, tool_call_id: str, base64_image: Optional[str] = None
129 | ) -> "Message":
130 | """Create a tool message"""
131 | return cls(
132 | role=Role.TOOL,
133 | content=content,
134 | name=name,
135 | tool_call_id=tool_call_id,
136 | base64_image=base64_image,
137 | )
138 |
139 | @classmethod
140 | def from_tool_calls(
141 | cls,
142 | tool_calls: List[Any],
143 | content: Union[str, List[str]] = "",
144 | base64_image: Optional[str] = None,
145 | **kwargs,
146 | ):
147 | """Create ToolCallsMessage from raw tool calls.
148 |
149 | Args:
150 | tool_calls: Raw tool calls from LLM
151 | content: Optional message content
152 | base64_image: Optional base64 encoded image
153 | """
154 | formatted_calls = [
155 | {"id": call.id, "function": call.function.model_dump(), "type": "function"}
156 | for call in tool_calls
157 | ]
158 | return cls(
159 | role=Role.ASSISTANT,
160 | content=content,
161 | tool_calls=formatted_calls,
162 | base64_image=base64_image,
163 | **kwargs,
164 | )
165 |
166 |
167 | class Memory(BaseModel):
168 | messages: List[Message] = Field(default_factory=list)
169 | max_messages: int = Field(default=100)
170 |
171 | def add_message(self, message: Message) -> None:
172 | """Add a message to memory"""
173 | self.messages.append(message)
174 | # Optional: Implement message limit
175 | if len(self.messages) > self.max_messages:
176 | self.messages = self.messages[-self.max_messages :]
177 |
178 | def add_messages(self, messages: List[Message]) -> None:
179 | """Add multiple messages to memory"""
180 | self.messages.extend(messages)
181 | # Optional: Implement message limit
182 | if len(self.messages) > self.max_messages:
183 | self.messages = self.messages[-self.max_messages :]
184 |
185 | def clear(self) -> None:
186 | """Clear all messages"""
187 | self.messages.clear()
188 |
189 | def get_recent_messages(self, n: int) -> List[Message]:
190 | """Get n most recent messages"""
191 | return self.messages[-n:]
192 |
193 | def to_dict_list(self) -> List[dict]:
194 | """Convert messages to list of dicts"""
195 | return [msg.to_dict() for msg in self.messages]
196 |
--------------------------------------------------------------------------------
/src/environment/research.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | from typing import Any, Dict
3 |
4 | from pydantic import Field
5 |
6 | from src.agent.chip_analysis import ChipAnalysisAgent
7 | from src.agent.hot_money import HotMoneyAgent
8 | from src.agent.risk_control import RiskControlAgent
9 | from src.agent.sentiment import SentimentAgent
10 | from src.agent.technical_analysis import TechnicalAnalysisAgent
11 | from src.agent.big_deal_analysis import BigDealAnalysisAgent
12 | from src.environment.base import BaseEnvironment
13 | from src.logger import logger
14 | from src.schema import Message
15 | from src.tool.stock_info_request import StockInfoRequest
16 | from src.utils.report_manager import report_manager
17 |
18 |
19 | class ResearchEnvironment(BaseEnvironment):
20 | """Environment for stock research using multiple specialized agents."""
21 |
22 | name: str = Field(default="research_environment")
23 | description: str = Field(default="Environment for comprehensive stock research")
24 | results: Dict[str, Any] = Field(default_factory=dict)
25 | max_steps: int = Field(default=3, description="Maximum steps for each agent")
26 |
27 | # Analysis mapping for agent roles
28 | analysis_mapping: Dict[str, str] = Field(
29 | default={
30 | "sentiment_agent": "sentiment",
31 | "risk_control_agent": "risk",
32 | "hot_money_agent": "hot_money",
33 | "technical_analysis_agent": "technical",
34 | "chip_analysis_agent": "chip_analysis",
35 | "big_deal_analysis_agent": "big_deal",
36 | }
37 | )
38 |
39 | async def initialize(self) -> None:
40 | """Initialize the research environment with specialized agents."""
41 | await super().initialize()
42 |
43 | # Create specialized analysis agents
44 | specialized_agents = {
45 | "sentiment_agent": await SentimentAgent.create(max_steps=self.max_steps),
46 | "risk_control_agent": await RiskControlAgent.create(max_steps=self.max_steps),
47 | "hot_money_agent": await HotMoneyAgent.create(max_steps=self.max_steps),
48 | "technical_analysis_agent": await TechnicalAnalysisAgent.create(max_steps=self.max_steps),
49 | "chip_analysis_agent": await ChipAnalysisAgent.create(max_steps=self.max_steps),
50 | "big_deal_analysis_agent": await BigDealAnalysisAgent.create(max_steps=self.max_steps),
51 | }
52 |
53 | # Register all agents
54 | for agent in specialized_agents.values():
55 | self.register_agent(agent)
56 |
57 | logger.info(f"Research environment initialized with 6 specialist agents (max_steps={self.max_steps})")
58 |
59 | async def run(self, stock_code: str) -> Dict[str, Any]:
60 | """Run research on the given stock code using all specialist agents."""
61 | logger.info(f"Running research on stock {stock_code}")
62 |
63 | try:
64 | # 获取股票基本信息
65 | basic_info_tool = StockInfoRequest()
66 | basic_info_result = await basic_info_tool.execute(stock_code=stock_code)
67 |
68 | if basic_info_result.error:
69 | logger.error(f"Error getting basic info: {basic_info_result.error}")
70 | else:
71 | # 将基本信息添加到每个agent的上下文中
72 | stock_info_message = f"""
73 | 股票代码: {stock_code}
74 | 当前交易日: {basic_info_result.output.get('current_trading_day', '未知')}
75 | 基本信息: {basic_info_result.output.get('basic_info', '{}')}
76 | """
77 |
78 | for agent_key in self.analysis_mapping.keys():
79 | agent = self.get_agent(agent_key)
80 | if agent and hasattr(agent, "memory"):
81 | agent.memory.add_message(
82 | Message.system_message(stock_info_message)
83 | )
84 | logger.info(f"Added basic stock info to {agent_key}'s context")
85 |
86 | # Run analysis tasks sequentially with 3-second intervals
87 | results = {}
88 | agent_count = 0
89 | total_agents = len([k for k in self.analysis_mapping.keys() if k in self.agents])
90 |
91 | # Import visualizer for progress display
92 | try:
93 | from src.console import visualizer
94 | show_visual = True
95 | except:
96 | show_visual = False
97 |
98 | for agent_key, result_key in self.analysis_mapping.items():
99 | if agent_key not in self.agents:
100 | continue
101 |
102 | agent_count += 1
103 | logger.info(f"🔄 Starting analysis with {agent_key} ({agent_count}/{total_agents})")
104 |
105 | # Show agent starting in terminal
106 | if show_visual:
107 | visualizer.show_agent_starting(agent_key, agent_count, total_agents)
108 |
109 | try:
110 | # Run individual agent
111 | result = await self.agents[agent_key].run(stock_code)
112 | results[result_key] = result
113 | logger.info(f"✅ Completed analysis with {agent_key}")
114 |
115 | # Show agent completion in terminal
116 | if show_visual:
117 | visualizer.show_agent_completed(agent_key, agent_count, total_agents)
118 |
119 | # Wait 3 seconds before next agent (except for the last one)
120 | if agent_count < total_agents:
121 | logger.info(f"⏳ Waiting 3 seconds before next agent...")
122 | if show_visual:
123 | visualizer.show_waiting_next_agent(3)
124 | await asyncio.sleep(3)
125 |
126 | except Exception as e:
127 | logger.error(f"❌ Error with {agent_key}: {str(e)}")
128 | results[result_key] = f"Error: {str(e)}"
129 |
130 | if not results:
131 | return {
132 | "error": "No specialist agents completed successfully",
133 | "stock_code": stock_code,
134 | }
135 |
136 | # 添加基本信息到结果中
137 | if not basic_info_result.error:
138 | results["basic_info"] = basic_info_result.output
139 |
140 | # Store and return complete results (without generating report here)
141 | self.results = {**results, "stock_code": stock_code}
142 | return self.results
143 |
144 | except Exception as e:
145 | logger.error(f"Error in research: {str(e)}")
146 | return {"error": str(e), "stock_code": stock_code}
147 |
148 | async def cleanup(self) -> None:
149 | """Clean up all agent resources."""
150 | cleanup_tasks = [
151 | agent.cleanup()
152 | for agent in self.agents.values()
153 | if hasattr(agent, "cleanup")
154 | ]
155 |
156 | if cleanup_tasks:
157 | await asyncio.gather(*cleanup_tasks)
158 |
159 | await super().cleanup()
160 |
--------------------------------------------------------------------------------
/src/tool/mcp_client.py:
--------------------------------------------------------------------------------
1 | from contextlib import AsyncExitStack
2 | from typing import Dict, List, Optional
3 |
4 | from mcp import ClientSession, StdioServerParameters
5 | from mcp.client.sse import sse_client
6 | from mcp.client.stdio import stdio_client
7 | from mcp.types import ListToolsResult, TextContent
8 | from src.logger import logger
9 | from src.tool.base import BaseTool, ToolResult
10 | from src.tool.tool_collection import ToolCollection
11 |
12 |
13 | class MCPClientTool(BaseTool):
14 | """Represents a tool proxy that can be called on the MCP server from the client side."""
15 |
16 | session: Optional[ClientSession] = None
17 | server_id: str = "" # Add server identifier
18 | original_name: str = ""
19 |
20 | async def execute(self, **kwargs) -> ToolResult:
21 | """Execute the tool by making a remote call to the MCP server."""
22 | if not self.session:
23 | return ToolResult(
24 | error="MCP server connection not available. This tool requires an active MCP server connection."
25 | )
26 |
27 | try:
28 | logger.info(f"Executing tool: {self.original_name}")
29 | result = await self.session.call_tool(self.original_name, kwargs)
30 | content_str = ", ".join(
31 | item.text for item in result.content if isinstance(item, TextContent)
32 | )
33 | return ToolResult(output=content_str or "No output returned.")
34 | except Exception as e:
35 | return ToolResult(error=f"Error executing tool: {str(e)}")
36 |
37 |
38 | class MCPClients(ToolCollection):
39 | """
40 | A collection of tools that connects to multiple MCP servers and manages available tools through the Model Context Protocol.
41 | """
42 |
43 | sessions: Dict[str, ClientSession] = {}
44 | exit_stacks: Dict[str, AsyncExitStack] = {}
45 | description: str = "MCP client tools for server interaction"
46 |
47 | def __init__(self):
48 | super().__init__() # Initialize with empty tools list
49 | self.name = "mcp" # Keep name for backward compatibility
50 |
51 | async def connect_sse(self, server_url: str, server_id: str = "") -> None:
52 | """Connect to an MCP server using SSE transport."""
53 | if not server_url:
54 | raise ValueError("Server URL is required.")
55 |
56 | server_id = server_id or server_url
57 |
58 | # Always ensure clean disconnection before new connection
59 | if server_id in self.sessions:
60 | await self.disconnect(server_id)
61 |
62 | exit_stack = AsyncExitStack()
63 | self.exit_stacks[server_id] = exit_stack
64 |
65 | streams_context = sse_client(url=server_url)
66 | streams = await exit_stack.enter_async_context(streams_context)
67 | session = await exit_stack.enter_async_context(ClientSession(*streams))
68 | self.sessions[server_id] = session
69 |
70 | await self._initialize_and_list_tools(server_id)
71 |
72 | async def connect_stdio(
73 | self, command: str, args: List[str], server_id: str = ""
74 | ) -> None:
75 | """Connect to an MCP server using stdio transport."""
76 | if not command:
77 | raise ValueError("Server command is required.")
78 |
79 | server_id = server_id or command
80 |
81 | # Always ensure clean disconnection before new connection
82 | if server_id in self.sessions:
83 | await self.disconnect(server_id)
84 |
85 | exit_stack = AsyncExitStack()
86 | self.exit_stacks[server_id] = exit_stack
87 |
88 | server_params = StdioServerParameters(command=command, args=args)
89 | stdio_transport = await exit_stack.enter_async_context(
90 | stdio_client(server_params)
91 | )
92 | read, write = stdio_transport
93 | session = await exit_stack.enter_async_context(ClientSession(read, write))
94 | self.sessions[server_id] = session
95 |
96 | await self._initialize_and_list_tools(server_id)
97 |
98 | async def _initialize_and_list_tools(self, server_id: str) -> None:
99 | """Initialize session and populate tool map."""
100 | session = self.sessions.get(server_id)
101 | if not session:
102 | raise RuntimeError(f"Session not initialized for server {server_id}")
103 |
104 | await session.initialize()
105 | response = await session.list_tools()
106 |
107 | # Create proper tool objects for each server tool
108 | for tool in response.tools:
109 | original_name = tool.name
110 | # Always prefix with server_id to ensure uniqueness
111 | tool_name = f"mcp_{server_id}_{original_name}"
112 |
113 | server_tool = MCPClientTool(
114 | name=tool_name,
115 | description=tool.description,
116 | parameters=tool.inputSchema,
117 | session=session,
118 | server_id=server_id,
119 | original_name=original_name,
120 | )
121 | self.tool_map[tool_name] = server_tool
122 |
123 | # Update tools tuple
124 | self.tools = tuple(self.tool_map.values())
125 | logger.info(
126 | f"Connected to server {server_id} with tools: {[tool.name for tool in response.tools]}"
127 | )
128 |
129 | async def list_tools(self) -> ListToolsResult:
130 | """List all available tools."""
131 | tools_result = ListToolsResult(tools=[])
132 | for session in self.sessions.values():
133 | response = await session.list_tools()
134 | tools_result.tools += response.tools
135 | return tools_result
136 |
137 | async def disconnect(self, server_id: str = "") -> None:
138 | """Disconnect from a specific MCP server or all servers if no server_id provided."""
139 | if server_id:
140 | if server_id in self.sessions:
141 | try:
142 | exit_stack = self.exit_stacks.get(server_id)
143 |
144 | # Close the exit stack which will handle session cleanup
145 | if exit_stack:
146 | try:
147 | await exit_stack.aclose()
148 | except RuntimeError as e:
149 | if "cancel scope" in str(e).lower():
150 | logger.warning(
151 | f"Cancel scope error during disconnect from {server_id}, continuing with cleanup: {e}"
152 | )
153 | else:
154 | raise
155 |
156 | # Clean up references
157 | self.sessions.pop(server_id, None)
158 | self.exit_stacks.pop(server_id, None)
159 |
160 | # Remove tools associated with this server
161 | self.tool_map = {
162 | k: v
163 | for k, v in self.tool_map.items()
164 | if v.server_id != server_id
165 | }
166 | self.tools = tuple(self.tool_map.values())
167 | logger.info(f"Disconnected from MCP server {server_id}")
168 | except Exception as e:
169 | logger.error(f"Error disconnecting from server {server_id}: {e}")
170 | else:
171 | # Disconnect from all servers in a deterministic order
172 | for sid in sorted(list(self.sessions.keys())):
173 | await self.disconnect(sid)
174 | self.tool_map = {}
175 | self.tools = tuple()
176 | logger.info("Disconnected from all MCP servers")
177 |
--------------------------------------------------------------------------------
/src/tool/financial_deep_search/index_capital.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 |
4 | """
5 | 获取东方财富上证指数资金流向数据
6 | API: https://push2.eastmoney.com/api/qt/stock/get
7 | """
8 |
9 | import json
10 | import os
11 | import re
12 | import time
13 | import traceback
14 | from datetime import datetime
15 |
16 | import requests
17 |
18 |
19 | # API URL - 上证指数(000001)资金流向
20 | INDEX_CAPITAL_FLOW_URL = "https://push2.eastmoney.com/api/qt/stock/get?invt=2&fltt=1&fields=f135,f136,f137,f138,f139,f140,f141,f142,f143,f144,f145,f146,f147,f148,f149&secid=1.000001&ut=fa5fd1943c7b386f172d6893dbfba10b&wbp2u=|0|0|0|web&dect=1"
21 |
22 | # 请求头设置
23 | HEADERS = {
24 | "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36",
25 | "Referer": "https://quote.eastmoney.com/",
26 | "Accept": "application/json, text/javascript, */*; q=0.01",
27 | }
28 |
29 |
30 | # 加载指数代码和名称映射
31 | def load_index_map():
32 | try:
33 | map_file = os.path.join(
34 | os.path.dirname(os.path.abspath(__file__)), "index_name_map.json"
35 | )
36 | if os.path.exists(map_file):
37 | with open(map_file, "r", encoding="utf-8") as f:
38 | return json.load(f)
39 | else:
40 | # 如果映射文件不存在,返回默认映射
41 | return {
42 | "000001": "上证指数",
43 | "399001": "深证成指",
44 | "399006": "创业板指",
45 | "000300": "沪深300",
46 | "000905": "中证500",
47 | "000016": "上证50",
48 | "000852": "中证1000",
49 | "000688": "科创50",
50 | "399673": "创业板50",
51 | }
52 | except Exception as e:
53 | print(f"加载指数映射文件失败: {e}")
54 | # 返回默认映射
55 | return {
56 | "000001": "上证指数",
57 | "399001": "深证成指",
58 | "399006": "创业板指",
59 | "000300": "沪深300",
60 | "000905": "中证500",
61 | }
62 |
63 |
64 | # 加载指数名称映射
65 | INDEX_CODE_NAME_MAP = load_index_map()
66 |
67 |
68 | def parse_jsonp(jsonp_str):
69 | """解析JSONP响应为JSON数据"""
70 | try:
71 | # 使用正则表达式提取JSON数据
72 | match = re.search(r"jQuery[0-9_]+\((.*)\)", jsonp_str)
73 | if match:
74 | json_str = match.group(1)
75 | return json.loads(json_str)
76 | else:
77 | # 如果不是JSONP格式,尝试直接解析JSON
78 | return json.loads(jsonp_str)
79 | except Exception as e:
80 | print(f"解析JSONP失败: {e}")
81 | print(f"原始数据: {jsonp_str[:100]}...") # 打印前100个字符用于调试
82 | return None
83 |
84 |
85 | def fetch_index_capital_flow(index_code="000001", max_retries=3, retry_delay=2):
86 | """
87 | 获取指数资金流向数据
88 |
89 | 参数:
90 | index_code: 指数代码,默认为上证指数(000001)
91 | max_retries: 最大重试次数
92 | retry_delay: 重试延迟时间(秒)
93 |
94 | 返回:
95 | dict: 包含资金流向数据的字典
96 | """
97 | # 根据指数代码构建API URL
98 | market = "1" # 1:上海 0:深圳
99 | if index_code.startswith("39") or index_code.startswith("1"):
100 | market = "0" # 深证指数
101 |
102 | url = INDEX_CAPITAL_FLOW_URL.replace(
103 | "secid=1.000001", f"secid={market}.{index_code}"
104 | )
105 |
106 | # 添加时间戳防止缓存
107 | timestamp = int(time.time() * 1000)
108 | if "?" in url:
109 | url += f"&_={timestamp}"
110 | else:
111 | url += f"?_={timestamp}"
112 |
113 | # 请求数据
114 | for attempt in range(1, max_retries + 1):
115 | try:
116 | resp = requests.get(url, headers=HEADERS, timeout=15)
117 | resp.raise_for_status()
118 |
119 | # 解析响应数据
120 | data = parse_jsonp(resp.text)
121 | if not data:
122 | print(f"解析指数资金流向数据失败 (第{attempt}次尝试)")
123 | if attempt < max_retries:
124 | time.sleep(retry_delay)
125 | continue
126 | return None
127 |
128 | # 提取资金流向数据
129 | flow_data = data.get("data", {})
130 | if not flow_data:
131 | print(f"未获取到指数资金流向数据 (第{attempt}次尝试)")
132 | if attempt < max_retries:
133 | time.sleep(retry_delay)
134 | continue
135 | return None
136 |
137 | # 返回数据
138 | return process_flow_data(flow_data, index_code)
139 |
140 | except Exception as e:
141 | print(f"获取指数资金流向数据失败: {e} (第{attempt}次尝试)")
142 | if attempt < max_retries:
143 | time.sleep(retry_delay)
144 | else:
145 | return None
146 |
147 |
148 | def process_flow_data(data, index_code):
149 | """
150 | 处理资金流向数据
151 |
152 | 参数:
153 | data: API返回的原始数据
154 | index_code: 指数代码
155 |
156 | 返回:
157 | dict: 处理后的资金流向数据
158 | """
159 | # 根据API返回的字段定义
160 | field_mapping = {
161 | "f135": "今日主力净流入",
162 | "f136": "今日主力流入",
163 | "f137": "今日主力流出",
164 | "f138": "今日超大单净流入",
165 | "f139": "今日超大单流入",
166 | "f140": "今日超大单流出",
167 | "f141": "今日大单净流入",
168 | "f142": "今日大单流入",
169 | "f143": "今日大单流出",
170 | "f144": "今日中单净流入",
171 | "f145": "今日中单流入",
172 | "f146": "今日中单流出",
173 | "f147": "今日小单净流入",
174 | "f148": "今日小单流入",
175 | "f149": "今日小单流出",
176 | }
177 |
178 | # 获取指数名称
179 | index_name = INDEX_CODE_NAME_MAP.get(index_code, f"指数{index_code}")
180 |
181 | # 提取数据并转换为更友好的格式
182 | result = {
183 | "指数代码": index_code,
184 | "指数名称": index_name,
185 | "更新时间": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
186 | }
187 |
188 | # 添加资金流向数据
189 | for field, label in field_mapping.items():
190 | # 将原始金额除以1亿,并保留2位小数
191 | if field in data:
192 | value = data.get(field, 0)
193 | if value:
194 | result[label] = round(float(value) / 100000000, 2) # 转换为亿元
195 | else:
196 | result[label] = 0
197 |
198 | return result
199 |
200 |
201 | def get_index_capital_flow(index_code="000001"):
202 | """
203 | 获取指数资金流向数据
204 |
205 | Args:
206 | index_code (str, optional): 指数代码,默认为上证指数 000001
207 |
208 | Returns:
209 | dict: 资金流向数据
210 | """
211 | try:
212 | # 获取数据
213 | flow_data = fetch_index_capital_flow(index_code)
214 |
215 | if not flow_data:
216 | return {
217 | "success": False,
218 | "message": f"获取指数{index_code}资金流向数据失败",
219 | "data": {},
220 | }
221 |
222 | # 获取指数名称
223 | index_name = flow_data.get(
224 | "指数名称", INDEX_CODE_NAME_MAP.get(index_code, f"指数{index_code}")
225 | )
226 |
227 | # 准备返回结果
228 | result = {
229 | "success": True,
230 | "message": f"成功获取{index_name}({index_code})资金流向数据",
231 | "last_updated": datetime.now().isoformat(),
232 | "data": flow_data,
233 | }
234 |
235 | return result
236 | except Exception as e:
237 | error_msg = f"获取指数资金流向数据时出错: {str(e)}"
238 | print(error_msg)
239 | print(traceback.format_exc())
240 | return {
241 | "success": False,
242 | "message": error_msg,
243 | "error": traceback.format_exc(),
244 | "data": {},
245 | }
246 |
247 |
248 | def main():
249 | """命令行调用入口函数"""
250 | import sys
251 |
252 | # 如果提供了参数,尝试按照参数获取特定指数的资金流向
253 | if len(sys.argv) > 1:
254 | index_code = sys.argv[1]
255 | result = get_index_capital_flow(index_code=index_code)
256 | else:
257 | # 否则获取上证指数资金流向
258 | result = get_index_capital_flow()
259 |
260 | # 打印结果
261 | print(json.dumps(result, ensure_ascii=False, indent=2))
262 |
263 |
264 | if __name__ == "__main__":
265 | main()
266 |
--------------------------------------------------------------------------------
/src/agent/base.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | from contextlib import asynccontextmanager
3 | from typing import List, Optional
4 |
5 | from pydantic import BaseModel, Field, model_validator
6 |
7 | from src.llm import LLM
8 | from src.logger import logger
9 | from src.schema import ROLE_TYPE, AgentState, Memory, Message
10 |
11 |
12 | class BaseAgent(BaseModel, ABC):
13 | """Abstract base class for managing agent state and execution.
14 |
15 | Provides foundational functionality for state transitions, memory management,
16 | and a step-based execution loop. Subclasses must implement the `step` method.
17 | """
18 |
19 | # Core attributes
20 | name: str = Field(..., description="Unique name of the agent")
21 | description: Optional[str] = Field(None, description="Optional agent description")
22 |
23 | # Prompts
24 | system_prompt: Optional[str] = Field(
25 | None, description="System-level instruction prompt"
26 | )
27 | next_step_prompt: Optional[str] = Field(
28 | None, description="Prompt for determining next action"
29 | )
30 |
31 | # Dependencies
32 | llm: LLM = Field(default_factory=LLM, description="Language model instance")
33 | memory: Memory = Field(default_factory=Memory, description="Agent's memory store")
34 | state: AgentState = Field(
35 | default=AgentState.IDLE, description="Current agent state"
36 | )
37 |
38 | # Execution control
39 | max_steps: int = Field(default=10, description="Maximum steps before termination")
40 | current_step: int = Field(default=0, description="Current step in execution")
41 |
42 | duplicate_threshold: int = 2
43 |
44 | class Config:
45 | arbitrary_types_allowed = True
46 | extra = "allow" # Allow extra fields for flexibility in subclasses
47 |
48 | @model_validator(mode="after")
49 | def initialize_agent(self) -> "BaseAgent":
50 | """Initialize agent with default settings if not provided."""
51 | if self.llm is None or not isinstance(self.llm, LLM):
52 | self.llm = LLM(config_name=self.name.lower())
53 | if not isinstance(self.memory, Memory):
54 | self.memory = Memory()
55 | return self
56 |
57 | @asynccontextmanager
58 | async def state_context(self, new_state: AgentState):
59 | """Context manager for safe agent state transitions.
60 |
61 | Args:
62 | new_state: The state to transition to during the context.
63 |
64 | Yields:
65 | None: Allows execution within the new state.
66 |
67 | Raises:
68 | ValueError: If the new_state is invalid.
69 | """
70 | if not isinstance(new_state, AgentState):
71 | raise ValueError(f"Invalid state: {new_state}")
72 |
73 | previous_state = self.state
74 | self.state = new_state
75 | try:
76 | yield
77 | except Exception as e:
78 | self.state = AgentState.ERROR # Transition to ERROR on failure
79 | raise e
80 | finally:
81 | self.state = previous_state # Revert to previous state
82 |
83 | def update_memory(
84 | self,
85 | role: ROLE_TYPE, # type: ignore
86 | content: str,
87 | base64_image: Optional[str] = None,
88 | **kwargs,
89 | ) -> None:
90 | """Add a message to the agent's memory.
91 |
92 | Args:
93 | role: The role of the message sender (user, system, assistant, tool).
94 | content: The message content.
95 | base64_image: Optional base64 encoded image.
96 | **kwargs: Additional arguments (e.g., tool_call_id for tool messages).
97 |
98 | Raises:
99 | ValueError: If the role is unsupported.
100 | """
101 | message_map = {
102 | "user": Message.user_message,
103 | "system": Message.system_message,
104 | "assistant": Message.assistant_message,
105 | "tool": lambda content, **kw: Message.tool_message(content, **kw),
106 | }
107 |
108 | if role not in message_map:
109 | raise ValueError(f"Unsupported message role: {role}")
110 |
111 | # Create message with appropriate parameters based on role
112 | kwargs = {"base64_image": base64_image, **(kwargs if role == "tool" else {})}
113 | self.memory.add_message(message_map[role](content, **kwargs))
114 |
115 | async def run(self, request: Optional[str] = None) -> str:
116 | """Execute the agent's main loop asynchronously.
117 |
118 | Args:
119 | request: Optional initial user request to process.
120 |
121 | Returns:
122 | A string summarizing the execution results.
123 |
124 | Raises:
125 | RuntimeError: If the agent is not in IDLE state at start.
126 | """
127 | if self.state != AgentState.IDLE:
128 | raise RuntimeError(f"Cannot run agent from state: {self.state}")
129 |
130 | if request:
131 | self.update_memory("user", request)
132 |
133 | results: List[str] = []
134 | async with self.state_context(AgentState.RUNNING):
135 | while (
136 | self.current_step < self.max_steps and self.state != AgentState.FINISHED
137 | ):
138 | self.current_step += 1
139 | logger.info(f"Executing step {self.current_step}/{self.max_steps}")
140 | step_result = await self.step()
141 |
142 | # Check for stuck state
143 | if self.is_stuck():
144 | self.handle_stuck_state()
145 |
146 | results.append(f"Step {self.current_step}: {step_result}")
147 |
148 | if self.current_step >= self.max_steps:
149 | self.current_step = 0
150 | self.state = AgentState.IDLE
151 | results.append(f"Terminated: Reached max steps ({self.max_steps})")
152 | return "\n".join(results) if results else "No steps executed"
153 |
154 | @abstractmethod
155 | async def step(self) -> str:
156 | """Execute a single step in the agent's workflow.
157 |
158 | Must be implemented by subclasses to define specific behavior.
159 | """
160 |
161 | def handle_stuck_state(self):
162 | """Handle stuck state by adding a prompt to change strategy"""
163 | stuck_prompt = "\
164 | Observed duplicate responses. Consider new strategies and avoid repeating ineffective paths already attempted."
165 | self.next_step_prompt = f"{stuck_prompt}\n{self.next_step_prompt}"
166 | logger.warning(f"Agent detected stuck state. Added prompt: {stuck_prompt}")
167 |
168 | def is_stuck(self) -> bool:
169 | """Check if the agent is stuck in a loop by detecting duplicate content"""
170 | if len(self.memory.messages) < 2:
171 | return False
172 |
173 | last_message = self.memory.messages[-1]
174 | if not last_message.content:
175 | return False
176 |
177 | # Count identical content occurrences
178 | duplicate_count = sum(
179 | 1
180 | for msg in reversed(self.memory.messages[:-1])
181 | if msg.role == "assistant" and msg.content == last_message.content
182 | )
183 |
184 | return duplicate_count >= self.duplicate_threshold
185 |
186 | @property
187 | def messages(self) -> List[Message]:
188 | """Retrieve a list of messages from the agent's memory."""
189 | return self.memory.messages
190 |
191 | @messages.setter
192 | def messages(self, value: List[Message]):
193 | """Set the list of messages in the agent's memory."""
194 | self.memory.messages = value
195 |
196 | def reset_execution_state(self) -> None:
197 | """Reset the agent's execution state to prepare for a new run.
198 |
199 | Resets the current step counter to zero and returns the agent to IDLE state.
200 | """
201 | self.current_step = 0
202 | self.state = AgentState.IDLE
203 | logger.info(f"Agent '{self.name}' execution state has been reset")
204 |
--------------------------------------------------------------------------------
/src/tool/big_deal_analysis.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Dict
2 | import time
3 | import pandas as pd # type: ignore
4 |
5 | from src.logger import logger
6 | from src.tool.base import BaseTool, ToolResult
7 |
8 | try:
9 | import akshare as ak # type: ignore
10 | except ImportError:
11 | ak = None # type: ignore
12 |
13 |
14 | class BigDealAnalysisTool(BaseTool):
15 | """Tool for analysing big order fund flows using akshare interfaces."""
16 |
17 | name: str = "big_deal_analysis_tool"
18 | description: str = (
19 | "获取市场及个股资金大单流向数据,并返回综合分析结果。"
20 | "调用 akshare 的 stock_fund_flow_big_deal、stock_fund_flow_individual、"
21 | "stock_individual_fund_flow、stock_zh_a_hist 接口。"
22 | )
23 | parameters: dict = {
24 | "type": "object",
25 | "properties": {
26 | "stock_code": {
27 | "type": "string",
28 | "description": "股票代码,如 '600036',若为空代表全市场分析",
29 | "default": "",
30 | },
31 | "top_n": {
32 | "type": "integer",
33 | "description": "排行前 N 名股票的数据",
34 | "default": 10,
35 | },
36 | "rank_symbol": {
37 | "type": "string",
38 | "description": "排行时间窗口,{\"即时\", \"3日排行\", \"5日排行\", \"10日排行\", \"20日排行\"}",
39 | "default": "即时",
40 | },
41 | "max_retry": {
42 | "type": "integer",
43 | "description": "最大重试次数",
44 | "default": 3,
45 | },
46 | "sleep_seconds": {
47 | "type": "integer",
48 | "description": "重试间隔秒数",
49 | "default": 1,
50 | },
51 | },
52 | }
53 |
54 | async def execute(
55 | self,
56 | stock_code: str = "",
57 | top_n: int = 10,
58 | rank_symbol: str = "即时",
59 | max_retry: int = 3,
60 | sleep_seconds: int = 1,
61 | **kwargs,
62 | ) -> ToolResult:
63 | """Fetch big deal fund flow data and return structured result."""
64 | if ak is None:
65 | return ToolResult(error="akshare library not installed")
66 |
67 | try:
68 | result: Dict[str, Any] = {}
69 |
70 | def _with_retry(func, *args, **kwargs):
71 | """Simple retry wrapper for unstable akshare endpoints."""
72 | for attempt in range(1, max_retry + 1):
73 | try:
74 | return func(*args, **kwargs)
75 | except Exception as e:
76 | if attempt >= max_retry:
77 | raise
78 | logger.warning(f"{func.__name__} attempt {attempt} failed: {e}. Retrying...")
79 | time.sleep(sleep_seconds)
80 |
81 | def _safe_fetch(func, *args, **kwargs):
82 | """Fetch data with retries; return None on ultimate failure instead of raising."""
83 | try:
84 | return _with_retry(func, *args, **kwargs)
85 | except Exception as e:
86 | logger.warning(f"{func.__name__} failed after {max_retry} attempts: {e}")
87 | return None
88 |
89 | # Market wide big deal flow (逐笔大单)
90 | df_bd = _safe_fetch(ak.stock_fund_flow_big_deal)
91 | if df_bd is not None and not df_bd.empty:
92 | # 清洗数字列
93 | def _to_float(series):
94 | return (
95 | series.astype(str)
96 | .str.replace(",", "", regex=False)
97 | .str.replace("亿", "e8") # unlikely present here
98 | .str.replace("万", "e4")
99 | .str.extract(r"([\d\.-eE]+)")[0]
100 | .astype(float)
101 | )
102 |
103 | if df_bd["成交额"].dtype == "O":
104 | df_bd["成交额"] = _to_float(df_bd["成交额"])
105 |
106 | # 计算买盘/卖盘汇总
107 | inflow = df_bd[df_bd["大单性质"] == "买盘"]["成交额"].sum()
108 | outflow = df_bd[df_bd["大单性质"] == "卖盘"]["成交额"].sum()
109 | result["market_summary"] = {
110 | "total_inflow_wan": round(inflow, 2),
111 | "total_outflow_wan": round(outflow, 2),
112 | "net_inflow_wan": round(inflow - outflow, 2),
113 | }
114 |
115 | # 按净额排序股票
116 | grouped = (
117 | df_bd.groupby(["股票代码", "股票简称", "大单性质"])["成交额"].sum().reset_index()
118 | )
119 | buy_df = grouped[grouped["大单性质"] == "买盘"].sort_values("成交额", ascending=False)
120 | sell_df = grouped[grouped["大单性质"] == "卖盘"].sort_values("成交额", ascending=False)
121 |
122 | result["top_inflow"] = buy_df.head(top_n).to_dict(orient="records")
123 | result["top_outflow"] = sell_df.head(top_n).to_dict(orient="records")
124 |
125 | # 保存部分原始逐笔记录以备调试(最多 top_n 条)
126 | result["market_big_deal_samples"] = df_bd.head(top_n).to_dict(orient="records")
127 | else:
128 | result["market_big_deal_samples"] = []
129 |
130 | # Individual fund flow rank 使用 stock_fund_flow_individual(symbol)
131 | individual_rank = _safe_fetch(ak.stock_fund_flow_individual, symbol=rank_symbol)
132 |
133 | # 默认返回排行榜前 top_n 条
134 | result["individual_rank_top"] = (
135 | individual_rank.head(top_n).to_dict(orient="records")
136 | if individual_rank is not None else []
137 | )
138 |
139 | # 若指定了 stock_code, 仅保留其对应行数据
140 | if stock_code and individual_rank is not None:
141 | rank_filtered = individual_rank[
142 | individual_rank["股票代码"].astype(str) == stock_code
143 | ]
144 | result["individual_rank_stock"] = (
145 | rank_filtered.to_dict(orient="records") if not rank_filtered.empty else []
146 | )
147 |
148 | if stock_code:
149 | # Stock specific fund flow trend 使用 stock_individual_fund_flow
150 | individual_flow = _safe_fetch(ak.stock_individual_fund_flow, stock=stock_code)
151 | result["stock_fund_flow"] = (
152 | individual_flow.to_dict(orient="records") if individual_flow is not None else []
153 | )
154 |
155 | # Historical price data for correlation
156 | hist_price = _safe_fetch(ak.stock_zh_a_hist, symbol=stock_code, period="daily")
157 | if hist_price is not None:
158 | result["stock_price_hist"] = hist_price.tail(120).to_dict(orient="records")
159 | else:
160 | result["stock_price_hist"] = []
161 |
162 | # 1. 先整体抓取逐笔大单
163 | # 复用已获取的 df_bd,若为空再尝试一次
164 | if df_bd is None:
165 | df_bd = _safe_fetch(ak.stock_fund_flow_big_deal)
166 |
167 | stk_df = pd.DataFrame()
168 | if df_bd is not None and not df_bd.empty:
169 | stk_df = df_bd[df_bd["股票代码"] == stock_code]
170 |
171 | if not stk_df.empty:
172 | inflow = stk_df[stk_df["大单性质"] == "买盘"]["成交额"].sum()
173 | outflow = stk_df[stk_df["大单性质"] == "卖盘"]["成交额"].sum()
174 |
175 | result["stock_big_deal_summary"] = {
176 | "inflow_wan": round(inflow, 2),
177 | "outflow_wan": round(outflow, 2),
178 | "net_inflow_wan": round(inflow - outflow, 2),
179 | "trade_count": len(stk_df),
180 | }
181 | result["stock_big_deal_samples"] = stk_df.head(top_n).to_dict(orient="records")
182 | else:
183 | result["stock_big_deal_summary"] = {}
184 | result["stock_big_deal_samples"] = []
185 |
186 | return ToolResult(output=result)
187 | except Exception as e:
188 | logger.error(f"BigDealAnalysisTool error: {e}")
189 | return ToolResult(error=str(e))
--------------------------------------------------------------------------------
/src/tool/hot_money.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import json
3 | from datetime import datetime
4 | from typing import Optional
5 |
6 | import efinance as ef
7 | import pandas as pd
8 | from pydantic import Field
9 |
10 | from src.logger import logger
11 | from src.tool.base import BaseTool, ToolResult, get_recent_trading_day
12 | from src.tool.financial_deep_search.get_section_data import get_all_section
13 | from src.tool.financial_deep_search.index_capital import get_index_capital_flow
14 | from src.tool.financial_deep_search.stock_capital import get_stock_capital_flow
15 |
16 |
17 | _HOT_MONEY_DESCRIPTION = """
18 | 获取股票热点资金和市场数据工具,用于分析主力资金流向和市场热点变化。
19 |
20 | 该工具提供以下精确数据服务:
21 | 1. 股票实时数据:提供目标股票的最新价格、涨跌幅、成交量、换手率、市值等关键指标
22 | 2. 龙虎榜数据:获取特定日期的市场龙虎榜,展示主力机构资金买卖方向和具体金额
23 | 3. 热门板块分析:按类型(概念、行业、地域)提供当日热门板块涨跌情况和资金流向
24 | 4. 个股资金流向:展示目标股票近期主力、散户、超大单资金净流入/流出数据
25 | 5. 大盘资金流向:提供指数级别的资金面分析,包括北向资金、融资余额等宏观指标
26 |
27 | 结果以结构化JSON格式返回,包含完整的数据类别、时间戳和数值指标。
28 | """
29 |
30 |
31 | class HotMoneyTool(BaseTool):
32 | """Tool for retrieving hot money and market data for stocks."""
33 |
34 | name: str = "hot_money_tool"
35 | description: str = _HOT_MONEY_DESCRIPTION
36 | parameters: dict = {
37 | "type": "object",
38 | "properties": {
39 | "stock_code": {
40 | "type": "string",
41 | "description": "股票代码(必填),如'600519'(贵州茅台)、'000001'(平安银行)、'300750'(宁德时代)等",
42 | },
43 | "index_code": {
44 | "type": "string",
45 | "description": "指数代码,如'000001'(上证指数)、'399001'(深证成指)。不提供则默认使用与股票代码相同的值",
46 | },
47 | "date": {
48 | "type": "string",
49 | "description": "查询日期,精确格式为YYYY-MM-DD(如'2023-05-15'),不提供则默认使用当天日期",
50 | "default": "",
51 | },
52 | "sector_types": {
53 | "type": "string",
54 | "description": "板块类型筛选,可选值:'all'(所有板块)、'hot'(热门板块)、'concept'(概念板块)、'regional'(地域板块)、'industry'(行业板块)",
55 | "default": "all",
56 | },
57 | "max_retry": {
58 | "type": "integer",
59 | "description": "数据获取最大重试次数,范围1-5,用于处理网络波动情况",
60 | "default": 3,
61 | },
62 | "sleep_seconds": {
63 | "type": "integer",
64 | "description": "重试间隔秒数,范围1-10,防止频繁请求被限制",
65 | "default": 1,
66 | },
67 | },
68 | "required": ["stock_code"],
69 | }
70 |
71 | lock: asyncio.Lock = Field(default_factory=asyncio.Lock)
72 |
73 | async def execute(
74 | self,
75 | stock_code: str,
76 | index_code: Optional[str] = None,
77 | date: str = "",
78 | sector_types: str = "all",
79 | max_retry: int = 3,
80 | sleep_seconds: int = 1,
81 | **kwargs,
82 | ) -> ToolResult:
83 | """
84 | Execute the hot money data retrieval operation.
85 |
86 | Args:
87 | stock_code: Stock code, e.g. "600519"
88 | index_code: Index code, e.g. "000001", defaults to stock_code if not provided
89 | date: Query date in YYYY-MM-DD format, defaults to current date
90 | sector_types: Sector types, options: 'all', 'hot', 'concept', 'regional', 'industry'
91 | max_retry: Maximum retry attempts, default 3
92 | sleep_seconds: Seconds to wait between retries, default 1
93 | **kwargs: Additional parameters
94 |
95 | Returns:
96 | ToolResult: Unified JSON format containing all data sources results or error message
97 | """
98 | async with self.lock:
99 | try:
100 | date = date or get_recent_trading_day()
101 | actual_index_code = index_code or stock_code
102 |
103 | result = {
104 | "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
105 | "stock_code": stock_code,
106 | "date": date,
107 | }
108 | if index_code:
109 | result["index_code"] = index_code
110 |
111 | # Get all data with retry mechanism
112 | data_sources = {
113 | "stock_latest_info": lambda: ef.stock.get_realtime_quotes(
114 | stock_code
115 | ),
116 | "daily_top_list": lambda: ef.stock.get_daily_billboard(
117 | start_date=date, end_date=date
118 | ),
119 | "hot_section_data": lambda: get_all_section(
120 | sector_types=sector_types
121 | ),
122 | "stock_net_flow": lambda: get_stock_capital_flow(
123 | stock_code=stock_code
124 | ),
125 | "index_net_flow": lambda: get_index_capital_flow(
126 | index_code=actual_index_code
127 | ),
128 | }
129 |
130 | # Retrieve each data source
131 | for key, func in data_sources.items():
132 | result[key] = await self._get_data_with_retry(
133 | func, key, max_retry, sleep_seconds
134 | )
135 |
136 | return ToolResult(output=result)
137 |
138 | except Exception as e:
139 | error_msg = f"Failed to get hot money data: {str(e)}"
140 | logger.error(error_msg)
141 | return ToolResult(error=error_msg)
142 |
143 | @staticmethod
144 | async def _get_data_with_retry(func, data_name, max_retry=3, sleep_seconds=1):
145 | """
146 | Get data with retry mechanism.
147 |
148 | Args:
149 | func: Function to call
150 | data_name: Data name (for logging)
151 | max_retry: Maximum retry attempts
152 | sleep_seconds: Seconds to wait between retries
153 |
154 | Returns:
155 | Function return data or None
156 | """
157 | last_error = None
158 | for attempt in range(1, max_retry + 1):
159 | try:
160 | # Use asyncio.to_thread for synchronous operations
161 | data = await asyncio.to_thread(func)
162 |
163 | # Convert data based on type
164 | if isinstance(data, pd.DataFrame):
165 | return data.to_dict(orient="records")
166 | elif isinstance(data, pd.Series):
167 | return data.to_dict()
168 | elif hasattr(data, "to_json"):
169 | return json.loads(data.to_json())
170 |
171 | logger.info(f"[{data_name}] Data retrieved successfully")
172 | return data
173 |
174 | except Exception as e:
175 | last_error = str(e)
176 | logger.warning(f"[{data_name}][Attempt {attempt}] Failed: {e}")
177 |
178 | if attempt < max_retry:
179 | await asyncio.sleep(sleep_seconds)
180 | logger.info(f"[{data_name}] Preparing attempt {attempt+1}...")
181 |
182 | logger.error(
183 | f"[{data_name}] Max retries ({max_retry}) reached, failed: {last_error}"
184 | )
185 | return None
186 |
187 |
188 | if __name__ == "__main__":
189 | import sys
190 |
191 | code = sys.argv[1] if len(sys.argv) > 1 else "600519"
192 | index_code = "000001" # Default to Shanghai Composite Index
193 |
194 | async def run_tool():
195 | tool = HotMoneyTool()
196 | result = await tool.execute(stock_code=code, index_code=index_code)
197 |
198 | if result.error:
199 | print(f"Failed: {result.error}")
200 | else:
201 | data = result.output
202 | print(f"Success! Timestamp: {data['timestamp']}")
203 | print(f"Stock Code: {data['stock_code']}")
204 | if "index_code" in data:
205 | print(f"Index Code: {data['index_code']}")
206 |
207 | for key in [
208 | "stock_latest_info",
209 | "daily_top_list",
210 | "hot_section_data",
211 | "stock_net_flow",
212 | "index_net_flow",
213 | ]:
214 | status = "Success" if data.get(key) is not None else "Failed"
215 | print(f"- {key}: {status}")
216 |
217 | filename = (
218 | f"hotmoney_data_{code}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
219 | )
220 | with open(filename, "w", encoding="utf-8") as f:
221 | json.dump(data, f, ensure_ascii=False, indent=2)
222 | print(f"\nComplete results saved to: {filename}")
223 |
224 | asyncio.run(run_tool())
225 |
--------------------------------------------------------------------------------
/src/agent/sentiment.py:
--------------------------------------------------------------------------------
1 | from typing import Any, List, Optional, Dict
2 |
3 | from pydantic import Field
4 |
5 | from src.agent.mcp import MCPAgent
6 | from src.prompt.mcp import NEXT_STEP_PROMPT_ZN
7 | from src.prompt.sentiment import SENTIMENT_SYSTEM_PROMPT
8 | from src.schema import Message
9 | from src.tool import Terminate, ToolCollection
10 | from src.tool.sentiment import SentimentTool
11 | from src.tool.web_search import WebSearch
12 |
13 | import logging
14 | logger = logging.getLogger(__name__)
15 |
16 |
17 | class SentimentAgent(MCPAgent):
18 | """Sentiment analysis agent focused on market sentiment and news."""
19 |
20 | name: str = "sentiment_agent"
21 | description: str = "Analyzes market sentiment, news, and social media for insights on stock performance."
22 | system_prompt: str = SENTIMENT_SYSTEM_PROMPT
23 | next_step_prompt: str = NEXT_STEP_PROMPT_ZN
24 |
25 | # Initialize with FinGenius tools with proper type annotation
26 | available_tools: ToolCollection = Field(
27 | default_factory=lambda: ToolCollection(
28 | SentimentTool(),
29 | WebSearch(),
30 | Terminate(),
31 | )
32 | )
33 |
34 | special_tool_names: List[str] = Field(default_factory=lambda: [Terminate().name])
35 |
36 | async def run(
37 | self, request: Optional[str] = None, stock_code: Optional[str] = None
38 | ) -> Any:
39 | """Run sentiment analysis on the given stock.
40 |
41 | Args:
42 | request: Optional initial request to process. If provided, overrides stock_code parameter.
43 | stock_code: The stock code/ticker to analyze
44 |
45 | Returns:
46 | Dictionary containing sentiment analysis results
47 | """
48 | # If stock_code is provided but request is not, create request from stock_code
49 | if stock_code and not request:
50 | # Set up system message about the stock being analyzed
51 | self.memory.add_message(
52 | Message.system_message(
53 | f"你正在分析股票 {stock_code} 的市场情绪。请收集相关新闻、社交媒体数据,并评估整体情绪。"
54 | )
55 | )
56 | request = f"请分析 {stock_code} 的市场情绪和相关新闻。"
57 |
58 | # Call parent implementation with the request
59 | return await super().run(request)
60 |
61 | async def analyze(self, stock_code: str, **kwargs) -> Dict:
62 | """执行舆情分析"""
63 | try:
64 | logger.info(f"开始舆情分析: {stock_code}")
65 |
66 | # 确保工具执行 - 添加强制执行逻辑
67 | analysis_tasks = []
68 |
69 | # 1. 强制执行新闻搜索
70 | try:
71 | news_result = await self.tool_call("web_search", {
72 | "query": f"{stock_code} 股票 最新消息 舆情",
73 | "max_results": 10
74 | })
75 | if news_result and news_result.success:
76 | analysis_tasks.append(("news_search", news_result.data))
77 | logger.info(f"新闻搜索成功: {stock_code}")
78 | else:
79 | logger.warning(f"新闻搜索失败: {stock_code}")
80 | except Exception as e:
81 | logger.error(f"新闻搜索异常: {stock_code}, {str(e)}")
82 |
83 | # 2. 强制执行社交媒体分析
84 | try:
85 | social_result = await self.tool_call("web_search", {
86 | "query": f"{stock_code} 股吧 讨论 情绪",
87 | "max_results": 5
88 | })
89 | if social_result and social_result.success:
90 | analysis_tasks.append(("social_media", social_result.data))
91 | logger.info(f"社交媒体分析成功: {stock_code}")
92 | else:
93 | logger.warning(f"社交媒体分析失败: {stock_code}")
94 | except Exception as e:
95 | logger.error(f"社交媒体分析异常: {stock_code}, {str(e)}")
96 |
97 | # 3. 强制执行舆情分析工具
98 | try:
99 | sentiment_result = await self.tool_call("sentiment_analysis", {
100 | "stock_code": stock_code,
101 | "analysis_type": "comprehensive"
102 | })
103 | if sentiment_result and sentiment_result.success:
104 | analysis_tasks.append(("sentiment_analysis", sentiment_result.data))
105 | logger.info(f"舆情分析工具成功: {stock_code}")
106 | else:
107 | logger.warning(f"舆情分析工具失败: {stock_code}")
108 | except Exception as e:
109 | logger.error(f"舆情分析工具异常: {stock_code}, {str(e)}")
110 |
111 | # 4. 综合分析结果
112 | if analysis_tasks:
113 | summary = self._generate_comprehensive_summary(analysis_tasks, stock_code)
114 | logger.info(f"舆情分析完成: {stock_code}, 执行了 {len(analysis_tasks)} 个任务")
115 | return {
116 | "success": True,
117 | "analysis_count": len(analysis_tasks),
118 | "summary": summary,
119 | "tasks_executed": [task[0] for task in analysis_tasks]
120 | }
121 | else:
122 | logger.warning(f"舆情分析没有成功执行任何任务: {stock_code}")
123 | return {
124 | "success": False,
125 | "analysis_count": 0,
126 | "summary": "无法获取舆情数据,请检查网络连接和数据源",
127 | "tasks_executed": []
128 | }
129 |
130 | except Exception as e:
131 | logger.error(f"舆情分析失败: {stock_code}, {str(e)}")
132 | return {
133 | "success": False,
134 | "error": str(e),
135 | "analysis_count": 0,
136 | "summary": f"舆情分析异常: {str(e)}"
137 | }
138 |
139 | def _generate_comprehensive_summary(self, analysis_tasks: List, stock_code: str) -> str:
140 | """生成综合舆情分析报告"""
141 | try:
142 | summary_parts = [f"## {stock_code} 舆情分析报告\n"]
143 |
144 | for task_name, task_data in analysis_tasks:
145 | if task_name == "news_search":
146 | summary_parts.append("### 📰 新闻舆情")
147 | summary_parts.append(f"- 搜索到 {len(task_data.get('results', []))} 条相关新闻")
148 | summary_parts.append(f"- 整体情绪倾向: {self._analyze_news_sentiment(task_data)}")
149 |
150 | elif task_name == "social_media":
151 | summary_parts.append("### 💬 社交媒体情绪")
152 | summary_parts.append(f"- 搜索到 {len(task_data.get('results', []))} 条相关讨论")
153 | summary_parts.append(f"- 投资者情绪: {self._analyze_social_sentiment(task_data)}")
154 |
155 | elif task_name == "sentiment_analysis":
156 | summary_parts.append("### 📊 专业舆情分析")
157 | summary_parts.append(f"- 情绪指数: {task_data.get('sentiment_score', 'N/A')}")
158 | summary_parts.append(f"- 风险等级: {task_data.get('risk_level', 'N/A')}")
159 |
160 | return "\n".join(summary_parts)
161 |
162 | except Exception as e:
163 | logger.error(f"生成舆情分析报告失败: {str(e)}")
164 | return f"舆情分析报告生成失败: {str(e)}"
165 |
166 | def _analyze_news_sentiment(self, data: Dict) -> str:
167 | """分析新闻情绪"""
168 | try:
169 | results = data.get('results', [])
170 | if not results:
171 | return "中性"
172 |
173 | # 简单的关键词情绪分析
174 | positive_keywords = ["上涨", "利好", "突破", "增长", "看好"]
175 | negative_keywords = ["下跌", "利空", "暴跌", "风险", "亏损"]
176 |
177 | positive_count = 0
178 | negative_count = 0
179 |
180 | for result in results:
181 | text = result.get('snippet', '') + result.get('title', '')
182 | for keyword in positive_keywords:
183 | if keyword in text:
184 | positive_count += 1
185 | for keyword in negative_keywords:
186 | if keyword in text:
187 | negative_count += 1
188 |
189 | if positive_count > negative_count:
190 | return "偏正面"
191 | elif negative_count > positive_count:
192 | return "偏负面"
193 | else:
194 | return "中性"
195 | except:
196 | return "中性"
197 |
198 | def _analyze_social_sentiment(self, data: Dict) -> str:
199 | """分析社交媒体情绪"""
200 | try:
201 | results = data.get('results', [])
202 | if not results:
203 | return "平淡"
204 |
205 | # 简单的讨论热度分析
206 | discussion_keywords = ["买入", "卖出", "持有", "看涨", "看跌"]
207 | keyword_count = 0
208 |
209 | for result in results:
210 | text = result.get('snippet', '') + result.get('title', '')
211 | for keyword in discussion_keywords:
212 | if keyword in text:
213 | keyword_count += 1
214 |
215 | if keyword_count >= 5:
216 | return "活跃"
217 | elif keyword_count >= 2:
218 | return "一般"
219 | else:
220 | return "平淡"
221 | except:
222 | return "平淡"
223 |
--------------------------------------------------------------------------------
/src/tool/financial_deep_search/stock_capital.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 |
4 | """
5 | 获取东方财富个股资金流向数据
6 | API: https://push2.eastmoney.com/api/qt/clist/get
7 | """
8 |
9 | import json
10 | import re
11 | import time
12 | import traceback
13 | from datetime import datetime
14 |
15 | import requests
16 |
17 |
18 | # API URL - 个股资金流向
19 | STOCK_CAPITAL_FLOW_URL = "https://push2.eastmoney.com/api/qt/clist/get?fid=f62&po=1&pz=50&pn=1&np=1&fltt=2&invt=2&ut=8dec03ba335b81bf4ebdf7b29ec27d15&fs=m%3A0%2Bt%3A6%2Bf%3A!2%2Cm%3A0%2Bt%3A13%2Bf%3A!2%2Cm%3A0%2Bt%3A80%2Bf%3A!2%2Cm%3A1%2Bt%3A2%2Bf%3A!2%2Cm%3A1%2Bt%3A23%2Bf%3A!2%2Cm%3A0%2Bt%3A7%2Bf%3A!2%2Cm%3A1%2Bt%3A3%2Bf%3A!2&fields=f12%2Cf14%2Cf2%2Cf3%2Cf62%2Cf184%2Cf66%2Cf69%2Cf72%2Cf75%2Cf78%2Cf81%2Cf84%2Cf87%2Cf204%2Cf205%2Cf124%2Cf1%2Cf13"
20 |
21 | # 请求头设置
22 | HEADERS = {
23 | "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36",
24 | "Referer": "https://quote.eastmoney.com/",
25 | "Accept": "application/json, text/javascript, */*; q=0.01",
26 | }
27 |
28 |
29 | def parse_jsonp(jsonp_str):
30 | """解析JSONP响应为JSON数据"""
31 | try:
32 | # 使用正则表达式提取JSON数据
33 | match = re.search(r"jQuery[0-9_]+\((.*)\)", jsonp_str)
34 | if match:
35 | json_str = match.group(1)
36 | return json.loads(json_str)
37 | else:
38 | # 如果不是JSONP格式,尝试直接解析JSON
39 | return json.loads(jsonp_str)
40 | except Exception as e:
41 | print(f"解析JSONP失败: {e}")
42 | print(f"原始数据: {jsonp_str[:100]}...") # 打印前100个字符用于调试
43 | return None
44 |
45 |
46 | def fetch_stock_list_capital_flow(
47 | page_size=50, page_num=1, max_retries=3, retry_delay=2
48 | ):
49 | """
50 | 获取股票列表的资金流向数据(按主力净流入排序)
51 |
52 | 参数:
53 | page_size: 每页显示数量,默认50
54 | page_num: 页码,默认第1页
55 | max_retries: 最大重试次数
56 | retry_delay: 重试延迟时间(秒)
57 |
58 | 返回:
59 | dict: 包含股票列表资金流向数据的字典
60 | """
61 | # 构建API URL,包含分页参数
62 | url = STOCK_CAPITAL_FLOW_URL.replace("pz=50", f"pz={page_size}").replace(
63 | "pn=1", f"pn={page_num}"
64 | )
65 |
66 | # 添加时间戳防止缓存
67 | timestamp = int(time.time() * 1000)
68 | if "?" in url:
69 | url += f"&_={timestamp}"
70 | else:
71 | url += f"?_={timestamp}"
72 |
73 | # 请求数据
74 | for attempt in range(1, max_retries + 1):
75 | try:
76 | resp = requests.get(url, headers=HEADERS, timeout=15)
77 | resp.raise_for_status()
78 |
79 | # 解析响应数据
80 | data = parse_jsonp(resp.text)
81 | if not data:
82 | print(f"解析个股资金流向数据失败 (第{attempt}次尝试)")
83 | if attempt < max_retries:
84 | time.sleep(retry_delay)
85 | continue
86 | return None
87 |
88 | # 提取资金流向数据
89 | stock_list = data.get("data", {}).get("diff", [])
90 | if not stock_list:
91 | print(f"未获取到个股资金流向数据 (第{attempt}次尝试)")
92 | if attempt < max_retries:
93 | time.sleep(retry_delay)
94 | continue
95 | return None
96 |
97 | # 处理股票数据
98 | return process_stock_list_data(
99 | stock_list, data.get("data", {}).get("total", 0)
100 | )
101 |
102 | except Exception as e:
103 | print(f"获取个股资金流向数据失败: {e} (第{attempt}次尝试)")
104 | if attempt < max_retries:
105 | time.sleep(retry_delay)
106 | else:
107 | return None
108 |
109 |
110 | def fetch_single_stock_capital_flow(stock_code, max_retries=3, retry_delay=2):
111 | """
112 | 获取单个股票的资金流向数据
113 |
114 | 参数:
115 | stock_code: 股票代码,如"000001"
116 | max_retries: 最大重试次数
117 | retry_delay: 重试延迟时间(秒)
118 |
119 | 返回:
120 | dict: 包含单个股票资金流向数据的字典,如果未找到则返回None
121 | """
122 | # 获取股票列表数据(多页搜索需要实现分页循环)
123 | for page in range(1, 10): # 最多查找10页
124 | stock_list = fetch_stock_list_capital_flow(50, page)
125 | if not stock_list:
126 | break
127 |
128 | # 在列表中查找指定股票
129 | for stock in stock_list.get("股票列表", []):
130 | if stock.get("股票代码") == stock_code:
131 | return {
132 | "success": True,
133 | "message": f"成功获取股票{stock.get('股票名称')}({stock_code})资金流向数据",
134 | "last_updated": datetime.now().isoformat(),
135 | "data": stock,
136 | }
137 |
138 | # 如果未找到目标股票,使用精确查询
139 | # 这里可以实现具体股票的接口查询,暂不实现
140 |
141 | return {"success": False, "message": f"未找到股票{stock_code}的资金流向数据", "data": {}}
142 |
143 |
144 | def process_stock_list_data(stock_list, total_count):
145 | """
146 | 处理股票列表资金流向数据
147 |
148 | 参数:
149 | stock_list: API返回的原始股票列表数据
150 | total_count: 总记录数
151 |
152 | 返回:
153 | dict: 处理后的资金流向数据
154 | """
155 | # 字段映射表
156 | field_mapping = {
157 | "f12": "股票代码",
158 | "f14": "股票名称",
159 | "f2": "最新价",
160 | "f3": "涨跌幅",
161 | "f62": "主力净流入",
162 | "f184": "主力净占比",
163 | "f66": "超大单净流入",
164 | "f69": "超大单净占比",
165 | "f72": "大单净流入",
166 | "f75": "大单净占比",
167 | "f78": "中单净流入",
168 | "f81": "中单净占比",
169 | "f84": "小单净流入",
170 | "f87": "小单净占比",
171 | "f124": "更新时间",
172 | "f1": "市场代码",
173 | "f13": "市场类型",
174 | }
175 |
176 | # 交易市场映射
177 | market_map = {
178 | 0: "SZ", # 深圳
179 | 1: "SH", # 上海
180 | 105: "NQ", # 纳斯达克
181 | 106: "NYSE", # 纽交所
182 | 107: "AMEX", # 美交所
183 | 116: "HK", # 港股
184 | 156: "LN", # 伦敦
185 | }
186 |
187 | # 初始化结果
188 | result = {
189 | "股票列表": [],
190 | "总数": total_count,
191 | "更新时间": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
192 | }
193 |
194 | # 处理每支股票数据
195 | for stock_item in stock_list:
196 | stock_data = {}
197 |
198 | # 转换每个字段
199 | for api_field, result_field in field_mapping.items():
200 | if api_field in stock_item:
201 | value = stock_item.get(api_field)
202 |
203 | # 特殊处理的字段
204 | if api_field == "f124": # 更新时间
205 | try:
206 | timestamp = int(value) / 1000
207 | stock_data[result_field] = datetime.fromtimestamp(
208 | timestamp
209 | ).strftime("%Y-%m-%d %H:%M:%S")
210 | except:
211 | stock_data[result_field] = "-"
212 | elif api_field in ["f62", "f66", "f72", "f78", "f84"]: # 资金流入流出金额
213 | stock_data[result_field] = (
214 | round(float(value) / 10000, 2) if value else 0
215 | ) # 转换为万元
216 | elif api_field in ["f3", "f184", "f69", "f75", "f81", "f87"]: # 百分比
217 | stock_data[result_field] = (
218 | round(float(value), 2) if value else 0
219 | ) # 保留两位小数
220 | elif api_field == "f13": # 市场类型
221 | market_code = value
222 | stock_data[result_field] = market_map.get(
223 | market_code, str(market_code)
224 | )
225 | else:
226 | stock_data[result_field] = value
227 |
228 | # 添加完整的股票代码
229 | if "股票代码" in stock_data and "市场类型" in stock_data:
230 | market_prefix = stock_data["市场类型"]
231 | stock_data["完整代码"] = f"{market_prefix}.{stock_data['股票代码']}"
232 |
233 | # 添加到结果列表
234 | result["股票列表"].append(stock_data)
235 |
236 | return result
237 |
238 |
239 | def get_stock_capital_flow(page_size=50, page_num=1, stock_code=None):
240 | """
241 | 获取股票资金流向数据,支持获取列表或单只股票数据
242 |
243 | 参数:
244 | page_size: 每页显示数量,默认50
245 | page_num: 页码,默认第1页
246 | stock_code: 股票代码,如果指定,则返回单只股票数据,否则返回列表
247 |
248 | 返回:
249 | dict: 包含资金流向数据的字典
250 | """
251 | try:
252 | # 获取数据(单只股票或列表)
253 | if stock_code:
254 | result = fetch_single_stock_capital_flow(stock_code)
255 | else:
256 | flow_data = fetch_stock_list_capital_flow(page_size, page_num)
257 | if not flow_data:
258 | return {"success": False, "message": f"获取股票资金流向数据失败", "data": {}}
259 |
260 | # 准备返回结果
261 | result = {
262 | "success": True,
263 | "message": f"成功获取股票资金流向数据,共{flow_data.get('总数', 0)}条",
264 | "last_updated": datetime.now().isoformat(),
265 | "data": flow_data,
266 | }
267 |
268 | return result
269 | except Exception as e:
270 | error_msg = f"获取股票资金流向数据时出错: {str(e)}"
271 | print(error_msg)
272 | print(traceback.format_exc())
273 | return {
274 | "success": False,
275 | "message": error_msg,
276 | "error": traceback.format_exc(),
277 | "data": {},
278 | }
279 |
280 |
281 | def main():
282 | """命令行调用入口函数"""
283 | import argparse
284 |
285 | parser = argparse.ArgumentParser(description="获取股票资金流向数据")
286 | parser.add_argument("--code", type=str, help="股票代码,不指定则获取列表")
287 | parser.add_argument("--page", type=int, default=1, help="页码,默认1")
288 | parser.add_argument("--size", type=int, default=50, help="每页数量,默认50")
289 | args = parser.parse_args()
290 |
291 | if args.code:
292 | result = get_stock_capital_flow(stock_code=args.code)
293 | else:
294 | result = get_stock_capital_flow(page_size=args.size, page_num=args.page)
295 |
296 | # 打印结果
297 | print(json.dumps(result, ensure_ascii=False, indent=2))
298 |
299 |
300 | if __name__ == "__main__":
301 | main()
302 |
--------------------------------------------------------------------------------
/src/mcp/server.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import asyncio
3 | import atexit
4 | import json
5 | import logging
6 | import sys
7 | from inspect import Parameter, Signature
8 | from typing import Any, Dict, Optional
9 |
10 | import uvicorn
11 | from starlette.applications import Starlette
12 | from starlette.middleware import Middleware
13 | from starlette.middleware.cors import CORSMiddleware
14 | from starlette.requests import Request
15 | from starlette.routing import Mount, Route
16 |
17 | from mcp.server import FastMCP, Server
18 | from mcp.server.sse import SseServerTransport
19 | from src.logger import logger
20 | from src.tool import BaseTool, Terminate
21 |
22 |
23 | logging.basicConfig(level=logging.INFO, handlers=[logging.StreamHandler(sys.stderr)])
24 |
25 |
26 | class MCPServer:
27 | """MCP Server implementation with tool registration and management."""
28 |
29 | def __init__(self, name: str = "FinGenius"):
30 | self.server = FastMCP(name)
31 | self.tools: Dict[str, BaseTool] = {}
32 |
33 | # Initialize standard tools
34 | self._initialize_standard_tools()
35 |
36 | def _initialize_standard_tools(self) -> None:
37 | """Initialize standard tools available in the server."""
38 | self.tools.update(
39 | {
40 | "terminate": Terminate(),
41 | }
42 | )
43 |
44 | def register_tool(self, tool: BaseTool, method_name: Optional[str] = None) -> None:
45 | """Register a tool with parameter validation and documentation."""
46 | tool_name = method_name or tool.name
47 | tool_param = tool.to_param()
48 | tool_function = tool_param["function"]
49 |
50 | # Define the async function to be registered
51 | async def tool_method(**kwargs):
52 | logger.info(f"Executing {tool_name}: {kwargs}")
53 | result = await tool.execute(**kwargs)
54 |
55 | logger.info(f"Result of {tool_name}: {result}")
56 |
57 | # Handle different types of results (match original logic)
58 | if hasattr(result, "model_dump"):
59 | return json.dumps(result.model_dump())
60 | elif isinstance(result, dict):
61 | return json.dumps(result)
62 | return result
63 |
64 | # Set method metadata
65 | tool_method.__name__ = tool_name
66 | tool_method.__doc__ = self._build_docstring(tool_function)
67 | tool_method.__signature__ = self._build_signature(tool_function)
68 |
69 | # Store parameter schema (important for tools that access it programmatically)
70 | param_props = tool_function.get("parameters", {}).get("properties", {})
71 | required_params = tool_function.get("parameters", {}).get("required", [])
72 | tool_method._parameter_schema = {
73 | param_name: {
74 | "description": param_details.get("description", ""),
75 | "type": param_details.get("type", "any"),
76 | "required": param_name in required_params,
77 | }
78 | for param_name, param_details in param_props.items()
79 | }
80 |
81 | # Register with server
82 | self.server.tool()(tool_method)
83 | logger.info(f"Registered tool: {tool_name}")
84 |
85 | def _build_docstring(self, tool_function: dict) -> str:
86 | """Build a formatted docstring from tool function metadata."""
87 | description = tool_function.get("description", "")
88 | param_props = tool_function.get("parameters", {}).get("properties", {})
89 | required_params = tool_function.get("parameters", {}).get("required", [])
90 |
91 | # Build docstring (match original format)
92 | docstring = description
93 | if param_props:
94 | docstring += "\n\nParameters:\n"
95 | for param_name, param_details in param_props.items():
96 | required_str = (
97 | "(required)" if param_name in required_params else "(optional)"
98 | )
99 | param_type = param_details.get("type", "any")
100 | param_desc = param_details.get("description", "")
101 | docstring += (
102 | f" {param_name} ({param_type}) {required_str}: {param_desc}\n"
103 | )
104 |
105 | return docstring
106 |
107 | def _build_signature(self, tool_function: dict) -> Signature:
108 | """Build a function signature from tool function metadata."""
109 | param_props = tool_function.get("parameters", {}).get("properties", {})
110 | required_params = tool_function.get("parameters", {}).get("required", [])
111 |
112 | parameters = []
113 |
114 | # Follow original type mapping
115 | for param_name, param_details in param_props.items():
116 | param_type = param_details.get("type", "")
117 | default = Parameter.empty if param_name in required_params else None
118 |
119 | # Map JSON Schema types to Python types (same as original)
120 | annotation = Any
121 | if param_type == "string":
122 | annotation = str
123 | elif param_type == "integer":
124 | annotation = int
125 | elif param_type == "number":
126 | annotation = float
127 | elif param_type == "boolean":
128 | annotation = bool
129 | elif param_type == "object":
130 | annotation = dict
131 | elif param_type == "array":
132 | annotation = list
133 |
134 | # Create parameter with same structure as original
135 | param = Parameter(
136 | name=param_name,
137 | kind=Parameter.KEYWORD_ONLY,
138 | default=default,
139 | annotation=annotation,
140 | )
141 | parameters.append(param)
142 |
143 | return Signature(parameters=parameters)
144 |
145 | async def cleanup(self) -> None:
146 | """Clean up server resources."""
147 | logger.info("Cleaning up resources")
148 | # Follow original cleanup logic - only clean browser tool
149 | if "browser" in self.tools and hasattr(self.tools["browser"], "cleanup"):
150 | await self.tools["browser"].cleanup()
151 |
152 | def register_all_tools(self) -> None:
153 | """Register all tools with the server."""
154 | for tool in self.tools.values():
155 | self.register_tool(tool)
156 |
157 | def run(self, transport: str = "stdio") -> None:
158 | """Run the MCP server."""
159 | # Register all tools
160 | self.register_all_tools()
161 |
162 | # Register cleanup function (match original behavior)
163 | atexit.register(lambda: asyncio.run(self.cleanup()))
164 |
165 | # Start server (with same logging as original)
166 | logger.info(f"Starting FinGenius server ({transport} mode)")
167 | self.server.run(transport=transport)
168 |
169 |
170 | def create_starlette_app(mcp_server: Server, *, debug: bool = False) -> Starlette:
171 | """Create a Starlette application that can serve the provided mcp server with SSE."""
172 | # Set up CORS middleware to allow connections from any origin
173 | middleware = [
174 | Middleware(
175 | CORSMiddleware,
176 | allow_origins=["*"],
177 | allow_methods=["*"],
178 | allow_headers=["*"],
179 | allow_credentials=True,
180 | )
181 | ]
182 |
183 | # Use '/messages/' as the endpoint path for SSE connections
184 | sse = SseServerTransport("/messages/")
185 |
186 | async def handle_sse(request: Request) -> None:
187 | try:
188 | logger.info(f"SSE connection request received from {request.client}")
189 | async with sse.connect_sse(
190 | request.scope,
191 | request.receive,
192 | request._send, # noqa: SLF001
193 | ) as (read_stream, write_stream):
194 | await mcp_server.run(
195 | read_stream,
196 | write_stream,
197 | mcp_server.create_initialization_options(),
198 | )
199 | except Exception as e:
200 | logger.error(f"Error handling SSE connection: {str(e)}")
201 | raise
202 |
203 | # Create Starlette app with middleware and routes
204 | app = Starlette(
205 | debug=debug,
206 | middleware=middleware,
207 | routes=[
208 | Route("/sse", endpoint=handle_sse),
209 | Mount("/messages/", app=sse.handle_post_message),
210 | ],
211 | )
212 |
213 | # Add a health check endpoint
214 | @app.route("/health")
215 | async def health_check(request: Request):
216 | from starlette.responses import JSONResponse
217 |
218 | return JSONResponse(
219 | {
220 | "status": "ok",
221 | "message": "FinGenius server is running",
222 | }
223 | )
224 |
225 | return app
226 |
227 |
228 | def parse_args() -> argparse.Namespace:
229 | """Parse command line arguments."""
230 | parser = argparse.ArgumentParser(description="FinGenius MCP Server")
231 | parser.add_argument(
232 | "--transport",
233 | choices=["stdio", "sse"],
234 | default="stdio",
235 | help="Communication method: stdio or sse (default: stdio)",
236 | )
237 | parser.add_argument("--host", default="0.0.0.0", help="Host to bind to (for sse)")
238 | parser.add_argument(
239 | "--port", type=int, default=8000, help="Port to listen on (for sse)"
240 | )
241 | parser.add_argument("--debug", action="store_true", help="Enable debug mode")
242 | return parser.parse_args()
243 |
244 |
245 | if __name__ == "__main__":
246 | args = parse_args()
247 |
248 | if args.transport == "sse":
249 | # Create an instance of MCPServer
250 | mcp_server = MCPServer()
251 | # Register all tools
252 | mcp_server.register_all_tools()
253 | # Get the underlying mcp_server from FastMCP
254 | underlying_server = mcp_server.server._mcp_server # noqa: WPS437
255 |
256 | logger.info(f"Starting FinGenius server with SSE on {args.host}:{args.port}")
257 |
258 | # Create Starlette application
259 | starlette_app = create_starlette_app(underlying_server, debug=args.debug)
260 |
261 | # Run with uvicorn
262 | uvicorn.run(
263 | starlette_app,
264 | host=args.host,
265 | port=args.port,
266 | log_level="info",
267 | access_log=True,
268 | )
269 | else:
270 | # Run in stdio mode
271 | mcp_server = MCPServer()
272 | mcp_server.run(transport=args.transport)
273 |
--------------------------------------------------------------------------------
/src/agent/toolcall.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import json
3 | from typing import Any, List, Optional, Union
4 |
5 | from pydantic import Field
6 |
7 | from src.agent.react import ReActAgent
8 | from src.exceptions import TokenLimitExceeded
9 | from src.llm import LLM
10 | from src.logger import logger
11 | from src.prompt.toolcall import NEXT_STEP_PROMPT, SYSTEM_PROMPT
12 | from src.schema import (
13 | TOOL_CHOICE_TYPE,
14 | AgentState,
15 | Memory,
16 | Message,
17 | ToolCall,
18 | ToolChoice,
19 | )
20 | from src.tool import Terminate, ToolCollection
21 |
22 |
23 | TOOL_CALL_REQUIRED = "Tool calls required but none provided"
24 |
25 |
26 | class ToolCallAgent(ReActAgent):
27 | """Base agent class for handling tool/function calls with enhanced abstraction"""
28 |
29 | name: str = "toolcall"
30 | description: str = "an agent that can execute tool calls."
31 |
32 | system_prompt: str = SYSTEM_PROMPT
33 | next_step_prompt: str = NEXT_STEP_PROMPT
34 |
35 | llm: Optional[LLM] = Field(default_factory=LLM)
36 | memory: Memory = Field(default_factory=Memory)
37 | state: AgentState = AgentState.IDLE
38 |
39 | available_tools: ToolCollection = ToolCollection(Terminate())
40 | tool_choices: TOOL_CHOICE_TYPE = ToolChoice.AUTO # type: ignore
41 | special_tool_names: List[str] = Field(default_factory=lambda: [Terminate().name])
42 |
43 | tool_calls: List[ToolCall] = Field(default_factory=list)
44 | _current_base64_image: Optional[str] = None
45 |
46 | max_steps: int = 30
47 | current_step: int = 0
48 |
49 | max_observe: Optional[Union[int, bool]] = None
50 |
51 | async def think(self) -> bool:
52 | """Process current state and decide next actions using tools"""
53 | if self.next_step_prompt:
54 | user_msg = Message.user_message(self.next_step_prompt)
55 | self.messages += [user_msg]
56 |
57 | try:
58 | # Get response with tool options
59 | response = await self.llm.ask_tool(
60 | messages=self.messages,
61 | system_msgs=(
62 | [Message.system_message(self.system_prompt)]
63 | if self.system_prompt
64 | else None
65 | ),
66 | tools=self.available_tools.to_params(),
67 | tool_choice=self.tool_choices,
68 | )
69 | except ValueError:
70 | raise
71 | except Exception as e:
72 | # Check if this is a RetryError containing TokenLimitExceeded
73 | if hasattr(e, "__cause__") and isinstance(e.__cause__, TokenLimitExceeded):
74 | token_limit_error = e.__cause__
75 | logger.error(
76 | f"🚨 Token limit error (from RetryError): {token_limit_error}"
77 | )
78 | self.memory.add_message(
79 | Message.assistant_message(
80 | f"Maximum token limit reached, cannot continue execution: {str(token_limit_error)}"
81 | )
82 | )
83 | self.state = AgentState.FINISHED
84 | return False
85 | raise
86 |
87 | self.tool_calls = tool_calls = (
88 | response.tool_calls if response and response.tool_calls else []
89 | )
90 | content = response.content if response and response.content else ""
91 |
92 | # Log response info
93 | logger.info(f"✨ {self.name}'s thoughts: {content}")
94 | logger.info(
95 | f"🛠️ {self.name} selected {len(tool_calls) if tool_calls else 0} tools to use"
96 | )
97 |
98 | # Show agent thinking in terminal if content exists
99 | if content and content.strip():
100 | from src.console import visualizer
101 | visualizer.show_agent_thought(self.name, content, "analysis")
102 | if tool_calls:
103 | logger.info(
104 | f"🧰 Tools being prepared: {[call.function.name for call in tool_calls]}"
105 | )
106 | logger.info(f"🔧 Tool arguments: {tool_calls[0].function.arguments}")
107 |
108 | try:
109 | if response is None:
110 | raise RuntimeError("No response received from the LLM")
111 |
112 | # Handle different tool_choices modes
113 | if self.tool_choices == ToolChoice.NONE:
114 | if tool_calls:
115 | logger.warning(
116 | f"🤔 Hmm, {self.name} tried to use tools when they weren't available!"
117 | )
118 | if content:
119 | self.memory.add_message(Message.assistant_message(content))
120 | return True
121 | return False
122 |
123 | # Create and add assistant message
124 | assistant_msg = (
125 | Message.from_tool_calls(content=content, tool_calls=self.tool_calls)
126 | if self.tool_calls
127 | else Message.assistant_message(content)
128 | )
129 | self.memory.add_message(assistant_msg)
130 |
131 | if self.tool_choices == ToolChoice.REQUIRED and not self.tool_calls:
132 | return True # Will be handled in act()
133 |
134 | # For 'auto' mode, continue with content if no commands but content exists
135 | if self.tool_choices == ToolChoice.AUTO and not self.tool_calls:
136 | return bool(content)
137 |
138 | return bool(self.tool_calls)
139 | except Exception as e:
140 | logger.error(f"🚨 Oops! The {self.name}'s thinking process hit a snag: {e}")
141 | self.memory.add_message(
142 | Message.assistant_message(
143 | f"Error encountered while processing: {str(e)}"
144 | )
145 | )
146 | return False
147 |
148 | async def act(self) -> str:
149 | """Execute tool calls and handle their results"""
150 | if not self.tool_calls:
151 | if self.tool_choices == ToolChoice.REQUIRED:
152 | raise ValueError(TOOL_CALL_REQUIRED)
153 |
154 | # Return last message content if no tool calls
155 | return self.messages[-1].content or "No content or commands to execute"
156 |
157 | results = []
158 | for command in self.tool_calls:
159 | # Reset base64_image for each tool call
160 | self._current_base64_image = None
161 |
162 | result = await self.execute_tool(command)
163 |
164 | if self.max_observe:
165 | result = result[: self.max_observe]
166 |
167 | logger.info(
168 | f"🎯 Tool '{command.function.name}' completed its mission! Result: {result}"
169 | )
170 |
171 | # Add tool response to memory
172 | tool_msg = Message.tool_message(
173 | content=result,
174 | tool_call_id=command.id,
175 | name=command.function.name,
176 | base64_image=self._current_base64_image,
177 | )
178 | self.memory.add_message(tool_msg)
179 | results.append(result)
180 |
181 | return "\n\n".join(results)
182 |
183 | async def execute_tool(self, command: ToolCall) -> str:
184 | """Execute a single tool call with robust error handling"""
185 | if not command or not command.function or not command.function.name:
186 | return "Error: Invalid command format"
187 |
188 | name = command.function.name
189 | if name not in self.available_tools.tool_map:
190 | return f"Error: Unknown tool '{name}'"
191 |
192 | try:
193 | # Parse arguments
194 | args = json.loads(command.function.arguments or "{}")
195 |
196 | # Execute the tool
197 | logger.info(f"🔧 Activating tool: '{name}'...")
198 | result = await self.available_tools.execute(name=name, tool_input=args)
199 |
200 | # Handle special tools
201 | await self._handle_special_tool(name=name, result=result)
202 |
203 | # Check if result is a ToolResult with base64_image
204 | if hasattr(result, "base64_image") and result.base64_image:
205 | # Store the base64_image for later use in tool_message
206 | self._current_base64_image = result.base64_image
207 |
208 | # Format result for display
209 | observation = (
210 | f"Observed output of cmd `{name}` executed:\n{str(result)}"
211 | if result
212 | else f"Cmd `{name}` completed with no output"
213 | )
214 | return observation
215 |
216 | # Format result for display (standard case)
217 | observation = (
218 | f"Observed output of cmd `{name}` executed:\n{str(result)}"
219 | if result
220 | else f"Cmd `{name}` completed with no output"
221 | )
222 |
223 | return observation
224 | except json.JSONDecodeError:
225 | error_msg = f"Error parsing arguments for {name}: Invalid JSON format"
226 | logger.error(
227 | f"📝 Oops! The arguments for '{name}' don't make sense - invalid JSON, arguments:{command.function.arguments}"
228 | )
229 | return f"Error: {error_msg}"
230 | except Exception as e:
231 | error_msg = f"⚠️ Tool '{name}' encountered a problem: {str(e)}"
232 | logger.exception(error_msg)
233 | return f"Error: {error_msg}"
234 |
235 | async def _handle_special_tool(self, name: str, result: Any, **kwargs):
236 | """Handle special tool execution and state changes"""
237 | if not self._is_special_tool(name):
238 | return
239 |
240 | if self._should_finish_execution(name=name, result=result, **kwargs):
241 | # Set agent state to finished
242 | logger.info(f"🏁 Special tool '{name}' has completed the task!")
243 | self.state = AgentState.FINISHED
244 |
245 | @staticmethod
246 | def _should_finish_execution(**kwargs) -> bool:
247 | """Determine if tool execution should finish the agent"""
248 | return True
249 |
250 | def _is_special_tool(self, name: str) -> bool:
251 | """Check if tool name is in special tools list"""
252 | return name.lower() in [n.lower() for n in self.special_tool_names]
253 |
254 | async def cleanup(self):
255 | """Clean up resources used by the agent's tools."""
256 | logger.info(f"🧹 Cleaning up resources for agent '{self.name}'...")
257 | for tool_name, tool_instance in self.available_tools.tool_map.items():
258 | if hasattr(tool_instance, "cleanup") and asyncio.iscoroutinefunction(
259 | tool_instance.cleanup
260 | ):
261 | try:
262 | logger.debug(f"🧼 Cleaning up tool: {tool_name}")
263 | await tool_instance.cleanup()
264 | except Exception as e:
265 | logger.error(
266 | f"🚨 Error cleaning up tool '{tool_name}': {e}", exc_info=True
267 | )
268 | logger.info(f"✨ Cleanup complete for agent '{self.name}'.")
269 |
270 | async def run(self, request: Optional[str] = None) -> str:
271 | """Run the agent with cleanup when done."""
272 | try:
273 | return await super().run(request)
274 | finally:
275 | await self.cleanup()
276 |
--------------------------------------------------------------------------------