├── requirements.txt
├── .streamlit
└── config.toml
├── rl_model_finrl
├── agents
│ ├── elegantrl
│ │ └── __init__.py
│ ├── stablebaseline3
│ │ ├── __init__.py
│ │ └── dqn_agent.py
│ ├── __init__.py
│ └── rllib
│ │ ├── __init__.py
│ │ └── ppo_agent.py
├── meta
│ ├── preprocessor
│ │ ├── __init__.py
│ │ ├── data_normalizer.py
│ │ └── feature_engineer.py
│ ├── __init__.py
│ └── data_processors
│ │ ├── __init__.py
│ │ └── akshare_processor.py
├── __init__.py
├── applications
│ └── stock_trading
│ │ ├── __init__.py
│ │ └── run_strategy.py
├── config.py
└── README.md
├── ui
├── help
│ └── intro.md
├── articles
│ ├── 雪球Token如何获取?.md
│ └── 什么是ETF?.md
└── pages
│ ├── settings.py
│ ├── market.py
│ └── sidebar.py
├── src
├── utils
│ ├── logger.py
│ ├── notification.py
│ ├── backtest_engine.py
│ └── plot.py
├── strategies
│ ├── strategy_factory.py
│ ├── market_sentiment
│ │ ├── utils.py
│ │ └── etf_dividend_handler.py
│ ├── rl_model_strategy.py
│ └── dual_ma_hedging
│ │ ├── ma_cross_hedge.py
│ │ ├── sync_long_hedge.py
│ │ └── macd_hedge.py
├── indicators
│ └── trailing_stop.py
├── trading
│ └── market_executor.py
└── data
│ └── data_loader.py
├── LICENSE
├── tools
├── visualization_README.md
└── plot_general_json.py
├── app.py
├── docs
├── etf_rotation_strategy.md
└── market_sentiment_strategy.md
├── README.md
└── .gitignore
/requirements.txt:
--------------------------------------------------------------------------------
1 | streamlit
2 | backtrader
3 | pandas
4 | numpy
5 | matplotlib
6 | loguru
7 | akshare
8 | tushare
9 | hyperopt
10 | plotly
11 | bokeh
12 | arch
13 | ray
14 | stable-baselines3
15 | gym
16 | torch
17 | scikit-learn
18 | dm_tree
--------------------------------------------------------------------------------
/.streamlit/config.toml:
--------------------------------------------------------------------------------
1 | [theme]
2 | backgroundColor="#ffffff"
3 | textColor="#262730"
4 | font="sans serif"
5 |
6 | [server]
7 | enableXsrfProtection = true
8 |
9 | [browser]
10 | gatherUsageStats = false
11 |
12 | [runner]
13 | fastRerenderThreshold = 1000
14 |
15 | [client]
16 | showErrorDetails = true
17 |
18 | [runner.magic]
19 | fastRerenderThreshold = 1000
--------------------------------------------------------------------------------
/rl_model_finrl/agents/elegantrl/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | 基于ElegantRL的强化学习智能体
3 |
4 | 此模块包含使用ElegantRL库实现的强化学习智能体:
5 | - PPOAgent: 基于PPO算法的交易智能体
6 |
7 | ElegantRL提供了高效、轻量级的深度强化学习实现,尤其适合金融应用场景。
8 | 其优势包括:
9 | - 更高效的训练速度和采样效率
10 | - 灵活的网络架构设计
11 | - 对多进程训练的良好支持
12 | """
13 |
14 | from src.strategies.rl_model_finrl.agents.elegantrl.ppo_agent import PPOAgent
15 |
16 | __all__ = [
17 | 'PPOAgent',
18 | ]
--------------------------------------------------------------------------------
/ui/help/intro.md:
--------------------------------------------------------------------------------
1 | ### 🎯 系统功能
2 | 这是一个专业的 量化交易策略回测系统,支持多种交易策略的回测和分析。系统具有以下特点:
3 |
4 | - 📊 **实时数据**:支持通过 Tushare(专业版)或 AKShare(免费)获取实时行情数据
5 | - 🚀 **多策略支持**:采用工厂模式设计,支持多种交易策略,便于扩展
6 | - 📈 **可视化分析**:使用 Plotly 提供交互式图表,包括 K 线、均线、交易点位等
7 | - ⚠️ **风险控制**:内置追踪止损、最大回撤限制等风险控制机制
8 | - 💰 **费用模拟**:精确计算交易费用,包括佣金等
9 | - 📝 **详细日志**:记录每笔交易的详细信息,便于分析和优化
10 |
11 | ### ⚠️ 风险提示
12 | 本系统仅供学习和研究使用,不构成任何投资建议。使用本系统进行实盘交易需要自行承担风险。
--------------------------------------------------------------------------------
/rl_model_finrl/agents/stablebaseline3/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | 基于Stable-Baselines3的强化学习智能体
3 |
4 | 此模块包含使用Stable-Baselines3库实现的强化学习智能体:
5 | - DQNAgent: 基于DQN算法的交易智能体
6 |
7 | 智能体功能包括:
8 | - 训练: 使用经验回放和目标网络进行深度Q学习
9 | - 预测: 基于学习策略选择最优交易动作
10 | - 测试: 使用学习策略进行回测
11 | - 保存/加载: 支持模型的保存和加载
12 | """
13 |
14 | from src.strategies.rl_model_finrl.agents.stablebaseline3.dqn_agent import DQNAgent
15 |
16 | __all__ = [
17 | 'DQNAgent',
18 | ]
--------------------------------------------------------------------------------
/rl_model_finrl/meta/preprocessor/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | 预处理器模块
3 |
4 | 这个模块提供了金融数据预处理的工具和类,用于强化学习环境中的数据准备。
5 |
6 | 主要组件:
7 | - FeatureEngineer: 金融特征工程工具
8 | - DataNormalizer: 数据归一化处理器
9 |
10 | 主要功能:
11 | - 技术指标计算与特征工程
12 | - 数据归一化和标准化
13 | - 缺失值处理
14 | - 异常值检测
15 | """
16 |
17 | from src.strategies.rl_model_finrl.meta.preprocessor.feature_engineer import FeatureEngineer
18 | from src.strategies.rl_model_finrl.meta.preprocessor.data_normalizer import DataNormalizer
19 |
20 | __all__ = [
21 | 'FeatureEngineer',
22 | 'DataNormalizer'
23 | ]
--------------------------------------------------------------------------------
/rl_model_finrl/meta/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | FinRL-Meta模块
3 |
4 | 此模块包含FinRL框架的基础组件,包括:
5 | - data_processors: 数据处理器,用于获取和处理市场数据
6 | - preprocessor: 数据预处理,用于生成技术指标和特征
7 |
8 | FinRL-Meta提供了金融强化学习的基础设施,支持:
9 | - 多种数据源的集成和处理
10 | - 标准化的交易环境接口
11 | - 丰富的市场特征和状态表示
12 | """
13 |
14 | from src.strategies.rl_model_finrl.meta.data_processors import TushareProcessor, AKShareProcessor
15 | from src.strategies.rl_model_finrl.meta.preprocessor import FeatureEngineer, DataNormalizer
16 |
17 | __all__ = [
18 | 'TushareProcessor',
19 | 'AKShareProcessor',
20 | 'FeatureEngineer',
21 | 'DataNormalizer'
22 | ]
--------------------------------------------------------------------------------
/rl_model_finrl/agents/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | 强化学习智能体模块
3 |
4 | 此模块包含各种实现的强化学习智能体,支持不同的算法库:
5 | - stablebaseline3: 基于Stable-Baselines3库的智能体
6 | - elegantrl: 基于ElegantRL库的智能体
7 | - rllib: 基于Ray RLlib库的智能体
8 |
9 | 当前支持的智能体:
10 | - DQNAgent: 基于DQN算法的强化学习智能体 (Stable-Baselines3)
11 | - PPOAgent: 基于PPO算法的强化学习智能体 (ElegantRL)
12 | - RLlibPPOAgent: 基于PPO算法的强化学习智能体 (Ray RLlib)
13 | """
14 |
15 | from src.strategies.rl_model_finrl.agents.stablebaseline3 import DQNAgent
16 | from src.strategies.rl_model_finrl.agents.elegantrl import PPOAgent
17 | from src.strategies.rl_model_finrl.agents.rllib import RLlibPPOAgent
18 |
19 | __all__ = [
20 | 'DQNAgent',
21 | 'PPOAgent',
22 | 'RLlibPPOAgent',
23 | ]
--------------------------------------------------------------------------------
/rl_model_finrl/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | RL模型基于FinRL框架优化的ETF交易模块
3 |
4 | 此模块包含使用强化学习进行ETF交易的完整框架,基于FinRL架构设计。
5 | 模块使用了三层架构:
6 | 1. 数据层(meta/data_processors): 负责从Tushare和AKShare获取A股ETF数据
7 | 2. 环境层(meta/env_stock_trading): 提供了ETF交易的交互环境
8 | 3. 智能体层(agents): 实现了各种RL算法,包括DQN等
9 |
10 | 主要功能:
11 | - 使用Tushare和AKShare获取A股ETF数据
12 | - 构建多ETF交易环境
13 | - 实现DQN等强化学习算法
14 | - 提供训练和回测功能
15 | """
16 |
17 | # 版本信息
18 | __version__ = "0.1.0"
19 |
20 | # 核心组件
21 | from src.strategies.rl_model_finrl.applications.stock_trading.etf_env import ETFTradingEnv
22 | from src.strategies.rl_model_finrl.agents.stablebaseline3.dqn_agent import DQNAgent
23 |
24 | # 导出接口
25 | __all__ = [
26 | 'ETFTradingEnv',
27 | 'DQNAgent'
28 | ]
--------------------------------------------------------------------------------
/rl_model_finrl/agents/rllib/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | RLlib智能体模块
3 |
4 | 这个模块提供了基于Ray RLlib框架的强化学习智能体实现,支持:
5 | - PPO (Proximal Policy Optimization)
6 |
7 | 主要组件:
8 | - RLlibPPOAgent: 基于RLlib的PPO算法实现
9 |
10 | 使用示例:
11 | ```python
12 | from src.strategies.rl_model_finrl.agents.rllib.ppo_agent import RLlibPPOAgent
13 | from src.strategies.rl_model_finrl.meta.env_stock_trading.etf_trading_env import ETFTradingEnv
14 |
15 | # 创建环境
16 | env = ETFTradingEnv(...)
17 |
18 | # 初始化PPO智能体
19 | agent = RLlibPPOAgent(env=env)
20 |
21 | # 训练模型
22 | agent.learn(total_timesteps=100000)
23 |
24 | # 保存模型
25 | agent.save("models/ppo_rllib")
26 |
27 | # 加载模型
28 | agent.load("models/ppo_rllib")
29 |
30 | # 测试模型
31 | asset_memory, action_memory = agent.test(test_env)
32 | ```
33 | """
34 |
35 | from src.strategies.rl_model_finrl.agents.rllib.ppo_agent import RLlibPPOAgent
36 |
37 | __all__ = [
38 | 'RLlibPPOAgent'
39 | ]
--------------------------------------------------------------------------------
/src/utils/logger.py:
--------------------------------------------------------------------------------
1 | from loguru import logger
2 | import sys
3 | import os
4 | from datetime import datetime
5 |
6 | def setup_logger():
7 | """配置日志记录器"""
8 | # 移除默认的处理程序
9 | logger.remove()
10 |
11 | # 获取当前日期作为日志文件名的一部分
12 | current_date = datetime.now().strftime("%Y%m%d")
13 | log_file = f"logs/backtest_{current_date}.log"
14 |
15 | # 确保日志目录存在
16 | os.makedirs("logs", exist_ok=True)
17 |
18 | # 添加控制台输出
19 | logger.add(
20 | sys.stderr,
21 | format="{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {name}:{function}:{line} - {message}",
22 | level="INFO"
23 | )
24 |
25 | # 添加文件输出
26 | logger.add(
27 | log_file,
28 | format="{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {name}:{function}:{line} - {message}",
29 | level="DEBUG",
30 | rotation="1 day", # 每天轮换一次日志文件
31 | retention="30 days", # 保留30天的日志
32 | compression="zip" # 压缩旧的日志文件
33 | )
34 |
35 | return logger
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2025 plan9x
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/src/strategies/strategy_factory.py:
--------------------------------------------------------------------------------
1 | from typing import Type, Dict
2 | import backtrader as bt
3 | from .dual_ma_hedging_strategy import DualMAHedgingStrategy
4 | from .dual_ma_strategy import DualMAStrategy
5 | from .market_sentiment_strategy import MarketSentimentStrategy
6 | from .etf_rotation_strategy import ETFRotationStrategy
7 |
8 | class StrategyFactory:
9 | _strategies: Dict[str, Type[bt.Strategy]] = {
10 | "双均线策略": DualMAStrategy,
11 | "市场情绪策略": MarketSentimentStrategy,
12 | "ETF轮动策略": ETFRotationStrategy,
13 | "双均线对冲策略": DualMAHedgingStrategy,
14 | }
15 |
16 | @classmethod
17 | def register_strategy(cls, name: str, strategy_class: Type[bt.Strategy]):
18 | """注册新的策略"""
19 | cls._strategies[name] = strategy_class
20 |
21 | @classmethod
22 | def get_strategy(cls, name: str) -> Type[bt.Strategy]:
23 | """获取策略类"""
24 | return cls._strategies.get(name)
25 |
26 | @classmethod
27 | def get_all_strategies(cls) -> Dict[str, Type[bt.Strategy]]:
28 | """获取所有已注册的策略"""
29 | return cls._strategies.copy()
30 |
31 | @classmethod
32 | def get_strategy_names(cls) -> list:
33 | """获取所有策略名称"""
34 | return list(cls._strategies.keys())
--------------------------------------------------------------------------------
/tools/visualization_README.md:
--------------------------------------------------------------------------------
1 | # JSON数据可视化工具
2 |
3 | 本仓库包含两个Streamlit应用程序,用于可视化JSON数据:
4 |
5 | 1. `plot_json_data.py` - 专门用于可视化cache目录数据的工具
6 | 2. `plot_general_json.py` - 可以处理各种JSON结构的通用工具
7 |
8 | ## 环境要求
9 |
10 | 所有需要的包都列在项目的 `requirements.txt` 文件中。可视化工具使用:
11 |
12 | - streamlit
13 | - pandas
14 | - numpy
15 | - matplotlib
16 | - plotly
17 |
18 | ## 如何运行
19 |
20 | ### 运行cache目录可视化工具
21 |
22 | ```bash
23 | streamlit run plot_json_data.py
24 | ```
25 |
26 | 该工具专门设计用于可视化具有预定义结构的cache目录下的ETF分红数据和市场情绪数据的JSON文件。它提供三种可视化类型:
27 |
28 | - 分红历史(时间序列)
29 | - 按年份的分红箱线图
30 | - 分红热力图(按年和月)
31 |
32 | ### 运行通用JSON可视化工具
33 |
34 | ```bash
35 | streamlit run plot_general_json.py
36 | ```
37 |
38 | 这个工具更加灵活,可以处理各种JSON数据结构。它支持:
39 |
40 | 1. 来自缓存目录的JSON文件
41 | 2. 用户上传的JSON文件
42 |
43 | 该工具会自动检测JSON数据的结构,并提供适当的可视化选项:
44 |
45 | - 时间序列(适用于带有日期/时间列的数据)
46 | - 柱状图
47 | - 散点图
48 | - 直方图
49 | - 箱线图
50 | - 热力图
51 |
52 | ## 功能特点
53 |
54 | - 自动数据类型检测
55 | - 使用Plotly的交互式可视化
56 | - 日期/时间字段自动检测
57 | - 支持嵌套的JSON结构
58 | - 数据统计显示
59 | - 原始数据查看选项
60 |
61 | ## 使用示例
62 |
63 | 1. 启动通用可视化工具
64 | 2. 从缓存目录选择JSON文件或上传您自己的文件
65 | 3. 从可用选项中选择可视化类型
66 | 4. 根据需要选择X轴、Y轴和其他参数
67 | 5. 与生成的可视化进行交互
68 | 6. 可选择通过勾选"显示原始数据"查看原始数据
69 |
70 | ## JSON数据格式支持
71 |
72 | 通用可视化工具可以处理:
73 |
74 | - 字典列表(最常见的数据可视化格式)
75 | - 嵌套字典
76 | - 具有嵌套数组和对象的混合结构
77 |
78 | 对于复杂的嵌套结构,该工具会尝试在保持关系的同时展平数据以实现可视化。
--------------------------------------------------------------------------------
/rl_model_finrl/applications/stock_trading/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | ETF交易应用模块
3 |
4 | 本模块提供了基于强化学习的ETF交易应用案例,用于展示如何使用FinRL框架构建ETF交易策略。
5 | 该应用在模拟和实际市场数据上进行回测,以评估策略的性能。
6 |
7 | 主要组件:
8 | 1. ETF交易环境: 一个特定为ETF交易定制的强化学习环境
9 | 2. 交易策略示例: 基于RL模型的ETF交易策略
10 | 3. 回测工具: 用于评估模型性能的回测工具
11 | 4. 性能分析: 交易策略的收益率、风险和其他相关指标分析
12 |
13 | 用法示例:
14 | ```python
15 | from src.strategies.rl_model_finrl.applications.stock_trading import (
16 | run_etf_strategy,
17 | backtest_etf_strategy,
18 | ETFTradingEnv
19 | )
20 |
21 | # 训练ETF交易策略
22 | trained_model = run_etf_strategy(
23 | start_date='2010-01-01',
24 | end_date='2020-12-31',
25 | ticker_list=['SPY', 'QQQ', 'IWM', 'EEM'],
26 | agent='ppo'
27 | )
28 |
29 | # 回测策略
30 | performance = backtest_etf_strategy(
31 | model=trained_model,
32 | test_start='2021-01-01',
33 | test_end='2021-12-31'
34 | )
35 | ```
36 | """
37 |
38 | from src.strategies.rl_model_finrl.applications.stock_trading.etf_env import ETFTradingEnv
39 | from src.strategies.rl_model_finrl.applications.stock_trading.run_strategy import run_etf_strategy
40 | from src.strategies.rl_model_finrl.applications.stock_trading.backtest import backtest_etf_strategy
41 | from src.strategies.rl_model_finrl.applications.stock_trading.analysis import ETFStrategyAnalyzer
42 |
43 | __all__ = [
44 | 'ETFTradingEnv',
45 | 'run_etf_strategy',
46 | 'backtest_etf_strategy',
47 | 'ETFStrategyAnalyzer'
48 | ]
--------------------------------------------------------------------------------
/ui/articles/雪球Token如何获取?.md:
--------------------------------------------------------------------------------
1 | # 雪球Token如何获取?
2 |
3 | 雪球实时行情数据可以通过以下接口获取:
4 | ```
5 | https://stock.xueqiu.com/v5/stock/quote.json?symbol=SH000001&extend=detail
6 | ```
7 | 注意:此接口需要token认证才能访问。
8 |
9 | ## 接口详情
10 |
11 | ### AKShare集成接口
12 | - 接口名称: stock_individual_spot_xq
13 | - 接口类型: 雪球-行情中心-个股
14 | - 目标地址: https://xueqiu.com/S/SH513520
15 | - 访问限制: 单次获取指定 symbol 的最新行情数据
16 |
17 | ### 认证信息获取步骤
18 |
19 | 获取token的具体方法如下:
20 |
21 | 1. 登录雪球官方网站 (https://xueqiu.com)
22 | 2. 打开浏览器开发者工具(按F12键)
23 | 3. 在浏览器中请求测试接口:
24 | ```
25 | https://stock.xueqiu.com/v5/stock/quote.json?symbol=SH000001&extend=detail
26 | ```
27 | 4. 在开发者工具的"网络"(Network)标签页中,找到该请求
28 | 5. 在请求的标头(Headers)信息中找到cookie字段,内容示例如下:
29 | ```
30 | cookie: cookiesu=631719021518417; device_id=c37e94779eede2e3f0250482d804ff81; s=af1w3zgebe; bid=03b31a66824a4a426b80229356033a8c_m4hqnffb; Hm_lvt_1db88642e346389874251b5a1eded6e3=1742519317; HMACCOUNT=2436B2ECFC061307; xq_a_token=d679467b716fd5b0a0af195f7e8143774d271a41; [... 其他cookie值 ...]
31 | ```
32 | 6. 从cookie中提取`xq_a_token`字段的值,这就是我们需要的token
33 |
34 | ### Token示例
35 | ```
36 | xq_a_token=d679467b716fd5b0a0af195f7e8143774d271a41
37 | ```
38 |
39 | ## 使用注意事项
40 | - Token具有时效性,一般有效期为7天,需要定期更新
41 | - 请合理控制接口调用频率,避免触发雪球的访问限制
42 | - 建议在程序中实现token自动更新机制
43 | - 不要公开分享或传播你的个人token
44 | - 使用token时建议通过环境变量或配置文件管理,避免硬编码
45 |
46 | ## 常见问题
47 | - 如遇到"未授权"错误,请检查token是否过期
48 | - 如遇到访问限制,请适当降低请求频率
49 | - Token仅对获取它时登录的账号有效
--------------------------------------------------------------------------------
/src/utils/notification.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import requests
4 | from src.utils.logger import setup_logger
5 |
6 | logger = setup_logger()
7 |
8 | def load_settings():
9 | """加载通知设置"""
10 | settings_file = "config/settings.json"
11 | if os.path.exists(settings_file):
12 | with open(settings_file, "r", encoding="utf-8") as f:
13 | return json.load(f)
14 | return {"sms": {"enabled": False}, "wechat": {"enabled": False}}
15 |
16 | def send_sms(message: str, api_key: str, phone_number: str):
17 | """发送短信通知"""
18 | # 这里需要根据实际使用的短信服务商API来实现
19 | # 示例使用阿里云短信服务
20 | try:
21 | # TODO: 实现实际的短信发送逻辑
22 | logger.info(f"发送短信到 {phone_number}: {message}")
23 | except Exception as e:
24 | logger.error(f"发送短信失败: {str(e)}")
25 |
26 | def send_wechat(message: str, webhook_url: str):
27 | """发送企业微信通知"""
28 | try:
29 | data = {
30 | "msgtype": "text",
31 | "text": {
32 | "content": message
33 | }
34 | }
35 | response = requests.post(webhook_url, json=data)
36 | response.raise_for_status()
37 | logger.info(f"发送企业微信通知成功: {message}")
38 | except Exception as e:
39 | logger.error(f"发送企业微信通知失败: {str(e)}")
40 |
41 | def send_notification(message: str):
42 | """发送通知"""
43 | settings = load_settings()
44 |
45 | # 发送短信通知
46 | if settings["sms"]["enabled"]:
47 | send_sms(
48 | message,
49 | settings["sms"]["api_key"],
50 | settings["sms"]["phone_number"]
51 | )
52 |
53 | # 发送企业微信通知
54 | if settings["wechat"]["enabled"]:
55 | send_wechat(
56 | message,
57 | settings["wechat"]["webhook_url"]
58 | )
--------------------------------------------------------------------------------
/ui/articles/什么是ETF?.md:
--------------------------------------------------------------------------------
1 | # 什么是ETF?
2 |
3 | ETF(Exchange Traded Fund,交易型开放式指数基金)是一种在证券交易所上市交易的基金,它既具有开放式基金的特点,又可以在二级市场进行交易。ETF在我国A股市场的发展虽然起步较晚,但近年来发展迅速,已成为投资者重要的投资工具。
4 |
5 | ## ETF的主要特点
6 |
7 | 1. **交易便利性**
8 | - 可以在证券交易所实时交易
9 | - 交易费用相对较低
10 | - 流动性好,可以随时买卖
11 |
12 | 2. **投资门槛低**
13 | - 最低交易单位通常为100份
14 | - 适合中小投资者参与
15 |
16 | 3. **透明度高**
17 | - 每日公布持仓信息
18 | - 价格实时反映市场变化
19 |
20 | ## ETF的主要类型
21 |
22 | 1. **股票ETF**
23 | - 跟踪各类股票指数
24 | - 如沪深300ETF、中证500ETF等
25 | - 最受欢迎的ETF类型
26 |
27 | 2. **债券ETF**
28 | - 跟踪债券指数
29 | - 如国债ETF、企业债ETF等
30 | - 风险相对较低
31 |
32 | 3. **商品ETF**
33 | - 跟踪黄金、原油等大宗商品
34 | - 如黄金ETF、原油ETF等
35 | - 提供商品投资渠道
36 |
37 | 4. **跨境ETF**
38 | - 跟踪海外市场指数
39 | - 如恒生ETF、标普500ETF等
40 | - 提供海外投资机会
41 |
42 | ## 投资ETF的优势
43 |
44 | 1. **分散风险**
45 | - 通过一个产品投资多个标的
46 | - 降低单一投资风险
47 |
48 | 2. **成本优势**
49 | - 管理费率较低
50 | - 交易成本相对较低
51 |
52 | 3. **操作灵活**
53 | - 可以像股票一样交易
54 | - 支持做多和做空
55 |
56 | ## 投资注意事项
57 |
58 | 1. **了解跟踪标的**
59 | - 仔细研究ETF跟踪的指数
60 | - 了解成分股构成
61 |
62 | 2. **关注流动性**
63 | - 选择交易活跃的ETF
64 | - 避免流动性风险
65 |
66 | 3. **注意交易成本**
67 | - 考虑管理费、交易费等
68 | - 选择成本较低的ETF
69 |
70 | ## ETF市场发展现状
71 |
72 | 截至2023年,我国ETF市场规模已突破2万亿元,产品数量超过800只。主要特点:
73 |
74 | 1. **产品创新**
75 | - 主题ETF不断涌现
76 | - 策略型ETF快速发展
77 |
78 | 2. **投资者结构**
79 | - 机构投资者占比提升
80 | - 个人投资者参与度增加
81 |
82 | 3. **监管完善**
83 | - 相关法规逐步健全
84 | - 市场秩序持续改善
85 |
86 | ## 未来展望
87 |
88 | 1. **产品多元化**
89 | - 更多创新产品推出
90 | - 投资策略更加丰富
91 |
92 | 2. **市场深化**
93 | - 交易机制优化
94 | - 投资者教育加强
95 |
96 | 3. **国际化发展**
97 | - 跨境产品增加
98 | - 国际投资者参与度提升
99 |
100 | ETF作为现代金融工具,在我国市场具有广阔的发展前景。投资者可以根据自身需求选择合适的ETF产品,实现资产配置和投资目标。
101 |
--------------------------------------------------------------------------------
/rl_model_finrl/config.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import torch
4 |
5 | # 设备配置
6 | DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
7 |
8 | # A股ETF数据配置
9 | TICKER_LIST = [
10 | "159915.SZ", # 易方达创业板ETF
11 | "510300.SH", # 华泰柏瑞沪深300ETF
12 | "510500.SH", # 南方中证500ETF
13 | "512100.SH", # 南方中证1000ETF
14 | "510050.SH", # 华夏上证50ETF
15 | "512880.SH", # 国泰中证军工ETF
16 | "512690.SH", # 鹏华中证医药卫生ETF
17 | "512980.SH", # 广发中证传媒ETF
18 | ] # A股ETF代码列表
19 |
20 | TECHNICAL_INDICATORS_LIST = [
21 | "macd",
22 | "boll_ub",
23 | "boll_lb",
24 | "rsi_30",
25 | "cci_30",
26 | "dx_30",
27 | "close_30_sma",
28 | "close_60_sma",
29 | "volatility_30",
30 | "momentum_30",
31 | ] # 技术指标列表
32 | INDICATORS_NORMALIZE = True # 是否标准化指标
33 | TURBULENCE_THRESHOLD = 0.01 # 波动性阈值
34 |
35 | # 交易环境配置
36 | INITIAL_AMOUNT = 1000000.0 # 初始资金
37 | TRANSACTION_COST_PCT = 0.0003 # 交易成本百分比,ETF一般费率更低
38 | MAX_POSITION_PCT = 0.3 # 单个ETF最大仓位
39 | REWARD_SCALING = 1e-3 # 奖励缩放系数
40 | STATE_SPACE_DIM = len(TECHNICAL_INDICATORS_LIST) + 4 # 状态空间维度 (技术指标 + 持仓量 + 现金比例 + 大盘指标 + 情绪指标)
41 | ACTION_SPACE_DIM = 3 # 动作空间维度(买入、卖出、持有)
42 |
43 | # 训练配置
44 | TRAIN_START_DATE = "2018-01-01"
45 | TRAIN_END_DATE = "2021-12-31"
46 | TEST_START_DATE = "2022-01-01"
47 | TEST_END_DATE = "2023-12-31"
48 | TIME_INTERVAL = "1d" # 时间间隔(日频)
49 |
50 | # 数据源配置
51 | TUSHARE_TOKEN = "" # 需要填入您的tushare token
52 | USE_TUSHARE = True
53 | USE_AKSHARE = True
54 |
55 | # 智能体配置
56 | REPLAY_BUFFER_SIZE = 100000 # 回放缓冲区大小
57 | GAMMA = 0.99 # 折扣因子
58 | LEARNING_RATE = 1e-4 # 学习率
59 | BATCH_SIZE = 256 # 批处理大小
60 | TARGET_UPDATE_FREQ = 100 # 目标网络更新频率
61 | NUM_EPISODES = 1000 # 训练回合数
62 | EPSILON_START = 0.9 # 探索率初始值
63 | EPSILON_END = 0.05 # 探索率终值
64 | EPSILON_DECAY = 1000 # 探索率衰减系数
65 |
66 | # 路径配置
67 | DATA_SAVE_PATH = "data"
68 | MODEL_SAVE_PATH = "models"
69 | RESULTS_PATH = "results"
70 | TENSORBOARD_PATH = "runs"
71 | TENSORBOARD_LOG_PATH = "runs/tensorboard"
72 |
73 | # 创建必要的目录
74 | for path in [DATA_SAVE_PATH, MODEL_SAVE_PATH, RESULTS_PATH, TENSORBOARD_PATH]:
75 | os.makedirs(path, exist_ok=True)
--------------------------------------------------------------------------------
/ui/pages/settings.py:
--------------------------------------------------------------------------------
1 | import streamlit as st
2 | import json
3 | import os
4 | from src.utils.logger import setup_logger
5 |
6 | logger = setup_logger()
7 |
8 | def load_settings():
9 | settings_file = "config/settings.json"
10 | if os.path.exists(settings_file):
11 | with open(settings_file, "r", encoding="utf-8") as f:
12 | return json.load(f)
13 | return {
14 | "sms": {
15 | "enabled": False,
16 | "api_key": "",
17 | "phone_number": ""
18 | },
19 | "wechat": {
20 | "enabled": False,
21 | "webhook_url": ""
22 | }
23 | }
24 |
25 | def save_settings(settings):
26 | os.makedirs("config", exist_ok=True)
27 | settings_file = "config/settings.json"
28 | with open(settings_file, "w", encoding="utf-8") as f:
29 | json.dump(settings, f, indent=4, ensure_ascii=False)
30 |
31 | def render_settings():
32 | st.header("系统设置")
33 |
34 | # 加载设置
35 | settings = load_settings()
36 |
37 | # 短信通知设置
38 | st.subheader("短信通知设置")
39 | sms_enabled = st.checkbox("启用短信通知", settings["sms"]["enabled"])
40 | if sms_enabled:
41 | sms_api_key = st.text_input("短信API密钥", settings["sms"]["api_key"])
42 | sms_phone = st.text_input("接收手机号", settings["sms"]["phone_number"])
43 | settings["sms"].update({
44 | "enabled": True,
45 | "api_key": sms_api_key,
46 | "phone_number": sms_phone
47 | })
48 | else:
49 | settings["sms"]["enabled"] = False
50 |
51 | # 微信通知设置
52 | st.subheader("微信通知设置")
53 | wechat_enabled = st.checkbox("启用微信通知", settings["wechat"]["enabled"])
54 | if wechat_enabled:
55 | webhook_url = st.text_input("企业微信Webhook地址", settings["wechat"]["webhook_url"])
56 | settings["wechat"].update({
57 | "enabled": True,
58 | "webhook_url": webhook_url
59 | })
60 | else:
61 | settings["wechat"]["enabled"] = False
62 |
63 | # 保存按钮
64 | if st.button("保存设置"):
65 | try:
66 | save_settings(settings)
67 | st.success("设置保存成功!")
68 | except Exception as e:
69 | st.error(f"保存设置时出错: {str(e)}")
70 | logger.error(f"保存设置时出错: {str(e)}")
71 |
--------------------------------------------------------------------------------
/app.py:
--------------------------------------------------------------------------------
1 | import streamlit as st
2 | import os
3 | from src.utils.logger import setup_logger
4 | from ui.pages.sidebar import render_sidebar
5 | from ui.pages.backtest import render_backtest
6 | from ui.pages.market import render_market
7 | from ui.pages.settings import render_settings
8 |
9 | # 设置日志
10 | logger = setup_logger()
11 |
12 | # 设置页面
13 | st.set_page_config(
14 | page_title="量化策略回测系统",
15 | layout="wide",
16 | initial_sidebar_state="expanded"
17 | )
18 |
19 | def main():
20 | st.title("量化策略回测系统")
21 |
22 | # 渲染侧边栏并获取参数
23 | params = render_sidebar()
24 | if params is None:
25 | return
26 |
27 | # 创建标签页
28 | tab1, tab2, tab3, tab4, tab5 = st.tabs(["系统介绍", "回测", "实盘记录", "文章", "系统设置"])
29 |
30 | # 系统介绍标签页
31 | with tab1:
32 | st.header("系统介绍")
33 | try:
34 | with open("ui/help/intro.md", "r", encoding="utf-8") as f:
35 | st.markdown(f.read())
36 | except FileNotFoundError:
37 | st.error("找不到系统介绍文件")
38 | except Exception as e:
39 | st.error(f"读取系统介绍文件时出错: {str(e)}")
40 |
41 | # 回测标签页
42 | with tab2:
43 | render_backtest(params)
44 |
45 | # 实盘记录标签页
46 | with tab3:
47 | render_market(params)
48 |
49 | # 文章标签页
50 | with tab4:
51 | st.header("文章列表")
52 | try:
53 | # 获取articles目录下的所有md文件
54 | articles_dir = "ui/articles"
55 | articles = [f for f in os.listdir(articles_dir) if f.endswith('.md')]
56 |
57 | # 创建两列布局
58 | col1, col2 = st.columns([1, 2])
59 |
60 | with col1:
61 | st.subheader("目录")
62 | # 创建文章列表
63 | for article in articles:
64 | if st.button(article.replace('.md', ''), key=f"article_{article}"):
65 | st.session_state.selected_article = article
66 |
67 | with col2:
68 | # 显示选中的文章内容
69 | selected_article = getattr(st.session_state, 'selected_article', None)
70 | if selected_article:
71 | with open(os.path.join(articles_dir, selected_article), "r", encoding="utf-8") as f:
72 | st.markdown(f.read())
73 | else:
74 | st.info("请从左侧选择要阅读的文章")
75 | except Exception as e:
76 | st.error(f"读取文章时出错: {str(e)}")
77 |
78 | # 系统设置标签页
79 | with tab5:
80 | render_settings()
81 |
82 | if __name__ == "__main__":
83 | main()
--------------------------------------------------------------------------------
/src/indicators/trailing_stop.py:
--------------------------------------------------------------------------------
1 | import backtrader as bt
2 | from loguru import logger
3 |
4 | class TrailingStop(bt.Indicator):
5 | """
6 | 追踪止损指标
7 | 当价格上涨时,止损线会跟随上涨,但不会随价格下跌而下跌
8 | """
9 | lines = ('trailing_stop',) # 声明指标线
10 | params = (
11 | ('trailing', 0.02), # 追踪止损比例,默认2%
12 | )
13 |
14 | plotinfo = dict(subplot=False) # 绘制在主图上
15 |
16 | def __init__(self):
17 | super(TrailingStop, self).__init__()
18 | self.plotinfo.plotname = f'TrailStop ({self.p.trailing:.1%})' # 图例名称
19 | self.max_price = float('-inf') # 初始化最高价为负无穷
20 | self.reset_requested = False # 用于标记是否需要重置
21 | self.in_trade = False # 标记是否在交易中
22 | self.entry_price = None # 记录入场价格
23 | self._prev_stop = 0.0 # 记录前一个止损价
24 | self._last_price = None # 添加上次价格记录
25 | self._call_count = 0 # 添加调用计数
26 |
27 | logger.info("初始化追踪止损指标 - 止损比例: {:.1%}", self.p.trailing)
28 |
29 | def reset(self, price=None):
30 | """重置最高价和止损价"""
31 | # 使用传入的价格或当前收盘价
32 | self.entry_price = price if price is not None else self.data.close[0]
33 | self.max_price = self.entry_price
34 | self._prev_stop = self.max_price * (1.0 - self.p.trailing)
35 | self.lines.trailing_stop[0] = self._prev_stop
36 | self.reset_requested = False
37 | self.in_trade = True
38 |
39 | def next(self):
40 | self._call_count += 1
41 | current_price = self.data.close[0]
42 | self._last_price = current_price
43 |
44 | # 如果在交易中,更新最高价和止损价
45 | if self.in_trade:
46 | # 如果价格创新高
47 | if current_price > self.max_price:
48 | self.max_price = current_price
49 | new_stop = self.max_price * (1.0 - self.p.trailing)
50 | # 确保新的止损价不低于之前的止损价
51 | self._prev_stop = max(new_stop, self._prev_stop)
52 | self.lines.trailing_stop[0] = self._prev_stop
53 | else:
54 | # 保持之前的止损价格
55 | self.lines.trailing_stop[0] = self._prev_stop
56 | else:
57 | # 不在交易中,设置止损价为0
58 | self.lines.trailing_stop[0] = 0.0
59 | # 重置最高价,为下次交易做准备
60 | self.max_price = float('-inf')
61 | self._prev_stop = 0.0
62 | self.entry_price = None
63 |
64 | def stop_tracking(self):
65 | """停止追踪"""
66 | if self.in_trade:
67 | logger.info("停止追踪止损 - 入场价: {:.2f}, 最高价: {:.2f}, 最终止损价: {:.2f}",
68 | self.entry_price, self.max_price, self._prev_stop)
69 | self.in_trade = False
--------------------------------------------------------------------------------
/docs/etf_rotation_strategy.md:
--------------------------------------------------------------------------------
1 | # ETF轮动策略(ETF Rotation Strategy)
2 |
3 | ## 策略概述
4 |
5 | ETF轮动策略是一种基于动量因子的资产配置方法,该策略通过定期评估多个ETF的相对强度,选择表现最佳的ETF进行投资。策略核心思想是"追涨",即买入近期表现最好的资产类别,利用市场趋势性和动量延续性特征获取超额收益。本策略主要应用于不同行业、主题或资产类别的ETF之间进行轮动配置。
6 |
7 | ## 核心理念
8 |
9 | 1. **动量因子驱动**:利用价格动量作为主要选股因子,买入强势ETF
10 | 2. **定期调仓机制**:按固定时间间隔对ETF池进行动量排名和调仓
11 | 3. **集中持仓策略**:只持有排名靠前的少数ETF,提高收益率
12 | 4. **风险管理体系**:设置追踪止损和最大回撤限制,控制下行风险
13 | 5. **资金管理方法**:根据波动率和风险比例动态调整仓位大小
14 |
15 | ## 动量指标构建
16 |
17 | 策略使用简单但有效的动量计算方法:
18 |
19 | 1. **动量计算**:
20 | - 使用N日价格变化百分比作为动量指标(默认为20日)
21 | - 计算公式:Momentum = (当前价格 / N日前价格) - 1
22 |
23 | 2. **排名机制**:
24 | - 对所有ETF的动量指标进行排序
25 | - 选择动量排名最高的前N只ETF(默认为1只)
26 |
27 | 3. **风险度量**:
28 | - 使用14日ATR(真实波动幅度均值)评估ETF的波动风险
29 | - 结合价格百分比设置动态止损线
30 |
31 | ## 交易规则
32 |
33 | 1. **调仓频率**:
34 | - 每隔固定天数(默认30天)进行一次调仓
35 | - 避免过于频繁交易,减少交易成本
36 |
37 | 2. **选择标准**:
38 | - 选择动量最强的前N只ETF(默认为1只)
39 | - 卖出不再位于前N名的持仓ETF
40 |
41 | 3. **头寸规模计算**:
42 | - 基于风险金额确定头寸大小
43 | - 考虑ATR和价格波动设置风险敞口
44 | - 调整为100股的整数倍交易单位
45 |
46 | 4. **交易执行**:
47 | - 在调仓日同时执行买入和卖出操作
48 | - 确保每笔交易至少达到最小交易单位(100股)
49 |
50 | ## 风险控制机制
51 |
52 | 1. **追踪止损**:
53 | - 设置百分比追踪止损(默认1.5%)
54 | - 价格回撤超过最高点一定比例时触发卖出
55 |
56 | 2. **最大回撤限制**:
57 | - 当组合回撤超过设定阈值(默认15%)时清仓
58 | - 避免在大幅下跌行情中持续损失
59 |
60 | 3. **风险头寸限制**:
61 | - 单次交易风险敞口不超过总资产的特定比例(默认2%)
62 | - 根据ETF波动性动态调整持仓规模
63 |
64 | 4. **资金分配**:
65 | - 在持有多只ETF时平均分配资金
66 | - 保留5%现金缓冲用于手续费和滑点
67 |
68 | ## 参数设置
69 |
70 | 主要参数包括:
71 |
72 | ```
73 | # 策略参数
74 | momentum_period: 20 # 动量计算周期
75 | rebalance_interval: 30 # 调仓间隔天数
76 | num_positions: 1 # 持有前N个ETF
77 | risk_ratio: 0.02 # 单次交易风险比率
78 | max_drawdown: 0.15 # 最大回撤限制
79 | trail_percent: 1.5 # 追踪止损百分比
80 | verbose: True # 是否输出详细日志
81 | ```
82 |
83 | ## 策略逻辑流程
84 |
85 | 1. **初始化**:
86 | - 为每个ETF设置动量和ATR指标
87 | - 记录ETF代码映射关系
88 |
89 | 2. **调仓检查**:
90 | - 确认当天是否为调仓日
91 | - 非调仓日仅检查止损条件
92 |
93 | 3. **动量排名**:
94 | - 计算所有ETF的动量指标
95 | - 按动量值从高到低排序
96 |
97 | 4. **仓位调整**:
98 | - 卖出不在排名前列的持仓ETF
99 | - 买入新进入排名的ETF
100 | - 计算新购买ETF的头寸大小
101 |
102 | 5. **风险监控**:
103 | - 每个交易日检查追踪止损条件
104 | - 监控总体回撤是否超过限制
105 |
106 | 6. **交易执行**:
107 | - 生成买入或卖出订单
108 | - 记录交易原因和执行情况
109 |
110 | ## 适用场景
111 |
112 | 此策略适用于:
113 |
114 | 1. 不同行业ETF之间的轮动(如科技、金融、消费等)
115 | 2. 不同资产类别ETF的配置(如股票、债券、商品等)
116 | 3. 不同地区市场ETF的轮动(如A股、港股、美股等)
117 | 4. 主题ETF之间的风格轮动(如价值、成长、中小盘等)
118 |
119 | ## 优势与局限性
120 |
121 | **优势**:
122 | - 跟随市场强势板块,利用趋势延续效应
123 | - 交易规则简单明确,易于实施和监控
124 | - 定期调仓减少交易频率和成本
125 | - 风险控制机制能有效限制下行风险
126 |
127 | **局限性**:
128 | - 市场反转时可能面临"追涨杀跌"风险
129 | - 频繁的板块轮换可能增加交易成本
130 | - 依赖历史动量数据,无法预测未来走势
131 | - 在横盘震荡市场中表现可能不佳
132 |
133 | ## 未来优化方向
134 |
135 | 1. 引入相对强度指标(RSI)作为辅助判断指标
136 | 2. 添加波动率因子过滤高风险ETF
137 | 3. 考虑价值指标与动量指标相结合
138 | 4. 增加择时功能规避市场系统性风险
139 | 5. 引入机器学习方法优化ETF选择和权重分配
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # 量化交易策略回测系统
2 |
3 | 一个基于 Python Backtrader回测框架的量化交易策略回测系统,使用 Streamlit 构建 Web 界面,支持多种交易策略的回测和分析。
4 | 注:ETF实盘交易使用东方财富证券可免收每笔交易至少5元的手续费,所以本策略回测系统只设置了佣金万2.5,以便更好的模拟实盘情况。
5 |
6 | ## 功能特点
7 |
8 | - 🚀 支持多种交易策略,采用工厂模式便于扩展
9 | - 📊 实时数据获取(支持 Tushare 和 AKShare)
10 | - 📈 交互式图表展示(K线、均线、交易点位等)
11 | - 💹 详细的回测指标(夏普比率、最大回撤、胜率等)
12 | - 🔄 T+1 交易规则支持
13 | - ⚠️ 风险控制(追踪止损、最大回撤限制等)
14 | - 💰 精确的交易费用计算(佣金等)
15 | - 📝 完整的交易日志
16 |
17 | ## 功能截图
18 | 
19 |
20 | ## 内置策略
21 |
22 | ### 双均线策略
23 | - 使用快速和慢速移动平均线的交叉产生交易信号
24 | - 支持追踪止损进行风险控制
25 | - 基于 ATR 动态计算持仓规模
26 |
27 | ### 市场情绪策略
28 | - 基于市场情绪指标进行交易,在极端情绪时入场
29 | - 使用 EMA 趋势确认,要求价格和 EMA 同步上涨
30 | - 动态调整 ATR 止盈倍数,结合布林带波动率和情绪因子
31 | - 分层建仓,在不同情绪阈值增加仓位
32 | - 支持追踪止损和最大回撤限制
33 |
34 | ## 安装使用
35 |
36 | 1. 克隆仓库:
37 | ```bash
38 | git clone https://github.com/sencloud/ETF-Strategies.git
39 | cd ETF-Strategies
40 | ```
41 |
42 | 2. 安装依赖,使用miniconda:
43 | ```bash
44 | # pip设置源
45 | pip config set global.index-url https://mirrors.aliyun.com/pypi/simple
46 | pip config set install.trusted-host mirrors.aliyun.com
47 | conda create -n etf python=3.10
48 | conda activate etf
49 | pip install -r requirements.txt
50 | ```
51 |
52 | 3. 安装TA-lib:
53 | ```bash
54 | conda install -c conda-forge ta-lib
55 | ```
56 |
57 | 4. 运行系统:
58 | ```bash
59 | streamlit run app.py
60 | ```
61 |
62 | ## 系统要求
63 |
64 | - Python 3.10+
65 | - 依赖包:
66 | - streamlit
67 | - backtrader
68 | - pandas
69 | - numpy
70 | - plotly
71 | - tushare
72 | - akshare
73 | - loguru
74 |
75 | ## 目录结构
76 |
77 | ```
78 | ETF-Strategies/
79 | ├── app.py # 主程序入口
80 | ├── requirements.txt # 依赖包列表
81 | ├── src/
82 | │ ├── data/ # 数据加载模块
83 | │ ├── strategies/ # 交易策略模块
84 | │ ├── indicators/ # 技术指标模块
85 | │ └── utils/ # 工具函数模块
86 | └── logs/ # 日志文件目录
87 | ```
88 |
89 | ## 如何添加新策略
90 |
91 | 1. 在 `src/strategies` 目录下创建新的策略类:
92 | ```python
93 | import backtrader as bt
94 |
95 | class YourStrategy(bt.Strategy):
96 | params = (
97 | # 定义策略参数
98 | )
99 |
100 | def __init__(self):
101 | # 初始化策略
102 | pass
103 |
104 | def next(self):
105 | # 实现交易逻辑
106 | pass
107 | ```
108 |
109 | 2. 在 `strategy_factory.py` 中注册策略:
110 | ```python
111 | from .your_strategy import YourStrategy
112 | StrategyFactory.register_strategy("策略名称", YourStrategy)
113 | ```
114 |
115 | 3. 在 `app.py` 中修改添加策略需要的参数,如下示意:
116 | ```python
117 | # 移动平均线参数(仅在选择双均线策略时显示)
118 | if strategy_name == "双均线策略":
119 | st.subheader("均线参数")
120 | col1, col2 = st.columns(2)
121 | with col1:
122 | fast_period = st.number_input("快线周期", value=10, min_value=1)
123 | with col2:
124 | slow_period = st.number_input("慢线周期", value=30, min_value=1)
125 | ...
126 | # 如果是双均线策略,添加特定参数
127 | if strategy_name == "双均线策略":
128 | strategy_params.update({
129 | 'fast_period': fast_period,
130 | 'slow_period': slow_period,
131 | })
132 | ...
133 | ```
134 |
135 | ## 回测指标说明
136 |
137 | - 总收益率:策略最终收益相对于初始资金的百分比
138 | - 夏普比率:超额收益相对于波动率的比率
139 | - 最大回撤:策略执行期间的最大亏损百分比
140 | - 胜率:盈利交易占总交易次数的比例
141 | - 盈亏比:平均盈利交易额与平均亏损交易额的比值
142 | - 系统质量指数(SQN):衡量交易系统的稳定性
143 |
144 | ## 风险提示
145 |
146 | 本系统仅供学习和研究使用,不构成任何投资建议。使用本系统进行实盘交易需要自行承担风险。
147 |
148 | ## 其他
149 | 如果你喜欢我的项目,可以给我买杯咖啡:
150 |
151 |
152 | ## 贡献指南
153 |
154 | 欢迎提交 Issue 和 Pull Request 来帮助改进这个项目。
155 |
156 | ## 许可证
157 |
158 | MIT License
159 |
--------------------------------------------------------------------------------
/docs/market_sentiment_strategy.md:
--------------------------------------------------------------------------------
1 | # 市场情绪指标策略(Market Sentiment Strategy)
2 |
3 | ## 策略概述
4 |
5 | 市场情绪指标策略是一种基于市场情绪数据的量化交易策略,通过监测市场情绪指标和技术分析指标的变化,在极端情绪区域进行逆向投资。该策略主要针对ETF等宽基指数产品,利用市场恐慌情绪作为买入信号,过度乐观情绪作为卖出信号,结合趋势识别和仓位管理,实现低买高卖的交易策略。
6 |
7 | ## 核心理念
8 |
9 | 1. **情绪逆向投资**:在市场极度恐慌时买入,在市场过度乐观时卖出
10 | 2. **多维度情绪度量**:综合考虑多个指数的技术指标、波动率、RSI等数据
11 | 3. **动态仓位管理**:根据市场情绪程度和波动环境动态调整仓位大小
12 | 4. **阶梯式止盈机制**:设置多个止盈阈值,随着情绪回升逐步减仓
13 | 5. **趋势识别与风险控制**:识别下跌趋势时避免操作,设置最大回撤限制
14 |
15 | ## 情绪指标构建
16 |
17 | 策略使用综合情绪指标,由以下数据源和指标构成:
18 |
19 | 1. **数据源**:
20 | - 上证指数(权重0.2)
21 | - 沪深300指数(权重0.5)
22 | - 上证50指数(权重0.2)
23 | - 金融指数(权重0.1)
24 |
25 | 2. **核心指标**:
26 | - GARCH波动率模型:估计条件波动率,区分上涨波动率和下跌波动率
27 | - RSI指标:使用21日EMA平滑处理,考虑顶背离和价格趋势
28 | - 布林带位置:价格相对均线的偏离程度
29 | - 成交量异动:相对于20日均线的成交量比率
30 | - 复合趋势检测:结合价格、成交量和RSI数据识别市场趋势
31 |
32 | 3. **改进特性**:
33 | - 混合归一化:使用动态窗口归一化处理各指标数据
34 | - S型钝化曲线:避免情绪指标出现过度敏感的极端值
35 | - 波动率过滤机制:在高波动期间降低情绪指标权重
36 |
37 | ## 交易信号与仓位管理
38 |
39 | 根据情绪指标得分,将买入信号分为三类:
40 |
41 | 1. **核心信号**(情绪分数 < 2.5):
42 | - 最高权重(50%目标仓位)
43 | - 资金占比至少10%或50万元
44 | - 不受价格限制约束
45 |
46 | 2. **次级信号**(情绪分数 2.5-10):
47 | - 中等权重(30%目标仓位)
48 | - 资金占比至少5%或20万元
49 | - 需满足价格低于持仓均价5%以上的条件
50 |
51 | 3. **轻仓信号**(情绪分数 10-15):
52 | - 最低权重(10%目标仓位)
53 | - 需满足价格和波动率条件
54 |
55 | ## 止盈策略
56 |
57 | 设计了三层阶梯式止盈机制:
58 |
59 | 1. **一级止盈**(情绪分数 > 10):
60 | - 减仓50%
61 | - 最小盈利要求1%
62 |
63 | 2. **二级止盈**(情绪分数 > 20):
64 | - 减仓30%
65 | - 最小盈利要求1%
66 |
67 | 3. **三级止盈**(情绪分数 > 80):
68 | - 全部清仓
69 | - 最小盈利要求1%
70 |
71 | 4. **动量保护机制**:
72 | - 当5日涨幅超过15%时减仓50%
73 | - 高波动环境(波动率>25%)减仓80%
74 |
75 | ## 风险控制机制
76 |
77 | 1. **最大回撤限制**:资产回撤超过15%时强制平仓
78 | 2. **趋势识别保护**:下跌趋势环境避免开仓
79 | 3. **追踪止损**:可选启用追踪止损功能,默认设置为2%
80 | 4. **T+1交易限制**:符合中国A股市场T+1交易规则
81 | 5. **价格涨跌停限制**:价格变动超过10%时不交易
82 |
83 | ## ETF分红处理
84 |
85 | 策略内置ETF分红处理功能:
86 |
87 | 1. 自动获取并缓存ETF历史分红数据
88 | 2. 在分红日调整持仓成本和现金
89 | 3. 分红收益单独统计,与价格收益分开计算
90 |
91 | ## 参数设置
92 |
93 | 主要参数包括:
94 |
95 | ```
96 | # 情绪阈值参数
97 | sentiment_core: 2.5 # 核心信号情绪阈值
98 | sentiment_secondary: 10.0 # 次级信号情绪阈值
99 | sentiment_warning: 15.0 # 预警信号情绪阈值
100 | sentiment_sell_1: 10.0 # 第一阶段止盈阈值
101 | sentiment_sell_2: 20.0 # 第二阶段止盈阈值
102 | sentiment_sell_3: 80.0 # 第三阶段止盈阈值
103 |
104 | # 仓位参数
105 | position_core: 0.5 # 核心信号仓位比例
106 | position_grid_step: 0.1 # 网格加仓步长
107 | position_warning: 0.05 # 预警信号初始仓位
108 | position_bb_signal: 0.1 # 布林带突破信号仓位
109 |
110 | # 风险控制参数
111 | min_momentum: -5.0 # 最小动量要求(10日涨跌幅)
112 | vol_threshold: 25.0 # 高波动率阈值
113 | quick_profit_days: 5 # 短期获利天数
114 | quick_profit_pct: 15.0 # 短期获利比例阈值
115 | high_vol_profit_pct: 20.0 # 高波动市场获利阈值
116 | trail_percent: 2.0 # 追踪止损百分比
117 | risk_ratio: 0.02 # 单次交易风险比率
118 | max_drawdown: 0.15 # 最大回撤限制
119 | ```
120 |
121 | ## 策略逻辑流程
122 |
123 | 1. **数据准备**:
124 | - 获取市场情绪数据
125 | - 计算技术指标(RSI、布林带、波动率等)
126 |
127 | 2. **市场状态识别**:
128 | - 识别当前是否处于下跌趋势
129 | - 计算市场波动率环境
130 |
131 | 3. **交易信号生成**:
132 | - 根据情绪分数生成交易信号
133 | - 根据市场状态过滤信号
134 |
135 | 4. **仓位计算**:
136 | - 根据信号类型确定目标仓位
137 | - 根据波动率调整仓位规模
138 |
139 | 5. **执行交易**:
140 | - 计算实际交易数量
141 | - 执行买入或卖出操作
142 |
143 | 6. **止盈管理**:
144 | - 监控情绪指标变化
145 | - 达到阈值时执行阶梯式减仓
146 |
147 | 7. **风险控制**:
148 | - 监控回撤和止损条件
149 | - 必要时执行风险管理措施
150 |
151 | ## 适用场景
152 |
153 | 此策略适用于:
154 |
155 | 1. ETF等宽基指数产品交易
156 | 2. 中长期价值投资者的择时工具
157 | 3. 市场剧烈波动环境下的逆向投资
158 | 4. 与其他策略组合形成多策略投资组合
159 |
160 | ## 优势与局限性
161 |
162 | **优势**:
163 | - 利用市场情绪波动带来的投资机会
164 | - 多维度技术指标综合评估,减少误判
165 | - 动态仓位管理适应不同市场环境
166 | - 阶梯式止盈保护收益
167 |
168 | **局限性**:
169 | - 依赖历史情绪数据的可靠性
170 | - 仓位调整频率较高,可能增加交易成本
171 | - 在趋势明确的单边市场中可能表现不佳
172 | - 需要持续调整参数以适应不同市场环境
173 |
174 | ## 未来优化方向
175 |
176 | 1. 加入基本面数据作为辅助指标
177 | 2. 优化GARCH模型以提高波动率预测准确性
178 | 3. 引入机器学习方法动态调整参数
179 | 4. 拓展到更多资产类别的情绪预测
180 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | cache/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 | cover/
54 |
55 | # Translations
56 | *.mo
57 | *.pot
58 |
59 | # Django stuff:
60 | *.log
61 | local_settings.py
62 | db.sqlite3
63 | db.sqlite3-journal
64 |
65 | # Flask stuff:
66 | instance/
67 | .webassets-cache
68 |
69 | # Scrapy stuff:
70 | .scrapy
71 |
72 | # Sphinx documentation
73 | docs/_build/
74 |
75 | # PyBuilder
76 | .pybuilder/
77 | target/
78 |
79 | # Jupyter Notebook
80 | .ipynb_checkpoints
81 |
82 | # IPython
83 | profile_default/
84 | ipython_config.py
85 |
86 | # pyenv
87 | # For a library or package, you might want to ignore these files since the code is
88 | # intended to run in multiple environments; otherwise, check them in:
89 | # .python-version
90 |
91 | # pipenv
92 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
93 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
94 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
95 | # install all needed dependencies.
96 | #Pipfile.lock
97 |
98 | # UV
99 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
100 | # This is especially recommended for binary packages to ensure reproducibility, and is more
101 | # commonly ignored for libraries.
102 | #uv.lock
103 |
104 | # poetry
105 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
106 | # This is especially recommended for binary packages to ensure reproducibility, and is more
107 | # commonly ignored for libraries.
108 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
109 | #poetry.lock
110 |
111 | # pdm
112 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
113 | #pdm.lock
114 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
115 | # in version control.
116 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
117 | .pdm.toml
118 | .pdm-python
119 | .pdm-build/
120 |
121 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
122 | __pypackages__/
123 |
124 | # Celery stuff
125 | celerybeat-schedule
126 | celerybeat.pid
127 |
128 | # SageMath parsed files
129 | *.sage.py
130 |
131 | # Environments
132 | .env
133 | .venv
134 | env/
135 | venv/
136 | ENV/
137 | env.bak/
138 | venv.bak/
139 |
140 | # Spyder project settings
141 | .spyderproject
142 | .spyproject
143 |
144 | # Rope project settings
145 | .ropeproject
146 |
147 | # mkdocs documentation
148 | /site
149 |
150 | # mypy
151 | .mypy_cache/
152 | .dmypy.json
153 | dmypy.json
154 |
155 | # Pyre type checker
156 | .pyre/
157 |
158 | # pytype static type analyzer
159 | .pytype/
160 |
161 | # Cython debug symbols
162 | cython_debug/
163 |
164 | # PyCharm
165 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
166 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
167 | # and can be added to the global gitignore or merged into this file. For a more nuclear
168 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
169 | #.idea/
170 |
171 | # Ruff stuff:
172 | .ruff_cache/
173 |
174 | # PyPI configuration file
175 | .pypirc
176 |
--------------------------------------------------------------------------------
/src/utils/backtest_engine.py:
--------------------------------------------------------------------------------
1 | import backtrader as bt
2 | import pandas as pd
3 | from datetime import datetime
4 | from loguru import logger
5 | from .analysis import Analysis
6 | from .plot import Plot
7 | import sys
8 |
9 | class BacktestEngine:
10 | def __init__(self, strategy_class, data_feed, cash=100000.0, commission=0.00025, strategy_params=None):
11 | """初始化回测引擎
12 | Args:
13 | strategy_class: 策略类
14 | data_feed: 数据源或数据源列表
15 | cash: 初始资金
16 | commission: 股票交易手续费率
17 | strategy_params: 策略参数
18 | """
19 | self.cerebro = bt.Cerebro()
20 | self.cerebro.broker.setcash(cash)
21 |
22 | # 设置手续费
23 | if isinstance(data_feed, list):
24 | # 默认设置股票/ETF手续费率
25 | self.cerebro.broker.setcommission(commission=commission)
26 |
27 | # 创建多数据源的特定手续费处理
28 | for i, feed in enumerate(data_feed):
29 | if i == 1: # 期货数据源使用固定手续费
30 | # 为期货设置固定手续费
31 | self.cerebro.broker.addcommissioninfo(
32 | bt.CommissionInfo(
33 | commission=1.51/100000, # 固定手续费转换为相对值
34 | margin=0.10, # 保证金比例
35 | mult=10, # 合约乘数
36 | commtype=0 # 固定手续费类型(0=固定手续费)
37 | )
38 | )
39 | else:
40 | self.cerebro.broker.setcommission(commission=commission)
41 |
42 | # 添加数据源
43 | try:
44 | if isinstance(data_feed, list):
45 | # 如果是数据源列表,添加所有数据源
46 | for feed in data_feed:
47 | self.cerebro.adddata(feed)
48 | else:
49 | # 如果是单个数据源,直接添加
50 | self.cerebro.adddata(data_feed)
51 | except Exception as e:
52 | logger.warning(f"添加数据源时出错: {str(e)}")
53 | # 如果出错,尝试不带ts_code参数添加
54 | if hasattr(data_feed, 'params'):
55 | data_feed.params.pop('ts_code', None)
56 | self.cerebro.adddata(data_feed)
57 |
58 | # 添加策略和参数
59 | if strategy_params:
60 | self.cerebro.addstrategy(strategy_class, **strategy_params)
61 | else:
62 | self.cerebro.addstrategy(strategy_class)
63 |
64 | # 添加分析器
65 | self.cerebro.addanalyzer(bt.analyzers.SharpeRatio, _name='sharpe', riskfreerate=0.02)
66 | self.cerebro.addanalyzer(bt.analyzers.DrawDown, _name='drawdown')
67 | self.cerebro.addanalyzer(bt.analyzers.Returns, _name='returns')
68 | self.cerebro.addanalyzer(bt.analyzers.TradeAnalyzer, _name='trades')
69 | self.cerebro.addanalyzer(bt.analyzers.VWR, _name='vwr') # 波动率加权收益
70 | self.cerebro.addanalyzer(bt.analyzers.SQN, _name='sqn') # 系统质量指数
71 |
72 | # 为每个数据源添加单独的交易记录分析器
73 | if isinstance(data_feed, list):
74 | for feed in data_feed:
75 | self.cerebro.addanalyzer(bt.analyzers.Transactions, _name=f'txn_{feed._name}')
76 | else:
77 | self.cerebro.addanalyzer(bt.analyzers.Transactions, _name='txn')
78 |
79 | self.trades = [] # 存储交易记录
80 |
81 | def run(self):
82 | """运行回测"""
83 | results = self.cerebro.run()
84 |
85 | self.strategy = results[0]
86 |
87 | analysis = self._get_analysis(self.strategy)
88 |
89 | logger.info("=== 回测统计 ===")
90 | logger.info(f"总收益率: {analysis['total_return']:.2%}")
91 | logger.info(f"年化收益率: {analysis['annualized_return']:.2%}")
92 | logger.info(f"夏普比率: {analysis['sharpe_ratio']:.2f}")
93 | logger.info(f"最大回撤: {analysis['max_drawdown']:.2%}")
94 | logger.info(f"胜率: {analysis['win_rate']:.2%}")
95 | logger.info(f"盈亏比: {analysis['profit_factor']:.2f}")
96 | logger.info(f"系统质量指数(SQN): {analysis['sqn']:.2f}")
97 |
98 | return analysis
99 |
100 | def plot(self, **kwargs):
101 | """使用Plotly绘制交互式回测结果"""
102 | fig = Plot(self.strategy).plot()
103 | return fig
104 |
105 | def _get_analysis(self, strategy):
106 | """获取回测分析结果"""
107 | analysis = Analysis()._get_analysis(self, strategy)
108 | return analysis
--------------------------------------------------------------------------------
/ui/pages/market.py:
--------------------------------------------------------------------------------
1 | import streamlit as st
2 | import pandas as pd
3 | import os
4 | from datetime import datetime, timedelta
5 | from src.utils.logger import setup_logger
6 | from src.trading.market_executor import MarketExecutor
7 | import threading
8 | import tushare as ts
9 |
10 | logger = setup_logger()
11 |
12 | def start_trading(params):
13 | """启动实盘交易"""
14 | try:
15 | # 从session state获取tushare token
16 | if not params['tushare_token']:
17 | st.error("请先在侧边栏设置Tushare Token")
18 | return
19 |
20 | # 设置tushare token
21 | ts.set_token(params['tushare_token'])
22 | pro = ts.pro_api()
23 |
24 | # 获取上证50成分股
25 | today = datetime.now().strftime('%Y%m%d')
26 | sz50 = pro.index_weight(index_code='000016.SH', trade_date=today)
27 | if sz50.empty:
28 | # 如果当天数据不可用,尝试获取最近的数据
29 | dates = pro.trade_cal(exchange='', start_date=(datetime.now() - timedelta(days=30)).strftime('%Y%m%d'),
30 | end_date=today, is_open='1')
31 | for date in sorted(dates['cal_date'].tolist(), reverse=True):
32 | sz50 = pro.index_weight(index_code='000016.SH', trade_date=date)
33 | if not sz50.empty:
34 | break
35 |
36 | if sz50.empty:
37 | st.error("未获取到上证50成分股列表")
38 | return
39 | symbols = sz50['con_code'].tolist()
40 | logger.info(f"上证50成分股: {symbols}")
41 |
42 | # 创建并启动MarketExecutor
43 | executor = MarketExecutor(symbols, tushare_token=params['tushare_token'])
44 | thread = threading.Thread(target=executor.run_continuously)
45 | thread.daemon = True
46 | thread.start()
47 |
48 | st.session_state.trading_thread = thread
49 | st.session_state.is_trading = True
50 | st.success("实盘交易已启动")
51 |
52 | except Exception as e:
53 | st.error(f"启动实盘交易失败: {str(e)}")
54 | logger.error(f"启动实盘交易失败: {str(e)}")
55 | import traceback
56 | traceback.print_exc()
57 |
58 | def stop_trading():
59 | """停止实盘交易"""
60 | try:
61 | if hasattr(st.session_state, 'trading_thread'):
62 | # 这里需要实现一个优雅的停止机制
63 | st.session_state.is_trading = False
64 | st.session_state.trading_thread.join(timeout=5)
65 | st.success("实盘交易已停止")
66 | except Exception as e:
67 | st.error(f"停止实盘交易失败: {str(e)}")
68 | logger.error(f"停止实盘交易失败: {str(e)}")
69 |
70 | def render_market(params):
71 | st.header("实盘交易记录")
72 |
73 | # 创建交易记录目录
74 | records_dir = "data/trading_records"
75 | os.makedirs(records_dir, exist_ok=True)
76 |
77 | # 初始化session state
78 | if 'is_trading' not in st.session_state:
79 | st.session_state.is_trading = False
80 |
81 | # 添加启动/停止按钮
82 | col1, col2 = st.columns(2)
83 | with col1:
84 | if not st.session_state.is_trading:
85 | if st.button("启动实盘", type="primary"):
86 | start_trading(params)
87 | with col2:
88 | if st.session_state.is_trading:
89 | if st.button("停止实盘", type="secondary"):
90 | stop_trading()
91 |
92 | # 显示当前状态
93 | status = "运行中" if st.session_state.is_trading else "已停止"
94 | st.info(f"实盘交易状态: {status}")
95 |
96 | # 获取最新的交易记录文件
97 | record_files = [f for f in os.listdir(records_dir) if f.endswith('.csv')]
98 | if not record_files:
99 | st.info("暂无交易记录")
100 | return
101 |
102 | latest_file = max(record_files, key=lambda x: os.path.getctime(os.path.join(records_dir, x)))
103 |
104 | # 读取并显示交易记录
105 | try:
106 | df = pd.read_csv(os.path.join(records_dir, latest_file))
107 | df['timestamp'] = pd.to_datetime(df['timestamp'])
108 | df = df.sort_values('timestamp', ascending=False)
109 |
110 | # 显示交易统计
111 | col1, col2, col3 = st.columns(3)
112 | with col1:
113 | st.metric("总交易次数", len(df))
114 | with col2:
115 | st.metric("买入次数", len(df[df['action'] == 'buy']))
116 | with col3:
117 | st.metric("卖出次数", len(df[df['action'] == 'sell']))
118 |
119 | # 显示交易记录表格
120 | st.dataframe(df)
121 |
122 | except Exception as e:
123 | st.error(f"读取交易记录时出错: {str(e)}")
124 | logger.error(f"读取交易记录时出错: {str(e)}")
125 |
--------------------------------------------------------------------------------
/rl_model_finrl/README.md:
--------------------------------------------------------------------------------
1 | # ETF-RL: 基于FinRL框架的A股ETF交易策略
2 |
3 | 这是一个基于FinRL框架设计的强化学习ETF交易系统,专注于A股市场ETF交易。该系统利用强化学习算法训练智能体,根据历史数据和市场特征自动进行ETF买入、卖出和持有的决策。
4 |
5 | ## 主要特点
6 |
7 | - 使用标准的三层FinRL架构:数据层、环境层和智能体层
8 | - 支持多ETF组合交易,能够同时管理多个ETF持仓
9 | - 集成Tushare和AKShare数据源,提供丰富的A股市场数据
10 | - 实现多种强化学习智能体:
11 | - 基于stable-baselines3的DQN智能体
12 | - 基于ElegantRL的PPO智能体
13 | - 基于RLlib的PPO智能体
14 | - 内置完整的数据预处理、特征工程、训练、回测和分析流程
15 | - 灵活的配置系统,支持命令行和配置文件参数
16 |
17 | ## 系统架构
18 |
19 | ```
20 | src/strategies/rl_model_finrl/
21 | ├── agents/ # 强化学习智能体
22 | │ ├── stablebaseline3/ # 基于SB3的智能体实现
23 | │ │ └── dqn_agent.py # DQN智能体实现
24 | │ ├── elegantrl/ # 基于ElegantRL的智能体
25 | │ │ └── ppo_agent.py # PPO智能体实现
26 | │ └── rllib/ # 基于RLlib的智能体
27 | │ └── ppo_agent.py # PPO智能体实现
28 | ├── meta/ # 基础组件
29 | │ ├── data_processors/ # 数据处理器
30 | │ │ ├── __init__.py # 数据处理器基类
31 | │ │ ├── tushare_processor.py # Tushare数据处理
32 | │ │ └── akshare_processor.py # AKShare数据处理
33 | │ └── preprocessor/ # 数据预处理
34 | │ ├── __init__.py # 预处理器基类
35 | │ ├── data_normalizer.py # 数据标准化
36 | │ └── feature_engineer.py # 特征工程
37 | ├── applications/ # 应用实现
38 | │ └── stock_trading/ # 股票交易应用
39 | │ ├── __init__.py # 应用初始化
40 | │ ├── etf_env.py # ETF交易环境
41 | │ ├── run_strategy.py # 策略运行
42 | │ ├── backtest.py # 回测模块
43 | │ └── analysis.py # 结果分析
44 | ├── config.py # 全局配置
45 | └── __init__.py # 模块初始化
46 | ```
47 |
48 | ## 安装依赖
49 |
50 | 项目依赖以下Python包:
51 |
52 | ```bash
53 | conda install -c conda-forge cmake
54 |
55 | pip install pandas numpy matplotlib tushare akshare gym stable-baselines3 torch ray[rllib] loguru scikit-learn
56 | ```
57 |
58 | ## 使用方法
59 |
60 | ### 1. 配置数据参数
61 |
62 | 在`config.py`中,配置您的Tushare API令牌和其他参数:
63 |
64 | ```python
65 | # 配置Tushare API令牌
66 | TUSHARE_TOKEN = "你的Tushare令牌"
67 |
68 | # 配置ETF列表
69 | TICKER_LIST = [
70 | "159915.SZ", # 易方达创业板ETF
71 | "510300.SH", # 华泰柏瑞沪深300ETF
72 | # 添加更多ETF...
73 | ]
74 |
75 | # 配置日期范围
76 | TRAIN_START_DATE = "2018-01-01"
77 | TRAIN_END_DATE = "2021-12-31"
78 | TEST_START_DATE = "2022-01-01"
79 | TEST_END_DATE = "2023-12-31"
80 | ```
81 |
82 | ### 2. 运行策略
83 |
84 | 使用`run_strategy.py`训练并运行策略:
85 |
86 | ```bash
87 | python -m src.strategies.rl_model_finrl.applications.stock_trading.run_strategy --tushare_token "你的Tushare令牌" --agent_type dqn
88 | ```
89 |
90 | 可选的`agent_type`参数:
91 | - `dqn`: 使用StableBaseline3的DQN智能体
92 | - `ppo_elegant`: 使用ElegantRL的PPO智能体
93 | - `ppo_rllib`: 使用RLlib的PPO智能体
94 |
95 | ### 3. 回测模型
96 |
97 | 训练完成后,使用回测脚本评估模型性能:
98 |
99 | ```bash
100 | python -m src.strategies.rl_model_finrl.applications.stock_trading.backtest --model_path models/dqn_etf_trading.zip --agent_type dqn
101 | ```
102 |
103 | ### 4. 分析结果
104 |
105 | 对回测结果进行详细分析:
106 |
107 | ```bash
108 | python -m src.strategies.rl_model_finrl.applications.stock_trading.analysis --result_path results/backtest_results.csv
109 | ```
110 |
111 | 回测结果将生成图表和性能指标,包括:
112 | - 投资组合价值曲线
113 | - 与基准ETF的对比
114 | - 回撤分析
115 | - 收益率和风险指标
116 | - 交易记录分析
117 | - 胜率和盈亏比分析
118 |
119 | ## 扩展功能
120 |
121 | 系统支持以下扩展:
122 |
123 | 1. 添加新的ETF:在配置中添加新的ETF代码
124 | 2. 实现新的智能体:在`agents`目录下添加新的智能体实现
125 | 3. 添加新的数据源:扩展数据处理器以支持更多数据源
126 | 4. 自定义奖励函数:在环境中修改奖励计算逻辑
127 | 5. 调整特征工程:在`preprocessor`目录下修改特征生成逻辑
128 |
129 | ## 性能基准
130 |
131 | 在不同市场条件下的基准性能:
132 |
133 | - 牛市(2019-2020):年化收益率 20-30%,最大回撤 15-20%
134 | - 震荡市(2021-2022):年化收益率 5-15%,最大回撤 10-15%
135 | - 熊市(2022下半年):年化收益率 -5-5%,最大回撤 20-25%
136 |
137 | 请注意,过去的性能不代表未来结果,交易有风险,投资需谨慎。
138 |
139 | ## 待优化
140 | 以下是需要完善和优化的关键方面:
141 |
142 | 1. **超参数优化框架**
143 | - 实现自动化超参数调优(如使用Optuna或Ray Tune)
144 | - 当前实现缺乏系统性的超参数优化,这对强化学习性能至关重要
145 |
146 | 2. **风险管理扩展**
147 | - 添加止损和止盈机制
148 | - 实现基于波动率的回撤约束和仓位控制
149 | - 集成凯利准则进行最优仓位配置
150 |
151 | 3. **增强奖励函数**
152 | - 实现结合收益、风险和交易频率的多目标奖励
153 | - 添加过度交易的惩罚项(以减少交易成本)
154 | - 为特定市场条件创建自定义奖励塑造
155 |
156 | 4. **市场状态检测**
157 | - 添加市场状态检测(牛市/熊市/震荡市)以适应不同市场环境
158 | - 为不同市场条件实现单独模型或使用上下文特征
159 |
160 | 5. **集成方法**
161 | - 结合多个强化学习智能体的决策,产生更稳健的交易信号
162 | - 基于近期表现实现模型选择
163 |
164 | 6. **在线学习能力**
165 | - 添加增量训练功能,使模型能适应新的市场数据
166 | - 实现优先采样的经验回放,支持持续学习
167 |
168 | 7. **可解释性工具**
169 | - 开发可视化工具以理解智能体决策过程
170 | - 实现归因分析,了解哪些特征驱动决策
171 |
172 | 8. **扩展特征工程**
173 | - 添加新闻/社交媒体的市场情绪分析
174 | - 纳入宏观经济指标
175 | - 包含资金流向数据和行业轮动指标
176 |
177 | 9. **基准比较**
178 | - 实现与传统策略(动量、均值回归)的比较
179 | - 为性能指标添加统计显著性测试
180 |
181 | 10. **生产环境准备**
182 | - 添加强健的错误处理和日志记录
183 | - 实现生产部署的系统监控
184 | - 创建模型版本控制和自动化回测管道
185 |
186 | 11. **多时间框架分析**
187 | - 整合多个时间框架的数据(日线、周线、月线)
188 | - 为不同决策周期实现分层强化学习
189 |
190 | 12. **迁移学习**
191 | - 在相似市场/资产上实现预训练
192 | - 添加领域适应技术实现模型在不同市场间的迁移
193 |
194 | 13. **替代数据集成**
195 | - 创建价格数据之外的替代数据源接口
196 | - 添加期权链数据作为隐含波动率信号
197 |
198 |
--------------------------------------------------------------------------------
/src/strategies/market_sentiment/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pandas as pd
3 | import talib
4 | from loguru import logger
5 |
6 | class TrendStateDetector:
7 | def __init__(self, window=60):
8 | self.window = window
9 | self.last_regime = None
10 | self.regime_change_date = None
11 | # 缓存历史计算结果
12 | self.cached_ema5 = None
13 | self.cached_ema13 = None
14 | self.cached_ema55 = None
15 | self.cached_momentum = None
16 | self.crossover_threshold = 0.001 # 0.1%的最小上穿幅度要求
17 |
18 | def detect(self, price_series, volume_series, current_date=None):
19 | """趋势检测逻辑,基于EMA交叉和价格位置"""
20 | price_series_pd = pd.Series(price_series)
21 |
22 | # 计算EMA指标
23 | ema5 = talib.EMA(price_series_pd, timeperiod=5)
24 | ema13 = talib.EMA(price_series_pd, timeperiod=10)
25 | ema55 = talib.EMA(price_series_pd, timeperiod=55)
26 |
27 | # 动量计算
28 | momentum = talib.MOM(price_series_pd, timeperiod=5)
29 | sma_momentum = talib.SMA(momentum, timeperiod=3)
30 |
31 | # 成交量指标
32 | volume_series_pd = pd.Series(volume_series)
33 | obv = talib.OBV(price_series_pd, volume_series_pd)
34 | obv_ema = talib.EMA(obv, timeperiod=5)
35 |
36 | current_regime = 'normal'
37 |
38 | if len(price_series) >= 55:
39 | # 获取当前和前一日的值
40 | ema5_current = ema5.iloc[-1]
41 | ema5_prev = ema5.iloc[-2]
42 | ema13_current = ema13.iloc[-1]
43 | ema13_prev = ema13.iloc[-2]
44 | ema55_current = ema55.iloc[-1]
45 | ema55_prev = ema55.iloc[-2]
46 | price_current = price_series_pd.iloc[-1]
47 | price_prev = price_series_pd.iloc[-2]
48 |
49 | # 成交量确认
50 | volume_confirmed = obv.iloc[-1] > obv_ema.iloc[-1]
51 |
52 | # 1. EMA5上穿EMA13时设置为potential_uptrend
53 | crossover_pct = (ema5_current - ema13_current) / ema13_current
54 | if (ema5_current > ema13_current and ema5_prev <= ema13_prev and
55 | crossover_pct > self.crossover_threshold):
56 | current_regime = 'potential_uptrend'
57 |
58 | # 2. EMA5上穿EMA13后,收盘价格连续3个交易日站上EMA55,设置为uptrend
59 | if (ema5_current > ema13_current and
60 | price_current > ema55_current and
61 | price_series_pd.iloc[-2] > ema55.iloc[-2] and
62 | price_series_pd.iloc[-3] > ema55.iloc[-3] and
63 | volume_confirmed):
64 | current_regime = 'uptrend'
65 |
66 | # 3. EMA5下穿EMA13,但收盘价格在EMA55之上,设置为potential_downtrend
67 | if (ema5_current < ema13_current and
68 | ema5_prev >= ema13_prev and
69 | price_current > ema55_current):
70 | current_regime = 'potential_downtrend'
71 |
72 | # 4. EMA5下穿EMA55,设置为downtrend
73 | if ema5_current < ema55_current and ema5_prev >= ema55_prev:
74 | current_regime = 'downtrend'
75 |
76 | # 趋势变化检测
77 | if current_regime != self.last_regime:
78 | self.regime_change_date = current_date
79 | if current_date is not None:
80 | logger.info(f"趋势变化: {self.last_regime} -> {current_regime}, 日期: {current_date}")
81 | self.last_regime = current_regime
82 |
83 | return current_regime
84 |
85 | class PositionManager:
86 | def __init__(self, max_risk=0.45):
87 | self.max_risk = max_risk
88 |
89 | def adjust_position(self, target_ratio, volatility):
90 | """简化的仓位管理,只根据波动率调整仓位"""
91 | # 波动率调整系数 - 这里的2是百分比形式的波动率基准值(2%)
92 | # 实际波动率通常在1%~3%之间
93 | vol_adj = np.clip(volatility / 2, 0.5, 1.5)
94 |
95 | # 根据波动率调整目标仓位
96 | adjusted_ratio = target_ratio * vol_adj
97 |
98 | # 设置风险上限
99 | if adjusted_ratio > self.max_risk * 2: # 最大允许仓位90%
100 | adjusted_ratio = self.max_risk * 2
101 |
102 | logger.info(f"波动率: {volatility:.2f}, 波动率调整系数: {vol_adj:.2f}, 目标仓位: {target_ratio:.2f}, 调整后仓位: {adjusted_ratio:.2f}")
103 |
104 | return adjusted_ratio
105 |
106 | # # 信号生成参数
107 | # SENTIMENT_THRESHOLDS = {
108 | # 'core': 2.5, # 核心信号阈值
109 | # 'secondary': 9.0, # 次级信号阈值
110 | # 'light': 9.0 # 轻仓信号阈值
111 | # }
112 |
113 | # POSITION_WEIGHTS = {
114 | # 'core': 0.95, # 核心信号仓位
115 | # 'secondary': 0.9, # 次级信号仓位
116 | # 'light': 0.8 # 轻仓信号仓位
117 | # }
118 |
119 | # def generate_signals(sentiment_score, regime, volatility):
120 | # """简化的信号生成,只区分下跌趋势和正常趋势"""
121 | # signals = []
122 |
123 | # # 下跌趋势不开仓
124 | # if regime == 'downtrend':
125 | # if sentiment_score < SENTIMENT_THRESHOLDS['core']:
126 | # signals.append({'type': 'buy', 'weight': POSITION_WEIGHTS['core']}) # 核心信号
127 | # return signals
128 | # else:
129 | # return []
130 |
131 | # # 其他情况根据情绪分数决定
132 | # if sentiment_score < SENTIMENT_THRESHOLDS['core']:
133 | # signals.append({'type': 'buy', 'weight': POSITION_WEIGHTS['core']}) # 核心信号
134 | # elif SENTIMENT_THRESHOLDS['core'] <= sentiment_score < SENTIMENT_THRESHOLDS['secondary']:
135 | # signals.append({'type': 'buy', 'weight': POSITION_WEIGHTS['secondary']}) # 次级信号
136 | # elif SENTIMENT_THRESHOLDS['secondary'] <= sentiment_score < SENTIMENT_THRESHOLDS['light']:
137 | # signals.append({'type': 'buy', 'weight': POSITION_WEIGHTS['light']}) # 轻仓信号
138 |
139 | # return signals
--------------------------------------------------------------------------------
/rl_model_finrl/meta/data_processors/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | 数据处理器模块
3 |
4 | 这个模块提供了处理金融市场数据的各种处理器,支持不同的数据源:
5 | - TushareProcessor: 处理来自Tushare API的数据
6 | - AKShareProcessor: 处理来自AKShare的数据
7 | - DataProcessor: 数据处理器的基类
8 |
9 | 主要功能:
10 | - 下载和预处理ETF价格数据
11 | - 添加技术指标
12 | - 准备用于强化学习环境的数据
13 |
14 | 使用示例:
15 | ```python
16 | from src.strategies.rl_model_finrl.meta.data_processors import DataProcessor, TushareProcessor
17 |
18 | # 使用Tushare处理器
19 | processor = TushareProcessor(token="your_tushare_token")
20 | data = processor.download_data(
21 | start_date="2018-01-01",
22 | end_date="2021-12-31",
23 | ticker_list=["510050"]
24 | )
25 |
26 | # 添加技术指标
27 | data_with_indicators = processor.add_technical_indicators(data)
28 |
29 | # 准备训练数据
30 | train_data = processor.data_split(data_with_indicators, start_date="2018-01-01", end_date="2020-12-31")
31 | ```
32 | """
33 |
34 | import pandas as pd
35 | import numpy as np
36 | from abc import ABC, abstractmethod
37 | from typing import Dict, List, Tuple, Union, Optional
38 |
39 | class DataProcessor(ABC):
40 | """
41 | 数据处理器基类
42 |
43 | 定义了数据处理的通用接口,所有特定数据源的处理器都应该继承这个类
44 | """
45 |
46 | @abstractmethod
47 | def download_data(self, **kwargs) -> pd.DataFrame:
48 | """
49 | 从数据源下载数据
50 |
51 | 参数:
52 | **kwargs: 下载参数,如起止日期、股票代码等
53 |
54 | 返回:
55 | 下载的数据DataFrame
56 | """
57 | pass
58 |
59 | @abstractmethod
60 | def clean_data(self, data: pd.DataFrame) -> pd.DataFrame:
61 | """
62 | 清洗数据
63 |
64 | 参数:
65 | data: 原始数据DataFrame
66 |
67 | 返回:
68 | 清洗后的数据DataFrame
69 | """
70 | pass
71 |
72 | def add_technical_indicators(self, data: pd.DataFrame) -> pd.DataFrame:
73 | """
74 | 添加技术指标
75 |
76 | 参数:
77 | data: 原始价格数据DataFrame
78 |
79 | 返回:
80 | 添加技术指标后的数据DataFrame
81 | """
82 | df = data.copy()
83 |
84 | # 确保列名标准化
85 | price_col = 'close' if 'close' in df.columns else 'Close'
86 | high_col = 'high' if 'high' in df.columns else 'High'
87 | low_col = 'low' if 'low' in df.columns else 'Low'
88 | volume_col = 'volume' if 'volume' in df.columns else 'Volume'
89 |
90 | # 计算MACD
91 | df['ema12'] = df[price_col].ewm(span=12, adjust=False).mean()
92 | df['ema26'] = df[price_col].ewm(span=26, adjust=False).mean()
93 | df['macd'] = df['ema12'] - df['ema26']
94 | df['macd_signal'] = df['macd'].ewm(span=9, adjust=False).mean()
95 | df['macd_hist'] = df['macd'] - df['macd_signal']
96 |
97 | # 计算RSI (相对强弱指标)
98 | delta = df[price_col].diff()
99 | gain = delta.copy()
100 | loss = delta.copy()
101 | gain[gain < 0] = 0
102 | loss[loss > 0] = 0
103 | loss = -loss
104 |
105 | avg_gain = gain.rolling(window=14).mean()
106 | avg_loss = loss.rolling(window=14).mean()
107 |
108 | rs = avg_gain / avg_loss
109 | df['rsi'] = 100 - (100 / (1 + rs))
110 |
111 | # 计算CCI (商品通道指标)
112 | df['tp'] = (df[high_col] + df[low_col] + df[price_col]) / 3
113 | df['tp_ma'] = df['tp'].rolling(window=20).mean()
114 | mean_dev = df['tp'].rolling(window=20).apply(lambda x: abs(x - x.mean()).mean())
115 | df['cci'] = (df['tp'] - df['tp_ma']) / (0.015 * mean_dev)
116 |
117 | # 计算布林带
118 | df['sma20'] = df[price_col].rolling(window=20).mean()
119 | df['bollinger_upper'] = df['sma20'] + 2 * df[price_col].rolling(window=20).std()
120 | df['bollinger_lower'] = df['sma20'] - 2 * df[price_col].rolling(window=20).std()
121 |
122 | # 计算ATR (真实波幅均值)
123 | df['tr1'] = df[high_col] - df[low_col]
124 | df['tr2'] = abs(df[high_col] - df[price_col].shift())
125 | df['tr3'] = abs(df[low_col] - df[price_col].shift())
126 | df['tr'] = df[['tr1', 'tr2', 'tr3']].max(axis=1)
127 | df['atr'] = df['tr'].rolling(window=14).mean()
128 |
129 | # 移除临时列
130 | df = df.drop(['tr1', 'tr2', 'tr3', 'tr', 'tp', 'tp_ma', 'ema12', 'ema26'], axis=1, errors='ignore')
131 |
132 | # 使用涨跌幅代替价格
133 | df['daily_return'] = df[price_col].pct_change()
134 |
135 | return df
136 |
137 | def data_split(
138 | self,
139 | df: pd.DataFrame,
140 | start_date: str,
141 | end_date: str
142 | ) -> pd.DataFrame:
143 | """
144 | 按日期范围分割数据
145 |
146 | 参数:
147 | df: 原始数据DataFrame
148 | start_date: 开始日期
149 | end_date: 结束日期
150 |
151 | 返回:
152 | 指定日期范围的数据
153 | """
154 | data = df.copy()
155 |
156 | # 确保索引是日期类型
157 | if not isinstance(data.index, pd.DatetimeIndex):
158 | if 'date' in data.columns:
159 | data['date'] = pd.to_datetime(data['date'])
160 | data = data.set_index('date')
161 | else:
162 | raise ValueError("数据必须有日期列或日期索引")
163 |
164 | # 转换日期字符串为日期类型
165 | start_date = pd.to_datetime(start_date)
166 | end_date = pd.to_datetime(end_date)
167 |
168 | # 筛选日期范围内的数据
169 | data = data.loc[start_date:end_date]
170 |
171 | return data
172 |
173 | def prepare_data_for_training(self, **kwargs) -> Tuple[Dict[str, pd.DataFrame], Optional[pd.DataFrame]]:
174 | """
175 | 准备用于训练的数据
176 |
177 | 参数:
178 | **kwargs: 准备数据的参数
179 |
180 | 返回:
181 | (ETF数据字典, 市场指数数据)元组
182 | """
183 | pass
184 |
185 | from src.strategies.rl_model_finrl.meta.data_processors.tushare_processor import TushareProcessor
186 | from src.strategies.rl_model_finrl.meta.data_processors.akshare_processor import AKShareProcessor
187 |
188 | __all__ = [
189 | 'DataProcessor',
190 | 'TushareProcessor',
191 | 'AKShareProcessor'
192 | ]
--------------------------------------------------------------------------------
/src/strategies/market_sentiment/etf_dividend_handler.py:
--------------------------------------------------------------------------------
1 | import tushare as ts
2 | import pandas as pd
3 | from src.data.data_loader import DataLoader
4 | from datetime import datetime, timedelta
5 | from loguru import logger
6 | import os
7 | import time
8 | import json
9 | import numpy as np
10 |
11 | class ETFDividendHandler:
12 | """处理ETF分红的类"""
13 | def __init__(self, ts_code=None):
14 | self.ts_code = ts_code # ETF代码
15 | self.dividend_data = None # 分红数据DataFrame
16 | self.last_api_call = 0 # 上次API调用时间
17 | self.min_interval = 1.0 # 最小调用间隔(秒)
18 | self.cache_file = f'cache/dividend_{ts_code}.json' # 缓存文件路径
19 |
20 | # 确保缓存目录存在
21 | os.makedirs('cache', exist_ok=True)
22 |
23 | # 初始化Tushare
24 | tushare_token = os.getenv('TUSHARE_TOKEN')
25 | if not tushare_token:
26 | raise ValueError("未设置TUSHARE_TOKEN环境变量")
27 | ts.set_token(tushare_token)
28 | self.pro = ts.pro_api()
29 |
30 | def _wait_for_rate_limit(self):
31 | """等待以满足API调用频率限制"""
32 | current_time = time.time()
33 | elapsed = current_time - self.last_api_call
34 | if elapsed < self.min_interval:
35 | time.sleep(self.min_interval - elapsed)
36 | self.last_api_call = time.time()
37 |
38 | def _load_from_cache(self):
39 | """从缓存加载分红数据"""
40 | try:
41 | if os.path.exists(self.cache_file):
42 | with open(self.cache_file, 'r') as f:
43 | data = json.load(f)
44 | df = pd.DataFrame(data)
45 | df['date'] = pd.to_datetime(df['date'])
46 | return df
47 | except Exception as e:
48 | logger.warning(f"从缓存加载分红数据失败: {str(e)}")
49 | return None
50 |
51 | def _save_to_cache(self, df):
52 | """保存分红数据到缓存"""
53 | try:
54 | # 将DataFrame转换为可序列化的格式
55 | df_copy = df.copy()
56 | df_copy['date'] = df_copy['date'].dt.strftime('%Y-%m-%d')
57 | data = df_copy.to_dict('records')
58 |
59 | with open(self.cache_file, 'w') as f:
60 | json.dump(data, f)
61 | except Exception as e:
62 | logger.warning(f"保存分红数据到缓存失败: {str(e)}")
63 |
64 | def _clean_dividend_data(self, df):
65 | """清理分红数据,过滤掉非数值的分红记录"""
66 | try:
67 | # 将分红列转换为数值类型,非数值将变为NaN
68 | df['dividend'] = pd.to_numeric(df['dividend'], errors='coerce')
69 |
70 | # 删除分红为NaN的记录
71 | df = df.dropna(subset=['dividend'])
72 |
73 | # 删除分红为0的记录
74 | df = df[df['dividend'] > 0]
75 |
76 | return df
77 | except Exception as e:
78 | logger.error(f"清理分红数据时出错: {str(e)}")
79 | return df
80 |
81 | def update_dividend_data(self, start_date=None, end_date=None):
82 | """更新分红数据"""
83 | try:
84 | # 先尝试从缓存加载
85 | cached_data = self._load_from_cache()
86 | if cached_data is not None:
87 | # 检查缓存数据是否覆盖了所需的日期范围
88 | if start_date and end_date:
89 | cached_start = cached_data['date'].min().date()
90 | cached_end = cached_data['date'].max().date()
91 | if cached_start <= start_date and cached_end >= end_date:
92 | self.dividend_data = cached_data
93 | logger.info(f"从缓存加载ETF分红数据成功 - {self.ts_code}, 数据长度: {len(cached_data)}")
94 | return True
95 |
96 | # 如果缓存不存在或数据不完整,从API获取
97 | self._wait_for_rate_limit()
98 |
99 | # 获取分红数据
100 | df = self.pro.fund_div(
101 | ts_code=self.ts_code,
102 | start_date=start_date.strftime('%Y%m%d') if start_date else None,
103 | end_date=end_date.strftime('%Y%m%d') if end_date else None
104 | )
105 |
106 | if df is not None and not df.empty:
107 | # 重命名列
108 | df = df.rename(columns={
109 | 'ann_date': 'date',
110 | 'div_cash': 'dividend'
111 | })
112 |
113 | # 转换日期格式
114 | df['date'] = pd.to_datetime(df['date'])
115 |
116 | # 清理分红数据
117 | df = self._clean_dividend_data(df)
118 |
119 | # 按日期排序
120 | df = df.sort_values('date')
121 |
122 | # 保存到缓存
123 | self._save_to_cache(df)
124 |
125 | self.dividend_data = df
126 | logger.info(f"成功获取ETF分红数据 - {self.ts_code}, 数据长度: {len(df)}")
127 | return True
128 | else:
129 | logger.warning(f"未获取到ETF分红数据 - {self.ts_code}")
130 | return False
131 |
132 | except Exception as e:
133 | logger.error(f"获取ETF分红数据出错: {str(e)}")
134 | return False
135 |
136 | def process_dividend(self, date_str, position_size, current_price):
137 | """处理指定日期的分红"""
138 | if self.dividend_data is None:
139 | return 0.0
140 |
141 | try:
142 | # 将日期字符串转换为datetime对象
143 | date = pd.to_datetime(date_str)
144 |
145 | # 查找当天的分红记录
146 | dividend_record = self.dividend_data[self.dividend_data['date'].dt.date == date.date()]
147 |
148 | if not dividend_record.empty:
149 | # 获取每股分红金额
150 | dividend_per_share = float(dividend_record['dividend'].iloc[0])
151 |
152 | # 计算总分红金额
153 | total_dividend = dividend_per_share * position_size
154 |
155 | logger.info(f"处理ETF分红 - 日期: {date_str}, 每股分红: {dividend_per_share:.4f}, "
156 | f"持仓数量: {position_size}, 总分红: {total_dividend:.2f}")
157 |
158 | return total_dividend
159 | else:
160 | return 0.0
161 |
162 | except Exception as e:
163 | logger.error(f"处理ETF分红时出错: {str(e)}")
164 | return 0.0
--------------------------------------------------------------------------------
/rl_model_finrl/meta/preprocessor/data_normalizer.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import numpy as np
3 | from typing import Dict, List, Tuple, Union, Optional
4 | from sklearn.preprocessing import StandardScaler, MinMaxScaler, RobustScaler
5 |
6 | class DataNormalizer:
7 | """
8 | 数据归一化处理器
9 |
10 | 用于金融数据的归一化和标准化,以便更好地用于机器学习模型
11 | """
12 |
13 | def __init__(self,
14 | method: str = 'standard', # 'standard', 'minmax', 'robust'
15 | feature_range: Tuple[float, float] = (0, 1),
16 | ):
17 | """
18 | 初始化归一化处理器
19 |
20 | 参数:
21 | method: 归一化方法: 'standard'(标准化), 'minmax'(最小最大归一化), 'robust'(稳健归一化)
22 | feature_range: 特征范围 (针对minmax方法)
23 | """
24 | self.method = method
25 | self.feature_range = feature_range
26 | self.scalers = {}
27 | self.feature_columns = None
28 |
29 | def fit(self, df: pd.DataFrame, feature_columns: Optional[List[str]] = None) -> 'DataNormalizer':
30 | """
31 | 拟合归一化器
32 |
33 | 参数:
34 | df: 数据DataFrame
35 | feature_columns: 要归一化的特征列名列表,如果为None,将使用所有数值型列
36 |
37 | 返回:
38 | 归一化处理器实例
39 | """
40 | data = df.copy()
41 |
42 | # 如果未指定特征列,使用所有数值列
43 | if feature_columns is None:
44 | self.feature_columns = data.select_dtypes(include=['float64', 'int64']).columns.tolist()
45 | else:
46 | self.feature_columns = feature_columns
47 |
48 | # 拟合每个特征的归一化器
49 | for col in self.feature_columns:
50 | if col in data.columns:
51 | # 创建适当的缩放器
52 | if self.method == 'standard':
53 | scaler = StandardScaler()
54 | elif self.method == 'minmax':
55 | scaler = MinMaxScaler(feature_range=self.feature_range)
56 | elif self.method == 'robust':
57 | scaler = RobustScaler()
58 | else:
59 | raise ValueError(f"不支持的归一化方法: {self.method}")
60 |
61 | # 拟合缩放器
62 | values = data[col].values.reshape(-1, 1)
63 | scaler.fit(values)
64 | self.scalers[col] = scaler
65 |
66 | return self
67 |
68 | def transform(self, df: pd.DataFrame) -> pd.DataFrame:
69 | """
70 | 转换数据
71 |
72 | 参数:
73 | df: 要转换的数据DataFrame
74 |
75 | 返回:
76 | 归一化后的DataFrame
77 | """
78 | if not self.scalers:
79 | raise ValueError("请先调用fit方法")
80 |
81 | data = df.copy()
82 |
83 | # 转换每个特征
84 | for col, scaler in self.scalers.items():
85 | if col in data.columns:
86 | values = data[col].values.reshape(-1, 1)
87 | data[col] = scaler.transform(values)
88 |
89 | return data
90 |
91 | def fit_transform(self, df: pd.DataFrame, feature_columns: Optional[List[str]] = None) -> pd.DataFrame:
92 | """
93 | 拟合并转换数据
94 |
95 | 参数:
96 | df: 数据DataFrame
97 | feature_columns: 要归一化的特征列名列表
98 |
99 | 返回:
100 | 归一化后的DataFrame
101 | """
102 | self.fit(df, feature_columns)
103 | return self.transform(df)
104 |
105 | def inverse_transform(self, df: pd.DataFrame) -> pd.DataFrame:
106 | """
107 | 反向转换数据(恢复原始比例)
108 |
109 | 参数:
110 | df: 归一化后的数据DataFrame
111 |
112 | 返回:
113 | 原始比例的DataFrame
114 | """
115 | if not self.scalers:
116 | raise ValueError("请先调用fit方法")
117 |
118 | data = df.copy()
119 |
120 | # 反向转换每个特征
121 | for col, scaler in self.scalers.items():
122 | if col in data.columns:
123 | values = data[col].values.reshape(-1, 1)
124 | data[col] = scaler.inverse_transform(values)
125 |
126 | return data
127 |
128 | def save(self, path: str) -> None:
129 | """
130 | 保存归一化器状态
131 |
132 | 参数:
133 | path: 保存路径
134 | """
135 | import joblib
136 |
137 | state = {
138 | 'method': self.method,
139 | 'feature_range': self.feature_range,
140 | 'scalers': self.scalers,
141 | 'feature_columns': self.feature_columns
142 | }
143 |
144 | joblib.dump(state, path)
145 |
146 | @classmethod
147 | def load(cls, path: str) -> 'DataNormalizer':
148 | """
149 | 从文件加载归一化器
150 |
151 | 参数:
152 | path: 文件路径
153 |
154 | 返回:
155 | 加载的归一化器实例
156 | """
157 | import joblib
158 |
159 | state = joblib.load(path)
160 |
161 | normalizer = cls(
162 | method=state['method'],
163 | feature_range=state['feature_range']
164 | )
165 |
166 | normalizer.scalers = state['scalers']
167 | normalizer.feature_columns = state['feature_columns']
168 |
169 | return normalizer
170 |
171 | def normalize_price_data(self, df: pd.DataFrame, price_columns: List[str]) -> pd.DataFrame:
172 | """
173 | 归一化价格数据
174 |
175 | 参数:
176 | df: 价格数据DataFrame
177 | price_columns: 价格列名列表
178 |
179 | 返回:
180 | 归一化后的DataFrame
181 | """
182 | data = df.copy()
183 |
184 | # 当价格数据用于机器学习时,通常更好的做法是转换为收益率或相对变化
185 | for col in price_columns:
186 | if col in data.columns:
187 | # 计算收益率
188 | data[f'{col}_return'] = data[col].pct_change()
189 |
190 | # 计算相对于初始价格的变化
191 | data[f'{col}_rel'] = data[col] / data[col].iloc[0] - 1
192 |
193 | # 对原始价格列进行归一化
194 | if self.method == 'standard':
195 | scaler = StandardScaler()
196 | elif self.method == 'minmax':
197 | scaler = MinMaxScaler(feature_range=self.feature_range)
198 | elif self.method == 'robust':
199 | scaler = RobustScaler()
200 |
201 | values = data[col].values.reshape(-1, 1)
202 | data[f'{col}_norm'] = scaler.fit_transform(values)
203 |
204 | # 保存缩放器以便将来反归一化
205 | self.scalers[col] = scaler
206 |
207 | return data
--------------------------------------------------------------------------------
/src/utils/plot.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 | import numpy as np
3 | import plotly.graph_objects as go
4 | from plotly.subplots import make_subplots
5 | from datetime import datetime
6 | import pandas as pd
7 |
8 | class Plot:
9 | def __init__(self, strategy):
10 | self.strategy = strategy
11 |
12 | def plot(self, **kwargs):
13 | """使用Plotly绘制交互式回测结果"""
14 | # 获取策略数据
15 | data = self.strategy.data
16 |
17 | # 将数据转换为numpy数组,处理日期时间
18 | dates = [datetime.fromordinal(int(d)).strftime('%Y-%m-%d') for d in data.datetime.array]
19 | opens = np.array(data.open.array)
20 | highs = np.array(data.high.array)
21 | lows = np.array(data.low.array)
22 | closes = np.array(data.close.array)
23 | volumes = np.array(data.volume.array)
24 |
25 | # 检查trailing_stop是否存在
26 | trailing_stop = getattr(self.strategy, 'trailing_stop', None)
27 | if trailing_stop is not None:
28 | trailing_stop_vals = np.array(trailing_stop.trailing_stop.array)
29 | else:
30 | # 如果trailing_stop为None,创建一个全为0的数组
31 | trailing_stop_vals = np.zeros_like(closes)
32 |
33 | # 创建DataFrame以便处理数据
34 | df = pd.DataFrame({
35 | 'date': dates,
36 | 'open': opens,
37 | 'high': highs,
38 | 'low': lows,
39 | 'close': closes,
40 | 'volume': volumes,
41 | 'trailing_stop': trailing_stop_vals
42 | })
43 |
44 | # 移除volume为0的行(非交易日)
45 | df = df[df['volume'] > 0].copy()
46 |
47 | # 创建子图
48 | fig = make_subplots(rows=2, cols=1,
49 | shared_xaxes=True,
50 | vertical_spacing=0.03,
51 | row_heights=[0.7, 0.3])
52 |
53 | # 添加K线图
54 | fig.add_trace(go.Candlestick(
55 | x=df['date'],
56 | open=df['open'],
57 | high=df['high'],
58 | low=df['low'],
59 | close=df['close'],
60 | name='K线',
61 | increasing_line_color='#ff0000', # 上涨为红色
62 | decreasing_line_color='#00ff00', # 下跌为绿色
63 | ), row=1, col=1)
64 |
65 | # 添加追踪止损线
66 | # 过滤掉追踪止损为0的点
67 | valid_stops = df[df['trailing_stop'] > 0].copy()
68 | if not valid_stops.empty:
69 | fig.add_trace(go.Scatter(
70 | x=valid_stops['date'],
71 | y=valid_stops['trailing_stop'],
72 | name='追踪止损',
73 | line=dict(color='#f1c40f', dash='dash')
74 | ), row=1, col=1)
75 |
76 | # 添加买卖点标记
77 | if hasattr(self, 'trades_df') and not self.trades_df.empty:
78 | # 获取买入点
79 | buy_points = self.trades_df[self.trades_df['类型'] == '买入']
80 | if not buy_points.empty:
81 | fig.add_trace(go.Scatter(
82 | x=buy_points['时间'].dt.strftime('%Y-%m-%d'),
83 | y=buy_points['价格'],
84 | mode='markers+text',
85 | marker=dict(symbol='triangle-up', size=15, color='#2ecc71'),
86 | text=[f"买入\n价格:{price:.2f}\n数量:{size}"
87 | for price, size in zip(buy_points['价格'], buy_points['数量'])],
88 | textposition="top center",
89 | name='买入点',
90 | hoverinfo='text'
91 | ), row=1, col=1)
92 |
93 | # 获取卖出点
94 | sell_points = self.trades_df[self.trades_df['类型'] == '卖出']
95 | if not sell_points.empty:
96 | fig.add_trace(go.Scatter(
97 | x=sell_points['时间'].dt.strftime('%Y-%m-%d'),
98 | y=sell_points['价格'],
99 | mode='markers+text',
100 | marker=dict(symbol='triangle-down', size=15, color='#e74c3c'),
101 | text=[f"卖出\n价格:{price:.2f}\n收益:{profit:.2f}"
102 | for price, profit in zip(sell_points['价格'], sell_points['累计收益'])],
103 | textposition="bottom center",
104 | name='卖出点',
105 | hoverinfo='text'
106 | ), row=1, col=1)
107 |
108 | # 添加成交量图
109 | colors = ['#ff0000' if close >= open_ else '#00ff00'
110 | for close, open_ in zip(df['close'], df['open'])]
111 |
112 | fig.add_trace(go.Bar(
113 | x=df['date'],
114 | y=df['volume'],
115 | name='成交量',
116 | marker_color=colors,
117 | marker=dict(
118 | color=colors,
119 | line=dict(color=colors, width=1)
120 | ),
121 | hovertemplate='日期: %{x}
成交量: %{y:,.0f}'
122 | ), row=2, col=1)
123 |
124 | # 更新布局
125 | fig.update_layout(
126 | title={'text': '回测结果', 'font': {'family': 'Arial'}},
127 | yaxis_title={'text': '价格', 'font': {'family': 'Arial'}},
128 | yaxis2_title={'text': '成交量', 'font': {'family': 'Arial'}},
129 | xaxis_rangeslider_visible=False,
130 | height=800,
131 | template='plotly_white',
132 | showlegend=True,
133 | legend=dict(
134 | yanchor="top",
135 | y=0.99,
136 | xanchor="left",
137 | x=0.01,
138 | font=dict(family='Arial')
139 | ),
140 | # 优化X轴显示
141 | xaxis=dict(
142 | type='category',
143 | rangeslider=dict(visible=False),
144 | showgrid=False, # 移除X轴网格线
145 | gridwidth=1,
146 | gridcolor='lightgrey'
147 | ),
148 | xaxis2=dict(
149 | type='category',
150 | rangeslider=dict(visible=False),
151 | showgrid=False, # 移除X轴网格线
152 | gridwidth=1,
153 | gridcolor='lightgrey'
154 | ),
155 | bargap=0, # 设置柱状图之间的间隔为0
156 | bargroupgap=0 # 设置柱状图组之间的间隔为0
157 | )
158 |
159 | # 更新Y轴格式
160 | fig.update_yaxes(
161 | title_text="价格",
162 | title_font=dict(family='Arial'),
163 | showgrid=True,
164 | gridwidth=1,
165 | gridcolor='lightgrey',
166 | row=1,
167 | col=1
168 | )
169 | fig.update_yaxes(
170 | title_text="成交量",
171 | title_font=dict(family='Arial'),
172 | showgrid=True,
173 | gridwidth=1,
174 | gridcolor='lightgrey',
175 | row=2,
176 | col=1
177 | )
178 |
179 | return fig
--------------------------------------------------------------------------------
/src/strategies/rl_model_strategy.py:
--------------------------------------------------------------------------------
1 | import backtrader as bt
2 | import numpy as np
3 | import pandas as pd
4 | import torch
5 | from loguru import logger
6 | import json
7 |
8 | class RLModelStrategy(bt.Strategy):
9 | params = (
10 | ('model_path', None), # 模型路径
11 | ('config_path', None), # 配置文件路径
12 | ('window_size', 10), # 观察窗口大小
13 | ('risk_ratio', 0.02), # 单次交易风险比率
14 | ('max_drawdown', 0.15), # 最大回撤限制
15 | ('price_limit', 0.10), # 涨跌停限制(10%)
16 | ('min_shares', 100), # 最小交易股数
17 | ('cash_buffer', 0.95), # 现金缓冲比例
18 | )
19 |
20 | def __init__(self):
21 | """初始化策略"""
22 | # 加载配置和模型
23 | if self.p.config_path:
24 | with open(self.p.config_path, 'r') as f:
25 | config_dict = json.load(f)
26 | # TODO: 从字典更新配置
27 |
28 | # 设置设备
29 | self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
30 |
31 | # 初始化智能体
32 | if self.p.model_path:
33 | self.agent = self._load_agent()
34 | else:
35 | logger.error("未提供模型路径")
36 | raise ValueError("请提供有效的模型路径")
37 |
38 | # 记录最高净值,用于计算回撤
39 | self.highest_value = self.broker.getvalue()
40 |
41 | # 用于跟踪订单和持仓
42 | self.order = None
43 | self.entry_price = None
44 | self.trade_reason = None
45 | self._orders = []
46 |
47 | # 技术指标
48 | self.atr = bt.indicators.ATR(self.data)
49 |
50 | # 价格和特征历史
51 | self.price_history = []
52 | self.feature_history = []
53 |
54 | logger.info("强化学习模型策略初始化完成")
55 |
56 | def _load_agent(self):
57 | """加载训练好的智能体"""
58 | # 计算状态维度
59 | pass
60 |
61 | def _get_state(self):
62 | """构建当前状态"""
63 | # 获取价格历史
64 | price_data = np.array([
65 | [self.data.open[i], self.data.high[i], self.data.low[i],
66 | self.data.close[i], self.data.volume[i]]
67 | for i in range(-self.p.window_size + 1, 1)
68 | ]).flatten()
69 |
70 | # 生成技术指标特征
71 | features = []
72 | for i in range(-self.p.window_size + 1, 1):
73 | # 趋势指标
74 | ma5 = np.mean([self.data.close[j] for j in range(i-4, i+1)])
75 | ma10 = np.mean([self.data.close[j] for j in range(i-9, i+1)])
76 | ma20 = np.mean([self.data.close[j] for j in range(i-19, i+1)])
77 |
78 | # 动量指标
79 | momentum = self.data.close[i] / self.data.close[i-5] - 1
80 |
81 | # 波动率指标
82 | volatility = np.std([self.data.close[j] for j in range(i-4, i+1)])
83 |
84 | # 成交量指标
85 | volume_ma5 = np.mean([self.data.volume[j] for j in range(i-4, i+1)])
86 |
87 | features.extend([ma5, ma10, ma20, momentum, volatility, volume_ma5])
88 |
89 | # 账户状态
90 | portfolio_value = self.broker.getvalue()
91 | position_value = self.position.size * self.data.close[0] if self.position else 0
92 | position_pct = position_value / portfolio_value if portfolio_value > 0 else 0
93 | cash_pct = self.broker.getcash() / portfolio_value
94 |
95 | # 组合状态
96 | state = np.concatenate([
97 | price_data,
98 | features,
99 | [cash_pct, position_pct]
100 | ])
101 |
102 | return state.astype(np.float32)
103 |
104 | def round_shares(self, shares):
105 | """将股数调整为100的整数倍"""
106 | return int(shares / 100) * 100
107 |
108 | def check_price_limit(self, price):
109 | """检查是否触及涨跌停"""
110 | prev_close = self.data.close[-1]
111 | upper_limit = prev_close * (1 + self.p.price_limit)
112 | lower_limit = prev_close * (1 - self.p.price_limit)
113 | return lower_limit <= price <= upper_limit
114 |
115 | def calculate_trade_size(self, price):
116 | """计算可交易的股数(考虑资金、手续费和100股整数倍)"""
117 | cash = self.broker.getcash() * self.p.cash_buffer
118 |
119 | # 计算风险金额(使用总资产的一定比例)
120 | total_value = self.broker.getvalue()
121 | risk_amount = total_value * self.p.risk_ratio
122 |
123 | # 使用ATR计算每股风险
124 | current_atr = self.atr[0]
125 | risk_per_share = current_atr * 1.5
126 |
127 | # 根据风险计算的股数
128 | risk_size = risk_amount / risk_per_share if risk_per_share > 0 else 0
129 |
130 | # 根据可用资金计算的股数
131 | cash_size = cash / price
132 |
133 | # 取较小值并调整为100股整数倍
134 | shares = min(risk_size, cash_size)
135 | shares = self.round_shares(shares)
136 |
137 | # 再次验证金额是否超过可用资金
138 | if shares * price > cash:
139 | shares = self.round_shares(cash / price)
140 |
141 | return shares if shares >= self.p.min_shares else 0
142 |
143 | def next(self):
144 | # 如果有未完成的订单,不执行新的交易
145 | if self.order:
146 | return
147 |
148 | # 计算当前回撤
149 | current_value = self.broker.getvalue()
150 | self.highest_value = max(self.highest_value, current_value)
151 | drawdown = (self.highest_value - current_value) / self.highest_value
152 |
153 | # 如果回撤超过限制,不开新仓
154 | if drawdown > self.p.max_drawdown:
155 | if self.position:
156 | self.trade_reason = f"触发最大回撤限制 ({drawdown:.2%})"
157 | self.close()
158 | logger.info(f"触发最大回撤限制 - 当前回撤: {drawdown:.2%}, 限制: {self.p.max_drawdown:.2%}")
159 | return
160 |
161 | # 检查是否触及涨跌停
162 | if not self.check_price_limit(self.data.close[0]):
163 | return
164 |
165 | # 获取当前状态
166 | state = self._get_state()
167 | state_tensor = torch.FloatTensor(state).unsqueeze(0).to(self.device)
168 |
169 | # 使用智能体选择动作
170 | with torch.no_grad():
171 | q_values = self.agent.q_network(state_tensor)
172 | action = q_values.argmax().item()
173 |
174 | current_price = self.data.close[0]
175 |
176 | if action == 1: # 买入
177 | if not self.position: # 没有持仓
178 | shares = self.calculate_trade_size(current_price)
179 | if shares >= self.p.min_shares:
180 | self.trade_reason = "智能体买入信号"
181 | self.order = self.buy(size=shares)
182 | if self.order:
183 | self.entry_price = current_price
184 | logger.info(f"买入信号 - 数量: {shares}, 价格: {current_price:.2f}")
185 |
186 | elif action == 2: # 卖出
187 | if self.position: # 有持仓
188 | self.trade_reason = "智能体卖出信号"
189 | self.order = self.close()
190 | if self.order:
191 | logger.info(f"卖出信号 - 价格: {current_price:.2f}")
192 |
193 | def notify_order(self, order):
194 | if order.status in [order.Submitted, order.Accepted]:
195 | return
196 |
197 | if order.status in [order.Completed]:
198 | if order.isbuy():
199 | order.info = {
200 | 'reason': self.trade_reason,
201 | 'total_value': self.broker.getvalue(),
202 | 'position_value': self.position.size * order.executed.price if self.position else 0
203 | }
204 | self._orders.append(order)
205 | logger.info(
206 | f'买入执行 - 价格: {order.executed.price:.2f}, '
207 | f'数量: {order.executed.size}, '
208 | f'原因: {self.trade_reason}'
209 | )
210 | else:
211 | self.entry_price = None
212 | order.info = {
213 | 'reason': self.trade_reason,
214 | 'total_value': self.broker.getvalue(),
215 | 'position_value': self.position.size * order.executed.price if self.position else 0
216 | }
217 | self._orders.append(order)
218 | logger.info(
219 | f'卖出执行 - 价格: {order.executed.price:.2f}, '
220 | f'数量: {order.executed.size}, '
221 | f'原因: {self.trade_reason}'
222 | )
223 | elif order.status in [order.Canceled, order.Margin, order.Rejected]:
224 | logger.warning(f'订单失败 - 状态: {order.getstatusname()}')
225 |
226 | self.order = None
227 |
228 | def stop(self):
229 | """策略结束时的汇总信息"""
230 | portfolio_value = self.broker.getvalue()
231 | returns = (portfolio_value / self.broker.startingcash) - 1.0
232 | logger.info(f"策略结束 - 最终资金: {portfolio_value:.2f}, 收益率: {returns:.2%}")
233 |
--------------------------------------------------------------------------------
/src/trading/market_executor.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import os
3 | from datetime import datetime, timedelta
4 | import time
5 | import tushare as ts
6 | import akshare as ak
7 | import backtrader as bt
8 | from src.utils.logger import setup_logger
9 | from src.utils.notification import send_notification
10 | from src.strategies.market_sentiment_strategy import MarketSentimentStrategy
11 | from src.data.data_loader import DataLoader
12 | from src.utils.analysis import Analysis
13 |
14 | logger = setup_logger()
15 |
16 | class MarketExecutor:
17 | def __init__(self, symbols: list, tushare_token: str):
18 | self.symbols = symbols
19 | self.records_dir = "data/trading_records"
20 | self.data_loader = DataLoader(tushare_token=tushare_token)
21 | self.analysis = Analysis()
22 | os.makedirs(self.records_dir, exist_ok=True)
23 |
24 | def is_trading_day(self):
25 | """判断当前是否为交易日"""
26 | today = datetime.now().date()
27 | # 获取交易日历
28 | trade_cal = self.data_loader.pro.trade_cal(exchange='SSE',
29 | start_date=today.strftime('%Y%m%d'),
30 | end_date=today.strftime('%Y%m%d'))
31 | return trade_cal.iloc[0]['is_open'] == 1
32 |
33 | def get_realtime_data(self, ts_code):
34 | """获取实时行情数据"""
35 | try:
36 | # 获取实时行情
37 | # 转换为雪球指数代码格式
38 | if ts_code == '000001.SH':
39 | market_code = 'SH000001'
40 | elif ts_code == '000300.SH':
41 | market_code = 'SH000300'
42 | elif ts_code == '000016.SH':
43 | market_code = 'SH000016'
44 | elif ts_code == '399240.SZ':
45 | market_code = 'SZ399240'
46 | else:
47 | market_code = ts_code.replace('.SH', 'SH').replace('.SZ', 'SZ')
48 |
49 | realtime_data = ak.stock_individual_spot_xq(symbol=market_code)
50 | if realtime_data is not None and not realtime_data.empty:
51 | # 将数据转换为以item为索引的格式
52 | realtime_data = realtime_data.set_index('item')
53 | # 转换数据格式以匹配原有接口
54 | return pd.DataFrame({
55 | 'ts_code': [ts_code],
56 | 'open': [float(realtime_data.loc['今开', 'value'])],
57 | 'high': [float(realtime_data.loc['最高', 'value'])],
58 | 'low': [float(realtime_data.loc['最低', 'value'])],
59 | 'close': [float(realtime_data.loc['现价', 'value'])],
60 | 'pre_close': [float(realtime_data.loc['昨收', 'value'])],
61 | 'vol': [float(realtime_data.loc['成交量', 'value'])],
62 | 'amount': [float(realtime_data.loc['成交额', 'value'])]
63 | })
64 | return None
65 | except Exception as e:
66 | logger.error(f"获取实时数据失败: {str(e)}")
67 | import traceback
68 | traceback.print_exc()
69 | return None
70 |
71 | def execute(self):
72 | """执行策略回测并记录交易信号"""
73 | try:
74 | # 检查是否为交易日
75 | if not self.is_trading_day():
76 | logger.info("当前不是交易日,跳过执行")
77 | return
78 |
79 | # 获取当前时间
80 | current_time = datetime.now()
81 |
82 | # 遍历上证50成分股进行回测
83 | for symbol in self.symbols:
84 | try:
85 | # 创建回测引擎
86 | cerebro = bt.Cerebro()
87 |
88 | # 设置初始资金
89 | cerebro.broker.setcash(1000000.0)
90 |
91 | # 设置交易手续费
92 | cerebro.broker.setcommission(commission=0.0003)
93 |
94 | # 获取历史数据
95 | start_date = (current_time - timedelta(days=365))
96 | end_date = current_time
97 |
98 | # 获取实时数据
99 | # realtime_data = self.get_realtime_data(symbol)
100 | # if realtime_data is None:
101 | # logger.error("获取实时数据失败,跳过本次执行")
102 | # continue
103 |
104 | # 使用DataLoader加载数据
105 | data = self.data_loader.download_data(
106 | symbol=symbol,
107 | start_date=start_date,
108 | end_date=end_date
109 | )
110 |
111 | if data is None:
112 | logger.error("获取数据失败,跳过本次执行")
113 | continue
114 |
115 | # 添加数据到回测引擎
116 | cerebro.adddata(data)
117 |
118 | # 添加策略
119 | cerebro.addstrategy(MarketSentimentStrategy)
120 |
121 | # 添加分析器
122 | cerebro.addanalyzer(bt.analyzers.SharpeRatio, _name='sharpe', riskfreerate=0.02)
123 | cerebro.addanalyzer(bt.analyzers.DrawDown, _name='drawdown')
124 | cerebro.addanalyzer(bt.analyzers.Returns, _name='returns')
125 | cerebro.addanalyzer(bt.analyzers.TradeAnalyzer, _name='trades')
126 | cerebro.addanalyzer(bt.analyzers.VWR, _name='vwr')
127 | cerebro.addanalyzer(bt.analyzers.SQN, _name='sqn')
128 | cerebro.addanalyzer(bt.analyzers.Transactions, _name='txn')
129 |
130 | # 运行回测
131 | results = cerebro.run()
132 | strategy = results[0]
133 |
134 | # 初始化trades列表
135 | cerebro.trades = []
136 |
137 | # 使用Analysis类获取回测结果
138 | analysis = self.analysis._get_analysis(cerebro, strategy)
139 |
140 | logger.info(f"回测结果: {analysis}")
141 |
142 | # 获取交易记录
143 | if 'trades' in analysis and not analysis['trades'].empty:
144 | # 获取当天的交易记录
145 | # today_trades = analysis['trades'][
146 | # analysis['trades']['交易时间'] == current_time.strftime('%Y-%m-%d')
147 | # ]
148 | today_trades = analysis['trades']
149 | logger.info(f"当天交易记录: {today_trades}, 所有交易记录: {analysis['trades']}")
150 |
151 | # 记录当天的交易
152 | for _, trade in today_trades.iterrows():
153 | action = "buy" if trade['方向'] == '买入' else "sell"
154 | signal = 1 if action == "buy" else -1
155 |
156 | self._record_trade(
157 | symbol=symbol,
158 | action=action,
159 | timestamp=current_time,
160 | signal=signal,
161 | price=float(trade['成交价']),
162 | size=int(trade['数量']),
163 | reason=trade['交易原因']
164 | )
165 |
166 | # 发送通知
167 | message = f"交易提醒: {symbol} {action.upper()} 信号\n"
168 | message += f"价格: {trade['成交价']}, 数量: {trade['数量']}\n"
169 | message += f"原因: {trade['交易原因']}"
170 | send_notification(message)
171 |
172 | except Exception as e:
173 | logger.error(f"处理股票 {symbol} 时出错: {str(e)}")
174 | continue
175 |
176 | except Exception as e:
177 | logger.error(f"执行策略回测时出错: {str(e)}")
178 | send_notification(f"策略回测错误: {str(e)}")
179 |
180 | def _record_trade(self, symbol: str, action: str, timestamp: datetime, signal: float,
181 | price: float, size: int, reason: str):
182 | """记录交易到CSV文件"""
183 | record_file = os.path.join(self.records_dir, f"trades_{timestamp.strftime('%Y%m%d')}.csv")
184 |
185 | record = {
186 | "timestamp": timestamp,
187 | "symbol": symbol,
188 | "action": action,
189 | "signal": signal,
190 | "price": price,
191 | "size": size,
192 | "reason": reason,
193 | "strategy": "market_sentiment"
194 | }
195 |
196 | df = pd.DataFrame([record])
197 |
198 | if os.path.exists(record_file):
199 | df.to_csv(record_file, mode='a', header=False, index=False)
200 | else:
201 | df.to_csv(record_file, index=False)
202 |
203 | def run_continuously(self, interval: int = 3600):
204 | """持续运行策略回测,每小时执行一次"""
205 | logger.info("启动实盘交易系统")
206 | while True:
207 | self.execute()
208 | time.sleep(interval)
--------------------------------------------------------------------------------
/rl_model_finrl/agents/rllib/ppo_agent.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import pandas as pd
4 | import logging
5 | from typing import Dict, List, Tuple, Any, Optional
6 | import ray
7 | from ray import tune
8 | from ray.rllib.algorithms.ppo import PPO
9 | from ray.rllib.utils.framework import try_import_tf, try_import_torch
10 | from ray.tune.registry import register_env
11 |
12 | logger = logging.getLogger(__name__)
13 |
14 | class RLlibPPOAgent:
15 | """
16 | RLlib PPO 智能体实现
17 |
18 | 这个类封装了Ray RLlib的PPO算法,用于ETF交易环境
19 | """
20 |
21 | def __init__(
22 | self,
23 | env,
24 | model_name: str = "ppo_rllib",
25 | tensorboard_log: Optional[str] = None,
26 | seed: Optional[int] = None,
27 | verbose: int = 1,
28 | device: str = "auto",
29 | **kwargs
30 | ):
31 | """
32 | 初始化RLlib PPO智能体
33 |
34 | 参数:
35 | env: 训练环境
36 | model_name: 模型名称
37 | tensorboard_log: TensorBoard日志目录
38 | seed: 随机种子
39 | verbose: 详细程度
40 | device: 设备选择 ('cpu', 'cuda', 'auto')
41 | **kwargs: 传递给PPO构造函数的其他参数
42 | """
43 | self.env = env
44 | self.model_name = model_name
45 | self.tensorboard_log = tensorboard_log
46 | self.seed = seed
47 | self.verbose = verbose
48 | self.device = device
49 | self.ppo_config = {
50 | "framework": "torch",
51 | "num_gpus": 0 if device == "cpu" else 1,
52 | "seed": seed,
53 | # PPO特定参数
54 | "lambda": 0.95,
55 | "kl_coeff": 0.5,
56 | "clip_param": 0.2,
57 | "vf_clip_param": 10.0,
58 | "entropy_coeff": 0.01,
59 | "train_batch_size": 5000,
60 | "sgd_minibatch_size": 500,
61 | "num_sgd_iter": 10,
62 | # 通用参数
63 | "gamma": 0.99,
64 | "lr": 3e-4,
65 | "log_level": "WARN",
66 | }
67 |
68 | # 更新配置
69 | self.ppo_config.update(kwargs)
70 |
71 | # 注册环境
72 | self._register_env()
73 |
74 | # 初始化Ray(如果尚未初始化)
75 | if not ray.is_initialized():
76 | if self.verbose > 0:
77 | logger.info("初始化Ray...")
78 | ray.init(ignore_reinit_error=True, logging_level=logging.ERROR)
79 |
80 | # 创建PPO算法实例
81 | self.model = PPO(
82 | config=self.ppo_config,
83 | env=self.env.__class__.__name__,
84 | )
85 |
86 | if self.verbose > 0:
87 | logger.info(f"RLlib PPO智能体初始化完成: {model_name}")
88 |
89 | def _register_env(self):
90 | """注册环境到Ray"""
91 | env_name = self.env.__class__.__name__
92 |
93 | # 注册环境创建函数
94 | def env_creator(env_config):
95 | return self.env
96 |
97 | # 注册环境
98 | register_env(env_name, env_creator)
99 |
100 | def learn(
101 | self,
102 | total_timesteps: int = 100000,
103 | callback: Any = None,
104 | log_interval: int = 4,
105 | eval_env = None,
106 | eval_freq: int = -1,
107 | n_eval_episodes: int = 5,
108 | tb_log_name: str = "PPO",
109 | eval_log_path: Optional[str] = None,
110 | reset_num_timesteps: bool = True,
111 | ) -> "RLlibPPOAgent":
112 | """
113 | 训练模型
114 |
115 | 参数:
116 | total_timesteps: 总训练步数
117 | callback: 回调函数
118 | log_interval: 日志记录间隔
119 | eval_env: 评估环境
120 | eval_freq: 评估频率
121 | n_eval_episodes: 评估轮数
122 | tb_log_name: TensorBoard日志名称
123 | eval_log_path: 评估日志路径
124 | reset_num_timesteps: 是否重置时间步计数
125 |
126 | 返回:
127 | 训练后的智能体
128 | """
129 | if self.verbose > 0:
130 | logger.info(f"开始训练模型,总步数: {total_timesteps}")
131 |
132 | iterations = total_timesteps // self.ppo_config["train_batch_size"]
133 |
134 | for i in range(iterations):
135 | if self.verbose > 0 and i % log_interval == 0:
136 | logger.info(f"训练迭代 {i+1}/{iterations}")
137 |
138 | # 执行训练
139 | result = self.model.train()
140 |
141 | # 打印训练结果
142 | if self.verbose > 0 and i % log_interval == 0:
143 | logger.info(f" 训练回报: {result['episode_reward_mean']:.2f}")
144 | logger.info(f" 训练长度: {result['episode_len_mean']:.2f}")
145 |
146 | # 评估(如果需要)
147 | if eval_env is not None and eval_freq > 0 and i % eval_freq == 0:
148 | self.evaluate(eval_env, n_eval_episodes)
149 |
150 | return self
151 |
152 | def predict(
153 | self,
154 | observation: np.ndarray,
155 | state: Optional[Tuple] = None,
156 | deterministic: bool = True,
157 | ) -> Tuple[np.ndarray, Optional[Tuple]]:
158 | """
159 | 预测动作
160 |
161 | 参数:
162 | observation: 当前观察
163 | state: RNN状态(如果适用)
164 | deterministic: 是否确定性预测
165 |
166 | 返回:
167 | (动作, 状态)元组
168 | """
169 | action = self.model.compute_single_action(
170 | observation,
171 | explore=not deterministic
172 | )
173 | return action, state
174 |
175 | def save(self, path: str) -> None:
176 | """
177 | 保存模型
178 |
179 | 参数:
180 | path: 保存路径
181 | """
182 | os.makedirs(os.path.dirname(path), exist_ok=True)
183 | checkpoint_path = self.model.save(path)
184 | if self.verbose > 0:
185 | logger.info(f"模型已保存到: {checkpoint_path}")
186 |
187 | def load(self, path: str) -> None:
188 | """
189 | 加载模型
190 |
191 | 参数:
192 | path: 模型路径
193 | """
194 | self.model.restore(path)
195 | if self.verbose > 0:
196 | logger.info(f"已从{path}加载模型")
197 |
198 | def evaluate(
199 | self,
200 | eval_env,
201 | n_eval_episodes: int = 10,
202 | deterministic: bool = True,
203 | render: bool = False,
204 | ) -> Tuple[float, float]:
205 | """
206 | 评估模型
207 |
208 | 参数:
209 | eval_env: 评估环境
210 | n_eval_episodes: 评估轮数
211 | deterministic: 是否确定性预测
212 | render: 是否渲染环境
213 |
214 | 返回:
215 | (平均奖励, 标准差)元组
216 | """
217 | if self.verbose > 0:
218 | logger.info(f"开始评估模型,轮数: {n_eval_episodes}")
219 |
220 | episode_rewards = []
221 | episode_lengths = []
222 |
223 | for i in range(n_eval_episodes):
224 | obs = eval_env.reset()
225 | done = False
226 | episode_reward = 0.0
227 | episode_length = 0
228 |
229 | while not done:
230 | action, _ = self.predict(obs, deterministic=deterministic)
231 | obs, reward, done, info = eval_env.step(action)
232 |
233 | episode_reward += reward
234 | episode_length += 1
235 |
236 | if render:
237 | eval_env.render()
238 |
239 | episode_rewards.append(episode_reward)
240 | episode_lengths.append(episode_length)
241 |
242 | mean_reward = np.mean(episode_rewards)
243 | std_reward = np.std(episode_rewards)
244 |
245 | if self.verbose > 0:
246 | logger.info(f"评估结果 - 平均奖励: {mean_reward:.2f} +/- {std_reward:.2f}")
247 |
248 | return mean_reward, std_reward
249 |
250 | def test(
251 | self,
252 | test_env,
253 | num_episodes: int = 1,
254 | render: bool = False
255 | ) -> Tuple[pd.DataFrame, pd.DataFrame]:
256 | """
257 | 测试模型并返回资产和动作记录
258 |
259 | 参数:
260 | test_env: 测试环境
261 | num_episodes: 测试轮数
262 | render: 是否渲染环境
263 |
264 | 返回:
265 | (资产记录, 动作记录)元组
266 | """
267 | if self.verbose > 0:
268 | logger.info(f"开始测试模型,轮数: {num_episodes}")
269 |
270 | # 重置环境
271 | state = test_env.reset()
272 |
273 | # 初始化记录
274 | episode_rewards = []
275 |
276 | # 决策步骤
277 | done = False
278 | episode_reward = 0.0
279 |
280 | while not done:
281 | action, _ = self.predict(state, deterministic=True)
282 | next_state, reward, done, info = test_env.step(action)
283 |
284 | state = next_state
285 | episode_reward += reward
286 |
287 | if render:
288 | test_env.render()
289 |
290 | if self.verbose > 0:
291 | logger.info(f"测试完成,总奖励: {episode_reward:.2f}")
292 |
293 | # 获取回测数据
294 | asset_memory = test_env.save_asset_memory()
295 | action_memory = test_env.save_action_memory()
296 |
297 | return asset_memory, action_memory
--------------------------------------------------------------------------------
/src/data/data_loader.py:
--------------------------------------------------------------------------------
1 | import tushare as ts
2 | import akshare as ak
3 | import pandas as pd
4 | import backtrader as bt
5 | from datetime import datetime
6 | import os
7 | from loguru import logger
8 |
9 | class DataLoader:
10 | def __init__(self, tushare_token=None):
11 | """
12 | 初始化数据加载器
13 | :param tushare_token: Tushare的API token
14 | """
15 | self.tushare_token = tushare_token
16 | if tushare_token:
17 | ts.set_token(tushare_token)
18 | self.pro = ts.pro_api()
19 | logger.info("Tushare API初始化成功")
20 |
21 | def download_data(self, symbol, start_date, end_date):
22 | """
23 | 下载数据,支持A股、ETF和港股
24 | :param symbol: 股票代码(格式:000001.SZ, 510300.SH, 00700.HK等)或ETF代码列表
25 | :param start_date: 开始日期
26 | :param end_date: 结束日期
27 | :return: DataFrame或DataFrame列表
28 | """
29 | try:
30 | # 转换日期格式
31 | start_str = start_date.strftime("%Y%m%d")
32 | end_str = end_date.strftime("%Y%m%d")
33 |
34 | # 如果是ETF轮动策略,symbol应该是一个列表
35 | if isinstance(symbol, list):
36 | logger.info(f"开始下载多个ETF数据: {symbol}")
37 | data_list = []
38 | for etf in symbol:
39 | df = self._download_etf_data(etf, start_date, end_date)
40 | if not df.empty:
41 | data = PandasData(dataname=df, ts_code=etf, fromdate=start_date, todate=end_date)
42 | data_list.append(data)
43 | logger.info(f"成功下载 {len(data_list)} 个ETF数据")
44 | return data_list
45 |
46 | # 判断市场类型
47 | if symbol.endswith(('.SH', '.SZ')): # A股或ETF
48 | if symbol.startswith('51') or symbol.startswith('159'): # ETF
49 | logger.info(f"下载ETF数据: {symbol}")
50 | df = self._download_etf_data(symbol, start_date, end_date)
51 | else: # A股
52 | logger.info(f"下载A股数据: {symbol}")
53 | df = self._download_stock_data(symbol, start_str, end_str)
54 | elif symbol.endswith('.HK'): # 港股
55 | logger.info(f"使用AKShare下载港股数据: {symbol}")
56 | df = self._download_hk_data(symbol, start_date, end_date)
57 | else:
58 | raise ValueError(f"不支持的市场类型: {symbol}")
59 |
60 | if df.empty:
61 | logger.warning(f"下载数据为空: {symbol}")
62 | return None
63 |
64 | logger.info(f"下载数据成功: {symbol},数据长度: {len(df)}, 日期范围: {df.index[0]} 至 {df.index[-1]}")
65 |
66 | # 创建PandasData对象并设置股票代码和时间范围
67 | data = PandasData(dataname=df, ts_code=symbol, fromdate=start_date, todate=end_date)
68 | return data
69 |
70 | except Exception as e:
71 | logger.error(f"下载数据失败: {str(e)}")
72 | import traceback
73 | traceback.print_exc()
74 | raise
75 |
76 | def _download_stock_data(self, symbol, start_date, end_date):
77 | """下载A股数据"""
78 | if not self.tushare_token:
79 | raise ValueError("需要设置Tushare token才能下载A股数据")
80 |
81 | df = self.pro.daily(ts_code=symbol, start_date=start_date, end_date=end_date)
82 | if df.empty:
83 | return pd.DataFrame()
84 |
85 | # 重命名列以匹配backtrader要求
86 | df = df.rename(columns={
87 | 'trade_date': 'date',
88 | 'open': 'open',
89 | 'high': 'high',
90 | 'low': 'low',
91 | 'close': 'close',
92 | 'vol': 'volume'
93 | })
94 |
95 | # 转换日期格式并设置为索引
96 | df['date'] = pd.to_datetime(df['date'])
97 | df = df.set_index('date')
98 | df = df.sort_index() # 确保按日期升序
99 |
100 | return df[['open', 'high', 'low', 'close', 'volume']]
101 |
102 | def _download_etf_data(self, symbol, start_date, end_date):
103 | """下载ETF数据"""
104 | symbol_code = symbol.split('.')[0] # 去掉市场后缀
105 | start_str = start_date.strftime('%Y%m%d')
106 | end_str = end_date.strftime('%Y%m%d')
107 |
108 | try:
109 | # 判断是否有tushare_token
110 | if self.tushare_token:
111 | logger.info(f"使用Tushare下载ETF数据: {symbol}")
112 | df = self.pro.fund_daily(ts_code=symbol, start_date=start_str, end_date=end_str)
113 |
114 | if not df.empty:
115 | # 重命名列以匹配backtrader要求
116 | df = df.rename(columns={
117 | 'trade_date': 'date',
118 | 'open': 'open',
119 | 'high': 'high',
120 | 'low': 'low',
121 | 'close': 'close',
122 | 'vol': 'volume'
123 | })
124 |
125 | # 转换日期格式并设置为索引
126 | df['date'] = pd.to_datetime(df['date'])
127 | df = df.set_index('date')
128 | df = df.sort_index() # 确保按日期升序
129 | else:
130 | logger.warning(f"Tushare未返回ETF数据,尝试使用AKShare: {symbol}")
131 | df = self._download_etf_from_akshare(symbol_code, start_date, end_date)
132 | else:
133 | logger.info(f"未设置Tushare token,使用AKShare下载ETF数据: {symbol}")
134 | df = self._download_etf_from_akshare(symbol_code, start_date, end_date)
135 |
136 | if df.empty:
137 | logger.warning(f"下载ETF数据为空 - 股票代码: {symbol}, 日期范围: {start_date} 至 {end_date}")
138 | return pd.DataFrame()
139 |
140 | logger.info(f"下载ETF数据成功: {symbol},数据长度: {len(df)}, 日期范围: {df.index[0]} 至 {df.index[-1]}")
141 | return df[['open', 'high', 'low', 'close', 'volume']]
142 |
143 | except Exception as e:
144 | logger.error(f"下载ETF数据失败: {str(e)}")
145 | import traceback
146 | traceback.print_exc()
147 | return pd.DataFrame()
148 |
149 | def _download_etf_from_akshare(self, symbol_code, start_date, end_date):
150 | """使用AKShare下载ETF数据"""
151 | df = ak.fund_etf_hist_em(symbol=symbol_code, period="daily")
152 |
153 | # 重命名列以匹配backtrader要求
154 | df = df.rename(columns={
155 | '日期': 'date',
156 | '开盘': 'open',
157 | '最高': 'high',
158 | '最低': 'low',
159 | '收盘': 'close',
160 | '成交量': 'volume'
161 | })
162 |
163 | # 转换日期格式并设置为索引
164 | df['date'] = pd.to_datetime(df['date'])
165 | df = df.set_index('date')
166 | df = df.sort_index() # 确保按日期升序
167 |
168 | # 过滤日期范围
169 | logger.info(f"过滤日期范围: {start_date} 至 {end_date}")
170 | mask = (df.index >= pd.Timestamp(start_date)) & (df.index <= pd.Timestamp(end_date))
171 | df = df[mask]
172 |
173 | return df
174 |
175 | def _download_hk_data(self, symbol, start_date, end_date):
176 | """下载港股数据"""
177 | symbol_code = symbol.split('.')[0] # 去掉市场后缀
178 |
179 | try:
180 | df = ak.stock_hk_daily(symbol=symbol_code)
181 |
182 | # 重命名列以匹配backtrader要求
183 | df = df.rename(columns={
184 | '日期': 'date',
185 | '开盘': 'open',
186 | '最高': 'high',
187 | '最低': 'low',
188 | '收盘': 'close',
189 | '成交量': 'volume'
190 | })
191 |
192 | # 转换日期格式并设置为索引
193 | df['date'] = pd.to_datetime(df['date'])
194 | df = df.set_index('date')
195 |
196 | # 过滤日期范围
197 | df = df.loc[start_date:end_date]
198 |
199 | return df[['open', 'high', 'low', 'close', 'volume']]
200 |
201 | except Exception as e:
202 | logger.error(f"下载港股数据失败: {str(e)}")
203 | return pd.DataFrame()
204 |
205 | class PandasData(bt.feeds.PandasData):
206 | """自定义PandasData类,用于加载数据"""
207 | params = (
208 | ('datetime', None), # 使用索引作为日期
209 | ('open', 'open'),
210 | ('high', 'high'),
211 | ('low', 'low'),
212 | ('close', 'close'),
213 | ('volume', 'volume'),
214 | ('openinterest', None),
215 | ('ts_code', None), # 股票代码
216 | ('fromdate', None),
217 | ('todate', None),
218 | )
219 |
220 | def __init__(self, **kwargs):
221 | """初始化数据源"""
222 | # 从kwargs中获取ts_code和日期范围
223 | self.ts_code = kwargs.pop('ts_code', None)
224 | self.fromdate = kwargs.pop('fromdate', None)
225 | self.todate = kwargs.pop('todate', None)
226 |
227 | # 确保参数被正确设置
228 | if self.ts_code:
229 | self.params.ts_code = self.ts_code
230 | if self.fromdate:
231 | self.params.fromdate = self.fromdate
232 | if self.todate:
233 | self.params.todate = self.todate
234 |
235 | # 调用父类初始化
236 | super().__init__(**kwargs)
237 |
238 | # 设置数据源的名称
239 | if self.ts_code:
240 | self._name = self.ts_code
--------------------------------------------------------------------------------
/tools/plot_general_json.py:
--------------------------------------------------------------------------------
1 | import json
2 | import pandas as pd
3 | import matplotlib.pyplot as plt
4 | import numpy as np
5 | import plotly.express as px
6 | import plotly.graph_objects as go
7 | import streamlit as st
8 | from datetime import datetime
9 | import os
10 |
11 | def load_json_data(file_path=None, uploaded_file=None):
12 | """从文件或上传的数据加载JSON数据"""
13 | try:
14 | if uploaded_file is not None:
15 | data = json.load(uploaded_file)
16 | elif file_path is not None:
17 | with open(file_path, 'r', encoding='utf-8') as f:
18 | data = json.load(f)
19 | else:
20 | st.error("未提供数据源")
21 | return None
22 | return data
23 | except Exception as e:
24 | st.error(f"加载JSON文件时出错:{e}")
25 | return None
26 |
27 | def flatten_json(data, parent_key='', sep='_'):
28 | """展平嵌套的JSON数据结构"""
29 | items = []
30 | for k, v in data.items():
31 | new_key = f"{parent_key}{sep}{k}" if parent_key else k
32 | if isinstance(v, dict):
33 | items.extend(flatten_json(v, new_key, sep=sep).items())
34 | elif isinstance(v, list):
35 | if len(v) > 0 and isinstance(v[0], dict):
36 | # 处理字典列表
37 | for i, item in enumerate(v):
38 | items.extend(flatten_json(item, f"{new_key}{sep}{i}", sep=sep).items())
39 | else:
40 | items.append((new_key, str(v)))
41 | else:
42 | items.append((new_key, v))
43 | return dict(items)
44 |
45 | def process_json_data(data):
46 | """处理JSON数据以准备可视化"""
47 | if isinstance(data, list):
48 | # 如果是列表,尝试将其转换为DataFrame
49 | if len(data) > 0 and isinstance(data[0], dict):
50 | df = pd.DataFrame(data)
51 | else:
52 | df = pd.DataFrame({"值": data})
53 | elif isinstance(data, dict):
54 | # 如果是字典,展平它
55 | flat_data = flatten_json(data)
56 | df = pd.DataFrame([flat_data])
57 | else:
58 | st.error("不支持的JSON数据格式")
59 | return None
60 |
61 | # 尝试转换日期列
62 | for col in df.columns:
63 | if df[col].dtype == object:
64 | try:
65 | df[col] = pd.to_datetime(df[col])
66 | except:
67 | pass
68 |
69 | return df
70 |
71 | def get_numeric_columns(df):
72 | """获取数值类型的列"""
73 | return df.select_dtypes(include=[np.number]).columns.tolist()
74 |
75 | def get_datetime_columns(df):
76 | """获取日期时间类型的列"""
77 | return df.select_dtypes(include=['datetime64']).columns.tolist()
78 |
79 | def get_categorical_columns(df):
80 | """获取分类类型的列"""
81 | return df.select_dtypes(include=['object', 'category']).columns.tolist()
82 |
83 | def plot_time_series(df, x_col, y_col, title=None):
84 | """创建时间序列图表"""
85 | fig = px.line(df, x=x_col, y=y_col,
86 | title=title or f'{y_col}随{x_col}的变化',
87 | labels={x_col: x_col, y_col: y_col})
88 |
89 | fig.update_layout(
90 | xaxis_title=x_col,
91 | yaxis_title=y_col,
92 | hovermode='x unified'
93 | )
94 |
95 | return fig
96 |
97 | def plot_scatter(df, x_col, y_col, color_col=None, title=None):
98 | """创建散点图"""
99 | fig = px.scatter(df, x=x_col, y=y_col, color=color_col,
100 | title=title or f'{x_col}与{y_col}的关系',
101 | labels={x_col: x_col, y_col: y_col})
102 |
103 | fig.update_layout(
104 | xaxis_title=x_col,
105 | yaxis_title=y_col,
106 | hovermode='closest'
107 | )
108 |
109 | return fig
110 |
111 | def plot_bar(df, x_col, y_col, title=None):
112 | """创建柱状图"""
113 | fig = px.bar(df, x=x_col, y=y_col,
114 | title=title or f'{x_col}的{y_col}分布',
115 | labels={x_col: x_col, y_col: y_col})
116 |
117 | fig.update_layout(
118 | xaxis_title=x_col,
119 | yaxis_title=y_col,
120 | bargap=0.2
121 | )
122 |
123 | return fig
124 |
125 | def plot_histogram(df, col, title=None):
126 | """创建直方图"""
127 | fig = px.histogram(df, x=col,
128 | title=title or f'{col}的分布',
129 | labels={col: col})
130 |
131 | fig.update_layout(
132 | xaxis_title=col,
133 | yaxis_title='频率',
134 | bargap=0.1
135 | )
136 |
137 | return fig
138 |
139 | def plot_box(df, x_col, y_col, title=None):
140 | """创建箱线图"""
141 | fig = px.box(df, x=x_col, y=y_col,
142 | title=title or f'{y_col}按{x_col}的分布',
143 | labels={x_col: x_col, y_col: y_col})
144 |
145 | fig.update_layout(
146 | xaxis_title=x_col,
147 | yaxis_title=y_col
148 | )
149 |
150 | return fig
151 |
152 | def plot_heatmap(df, x_col, y_col, values_col, title=None):
153 | """创建热力图"""
154 | heatmap_data = df.pivot_table(
155 | values=values_col,
156 | index=y_col,
157 | columns=x_col,
158 | aggfunc='mean'
159 | )
160 |
161 | fig = go.Figure(data=go.Heatmap(
162 | z=heatmap_data.values,
163 | x=heatmap_data.columns,
164 | y=heatmap_data.index,
165 | colorscale='Viridis'
166 | ))
167 |
168 | fig.update_layout(
169 | title=title or f'{x_col}和{y_col}的{values_col}热力图',
170 | xaxis_title=x_col,
171 | yaxis_title=y_col
172 | )
173 |
174 | return fig
175 |
176 | def main():
177 | st.title("通用JSON数据可视化工具")
178 | st.write("此应用程序可以可视化任何结构的JSON数据。")
179 |
180 | # 数据输入选项
181 | data_source = st.radio(
182 | "选择数据来源",
183 | ["从cache目录加载", "上传JSON文件"]
184 | )
185 |
186 | data = None
187 | if data_source == "从cache目录加载":
188 | json_files = [f for f in os.listdir('../cache') if f.endswith('.json')]
189 | if not json_files:
190 | st.error("在'cache'目录中未找到JSON文件。")
191 | return
192 |
193 | selected_file = st.selectbox("选择要可视化的JSON文件", json_files)
194 | file_path = os.path.join('../cache', selected_file)
195 | data = load_json_data(file_path=file_path)
196 | else:
197 | uploaded_file = st.file_uploader("上传JSON文件", type=['json'])
198 | if uploaded_file is not None:
199 | data = load_json_data(uploaded_file=uploaded_file)
200 |
201 | if data is None:
202 | return
203 |
204 | # 处理数据
205 | df = process_json_data(data)
206 | if df is None:
207 | return
208 |
209 | # 显示基本信息
210 | st.subheader("数据概览")
211 | st.write(f"记录数量:{len(df)}")
212 | st.write(f"列数量:{len(df.columns)}")
213 |
214 | # 获取不同类型的列
215 | numeric_cols = get_numeric_columns(df)
216 | datetime_cols = get_datetime_columns(df)
217 | categorical_cols = get_categorical_columns(df)
218 |
219 | # 可视化选项
220 | viz_type = st.selectbox(
221 | "选择可视化类型",
222 | ["时间序列图", "散点图", "柱状图", "直方图", "箱线图", "热力图"]
223 | )
224 |
225 | if viz_type == "时间序列图":
226 | if not datetime_cols:
227 | st.warning("数据中没有日期时间列,无法创建时间序列图。")
228 | return
229 |
230 | x_col = st.selectbox("选择时间列", datetime_cols)
231 | y_col = st.selectbox("选择数值列", numeric_cols)
232 |
233 | fig = plot_time_series(df, x_col, y_col)
234 | st.plotly_chart(fig)
235 |
236 | elif viz_type == "散点图":
237 | x_col = st.selectbox("选择X轴", numeric_cols)
238 | y_col = st.selectbox("选择Y轴", numeric_cols)
239 | color_col = st.selectbox("选择颜色分类(可选)", ["无"] + categorical_cols)
240 |
241 | fig = plot_scatter(df, x_col, y_col,
242 | color_col if color_col != "无" else None)
243 | st.plotly_chart(fig)
244 |
245 | elif viz_type == "柱状图":
246 | x_col = st.selectbox("选择分类列", categorical_cols)
247 | y_col = st.selectbox("选择数值列", numeric_cols)
248 |
249 | fig = plot_bar(df, x_col, y_col)
250 | st.plotly_chart(fig)
251 |
252 | elif viz_type == "直方图":
253 | col = st.selectbox("选择数值列", numeric_cols)
254 |
255 | fig = plot_histogram(df, col)
256 | st.plotly_chart(fig)
257 |
258 | elif viz_type == "箱线图":
259 | x_col = st.selectbox("选择分组列", categorical_cols)
260 | y_col = st.selectbox("选择数值列", numeric_cols)
261 |
262 | fig = plot_box(df, x_col, y_col)
263 | st.plotly_chart(fig)
264 |
265 | elif viz_type == "热力图":
266 | x_col = st.selectbox("选择X轴分类", categorical_cols)
267 | y_col = st.selectbox("选择Y轴分类", categorical_cols)
268 | values_col = st.selectbox("选择数值列", numeric_cols)
269 |
270 | fig = plot_heatmap(df, x_col, y_col, values_col)
271 | st.plotly_chart(fig)
272 |
273 | # 显示数据统计
274 | if st.checkbox("显示数据统计"):
275 | st.subheader("数据统计")
276 | st.write("数值列统计:")
277 | st.dataframe(df[numeric_cols].describe())
278 |
279 | if categorical_cols:
280 | st.write("分类列统计:")
281 | for col in categorical_cols:
282 | st.write(f"\n{col}的唯一值数量:{df[col].nunique()}")
283 | st.dataframe(df[col].value_counts().head())
284 |
285 | # 显示原始数据
286 | if st.checkbox("显示原始数据"):
287 | st.subheader("原始数据")
288 | st.dataframe(df)
289 |
290 | # 添加下载按钮
291 | csv = df.to_csv(index=False).encode('utf-8')
292 | st.download_button(
293 | label="下载数据为CSV",
294 | data=csv,
295 | file_name=f'data_{datetime.now().strftime("%Y%m%d")}.csv',
296 | mime='text/csv',
297 | )
298 |
299 | if __name__ == "__main__":
300 | main()
--------------------------------------------------------------------------------
/rl_model_finrl/meta/preprocessor/feature_engineer.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import numpy as np
3 | import talib
4 | from typing import Dict, List, Tuple, Union, Optional
5 |
6 | class FeatureEngineer:
7 | """
8 | 金融特征工程类
9 |
10 | 用于计算和添加金融技术指标和特征
11 | """
12 |
13 | def __init__(self,
14 | use_technical_indicators: bool = True,
15 | use_vix: bool = False,
16 | use_turbulence: bool = False,
17 | use_sentiment: bool = False):
18 | """
19 | 初始化特征工程器
20 |
21 | 参数:
22 | use_technical_indicators: 是否使用技术指标
23 | use_vix: 是否使用波动率指数
24 | use_turbulence: 是否使用市场波动指标
25 | use_sentiment: 是否使用情绪指标
26 | """
27 | self.use_technical_indicators = use_technical_indicators
28 | self.use_vix = use_vix
29 | self.use_turbulence = use_turbulence
30 | self.use_sentiment = use_sentiment
31 |
32 | def preprocess(self, df: pd.DataFrame) -> pd.DataFrame:
33 | """
34 | 预处理数据并添加特征
35 |
36 | 参数:
37 | df: 原始数据DataFrame
38 |
39 | 返回:
40 | 处理后的DataFrame
41 | """
42 | data = df.copy()
43 |
44 | # 添加技术指标
45 | if self.use_technical_indicators:
46 | data = self.add_technical_indicators(data)
47 |
48 | # 添加波动率指数
49 | if self.use_vix:
50 | data = self.add_vix(data)
51 |
52 | # 添加市场波动指标
53 | if self.use_turbulence:
54 | data = self.add_turbulence(data)
55 |
56 | # 添加情绪指标
57 | if self.use_sentiment:
58 | data = self.add_sentiment(data)
59 |
60 | # 填充NaN值
61 | data = self.fill_missing_values(data)
62 |
63 | return data
64 |
65 | def add_technical_indicators(self, data: pd.DataFrame) -> pd.DataFrame:
66 | """
67 | 添加技术指标
68 |
69 | 参数:
70 | data: 价格数据DataFrame
71 |
72 | 返回:
73 | 添加技术指标后的DataFrame
74 | """
75 | df = data.copy()
76 |
77 | # 确保列名标准化
78 | price_col = 'close' if 'close' in df.columns else 'Close'
79 | high_col = 'high' if 'high' in df.columns else 'High'
80 | low_col = 'low' if 'low' in df.columns else 'Low'
81 | volume_col = 'volume' if 'volume' in df.columns else 'Volume'
82 |
83 | # 使用TA-Lib计算技术指标 (如果可用)
84 | try:
85 | # MACD
86 | macd, macd_signal, macd_hist = talib.MACD(df[price_col])
87 | df['macd'] = macd
88 | df['macd_signal'] = macd_signal
89 | df['macd_hist'] = macd_hist
90 |
91 | # RSI
92 | df['rsi_6'] = talib.RSI(df[price_col], timeperiod=6)
93 | df['rsi_14'] = talib.RSI(df[price_col], timeperiod=14)
94 | df['rsi_30'] = talib.RSI(df[price_col], timeperiod=30)
95 |
96 | # CCI
97 | df['cci'] = talib.CCI(df[high_col], df[low_col], df[price_col], timeperiod=14)
98 |
99 | # ADX
100 | df['adx'] = talib.ADX(df[high_col], df[low_col], df[price_col], timeperiod=14)
101 |
102 | # 布林带
103 | df['boll_upper'], df['boll_middle'], df['boll_lower'] = talib.BBANDS(
104 | df[price_col], timeperiod=20, nbdevup=2, nbdevdn=2, matype=0)
105 |
106 | # ATR
107 | df['atr'] = talib.ATR(df[high_col], df[low_col], df[price_col], timeperiod=14)
108 |
109 | # 移动平均线
110 | df['sma_5'] = talib.SMA(df[price_col], timeperiod=5)
111 | df['sma_10'] = talib.SMA(df[price_col], timeperiod=10)
112 | df['sma_20'] = talib.SMA(df[price_col], timeperiod=20)
113 | df['sma_60'] = talib.SMA(df[price_col], timeperiod=60)
114 |
115 | # WILLR
116 | df['willr'] = talib.WILLR(df[high_col], df[low_col], df[price_col], timeperiod=14)
117 |
118 | # ROC
119 | df['roc'] = talib.ROC(df[price_col], timeperiod=10)
120 |
121 | # OBV
122 | df['obv'] = talib.OBV(df[price_col], df[volume_col])
123 |
124 | except (ImportError, AttributeError):
125 | # 如果TA-Lib不可用,使用Pandas计算
126 | # MACD
127 | df['ema12'] = df[price_col].ewm(span=12, adjust=False).mean()
128 | df['ema26'] = df[price_col].ewm(span=26, adjust=False).mean()
129 | df['macd'] = df['ema12'] - df['ema26']
130 | df['macd_signal'] = df['macd'].ewm(span=9, adjust=False).mean()
131 | df['macd_hist'] = df['macd'] - df['macd_signal']
132 |
133 | # RSI
134 | delta = df[price_col].diff()
135 | gain = delta.copy()
136 | loss = delta.copy()
137 | gain[gain < 0] = 0
138 | loss[loss > 0] = 0
139 | loss = -loss
140 |
141 | # RSI 14
142 | avg_gain = gain.rolling(window=14).mean()
143 | avg_loss = loss.rolling(window=14).mean()
144 | rs = avg_gain / avg_loss
145 | df['rsi_14'] = 100 - (100 / (1 + rs))
146 |
147 | # CCI
148 | df['tp'] = (df[high_col] + df[low_col] + df[price_col]) / 3
149 | df['tp_ma'] = df['tp'].rolling(window=20).mean()
150 | mean_dev = df['tp'].rolling(window=20).apply(lambda x: abs(x - x.mean()).mean())
151 | df['cci'] = (df['tp'] - df['tp_ma']) / (0.015 * mean_dev)
152 |
153 | # 布林带
154 | df['sma_20'] = df[price_col].rolling(window=20).mean()
155 | df['boll_upper'] = df['sma_20'] + 2 * df[price_col].rolling(window=20).std()
156 | df['boll_lower'] = df['sma_20'] - 2 * df[price_col].rolling(window=20).std()
157 |
158 | # ATR
159 | df['tr1'] = df[high_col] - df[low_col]
160 | df['tr2'] = abs(df[high_col] - df[price_col].shift())
161 | df['tr3'] = abs(df[low_col] - df[price_col].shift())
162 | df['tr'] = df[['tr1', 'tr2', 'tr3']].max(axis=1)
163 | df['atr'] = df['tr'].rolling(window=14).mean()
164 |
165 | # 移动平均线
166 | df['sma_5'] = df[price_col].rolling(window=5).mean()
167 | df['sma_10'] = df[price_col].rolling(window=10).mean()
168 | df['sma_60'] = df[price_col].rolling(window=60).mean()
169 |
170 | # 清理临时列
171 | df = df.drop(['tr1', 'tr2', 'tr3', 'tr', 'tp', 'tp_ma', 'ema12', 'ema26'],
172 | axis=1, errors='ignore')
173 |
174 | # 涨跌幅
175 | df['daily_return'] = df[price_col].pct_change()
176 |
177 | # 波动率
178 | df['volatility'] = df['daily_return'].rolling(window=20).std() * np.sqrt(252)
179 |
180 | return df
181 |
182 | def add_vix(self, data: pd.DataFrame, vix_data: Optional[pd.DataFrame] = None) -> pd.DataFrame:
183 | """
184 | 添加VIX波动率指数
185 |
186 | 参数:
187 | data: 原始数据DataFrame
188 | vix_data: VIX数据DataFrame(可选)
189 |
190 | 返回:
191 | 添加VIX后的DataFrame
192 | """
193 | df = data.copy()
194 |
195 | if vix_data is not None:
196 | # 如果提供了VIX数据,合并到主数据中
197 | vix_data = vix_data.rename(columns={'close': 'vix'})
198 | df = pd.merge(df, vix_data[['vix']],
199 | left_index=True, right_index=True,
200 | how='left')
201 | else:
202 | # 如果没有VIX数据,使用收益率的滚动波动率作为VIX代理
203 | price_col = 'close' if 'close' in df.columns else 'Close'
204 | returns = df[price_col].pct_change()
205 | df['vix'] = returns.rolling(window=20).std() * np.sqrt(252) * 100
206 |
207 | return df
208 |
209 | def add_turbulence(self, data: pd.DataFrame, window: int = 252) -> pd.DataFrame:
210 | """
211 | 添加市场波动指标
212 |
213 | 参数:
214 | data: 原始数据DataFrame
215 | window: 计算窗口
216 |
217 | 返回:
218 | 添加波动指标后的DataFrame
219 | """
220 | df = data.copy()
221 |
222 | # 计算收益率
223 | price_col = 'close' if 'close' in df.columns else 'Close'
224 | df['return'] = df[price_col].pct_change()
225 |
226 | # 计算波动指标
227 | df['turbulence'] = df['return'].rolling(window=window).apply(
228 | lambda x: np.sum(np.square(x - x.mean())) / len(x)
229 | )
230 |
231 | return df
232 |
233 | def add_sentiment(self, data: pd.DataFrame, sentiment_data: Optional[pd.DataFrame] = None) -> pd.DataFrame:
234 | """
235 | 添加市场情绪指标
236 |
237 | 参数:
238 | data: 原始数据DataFrame
239 | sentiment_data: 情绪数据DataFrame(可选)
240 |
241 | 返回:
242 | 添加情绪指标后的DataFrame
243 | """
244 | df = data.copy()
245 |
246 | if sentiment_data is not None:
247 | # 如果提供了情绪数据,合并到主数据中
248 | df = pd.merge(df, sentiment_data,
249 | left_index=True, right_index=True,
250 | how='left')
251 | else:
252 | # 如果没有情绪数据,简单地添加一个基于技术指标的情绪代理
253 | if 'rsi_14' in df.columns:
254 | # RSI的简单情绪指标: RSI高表示乐观,RSI低表示悲观
255 | df['sentiment'] = (df['rsi_14'] - 50) / 50 # 归一化到[-1, 1]
256 | elif 'macd' in df.columns and 'macd_signal' in df.columns:
257 | # 基于MACD信号的情绪: MACD > 信号线表示乐观
258 | df['sentiment'] = np.where(df['macd'] > df['macd_signal'], 1, -1)
259 | else:
260 | # 如果没有其他指标,使用收益率的动量作为情绪代理
261 | price_col = 'close' if 'close' in df.columns else 'Close'
262 | returns = df[price_col].pct_change()
263 | df['sentiment'] = returns.rolling(window=5).mean() / returns.rolling(window=5).std()
264 |
265 | return df
266 |
267 | def fill_missing_values(self, data: pd.DataFrame) -> pd.DataFrame:
268 | """
269 | 填充缺失值
270 |
271 | 参数:
272 | data: 包含缺失值的DataFrame
273 |
274 | 返回:
275 | 填充缺失值后的DataFrame
276 | """
277 | df = data.copy()
278 |
279 | # 前向填充(填充交易日中间缺失的数据)
280 | df = df.fillna(method='ffill')
281 |
282 | # 后向填充(处理最早的数据)
283 | df = df.fillna(method='bfill')
284 |
285 | # 对于仍然缺失的值,用0填充
286 | df = df.fillna(0)
287 |
288 | return df
--------------------------------------------------------------------------------
/rl_model_finrl/agents/stablebaseline3/dqn_agent.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import pandas as pd
4 | from typing import Any, Dict, List, Tuple, Type, Union
5 | import torch
6 | import torch.nn as nn
7 | from stable_baselines3 import DQN
8 | from stable_baselines3.common.callbacks import BaseCallback
9 | from stable_baselines3.common.vec_env import DummyVecEnv
10 | from stable_baselines3.common.logger import configure
11 | import logging
12 | import matplotlib.pyplot as plt
13 |
14 | from src.strategies.rl_model_finrl.config import (
15 | GAMMA,
16 | LEARNING_RATE,
17 | BATCH_SIZE,
18 | REPLAY_BUFFER_SIZE,
19 | NUM_EPISODES,
20 | TENSORBOARD_PATH
21 | )
22 |
23 | class TensorboardCallback(BaseCallback):
24 | """用于记录训练过程中的指标的回调函数"""
25 |
26 | def __init__(self, verbose=0):
27 | super(TensorboardCallback, self).__init__(verbose)
28 | self.training_env = None
29 |
30 | def _on_training_start(self) -> None:
31 | """训练开始时设置训练环境引用"""
32 | self.training_env = self.model.get_env()
33 | self._log_freq = 1 # 记录频率 = 1个回合
34 |
35 | def _on_step(self) -> bool:
36 | """每步执行,并记录指标"""
37 | if self.n_calls % self._log_freq == 0:
38 | # 从训练环境中获取信息
39 | info = self.training_env.buf_infos[0]
40 | portfolio_value = info.get('portfolio_value', 0)
41 | step_return = info.get('step_return', 0)
42 |
43 | # 记录到TensorBoard
44 | self.logger.record('portfolio_value', portfolio_value)
45 | self.logger.record('step_return', step_return)
46 |
47 | return True
48 |
49 | class DQNAgent:
50 | """使用 stable-baselines3 的DQN智能体"""
51 |
52 | def __init__(
53 | self,
54 | env,
55 | model_name: str = "dqn_etf_trading",
56 | policy_type: str = "MlpPolicy",
57 | learning_rate: float = LEARNING_RATE,
58 | gamma: float = GAMMA,
59 | batch_size: int = BATCH_SIZE,
60 | buffer_size: int = REPLAY_BUFFER_SIZE,
61 | exploration_fraction: float = 0.2,
62 | exploration_initial_eps: float = 1.0,
63 | exploration_final_eps: float = 0.1,
64 | verbose: int = 1,
65 | tensorboard_log: str = TENSORBOARD_PATH,
66 | device: str = "auto"
67 | ):
68 | """
69 | 初始化DQN智能体
70 |
71 | 参数:
72 | env: 交易环境
73 | model_name: 模型名称
74 | policy_type: 策略类型,默认MlpPolicy
75 | learning_rate: 学习率
76 | gamma: 折扣因子
77 | batch_size: 批处理大小
78 | buffer_size: 经验回放缓冲区大小
79 | exploration_fraction: 探索阶段占总训练步数的比例
80 | exploration_initial_eps: 初始探索率
81 | exploration_final_eps: 最终探索率
82 | verbose: 日志详细程度
83 | tensorboard_log: TensorBoard日志目录
84 | device: 运行设备,"auto"会自动选择GPU或CPU
85 | """
86 | self.env = env
87 | self.model_name = model_name
88 |
89 | # 设置日志
90 | logging.basicConfig(level=logging.INFO)
91 | self.logger = logging.getLogger(__name__)
92 |
93 | # 检查环境是否被向量化
94 | if not isinstance(env, DummyVecEnv):
95 | self.env = DummyVecEnv([lambda: env])
96 |
97 | # 创建日志目录
98 | if not os.path.exists(tensorboard_log):
99 | os.makedirs(tensorboard_log)
100 |
101 | # 设置TensorBoard日志
102 | log_dir = os.path.join(tensorboard_log, model_name)
103 | self.logger.info(f"TensorBoard日志目录: {log_dir}")
104 |
105 | # 创建DQN模型
106 | self.model = DQN(
107 | policy=policy_type,
108 | env=self.env,
109 | learning_rate=learning_rate,
110 | gamma=gamma,
111 | batch_size=batch_size,
112 | buffer_size=buffer_size,
113 | exploration_fraction=exploration_fraction,
114 | exploration_initial_eps=exploration_initial_eps,
115 | exploration_final_eps=exploration_final_eps,
116 | tensorboard_log=log_dir,
117 | verbose=verbose,
118 | device=device
119 | )
120 |
121 | self.logger.info(f"初始化DQN智能体,模型名称: {model_name}")
122 |
123 | def train(
124 | self,
125 | total_timesteps: int = NUM_EPISODES * 100, # 假设每个回合约100步
126 | tb_log_name: str = "dqn_training",
127 | eval_freq: int = 10000,
128 | n_eval_episodes: int = 5,
129 | eval_env = None
130 | ):
131 | """
132 | 训练DQN智能体
133 |
134 | 参数:
135 | total_timesteps: 总训练步数
136 | tb_log_name: TensorBoard日志名称
137 | eval_freq: 评估频率
138 | n_eval_episodes: 评估回合数
139 | eval_env: 评估环境,如果None则使用训练环境
140 |
141 | 返回:
142 | 训练后的模型
143 | """
144 | # 设置评估环境
145 | if eval_env is None:
146 | eval_env = self.env
147 |
148 | # 创建TensorBoard回调
149 | callback = TensorboardCallback()
150 |
151 | # 开始训练
152 | self.logger.info(f"开始训练DQN模型,总步数: {total_timesteps}")
153 | self.model.learn(
154 | total_timesteps=total_timesteps,
155 | tb_log_name=tb_log_name,
156 | callback=callback
157 | )
158 |
159 | # 保存模型
160 | model_path = os.path.join("models", f"{self.model_name}.zip")
161 | self.model.save(model_path)
162 | self.logger.info(f"模型已保存到 {model_path}")
163 |
164 | return self.model
165 |
166 | def predict(self, observation, state=None, deterministic=True):
167 | """
168 | 使用模型进行预测
169 |
170 | 参数:
171 | observation: 观察状态
172 | state: 隐藏状态(如适用)
173 | deterministic: 是否确定性预测
174 |
175 | 返回:
176 | 预测的动作
177 | """
178 | return self.model.predict(observation, state, deterministic)
179 |
180 | def load(self, path):
181 | """
182 | 加载训练好的模型
183 |
184 | 参数:
185 | path: 模型路径
186 |
187 | 返回:
188 | 加载的模型
189 | """
190 | self.model = DQN.load(path, env=self.env)
191 | self.logger.info(f"模型已加载: {path}")
192 | return self.model
193 |
194 | def save(self, path):
195 | """
196 | 保存当前模型
197 |
198 | 参数:
199 | path: 保存路径
200 | """
201 | self.model.save(path)
202 | self.logger.info(f"模型已保存: {path}")
203 |
204 | def test(self, test_env, num_episodes=1, render=False):
205 | """
206 | 在测试环境中评估模型
207 |
208 | 参数:
209 | test_env: 测试环境
210 | num_episodes: 测试回合数
211 | render: 是否渲染环境
212 |
213 | 返回:
214 | 资产记忆和交易记忆
215 | """
216 | self.logger.info(f"开始测试模型,回合数: {num_episodes}")
217 |
218 | # 如果测试环境不是向量化的,则向量化
219 | if not isinstance(test_env, DummyVecEnv):
220 | test_env = DummyVecEnv([lambda: test_env])
221 |
222 | # 初始化统计信息
223 | total_rewards = []
224 | portfolio_values = []
225 | all_trades = []
226 |
227 | # 多次测试
228 | for episode in range(num_episodes):
229 | # 初始化环境
230 | obs = test_env.reset()
231 | done = False
232 | total_reward = 0
233 | step = 0
234 |
235 | while not done:
236 | # 使用模型选择动作
237 | action, _states = self.model.predict(obs, deterministic=True)
238 |
239 | # 执行动作
240 | obs, reward, done, info = test_env.step(action)
241 |
242 | # 更新统计信息
243 | total_reward += reward[0]
244 | step += 1
245 |
246 | # 收集交易信息
247 | if 'trades' in info[0]:
248 | all_trades.extend(info[0]['trades'])
249 |
250 | # 收集投资组合价值
251 | if 'portfolio_value' in info[0]:
252 | portfolio_values.append(info[0]['portfolio_value'])
253 |
254 | # 如果需要,渲染环境
255 | if render:
256 | test_env.render()
257 |
258 | # 记录回合奖励
259 | total_rewards.append(total_reward)
260 | self.logger.info(f"回合 {episode+1}/{num_episodes}, 总奖励: {total_reward:.4f}")
261 |
262 | # 获取资产记忆
263 | asset_memory = test_env.envs[0].save_asset_memory()
264 |
265 | # 获取交易记忆
266 | action_memory = test_env.envs[0].save_action_memory()
267 |
268 | # 绘制测试结果
269 | if render:
270 | self._plot_test_results(asset_memory, action_memory)
271 |
272 | # 获取最终统计信息
273 | stats = test_env.envs[0].get_final_stats()
274 | self.logger.info(f"测试结果统计: {stats}")
275 |
276 | return asset_memory, action_memory
277 |
278 | def _plot_test_results(self, asset_memory, action_memory):
279 | """
280 | 绘制测试结果
281 |
282 | 参数:
283 | asset_memory: 资产记忆DataFrame
284 | action_memory: 交易记忆DataFrame
285 | """
286 | # 绘制投资组合价值
287 | plt.figure(figsize=(12, 6))
288 | plt.plot(asset_memory.index, asset_memory['portfolio_value'], label='投资组合价值')
289 |
290 | # 添加买入卖出标记
291 | if not action_memory.empty:
292 | for idx, row in action_memory.iterrows():
293 | date = row['date']
294 | action = row['action']
295 | if date in asset_memory.index:
296 | portfolio_value = asset_memory.loc[date, 'portfolio_value']
297 | if action == 'buy':
298 | plt.scatter(date, portfolio_value, marker='^', color='green', s=100)
299 | elif action == 'sell':
300 | plt.scatter(date, portfolio_value, marker='v', color='red', s=100)
301 |
302 | plt.title('ETF交易投资组合价值')
303 | plt.xlabel('日期')
304 | plt.ylabel('价值')
305 | plt.grid(True)
306 | plt.legend()
307 | plt.tight_layout()
308 | plt.savefig(f"results/{self.model_name}_portfolio_value.png")
309 | plt.show()
310 |
311 | # 绘制收益率
312 | returns = asset_memory['portfolio_value'].pct_change().dropna()
313 | plt.figure(figsize=(12, 6))
314 | plt.plot(returns.index, returns, label='每日收益率')
315 | plt.title('ETF交易每日收益率')
316 | plt.xlabel('日期')
317 | plt.ylabel('收益率')
318 | plt.grid(True)
319 | plt.legend()
320 | plt.tight_layout()
321 | plt.savefig(f"results/{self.model_name}_daily_returns.png")
322 | plt.show()
--------------------------------------------------------------------------------
/ui/pages/sidebar.py:
--------------------------------------------------------------------------------
1 | import streamlit as st
2 | from datetime import datetime, timedelta
3 | from src.strategies.strategy_factory import StrategyFactory
4 |
5 | def render_sidebar():
6 | """渲染侧边栏"""
7 | with st.sidebar:
8 | st.header("策略参数设置")
9 |
10 | # 策略选择
11 | st.subheader("策略选择")
12 | strategy_name = st.selectbox(
13 | "选择策略",
14 | options=StrategyFactory.get_strategy_names(),
15 | index=0
16 | )
17 |
18 | # 数据源设置
19 | st.subheader("数据源配置")
20 | if strategy_name in ["市场情绪策略", "双均线对冲策略"]:
21 | tushare_token = st.text_input("Tushare Token(必填)", value="", type="password", help="市场情绪策略需要使用Tushare数据源")
22 | if not tushare_token:
23 | st.error(f"{strategy_name}必须提供Tushare Token")
24 | else:
25 | tushare_token = st.text_input("Tushare Token(可选,如不填则使用akshare)", type="password")
26 |
27 | # ETF轮动策略的ETF选择
28 | if strategy_name == "ETF轮动策略":
29 | etf_list = [
30 | '510050.SH', # 上证50ETF
31 | '510300.SH', # 沪深300ETF
32 | '510500.SH', # 中证500ETF
33 | '159915.SZ', # 创业板ETF
34 | '512880.SH', # 证券ETF
35 | '512690.SH', # 酒ETF
36 | '512660.SH', # 军工ETF
37 | '512010.SH', # 医药ETF
38 | '512800.SH', # 银行ETF
39 | '512170.SH', # 医疗ETF
40 | '512760.SH', # 芯片ETF
41 | '159928.SZ', # 消费ETF
42 | '512480.SH', # 半导体ETF
43 | '512980.SH', # 科技ETF
44 | '512580.SH', # 环保ETF
45 | '512400.SH', # 有色金属ETF
46 | '512200.SH', # 地产ETF
47 | '516160.SH', # 新能源车ETF
48 | '159939.SZ', # 信息技术ETF
49 | '512600.SH', # 主要消费ETF
50 | '512070.SH', # 证券保险ETF
51 | '159869.SZ', # 新基建ETF
52 | '515030.SH', # 新能源ETF
53 | '515790.SH', # 光伏ETF
54 | '513050.SH', # 中概互联ETF
55 | ]
56 | selected_etfs = st.multiselect(
57 | "选择ETF",
58 | options=etf_list,
59 | default=etf_list[:5], # 默认选择前5个ETF
60 | help="选择要轮动的ETF,建议选择3-5个相关性较低的ETF"
61 | )
62 | if not selected_etfs:
63 | st.error("请至少选择一个ETF")
64 | return None
65 | else:
66 | if strategy_name == "双均线对冲策略":
67 | symbol = st.text_input("ETF代码", value="159985.SZ", help="使用M豆粕主力合约作为被对冲标的")
68 | else:
69 | symbol = st.text_input("ETF代码", value="510050.SH", help="支持:A股(000001.SZ)、ETF(510300.SH)、港股(00700.HK)")
70 |
71 | # 移动平均线参数(对双均线策略和双均线对冲策略显示)
72 | if strategy_name in ["双均线策略", "双均线对冲策略"]:
73 | st.subheader("均线参数")
74 | col1, col2 = st.columns(2)
75 | with col1:
76 | fast_period = st.number_input("快线周期", value=5, min_value=1)
77 | with col2:
78 | slow_period = st.number_input("慢线周期", value=10, min_value=1)
79 |
80 | # 止损止盈参数
81 | st.subheader("止损止盈参数")
82 | col1, col2 = st.columns(2)
83 | with col1:
84 | atr_profit_multiplier = st.number_input("ATR止盈倍数", value=3.0, min_value=0.5, max_value=5.0, step=0.1,
85 | help="ATR止盈倍数,值越大止盈距离越远")
86 | with col2:
87 | atr_loss_multiplier = st.number_input("ATR止损倍数", value=3.0, min_value=0.5, max_value=5.0, step=0.1,
88 | help="ATR止损倍数,值越大止损距离越远")
89 | with col1:
90 | atr_period = st.number_input("ATR周期", value=14, min_value=5, max_value=30, step=1)
91 |
92 | # 其他技术参数
93 | col1, col2 = st.columns(2)
94 | with col1:
95 | enable_trailing_stop = st.checkbox("启用追踪止损", value=False,
96 | help="追踪止损会在价格创新高后设置止损价,可以锁定更多利润")
97 | with col2:
98 | enable_death_cross = st.checkbox("启用死叉卖出", value=False,
99 | help="启用后,当快线下穿慢线时将卖出持仓")
100 |
101 | # 双均线对冲策略特有参数
102 | if strategy_name == "双均线对冲策略":
103 | st.subheader("对冲参数")
104 | col1, col2 = st.columns(2)
105 | with col1:
106 | hedge_contract_size = st.number_input("对冲合约手数", value=10, min_value=1, max_value=100, step=1,
107 | help="豆粕期货对冲合约手数,建议10-20手")
108 | with col2:
109 | hedge_profit_multiplier = st.number_input("对冲盈利倍数", value=1.0, min_value=0.5, max_value=5.0, step=0.1,
110 | help="对冲目标盈利 = 原始损失 × (1 + 盈利倍数)")
111 |
112 | # 对冲模块开关
113 | st.subheader("对冲模块开关")
114 | hedge_mode = st.radio(
115 | "选择对冲模式",
116 | ["无对冲", "ATR止损开空对冲", "MA交叉死叉做空对冲", "MACD零轴上方死叉做空对冲", "同步做多对冲"],
117 | help="选择要启用的对冲模式,只能选择一种"
118 | )
119 |
120 | # 根据选择设置对应的开关
121 | enable_hedging = (hedge_mode == "ATR止损开空对冲")
122 | enable_ma_cross_hedge = (hedge_mode == "MA交叉死叉做空对冲")
123 | enable_macd_hedge = (hedge_mode == "MACD零轴上方死叉做空对冲")
124 | enable_sync_long_hedge = (hedge_mode == "同步做多对冲")
125 |
126 | # ETF轮动策略参数
127 | if strategy_name == "ETF轮动策略":
128 | st.subheader("轮动参数")
129 | col1, col2 = st.columns(2)
130 | with col1:
131 | momentum_short = st.number_input("短期动量周期", value=10, min_value=1)
132 | with col2:
133 | momentum_long = st.number_input("长期动量周期", value=60, min_value=1)
134 | col1, col2 = st.columns(2)
135 | with col1:
136 | rebalance_interval = st.number_input("调仓间隔(天)", value=20, min_value=1)
137 | with col2:
138 | num_positions = st.number_input("持仓数量", value=3, min_value=1, max_value=10)
139 |
140 | # 止盈止损参数
141 | st.subheader("止盈止损参数")
142 | col1, col2 = st.columns(2)
143 | with col1:
144 | profit_target1 = st.number_input("第一止盈目标(%)", value=5.0, min_value=1.0, max_value=100.0, step=1.0)
145 | with col2:
146 | profit_target2 = st.number_input("第二止盈目标(%)", value=10.0, min_value=1.0, max_value=100.0, step=1.0)
147 |
148 | # 市场状态参数
149 | st.subheader("市场状态参数")
150 | col1, col2 = st.columns(2)
151 | with col1:
152 | market_trend_threshold = st.number_input("市场趋势阈值(%)", value=-5.0, min_value=-20.0, max_value=0.0, step=1.0)
153 | with col2:
154 | vix_threshold = st.number_input("波动率阈值(%)", value=3.0, min_value=1.0, max_value=10.0, step=0.5)
155 |
156 | # 动量衰减参数
157 | momentum_decay = st.slider("动量衰减阈值(%)", 10.0, 50.0, 30.0, 1.0)
158 | atr_multiplier = st.slider("ATR倍数", 1.0, 3.0, 2.0, 0.1)
159 |
160 | # 风险控制参数
161 | st.subheader("风险控制")
162 | trail_percent = st.slider("追踪止损比例(%)", 0.5, 5.0, 2.0, 0.1)
163 | risk_ratio = st.slider("单次交易风险比例(%)", 0.5, 5.0, 2.0, 0.1)
164 | max_drawdown = st.slider("最大回撤限制(%)", 5.0, 30.0, 15.0, 1.0)
165 |
166 | # 回测区间
167 | st.subheader("回测区间")
168 | start_date = st.date_input(
169 | "开始日期",
170 | datetime.now() - timedelta(days=365)
171 | )
172 | end_date = st.date_input("结束日期", datetime.now())
173 |
174 | # 资金设置
175 | st.subheader("资金设置")
176 | initial_cash = st.number_input("初始资金", value=100000.0, min_value=1000.0)
177 | commission = st.number_input("佣金费率(双向收取,默认万分之2.5)", value=0.00025, min_value=0.0, max_value=0.01, format="%.5f",
178 | help="双向收取,例如:0.00025表示万分之2.5")
179 |
180 | # 返回所有参数
181 | params = {
182 | 'strategy_name': strategy_name,
183 | 'tushare_token': tushare_token,
184 | 'selected_etfs': selected_etfs if strategy_name == "ETF轮动策略" else None,
185 | 'symbol': symbol if strategy_name != "ETF轮动策略" else None,
186 | 'fast_period': fast_period if strategy_name in ["双均线策略", "双均线对冲策略"] else None,
187 | 'slow_period': slow_period if strategy_name in ["双均线策略", "双均线对冲策略"] else None,
188 | 'atr_profit_multiplier': atr_profit_multiplier if strategy_name in ["双均线策略", "双均线对冲策略", "ETF轮动策略"] else None,
189 | 'atr_loss_multiplier': atr_loss_multiplier if strategy_name in ["双均线策略", "双均线对冲策略", "ETF轮动策略"] else None,
190 | 'atr_period': atr_period if strategy_name in ["双均线策略", "双均线对冲策略"] else None,
191 | 'enable_trailing_stop': enable_trailing_stop if strategy_name in ["双均线策略", "双均线对冲策略"] else None,
192 | 'enable_death_cross': enable_death_cross if strategy_name in ["双均线策略", "双均线对冲策略"] else None,
193 | 'hedge_contract_size': hedge_contract_size if strategy_name == "双均线对冲策略" else None,
194 | 'hedge_profit_multiplier': hedge_profit_multiplier if strategy_name == "双均线对冲策略" else None,
195 | 'enable_hedging': enable_hedging if strategy_name == "双均线对冲策略" else None,
196 | 'enable_ma_cross_hedge': enable_ma_cross_hedge if strategy_name == "双均线对冲策略" else None,
197 | 'enable_macd_hedge': enable_macd_hedge if strategy_name == "双均线对冲策略" else None,
198 | 'enable_sync_long_hedge': enable_sync_long_hedge if strategy_name == "双均线对冲策略" else None,
199 | 'momentum_short': momentum_short if strategy_name == "ETF轮动策略" else None,
200 | 'momentum_long': momentum_long if strategy_name == "ETF轮动策略" else None,
201 | 'rebalance_interval': rebalance_interval if strategy_name == "ETF轮动策略" else None,
202 | 'num_positions': num_positions if strategy_name == "ETF轮动策略" else None,
203 | 'profit_target1': profit_target1 / 100 if strategy_name == "ETF轮动策略" else None,
204 | 'profit_target2': profit_target2 / 100 if strategy_name == "ETF轮动策略" else None,
205 | 'market_trend_threshold': market_trend_threshold / 100 if strategy_name == "ETF轮动策略" else None,
206 | 'vix_threshold': vix_threshold / 100 if strategy_name == "ETF轮动策略" else None,
207 | 'momentum_decay': momentum_decay / 100 if strategy_name == "ETF轮动策略" else None,
208 | 'trail_percent': trail_percent,
209 | 'risk_ratio': risk_ratio,
210 | 'max_drawdown': max_drawdown,
211 | 'start_date': start_date,
212 | 'end_date': end_date,
213 | 'initial_cash': initial_cash,
214 | 'commission': commission
215 | }
216 |
217 | return params
--------------------------------------------------------------------------------
/src/strategies/dual_ma_hedging/ma_cross_hedge.py:
--------------------------------------------------------------------------------
1 | import backtrader as bt
2 | from loguru import logger
3 |
4 | class MACrossHedge:
5 | def __init__(self, strategy):
6 | self.strategy = strategy
7 | self.enabled = False
8 | self.hedge_position = None
9 | self.hedge_entry_price = None
10 | self.hedge_order = None
11 | self.hedge_contract_code = None
12 | self.hedge_entry_date = None # 添加入场日期记录
13 |
14 | def enable(self):
15 | """启用MA交叉对冲功能"""
16 | self.enabled = True
17 | logger.info("启用MA交叉对冲功能")
18 |
19 | def disable(self):
20 | """禁用MA交叉对冲功能"""
21 | self.enabled = False
22 | logger.info("禁用MA交叉对冲功能")
23 |
24 | def on_death_cross(self):
25 | """在快线下穿慢线时开空期货"""
26 | if not self.enabled:
27 | return
28 |
29 | if self.hedge_position is not None or self.hedge_order is not None:
30 | logger.info("已有对冲仓位或对冲订单,不再开仓")
31 | return
32 |
33 | try:
34 | # 计算ATR止盈止损价格
35 | current_atr = self.strategy.atr[0]
36 | stop_loss = self.strategy.data1.close[0] + (current_atr * self.strategy.p.atr_loss_multiplier)
37 | take_profit = self.strategy.data1.close[0] - (current_atr * self.strategy.p.atr_profit_multiplier)
38 |
39 | # 开空豆粕期货
40 | hedge_size = self.strategy.p.hedge_contract_size
41 | self.hedge_order = self.strategy.sell(data=self.strategy.data1, size=hedge_size)
42 |
43 | if self.hedge_order:
44 | # 记录入场价格和合约代码
45 | self.hedge_entry_price = self.strategy.data1.close[0]
46 | # 确保获取正确的合约代码
47 | current_date = self.strategy.data1.datetime.datetime(0)
48 | if hasattr(self.strategy.data1, 'contract_mapping') and current_date in self.strategy.data1.contract_mapping:
49 | self.hedge_contract_code = self.strategy.data1.contract_mapping[current_date]
50 | else:
51 | # 如果无法获取映射,使用数据名称
52 | self.hedge_contract_code = self.strategy.data1._name
53 | self.hedge_entry_date = self.strategy.data.datetime.date(0) # 记录入场日期
54 |
55 | # 计算保证金
56 | margin = self.hedge_entry_price * hedge_size * self.strategy.p.future_contract_multiplier * 0.10
57 |
58 | # 从期货账户扣除保证金
59 | pre_cash = self.strategy.future_cash
60 | self.strategy.future_cash -= margin
61 |
62 | logger.info(f"开仓扣除保证金 - 之前: {pre_cash:.2f}, 扣除: {margin:.2f}, 之后: {self.strategy.future_cash:.2f}")
63 |
64 | # 记录交易信息
65 | self.hedge_order.info = {
66 | 'reason': f"MA死叉开空 - 快线: {self.strategy.fast_ma[0]:.2f}, 慢线: {self.strategy.slow_ma[0]:.2f}",
67 | 'margin': margin,
68 | 'future_cash': self.strategy.future_cash,
69 | 'execution_date': self.hedge_entry_date,
70 | 'total_value': self.strategy.future_cash,
71 | 'position_value': abs(margin),
72 | 'position_ratio': margin / self.strategy.future_cash if self.strategy.future_cash > 0 else 0,
73 | 'etf_code': self.hedge_contract_code,
74 | 'pnl': 0,
75 | 'return': 0,
76 | 'stop_loss': stop_loss,
77 | 'take_profit': take_profit,
78 | 'avg_cost': self.hedge_entry_price
79 | }
80 |
81 | logger.info(f"MA死叉开空 - 合约: {self.hedge_contract_code}, 价格: {self.hedge_entry_price:.2f}, 数量: {hedge_size}手, "
82 | f"止损价: {stop_loss:.2f}, 止盈价: {take_profit:.2f}")
83 |
84 | except Exception as e:
85 | logger.error(f"MA死叉开空失败: {str(e)}")
86 |
87 | def check_exit(self):
88 | """检查是否需要平仓"""
89 | if not self.enabled or not self.hedge_position or self.hedge_order is not None:
90 | return
91 |
92 | current_price = self.strategy.data1.close[0]
93 | current_atr = self.strategy.atr[0]
94 |
95 | # 计算ATR止盈止损价格
96 | stop_loss = self.hedge_entry_price + (current_atr * self.strategy.p.atr_loss_multiplier)
97 | take_profit = self.hedge_entry_price - (current_atr * self.strategy.p.atr_profit_multiplier)
98 |
99 | # 获取当前日期
100 | current_date = self.strategy.data.datetime.date(0)
101 |
102 | # 检查是否触发止盈止损
103 | if current_price >= stop_loss or current_price <= take_profit:
104 | contract_code = self.hedge_contract_code
105 | self.hedge_order = self.strategy.close(data=self.strategy.data1)
106 | reason = "触发止盈" if current_price <= take_profit else "触发止损"
107 | logger.info(f"MA死叉对冲{reason} - 日期: {current_date}, 合约: {contract_code}, 当前价格: {current_price:.2f}, {reason}价: {take_profit if current_price <= take_profit else stop_loss:.2f}")
108 |
109 | def on_order_completed(self, order):
110 | """处理订单完成事件"""
111 | if not self.enabled:
112 | return
113 |
114 | if order.status in [order.Completed]:
115 | if order.isbuy(): # 买入豆粕期货(平空)
116 | # 确保有对应的入场价格
117 | if self.hedge_entry_price is None or self.hedge_contract_code is None:
118 | logger.error("平仓时找不到入场价格或合约代码,跳过处理")
119 | return
120 |
121 | # 记录平仓前的合约信息,用于日志
122 | entry_price = self.hedge_entry_price
123 | contract_code = self.hedge_contract_code
124 | entry_date = self.hedge_entry_date
125 |
126 | # 记录交易日期和价格
127 | trade_date = self.strategy.data.datetime.date(0)
128 | trade_price = order.executed.price
129 |
130 | # 先重置持仓相关变量,防止重复平仓
131 | self.hedge_position = None
132 | self.hedge_order = None
133 | self.hedge_entry_price = None
134 | self.hedge_contract_code = None
135 | self.hedge_entry_date = None
136 | self.hedge_target_profit = None
137 |
138 | # 计算对冲盈亏(空仓:入场价 - 平仓价)
139 | hedge_profit = (entry_price - trade_price) * self.strategy.p.hedge_contract_size * self.strategy.p.future_contract_multiplier
140 |
141 | # 减去开平仓手续费
142 | total_fee = self.strategy.p.hedge_fee * self.strategy.p.hedge_contract_size * 2
143 | net_profit = hedge_profit - total_fee
144 |
145 | # 归还保证金并添加盈亏到期货账户
146 | margin_returned = entry_price * self.strategy.p.hedge_contract_size * self.strategy.p.future_contract_multiplier * 0.10
147 |
148 | # 记录更新前的资金
149 | pre_cash = self.strategy.future_cash
150 |
151 | # 更新期货账户资金
152 | self.strategy.future_cash += (margin_returned + net_profit)
153 |
154 | # 记录资金变动
155 | logger.info(f"平仓资金变动 - 之前: {pre_cash:.2f}, 返还保证金: {margin_returned:.2f}, 盈亏: {net_profit:.2f}, 之后: {self.strategy.future_cash:.2f}")
156 |
157 | # 更新期货账户最高净值
158 | self.strategy.future_highest_value = max(self.strategy.future_highest_value, self.strategy.future_cash)
159 |
160 | # 计算期货账户回撤
161 | future_drawdown = (self.strategy.future_highest_value - self.strategy.future_cash) / self.strategy.future_highest_value if self.strategy.future_highest_value > 0 else 0
162 |
163 | # 计算收益率
164 | return_pct = (hedge_profit / (entry_price * self.strategy.p.hedge_contract_size * self.strategy.p.future_contract_multiplier)) * 100
165 |
166 | # 更新订单信息
167 | order.info.update({
168 | 'pnl': hedge_profit,
169 | 'return': return_pct,
170 | 'total_value': self.strategy.future_cash,
171 | 'position_value': 0, # 平仓后持仓价值为0
172 | 'avg_cost': entry_price,
173 | 'etf_code': contract_code, # 确保使用原始合约代码
174 | 'execution_date': trade_date, # 确保使用当前交易日期
175 | 'reason': f"MA死叉对冲平仓 - 合约: {contract_code}, 入场日期: {entry_date}, 入场价: {entry_price:.2f}, 平仓价: {trade_price:.2f}, 收益率: {return_pct:.2f}%"
176 | })
177 |
178 | logger.info(f"MA死叉对冲平仓 - 日期: {trade_date}, 合约: {contract_code}, 价格: {trade_price:.2f}, 盈利: {hedge_profit:.2f}, "
179 | f"手续费: {total_fee:.2f}, 净盈利: {net_profit:.2f}, "
180 | f"期货账户余额: {self.strategy.future_cash:.2f}, 回撤: {future_drawdown:.2%}, "
181 | f"收益率: {return_pct:.2f}%")
182 |
183 | else: # 卖出豆粕期货(开空)
184 | # 记录对冲持仓
185 | self.hedge_position = order
186 |
187 | # 更新订单信息
188 | order.info.update({
189 | 'total_value': self.strategy.future_cash,
190 | 'position_value': abs(order.info['margin']),
191 | 'avg_cost': order.executed.price,
192 | 'etf_code': self.hedge_contract_code # 确保合约代码正确
193 | })
194 |
195 | elif order.status in [order.Canceled, order.Margin, order.Rejected]:
196 | self.hedge_order = None
197 | logger.warning(f'MA死叉对冲订单失败 - 状态: {order.getstatusname()}')
198 |
199 | def on_strategy_stop(self):
200 | """策略结束时平掉所有期货仓位"""
201 | if not self.enabled or not self.hedge_position:
202 | return
203 |
204 | if self.hedge_order is None: # 确保没有未完成订单
205 | # 获取当前持仓的合约代码
206 | current_contract = self.hedge_contract_code
207 | if not current_contract:
208 | logger.error("策略结束时找不到期货合约代码,无法平仓")
209 | return
210 |
211 | # 获取当前日期
212 | current_date = self.strategy.data.datetime.date(0)
213 |
214 | self.hedge_order = self.strategy.close(data=self.strategy.data1)
215 | logger.info(f"策略结束,平掉期货仓位 - 日期: {current_date}, 合约: {current_contract}, 入场日期: {self.hedge_entry_date}, 入场价: {self.hedge_entry_price:.2f}")
216 |
217 | # 更新订单信息
218 | if self.hedge_order:
219 | self.hedge_order.info.update({
220 | 'etf_code': current_contract, # 确保使用正确的合约代码
221 | 'execution_date': current_date, # 使用当前日期
222 | 'reason': f"策略结束平仓 - 日期: {current_date}, 合约: {current_contract}, 入场日期: {self.hedge_entry_date}, 入场价: {self.hedge_entry_price:.2f}"
223 | })
--------------------------------------------------------------------------------
/rl_model_finrl/meta/data_processors/akshare_processor.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import numpy as np
3 | import akshare as ak
4 | import os
5 | from typing import List, Dict, Optional, Tuple
6 | import logging
7 | from datetime import datetime
8 |
9 | from src.strategies.rl_model_finrl.config import (
10 | DATA_SAVE_PATH,
11 | TRAIN_START_DATE,
12 | TRAIN_END_DATE,
13 | TEST_START_DATE,
14 | TEST_END_DATE,
15 | TICKER_LIST,
16 | )
17 |
18 | class AKShareProcessor:
19 | """
20 | AKShareProcessor 类负责从AKShare获取A股ETF相关数据
21 | 作为补充数据源使用
22 | """
23 |
24 | def __init__(self):
25 | """初始化AKShare数据处理器"""
26 | # 确保数据保存目录存在
27 | self.data_path = os.path.join(DATA_SAVE_PATH, "akshare_data")
28 | os.makedirs(self.data_path, exist_ok=True)
29 |
30 | # 设置日志
31 | logging.basicConfig(level=logging.INFO)
32 | self.logger = logging.getLogger(__name__)
33 |
34 | def download_etf_fund_info(self, ticker_list: List[str]) -> Dict[str, pd.DataFrame]:
35 | """
36 | 下载ETF基金基本信息
37 |
38 | 参数:
39 | ticker_list: ETF代码列表
40 |
41 | 返回:
42 | 包含ETF基本信息的字典
43 | """
44 | etf_info_dict = {}
45 |
46 | for ticker in ticker_list:
47 | try:
48 | # 处理代码格式
49 | ticker_code = ticker.split('.')[0]
50 |
51 | self.logger.info(f"获取ETF {ticker} 的基本信息...")
52 |
53 | # 获取ETF基本信息
54 | etf_info = ak.fund_etf_basic_info(symbol=ticker_code)
55 |
56 | if etf_info is not None and not etf_info.empty:
57 | # 保存到字典
58 | etf_info_dict[ticker] = etf_info
59 |
60 | # 保存到CSV
61 | file_name = f"{ticker}_info.csv"
62 | file_path = os.path.join(self.data_path, file_name)
63 | etf_info.to_csv(file_path, index=False)
64 | self.logger.info(f"ETF基本信息已保存到 {file_path}")
65 | else:
66 | self.logger.warning(f"无法获取ETF {ticker} 的基本信息")
67 |
68 | except Exception as e:
69 | self.logger.error(f"获取ETF {ticker} 的基本信息时出错: {str(e)}")
70 |
71 | return etf_info_dict
72 |
73 | def download_etf_daily_data(
74 | self,
75 | ticker_list: List[str],
76 | start_date: str,
77 | end_date: str
78 | ) -> Dict[str, pd.DataFrame]:
79 | """
80 | 从AKShare下载ETF日线数据
81 |
82 | 参数:
83 | ticker_list: ETF代码列表
84 | start_date: 开始日期,格式'YYYY-MM-DD'
85 | end_date: 结束日期,格式'YYYY-MM-DD'
86 |
87 | 返回:
88 | 字典,包含每个ETF代码及其对应的数据框
89 | """
90 | etf_data_dict = {}
91 |
92 | for ticker in ticker_list:
93 | try:
94 | # 处理代码格式
95 | ticker_code = ticker.split('.')[0]
96 |
97 | self.logger.info(f"从AKShare获取ETF {ticker} 的日线数据...")
98 |
99 | # 获取ETF历史数据
100 | etf_data = ak.fund_etf_hist_em(
101 | symbol=ticker_code,
102 | period="daily",
103 | start_date=start_date,
104 | end_date=end_date,
105 | adjust="qfq"
106 | )
107 |
108 | if etf_data is not None and not etf_data.empty:
109 | # 重命名列以匹配tushare格式
110 | etf_data = etf_data.rename(columns={
111 | '日期': 'date',
112 | '开盘': 'open',
113 | '收盘': 'close',
114 | '最高': 'high',
115 | '最低': 'low',
116 | '成交量': 'volume',
117 | '成交额': 'amount',
118 | '振幅': 'amplitude',
119 | '涨跌幅': 'change_pct',
120 | '涨跌额': 'change',
121 | '换手率': 'turnover'
122 | })
123 |
124 | # 转换日期格式并设置为索引
125 | etf_data['date'] = pd.to_datetime(etf_data['date'])
126 | etf_data = etf_data.sort_values('date')
127 | etf_data = etf_data.set_index('date')
128 |
129 | # 添加代码列
130 | etf_data['tic'] = ticker
131 |
132 | # 保存到字典
133 | etf_data_dict[ticker] = etf_data
134 |
135 | # 保存到CSV
136 | file_name = f"{ticker}_daily_akshare_{start_date.replace('-', '')}_{end_date.replace('-', '')}.csv"
137 | file_path = os.path.join(self.data_path, file_name)
138 | etf_data.to_csv(file_path)
139 | self.logger.info(f"ETF日线数据已保存到 {file_path}")
140 | else:
141 | self.logger.warning(f"无法获取ETF {ticker} 的日线数据")
142 |
143 | except Exception as e:
144 | self.logger.error(f"获取ETF {ticker} 的日线数据时出错: {str(e)}")
145 |
146 | return etf_data_dict
147 |
148 | def download_etf_fund_flow(
149 | self,
150 | ticker_list: List[str],
151 | start_date: str,
152 | end_date: str
153 | ) -> Dict[str, pd.DataFrame]:
154 | """
155 | 获取ETF资金流向数据
156 |
157 | 参数:
158 | ticker_list: ETF代码列表
159 | start_date: 开始日期,格式'YYYY-MM-DD'
160 | end_date: 结束日期,格式'YYYY-MM-DD'
161 |
162 | 返回:
163 | 包含资金流向的字典
164 | """
165 | fund_flow_dict = {}
166 |
167 | for ticker in ticker_list:
168 | try:
169 | # 处理代码格式
170 | ticker_code = ticker.split('.')[0]
171 |
172 | self.logger.info(f"获取ETF {ticker} 的资金流向数据...")
173 |
174 | # 获取ETF资金流向
175 | try:
176 | # 尝试获取资金流向数据,这个可能需要调整,因为AKShare的API可能有变化
177 | fund_flow = ak.fund_etf_fund_flow_rank()
178 |
179 | # 筛选特定ETF的数据
180 | if ticker_code in fund_flow['代码'].values:
181 | fund_flow_single = fund_flow[fund_flow['代码'] == ticker_code]
182 | fund_flow_dict[ticker] = fund_flow_single
183 |
184 | # 保存到CSV
185 | file_name = f"{ticker}_fund_flow.csv"
186 | file_path = os.path.join(self.data_path, file_name)
187 | fund_flow_single.to_csv(file_path, index=False)
188 | self.logger.info(f"ETF资金流向数据已保存到 {file_path}")
189 | else:
190 | self.logger.warning(f"未找到ETF {ticker} 的资金流向数据")
191 | except Exception as inner_e:
192 | self.logger.error(f"获取ETF资金流向时出现内部错误: {str(inner_e)}")
193 |
194 | except Exception as e:
195 | self.logger.error(f"获取ETF {ticker} 的资金流向数据时出错: {str(e)}")
196 |
197 | return fund_flow_dict
198 |
199 | def download_etf_holdings(self, ticker_list: List[str]) -> Dict[str, pd.DataFrame]:
200 | """
201 | 获取ETF持仓数据
202 |
203 | 参数:
204 | ticker_list: ETF代码列表
205 |
206 | 返回:
207 | 包含ETF持仓的字典
208 | """
209 | holdings_dict = {}
210 |
211 | for ticker in ticker_list:
212 | try:
213 | # 处理代码格式
214 | ticker_code = ticker.split('.')[0]
215 |
216 | self.logger.info(f"获取ETF {ticker} 的持仓数据...")
217 |
218 | # 获取ETF持仓
219 | try:
220 | # 尝试获取ETF持仓数据,这个API需要根据AKShare的最新文档调整
221 | holdings = ak.fund_etf_spot_deal_em()
222 |
223 | # 筛选特定ETF的数据
224 | if ticker_code in holdings['代码'].values:
225 | holdings_single = holdings[holdings['代码'] == ticker_code]
226 | holdings_dict[ticker] = holdings_single
227 |
228 | # 保存到CSV
229 | file_name = f"{ticker}_holdings.csv"
230 | file_path = os.path.join(self.data_path, file_name)
231 | holdings_single.to_csv(file_path, index=False)
232 | self.logger.info(f"ETF持仓数据已保存到 {file_path}")
233 | else:
234 | self.logger.warning(f"未找到ETF {ticker} 的持仓数据")
235 | except Exception as inner_e:
236 | self.logger.error(f"获取ETF持仓时出现内部错误: {str(inner_e)}")
237 |
238 | except Exception as e:
239 | self.logger.error(f"获取ETF {ticker} 的持仓数据时出错: {str(e)}")
240 |
241 | return holdings_dict
242 |
243 | def download_market_sentiment(self) -> pd.DataFrame:
244 | """
245 | 获取市场情绪数据(如恐慌指数等)
246 |
247 | 返回:
248 | 包含市场情绪的DataFrame
249 | """
250 | try:
251 | self.logger.info("获取A股市场情绪数据...")
252 |
253 | # 获取A股情绪指标
254 | # 注意:需要根据AKShare的最新文档确认获取情绪数据的API
255 | # 这里用股市情绪指标作为示例
256 | sentiment_data = ak.stock_market_emotion_baidu()
257 |
258 | if sentiment_data is not None and not sentiment_data.empty:
259 | # 保存到CSV
260 | file_name = f"market_sentiment_{datetime.now().strftime('%Y%m%d')}.csv"
261 | file_path = os.path.join(self.data_path, file_name)
262 | sentiment_data.to_csv(file_path, index=False)
263 | self.logger.info(f"市场情绪数据已保存到 {file_path}")
264 |
265 | return sentiment_data
266 | else:
267 | self.logger.warning("无法获取市场情绪数据")
268 | return pd.DataFrame()
269 |
270 | except Exception as e:
271 | self.logger.error(f"获取市场情绪数据时出错: {str(e)}")
272 | return pd.DataFrame()
273 |
274 | def prepare_supplementary_data(self) -> Dict[str, pd.DataFrame]:
275 | """
276 | 准备所有补充数据
277 |
278 | 返回:
279 | 包含所有补充数据的字典
280 | """
281 | supplementary_data = {}
282 |
283 | # 1. 获取ETF基本信息
284 | etf_info = self.download_etf_fund_info(TICKER_LIST)
285 | supplementary_data['etf_info'] = etf_info
286 |
287 | # 2. 获取ETF资金流向
288 | fund_flow = self.download_etf_fund_flow(
289 | TICKER_LIST,
290 | TRAIN_START_DATE,
291 | TEST_END_DATE
292 | )
293 | supplementary_data['fund_flow'] = fund_flow
294 |
295 | # 3. 获取ETF持仓
296 | holdings = self.download_etf_holdings(TICKER_LIST)
297 | supplementary_data['holdings'] = holdings
298 |
299 | # 4. 获取市场情绪
300 | sentiment = self.download_market_sentiment()
301 | supplementary_data['market_sentiment'] = sentiment
302 |
303 | return supplementary_data
--------------------------------------------------------------------------------
/src/strategies/dual_ma_hedging/sync_long_hedge.py:
--------------------------------------------------------------------------------
1 | import backtrader as bt
2 | from loguru import logger
3 |
4 | class SyncLongHedge:
5 | def __init__(self, strategy):
6 | self.strategy = strategy
7 | self.enabled = False
8 | self.hedge_position = None
9 | self.hedge_entry_price = None
10 | self.hedge_order = None
11 | self.hedge_contract_code = None
12 | self.hedge_entry_date = None # 添加入场日期记录
13 |
14 | def enable(self):
15 | """启用同步做多对冲功能"""
16 | self.enabled = True
17 | logger.info("启用同步做多对冲功能")
18 |
19 | def disable(self):
20 | """禁用同步做多对冲功能"""
21 | self.enabled = False
22 | logger.info("禁用同步做多对冲功能")
23 |
24 | def on_golden_cross(self):
25 | """在ETF金叉时同步开多期货"""
26 | if not self.enabled:
27 | return
28 |
29 | if self.hedge_position is not None or self.hedge_order is not None:
30 | logger.info("已有对冲仓位或对冲订单,不再开仓")
31 | return
32 |
33 | try:
34 | # 计算ATR止盈止损价格
35 | current_atr = self.strategy.atr[0]
36 | stop_loss = self.strategy.data1.close[0] - (current_atr * self.strategy.p.atr_loss_multiplier)
37 | take_profit = self.strategy.data1.close[0] + (current_atr * self.strategy.p.atr_profit_multiplier)
38 |
39 | # 开多豆粕期货
40 | hedge_size = self.strategy.p.hedge_contract_size
41 | self.hedge_order = self.strategy.buy(data=self.strategy.data1, size=hedge_size)
42 |
43 | if self.hedge_order:
44 | # 记录入场价格和合约代码
45 | self.hedge_entry_price = self.strategy.data1.close[0]
46 | # 确保获取正确的合约代码
47 | current_date = self.strategy.data1.datetime.datetime(0)
48 | if hasattr(self.strategy.data1, 'contract_mapping') and current_date in self.strategy.data1.contract_mapping:
49 | self.hedge_contract_code = self.strategy.data1.contract_mapping[current_date]
50 | else:
51 | # 如果无法获取映射,使用数据名称
52 | self.hedge_contract_code = self.strategy.data1._name
53 | self.hedge_entry_date = self.strategy.data.datetime.date(0) # 记录入场日期
54 |
55 | # 计算保证金
56 | margin = self.hedge_entry_price * hedge_size * self.strategy.p.future_contract_multiplier * 0.10
57 |
58 | # 从期货账户扣除保证金
59 | pre_cash = self.strategy.future_cash
60 | self.strategy.future_cash -= margin
61 |
62 | logger.info(f"开仓扣除保证金 - 之前: {pre_cash:.2f}, 扣除: {margin:.2f}, 之后: {self.strategy.future_cash:.2f}")
63 |
64 | # 记录交易信息
65 | self.hedge_order.info = {
66 | 'reason': f"ETF金叉同步开多 - 快线: {self.strategy.fast_ma[0]:.2f}, 慢线: {self.strategy.slow_ma[0]:.2f}",
67 | 'margin': margin,
68 | 'future_cash': self.strategy.future_cash,
69 | 'execution_date': self.hedge_entry_date,
70 | 'total_value': self.strategy.future_cash,
71 | 'position_value': abs(margin),
72 | 'position_ratio': margin / self.strategy.future_cash if self.strategy.future_cash > 0 else 0,
73 | 'etf_code': self.hedge_contract_code,
74 | 'pnl': 0,
75 | 'return': 0,
76 | 'stop_loss': stop_loss,
77 | 'take_profit': take_profit,
78 | 'avg_cost': self.hedge_entry_price
79 | }
80 |
81 | logger.info(f"ETF金叉同步开多 - 合约: {self.hedge_contract_code}, 价格: {self.hedge_entry_price:.2f}, 数量: {hedge_size}手, "
82 | f"止损价: {stop_loss:.2f}, 止盈价: {take_profit:.2f}")
83 |
84 | except Exception as e:
85 | logger.error(f"ETF金叉同步开多失败: {str(e)}")
86 |
87 | def on_etf_close(self):
88 | """在ETF平仓时同步平多仓"""
89 | if not self.enabled or not self.hedge_position:
90 | return
91 |
92 | if self.hedge_order is None: # 确保没有未完成订单
93 | self.hedge_order = self.strategy.close(data=self.strategy.data1)
94 | logger.info("ETF平仓,同步平多仓")
95 |
96 | def check_exit(self):
97 | """检查是否需要平仓"""
98 | if not self.enabled or not self.hedge_position or self.hedge_order is not None:
99 | return
100 |
101 | current_price = self.strategy.data1.close[0]
102 | current_atr = self.strategy.atr[0]
103 |
104 | # 计算ATR止盈止损价格
105 | stop_loss = self.hedge_entry_price - (current_atr * self.strategy.p.atr_loss_multiplier)
106 | take_profit = self.hedge_entry_price + (current_atr * self.strategy.p.atr_profit_multiplier)
107 |
108 | # 获取当前日期
109 | current_date = self.strategy.data.datetime.date(0)
110 |
111 | # 检查是否触发止盈止损
112 | if current_price <= stop_loss or current_price >= take_profit:
113 | contract_code = self.hedge_contract_code
114 | self.hedge_order = self.strategy.close(data=self.strategy.data1)
115 | reason = "触发止盈" if current_price >= take_profit else "触发止损"
116 | logger.info(f"同步做多对冲{reason} - 日期: {current_date}, 合约: {contract_code}, 当前价格: {current_price:.2f}, {reason}价: {take_profit if current_price >= take_profit else stop_loss:.2f}")
117 |
118 | def on_order_completed(self, order):
119 | """处理订单完成事件"""
120 | if not self.enabled:
121 | return
122 |
123 | if order.status in [order.Completed]:
124 | if order.issell(): # 卖出豆粕期货(平多)
125 | # 确保有对应的入场价格
126 | if self.hedge_entry_price is None or self.hedge_contract_code is None:
127 | logger.error("平仓时找不到入场价格或合约代码,跳过处理")
128 | return
129 |
130 | # 记录平仓前的合约信息,用于日志
131 | entry_price = self.hedge_entry_price
132 | contract_code = self.hedge_contract_code
133 | entry_date = self.hedge_entry_date
134 |
135 | # 记录交易日期和价格
136 | trade_date = self.strategy.data.datetime.date(0)
137 | trade_price = order.executed.price
138 |
139 | # 先重置持仓相关变量,防止重复平仓
140 | self.hedge_position = None
141 | self.hedge_order = None
142 | self.hedge_entry_price = None
143 | self.hedge_contract_code = None
144 | self.hedge_entry_date = None
145 | self.hedge_target_profit = None
146 |
147 | # 计算对冲盈亏
148 | hedge_profit = (trade_price - entry_price) * self.strategy.p.hedge_contract_size * self.strategy.p.future_contract_multiplier
149 |
150 | # 减去开平仓手续费
151 | total_fee = self.strategy.p.hedge_fee * self.strategy.p.hedge_contract_size * 2
152 | net_profit = hedge_profit - total_fee
153 |
154 | # 归还保证金并添加盈亏到期货账户
155 | margin_returned = entry_price * self.strategy.p.hedge_contract_size * self.strategy.p.future_contract_multiplier * 0.10
156 |
157 | # 记录更新前的资金
158 | pre_cash = self.strategy.future_cash
159 |
160 | # 更新期货账户资金
161 | self.strategy.future_cash += (margin_returned + net_profit)
162 |
163 | # 记录资金变动
164 | logger.info(f"平仓资金变动 - 之前: {pre_cash:.2f}, 返还保证金: {margin_returned:.2f}, 盈亏: {net_profit:.2f}, 之后: {self.strategy.future_cash:.2f}")
165 |
166 | # 更新期货账户最高净值
167 | self.strategy.future_highest_value = max(self.strategy.future_highest_value, self.strategy.future_cash)
168 |
169 | # 计算期货账户回撤
170 | future_drawdown = (self.strategy.future_highest_value - self.strategy.future_cash) / self.strategy.future_highest_value if self.strategy.future_highest_value > 0 else 0
171 |
172 | # 计算收益率
173 | return_pct = (hedge_profit / (entry_price * self.strategy.p.hedge_contract_size * self.strategy.p.future_contract_multiplier)) * 100
174 |
175 | # 更新订单信息
176 | order.info.update({
177 | 'pnl': hedge_profit,
178 | 'return': return_pct,
179 | 'total_value': self.strategy.future_cash,
180 | 'position_value': 0, # 平仓后持仓价值为0
181 | 'avg_cost': entry_price,
182 | 'etf_code': contract_code, # 确保使用原始合约代码
183 | 'execution_date': trade_date, # 确保使用当前交易日期
184 | 'reason': f"同步做多对冲平仓 - 合约: {contract_code}, 入场日期: {entry_date}, 入场价: {entry_price:.2f}, 平仓价: {trade_price:.2f}, 收益率: {return_pct:.2f}%"
185 | })
186 |
187 | logger.info(f"同步做多对冲平仓 - 日期: {trade_date}, 合约: {contract_code}, 价格: {trade_price:.2f}, 盈利: {hedge_profit:.2f}, "
188 | f"手续费: {total_fee:.2f}, 净盈利: {net_profit:.2f}, "
189 | f"期货账户余额: {self.strategy.future_cash:.2f}, 回撤: {future_drawdown:.2%}, "
190 | f"收益率: {return_pct:.2f}%")
191 |
192 | else: # 买入豆粕期货(开多)
193 | # 记录对冲持仓
194 | self.hedge_position = order
195 |
196 | # 更新订单信息
197 | order.info.update({
198 | 'total_value': self.strategy.future_cash,
199 | 'position_value': abs(order.info['margin']),
200 | 'avg_cost': order.executed.price,
201 | 'etf_code': self.hedge_contract_code # 确保合约代码正确
202 | })
203 |
204 | elif order.status in [order.Canceled, order.Margin, order.Rejected]:
205 | self.hedge_order = None
206 | logger.warning(f'同步做多对冲订单失败 - 状态: {order.getstatusname()}')
207 |
208 | def on_strategy_stop(self):
209 | """策略结束时平掉所有期货仓位"""
210 | if not self.enabled or not self.hedge_position:
211 | return
212 |
213 | if self.hedge_order is None: # 确保没有未完成订单
214 | # 获取当前持仓的合约代码
215 | current_contract = self.hedge_contract_code
216 | if not current_contract:
217 | logger.error("策略结束时找不到期货合约代码,无法平仓")
218 | return
219 |
220 | # 获取当前日期
221 | current_date = self.strategy.data.datetime.date(0)
222 |
223 | self.hedge_order = self.strategy.close(data=self.strategy.data1)
224 | logger.info(f"策略结束,平掉期货仓位 - 日期: {current_date}, 合约: {current_contract}, 入场日期: {self.hedge_entry_date}, 入场价: {self.hedge_entry_price:.2f}")
225 |
226 | # 更新订单信息
227 | if self.hedge_order:
228 | self.hedge_order.info.update({
229 | 'etf_code': current_contract, # 确保使用正确的合约代码
230 | 'execution_date': current_date, # 使用当前日期
231 | 'reason': f"策略结束平仓 - 日期: {current_date}, 合约: {current_contract}, 入场日期: {self.hedge_entry_date}, 入场价: {self.hedge_entry_price:.2f}"
232 | })
--------------------------------------------------------------------------------
/rl_model_finrl/applications/stock_trading/run_strategy.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pandas as pd
3 | import numpy as np
4 | import matplotlib.pyplot as plt
5 | import datetime
6 | from typing import List, Dict, Any, Optional, Union, Tuple
7 | import logging
8 |
9 | from stable_baselines3 import PPO, A2C, DDPG, TD3, SAC, DQN
10 | from stable_baselines3.common.callbacks import BaseCallback, EvalCallback
11 |
12 | from src.strategies.rl_model_finrl.meta.data_processors import DataProcessor
13 | from src.strategies.rl_model_finrl.applications.stock_trading.etf_env import ETFTradingEnv
14 | from src.strategies.rl_model_finrl.agents.stablebaseline3 import DQNAgent
15 | from src.strategies.rl_model_finrl.agents.elegantrl import PPOAgent
16 |
17 | from src.strategies.rl_model_finrl.config import (
18 | INITIAL_AMOUNT,
19 | TRANSACTION_COST_PCT,
20 | TURBULENCE_THRESHOLD,
21 | TRAIN_START_DATE,
22 | TRAIN_END_DATE,
23 | TEST_START_DATE,
24 | TEST_END_DATE,
25 | TECHNICAL_INDICATORS_LIST,
26 | TICKER_LIST,
27 | MODEL_SAVE_PATH,
28 | NUM_EPISODES,
29 | TENSORBOARD_LOG_PATH
30 | )
31 |
32 |
33 | class TensorboardCallback(BaseCallback):
34 | """
35 | 自定义回调,记录训练过程中的奖励和组合价值
36 | """
37 | def __init__(self, verbose=0):
38 | super(TensorboardCallback, self).__init__(verbose)
39 | self.rewards = []
40 |
41 | def _on_step(self) -> bool:
42 | # 获取最近的奖励
43 | if len(self.model.env.envs[0].rewards_memory) > 0:
44 | latest_reward = self.model.env.envs[0].rewards_memory[-1]
45 | self.rewards.append(latest_reward)
46 | # 记录到tensorboard
47 | self.logger.record('train/reward', latest_reward)
48 | self.logger.record('train/portfolio_value', self.model.env.envs[0].total_asset)
49 |
50 | return True
51 |
52 |
53 | def prepare_etf_data(
54 | processor: DataProcessor,
55 | ticker_list: List[str],
56 | start_date: str,
57 | end_date: str,
58 | data_source: str = "tushare"
59 | ) -> pd.DataFrame:
60 | """
61 | 准备ETF交易数据
62 |
63 | 参数:
64 | processor: 数据处理器
65 | ticker_list: ETF代码列表
66 | start_date: 开始日期
67 | end_date: 结束日期
68 | data_source: 数据源
69 |
70 | 返回:
71 | 处理后的DataFrame
72 | """
73 | # 如果没有提供ETF列表,使用默认列表
74 | if not ticker_list:
75 | ticker_list = TICKER_LIST
76 |
77 | # 获取原始数据
78 | df = processor.download_data(
79 | ticker_list=ticker_list,
80 | start_date=start_date,
81 | end_date=end_date,
82 | data_source=data_source
83 | )
84 |
85 | # 处理数据
86 | df = processor.clean_data(df)
87 |
88 | # 添加技术指标
89 | df = processor.add_technical_indicators(df, TECHNICAL_INDICATORS_LIST)
90 |
91 | # 添加波动性指标
92 | df = processor.add_turbulence(df)
93 |
94 | # 填充缺失值
95 | df = df.fillna(method='ffill').fillna(method='bfill')
96 |
97 | # 确保日期是正确的格式
98 | df.index = pd.to_datetime(df.index)
99 |
100 | return df
101 |
102 |
103 | def run_etf_strategy(
104 | start_date: str = TRAIN_START_DATE,
105 | end_date: str = TRAIN_END_DATE,
106 | ticker_list: List[str] = None,
107 | data_source: str = "tushare",
108 | time_interval: str = "1d",
109 | technical_indicator_list: List[str] = TECHNICAL_INDICATORS_LIST,
110 | initial_amount: float = INITIAL_AMOUNT,
111 | transaction_cost_pct: float = TRANSACTION_COST_PCT,
112 | agent: str = "ppo",
113 | model_name: str = None,
114 | turbulence_threshold: float = TURBULENCE_THRESHOLD,
115 | if_store_model: bool = True,
116 | num_episodes: int = NUM_EPISODES,
117 | **kwargs
118 | ) -> Any:
119 | """
120 | 运行ETF交易策略
121 |
122 | 参数:
123 | start_date: 训练开始日期
124 | end_date: 训练结束日期
125 | ticker_list: ETF代码列表
126 | data_source: 数据源
127 | time_interval: 时间间隔
128 | technical_indicator_list: 技术指标列表
129 | initial_amount: 初始资金
130 | transaction_cost_pct: 交易成本百分比
131 | agent: 智能体类型 (ppo, dqn, a2c等)
132 | model_name: 模型名称
133 | turbulence_threshold: 市场波动阈值
134 | if_store_model: 是否存储模型
135 | num_episodes: 训练回合数
136 | **kwargs: 传递给agent的其他参数
137 |
138 | 返回:
139 | 训练好的模型
140 | """
141 | # 设置日志
142 | logging.basicConfig(level=logging.INFO)
143 | logger = logging.getLogger(__name__)
144 |
145 | # 如果未指定模型名称,则自动生成
146 | if model_name is None:
147 | model_name = f"{agent}_etf_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}"
148 |
149 | logger.info(f"开始运行ETF交易策略, 智能体: {agent}, 模型名称: {model_name}")
150 |
151 | # 初始化数据处理器
152 | processor = DataProcessor(data_source=data_source, time_interval=time_interval)
153 |
154 | # 准备数据
155 | logger.info(f"准备ETF数据: {start_date} 到 {end_date}")
156 | df = prepare_etf_data(
157 | processor=processor,
158 | ticker_list=ticker_list or TICKER_LIST,
159 | start_date=start_date,
160 | end_date=end_date,
161 | data_source=data_source
162 | )
163 |
164 | # 创建ETF交易环境
165 | stock_dimension = len(df['tic'].unique())
166 | env = ETFTradingEnv(
167 | df=df,
168 | stock_dim=stock_dimension,
169 | hmax=100,
170 | initial_amount=initial_amount,
171 | transaction_cost_pct=transaction_cost_pct,
172 | reward_scaling=1.0,
173 | tech_indicator_list=technical_indicator_list,
174 | turbulence_threshold=turbulence_threshold,
175 | day_trade=True,
176 | reward_type='sharpe',
177 | cash_penalty_proportion=0.1,
178 | )
179 |
180 | # 创建向量化环境
181 | env_vec = env.get_sb_env()
182 |
183 | # 创建智能体
184 | if agent.lower() == "ppo_elegant":
185 | # ElegantRL PPO实现
186 | model = PPOAgent(
187 | env=env_vec,
188 | model_name=model_name,
189 | learning_rate=kwargs.get('learning_rate', 0.0003),
190 | gamma=kwargs.get('gamma', 0.99),
191 | tensorboard_log=TENSORBOARD_LOG_PATH
192 | )
193 | elif agent.lower() == "dqn":
194 | # Stable-Baselines3 DQN实现
195 | model = DQNAgent(
196 | env=env_vec,
197 | model_name=model_name,
198 | learning_rate=kwargs.get('learning_rate', 0.0001),
199 | gamma=kwargs.get('gamma', 0.99),
200 | tensorboard_log=TENSORBOARD_LOG_PATH
201 | )
202 | elif agent.lower() == "ppo":
203 | # Stable-Baselines3 PPO实现
204 | model = PPO(
205 | "MlpPolicy",
206 | env_vec,
207 | verbose=1,
208 | learning_rate=kwargs.get('learning_rate', 0.0003),
209 | gamma=kwargs.get('gamma', 0.99),
210 | tensorboard_log=TENSORBOARD_LOG_PATH
211 | )
212 | elif agent.lower() == "a2c":
213 | # Stable-Baselines3 A2C实现
214 | model = A2C(
215 | "MlpPolicy",
216 | env_vec,
217 | verbose=1,
218 | learning_rate=kwargs.get('learning_rate', 0.0007),
219 | gamma=kwargs.get('gamma', 0.99),
220 | tensorboard_log=TENSORBOARD_LOG_PATH
221 | )
222 | elif agent.lower() == "ddpg":
223 | # Stable-Baselines3 DDPG实现
224 | model = DDPG(
225 | "MlpPolicy",
226 | env_vec,
227 | verbose=1,
228 | learning_rate=kwargs.get('learning_rate', 0.0001),
229 | gamma=kwargs.get('gamma', 0.99),
230 | tensorboard_log=TENSORBOARD_LOG_PATH
231 | )
232 | elif agent.lower() == "sac":
233 | # Stable-Baselines3 SAC实现
234 | model = SAC(
235 | "MlpPolicy",
236 | env_vec,
237 | verbose=1,
238 | learning_rate=kwargs.get('learning_rate', 0.0003),
239 | gamma=kwargs.get('gamma', 0.99),
240 | tensorboard_log=TENSORBOARD_LOG_PATH
241 | )
242 | else:
243 | raise ValueError(f"不支持的智能体类型: {agent}")
244 |
245 | # 创建回调
246 | callback = TensorboardCallback()
247 |
248 | # 创建模型保存路径
249 | if not os.path.exists(MODEL_SAVE_PATH):
250 | os.makedirs(MODEL_SAVE_PATH)
251 |
252 | # 训练模型
253 | logger.info(f"开始训练模型: {model_name}")
254 |
255 | # 根据不同模型类型采用不同的训练方法
256 | if agent.lower() in ["ppo_elegant"]:
257 | # ElegantRL训练方法
258 | model.train(
259 | total_timesteps=num_episodes * 100,
260 | eval_freq=1000,
261 | n_eval_episodes=5,
262 | log_interval=100
263 | )
264 | else:
265 | # Stable-Baselines3训练方法
266 | model.learn(
267 | total_timesteps=num_episodes * 100,
268 | callback=callback,
269 | tb_log_name=model_name
270 | )
271 |
272 | # 保存模型
273 | if if_store_model:
274 | if agent.lower() in ["ppo_elegant"]:
275 | # ElegantRL模型保存
276 | model_path = os.path.join(MODEL_SAVE_PATH, f"{model_name}.pt")
277 | model.save(model_path)
278 | else:
279 | # Stable-Baselines3模型保存
280 | model_path = os.path.join(MODEL_SAVE_PATH, f"{model_name}.zip")
281 | model.save(model_path)
282 |
283 | logger.info(f"模型已保存至: {model_path}")
284 |
285 | # 绘制训练曲线
286 | if len(callback.rewards) > 0:
287 | plt.figure(figsize=(12, 6))
288 | plt.plot(callback.rewards)
289 | plt.title(f'{agent.upper()} ETF交易策略 - 训练奖励')
290 | plt.xlabel('步数')
291 | plt.ylabel('奖励')
292 | plt.grid(True)
293 | plt.savefig(f"results/{model_name}_train_rewards.png")
294 | plt.close()
295 |
296 | # 返回训练好的模型
297 | return model
298 |
299 |
300 | def load_etf_model(
301 | model_name: str,
302 | agent: str = "ppo",
303 | env = None
304 | ) -> Any:
305 | """
306 | 加载已训练的ETF交易模型
307 |
308 | 参数:
309 | model_name: 模型名称
310 | agent: 智能体类型
311 | env: 环境实例(可选)
312 |
313 | 返回:
314 | 加载的模型
315 | """
316 | if agent.lower() == "ppo_elegant":
317 | # ElegantRL模型加载
318 | model_path = os.path.join(MODEL_SAVE_PATH, f"{model_name}.pt")
319 | model = PPOAgent(env=env)
320 | model.load(model_path)
321 | elif agent.lower() == "dqn":
322 | # Stable-Baselines3 DQN模型加载
323 | model_path = os.path.join(MODEL_SAVE_PATH, f"{model_name}.zip")
324 | model = DQN.load(model_path, env=env)
325 | elif agent.lower() == "ppo":
326 | # Stable-Baselines3 PPO模型加载
327 | model_path = os.path.join(MODEL_SAVE_PATH, f"{model_name}.zip")
328 | model = PPO.load(model_path, env=env)
329 | elif agent.lower() == "a2c":
330 | # Stable-Baselines3 A2C模型加载
331 | model_path = os.path.join(MODEL_SAVE_PATH, f"{model_name}.zip")
332 | model = A2C.load(model_path, env=env)
333 | elif agent.lower() == "ddpg":
334 | # Stable-Baselines3 DDPG模型加载
335 | model_path = os.path.join(MODEL_SAVE_PATH, f"{model_name}.zip")
336 | model = DDPG.load(model_path, env=env)
337 | elif agent.lower() == "sac":
338 | # Stable-Baselines3 SAC模型加载
339 | model_path = os.path.join(MODEL_SAVE_PATH, f"{model_name}.zip")
340 | model = SAC.load(model_path, env=env)
341 | else:
342 | raise ValueError(f"不支持的智能体类型: {agent}")
343 |
344 | return model
345 |
346 |
347 | if __name__ == "__main__":
348 | # 示例用法
349 | model = run_etf_strategy(
350 | start_date=TRAIN_START_DATE,
351 | end_date=TRAIN_END_DATE,
352 | ticker_list=TICKER_LIST,
353 | agent="ppo",
354 | model_name="ppo_etf_demo",
355 | num_episodes=NUM_EPISODES
356 | )
--------------------------------------------------------------------------------
/src/strategies/dual_ma_hedging/macd_hedge.py:
--------------------------------------------------------------------------------
1 | import backtrader as bt
2 | from loguru import logger
3 |
4 | class MACDHedge:
5 | def __init__(self, strategy):
6 | self.strategy = strategy
7 | self.enabled = False
8 | self.hedge_position = None
9 | self.hedge_entry_price = None
10 | self.hedge_order = None
11 | self.hedge_contract_code = None
12 | self.hedge_entry_date = None # 添加入场日期记录
13 |
14 | # 初始化MACD指标
15 | self.macd = bt.indicators.MACD(
16 | self.strategy.data.close,
17 | period_me1=12,
18 | period_me2=26,
19 | period_signal=9
20 | )
21 |
22 | def enable(self):
23 | """启用MACD对冲功能"""
24 | self.enabled = True
25 | logger.info("启用MACD对冲功能")
26 |
27 | def disable(self):
28 | """禁用MACD对冲功能"""
29 | self.enabled = False
30 | logger.info("禁用MACD对冲功能")
31 |
32 | def on_death_cross(self):
33 | """在MACD零轴上方死叉时开空期货"""
34 | if not self.enabled:
35 | return
36 |
37 | if self.hedge_position is not None or self.hedge_order is not None:
38 | logger.info("已有对冲仓位或对冲订单,不再开仓")
39 | return
40 |
41 | # 检查MACD是否在零轴上方
42 | if self.macd.macd[0] <= 0:
43 | logger.info("MACD不在零轴上方,不开仓")
44 | return
45 |
46 | # 检查是否形成死叉(MACD线下穿信号线)
47 | if self.macd.macd[0] > self.macd.signal[0] or self.macd.macd[-1] <= self.macd.signal[-1]:
48 | logger.info("未形成死叉,不开仓")
49 | return
50 |
51 | try:
52 | # 计算ATR止盈止损价格
53 | current_atr = self.strategy.atr[0]
54 | stop_loss = self.strategy.data1.close[0] + (current_atr * self.strategy.p.atr_loss_multiplier)
55 | take_profit = self.strategy.data1.close[0] - (current_atr * self.strategy.p.atr_profit_multiplier)
56 |
57 | # 开空豆粕期货
58 | hedge_size = self.strategy.p.hedge_contract_size
59 |
60 | # 检查期货账户资金是否足够
61 | future_price = self.strategy.data1.close[0]
62 | margin_requirement = future_price * hedge_size * self.strategy.p.future_contract_multiplier * self.strategy.p.m_margin_ratio
63 |
64 | if margin_requirement > self.strategy.future_cash:
65 | logger.warning(f"期货账户资金不足,需要{margin_requirement:.2f},当前可用{self.strategy.future_cash:.2f}")
66 | # 根据可用资金调整手数
67 | adjusted_size = int(self.strategy.future_cash / (future_price * self.strategy.p.future_contract_multiplier * self.strategy.p.m_margin_ratio))
68 | if adjusted_size < 1:
69 | logger.error("期货账户资金不足以开仓一手")
70 | return
71 | hedge_size = adjusted_size
72 | logger.info(f"已调整对冲手数为: {hedge_size}")
73 |
74 | self.hedge_order = self.strategy.sell(data=self.strategy.data1, size=hedge_size)
75 |
76 | if self.hedge_order:
77 | # 记录入场价格和合约代码
78 | self.hedge_entry_price = self.strategy.data1.close[0]
79 | # 确保获取正确的合约代码
80 | current_date = self.strategy.data1.datetime.datetime(0)
81 | if hasattr(self.strategy.data1, 'contract_mapping') and current_date in self.strategy.data1.contract_mapping:
82 | self.hedge_contract_code = self.strategy.data1.contract_mapping[current_date]
83 | else:
84 | # 如果无法获取映射,使用数据名称
85 | self.hedge_contract_code = self.strategy.data1._name
86 | self.hedge_entry_date = self.strategy.data.datetime.date(0) # 记录入场日期
87 |
88 | # 计算保证金
89 | margin = self.hedge_entry_price * hedge_size * self.strategy.p.future_contract_multiplier * self.strategy.p.m_margin_ratio
90 |
91 | # 从期货账户扣除保证金
92 | pre_cash = self.strategy.future_cash
93 | self.strategy.future_cash -= margin
94 |
95 | logger.info(f"开仓扣除保证金 - 之前: {pre_cash:.2f}, 扣除: {margin:.2f}, 之后: {self.strategy.future_cash:.2f}")
96 |
97 | # 记录交易信息
98 | self.hedge_order.info = {
99 | 'reason': f"MACD死叉开空 - MACD: {self.macd.macd[0]:.2f}, Signal: {self.macd.signal[0]:.2f}",
100 | 'margin': margin,
101 | 'future_cash': self.strategy.future_cash,
102 | 'execution_date': self.hedge_entry_date,
103 | 'total_value': self.strategy.future_cash,
104 | 'position_value': abs(margin),
105 | 'position_ratio': margin / self.strategy.future_cash if self.strategy.future_cash > 0 else 0,
106 | 'etf_code': self.hedge_contract_code,
107 | 'pnl': 0,
108 | 'return': 0,
109 | 'stop_loss': stop_loss,
110 | 'take_profit': take_profit,
111 | 'avg_cost': self.hedge_entry_price
112 | }
113 |
114 | logger.info(f"MACD死叉开空 - 合约: {self.hedge_contract_code}, 价格: {self.hedge_entry_price:.2f}, 数量: {hedge_size}手, "
115 | f"止损价: {stop_loss:.2f}, 止盈价: {take_profit:.2f}")
116 |
117 | except Exception as e:
118 | logger.error(f"MACD死叉开空失败: {str(e)}")
119 |
120 | def check_exit(self):
121 | """检查是否需要平仓"""
122 | if not self.enabled or not self.hedge_position:
123 | return
124 |
125 | current_price = self.strategy.data1.close[0]
126 | current_atr = self.strategy.atr[0]
127 |
128 | # 计算ATR止盈止损价格
129 | stop_loss = self.hedge_entry_price + (current_atr * self.strategy.p.atr_loss_multiplier)
130 | take_profit = self.hedge_entry_price - (current_atr * self.strategy.p.atr_profit_multiplier)
131 |
132 | # 检查是否触发止盈止损
133 | if current_price >= stop_loss or current_price <= take_profit:
134 | if self.hedge_order is None: # 确保没有未完成订单
135 | self.hedge_order = self.strategy.close(data=self.strategy.data1)
136 | reason = "触发止盈" if current_price <= take_profit else "触发止损"
137 | logger.info(f"MACD死叉对冲{reason} - 当前价格: {current_price:.2f}, {reason}价: {take_profit if current_price <= take_profit else stop_loss:.2f}")
138 |
139 | def on_order_completed(self, order):
140 | """处理订单完成事件"""
141 | if not self.enabled:
142 | return
143 |
144 | if order.status in [order.Completed]:
145 | if order.isbuy(): # 买入豆粕期货(平空)
146 | # 确保有对应的入场价格
147 | if self.hedge_entry_price is None or self.hedge_contract_code is None:
148 | logger.error("平仓时找不到入场价格或合约代码,跳过处理")
149 | return
150 |
151 | # 记录平仓前的合约信息,用于日志
152 | entry_price = self.hedge_entry_price
153 | contract_code = self.hedge_contract_code
154 | entry_date = self.hedge_entry_date
155 |
156 | # 记录交易日期和价格
157 | trade_date = self.strategy.data.datetime.date(0)
158 | trade_price = order.executed.price
159 |
160 | # 先重置持仓相关变量,防止重复平仓
161 | self.hedge_position = None
162 | self.hedge_order = None
163 | self.hedge_entry_price = None
164 | self.hedge_contract_code = None
165 | self.hedge_entry_date = None
166 | self.hedge_target_profit = None
167 |
168 | # 计算对冲盈亏(空仓:入场价 - 平仓价)
169 | hedge_profit = (entry_price - trade_price) * self.strategy.p.hedge_contract_size * self.strategy.p.future_contract_multiplier
170 |
171 | # 减去开平仓手续费
172 | total_fee = self.strategy.p.hedge_fee * self.strategy.p.hedge_contract_size * 2
173 | net_profit = hedge_profit - total_fee
174 |
175 | # 归还保证金并添加盈亏到期货账户
176 | margin_returned = entry_price * self.strategy.p.hedge_contract_size * self.strategy.p.future_contract_multiplier * self.strategy.p.m_margin_ratio
177 |
178 | # 记录更新前的资金
179 | pre_cash = self.strategy.future_cash
180 |
181 | # 更新期货账户资金
182 | self.strategy.future_cash += (margin_returned + net_profit)
183 |
184 | # 记录资金变动
185 | logger.info(f"平仓资金变动 - 之前: {pre_cash:.2f}, 返还保证金: {margin_returned:.2f}, 盈亏: {net_profit:.2f}, 之后: {self.strategy.future_cash:.2f}")
186 |
187 | # 更新期货账户最高净值
188 | self.strategy.future_highest_value = max(self.strategy.future_highest_value, self.strategy.future_cash)
189 |
190 | # 计算期货账户回撤
191 | future_drawdown = (self.strategy.future_highest_value - self.strategy.future_cash) / self.strategy.future_highest_value if self.strategy.future_highest_value > 0 else 0
192 |
193 | # 计算收益率
194 | return_pct = (hedge_profit / (entry_price * self.strategy.p.hedge_contract_size * self.strategy.p.future_contract_multiplier)) * 100
195 |
196 | # 更新订单信息
197 | order.info.update({
198 | 'pnl': hedge_profit,
199 | 'return': return_pct,
200 | 'total_value': self.strategy.future_cash,
201 | 'position_value': 0, # 平仓后持仓价值为0
202 | 'avg_cost': entry_price,
203 | 'etf_code': contract_code, # 确保使用原始合约代码
204 | 'execution_date': trade_date, # 确保使用当前交易日期
205 | 'reason': f"MACD死叉对冲平仓 - 合约: {contract_code}, 入场日期: {entry_date}, 入场价: {entry_price:.2f}, 平仓价: {trade_price:.2f}, 收益率: {return_pct:.2f}%"
206 | })
207 |
208 | logger.info(f"MACD死叉对冲平仓 - 日期: {trade_date}, 合约: {contract_code}, 价格: {trade_price:.2f}, 盈利: {hedge_profit:.2f}, "
209 | f"手续费: {total_fee:.2f}, 净盈利: {net_profit:.2f}, "
210 | f"期货账户余额: {self.strategy.future_cash:.2f}, 回撤: {future_drawdown:.2%}, "
211 | f"收益率: {return_pct:.2f}%")
212 |
213 | else: # 卖出豆粕期货(开空)
214 | # 记录对冲持仓
215 | self.hedge_position = order
216 |
217 | # 更新订单信息
218 | order.info.update({
219 | 'total_value': self.strategy.future_cash,
220 | 'position_value': abs(order.info['margin']),
221 | 'avg_cost': order.executed.price,
222 | 'etf_code': self.hedge_contract_code # 确保合约代码正确
223 | })
224 |
225 | elif order.status in [order.Canceled, order.Margin, order.Rejected]:
226 | self.hedge_order = None
227 | logger.warning(f'MACD死叉对冲订单失败 - 状态: {order.getstatusname()}')
228 |
229 | def on_strategy_stop(self):
230 | """策略结束时平掉所有持仓"""
231 | if not self.enabled or not self.hedge_position:
232 | return
233 |
234 | try:
235 | if self.hedge_order is None: # 确保没有未完成订单
236 | self.hedge_order = self.strategy.close(data=self.strategy.data1)
237 | logger.info("策略结束,平掉MACD对冲持仓")
238 | except Exception as e:
239 | logger.error(f"策略结束时平仓失败: {str(e)}")
--------------------------------------------------------------------------------