├── requirements.txt ├── .streamlit └── config.toml ├── rl_model_finrl ├── agents │ ├── elegantrl │ │ └── __init__.py │ ├── stablebaseline3 │ │ ├── __init__.py │ │ └── dqn_agent.py │ ├── __init__.py │ └── rllib │ │ ├── __init__.py │ │ └── ppo_agent.py ├── meta │ ├── preprocessor │ │ ├── __init__.py │ │ ├── data_normalizer.py │ │ └── feature_engineer.py │ ├── __init__.py │ └── data_processors │ │ ├── __init__.py │ │ └── akshare_processor.py ├── __init__.py ├── applications │ └── stock_trading │ │ ├── __init__.py │ │ └── run_strategy.py ├── config.py └── README.md ├── ui ├── help │ └── intro.md ├── articles │ ├── 雪球Token如何获取?.md │ └── 什么是ETF?.md └── pages │ ├── settings.py │ ├── market.py │ └── sidebar.py ├── src ├── utils │ ├── logger.py │ ├── notification.py │ ├── backtest_engine.py │ └── plot.py ├── strategies │ ├── strategy_factory.py │ ├── market_sentiment │ │ ├── utils.py │ │ └── etf_dividend_handler.py │ ├── rl_model_strategy.py │ └── dual_ma_hedging │ │ ├── ma_cross_hedge.py │ │ ├── sync_long_hedge.py │ │ └── macd_hedge.py ├── indicators │ └── trailing_stop.py ├── trading │ └── market_executor.py └── data │ └── data_loader.py ├── LICENSE ├── tools ├── visualization_README.md └── plot_general_json.py ├── app.py ├── docs ├── etf_rotation_strategy.md └── market_sentiment_strategy.md ├── README.md └── .gitignore /requirements.txt: -------------------------------------------------------------------------------- 1 | streamlit 2 | backtrader 3 | pandas 4 | numpy 5 | matplotlib 6 | loguru 7 | akshare 8 | tushare 9 | hyperopt 10 | plotly 11 | bokeh 12 | arch 13 | ray 14 | stable-baselines3 15 | gym 16 | torch 17 | scikit-learn 18 | dm_tree -------------------------------------------------------------------------------- /.streamlit/config.toml: -------------------------------------------------------------------------------- 1 | [theme] 2 | backgroundColor="#ffffff" 3 | textColor="#262730" 4 | font="sans serif" 5 | 6 | [server] 7 | enableXsrfProtection = true 8 | 9 | [browser] 10 | gatherUsageStats = false 11 | 12 | [runner] 13 | fastRerenderThreshold = 1000 14 | 15 | [client] 16 | showErrorDetails = true 17 | 18 | [runner.magic] 19 | fastRerenderThreshold = 1000 -------------------------------------------------------------------------------- /rl_model_finrl/agents/elegantrl/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | 基于ElegantRL的强化学习智能体 3 | 4 | 此模块包含使用ElegantRL库实现的强化学习智能体: 5 | - PPOAgent: 基于PPO算法的交易智能体 6 | 7 | ElegantRL提供了高效、轻量级的深度强化学习实现,尤其适合金融应用场景。 8 | 其优势包括: 9 | - 更高效的训练速度和采样效率 10 | - 灵活的网络架构设计 11 | - 对多进程训练的良好支持 12 | """ 13 | 14 | from src.strategies.rl_model_finrl.agents.elegantrl.ppo_agent import PPOAgent 15 | 16 | __all__ = [ 17 | 'PPOAgent', 18 | ] -------------------------------------------------------------------------------- /ui/help/intro.md: -------------------------------------------------------------------------------- 1 | ### 🎯 系统功能 2 | 这是一个专业的 量化交易策略回测系统,支持多种交易策略的回测和分析。系统具有以下特点: 3 | 4 | - 📊 **实时数据**:支持通过 Tushare(专业版)或 AKShare(免费)获取实时行情数据 5 | - 🚀 **多策略支持**:采用工厂模式设计,支持多种交易策略,便于扩展 6 | - 📈 **可视化分析**:使用 Plotly 提供交互式图表,包括 K 线、均线、交易点位等 7 | - ⚠️ **风险控制**:内置追踪止损、最大回撤限制等风险控制机制 8 | - 💰 **费用模拟**:精确计算交易费用,包括佣金等 9 | - 📝 **详细日志**:记录每笔交易的详细信息,便于分析和优化 10 | 11 | ### ⚠️ 风险提示 12 | 本系统仅供学习和研究使用,不构成任何投资建议。使用本系统进行实盘交易需要自行承担风险。 -------------------------------------------------------------------------------- /rl_model_finrl/agents/stablebaseline3/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | 基于Stable-Baselines3的强化学习智能体 3 | 4 | 此模块包含使用Stable-Baselines3库实现的强化学习智能体: 5 | - DQNAgent: 基于DQN算法的交易智能体 6 | 7 | 智能体功能包括: 8 | - 训练: 使用经验回放和目标网络进行深度Q学习 9 | - 预测: 基于学习策略选择最优交易动作 10 | - 测试: 使用学习策略进行回测 11 | - 保存/加载: 支持模型的保存和加载 12 | """ 13 | 14 | from src.strategies.rl_model_finrl.agents.stablebaseline3.dqn_agent import DQNAgent 15 | 16 | __all__ = [ 17 | 'DQNAgent', 18 | ] -------------------------------------------------------------------------------- /rl_model_finrl/meta/preprocessor/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | 预处理器模块 3 | 4 | 这个模块提供了金融数据预处理的工具和类,用于强化学习环境中的数据准备。 5 | 6 | 主要组件: 7 | - FeatureEngineer: 金融特征工程工具 8 | - DataNormalizer: 数据归一化处理器 9 | 10 | 主要功能: 11 | - 技术指标计算与特征工程 12 | - 数据归一化和标准化 13 | - 缺失值处理 14 | - 异常值检测 15 | """ 16 | 17 | from src.strategies.rl_model_finrl.meta.preprocessor.feature_engineer import FeatureEngineer 18 | from src.strategies.rl_model_finrl.meta.preprocessor.data_normalizer import DataNormalizer 19 | 20 | __all__ = [ 21 | 'FeatureEngineer', 22 | 'DataNormalizer' 23 | ] -------------------------------------------------------------------------------- /rl_model_finrl/meta/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | FinRL-Meta模块 3 | 4 | 此模块包含FinRL框架的基础组件,包括: 5 | - data_processors: 数据处理器,用于获取和处理市场数据 6 | - preprocessor: 数据预处理,用于生成技术指标和特征 7 | 8 | FinRL-Meta提供了金融强化学习的基础设施,支持: 9 | - 多种数据源的集成和处理 10 | - 标准化的交易环境接口 11 | - 丰富的市场特征和状态表示 12 | """ 13 | 14 | from src.strategies.rl_model_finrl.meta.data_processors import TushareProcessor, AKShareProcessor 15 | from src.strategies.rl_model_finrl.meta.preprocessor import FeatureEngineer, DataNormalizer 16 | 17 | __all__ = [ 18 | 'TushareProcessor', 19 | 'AKShareProcessor', 20 | 'FeatureEngineer', 21 | 'DataNormalizer' 22 | ] -------------------------------------------------------------------------------- /rl_model_finrl/agents/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | 强化学习智能体模块 3 | 4 | 此模块包含各种实现的强化学习智能体,支持不同的算法库: 5 | - stablebaseline3: 基于Stable-Baselines3库的智能体 6 | - elegantrl: 基于ElegantRL库的智能体 7 | - rllib: 基于Ray RLlib库的智能体 8 | 9 | 当前支持的智能体: 10 | - DQNAgent: 基于DQN算法的强化学习智能体 (Stable-Baselines3) 11 | - PPOAgent: 基于PPO算法的强化学习智能体 (ElegantRL) 12 | - RLlibPPOAgent: 基于PPO算法的强化学习智能体 (Ray RLlib) 13 | """ 14 | 15 | from src.strategies.rl_model_finrl.agents.stablebaseline3 import DQNAgent 16 | from src.strategies.rl_model_finrl.agents.elegantrl import PPOAgent 17 | from src.strategies.rl_model_finrl.agents.rllib import RLlibPPOAgent 18 | 19 | __all__ = [ 20 | 'DQNAgent', 21 | 'PPOAgent', 22 | 'RLlibPPOAgent', 23 | ] -------------------------------------------------------------------------------- /rl_model_finrl/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | RL模型基于FinRL框架优化的ETF交易模块 3 | 4 | 此模块包含使用强化学习进行ETF交易的完整框架,基于FinRL架构设计。 5 | 模块使用了三层架构: 6 | 1. 数据层(meta/data_processors): 负责从Tushare和AKShare获取A股ETF数据 7 | 2. 环境层(meta/env_stock_trading): 提供了ETF交易的交互环境 8 | 3. 智能体层(agents): 实现了各种RL算法,包括DQN等 9 | 10 | 主要功能: 11 | - 使用Tushare和AKShare获取A股ETF数据 12 | - 构建多ETF交易环境 13 | - 实现DQN等强化学习算法 14 | - 提供训练和回测功能 15 | """ 16 | 17 | # 版本信息 18 | __version__ = "0.1.0" 19 | 20 | # 核心组件 21 | from src.strategies.rl_model_finrl.applications.stock_trading.etf_env import ETFTradingEnv 22 | from src.strategies.rl_model_finrl.agents.stablebaseline3.dqn_agent import DQNAgent 23 | 24 | # 导出接口 25 | __all__ = [ 26 | 'ETFTradingEnv', 27 | 'DQNAgent' 28 | ] -------------------------------------------------------------------------------- /rl_model_finrl/agents/rllib/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | RLlib智能体模块 3 | 4 | 这个模块提供了基于Ray RLlib框架的强化学习智能体实现,支持: 5 | - PPO (Proximal Policy Optimization) 6 | 7 | 主要组件: 8 | - RLlibPPOAgent: 基于RLlib的PPO算法实现 9 | 10 | 使用示例: 11 | ```python 12 | from src.strategies.rl_model_finrl.agents.rllib.ppo_agent import RLlibPPOAgent 13 | from src.strategies.rl_model_finrl.meta.env_stock_trading.etf_trading_env import ETFTradingEnv 14 | 15 | # 创建环境 16 | env = ETFTradingEnv(...) 17 | 18 | # 初始化PPO智能体 19 | agent = RLlibPPOAgent(env=env) 20 | 21 | # 训练模型 22 | agent.learn(total_timesteps=100000) 23 | 24 | # 保存模型 25 | agent.save("models/ppo_rllib") 26 | 27 | # 加载模型 28 | agent.load("models/ppo_rllib") 29 | 30 | # 测试模型 31 | asset_memory, action_memory = agent.test(test_env) 32 | ``` 33 | """ 34 | 35 | from src.strategies.rl_model_finrl.agents.rllib.ppo_agent import RLlibPPOAgent 36 | 37 | __all__ = [ 38 | 'RLlibPPOAgent' 39 | ] -------------------------------------------------------------------------------- /src/utils/logger.py: -------------------------------------------------------------------------------- 1 | from loguru import logger 2 | import sys 3 | import os 4 | from datetime import datetime 5 | 6 | def setup_logger(): 7 | """配置日志记录器""" 8 | # 移除默认的处理程序 9 | logger.remove() 10 | 11 | # 获取当前日期作为日志文件名的一部分 12 | current_date = datetime.now().strftime("%Y%m%d") 13 | log_file = f"logs/backtest_{current_date}.log" 14 | 15 | # 确保日志目录存在 16 | os.makedirs("logs", exist_ok=True) 17 | 18 | # 添加控制台输出 19 | logger.add( 20 | sys.stderr, 21 | format="{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {name}:{function}:{line} - {message}", 22 | level="INFO" 23 | ) 24 | 25 | # 添加文件输出 26 | logger.add( 27 | log_file, 28 | format="{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {name}:{function}:{line} - {message}", 29 | level="DEBUG", 30 | rotation="1 day", # 每天轮换一次日志文件 31 | retention="30 days", # 保留30天的日志 32 | compression="zip" # 压缩旧的日志文件 33 | ) 34 | 35 | return logger -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 plan9x 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /src/strategies/strategy_factory.py: -------------------------------------------------------------------------------- 1 | from typing import Type, Dict 2 | import backtrader as bt 3 | from .dual_ma_hedging_strategy import DualMAHedgingStrategy 4 | from .dual_ma_strategy import DualMAStrategy 5 | from .market_sentiment_strategy import MarketSentimentStrategy 6 | from .etf_rotation_strategy import ETFRotationStrategy 7 | 8 | class StrategyFactory: 9 | _strategies: Dict[str, Type[bt.Strategy]] = { 10 | "双均线策略": DualMAStrategy, 11 | "市场情绪策略": MarketSentimentStrategy, 12 | "ETF轮动策略": ETFRotationStrategy, 13 | "双均线对冲策略": DualMAHedgingStrategy, 14 | } 15 | 16 | @classmethod 17 | def register_strategy(cls, name: str, strategy_class: Type[bt.Strategy]): 18 | """注册新的策略""" 19 | cls._strategies[name] = strategy_class 20 | 21 | @classmethod 22 | def get_strategy(cls, name: str) -> Type[bt.Strategy]: 23 | """获取策略类""" 24 | return cls._strategies.get(name) 25 | 26 | @classmethod 27 | def get_all_strategies(cls) -> Dict[str, Type[bt.Strategy]]: 28 | """获取所有已注册的策略""" 29 | return cls._strategies.copy() 30 | 31 | @classmethod 32 | def get_strategy_names(cls) -> list: 33 | """获取所有策略名称""" 34 | return list(cls._strategies.keys()) -------------------------------------------------------------------------------- /tools/visualization_README.md: -------------------------------------------------------------------------------- 1 | # JSON数据可视化工具 2 | 3 | 本仓库包含两个Streamlit应用程序,用于可视化JSON数据: 4 | 5 | 1. `plot_json_data.py` - 专门用于可视化cache目录数据的工具 6 | 2. `plot_general_json.py` - 可以处理各种JSON结构的通用工具 7 | 8 | ## 环境要求 9 | 10 | 所有需要的包都列在项目的 `requirements.txt` 文件中。可视化工具使用: 11 | 12 | - streamlit 13 | - pandas 14 | - numpy 15 | - matplotlib 16 | - plotly 17 | 18 | ## 如何运行 19 | 20 | ### 运行cache目录可视化工具 21 | 22 | ```bash 23 | streamlit run plot_json_data.py 24 | ``` 25 | 26 | 该工具专门设计用于可视化具有预定义结构的cache目录下的ETF分红数据和市场情绪数据的JSON文件。它提供三种可视化类型: 27 | 28 | - 分红历史(时间序列) 29 | - 按年份的分红箱线图 30 | - 分红热力图(按年和月) 31 | 32 | ### 运行通用JSON可视化工具 33 | 34 | ```bash 35 | streamlit run plot_general_json.py 36 | ``` 37 | 38 | 这个工具更加灵活,可以处理各种JSON数据结构。它支持: 39 | 40 | 1. 来自缓存目录的JSON文件 41 | 2. 用户上传的JSON文件 42 | 43 | 该工具会自动检测JSON数据的结构,并提供适当的可视化选项: 44 | 45 | - 时间序列(适用于带有日期/时间列的数据) 46 | - 柱状图 47 | - 散点图 48 | - 直方图 49 | - 箱线图 50 | - 热力图 51 | 52 | ## 功能特点 53 | 54 | - 自动数据类型检测 55 | - 使用Plotly的交互式可视化 56 | - 日期/时间字段自动检测 57 | - 支持嵌套的JSON结构 58 | - 数据统计显示 59 | - 原始数据查看选项 60 | 61 | ## 使用示例 62 | 63 | 1. 启动通用可视化工具 64 | 2. 从缓存目录选择JSON文件或上传您自己的文件 65 | 3. 从可用选项中选择可视化类型 66 | 4. 根据需要选择X轴、Y轴和其他参数 67 | 5. 与生成的可视化进行交互 68 | 6. 可选择通过勾选"显示原始数据"查看原始数据 69 | 70 | ## JSON数据格式支持 71 | 72 | 通用可视化工具可以处理: 73 | 74 | - 字典列表(最常见的数据可视化格式) 75 | - 嵌套字典 76 | - 具有嵌套数组和对象的混合结构 77 | 78 | 对于复杂的嵌套结构,该工具会尝试在保持关系的同时展平数据以实现可视化。 -------------------------------------------------------------------------------- /rl_model_finrl/applications/stock_trading/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | ETF交易应用模块 3 | 4 | 本模块提供了基于强化学习的ETF交易应用案例,用于展示如何使用FinRL框架构建ETF交易策略。 5 | 该应用在模拟和实际市场数据上进行回测,以评估策略的性能。 6 | 7 | 主要组件: 8 | 1. ETF交易环境: 一个特定为ETF交易定制的强化学习环境 9 | 2. 交易策略示例: 基于RL模型的ETF交易策略 10 | 3. 回测工具: 用于评估模型性能的回测工具 11 | 4. 性能分析: 交易策略的收益率、风险和其他相关指标分析 12 | 13 | 用法示例: 14 | ```python 15 | from src.strategies.rl_model_finrl.applications.stock_trading import ( 16 | run_etf_strategy, 17 | backtest_etf_strategy, 18 | ETFTradingEnv 19 | ) 20 | 21 | # 训练ETF交易策略 22 | trained_model = run_etf_strategy( 23 | start_date='2010-01-01', 24 | end_date='2020-12-31', 25 | ticker_list=['SPY', 'QQQ', 'IWM', 'EEM'], 26 | agent='ppo' 27 | ) 28 | 29 | # 回测策略 30 | performance = backtest_etf_strategy( 31 | model=trained_model, 32 | test_start='2021-01-01', 33 | test_end='2021-12-31' 34 | ) 35 | ``` 36 | """ 37 | 38 | from src.strategies.rl_model_finrl.applications.stock_trading.etf_env import ETFTradingEnv 39 | from src.strategies.rl_model_finrl.applications.stock_trading.run_strategy import run_etf_strategy 40 | from src.strategies.rl_model_finrl.applications.stock_trading.backtest import backtest_etf_strategy 41 | from src.strategies.rl_model_finrl.applications.stock_trading.analysis import ETFStrategyAnalyzer 42 | 43 | __all__ = [ 44 | 'ETFTradingEnv', 45 | 'run_etf_strategy', 46 | 'backtest_etf_strategy', 47 | 'ETFStrategyAnalyzer' 48 | ] -------------------------------------------------------------------------------- /ui/articles/雪球Token如何获取?.md: -------------------------------------------------------------------------------- 1 | # 雪球Token如何获取? 2 | 3 | 雪球实时行情数据可以通过以下接口获取: 4 | ``` 5 | https://stock.xueqiu.com/v5/stock/quote.json?symbol=SH000001&extend=detail 6 | ``` 7 | 注意:此接口需要token认证才能访问。 8 | 9 | ## 接口详情 10 | 11 | ### AKShare集成接口 12 | - 接口名称: stock_individual_spot_xq 13 | - 接口类型: 雪球-行情中心-个股 14 | - 目标地址: https://xueqiu.com/S/SH513520 15 | - 访问限制: 单次获取指定 symbol 的最新行情数据 16 | 17 | ### 认证信息获取步骤 18 | 19 | 获取token的具体方法如下: 20 | 21 | 1. 登录雪球官方网站 (https://xueqiu.com) 22 | 2. 打开浏览器开发者工具(按F12键) 23 | 3. 在浏览器中请求测试接口: 24 | ``` 25 | https://stock.xueqiu.com/v5/stock/quote.json?symbol=SH000001&extend=detail 26 | ``` 27 | 4. 在开发者工具的"网络"(Network)标签页中,找到该请求 28 | 5. 在请求的标头(Headers)信息中找到cookie字段,内容示例如下: 29 | ``` 30 | cookie: cookiesu=631719021518417; device_id=c37e94779eede2e3f0250482d804ff81; s=af1w3zgebe; bid=03b31a66824a4a426b80229356033a8c_m4hqnffb; Hm_lvt_1db88642e346389874251b5a1eded6e3=1742519317; HMACCOUNT=2436B2ECFC061307; xq_a_token=d679467b716fd5b0a0af195f7e8143774d271a41; [... 其他cookie值 ...] 31 | ``` 32 | 6. 从cookie中提取`xq_a_token`字段的值,这就是我们需要的token 33 | 34 | ### Token示例 35 | ``` 36 | xq_a_token=d679467b716fd5b0a0af195f7e8143774d271a41 37 | ``` 38 | 39 | ## 使用注意事项 40 | - Token具有时效性,一般有效期为7天,需要定期更新 41 | - 请合理控制接口调用频率,避免触发雪球的访问限制 42 | - 建议在程序中实现token自动更新机制 43 | - 不要公开分享或传播你的个人token 44 | - 使用token时建议通过环境变量或配置文件管理,避免硬编码 45 | 46 | ## 常见问题 47 | - 如遇到"未授权"错误,请检查token是否过期 48 | - 如遇到访问限制,请适当降低请求频率 49 | - Token仅对获取它时登录的账号有效 -------------------------------------------------------------------------------- /src/utils/notification.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import requests 4 | from src.utils.logger import setup_logger 5 | 6 | logger = setup_logger() 7 | 8 | def load_settings(): 9 | """加载通知设置""" 10 | settings_file = "config/settings.json" 11 | if os.path.exists(settings_file): 12 | with open(settings_file, "r", encoding="utf-8") as f: 13 | return json.load(f) 14 | return {"sms": {"enabled": False}, "wechat": {"enabled": False}} 15 | 16 | def send_sms(message: str, api_key: str, phone_number: str): 17 | """发送短信通知""" 18 | # 这里需要根据实际使用的短信服务商API来实现 19 | # 示例使用阿里云短信服务 20 | try: 21 | # TODO: 实现实际的短信发送逻辑 22 | logger.info(f"发送短信到 {phone_number}: {message}") 23 | except Exception as e: 24 | logger.error(f"发送短信失败: {str(e)}") 25 | 26 | def send_wechat(message: str, webhook_url: str): 27 | """发送企业微信通知""" 28 | try: 29 | data = { 30 | "msgtype": "text", 31 | "text": { 32 | "content": message 33 | } 34 | } 35 | response = requests.post(webhook_url, json=data) 36 | response.raise_for_status() 37 | logger.info(f"发送企业微信通知成功: {message}") 38 | except Exception as e: 39 | logger.error(f"发送企业微信通知失败: {str(e)}") 40 | 41 | def send_notification(message: str): 42 | """发送通知""" 43 | settings = load_settings() 44 | 45 | # 发送短信通知 46 | if settings["sms"]["enabled"]: 47 | send_sms( 48 | message, 49 | settings["sms"]["api_key"], 50 | settings["sms"]["phone_number"] 51 | ) 52 | 53 | # 发送企业微信通知 54 | if settings["wechat"]["enabled"]: 55 | send_wechat( 56 | message, 57 | settings["wechat"]["webhook_url"] 58 | ) -------------------------------------------------------------------------------- /ui/articles/什么是ETF?.md: -------------------------------------------------------------------------------- 1 | # 什么是ETF? 2 | 3 | ETF(Exchange Traded Fund,交易型开放式指数基金)是一种在证券交易所上市交易的基金,它既具有开放式基金的特点,又可以在二级市场进行交易。ETF在我国A股市场的发展虽然起步较晚,但近年来发展迅速,已成为投资者重要的投资工具。 4 | 5 | ## ETF的主要特点 6 | 7 | 1. **交易便利性** 8 | - 可以在证券交易所实时交易 9 | - 交易费用相对较低 10 | - 流动性好,可以随时买卖 11 | 12 | 2. **投资门槛低** 13 | - 最低交易单位通常为100份 14 | - 适合中小投资者参与 15 | 16 | 3. **透明度高** 17 | - 每日公布持仓信息 18 | - 价格实时反映市场变化 19 | 20 | ## ETF的主要类型 21 | 22 | 1. **股票ETF** 23 | - 跟踪各类股票指数 24 | - 如沪深300ETF、中证500ETF等 25 | - 最受欢迎的ETF类型 26 | 27 | 2. **债券ETF** 28 | - 跟踪债券指数 29 | - 如国债ETF、企业债ETF等 30 | - 风险相对较低 31 | 32 | 3. **商品ETF** 33 | - 跟踪黄金、原油等大宗商品 34 | - 如黄金ETF、原油ETF等 35 | - 提供商品投资渠道 36 | 37 | 4. **跨境ETF** 38 | - 跟踪海外市场指数 39 | - 如恒生ETF、标普500ETF等 40 | - 提供海外投资机会 41 | 42 | ## 投资ETF的优势 43 | 44 | 1. **分散风险** 45 | - 通过一个产品投资多个标的 46 | - 降低单一投资风险 47 | 48 | 2. **成本优势** 49 | - 管理费率较低 50 | - 交易成本相对较低 51 | 52 | 3. **操作灵活** 53 | - 可以像股票一样交易 54 | - 支持做多和做空 55 | 56 | ## 投资注意事项 57 | 58 | 1. **了解跟踪标的** 59 | - 仔细研究ETF跟踪的指数 60 | - 了解成分股构成 61 | 62 | 2. **关注流动性** 63 | - 选择交易活跃的ETF 64 | - 避免流动性风险 65 | 66 | 3. **注意交易成本** 67 | - 考虑管理费、交易费等 68 | - 选择成本较低的ETF 69 | 70 | ## ETF市场发展现状 71 | 72 | 截至2023年,我国ETF市场规模已突破2万亿元,产品数量超过800只。主要特点: 73 | 74 | 1. **产品创新** 75 | - 主题ETF不断涌现 76 | - 策略型ETF快速发展 77 | 78 | 2. **投资者结构** 79 | - 机构投资者占比提升 80 | - 个人投资者参与度增加 81 | 82 | 3. **监管完善** 83 | - 相关法规逐步健全 84 | - 市场秩序持续改善 85 | 86 | ## 未来展望 87 | 88 | 1. **产品多元化** 89 | - 更多创新产品推出 90 | - 投资策略更加丰富 91 | 92 | 2. **市场深化** 93 | - 交易机制优化 94 | - 投资者教育加强 95 | 96 | 3. **国际化发展** 97 | - 跨境产品增加 98 | - 国际投资者参与度提升 99 | 100 | ETF作为现代金融工具,在我国市场具有广阔的发展前景。投资者可以根据自身需求选择合适的ETF产品,实现资产配置和投资目标。 101 | -------------------------------------------------------------------------------- /rl_model_finrl/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | 5 | # 设备配置 6 | DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") 7 | 8 | # A股ETF数据配置 9 | TICKER_LIST = [ 10 | "159915.SZ", # 易方达创业板ETF 11 | "510300.SH", # 华泰柏瑞沪深300ETF 12 | "510500.SH", # 南方中证500ETF 13 | "512100.SH", # 南方中证1000ETF 14 | "510050.SH", # 华夏上证50ETF 15 | "512880.SH", # 国泰中证军工ETF 16 | "512690.SH", # 鹏华中证医药卫生ETF 17 | "512980.SH", # 广发中证传媒ETF 18 | ] # A股ETF代码列表 19 | 20 | TECHNICAL_INDICATORS_LIST = [ 21 | "macd", 22 | "boll_ub", 23 | "boll_lb", 24 | "rsi_30", 25 | "cci_30", 26 | "dx_30", 27 | "close_30_sma", 28 | "close_60_sma", 29 | "volatility_30", 30 | "momentum_30", 31 | ] # 技术指标列表 32 | INDICATORS_NORMALIZE = True # 是否标准化指标 33 | TURBULENCE_THRESHOLD = 0.01 # 波动性阈值 34 | 35 | # 交易环境配置 36 | INITIAL_AMOUNT = 1000000.0 # 初始资金 37 | TRANSACTION_COST_PCT = 0.0003 # 交易成本百分比,ETF一般费率更低 38 | MAX_POSITION_PCT = 0.3 # 单个ETF最大仓位 39 | REWARD_SCALING = 1e-3 # 奖励缩放系数 40 | STATE_SPACE_DIM = len(TECHNICAL_INDICATORS_LIST) + 4 # 状态空间维度 (技术指标 + 持仓量 + 现金比例 + 大盘指标 + 情绪指标) 41 | ACTION_SPACE_DIM = 3 # 动作空间维度(买入、卖出、持有) 42 | 43 | # 训练配置 44 | TRAIN_START_DATE = "2018-01-01" 45 | TRAIN_END_DATE = "2021-12-31" 46 | TEST_START_DATE = "2022-01-01" 47 | TEST_END_DATE = "2023-12-31" 48 | TIME_INTERVAL = "1d" # 时间间隔(日频) 49 | 50 | # 数据源配置 51 | TUSHARE_TOKEN = "" # 需要填入您的tushare token 52 | USE_TUSHARE = True 53 | USE_AKSHARE = True 54 | 55 | # 智能体配置 56 | REPLAY_BUFFER_SIZE = 100000 # 回放缓冲区大小 57 | GAMMA = 0.99 # 折扣因子 58 | LEARNING_RATE = 1e-4 # 学习率 59 | BATCH_SIZE = 256 # 批处理大小 60 | TARGET_UPDATE_FREQ = 100 # 目标网络更新频率 61 | NUM_EPISODES = 1000 # 训练回合数 62 | EPSILON_START = 0.9 # 探索率初始值 63 | EPSILON_END = 0.05 # 探索率终值 64 | EPSILON_DECAY = 1000 # 探索率衰减系数 65 | 66 | # 路径配置 67 | DATA_SAVE_PATH = "data" 68 | MODEL_SAVE_PATH = "models" 69 | RESULTS_PATH = "results" 70 | TENSORBOARD_PATH = "runs" 71 | TENSORBOARD_LOG_PATH = "runs/tensorboard" 72 | 73 | # 创建必要的目录 74 | for path in [DATA_SAVE_PATH, MODEL_SAVE_PATH, RESULTS_PATH, TENSORBOARD_PATH]: 75 | os.makedirs(path, exist_ok=True) -------------------------------------------------------------------------------- /ui/pages/settings.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | import json 3 | import os 4 | from src.utils.logger import setup_logger 5 | 6 | logger = setup_logger() 7 | 8 | def load_settings(): 9 | settings_file = "config/settings.json" 10 | if os.path.exists(settings_file): 11 | with open(settings_file, "r", encoding="utf-8") as f: 12 | return json.load(f) 13 | return { 14 | "sms": { 15 | "enabled": False, 16 | "api_key": "", 17 | "phone_number": "" 18 | }, 19 | "wechat": { 20 | "enabled": False, 21 | "webhook_url": "" 22 | } 23 | } 24 | 25 | def save_settings(settings): 26 | os.makedirs("config", exist_ok=True) 27 | settings_file = "config/settings.json" 28 | with open(settings_file, "w", encoding="utf-8") as f: 29 | json.dump(settings, f, indent=4, ensure_ascii=False) 30 | 31 | def render_settings(): 32 | st.header("系统设置") 33 | 34 | # 加载设置 35 | settings = load_settings() 36 | 37 | # 短信通知设置 38 | st.subheader("短信通知设置") 39 | sms_enabled = st.checkbox("启用短信通知", settings["sms"]["enabled"]) 40 | if sms_enabled: 41 | sms_api_key = st.text_input("短信API密钥", settings["sms"]["api_key"]) 42 | sms_phone = st.text_input("接收手机号", settings["sms"]["phone_number"]) 43 | settings["sms"].update({ 44 | "enabled": True, 45 | "api_key": sms_api_key, 46 | "phone_number": sms_phone 47 | }) 48 | else: 49 | settings["sms"]["enabled"] = False 50 | 51 | # 微信通知设置 52 | st.subheader("微信通知设置") 53 | wechat_enabled = st.checkbox("启用微信通知", settings["wechat"]["enabled"]) 54 | if wechat_enabled: 55 | webhook_url = st.text_input("企业微信Webhook地址", settings["wechat"]["webhook_url"]) 56 | settings["wechat"].update({ 57 | "enabled": True, 58 | "webhook_url": webhook_url 59 | }) 60 | else: 61 | settings["wechat"]["enabled"] = False 62 | 63 | # 保存按钮 64 | if st.button("保存设置"): 65 | try: 66 | save_settings(settings) 67 | st.success("设置保存成功!") 68 | except Exception as e: 69 | st.error(f"保存设置时出错: {str(e)}") 70 | logger.error(f"保存设置时出错: {str(e)}") 71 | -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | import os 3 | from src.utils.logger import setup_logger 4 | from ui.pages.sidebar import render_sidebar 5 | from ui.pages.backtest import render_backtest 6 | from ui.pages.market import render_market 7 | from ui.pages.settings import render_settings 8 | 9 | # 设置日志 10 | logger = setup_logger() 11 | 12 | # 设置页面 13 | st.set_page_config( 14 | page_title="量化策略回测系统", 15 | layout="wide", 16 | initial_sidebar_state="expanded" 17 | ) 18 | 19 | def main(): 20 | st.title("量化策略回测系统") 21 | 22 | # 渲染侧边栏并获取参数 23 | params = render_sidebar() 24 | if params is None: 25 | return 26 | 27 | # 创建标签页 28 | tab1, tab2, tab3, tab4, tab5 = st.tabs(["系统介绍", "回测", "实盘记录", "文章", "系统设置"]) 29 | 30 | # 系统介绍标签页 31 | with tab1: 32 | st.header("系统介绍") 33 | try: 34 | with open("ui/help/intro.md", "r", encoding="utf-8") as f: 35 | st.markdown(f.read()) 36 | except FileNotFoundError: 37 | st.error("找不到系统介绍文件") 38 | except Exception as e: 39 | st.error(f"读取系统介绍文件时出错: {str(e)}") 40 | 41 | # 回测标签页 42 | with tab2: 43 | render_backtest(params) 44 | 45 | # 实盘记录标签页 46 | with tab3: 47 | render_market(params) 48 | 49 | # 文章标签页 50 | with tab4: 51 | st.header("文章列表") 52 | try: 53 | # 获取articles目录下的所有md文件 54 | articles_dir = "ui/articles" 55 | articles = [f for f in os.listdir(articles_dir) if f.endswith('.md')] 56 | 57 | # 创建两列布局 58 | col1, col2 = st.columns([1, 2]) 59 | 60 | with col1: 61 | st.subheader("目录") 62 | # 创建文章列表 63 | for article in articles: 64 | if st.button(article.replace('.md', ''), key=f"article_{article}"): 65 | st.session_state.selected_article = article 66 | 67 | with col2: 68 | # 显示选中的文章内容 69 | selected_article = getattr(st.session_state, 'selected_article', None) 70 | if selected_article: 71 | with open(os.path.join(articles_dir, selected_article), "r", encoding="utf-8") as f: 72 | st.markdown(f.read()) 73 | else: 74 | st.info("请从左侧选择要阅读的文章") 75 | except Exception as e: 76 | st.error(f"读取文章时出错: {str(e)}") 77 | 78 | # 系统设置标签页 79 | with tab5: 80 | render_settings() 81 | 82 | if __name__ == "__main__": 83 | main() -------------------------------------------------------------------------------- /src/indicators/trailing_stop.py: -------------------------------------------------------------------------------- 1 | import backtrader as bt 2 | from loguru import logger 3 | 4 | class TrailingStop(bt.Indicator): 5 | """ 6 | 追踪止损指标 7 | 当价格上涨时,止损线会跟随上涨,但不会随价格下跌而下跌 8 | """ 9 | lines = ('trailing_stop',) # 声明指标线 10 | params = ( 11 | ('trailing', 0.02), # 追踪止损比例,默认2% 12 | ) 13 | 14 | plotinfo = dict(subplot=False) # 绘制在主图上 15 | 16 | def __init__(self): 17 | super(TrailingStop, self).__init__() 18 | self.plotinfo.plotname = f'TrailStop ({self.p.trailing:.1%})' # 图例名称 19 | self.max_price = float('-inf') # 初始化最高价为负无穷 20 | self.reset_requested = False # 用于标记是否需要重置 21 | self.in_trade = False # 标记是否在交易中 22 | self.entry_price = None # 记录入场价格 23 | self._prev_stop = 0.0 # 记录前一个止损价 24 | self._last_price = None # 添加上次价格记录 25 | self._call_count = 0 # 添加调用计数 26 | 27 | logger.info("初始化追踪止损指标 - 止损比例: {:.1%}", self.p.trailing) 28 | 29 | def reset(self, price=None): 30 | """重置最高价和止损价""" 31 | # 使用传入的价格或当前收盘价 32 | self.entry_price = price if price is not None else self.data.close[0] 33 | self.max_price = self.entry_price 34 | self._prev_stop = self.max_price * (1.0 - self.p.trailing) 35 | self.lines.trailing_stop[0] = self._prev_stop 36 | self.reset_requested = False 37 | self.in_trade = True 38 | 39 | def next(self): 40 | self._call_count += 1 41 | current_price = self.data.close[0] 42 | self._last_price = current_price 43 | 44 | # 如果在交易中,更新最高价和止损价 45 | if self.in_trade: 46 | # 如果价格创新高 47 | if current_price > self.max_price: 48 | self.max_price = current_price 49 | new_stop = self.max_price * (1.0 - self.p.trailing) 50 | # 确保新的止损价不低于之前的止损价 51 | self._prev_stop = max(new_stop, self._prev_stop) 52 | self.lines.trailing_stop[0] = self._prev_stop 53 | else: 54 | # 保持之前的止损价格 55 | self.lines.trailing_stop[0] = self._prev_stop 56 | else: 57 | # 不在交易中,设置止损价为0 58 | self.lines.trailing_stop[0] = 0.0 59 | # 重置最高价,为下次交易做准备 60 | self.max_price = float('-inf') 61 | self._prev_stop = 0.0 62 | self.entry_price = None 63 | 64 | def stop_tracking(self): 65 | """停止追踪""" 66 | if self.in_trade: 67 | logger.info("停止追踪止损 - 入场价: {:.2f}, 最高价: {:.2f}, 最终止损价: {:.2f}", 68 | self.entry_price, self.max_price, self._prev_stop) 69 | self.in_trade = False -------------------------------------------------------------------------------- /docs/etf_rotation_strategy.md: -------------------------------------------------------------------------------- 1 | # ETF轮动策略(ETF Rotation Strategy) 2 | 3 | ## 策略概述 4 | 5 | ETF轮动策略是一种基于动量因子的资产配置方法,该策略通过定期评估多个ETF的相对强度,选择表现最佳的ETF进行投资。策略核心思想是"追涨",即买入近期表现最好的资产类别,利用市场趋势性和动量延续性特征获取超额收益。本策略主要应用于不同行业、主题或资产类别的ETF之间进行轮动配置。 6 | 7 | ## 核心理念 8 | 9 | 1. **动量因子驱动**:利用价格动量作为主要选股因子,买入强势ETF 10 | 2. **定期调仓机制**:按固定时间间隔对ETF池进行动量排名和调仓 11 | 3. **集中持仓策略**:只持有排名靠前的少数ETF,提高收益率 12 | 4. **风险管理体系**:设置追踪止损和最大回撤限制,控制下行风险 13 | 5. **资金管理方法**:根据波动率和风险比例动态调整仓位大小 14 | 15 | ## 动量指标构建 16 | 17 | 策略使用简单但有效的动量计算方法: 18 | 19 | 1. **动量计算**: 20 | - 使用N日价格变化百分比作为动量指标(默认为20日) 21 | - 计算公式:Momentum = (当前价格 / N日前价格) - 1 22 | 23 | 2. **排名机制**: 24 | - 对所有ETF的动量指标进行排序 25 | - 选择动量排名最高的前N只ETF(默认为1只) 26 | 27 | 3. **风险度量**: 28 | - 使用14日ATR(真实波动幅度均值)评估ETF的波动风险 29 | - 结合价格百分比设置动态止损线 30 | 31 | ## 交易规则 32 | 33 | 1. **调仓频率**: 34 | - 每隔固定天数(默认30天)进行一次调仓 35 | - 避免过于频繁交易,减少交易成本 36 | 37 | 2. **选择标准**: 38 | - 选择动量最强的前N只ETF(默认为1只) 39 | - 卖出不再位于前N名的持仓ETF 40 | 41 | 3. **头寸规模计算**: 42 | - 基于风险金额确定头寸大小 43 | - 考虑ATR和价格波动设置风险敞口 44 | - 调整为100股的整数倍交易单位 45 | 46 | 4. **交易执行**: 47 | - 在调仓日同时执行买入和卖出操作 48 | - 确保每笔交易至少达到最小交易单位(100股) 49 | 50 | ## 风险控制机制 51 | 52 | 1. **追踪止损**: 53 | - 设置百分比追踪止损(默认1.5%) 54 | - 价格回撤超过最高点一定比例时触发卖出 55 | 56 | 2. **最大回撤限制**: 57 | - 当组合回撤超过设定阈值(默认15%)时清仓 58 | - 避免在大幅下跌行情中持续损失 59 | 60 | 3. **风险头寸限制**: 61 | - 单次交易风险敞口不超过总资产的特定比例(默认2%) 62 | - 根据ETF波动性动态调整持仓规模 63 | 64 | 4. **资金分配**: 65 | - 在持有多只ETF时平均分配资金 66 | - 保留5%现金缓冲用于手续费和滑点 67 | 68 | ## 参数设置 69 | 70 | 主要参数包括: 71 | 72 | ``` 73 | # 策略参数 74 | momentum_period: 20 # 动量计算周期 75 | rebalance_interval: 30 # 调仓间隔天数 76 | num_positions: 1 # 持有前N个ETF 77 | risk_ratio: 0.02 # 单次交易风险比率 78 | max_drawdown: 0.15 # 最大回撤限制 79 | trail_percent: 1.5 # 追踪止损百分比 80 | verbose: True # 是否输出详细日志 81 | ``` 82 | 83 | ## 策略逻辑流程 84 | 85 | 1. **初始化**: 86 | - 为每个ETF设置动量和ATR指标 87 | - 记录ETF代码映射关系 88 | 89 | 2. **调仓检查**: 90 | - 确认当天是否为调仓日 91 | - 非调仓日仅检查止损条件 92 | 93 | 3. **动量排名**: 94 | - 计算所有ETF的动量指标 95 | - 按动量值从高到低排序 96 | 97 | 4. **仓位调整**: 98 | - 卖出不在排名前列的持仓ETF 99 | - 买入新进入排名的ETF 100 | - 计算新购买ETF的头寸大小 101 | 102 | 5. **风险监控**: 103 | - 每个交易日检查追踪止损条件 104 | - 监控总体回撤是否超过限制 105 | 106 | 6. **交易执行**: 107 | - 生成买入或卖出订单 108 | - 记录交易原因和执行情况 109 | 110 | ## 适用场景 111 | 112 | 此策略适用于: 113 | 114 | 1. 不同行业ETF之间的轮动(如科技、金融、消费等) 115 | 2. 不同资产类别ETF的配置(如股票、债券、商品等) 116 | 3. 不同地区市场ETF的轮动(如A股、港股、美股等) 117 | 4. 主题ETF之间的风格轮动(如价值、成长、中小盘等) 118 | 119 | ## 优势与局限性 120 | 121 | **优势**: 122 | - 跟随市场强势板块,利用趋势延续效应 123 | - 交易规则简单明确,易于实施和监控 124 | - 定期调仓减少交易频率和成本 125 | - 风险控制机制能有效限制下行风险 126 | 127 | **局限性**: 128 | - 市场反转时可能面临"追涨杀跌"风险 129 | - 频繁的板块轮换可能增加交易成本 130 | - 依赖历史动量数据,无法预测未来走势 131 | - 在横盘震荡市场中表现可能不佳 132 | 133 | ## 未来优化方向 134 | 135 | 1. 引入相对强度指标(RSI)作为辅助判断指标 136 | 2. 添加波动率因子过滤高风险ETF 137 | 3. 考虑价值指标与动量指标相结合 138 | 4. 增加择时功能规避市场系统性风险 139 | 5. 引入机器学习方法优化ETF选择和权重分配 -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 量化交易策略回测系统 2 | 3 | 一个基于 Python Backtrader回测框架的量化交易策略回测系统,使用 Streamlit 构建 Web 界面,支持多种交易策略的回测和分析。 4 | 注:ETF实盘交易使用东方财富证券可免收每笔交易至少5元的手续费,所以本策略回测系统只设置了佣金万2.5,以便更好的模拟实盘情况。 5 | 6 | ## 功能特点 7 | 8 | - 🚀 支持多种交易策略,采用工厂模式便于扩展 9 | - 📊 实时数据获取(支持 Tushare 和 AKShare) 10 | - 📈 交互式图表展示(K线、均线、交易点位等) 11 | - 💹 详细的回测指标(夏普比率、最大回撤、胜率等) 12 | - 🔄 T+1 交易规则支持 13 | - ⚠️ 风险控制(追踪止损、最大回撤限制等) 14 | - 💰 精确的交易费用计算(佣金等) 15 | - 📝 完整的交易日志 16 | 17 | ## 功能截图 18 | ![image](https://github.com/user-attachments/assets/0af62636-dd11-44d5-8e8d-775db56df64e) 19 | 20 | ## 内置策略 21 | 22 | ### 双均线策略 23 | - 使用快速和慢速移动平均线的交叉产生交易信号 24 | - 支持追踪止损进行风险控制 25 | - 基于 ATR 动态计算持仓规模 26 | 27 | ### 市场情绪策略 28 | - 基于市场情绪指标进行交易,在极端情绪时入场 29 | - 使用 EMA 趋势确认,要求价格和 EMA 同步上涨 30 | - 动态调整 ATR 止盈倍数,结合布林带波动率和情绪因子 31 | - 分层建仓,在不同情绪阈值增加仓位 32 | - 支持追踪止损和最大回撤限制 33 | 34 | ## 安装使用 35 | 36 | 1. 克隆仓库: 37 | ```bash 38 | git clone https://github.com/sencloud/ETF-Strategies.git 39 | cd ETF-Strategies 40 | ``` 41 | 42 | 2. 安装依赖,使用miniconda: 43 | ```bash 44 | # pip设置源 45 | pip config set global.index-url https://mirrors.aliyun.com/pypi/simple 46 | pip config set install.trusted-host mirrors.aliyun.com 47 | conda create -n etf python=3.10 48 | conda activate etf 49 | pip install -r requirements.txt 50 | ``` 51 | 52 | 3. 安装TA-lib: 53 | ```bash 54 | conda install -c conda-forge ta-lib 55 | ``` 56 | 57 | 4. 运行系统: 58 | ```bash 59 | streamlit run app.py 60 | ``` 61 | 62 | ## 系统要求 63 | 64 | - Python 3.10+ 65 | - 依赖包: 66 | - streamlit 67 | - backtrader 68 | - pandas 69 | - numpy 70 | - plotly 71 | - tushare 72 | - akshare 73 | - loguru 74 | 75 | ## 目录结构 76 | 77 | ``` 78 | ETF-Strategies/ 79 | ├── app.py # 主程序入口 80 | ├── requirements.txt # 依赖包列表 81 | ├── src/ 82 | │ ├── data/ # 数据加载模块 83 | │ ├── strategies/ # 交易策略模块 84 | │ ├── indicators/ # 技术指标模块 85 | │ └── utils/ # 工具函数模块 86 | └── logs/ # 日志文件目录 87 | ``` 88 | 89 | ## 如何添加新策略 90 | 91 | 1. 在 `src/strategies` 目录下创建新的策略类: 92 | ```python 93 | import backtrader as bt 94 | 95 | class YourStrategy(bt.Strategy): 96 | params = ( 97 | # 定义策略参数 98 | ) 99 | 100 | def __init__(self): 101 | # 初始化策略 102 | pass 103 | 104 | def next(self): 105 | # 实现交易逻辑 106 | pass 107 | ``` 108 | 109 | 2. 在 `strategy_factory.py` 中注册策略: 110 | ```python 111 | from .your_strategy import YourStrategy 112 | StrategyFactory.register_strategy("策略名称", YourStrategy) 113 | ``` 114 | 115 | 3. 在 `app.py` 中修改添加策略需要的参数,如下示意: 116 | ```python 117 | # 移动平均线参数(仅在选择双均线策略时显示) 118 | if strategy_name == "双均线策略": 119 | st.subheader("均线参数") 120 | col1, col2 = st.columns(2) 121 | with col1: 122 | fast_period = st.number_input("快线周期", value=10, min_value=1) 123 | with col2: 124 | slow_period = st.number_input("慢线周期", value=30, min_value=1) 125 | ... 126 | # 如果是双均线策略,添加特定参数 127 | if strategy_name == "双均线策略": 128 | strategy_params.update({ 129 | 'fast_period': fast_period, 130 | 'slow_period': slow_period, 131 | }) 132 | ... 133 | ``` 134 | 135 | ## 回测指标说明 136 | 137 | - 总收益率:策略最终收益相对于初始资金的百分比 138 | - 夏普比率:超额收益相对于波动率的比率 139 | - 最大回撤:策略执行期间的最大亏损百分比 140 | - 胜率:盈利交易占总交易次数的比例 141 | - 盈亏比:平均盈利交易额与平均亏损交易额的比值 142 | - 系统质量指数(SQN):衡量交易系统的稳定性 143 | 144 | ## 风险提示 145 | 146 | 本系统仅供学习和研究使用,不构成任何投资建议。使用本系统进行实盘交易需要自行承担风险。 147 | 148 | ## 其他 149 | 如果你喜欢我的项目,可以给我买杯咖啡: 150 | image 151 | 152 | ## 贡献指南 153 | 154 | 欢迎提交 Issue 和 Pull Request 来帮助改进这个项目。 155 | 156 | ## 许可证 157 | 158 | MIT License 159 | -------------------------------------------------------------------------------- /docs/market_sentiment_strategy.md: -------------------------------------------------------------------------------- 1 | # 市场情绪指标策略(Market Sentiment Strategy) 2 | 3 | ## 策略概述 4 | 5 | 市场情绪指标策略是一种基于市场情绪数据的量化交易策略,通过监测市场情绪指标和技术分析指标的变化,在极端情绪区域进行逆向投资。该策略主要针对ETF等宽基指数产品,利用市场恐慌情绪作为买入信号,过度乐观情绪作为卖出信号,结合趋势识别和仓位管理,实现低买高卖的交易策略。 6 | 7 | ## 核心理念 8 | 9 | 1. **情绪逆向投资**:在市场极度恐慌时买入,在市场过度乐观时卖出 10 | 2. **多维度情绪度量**:综合考虑多个指数的技术指标、波动率、RSI等数据 11 | 3. **动态仓位管理**:根据市场情绪程度和波动环境动态调整仓位大小 12 | 4. **阶梯式止盈机制**:设置多个止盈阈值,随着情绪回升逐步减仓 13 | 5. **趋势识别与风险控制**:识别下跌趋势时避免操作,设置最大回撤限制 14 | 15 | ## 情绪指标构建 16 | 17 | 策略使用综合情绪指标,由以下数据源和指标构成: 18 | 19 | 1. **数据源**: 20 | - 上证指数(权重0.2) 21 | - 沪深300指数(权重0.5) 22 | - 上证50指数(权重0.2) 23 | - 金融指数(权重0.1) 24 | 25 | 2. **核心指标**: 26 | - GARCH波动率模型:估计条件波动率,区分上涨波动率和下跌波动率 27 | - RSI指标:使用21日EMA平滑处理,考虑顶背离和价格趋势 28 | - 布林带位置:价格相对均线的偏离程度 29 | - 成交量异动:相对于20日均线的成交量比率 30 | - 复合趋势检测:结合价格、成交量和RSI数据识别市场趋势 31 | 32 | 3. **改进特性**: 33 | - 混合归一化:使用动态窗口归一化处理各指标数据 34 | - S型钝化曲线:避免情绪指标出现过度敏感的极端值 35 | - 波动率过滤机制:在高波动期间降低情绪指标权重 36 | 37 | ## 交易信号与仓位管理 38 | 39 | 根据情绪指标得分,将买入信号分为三类: 40 | 41 | 1. **核心信号**(情绪分数 < 2.5): 42 | - 最高权重(50%目标仓位) 43 | - 资金占比至少10%或50万元 44 | - 不受价格限制约束 45 | 46 | 2. **次级信号**(情绪分数 2.5-10): 47 | - 中等权重(30%目标仓位) 48 | - 资金占比至少5%或20万元 49 | - 需满足价格低于持仓均价5%以上的条件 50 | 51 | 3. **轻仓信号**(情绪分数 10-15): 52 | - 最低权重(10%目标仓位) 53 | - 需满足价格和波动率条件 54 | 55 | ## 止盈策略 56 | 57 | 设计了三层阶梯式止盈机制: 58 | 59 | 1. **一级止盈**(情绪分数 > 10): 60 | - 减仓50% 61 | - 最小盈利要求1% 62 | 63 | 2. **二级止盈**(情绪分数 > 20): 64 | - 减仓30% 65 | - 最小盈利要求1% 66 | 67 | 3. **三级止盈**(情绪分数 > 80): 68 | - 全部清仓 69 | - 最小盈利要求1% 70 | 71 | 4. **动量保护机制**: 72 | - 当5日涨幅超过15%时减仓50% 73 | - 高波动环境(波动率>25%)减仓80% 74 | 75 | ## 风险控制机制 76 | 77 | 1. **最大回撤限制**:资产回撤超过15%时强制平仓 78 | 2. **趋势识别保护**:下跌趋势环境避免开仓 79 | 3. **追踪止损**:可选启用追踪止损功能,默认设置为2% 80 | 4. **T+1交易限制**:符合中国A股市场T+1交易规则 81 | 5. **价格涨跌停限制**:价格变动超过10%时不交易 82 | 83 | ## ETF分红处理 84 | 85 | 策略内置ETF分红处理功能: 86 | 87 | 1. 自动获取并缓存ETF历史分红数据 88 | 2. 在分红日调整持仓成本和现金 89 | 3. 分红收益单独统计,与价格收益分开计算 90 | 91 | ## 参数设置 92 | 93 | 主要参数包括: 94 | 95 | ``` 96 | # 情绪阈值参数 97 | sentiment_core: 2.5 # 核心信号情绪阈值 98 | sentiment_secondary: 10.0 # 次级信号情绪阈值 99 | sentiment_warning: 15.0 # 预警信号情绪阈值 100 | sentiment_sell_1: 10.0 # 第一阶段止盈阈值 101 | sentiment_sell_2: 20.0 # 第二阶段止盈阈值 102 | sentiment_sell_3: 80.0 # 第三阶段止盈阈值 103 | 104 | # 仓位参数 105 | position_core: 0.5 # 核心信号仓位比例 106 | position_grid_step: 0.1 # 网格加仓步长 107 | position_warning: 0.05 # 预警信号初始仓位 108 | position_bb_signal: 0.1 # 布林带突破信号仓位 109 | 110 | # 风险控制参数 111 | min_momentum: -5.0 # 最小动量要求(10日涨跌幅) 112 | vol_threshold: 25.0 # 高波动率阈值 113 | quick_profit_days: 5 # 短期获利天数 114 | quick_profit_pct: 15.0 # 短期获利比例阈值 115 | high_vol_profit_pct: 20.0 # 高波动市场获利阈值 116 | trail_percent: 2.0 # 追踪止损百分比 117 | risk_ratio: 0.02 # 单次交易风险比率 118 | max_drawdown: 0.15 # 最大回撤限制 119 | ``` 120 | 121 | ## 策略逻辑流程 122 | 123 | 1. **数据准备**: 124 | - 获取市场情绪数据 125 | - 计算技术指标(RSI、布林带、波动率等) 126 | 127 | 2. **市场状态识别**: 128 | - 识别当前是否处于下跌趋势 129 | - 计算市场波动率环境 130 | 131 | 3. **交易信号生成**: 132 | - 根据情绪分数生成交易信号 133 | - 根据市场状态过滤信号 134 | 135 | 4. **仓位计算**: 136 | - 根据信号类型确定目标仓位 137 | - 根据波动率调整仓位规模 138 | 139 | 5. **执行交易**: 140 | - 计算实际交易数量 141 | - 执行买入或卖出操作 142 | 143 | 6. **止盈管理**: 144 | - 监控情绪指标变化 145 | - 达到阈值时执行阶梯式减仓 146 | 147 | 7. **风险控制**: 148 | - 监控回撤和止损条件 149 | - 必要时执行风险管理措施 150 | 151 | ## 适用场景 152 | 153 | 此策略适用于: 154 | 155 | 1. ETF等宽基指数产品交易 156 | 2. 中长期价值投资者的择时工具 157 | 3. 市场剧烈波动环境下的逆向投资 158 | 4. 与其他策略组合形成多策略投资组合 159 | 160 | ## 优势与局限性 161 | 162 | **优势**: 163 | - 利用市场情绪波动带来的投资机会 164 | - 多维度技术指标综合评估,减少误判 165 | - 动态仓位管理适应不同市场环境 166 | - 阶梯式止盈保护收益 167 | 168 | **局限性**: 169 | - 依赖历史情绪数据的可靠性 170 | - 仓位调整频率较高,可能增加交易成本 171 | - 在趋势明确的单边市场中可能表现不佳 172 | - 需要持续调整参数以适应不同市场环境 173 | 174 | ## 未来优化方向 175 | 176 | 1. 加入基本面数据作为辅助指标 177 | 2. 优化GARCH模型以提高波动率预测准确性 178 | 3. 引入机器学习方法动态调整参数 179 | 4. 拓展到更多资产类别的情绪预测 180 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | cache/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | cover/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | db.sqlite3 63 | db.sqlite3-journal 64 | 65 | # Flask stuff: 66 | instance/ 67 | .webassets-cache 68 | 69 | # Scrapy stuff: 70 | .scrapy 71 | 72 | # Sphinx documentation 73 | docs/_build/ 74 | 75 | # PyBuilder 76 | .pybuilder/ 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # IPython 83 | profile_default/ 84 | ipython_config.py 85 | 86 | # pyenv 87 | # For a library or package, you might want to ignore these files since the code is 88 | # intended to run in multiple environments; otherwise, check them in: 89 | # .python-version 90 | 91 | # pipenv 92 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 93 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 94 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 95 | # install all needed dependencies. 96 | #Pipfile.lock 97 | 98 | # UV 99 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. 100 | # This is especially recommended for binary packages to ensure reproducibility, and is more 101 | # commonly ignored for libraries. 102 | #uv.lock 103 | 104 | # poetry 105 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 106 | # This is especially recommended for binary packages to ensure reproducibility, and is more 107 | # commonly ignored for libraries. 108 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 109 | #poetry.lock 110 | 111 | # pdm 112 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 113 | #pdm.lock 114 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 115 | # in version control. 116 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 117 | .pdm.toml 118 | .pdm-python 119 | .pdm-build/ 120 | 121 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 122 | __pypackages__/ 123 | 124 | # Celery stuff 125 | celerybeat-schedule 126 | celerybeat.pid 127 | 128 | # SageMath parsed files 129 | *.sage.py 130 | 131 | # Environments 132 | .env 133 | .venv 134 | env/ 135 | venv/ 136 | ENV/ 137 | env.bak/ 138 | venv.bak/ 139 | 140 | # Spyder project settings 141 | .spyderproject 142 | .spyproject 143 | 144 | # Rope project settings 145 | .ropeproject 146 | 147 | # mkdocs documentation 148 | /site 149 | 150 | # mypy 151 | .mypy_cache/ 152 | .dmypy.json 153 | dmypy.json 154 | 155 | # Pyre type checker 156 | .pyre/ 157 | 158 | # pytype static type analyzer 159 | .pytype/ 160 | 161 | # Cython debug symbols 162 | cython_debug/ 163 | 164 | # PyCharm 165 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 166 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 167 | # and can be added to the global gitignore or merged into this file. For a more nuclear 168 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 169 | #.idea/ 170 | 171 | # Ruff stuff: 172 | .ruff_cache/ 173 | 174 | # PyPI configuration file 175 | .pypirc 176 | -------------------------------------------------------------------------------- /src/utils/backtest_engine.py: -------------------------------------------------------------------------------- 1 | import backtrader as bt 2 | import pandas as pd 3 | from datetime import datetime 4 | from loguru import logger 5 | from .analysis import Analysis 6 | from .plot import Plot 7 | import sys 8 | 9 | class BacktestEngine: 10 | def __init__(self, strategy_class, data_feed, cash=100000.0, commission=0.00025, strategy_params=None): 11 | """初始化回测引擎 12 | Args: 13 | strategy_class: 策略类 14 | data_feed: 数据源或数据源列表 15 | cash: 初始资金 16 | commission: 股票交易手续费率 17 | strategy_params: 策略参数 18 | """ 19 | self.cerebro = bt.Cerebro() 20 | self.cerebro.broker.setcash(cash) 21 | 22 | # 设置手续费 23 | if isinstance(data_feed, list): 24 | # 默认设置股票/ETF手续费率 25 | self.cerebro.broker.setcommission(commission=commission) 26 | 27 | # 创建多数据源的特定手续费处理 28 | for i, feed in enumerate(data_feed): 29 | if i == 1: # 期货数据源使用固定手续费 30 | # 为期货设置固定手续费 31 | self.cerebro.broker.addcommissioninfo( 32 | bt.CommissionInfo( 33 | commission=1.51/100000, # 固定手续费转换为相对值 34 | margin=0.10, # 保证金比例 35 | mult=10, # 合约乘数 36 | commtype=0 # 固定手续费类型(0=固定手续费) 37 | ) 38 | ) 39 | else: 40 | self.cerebro.broker.setcommission(commission=commission) 41 | 42 | # 添加数据源 43 | try: 44 | if isinstance(data_feed, list): 45 | # 如果是数据源列表,添加所有数据源 46 | for feed in data_feed: 47 | self.cerebro.adddata(feed) 48 | else: 49 | # 如果是单个数据源,直接添加 50 | self.cerebro.adddata(data_feed) 51 | except Exception as e: 52 | logger.warning(f"添加数据源时出错: {str(e)}") 53 | # 如果出错,尝试不带ts_code参数添加 54 | if hasattr(data_feed, 'params'): 55 | data_feed.params.pop('ts_code', None) 56 | self.cerebro.adddata(data_feed) 57 | 58 | # 添加策略和参数 59 | if strategy_params: 60 | self.cerebro.addstrategy(strategy_class, **strategy_params) 61 | else: 62 | self.cerebro.addstrategy(strategy_class) 63 | 64 | # 添加分析器 65 | self.cerebro.addanalyzer(bt.analyzers.SharpeRatio, _name='sharpe', riskfreerate=0.02) 66 | self.cerebro.addanalyzer(bt.analyzers.DrawDown, _name='drawdown') 67 | self.cerebro.addanalyzer(bt.analyzers.Returns, _name='returns') 68 | self.cerebro.addanalyzer(bt.analyzers.TradeAnalyzer, _name='trades') 69 | self.cerebro.addanalyzer(bt.analyzers.VWR, _name='vwr') # 波动率加权收益 70 | self.cerebro.addanalyzer(bt.analyzers.SQN, _name='sqn') # 系统质量指数 71 | 72 | # 为每个数据源添加单独的交易记录分析器 73 | if isinstance(data_feed, list): 74 | for feed in data_feed: 75 | self.cerebro.addanalyzer(bt.analyzers.Transactions, _name=f'txn_{feed._name}') 76 | else: 77 | self.cerebro.addanalyzer(bt.analyzers.Transactions, _name='txn') 78 | 79 | self.trades = [] # 存储交易记录 80 | 81 | def run(self): 82 | """运行回测""" 83 | results = self.cerebro.run() 84 | 85 | self.strategy = results[0] 86 | 87 | analysis = self._get_analysis(self.strategy) 88 | 89 | logger.info("=== 回测统计 ===") 90 | logger.info(f"总收益率: {analysis['total_return']:.2%}") 91 | logger.info(f"年化收益率: {analysis['annualized_return']:.2%}") 92 | logger.info(f"夏普比率: {analysis['sharpe_ratio']:.2f}") 93 | logger.info(f"最大回撤: {analysis['max_drawdown']:.2%}") 94 | logger.info(f"胜率: {analysis['win_rate']:.2%}") 95 | logger.info(f"盈亏比: {analysis['profit_factor']:.2f}") 96 | logger.info(f"系统质量指数(SQN): {analysis['sqn']:.2f}") 97 | 98 | return analysis 99 | 100 | def plot(self, **kwargs): 101 | """使用Plotly绘制交互式回测结果""" 102 | fig = Plot(self.strategy).plot() 103 | return fig 104 | 105 | def _get_analysis(self, strategy): 106 | """获取回测分析结果""" 107 | analysis = Analysis()._get_analysis(self, strategy) 108 | return analysis -------------------------------------------------------------------------------- /ui/pages/market.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | import pandas as pd 3 | import os 4 | from datetime import datetime, timedelta 5 | from src.utils.logger import setup_logger 6 | from src.trading.market_executor import MarketExecutor 7 | import threading 8 | import tushare as ts 9 | 10 | logger = setup_logger() 11 | 12 | def start_trading(params): 13 | """启动实盘交易""" 14 | try: 15 | # 从session state获取tushare token 16 | if not params['tushare_token']: 17 | st.error("请先在侧边栏设置Tushare Token") 18 | return 19 | 20 | # 设置tushare token 21 | ts.set_token(params['tushare_token']) 22 | pro = ts.pro_api() 23 | 24 | # 获取上证50成分股 25 | today = datetime.now().strftime('%Y%m%d') 26 | sz50 = pro.index_weight(index_code='000016.SH', trade_date=today) 27 | if sz50.empty: 28 | # 如果当天数据不可用,尝试获取最近的数据 29 | dates = pro.trade_cal(exchange='', start_date=(datetime.now() - timedelta(days=30)).strftime('%Y%m%d'), 30 | end_date=today, is_open='1') 31 | for date in sorted(dates['cal_date'].tolist(), reverse=True): 32 | sz50 = pro.index_weight(index_code='000016.SH', trade_date=date) 33 | if not sz50.empty: 34 | break 35 | 36 | if sz50.empty: 37 | st.error("未获取到上证50成分股列表") 38 | return 39 | symbols = sz50['con_code'].tolist() 40 | logger.info(f"上证50成分股: {symbols}") 41 | 42 | # 创建并启动MarketExecutor 43 | executor = MarketExecutor(symbols, tushare_token=params['tushare_token']) 44 | thread = threading.Thread(target=executor.run_continuously) 45 | thread.daemon = True 46 | thread.start() 47 | 48 | st.session_state.trading_thread = thread 49 | st.session_state.is_trading = True 50 | st.success("实盘交易已启动") 51 | 52 | except Exception as e: 53 | st.error(f"启动实盘交易失败: {str(e)}") 54 | logger.error(f"启动实盘交易失败: {str(e)}") 55 | import traceback 56 | traceback.print_exc() 57 | 58 | def stop_trading(): 59 | """停止实盘交易""" 60 | try: 61 | if hasattr(st.session_state, 'trading_thread'): 62 | # 这里需要实现一个优雅的停止机制 63 | st.session_state.is_trading = False 64 | st.session_state.trading_thread.join(timeout=5) 65 | st.success("实盘交易已停止") 66 | except Exception as e: 67 | st.error(f"停止实盘交易失败: {str(e)}") 68 | logger.error(f"停止实盘交易失败: {str(e)}") 69 | 70 | def render_market(params): 71 | st.header("实盘交易记录") 72 | 73 | # 创建交易记录目录 74 | records_dir = "data/trading_records" 75 | os.makedirs(records_dir, exist_ok=True) 76 | 77 | # 初始化session state 78 | if 'is_trading' not in st.session_state: 79 | st.session_state.is_trading = False 80 | 81 | # 添加启动/停止按钮 82 | col1, col2 = st.columns(2) 83 | with col1: 84 | if not st.session_state.is_trading: 85 | if st.button("启动实盘", type="primary"): 86 | start_trading(params) 87 | with col2: 88 | if st.session_state.is_trading: 89 | if st.button("停止实盘", type="secondary"): 90 | stop_trading() 91 | 92 | # 显示当前状态 93 | status = "运行中" if st.session_state.is_trading else "已停止" 94 | st.info(f"实盘交易状态: {status}") 95 | 96 | # 获取最新的交易记录文件 97 | record_files = [f for f in os.listdir(records_dir) if f.endswith('.csv')] 98 | if not record_files: 99 | st.info("暂无交易记录") 100 | return 101 | 102 | latest_file = max(record_files, key=lambda x: os.path.getctime(os.path.join(records_dir, x))) 103 | 104 | # 读取并显示交易记录 105 | try: 106 | df = pd.read_csv(os.path.join(records_dir, latest_file)) 107 | df['timestamp'] = pd.to_datetime(df['timestamp']) 108 | df = df.sort_values('timestamp', ascending=False) 109 | 110 | # 显示交易统计 111 | col1, col2, col3 = st.columns(3) 112 | with col1: 113 | st.metric("总交易次数", len(df)) 114 | with col2: 115 | st.metric("买入次数", len(df[df['action'] == 'buy'])) 116 | with col3: 117 | st.metric("卖出次数", len(df[df['action'] == 'sell'])) 118 | 119 | # 显示交易记录表格 120 | st.dataframe(df) 121 | 122 | except Exception as e: 123 | st.error(f"读取交易记录时出错: {str(e)}") 124 | logger.error(f"读取交易记录时出错: {str(e)}") 125 | -------------------------------------------------------------------------------- /rl_model_finrl/README.md: -------------------------------------------------------------------------------- 1 | # ETF-RL: 基于FinRL框架的A股ETF交易策略 2 | 3 | 这是一个基于FinRL框架设计的强化学习ETF交易系统,专注于A股市场ETF交易。该系统利用强化学习算法训练智能体,根据历史数据和市场特征自动进行ETF买入、卖出和持有的决策。 4 | 5 | ## 主要特点 6 | 7 | - 使用标准的三层FinRL架构:数据层、环境层和智能体层 8 | - 支持多ETF组合交易,能够同时管理多个ETF持仓 9 | - 集成Tushare和AKShare数据源,提供丰富的A股市场数据 10 | - 实现多种强化学习智能体: 11 | - 基于stable-baselines3的DQN智能体 12 | - 基于ElegantRL的PPO智能体 13 | - 基于RLlib的PPO智能体 14 | - 内置完整的数据预处理、特征工程、训练、回测和分析流程 15 | - 灵活的配置系统,支持命令行和配置文件参数 16 | 17 | ## 系统架构 18 | 19 | ``` 20 | src/strategies/rl_model_finrl/ 21 | ├── agents/ # 强化学习智能体 22 | │ ├── stablebaseline3/ # 基于SB3的智能体实现 23 | │ │ └── dqn_agent.py # DQN智能体实现 24 | │ ├── elegantrl/ # 基于ElegantRL的智能体 25 | │ │ └── ppo_agent.py # PPO智能体实现 26 | │ └── rllib/ # 基于RLlib的智能体 27 | │ └── ppo_agent.py # PPO智能体实现 28 | ├── meta/ # 基础组件 29 | │ ├── data_processors/ # 数据处理器 30 | │ │ ├── __init__.py # 数据处理器基类 31 | │ │ ├── tushare_processor.py # Tushare数据处理 32 | │ │ └── akshare_processor.py # AKShare数据处理 33 | │ └── preprocessor/ # 数据预处理 34 | │ ├── __init__.py # 预处理器基类 35 | │ ├── data_normalizer.py # 数据标准化 36 | │ └── feature_engineer.py # 特征工程 37 | ├── applications/ # 应用实现 38 | │ └── stock_trading/ # 股票交易应用 39 | │ ├── __init__.py # 应用初始化 40 | │ ├── etf_env.py # ETF交易环境 41 | │ ├── run_strategy.py # 策略运行 42 | │ ├── backtest.py # 回测模块 43 | │ └── analysis.py # 结果分析 44 | ├── config.py # 全局配置 45 | └── __init__.py # 模块初始化 46 | ``` 47 | 48 | ## 安装依赖 49 | 50 | 项目依赖以下Python包: 51 | 52 | ```bash 53 | conda install -c conda-forge cmake 54 | 55 | pip install pandas numpy matplotlib tushare akshare gym stable-baselines3 torch ray[rllib] loguru scikit-learn 56 | ``` 57 | 58 | ## 使用方法 59 | 60 | ### 1. 配置数据参数 61 | 62 | 在`config.py`中,配置您的Tushare API令牌和其他参数: 63 | 64 | ```python 65 | # 配置Tushare API令牌 66 | TUSHARE_TOKEN = "你的Tushare令牌" 67 | 68 | # 配置ETF列表 69 | TICKER_LIST = [ 70 | "159915.SZ", # 易方达创业板ETF 71 | "510300.SH", # 华泰柏瑞沪深300ETF 72 | # 添加更多ETF... 73 | ] 74 | 75 | # 配置日期范围 76 | TRAIN_START_DATE = "2018-01-01" 77 | TRAIN_END_DATE = "2021-12-31" 78 | TEST_START_DATE = "2022-01-01" 79 | TEST_END_DATE = "2023-12-31" 80 | ``` 81 | 82 | ### 2. 运行策略 83 | 84 | 使用`run_strategy.py`训练并运行策略: 85 | 86 | ```bash 87 | python -m src.strategies.rl_model_finrl.applications.stock_trading.run_strategy --tushare_token "你的Tushare令牌" --agent_type dqn 88 | ``` 89 | 90 | 可选的`agent_type`参数: 91 | - `dqn`: 使用StableBaseline3的DQN智能体 92 | - `ppo_elegant`: 使用ElegantRL的PPO智能体 93 | - `ppo_rllib`: 使用RLlib的PPO智能体 94 | 95 | ### 3. 回测模型 96 | 97 | 训练完成后,使用回测脚本评估模型性能: 98 | 99 | ```bash 100 | python -m src.strategies.rl_model_finrl.applications.stock_trading.backtest --model_path models/dqn_etf_trading.zip --agent_type dqn 101 | ``` 102 | 103 | ### 4. 分析结果 104 | 105 | 对回测结果进行详细分析: 106 | 107 | ```bash 108 | python -m src.strategies.rl_model_finrl.applications.stock_trading.analysis --result_path results/backtest_results.csv 109 | ``` 110 | 111 | 回测结果将生成图表和性能指标,包括: 112 | - 投资组合价值曲线 113 | - 与基准ETF的对比 114 | - 回撤分析 115 | - 收益率和风险指标 116 | - 交易记录分析 117 | - 胜率和盈亏比分析 118 | 119 | ## 扩展功能 120 | 121 | 系统支持以下扩展: 122 | 123 | 1. 添加新的ETF:在配置中添加新的ETF代码 124 | 2. 实现新的智能体:在`agents`目录下添加新的智能体实现 125 | 3. 添加新的数据源:扩展数据处理器以支持更多数据源 126 | 4. 自定义奖励函数:在环境中修改奖励计算逻辑 127 | 5. 调整特征工程:在`preprocessor`目录下修改特征生成逻辑 128 | 129 | ## 性能基准 130 | 131 | 在不同市场条件下的基准性能: 132 | 133 | - 牛市(2019-2020):年化收益率 20-30%,最大回撤 15-20% 134 | - 震荡市(2021-2022):年化收益率 5-15%,最大回撤 10-15% 135 | - 熊市(2022下半年):年化收益率 -5-5%,最大回撤 20-25% 136 | 137 | 请注意,过去的性能不代表未来结果,交易有风险,投资需谨慎。 138 | 139 | ## 待优化 140 | 以下是需要完善和优化的关键方面: 141 | 142 | 1. **超参数优化框架** 143 | - 实现自动化超参数调优(如使用Optuna或Ray Tune) 144 | - 当前实现缺乏系统性的超参数优化,这对强化学习性能至关重要 145 | 146 | 2. **风险管理扩展** 147 | - 添加止损和止盈机制 148 | - 实现基于波动率的回撤约束和仓位控制 149 | - 集成凯利准则进行最优仓位配置 150 | 151 | 3. **增强奖励函数** 152 | - 实现结合收益、风险和交易频率的多目标奖励 153 | - 添加过度交易的惩罚项(以减少交易成本) 154 | - 为特定市场条件创建自定义奖励塑造 155 | 156 | 4. **市场状态检测** 157 | - 添加市场状态检测(牛市/熊市/震荡市)以适应不同市场环境 158 | - 为不同市场条件实现单独模型或使用上下文特征 159 | 160 | 5. **集成方法** 161 | - 结合多个强化学习智能体的决策,产生更稳健的交易信号 162 | - 基于近期表现实现模型选择 163 | 164 | 6. **在线学习能力** 165 | - 添加增量训练功能,使模型能适应新的市场数据 166 | - 实现优先采样的经验回放,支持持续学习 167 | 168 | 7. **可解释性工具** 169 | - 开发可视化工具以理解智能体决策过程 170 | - 实现归因分析,了解哪些特征驱动决策 171 | 172 | 8. **扩展特征工程** 173 | - 添加新闻/社交媒体的市场情绪分析 174 | - 纳入宏观经济指标 175 | - 包含资金流向数据和行业轮动指标 176 | 177 | 9. **基准比较** 178 | - 实现与传统策略(动量、均值回归)的比较 179 | - 为性能指标添加统计显著性测试 180 | 181 | 10. **生产环境准备** 182 | - 添加强健的错误处理和日志记录 183 | - 实现生产部署的系统监控 184 | - 创建模型版本控制和自动化回测管道 185 | 186 | 11. **多时间框架分析** 187 | - 整合多个时间框架的数据(日线、周线、月线) 188 | - 为不同决策周期实现分层强化学习 189 | 190 | 12. **迁移学习** 191 | - 在相似市场/资产上实现预训练 192 | - 添加领域适应技术实现模型在不同市场间的迁移 193 | 194 | 13. **替代数据集成** 195 | - 创建价格数据之外的替代数据源接口 196 | - 添加期权链数据作为隐含波动率信号 197 | 198 | -------------------------------------------------------------------------------- /src/strategies/market_sentiment/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import talib 4 | from loguru import logger 5 | 6 | class TrendStateDetector: 7 | def __init__(self, window=60): 8 | self.window = window 9 | self.last_regime = None 10 | self.regime_change_date = None 11 | # 缓存历史计算结果 12 | self.cached_ema5 = None 13 | self.cached_ema13 = None 14 | self.cached_ema55 = None 15 | self.cached_momentum = None 16 | self.crossover_threshold = 0.001 # 0.1%的最小上穿幅度要求 17 | 18 | def detect(self, price_series, volume_series, current_date=None): 19 | """趋势检测逻辑,基于EMA交叉和价格位置""" 20 | price_series_pd = pd.Series(price_series) 21 | 22 | # 计算EMA指标 23 | ema5 = talib.EMA(price_series_pd, timeperiod=5) 24 | ema13 = talib.EMA(price_series_pd, timeperiod=10) 25 | ema55 = talib.EMA(price_series_pd, timeperiod=55) 26 | 27 | # 动量计算 28 | momentum = talib.MOM(price_series_pd, timeperiod=5) 29 | sma_momentum = talib.SMA(momentum, timeperiod=3) 30 | 31 | # 成交量指标 32 | volume_series_pd = pd.Series(volume_series) 33 | obv = talib.OBV(price_series_pd, volume_series_pd) 34 | obv_ema = talib.EMA(obv, timeperiod=5) 35 | 36 | current_regime = 'normal' 37 | 38 | if len(price_series) >= 55: 39 | # 获取当前和前一日的值 40 | ema5_current = ema5.iloc[-1] 41 | ema5_prev = ema5.iloc[-2] 42 | ema13_current = ema13.iloc[-1] 43 | ema13_prev = ema13.iloc[-2] 44 | ema55_current = ema55.iloc[-1] 45 | ema55_prev = ema55.iloc[-2] 46 | price_current = price_series_pd.iloc[-1] 47 | price_prev = price_series_pd.iloc[-2] 48 | 49 | # 成交量确认 50 | volume_confirmed = obv.iloc[-1] > obv_ema.iloc[-1] 51 | 52 | # 1. EMA5上穿EMA13时设置为potential_uptrend 53 | crossover_pct = (ema5_current - ema13_current) / ema13_current 54 | if (ema5_current > ema13_current and ema5_prev <= ema13_prev and 55 | crossover_pct > self.crossover_threshold): 56 | current_regime = 'potential_uptrend' 57 | 58 | # 2. EMA5上穿EMA13后,收盘价格连续3个交易日站上EMA55,设置为uptrend 59 | if (ema5_current > ema13_current and 60 | price_current > ema55_current and 61 | price_series_pd.iloc[-2] > ema55.iloc[-2] and 62 | price_series_pd.iloc[-3] > ema55.iloc[-3] and 63 | volume_confirmed): 64 | current_regime = 'uptrend' 65 | 66 | # 3. EMA5下穿EMA13,但收盘价格在EMA55之上,设置为potential_downtrend 67 | if (ema5_current < ema13_current and 68 | ema5_prev >= ema13_prev and 69 | price_current > ema55_current): 70 | current_regime = 'potential_downtrend' 71 | 72 | # 4. EMA5下穿EMA55,设置为downtrend 73 | if ema5_current < ema55_current and ema5_prev >= ema55_prev: 74 | current_regime = 'downtrend' 75 | 76 | # 趋势变化检测 77 | if current_regime != self.last_regime: 78 | self.regime_change_date = current_date 79 | if current_date is not None: 80 | logger.info(f"趋势变化: {self.last_regime} -> {current_regime}, 日期: {current_date}") 81 | self.last_regime = current_regime 82 | 83 | return current_regime 84 | 85 | class PositionManager: 86 | def __init__(self, max_risk=0.45): 87 | self.max_risk = max_risk 88 | 89 | def adjust_position(self, target_ratio, volatility): 90 | """简化的仓位管理,只根据波动率调整仓位""" 91 | # 波动率调整系数 - 这里的2是百分比形式的波动率基准值(2%) 92 | # 实际波动率通常在1%~3%之间 93 | vol_adj = np.clip(volatility / 2, 0.5, 1.5) 94 | 95 | # 根据波动率调整目标仓位 96 | adjusted_ratio = target_ratio * vol_adj 97 | 98 | # 设置风险上限 99 | if adjusted_ratio > self.max_risk * 2: # 最大允许仓位90% 100 | adjusted_ratio = self.max_risk * 2 101 | 102 | logger.info(f"波动率: {volatility:.2f}, 波动率调整系数: {vol_adj:.2f}, 目标仓位: {target_ratio:.2f}, 调整后仓位: {adjusted_ratio:.2f}") 103 | 104 | return adjusted_ratio 105 | 106 | # # 信号生成参数 107 | # SENTIMENT_THRESHOLDS = { 108 | # 'core': 2.5, # 核心信号阈值 109 | # 'secondary': 9.0, # 次级信号阈值 110 | # 'light': 9.0 # 轻仓信号阈值 111 | # } 112 | 113 | # POSITION_WEIGHTS = { 114 | # 'core': 0.95, # 核心信号仓位 115 | # 'secondary': 0.9, # 次级信号仓位 116 | # 'light': 0.8 # 轻仓信号仓位 117 | # } 118 | 119 | # def generate_signals(sentiment_score, regime, volatility): 120 | # """简化的信号生成,只区分下跌趋势和正常趋势""" 121 | # signals = [] 122 | 123 | # # 下跌趋势不开仓 124 | # if regime == 'downtrend': 125 | # if sentiment_score < SENTIMENT_THRESHOLDS['core']: 126 | # signals.append({'type': 'buy', 'weight': POSITION_WEIGHTS['core']}) # 核心信号 127 | # return signals 128 | # else: 129 | # return [] 130 | 131 | # # 其他情况根据情绪分数决定 132 | # if sentiment_score < SENTIMENT_THRESHOLDS['core']: 133 | # signals.append({'type': 'buy', 'weight': POSITION_WEIGHTS['core']}) # 核心信号 134 | # elif SENTIMENT_THRESHOLDS['core'] <= sentiment_score < SENTIMENT_THRESHOLDS['secondary']: 135 | # signals.append({'type': 'buy', 'weight': POSITION_WEIGHTS['secondary']}) # 次级信号 136 | # elif SENTIMENT_THRESHOLDS['secondary'] <= sentiment_score < SENTIMENT_THRESHOLDS['light']: 137 | # signals.append({'type': 'buy', 'weight': POSITION_WEIGHTS['light']}) # 轻仓信号 138 | 139 | # return signals -------------------------------------------------------------------------------- /rl_model_finrl/meta/data_processors/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | 数据处理器模块 3 | 4 | 这个模块提供了处理金融市场数据的各种处理器,支持不同的数据源: 5 | - TushareProcessor: 处理来自Tushare API的数据 6 | - AKShareProcessor: 处理来自AKShare的数据 7 | - DataProcessor: 数据处理器的基类 8 | 9 | 主要功能: 10 | - 下载和预处理ETF价格数据 11 | - 添加技术指标 12 | - 准备用于强化学习环境的数据 13 | 14 | 使用示例: 15 | ```python 16 | from src.strategies.rl_model_finrl.meta.data_processors import DataProcessor, TushareProcessor 17 | 18 | # 使用Tushare处理器 19 | processor = TushareProcessor(token="your_tushare_token") 20 | data = processor.download_data( 21 | start_date="2018-01-01", 22 | end_date="2021-12-31", 23 | ticker_list=["510050"] 24 | ) 25 | 26 | # 添加技术指标 27 | data_with_indicators = processor.add_technical_indicators(data) 28 | 29 | # 准备训练数据 30 | train_data = processor.data_split(data_with_indicators, start_date="2018-01-01", end_date="2020-12-31") 31 | ``` 32 | """ 33 | 34 | import pandas as pd 35 | import numpy as np 36 | from abc import ABC, abstractmethod 37 | from typing import Dict, List, Tuple, Union, Optional 38 | 39 | class DataProcessor(ABC): 40 | """ 41 | 数据处理器基类 42 | 43 | 定义了数据处理的通用接口,所有特定数据源的处理器都应该继承这个类 44 | """ 45 | 46 | @abstractmethod 47 | def download_data(self, **kwargs) -> pd.DataFrame: 48 | """ 49 | 从数据源下载数据 50 | 51 | 参数: 52 | **kwargs: 下载参数,如起止日期、股票代码等 53 | 54 | 返回: 55 | 下载的数据DataFrame 56 | """ 57 | pass 58 | 59 | @abstractmethod 60 | def clean_data(self, data: pd.DataFrame) -> pd.DataFrame: 61 | """ 62 | 清洗数据 63 | 64 | 参数: 65 | data: 原始数据DataFrame 66 | 67 | 返回: 68 | 清洗后的数据DataFrame 69 | """ 70 | pass 71 | 72 | def add_technical_indicators(self, data: pd.DataFrame) -> pd.DataFrame: 73 | """ 74 | 添加技术指标 75 | 76 | 参数: 77 | data: 原始价格数据DataFrame 78 | 79 | 返回: 80 | 添加技术指标后的数据DataFrame 81 | """ 82 | df = data.copy() 83 | 84 | # 确保列名标准化 85 | price_col = 'close' if 'close' in df.columns else 'Close' 86 | high_col = 'high' if 'high' in df.columns else 'High' 87 | low_col = 'low' if 'low' in df.columns else 'Low' 88 | volume_col = 'volume' if 'volume' in df.columns else 'Volume' 89 | 90 | # 计算MACD 91 | df['ema12'] = df[price_col].ewm(span=12, adjust=False).mean() 92 | df['ema26'] = df[price_col].ewm(span=26, adjust=False).mean() 93 | df['macd'] = df['ema12'] - df['ema26'] 94 | df['macd_signal'] = df['macd'].ewm(span=9, adjust=False).mean() 95 | df['macd_hist'] = df['macd'] - df['macd_signal'] 96 | 97 | # 计算RSI (相对强弱指标) 98 | delta = df[price_col].diff() 99 | gain = delta.copy() 100 | loss = delta.copy() 101 | gain[gain < 0] = 0 102 | loss[loss > 0] = 0 103 | loss = -loss 104 | 105 | avg_gain = gain.rolling(window=14).mean() 106 | avg_loss = loss.rolling(window=14).mean() 107 | 108 | rs = avg_gain / avg_loss 109 | df['rsi'] = 100 - (100 / (1 + rs)) 110 | 111 | # 计算CCI (商品通道指标) 112 | df['tp'] = (df[high_col] + df[low_col] + df[price_col]) / 3 113 | df['tp_ma'] = df['tp'].rolling(window=20).mean() 114 | mean_dev = df['tp'].rolling(window=20).apply(lambda x: abs(x - x.mean()).mean()) 115 | df['cci'] = (df['tp'] - df['tp_ma']) / (0.015 * mean_dev) 116 | 117 | # 计算布林带 118 | df['sma20'] = df[price_col].rolling(window=20).mean() 119 | df['bollinger_upper'] = df['sma20'] + 2 * df[price_col].rolling(window=20).std() 120 | df['bollinger_lower'] = df['sma20'] - 2 * df[price_col].rolling(window=20).std() 121 | 122 | # 计算ATR (真实波幅均值) 123 | df['tr1'] = df[high_col] - df[low_col] 124 | df['tr2'] = abs(df[high_col] - df[price_col].shift()) 125 | df['tr3'] = abs(df[low_col] - df[price_col].shift()) 126 | df['tr'] = df[['tr1', 'tr2', 'tr3']].max(axis=1) 127 | df['atr'] = df['tr'].rolling(window=14).mean() 128 | 129 | # 移除临时列 130 | df = df.drop(['tr1', 'tr2', 'tr3', 'tr', 'tp', 'tp_ma', 'ema12', 'ema26'], axis=1, errors='ignore') 131 | 132 | # 使用涨跌幅代替价格 133 | df['daily_return'] = df[price_col].pct_change() 134 | 135 | return df 136 | 137 | def data_split( 138 | self, 139 | df: pd.DataFrame, 140 | start_date: str, 141 | end_date: str 142 | ) -> pd.DataFrame: 143 | """ 144 | 按日期范围分割数据 145 | 146 | 参数: 147 | df: 原始数据DataFrame 148 | start_date: 开始日期 149 | end_date: 结束日期 150 | 151 | 返回: 152 | 指定日期范围的数据 153 | """ 154 | data = df.copy() 155 | 156 | # 确保索引是日期类型 157 | if not isinstance(data.index, pd.DatetimeIndex): 158 | if 'date' in data.columns: 159 | data['date'] = pd.to_datetime(data['date']) 160 | data = data.set_index('date') 161 | else: 162 | raise ValueError("数据必须有日期列或日期索引") 163 | 164 | # 转换日期字符串为日期类型 165 | start_date = pd.to_datetime(start_date) 166 | end_date = pd.to_datetime(end_date) 167 | 168 | # 筛选日期范围内的数据 169 | data = data.loc[start_date:end_date] 170 | 171 | return data 172 | 173 | def prepare_data_for_training(self, **kwargs) -> Tuple[Dict[str, pd.DataFrame], Optional[pd.DataFrame]]: 174 | """ 175 | 准备用于训练的数据 176 | 177 | 参数: 178 | **kwargs: 准备数据的参数 179 | 180 | 返回: 181 | (ETF数据字典, 市场指数数据)元组 182 | """ 183 | pass 184 | 185 | from src.strategies.rl_model_finrl.meta.data_processors.tushare_processor import TushareProcessor 186 | from src.strategies.rl_model_finrl.meta.data_processors.akshare_processor import AKShareProcessor 187 | 188 | __all__ = [ 189 | 'DataProcessor', 190 | 'TushareProcessor', 191 | 'AKShareProcessor' 192 | ] -------------------------------------------------------------------------------- /src/strategies/market_sentiment/etf_dividend_handler.py: -------------------------------------------------------------------------------- 1 | import tushare as ts 2 | import pandas as pd 3 | from src.data.data_loader import DataLoader 4 | from datetime import datetime, timedelta 5 | from loguru import logger 6 | import os 7 | import time 8 | import json 9 | import numpy as np 10 | 11 | class ETFDividendHandler: 12 | """处理ETF分红的类""" 13 | def __init__(self, ts_code=None): 14 | self.ts_code = ts_code # ETF代码 15 | self.dividend_data = None # 分红数据DataFrame 16 | self.last_api_call = 0 # 上次API调用时间 17 | self.min_interval = 1.0 # 最小调用间隔(秒) 18 | self.cache_file = f'cache/dividend_{ts_code}.json' # 缓存文件路径 19 | 20 | # 确保缓存目录存在 21 | os.makedirs('cache', exist_ok=True) 22 | 23 | # 初始化Tushare 24 | tushare_token = os.getenv('TUSHARE_TOKEN') 25 | if not tushare_token: 26 | raise ValueError("未设置TUSHARE_TOKEN环境变量") 27 | ts.set_token(tushare_token) 28 | self.pro = ts.pro_api() 29 | 30 | def _wait_for_rate_limit(self): 31 | """等待以满足API调用频率限制""" 32 | current_time = time.time() 33 | elapsed = current_time - self.last_api_call 34 | if elapsed < self.min_interval: 35 | time.sleep(self.min_interval - elapsed) 36 | self.last_api_call = time.time() 37 | 38 | def _load_from_cache(self): 39 | """从缓存加载分红数据""" 40 | try: 41 | if os.path.exists(self.cache_file): 42 | with open(self.cache_file, 'r') as f: 43 | data = json.load(f) 44 | df = pd.DataFrame(data) 45 | df['date'] = pd.to_datetime(df['date']) 46 | return df 47 | except Exception as e: 48 | logger.warning(f"从缓存加载分红数据失败: {str(e)}") 49 | return None 50 | 51 | def _save_to_cache(self, df): 52 | """保存分红数据到缓存""" 53 | try: 54 | # 将DataFrame转换为可序列化的格式 55 | df_copy = df.copy() 56 | df_copy['date'] = df_copy['date'].dt.strftime('%Y-%m-%d') 57 | data = df_copy.to_dict('records') 58 | 59 | with open(self.cache_file, 'w') as f: 60 | json.dump(data, f) 61 | except Exception as e: 62 | logger.warning(f"保存分红数据到缓存失败: {str(e)}") 63 | 64 | def _clean_dividend_data(self, df): 65 | """清理分红数据,过滤掉非数值的分红记录""" 66 | try: 67 | # 将分红列转换为数值类型,非数值将变为NaN 68 | df['dividend'] = pd.to_numeric(df['dividend'], errors='coerce') 69 | 70 | # 删除分红为NaN的记录 71 | df = df.dropna(subset=['dividend']) 72 | 73 | # 删除分红为0的记录 74 | df = df[df['dividend'] > 0] 75 | 76 | return df 77 | except Exception as e: 78 | logger.error(f"清理分红数据时出错: {str(e)}") 79 | return df 80 | 81 | def update_dividend_data(self, start_date=None, end_date=None): 82 | """更新分红数据""" 83 | try: 84 | # 先尝试从缓存加载 85 | cached_data = self._load_from_cache() 86 | if cached_data is not None: 87 | # 检查缓存数据是否覆盖了所需的日期范围 88 | if start_date and end_date: 89 | cached_start = cached_data['date'].min().date() 90 | cached_end = cached_data['date'].max().date() 91 | if cached_start <= start_date and cached_end >= end_date: 92 | self.dividend_data = cached_data 93 | logger.info(f"从缓存加载ETF分红数据成功 - {self.ts_code}, 数据长度: {len(cached_data)}") 94 | return True 95 | 96 | # 如果缓存不存在或数据不完整,从API获取 97 | self._wait_for_rate_limit() 98 | 99 | # 获取分红数据 100 | df = self.pro.fund_div( 101 | ts_code=self.ts_code, 102 | start_date=start_date.strftime('%Y%m%d') if start_date else None, 103 | end_date=end_date.strftime('%Y%m%d') if end_date else None 104 | ) 105 | 106 | if df is not None and not df.empty: 107 | # 重命名列 108 | df = df.rename(columns={ 109 | 'ann_date': 'date', 110 | 'div_cash': 'dividend' 111 | }) 112 | 113 | # 转换日期格式 114 | df['date'] = pd.to_datetime(df['date']) 115 | 116 | # 清理分红数据 117 | df = self._clean_dividend_data(df) 118 | 119 | # 按日期排序 120 | df = df.sort_values('date') 121 | 122 | # 保存到缓存 123 | self._save_to_cache(df) 124 | 125 | self.dividend_data = df 126 | logger.info(f"成功获取ETF分红数据 - {self.ts_code}, 数据长度: {len(df)}") 127 | return True 128 | else: 129 | logger.warning(f"未获取到ETF分红数据 - {self.ts_code}") 130 | return False 131 | 132 | except Exception as e: 133 | logger.error(f"获取ETF分红数据出错: {str(e)}") 134 | return False 135 | 136 | def process_dividend(self, date_str, position_size, current_price): 137 | """处理指定日期的分红""" 138 | if self.dividend_data is None: 139 | return 0.0 140 | 141 | try: 142 | # 将日期字符串转换为datetime对象 143 | date = pd.to_datetime(date_str) 144 | 145 | # 查找当天的分红记录 146 | dividend_record = self.dividend_data[self.dividend_data['date'].dt.date == date.date()] 147 | 148 | if not dividend_record.empty: 149 | # 获取每股分红金额 150 | dividend_per_share = float(dividend_record['dividend'].iloc[0]) 151 | 152 | # 计算总分红金额 153 | total_dividend = dividend_per_share * position_size 154 | 155 | logger.info(f"处理ETF分红 - 日期: {date_str}, 每股分红: {dividend_per_share:.4f}, " 156 | f"持仓数量: {position_size}, 总分红: {total_dividend:.2f}") 157 | 158 | return total_dividend 159 | else: 160 | return 0.0 161 | 162 | except Exception as e: 163 | logger.error(f"处理ETF分红时出错: {str(e)}") 164 | return 0.0 -------------------------------------------------------------------------------- /rl_model_finrl/meta/preprocessor/data_normalizer.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | from typing import Dict, List, Tuple, Union, Optional 4 | from sklearn.preprocessing import StandardScaler, MinMaxScaler, RobustScaler 5 | 6 | class DataNormalizer: 7 | """ 8 | 数据归一化处理器 9 | 10 | 用于金融数据的归一化和标准化,以便更好地用于机器学习模型 11 | """ 12 | 13 | def __init__(self, 14 | method: str = 'standard', # 'standard', 'minmax', 'robust' 15 | feature_range: Tuple[float, float] = (0, 1), 16 | ): 17 | """ 18 | 初始化归一化处理器 19 | 20 | 参数: 21 | method: 归一化方法: 'standard'(标准化), 'minmax'(最小最大归一化), 'robust'(稳健归一化) 22 | feature_range: 特征范围 (针对minmax方法) 23 | """ 24 | self.method = method 25 | self.feature_range = feature_range 26 | self.scalers = {} 27 | self.feature_columns = None 28 | 29 | def fit(self, df: pd.DataFrame, feature_columns: Optional[List[str]] = None) -> 'DataNormalizer': 30 | """ 31 | 拟合归一化器 32 | 33 | 参数: 34 | df: 数据DataFrame 35 | feature_columns: 要归一化的特征列名列表,如果为None,将使用所有数值型列 36 | 37 | 返回: 38 | 归一化处理器实例 39 | """ 40 | data = df.copy() 41 | 42 | # 如果未指定特征列,使用所有数值列 43 | if feature_columns is None: 44 | self.feature_columns = data.select_dtypes(include=['float64', 'int64']).columns.tolist() 45 | else: 46 | self.feature_columns = feature_columns 47 | 48 | # 拟合每个特征的归一化器 49 | for col in self.feature_columns: 50 | if col in data.columns: 51 | # 创建适当的缩放器 52 | if self.method == 'standard': 53 | scaler = StandardScaler() 54 | elif self.method == 'minmax': 55 | scaler = MinMaxScaler(feature_range=self.feature_range) 56 | elif self.method == 'robust': 57 | scaler = RobustScaler() 58 | else: 59 | raise ValueError(f"不支持的归一化方法: {self.method}") 60 | 61 | # 拟合缩放器 62 | values = data[col].values.reshape(-1, 1) 63 | scaler.fit(values) 64 | self.scalers[col] = scaler 65 | 66 | return self 67 | 68 | def transform(self, df: pd.DataFrame) -> pd.DataFrame: 69 | """ 70 | 转换数据 71 | 72 | 参数: 73 | df: 要转换的数据DataFrame 74 | 75 | 返回: 76 | 归一化后的DataFrame 77 | """ 78 | if not self.scalers: 79 | raise ValueError("请先调用fit方法") 80 | 81 | data = df.copy() 82 | 83 | # 转换每个特征 84 | for col, scaler in self.scalers.items(): 85 | if col in data.columns: 86 | values = data[col].values.reshape(-1, 1) 87 | data[col] = scaler.transform(values) 88 | 89 | return data 90 | 91 | def fit_transform(self, df: pd.DataFrame, feature_columns: Optional[List[str]] = None) -> pd.DataFrame: 92 | """ 93 | 拟合并转换数据 94 | 95 | 参数: 96 | df: 数据DataFrame 97 | feature_columns: 要归一化的特征列名列表 98 | 99 | 返回: 100 | 归一化后的DataFrame 101 | """ 102 | self.fit(df, feature_columns) 103 | return self.transform(df) 104 | 105 | def inverse_transform(self, df: pd.DataFrame) -> pd.DataFrame: 106 | """ 107 | 反向转换数据(恢复原始比例) 108 | 109 | 参数: 110 | df: 归一化后的数据DataFrame 111 | 112 | 返回: 113 | 原始比例的DataFrame 114 | """ 115 | if not self.scalers: 116 | raise ValueError("请先调用fit方法") 117 | 118 | data = df.copy() 119 | 120 | # 反向转换每个特征 121 | for col, scaler in self.scalers.items(): 122 | if col in data.columns: 123 | values = data[col].values.reshape(-1, 1) 124 | data[col] = scaler.inverse_transform(values) 125 | 126 | return data 127 | 128 | def save(self, path: str) -> None: 129 | """ 130 | 保存归一化器状态 131 | 132 | 参数: 133 | path: 保存路径 134 | """ 135 | import joblib 136 | 137 | state = { 138 | 'method': self.method, 139 | 'feature_range': self.feature_range, 140 | 'scalers': self.scalers, 141 | 'feature_columns': self.feature_columns 142 | } 143 | 144 | joblib.dump(state, path) 145 | 146 | @classmethod 147 | def load(cls, path: str) -> 'DataNormalizer': 148 | """ 149 | 从文件加载归一化器 150 | 151 | 参数: 152 | path: 文件路径 153 | 154 | 返回: 155 | 加载的归一化器实例 156 | """ 157 | import joblib 158 | 159 | state = joblib.load(path) 160 | 161 | normalizer = cls( 162 | method=state['method'], 163 | feature_range=state['feature_range'] 164 | ) 165 | 166 | normalizer.scalers = state['scalers'] 167 | normalizer.feature_columns = state['feature_columns'] 168 | 169 | return normalizer 170 | 171 | def normalize_price_data(self, df: pd.DataFrame, price_columns: List[str]) -> pd.DataFrame: 172 | """ 173 | 归一化价格数据 174 | 175 | 参数: 176 | df: 价格数据DataFrame 177 | price_columns: 价格列名列表 178 | 179 | 返回: 180 | 归一化后的DataFrame 181 | """ 182 | data = df.copy() 183 | 184 | # 当价格数据用于机器学习时,通常更好的做法是转换为收益率或相对变化 185 | for col in price_columns: 186 | if col in data.columns: 187 | # 计算收益率 188 | data[f'{col}_return'] = data[col].pct_change() 189 | 190 | # 计算相对于初始价格的变化 191 | data[f'{col}_rel'] = data[col] / data[col].iloc[0] - 1 192 | 193 | # 对原始价格列进行归一化 194 | if self.method == 'standard': 195 | scaler = StandardScaler() 196 | elif self.method == 'minmax': 197 | scaler = MinMaxScaler(feature_range=self.feature_range) 198 | elif self.method == 'robust': 199 | scaler = RobustScaler() 200 | 201 | values = data[col].values.reshape(-1, 1) 202 | data[f'{col}_norm'] = scaler.fit_transform(values) 203 | 204 | # 保存缩放器以便将来反归一化 205 | self.scalers[col] = scaler 206 | 207 | return data -------------------------------------------------------------------------------- /src/utils/plot.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | import plotly.graph_objects as go 4 | from plotly.subplots import make_subplots 5 | from datetime import datetime 6 | import pandas as pd 7 | 8 | class Plot: 9 | def __init__(self, strategy): 10 | self.strategy = strategy 11 | 12 | def plot(self, **kwargs): 13 | """使用Plotly绘制交互式回测结果""" 14 | # 获取策略数据 15 | data = self.strategy.data 16 | 17 | # 将数据转换为numpy数组,处理日期时间 18 | dates = [datetime.fromordinal(int(d)).strftime('%Y-%m-%d') for d in data.datetime.array] 19 | opens = np.array(data.open.array) 20 | highs = np.array(data.high.array) 21 | lows = np.array(data.low.array) 22 | closes = np.array(data.close.array) 23 | volumes = np.array(data.volume.array) 24 | 25 | # 检查trailing_stop是否存在 26 | trailing_stop = getattr(self.strategy, 'trailing_stop', None) 27 | if trailing_stop is not None: 28 | trailing_stop_vals = np.array(trailing_stop.trailing_stop.array) 29 | else: 30 | # 如果trailing_stop为None,创建一个全为0的数组 31 | trailing_stop_vals = np.zeros_like(closes) 32 | 33 | # 创建DataFrame以便处理数据 34 | df = pd.DataFrame({ 35 | 'date': dates, 36 | 'open': opens, 37 | 'high': highs, 38 | 'low': lows, 39 | 'close': closes, 40 | 'volume': volumes, 41 | 'trailing_stop': trailing_stop_vals 42 | }) 43 | 44 | # 移除volume为0的行(非交易日) 45 | df = df[df['volume'] > 0].copy() 46 | 47 | # 创建子图 48 | fig = make_subplots(rows=2, cols=1, 49 | shared_xaxes=True, 50 | vertical_spacing=0.03, 51 | row_heights=[0.7, 0.3]) 52 | 53 | # 添加K线图 54 | fig.add_trace(go.Candlestick( 55 | x=df['date'], 56 | open=df['open'], 57 | high=df['high'], 58 | low=df['low'], 59 | close=df['close'], 60 | name='K线', 61 | increasing_line_color='#ff0000', # 上涨为红色 62 | decreasing_line_color='#00ff00', # 下跌为绿色 63 | ), row=1, col=1) 64 | 65 | # 添加追踪止损线 66 | # 过滤掉追踪止损为0的点 67 | valid_stops = df[df['trailing_stop'] > 0].copy() 68 | if not valid_stops.empty: 69 | fig.add_trace(go.Scatter( 70 | x=valid_stops['date'], 71 | y=valid_stops['trailing_stop'], 72 | name='追踪止损', 73 | line=dict(color='#f1c40f', dash='dash') 74 | ), row=1, col=1) 75 | 76 | # 添加买卖点标记 77 | if hasattr(self, 'trades_df') and not self.trades_df.empty: 78 | # 获取买入点 79 | buy_points = self.trades_df[self.trades_df['类型'] == '买入'] 80 | if not buy_points.empty: 81 | fig.add_trace(go.Scatter( 82 | x=buy_points['时间'].dt.strftime('%Y-%m-%d'), 83 | y=buy_points['价格'], 84 | mode='markers+text', 85 | marker=dict(symbol='triangle-up', size=15, color='#2ecc71'), 86 | text=[f"买入\n价格:{price:.2f}\n数量:{size}" 87 | for price, size in zip(buy_points['价格'], buy_points['数量'])], 88 | textposition="top center", 89 | name='买入点', 90 | hoverinfo='text' 91 | ), row=1, col=1) 92 | 93 | # 获取卖出点 94 | sell_points = self.trades_df[self.trades_df['类型'] == '卖出'] 95 | if not sell_points.empty: 96 | fig.add_trace(go.Scatter( 97 | x=sell_points['时间'].dt.strftime('%Y-%m-%d'), 98 | y=sell_points['价格'], 99 | mode='markers+text', 100 | marker=dict(symbol='triangle-down', size=15, color='#e74c3c'), 101 | text=[f"卖出\n价格:{price:.2f}\n收益:{profit:.2f}" 102 | for price, profit in zip(sell_points['价格'], sell_points['累计收益'])], 103 | textposition="bottom center", 104 | name='卖出点', 105 | hoverinfo='text' 106 | ), row=1, col=1) 107 | 108 | # 添加成交量图 109 | colors = ['#ff0000' if close >= open_ else '#00ff00' 110 | for close, open_ in zip(df['close'], df['open'])] 111 | 112 | fig.add_trace(go.Bar( 113 | x=df['date'], 114 | y=df['volume'], 115 | name='成交量', 116 | marker_color=colors, 117 | marker=dict( 118 | color=colors, 119 | line=dict(color=colors, width=1) 120 | ), 121 | hovertemplate='日期: %{x}
成交量: %{y:,.0f}' 122 | ), row=2, col=1) 123 | 124 | # 更新布局 125 | fig.update_layout( 126 | title={'text': '回测结果', 'font': {'family': 'Arial'}}, 127 | yaxis_title={'text': '价格', 'font': {'family': 'Arial'}}, 128 | yaxis2_title={'text': '成交量', 'font': {'family': 'Arial'}}, 129 | xaxis_rangeslider_visible=False, 130 | height=800, 131 | template='plotly_white', 132 | showlegend=True, 133 | legend=dict( 134 | yanchor="top", 135 | y=0.99, 136 | xanchor="left", 137 | x=0.01, 138 | font=dict(family='Arial') 139 | ), 140 | # 优化X轴显示 141 | xaxis=dict( 142 | type='category', 143 | rangeslider=dict(visible=False), 144 | showgrid=False, # 移除X轴网格线 145 | gridwidth=1, 146 | gridcolor='lightgrey' 147 | ), 148 | xaxis2=dict( 149 | type='category', 150 | rangeslider=dict(visible=False), 151 | showgrid=False, # 移除X轴网格线 152 | gridwidth=1, 153 | gridcolor='lightgrey' 154 | ), 155 | bargap=0, # 设置柱状图之间的间隔为0 156 | bargroupgap=0 # 设置柱状图组之间的间隔为0 157 | ) 158 | 159 | # 更新Y轴格式 160 | fig.update_yaxes( 161 | title_text="价格", 162 | title_font=dict(family='Arial'), 163 | showgrid=True, 164 | gridwidth=1, 165 | gridcolor='lightgrey', 166 | row=1, 167 | col=1 168 | ) 169 | fig.update_yaxes( 170 | title_text="成交量", 171 | title_font=dict(family='Arial'), 172 | showgrid=True, 173 | gridwidth=1, 174 | gridcolor='lightgrey', 175 | row=2, 176 | col=1 177 | ) 178 | 179 | return fig -------------------------------------------------------------------------------- /src/strategies/rl_model_strategy.py: -------------------------------------------------------------------------------- 1 | import backtrader as bt 2 | import numpy as np 3 | import pandas as pd 4 | import torch 5 | from loguru import logger 6 | import json 7 | 8 | class RLModelStrategy(bt.Strategy): 9 | params = ( 10 | ('model_path', None), # 模型路径 11 | ('config_path', None), # 配置文件路径 12 | ('window_size', 10), # 观察窗口大小 13 | ('risk_ratio', 0.02), # 单次交易风险比率 14 | ('max_drawdown', 0.15), # 最大回撤限制 15 | ('price_limit', 0.10), # 涨跌停限制(10%) 16 | ('min_shares', 100), # 最小交易股数 17 | ('cash_buffer', 0.95), # 现金缓冲比例 18 | ) 19 | 20 | def __init__(self): 21 | """初始化策略""" 22 | # 加载配置和模型 23 | if self.p.config_path: 24 | with open(self.p.config_path, 'r') as f: 25 | config_dict = json.load(f) 26 | # TODO: 从字典更新配置 27 | 28 | # 设置设备 29 | self.device = 'cuda' if torch.cuda.is_available() else 'cpu' 30 | 31 | # 初始化智能体 32 | if self.p.model_path: 33 | self.agent = self._load_agent() 34 | else: 35 | logger.error("未提供模型路径") 36 | raise ValueError("请提供有效的模型路径") 37 | 38 | # 记录最高净值,用于计算回撤 39 | self.highest_value = self.broker.getvalue() 40 | 41 | # 用于跟踪订单和持仓 42 | self.order = None 43 | self.entry_price = None 44 | self.trade_reason = None 45 | self._orders = [] 46 | 47 | # 技术指标 48 | self.atr = bt.indicators.ATR(self.data) 49 | 50 | # 价格和特征历史 51 | self.price_history = [] 52 | self.feature_history = [] 53 | 54 | logger.info("强化学习模型策略初始化完成") 55 | 56 | def _load_agent(self): 57 | """加载训练好的智能体""" 58 | # 计算状态维度 59 | pass 60 | 61 | def _get_state(self): 62 | """构建当前状态""" 63 | # 获取价格历史 64 | price_data = np.array([ 65 | [self.data.open[i], self.data.high[i], self.data.low[i], 66 | self.data.close[i], self.data.volume[i]] 67 | for i in range(-self.p.window_size + 1, 1) 68 | ]).flatten() 69 | 70 | # 生成技术指标特征 71 | features = [] 72 | for i in range(-self.p.window_size + 1, 1): 73 | # 趋势指标 74 | ma5 = np.mean([self.data.close[j] for j in range(i-4, i+1)]) 75 | ma10 = np.mean([self.data.close[j] for j in range(i-9, i+1)]) 76 | ma20 = np.mean([self.data.close[j] for j in range(i-19, i+1)]) 77 | 78 | # 动量指标 79 | momentum = self.data.close[i] / self.data.close[i-5] - 1 80 | 81 | # 波动率指标 82 | volatility = np.std([self.data.close[j] for j in range(i-4, i+1)]) 83 | 84 | # 成交量指标 85 | volume_ma5 = np.mean([self.data.volume[j] for j in range(i-4, i+1)]) 86 | 87 | features.extend([ma5, ma10, ma20, momentum, volatility, volume_ma5]) 88 | 89 | # 账户状态 90 | portfolio_value = self.broker.getvalue() 91 | position_value = self.position.size * self.data.close[0] if self.position else 0 92 | position_pct = position_value / portfolio_value if portfolio_value > 0 else 0 93 | cash_pct = self.broker.getcash() / portfolio_value 94 | 95 | # 组合状态 96 | state = np.concatenate([ 97 | price_data, 98 | features, 99 | [cash_pct, position_pct] 100 | ]) 101 | 102 | return state.astype(np.float32) 103 | 104 | def round_shares(self, shares): 105 | """将股数调整为100的整数倍""" 106 | return int(shares / 100) * 100 107 | 108 | def check_price_limit(self, price): 109 | """检查是否触及涨跌停""" 110 | prev_close = self.data.close[-1] 111 | upper_limit = prev_close * (1 + self.p.price_limit) 112 | lower_limit = prev_close * (1 - self.p.price_limit) 113 | return lower_limit <= price <= upper_limit 114 | 115 | def calculate_trade_size(self, price): 116 | """计算可交易的股数(考虑资金、手续费和100股整数倍)""" 117 | cash = self.broker.getcash() * self.p.cash_buffer 118 | 119 | # 计算风险金额(使用总资产的一定比例) 120 | total_value = self.broker.getvalue() 121 | risk_amount = total_value * self.p.risk_ratio 122 | 123 | # 使用ATR计算每股风险 124 | current_atr = self.atr[0] 125 | risk_per_share = current_atr * 1.5 126 | 127 | # 根据风险计算的股数 128 | risk_size = risk_amount / risk_per_share if risk_per_share > 0 else 0 129 | 130 | # 根据可用资金计算的股数 131 | cash_size = cash / price 132 | 133 | # 取较小值并调整为100股整数倍 134 | shares = min(risk_size, cash_size) 135 | shares = self.round_shares(shares) 136 | 137 | # 再次验证金额是否超过可用资金 138 | if shares * price > cash: 139 | shares = self.round_shares(cash / price) 140 | 141 | return shares if shares >= self.p.min_shares else 0 142 | 143 | def next(self): 144 | # 如果有未完成的订单,不执行新的交易 145 | if self.order: 146 | return 147 | 148 | # 计算当前回撤 149 | current_value = self.broker.getvalue() 150 | self.highest_value = max(self.highest_value, current_value) 151 | drawdown = (self.highest_value - current_value) / self.highest_value 152 | 153 | # 如果回撤超过限制,不开新仓 154 | if drawdown > self.p.max_drawdown: 155 | if self.position: 156 | self.trade_reason = f"触发最大回撤限制 ({drawdown:.2%})" 157 | self.close() 158 | logger.info(f"触发最大回撤限制 - 当前回撤: {drawdown:.2%}, 限制: {self.p.max_drawdown:.2%}") 159 | return 160 | 161 | # 检查是否触及涨跌停 162 | if not self.check_price_limit(self.data.close[0]): 163 | return 164 | 165 | # 获取当前状态 166 | state = self._get_state() 167 | state_tensor = torch.FloatTensor(state).unsqueeze(0).to(self.device) 168 | 169 | # 使用智能体选择动作 170 | with torch.no_grad(): 171 | q_values = self.agent.q_network(state_tensor) 172 | action = q_values.argmax().item() 173 | 174 | current_price = self.data.close[0] 175 | 176 | if action == 1: # 买入 177 | if not self.position: # 没有持仓 178 | shares = self.calculate_trade_size(current_price) 179 | if shares >= self.p.min_shares: 180 | self.trade_reason = "智能体买入信号" 181 | self.order = self.buy(size=shares) 182 | if self.order: 183 | self.entry_price = current_price 184 | logger.info(f"买入信号 - 数量: {shares}, 价格: {current_price:.2f}") 185 | 186 | elif action == 2: # 卖出 187 | if self.position: # 有持仓 188 | self.trade_reason = "智能体卖出信号" 189 | self.order = self.close() 190 | if self.order: 191 | logger.info(f"卖出信号 - 价格: {current_price:.2f}") 192 | 193 | def notify_order(self, order): 194 | if order.status in [order.Submitted, order.Accepted]: 195 | return 196 | 197 | if order.status in [order.Completed]: 198 | if order.isbuy(): 199 | order.info = { 200 | 'reason': self.trade_reason, 201 | 'total_value': self.broker.getvalue(), 202 | 'position_value': self.position.size * order.executed.price if self.position else 0 203 | } 204 | self._orders.append(order) 205 | logger.info( 206 | f'买入执行 - 价格: {order.executed.price:.2f}, ' 207 | f'数量: {order.executed.size}, ' 208 | f'原因: {self.trade_reason}' 209 | ) 210 | else: 211 | self.entry_price = None 212 | order.info = { 213 | 'reason': self.trade_reason, 214 | 'total_value': self.broker.getvalue(), 215 | 'position_value': self.position.size * order.executed.price if self.position else 0 216 | } 217 | self._orders.append(order) 218 | logger.info( 219 | f'卖出执行 - 价格: {order.executed.price:.2f}, ' 220 | f'数量: {order.executed.size}, ' 221 | f'原因: {self.trade_reason}' 222 | ) 223 | elif order.status in [order.Canceled, order.Margin, order.Rejected]: 224 | logger.warning(f'订单失败 - 状态: {order.getstatusname()}') 225 | 226 | self.order = None 227 | 228 | def stop(self): 229 | """策略结束时的汇总信息""" 230 | portfolio_value = self.broker.getvalue() 231 | returns = (portfolio_value / self.broker.startingcash) - 1.0 232 | logger.info(f"策略结束 - 最终资金: {portfolio_value:.2f}, 收益率: {returns:.2%}") 233 | -------------------------------------------------------------------------------- /src/trading/market_executor.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import os 3 | from datetime import datetime, timedelta 4 | import time 5 | import tushare as ts 6 | import akshare as ak 7 | import backtrader as bt 8 | from src.utils.logger import setup_logger 9 | from src.utils.notification import send_notification 10 | from src.strategies.market_sentiment_strategy import MarketSentimentStrategy 11 | from src.data.data_loader import DataLoader 12 | from src.utils.analysis import Analysis 13 | 14 | logger = setup_logger() 15 | 16 | class MarketExecutor: 17 | def __init__(self, symbols: list, tushare_token: str): 18 | self.symbols = symbols 19 | self.records_dir = "data/trading_records" 20 | self.data_loader = DataLoader(tushare_token=tushare_token) 21 | self.analysis = Analysis() 22 | os.makedirs(self.records_dir, exist_ok=True) 23 | 24 | def is_trading_day(self): 25 | """判断当前是否为交易日""" 26 | today = datetime.now().date() 27 | # 获取交易日历 28 | trade_cal = self.data_loader.pro.trade_cal(exchange='SSE', 29 | start_date=today.strftime('%Y%m%d'), 30 | end_date=today.strftime('%Y%m%d')) 31 | return trade_cal.iloc[0]['is_open'] == 1 32 | 33 | def get_realtime_data(self, ts_code): 34 | """获取实时行情数据""" 35 | try: 36 | # 获取实时行情 37 | # 转换为雪球指数代码格式 38 | if ts_code == '000001.SH': 39 | market_code = 'SH000001' 40 | elif ts_code == '000300.SH': 41 | market_code = 'SH000300' 42 | elif ts_code == '000016.SH': 43 | market_code = 'SH000016' 44 | elif ts_code == '399240.SZ': 45 | market_code = 'SZ399240' 46 | else: 47 | market_code = ts_code.replace('.SH', 'SH').replace('.SZ', 'SZ') 48 | 49 | realtime_data = ak.stock_individual_spot_xq(symbol=market_code) 50 | if realtime_data is not None and not realtime_data.empty: 51 | # 将数据转换为以item为索引的格式 52 | realtime_data = realtime_data.set_index('item') 53 | # 转换数据格式以匹配原有接口 54 | return pd.DataFrame({ 55 | 'ts_code': [ts_code], 56 | 'open': [float(realtime_data.loc['今开', 'value'])], 57 | 'high': [float(realtime_data.loc['最高', 'value'])], 58 | 'low': [float(realtime_data.loc['最低', 'value'])], 59 | 'close': [float(realtime_data.loc['现价', 'value'])], 60 | 'pre_close': [float(realtime_data.loc['昨收', 'value'])], 61 | 'vol': [float(realtime_data.loc['成交量', 'value'])], 62 | 'amount': [float(realtime_data.loc['成交额', 'value'])] 63 | }) 64 | return None 65 | except Exception as e: 66 | logger.error(f"获取实时数据失败: {str(e)}") 67 | import traceback 68 | traceback.print_exc() 69 | return None 70 | 71 | def execute(self): 72 | """执行策略回测并记录交易信号""" 73 | try: 74 | # 检查是否为交易日 75 | if not self.is_trading_day(): 76 | logger.info("当前不是交易日,跳过执行") 77 | return 78 | 79 | # 获取当前时间 80 | current_time = datetime.now() 81 | 82 | # 遍历上证50成分股进行回测 83 | for symbol in self.symbols: 84 | try: 85 | # 创建回测引擎 86 | cerebro = bt.Cerebro() 87 | 88 | # 设置初始资金 89 | cerebro.broker.setcash(1000000.0) 90 | 91 | # 设置交易手续费 92 | cerebro.broker.setcommission(commission=0.0003) 93 | 94 | # 获取历史数据 95 | start_date = (current_time - timedelta(days=365)) 96 | end_date = current_time 97 | 98 | # 获取实时数据 99 | # realtime_data = self.get_realtime_data(symbol) 100 | # if realtime_data is None: 101 | # logger.error("获取实时数据失败,跳过本次执行") 102 | # continue 103 | 104 | # 使用DataLoader加载数据 105 | data = self.data_loader.download_data( 106 | symbol=symbol, 107 | start_date=start_date, 108 | end_date=end_date 109 | ) 110 | 111 | if data is None: 112 | logger.error("获取数据失败,跳过本次执行") 113 | continue 114 | 115 | # 添加数据到回测引擎 116 | cerebro.adddata(data) 117 | 118 | # 添加策略 119 | cerebro.addstrategy(MarketSentimentStrategy) 120 | 121 | # 添加分析器 122 | cerebro.addanalyzer(bt.analyzers.SharpeRatio, _name='sharpe', riskfreerate=0.02) 123 | cerebro.addanalyzer(bt.analyzers.DrawDown, _name='drawdown') 124 | cerebro.addanalyzer(bt.analyzers.Returns, _name='returns') 125 | cerebro.addanalyzer(bt.analyzers.TradeAnalyzer, _name='trades') 126 | cerebro.addanalyzer(bt.analyzers.VWR, _name='vwr') 127 | cerebro.addanalyzer(bt.analyzers.SQN, _name='sqn') 128 | cerebro.addanalyzer(bt.analyzers.Transactions, _name='txn') 129 | 130 | # 运行回测 131 | results = cerebro.run() 132 | strategy = results[0] 133 | 134 | # 初始化trades列表 135 | cerebro.trades = [] 136 | 137 | # 使用Analysis类获取回测结果 138 | analysis = self.analysis._get_analysis(cerebro, strategy) 139 | 140 | logger.info(f"回测结果: {analysis}") 141 | 142 | # 获取交易记录 143 | if 'trades' in analysis and not analysis['trades'].empty: 144 | # 获取当天的交易记录 145 | # today_trades = analysis['trades'][ 146 | # analysis['trades']['交易时间'] == current_time.strftime('%Y-%m-%d') 147 | # ] 148 | today_trades = analysis['trades'] 149 | logger.info(f"当天交易记录: {today_trades}, 所有交易记录: {analysis['trades']}") 150 | 151 | # 记录当天的交易 152 | for _, trade in today_trades.iterrows(): 153 | action = "buy" if trade['方向'] == '买入' else "sell" 154 | signal = 1 if action == "buy" else -1 155 | 156 | self._record_trade( 157 | symbol=symbol, 158 | action=action, 159 | timestamp=current_time, 160 | signal=signal, 161 | price=float(trade['成交价']), 162 | size=int(trade['数量']), 163 | reason=trade['交易原因'] 164 | ) 165 | 166 | # 发送通知 167 | message = f"交易提醒: {symbol} {action.upper()} 信号\n" 168 | message += f"价格: {trade['成交价']}, 数量: {trade['数量']}\n" 169 | message += f"原因: {trade['交易原因']}" 170 | send_notification(message) 171 | 172 | except Exception as e: 173 | logger.error(f"处理股票 {symbol} 时出错: {str(e)}") 174 | continue 175 | 176 | except Exception as e: 177 | logger.error(f"执行策略回测时出错: {str(e)}") 178 | send_notification(f"策略回测错误: {str(e)}") 179 | 180 | def _record_trade(self, symbol: str, action: str, timestamp: datetime, signal: float, 181 | price: float, size: int, reason: str): 182 | """记录交易到CSV文件""" 183 | record_file = os.path.join(self.records_dir, f"trades_{timestamp.strftime('%Y%m%d')}.csv") 184 | 185 | record = { 186 | "timestamp": timestamp, 187 | "symbol": symbol, 188 | "action": action, 189 | "signal": signal, 190 | "price": price, 191 | "size": size, 192 | "reason": reason, 193 | "strategy": "market_sentiment" 194 | } 195 | 196 | df = pd.DataFrame([record]) 197 | 198 | if os.path.exists(record_file): 199 | df.to_csv(record_file, mode='a', header=False, index=False) 200 | else: 201 | df.to_csv(record_file, index=False) 202 | 203 | def run_continuously(self, interval: int = 3600): 204 | """持续运行策略回测,每小时执行一次""" 205 | logger.info("启动实盘交易系统") 206 | while True: 207 | self.execute() 208 | time.sleep(interval) -------------------------------------------------------------------------------- /rl_model_finrl/agents/rllib/ppo_agent.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import pandas as pd 4 | import logging 5 | from typing import Dict, List, Tuple, Any, Optional 6 | import ray 7 | from ray import tune 8 | from ray.rllib.algorithms.ppo import PPO 9 | from ray.rllib.utils.framework import try_import_tf, try_import_torch 10 | from ray.tune.registry import register_env 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | class RLlibPPOAgent: 15 | """ 16 | RLlib PPO 智能体实现 17 | 18 | 这个类封装了Ray RLlib的PPO算法,用于ETF交易环境 19 | """ 20 | 21 | def __init__( 22 | self, 23 | env, 24 | model_name: str = "ppo_rllib", 25 | tensorboard_log: Optional[str] = None, 26 | seed: Optional[int] = None, 27 | verbose: int = 1, 28 | device: str = "auto", 29 | **kwargs 30 | ): 31 | """ 32 | 初始化RLlib PPO智能体 33 | 34 | 参数: 35 | env: 训练环境 36 | model_name: 模型名称 37 | tensorboard_log: TensorBoard日志目录 38 | seed: 随机种子 39 | verbose: 详细程度 40 | device: 设备选择 ('cpu', 'cuda', 'auto') 41 | **kwargs: 传递给PPO构造函数的其他参数 42 | """ 43 | self.env = env 44 | self.model_name = model_name 45 | self.tensorboard_log = tensorboard_log 46 | self.seed = seed 47 | self.verbose = verbose 48 | self.device = device 49 | self.ppo_config = { 50 | "framework": "torch", 51 | "num_gpus": 0 if device == "cpu" else 1, 52 | "seed": seed, 53 | # PPO特定参数 54 | "lambda": 0.95, 55 | "kl_coeff": 0.5, 56 | "clip_param": 0.2, 57 | "vf_clip_param": 10.0, 58 | "entropy_coeff": 0.01, 59 | "train_batch_size": 5000, 60 | "sgd_minibatch_size": 500, 61 | "num_sgd_iter": 10, 62 | # 通用参数 63 | "gamma": 0.99, 64 | "lr": 3e-4, 65 | "log_level": "WARN", 66 | } 67 | 68 | # 更新配置 69 | self.ppo_config.update(kwargs) 70 | 71 | # 注册环境 72 | self._register_env() 73 | 74 | # 初始化Ray(如果尚未初始化) 75 | if not ray.is_initialized(): 76 | if self.verbose > 0: 77 | logger.info("初始化Ray...") 78 | ray.init(ignore_reinit_error=True, logging_level=logging.ERROR) 79 | 80 | # 创建PPO算法实例 81 | self.model = PPO( 82 | config=self.ppo_config, 83 | env=self.env.__class__.__name__, 84 | ) 85 | 86 | if self.verbose > 0: 87 | logger.info(f"RLlib PPO智能体初始化完成: {model_name}") 88 | 89 | def _register_env(self): 90 | """注册环境到Ray""" 91 | env_name = self.env.__class__.__name__ 92 | 93 | # 注册环境创建函数 94 | def env_creator(env_config): 95 | return self.env 96 | 97 | # 注册环境 98 | register_env(env_name, env_creator) 99 | 100 | def learn( 101 | self, 102 | total_timesteps: int = 100000, 103 | callback: Any = None, 104 | log_interval: int = 4, 105 | eval_env = None, 106 | eval_freq: int = -1, 107 | n_eval_episodes: int = 5, 108 | tb_log_name: str = "PPO", 109 | eval_log_path: Optional[str] = None, 110 | reset_num_timesteps: bool = True, 111 | ) -> "RLlibPPOAgent": 112 | """ 113 | 训练模型 114 | 115 | 参数: 116 | total_timesteps: 总训练步数 117 | callback: 回调函数 118 | log_interval: 日志记录间隔 119 | eval_env: 评估环境 120 | eval_freq: 评估频率 121 | n_eval_episodes: 评估轮数 122 | tb_log_name: TensorBoard日志名称 123 | eval_log_path: 评估日志路径 124 | reset_num_timesteps: 是否重置时间步计数 125 | 126 | 返回: 127 | 训练后的智能体 128 | """ 129 | if self.verbose > 0: 130 | logger.info(f"开始训练模型,总步数: {total_timesteps}") 131 | 132 | iterations = total_timesteps // self.ppo_config["train_batch_size"] 133 | 134 | for i in range(iterations): 135 | if self.verbose > 0 and i % log_interval == 0: 136 | logger.info(f"训练迭代 {i+1}/{iterations}") 137 | 138 | # 执行训练 139 | result = self.model.train() 140 | 141 | # 打印训练结果 142 | if self.verbose > 0 and i % log_interval == 0: 143 | logger.info(f" 训练回报: {result['episode_reward_mean']:.2f}") 144 | logger.info(f" 训练长度: {result['episode_len_mean']:.2f}") 145 | 146 | # 评估(如果需要) 147 | if eval_env is not None and eval_freq > 0 and i % eval_freq == 0: 148 | self.evaluate(eval_env, n_eval_episodes) 149 | 150 | return self 151 | 152 | def predict( 153 | self, 154 | observation: np.ndarray, 155 | state: Optional[Tuple] = None, 156 | deterministic: bool = True, 157 | ) -> Tuple[np.ndarray, Optional[Tuple]]: 158 | """ 159 | 预测动作 160 | 161 | 参数: 162 | observation: 当前观察 163 | state: RNN状态(如果适用) 164 | deterministic: 是否确定性预测 165 | 166 | 返回: 167 | (动作, 状态)元组 168 | """ 169 | action = self.model.compute_single_action( 170 | observation, 171 | explore=not deterministic 172 | ) 173 | return action, state 174 | 175 | def save(self, path: str) -> None: 176 | """ 177 | 保存模型 178 | 179 | 参数: 180 | path: 保存路径 181 | """ 182 | os.makedirs(os.path.dirname(path), exist_ok=True) 183 | checkpoint_path = self.model.save(path) 184 | if self.verbose > 0: 185 | logger.info(f"模型已保存到: {checkpoint_path}") 186 | 187 | def load(self, path: str) -> None: 188 | """ 189 | 加载模型 190 | 191 | 参数: 192 | path: 模型路径 193 | """ 194 | self.model.restore(path) 195 | if self.verbose > 0: 196 | logger.info(f"已从{path}加载模型") 197 | 198 | def evaluate( 199 | self, 200 | eval_env, 201 | n_eval_episodes: int = 10, 202 | deterministic: bool = True, 203 | render: bool = False, 204 | ) -> Tuple[float, float]: 205 | """ 206 | 评估模型 207 | 208 | 参数: 209 | eval_env: 评估环境 210 | n_eval_episodes: 评估轮数 211 | deterministic: 是否确定性预测 212 | render: 是否渲染环境 213 | 214 | 返回: 215 | (平均奖励, 标准差)元组 216 | """ 217 | if self.verbose > 0: 218 | logger.info(f"开始评估模型,轮数: {n_eval_episodes}") 219 | 220 | episode_rewards = [] 221 | episode_lengths = [] 222 | 223 | for i in range(n_eval_episodes): 224 | obs = eval_env.reset() 225 | done = False 226 | episode_reward = 0.0 227 | episode_length = 0 228 | 229 | while not done: 230 | action, _ = self.predict(obs, deterministic=deterministic) 231 | obs, reward, done, info = eval_env.step(action) 232 | 233 | episode_reward += reward 234 | episode_length += 1 235 | 236 | if render: 237 | eval_env.render() 238 | 239 | episode_rewards.append(episode_reward) 240 | episode_lengths.append(episode_length) 241 | 242 | mean_reward = np.mean(episode_rewards) 243 | std_reward = np.std(episode_rewards) 244 | 245 | if self.verbose > 0: 246 | logger.info(f"评估结果 - 平均奖励: {mean_reward:.2f} +/- {std_reward:.2f}") 247 | 248 | return mean_reward, std_reward 249 | 250 | def test( 251 | self, 252 | test_env, 253 | num_episodes: int = 1, 254 | render: bool = False 255 | ) -> Tuple[pd.DataFrame, pd.DataFrame]: 256 | """ 257 | 测试模型并返回资产和动作记录 258 | 259 | 参数: 260 | test_env: 测试环境 261 | num_episodes: 测试轮数 262 | render: 是否渲染环境 263 | 264 | 返回: 265 | (资产记录, 动作记录)元组 266 | """ 267 | if self.verbose > 0: 268 | logger.info(f"开始测试模型,轮数: {num_episodes}") 269 | 270 | # 重置环境 271 | state = test_env.reset() 272 | 273 | # 初始化记录 274 | episode_rewards = [] 275 | 276 | # 决策步骤 277 | done = False 278 | episode_reward = 0.0 279 | 280 | while not done: 281 | action, _ = self.predict(state, deterministic=True) 282 | next_state, reward, done, info = test_env.step(action) 283 | 284 | state = next_state 285 | episode_reward += reward 286 | 287 | if render: 288 | test_env.render() 289 | 290 | if self.verbose > 0: 291 | logger.info(f"测试完成,总奖励: {episode_reward:.2f}") 292 | 293 | # 获取回测数据 294 | asset_memory = test_env.save_asset_memory() 295 | action_memory = test_env.save_action_memory() 296 | 297 | return asset_memory, action_memory -------------------------------------------------------------------------------- /src/data/data_loader.py: -------------------------------------------------------------------------------- 1 | import tushare as ts 2 | import akshare as ak 3 | import pandas as pd 4 | import backtrader as bt 5 | from datetime import datetime 6 | import os 7 | from loguru import logger 8 | 9 | class DataLoader: 10 | def __init__(self, tushare_token=None): 11 | """ 12 | 初始化数据加载器 13 | :param tushare_token: Tushare的API token 14 | """ 15 | self.tushare_token = tushare_token 16 | if tushare_token: 17 | ts.set_token(tushare_token) 18 | self.pro = ts.pro_api() 19 | logger.info("Tushare API初始化成功") 20 | 21 | def download_data(self, symbol, start_date, end_date): 22 | """ 23 | 下载数据,支持A股、ETF和港股 24 | :param symbol: 股票代码(格式:000001.SZ, 510300.SH, 00700.HK等)或ETF代码列表 25 | :param start_date: 开始日期 26 | :param end_date: 结束日期 27 | :return: DataFrame或DataFrame列表 28 | """ 29 | try: 30 | # 转换日期格式 31 | start_str = start_date.strftime("%Y%m%d") 32 | end_str = end_date.strftime("%Y%m%d") 33 | 34 | # 如果是ETF轮动策略,symbol应该是一个列表 35 | if isinstance(symbol, list): 36 | logger.info(f"开始下载多个ETF数据: {symbol}") 37 | data_list = [] 38 | for etf in symbol: 39 | df = self._download_etf_data(etf, start_date, end_date) 40 | if not df.empty: 41 | data = PandasData(dataname=df, ts_code=etf, fromdate=start_date, todate=end_date) 42 | data_list.append(data) 43 | logger.info(f"成功下载 {len(data_list)} 个ETF数据") 44 | return data_list 45 | 46 | # 判断市场类型 47 | if symbol.endswith(('.SH', '.SZ')): # A股或ETF 48 | if symbol.startswith('51') or symbol.startswith('159'): # ETF 49 | logger.info(f"下载ETF数据: {symbol}") 50 | df = self._download_etf_data(symbol, start_date, end_date) 51 | else: # A股 52 | logger.info(f"下载A股数据: {symbol}") 53 | df = self._download_stock_data(symbol, start_str, end_str) 54 | elif symbol.endswith('.HK'): # 港股 55 | logger.info(f"使用AKShare下载港股数据: {symbol}") 56 | df = self._download_hk_data(symbol, start_date, end_date) 57 | else: 58 | raise ValueError(f"不支持的市场类型: {symbol}") 59 | 60 | if df.empty: 61 | logger.warning(f"下载数据为空: {symbol}") 62 | return None 63 | 64 | logger.info(f"下载数据成功: {symbol},数据长度: {len(df)}, 日期范围: {df.index[0]} 至 {df.index[-1]}") 65 | 66 | # 创建PandasData对象并设置股票代码和时间范围 67 | data = PandasData(dataname=df, ts_code=symbol, fromdate=start_date, todate=end_date) 68 | return data 69 | 70 | except Exception as e: 71 | logger.error(f"下载数据失败: {str(e)}") 72 | import traceback 73 | traceback.print_exc() 74 | raise 75 | 76 | def _download_stock_data(self, symbol, start_date, end_date): 77 | """下载A股数据""" 78 | if not self.tushare_token: 79 | raise ValueError("需要设置Tushare token才能下载A股数据") 80 | 81 | df = self.pro.daily(ts_code=symbol, start_date=start_date, end_date=end_date) 82 | if df.empty: 83 | return pd.DataFrame() 84 | 85 | # 重命名列以匹配backtrader要求 86 | df = df.rename(columns={ 87 | 'trade_date': 'date', 88 | 'open': 'open', 89 | 'high': 'high', 90 | 'low': 'low', 91 | 'close': 'close', 92 | 'vol': 'volume' 93 | }) 94 | 95 | # 转换日期格式并设置为索引 96 | df['date'] = pd.to_datetime(df['date']) 97 | df = df.set_index('date') 98 | df = df.sort_index() # 确保按日期升序 99 | 100 | return df[['open', 'high', 'low', 'close', 'volume']] 101 | 102 | def _download_etf_data(self, symbol, start_date, end_date): 103 | """下载ETF数据""" 104 | symbol_code = symbol.split('.')[0] # 去掉市场后缀 105 | start_str = start_date.strftime('%Y%m%d') 106 | end_str = end_date.strftime('%Y%m%d') 107 | 108 | try: 109 | # 判断是否有tushare_token 110 | if self.tushare_token: 111 | logger.info(f"使用Tushare下载ETF数据: {symbol}") 112 | df = self.pro.fund_daily(ts_code=symbol, start_date=start_str, end_date=end_str) 113 | 114 | if not df.empty: 115 | # 重命名列以匹配backtrader要求 116 | df = df.rename(columns={ 117 | 'trade_date': 'date', 118 | 'open': 'open', 119 | 'high': 'high', 120 | 'low': 'low', 121 | 'close': 'close', 122 | 'vol': 'volume' 123 | }) 124 | 125 | # 转换日期格式并设置为索引 126 | df['date'] = pd.to_datetime(df['date']) 127 | df = df.set_index('date') 128 | df = df.sort_index() # 确保按日期升序 129 | else: 130 | logger.warning(f"Tushare未返回ETF数据,尝试使用AKShare: {symbol}") 131 | df = self._download_etf_from_akshare(symbol_code, start_date, end_date) 132 | else: 133 | logger.info(f"未设置Tushare token,使用AKShare下载ETF数据: {symbol}") 134 | df = self._download_etf_from_akshare(symbol_code, start_date, end_date) 135 | 136 | if df.empty: 137 | logger.warning(f"下载ETF数据为空 - 股票代码: {symbol}, 日期范围: {start_date} 至 {end_date}") 138 | return pd.DataFrame() 139 | 140 | logger.info(f"下载ETF数据成功: {symbol},数据长度: {len(df)}, 日期范围: {df.index[0]} 至 {df.index[-1]}") 141 | return df[['open', 'high', 'low', 'close', 'volume']] 142 | 143 | except Exception as e: 144 | logger.error(f"下载ETF数据失败: {str(e)}") 145 | import traceback 146 | traceback.print_exc() 147 | return pd.DataFrame() 148 | 149 | def _download_etf_from_akshare(self, symbol_code, start_date, end_date): 150 | """使用AKShare下载ETF数据""" 151 | df = ak.fund_etf_hist_em(symbol=symbol_code, period="daily") 152 | 153 | # 重命名列以匹配backtrader要求 154 | df = df.rename(columns={ 155 | '日期': 'date', 156 | '开盘': 'open', 157 | '最高': 'high', 158 | '最低': 'low', 159 | '收盘': 'close', 160 | '成交量': 'volume' 161 | }) 162 | 163 | # 转换日期格式并设置为索引 164 | df['date'] = pd.to_datetime(df['date']) 165 | df = df.set_index('date') 166 | df = df.sort_index() # 确保按日期升序 167 | 168 | # 过滤日期范围 169 | logger.info(f"过滤日期范围: {start_date} 至 {end_date}") 170 | mask = (df.index >= pd.Timestamp(start_date)) & (df.index <= pd.Timestamp(end_date)) 171 | df = df[mask] 172 | 173 | return df 174 | 175 | def _download_hk_data(self, symbol, start_date, end_date): 176 | """下载港股数据""" 177 | symbol_code = symbol.split('.')[0] # 去掉市场后缀 178 | 179 | try: 180 | df = ak.stock_hk_daily(symbol=symbol_code) 181 | 182 | # 重命名列以匹配backtrader要求 183 | df = df.rename(columns={ 184 | '日期': 'date', 185 | '开盘': 'open', 186 | '最高': 'high', 187 | '最低': 'low', 188 | '收盘': 'close', 189 | '成交量': 'volume' 190 | }) 191 | 192 | # 转换日期格式并设置为索引 193 | df['date'] = pd.to_datetime(df['date']) 194 | df = df.set_index('date') 195 | 196 | # 过滤日期范围 197 | df = df.loc[start_date:end_date] 198 | 199 | return df[['open', 'high', 'low', 'close', 'volume']] 200 | 201 | except Exception as e: 202 | logger.error(f"下载港股数据失败: {str(e)}") 203 | return pd.DataFrame() 204 | 205 | class PandasData(bt.feeds.PandasData): 206 | """自定义PandasData类,用于加载数据""" 207 | params = ( 208 | ('datetime', None), # 使用索引作为日期 209 | ('open', 'open'), 210 | ('high', 'high'), 211 | ('low', 'low'), 212 | ('close', 'close'), 213 | ('volume', 'volume'), 214 | ('openinterest', None), 215 | ('ts_code', None), # 股票代码 216 | ('fromdate', None), 217 | ('todate', None), 218 | ) 219 | 220 | def __init__(self, **kwargs): 221 | """初始化数据源""" 222 | # 从kwargs中获取ts_code和日期范围 223 | self.ts_code = kwargs.pop('ts_code', None) 224 | self.fromdate = kwargs.pop('fromdate', None) 225 | self.todate = kwargs.pop('todate', None) 226 | 227 | # 确保参数被正确设置 228 | if self.ts_code: 229 | self.params.ts_code = self.ts_code 230 | if self.fromdate: 231 | self.params.fromdate = self.fromdate 232 | if self.todate: 233 | self.params.todate = self.todate 234 | 235 | # 调用父类初始化 236 | super().__init__(**kwargs) 237 | 238 | # 设置数据源的名称 239 | if self.ts_code: 240 | self._name = self.ts_code -------------------------------------------------------------------------------- /tools/plot_general_json.py: -------------------------------------------------------------------------------- 1 | import json 2 | import pandas as pd 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | import plotly.express as px 6 | import plotly.graph_objects as go 7 | import streamlit as st 8 | from datetime import datetime 9 | import os 10 | 11 | def load_json_data(file_path=None, uploaded_file=None): 12 | """从文件或上传的数据加载JSON数据""" 13 | try: 14 | if uploaded_file is not None: 15 | data = json.load(uploaded_file) 16 | elif file_path is not None: 17 | with open(file_path, 'r', encoding='utf-8') as f: 18 | data = json.load(f) 19 | else: 20 | st.error("未提供数据源") 21 | return None 22 | return data 23 | except Exception as e: 24 | st.error(f"加载JSON文件时出错:{e}") 25 | return None 26 | 27 | def flatten_json(data, parent_key='', sep='_'): 28 | """展平嵌套的JSON数据结构""" 29 | items = [] 30 | for k, v in data.items(): 31 | new_key = f"{parent_key}{sep}{k}" if parent_key else k 32 | if isinstance(v, dict): 33 | items.extend(flatten_json(v, new_key, sep=sep).items()) 34 | elif isinstance(v, list): 35 | if len(v) > 0 and isinstance(v[0], dict): 36 | # 处理字典列表 37 | for i, item in enumerate(v): 38 | items.extend(flatten_json(item, f"{new_key}{sep}{i}", sep=sep).items()) 39 | else: 40 | items.append((new_key, str(v))) 41 | else: 42 | items.append((new_key, v)) 43 | return dict(items) 44 | 45 | def process_json_data(data): 46 | """处理JSON数据以准备可视化""" 47 | if isinstance(data, list): 48 | # 如果是列表,尝试将其转换为DataFrame 49 | if len(data) > 0 and isinstance(data[0], dict): 50 | df = pd.DataFrame(data) 51 | else: 52 | df = pd.DataFrame({"值": data}) 53 | elif isinstance(data, dict): 54 | # 如果是字典,展平它 55 | flat_data = flatten_json(data) 56 | df = pd.DataFrame([flat_data]) 57 | else: 58 | st.error("不支持的JSON数据格式") 59 | return None 60 | 61 | # 尝试转换日期列 62 | for col in df.columns: 63 | if df[col].dtype == object: 64 | try: 65 | df[col] = pd.to_datetime(df[col]) 66 | except: 67 | pass 68 | 69 | return df 70 | 71 | def get_numeric_columns(df): 72 | """获取数值类型的列""" 73 | return df.select_dtypes(include=[np.number]).columns.tolist() 74 | 75 | def get_datetime_columns(df): 76 | """获取日期时间类型的列""" 77 | return df.select_dtypes(include=['datetime64']).columns.tolist() 78 | 79 | def get_categorical_columns(df): 80 | """获取分类类型的列""" 81 | return df.select_dtypes(include=['object', 'category']).columns.tolist() 82 | 83 | def plot_time_series(df, x_col, y_col, title=None): 84 | """创建时间序列图表""" 85 | fig = px.line(df, x=x_col, y=y_col, 86 | title=title or f'{y_col}随{x_col}的变化', 87 | labels={x_col: x_col, y_col: y_col}) 88 | 89 | fig.update_layout( 90 | xaxis_title=x_col, 91 | yaxis_title=y_col, 92 | hovermode='x unified' 93 | ) 94 | 95 | return fig 96 | 97 | def plot_scatter(df, x_col, y_col, color_col=None, title=None): 98 | """创建散点图""" 99 | fig = px.scatter(df, x=x_col, y=y_col, color=color_col, 100 | title=title or f'{x_col}与{y_col}的关系', 101 | labels={x_col: x_col, y_col: y_col}) 102 | 103 | fig.update_layout( 104 | xaxis_title=x_col, 105 | yaxis_title=y_col, 106 | hovermode='closest' 107 | ) 108 | 109 | return fig 110 | 111 | def plot_bar(df, x_col, y_col, title=None): 112 | """创建柱状图""" 113 | fig = px.bar(df, x=x_col, y=y_col, 114 | title=title or f'{x_col}的{y_col}分布', 115 | labels={x_col: x_col, y_col: y_col}) 116 | 117 | fig.update_layout( 118 | xaxis_title=x_col, 119 | yaxis_title=y_col, 120 | bargap=0.2 121 | ) 122 | 123 | return fig 124 | 125 | def plot_histogram(df, col, title=None): 126 | """创建直方图""" 127 | fig = px.histogram(df, x=col, 128 | title=title or f'{col}的分布', 129 | labels={col: col}) 130 | 131 | fig.update_layout( 132 | xaxis_title=col, 133 | yaxis_title='频率', 134 | bargap=0.1 135 | ) 136 | 137 | return fig 138 | 139 | def plot_box(df, x_col, y_col, title=None): 140 | """创建箱线图""" 141 | fig = px.box(df, x=x_col, y=y_col, 142 | title=title or f'{y_col}按{x_col}的分布', 143 | labels={x_col: x_col, y_col: y_col}) 144 | 145 | fig.update_layout( 146 | xaxis_title=x_col, 147 | yaxis_title=y_col 148 | ) 149 | 150 | return fig 151 | 152 | def plot_heatmap(df, x_col, y_col, values_col, title=None): 153 | """创建热力图""" 154 | heatmap_data = df.pivot_table( 155 | values=values_col, 156 | index=y_col, 157 | columns=x_col, 158 | aggfunc='mean' 159 | ) 160 | 161 | fig = go.Figure(data=go.Heatmap( 162 | z=heatmap_data.values, 163 | x=heatmap_data.columns, 164 | y=heatmap_data.index, 165 | colorscale='Viridis' 166 | )) 167 | 168 | fig.update_layout( 169 | title=title or f'{x_col}和{y_col}的{values_col}热力图', 170 | xaxis_title=x_col, 171 | yaxis_title=y_col 172 | ) 173 | 174 | return fig 175 | 176 | def main(): 177 | st.title("通用JSON数据可视化工具") 178 | st.write("此应用程序可以可视化任何结构的JSON数据。") 179 | 180 | # 数据输入选项 181 | data_source = st.radio( 182 | "选择数据来源", 183 | ["从cache目录加载", "上传JSON文件"] 184 | ) 185 | 186 | data = None 187 | if data_source == "从cache目录加载": 188 | json_files = [f for f in os.listdir('../cache') if f.endswith('.json')] 189 | if not json_files: 190 | st.error("在'cache'目录中未找到JSON文件。") 191 | return 192 | 193 | selected_file = st.selectbox("选择要可视化的JSON文件", json_files) 194 | file_path = os.path.join('../cache', selected_file) 195 | data = load_json_data(file_path=file_path) 196 | else: 197 | uploaded_file = st.file_uploader("上传JSON文件", type=['json']) 198 | if uploaded_file is not None: 199 | data = load_json_data(uploaded_file=uploaded_file) 200 | 201 | if data is None: 202 | return 203 | 204 | # 处理数据 205 | df = process_json_data(data) 206 | if df is None: 207 | return 208 | 209 | # 显示基本信息 210 | st.subheader("数据概览") 211 | st.write(f"记录数量:{len(df)}") 212 | st.write(f"列数量:{len(df.columns)}") 213 | 214 | # 获取不同类型的列 215 | numeric_cols = get_numeric_columns(df) 216 | datetime_cols = get_datetime_columns(df) 217 | categorical_cols = get_categorical_columns(df) 218 | 219 | # 可视化选项 220 | viz_type = st.selectbox( 221 | "选择可视化类型", 222 | ["时间序列图", "散点图", "柱状图", "直方图", "箱线图", "热力图"] 223 | ) 224 | 225 | if viz_type == "时间序列图": 226 | if not datetime_cols: 227 | st.warning("数据中没有日期时间列,无法创建时间序列图。") 228 | return 229 | 230 | x_col = st.selectbox("选择时间列", datetime_cols) 231 | y_col = st.selectbox("选择数值列", numeric_cols) 232 | 233 | fig = plot_time_series(df, x_col, y_col) 234 | st.plotly_chart(fig) 235 | 236 | elif viz_type == "散点图": 237 | x_col = st.selectbox("选择X轴", numeric_cols) 238 | y_col = st.selectbox("选择Y轴", numeric_cols) 239 | color_col = st.selectbox("选择颜色分类(可选)", ["无"] + categorical_cols) 240 | 241 | fig = plot_scatter(df, x_col, y_col, 242 | color_col if color_col != "无" else None) 243 | st.plotly_chart(fig) 244 | 245 | elif viz_type == "柱状图": 246 | x_col = st.selectbox("选择分类列", categorical_cols) 247 | y_col = st.selectbox("选择数值列", numeric_cols) 248 | 249 | fig = plot_bar(df, x_col, y_col) 250 | st.plotly_chart(fig) 251 | 252 | elif viz_type == "直方图": 253 | col = st.selectbox("选择数值列", numeric_cols) 254 | 255 | fig = plot_histogram(df, col) 256 | st.plotly_chart(fig) 257 | 258 | elif viz_type == "箱线图": 259 | x_col = st.selectbox("选择分组列", categorical_cols) 260 | y_col = st.selectbox("选择数值列", numeric_cols) 261 | 262 | fig = plot_box(df, x_col, y_col) 263 | st.plotly_chart(fig) 264 | 265 | elif viz_type == "热力图": 266 | x_col = st.selectbox("选择X轴分类", categorical_cols) 267 | y_col = st.selectbox("选择Y轴分类", categorical_cols) 268 | values_col = st.selectbox("选择数值列", numeric_cols) 269 | 270 | fig = plot_heatmap(df, x_col, y_col, values_col) 271 | st.plotly_chart(fig) 272 | 273 | # 显示数据统计 274 | if st.checkbox("显示数据统计"): 275 | st.subheader("数据统计") 276 | st.write("数值列统计:") 277 | st.dataframe(df[numeric_cols].describe()) 278 | 279 | if categorical_cols: 280 | st.write("分类列统计:") 281 | for col in categorical_cols: 282 | st.write(f"\n{col}的唯一值数量:{df[col].nunique()}") 283 | st.dataframe(df[col].value_counts().head()) 284 | 285 | # 显示原始数据 286 | if st.checkbox("显示原始数据"): 287 | st.subheader("原始数据") 288 | st.dataframe(df) 289 | 290 | # 添加下载按钮 291 | csv = df.to_csv(index=False).encode('utf-8') 292 | st.download_button( 293 | label="下载数据为CSV", 294 | data=csv, 295 | file_name=f'data_{datetime.now().strftime("%Y%m%d")}.csv', 296 | mime='text/csv', 297 | ) 298 | 299 | if __name__ == "__main__": 300 | main() -------------------------------------------------------------------------------- /rl_model_finrl/meta/preprocessor/feature_engineer.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | import talib 4 | from typing import Dict, List, Tuple, Union, Optional 5 | 6 | class FeatureEngineer: 7 | """ 8 | 金融特征工程类 9 | 10 | 用于计算和添加金融技术指标和特征 11 | """ 12 | 13 | def __init__(self, 14 | use_technical_indicators: bool = True, 15 | use_vix: bool = False, 16 | use_turbulence: bool = False, 17 | use_sentiment: bool = False): 18 | """ 19 | 初始化特征工程器 20 | 21 | 参数: 22 | use_technical_indicators: 是否使用技术指标 23 | use_vix: 是否使用波动率指数 24 | use_turbulence: 是否使用市场波动指标 25 | use_sentiment: 是否使用情绪指标 26 | """ 27 | self.use_technical_indicators = use_technical_indicators 28 | self.use_vix = use_vix 29 | self.use_turbulence = use_turbulence 30 | self.use_sentiment = use_sentiment 31 | 32 | def preprocess(self, df: pd.DataFrame) -> pd.DataFrame: 33 | """ 34 | 预处理数据并添加特征 35 | 36 | 参数: 37 | df: 原始数据DataFrame 38 | 39 | 返回: 40 | 处理后的DataFrame 41 | """ 42 | data = df.copy() 43 | 44 | # 添加技术指标 45 | if self.use_technical_indicators: 46 | data = self.add_technical_indicators(data) 47 | 48 | # 添加波动率指数 49 | if self.use_vix: 50 | data = self.add_vix(data) 51 | 52 | # 添加市场波动指标 53 | if self.use_turbulence: 54 | data = self.add_turbulence(data) 55 | 56 | # 添加情绪指标 57 | if self.use_sentiment: 58 | data = self.add_sentiment(data) 59 | 60 | # 填充NaN值 61 | data = self.fill_missing_values(data) 62 | 63 | return data 64 | 65 | def add_technical_indicators(self, data: pd.DataFrame) -> pd.DataFrame: 66 | """ 67 | 添加技术指标 68 | 69 | 参数: 70 | data: 价格数据DataFrame 71 | 72 | 返回: 73 | 添加技术指标后的DataFrame 74 | """ 75 | df = data.copy() 76 | 77 | # 确保列名标准化 78 | price_col = 'close' if 'close' in df.columns else 'Close' 79 | high_col = 'high' if 'high' in df.columns else 'High' 80 | low_col = 'low' if 'low' in df.columns else 'Low' 81 | volume_col = 'volume' if 'volume' in df.columns else 'Volume' 82 | 83 | # 使用TA-Lib计算技术指标 (如果可用) 84 | try: 85 | # MACD 86 | macd, macd_signal, macd_hist = talib.MACD(df[price_col]) 87 | df['macd'] = macd 88 | df['macd_signal'] = macd_signal 89 | df['macd_hist'] = macd_hist 90 | 91 | # RSI 92 | df['rsi_6'] = talib.RSI(df[price_col], timeperiod=6) 93 | df['rsi_14'] = talib.RSI(df[price_col], timeperiod=14) 94 | df['rsi_30'] = talib.RSI(df[price_col], timeperiod=30) 95 | 96 | # CCI 97 | df['cci'] = talib.CCI(df[high_col], df[low_col], df[price_col], timeperiod=14) 98 | 99 | # ADX 100 | df['adx'] = talib.ADX(df[high_col], df[low_col], df[price_col], timeperiod=14) 101 | 102 | # 布林带 103 | df['boll_upper'], df['boll_middle'], df['boll_lower'] = talib.BBANDS( 104 | df[price_col], timeperiod=20, nbdevup=2, nbdevdn=2, matype=0) 105 | 106 | # ATR 107 | df['atr'] = talib.ATR(df[high_col], df[low_col], df[price_col], timeperiod=14) 108 | 109 | # 移动平均线 110 | df['sma_5'] = talib.SMA(df[price_col], timeperiod=5) 111 | df['sma_10'] = talib.SMA(df[price_col], timeperiod=10) 112 | df['sma_20'] = talib.SMA(df[price_col], timeperiod=20) 113 | df['sma_60'] = talib.SMA(df[price_col], timeperiod=60) 114 | 115 | # WILLR 116 | df['willr'] = talib.WILLR(df[high_col], df[low_col], df[price_col], timeperiod=14) 117 | 118 | # ROC 119 | df['roc'] = talib.ROC(df[price_col], timeperiod=10) 120 | 121 | # OBV 122 | df['obv'] = talib.OBV(df[price_col], df[volume_col]) 123 | 124 | except (ImportError, AttributeError): 125 | # 如果TA-Lib不可用,使用Pandas计算 126 | # MACD 127 | df['ema12'] = df[price_col].ewm(span=12, adjust=False).mean() 128 | df['ema26'] = df[price_col].ewm(span=26, adjust=False).mean() 129 | df['macd'] = df['ema12'] - df['ema26'] 130 | df['macd_signal'] = df['macd'].ewm(span=9, adjust=False).mean() 131 | df['macd_hist'] = df['macd'] - df['macd_signal'] 132 | 133 | # RSI 134 | delta = df[price_col].diff() 135 | gain = delta.copy() 136 | loss = delta.copy() 137 | gain[gain < 0] = 0 138 | loss[loss > 0] = 0 139 | loss = -loss 140 | 141 | # RSI 14 142 | avg_gain = gain.rolling(window=14).mean() 143 | avg_loss = loss.rolling(window=14).mean() 144 | rs = avg_gain / avg_loss 145 | df['rsi_14'] = 100 - (100 / (1 + rs)) 146 | 147 | # CCI 148 | df['tp'] = (df[high_col] + df[low_col] + df[price_col]) / 3 149 | df['tp_ma'] = df['tp'].rolling(window=20).mean() 150 | mean_dev = df['tp'].rolling(window=20).apply(lambda x: abs(x - x.mean()).mean()) 151 | df['cci'] = (df['tp'] - df['tp_ma']) / (0.015 * mean_dev) 152 | 153 | # 布林带 154 | df['sma_20'] = df[price_col].rolling(window=20).mean() 155 | df['boll_upper'] = df['sma_20'] + 2 * df[price_col].rolling(window=20).std() 156 | df['boll_lower'] = df['sma_20'] - 2 * df[price_col].rolling(window=20).std() 157 | 158 | # ATR 159 | df['tr1'] = df[high_col] - df[low_col] 160 | df['tr2'] = abs(df[high_col] - df[price_col].shift()) 161 | df['tr3'] = abs(df[low_col] - df[price_col].shift()) 162 | df['tr'] = df[['tr1', 'tr2', 'tr3']].max(axis=1) 163 | df['atr'] = df['tr'].rolling(window=14).mean() 164 | 165 | # 移动平均线 166 | df['sma_5'] = df[price_col].rolling(window=5).mean() 167 | df['sma_10'] = df[price_col].rolling(window=10).mean() 168 | df['sma_60'] = df[price_col].rolling(window=60).mean() 169 | 170 | # 清理临时列 171 | df = df.drop(['tr1', 'tr2', 'tr3', 'tr', 'tp', 'tp_ma', 'ema12', 'ema26'], 172 | axis=1, errors='ignore') 173 | 174 | # 涨跌幅 175 | df['daily_return'] = df[price_col].pct_change() 176 | 177 | # 波动率 178 | df['volatility'] = df['daily_return'].rolling(window=20).std() * np.sqrt(252) 179 | 180 | return df 181 | 182 | def add_vix(self, data: pd.DataFrame, vix_data: Optional[pd.DataFrame] = None) -> pd.DataFrame: 183 | """ 184 | 添加VIX波动率指数 185 | 186 | 参数: 187 | data: 原始数据DataFrame 188 | vix_data: VIX数据DataFrame(可选) 189 | 190 | 返回: 191 | 添加VIX后的DataFrame 192 | """ 193 | df = data.copy() 194 | 195 | if vix_data is not None: 196 | # 如果提供了VIX数据,合并到主数据中 197 | vix_data = vix_data.rename(columns={'close': 'vix'}) 198 | df = pd.merge(df, vix_data[['vix']], 199 | left_index=True, right_index=True, 200 | how='left') 201 | else: 202 | # 如果没有VIX数据,使用收益率的滚动波动率作为VIX代理 203 | price_col = 'close' if 'close' in df.columns else 'Close' 204 | returns = df[price_col].pct_change() 205 | df['vix'] = returns.rolling(window=20).std() * np.sqrt(252) * 100 206 | 207 | return df 208 | 209 | def add_turbulence(self, data: pd.DataFrame, window: int = 252) -> pd.DataFrame: 210 | """ 211 | 添加市场波动指标 212 | 213 | 参数: 214 | data: 原始数据DataFrame 215 | window: 计算窗口 216 | 217 | 返回: 218 | 添加波动指标后的DataFrame 219 | """ 220 | df = data.copy() 221 | 222 | # 计算收益率 223 | price_col = 'close' if 'close' in df.columns else 'Close' 224 | df['return'] = df[price_col].pct_change() 225 | 226 | # 计算波动指标 227 | df['turbulence'] = df['return'].rolling(window=window).apply( 228 | lambda x: np.sum(np.square(x - x.mean())) / len(x) 229 | ) 230 | 231 | return df 232 | 233 | def add_sentiment(self, data: pd.DataFrame, sentiment_data: Optional[pd.DataFrame] = None) -> pd.DataFrame: 234 | """ 235 | 添加市场情绪指标 236 | 237 | 参数: 238 | data: 原始数据DataFrame 239 | sentiment_data: 情绪数据DataFrame(可选) 240 | 241 | 返回: 242 | 添加情绪指标后的DataFrame 243 | """ 244 | df = data.copy() 245 | 246 | if sentiment_data is not None: 247 | # 如果提供了情绪数据,合并到主数据中 248 | df = pd.merge(df, sentiment_data, 249 | left_index=True, right_index=True, 250 | how='left') 251 | else: 252 | # 如果没有情绪数据,简单地添加一个基于技术指标的情绪代理 253 | if 'rsi_14' in df.columns: 254 | # RSI的简单情绪指标: RSI高表示乐观,RSI低表示悲观 255 | df['sentiment'] = (df['rsi_14'] - 50) / 50 # 归一化到[-1, 1] 256 | elif 'macd' in df.columns and 'macd_signal' in df.columns: 257 | # 基于MACD信号的情绪: MACD > 信号线表示乐观 258 | df['sentiment'] = np.where(df['macd'] > df['macd_signal'], 1, -1) 259 | else: 260 | # 如果没有其他指标,使用收益率的动量作为情绪代理 261 | price_col = 'close' if 'close' in df.columns else 'Close' 262 | returns = df[price_col].pct_change() 263 | df['sentiment'] = returns.rolling(window=5).mean() / returns.rolling(window=5).std() 264 | 265 | return df 266 | 267 | def fill_missing_values(self, data: pd.DataFrame) -> pd.DataFrame: 268 | """ 269 | 填充缺失值 270 | 271 | 参数: 272 | data: 包含缺失值的DataFrame 273 | 274 | 返回: 275 | 填充缺失值后的DataFrame 276 | """ 277 | df = data.copy() 278 | 279 | # 前向填充(填充交易日中间缺失的数据) 280 | df = df.fillna(method='ffill') 281 | 282 | # 后向填充(处理最早的数据) 283 | df = df.fillna(method='bfill') 284 | 285 | # 对于仍然缺失的值,用0填充 286 | df = df.fillna(0) 287 | 288 | return df -------------------------------------------------------------------------------- /rl_model_finrl/agents/stablebaseline3/dqn_agent.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import pandas as pd 4 | from typing import Any, Dict, List, Tuple, Type, Union 5 | import torch 6 | import torch.nn as nn 7 | from stable_baselines3 import DQN 8 | from stable_baselines3.common.callbacks import BaseCallback 9 | from stable_baselines3.common.vec_env import DummyVecEnv 10 | from stable_baselines3.common.logger import configure 11 | import logging 12 | import matplotlib.pyplot as plt 13 | 14 | from src.strategies.rl_model_finrl.config import ( 15 | GAMMA, 16 | LEARNING_RATE, 17 | BATCH_SIZE, 18 | REPLAY_BUFFER_SIZE, 19 | NUM_EPISODES, 20 | TENSORBOARD_PATH 21 | ) 22 | 23 | class TensorboardCallback(BaseCallback): 24 | """用于记录训练过程中的指标的回调函数""" 25 | 26 | def __init__(self, verbose=0): 27 | super(TensorboardCallback, self).__init__(verbose) 28 | self.training_env = None 29 | 30 | def _on_training_start(self) -> None: 31 | """训练开始时设置训练环境引用""" 32 | self.training_env = self.model.get_env() 33 | self._log_freq = 1 # 记录频率 = 1个回合 34 | 35 | def _on_step(self) -> bool: 36 | """每步执行,并记录指标""" 37 | if self.n_calls % self._log_freq == 0: 38 | # 从训练环境中获取信息 39 | info = self.training_env.buf_infos[0] 40 | portfolio_value = info.get('portfolio_value', 0) 41 | step_return = info.get('step_return', 0) 42 | 43 | # 记录到TensorBoard 44 | self.logger.record('portfolio_value', portfolio_value) 45 | self.logger.record('step_return', step_return) 46 | 47 | return True 48 | 49 | class DQNAgent: 50 | """使用 stable-baselines3 的DQN智能体""" 51 | 52 | def __init__( 53 | self, 54 | env, 55 | model_name: str = "dqn_etf_trading", 56 | policy_type: str = "MlpPolicy", 57 | learning_rate: float = LEARNING_RATE, 58 | gamma: float = GAMMA, 59 | batch_size: int = BATCH_SIZE, 60 | buffer_size: int = REPLAY_BUFFER_SIZE, 61 | exploration_fraction: float = 0.2, 62 | exploration_initial_eps: float = 1.0, 63 | exploration_final_eps: float = 0.1, 64 | verbose: int = 1, 65 | tensorboard_log: str = TENSORBOARD_PATH, 66 | device: str = "auto" 67 | ): 68 | """ 69 | 初始化DQN智能体 70 | 71 | 参数: 72 | env: 交易环境 73 | model_name: 模型名称 74 | policy_type: 策略类型,默认MlpPolicy 75 | learning_rate: 学习率 76 | gamma: 折扣因子 77 | batch_size: 批处理大小 78 | buffer_size: 经验回放缓冲区大小 79 | exploration_fraction: 探索阶段占总训练步数的比例 80 | exploration_initial_eps: 初始探索率 81 | exploration_final_eps: 最终探索率 82 | verbose: 日志详细程度 83 | tensorboard_log: TensorBoard日志目录 84 | device: 运行设备,"auto"会自动选择GPU或CPU 85 | """ 86 | self.env = env 87 | self.model_name = model_name 88 | 89 | # 设置日志 90 | logging.basicConfig(level=logging.INFO) 91 | self.logger = logging.getLogger(__name__) 92 | 93 | # 检查环境是否被向量化 94 | if not isinstance(env, DummyVecEnv): 95 | self.env = DummyVecEnv([lambda: env]) 96 | 97 | # 创建日志目录 98 | if not os.path.exists(tensorboard_log): 99 | os.makedirs(tensorboard_log) 100 | 101 | # 设置TensorBoard日志 102 | log_dir = os.path.join(tensorboard_log, model_name) 103 | self.logger.info(f"TensorBoard日志目录: {log_dir}") 104 | 105 | # 创建DQN模型 106 | self.model = DQN( 107 | policy=policy_type, 108 | env=self.env, 109 | learning_rate=learning_rate, 110 | gamma=gamma, 111 | batch_size=batch_size, 112 | buffer_size=buffer_size, 113 | exploration_fraction=exploration_fraction, 114 | exploration_initial_eps=exploration_initial_eps, 115 | exploration_final_eps=exploration_final_eps, 116 | tensorboard_log=log_dir, 117 | verbose=verbose, 118 | device=device 119 | ) 120 | 121 | self.logger.info(f"初始化DQN智能体,模型名称: {model_name}") 122 | 123 | def train( 124 | self, 125 | total_timesteps: int = NUM_EPISODES * 100, # 假设每个回合约100步 126 | tb_log_name: str = "dqn_training", 127 | eval_freq: int = 10000, 128 | n_eval_episodes: int = 5, 129 | eval_env = None 130 | ): 131 | """ 132 | 训练DQN智能体 133 | 134 | 参数: 135 | total_timesteps: 总训练步数 136 | tb_log_name: TensorBoard日志名称 137 | eval_freq: 评估频率 138 | n_eval_episodes: 评估回合数 139 | eval_env: 评估环境,如果None则使用训练环境 140 | 141 | 返回: 142 | 训练后的模型 143 | """ 144 | # 设置评估环境 145 | if eval_env is None: 146 | eval_env = self.env 147 | 148 | # 创建TensorBoard回调 149 | callback = TensorboardCallback() 150 | 151 | # 开始训练 152 | self.logger.info(f"开始训练DQN模型,总步数: {total_timesteps}") 153 | self.model.learn( 154 | total_timesteps=total_timesteps, 155 | tb_log_name=tb_log_name, 156 | callback=callback 157 | ) 158 | 159 | # 保存模型 160 | model_path = os.path.join("models", f"{self.model_name}.zip") 161 | self.model.save(model_path) 162 | self.logger.info(f"模型已保存到 {model_path}") 163 | 164 | return self.model 165 | 166 | def predict(self, observation, state=None, deterministic=True): 167 | """ 168 | 使用模型进行预测 169 | 170 | 参数: 171 | observation: 观察状态 172 | state: 隐藏状态(如适用) 173 | deterministic: 是否确定性预测 174 | 175 | 返回: 176 | 预测的动作 177 | """ 178 | return self.model.predict(observation, state, deterministic) 179 | 180 | def load(self, path): 181 | """ 182 | 加载训练好的模型 183 | 184 | 参数: 185 | path: 模型路径 186 | 187 | 返回: 188 | 加载的模型 189 | """ 190 | self.model = DQN.load(path, env=self.env) 191 | self.logger.info(f"模型已加载: {path}") 192 | return self.model 193 | 194 | def save(self, path): 195 | """ 196 | 保存当前模型 197 | 198 | 参数: 199 | path: 保存路径 200 | """ 201 | self.model.save(path) 202 | self.logger.info(f"模型已保存: {path}") 203 | 204 | def test(self, test_env, num_episodes=1, render=False): 205 | """ 206 | 在测试环境中评估模型 207 | 208 | 参数: 209 | test_env: 测试环境 210 | num_episodes: 测试回合数 211 | render: 是否渲染环境 212 | 213 | 返回: 214 | 资产记忆和交易记忆 215 | """ 216 | self.logger.info(f"开始测试模型,回合数: {num_episodes}") 217 | 218 | # 如果测试环境不是向量化的,则向量化 219 | if not isinstance(test_env, DummyVecEnv): 220 | test_env = DummyVecEnv([lambda: test_env]) 221 | 222 | # 初始化统计信息 223 | total_rewards = [] 224 | portfolio_values = [] 225 | all_trades = [] 226 | 227 | # 多次测试 228 | for episode in range(num_episodes): 229 | # 初始化环境 230 | obs = test_env.reset() 231 | done = False 232 | total_reward = 0 233 | step = 0 234 | 235 | while not done: 236 | # 使用模型选择动作 237 | action, _states = self.model.predict(obs, deterministic=True) 238 | 239 | # 执行动作 240 | obs, reward, done, info = test_env.step(action) 241 | 242 | # 更新统计信息 243 | total_reward += reward[0] 244 | step += 1 245 | 246 | # 收集交易信息 247 | if 'trades' in info[0]: 248 | all_trades.extend(info[0]['trades']) 249 | 250 | # 收集投资组合价值 251 | if 'portfolio_value' in info[0]: 252 | portfolio_values.append(info[0]['portfolio_value']) 253 | 254 | # 如果需要,渲染环境 255 | if render: 256 | test_env.render() 257 | 258 | # 记录回合奖励 259 | total_rewards.append(total_reward) 260 | self.logger.info(f"回合 {episode+1}/{num_episodes}, 总奖励: {total_reward:.4f}") 261 | 262 | # 获取资产记忆 263 | asset_memory = test_env.envs[0].save_asset_memory() 264 | 265 | # 获取交易记忆 266 | action_memory = test_env.envs[0].save_action_memory() 267 | 268 | # 绘制测试结果 269 | if render: 270 | self._plot_test_results(asset_memory, action_memory) 271 | 272 | # 获取最终统计信息 273 | stats = test_env.envs[0].get_final_stats() 274 | self.logger.info(f"测试结果统计: {stats}") 275 | 276 | return asset_memory, action_memory 277 | 278 | def _plot_test_results(self, asset_memory, action_memory): 279 | """ 280 | 绘制测试结果 281 | 282 | 参数: 283 | asset_memory: 资产记忆DataFrame 284 | action_memory: 交易记忆DataFrame 285 | """ 286 | # 绘制投资组合价值 287 | plt.figure(figsize=(12, 6)) 288 | plt.plot(asset_memory.index, asset_memory['portfolio_value'], label='投资组合价值') 289 | 290 | # 添加买入卖出标记 291 | if not action_memory.empty: 292 | for idx, row in action_memory.iterrows(): 293 | date = row['date'] 294 | action = row['action'] 295 | if date in asset_memory.index: 296 | portfolio_value = asset_memory.loc[date, 'portfolio_value'] 297 | if action == 'buy': 298 | plt.scatter(date, portfolio_value, marker='^', color='green', s=100) 299 | elif action == 'sell': 300 | plt.scatter(date, portfolio_value, marker='v', color='red', s=100) 301 | 302 | plt.title('ETF交易投资组合价值') 303 | plt.xlabel('日期') 304 | plt.ylabel('价值') 305 | plt.grid(True) 306 | plt.legend() 307 | plt.tight_layout() 308 | plt.savefig(f"results/{self.model_name}_portfolio_value.png") 309 | plt.show() 310 | 311 | # 绘制收益率 312 | returns = asset_memory['portfolio_value'].pct_change().dropna() 313 | plt.figure(figsize=(12, 6)) 314 | plt.plot(returns.index, returns, label='每日收益率') 315 | plt.title('ETF交易每日收益率') 316 | plt.xlabel('日期') 317 | plt.ylabel('收益率') 318 | plt.grid(True) 319 | plt.legend() 320 | plt.tight_layout() 321 | plt.savefig(f"results/{self.model_name}_daily_returns.png") 322 | plt.show() -------------------------------------------------------------------------------- /ui/pages/sidebar.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | from datetime import datetime, timedelta 3 | from src.strategies.strategy_factory import StrategyFactory 4 | 5 | def render_sidebar(): 6 | """渲染侧边栏""" 7 | with st.sidebar: 8 | st.header("策略参数设置") 9 | 10 | # 策略选择 11 | st.subheader("策略选择") 12 | strategy_name = st.selectbox( 13 | "选择策略", 14 | options=StrategyFactory.get_strategy_names(), 15 | index=0 16 | ) 17 | 18 | # 数据源设置 19 | st.subheader("数据源配置") 20 | if strategy_name in ["市场情绪策略", "双均线对冲策略"]: 21 | tushare_token = st.text_input("Tushare Token(必填)", value="", type="password", help="市场情绪策略需要使用Tushare数据源") 22 | if not tushare_token: 23 | st.error(f"{strategy_name}必须提供Tushare Token") 24 | else: 25 | tushare_token = st.text_input("Tushare Token(可选,如不填则使用akshare)", type="password") 26 | 27 | # ETF轮动策略的ETF选择 28 | if strategy_name == "ETF轮动策略": 29 | etf_list = [ 30 | '510050.SH', # 上证50ETF 31 | '510300.SH', # 沪深300ETF 32 | '510500.SH', # 中证500ETF 33 | '159915.SZ', # 创业板ETF 34 | '512880.SH', # 证券ETF 35 | '512690.SH', # 酒ETF 36 | '512660.SH', # 军工ETF 37 | '512010.SH', # 医药ETF 38 | '512800.SH', # 银行ETF 39 | '512170.SH', # 医疗ETF 40 | '512760.SH', # 芯片ETF 41 | '159928.SZ', # 消费ETF 42 | '512480.SH', # 半导体ETF 43 | '512980.SH', # 科技ETF 44 | '512580.SH', # 环保ETF 45 | '512400.SH', # 有色金属ETF 46 | '512200.SH', # 地产ETF 47 | '516160.SH', # 新能源车ETF 48 | '159939.SZ', # 信息技术ETF 49 | '512600.SH', # 主要消费ETF 50 | '512070.SH', # 证券保险ETF 51 | '159869.SZ', # 新基建ETF 52 | '515030.SH', # 新能源ETF 53 | '515790.SH', # 光伏ETF 54 | '513050.SH', # 中概互联ETF 55 | ] 56 | selected_etfs = st.multiselect( 57 | "选择ETF", 58 | options=etf_list, 59 | default=etf_list[:5], # 默认选择前5个ETF 60 | help="选择要轮动的ETF,建议选择3-5个相关性较低的ETF" 61 | ) 62 | if not selected_etfs: 63 | st.error("请至少选择一个ETF") 64 | return None 65 | else: 66 | if strategy_name == "双均线对冲策略": 67 | symbol = st.text_input("ETF代码", value="159985.SZ", help="使用M豆粕主力合约作为被对冲标的") 68 | else: 69 | symbol = st.text_input("ETF代码", value="510050.SH", help="支持:A股(000001.SZ)、ETF(510300.SH)、港股(00700.HK)") 70 | 71 | # 移动平均线参数(对双均线策略和双均线对冲策略显示) 72 | if strategy_name in ["双均线策略", "双均线对冲策略"]: 73 | st.subheader("均线参数") 74 | col1, col2 = st.columns(2) 75 | with col1: 76 | fast_period = st.number_input("快线周期", value=5, min_value=1) 77 | with col2: 78 | slow_period = st.number_input("慢线周期", value=10, min_value=1) 79 | 80 | # 止损止盈参数 81 | st.subheader("止损止盈参数") 82 | col1, col2 = st.columns(2) 83 | with col1: 84 | atr_profit_multiplier = st.number_input("ATR止盈倍数", value=3.0, min_value=0.5, max_value=5.0, step=0.1, 85 | help="ATR止盈倍数,值越大止盈距离越远") 86 | with col2: 87 | atr_loss_multiplier = st.number_input("ATR止损倍数", value=3.0, min_value=0.5, max_value=5.0, step=0.1, 88 | help="ATR止损倍数,值越大止损距离越远") 89 | with col1: 90 | atr_period = st.number_input("ATR周期", value=14, min_value=5, max_value=30, step=1) 91 | 92 | # 其他技术参数 93 | col1, col2 = st.columns(2) 94 | with col1: 95 | enable_trailing_stop = st.checkbox("启用追踪止损", value=False, 96 | help="追踪止损会在价格创新高后设置止损价,可以锁定更多利润") 97 | with col2: 98 | enable_death_cross = st.checkbox("启用死叉卖出", value=False, 99 | help="启用后,当快线下穿慢线时将卖出持仓") 100 | 101 | # 双均线对冲策略特有参数 102 | if strategy_name == "双均线对冲策略": 103 | st.subheader("对冲参数") 104 | col1, col2 = st.columns(2) 105 | with col1: 106 | hedge_contract_size = st.number_input("对冲合约手数", value=10, min_value=1, max_value=100, step=1, 107 | help="豆粕期货对冲合约手数,建议10-20手") 108 | with col2: 109 | hedge_profit_multiplier = st.number_input("对冲盈利倍数", value=1.0, min_value=0.5, max_value=5.0, step=0.1, 110 | help="对冲目标盈利 = 原始损失 × (1 + 盈利倍数)") 111 | 112 | # 对冲模块开关 113 | st.subheader("对冲模块开关") 114 | hedge_mode = st.radio( 115 | "选择对冲模式", 116 | ["无对冲", "ATR止损开空对冲", "MA交叉死叉做空对冲", "MACD零轴上方死叉做空对冲", "同步做多对冲"], 117 | help="选择要启用的对冲模式,只能选择一种" 118 | ) 119 | 120 | # 根据选择设置对应的开关 121 | enable_hedging = (hedge_mode == "ATR止损开空对冲") 122 | enable_ma_cross_hedge = (hedge_mode == "MA交叉死叉做空对冲") 123 | enable_macd_hedge = (hedge_mode == "MACD零轴上方死叉做空对冲") 124 | enable_sync_long_hedge = (hedge_mode == "同步做多对冲") 125 | 126 | # ETF轮动策略参数 127 | if strategy_name == "ETF轮动策略": 128 | st.subheader("轮动参数") 129 | col1, col2 = st.columns(2) 130 | with col1: 131 | momentum_short = st.number_input("短期动量周期", value=10, min_value=1) 132 | with col2: 133 | momentum_long = st.number_input("长期动量周期", value=60, min_value=1) 134 | col1, col2 = st.columns(2) 135 | with col1: 136 | rebalance_interval = st.number_input("调仓间隔(天)", value=20, min_value=1) 137 | with col2: 138 | num_positions = st.number_input("持仓数量", value=3, min_value=1, max_value=10) 139 | 140 | # 止盈止损参数 141 | st.subheader("止盈止损参数") 142 | col1, col2 = st.columns(2) 143 | with col1: 144 | profit_target1 = st.number_input("第一止盈目标(%)", value=5.0, min_value=1.0, max_value=100.0, step=1.0) 145 | with col2: 146 | profit_target2 = st.number_input("第二止盈目标(%)", value=10.0, min_value=1.0, max_value=100.0, step=1.0) 147 | 148 | # 市场状态参数 149 | st.subheader("市场状态参数") 150 | col1, col2 = st.columns(2) 151 | with col1: 152 | market_trend_threshold = st.number_input("市场趋势阈值(%)", value=-5.0, min_value=-20.0, max_value=0.0, step=1.0) 153 | with col2: 154 | vix_threshold = st.number_input("波动率阈值(%)", value=3.0, min_value=1.0, max_value=10.0, step=0.5) 155 | 156 | # 动量衰减参数 157 | momentum_decay = st.slider("动量衰减阈值(%)", 10.0, 50.0, 30.0, 1.0) 158 | atr_multiplier = st.slider("ATR倍数", 1.0, 3.0, 2.0, 0.1) 159 | 160 | # 风险控制参数 161 | st.subheader("风险控制") 162 | trail_percent = st.slider("追踪止损比例(%)", 0.5, 5.0, 2.0, 0.1) 163 | risk_ratio = st.slider("单次交易风险比例(%)", 0.5, 5.0, 2.0, 0.1) 164 | max_drawdown = st.slider("最大回撤限制(%)", 5.0, 30.0, 15.0, 1.0) 165 | 166 | # 回测区间 167 | st.subheader("回测区间") 168 | start_date = st.date_input( 169 | "开始日期", 170 | datetime.now() - timedelta(days=365) 171 | ) 172 | end_date = st.date_input("结束日期", datetime.now()) 173 | 174 | # 资金设置 175 | st.subheader("资金设置") 176 | initial_cash = st.number_input("初始资金", value=100000.0, min_value=1000.0) 177 | commission = st.number_input("佣金费率(双向收取,默认万分之2.5)", value=0.00025, min_value=0.0, max_value=0.01, format="%.5f", 178 | help="双向收取,例如:0.00025表示万分之2.5") 179 | 180 | # 返回所有参数 181 | params = { 182 | 'strategy_name': strategy_name, 183 | 'tushare_token': tushare_token, 184 | 'selected_etfs': selected_etfs if strategy_name == "ETF轮动策略" else None, 185 | 'symbol': symbol if strategy_name != "ETF轮动策略" else None, 186 | 'fast_period': fast_period if strategy_name in ["双均线策略", "双均线对冲策略"] else None, 187 | 'slow_period': slow_period if strategy_name in ["双均线策略", "双均线对冲策略"] else None, 188 | 'atr_profit_multiplier': atr_profit_multiplier if strategy_name in ["双均线策略", "双均线对冲策略", "ETF轮动策略"] else None, 189 | 'atr_loss_multiplier': atr_loss_multiplier if strategy_name in ["双均线策略", "双均线对冲策略", "ETF轮动策略"] else None, 190 | 'atr_period': atr_period if strategy_name in ["双均线策略", "双均线对冲策略"] else None, 191 | 'enable_trailing_stop': enable_trailing_stop if strategy_name in ["双均线策略", "双均线对冲策略"] else None, 192 | 'enable_death_cross': enable_death_cross if strategy_name in ["双均线策略", "双均线对冲策略"] else None, 193 | 'hedge_contract_size': hedge_contract_size if strategy_name == "双均线对冲策略" else None, 194 | 'hedge_profit_multiplier': hedge_profit_multiplier if strategy_name == "双均线对冲策略" else None, 195 | 'enable_hedging': enable_hedging if strategy_name == "双均线对冲策略" else None, 196 | 'enable_ma_cross_hedge': enable_ma_cross_hedge if strategy_name == "双均线对冲策略" else None, 197 | 'enable_macd_hedge': enable_macd_hedge if strategy_name == "双均线对冲策略" else None, 198 | 'enable_sync_long_hedge': enable_sync_long_hedge if strategy_name == "双均线对冲策略" else None, 199 | 'momentum_short': momentum_short if strategy_name == "ETF轮动策略" else None, 200 | 'momentum_long': momentum_long if strategy_name == "ETF轮动策略" else None, 201 | 'rebalance_interval': rebalance_interval if strategy_name == "ETF轮动策略" else None, 202 | 'num_positions': num_positions if strategy_name == "ETF轮动策略" else None, 203 | 'profit_target1': profit_target1 / 100 if strategy_name == "ETF轮动策略" else None, 204 | 'profit_target2': profit_target2 / 100 if strategy_name == "ETF轮动策略" else None, 205 | 'market_trend_threshold': market_trend_threshold / 100 if strategy_name == "ETF轮动策略" else None, 206 | 'vix_threshold': vix_threshold / 100 if strategy_name == "ETF轮动策略" else None, 207 | 'momentum_decay': momentum_decay / 100 if strategy_name == "ETF轮动策略" else None, 208 | 'trail_percent': trail_percent, 209 | 'risk_ratio': risk_ratio, 210 | 'max_drawdown': max_drawdown, 211 | 'start_date': start_date, 212 | 'end_date': end_date, 213 | 'initial_cash': initial_cash, 214 | 'commission': commission 215 | } 216 | 217 | return params -------------------------------------------------------------------------------- /src/strategies/dual_ma_hedging/ma_cross_hedge.py: -------------------------------------------------------------------------------- 1 | import backtrader as bt 2 | from loguru import logger 3 | 4 | class MACrossHedge: 5 | def __init__(self, strategy): 6 | self.strategy = strategy 7 | self.enabled = False 8 | self.hedge_position = None 9 | self.hedge_entry_price = None 10 | self.hedge_order = None 11 | self.hedge_contract_code = None 12 | self.hedge_entry_date = None # 添加入场日期记录 13 | 14 | def enable(self): 15 | """启用MA交叉对冲功能""" 16 | self.enabled = True 17 | logger.info("启用MA交叉对冲功能") 18 | 19 | def disable(self): 20 | """禁用MA交叉对冲功能""" 21 | self.enabled = False 22 | logger.info("禁用MA交叉对冲功能") 23 | 24 | def on_death_cross(self): 25 | """在快线下穿慢线时开空期货""" 26 | if not self.enabled: 27 | return 28 | 29 | if self.hedge_position is not None or self.hedge_order is not None: 30 | logger.info("已有对冲仓位或对冲订单,不再开仓") 31 | return 32 | 33 | try: 34 | # 计算ATR止盈止损价格 35 | current_atr = self.strategy.atr[0] 36 | stop_loss = self.strategy.data1.close[0] + (current_atr * self.strategy.p.atr_loss_multiplier) 37 | take_profit = self.strategy.data1.close[0] - (current_atr * self.strategy.p.atr_profit_multiplier) 38 | 39 | # 开空豆粕期货 40 | hedge_size = self.strategy.p.hedge_contract_size 41 | self.hedge_order = self.strategy.sell(data=self.strategy.data1, size=hedge_size) 42 | 43 | if self.hedge_order: 44 | # 记录入场价格和合约代码 45 | self.hedge_entry_price = self.strategy.data1.close[0] 46 | # 确保获取正确的合约代码 47 | current_date = self.strategy.data1.datetime.datetime(0) 48 | if hasattr(self.strategy.data1, 'contract_mapping') and current_date in self.strategy.data1.contract_mapping: 49 | self.hedge_contract_code = self.strategy.data1.contract_mapping[current_date] 50 | else: 51 | # 如果无法获取映射,使用数据名称 52 | self.hedge_contract_code = self.strategy.data1._name 53 | self.hedge_entry_date = self.strategy.data.datetime.date(0) # 记录入场日期 54 | 55 | # 计算保证金 56 | margin = self.hedge_entry_price * hedge_size * self.strategy.p.future_contract_multiplier * 0.10 57 | 58 | # 从期货账户扣除保证金 59 | pre_cash = self.strategy.future_cash 60 | self.strategy.future_cash -= margin 61 | 62 | logger.info(f"开仓扣除保证金 - 之前: {pre_cash:.2f}, 扣除: {margin:.2f}, 之后: {self.strategy.future_cash:.2f}") 63 | 64 | # 记录交易信息 65 | self.hedge_order.info = { 66 | 'reason': f"MA死叉开空 - 快线: {self.strategy.fast_ma[0]:.2f}, 慢线: {self.strategy.slow_ma[0]:.2f}", 67 | 'margin': margin, 68 | 'future_cash': self.strategy.future_cash, 69 | 'execution_date': self.hedge_entry_date, 70 | 'total_value': self.strategy.future_cash, 71 | 'position_value': abs(margin), 72 | 'position_ratio': margin / self.strategy.future_cash if self.strategy.future_cash > 0 else 0, 73 | 'etf_code': self.hedge_contract_code, 74 | 'pnl': 0, 75 | 'return': 0, 76 | 'stop_loss': stop_loss, 77 | 'take_profit': take_profit, 78 | 'avg_cost': self.hedge_entry_price 79 | } 80 | 81 | logger.info(f"MA死叉开空 - 合约: {self.hedge_contract_code}, 价格: {self.hedge_entry_price:.2f}, 数量: {hedge_size}手, " 82 | f"止损价: {stop_loss:.2f}, 止盈价: {take_profit:.2f}") 83 | 84 | except Exception as e: 85 | logger.error(f"MA死叉开空失败: {str(e)}") 86 | 87 | def check_exit(self): 88 | """检查是否需要平仓""" 89 | if not self.enabled or not self.hedge_position or self.hedge_order is not None: 90 | return 91 | 92 | current_price = self.strategy.data1.close[0] 93 | current_atr = self.strategy.atr[0] 94 | 95 | # 计算ATR止盈止损价格 96 | stop_loss = self.hedge_entry_price + (current_atr * self.strategy.p.atr_loss_multiplier) 97 | take_profit = self.hedge_entry_price - (current_atr * self.strategy.p.atr_profit_multiplier) 98 | 99 | # 获取当前日期 100 | current_date = self.strategy.data.datetime.date(0) 101 | 102 | # 检查是否触发止盈止损 103 | if current_price >= stop_loss or current_price <= take_profit: 104 | contract_code = self.hedge_contract_code 105 | self.hedge_order = self.strategy.close(data=self.strategy.data1) 106 | reason = "触发止盈" if current_price <= take_profit else "触发止损" 107 | logger.info(f"MA死叉对冲{reason} - 日期: {current_date}, 合约: {contract_code}, 当前价格: {current_price:.2f}, {reason}价: {take_profit if current_price <= take_profit else stop_loss:.2f}") 108 | 109 | def on_order_completed(self, order): 110 | """处理订单完成事件""" 111 | if not self.enabled: 112 | return 113 | 114 | if order.status in [order.Completed]: 115 | if order.isbuy(): # 买入豆粕期货(平空) 116 | # 确保有对应的入场价格 117 | if self.hedge_entry_price is None or self.hedge_contract_code is None: 118 | logger.error("平仓时找不到入场价格或合约代码,跳过处理") 119 | return 120 | 121 | # 记录平仓前的合约信息,用于日志 122 | entry_price = self.hedge_entry_price 123 | contract_code = self.hedge_contract_code 124 | entry_date = self.hedge_entry_date 125 | 126 | # 记录交易日期和价格 127 | trade_date = self.strategy.data.datetime.date(0) 128 | trade_price = order.executed.price 129 | 130 | # 先重置持仓相关变量,防止重复平仓 131 | self.hedge_position = None 132 | self.hedge_order = None 133 | self.hedge_entry_price = None 134 | self.hedge_contract_code = None 135 | self.hedge_entry_date = None 136 | self.hedge_target_profit = None 137 | 138 | # 计算对冲盈亏(空仓:入场价 - 平仓价) 139 | hedge_profit = (entry_price - trade_price) * self.strategy.p.hedge_contract_size * self.strategy.p.future_contract_multiplier 140 | 141 | # 减去开平仓手续费 142 | total_fee = self.strategy.p.hedge_fee * self.strategy.p.hedge_contract_size * 2 143 | net_profit = hedge_profit - total_fee 144 | 145 | # 归还保证金并添加盈亏到期货账户 146 | margin_returned = entry_price * self.strategy.p.hedge_contract_size * self.strategy.p.future_contract_multiplier * 0.10 147 | 148 | # 记录更新前的资金 149 | pre_cash = self.strategy.future_cash 150 | 151 | # 更新期货账户资金 152 | self.strategy.future_cash += (margin_returned + net_profit) 153 | 154 | # 记录资金变动 155 | logger.info(f"平仓资金变动 - 之前: {pre_cash:.2f}, 返还保证金: {margin_returned:.2f}, 盈亏: {net_profit:.2f}, 之后: {self.strategy.future_cash:.2f}") 156 | 157 | # 更新期货账户最高净值 158 | self.strategy.future_highest_value = max(self.strategy.future_highest_value, self.strategy.future_cash) 159 | 160 | # 计算期货账户回撤 161 | future_drawdown = (self.strategy.future_highest_value - self.strategy.future_cash) / self.strategy.future_highest_value if self.strategy.future_highest_value > 0 else 0 162 | 163 | # 计算收益率 164 | return_pct = (hedge_profit / (entry_price * self.strategy.p.hedge_contract_size * self.strategy.p.future_contract_multiplier)) * 100 165 | 166 | # 更新订单信息 167 | order.info.update({ 168 | 'pnl': hedge_profit, 169 | 'return': return_pct, 170 | 'total_value': self.strategy.future_cash, 171 | 'position_value': 0, # 平仓后持仓价值为0 172 | 'avg_cost': entry_price, 173 | 'etf_code': contract_code, # 确保使用原始合约代码 174 | 'execution_date': trade_date, # 确保使用当前交易日期 175 | 'reason': f"MA死叉对冲平仓 - 合约: {contract_code}, 入场日期: {entry_date}, 入场价: {entry_price:.2f}, 平仓价: {trade_price:.2f}, 收益率: {return_pct:.2f}%" 176 | }) 177 | 178 | logger.info(f"MA死叉对冲平仓 - 日期: {trade_date}, 合约: {contract_code}, 价格: {trade_price:.2f}, 盈利: {hedge_profit:.2f}, " 179 | f"手续费: {total_fee:.2f}, 净盈利: {net_profit:.2f}, " 180 | f"期货账户余额: {self.strategy.future_cash:.2f}, 回撤: {future_drawdown:.2%}, " 181 | f"收益率: {return_pct:.2f}%") 182 | 183 | else: # 卖出豆粕期货(开空) 184 | # 记录对冲持仓 185 | self.hedge_position = order 186 | 187 | # 更新订单信息 188 | order.info.update({ 189 | 'total_value': self.strategy.future_cash, 190 | 'position_value': abs(order.info['margin']), 191 | 'avg_cost': order.executed.price, 192 | 'etf_code': self.hedge_contract_code # 确保合约代码正确 193 | }) 194 | 195 | elif order.status in [order.Canceled, order.Margin, order.Rejected]: 196 | self.hedge_order = None 197 | logger.warning(f'MA死叉对冲订单失败 - 状态: {order.getstatusname()}') 198 | 199 | def on_strategy_stop(self): 200 | """策略结束时平掉所有期货仓位""" 201 | if not self.enabled or not self.hedge_position: 202 | return 203 | 204 | if self.hedge_order is None: # 确保没有未完成订单 205 | # 获取当前持仓的合约代码 206 | current_contract = self.hedge_contract_code 207 | if not current_contract: 208 | logger.error("策略结束时找不到期货合约代码,无法平仓") 209 | return 210 | 211 | # 获取当前日期 212 | current_date = self.strategy.data.datetime.date(0) 213 | 214 | self.hedge_order = self.strategy.close(data=self.strategy.data1) 215 | logger.info(f"策略结束,平掉期货仓位 - 日期: {current_date}, 合约: {current_contract}, 入场日期: {self.hedge_entry_date}, 入场价: {self.hedge_entry_price:.2f}") 216 | 217 | # 更新订单信息 218 | if self.hedge_order: 219 | self.hedge_order.info.update({ 220 | 'etf_code': current_contract, # 确保使用正确的合约代码 221 | 'execution_date': current_date, # 使用当前日期 222 | 'reason': f"策略结束平仓 - 日期: {current_date}, 合约: {current_contract}, 入场日期: {self.hedge_entry_date}, 入场价: {self.hedge_entry_price:.2f}" 223 | }) -------------------------------------------------------------------------------- /rl_model_finrl/meta/data_processors/akshare_processor.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | import akshare as ak 4 | import os 5 | from typing import List, Dict, Optional, Tuple 6 | import logging 7 | from datetime import datetime 8 | 9 | from src.strategies.rl_model_finrl.config import ( 10 | DATA_SAVE_PATH, 11 | TRAIN_START_DATE, 12 | TRAIN_END_DATE, 13 | TEST_START_DATE, 14 | TEST_END_DATE, 15 | TICKER_LIST, 16 | ) 17 | 18 | class AKShareProcessor: 19 | """ 20 | AKShareProcessor 类负责从AKShare获取A股ETF相关数据 21 | 作为补充数据源使用 22 | """ 23 | 24 | def __init__(self): 25 | """初始化AKShare数据处理器""" 26 | # 确保数据保存目录存在 27 | self.data_path = os.path.join(DATA_SAVE_PATH, "akshare_data") 28 | os.makedirs(self.data_path, exist_ok=True) 29 | 30 | # 设置日志 31 | logging.basicConfig(level=logging.INFO) 32 | self.logger = logging.getLogger(__name__) 33 | 34 | def download_etf_fund_info(self, ticker_list: List[str]) -> Dict[str, pd.DataFrame]: 35 | """ 36 | 下载ETF基金基本信息 37 | 38 | 参数: 39 | ticker_list: ETF代码列表 40 | 41 | 返回: 42 | 包含ETF基本信息的字典 43 | """ 44 | etf_info_dict = {} 45 | 46 | for ticker in ticker_list: 47 | try: 48 | # 处理代码格式 49 | ticker_code = ticker.split('.')[0] 50 | 51 | self.logger.info(f"获取ETF {ticker} 的基本信息...") 52 | 53 | # 获取ETF基本信息 54 | etf_info = ak.fund_etf_basic_info(symbol=ticker_code) 55 | 56 | if etf_info is not None and not etf_info.empty: 57 | # 保存到字典 58 | etf_info_dict[ticker] = etf_info 59 | 60 | # 保存到CSV 61 | file_name = f"{ticker}_info.csv" 62 | file_path = os.path.join(self.data_path, file_name) 63 | etf_info.to_csv(file_path, index=False) 64 | self.logger.info(f"ETF基本信息已保存到 {file_path}") 65 | else: 66 | self.logger.warning(f"无法获取ETF {ticker} 的基本信息") 67 | 68 | except Exception as e: 69 | self.logger.error(f"获取ETF {ticker} 的基本信息时出错: {str(e)}") 70 | 71 | return etf_info_dict 72 | 73 | def download_etf_daily_data( 74 | self, 75 | ticker_list: List[str], 76 | start_date: str, 77 | end_date: str 78 | ) -> Dict[str, pd.DataFrame]: 79 | """ 80 | 从AKShare下载ETF日线数据 81 | 82 | 参数: 83 | ticker_list: ETF代码列表 84 | start_date: 开始日期,格式'YYYY-MM-DD' 85 | end_date: 结束日期,格式'YYYY-MM-DD' 86 | 87 | 返回: 88 | 字典,包含每个ETF代码及其对应的数据框 89 | """ 90 | etf_data_dict = {} 91 | 92 | for ticker in ticker_list: 93 | try: 94 | # 处理代码格式 95 | ticker_code = ticker.split('.')[0] 96 | 97 | self.logger.info(f"从AKShare获取ETF {ticker} 的日线数据...") 98 | 99 | # 获取ETF历史数据 100 | etf_data = ak.fund_etf_hist_em( 101 | symbol=ticker_code, 102 | period="daily", 103 | start_date=start_date, 104 | end_date=end_date, 105 | adjust="qfq" 106 | ) 107 | 108 | if etf_data is not None and not etf_data.empty: 109 | # 重命名列以匹配tushare格式 110 | etf_data = etf_data.rename(columns={ 111 | '日期': 'date', 112 | '开盘': 'open', 113 | '收盘': 'close', 114 | '最高': 'high', 115 | '最低': 'low', 116 | '成交量': 'volume', 117 | '成交额': 'amount', 118 | '振幅': 'amplitude', 119 | '涨跌幅': 'change_pct', 120 | '涨跌额': 'change', 121 | '换手率': 'turnover' 122 | }) 123 | 124 | # 转换日期格式并设置为索引 125 | etf_data['date'] = pd.to_datetime(etf_data['date']) 126 | etf_data = etf_data.sort_values('date') 127 | etf_data = etf_data.set_index('date') 128 | 129 | # 添加代码列 130 | etf_data['tic'] = ticker 131 | 132 | # 保存到字典 133 | etf_data_dict[ticker] = etf_data 134 | 135 | # 保存到CSV 136 | file_name = f"{ticker}_daily_akshare_{start_date.replace('-', '')}_{end_date.replace('-', '')}.csv" 137 | file_path = os.path.join(self.data_path, file_name) 138 | etf_data.to_csv(file_path) 139 | self.logger.info(f"ETF日线数据已保存到 {file_path}") 140 | else: 141 | self.logger.warning(f"无法获取ETF {ticker} 的日线数据") 142 | 143 | except Exception as e: 144 | self.logger.error(f"获取ETF {ticker} 的日线数据时出错: {str(e)}") 145 | 146 | return etf_data_dict 147 | 148 | def download_etf_fund_flow( 149 | self, 150 | ticker_list: List[str], 151 | start_date: str, 152 | end_date: str 153 | ) -> Dict[str, pd.DataFrame]: 154 | """ 155 | 获取ETF资金流向数据 156 | 157 | 参数: 158 | ticker_list: ETF代码列表 159 | start_date: 开始日期,格式'YYYY-MM-DD' 160 | end_date: 结束日期,格式'YYYY-MM-DD' 161 | 162 | 返回: 163 | 包含资金流向的字典 164 | """ 165 | fund_flow_dict = {} 166 | 167 | for ticker in ticker_list: 168 | try: 169 | # 处理代码格式 170 | ticker_code = ticker.split('.')[0] 171 | 172 | self.logger.info(f"获取ETF {ticker} 的资金流向数据...") 173 | 174 | # 获取ETF资金流向 175 | try: 176 | # 尝试获取资金流向数据,这个可能需要调整,因为AKShare的API可能有变化 177 | fund_flow = ak.fund_etf_fund_flow_rank() 178 | 179 | # 筛选特定ETF的数据 180 | if ticker_code in fund_flow['代码'].values: 181 | fund_flow_single = fund_flow[fund_flow['代码'] == ticker_code] 182 | fund_flow_dict[ticker] = fund_flow_single 183 | 184 | # 保存到CSV 185 | file_name = f"{ticker}_fund_flow.csv" 186 | file_path = os.path.join(self.data_path, file_name) 187 | fund_flow_single.to_csv(file_path, index=False) 188 | self.logger.info(f"ETF资金流向数据已保存到 {file_path}") 189 | else: 190 | self.logger.warning(f"未找到ETF {ticker} 的资金流向数据") 191 | except Exception as inner_e: 192 | self.logger.error(f"获取ETF资金流向时出现内部错误: {str(inner_e)}") 193 | 194 | except Exception as e: 195 | self.logger.error(f"获取ETF {ticker} 的资金流向数据时出错: {str(e)}") 196 | 197 | return fund_flow_dict 198 | 199 | def download_etf_holdings(self, ticker_list: List[str]) -> Dict[str, pd.DataFrame]: 200 | """ 201 | 获取ETF持仓数据 202 | 203 | 参数: 204 | ticker_list: ETF代码列表 205 | 206 | 返回: 207 | 包含ETF持仓的字典 208 | """ 209 | holdings_dict = {} 210 | 211 | for ticker in ticker_list: 212 | try: 213 | # 处理代码格式 214 | ticker_code = ticker.split('.')[0] 215 | 216 | self.logger.info(f"获取ETF {ticker} 的持仓数据...") 217 | 218 | # 获取ETF持仓 219 | try: 220 | # 尝试获取ETF持仓数据,这个API需要根据AKShare的最新文档调整 221 | holdings = ak.fund_etf_spot_deal_em() 222 | 223 | # 筛选特定ETF的数据 224 | if ticker_code in holdings['代码'].values: 225 | holdings_single = holdings[holdings['代码'] == ticker_code] 226 | holdings_dict[ticker] = holdings_single 227 | 228 | # 保存到CSV 229 | file_name = f"{ticker}_holdings.csv" 230 | file_path = os.path.join(self.data_path, file_name) 231 | holdings_single.to_csv(file_path, index=False) 232 | self.logger.info(f"ETF持仓数据已保存到 {file_path}") 233 | else: 234 | self.logger.warning(f"未找到ETF {ticker} 的持仓数据") 235 | except Exception as inner_e: 236 | self.logger.error(f"获取ETF持仓时出现内部错误: {str(inner_e)}") 237 | 238 | except Exception as e: 239 | self.logger.error(f"获取ETF {ticker} 的持仓数据时出错: {str(e)}") 240 | 241 | return holdings_dict 242 | 243 | def download_market_sentiment(self) -> pd.DataFrame: 244 | """ 245 | 获取市场情绪数据(如恐慌指数等) 246 | 247 | 返回: 248 | 包含市场情绪的DataFrame 249 | """ 250 | try: 251 | self.logger.info("获取A股市场情绪数据...") 252 | 253 | # 获取A股情绪指标 254 | # 注意:需要根据AKShare的最新文档确认获取情绪数据的API 255 | # 这里用股市情绪指标作为示例 256 | sentiment_data = ak.stock_market_emotion_baidu() 257 | 258 | if sentiment_data is not None and not sentiment_data.empty: 259 | # 保存到CSV 260 | file_name = f"market_sentiment_{datetime.now().strftime('%Y%m%d')}.csv" 261 | file_path = os.path.join(self.data_path, file_name) 262 | sentiment_data.to_csv(file_path, index=False) 263 | self.logger.info(f"市场情绪数据已保存到 {file_path}") 264 | 265 | return sentiment_data 266 | else: 267 | self.logger.warning("无法获取市场情绪数据") 268 | return pd.DataFrame() 269 | 270 | except Exception as e: 271 | self.logger.error(f"获取市场情绪数据时出错: {str(e)}") 272 | return pd.DataFrame() 273 | 274 | def prepare_supplementary_data(self) -> Dict[str, pd.DataFrame]: 275 | """ 276 | 准备所有补充数据 277 | 278 | 返回: 279 | 包含所有补充数据的字典 280 | """ 281 | supplementary_data = {} 282 | 283 | # 1. 获取ETF基本信息 284 | etf_info = self.download_etf_fund_info(TICKER_LIST) 285 | supplementary_data['etf_info'] = etf_info 286 | 287 | # 2. 获取ETF资金流向 288 | fund_flow = self.download_etf_fund_flow( 289 | TICKER_LIST, 290 | TRAIN_START_DATE, 291 | TEST_END_DATE 292 | ) 293 | supplementary_data['fund_flow'] = fund_flow 294 | 295 | # 3. 获取ETF持仓 296 | holdings = self.download_etf_holdings(TICKER_LIST) 297 | supplementary_data['holdings'] = holdings 298 | 299 | # 4. 获取市场情绪 300 | sentiment = self.download_market_sentiment() 301 | supplementary_data['market_sentiment'] = sentiment 302 | 303 | return supplementary_data -------------------------------------------------------------------------------- /src/strategies/dual_ma_hedging/sync_long_hedge.py: -------------------------------------------------------------------------------- 1 | import backtrader as bt 2 | from loguru import logger 3 | 4 | class SyncLongHedge: 5 | def __init__(self, strategy): 6 | self.strategy = strategy 7 | self.enabled = False 8 | self.hedge_position = None 9 | self.hedge_entry_price = None 10 | self.hedge_order = None 11 | self.hedge_contract_code = None 12 | self.hedge_entry_date = None # 添加入场日期记录 13 | 14 | def enable(self): 15 | """启用同步做多对冲功能""" 16 | self.enabled = True 17 | logger.info("启用同步做多对冲功能") 18 | 19 | def disable(self): 20 | """禁用同步做多对冲功能""" 21 | self.enabled = False 22 | logger.info("禁用同步做多对冲功能") 23 | 24 | def on_golden_cross(self): 25 | """在ETF金叉时同步开多期货""" 26 | if not self.enabled: 27 | return 28 | 29 | if self.hedge_position is not None or self.hedge_order is not None: 30 | logger.info("已有对冲仓位或对冲订单,不再开仓") 31 | return 32 | 33 | try: 34 | # 计算ATR止盈止损价格 35 | current_atr = self.strategy.atr[0] 36 | stop_loss = self.strategy.data1.close[0] - (current_atr * self.strategy.p.atr_loss_multiplier) 37 | take_profit = self.strategy.data1.close[0] + (current_atr * self.strategy.p.atr_profit_multiplier) 38 | 39 | # 开多豆粕期货 40 | hedge_size = self.strategy.p.hedge_contract_size 41 | self.hedge_order = self.strategy.buy(data=self.strategy.data1, size=hedge_size) 42 | 43 | if self.hedge_order: 44 | # 记录入场价格和合约代码 45 | self.hedge_entry_price = self.strategy.data1.close[0] 46 | # 确保获取正确的合约代码 47 | current_date = self.strategy.data1.datetime.datetime(0) 48 | if hasattr(self.strategy.data1, 'contract_mapping') and current_date in self.strategy.data1.contract_mapping: 49 | self.hedge_contract_code = self.strategy.data1.contract_mapping[current_date] 50 | else: 51 | # 如果无法获取映射,使用数据名称 52 | self.hedge_contract_code = self.strategy.data1._name 53 | self.hedge_entry_date = self.strategy.data.datetime.date(0) # 记录入场日期 54 | 55 | # 计算保证金 56 | margin = self.hedge_entry_price * hedge_size * self.strategy.p.future_contract_multiplier * 0.10 57 | 58 | # 从期货账户扣除保证金 59 | pre_cash = self.strategy.future_cash 60 | self.strategy.future_cash -= margin 61 | 62 | logger.info(f"开仓扣除保证金 - 之前: {pre_cash:.2f}, 扣除: {margin:.2f}, 之后: {self.strategy.future_cash:.2f}") 63 | 64 | # 记录交易信息 65 | self.hedge_order.info = { 66 | 'reason': f"ETF金叉同步开多 - 快线: {self.strategy.fast_ma[0]:.2f}, 慢线: {self.strategy.slow_ma[0]:.2f}", 67 | 'margin': margin, 68 | 'future_cash': self.strategy.future_cash, 69 | 'execution_date': self.hedge_entry_date, 70 | 'total_value': self.strategy.future_cash, 71 | 'position_value': abs(margin), 72 | 'position_ratio': margin / self.strategy.future_cash if self.strategy.future_cash > 0 else 0, 73 | 'etf_code': self.hedge_contract_code, 74 | 'pnl': 0, 75 | 'return': 0, 76 | 'stop_loss': stop_loss, 77 | 'take_profit': take_profit, 78 | 'avg_cost': self.hedge_entry_price 79 | } 80 | 81 | logger.info(f"ETF金叉同步开多 - 合约: {self.hedge_contract_code}, 价格: {self.hedge_entry_price:.2f}, 数量: {hedge_size}手, " 82 | f"止损价: {stop_loss:.2f}, 止盈价: {take_profit:.2f}") 83 | 84 | except Exception as e: 85 | logger.error(f"ETF金叉同步开多失败: {str(e)}") 86 | 87 | def on_etf_close(self): 88 | """在ETF平仓时同步平多仓""" 89 | if not self.enabled or not self.hedge_position: 90 | return 91 | 92 | if self.hedge_order is None: # 确保没有未完成订单 93 | self.hedge_order = self.strategy.close(data=self.strategy.data1) 94 | logger.info("ETF平仓,同步平多仓") 95 | 96 | def check_exit(self): 97 | """检查是否需要平仓""" 98 | if not self.enabled or not self.hedge_position or self.hedge_order is not None: 99 | return 100 | 101 | current_price = self.strategy.data1.close[0] 102 | current_atr = self.strategy.atr[0] 103 | 104 | # 计算ATR止盈止损价格 105 | stop_loss = self.hedge_entry_price - (current_atr * self.strategy.p.atr_loss_multiplier) 106 | take_profit = self.hedge_entry_price + (current_atr * self.strategy.p.atr_profit_multiplier) 107 | 108 | # 获取当前日期 109 | current_date = self.strategy.data.datetime.date(0) 110 | 111 | # 检查是否触发止盈止损 112 | if current_price <= stop_loss or current_price >= take_profit: 113 | contract_code = self.hedge_contract_code 114 | self.hedge_order = self.strategy.close(data=self.strategy.data1) 115 | reason = "触发止盈" if current_price >= take_profit else "触发止损" 116 | logger.info(f"同步做多对冲{reason} - 日期: {current_date}, 合约: {contract_code}, 当前价格: {current_price:.2f}, {reason}价: {take_profit if current_price >= take_profit else stop_loss:.2f}") 117 | 118 | def on_order_completed(self, order): 119 | """处理订单完成事件""" 120 | if not self.enabled: 121 | return 122 | 123 | if order.status in [order.Completed]: 124 | if order.issell(): # 卖出豆粕期货(平多) 125 | # 确保有对应的入场价格 126 | if self.hedge_entry_price is None or self.hedge_contract_code is None: 127 | logger.error("平仓时找不到入场价格或合约代码,跳过处理") 128 | return 129 | 130 | # 记录平仓前的合约信息,用于日志 131 | entry_price = self.hedge_entry_price 132 | contract_code = self.hedge_contract_code 133 | entry_date = self.hedge_entry_date 134 | 135 | # 记录交易日期和价格 136 | trade_date = self.strategy.data.datetime.date(0) 137 | trade_price = order.executed.price 138 | 139 | # 先重置持仓相关变量,防止重复平仓 140 | self.hedge_position = None 141 | self.hedge_order = None 142 | self.hedge_entry_price = None 143 | self.hedge_contract_code = None 144 | self.hedge_entry_date = None 145 | self.hedge_target_profit = None 146 | 147 | # 计算对冲盈亏 148 | hedge_profit = (trade_price - entry_price) * self.strategy.p.hedge_contract_size * self.strategy.p.future_contract_multiplier 149 | 150 | # 减去开平仓手续费 151 | total_fee = self.strategy.p.hedge_fee * self.strategy.p.hedge_contract_size * 2 152 | net_profit = hedge_profit - total_fee 153 | 154 | # 归还保证金并添加盈亏到期货账户 155 | margin_returned = entry_price * self.strategy.p.hedge_contract_size * self.strategy.p.future_contract_multiplier * 0.10 156 | 157 | # 记录更新前的资金 158 | pre_cash = self.strategy.future_cash 159 | 160 | # 更新期货账户资金 161 | self.strategy.future_cash += (margin_returned + net_profit) 162 | 163 | # 记录资金变动 164 | logger.info(f"平仓资金变动 - 之前: {pre_cash:.2f}, 返还保证金: {margin_returned:.2f}, 盈亏: {net_profit:.2f}, 之后: {self.strategy.future_cash:.2f}") 165 | 166 | # 更新期货账户最高净值 167 | self.strategy.future_highest_value = max(self.strategy.future_highest_value, self.strategy.future_cash) 168 | 169 | # 计算期货账户回撤 170 | future_drawdown = (self.strategy.future_highest_value - self.strategy.future_cash) / self.strategy.future_highest_value if self.strategy.future_highest_value > 0 else 0 171 | 172 | # 计算收益率 173 | return_pct = (hedge_profit / (entry_price * self.strategy.p.hedge_contract_size * self.strategy.p.future_contract_multiplier)) * 100 174 | 175 | # 更新订单信息 176 | order.info.update({ 177 | 'pnl': hedge_profit, 178 | 'return': return_pct, 179 | 'total_value': self.strategy.future_cash, 180 | 'position_value': 0, # 平仓后持仓价值为0 181 | 'avg_cost': entry_price, 182 | 'etf_code': contract_code, # 确保使用原始合约代码 183 | 'execution_date': trade_date, # 确保使用当前交易日期 184 | 'reason': f"同步做多对冲平仓 - 合约: {contract_code}, 入场日期: {entry_date}, 入场价: {entry_price:.2f}, 平仓价: {trade_price:.2f}, 收益率: {return_pct:.2f}%" 185 | }) 186 | 187 | logger.info(f"同步做多对冲平仓 - 日期: {trade_date}, 合约: {contract_code}, 价格: {trade_price:.2f}, 盈利: {hedge_profit:.2f}, " 188 | f"手续费: {total_fee:.2f}, 净盈利: {net_profit:.2f}, " 189 | f"期货账户余额: {self.strategy.future_cash:.2f}, 回撤: {future_drawdown:.2%}, " 190 | f"收益率: {return_pct:.2f}%") 191 | 192 | else: # 买入豆粕期货(开多) 193 | # 记录对冲持仓 194 | self.hedge_position = order 195 | 196 | # 更新订单信息 197 | order.info.update({ 198 | 'total_value': self.strategy.future_cash, 199 | 'position_value': abs(order.info['margin']), 200 | 'avg_cost': order.executed.price, 201 | 'etf_code': self.hedge_contract_code # 确保合约代码正确 202 | }) 203 | 204 | elif order.status in [order.Canceled, order.Margin, order.Rejected]: 205 | self.hedge_order = None 206 | logger.warning(f'同步做多对冲订单失败 - 状态: {order.getstatusname()}') 207 | 208 | def on_strategy_stop(self): 209 | """策略结束时平掉所有期货仓位""" 210 | if not self.enabled or not self.hedge_position: 211 | return 212 | 213 | if self.hedge_order is None: # 确保没有未完成订单 214 | # 获取当前持仓的合约代码 215 | current_contract = self.hedge_contract_code 216 | if not current_contract: 217 | logger.error("策略结束时找不到期货合约代码,无法平仓") 218 | return 219 | 220 | # 获取当前日期 221 | current_date = self.strategy.data.datetime.date(0) 222 | 223 | self.hedge_order = self.strategy.close(data=self.strategy.data1) 224 | logger.info(f"策略结束,平掉期货仓位 - 日期: {current_date}, 合约: {current_contract}, 入场日期: {self.hedge_entry_date}, 入场价: {self.hedge_entry_price:.2f}") 225 | 226 | # 更新订单信息 227 | if self.hedge_order: 228 | self.hedge_order.info.update({ 229 | 'etf_code': current_contract, # 确保使用正确的合约代码 230 | 'execution_date': current_date, # 使用当前日期 231 | 'reason': f"策略结束平仓 - 日期: {current_date}, 合约: {current_contract}, 入场日期: {self.hedge_entry_date}, 入场价: {self.hedge_entry_price:.2f}" 232 | }) -------------------------------------------------------------------------------- /rl_model_finrl/applications/stock_trading/run_strategy.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pandas as pd 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | import datetime 6 | from typing import List, Dict, Any, Optional, Union, Tuple 7 | import logging 8 | 9 | from stable_baselines3 import PPO, A2C, DDPG, TD3, SAC, DQN 10 | from stable_baselines3.common.callbacks import BaseCallback, EvalCallback 11 | 12 | from src.strategies.rl_model_finrl.meta.data_processors import DataProcessor 13 | from src.strategies.rl_model_finrl.applications.stock_trading.etf_env import ETFTradingEnv 14 | from src.strategies.rl_model_finrl.agents.stablebaseline3 import DQNAgent 15 | from src.strategies.rl_model_finrl.agents.elegantrl import PPOAgent 16 | 17 | from src.strategies.rl_model_finrl.config import ( 18 | INITIAL_AMOUNT, 19 | TRANSACTION_COST_PCT, 20 | TURBULENCE_THRESHOLD, 21 | TRAIN_START_DATE, 22 | TRAIN_END_DATE, 23 | TEST_START_DATE, 24 | TEST_END_DATE, 25 | TECHNICAL_INDICATORS_LIST, 26 | TICKER_LIST, 27 | MODEL_SAVE_PATH, 28 | NUM_EPISODES, 29 | TENSORBOARD_LOG_PATH 30 | ) 31 | 32 | 33 | class TensorboardCallback(BaseCallback): 34 | """ 35 | 自定义回调,记录训练过程中的奖励和组合价值 36 | """ 37 | def __init__(self, verbose=0): 38 | super(TensorboardCallback, self).__init__(verbose) 39 | self.rewards = [] 40 | 41 | def _on_step(self) -> bool: 42 | # 获取最近的奖励 43 | if len(self.model.env.envs[0].rewards_memory) > 0: 44 | latest_reward = self.model.env.envs[0].rewards_memory[-1] 45 | self.rewards.append(latest_reward) 46 | # 记录到tensorboard 47 | self.logger.record('train/reward', latest_reward) 48 | self.logger.record('train/portfolio_value', self.model.env.envs[0].total_asset) 49 | 50 | return True 51 | 52 | 53 | def prepare_etf_data( 54 | processor: DataProcessor, 55 | ticker_list: List[str], 56 | start_date: str, 57 | end_date: str, 58 | data_source: str = "tushare" 59 | ) -> pd.DataFrame: 60 | """ 61 | 准备ETF交易数据 62 | 63 | 参数: 64 | processor: 数据处理器 65 | ticker_list: ETF代码列表 66 | start_date: 开始日期 67 | end_date: 结束日期 68 | data_source: 数据源 69 | 70 | 返回: 71 | 处理后的DataFrame 72 | """ 73 | # 如果没有提供ETF列表,使用默认列表 74 | if not ticker_list: 75 | ticker_list = TICKER_LIST 76 | 77 | # 获取原始数据 78 | df = processor.download_data( 79 | ticker_list=ticker_list, 80 | start_date=start_date, 81 | end_date=end_date, 82 | data_source=data_source 83 | ) 84 | 85 | # 处理数据 86 | df = processor.clean_data(df) 87 | 88 | # 添加技术指标 89 | df = processor.add_technical_indicators(df, TECHNICAL_INDICATORS_LIST) 90 | 91 | # 添加波动性指标 92 | df = processor.add_turbulence(df) 93 | 94 | # 填充缺失值 95 | df = df.fillna(method='ffill').fillna(method='bfill') 96 | 97 | # 确保日期是正确的格式 98 | df.index = pd.to_datetime(df.index) 99 | 100 | return df 101 | 102 | 103 | def run_etf_strategy( 104 | start_date: str = TRAIN_START_DATE, 105 | end_date: str = TRAIN_END_DATE, 106 | ticker_list: List[str] = None, 107 | data_source: str = "tushare", 108 | time_interval: str = "1d", 109 | technical_indicator_list: List[str] = TECHNICAL_INDICATORS_LIST, 110 | initial_amount: float = INITIAL_AMOUNT, 111 | transaction_cost_pct: float = TRANSACTION_COST_PCT, 112 | agent: str = "ppo", 113 | model_name: str = None, 114 | turbulence_threshold: float = TURBULENCE_THRESHOLD, 115 | if_store_model: bool = True, 116 | num_episodes: int = NUM_EPISODES, 117 | **kwargs 118 | ) -> Any: 119 | """ 120 | 运行ETF交易策略 121 | 122 | 参数: 123 | start_date: 训练开始日期 124 | end_date: 训练结束日期 125 | ticker_list: ETF代码列表 126 | data_source: 数据源 127 | time_interval: 时间间隔 128 | technical_indicator_list: 技术指标列表 129 | initial_amount: 初始资金 130 | transaction_cost_pct: 交易成本百分比 131 | agent: 智能体类型 (ppo, dqn, a2c等) 132 | model_name: 模型名称 133 | turbulence_threshold: 市场波动阈值 134 | if_store_model: 是否存储模型 135 | num_episodes: 训练回合数 136 | **kwargs: 传递给agent的其他参数 137 | 138 | 返回: 139 | 训练好的模型 140 | """ 141 | # 设置日志 142 | logging.basicConfig(level=logging.INFO) 143 | logger = logging.getLogger(__name__) 144 | 145 | # 如果未指定模型名称,则自动生成 146 | if model_name is None: 147 | model_name = f"{agent}_etf_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}" 148 | 149 | logger.info(f"开始运行ETF交易策略, 智能体: {agent}, 模型名称: {model_name}") 150 | 151 | # 初始化数据处理器 152 | processor = DataProcessor(data_source=data_source, time_interval=time_interval) 153 | 154 | # 准备数据 155 | logger.info(f"准备ETF数据: {start_date} 到 {end_date}") 156 | df = prepare_etf_data( 157 | processor=processor, 158 | ticker_list=ticker_list or TICKER_LIST, 159 | start_date=start_date, 160 | end_date=end_date, 161 | data_source=data_source 162 | ) 163 | 164 | # 创建ETF交易环境 165 | stock_dimension = len(df['tic'].unique()) 166 | env = ETFTradingEnv( 167 | df=df, 168 | stock_dim=stock_dimension, 169 | hmax=100, 170 | initial_amount=initial_amount, 171 | transaction_cost_pct=transaction_cost_pct, 172 | reward_scaling=1.0, 173 | tech_indicator_list=technical_indicator_list, 174 | turbulence_threshold=turbulence_threshold, 175 | day_trade=True, 176 | reward_type='sharpe', 177 | cash_penalty_proportion=0.1, 178 | ) 179 | 180 | # 创建向量化环境 181 | env_vec = env.get_sb_env() 182 | 183 | # 创建智能体 184 | if agent.lower() == "ppo_elegant": 185 | # ElegantRL PPO实现 186 | model = PPOAgent( 187 | env=env_vec, 188 | model_name=model_name, 189 | learning_rate=kwargs.get('learning_rate', 0.0003), 190 | gamma=kwargs.get('gamma', 0.99), 191 | tensorboard_log=TENSORBOARD_LOG_PATH 192 | ) 193 | elif agent.lower() == "dqn": 194 | # Stable-Baselines3 DQN实现 195 | model = DQNAgent( 196 | env=env_vec, 197 | model_name=model_name, 198 | learning_rate=kwargs.get('learning_rate', 0.0001), 199 | gamma=kwargs.get('gamma', 0.99), 200 | tensorboard_log=TENSORBOARD_LOG_PATH 201 | ) 202 | elif agent.lower() == "ppo": 203 | # Stable-Baselines3 PPO实现 204 | model = PPO( 205 | "MlpPolicy", 206 | env_vec, 207 | verbose=1, 208 | learning_rate=kwargs.get('learning_rate', 0.0003), 209 | gamma=kwargs.get('gamma', 0.99), 210 | tensorboard_log=TENSORBOARD_LOG_PATH 211 | ) 212 | elif agent.lower() == "a2c": 213 | # Stable-Baselines3 A2C实现 214 | model = A2C( 215 | "MlpPolicy", 216 | env_vec, 217 | verbose=1, 218 | learning_rate=kwargs.get('learning_rate', 0.0007), 219 | gamma=kwargs.get('gamma', 0.99), 220 | tensorboard_log=TENSORBOARD_LOG_PATH 221 | ) 222 | elif agent.lower() == "ddpg": 223 | # Stable-Baselines3 DDPG实现 224 | model = DDPG( 225 | "MlpPolicy", 226 | env_vec, 227 | verbose=1, 228 | learning_rate=kwargs.get('learning_rate', 0.0001), 229 | gamma=kwargs.get('gamma', 0.99), 230 | tensorboard_log=TENSORBOARD_LOG_PATH 231 | ) 232 | elif agent.lower() == "sac": 233 | # Stable-Baselines3 SAC实现 234 | model = SAC( 235 | "MlpPolicy", 236 | env_vec, 237 | verbose=1, 238 | learning_rate=kwargs.get('learning_rate', 0.0003), 239 | gamma=kwargs.get('gamma', 0.99), 240 | tensorboard_log=TENSORBOARD_LOG_PATH 241 | ) 242 | else: 243 | raise ValueError(f"不支持的智能体类型: {agent}") 244 | 245 | # 创建回调 246 | callback = TensorboardCallback() 247 | 248 | # 创建模型保存路径 249 | if not os.path.exists(MODEL_SAVE_PATH): 250 | os.makedirs(MODEL_SAVE_PATH) 251 | 252 | # 训练模型 253 | logger.info(f"开始训练模型: {model_name}") 254 | 255 | # 根据不同模型类型采用不同的训练方法 256 | if agent.lower() in ["ppo_elegant"]: 257 | # ElegantRL训练方法 258 | model.train( 259 | total_timesteps=num_episodes * 100, 260 | eval_freq=1000, 261 | n_eval_episodes=5, 262 | log_interval=100 263 | ) 264 | else: 265 | # Stable-Baselines3训练方法 266 | model.learn( 267 | total_timesteps=num_episodes * 100, 268 | callback=callback, 269 | tb_log_name=model_name 270 | ) 271 | 272 | # 保存模型 273 | if if_store_model: 274 | if agent.lower() in ["ppo_elegant"]: 275 | # ElegantRL模型保存 276 | model_path = os.path.join(MODEL_SAVE_PATH, f"{model_name}.pt") 277 | model.save(model_path) 278 | else: 279 | # Stable-Baselines3模型保存 280 | model_path = os.path.join(MODEL_SAVE_PATH, f"{model_name}.zip") 281 | model.save(model_path) 282 | 283 | logger.info(f"模型已保存至: {model_path}") 284 | 285 | # 绘制训练曲线 286 | if len(callback.rewards) > 0: 287 | plt.figure(figsize=(12, 6)) 288 | plt.plot(callback.rewards) 289 | plt.title(f'{agent.upper()} ETF交易策略 - 训练奖励') 290 | plt.xlabel('步数') 291 | plt.ylabel('奖励') 292 | plt.grid(True) 293 | plt.savefig(f"results/{model_name}_train_rewards.png") 294 | plt.close() 295 | 296 | # 返回训练好的模型 297 | return model 298 | 299 | 300 | def load_etf_model( 301 | model_name: str, 302 | agent: str = "ppo", 303 | env = None 304 | ) -> Any: 305 | """ 306 | 加载已训练的ETF交易模型 307 | 308 | 参数: 309 | model_name: 模型名称 310 | agent: 智能体类型 311 | env: 环境实例(可选) 312 | 313 | 返回: 314 | 加载的模型 315 | """ 316 | if agent.lower() == "ppo_elegant": 317 | # ElegantRL模型加载 318 | model_path = os.path.join(MODEL_SAVE_PATH, f"{model_name}.pt") 319 | model = PPOAgent(env=env) 320 | model.load(model_path) 321 | elif agent.lower() == "dqn": 322 | # Stable-Baselines3 DQN模型加载 323 | model_path = os.path.join(MODEL_SAVE_PATH, f"{model_name}.zip") 324 | model = DQN.load(model_path, env=env) 325 | elif agent.lower() == "ppo": 326 | # Stable-Baselines3 PPO模型加载 327 | model_path = os.path.join(MODEL_SAVE_PATH, f"{model_name}.zip") 328 | model = PPO.load(model_path, env=env) 329 | elif agent.lower() == "a2c": 330 | # Stable-Baselines3 A2C模型加载 331 | model_path = os.path.join(MODEL_SAVE_PATH, f"{model_name}.zip") 332 | model = A2C.load(model_path, env=env) 333 | elif agent.lower() == "ddpg": 334 | # Stable-Baselines3 DDPG模型加载 335 | model_path = os.path.join(MODEL_SAVE_PATH, f"{model_name}.zip") 336 | model = DDPG.load(model_path, env=env) 337 | elif agent.lower() == "sac": 338 | # Stable-Baselines3 SAC模型加载 339 | model_path = os.path.join(MODEL_SAVE_PATH, f"{model_name}.zip") 340 | model = SAC.load(model_path, env=env) 341 | else: 342 | raise ValueError(f"不支持的智能体类型: {agent}") 343 | 344 | return model 345 | 346 | 347 | if __name__ == "__main__": 348 | # 示例用法 349 | model = run_etf_strategy( 350 | start_date=TRAIN_START_DATE, 351 | end_date=TRAIN_END_DATE, 352 | ticker_list=TICKER_LIST, 353 | agent="ppo", 354 | model_name="ppo_etf_demo", 355 | num_episodes=NUM_EPISODES 356 | ) -------------------------------------------------------------------------------- /src/strategies/dual_ma_hedging/macd_hedge.py: -------------------------------------------------------------------------------- 1 | import backtrader as bt 2 | from loguru import logger 3 | 4 | class MACDHedge: 5 | def __init__(self, strategy): 6 | self.strategy = strategy 7 | self.enabled = False 8 | self.hedge_position = None 9 | self.hedge_entry_price = None 10 | self.hedge_order = None 11 | self.hedge_contract_code = None 12 | self.hedge_entry_date = None # 添加入场日期记录 13 | 14 | # 初始化MACD指标 15 | self.macd = bt.indicators.MACD( 16 | self.strategy.data.close, 17 | period_me1=12, 18 | period_me2=26, 19 | period_signal=9 20 | ) 21 | 22 | def enable(self): 23 | """启用MACD对冲功能""" 24 | self.enabled = True 25 | logger.info("启用MACD对冲功能") 26 | 27 | def disable(self): 28 | """禁用MACD对冲功能""" 29 | self.enabled = False 30 | logger.info("禁用MACD对冲功能") 31 | 32 | def on_death_cross(self): 33 | """在MACD零轴上方死叉时开空期货""" 34 | if not self.enabled: 35 | return 36 | 37 | if self.hedge_position is not None or self.hedge_order is not None: 38 | logger.info("已有对冲仓位或对冲订单,不再开仓") 39 | return 40 | 41 | # 检查MACD是否在零轴上方 42 | if self.macd.macd[0] <= 0: 43 | logger.info("MACD不在零轴上方,不开仓") 44 | return 45 | 46 | # 检查是否形成死叉(MACD线下穿信号线) 47 | if self.macd.macd[0] > self.macd.signal[0] or self.macd.macd[-1] <= self.macd.signal[-1]: 48 | logger.info("未形成死叉,不开仓") 49 | return 50 | 51 | try: 52 | # 计算ATR止盈止损价格 53 | current_atr = self.strategy.atr[0] 54 | stop_loss = self.strategy.data1.close[0] + (current_atr * self.strategy.p.atr_loss_multiplier) 55 | take_profit = self.strategy.data1.close[0] - (current_atr * self.strategy.p.atr_profit_multiplier) 56 | 57 | # 开空豆粕期货 58 | hedge_size = self.strategy.p.hedge_contract_size 59 | 60 | # 检查期货账户资金是否足够 61 | future_price = self.strategy.data1.close[0] 62 | margin_requirement = future_price * hedge_size * self.strategy.p.future_contract_multiplier * self.strategy.p.m_margin_ratio 63 | 64 | if margin_requirement > self.strategy.future_cash: 65 | logger.warning(f"期货账户资金不足,需要{margin_requirement:.2f},当前可用{self.strategy.future_cash:.2f}") 66 | # 根据可用资金调整手数 67 | adjusted_size = int(self.strategy.future_cash / (future_price * self.strategy.p.future_contract_multiplier * self.strategy.p.m_margin_ratio)) 68 | if adjusted_size < 1: 69 | logger.error("期货账户资金不足以开仓一手") 70 | return 71 | hedge_size = adjusted_size 72 | logger.info(f"已调整对冲手数为: {hedge_size}") 73 | 74 | self.hedge_order = self.strategy.sell(data=self.strategy.data1, size=hedge_size) 75 | 76 | if self.hedge_order: 77 | # 记录入场价格和合约代码 78 | self.hedge_entry_price = self.strategy.data1.close[0] 79 | # 确保获取正确的合约代码 80 | current_date = self.strategy.data1.datetime.datetime(0) 81 | if hasattr(self.strategy.data1, 'contract_mapping') and current_date in self.strategy.data1.contract_mapping: 82 | self.hedge_contract_code = self.strategy.data1.contract_mapping[current_date] 83 | else: 84 | # 如果无法获取映射,使用数据名称 85 | self.hedge_contract_code = self.strategy.data1._name 86 | self.hedge_entry_date = self.strategy.data.datetime.date(0) # 记录入场日期 87 | 88 | # 计算保证金 89 | margin = self.hedge_entry_price * hedge_size * self.strategy.p.future_contract_multiplier * self.strategy.p.m_margin_ratio 90 | 91 | # 从期货账户扣除保证金 92 | pre_cash = self.strategy.future_cash 93 | self.strategy.future_cash -= margin 94 | 95 | logger.info(f"开仓扣除保证金 - 之前: {pre_cash:.2f}, 扣除: {margin:.2f}, 之后: {self.strategy.future_cash:.2f}") 96 | 97 | # 记录交易信息 98 | self.hedge_order.info = { 99 | 'reason': f"MACD死叉开空 - MACD: {self.macd.macd[0]:.2f}, Signal: {self.macd.signal[0]:.2f}", 100 | 'margin': margin, 101 | 'future_cash': self.strategy.future_cash, 102 | 'execution_date': self.hedge_entry_date, 103 | 'total_value': self.strategy.future_cash, 104 | 'position_value': abs(margin), 105 | 'position_ratio': margin / self.strategy.future_cash if self.strategy.future_cash > 0 else 0, 106 | 'etf_code': self.hedge_contract_code, 107 | 'pnl': 0, 108 | 'return': 0, 109 | 'stop_loss': stop_loss, 110 | 'take_profit': take_profit, 111 | 'avg_cost': self.hedge_entry_price 112 | } 113 | 114 | logger.info(f"MACD死叉开空 - 合约: {self.hedge_contract_code}, 价格: {self.hedge_entry_price:.2f}, 数量: {hedge_size}手, " 115 | f"止损价: {stop_loss:.2f}, 止盈价: {take_profit:.2f}") 116 | 117 | except Exception as e: 118 | logger.error(f"MACD死叉开空失败: {str(e)}") 119 | 120 | def check_exit(self): 121 | """检查是否需要平仓""" 122 | if not self.enabled or not self.hedge_position: 123 | return 124 | 125 | current_price = self.strategy.data1.close[0] 126 | current_atr = self.strategy.atr[0] 127 | 128 | # 计算ATR止盈止损价格 129 | stop_loss = self.hedge_entry_price + (current_atr * self.strategy.p.atr_loss_multiplier) 130 | take_profit = self.hedge_entry_price - (current_atr * self.strategy.p.atr_profit_multiplier) 131 | 132 | # 检查是否触发止盈止损 133 | if current_price >= stop_loss or current_price <= take_profit: 134 | if self.hedge_order is None: # 确保没有未完成订单 135 | self.hedge_order = self.strategy.close(data=self.strategy.data1) 136 | reason = "触发止盈" if current_price <= take_profit else "触发止损" 137 | logger.info(f"MACD死叉对冲{reason} - 当前价格: {current_price:.2f}, {reason}价: {take_profit if current_price <= take_profit else stop_loss:.2f}") 138 | 139 | def on_order_completed(self, order): 140 | """处理订单完成事件""" 141 | if not self.enabled: 142 | return 143 | 144 | if order.status in [order.Completed]: 145 | if order.isbuy(): # 买入豆粕期货(平空) 146 | # 确保有对应的入场价格 147 | if self.hedge_entry_price is None or self.hedge_contract_code is None: 148 | logger.error("平仓时找不到入场价格或合约代码,跳过处理") 149 | return 150 | 151 | # 记录平仓前的合约信息,用于日志 152 | entry_price = self.hedge_entry_price 153 | contract_code = self.hedge_contract_code 154 | entry_date = self.hedge_entry_date 155 | 156 | # 记录交易日期和价格 157 | trade_date = self.strategy.data.datetime.date(0) 158 | trade_price = order.executed.price 159 | 160 | # 先重置持仓相关变量,防止重复平仓 161 | self.hedge_position = None 162 | self.hedge_order = None 163 | self.hedge_entry_price = None 164 | self.hedge_contract_code = None 165 | self.hedge_entry_date = None 166 | self.hedge_target_profit = None 167 | 168 | # 计算对冲盈亏(空仓:入场价 - 平仓价) 169 | hedge_profit = (entry_price - trade_price) * self.strategy.p.hedge_contract_size * self.strategy.p.future_contract_multiplier 170 | 171 | # 减去开平仓手续费 172 | total_fee = self.strategy.p.hedge_fee * self.strategy.p.hedge_contract_size * 2 173 | net_profit = hedge_profit - total_fee 174 | 175 | # 归还保证金并添加盈亏到期货账户 176 | margin_returned = entry_price * self.strategy.p.hedge_contract_size * self.strategy.p.future_contract_multiplier * self.strategy.p.m_margin_ratio 177 | 178 | # 记录更新前的资金 179 | pre_cash = self.strategy.future_cash 180 | 181 | # 更新期货账户资金 182 | self.strategy.future_cash += (margin_returned + net_profit) 183 | 184 | # 记录资金变动 185 | logger.info(f"平仓资金变动 - 之前: {pre_cash:.2f}, 返还保证金: {margin_returned:.2f}, 盈亏: {net_profit:.2f}, 之后: {self.strategy.future_cash:.2f}") 186 | 187 | # 更新期货账户最高净值 188 | self.strategy.future_highest_value = max(self.strategy.future_highest_value, self.strategy.future_cash) 189 | 190 | # 计算期货账户回撤 191 | future_drawdown = (self.strategy.future_highest_value - self.strategy.future_cash) / self.strategy.future_highest_value if self.strategy.future_highest_value > 0 else 0 192 | 193 | # 计算收益率 194 | return_pct = (hedge_profit / (entry_price * self.strategy.p.hedge_contract_size * self.strategy.p.future_contract_multiplier)) * 100 195 | 196 | # 更新订单信息 197 | order.info.update({ 198 | 'pnl': hedge_profit, 199 | 'return': return_pct, 200 | 'total_value': self.strategy.future_cash, 201 | 'position_value': 0, # 平仓后持仓价值为0 202 | 'avg_cost': entry_price, 203 | 'etf_code': contract_code, # 确保使用原始合约代码 204 | 'execution_date': trade_date, # 确保使用当前交易日期 205 | 'reason': f"MACD死叉对冲平仓 - 合约: {contract_code}, 入场日期: {entry_date}, 入场价: {entry_price:.2f}, 平仓价: {trade_price:.2f}, 收益率: {return_pct:.2f}%" 206 | }) 207 | 208 | logger.info(f"MACD死叉对冲平仓 - 日期: {trade_date}, 合约: {contract_code}, 价格: {trade_price:.2f}, 盈利: {hedge_profit:.2f}, " 209 | f"手续费: {total_fee:.2f}, 净盈利: {net_profit:.2f}, " 210 | f"期货账户余额: {self.strategy.future_cash:.2f}, 回撤: {future_drawdown:.2%}, " 211 | f"收益率: {return_pct:.2f}%") 212 | 213 | else: # 卖出豆粕期货(开空) 214 | # 记录对冲持仓 215 | self.hedge_position = order 216 | 217 | # 更新订单信息 218 | order.info.update({ 219 | 'total_value': self.strategy.future_cash, 220 | 'position_value': abs(order.info['margin']), 221 | 'avg_cost': order.executed.price, 222 | 'etf_code': self.hedge_contract_code # 确保合约代码正确 223 | }) 224 | 225 | elif order.status in [order.Canceled, order.Margin, order.Rejected]: 226 | self.hedge_order = None 227 | logger.warning(f'MACD死叉对冲订单失败 - 状态: {order.getstatusname()}') 228 | 229 | def on_strategy_stop(self): 230 | """策略结束时平掉所有持仓""" 231 | if not self.enabled or not self.hedge_position: 232 | return 233 | 234 | try: 235 | if self.hedge_order is None: # 确保没有未完成订单 236 | self.hedge_order = self.strategy.close(data=self.strategy.data1) 237 | logger.info("策略结束,平掉MACD对冲持仓") 238 | except Exception as e: 239 | logger.error(f"策略结束时平仓失败: {str(e)}") --------------------------------------------------------------------------------