├── .gitignore ├── .pre-commit-config.yaml ├── backtrader_app.py ├── charts ├── __init__.py ├── results.py └── stock.py ├── config └── strategy.yaml ├── demo.gif ├── frames ├── __init__.py ├── form.py └── sidebar.py ├── readme.md ├── requirements.txt ├── strategy ├── __init__.py ├── base.py ├── ma.py └── macross.py ├── tests ├── __init__.py ├── base_test.py ├── ma_test.py └── macross_test.py └── utils ├── __init__.py ├── load.py ├── logs.py ├── processing.py └── schemas.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.h5 2 | *.ipynb 3 | .pytest_cache/ 4 | .vscode/ 5 | tests/unused/* 6 | # Byte-compiled / optimized / DLL files 7 | __pycache__/ 8 | *.py[cod] 9 | *$py.class 10 | .idea/ 11 | # C extensions 12 | *.so 13 | 14 | # Distribution / packaging 15 | .Python 16 | env/ 17 | build/ 18 | develop-eggs/ 19 | dist/ 20 | downloads/ 21 | eggs/ 22 | .eggs/ 23 | lib/ 24 | lib64/ 25 | parts/ 26 | sdist/ 27 | var/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *,cover 51 | .hypothesis/ 52 | 53 | # Translations 54 | *.mo 55 | *.pot 56 | 57 | # Django stuff: 58 | *.log 59 | local_settings.py 60 | 61 | # Flask instance folder 62 | instance/ 63 | 64 | # Scrapy stuff: 65 | .scrapy 66 | 67 | # Sphinx documentation 68 | docs/_build/ 69 | 70 | # PyBuilder 71 | target/ 72 | 73 | # IPython Notebook 74 | .ipynb_checkpoints 75 | 76 | # pyenv 77 | .python-version 78 | 79 | # celery beat schedule file 80 | celerybeat-schedule 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | venv/ 87 | ENV/ 88 | 89 | # Spyder project settings 90 | .spyderproject 91 | 92 | # Rope project settings 93 | .ropeproject 94 | 95 | # ========================= 96 | # Operating System Files 97 | # ========================= 98 | 99 | # OSX 100 | # ========================= 101 | 102 | .DS_Store 103 | .AppleDouble 104 | .LSOverride 105 | 106 | # Thumbnails 107 | ._* 108 | 109 | # Files that might appear in the root of a volume 110 | .DocumentRevisions-V100 111 | .fseventsd 112 | .Spotlight-V100 113 | .TemporaryItems 114 | .Trashes 115 | .VolumeIcon.icns 116 | 117 | # Directories potentially created on remote AFP share 118 | .AppleDB 119 | .AppleDesktop 120 | Network Trash Folder 121 | Temporary Items 122 | .apdisk 123 | 124 | # Windows 125 | # ========================= 126 | 127 | # Windows image file caches 128 | Thumbs.db 129 | ehthumbs.db 130 | 131 | # Folder config file 132 | Desktop.ini 133 | 134 | # Recycle Bin used on file shares 135 | $RECYCLE.BIN/ 136 | 137 | # Windows Installer files 138 | *.cab 139 | *.msi 140 | *.msm 141 | *.msp 142 | 143 | # Windows shortcuts 144 | *.lnk 145 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | default_stages: [ pre-commit ] 2 | 3 | repos: 4 | - repo: https://github.com/pycqa/isort 5 | rev: 6.0.1 6 | hooks: 7 | - id: isort 8 | args: ['--profile', 'black'] 9 | 10 | - repo: https://github.com/astral-sh/ruff-pre-commit 11 | rev: v0.11.3 12 | hooks: 13 | - id: ruff 14 | args: [ --fix ] 15 | 16 | - repo: https://github.com/psf/black 17 | rev: 25.1.0 18 | hooks: 19 | - id: black 20 | args: ['--line-length', '120'] -------------------------------------------------------------------------------- /backtrader_app.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | from streamlit_echarts import st_pyecharts 3 | 4 | from charts import draw_pro_kline, draw_result_bar 5 | from frames import akshare_selector_ui, backtrader_selector_ui, params_selector_ui 6 | from utils.load import load_strategy 7 | from utils.logs import logger 8 | from utils.processing import gen_stock_df, run_backtrader 9 | from utils.schemas import StrategyBase 10 | 11 | st.set_page_config(page_title="backtrader", page_icon=":chart_with_upwards_trend:", layout="wide") 12 | 13 | 14 | def main(): 15 | ak_params = akshare_selector_ui() 16 | bt_params = backtrader_selector_ui() 17 | if ak_params.symbol: 18 | stock_df = gen_stock_df(ak_params) 19 | if stock_df.empty: 20 | st.error("Get stock data failed!") 21 | return 22 | 23 | st.subheader("Kline") 24 | kline = draw_pro_kline(stock_df) 25 | st_pyecharts(kline, height="500px") 26 | 27 | st.subheader("Strategy") 28 | name = st.selectbox("strategy", list(strategy_dict.keys())) 29 | submitted, params = params_selector_ui(strategy_dict[name]) 30 | if submitted: 31 | logger.info(f"akshare: {ak_params}") 32 | logger.info(f"backtrader: {bt_params}") 33 | stock_df = stock_df.rename( 34 | columns={ 35 | "日期": "date", 36 | "开盘": "open", 37 | "收盘": "close", 38 | "最高": "high", 39 | "最低": "low", 40 | "成交量": "volume", 41 | } 42 | ) 43 | strategy = StrategyBase(name=name, params=params) 44 | par_df = run_backtrader(stock_df, strategy, bt_params) 45 | st.dataframe(par_df.style.highlight_max(subset=par_df.columns[-3:])) 46 | bar = draw_result_bar(par_df) 47 | st_pyecharts(bar, height="500px") 48 | 49 | 50 | strategy_dict = load_strategy("./config/strategy.yaml") 51 | 52 | if __name__ == "__main__": 53 | main() 54 | -------------------------------------------------------------------------------- /charts/__init__.py: -------------------------------------------------------------------------------- 1 | from .results import draw_result_bar 2 | from .stock import draw_pro_kline 3 | 4 | __all__ = ["draw_pro_kline", "draw_result_bar"] 5 | -------------------------------------------------------------------------------- /charts/results.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from pyecharts import options as opts 3 | from pyecharts.charts import Bar 4 | 5 | 6 | def draw_result_bar(df: pd.DataFrame, n_scors: int = 3) -> Bar: 7 | params_columns = df.columns[:-n_scors] 8 | scores_columns = df.columns[-n_scors:] 9 | x_data = ( 10 | df[params_columns] 11 | .apply( 12 | lambda x: "\n".join([f"{name}_{value}" for name, value in zip(params_columns, x)]), 13 | axis=1, 14 | ) 15 | .values.tolist() 16 | ) 17 | bar = ( 18 | Bar() 19 | .add_xaxis(x_data) 20 | .set_global_opts( 21 | tooltip_opts=opts.TooltipOpts(trigger="axis"), 22 | legend_opts=opts.LegendOpts(selected_mode="single"), 23 | ) 24 | ) 25 | for col in scores_columns: 26 | bar.add_yaxis(col, df[col].values.tolist()) 27 | bar.set_series_opts( 28 | label_opts=opts.LabelOpts(is_show=False), 29 | markpoint_opts=opts.MarkPointOpts( 30 | data=[ 31 | opts.MarkPointItem(type_="max", name="最大值"), 32 | opts.MarkPointItem(type_="min", name="最小值"), 33 | ] 34 | ), 35 | ) 36 | 37 | return bar 38 | -------------------------------------------------------------------------------- /charts/stock.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import pyecharts.options as opts 3 | from pyecharts.charts import Bar, Grid, Kline, Line 4 | 5 | 6 | def split_data(df: pd.DataFrame) -> tuple[list[str], list[list[float]], pd.Series, list[list[float]]]: 7 | x_data = df["日期"].values.tolist() 8 | y_data = df[["开盘", "收盘", "最低", "最高"]].values.tolist() 9 | df_close = df["收盘"] 10 | 11 | df["index"] = df.index 12 | df["rise"] = df[["开盘", "收盘"]].apply(lambda x: 1 if x.iloc[0] > x.iloc[1] else -1, axis=1) 13 | y_vol = df[["index", "成交量", "rise"]].values.tolist() 14 | return x_data, y_data, df_close, y_vol 15 | 16 | 17 | def calculate_ma(day_count: int, df: pd.DataFrame) -> list[float]: 18 | df_ma = df.rolling(day_count).mean().round(2).fillna("-") 19 | return df_ma.values.tolist() 20 | 21 | 22 | def draw_pro_kline(df: pd.DataFrame) -> Grid: 23 | x_data, y_data, df_close, y_vol = split_data(df) 24 | 25 | kline = ( 26 | Kline() 27 | .add_xaxis(xaxis_data=x_data) 28 | .add_yaxis( 29 | series_name="日K", 30 | y_axis=y_data, 31 | itemstyle_opts=opts.ItemStyleOpts(color="#ec0000", color0="#00da3c"), 32 | ) 33 | .set_global_opts( 34 | legend_opts=opts.LegendOpts(is_show=False, pos_bottom=10, pos_left="center"), 35 | datazoom_opts=[ 36 | opts.DataZoomOpts( 37 | is_show=False, 38 | type_="inside", 39 | xaxis_index=[0, 1], 40 | range_start=80, 41 | range_end=100, 42 | ), 43 | opts.DataZoomOpts( 44 | is_show=True, 45 | xaxis_index=[0, 1], 46 | type_="slider", 47 | pos_top="85%", 48 | range_start=80, 49 | range_end=100, 50 | ), 51 | ], 52 | yaxis_opts=opts.AxisOpts( 53 | is_scale=True, 54 | splitarea_opts=opts.SplitAreaOpts(is_show=True, areastyle_opts=opts.AreaStyleOpts(opacity=1)), 55 | ), 56 | tooltip_opts=opts.TooltipOpts( 57 | trigger="axis", 58 | axis_pointer_type="cross", 59 | background_color="rgba(245, 245, 245, 0.8)", 60 | border_width=1, 61 | border_color="#ccc", 62 | textstyle_opts=opts.TextStyleOpts(color="#000"), 63 | ), 64 | visualmap_opts=opts.VisualMapOpts( 65 | is_show=False, 66 | dimension=2, 67 | series_index=5, 68 | is_piecewise=True, 69 | pieces=[ 70 | {"value": 1, "color": "#00da3c"}, 71 | {"value": -1, "color": "#ec0000"}, 72 | ], 73 | ), 74 | axispointer_opts=opts.AxisPointerOpts( 75 | is_show=True, 76 | link=[{"xAxisIndex": "all"}], 77 | label=opts.LabelOpts(background_color="#777"), 78 | ), 79 | brush_opts=opts.BrushOpts( 80 | x_axis_index="all", 81 | brush_link="all", 82 | out_of_brush={"colorAlpha": 0.1}, 83 | brush_type="lineX", 84 | ), 85 | ) 86 | ) 87 | 88 | line = ( 89 | Line() 90 | .add_xaxis(xaxis_data=x_data) 91 | .add_yaxis( 92 | series_name="MA5", 93 | y_axis=calculate_ma(5, df_close), 94 | is_smooth=True, 95 | is_hover_animation=False, 96 | linestyle_opts=opts.LineStyleOpts(width=3, opacity=0.5), 97 | label_opts=opts.LabelOpts(is_show=False), 98 | ) 99 | .add_yaxis( 100 | series_name="MA10", 101 | y_axis=calculate_ma(10, df_close), 102 | is_smooth=True, 103 | is_hover_animation=False, 104 | linestyle_opts=opts.LineStyleOpts(width=3, opacity=0.5), 105 | label_opts=opts.LabelOpts(is_show=False), 106 | ) 107 | .add_yaxis( 108 | series_name="MA20", 109 | y_axis=calculate_ma(20, df_close), 110 | is_smooth=True, 111 | is_hover_animation=False, 112 | linestyle_opts=opts.LineStyleOpts(width=3, opacity=0.5), 113 | label_opts=opts.LabelOpts(is_show=False), 114 | ) 115 | .add_yaxis( 116 | series_name="MA30", 117 | y_axis=calculate_ma(30, df_close), 118 | is_smooth=True, 119 | is_hover_animation=False, 120 | linestyle_opts=opts.LineStyleOpts(width=3, opacity=0.5), 121 | label_opts=opts.LabelOpts(is_show=False), 122 | ) 123 | .set_global_opts(xaxis_opts=opts.AxisOpts(type_="category")) 124 | ) 125 | 126 | bar = ( 127 | Bar() 128 | .add_xaxis(xaxis_data=x_data) 129 | .add_yaxis( 130 | series_name="Volume", 131 | y_axis=y_vol, 132 | xaxis_index=1, 133 | yaxis_index=1, 134 | label_opts=opts.LabelOpts(is_show=False), 135 | ) 136 | .set_global_opts( 137 | xaxis_opts=opts.AxisOpts( 138 | type_="category", 139 | is_scale=True, 140 | grid_index=1, 141 | boundary_gap=True, 142 | axisline_opts=opts.AxisLineOpts(is_on_zero=False), 143 | axistick_opts=opts.AxisTickOpts(is_show=False), 144 | splitline_opts=opts.SplitLineOpts(is_show=False), 145 | axislabel_opts=opts.LabelOpts(is_show=False), 146 | split_number=20, 147 | min_="dataMin", 148 | max_="dataMax", 149 | ), 150 | yaxis_opts=opts.AxisOpts( 151 | grid_index=1, 152 | is_scale=True, 153 | split_number=2, 154 | axislabel_opts=opts.LabelOpts(is_show=False), 155 | axisline_opts=opts.AxisLineOpts(is_show=False), 156 | axistick_opts=opts.AxisTickOpts(is_show=False), 157 | splitline_opts=opts.SplitLineOpts(is_show=False), 158 | ), 159 | legend_opts=opts.LegendOpts(is_show=False), 160 | ) 161 | ) 162 | 163 | # Kline And Line 164 | overlap_kline_line = kline.overlap(line) 165 | 166 | # Grid Overlap + Bar 167 | grid_chart = Grid( 168 | init_opts=opts.InitOpts( 169 | animation_opts=opts.AnimationOpts(animation=False), 170 | ) 171 | ) 172 | grid_chart.add( 173 | overlap_kline_line, 174 | grid_opts=opts.GridOpts(pos_left="10%", pos_right="8%", height="50%"), 175 | ) 176 | grid_chart.add( 177 | bar, 178 | grid_opts=opts.GridOpts(pos_left="10%", pos_right="8%", pos_top="63%", height="16%"), 179 | ) 180 | 181 | return grid_chart 182 | -------------------------------------------------------------------------------- /config/strategy.yaml: -------------------------------------------------------------------------------- 1 | Ma: 2 | - 3 | name: maperiod 4 | type: int 5 | min: 10 6 | max: 31 7 | step: 1 8 | 9 | MaCross: 10 | - 11 | name: fast_length 12 | type: int 13 | min: 1 14 | max: 11 15 | step: 5 16 | - 17 | name: slow_length 18 | type: int 19 | min: 25 20 | max: 35 21 | step: 5 -------------------------------------------------------------------------------- /demo.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenwr727/stock-backtrader-web-app/ff1d2e5c03bef12a62567f99fe50279f286ea0a9/demo.gif -------------------------------------------------------------------------------- /frames/__init__.py: -------------------------------------------------------------------------------- 1 | from .form import params_selector_ui 2 | from .sidebar import akshare_selector_ui, backtrader_selector_ui 3 | 4 | __all__ = ["akshare_selector_ui", "backtrader_selector_ui", "params_selector_ui"] 5 | -------------------------------------------------------------------------------- /frames/form.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | 3 | 4 | def params_selector_ui(params: dict) -> tuple[bool, dict]: 5 | params_parse = dict() 6 | with st.form("params"): 7 | for param in params: 8 | if param["type"] == "int": 9 | col1, col2 = st.columns(2) 10 | with col1: 11 | min_number = st.number_input("min " + param["name"], value=param["min"]) 12 | with col2: 13 | max_number = st.number_input("max " + param["name"], value=param["max"]) 14 | params_parse[param["name"]] = range(min_number, max_number, param["step"]) 15 | else: 16 | pass 17 | submitted = st.form_submit_button("Submit") 18 | return submitted, params_parse 19 | -------------------------------------------------------------------------------- /frames/sidebar.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | 3 | import streamlit as st 4 | 5 | from utils.schemas import AkshareParams, BacktraderParams 6 | 7 | 8 | def akshare_selector_ui() -> AkshareParams: 9 | """akshare params 10 | 11 | :return: AkshareParams 12 | """ 13 | st.sidebar.markdown("# Akshare Config") 14 | symbol = st.sidebar.text_input("symbol") 15 | period = st.sidebar.selectbox("period", ("daily", "weekly", "monthly")) 16 | start_date = st.sidebar.date_input("start date", datetime.date(1970, 1, 1)) 17 | start_date = start_date.strftime("%Y%m%d") 18 | end_date = st.sidebar.date_input("end date", datetime.datetime.today()) 19 | end_date = end_date.strftime("%Y%m%d") 20 | adjust = st.sidebar.selectbox("adjust", ("qfq", "hfq", "")) 21 | return AkshareParams( 22 | symbol=symbol, 23 | period=period, 24 | start_date=start_date, 25 | end_date=end_date, 26 | adjust=adjust, 27 | ) 28 | 29 | 30 | def backtrader_selector_ui() -> BacktraderParams: 31 | """backtrader params 32 | 33 | :return: BacktraderParams 34 | """ 35 | st.sidebar.markdown("# BackTrader Config") 36 | start_date = st.sidebar.date_input("backtrader start date", datetime.date(2000, 1, 1)) 37 | end_date = st.sidebar.date_input("backtrader end date", datetime.datetime.today()) 38 | start_cash = st.sidebar.number_input("start cash", min_value=0, value=100000, step=10000) 39 | commission_fee = st.sidebar.number_input("commission fee", min_value=0.0, max_value=1.0, value=0.001, step=0.0001) 40 | stake = st.sidebar.number_input("stake", min_value=0, value=100, step=10) 41 | return BacktraderParams( 42 | start_date=start_date, 43 | end_date=end_date, 44 | start_cash=start_cash, 45 | commission_fee=commission_fee, 46 | stake=stake, 47 | ) 48 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # 股票回测Web应用 2 | 3 | ## 项目概述 4 | 5 | 这是一个基于Python的股票回测Web应用,集成了多个强大的开源库,为量化交易研究提供了一站式解决方案。通过直观的界面,用户可以获取市场数据、执行策略回测并可视化分析结果。 6 | 7 | ![demo](demo.gif) 8 | 9 | ### 核心特性 10 | 11 | - **数据获取** - 通过AkShare实时获取A股市场数据 12 | - **策略回测** - 利用Backtrader框架测试交易策略表现 13 | - **结果可视化** - 使用Pyecharts生成专业图表展示 14 | - **交互界面** - 基于Streamlit构建友好的Web操作环境 15 | 16 | ## 技术架构 17 | 18 | | 组件 | 功能 | 链接 | 19 | |------|------|------| 20 | | **Streamlit** | 构建交互式数据应用界面 | [官方仓库](https://github.com/streamlit/streamlit) | 21 | | **AkShare** | 获取金融市场数据 | [官方仓库](https://github.com/akfamily/akshare) | 22 | | **Backtrader** | 执行量化交易策略回测 | [官方仓库](https://github.com/mementum/backtrader) | 23 | | **Pyecharts** | 生成专业金融数据图表 | [官方仓库](https://github.com/pyecharts/pyecharts) | 24 | 25 | ## 快速开始 26 | 27 | ### 环境准备 28 | 29 | 确保已安装所有依赖包: 30 | 31 | ```bash 32 | pip install -r requirements.txt 33 | ``` 34 | 35 | ### 启动应用 36 | 37 | 执行以下命令启动Web界面: 38 | 39 | ```bash 40 | streamlit run backtrader_app.py 41 | ``` 42 | 43 | ### 策略测试 44 | 45 | 运行内置策略的单元测试: 46 | 47 | ```bash 48 | python -m unittest tests.MaStrategyTest 49 | ``` 50 | 51 | ## 支持的策略 52 | 53 | 本项目实现了以下量化交易策略: 54 | 55 | - **MA策略** - 基于单一移动平均线的趋势跟踪策略 56 | - **MACross策略** - 基于快慢双均线交叉的信号策略 57 | 58 | ## 参数配置指南 59 | 60 | ### AkShare数据参数 61 | 62 | | 参数 | 说明 | 63 | |------|------| 64 | | **symbol** | 股票代码(如:600070) | 65 | | **period** | 数据周期(日线、周线、月线) | 66 | | **start date** | 数据起始日期 | 67 | | **end date** | 数据结束日期 | 68 | | **adjust** | 复权方式(qfq:前复权,hfq:后复权) | 69 | 70 | ### Backtrader回测参数 71 | 72 | | 参数 | 说明 | 73 | |------|------| 74 | | **start date** | 回测开始日期 | 75 | | **end date** | 回测结束日期 | 76 | | **start cash** | 初始资金 | 77 | | **commission fee** | 交易佣金比例 | 78 | | **stake** | 每次交易股数 | 79 | 80 | ## 相关推荐 81 | 82 | - [**FinVizAI**](https://github.com/chenwr727/FinVizAI.git) - 一键生成股票与期货分析视频 83 | - [**akshare-gpt**](https://github.com/chenwr727/akshare-gpt.git) - 将AkShare集成到GPT中实现自然语言金融数据查询 84 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | akshare==1.16.81 2 | backtrader==1.9.78.123 3 | loguru==0.7.3 4 | pandas==2.2.3 5 | pre-commit==4.2.0 6 | pydantic==2.11.3 7 | pyecharts==2.0.8 8 | PyYAML==6.0.2 9 | streamlit==1.44.1 10 | streamlit_echarts==0.4.0 11 | -------------------------------------------------------------------------------- /strategy/__init__.py: -------------------------------------------------------------------------------- 1 | from .ma import MaStrategy 2 | from .macross import MaCrossStrategy 3 | 4 | 5 | __all__ = ["MaStrategy", "MaCrossStrategy"] 6 | -------------------------------------------------------------------------------- /strategy/base.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import backtrader as bt 4 | 5 | from utils.logs import logger 6 | 7 | 8 | class BaseStrategy(bt.Strategy): 9 | """base strategy""" 10 | 11 | _name = "base" 12 | params = (("printlog", False),) 13 | 14 | def log(self, txt: str, dt: Optional[bt.datetime.date] = None, doprint: bool = False) -> None: 15 | """Logging function for this strategy""" 16 | if self.params.printlog or doprint: 17 | dt = dt or self.datas[0].datetime.date(0) 18 | logger.info("%s, %s" % (dt.isoformat(), txt)) 19 | 20 | def notify_order(self, order: bt.OrderBase) -> None: 21 | if order.status in [order.Submitted, order.Accepted]: 22 | # Buy/Sell order submitted/accepted to/by broker - Nothing to do 23 | return 24 | 25 | # Check if an order has been completed 26 | # Attention: broker could reject order if not enough cash 27 | if order.status in [order.Completed]: 28 | if order.isbuy(): 29 | self.log( 30 | "BUY EXECUTED, Price: %.2f, Cost: %.2f, Comm %.2f" 31 | % (order.executed.price, order.executed.value, order.executed.comm) 32 | ) 33 | 34 | self.buyprice = order.executed.price 35 | self.buycomm = order.executed.comm 36 | else: # Sell 37 | self.log( 38 | "SELL EXECUTED, Price: %.2f, Cost: %.2f, Comm %.2f" 39 | % (order.executed.price, order.executed.value, order.executed.comm) 40 | ) 41 | 42 | self.bar_executed = len(self) 43 | 44 | elif order.status in [order.Canceled, order.Margin, order.Rejected]: 45 | self.log("Order Canceled/Margin/Rejected") 46 | 47 | # Write down: no pending order 48 | self.order = None 49 | 50 | def notify_trade(self, trade: bt.Trade) -> None: 51 | if not trade.isclosed: 52 | return 53 | 54 | self.log("OPERATION PROFIT, GROSS %.2f, NET %.2f" % (trade.pnl, trade.pnlcomm)) 55 | 56 | def next(self) -> None: 57 | pass 58 | 59 | def stop(self) -> None: 60 | params = [f"{k}_{v}" for k, v in self.params._getkwargs().items() if k != "printlog"] 61 | self.log( 62 | "(%s %s) Ending Value %.2f" % (self._name, " ".join(params), self.broker.getvalue()), 63 | doprint=True, 64 | ) 65 | -------------------------------------------------------------------------------- /strategy/ma.py: -------------------------------------------------------------------------------- 1 | import backtrader as bt 2 | 3 | from .base import BaseStrategy 4 | 5 | 6 | class MaStrategy(BaseStrategy): 7 | """Ma strategy""" 8 | 9 | _name = "Ma" 10 | params = ( 11 | ("maperiod", 15), 12 | ("printlog", False), 13 | ) 14 | 15 | def __init__(self) -> None: 16 | super().__init__() 17 | # Keep a reference to the "close" line in the data[0] dataseries 18 | self.dataclose = self.datas[0].close 19 | 20 | # To keep track of pending orders and buy price/commission 21 | self.order = None 22 | self.buyprice = None 23 | self.buycomm = None 24 | 25 | # Add a MovingAverageSimple indicator 26 | self.sma = bt.indicators.SMA(self.datas[0], period=self.params.maperiod) 27 | 28 | def next(self) -> None: 29 | # Simply log the closing price of the series from the reference 30 | self.log(f"Close, {self.dataclose[0]:.2f}") 31 | 32 | # Check if an order is pending ... if yes, we cannot send a 2nd one 33 | if self.order: 34 | return 35 | 36 | # Check if we are in the market 37 | if not self.position: 38 | # Not yet ... we MIGHT BUY if ... 39 | if self.dataclose[0] > self.sma[0]: 40 | # BUY, BUY, BUY!!! (with all possible default parameters) 41 | self.log(f"BUY CREATE, {self.dataclose[0]:.2f}") 42 | # Keep track of the created order to avoid a 2nd order 43 | self.order = self.buy() 44 | else: 45 | if self.dataclose[0] < self.sma[0]: 46 | # SELL, SELL, SELL!!! (with all possible default parameters) 47 | self.log(f"SELL CREATE, {self.dataclose[0]:.2f}") 48 | # Keep track of the created order to avoid a 2nd order 49 | self.order = self.sell() 50 | -------------------------------------------------------------------------------- /strategy/macross.py: -------------------------------------------------------------------------------- 1 | import backtrader as bt 2 | 3 | from .base import BaseStrategy 4 | 5 | 6 | class MaCrossStrategy(BaseStrategy): 7 | """MaCross strategy - 双均线交叉策略""" 8 | 9 | _name = "MaCross" 10 | params = (("printlog", False), ("fast_length", 10), ("slow_length", 50)) 11 | 12 | def __init__(self) -> None: 13 | super().__init__() 14 | # Keep a reference to the "close" line in the data[0] dataseries 15 | self.dataclose = self.datas[0].close 16 | 17 | # To keep track of pending orders and buy price/commission 18 | self.order = None 19 | self.buyprice = None 20 | self.buycomm = None 21 | 22 | # Add a MovingAverageSimple indicator 23 | ma_fast = bt.ind.SMA(period=self.params.fast_length) 24 | ma_slow = bt.ind.SMA(period=self.params.slow_length) 25 | 26 | self.crossover = bt.ind.CrossOver(ma_fast, ma_slow) 27 | 28 | def next(self) -> None: 29 | # Simply log the closing price of the series from the reference 30 | self.log(f"Close, {self.dataclose[0]:.2f}") 31 | 32 | # Check if an order is pending ... if yes, we cannot send a 2nd one 33 | if self.order: 34 | return 35 | 36 | # Check if we are in the market 37 | if not self.position: 38 | # Not yet ... we MIGHT BUY if ... 39 | if self.crossover > 0: 40 | # BUY, BUY, BUY!!! (with all possible default parameters) 41 | self.log(f"BUY CREATE, {self.dataclose[0]:.2f}") 42 | # Keep track of the created order to avoid a 2nd order 43 | self.order = self.buy() 44 | else: 45 | if self.crossover < 0: 46 | # SELL, SELL, SELL!!! (with all possible default parameters) 47 | self.log(f"SELL CREATE, {self.dataclose[0]:.2f}") 48 | # Keep track of the created order to avoid a 2nd order 49 | self.order = self.sell() 50 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | from .ma_test import MaStrategyTest 2 | from .macross_test import MaCrossStrategyTest 3 | 4 | 5 | __all__ = ["MaStrategyTest", "MaCrossStrategyTest"] -------------------------------------------------------------------------------- /tests/base_test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from datetime import datetime 3 | from typing import Type 4 | 5 | import akshare as ak 6 | import backtrader as bt 7 | import backtrader.analyzers as btanalyzers 8 | import pandas as pd 9 | 10 | from strategy.base import BaseStrategy 11 | from utils.load import load_strategy 12 | 13 | 14 | class StrategyTest(unittest.TestCase): 15 | """策略测试基类""" 16 | 17 | def setUp(self): 18 | """测试前准备工作,加载数据和设置回测环境""" 19 | 20 | # 加载股票历史数据 21 | stock_hfq_df = ak.stock_zh_a_hist(symbol="600070", adjust="hfq", start_date="20230101", end_date="20250101") 22 | stock_hfq_df = stock_hfq_df[["日期", "开盘", "收盘", "最高", "最低", "成交量"]] 23 | stock_hfq_df.columns = ["date", "open", "close", "high", "low", "volume"] 24 | stock_hfq_df.index = pd.to_datetime(stock_hfq_df["date"]) 25 | start_date = datetime(2024, 1, 1) 26 | end_date = datetime(2025, 1, 1) 27 | data = bt.feeds.PandasData(dataname=stock_hfq_df, fromdate=start_date, todate=end_date) 28 | 29 | # 设置回测引擎 30 | self.cerebro = cerebro = bt.Cerebro() 31 | cerebro.adddata(data) 32 | cerebro.broker.setcash(1000000) 33 | cerebro.broker.setcommission(commission=0.001) 34 | cerebro.addsizer(bt.sizers.FixedSize, stake=100) 35 | 36 | # 添加分析器 37 | cerebro.addanalyzer(btanalyzers.SharpeRatio, _name="sharpe") 38 | cerebro.addanalyzer(btanalyzers.DrawDown, _name="drawdown") 39 | cerebro.addanalyzer(btanalyzers.Returns, _name="returns") 40 | 41 | # 加载策略配置 42 | self.strategys = load_strategy("./config/strategy.yaml") 43 | self.result = None 44 | 45 | def tearDown(self): 46 | """测试后验证结果""" 47 | self.assertIsInstance(self.result, pd.DataFrame) 48 | print(f"测试结果:\n{self.result}") 49 | 50 | 51 | def run_back_trader(cerebro: bt.Cerebro, strategy: Type[BaseStrategy], **kwargs) -> pd.DataFrame: 52 | """运行回测 53 | 54 | Args: 55 | cerebro (bt.Cerebro): 回测引擎 56 | strategy (Type[BaseStrategy]): 策略类 57 | **kwargs: 策略参数 58 | 59 | Returns: 60 | pd.DataFrame: 回测结果 61 | """ 62 | # 添加优化策略 63 | cerebro.optstrategy(strategy, **kwargs) 64 | 65 | # 运行回测 66 | back = cerebro.run(maxcpus=1) 67 | 68 | # 处理回测结果 69 | par_list = [] 70 | for x in back: 71 | # 收集策略参数 72 | par = [] 73 | for param in kwargs.keys(): 74 | par.append(x[0].params._getkwargs()[param]) 75 | 76 | # 添加性能指标 77 | par.extend( 78 | [ 79 | x[0].analyzers.returns.get_analysis()["rnorm100"], 80 | x[0].analyzers.drawdown.get_analysis()["max"]["drawdown"], 81 | x[0].analyzers.sharpe.get_analysis()["sharperatio"], 82 | ] 83 | ) 84 | par_list.append(par) 85 | 86 | # 创建结果数据框 87 | columns = list(kwargs.keys()) 88 | columns.extend(["return", "dd", "sharpe"]) 89 | par_df = pd.DataFrame(par_list, columns=columns) 90 | return par_df 91 | -------------------------------------------------------------------------------- /tests/ma_test.py: -------------------------------------------------------------------------------- 1 | from strategy import MaStrategy 2 | 3 | from .base_test import StrategyTest, run_back_trader 4 | 5 | 6 | class MaStrategyTest(StrategyTest): 7 | """ma strategy test""" 8 | 9 | def test_ma(self): 10 | self.result = run_back_trader(self.cerebro, MaStrategy, maperiod=range(3, 31)) 11 | -------------------------------------------------------------------------------- /tests/macross_test.py: -------------------------------------------------------------------------------- 1 | from strategy import MaCrossStrategy 2 | 3 | from .base_test import StrategyTest, run_back_trader 4 | 5 | 6 | class MaCrossStrategyTest(StrategyTest): 7 | """ma cross strategy test""" 8 | 9 | def test_ma(self): 10 | self.result = run_back_trader( 11 | self.cerebro, 12 | MaCrossStrategy, 13 | fast_length=range(1, 11, 5), 14 | slow_length=range(25, 35, 5), 15 | ) 16 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenwr727/stock-backtrader-web-app/ff1d2e5c03bef12a62567f99fe50279f286ea0a9/utils/__init__.py -------------------------------------------------------------------------------- /utils/load.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict 2 | 3 | import yaml 4 | 5 | from .logs import logger 6 | 7 | 8 | def load_strategy(yaml_file: str) -> Dict[str, Any]: 9 | """加载策略配置 10 | 11 | Args: 12 | yaml_file (str): 策略配置文件路径 13 | 14 | Returns: 15 | Dict[str, Any]: 策略配置 16 | """ 17 | try: 18 | with open(yaml_file, "r", encoding="utf-8") as f: 19 | strategy = yaml.safe_load(f) 20 | return strategy 21 | except (FileNotFoundError, yaml.YAMLError) as e: 22 | logger.error(f"加载策略配置失败: {e}") 23 | raise 24 | -------------------------------------------------------------------------------- /utils/logs.py: -------------------------------------------------------------------------------- 1 | from loguru import logger 2 | 3 | logger.add( 4 | "./logs/{time:YYYY-MM-DD}.log", 5 | rotation="00:00", 6 | retention="7 days", 7 | level="INFO", 8 | encoding="utf-8", 9 | ) 10 | -------------------------------------------------------------------------------- /utils/processing.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import akshare as ak 4 | import backtrader as bt 5 | import backtrader.analyzers as btanalyzers 6 | import pandas as pd 7 | import streamlit as st 8 | 9 | from .logs import logger 10 | from .schemas import AkshareParams, BacktraderParams, StrategyBase 11 | 12 | logging.getLogger("streamlit.runtime.scriptrunner_utils").setLevel(logging.ERROR) 13 | 14 | 15 | model_hash_func = lambda x: x.model_dump() 16 | 17 | 18 | @st.cache_data(hash_funcs={AkshareParams: model_hash_func}) 19 | def gen_stock_df(ak_params: AkshareParams) -> pd.DataFrame: 20 | """生成股票数据 21 | 22 | Args: 23 | ak_params (AkshareParams): akshare 参数 24 | 25 | Returns: 26 | pd.DataFrame: 股票历史数据 27 | """ 28 | df = ak.stock_zh_a_hist(**ak_params.model_dump()) 29 | if not df.empty: 30 | return df[["日期", "开盘", "收盘", "最高", "最低", "成交量"]] 31 | return pd.DataFrame() 32 | 33 | 34 | @st.cache_data(hash_funcs={StrategyBase: model_hash_func, BacktraderParams: model_hash_func}) 35 | def run_backtrader(stock_df: pd.DataFrame, strategy: StrategyBase, bt_params: BacktraderParams) -> pd.DataFrame: 36 | """运行回测 37 | 38 | Args: 39 | stock_df (pd.DataFrame): 股票数据 40 | strategy (StrategyBase): 策略名称和参数 41 | bt_params (BacktraderParams): 回测参数 42 | 43 | Returns: 44 | pd.DataFrame: 回测结果 45 | """ 46 | # 设置日期索引 47 | stock_df.index = pd.to_datetime(stock_df["date"]) 48 | 49 | # 创建数据源 50 | data = bt.feeds.PandasData(dataname=stock_df, fromdate=bt_params.start_date, todate=bt_params.end_date) 51 | 52 | # 初始化回测引擎 53 | cerebro = bt.Cerebro() 54 | cerebro.adddata(data) 55 | cerebro.broker.setcash(bt_params.start_cash) 56 | cerebro.broker.setcommission(commission=bt_params.commission_fee) 57 | cerebro.addsizer(bt.sizers.FixedSize, stake=bt_params.stake) 58 | 59 | # 添加分析器 60 | cerebro.addanalyzer(btanalyzers.SharpeRatio, _name="sharpe", riskfreerate=0.0) 61 | cerebro.addanalyzer(btanalyzers.DrawDown, _name="drawdown") 62 | cerebro.addanalyzer(btanalyzers.Returns, _name="returns") 63 | 64 | # 动态导入策略类 65 | try: 66 | strategy_cli = getattr(__import__("strategy"), f"{strategy.name}Strategy") 67 | cerebro.optstrategy(strategy_cli, **strategy.params) 68 | except (ImportError, AttributeError) as e: 69 | logger.error(f"策略导入失败: {e}") 70 | raise ValueError(f"无法找到策略: {strategy.name}Strategy") 71 | 72 | # 运行回测 73 | back = cerebro.run() 74 | 75 | # 处理回测结果 76 | par_list = [] 77 | for x in back: 78 | # 收集策略参数 79 | par = [] 80 | for param in strategy.params.keys(): 81 | par.append(x[0].params._getkwargs()[param]) 82 | 83 | # 添加性能指标 84 | par.extend( 85 | [ 86 | x[0].analyzers.returns.get_analysis()["rnorm100"], 87 | x[0].analyzers.drawdown.get_analysis()["max"]["drawdown"], 88 | x[0].analyzers.sharpe.get_analysis()["sharperatio"], 89 | ] 90 | ) 91 | par_list.append(par) 92 | 93 | # 创建结果数据框 94 | columns = list(strategy.params.keys()) 95 | columns.extend(["return", "dd", "sharpe"]) 96 | par_df = pd.DataFrame(par_list, columns=columns) 97 | return par_df 98 | -------------------------------------------------------------------------------- /utils/schemas.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | from typing import Any, Dict 3 | 4 | from pydantic import BaseModel 5 | 6 | 7 | class AkshareParams(BaseModel): 8 | """AkshareParams 模型""" 9 | 10 | symbol: str 11 | period: str 12 | start_date: str 13 | end_date: str 14 | adjust: str 15 | 16 | 17 | class BacktraderParams(BaseModel): 18 | """BacktraderParams 模型""" 19 | 20 | start_date: datetime.date 21 | end_date: datetime.date 22 | start_cash: float 23 | commission_fee: float 24 | stake: int 25 | 26 | 27 | class StrategyBase(BaseModel): 28 | """策略基础模型""" 29 | 30 | name: str 31 | params: Dict[str, Any] 32 | --------------------------------------------------------------------------------