├── .github └── workflows │ └── python-publish.yml ├── .gitignore ├── LICENSE ├── README.md ├── codes ├── demo_exec.py └── prepare_data.py ├── data └── prepare_data.py ├── deprecated └── polars_group │ ├── __init__.py │ ├── code.py │ ├── printer.py │ └── template.py.j2 ├── examples ├── demo_exec_pl.py ├── demo_express.py ├── demo_min.py ├── demo_sql.py ├── demo_tdx.py ├── demo_transformer.py ├── output_alpha101.py ├── output_pandas.py ├── output_polars.py ├── prefilter.py ├── show_tree.py ├── sympy_define.py └── tail_n.py ├── expr_codegen ├── __init__.py ├── _version.py ├── codes.py ├── dag.py ├── expr.py ├── latex │ ├── __init__.py │ └── printer.py ├── model.py ├── pandas │ ├── __init__.py │ ├── code.py │ ├── helper.py │ ├── printer.py │ ├── ta.py │ └── template.py.j2 ├── polars │ ├── __init__.py │ ├── code.py │ ├── printer.py │ └── template.py.j2 ├── sql │ ├── __init__.py │ ├── code.py │ ├── printer.py │ └── template.sql.j2 └── tool.py ├── pyproject.toml ├── requirements.txt ├── streamlit_app.py └── tests ├── expr_order.py ├── formula_transformer.py ├── speed_pandas.py ├── speed_pandas2.py ├── speed_polars.py └── speed_polars2.py /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflow will upload a Python Package using Twine when a release is created 2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#publishing-to-package-registries 3 | 4 | # This workflow uses actions that are not certified by GitHub. 5 | # They are provided by a third-party and are governed by 6 | # separate terms of service, privacy policy, and support 7 | # documentation. 8 | 9 | name: Upload Python Package 10 | 11 | on: 12 | release: 13 | types: [published] 14 | 15 | permissions: 16 | contents: read 17 | 18 | jobs: 19 | deploy: 20 | 21 | runs-on: ubuntu-latest 22 | 23 | steps: 24 | - uses: actions/checkout@v3 25 | - name: Set up Python 26 | uses: actions/setup-python@v3 27 | with: 28 | python-version: '3.x' 29 | - name: Install dependencies 30 | run: | 31 | python -m pip install --upgrade pip 32 | pip install build 33 | - name: Build package 34 | run: python -m build 35 | - name: Publish package 36 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 37 | with: 38 | user: __token__ 39 | password: ${{ secrets.PYPI_API_TOKEN }} 40 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2023, wukan 4 | 5 | Redistribution and use in source and binary forms, with or without 6 | modification, are permitted provided that the following conditions are met: 7 | 8 | 1. Redistributions of source code must retain the above copyright notice, this 9 | list of conditions and the following disclaimer. 10 | 11 | 2. Redistributions in binary form must reproduce the above copyright notice, 12 | this list of conditions and the following disclaimer in the documentation 13 | and/or other materials provided with the distribution. 14 | 15 | 3. Neither the name of the copyright holder nor the names of its 16 | contributors may be used to endorse or promote products derived from 17 | this software without specific prior written permission. 18 | 19 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 22 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 23 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 24 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 25 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 26 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 27 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # expr_codegen 表达式转译器 2 | 3 | ## 项目背景 4 | 5 | 在本人新推出[polars_ta](https://github.com/wukan1986/polars_ta)这个库后,再回头反思`expr_codegen`是什么。 6 | 7 | > `expr_codegen`本质是`DSL`,领域特定语⾔(Domain Specific Language)。但它没有定义新的语法 8 | 9 | 它解决了两个问题: 10 | 11 | 1. `polars_ta`已经能很方便的写出特征计算表达式,但遇到`混用时序与截面`的表达式,利用`expr_codegen`能自动分组大大节省工作 12 | 2. `expr_codegen`利用了`Common Subexpression Elimination`公共子表达式消除,大量减少重复计算,提高效率 13 | 14 | 就算在量化领域,初级研究员局限于时序指标,仅用`polars_ta`即可,中高级研究员使用截面指标,推荐用`expr_codegen` 15 | 16 | 虽然现在此项目与`polars_ta`依赖非常紧密,但也是支持翻译成其它库,如`pandas / cudf.pandas`,只是目前缺乏一个比较简易的库 17 | 18 | ## 在线演示 19 | 20 | https://exprcodegen.streamlit.app 21 | 22 | 初级用户可以直接访问此链接进行表达式转译,不需要另外安装软件。(此工具免费部署在国外,打开可能有些慢) 23 | 24 | 更完整示例访问[alpha_examples](https://github.com/wukan1986/alpha_examples) 25 | 26 | ## 使用示例 27 | 28 | ```python 29 | import sys 30 | from io import StringIO 31 | 32 | import polars as pl 33 | 34 | from expr_codegen import codegen_exec 35 | 36 | 37 | def _code_block_1(): 38 | # 因子编辑区,可利用IDE的智能提示在此区域编辑因子 39 | LOG_MC_ZS = cs_mad_zscore(log1p(market_cap)) 40 | 41 | 42 | def _code_block_2(): 43 | # 模板中已经默认导入了from polars_ta.prefix下大量的算子,但 44 | # talib在模板中没有默认导入。这种写法可实现在生成的代码中导入 45 | from polars_ta.prefix.talib import ts_LINEARREG_SLOPE # noqa 46 | 47 | # 1. 下划线开头的变量只是中间变量,会被自动更名,最终输出时会被剔除 48 | # 2. 下划线开头的变量可以重复使用。多个复杂因子多行书写时有重复中间变时不再冲突 49 | _avg = ts_mean(corr, 20) 50 | _std = ts_std_dev(corr, 20) 51 | _beta = ts_LINEARREG_SLOPE(corr, 20) 52 | 53 | # 3. 下划线开头的变量有环循环赋值。在调试时可快速用注释进行切换 54 | _avg = cs_mad_zscore_resid(_avg, LOG_MC_ZS, ONE) 55 | _std = cs_mad_zscore_resid(_std, LOG_MC_ZS, ONE) 56 | # _beta = cs_mad_zscore_resid(_beta, LOG_MC_ZS, ONE) 57 | 58 | _corr = cs_zscore(_avg) + cs_zscore(_std) 59 | CPV = cs_zscore(_corr) + cs_zscore(_beta) 60 | 61 | 62 | code = codegen_exec(None, _code_block_1, _code_block_2, over_null='partition_by', output_file=sys.stdout) # 打印代码 63 | code = codegen_exec(None, _code_block_1, _code_block_2, over_null='partition_by', output_file="output.py") # 保存到文件 64 | code = codegen_exec(None, _code_block_1, _code_block_2, over_null='partition_by') # 只执行,不保存代码 65 | 66 | code = StringIO() 67 | codegen_exec(None, _code_block_1, _code_block_2, over_null='partition_by', output_file=code) # 保存到字符串 68 | code.seek(0) 69 | code.read() # 读取代码 70 | 71 | # TODO 替换成合适的数据 72 | df = pl.DataFrame() 73 | df = codegen_exec(df.lazy(), _code_block_1, _code_block_2, over_null='partition_by').collect() # Lazy CPU 74 | df = codegen_exec(df.lazy(), _code_block_1, _code_block_2, over_null='partition_by').collect(engine="gpu") # Lazy GPU 75 | 76 | ``` 77 | 78 | ## 目录结构 79 | 80 | ```commandline 81 | │ requirements.txt # 通过`pip install -r requirements.txt`安装依赖 82 | ├─data 83 | │ prepare_date.py # 准备数据 84 | ├─examples 85 | │ demo_express.py # 速成示例。演示如何将表达式转换成代码 86 | │ demo_exec_pl.py # 演示调用转换后代码并绘图 87 | │ demo_transformer.py # 演示将第三方表达式转成内部表达式 88 | │ output.py # 结果输出。可不修改代码,直接被其它项目导入 89 | │ show_tree.py # 画表达式树形图。可用于分析对比优化结果 90 | │ sympy_define.py # 符号定义,由于太多地方重复使用到,所以统一提取到此处 91 | ├─expr_codegen 92 | │ │ expr.py # 表达式处理基本函数 93 | │ │ tool.py # 核心工具代码 94 | │ ├─polars 95 | │ │ │ code.py # 针对polars语法的代码生成功能 96 | │ │ │ template.py.j2 # `Jinja2`模板。用于生成对应py文件,一般不需修改 97 | │ │ │ printer.py # 继承于`Sympy`中的`StrPrinter`,添加新函数时可能需修改此文件 98 | ``` 99 | 100 | ## 工作原理 101 | 102 | 本项目依赖于`sympy`项目。所用到的主要函数如下: 103 | 104 | 1. `simplify`: 对复杂表达式进行化简 105 | 2. `cse`: `Common Subexpression Elimination`公共子表达式消除 106 | 3. `StrPrinter`: 根据不同的函数输出不同字符串。定制此代码可以支持其它语种或库 107 | 108 | 因为`groupby`,`sort`都比较占用时间。如果提前将公式分类,不同的类别使用不同的`groupby`,可以减少计算时间。 109 | 110 | 1. `ts_xxx(ts_xxx)`: 可在同一`groupby`中进行计算 111 | 2. `cs_xxx(cs_xxx)`: 可在同一`groupby`中进行计算 112 | 3. `ts_xxx(cs_xxx)`: 需在不同`groupby`中进行计算 113 | 4. `cs_xxx(ts_xxx(cs_xxx))`: 需三不同`groupby`中进行计算 114 | 5. `gp_xxx(aa, )+gp_xxx(bb, )`: 因`aa`,`bb`不同,需在两不同`groupby`中进行计算 115 | 116 | 所以 117 | 118 | 1. 需要有一个函数能获取当前表达式的类别(`get_current`)和子表达式的类别(`get_children`) 119 | 2. 如果当前类别与子类别不同就可以提取出短公式(`extract`)。不同层的同类别表达式有先后关系,不能放同一`groupby` 120 | 3. 利用`cse`的特点,将长表达式替换成前期提取出来的短表达式。然后输入到有向无环图(`DAG`) 121 | 4. 利用有向无环图的流转,进行分层。同一层的`ts`,`cs`,`gp`不区分先后 122 | 5. 同一层对`ts`,`cs`,`gp`分组,然后生成代码(`codegen`)即可 123 | 124 | 隐含信息 125 | 126 | 1. `ts`: sort(by=[ASSET, DATE]).groupby(by=[ASSET], maintain_order=True) 127 | 2. `cs`: sort(by=[DATE]).groupby(by=[DATE], maintain_order=False) 128 | 3. `gp`: sort(by=[DATE, GROUP]).groupby(by=[DATE, GROUP], maintain_order=False) 129 | 130 | 即 131 | 132 | 1. 时序函数隐藏了两个字段`ASSET, DATE`,横截面函数了隐藏了一个字段`DATE` 133 | 2. 分组函数转入了一个字段`GROUP`,同时隐藏了一个字段`DATE` 134 | 135 | 两种分类方法 136 | 137 | 1. 根据算子前缀分类(`get_current_by_prefix`),限制算子必需以`ts_`、`cs_`、`gp_`开头 138 | 2. 根据算子全名分类(`get_current_by_name`), 不再限制算子名。比如`cs_rank`可以叫`rank` 139 | 140 | ## Null处理 141 | 142 | `null`是如何产生的? 143 | 144 | 1. 停牌导致。在计算前就直接过滤掉了,不会对后续计算产生影响。 145 | 2. 不同品种交易时段不同 146 | 3. 计算产生。`null`在数列两端不影响后续时序算子结果,但中间出现`null`会影响。例如: `if_else(close<2, None, close)` 147 | 148 | https://github.com/pola-rs/polars/issues/12925#issuecomment-2552764629 149 | 150 | 非常棒的点子,总结下来有两种实现方式: 151 | 152 | 1. 将`null`分成一组,`not_null`分成另一组。要调用两次 153 | 2. 仅一组,但复合排序,将`null`排在前面,`not_null`排后面。只调用一次,略快一些 154 | 155 | ```python 156 | X1 = (ts_returns(CLOSE, 3)).over(CLOSE.is_not_null(), _ASSET_, order_by=_DATE_), 157 | X2 = (ts_returns(CLOSE, 3)).over(_ASSET_, order_by=[CLOSE.is_not_null(), _DATE_]), 158 | X3 = (ts_returns(CLOSE, 3)).over(_ASSET_, order_by=_DATE_), 159 | ``` 160 | 161 | 第2种开头的`null`区域,是否影响结果由算子所决定,特别时是多列输入时`null`区域可能有数据 162 | 163 | 1. `over_null='partition_by'`。分到两个区域 164 | 2. `over_null='order_by'`。分到一个区域,`null`排在前面 165 | 3. `over_null=None`。不处理,直接调用,速度更快。如果确信不会中段产生`null`建议使用此参数 166 | 167 | `codegen_exec(over_null='partition_by')`为全局使用`partition_by`。但遇到`ts_count_nulls`这类`null` 168 | 函数就得使用`over_null=None`,所以本工具还新添了注释功能来指定单行表达式参数 169 | 170 | 1. `# --over_null partition_by`。单行`over_null='partition_by'` 171 | 2. `# --over_null=order_by`。单行`over_null='order_by'` 172 | 3. `# --over_null`。单行`over_null=None` 173 | 4. `# `。取`codegen_exec`参数传入的`over_null`值 174 | 175 | 注意: 176 | 177 | 1. `# --over_null`传参注释只能写在单行表达式的后面,不能独立成一行,否则会被忽略 178 | 2. `# --over_null # --over_null=order_by`多个`#`时,只取第一个有效 179 | 3. 只对最外层`ts`函数有效。如果`ts`函数不在外层,需要人工提炼。例如: 180 | ```python 181 | X1 = cs_rank(ts_mean(CLOSE, 3)) # --over_null=order_by # 应用在cs_rank上,没有意义 182 | X2 = ts_rank(ts_mean(CLOSE, 3), 5) # --over_null=order_by # 本以为应用在ts_rank(ts_mean)上,但由于出现了公共ts_mean,其实是应用在ts_rank(_x_0)上 183 | ``` 184 | 185 | 需写成 186 | 187 | ```python 188 | _x_0 = ts_mean(CLOSE, 3) # --over_null=order_by 189 | X1 = cs_rank(_x_0) 190 | X2 = ts_rank(_x_0, 5) 191 | ``` 192 | 4. 由于很容易搞错,强烈建议生成`output_file`,检查生成的代码是否正确。 193 | 194 | ## `expr_codegen`局限性 195 | 196 | 1. `DAG`只能增加列无法删除。增加列时,遇到同名列会覆盖 197 | 2. 不支持`删除行`,但可以添加删除标记列,然后在外进行删除行。删除行影响了所有列,不满足`DAG` 198 | 3. 不支持`重采样`,原理同不支持删除行。需在外进行 199 | 4. 可以将`删除行`与`重采样`做为分割线,一大块代码分成多个`DAG`串联。复杂不易理解,所以最终没有实现 200 | 201 | ## 特别语法 202 | 203 | 1. 支持`C?T:F`三元表达式(仅可字符串中使用),底层会先转成`C or True if( T )else F`,然后修正成`T if C else F` 204 | ,最后转成`if_else(C,T,F)`。支持与`if else`混用 205 | 2. `(A pl.DataFrame: 237 | df = df.sort(by=[_DATE_]) 238 | # ======================================== 239 | df = df.with_columns( 240 | _x_0=1 / ts_delay(OPEN, -1), 241 | LABEL_CC_1=(-CLOSE + ts_delay(CLOSE, -1)) / CLOSE, 242 | ) 243 | # ======================================== 244 | df = df.with_columns( 245 | LABEL_OO_1=_x_0 * ts_delay(OPEN, -2) - 1, 246 | LABEL_OO_2=_x_0 * ts_delay(OPEN, -3) - 1, 247 | ) 248 | return df 249 | ``` 250 | 251 | 转译后的代码片段,详细代码请参考[Pandas版](examples/output_pandas.py) 252 | 253 | ```python 254 | def func_2_cs__date(df: pd.DataFrame) -> pd.DataFrame: 255 | # expr_4 = cs_rank(x_7) 256 | df["expr_4"] = (df["x_7"]).rank(pct=True) 257 | return df 258 | 259 | 260 | def func_3_ts__asset__date(df: pd.DataFrame) -> pd.DataFrame: 261 | # expr_5 = -ts_corr(OPEN, CLOSE, 10) 262 | df["expr_5"] = -(df["OPEN"]).rolling(10).corr(df["CLOSE"]) 263 | # expr_6 = ts_delta(OPEN, 10) 264 | df["expr_6"] = df["OPEN"].diff(10) 265 | return df 266 | 267 | ``` 268 | 269 | ## 本地部署交互网页 270 | 271 | 只需运行`streamlit run streamlit_app.py` 272 | -------------------------------------------------------------------------------- /codes/demo_exec.py: -------------------------------------------------------------------------------- 1 | # this code is auto generated by the expr_codegen 2 | # https://github.com/wukan1986/expr_codegen 3 | # 此段代码由 expr_codegen 自动生成,欢迎提交 issue 或 pull request 4 | import re 5 | 6 | import numpy as np # noqa 7 | import pandas as pd # noqa 8 | import polars as pl # noqa 9 | import polars.selectors as cs # noqa 10 | from loguru import logger # noqa 11 | 12 | # =================================== 13 | # 导入优先级,例如:ts_RSI在ta与talib中都出现了,优先使用ta 14 | # 运行时,后导入覆盖前导入,但IDE智能提示是显示先导入的 15 | _ = pl # 只要之前出现了语句,之后的import位置不参与调整 16 | # from polars_ta.prefix.talib import * # noqa 17 | from polars_ta.prefix.tdx import * # noqa 18 | from polars_ta.prefix.ta import * # noqa 19 | from polars_ta.prefix.wq import * # noqa 20 | 21 | # =================================== 22 | 23 | 24 | _ = ( 25 | "CLOSE", 26 | "移动平均_10", 27 | ) 28 | ( 29 | CLOSE, 30 | 移动平均_10, 31 | ) = (pl.col(i) for i in _) 32 | 33 | _ = ( 34 | "移动平均_10", 35 | "移动平均_20", 36 | "MAMA_20", 37 | ) 38 | ( 39 | 移动平均_10, 40 | 移动平均_20, 41 | MAMA_20, 42 | ) = (pl.col(i) for i in _) 43 | 44 | _DATE_ = "date" 45 | _ASSET_ = "asset" 46 | 47 | 48 | def func_0_ts__asset(df: pl.DataFrame) -> pl.DataFrame: 49 | df = df.sort(by=[_DATE_]) 50 | # ======================================== 51 | df = df.with_columns( 52 | 移动平均_10=ts_mean(CLOSE, 10), 53 | 移动平均_20=ts_mean(CLOSE, 20), 54 | ) 55 | # ======================================== 56 | df = df.with_columns( 57 | MAMA_20=ts_mean(移动平均_10, 20), 58 | ) 59 | return df 60 | 61 | 62 | """ 63 | #========================================func_0_ts__asset 64 | 移动平均_10 = ts_mean(CLOSE, 10) 65 | 移动平均_20 = ts_mean(CLOSE, 20) 66 | #========================================func_0_ts__asset 67 | MAMA_20 = ts_mean(移动平均_10, 20) 68 | """ 69 | 70 | """ 71 | 移动平均_10 = ts_mean(CLOSE, 10) 72 | 移动平均_20 = ts_mean(CLOSE, 20) 73 | MAMA_20 = ts_mean(移动平均_10, 20) 74 | """ 75 | 76 | 77 | def main(df: pl.DataFrame): 78 | # logger.info("start...") 79 | 80 | df = df.sort(by=[_DATE_, _ASSET_]) 81 | df = df.group_by(by=[_ASSET_]).map_groups(func_0_ts__asset) 82 | 83 | # drop intermediate columns 84 | df = df.drop(columns=list(filter(lambda x: re.search(r"^_x_\d+", x), df.columns))) 85 | 86 | # shrink 87 | df = df.select(cs.all().shrink_dtype()) 88 | df = df.shrink_to_fit() 89 | 90 | # logger.info('done') 91 | 92 | # save 93 | # df.write_parquet('output.parquet', compression='zstd') 94 | 95 | return df 96 | 97 | 98 | if __name__ in ("__main__", "builtins"): 99 | # TODO: 数据加载或外部传入 100 | df_output = main(df_input) 101 | -------------------------------------------------------------------------------- /codes/prepare_data.py: -------------------------------------------------------------------------------- 1 | # this code is auto generated by the expr_codegen 2 | # https://github.com/wukan1986/expr_codegen 3 | # 此段代码由 expr_codegen 自动生成,欢迎提交 issue 或 pull request 4 | 5 | import numpy as np # noqa 6 | import pandas as pd # noqa 7 | import polars as pl # noqa 8 | import polars.selectors as cs # noqa 9 | from loguru import logger # noqa 10 | 11 | # =================================== 12 | # 导入优先级,例如:ts_RSI在ta与talib中都出现了,优先使用ta 13 | # 运行时,后导入覆盖前导入,但IDE智能提示是显示先导入的 14 | _ = 0 # 只要之前出现了语句,之后的import位置不参与调整 15 | # from polars_ta.prefix.talib import * # noqa 16 | from polars_ta.prefix.tdx import * # noqa 17 | from polars_ta.prefix.ta import * # noqa 18 | from polars_ta.prefix.wq import * # noqa 19 | from polars_ta.prefix.cdl import * # noqa 20 | 21 | # =================================== 22 | 23 | _ = ( 24 | "OPEN", 25 | "CLOSE", 26 | ) 27 | ( 28 | OPEN, 29 | CLOSE, 30 | ) = (pl.col(i) for i in _) 31 | 32 | _ = ( 33 | "_x_0", 34 | "RETURN_CC_1", 35 | "RETURN_OO_1", 36 | "RETURN_OO_2", 37 | "RETURN_OO_5", 38 | ) 39 | ( 40 | _x_0, 41 | RETURN_CC_1, 42 | RETURN_OO_1, 43 | RETURN_OO_2, 44 | RETURN_OO_5, 45 | ) = (pl.col(i) for i in _) 46 | 47 | _DATE_ = "date" 48 | _ASSET_ = "asset" 49 | 50 | 51 | def func_0_ts__asset(df: pl.DataFrame) -> pl.DataFrame: 52 | df = df.sort(_DATE_) 53 | # ======================================== 54 | df = df.with_columns( 55 | _x_0=1 / ts_delay(OPEN, -1), 56 | RETURN_CC_1=(-CLOSE + ts_delay(CLOSE, -1)) / CLOSE, 57 | ) 58 | # ======================================== 59 | df = df.with_columns( 60 | RETURN_OO_1=_x_0 * ts_delay(OPEN, -2) - 1, 61 | RETURN_OO_2=_x_0 * ts_delay(OPEN, -3) - 1, 62 | RETURN_OO_5=_x_0 * ts_delay(OPEN, -6) - 1, 63 | ) 64 | return df 65 | 66 | 67 | """ 68 | #========================================func_0_ts__asset 69 | _x_0 = 1/ts_delay(OPEN, -1) 70 | RETURN_CC_1 = (-CLOSE + ts_delay(CLOSE, -1))/CLOSE 71 | #========================================func_0_ts__asset 72 | RETURN_OO_1 = _x_0*ts_delay(OPEN, -2) - 1 73 | RETURN_OO_2 = _x_0*ts_delay(OPEN, -3) - 1 74 | RETURN_OO_5 = _x_0*ts_delay(OPEN, -6) - 1 75 | """ 76 | 77 | """ 78 | RETURN_OO_1 = ts_delay(OPEN, -2)/ts_delay(OPEN, -1) - 1 79 | RETURN_OO_2 = ts_delay(OPEN, -3)/ts_delay(OPEN, -1) - 1 80 | RETURN_OO_5 = ts_delay(OPEN, -6)/ts_delay(OPEN, -1) - 1 81 | RETURN_CC_1 = -1 + ts_delay(CLOSE, -1)/CLOSE 82 | """ 83 | 84 | 85 | def main(df: pl.DataFrame) -> pl.DataFrame: 86 | # logger.info("start...") 87 | 88 | df = df.sort(_ASSET_, _DATE_).group_by(_ASSET_).map_groups(func_0_ts__asset).drop(*["_x_0"]) 89 | 90 | # drop intermediate columns 91 | # df = df.select(pl.exclude(r'^_x_\d+$')) 92 | df = df.select(~cs.starts_with("_")) 93 | 94 | # shrink 95 | df = df.select(cs.all().shrink_dtype()) 96 | df = df.shrink_to_fit() 97 | 98 | # logger.info('done') 99 | 100 | # save 101 | # df.write_parquet('output.parquet') 102 | 103 | return df 104 | 105 | 106 | if __name__ in ("__main__", "builtins"): 107 | # TODO: 数据加载或外部传入 108 | df_output = main(df_input) -------------------------------------------------------------------------------- /data/prepare_data.py: -------------------------------------------------------------------------------- 1 | """ 2 | 准备测试数据,可用于遗传算法 3 | 4 | """ 5 | import numpy as np 6 | import pandas as pd 7 | import polars as pl 8 | 9 | _N = 250 * 10 10 | _K = 500 11 | 12 | asset = [f's_{i:04d}' for i in range(_K)] 13 | date = pd.date_range('2015-1-1', periods=_N) 14 | 15 | df = pd.DataFrame({ 16 | 'OPEN': np.cumprod(1 + np.random.uniform(-0.1, 0.1, size=(_N, _K)), axis=0).reshape(-1), 17 | 'HIGH': np.cumprod(1 + np.random.uniform(-0.1, 0.1, size=(_N, _K)), axis=0).reshape(-1), 18 | 'LOW': np.cumprod(1 + np.random.uniform(-0.1, 0.1, size=(_N, _K)), axis=0).reshape(-1), 19 | 'CLOSE': np.cumprod(1 + np.random.uniform(-0.1, 0.1, size=(_N, _K)), axis=0).reshape(-1), 20 | }, index=pd.MultiIndex.from_product([date, asset], names=['date', 'asset'])).reset_index() 21 | 22 | # 向脚本输入数据 23 | df = pl.from_pandas(df) 24 | 25 | """ 26 | RETURN_OO_1 = ts_delay(OPEN, -2) / ts_delay(OPEN, -1) - 1 27 | RETURN_OO_2 = ts_delay(OPEN, -3) / ts_delay(OPEN, -1) - 1 28 | RETURN_OO_5 = ts_delay(OPEN, -6) / ts_delay(OPEN, -1) - 1 29 | RETURN_CC_1 = ts_delay(CLOSE, -1) / CLOSE - 1 30 | """ 31 | from codes.prepare_data import main 32 | 33 | df = main(df) 34 | 35 | # save 36 | df.write_parquet('data.parquet', compression='zstd') 37 | -------------------------------------------------------------------------------- /deprecated/polars_group/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wukan1986/expr_codegen/faa577832438d2f4765c289a7d56c29e09e5cf3b/deprecated/polars_group/__init__.py -------------------------------------------------------------------------------- /deprecated/polars_group/code.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Sequence 3 | 4 | import jinja2 5 | from jinja2 import FileSystemLoader, TemplateNotFound 6 | 7 | from expr_codegen.expr import TS, CS, GP 8 | from expr_codegen.model import ListDictList 9 | from expr_codegen.polars_group.printer import PolarsStrPrinter 10 | 11 | 12 | def get_groupby_from_tuple(tup, func_name, drop_cols): 13 | """从传入的元组中生成分组运行代码""" 14 | prefix2, *_ = tup 15 | 16 | if prefix2 == TS: 17 | # 组内需要按时间进行排序,需要维持顺序 18 | prefix2, asset = tup 19 | return f'df = df.sort(_ASSET_, _DATE_).group_by(_ASSET_).map_groups({func_name}).drop(*{drop_cols})' 20 | if prefix2 == CS: 21 | prefix2, date = tup 22 | return f'df = df.sort(_DATE_).group_by(_DATE_).map_groups({func_name}).drop(*{drop_cols})' 23 | if prefix2 == GP: 24 | prefix2, date, group = tup 25 | return f'df = df.sort(_DATE_, "{group}").group_by(_DATE_, "{group}").map_groups({func_name}).drop(*{drop_cols})' 26 | 27 | return f'df = {func_name}(df).drop(*{drop_cols})' 28 | 29 | 30 | def symbols_to_code(syms): 31 | a = [f"{s}" for s in syms] 32 | b = [f"'{s}'" for s in syms] 33 | return f"""_ = [{','.join(b)}] 34 | [{','.join(a)}] = [pl.col(i) for i in _]""" 35 | 36 | 37 | def codegen(exprs_ldl: ListDictList, exprs_src, syms_dst, 38 | filename, 39 | date='date', asset='asset', 40 | extra_codes: Sequence[str] = (), 41 | **kwargs): 42 | """基于模板的代码生成""" 43 | if filename is None: 44 | filename = 'template.py.j2' 45 | 46 | # 打印Polars风格代码 47 | p = PolarsStrPrinter() 48 | 49 | # polars风格代码 50 | funcs = {} 51 | # 分组应用代码。这里利用了字典按插入顺序排序的特点,将排序放在最前 52 | groupbys = {'sort': ''} 53 | # 处理过后的表达式 54 | exprs_dst = [] 55 | syms_out = [] 56 | 57 | drop_symbols = exprs_ldl.drop_symbols() 58 | j = -1 59 | for i, row in enumerate(exprs_ldl.values()): 60 | for k, vv in row.items(): 61 | j += 1 62 | if len(vv) == 0: 63 | continue 64 | # 函数名 65 | func_name = f'func_{i}_{"__".join(k)}' 66 | func_code = [] 67 | for kv in vv: 68 | if kv is None: 69 | func_code.append(f" )") 70 | func_code.append(f"# " + '=' * 40) 71 | func_code.append(f" df = df.with_columns(") 72 | exprs_dst.append(f"#" + '=' * 40 + func_name) 73 | else: 74 | va, ex, sym, comment = kv 75 | s1 = str(ex) 76 | s2 = p.doprint(ex) 77 | if s1 != s2: 78 | # 不想等,打印注释,显示会更直观察 79 | func_code.append(f"# {va} = {s1}") 80 | 81 | func_code.append(f"{va}={s2},") 82 | exprs_dst.append(f"{va} = {s1} {comment}") 83 | if va not in syms_dst: 84 | syms_out.append(va) 85 | func_code.append(f" )") 86 | func_code = func_code[1:] 87 | 88 | if k[0] == TS: 89 | # if len(groupbys['sort']) == 0: 90 | # groupbys['sort'] = f'df = df.sort(_ASSET_, _DATE_)' 91 | # 时序需要排序 92 | func_code = [f' df = df.sort(_DATE_)'] + func_code 93 | 94 | # polars风格代码列表 95 | funcs[func_name] = '\n'.join(func_code) 96 | # 只有下划线开头再删除 97 | ds = [x for x in drop_symbols[j] if x.startswith('_')] 98 | # 分组应用代码 99 | groupbys[func_name] = get_groupby_from_tuple(k, func_name, ds) 100 | 101 | syms1 = symbols_to_code(syms_dst) 102 | syms2 = symbols_to_code(syms_out) 103 | 104 | try: 105 | env = jinja2.Environment(loader=FileSystemLoader(os.path.dirname(__file__))) 106 | template = env.get_template(filename) 107 | except TemplateNotFound: 108 | env = jinja2.Environment(loader=FileSystemLoader(os.path.dirname(filename))) 109 | template = env.get_template(os.path.basename(filename)) 110 | 111 | return template.render(funcs=funcs, groupbys=groupbys, 112 | exprs_src=exprs_src, exprs_dst=exprs_dst, 113 | syms1=syms1, syms2=syms2, 114 | date=date, asset=asset, 115 | extra_codes=extra_codes) 116 | -------------------------------------------------------------------------------- /deprecated/polars_group/printer.py: -------------------------------------------------------------------------------- 1 | from sympy import Basic, Function, StrPrinter 2 | from sympy.printing.precedence import precedence, PRECEDENCE 3 | 4 | 5 | # TODO: 如有新添加函数,但表达式有变更才需要在此补充对应的打印代码,否则可以省略 6 | 7 | class PolarsStrPrinter(StrPrinter): 8 | def _print(self, expr, **kwargs) -> str: 9 | """Internal dispatcher 10 | 11 | Tries the following concepts to print an expression: 12 | 1. Let the object print itself if it knows how. 13 | 2. Take the best fitting method defined in the printer. 14 | 3. As fall-back use the emptyPrinter method for the printer. 15 | """ 16 | self._print_level += 1 17 | try: 18 | # If the printer defines a name for a printing method 19 | # (Printer.printmethod) and the object knows for itself how it 20 | # should be printed, use that method. 21 | if self.printmethod and hasattr(expr, self.printmethod): 22 | if not (isinstance(expr, type) and issubclass(expr, Basic)): 23 | return getattr(expr, self.printmethod)(self, **kwargs) 24 | 25 | # See if the class of expr is known, or if one of its super 26 | # classes is known, and use that print function 27 | # Exception: ignore the subclasses of Undefined, so that, e.g., 28 | # Function('gamma') does not get dispatched to _print_gamma 29 | classes = type(expr).__mro__ 30 | # if AppliedUndef in classes: 31 | # classes = classes[classes.index(AppliedUndef):] 32 | # if UndefinedFunction in classes: 33 | # classes = classes[classes.index(UndefinedFunction):] 34 | # Another exception: if someone subclasses a known function, e.g., 35 | # gamma, and changes the name, then ignore _print_gamma 36 | if Function in classes: 37 | i = classes.index(Function) 38 | classes = tuple(c for c in classes[:i] if \ 39 | c.__name__ == classes[0].__name__ or \ 40 | c.__name__.endswith("Base")) + classes[i:] 41 | for cls in classes: 42 | printmethodname = '_print_' + cls.__name__ 43 | 44 | # 所有以gp_开头的函数都转换成cs_开头 45 | if printmethodname.startswith('_print_gp_'): 46 | printmethodname = "_print_gp_" 47 | 48 | printmethod = getattr(self, printmethodname, None) 49 | if printmethod is not None: 50 | return printmethod(expr, **kwargs) 51 | # Unknown object, fall back to the emptyPrinter. 52 | return self.emptyPrinter(expr) 53 | finally: 54 | self._print_level -= 1 55 | 56 | def _print_Symbol(self, expr): 57 | return expr.name 58 | 59 | def _print_Equality(self, expr): 60 | PREC = precedence(expr) 61 | return "%s==%s" % (self.parenthesize(expr.args[0], PREC), self.parenthesize(expr.args[1], PREC)) 62 | 63 | def _print_Or(self, expr): 64 | PREC = PRECEDENCE["Mul"] 65 | return " | ".join(self.parenthesize(arg, PREC) for arg in expr.args) 66 | 67 | def _print_Xor(self, expr): 68 | PREC = PRECEDENCE["Mul"] 69 | return " ^ ".join(self.parenthesize(arg, PREC) for arg in expr.args) 70 | 71 | def _print_And(self, expr): 72 | PREC = PRECEDENCE["Mul"] 73 | return " & ".join(self.parenthesize(arg, PREC) for arg in expr.args) 74 | 75 | def _print_Not(self, expr): 76 | PREC = PRECEDENCE["Mul"] 77 | return "~%s" % self.parenthesize(expr.args[0], PREC) 78 | 79 | def _print_gp_(self, expr): 80 | """gp_函数都转换成cs_函数,但要丢弃第一个参数""" 81 | new_args = [self._print(arg) for arg in expr.args[1:]] 82 | func_name = expr.func.__name__[3:] 83 | return "cs_%s(%s)" % (func_name, ",".join(new_args)) 84 | -------------------------------------------------------------------------------- /deprecated/polars_group/template.py.j2: -------------------------------------------------------------------------------- 1 | # this code is auto generated by the expr_codegen 2 | # https://github.com/wukan1986/expr_codegen 3 | # 此段代码由 expr_codegen 自动生成,欢迎提交 issue 或 pull request 4 | from typing import TypeVar 5 | 6 | import polars as pl # noqa 7 | import polars.selectors as cs # noqa 8 | # from loguru import logger # noqa 9 | from polars import DataFrame as _pl_DataFrame 10 | from polars import LazyFrame as _pl_LazyFrame 11 | # =================================== 12 | # 导入优先级,例如:ts_RSI在ta与talib中都出现了,优先使用ta 13 | # 运行时,后导入覆盖前导入,但IDE智能提示是显示先导入的 14 | _ = 0 # 只要之前出现了语句,之后的import位置不参与调整 15 | # from polars_ta.prefix.talib import * # noqa 16 | from polars_ta.prefix.tdx import * # noqa 17 | from polars_ta.prefix.ta import * # noqa 18 | from polars_ta.prefix.wq import * # noqa 19 | from polars_ta.prefix.cdl import * # noqa 20 | 21 | DataFrame = TypeVar('DataFrame', _pl_LazyFrame, _pl_DataFrame) 22 | # =================================== 23 | 24 | {{ syms1 }} 25 | 26 | {{ syms2 }} 27 | 28 | _DATE_ = '{{ date }}' 29 | _ASSET_ = '{{ asset }}' 30 | _NONE_ = None 31 | _TRUE_ = True 32 | _FALSE_ = False 33 | 34 | def unpack(x: pl.Expr, idx: int = 0) -> pl.Expr: 35 | return x.struct[idx] 36 | 37 | {%-for row in extra_codes %} 38 | {{ row-}} 39 | {% endfor %} 40 | 41 | {% for key, value in funcs.items() %} 42 | def {{ key }}(df: DataFrame) -> DataFrame: 43 | {{ value }} 44 | return df 45 | {% endfor %} 46 | 47 | """ 48 | {%-for row in exprs_dst %} 49 | {{ row-}} 50 | {% endfor %} 51 | """ 52 | 53 | """ 54 | {%-for a,b,c in exprs_src %} 55 | {{ a }} = {{ b}} {{c-}} 56 | {% endfor %} 57 | """ 58 | 59 | 60 | def main(df: DataFrame) -> DataFrame: 61 | # logger.info("start...") 62 | {% for key, value in groupbys.items() %} 63 | {{ value-}} 64 | {% endfor %} 65 | 66 | # drop intermediate columns 67 | # df = df.select(pl.exclude(r'^_x_\d+$')) 68 | df = df.select(~cs.starts_with("_")) 69 | 70 | # shrink 71 | df = df.select(cs.all().shrink_dtype()) 72 | # df = df.shrink_to_fit() 73 | 74 | # logger.info('done') 75 | 76 | # save 77 | # df.write_parquet('output.parquet') 78 | 79 | return df 80 | 81 | # if __name__ in ("__main__", "builtins"): 82 | # # TODO: 数据加载或外部传入 83 | # df_output = main(df_input) 84 | -------------------------------------------------------------------------------- /examples/demo_exec_pl.py: -------------------------------------------------------------------------------- 1 | # %% 2 | import os 3 | import sys 4 | from pathlib import Path 5 | 6 | # 修改当前目录到上层目录,方便跨不同IDE中使用 7 | pwd = str(Path(__file__).parents[1]) 8 | os.chdir(pwd) 9 | sys.path.append(pwd) 10 | 11 | import polars as pl 12 | from loguru import logger # noqa 13 | 14 | # %% 15 | df_input = pl.read_parquet('data/data.parquet') 16 | # print(df.tail()) 17 | 18 | from codes.demo_exec import main 19 | 20 | df = main(df_input) 21 | print(df.tail()) 22 | 23 | # %% 24 | columns = ['CLOSE', '移动平均_10', '移动平均_20', 'MAMA_20'] 25 | df1 = df.filter(pl.col('asset') == 's_100').select('date', 'asset', *columns) 26 | df2 = df.filter(pl.col('asset') == 's_200').select('date', 'asset', *columns) 27 | # %% 28 | # 此绘图需要安装hvplot 29 | # 需要在notebook环境中使用 30 | plot1 = df1.plot(x='date', y=columns, label='s_100') 31 | plot2 = df2.plot(x='date', y=columns, label='s_200') 32 | 33 | # hvplot叠加特方便,但缺点是不够灵活 34 | (plot1 + plot2).cols(1) 35 | # %% 36 | -------------------------------------------------------------------------------- /examples/demo_express.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from io import StringIO 3 | 4 | import polars as pl 5 | 6 | from expr_codegen import codegen_exec 7 | 8 | 9 | def _code_block_1(): 10 | # 因子编辑区,可利用IDE的智能提示在此区域编辑因子 11 | LOG_MC_ZS = cs_mad_zscore(log1p(market_cap)) 12 | 13 | 14 | def _code_block_2(): 15 | # 模板中已经默认导入了from polars_ta.prefix下大量的算子,但 16 | # talib在模板中没有默认导入。这种写法可实现在生成的代码中导入 17 | from polars_ta.prefix.talib import ts_LINEARREG_SLOPE # noqa 18 | 19 | # 1. 下划线开头的变量只是中间变量,会被自动更名,最终输出时会被剔除 20 | # 2. 下划线开头的变量可以重复使用。多个复杂因子多行书写时有重复中间变时不再冲突 21 | _avg = ts_mean(corr, 20) 22 | _std = ts_std_dev(corr, 20) 23 | _beta = ts_LINEARREG_SLOPE(corr, 20) 24 | 25 | # 3. 下划线开头的变量有环循环赋值。在调试时可快速用注释进行切换 26 | _avg = cs_mad_zscore_resid(_avg, LOG_MC_ZS, ONE) 27 | _std = cs_mad_zscore_resid(_std, LOG_MC_ZS, ONE) 28 | # _beta = cs_mad_zscore_resid(_beta, LOG_MC_ZS, ONE) 29 | 30 | _corr = cs_zscore(_avg) + cs_zscore(_std) 31 | CPV = cs_zscore(_corr) + cs_zscore(_beta) 32 | 33 | 34 | code = codegen_exec(None, _code_block_1, _code_block_2, over_null='partition_by', output_file=sys.stdout) # 打印代码 35 | code = codegen_exec(None, _code_block_1, _code_block_2, over_null='partition_by', output_file="output.py") # 保存到文件 36 | code = codegen_exec(None, _code_block_1, _code_block_2, over_null='partition_by') # 只执行,不保存代码 37 | 38 | code = StringIO() 39 | codegen_exec(None, _code_block_1, _code_block_2, over_null='partition_by', output_file=code) # 保存到字符串 40 | code.seek(0) 41 | code.read() # 读取代码 42 | 43 | # TODO 替换成合适的数据 44 | df = pl.DataFrame() 45 | df = codegen_exec(df.lazy(), _code_block_1, _code_block_2, over_null='partition_by').collect() # Lazy CPU 46 | df = codegen_exec(df.lazy(), _code_block_1, _code_block_2, over_null='partition_by').collect(engine="gpu") # Lazy GPU 47 | -------------------------------------------------------------------------------- /examples/demo_min.py: -------------------------------------------------------------------------------- 1 | """ 2 | 分钟线处理示例 3 | 4 | 在分钟线预处理时常常需要按天分别处理。例如ts_delay如果是对多天分钟数据处理,只有第一天第一条为null, 5 | 如果是对每天分钟数据处理,每天第一条为null, 6 | 7 | 在日线时,默认date参数的freq是1d,asset参数是股票代码 8 | 在按日划分的分钟时,默认date参数的freq是1min,asset参数是股票代码+日,这样才能每天独立处理 9 | 10 | 如果分钟数据已经按日期分好了文件,也可以直接多进程并行处理,就没这么麻烦 11 | 12 | """ 13 | from datetime import datetime 14 | 15 | import numpy as np 16 | import pandas as pd 17 | import polars as pl 18 | from loguru import logger 19 | 20 | from expr_codegen import codegen_exec # noqa 21 | 22 | np.random.seed(42) 23 | 24 | ASSET_COUNT = 500 25 | DATE_COUNT = 250 * 24 * 10 * 1 26 | DATE = pd.date_range(datetime(2020, 1, 1), periods=DATE_COUNT, freq='1min').repeat(ASSET_COUNT) 27 | ASSET = [f'A{i:04d}' for i in range(ASSET_COUNT)] * DATE_COUNT 28 | 29 | df = pl.DataFrame( 30 | { 31 | 'datetime': DATE, 32 | 'asset': ASSET, 33 | "OPEN": np.random.rand(DATE_COUNT * ASSET_COUNT), 34 | "HIGH": np.random.rand(DATE_COUNT * ASSET_COUNT), 35 | "LOW": np.random.rand(DATE_COUNT * ASSET_COUNT), 36 | "CLOSE": np.random.rand(DATE_COUNT * ASSET_COUNT), 37 | "VOLUME": np.random.rand(DATE_COUNT * ASSET_COUNT), 38 | "OPEN_INTEREST": np.random.rand(DATE_COUNT * ASSET_COUNT), 39 | "FILTER": np.tri(DATE_COUNT, ASSET_COUNT, k=-2).reshape(-1), 40 | } 41 | ).lazy() 42 | 43 | df = df.filter(pl.col('FILTER') == 1) 44 | 45 | logger.info('时间戳调整开始') 46 | # 交易日,期货夜盘属于下一个交易日,后移4小时夜盘日期就一样了 47 | df = df.with_columns(trading_day=pl.col('datetime').dt.offset_by("4h")) 48 | # 周五晚已经变成了周六,双休要移动到周一 49 | df = df.with_columns(trading_day=pl.when(pl.col('trading_day').dt.weekday() > 5) 50 | .then(pl.col("trading_day").dt.offset_by("2d")) 51 | .otherwise(pl.col("trading_day"))) 52 | df = df.with_columns( 53 | # 交易日 54 | trading_day=pl.col("trading_day").dt.date(), 55 | # 工作日 56 | action_day=pl.col('datetime').dt.date(), 57 | ) 58 | df = df.collect() 59 | logger.info('时间戳调整完成') 60 | # --- 61 | # !!! 重要代码,生成复合字段,用来ts_排序 62 | # _asset_date以下划线开头,会自动删除,如要保留,可去了下划线 63 | # 股票用action_day,期货用trading_day 64 | df = df.with_columns(_asset_date=pl.struct("asset", "trading_day")) 65 | df = codegen_exec(df, """OPEN_RANK = cs_rank(OPEN[1]) # 仅演示""", over_null='partition_by', 66 | # !!!使用时一定要分清分组是用哪个字段 67 | date='datetime', asset='_asset_date') 68 | # --- 69 | logger.info('1分钟转15分钟线开始') 70 | df1 = df.sort('asset', 'datetime').group_by_dynamic('datetime', every="15m", closed='left', label="left", group_by=['asset', 'trading_day']).agg( 71 | open_dt=pl.first("datetime"), 72 | close_dt=pl.last("datetime"), 73 | OPEN=pl.first("OPEN"), 74 | HIGH=pl.max("HIGH"), 75 | LOW=pl.min("LOW"), 76 | CLOSE=pl.last("CLOSE"), 77 | VOLUME=pl.sum("VOLUME"), 78 | OPEN_INTEREST=pl.last("OPEN_INTEREST"), 79 | ) 80 | logger.info('1分钟转15分钟线结束') 81 | print(df1) 82 | # --- 83 | logger.info('1分钟转日线开始') 84 | # 也可以使用group_by_dynamic,只是日线隐含了label="left" 85 | df1 = df.sort('asset', 'datetime').group_by('asset', 'trading_day', maintain_order=True).agg( 86 | open_dt=pl.first("datetime"), 87 | close_dt=pl.last("datetime"), 88 | OPEN=pl.first("OPEN"), 89 | HIGH=pl.max("HIGH"), 90 | LOW=pl.min("LOW"), 91 | CLOSE=pl.last("CLOSE"), 92 | VOLUME=pl.sum("VOLUME"), 93 | OPEN_INTEREST=pl.last("OPEN_INTEREST"), 94 | ) 95 | logger.info('1分钟转日线结束') 96 | print(df1) 97 | -------------------------------------------------------------------------------- /examples/demo_sql.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import polars as pl 3 | from polars_ta.prefix.wq import * 4 | 5 | from expr_codegen import codegen_exec 6 | 7 | _N = 250 * 1 8 | _K = 500 # TODO 如要单资产,改此处为1即可 9 | 10 | asset = [f's_{i:04d}' for i in range(_K)] 11 | date = pd.date_range('2015-1-1', periods=_N) 12 | 13 | df = pd.DataFrame({ 14 | # 原始价格 15 | 'CLOSE': np.cumprod(1 + np.random.uniform(-0.1, 0.1, size=(_N, _K)), axis=0).reshape(-1), 16 | 'OPEN': np.cumprod(1 + np.random.uniform(-0.1, 0.1, size=(_N, _K)), axis=0).reshape(-1), 17 | # TODO 这只是为了制造长度不同的数据而设计 18 | "FILTER": np.tri(_N, _K, k=100).reshape(-1), 19 | }, index=pd.MultiIndex.from_product([date, asset], names=['date', 'asset'])).reset_index() 20 | 21 | # 向脚本输入数据 22 | df = pl.from_pandas(df) 23 | 24 | 25 | def _code_block_1(): 26 | # 因子编辑区,可利用IDE的智能提示在此区域编辑因子 27 | 28 | A1 = floor(log1p(ceiling(abs_(CLOSE * 100) * 2))) 29 | 30 | 31 | df = codegen_exec(df, _code_block_1, over_null='partition_by', output_file='1_out.sql', style='sql', 32 | table_name='self') # 打印代码 33 | 34 | print(df) 35 | 36 | code = codegen_exec(None, _code_block_1, over_null='partition_by', output_file='1_out.sql', style='sql', 37 | date='date1', asset='asset2', table_name='df3') # 打印代码 38 | 39 | print(code) 40 | -------------------------------------------------------------------------------- /examples/demo_tdx.py: -------------------------------------------------------------------------------- 1 | """尝试套通达信风格指数""" 2 | import os 3 | import sys 4 | import time 5 | from pathlib import Path 6 | 7 | # 修改当前目录到上层目录,方便跨不同IDE中使用 8 | pwd = str(Path(__file__).parents[1]) 9 | os.chdir(pwd) 10 | sys.path.append(pwd) 11 | print("pwd:", os.getcwd()) 12 | # ==================== 13 | import polars as pl 14 | from expr_codegen import codegen_exec 15 | from loguru import logger 16 | from polars_ta.prefix.wq import * 17 | 18 | 19 | def _code_block_1(): 20 | # 基础字段准备=================== 21 | 涨跌幅 = CLOSE / CLOSE[1] - 1 22 | 振幅 = (HIGH - LOW) / CLOSE[1] 23 | 24 | 开盘涨停 = open >= high_limit - 0.001 25 | 最高涨停 = high >= high_limit - 0.001 26 | 一字涨停 = low >= high_limit - 0.001 27 | 收盘涨停 = close >= high_limit - 0.001 28 | 29 | 开盘跌停 = open <= low_limit + 0.001 30 | 一字跌停 = high <= low_limit + 0.001 31 | 最低跌停 = low <= low_limit + 0.001 32 | 收盘跌停 = close <= low_limit + 0.001 33 | 34 | 连板天数 = ts_cum_sum_reset(收盘涨停) 35 | 涨停T天, 涨停N板, _ = ts_up_stat(收盘涨停) 36 | 37 | 38 | def _code_block_2(): 39 | # 通达信风格板块=================== 40 | _昨日强势1 = (0.07 < 涨跌幅) & (涨跌幅 < 0.1) & (上海主板 | 深圳主板) 41 | _昨日强势2 = (0.14 < 涨跌幅) & (涨跌幅 < 0.2) & (创业板 | 科创板) 42 | _昨日强势3 = (0.21 < 涨跌幅) & (涨跌幅 < 0.3) & 北交所 43 | 昨日强势 = _昨日强势1 | _昨日强势2 | _昨日强势3 44 | _昨日弱势1 = (-0.1 < 涨跌幅) & (涨跌幅 < -0.07) & (上海主板 | 深圳主板) 45 | _昨日弱势2 = (-0.2 < 涨跌幅) & (涨跌幅 < -0.14) & (创业板 | 科创板) 46 | _昨日弱势3 = (-0.3 < 涨跌幅) & (涨跌幅 < -0.21) & 北交所 47 | 昨日弱势 = _昨日弱势1 | _昨日弱势2 | _昨日弱势3 48 | _昨日较弱1 = (-0.07 <= 涨跌幅) & (涨跌幅 <= -0.05) & (上海主板 | 深圳主板) 49 | _昨日较弱2 = (-0.14 <= 涨跌幅) & (涨跌幅 <= -0.10) & (创业板 | 科创板) 50 | _昨日较弱3 = (-0.21 <= 涨跌幅) & (涨跌幅 <= -0.15) & 北交所 51 | 昨日较弱 = _昨日较弱1 | _昨日较弱2 | _昨日较弱3 52 | _昨日较强1 = (0.05 <= 涨跌幅) & (涨跌幅 <= 0.07) & (上海主板 | 深圳主板) 53 | _昨日较强2 = (0.10 <= 涨跌幅) & (涨跌幅 <= 0.14) & (创业板 | 科创板) 54 | _昨日较强3 = (0.15 <= 涨跌幅) & (涨跌幅 <= 0.21) & 北交所 55 | 昨日较强 = _昨日较强1 | _昨日较强2 | _昨日较强3 56 | _最近异动1 = (ts_sum(turnover_ratio, 3) > 25) & (ts_mean(振幅, 3) > 0.07) & ~(创业板 | 科创板) 57 | _最近异动2 = (ts_sum(turnover_ratio, 3) > 50) & (ts_mean(振幅, 3) > 0.14) & (创业板 | 科创板) 58 | 最近异动 = _最近异动1 | _最近异动2 59 | _昨高换手1 = (turnover_ratio > 15) & ~科创板 60 | _昨高换手2 = (turnover_ratio > 30) & 科创板 61 | 昨高换手 = _昨高换手1 | _昨高换手2 62 | 近期强势 = (ts_returns(CLOSE, 20) >= 0.3) & (ts_returns(CLOSE, 3) > 0) 63 | 近期弱势 = (ts_returns(CLOSE, 20) <= -0.2) & (ts_returns(CLOSE, 3) < 0) 64 | # =================== 65 | 最近情绪 = ts_count(收盘涨停 | 收盘跌停, 5) > 0 66 | 昨日跌停 = 收盘跌停 67 | 昨曾跌停 = 最低跌停 & ~收盘跌停 68 | 昨日首板 = 连板天数 == 1 69 | 最近多板 = 涨停N板 >= 2 70 | 昨日连板 = 连板天数 >= 2 71 | 昨日涨停 = 收盘涨停 72 | 昨曾涨停 = 最高涨停 & ~收盘涨停 73 | 74 | 昨成交20 = cs_rank(-amount, False) <= 20 75 | 大盘股 = cs_rank(-market_cap, False) <= 200 76 | 微盘股 = cs_rank(market_cap, False) <= 400 77 | 高市盈率 = cs_rank(-pe_ratio, False) <= 200 78 | 低市盈率 = cs_rank(pe_ratio, False) <= 200 79 | 高市净率 = cs_rank(-pb_ratio, False) <= 200 80 | 低市净率 = cs_rank(pb_ratio, False) <= 200 81 | 活跃股 = cs_rank(-ts_sum(turnover_ratio, 5)) <= 100 82 | 不活跃股 = ts_sum(turnover_ratio, 5) < 20 83 | 昨日振荡 = (振幅 > 0.08) & (LOW < CLOSE[1]) & (HIGH > CLOSE[1]) 84 | 近期新高 = ts_max(HIGH, 3) == ts_max(HIGH, 250) 85 | 近期新低 = ts_min(LOW, 3) == ts_max(LOW, 250) 86 | 百元股 = (ts_max(high, 5) > 100) & (close[1] > 90) 87 | 低价股 = close <= 3 88 | 89 | 90 | # 由于读写多,推荐放到内存盘,加快速度 91 | PATH_INPUT1 = r'M:\preprocessing\data2.parquet' 92 | # 去除停牌后的基础数据 93 | PATH_OUTPUT = r'M:\preprocessing\out1.parquet' 94 | 95 | if __name__ == '__main__': 96 | logger.info('数据准备开始') 97 | df = pl.read_parquet(PATH_INPUT1) 98 | 99 | logger.info('数据准备完成') 100 | # ===================================== 101 | logger.info('计算开始') 102 | t1 = time.perf_counter() 103 | df = codegen_exec(df, _code_block_1, _code_block_2, over_null=None, output_file='1_out.py', run_file=False) 104 | t2 = time.perf_counter() 105 | df = codegen_exec(df, _code_block_1, _code_block_2, over_null=None, output_file='1_out.py', run_file=True) 106 | t3 = time.perf_counter() 107 | df = codegen_exec(df, _code_block_1, _code_block_2, over_null=None, output_file='1_out.py', run_file=True) 108 | t4 = time.perf_counter() 109 | print(t2 - t1, t3 - t2, t4 - t3) 110 | logger.info('计算结束') 111 | df = df.filter( 112 | ~pl.col('is_st'), 113 | ) 114 | print(df) 115 | -------------------------------------------------------------------------------- /examples/demo_transformer.py: -------------------------------------------------------------------------------- 1 | import ast 2 | 3 | from expr_codegen.codes import source_replace, SyntaxTransformer, RenameTransformer 4 | 5 | encoding = 'utf-8' 6 | input_file = 'factors_test1.txt' 7 | output_file = 'factors_test2.txt' 8 | 9 | # ========================== 10 | # 观察区,查看存在哪些变量和函数 11 | with open(input_file, 'r', encoding=encoding) as f: 12 | sources = f.readlines() 13 | 14 | # 不要太大,防止内存不足 15 | source = '\n'.join(sources[:1000]) 16 | 17 | tree = ast.parse(source_replace(source)) 18 | t1 = SyntaxTransformer(True) 19 | t1.visit(tree) 20 | t = RenameTransformer({}, {}) 21 | t.visit(tree) 22 | 23 | print('=' * 60) 24 | print(t.funcs_old) 25 | print(t.args_old) 26 | print(t.targets_old) 27 | print('=' * 60) 28 | 29 | # ========================== 30 | # 映射 31 | funcs_map = {'abs': 'abs_', 32 | 'max': 'max_', 33 | 'min': 'min_', 34 | 'delta': 'ts_delta', 35 | 'delay': 'ts_delay', 36 | 'ts_argmin': 'ts_arg_min', 37 | 'ts_argmax': 'ts_arg_max', 38 | # TODO 目前不支持的操作 39 | # 'cs_corr': '', 40 | # 'cs_std': '', 41 | } 42 | args_map = {} 43 | targets_map = {} 44 | 45 | # TODO 如果后面文件太大,耗时太久,需要手工放开后面一段 46 | # sys.exit(-1) 47 | # ========================== 48 | with open(input_file, 'r', encoding=encoding) as f: 49 | sources = f.readlines() 50 | 51 | t1 = SyntaxTransformer(True) 52 | t = RenameTransformer(funcs_map, targets_map, args_map) 53 | 54 | outputs = [] 55 | for i in range(0, len(sources), 1000): 56 | print(i) 57 | source = '\n'.join(sources[i:i + 1000]) 58 | 59 | tree = ast.parse(source_replace(source)) 60 | t1.visit(tree) 61 | t.visit(tree) 62 | outputs.append(ast.unparse(tree)) 63 | 64 | print('转码完成') 65 | with open(output_file, 'w') as f2: 66 | f2.writelines(outputs) 67 | print('保存成功') 68 | -------------------------------------------------------------------------------- /examples/output_alpha101.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | import numpy as np 4 | import pandas as pd 5 | 6 | from loguru import logger 7 | 8 | # TODO: load data 9 | df = pd.DataFrame() 10 | 11 | 12 | def signed_power(x, y): 13 | return x.sign() * (x.abs() ** y) 14 | 15 | 16 | def func_0_ts__asset__date(df: pd.DataFrame) -> pd.DataFrame: 17 | # ======================================== 18 | # x_10 = ts_sum(RETURNS, 5) 19 | df["x_10"] = df["RETURNS"].rolling(5).sum() 20 | # x_22 = ts_delta(RETURNS, 3) 21 | df["x_22"] = df["RETURNS"].diff(3) 22 | # x_32 = ts_sum(RETURNS, 250) + 1 23 | df["x_32"] = df["RETURNS"].rolling(250).sum() + 1 24 | # x_0 = ts_arg_max(signed_power(if_else(RETURNS < 0, ts_std_dev(RETURNS, 20), CLOSE), 2.0), 5) 25 | df["x_0"] = signed_power(df["RETURNS"].rolling(20).std(ddof=0).where(df["RETURNS"] < 0, df["CLOSE"]), 2.0).rolling(5).apply(np.argmax, engine="numba", raw=True) 26 | # x_12 = ts_delta(CLOSE, 1) 27 | df["x_12"] = df["CLOSE"].diff(1) 28 | # x_29 = ts_rank(CLOSE, 10) 29 | df["x_29"] = df["CLOSE"].rolling(10).rank(pct=True) 30 | # x_38 = ts_delta(CLOSE, 7) 31 | df["x_38"] = df["CLOSE"].diff(7) 32 | # x_1 = ts_delta(log(VOLUME), 2) 33 | df["x_1"] = np.log(df["VOLUME"]).diff(2) 34 | # x_19 = ts_delta(VOLUME, 3) 35 | df["x_19"] = df["VOLUME"].diff(3) 36 | # x_9 = ts_sum(OPEN, 5) 37 | df["x_9"] = df["OPEN"].rolling(5).sum() 38 | # x_34 = OPEN - ts_delay(CLOSE, 1) 39 | df["x_34"] = df["OPEN"] - df["CLOSE"].shift(1) 40 | # x_37 = ts_corr(OPEN, VOLUME, 10) 41 | df["x_37"] = df["OPEN"].rolling(10).corr(df["VOLUME"], ddof=0) 42 | # x_36 = OPEN - ts_delay(LOW, 1) 43 | df["x_36"] = df["OPEN"] - df["LOW"].shift(1) 44 | # x_8 = OPEN - ts_sum(VWAP, 10)/10 45 | df["x_8"] = df["OPEN"] - df["VWAP"].rolling(10).sum() / 10 46 | # x_35 = OPEN - ts_delay(HIGH, 1) 47 | df["x_35"] = df["OPEN"] - df["HIGH"].shift(1) 48 | # x_30 = ts_rank(VOLUME/ADV20, 5) 49 | df["x_30"] = (df["VOLUME"] / df["ADV20"]).rolling(5).rank(pct=True) 50 | return df 51 | 52 | 53 | def func_0_cs__date(df: pd.DataFrame) -> pd.DataFrame: 54 | # ======================================== 55 | # x_20 = cs_rank(CLOSE) 56 | df["x_20"] = df["CLOSE"].rank(pct=True) 57 | # x_6 = cs_rank(VOLUME) 58 | df["x_6"] = df["VOLUME"].rank(pct=True) 59 | # x_5 = cs_rank(OPEN) 60 | df["x_5"] = df["OPEN"].rank(pct=True) 61 | # x_7 = cs_rank(LOW) 62 | df["x_7"] = df["LOW"].rank(pct=True) 63 | # x_24 = cs_rank(HIGH) 64 | df["x_24"] = df["HIGH"].rank(pct=True) 65 | # ======================================== 66 | # x_3 = CLOSE - OPEN 67 | df["x_3"] = df["CLOSE"] - df["OPEN"] 68 | # x_15 = CLOSE - VWAP 69 | df["x_15"] = df["CLOSE"] - df["VWAP"] 70 | # ======================================== 71 | # x_23 = cs_rank(x_22) 72 | df["x_23"] = df["x_22"].rank(pct=True) 73 | # x_33 = cs_rank(x_32) + 1 74 | df["x_33"] = df["x_32"].rank(pct=True) + 1 75 | # alpha_001 = cs_rank(x_0) - 0.5 76 | df["alpha_001"] = df["x_0"].rank(pct=True) - 0.5 77 | # x_2 = cs_rank(x_1) 78 | df["x_2"] = df["x_1"].rank(pct=True) 79 | # x_4 = cs_rank(x_3/OPEN) 80 | df["x_4"] = (df["x_3"] / df["OPEN"]).rank(pct=True) 81 | # alpha_005 = -abs(cs_rank(x_15))*cs_rank(x_8) 82 | df["alpha_005"] = -df["x_15"].rank(pct=True).abs() * df["x_8"].rank(pct=True) 83 | # alpha_020 = -cs_rank(x_34)*cs_rank(x_35)*cs_rank(x_36) 84 | df["alpha_020"] = -df["x_34"].rank(pct=True) * df["x_35"].rank(pct=True) * df["x_36"].rank(pct=True) 85 | # ======================================== 86 | # x_13 = -x_12 87 | df["x_13"] = -df["x_12"] 88 | # alpha_006 = -x_37 89 | df["alpha_006"] = -df["x_37"] 90 | # x_16 = -x_15 91 | df["x_16"] = -df["x_15"] 92 | return df 93 | 94 | 95 | def func_1_ts__asset__date(df: pd.DataFrame) -> pd.DataFrame: 96 | # ======================================== 97 | # x_28 = ts_delta(x_12, 1) 98 | df["x_28"] = df["x_12"].diff(1) 99 | # alpha_012 = -x_12*sign(ts_delta(VOLUME, 1)) 100 | df["alpha_012"] = -df["x_12"] * df["VOLUME"].diff(1).sign() 101 | # alpha_007 = if_else(ADV20 < VOLUME, -sign(x_38)*ts_rank(abs(x_38), 60), -1) 102 | df["alpha_007"] = (-df["x_38"].sign() * df["x_38"].abs().rolling(60).rank(pct=True)).where(df["ADV20"] < df["VOLUME"], -1) 103 | # x_21 = ts_covariance(x_20, x_6, 5) 104 | df["x_21"] = df["x_20"].rolling(5).cov(df["x_6"], ddof=0) 105 | # x_31 = x_3 + ts_corr(CLOSE, OPEN, 10) + ts_std_dev(abs(x_3), 5) 106 | df["x_31"] = df["x_3"] + df["CLOSE"].rolling(10).corr(df["OPEN"], ddof=0) + df["x_3"].abs().rolling(5).std(ddof=0) 107 | # alpha_003 = -ts_corr(x_5, x_6, 10) 108 | df["alpha_003"] = -df["x_5"].rolling(10).corr(df["x_6"], ddof=0) 109 | # x_11 = x_10*x_9 - ts_delay(x_10*x_9, 10) 110 | df["x_11"] = df["x_10"] * df["x_9"] - (df["x_10"] * df["x_9"]).shift(10) 111 | # alpha_004 = -ts_rank(x_7, 9) 112 | df["alpha_004"] = -df["x_7"].rolling(9).rank(pct=True) 113 | # x_25 = ts_corr(x_24, x_6, 3) 114 | df["x_25"] = df["x_24"].rolling(3).corr(df["x_6"], ddof=0) 115 | # x_27 = ts_covariance(x_24, x_6, 5) 116 | df["x_27"] = df["x_24"].rolling(5).cov(df["x_6"], ddof=0) 117 | # ======================================== 118 | # alpha_014 = -x_23*x_37 119 | df["alpha_014"] = -df["x_23"] * df["x_37"] 120 | # ======================================== 121 | # alpha_019 = -x_33*sign(CLOSE + x_38 - ts_delay(CLOSE, 7)) 122 | df["alpha_019"] = -df["x_33"] * (df["CLOSE"] + df["x_38"] - df["CLOSE"].shift(7)).sign() 123 | # x_14 = if_else(ts_min(x_12, 4) > 0, x_12, if_else(ts_max(x_12, 4) < 0, x_12, x_13)) 124 | df["x_14"] = df["x_12"].where(df["x_12"].rolling(4).min() > 0, df["x_12"].where(df["x_12"].rolling(4).max() < 0, df["x_13"])) 125 | # alpha_009 = if_else(ts_min(x_12, 5) > 0, x_12, if_else(ts_max(x_12, 5) < 0, x_12, x_13)) 126 | df["alpha_009"] = df["x_12"].where(df["x_12"].rolling(5).min() > 0, df["x_12"].where(df["x_12"].rolling(5).max() < 0, df["x_13"])) 127 | # alpha_002 = -ts_corr(x_2, x_4, 6) 128 | df["alpha_002"] = -df["x_2"].rolling(6).corr(df["x_4"], ddof=0) 129 | # x_17 = ts_max(x_16, 3) 130 | df["x_17"] = df["x_16"].rolling(3).max() 131 | # x_18 = ts_min(x_16, 3) 132 | df["x_18"] = df["x_16"].rolling(3).min() 133 | return df 134 | 135 | 136 | def func_2_cs__date(df: pd.DataFrame) -> pd.DataFrame: 137 | # ======================================== 138 | # alpha_017 = -cs_rank(x_28)*cs_rank(x_29)*cs_rank(x_30) 139 | df["alpha_017"] = -df["x_28"].rank(pct=True) * df["x_29"].rank(pct=True) * df["x_30"].rank(pct=True) 140 | # alpha_013 = -cs_rank(x_21) 141 | df["alpha_013"] = -df["x_21"].rank(pct=True) 142 | # alpha_018 = -cs_rank(x_31) 143 | df["alpha_018"] = -df["x_31"].rank(pct=True) 144 | # alpha_008 = -cs_rank(x_11) 145 | df["alpha_008"] = -df["x_11"].rank(pct=True) 146 | # x_26 = cs_rank(x_25) 147 | df["x_26"] = df["x_25"].rank(pct=True) 148 | # alpha_016 = -cs_rank(x_27) 149 | df["alpha_016"] = -df["x_27"].rank(pct=True) 150 | # ======================================== 151 | # alpha_010 = cs_rank(x_14) 152 | df["alpha_010"] = df["x_14"].rank(pct=True) 153 | # alpha_011 = (cs_rank(x_17) + cs_rank(x_18))*cs_rank(x_19) 154 | df["alpha_011"] = (df["x_17"].rank(pct=True) + df["x_18"].rank(pct=True)) * df["x_19"].rank(pct=True) 155 | return df 156 | 157 | 158 | def func_3_ts__asset__date(df: pd.DataFrame) -> pd.DataFrame: 159 | # ======================================== 160 | # alpha_015 = -ts_sum(x_26, 3) 161 | df["alpha_015"] = -df["x_26"].rolling(3).sum() 162 | return df 163 | 164 | 165 | logger.info("start...") 166 | 167 | 168 | df = df.sort_values(by=["asset", "date"]).groupby(by=["asset"], group_keys=False).apply(func_0_ts__asset__date) 169 | df = df.groupby(by=["date"], group_keys=False).apply(func_0_cs__date) 170 | df = df.sort_values(by=["asset", "date"]).groupby(by=["asset"], group_keys=False).apply(func_1_ts__asset__date) 171 | df = df.groupby(by=["date"], group_keys=False).apply(func_2_cs__date) 172 | df = df.sort_values(by=["asset", "date"]).groupby(by=["asset"], group_keys=False).apply(func_3_ts__asset__date) 173 | 174 | 175 | # #========================================func_0_ts__asset__date 176 | # x_10 = ts_sum(RETURNS, 5) 177 | # x_22 = ts_delta(RETURNS, 3) 178 | # x_32 = ts_sum(RETURNS, 250) + 1 179 | # x_0 = ts_arg_max(signed_power(if_else(RETURNS < 0, ts_std_dev(RETURNS, 20), CLOSE), 2.0), 5) 180 | # x_12 = ts_delta(CLOSE, 1) 181 | # x_29 = ts_rank(CLOSE, 10) 182 | # x_38 = ts_delta(CLOSE, 7) 183 | # x_1 = ts_delta(log(VOLUME), 2) 184 | # x_19 = ts_delta(VOLUME, 3) 185 | # x_9 = ts_sum(OPEN, 5) 186 | # x_34 = OPEN - ts_delay(CLOSE, 1) 187 | # x_37 = ts_corr(OPEN, VOLUME, 10) 188 | # x_36 = OPEN - ts_delay(LOW, 1) 189 | # x_8 = OPEN - ts_sum(VWAP, 10)/10 190 | # x_35 = OPEN - ts_delay(HIGH, 1) 191 | # x_30 = ts_rank(VOLUME/ADV20, 5) 192 | # #========================================func_0_cs__date 193 | # x_20 = cs_rank(CLOSE) 194 | # x_6 = cs_rank(VOLUME) 195 | # x_5 = cs_rank(OPEN) 196 | # x_7 = cs_rank(LOW) 197 | # x_24 = cs_rank(HIGH) 198 | # #========================================func_0_cs__date 199 | # x_3 = CLOSE - OPEN 200 | # x_15 = CLOSE - VWAP 201 | # #========================================func_0_cs__date 202 | # x_23 = cs_rank(x_22) 203 | # x_33 = cs_rank(x_32) + 1 204 | # alpha_001 = cs_rank(x_0) - 0.5 205 | # x_2 = cs_rank(x_1) 206 | # x_4 = cs_rank(x_3/OPEN) 207 | # alpha_005 = -abs(cs_rank(x_15))*cs_rank(x_8) 208 | # alpha_020 = -cs_rank(x_34)*cs_rank(x_35)*cs_rank(x_36) 209 | # #========================================func_0_cs__date 210 | # x_13 = -x_12 211 | # alpha_006 = -x_37 212 | # x_16 = -x_15 213 | # #========================================func_1_ts__asset__date 214 | # x_28 = ts_delta(x_12, 1) 215 | # alpha_012 = -x_12*sign(ts_delta(VOLUME, 1)) 216 | # alpha_007 = if_else(ADV20 < VOLUME, -sign(x_38)*ts_rank(abs(x_38), 60), -1) 217 | # x_21 = ts_covariance(x_20, x_6, 5) 218 | # x_31 = x_3 + ts_corr(CLOSE, OPEN, 10) + ts_std_dev(abs(x_3), 5) 219 | # alpha_003 = -ts_corr(x_5, x_6, 10) 220 | # x_11 = x_10*x_9 - ts_delay(x_10*x_9, 10) 221 | # alpha_004 = -ts_rank(x_7, 9) 222 | # x_25 = ts_corr(x_24, x_6, 3) 223 | # x_27 = ts_covariance(x_24, x_6, 5) 224 | # #========================================func_1_ts__asset__date 225 | # alpha_014 = -x_23*x_37 226 | # #========================================func_1_ts__asset__date 227 | # alpha_019 = -x_33*sign(CLOSE + x_38 - ts_delay(CLOSE, 7)) 228 | # x_14 = if_else(ts_min(x_12, 4) > 0, x_12, if_else(ts_max(x_12, 4) < 0, x_12, x_13)) 229 | # alpha_009 = if_else(ts_min(x_12, 5) > 0, x_12, if_else(ts_max(x_12, 5) < 0, x_12, x_13)) 230 | # alpha_002 = -ts_corr(x_2, x_4, 6) 231 | # x_17 = ts_max(x_16, 3) 232 | # x_18 = ts_min(x_16, 3) 233 | # #========================================func_2_cs__date 234 | # alpha_017 = -cs_rank(x_28)*cs_rank(x_29)*cs_rank(x_30) 235 | # alpha_013 = -cs_rank(x_21) 236 | # alpha_018 = -cs_rank(x_31) 237 | # alpha_008 = -cs_rank(x_11) 238 | # x_26 = cs_rank(x_25) 239 | # alpha_016 = -cs_rank(x_27) 240 | # #========================================func_2_cs__date 241 | # alpha_010 = cs_rank(x_14) 242 | # alpha_011 = (cs_rank(x_17) + cs_rank(x_18))*cs_rank(x_19) 243 | # #========================================func_3_ts__asset__date 244 | # alpha_015 = -ts_sum(x_26, 3) 245 | 246 | # alpha_001 = -0.5 + cs_rank(ts_arg_max(signed_power(if_else(RETURNS < 0, ts_std_dev(RETURNS, 20), CLOSE), 2.0), 5)) 247 | # alpha_002 = -ts_corr(cs_rank(ts_delta(log(VOLUME), 2)), cs_rank((CLOSE - OPEN)/OPEN), 6) 248 | # alpha_003 = -ts_corr(cs_rank(OPEN), cs_rank(VOLUME), 10) 249 | # alpha_004 = -ts_rank(cs_rank(LOW), 9) 250 | # alpha_005 = -abs(cs_rank(CLOSE - VWAP))*cs_rank(OPEN - ts_sum(VWAP, 10)/10) 251 | # alpha_006 = -ts_corr(OPEN, VOLUME, 10) 252 | # alpha_007 = if_else(ADV20 < VOLUME, -sign(ts_delta(CLOSE, 7))*ts_rank(abs(ts_delta(CLOSE, 7)), 60), -1) 253 | # alpha_008 = -cs_rank(-ts_delay(ts_sum(OPEN, 5)*ts_sum(RETURNS, 5), 10) + ts_sum(OPEN, 5)*ts_sum(RETURNS, 5)) 254 | # alpha_009 = if_else(ts_min(ts_delta(CLOSE, 1), 5) > 0, ts_delta(CLOSE, 1), if_else(ts_max(ts_delta(CLOSE, 1), 5) < 0, ts_delta(CLOSE, 1), -ts_delta(CLOSE, 1))) 255 | # alpha_010 = cs_rank(if_else(ts_min(ts_delta(CLOSE, 1), 4) > 0, ts_delta(CLOSE, 1), if_else(ts_max(ts_delta(CLOSE, 1), 4) < 0, ts_delta(CLOSE, 1), -ts_delta(CLOSE, 1)))) 256 | # alpha_011 = (cs_rank(ts_max(-CLOSE + VWAP, 3)) + cs_rank(ts_min(-CLOSE + VWAP, 3)))*cs_rank(ts_delta(VOLUME, 3)) 257 | # alpha_012 = -sign(ts_delta(VOLUME, 1))*ts_delta(CLOSE, 1) 258 | # alpha_013 = -cs_rank(ts_covariance(cs_rank(CLOSE), cs_rank(VOLUME), 5)) 259 | # alpha_014 = -cs_rank(ts_delta(RETURNS, 3))*ts_corr(OPEN, VOLUME, 10) 260 | # alpha_015 = -ts_sum(cs_rank(ts_corr(cs_rank(HIGH), cs_rank(VOLUME), 3)), 3) 261 | # alpha_016 = -cs_rank(ts_covariance(cs_rank(HIGH), cs_rank(VOLUME), 5)) 262 | # alpha_017 = -cs_rank(ts_delta(ts_delta(CLOSE, 1), 1))*cs_rank(ts_rank(CLOSE, 10))*cs_rank(ts_rank(VOLUME/ADV20, 5)) 263 | # alpha_018 = -cs_rank(CLOSE - OPEN + ts_corr(CLOSE, OPEN, 10) + ts_std_dev(abs(CLOSE - OPEN), 5)) 264 | # alpha_019 = -(cs_rank(ts_sum(RETURNS, 250) + 1) + 1)*sign(CLOSE - ts_delay(CLOSE, 7) + ts_delta(CLOSE, 7)) 265 | # alpha_020 = -cs_rank(OPEN - ts_delay(CLOSE, 1))*cs_rank(OPEN - ts_delay(HIGH, 1))*cs_rank(OPEN - ts_delay(LOW, 1)) 266 | 267 | # drop intermediate columns 268 | df = df.drop(columns=filter(lambda x: re.search(r"^x_\d+", x), df.columns)) 269 | 270 | 271 | logger.info("done") 272 | 273 | # save 274 | # df.to_parquet('output.parquet', compression='zstd') 275 | 276 | print(df.tail(5)) 277 | -------------------------------------------------------------------------------- /examples/output_pandas.py: -------------------------------------------------------------------------------- 1 | # this code is auto generated by the expr_codegen 2 | # https://github.com/wukan1986/expr_codegen 3 | # 此段代码由 expr_codegen 自动生成,欢迎提交 issue 或 pull request 4 | 5 | import re 6 | 7 | import numpy as np 8 | import pandas as pd 9 | import talib as ta 10 | 11 | from loguru import logger 12 | 13 | 14 | # TODO: 数据加载或外部传入 15 | df = df_input 16 | 17 | 18 | def signed_power(x, y): 19 | return x.sign() * (x.abs() ** y) 20 | 21 | 22 | def scale(x, scale=1): 23 | return x / x.abs().sum() * scale 24 | 25 | 26 | def neutralize(x): 27 | return x - x.mean() 28 | 29 | 30 | def func_0_ts__asset__date(df: pd.DataFrame) -> pd.DataFrame: 31 | df = df.sort_values(by=["date"]) 32 | # ======================================== 33 | # _x_0 = ts_mean(OPEN, 10) 34 | df["_x_0"] = df["OPEN"].rolling(10).mean() 35 | # expr_6 = ts_delta(OPEN, 10) 36 | df["expr_6"] = df["OPEN"].diff(10) 37 | # expr_7 = ts_rank(OPEN + 1, 10) 38 | df["expr_7"] = (df["OPEN"] + 1).rolling(10).rank(pct=True) 39 | # _x_1 = ts_mean(CLOSE, 10) 40 | df["_x_1"] = df["CLOSE"].rolling(10).mean() 41 | # expr_5 = -ts_corr(OPEN, CLOSE, 10) 42 | df["expr_5"] = -df["OPEN"].rolling(10).corr(df["CLOSE"], ddof=0) 43 | return df 44 | 45 | 46 | def func_0_cs__date(df: pd.DataFrame) -> pd.DataFrame: 47 | # ======================================== 48 | # _x_5 = cs_rank(OPEN) 49 | df["_x_5"] = df["OPEN"].rank(pct=True) 50 | return df 51 | 52 | 53 | def func_1_cs__date(df: pd.DataFrame) -> pd.DataFrame: 54 | # ======================================== 55 | # _x_2 = cs_rank(_x_0) 56 | df["_x_2"] = df["_x_0"].rank(pct=True) 57 | # _x_3 = cs_rank(_x_1) 58 | df["_x_3"] = df["_x_1"].rank(pct=True) 59 | return df 60 | 61 | 62 | def func_1_ts__asset__date(df: pd.DataFrame) -> pd.DataFrame: 63 | df = df.sort_values(by=["date"]) 64 | # ======================================== 65 | # _x_6 = ts_mean(_x_5, 10) 66 | df["_x_6"] = df["_x_5"].rolling(10).mean() 67 | # expr_8 = ts_rank(expr_7 + 1, 10) 68 | df["expr_8"] = (df["expr_7"] + 1).rolling(10).rank(pct=True) 69 | return df 70 | 71 | 72 | def func_2_cl(df: pd.DataFrame) -> pd.DataFrame: 73 | # ======================================== 74 | # expr_2 = _x_2 - Abs(log(_x_1)) 75 | df["expr_2"] = df["_x_2"] - (np.log(df["_x_1"])).abs() 76 | return df 77 | 78 | 79 | def func_2_ts__asset__date(df: pd.DataFrame) -> pd.DataFrame: 80 | df = df.sort_values(by=["date"]) 81 | # ======================================== 82 | # expr_3 = ts_mean(_x_2, 10) 83 | df["expr_3"] = df["_x_2"].rolling(10).mean() 84 | # expr_1 = -ts_corr(_x_2, _x_3, 10) 85 | df["expr_1"] = -df["_x_2"].rolling(10).corr(df["_x_3"], ddof=0) 86 | return df 87 | 88 | 89 | def func_2_cs__date(df: pd.DataFrame) -> pd.DataFrame: 90 | # ======================================== 91 | # expr_4 = cs_rank(_x_6) 92 | df["expr_4"] = df["_x_6"].rank(pct=True) 93 | return df 94 | 95 | 96 | # logger.info("start...") 97 | 98 | 99 | df = df.sort_values(by=["date", "asset"]).reset_index(drop=True) 100 | df = df.groupby(by=["asset"], group_keys=False).apply(func_0_ts__asset__date) 101 | df = df.groupby(by=["date"], group_keys=False).apply(func_0_cs__date) 102 | df = df.groupby(by=["date"], group_keys=False).apply(func_1_cs__date) 103 | df = df.groupby(by=["asset"], group_keys=False).apply(func_1_ts__asset__date) 104 | df = func_2_cl(df) 105 | df = df.groupby(by=["asset"], group_keys=False).apply(func_2_ts__asset__date) 106 | df = df.groupby(by=["date"], group_keys=False).apply(func_2_cs__date) 107 | 108 | 109 | # #========================================func_0_ts__asset__date 110 | # _x_0 = ts_mean(OPEN, 10) 111 | # expr_6 = ts_delta(OPEN, 10) 112 | # expr_7 = ts_rank(OPEN + 1, 10) 113 | # _x_1 = ts_mean(CLOSE, 10) 114 | # expr_5 = -ts_corr(OPEN, CLOSE, 10) 115 | # #========================================func_0_cs__date 116 | # _x_5 = cs_rank(OPEN) 117 | # #========================================func_1_cs__date 118 | # _x_2 = cs_rank(_x_0) 119 | # _x_3 = cs_rank(_x_1) 120 | # #========================================func_1_ts__asset__date 121 | # _x_6 = ts_mean(_x_5, 10) 122 | # expr_8 = ts_rank(expr_7 + 1, 10) 123 | # #========================================func_2_cl 124 | # expr_2 = _x_2 - Abs(log(_x_1)) 125 | # #========================================func_2_ts__asset__date 126 | # expr_3 = ts_mean(_x_2, 10) 127 | # expr_1 = -ts_corr(_x_2, _x_3, 10) 128 | # #========================================func_2_cs__date 129 | # expr_4 = cs_rank(_x_6) 130 | 131 | """ 132 | [OPEN, CLOSE, expr_7] 133 | """ 134 | 135 | """ 136 | expr_1 = -ts_corr(cs_rank(ts_mean(OPEN, 10)), cs_rank(ts_mean(CLOSE, 10)), 10) 137 | expr_2 = cs_rank(ts_mean(OPEN, 10)) - Abs(log(ts_mean(CLOSE, 10))) 138 | expr_3 = ts_mean(cs_rank(ts_mean(OPEN, 10)), 10) 139 | expr_4 = cs_rank(ts_mean(cs_rank(OPEN), 10)) 140 | expr_5 = -ts_corr(OPEN, CLOSE, 10) 141 | expr_6 = ts_delta(OPEN, 10) 142 | expr_8 = ts_rank(expr_7 + 1, 10) 143 | expr_7 = ts_rank(OPEN + 1, 10) 144 | """ 145 | 146 | # drop intermediate columns 147 | df = df.drop(columns=list(filter(lambda x: re.search(r"^_x_\d+", x), df.columns))) 148 | 149 | 150 | # logger.info('done') 151 | 152 | # save 153 | # df.to_parquet('output.parquet', compression='zstd') 154 | 155 | # print(df.tail(5)) 156 | 157 | # 向外部传出数据 158 | df_output = df 159 | -------------------------------------------------------------------------------- /examples/output_polars.py: -------------------------------------------------------------------------------- 1 | # this code is auto generated by the expr_codegen 2 | # https://github.com/wukan1986/expr_codegen 3 | # 此段代码由 expr_codegen 自动生成,欢迎提交 issue 或 pull request 4 | 5 | import numpy as np # noqa 6 | import pandas as pd # noqa 7 | import polars as pl # noqa 8 | import polars.selectors as cs # noqa 9 | from loguru import logger # noqa 10 | 11 | # =================================== 12 | # 导入优先级,例如:ts_RSI在ta与talib中都出现了,优先使用ta 13 | # 运行时,后导入覆盖前导入,但IDE智能提示是显示先导入的 14 | _ = 0 # 只要之前出现了语句,之后的import位置不参与调整 15 | # from polars_ta.prefix.talib import * # noqa 16 | from polars_ta.prefix.tdx import * # noqa 17 | from polars_ta.prefix.ta import * # noqa 18 | from polars_ta.prefix.wq import * # noqa 19 | from polars_ta.prefix.cdl import * # noqa 20 | 21 | # =================================== 22 | 23 | _ = ( 24 | "CLOSE", 25 | "HIGH", 26 | "LOW", 27 | "OPEN", 28 | "SMA_005", 29 | "SMA_010", 30 | "SMA_020", 31 | "IMAX_005", 32 | "IMIN_005", 33 | "IMAX_010", 34 | "IMIN_010", 35 | "IMAX_020", 36 | "IMIN_020", 37 | "IMAX_060", 38 | "IMIN_060", 39 | ) 40 | ( 41 | CLOSE, 42 | HIGH, 43 | LOW, 44 | OPEN, 45 | SMA_005, 46 | SMA_010, 47 | SMA_020, 48 | IMAX_005, 49 | IMIN_005, 50 | IMAX_010, 51 | IMIN_010, 52 | IMAX_020, 53 | IMIN_020, 54 | IMAX_060, 55 | IMIN_060, 56 | ) = (pl.col(i) for i in _) 57 | 58 | _ = ( 59 | "_x_0", 60 | "ROCP_001", 61 | "ROCP_003", 62 | "ROCP_005", 63 | "ROCP_010", 64 | "ROCP_020", 65 | "ROCP_060", 66 | "ROCP_120", 67 | "TS_RANK_005", 68 | "TS_RANK_010", 69 | "TS_RANK_020", 70 | "TS_RANK_060", 71 | "TS_RANK_120", 72 | "TS_SCALE_005", 73 | "TS_SCALE_010", 74 | "TS_SCALE_020", 75 | "TS_SCALE_060", 76 | "TS_SCALE_120", 77 | "RSI_006", 78 | "RSI_012", 79 | "RSI_024", 80 | "PSY_010", 81 | "PSY_020", 82 | "PPO_12_26", 83 | "IMAX_120", 84 | "IMIN_120", 85 | "RSV_005", 86 | "RSV_010", 87 | "RSV_020", 88 | "RSV_060", 89 | "WILLR_006", 90 | "WILLR_010", 91 | "NATR_006", 92 | "NATR_012", 93 | "NATR_024", 94 | "ADX_014", 95 | "ADX_021", 96 | "AR_010", 97 | "AR_020", 98 | "HHV_005", 99 | "HHV_010", 100 | "HHV_020", 101 | "HHV_060", 102 | "HHV_120", 103 | "LLV_005", 104 | "LLV_010", 105 | "LLV_020", 106 | "LLV_060", 107 | "LLV_120", 108 | "SMA_060", 109 | "SMA_120", 110 | "STD_005", 111 | "STD_010", 112 | "STD_020", 113 | "STD_060", 114 | "STD_120", 115 | "IMXD_005", 116 | "IMXD_010", 117 | "IMXD_020", 118 | "IMXD_060", 119 | "SMA_001_005", 120 | "SMA_005_010", 121 | "SMA_010_020", 122 | ) 123 | ( 124 | _x_0, 125 | ROCP_001, 126 | ROCP_003, 127 | ROCP_005, 128 | ROCP_010, 129 | ROCP_020, 130 | ROCP_060, 131 | ROCP_120, 132 | TS_RANK_005, 133 | TS_RANK_010, 134 | TS_RANK_020, 135 | TS_RANK_060, 136 | TS_RANK_120, 137 | TS_SCALE_005, 138 | TS_SCALE_010, 139 | TS_SCALE_020, 140 | TS_SCALE_060, 141 | TS_SCALE_120, 142 | RSI_006, 143 | RSI_012, 144 | RSI_024, 145 | PSY_010, 146 | PSY_020, 147 | PPO_12_26, 148 | IMAX_120, 149 | IMIN_120, 150 | RSV_005, 151 | RSV_010, 152 | RSV_020, 153 | RSV_060, 154 | WILLR_006, 155 | WILLR_010, 156 | NATR_006, 157 | NATR_012, 158 | NATR_024, 159 | ADX_014, 160 | ADX_021, 161 | AR_010, 162 | AR_020, 163 | HHV_005, 164 | HHV_010, 165 | HHV_020, 166 | HHV_060, 167 | HHV_120, 168 | LLV_005, 169 | LLV_010, 170 | LLV_020, 171 | LLV_060, 172 | LLV_120, 173 | SMA_060, 174 | SMA_120, 175 | STD_005, 176 | STD_010, 177 | STD_020, 178 | STD_060, 179 | STD_120, 180 | IMXD_005, 181 | IMXD_010, 182 | IMXD_020, 183 | IMXD_060, 184 | SMA_001_005, 185 | SMA_005_010, 186 | SMA_010_020, 187 | ) = (pl.col(i) for i in _) 188 | 189 | _DATE_ = "date" 190 | _ASSET_ = "asset" 191 | 192 | 193 | def func_0_cl(df: pl.DataFrame) -> pl.DataFrame: 194 | # ======================================== 195 | df = df.with_columns( 196 | _x_0=1 / CLOSE, 197 | ) 198 | return df 199 | 200 | 201 | def func_0_ts__asset(df: pl.DataFrame) -> pl.DataFrame: 202 | df = df.sort(_DATE_) 203 | # ======================================== 204 | df = df.with_columns( 205 | ROCP_001=ts_returns(CLOSE, 1), 206 | ROCP_003=ts_returns(CLOSE, 3), 207 | ROCP_005=ts_returns(CLOSE, 5), 208 | ROCP_010=ts_returns(CLOSE, 10), 209 | ROCP_020=ts_returns(CLOSE, 20), 210 | ROCP_060=ts_returns(CLOSE, 60), 211 | ROCP_120=ts_returns(CLOSE, 120), 212 | TS_RANK_005=ts_rank(CLOSE, 5), 213 | TS_RANK_010=ts_rank(CLOSE, 10), 214 | TS_RANK_020=ts_rank(CLOSE, 20), 215 | TS_RANK_060=ts_rank(CLOSE, 60), 216 | TS_RANK_120=ts_rank(CLOSE, 120), 217 | TS_SCALE_005=ts_scale(CLOSE, 5), 218 | TS_SCALE_010=ts_scale(CLOSE, 10), 219 | TS_SCALE_020=ts_scale(CLOSE, 20), 220 | TS_SCALE_060=ts_scale(CLOSE, 60), 221 | TS_SCALE_120=ts_scale(CLOSE, 120), 222 | RSI_006=ts_RSI(CLOSE, 6), 223 | RSI_012=ts_RSI(CLOSE, 12), 224 | RSI_024=ts_RSI(CLOSE, 24), 225 | PSY_010=ts_PSY(CLOSE, 10), 226 | PSY_020=ts_PSY(CLOSE, 20), 227 | PPO_12_26=ts_PPO(CLOSE, 12, 26), 228 | IMAX_005=ts_arg_max(HIGH, 5) / 5, 229 | IMAX_010=ts_arg_max(HIGH, 10) / 10, 230 | IMAX_020=ts_arg_max(HIGH, 20) / 20, 231 | IMAX_060=ts_arg_max(HIGH, 60) / 60, 232 | IMAX_120=ts_arg_max(HIGH, 120) / 120, 233 | IMIN_005=ts_arg_min(LOW, 5) / 5, 234 | IMIN_010=ts_arg_min(LOW, 10) / 10, 235 | IMIN_020=ts_arg_min(LOW, 20) / 20, 236 | IMIN_060=ts_arg_min(LOW, 60) / 60, 237 | IMIN_120=ts_arg_min(LOW, 120) / 120, 238 | RSV_005=ts_RSV(HIGH, LOW, CLOSE, 5), 239 | RSV_010=ts_RSV(HIGH, LOW, CLOSE, 10), 240 | RSV_020=ts_RSV(HIGH, LOW, CLOSE, 20), 241 | RSV_060=ts_RSV(HIGH, LOW, CLOSE, 60), 242 | WILLR_006=ts_WILLR(HIGH, LOW, CLOSE, 6), 243 | WILLR_010=ts_WILLR(HIGH, LOW, CLOSE, 10), 244 | NATR_006=ts_NATR(HIGH, LOW, CLOSE, 6), 245 | NATR_012=ts_NATR(HIGH, LOW, CLOSE, 12), 246 | NATR_024=ts_NATR(HIGH, LOW, CLOSE, 24), 247 | ADX_014=ts_ADX(HIGH, LOW, CLOSE, 14), 248 | ADX_021=ts_ADX(HIGH, LOW, CLOSE, 21), 249 | AR_010=ts_BRAR_AR(OPEN, HIGH, LOW, CLOSE, 10), 250 | AR_020=ts_BRAR_AR(OPEN, HIGH, LOW, CLOSE, 20), 251 | ) 252 | # ======================================== 253 | df = df.with_columns( 254 | HHV_005=_x_0 * ts_max(HIGH, 5), 255 | HHV_010=_x_0 * ts_max(HIGH, 10), 256 | HHV_020=_x_0 * ts_max(HIGH, 20), 257 | HHV_060=_x_0 * ts_max(HIGH, 60), 258 | HHV_120=_x_0 * ts_max(HIGH, 120), 259 | LLV_005=_x_0 * ts_min(LOW, 5), 260 | LLV_010=_x_0 * ts_min(LOW, 10), 261 | LLV_020=_x_0 * ts_min(LOW, 20), 262 | LLV_060=_x_0 * ts_min(LOW, 60), 263 | LLV_120=_x_0 * ts_min(LOW, 120), 264 | SMA_005=_x_0 * ts_mean(CLOSE, 5), 265 | SMA_010=_x_0 * ts_mean(CLOSE, 10), 266 | SMA_020=_x_0 * ts_mean(CLOSE, 20), 267 | SMA_060=_x_0 * ts_mean(CLOSE, 60), 268 | SMA_120=_x_0 * ts_mean(CLOSE, 120), 269 | STD_005=_x_0 * ts_std_dev(CLOSE, 5), 270 | STD_010=_x_0 * ts_std_dev(CLOSE, 10), 271 | STD_020=_x_0 * ts_std_dev(CLOSE, 20), 272 | STD_060=_x_0 * ts_std_dev(CLOSE, 60), 273 | STD_120=_x_0 * ts_std_dev(CLOSE, 120), 274 | ) 275 | return df 276 | 277 | 278 | def func_1_cl(df: pl.DataFrame) -> pl.DataFrame: 279 | # ======================================== 280 | df = df.with_columns( 281 | IMXD_005=IMAX_005 - IMIN_005, 282 | IMXD_010=IMAX_010 - IMIN_010, 283 | IMXD_020=IMAX_020 - IMIN_020, 284 | IMXD_060=IMAX_060 - IMIN_060, 285 | ) 286 | # ======================================== 287 | df = df.with_columns( 288 | SMA_001_005=CLOSE / SMA_005, 289 | SMA_005_010=SMA_005 / SMA_010, 290 | SMA_010_020=SMA_010 / SMA_020, 291 | ) 292 | return df 293 | 294 | 295 | """ 296 | #========================================func_0_cl 297 | _x_0 = 1/CLOSE 298 | #========================================func_0_ts__asset 299 | ROCP_001 = ts_returns(CLOSE, 1) 300 | ROCP_003 = ts_returns(CLOSE, 3) 301 | ROCP_005 = ts_returns(CLOSE, 5) 302 | ROCP_010 = ts_returns(CLOSE, 10) 303 | ROCP_020 = ts_returns(CLOSE, 20) 304 | ROCP_060 = ts_returns(CLOSE, 60) 305 | ROCP_120 = ts_returns(CLOSE, 120) 306 | TS_RANK_005 = ts_rank(CLOSE, 5) 307 | TS_RANK_010 = ts_rank(CLOSE, 10) 308 | TS_RANK_020 = ts_rank(CLOSE, 20) 309 | TS_RANK_060 = ts_rank(CLOSE, 60) 310 | TS_RANK_120 = ts_rank(CLOSE, 120) 311 | TS_SCALE_005 = ts_scale(CLOSE, 5) 312 | TS_SCALE_010 = ts_scale(CLOSE, 10) 313 | TS_SCALE_020 = ts_scale(CLOSE, 20) 314 | TS_SCALE_060 = ts_scale(CLOSE, 60) 315 | TS_SCALE_120 = ts_scale(CLOSE, 120) 316 | RSI_006 = ts_RSI(CLOSE, 6) 317 | RSI_012 = ts_RSI(CLOSE, 12) 318 | RSI_024 = ts_RSI(CLOSE, 24) 319 | PSY_010 = ts_PSY(CLOSE, 10) 320 | PSY_020 = ts_PSY(CLOSE, 20) 321 | PPO_12_26 = ts_PPO(CLOSE, 12, 26) 322 | IMAX_005 = ts_arg_max(HIGH, 5)/5 323 | IMAX_010 = ts_arg_max(HIGH, 10)/10 324 | IMAX_020 = ts_arg_max(HIGH, 20)/20 325 | IMAX_060 = ts_arg_max(HIGH, 60)/60 326 | IMAX_120 = ts_arg_max(HIGH, 120)/120 327 | IMIN_005 = ts_arg_min(LOW, 5)/5 328 | IMIN_010 = ts_arg_min(LOW, 10)/10 329 | IMIN_020 = ts_arg_min(LOW, 20)/20 330 | IMIN_060 = ts_arg_min(LOW, 60)/60 331 | IMIN_120 = ts_arg_min(LOW, 120)/120 332 | RSV_005 = ts_RSV(HIGH, LOW, CLOSE, 5) 333 | RSV_010 = ts_RSV(HIGH, LOW, CLOSE, 10) 334 | RSV_020 = ts_RSV(HIGH, LOW, CLOSE, 20) 335 | RSV_060 = ts_RSV(HIGH, LOW, CLOSE, 60) 336 | WILLR_006 = ts_WILLR(HIGH, LOW, CLOSE, 6) 337 | WILLR_010 = ts_WILLR(HIGH, LOW, CLOSE, 10) 338 | NATR_006 = ts_NATR(HIGH, LOW, CLOSE, 6) 339 | NATR_012 = ts_NATR(HIGH, LOW, CLOSE, 12) 340 | NATR_024 = ts_NATR(HIGH, LOW, CLOSE, 24) 341 | ADX_014 = ts_ADX(HIGH, LOW, CLOSE, 14) 342 | ADX_021 = ts_ADX(HIGH, LOW, CLOSE, 21) 343 | AR_010 = ts_BRAR_AR(OPEN, HIGH, LOW, CLOSE, 10) 344 | AR_020 = ts_BRAR_AR(OPEN, HIGH, LOW, CLOSE, 20) 345 | #========================================func_0_ts__asset 346 | HHV_005 = _x_0*ts_max(HIGH, 5) 347 | HHV_010 = _x_0*ts_max(HIGH, 10) 348 | HHV_020 = _x_0*ts_max(HIGH, 20) 349 | HHV_060 = _x_0*ts_max(HIGH, 60) 350 | HHV_120 = _x_0*ts_max(HIGH, 120) 351 | LLV_005 = _x_0*ts_min(LOW, 5) 352 | LLV_010 = _x_0*ts_min(LOW, 10) 353 | LLV_020 = _x_0*ts_min(LOW, 20) 354 | LLV_060 = _x_0*ts_min(LOW, 60) 355 | LLV_120 = _x_0*ts_min(LOW, 120) 356 | SMA_005 = _x_0*ts_mean(CLOSE, 5) 357 | SMA_010 = _x_0*ts_mean(CLOSE, 10) 358 | SMA_020 = _x_0*ts_mean(CLOSE, 20) 359 | SMA_060 = _x_0*ts_mean(CLOSE, 60) 360 | SMA_120 = _x_0*ts_mean(CLOSE, 120) 361 | STD_005 = _x_0*ts_std_dev(CLOSE, 5) 362 | STD_010 = _x_0*ts_std_dev(CLOSE, 10) 363 | STD_020 = _x_0*ts_std_dev(CLOSE, 20) 364 | STD_060 = _x_0*ts_std_dev(CLOSE, 60) 365 | STD_120 = _x_0*ts_std_dev(CLOSE, 120) 366 | #========================================func_1_cl 367 | IMXD_005 = IMAX_005 - IMIN_005 368 | IMXD_010 = IMAX_010 - IMIN_010 369 | IMXD_020 = IMAX_020 - IMIN_020 370 | IMXD_060 = IMAX_060 - IMIN_060 371 | #========================================func_1_cl 372 | SMA_001_005 = CLOSE/SMA_005 373 | SMA_005_010 = SMA_005/SMA_010 374 | SMA_010_020 = SMA_010/SMA_020 375 | """ 376 | 377 | """ 378 | HHV_005 = ts_max(HIGH, 5)/CLOSE 379 | HHV_010 = ts_max(HIGH, 10)/CLOSE 380 | HHV_020 = ts_max(HIGH, 20)/CLOSE 381 | HHV_060 = ts_max(HIGH, 60)/CLOSE 382 | HHV_120 = ts_max(HIGH, 120)/CLOSE 383 | LLV_005 = ts_min(LOW, 5)/CLOSE 384 | LLV_010 = ts_min(LOW, 10)/CLOSE 385 | LLV_020 = ts_min(LOW, 20)/CLOSE 386 | LLV_060 = ts_min(LOW, 60)/CLOSE 387 | LLV_120 = ts_min(LOW, 120)/CLOSE 388 | SMA_005 = ts_mean(CLOSE, 5)/CLOSE 389 | SMA_010 = ts_mean(CLOSE, 10)/CLOSE 390 | SMA_020 = ts_mean(CLOSE, 20)/CLOSE 391 | SMA_060 = ts_mean(CLOSE, 60)/CLOSE 392 | SMA_120 = ts_mean(CLOSE, 120)/CLOSE 393 | STD_005 = ts_std_dev(CLOSE, 5)/CLOSE 394 | STD_010 = ts_std_dev(CLOSE, 10)/CLOSE 395 | STD_020 = ts_std_dev(CLOSE, 20)/CLOSE 396 | STD_060 = ts_std_dev(CLOSE, 60)/CLOSE 397 | STD_120 = ts_std_dev(CLOSE, 120)/CLOSE 398 | ROCP_001 = ts_returns(CLOSE, 1) 399 | ROCP_003 = ts_returns(CLOSE, 3) 400 | ROCP_005 = ts_returns(CLOSE, 5) 401 | ROCP_010 = ts_returns(CLOSE, 10) 402 | ROCP_020 = ts_returns(CLOSE, 20) 403 | ROCP_060 = ts_returns(CLOSE, 60) 404 | ROCP_120 = ts_returns(CLOSE, 120) 405 | TS_RANK_005 = ts_rank(CLOSE, 5) 406 | TS_RANK_010 = ts_rank(CLOSE, 10) 407 | TS_RANK_020 = ts_rank(CLOSE, 20) 408 | TS_RANK_060 = ts_rank(CLOSE, 60) 409 | TS_RANK_120 = ts_rank(CLOSE, 120) 410 | TS_SCALE_005 = ts_scale(CLOSE, 5) 411 | TS_SCALE_010 = ts_scale(CLOSE, 10) 412 | TS_SCALE_020 = ts_scale(CLOSE, 20) 413 | TS_SCALE_060 = ts_scale(CLOSE, 60) 414 | TS_SCALE_120 = ts_scale(CLOSE, 120) 415 | IMAX_005 = ts_arg_max(HIGH, 5)/5 416 | IMAX_010 = ts_arg_max(HIGH, 10)/10 417 | IMAX_020 = ts_arg_max(HIGH, 20)/20 418 | IMAX_060 = ts_arg_max(HIGH, 60)/60 419 | IMAX_120 = ts_arg_max(HIGH, 120)/120 420 | IMIN_005 = ts_arg_min(LOW, 5)/5 421 | IMIN_010 = ts_arg_min(LOW, 10)/10 422 | IMIN_020 = ts_arg_min(LOW, 20)/20 423 | IMIN_060 = ts_arg_min(LOW, 60)/60 424 | IMIN_120 = ts_arg_min(LOW, 120)/120 425 | RSI_006 = ts_RSI(CLOSE, 6) 426 | RSI_012 = ts_RSI(CLOSE, 12) 427 | RSI_024 = ts_RSI(CLOSE, 24) 428 | RSV_005 = ts_RSV(HIGH, LOW, CLOSE, 5) 429 | RSV_010 = ts_RSV(HIGH, LOW, CLOSE, 10) 430 | RSV_020 = ts_RSV(HIGH, LOW, CLOSE, 20) 431 | RSV_060 = ts_RSV(HIGH, LOW, CLOSE, 60) 432 | WILLR_006 = ts_WILLR(HIGH, LOW, CLOSE, 6) 433 | WILLR_010 = ts_WILLR(HIGH, LOW, CLOSE, 10) 434 | NATR_006 = ts_NATR(HIGH, LOW, CLOSE, 6) 435 | NATR_012 = ts_NATR(HIGH, LOW, CLOSE, 12) 436 | NATR_024 = ts_NATR(HIGH, LOW, CLOSE, 24) 437 | ADX_014 = ts_ADX(HIGH, LOW, CLOSE, 14) 438 | ADX_021 = ts_ADX(HIGH, LOW, CLOSE, 21) 439 | AR_010 = ts_BRAR_AR(OPEN, HIGH, LOW, CLOSE, 10) 440 | AR_020 = ts_BRAR_AR(OPEN, HIGH, LOW, CLOSE, 20) 441 | PSY_010 = ts_PSY(CLOSE, 10) 442 | PSY_020 = ts_PSY(CLOSE, 20) 443 | PPO_12_26 = ts_PPO(CLOSE, 12, 26) 444 | SMA_001_005 = CLOSE/SMA_005 445 | SMA_005_010 = SMA_005/SMA_010 446 | SMA_010_020 = SMA_010/SMA_020 447 | IMXD_005 = IMAX_005 - IMIN_005 448 | IMXD_010 = IMAX_010 - IMIN_010 449 | IMXD_020 = IMAX_020 - IMIN_020 450 | IMXD_060 = IMAX_060 - IMIN_060 451 | """ 452 | 453 | 454 | def main(df: pl.DataFrame) -> pl.DataFrame: 455 | # logger.info("start...") 456 | 457 | df = func_0_cl(df).drop(*[]) 458 | df = df.sort(_ASSET_, _DATE_).group_by(_ASSET_).map_groups(func_0_ts__asset).drop(*["_x_0"]) 459 | df = func_1_cl(df).drop(*[]) 460 | 461 | # drop intermediate columns 462 | # df = df.select(pl.exclude(r'^_x_\d+$')) 463 | df = df.select(~cs.starts_with("_")) 464 | 465 | # shrink 466 | df = df.select(cs.all().shrink_dtype()) 467 | df = df.shrink_to_fit() 468 | 469 | # logger.info('done') 470 | 471 | # save 472 | # df.write_parquet('output.parquet') 473 | 474 | return df 475 | 476 | 477 | if __name__ in ("__main__", "builtins"): 478 | # TODO: 数据加载或外部传入 479 | df_output = main(df_input) -------------------------------------------------------------------------------- /examples/prefilter.py: -------------------------------------------------------------------------------- 1 | """ 2 | 由于因子的计算步骤很多,又很耗时,如果能提前横截面过滤,只计算重要的品种,是否能加快计算? 3 | """ 4 | import time 5 | 6 | import pandas as pd 7 | import polars as pl 8 | from polars_ta.prefix.wq import * 9 | 10 | from expr_codegen import codegen_exec 11 | 12 | _N = 500 13 | _K = 5000 14 | 15 | asset = [f's_{i:04d}' for i in range(_K)] 16 | date = pd.date_range('2015-1-1', periods=_N) 17 | 18 | df = pd.DataFrame({ 19 | 'RETURNS': np.cumprod(1 + np.random.uniform(-0.1, 0.1, size=(_N, _K)), axis=0).reshape(-1), 20 | 'VWAP': np.cumprod(1 + np.random.uniform(-0.1, 0.1, size=(_N, _K)), axis=0).reshape(-1), 21 | 'LOW': np.cumprod(1 + np.random.uniform(-0.1, 0.1, size=(_N, _K)), axis=0).reshape(-1), 22 | 'CLOSE': np.cumprod(1 + np.random.uniform(-0.1, 0.1, size=(_N, _K)), axis=0).reshape(-1), 23 | }, index=pd.MultiIndex.from_product([date, asset], names=['date', 'asset'])).reset_index() 24 | df = pl.from_pandas(df) 25 | 26 | 27 | def _code_block_1(): 28 | cond1 = ts_returns(CLOSE, 5) 29 | cond2 = cond1 > 0 30 | 31 | 32 | df1 = codegen_exec(df, _code_block_1, over_null='partition_by', filter_last=True).select('asset', 'cond2').filter(pl.col('cond2')) 33 | print(df1) 34 | # 后面只对cond2=true的计算 35 | 36 | 37 | 38 | t1 = time.perf_counter() 39 | # 方案1 40 | df2 = pl.concat([df, df1], how='align_left').filter(pl.col('cond2')) 41 | t2 = time.perf_counter() 42 | # 方案2 43 | df2 = df.join(df1, on=['asset'], how='left').filter(pl.col('cond2')) 44 | t3 = time.perf_counter() 45 | # 方案3 46 | assets = set(df1['asset'].to_list()) 47 | df2 = df.filter(pl.col('asset').is_in(assets)) 48 | t4 = time.perf_counter() 49 | print("耗时比较", t2 - t1, t3 - t2, t4 - t3) 50 | 51 | 52 | def _code_block_2(): 53 | MA1 = ts_mean(CLOSE, 5) 54 | MA2 = ts_mean(CLOSE, 10) 55 | MA3 = ts_mean(CLOSE, 20) 56 | 57 | 58 | df3 = codegen_exec(df2, _code_block_2, over_null='partition_by', filter_last=True) 59 | print(df3) 60 | -------------------------------------------------------------------------------- /examples/show_tree.py: -------------------------------------------------------------------------------- 1 | from matplotlib import pyplot as plt 2 | from polars_ta.prefix.cdl import * # noqa 3 | from polars_ta.prefix.ta import * # noqa 4 | from polars_ta.prefix.tdx import * # noqa 5 | from polars_ta.prefix.wq import * # noqa 6 | from sympy import numbered_symbols 7 | 8 | from examples.sympy_define import * 9 | from expr_codegen.codes import sources_to_exprs 10 | from expr_codegen.dag import zero_outdegree 11 | from expr_codegen.model import create_dag_exprs, init_dag_exprs, draw_expr_tree, merge_nodes_1, merge_nodes_2 12 | from expr_codegen.tool import ExprTool 13 | 14 | RETURNS, VWAP, = symbols('RETURNS, VWAP, ', cls=Symbol) 15 | 16 | exprs_src = """ 17 | alpha_001=( 18 | cs_rank(ts_arg_max(signed_power(if_else((RETURNS < 0), ts_std_dev(RETURNS, 20), CLOSE), 2.), 5)) - 0.5) 19 | alpha_002=(-1 * ts_corr(cs_rank(ts_delta(log(VOLUME), 2)), cs_rank(((CLOSE - OPEN) / OPEN)), 6)) 20 | alpha_003=(-1 * ts_corr(cs_rank(OPEN), cs_rank(VOLUME), 10)) 21 | alpha_004=(-1 * ts_rank(cs_rank(LOW), 9)) 22 | alpha_005=(cs_rank((OPEN - (ts_sum(VWAP, 10) / 10))) * (-1 * abs_(cs_rank((CLOSE - VWAP))))) 23 | alpha_006= -1 * ts_corr(OPEN, VOLUME, 10) 24 | """ 25 | # # 表达式设置 26 | # exprs_src = """ 27 | # _A = OPEN * CLOSE 28 | # _B = CLOSE * VOLUME 29 | # C = _A > _B or True if _A else _B 30 | # """ 31 | exprs_src = """ 32 | _A = OPEN * CLOSE 33 | _B = CLOSE * VOLUME 34 | C = (_A > _B) + (_A == _B) 35 | """ 36 | raw, exprs_src = sources_to_exprs(globals().copy(), exprs_src, convert_xor=False) 37 | 38 | tool = ExprTool() 39 | # 子表达式在前,原表式在最后 40 | exprs_dst, syms_dst = tool.merge("date", "asset", **exprs_src) 41 | 42 | # 提取公共表达式 43 | exprs_dict = tool.cse(exprs_dst, symbols_repl=numbered_symbols('x_'), symbols_redu=exprs_src.keys()) 44 | 45 | # 创建DAG 46 | G = create_dag_exprs(exprs_dict) 47 | G = init_dag_exprs(G, tool.get_current_func, tool.get_current_func_kwargs, "date", "asset") 48 | 49 | keep_nodes = [k for k in exprs_src.keys() if not k.startswith('_')] 50 | # keep_nodes = exprs_src.keys() 51 | # 以下可以看到节点的合并过程 52 | zero = zero_outdegree(G) 53 | for z in zero: 54 | print(z) 55 | # 在同一画布上画上下两图 56 | fig, axs = plt.subplots(2, 1) 57 | draw_expr_tree(G, z, ax=axs[0]) 58 | merge_nodes_1(G, keep_nodes, z) 59 | merge_nodes_2(G, keep_nodes, z) 60 | draw_expr_tree(G, z, ax=axs[1]) 61 | plt.show() 62 | -------------------------------------------------------------------------------- /examples/sympy_define.py: -------------------------------------------------------------------------------- 1 | """ 2 | !!! 所有新补充的`Function`, 如果表示方式特殊则需要在`printer.py`中添加对应的处理代码 3 | 4 | # 由于与buildins中函数重复,所以重新定义max_, min_, abs_ 5 | # sign由于会被翻译成Piecewise,所以使用自义函数sign 6 | 7 | """ 8 | from sympy import Symbol, Function, symbols # noqa 9 | 10 | # 由于实现了函数名自注册,现在只要import即可,如果你要使用其它库也可以修改此处 11 | _ = 0 # 只要之前出现了语句,之后的import位置不参与调整 12 | # talib在Streamlit免费托管平台安装失败,先屏蔽 13 | # from polars_ta.prefix.talib import * # noqa 14 | from polars_ta.prefix.tdx import * # noqa 15 | from polars_ta.prefix.ta import * # noqa 16 | from polars_ta.prefix.wq import * # noqa 17 | from polars_ta.prefix.cdl import * # noqa 18 | 19 | # TODO: 通用算子。时序、横截面和整体都能使用的算子。请根据需要补充 20 | 21 | # TODO: 时序算子。需要提前按资产分组,组内按时间排序。请根据需要补充。必需以`ts_`开头 22 | 23 | # TODO: 横截面算子。需要提前按时间分组。请根据需要补充。必需以`cs_`开头 24 | 25 | # TODO: 分组算子。需要提前按时间、行业分组。必需以`gp_`开头 26 | gp_rank, gp_demean, = symbols('gp_rank, gp_demean, ', cls=Function) 27 | 28 | # TODO: 因子。请根据需要补充 29 | OPEN, HIGH, LOW, CLOSE, VOLUME, AMOUNT, OPEN_INTEREST, VWAP, = symbols('OPEN, HIGH, LOW, CLOSE, VOLUME, AMOUNT, OPEN_INTEREST, VWAP, ', cls=Symbol) 30 | -------------------------------------------------------------------------------- /examples/tail_n.py: -------------------------------------------------------------------------------- 1 | """ 2 | 估计最小数据长度 3 | 4 | 1. 停牌、上新股。都会对结果有影响 5 | 2. EMA算法特殊,参数10时,提供10个数据,和11个数据,结果是不一样的 6 | 3. MACD等不少指标底层是EMA 7 | """ 8 | import numpy as np 9 | import pandas as pd 10 | import polars as pl 11 | 12 | from expr_codegen import codegen_exec 13 | 14 | # TODO 预留500交易日 15 | _N = 500 16 | # TODO 可换成股票数 17 | _K = 10 18 | 19 | asset = [f's_{i:04d}' for i in range(_K)] 20 | date = pd.date_range('2015-1-1', periods=_N) 21 | 22 | df = pd.DataFrame({ 23 | 'VWAP': np.cumprod(1 + np.random.uniform(-0.1, 0.1, size=(_N, _K)), axis=0).reshape(-1), 24 | 'LOW': np.cumprod(1 + np.random.uniform(-0.1, 0.1, size=(_N, _K)), axis=0).reshape(-1), 25 | 'CLOSE': np.cumprod(1 + np.random.uniform(-0.1, 0.1, size=(_N, _K)), axis=0).reshape(-1), 26 | }, index=pd.MultiIndex.from_product([date, asset], names=['date', 'asset'])).reset_index() 27 | # TODO 这里可以考虑替换成真实数据,但停牌会对结果有影响 28 | df = pl.from_pandas(df) 29 | 30 | 31 | def _code_block_1(): 32 | # TODO 替换成自己的因子表达式 33 | A1 = ts_mean(CLOSE, 5) 34 | A2 = ts_returns(A1, 5) 35 | A3 = cs_rank(A2) 36 | A4 = ts_returns(A3, 10) 37 | 38 | 39 | # TODO 可以用于事后检查计算得参数是否正确 40 | df1: pl.DataFrame = codegen_exec(df, _code_block_1, over_null='partition_by') 41 | # 检查null数量’ 42 | print(df1.null_count()) 43 | # 单股票时,时序上最大null数+1,就是tail理论最小参数 44 | # 但部分截面因子无法在单票上得出有效值,所以还是得换成多股票版 45 | n = df1.null_count().max_horizontal()[0] + _K 46 | print("tail理论最小参数", n) 47 | print("理论上单支股票最小数据量", n / _K) 48 | 49 | # 这里一定要排序后再tail 50 | df2: pl.DataFrame = codegen_exec(df.sort('date', 'asset').tail(n), _code_block_1, over_null='partition_by') 51 | # 如果设置正确,这里应当看到最后一行非null,倒数第二行null 52 | print(df2.sort('asset', 'date').tail(3)) 53 | -------------------------------------------------------------------------------- /expr_codegen/__init__.py: -------------------------------------------------------------------------------- 1 | from expr_codegen._version import __version__ 2 | from expr_codegen.tool import codegen_exec 3 | -------------------------------------------------------------------------------- /expr_codegen/_version.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.13.2" 2 | -------------------------------------------------------------------------------- /expr_codegen/codes.py: -------------------------------------------------------------------------------- 1 | import ast 2 | import re 3 | from ast import expr 4 | 5 | import ast_comments 6 | from black import Mode, format_str 7 | from sympy import Add, Mul, Pow, Eq, Not, Xor 8 | 9 | from expr_codegen.expr import register_symbols, list_to_exprs 10 | 11 | 12 | class SyntaxTransformer(ast.NodeTransformer): 13 | """修改语法。注意:一定要修改语法后才能改名""" 14 | 15 | def __init__(self, convert_xor): 16 | # ^ 是异或还是乘方呢? 17 | self.convert_xor = convert_xor 18 | 19 | def visit_Assign(self, node): 20 | t = node.targets[0] 21 | nodes = [] 22 | if isinstance(t, ast.Tuple): 23 | for i, dim in enumerate(t.dims): 24 | _v = ast.Call( 25 | func=ast.Name(id='unpack', ctx=ast.Load()), 26 | args=[node.value, ast.Constant(i)], 27 | keywords=[], 28 | ) 29 | n = ast.Assign([dim], _v, ctx=ast.Load()) 30 | nodes.append(n) 31 | return nodes 32 | 33 | self.generic_visit(node) 34 | return node 35 | 36 | def visit_Compare(self, node): 37 | assert len(node.comparators) == 1, f"不支持连续等号,请手工添加括号, {ast.unparse(node)}" 38 | 39 | self.generic_visit(node) 40 | return node 41 | 42 | def visit_IfExp(self, node): 43 | # 三元表达式。需要在外部提前替换成or True if else 44 | # 只要body区域,出现了or True,就认为是特殊处理过的 45 | if isinstance(node.body, ast.BoolOp) and isinstance(node.body.op, ast.Or): 46 | if isinstance(node.body.values[-1], ast.Constant): 47 | if node.body.values[-1].value: 48 | node.test, node.body = node.body.values[0], node.test 49 | 50 | node = ast.Call( 51 | func=ast.Name(id='if_else', ctx=ast.Load()), 52 | args=[node.test, node.body, node.orelse], 53 | keywords=[], 54 | ) 55 | 56 | self.generic_visit(node) 57 | return node 58 | 59 | def visit_BinOp(self, node): 60 | # TypeError: unsupported operand type(s) for *: 'StrictLessThan' and 'int' 61 | if isinstance(node.op, (ast.Mult, ast.Add, ast.Div, ast.Sub)): 62 | # (OPEN < CLOSE) * -1 63 | if isinstance(node.left, ast.Compare): 64 | node.left = ast.Call( 65 | func=ast.Name(id='int_', ctx=ast.Load()), 66 | args=[node.left], 67 | keywords=[], 68 | ) 69 | # -1*(OPEN < CLOSE) 70 | if isinstance(node.right, ast.Compare): 71 | node.right = ast.Call( 72 | func=ast.Name(id='int_', ctx=ast.Load()), 73 | args=[node.right], 74 | keywords=[], 75 | ) 76 | # 这种情况,已经包含 77 | # (OPEN < CLOSE)*(OPEN < CLOSE) 78 | 79 | if isinstance(node.op, ast.BitXor): 80 | # ^ 运算符,转换为pow还是xor 81 | if self.convert_xor: 82 | node = ast.Call( 83 | func=ast.Name(id='Pow', ctx=ast.Load()), 84 | args=[node.left, node.right], 85 | keywords=[], 86 | ) 87 | else: 88 | node = ast.Call( 89 | func=ast.Name(id='Xor', ctx=ast.Load()), 90 | args=[node.left, node.right], 91 | keywords=[], 92 | ) 93 | 94 | self.generic_visit(node) 95 | return node 96 | 97 | def visit_UnaryOp(self, node): 98 | # ~ts_delay 报错,替换成Not(ts_delay) 99 | if isinstance(node.op, ast.Invert): 100 | node = ast.Call( 101 | func=ast.Name(id='Not', ctx=ast.Load()), 102 | args=[node.operand], 103 | keywords=[], 104 | ) 105 | 106 | self.generic_visit(node) 107 | return node 108 | 109 | def visit_Subscript(self, node): 110 | if isinstance(node.slice, ast.Constant) and node.slice.value == 0: 111 | node = node.value 112 | elif isinstance(node.slice, ast.UnaryOp) and isinstance(node.slice.operand, 113 | ast.Constant) and node.slice.operand.value == 0: 114 | node = node.value 115 | else: 116 | node = ast.Call( 117 | func=ast.Name(id='ts_delay', ctx=ast.Load()), 118 | args=[node.value, node.slice], 119 | keywords=[], 120 | ) 121 | self.generic_visit(node) 122 | return node 123 | 124 | 125 | class RenameTransformer(ast.NodeTransformer): 126 | """改名处理。改名前需要语法规范""" 127 | 128 | def __init__(self, funcs_map, targets_map, args_map=None): 129 | 130 | if args_map is None: 131 | # 保留字 132 | args_map = {'True': "_TRUE_", 'False': "_FALSE_", 'None': "_NONE_"} 133 | self.funcs_old = set() 134 | self.args_old = set() 135 | self.targets_old = set() 136 | self.funcs_new = set() 137 | self.args_new = set() 138 | self.targets_new = set() 139 | # 映射 140 | self.funcs_map = funcs_map 141 | # 由于None等常量无法在sympy中正确处理,只能改成Symbol变量 142 | # !!!一定要在drop_symbols时排除 143 | self.args_map = args_map 144 | # 只对非下划线开头的生效 145 | self.targets_map = targets_map 146 | 147 | def visit_Call(self, node): 148 | # 提取函数名 149 | self.funcs_old.add(node.func.id) 150 | node.func.id = self.funcs_map.get(node.func.id, node.func.id) 151 | self.funcs_new.add(node.func.id) 152 | # 提取参数名 153 | for i, arg in enumerate(node.args): 154 | if isinstance(arg, ast.Name): 155 | self.args_old.add(arg.id) 156 | arg.id = self.args_map.get(arg.id, arg.id) 157 | self.args_new.add(arg.id) 158 | if isinstance(arg, ast.Constant): 159 | old_arg_value = str(arg.value) 160 | if old_arg_value in self.args_map: 161 | new_arg_value = self.args_map.get(old_arg_value, old_arg_value) 162 | self.args_old.add(old_arg_value) 163 | node.args[i] = ast.Name(new_arg_value, ctx=ast.Load()) 164 | self.args_new.add(new_arg_value) 165 | 166 | self.generic_visit(node) 167 | return node 168 | 169 | def __visit_Assign(self, target: expr): 170 | old_target_id = target.id 171 | new_target_id = self.targets_map.get(old_target_id, old_target_id) 172 | self.targets_old.add(old_target_id) 173 | 174 | # 赋值给下划线开头代码时,对其进行重命名,方便重复书写表达式时不冲突 175 | if old_target_id.startswith('_'): 176 | # 减少与cse中_x_冲突 177 | new_target_id = f'{old_target_id}_{len(self.targets_new)}_' 178 | 179 | if old_target_id != new_target_id: 180 | self.targets_new.add(new_target_id) 181 | target.id = new_target_id 182 | # 记录修改的变量名,之后会使用到 183 | self.args_map[old_target_id] = new_target_id 184 | 185 | if isinstance(target, ast.Constant): 186 | old_target_value = str(target.value) 187 | if old_target_value in self.args_map: 188 | new_target_value = self.args_map.get(old_target_value, old_target_value) 189 | self.args_old.add(old_target_value) 190 | target = ast.Name(new_target_value, ctx=ast.Load()) 191 | self.args_new.add(new_target_value) 192 | 193 | return target 194 | 195 | def visit_Assign(self, node): 196 | # 调整位置,支持循环赋值 197 | # _A = _A+1 调整成 _A_001 = _A_000 + 1 198 | self.generic_visit(node) 199 | 200 | # 提取输出变量名 201 | for i, target in enumerate(node.targets): 202 | if isinstance(target, ast.Tuple): 203 | for j, t in enumerate(target.elts): 204 | target.elts[j] = self.__visit_Assign(t) 205 | else: 206 | node.targets[i] = self.__visit_Assign(target) 207 | 208 | # 处理 alpha=close 这种情况 209 | if isinstance(node.value, ast.Name): 210 | self.args_old.add(node.value.id) 211 | node.value.id = self.args_map.get(node.value.id, node.value.id) 212 | self.args_new.add(node.value.id) 213 | if isinstance(node.value, ast.Constant): 214 | old_node_value = str(node.value.value) 215 | if old_node_value in self.args_map: 216 | new_node_value = self.args_map.get(old_node_value, old_node_value) 217 | self.args_old.add(old_node_value) 218 | node.value = ast.Name(new_node_value, ctx=ast.Load()) 219 | self.args_new.add(new_node_value) 220 | 221 | return node 222 | 223 | def visit_Compare(self, node): 224 | # 比较符的左右也可能是变量,要处理 225 | if isinstance(node.left, ast.Name): 226 | self.args_old.add(node.left.id) 227 | node.left.id = self.args_map.get(node.left.id, node.left.id) 228 | self.args_new.add(node.left.id) 229 | 230 | for i, com in enumerate(node.comparators): 231 | if isinstance(com, ast.Name): 232 | self.args_old.add(com.id) 233 | com.id = self.args_map.get(com.id, com.id) 234 | self.args_new.add(com.id) 235 | if isinstance(com, ast.Constant): 236 | old_com_value = str(com.value) 237 | if old_com_value in self.args_map: 238 | new_com_value = self.args_map.get(old_com_value, old_com_value) 239 | self.args_old.add(old_com_value) 240 | node.comparators[i] = ast.Name(new_com_value, ctx=ast.Load()) 241 | self.args_new.add(new_com_value) 242 | 243 | self.generic_visit(node) 244 | return node 245 | 246 | def visit_IfExp(self, node): 247 | if isinstance(node.body, ast.Name): 248 | self.args_old.add(node.body.id) 249 | node.body.id = self.args_map.get(node.body.id, node.body.id) 250 | self.args_new.add(node.body.id) 251 | if isinstance(node.orelse, ast.Name): 252 | self.args_old.add(node.orelse.id) 253 | node.orelse.id = self.args_map.get(node.orelse.id, node.orelse.id) 254 | self.args_new.add(node.orelse.id) 255 | 256 | self.generic_visit(node) 257 | return node 258 | 259 | def visit_BinOp(self, node): 260 | if isinstance(node.left, ast.Name): 261 | self.args_old.add(node.left.id) 262 | node.left.id = self.args_map.get(node.left.id, node.left.id) 263 | self.args_new.add(node.left.id) 264 | if isinstance(node.right, ast.Name): 265 | self.args_old.add(node.right.id) 266 | node.right.id = self.args_map.get(node.right.id, node.right.id) 267 | self.args_new.add(node.right.id) 268 | if isinstance(node.left, ast.Constant): 269 | old_node_value = str(node.left.value) 270 | if old_node_value in self.args_map: 271 | new_node_value = self.args_map.get(old_node_value, old_node_value) 272 | self.args_old.add(old_node_value) 273 | node.left = ast.Name(new_node_value, ctx=ast.Load()) 274 | self.args_new.add(new_node_value) 275 | if isinstance(node.right, ast.Constant): 276 | old_node_value = str(node.right.value) 277 | if old_node_value in self.args_map: 278 | new_node_value = self.args_map.get(old_node_value, old_node_value) 279 | self.args_old.add(old_node_value) 280 | node.right = ast.Name(new_node_value, ctx=ast.Load()) 281 | self.args_new.add(new_node_value) 282 | 283 | self.generic_visit(node) 284 | return node 285 | 286 | def visit_UnaryOp(self, node): 287 | # -x 288 | if isinstance(node.operand, ast.Name): 289 | self.args_old.add(node.operand.id) 290 | node.operand.id = self.args_map.get(node.operand.id, node.operand.id) 291 | self.args_new.add(node.operand.id) 292 | if isinstance(node.operand, ast.Constant): 293 | old_operand_value = str(node.operand.value) 294 | if old_operand_value in self.args_map: 295 | new_operand_value = self.args_map.get(old_operand_value, old_operand_value) 296 | self.args_old.add(old_operand_value) 297 | node.operand = ast.Name(new_operand_value, ctx=ast.Load()) 298 | self.args_new.add(new_operand_value) 299 | 300 | self.generic_visit(node) 301 | return node 302 | 303 | def visit_Subscript(self, node): 304 | self.args_old.add(node.value.id) 305 | node.value.id = self.args_map.get(node.value.id, node.value.id) 306 | self.args_new.add(node.value.id) 307 | 308 | self.generic_visit(node) 309 | return node 310 | 311 | 312 | def source_replace(source: str) -> str: 313 | # 三元表达式转换成 错误版if( )else,一定得在Transformer中修正 314 | num = 1 315 | while num > 0: 316 | # 利用or 的优先级最低,构造特殊的if else,只要出现,就认为位置要替换 317 | # C?T:F --> C or True if( T )else F 318 | source, num = re.subn(r'\?(.+?):(.+?)', r' or True if( \1 )else \2', source, flags=re.S) 319 | # break 320 | # 或、与 321 | source = source.replace('||', '|').replace('&&', '&') 322 | # IndentationError: unexpected indent 323 | # 嵌套函数前有空格,会报错 324 | source = format_str(source, mode=Mode(line_length=600, magic_trailing_comma=True)) 325 | return source 326 | 327 | 328 | def assigns_to_dict(assigns): 329 | """赋值表达式转成字典""" 330 | return {ast.unparse(a.targets): ast.unparse(a.value) for a in assigns} 331 | 332 | 333 | def assigns_to_list(assigns): 334 | """赋值表达式转成列表""" 335 | outputs = [] 336 | for i, a in enumerate(assigns): 337 | comment = "#" 338 | if i + 1 < len(assigns): 339 | b = assigns[i + 1] 340 | if isinstance(b, ast_comments.Comment): 341 | # comment = ast_comments.unparse(b) 342 | comment = b.value 343 | if isinstance(a, ast.Assign): 344 | outputs.append((ast.unparse(a.targets), ast.unparse(a.value), comment)) 345 | return outputs 346 | 347 | 348 | def raw_to_code(raw): 349 | """导入语句转字符列表""" 350 | return '\n'.join([ast.unparse(a) for a in raw]) 351 | 352 | 353 | def sources_to_asts(*sources, convert_xor: bool): 354 | """输入多份源代码""" 355 | 356 | def _source_to_asts(source): 357 | """源代码""" 358 | tree = ast_comments.parse(source_replace(source)) 359 | 360 | if isinstance(tree.body[0], ast.FunctionDef): 361 | body = tree.body[0].body 362 | else: 363 | body = tree.body 364 | 365 | return body 366 | 367 | tree = ast_comments.parse("") 368 | for arg in sources: 369 | tree.body.extend(_source_to_asts(arg)) 370 | 371 | t1 = SyntaxTransformer(convert_xor) 372 | t1.visit(tree) 373 | t = RenameTransformer({}, {}) 374 | t.visit(tree) 375 | 376 | raw = [] 377 | assigns = [] 378 | 379 | for i, node in enumerate(tree.body): 380 | # 特殊处理的节点 381 | if isinstance(node, ast.Assign): 382 | assigns.append(node) 383 | continue 384 | # TODO 是否要把其它语句也加入?是否有安全问题? 385 | if isinstance(node, (ast.Import, ast.ImportFrom, ast.FunctionDef, ast.ClassDef)): 386 | raw.append(node) 387 | continue 388 | if isinstance(node, ast_comments.Comment): 389 | # 添加注释 390 | if node.inline and isinstance(tree.body[i - 1], ast.Assign): 391 | assigns.append(node) 392 | continue 393 | 394 | return raw_to_code(raw), assigns_to_list(assigns), t.funcs_new, t.args_new, t.targets_new 395 | 396 | 397 | def _add_default_type(globals_): 398 | # 这种写法可以省去由用户导入Eq一类的工作 399 | globals_['Add'] = Add 400 | globals_['Mul'] = Mul 401 | globals_['Pow'] = Pow 402 | globals_['Eq'] = Eq 403 | globals_['Not'] = Not 404 | globals_['Xor'] = Xor 405 | return globals_ 406 | 407 | 408 | def sources_to_exprs(globals_, *sources, convert_xor: bool): 409 | """将源代码转换成表达式""" 410 | 411 | globals_ = _add_default_type(globals_) 412 | 413 | raw, assigns, funcs_new, args_new, targets_new = sources_to_asts(*sources, convert_xor=convert_xor) 414 | # 支持OPEN[1]转ts_delay(OPEN,1) 415 | funcs_new.add('ts_delay') 416 | 417 | register_symbols(funcs_new, globals_, is_function=True) 418 | register_symbols(args_new, globals_, is_function=False) 419 | register_symbols(targets_new, globals_, is_function=False) 420 | return raw, list_to_exprs(assigns, globals_) 421 | -------------------------------------------------------------------------------- /expr_codegen/dag.py: -------------------------------------------------------------------------------- 1 | from itertools import product 2 | 3 | import networkx as nx 4 | 5 | 6 | def zero_indegree(G: nx.DiGraph): 7 | """入度为0的所有节点""" 8 | return [v for v, d in G.in_degree() if d == 0] 9 | 10 | 11 | def zero_outdegree(G: nx.DiGraph): 12 | """出度为0的所有节点""" 13 | return [v for v, d in G.out_degree() if d == 0] 14 | 15 | 16 | def skip_node(G: nx.DiGraph, node): 17 | """跳过中间节点,将两端的节点直接连接起来 18 | 19 | 1. (A,B,C) 模式,直接成 (A,C) 20 | 2. (A,B,C), (D, B) 模式,变成 (A,C),(D,C) 21 | """ 22 | pred = G.pred[node] 23 | succ = G.succ[node] 24 | if len(pred) == 0 or len(succ) == 0: 25 | return G 26 | # 这里用了product生成多个关联边 27 | G.add_edges_from(product(pred, succ)) 28 | G.remove_node(node) 29 | return G 30 | 31 | 32 | def remove_paths(G: nx.DiGraph, *args): 33 | """删除路径。选择一个叶子节点,会删到不影响其它分支停止 34 | 35 | 对于Y型,会全删除 36 | """ 37 | # 准备一个当前节点列表 38 | this_pred = args 39 | # 下一步不为空就继续 40 | while this_pred: 41 | next_pred = [] 42 | for node in this_pred: 43 | # 有可能多路径删除时,已经被删 44 | if not G.has_node(node): 45 | continue 46 | # 出度为0 47 | if len(G.succ[node]) == 0: 48 | # 找到所有上游节点 49 | next_pred.extend(G.pred[node]) 50 | G.remove_node(node) 51 | # 更新下一次循环 52 | this_pred = list(set(next_pred)) 53 | return G 54 | 55 | 56 | def remove_paths_by_zero_outdegree(G: nx.DiGraph, exclude): 57 | """删除悬空路径 58 | 59 | 注意:如果没有设置要排除,可能全图被删""" 60 | nodes = zero_outdegree(G) 61 | # 悬空 62 | dangling = set(nodes) - set(exclude) 63 | return remove_paths(G, *dangling) 64 | 65 | 66 | def show_nodes(G): 67 | for i, generation in enumerate(nx.topological_generations(G)): 68 | print(i, '=' * 20, generation) 69 | for node in generation: 70 | print(G.nodes[node]) 71 | 72 | 73 | def show_paths(G: nx.DiGraph, *args): 74 | """显示路径 75 | """ 76 | # 准备一个当前节点列表 77 | this_pred = args 78 | # 下一步不为空就继续 79 | while this_pred: 80 | next_pred = [] 81 | for node in this_pred: 82 | print(G.nodes[node]) 83 | next_pred.extend(G.pred[node]) 84 | 85 | # 更新下一次循环 86 | this_pred = list(set(next_pred)) 87 | return G 88 | 89 | 90 | def node_included_path(G: nx.DiGraph, source): 91 | """"得到节点所在路径 92 | 93 | TODO: 总感觉官方提供了类似方法,就是没找到 94 | """ 95 | pred = nx.ancestors(G, source) 96 | succ = nx.descendants(G, source) 97 | # set先后没有区别 98 | return pred | succ | {source} 99 | 100 | 101 | # 根据原版做了修改,树结构顶部为根,向下生成。与表达式正好相反,所以这里特意将找节点的方向反过来 102 | # https://stackoverflow.com/questions/29586520/can-one-get-hierarchical-graphs-from-networkx-with-python-3/ 103 | def hierarchy_pos(G, root, levels=None, width=1., height=1.): 104 | """If there is a cycle that is reachable from root, then this will see infinite recursion. 105 | G: the graph 106 | root: the root node 107 | levels: a dictionary 108 | key: level number (starting from 0) 109 | value: number of nodes in this level 110 | width: horizontal space allocated for drawing 111 | height: vertical space allocated for drawing""" 112 | TOTAL = "total" 113 | CURRENT = "current" 114 | 115 | def make_levels(levels, node=root, currentLevel=0, parent=None): 116 | """Compute the number of nodes for each level 117 | """ 118 | if not currentLevel in levels: 119 | levels[currentLevel] = {TOTAL: 0, CURRENT: 0} 120 | levels[currentLevel][TOTAL] += 1 121 | neighbors = G.predecessors(node) 122 | for neighbor in neighbors: 123 | if not neighbor == parent: 124 | levels = make_levels(levels, neighbor, currentLevel + 1, node) 125 | return levels 126 | 127 | def make_pos(pos, node=root, currentLevel=0, parent=None, vert_loc=0): 128 | dx = 1 / levels[currentLevel][TOTAL] 129 | left = dx / 2 130 | pos[node] = ((left + dx * levels[currentLevel][CURRENT]) * width, vert_loc) 131 | levels[currentLevel][CURRENT] += 1 132 | neighbors = G.predecessors(node) 133 | for neighbor in neighbors: 134 | if not neighbor == parent: 135 | pos = make_pos(pos, neighbor, currentLevel + 1, node, vert_loc - vert_gap) 136 | return pos 137 | 138 | if levels is None: 139 | levels = make_levels({}) 140 | else: 141 | levels = {l: {TOTAL: levels[l], CURRENT: 0} for l in levels} 142 | vert_gap = height / (max([l for l in levels]) + 1) 143 | return make_pos({}) 144 | -------------------------------------------------------------------------------- /expr_codegen/expr.py: -------------------------------------------------------------------------------- 1 | from functools import reduce 2 | 3 | from sympy import Mul, preorder_traversal, symbols, Function, simplify, Add, Basic, Symbol, sympify, FunctionClass 4 | 5 | # 预定义前缀,算子用前缀进行区分更方便。 6 | # 当然也可以用是否在某容器中进行分类 7 | CL = 'cl' # 列算子, column 8 | TS = 'ts' # 时序算子, time-series 9 | CS = 'cs' # 横截面算子 cross section 10 | GP = 'gp' # 分组算子。group 分组越小,速度越慢 11 | 12 | CL_TUP = (CL,) # 整列元组 13 | CL_SET = {CL_TUP} # 整列集合 14 | 15 | 16 | def is_symbol(x, globals_): 17 | s = globals_.get(x, None) 18 | if s is None: 19 | return False 20 | if isinstance(s, Symbol): 21 | # OPEN 22 | return True 23 | if type(s) is type: 24 | # Eq 25 | return issubclass(s, Basic) 26 | if isinstance(s, FunctionClass): 27 | # Not 28 | return True 29 | return False 30 | 31 | 32 | def register_symbols(syms, globals_, is_function: bool): 33 | """注册sympy中需要使用的符号""" 34 | # Eq等已经是sympy的符号不需注册 35 | syms = [s for s in syms if not is_symbol(s, globals_)] 36 | if len(syms) == 0: 37 | return globals_ 38 | 39 | if is_function: 40 | # 函数被注册后不能再调用,所以一定要用globals().copy() 41 | syms = symbols(','.join(syms), cls=Function, seq=True) 42 | else: 43 | syms = symbols(','.join(syms), cls=Symbol, seq=True) 44 | syms = {s.name: s for s in syms} 45 | globals_.update(syms) 46 | return globals_ 47 | 48 | 49 | def list_to_exprs(exprs_src, globals_): 50 | return [(k, sympify(v, globals_, evaluate=False), c) for k, v, c in exprs_src] 51 | 52 | 53 | def append_node(node, output_exprs): 54 | """添加到队列。其中,-x将转为x 55 | 56 | 此举是为了防止公共表达式中出现大量-x这种情况 57 | 58 | Parameters 59 | ---------- 60 | node 61 | 表达式当前节点 62 | output_exprs 63 | 表达式列表 64 | 65 | Returns 66 | ------- 67 | 表达式列表 68 | 69 | """ 70 | if isinstance(node, Mul): 71 | if node.args[0] == -1 or node.args[0] == 1: 72 | # 可能是-1也可能是1.0 73 | for arg in node.args[1:]: 74 | if arg.is_Atom: 75 | continue 76 | output_exprs.append(arg) 77 | # print(1, arg) 78 | else: 79 | output_exprs.append(node) 80 | # print(2, node) 81 | else: 82 | output_exprs.append(node) 83 | # print(3, node) 84 | 85 | return output_exprs 86 | 87 | 88 | def get_symbols(expr, syms=None, return_str=True): 89 | """得到""" 90 | if syms is None: 91 | syms = [] 92 | 93 | for arg in expr.args: 94 | if arg.is_Symbol: 95 | if return_str: 96 | syms.append(arg.name) 97 | else: 98 | syms.append(arg) 99 | elif arg.is_Number: 100 | # alpha_001 = log(1)+1 101 | if return_str: 102 | syms.append(str(arg)) 103 | else: 104 | syms.append(arg) 105 | else: 106 | get_symbols(arg, syms, return_str) 107 | return syms 108 | 109 | 110 | def is_NegativeX(expr): 111 | """-x, 但-ts_sum格式不返回False""" 112 | if isinstance(expr, Mul): 113 | if expr.args[0] == -1 and len(expr.args) == 2 and expr.args[1].is_Atom: 114 | return True 115 | return False 116 | 117 | 118 | def is_simple_expr(expr): 119 | if isinstance(expr, Mul): 120 | if expr.args[0] == -1 and len(expr.args) == 2 and expr.args[1].is_Atom: 121 | return True 122 | if isinstance(expr, Symbol): 123 | return True 124 | return False 125 | 126 | 127 | def get_current_by_prefix(expr, date, asset, **kwargs): 128 | """表达式根节点信息。按名称前缀。例如 129 | 130 | OPEN取的是OPEN,得cl 131 | ts_mean取的ts_mean,得ts 132 | -ts_mean取的是-,得cl 133 | """ 134 | if expr.is_Function: 135 | if hasattr(expr, 'name'): # Or 没有名字 136 | prefix1 = expr.name[2] 137 | if prefix1 == '_': 138 | prefix2 = expr.name[:2] 139 | 140 | if prefix2 == TS: 141 | return TS, asset 142 | if prefix2 == CS: 143 | return CS, date 144 | if prefix2 == GP: 145 | return GP, date, expr.args[0].name 146 | # 不需分组 147 | return CL_TUP 148 | 149 | 150 | def get_current_by_name(expr, ts_names, cs_names, gp_names, date, asset, **kwargs): 151 | """表达式根节点信息。按名称。 152 | 153 | Parameters 154 | ---------- 155 | expr 156 | ts_names 157 | 时序算子名称字符串集合 158 | cs_names 159 | 横截面算子名称字符串集合 160 | gp_names 161 | 分组算子名称字符串集合 162 | date 163 | 日期字符串 164 | asset 165 | 资产字符串 166 | kwargs 167 | 168 | """ 169 | if expr.is_Function: 170 | if hasattr(expr, 'name'): # Or 没有名字 171 | if expr.name in ts_names: 172 | return TS, asset 173 | if expr.name in cs_names: 174 | return CS, date 175 | if expr.name in gp_names: 176 | return GP, date, expr.args[0].name 177 | 178 | # 不需分组 179 | return CL_TUP 180 | 181 | 182 | # 调试用,勿删 183 | # __level__ = 0 184 | 185 | 186 | def get_children(func, func_kwargs, expr, output_exprs, output_symbols, date, asset): 187 | """表达式整体信息。例如 188 | 189 | -ts_corr返回{ts}而不是 {cl} 190 | -ts_corr+cs_rank返回{ts,cs} 191 | -OPEN-CLOSE返回{cl} 192 | 193 | Parameters 194 | ---------- 195 | func 196 | 表达式根分类函数 197 | func_kwargs 198 | func对应的参数字典 199 | expr 200 | 表达式 201 | output_exprs 202 | 输出分割后的了表达式 203 | output_symbols 204 | 输出每个子表达式中的符号 205 | date 206 | asset 207 | 208 | Returns 209 | ------- 210 | 211 | """ 212 | # global __level__ 213 | # __level__ += 1 214 | 215 | try: 216 | curr = func(expr, date, asset, **func_kwargs) 217 | children = [get_children(func, func_kwargs, a, output_exprs, output_symbols, date, asset) for a in expr.args] 218 | 219 | # print(expr, curr, children, __level__) 220 | # if __level__ == 6: 221 | # print(expr, curr, children) 222 | 223 | # 多个集合合成一个去重 224 | unique = reduce(lambda x, y: x | y, children, set()) - CL_SET 225 | 226 | if len(unique) >= 2: 227 | # 大于1,表示内部不统一,内部都要处理 228 | for i, child in enumerate(children): 229 | # alpha_047无法正确输出 230 | if expr.args[i].is_Atom: 231 | # print(expr.args[i], 'is_Atom 1') 232 | continue 233 | output_exprs = append_node(expr.args[i], output_exprs) 234 | elif curr[0] != CL: 235 | # 外部与内部不同,需处理 236 | for i, child in enumerate(children): 237 | # 不在子中即表示不同 238 | if curr in child: 239 | continue 240 | if expr.args[i].is_Atom: 241 | # print(expr.args[i], 'is_Atom 2') 242 | continue 243 | output_exprs = append_node(expr.args[i], output_exprs) 244 | else: 245 | # ts_sum(OPEN, 5)*ts_sum(RETURNS, 5) ('cl',) [{('ts', 'asset', 'date')}, {('ts', 'asset', 'date')}] 6 alpha_008 246 | pass 247 | # if isinstance(expr, Mul): 248 | # output_exprs = append_node(expr, output_exprs) 249 | 250 | # 按需返回,当前是基础算子就返回下一层信息,否则返回当前 251 | if curr[0] == CL: 252 | if expr.is_Symbol: 253 | # 汇总符号列表 254 | output_symbols.append(expr) 255 | # 返回子中出现过的集合{ts cs gp} 256 | return unique 257 | else: 258 | # 当前算子,非列算子,直接返回,如{ts} {cs} {gp} 259 | return {curr} 260 | finally: 261 | # __level__ -= 1 262 | pass 263 | 264 | 265 | def get_key(children): 266 | """!!!此函数只能在先抽取出子表达式后再cse,然后才能调用。否则报错。 267 | 268 | 为了保证expr能正确分组,只有一种分法 269 | 270 | Parameters 271 | ---------- 272 | 273 | 274 | Returns 275 | ------- 276 | 用于字典的键 277 | 278 | """ 279 | if len(children) == 0: 280 | # OPEN等因子会走这一步 281 | return CL_TUP 282 | elif len(children) == 1: 283 | # 只有一种分法,最合适的方法 284 | return list(children)[0] 285 | else: 286 | assert False, f'{children} 无法正确分类,之前没有分清' 287 | 288 | 289 | def replace_exprs(exprs): 290 | """使用替换的方式简化表达式""" 291 | # Alpha101中大量ts_sum(x, 10)/10, 转成ts_mean(x, 10) 292 | exprs = [(k, _replace__ts_sum__to__ts_mean(v), c) for k, v, c in exprs] 293 | # alpha_031中大量cs_rank(cs_rank(x)) 转成cs_rank(x) 294 | exprs = [(k, _replace__repeat(v), c) for k, v, c in exprs] 295 | # 1.0*VWAP转VWAP 296 | exprs = [(k, _replace__one_mul(v), c) for k, v, c in exprs] 297 | # 将部分参数为1的ts函数进行简化 298 | exprs = [(k, _replace__ts_xxx_1(v), c) for k, v, c in exprs] 299 | # ts_delay转成ts_delta 300 | exprs = [(k, _replace__ts_delay__to__ts_delta(v), c) for k, v, c in exprs] 301 | 302 | return exprs 303 | 304 | 305 | def get_node_name(node): 306 | """得到节点名""" 307 | if hasattr(node, 'name'): 308 | # 如 ts_arg_max 309 | node_name = node.name 310 | else: 311 | # 如 log 312 | node_name = str(node.func) 313 | return node_name 314 | 315 | 316 | def _replace__ts_sum__to__ts_mean(e): 317 | """将ts_sum(x, y)/y 转成 ts_mean(x, y)""" 318 | if not isinstance(e, Basic): 319 | return e 320 | 321 | # TODO: 这里重新定义的ts_mean与外部已经定义好的是否同一个? 322 | ts_mean = symbols('ts_mean', cls=Function) 323 | 324 | replacements = [] 325 | for node in preorder_traversal(e): 326 | if node.is_Mul and node.args[0].is_Rational and node.args[1].is_Function: 327 | node_name = get_node_name(node.args[1]) 328 | if node_name == 'ts_sum': 329 | if node.args[1].args[1] == node.args[0].q and node.args[0].p == 1: 330 | replacements.append((node, ts_mean(node.args[1].args[0], node.args[1].args[1]))) 331 | for node, replacement in replacements: 332 | print(node, ' -> ', replacement) 333 | e = e.xreplace({node: replacement}) 334 | return e 335 | 336 | 337 | def _replace__repeat(e): 338 | """cs_rank(cs_rank(x)) 转成 cs_rank(x) 339 | sign(sign(x)) 转成 sign(x) 340 | Abs(Abs(x)) 转成 Abs(x) 341 | """ 342 | if not isinstance(e, Basic): 343 | return e 344 | 345 | replacements = [] 346 | for node in preorder_traversal(e): 347 | # print(node) 348 | if len(node.args) == 0: 349 | continue 350 | node_name = get_node_name(node) 351 | node_args0_name = get_node_name(node.args[0]) 352 | if node_name == node_args0_name: 353 | if node_name in ('cs_rank', 'sign', 'Abs', 'abs_'): 354 | replacements.append((node, node.args[0])) 355 | for node, replacement in replacements: 356 | print(node, ' -> ', replacement) 357 | e = e.xreplace({node: replacement}) 358 | return e 359 | 360 | 361 | def _replace__one_mul(e): 362 | """1.0*VWAP转成VWAP""" 363 | if not isinstance(e, Basic): 364 | return e 365 | 366 | replacements = [] 367 | for node in preorder_traversal(e): 368 | # print(node) 369 | if isinstance(node, Mul) and node.args[0] == 1: 370 | if len(node.args) > 2: 371 | replacements.append((node, Mul._from_args(node.args[1:]))) 372 | else: 373 | replacements.append((node, node.args[1])) 374 | for node, replacement in replacements: 375 | print(node, ' -> ', replacement) 376 | e = e.xreplace({node: replacement}) 377 | return e 378 | 379 | 380 | def _replace__ts_xxx_1(e): 381 | """ts_xxx部分函数如果参数为1,可直接丢弃""" 382 | if not isinstance(e, Basic): 383 | return e 384 | 385 | replacements = [] 386 | for node in preorder_traversal(e): 387 | node_name = get_node_name(node) 388 | if node_name in ('ts_mean', 'ts_sum', 'ts_decay_linear', 389 | 'ts_max', 'ts_min', 'ts_arg_max', 'ts_arg_min', 390 | 'ts_product', 'ts_std_dev', 'ts_rank'): 391 | try: 392 | if node.args[1] <= 1: 393 | replacements.append((node, node.args[0])) 394 | except: 395 | print(node_name) 396 | print(e) 397 | raise 398 | for node, replacement in replacements: 399 | print(node, ' -> ', replacement) 400 | e = e.xreplace({node: replacement}) 401 | return e 402 | 403 | 404 | def _replace__ts_delay__to__ts_delta(e): 405 | """ 将-ts_delay(x, y)转成ts_delta(x, y)-x 406 | 407 | 本质上为x-ts_delay(x, y) 转成 ts_delta(x, y) 408 | 409 | 例如 OPEN - ts_delay(OPEN, 5) + (CLOSE - ts_delay(CLOSE, 5)) 410 | 结果 ts_delta(CLOSE, 5) + ts_delta(OPEN, 5) 411 | """ 412 | if not isinstance(e, Basic): 413 | return e 414 | 415 | ts_delta = symbols('ts_delta', cls=Function) 416 | 417 | replacements = [] 418 | for node in preorder_traversal(e): 419 | if node.is_Add: 420 | new_args = [] 421 | for arg in node.args: 422 | if arg.is_Mul: 423 | if arg.args[0] == -1 and arg.args[1].is_Function and get_node_name(arg.args[1]) == 'ts_delay': 424 | # 添加ts_delta(x, y) 425 | new_args.append(ts_delta(arg.args[1].args[0], arg.args[1].args[1])) 426 | # 添加-x 427 | new_args.append(-arg.args[1].args[0]) 428 | else: 429 | new_args.append(arg) 430 | else: 431 | new_args.append(arg) 432 | if len(new_args) > len(node.args): 433 | # 长度变长,表示成功实现了调整 434 | tmp_args = simplify(Add._from_args(new_args)) 435 | # 优化后长度变短,表示有变量对冲掉了,成功 436 | if len(tmp_args.args) < len(new_args): 437 | replacements.append((node, tmp_args)) 438 | for node, replacement in replacements: 439 | print(node, ' -> ', replacement) 440 | e = e.xreplace({node: replacement}) 441 | return e 442 | 443 | # def is_meaningless(e): 444 | # if _meaningless__ts_xxx_1(e): 445 | # return True 446 | # if _meaningless__xx_xx(e): 447 | # return True 448 | # return False 449 | # 450 | # 451 | # def _meaningless__ts_xxx_1(e): 452 | # """ts_xxx部分函数如果参数为1,可直接丢弃""" 453 | # for node in preorder_traversal(e): 454 | # if len(node.args) >= 2: 455 | # node_name = get_node_name(node) 456 | # if node_name in ('ts_delay', 'ts_delta'): 457 | # if not node.args[1].is_Integer: 458 | # return True 459 | # if node_name.startswith('ts_'): 460 | # if not node.args[-1].is_Number: 461 | # return True 462 | # if node.args[-1] <= 1: 463 | # return True 464 | # return False 465 | # 466 | # 467 | # def _meaningless__xx_xx(e): 468 | # """部分函数如果两参数完全一样,可直接丢弃""" 469 | # for node in preorder_traversal(e): 470 | # if len(node.args) >= 2: 471 | # if node.args[0] == node.args[1]: 472 | # return True 473 | # return False 474 | -------------------------------------------------------------------------------- /expr_codegen/latex/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wukan1986/expr_codegen/faa577832438d2f4765c289a7d56c29e09e5cf3b/expr_codegen/latex/__init__.py -------------------------------------------------------------------------------- /expr_codegen/latex/printer.py: -------------------------------------------------------------------------------- 1 | from sympy import Symbol, Function, Basic 2 | from sympy.core.sorting import default_sort_key 3 | from sympy.printing.latex import LatexPrinter, accepted_latex_functions 4 | 5 | 6 | class ExprLatexPrinter(LatexPrinter): 7 | """改版的Latex表达式打印。 8 | 9 | 主要解决不少函数和符号中的下划线被转成下标的问题 10 | """ 11 | 12 | def _print(self, expr, **kwargs) -> str: 13 | """Internal dispatcher 14 | 15 | Tries the following concepts to print an expression: 16 | 1. Let the object print itself if it knows how. 17 | 2. Take the best fitting method defined in the printer. 18 | 3. As fall-back use the emptyPrinter method for the printer. 19 | """ 20 | self._print_level += 1 21 | try: 22 | # If the printer defines a name for a printing method 23 | # (Printer.printmethod) and the object knows for itself how it 24 | # should be printed, use that method. 25 | if self.printmethod and hasattr(expr, self.printmethod): 26 | if not (isinstance(expr, type) and issubclass(expr, Basic)): 27 | return getattr(expr, self.printmethod)(self, **kwargs) 28 | 29 | # See if the class of expr is known, or if one of its super 30 | # classes is known, and use that print function 31 | # Exception: ignore the subclasses of Undefined, so that, e.g., 32 | # Function('gamma') does not get dispatched to _print_gamma 33 | classes = type(expr).__mro__ 34 | # if AppliedUndef in classes: 35 | # classes = classes[classes.index(AppliedUndef):] 36 | # if UndefinedFunction in classes: 37 | # classes = classes[classes.index(UndefinedFunction):] 38 | # Another exception: if someone subclasses a known function, e.g., 39 | # gamma, and changes the name, then ignore _print_gamma 40 | if Function in classes: 41 | i = classes.index(Function) 42 | classes = tuple(c for c in classes[:i] if \ 43 | c.__name__ == classes[0].__name__ or \ 44 | c.__name__.endswith("Base")) + classes[i:] 45 | for cls in classes: 46 | printmethodname = '_print_' + cls.__name__ 47 | printmethod = getattr(self, printmethodname, None) 48 | if printmethod is not None: 49 | return printmethod(expr, **kwargs) 50 | # Unknown object, fall back to the emptyPrinter. 51 | return self.emptyPrinter(expr) 52 | finally: 53 | self._print_level -= 1 54 | 55 | def _hprint_Function(self, func: str) -> str: 56 | # func = self._deal_with_super_sub(func) 57 | superscriptidx = -1 # func.find("^") 58 | subscriptidx = -1 # func.find("_") 59 | func = func.replace('_', r'\_') 60 | if func in accepted_latex_functions: 61 | name = r"\%s" % func 62 | elif len(func) == 1 or func.startswith('\\') or subscriptidx == 1 or superscriptidx == 1: 63 | name = func 64 | else: 65 | if superscriptidx > 0 and subscriptidx > 0: 66 | name = r"\operatorname{%s}%s" % ( 67 | func[:min(subscriptidx, superscriptidx)], 68 | func[min(subscriptidx, superscriptidx):]) 69 | elif superscriptidx > 0: 70 | name = r"\operatorname{%s}%s" % ( 71 | func[:superscriptidx], 72 | func[superscriptidx:]) 73 | elif subscriptidx > 0: 74 | name = r"\operatorname{%s}%s" % ( 75 | func[:subscriptidx], 76 | func[subscriptidx:]) 77 | else: 78 | name = r"\operatorname{%s}" % func 79 | return name 80 | 81 | def _print_Symbol(self, expr: Symbol, style='plain'): 82 | name: str = self._settings['symbol_names'].get(expr) 83 | if name is not None: 84 | return name 85 | 86 | return expr.name.replace('_', r'\_') 87 | 88 | def _print_abs_(self, expr, exp=None): 89 | return self._print_Abs(expr, exp) 90 | 91 | def _print_log_(self, expr, exp=None): 92 | return self._print_log(expr, exp) 93 | 94 | def _print_max_(self, expr, exp=None): 95 | # return self._print_Max(expr, exp) 96 | args = sorted(expr.args, key=default_sort_key) 97 | texargs = [r"%s" % self._print(symbol) for symbol in args] 98 | tex = r"\%s\left(%s\right)" % ('max', ", ".join(texargs)) 99 | if exp is not None: 100 | return r"%s^{%s}" % (tex, exp) 101 | else: 102 | return tex 103 | 104 | def _print_min_(self, expr, exp=None): 105 | args = sorted(expr.args, key=default_sort_key) 106 | texargs = [r"%s" % self._print(symbol) for symbol in args] 107 | tex = r"\%s\left(%s\right)" % ('min', ", ".join(texargs)) 108 | if exp is not None: 109 | return r"%s^{%s}" % (tex, exp) 110 | else: 111 | return tex 112 | 113 | 114 | def latex(expr, mode='equation*', mul_symbol='times', **settings): 115 | """表达式转LATEX字符串""" 116 | settings.update({'mode': mode, 'mul_symbol': mul_symbol}) 117 | return ExprLatexPrinter(settings).doprint(expr) 118 | 119 | 120 | def display_latex(expr): 121 | """显示LATEX表达式,在VSCode或Notebook中显示正常""" 122 | from IPython.display import Markdown, display 123 | 124 | return display(Markdown(latex(expr))) 125 | -------------------------------------------------------------------------------- /expr_codegen/model.py: -------------------------------------------------------------------------------- 1 | from functools import reduce 2 | from itertools import product, permutations 3 | 4 | import networkx as nx 5 | from sympy import symbols 6 | 7 | from expr_codegen.dag import zero_indegree, hierarchy_pos, remove_paths_by_zero_outdegree 8 | from expr_codegen.expr import CL, get_symbols, get_children, get_key, is_simple_expr 9 | 10 | _RESERVED_WORD_ = {'_NONE_', '_TRUE_', '_FALSE_'} 11 | 12 | 13 | class ListDictList: 14 | """嵌套列表 15 | 16 | 1. 最外层是 列表[] 17 | 2. 第二层是 字典{} 18 | 3. 第三层是 列表[] 19 | 20 | [ 21 | {'ts': [1, 2], 'cs': [2], 'gp_date': [2], 'gp_key': [2], } 22 | {'ts': [1], 'cs': [1],} 23 | ] 24 | 25 | """ 26 | 27 | def __init__(self): 28 | self._list = [] 29 | 30 | def clear(self): 31 | """清空""" 32 | self._list = [] 33 | 34 | def values(self): 35 | return self._list 36 | 37 | def next_row(self): 38 | """移动到新的一行""" 39 | self._list.append({}) 40 | 41 | def append(self, key, item): 42 | """自动放入同key的字典中""" 43 | last_row = self._list[-1] 44 | v = last_row.get(key, None) 45 | if v is None: 46 | # 同一行的新一列 47 | last_row[key] = [None, item] 48 | else: 49 | last_row[key].append(item) 50 | 51 | def filter_empty(self): 52 | """过滤空值""" 53 | new_list = [] 54 | for row in self._list: 55 | try_del1 = [] 56 | for k, v in row.items(): 57 | if len(v) == 0: 58 | try_del1.append(k) 59 | for k in try_del1: 60 | row.pop(k) 61 | if len(row) > 0: 62 | new_list.append(row) 63 | self._list = new_list 64 | 65 | def back_merge(self): 66 | """向上合并,将CL类型向前合并""" 67 | keys = reduce(lambda x, y: x + list(y.keys()), self._list, []) 68 | values = reduce(lambda x, y: x + list(y.values()), self._list, []) 69 | 70 | new_keys = [] 71 | new_values = [] 72 | last_k = None 73 | last_v = None 74 | for k, v in zip(keys, values): 75 | # 当前是整列时可以向上合并,但前一个是gp_xxx一类时不合并,因为循环太多次了 76 | # if (last_v is not None) and (k[0] == CL) and (last_k[0] != GP): 77 | if (last_v is not None) and (k == last_k): 78 | # print(1, k, last_k) 79 | last_v.extend(v) 80 | v.clear() 81 | else: 82 | # print(2, k, last_k) 83 | new_keys.append(k) 84 | new_values.append(v) 85 | last_v = v 86 | last_k = k 87 | 88 | def optimize(self, merge: bool): 89 | """将多组groupby根据规则进行合并,减少运行时间""" 90 | # 接龙。groupby的数量没少,首尾接龙数据比较整齐 91 | self._list = chain_create(self._list) 92 | if merge: 93 | # 首尾一样,接上去 94 | self.back_merge() 95 | # 出现了空行,删除 96 | self.filter_empty() 97 | 98 | def drop_symbols(self): 99 | """组装一种数据结构,用来存储之后会用到的变量名,用于提前删除不需要的变量""" 100 | # 获取每一小块所用到的所有变量名 101 | l1 = [] 102 | for row in self._list: 103 | for k, v in row.items(): 104 | vv = [] 105 | for v1 in v: 106 | if v1 is None: 107 | continue 108 | vv.extend(v1[2]) 109 | l1.append(set(vv)) 110 | 111 | # 得到此行与之后都会出现的变量名 112 | l2 = [set()] 113 | s = set() 114 | for i in reversed(l1): 115 | s = s | i # - {'_NONE_', '_TRUE_', '_FALSE_'} 116 | l2.append(s) 117 | l2 = list(reversed(l2)) 118 | 119 | # 计算之后不会再出现的变量名 120 | l3 = [list(s - e) for s, e in zip(l2[:-1], l2[1:])] 121 | 122 | return l3 123 | 124 | 125 | def score1(row) -> int: 126 | # 首尾相连打分加1 127 | lst = [None] + [key for r in row for key in dict(r).keys()] 128 | return sum([x == y for x, y in zip(lst[:-1], lst[1:])]) 129 | 130 | 131 | def score2(row) -> float: 132 | # 最后一个ts越靠前,打分越高 133 | lst = ['ts'] + [key[0] for r in row for key in dict(r).keys()] 134 | return lst[::-1].index('ts') / len(lst) 135 | 136 | 137 | def chain_create(nested_list): 138 | """接龙。多个列表,头尾相连 139 | 140 | 测试用表达式 141 | ma_10 = ts_mean(CLOSE, 10) 142 | MAMA_20 = ts_mean(ma_10, 20) 143 | alpha_031 = ((cs_rank(cs_rank(cs_rank(ts_decay_linear((-1 * cs_rank(cs_rank(ts_delta(CLOSE, 10)))), 10)))))) 144 | 145 | """ 146 | perms = [] 147 | for d in nested_list: 148 | # 每一层生成排列 149 | perms.append(permutations(d.items())) 150 | 151 | last_score = float('-inf') 152 | last_row = None 153 | # 生成笛卡尔积 154 | for row in product(*perms): 155 | result = score1(row) + score2(row) 156 | # print(result, row) 157 | if result > last_score: 158 | last_score = result 159 | last_row = row 160 | 161 | return [dict(ro) for ro in last_row] 162 | 163 | 164 | # ========================== 165 | 166 | def create_dag_exprs(exprs): 167 | """根据表达式字典生成DAG""" 168 | # 创建有向无环图 169 | G = nx.DiGraph() 170 | 171 | for symbol, expr, comment in exprs: 172 | # if symbol.name == 'GP_0': 173 | # test = 1 174 | if expr.is_Symbol: 175 | G.add_node(symbol.name, symbol=symbol, expr=expr, comment=comment) 176 | G.add_edge(expr.name, symbol.name) 177 | else: 178 | # 添加中间节点 179 | G.add_node(symbol.name, symbol=symbol, expr=expr, comment=comment) 180 | syms = get_symbols(expr, return_str=True) 181 | for sym in syms: 182 | # 由于边的原因,这里会主动生成一些源节点 183 | G.add_edge(sym, symbol.name) 184 | if len(syms) == 0: 185 | # GP_0033=log(1/2400) 186 | if hasattr(expr, 'name'): 187 | G.add_edge(expr.name, symbol.name) 188 | else: 189 | G.add_edge(str(expr), symbol.name) 190 | 191 | # 源始因子,添加属性 192 | for node in zero_indegree(G): 193 | s = symbols(node) 194 | G.nodes[node]['symbol'] = s 195 | G.nodes[node]['expr'] = s 196 | G.nodes[node]['comment'] = "#" 197 | # 198 | # for node in zero_outdegree(G): 199 | # print(11, G.nodes[node]['comment']) 200 | return G 201 | 202 | 203 | def init_dag_exprs(G, func, func_kwargs, date, asset): 204 | """使用表达式信息初始化DAG""" 205 | for i, generation in enumerate(nx.topological_generations(G)): 206 | # print(i, generation) 207 | for node in generation: 208 | expr = G.nodes[node]['expr'] 209 | syms = [] 210 | children = get_children(func, func_kwargs, expr, [], syms, date, asset) 211 | G.nodes[node]['children'] = children 212 | G.nodes[node]['key'] = get_key(children) 213 | G.nodes[node]['symbols'] = [str(s) for s in syms] 214 | G.nodes[node]['gen'] = i 215 | # print(G.nodes[node]) 216 | return G 217 | 218 | 219 | def merge_nodes_1(G: nx.DiGraph, keep_nodes, *args): 220 | """合并节点,从当前节点开始,查看是否可能替换前后两端的节点""" 221 | # 准备一个当前节点列表 222 | this_pred = args 223 | # 下一步不为空就继续 224 | while this_pred: 225 | next_pred = [] 226 | for node in this_pred: 227 | if not G.has_node(node): 228 | continue 229 | pred = G.pred[node] 230 | if len(pred) == 0: 231 | # 到了最上层的因子,需停止 232 | continue 233 | dic = G.nodes[node] 234 | key = dic['key'] 235 | expr = dic['expr'] 236 | symbols = dic['symbols'] 237 | if key[0] == CL: 238 | if is_simple_expr(expr): 239 | # 检查表达式是否很简单, 是就替换,可能会替换多个 240 | skip_expr_node(G, node, keep_nodes) 241 | else: 242 | succ = G.succ[node] 243 | # 下游只有一个,直接替换。 244 | if len(succ) == 1: 245 | for s in succ: 246 | # if_else(_A>_B,_A,_B)会出现量次,不能删 247 | if G.nodes[s]['symbols'].count(node) > 1: 248 | continue 249 | skip_expr_node(G, node, keep_nodes) 250 | else: 251 | # 复制一次,防止修改后报错 252 | for p in pred.copy(): 253 | # 在下游同一表达式中使用了多次,不替换 254 | if symbols.count(p) > 1: 255 | continue 256 | d = G.nodes[p] 257 | k = d['key'] 258 | e = d['expr'] 259 | if key == k: 260 | # 同类型 261 | succ = G.succ[p] 262 | # 下游只有一个,直接替换。 263 | if len(succ) == 1: 264 | for s in succ: 265 | if G.nodes[s]['symbols'].count(p) > 1: 266 | continue 267 | skip_expr_node(G, p, keep_nodes) 268 | next_pred.extend(pred) 269 | # 更新下一次循环 270 | this_pred = list(set(next_pred)) 271 | return G 272 | 273 | 274 | def merge_nodes_2(G: nx.DiGraph, keep_nodes, *args): 275 | """合并节点,从当前节点开始,查看是否需要被替换,只做用于根节点""" 276 | # 准备一个当前节点列表 277 | this_pred = args 278 | # 下一步不为空就继续 279 | while this_pred: 280 | next_pred = [] 281 | for node in this_pred: 282 | dic = G.nodes[node] 283 | expr = dic['expr'] 284 | if not is_simple_expr(expr): 285 | continue 286 | pred = G.pred[node] 287 | for p in pred.copy(): 288 | succ = G.succ[p] 289 | if len(succ) > 1: 290 | # 上游节点只有一个下游,当前就是自己了 291 | continue 292 | for s in succ: 293 | if G.nodes[s]['symbols'].count(p) > 1: 294 | continue 295 | skip_expr_node(G, p, keep_nodes) 296 | # 只做根节点,所以没有下一次了 297 | # next_pred.extend(pred) 298 | # 更新下一次循环 299 | this_pred = list(set(next_pred)) 300 | return G 301 | 302 | 303 | def get_expr_labels(G, nodes=None): 304 | """得到表达式标签""" 305 | labels = {} 306 | if nodes is None: 307 | for n, d in G.nodes(data=True): 308 | labels[n] = '{symbol}={expr}'.format(**d) 309 | else: 310 | for n, d in G.nodes(data=True): 311 | if n not in nodes: 312 | continue 313 | labels[n] = '{symbol}={expr}'.format(**d) 314 | return labels 315 | 316 | 317 | def draw_expr_tree(G: nx.DiGraph, root: str, ax=None): 318 | """画表达式树""" 319 | # 查找上游节点 320 | nodes = nx.ancestors(G, root) | {root} 321 | labels = get_expr_labels(G, nodes) 322 | # 子图 323 | view = nx.subgraph(G, nodes) 324 | # 位置 325 | pos = hierarchy_pos(G, root) 326 | nx.draw(view, ax=ax, pos=pos, labels=labels) 327 | 328 | 329 | def skip_expr_node(G: nx.DiGraph, node, keep_nodes): 330 | """跳过中间节点,将两端的节点直接连接起来,同时更新表达式 331 | 332 | 1. (A,B,C) 模式,直接成 (A,C) 333 | 2. (A,B,C), (D, B) 模式,变成 (A,C),(D,C) 334 | """ 335 | if node in keep_nodes: 336 | return G 337 | 338 | pred = G.pred[node] 339 | succ = G.succ[node] 340 | if len(pred) == 0 or len(succ) == 0: 341 | return G 342 | 343 | # 取当前节点表达式 344 | d = G.nodes[node] 345 | expr = d['expr'] 346 | symbol = d['symbol'] 347 | 348 | for s in succ: 349 | e = G.nodes[s]['expr'] 350 | e = e.xreplace({symbol: expr}) 351 | G.nodes[s]['expr'] = e 352 | 353 | # 这里用了product生成多个关联边 354 | G.add_edges_from(product(pred, succ)) 355 | G.remove_node(node) 356 | return G 357 | 358 | 359 | def dag_start(exprs_list, func, func_kwargs, date, asset): 360 | """初始生成DAG""" 361 | G = create_dag_exprs(exprs_list) 362 | G = init_dag_exprs(G, func, func_kwargs, date, asset) 363 | 364 | # 分层输出 365 | return G 366 | 367 | 368 | def dag_middle(G, exprs_names, func, func_kwargs, date, asset): 369 | """删除几个没有必要的节点""" 370 | G = remove_paths_by_zero_outdegree(G, exprs_names) 371 | # 以下划线开头的节点,不保留 372 | keep_nodes = [k for k in exprs_names if not k.startswith('_')] 373 | G = merge_nodes_1(G, keep_nodes, *keep_nodes) 374 | G = merge_nodes_2(G, keep_nodes, *keep_nodes) 375 | 376 | # 由于表达式修改,需再次更新表达式 377 | G = init_dag_exprs(G, func, func_kwargs, date, asset) 378 | 379 | # 分层输出 380 | return G 381 | 382 | 383 | def dag_end(G): 384 | """有向无环图流转""" 385 | exprs_ldl = ListDictList() 386 | 387 | for i, generation in enumerate(nx.topological_generations(G)): 388 | exprs_ldl.next_row() 389 | for node in generation: 390 | key = G.nodes[node]['key'] 391 | expr = G.nodes[node]['expr'] 392 | comment = G.nodes[node]['comment'] 393 | symbols = G.nodes[node]['symbols'] 394 | # 这几个特殊的不算成字段名 395 | symbols = list(set(symbols) - _RESERVED_WORD_) 396 | 397 | exprs_ldl.append(key, (node, expr, symbols, comment)) 398 | 399 | # 第0层是CLOSE等基础因子,剔除 400 | exprs_ldl._list = exprs_ldl.values()[1:] 401 | 402 | return exprs_ldl, G 403 | -------------------------------------------------------------------------------- /expr_codegen/pandas/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wukan1986/expr_codegen/faa577832438d2f4765c289a7d56c29e09e5cf3b/expr_codegen/pandas/__init__.py -------------------------------------------------------------------------------- /expr_codegen/pandas/code.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Sequence 3 | 4 | import jinja2 5 | from jinja2 import FileSystemLoader, TemplateNotFound 6 | 7 | from expr_codegen.expr import TS, CS, GP 8 | from expr_codegen.model import ListDictList 9 | from expr_codegen.pandas.printer import PandasStrPrinter 10 | 11 | 12 | def get_groupby_from_tuple(tup, func_name, drop_cols): 13 | """从传入的元组中生成分组运行代码""" 14 | prefix2, *_ = tup 15 | 16 | if prefix2 == TS: 17 | # 组内需要按时间进行排序,需要维持顺序 18 | prefix2, asset = tup 19 | return f'df = df.groupby(by=[_ASSET_], group_keys=False).apply({func_name}).drop(columns={drop_cols})' 20 | if prefix2 == CS: 21 | prefix2, date = tup 22 | return f'df = df.groupby(by=[_DATE_], group_keys=False).apply({func_name}).drop(columns={drop_cols})' 23 | if prefix2 == GP: 24 | prefix2, date, group = tup 25 | return f'df = df.groupby(by=[_DATE_, "{group}"], group_keys=False).apply({func_name}).drop(columns={drop_cols})' 26 | 27 | return f'df = {func_name}(df).drop(columns={drop_cols})' 28 | 29 | 30 | def symbols_to_code(syms): 31 | a = [f"{s}" for s in syms] 32 | b = [f"'{s}'" for s in syms] 33 | return f"""_ = [{','.join(b)}] 34 | [{','.join(a)}] = _""" 35 | 36 | 37 | def codegen(exprs_ldl: ListDictList, exprs_src, syms_dst, 38 | filename, 39 | date='date', asset='asset', 40 | extra_codes: Sequence[str] = (), 41 | filter_last: bool = False, 42 | **kwargs): 43 | """基于模板的代码生成""" 44 | if filename is None: 45 | filename = 'template.py.j2' 46 | 47 | # 打印Pandas风格代码 48 | p = PandasStrPrinter() 49 | 50 | # polars风格代码 51 | funcs = {} 52 | # 分组应用代码。这里利用了字典按插入顺序排序的特点,将排序放在最前 53 | groupbys = {'sort': ''} 54 | # 处理过后的表达式 55 | exprs_dst = [] 56 | syms_out = [] 57 | ts_func_name = None 58 | drop_symbols = exprs_ldl.drop_symbols() 59 | j = -1 60 | for i, row in enumerate(exprs_ldl.values()): 61 | for k, vv in row.items(): 62 | j += 1 63 | if len(vv) == 0: 64 | continue 65 | # 函数名 66 | func_name = f'func_{i}_{"__".join(k)}' 67 | func_code = [] 68 | for kv in vv: 69 | if kv is None: 70 | func_code.append(f" # " + '=' * 40) 71 | exprs_dst.append(f"#" + '=' * 40 + func_name) 72 | else: 73 | va, ex, sym, comment = kv 74 | func_code.append(f" # {va} = {ex}\n g[{va}] = {p.doprint(ex)}") 75 | exprs_dst.append(f"{va} = {ex} {comment}") 76 | if va not in syms_dst: 77 | syms_out.append(va) 78 | 79 | if len(groupbys['sort']) == 0: 80 | groupbys['sort'] = f'df = df.sort_values(by=[_ASSET_, _DATE_]).reset_index(drop=True)' 81 | if k[0] == TS: 82 | ts_func_name = func_name 83 | # 时序需要排序 84 | func_code = [f' g.df = df.sort_values(by=[_DATE_])'] + func_code 85 | else: 86 | # 时序需要排序 87 | func_code = [f' g.df = df'] + func_code 88 | 89 | # polars风格代码列表 90 | funcs[func_name] = '\n'.join(func_code) 91 | # 只有下划线开头再删除 92 | ds = [x for x in drop_symbols[j] if x.startswith('_')] 93 | # 分组应用代码 94 | groupbys[func_name] = get_groupby_from_tuple(k, func_name, ds) 95 | 96 | syms1 = symbols_to_code(syms_dst) 97 | syms2 = symbols_to_code(syms_out) 98 | if filter_last: 99 | _groupbys = {'sort': groupbys['sort']} 100 | if ts_func_name is None: 101 | _groupbys['_filter_last'] = "df = filter_last(df.sort_values(by=[_DATE_]))" 102 | for k, v in groupbys.items(): 103 | _groupbys[k] = v 104 | if k == ts_func_name: 105 | _groupbys[k + '_filter_last'] = "df = filter_last(df)" 106 | groupbys = _groupbys 107 | 108 | try: 109 | env = jinja2.Environment(loader=FileSystemLoader(os.path.dirname(__file__))) 110 | template = env.get_template(filename) 111 | except TemplateNotFound: 112 | env = jinja2.Environment(loader=FileSystemLoader(os.path.dirname(filename))) 113 | template = env.get_template(os.path.basename(filename)) 114 | 115 | return template.render(funcs=funcs, groupbys=groupbys, 116 | exprs_src=exprs_src, exprs_dst=exprs_dst, 117 | syms1=syms1, syms2=syms2, 118 | date=date, asset=asset, 119 | extra_codes=extra_codes) 120 | -------------------------------------------------------------------------------- /expr_codegen/pandas/helper.py: -------------------------------------------------------------------------------- 1 | """ 2 | A、B、C=MACD()无法生成DAG,所以变通的改成 3 | 4 | A=unpack(MACD(),0) 5 | B=unpack(MACD(),1) 6 | C=unpack(MACD(),2) 7 | 8 | cse能自动提取成 9 | 10 | _x_0 = MACD() 11 | 12 | 但 df['_x_0'] 是无法放入tuple的,所以决定用另一个类来实现兼容 13 | 14 | """ 15 | import pandas as pd 16 | 17 | 18 | class GlobalVariable(object): 19 | def __init__(self): 20 | self.dict = {} 21 | self.df = pd.DataFrame() 22 | 23 | def __getitem__(self, item): 24 | if item in self.dict: 25 | return self.dict[item] 26 | return self.df[item] 27 | 28 | def __setitem__(self, key, value): 29 | if isinstance(value, tuple): 30 | # tuple存字典中 31 | self.dict[key] = value 32 | # 占位,避免drop时报错 33 | self.df[key] = False 34 | else: 35 | self.df[key] = value 36 | -------------------------------------------------------------------------------- /expr_codegen/pandas/printer.py: -------------------------------------------------------------------------------- 1 | from sympy import Basic, Function, StrPrinter 2 | from sympy.printing.precedence import precedence, PRECEDENCE 3 | 4 | 5 | # TODO: 如有新添加函数,但表达式有变更才需要在此补充对应的打印代码,否则可以省略 6 | 7 | class PandasStrPrinter(StrPrinter): 8 | def _print(self, expr, **kwargs) -> str: 9 | """Internal dispatcher 10 | 11 | Tries the following concepts to print an expression: 12 | 1. Let the object print itself if it knows how. 13 | 2. Take the best fitting method defined in the printer. 14 | 3. As fall-back use the emptyPrinter method for the printer. 15 | """ 16 | self._print_level += 1 17 | try: 18 | # If the printer defines a name for a printing method 19 | # (Printer.printmethod) and the object knows for itself how it 20 | # should be printed, use that method. 21 | if self.printmethod and hasattr(expr, self.printmethod): 22 | if not (isinstance(expr, type) and issubclass(expr, Basic)): 23 | return getattr(expr, self.printmethod)(self, **kwargs) 24 | 25 | # See if the class of expr is known, or if one of its super 26 | # classes is known, and use that print function 27 | # Exception: ignore the subclasses of Undefined, so that, e.g., 28 | # Function('gamma') does not get dispatched to _print_gamma 29 | classes = type(expr).__mro__ 30 | # if AppliedUndef in classes: 31 | # classes = classes[classes.index(AppliedUndef):] 32 | # if UndefinedFunction in classes: 33 | # classes = classes[classes.index(UndefinedFunction):] 34 | # Another exception: if someone subclasses a known function, e.g., 35 | # gamma, and changes the name, then ignore _print_gamma 36 | if Function in classes: 37 | i = classes.index(Function) 38 | classes = tuple(c for c in classes[:i] if \ 39 | c.__name__ == classes[0].__name__ or \ 40 | c.__name__.endswith("Base")) + classes[i:] 41 | for cls in classes: 42 | printmethodname = '_print_' + cls.__name__ 43 | 44 | # 所有以gp_开头的函数都转换成cs_开头 45 | if printmethodname.startswith('_print_gp_'): 46 | printmethodname = "_print_gp_" 47 | 48 | printmethod = getattr(self, printmethodname, None) 49 | if printmethod is not None: 50 | return printmethod(expr, **kwargs) 51 | # Unknown object, fall back to the emptyPrinter. 52 | return self.emptyPrinter(expr) 53 | finally: 54 | self._print_level -= 1 55 | 56 | def _print_Symbol(self, expr): 57 | if expr.name in ('_NONE_', '_TRUE_', '_FALSE_'): 58 | return expr.name 59 | return f"g[{expr.name}]" 60 | 61 | def _print_Equality(self, expr): 62 | PREC = precedence(expr) 63 | return "%s==%s" % (self.parenthesize(expr.args[0], PREC), self.parenthesize(expr.args[1], PREC)) 64 | 65 | def _print_Or(self, expr): 66 | PREC = PRECEDENCE["Mul"] 67 | return " | ".join(self.parenthesize(arg, PREC) for arg in expr.args) 68 | 69 | def _print_Xor(self, expr): 70 | PREC = PRECEDENCE["Mul"] 71 | return " ^ ".join(self.parenthesize(arg, PREC) for arg in expr.args) 72 | 73 | def _print_And(self, expr): 74 | PREC = PRECEDENCE["Mul"] 75 | return " & ".join(self.parenthesize(arg, PREC) for arg in expr.args) 76 | 77 | def _print_Not(self, expr): 78 | PREC = PRECEDENCE["Mul"] 79 | return "~%s" % self.parenthesize(expr.args[0], PREC) 80 | 81 | def _print_gp_(self, expr): 82 | """gp_函数都转换成cs_函数,但要丢弃第一个参数""" 83 | new_args = [self._print(arg) for arg in expr.args[1:]] 84 | func_name = expr.func.__name__[3:] 85 | return "cs_%s(%s)" % (func_name, ",".join(new_args)) 86 | -------------------------------------------------------------------------------- /expr_codegen/pandas/ta.py: -------------------------------------------------------------------------------- 1 | """ 2 | 原本代码都写在printer.py中,但存在一些不足 3 | 1. 每新添函数都要修改printer.py,工作量大,不够灵活 4 | 2. 生成的代码过于直接,不便研究分析 5 | 3. 纯字符串拼接,没有IDE语法检查,非常容易出错 6 | 4. 部分写法对模板侵入性高,import混乱 7 | 8 | 所以有必要使用类似于polars_ta的公共库,但因目前未找到合适库,所以以下是临时版,以后要独立出去 9 | """ 10 | from typing import Tuple 11 | 12 | import numpy as np 13 | import pandas as pd 14 | 15 | try: 16 | import talib 17 | except: 18 | pass 19 | 20 | 21 | def abs_(x: pd.Series) -> pd.Series: 22 | return np.abs(x) 23 | 24 | 25 | def cs_demean(x: pd.Series) -> pd.Series: 26 | return x - x.mean() 27 | 28 | 29 | def cs_rank(x: pd.Series, pct: bool = True) -> pd.Series: 30 | return x.rank(pct=pct) 31 | 32 | 33 | def cs_scale(x: pd.Series, scale: float = 1) -> pd.Series: 34 | return x / x.abs().sum() * scale 35 | 36 | 37 | def if_else(input1: pd.Series, input2: pd.Series, input3: pd.Series = None): 38 | return np.where(input1, input2, input3) 39 | 40 | 41 | def log(x: pd.Series) -> pd.Series: 42 | return np.log(x) 43 | 44 | 45 | def max_(a: pd.Series, b: pd.Series) -> pd.Series: 46 | return np.maximum(a, b) 47 | 48 | 49 | def min_(a: pd.Series, b: pd.Series) -> pd.Series: 50 | return np.minimum(a, b) 51 | 52 | 53 | def sign(x: pd.Series) -> pd.Series: 54 | return np.sign(x) 55 | 56 | 57 | def signed_power(x: pd.Series, y: float) -> pd.Series: 58 | return x.sign() * (x.abs() ** y) 59 | 60 | 61 | def ts_corr(x: pd.Series, y: pd.Series, d: int = 5, ddof: int = 1) -> pd.Series: 62 | return x.rolling(d).corr(y, ddof=ddof) 63 | 64 | 65 | def ts_covariance(x: pd.Series, y: pd.Series, d: int = 5, ddof: int = 1) -> pd.Series: 66 | return x.rolling(d).cov(y, ddof=ddof) 67 | 68 | 69 | def ts_delay(x: pd.Series, d: int = 1) -> pd.Series: 70 | return x.shift(d) 71 | 72 | 73 | def ts_delta(x: pd.Series, d: int = 1) -> pd.Series: 74 | return x.diff(d) 75 | 76 | def ts_returns(x: pd.Series, d: int = 1) -> pd.Series: 77 | return x.pct_change(d) 78 | 79 | 80 | def ts_max(x: pd.Series, d: int = 5) -> pd.Series: 81 | return x.rolling(d).max() 82 | 83 | 84 | def ts_mean(x: pd.Series, d: int = 5) -> pd.Series: 85 | return x.rolling(d).mean() 86 | 87 | 88 | def ts_min(x: pd.Series, d: int = 5) -> pd.Series: 89 | return x.rolling(d).min() 90 | 91 | 92 | def ts_product(x: pd.Series, d: int = 5) -> pd.Series: 93 | return x.rolling(d).apply(np.prod, raw=True) 94 | 95 | 96 | def ts_rank(x: pd.Series, d: int = 5, pct: bool = True) -> pd.Series: 97 | return x.rolling(d).rank(pct=pct) 98 | 99 | 100 | def ts_std_dev(x: pd.Series, d: int = 5, ddof: int = 0) -> pd.Series: 101 | return x.rolling(d).std(ddof=ddof) 102 | 103 | 104 | def ts_sum(x: pd.Series, d: int = 5) -> pd.Series: 105 | return x.rolling(d).sum() 106 | 107 | 108 | def ts_MACD(close: pd.Series, fastperiod: int = 12, slowperiod: int = 26, signalperiod: int = 9) -> Tuple[pd.Series, pd.Series, pd.Series]: 109 | return talib.MACD(close, fastperiod, slowperiod, signalperiod) 110 | -------------------------------------------------------------------------------- /expr_codegen/pandas/template.py.j2: -------------------------------------------------------------------------------- 1 | # this code is auto generated by the expr_codegen 2 | # https://github.com/wukan1986/expr_codegen 3 | # 此段代码由 expr_codegen 自动生成,欢迎提交 issue 或 pull request 4 | from typing import Tuple 5 | 6 | import numpy as np # noqa 7 | import pandas as pd # noqa 8 | from loguru import logger # noqa 9 | 10 | from expr_codegen.pandas.helper import GlobalVariable 11 | from expr_codegen.pandas.ta import * # noqa 12 | 13 | {{ syms1 }} 14 | 15 | {{ syms2 }} 16 | 17 | _DATE_ = '{{ date }}' 18 | _ASSET_ = '{{ asset }}' 19 | _NONE_ = None 20 | _TRUE_ = True 21 | _FALSE_ = False 22 | 23 | g = GlobalVariable() 24 | 25 | 26 | def unpack(x: Tuple, idx: int = 0) -> pd.Series: 27 | return x[idx] 28 | 29 | {%-for row in extra_codes %} 30 | {{ row-}} 31 | {% endfor %} 32 | 33 | {% for key, value in funcs.items() %} 34 | 35 | def {{ key }}(df: pd.DataFrame) -> pd.DataFrame: 36 | {{ value }} 37 | return g.df 38 | 39 | {% endfor %} 40 | 41 | """ 42 | {%-for row in exprs_dst %} 43 | {{ row-}} 44 | {% endfor %} 45 | """ 46 | 47 | """ 48 | {%-for a,b,c in exprs_src %} 49 | {{ a }} = {{ b}} {{c-}} 50 | {% endfor %} 51 | """ 52 | 53 | 54 | def filter_last(df: pd.DataFrame) -> pd.DataFrame: 55 | """过滤数据,只取最后一天。实盘时可用于减少计算量""" 56 | return df[df[_DATE_] >= df[_DATE_].iloc[-1]] 57 | 58 | 59 | def main(df: pd.DataFrame) -> pd.DataFrame: 60 | {% for key, value in groupbys.items() %} 61 | {{ value-}} 62 | {% endfor %} 63 | 64 | # drop intermediate columns 65 | df = df.drop(columns=list(filter(lambda x: x.startswith("_"), df.columns))) 66 | 67 | return df 68 | -------------------------------------------------------------------------------- /expr_codegen/polars/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wukan1986/expr_codegen/faa577832438d2f4765c289a7d56c29e09e5cf3b/expr_codegen/polars/__init__.py -------------------------------------------------------------------------------- /expr_codegen/polars/code.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from typing import Sequence, Literal 4 | 5 | import jinja2 6 | from jinja2 import FileSystemLoader, TemplateNotFound 7 | 8 | from expr_codegen.expr import TS, CS, GP 9 | from expr_codegen.model import ListDictList 10 | from expr_codegen.polars.printer import PolarsStrPrinter 11 | 12 | 13 | def get_groupby_from_tuple(tup, func_name, drop_cols): 14 | """从传入的元组中生成分组运行代码""" 15 | prefix2, *_ = tup 16 | 17 | if prefix2 == TS: 18 | # 组内需要按时间进行排序,需要维持顺序 19 | prefix2, asset = tup 20 | return f'df = {func_name}(df.sort(_ASSET_, _DATE_)).drop(*{drop_cols})' 21 | if prefix2 == CS: 22 | prefix2, date = tup 23 | return f'df = {func_name}(df.sort(_DATE_)).drop(*{drop_cols})' 24 | if prefix2 == GP: 25 | prefix2, date, group = tup 26 | return f'df = {func_name}(df.sort(_DATE_, "{group}")).drop(*{drop_cols})' 27 | 28 | return f'df = {func_name}(df).drop(*{drop_cols})' 29 | 30 | 31 | def symbols_to_code(syms): 32 | a = [f"{s}" for s in syms] 33 | b = [f"'{s}'" for s in syms] 34 | return f"""_ = [{','.join(b)}] 35 | [{','.join(a)}] = [pl.col(i) for i in _]""" 36 | 37 | 38 | def codegen(exprs_ldl: ListDictList, exprs_src, syms_dst, 39 | filename, 40 | date='date', asset='asset', 41 | extra_codes: Sequence[str] = (), 42 | over_null: Literal['order_by', 'partition_by', None] = 'partition_by', 43 | filter_last: bool = False, 44 | **kwargs): 45 | """基于模板的代码生成""" 46 | if filename is None: 47 | filename = 'template.py.j2' 48 | 49 | parser = argparse.ArgumentParser() 50 | parser.add_argument("--over_null", type=str, nargs="?", default=over_null) 51 | 52 | # 打印Polars风格代码 53 | p = PolarsStrPrinter() 54 | 55 | # polars风格代码 56 | funcs = {} 57 | # 分组应用代码。这里利用了字典按插入顺序排序的特点,将排序放在最前 58 | groupbys = {'sort': ''} 59 | # 处理过后的表达式 60 | exprs_dst = [] 61 | syms_out = [] 62 | ts_func_name = None 63 | drop_symbols = exprs_ldl.drop_symbols() 64 | j = -1 65 | for i, row in enumerate(exprs_ldl.values()): 66 | for k, vv in row.items(): 67 | j += 1 68 | if len(vv) == 0: 69 | continue 70 | # 函数名 71 | func_name = f'func_{i}_{"__".join(k)}' 72 | func_code = [] 73 | for kv in vv: 74 | if kv is None: 75 | func_code.append(f" )") 76 | func_code.append(f"# " + '=' * 40) 77 | func_code.append(f" df = df.with_columns(") 78 | exprs_dst.append(f"#" + '=' * 40 + func_name) 79 | else: 80 | va, ex, sym, comment = kv 81 | # 多个#时,只取第一个#后的参数 82 | args, argv = parser.parse_known_args(args=comment.split("#")[1].split(" ")) 83 | s1 = str(ex) 84 | s2 = p.doprint(ex) 85 | if s1 != s2: 86 | # 不想等,打印注释,显示会更直观察 87 | func_code.append(f"# {va} = {s1}") 88 | if k[0] == TS: 89 | ts_func_name = func_name 90 | # https://github.com/pola-rs/polars/issues/12925#issuecomment-2552764629 91 | _sym = [f"{s}.is_not_null()" for s in set(sym)] 92 | if len(_sym) > 1: 93 | _sym = f"pl.all_horizontal({','.join(_sym)})" 94 | else: 95 | _sym = ','.join(_sym) 96 | if args.over_null == 'partition_by': 97 | func_code.append(f"{va}=({s2}).over({_sym}, _ASSET_, order_by=_DATE_),") 98 | elif args.over_null == 'order_by': 99 | func_code.append(f"{va}=({s2}).over(_ASSET_, order_by=[{_sym}, _DATE_]),") 100 | else: 101 | func_code.append(f"{va}=({s2}).over(_ASSET_, order_by=_DATE_),") 102 | elif k[0] == CS: 103 | func_code.append(f"{va}=({s2}).over(_DATE_),") 104 | elif k[0] == GP: 105 | func_code.append(f"{va}=({s2}).over(_DATE_, '{k[2]}'),") 106 | else: 107 | func_code.append(f"{va}={s2},") 108 | exprs_dst.append(f"{va} = {s1} {comment}") 109 | if va not in syms_dst: 110 | syms_out.append(va) 111 | func_code.append(f" )") 112 | func_code = func_code[1:] 113 | 114 | # polars风格代码列表 115 | funcs[func_name] = '\n'.join(func_code) 116 | # 只有下划线开头再删除 117 | ds = [x for x in drop_symbols[j] if x.startswith('_')] 118 | # 分组应用代码 119 | groupbys[func_name] = get_groupby_from_tuple(k, func_name, ds) 120 | 121 | syms1 = symbols_to_code(syms_dst) 122 | syms2 = symbols_to_code(syms_out) 123 | if filter_last: 124 | _groupbys = {'sort': groupbys['sort']} 125 | if ts_func_name is None: 126 | _groupbys['_filter_last'] = "df = filter_last(df.sort(_DATE_))" 127 | for k, v in groupbys.items(): 128 | _groupbys[k] = v 129 | if k == ts_func_name: 130 | _groupbys[k + '_filter_last'] = "df = filter_last(df)" 131 | groupbys = _groupbys 132 | 133 | try: 134 | env = jinja2.Environment(loader=FileSystemLoader(os.path.dirname(__file__))) 135 | template = env.get_template(filename) 136 | except TemplateNotFound: 137 | env = jinja2.Environment(loader=FileSystemLoader(os.path.dirname(filename))) 138 | template = env.get_template(os.path.basename(filename)) 139 | 140 | return template.render(funcs=funcs, groupbys=groupbys, 141 | exprs_src=exprs_src, exprs_dst=exprs_dst, 142 | syms1=syms1, syms2=syms2, 143 | date=date, asset=asset, 144 | extra_codes=extra_codes) 145 | -------------------------------------------------------------------------------- /expr_codegen/polars/printer.py: -------------------------------------------------------------------------------- 1 | from sympy import Basic, Function, StrPrinter 2 | from sympy.printing.precedence import precedence, PRECEDENCE 3 | 4 | 5 | # TODO: 如有新添加函数,但表达式有变更才需要在此补充对应的打印代码,否则可以省略 6 | 7 | class PolarsStrPrinter(StrPrinter): 8 | def _print(self, expr, **kwargs) -> str: 9 | """Internal dispatcher 10 | 11 | Tries the following concepts to print an expression: 12 | 1. Let the object print itself if it knows how. 13 | 2. Take the best fitting method defined in the printer. 14 | 3. As fall-back use the emptyPrinter method for the printer. 15 | """ 16 | self._print_level += 1 17 | try: 18 | # If the printer defines a name for a printing method 19 | # (Printer.printmethod) and the object knows for itself how it 20 | # should be printed, use that method. 21 | if self.printmethod and hasattr(expr, self.printmethod): 22 | if not (isinstance(expr, type) and issubclass(expr, Basic)): 23 | return getattr(expr, self.printmethod)(self, **kwargs) 24 | 25 | # See if the class of expr is known, or if one of its super 26 | # classes is known, and use that print function 27 | # Exception: ignore the subclasses of Undefined, so that, e.g., 28 | # Function('gamma') does not get dispatched to _print_gamma 29 | classes = type(expr).__mro__ 30 | # if AppliedUndef in classes: 31 | # classes = classes[classes.index(AppliedUndef):] 32 | # if UndefinedFunction in classes: 33 | # classes = classes[classes.index(UndefinedFunction):] 34 | # Another exception: if someone subclasses a known function, e.g., 35 | # gamma, and changes the name, then ignore _print_gamma 36 | if Function in classes: 37 | i = classes.index(Function) 38 | classes = tuple(c for c in classes[:i] if \ 39 | c.__name__ == classes[0].__name__ or \ 40 | c.__name__.endswith("Base")) + classes[i:] 41 | for cls in classes: 42 | printmethodname = '_print_' + cls.__name__ 43 | 44 | # 所有以gp_开头的函数都转换成cs_开头 45 | if printmethodname.startswith('_print_gp_'): 46 | printmethodname = "_print_gp_" 47 | 48 | printmethod = getattr(self, printmethodname, None) 49 | if printmethod is not None: 50 | return printmethod(expr, **kwargs) 51 | # Unknown object, fall back to the emptyPrinter. 52 | return self.emptyPrinter(expr) 53 | finally: 54 | self._print_level -= 1 55 | 56 | def _print_Symbol(self, expr): 57 | return expr.name 58 | 59 | def _print_Equality(self, expr): 60 | PREC = precedence(expr) 61 | return "%s==%s" % (self.parenthesize(expr.args[0], PREC), self.parenthesize(expr.args[1], PREC)) 62 | 63 | def _print_Or(self, expr): 64 | PREC = PRECEDENCE["Mul"] 65 | return " | ".join(self.parenthesize(arg, PREC) for arg in expr.args) 66 | 67 | def _print_Xor(self, expr): 68 | PREC = PRECEDENCE["Mul"] 69 | return " ^ ".join(self.parenthesize(arg, PREC) for arg in expr.args) 70 | 71 | def _print_And(self, expr): 72 | PREC = PRECEDENCE["Mul"] 73 | return " & ".join(self.parenthesize(arg, PREC) for arg in expr.args) 74 | 75 | def _print_Not(self, expr): 76 | PREC = PRECEDENCE["Mul"] 77 | return "~%s" % self.parenthesize(expr.args[0], PREC) 78 | 79 | def _print_gp_(self, expr): 80 | """gp_函数都转换成cs_函数,但要丢弃第一个参数""" 81 | new_args = [self._print(arg) for arg in expr.args[1:]] 82 | func_name = expr.func.__name__[3:] 83 | return "cs_%s(%s)" % (func_name, ",".join(new_args)) 84 | -------------------------------------------------------------------------------- /expr_codegen/polars/template.py.j2: -------------------------------------------------------------------------------- 1 | # this code is auto generated by the expr_codegen 2 | # https://github.com/wukan1986/expr_codegen 3 | # 此段代码由 expr_codegen 自动生成,欢迎提交 issue 或 pull request 4 | from typing import TypeVar 5 | 6 | import polars as pl # noqa 7 | import polars.selectors as cs # noqa 8 | # from loguru import logger # noqa 9 | from polars import DataFrame as _pl_DataFrame 10 | from polars import LazyFrame as _pl_LazyFrame 11 | 12 | # =================================== 13 | # 导入优先级,例如:ts_RSI在ta与talib中都出现了,优先使用ta 14 | # 运行时,后导入覆盖前导入,但IDE智能提示是显示先导入的 15 | _ = 0 # 只要之前出现了语句,之后的import位置不参与调整 16 | # from polars_ta.prefix.talib import * # noqa 17 | from polars_ta.prefix.tdx import * # noqa 18 | from polars_ta.prefix.ta import * # noqa 19 | from polars_ta.prefix.wq import * # noqa 20 | from polars_ta.prefix.cdl import * # noqa 21 | from polars_ta.prefix.vec import * # noqa 22 | 23 | DataFrame = TypeVar('DataFrame', _pl_LazyFrame, _pl_DataFrame) 24 | # =================================== 25 | 26 | {{ syms1 }} 27 | 28 | {{ syms2 }} 29 | 30 | _DATE_ = '{{ date }}' 31 | _ASSET_ = '{{ asset }}' 32 | _NONE_ = None 33 | _TRUE_ = True 34 | _FALSE_ = False 35 | 36 | 37 | def unpack(x: pl.Expr, idx: int = 0) -> pl.Expr: 38 | return x.struct[idx] 39 | 40 | {%-for row in extra_codes %} 41 | {{ row-}} 42 | {% endfor %} 43 | 44 | {% for key, value in funcs.items() %} 45 | 46 | def {{ key }}(df: DataFrame) -> DataFrame: 47 | {{ value }} 48 | return df 49 | 50 | {% endfor %} 51 | 52 | """ 53 | {%-for row in exprs_dst %} 54 | {{ row-}} 55 | {% endfor %} 56 | """ 57 | 58 | """ 59 | {%-for a,b,c in exprs_src %} 60 | {{ a }} = {{ b}} {{c-}} 61 | {% endfor %} 62 | """ 63 | 64 | 65 | def filter_last(df: DataFrame) -> DataFrame: 66 | """过滤数据,只取最后一天。实盘时可用于减少计算量 67 | 前一个调用的ts,这里可以直接调用,可以认为已经排序好 68 | `df = filter_last(df)` 69 | 反之 70 | `df = filter_last(df.sort(_DATE_))` 71 | """ 72 | return df.filter(pl.col(_DATE_) >= df.select(pl.last(_DATE_))[0, 0]) 73 | 74 | 75 | def main(df: DataFrame) -> DataFrame: 76 | {% for key, value in groupbys.items() %} 77 | {{ value-}} 78 | {% endfor %} 79 | 80 | # drop intermediate columns 81 | # df = df.select(pl.exclude(r'^_x_\d+$')) 82 | df = df.select(~cs.starts_with("_")) 83 | 84 | # shrink 85 | df = df.select(cs.all().shrink_dtype()) 86 | 87 | return df 88 | 89 | -------------------------------------------------------------------------------- /expr_codegen/sql/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wukan1986/expr_codegen/faa577832438d2f4765c289a7d56c29e09e5cf3b/expr_codegen/sql/__init__.py -------------------------------------------------------------------------------- /expr_codegen/sql/code.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from typing import Sequence, Literal 4 | 5 | import jinja2 6 | from jinja2 import FileSystemLoader, TemplateNotFound 7 | 8 | from expr_codegen.expr import TS, CS, GP 9 | from expr_codegen.model import ListDictList 10 | from expr_codegen.sql.printer import SQLStrPrinter 11 | 12 | 13 | def codegen(exprs_ldl: ListDictList, exprs_src, syms_dst, 14 | filename, 15 | date='date', asset='asset', 16 | extra_codes: Sequence[str] = (), 17 | over_null: Literal['order_by', 'partition_by', None] = 'partition_by', 18 | table_name: str = 'self', 19 | filter_last: bool = False, 20 | **kwargs): 21 | """基于模板的代码生成""" 22 | if filename is None: 23 | filename = 'template.sql.j2' 24 | 25 | parser = argparse.ArgumentParser() 26 | parser.add_argument("--over_null", type=str, nargs="?", default=over_null) 27 | 28 | # 打印Polars风格代码 29 | p = SQLStrPrinter() 30 | 31 | # polars风格代码 32 | funcs = {} 33 | # 分组应用代码。这里利用了字典按插入顺序排序的特点,将排序放在最前 34 | groupbys = {'sort': ''} 35 | # 处理过后的表达式 36 | exprs_dst = [] 37 | syms_out = [] 38 | ts_func_name = None 39 | drop_symbols = exprs_ldl.drop_symbols() 40 | j = -1 41 | last_func_name = table_name 42 | for i, row in enumerate(exprs_ldl.values()): 43 | for k, vv in row.items(): 44 | j += 1 45 | if len(vv) == 0: 46 | continue 47 | # 函数名 48 | func_name = f'cte_{i}_{"__".join(k)}' 49 | func_code = [] 50 | for kv in vv: 51 | if kv is None: 52 | func_code.append(f"{func_name} AS (SELECT *,") 53 | exprs_dst.append(f"#" + '=' * 40 + func_name) 54 | else: 55 | va, ex, sym, comment = kv 56 | # 多个#时,只取第一个#后的参数 57 | args, argv = parser.parse_known_args(args=comment.split("#")[1].split(" ")) 58 | s1 = str(ex) 59 | s2 = p.doprint(ex) 60 | if k[0] == TS: 61 | ts_func_name = func_name 62 | # https://github.com/pola-rs/polars/issues/12925#issuecomment-2552764629 63 | _sym = [f"`{s}` IS NOT NULL" for s in set(sym)] 64 | if len(_sym) > 1: 65 | _sym = f"({' AND '.join(_sym)})" 66 | else: 67 | _sym = ','.join(_sym) 68 | if args.over_null == 'partition_by': 69 | func_code.append(f"{s2} OVER(PARTITION BY {_sym},`{asset}` ORDER BY `{date}`) AS {va},") 70 | elif args.over_null == 'order_by': 71 | func_code.append(f"{s2} OVER(PARTITION BY `{asset}` ORDER BY {_sym},`{date}`) AS {va},") 72 | else: 73 | func_code.append(f"{s2} OVER(PARTITION BY `{asset}` ORDER BY `{date}`) AS {va},") 74 | elif k[0] == CS: 75 | func_code.append(f"{s2} OVER(PARTITION BY `{date}`) AS {va},") 76 | elif k[0] == GP: 77 | func_code.append(f"{s2} OVER(PARTITION BY `{date}`,`{k[2]}`) AS {va},") 78 | else: 79 | func_code.append(f"{s2} AS {va},") 80 | exprs_dst.append(f"{va} = {s1} {comment}") 81 | if va not in syms_dst: 82 | syms_out.append(va) 83 | func_code.append(f"FROM {last_func_name}),") 84 | last_func_name = func_name 85 | 86 | # sql风格代码列表 87 | funcs[func_name] = '\n '.join(func_code) 88 | # 只有下划线开头再删除 89 | ds = [x for x in drop_symbols[j] if x.startswith('_')] 90 | 91 | if filter_last: 92 | # TODO 没有实现 93 | pass 94 | 95 | try: 96 | env = jinja2.Environment(loader=FileSystemLoader(os.path.dirname(__file__))) 97 | template = env.get_template(filename) 98 | except TemplateNotFound: 99 | env = jinja2.Environment(loader=FileSystemLoader(os.path.dirname(filename))) 100 | template = env.get_template(os.path.basename(filename)) 101 | 102 | return template.render(funcs=funcs, 103 | exprs_src=exprs_src, exprs_dst=exprs_dst, 104 | date=date, asset=asset, 105 | extra_codes=extra_codes, 106 | last_func_name=last_func_name) 107 | -------------------------------------------------------------------------------- /expr_codegen/sql/printer.py: -------------------------------------------------------------------------------- 1 | from sympy import Basic, Function, StrPrinter 2 | from sympy.printing.precedence import precedence, PRECEDENCE 3 | 4 | 5 | # TODO: 如有新添加函数,但表达式有变更才需要在此补充对应的打印代码,否则可以省略 6 | 7 | class SQLStrPrinter(StrPrinter): 8 | # https://docs.pola.rs/api/python/stable/reference/sql/functions/index.html 9 | # https://github.com/pola-rs/polars/blob/main/crates/polars-sql/src/functions.rs 10 | # 将polars_ta中的函数转换成SQL中的函数 11 | convert_funcs = { 12 | # Math functions 13 | 'abs_': 'ABS', 14 | 'ceiling': 'CEIL', 15 | 'div': 'DIV', 16 | 'exp': 'EXP', 17 | 'floor': 'FLOOR', 18 | 'log': 'LN', 19 | 'log10': 'LOG10', 20 | 'log1p': 'LOG1P', 21 | 'log2': 'LOG2', 22 | 'mod': 'MOD', 23 | 'sign': 'SIGN', 24 | 'sqrt': 'SQRT', 25 | 'power': 'POW', 26 | 'round_': 'ROUND', 27 | # Trig functions 28 | 'arc_cos': 'ACOS', 29 | 'arc_sin': 'ASIN', 30 | 'arc_tan': 'ATAN', 31 | 'arc_tan2': 'ATAN2', 32 | 'cot': 'COT', 33 | 'cos': 'COS', 34 | 'degrees': 'DEGREES', 35 | 'radians': 'RADIANS', 36 | 'sin': 'SIN', 37 | 'tan': 'TAN', 38 | } 39 | 40 | def _print(self, expr, **kwargs) -> str: 41 | """Internal dispatcher 42 | 43 | Tries the following concepts to print an expression: 44 | 1. Let the object print itself if it knows how. 45 | 2. Take the best fitting method defined in the printer. 46 | 3. As fall-back use the emptyPrinter method for the printer. 47 | """ 48 | self._print_level += 1 49 | try: 50 | # If the printer defines a name for a printing method 51 | # (Printer.printmethod) and the object knows for itself how it 52 | # should be printed, use that method. 53 | if self.printmethod and hasattr(expr, self.printmethod): 54 | if not (isinstance(expr, type) and issubclass(expr, Basic)): 55 | return getattr(expr, self.printmethod)(self, **kwargs) 56 | 57 | # See if the class of expr is known, or if one of its super 58 | # classes is known, and use that print function 59 | # Exception: ignore the subclasses of Undefined, so that, e.g., 60 | # Function('gamma') does not get dispatched to _print_gamma 61 | classes = type(expr).__mro__ 62 | # if AppliedUndef in classes: 63 | # classes = classes[classes.index(AppliedUndef):] 64 | # if UndefinedFunction in classes: 65 | # classes = classes[classes.index(UndefinedFunction):] 66 | # Another exception: if someone subclasses a known function, e.g., 67 | # gamma, and changes the name, then ignore _print_gamma 68 | if Function in classes: 69 | i = classes.index(Function) 70 | classes = tuple(c for c in classes[:i] if \ 71 | c.__name__ == classes[0].__name__ or \ 72 | c.__name__.endswith("Base")) + classes[i:] 73 | for cls in classes: 74 | printmethodname = '_print_' + cls.__name__ 75 | 76 | # 所有以gp_开头的函数都转换成cs_开头 77 | if printmethodname.startswith('_print_gp_'): 78 | printmethodname = "_print_gp_" 79 | 80 | # polars_ta中的函数转换成SQL函数 81 | if cls.__name__ in self.convert_funcs: 82 | return self._print_Rename_(expr, self.convert_funcs[cls.__name__]) 83 | 84 | printmethod = getattr(self, printmethodname, None) 85 | if printmethod is not None: 86 | return printmethod(expr, **kwargs) 87 | # Unknown object, fall back to the emptyPrinter. 88 | return self.emptyPrinter(expr) 89 | finally: 90 | self._print_level -= 1 91 | 92 | def _print_Symbol(self, expr): 93 | return f'`{expr.name}`' 94 | 95 | def _print_Equality(self, expr): 96 | PREC = precedence(expr) 97 | return "%s=%s" % (self.parenthesize(expr.args[0], PREC), self.parenthesize(expr.args[1], PREC)) 98 | 99 | def _print_Or(self, expr): 100 | PREC = PRECEDENCE["Mul"] 101 | return " OR ".join(self.parenthesize(arg, PREC) for arg in expr.args) 102 | 103 | def _print_Xor(self, expr): 104 | PREC = PRECEDENCE["Mul"] 105 | return " XOR ".join(self.parenthesize(arg, PREC) for arg in expr.args) 106 | 107 | def _print_And(self, expr): 108 | PREC = PRECEDENCE["Mul"] 109 | return " AND ".join(self.parenthesize(arg, PREC) for arg in expr.args) 110 | 111 | def _print_Not(self, expr): 112 | PREC = PRECEDENCE["Mul"] 113 | return "NOT %s" % self.parenthesize(expr.args[0], PREC) 114 | 115 | def _print_gp_(self, expr): 116 | """gp_函数都转换成cs_函数,但要丢弃第一个参数""" 117 | new_args = [self._print(arg) for arg in expr.args[1:]] 118 | func_name = expr.func.__name__[3:] 119 | return "cs_%s(%s)" % (func_name, ",".join(new_args)) 120 | 121 | def _print_Rename_(self, expr, new_name): 122 | l = [self._print(o) for o in expr.args] 123 | return new_name + "(%s)" % ", ".join(l) 124 | -------------------------------------------------------------------------------- /expr_codegen/sql/template.sql.j2: -------------------------------------------------------------------------------- 1 | /* 2 | this code is auto generated by the expr_codegen 3 | # https://github.com/wukan1986/expr_codegen 4 | # 此段代码由 expr_codegen 自动生成,欢迎提交 issue 或 pull request 5 | */ 6 | 7 | /* 8 | {%-for row in extra_codes %} 9 | {{ row-}} 10 | {% endfor %} 11 | */ 12 | 13 | WITH 14 | {% for key, value in funcs.items() %} 15 | {{ value }} 16 | {% endfor %} 17 | SELECT * FROM {{ last_func_name }} 18 | 19 | /* 20 | {%-for row in exprs_dst %} 21 | {{ row-}} 22 | {% endfor %} 23 | */ 24 | 25 | /* 26 | {%-for a,b,c in exprs_src %} 27 | {{ a }} = {{ b}} {{c-}} 28 | {% endfor %} 29 | */ 30 | -------------------------------------------------------------------------------- /expr_codegen/tool.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | import pathlib 3 | from functools import lru_cache 4 | from io import TextIOBase 5 | from typing import Sequence, Union, TypeVar, Optional, Literal 6 | 7 | import polars as pl 8 | from black import Mode, format_str 9 | from loguru import logger 10 | from sympy import simplify, cse, symbols, numbered_symbols 11 | from sympy.core.expr import Expr 12 | from sympy.logic import boolalg 13 | 14 | from expr_codegen.codes import sources_to_exprs 15 | from expr_codegen.expr import get_current_by_prefix, get_children, replace_exprs 16 | from expr_codegen.model import dag_start, dag_end, dag_middle, _RESERVED_WORD_ 17 | 18 | try: 19 | from pandas import DataFrame as _pd_DataFrame 20 | except ImportError: 21 | _pd_DataFrame = None 22 | 23 | try: 24 | from polars import DataFrame as _pl_DataFrame 25 | from polars import LazyFrame as _pl_LazyFrame 26 | except ImportError: 27 | _pl_DataFrame = None 28 | _pl_LazyFrame = None 29 | 30 | DataFrame = TypeVar('DataFrame', _pl_LazyFrame, _pl_DataFrame, _pd_DataFrame) 31 | 32 | # =============================== 33 | # TypeError: expecting bool or Boolean, not `ts_delay(X, 3)`. 34 | # ts_delay(X, 3) & ts_delay(Y, 3) 35 | boolalg.as_Boolean = lambda x: x 36 | 37 | 38 | # AttributeError: 'StrictGreaterThan' object has no attribute 'diff' 39 | # ts_count(open > 1, 2) == 2 40 | def _diff(self, *symbols, **assumptions): 41 | assumptions.setdefault("evaluate", False) 42 | from sympy.core.function import _derivative_dispatch 43 | return _derivative_dispatch(self, *symbols, **assumptions) 44 | 45 | 46 | Expr.diff = _diff 47 | 48 | 49 | # =============================== 50 | 51 | def simplify2(expr): 52 | # return simplify(expr) 53 | try: 54 | expr = simplify(expr) 55 | except AttributeError as e: 56 | print(f'{expr} ,表达式无法简化, {e}') 57 | return expr 58 | 59 | 60 | class ExprTool: 61 | 62 | def __init__(self): 63 | self.get_current_func = get_current_by_prefix 64 | self.get_current_func_kwargs = {} 65 | self.exprs_list = {} 66 | self.exprs_names = [] 67 | self.globals_ = {} 68 | 69 | def set_current(self, func, **kwargs): 70 | self.get_current_func = func 71 | self.get_current_func_kwargs = kwargs 72 | 73 | def extract(self, expr, date, asset): 74 | """抽取分割后的子公式 75 | 76 | Parameters 77 | ---------- 78 | expr 79 | 单表达式 80 | 81 | Returns 82 | ------- 83 | 表达式列表 84 | 85 | """ 86 | exprs = [] 87 | syms = [] 88 | get_children(self.get_current_func, self.get_current_func_kwargs, 89 | expr, 90 | output_exprs=exprs, output_symbols=syms, 91 | date=date, asset=asset) 92 | # print('=' * 20, expr) 93 | # print(exprs) 94 | return exprs, syms 95 | 96 | def merge(self, date, asset, args): 97 | """合并多个表达式 98 | 99 | 1. 先抽取分割子公式 100 | 2. 合并 子公式+长公式,去重 101 | 102 | Parameters 103 | ---------- 104 | args 105 | 表达式列表 106 | 107 | Returns 108 | ------- 109 | 表达式列表 110 | """ 111 | # 抽取前先化简 112 | args = [(k, simplify2(v), c) for k, v, c in args] 113 | 114 | # 保留了注释信息 115 | exprs_syms = [(self.extract(v, date, asset), c) for k, v, c in args] 116 | exprs = [] 117 | syms = [] 118 | for (e, s), c in exprs_syms: 119 | syms.extend(s) 120 | for _ in e: 121 | # 抽取的表达式添加注释 122 | exprs.append((_, c)) 123 | 124 | syms = sorted(set(syms), key=syms.index) 125 | # 如果目标有重复表达式,这里会混乱 126 | exprs = sorted(set(exprs), key=exprs.index) 127 | # 这里不能合并简化与未简化的表达式,会导致cse时失败,需要简化表达式合并 128 | exprs = exprs + [(v, c) for k, v, c in args] 129 | 130 | # print(exprs) 131 | syms = [str(s) for s in syms] 132 | return exprs, syms 133 | 134 | def reduce(self, repl, redu): 135 | """减少中间变量数量,有利用减少内存占用""" 136 | 137 | exprs_list = [] 138 | 139 | # cse前简化一次,cse后不再简化 140 | # (~开盘涨停 & 昨收涨停) | (~收盘涨停 & 最高涨停) 141 | for k, v in repl: 142 | exprs_list.append((k, v, "#")) 143 | for k, v, c in redu: 144 | exprs_list.append((k, v, c)) 145 | 146 | return exprs_list 147 | 148 | def cse(self, exprs, symbols_repl=None, exprs_src=None): 149 | """多个子公式+长公式,提取公共公式 150 | 151 | Parameters 152 | ---------- 153 | exprs 154 | 表达式列表 155 | symbols_repl 156 | 中间字段名迭代器 157 | exprs_src 158 | 最终字段名列表 159 | 160 | Returns 161 | ------- 162 | graph_dag 163 | 依赖关系的有向无环图 164 | graph_key 165 | 每个函数分组用key 166 | graph_exp 167 | 表达式 168 | 169 | """ 170 | self.exprs_names = [k for k, v, c in exprs_src] 171 | # 包含了注释信息 172 | _exprs = [k for k, v in exprs] 173 | 174 | # 注意:对于表达式右边相同,左边不同的情况,会当成一个处理 175 | repl, redu = cse(_exprs, symbols_repl, optimizations="basic") 176 | outputs_len = len(exprs_src) 177 | 178 | new_redu = [] 179 | symbols_redu = iter(exprs_src) 180 | for expr in redu[-outputs_len:]: 181 | # 可能部分表达式只在之前出现过,后面完全用不到如,ts_rank(ts_decay_linear(x_147, 11.4157), 6.72611) 182 | variable = next(symbols_redu) 183 | a = symbols(variable[0]) 184 | new_redu.append((a, expr, variable[2])) 185 | 186 | self.exprs_list = self.reduce(repl, new_redu) 187 | 188 | # with open("exprs.pickle", "wb") as file: 189 | # pickle.dump(exprs_dict, file) 190 | 191 | return self.exprs_list 192 | 193 | def dag(self, merge: bool, date, asset): 194 | """生成DAG""" 195 | G = dag_start(self.exprs_list, self.get_current_func, self.get_current_func_kwargs, date, asset) 196 | if merge: 197 | G = dag_middle(G, self.exprs_names, self.get_current_func, self.get_current_func_kwargs, date, asset) 198 | return dag_end(G) 199 | 200 | def all(self, exprs_src, style: Literal['pandas', 'polars', 'sql'] = 'polars', 201 | template_file: Optional[str] = None, 202 | replace: bool = True, regroup: bool = False, format: bool = True, 203 | date='date', asset='asset', 204 | extra_codes: Sequence[object] = (), 205 | over_null: Literal['order_by', 'partition_by', None] = 'partition_by', 206 | table_name: str = 'self', 207 | filter_last: bool = False, 208 | **kwargs): 209 | """功能集成版,将几个功能写到一起方便使用 210 | 211 | Parameters 212 | ---------- 213 | exprs_src: list 214 | 表达式列表 215 | style: str 216 | 代码风格。可选值 ('polars', 'pandas', 'sql') 217 | template_file: str 218 | 根据需求可定制模板 219 | replace:bool 220 | 表达式提换 221 | regroup:bool 222 | 分组重排。注意:目前好像不稳定 223 | format:bool 224 | 代码格式化 225 | date:str 226 | 日期字段名 227 | asset:str 228 | 资产字段名 229 | extra_codes: Sequence[object] 230 | 需要复制到模板中的额外代码 231 | table_name 232 | filter_last 233 | 234 | Returns 235 | ------- 236 | 代码字符串 237 | 238 | """ 239 | assert style in ('pandas', 'polars', 'sql') 240 | 241 | if replace: 242 | exprs_src = replace_exprs(exprs_src) 243 | 244 | # 子表达式在前,原表式在最后 245 | exprs_dst, syms_dst = self.merge(date, asset, exprs_src) 246 | syms_dst = list(set(syms_dst) - _RESERVED_WORD_) 247 | 248 | # 提取公共表达式 249 | self.cse(exprs_dst, symbols_repl=numbered_symbols('_x_'), exprs_src=exprs_src) 250 | # 有向无环图流转 251 | exprs_ldl, G = self.dag(True, date, asset) 252 | 253 | if regroup: 254 | exprs_ldl.optimize(merge=style != 'sql') 255 | 256 | if style == 'polars': 257 | from expr_codegen.polars.code import codegen 258 | elif style == 'pandas': 259 | from expr_codegen.pandas.code import codegen 260 | elif style == 'sql': 261 | from expr_codegen.sql.code import codegen 262 | format = False 263 | else: 264 | raise ValueError(f'unknown style {style}') 265 | 266 | extra_codes = [c if isinstance(c, str) else inspect.getsource(c) for c in extra_codes] 267 | 268 | codes = codegen(exprs_ldl, exprs_src, syms_dst, 269 | filename=template_file, date=date, asset=asset, 270 | extra_codes=extra_codes, 271 | over_null=over_null, 272 | table_name=table_name, 273 | filter_last=filter_last, 274 | **kwargs) 275 | 276 | logger.info(f'{style} code is generated') 277 | 278 | if format: 279 | # 格式化。在遗传算法中没有必要 280 | codes = format_str(codes, mode=Mode(line_length=600, magic_trailing_comma=True)) 281 | 282 | return codes, G 283 | 284 | @lru_cache(maxsize=64) 285 | def _get_code(self, 286 | source: str, *more_sources: str, 287 | extra_codes: str, 288 | output_file: str, 289 | convert_xor: bool, 290 | style: Literal['pandas', 'polars', 'sql'] = 'polars', 291 | template_file: Optional[str] = None, 292 | date: str = 'date', asset: str = 'asset', 293 | over_null: Literal['order_by', 'partition_by', None] = 'partition_by', 294 | table_name: str = 'self', 295 | filter_last: bool = False, 296 | **kwargs) -> str: 297 | """通过字符串生成代码, 加了缓存,多次调用不重复生成""" 298 | raw, exprs_list = sources_to_exprs(self.globals_, source, *more_sources, convert_xor=convert_xor) 299 | 300 | # 生成代码 301 | code, G = _TOOL_.all(exprs_list, style=style, template_file=template_file, 302 | replace=True, regroup=True, format=True, 303 | date=date, asset=asset, 304 | # 复制了需要使用的函数,还复制了最原始的表达式 305 | extra_codes=(raw, 306 | # 传入多个列的方法 307 | extra_codes, 308 | ), 309 | over_null=over_null, 310 | table_name=table_name, 311 | filter_last=filter_last, 312 | **kwargs) 313 | 314 | # 移回到cache,防止多次调用多次保存 315 | if hasattr(output_file, "write"): 316 | output_file.write(code) 317 | elif output_file is not None: 318 | output_file = pathlib.Path(output_file) 319 | logger.info(f'save to {output_file.absolute()}') 320 | with open(output_file, 'w', encoding='utf-8') as f: 321 | f.write(code) 322 | 323 | return code 324 | 325 | 326 | @lru_cache(maxsize=64, typed=True) 327 | def _get_func_from_code_py(code: str): 328 | logger.info(f'get func from code py') 329 | globals_ = {} 330 | exec(code, globals_) 331 | return globals_['main'] 332 | 333 | 334 | @lru_cache(maxsize=64, typed=True) 335 | def _get_func_from_module(module: str): 336 | """"可下断点调试""" 337 | m = __import__(module, fromlist=['*']) 338 | logger.info(f'get func from module {m}') 339 | return m.main 340 | 341 | 342 | @lru_cache(maxsize=64, typed=True) 343 | def _get_func_from_file_py(file: str): 344 | file = pathlib.Path(file) 345 | logger.info(f'get func from file "{file.absolute()}"') 346 | with open(file, 'r', encoding='utf-8') as f: 347 | globals_ = {} 348 | exec(f.read(), globals_) 349 | return globals_['main'] 350 | 351 | 352 | @lru_cache(maxsize=64, typed=True) 353 | def _get_code_from_file(file: str): 354 | file = pathlib.Path(file) 355 | logger.info(f'get code from file "{file.absolute()}"') 356 | with open(file, 'r', encoding='utf-8') as f: 357 | return f.read() 358 | 359 | 360 | _TOOL_ = ExprTool() 361 | 362 | 363 | def codegen_exec(df: Union[DataFrame, None], 364 | *codes, 365 | over_null: Literal['partition_by', 'order_by', None], 366 | extra_codes: str = r'CS_SW_L1 = r"^sw_l1_\d+$"', 367 | output_file: Union[str, TextIOBase, None] = None, 368 | run_file: Union[bool, str] = False, 369 | convert_xor: bool = False, 370 | style: Literal['pandas', 'polars', 'sql'] = 'polars', 371 | template_file: Optional[str] = None, 372 | date: str = 'date', asset: str = 'asset', 373 | table_name: str = 'self', 374 | filter_last: bool = False, 375 | **kwargs) -> Union[DataFrame, str]: 376 | """快速转换源代码并执行 377 | 378 | Parameters 379 | ---------- 380 | df: pl.DataFrame, pd.DataFrame, pl.LazyFrame,None 381 | 输入DataFrame,输出DataFrame 382 | 输入None,输出代码 383 | codes: 384 | 函数体。此部分中的表达式会被翻译成目标代码 385 | extra_codes: str 386 | 额外代码。不做处理,会被直接复制到目标代码中 387 | output_file: str| TextIOBase 388 | 保存生成的目标代码到文件中 389 | run_file: bool or str 390 | 是否不生成脚本,直接运行代码。注意:带缓存功能,多次调用不重复生成 391 | - 如果是True,会自动从output_file中读取代码 392 | - 如果是字符串,会自动从run_file中读取代码 393 | - 如果是模块名,会自动从模块中读取代码(可调试) 394 | - 注意:可能调用到其他目录下的同名模块,所以保存的文件名要有辨识度 395 | convert_xor: bool 396 | ^ 转成异或还是乘方 397 | style: str 398 | 代码风格。可选值 'pandas', 'polars', 'sql' 399 | - pandas: 不支持struct 400 | - sql: 只生成sql语句,不执行 401 | template_file: str 402 | 代码模板 403 | date: str 404 | 时间字段 405 | asset: str 406 | 资产字段 407 | over_null: str 408 | 时序中遇到null时的处理方式。只在style参数为'polars', 'sql'时有效 409 | - partition_by: 空值划分到不同分区 410 | - order_by: 空值排同一分区的前排 411 | - None: 不做处理 412 | table_name:str 413 | 表名。只在style参数为sql时有效 414 | filter_last:bool 415 | 在实盘时,只需要最后一天日期的数据,可以在最后一个`ts`之后过滤数据。目前只在style参数为'polars', 'pandas'时有效 416 | 417 | 418 | Returns 419 | ------- 420 | DataFrame 421 | 输出DataFrame 422 | str 423 | 输出代码 424 | 425 | Notes 426 | ----- 427 | 处处都有缓存,所以在公式研发阶段要留意日志输出。以免一直调试的旧公式 428 | 429 | 1. 确保重新生成了代码 `code is generated` 430 | 2. 通过代码得到了函数 `get func from code` 431 | 432 | """ 433 | if df is not None: 434 | input_file = None 435 | # 以下代码都带缓存功能 436 | if run_file is True: 437 | assert output_file is not None, 'output_file is required' 438 | input_file = str(output_file) 439 | elif run_file is not False: 440 | input_file = str(run_file) 441 | 442 | if input_file is not None: 443 | if input_file.endswith('.py'): 444 | return _get_func_from_file_py(input_file)(df) 445 | elif input_file.endswith('.sql'): 446 | with pl.SQLContext(frames={table_name: df}) as ctx: 447 | return ctx.execute(_get_code_from_file(input_file), eager=isinstance(df, _pl_DataFrame)) 448 | else: 449 | return _get_func_from_module(input_file)(df) # 可断点调试 450 | else: 451 | pass 452 | 453 | # 此代码来自于sympy.var 454 | frame = inspect.currentframe().f_back 455 | _TOOL_.globals_ = frame.f_globals.copy() 456 | del frame 457 | 458 | more_sources = [c if isinstance(c, str) else inspect.getsource(c) for c in codes] 459 | 460 | code = _TOOL_._get_code( 461 | *more_sources, 462 | extra_codes=extra_codes, 463 | output_file=output_file, 464 | convert_xor=convert_xor, 465 | style=style, template_file=template_file, 466 | date=date, asset=asset, 467 | over_null=over_null, 468 | table_name=table_name, 469 | filter_last=filter_last, 470 | **kwargs 471 | ) 472 | 473 | if df is None: 474 | # 如果df为空,直接返回代码 475 | return code 476 | elif style == 'sql': 477 | with pl.SQLContext(frames={table_name: df}) as ctx: 478 | return ctx.execute(code, eager=isinstance(df, _pl_DataFrame)) 479 | else: 480 | # 代码一样时就从缓存中取出函数 481 | return _get_func_from_code_py(code)(df) 482 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "expr_codegen" 3 | authors = [ 4 | { name = "wukan", email = "wu-kan@163.com" }, 5 | ] 6 | description = "symbol expression to polars expression tool" 7 | readme = "README.md" 8 | requires-python = ">=3.9" 9 | keywords = ["polars", "expression", "talib"] 10 | license = { file = "LICENSE" } 11 | classifiers = [ 12 | "Development Status :: 4 - Beta", 13 | "Programming Language :: Python" 14 | ] 15 | dependencies = [ 16 | 'black', 17 | 'Jinja2', 18 | 'networkx', 19 | 'loguru', 20 | 'sympy', 21 | 'ast-comments', 22 | ] 23 | dynamic = ["version"] 24 | 25 | [project.optional-dependencies] 26 | streamlit = [ 27 | 'streamlit', 28 | 'streamlit-ace', 29 | 'more_itertools', 30 | ] 31 | 32 | [build-system] 33 | requires = ["hatchling"] 34 | build-backend = "hatchling.build" 35 | 36 | [tool.hatch.version] 37 | path = "expr_codegen/_version.py" 38 | 39 | [tool.hatch.build.targets.wheel] 40 | packages = ["expr_codegen"] 41 | include-package-data = true 42 | 43 | [tool.hatch.build.targets.sdist] 44 | include = ["expr_codegen*"] 45 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | black 2 | Jinja2 3 | loguru 4 | sympy 5 | networkx 6 | ast-comments 7 | streamlit 8 | streamlit-ace 9 | more_itertools 10 | polars_ta>=0.5.5 11 | -------------------------------------------------------------------------------- /streamlit_app.py: -------------------------------------------------------------------------------- 1 | import base64 2 | 3 | import streamlit as st 4 | from streamlit_ace import st_ace 5 | 6 | from expr_codegen import codegen_exec, __version__ 7 | 8 | st.set_page_config(page_title='Expr Codegen', layout="wide") 9 | 10 | with st.sidebar: 11 | st.subheader("配置参数") 12 | 13 | date_name = st.text_input('日期字段名', 'date') 14 | asset_name = st.text_input('资产字段名', 'asset') 15 | 16 | # 生成代码 17 | style = st.radio('代码风格', ('polars', 'pandas', 'sql')) 18 | over_null = st.radio('over_null模式', ('partition_by', 'order_by', None)) 19 | filter_last = st.checkbox('filter_last', False) 20 | 21 | convert_xor = st.checkbox('将`^`转换为`**`', True) 22 | 23 | st.subheader("关于") 24 | st.markdown(f"""[Github仓库](https://github.com/wukan1986/expr_codegen) 25 | 26 | [问题反馈](http://github.com/wukan1986/expr_codegen/issues) 27 | 28 | 作者: wukan 29 | 30 | 声明: 31 | 1. 本站点不存储用户输入的表达式,安全保密可放心 32 | 2. 生成的代码可能有错,发现后请及时反馈 33 | 34 | version: {__version__} 35 | """) 36 | 37 | st.subheader('自定义表达式') 38 | 39 | exprs_src = st_ace(value=f"""# 请在此添加表达式,`=`右边为表达式,`=`左边为新因子名 40 | alpha_003=-1 * ts_corr(cs_rank(OPEN), cs_rank(VOLUME), 10) 41 | alpha_006=-1 * ts_corr(OPEN, VOLUME, 10) 42 | alpha_101=(CLOSE - OPEN) / ((HIGH - LOW) + 0.001) 43 | alpha_201=alpha_101+CLOSE # 中间变量示例 44 | 45 | LABEL_OO_1=OPEN[-2]/OPEN[-1]-1 # 第二天开盘交易 46 | LABEL_OO_2=OPEN[-3]/OPEN[-1]-1 # 第二天开盘交易,持有二天 47 | LABEL_CC_1=CLOSE[-1]/CLOSE-1 # 每天收盘交易 48 | """, 49 | language="python", 50 | auto_update=True, 51 | ) 52 | 53 | if st.button('生成代码'): 54 | with st.spinner('生成中,请等待...'): 55 | res = codegen_exec(None, exprs_src, over_null=over_null, convert_xor=convert_xor, style=style, filter_last=filter_last) 56 | b64 = base64.b64encode(res.encode('utf-8')) 57 | st.markdown(f'下载代码', 58 | unsafe_allow_html=True) 59 | # 下载按钮点击后会刷新页面,不推荐 60 | # st.download_button(label="下载代码", data=res, file_name='output.py') 61 | 62 | language = "python" 63 | if style == 'sql': 64 | language = "sql" 65 | 66 | with st.expander(label="预览代码"): 67 | st.code(res, language=language) 68 | -------------------------------------------------------------------------------- /tests/expr_order.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import polars as pl 4 | from loguru import logger 5 | 6 | from expr_codegen import codegen_exec 7 | 8 | _N = 250 * 10 9 | _K = 500 10 | 11 | asset = [f's_{i:04d}' for i in range(_K)] 12 | date = pd.date_range('2015-1-1', periods=_N) 13 | 14 | df = pd.DataFrame({ 15 | 'RETURNS': np.cumprod(1 + np.random.uniform(-0.1, 0.1, size=(_N, _K)), axis=0).reshape(-1), 16 | 'VWAP': np.cumprod(1 + np.random.uniform(-0.1, 0.1, size=(_N, _K)), axis=0).reshape(-1), 17 | 'OPEN': np.cumprod(1 + np.random.uniform(-0.1, 0.1, size=(_N, _K)), axis=0).reshape(-1), 18 | 'CLOSE': np.cumprod(1 + np.random.uniform(-0.1, 0.1, size=(_N, _K)), axis=0).reshape(-1), 19 | }, index=pd.MultiIndex.from_product([date, asset], names=['date', 'asset'])).reset_index() 20 | 21 | # 向脚本输入数据 22 | df = pl.from_pandas(df) 23 | 24 | 25 | def _code_block_1(): 26 | # 要求能将ts_提前 27 | B = cs_rank(CLOSE, False) 28 | C = cs_rank(CLOSE, True) 29 | A = ts_returns(CLOSE, 5) 30 | D = ts_returns(CLOSE, 10) 31 | E = A + B 32 | 33 | 34 | def _code_block_2(): 35 | # 要求能将ts_提前 36 | ma_10 = ts_mean(CLOSE, 10) 37 | MAMA_20 = ts_mean(ma_10, 20) 38 | alpha_031 = ((cs_rank(cs_rank(cs_rank(ts_decay_linear((-1 * cs_rank(cs_rank(ts_delta(CLOSE, 10)))), 10)))))) 39 | 40 | 41 | def _code_block_1(): 42 | # 要求能将ts_提前 43 | A = ts_returns(CLOSE, 5) 44 | D = ts_returns(A, 10) + cs_rank(CLOSE) 45 | E = A + D 46 | 47 | 48 | def _code_block_1(): 49 | # 要求能将ts_提前 50 | B = cs_rank(CLOSE, False) 51 | C = cs_rank(CLOSE, True) 52 | E = B + C 53 | 54 | 55 | def _code_block_1(): 56 | _OO_02 = OPEN[-3] / OPEN[-1] 57 | _OO_05 = OPEN[-6] / OPEN[-1] 58 | _OO_10 = OPEN[-11] / OPEN[-1] 59 | 60 | # 几何平均 61 | RETURN_OO_02 = _OO_02 ** (1 / 2) - 1 62 | RETURN_OO_05 = _OO_05 ** (1 / 5) - 1 63 | RETURN_OO_10 = _OO_10 ** (1 / 10) - 1 64 | 65 | 66 | def _code_block_1(): 67 | ONE = 1 68 | TWO = add(CLOSE, False) 69 | 70 | 71 | logger.info("1") 72 | df = codegen_exec(df, _code_block_1, over_null='partition_by', output_file="1_out.py", style='polars', filter_last=True) 73 | print(df) 74 | logger.info("2") 75 | -------------------------------------------------------------------------------- /tests/formula_transformer.py: -------------------------------------------------------------------------------- 1 | import ast 2 | 3 | from expr_codegen.codes import source_replace, RenameTransformer 4 | 5 | source = """ 6 | OPEN>=CLOSE?1:0 7 | OPEN>CLOSE?A==B?3:DE>FG?5:6:0 8 | A=OPEN==CLOSE 9 | B,C=(OPEN < CLOSE) * -1,1 10 | 11 | ts_sum(min(((delta(vol,8)-(high*vol))+7.1),max(((open>=low?amt:low)/4.2),ts_sum((open=vol?amt:vol),3)+((vol==low?vwap:vol)*10.3)) 14 | 15 | """ 16 | 17 | source = """ 18 | _A = 1+2 19 | _B = 3+4 20 | C = _A+_B 21 | _A = 10+20 22 | _B = 30+40 23 | D = _A+_B 24 | """ 25 | 26 | # 再也不怕出现有环了 27 | source = """ 28 | _A = True+None 29 | # F = 1>None 30 | # _A = Add(1,True) 31 | # 32 | # _A = _A+_A 33 | # _B = _A+2 34 | # _C = _A+_B 35 | # 36 | # D = _A 37 | 38 | """ 39 | funcs_map = {'abs': 'abs_', 40 | 'max': 'max_', 41 | 'min': 'min_', 42 | 'delta': 'ts_delta', 43 | 'delay': 'ts_delay', 44 | } 45 | args_map = {'True': "TRUE", 'False': "FALSE", 'None': "NONE"} 46 | targets_map = {'_A': '_12'} 47 | 48 | tree = ast.parse(source_replace(source)) 49 | t = RenameTransformer(funcs_map, targets_map, args_map) 50 | 51 | t.visit(tree) 52 | print('=' * 60) 53 | print(t.funcs_old) 54 | print(t.args_old) 55 | print(t.targets_old) 56 | print('=' * 60) 57 | print(t.funcs_new) 58 | print(t.args_new) 59 | print(t.targets_new) 60 | print('=' * 60) 61 | print(ast.unparse(tree)) 62 | -------------------------------------------------------------------------------- /tests/speed_pandas.py: -------------------------------------------------------------------------------- 1 | import timeit 2 | 3 | import numpy as np 4 | import pandas as pd 5 | 6 | 7 | def getCols(k) -> str: 8 | return [f'x_{i}' for i in range(k)] 9 | 10 | 11 | pd.testing.getCols = getCols 12 | pd.testing._N = 250 * 10 13 | pd.testing._K = 5000 14 | 15 | # 生成5000支股票实在太慢,所以改用其它方案 16 | # CLOSE = pd._testing.makeTimeDataFrame() 17 | 18 | _N = 250 * 10 19 | _K = 5000 20 | 21 | asset = [f'x_{i}' for i in range(_K)] 22 | date = pd.date_range('2000-01-1', periods=_N) 23 | 24 | df_sort_by_date = pd.DataFrame({ 25 | 'OPEN': np.random.rand(_K * _N), 26 | 'HIGH': np.random.rand(_K * _N), 27 | 'LOW': np.random.rand(_K * _N), 28 | 'CLOSE': np.random.rand(_K * _N), 29 | }, index=pd.MultiIndex.from_product([date, asset], names=['date', 'asset'])).reset_index() 30 | 31 | df_sort_by_asset = pd.DataFrame({ 32 | 'OPEN': np.random.rand(_K * _N), 33 | 'HIGH': np.random.rand(_K * _N), 34 | 'LOW': np.random.rand(_K * _N), 35 | 'CLOSE': np.random.rand(_K * _N), 36 | }, index=pd.MultiIndex.from_product([asset, date], names=['asset', 'date'])).reset_index() 37 | 38 | df_sort_by_date.info() 39 | # 乱序 40 | df = df_sort_by_date.sample(_K * _N * 2, replace=True, ignore_index=True) 41 | print(df.tail(10)) 42 | 43 | 44 | def func_0_ts__asset__date_1(df: pd.DataFrame) -> pd.DataFrame: 45 | # ======================================== 46 | # x_0 = ts_mean(OPEN, 10) 47 | df["x_0"] = df["OPEN"].rolling(10).mean() 48 | # expr_6 = ts_delta(OPEN, 10) 49 | df["expr_6"] = df["OPEN"].diff(10) 50 | # expr_7 = ts_delta(OPEN + 1, 10) 51 | df["expr_7"] = (df["OPEN"] + 1).diff(10) 52 | # x_1 = ts_mean(CLOSE, 10) 53 | df["x_1"] = df["CLOSE"].rolling(10).mean() 54 | return df 55 | 56 | 57 | def func_0_ts__asset__date_2(df: pd.DataFrame) -> pd.DataFrame: 58 | df = df.sort_values(by=["date"]) 59 | # ======================================== 60 | # x_0 = ts_mean(OPEN, 10) 61 | df["x_0"] = df["OPEN"].rolling(10).mean() 62 | # expr_6 = ts_delta(OPEN, 10) 63 | df["expr_6"] = df["OPEN"].diff(10) 64 | # expr_7 = ts_delta(OPEN + 1, 10) 65 | df["expr_7"] = (df["OPEN"] + 1).diff(10) 66 | # x_1 = ts_mean(CLOSE, 10) 67 | df["x_1"] = df["CLOSE"].rolling(10).mean() 68 | return df 69 | 70 | 71 | def func_0_cs__date(df: pd.DataFrame) -> pd.DataFrame: 72 | # ======================================== 73 | # x_2 = cs_rank(x_0) 74 | df["x_2"] = df["x_0"].rank(pct=True) 75 | # x_3 = cs_rank(x_1) 76 | df["x_3"] = df["x_1"].rank(pct=True) 77 | return df 78 | 79 | 80 | print('=' * 60) 81 | # 10年,已经按日期,资产排序的情况下,3种情况速度并没有多大差异 82 | # 21.95759929995984 83 | # 23.93896960001439 84 | # 23.232979799970053 85 | 86 | print(timeit.timeit('df_sort_by_date.sort_values(by=["date"]).groupby(by=["asset"], group_keys=False).apply(func_0_ts__asset__date_1)', number=3, globals=locals())) 87 | print(timeit.timeit('df_sort_by_date.sort_values(by=["asset", "date"]).groupby(by=["asset"], group_keys=False).apply(func_0_ts__asset__date_1)', number=3, globals=locals())) 88 | print(timeit.timeit('df_sort_by_date.groupby(by=["asset"], group_keys=False).apply(func_0_ts__asset__date_2)', number=3, globals=locals())) 89 | 90 | print('=' * 60) 91 | # 10年,已经按资产,日期排序的情况下,3种情况速度并没有多大差异 92 | # 25.781703099957667 93 | # 20.82362669997383 94 | # 20.364632499986328 95 | print(timeit.timeit('df_sort_by_asset.sort_values(by=["date"]).groupby(by=["asset"], group_keys=False).apply(func_0_ts__asset__date_1)', number=3, globals=locals())) 96 | print(timeit.timeit('df_sort_by_asset.sort_values(by=["asset", "date"]).groupby(by=["asset"], group_keys=False).apply(func_0_ts__asset__date_1)', number=3, globals=locals())) 97 | print(timeit.timeit('df_sort_by_asset.groupby(by=["asset"], group_keys=False).apply(func_0_ts__asset__date_2)', number=3, globals=locals())) 98 | 99 | print('=' * 60) 100 | # 2年,乱序后,结果差异很大 101 | # 56.11242270004004 102 | # 45.11343280004803 103 | # 34.87692619999871 104 | print(timeit.timeit('df.sort_values(by=["date"]).groupby(by=["asset"], group_keys=False).apply(func_0_ts__asset__date_1)', number=3, globals=locals())) 105 | print(timeit.timeit('df.sort_values(by=["asset", "date"]).groupby(by=["asset"], group_keys=False).apply(func_0_ts__asset__date_1)', number=3, globals=locals())) 106 | print(timeit.timeit('df.groupby(by=["asset"], group_keys=False).apply(func_0_ts__asset__date_2)', number=3, globals=locals())) 107 | -------------------------------------------------------------------------------- /tests/speed_pandas2.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | 4 | 5 | def getCols(k) -> str: 6 | return [f'x_{i}' for i in range(k)] 7 | 8 | 9 | pd.testing.getCols = getCols 10 | pd.testing._N = 250 * 10 11 | pd.testing._K = 5000 12 | 13 | # 生成5000支股票实在太慢,所以改用其它方案 14 | # CLOSE = pd._testing.makeTimeDataFrame() 15 | 16 | _N = 250 * 10 17 | _K = 5000 18 | 19 | asset = [f'x_{i}' for i in range(_K)] 20 | date = pd.date_range('2000-01-1', periods=_N) 21 | 22 | df_sort_by_asset = pd.DataFrame({ 23 | 'OPEN': np.random.rand(_K * _N), 24 | 'HIGH': np.random.rand(_K * _N), 25 | 'LOW': np.random.rand(_K * _N), 26 | 'CLOSE': np.random.rand(_K * _N), 27 | }, index=pd.MultiIndex.from_product([asset, date], names=['asset', 'date'])).reset_index() 28 | 29 | 30 | def func_0_ts__asset__date(df: pd.DataFrame) -> pd.DataFrame: 31 | # ======================================== 32 | # x_0 = ts_mean(OPEN, 10) 33 | df["x_0"] = df["OPEN"].rolling(10).mean() 34 | # expr_6 = ts_delta(OPEN, 10) 35 | df["expr_6"] = df["OPEN"].diff(10) 36 | # expr_7 = ts_delta(OPEN + 1, 10) 37 | df["expr_7"] = (df["OPEN"] + 1).diff(10) 38 | # x_1 = ts_mean(CLOSE, 10) 39 | df["x_1"] = df["CLOSE"].rolling(10).mean() 40 | assert df['date'].is_monotonic_increasing 41 | return df 42 | 43 | 44 | def func_0_cs__date(df: pd.DataFrame) -> pd.DataFrame: 45 | # ======================================== 46 | # x_2 = cs_rank(x_0) 47 | df["x_2"] = df["x_0"].rank(pct=True) 48 | # x_3 = cs_rank(x_1) 49 | df["x_3"] = df["x_1"].rank(pct=True) 50 | return df 51 | 52 | 53 | def run1(df): 54 | df = df.groupby(by=["asset"], group_keys=False).apply(func_0_ts__asset__date) 55 | assert df['date'].is_monotonic_increasing 56 | df = df.groupby(by=["date"], group_keys=False).apply(func_0_cs__date) 57 | assert df['date'].is_monotonic_increasing 58 | df = df.groupby(by=["asset"], group_keys=False).apply(func_0_ts__asset__date) 59 | assert df['date'].is_monotonic_increasing 60 | 61 | # 测sort的时机,内部折腾,外部顺序不变 62 | df = df_sort_by_asset.sort_values(by=["date", "asset"]).reset_index(drop=True) 63 | run1(df) 64 | -------------------------------------------------------------------------------- /tests/speed_polars.py: -------------------------------------------------------------------------------- 1 | import timeit 2 | 3 | import numpy as np 4 | import pandas as pd 5 | import polars as pl 6 | 7 | _N = 250 * 10 8 | _K = 5000 9 | 10 | asset = [f'x_{i}' for i in range(_K)] 11 | date = pd.date_range('2000-01-1', periods=_N) 12 | 13 | df_sort_by_date = pd.DataFrame({ 14 | 'OPEN': np.random.rand(_K * _N), 15 | 'HIGH': np.random.rand(_K * _N), 16 | 'LOW': np.random.rand(_K * _N), 17 | 'CLOSE': np.random.rand(_K * _N), 18 | }, index=pd.MultiIndex.from_product([date, asset], names=['date', 'asset'])).reset_index() 19 | 20 | df_sort_by_asset = pd.DataFrame({ 21 | 'OPEN': np.random.rand(_K * _N), 22 | 'HIGH': np.random.rand(_K * _N), 23 | 'LOW': np.random.rand(_K * _N), 24 | 'CLOSE': np.random.rand(_K * _N), 25 | }, index=pd.MultiIndex.from_product([asset, date], names=['asset', 'date'])).reset_index() 26 | 27 | df_sort_by_date.info() 28 | # 乱序 29 | df = df_sort_by_date.sample(_K * _N * 2, replace=True, ignore_index=True) 30 | 31 | df_sort_by_date = pl.from_pandas(df_sort_by_date) 32 | df_sort_by_asset = pl.from_pandas(df_sort_by_asset) 33 | df = pl.from_pandas(df) 34 | 35 | print(df.tail(10)) 36 | 37 | 38 | def func_0_ts__asset__date_1(df: pl.DataFrame) -> pl.DataFrame: 39 | # ======================================== 40 | df = df.with_columns( 41 | # x_0 = ts_mean(OPEN, 10) 42 | x_0=(pl.col("OPEN").rolling_mean(10)), 43 | # expr_6 = ts_delta(OPEN, 10) 44 | expr_6=(pl.col("OPEN").diff(10)), 45 | # x_1 = ts_mean(CLOSE, 10) 46 | x_1=(pl.col("CLOSE").rolling_mean(10)), 47 | ) 48 | # print(df['date']) 49 | return df 50 | 51 | 52 | def func_0_ts__asset__date_2(df: pl.DataFrame) -> pl.DataFrame: 53 | df = df.sort(by=["date"]) 54 | df = df.with_columns( 55 | # x_0 = ts_mean(OPEN, 10) 56 | x_0=(pl.col("OPEN").rolling_mean(10)), 57 | # expr_6 = ts_delta(OPEN, 10) 58 | expr_6=(pl.col("OPEN").diff(10)), 59 | # x_1 = ts_mean(CLOSE, 10) 60 | x_1=(pl.col("CLOSE").rolling_mean(10)), 61 | ) 62 | # print(df['date']) 63 | return df 64 | 65 | 66 | print('=' * 60) 67 | # 10年,已经按日期,资产排序的情况下,3种情况速度并没有多大差异 68 | # 6.80373189994134 69 | # 9.654270599945448 70 | # 8.68796220002696 71 | 72 | print(timeit.timeit('df_sort_by_date.sort(by=["date"]).groupby(by=["asset"], maintain_order=True).apply(func_0_ts__asset__date_1)', number=5, globals=locals())) 73 | print(timeit.timeit('df_sort_by_date.sort(by=["asset", "date"]).groupby(by=["asset"], maintain_order=True).apply(func_0_ts__asset__date_1)', number=5, globals=locals())) 74 | print(timeit.timeit('df_sort_by_date.groupby(by=["asset"], maintain_order=False).apply(func_0_ts__asset__date_2)', number=5, globals=locals())) 75 | 76 | print('=' * 60) 77 | # 10年,已经按资产,日期排序的情况下,3种情况速度并没有多大差异 78 | # 8.119568099966273 79 | # 7.845328400027938 80 | # 7.50117709999904 81 | print(timeit.timeit('df_sort_by_asset.sort(by=["date"]).groupby(by=["asset"], maintain_order=True).apply(func_0_ts__asset__date_1)', number=5, globals=locals())) 82 | print(timeit.timeit('df_sort_by_asset.sort(by=["asset", "date"]).groupby(by=["asset"], maintain_order=True).apply(func_0_ts__asset__date_1)', number=5, globals=locals())) 83 | print(timeit.timeit('df_sort_by_asset.groupby(by=["asset"], maintain_order=False).apply(func_0_ts__asset__date_2)', number=5, globals=locals())) 84 | 85 | print('=' * 60) 86 | # 2年,乱序后,结果差异很大 87 | # 16.66170910000801 88 | # 23.977682299911976 89 | # 16.773866499890573 90 | print(timeit.timeit('df.sort(by=["date"]).groupby(by=["asset"], maintain_order=True).apply(func_0_ts__asset__date_1)', number=5, globals=locals())) 91 | print(timeit.timeit('df.sort(by=["asset", "date"]).groupby(by=["asset"], maintain_order=True).apply(func_0_ts__asset__date_1)', number=5, globals=locals())) 92 | print(timeit.timeit('df.groupby(by=["asset"], maintain_order=False).apply(func_0_ts__asset__date_2)', number=5, globals=locals())) 93 | -------------------------------------------------------------------------------- /tests/speed_polars2.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import polars as pl 4 | 5 | _N = 250 * 10 6 | _K = 5000 7 | 8 | asset = [f'x_{i}' for i in range(_K)] 9 | date = pd.date_range('2000-01-1', periods=_N) 10 | 11 | df_sort_by_asset = pd.DataFrame({ 12 | 'OPEN': np.random.rand(_K * _N), 13 | 'HIGH': np.random.rand(_K * _N), 14 | 'LOW': np.random.rand(_K * _N), 15 | 'CLOSE': np.random.rand(_K * _N), 16 | }, index=pd.MultiIndex.from_product([asset, date], names=['asset', 'date'])).reset_index() 17 | 18 | df = pl.from_pandas(df_sort_by_asset) 19 | 20 | 21 | def rank_pct(expr: pl.Expr) -> pl.Expr: 22 | """rank(pct=True)""" 23 | return expr.rank() / (expr.len() - expr.null_count()) 24 | 25 | 26 | def func_0_ts__asset__date(df: pl.DataFrame) -> pl.DataFrame: 27 | assert df['date'].to_pandas().is_monotonic_increasing 28 | # ======================================== 29 | df = df.with_columns( 30 | # x_0 = ts_mean(OPEN, 10) 31 | x_0=(pl.col("OPEN").rolling_mean(10)), 32 | # expr_6 = ts_delta(OPEN, 10) 33 | expr_6=(pl.col("OPEN").diff(10)), 34 | # x_1 = ts_mean(CLOSE, 10) 35 | x_1=(pl.col("CLOSE").rolling_mean(10)), 36 | ) 37 | return df 38 | 39 | 40 | def func_0_cs__date(df: pl.DataFrame) -> pl.DataFrame: 41 | # ======================================== 42 | df = df.with_columns( 43 | # x_6 = cs_rank(OPEN) 44 | x_6=(rank_pct(pl.col("OPEN"))), 45 | ) 46 | # ======================================== 47 | df = df.with_columns( 48 | # x_2 = cs_rank(x_0) 49 | x_2=(rank_pct(pl.col("x_0"))), 50 | # x_3 = cs_rank(x_1) 51 | x_3=(rank_pct(pl.col("x_1"))), 52 | ) 53 | return df 54 | 55 | 56 | def run1(df): 57 | df = df.groupby(by=["asset"], maintain_order=True).apply(func_0_ts__asset__date) 58 | # 输出也是按同股票在一起,所以整体无法通过 59 | # assert df['date'].to_pandas().is_monotonic_increasing 60 | df = df.groupby(by=["date"], maintain_order=False).apply(func_0_cs__date) 61 | # 输了更乱了 62 | #assert df['date'].to_pandas().is_monotonic_increasing 63 | df = df.groupby(by=["asset"], maintain_order=True).apply(func_0_ts__asset__date) 64 | # assert df['date'].to_pandas().is_monotonic_increasing 65 | 66 | # 测sort的时机,输出结果很快就变了 67 | run1(df) 68 | --------------------------------------------------------------------------------