├── .github └── workflows │ └── python-publish.yml ├── .gitignore ├── LICENSE ├── README.md ├── examples ├── group_daily.py ├── group_monthly.py ├── speed.py └── stocks.py ├── lightbt ├── __init__.py ├── _version.py ├── base.py ├── callbacks.py ├── enums.py ├── portfolio.py ├── position.py ├── signals.py ├── stats.py └── utils.py ├── pyproject.toml ├── requirements.txt └── setup.py /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflow will upload a Python Package using Twine when a release is created 2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#publishing-to-package-registries 3 | 4 | # This workflow uses actions that are not certified by GitHub. 5 | # They are provided by a third-party and are governed by 6 | # separate terms of service, privacy policy, and support 7 | # documentation. 8 | 9 | name: Upload Python Package 10 | 11 | on: 12 | release: 13 | types: [published] 14 | 15 | permissions: 16 | contents: read 17 | 18 | jobs: 19 | deploy: 20 | 21 | runs-on: ubuntu-latest 22 | 23 | steps: 24 | - uses: actions/checkout@v4 25 | - name: Set up Python 26 | uses: actions/setup-python@v3 27 | with: 28 | python-version: '3.x' 29 | - name: Install dependencies 30 | run: | 31 | python -m pip install --upgrade pip 32 | pip install build 33 | - name: Build package 34 | run: python -m build 35 | - name: Publish package 36 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 37 | with: 38 | user: __token__ 39 | password: ${{ secrets.PYPI_API_TOKEN }} 40 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 wukan 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # LightBT 2 | 3 | 轻量级回测工具 4 | 5 | ## 注意 6 | 7 | 每种工具都有其适合的应用场景。遇到需要快速评测大批量参数或因子的场景,我们只要相对准确。所以使用对数收益累加更合适 8 | 9 | ## 现状 10 | 11 | 1. 收益率乘权重的向量计算极快,但在连续做空的场景下,收益不正确,仅仅是因为每天收益率变化不大,所以误差很难察觉。 12 | 2. 大部分回测库策略与绩效统计结合过于紧密,达不到策略与平台分离的设想 13 | 3. 大部分回测库都很慢,如:`zipline`、`rqalpha`、`backtrader` 14 | 4. 回测快的库,底层一般是用`C++`、`Rust`等语言实现。跨语言部署和二次开发对量化研究员而言比较困难 15 | 5. `vectorbt`计算准确,回测也快。但不支持保证金无法直接用在期货市场,输入宽表比较占内存 16 | 6. `bt`策略树设计很好,但也不支持保证金概念 17 | 18 | ## 目标 19 | 20 | 1. 正确处理保证金和做空 21 | 2. 架构简单易扩展 22 | 3. 回测速度快 23 | 24 | ## 技术选型 25 | 26 | 1. `C++/Cython`开发,模仿`TA-Lib`的技术方案,`Cython`版库部署麻烦,开发也复杂 27 | 2. `Rust`开发,模仿`polars`的技术方案,使用`pyo3`进行跨语言调用,但`Rust`入门困难 28 | 3. `Numba`支持`JIT`,安装部署方便,易于调试和二次开发 29 | 30 | ## 三层结构 31 | 32 | 1. 开仓成交回报产生的持仓导致持仓盈亏、平仓成交回报产生平仓盈亏。将两种盈亏累计便成盈亏曲线 33 | - 已经可以统计盈亏、胜率、最大回撤等指标 34 | 2. 叠加初始资金即可构成资金曲线和净值曲线,可计算收益率、最大回撤率等指标 35 | - 没有考虑资金不足、是否能成交等情况。已经是简化回测绩效统计工具。仅能按手数进行交易 36 | 3. 初始资金和保证金率决定了可开手数 37 | - 关注交易细节,考虑资金不足等情况。可以按比例进行资金分配下单 38 | 39 | ## 工作原理 40 | 41 | 1. 使用对象来维护每种资产的交易信息。其中的成员变量全是最新值 42 | 2. 根据时间的推进和交易指令,更新对应对象 43 | 3. 指定时间间隔获取对象成员变量的快照,将所有快照连起来便成总体绩效 44 | - 月频交易,但日频取快照 45 | - 周频交易,周频取快照 46 | - 分钟交易,日频取快照 47 | 48 | ## 安装 49 | 50 | ```commandline 51 | pip install lightbt -U 52 | ``` 53 | 54 | ## 使用 55 | 56 | 以下是代码片段。完整代码请参考[stocks.py](examples/stocks.py) 57 | 58 | ```python 59 | # %% 60 | 61 | from lightbt import LightBT, warmup 62 | from lightbt.stats import total_equity 63 | from lightbt.utils import Timer 64 | 65 | # 省略代码...... 66 | 67 | # %% 热身 68 | print('warmup:', warmup()) 69 | 70 | # %% 初始化 71 | bt = LightBT(init_cash=0.0, 72 | positions_precision=1.0, 73 | max_trades=_N * _K * 2 // 1, # 反手占两条记录,所以预留2倍空间比较安全 74 | max_performances=_N * _K) 75 | # 入金。必需先入金,否则资金为0无法交易 76 | bt.deposit(10000 * 100) 77 | 78 | # %% 配置资产信息 79 | with Timer(): 80 | bt.setup(config) 81 | 82 | # %% 资产转换,只做一次即可 83 | df['asset'] = df['asset'].map(bt.mapping_asset_int) 84 | 85 | # %% 交易 86 | with Timer(): 87 | bt.run_bars(groupby(orders_daily(df, sort=True), by='date', dtype=order_outside_dt)) 88 | 89 | # %% 查看最终持仓 90 | positions = bt.positions() 91 | print(positions) 92 | # %% 查看所有交易记录 93 | trades = bt.trades(return_all=True) 94 | print(trades) 95 | trades_stats = bt.trades_stats() 96 | print(trades_stats) 97 | roundtrips = bt.roundtrips() 98 | print(roundtrips) 99 | roundtrips_stats = bt.roundtrips_stats() 100 | print(roundtrips_stats) 101 | 102 | # %% 查看绩效 103 | perf = bt.performances(return_all=True) 104 | print(perf) 105 | # %% 总体绩效 106 | equity = total_equity(perf) 107 | print(equity) 108 | equity.plot() 109 | 110 | ``` 111 | 112 | ## 注意 113 | 114 | `run_bars`的输入是迭代类型,它处理数据是一批一批的,同一批中处理优先级如下: 115 | 116 | 1. 平仓优先。先平后开 117 | 2. 时间优先。后面订单可能资金不足而少报 118 | 119 | 所以。每天早上提供一张目标仓位清单后,内部分成两组,先下平仓单,再下开仓单。 120 | 121 | - 本回测系统由于会立即成交,所以平仓后资金立即释放,可以立即开仓 122 | - 但实盘中平仓要等成交释放资金后才能再开仓,如果资金非常多,一起开平也可以 123 | 124 | `groupby`是用来分批的工具,可以使用多个参数进行多重分组, 125 | 如参数为`groupby(by=['date'])`时就是一天一个交易清单,如果需要收盘开仓,早盘平仓 ,可以`date`中的时间`精确到小时`做成每天两批 126 | 127 | 同一批中,主要参数需要完全一样,系统只取每批的最后一组。例如:`size_type`在同一批中不同概念冲突了。 128 | 129 | ### 结果稳定性 130 | 131 | 1. 部分工具的选出前10等功能,可能由于前20个值都一样,这时一定要考察更多的指标来确定顺序,比如多考察股票名。否则结果可能每次都不一样。 132 | 2. `config`函输入`asset`也需要提前排序 133 | 3. `groupby`前也要排序 134 | 135 | ## 输入格式 136 | 137 | 1. date: int64 138 | - 时间日期。需在外部提前转成数字。可用`astype(np.int64)`或`to_records(dtype)`等方法来实现 139 | 2. size_type: int 140 | - 数量类型。参考`lightbt.enums.SizeType` 141 | 3. asset: int 142 | - 资产ID。由`LightBT.setup`执行结果所确定。可通过`LightBT.asset_str2int`和`LightBT.asset_int2str`进行相互转换 143 | 4. size: float 144 | - 数量。具体含义需根据`size_type`决定。`nan`是一特殊值。用来表示当前一行不交易。在只更新最新价的需求中将频繁用到。 145 | 5. fill_price: float 146 | - 成交价。成交价不等于最新价也不等于收盘价。可以用成交均价等一些有意义的价格进行代替。 147 | 6. last_price: float 148 | - 最新价。可用收盘价、结算价等代替。它影响持仓的浮动盈亏。所以在对绩效快照前一定要更新一次 149 | 7. date_diff: bool 150 | - 是否换日。在换日的最后时刻需要更新最新价和记录绩效 151 | 152 | ## 配置格式 153 | 154 | 通过`LightBT.setup`进行设置 155 | 156 | 1. asset: str 157 | - 资产名。内部将使用对应的int进行各项处理 158 | 2. mult: float 159 | - 合约乘数。股票的合约乘数为1.0 160 | 3. margin_ratio: float 161 | - 保证金率。股票的保证金率为1.0 162 | 4. commission_ratio: float 163 | - 手续费率参数。具体含义参考`commission_fn` 164 | 5. commission_fn: 165 | - 手续费处理函数 166 | 167 | ## 调试 168 | 169 | ```python 170 | import os 171 | 172 | os.environ['NUMBA_DISABLE_JIT'] = '1' 173 | ``` 174 | 175 | `numba`的JIT模式下是无法直接调试的,编译也花时间,可以先添加环境变量`NUMBA_DISABLE_JIT=1`,禁止JIT模式。 176 | 177 | 数据量较小时,禁用JIT模式反而速度更快。 178 | 179 | ## 二次开发 180 | 181 | ```commandline 182 | git --clone https://github.com/wukan1986/LightBT.git 183 | cd LightBT 184 | pip install -e . 185 | ``` -------------------------------------------------------------------------------- /examples/group_daily.py: -------------------------------------------------------------------------------- 1 | # %% 2 | """ 3 | 资金分成5份,每份持有5天,每天入场 4 | 将因子分成10组,每组进行分别统计,然后画分组曲线 5 | 6 | Notes 7 | ===== 8 | 9 | 10 | """ 11 | # %% 12 | # import os 13 | # 14 | # os.environ['NUMBA_DISABLE_JIT'] = '1' 15 | 16 | import matplotlib.pyplot as plt 17 | import numpy as np 18 | import pandas as pd 19 | 20 | from lightbt import LightBT, warmup 21 | from lightbt.callbacks import commission_by_value 22 | from lightbt.enums import SizeType, order_outside_dt 23 | from lightbt.signals import orders_daily 24 | from lightbt.stats import total_equity 25 | from lightbt.utils import Timer, groupby 26 | 27 | pd.set_option('display.max_columns', None) 28 | pd.set_option('display.width', 1000) 29 | # pd.options.plotting.backend = 'plotly' 30 | 31 | # %% 32 | _Q = 10 # 分组数量 33 | _P = 5 # 分成多少份入场 34 | _K = 5000 # 多支股票 35 | 36 | asset = [f's_{i:04d}' for i in range(_K)] 37 | date = pd.date_range(start='2000-01-01', end='2005-12-31', freq='B') 38 | _N = len(date) # 10年 39 | 40 | CLOSE = np.cumprod(1 + np.random.uniform(-0.1, 0.1, size=(_N, _K)), axis=0) * np.random.randint(10, 100, _K) 41 | CLOSE = pd.DataFrame(CLOSE, index=date, columns=asset) 42 | 43 | SMA10 = CLOSE.rolling(10).mean() 44 | SMA20 = CLOSE.rolling(20).mean() 45 | 46 | # 因子构建 47 | factor: pd.DataFrame = SMA10 / SMA20 - 1.0 # 因子 48 | 49 | # 收盘时产生信号,第二天开盘交易 50 | factor = factor.shift(1) 51 | 52 | # 分组 53 | factor = factor.stack() 54 | factor.index.names = ['date', 'asset'] 55 | quantile: pd.DataFrame = factor.groupby(by=['date'], group_keys=False).apply( 56 | lambda x: pd.qcut(x, _Q, labels=False, duplicates='drop')).unstack() 57 | quantile, _ = quantile.align(CLOSE, fill_value=-1) 58 | row_num = pd.DataFrame(0, index=CLOSE.index, columns=CLOSE.columns, dtype=int) 59 | row_num[:] = np.arange(len(CLOSE)).reshape(-1, 1) % _P 60 | 61 | size = pd.DataFrame(0.0, index=CLOSE.index, columns=CLOSE.columns, dtype=float) 62 | 63 | df = pd.DataFrame({ 64 | 'CLOSE': CLOSE.to_numpy().reshape(-1), 65 | 'size': size.to_numpy().reshape(-1), 66 | 'quantile': quantile.to_numpy().reshape(-1), 67 | 'row_num': row_num.to_numpy().reshape(-1), 68 | }, index=pd.MultiIndex.from_product([date, asset], names=['date', 'asset'])).reset_index() 69 | df['size_type'] = SizeType.TargetValueScale 70 | 71 | del CLOSE 72 | del SMA10 73 | del SMA20 74 | del size 75 | del factor 76 | del quantile 77 | del row_num 78 | df.columns = ['date', 'asset', 'CLOSE', 'size', 'quantile', 'row_num', 'size_type'] 79 | 80 | df['fill_price'] = df['CLOSE'] 81 | df['last_price'] = df['fill_price'] 82 | 83 | # %% 热身 84 | print('warmup:', warmup()) 85 | 86 | # %% 初始化 87 | unit = df['date'].dtype.name[-3:-1] 88 | bt = LightBT(init_cash=10000 * 100, # 初始资金 89 | positions_precision=1.0, 90 | max_trades=_N * _K * 2 // 1, # 反手占两条记录,所以预留2倍空间比较安全 91 | max_performances=_N * _K, 92 | unit=unit) 93 | 94 | # %% 配置资产信息 95 | asset = sorted(df['asset'].unique()) 96 | config = pd.DataFrame({'asset': asset, 'mult': 1.0, 'margin_ratio': 1.0, 97 | 'commission_ratio': 0.0005, 'commission_fn': commission_by_value}) 98 | with Timer(): 99 | bt.setup(config) 100 | 101 | # %% 资产转换,只做一次即可 102 | df['asset'] = df['asset'].map(bt.mapping_asset_int) 103 | 104 | # %% 交易 105 | equities1 = pd.DataFrame() 106 | equities2 = pd.DataFrame() 107 | for i in range(_Q): 108 | _equities = pd.DataFrame() 109 | 110 | df['size'] = np.where(df['quantile'] == i, 1, 0).astype(float) 111 | for j in range(_P): 112 | print(i, '\t', j, '\t', end='') 113 | df['size_type'] = np.where(df['row_num'] == j, SizeType.TargetValueScale, SizeType.NOP).astype(int) 114 | 115 | bt.reset() # 必需在初始化时设置资金,之后的入金在reset后不生效 116 | with Timer(): 117 | # 按周更新净值 118 | bt.run_bars(groupby(orders_daily(df, sort=True), by='date', dtype=order_outside_dt)) 119 | 120 | perf = bt.performances(return_all=True) 121 | s1 = total_equity(perf)['equity'] 122 | equities1[f"{i}_{j}"] = s1 123 | _equities[f"{i}_{j}"] = s1 124 | s2 = _equities.sum(axis=1) 125 | equities2[i] = s2 126 | equities3 = equities2.sum(axis=1) 127 | 128 | # %% 129 | fig, axes = plt.subplots(1, 3) 130 | equities1.plot(ax=axes[0]) 131 | equities2.plot(ax=axes[1]) 132 | equities3.plot(ax=axes[2]) 133 | 134 | plt.show() 135 | -------------------------------------------------------------------------------- /examples/group_monthly.py: -------------------------------------------------------------------------------- 1 | # %% 2 | """ 3 | 每月初再平衡 4 | 将因子分成10组,每组进行分别统计,然后画分组曲线 5 | 6 | Notes 7 | ===== 8 | 只是为了看因子的区分能力,用对数收益累加的速度更快。 9 | 这里只是为了演示可以实现分层功能 10 | 11 | """ 12 | # import os 13 | # 14 | # os.environ['NUMBA_DISABLE_JIT'] = '1' 15 | 16 | import matplotlib.pyplot as plt 17 | import numpy as np 18 | import pandas as pd 19 | 20 | from lightbt import LightBT, warmup 21 | from lightbt.callbacks import commission_by_value 22 | from lightbt.enums import SizeType, order_outside_dt 23 | from lightbt.signals import orders_daily 24 | from lightbt.stats import total_equity 25 | from lightbt.utils import Timer, groupby 26 | 27 | pd.set_option('display.max_columns', None) 28 | pd.set_option('display.width', 1000) 29 | # pd.options.plotting.backend = 'plotly' 30 | 31 | # %% 32 | 33 | _K = 5000 # 多支股票 34 | 35 | asset = [f's_{i:04d}' for i in range(_K)] 36 | date = pd.date_range(start='2000-01-01', end='2010-12-31', freq='B') 37 | _N = len(date) # 10年 38 | 39 | CLOSE = np.cumprod(1 + np.random.uniform(-0.1, 0.1, size=(_N, _K)), axis=0) * np.random.randint(10, 100, _K) 40 | CLOSE = pd.DataFrame(CLOSE, index=date, columns=asset) 41 | 42 | SMA10 = CLOSE.rolling(10).mean() 43 | SMA20 = CLOSE.rolling(20).mean() 44 | 45 | # 时间处理,每月第一个交易日调仓,每月第一天可能不是交易日 46 | dt = pd.DataFrame(index=CLOSE.index) 47 | dt['start'] = dt.index 48 | dt['end'] = dt.index 49 | dt = dt.resample('ME').agg({'start': 'first', 'end': 'last'}) 50 | 51 | # 目标市值 52 | size_type = pd.DataFrame(SizeType.NOP, index=CLOSE.index, columns=CLOSE.columns, dtype=int) 53 | size_type.loc[dt['start']] = SizeType.TargetValueScale 54 | 55 | # 因子构建 56 | factor: pd.DataFrame = SMA10 / SMA20 - 1.0 # 因子 57 | 58 | # 收盘时产生信号,第二天开盘交易 59 | factor = factor.shift(1) 60 | 61 | # 分组 62 | factor = factor.loc[dt['start']].stack() 63 | factor.index.names = ['date', 'asset'] 64 | quantile: pd.DataFrame = factor.groupby(by=['date'], group_keys=False).apply( 65 | lambda x: pd.qcut(x, 10, labels=False, duplicates='drop')).unstack() 66 | quantile, _ = quantile.align(CLOSE, fill_value=-1) 67 | 68 | size = pd.DataFrame(0.0, index=CLOSE.index, columns=CLOSE.columns, dtype=float) 69 | 70 | df = pd.DataFrame({ 71 | 'CLOSE': CLOSE.to_numpy().reshape(-1), 72 | 'size_type': size_type.to_numpy().reshape(-1), 73 | 'size': size.to_numpy().reshape(-1), 74 | 'quantile': quantile.to_numpy().reshape(-1), 75 | }, index=pd.MultiIndex.from_product([date, asset], names=['date', 'asset'])).reset_index() 76 | 77 | del CLOSE 78 | del SMA10 79 | del SMA20 80 | del size_type 81 | del size 82 | del factor 83 | del quantile 84 | df.columns = ['date', 'asset', 'CLOSE', 'size_type', 'size', 'quantile'] 85 | 86 | df['fill_price'] = df['CLOSE'] 87 | df['last_price'] = df['fill_price'] 88 | 89 | # %% 热身 90 | print('warmup:', warmup()) 91 | 92 | # %% 初始化 93 | unit = df['date'].dtype.name[-3:-1] 94 | bt = LightBT(init_cash=10000 * 100, # 初始资金 95 | positions_precision=1.0, 96 | max_trades=_N * _K * 2 // 1, # 反手占两条记录,所以预留2倍空间比较安全 97 | max_performances=_N * _K, 98 | unit=unit) 99 | 100 | asset = sorted(df['asset'].unique()) 101 | config = pd.DataFrame({'asset': asset, 'mult': 1.0, 'margin_ratio': 1.0, 102 | 'commission_ratio': 0.0005, 'commission_fn': commission_by_value}) 103 | 104 | # %% 配置资产信息 105 | with Timer(): 106 | bt.setup(config) 107 | 108 | # %% 资产转换,只做一次即可 109 | df['asset'] = df['asset'].map(bt.mapping_asset_int) 110 | 111 | # %% 交易 112 | equities = pd.DataFrame() 113 | # 如果能并行就好了 114 | for i in range(10): 115 | print(i, '\t', end='') 116 | df['size'] = (df['quantile'] == i).astype(float) 117 | 118 | bt.reset() # 必需在初始化时设置资金,之后的入金在reset后不生效 119 | with Timer(): 120 | # 按周更新净值 121 | bt.run_bars(groupby(orders_daily(df, sort=True), by='date', dtype=order_outside_dt)) 122 | 123 | perf = bt.performances(return_all=True) 124 | equities[i] = total_equity(perf)['equity'] 125 | 126 | # %% 127 | equities.plot() 128 | plt.show() 129 | -------------------------------------------------------------------------------- /examples/speed.py: -------------------------------------------------------------------------------- 1 | """ 2 | 测试热身第一次和第二次的速度差别 3 | 再测试之后运行多次的耗时 4 | """ 5 | import os 6 | # os.environ['NUMBA_DISABLE_JIT'] = '1' 7 | import timeit 8 | 9 | import pandas as pd 10 | 11 | from lightbt import warmup 12 | 13 | _ = os 14 | pd.set_option('display.max_columns', None) 15 | pd.set_option('display.width', 1000) 16 | 17 | if __name__ == '__main__': 18 | print('start') 19 | print('warmup:', warmup()) 20 | print('warmup:', warmup()) 21 | print(timeit.timeit('warmup()', number=1000, globals=locals())) 22 | -------------------------------------------------------------------------------- /examples/stocks.py: -------------------------------------------------------------------------------- 1 | # %% 2 | """ 3 | 每月初做多前100支,做空后100支 4 | 权重按因子值大小进行分配。分配前因子标准化 5 | 6 | 计算因子值后,第二天早上交易 7 | 8 | 由于不支持除权除息,所以价格都为后复权价 9 | """ 10 | import numpy as np 11 | import pandas as pd 12 | 13 | from lightbt import LightBT, warmup 14 | from lightbt.callbacks import commission_by_value 15 | from lightbt.enums import order_outside_dt, SizeType 16 | from lightbt.signals import orders_daily 17 | from lightbt.stats import total_equity, pnl_by_asset, pnl_by_assets 18 | from lightbt.utils import Timer, groupby 19 | 20 | # %% 21 | # import os 22 | # os.environ['NUMBA_DISABLE_JIT'] = '1' 23 | 24 | pd.set_option('display.max_columns', None) 25 | pd.set_option('display.width', 1000) 26 | pd.options.plotting.backend = 'plotly' 27 | 28 | # %% 29 | 30 | _K = 5000 # 多支股票 31 | 32 | asset = [f's_{i:04d}' for i in range(_K)] 33 | date = pd.date_range(start='2000-01-01', end='2010-12-31', freq='B') 34 | _N = len(date) # 10年 35 | 36 | CLOSE = np.cumprod(1 + np.random.uniform(-0.1, 0.1, size=(_N, _K)), axis=0) * np.random.randint(10, 100, _K) 37 | CLOSE = pd.DataFrame(CLOSE, index=date, columns=asset) 38 | 39 | OPEN = np.cumprod(1 + np.random.uniform(-0.1, 0.1, size=(_N, _K)), axis=0) * np.random.randint(10, 100, _K) 40 | 41 | SMA10 = CLOSE.rolling(10).mean() 42 | SMA20 = CLOSE.rolling(20).mean() 43 | 44 | # 时间处理,每月第一个交易日调仓,每月第一天可能不是交易日 45 | dt = pd.DataFrame(index=CLOSE.index) 46 | dt['start'] = dt.index 47 | dt['end'] = dt.index 48 | dt = dt.resample('ME').agg({'start': 'first', 'end': 'last'}) 49 | 50 | # 目标市值 51 | size_type = pd.DataFrame(SizeType.NOP, index=CLOSE.index, columns=CLOSE.columns, dtype=int) 52 | # size_type.loc[dt['start']] = SizeType.TargetValuePercent 53 | # size_type[:] = SizeType.TargetValuePercent 54 | 55 | # 因子构建,过滤多头与空头 56 | factor: pd.DataFrame = SMA10 / SMA20 - 1.0 # 因子 57 | 58 | # 收盘时产生信号,第二天开盘交易 59 | factor = factor.shift(1) 60 | 61 | # 因为之后将按因子值进行权重分配,这里需要提前做做标准化 62 | # 标准化后,前N一定是正数,后N一定是负数 63 | factor = factor.subtract(factor.mean(axis=1), axis=0).div(factor.std(axis=1, ddof=0), axis=0) 64 | 65 | top = factor.rank(axis=1, pct=False, ascending=False) <= 100 # 横截面按从大到小排序 66 | bottom = factor.rank(axis=1, pct=False, ascending=True) <= 100 # 横截面按从小到大排序 67 | 68 | size = pd.DataFrame(0.0, index=CLOSE.index, columns=CLOSE.columns, dtype=float) 69 | size[top] = factor[top] # 前N做多 70 | size[bottom] = factor[bottom] # 后N做空 71 | # 因子加权。因子值大权重大。也可设成等权 72 | size = size.div(size.abs().sum(axis=1), axis=0) 73 | 74 | df = pd.DataFrame({ 75 | 'OPEN': OPEN.reshape(-1), 76 | 'CLOSE': CLOSE.to_numpy().reshape(-1), 77 | 'size_type': size_type.to_numpy().reshape(-1), 78 | 'size': size.to_numpy().reshape(-1), 79 | }, index=pd.MultiIndex.from_product([date, asset], names=['date', 'asset'])).reset_index() 80 | 81 | del OPEN 82 | del CLOSE 83 | del SMA10 84 | del SMA20 85 | del size_type 86 | del size 87 | del top 88 | del bottom 89 | del factor 90 | df.columns = ['date', 'asset', 'OPEN', 'CLOSE', 'size_type', 'size'] 91 | 92 | # 早上开盘时交易。在集合竞价交易可使用开盘价,也可以使用前5分钟VWAP价 93 | df['fill_price'] = df['OPEN'] 94 | # 每天的收盘价或结算价 95 | df['last_price'] = df['CLOSE'] 96 | 97 | df.to_parquet('tmp.parquet') 98 | df = pd.read_parquet('tmp.parquet') 99 | 100 | # %% 热身 101 | print('warmup:', warmup()) 102 | 103 | # %% 初始化 104 | unit = df['date'].dtype.name[-3:-1] 105 | bt = LightBT(init_cash=0.0, 106 | positions_precision=1.0, 107 | max_trades=_N * _K * 2 // 1, # 反手占两条记录,所以预留2倍空间比较安全 108 | max_performances=_N * _K, 109 | unit=unit) 110 | # 入金。必需先入金,否则资金为0无法交易 111 | bt.deposit(10000 * 100) 112 | 113 | # %% 配置资产信息 114 | asset = sorted(df['asset'].unique()) 115 | config = pd.DataFrame({'asset': asset, 'mult': 1.0, 'margin_ratio': 1.0, 116 | 'commission_ratio': 0.0005, 'commission_fn': commission_by_value}) 117 | with Timer(): 118 | bt.setup(config) 119 | 120 | # %% 资产转换,只做一次即可 121 | df['asset'] = df['asset'].map(bt.mapping_asset_int) 122 | 123 | # %% 交易 124 | with Timer(): 125 | # 按日更新净值 126 | bt.run_bars(groupby(orders_daily(df, sort=True), by='date', dtype=order_outside_dt)) 127 | 128 | # perf = bt.performances(return_all=True) 129 | # s1 = total_equity(perf)['equity'] 130 | # print(s1.tail()) 131 | 132 | 133 | # %% 查看最终持仓 134 | positions = bt.positions() 135 | print(positions) 136 | # %% 查看所有交易记录 137 | trades = bt.trades(return_all=True) 138 | print(trades) 139 | trades_stats = bt.trades_stats() 140 | print(trades_stats) 141 | roundtrips = bt.roundtrips() 142 | print(roundtrips) 143 | roundtrips_stats = bt.roundtrips_stats() 144 | print(roundtrips_stats) 145 | 146 | # %% 查看绩效 147 | perf = bt.performances(return_all=True) 148 | print(perf) 149 | # %% 总体绩效 150 | equity = total_equity(perf) 151 | print(equity) 152 | equity.plot() 153 | 154 | # %% 多个资产的收益曲线 155 | pnls = pnl_by_assets(perf, ['s_0000', 's_0100', 's_0300'], bt.mapping_asset_int, bt.mapping_int_asset) 156 | print(pnls) 157 | pnls.plot() 158 | 159 | # %% 单个资产的绩效细节 160 | pnls = pnl_by_asset(perf, 's_0000', df[['date', 'asset', 'CLOSE']], bt.mapping_asset_int, bt.mapping_int_asset) 161 | print(pnls) 162 | pnls.plot() 163 | # %% 164 | pd.options.plotting.backend = 'matplotlib' 165 | pnls[['PnL', 'CLOSE']].plot(secondary_y='CLOSE') 166 | # %% 167 | print(df) 168 | # %% 169 | -------------------------------------------------------------------------------- /lightbt/__init__.py: -------------------------------------------------------------------------------- 1 | from ._version import __version__ 2 | from .base import LightBT, warmup 3 | -------------------------------------------------------------------------------- /lightbt/_version.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.2.4" 2 | -------------------------------------------------------------------------------- /lightbt/base.py: -------------------------------------------------------------------------------- 1 | import time 2 | from typing import List, Union 3 | 4 | import numpy as np 5 | import pandas as pd 6 | from lightbt.stats import calc_trades_stats, calc_roundtrips_stats, trades_to_roundtrips 7 | 8 | 9 | class LightBT: 10 | def __init__(self, 11 | init_cash: float = 10000, 12 | positions_precision: float = 1.0, 13 | max_trades: int = 10000, 14 | max_performances: int = 10000, 15 | unit: str = 'ns' 16 | ) -> None: 17 | """初始化 18 | 19 | Parameters 20 | ---------- 21 | positions_precision: float 22 | 持仓精度 23 | max_trades: int 24 | 记录成交的缓存大小。空间不足时将丢弃 25 | max_performances: int 26 | 记录绩效的缓存大小。空间不足时将丢弃 27 | unit:str 28 | pd.to_datetime的unit参数 29 | 30 | """ 31 | from lightbt.portfolio import Portfolio 32 | 33 | self._init_cash = init_cash 34 | self._positions_precision = positions_precision 35 | self._max_trades = max_trades 36 | self._max_performances = max_performances 37 | self._unit = unit 38 | 39 | self.pf = Portfolio(positions_precision=self._positions_precision, 40 | max_trades=self._max_trades, 41 | max_performances=self._max_performances) 42 | # 入金 43 | self.deposit(self._init_cash) 44 | 45 | # 底层没有资产名字符串,只有纯数字 46 | self.mapping_asset_int = {} 47 | self.mapping_int_asset = {} 48 | self.conf: pd.DataFrame = pd.DataFrame() 49 | 50 | def reset(self): 51 | """重置。不需要再次`setup`,只需要重新跑一次`run_`即可""" 52 | self.pf.reset() 53 | # 入初始资金 54 | self.deposit(self._init_cash) 55 | 56 | def setup(self, df: pd.DataFrame) -> None: 57 | """映射资产,配置合约乘数和保证金率 58 | 59 | 同名的会进行替换 60 | 61 | Parameters 62 | ---------- 63 | df: pd.DataFrame 64 | - asset 65 | - mult 66 | - margin_ratio 67 | - commission_ratio 68 | 手续费率 69 | - commission_fn 70 | 手续费计算函数 71 | 72 | """ 73 | self.conf = pd.concat([self.conf, df]) 74 | self.conf.drop_duplicates(subset='asset', keep='first', inplace=True) 75 | 76 | # 资产与底层持仓位置的映射 77 | conf = self.conf.reset_index(drop=True) 78 | self.mapping_int_asset = conf['asset'].to_dict() 79 | self.mapping_asset_int = {v: k for k, v in self.mapping_int_asset.items()} 80 | 81 | # 转成底层方便的格式 82 | asset = np.asarray(conf.index, dtype=int) 83 | mult = np.asarray(conf['mult'], dtype=float) 84 | margin_ratio = np.asarray(conf['margin_ratio'], dtype=float) 85 | commission_ratio = np.asarray(conf['commission_ratio'], dtype=float) 86 | commission_fn = np.asarray(conf['commission_fn']) 87 | 88 | # 调用底层的批量处理函数 89 | self.pf.setup(asset, mult, margin_ratio, commission_ratio) 90 | # 设置手续费函数 91 | for aid, fn in zip(asset, commission_fn): 92 | self.pf.set_commission_fn(aid, fn) 93 | 94 | def asset_str2int(self, strings: Union[List[str], str]) -> Union[List[int], int]: 95 | """资产转换。字符串转数字""" 96 | if isinstance(strings, list) and len(strings) == 1: 97 | strings = strings[0] 98 | if isinstance(strings, str): 99 | return self.mapping_asset_int.get(strings) 100 | 101 | return list(map(self.mapping_asset_int.get, strings)) 102 | 103 | def asset_int2str(self, integers: Union[List[int], int]) -> Union[List[str], str]: 104 | """资产转换。数字转字符串""" 105 | if isinstance(integers, list) and len(integers) == 1: 106 | integers = integers[0] 107 | if isinstance(integers, int): 108 | return self.mapping_int_asset.get(integers) 109 | 110 | return list(map(self.mapping_int_asset.get, integers)) 111 | 112 | def deposit(self, cash: float) -> float: 113 | """入金 114 | 115 | Parameters 116 | ---------- 117 | cash: float 118 | 119 | Returns 120 | ------- 121 | float 122 | 123 | Notes 124 | ----- 125 | 默认资金为0,所以交易前需要入金 126 | 127 | """ 128 | return self.pf.deposit(cash) 129 | 130 | def withdraw(self, cash: float) -> float: 131 | """出金""" 132 | return self.pf.withdraw(cash) 133 | 134 | def positions(self, readable: bool = True) -> Union[pd.DataFrame, np.ndarray]: 135 | """持仓记录""" 136 | records = self.pf.positions() 137 | if not readable: 138 | return records 139 | 140 | df = pd.DataFrame.from_records(records) 141 | df['asset'] = df['asset'].map(self.mapping_int_asset) 142 | return df 143 | 144 | def trades(self, return_all: bool, readable: bool = True) -> Union[pd.DataFrame, np.ndarray]: 145 | """成交记录 146 | 147 | Parameters 148 | ---------- 149 | return_all: bool 150 | 返回所有记录或返回最近一批记录 151 | readable: bool 152 | 返回可读格式 153 | unit: str 154 | 时间单位 155 | 156 | Returns 157 | ------- 158 | pd.DataFrame or np.ndarray 159 | 160 | """ 161 | records = self.pf.trades(return_all) 162 | if not readable: 163 | return records 164 | 165 | df = pd.DataFrame.from_records(records) 166 | df['date'] = pd.to_datetime(df['date'], unit=self._unit) 167 | df['asset'] = df['asset'].map(self.mapping_int_asset) 168 | return df 169 | 170 | def performances(self, return_all: bool, readable: bool = True) -> Union[ 171 | pd.DataFrame, np.ndarray]: 172 | """绩效记录""" 173 | records = self.pf.performances(return_all) 174 | if not readable: 175 | return records 176 | 177 | df = pd.DataFrame.from_records(records) 178 | df['date'] = pd.to_datetime(df['date'], unit=self._unit) 179 | df['asset'] = df['asset'].map(self.mapping_int_asset) 180 | return df 181 | 182 | def trades_stats(self, readable: bool = True) -> Union[pd.DataFrame, np.ndarray]: 183 | """成交统计""" 184 | trades = self.pf.trades(True) 185 | stats = calc_trades_stats(trades, len(self.mapping_int_asset)) 186 | if not readable: 187 | return stats 188 | 189 | df = pd.DataFrame.from_records(stats) 190 | df['start'] = pd.to_datetime(df['start'], unit=self._unit) 191 | df['end'] = pd.to_datetime(df['end'], unit=self._unit) 192 | df['period'] = pd.to_timedelta(df['period'], unit=self._unit) 193 | df['asset'] = df['asset'].map(self.mapping_int_asset) 194 | return df 195 | 196 | def roundtrips(self, readable: bool = True) -> Union[pd.DataFrame, np.ndarray]: 197 | """每轮交易记录""" 198 | trades = self.pf.trades(True) 199 | rounds = trades_to_roundtrips(trades, len(self.mapping_int_asset)) 200 | if not readable: 201 | return rounds 202 | 203 | df = pd.DataFrame.from_records(rounds) 204 | df['entry_date'] = pd.to_datetime(df['entry_date'], unit=self._unit) 205 | df['exit_date'] = pd.to_datetime(df['exit_date'], unit=self._unit) 206 | df['asset'] = df['asset'].map(self.mapping_int_asset) 207 | return df 208 | 209 | def roundtrips_stats(self, readable: bool = True) -> Union[pd.DataFrame, np.ndarray]: 210 | """每轮交易统计""" 211 | rounds = self.roundtrips(False) 212 | stats = calc_roundtrips_stats(rounds, len(self.mapping_int_asset)) 213 | if not readable: 214 | return stats 215 | df = pd.DataFrame.from_records(stats) 216 | df['asset'] = df['asset'].map(self.mapping_int_asset) 217 | return df 218 | 219 | def run_bar(self, arr) -> None: 220 | """同一时点,截面所有资产立即执行 221 | 222 | Parameters 223 | ---------- 224 | arr 225 | - date 226 | - size_type 227 | - asset 228 | - size 229 | nan可用于只更新价格但不交易 230 | - fill_price 231 | - last_price 232 | - commission 233 | - date_diff 234 | 235 | """ 236 | self.pf.run_bar2(arr) 237 | 238 | def run_bars(self, arrs) -> None: 239 | """多时点,循序分批执行 240 | 241 | Parameters 242 | ---------- 243 | arrs 244 | - date 245 | - size_type 246 | - asset 247 | - size: 248 | nan可用于只更新价格但不交易 249 | - fill_price 250 | - last_price 251 | - commission 252 | - date_diff 253 | 254 | """ 255 | for arr in arrs: 256 | self.pf.run_bar2(arr) 257 | 258 | 259 | def warmup() -> float: 260 | """热身 261 | 262 | 由于Numba JIT编译要占去不少时间,提前将大部分路径都跑一遍,之后调用就快了""" 263 | # import os 264 | # os.environ['NUMBA_DISABLE_JIT'] = '1' 265 | # pd.set_option('display.max_columns', None) 266 | # pd.set_option('display.width', 1000) 267 | 268 | from lightbt.enums import SizeType 269 | from lightbt.callbacks import commission_by_qty, commission_by_value 270 | from lightbt.enums import order_outside_dt 271 | from lightbt.signals import orders_daily 272 | from lightbt.utils import groupby 273 | 274 | symbols = [('510300', 1, 1, 0.001, commission_by_qty), ('IF2309', 300, 0.2, 0.0005, commission_by_value), ] 275 | config = pd.DataFrame.from_records(symbols, 276 | columns=['asset', 'mult', 'margin_ratio', 'commission_ratio', 'commission_fn']) 277 | 278 | df1 = pd.DataFrame({'asset': ['510300', 'IF2309'], 279 | 'size': [np.nan, -0.5], 280 | 'fill_price': [4.0, 4000.0], 281 | 'last_price': [4.0, 4000.0], 282 | 'date': '2023-08-01', 283 | 'size_type': SizeType.TargetValuePercent}) 284 | 285 | df2 = pd.DataFrame({'asset': ['510300', 'IF2309'], 286 | 'size': [0.5, 0.5], 287 | 'fill_price': [4.0, 4000.0], 288 | 'last_price': [4.0, 4000.0], 289 | 'date': '2023-08-02', 290 | 'size_type': SizeType.TargetValuePercent}) 291 | 292 | df = pd.concat([df1, df2]) 293 | df['date'] = pd.to_datetime(df['date']) 294 | unit = df['date'].dtype.name[-3:-1] 295 | 296 | tic = time.perf_counter() 297 | 298 | bt = LightBT(init_cash=10000 * 50, unit=unit) 299 | bt.deposit(10000 * 20) 300 | bt.withdraw(10000 * 10) 301 | 302 | bt.setup(config) 303 | # 只能在setup后才能做map 304 | df['asset'] = df['asset'].map(bt.mapping_asset_int) 305 | 306 | bt.run_bars(groupby(orders_daily(df, sort=True), by='date', dtype=order_outside_dt)) 307 | 308 | bt.positions() 309 | bt.trades(return_all=True) 310 | bt.performances(return_all=True) 311 | bt.reset() 312 | 313 | toc = time.perf_counter() 314 | return toc - tic 315 | -------------------------------------------------------------------------------- /lightbt/callbacks.py: -------------------------------------------------------------------------------- 1 | from numba import cfunc, float64, bool_ 2 | 3 | 4 | # 以下是手续费处理函数,用户也可以自己定义,通过setup进行设置 5 | 6 | @cfunc(float64(bool_, bool_, float64, float64, float64)) 7 | def commission_0(is_buy: bool, is_open: bool, value: float, qty: float, commission_ratio: float) -> float: 8 | """0手续费""" 9 | return 0.0 10 | 11 | 12 | @cfunc(float64(bool_, bool_, float64, float64, float64)) 13 | def commission_by_qty(is_buy: bool, is_open: bool, value: float, qty: float, commission_ratio: float) -> float: 14 | """按数量计算手续费""" 15 | return qty * commission_ratio 16 | 17 | 18 | @cfunc(float64(bool_, bool_, float64, float64, float64)) 19 | def commission_by_value(is_buy: bool, is_open: bool, value: float, qty: float, commission_ratio: float) -> float: 20 | """按市值计算手续费""" 21 | return value * commission_ratio 22 | 23 | 24 | @cfunc(float64(bool_, bool_, float64, float64, float64)) 25 | def commission_AStock(is_buy: bool, is_open: bool, value: float, qty: float, commission_ratio: float) -> float: 26 | """按市值计算手续费""" 27 | if is_open: 28 | commission = value * commission_ratio 29 | else: 30 | # 卖出平仓,多收千1的税 31 | commission = value * (commission_ratio + 0.001) 32 | 33 | if commission < 5.0: 34 | return 5.0 35 | return commission 36 | -------------------------------------------------------------------------------- /lightbt/enums.py: -------------------------------------------------------------------------------- 1 | from typing import NamedTuple 2 | 3 | import numpy as np 4 | 5 | 6 | # 此枚举定义参考于vectorbt。多加了与保证金有关类别 7 | class SizeTypeT(NamedTuple): 8 | # 空操作指令。通过此值比size全nan能减少代码执行 9 | NOP: int = 0 10 | # 下单数量和方向 11 | Amount: int = 1 12 | # 下单市值和方向 13 | Value: int = 2 14 | # 下单保证金和方向 15 | Margin: int = 3 16 | # 正数使用现金比例,负数卖出持仓比例 17 | Percent: int = 4 18 | # 目标数量和方向 19 | TargetAmount: int = 5 20 | # 目标市值和方向 21 | TargetValue: int = 6 22 | # 目标市值百分比。size绝对值之和范围[0,1] 23 | TargetValuePercent: int = 7 24 | # 目标市值比例。size值可能为1.5:1:-1等 25 | TargetValueScale: int = 8 26 | # 目标保证金和方向 27 | TargetMargin: int = 9 28 | # 目标保证金百分比。size绝对值之和范围[0,1] 29 | TargetMarginPercent: int = 10 30 | # 目标保证金比例。size值可能为1.5:1:-1等 31 | TargetMarginScale: int = 11 32 | 33 | 34 | SizeType = SizeTypeT() 35 | 36 | # 绩效统计。为减少内存,使用float32 37 | performance_dt = np.dtype([ 38 | ('date', np.uint64), # 日期时间 39 | ('asset', np.uint32), # 资产ID 40 | ('amount', np.float32), # 净持仓量。负数表示空头 41 | ('value', np.float32), # 净持仓市值。负数表示空头 42 | ('cash', np.float32), # 当前批量单执行后所剩余现金 43 | ('margin', np.float32), # 占用保证金 44 | ('upnl', np.float32), # 持仓盈亏 45 | ('cum_pnl', np.float32), # 累计平仓盈亏(未扣除手续费) 46 | ('cum_commission', np.float32), # 累计手续费 47 | ], align=True) 48 | 49 | # 成交记录。为减少内存,使用float32 50 | trade_dt = np.dtype([ 51 | ('date', np.uint64), # 日期时间 52 | ('asset', np.uint32), # 资产ID 53 | ('is_buy', np.bool_), # 是否买入。 54 | ('is_open', np.bool_), # 是否开平。开一定是开,平有可能含反手。反手也可以拆成两单 55 | ('fill_price', np.float32), # 成交价 56 | ('qty', np.float32), # 当前交易数量 57 | ('amount', np.float32), # 持仓量和方向 58 | ('margin', np.float32), # 保证金 59 | ('commission', np.float32), # 手续费 60 | ('upnl', np.float32), # 持仓盈亏 61 | ('pnl', np.float32), # 平仓盈亏(未扣除手续费) 62 | ('cash_flow', np.float32), # 现金流=平仓盈亏-保证金-手续费 63 | ('cash', np.float32), # 现金 64 | ], align=True) 65 | 66 | # 持仓记录。中间计算用字段与计算字段类型一致,而展示用字段减少内存 67 | position_dt = np.dtype([ 68 | ('asset', np.uint32), # 资产ID 69 | # 用于相关下单数量计算 70 | ('mult', float), # 合约乘数 71 | ('margin_ratio', float), # 保证金率 72 | ('amount', float), # 净持仓数量 73 | # 展示用字段 74 | ('value', np.float32), # 市值 75 | ('open_value', np.float32), # 开仓市值 76 | ('avg_price', np.float32), # 平均价 77 | ('last_price', np.float32), # 最新价 78 | ('margin', np.float32), # 保证金 79 | ('upnl', np.float32), # 持仓盈亏 80 | ('cum_pnl', np.float32), # 累计平仓盈亏(未扣除手续费) 81 | ('cum_commission', np.float32), # 累计手续费 82 | ], align=True) 83 | 84 | # 外部下单指令,用于将用户的指令转成内部指令 85 | order_outside_dt = np.dtype([ 86 | ('date', np.uint64), # 日期时间 87 | ('size_type', int), # size字段类型 88 | ('asset', int), # 资产ID 89 | ('size', float), # nan时表示此行不参与交易。可用于有持仓但不交易的资产更新最新价 90 | ('fill_price', float), # 成交价 91 | ('last_price', float), # 最新价 92 | ('date_diff', bool), # 标记换日,会触发绩效更新 93 | ], align=True) 94 | 95 | # 内部下单指令。用于将上层的目标持仓等信息转换成实际下单指令 96 | order_inside_dt = np.dtype([ 97 | ('asset', int), # 资产ID 98 | ('is_buy', bool), # 是否买入 99 | ('is_open', bool), # 是否开仓 100 | ('fill_price', float), # 成交价 101 | ('qty', float), # 成交数量 102 | ], align=True) 103 | 104 | # 成交统计。条目数一般等于资产数量 105 | trades_stats_dt = np.dtype([ 106 | ('asset', np.uint32), # 资产ID 107 | ('start', np.uint64), # 第一条记录时间 108 | ('end', np.uint64), # 最后一条记录时间 109 | ('period', np.uint64), # 期 110 | ('total_count', np.uint32), # 总条数 111 | ('buy_count', np.uint32), # 买入条数 112 | ('sell_count', np.uint32), # 卖出条数 113 | ('min_qty', np.float32), # 最小交易量 114 | ('max_qty', np.float32), # 最大交易量 115 | ('avg_qty', np.float32), # 平均交易量 116 | ('avg_buy_qty', np.float32), # 平均买入量 117 | ('avg_sell_qty', np.float32), # 平均卖出量 118 | ('avg_buy_price', np.float32), # 平均买入价 119 | ('avg_sell_price', np.float32), # 平均卖出价 120 | ('total_commission', np.float32), # 总手续费 121 | ('min_commission', np.float32), # 最小手续费 122 | ('max_commission', np.float32), # 最大手续费 123 | ('avg_commission', np.float32), # 平均手续费 124 | ('avg_buy_commission', np.float32), # 平均买入手续费 125 | ('avg_sell_commission', np.float32), # 平均卖出手续费 126 | ], align=True) 127 | 128 | # 交易每轮。由入场和出场组成 129 | roundtrip_dt = np.dtype([ 130 | ('asset', np.uint32), # 资产ID 131 | ('is_long', np.bool_), # 是否多头 132 | ('is_close', np.bool_), # 是否已平 133 | ('qty', np.float32), # 数量 134 | ('entry_date', np.uint64), # 入场时间 135 | ('entry_price', np.float32), # 入场价 136 | ('entry_commission', np.float32), # 入场手续费 137 | ('exit_date', np.uint64), # 出场时间 138 | ('exit_price', np.float32), # 出场价 139 | ('exit_commission', np.float32), # 出场手续费 140 | ('pnl', np.float32), # 本轮平仓盈亏 141 | ('pnl_com', np.float32), # 本轮平仓盈亏(已减手续费) 142 | ], align=True) 143 | 144 | # 每轮统计 145 | roundtrip_stats_dt = np.dtype([ 146 | ('asset', np.uint32), # 资产ID 147 | ('total_count', np.uint32), # 总条数 148 | ('long_count', np.uint32), # 多头条数 149 | ('short_count', np.uint32), # 空头条数 150 | ('winning_count', np.uint32), # 盈利条数 151 | ('losing_count', np.uint32), # 亏损条数 152 | ('win_rate', np.float32), # 胜率 153 | ], align=True) 154 | -------------------------------------------------------------------------------- /lightbt/portfolio.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | from numba import typeof, objmode, types, prange 5 | from numba.experimental import jitclass 6 | from numba.typed.typedlist import List 7 | 8 | from lightbt.enums import SizeType, trade_dt, position_dt, performance_dt, order_inside_dt 9 | from lightbt.position import Position 10 | 11 | 12 | class Portfolio: 13 | _positions_precision: float 14 | _cash: float 15 | 16 | _idx_curr_trade: int 17 | _idx_curr_performance: int 18 | _idx_last_trade: int 19 | _idx_last_performance: int 20 | _max_trades: int 21 | _max_performances: int 22 | 23 | def __init__(self, 24 | positions_precision: float = 1.0, 25 | max_trades: int = 1024, 26 | max_performances: int = 1024) -> None: 27 | """初始化 28 | 29 | Parameters 30 | ---------- 31 | positions_precision: float 32 | 持仓精度 33 | - 1.0 表示整数量 34 | - 0.01 表示持仓可以精确到0.01,用于数字货币等场景 35 | - 0.000001 精度高,相当于对持仓精度不做限制 36 | max_trades: int 37 | 记录成交的缓存大小。空间不足时将丢弃 38 | max_performances: int 39 | 记录绩效的缓存大小。空间不足时将丢弃 40 | """ 41 | # https://github.com/numba/numba/issues/8733 42 | list_tmp = List() 43 | list_tmp.append(Position(0)) 44 | list_tmp.clear() 45 | self._positions = list_tmp 46 | 47 | self._trade_records = np.empty(max_trades, dtype=trade_dt) 48 | self._position_records = np.empty(1, dtype=position_dt) 49 | self._performance_records = np.empty(max_performances, dtype=performance_dt) 50 | 51 | self._positions_precision = positions_precision 52 | 53 | self._max_trades = max_trades 54 | self._max_performances = max_performances 55 | 56 | self.reset() 57 | 58 | def reset(self): 59 | self._cash = 0.0 60 | 61 | self._idx_curr_trade = 0 62 | self._idx_curr_performance = 0 63 | self._idx_last_trade = 0 64 | self._idx_last_performance = 0 65 | 66 | for p in self._positions: 67 | p.reset() 68 | 69 | @property 70 | def Cash(self) -> float: 71 | """现金""" 72 | return self._cash 73 | 74 | @property 75 | def Value(self) -> float: 76 | """持仓市值。空头为负数""" 77 | return np.sum(np.array([pos.Value for pos in self._positions])) 78 | 79 | @property 80 | def Margin(self) -> float: 81 | """保证金占用""" 82 | return np.sum(np.array([pos.Margin for pos in self._positions])) 83 | 84 | @property 85 | def UPnL(self) -> float: 86 | """未平仓盈亏""" 87 | return np.sum(np.array([pos.UPnL for pos in self._positions])) 88 | 89 | @property 90 | def Equity(self) -> float: 91 | """权益=子权益+现金""" 92 | return np.sum(np.array([pos.Equity for pos in self._positions])) + self._cash 93 | 94 | def deposit(self, cash: float) -> float: 95 | """入金 96 | 97 | Parameters 98 | ---------- 99 | cash: float 100 | 101 | Returns 102 | ------- 103 | float 104 | 105 | Notes 106 | ----- 107 | 默认资金为0,所以交易前需要入金 108 | 109 | """ 110 | self._cash += cash 111 | return self._cash 112 | 113 | def withdraw(self, cash: float) -> float: 114 | """出金""" 115 | self._cash -= cash 116 | return self._cash 117 | 118 | def set_commission_fn(self, asset: int, func) -> None: 119 | """设置手续费函数""" 120 | pos: Position = self._positions[asset] 121 | pos.set_commission_fn(func) 122 | 123 | def setup(self, asset: np.ndarray, 124 | mult: np.ndarray, margin_ratio: np.ndarray, 125 | commission_ratio: np.ndarray) -> None: 126 | """批量配置各品种的参数。 127 | 128 | 1. 有部分品种以前是一种配置,后来又换了配置. 如黄金 129 | 2. 新添品种 130 | 131 | Parameters 132 | ---------- 133 | asset: np.ndarray 134 | 资产ID 135 | mult: np.ndarray 136 | 合约乘数 137 | margin_ratio: np.ndarray 138 | 保证金率 139 | commission_ratio: np.ndarray 140 | 手续费率 141 | 142 | """ 143 | # 指定长度进行初始化 144 | count = len(mult) 145 | while len(self._positions) < count: 146 | self._positions.append(Position(len(self._positions))) 147 | 148 | # 创建记录体,用于最终显示持仓 149 | self._position_records = np.empty(len(self._positions), dtype=position_dt) 150 | 151 | for i in prange(count): 152 | # 153 | self._positions[asset[i]].setup(mult[i], margin_ratio[i], commission_ratio[i]) 154 | 155 | def _fill_trade_record(self, 156 | date: np.uint64, asset: int, 157 | is_buy: bool, is_open: bool, fill_price: float, qty: float) -> None: 158 | """遇到有效成交时自动更新,所以内容直接取即可""" 159 | if self._idx_curr_trade >= self._max_trades: 160 | return 161 | rec = self._trade_records[self._idx_curr_trade] 162 | 163 | self._positions[asset].to_record_trade(rec, date, is_buy, is_open, fill_price, qty, self._cash) 164 | 165 | self._idx_curr_trade += 1 166 | 167 | def _fill_position_records(self, detail: bool) -> None: 168 | """更新持仓记录""" 169 | for i, pos in enumerate(self._positions): 170 | rec = self._position_records[i] 171 | pos.to_record_position(rec, detail) 172 | 173 | def update_last_price(self, asset: np.ndarray, last_price: np.ndarray) -> None: 174 | """更新结算价""" 175 | for i in prange(len(asset)): 176 | pos: Position = self._positions[asset[i]] 177 | if pos.Amount == 0: 178 | # 只对有持仓的更新最新价即可 179 | continue 180 | pos.update_last_price(last_price[i]) 181 | 182 | def update_performances(self, date: np.uint64) -> None: 183 | """更新绩效""" 184 | cash: float = self._cash 185 | # 上次的位置 186 | self._idx_last_performance = self._idx_curr_performance 187 | for i, pos in enumerate(self._positions): 188 | if self._idx_curr_performance >= self._max_performances: 189 | return 190 | 191 | rec = self._performance_records[self._idx_curr_performance] 192 | pos.to_record_performance(rec, date, cash) 193 | 194 | self._idx_curr_performance += 1 195 | 196 | def update(self, date: np.uint64, asset: np.ndarray, last_price: np.ndarray, do_settlement: bool) -> None: 197 | """更新价格。记录绩效 198 | 199 | Parameters 200 | ---------- 201 | date: np.int64 202 | 日期。可以转成pandas时间 203 | asset: np.ndarray 204 | 需更新行情的资产 205 | last_price: np.ndarray 206 | 最新价。日频可以是结算价 207 | do_settlement: bool 208 | 是否结算 209 | 210 | """ 211 | self.update_last_price(asset, last_price) 212 | if do_settlement: 213 | self.settlement() 214 | self.update_performances(date) 215 | 216 | def settlement(self) -> None: 217 | """结算""" 218 | for i, pos in enumerate(self._positions): 219 | self._cash += pos.settlement() 220 | 221 | def performances(self, return_all: bool) -> np.ndarray: 222 | """绩效记录""" 223 | if return_all: 224 | return self._performance_records[:self._idx_curr_performance] 225 | else: 226 | return self._performance_records[self._idx_last_performance:self._idx_curr_performance] 227 | 228 | def trades(self, return_all: bool) -> np.ndarray: 229 | """很多变量只记录了瞬时值,当需要时序值时,通过此函数记录下来备用""" 230 | if return_all: 231 | return self._trade_records[:self._idx_curr_trade] 232 | else: 233 | return self._trade_records[self._idx_last_trade:self._idx_curr_trade] 234 | 235 | def positions(self) -> np.ndarray: 236 | """最新持仓记录""" 237 | self._fill_position_records(True) 238 | return self._position_records 239 | 240 | def order(self, date: np.uint64, asset: int, is_buy: bool, is_open: bool, fill_price: float, qty: float) -> bool: 241 | """下单 242 | 243 | Parameters 244 | ---------- 245 | date: int 246 | asset: int 247 | is_buy: bool 248 | is_open: bool 249 | 是否开仓。反手暂时归属于平仓。 250 | fill_price: float 251 | qty: float 252 | 253 | Returns 254 | ------- 255 | bool 256 | 257 | """ 258 | # convert_size时已经过滤了不合法的数量,所以这里注释了 259 | 260 | # if qty <= 0.0: 261 | # # 数量不合法,返回。可用于刷新行情但不产生交易记录 262 | # return False 263 | 264 | pos: Position = self._positions[asset] 265 | # 成交价所对应的市值和手续费 266 | value = pos.calc_value(fill_price, qty) 267 | commission = pos.calc_commission(is_buy, is_open, value, qty) 268 | if is_open: 269 | # 可开手数检查 270 | if not pos.openable(self._cash, value, commission): 271 | return False 272 | else: 273 | # TODO: 可能有反手情况。这个以后再处理 274 | pass 275 | 276 | pos.fill(is_buy, is_open, value, fill_price, qty, commission) 277 | self._cash += pos.CashFlow 278 | 279 | self._fill_trade_record(date, asset, is_buy, is_open, fill_price, qty) 280 | 281 | return True 282 | 283 | def convert_size(self, size_type: int, asset: np.ndarray, size: np.ndarray, fill_price: np.ndarray) -> np.ndarray: 284 | """交易数量转换 285 | 286 | Parameters 287 | ---------- 288 | size_type 289 | asset 290 | size: float 291 | nan时表示不交易 292 | fill_price 293 | 294 | """ 295 | self._fill_position_records(False) 296 | # asset不能出现重复 297 | _rs: np.ndarray = self._position_records[asset] 298 | margin_ratio: np.ndarray = _rs['margin_ratio'] 299 | amount: np.ndarray = _rs['amount'] 300 | mult: np.ndarray = _rs['mult'] 301 | 302 | # 所有的TargetXxx类型,如果出现size=0, 直接处理更精确 303 | is_target: bool = size_type >= SizeType.TargetAmount 304 | is_zero: np.ndarray = size == 0 305 | 306 | # 归一时做分母。但必需是没有上游改动 307 | _equity: float = 0.0 308 | equity: float = self.Equity 309 | cash: float = self._cash 310 | size_abs_sum: float = np.nansum(np.abs(size)) 311 | if size_abs_sum == 0: 312 | # 全0表示清仓 313 | if size_type > SizeType.TargetAmount: 314 | size_type = SizeType.TargetAmount 315 | 316 | # 目标保证金比率相关计算。最后转成目标市值 317 | if size_type >= SizeType.TargetMargin: 318 | if size_type == SizeType.TargetMarginScale: 319 | size /= size_abs_sum # 归一。最终size和是1 320 | _equity = equity 321 | size *= _equity 322 | if size_type == SizeType.TargetMarginPercent: 323 | _equity = equity * size_abs_sum 324 | size *= _equity 325 | if size_type == SizeType.TargetMargin: 326 | pass 327 | 328 | # 统一保证金 329 | size /= margin_ratio 330 | size_type = SizeType.TargetValue 331 | 332 | # 目标市值比率相关计算。最后转成目标市值 333 | if size_type > SizeType.TargetValue: 334 | if size_type == SizeType.TargetValueScale: 335 | size /= size_abs_sum 336 | _equity = equity 337 | if size_type == SizeType.TargetValuePercent: 338 | _equity = equity * size_abs_sum 339 | 340 | # 特殊处理,通过保证金率还原市值占比 341 | _ratio: float = np.nansum((np.abs(size) * margin_ratio)) 342 | size *= _equity 343 | if _ratio != 0: 344 | size /= _ratio 345 | size_type = SizeType.TargetValue 346 | 347 | # 使用次数最多的类型 348 | if size_type == SizeType.TargetValue: 349 | # 前后市值之差 350 | size -= (fill_price * mult * amount) 351 | size_type = SizeType.Value 352 | if size_type == SizeType.TargetAmount: 353 | # 前后Amout差值 354 | size -= amount 355 | size_type = SizeType.Amount 356 | if size_type == SizeType.Percent: 357 | # 买入开仓,用现金转市值 358 | # 卖出平仓,持仓市值的百分比 359 | # TODO: 由于无法表示卖出开仓。所以只能用在股票市场 360 | size *= np.where(size >= 0, cash / (fill_price * mult * margin_ratio), amount) 361 | size_type = SizeType.Amount 362 | if size_type == SizeType.Margin: 363 | # 将保证金转换成市值 364 | size /= margin_ratio 365 | size_type = SizeType.Value 366 | if size_type == SizeType.Value: 367 | # 将市值转成手数 368 | size /= (fill_price * mult) 369 | size_type = SizeType.Amount 370 | if size_type == SizeType.Amount: 371 | pass 372 | 373 | if is_target: 374 | # 直接取反,回避了前期各种计算导致的误差 375 | size[is_zero] = -amount[is_zero] 376 | 377 | is_open: np.ndarray = np.sign(amount) * np.sign(size) 378 | is_open = np.where(is_open == 0, amount == 0, is_open > 0) 379 | 380 | amount_abs = np.abs(amount) 381 | size_abs = np.abs(size) 382 | 383 | # 创建一个原始订单表,其中存在反手单 384 | orders = np.empty(len(asset), dtype=order_inside_dt) 385 | orders['asset'][:] = asset 386 | orders['fill_price'][:] = fill_price 387 | orders['qty'][:] = size_abs 388 | orders['is_buy'][:] = size >= 0 389 | orders['is_open'][:] = is_open 390 | 391 | # 是否有反手单 392 | is_reverse = (~is_open) & (size_abs > amount_abs) 393 | 394 | # 将反手单分离成两单。注意:trades表占用翻倍 395 | if np.any(is_reverse): 396 | orders1 = orders.copy() 397 | orders2 = orders.copy() 398 | orders2['is_open'][:] = True 399 | 400 | orders1['qty'][is_reverse] = amount_abs[is_reverse] 401 | orders2['qty'][is_reverse] -= amount_abs[is_reverse] 402 | # print(orders2[is_reverse]) 403 | 404 | orders = np.concatenate((orders1, orders2[is_reverse])) 405 | 406 | qty = orders['qty'] 407 | is_open = orders['is_open'] 408 | 409 | if self._positions_precision == 1.0: 410 | # 提前条件判断,速度能快一些 411 | qty = np.where(is_open, np.floor(qty + 1e-9), np.ceil(qty - 1e-9)) 412 | else: 413 | # 开仓用小数量,平仓用大数量。接近于0时自动调整为0 414 | qty /= self._positions_precision 415 | # 10.2/0.2=50.99999999999999 416 | qty = np.where(is_open, np.floor(qty + 1e-9), np.ceil(qty - 1e-9)) * self._positions_precision 417 | # 原数字处理后会有小尾巴,简单处理一下 418 | qty = np.round(qty, 9) 419 | 420 | orders['qty'][:] = qty 421 | 422 | # 过滤无效操作。nan正好也被过滤了不会下单 423 | return orders[orders['qty'] > 0] 424 | 425 | def run_bar1(self, 426 | date: np.uint64, size_type: int, 427 | asset: np.ndarray, size: np.ndarray, fill_price: np.ndarray) -> None: 428 | """一层截面信号处理。只处理同时间截面上所有资产的交易信号 429 | 430 | Parameters 431 | ---------- 432 | date 433 | size_type 434 | asset 435 | size 436 | fill_price 437 | 438 | """ 439 | # 空指令直接返回 440 | if size_type == SizeType.NOP: 441 | return 442 | # 全空,返回 443 | if np.all(np.isnan(size)): 444 | return 445 | 446 | # size转换 447 | orders: np.ndarray = self.convert_size(size_type, asset, size, fill_price) 448 | 449 | # 过滤后为空 450 | if len(orders) == 0: 451 | return 452 | 453 | # 记录上次位置 454 | self._idx_last_trade = self._idx_curr_trade 455 | 456 | # 先平仓 457 | orders_close = orders[~orders['is_open']] 458 | for i in prange(len(orders_close)): 459 | _o = orders_close[i] 460 | self.order(date, _o['asset'], _o['is_buy'], _o['is_open'], _o['fill_price'], _o['qty']) 461 | 462 | # 后开仓 463 | orders_open = orders[orders['is_open']] 464 | for i in prange(len(orders_open)): 465 | _o = orders_open[i] 466 | self.order(date, _o['asset'], _o['is_buy'], _o['is_open'], _o['fill_price'], _o['qty']) 467 | 468 | def run_bar2(self, arr: np.ndarray) -> None: 469 | """二层截面信号处理。在一层截面信号的基础上多了最新价更新,以及绩效记录 470 | 471 | Parameters 472 | ---------- 473 | arr 474 | - date 475 | - size_type 476 | - asset 477 | - size 478 | - fill_price 479 | - last_price 480 | - date_diff 481 | 482 | """ 483 | _date: np.uint64 = arr['date'][-1] 484 | _size_type: int = arr['size_type'][-1] 485 | _date_diff: bool = arr['date_diff'][-1] 486 | _asset = arr['asset'] 487 | 488 | # 先执行交易 489 | self.run_bar1(_date, _size_type, _asset, arr['size'], arr['fill_price']) 490 | # 更新最新价。浮动盈亏得到了调整 491 | self.update_last_price(_asset, arr['last_price']) 492 | # 每日收盘记录绩效 493 | if _date_diff: 494 | self.update_performances(_date) 495 | 496 | def __str__(self): 497 | # 这个要少调用,很慢 498 | with objmode(string=types.unicode_type): 499 | string = f'Portfolio(Value={self.Value}, Cash={self.Cash}, Equity={self.Equity})' 500 | return string 501 | 502 | 503 | # 这种写法是为了方便开关断点调试 504 | if os.environ.get('NUMBA_DISABLE_JIT', '0') != '1': 505 | # TODO: List支持有问题,不得不这么写,等以后numba修复了再改回来 506 | list_tmp = List() 507 | list_tmp.append(Position(0)) 508 | position_list_type = typeof(list_tmp) 509 | 510 | trade_type = typeof(np.empty(1, dtype=trade_dt)) 511 | position_type = typeof(np.empty(1, dtype=position_dt)) 512 | performance_type = typeof(np.empty(1, dtype=performance_dt)) 513 | 514 | Portfolio = jitclass(Portfolio, 515 | [('_positions', position_list_type), 516 | ('_trade_records', trade_type), 517 | ('_position_records', position_type), 518 | ('_performance_records', performance_type), ]) 519 | -------------------------------------------------------------------------------- /lightbt/position.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | from numba import objmode, types, njit, float64, typeof, bool_ 5 | from numba.experimental import jitclass 6 | 7 | from lightbt.callbacks import commission_0 8 | 9 | __TOL__: float = 1e-6 10 | 11 | 12 | @njit(float64(float64, float64, float64), fastmath=True, nogil=True, cache=True) 13 | def _value_with_mult(price: float, qty: float, mult: float) -> float: 14 | """计算市值""" 15 | if mult == 1.0: 16 | # 少算一次乘法,速度快一些 17 | return price * qty 18 | return price * qty * mult 19 | 20 | 21 | @njit(float64(float64, float64, float64), fastmath=True, nogil=True, cache=True) 22 | def _avg_with_mult(value: float, qty: float, mult: float) -> float: 23 | """计算均价""" 24 | if mult == 1.0: 25 | # 少算一次除法法,速度快一些 26 | return value / qty 27 | return value / qty / mult 28 | 29 | 30 | @njit(float64(float64, float64), fastmath=True, nogil=True, cache=True) 31 | def _net_cash_flow_with_margin(value: float, margin_ratio: float) -> float: 32 | """计算净现金流""" 33 | if margin_ratio == 1.0: 34 | # 少算一次乘法,速度快一些 35 | return value 36 | return value * margin_ratio 37 | 38 | 39 | @njit(bool_(float64), fastmath=True, nogil=True) 40 | def _is_zero(x: float) -> bool: 41 | """是否为0 42 | 43 | float,double分别遵循R32-24,R64-53的标准。 44 | 所以float的精度误差在1e-6;double精度误差在1e-15 45 | """ 46 | return (x <= __TOL__) and (x >= -__TOL__) 47 | 48 | 49 | # 部分参考了SmartQuant部分代码,但又做了大量调整 50 | class Position: 51 | Asset: int 52 | # TODO: 由于List中使用嵌套jitclass有问题,不得得将SubPosition简化成float 53 | # 多头数量 54 | LongQty: float 55 | # 空头数量 56 | ShortQty: float 57 | 58 | # !!! 本来应当将合约乘数,保证金率等信息存在Instrument对象中,但为了简化,安排在Position对象中 59 | # 合约乘数 60 | _mult: float 61 | # 保证金率 62 | _margin_ratio: float 63 | # 手续费率 64 | _commission_ratio: float 65 | 66 | # 最新价。用于计算Value/UPnL 67 | LastPrice: float 68 | # 持仓量。空头为负数 69 | Amount: float 70 | # 累计买入量 71 | QtyBought: float 72 | # 累计卖出量 73 | QtySold: float 74 | # 开仓均价 75 | AvgPx: float 76 | 77 | # 手续费 78 | _commission: float 79 | # 累计手续费 80 | _cum_commission: float 81 | # 盈亏 82 | _pnl: float 83 | # 累计盈亏 84 | _cum_pnl: float 85 | # 保证金 86 | _margin: float 87 | # 开仓市值 88 | _open_value: float 89 | # 市值流变化 90 | _value_flow: float 91 | # 净现金流 92 | _net_cash_flow: float 93 | # 现金流 94 | _cash_flow: float 95 | 96 | def __init__(self, asset: int) -> None: 97 | """初始化 98 | 99 | Parameters 100 | ---------- 101 | asset: int 102 | 预设的资产顺序ID 103 | """ 104 | self.Asset = asset 105 | 106 | self._mult = 1.0 107 | self._margin_ratio = 1.0 108 | self._commission_ratio = 0.0 109 | self._commission_fn = commission_0 110 | 111 | self.reset() 112 | 113 | def reset(self): 114 | self.LongQty = 0.0 115 | self.ShortQty = 0.0 116 | self.LastPrice = 0.0 117 | self.Amount = 0.0 118 | self.QtyBought = 0.0 119 | self.QtySold = 0.0 120 | self.AvgPx = 0.0 121 | self._commission = 0.0 122 | self._cum_commission = 0.0 123 | self._margin = 0.0 124 | self._open_value = 0.0 125 | self._pnl = 0.0 126 | self._cum_pnl = 0.0 127 | self._value_flow = 0.0 128 | self._net_cash_flow = 0.0 129 | self._cash_flow = 0.0 130 | 131 | @property 132 | def is_long(self) -> bool: 133 | """是否多头""" 134 | return self.Amount >= 0 135 | 136 | @property 137 | def Qty(self) -> float: 138 | """持仓数量""" 139 | if self.Amount >= 0: 140 | return self.Amount 141 | else: 142 | return -self.Amount 143 | 144 | @property 145 | def Value(self) -> float: 146 | """持仓市值。受last_prce影响""" 147 | if self.Amount == 0: 148 | return 0.0 149 | return self.LastPrice * self.Amount * self._mult 150 | 151 | @property 152 | def OpenValue(self) -> float: 153 | """开仓市值""" 154 | if self.Amount < 0: 155 | return -self._open_value 156 | else: 157 | return self._open_value 158 | 159 | @property 160 | def Margin(self) -> float: 161 | """保证金占用""" 162 | return self._margin 163 | 164 | @property 165 | def UPnL(self) -> float: 166 | """持仓盈亏=持仓市值-开仓市值。受`last_prce`影响""" 167 | if self.Amount == 0: 168 | return 0.0 169 | return self.Value - self.OpenValue 170 | 171 | @property 172 | def PnL(self) -> float: 173 | """平仓盈亏(未减手续费)""" 174 | return self._pnl 175 | 176 | @property 177 | def CumPnL(self) -> float: 178 | """累计平仓盈亏(未减手续费)""" 179 | return self._cum_pnl 180 | 181 | @property 182 | def CumCommission(self) -> float: 183 | """累计手续费""" 184 | return self._cum_commission 185 | 186 | @property 187 | def CashFlow(self) -> float: 188 | """现金流=平仓盈亏-保证金-手续费""" 189 | return self._cash_flow 190 | 191 | @property 192 | def Equity(self) -> float: 193 | """持仓权益=保证金+浮动盈亏。受`last_price`影响""" 194 | return self.Margin + self.UPnL 195 | 196 | def calc_value(self, price: float, qty: float)->float: 197 | """计算市值""" 198 | return _value_with_mult(price, qty, self._mult) 199 | 200 | def calc_commission(self, is_buy: bool, is_open: bool, value: float, qty: float) -> float: 201 | """计算手续费""" 202 | return self._commission_fn(is_buy, is_open, value, qty, self._commission_ratio) 203 | 204 | def openable(self, cash: float, value: float, commission: float) -> bool: 205 | """可开手数 206 | 207 | 需考虑负价格情况 208 | 1. 原油、天然气、电价都出现过负价格 209 | 2. 套利合约的价差也可能是负数 210 | 211 | Parameters 212 | ---------- 213 | cash: float 214 | 分配的可用现金。注意:不是总现金。 215 | value: float 216 | 报单市值 217 | commission: float 218 | 手续费 219 | 220 | Returns 221 | ------- 222 | bool 223 | 是否成功 224 | 225 | """ 226 | # TODO: 负手续费的情况下是如何处理? 227 | if cash < 0: 228 | return False 229 | 230 | if self._margin_ratio == 1.0: 231 | return (cash + commission) >= value 232 | else: 233 | return (cash + commission) >= value * self._margin_ratio 234 | 235 | def closable(self, is_long: bool) -> float: 236 | """可平手数 237 | 238 | Parameters 239 | ---------- 240 | is_long: bool 241 | 是否多头 242 | 243 | Returns 244 | ------- 245 | float 246 | 247 | """ 248 | if is_long: 249 | return self.LongQty 250 | else: 251 | return self.ShortQty 252 | 253 | def set_commission_fn(self, func=commission_0) -> None: 254 | self._commission_fn = func 255 | 256 | def setup(self, mult: float = 1.0, margin_ratio: float = 1.0, commission_ratio: float = 0.0) -> None: 257 | """配置资产信息 258 | 259 | Parameters 260 | ---------- 261 | mult: float 262 | 合约乘数 263 | margin_ratio: float 264 | 保证金率 265 | commission_ratio: float 266 | 手续费率 267 | 268 | """ 269 | self._mult = mult 270 | self._margin_ratio = margin_ratio 271 | self._commission_ratio = commission_ratio 272 | 273 | def settlement(self) -> float: 274 | """结算。结算后可能产生现金变动 275 | 276 | 1. 逆回购返利息 277 | 2. 分红 278 | 3. 手续费减免 279 | """ 280 | return 0.0 281 | 282 | def update_last_price(self, last_price: float) -> None: 283 | """更新最新价。用于计算资金曲线 284 | 285 | Parameters 286 | ---------- 287 | last_price: float 288 | 289 | """ 290 | self.LastPrice = last_price 291 | 292 | def fill(self, is_buy: bool, is_open: bool, value: float, fill_price: float, qty: float, 293 | commission: float = 0.0) -> None: 294 | """通过成交回报,更新各字段 295 | 296 | Parameters 297 | ---------- 298 | is_buy: bool 299 | 是否买入 300 | is_open: bool 301 | 是否开仓。反手需要标记成平仓 302 | value: float 303 | 开仓市值 304 | fill_price: float 305 | 成交价 306 | qty: float 307 | 成交量 308 | commission: float 309 | 手续费 310 | 311 | """ 312 | self._net_cash_flow = 0.0 313 | self._cash_flow = 0.0 314 | 315 | # 计算开仓市值,平均价。返回改变的持仓市值 316 | self._calculate(is_open, value, fill_price, qty, commission, self._mult) 317 | 318 | if is_buy: 319 | self.QtyBought += qty 320 | else: 321 | self.QtySold += qty 322 | 323 | # 此处需要更新正确的子持仓对像。如股票买入时只能更新昨仓对象,而买入时只能更新今仓对象 324 | # !!!注意: is_open与is_long是有区别的 325 | if is_open: 326 | if is_buy: 327 | self.LongQty += qty 328 | else: 329 | self.ShortQty += qty 330 | else: 331 | if is_buy: 332 | self.ShortQty -= qty 333 | else: 334 | self.LongQty -= qty 335 | 336 | # 新持仓量需做部分计算后再更新 337 | self.Amount = self.QtyBought - self.QtySold 338 | # 净现金流 339 | self._net_cash_flow = -self._value_flow * self._margin_ratio 340 | # 现金流。考虑了盈亏和手续费 341 | self._cash_flow = self._pnl + self._net_cash_flow - commission 342 | # 保证金占用 343 | self._margin -= self._net_cash_flow 344 | # 更新最新价,用于计算盈亏 345 | self.LastPrice = fill_price 346 | # 累计盈亏 347 | self._cum_pnl += self._pnl 348 | # 累计手续费 349 | self._cum_commission += self._commission 350 | 351 | # 这几个值出现接近0时调整成0 352 | # 有了这个调整后,回测速度反而加快 353 | if _is_zero(self.Amount): 354 | self.Amount = 0.0 355 | self._margin = 0.0 356 | 357 | def _calculate_pnl(self, is_long: bool, avg_price: float, fill_price: float, qty: float, mult: float) -> float: 358 | """根据每笔成交计算盈亏。只有平仓才会调用此函数""" 359 | value: float = fill_price - avg_price if is_long else avg_price - fill_price 360 | return qty * value * mult 361 | 362 | def _calculate(self, is_open: bool, value: float, fill_price: float, qty: float, commission: float, 363 | mult: float) -> None: 364 | """更新开仓市值和平均价。需用到合约乘数。 365 | 内部函数,不检查合法性。检查提前,有利于加快速度""" 366 | self._pnl = 0.0 367 | self._value_flow = 0.0 368 | self._commission = 0.0 369 | 370 | # 当前空仓 371 | if self.Amount == 0.0: 372 | self._value_flow = value # _value_with_mult(fill_price, qty, mult) 373 | self._open_value = self._value_flow 374 | self.AvgPx = fill_price 375 | self._commission = commission 376 | return 377 | 378 | # 开仓。已经到这只能是加仓 379 | if is_open: 380 | # 开仓改变的市值流与外部计算结果一样 381 | self._value_flow = value # _value_with_mult(fill_price, qty, mult) 382 | self._open_value += self._value_flow 383 | self.AvgPx = _avg_with_mult(self._open_value, self.Qty + qty, mult) 384 | self._commission = commission 385 | return 386 | 387 | # 平仓 388 | self._pnl = self._calculate_pnl(self.is_long, self.AvgPx, fill_price, qty, mult) 389 | 390 | if _is_zero(self.Qty - qty): 391 | # 清仓,市值流正好是之前持仓市值 392 | self._value_flow = -self._open_value 393 | self._open_value = 0.0 394 | self.AvgPx = 0.0 395 | self._commission = commission 396 | return 397 | elif self.Qty > qty: 398 | # 减仓 399 | self._value_flow = -_value_with_mult(self.AvgPx, qty, mult) 400 | self._open_value += self._value_flow 401 | self._commission = commission 402 | return 403 | 404 | # !!! 为简化外部操作。对于反手情况也支持,但is_open=False 405 | 406 | # 反手。平仓后开仓 407 | num: float = qty - self.Qty 408 | old_frozen_value: float = self._open_value 409 | new_frozen_value: float = _value_with_mult(fill_price, num, mult) 410 | self._open_value = new_frozen_value 411 | self.AvgPx = fill_price 412 | self._value_flow = new_frozen_value - old_frozen_value 413 | self._commission = commission 414 | return 415 | 416 | def to_record_position(self, rec: np.ndarray, detail: bool) -> np.ndarray: 417 | """持仓对象转持仓记录 418 | 419 | Parameters 420 | ---------- 421 | rec: np.ndarray 422 | detail: bool 423 | 424 | Returns 425 | ------- 426 | np.ndarray 427 | 428 | """ 429 | rec['mult'] = self._mult 430 | rec['margin_ratio'] = self._margin_ratio 431 | rec['amount'] = self.Amount 432 | 433 | if detail: 434 | rec['upnl'] = self.UPnL 435 | rec['value'] = self.Value 436 | rec['open_value'] = self.OpenValue 437 | rec['avg_price'] = self.AvgPx 438 | rec['last_price'] = self.LastPrice 439 | rec['margin'] = self._margin 440 | rec['asset'] = self.Asset 441 | rec['cum_pnl'] = self._cum_pnl 442 | rec['cum_commission'] = self._cum_commission 443 | 444 | return rec 445 | 446 | def to_record_trade(self, rec: np.ndarray, 447 | date: np.uint64, is_buy: bool, is_open: bool, fill_price: float, qty: float, 448 | cash: float) -> np.ndarray: 449 | """订单对象转订单记录""" 450 | rec['asset'] = self.Asset 451 | rec['amount'] = self.Amount 452 | rec['margin'] = self._margin 453 | rec['commission'] = self._commission 454 | rec['upnl'] = self.UPnL # 最新价会导致此项发生变化 455 | rec['pnl'] = self._pnl 456 | rec['cash_flow'] = self._cash_flow 457 | 458 | rec['date'] = date 459 | rec['is_buy'] = is_buy 460 | rec['is_open'] = is_open 461 | rec['fill_price'] = fill_price 462 | rec['qty'] = qty 463 | rec['cash'] = cash 464 | 465 | return rec 466 | 467 | def to_record_performance(self, rec: np.ndarray, date: np.uint64, cash: float) -> np.ndarray: 468 | """转绩效""" 469 | rec['date'] = date 470 | rec['cash'] = cash 471 | rec['asset'] = self.Asset 472 | rec['amount'] = self.Amount 473 | rec['value'] = self.Value 474 | rec['margin'] = self._margin 475 | rec['upnl'] = self.UPnL 476 | # pnl与commission只记录了产生交易舜时的值,还需要累计值 477 | rec['cum_pnl'] = self._cum_pnl 478 | rec['cum_commission'] = self._cum_commission 479 | 480 | return rec 481 | 482 | def __str__(self) -> str: 483 | # f-string的实现方法比较特殊 484 | # https://github.com/numba/numba/issues/8969 485 | with objmode(string=types.unicode_type): # declare that the "escaping" string variable is of unicode type. 486 | string = f'Position(Asset={self.Asset}, Value={self.Value}, OpenValue={self.OpenValue}, Margin={self.Margin}, UPnL={self.UPnL}, PnL={self.PnL}, Amount={self.Amount})' 487 | return string 488 | 489 | 490 | # 这种写法是为了方便开关断点调试 491 | if os.environ.get('NUMBA_DISABLE_JIT', '0') != '1': 492 | commission_fn_type = typeof(commission_0) 493 | 494 | Position = jitclass(Position, 495 | [('_commission_fn', commission_fn_type)]) 496 | -------------------------------------------------------------------------------- /lightbt/signals.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | from numba import njit, float64, bool_ 4 | 5 | 6 | def state_to_action(state: pd.DataFrame): 7 | """将持仓状态转换成下单操作。 8 | 9 | 1. 开头出现nan 10 | 2. 中段出现nan 11 | 12 | Parameters 13 | ---------- 14 | state: pd.DataFrame 15 | 持仓状态宽表。值应当都是整数,浮点和nan都是不合法数据 16 | 17 | Returns 18 | ------- 19 | pd.DataFrame 20 | 下单操作 21 | 22 | """ 23 | state = state.fillna(0) # 防止nan计算有问题 24 | action = state.diff() 25 | action.iloc[0] = state.iloc[0] 26 | return action 27 | 28 | 29 | def action_to_state(action: pd.DataFrame): 30 | """将操作转换成状态 31 | 32 | Parameters 33 | ---------- 34 | action: pd.DataFrame 35 | 下单操作宽表 36 | 37 | Examples 38 | -------- 39 | s = pd.DataFrame({ 40 | 'a': [1, 1, 2, 0, -1, 0], 41 | 'b': [np.nan, np.nan, 1, 0, 0, 0], 42 | }) 43 | print(s) 44 | a = state_to_action(s) 45 | print(a) 46 | s = action_to_state(a) 47 | print(s) 48 | 49 | """ 50 | action = action.fillna(0) 51 | return action.cumsum() 52 | 53 | 54 | @njit(float64[:](bool_[:], bool_[:], bool_[:], bool_[:], bool_, bool_), fastmath=True, nogil=True, cache=True) 55 | def signals_to_amount(is_long_entry: np.ndarray, is_long_exit: np.ndarray, 56 | is_short_entry: np.ndarray, is_short_exit: np.ndarray, 57 | accumulate: bool = False, 58 | action: bool = False) -> np.ndarray: 59 | """将4路信号转换成持仓状态。适合按资产分组后的长表 60 | 61 | 在`LongOnly`场景下,`is_short_entry`和`is_short_exit`输入数据值都为`False`即可 62 | 63 | Parameters 64 | ---------- 65 | is_long_entry: np.ndarray 66 | 是否多头入场 67 | is_long_exit: np.ndarray 68 | 是否多头出场 69 | is_short_entry: np.ndarray 70 | 是否空头入场 71 | is_short_exit: np.ndarray 72 | 是否空头出场 73 | accumulate: bool 74 | 遇到重复信号时是否累计 75 | action: bool 76 | 返回持仓状态还是下单操作 77 | 78 | Returns 79 | ------- 80 | np.ndarray 81 | 持仓状态 82 | 83 | Examples 84 | -------- 85 | ```python 86 | long_entry = np.array([True, True, False, False, False]) 87 | long_exit = np.array([False, False, True, False, False]) 88 | short_entry = np.array([False, False, True, False, False]) 89 | short_exit = np.array([False, False, False, True, False]) 90 | 91 | amount = signals_to_amount(long_entry, long_exit, short_entry, short_exit, accumulate=True, action=False) 92 | ``` 93 | 94 | """ 95 | _amount: float = 0.0 # 持仓状态 96 | _action: float = 0.0 # 下单方向 97 | out_amount = np.zeros(len(is_long_entry), dtype=float) 98 | out_action = np.zeros(len(is_long_entry), dtype=float) 99 | for i in range(len(is_long_entry)): 100 | if _amount == 0.0: 101 | # 多头信号优先级高于空头信号 102 | if is_long_entry[i]: 103 | _amount += 1.0 104 | _action = 1.0 105 | elif is_short_entry[i]: 106 | _amount -= 1.0 107 | _action = -1.0 108 | elif _amount > 0.0: 109 | if is_long_exit[i]: 110 | _amount -= 1.0 111 | _action = -1.0 112 | elif is_long_entry[i] and accumulate: 113 | _amount += 1.0 114 | _action = 1.0 115 | else: 116 | if is_short_exit[i]: 117 | _amount += 1.0 118 | _action = 1.0 119 | elif is_short_entry[i] and accumulate: 120 | _amount -= 1.0 121 | _action = -1.0 122 | out_amount[i] = _amount 123 | out_action[i] = _action 124 | 125 | if action: 126 | return out_action 127 | else: 128 | return out_amount 129 | 130 | 131 | def orders_daily(df: pd.DataFrame, sort: bool = True) -> pd.DataFrame: 132 | """ 133 | 134 | Parameters 135 | ---------- 136 | df 137 | sort: bool 138 | 默认按时间、资产名进行排序 139 | 140 | Returns 141 | ------- 142 | pd.DataFrame 143 | 1. 已经按时间进行了排序。sort_values 144 | 2. 添加了日期标记,用于触发内部的绩效快照 145 | 146 | Notes 147 | ----- 148 | 有多处修改了数据,所以需要`copy`。`sort_values`隐含了`copy` 149 | 150 | """ 151 | # 全体数据排序,并复制 152 | if sort: 153 | df = df.sort_values(by=['date', 'asset']) 154 | 155 | # 按日期标记,每段的最后一条标记为True。一定要提前排序 156 | date_0 = df['date'].dt.date 157 | df['date_diff'] = date_0 != date_0.shift(-1, fill_value=0) 158 | 159 | return df 160 | 161 | 162 | def orders_weekly(df: pd.DataFrame, sort: bool = True) -> pd.DataFrame: 163 | """ 164 | 165 | Parameters 166 | ---------- 167 | df 168 | sort: bool 169 | 默认按时间、资产名进行排序 170 | 171 | Returns 172 | ------- 173 | pd.DataFrame 174 | 1. 已经按时间进行了排序。sort_values 175 | 2. 添加了日期标记,用于触发内部的绩效快照 176 | 177 | Notes 178 | ----- 179 | 有多处修改了数据,所以需要`copy`。`sort_values`隐含了`copy` 180 | 181 | 一定得是每周只交易一次的清单,如果是每天都交易的清单输入,会将5张单子一起,先平后开,导入顺序混乱 182 | 183 | """ 184 | # 全体数据排序,并复制 185 | if sort: 186 | df = df.sort_values(by=['date', 'asset']) 187 | 188 | # 按日期标记,每段的最后一条标记为True。一定要提前排序 189 | date_0 = df['date'].dt.isocalendar().week 190 | # week的范围是1~52,所以可以fill_value写0 或>52的值 191 | df['date_diff'] = date_0 != date_0.shift(-1, fill_value=99) 192 | 193 | return df 194 | -------------------------------------------------------------------------------- /lightbt/stats.py: -------------------------------------------------------------------------------- 1 | from typing import List, Dict, Union 2 | 3 | import numpy as np 4 | import pandas as pd 5 | 6 | from lightbt.enums import trades_stats_dt, roundtrip_stats_dt, roundtrip_dt 7 | from lightbt.utils import groupby 8 | 9 | 10 | def total_equity(perf: pd.DataFrame) -> pd.DataFrame: 11 | """总权益统计。 12 | 13 | Parameters 14 | ---------- 15 | perf: pd.DataFrame 16 | 输入为`bt.performances()`的输出 17 | 18 | Returns 19 | ------- 20 | pd.DataFrame 21 | 多加了总权益 22 | 23 | Examples 24 | -------- 25 | >>> pd.options.plotting.backend = 'plotly' 26 | >>> perf = bt.performances() 27 | >>> equity = total_equity(perf) 28 | >>> equity.plot() 29 | 30 | """ 31 | # 此处的amount返回的是净持仓数量 32 | agg = {'cash': 'last', 'value': 'sum', 'margin': 'sum', 'upnl': 'sum', 'cum_pnl': 'sum', 'cum_commission': 'sum', 'amount': 'sum'} 33 | p = perf.set_index(['date', 'asset']).groupby(by=['date']).agg(agg) 34 | # 总权益曲线。cash中已经包含了pnls和commissions 35 | p['equity'] = p['cash'] + p['margin'] + p['upnl'] 36 | return p 37 | 38 | 39 | def pnl_by_asset(perf, asset: Union[int, str], close: pd.DataFrame, 40 | mapping_asset_int: Dict[str, int], 41 | mapping_int_asset: Dict[int, str]) -> pd.DataFrame: 42 | """单资产的盈亏信息 43 | 44 | Parameters 45 | ---------- 46 | perf: pd.DataFrame 47 | 输入为`bt.performances()`的输出 48 | asset: int or str 49 | 资产id 50 | close: pd.DataFrame 51 | 行情 52 | mapping_asset_int: dict 53 | 资产 字符串转数字 54 | mapping_int_asset: dict 55 | 资产 数字转字符串 56 | 57 | Returns 58 | ------- 59 | pd.DataFrame 60 | 多加了盈亏曲线 61 | 62 | Examples 63 | -------- 64 | >>> pd.options.plotting.backend = 'plotly' 65 | >>> pnls = pnls = pnl_by_asset(perf, 's_0000', df[['date', 'asset', 'CLOSE']], bt.mapping_asset_int, bt.mapping_int_asset) 66 | >>> pnls.plot() 67 | >>> pd.options.plotting.backend = 'matplotlib' 68 | >>> pnls[['PnL', 'CLOSE']].plot(secondary_y='CLOSE') 69 | 70 | """ 71 | if pd.api.types.is_string_dtype(perf['asset']): 72 | if isinstance(asset, int): 73 | asset = mapping_int_asset.get(asset) 74 | elif pd.api.types.is_integer_dtype(perf['asset']): 75 | if isinstance(asset, str): 76 | asset = mapping_asset_int.get(asset) 77 | 78 | df1 = perf[perf['asset'] == asset] 79 | 80 | if close is None: 81 | df = df1 82 | agg = {'value': 'sum', 'margin': 'sum', 'upnl': 'sum', 'cum_pnl': 'sum', 'cum_commission': 'sum', 'amount': 'sum'} 83 | else: 84 | if pd.api.types.is_string_dtype(close['asset']): 85 | if isinstance(asset, int): 86 | close = close.copy() 87 | close['asset'] = close['asset'].map(mapping_asset_int) 88 | elif pd.api.types.is_integer_dtype(close['asset']): 89 | if isinstance(asset, str): 90 | close = close.copy() 91 | close['asset'] = close['asset'].map(mapping_int_asset) 92 | 93 | df2 = close[close['asset'] == asset] 94 | df = pd.merge(left=df1, right=df2, left_on=['date', 'asset'], right_on=['date', 'asset']) 95 | agg = {'value': 'sum', 'margin': 'sum', 'upnl': 'sum', 'cum_pnl': 'sum', 'cum_commission': 'sum', 'amount': 'sum', close.columns[-1]: 'last'} 96 | 97 | p = df.set_index(['date', 'asset']).groupby(by=['date']).agg(agg) 98 | # 盈亏曲线=持仓盈亏+累计盈亏+累计手续费 99 | p['PnL'] = p['upnl'] + p['cum_pnl'] - p['cum_commission'] 100 | return p 101 | 102 | 103 | def pnl_by_assets(perf: pd.DataFrame, 104 | assets: Union[List[str], List[int]], 105 | mapping_asset_int: Dict[str, int], 106 | mapping_int_asset: Dict[int, str]) -> pd.DataFrame: 107 | """多个资产的盈亏曲线 108 | 109 | Parameters 110 | ---------- 111 | perf: pd.DataFrame 112 | 输入为`bt.performances()`的输出 113 | assets: list[int] or list[str] 114 | 关注的资产列表 115 | mapping_asset_int: dict 116 | 资产 字符串转数字 117 | mapping_int_asset: dict 118 | 资产 数字转字符串 119 | 120 | Returns 121 | ------- 122 | pd.DataFrame 123 | 多资产盈亏曲线矩阵 124 | 125 | Examples 126 | -------- 127 | >>> perf = bt.performances() 128 | >>> pnls = pnl_by_assets(perf, ['s_0000', 's_0100', 's_0300'], bt.mapping_asset_int, bt.mapping_int_asset) 129 | >>> pnls.plot() 130 | 131 | """ 132 | if pd.api.types.is_string_dtype(perf['asset']): 133 | if isinstance(assets[0], int): 134 | assets = list(map(mapping_int_asset.get, assets)) 135 | elif pd.api.types.is_integer_dtype(perf['asset']): 136 | if isinstance(assets[0], str): 137 | assets = list(map(mapping_asset_int.get, assets)) 138 | 139 | # 单资产的盈亏曲线 140 | df = perf[perf['asset'].isin(assets)] 141 | df = df.set_index(['date', 'asset']) 142 | df['PnL'] = df['upnl'] + df['cum_pnl'] - df['cum_commission'] 143 | return df['PnL'].unstack().ffill() 144 | 145 | 146 | def trades_to_roundtrips(trades: np.ndarray, asset_count: int) -> np.ndarray: 147 | """多笔成交转为成对的交易轮 148 | 149 | Parameters 150 | ---------- 151 | trades: np.ndarray 152 | 全体成交记录 153 | asset_count: int 154 | 总资产数。用于分配足够的空间用于返回 155 | 156 | Returns 157 | ------- 158 | np.ndarray 159 | 160 | """ 161 | trades = trades[trades['asset'].argsort(kind='stable')] 162 | groups = groupby(trades, by='asset', dtype=None) 163 | 164 | records = np.zeros(len(trades) // 2 + 1 + asset_count, dtype=roundtrip_dt) 165 | k = 0 # 目标位置 166 | for group in groups: 167 | # 每段开始位置 168 | flag = np.ones(shape=len(group) + 1, dtype=bool) 169 | flag[1:] = group['amount'] == 0.0 170 | flag[-1] = True 171 | idx = np.argwhere(flag).flatten() 172 | for i, j in zip(idx[:-1], idx[1:]): 173 | g = group[i:j] 174 | rec = records[k] 175 | 176 | is_open = g[g['is_open']] 177 | is_close = g[~g['is_open']] 178 | 179 | rec['asset'] = g['asset'][0] 180 | rec['is_long'] = g['is_buy'][0] 181 | rec['is_close'] = g['amount'][-1] == 0.0 182 | rec['qty'] = np.sum(is_open['qty']) 183 | rec['entry_date'] = g['date'][0] 184 | rec['entry_price'] = np.mean(is_open['fill_price']) 185 | rec['entry_commission'] = np.sum(is_open['commission']) 186 | rec['exit_date'] = g['date'][-1] 187 | if len(is_close) > 0: 188 | rec['exit_price'] = np.mean(is_close['fill_price']) 189 | rec['exit_commission'] = np.sum(is_close['commission']) 190 | rec['pnl'] = np.sum(g['pnl']) 191 | rec['pnl_com'] = rec['pnl'] - rec['entry_commission'] - rec['exit_commission'] 192 | 193 | k += 1 194 | return records[:k] 195 | 196 | 197 | def calc_roundtrips_stats(roundtrips: np.ndarray, asset_count: int) -> np.ndarray: 198 | """每轮交易统计 199 | 200 | Parameters 201 | ---------- 202 | roundtrips: np.ndarray 203 | 全体每轮交易 204 | asset_count: int 205 | 总资产数。用于分配足够的空间用于返回 206 | 207 | Returns 208 | ------- 209 | np.ndarray 210 | 211 | """ 212 | roundtrips = roundtrips[roundtrips['asset'].argsort(kind='stable')] 213 | groups = groupby(roundtrips, by='asset', dtype=None) 214 | 215 | records = np.zeros(asset_count, dtype=roundtrip_stats_dt) 216 | i = 0 217 | for i, g in enumerate(groups): 218 | rec = records[i] 219 | 220 | is_long = g[g['is_long']] 221 | is_short = g[~g['is_long']] 222 | winning = g[g['pnl_com'] > 0.0] 223 | losing = g[g['pnl_com'] < 0.0] 224 | 225 | rec['asset'] = g['asset'][0] 226 | rec['total_count'] = len(g) 227 | rec['long_count'] = len(is_long) 228 | rec['short_count'] = len(is_short) 229 | rec['long_count'] = len(is_long) 230 | rec['short_count'] = len(is_short) 231 | rec['winning_count'] = len(winning) 232 | rec['losing_count'] = len(losing) 233 | rec['win_rate'] = rec['winning_count'] / rec['total_count'] 234 | 235 | return records[:i + 1] 236 | 237 | 238 | def calc_trades_stats(trades: np.ndarray, asset_count: int) -> np.ndarray: 239 | """成交统计 240 | 241 | Parameters 242 | ---------- 243 | trades: np.ndarray 244 | 全体交易记录 245 | asset_count: int 246 | 总资产数。用于分配足够的空间用于返回 247 | 248 | Returns 249 | ------- 250 | np.ndarray 251 | 252 | """ 253 | trades = trades[trades['asset'].argsort(kind='stable')] # stable一定要有,否则乱序 254 | groups = groupby(trades, by='asset', dtype=None) 255 | 256 | records = np.zeros(asset_count, dtype=trades_stats_dt) 257 | i = 0 258 | for i, g in enumerate(groups): 259 | rec = records[i] 260 | 261 | is_buy = g[g['is_buy']] 262 | is_sell = g[~g['is_buy']] 263 | 264 | rec['asset'] = g['asset'][0] 265 | rec['start'] = g['date'][0] 266 | rec['end'] = g['date'][-1] 267 | rec['period'] = rec['end'] - rec['start'] 268 | rec['total_count'] = len(g) 269 | rec['buy_count'] = len(is_buy) 270 | rec['sell_count'] = len(is_sell) 271 | rec['min_qty'] = np.min(g['qty']) 272 | rec['max_qty'] = np.max(g['qty']) 273 | rec['avg_qty'] = np.mean(g['qty']) 274 | rec['total_commission'] = np.sum(g['commission']) 275 | rec['min_commission'] = np.min(g['commission']) 276 | rec['max_commission'] = np.max(g['commission']) 277 | rec['avg_commission'] = np.mean(g['commission']) 278 | 279 | if len(is_buy) > 0: 280 | rec['avg_buy_qty'] = np.mean(is_buy['qty']) 281 | rec['avg_buy_price'] = np.mean(is_buy['fill_price']) 282 | rec['avg_buy_commission'] = np.mean(is_buy['commission']) 283 | if len(is_sell) > 0: 284 | rec['avg_sell_qty'] = np.mean(is_sell['qty']) 285 | rec['avg_sell_price'] = np.mean(is_sell['fill_price']) 286 | rec['avg_sell_commission'] = np.mean(is_sell['commission']) 287 | 288 | return records[:i + 1] 289 | -------------------------------------------------------------------------------- /lightbt/utils.py: -------------------------------------------------------------------------------- 1 | import time 2 | from typing import Optional 3 | 4 | import numpy as np 5 | import pandas as pd 6 | 7 | 8 | def timer(func): 9 | def wrapper(*args, **kwargs): 10 | start_time = time.perf_counter() 11 | result = func(*args, **kwargs) 12 | end_time = time.perf_counter() 13 | print(f"{func.__name__} executed in {end_time - start_time} seconds") 14 | return result 15 | 16 | return wrapper 17 | 18 | 19 | class Timer: 20 | def __init__(self): 21 | self.start_time = None 22 | 23 | def __enter__(self): 24 | self.start_time = time.perf_counter() 25 | return self 26 | 27 | def __exit__(self, exc_type, exc_val, exc_tb): 28 | end_time = time.perf_counter() 29 | print(f"code executed in {end_time - self.start_time} seconds") 30 | 31 | 32 | def get_dtypes(dtype): 33 | return {a: b for a, b in dtype.descr} 34 | 35 | 36 | def groupby(df: pd.DataFrame, by: str, dtype: Optional[np.dtype] = None) -> np.ndarray: 37 | """简版数据分组 38 | 39 | Parameters 40 | ---------- 41 | df: pd.DataFrame 42 | by: str 43 | dtype: np.dtype 44 | 指定类型能提速 45 | 46 | Returns 47 | ------- 48 | np.ndarray 49 | 迭代器 50 | 51 | Notes 52 | ----- 53 | `df`一定要提前按`by`排序,否则结果是错的 54 | 55 | """ 56 | if dtype is not None: 57 | # 控制同样的位置。否则record转dtype失败会导致效率低 58 | df = df[list(dtype.names)] 59 | 60 | if isinstance(df, pd.DataFrame): 61 | # recarray转np.ndarray 62 | arr = df.to_records(index=False, column_dtypes=get_dtypes(dtype)) 63 | 64 | # 这里支持复合分组 65 | idx = df.groupby(by=by)['asset'].count().cumsum().to_numpy() 66 | idx = np.insert(idx, 0, 0) 67 | else: 68 | # 原数据是否需要复制?从代码上看没有复制之处 69 | arr = df # .copy() 70 | 71 | dt = arr[by] 72 | flag = np.ones(shape=len(dt) + 1, dtype=bool) 73 | # 前后都为True 74 | flag[1:-1] = dt[:-1] != dt[1:] 75 | idx = np.argwhere(flag).flatten() 76 | 77 | for i, j in zip(idx[:-1], idx[1:]): 78 | # 由于标记的位置正好是每段的开始位置,所以j不需加1 79 | yield arr[i:j] 80 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "lightbt" 3 | authors = [ 4 | { name = "wukan", email = "wu-kan@163.com" }, 5 | ] 6 | description = "lightweight backtester" 7 | readme = "README.md" 8 | requires-python = ">=3.7" 9 | keywords = ["backtest"] 10 | license = { text = "BSD-3-Clause" } 11 | classifiers = [ 12 | "Development Status :: 4 - Beta", 13 | "Programming Language :: Python" 14 | ] 15 | dependencies = [ 16 | "numba>=0.57.1", 17 | ] 18 | dynamic = ["version"] 19 | 20 | [build-system] 21 | requires = ["hatchling"] 22 | build-backend = "hatchling.build" 23 | 24 | [tool.hatch.version] 25 | path = "lightbt/_version.py" 26 | 27 | [tool.hatch.build.targets.wheel] 28 | packages = ["lightbt"] 29 | include-package-data = true 30 | 31 | [tool.hatch.build.targets.sdist] 32 | include = ["lightbt*"] -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | pandas 3 | numba -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup() 4 | --------------------------------------------------------------------------------