├── .gitignore ├── LICENSE ├── README.md ├── best_practices ├── README.md ├── step1_data.py ├── step2_feature.py └── utils.py ├── examples ├── alpha101.py ├── alpha191.py ├── dynamic_parameter_long.py ├── dynamic_parameter_wide.py ├── formula_dataframe.py ├── formula_series.py ├── formula_tdx.py ├── group_func.py ├── group_nav.py ├── talib_2d.py └── 市场宽度.py ├── requirements.txt ├── setup.py ├── ta_cn ├── __init__.py ├── _version.py ├── aggregate.py ├── alphas │ ├── __init__.py │ ├── alpha.py │ ├── alpha101.py │ ├── alpha191.py │ ├── test101.py │ └── test191.py ├── bn_wraps.py ├── candle.py ├── chip.py ├── ema.py ├── ewm_nb.py ├── imports │ ├── __init__.py │ ├── gtja_long.py │ ├── gtja_wide.py │ ├── long.py │ ├── long_ta.py │ ├── long_tdx.py │ ├── long_wq.py │ ├── wide.py │ ├── wide_ta.py │ ├── wide_tdx.py │ ├── wide_wq.py │ ├── wq_long.py │ └── wq_wide.py ├── nb.py ├── noise.py ├── performance.py ├── preprocess.py ├── regress.py ├── research_report │ ├── __init__.py │ ├── gdzq.py │ └── gfzq.py ├── slow │ ├── __init__.py │ └── slow.py ├── split.py ├── talib │ └── __init__.py ├── tdx │ ├── __init__.py │ ├── logical.py │ ├── over_bought_over_sold.py │ ├── pressure_support.py │ ├── reference.py │ ├── statistics.py │ ├── trend.py │ └── volume.py ├── utils.py ├── utils_long.py ├── utils_wide.py └── wq │ ├── __init__.py │ ├── arithmetic.py │ ├── cross_sectional.py │ ├── group.py │ ├── logical.py │ ├── special.py │ ├── time_series.py │ ├── transformational.py │ └── vector.py ├── tests ├── MyTT.py ├── atr_.py ├── atr_cn.py ├── avedev_.py ├── boll_.py ├── cci_.py ├── chip_.py ├── covar_.py ├── cross_.py ├── dmi_.py ├── ema_.py ├── forcast_.py ├── macd_.py ├── mfi_.py ├── obv_.py ├── ols_.py ├── ray_test.py ├── reg_.py ├── rsi_.py ├── slope_.py ├── speed_test.py ├── stddev_.py ├── tr_.py ├── trix_.py ├── var_.py ├── wls_.py └── wma_.py ├── 加速.md ├── 参数.md ├── 复权.md └── 指标对比.xlsx /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 wukan 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ta_cn 中国版技术指标 2 | 3 | ## !!! 注意:如要在`polars`中使用技术指标,请使用[polars_ta](https://github.com/wukan1986/polars_ta) 4 | 5 | ## 项目背景 6 | 7 | 初学量化,技术指标一般使用`TA-Lib`,但存在以下问题 8 | 9 | 1. 部分技术指标与国内不同。但大部分中国股民都是参考国内技术指标进行交易。需要实现中国版指标 10 | 2. `TA-Lib`只支持单支序列,要计算多支股票需循环,耗时久。 11 | 12 | 在实现遗传算法挖因子时,意识到如果能将常用的操作都转成算子,将大大简化策略的研究, 13 | 特别是将`+`、`-`、`*`、`/`等中缀操作符转成`ADD()`、`SUB()`、`MUL()`、`DIV()`前缀函数,可直接输到遗传算法工具中 14 | 15 | 所以开始参考`Alpha101`和各券商金融工程研报,试着实现一些算子,但后期实现中发现一些问题 16 | 17 | 1. 每家金工的研报指标命名上都有区别,难以统一 18 | 2. 指标太多,实现工作太大 19 | 20 | 直到看到了`MyTT`这个项目才意识到,指标命名参考股票软件的公式才是最方便直接的,可以直接到各股软中复制公式。遇到性能问题再针对性转换即可。 21 | 22 | ## 本人为何不直接用`MyTT`,而是重复造轮子呢? 23 | 24 | 1. 大部分公式只支持单条数据,遇到几千支股票的DataFrame,循环太慢 25 | 2. `TA-Lib`与国内指标不同,区别在哪,没有对比。错过了很好的教学机会 26 | 3. 为了行数短牺牲了可读性 27 | 4. 部分函数直接复制于股票软件,代码没有优化,有重复计算 28 | 29 | ## 再次大迭代,仿WorldQuant 30 | 31 | 1. 2022年9月初,知道WorldQuant Websim重新开放为WorldQuant BRAIN后,开始研究国外的平台 32 | 2. WQ公式更加科学。例如: 33 | 1. WQ时序有`ts_`前缀 34 | 2. WQ有横截面函数和分组函数,通达信没有 35 | 3. 通达信将大量不相关的指标都归类为引用函数 36 | 4. WQ公式为Alpha因子而设计,有大量的权重处理等函数 37 | 38 | ## 目标 39 | 40 | 1. 优先实现WorldQuant公式,然后实现通达信公式 41 | 2. 通过在通达信中导入WQ公式并别名,来实现通达信公式覆盖 42 | 3. 支持二维矩阵计算 43 | 4. 支持长表和宽表,支持NaN跳过 44 | 5. 最终实现WQ的本地版 45 | 46 | ## 实现方案优先级 47 | 48 | 1. bottleneck。支持二维数据,优先使用 49 | 2. TA-Lib。封装了常用函数,次要选择 50 | 3. numba。速度受影响,最后才用它 51 | 52 | ## 安装 53 | 54 | 1. 只想使用二维矩阵TA-Lib,只需安装基础版即可 55 | 56 | ```commandline 57 | pip install ta_cn -i https://mirrors.aliyun.com/pypi/simple --upgrade 58 | ``` 59 | 60 | 2. 使用中国版指标加速 61 | 62 | ```commandline 63 | pip install ta_cn[cn] -i https://mirrors.aliyun.com/pypi/simple --upgrade 64 | ``` 65 | 66 | 3. 开发人员安装。开发迭代很快,只有版本稳定才会发布到`PyPI`,需要时效更高的安装方法 67 | 1. 从github下载zip文件 68 | 2. 解压zip, 进入解压后目录,输入以下命令 69 | 70 | ```commandline 71 | pip install .[cn] -i https://mirrors.aliyun.com/pypi/simple --upgrade 72 | ``` 73 | 74 | 4. 库维护者安装。可修改本地文件 75 | 76 | ```commandline 77 | pip install -e . 78 | ``` 79 | 80 | ## 常见使用方法 81 | 82 | 1. 转发原生talib,输入一维向量 83 | - 优点: 本库提供了跳过空值的功能 84 | - 缺点: 不要在大量循环中调用,因为跳过空值的功能每调用一次就要预分配内存 85 | 2. 封装原生talib,输入二维矩阵,同时支持参数一维向量化 86 | - 优点:可为不同股票指定不同参数,可用于按天遍历计算指标。只分配一次内存 87 | 3. 直接调用包中定义的指标,如KDJ等 88 | - 优点:符合中国习惯的技术指标 89 | - 缺点:指标数目前比较少。一般没有跳过空值功能 90 | 4. 输入为长表,分组计算 91 | - 优点:使用简单,可进行指标嵌套 92 | - 缺点:速度会慢一些。准备工作偏复杂 93 | 5. 输入为宽表 94 | - 优点:计算快 95 | - 缺点:计算前需要准备数据为指定格式,占大量内存 96 | 97 | ## 停牌处理,跳过空值 98 | 99 | 1. TA-Lib遇到空值后面结果全为NaN 100 | 2. 跳过NaN后,导致数据长度不够,部分函数可能报错 101 | 3. 方案一:将所有数据进行移动,时序指标移动到最后,横截面指标移动到最右。 102 | - 优点:原指标不需要改动,只要提前处理数据。处理速度也快 103 | - 缺点:时序指标与横截面指标不能混合使用,得分别处理 104 | 4. 方案二:预先初始化空白区,计算指标时屏蔽NaN,算完后回填 105 | - 优点:外部调用简单,不需要对数据提前处理 106 | - 缺点:由于有大量的是否跳过NaN的处理,所以速度慢 107 | 108 | ### 常见示例 109 | 110 | ```python 111 | import numpy as np 112 | 113 | # 新版talib,只要替换引用,并添加一句init即可 114 | import ta_cn.talib as ta 115 | from ta_cn.utils_wide import pushna, pullna 116 | 117 | # 原版talib,不支持二维数据 118 | # import talib as ta 119 | 120 | # 准备数据 121 | h = np.random.rand(1000000).reshape(-1, 5000) + 10 122 | l = np.random.rand(1000000).reshape(-1, 5000) 123 | c = np.random.rand(1000000).reshape(-1, 5000) 124 | # 指定模式,否则对talib封装的所有函数都不存在 125 | ta.init(mode=2, skipna=False, to_globals=True) 126 | 127 | # 几个调用函数演示 128 | r = ta.ATR(h, l, c, timeperiod=10) 129 | print(r) 130 | x, y, z = ta.BBANDS(c, timeperiod=10, nbdevup=2, nbdevdn=2) 131 | print(z) 132 | 133 | # 将少量值设置为空,用来模拟停牌 134 | c[c < 0.4] = np.nan 135 | 136 | # 提前处理数据,跳过停牌进行计算,再还原的演示 137 | # 嵌套指标时,全为时序指标使用down,或全为截面使用right。混合时此方法不要轻易使用 138 | arr, row, col = pushna(c, direction='down') 139 | rr = ta.SMA(arr, timeperiod=10) 140 | r = pullna(rr, row, col) 141 | print(r) 142 | 143 | ``` 144 | 145 | ### 使用ta_cn中定义的公式 146 | 147 | ```python 148 | import numpy as np 149 | 150 | from ta_cn.talib import init, set_compatibility_enable, set_compatibility 151 | from ta_cn.tdx.over_bought_over_sold import ATR_CN 152 | from ta_cn.tdx.trend import MACD 153 | 154 | # ta_cn.talib库底层是循环调用talib,部分计算效率不高 155 | # 可导入ta_cn中的公式 156 | 157 | # 准备数据 158 | h = np.random.rand(10000000).reshape(-1, 50000) + 10 159 | l = np.random.rand(10000000).reshape(-1, 50000) 160 | c = np.random.rand(10000000).reshape(-1, 50000) 161 | 162 | init(mode=2, skipna=False) 163 | 164 | r = ATR_CN(h, l, c, timeperiod=10) 165 | print(r) 166 | 167 | # 设置参数,让MACD中的EMA算法与国内算法相同 168 | set_compatibility_enable(True) 169 | set_compatibility(1) 170 | set_compatibility_enable(False) 171 | 172 | x, y, z = MACD(c, fastperiod=12, slowperiod=26, signalperiod=9) 173 | print(z) 174 | ``` 175 | 176 | ## 长宽表处理 177 | 178 | 二维矩阵计算,的确方便,`Alpha101`中的公式很快就可以实现,即支持时序又支持截面,但其中有一个难点, 179 | 就是NaN值的处理。`pushna`和`pullna`可用于解决此问题,但在公式中嵌入就比较棘手。 180 | 181 | 所以,本项目还特别提供了长表与宽表的装饰器,按照一定的要求套用装饰器,就能让原本不可跳过空值的函数自动跳过空值。 182 | 如果明确数据内不会产生空值,可以不使用长宽表装饰器,效率会更快。 183 | 184 | ### 长表 185 | 186 | 处理慢一些,但结果更适合于机器学习。 187 | 底层主要通过`series_groupby_apply`(针对单列输入)和`dataframe_groupby_apply`(针对多列输入)装饰器来实现跳过空值。 188 | 189 | ```python 190 | import pandas as pd 191 | 192 | from ta_cn.imports.long_ta import ATR, SMA 193 | from ta_cn.imports.long_wq import group_neutralize, rank 194 | 195 | pd._testing._N = 500 196 | pd._testing._K = 30 197 | 198 | open_ = pd._testing.makeTimeDataFrame() + 5 199 | high = pd._testing.makeTimeDataFrame() + 10 200 | low = pd._testing.makeTimeDataFrame() + 5 201 | close = pd._testing.makeTimeDataFrame() + 5 202 | group = close.copy() * 100 // 1 % 5 203 | 204 | df = { 205 | 'open_': open_.stack(), 206 | 'high': high.stack(), 207 | 'low': low.stack(), 208 | 'close': close.stack(), 209 | 'group': group.stack(), 210 | } 211 | df = pd.DataFrame(df) 212 | df.index.names = ['date', 'asset'] 213 | kwargs = df.to_dict(orient='series') 214 | 215 | # 单输入 216 | r = SMA(df['close'], timeperiod=10) 217 | print(r.unstack()) 218 | # 多输入 219 | r = ATR(df['high'], df['low'], df['close'], 10) 220 | print(r.unstack()) 221 | # 横截面 222 | r = rank(df['close']) 223 | print(r.unstack()) 224 | r = group_neutralize(df['close'], df['group']) 225 | 226 | print(r.unstack()) 227 | 228 | 229 | ``` 230 | 231 | ### 宽表 232 | 233 | 处理速度通常比长表要快。核心是输入需要封装成`WArr`,输出要`.raw()`提取。 234 | 底层通过`wide_wraps`装饰器来实现空值跳过,通过`long_wraps`装饰器来实现长表函数转宽表函数 235 | 236 | ```python 237 | import pandas as pd 238 | 239 | from ta_cn.imports.wide_ta import ATR 240 | from ta_cn.utils import np_to_pd 241 | from ta_cn.utils_wide import WArr 242 | 243 | pd._testing._N = 250 244 | pd._testing._K = 30 245 | h = pd._testing.makeTimeDataFrame() + 10 246 | l = pd._testing.makeTimeDataFrame() 247 | c = pd._testing.makeTimeDataFrame() 248 | 249 | # 数据需要封装成特殊对象,实现NaN的堆叠和还原 250 | h_ = WArr.from_array(h, direction='down') 251 | l_ = WArr.from_array(l, direction='down') 252 | c_ = WArr.from_array(c, direction='down') 253 | 254 | r = ATR(h_, l_, c_, 10) 255 | # 返回的数据可能是np.ndarray 256 | print(r.raw()) 257 | 258 | # 可以再封装回pd.DataFrame 259 | d = np_to_pd(r.raw(), copy=False, index=c.index, columns=c.columns) 260 | print(d.iloc[-5:]) 261 | 262 | 263 | ``` 264 | 265 | ## 指标对比清单 266 | 267 | 参考 [指标对比](指标对比.xlsx) 未完工,待补充 268 | 269 | ## Alpha101/Alpha191 270 | 271 | 本项目,试着用公式系统实现`Alpha101`、`Alpha191`,请参考examples文件下的测试示例。它最大的特点是尽量保持原公式的形式, 272 | 少改动,防止乱中出错。然后再优化代码提高效率。 273 | 274 | ## 停牌处理,空值填充 275 | 276 | 1. 板块指数,停牌了也要最近的行情进行计算,否则指数过小 277 | 2. 停牌期的开高低收都是最近的收盘价,收盘价可以ffill 278 | 279 | ## 参考项目 280 | 281 | 1. [TA-Lib](https://github.com/TA-Lib/ta-lib) TA-Lib C语言版,非官方镜像 282 | 2. [ta-lib](https://github.com/mrjbq7/ta-lib) TA-Lib Python版封装 283 | 3. [MyTT](https://github.com/mpquant/MyTT) My麦语言 T通达信 T同花顺 284 | 4. [funcat](https://github.com/cedricporter/funcat) 公式移植 285 | 5. [pandas-ta](https://github.com/twopirllc/pandas-ta) 支持Pandas扩展的技术指标 286 | 6. [ta](https://github.com/bukosabino/ta) 通过类实现的技术指标 287 | 7. [WorldQuant算子](https://platform.worldquantbrain.com/learn/data-and-operators/operators) 288 | 8. [WorldQuant算子详情](https://platform.worldquantbrain.com/learn/data-and-operators/detailed-operator-descriptions) 289 | 290 | ## 交流群 291 | 292 | ta_cn技术指标交流群: 601477228 -------------------------------------------------------------------------------- /best_practices/README.md: -------------------------------------------------------------------------------- 1 | !!! 注意:这只是一个试验性的演示,并不一定达到了最佳 2 | 3 | # 说明 4 | 本演示经历了两个阶段 5 | 1. 多文件多进程并行计算。没有找到合适的多进程计算库,所以前期用的此方案。缺点是计算前需要合并数据,占用大量资源 6 | 2. 使用支持多线程的库 7 | 8 | 不同于普通的机器学习任务,可以直接使用sklearn等库进行数据处理。金融机器学习任务有特殊性,需要分日期(横截面)、分股票(时序)进行处理。 9 | 所以对于groupby和apply有要求。 10 | 1. apply最好支持调用numpy、talib、numba等函数 11 | 2. groupby最好支持多线程,否则只能通过切文件后多进程处理。 12 | 13 | # 可能的加速方案 14 | Python无法直接利用多核进行加速。如何加速计算,可能方案有: 15 | 1. 调用C++、Rust等一类的库,这些库底层实现多线程计算。 16 | - 一般无法直接利用pandas、talib等库的函数,学习难度大,开发工作量更大 17 | 2. 多进程处理,实际上数据会被序列化到硬盘再反序列化到每个进程,数据大时极为耗时。 18 | - 如果能砍掉序列化这个步骤,每个进程只反序列化自己部分,这样速度能快不少。 19 | - 但只反序列化自己部分需要建立索引,否则无法找到指定位置 20 | - 单个文件内部的索引相当于要设计一种全新的文件格式,过于复杂,不考虑 21 | - 最终,多个文件实现起来更简单 22 | 3. 如果到Python 3.12版本以后,GIL真的取消掉可以真的实现多线程,那么这种分文件的方式也就没有必要了。 23 | 只要提前分配好对应内存,后面就分区计算即可 24 | 25 | ## 多文件方案 26 | 1. 数据使用长表进行保存。方便从末尾添加数据,节约存储空间,容易添加新股票,方便机器学习,但计算指标速度慢。 27 | 2. 同一属性的数据有二维:时间(行)、合约(列)。 28 | 3. 二维数据可以按九宫格方式分割成多块。每格,文件标记成 `行__列`。目前行按年划分,而列按股票最后一个数字分成10组。如果数据量大,可以分得更细 29 | 4. 当要做时序计算时将同列的数据纵向合并成长表,如 `*__列` 为长时序数据。然后按`股票`分组后计算`时序指标` 30 | 5. 当要做截面计算时将同行的数据横向合并成长表,如 `行__*` 为横截面数据。然后按`时间`分组后计算`截面指标` 31 | 6. 时序计算,根据指标的特性,有可能只使用一段数据即可计算正确或近似。而横截面计算可能整行都得参与计算才正确 32 | 7. 每个块由输入、计算、输出三个功能组成。需根据计算要求,设计输入是整行(时序)加载还是整列(截面)加载 33 | 8. 默认每个块的输出数据与输入同形式,同为整行或整列。如果下一块形式不同,则当前块要输出九宫格式来适应下一块。 34 | - 原数据`田`划分,由于第一块计算时序指标,所以指定按`川`读入,2个进程 35 | - 第二块计算截面指标,将`川`整行读入,会导致1进程数据全加载。 36 | - 所以第一块的输出必须还原成`田`,第二块就能按`三`读入,2个进程 37 | 9. 一般不再需要跳过空值的调整,也不需要二维函数了。因为长表在stack时默认就丢弃了空值,所以直接算即可 38 | 39 | ## 进一步优化 40 | 1. 输入时的多文件加载合并,输出时的分组写入。这些都是瓶颈。 41 | 2. 硬盘换成固态硬盘。或加大内存,数据写入共享内存中。 42 | 3. 进程数多不一定能提高速度,因为可能IO瓶颈,硬盘忙不过来。 43 | 4. 进程过多还有可能导至内存占用过高,进而导致内存不足而崩溃 44 | 6. 每次时序或截面计算时都要考虑减少内存 45 | 46 | ## 结论 47 | 1. 三年4000多支股票,日线数据,计算十几个指标。numba启用缓存 48 | 2. 3进程初始准备,5进程时序指标,3进程截面指标。用时约80秒 49 | 2. 1进程初始准备,1进程时序指标,1进程截面指标。用时约250秒 50 | 51 | ## 宽表版 52 | 1. 此处长表版与ta_cn长表版区别是groupby时机问题。 53 | - ta_cn版每个公式函数都要groupby,几十个指标就需要groupby几十次 54 | - 而这里不管有多少公式,只groupby一次 55 | 2. 缺点是没有ta_cn长表版易用,比如想先横截面排序,然后时序相关系数,工序比较多 56 | 3. 宽表版,需要将数据全都整理成宽表,然后每个宽表需要一层文件,所以需要另行设计文件路径 57 | - `行__列`已经用了,可以一个文件夹为字段名,内再放parquet文件 58 | - 考虑以后再测试此功能 59 | 60 | # 多线程方案 61 | 考察了polars和vaex后,最后选择的polars 62 | 1. polars: 语法需要重新学习 63 | 2. vaex: 函数与pandas比较接近。但apply时可能是多进程并出现复制 64 | 65 | ## 进一步优化 66 | 1. 结果相同的情况下,优先使用polars提供的函数 67 | 2. 加内存条,或者使用mmap模式等 -------------------------------------------------------------------------------- /best_practices/step1_data.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | 3 | from itertools import product 4 | 5 | import pandas as pd 6 | import parmap 7 | from loguru import logger 8 | 9 | from best_practices.utils import func_load_calc_save, dataframe_save, timer, func_load_parquet 10 | 11 | # 行范围,字符串,只用在文件名上 12 | # 分钟数据太大,可改为按月划分 13 | RANGE_I = pd.date_range('2020-01-01', '2022-12-31', freq='Y') 14 | RANGE_I = [f'{i:%Y}' for i in RANGE_I] 15 | 16 | # 列范围,字符串,只用在文件名上 17 | # 横向10组在数据量大时也比较慢,可改为20组,50组等 18 | RANGE_J = [f'{i:02d}' for i in range(10)] 19 | 20 | # 数据源路径 21 | PATH_STEP0_INPUT1 = r'D:\data\jqresearch\get_price_stock_daily' 22 | PATH_STEP0_INPUT2 = r'D:\data\jqresearch\get_price_stock_factor' 23 | PATH_STEP0_INPUT3 = r'D:\data\jqresearch\get_extras_stock_is_st' 24 | # 中间过程路径。可以考虑这部分保存到共享内存或内存盘中,加快速度 25 | PATH_STEP0_OUTPUT = r'M:\data3\step0' 26 | PATH_STEP1_OUTPUT = r'M:\data3\step1' 27 | 28 | 29 | def func_prepare(df, i, j): 30 | """准备工作。从外部导入数据后,对数据做一些准备工作 31 | 32 | Parameters 33 | ---------- 34 | df 35 | i 36 | j 37 | 38 | Returns 39 | ------- 40 | 41 | """ 42 | # 统一表头,防止不同数据源要分别写代码 43 | df.rename(columns={'code': 'asset', 'time': 'date', 'money': 'amount'}, inplace=True) 44 | # 不做科创板,不做创业板 45 | df = df[~df['asset'].str.startswith(('688', '300'))].copy() 46 | # 先转类型 47 | df['asset'] = df['asset'].astype('category') 48 | # 只转换小样本,速度更快 49 | to_replace = {r'(\d+)\.XSHG': r'\1.SH', r'(\d+)\.XSHE': r'\1.SZ'} 50 | cat = df['asset'].cat.categories 51 | mapping = pd.Series(cat, index=cat).replace(to_replace, regex=True) 52 | df['asset'] = df['asset'].cat.rename_categories(mapping.to_dict()) 53 | 54 | # 过滤停牌。计算技术指标和和横截面时会剔除停牌,但计算板块和指数时,停牌也参与计算 55 | df = df[df['paused'] == 0].copy() 56 | 57 | # # 用数字节约内存 58 | # df[AXIS_I] = int(i) 59 | # # [5: 6]一位用来当关键字,数据量大就用两位[4:6] 60 | # df[AXIS_J] = df['asset'].str[5:6].astype(int) 61 | # # set_index处理后,index中的类型变成了uint64,很无语 62 | # # df[AXIS_I] = df[AXIS_I].astype(np.uint16) # 请按数据情况填写,因为可能是2022年,又有可能是202211月,也可能是20221121日 63 | # # df[AXIS_J] = df[AXIS_J].astype(np.uint8) # 划分数量一般小时CPU真核数量 64 | 65 | # 内存优化 66 | df['paused'] = df['paused'].astype(bool) 67 | 68 | # 整理排序,加入了两个特殊列, 这样在后面的文件划分时处理更方便 69 | # df = df.set_index(['asset', 'date', AXIS_I, AXIS_J]).sort_index() 70 | df = df.set_index(['asset', 'date']).sort_index() 71 | 72 | return df 73 | 74 | 75 | def func_is_st(df: pd.DataFrame): 76 | df = df.stack(dropna=True) 77 | df.name = 'is_st' 78 | df.index.names = ['time', 'code'] 79 | df = df.reset_index() 80 | return df 81 | 82 | 83 | @timer 84 | def step0(): 85 | """计算技术指标,并提取用来算横截面的字段""" 86 | 87 | ii = RANGE_I # 年份 88 | jj = ['*'] # 品种 89 | parmap.map(func_load_calc_save, 90 | product(ii, jj), 91 | load_func=func_load_parquet, load_kwargs= 92 | { 93 | 'path': [PATH_STEP0_INPUT1, PATH_STEP0_INPUT2, PATH_STEP0_INPUT3], 94 | 'pattern': ['{0}*', '{0}*', '{0}*'], 95 | 'axis': [0, 0, 0], 96 | 'func': [None, None, func_is_st], 97 | # 源数据来自聚宽,股票代码和时间如下 98 | 'on': [['code', 'time'], ['code', 'time'], ['code', 'time']], 99 | 'index': [False, False, False], 100 | }, 101 | calc_func=func_prepare, calc_args=[], 102 | save_func=dataframe_save, save_args= 103 | [ 104 | # 这里没有按品种分别保存 105 | {'split_axis': None, 'path': PATH_STEP0_OUTPUT, 'exclude': []}, # 上流输入`三`型,输出转`田`型 106 | ], 107 | pm_processes=len(ii), 108 | pm_parallel=True) 109 | 110 | 111 | @timer 112 | def step1(): 113 | """合并多个parquet文件,因为在polars中不支持分类变量合并加载,所以在这进行合并,否则是可以不需要此步""" 114 | path1 = pathlib.Path(PATH_STEP0_OUTPUT) 115 | path2 = pathlib.Path(PATH_STEP1_OUTPUT) 116 | path2.mkdir(parents=True, exist_ok=True) 117 | 118 | df = pd.read_parquet(path1) # 读数文件夹 119 | df.to_parquet(path2 / 'data.parquet', compression='zstd') 120 | 121 | 122 | if __name__ == '__main__': 123 | # 原始数据合并 124 | step0() 125 | # 再次合并数据 126 | step1() 127 | 128 | logger.info('done') 129 | -------------------------------------------------------------------------------- /best_practices/step2_feature.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | 3 | import polars as pl 4 | from loguru import logger 5 | from polars import LazyFrame 6 | from talib import * 7 | 8 | from best_practices.utils import pl_np_wraps 9 | from ta_cn import EPSILON 10 | from ta_cn.tdx.logical import CROSS 11 | from ta_cn.tdx.over_bought_over_sold import KDJ 12 | from ta_cn.tdx.reference import BARSLASTCOUNT 13 | from ta_cn.tdx.statistics import limit_count 14 | from ta_cn.wq.cross_sectional import rank 15 | from ta_cn.wq.time_series import ts_returns, ts_rank 16 | 17 | 18 | def calc_col(df: LazyFrame) -> LazyFrame: 19 | """整列运算,不用groupby""" 20 | 21 | # 第一阶段,只利用原字段 22 | df = df.with_columns([ 23 | 24 | (pl.col('high') + pl.col('low')).alias('mid_price'), 25 | 26 | # 需要传入到talib,所以先不转float32 27 | (pl.col('open') * pl.col('factor')).alias('open_adj'), 28 | (pl.col('high') * pl.col('factor')).alias('high_adj'), 29 | (pl.col('low') * pl.col('factor')).alias('low_adj'), 30 | (pl.col('close') * pl.col('factor')).alias('close_adj'), 31 | 32 | (pl.col('close') >= pl.col('high_limit') - EPSILON).alias('涨停'), 33 | (pl.col('high') >= pl.col('high_limit') - EPSILON).alias('曾涨停'), 34 | (pl.col('close') <= pl.col('low_limit') + EPSILON).alias('跌停'), 35 | (pl.col('low') <= pl.col('low_limit') + EPSILON).alias('曾跌停'), 36 | 37 | (pl.col('low') >= pl.col('high_limit') - EPSILON).alias('一字板'), 38 | 39 | (pl.col('close') / pl.col('pre_close') - 1).cast(pl.Float32).alias('pct') 40 | ]) 41 | 42 | # 第二阶段,利用新字段 43 | df = df.with_columns([ 44 | 45 | (pl.col('涨停') & (pl.col('high') > pl.col('low'))).alias('T字板'), 46 | ]) 47 | 48 | return df 49 | 50 | 51 | def calc_ts(df) -> LazyFrame: 52 | """时序方向上计算,按股票分组。注意时序有先后顺序""" 53 | # 第一阶段 54 | df = df.with_columns([ 55 | # 收益率标签 56 | pl.col('close_adj').pct_change(1).shift(-1).alias('returns_1'), 57 | pl.col('close_adj').pct_change(5).shift(-5).alias('returns_5'), 58 | pl.col('close_adj').pct_change(10).shift(-10).alias('returns_10'), 59 | 60 | pl.col('close_adj').pct_change(1).alias('mom_1'), 61 | pl.col('close_adj').pct_change(5).alias('mom_5'), 62 | pl.col('close_adj').pct_change(10).alias('mom_10'), 63 | # 64 | *[pl.col('CLOSE').pct_change(i).alias(f'ROCP_{i}') for i in (1, 3, 5, 10, 20, 60)], 65 | 66 | pl.col('close_adj').map(lambda x: RSI(x, 6)).cast(pl.Float32).alias('rsi_6'), 67 | pl.col('close_adj').map(lambda x: RSI(x, 12)).cast(pl.Float32).alias('rsi_12'), 68 | pl.col('close_adj').map(lambda x: RSI(x, 24)).cast(pl.Float32).alias('rsi_24'), 69 | 70 | pl.col('close_adj').map(lambda x: STDDEV(x, 20)).cast(pl.Float32).alias('std'), 71 | pl.col('close_adj').map(lambda x: ROCP(x, 10)).cast(pl.Float32).alias('rocp'), 72 | 73 | pl.col('amount').map(lambda x: pl_np_wraps(ts_returns)(x, 1)).alias('amount_ratio'), 74 | 75 | pl.col('close_adj').map(lambda x: pl_np_wraps(ts_rank)(x, 10)).alias('收盘价10日排序'), 76 | pl.col('close_adj').map(lambda x: pl_np_wraps(ts_rank)(x, 20)).alias('收盘价20日排序'), 77 | 78 | # 多输入,单输出 79 | pl.map(['close_adj', 'volume'], lambda x: OBV(*x)).cast(pl.Float32).alias('obv'), 80 | pl.map(['high_adj', 'low_adj', 'close_adj'], lambda x: CCI(*x, 14)).cast(pl.Float32).alias('cci'), 81 | pl.map(['high_adj', 'low_adj', 'close_adj'], lambda x: WILLR(*x, 6)).cast(pl.Float32).alias('wr_6'), 82 | pl.map(['high_adj', 'low_adj', 'close_adj'], lambda x: WILLR(*x, 10)).cast(pl.Float32).alias('wr_10'), 83 | pl.map(['high_adj', 'low_adj', 'close_adj'], lambda x: NATR(*x, 14)).cast(pl.Float32).alias('atr_14'), 84 | 85 | # 这个为何严重拖慢速度?多线程优势为何没了? 86 | pl.col('涨停').map(lambda x: pl_np_wraps(BARSLASTCOUNT)(x)).alias('连板'), 87 | ]) 88 | 89 | # 第二阶段 90 | df = df.with_columns([ 91 | pl.map(['sma_5', 'sma_10'], lambda x: pl_np_wraps(CROSS)(*x)).alias('MA金叉'), 92 | ]) 93 | 94 | # 多输出 95 | df = df.hstack( 96 | pl.DataFrame([ 97 | *MACDEXT(df["close_adj"], 98 | fastperiod=12, fastmatype=1, 99 | slowperiod=26, slowmatype=1, 100 | signalperiod=9, signalmatype=1), 101 | *BBANDS(df['close_adj'], 102 | timeperiod=20, nbdevup=2, nbdevdn=2, matype=0), 103 | *pl_np_wraps(KDJ, 3, 3)(df['high_adj'], df['low_adj'], df['close_adj'], 9, 3, 3), 104 | *pl_np_wraps(limit_count, 1, 2)(df['涨停'], 2), 105 | ], 106 | columns=[ 107 | "macd", "macdsignal", "macdhist", 108 | "upperband", "middleband", "lowerband", 109 | "kdj_k", "kdj_d", "kdj_j", 110 | "N天", "M板", 111 | ] 112 | ) 113 | ) 114 | 115 | return df 116 | 117 | 118 | def calc_cs(df: LazyFrame) -> LazyFrame: 119 | """截面处理""" 120 | 121 | df = df.with_columns([ 122 | pl.col('close').map(lambda x: pl_np_wraps(rank)(x)).alias('close_rank'), 123 | ]) 124 | 125 | # 目前没有行情业数据,只能模拟行业分组 126 | df = df.with_columns([ 127 | (pl.col('close_rank') * 10 // 2).cast(pl.Int8).alias('industry'), 128 | ]) 129 | 130 | df = df.with_columns([ 131 | pl.col('returns_1').map(lambda x: pl_np_wraps(rank)(x)).alias('label_1'), 132 | pl.col('returns_5').map(lambda x: pl_np_wraps(rank)(x)).alias('label_5'), 133 | pl.col('returns_10').map(lambda x: pl_np_wraps(rank)(x)).alias('label_10'), 134 | ]) 135 | 136 | return df 137 | 138 | 139 | def calc_cs2(df: LazyFrame): 140 | """行业中性处理演示""" 141 | df = df.with_columns([ 142 | pl.col('amount').map(lambda x: pl_np_wraps(rank)(x)).alias('amount_rank'), 143 | 144 | ]) 145 | return df 146 | 147 | 148 | if __name__ == '__main__': 149 | PATH_STEP1_OUTPUT = r'M:\data3\step1' 150 | PATH_STEP2_OUTPUT = r'M:\data3\step2' 151 | 152 | # 路径准备 153 | PATH_STEP1_OUTPUT = pathlib.Path(PATH_STEP1_OUTPUT) 154 | PATH_STEP2_OUTPUT = pathlib.Path(PATH_STEP2_OUTPUT) 155 | PATH_STEP2_OUTPUT.mkdir(parents=True, exist_ok=True) 156 | 157 | logger.info('开始 数据加载') 158 | df = pl.read_parquet(PATH_STEP1_OUTPUT / '*.parquet', use_pyarrow=False, memory_map=True) 159 | # 调整表头顺序,方便观察 160 | df = df.select([ 161 | pl.col(['asset', 'date']), 162 | pl.all().exclude(['asset', 'date']) 163 | ]) 164 | print(df.head()) 165 | logger.info('开始 列计算') 166 | df = calc_col(df) 167 | 168 | logger.info('开始 时序计算') 169 | # 计算时序指标时date一定要保证顺序 170 | df = df.sort(by=['asset', 'date']) 171 | df = df.groupby(by=['asset']).apply(calc_ts) 172 | 173 | logger.info('开始 截面计算') 174 | # 排序后的数据groupby会更快 175 | df = df.sort(by=['date', 'asset']) 176 | df = df.groupby(by=['date']).apply(calc_cs) 177 | 178 | logger.info('开始 行业处理') 179 | df = df.groupby(by=['date', 'industry']).apply(calc_cs2) 180 | 181 | logger.info('开始 保存') 182 | # gzip格式耗时长,压缩率也没有明显提高, zstd格式好像更合适 183 | df.write_parquet(PATH_STEP2_OUTPUT / 'feature.parquet', compression='zstd') 184 | 185 | logger.info('完成') 186 | print(df.tail()) 187 | -------------------------------------------------------------------------------- /examples/alpha101.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | import pandas as pd 5 | from numpy.testing import assert_allclose 6 | from pandas._testing import assert_series_equal, assert_numpy_array_equal 7 | 8 | from ta_cn.utils_wide import WArr 9 | 10 | os.environ['TA_CN_MODE'] = 'WIDE' 11 | import ta_cn.alphas.alpha101 as w 12 | 13 | # 移除,这样就可以重复导入包了 14 | sys.modules.pop('ta_cn.alphas.alpha101') 15 | 16 | os.environ['TA_CN_MODE'] = 'LONG' 17 | import ta_cn.alphas.alpha101 as l 18 | import ta_cn.alphas.test101 as t 19 | 20 | if __name__ == '__main__': 21 | pd._testing._N = 500 22 | pd._testing._K = 30 23 | 24 | open_ = pd._testing.makeTimeDataFrame() + 5 25 | high = pd._testing.makeTimeDataFrame() + 10 26 | low = pd._testing.makeTimeDataFrame() + 5 27 | close = pd._testing.makeTimeDataFrame() + 5 28 | volume = pd._testing.makeTimeDataFrame() * 10 + 100 29 | vwap = pd._testing.makeTimeDataFrame() 30 | adv20 = pd._testing.makeTimeDataFrame() 31 | returns = pd._testing.makeTimeDataFrame() 32 | cap = pd._testing.makeTimeDataFrame() * 100 + 100 33 | group = close.copy() * 100 // 1 % 5 34 | 35 | df = { 36 | 'open': open_, 37 | 'high': high, 38 | 'low': low, 39 | 'close': close, 40 | 'returns': returns, 41 | 'volume': volume, 42 | 'vwap': vwap, 43 | 'adv5': adv20, 44 | 'adv10': adv20, 45 | 'adv15': adv20, 46 | 'adv20': adv20, 47 | 'adv30': adv20, 48 | 'adv40': adv20, 49 | 'adv50': adv20, 50 | 'adv60': adv20, 51 | 'adv81': adv20, 52 | 'adv120': adv20, 53 | 'adv150': adv20, 54 | 'adv180': adv20, 55 | 'subindustry': group, 56 | 'sector': group, 57 | 'industry': group, 58 | 'cap': cap, 59 | } 60 | 61 | kwargs_w = {k: WArr.from_array(v, direction='down') for k, v in df.items()} 62 | 63 | kwargs_l = {k: v.stack() for k, v in df.items()} 64 | kwargs_l = pd.DataFrame(kwargs_l) 65 | kwargs_l.index.names = ['date', 'asset'] 66 | kwargs_l = kwargs_l.to_dict(orient='series') 67 | 68 | for i in range(1, 101 + 1): 69 | # if i in (-100,): 70 | # continue 71 | name = f'alpha_{i:03d}' 72 | ft = getattr(t, name, None) 73 | fl = getattr(l, name, None) 74 | fw = getattr(w, name, None) 75 | 76 | print(name) 77 | rt = ft(**kwargs_l) 78 | rl = fl(**kwargs_l) 79 | rw = fw(**kwargs_w) 80 | # 比较 原版公式 与 优化后公式的 结果是否相同 81 | assert_series_equal(rt, rl) 82 | # 比较 长表 与 宽表 结果是否相同 83 | if i == 100: 84 | # alpha 100 scale后有少量误差,只能用allclose 85 | # scale(indneutralize(t1, group=subindustry), 1.) 86 | assert_allclose(rw.raw(), rl.unstack().values) 87 | else: 88 | assert_numpy_array_equal(rw.raw(), rl.unstack().values) 89 | -------------------------------------------------------------------------------- /examples/alpha191.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | import pandas as pd 5 | from pandas._testing import assert_series_equal, assert_numpy_array_equal 6 | 7 | from ta_cn.utils_wide import WArr 8 | 9 | os.environ['TA_CN_MODE'] = 'WIDE' 10 | import ta_cn.alphas.alpha191 as w 11 | 12 | # 移除,这样就可以重复导入包了 13 | sys.modules.pop('ta_cn.alphas.alpha191') 14 | 15 | os.environ['TA_CN_MODE'] = 'LONG' 16 | import ta_cn.alphas.alpha191 as l 17 | import ta_cn.alphas.test191 as t 18 | 19 | if __name__ == '__main__': 20 | pd._testing._N = 500 21 | pd._testing._K = 30 22 | 23 | open_ = pd._testing.makeTimeDataFrame() + 5 24 | high = pd._testing.makeTimeDataFrame() + 10 25 | low = pd._testing.makeTimeDataFrame() + 5 26 | close = pd._testing.makeTimeDataFrame() + 5 27 | volume = pd._testing.makeTimeDataFrame() * 10 + 100 28 | vwap = pd._testing.makeTimeDataFrame() 29 | adv20 = pd._testing.makeTimeDataFrame() 30 | returns = pd._testing.makeTimeDataFrame() 31 | cap = pd._testing.makeTimeDataFrame() * 100 + 100 32 | group = close.copy() * 100 // 1 % 5 33 | 34 | df = { 35 | 'OPEN': open_, 36 | 'HIGH': high, 37 | 'LOW': low, 38 | 'CLOSE': close, 39 | 'RET': returns, 40 | 'VOLUME': volume, 41 | 'AMOUNT': volume * 100, 42 | 'VWAP': vwap, 43 | 'DTM': high, 44 | 'DBM': low, 45 | 'MKT': high, 46 | 'SMB': low, 47 | 'HML': close, 48 | 'BANCHMARKINDEXOPEN': high, 49 | 'BANCHMARKINDEXCLOSE': low, 50 | } 51 | 52 | kwargs_w = {k: WArr.from_array(v, direction='down') for k, v in df.items()} 53 | 54 | kwargs_l = {k: v.stack() for k, v in df.items()} 55 | kwargs_l = pd.DataFrame(kwargs_l) 56 | kwargs_l.index.names = ['date', 'asset'] 57 | kwargs_l = kwargs_l.to_dict(orient='series') 58 | 59 | for i in range(1, 191 + 1): 60 | # 165 183 是MAX 与 SUMAC 问题 61 | if i in (165, 183, 30): 62 | continue 63 | name = f'alpha_{i:03d}' 64 | ft = getattr(t, name, None) 65 | fl = getattr(l, name, None) 66 | fw = getattr(w, name, None) 67 | 68 | print(name) 69 | rt = ft(**kwargs_l) 70 | rl = fl(**kwargs_l) 71 | rw = fw(**kwargs_w) 72 | # 比较 原版公式 与 优化后公式的 结果是否相同 73 | assert_series_equal(rt, rl) 74 | # 比较 长表 与 宽表 结果是否相同 75 | assert_numpy_array_equal(rw.raw(), rl.unstack().values) 76 | -------------------------------------------------------------------------------- /examples/dynamic_parameter_long.py: -------------------------------------------------------------------------------- 1 | """ 2 | 以下是在长表上的动态复权和动态参数的示例 3 | """ 4 | import numpy as np 5 | import pandas as pd 6 | 7 | import talib as ta 8 | 9 | # 准备数据 10 | 11 | o = np.random.rand(1000000).reshape(-1, 5000) 12 | h = np.random.rand(1000000).reshape(-1, 5000) + 10 13 | l = np.random.rand(1000000).reshape(-1, 5000) 14 | c = np.random.rand(1000000).reshape(-1, 5000) 15 | 16 | # 周期参数 17 | p1 = np.empty_like(c) 18 | p1[:] = 20 19 | p1[:150] = 5 20 | p1[:120] = 10 21 | p1[:50] = 30 22 | 23 | p2 = np.empty_like(c) 24 | p2[:] = 10 25 | p2[:180] = 15 26 | p2[:110] = 20 27 | p2[:70] = 10 28 | 29 | # 后复权因子 30 | f = np.empty_like(c) 31 | f[:] = 1.5 32 | f[:180] = 1.3 33 | f[:110] = 1.2 34 | f[:60] = 1.1 35 | 36 | # 输出 37 | up = np.empty_like(c) 38 | down = np.empty_like(c) 39 | 40 | # 周期等参数需要两行之间比较不同,然后不同周期再合并考虑 41 | 42 | df = { 43 | 'open': pd.DataFrame(o).stack(), 44 | 'high': pd.DataFrame(h).stack(), 45 | 'low': pd.DataFrame(l).stack(), 46 | 'close': pd.DataFrame(c).stack(), 47 | 'p1': pd.DataFrame(p1).stack(), 48 | 'p2': pd.DataFrame(p2).stack(), 49 | 'factor': pd.DataFrame(f).stack(), 50 | 'up': pd.DataFrame(up).stack(), 51 | 'down': pd.DataFrame(down).stack(), 52 | } 53 | df = pd.DataFrame(df) 54 | df.index.names = ['date', 'asset'] 55 | 56 | date = df.index.levels[0] 57 | 58 | 59 | def func(df): 60 | """分块计算指标""" 61 | 62 | def f0(x): 63 | return x['factor'] / x['factor'].iloc[-1] 64 | 65 | def f1(x): 66 | return ta.SMA(x['close'] * x['f'], timeperiod=x['p1'].iloc[-1]) 67 | 68 | def f2(x): 69 | return ta.ATR(x['high'] * x['f'], x['low'] * x['f'], x['close'] * x['f'], timeperiod=x['p2'].iloc[-1]) 70 | 71 | df['f'] = df.groupby(by='asset', group_keys=False).apply(f0) 72 | ma = df.groupby(by='asset', group_keys=False).apply(f1) 73 | atr = df.groupby(by='asset', group_keys=False).apply(f2) 74 | 75 | # 这里的改动并不影响结果 76 | df['up'] = (ma + atr) 77 | df['down'] = (ma - atr) 78 | return df[['up', 'down']] 79 | 80 | 81 | def loop(func, 82 | df, 83 | date, 84 | input_args, 85 | slicer, 86 | periods, 87 | output_args, 88 | is_backtest=True, 89 | pre_load=0, 90 | ): 91 | """按天进行循环 92 | 93 | Parameters 94 | ---------- 95 | func: 96 | 计算函数 97 | input_args: list 98 | 输入矩阵列表 99 | slicer: list 100 | 切片器 101 | periods: list 102 | 周期矩阵 103 | output_args: list 104 | 输出矩阵列表 105 | is_backtest: bool 106 | 是否回测模式: 107 | True: 回测。有全部数据,所以可以通过提前加载参数,分块计算,实现加速 108 | False: 仿真。行情按天推送,只能每天都算 109 | pre_load: int 110 | 调整值,可以加载更多数据 111 | """ 112 | 113 | def same_slice(df, slicer, date0, date1): 114 | p0 = df.loc[date0, slicer] 115 | p1 = df.loc[date1, slicer] 116 | return np.allclose(p0, p1) 117 | 118 | def max_period(df, periods, date0): 119 | p = df.loc[date0, periods] 120 | return max(p.max()) 121 | 122 | # 将日期准备 123 | end = 0 124 | for i, d in enumerate(date): 125 | if is_backtest and i < len(date) - 1: 126 | if same_slice(df, slicer, date[i], date[i + 1]): 127 | continue 128 | 129 | start = end 130 | 131 | pre_start = int(max_period(df, periods, date[i])) + pre_load 132 | pre_start = max(start - pre_start, 0) 133 | end = i + 1 134 | 135 | print(f'当前:{i}, 更新:[{start},{end}), 加载:[{pre_start},{end}), 步长:{end - start}') 136 | pre_start_date = date[pre_start] 137 | end_date = date[end - 1] # 时间切片时是左闭右闭 138 | 139 | # 切片 140 | # 迭代计算, 可能会有累积计算结果的需求 141 | func_outputs = func(df.loc[pre_start_date:end_date, input_args + periods + output_args]) 142 | df.loc[pre_start_date:end_date, output_args] = func_outputs 143 | 144 | 145 | loop(func, 146 | df, 147 | date, 148 | ['open', 'high', 'low', 'close', 'factor'], 149 | ['p1', 'p2', 'factor'], 150 | ['p1', 'p2'], 151 | ['up', 'down'], 152 | is_backtest=True) 153 | 154 | # up与down是我们想要的结果 155 | print(df[['up', 'down']]) 156 | -------------------------------------------------------------------------------- /examples/dynamic_parameter_wide.py: -------------------------------------------------------------------------------- 1 | """ 2 | 以下是在宽表上的动态复权和动态参数的示例 3 | """ 4 | import numpy as np 5 | 6 | # 准备数据 7 | import ta_cn.imports.wide_ta as W_TA 8 | 9 | o = np.random.rand(1000000).reshape(-1, 5000) 10 | h = np.random.rand(1000000).reshape(-1, 5000) + 10 11 | l = np.random.rand(1000000).reshape(-1, 5000) 12 | c = np.random.rand(1000000).reshape(-1, 5000) 13 | 14 | # 周期参数 15 | p1 = np.empty_like(c) 16 | p1[:] = 20 17 | p1[:150] = 5 18 | p1[:120] = 10 19 | p1[:50] = 30 20 | 21 | p2 = np.empty_like(c) 22 | p2[:] = 10 23 | p2[:180] = 15 24 | p2[:110] = 20 25 | p2[:70] = 10 26 | 27 | # 后复权因子 28 | f = np.empty_like(c) 29 | f[:] = 1.5 30 | f[:180] = 1.3 31 | f[:110] = 1.2 32 | f[:60] = 1.1 33 | 34 | # 输出 35 | up = np.empty_like(c) 36 | down = np.empty_like(c) 37 | 38 | 39 | def func(open_, high, low, close, factor, 40 | period1, period2): 41 | """分块计算指标""" 42 | 43 | # 后复权转前复权 44 | f = factor / factor[-1] 45 | open_ *= f 46 | high *= f 47 | low *= f 48 | close *= f 49 | 50 | # 计算指标 51 | ma = W_TA.SMA(close, timeperiod=period1).raw() 52 | atr = W_TA.ATR(high, low, close, timeperiod=period2).raw() 53 | 54 | # 每天记录新值 55 | return (ma + atr), (ma - atr) 56 | 57 | 58 | def loop(func, 59 | input_args, 60 | slicer, 61 | periods, 62 | output_args, 63 | is_backtest=True, 64 | pre_load=0, 65 | ): 66 | """按天进行循环 67 | 68 | Parameters 69 | ---------- 70 | func: 71 | 计算函数 72 | input_args: list 73 | 输入矩阵列表 74 | slicer: list 75 | 切片器 76 | periods: list 77 | 周期矩阵 78 | output_args: list 79 | 输出矩阵列表 80 | is_backtest: bool 81 | 是否回测模式: 82 | True: 回测。有全部数据,所以可以通过提前加载参数,分块计算,实现加速 83 | False: 仿真。行情按天推送,只能每天都算 84 | pre_load: int 85 | 调整值,可以加载更多数据 86 | """ 87 | 88 | def same_slice(ps, i): 89 | for p in ps: 90 | if not np.all(p[i] == p[i + 1]): 91 | return False 92 | return True 93 | 94 | def max_period(ps, i): 95 | m = -np.inf 96 | for p in ps: 97 | m = max(max(p[i]), m) 98 | return m 99 | 100 | end = 0 101 | real = input_args[0] 102 | for i in range(len(real)): 103 | # 1. 明天的参数是未知的,而这里为了排除重复计算,提前预取了明天的值 104 | # 2. 实盘时,每天都是最后一天 105 | if is_backtest and i < len(real) - 1: 106 | # 1. 数据最后,没有明天,不得跳过,必须计算一次 107 | if same_slice(slicer, i): 108 | # 2. 与明天的参数一样,跳过 109 | continue 110 | 111 | # 上期的结束就是这期的开始 112 | start = end 113 | # 预加载长度,对于EMA可能需要加载的更多,这里要灵活变动 114 | pre_start = int(max_period(periods, i)) + pre_load 115 | pre_start = max(start - pre_start, 0) 116 | end = i + 1 117 | 118 | print(f'当前:{i}, 更新:[{start},{end}), 加载:[{pre_start},{end}), 步长:{end - start}') 119 | 120 | # 切片 121 | func_args = [v[pre_start:end] for v in input_args] 122 | func_periods = [v[i] for v in periods] 123 | # 迭代计算, 可能会有累积计算结果的需求 124 | func_outputs = func(*func_args, *func_periods) 125 | 126 | # 只取后一段进行更新 127 | for x, y in zip(output_args, func_outputs): 128 | x[start:end] = y[start - end:] 129 | 130 | 131 | loop(func, 132 | [o, h, l, c, f], 133 | [p1, p2, f], 134 | [p1, p2], 135 | [up, down], 136 | is_backtest=False) 137 | 138 | # up与down是我们想要的结果 139 | print(up) 140 | print(down) 141 | -------------------------------------------------------------------------------- /examples/formula_dataframe.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | 3 | from ta_cn.imports.wide_ta import ATR 4 | from ta_cn.utils import np_to_pd 5 | from ta_cn.utils_wide import WArr 6 | 7 | pd._testing._N = 250 8 | pd._testing._K = 30 9 | h = pd._testing.makeTimeDataFrame() + 10 10 | l = pd._testing.makeTimeDataFrame() 11 | c = pd._testing.makeTimeDataFrame() 12 | 13 | # 数据需要封装成特殊对象,实现NaN的堆叠和还原 14 | h_ = WArr.from_array(h, direction='down') 15 | l_ = WArr.from_array(l, direction='down') 16 | c_ = WArr.from_array(c, direction='down') 17 | 18 | r = ATR(h_, l_, c_, 10) 19 | # 返回的数据可能是np.ndarray 20 | print(r.raw()) 21 | 22 | # 可以再封装回pd.DataFrame 23 | d = np_to_pd(r.raw(), copy=False, index=c.index, columns=c.columns) 24 | print(d.iloc[-5:]) 25 | -------------------------------------------------------------------------------- /examples/formula_series.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | 3 | from ta_cn.imports.long_ta import ATR, SMA 4 | from ta_cn.imports.long_wq import group_neutralize, rank 5 | 6 | pd._testing._N = 500 7 | pd._testing._K = 30 8 | 9 | open_ = pd._testing.makeTimeDataFrame() + 5 10 | high = pd._testing.makeTimeDataFrame() + 10 11 | low = pd._testing.makeTimeDataFrame() + 5 12 | close = pd._testing.makeTimeDataFrame() + 5 13 | group = close.copy() * 100 // 1 % 5 14 | 15 | df = { 16 | 'open_': open_.stack(), 17 | 'high': high.stack(), 18 | 'low': low.stack(), 19 | 'close': close.stack(), 20 | 'group': group.stack(), 21 | } 22 | df = pd.DataFrame(df) 23 | df.index.names = ['date', 'asset'] 24 | kwargs = df.to_dict(orient='series') 25 | 26 | # 单输入 27 | r = SMA(df['close'], timeperiod=10) 28 | print(r.unstack()) 29 | # 多输入 30 | r = ATR(df['high'], df['low'], df['close'], 10) 31 | print(r.unstack()) 32 | # 横截面 33 | r = rank(df['close']) 34 | print(r.unstack()) 35 | r = group_neutralize(df['close'], df['group']) 36 | 37 | print(r.unstack()) 38 | -------------------------------------------------------------------------------- /examples/formula_tdx.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from ta_cn.talib import init, set_compatibility_enable, set_compatibility 4 | from ta_cn.tdx.over_bought_over_sold import ATR_CN 5 | from ta_cn.tdx.trend import MACD 6 | 7 | # ta_cn.talib库底层是循环调用talib,部分计算效率不高 8 | # 可导入ta_cn中的公式 9 | 10 | # 准备数据 11 | h = np.random.rand(10000000).reshape(-1, 50000) + 10 12 | l = np.random.rand(10000000).reshape(-1, 50000) 13 | c = np.random.rand(10000000).reshape(-1, 50000) 14 | 15 | init(mode=2, skipna=False) 16 | 17 | r = ATR_CN(h, l, c, timeperiod=10) 18 | print(r) 19 | 20 | # 设置参数,让MACD中的EMA算法与国内算法相同 21 | set_compatibility_enable(True) 22 | set_compatibility(1) 23 | set_compatibility_enable(False) 24 | 25 | x, y, z = MACD(c, fastperiod=12, slowperiod=26, signalperiod=9) 26 | print(z) 27 | 28 | """ 29 | 三种不同调用MACD的方法: 30 | 31 | from ta_cn.imports import * 32 | %timeit MACD(c) 33 | 499 ms ± 10 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) 34 | 35 | from ta_cn.slow import MACD_CN 36 | %timeit MACD_CN(c) 37 | 3.59 s ± 58.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) 38 | 39 | import ta_cn.talib as ta 40 | %timeit ta.MACD(c) 41 | 426 ms ± 1.46 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) 42 | """ 43 | -------------------------------------------------------------------------------- /examples/group_func.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from pandas._testing import assert_series_equal 3 | 4 | from ta_cn.wq.group import group_mean 5 | 6 | if __name__ == '__main__': 7 | pd._testing._N = 500 8 | pd._testing._K = 30 9 | 10 | close = pd._testing.makeTimeDataFrame() + 5 11 | group = close.copy() * 100 // 1 % 5 12 | 13 | df = { 14 | 'close': close, 15 | 'sector': group, 16 | } 17 | 18 | kwargs_l = {k: v.stack() for k, v in df.items()} 19 | kwargs_l = pd.DataFrame(kwargs_l) 20 | kwargs_l.index.names = ['date', 'asset'] 21 | kwargs_l = kwargs_l.to_dict(orient='series') 22 | 23 | # r = group_percentage(kwargs_l['close'], kwargs_l['sector'], percentage=0.1) 24 | # print(r) 25 | # r = group_percentage(kwargs_l['close'], kwargs_l['sector'], percentage=0.9) 26 | # print(r) 27 | # r = group_neutralize(kwargs_l['close'], kwargs_l['sector']) 28 | # print(r) 29 | # r = group_normalize(kwargs_l['close'], kwargs_l['sector']) 30 | # print(r) 31 | r = group_mean(kwargs_l['close'], kwargs_l['close'], kwargs_l['sector']) 32 | print(r) 33 | -------------------------------------------------------------------------------- /examples/group_nav.py: -------------------------------------------------------------------------------- 1 | """ 2 | 本示例使用的是1期收益率*1期持仓权重。有以下注意事项: 3 | 1. 收益率是每次持有的时间。如持有5天,则将第5天价格除以第1天价格,得到的收益率移动到对应位置 4 | 2. 持有期不能重叠,否则结果错误 5 | 3. 只能用在做多场景。做空场景下复利结果有误差, 6 | 7 | 错误的例子如下: 8 | 1. 5天收益率,每天都产生信号。相当于有5份资金,每天都入了1份并持有了5天 9 | 10 | 问题:如果想计算因子持有期是1天、5天、10天收益更好,怎么处理? 11 | 1. 因子重采样成1天,5天,10天。然后相乘 12 | 不同入场时间对结果有很大影响,时间长影响大。比如月度调仓抢跑 13 | 不同周期的曲线不能画在同一图上 14 | 不同周期的终值可以画在同一图上 15 | 2. 因子每天都入场,然后乘以1天、5天、10天的收益率。 16 | 1. 相当于资金多了N份 17 | 2. 每个收益都错开了一段,时间不同,不能直接相加。最后的总收益可以粗略相加N期求平均 18 | 19 | """ 20 | import numpy as np 21 | import pandas as pd 22 | 23 | from ta_cn.performance import weighted_index, ic, ir, ic_decay 24 | from ta_cn.split import top_k, quantile_n 25 | 26 | # 因子值 27 | f = np.random.rand(10000000).reshape(-1, 5000) 28 | # 持有1期收益 29 | r = (np.random.rand(10000000).reshape(-1, 5000) - 0.5) / 100 30 | 31 | ics = ic(f, r) 32 | print(ics) 33 | print(ir(ics)) 34 | print(ic_decay(f, r).mean()) 35 | 36 | # 前50,前100,前200,分三组的净值 37 | d = {} 38 | topK = top_k(-f, bins=[0, 50, 100, 200]) 39 | for k, v in topK.items(): 40 | d[k] = weighted_index(~np.isnan(v), returns=r, need_one=True) 41 | 42 | df = pd.DataFrame(d) 43 | print(df) 44 | 45 | # 分十组的净值 46 | d = {} 47 | qN = quantile_n(f, n=10) 48 | for k, v in qN.items(): 49 | d[k] = weighted_index(~np.isnan(v), returns=r, need_one=True) 50 | 51 | df = pd.DataFrame(d) 52 | print(df) 53 | -------------------------------------------------------------------------------- /examples/talib_2d.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | # 新版talib,只要替换引用,并添加一句init即可 4 | import ta_cn.talib as ta 5 | from ta_cn.utils_wide import pushna, pullna 6 | 7 | # 原版talib,不支持二维数据 8 | # import talib as ta 9 | 10 | # 准备数据 11 | h = np.random.rand(1000000).reshape(-1, 5000) + 10 12 | l = np.random.rand(1000000).reshape(-1, 5000) 13 | c = np.random.rand(1000000).reshape(-1, 5000) 14 | # 指定模式,否则对talib封装的所有函数都不存在 15 | ta.init(mode=2, skipna=False, to_globals=True) 16 | 17 | # 几个调用函数演示 18 | r = ta.ATR(h, l, c, timeperiod=10) 19 | print(r) 20 | x, y, z = ta.BBANDS(c, timeperiod=10, nbdevup=2, nbdevdn=2) 21 | print(z) 22 | 23 | # 将少量值设置为空,用来模拟停牌 24 | c[c < 0.4] = np.nan 25 | 26 | # 提前处理数据,跳过停牌进行计算,再还原的演示 27 | # 嵌套指标时,全为时序指标使用down,或全为截面使用right。混合时此方法不要轻易使用 28 | arr, row, col = pushna(c, direction='down') 29 | rr = ta.SMA(arr, timeperiod=10) 30 | r = pullna(rr, row, col) 31 | print(r) 32 | -------------------------------------------------------------------------------- /examples/市场宽度.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from pandas._testing import assert_series_equal 3 | 4 | from ta_cn.imports.long import A_div_AB_1, A_div_AB_2 5 | from ta_cn.imports.wide_wq import ts_mean 6 | from ta_cn.utils import np_to_pd 7 | 8 | if __name__ == '__main__': 9 | pd._testing._N = 500 10 | pd._testing._K = 30 11 | 12 | close = pd._testing.makeTimeDataFrame() 13 | group = close.copy() * 500 // 1 % 5 14 | 15 | # 移动平均需要处理一下 16 | sma20 = ts_mean(close, 20) 17 | sma20 = np_to_pd(sma20.raw(), index=close.index, columns=close.columns) 18 | 19 | df = { 20 | 'close': close, 21 | 'group': group, 22 | 'c_gt_sma': close > sma20, 23 | } 24 | 25 | kwargs_l = {k: v.stack() for k, v in df.items()} 26 | kwargs_l = pd.DataFrame(kwargs_l) 27 | kwargs_l.index.names = ['date', 'asset'] 28 | kwargs_l = kwargs_l.to_dict(orient='series') 29 | 30 | r = A_div_AB_1(kwargs_l['close']) 31 | print(r) 32 | r = A_div_AB_2(kwargs_l['c_gt_sma'], kwargs_l['group']) 33 | print(r) 34 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | bottleneck 2 | numpy 3 | pandas 4 | TA-Lib>=0.4.19 5 | numba 6 | loguru -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | with open("README.md", "r", encoding='utf-8') as fp: 4 | long_description = fp.read() 5 | 6 | version = {} 7 | with open("ta_cn/_version.py", encoding="utf-8") as fp: 8 | exec(fp.read(), version) 9 | 10 | setuptools.setup( 11 | name="ta_cn", 12 | version=version['__version__'], 13 | author="wukan", 14 | author_email="wu-kan@163.com", 15 | description="Technical Analysis Indicators", 16 | long_description=long_description, 17 | long_description_content_type="text/markdown", 18 | url="https://github.com/wukan1986/ta_cn", 19 | packages=setuptools.find_packages(), 20 | install_requires=[ 21 | 'numpy', # 基础版,希望用户自己能解决talib的安装问题 22 | 'loguru', 23 | ], 24 | extras_require={ 25 | 'talib': [ 26 | 'TA-Lib>=0.4.19', # 有编译条件的用户通过它也能编译安装 27 | ], 28 | 'cn': [ 29 | 'numpy>=1.20.0', # 主要是为了sliding_window_view 30 | 'pandas', 31 | 'bottleneck', 32 | 'numba', 33 | # 'TA-Lib>=0.4.19', # 担心talib安装不过导致失败,所以注释此句 34 | ], 35 | 'parallel': [ 36 | 'parmap', # 并行库,带进度条 37 | ], 38 | }, 39 | classifiers=[ 40 | "Programming Language :: Python :: 3", 41 | "License :: OSI Approved :: MIT License", 42 | "Operating System :: OS Independent", 43 | 'Intended Audience :: Developers', 44 | ], 45 | python_requires=">=3.7", 46 | ) 47 | -------------------------------------------------------------------------------- /ta_cn/__init__.py: -------------------------------------------------------------------------------- 1 | # 此处引入库要小心,有可能打破只想用talib二维库的需求 2 | 3 | # 按股票分组,计算时序指标。注意,组内时序需要已经排序 4 | BY_ASSET = 'asset' 5 | # 按时间分组。计算横截面 6 | BY_DATE = 'date' 7 | # 横截面上进行行业中性化 8 | BY_GROUP = ['date', 'group'] 9 | 10 | # 浮点数比较精度 11 | # 实测遇到了1.6518123e-06这种实际为0的情况 12 | # 再考虑到np.allclose(rtol=1e-05, atol=1e-08),所以将EPSILON改成1e-05 13 | EPSILON = 1e-05 14 | -------------------------------------------------------------------------------- /ta_cn/_version.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.5.2" 2 | -------------------------------------------------------------------------------- /ta_cn/aggregate.py: -------------------------------------------------------------------------------- 1 | import numpy as _np 2 | 3 | 4 | def A_div_AB(x): 5 | """A/(A+B) 6 | 7 | 输入一维,得到一个值 8 | 输入二维,得到一列值 9 | 10 | 可用于计算市场宽度等指标 11 | """ 12 | if x.ndim == 2: 13 | t1 = _np.nansum(x, axis=1, keepdims=True) 14 | t2 = _np.nansum(~_np.isnan(x), axis=1, keepdims=True) 15 | else: 16 | t1 = _np.nansum(x) 17 | t2 = _np.nansum(~_np.isnan(x)) 18 | return t1 / t2 19 | -------------------------------------------------------------------------------- /ta_cn/alphas/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wukan1986/ta_cn/a569a618109daa804541c5d67fa4c0407f03fd49/ta_cn/alphas/__init__.py -------------------------------------------------------------------------------- /ta_cn/alphas/alpha.py: -------------------------------------------------------------------------------- 1 | """ 2 | 3 | """ 4 | import numpy as np 5 | 6 | 7 | def FILTER_191(A, condition): 8 | """注意:与TDX中的FILTER不同""" 9 | return np.where(condition, A, 0) 10 | 11 | 12 | def CUMPROD(A): 13 | return np.cumprod(A, axis=0) 14 | 15 | 16 | def CUMSUM(A): 17 | return np.cumsum(A, axis=0) 18 | -------------------------------------------------------------------------------- /ta_cn/bn_wraps.py: -------------------------------------------------------------------------------- 1 | from functools import wraps 2 | 3 | import numpy as np 4 | from bottleneck import * 5 | 6 | 7 | def bn_move_window(func): 8 | """bottleneck在数据量少返回nan""" 9 | 10 | @wraps(func) 11 | def decorated(a, window, *args, axis=-1, **kwargs): 12 | if a.shape[axis] < window: 13 | return np.full_like(a, fill_value=np.nan) 14 | return func(a, window, *args, axis=axis, **kwargs) 15 | 16 | return decorated 17 | 18 | 19 | move_argmax = bn_move_window(move_argmax) 20 | move_argmin = bn_move_window(move_argmin) 21 | move_max = bn_move_window(move_max) 22 | move_mean = bn_move_window(move_mean) 23 | move_median = bn_move_window(move_median) 24 | move_min = bn_move_window(move_min) 25 | move_rank = bn_move_window(move_rank) 26 | move_std = bn_move_window(move_std) 27 | move_sum = bn_move_window(move_sum) 28 | move_var = bn_move_window(move_var) 29 | -------------------------------------------------------------------------------- /ta_cn/candle.py: -------------------------------------------------------------------------------- 1 | """ 2 | K线相关计算 3 | """ 4 | # https://github.com/TA-Lib/ta-lib/blob/master/src/ta_func/ta_utility.h#L327 5 | import numpy as np 6 | 7 | from . import EPSILON 8 | from .wq.time_series import ts_delay 9 | 10 | 11 | # ========================================= 12 | # 单K线计算 13 | 14 | def REALBODY(open_, high, low, close): 15 | """实体""" 16 | return (close - open_).abs() 17 | 18 | 19 | def UPPERSHADOW(open_, high, low, close): 20 | """上影""" 21 | return high - np.maximum(open_, close) 22 | 23 | 24 | def LOWERSHADOW(open_, high, low, close): 25 | """下影""" 26 | return np.minimum(open_, close) - low 27 | 28 | 29 | def HIGHLOWRANGE(open_, high, low, close): 30 | """总长""" 31 | return high - low 32 | 33 | 34 | def CANDLECOLOR(open_, high, low, close): 35 | """K线颜色""" 36 | return close >= open_ 37 | 38 | 39 | def UPPERBODY(open_, high, low, close): 40 | """实体上沿""" 41 | return np.maximum(open_, close) 42 | 43 | 44 | def LOWERBODY(open_, high, low, close): 45 | """实体下沿""" 46 | return np.minimum(open_, close) 47 | 48 | 49 | def efficiency_ratio_candle(open_, high, low, close): 50 | """K线内的市场效率。两个总长减去一个实体长就是路程 51 | 52 | 比较粗略的计算市场效率的方法。丢失了部分路程信息,所以结果会偏大 53 | """ 54 | displacement = REALBODY(open_, high, low, close) 55 | distance = 2 * HIGHLOWRANGE(open_, high, low, close) - displacement 56 | return displacement / (distance + EPSILON) 57 | 58 | 59 | # ========================================= 60 | # 双K线计算 61 | 62 | def REALBODYGAPUP(open_, high, low, close): 63 | """实体跳空高开。当天实体下界大于昨天实体上界""" 64 | upper = UPPERBODY(open_, high, low, close) 65 | lower = LOWERBODY(open_, high, low, close) 66 | return lower > ts_delay(upper, 1) 67 | 68 | 69 | def REALBODYGAPDOWN(open_, high, low, close): 70 | """实体跳空低开。当天实体上界小于昨天实体下界""" 71 | upper = UPPERBODY(open_, high, low, close) 72 | lower = LOWERBODY(open_, high, low, close) 73 | return upper < ts_delay(lower, 1) 74 | 75 | 76 | def CANDLEGAPUP(open_, high, low, close): 77 | """跳空高开。当天最低大于昨天最高""" 78 | return low > ts_delay(high, 1) 79 | 80 | 81 | def CANDLEGAPDOWN(open_, high, low, close): 82 | """跳空低开。当天最高小于昨天最低""" 83 | return high < ts_delay(low, 1) 84 | -------------------------------------------------------------------------------- /ta_cn/chip.py: -------------------------------------------------------------------------------- 1 | import numba 2 | import numpy as np 3 | 4 | 5 | @numba.jit(nopython=True, cache=True, nogil=True) 6 | def chip(high, low, avg, turnover, 7 | start=None, stop=None, step=0.2): 8 | """筹码分布,可用于WINNER或COST指标 9 | 10 | 不可能完全还原真实的筹码分布,只能接近。所以做了一下特别处理 11 | 12 | 1. 三角分布,比平均分布更接近 13 | 2. 步长。没有必要每个价格都统计,特别是复权后价格也无法正好是0.01间隔 14 | 高价股建议步长设大些,低价股步长需设小些 15 | 16 | Parameters 17 | ---------- 18 | high 19 | low 20 | avg 21 | 一维序列 22 | turnover: 23 | 换手率,需要在外转成0~1范围内 24 | start 25 | 开始价格 26 | stop 27 | 结束价格 28 | step 29 | 步长。一字涨停时,三角分布的底为1,高为2。但无法当成梯形计算面积,所以从中用半步长切开计算 30 | 31 | Returns 32 | ------- 33 | out 34 | 筹码分布 35 | columns 36 | 价格表头 37 | 38 | """ 39 | # 将类型进行转换,要提速时可以将此在外部实现 40 | # open_ = pd_to_np(open_) 41 | # high = pd_to_np(high) 42 | # low = pd_to_np(low) 43 | # close = pd_to_np(close) 44 | # turnover = pd_to_np(turnover) 45 | 46 | # 网格范围 47 | if start is None: 48 | start = np.min(low) 49 | if stop is None: 50 | stop = np.max(high) 51 | 52 | left = round(start / step) * 2 - 1 53 | right = round(stop / step) * 2 + 1 54 | 55 | # 最小最大值左右要留半格,range是左闭右开,长度必须为2n+1 56 | columns = np.arange(left, right + 1) 57 | grid_shape = (len(turnover), len(columns)) 58 | 59 | # numba中round写法特殊 60 | _high = np.empty_like(high) 61 | _low = np.empty_like(low) 62 | _avg = np.empty_like(avg) 63 | 64 | # high和low必须落到边缘上 65 | _high = np.round(high / step, 0, _high) * 2 + 1 66 | _low = np.round(low / step, 0, _low) * 2 - 1 67 | # avg必须落在实体上 68 | _avg = np.round(avg / step, 0, _avg) * 2 69 | tri_height = 2 / ((_high - _low) // 2) # 三角形高度 70 | 71 | # 得到三组值在网格中的位置 72 | high_arg = np.argwhere(columns == _high.reshape(-1, 1))[:, 1] 73 | avg_arg = np.argwhere(columns == _avg.reshape(-1, 1))[:, 1] 74 | low_arg = np.argwhere(columns == _low.reshape(-1, 1))[:, 1] 75 | 76 | # 高度表 77 | height = np.zeros(grid_shape) 78 | for i in range(len(height)): 79 | la = low_arg[i] 80 | aa = avg_arg[i] 81 | ha = high_arg[i] 82 | th = tri_height[i] 83 | height[i, la:aa + 1] = np.linspace(0, th, aa - la + 1) 84 | height[i, aa:ha + 1] = np.linspace(th, 0, ha - aa + 1) 85 | 86 | # 计算半块面积, 三角形的高变成了梯形的上下底,梯形高固定为0.5,*0.5/2=/4 87 | # 宽度-1,例如,原长度为5,-1后为4 88 | area = (height[:, :-1] + height[:, 1:]) / 4 89 | # 合成一块。宽度/2,例如原长度为4,/2后为2 90 | weight = area[:, ::2] + area[:, 1::2] 91 | 92 | # 输出 93 | out = np.zeros_like(weight) 94 | # 剩余换手率 95 | turnover2 = 1 - turnover 96 | # 第一天其实应当用上市发行价,过于麻烦,还是将第一天等权 97 | # 取巧方法,利用-1的特性,可减少if判断, 98 | out[-1] = weight[0] 99 | # 这里现在用的numpy, 还要快可考虑numba 100 | for i in range(len(turnover)): 101 | out[i] = out[i - 1] * turnover2[i] + weight[i] * turnover[i] 102 | 103 | # print(out.sum(axis=1)) 104 | return out, (step / 2) * columns[1::2] 105 | 106 | 107 | def WINNER(out, columns, close): 108 | """获利盘比例 109 | 110 | Parameters 111 | ---------- 112 | out 113 | chip函数生成的筹码分布矩阵 114 | columns 115 | 奇数位置为边缘价格 116 | 偶数位置为梯形价格 117 | close 118 | 收盘价。或指定价 119 | 120 | Examples 121 | -------- 122 | >>> out, columns = chip(high, low, avg, turnover, step=0.1) 123 | >>> WINNER(out, columns, close) 124 | 125 | 126 | """ 127 | if isinstance(close, np.ndarray): 128 | close = close.reshape(-1, 1) 129 | cheap = np.where(columns <= close, out, 0) 130 | return np.sum(cheap, axis=1) 131 | 132 | 133 | def COST(out, columns, cost): 134 | """成本分布 135 | 136 | Parameters 137 | ---------- 138 | out 139 | chip函数生成的筹码分布矩阵 140 | columns 141 | 价格网格 142 | cost 143 | 成本 144 | 145 | Examples 146 | -------- 147 | >>> out, columns = chip(high, low, avg, turnover, step=0.1) 148 | >>> COST(out, columns, 0.5) 149 | 150 | """ 151 | if isinstance(cost, np.ndarray): 152 | cost = cost.reshape(-1, 1) 153 | cum = np.cumsum(out, axis=1) 154 | prices = np.where(cum <= cost, columns, 0) 155 | return np.max(prices, axis=1) 156 | -------------------------------------------------------------------------------- /ta_cn/ema.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | 3 | from . import talib as ta 4 | from .ewm_nb import ewm_mean, ma_1st, sum_1st 5 | from .talib import set_compatibility, TA_COMPATIBILITY_DEFAULT, TA_COMPATIBILITY_METASTOCK 6 | from .utils import np_to_pd 7 | 8 | """ 9 | 由于MA有太多种了,单独提到此处,方便对比 10 | 11 | 默认EMA算法中,上一期值权重(timeperiod-1)/(timeperiod+1),当前值权重2/(timeperiod+1) 12 | ta.set_compatibility建议放在循环前执行 13 | 14 | References 15 | ---------- 16 | https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.ewm.html?highlight=ewm#pandas.DataFrame.ewm 17 | https://stackoverflow.com/questions/42869495/numpy-version-of-exponential-weighted-moving-average-equivalent-to-pandas-ewm 18 | 19 | Warnings 20 | -------- 21 | 由于EMA的计算特点,只要调用了此文件中的公式,都需要预留一些数据,数据太短可能导致起点不同值不同 22 | 23 | """ 24 | _ta2d = ta.init(mode=2, skipna=False) 25 | 26 | 27 | def EMA_0_TA(real, timeperiod: int): 28 | """EMA第一个值用MA 29 | 30 | EMA_0_TA(real, timeperiod=24) 31 | """ 32 | set_compatibility(TA_COMPATIBILITY_DEFAULT) 33 | return _ta2d.EMA(real, timeperiod=timeperiod) 34 | 35 | 36 | def EXPMEMA(real, timeperiod): 37 | """EMA第一个值用MA 38 | 39 | EXPMEMA(real, timeperiod=24) 40 | 41 | return np_to_pd(ma_1st(real, timeperiod)).ewm(span=timeperiod, min_periods=0, adjust=False).mean() 42 | """ 43 | return ewm_mean(ma_1st(real, timeperiod), span=timeperiod, min_periods=0, adjust=False) 44 | 45 | 46 | def EMA_1_TA(real, timeperiod): 47 | """EMA第一个值用Price 48 | 49 | EMA_1_TA(real, timeperiod=24) 50 | """ 51 | set_compatibility(TA_COMPATIBILITY_METASTOCK) 52 | return _ta2d.EMA(real, timeperiod=timeperiod) 53 | 54 | 55 | def EMA_1_PD(real, timeperiod): 56 | """EMA第一个值用Price 57 | 58 | EMA_1_PD(real, timeperiod=24) 59 | 60 | return np_to_pd(real).ewm(span=timeperiod, min_periods=timeperiod, adjust=False).mean() 61 | """ 62 | return ewm_mean(real, span=timeperiod, min_periods=timeperiod, adjust=False) 63 | 64 | 65 | def SMA_CN(real, timeperiod, M): 66 | """EMA第一个值用MA 67 | 68 | SMA(real, timeperiod=24, M=1) 69 | 70 | return np_to_pd(ma_1st(real, timeperiod)).ewm(alpha=M / timeperiod, min_periods=0, adjust=False).mean() 71 | """ 72 | return ewm_mean(ma_1st(real, timeperiod), alpha=M / timeperiod, min_periods=0, adjust=False) 73 | 74 | 75 | def DMA(real, alpha): 76 | """求X的动态移动平均。 0= minp) else np.nan 88 | old_wt = 1. 89 | 90 | for i in range(1, N): 91 | cur = a[i] 92 | is_observation = (cur == cur) # 非NaN 93 | nobs += is_observation 94 | 95 | if weighted_avg == weighted_avg: 96 | old_wt *= old_wt_factor[i] 97 | new_wt = new_wt_factor[i] 98 | if is_observation: 99 | # avoid numerical errors on constant series 100 | if weighted_avg != cur: 101 | weighted_avg = ((old_wt * weighted_avg) + (new_wt * cur)) / (old_wt + new_wt) 102 | if adjust: 103 | old_wt += new_wt 104 | else: 105 | old_wt = 1. 106 | elif is_observation: 107 | weighted_avg = cur 108 | out[i] = weighted_avg if (nobs >= minp) else np.nan 109 | return out 110 | 111 | 112 | @numba.jit(nopython=True, cache=True, nogil=True) 113 | def ewm_mean_nb(a, out, alpha, minp: int = 0, adjust: bool = False): 114 | """2-dim version of `ewm_mean_1d_nb`.""" 115 | for col in range(a.shape[1]): 116 | out[:, col] = ewm_mean_1d_nb(a[:, col], out[:, col], alpha[:, col], minp=minp, adjust=adjust) 117 | return out 118 | 119 | 120 | def ewm_mean(a, alpha=None, span=None, com=None, min_periods: int = 0, adjust: bool = False): 121 | """指数移动平均 122 | 123 | pandas使用cython技术,本处使用numba。实测发现numba版更快 124 | 125 | Parameters 126 | ---------- 127 | a 128 | alpha 129 | span 130 | com 131 | min_periods 132 | adjust: 133 | 是否使用调整算法。一般使用False 134 | 135 | References 136 | ---------- 137 | https://github.com/pandas-dev/pandas/blob/main/pandas/_libs/window/aggregations.pyx 138 | https://github.com/polakowo/vectorbt/blob/master/vectorbt/generic/nb.py 139 | 140 | """ 141 | # 三个参数只使用其中一个即可 142 | if span is not None: 143 | com = (span - 1) / 2.0 144 | if com is not None: 145 | alpha = 1. / (1. + com) 146 | 147 | a = pd_to_np(a) 148 | out = np.empty_like(a) 149 | 150 | if isinstance(alpha, (int, float)): 151 | # 单数字就扩展 152 | alpha = np.full_like(a, fill_value=alpha) 153 | 154 | if a.ndim == 2: 155 | return ewm_mean_nb(a, out, alpha, minp=min_periods, adjust=adjust) 156 | else: 157 | return ewm_mean_1d_nb(a, out, alpha, minp=min_periods, adjust=adjust) 158 | -------------------------------------------------------------------------------- /ta_cn/imports/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | 本目录下提供二维函数跳过空值的封装 3 | 1. 长表跳空值 4 | 2. 宽表跳空值 5 | """ -------------------------------------------------------------------------------- /ta_cn/imports/gtja_long.py: -------------------------------------------------------------------------------- 1 | """ 2 | 公式转alpha191 3 | 4 | 国泰君安-基于短周期价量特征的多因子选股 5 | """ 6 | 7 | from ..imports import long as L 8 | from ..imports import long_ta as L_TA 9 | from ..imports import long_tdx as L_TDX 10 | from ..imports import long_wq as L_WQ 11 | 12 | CORR = L_WQ.ts_corr 13 | REGSLOPE = L_TA.LINEARREG_SLOPE 14 | MEAN = L_TA.SMA 15 | WMA = L_TA.SMA # !!!WMA的公式没看懂,所以用另一个替代,以后再改 16 | DECAYLINEAR = L_WQ.ts_decay_linear 17 | 18 | CUMPROD = L.CUMPROD 19 | FILTER = L.FILTER_191 20 | RANK = L_WQ.rank 21 | TSRANK = L_WQ.ts_rank 22 | LessThan = L_WQ.less 23 | 24 | IF = L_WQ.if_else 25 | ABS = L_WQ.abs_ 26 | LOG = L_WQ.log # 这里是用的自然对数 27 | MAX = L_WQ.max_ 28 | MIN = L_WQ.min_ 29 | SIGN = L_WQ.sign 30 | 31 | SMA = L_TDX.SMA_CN 32 | 33 | COUNT = L_WQ.ts_count 34 | DELTA = L_WQ.ts_delta 35 | TSMAX = L_WQ.ts_max 36 | HIGHDAY = L_WQ.ts_arg_max 37 | TSMIN = L_WQ.ts_min 38 | LOWDAY = L_WQ.ts_arg_min 39 | MA = L_WQ.ts_mean 40 | PROD = L_WQ.ts_product 41 | DELAY = L_WQ.ts_delay 42 | SUM = L_WQ.ts_sum 43 | SUMIF = L_TDX.SUMIF # 注意,SUMIF参数的位置常用的方式不同 44 | 45 | REGBETA = L.SLOPE_YX 46 | REGRESI = L.REGRESI4 47 | 48 | COVIANCE = L_WQ.ts_covariance 49 | STD = L_WQ.ts_std_dev # 引入的是全体标准差 50 | -------------------------------------------------------------------------------- /ta_cn/imports/gtja_wide.py: -------------------------------------------------------------------------------- 1 | """ 2 | 公式转alpha191 3 | 4 | 国泰君安-基于短周期价量特征的多因子选股 5 | """ 6 | from ..imports import wide as W 7 | from ..imports import wide_ta as W_TA 8 | from ..imports import wide_tdx as W_TDX 9 | from ..imports import wide_wq as W_WQ 10 | 11 | CORR = W_WQ.ts_corr 12 | REGSLOPE = W_TA.LINEARREG_SLOPE 13 | MEAN = W_TA.SMA 14 | WMA = W_TA.SMA # !!!WMA的公式没看懂,所以用另一个替代,以后再改 15 | DECAYLINEAR = W_WQ.ts_decay_linear 16 | 17 | CUMPROD = W.CUMPROD 18 | FILTER = W.FILTER_191 19 | RANK = W_WQ.rank 20 | TSRANK = W_WQ.ts_rank 21 | LessThan = W_WQ.less 22 | 23 | IF = W_WQ.if_else 24 | ABS = W_WQ.abs_ 25 | LOG = W_WQ.log # 这里是用的自然对数 26 | MAX = W_WQ.max_ 27 | MIN = W_WQ.min_ 28 | SIGN = W_WQ.sign 29 | 30 | SMA = W_TDX.SMA_CN 31 | 32 | COUNT = W_WQ.ts_count 33 | DELTA = W_WQ.ts_delta 34 | TSMAX = W_WQ.ts_max 35 | HIGHDAY = W_WQ.ts_arg_max 36 | TSMIN = W_WQ.ts_min 37 | LOWDAY = W_WQ.ts_arg_min 38 | MA = W_WQ.ts_mean 39 | PROD = W_WQ.ts_product 40 | DELAY = W_WQ.ts_delay 41 | SUM = W_WQ.ts_sum 42 | SUMIF = W_TDX.SUMIF # 注意,SUMIF参数的位置常用的方式不同 43 | 44 | REGBETA = W.SLOPE_YX 45 | REGRESI = W.REGRESI4 46 | 47 | COVIANCE = W_WQ.ts_covariance 48 | STD = W_WQ.ts_std_dev # 引入的是全体标准差 49 | -------------------------------------------------------------------------------- /ta_cn/imports/long.py: -------------------------------------------------------------------------------- 1 | """ 2 | 对指标的算子化包装 3 | 1. 包装成只支持 长表 输入和输出 4 | 2. 简化参数输入,命名参数也可当成位置参数输入 5 | 3. 通过dropna的方法,自动跳过停牌 6 | 7 | !!!函数太多,又想要智能提示,只能手工按需补充 8 | """ 9 | from .. import BY_ASSET, BY_DATE, BY_GROUP 10 | from ..aggregate import A_div_AB 11 | from ..alphas.alpha import CUMPROD 12 | from ..alphas.alpha import FILTER_191 13 | from ..regress import REGRESI 14 | from ..regress import SLOPE_YX 15 | from ..utils_long import dataframe_groupby_apply, series_groupby_apply 16 | 17 | # 特殊 18 | CUMPROD = series_groupby_apply(CUMPROD, by=BY_ASSET, to_kwargs={}) 19 | FILTER_191 = dataframe_groupby_apply(FILTER_191, by=BY_ASSET, to_kwargs={}, dropna=False) 20 | # 21 | 22 | 23 | SLOPE_YX = dataframe_groupby_apply(SLOPE_YX, by=BY_ASSET) 24 | REGRESI4 = dataframe_groupby_apply(REGRESI, by=BY_ASSET, to_df=[0, 1, 2, 3], to_kwargs={4: 'timeperiod'}) 25 | 26 | # 可用于 全部市场宽度 27 | A_div_AB_1 = series_groupby_apply(A_div_AB, by=BY_DATE, to_kwargs={}) 28 | # 可用于 板块市场宽度 29 | A_div_AB_2 = dataframe_groupby_apply(A_div_AB, by=BY_GROUP, to_df=[0, 'group'], to_kwargs={}) 30 | -------------------------------------------------------------------------------- /ta_cn/imports/long_ta.py: -------------------------------------------------------------------------------- 1 | """ 2 | TALIB库,长表模式,跳过空值 3 | """ 4 | import talib as _talib 5 | from talib import abstract as _abstract 6 | 7 | from .. import BY_ASSET 8 | from ..utils_long import dataframe_groupby_apply, series_groupby_apply 9 | 10 | __FUNCS__ = {} 11 | for i, func_name in enumerate(_talib.get_functions()): 12 | """talib遍历""" 13 | 14 | 15 | def _input_names_len(kvs): 16 | cnt = 0 17 | for k, v in kvs.items(): 18 | if isinstance(v, list): 19 | cnt += len(v) 20 | else: 21 | cnt += 1 22 | return cnt 23 | 24 | 25 | # 导入原版talib 26 | _ta_func = getattr(_talib, func_name) 27 | info = _abstract.Function(func_name).info 28 | 29 | input_num = _input_names_len(info['input_names']) 30 | para_num = len(info['parameters']) 31 | output_num = len(info['output_names']) 32 | 33 | to_kwargs = {i + input_num: p for i, p in enumerate(info['parameters'])} 34 | dropna = True if info['group'] in ('Math Transform',) else False 35 | 36 | if input_num == 1: 37 | _ta_long = series_groupby_apply(_ta_func, by=BY_ASSET, dropna=dropna, 38 | to_kwargs=to_kwargs, output_num=output_num) 39 | else: 40 | _ta_long = dataframe_groupby_apply(_ta_func, by=BY_ASSET, dropna=dropna, 41 | to_df=range(input_num), to_kwargs=to_kwargs, output_num=output_num) 42 | 43 | __FUNCS__[func_name] = _ta_long 44 | 45 | # Overlap Studies 46 | BBANDS = __FUNCS__['BBANDS'] 47 | DEMA = __FUNCS__['DEMA'] 48 | EMA = __FUNCS__['EMA'] 49 | HT_TRENDLINE = __FUNCS__['HT_TRENDLINE'] 50 | KAMA = __FUNCS__['KAMA'] 51 | MA = __FUNCS__['MA'] 52 | MAMA = __FUNCS__['MAMA'] 53 | MAVP = __FUNCS__['MAVP'] 54 | MIDPOINT = __FUNCS__['MIDPOINT'] 55 | MIDPRICE = __FUNCS__['MIDPRICE'] 56 | SAR = __FUNCS__['SAR'] 57 | SAREXT = __FUNCS__['SAREXT'] 58 | SMA = __FUNCS__['SMA'] 59 | T3 = __FUNCS__['T3'] 60 | TEMA = __FUNCS__['TEMA'] 61 | TRIMA = __FUNCS__['TRIMA'] 62 | WMA = __FUNCS__['WMA'] 63 | 64 | # Momentum Indicators 65 | ADX = __FUNCS__['ADX'] 66 | ADXR = __FUNCS__['ADXR'] 67 | APO = __FUNCS__['APO'] 68 | AROON = __FUNCS__['AROON'] 69 | AROONOSC = __FUNCS__['AROONOSC'] 70 | BOP = __FUNCS__['BOP'] 71 | CCI = __FUNCS__['CCI'] 72 | CMO = __FUNCS__['CMO'] 73 | DX = __FUNCS__['DX'] 74 | MACD = __FUNCS__['MACD'] 75 | MACDEXT = __FUNCS__['MACDEXT'] 76 | MACDFIX = __FUNCS__['MACDFIX'] 77 | MFI = __FUNCS__['MFI'] 78 | MINUS_DI = __FUNCS__['MINUS_DI'] 79 | MINUS_DM = __FUNCS__['MINUS_DM'] 80 | MOM = __FUNCS__['MOM'] 81 | PLUS_DI = __FUNCS__['PLUS_DI'] 82 | PLUS_DM = __FUNCS__['PLUS_DM'] 83 | PPO = __FUNCS__['PPO'] 84 | ROC = __FUNCS__['ROC'] 85 | ROCP = __FUNCS__['ROCP'] 86 | ROCR = __FUNCS__['ROCR'] 87 | ROCR100 = __FUNCS__['ROCR100'] 88 | RSI = __FUNCS__['RSI'] 89 | STOCH = __FUNCS__['STOCH'] 90 | STOCHF = __FUNCS__['STOCHF'] 91 | STOCHRSI = __FUNCS__['STOCHRSI'] 92 | TRIX = __FUNCS__['TRIX'] 93 | ULTOSC = __FUNCS__['ULTOSC'] 94 | WILLR = __FUNCS__['WILLR'] 95 | 96 | # Volume Indicators 97 | AD = __FUNCS__['AD'] 98 | ADOSC = __FUNCS__['ADOSC'] 99 | OBV = __FUNCS__['OBV'] 100 | 101 | # Volatility Indicators 102 | ATR = __FUNCS__['ATR'] 103 | NATR = __FUNCS__['NATR'] 104 | TRANGE = __FUNCS__['TRANGE'] 105 | 106 | # Price Transform 107 | AVGPRICE = __FUNCS__['AVGPRICE'] 108 | MEDPRICE = __FUNCS__['MEDPRICE'] 109 | TYPPRICE = __FUNCS__['TYPPRICE'] 110 | WCLPRICE = __FUNCS__['WCLPRICE'] 111 | 112 | # Cycle Indicators 113 | HT_DCPERIOD = __FUNCS__['HT_DCPERIOD'] 114 | HT_DCPHASE = __FUNCS__['HT_DCPHASE'] 115 | HT_PHASOR = __FUNCS__['HT_PHASOR'] 116 | HT_SINE = __FUNCS__['HT_SINE'] 117 | HT_TRENDMODE = __FUNCS__['HT_TRENDMODE'] 118 | 119 | # Pattern Recognition 120 | CDL2CROWS = __FUNCS__['CDL2CROWS'] 121 | CDL3BLACKCROWS = __FUNCS__['CDL3BLACKCROWS'] 122 | CDL3INSIDE = __FUNCS__['CDL3INSIDE'] 123 | CDL3LINESTRIKE = __FUNCS__['CDL3LINESTRIKE'] 124 | CDL3OUTSIDE = __FUNCS__['CDL3OUTSIDE'] 125 | CDL3STARSINSOUTH = __FUNCS__['CDL3STARSINSOUTH'] 126 | CDL3WHITESOLDIERS = __FUNCS__['CDL3WHITESOLDIERS'] 127 | CDLABANDONEDBABY = __FUNCS__['CDLABANDONEDBABY'] 128 | CDLADVANCEBLOCK = __FUNCS__['CDLADVANCEBLOCK'] 129 | CDLBELTHOLD = __FUNCS__['CDLBELTHOLD'] 130 | CDLBREAKAWAY = __FUNCS__['CDLBREAKAWAY'] 131 | CDLCLOSINGMARUBOZU = __FUNCS__['CDLCLOSINGMARUBOZU'] 132 | CDLCONCEALBABYSWALL = __FUNCS__['CDLCONCEALBABYSWALL'] 133 | CDLCOUNTERATTACK = __FUNCS__['CDLCOUNTERATTACK'] 134 | CDLDARKCLOUDCOVER = __FUNCS__['CDLDARKCLOUDCOVER'] 135 | CDLDOJI = __FUNCS__['CDLDOJI'] 136 | CDLDOJISTAR = __FUNCS__['CDLDOJISTAR'] 137 | CDLDRAGONFLYDOJI = __FUNCS__['CDLDRAGONFLYDOJI'] 138 | CDLENGULFING = __FUNCS__['CDLENGULFING'] 139 | CDLEVENINGDOJISTAR = __FUNCS__['CDLEVENINGDOJISTAR'] 140 | CDLEVENINGSTAR = __FUNCS__['CDLEVENINGSTAR'] 141 | CDLGAPSIDESIDEWHITE = __FUNCS__['CDLGAPSIDESIDEWHITE'] 142 | CDLGRAVESTONEDOJI = __FUNCS__['CDLGRAVESTONEDOJI'] 143 | CDLHAMMER = __FUNCS__['CDLHAMMER'] 144 | CDLHANGINGMAN = __FUNCS__['CDLHANGINGMAN'] 145 | CDLHARAMI = __FUNCS__['CDLHARAMI'] 146 | CDLHARAMICROSS = __FUNCS__['CDLHARAMICROSS'] 147 | CDLHIGHWAVE = __FUNCS__['CDLHIGHWAVE'] 148 | CDLHIKKAKE = __FUNCS__['CDLHIKKAKE'] 149 | CDLHIKKAKEMOD = __FUNCS__['CDLHIKKAKEMOD'] 150 | CDLHOMINGPIGEON = __FUNCS__['CDLHOMINGPIGEON'] 151 | CDLIDENTICAL3CROWS = __FUNCS__['CDLIDENTICAL3CROWS'] 152 | CDLINNECK = __FUNCS__['CDLINNECK'] 153 | CDLINVERTEDHAMMER = __FUNCS__['CDLINVERTEDHAMMER'] 154 | CDLKICKING = __FUNCS__['CDLKICKING'] 155 | CDLKICKINGBYLENGTH = __FUNCS__['CDLKICKINGBYLENGTH'] 156 | CDLLADDERBOTTOM = __FUNCS__['CDLLADDERBOTTOM'] 157 | CDLLONGLEGGEDDOJI = __FUNCS__['CDLLONGLEGGEDDOJI'] 158 | CDLLONGLINE = __FUNCS__['CDLLONGLINE'] 159 | CDLMARUBOZU = __FUNCS__['CDLMARUBOZU'] 160 | CDLMATCHINGLOW = __FUNCS__['CDLMATCHINGLOW'] 161 | CDLMATHOLD = __FUNCS__['CDLMATHOLD'] 162 | CDLMORNINGDOJISTAR = __FUNCS__['CDLMORNINGDOJISTAR'] 163 | CDLMORNINGSTAR = __FUNCS__['CDLMORNINGSTAR'] 164 | CDLONNECK = __FUNCS__['CDLONNECK'] 165 | CDLPIERCING = __FUNCS__['CDLPIERCING'] 166 | CDLRICKSHAWMAN = __FUNCS__['CDLRICKSHAWMAN'] 167 | CDLRISEFALL3METHODS = __FUNCS__['CDLRISEFALL3METHODS'] 168 | CDLSEPARATINGLINES = __FUNCS__['CDLSEPARATINGLINES'] 169 | CDLSHOOTINGSTAR = __FUNCS__['CDLSHOOTINGSTAR'] 170 | CDLSHORTLINE = __FUNCS__['CDLSHORTLINE'] 171 | CDLSPINNINGTOP = __FUNCS__['CDLSPINNINGTOP'] 172 | CDLSTALLEDPATTERN = __FUNCS__['CDLSTALLEDPATTERN'] 173 | CDLSTICKSANDWICH = __FUNCS__['CDLSTICKSANDWICH'] 174 | CDLTAKURI = __FUNCS__['CDLTAKURI'] 175 | CDLTASUKIGAP = __FUNCS__['CDLTASUKIGAP'] 176 | CDLTHRUSTING = __FUNCS__['CDLTHRUSTING'] 177 | CDLTRISTAR = __FUNCS__['CDLTRISTAR'] 178 | CDLUNIQUE3RIVER = __FUNCS__['CDLUNIQUE3RIVER'] 179 | CDLUPSIDEGAP2CROWS = __FUNCS__['CDLUPSIDEGAP2CROWS'] 180 | CDLXSIDEGAP3METHODS = __FUNCS__['CDLXSIDEGAP3METHODS'] 181 | 182 | # Statistic Functions 183 | BETA = __FUNCS__['BETA'] 184 | CORREL = __FUNCS__['CORREL'] 185 | LINEARREG = __FUNCS__['LINEARREG'] 186 | LINEARREG_ANGLE = __FUNCS__['LINEARREG_ANGLE'] 187 | LINEARREG_INTERCEPT = __FUNCS__['LINEARREG_INTERCEPT'] 188 | LINEARREG_SLOPE = __FUNCS__['LINEARREG_SLOPE'] 189 | STDDEV = __FUNCS__['STDDEV'] 190 | TSF = __FUNCS__['TSF'] 191 | VAR = __FUNCS__['VAR'] 192 | 193 | # Math Transform 194 | ACOS = __FUNCS__['ACOS'] 195 | ASIN = __FUNCS__['ASIN'] 196 | ATAN = __FUNCS__['ATAN'] 197 | CEIL = __FUNCS__['CEIL'] 198 | COS = __FUNCS__['COS'] 199 | COSH = __FUNCS__['COSH'] 200 | EXP = __FUNCS__['EXP'] 201 | FLOOR = __FUNCS__['FLOOR'] 202 | LN = __FUNCS__['LN'] 203 | LOG10 = __FUNCS__['LOG10'] 204 | SIN = __FUNCS__['SIN'] 205 | SINH = __FUNCS__['SINH'] 206 | SQRT = __FUNCS__['SQRT'] 207 | TAN = __FUNCS__['TAN'] 208 | TANH = __FUNCS__['TANH'] 209 | 210 | # Math Operator 211 | ADD = __FUNCS__['ADD'] 212 | DIV = __FUNCS__['DIV'] 213 | MAX = __FUNCS__['MAX'] 214 | MAXINDEX = __FUNCS__['MAXINDEX'] 215 | MIN = __FUNCS__['MIN'] 216 | MININDEX = __FUNCS__['MININDEX'] 217 | MINMAX = __FUNCS__['MINMAX'] 218 | MINMAXINDEX = __FUNCS__['MINMAXINDEX'] 219 | MULT = __FUNCS__['MULT'] 220 | SUB = __FUNCS__['SUB'] 221 | SUM = __FUNCS__['SUM'] 222 | -------------------------------------------------------------------------------- /ta_cn/imports/long_tdx.py: -------------------------------------------------------------------------------- 1 | """ 2 | 通达信公式,长表模式,跳过空值 3 | """ 4 | from .. import BY_ASSET 5 | from ..ema import SMA_CN 6 | from ..tdx.logical import BETWEEN 7 | from ..tdx.logical import CROSS 8 | from ..tdx.logical import EVERY 9 | from ..tdx.logical import EXIST 10 | from ..tdx.logical import LAST 11 | from ..tdx.logical import VALUEWHEN 12 | from ..tdx.over_bought_over_sold import ATR_CN 13 | from ..tdx.over_bought_over_sold import BIAS 14 | from ..tdx.over_bought_over_sold import KDJ 15 | from ..tdx.over_bought_over_sold import MEDPRICE 16 | from ..tdx.over_bought_over_sold import ROC 17 | from ..tdx.over_bought_over_sold import RSI 18 | from ..tdx.over_bought_over_sold import TYPPRICE 19 | from ..tdx.over_bought_over_sold import WR 20 | from ..tdx.pressure_support import BOLL 21 | from ..tdx.reference import BARSLAST 22 | from ..tdx.reference import BARSLASTCOUNT 23 | from ..tdx.reference import BARSSINCEN 24 | from ..tdx.reference import CONST 25 | from ..tdx.reference import FILTER 26 | from ..tdx.reference import SUMIF 27 | from ..tdx.reference import TR 28 | from ..tdx.statistics import AVEDEV 29 | from ..tdx.statistics import STD 30 | from ..tdx.statistics import STDP 31 | from ..tdx.statistics import VAR 32 | from ..tdx.statistics import VARP 33 | from ..tdx.trend import BBI 34 | from ..tdx.trend import DI 35 | from ..tdx.trend import DM 36 | from ..tdx.trend import DMI 37 | from ..tdx.trend import DPO 38 | from ..tdx.trend import MACD 39 | from ..tdx.trend import MTM 40 | from ..tdx.trend import PSY 41 | from ..tdx.trend import TRIX 42 | from ..tdx.volume import OBV 43 | from ..tdx.volume import VR 44 | from ..utils_long import dataframe_groupby_apply, series_groupby_apply 45 | 46 | # 逻辑函数 47 | CROSS = dataframe_groupby_apply(CROSS, by=BY_ASSET, to_kwargs={}) 48 | EVERY = series_groupby_apply(EVERY, by=BY_ASSET) 49 | EXIST = series_groupby_apply(EXIST, by=BY_ASSET) 50 | BETWEEN = dataframe_groupby_apply(BETWEEN, by=BY_ASSET, to_df=[0, 1, 2], to_kwargs={}) 51 | VALUEWHEN = dataframe_groupby_apply(VALUEWHEN, by=BY_ASSET, to_df=[0, 1], to_kwargs={}) 52 | LAST = series_groupby_apply(LAST, by=BY_ASSET, to_kwargs={1: 'n', 2: 'm'}) 53 | 54 | # 超买超卖 55 | ATR_CN = dataframe_groupby_apply(ATR_CN, by=BY_ASSET, to_df=[0, 1, 2], to_kwargs={3: 'timeperiod'}) 56 | BIAS = series_groupby_apply(BIAS, by=BY_ASSET) 57 | KDJ = dataframe_groupby_apply(KDJ, by=BY_ASSET, to_df=[0, 1, 2], 58 | to_kwargs={3: 'fastk_period', 4: 'M1', 5: 'M2'}, output_num=3) 59 | ROC = series_groupby_apply(ROC, by=BY_ASSET) 60 | TYPPRICE = dataframe_groupby_apply(TYPPRICE, by=BY_ASSET, to_df=[0, 1, 2], to_kwargs={}) 61 | MEDPRICE = dataframe_groupby_apply(MEDPRICE, by=BY_ASSET, to_df=[0, 1], to_kwargs={}) 62 | WR = dataframe_groupby_apply(WR, by=BY_ASSET, to_df=[0, 1, 2], to_kwargs={3: 'timeperiod'}) 63 | RSI = series_groupby_apply(RSI, by=BY_ASSET) 64 | 65 | # 压力支撑 66 | BOLL = series_groupby_apply(BOLL, by=BY_ASSET, to_kwargs={1: 'timeperiod', 2: 'nbdevup', 3: 'nbdevdn'}, output_num=3) 67 | 68 | # 引用 69 | CONST = series_groupby_apply(CONST, by=BY_ASSET, to_kwargs={}) 70 | SUMIF = dataframe_groupby_apply(SUMIF, by=BY_ASSET, to_df=[0, 1], to_kwargs={2: 'timeperiod'}) 71 | TR = dataframe_groupby_apply(TR, by=BY_ASSET, to_df=[0, 1, 2], to_kwargs={}) 72 | FILTER = series_groupby_apply(FILTER, by=BY_ASSET, to_kwargs={1: 'N'}) 73 | BARSLAST = series_groupby_apply(BARSLAST, by=BY_ASSET, to_kwargs={}) 74 | BARSLASTCOUNT = series_groupby_apply(BARSLASTCOUNT, by=BY_ASSET, to_kwargs={}) 75 | BARSSINCEN = series_groupby_apply(BARSSINCEN, by=BY_ASSET) 76 | 77 | # 统计 78 | AVEDEV = series_groupby_apply(AVEDEV, by=BY_ASSET) 79 | STD = series_groupby_apply(STD, by=BY_ASSET, to_kwargs={1: 'd'}) 80 | STDP = series_groupby_apply(STDP, by=BY_ASSET, to_kwargs={1: 'd'}) 81 | VAR = series_groupby_apply(VAR, by=BY_ASSET, to_kwargs={1: 'd'}) 82 | VARP = series_groupby_apply(VARP, by=BY_ASSET, to_kwargs={1: 'd'}) 83 | 84 | # 趋势 85 | BBI = series_groupby_apply(BBI, by=BY_ASSET, 86 | to_kwargs={1: 'timeperiod1', 2: 'timeperiod2', 3: 'timeperiod3', 4: 'timeperiod4'}) 87 | DPO = series_groupby_apply(DPO, by=BY_ASSET) 88 | MACD = series_groupby_apply(MACD, by=BY_ASSET, to_kwargs={1: 'fastperiod', 2: 'slowperiod', 3: 'signalperiod'}, 89 | output_num=3) 90 | MTM = series_groupby_apply(MTM, by=BY_ASSET) 91 | PSY = series_groupby_apply(PSY, by=BY_ASSET) 92 | DM = dataframe_groupby_apply(DM, by=BY_ASSET, to_df=[0, 1], to_kwargs={2: 'timeperiod'}, output_num=2) 93 | DI = dataframe_groupby_apply(DI, by=BY_ASSET, to_df=[0, 1, 2], to_kwargs={3: 'timeperiod'}, output_num=2) 94 | DMI = dataframe_groupby_apply(DMI, by=BY_ASSET, to_df=[0, 1, 2], to_kwargs={3: 'timeperiod'}, output_num=2) 95 | TRIX = series_groupby_apply(TRIX, by=BY_ASSET) 96 | 97 | # 成交量 98 | OBV = dataframe_groupby_apply(OBV, by=BY_ASSET, to_df=[0, 1], to_kwargs={2: 'scale'}) 99 | VR = dataframe_groupby_apply(VR, by=BY_ASSET, to_df=[0, 1], to_kwargs={2: 'timeperiod'}) 100 | 101 | # EMA系列 102 | SMA_CN = series_groupby_apply(SMA_CN, by=BY_ASSET, to_kwargs={1: 'timeperiod', 2: 'M'}) 103 | 104 | # WQ已经定义过的公式,通达信中别名 105 | from .long_wq import abs_ as ABS 106 | from .long_wq import add as ADD 107 | from .long_wq import divide as DIV 108 | from .long_wq import log as LN # 自然对数 109 | from .long_wq import log10 as LOG # 10为底的对数 110 | from .long_wq import max_ as MAX 111 | from .long_wq import mean as MEAN 112 | from .long_wq import min_ as MIN 113 | from .long_wq import multiply as MUL 114 | from .long_wq import round_ as ROUND 115 | from .long_wq import sign as SGN 116 | from .long_wq import subtract as SUB 117 | from .long_wq import if_else as IF 118 | from .long_wq import ts_count as COUNT 119 | from .long_wq import ts_delay as REF 120 | from .long_wq import ts_delta as DIFF 121 | from .long_wq import ts_max as HHV 122 | from .long_wq import ts_arg_max as HHVBARS 123 | from .long_wq import ts_mean as MA 124 | from .long_wq import ts_min as LLV 125 | from .long_wq import ts_arg_min as LLVBARS 126 | from .long_wq import ts_sum as SUM 127 | from .long_wq import rank as RANK 128 | from .long_wq import power as POW 129 | 130 | ABS 131 | MAX 132 | MIN 133 | REF 134 | HHV 135 | HHVBARS 136 | MA 137 | LLV 138 | LLVBARS 139 | SUM 140 | ADD 141 | SUB 142 | MUL 143 | DIV 144 | ROUND 145 | MEAN 146 | LN 147 | LOG 148 | SGN 149 | DIFF 150 | IF 151 | COUNT 152 | RANK 153 | POW 154 | -------------------------------------------------------------------------------- /ta_cn/imports/long_wq.py: -------------------------------------------------------------------------------- 1 | """ 2 | WQ公式,长表模式,跳过空值 3 | """ 4 | from .. import BY_ASSET, BY_DATE 5 | from ..utils import to_pd 6 | from ..utils_long import series_groupby_apply, dataframe_groupby_apply 7 | from ..wq.arithmetic import abs_ 8 | from ..wq.arithmetic import add 9 | from ..wq.arithmetic import ceiling 10 | from ..wq.arithmetic import densify 11 | from ..wq.arithmetic import divide 12 | from ..wq.arithmetic import exp 13 | from ..wq.arithmetic import floor 14 | from ..wq.arithmetic import fraction 15 | from ..wq.arithmetic import inverse 16 | from ..wq.arithmetic import log 17 | from ..wq.arithmetic import log10 18 | from ..wq.arithmetic import log_diff 19 | from ..wq.arithmetic import max_ 20 | from ..wq.arithmetic import mean 21 | from ..wq.arithmetic import min_ 22 | from ..wq.arithmetic import multiply 23 | from ..wq.arithmetic import nan_mask 24 | from ..wq.arithmetic import nan_out 25 | from ..wq.arithmetic import power 26 | from ..wq.arithmetic import purify 27 | from ..wq.arithmetic import replace 28 | from ..wq.arithmetic import reverse 29 | from ..wq.arithmetic import round_ 30 | from ..wq.arithmetic import round_down 31 | from ..wq.arithmetic import s_log_1p 32 | from ..wq.arithmetic import sign 33 | from ..wq.arithmetic import signed_power 34 | from ..wq.arithmetic import sqrt 35 | from ..wq.arithmetic import subtract 36 | from ..wq.arithmetic import to_nan 37 | from ..wq.cross_sectional import rank 38 | from ..wq.cross_sectional import scale 39 | from ..wq.group import group_neutralize 40 | from ..wq.logical import if_else 41 | from ..wq.logical import less 42 | from ..wq.time_series import ts_arg_max 43 | from ..wq.time_series import ts_arg_min 44 | from ..wq.time_series import ts_corr 45 | from ..wq.time_series import ts_count 46 | from ..wq.time_series import ts_covariance 47 | from ..wq.time_series import ts_decay_linear 48 | from ..wq.time_series import ts_delay 49 | from ..wq.time_series import ts_delta 50 | from ..wq.time_series import ts_max 51 | from ..wq.time_series import ts_mean 52 | from ..wq.time_series import ts_min 53 | from ..wq.time_series import ts_product 54 | from ..wq.time_series import ts_rank 55 | from ..wq.time_series import ts_std_dev 56 | from ..wq.time_series import ts_sum 57 | 58 | # Arithmetic Operators 59 | abs_ = to_pd(abs_) 60 | add = to_pd(add) 61 | ceiling = to_pd(ceiling) 62 | divide = to_pd(divide) 63 | exp = to_pd(exp) 64 | floor = to_pd(floor) 65 | fraction = to_pd(fraction) 66 | inverse = to_pd(inverse) 67 | log = to_pd(log) 68 | log_diff = to_pd(log_diff) 69 | max_ = to_pd(max_) 70 | min_ = to_pd(min_) 71 | multiply = to_pd(multiply) 72 | nan_mask = to_pd(nan_mask) 73 | nan_out = to_pd(nan_out) 74 | power = to_pd(power) 75 | purify = to_pd(purify) 76 | replace = to_pd(replace) 77 | reverse = to_pd(reverse) 78 | round_ = to_pd(round_) 79 | round_down = to_pd(round_down) 80 | sign = to_pd(sign) 81 | signed_power = to_pd(signed_power) 82 | s_log_1p = to_pd(s_log_1p) 83 | sqrt = to_pd(sqrt) 84 | subtract = to_pd(subtract) 85 | to_nan = to_pd(to_nan) 86 | densify = to_pd(densify) 87 | log10 = to_pd(log10) 88 | mean = to_pd(mean) 89 | 90 | # Vector Operators 91 | # Logical Operators 92 | less = to_pd(less) 93 | if_else = to_pd(if_else) 94 | 95 | # Transformational Operators 96 | 97 | # Cross Sectional Operators 98 | rank = series_groupby_apply(rank, by=BY_DATE, to_kwargs={}) 99 | scale = series_groupby_apply(scale, by=BY_DATE, to_kwargs={1: 'scale'}, dropna=False) 100 | 101 | # Group Operators 102 | group_neutralize = group_neutralize 103 | 104 | # Time Series Operators 105 | ts_arg_max = series_groupby_apply(ts_arg_max, by=BY_ASSET, to_kwargs={1: 'd'}) 106 | ts_arg_min = series_groupby_apply(ts_arg_min, by=BY_ASSET, to_kwargs={1: 'd'}) 107 | ts_corr = dataframe_groupby_apply(ts_corr, by=BY_ASSET, to_kwargs={2: 'd'}) 108 | ts_count = series_groupby_apply(ts_count, by=BY_ASSET, to_kwargs={1: 'd'}) 109 | ts_covariance = dataframe_groupby_apply(ts_covariance, by=BY_ASSET, to_kwargs={2: 'd'}) 110 | ts_decay_linear = series_groupby_apply(ts_decay_linear, by=BY_ASSET, to_kwargs={1: 'd'}) 111 | ts_delay = series_groupby_apply(ts_delay, by=BY_ASSET, to_kwargs={1: 'd'}) 112 | ts_delta = series_groupby_apply(ts_delta, by=BY_ASSET, to_kwargs={1: 'd'}) 113 | ts_max = series_groupby_apply(ts_max, by=BY_ASSET, to_kwargs={1: 'd'}) 114 | ts_mean = series_groupby_apply(ts_mean, by=BY_ASSET, to_kwargs={1: 'd'}) 115 | ts_min = series_groupby_apply(ts_min, by=BY_ASSET, to_kwargs={1: 'd'}) 116 | ts_product = series_groupby_apply(ts_product, by=BY_ASSET, to_kwargs={1: 'd'}, dropna=False) 117 | ts_rank = series_groupby_apply(ts_rank, by=BY_ASSET, to_kwargs={1: 'd'}) 118 | ts_std_dev = series_groupby_apply(ts_std_dev, by=BY_ASSET, to_kwargs={1: 'd'}) 119 | ts_sum = series_groupby_apply(ts_sum, by=BY_ASSET, to_kwargs={1: 'd'}) 120 | 121 | # Special Operators 122 | -------------------------------------------------------------------------------- /ta_cn/imports/wide.py: -------------------------------------------------------------------------------- 1 | """ 2 | 对指标的算子化包装 3 | 1. 包装成只支持 宽表 输入,输出是特殊格式,需要处理得到输出 4 | 2. 简化参数输入,命名参数也可当成位置参数输入 5 | 3. 通过堆叠的方法,自动跳过停牌 6 | 7 | !!!函数太多,又想要智能提示,只能手工按需补充 8 | """ 9 | from ..alphas.alpha import CUMPROD 10 | from ..alphas.alpha import FILTER_191 11 | from ..regress import REGRESI 12 | from ..regress import SLOPE_YX 13 | from ..utils_wide import wide_wraps 14 | 15 | # 16 | CUMPROD = wide_wraps(CUMPROD, to_kwargs={}) 17 | FILTER_191 = wide_wraps(FILTER_191, input_num=2, to_kwargs={}) 18 | 19 | 20 | 21 | SLOPE_YX = wide_wraps(SLOPE_YX, input_num=2, to_kwargs={2: 'timeperiod'}) 22 | REGRESI4 = wide_wraps(REGRESI, input_num=4, to_kwargs={4: 'timeperiod'}) 23 | 24 | # 长表转宽表 25 | 26 | -------------------------------------------------------------------------------- /ta_cn/imports/wide_ta.py: -------------------------------------------------------------------------------- 1 | """ 2 | TALIB库,宽表模式,跳过空值 3 | """ 4 | import talib as _talib 5 | from talib import abstract as _abstract 6 | 7 | from .. import talib as ta 8 | from ..utils_wide import wide_wraps 9 | 10 | # 二维TALIB 11 | _ta2d = ta.init(mode=2, skipna=False, to_globals=False) 12 | 13 | __FUNCS__ = {} 14 | for i, func_name in enumerate(_talib.get_functions()): 15 | """talib遍历""" 16 | 17 | 18 | def _input_names_len(kvs): 19 | cnt = 0 20 | for k, v in kvs.items(): 21 | if isinstance(v, list): 22 | cnt += len(v) 23 | else: 24 | cnt += 1 25 | return cnt 26 | 27 | 28 | # 导入二维封装版talib 29 | _ta_func = getattr(_ta2d, func_name) 30 | info = _abstract.Function(func_name).info 31 | 32 | input_num = _input_names_len(info['input_names']) 33 | para_num = len(info['parameters']) 34 | output_num = len(info['output_names']) 35 | 36 | to_kwargs = {i + input_num: p for i, p in enumerate(info['parameters'])} 37 | direction = None if info['group'] in ('Math Transform',) else 'down' 38 | 39 | _ta_wide = wide_wraps(_ta_func, direction=direction, 40 | input_num=input_num, to_kwargs=to_kwargs, output_num=output_num) 41 | 42 | __FUNCS__[func_name] = _ta_wide 43 | 44 | # Overlap Studies 45 | BBANDS = __FUNCS__['BBANDS'] 46 | DEMA = __FUNCS__['DEMA'] 47 | EMA = __FUNCS__['EMA'] 48 | HT_TRENDLINE = __FUNCS__['HT_TRENDLINE'] 49 | KAMA = __FUNCS__['KAMA'] 50 | MA = __FUNCS__['MA'] 51 | MAMA = __FUNCS__['MAMA'] 52 | MAVP = __FUNCS__['MAVP'] 53 | MIDPOINT = __FUNCS__['MIDPOINT'] 54 | MIDPRICE = __FUNCS__['MIDPRICE'] 55 | SAR = __FUNCS__['SAR'] 56 | SAREXT = __FUNCS__['SAREXT'] 57 | SMA = __FUNCS__['SMA'] 58 | T3 = __FUNCS__['T3'] 59 | TEMA = __FUNCS__['TEMA'] 60 | TRIMA = __FUNCS__['TRIMA'] 61 | WMA = __FUNCS__['WMA'] 62 | 63 | # Momentum Indicators 64 | ADX = __FUNCS__['ADX'] 65 | ADXR = __FUNCS__['ADXR'] 66 | APO = __FUNCS__['APO'] 67 | AROON = __FUNCS__['AROON'] 68 | AROONOSC = __FUNCS__['AROONOSC'] 69 | BOP = __FUNCS__['BOP'] 70 | CCI = __FUNCS__['CCI'] 71 | CMO = __FUNCS__['CMO'] 72 | DX = __FUNCS__['DX'] 73 | MACD = __FUNCS__['MACD'] 74 | MACDEXT = __FUNCS__['MACDEXT'] 75 | MACDFIX = __FUNCS__['MACDFIX'] 76 | MFI = __FUNCS__['MFI'] 77 | MINUS_DI = __FUNCS__['MINUS_DI'] 78 | MINUS_DM = __FUNCS__['MINUS_DM'] 79 | MOM = __FUNCS__['MOM'] 80 | PLUS_DI = __FUNCS__['PLUS_DI'] 81 | PLUS_DM = __FUNCS__['PLUS_DM'] 82 | PPO = __FUNCS__['PPO'] 83 | ROC = __FUNCS__['ROC'] 84 | ROCP = __FUNCS__['ROCP'] 85 | ROCR = __FUNCS__['ROCR'] 86 | ROCR100 = __FUNCS__['ROCR100'] 87 | RSI = __FUNCS__['RSI'] 88 | STOCH = __FUNCS__['STOCH'] 89 | STOCHF = __FUNCS__['STOCHF'] 90 | STOCHRSI = __FUNCS__['STOCHRSI'] 91 | TRIX = __FUNCS__['TRIX'] 92 | ULTOSC = __FUNCS__['ULTOSC'] 93 | WILLR = __FUNCS__['WILLR'] 94 | 95 | # Volume Indicators 96 | AD = __FUNCS__['AD'] 97 | ADOSC = __FUNCS__['ADOSC'] 98 | OBV = __FUNCS__['OBV'] 99 | 100 | # Volatility Indicators 101 | ATR = __FUNCS__['ATR'] 102 | NATR = __FUNCS__['NATR'] 103 | TRANGE = __FUNCS__['TRANGE'] 104 | 105 | # Price Transform 106 | AVGPRICE = __FUNCS__['AVGPRICE'] 107 | MEDPRICE = __FUNCS__['MEDPRICE'] 108 | TYPPRICE = __FUNCS__['TYPPRICE'] 109 | WCLPRICE = __FUNCS__['WCLPRICE'] 110 | 111 | # Cycle Indicators 112 | HT_DCPERIOD = __FUNCS__['HT_DCPERIOD'] 113 | HT_DCPHASE = __FUNCS__['HT_DCPHASE'] 114 | HT_PHASOR = __FUNCS__['HT_PHASOR'] 115 | HT_SINE = __FUNCS__['HT_SINE'] 116 | HT_TRENDMODE = __FUNCS__['HT_TRENDMODE'] 117 | 118 | # Pattern Recognition 119 | CDL2CROWS = __FUNCS__['CDL2CROWS'] 120 | CDL3BLACKCROWS = __FUNCS__['CDL3BLACKCROWS'] 121 | CDL3INSIDE = __FUNCS__['CDL3INSIDE'] 122 | CDL3LINESTRIKE = __FUNCS__['CDL3LINESTRIKE'] 123 | CDL3OUTSIDE = __FUNCS__['CDL3OUTSIDE'] 124 | CDL3STARSINSOUTH = __FUNCS__['CDL3STARSINSOUTH'] 125 | CDL3WHITESOLDIERS = __FUNCS__['CDL3WHITESOLDIERS'] 126 | CDLABANDONEDBABY = __FUNCS__['CDLABANDONEDBABY'] 127 | CDLADVANCEBLOCK = __FUNCS__['CDLADVANCEBLOCK'] 128 | CDLBELTHOLD = __FUNCS__['CDLBELTHOLD'] 129 | CDLBREAKAWAY = __FUNCS__['CDLBREAKAWAY'] 130 | CDLCLOSINGMARUBOZU = __FUNCS__['CDLCLOSINGMARUBOZU'] 131 | CDLCONCEALBABYSWALL = __FUNCS__['CDLCONCEALBABYSWALL'] 132 | CDLCOUNTERATTACK = __FUNCS__['CDLCOUNTERATTACK'] 133 | CDLDARKCLOUDCOVER = __FUNCS__['CDLDARKCLOUDCOVER'] 134 | CDLDOJI = __FUNCS__['CDLDOJI'] 135 | CDLDOJISTAR = __FUNCS__['CDLDOJISTAR'] 136 | CDLDRAGONFLYDOJI = __FUNCS__['CDLDRAGONFLYDOJI'] 137 | CDLENGULFING = __FUNCS__['CDLENGULFING'] 138 | CDLEVENINGDOJISTAR = __FUNCS__['CDLEVENINGDOJISTAR'] 139 | CDLEVENINGSTAR = __FUNCS__['CDLEVENINGSTAR'] 140 | CDLGAPSIDESIDEWHITE = __FUNCS__['CDLGAPSIDESIDEWHITE'] 141 | CDLGRAVESTONEDOJI = __FUNCS__['CDLGRAVESTONEDOJI'] 142 | CDLHAMMER = __FUNCS__['CDLHAMMER'] 143 | CDLHANGINGMAN = __FUNCS__['CDLHANGINGMAN'] 144 | CDLHARAMI = __FUNCS__['CDLHARAMI'] 145 | CDLHARAMICROSS = __FUNCS__['CDLHARAMICROSS'] 146 | CDLHIGHWAVE = __FUNCS__['CDLHIGHWAVE'] 147 | CDLHIKKAKE = __FUNCS__['CDLHIKKAKE'] 148 | CDLHIKKAKEMOD = __FUNCS__['CDLHIKKAKEMOD'] 149 | CDLHOMINGPIGEON = __FUNCS__['CDLHOMINGPIGEON'] 150 | CDLIDENTICAL3CROWS = __FUNCS__['CDLIDENTICAL3CROWS'] 151 | CDLINNECK = __FUNCS__['CDLINNECK'] 152 | CDLINVERTEDHAMMER = __FUNCS__['CDLINVERTEDHAMMER'] 153 | CDLKICKING = __FUNCS__['CDLKICKING'] 154 | CDLKICKINGBYLENGTH = __FUNCS__['CDLKICKINGBYLENGTH'] 155 | CDLLADDERBOTTOM = __FUNCS__['CDLLADDERBOTTOM'] 156 | CDLLONGLEGGEDDOJI = __FUNCS__['CDLLONGLEGGEDDOJI'] 157 | CDLLONGLINE = __FUNCS__['CDLLONGLINE'] 158 | CDLMARUBOZU = __FUNCS__['CDLMARUBOZU'] 159 | CDLMATCHINGLOW = __FUNCS__['CDLMATCHINGLOW'] 160 | CDLMATHOLD = __FUNCS__['CDLMATHOLD'] 161 | CDLMORNINGDOJISTAR = __FUNCS__['CDLMORNINGDOJISTAR'] 162 | CDLMORNINGSTAR = __FUNCS__['CDLMORNINGSTAR'] 163 | CDLONNECK = __FUNCS__['CDLONNECK'] 164 | CDLPIERCING = __FUNCS__['CDLPIERCING'] 165 | CDLRICKSHAWMAN = __FUNCS__['CDLRICKSHAWMAN'] 166 | CDLRISEFALL3METHODS = __FUNCS__['CDLRISEFALL3METHODS'] 167 | CDLSEPARATINGLINES = __FUNCS__['CDLSEPARATINGLINES'] 168 | CDLSHOOTINGSTAR = __FUNCS__['CDLSHOOTINGSTAR'] 169 | CDLSHORTLINE = __FUNCS__['CDLSHORTLINE'] 170 | CDLSPINNINGTOP = __FUNCS__['CDLSPINNINGTOP'] 171 | CDLSTALLEDPATTERN = __FUNCS__['CDLSTALLEDPATTERN'] 172 | CDLSTICKSANDWICH = __FUNCS__['CDLSTICKSANDWICH'] 173 | CDLTAKURI = __FUNCS__['CDLTAKURI'] 174 | CDLTASUKIGAP = __FUNCS__['CDLTASUKIGAP'] 175 | CDLTHRUSTING = __FUNCS__['CDLTHRUSTING'] 176 | CDLTRISTAR = __FUNCS__['CDLTRISTAR'] 177 | CDLUNIQUE3RIVER = __FUNCS__['CDLUNIQUE3RIVER'] 178 | CDLUPSIDEGAP2CROWS = __FUNCS__['CDLUPSIDEGAP2CROWS'] 179 | CDLXSIDEGAP3METHODS = __FUNCS__['CDLXSIDEGAP3METHODS'] 180 | 181 | # Statistic Functions 182 | BETA = __FUNCS__['BETA'] 183 | CORREL = __FUNCS__['CORREL'] 184 | LINEARREG = __FUNCS__['LINEARREG'] 185 | LINEARREG_ANGLE = __FUNCS__['LINEARREG_ANGLE'] 186 | LINEARREG_INTERCEPT = __FUNCS__['LINEARREG_INTERCEPT'] 187 | LINEARREG_SLOPE = __FUNCS__['LINEARREG_SLOPE'] 188 | STDDEV = __FUNCS__['STDDEV'] 189 | TSF = __FUNCS__['TSF'] 190 | VAR = __FUNCS__['VAR'] 191 | 192 | # Math Transform 193 | ACOS = __FUNCS__['ACOS'] 194 | ASIN = __FUNCS__['ASIN'] 195 | ATAN = __FUNCS__['ATAN'] 196 | CEIL = __FUNCS__['CEIL'] 197 | COS = __FUNCS__['COS'] 198 | COSH = __FUNCS__['COSH'] 199 | EXP = __FUNCS__['EXP'] 200 | FLOOR = __FUNCS__['FLOOR'] 201 | LN = __FUNCS__['LN'] 202 | LOG10 = __FUNCS__['LOG10'] 203 | SIN = __FUNCS__['SIN'] 204 | SINH = __FUNCS__['SINH'] 205 | SQRT = __FUNCS__['SQRT'] 206 | TAN = __FUNCS__['TAN'] 207 | TANH = __FUNCS__['TANH'] 208 | 209 | # Math Operator 210 | ADD = __FUNCS__['ADD'] 211 | DIV = __FUNCS__['DIV'] 212 | MAX = __FUNCS__['MAX'] 213 | MAXINDEX = __FUNCS__['MAXINDEX'] 214 | MIN = __FUNCS__['MIN'] 215 | MININDEX = __FUNCS__['MININDEX'] 216 | MINMAX = __FUNCS__['MINMAX'] 217 | MINMAXINDEX = __FUNCS__['MINMAXINDEX'] 218 | MULT = __FUNCS__['MULT'] 219 | SUB = __FUNCS__['SUB'] 220 | SUM = __FUNCS__['SUM'] 221 | -------------------------------------------------------------------------------- /ta_cn/imports/wide_tdx.py: -------------------------------------------------------------------------------- 1 | """ 2 | 通达信公式,宽表模式,跳过空值 3 | """ 4 | from ..ema import SMA_CN 5 | from ..tdx.logical import BETWEEN 6 | from ..tdx.logical import CROSS 7 | from ..tdx.logical import EVERY 8 | from ..tdx.logical import EXIST 9 | from ..tdx.logical import LAST 10 | from ..tdx.logical import VALUEWHEN 11 | from ..tdx.over_bought_over_sold import ATR_CN 12 | from ..tdx.over_bought_over_sold import BIAS 13 | from ..tdx.over_bought_over_sold import KDJ 14 | from ..tdx.over_bought_over_sold import MEDPRICE 15 | from ..tdx.over_bought_over_sold import ROC 16 | from ..tdx.over_bought_over_sold import RSI 17 | from ..tdx.over_bought_over_sold import TYPPRICE 18 | from ..tdx.pressure_support import BOLL 19 | from ..tdx.reference import BARSLAST 20 | from ..tdx.reference import BARSLASTCOUNT 21 | from ..tdx.reference import BARSSINCEN 22 | from ..tdx.reference import CONST 23 | from ..tdx.reference import FILTER 24 | from ..tdx.reference import SUMIF 25 | from ..tdx.reference import TR 26 | from ..tdx.statistics import AVEDEV 27 | from ..tdx.statistics import STD 28 | from ..tdx.statistics import STDP 29 | from ..tdx.statistics import VAR 30 | from ..tdx.statistics import VARP 31 | from ..tdx.trend import BBI 32 | from ..tdx.trend import DI 33 | from ..tdx.trend import DM 34 | from ..tdx.trend import DMI 35 | from ..tdx.trend import DPO 36 | from ..tdx.trend import MACD 37 | from ..tdx.trend import MTM 38 | from ..tdx.trend import PSY 39 | from ..tdx.trend import TRIX 40 | from ..tdx.volume import OBV 41 | from ..tdx.volume import VR 42 | 43 | from ..utils_wide import wide_wraps 44 | 45 | # 逻辑函数 46 | CROSS = wide_wraps(CROSS, input_num=2, to_kwargs={}) 47 | EVERY = wide_wraps(EVERY) 48 | EXIST = wide_wraps(EXIST) 49 | BETWEEN = wide_wraps(BETWEEN, input_num=3, to_kwargs={}) 50 | VALUEWHEN = wide_wraps(VALUEWHEN, input_num=2, to_kwargs={}) 51 | LAST = wide_wraps(LAST, to_kwargs={1: 'n', 2: 'm'}) 52 | 53 | # 超买超卖 54 | ATR_CN = wide_wraps(ATR_CN, input_num=3, to_kwargs={3: 'timeperiod'}) 55 | BIAS = wide_wraps(BIAS) 56 | KDJ = wide_wraps(KDJ, input_num=3, to_kwargs={3: 'fastk_period', 4: 'M1', 5: 'M2'}, output_num=3) 57 | ROC = wide_wraps(ROC) 58 | TYPPRICE = wide_wraps(TYPPRICE, input_num=3, to_kwargs={}) 59 | MEDPRICE = wide_wraps(MEDPRICE, input_num=2, to_kwargs={}) 60 | RSI = wide_wraps(RSI) 61 | 62 | # 压力支撑 63 | BOLL = wide_wraps(BOLL, to_kwargs={1: 'timeperiod', 2: 'nbdevup', 3: 'nbdevdn'}, output_num=3) 64 | 65 | # 引用 66 | CONST = wide_wraps(CONST, to_kwargs={}) 67 | SUMIF = wide_wraps(SUMIF, input_num=2, to_kwargs={2: 'timeperiod'}) 68 | TR = wide_wraps(TR, input_num=3, to_kwargs={}) 69 | FILTER = wide_wraps(FILTER, to_kwargs={1: 'N'}) 70 | BARSLAST = wide_wraps(BARSLAST, to_kwargs={}) 71 | BARSLASTCOUNT = wide_wraps(BARSLASTCOUNT, to_kwargs={}) 72 | BARSSINCEN = wide_wraps(BARSSINCEN) 73 | 74 | # 统计 75 | AVEDEV = wide_wraps(AVEDEV) 76 | STD = wide_wraps(STD, to_kwargs={1: 'd'}) 77 | STDP = wide_wraps(STDP, to_kwargs={1: 'd'}) 78 | VAR = wide_wraps(VAR, to_kwargs={1: 'd'}) 79 | VARP = wide_wraps(VARP, to_kwargs={1: 'd'}) 80 | 81 | # 趋势 82 | BBI = wide_wraps(BBI, to_kwargs={1: 'timeperiod1', 2: 'timeperiod2', 3: 'timeperiod3', 4: 'timeperiod4'}) 83 | DPO = wide_wraps(DPO) 84 | MACD = wide_wraps(MACD, to_kwargs={1: 'fastperiod', 2: 'slowperiod', 3: 'signalperiod'}, output_num=3) 85 | MTM = wide_wraps(MTM) 86 | PSY = wide_wraps(PSY) 87 | DM = wide_wraps(DM, input_num=2, to_kwargs={2: 'timeperiod'}, output_num=2) 88 | DI = wide_wraps(DI, input_num=3, to_kwargs={3: 'timeperiod'}, output_num=2) 89 | DMI = wide_wraps(DMI, input_num=3, to_kwargs={3: 'timeperiod'}, output_num=2) 90 | TRIX = wide_wraps(TRIX) 91 | 92 | # 成交量 93 | OBV = wide_wraps(OBV, input_num=2, to_kwargs={2: 'scale'}) 94 | VR = wide_wraps(VR, input_num=2, to_kwargs={2: 'timeperiod'}) 95 | 96 | # EMA系列 97 | SMA_CN = wide_wraps(SMA_CN, to_kwargs={1: 'timeperiod', 2: 'M'}) 98 | 99 | # WQ已经定义过的公式,通达信中别名 100 | from .wide_wq import abs_ as ABS 101 | from .wide_wq import add as ADD 102 | from .wide_wq import divide as DIV 103 | from .wide_wq import log as LN # 自然对数 104 | from .wide_wq import log10 as LOG # 10为底的对数 105 | from .wide_wq import max_ as MAX 106 | from .wide_wq import mean as MEAN 107 | from .wide_wq import min_ as MIN 108 | from .wide_wq import multiply as MUL 109 | from .wide_wq import round_ as ROUND 110 | from .wide_wq import sign as SGN 111 | from .wide_wq import subtract as SUB 112 | from .wide_wq import if_else as IF 113 | from .wide_wq import ts_count as COUNT 114 | from .wide_wq import ts_delay as REF 115 | from .wide_wq import ts_delta as DIFF 116 | from .wide_wq import ts_max as HHV 117 | from .wide_wq import ts_arg_max as HHVBARS 118 | from .wide_wq import ts_mean as MA 119 | from .wide_wq import ts_min as LLV 120 | from .wide_wq import ts_arg_min as LLVBARS 121 | from .wide_wq import ts_sum as SUM 122 | from .wide_wq import rank as RANK 123 | from .wide_wq import power as POW 124 | 125 | ABS 126 | MAX 127 | MIN 128 | REF 129 | HHV 130 | HHVBARS 131 | MA 132 | LLV 133 | LLVBARS 134 | SUM 135 | ADD 136 | SUB 137 | MUL 138 | DIV 139 | ROUND 140 | MEAN 141 | LN 142 | LOG 143 | SGN 144 | DIFF 145 | IF 146 | COUNT 147 | RANK 148 | POW 149 | -------------------------------------------------------------------------------- /ta_cn/imports/wide_wq.py: -------------------------------------------------------------------------------- 1 | """ 2 | WQ公式,宽表模式,跳过空值 3 | """ 4 | from .long_wq import group_neutralize 5 | from ..utils_wide import wide_wraps, long_wraps 6 | from ..wq.arithmetic import abs_ 7 | from ..wq.arithmetic import add 8 | from ..wq.arithmetic import ceiling 9 | from ..wq.arithmetic import densify 10 | from ..wq.arithmetic import divide 11 | from ..wq.arithmetic import exp 12 | from ..wq.arithmetic import floor 13 | from ..wq.arithmetic import fraction 14 | from ..wq.arithmetic import inverse 15 | from ..wq.arithmetic import log 16 | from ..wq.arithmetic import log10 17 | from ..wq.arithmetic import log_diff 18 | from ..wq.arithmetic import max_ 19 | from ..wq.arithmetic import mean 20 | from ..wq.arithmetic import min_ 21 | from ..wq.arithmetic import multiply 22 | from ..wq.arithmetic import nan_mask 23 | from ..wq.arithmetic import nan_out 24 | from ..wq.arithmetic import power 25 | from ..wq.arithmetic import purify 26 | from ..wq.arithmetic import replace 27 | from ..wq.arithmetic import reverse 28 | from ..wq.arithmetic import round_ 29 | from ..wq.arithmetic import round_down 30 | from ..wq.arithmetic import s_log_1p 31 | from ..wq.arithmetic import sign 32 | from ..wq.arithmetic import signed_power 33 | from ..wq.arithmetic import sqrt 34 | from ..wq.arithmetic import subtract 35 | from ..wq.arithmetic import to_nan 36 | from ..wq.cross_sectional import rank 37 | from ..wq.cross_sectional import scale 38 | from ..wq.logical import if_else 39 | from ..wq.logical import less 40 | from ..wq.time_series import ts_arg_max 41 | from ..wq.time_series import ts_arg_min 42 | from ..wq.time_series import ts_corr 43 | from ..wq.time_series import ts_count 44 | from ..wq.time_series import ts_covariance 45 | from ..wq.time_series import ts_decay_linear 46 | from ..wq.time_series import ts_delay 47 | from ..wq.time_series import ts_delta 48 | from ..wq.time_series import ts_max 49 | from ..wq.time_series import ts_mean 50 | from ..wq.time_series import ts_min 51 | from ..wq.time_series import ts_product 52 | from ..wq.time_series import ts_rank 53 | from ..wq.time_series import ts_std_dev 54 | from ..wq.time_series import ts_sum 55 | 56 | # Arithmetic Operators 57 | abs_ = wide_wraps(abs_, direction=None, to_kwargs={}) 58 | add = wide_wraps(add, direction=None, input_num=2, to_kwargs={}) 59 | ceiling = wide_wraps(ceiling, direction=None, to_kwargs={}) 60 | densify = wide_wraps(densify, direction=None, to_kwargs={}) 61 | divide = wide_wraps(divide, direction=None, input_num=2, to_kwargs={}) 62 | exp = wide_wraps(exp, direction=None, to_kwargs={}) 63 | floor = wide_wraps(floor, direction=None, to_kwargs={}) 64 | fraction = wide_wraps(fraction, direction=None, to_kwargs={}) 65 | inverse = wide_wraps(inverse, direction=None, to_kwargs={}) 66 | log = wide_wraps(log, direction=None, to_kwargs={}) 67 | log10 = wide_wraps(log10, direction=None, to_kwargs={}) 68 | log_diff = wide_wraps(log_diff, direction=None, to_kwargs={}) 69 | max_ = wide_wraps(max_, direction=None, input_num=2, to_kwargs={}) 70 | mean = wide_wraps(mean, direction=None, input_num=2, to_kwargs={}) 71 | min_ = wide_wraps(min_, direction=None, input_num=2, to_kwargs={}) 72 | multiply = wide_wraps(multiply, direction=None, input_num=2, to_kwargs={}) 73 | nan_mask = wide_wraps(nan_mask, direction=None, input_num=2, to_kwargs={}) 74 | nan_out = wide_wraps(nan_out, direction=None, to_kwargs={2: 'lower', 3: 'upper'}) 75 | power = wide_wraps(power, direction=None, input_num=2, to_kwargs={}) 76 | purify = wide_wraps(purify, direction=None, to_kwargs={}) 77 | replace = wide_wraps(replace, direction=None, to_kwargs={2: 'target', 3: 'dest'}) 78 | reverse = wide_wraps(reverse, direction=None, to_kwargs={}) 79 | round_ = wide_wraps(round_, direction=None, to_kwargs={}) 80 | round_down = wide_wraps(round_down, direction=None, to_kwargs={}) 81 | sign = wide_wraps(sign, direction=None, to_kwargs={}) 82 | signed_power = wide_wraps(signed_power, direction=None, input_num=2, to_kwargs={}) 83 | s_log_1p = wide_wraps(s_log_1p, direction=None, to_kwargs={}) 84 | sqrt = wide_wraps(sqrt, direction=None, to_kwargs={}) 85 | subtract = wide_wraps(subtract, direction=None, input_num=2, to_kwargs={}) 86 | to_nan = wide_wraps(to_nan, direction=None, to_kwargs={2: 'value', 3: 'reverse'}) 87 | densify = wide_wraps(densify, direction=None, to_kwargs={}) 88 | log10 = wide_wraps(log10, direction=None, to_kwargs={}) 89 | mean = wide_wraps(mean, direction=None, input_num=2, to_kwargs={}) 90 | 91 | # Vector Operators 92 | # Logical Operators 93 | less = wide_wraps(less, input_num=2, to_kwargs={}) 94 | if_else = wide_wraps(if_else, direction=None, input_num=3, to_kwargs={}) 95 | 96 | # Transformational Operators 97 | 98 | # Cross Sectional Operators 99 | rank = wide_wraps(rank, direction=None, to_kwargs={}) 100 | scale = wide_wraps(scale, direction=None, to_kwargs={1: 'scale'}) 101 | 102 | # Group Operators 103 | group_neutralize = long_wraps(group_neutralize, direction='right') 104 | 105 | # Time Series Operators 106 | ts_arg_max = wide_wraps(ts_arg_max, to_kwargs={1: 'd'}) 107 | ts_arg_min = wide_wraps(ts_arg_min, to_kwargs={1: 'd'}) 108 | ts_corr = wide_wraps(ts_corr, input_num=2, to_kwargs={2: 'd'}) 109 | ts_count = wide_wraps(ts_count, to_kwargs={1: 'd'}) 110 | ts_covariance = wide_wraps(ts_covariance, input_num=2, to_kwargs={2: 'd'}) 111 | ts_decay_linear = wide_wraps(ts_decay_linear, to_kwargs={1: 'd'}) 112 | ts_delay = wide_wraps(ts_delay, to_kwargs={1: 'd'}) 113 | ts_delta = wide_wraps(ts_delta, to_kwargs={1: 'd'}) 114 | ts_max = wide_wraps(ts_max, to_kwargs={1: 'd'}) 115 | ts_mean = wide_wraps(ts_mean, to_kwargs={1: 'd'}) 116 | ts_min = wide_wraps(ts_min, to_kwargs={1: 'd'}) 117 | ts_product = wide_wraps(ts_product, to_kwargs={1: 'd'}) 118 | ts_rank = wide_wraps(ts_rank, to_kwargs={1: 'd'}) 119 | ts_std_dev = wide_wraps(ts_std_dev, to_kwargs={1: 'd'}) 120 | ts_sum = wide_wraps(ts_sum, to_kwargs={1: 'd'}) 121 | 122 | # Special Operators 123 | -------------------------------------------------------------------------------- /ta_cn/imports/wq_long.py: -------------------------------------------------------------------------------- 1 | """ 2 | 公式转alpha101 3 | 4 | 101 Formulaic Alphas 5 | 文档质量一般 6 | 1. 部分函数即有全小写,又有大小写混合 7 | 2. 很多参数应当是整数,但输入是小数,不得不做修正 8 | """ 9 | from ..imports import long_wq as L_WQ 10 | 11 | from ..utils import round_a_i, round_a_a_i 12 | 13 | correlation = round_a_a_i(L_WQ.ts_corr) 14 | decay_linear = round_a_i(L_WQ.ts_decay_linear) 15 | 16 | LessThan = L_WQ.less 17 | rank = L_WQ.rank 18 | ts_rank = round_a_i(L_WQ.ts_rank) 19 | scale = L_WQ.scale 20 | SignedPower = L_WQ.signed_power 21 | 22 | IF = L_WQ.if_else 23 | abs = L_WQ.abs_ 24 | log = L_WQ.log # 这里是用的自然对数 25 | MAX = L_WQ.max_ 26 | MIN = L_WQ.min_ 27 | sign = L_WQ.sign 28 | 29 | delta = round_a_i(L_WQ.ts_delta) 30 | ts_max = round_a_i(L_WQ.ts_max) 31 | ts_argmax = round_a_i(L_WQ.ts_arg_max) 32 | ts_min = round_a_i(L_WQ.ts_min) 33 | ts_argmin = round_a_i(L_WQ.ts_arg_min) 34 | product = round_a_i(L_WQ.ts_product) 35 | delay = round_a_i(L_WQ.ts_delay) 36 | sum = round_a_i(L_WQ.ts_sum) 37 | 38 | covariance = round_a_a_i(L_WQ.ts_covariance) 39 | stddev = round_a_i(L_WQ.ts_std_dev) # 引入的是全体标准差 40 | 41 | indneutralize = L_WQ.group_neutralize 42 | 43 | # 部分别名,这样官方公式可以减少改动 44 | Ts_Rank = ts_rank 45 | IndNeutralize = indneutralize 46 | Ts_ArgMax = ts_argmax 47 | Ts_ArgMin = ts_argmin 48 | LessThan = LessThan 49 | min = MIN 50 | max = MAX 51 | Sign = sign 52 | Log = log 53 | -------------------------------------------------------------------------------- /ta_cn/imports/wq_wide.py: -------------------------------------------------------------------------------- 1 | """ 2 | 公式转alpha101 3 | 4 | 101 Formulaic Alphas 5 | 文档质量一般 6 | 1. 部分函数即有全小写,又有大小写混合 7 | 2. 很多参数应当是整数,但输入是小数,不得不做修正 8 | """ 9 | from ..imports import wide_wq as W_WQ 10 | from ..utils import round_a_i, round_a_a_i 11 | 12 | correlation = round_a_a_i(W_WQ.ts_corr) 13 | decay_linear = round_a_i(W_WQ.ts_decay_linear) 14 | 15 | LessThan = W_WQ.less 16 | rank = W_WQ.rank 17 | ts_rank = round_a_i(W_WQ.ts_rank) 18 | scale = W_WQ.scale 19 | SignedPower = W_WQ.signed_power 20 | 21 | IF = W_WQ.if_else 22 | abs = W_WQ.abs_ 23 | log = W_WQ.log # 这里是用的自然对数 24 | MAX = W_WQ.max_ 25 | MIN = W_WQ.min_ 26 | sign = W_WQ.sign 27 | 28 | delta = round_a_i(W_WQ.ts_delta) 29 | ts_max = round_a_i(W_WQ.ts_max) 30 | ts_argmax = round_a_i(W_WQ.ts_arg_max) 31 | ts_min = round_a_i(W_WQ.ts_min) 32 | ts_argmin = round_a_i(W_WQ.ts_arg_min) 33 | product = round_a_i(W_WQ.ts_product) 34 | delay = round_a_i(W_WQ.ts_delay) 35 | sum = round_a_i(W_WQ.ts_sum) 36 | 37 | covariance = round_a_a_i(W_WQ.ts_covariance) 38 | stddev = round_a_i(W_WQ.ts_std_dev) # 引入的是全体标准差 39 | 40 | indneutralize = W_WQ.group_neutralize 41 | 42 | # 部分别名,这样官方公式可以减少改动 43 | Ts_Rank = ts_rank 44 | IndNeutralize = indneutralize 45 | Ts_ArgMax = ts_argmax 46 | Ts_ArgMin = ts_argmin 47 | LessThan = LessThan 48 | min = MIN 49 | max = MAX 50 | Sign = sign 51 | Log = log 52 | -------------------------------------------------------------------------------- /ta_cn/noise.py: -------------------------------------------------------------------------------- 1 | """ 2 | ATR与STD也是一种度量波动的方法,这里不再提供 3 | 4 | 以下方法来自于Trading Systems and Methods, Chapter 1, Measuring Noise 5 | 6 | References 7 | ---------- 8 | https://zhuanlan.zhihu.com/p/544744582 9 | 10 | """ 11 | from .wq.arithmetic import abs_, sqrt, log 12 | from .wq.time_series import ts_delta, ts_sum, ts_max, ts_min 13 | 14 | 15 | def efficiency_ratio(x, d): 16 | """效率系数。值越大,噪音越小。最大值为1,最小值为0 17 | 18 | 本质上是位移除以路程 19 | """ 20 | t1 = abs_(ts_delta(x, d)) 21 | t2 = ts_sum(abs_(ts_delta(x, 1)), d) 22 | return t1 / t2 23 | 24 | 25 | def price_density(high, low, d): 26 | """价格密度。值越大,噪音越大 27 | 28 | 如果K线高低相连,上涨为1,下跌也为1 29 | 如果K线高低平行,值大于1,最大为d 30 | """ 31 | t1 = ts_sum(high - low, d) 32 | t2 = ts_max(high) - ts_min(low) 33 | return t1 / t2 34 | 35 | 36 | def fractal_dimension(high, low, close, d): 37 | """分形维度。值越大,噪音越大""" 38 | t1 = ts_max(high) - ts_min(low) 39 | t2 = ts_delta(close, 1) # TODO: 这里是否要求绝对值? 40 | t3 = (1 / d) ** 2 41 | L = ts_sum(sqrt(t3 + t2 / t1), d) 42 | return 1 + (log(L) + log(2)) / log(2 * d) 43 | -------------------------------------------------------------------------------- /ta_cn/preprocess.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from .utils import pd_to_np 4 | from .wq.cross_sectional import winsorize, zscore, scale 5 | 6 | 7 | def winsorize_mad(x, n=3, constant=1.4826): 8 | """缩尾去极值,MAD法 9 | 10 | Parameters 11 | ---------- 12 | x: array 13 | 需要缩尾去极值的数据 14 | n: float 15 | 设置范围大小 16 | constant: float 17 | 比例因子。为了能将MAD当作标准差估计的一种一致估计量,使用两者之间的关联。其中比例因子常量k的值取决于分布类型,在正态分布下该常量约等于1.4826 18 | 19 | .. math:: \hat {\sigma } = k * MAD 20 | 21 | Returns 22 | ------- 23 | array 24 | `x` 经过MAD法缩尾去极值处理后的新数据 25 | 26 | References 27 | ---------- 28 | https://en.wikipedia.org/wiki/Median_absolute_deviation 29 | 30 | 31 | """ 32 | x = pd_to_np(x, copy=False) 33 | if np.isnan(x).all(): 34 | return x 35 | 36 | axis = x.ndim - 1 37 | _median = np.nanmedian(x, axis=axis, keepdims=True) 38 | _mad = np.nanmedian(abs(x - _median), axis=axis, keepdims=True) 39 | _mad = _mad * constant * n 40 | 41 | return np.clip(x, _median - _mad, _median + _mad) 42 | 43 | 44 | def winsorize_quantile(x, min_=0.1, max_=0.9): 45 | """缩尾去极值,分位数法 46 | 47 | Parameters 48 | ---------- 49 | x: array 50 | 需要缩尾去极值的数据 51 | min_: float 52 | 设置下界分位数 53 | max_: float 54 | 设置上界分位数 55 | 56 | Returns 57 | ------- 58 | array 59 | `x` 经过分位数法缩尾去极值处理后的新数据 60 | 61 | """ 62 | x = pd_to_np(x, copy=False) 63 | # RuntimeWarning: All-NaN slice encountered r, k = function_base._ureduce(a, 64 | if np.isnan(x).all(): 65 | return x 66 | 67 | axis = x.ndim - 1 68 | q = np.nanquantile(x, [min_, max_], axis=axis, keepdims=True) 69 | return np.clip(x, q[0], q[1]) 70 | 71 | 72 | """ 73 | 删除极值 74 | """ 75 | 76 | 77 | def drop_quantile(x, min_=0.1, max_=0.9): 78 | """删除去极值,分位数法 79 | 80 | Parameters 81 | ---------- 82 | x: array 83 | 需要删除去极值的数据 84 | min_: float 85 | 设置下界分位数 86 | max_: float 87 | 设置上界分位数 88 | 89 | Returns 90 | ------- 91 | array 92 | `x` 经过分位数法删除去极值处理后的新数据 93 | 94 | """ 95 | x = pd_to_np(x, copy=False) 96 | axis = x.ndim - 1 97 | 98 | q = np.nanquantile(x, [min_, max_], axis=axis, keepdims=True) 99 | x = np.where((x < q[0]) | (x > q[1]), np.nan, x) 100 | return x 101 | 102 | 103 | """ 104 | 标准化 105 | """ 106 | 107 | 108 | def fill_na(x): 109 | """用中位数填充,还是用平均值填充? 110 | 111 | -1到1归一化的值,用0填充也行吧? 112 | """ 113 | x = x.copy() 114 | x[np.isnan(x)] = np.nanmedian(x) 115 | return x 116 | 117 | 118 | # worldquant中函数的别名 119 | winsorize_3sigma = winsorize 120 | standardize_zscore = zscore 121 | standardize_minmax = scale 122 | -------------------------------------------------------------------------------- /ta_cn/regress.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import numba 4 | import numpy as np 5 | import pandas as pd 6 | 7 | from . import bn_wraps as bn 8 | from . import talib as ta 9 | from .nb import numpy_rolling_apply_1, _rolling_func_1_1_nb, _rolling_func_2_1_nb, extend_shape 10 | from .utils import pd_to_np 11 | 12 | _ta1d = ta.init(mode=1, skipna=False) 13 | _ta2d = ta.init(mode=2, skipna=False) 14 | 15 | # 通达信,线性回归预测值 16 | FORCAST = _ta2d.LINEARREG 17 | # 通达信,线性回归斜率 18 | SLOPE = _ta2d.LINEARREG_SLOPE 19 | 20 | 21 | @numba.jit(nopython=True, cache=True, nogil=True) 22 | def _slope_y_nb(y, x, m_x): 23 | """slope线性回归斜率。由于x固定,所以提前在外部计算,加快速度""" 24 | m_y = np.mean(y) 25 | return np.sum((x - m_x) * (y - m_y)) / np.sum((x - m_x) ** 2) 26 | 27 | 28 | def SLOPE_Y(real, timeperiod): 29 | """numba版,输出结果与LINEARREG_SLOPE一样 30 | 31 | SLOPE_Y(real, timeperiod=14) 32 | """ 33 | x = np.arange(timeperiod) 34 | m_x = np.mean(x) 35 | return numpy_rolling_apply_1([pd_to_np(real)], timeperiod, _rolling_func_1_1_nb, _slope_y_nb, x, m_x) 36 | 37 | 38 | @numba.jit(nopython=True, cache=True, nogil=True) 39 | def _slope_yx_nb(y, x): 40 | """slope线性回归斜率。y与x是一直变化的""" 41 | m_x = np.mean(x) 42 | m_y = np.mean(y) 43 | return np.sum((x - m_x) * (y - m_y)) / np.sum((x - m_x) ** 2) 44 | 45 | 46 | def SLOPE_YX(real0, real1, timeperiod): 47 | """与talib.BETA不一样,talib将price转return后再回归 48 | 49 | SLOPE_YX(real0, real1, timeperiod=30) 50 | """ 51 | return numpy_rolling_apply_1([pd_to_np(real0), pd_to_np(real1)], 52 | timeperiod, _rolling_func_2_1_nb, _slope_yx_nb) 53 | 54 | 55 | def ts_simple_regress(y, x, d, lag=0, rettype=0): 56 | """滚动一元线性回归。 57 | 58 | 由于利用了bottleneck的滚动功能,比SLOPE_YX_NB快一些,但精度有少量差异,后期需再验证 59 | 60 | Parameters 61 | ---------- 62 | y: array_like 63 | 回归因变量 64 | x: array_like 65 | 回归自变量 66 | d 67 | 移动计算窗口长度 68 | lag 69 | rettype 70 | 71 | Returns 72 | ------- 73 | residual_hat: ndarray 74 | 回归残差项 75 | intercept_hat: ndarray 76 | 回归截距项 77 | beta_hat: ndarray 78 | 回归系数项 79 | 80 | """ 81 | # 准备 82 | outputs = {} 83 | if not isinstance(rettype, list): 84 | rettype = [rettype] 85 | 86 | # 计算 87 | xy_ts_sum = bn.move_sum(np.multiply(x, y), window=d, axis=0) 88 | xx_ts_sum = bn.move_sum(np.multiply(x, x), window=d, axis=0) 89 | x_bar = bn.move_mean(x, window=d, axis=0) 90 | y_bar = bn.move_mean(y, window=d, axis=0) 91 | 92 | up = xy_ts_sum - np.multiply(x_bar, y_bar) * d 93 | down = xx_ts_sum - np.multiply(x_bar, x_bar) * d 94 | beta_hat = up / down 95 | intercept_hat = y_bar - np.multiply(beta_hat, x_bar) 96 | y_hat = intercept_hat + np.multiply(beta_hat, x) 97 | residual_hat = y - y_hat 98 | 99 | def _sse(x, y, d): 100 | x1 = extend_shape(x, d - 1) 101 | x2 = np.lib.stride_tricks.sliding_window_view(x1, d, axis=0) 102 | z = x2 - np.expand_dims(y, axis=-1) 103 | z = np.sum(z ** 2, axis=-1) 104 | return z 105 | 106 | # 残差 107 | outputs[0] = residual_hat 108 | # 截距 109 | outputs[1] = intercept_hat 110 | # 系数 111 | outputs[2] = beta_hat 112 | # 预测值 113 | outputs[3] = y_hat 114 | 115 | s = set([4, 5, 6]) & set(rettype) 116 | if len(s) > 0: 117 | # 残差平方和 118 | outputs[4] = _sse(y, y_hat, d) 119 | # 总平方和 120 | outputs[5] = _sse(y, y_bar, d) 121 | 122 | # 决定系数 123 | # 从“残差平方和的补”的角度来看 124 | # outputs[6] = 1 - outputs[4] / outputs[5] # 可能有负数 125 | # 从“可解释方差”的角度来看 126 | outputs[6] = _sse(y_hat, y_bar, d) / outputs[5] # 没有负数 127 | 128 | # 输出 129 | if len(rettype) == 1: 130 | return outputs[rettype[0]] 131 | else: 132 | return tuple([outputs[r] for r in rettype]) 133 | 134 | 135 | warnings.filterwarnings("ignore", category=numba.NumbaPerformanceWarning) 136 | 137 | 138 | @numba.jit(nopython=True, cache=True, nogil=True) 139 | def _rolling_func_xy_nb(x, y, out, timeperiod, func, *args): 140 | """滚动多元""" 141 | if x.ndim == 3: 142 | for i, (yy, xx) in enumerate(zip(y, x)): 143 | out[i + timeperiod - 1] = func(yy, xx, *args) 144 | 145 | return out 146 | 147 | 148 | @numba.jit(nopython=True, cache=True, nogil=True) 149 | def _ts_ols_nb(y, x): 150 | """使用可逆矩阵计算多元回归。 151 | 152 | 由于sliding_window_view后的形状再enumerate后比较特殊,所以原公式的转置进行了调整 153 | """ 154 | # return np.dot((np.dot(np.linalg.inv(np.dot(x, x.T)), x)), y) 155 | return np.linalg.pinv(x.T).dot(y) 156 | 157 | 158 | def ts_multiple_regress(y, x, timeperiod=10, add_constant=True): 159 | """时序上滚动多元线性回归 160 | 161 | Parameters 162 | ---------- 163 | y: 1d array 164 | 因变量。一维 165 | x: 2d array 166 | 自变量。二维。一列为一个特征 167 | timeperiod:int 168 | 周期 169 | add_constant: bool 170 | 是否添加常量 171 | 172 | Returns 173 | ------- 174 | residual: 175 | 残差。与y形状类似,由实际y-预测y而得到 176 | y_hat: 177 | 预测y 178 | coef: 179 | 系数。与x形状类似,每个特性占一例。时序变化,所以每天都有一行 180 | 181 | """ 182 | _y = pd_to_np(y) 183 | _x = pd_to_np(x) 184 | # 拼接出y1x这种大矩阵 185 | _y1x = np.vstack((_y, np.ones_like(_y), _x.T)).T 186 | 187 | # 找到某行出现nan 188 | mask = ~np.any(np.isnan(_y1x), axis=1) 189 | 190 | # 位置1开始加了常量1,位置2开始没有常量1 191 | _1x = _y1x[mask, 2 - add_constant:] 192 | _1y = _y1x[mask, 0] 193 | 194 | coef = numpy_rolling_apply_1([_1x, _1y], timeperiod, _rolling_func_xy_nb, _ts_ols_nb) 195 | y_hat = np.full_like(_y, np.nan, dtype=_y.dtype) 196 | y_hat[mask] = np.sum(_1x * coef, axis=1) 197 | residual = _y - y_hat 198 | return residual, y_hat, coef 199 | 200 | 201 | @numba.jit(nopython=True, cache=True, nogil=True) 202 | def _cs_ols_nb(y, x): 203 | """使用可逆矩阵计算多元回归。 204 | 205 | 标准的多元回归 206 | """ 207 | # https://github.com/tirthajyoti/Machine-Learning-with-Python/blob/master/Regression/Linear_Regression_Methods.ipynb 208 | # 由于出现了不可逆,导致常用的inv失效果,只能使用Moore-Penrose pseudoinverse 209 | # numpy.linalg.LinAlgError: Matrix is singular to machine precision. 210 | # return np.dot((np.dot(np.linalg.inv(np.dot(x.T, x)), x.T)), y) 211 | return np.linalg.pinv(x).dot(y) 212 | 213 | 214 | def multiple_regress(y, x, add_constant=True): 215 | """横截面上的多元回归。主要用于中性化多元回归场景 216 | 217 | 需要先按日期进行groupby,然后再应用回归函数 218 | """ 219 | _y = pd_to_np(y) 220 | _x = pd_to_np(x) 221 | # 拼接出y1x这种大矩阵 222 | _y1x = np.vstack((_y, np.ones_like(_y), _x.T)).T 223 | 224 | # 找到某行出现nan 225 | mask = ~np.any(np.isnan(_y1x), axis=1) 226 | 227 | # 位置1开始加了常量1,位置2开始没有常量1 228 | _1x = _y1x[mask, 2 - add_constant:] 229 | _1y = _y1x[mask, 0] 230 | 231 | coef = _cs_ols_nb(_1y, _1x) 232 | 233 | y_hat = np.full_like(_y, np.nan, dtype=_y.dtype) 234 | y_hat[mask] = _1x @ coef 235 | 236 | residual = _y - y_hat 237 | return residual, y_hat, coef 238 | 239 | 240 | def REGRESI(y, *args, timeperiod=60): 241 | if isinstance(y, pd.Series): 242 | x = pd.concat(args, axis=1) 243 | else: 244 | x = np.concatenate(args, axis=1) 245 | residual, y_hat, coef = ts_multiple_regress(y, x, timeperiod=timeperiod, add_constant=True) 246 | return residual 247 | -------------------------------------------------------------------------------- /ta_cn/research_report/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wukan1986/ta_cn/a569a618109daa804541c5d67fa4c0407f03fd49/ta_cn/research_report/__init__.py -------------------------------------------------------------------------------- /ta_cn/research_report/gdzq.py: -------------------------------------------------------------------------------- 1 | """ 2 | 光大证券,研报复现 3 | 4 | 光大金工跳槽为中金金工,RSRS改名为QRS 5 | """ 6 | from ..regress import ts_simple_regress 7 | from ..wq.time_series import ts_zscore 8 | 9 | 10 | def RSRS_v1(high, low, d): 11 | """阻力支撑相对强度(Resistance Support Relative Strength)""" 12 | beta = ts_simple_regress(high, low, d, lag=0, rettype=2) 13 | return beta 14 | 15 | 16 | def RSRS_v2(high, low, d, m): 17 | """阻力支撑相对强度,标准化""" 18 | beta = ts_simple_regress(high, low, d, lag=0, rettype=2) 19 | return ts_zscore(beta, m) 20 | 21 | 22 | def RSRS_v3(high, low, d, m): 23 | """阻力支撑相对强度,标准化。再乘R2 24 | 25 | 解释力度较弱时,数值乘R方会向0压缩 26 | """ 27 | beta, r2 = ts_simple_regress(high, low, d, lag=0, rettype=[2, 6]) 28 | return ts_zscore(beta, m) * r2 29 | -------------------------------------------------------------------------------- /ta_cn/research_report/gfzq.py: -------------------------------------------------------------------------------- 1 | """ 2 | 广发证券,研报复现 3 | 4 | 有关TD系列的指标,由于没有看懂,所以先放着 5 | 只测试了一下一维数据,等完全正确了再改成二维版并加快速 6 | 对TD有见解的朋友可以与我联系一起完善此部分 7 | 8 | """ 9 | import numba 10 | import numpy as np 11 | 12 | from ta_cn.wq.arithmetic import sign 13 | from ta_cn.wq.time_series import ts_delta 14 | 15 | 16 | @numba.jit(nopython=True, cache=True, nogil=True) 17 | def LLT(p, a): 18 | """LLT低延迟趋势线(Low-lag Trendline), 《短线择时策略研究之三:低延迟趋势线与交易性择时》 19 | 20 | """ 21 | # a = 2/(d+1) 22 | # 系数 23 | b1 = a - (a ** 2) / 4 24 | b2 = (a ** 2) / 2 25 | b3 = a - 3 * (a ** 2) / 4 26 | b4 = 2 * (1 - a) 27 | b5 = (1 - a) ** 2 28 | 29 | llt = p.copy() 30 | for t in range(2, len(p)): 31 | c1 = b1 * p[t] + b2 * p[t - 1] - b3 * p[t - 2] + b4 * llt[t - 1] - b5 * llt[t - 2] 32 | # 处理数据前段为空值的情况,中段出现空值将出错重新计算 33 | llt[t] = np.where(llt[t - 2] == llt[t - 2], c1, p[t]) 34 | return llt 35 | 36 | 37 | def TD9(close, n1=4, n2=9): 38 | """神奇九转。 单边行情下无效 39 | 40 | 返回持仓状态,不是买卖信息 41 | """ 42 | # 前4天 43 | ud = sign(ts_delta(close, n1)) 44 | setup = ud.copy() 45 | 46 | # 九转 47 | for i in range(n1, len(close)): 48 | # 相同就累加昨天的值,不同就设成1或-1,0 49 | setup[i] = np.where(ud[i] == ud[i - 1], setup[i - 1], 0) + ud[i] 50 | 51 | # 反转信号。-9表示要开始买入,9表示要开始卖出 52 | signal = np.zeros_like(setup) 53 | 54 | if True: 55 | # 返回持仓状态 56 | signal[setup >= n2] = -1 57 | signal[setup <= -n2] = 1 58 | else: 59 | # 返回买卖动作 60 | signal[setup == n2] = -1 61 | signal[setup == -n2] = 1 62 | 63 | return signal 64 | 65 | 66 | def tom_demark_sequential(high, low, close, n1=4, n2=9, n3=13): 67 | # TODO: 等这里完全正确后再改成二维版 68 | # 前4天 69 | ud = sign(ts_delta(close, n1)) 70 | setup = ud.copy() 71 | 72 | # 九转 73 | for i in range(n1, len(close)): 74 | # 相同就累加昨天的值,不同就设成1或-1,0 75 | setup[i] = np.where(ud[i] == ud[i - 1], setup[i - 1], 0) + ud[i] 76 | 77 | # 启动后做多计数条件 78 | buy_cond = close < ts_delta(low, 2) 79 | # 启动后做空计数条件 80 | sell_cond = close > ts_delta(high, 2) 81 | 82 | buy_count = buy_cond * 0 83 | sell_count = sell_cond * 0 84 | 85 | # 计数 TODO: 这里逻辑可能有问题 86 | for i in range(n1 + n2 - 1, len(close)): 87 | if setup[i] <= -n2: 88 | buy_count[i] = buy_count[i - 1] + buy_cond[i] 89 | elif setup[i] >= n2: 90 | sell_count[i] = sell_count[i - 1] + sell_cond[i] 91 | 92 | signal = np.zeros_like(close) 93 | signal[buy_count >= n3] = 1 94 | signal[sell_count >= n3] = -1 95 | return signal 96 | 97 | 98 | def TD(high, low, close, n1, n2, n3): 99 | """ 100 | 101 | Parameters 102 | ---------- 103 | n1: 104 | 价格比较滞后期数 105 | n2: 106 | 价格关系单向连续个数 107 | n3: 108 | 计数阶段最终信号发出所需计数值 109 | 110 | Returns 111 | ------- 112 | 113 | """ 114 | ud = sign(ts_delta(close, n1)) 115 | ud_acc = ud.copy() 116 | buy_cnt = np.zeros_like(ud) 117 | sell_cnt = np.zeros_like(ud) 118 | for i in range(n1, len(close)): 119 | # 相同就累加昨天的值 120 | ud_acc[i] = np.where(ud[i] == ud[i - 1], ud_acc[i - 1], 0) + ud[i] 121 | 122 | if ud_acc[i] == -n2: 123 | # 昨天买入已经启动了 124 | a = close[i] >= high[i - 2] 125 | b = high[i] > high[i - 1] 126 | c = close[i] > close[i - 1] 127 | if a & b & c: 128 | pass 129 | if ud_acc[i] == n2: 130 | # 昨天卖出已经启动了 131 | a = close[i] <= low[i - 2] 132 | b = low[i] < low[i - 1] 133 | c = close[i] < close[i - 1] 134 | if a & b & c: 135 | pass 136 | 137 | print(ud) 138 | print(ud_acc) 139 | # return ud_acc 140 | 141 | 142 | x = np.array([0, 1, 2, 3, 4, 143 | 5, 6, 7, 8, 9, 144 | 10, 11, 12, 13, 14, 145 | 15, 16, 17, 18, 19, 146 | 5], dtype=float) 147 | tom_demark_sequential(x, 4, 9) 148 | # http://www.snailtoday.com/archives/5469 149 | # https://github.com/dachuanwud/coincock/blob/main/program/%E6%8B%A9%E6%97%B6%E7%AD%96%E7%95%A5_%E5%9B%9E%E6%B5%8B/Signals.py 150 | # https://github.com/Rebeccawing/Quantitative-Contest/blob/master/GFTD.py 151 | # https://github.com/dachuanwud/coincock/blob/main/program/%E6%8B%A9%E6%97%B6%E7%AD%96%E7%95%A5_%E5%9B%9E%E6%B5%8B/Signals.py 152 | # https://github.com/tongtong263/quant_class/blob/75dbfcbca0924a6f829bb9069a7a60bced379213/%E9%87%8F%E5%8C%96%E9%AB%98%E9%98%B6%E6%95%99%E7%A8%8B/xbx_stock_2019/program/%E6%8B%A9%E6%97%B6%E7%AD%96%E7%95%A5_%E5%9B%9E%E6%B5%8B/Signals.py 153 | -------------------------------------------------------------------------------- /ta_cn/slow/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wukan1986/ta_cn/a569a618109daa804541c5d67fa4c0407f03fd49/ta_cn/slow/__init__.py -------------------------------------------------------------------------------- /ta_cn/slow/slow.py: -------------------------------------------------------------------------------- 1 | """这里存放因速度慢被淘汰的代码,因为有教学意义而保留""" 2 | import numpy as np 3 | import pandas as pd 4 | 5 | from .. import talib as ta 6 | from ..ema import EMA_1_PD, WS_SUM, SMA_CN 7 | from ..nb import fill_notna 8 | from ..tdx import HHV, LLV, MA, REF, SUM, MAX, ABS 9 | from ..tdx.over_bought_over_sold import ROC, TYPPRICE 10 | from ..tdx.reference import TR 11 | from ..tdx.statistics import AVEDEV 12 | from ..utils import np_to_pd 13 | 14 | _ta1d = ta.init(mode=1, skipna=False) 15 | _ta2d = ta.init(mode=2, skipna=False) 16 | 17 | 18 | def ATR_TA(high, low, close, timeperiod=14): 19 | """ATR真实波幅N日平均 20 | 21 | talib的ATR算法类似于SMA,所以要重写此处才与中国ATR相同 22 | """ 23 | return SMA_CN(TR(high, low, close), timeperiod) 24 | 25 | 26 | def ATR_CN(high, low, close, timeperiod=14): 27 | """ATR真实波幅N日平均 28 | 29 | talib的ATR算法类似于EMA,所以要重写此处才与中国ATR相同 30 | """ 31 | 32 | # 以下要慢一些,不采用了 33 | def func(high, low, close, timeperiod): 34 | return _ta1d.SMA(_ta1d.TRANGE(high, low, close), timeperiod) 35 | 36 | return _ta2d.tafunc_nditer_1(func, [high, low, close], {'timeperiod': timeperiod}, 37 | ['high', 'low', 'close'], ['real']) 38 | 39 | 40 | def MACD_CN(real: pd.DataFrame, fastperiod=12, slowperiod=26, signalperiod=9): 41 | # 中国区使用公式自行实现的算法,由于多次用到EMA,效率不高,建议直接使用MACDEXT 42 | DIF = EMA_1_PD(real, fastperiod) - EMA_1_PD(real, slowperiod) 43 | DEA = EMA_1_PD(DIF, signalperiod) 44 | MACD = (DIF - DEA) * 2 45 | return DIF, DEA, MACD 46 | 47 | 48 | def KDJ_CN(high, low, close, fastk_period=9, M1=3, M2=3): 49 | """KDJ指标""" 50 | hh = HHV(high, fastk_period) 51 | ll = LLV(low, fastk_period) 52 | RSV = (close - ll) / (hh - ll) * 100 53 | K = EMA_1_PD(RSV, (M1 * 2 - 1)) 54 | D = EMA_1_PD(K, (M2 * 2 - 1)) 55 | 56 | J = K * 3 - D * 2 57 | return K, D, J 58 | 59 | 60 | def TRIX_CN(real, timeperiod=12): 61 | """三重指数平滑均线 62 | 63 | 由EMA算法差异导致的不同 64 | """ 65 | TR = EMA_1_PD(EMA_1_PD(EMA_1_PD(real, timeperiod), timeperiod), timeperiod) 66 | return ROC(TR, 1) 67 | 68 | 69 | def CCI(high, low, close, timeperiod=14): 70 | """CCI顺势指标,talib版更快""" 71 | tp = TYPPRICE(high, low, close) 72 | return (tp - MA(tp, timeperiod)) / (0.015 * AVEDEV(tp, timeperiod)) 73 | 74 | 75 | def RSI(real, timeperiod=24): 76 | """RSI指标""" 77 | DIF = real - REF(real, 1) 78 | return SMA_CN(MAX(DIF, 0), timeperiod, 1) / SMA_CN(ABS(DIF), timeperiod, 1) * 100 79 | 80 | 81 | def WMA(real, timeperiod=5): 82 | """加权移动平均""" 83 | 84 | def func(x): 85 | # 复制于MyTT,比tqsdk中tafunc中计算要快 86 | return x[::-1].cumsum().sum() * 2 / timeperiod / (timeperiod + 1) 87 | 88 | return np_to_pd(real).rolling(timeperiod).apply(func, raw=True) 89 | 90 | 91 | def MFI(high, low, close, volume, timeperiod=14): 92 | """MFI指标""" 93 | tp = TYPPRICE(high, low, close) 94 | tpv = tp * volume 95 | # 比TALIB结果多一个数字,通过置空实现与TA-LIB完全一样 96 | tpv = fill_notna(tpv, fill_value=np.nan, n=1) 97 | 98 | is_raising = tp > REF(tp, 1) 99 | pos_sum = SUM(is_raising * tpv, timeperiod) 100 | neg_sum = SUM(~is_raising * tpv, timeperiod) 101 | return 100 * pos_sum / (pos_sum + neg_sum) 102 | 103 | 104 | def DM(high, low, timeperiod=14): 105 | """Directional Movement方向动量 106 | 107 | WS_SUM威尔德平滑求和 108 | """ 109 | HD = high - REF(high, 1) 110 | LD = REF(low, 1) - low 111 | 112 | # REF导至出现空,处理一下,防止空值出现 113 | HD[np.isnan(HD) & (~np.isnan(high))] = 0 114 | LD[np.isnan(LD) & (~np.isnan(low))] = 0 115 | 116 | # talib中是用的威尔德平滑 117 | PDM = WS_SUM(((HD > 0) & (HD > LD)) * HD, timeperiod) 118 | MDM = WS_SUM(((LD > 0) & (LD > HD)) * LD, timeperiod) 119 | return PDM, MDM 120 | 121 | 122 | def DM_CN(high, low, timeperiod=14): 123 | """中国版DM 124 | 125 | SUM滚动求和 126 | """ 127 | HD = high - REF(high, 1) 128 | LD = REF(low, 1) - low 129 | # 而中国版一般是直接滚动求和 130 | PDM = SUM(((HD > 0) & (HD > LD)) * HD, timeperiod) 131 | MDM = SUM(((LD > 0) & (LD > HD)) * LD, timeperiod) 132 | return PDM, MDM 133 | 134 | 135 | def DI(high, low, close, timeperiod=14): 136 | """Directional Indicator方向指标""" 137 | tr = TR(high, low, close) 138 | # 数据有效区开始,值由nan要设成0 139 | tr[np.isnan(tr) & (~np.isnan(close))] = 0 140 | 141 | TRS = WS_SUM(tr, timeperiod) 142 | # 比talib多一个,删除它 143 | TRS = fill_notna(TRS, fill_value=np.nan, n=1) 144 | 145 | PDM, MDM = DM(high, low, timeperiod) 146 | PDI = PDM * 100 / TRS 147 | MDI = MDM * 100 / TRS 148 | return PDI, MDI 149 | 150 | 151 | def DI_CN(high, low, close, timeperiod=14): 152 | """中国版DI 153 | 154 | 区别是SUM与WS_SUM 155 | """ 156 | TRS = SUM(TR(high, low, close), timeperiod) 157 | PDM, MDM = DM_CN(high, low, timeperiod) 158 | PDI = PDM * 100 / TRS 159 | MDI = MDM * 100 / TRS 160 | return PDI, MDI 161 | 162 | 163 | def DMI(high, low, close, timeperiod=14): 164 | """趋向指标""" 165 | PDI, MDI = DI(high, low, close, timeperiod=timeperiod) 166 | ADX = SMA_CN(ABS(PDI - MDI) / (PDI + MDI) * 100, timeperiod) 167 | # 这里timeperiod-1,才正好与talib对应 168 | ADXR = (ADX + REF(ADX, timeperiod - 1)) / 2 169 | return PDI, MDI, ADX, ADXR 170 | 171 | 172 | def DMI_CN(high, low, close, timeperiod=14): 173 | # DI算法不同 174 | PDI, MDI = DI_CN(high, low, close, timeperiod=timeperiod) 175 | # ADX中的MA与SMA不同 176 | ADX = MA(ABS(PDI - MDI) / (PDI + MDI) * 100, timeperiod) 177 | # timeperiod与timeperiod-1不同 178 | ADXR = (ADX + REF(ADX, timeperiod)) / 2 179 | return PDI, MDI, ADX, ADXR 180 | 181 | 182 | def _AVEDEV(real, timeperiod: int = 20): 183 | """平均绝对偏差,慢,请用nb版""" 184 | 185 | def mad(x): 186 | return np.abs(x - x.mean()).mean() 187 | 188 | return np_to_pd(real).rolling(window=timeperiod).apply(mad, raw=True) 189 | 190 | 191 | def _SLOPE(S, N=14): 192 | return np_to_pd(S).rolling(N).apply(lambda x: np.polyfit(range(N), x, deg=1)[0], raw=True) 193 | 194 | 195 | def _FORCAST(S, N=14): 196 | return np_to_pd(S).rolling(N).apply(lambda x: np.polyval(np.polyfit(range(N), x, deg=1), N - 1), raw=True) 197 | -------------------------------------------------------------------------------- /ta_cn/split.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from .wq.cross_sectional import rank as RANK 4 | 5 | 6 | def cut(x, bins=[0, 10, 100, 1000, 10000], rank=True, pct=False): 7 | """返回的值最小值一般是1,但如果原值小于bins,那么最小值是0""" 8 | if rank: 9 | x = RANK(x, pct=pct) 10 | return np.digitize(x, bins, right=True) 11 | 12 | 13 | def qcut(x, q=[0, 0.5, 1], rank=True, pct=True): 14 | if rank: 15 | x = RANK(x, pct=pct) 16 | return np.digitize(x, bins=q, right=True) 17 | 18 | 19 | def top_k(x, bins=[0, 50, 100, 200]): 20 | """前N权重字典。正序,越小排名越靠前。所以用户可能需要主动取负数 21 | 22 | Parameters 23 | ---------- 24 | x 25 | bins 26 | 27 | """ 28 | labels = cut(x, bins=bins, rank=True, pct=False) 29 | # 前200=前50+前50到100+前100到200 30 | d = {k: np.where(labels <= i, x, np.nan) for i, k in enumerate(bins) if k > 0} 31 | 32 | return d 33 | 34 | 35 | def quantile_n(x, n=10): 36 | """分位数权重 37 | 38 | Parameters 39 | ---------- 40 | 41 | Returns 42 | ------- 43 | 44 | """ 45 | q = np.linspace(0, 1, n + 1) 46 | 47 | # 根据因子大小值进行百分位分层,数少就少分几层 48 | labels = qcut(x, q=q, rank=True, pct=True) 49 | 50 | # 因子分层,在这里就已经过滤只看几组,减少内存 51 | d = {k: np.where(labels == i, x, np.nan) for i, k in enumerate(q) if k > 0} 52 | 53 | return d 54 | -------------------------------------------------------------------------------- /ta_cn/talib/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Author :wukan 4 | # @License :(C) Copyright 2022, wukan 5 | # @Date :2022-06-16 6 | """ 7 | Examples 8 | -------- 9 | 只要替换导入即可直接支持二维矩阵 10 | >>> import talib as ta 11 | >>> import ta_cn.talib as ta 12 | >>> _ta1d = ta.init(mode=1, skipna=False, to_globals=False) 13 | >>> _ta2d = ta.init(mode=2, skipna=False, to_globals=False) 14 | 15 | """ 16 | from functools import wraps, reduce 17 | 18 | import numpy as np 19 | import talib as _talib 20 | from talib import abstract as _abstract 21 | 22 | 23 | def tafunc_nditer_1(tafunc, args, kwargs, input_names, output_names, skipna): 24 | """直接调用talib""" 25 | 26 | def ALL_NOTNA(*args): 27 | """多输入,同位置没有出现过nan,标记成True""" 28 | return reduce(lambda x, y: np.logical_and(x, ~np.isnan(y)), [True] + list(args)) 29 | 30 | # 不跳过空值,直接调用函数 31 | if not skipna: 32 | return tafunc(*args, **kwargs) 33 | 34 | # 取一个非数字,得用于得到形状 35 | real = args[0] 36 | 37 | output_num = len(output_names) 38 | 39 | outputs = [np.full_like(real, fill_value=np.nan) for _ in output_names] 40 | 41 | _notna = ALL_NOTNA(*args) 42 | # 只有不连续的nan才需要做切片 https://www.cnpython.com/qa/352363 43 | # TODO: https://stackoverflow.com/questions/41721674/find-consecutive-repeated-nan-in-a-numpy-array/41722059#41722059 44 | _in = [v[_notna] for v in args] 45 | 46 | if len(_in[0]) > 0: 47 | ta_out = tafunc(*_in, **kwargs) 48 | if output_num == 1: 49 | ta_out = [ta_out] 50 | 51 | for _i, _o in zip(outputs, ta_out): 52 | _i[_notna] = _o 53 | 54 | # 输出 55 | return outputs[0] if output_num == 1 else tuple(outputs) 56 | 57 | 58 | def tafunc_nditer_2(tafunc, args, kwargs, input_names, output_names, 59 | skipna, order='F'): 60 | """内部按列迭代函数,支持timeperiod等命名参数向量化 61 | 62 | Parameters 63 | ---------- 64 | tafunc 65 | 计算单列的函数 66 | args 67 | 位置参数 68 | kwargs 69 | 命名参数 70 | output_names: list 71 | tafunc输出参数名 72 | skipna: 73 | 如想跳过空值,需将数据堆叠到首或尾,实现连续计算 74 | order: 75 | F, 按列进行遍历 76 | C, 按行进行遍历 77 | 78 | Returns 79 | ------- 80 | tuple 81 | 输出元组 82 | 83 | """ 84 | 85 | def num_to_np(x, like): 86 | """将单字数字转成矩阵""" 87 | if hasattr(x, "__getitem__"): 88 | # 长度不足时,用最后一个值填充之后的参数 89 | y = np.full_like(like, fill_value=x[-1]) 90 | # 长度超出时,截断 91 | y[:len(x)] = x[:len(y)] 92 | return y 93 | # 单一值,填充成唯一值 94 | return np.full_like(like, fill_value=x) 95 | 96 | def last_isna(x): 97 | # 只检查最后一行 98 | return np.any([y[-1] != y[-1] for y in x]) 99 | 100 | real = args[0] 101 | 102 | if real.ndim == 1: 103 | return tafunc_nditer_1(tafunc, args, kwargs, input_names, output_names, skipna) 104 | 105 | # =====以下是二维====== 106 | inputs = [*args] 107 | 108 | # 输出缓存 109 | outputs = [np.full_like(real, fill_value=np.nan) for _ in output_names] 110 | kwargs = {k: num_to_np(v, real[0]) for k, v in kwargs.items()} 111 | 112 | # 只有一行输入时需要特别处理 113 | with np.nditer(inputs + outputs, 114 | flags=['external_loop'] if real.shape[0] > 1 else None, 115 | order=order, 116 | op_flags=[['readonly']] * len(inputs) + [['writeonly']] * len(outputs)) as it: 117 | for i, in_out in enumerate(it): 118 | if real.shape[0] == 1: 119 | # 需要将0维array改成1维,否则talib报错 120 | in_out = [v.reshape(1) for v in in_out] 121 | 122 | _in = in_out[:len(inputs)] # 分离输入 123 | # 最后一行出现了空 124 | if last_isna(_in): 125 | continue 126 | _out = in_out[-len(outputs):] # 分离输出 127 | 128 | # 切片得到每列的参数 129 | _kw = {k: v[i] for k, v in kwargs.items()} 130 | 131 | # 计算并封装 132 | ta_out = tafunc(*_in, **_kw) 133 | if not isinstance(ta_out, tuple): 134 | ta_out = tuple([ta_out]) 135 | 136 | for _i, _o in zip(_out, ta_out): 137 | _i[...] = _o 138 | 139 | # 输出 140 | if len(outputs) == 1: 141 | return outputs[0] 142 | return outputs 143 | 144 | 145 | def ta_decorator(func, mode, input_names, output_names, skipnan): 146 | # 设置对应处理函数 147 | ff = {1: tafunc_nditer_1, 2: tafunc_nditer_2}.get(mode) 148 | 149 | @wraps(func) 150 | def decorated(*args, **kwargs): 151 | return ff(func, args, kwargs, input_names, output_names, skipnan) 152 | 153 | return decorated 154 | 155 | 156 | def init(mode=1, skipna=False, to_globals=False): 157 | """初始化环境 158 | 159 | Parameters 160 | ---------- 161 | mode: int 162 | 1: 输入数据支持一维矩阵。数据使用位置,周期等使用命名 163 | 2. 输入参数支持一维向量。数据使用位置,周期等使用命名。否则报错 164 | skipna: bool 165 | 是否跳过空值。跳过空值功能会导致计算变慢。 166 | - 确信数据不会中途出现空值建议设置成False, 加快计算 167 | to_globals: bool 168 | 注册到包中 169 | 170 | """ 171 | assert mode in (1, 2) 172 | 173 | class TA_CN_LIB: 174 | pass 175 | 176 | lib = TA_CN_LIB() 177 | for i, func_name in enumerate(_talib.get_functions()): 178 | """talib遍历""" 179 | _ta_func = getattr(_talib, func_name) 180 | info = _abstract.Function(func_name).info 181 | output_names = info['output_names'] 182 | input_names = info['input_names'] 183 | 184 | # 创建函数 185 | f = ta_decorator(_ta_func, mode, input_names, output_names, skipna) 186 | setattr(lib, func_name, f) 187 | if to_globals: 188 | globals()[func_name] = f 189 | 190 | return lib 191 | 192 | 193 | # ============================================= 194 | 195 | TA_COMPATIBILITY_DEFAULT = 0 # 使用MA做第一个值 196 | TA_COMPATIBILITY_METASTOCK = 1 # 使用Price做第一个值 197 | 198 | _COMPATIBILITY_ENABLE_ = False 199 | 200 | 201 | def set_compatibility_enable(enable): 202 | """talib兼容性设置""" 203 | global _COMPATIBILITY_ENABLE_ 204 | _COMPATIBILITY_ENABLE_ = enable 205 | 206 | 207 | def set_compatibility(compatibility): 208 | """talib兼容性设置""" 209 | global _COMPATIBILITY_ENABLE_ 210 | if _COMPATIBILITY_ENABLE_: 211 | print('do talib.set_compatibility', compatibility) 212 | _talib.set_compatibility(compatibility) 213 | -------------------------------------------------------------------------------- /ta_cn/tdx/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | 优先实现WorldQuant BRAIN中Fast Expression的函数 3 | 然后再实现通达信的函数 4 | 5 | 此文件中不跳过空值 6 | """ 7 | 8 | from ..wq.arithmetic import abs_ as ABS 9 | from ..wq.arithmetic import add as ADD 10 | from ..wq.arithmetic import divide as DIV 11 | from ..wq.arithmetic import log as LN # 自然对数 12 | from ..wq.arithmetic import log10 as LOG # 10为底的对数 13 | from ..wq.arithmetic import max_ as MAX 14 | from ..wq.arithmetic import mean as MEAN 15 | from ..wq.arithmetic import min_ as MIN 16 | from ..wq.arithmetic import multiply as MUL 17 | from ..wq.arithmetic import round_ as ROUND 18 | from ..wq.arithmetic import sign as SGN 19 | from ..wq.arithmetic import subtract as SUB 20 | from ..wq.logical import if_else as IF 21 | from ..wq.time_series import ts_count as COUNT 22 | from ..wq.time_series import ts_delay as REF 23 | from ..wq.time_series import ts_delta as DIFF 24 | from ..wq.time_series import ts_max as HHV 25 | from ..wq.time_series import ts_mean as MA 26 | from ..wq.time_series import ts_min as LLV 27 | from ..wq.time_series import ts_sum as SUM 28 | from ..wq.cross_sectional import rank as RANK 29 | 30 | ABS 31 | MAX 32 | MIN 33 | REF 34 | HHV 35 | MA 36 | LLV 37 | SUM 38 | ADD 39 | SUB 40 | MUL 41 | DIV 42 | ROUND 43 | MEAN 44 | LN 45 | LOG 46 | SGN 47 | DIFF 48 | IF 49 | COUNT 50 | RANK 51 | -------------------------------------------------------------------------------- /ta_cn/tdx/logical.py: -------------------------------------------------------------------------------- 1 | import numba 2 | import numpy as np 3 | 4 | from . import SUM 5 | from .. import EPSILON 6 | from ..nb import numpy_rolling_apply_1, _rolling_func_1_1_nb 7 | from ..utils import np_to_pd, num_to_np, pd_to_np 8 | 9 | 10 | def CROSS(S1, S2): 11 | """向上金叉""" 12 | # 可能输入单个数字,需预处理 13 | S1 = num_to_np(S1, S2) 14 | # 处理精度问题 15 | # 1. S1<=S2时,需要给S2范围大一点,表示等于 16 | # 2. S1>S2时,同样要给S2范围大一点,排除等于 17 | S2 = num_to_np(S2, S1) + EPSILON 18 | 19 | arr = np.zeros_like(S1, dtype=bool) 20 | # 输入为Series时对齐有差异,前面需转成numpy 21 | arr[1:] = (S1 <= S2)[:-1] & (S1 > S2)[1:] 22 | return arr 23 | 24 | 25 | def EVERY(real, timeperiod): 26 | """最近timeperiod天是否全为True 27 | 28 | EVERY(real, timeperiod=5) 29 | """ 30 | return SUM(real, timeperiod) == timeperiod 31 | 32 | 33 | def EXIST(real, timeperiod): 34 | """最近timeperiod天是否存在一天为True 35 | 36 | EXIST(real, timeperiod=5) 37 | """ 38 | return SUM(real, timeperiod) > 0 39 | 40 | 41 | def BETWEEN(S, A, B): 42 | """BETWEEN(A,B,C)表示A处于B和C之间时返回1,否则返回0""" 43 | return ((A < S) & (S < B)) | ((A > S) & (S > B)) 44 | 45 | 46 | def VALUEWHEN(S, X): 47 | """条件跟随函数。当COND条件成立时,取X的当前值,否则取VALUEWHEN的上个值""" 48 | return np_to_pd(np.where(S, X, np.nan)).ffill() 49 | 50 | 51 | @numba.jit(nopython=True, cache=True, nogil=True) 52 | def _last_nb(arr, n, m): 53 | """LAST(X,A,B),A>B,表示从前A日到前B日一致满足X条件""" 54 | return np.all(arr[:n - m]) 55 | 56 | 57 | def LAST(real, n, m): 58 | """LAST(X,A,B),A>B,表示从前A日到前B日一致满足X条件 59 | 60 | LAST(real, n=20, m=10) 61 | """ 62 | return numpy_rolling_apply_1([pd_to_np(real)], n, _rolling_func_1_1_nb, _last_nb, n, m) 63 | -------------------------------------------------------------------------------- /ta_cn/tdx/over_bought_over_sold.py: -------------------------------------------------------------------------------- 1 | from talib import MA_Type 2 | 3 | from . import MA, REF, HHV, LLV, MEAN 4 | from .reference import TR 5 | from .. import talib as ta 6 | from ..talib import set_compatibility, TA_COMPATIBILITY_DEFAULT, TA_COMPATIBILITY_METASTOCK 7 | 8 | _ta1d = ta.init(mode=1, skipna=False, to_globals=False) 9 | _ta2d = ta.init(mode=2, skipna=False, to_globals=False) 10 | 11 | 12 | def ATR_CN(high, low, close, timeperiod): 13 | """ATR真实波幅N日平均 14 | 15 | ATR_CN(high, low, close, timeperiod=14) 16 | 17 | talib的ATR算法类似于EMA,所以要重写此处才与中国ATR相同 18 | """ 19 | return MA(TR(high, low, close), timeperiod) 20 | 21 | 22 | def BIAS(real, timeperiod): 23 | """BIAS乖离率 24 | 25 | BIAS(real, timeperiod=6) 26 | 27 | Parameters 28 | ---------- 29 | real 30 | timeperiod:int 31 | 常用参数:6,12,24 32 | 33 | """ 34 | return (real / MA(real, timeperiod) - 1) * 100 35 | 36 | 37 | def KDJ(high, low, close, fastk_period, M1, M2): 38 | """KDJ指标 39 | 40 | KDJ(high, low, close, fastk_period=9, M1=3, M2=3) 41 | 42 | talib中EMA的参数用法不同 43 | """ 44 | set_compatibility(TA_COMPATIBILITY_METASTOCK) 45 | 46 | K, D = _ta2d.STOCH(high, low, close, 47 | fastk_period=fastk_period, 48 | slowk_period=(M1 * 2 - 1), slowk_matype=MA_Type.EMA, 49 | slowd_period=(M2 * 2 - 1), slowd_matype=MA_Type.EMA) 50 | 51 | J = K * 3 - D * 2 52 | return K, D, J 53 | 54 | 55 | def ROC(real, timeperiod=12): # 变动率指标 56 | """ROC变动率指标 57 | 58 | Parameters 59 | ---------- 60 | real 61 | timeperiod:int 62 | 常用参数:12 63 | 64 | Examples 65 | -------- 66 | 股票软件上常再求一次MA 67 | >>> MA(ROC(CLOSE, 12), 6) 68 | 69 | """ 70 | if real.ndim == 2: 71 | return (real / REF(real, timeperiod) - 1) * 100 72 | else: 73 | return _ta1d.ROC(real, timeperiod=timeperiod) 74 | 75 | 76 | def TYPPRICE(high, low, close): 77 | """典型价格。高低收的平均值""" 78 | if close.ndim == 2: 79 | return MEAN(high, low, close) 80 | else: 81 | return _ta1d.TYPPRICE(high, low, close) 82 | 83 | 84 | def MEDPRICE(high, low): 85 | """中间价格。高低平均值""" 86 | if high.ndim == 2: 87 | return MEAN(high, low) 88 | else: 89 | return _ta1d.MEDPRICE(high, low) 90 | 91 | 92 | def WR(high, low, close, timeperiod=10): 93 | """W&R威廉指标 94 | 95 | Parameters 96 | ---------- 97 | high 98 | low 99 | close 100 | timeperiod: int 101 | 常用参数:10, 6 102 | 103 | Returns 104 | ------- 105 | 106 | """ 107 | 108 | if close.ndim == 2: 109 | hh = HHV(high, timeperiod) 110 | ll = LLV(low, timeperiod) 111 | return (hh - close) / (hh - ll) * 100 112 | else: 113 | return _ta1d.WILLR(high, low, close, timeperiod=timeperiod) 114 | 115 | 116 | def RSI(real, timeperiod=24): 117 | """RSI指标""" 118 | # 如果设置成1,将会多一个数字 119 | set_compatibility(TA_COMPATIBILITY_DEFAULT) 120 | 121 | return _ta2d.RSI(real, timeperiod=timeperiod) 122 | -------------------------------------------------------------------------------- /ta_cn/tdx/pressure_support.py: -------------------------------------------------------------------------------- 1 | from . import MA 2 | from .statistics import STDP 3 | 4 | 5 | def BOLL(real, timeperiod, nbdevup, nbdevdn): 6 | """BOLL指标,布林带 7 | 8 | 使用总体标准差 9 | 10 | BOLL(real, timeperiod=20, nbdevup=2, nbdevdn=2) 11 | 12 | References 13 | ---------- 14 | https://en.wikipedia.org/wiki/Bollinger_Bands 15 | 16 | """ 17 | MID = MA(real, timeperiod) 18 | # 这里是总体标准差,值比样本标准差小。部分软件使用样本标准差是错误的, 19 | _std = STDP(real, timeperiod) 20 | UPPER = MID + _std * nbdevup 21 | LOWER = MID - _std * nbdevdn 22 | return UPPER, MID, LOWER 23 | -------------------------------------------------------------------------------- /ta_cn/tdx/reference.py: -------------------------------------------------------------------------------- 1 | import numba 2 | import numpy as np 3 | 4 | from . import ABS, MAX, REF 5 | from .. import bn_wraps as bn 6 | from .. import talib as ta 7 | from ..nb import numpy_rolling_apply_1, _rolling_func_1_1_nb 8 | from ..utils import pd_to_np 9 | 10 | _ta1d = ta.init(mode=1, skipna=False, to_globals=False) 11 | _ta2d = ta.init(mode=2, skipna=False, to_globals=False) 12 | 13 | 14 | def CONST(real): 15 | """取A最后的值为常量""" 16 | return np.full_like(real, real[-1]) 17 | 18 | 19 | def SUMIF(real, condition, timeperiod): 20 | """!!!注意,condition位置""" 21 | return bn.move_sum(real * condition, window=timeperiod, axis=0) 22 | 23 | 24 | def TR(high, low, close): 25 | """TR真实波幅""" 26 | lc = REF(close, 1) 27 | return MAX(high - low, ABS(high - lc), ABS(lc - low)) 28 | 29 | 30 | @numba.jit(nopython=True, cache=True, nogil=True) 31 | def _filter_nb(arr, n): 32 | is_1d = arr.ndim == 1 33 | x = arr.shape[0] 34 | y = 1 if is_1d else arr.shape[1] 35 | 36 | for j in range(y): 37 | a = arr if is_1d else arr[:, j] 38 | 39 | # 为了跳过不必要部分,由for改while 40 | i = 0 41 | while i < x: 42 | if a[i]: 43 | a[i + 1:i + 1 + n] = 0 44 | i += n + 1 45 | else: 46 | i += 1 47 | return arr 48 | 49 | 50 | def FILTER(S, N): 51 | """FILTER函数,S满足条件后,将其后N周期内的数据置为0""" 52 | S = pd_to_np(S, copy=True) 53 | return _filter_nb(S, N) 54 | 55 | 56 | @numba.jit(nopython=True, cache=True, nogil=True) 57 | def _bars_last_nb(arr, out): 58 | """上一次条件成立到当前的周期数""" 59 | is_1d = arr.ndim == 1 60 | x = arr.shape[0] 61 | y = 1 if is_1d else arr.shape[1] 62 | 63 | for j in range(y): 64 | a = arr if is_1d else arr[:, j] 65 | b = out if is_1d else out[:, j] 66 | s = 0 67 | for i in range(x): 68 | if a[i]: 69 | s = 0 70 | b[i] = s 71 | s += 1 72 | 73 | return out 74 | 75 | 76 | def BARSLAST(S): 77 | """BARSLAST(X),上一次X不为0到现在的天数 78 | 79 | 成立当天输出0 80 | """ 81 | S = pd_to_np(S, copy=False) 82 | out = np.zeros_like(S, dtype=int) 83 | return _bars_last_nb(S, out) 84 | 85 | 86 | @numba.jit(nopython=True, cache=True, nogil=True) 87 | def _bars_last_count_nb(arr, out): 88 | """ 89 | 90 | Parameters 91 | ---------- 92 | arr 93 | out 94 | 95 | References 96 | ---------- 97 | https://stackoverflow.com/questions/18196811/cumsum-reset-at-nan 98 | 99 | """ 100 | is_1d = arr.ndim == 1 101 | x = arr.shape[0] 102 | y = 1 if is_1d else arr.shape[1] 103 | 104 | for j in range(y): 105 | a = arr if is_1d else arr[:, j] 106 | b = out if is_1d else out[:, j] 107 | s = 0 108 | for i in range(x): 109 | if a[i]: 110 | s += 1 111 | b[i] = s 112 | else: 113 | s = 0 114 | 115 | return out 116 | 117 | 118 | def BARSLASTCOUNT(S): 119 | """BARSLASTCOUNT(X),统计连续满足X条件的周期数 120 | 121 | 成立第一天输出1 122 | """ 123 | S = pd_to_np(S, copy=False) 124 | out = np.zeros_like(S, dtype=int) 125 | return _bars_last_count_nb(S, out) 126 | 127 | 128 | @numba.jit(nopython=True, cache=True, nogil=True) 129 | def _bars_since_n_nb(a, n): 130 | """BARSSINCEN(X,N):N周期内第一次X不为0到现在的天数""" 131 | for i, x in enumerate(a): 132 | if x: 133 | return n - i - 1 134 | return 0 135 | 136 | 137 | def BARSSINCEN(cond, timeperiod): 138 | """BARSSINCEN(X,N):N周期内第一次X不为0到现在的天数""" 139 | return numpy_rolling_apply_1([pd_to_np(cond)], timeperiod, _rolling_func_1_1_nb, _bars_since_n_nb, timeperiod) 140 | -------------------------------------------------------------------------------- /ta_cn/tdx/statistics.py: -------------------------------------------------------------------------------- 1 | import numba 2 | import numpy as np 3 | 4 | from .. import bn_wraps as bn 5 | from ..nb import numpy_rolling_apply_1, _rolling_func_1_1_nb 6 | from ..utils import pd_to_np 7 | 8 | 9 | @numba.jit(nopython=True, cache=True, nogil=True) 10 | def _avedev_nb(a): 11 | """avedev平均绝对偏差""" 12 | return np.mean(np.abs(a - np.mean(a))) 13 | 14 | 15 | def AVEDEV(real, timeperiod: int): 16 | """平均绝对偏差 17 | 18 | AVEDEV(real, timeperiod=20) 19 | """ 20 | return numpy_rolling_apply_1([pd_to_np(real)], timeperiod, _rolling_func_1_1_nb, _avedev_nb) 21 | 22 | 23 | def STD(x, d): 24 | """样本标准差""" 25 | return bn.move_std(x, window=d, axis=0, ddof=1) 26 | 27 | 28 | def STDP(x, d): 29 | """总体标准差""" 30 | return bn.move_std(x, window=d, axis=0, ddof=0) 31 | 32 | 33 | def VAR(x, d): 34 | """样本方差""" 35 | return bn.move_var(x, window=d, axis=0, ddof=1) 36 | 37 | 38 | def VARP(x, d): 39 | """总体方差""" 40 | return bn.move_var(x, window=d, axis=0, ddof=0) 41 | 42 | 43 | @numba.jit(nopython=True, cache=True, nogil=True) 44 | def _limit_count_nb(arr, out1, out2, d): 45 | is_1d = arr.ndim == 1 46 | x = arr.shape[0] 47 | y = 1 if is_1d else arr.shape[1] 48 | 49 | for j in range(y): 50 | a = arr if is_1d else arr[:, j] 51 | nn = out1 if is_1d else out1[:, j] 52 | mm = out2 if is_1d else out2[:, j] 53 | n = 0 # N天 54 | m = 0 # M板 55 | k = 0 # 连续False个数 56 | f = True # 前面的False不处理 57 | for i in range(x): 58 | if a[i]: 59 | # 正常统计 60 | k = 0 61 | n += 1 62 | m += 1 63 | nn[i] = n 64 | mm[i] = m 65 | f = False 66 | else: 67 | if f: 68 | continue 69 | k += 1 # 非False计数 70 | nn[i] = -k # 表示离上涨停的天数,-1表示昨天是涨停的 71 | if k > d: 72 | m = 0 73 | n = 0 74 | else: 75 | n += 1 76 | mm[i] = 0 77 | # N天M板 78 | return out1, out2 79 | 80 | 81 | def limit_count(x, d): 82 | """涨停统计 83 | 84 | BARSLASTCOUNT可以统计连板数。但无法统计N天M板这种情况. 85 | 86 | Parameters 87 | ---------- 88 | x 89 | d: int 90 | 0表示必须连板 91 | 1表示可以间隔1天板。出现5天3板,9天5板都是有可能的 92 | 2表示可以间隔2天板。所以4天2板这种情况一定要用2才能区分 93 | 3以此类推 94 | 95 | Returns 96 | ------- 97 | out1 98 | 总天数。小于0时,表示前几天是涨停。大于0时是累计天数 99 | out2 100 | 板数。累计连板数。断板时会延用上次板数d天时间 101 | 102 | Examples 103 | -------- 104 | >>> a = np.array([0, 1, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0, 1]) 105 | >>> b, c = limit_count(a, 2) 106 | >>> print(a) 107 | >>> print(b) 108 | >>> print(c) 109 | 110 | >>> [0 1 0 1 1 0 0 1 1 0 0 0 0 1] 111 | >>> [0 1 -1 3 4 -1 -2 1 2 -1 -2 -3 -4 1] 112 | >>> [0 1 0 2 3 0 0 1 2 0 0 0 0 1] 113 | 114 | 解读一下,如: 115 | 0. 无 116 | 1. 1天1板 117 | 2. 昨天1天1板 118 | 3. 3天2板 119 | 4. 4天3板 120 | 5. 昨天4天3板 121 | 6. 前天4天3板 122 | 7. 1天1板 123 | 8. 2天2板 124 | 125 | """ 126 | x = pd_to_np(x, copy=False) 127 | out1 = np.zeros_like(x, dtype=int) 128 | out2 = np.zeros_like(x, dtype=int) 129 | return _limit_count_nb(x, out1, out2, d) 130 | -------------------------------------------------------------------------------- /ta_cn/tdx/trend.py: -------------------------------------------------------------------------------- 1 | from talib import MA_Type 2 | 3 | from . import COUNT, REF, DIFF, MA, MEAN 4 | from .. import talib as ta 5 | from ..talib import set_compatibility, TA_COMPATIBILITY_METASTOCK 6 | 7 | _ta1d = ta.init(mode=1, skipna=False) 8 | _ta2d = ta.init(mode=2, skipna=False) 9 | 10 | 11 | def BBI(real, timeperiod1: int, timeperiod2: int, timeperiod3: int, timeperiod4: int): 12 | """BBI多空指标 13 | 14 | BBI(real, timeperiod1=3, timeperiod2=6, timeperiod3=12, timeperiod4=24) 15 | 16 | """ 17 | return MEAN(MA(real, timeperiod1), 18 | MA(real, timeperiod2), 19 | MA(real, timeperiod3), 20 | MA(real, timeperiod4)) 21 | 22 | 23 | def DPO(real, timeperiod: int): 24 | """DPO区间震荡线 25 | 26 | DPO(real, timeperiod=20) 27 | 28 | Parameters 29 | ---------- 30 | real 31 | timeperiod:int 32 | 常用参数:12 33 | 34 | Examples 35 | -------- 36 | 股票软件上常再求一次MA 37 | >>> MA(DPO(CLOSE, 20), 6) 38 | 39 | """ 40 | return real - REF(MA(real, timeperiod), timeperiod // 2) 41 | 42 | 43 | def MACD(real, fastperiod: int, slowperiod: int, signalperiod: int): 44 | """MACD指标 45 | 46 | MACD(real, fastperiod=12, slowperiod=26, signalperiod=9) 47 | 48 | """ 49 | set_compatibility(TA_COMPATIBILITY_METASTOCK) 50 | 51 | macd, macdsignal, macdhist = _ta2d.MACDEXT(real, 52 | fastperiod=fastperiod, fastmatype=MA_Type.EMA, 53 | slowperiod=slowperiod, slowmatype=MA_Type.EMA, 54 | signalperiod=signalperiod, signalmatype=MA_Type.EMA) 55 | 56 | # macd起始位不是按slowperiod-1,而是按slowperiod+signalperiod-2,可能是为了三个输出的起始位相同 57 | # talib中的MACD没有*2 58 | return macd, macdsignal, macdhist * 2 59 | 60 | 61 | def MTM(real, timeperiod: int): 62 | """MTM动量指标 63 | 64 | MTM(real, timeperiod=12) 65 | 66 | Parameters 67 | ---------- 68 | real 69 | timeperiod:int 70 | 常用参数:12 71 | 72 | Examples 73 | -------- 74 | 股票软件上常再求一次MA 75 | >>> MA(MTM(CLOSE, 12), 6) 76 | 77 | """ 78 | if real.ndim == 2: 79 | return DIFF(real, timeperiod) 80 | else: 81 | return _ta1d.MOM(real, timeperiod=timeperiod) 82 | 83 | 84 | def PSY(real, timeperiod: int): 85 | """PSY心理线 86 | 87 | PSY(real, timeperiod=12) 88 | 89 | Parameters 90 | ---------- 91 | real 92 | timeperiod:int 93 | 常用参数:12 94 | 95 | Examples 96 | -------- 97 | 股票软件上常再求一次MA 98 | >>> MA(PSY(CLOSE, 12), 6) 99 | 100 | """ 101 | return COUNT(real > REF(real, 1), timeperiod) / timeperiod * 100 102 | 103 | 104 | def DM(high, low, timeperiod): 105 | """Directional Movement方向动量 106 | 107 | DM(high, low, timeperiod=14) 108 | 109 | WS_SUM威尔德平滑求和 110 | """ 111 | return _ta2d.PLUS_DM(high, low, timeperiod=timeperiod), _ta2d.MINUS_DM(high, low, timeperiod=timeperiod) 112 | 113 | 114 | def DI(high, low, close, timeperiod: int): 115 | """Directional Indicator方向指标 116 | 117 | DI(high, low, close, timeperiod=14) 118 | """ 119 | return _ta2d.PLUS_DI(high, low, close, timeperiod=timeperiod), _ta2d.MINUS_DI(high, low, close, 120 | timeperiod=timeperiod) 121 | 122 | 123 | def DMI(high, low, close, timeperiod: int): 124 | """趋向指标 125 | 126 | DMI(high, low, close, timeperiod=14) 127 | """ 128 | return (_ta2d.PLUS_DI(high, low, close, timeperiod=timeperiod), 129 | _ta2d.MINUS_DI(high, low, close, timeperiod=timeperiod), 130 | _ta2d.ADX(high, low, close, timeperiod=timeperiod), 131 | _ta2d.ADXR(high, low, close, timeperiod=timeperiod), 132 | ) 133 | 134 | 135 | def TRIX(real, timeperiod: int): 136 | """三重指数平滑均线 137 | 138 | TRIX(real, timeperiod=12) 139 | 140 | 由EMA算法差异导致的不同 141 | """ 142 | set_compatibility(TA_COMPATIBILITY_METASTOCK) 143 | 144 | return _ta2d.TRIX(real, timeperiod=timeperiod) 145 | -------------------------------------------------------------------------------- /ta_cn/tdx/volume.py: -------------------------------------------------------------------------------- 1 | from . import IF 2 | from . import SGN 3 | from . import SUM, DIFF, REF 4 | 5 | 6 | def OBV(real, volume, scale): # 能量潮指标 7 | """能量潮指标 8 | 9 | OBV(real, volume, scale=1) 10 | """ 11 | # 同花顺最后会除10000,但东方财富没有除 scale=1 / 10000 12 | return SUM(SGN(DIFF(real)) * volume, 0) * scale 13 | 14 | 15 | def VR(close, volume, timeperiod: int): 16 | """VR容量比率 17 | 18 | VR(close, volume, timeperiod=26) 19 | """ 20 | LC = REF(close, timeperiod=1) 21 | return SUM(IF(close > LC, volume, 0), timeperiod) / SUM(IF(close <= LC, volume, 0), timeperiod) * 100 22 | -------------------------------------------------------------------------------- /ta_cn/utils_long.py: -------------------------------------------------------------------------------- 1 | """ 2 | 长表处理工具 3 | """ 4 | from functools import wraps 5 | 6 | import pandas as pd 7 | 8 | from .utils import to_pd 9 | 10 | 11 | def series_groupby_apply(func, by='asset', dropna=True, to_kwargs={1: 'timeperiod'}, output_num=1): 12 | """普通指标转换成按分组处理的指标。只支持单参数 13 | 14 | Parameters 15 | ---------- 16 | func 17 | by 18 | dropna 19 | to_kwargs 20 | output_num 21 | 22 | Notes 23 | ----- 24 | 只能处理Series 25 | 26 | """ 27 | 28 | @wraps(func) 29 | def decorated(*args, **kwargs): 30 | # 参数位置调整,实现命命参数简化 31 | _kwargs = {k: args[i] for i, k in to_kwargs.items() if i < len(args)} 32 | s1 = args[0] 33 | 34 | if dropna: 35 | s2 = s1.dropna() 36 | else: 37 | s2 = s1 38 | 39 | if len(s2) == 0: 40 | if output_num == 1: 41 | return pd.Series(index=s1.index, dtype=float) 42 | else: 43 | return tuple([pd.Series(index=s1.index, dtype=float) for i in range(output_num)]) 44 | 45 | s3 = s2.groupby(by=by, group_keys=False).apply(to_pd(func), **_kwargs, **kwargs) 46 | 47 | if output_num == 1: 48 | # 单输出 49 | if len(s1) == len(s2): 50 | return s3 51 | else: 52 | # 由于长度变化了,只能重整长度 53 | return pd.Series(s3, index=s1.index) 54 | else: 55 | # 多输出 56 | return tuple([pd.concat([s[i] for s in s3]) for i in range(output_num)]) 57 | 58 | return decorated 59 | 60 | 61 | def dataframe_groupby_apply(func, by='asset', dropna=True, to_df=[0, 1], to_kwargs={2: 'timeperiod'}, output_num=1): 62 | """普通指标转换成按分组处理的指标。支持多输入 63 | 64 | Parameters 65 | ---------- 66 | func 67 | by 68 | dropna 69 | to_df 70 | to_kwargs 71 | output_num 72 | 73 | Notes 74 | ----- 75 | 即能处理DataFrame,又能处理Series,但考虑到效率,单输入时使用series_groupby_apply 76 | 77 | """ 78 | 79 | def get(i, k, args, kwargs): 80 | if i == k: 81 | return args[i] 82 | if isinstance(k, str): 83 | v = kwargs.get(k, None) 84 | if v is None: 85 | return args[i] 86 | return v 87 | 88 | @wraps(func) 89 | def decorated(*args, **kwargs): 90 | _kwargs = {k: args[i] for i, k in to_kwargs.items() if i < len(args)} 91 | s1 = pd.DataFrame({k: get(i, k, args, kwargs) for i, k in enumerate(to_df)}) 92 | 93 | if dropna: 94 | s2 = s1.dropna() 95 | else: 96 | s2 = s1 97 | 98 | if len(s2) == 0: 99 | if output_num == 1: 100 | return pd.Series(index=s1.index, dtype=float) 101 | else: 102 | return tuple([pd.Series(index=s1.index, dtype=float) for i in range(output_num)]) 103 | 104 | s3 = s2.groupby(by=by, group_keys=False).apply(to_pd(dataframe_split(func)), **_kwargs) 105 | 106 | if output_num == 1: 107 | # 单输出 108 | if len(s1) == len(s2): 109 | return s3 110 | else: 111 | # 由于长度变化了,只能重整长度 112 | return pd.Series(s3, index=s1.index) 113 | else: 114 | # 多输出 115 | return tuple([pd.concat([s[i] for s in s3]) for i in range(output_num)]) 116 | 117 | return decorated 118 | 119 | 120 | def dataframe_split(func): 121 | """将第一个DataFrame分拆传到指定函数""" 122 | 123 | @wraps(func) 124 | def decorated(df: pd.DataFrame, *args, **kwargs): 125 | ss = df.to_dict(orient='series') 126 | args_input = [v for k, v in ss.items() if isinstance(k, int)] 127 | kwargs_input = {k: v for k, v in ss.items() if not isinstance(k, int)} 128 | return func(*args_input, *args, **kwargs_input, **kwargs) 129 | 130 | return decorated 131 | -------------------------------------------------------------------------------- /ta_cn/wq/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | 本项目参考了WorldQuant BRAIN中Fast Expression的语法 3 | 4 | References 5 | ---------- 6 | 1. https://platform.worldquantbrain.com/learn/data-and-operators/operators 7 | 2. https://platform.worldquantbrain.com/learn/data-and-operators/detailed-operator-descriptions 8 | """ 9 | -------------------------------------------------------------------------------- /ta_cn/wq/arithmetic.py: -------------------------------------------------------------------------------- 1 | """ 2 | Arithmetic Operators 3 | """ 4 | import numpy as np 5 | from functools import reduce 6 | 7 | from .time_series import ts_delta 8 | 9 | 10 | def abs_(x): 11 | """Absolute value of x""" 12 | return np.absolute(x) 13 | 14 | 15 | def add(*args, filter=False): 16 | """Add all inputs (at least 2 inputs required). If filter = true, filter all input NaN to 0 before adding""" 17 | 18 | def _add(x, y): 19 | mask = np.isnan(x) & np.isnan(y) 20 | x = np.nan_to_num(x, nan=0) 21 | y = np.nan_to_num(y, nan=0) 22 | x = np.ma.masked_array(x, mask=mask) 23 | y = np.ma.masked_array(y, mask=x.mask) 24 | return (x + y).filled(np.nan) 25 | 26 | func = _add if filter else np.add 27 | 28 | with np.errstate(over='ignore'): 29 | return reduce(func, args) 30 | 31 | 32 | def ceiling(x): 33 | """Nearest larger integer""" 34 | return np.ceil(x) 35 | 36 | 37 | def divide(*args): 38 | """x / y 39 | 40 | RuntimeWarning: invalid value encountered in true_divide 41 | """ 42 | with np.errstate(divide='ignore', invalid='ignore'): 43 | return reduce(np.true_divide, args) 44 | 45 | 46 | def safe_divide(a, b, posinf_=np.inf, neginf_=-np.inf, nan_=np.nan, is_close=True, replace=True): 47 | """安全除法 48 | 1. 解决除0时产生异常值。如:1/0=inf -1/0=-inf 0/0=nan 49 | 2. 解决数字接近0时产生异常值 50 | 51 | Parameters 52 | ---------- 53 | a 54 | b 55 | posinf_: float 56 | 正数/0时,将inf替换成posinf_ 57 | neginf_: float 58 | 负数/0时,将-inf替换成neginf_ 59 | nan_: float 60 | 0/0时,将nan替换成nan_,其它nan/0, 0/nan, nan/nan还是保持nan 61 | is_close: bool 62 | 数字接近于0时是否当成0使用。防止除很小值时产生很大值 63 | replace: bool 64 | 是否进行安全除法替换 65 | 66 | """ 67 | if is_close: 68 | a_is_zero = np.isclose(a, 0) 69 | b_is_zero = np.isclose(b, 0) 70 | else: 71 | a_is_zero = a == 0 72 | b_is_zero = b == 0 73 | 74 | with np.errstate(divide='ignore', invalid='ignore'): 75 | out = a / b 76 | 77 | if replace: 78 | out[b_is_zero & (a > 0)] = posinf_ 79 | out[b_is_zero & (a < 0)] = neginf_ 80 | out[b_is_zero & a_is_zero] = nan_ 81 | 82 | return out 83 | 84 | 85 | def exp(x): 86 | """Natural exponential function: e^x""" 87 | return np.exp(x) 88 | 89 | 90 | def floor(x): 91 | """Nearest smaller integer""" 92 | return np.floor(x) 93 | 94 | 95 | def fraction(x): 96 | """This operator removes the whole number part and returns the remaining fraction part with sign.""" 97 | # return sign(x) * (abs(x) - floor(abs(x))) 98 | return sign(x) * (abs_(x) % 1.) 99 | 100 | 101 | def inverse(x): 102 | """1 / x""" 103 | x = x.copy() 104 | x[x == 0] = np.nan 105 | return 1 / x 106 | 107 | 108 | def log(x): 109 | """Natural logarithm. For example: Log(high/low) uses natural logarithm of high/low ratio as stock weights.""" 110 | return np.log(np.where(x > 0, x, np.nan)) 111 | 112 | 113 | def log_diff(x): 114 | """Returns log(current value of input or x[t] ) - log(previous value of input or x[t-1]).""" 115 | return ts_delta(log(x), 1) 116 | 117 | 118 | def max_(*args): 119 | """Maximum value of all inputs. At least 2 inputs are required.""" 120 | return reduce(np.maximum, args) 121 | 122 | 123 | def min_(*args): 124 | """Minimum value of all inputs. At least 2 inputs are required.""" 125 | return reduce(np.minimum, args) 126 | 127 | 128 | def multiply(*args, filter=False): 129 | """Multiply all inputs. At least 2 inputs are required. Filter sets the NaN values to 1.""" 130 | 131 | def _multiply(x, y): 132 | mask = np.isnan(x) & np.isnan(y) 133 | x = np.nan_to_num(x, nan=1) 134 | y = np.nan_to_num(y, nan=1) 135 | x = np.ma.masked_array(x, mask=mask) 136 | y = np.ma.masked_array(y, mask=x.mask) 137 | return (x * y).filled(np.nan) 138 | 139 | func = _multiply if filter else np.multiply 140 | 141 | with np.errstate(over='ignore'): 142 | return reduce(func, args) 143 | 144 | 145 | def nan_mask(x, y): 146 | """replace input with NAN if input's corresponding mask value or the second input here, is negative.""" 147 | x = x.copy() 148 | x[y < 0] = np.nan 149 | return x 150 | 151 | 152 | def nan_out(x, lower=0, upper=0): 153 | """If x < lower or x > upper return NaN, else return x. At least one of "lower", "upper" is required.""" 154 | return np.where((x < lower) | (x > upper), np.nan, x) 155 | 156 | 157 | def power(x, y): 158 | """x ^ y""" 159 | with np.errstate(divide='ignore', invalid='ignore'): 160 | r = np.power(x, y) 161 | # 有可能产生inf,是否需要处理 162 | # r[np.isinf(r)] = np.nan 163 | return r 164 | 165 | 166 | def purify(x): 167 | """Clear infinities (+inf, -inf) by replacing with NaN.""" 168 | x = x.copy() 169 | x[np.isinf(x)] = np.nan 170 | return x 171 | 172 | 173 | def replace(x, target=[-np.inf, np.inf], dest=[np.nan, np.nan]): 174 | """Replace target values in input with destination values.""" 175 | x = x.copy() 176 | for t, d in zip(target, dest): 177 | x[x == t] = d 178 | return x 179 | 180 | 181 | def reverse(x): 182 | """- x""" 183 | return -x 184 | 185 | 186 | def round_(x): 187 | """Round input to closest integer.""" 188 | return np.around(x) 189 | 190 | 191 | def round_down(x, f=1): 192 | """Round input to greatest multiple of f less than input;""" 193 | with np.errstate(invalid='ignore', divide='ignore'): 194 | return x // f * f 195 | 196 | 197 | def sign(x): 198 | """if input = NaN; return NaN 199 | else if input > 0, return 1 200 | else if input < 0, return -1 201 | else if input = 0, return 0 202 | 203 | return (0 < x) * 1 - (x < 0) 204 | """ 205 | return np.sign(x) 206 | 207 | 208 | def signed_power(x, y): 209 | """x raised to the power of y such that final result preserves sign of x.""" 210 | with np.errstate(invalid='ignore', divide='ignore'): 211 | return sign(x) * (abs_(y) ** y) 212 | 213 | 214 | def s_log_1p(x): 215 | """Confine function to a shorter range using logarithm such that higher input remains higher and negative input remains negative as an output of resulting function and -1 or 1 is an asymptotic value.""" 216 | with np.errstate(invalid='ignore', divide='ignore'): 217 | return sign(x) * log(1 + abs_(x)) 218 | 219 | 220 | def sqrt(x): 221 | """Square root of x""" 222 | with np.errstate(invalid='ignore'): 223 | return np.sqrt(x) 224 | 225 | 226 | def subtract(*args, filter=False): 227 | """x-y. If filter = true, filter all input NaN to 0 before subtracting""" 228 | 229 | def _subtract(x, y): 230 | mask = np.isnan(x) & np.isnan(y) 231 | x = np.nan_to_num(x, nan=0) 232 | y = np.nan_to_num(y, nan=0) 233 | x = np.ma.masked_array(x, mask=mask) 234 | y = np.ma.masked_array(y, mask=x.mask) 235 | return (x - y).filled(np.nan) 236 | 237 | func = _subtract if filter else np.subtract 238 | 239 | with np.errstate(over='ignore'): 240 | return reduce(func, args) 241 | 242 | 243 | def to_nan(x, value=0, reverse=False): 244 | """Convert value to NaN or NaN to value if reverse=true""" 245 | x = x.copy() 246 | if reverse: 247 | x[np.isnan(x)] = value 248 | else: 249 | x[x == value] = np.nan 250 | return x 251 | 252 | 253 | def densify(x): 254 | """Converts a grouping field of many buckets into lesser number of only available buckets so as to make working with grouping fields computationally efficient.""" 255 | return x 256 | 257 | 258 | # ---------------- 259 | def log10(x): 260 | """10为底对数收益率""" 261 | return np.log10(np.where(x > 0, x, np.nan)) 262 | 263 | 264 | # 过于简单,直接添加于此 265 | def mean(*args, filter=True): 266 | """多个均值""" 267 | return add(*args, filter=filter) / len(args) 268 | -------------------------------------------------------------------------------- /ta_cn/wq/cross_sectional.py: -------------------------------------------------------------------------------- 1 | """ 2 | Cross Sectional Operators 3 | """ 4 | import numpy as np 5 | 6 | from .arithmetic import add, abs_ 7 | from .. import bn_wraps as bn 8 | from ..regress import multiple_regress 9 | from ..utils import pd_to_np 10 | 11 | 12 | def normalize(x, useStd=False, limit=0.0): 13 | """Calculates the mean value of all valid alpha values for a certain date, then subtracts that mean from each element.""" 14 | x = pd_to_np(x, copy=False) 15 | if np.isnan(x).all(): 16 | return x 17 | 18 | axis = x.ndim - 1 19 | t1 = np.nanmean(x, axis=axis, keepdims=True) 20 | if useStd: 21 | # 这里用ddof=1才能与文档示例的数值对应上 22 | t2 = np.nanstd(x, axis=axis, keepdims=True, ddof=1) 23 | with np.errstate(divide='ignore', invalid='ignore'): 24 | r = (x - t1) / t2 25 | else: 26 | r = (x - t1) 27 | 28 | return r if limit == 0 else np.clip(r, -limit, limit) 29 | 30 | 31 | def one_side(x, side='long'): 32 | """Shifts all instruments up or down so that the Alpha becomes long-only or short-only 33 | (if side = short), respectively.""" 34 | # TODO: 这里不确定,需再研究 35 | # [-1, 0, 1]+1=[0, 1, 2] 36 | # max([-1, 0, 1], 0)=[0,0,1] 37 | if side == 'long': 38 | return np.maximum(x, 0) 39 | else: 40 | return np.minimum(x, 0) 41 | 42 | 43 | def quantile(x, driver='gaussian', sigma=1.0): 44 | """Rank the raw vector, shift the ranked Alpha vector, apply distribution ( gaussian, cauchy, uniform ). If driver is uniform, it simply subtract each Alpha value with the mean of all Alpha values in the Alpha vector.""" 45 | pass 46 | 47 | 48 | def rank(x, rate=2, pct=True): 49 | """Ranks the input among all the instruments and returns an equally distributed number between 0.0 and 1.0. For precise sort, use the rate as 0.""" 50 | axis = x.ndim - 1 51 | t1 = bn.nanrankdata(x, axis=axis) 52 | 53 | if pct: 54 | t2 = np.nansum(~np.isnan(x), axis=axis, keepdims=True) 55 | with np.errstate(divide='ignore', invalid='ignore'): 56 | return t1 / t2 57 | else: 58 | return t1 59 | 60 | 61 | def rank_by_side(x, rate=2, scale=1): 62 | """Ranks positive and negative input separately and scale to book. For precise sorting use rate=0.""" 63 | pass 64 | 65 | 66 | def generalized_rank(open, m=1): 67 | """The idea is that difference between instrument values raised to the power of m is added to the rank of instrument with bigger value and subtracted from the rank of instrument with lesser value. More details in the notes at the end of page.""" 68 | pass 69 | 70 | 71 | def regression_neut(y, x): 72 | """Conducts the cross-sectional regression on the stocks with Y as target and X as the independent variable.""" 73 | return multiple_regress(y, x, add_constant=True)[0] 74 | 75 | 76 | def regression_proj(y, x): 77 | """Conducts the cross-sectional regression on the stocks with Y as target and X as the independent variable.""" 78 | return multiple_regress(y, x, add_constant=True)[1] 79 | 80 | 81 | def scale(x, scale=1, longscale=1, shortscale=1): 82 | """Scales input to booksize. We can also scale the long positions and short positions to separate scales by mentioning additional parameters to the operator.""" 83 | axis = x.ndim - 1 84 | if longscale != 1 or shortscale != 1: 85 | L = np.where(x > 0, x, np.nan) 86 | S = np.where(x < 0, x, np.nan) 87 | 88 | sum_l = np.nansum(abs_(L), axis=axis, keepdims=True) 89 | sum_s = np.nansum(abs_(S), axis=axis, keepdims=True) 90 | 91 | with np.errstate(divide='ignore', invalid='ignore'): 92 | return add(L / sum_l * longscale, S / sum_s * shortscale, filter=True) 93 | else: 94 | sum_x = np.nansum(abs_(x), axis=axis, keepdims=True) 95 | with np.errstate(divide='ignore', invalid='ignore'): 96 | return x / sum_x * scale 97 | 98 | 99 | def scale_down(x, constant=0): 100 | """Scales all values in each day proportionately between 0 and 1 such that minimum value maps to 0 and maximum value maps to 1. Constant is the offset by which final result is subtracted.""" 101 | axis = x.ndim - 1 102 | m1 = np.nanmin(x, axis=axis, keepdims=True) 103 | m2 = np.nanmax(x, axis=axis, keepdims=True) 104 | 105 | with np.errstate(divide='ignore', invalid='ignore'): 106 | return (x - m1) / (m2 - m1) - constant 107 | 108 | 109 | def truncate(x, maxPercent=0.01): 110 | """Operator truncates all values of x to maxPercent. Here, maxPercent is in decimal notation.""" 111 | axis = x.ndim - 1 112 | t1 = np.nansum(x, axis=axis, keepdims=True) * maxPercent 113 | 114 | return np.minimum(x, t1) 115 | 116 | 117 | def vector_neut(x, y): 118 | """For given vectors x and y, it finds a new vector x* (output) such that x* is orthogonal to y.""" 119 | pass 120 | 121 | 122 | def vector_proj(x, y): 123 | """Returns vector projection of x onto y. Algebraic and geometric details can be found on wiki""" 124 | pass 125 | 126 | 127 | def winsorize(x, std=4): 128 | """Winsorizes x to make sure that all values in x are between the lower and upper limits, which are specified as multiple of std. Details can be found on wiki""" 129 | x = pd_to_np(x, copy=False) 130 | axis = x.ndim - 1 131 | _mean = np.nanmean(x, axis=axis, keepdims=True) 132 | _std = np.nanstd(x, axis=axis, keepdims=True, ddof=0) * std 133 | 134 | return np.clip(x, _mean - _std, _mean + _std) 135 | 136 | 137 | def zscore(x): 138 | """Z-score is a numerical measurement that describes a value's relationship to the mean of a group of values. Z-score is measured in terms of standard deviations from the mean""" 139 | x = pd_to_np(x, copy=False) 140 | if np.isnan(x).all(): 141 | return x 142 | axis = x.ndim - 1 143 | _mean = np.nanmean(x, axis=axis, keepdims=True) 144 | _std = np.nanstd(x, axis=axis, keepdims=True, ddof=0) 145 | 146 | with np.errstate(divide='ignore', invalid='ignore'): 147 | return (x - _mean) / _std 148 | 149 | 150 | def rank_gmean_amean_diff(*args): 151 | """Operator returns difference of geometric mean and arithmetic mean of cross sectional rank of inputs.""" 152 | 153 | # TODO: 输入输出的形式还没搞清,核心功能已经实现先放这 154 | def _gmean(x): 155 | return np.exp(np.nanmean(np.log(x))) 156 | 157 | inputs = rank(np.array(args)) 158 | inputs = np.array(args) 159 | return _gmean(inputs) - np.nanmean(inputs) 160 | -------------------------------------------------------------------------------- /ta_cn/wq/group.py: -------------------------------------------------------------------------------- 1 | """ 2 | Group Operators 3 | 4 | 分组计算 5 | 返回的组内是一个值,还是多个值? 6 | 1. group_count等等肯定是一个值。输出索引是date,group 7 | 2. group_rankt等等是输出多个值。输出索引是date,asset 8 | 9 | """ 10 | import numpy as np 11 | 12 | from .. import BY_GROUP 13 | from ..utils_long import dataframe_groupby_apply 14 | 15 | 16 | def group_backfill(x, group, d, std=4.0): 17 | """If a certain value for a certain date and instrument is NaN, from the set of same group instruments, calculate winsorized mean of all non-NaN values over last d days.""" 18 | pass 19 | 20 | 21 | def group_count(x, group): 22 | """Gives the number of instruments in the same group (e.g. sector) which have valid values of x. For example, x=1 gives the number of instruments in each group (without regard for whether any particular field has valid data).""" 23 | func = dataframe_groupby_apply(len, by=BY_GROUP, to_df=[0, 'group'], to_kwargs={}) 24 | 25 | return func(x, group) 26 | 27 | 28 | def group_extra(x, weight, group): 29 | """Replaces NaN values by their corresponding group means.""" 30 | pass 31 | # func = dataframe_groupby_apply(None, by=BY_GROUP, to_df=[0, 'group'], to_kwargs={}, dropna=False) 32 | # 33 | # return func(x, group) 34 | 35 | 36 | def group_max(x, group): 37 | """Maximum of x for all instruments in the same group.""" 38 | func = dataframe_groupby_apply(np.max, by=BY_GROUP, to_df=[0, 'group'], to_kwargs={}) 39 | 40 | return func(x, group) 41 | 42 | 43 | def group_mean(x, weight, group): 44 | """All elements in group equals to the mean value of the group. Mean = sum(data*weight) / sum(weight) in each group.""" 45 | 46 | # TODO: 这里的权重是与x等长的权重序列,还是与股票数一样的权重? 47 | def _mean(x, weight): 48 | return np.average(x, weights=weight) 49 | 50 | func = dataframe_groupby_apply(_mean, by=BY_GROUP, to_df=[0, 1, 'group'], to_kwargs={}) 51 | 52 | return func(x, weight, group) 53 | 54 | 55 | def group_median(x, group): 56 | """All elements in group equals to the median value of the group.""" 57 | func = dataframe_groupby_apply(np.median, by=BY_GROUP, to_df=[0, 'group'], to_kwargs={}) 58 | 59 | return func(x, group) 60 | 61 | 62 | def group_min(x, group): 63 | """All elements in group equals to the min value of the group.""" 64 | func = dataframe_groupby_apply(np.min, by=BY_GROUP, to_df=[0, 'group'], to_kwargs={}) 65 | 66 | return func(x, group) 67 | 68 | 69 | def group_neutralize(x, group): 70 | """Neutralizes Alpha against groups. These groups can be subindustry, industry, sector, country or a constant.""" 71 | 72 | def _demean(x): 73 | """行业中性化,需要与groupby配合使用 74 | 75 | RuntimeWarning: Mean of empty slice 76 | nanmean在全nan时报此警告。这个警告还不好屏蔽 77 | """ 78 | return x - np.nanmean(x) 79 | 80 | func = dataframe_groupby_apply(_demean, by=BY_GROUP, to_df=[0, 'group'], to_kwargs={}) 81 | 82 | return func(x, group) 83 | 84 | 85 | def group_normalize(x, group, constantCheck=False, tolerance=0.01, scale=1): 86 | """Normalizes input such that each group's absolute sum is 1.""" 87 | 88 | # 发现这里的group_normalize其实像scale,也就是不少地方混乱 89 | def _normalize(x, scale): 90 | sum_x = np.sum(abs(x)) 91 | return x / sum_x * scale 92 | 93 | func = dataframe_groupby_apply(_normalize, by=BY_GROUP, to_df=[0, 'group'], to_kwargs={2: 'scale'}) 94 | 95 | return func(x, group, scale) 96 | 97 | 98 | def group_percentage(x, group, percentage=0.5): 99 | """All elements in group equals to the value over the percentage of the group. 100 | Percentage = 0.5 means value is equal to group_median(x, group)""" 101 | 102 | func = dataframe_groupby_apply(np.quantile, by=BY_GROUP, to_df=[0, 'group'], to_kwargs={2: 'q'}) 103 | 104 | return func(x, group, percentage) 105 | 106 | 107 | def group_vector_proj(x, y, g): 108 | """ 109 | Similar to vector_proj(x, y) but x projection to y for each group which can be any classifier such as subindustry, industry, sector, etc. Refer wiki for more details""" 110 | pass 111 | 112 | 113 | def group_rank(x, group): 114 | """Each elements in a group is assigned the corresponding rank in this group""" 115 | func = dataframe_groupby_apply(lambda x: x.rank(pct=True), by=BY_GROUP, to_df=[0, 'group'], to_kwargs={}) 116 | 117 | return func(x, group) 118 | 119 | 120 | def group_scale(x, group): 121 | """Normalizes the values in a group to be between 0 and 1. (x - groupmin) / (groupmax - groupmin)""" 122 | 123 | def _scale(x): 124 | t1 = np.min(x) 125 | t2 = np.max(x) 126 | return (x - t1) / (t2 - t1) 127 | 128 | func = dataframe_groupby_apply(_scale, by=BY_GROUP, to_df=[0, 'group'], to_kwargs={}) 129 | 130 | return func(x, group) 131 | 132 | 133 | def group_std_dev(x, group): 134 | """All elements in group equals to the standard deviation of the group.""" 135 | 136 | func = dataframe_groupby_apply(np.std, by=BY_GROUP, to_df=[0, 'group'], to_kwargs={}) 137 | 138 | return func(x, group) 139 | 140 | 141 | def group_sum(x, group): 142 | """Sum of x for all instruments in the same group.""" 143 | func = dataframe_groupby_apply(np.sum, by=BY_GROUP, to_df=[0, 'group'], to_kwargs={}) 144 | 145 | return func(x, group) 146 | 147 | 148 | def group_vector_neut(x, y, g): 149 | """Similar to vector_neut(x, y) but x neutralize to y for each group g which can be any classifier such as subindustry, industry, sector, etc.""" 150 | pass 151 | 152 | 153 | def group_zscore(x, group): 154 | """Calculates group Z-score - numerical measurement that describes a value's relationship to the mean of a group of values. Z-score is measured in terms of standard deviations from the mean. zscore = (data - mean) / stddev of x for each instrument within its group.""" 155 | 156 | def _zscore(x): 157 | return (x - np.mean(x)) / np.std(x) 158 | 159 | func = dataframe_groupby_apply(_zscore, by=BY_GROUP, to_df=[0, 'group'], to_kwargs={}) 160 | 161 | return func(x, group) 162 | -------------------------------------------------------------------------------- /ta_cn/wq/logical.py: -------------------------------------------------------------------------------- 1 | """ 2 | Logical Operators 3 | """ 4 | from functools import reduce 5 | 6 | import numpy as np 7 | 8 | 9 | def and_(*args): 10 | """Logical AND operator, returns true if both operands are true and returns false otherwise""" 11 | return reduce(lambda x, y: np.logical_and(x, y), list(args)) 12 | 13 | 14 | def or_(*args): 15 | """Logical OR operator returns true if either or both inputs are true and returns false otherwise""" 16 | return reduce(lambda x, y: np.logical_or(x, y), args) 17 | 18 | 19 | def equal(input1, input2): 20 | """Returns true if both inputs are same and returns false otherwise""" 21 | return input1 == input2 22 | 23 | 24 | def negate(input): 25 | """The result is true if the converted operand is false; the result is false if the converted operand is true""" 26 | return ~input 27 | 28 | 29 | def less(input1, input2): 30 | """If input1 < input2 return true, else return false""" 31 | # ValueError: Can only compare identically-labeled Series objects 32 | return input1 - input2 < 0 33 | 34 | 35 | def if_else(input1, input2, input3): 36 | """If input1 is true then return input2 else return input3.""" 37 | return np.where(input1, input2, input3) 38 | 39 | 40 | def is_not_nan(input): 41 | """If (input != NaN) return 1 else return 0""" 42 | return input == input 43 | 44 | 45 | def is_nan(input): 46 | """If (input == NaN) return 1 else return 0""" 47 | return input != input 48 | 49 | 50 | def is_finite(input): 51 | """If (input NaN or input == INF) return 0, else return 1""" 52 | return np.isinf(input) 53 | 54 | 55 | def is_not_finite(input): 56 | """If (input NAN or input == INF) return 1 else return 0""" 57 | return ~np.isinf(input) 58 | -------------------------------------------------------------------------------- /ta_cn/wq/special.py: -------------------------------------------------------------------------------- 1 | """ 2 | Special Operators 3 | """ 4 | 5 | 6 | def convert(x, mode="dollar2share"): 7 | """Convert dollars to share or share to dollar when mode = "share2dollar" 8 | """ 9 | pass 10 | 11 | 12 | def inst_pnl(x): 13 | """Generate pnl per instruments""" 14 | pass 15 | -------------------------------------------------------------------------------- /ta_cn/wq/transformational.py: -------------------------------------------------------------------------------- 1 | """ 2 | Transformational Operators 3 | """ 4 | import numba 5 | import numpy as np 6 | 7 | from ..utils import pd_to_np 8 | from ..wq.logical import if_else 9 | from ..wq.time_series import days_from_last_change 10 | 11 | 12 | def arc_cos(x): 13 | """If -1 <= x <= 1: arccos(x); else NaN""" 14 | return np.arccos(x) 15 | 16 | 17 | def arc_sin(x): 18 | """If -1 <= x <= 1: arcsin(x); else NaN""" 19 | return np.arcsin(x) 20 | 21 | 22 | def arc_tan(x): 23 | """This operator does inverse tangent of input. """ 24 | return np.arctan(x) 25 | 26 | 27 | def bucket(x, 28 | range="0, 1, 0.1", 29 | buckets="2,5,6,7,10", 30 | skipBegin=False, skipEnd=False, skipBoth=False, 31 | NANGroup=True): 32 | """Convert float values into indexes for user-specified buckets. Bucket is useful for creating group values, which can be passed to group operators as input.""" 33 | pass 34 | 35 | 36 | def clamp(x, lower=0, upper=0, inverse=False, mask=np.nan): 37 | """Limits input value between lower and upper bound in inverse = false mode (which is default). Alternatively, when inverse = true, values between bounds are replaced with mask, while values outside bounds are left as is.""" 38 | if inverse: 39 | # mask is one of: 'nearest_bound', 'mean', 'NAN' or any floating point number 40 | return if_else((x > lower) & (x < upper), mask, x) 41 | else: 42 | # q = if_else(x < lower, lower, x) 43 | # u = if_else(q > upper, upper, q) 44 | return np.clip(x, lower, upper) 45 | 46 | 47 | def filter(x, h="1, 2, 3, 4", t="0.5"): 48 | """Used to filter the value and allows to create filters like linear or exponential decay.""" 49 | pass 50 | 51 | 52 | def keep(x, f, period=5): 53 | """This operator outputs value x when f changes and continues to do that for “period” days after f stopped changing. After “period” days since last change of f, NaN is output.""" 54 | D = days_from_last_change(f) 55 | return trade_when(D < period, x, D > period) 56 | 57 | 58 | def left_tail(x, maximum=0): 59 | """NaN everything greater than maximum, maximum should be constant.""" 60 | return np.where(x > maximum, np.nan, x) 61 | 62 | 63 | def pasteurize(x): 64 | """Set to NaN if x is INF or if the underlying instrument is not in the Alpha universe""" 65 | # TODO: 不在票池中的的功能无法表示 66 | x = x.copy() 67 | x[np.isinf(x)] = np.nan 68 | return x 69 | 70 | 71 | def right_tail(x, minimum=0): 72 | """NaN everything less than minimum, minimum should be constant.""" 73 | return np.where(x < minimum, np.nan, x) 74 | 75 | 76 | def sigmoid(x): 77 | """Returns 1 / (1 + exp(-x))""" 78 | return 1 / (1 + np.exp(-x)) 79 | 80 | 81 | def tail(x, lower=0, upper=0, newval=0): 82 | """If (x > lower AND x < upper) return newval, else return x. Lower, upper, newval should be constants. """ 83 | return np.where((x > lower) & (x < upper), newval, x) 84 | 85 | 86 | def tanh(x): 87 | """Hyperbolic tangent of x""" 88 | return np.tanh(x) 89 | 90 | 91 | @numba.jit(nopython=True, cache=True, nogil=True) 92 | def _trade_when_nb(xx, yy, zz, out): 93 | is_1d = xx.ndim == 1 94 | x = xx.shape[0] 95 | y = 1 if is_1d else xx.shape[1] 96 | 97 | for j in range(y): 98 | a = xx if is_1d else xx[:, j] 99 | b = yy if is_1d else yy[:, j] 100 | c = zz if is_1d else zz[:, j] 101 | d = out if is_1d else out[:, j] 102 | last = np.nan 103 | for i in range(x): 104 | if c[i] > 0: 105 | d[i] = np.nan 106 | elif a[i] > 0: 107 | d[i] = b[i] 108 | else: 109 | d[i] = last 110 | last = d[i] 111 | 112 | return out 113 | 114 | 115 | def trade_when(x, y, z): 116 | """Used in order to change Alpha values only under a specified condition and to hold Alpha values in other cases. It also allows to close Alpha positions (assign NaN values) under a specified condition.""" 117 | x = pd_to_np(x, copy=False) 118 | y = pd_to_np(y, copy=False) 119 | z = pd_to_np(z, copy=False) 120 | out = np.empty_like(y) 121 | return _trade_when_nb(x, y, z, out) 122 | -------------------------------------------------------------------------------- /ta_cn/wq/vector.py: -------------------------------------------------------------------------------- 1 | """ 2 | Vector Operators 3 | 4 | # TODO: 没有搞明白,等看懂了再来实现 5 | """ 6 | 7 | 8 | def vec_avg(x): 9 | """Taking mean of the vector field x""" 10 | pass 11 | 12 | 13 | def vec_choose(x, nth=1): 14 | """Choosing kth item(indexed at 0) from each vector field x""" 15 | pass 16 | 17 | 18 | def vec_count(x): 19 | """Number of elements in vector field x""" 20 | pass 21 | 22 | 23 | def vec_ir(x): 24 | """Information Ratio (Mean / Standard Deviation) of vector field x""" 25 | pass 26 | 27 | 28 | def vec_kurtosis(x): 29 | """Kurtosis of vector field x""" 30 | pass 31 | 32 | 33 | def vec_max(x): 34 | """Maximum value form vector field x""" 35 | pass 36 | 37 | 38 | def vec_min(x): 39 | """Minimum value form vector field x""" 40 | pass 41 | 42 | 43 | def vec_norm(x): 44 | """Sum of all absolute values of vector field x""" 45 | pass 46 | 47 | 48 | def vec_percentage(x, percentage=0.5): 49 | """Percentile of vector field x""" 50 | pass 51 | 52 | 53 | def vec_powersum(x, constant=2): 54 | """Sum of power of vector field x""" 55 | pass 56 | 57 | 58 | def vec_range(x): 59 | """Difference between maximum and minimum element in vector field x""" 60 | pass 61 | 62 | 63 | def vec_skewness(x): 64 | """Skewness of vector field x""" 65 | pass 66 | 67 | 68 | def vec_stddev(x): 69 | """Standard Deviation of vector field x""" 70 | pass 71 | 72 | 73 | def vec_sum(x): 74 | """Sum of vector field x""" 75 | pass 76 | -------------------------------------------------------------------------------- /tests/atr_.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | import pandas as pd 6 | 7 | import ta_cn.talib as ta 8 | from ta_cn.slow import ATR 9 | 10 | if __name__ == '__main__': 11 | # 准备数据 12 | h = np.random.rand(10000000).reshape(-1, 50000) + 10 13 | l = np.random.rand(10000000).reshape(-1, 50000) 14 | c = np.random.rand(10000000).reshape(-1, 50000) 15 | 16 | c[:20, -1] = np.nan 17 | 18 | t1 = time.time() 19 | z1 = ta.ATR(h, l, c) 20 | t2 = time.time() 21 | z2 = ATR(h, l, c) 22 | t3 = time.time() 23 | 24 | print(t2 - t1, t3 - t2) 25 | 26 | pd.DataFrame({'TA': z1[:, -1], 'MY': z2[:, -1]}).plot() 27 | pd.DataFrame({'MY': z2[:, -1], 'TA': z1[:, -1]}).plot() 28 | plt.show() 29 | -------------------------------------------------------------------------------- /tests/atr_cn.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | import pandas as pd 6 | 7 | from ta_cn.tdx.over_bought_over_sold import ATR_CN as ATR_CN1 8 | from ta_cn.slow import ATR_CN as ATR_CN2 9 | 10 | if __name__ == '__main__': 11 | # 准备数据 12 | h = np.random.rand(10000000).reshape(-1, 50000) + 10 13 | l = np.random.rand(10000000).reshape(-1, 50000) 14 | c = np.random.rand(10000000).reshape(-1, 50000) 15 | 16 | t1 = time.time() 17 | z1 = ATR_CN1(h, l, c) 18 | t2 = time.time() 19 | z2 = ATR_CN2(h, l, c) 20 | t3 = time.time() 21 | 22 | print(t2 - t1, t3 - t2) 23 | 24 | pd.DataFrame({'TA': z1[:, -1], 'MY': z2[:, -1]}).plot() 25 | pd.DataFrame({'MY': z2[:, -1], 'TA': z1[:, -1]}).plot() 26 | plt.show() 27 | -------------------------------------------------------------------------------- /tests/avedev_.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | import pandas as pd 6 | 7 | from ta_cn.slow import _AVEDEV as AVEDEV2 8 | from ta_cn.tdx.statistics import AVEDEV as AVEDEV1 9 | 10 | if __name__ == '__main__': 11 | # 准备数据 12 | h = np.random.rand(100000).reshape(-1, 500) + 10 13 | l = np.random.rand(100000).reshape(-1, 500) 14 | c = np.random.rand(100000).reshape(-1, 500) 15 | 16 | c[:20, -1] = np.nan 17 | # 先执行一次让numba编译 18 | z1 = AVEDEV1(c) 19 | t1 = time.time() 20 | z1 = AVEDEV1(c) 21 | t2 = time.time() 22 | z2 = AVEDEV2(c).values 23 | t3 = time.time() 24 | 25 | print(t2 - t1, t3 - t2) 26 | 27 | pd.DataFrame({'TA': z1[:, -1], 'MY': z2[:, -1]}).plot() 28 | pd.DataFrame({'MY': z2[:, -1], 'TA': z1[:, -1]}).plot() 29 | plt.show() 30 | -------------------------------------------------------------------------------- /tests/boll_.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | import pandas as pd 6 | 7 | import ta_cn.talib as ta 8 | from ta_cn.tdx.pressure_support import BOLL 9 | 10 | if __name__ == '__main__': 11 | # 准备数据 12 | h = np.random.rand(10000000).reshape(-1, 50000) + 10 13 | l = np.random.rand(10000000).reshape(-1, 50000) 14 | c = np.random.rand(10000000).reshape(-1, 50000) 15 | 16 | c[:20, -1] = np.nan 17 | 18 | t1 = time.time() 19 | x1, y1, z1 = ta.BBANDS(c, timeperiod=10, nbdevup=2, nbdevdn=2) 20 | t2 = time.time() 21 | x2, y2, z2 = BOLL(c, timeperiod=10, nbdevup=2, nbdevdn=2) 22 | t3 = time.time() 23 | 24 | print(t2 - t1, t3 - t2) 25 | 26 | pd.DataFrame({'TA': z1[:, -1], 'MY': z2[:, -1]}).plot() 27 | pd.DataFrame({'MY': z2[:, -1], 'TA': z1[:, -1]}).plot() 28 | plt.show() 29 | -------------------------------------------------------------------------------- /tests/cci_.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | import pandas as pd 6 | 7 | import ta_cn.talib as ta 8 | from ta_cn.slow import CCI 9 | 10 | if __name__ == '__main__': 11 | # 准备数据 12 | h = np.random.rand(1000000).reshape(-1, 5000) + 10 13 | l = np.random.rand(1000000).reshape(-1, 5000) 14 | c = np.random.rand(1000000).reshape(-1, 5000) 15 | 16 | c[:20, -1] = np.nan 17 | z2 = CCI(h, l, c) 18 | t1 = time.time() 19 | z1 = ta.CCI(h, l, c) 20 | t2 = time.time() 21 | z2 = CCI(h, l, c) 22 | t3 = time.time() 23 | 24 | print(t2 - t1, t3 - t2) 25 | 26 | pd.DataFrame({'TA': z1[:, -1], 'MY': z2[:, -1]}).plot() 27 | pd.DataFrame({'MY': z2[:, -1], 'TA': z1[:, -1]}).plot() 28 | plt.show() 29 | -------------------------------------------------------------------------------- /tests/chip_.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from ta_cn.chip import chip, WINNER, COST 4 | 5 | high = np.array([10.4, 10.2, 10.2, 10.4, 10.5, 10.7, 10.7, 10.7, 10.8, 10.9]) 6 | low = np.array([10.3, 10.1, 10.2, 10.3, 10.4, 10.5, 10.6, 10.7, 10.8, 10.9]) 7 | avg = np.array([10.3, 10.1, 10.2, 10.3, 10.4, 10.5, 10.6, 10.7, 10.8, 10.9]) 8 | close = np.array([10.3, 10.1, 10.2, 10.3, 10.4, 10.5, 10.6, 10.7, 10.8, 10.9]) 9 | turnover = np.array([0.3, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]) 10 | 11 | step = 0.1 12 | out, columns = chip(high, low, avg, turnover, step=step) 13 | print(out.sum(axis=1)) 14 | print(columns) 15 | x = WINNER(out, columns, close) 16 | print(x) 17 | y = COST(out, columns, 0.85) 18 | print(y) 19 | -------------------------------------------------------------------------------- /tests/covar_.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | import pandas as pd 6 | 7 | from ta_cn.nb import numpy_rolling_apply_1, _rolling_func_2_1_nb, _cov_nb 8 | from ta_cn.utils import np_to_pd, pd_to_np 9 | 10 | 11 | def COVAR(real0, real1, timeperiod=30): 12 | return numpy_rolling_apply_1([pd_to_np(real0), pd_to_np(real1)], 13 | timeperiod, _rolling_func_2_1_nb, _cov_nb) 14 | 15 | 16 | # 移动相关系数 17 | def correlation(x, y, window=30): 18 | return np_to_pd(x).rolling(window).corr(np_to_pd(y)) 19 | 20 | 21 | # 移动协方差 22 | def covariance(x, y, window=30): 23 | return np_to_pd(x).rolling(window).cov(np_to_pd(y)) 24 | 25 | 26 | if __name__ == '__main__': 27 | # 准备数据 28 | h = np.random.rand(100000).reshape(-1, 500) + 10 29 | l = np.random.rand(100000).reshape(-1, 500) 30 | 31 | z1 = COVAR(h, l) 32 | t1 = time.time() 33 | z1 = COVAR(h, l) 34 | t2 = time.time() 35 | z2 = covariance(h, l).values 36 | t3 = time.time() 37 | print(t2 - t1, t3 - t2) 38 | 39 | pd.DataFrame({'TA': z1[:, -1], 'MY': z2[:, -1]}).plot() 40 | pd.DataFrame({'MY': z2[:, -1], 'TA': z1[:, -1]}).plot() 41 | plt.show() 42 | -------------------------------------------------------------------------------- /tests/cross_.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | import pandas as pd 4 | from pandas._testing import assert_numpy_array_equal 5 | 6 | from ta_cn.tdx.logical import CROSS 7 | 8 | names = ['rf'] 9 | dates = pd.date_range(start='2017-01-01', end='2017-12-31', freq=pd.tseries.offsets.BDay()) 10 | n = len(dates) 11 | rdf = pd.DataFrame( 12 | np.zeros((n, len(names))), 13 | index=dates, 14 | columns=names 15 | ) 16 | 17 | np.random.seed(1) 18 | rdf['rf'] = 0. 19 | 20 | pdf = 100 * np.cumprod(1 + rdf) 21 | 22 | factor = rdf.copy() 23 | factor[:] = 1.0 24 | factor[:][-100:] = 1.1 25 | factor[:][-50:] = 1.5 26 | 27 | close_h = pdf * factor 28 | close_q = pdf * (factor / factor.iloc[-1]) 29 | 30 | 31 | def calc(close): 32 | # 不同的复权方法,结果不同,已经修正 33 | 34 | # 多了两点 35 | ma5 = close.rolling(10).mean() 36 | ma10 = close.rolling(20).mean() 37 | 38 | # 少了一点 39 | # ma5 = ta.MA(close.iloc[:, 0], 10) 40 | # ma10 = ta.MA(close.iloc[:, 0], 20) 41 | 42 | r = CROSS(ma5, ma10) 43 | 44 | df = pd.DataFrame() 45 | df['CLOSE'] = close 46 | df['CROSS'] = r.astype(float) 47 | df['MA5'] = ma5 48 | df['MA10'] = ma10 49 | df['5/10'] = ma5 / ma10 50 | 51 | return df 52 | 53 | 54 | a1 = calc(close_h) # 出现了两个毛刺,原来是精度差异导致 55 | a1.plot(secondary_y=['CROSS', '5/10']) 56 | 57 | a2 = calc(close_q) # 换算法后又少一个点 58 | a2.plot(secondary_y=['CROSS', '5/10']) 59 | 60 | assert_numpy_array_equal(a1['CROSS'].values, a2['CROSS'].values) 61 | 62 | plt.show() 63 | 64 | # 65 | -------------------------------------------------------------------------------- /tests/dmi_.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | import pandas as pd 6 | 7 | from ta_cn.slow import DMI as DMI2 8 | from ta_cn.tdx.trend import DMI as DMI1 9 | 10 | if __name__ == '__main__': 11 | # 准备数据 12 | h = np.random.rand(1000).reshape(-1, 5) + 10 13 | l = np.random.rand(1000).reshape(-1, 5) 14 | c = np.random.rand(1000).reshape(-1, 5) 15 | 16 | h[:20, -1] = np.nan 17 | l[:20, -1] = np.nan 18 | c[:20, -1] = np.nan 19 | 20 | t1 = time.time() 21 | z1 = DMI1(h, l, c, 3)[3] # 查看的是ADXR 22 | t2 = time.time() 23 | z2 = DMI2(h, l, c, 3)[3] 24 | t3 = time.time() 25 | 26 | print(t2 - t1, t3 - t2) 27 | 28 | pd.DataFrame({'TA': z1[:, -1], 'MY': z2[:, -1]}).plot() 29 | pd.DataFrame({'MY': z2[:, -1], 'TA': z1[:, -1]}).plot() 30 | 31 | plt.show() 32 | -------------------------------------------------------------------------------- /tests/ema_.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | import pandas as pd 6 | 7 | from ta_cn.ema import EXPMEMA, EMA_0_TA, EMA_1_TA, EMA_1_PD 8 | from ta_cn.talib import set_compatibility_enable 9 | 10 | if __name__ == '__main__': 11 | set_compatibility_enable(True) 12 | 13 | # 准备数据 14 | h = np.random.rand(1000).reshape(-1, 2) + 10 15 | l = np.random.rand(1000).reshape(-1, 2) 16 | c = np.random.rand(1000).reshape(-1, 2) 17 | 18 | # c[:20, -1] = np.nan 19 | z2 = EXPMEMA(c) 20 | 21 | t1 = time.time() 22 | z1 = EMA_0_TA(c, timeperiod=24) 23 | t2 = time.time() 24 | z2 = EXPMEMA(c) 25 | t3 = time.time() 26 | 27 | print(t2 - t1, t3 - t2) 28 | 29 | pd.DataFrame({'TA': z1[:, -1], 'MY': z2[:, -1]}).plot() 30 | pd.DataFrame({'MY': z2[:, -1], 'TA': z1[:, -1]}).plot() 31 | # plt.show() 32 | 33 | t1 = time.time() 34 | z1 = EMA_1_TA(c) 35 | t2 = time.time() 36 | z2 = EMA_1_PD(c) 37 | t3 = time.time() 38 | 39 | print(t2 - t1, t3 - t2) 40 | 41 | pd.DataFrame({'TA': z1[:, -1], 'MY': z2[:, -1]}).plot() 42 | pd.DataFrame({'MY': z2[:, -1], 'TA': z1[:, -1]}).plot() 43 | plt.show() 44 | -------------------------------------------------------------------------------- /tests/forcast_.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | import pandas as pd 6 | import ta_cn.talib as ta 7 | 8 | # def FORCAST(S, N=14): # 返回S序列N周期回线性回归后的预测值, jqz1226改进成序列出 9 | # return pd.DataFrame(S).rolling(N).apply(lambda x: np.polyval(np.polyfit(range(N), x, deg=1), N - 1), raw=True) 10 | from ta_cn.regress import FORCAST 11 | 12 | if __name__ == '__main__': 13 | # 准备数据 14 | h = np.random.rand(100000).reshape(-1, 500) + 10 15 | l = np.random.rand(100000).reshape(-1, 500) 16 | c = np.random.rand(100000).reshape(-1, 500) 17 | 18 | c[:20, -1] = np.nan 19 | # 先执行一次让numba编译 20 | 21 | t1 = time.time() 22 | z1 = ta.LINEARREG(c) 23 | t2 = time.time() 24 | z2 = FORCAST(c) 25 | t3 = time.time() 26 | 27 | print(t2 - t1, t3 - t2) 28 | 29 | pd.DataFrame({'TA': z1[:, -1], 'MY': z2[:, -1]}).plot() 30 | pd.DataFrame({'MY': z2[:, -1], 'TA': z1[:, -1]}).plot() 31 | plt.show() 32 | 33 | FORCAST 34 | -------------------------------------------------------------------------------- /tests/macd_.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | import pandas as pd 6 | 7 | from ta_cn.slow import MACD_CN 8 | from ta_cn.talib import set_compatibility_enable, TA_COMPATIBILITY_METASTOCK, set_compatibility 9 | from ta_cn.tdx.trend import MACD 10 | 11 | if __name__ == '__main__': 12 | set_compatibility_enable(True) 13 | set_compatibility(TA_COMPATIBILITY_METASTOCK) 14 | set_compatibility_enable(False) 15 | 16 | # 准备数据 17 | h = np.random.rand(100000).reshape(-1, 500) + 10 18 | l = np.random.rand(100000).reshape(-1, 500) 19 | c = np.random.rand(100000).reshape(-1, 500) 20 | 21 | c[:20, -1] = np.nan 22 | 23 | t1 = time.time() 24 | z1 = MACD(c)[-1] 25 | t2 = time.time() 26 | z2 = MACD_CN(c)[-1] 27 | t3 = time.time() 28 | 29 | print(t2 - t1, t3 - t2) 30 | 31 | pd.DataFrame({'TA': z1[:, -1], 'MY': z2[:, -1]}).plot() 32 | pd.DataFrame({'MY': z2[:, -1], 'TA': z1[:, -1]}).plot() 33 | plt.show() 34 | -------------------------------------------------------------------------------- /tests/mfi_.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | import pandas as pd 6 | 7 | import ta_cn.talib as ta 8 | from ta_cn.slow import MFI 9 | 10 | if __name__ == '__main__': 11 | # 准备数据 12 | h = np.random.rand(10000000).reshape(-1, 50000) + 10 13 | l = np.random.rand(10000000).reshape(-1, 50000) 14 | c = np.random.rand(10000000).reshape(-1, 50000) 15 | v = np.random.rand(10000000).reshape(-1, 50000) * 1000 16 | 17 | c[:20, -1] = np.nan 18 | 19 | t1 = time.time() 20 | z1 = ta.MFI(h, l, c, v) 21 | t2 = time.time() 22 | z2 = MFI(h, l, c, v) 23 | t3 = time.time() 24 | 25 | print(t2 - t1, t3 - t2) 26 | 27 | pd.DataFrame({'TA': z1[:, 0], 'MY': z2[:, 0]}).plot() 28 | pd.DataFrame({'MY': z2[:, 0], 'TA': z1[:, 0]}).plot() 29 | plt.show() 30 | -------------------------------------------------------------------------------- /tests/obv_.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | import pandas as pd 6 | 7 | import ta_cn.talib as ta 8 | from ta_cn.tdx.volume import OBV 9 | 10 | if __name__ == '__main__': 11 | # 准备数据 12 | h = np.random.rand(10000000).reshape(-1, 50000) + 10 13 | l = np.random.rand(10000000).reshape(-1, 50000) 14 | c = np.random.rand(10000000).reshape(-1, 50000) 15 | v = np.random.rand(10000000).reshape(-1, 50000) * 1000 16 | 17 | c[:20, -1] = np.nan 18 | v[:20, -1] = np.nan 19 | 20 | t1 = time.time() 21 | z1 = ta.OBV(c, v) 22 | t2 = time.time() 23 | z2 = OBV(c, v, 1) 24 | t3 = time.time() 25 | 26 | print(t2 - t1, t3 - t2) 27 | 28 | pd.DataFrame({'TA': z1[:, -1], 'MY': z2[:, -1]}).plot() 29 | pd.DataFrame({'MY': z2[:, -1], 'TA': z1[:, -1]}).plot() 30 | plt.show() 31 | -------------------------------------------------------------------------------- /tests/ols_.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import numpy as np 4 | import statsmodels.api as sm 5 | from statsmodels.regression.rolling import RollingOLS 6 | 7 | from ta_cn.regress import multiple_regress, ts_multiple_regress 8 | 9 | y = np.random.rand(1000) 10 | # y[1] = np.nan 11 | x = np.random.rand(1000 * 4).reshape(-1, 4) 12 | x = np.random.rand(1000 * 1) # .reshape(-1, 4) 13 | # x[1, 1] = np.nan 14 | residual, y_hat, coef = multiple_regress(y, x, add_constant=False) 15 | t1 = time.time() 16 | residual, y_hat, coef = multiple_regress(y, x, add_constant=True) 17 | t2 = time.time() 18 | print(coef) 19 | print(residual[:10]) 20 | 21 | t = sm.add_constant(x) 22 | model = sm.OLS(y, t) 23 | results = model.fit() 24 | t3 = time.time() 25 | print(results.params) 26 | print((y-results.fittedvalues)[:10]) 27 | print(t2 - t1, t3 - t2) 28 | 29 | residual, y_hat, coef = ts_multiple_regress(y, x, timeperiod=80, add_constant=True) 30 | t1 = time.time() 31 | residual, y_hat, coef = ts_multiple_regress(y, x, timeperiod=10, add_constant=True) 32 | print(coef) 33 | t2 = time.time() 34 | t = sm.add_constant(x) 35 | model = RollingOLS(y, t, window=10) 36 | results = model.fit() 37 | 38 | t3 = time.time() 39 | print(results.params) 40 | print(t2 - t1, t3 - t2) 41 | -------------------------------------------------------------------------------- /tests/ray_test.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import polars as pl 3 | import ray 4 | 5 | ray.init() 6 | 7 | frames = [pl.DataFrame({'a': np.random.rand(1000), 8 | 'b': np.random.rand(1000), 9 | 'c': np.random.rand(1000)}) for i in range(10)] 10 | 11 | 12 | def process_a_frame(f: pl.DataFrame): 13 | # return f.a.max() + f.b.min() + f.c.mean() 14 | return f.select([ 15 | pl.col('a').max(), 16 | pl.col('b').min(), 17 | pl.col('c').mean() 18 | ]).sum(axis=1)[0] 19 | 20 | 21 | MP_STYLES = [None, # No parallelism. 22 | 'ray'] # Ray parallelism 23 | MP_STYLE = None 24 | 25 | 26 | def pxmap(f, xs, mp_style): 27 | """Parallel map, implemented with different python parallel execution libraries.""" 28 | if mp_style not in MP_STYLES: 29 | print(f"Unrecognized mp_style {mp_style}") 30 | elif mp_style == 'ray': 31 | @ray.remote 32 | def g(x): 33 | return f(x) 34 | 35 | return ray.get([g.remote(x) for x in xs]) 36 | return [f(x) for x in xs] 37 | 38 | 39 | # https://github.com/pola-rs/polars/issues/1109 40 | 41 | _ = pxmap(process_a_frame, frames, 'ray') 42 | print(_) 43 | -------------------------------------------------------------------------------- /tests/reg_.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import numpy as np 4 | import talib as ta 5 | 6 | from ta_cn.regress import SLOPE_YX, ts_simple_regress 7 | from ta_cn.tdx.reference import REF 8 | 9 | a = np.random.rand(1000) # .reshape(-1, 10) 10 | b = np.random.rand(1000) # .reshape(-1, 10) 11 | 12 | r1 = SLOPE_YX(a, b, 30) 13 | t1 = time.time() 14 | x, r2, z = ts_simple_regress(a, b, 30, rettype=[0, 2, 1]) 15 | t2 = time.time() 16 | t3 = time.time() 17 | print(t2 - t1, t3 - t2) 18 | print(r1[-10:]) 19 | print(r2[-10:]) 20 | ##################### 21 | 22 | c = ta.BETA(a, b, 30) 23 | print(c[-10:]) 24 | # 为何算起来有不小的误差 25 | c = SLOPE_YX(b / REF(b, 1) - 1, a / REF(a, 1) - 1, 30) 26 | print(c[-10:]) 27 | -------------------------------------------------------------------------------- /tests/rsi_.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | import pandas as pd 6 | 7 | import ta_cn.talib as ta 8 | from ta_cn.slow import RSI 9 | 10 | if __name__ == '__main__': 11 | # 准备数据 12 | h = np.random.rand(100000).reshape(-1, 500) + 10 13 | l = np.random.rand(100000).reshape(-1, 500) 14 | c = np.random.rand(100000).reshape(-1, 500) 15 | 16 | c[:20, -1] = np.nan 17 | 18 | t1 = time.time() 19 | z1 = ta.RSI(c, timeperiod=10) 20 | t2 = time.time() 21 | z2 = RSI(c, timeperiod=10) 22 | t3 = time.time() 23 | 24 | print(t2 - t1, t3 - t2) 25 | 26 | pd.DataFrame({'TA': z1[:, -1], 'MY': z2[:, -1]}).plot() 27 | pd.DataFrame({'MY': z2[:, -1], 'TA': z1[:, -1]}).plot() 28 | plt.show() 29 | -------------------------------------------------------------------------------- /tests/slope_.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | import pandas as pd 6 | 7 | import ta_cn.talib as ta 8 | from ta_cn.regress import SLOPE 9 | 10 | ta.init(mode=2, skipna=False, to_globals=True) 11 | 12 | if __name__ == '__main__': 13 | # 准备数据 14 | h = np.random.rand(1000).reshape(-1, 5) + 10 15 | l = np.random.rand(1000).reshape(-1, 5) 16 | c = np.random.rand(1000).reshape(-1, 5) 17 | 18 | c[:20, -1] = np.nan 19 | z2 = SLOPE(c) 20 | t1 = time.time() 21 | z1 = ta.LINEARREG_SLOPE(c) 22 | t2 = time.time() 23 | z2 = SLOPE(c) 24 | t3 = time.time() 25 | 26 | print(t2 - t1, t3 - t2) 27 | 28 | pd.DataFrame({'TA': z1[:, -1], 'MY': z2[:, -1]}).plot() 29 | pd.DataFrame({'MY': z2[:, -1], 'TA': z1[:, -1]}).plot() 30 | plt.show() 31 | -------------------------------------------------------------------------------- /tests/speed_test.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import bottleneck as bn 4 | import numpy as np 5 | import pandas as pd 6 | import polars as pl 7 | import talib 8 | 9 | import ta_cn.talib as ta 10 | 11 | ta2d = ta.init(mode=2, skipna=False, to_globals=False) 12 | 13 | pl.Config.set_tbl_rows(50) 14 | 15 | print(pl.__version__) 16 | 17 | 18 | def f1(df, ldf): 19 | # 最快,因为同资产一起处理的,调用底层库次数少 20 | return df.groupby(by='asset').apply(lambda x: bn.move_mean(x, 5, axis=0)) 21 | 22 | 23 | def f2(df, ldf): 24 | # 看来Rust语言底层做得不错 25 | a = ldf.select([ 26 | 'date', 'asset', 27 | pl.all().exclude(['date', 'asset']).rolling_mean(5).over('asset').suffix('_sma') 28 | ]) 29 | return a 30 | 31 | 32 | def f3(df, ldf): 33 | # 自己的迭代封装看来也可以,只比Rust的差一点点 34 | return df.groupby(by='asset').apply(lambda x: ta2d.SMA(x, timeperiod=5)) 35 | 36 | 37 | def f4(df, ldf): 38 | # 比自己的封装版要差一点点 39 | a = ldf.select([ 40 | 'date', 'asset', 41 | pl.all().exclude(['date', 'asset']).apply(lambda x: pl.Series(bn.move_mean(x, 5))).over('asset').suffix('_sma') 42 | ]) 43 | return a 44 | 45 | 46 | def f5(df, ldf): 47 | # 比自己的封装版要差一点点 48 | a = ldf.select([ 49 | 'date', 'asset', 50 | pl.all().exclude(['date', 'asset']).apply(lambda x: pl.Series(talib.SMA(x.to_numpy(), 5))).over('asset').suffix( 51 | '_sma') 52 | ]) 53 | return a 54 | 55 | 56 | def f6(df, ldf): 57 | return df.groupby(by='asset').rolling(5).mean() 58 | 59 | 60 | if __name__ == '__main__': 61 | c = np.random.rand(1000000).reshape(-1, 5000) 62 | c = pd.DataFrame(c).stack() 63 | 64 | df = {i: c for i in range(100)} 65 | df = pd.DataFrame(df) 66 | df.index.names = ['date', 'asset'] 67 | df1 = pl.from_pandas(df.reset_index()) 68 | 69 | for f in [f1, f2, f3, f4, f5, f6]: 70 | t0 = time.time() 71 | f(df, df1) 72 | t1 = time.time() 73 | print(f.__name__, t1 - t0) 74 | 75 | # f1 1.5436947345733643 76 | # f2 3.0246715545654297 77 | # f3 6.506034851074219 78 | # f4 10.475429773330688 79 | # f5 10.043684959411621 80 | # f6 29.300340175628662 81 | -------------------------------------------------------------------------------- /tests/stddev_.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | import pandas as pd 4 | import talib as ta 5 | from pandas._testing import assert_series_equal 6 | 7 | from ta_cn import EPSILON 8 | from ta_cn.tdx.logical import CROSS 9 | from ta_cn.wq.arithmetic import divide 10 | from ta_cn.wq.transformational import tail 11 | 12 | names = ['foo', 'bar', 'rf'] 13 | names = ['rf'] 14 | dates = pd.date_range(start='2017-01-01', end='2017-12-31', freq=pd.tseries.offsets.BDay()) 15 | n = len(dates) 16 | rdf = pd.DataFrame( 17 | np.zeros((n, len(names))), 18 | index=dates, 19 | columns=names 20 | ) 21 | 22 | np.random.seed(1) 23 | # rdf['foo'] = np.random.normal(loc=0.1 / n, scale=0.2 / np.sqrt(n), size=n) 24 | # rdf['bar'] = np.random.normal(loc=0.04 / n, scale=0.05 / np.sqrt(n), size=n) 25 | rdf['rf'] = 0. 26 | 27 | pdf = 100 * np.cumprod(1 + rdf) 28 | 29 | factor = rdf.copy() 30 | factor[:] = 1.0 31 | factor[:][-100:] = 1.1 32 | factor[:][-50:] = 1.5 33 | 34 | close_h = pdf * factor 35 | close_q = pdf * (factor / factor.iloc[-1]) 36 | 37 | 38 | def calc(close): 39 | # STDDEV与收盘价的比结果是一样的,但VAR与close的比就不一样了 40 | # 不过VAR与close**2的二次方是一样的 41 | ma5 = ta.VAR(close.iloc[:, 0], 10) 42 | ma10 = close.iloc[:, 0]**2 43 | 44 | ma5 = tail(ma5, lower=-EPSILON, upper=EPSILON, newval=0) 45 | 46 | r = CROSS(ma5, ma10) 47 | 48 | df = pd.DataFrame() 49 | df['CLOSE'] = close 50 | df['CROSS'] = r.astype(float) 51 | df['MA5'] = ma5 52 | df['MA10'] = ma10 53 | df['5/10'] = divide(ma5, ma10) 54 | 55 | return df 56 | 57 | 58 | a2 = calc(close_q) # 换算法后又少一个点 59 | a2.plot(secondary_y=['CROSS', '5/10']) 60 | 61 | a1 = calc(close_h) # 出现了两个毛刺,原来是精度差异导致 62 | a1.plot(secondary_y=['CROSS', '5/10']) 63 | 64 | plt.show() 65 | 66 | assert_series_equal(a1['5/10'], a2['5/10']) 67 | -------------------------------------------------------------------------------- /tests/tr_.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | import pandas as pd 6 | 7 | import ta_cn.talib as ta 8 | from ta_cn.tdx.reference import TR 9 | 10 | if __name__ == '__main__': 11 | # 准备数据 12 | h = np.random.rand(10000000).reshape(-1, 50000) + 10 13 | l = np.random.rand(10000000).reshape(-1, 50000) 14 | c = np.random.rand(10000000).reshape(-1, 50000) 15 | 16 | c[:20, -1] = np.nan 17 | 18 | t1 = time.time() 19 | z1 = ta.TRANGE(h, l, c) 20 | t2 = time.time() 21 | z2 = TR(h, l, c) 22 | t3 = time.time() 23 | 24 | print(t2 - t1, t3 - t2) 25 | 26 | pd.DataFrame({'TA': z1[:, -1], 'MY': z2[:, -1]}).plot() 27 | pd.DataFrame({'MY': z2[:, -1], 'TA': z1[:, -1]}).plot() 28 | plt.show() 29 | -------------------------------------------------------------------------------- /tests/trix_.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | import pandas as pd 6 | 7 | from ta_cn.talib import set_compatibility_enable, TA_COMPATIBILITY_METASTOCK, set_compatibility 8 | from ta_cn.slow import TRIX_CN 9 | import ta_cn.talib as ta 10 | 11 | if __name__ == '__main__': 12 | set_compatibility_enable(True) 13 | set_compatibility(TA_COMPATIBILITY_METASTOCK) 14 | set_compatibility_enable(False) 15 | # 准备数据 16 | h = np.random.rand(100000).reshape(-1, 500) + 10 17 | l = np.random.rand(100000).reshape(-1, 500) 18 | c = np.random.rand(100000).reshape(-1, 500) 19 | 20 | c[:20, -1] = np.nan 21 | 22 | t1 = time.time() 23 | 24 | z1 = ta.TRIX(c, timeperiod=10) 25 | t2 = time.time() 26 | z2 = TRIX_CN(c, timeperiod=10) 27 | t3 = time.time() 28 | 29 | print(t2 - t1, t3 - t2) 30 | 31 | pd.DataFrame({'TA': z1[:, -1], 'MY': z2[:, -1]}).plot() 32 | pd.DataFrame({'MY': z2[:, -1], 'TA': z1[:, -1]}).plot() 33 | plt.show() 34 | -------------------------------------------------------------------------------- /tests/var_.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | import pandas as pd 4 | import talib as ta 5 | from pandas._testing import assert_series_equal 6 | 7 | from ta_cn import EPSILON 8 | from ta_cn.tdx.logical import CROSS 9 | from ta_cn.wq.arithmetic import divide 10 | from ta_cn.wq.transformational import tail 11 | 12 | names = ['foo', 'bar', 'rf'] 13 | names = ['rf'] 14 | dates = pd.date_range(start='2017-01-01', end='2017-12-31', freq=pd.tseries.offsets.BDay()) 15 | n = len(dates) 16 | rdf = pd.DataFrame( 17 | np.zeros((n, len(names))), 18 | index=dates, 19 | columns=names 20 | ) 21 | 22 | np.random.seed(1) 23 | # rdf['foo'] = np.random.normal(loc=0.1 / n, scale=0.2 / np.sqrt(n), size=n) 24 | # rdf['bar'] = np.random.normal(loc=0.04 / n, scale=0.05 / np.sqrt(n), size=n) 25 | rdf['rf'] = 0. 26 | 27 | pdf = 100 * np.cumprod(1 + rdf) 28 | 29 | factor = rdf.copy() 30 | factor[:] = 1.0 31 | factor[:][-100:] = 1.1 32 | factor[:][-50:] = 1.5 33 | 34 | close_h = pdf * factor 35 | close_q = pdf * (factor / factor.iloc[-1]) 36 | 37 | 38 | def calc(close): 39 | # 少一点 40 | ma5 = ta.VAR(close.iloc[:, 0], 10) 41 | ma10 = ta.VAR(close.iloc[:, 0], 20) 42 | 43 | # 由于计算误差,很多结果1e-11这种接近于0的值,都是这么小,除后结果被放大 44 | ma5 = tail(ma5, lower=-EPSILON, upper=EPSILON, newval=0) 45 | ma10 = tail(ma10, lower=-EPSILON, upper=EPSILON, newval=0) 46 | 47 | r = CROSS(ma5, ma10) 48 | 49 | df = pd.DataFrame() 50 | df['CLOSE'] = close 51 | df['CROSS'] = r.astype(float) 52 | df['MA5'] = ma5 53 | df['MA10'] = ma10 54 | df['5/10'] = divide(ma5, ma10) 55 | 56 | return df 57 | 58 | 59 | a2 = calc(close_q) # 换算法后又少一个点 60 | a2.plot(secondary_y=['CROSS', '5/10']) 61 | 62 | a1 = calc(close_h) # 出现了两个毛刺,原来是精度差异导致 63 | a1.plot(secondary_y=['CROSS', '5/10']) 64 | 65 | plt.show() 66 | 67 | assert_series_equal(a1['5/10'], a2['5/10']) 68 | -------------------------------------------------------------------------------- /tests/wls_.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import numpy as np 4 | import statsmodels.api as sm 5 | from statsmodels.regression.rolling import RollingOLS 6 | 7 | from ta_cn.regress import multiple_regress, ts_multiple_regress 8 | 9 | y = np.random.rand(1000) 10 | x = np.random.rand(1000 * 4).reshape(-1, 4) 11 | # coef, residual = multiple_regress(y, x) 12 | # t1 = time.time() 13 | # coef, residual = multiple_regress(y, x) 14 | # t2 = time.time() 15 | # print(residual) 16 | 17 | t = sm.add_constant(x) 18 | model = sm.WLS(y, t, weights=1) 19 | results = model.fit() 20 | t3 = time.time() 21 | # # print(y-results.fittedvalues) 22 | # print(t2 - t1, t3 - t2) 23 | # 24 | # coef, residual = ts_multiple_regress(y, x, timeperiod=80, add_constant=True) 25 | # t1 = time.time() 26 | # coef, residual = ts_multiple_regress(y, x, timeperiod=10, add_constant=True) 27 | # print(coef) 28 | # t2 = time.time() 29 | # t = sm.add_constant(x) 30 | # model = RollingOLS(y, t, window=10) 31 | # results = model.fit() 32 | # 33 | # t3 = time.time() 34 | # print(results.params) 35 | # print(t2 - t1, t3 - t2) 36 | -------------------------------------------------------------------------------- /tests/wma_.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | import pandas as pd 6 | 7 | import ta_cn.talib as ta 8 | from ta_cn.slow import WMA 9 | 10 | if __name__ == '__main__': 11 | # 准备数据 12 | h = np.random.rand(100000).reshape(-1, 500) + 10 13 | l = np.random.rand(100000).reshape(-1, 500) 14 | c = np.random.rand(100000).reshape(-1, 500) 15 | 16 | c[:20, -1] = np.nan 17 | 18 | t1 = time.time() 19 | z1 = ta.WMA(c, timeperiod=10) 20 | t2 = time.time() 21 | z2 = WMA(c, timeperiod=10).values 22 | t3 = time.time() 23 | 24 | print(t2 - t1, t3 - t2) 25 | 26 | pd.DataFrame({'TA': z1[:, -1], 'MY': z2[:, -1]}).plot() 27 | pd.DataFrame({'MY': z2[:, -1], 'TA': z1[:, -1]}).plot() 28 | plt.show() 29 | -------------------------------------------------------------------------------- /加速.md: -------------------------------------------------------------------------------- 1 | # 加速 2 | 1. 使用C语言等替代关键函数。如指标库TA-Lib就是用C语言实现的 3 | 2. 减少跨语言调用次数。循环调用talib.SMA,还不如调用一次bottleneck.move_mean 4 | 3. 易用性上,numba优于Cython。因为Cython对不同Python版本不同操作系统要分别编译,过于复杂 5 | 6 | ## 长表与宽表 7 | 1. 长表 8 | 1. 如同数据库,每一列为一个字段,如开高低收。必定有两个关键的列做为索引,一列表示股票代码,另一表示时间 9 | 2. 两个索引中有大量重复值,如果只存一个字段利用率只有1/3,得存多个字段才经济 10 | 3. 两个关键字段最好如同数据库一样建B+树索引,否则这两个字段可能因数据太多无法加载到内存 11 | 4. 两个索引的排序先后也对速度有影响。 12 | 1. 先时间分组。真实场景。数据也是按时间生成的。适合做横截面,比如同一天所有股票的涨跌幅排序 13 | 2. 先股票分组,组内时间排序。适合历史分析。可以快速将一指股票取出,NaN也提前就drop了,计算技术指标也方便 14 | 5. 复合索引存一个文件,其它每个字段一个文件,添加新字段时速度快 15 | 2. 宽表 16 | 1. 整表为一个字段,例如都为收盘价,列头为股票代码,索引为时间 17 | 2. 由于上市时间有先后,二维矩阵会出现大量NaN,停牌也会出现NaN,空间利用率低 18 | 3. 停牌中段出现的空值对talib这类的指标计算有影响,必须特别处理才行 19 | 4. 整块存储时,一支股票为一列,这一列均匀分布在文件中,为了加载这列股票,整个文件都必须加载,超大文件时根本不可行 20 | 5. 列式存储时,按整表取出,底层4000多支股票合并会占用大量空间和时间 21 | 6. 两个索引一个文件,其它二维数据一个文件 22 | ## 技术指标计算难点 23 | 1. 股票支数太多4000多支,talib至少要调用4000多次,单核很难在3s内下次行情推送过来前计算完成 24 | 2. 由于GIL的原因,IO密集型可以用多线程,而计算密集型只能用多进程 25 | 1. 多进程最大的问题是数据跨进程。数据序列化和反序列化才能传给子进程,而数据通常几个G,序列化不现实 26 | 2. 共存内存和内存文件映射是一个不错的方案 27 | 1. 宽表。大文件会被全量加载,内存吃紧 28 | 2. 长表。索引需要一个好的机制 29 | 3. 无论哪种方案都对大数据量都比较麻烦 30 | 31 | ## 方案 32 | 1. 由于NaN处理实在麻烦 33 | 2. 二维计算对各种指标库要求太高。 34 | 35 | 所以还是用长表,然后配合groupby比较方便, 36 | 1. 算指标时按股票代码分组,按时间排序。为了加速,最好每天收盘后就做好排序整理 37 | 2. 算横截面时按时间分组 38 | 39 | 指标计算简单了,但只利用到了单核,计算还是慢。文件还大。多进程时又受限于数据序列化 40 | 41 | ## 分治法 42 | 1. 水平品种分割 43 | 1. 4000支股票放在一个进程中处理不过来,放在4000多个文件也打开处理很慢。应当分组存放 44 | 2. 通过股票代码后一位,可以均匀分布10个文件,10个进程分别处理 45 | 3. 根据数据量的大小和CPU核数,可以最后两位分成100个,或100%5分成20个,需要进行取舍 46 | 2. 垂直时序分割 47 | 1. 数据长计算久,历史数据可以长数据,但对实时性要求高的部分则最好能计算最新部分的一小段 48 | 2. 日线可以按年划分,分钟线可以按月划分 49 | 3. 由于指标需要预加载部分数据,所以数据需要一段前一时段的行情 50 | 3. 横截面处理 51 | 1. 在4000支股票的长表数据分割到10个文件前可以进行基于基础行情的横载面的指标 52 | 2. 如果已经划分到10个文件了,只能合并10个文件才能进行横载面计算,横截面处理的机会目前不多 53 | 54 | ## 数据处理流程 55 | 1. 按天下载全部A股数据,比按支下载4000多次要快。所以数据原始格式为按天排序的长表 56 | 2. 按年将表加载,然后统一表头,为今后可能切换数据源做预留 57 | 3. 可以进行初步的横截面计算 58 | 3. 股票按股票代码最后的数字进行分成10个文件,可以同时做一下 59 | 4. 多进程对10个文件进行时序指标的计算 60 | 5. 再考虑是否要合并计算横截面 -------------------------------------------------------------------------------- /参数.md: -------------------------------------------------------------------------------- 1 | # 参数 2 | 1. variable argument可变参数 3 | 2. dynamic parameter动态参数 4 | 3. sequence paramater序列参数 5 | 6 | 参数每天变化,不同股票参数也可能不同,这种情况规范的叫法是什么? 7 | 8 | ## 分类 9 | 1. 不依赖上期结果(MA类) 10 | 1. 如MA等需要近期一段时间的数据进行计算 11 | 2. 需要一段时间数据用于计算 12 | 2. 依赖上期结果(EMA类) 13 | 1. 如EMA等需要昨天结果与今天数据进行计算,实时计算时可加速 14 | 2. 由于数据依赖于昨天,为了结果稳定,需要提前多期开始计算 15 | 16 | 或以上两种的组合,如KAMA指标 17 | 1. 先用一段时间数据计算位移和路程,计算得出效率系数 18 | 2. 通过效率系数和约定的最小周期和最大周期,得到特定的参数alpha,没办法直接用pandas的ewm 19 | 20 | ## 参数空值处理 21 | 1. NaN。指标结果也输出为NaN 22 | 2. bfill。以下一期的参数做这期的参数进行计算 23 | 1. 用到来未来参数,不得使用!!! 24 | 3. ffill。以上一期的参数做这期的参数进行计算 25 | 1. EMA类指标得按天循环计算。 26 | 2. MA类指标也许有向量化快速计算方法。所以分段的最后一天进行统一计算更快 27 | 28 | ## 单值参数,如何改成序列参数,或矩阵参数 29 | 1. 矩阵参数才是真实世界。今天的你已经不是昨天的你 30 | 2. 根据发生时序,先遍历同一行,然后遍历下一行,即每次遍历同一天 31 | 3. 同财务的Point In Time一样,观察点很重要,每天都在站在当前观察点回顾历史所能获取到的值 32 | 4. 每天计算出一个新矩阵后,只保留最新一行。每天重复此动作,生成一个拼接的矩阵即为所需要的结果 33 | 5. 同一指标不同窗口长度,导致统一切片时使用合适长度比较困难 -------------------------------------------------------------------------------- /复权.md: -------------------------------------------------------------------------------- 1 | # 复权 2 | 1. 前复权。新价格不动,老价格进行调整。可以看持仓成本 3 | 1. 老价格发生成了变化,直接使用将引入了未来数据 4 | 2. 可能导致以前的价格出现负数 5 | 2. 后复权。老价格不动,新价格进行调整。可以看持有收益 6 | 1. 可用于计算收益率 7 | 2. 计算技术指标与使用习惯不对应 8 | 3. 动态前复权。每天都为观察点,进行前复权,记录下最新一天结果,然后将所有最新一天结果合并。 9 | 1. 最真实的应用场景 10 | 2. 迭代速度慢 11 | 3. 可用于计算指标 12 | 13 | ## 案例 14 | 1. rank(close),由于没有时序指标参与,所以直接计算,也不用复权。计算的结果取最后一行值保存与全量计算结果是一样的 15 | 2. sma(close), 时序指标,需要动态复权然后计算 16 | 3. sma(sma(close)). 取最后值进行拼接的操作应当只在最外层结果才做 17 | 4. rank(sma(close)),可以将sma的结果取最后一行后拼接然后rank,与最后拼接结果一样 18 | 5. sma(rank(close)), rank是否复权影响sma结果、如何做,还需考虑 19 | 20 | ## 总结 21 | 1. 指标都存在时序窗口期。 22 | 1. 窗口期=1. 不受复权影响。以下简化成1() 23 | 2. 窗口期>1. 需要动态复权。以下简化成2() 24 | 2. 指标是嵌套使用的。 25 | 1. 1(1())。不受复权影响 26 | 2. 2(2())。需要动态复权 27 | 2. 1(2())。内层先复权拼接,与最外层统一拼接结果一样。因为外层不改变最终结论 28 | 2. 2(1())。复权影响1()前后的结果 29 | 3. 动态复权,整体拼接。结构最统一 30 | 31 | ## 方案 32 | 1. 先提前取复权因子,截取到不需要复权部分,统一计算。 33 | 2. 遇到要复权时进行大块循环 34 | 35 | # 新想法 36 | 同指标或同原理指标,进行相对操作时,产生信号的时间是一样的,所以不需要进行动态前复权,只需要提前后复权即可。 37 | 38 | 例如: 39 | 1. SMA快速和慢速,不同复权方式,金叉死叉发生的时间是不变的 40 | 2. ATR,不同复权,绝对值不同,但与Close的相对比值是不变的。所以3倍ATR止损发生的时机也是一样的 41 | 3. STD,同上 42 | 4. 其它指标是否有此特性需要根据算法进行研究 43 | 44 | 所以,如果使用相对值,就可以统一后复权?需要写代码验证一下 45 | 1. MA的CROSS,MA5/MA10,达到预期 46 | 2. STDDEV10/STDDEV20, 达到预期 47 | 3. VAR10/CLOSE**2, 达到预期,即VAR是STDDEV的二次方,所以收盘价也得二次方 48 | -------------------------------------------------------------------------------- /指标对比.xlsx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wukan1986/ta_cn/a569a618109daa804541c5d67fa4c0407f03fd49/指标对比.xlsx --------------------------------------------------------------------------------