├── .gitignore ├── .gitmodules ├── LICENSE ├── README.md ├── examples ├── dual_ema_on_apple.py └── smart_beta.py ├── requirements.txt ├── setup.py ├── spectre ├── __init__.py ├── config.py ├── data │ ├── __init__.py │ ├── arrow.py │ ├── csv.py │ ├── dataloader.py │ ├── iex.py │ ├── memory.py │ ├── quandl.py │ └── yahoo.py ├── factors │ ├── __init__.py │ ├── basic.py │ ├── datafactor.py │ ├── engine.py │ ├── factor.py │ ├── feature.py │ ├── filter.py │ ├── label.py │ ├── multiprocessing.py │ ├── statistical.py │ └── technical.py ├── parallel │ ├── __init__.py │ ├── algorithmic.py │ └── constants.py ├── plotting │ ├── __init__.py │ ├── chart.py │ ├── factor_diagram.py │ └── returns_chart.py └── trading │ ├── __init__.py │ ├── algorithm.py │ ├── blotter.py │ ├── calendar.py │ ├── event.py │ ├── metric.py │ ├── portfolio.py │ ├── position.py │ └── stopmodel.py └── tests ├── __init__.py ├── benchmarks_spectre.ipynb ├── benchmarks_zipline.ipynb ├── data ├── 5mins │ ├── AAPL_2018.csv │ ├── AAPL_2019.csv │ ├── MSFT_2018.csv │ └── MSFT_2019.csv ├── daily │ ├── AAPL.csv │ └── MSFT.csv ├── dividends │ ├── AAPL.csv │ └── MSFT.csv └── splits │ ├── AAPL.csv │ └── MSFT.csv ├── test_blotter.py ├── test_custom_factor.py ├── test_data_factor.py ├── test_data_loader.py ├── test_event.py ├── test_factor.py ├── test_metric.py ├── test_parallel_algo.py └── test_trading_algorithm.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | /.vscode 106 | /.idea 107 | /.note.md 108 | /tests/data/yahoo 109 | /private_factors 110 | /tests/data/orders_2020-06-01.csv 111 | /tests/data/orders_2020-06-02.csv 112 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "spectre/data/iex_fetcher"] 2 | path = spectre/data/iex_fetcher 3 | url = https://github.com/Heerozh/iex_fetcher.git 4 | branch = lib 5 | -------------------------------------------------------------------------------- /examples/dual_ema_on_apple.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: Heerozh (Zhang Jianhao) 3 | @copyright: Copyright 2020, Heerozh. All rights reserved. 4 | @license: Apache 2.0 5 | @email: heeroz@gmail.com 6 | """ 7 | 8 | from spectre import factors, trading 9 | from spectre.data import YahooDownloader, ArrowLoader 10 | import pandas as pd 11 | 12 | 13 | class AppleDualEma(trading.CustomAlgorithm): 14 | invested = False 15 | asset = 'AAPL' 16 | 17 | def initialize(self): 18 | # setup engine 19 | engine = self.get_factor_engine() 20 | engine.to_cuda() 21 | 22 | universe = factors.StaticAssets({self.asset}) 23 | engine.set_filter(universe) 24 | 25 | # add your factors 26 | fast_ema = factors.EMA(20) 27 | slow_ema = factors.EMA(40) 28 | engine.add(fast_ema, 'fast_ema') 29 | engine.add(slow_ema, 'slow_ema') 30 | engine.add(fast_ema > slow_ema, 'buy_signal') 31 | engine.add(fast_ema < slow_ema, 'sell_signal') 32 | engine.add(factors.OHLCV.close, 'price') 33 | 34 | # schedule rebalance before market close 35 | self.schedule_rebalance(trading.event.MarketClose(self.rebalance, offset_ns=-10000)) 36 | 37 | # simulation parameters 38 | self.blotter.capital_base = 10000 39 | self.blotter.set_commission(percentage=0, per_share=0.005, minimum=1) 40 | 41 | def rebalance(self, data: pd.DataFrame, history: pd.DataFrame): 42 | asset_data = data.loc[self.asset] 43 | buy, sell = False, False 44 | if asset_data.buy_signal and not self.invested: 45 | self.blotter.order(self.asset, 100) 46 | self.invested = True 47 | buy = True 48 | elif asset_data.sell_signal and self.invested: 49 | self.blotter.order(self.asset, -100) 50 | self.invested = False 51 | sell = True 52 | 53 | self.record(AAPL=asset_data.price, 54 | short_ema=asset_data.fast_ema, 55 | long_ema=asset_data.slow_ema, 56 | buy=buy, 57 | sell=sell) 58 | 59 | def terminate(self, records: 'pd.DataFrame'): 60 | # plotting results 61 | self.plot(benchmark='SPY') 62 | 63 | import matplotlib.pyplot as plt 64 | fig = plt.figure() 65 | ax1 = fig.add_subplot(211) 66 | ax1.set_ylabel('Price (USD)') 67 | 68 | records[['AAPL', 'short_ema', 'long_ema']].plot(ax=ax1) 69 | 70 | ax1.plot( 71 | records.index[records.buy], 72 | records.loc[records.buy, 'long_ema'], 73 | '^', 74 | markersize=10, 75 | color='m', 76 | ) 77 | ax1.plot( 78 | records.index[records.sell], 79 | records.loc[records.sell, 'short_ema'], 80 | 'v', 81 | markersize=10, 82 | color='k', 83 | ) 84 | plt.legend(loc=0) 85 | plt.gcf().set_size_inches(18, 8) 86 | 87 | plt.show() 88 | 89 | 90 | if __name__ == '__main__': 91 | import plotly.io as pio 92 | pio.renderers.default = "browser" 93 | 94 | import argparse 95 | parser = argparse.ArgumentParser() 96 | parser.add_argument("--download", help="download yahoo data") 97 | args = parser.parse_args() 98 | 99 | if args.download: 100 | YahooDownloader.ingest( 101 | start_date="2001", save_to="./yahoo", 102 | symbols=None, skip_exists=True) 103 | 104 | loader = ArrowLoader('./yahoo/yahoo.feather') 105 | results = trading.run_backtest(loader, AppleDualEma, '2013-01-01', '2018-01-01') 106 | print(results.transactions) 107 | -------------------------------------------------------------------------------- /examples/smart_beta.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: Heerozh (Zhang Jianhao) 3 | @copyright: Copyright 2020, Heerozh. All rights reserved. 4 | @license: Apache 2.0 5 | @email: heeroz@gmail.com 6 | """ 7 | 8 | from spectre import factors, trading 9 | from spectre.data import YahooDownloader, ArrowLoader 10 | import pandas as pd 11 | 12 | 13 | class SmartBeta(trading.CustomAlgorithm): 14 | def initialize(self): 15 | # setup engine 16 | engine = self.get_factor_engine() 17 | engine.to_cuda() 18 | 19 | universe = factors.AverageDollarVolume(win=120).top(500) 20 | engine.set_filter(universe) 21 | 22 | # SP500 factor 23 | sp500 = factors.AverageDollarVolume(win=63) 24 | # our alpha is put more weight on NVDA! StaticAssets return True(1) on NVDA 25 | # and False(0) on others 26 | alpha = sp500 * (factors.StaticAssets({'NVDA'})*5 + 1) 27 | engine.add(alpha.to_weight(demean=False), 'weight') 28 | 29 | # schedule rebalance before market close 30 | self.schedule_rebalance(trading.event.MarketClose(self.rebalance, offset_ns=-10000)) 31 | 32 | # simulation parameters 33 | self.blotter.capital_base = 1000000 34 | self.blotter.set_commission(percentage=0, per_share=0.005, minimum=1) 35 | 36 | def rebalance(self, data: pd.DataFrame, history: pd.DataFrame): 37 | self.blotter.batch_order_target_percent(data.index, data.weight) 38 | # closing asset position that are no longer in our universe. 39 | removes = self.blotter.portfolio.positions.keys() - set(data.index) 40 | self.blotter.batch_order_target_percent(removes, [0] * len(removes)) 41 | 42 | def terminate(self, records: 'pd.DataFrame'): 43 | # plotting results 44 | self.plot(benchmark='SPY') 45 | 46 | 47 | if __name__ == '__main__': 48 | import plotly.io as pio 49 | pio.renderers.default = "browser" 50 | 51 | import argparse 52 | parser = argparse.ArgumentParser() 53 | parser.add_argument("--download", help="download yahoo data") 54 | args = parser.parse_args() 55 | 56 | if args.download: 57 | YahooDownloader.ingest( 58 | start_date="2001", save_to="./yahoo", 59 | symbols=None, skip_exists=True) 60 | 61 | loader = ArrowLoader('./yahoo/yahoo.feather') 62 | results = trading.run_backtest(loader, SmartBeta, '2013-01-01', '2018-01-01') 63 | print(results.transactions) 64 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | requests 2 | python>=3.5 3 | pyarrow 4 | numpy 5 | pandas>=0.22 6 | bs4 7 | lxml 8 | torch>=1.3 9 | plotly 10 | tqdm 11 | 12 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup( 4 | name="spectre", 5 | author='Zhang Jianhao', 6 | author_email='heeroz@gmail.com', 7 | description='GPU-accelerated Parallel quantitative trading library', 8 | long_description=open('README.md', encoding='utf-8').read(), 9 | license='Apache 2.0', 10 | keywords='quantitative analysis backtesting parallel algorithmic trading', 11 | url='https://github.com/Heerozh/spectre', 12 | classifiers=[ 13 | "Programming Language :: Python :: 3", 14 | "Operating System :: OS Independent", 15 | ], 16 | python_requires='>=3.5', 17 | 18 | version="0.5", 19 | packages=['spectre', 'spectre.data', 'spectre.factors', 'spectre.parallel', 'spectre.trading', 20 | 'spectre.plotting'], 21 | ) 22 | -------------------------------------------------------------------------------- /spectre/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from . import data 3 | from . import factors 4 | from . import parallel 5 | from . import trading 6 | from . import plotting 7 | 8 | 9 | __all__ = [ 10 | "data", 11 | "factors", 12 | "parallel", 13 | "trading", 14 | "plotting", 15 | ] 16 | -------------------------------------------------------------------------------- /spectre/config.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class Global: 5 | # default float type for engine (not affect dataloader) 6 | float_type = torch.float32 7 | 8 | -------------------------------------------------------------------------------- /spectre/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .dataloader import ( 2 | DataLoader, 3 | DataLoaderFastGetter, 4 | ) 5 | from .arrow import ( 6 | ArrowLoader, 7 | ) 8 | from .csv import ( 9 | CsvDirLoader, 10 | ) 11 | from .memory import ( 12 | MemoryLoader, 13 | ) 14 | from .quandl import ( 15 | QuandlLoader, 16 | ) 17 | from .yahoo import ( 18 | YahooDownloader, 19 | ) 20 | -------------------------------------------------------------------------------- /spectre/data/arrow.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: Heerozh (Zhang Jianhao) 3 | @copyright: Copyright 2019-2020, Heerozh. All rights reserved. 4 | @license: Apache 2.0 5 | @email: heeroz@gmail.com 6 | """ 7 | import pandas as pd 8 | import os 9 | import warnings 10 | from .dataloader import DataLoader 11 | 12 | 13 | class ArrowLoader(DataLoader): 14 | """ Read from persistent data. """ 15 | 16 | def __init__(self, path: str = None, keep_in_memory: bool = True) -> None: 17 | if not os.path.exists(path + '.meta'): 18 | raise FileNotFoundError(os.path.abspath(path + '.meta')) 19 | 20 | # pandas 0.22 has the fastest MultiIndex 21 | if pd.__version__.startswith('0.22'): 22 | import feather 23 | cols = feather.read_dataframe(path + '.meta') 24 | else: 25 | cols = pd.read_feather(path + '.meta') 26 | 27 | ohlcv = cols.ohlcv.values 28 | adjustments = cols.adjustments.values[:2] 29 | if adjustments[0] is None: 30 | adjustments = None 31 | super().__init__(path, ohlcv, adjustments) 32 | self.keep_in_memory = keep_in_memory 33 | self._cache = None 34 | self._filter = None 35 | 36 | @classmethod 37 | def _last_modified(cls, file_path) -> float: 38 | if not os.path.isfile(file_path): 39 | return 0 40 | else: 41 | return os.path.getmtime(file_path) 42 | 43 | @property 44 | def last_modified(self) -> float: 45 | return self._last_modified(self._path) 46 | 47 | @classmethod 48 | def ingest(cls, source: DataLoader, save_to, force: bool = False) -> None: 49 | if not force and (source.last_modified <= cls._last_modified(save_to)): 50 | warnings.warn("You called `ingest()`, but `source` seems unchanged, " 51 | "no ingestion required. Set `force=True` to re-ingest.", 52 | RuntimeWarning) 53 | return 54 | 55 | df = source.test_load() 56 | df.reset_index(inplace=True) 57 | df.to_feather(save_to) 58 | 59 | meta = pd.DataFrame(columns=['ohlcv', 'adjustments']) 60 | meta.ohlcv = source.ohlcv 61 | meta.adjustments[:2] = source.adjustments 62 | # meta.loc[:2, "adjustments"] = source.adjustments 63 | meta.to_feather(save_to + '.meta') 64 | 65 | def filter(self, func): 66 | self._filter = func 67 | self._cache = None 68 | 69 | def _load(self) -> pd.DataFrame: 70 | if self._cache is not None: 71 | return self._cache 72 | 73 | if pd.__version__.startswith('0.22'): 74 | import feather 75 | df = feather.read_dataframe(self._path) 76 | else: 77 | df = pd.read_feather(self._path) 78 | df.set_index(['date', 'asset'], inplace=True) 79 | 80 | if self._filter is not None: 81 | df = self._filter(df) 82 | 83 | if self.keep_in_memory: 84 | self._cache = df 85 | return df 86 | -------------------------------------------------------------------------------- /spectre/data/csv.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: Heerozh (Zhang Jianhao) 3 | @copyright: Copyright 2019-2020, Heerozh. All rights reserved. 4 | @license: Apache 2.0 5 | @email: heeroz@gmail.com 6 | """ 7 | import numpy as np 8 | import pandas as pd 9 | import os 10 | import glob 11 | import warnings 12 | from .dataloader import DataLoader 13 | 14 | 15 | class CsvDirLoader(DataLoader): 16 | def __init__(self, prices_path: str, prices_by_year=False, earliest_date: pd.Timestamp = None, 17 | dividends_path=None, splits_path=None, file_pattern='*.csv', 18 | calender_asset: str = None, align_by_time=False, 19 | ohlcv=('open', 'high', 'low', 'close', 'volume'), adjustments=None, 20 | split_ratio_is_inverse=False, split_ratio_is_fraction=False, 21 | prices_index='date', dividends_index='exDate', splits_index='exDate', **read_csv): 22 | """ 23 | Load data from csv dir 24 | :param prices_path: prices csv folder, structured as one csv per stock. 25 | When encountering duplicate indexes data in `prices_index`, Loader will keep the last, 26 | drop others. 27 | :param prices_by_year: If price file name like 'spy_2017.csv', set this to True 28 | :param earliest_date: Data before this date will not be saved to memory. Note: Use the 29 | same time zone as in the csv files. 30 | :param dividends_path: dividends csv folder, structured as one csv per stock. 31 | For duplicate data, loader will first drop the exact same rows, and then for the same 32 | `dividends_index` but different 'dividend amount' rows, loader will sum them up. 33 | If `dividends_path` not set, the `adjustments[0]` column is considered to be included 34 | in the prices csv. 35 | :param splits_path: splits csv folder, structured as one csv per stock, 36 | When encountering duplicate indexes data in `splits_index`, Loader will use the last 37 | non-NaN 'split ratio', drop others. 38 | If `splits_path` not set, the `adjustments[1]` column is considered to be included 39 | in the prices csv. 40 | :param file_pattern: csv file name pattern, default is '*.csv'. 41 | :param calender_asset: asset name as trading calendar, like 'SPY', for clean up non-trading 42 | time data. 43 | :param align_by_time: if True and `calender_asset` not None, df index will be the product of 44 | 'date' and 'asset'. 45 | :param ohlcv: Required, OHLCV column names. When you don't need to use `adjustments` and 46 | `factors.OHLCV`, you can set this to None. 47 | :param adjustments: Optional, `dividend amount` and `splits ratio` column names. 48 | :param split_ratio_is_inverse: If split ratio calculated by to/from, set to True. 49 | For example, 2-for-1 split, to/form = 2, 1-for-15 Reverse Split, to/form = 0.6666... 50 | :param split_ratio_is_fraction: If split ratio in csv is fraction string, set to True. 51 | :param prices_index: `index_col`for csv in prices_path 52 | :param dividends_index: `index_col`for csv in dividends_path. 53 | :param splits_index: `index_col`for csv in splits_path. 54 | :param read_csv: Parameters for all csv when calling pd.read_csv. 55 | """ 56 | if adjustments is None: 57 | super().__init__(prices_path, ohlcv, None) 58 | else: 59 | super().__init__(prices_path, ohlcv, ('ex-dividend', 'split_ratio')) 60 | 61 | assert 'index_col' not in read_csv, \ 62 | "`index_col` cannot be used here. Use `prices_index` and `dividends_index` and " \ 63 | "`splits_index` instead." 64 | if 'dtype' not in read_csv: 65 | warnings.warn("It is recommended to set the `dtype` parameter and use float32 whenever " 66 | "possible. Example: dtype = {'Open': np.float32, 'High': np.float32, " 67 | "'Low': np.float32, 'Close': np.float32, 'Volume': np.float64}", 68 | RuntimeWarning) 69 | self._adjustment_cols = adjustments 70 | self._split_ratio_is_inverse = split_ratio_is_inverse 71 | self._split_ratio_is_fraction = split_ratio_is_fraction 72 | self._prices_by_year = prices_by_year 73 | self._earliest_date = earliest_date 74 | self._dividends_path = dividends_path 75 | self._splits_path = splits_path 76 | self._file_pattern = file_pattern 77 | self._calender = calender_asset 78 | self._prices_index = prices_index 79 | self._dividends_index = dividends_index 80 | self._splits_index = splits_index 81 | self._read_csv = read_csv 82 | self._align_by_time = align_by_time 83 | 84 | @property 85 | def last_modified(self) -> float: 86 | pattern = os.path.join(self._path, self._file_pattern) 87 | files = glob.glob(pattern) 88 | if len(files) == 0: 89 | raise ValueError("Dir '{}' does not contains any csv files.".format(self._path)) 90 | return max([os.path.getmtime(fn) for fn in files]) 91 | 92 | def _walk_split_by_year_dir(self, csv_path, index_col): 93 | years = set(pd.date_range(self._earliest_date or 0, pd.Timestamp.now()).year) 94 | pattern = os.path.join(csv_path, self._file_pattern) 95 | files = glob.glob(pattern) 96 | assets = {} 97 | for fn in files: 98 | base = os.path.basename(fn) 99 | symbol, year = base[:-9].upper(), int(base[-8:-4]) # like 'spy_2011.csv' 100 | if year in years: 101 | if symbol in assets: 102 | assets[symbol].append(fn) 103 | else: 104 | assets[symbol] = [fn, ] 105 | 106 | def multi_read_csv(file_list): 107 | df = pd.concat([pd.read_csv(_fn, index_col=index_col, **self._read_csv) 108 | for _fn in file_list]) 109 | if not isinstance(df.index, pd.DatetimeIndex): 110 | raise ValueError( 111 | "df must index by datetime, set correct `read_csv`, " 112 | "for example index_col='date', parse_dates=True. " 113 | "For mixed-timezone like daylight saving time, " 114 | "set date_parser=lambda col: pd.to_datetime(col, utc=True)") 115 | 116 | return df[~df.index.duplicated(keep='last')] 117 | 118 | dfs = {symbol: multi_read_csv(file_list) for symbol, file_list in assets.items()} 119 | return dfs 120 | 121 | def _walk_dir(self, csv_path, index_col): 122 | pattern = os.path.join(csv_path, self._file_pattern) 123 | files = glob.glob(pattern) 124 | if len(files) == 0: 125 | raise ValueError("There are no files is {}".format(csv_path)) 126 | 127 | def symbol(file): 128 | return os.path.basename(file)[:-4].upper() 129 | 130 | def read_csv(file): 131 | df = pd.read_csv(file, index_col=index_col, **self._read_csv) 132 | if len(df.index.dropna()) == 0: 133 | return None 134 | if not isinstance(df.index, pd.DatetimeIndex): 135 | raise ValueError( 136 | "df must index by datetime, set correct `read_csv`, " 137 | "for example parse_dates=True. " 138 | "For mixed-timezone like daylight saving time, " 139 | "set date_parser=lambda col: pd.to_datetime(col, utc=True)") 140 | return df[self._earliest_date:] 141 | 142 | dfs = {symbol(fn): read_csv(fn) for fn in files} 143 | return dfs 144 | 145 | def _load(self): 146 | if self._prices_by_year: 147 | dfs = self._walk_split_by_year_dir(self._path, self._prices_index) 148 | else: 149 | dfs = self._walk_dir(self._path, self._prices_index) 150 | dfs = {k: v[~v.index.duplicated(keep='last')] for k, v in dfs.items() if v is not None} 151 | df = pd.concat(dfs, sort=False) 152 | df = df.rename_axis(['asset', 'date']) 153 | 154 | if self.ohlcv is not None: 155 | # 这里把0当成nan进行ffill,如果未来取消了,要先把0变成nan,然后df.ffill 156 | df[list(self.ohlcv)] = df[list(self.ohlcv)].replace(to_replace=0, method='ffill') 157 | 158 | if self._dividends_path is not None: 159 | dfs = self._walk_dir(self._dividends_path, self._dividends_index) 160 | ex_div_col = self._adjustment_cols[0] 161 | div_index = self._dividends_index 162 | 163 | def _agg_duplicated(_df): 164 | if _df is None or ex_div_col not in _df: 165 | return None 166 | _df = _df.reset_index().drop_duplicates() 167 | _df = _df.dropna(subset=[ex_div_col]) 168 | _df = _df.set_index(div_index)[ex_div_col] 169 | return _df.groupby(level=0).agg('sum') 170 | 171 | dfs = {k: _agg_duplicated(v) for k, v in dfs.items()} 172 | div = pd.concat(dfs, sort=False) 173 | div = div.reindex(df.index) 174 | div = div.fillna(0) 175 | div.name = self._adjustments[0] 176 | # div = df.rename_axis(['asset', 'date']) 177 | df = pd.concat([df, div], axis=1) 178 | 179 | if self._splits_path is not None: 180 | dfs = self._walk_dir(self._splits_path, self._splits_index) 181 | sp_rto_col = self._adjustment_cols[1] 182 | 183 | def _drop_na_and_duplicated(_df): 184 | if _df is None or sp_rto_col not in _df: 185 | return None 186 | _df = _df.dropna(subset=[sp_rto_col])[sp_rto_col] 187 | return _df[~_df.index.duplicated(keep='last')] 188 | 189 | dfs = {k: _drop_na_and_duplicated(v) for k, v in dfs.items()} 190 | splits = pd.concat(dfs, sort=False) 191 | if self._split_ratio_is_fraction: 192 | from fractions import Fraction 193 | 194 | def fraction_2_float(x): 195 | try: 196 | return float(Fraction(x)) 197 | except (ValueError, ZeroDivisionError): 198 | return np.nan 199 | 200 | splits = splits.apply(fraction_2_float) 201 | splits = splits.reindex(df.index) 202 | splits = splits.fillna(1) 203 | splits.name = self._adjustments[1] 204 | df = pd.concat([df, splits], axis=1) 205 | 206 | df = df.swaplevel(0, 1).sort_index(level=0) 207 | 208 | if self._calender: 209 | # drop the data of the non-trading day by calender, 210 | # because there may be some one-line junk data in non-trading day, 211 | # causing extra row of nan to all others assets. 212 | df = self._align_to(df, self._calender, self._align_by_time) 213 | df.sort_index(level=[0, 1], inplace=True) 214 | df = self._format(df, self._split_ratio_is_inverse) 215 | return df 216 | -------------------------------------------------------------------------------- /spectre/data/dataloader.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: Heerozh (Zhang Jianhao) 3 | @copyright: Copyright 2019-2020, Heerozh. All rights reserved. 4 | @license: Apache 2.0 5 | @email: heeroz@gmail.com 6 | """ 7 | from typing import Optional 8 | import pandas as pd 9 | import numpy as np 10 | 11 | 12 | class DataLoader: 13 | def __init__(self, path: str, ohlcv=('open', 'high', 'low', 'close', 'volume'), 14 | adjustments=('ex-dividend', 'split_ratio')) -> None: 15 | self._path = path 16 | self._ohlcv = ohlcv 17 | self._adjustments = adjustments 18 | 19 | @property 20 | def ohlcv(self): 21 | return self._ohlcv 22 | 23 | @property 24 | def adjustments(self): 25 | return self._adjustments 26 | 27 | @property 28 | def adjustment_multipliers(self): 29 | return ['price_multi', 'vol_multi'] 30 | 31 | @property 32 | def time_category(self): 33 | return '_time_cat_id' 34 | 35 | @property 36 | def last_modified(self) -> float: 37 | """ data source last modification time """ 38 | raise NotImplementedError("abstractmethod") 39 | 40 | @property 41 | def min_timedelta(self) -> pd.Timedelta: 42 | """ Minimum time delta of date index """ 43 | date_idx = self.load().index.levels[0] 44 | return min(date_idx[1:] - date_idx[:-1]) 45 | 46 | @classmethod 47 | def _align_to(cls, df, calender_asset, align_by_time=False): 48 | """ helper method for align index """ 49 | index = df.loc[(slice(None), calender_asset), :].index.get_level_values(0) 50 | df = df[df.index.get_level_values(0).isin(index)] 51 | df.index = df.index.remove_unused_levels() 52 | if align_by_time: 53 | df = df.reindex(pd.MultiIndex.from_product(df.index.levels)) 54 | # df = df.unstack(level=1).stack(dropna=False) 这个速度快,但新版本取消掉了 55 | 56 | def trim_nans(x): 57 | dts = x.index.get_level_values(0) 58 | first = x.first_valid_index() 59 | last = x.last_valid_index() 60 | if first is None: 61 | return None 62 | mask = (dts >= first[0]) & (dts <= last[0]) 63 | return x[mask] 64 | 65 | df = df.groupby(level=1, group_keys=False).apply(trim_nans) 66 | return df 67 | 68 | def _format(self, df, split_ratio_is_inverse=False) -> pd.DataFrame: 69 | """ 70 | Format the data as we want it. df index must be in order [datetime, asset_name] 71 | * change index name to ['date', 'asset'] 72 | * change asset column type to category 73 | * covert date index to utc timezone 74 | * create time_cat column 75 | * create adjustment multipliers columns 76 | """ 77 | # print(pd.Timestamp.now(), 'Formatting index...') 78 | df = df.rename_axis(['date', 'asset']) 79 | # speed up asset index search time 80 | df = df.reset_index() 81 | asset_type = pd.api.types.CategoricalDtype(categories=pd.unique(df.asset).sort(), 82 | ordered=True) 83 | df.asset = df.asset.astype(asset_type) 84 | # format index and convert to utc timezone-aware 85 | df.set_index(['date', 'asset'], inplace=True) 86 | if df.index.levels[0].tzinfo is None: 87 | df = df.tz_localize('UTC', level=0, copy=False) 88 | else: 89 | df = df.tz_convert('UTC', level=0, copy=False) 90 | df.sort_index(level=[0, 1], inplace=True) 91 | # generate time key for parallel 92 | # print(pd.Timestamp.now(), 'Formatting time key for gpu sorting...') 93 | date_index = df.index.get_level_values(0) 94 | unique_date = date_index.unique() 95 | time_cat = dict(zip(unique_date, range(len(unique_date)))) 96 | # cat = np.fromiter(map(lambda x: time_cat[x], date_index), dtype=int) 97 | df[self.time_category] = date_index.map(time_cat) 98 | # print(pd.Timestamp.now(), 'Done.') 99 | 100 | # Process dividends and split 101 | if self.adjustments is not None: 102 | div_col = self.adjustments[0] 103 | spr_col = self.adjustments[1] 104 | close_col = self.ohlcv[3] 105 | price_multi_col = self.adjustment_multipliers[0] 106 | vol_multi_col = self.adjustment_multipliers[1] 107 | if split_ratio_is_inverse: 108 | df[spr_col] = 1 / df[spr_col] 109 | 110 | # move ex-div up 1 row 111 | groupby = df.groupby(level=1, observed=False) 112 | last = pd.DataFrame.last_valid_index 113 | ex_div = groupby[div_col].shift(-1) 114 | ex_div.loc[groupby.apply(last)] = 0 115 | sp_rto = groupby[spr_col].shift(-1) 116 | sp_rto.loc[groupby.apply(last)] = 1 117 | 118 | df[div_col] = ex_div 119 | df[spr_col] = sp_rto 120 | 121 | # generate dividend multipliers 122 | price_multi = (1 - ex_div / df[close_col]) * sp_rto 123 | price_multi = price_multi[::-1].groupby(level=1, observed=False).cumprod()[::-1] 124 | df[price_multi_col] = price_multi.astype(np.float32) 125 | vol_multi = (1 / sp_rto)[::-1].groupby(level=1, observed=False).cumprod()[::-1] 126 | df[vol_multi_col] = vol_multi.astype(np.float32) 127 | 128 | return df 129 | 130 | def _load(self) -> pd.DataFrame: 131 | """ 132 | Return dataframe with multi-index ['date', 'asset'] 133 | 134 | You need to call `self.test_load()` in your test case to check if the format 135 | you returned is correct. 136 | """ 137 | raise NotImplementedError("abstractmethod") 138 | 139 | def test_load(self): 140 | """ 141 | Basic test for the format returned by _load(), 142 | If you write your own Loader, call this method at your test case. 143 | """ 144 | df = self._load() 145 | 146 | assert df.index.names == ['date', 'asset'], \ 147 | "df.index.names should be ['date', 'asset'] " 148 | assert not any(df.index.duplicated()), \ 149 | "There are duplicate indexes in df, you need handle them up." 150 | assert df.index.is_monotonic_increasing, \ 151 | "df.index must be sorted, try using df.sort_index(level=0, inplace=True)" 152 | assert str(df.index.levels[0].tzinfo) == 'UTC', \ 153 | "df.index.date must be UTC timezone." 154 | assert df.index.levels[-1].ordered, \ 155 | "df.index.asset must ordered categorical." 156 | assert self.time_category in df, \ 157 | "You must create a time_category column, convert time to category id" 158 | 159 | if self.adjustments: 160 | assert all(x in df for x in self.adjustments), \ 161 | "Adjustments columns `{}` not found.".format(self.adjustments) 162 | assert all(x in df for x in self.adjustment_multipliers), \ 163 | "Adjustment multipliers columns `{}` not found.".format(self.adjustment_multipliers) 164 | assert not any(df[self.adjustments[0]].isna()), \ 165 | "There is nan value in ex-dividend column, should be filled with 0." 166 | assert not any(df[self.adjustments[1]].isna()), \ 167 | "There is nan value in split_ratio column, should be filled with 1." 168 | assert not any(df[self.time_category].isna()), \ 169 | "There is nan value in time_category column, should be filled with time id." 170 | return df 171 | 172 | def load(self, start: Optional[pd.Timestamp] = None, end: Optional[pd.Timestamp] = None, 173 | backwards: int = 0) -> pd.DataFrame: 174 | df = self._load() 175 | 176 | index = df.index.levels[0] 177 | 178 | if start is None: 179 | start = index[0] 180 | if end is None: 181 | end = index[-1] 182 | 183 | if index[0] > start: 184 | raise ValueError( 185 | f"`start` time ({start}) cannot be less " \ 186 | f"than earliest time of data: {index[0]}." 187 | ) 188 | 189 | if index[-1] < end: 190 | raise ValueError( 191 | f"`end` time ({end}) cannot be greater " \ 192 | f"than latest time of data: {index[-1]}." 193 | ) 194 | 195 | start_loc = index.get_indexer([start], method='bfill')[0] 196 | backward_loc = max(start_loc - backwards, 0) 197 | end_loc = index.get_indexer([end], method='ffill')[0] 198 | assert end_loc >= start_loc, 'There is no data between `start` and `end`.' 199 | 200 | backward_start = index[backward_loc] 201 | return df.loc[backward_start:end] 202 | 203 | 204 | class DataLoaderFastGetter: 205 | """Fast get method for dataloader's DataFrame""" 206 | class DictLikeCursor: 207 | def __init__(self, parent, row_slice, column_id): 208 | self.parent = parent 209 | self.row_slice = row_slice 210 | self.data = parent.raw_data[row_slice, column_id] 211 | self.index = parent.asset_index[row_slice] 212 | self.length = len(self.index) 213 | 214 | def get_datetime_index(self): 215 | return self.parent.indexes[0][self.row_slice] 216 | 217 | def __getitem__(self, asset): 218 | asset_id = self.parent.asset_to_code[asset] 219 | cursor_index = self.index 220 | i = cursor_index.searchsorted(asset_id) 221 | if i >= self.length: 222 | raise KeyError('{} not found'.format(asset)) 223 | if cursor_index[i] != asset_id: 224 | raise KeyError('{} not found'.format(asset)) 225 | return self.data[i] 226 | 227 | def items(self): 228 | idx = self.index 229 | code_to_asset = self.parent.code_to_asset 230 | for i in range(self.length): 231 | code = idx[i] 232 | name = code_to_asset[code] 233 | yield name, self.data[i] 234 | 235 | def get(self, asset, default=None): 236 | try: 237 | return self[asset] 238 | except KeyError: 239 | return default 240 | 241 | def __init__(self, df): 242 | cat = df.index.get_level_values(1) 243 | 244 | self.source = df 245 | self.raw_data = df.values 246 | self.columns = df.columns 247 | self.indexes = [df.index.get_level_values(0), cat] 248 | self.asset_index = cat.codes 249 | self.asset_to_code = {v: k for k, v in enumerate(cat.categories)} 250 | self.code_to_asset = dict(enumerate(cat.categories)) 251 | self.last_row_slice = None 252 | 253 | def get_slice(self, start, stop): 254 | if isinstance(start, slice): 255 | return start 256 | idx = self.indexes[0] 257 | stop = stop or start 258 | row_slice = slice(idx.searchsorted(start), idx.searchsorted(stop, side='right')) 259 | return row_slice 260 | 261 | def get_as_dict(self, start, stop=None, column_id=slice(None)): 262 | row_slice = self.get_slice(start, stop) 263 | cur = self.DictLikeCursor(self, row_slice, column_id) 264 | self.last_row_slice = row_slice 265 | return cur 266 | 267 | def get_as_df(self, start, stop=None): 268 | """550x faster than .loc[], 3x faster than .iloc[]""" 269 | row_slice = self.get_slice(start, stop) 270 | data = self.raw_data[row_slice] 271 | index = self.indexes[1][row_slice] 272 | self.last_row_slice = row_slice 273 | return pd.DataFrame(data, index=index, columns=self.columns) 274 | -------------------------------------------------------------------------------- /spectre/data/iex.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: Heerozh (Zhang Jianhao) 3 | @copyright: Copyright 2020, Heerozh. All rights reserved. 4 | @license: Apache 2.0 5 | @email: heeroz@gmail.com 6 | """ 7 | import os 8 | import numpy as np 9 | from .dataloader import ArrowLoader, CsvDirLoader 10 | from .iex_fetcher import iex 11 | 12 | 13 | class IexDownloader: # todo IexDownloader unfinished 14 | @classmethod 15 | def _concat(cls): 16 | pass 17 | 18 | @classmethod 19 | def ingest(cls, iex_key, save_to, range_='5y', symbols: list = None, skip_exists=True): 20 | """ 21 | Download data from IEX. Please note that downloading all the data will cost around $60. 22 | :param iex_key: your private api key of IEX account. 23 | :param save_to: path to folder 24 | :param range_: historical range, supports 5y, 2y, 1y, ytd, 6m, 3m, 1m. 25 | :param symbols: list of symbol to download. If is None, download All Stocks 26 | (not including delisted). 27 | :param skip_exists: skip if file exists, useful for resume from interruption. 28 | """ 29 | from tqdm.auto import tqdm 30 | print("Download prices from IEX...") 31 | iex.init(iex_key, api='cloud') 32 | 33 | calender_asset = None 34 | if symbols is None: 35 | symbols = iex.Reference.symbols() 36 | types = (symbols.type == 'ad') | (symbols.type == 'cs') & (symbols.exchange != 'OTC') 37 | symbols = symbols[types].symbol.values 38 | symbols.extend(['SPY', 'QQQ']) 39 | calender_asset = 'SPY' 40 | 41 | def download(event, folder): 42 | for symbol in tqdm(symbols): 43 | csv_path = os.path.join(folder, '{}.csv'.format(symbol)) 44 | if os.path.exists(csv_path) and skip_exists: 45 | continue 46 | if event == 'chart': 47 | iex.Stock(symbol).chart(range_).to_csv(csv_path) 48 | elif event == 'dividends': 49 | iex.Stock(symbol).dividends(range_).to_csv(csv_path) 50 | elif event == 'splits': 51 | iex.Stock(symbol).splits(range_).to_csv(csv_path) 52 | 53 | print('Ingest prices...') 54 | prices_dir = os.path.join(save_to, 'daily') 55 | if not os.path.exists(prices_dir): 56 | os.makedirs(prices_dir) 57 | download('chart', prices_dir) 58 | 59 | print('Ingest dividends...') 60 | div_dir = os.path.join(save_to, 'dividends') 61 | if not os.path.exists(div_dir): 62 | os.makedirs(div_dir) 63 | download('dividends', div_dir) 64 | 65 | print('Ingest splits...') 66 | sp_dir = os.path.join(save_to, 'splits') 67 | if not os.path.exists(sp_dir): 68 | os.makedirs(sp_dir) 69 | download('splits', sp_dir) 70 | 71 | print('Converting...') 72 | use_cols = {'date', 'uOpen', 'uHigh', 'uLow', 'uClose', 'uVolume', 'exDate', 'amount', 73 | 'ratio'} 74 | loader = CsvDirLoader( 75 | prices_dir, calender_asset=calender_asset, 76 | dividends_path=div_dir, splits_path=sp_dir, 77 | ohlcv=('uOpen', 'uHigh', 'uLow', 'uClose', 'uVolume'), adjustments=('amount', 'ratio'), 78 | prices_index='date', dividends_index='exDate', splits_index='exDate', 79 | parse_dates=True, usecols=lambda x: x in use_cols, 80 | dtype={'uOpen': np.float32, 'uHigh': np.float32, 'uLow': np.float32, 81 | 'uClose': np.float32, 82 | 'uVolume': np.float64, 'amount': np.float64, 'ratio': np.float64}) 83 | 84 | arrow_file = os.path.join(save_to, 'yahoo.feather') 85 | ArrowLoader.ingest(source=loader, save_to=arrow_file, force=True) 86 | 87 | print('Ingest completed! Use `loader = spectre.data.ArrowLoader(r"{}")` ' 88 | 'to load your data.'.format(arrow_file)) 89 | 90 | @classmethod 91 | def update(cls, range_, temp_path, save_to): 92 | # todo download and concat 93 | pass 94 | -------------------------------------------------------------------------------- /spectre/data/memory.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: Heerozh (Zhang Jianhao) 3 | @copyright: Copyright 2019-2020, Heerozh. All rights reserved. 4 | @license: Apache 2.0 5 | @email: heeroz@gmail.com 6 | """ 7 | from .dataloader import DataLoader 8 | 9 | 10 | class MemoryLoader(DataLoader): 11 | """ Convert pd.Dataframe to spectre.data.DataLoader """ 12 | def __init__(self, df, ohlcv=None, adjustments=None) -> None: 13 | super().__init__("", ohlcv, adjustments) 14 | self.df = self._format(df) 15 | self.test_load() 16 | 17 | @property 18 | def last_modified(self) -> float: 19 | return 1 20 | 21 | def _load(self): 22 | return self.df 23 | -------------------------------------------------------------------------------- /spectre/data/quandl.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: Heerozh (Zhang Jianhao) 3 | @copyright: Copyright 2019-2020, Heerozh. All rights reserved. 4 | @license: Apache 2.0 5 | @email: heeroz@gmail.com 6 | """ 7 | from zipfile import ZipFile 8 | import numpy as np 9 | import pandas as pd 10 | from .dataloader import DataLoader 11 | 12 | 13 | class QuandlLoader(DataLoader): 14 | @property 15 | def last_modified(self) -> float: 16 | """ the quandl data is no longer updated, so return a fixed value """ 17 | return 1 18 | 19 | def __init__(self, file: str, calender_asset='AAPL') -> None: 20 | """ 21 | Usage: 22 | download data first: 23 | https://www.quandl.com/api/v3/datatables/WIKI/PRICES.csv?qopts.export=true&api_key=[yourapi_key] 24 | then: 25 | loader = data.QuandlLoader('./quandl/WIKI_PRICES.zip') 26 | """ 27 | super().__init__(file, 28 | ohlcv=('open', 'high', 'low', 'close', 'volume'), 29 | adjustments=('ex-dividend', 'split_ratio')) 30 | self._calender = calender_asset 31 | 32 | def _load(self) -> pd.DataFrame: 33 | with ZipFile(self._path) as pkg: 34 | with pkg.open(pkg.namelist()[0]) as csv: 35 | df = pd.read_csv(csv, parse_dates=['date'], 36 | usecols=['ticker', 'date', 'open', 'high', 'low', 'close', 37 | 'volume', 'ex-dividend', 'split_ratio', ], 38 | dtype={ 39 | 'open': np.float32, 'high': np.float32, 'low': np.float32, 40 | 'close': np.float32, 'volume': np.float64, 41 | 'ex-dividend': np.float64, 'split_ratio': np.float64 42 | }) 43 | 44 | df.set_index(['date', 'ticker'], inplace=True) 45 | df.split_ratio.loc[("2001-09-12", 'GMT')] = 1 # fix nan 46 | if self._calender: 47 | df = self._align_to(df, self._calender) 48 | df.sort_index(level=[0, 1], inplace=True) 49 | df = self._format(df, split_ratio_is_inverse=True) 50 | 51 | return df 52 | -------------------------------------------------------------------------------- /spectre/data/yahoo.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: Heerozh (Zhang Jianhao) 3 | @copyright: Copyright 2019, Heerozh. All rights reserved. 4 | @license: Apache 2.0 5 | @email: heeroz@gmail.com 6 | """ 7 | import datetime 8 | import os 9 | import time 10 | import pandas as pd 11 | import numpy as np 12 | 13 | from .arrow import ArrowLoader 14 | from .csv import CsvDirLoader 15 | 16 | 17 | class YahooDownloader: 18 | 19 | @classmethod 20 | def ingest(cls, start_date: str, save_to: str, symbols: list = None, skip_exists=True) -> None: 21 | """ 22 | Download data from yahoo. 23 | :param start_date: 24 | :param save_to: path to folder 25 | :param symbols: list of symbol to download. If is None, download SP500 components. 26 | :param skip_exists: skip if file exists, useful for resume from interruption. 27 | """ 28 | import requests 29 | import re 30 | from tqdm.auto import tqdm 31 | 32 | print("Download prices from yahoo...") 33 | 34 | start_date = pd.to_datetime(start_date, utc=True) 35 | 36 | calender_asset = None 37 | if symbols is None: 38 | etf = pd.read_html(requests.get( 39 | 'https://etfdailynews.com/etf/spy/', headers={'User-agent': 'Mozilla/5.0'} 40 | ).text, attrs={'id': 'etfs-that-own'}) 41 | symbols = [x for x in etf[0].Symbol.values.tolist() if isinstance(x, str)] 42 | symbols.extend(['SPY', 'QQQ']) 43 | calender_asset = 'SPY' 44 | 45 | session = requests.Session() 46 | page = session.get('https://finance.yahoo.com/quote/IBM/history?p=IBM') 47 | # CrumbStore 48 | m = re.search('"CrumbStore":{"crumb":"(.*?)"}', page.text) 49 | crumb = m.group(1) 50 | crumb = crumb.encode('ascii').decode('unicode-escape') 51 | 52 | def download(event, folder): 53 | start = int(start_date.timestamp()) 54 | now = int(datetime.datetime.now().timestamp()) 55 | for symbol in tqdm(symbols): 56 | symbol = symbol.replace('.', '-') 57 | csv_path = os.path.join(folder, '{}.csv'.format(symbol)) 58 | if os.path.exists(csv_path) and skip_exists: 59 | continue 60 | url = "https://query1.finance.yahoo.com/v7/finance/download/" \ 61 | "{}?period1={}&period2={}&interval=1d&events={}&crumb={}".format( 62 | symbol, start, now, event, crumb) 63 | 64 | retry = 0.25 65 | while True: 66 | req = session.get(url) 67 | if req.status_code != requests.codes.ok: 68 | if 'No data found' in req.text: 69 | print('Symbol invalid, skipped: {}.'.format(symbol)) 70 | break 71 | retry *= 2 72 | if retry >= 5: 73 | print('Get {} failed, Over 4 retries, skipped, reason: {}'.format( 74 | symbol, req.text)) 75 | break 76 | else: 77 | time.sleep(retry) 78 | continue 79 | with open(csv_path, 'wb') as f: 80 | f.write(req.content) 81 | break 82 | 83 | print('Ingest prices...') 84 | prices_dir = os.path.join(save_to, 'daily') 85 | if not os.path.exists(prices_dir): 86 | os.makedirs(prices_dir) 87 | download('history', prices_dir) 88 | 89 | # yahoo prices data already split adjusted 90 | # print('Ingest dividends...') 91 | # div_dir = os.path.join(save_to, 'dividends') 92 | # if not os.path.exists(div_dir): 93 | # os.makedirs(div_dir) 94 | # download('div', div_dir) 95 | 96 | # print('Ingest splits...') 97 | # sp_dir = os.path.join(save_to, 'splits') 98 | # if not os.path.exists(sp_dir): 99 | # os.makedirs(sp_dir) 100 | # download('split', sp_dir) 101 | 102 | session.close() 103 | 104 | print('Converting...') 105 | loader = CsvDirLoader( 106 | prices_dir, calender_asset=calender_asset, 107 | # dividends_path=div_dir, 108 | # splits_path=sp_dir, 109 | ohlcv=('Open', 'High', 'Low', 'Close', 'Volume'), 110 | # adjustments=('Dividends', 'Stock Splits'), 111 | prices_index='Date', 112 | # dividends_index='Date', splits_index='Date', split_ratio_is_fraction=True, 113 | parse_dates=True, 114 | dtype={'Open': np.float32, 'High': np.float32, 'Low': np.float32, 115 | 'Close': np.float32, 116 | 'Volume': np.float64, 'Dividends': np.float64}) 117 | 118 | arrow_file = os.path.join(save_to, 'yahoo.feather') 119 | ArrowLoader.ingest(source=loader, save_to=arrow_file, force=True) 120 | 121 | print('Ingest completed! Use `loader = spectre.data.ArrowLoader(r"{}")` ' 122 | 'to load your data.'.format(arrow_file)) 123 | -------------------------------------------------------------------------------- /spectre/factors/__init__.py: -------------------------------------------------------------------------------- 1 | from .engine import ( 2 | FactorEngine, 3 | OHLCV, 4 | ) 5 | 6 | from .factor import ( 7 | BaseFactor, PlaceHolderFactor, 8 | CustomFactor, 9 | CrossSectionFactor, 10 | RankFactor, RollingRankFactor, 11 | ZScoreFactor, RollingZScoreFactor, 12 | XSMax, XSMin, 13 | QuantileClassifier, 14 | SumFactor, ProdFactor, UniqueTSSumFactor, 15 | MADClampFactor, 16 | WinsorizingFactor, 17 | IQRNormalityFactor, 18 | ) 19 | 20 | from .datafactor import ( 21 | ColumnDataFactor, 22 | AdjustedColumnDataFactor, 23 | AssetClassifierDataFactor, 24 | SeriesDataFactor, 25 | DatetimeDataFactor, 26 | ) 27 | 28 | from .filter import ( 29 | FilterFactor, 30 | StaticAssets, 31 | OneHotEncoder, 32 | AndFactor, 33 | AnyFilter, 34 | AllFilter, 35 | PlaceHolderFilter, 36 | ) 37 | 38 | from .multiprocessing import ( 39 | CPUParallelFactor 40 | ) 41 | 42 | from .basic import ( 43 | Returns, 44 | LogReturns, 45 | SimpleMovingAverage, MA, SMA, 46 | WeightedAverageValue, 47 | LinearWeightedAverage, 48 | VWAP, 49 | ExponentialWeightedMovingAverage, EMA, 50 | AverageDollarVolume, 51 | AnnualizedVolatility, 52 | ElementWiseMax, ElementWiseMin, 53 | RollingArgMax, RollingArgMin, 54 | ConstantsFactor, 55 | ) 56 | 57 | from .technical import ( 58 | BollingerBands, BBANDS, 59 | MovingAverageConvergenceDivergenceSignal, MACD, 60 | TrueRange, TRANGE, 61 | RSI, 62 | FastStochasticOscillator, STOCHF, 63 | ) 64 | 65 | from .statistical import ( 66 | StandardDeviation, STDDEV, 67 | RollingHigh, MAX, 68 | RollingLow, MIN, 69 | RollingLinearRegression, 70 | RollingMomentum, 71 | RollingQuantile, 72 | HalfLifeMeanReversion, 73 | RollingCorrelation, 74 | RollingCovariance, 75 | XSMaxCorrCoef, 76 | InformationCoefficient, RankWeightedInformationCoefficient, 77 | RollingInformationCoefficient, 78 | TTest1Samp, StudentCDF, 79 | CrossSectionR2, 80 | FactorWiseKthValue, FactorWiseZScore, 81 | ) 82 | 83 | from .feature import ( 84 | MarketDispersion, 85 | MarketReturn, 86 | MarketVolatility, 87 | AdvanceDeclineRatio, 88 | AssetData, 89 | MONTH, WEEKDAY, QUARTER, TIME, 90 | IS_JANUARY, IS_DECEMBER, IS_MONTH_END, IS_MONTH_START, IS_QUARTER_END, IS_QUARTER_START, 91 | ) 92 | 93 | from .label import ( 94 | RollingFirst, 95 | ForwardSignalData, 96 | ) 97 | -------------------------------------------------------------------------------- /spectre/factors/basic.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: Heerozh (Zhang Jianhao) 3 | @copyright: Copyright 2019, Heerozh. All rights reserved. 4 | @license: Apache 2.0 5 | @email: heeroz@gmail.com 6 | """ 7 | from typing import Sequence 8 | import numpy as np 9 | import torch 10 | import math 11 | from .factor import BaseFactor, CustomFactor 12 | from ..parallel import nansum, nanmean 13 | from .engine import OHLCV 14 | from ..config import Global 15 | 16 | 17 | class Returns(CustomFactor): 18 | """ Returns by tick (not by time) """ 19 | inputs = [OHLCV.close] 20 | win = 2 21 | _min_win = 2 22 | 23 | def compute(self, closes): 24 | # missing data considered as delisted, calculated on the last day's data. 25 | return closes.last_nonnan(offset=1) / closes.first() - 1 26 | 27 | 28 | class LogReturns(CustomFactor): 29 | inputs = [OHLCV.close] 30 | win = 2 31 | _min_win = 2 32 | 33 | def compute(self, closes): 34 | return (closes.last_nonnan() / closes.first()).log() 35 | 36 | 37 | class SimpleMovingAverage(CustomFactor): 38 | inputs = [OHLCV.close] 39 | _min_win = 2 40 | 41 | def compute(self, data): 42 | return data.nanmean() 43 | 44 | 45 | class WeightedAverageValue(CustomFactor): 46 | _min_win = 2 47 | 48 | def compute(self, base, weight): 49 | def _weight_mean(_base, _weight): 50 | return nansum(_base * _weight, dim=2) / nansum(_weight, dim=2) 51 | 52 | return base.agg(_weight_mean, weight) 53 | 54 | 55 | class LinearWeightedAverage(CustomFactor): 56 | _min_win = 2 57 | 58 | def __init__(self, win=None, inputs=None): 59 | super().__init__(win, inputs) 60 | self.weight = torch.arange(1, self.win + 1, dtype=Global.float_type) 61 | self.weight = self.weight / self.weight.sum() 62 | 63 | def pre_compute_(self, engine, start, end) -> None: 64 | super().pre_compute_(engine, start, end) 65 | self.weight = self.weight.to(device=engine.device, copy=False) 66 | 67 | def compute(self, base): 68 | def _weight_mean(_base): 69 | return nansum(_base * self.weight, dim=2) 70 | 71 | return base.agg(_weight_mean) 72 | 73 | 74 | class VWAP(WeightedAverageValue): 75 | inputs = (OHLCV.close, OHLCV.volume) 76 | 77 | 78 | class ExponentialWeightedMovingAverage(CustomFactor): 79 | inputs = [OHLCV.close] 80 | win = 2 81 | _min_win = 2 82 | 83 | def __init__(self, span: int = None, inputs: Sequence[BaseFactor] = None, 84 | adjust=False, half_life: float = None): 85 | if span is not None: 86 | self.alpha = (2.0 / (1.0 + span)) 87 | # Length required to achieve 99.97% accuracy, np.log(1-99.97/100) / np.log(alpha) 88 | # simplification to 4 * (span+1). 3.45 achieve 99.90%, 2.26 99.00% 89 | self.win = int(4.5 * (span + 1)) 90 | else: 91 | self.alpha = 1 - math.exp(math.log(0.5) / half_life) 92 | self.win = 15 * half_life 93 | 94 | super().__init__(None, inputs) 95 | self.adjust = adjust 96 | self.weight = np.full(self.win, 1 - self.alpha) ** np.arange(self.win - 1, -1, -1) 97 | if self.adjust: 98 | self.weight = self.weight / sum(self.weight) # to sum one 99 | 100 | def pre_compute_(self, engine, start, end) -> None: 101 | super().pre_compute_(engine, start, end) 102 | if not isinstance(self.weight, torch.Tensor): 103 | self.weight = torch.tensor(self.weight, dtype=Global.float_type, device=engine.device) 104 | 105 | def compute(self, data): 106 | self.weight = self.weight.to(device=data.device) 107 | weighted_mean = data.agg(lambda x: nansum(x * self.weight, dim=2)) 108 | if self.adjust: 109 | return weighted_mean 110 | else: 111 | shifted = data.last().roll(self.win - 1, dims=1) 112 | shifted[:, 0:self.win - 1] = 0 113 | alpha = self.alpha 114 | return alpha * weighted_mean + (shifted * (1 - alpha) ** self.win) 115 | 116 | 117 | class AverageDollarVolume(CustomFactor): 118 | inputs = [OHLCV.close, OHLCV.volume] 119 | 120 | def compute(self, closes, volumes): 121 | if self.win == 1: 122 | return closes * volumes 123 | else: 124 | return closes.agg(lambda c, v: nanmean(c * v, dim=2), volumes) 125 | 126 | 127 | class AnnualizedVolatility(CustomFactor): 128 | inputs = [Returns(win=2), 252] 129 | window_length = 20 130 | _min_win = 2 131 | 132 | def compute(self, returns, annualization_factor): 133 | return returns.nanstd() * (annualization_factor ** .5) 134 | 135 | 136 | class ElementWiseMax(CustomFactor): 137 | _min_win = 1 138 | 139 | def __init__(self, win=None, inputs=None): 140 | super().__init__(win, inputs) 141 | assert self.win == 1 142 | 143 | @classmethod 144 | def binary_fill_na(cls, a, b, value): 145 | a = a.clone() 146 | b = b.clone() 147 | if a.dtype != b.dtype or a.dtype not in (torch.float32, torch.float64, torch.float16): 148 | a = a.to(Global.float_type) 149 | b = b.to(Global.float_type) 150 | 151 | a.masked_fill_(torch.isnan(a), value) 152 | b.masked_fill_(torch.isnan(b), value) 153 | return a, b 154 | 155 | def compute(self, a, b): 156 | ret = torch.max(*ElementWiseMax.binary_fill_na(a, b, -np.inf)) 157 | ret.masked_fill_(torch.isinf(ret), np.nan) 158 | return ret 159 | 160 | 161 | class ElementWiseMin(CustomFactor): 162 | _min_win = 1 163 | 164 | def __init__(self, win=None, inputs=None): 165 | super().__init__(win, inputs) 166 | assert self.win == 1 167 | 168 | def compute(self, a, b): 169 | ret = torch.min(*ElementWiseMax.binary_fill_na(a, b, np.inf)) 170 | ret.masked_fill_(torch.isinf(ret), np.nan) 171 | return ret 172 | 173 | 174 | class RollingArgMax(CustomFactor): 175 | _min_win = 2 176 | 177 | def compute(self, data): 178 | def _argmax(_data): 179 | ret = (_data.argmax(dim=2) + 1.) / self.win 180 | return ret.to(Global.float_type) 181 | 182 | return data.agg(_argmax) 183 | 184 | 185 | class RollingArgMin(CustomFactor): 186 | _min_win = 2 187 | 188 | def compute(self, data): 189 | def _argmin(_data): 190 | ret = (_data.argmin(dim=2) + 1.) / self.win 191 | return ret.to(Global.float_type) 192 | 193 | return data.agg(_argmin) 194 | 195 | 196 | class ConstantsFactor(CustomFactor): 197 | def __init__(self, value, like=OHLCV.open): 198 | self.value = value 199 | super().__init__(1, inputs=[like]) 200 | 201 | def compute(self, x): 202 | return torch.full(x.shape, self.value, device=x.device, dtype=x.dtype) 203 | 204 | 205 | MA = SimpleMovingAverage 206 | SMA = SimpleMovingAverage 207 | EMA = ExponentialWeightedMovingAverage 208 | -------------------------------------------------------------------------------- /spectre/factors/datafactor.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: Heerozh (Zhang Jianhao) 3 | @copyright: Copyright 2019-2020, Heerozh. All rights reserved. 4 | @license: Apache 2.0 5 | @email: heeroz@gmail.com 6 | """ 7 | import numpy as np 8 | import torch 9 | import pandas as pd 10 | 11 | from typing import Optional, Sequence, Union 12 | from ..parallel import nanlast 13 | from .factor import BaseFactor, CustomFactor 14 | from ..config import Global 15 | 16 | 17 | class ColumnDataFactor(BaseFactor): 18 | def __init__(self, inputs: Optional[Sequence[str]] = None, 19 | should_delay=True, dtype=None) -> None: 20 | super().__init__() 21 | if inputs: 22 | self.inputs = inputs 23 | assert (3 > len(self.inputs) > 0), \ 24 | "ColumnDataFactor's `inputs` can only contains one data column and corresponding " \ 25 | "adjustments column" 26 | self._data = None 27 | self._multiplier = None 28 | self._should_delay = should_delay 29 | self.dtype = dtype 30 | 31 | @property 32 | def adjustments(self): 33 | return self._multiplier 34 | 35 | def get_total_backwards_(self) -> int: 36 | return 0 37 | 38 | def should_delay(self) -> bool: 39 | return self._should_delay 40 | 41 | def pre_compute_(self, engine, start, end) -> None: 42 | super().pre_compute_(engine, start, end) 43 | if self._data is None: 44 | self._data = engine.column_to_tensor_(self.inputs[0]) 45 | if self.dtype is not None: 46 | self._data = self._data.to(self.dtype) 47 | self._data = engine.group_by_(self._data, self.groupby) 48 | if len(self.inputs) > 1 and self.inputs[1] in engine.dataframe_: 49 | self._multiplier = engine.column_to_tensor_(self.inputs[1]) 50 | self._multiplier = engine.group_by_(self._multiplier, self.groupby) 51 | if self.dtype is not None: 52 | self._multiplier = self._multiplier.to(self.dtype) 53 | else: 54 | self._multiplier = None 55 | self._clean_required = True 56 | 57 | def clean_up_(self, force=False) -> None: 58 | super().clean_up_(force) 59 | self._data = None 60 | self._multiplier = None 61 | self._clean_required = False 62 | 63 | def compute_(self, stream: Union[torch.cuda.Stream, None]) -> torch.Tensor: 64 | return self._data 65 | 66 | def compute(self, *inputs: Sequence[torch.Tensor]) -> torch.Tensor: 67 | pass 68 | 69 | # def adjusted_shift(self, periods=1): 70 | # factor = AdjustedShiftFactor(win=periods, inputs=(self,)) 71 | # return factor 72 | # 73 | # 74 | # class AdjustedShiftFactor(CustomFactor): 75 | # """ Shift the root datafactor """ 76 | # 77 | # def compute(self, data) -> torch.Tensor: 78 | # return data.first() 79 | 80 | 81 | class AdjustedColumnDataFactor(CustomFactor): 82 | def __init__(self, data: ColumnDataFactor): 83 | super().__init__(1, (data,)) 84 | self.parent = data 85 | 86 | def compute(self, data) -> torch.Tensor: 87 | multi = self.parent.adjustments 88 | if multi is None: 89 | return data 90 | return data * multi / nanlast(multi, dim=1)[:, None] 91 | 92 | 93 | class AssetClassifierDataFactor(BaseFactor): 94 | """ Dict to categorical output for asset, slow """ 95 | def __init__(self, sector: dict, default: int): 96 | super().__init__() 97 | self.sector = sector 98 | self.default = default 99 | self._data = None 100 | 101 | def get_total_backwards_(self) -> int: 102 | return 0 103 | 104 | def should_delay(self) -> bool: 105 | return False 106 | 107 | def pre_compute_(self, engine, start, end) -> None: 108 | super().pre_compute_(engine, start, end) 109 | assets = engine.dataframe_index[1] 110 | sector = self.sector 111 | default = self.default 112 | data = [sector.get(asset, default) for asset in assets] # slow 113 | data = torch.tensor(data, device=engine.device, dtype=Global.float_type) 114 | self._data = engine.group_by_(data, self.groupby) 115 | 116 | def clean_up_(self, force=False) -> None: 117 | super().clean_up_(force) 118 | self._data = None 119 | 120 | def compute_(self, stream: Union[torch.cuda.Stream, None]) -> torch.Tensor: 121 | return self._data 122 | 123 | def compute(self, *inputs: Sequence[torch.Tensor]) -> torch.Tensor: 124 | pass 125 | 126 | 127 | class SeriesDataFactor(ColumnDataFactor): 128 | """ Add series to engine, slow """ 129 | def __init__(self, series: pd.Series, fill_na=None, should_delay=True): 130 | self.series = series 131 | self.fill_na = fill_na 132 | assert series.index.names == ['date', 'asset'], \ 133 | "df.index.names should be ['date', 'asset'] " 134 | super().__init__(inputs=[str(series.name)], should_delay=should_delay) 135 | 136 | def pre_compute_(self, engine, start, end) -> None: 137 | if self.series.name not in engine.dataframe_.columns: 138 | engine._dataframe = engine.dataframe_.join(self.series) 139 | if self.fill_na is not None: 140 | engine._dataframe[self.series.name] = engine._dataframe[self.series.name].\ 141 | groupby(level=1).fillna(method=self.fill_na) 142 | super().pre_compute_(engine, start, end) 143 | 144 | 145 | class DatetimeDataFactor(BaseFactor): 146 | """ Datetime's attr to DataFactor """ 147 | _instance = {} 148 | 149 | def __new__(cls, attr): 150 | if attr not in cls._instance: 151 | cls._instance[attr] = super().__new__(cls) 152 | return cls._instance[attr] 153 | 154 | def __init__(self, attr) -> None: 155 | super().__init__() 156 | self._data = None 157 | self.attr = attr 158 | 159 | def get_total_backwards_(self) -> int: 160 | return 0 161 | 162 | def should_delay(self) -> bool: 163 | return False 164 | 165 | def pre_compute_(self, engine, start, end) -> None: 166 | super().pre_compute_(engine, start, end) 167 | if self._data is None: 168 | data = getattr(engine.dataframe_index[0], self.attr) # slow 169 | if not isinstance(data, np.ndarray): 170 | data = data.values 171 | data = torch.from_numpy(data).to( 172 | device=engine.device, dtype=Global.float_type, non_blocking=True) 173 | self._data = engine.group_by_(data, self.groupby) 174 | self._clean_required = True 175 | 176 | def clean_up_(self, force=False) -> None: 177 | super().clean_up_(force) 178 | self._data = None 179 | self._clean_required = False 180 | 181 | def compute_(self, stream: Union[torch.cuda.Stream, None]) -> torch.Tensor: 182 | return self._data 183 | 184 | def compute(self, *inputs: Sequence[torch.Tensor]) -> torch.Tensor: 185 | pass 186 | -------------------------------------------------------------------------------- /spectre/factors/feature.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: Heerozh (Zhang Jianhao) 3 | @copyright: Copyright 2019-2020, Heerozh. All rights reserved. 4 | @license: Apache 2.0 5 | @email: heeroz@gmail.com 6 | """ 7 | import warnings 8 | from .datafactor import DatetimeDataFactor 9 | from .factor import CrossSectionFactor, CustomFactor 10 | from .basic import Returns 11 | from ..parallel import nanstd, nanmean, nansum 12 | from ..config import Global 13 | 14 | 15 | # ----------- Common Market Features ----------- 16 | 17 | 18 | class MarketDispersion(CrossSectionFactor): 19 | """Cross-section standard deviation of universe stocks returns.""" 20 | inputs = (Returns(), ) 21 | win = 1 22 | 23 | def compute(self, returns): 24 | ret = nanstd(returns, dim=1).unsqueeze(-1) 25 | return ret.expand(ret.shape[0], returns.shape[1]) 26 | 27 | 28 | class MarketReturn(CrossSectionFactor): 29 | """Cross-section mean returns of universe stocks.""" 30 | inputs = (Returns(), ) 31 | win = 1 32 | 33 | def compute(self, returns): 34 | ret = nanmean(returns, dim=1).unsqueeze(-1) 35 | return ret.expand(ret.shape[0], returns.shape[1]) 36 | 37 | 38 | class MarketVolatility(CustomFactor): 39 | """MarketReturn Rolling standard deviation.""" 40 | inputs = (MarketReturn(), 252) 41 | win = 252 42 | _min_win = 2 43 | 44 | def compute(self, returns, annualization_factor): 45 | return (returns.nanvar() * annualization_factor) ** 0.5 46 | 47 | 48 | class AdvanceDeclineRatio(CrossSectionFactor): 49 | """Need to work with MA, and could be applied to volume too""" 50 | inputs = (Returns(), ) 51 | win = 1 52 | 53 | def compute(self, returns): 54 | advancing = nansum(returns > 0, dim=1).to(Global.float_type) 55 | declining = nansum(returns < 0, dim=1).to(Global.float_type) 56 | ratio = (advancing / declining).unsqueeze(-1) 57 | return ratio.expand(ratio.shape[0], returns.shape[1]) 58 | 59 | 60 | # ----------- Asset-specific data ----------- 61 | 62 | 63 | class AssetData(CustomFactor): 64 | def __init__(self, asset, factor): 65 | self.asset = asset 66 | self.asset_ind = None 67 | super().__init__(win=1, inputs=[factor]) 68 | 69 | def pre_compute_(self, engine, start, end): 70 | super().pre_compute_(engine, start, end) 71 | if not engine.align_by_time: 72 | warnings.warn("Make sure your data is aligned by time, otherwise will cause data " 73 | "disorder. Or set engine.align_by_time = True.", 74 | RuntimeWarning) 75 | self.asset_ind = engine.dataframe_index[1].unique().categories.get_loc(self.asset) 76 | 77 | def compute(self, data): 78 | ret = data[self.asset_ind] 79 | return ret.expand(data.shape[0], ret.shape[0]) 80 | 81 | 82 | # ----------- Common Calendar Features ----------- 83 | 84 | 85 | MONTH = DatetimeDataFactor('month') 86 | WEEKDAY = DatetimeDataFactor('weekday') 87 | QUARTER = DatetimeDataFactor('quarter') 88 | TIME = DatetimeDataFactor('hour') + DatetimeDataFactor('minute') / 60.0 89 | 90 | IS_JANUARY = MONTH == 1 91 | IS_DECEMBER = MONTH == 12 92 | # Note: shift(-1) may fail the engine.test_lookahead_bias(), 93 | # but this method is the fastest, so be it. 94 | IS_MONTH_END = MONTH.shift(-1) != MONTH 95 | IS_MONTH_START = MONTH.shift(1) != MONTH 96 | IS_QUARTER_END = QUARTER.shift(-1) != QUARTER 97 | IS_QUARTER_START = QUARTER.shift(1) != QUARTER 98 | 99 | -------------------------------------------------------------------------------- /spectre/factors/filter.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: Heerozh (Zhang Jianhao) 3 | @copyright: Copyright 2019-2020, Heerozh. All rights reserved. 4 | @license: Apache 2.0 5 | @email: heeroz@gmail.com 6 | """ 7 | from abc import ABC 8 | from typing import Set 9 | from .factor import BaseFactor, CustomFactor, MultiRetSelector 10 | from ..parallel import Rolling 11 | import torch 12 | import warnings 13 | 14 | 15 | class FilterFactor(CustomFactor, ABC): 16 | def __getitem__(self, key): 17 | return FilterMultiRetSelector(inputs=(self, key)) 18 | 19 | def shift(self, periods=1): 20 | factor = FilterRawShiftFactor(inputs=(self,)) 21 | factor.periods = periods 22 | return factor 23 | 24 | def sum(self, win): 25 | raise ValueError("FilterFactor does not support `.sum()` method, " 26 | "please convert to float by using `filter_factor.float()`") 27 | 28 | def filter(self, mask): 29 | raise ValueError("FilterFactor does not support local filtering `.filter()` method, " 30 | "please convert to float by using `filter_factor.float()`") 31 | 32 | def ts_any(self, win): 33 | """ Return True if Rolling window contains any True """ 34 | return AnyFilter(win, inputs=(self,)) 35 | 36 | def ts_all(self, win): 37 | """ Return True if Rolling window all are True """ 38 | return AllFilter(win, inputs=(self,)) 39 | 40 | 41 | class FilterMultiRetSelector(MultiRetSelector, FilterFactor): 42 | """MultiRetSelector returns CustomFactor, we're override here as FilterFactor""" 43 | pass 44 | 45 | 46 | class PlaceHolderFilter(FilterFactor): 47 | def compute(self, data: torch.Tensor) -> torch.Tensor: 48 | return data 49 | 50 | 51 | class FilterRawShiftFactor(FilterFactor): 52 | """For "roll_cuda" not implemented for 'Bool' """ 53 | periods = 1 54 | 55 | def compute(self, data: torch.Tensor) -> torch.Tensor: 56 | shift = data.char().roll(self.periods, dims=1) 57 | if self.periods > 0: 58 | shift[:, 0:self.periods] = 0 59 | else: 60 | shift[:, self.periods:] = 0 61 | 62 | return shift.bool() 63 | 64 | 65 | class StaticAssets(FilterFactor): 66 | """Useful for remove specific outliers or debug some assets""" 67 | def __init__(self, assets: Set[str]): 68 | from .engine import OHLCV 69 | super().__init__(win=1, inputs=[OHLCV.open]) 70 | self.assets = assets 71 | 72 | def compute(self, data: torch.Tensor) -> torch.Tensor: 73 | s = self._revert_to_series(data) 74 | ret = s.index.isin(self.assets, level=1) 75 | return self._regroup(ret) 76 | 77 | 78 | class OneHotEncoder(FilterFactor): 79 | def __init__(self, input_: BaseFactor): 80 | super().__init__(1, [input_]) 81 | 82 | def compute(self, data: torch.Tensor) -> torch.Tensor: 83 | classes = torch.unique(data, sorted=False) 84 | classes = classes[~torch.isnan(classes)] 85 | one_hot = [] 86 | if classes.shape[0] > 1000: 87 | warnings.warn("One hot encoding with too many features: ({}). " 88 | .format(classes.shape[0]), 89 | RuntimeWarning) 90 | for i in range(classes.shape[0]): 91 | one_hot.append((data == classes[i]).unsqueeze(-1)) 92 | return torch.cat(one_hot, dim=-1) 93 | 94 | 95 | class AnyFilter(FilterFactor): 96 | _min_win = 2 97 | 98 | def compute(self, data: Rolling) -> torch.Tensor: 99 | return data.values.any(dim=2) 100 | 101 | 102 | class AllFilter(FilterFactor): 103 | _min_win = 2 104 | 105 | def compute(self, data: Rolling) -> torch.Tensor: 106 | return data.values.all(dim=2) 107 | 108 | 109 | class AnyNonNaNFactor(FilterFactor): 110 | _min_win = 2 111 | 112 | def compute(self, data: Rolling) -> torch.Tensor: 113 | return (~torch.isnan(data.values)).any(dim=2) 114 | 115 | 116 | class AllNonNaNFactor(FilterFactor): 117 | _min_win = 2 118 | 119 | def compute(self, data: Rolling) -> torch.Tensor: 120 | return (~torch.isnan(data.values)).all(dim=2) 121 | 122 | 123 | class InvertFactor(FilterFactor): 124 | def compute(self, left) -> torch.Tensor: 125 | return ~left 126 | 127 | 128 | class OrFactor(FilterFactor): 129 | def compute(self, left, right) -> torch.Tensor: 130 | return left | right 131 | 132 | 133 | class XorFactor(FilterFactor): 134 | def compute(self, left, right) -> torch.Tensor: 135 | return left ^ right 136 | 137 | 138 | class AndFactor(FilterFactor): 139 | def compute(self, left, right) -> torch.Tensor: 140 | return left & right 141 | 142 | 143 | class LtFactor(FilterFactor): 144 | def compute(self, left, right) -> torch.Tensor: 145 | return torch.lt(left, right) 146 | 147 | 148 | class LeFactor(FilterFactor): 149 | def compute(self, left, right) -> torch.Tensor: 150 | return torch.le(left, right) 151 | 152 | 153 | class GtFactor(FilterFactor): 154 | def compute(self, left, right) -> torch.Tensor: 155 | return torch.gt(left, right) 156 | 157 | 158 | class GeFactor(FilterFactor): 159 | def compute(self, left, right) -> torch.Tensor: 160 | return torch.ge(left, right) 161 | 162 | 163 | class EqFactor(FilterFactor): 164 | def compute(self, left, right) -> torch.Tensor: 165 | return torch.eq(left, right) 166 | 167 | 168 | class NeFactor(FilterFactor): 169 | def compute(self, left, right) -> torch.Tensor: 170 | return torch.ne(left, right) 171 | -------------------------------------------------------------------------------- /spectre/factors/label.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: Heerozh (Zhang Jianhao) 3 | @copyright: Copyright 2019-2020, Heerozh. All rights reserved. 4 | @license: Apache 2.0 5 | @email: heeroz@gmail.com 6 | """ 7 | from .factor import CustomFactor 8 | from ..parallel import masked_first 9 | 10 | 11 | class RollingFirst(CustomFactor): 12 | win = 2 13 | _min_win = 2 14 | 15 | def __init__(self, win, data, mask): 16 | super().__init__(win, inputs=(data, mask)) 17 | 18 | def compute(self, data, mask): 19 | def _first_filter(_data, _mask): 20 | first_signal_price = masked_first(_data, _mask, dim=2) 21 | return first_signal_price 22 | 23 | return data.agg(_first_filter, mask) 24 | 25 | 26 | class ForwardSignalData(RollingFirst): 27 | """Data in future window periods where signal = True. Lookahead biased.""" 28 | def __init__(self, win, data, signal): 29 | super().__init__(win, data.shift(-win+1), signal.shift(-win+1)) 30 | -------------------------------------------------------------------------------- /spectre/factors/multiprocessing.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: Heerozh (Zhang Jianhao) 3 | @copyright: Copyright 2019, Heerozh. All rights reserved. 4 | @license: Apache 2.0 5 | @email: heeroz@gmail.com 6 | """ 7 | from typing import Optional, Sequence 8 | from .factor import BaseFactor, CustomFactor 9 | from .datafactor import ColumnDataFactor 10 | from ..parallel import Rolling 11 | import pandas as pd 12 | import numpy as np 13 | from multiprocessing import Pool, cpu_count 14 | from multiprocessing.pool import ThreadPool 15 | 16 | 17 | class CPFCaller: 18 | inputs = None 19 | win = None 20 | callback = None 21 | 22 | def split_call(self, splits): 23 | split_inputs = [[]] * len(self.inputs) 24 | for i, data in enumerate(self.inputs): 25 | if isinstance(data, pd.DataFrame): 26 | split_inputs[i] = [data.iloc[beg:end] for beg, end in splits] 27 | else: 28 | split_inputs[i] = [data] * len(splits) 29 | return np.array([self.callback(*params) for params in zip(*split_inputs)]) 30 | 31 | 32 | class CPUParallelFactor(CustomFactor): 33 | """ 34 | Use CPU multi-process/thread instead of GPU to process each window of data. 35 | Useful when your calculations can only be done in the CPU. 36 | 37 | The performance of this method is not so ideal, definitely not as fast as 38 | using the vectorization library directly. 39 | """ 40 | 41 | def __init__(self, win: Optional[int] = None, inputs: Optional[Sequence[BaseFactor]] = None, 42 | multiprocess=False, core=None): 43 | """ 44 | `multiprocess=True` may not working on windows If your code is written in a notebook cell. 45 | So it is recommended that you write the CPUParallelFactor code in a file. 46 | """ 47 | super().__init__(win, inputs) 48 | 49 | for data in inputs: 50 | if isinstance(data, ColumnDataFactor): 51 | raise ValueError('Cannot use ColumnDataFactor in CPUParallelFactor, ' 52 | 'please use AdjustedColumnDataFactor instead.') 53 | if multiprocess: 54 | self.pool = Pool 55 | else: 56 | self.pool = ThreadPool 57 | if core is None: 58 | self.core = cpu_count() 59 | else: 60 | self.core = core 61 | 62 | def compute(self, *inputs): 63 | n_cores = self.core 64 | origin_input = None 65 | date_count = 0 66 | 67 | converted_inputs = [] 68 | for data in inputs: 69 | if isinstance(data, Rolling): 70 | s = self._revert_to_series(data.last()) 71 | unstacked = s.unstack(level=1) 72 | converted_inputs.append(unstacked) 73 | if origin_input is None: 74 | origin_input = s 75 | date_count = len(unstacked) 76 | else: 77 | converted_inputs.append(data) 78 | 79 | backwards = self.get_total_backwards_() 80 | first_win_beg = backwards - self.win + 1 81 | first_win_end = backwards + 1 82 | windows = date_count - backwards 83 | ranges = list(zip(range(first_win_beg, first_win_beg + windows), 84 | range(first_win_end, date_count + 1))) 85 | caller = CPFCaller() 86 | caller.inputs = converted_inputs 87 | caller.callback = type(self).mp_compute 88 | 89 | if len(ranges) < n_cores: 90 | n_cores = len(ranges) 91 | split_range = np.array_split(ranges, n_cores) 92 | 93 | with self.pool(n_cores) as p: 94 | pool_ret = p.map(caller.split_call, split_range) 95 | 96 | pool_ret = np.concatenate(pool_ret) 97 | ret = pd.Series(index=origin_input.index, dtype='float64').unstack(level=1) 98 | if pool_ret.shape != ret.iloc[backwards:].shape: 99 | raise ValueError('return value shape {} != original {}'.format( 100 | pool_ret.shape, ret.iloc[backwards:].shape)) 101 | 102 | ret.iloc[backwards:] = pool_ret 103 | ret = ret.stack(dropna=False)[origin_input.index] 104 | 105 | return self._regroup(ret) 106 | 107 | @staticmethod 108 | def mp_compute(*inputs) -> np.array: 109 | """ 110 | You will receive a window of the input data, type is pd.DataFrame. 111 | The table below is how it looks when win=3 112 | | date | A | AA | ... | ZET | 113 | |------------|------|--------|-----|--------| 114 | | 2020-01-01 | 11.1 | xxx.xx | ... | 123.45 | 115 | | 2020-01-02 | ... | ... | ... | ... | 116 | | 2020-01-03 | 22.2 | xxx.xx | ... | 234.56 | 117 | 118 | You should return an np.array of length `input.shape[1]` 119 | """ 120 | raise NotImplementedError("abstractmethod") 121 | -------------------------------------------------------------------------------- /spectre/factors/statistical.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: Heerozh (Zhang Jianhao) 3 | @copyright: Copyright 2019, Heerozh. All rights reserved. 4 | @license: Apache 2.0 5 | @email: heeroz@gmail.com 6 | """ 7 | import torch 8 | import math 9 | import numpy as np 10 | from .factor import CustomFactor, CrossSectionFactor 11 | from .engine import OHLCV 12 | from ..parallel import (linear_regression_1d, quantile, pearsonr, unmasked_mean, unmasked_sum, 13 | nanmean, nanstd, covariance, nanvar) 14 | from ..parallel import DeviceConstant 15 | from ..config import Global 16 | 17 | 18 | class StandardDeviation(CustomFactor): 19 | inputs = [OHLCV.close] 20 | _min_win = 2 21 | ddof = 0 22 | 23 | def compute(self, data): 24 | return data.nanstd(ddof=self.ddof) 25 | 26 | 27 | class XSStandardDeviation(CrossSectionFactor): 28 | inputs = [OHLCV.close] 29 | ddof = 0 30 | 31 | def compute(self, data): 32 | return nanstd(data, ddof=self.ddof).unsqueeze(-1).expand(data.shape[0], data.shape[1]) 33 | 34 | 35 | class RollingHigh(CustomFactor): 36 | inputs = (OHLCV.close,) 37 | win = 5 38 | _min_win = 2 39 | 40 | def compute(self, data): 41 | return data.nanmax() 42 | 43 | 44 | class RollingLow(CustomFactor): 45 | inputs = (OHLCV.close,) 46 | win = 5 47 | _min_win = 2 48 | 49 | def compute(self, data): 50 | return data.nanmin() 51 | 52 | 53 | class RollingLinearRegression(CustomFactor): 54 | _min_win = 2 55 | 56 | def __init__(self, win, x, y): 57 | super().__init__(win=win, inputs=[x, y]) 58 | 59 | def compute(self, x, y): 60 | def lin_reg(_y, _x=None): 61 | if _x is None: 62 | _x = DeviceConstant.get(_y.device).arange(_y.shape[2], dtype=_y.dtype) 63 | _x = _x.expand(_y.shape[0], _y.shape[1], _x.shape[0]) 64 | m, b = linear_regression_1d(_x, _y, dim=2) 65 | return torch.cat([m.unsqueeze(-1), b.unsqueeze(-1)], dim=-1) 66 | if x is None: 67 | return y.agg(lin_reg) 68 | else: 69 | return y.agg(lin_reg, x) 70 | 71 | @property 72 | def coef(self): 73 | return self[0] 74 | 75 | @property 76 | def intercept(self): 77 | return self[1] 78 | 79 | 80 | class RollingMomentum(CustomFactor): 81 | inputs = (OHLCV.close,) 82 | win = 20 83 | _min_win = 2 84 | 85 | def compute(self, prices): 86 | def polynomial_reg(_y): 87 | x = DeviceConstant.get(_y.device).arange(_y.shape[2], dtype=_y.dtype) 88 | ones = torch.ones(x.shape[0], device=_y.device, dtype=_y.dtype) 89 | x = torch.stack([ones, x, x ** 2]).T 90 | x = x.expand(_y.shape[0], _y.shape[1], x.shape[0], x.shape[1]) 91 | 92 | xt = x.transpose(-2, -1) 93 | ret = (xt @ x).inverse() @ xt @ _y.unsqueeze(-1) 94 | return ret.squeeze(-1) 95 | 96 | return prices.agg(polynomial_reg) 97 | 98 | @property 99 | def gain(self): 100 | """gain>0 means stock gaining, otherwise is losing.""" 101 | return self[1] 102 | 103 | @property 104 | def accelerate(self): 105 | """accelerate>0 means stock accelerating, otherwise is decelerating.""" 106 | return self[2] 107 | 108 | @property 109 | def intercept(self): 110 | return self[0] 111 | 112 | 113 | class RollingQuantile(CustomFactor): 114 | inputs = (OHLCV.close, 5) 115 | _min_win = 2 116 | 117 | def compute(self, data, bins): 118 | def _quantile(_data): 119 | return quantile(_data, bins, dim=2)[:, :, -1] 120 | return data.agg(_quantile) 121 | 122 | 123 | class HalfLifeMeanReversion(CustomFactor): 124 | _min_win = 2 125 | 126 | def __init__(self, win, data, mean, mask=None): 127 | lag = data.shift(1) - mean 128 | diff = data - data.shift(1) 129 | lag.set_mask(mask) 130 | diff.set_mask(mask) 131 | super().__init__(win=win, inputs=[lag, diff, math.log(2)]) 132 | 133 | def compute(self, lag, diff, ln2): 134 | def calc_h(_x, _y): 135 | _lambda, _ = linear_regression_1d(_x, _y, dim=2) 136 | return -ln2 / _lambda 137 | return lag.agg(calc_h, diff) 138 | 139 | 140 | class RollingCorrelation(CustomFactor): 141 | _min_win = 2 142 | 143 | def compute(self, x, y): 144 | def _corr(_x, _y): 145 | return pearsonr(_x, _y, dim=2, ddof=1) 146 | return x.agg(_corr, y) 147 | 148 | 149 | class RollingCovariance(CustomFactor): 150 | _min_win = 2 151 | 152 | def compute(self, x, y): 153 | def _cov(_x, _y): 154 | return covariance(_x, _y, dim=2, ddof=1) 155 | return x.agg(_cov, y) 156 | 157 | 158 | class XSMaxCorrCoef(CrossSectionFactor): 159 | """ 160 | Returns the maximum correlation coefficient for each x to others 161 | """ 162 | 163 | def compute(self, *xs): 164 | x = torch.stack(xs, dim=1) 165 | x_bar = nanmean(x, dim=2).unsqueeze(-1) 166 | demean = x - x_bar 167 | demean.masked_fill_(torch.isnan(demean), 0) 168 | cov = demean @ demean.transpose(-2, -1) 169 | cov = cov / (x.shape[-1] - 1) 170 | diag = cov[:, range(len(xs)), range(len(xs)), None] 171 | std = diag ** 0.5 172 | corr = cov / std / std.transpose(-2, -1) 173 | # set auto corr to zero 174 | corr[:, range(len(xs)), range(len(xs))] = 0 175 | max_corr = corr.max(dim=2).values.unsqueeze(-2) 176 | return max_corr.expand(x.shape[0], x.shape[2], x.shape[1]) 177 | 178 | 179 | class InformationCoefficient(CrossSectionFactor): 180 | """ 181 | Cross-Section IC, the ic value of all assets is the same. 182 | """ 183 | def __init__(self, x, y, mask=None, weight=None): 184 | super().__init__(win=1, inputs=[x, y, weight], mask=mask) 185 | 186 | def compute(self, x, y, weight): 187 | if weight is None: 188 | ic = pearsonr(x, y, dim=1, ddof=1) 189 | else: 190 | xy = x * y 191 | mask = torch.isnan(x * y) 192 | w = weight / unmasked_sum(weight, mask=mask, dim=1).unsqueeze(-1) 193 | x_bar = unmasked_sum(w * x, mask=mask, dim=1) 194 | y_bar = unmasked_sum(w * y, mask=mask, dim=1) 195 | cov_xy = unmasked_sum(w * xy, mask=mask, dim=1) - x_bar * y_bar 196 | var_x = unmasked_sum(w * x ** 2, mask=mask, dim=1) - x_bar ** 2 197 | var_y = unmasked_sum(w * y ** 2, mask=mask, dim=1) - y_bar ** 2 198 | ic = cov_xy / (var_x * var_y) ** 0.5 199 | return ic.unsqueeze(-1).expand(ic.shape[0], y.shape[1]) 200 | 201 | def to_ir(self, win, ddof=1): 202 | # Use CrossSectionFactor and unfold by self, because if use CustomFactor, the ir value 203 | # will inconsistent when some assets have no data (like newly listed), the ir value should 204 | # not be related to assets. 205 | class RollingIC2IR(CrossSectionFactor): 206 | def __init__(self, win_, inputs): 207 | super().__init__(1, inputs) 208 | self.rolling_win = win_ 209 | 210 | def compute(self, ic): 211 | x = ic[:, 0] 212 | nan_stack = x.new_full((self.rolling_win - 1,), np.nan) 213 | new_x = torch.cat((nan_stack, x), dim=0) 214 | rolling_ic = new_x.unfold(0, self.rolling_win, 1) 215 | 216 | # Fundamental Law of Active Management: ir = ic * sqrt(b), 1/sqrt(b) = std(ic) 217 | ir = nanmean(rolling_ic, dim=1) / nanstd(rolling_ic, dim=1, ddof=ddof) 218 | return ir.unsqueeze(-1).expand(ic.shape) 219 | return RollingIC2IR(win_=win, inputs=[self]) 220 | 221 | 222 | class RollingInformationCoefficient(RollingCorrelation): 223 | """ 224 | Rolling IC, Calculate IC between 2 historical data for each asset. 225 | """ 226 | def to_ir(self, win): 227 | std = StandardDeviation(win=win, inputs=(self,)) 228 | std.ddof = 1 229 | mean = self.sum(win) / win 230 | 231 | return mean / std 232 | 233 | 234 | class RankWeightedInformationCoefficient(InformationCoefficient): 235 | def __init__(self, x, y, half_life, mask=None): 236 | alpha = np.exp((np.log(0.5) / half_life)) 237 | y_rank = y.rank(ascending=False, mask=mask) - 1 238 | weight = alpha ** y_rank 239 | super().__init__(x, y, mask=mask, weight=weight) 240 | 241 | 242 | class TTest1Samp(CustomFactor): 243 | _min_win = 2 244 | 245 | def compute(self, a, pop_mean): 246 | def _ttest(_x): 247 | d = nanmean(_x, dim=2) - pop_mean 248 | v = nanvar(_x, dim=2, ddof=1) 249 | denom = torch.sqrt(v / self._min_win) 250 | t = d / denom 251 | return t 252 | return a.agg(_ttest) 253 | 254 | 255 | class StudentCDF(CrossSectionFactor): 256 | """ 257 | Note!! For performance, This factor uses Cross-section mean of t-value. 258 | """ 259 | DefaultPrecision = 0.001 260 | 261 | def compute(self, t, dof, precision): 262 | reduced_t = nanmean(t, dim=1) 263 | p = torch.zeros_like(reduced_t) 264 | dof = torch.tensor(dof, dtype=torch.float64, device=t.device) 265 | for i, v in enumerate(reduced_t.cpu()): 266 | if np.isnan(v): 267 | p[i] = torch.nan 268 | elif np.isinf(v): 269 | p[i] = 1 270 | elif v < -9: 271 | p[i] = 0 272 | else: 273 | x = torch.arange(-9, v, precision, device=t.device) 274 | p[i] = torch.e ** torch.lgamma((dof + 1) / 2) / ( 275 | torch.sqrt(dof * torch.pi) * torch.e ** torch.lgamma(dof / 2)) * ( 276 | torch.trapezoid((1 + x ** 2 / dof) ** (-dof / 2 - 1 / 2), x) 277 | ) 278 | return p.unsqueeze(-1).expand(t.shape) 279 | 280 | 281 | class CrossSectionR2(CrossSectionFactor): 282 | def __init__(self, y, y_pred, mask, total_r2=False): 283 | super().__init__(win=1, inputs=[y, y_pred], mask=mask) 284 | self.total_r2 = total_r2 285 | 286 | def compute(self, y, y_pred): 287 | mask = torch.isnan(y_pred) | torch.isnan(y) 288 | ss_err = unmasked_sum((y - y_pred) ** 2, mask, dim=1) 289 | if self.total_r2: 290 | # 按市场总回报来算r2的话,用这个。不然就是相对回报。 291 | ss_tot = unmasked_sum(y ** 2, mask, dim=1) 292 | else: 293 | y_bar = unmasked_mean(y, mask, dim=1).unsqueeze(-1) 294 | ss_tot = unmasked_sum((y - y_bar) ** 2, mask, dim=1) 295 | r2 = -ss_err / ss_tot + 1 296 | r2[(~mask).to(Global.float_type).sum(dim=1) < 2] = np.nan 297 | return r2.unsqueeze(-1).expand(r2.shape[0], y.shape[1]) 298 | 299 | 300 | class FactorWiseKthValue(CrossSectionFactor): 301 | """ The kth value of all factors sorted in ascending order, grouped by each datetime """ 302 | def __init__(self, kth, inputs=None): 303 | super().__init__(1, inputs) 304 | self.kth = kth 305 | 306 | def compute(self, *data): 307 | mx = torch.stack([nanmean(x, dim=1) for x in data], dim=-1) 308 | nans = torch.isnan(mx) 309 | mx.masked_fill_(nans, -np.inf) 310 | ret = torch.kthvalue(mx, self.kth, dim=1, keepdim=True).values 311 | return ret.expand(ret.shape[0], data[0].shape[1]) 312 | 313 | 314 | class FactorWiseZScore(CrossSectionFactor): 315 | def compute(self, *data): 316 | mx = torch.stack([nanmean(x, dim=1) for x in data], dim=-1) 317 | ret = (mx - nanmean(mx, dim=1).unsqueeze(-1)) / nanstd(mx, dim=1).unsqueeze(-1) 318 | return ret.unsqueeze(-2).repeat(1, data[0].shape[1], 1) 319 | 320 | 321 | STDDEV = StandardDeviation 322 | MAX = RollingHigh 323 | MIN = RollingLow 324 | -------------------------------------------------------------------------------- /spectre/factors/technical.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: Heerozh (Zhang Jianhao) 3 | @copyright: Copyright 2019, Heerozh. All rights reserved. 4 | @license: Apache 2.0 5 | @email: heeroz@gmail.com 6 | """ 7 | from typing import Optional, Sequence 8 | from .factor import BaseFactor, CustomFactor 9 | from .basic import MA, EMA 10 | from .statistical import STDDEV 11 | from .engine import OHLCV 12 | from ..parallel import nanmean 13 | import numpy as np 14 | import torch 15 | 16 | 17 | class BollingerBands(CustomFactor): 18 | """ usage: BBANDS(win, inputs=[OHLCV.close, k]), k is constant normally 2 """ 19 | inputs = (OHLCV.close, 2) 20 | win = 20 21 | _min_win = 2 22 | 23 | def __init__(self, win: Optional[int] = None, inputs: Optional[Sequence[BaseFactor]] = None): 24 | super().__init__(win, inputs) 25 | if len(self.inputs) < 2: 26 | raise ValueError("BollingerBands's inputs needs 2 inputs, " 27 | "inputs=[OHLCV.close, k]), k is constant normally 2.") 28 | comm_inputs = (self.inputs[0],) 29 | k = self.inputs[1] 30 | self.inputs = (self.inputs[0], 31 | MA(win=self.win, inputs=comm_inputs), 32 | STDDEV(win=self.win, inputs=comm_inputs), 33 | k) 34 | self.win = 1 35 | 36 | def compute(self, closes, ma, std, k): 37 | d = k * std 38 | up = ma + d 39 | down = ma - d 40 | return torch.cat([up.unsqueeze(-1), ma.unsqueeze(-1), down.unsqueeze(-1)], dim=-1) 41 | 42 | def normalized(self): 43 | return NormalizedBollingerBands(self.win, self.inputs) 44 | 45 | 46 | class NormalizedBollingerBands(CustomFactor): 47 | def compute(self, closes, ma, std, k): 48 | return (closes - ma) / (k * std) 49 | 50 | 51 | class MovingAverageConvergenceDivergenceSignal(EMA): 52 | """ 53 | engine.add( MACD(fast, slow, sign, inputs=[OHLCV.close]) ) 54 | or 55 | engine.add( MACD().normalized() ) 56 | """ 57 | inputs = (OHLCV.close,) 58 | win = 9 59 | _min_win = 2 60 | 61 | def __init__(self, fast=12, slow=26, sign=9, inputs: Optional[Sequence[BaseFactor]] = None, 62 | adjust=False): 63 | super().__init__(sign, inputs, adjust) 64 | self.inputs = (EMA(inputs=self.inputs, span=fast) - EMA(inputs=self.inputs, span=slow),) 65 | 66 | def normalized(self): 67 | # In order not to double the calculation, reuse `inputs` factor here 68 | macd = self.inputs[0] 69 | sign = self 70 | return macd - sign 71 | 72 | 73 | class TrueRange(CustomFactor): 74 | """ATR = MA(14, inputs=(TrueRange(),))""" 75 | inputs = (OHLCV.high, OHLCV.low, OHLCV.close) 76 | win = 2 77 | _min_win = 2 78 | 79 | def compute(self, highs, lows, closes): 80 | high_to_low = highs.last() - lows.last() 81 | high_to_prev_close = (highs.last() - closes.first()).abs() 82 | low_to_prev_close = (lows.last() - closes.first()).abs() 83 | max1 = high_to_low.where(high_to_low > high_to_prev_close, high_to_prev_close) 84 | return max1.where(max1 > low_to_prev_close, low_to_prev_close) 85 | 86 | 87 | class RSI(CustomFactor): 88 | """ usage: RSI(win, inputs=[OHLCV.close]) """ 89 | inputs = (OHLCV.close,) 90 | win = 14 91 | _min_win = 2 92 | 93 | def __init__(self, win: Optional[int] = None, inputs: Optional[Sequence[BaseFactor]] = None): 94 | super().__init__(win, inputs) 95 | self.win = self.win + 1 # +1 for 1 day diff 96 | 97 | def compute(self, closes): 98 | def _rsi(_closes): 99 | shift = _closes.roll(1, dims=2) 100 | shift = shift.contiguous() 101 | shift[:, :, 0] = np.nan 102 | diff = _closes - shift 103 | del shift, _closes 104 | up = diff.clamp(min=0) 105 | down = diff.clamp(max=0) 106 | del diff 107 | # Cutler's RSI, more stable, independent to data length 108 | up = nanmean(up[:, :, 1:], dim=2) 109 | down = nanmean(down[:, :, 1:], dim=2).abs() 110 | return 100 - (100 / (1 + up / down)) 111 | # Wilder's RSI 112 | # up = up.ewm(com=14-1, adjust=False).mean() 113 | # down = down.ewm(com=14-1, adjust=False).mean().abs() 114 | return closes.agg(_rsi) 115 | 116 | def normalized(self): 117 | return self / 50 - 1 118 | 119 | 120 | class FastStochasticOscillator(CustomFactor): 121 | """ usage: STOCHF(win, inputs=[OHLCV.high, OHLCV.low, OHLCV.close]) """ 122 | inputs = (OHLCV.high, OHLCV.low, OHLCV.close) 123 | win = 14 124 | _min_win = 2 125 | 126 | def compute(self, highs, lows, closes): 127 | highest_highs = highs.nanmax() 128 | lowest_lows = lows.nanmin() 129 | k = (closes.last() - lowest_lows) / (highest_highs - lowest_lows) 130 | 131 | return k * 100 132 | 133 | def normalized(self): 134 | return self / 100 - 0.5 135 | 136 | 137 | BBANDS = BollingerBands 138 | MACD = MovingAverageConvergenceDivergenceSignal 139 | TRANGE = TrueRange 140 | STOCHF = FastStochasticOscillator 141 | -------------------------------------------------------------------------------- /spectre/parallel/__init__.py: -------------------------------------------------------------------------------- 1 | from .algorithmic import ( 2 | ParallelGroupBy, DummyParallelGroupBy, 3 | Rolling, 4 | nansum, unmasked_sum, 5 | nanmean, unmasked_mean, 6 | nanvar, 7 | nanstd, 8 | masked_last, 9 | masked_first, 10 | nanlast, 11 | nanmax, 12 | nanmin, 13 | pad_2d, 14 | rankdata, 15 | covariance, 16 | pearsonr, 17 | spearman, 18 | linear_regression_1d, 19 | quantile, 20 | masked_kth_value_1d, 21 | clamp_1d_, 22 | ) 23 | 24 | from .constants import DeviceConstant 25 | -------------------------------------------------------------------------------- /spectre/parallel/constants.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class TensorConstant: 5 | def __init__(self, device): 6 | self.device = device 7 | self.linspace_cache = {} 8 | self.r_linspace_cache = {} 9 | self.arange_cache = {} 10 | 11 | def linspace(self, size, dtype): 12 | if dtype in self.linspace_cache: 13 | w = self.linspace_cache[dtype] 14 | if size <= len(w): 15 | return w[:size] 16 | 17 | self.linspace_cache[dtype] = new = torch.linspace( 18 | 0.0, 0.9, size, dtype=dtype, device=self.device) 19 | return new 20 | 21 | def r_linspace(self, size, dtype): 22 | if dtype in self.r_linspace_cache: 23 | w = self.r_linspace_cache[dtype] 24 | if size <= len(w): 25 | return w[:size] 26 | 27 | self.r_linspace_cache[dtype] = new = torch.linspace( 28 | 0.9, 0.0, size, dtype=dtype, device=self.device) 29 | return new 30 | 31 | def arange(self, size, dtype): 32 | if dtype in self.arange_cache: 33 | w = self.arange_cache[dtype] 34 | if size <= len(w): 35 | return w[:size] 36 | 37 | self.arange_cache[dtype] = new = torch.arange(size, dtype=dtype, device=self.device) 38 | return new 39 | 40 | 41 | class DeviceConstant: 42 | constants = {} 43 | 44 | @classmethod 45 | def clean(cls): 46 | cls.constants = {} 47 | 48 | @classmethod 49 | def get(cls, device): 50 | if device in cls.constants: 51 | return cls.constants[device] 52 | 53 | cls.constants[device] = new = TensorConstant(device) 54 | return new 55 | 56 | 57 | 58 | 59 | 60 | -------------------------------------------------------------------------------- /spectre/plotting/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from .returns_chart import ( 3 | plot_quantile_and_cumulative_returns, 4 | cumulative_returns_fig, 5 | ) 6 | 7 | from .factor_diagram import ( 8 | plot_factor_diagram, 9 | ) 10 | 11 | from .chart import ( 12 | plot_chart, 13 | ) 14 | -------------------------------------------------------------------------------- /spectre/plotting/chart.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: Heerozh (Zhang Jianhao) 3 | @copyright: Copyright 2019-2020, Heerozh. All rights reserved. 4 | @license: Apache 2.0 5 | @email: heeroz@gmail.com 6 | """ 7 | import warnings 8 | 9 | 10 | def plot_chart(df_prices, ohlcv, df_factor, trace_types=None, styles=None, inline=True): 11 | import plotly.graph_objects as go 12 | 13 | # group df by asset 14 | asset_group = df_factor.index.get_level_values(1).remove_unused_categories() 15 | dfs = list(df_factor.groupby(asset_group)) 16 | if len(dfs) > 5: 17 | warnings.warn("Warning!! Too many assets {}, only plotting top 5.".format(len(dfs)), 18 | RuntimeWarning) 19 | dfs = dfs[:5] 20 | dfs = [(asset, factors) for (asset, factors) in dfs if factors.shape[0] > 0] 21 | trace_types = trace_types and trace_types or {} 22 | 23 | # init styles 24 | styles = styles and styles or {} 25 | styles['price'] = styles.get('price', {}) 26 | styles['volume'] = styles.get('volume', {}) 27 | 28 | # default styles 29 | styles['height'] = styles.get('height', 500) 30 | styles['price']['line'] = styles['price'].get('line', dict(width=1)) 31 | styles['price']['name'] = styles['price'].get('name', 'price') 32 | styles['volume']['opacity'] = styles['volume'].get('opacity', 0.2) 33 | styles['volume']['yaxis'] = styles['volume'].get('yaxis', 'y2') 34 | styles['volume']['name'] = styles['volume'].get('name', 'volume') 35 | 36 | # get y_axes 37 | y_axes = set() 38 | for k, v in styles.items(): 39 | if not isinstance(v, dict): 40 | continue 41 | if 'yaxis' in v: 42 | y_axes.add('yaxis' + v['yaxis'][1:]) 43 | if 'yref' in v: 44 | y_axes.add('yaxis' + v['yref'][1:]) 45 | 46 | figs = {} 47 | # plotting 48 | for i, (asset, factors) in enumerate(dfs): 49 | fig = go.Figure() 50 | figs[asset] = fig 51 | 52 | factors = factors.droplevel(level=1) 53 | start, end = factors.index[0], factors.index[-1] 54 | 55 | prices = df_prices.loc[(slice(start, end), asset), :].droplevel(level=1) 56 | index = prices.index.strftime("%y-%m-%d %H%M%S") 57 | if ohlcv[0] is not None: 58 | # add candlestick 59 | fig.add_trace( 60 | go.Candlestick(x=index, open=prices[ohlcv[0]], high=prices[ohlcv[1]], 61 | low=prices[ohlcv[2]], close=prices[ohlcv[3]], **styles['price'])) 62 | fig.add_trace( 63 | go.Bar(x=index, y=prices[ohlcv[4]], **styles['volume'])) 64 | else: 65 | # add line plot 66 | if ohlcv[3] is not None: 67 | fig.add_trace( 68 | go.Scatter(x=index, y=prices[ohlcv[3]], **styles['price'])) 69 | if ohlcv[4] is not None: 70 | fig.add_trace( 71 | go.Scatter(x=index, y=prices[ohlcv[4]], **styles['volume'])) 72 | 73 | # add factors 74 | for col in factors.columns: 75 | trace_type = trace_types.get(col, 'Scatter') 76 | if trace_type is None: 77 | continue 78 | style = styles.get(col, {}) 79 | style['name'] = style.get('name', col) 80 | fig.add_trace(getattr(go, trace_type)(x=index, y=factors[col], **style)) 81 | 82 | new_axis = dict(anchor="free", overlaying="y", side="right", position=1) 83 | alpha_ordered_axises = list(y_axes) 84 | alpha_ordered_axises.sort() 85 | for y_axis in alpha_ordered_axises: 86 | fig.update_layout(**{y_axis: new_axis}) 87 | new_axis['position'] -= 0.03 88 | x_right = new_axis['position'] + 0.03 89 | 90 | fig.update_layout(xaxis=dict(domain=[0, x_right])) 91 | fig.update_xaxes(rangeslider=dict(visible=False)) 92 | fig.update_yaxes(showgrid=False, scaleanchor="x", scaleratio=1) 93 | fig.update_layout(legend=dict(xanchor='right', x=x_right, y=1, bgcolor='rgba(0,0,0,0)')) 94 | fig.update_layout(height=styles['height'], barmode='group', bargap=0.5, margin={'t': 50}, 95 | title=asset) 96 | 97 | if inline: 98 | fig.show() 99 | 100 | return figs 101 | -------------------------------------------------------------------------------- /spectre/plotting/factor_diagram.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: Heerozh (Zhang Jianhao) 3 | @copyright: Copyright 2019-2020, Heerozh. All rights reserved. 4 | @license: Apache 2.0 5 | @email: heeroz@gmail.com 6 | """ 7 | from itertools import cycle, islice 8 | 9 | 10 | def plot_factor_diagram(factor): 11 | import plotly.graph_objects as go 12 | from ..factors import BaseFactor, CustomFactor 13 | from ..factors import ColumnDataFactor, DatetimeDataFactor 14 | 15 | color = [ 16 | "rgba(31, 119, 180, 0.8)", "rgba(255, 127, 14, 0.8)", "rgba(44, 160, 44, 0.8)", 17 | "rgba(214, 39, 40, 0.8)", "rgba(148, 103, 189, 0.8)", "rgba(140, 86, 75, 0.8)", 18 | "rgba(227, 119, 194, 0.8)", "rgba(127, 127, 127, 0.8)", "rgba(188, 189, 34, 0.8)", 19 | "rgba(23, 190, 207, 0.8)", "rgba(31, 119, 180, 0.8)", "rgba(255, 127, 14, 0.8)", 20 | "rgba(44, 160, 44, 0.8)", "rgba(214, 39, 40, 0.8)", "rgba(148, 103, 189, 0.8)", 21 | "rgba(140, 86, 75, 0.8)", "rgba(227, 119, 194, 0.8)", "rgba(127, 127, 127, 0.8)", 22 | "rgba(188, 189, 34, 0.8)", "rgba(23, 190, 207, 0.8)", "rgba(31, 119, 180, 0.8)", 23 | "rgba(255, 127, 14, 0.8)", "rgba(44, 160, 44, 0.8)", "rgba(214, 39, 40, 0.8)", 24 | "rgba(148, 103, 189, 0.8)", "rgba(140, 86, 75, 0.8)", "rgba(227, 119, 194, 0.8)", 25 | "rgba(127, 127, 127, 0.8)", "rgba(188, 189, 34, 0.8)", "rgba(23, 190, 207, 0.8)", 26 | "rgba(31, 119, 180, 0.8)", "rgba(255, 127, 14, 0.8)", "rgba(44, 160, 44, 0.8)", 27 | "rgba(214, 39, 40, 0.8)", "rgba(148, 103, 189, 0.8)", "magenta", 28 | "rgba(227, 119, 194, 0.8)", "rgba(127, 127, 127, 0.8)", "rgba(188, 189, 34, 0.8)", 29 | "rgba(23, 190, 207, 0.8)", "rgba(31, 119, 180, 0.8)", "rgba(255, 127, 14, 0.8)", 30 | "rgba(44, 160, 44, 0.8)", "rgba(214, 39, 40, 0.8)", "rgba(148, 103, 189, 0.8)", 31 | "rgba(140, 86, 75, 0.8)", "rgba(227, 119, 194, 0.8)", "rgba(127, 127, 127, 0.8)" 32 | ] 33 | 34 | factor_id = dict() 35 | label = [] 36 | source = [] 37 | target = [] 38 | value = [] 39 | line_label = [] 40 | 41 | def add_node(this, parent_label_id, parent_label, parent_win): 42 | class_id = id(this) 43 | 44 | if class_id in factor_id: 45 | this_label_id = factor_id[class_id] 46 | else: 47 | this_label_id = len(label) 48 | if isinstance(this, ColumnDataFactor): 49 | label.append(this.inputs[0]) 50 | if isinstance(this, DatetimeDataFactor): 51 | label.append(this.attr) 52 | else: 53 | label.append(type(this).__name__) 54 | 55 | if parent_label_id is not None: 56 | source.append(parent_label_id) 57 | target.append(this_label_id) 58 | value.append(parent_win) 59 | line_label.append(parent_label) 60 | 61 | if class_id in factor_id: 62 | return 63 | 64 | if isinstance(this, CustomFactor): 65 | this_win = this.win 66 | else: 67 | this_win = 1 68 | 69 | factor_id[class_id] = this_label_id 70 | if isinstance(this, CustomFactor): 71 | if this.inputs: 72 | for upstream in this.inputs: 73 | if isinstance(upstream, BaseFactor): 74 | add_node(upstream, this_label_id, 'inputs', this_win) 75 | 76 | if this._mask is not None: 77 | add_node(this._mask, this_label_id, 'mask', this_win) 78 | 79 | add_node(factor, None, None, None) 80 | 81 | fig = go.Figure(data=[go.Sankey( 82 | valueformat=".0f", 83 | valuesuffix="win", 84 | node=dict( 85 | pad=15, 86 | thickness=15, 87 | line=dict(color="black", width=0.5), 88 | label=label, 89 | color=list(islice(cycle(color), len(label))) 90 | ), 91 | # Add links 92 | link=dict( 93 | source=source, 94 | target=target, 95 | value=value, 96 | label=line_label 97 | ))]) 98 | 99 | fig.update_layout(title_text="Factor Diagram") 100 | fig.show() 101 | -------------------------------------------------------------------------------- /spectre/plotting/returns_chart.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: Heerozh (Zhang Jianhao) 3 | @copyright: Copyright 2019-2020, Heerozh. All rights reserved. 4 | @license: Apache 2.0 5 | @email: heeroz@gmail.com 6 | """ 7 | import math 8 | from itertools import cycle 9 | import sys 10 | 11 | 12 | DEFAULT_COLORS = [ 13 | 'rgb(99, 110, 250)', 'rgb(239, 85, 59)', 'rgb(0, 204, 150)', 14 | 'rgb(171, 99, 250)', 'rgb(255, 161, 90)', 'rgb(25, 211, 243)' 15 | ] 16 | 17 | 18 | def plot_quantile_and_cumulative_returns(factor_data, mean_ret): 19 | """ 20 | Plotly Installation: 21 | https://github.com/plotly/plotly.py#jupyterlab-support-python-35 22 | """ 23 | quantiles = mean_ret.index 24 | 25 | if 'plotly.graph_objects' not in sys.modules: 26 | print('Importing plotly, it may take a while...') 27 | import plotly.graph_objects as go 28 | import plotly.subplots as subplots 29 | 30 | x = quantiles 31 | factors = mean_ret.columns.levels[0] 32 | periods = list(mean_ret.columns.levels[1]) 33 | periods.sort(key=lambda cn: int(cn[:-1])) 34 | rows = math.ceil(len(factors)) 35 | 36 | colors = dict(zip(periods + ['to'], cycle(DEFAULT_COLORS))) 37 | quantile_styles = { 38 | period: {'name': period, 'legendgroup': period, 39 | 'hovertemplate': 'Quantile:%{x}
' 40 | 'Return: %{y:.2f}bps ±%{error_y.array:.2f}bps', 41 | 'marker': {'color': colors[period]}} 42 | for period in periods 43 | } 44 | cumulative_styles = { 45 | period: {'name': period, 'mode': 'lines', 'legendgroup': period, 'showlegend': False, 46 | 'hovertemplate': 'Date:%{x}
' 47 | 'Return: %{y:.3f}%', 48 | 'marker': {'color': colors[period]}} 49 | for period in periods 50 | } 51 | turnover_styles = {'opacity': 0.2, 'name': 'turnover', 'legendgroup': 'turnover', 52 | 'marker': {'color': colors['to']}} 53 | 54 | specs = [[{}, {"secondary_y": True}]] * rows 55 | fig = subplots.make_subplots( 56 | rows=rows, cols=2, 57 | vertical_spacing=0.06 / rows, 58 | horizontal_spacing=0.06, 59 | specs=specs, 60 | subplot_titles=['Quantile Return', 'Portfolio cumulative returns'], 61 | ) 62 | 63 | mean_ret = mean_ret * 10000 64 | for i, factor in enumerate(factors): 65 | row = i + 1 66 | weight_col = (factor, 'factor_weight') 67 | weighted = factor_data['Returns'].multiply(factor_data[weight_col], axis=0) 68 | factor_return = weighted.groupby(level='date').sum() 69 | for j, period in enumerate(periods): 70 | y = mean_ret.loc[:, (factor, period, 'mean')] 71 | err_y = mean_ret.loc[:, (factor, period, 'sem')] 72 | fig.add_trace(go.Bar( 73 | x=x, y=y, error_y=dict(type='data', array=err_y, thickness=0.2), 74 | yaxis='y1', **quantile_styles[period] 75 | ), row=row, col=1) 76 | quantile_styles[period]['showlegend'] = False 77 | 78 | open_period = period 79 | if period.endswith('D'): 80 | open_period = 'b' + open_period 81 | try: 82 | cum_ret = factor_return[period].resample(open_period).mean().dropna() 83 | except ValueError as e: 84 | print("pandas re-sampling failed, try set " 85 | "`engine.timezone = 'your local timezone'`") 86 | cum_ret = factor_return[period].resample(period).mean().dropna() 87 | cum_ret = (cum_ret + 1).cumprod() * 100 - 100 88 | fig.add_trace(go.Scatter( 89 | x=cum_ret.index, y=cum_ret.values, yaxis='y2', **cumulative_styles[period] 90 | ), row=row, col=2) 91 | 92 | fig.update_xaxes(type="category", row=row, col=1) 93 | 94 | fig.add_shape(go.layout.Shape( 95 | type="line", line=dict(width=1), 96 | y0=0, y1=0, x0=factor_return.index[0], x1=factor_return.index[-1], 97 | ), row=row, col=2) 98 | 99 | weight_diff = factor_data[weight_col].unstack(level=[1]).diff() 100 | to = weight_diff.abs().sum(axis=1) * 100 101 | resample = int(len(to) / 64) 102 | if resample > 0: 103 | to = to.fillna(0).rolling(resample).mean()[::resample] 104 | fig.add_trace(go.Bar(x=to.index, y=to.values, **turnover_styles), 105 | secondary_y=True, row=row, col=2) 106 | turnover_styles['showlegend'] = False 107 | 108 | fig.update_yaxes(title_text=factor, row=row, col=1, matches='y1') 109 | fig.update_yaxes(row=row, col=2, ticksuffix='%') 110 | fig.update_yaxes(row=row, col=2, secondary_y=False, matches='y2') 111 | 112 | fig.update_layout(height=300 * rows, barmode='group', bargap=0.5, margin={'t': 50}) 113 | return fig 114 | 115 | 116 | def cumulative_returns_fig(returns, positions, transactions, benchmark, annual_risk_free, start=0): 117 | from ..trading import turnover, sharpe_ratio, drawdown, annual_volatility 118 | 119 | import plotly.graph_objects as go 120 | import plotly.subplots as subplots 121 | 122 | fig = subplots.make_subplots(specs=[[{"secondary_y": True}]]) 123 | 124 | cum_ret = returns.iloc[start:] 125 | cum_ret = (cum_ret + 1).cumprod() 126 | fig.add_trace(go.Scatter(x=cum_ret.index, y=cum_ret.values * 100 - 100, name='portfolio', 127 | hovertemplate='Date:%{x}
Return: %{y:.3f}%')) 128 | fig.add_shape(go.layout.Shape(y0=0, y1=0, x0=cum_ret.index[0], x1=cum_ret.index[-1], 129 | type="line", line=dict(width=1))) 130 | 131 | if benchmark is not None: 132 | cum_bench = benchmark.iloc[start:] 133 | cum_bench = (cum_bench + 1).cumprod() 134 | fig.add_trace(go.Scatter(x=cum_bench.index, y=cum_bench.values * 100 - 100, 135 | name='benchmark', line=dict(width=0.5))) 136 | 137 | fig.add_shape(go.layout.Shape( 138 | type="rect", xref="x", yref="paper", opacity=0.5, line_width=0, 139 | fillcolor="LightGoldenrodYellow", layer="below", 140 | y0=0, y1=1, x0=cum_ret.idxmax(), x1=cum_ret[cum_ret.idxmax():].idxmin(), 141 | )) 142 | 143 | to = turnover(positions, transactions).iloc[start:] * 100 144 | resample = int(len(to) / 126) 145 | if resample > 0: 146 | to = to.fillna(0).rolling(resample).mean()[::resample] 147 | fig.add_trace(go.Bar(x=to.index, y=to.values, opacity=0.2, name='turnover'), 148 | secondary_y=True) 149 | 150 | sr = sharpe_ratio(returns, annual_risk_free) 151 | dd, ddd = drawdown(cum_ret) 152 | mdd = abs(dd.min()) 153 | mdd_dur = ddd.max() 154 | vol = annual_volatility(returns) * 100 155 | 156 | if benchmark is not None: 157 | bench_sr = sharpe_ratio(benchmark, annual_risk_free) 158 | bench_vol = annual_volatility(benchmark) * 100 159 | else: 160 | bench_sr = 0 161 | bench_vol = 0 162 | 163 | ann = go.layout.Annotation( 164 | x=0.01, y=0.98, xref="paper", yref="paper", 165 | showarrow=False, borderwidth=1, bordercolor='black', align='left', 166 | text="Overall (portfolio/benchmark)
" 167 | "SharpeRatio: {:.3f}/{:.3f}
" 168 | "MaxDrawDown: {:.2f}%, {} Days
" 169 | "AnnualVolatility: {:.2f}%/{:.2f}%" 170 | .format(sr, bench_sr, mdd * 100, mdd_dur, vol, bench_vol), 171 | ) 172 | 173 | fig.update_layout(height=400, annotations=[ann], margin={'t': 50}) 174 | fig.update_xaxes(tickformat='%Y-%m-%d') 175 | fig.update_yaxes(title_text='cumulative return', ticksuffix='%', secondary_y=False) 176 | fig.update_yaxes(title_text='turnover', ticksuffix='%', secondary_y=True) 177 | return fig 178 | -------------------------------------------------------------------------------- /spectre/trading/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Trading algorithm architecture 3 | ------------------------------ 4 | 5 | +----------+ 6 | | +<---Network 7 | |Broker API| 8 | Live | +<--+ 9 | +----+-----+ | 10 | | | 11 | +------+-------+ | 12 | +----fire-EveryBarData-event------+LiveDataLoader| | 13 | | +------+-------+ | 14 | | | | 15 | | +------------------+ +---------+ +-----+------+ | 16 | | |MarketEventManager| |Algorithm| |FactorEngine| | 17 | | +---------+--------+ +----+----+ +-----+------+ | 18 | | | | | | 19 | | +----+-----+ +-----+-----+ +-+-+ | 20 | +----->+fire_event+-->+_run_engine+--->+run+---+ | 21 | +----+-----+ +-----+----++ +---+ | | 22 | | | ^ | | 23 | | | +--save-data--+ | 24 | | | | 25 | time +----+-----+ +----+----+ --------+ | 26 | trigger->+fire close+--->+rebalance+--->+Blotter+----+ 27 | +----------+ +---------+ +-------+ 28 | 29 | 30 | +-------------+ 31 | Back-test |CsvDataLoader| 32 | +-----+-------+ 33 | | 34 | +----------------------+ +---------+ +-----+------+ 35 | |SimulationEventManager| |Algorithm| |FactorEngine| 36 | +------------+---------+ +----+----+ +-----+------+ 37 | | | | 38 | +------+------+ +----+-----+ +-+-+ 39 | |loop data row+--->+run_engine+--->+run+---+ 40 | +------+---+--+ ++---+----++ +---+ | 41 | | ^ | | ^ | 42 | | +-return-+ +---return----+ 43 | | | 44 | | +----+----+ +-----------------+ 45 | +---------->+rebalance+-->+SimulationBlotter| 46 | +---------+ +-----------------+ 47 | 48 | 49 | Pseudo-code for back-test and live 50 | ---------------------------------- 51 | 52 | class MyAlg(trading.CustomAlgorithm): 53 | def initialize(self): 54 | engine = self.get_factor_engine() 55 | factor = .... 56 | factor_engine.add(factor, 'your_factor') 57 | 58 | # 10000 ns before market close 59 | self.schedule_rebalance(trading.event.MarketClose(self.rebalance, -10000)) 60 | 61 | self.blotter.set_commission() # only works on back-test 62 | self.blotter.set_slippage() # only works on back-test 63 | 64 | def rebalance(self, data, history): 65 | weight = data.your_factor / data.your_factor.sum() 66 | self.order_to_percent(data.index, weight) 67 | 68 | record(...) 69 | 70 | def terminate(self): 71 | plot() 72 | 73 | # Back-test 74 | ----------------- 75 | loader = spectre.data.CsvDirLoader(...) 76 | blotter = spectre.trading.SimulationBlotter(loader) 77 | evt_mgr = spectre.trading.SimulationEventManager() 78 | alg = MyAlg(blotter, man=loader) 79 | evt_mgr.subscribe(alg) 80 | evt_mgr.subscribe(blotter) 81 | evt_mgr.run('2018-01-01', '2019-01-01') 82 | 83 | ## Or the helper function: 84 | spectre.trading.run_backtest(loader, MyAlg, '2018-01-01', '2019-01-01') 85 | 86 | # Live 87 | ---------------- 88 | class YourBrokerAPI: 89 | class LiveDataLoader(EventReceiver, DataLoader): 90 | def on_run(): 91 | self.schedule(event.Always(read_data)) 92 | def read_data(self): 93 | api.asio.read() 94 | agg_data(_cache) 95 | _cache.resample(self.rule) 96 | if new_bar: 97 | self.fire_event(event.EveryBarData) 98 | def load(...): 99 | return self._cache[xx:xx] 100 | ... 101 | 102 | broker_api = YourBrokerAPI() 103 | loader = broker_api.LiveDataLoader(rule='5mins') 104 | blotter = broker_api.LiveBlotter() 105 | 106 | evt_mgr = spectre.trading.MarketEventManager(calendar_2020) 107 | evt_mgr.subscribe(loader) 108 | 109 | alg = MyAlg(blotter, main=loader) 110 | evt_mgr.subscribe(alg) 111 | evt_mgr.run() 112 | 113 | """ 114 | from .event import ( 115 | Event, 116 | EveryBarData, 117 | Always, 118 | MarketOpen, 119 | MarketClose, 120 | EventReceiver, 121 | EventManager, 122 | MarketEventManager, 123 | ) 124 | from .calendar import ( 125 | Calendar, 126 | CNCalendar, 127 | JPCalendar, 128 | ) 129 | from .algorithm import ( 130 | CustomAlgorithm, 131 | SimulationEventManager 132 | ) 133 | from .stopmodel import ( 134 | StopModel, 135 | TrailingStopModel, 136 | PnLDecayTrailingStopModel, 137 | TimeDecayTrailingStopModel 138 | ) 139 | from .position import ( 140 | Position, 141 | ) 142 | from .portfolio import ( 143 | Portfolio, 144 | ) 145 | from .blotter import ( 146 | BaseBlotter, 147 | SimulationBlotter, 148 | ManualBlotter, 149 | CommissionModel, 150 | DailyCurbModel, 151 | ) 152 | from .metric import ( 153 | drawdown, 154 | sharpe_ratio, 155 | turnover, 156 | annual_volatility, 157 | ) 158 | 159 | 160 | def run_backtest(loader: 'DataLoader', alg_type: 'Type[CustomAlgorithm]', start, end, 161 | delay_factor=True, ohlcv=None): 162 | # force python to free memory, else may be encountering cuda out of memory 163 | import gc 164 | import pandas as pd 165 | gc.collect() 166 | 167 | _blotter = SimulationBlotter(loader, start=pd.to_datetime(start, utc=True), ohlcv=ohlcv) 168 | evt_mgr = SimulationEventManager() 169 | alg = alg_type(_blotter, main=loader) 170 | evt_mgr.subscribe(_blotter) 171 | evt_mgr.subscribe(alg) 172 | evt_mgr.run(start, end, delay_factor) 173 | 174 | return alg.results 175 | 176 | 177 | def get_algorithm_data(loader: 'DataLoader', alg_type: 'Type[CustomAlgorithm]', 178 | start, end, delay_factor=True): 179 | import pandas as pd 180 | start, end = pd.to_datetime(start, utc=True), pd.to_datetime(end, utc=True) 181 | 182 | _blotter = SimulationBlotter(loader, start=start) 183 | evt_mgr = SimulationEventManager() 184 | alg = alg_type(_blotter, main=loader) 185 | evt_mgr.subscribe(_blotter) 186 | evt_mgr.subscribe(alg) 187 | alg.on_run() 188 | 189 | return alg.run_engine(start, end, delay_factor) 190 | -------------------------------------------------------------------------------- /spectre/trading/calendar.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: Heerozh (Zhang Jianhao) 3 | @copyright: Copyright 2019, Heerozh. All rights reserved. 4 | @license: Apache 2.0 5 | @email: heeroz@gmail.com 6 | """ 7 | from collections import defaultdict 8 | from typing import Dict 9 | import pandas as pd 10 | 11 | 12 | class Calendar: 13 | """ 14 | Usage: 15 | call build() first, get business day calendar. 16 | and manually add holiday by calling set_as_holiday(). 17 | if open half-day, use remove_events() remove all events that day, and add_event() manually. 18 | US holiday calendar can found at https://iextrading.com/trading/ 19 | """ 20 | 21 | def __init__(self) -> None: 22 | self.events = defaultdict(list) 23 | self.timezone = None 24 | 25 | def build(self, start: str, end: str, daily_events: Dict[str, str], tz='UTC', freq='B', 26 | pop_passed=True): 27 | """ build("2020", {'Open': '9:00:00', 'Close': '15:00:00'}) """ 28 | self.timezone = tz 29 | days = pd.date_range(pd.Timestamp(start, tz=tz).normalize(), 30 | pd.Timestamp(end, tz=tz).normalize(), 31 | tz=tz, freq=freq) 32 | if len(days) == 0: 33 | raise ValueError("Empty date range between now({}) to end({})".format( 34 | pd.Timestamp.now(tz=tz).normalize(), end)) 35 | 36 | self.events = {name: [day + pd.Timedelta(time) for day in days] 37 | for name, time in daily_events.items()} 38 | 39 | if pop_passed: 40 | for k, _ in self.events.items(): 41 | self.pop_passed(k) 42 | 43 | def add_event(self, event: str, datetime: pd.Timestamp): 44 | self.events[event].append(datetime) 45 | self.events[event].sort() 46 | 47 | def remove_events(self, date: pd.Timestamp): 48 | self.events = { 49 | event: [dt for dt in dts if dt.normalize() != date] 50 | for event, dts in self.events.items() 51 | } 52 | 53 | def set_as_holiday(self, date: pd.Timestamp): 54 | return self.remove_events(date) 55 | 56 | def hr_now(self): 57 | """ Return now time """ 58 | # todo high res 59 | return pd.Timestamp.now(self.timezone) 60 | 61 | def pop_passed(self, event_name): 62 | """ Remove passed events """ 63 | now = self.hr_now() 64 | # every event is daily, so will not be overkilled 65 | dts = self.events[event_name] 66 | while True: 67 | if dts[0] <= now: 68 | del dts[0] 69 | else: 70 | break 71 | return self 72 | 73 | def today_next(self): 74 | """ Return today next events """ 75 | now = self.hr_now() 76 | return { 77 | event: dts[0] 78 | for event, dts in self.events.items() 79 | if dts[0].normalize() == now.normalize() 80 | } 81 | 82 | 83 | class CNCalendar(Calendar): 84 | """ 85 | CN holiday calendar: http://www.sse.com.cn/disclosure/dealinstruc/closed/ 86 | """ 87 | # yearly manually update 88 | closed = [ 89 | *pd.date_range('2020-06-25', '2020-06-28', freq='D'), 90 | *pd.date_range('2020-10-01', '2020-10-08', freq='D'), 91 | 92 | *pd.date_range('2021-01-01', '2021-01-03', freq='D'), 93 | *pd.date_range('2021-02-11', '2021-02-17', freq='D'), 94 | *pd.date_range('2021-04-03', '2021-04-05', freq='D'), 95 | *pd.date_range('2021-05-01', '2021-05-05', freq='D'), 96 | *pd.date_range('2021-06-12', '2021-06-14', freq='D'), 97 | *pd.date_range('2021-09-19', '2021-09-21', freq='D'), 98 | *pd.date_range('2021-10-01', '2021-10-07', freq='D'), 99 | 100 | *pd.date_range('2022-01-01', '2022-01-03', freq='D'), 101 | *pd.date_range('2022-01-31', '2022-02-06', freq='D'), 102 | *pd.date_range('2022-04-03', '2022-04-05', freq='D'), 103 | *pd.date_range('2022-04-30', '2022-05-04', freq='D'), 104 | *pd.date_range('2022-06-03', '2022-06-05', freq='D'), 105 | *pd.date_range('2022-09-10', '2022-09-12', freq='D'), 106 | *pd.date_range('2022-10-01', '2022-10-07', freq='D'), 107 | 108 | *pd.date_range('2023-01-01', '2023-01-02', freq='D'), 109 | *pd.date_range('2023-01-21', '2023-01-27', freq='D'), 110 | *pd.date_range('2023-04-05', '2023-04-05', freq='D'), 111 | *pd.date_range('2023-04-29', '2023-05-03', freq='D'), 112 | *pd.date_range('2023-06-22', '2023-06-24', freq='D'), 113 | *pd.date_range('2023-09-29', '2023-10-06', freq='D'), 114 | 115 | *pd.date_range('2024-01-01', '2024-01-01', freq='D'), 116 | *pd.date_range('2024-02-09', '2024-02-17', freq='D'), 117 | *pd.date_range('2024-04-04', '2024-04-06', freq='D'), 118 | *pd.date_range('2024-05-01', '2024-05-05', freq='D'), 119 | *pd.date_range('2024-06-10', '2024-06-10', freq='D'), 120 | *pd.date_range('2024-09-15', '2024-09-17', freq='D'), 121 | *pd.date_range('2024-10-01', '2024-10-07', freq='D'), 122 | ] 123 | 124 | daily_events = { 125 | 'DayStart': '00:00:00', 126 | 'PreOpen': '9:15:00', 127 | 'Open': '9:30:00', 128 | 'Lunch': '11:30:00', 129 | 'LunchEnd': '13:00:00', 130 | 'Close': '15:00:00', 131 | 'DayEnd': '23:59:59' 132 | } 133 | 134 | def __init__(self, start=None, pop_passed=True): 135 | super().__init__() 136 | timezone = 'Asia/Shanghai' 137 | if start is None: 138 | start = pd.Timestamp.now(self.timezone).normalize() 139 | assert start.year <= CNCalendar.closed[-1].year 140 | self.build( 141 | start=str(start), 142 | end=str(CNCalendar.closed[-1].year + 1), 143 | daily_events=self.daily_events, 144 | tz=timezone, pop_passed=pop_passed) 145 | for d in CNCalendar.closed: 146 | self.set_as_holiday(d.tz_localize(timezone)) 147 | 148 | 149 | class JPCalendar(Calendar): 150 | """ 151 | JP holiday calendar: https://www.jpx.co.jp/corporate/about-jpx/calendar/index.html 152 | """ 153 | closed = [ 154 | *pd.date_range(f'{pd.Timestamp.now().year}-01-01', 155 | f'{pd.Timestamp.now().year}-01-03', freq='D'), 156 | *pd.date_range(f'{pd.Timestamp.now().year+1}-01-01', 157 | f'{pd.Timestamp.now().year+1}-01-03', freq='D'), 158 | *pd.date_range(f'{pd.Timestamp.now().year}-12-31', 159 | f'{pd.Timestamp.now().year}-12-31', freq='D'), 160 | *pd.date_range(f'{pd.Timestamp.now().year+1}-12-31', 161 | f'{pd.Timestamp.now().year+1}-12-31', freq='D'), 162 | 163 | # yearly manually updated 164 | *pd.date_range('2023-01-09', '2023-01-09', freq='D'), 165 | *pd.date_range('2023-02-11', '2023-02-11', freq='D'), 166 | *pd.date_range('2023-02-23', '2023-02-23', freq='D'), 167 | *pd.date_range('2023-03-21', '2023-03-21', freq='D'), 168 | *pd.date_range('2023-04-29', '2023-04-29', freq='D'), 169 | *pd.date_range('2023-05-03', '2023-05-05', freq='D'), 170 | *pd.date_range('2023-07-17', '2023-07-17', freq='D'), 171 | *pd.date_range('2023-08-11', '2023-08-11', freq='D'), 172 | *pd.date_range('2023-09-18', '2023-09-18', freq='D'), 173 | *pd.date_range('2023-09-23', '2023-09-23', freq='D'), 174 | *pd.date_range('2023-10-09', '2023-10-09', freq='D'), 175 | *pd.date_range('2023-11-03', '2023-11-03', freq='D'), 176 | *pd.date_range('2023-11-23', '2023-11-23', freq='D'), 177 | 178 | *pd.date_range('2024-01-08', '2024-01-08', freq='D'), 179 | *pd.date_range('2024-02-11', '2024-02-12', freq='D'), 180 | *pd.date_range('2024-02-23', '2024-02-23', freq='D'), 181 | *pd.date_range('2024-03-20', '2024-03-20', freq='D'), 182 | *pd.date_range('2024-04-29', '2024-04-29', freq='D'), 183 | *pd.date_range('2024-05-03', '2024-05-06', freq='D'), 184 | *pd.date_range('2024-07-15', '2024-07-15', freq='D'), 185 | *pd.date_range('2024-08-11', '2024-08-12', freq='D'), 186 | *pd.date_range('2024-09-16', '2024-09-16', freq='D'), 187 | *pd.date_range('2024-09-22', '2024-09-23', freq='D'), 188 | *pd.date_range('2024-10-14', '2024-10-14', freq='D'), 189 | *pd.date_range('2024-11-03', '2024-11-04', freq='D'), 190 | *pd.date_range('2024-11-23', '2024-11-23', freq='D'), 191 | ] 192 | 193 | daily_events = { 194 | 'DayStart': '00:00:00', 195 | 'PreOpen': '8:00:00', 196 | 'Open': '9:00:00', 197 | 'Lunch': '11:30:00', 198 | 'LunchEnd': '12:30:00', 199 | 'Close': '15:00:00', 200 | 'DayEnd': '23:59:59' 201 | } 202 | 203 | def __init__(self, start=None, pop_passed=True): 204 | super().__init__() 205 | timezone = 'Asia/Tokyo' 206 | if start is None: 207 | start = pd.Timestamp.now(self.timezone).normalize() 208 | assert start.year <= self.closed[-1].year 209 | self.build( 210 | start=str(start), 211 | end=str(self.closed[-1].year + 1), 212 | daily_events=self.daily_events, 213 | tz=timezone, pop_passed=pop_passed) 214 | for d in self.closed: 215 | self.set_as_holiday(d.tz_localize(timezone)) 216 | -------------------------------------------------------------------------------- /spectre/trading/event.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: Heerozh (Zhang Jianhao) 3 | @copyright: Copyright 2019, Heerozh. All rights reserved. 4 | @license: Apache 2.0 5 | @email: heeroz@gmail.com 6 | """ 7 | import time 8 | from typing import Type 9 | import pandas as pd 10 | from .calendar import Calendar 11 | 12 | 13 | class Event: 14 | def __init__(self, callback) -> None: 15 | self.callback = callback # Callable[[Object], None] 16 | 17 | def on_schedule(self, evt_mgr): 18 | pass 19 | 20 | def should_trigger(self) -> bool: 21 | raise NotImplementedError("abstractmethod") 22 | 23 | 24 | class EveryBarData(Event): 25 | """This event is triggered passively""" 26 | def should_trigger(self) -> bool: 27 | return False 28 | 29 | 30 | class Always(Event): 31 | """Always event is useful for live date IO function like asio.read_until_complete()""" 32 | def should_trigger(self) -> bool: 33 | return True 34 | 35 | 36 | class CalendarEvent(Event): 37 | """ The following code is for live trading, will not work on back-testing """ 38 | def __init__(self, calendar_event_name, callback, offset_ns=0) -> None: 39 | super().__init__(callback) 40 | self.offset = offset_ns 41 | self.calendar = None 42 | self.event_name = calendar_event_name 43 | self.trigger_time = None 44 | 45 | def on_schedule(self, evt_mgr): 46 | try: 47 | self.calendar = evt_mgr.calendar 48 | self.calculate_range() 49 | except AttributeError: 50 | pass 51 | 52 | def calculate_range(self): 53 | self.trigger_time = self.calendar.events[self.event_name][0] + \ 54 | pd.Timedelta(self.offset, unit='ns') 55 | 56 | def should_trigger(self) -> bool: 57 | if self.calendar.hr_now() >= self.trigger_time: 58 | self.calendar.pop_passed(self.event_name) 59 | self.calculate_range() 60 | return True 61 | return False 62 | 63 | 64 | class MarketOpen(CalendarEvent): 65 | """ Works on both live and backtest """ 66 | def __init__(self, callback, offset_ns=0) -> None: 67 | super().__init__('Open', callback, offset_ns) 68 | 69 | 70 | class MarketClose(CalendarEvent): 71 | """ Works on both live and backtest """ 72 | def __init__(self, callback, offset_ns=0) -> None: 73 | super().__init__('Close', callback, offset_ns) 74 | 75 | 76 | # ---------------------------------------------------------------- 77 | 78 | 79 | class EventReceiver: 80 | def __init__(self) -> None: 81 | self._event_manager = None 82 | 83 | def unsubscribe(self): 84 | if self._event_manager is not None: 85 | self._event_manager.unsubscribe(self) 86 | 87 | def schedule(self, evt: Event): 88 | self._event_manager.schedule(self, evt) 89 | 90 | def stop_event_manager(self): 91 | self._event_manager.stop() 92 | 93 | def fire_event(self, evt_type: Type[Event]): 94 | self._event_manager.fire_event(self, evt_type) 95 | 96 | def on_run(self): 97 | pass 98 | 99 | def on_end_of_run(self): 100 | pass 101 | 102 | 103 | class EventManager: 104 | def __init__(self) -> None: 105 | self._subscribers = dict() 106 | self._stop = False 107 | 108 | def subscribe(self, receiver: EventReceiver): 109 | assert receiver not in self._subscribers, 'Duplicate subscribe' 110 | self._subscribers[receiver] = [] 111 | receiver._event_manager = self 112 | 113 | def unsubscribe(self, receiver: EventReceiver): 114 | assert receiver in self._subscribers, 'Subscriber not exists' 115 | del self._subscribers[receiver] 116 | receiver._event_manager = None 117 | 118 | def schedule(self, receiver: EventReceiver, event: Event): 119 | self._subscribers[receiver].append(event) 120 | event.on_schedule(self) 121 | 122 | def fire_event(self, source, evt_type: Type[Event]): 123 | for r, events in self._subscribers.items(): 124 | for evt in events: 125 | if isinstance(evt, evt_type): 126 | evt.callback(source) 127 | 128 | def stop(self): 129 | self._stop = True 130 | 131 | def _beg_run(self): 132 | if not self._subscribers: 133 | raise ValueError("At least one subscriber.") 134 | 135 | for r, events in self._subscribers.items(): 136 | # clear scheduled events 137 | events.clear() 138 | r.on_run() 139 | 140 | def _end_run(self): 141 | for r in self._subscribers.keys(): 142 | r.on_end_of_run() 143 | 144 | def _run_once(self): 145 | for r, events in self._subscribers.items(): 146 | for event in events: 147 | if event.should_trigger(): 148 | event.callback(self) 149 | 150 | def run(self, *params): 151 | self._beg_run() 152 | 153 | while not self._stop: 154 | time.sleep(0.001) 155 | self._run_once() 156 | 157 | self._end_run() 158 | 159 | 160 | # ---------------------------------------------------------------- 161 | 162 | 163 | class MarketEventManager(EventManager): 164 | def __init__(self, calendar: Calendar) -> None: 165 | self.calendar = calendar 166 | super().__init__() 167 | 168 | -------------------------------------------------------------------------------- /spectre/trading/metric.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: Heerozh (Zhang Jianhao) 3 | @copyright: Copyright 2019-2020, Heerozh. All rights reserved. 4 | @license: Apache 2.0 5 | @email: heeroz@gmail.com 6 | """ 7 | import numpy as np 8 | import pandas as pd 9 | 10 | 11 | def drawdown(cumulative_returns): 12 | max_ret = cumulative_returns.cummax() 13 | dd = cumulative_returns / max_ret - 1 14 | dd_group = 0 15 | 16 | def drawdown_split(x): 17 | nonlocal dd_group 18 | if dd[x] == 0: 19 | dd_group += 1 20 | return dd_group 21 | 22 | dd_duration = dd.groupby(drawdown_split).cumcount() 23 | return dd, dd_duration 24 | 25 | 26 | def sharpe_ratio(daily_returns: pd.Series, annual_risk_free_rate): 27 | risk_adj_ret = daily_returns.sub(annual_risk_free_rate/252) 28 | annual_factor = np.sqrt(252) 29 | return annual_factor * risk_adj_ret.mean() / risk_adj_ret.std(ddof=1) 30 | 31 | 32 | def turnover(positions, transactions): 33 | if transactions.shape[0] == 0: 34 | return transactions.amount 35 | value_trades = (transactions.amount * transactions.fill_price).abs() 36 | value_trades = value_trades.groupby(value_trades.index.normalize()).sum() 37 | return value_trades / positions.value.sum(axis=1) 38 | 39 | 40 | def annual_volatility(daily_returns: pd.Series): 41 | volatility = daily_returns.std(ddof=1) 42 | annual_factor = np.sqrt(252) 43 | return annual_factor * volatility 44 | -------------------------------------------------------------------------------- /spectre/trading/portfolio.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: Heerozh (Zhang Jianhao) 3 | @copyright: Copyright 2019-2020, Heerozh. All rights reserved. 4 | @license: Apache 2.0 5 | @email: heeroz@gmail.com 6 | """ 7 | from typing import Union, Callable 8 | import pandas as pd 9 | import numpy as np 10 | from .position import Position 11 | from .stopmodel import StopModel 12 | 13 | 14 | class Portfolio: 15 | def __init__(self, stop_model: StopModel = None): 16 | self._history = [] 17 | self._positions = dict() 18 | self._cash = 0 19 | self._funds_change = [] 20 | self._current_dt = None 21 | self.stop_model = stop_model 22 | 23 | def set_stop_model(self, stop_model: StopModel): 24 | """ 25 | Set default portfolio stop model. 26 | Stop model can make more strategic triggers than stop orders. 27 | """ 28 | self.stop_model = stop_model 29 | 30 | @property 31 | def history(self): 32 | ret = pd.DataFrame(self._history + [self._get_today_record()]) 33 | ret.columns = pd.MultiIndex.from_tuples(ret.columns) 34 | ret = ret.set_index('index').sort_index(axis=0).sort_index(axis=1) 35 | return ret 36 | 37 | @property 38 | def fund_history(self): 39 | return pd.DataFrame(self._funds_change).set_index('index') 40 | 41 | @property 42 | def returns(self): 43 | if self._funds_change: 44 | value = self.history.value.sum(axis=1) 45 | funds = pd.DataFrame(self._funds_change) 46 | funds['index'] = funds['index'].fillna(value.index[0]) 47 | funds = funds.set_index('index').amount 48 | return value.sub(funds, fill_value=0) / value.shift(1) - 1 49 | else: 50 | return self.history.value.sum(axis=1).pct_change() 51 | 52 | @property 53 | def positions(self): 54 | return self._positions 55 | 56 | @property 57 | def cash(self): 58 | return self._cash 59 | 60 | @property 61 | def value(self): 62 | # for asset, shares in self.positions.items(): 63 | # if self._last_price[asset] != self._last_price[asset]: 64 | # raise ValueError('{}({}) is nan in {}'.format(asset, shares, self._current_dt)) 65 | values = [pos.value for asset, pos in self.positions.items() if pos.shares != 0] 66 | return sum(values) + self._cash 67 | 68 | @property 69 | def leverage(self): 70 | values = [pos.value for asset, pos in self.positions.items() if pos.shares != 0] 71 | return sum(np.abs(values)) / (sum(values) + self._cash) 72 | 73 | @property 74 | def current_dt(self): 75 | return self._current_dt 76 | 77 | def __repr__(self): 78 | return "" + str(self.history)[11:] 79 | 80 | def clear(self): 81 | self.__init__(self.stop_model) 82 | 83 | def shares(self, asset): 84 | try: 85 | return self._positions[asset].shares 86 | except KeyError: 87 | return 0 88 | 89 | def _get_today_record(self): 90 | current_date = self._current_dt.normalize() 91 | record = {('index', ''): current_date, ('value', 'cash'): self._cash} 92 | for asset, pos in self._positions.items(): 93 | record[('avg_px', asset)] = pos.average_price 94 | record[('shares', asset)] = pos.shares 95 | record[('value', asset)] = pos.value 96 | return record 97 | 98 | def set_datetime(self, dt): 99 | if isinstance(dt, str): 100 | dt = pd.Timestamp(dt) 101 | date = dt.normalize() 102 | if self._current_dt is not None: 103 | current_date = self._current_dt.normalize() 104 | if dt < self._current_dt: 105 | raise ValueError('Cannot set a date less than the current date') 106 | elif date > current_date: 107 | # today add to history 108 | self._history.append(self._get_today_record()) 109 | 110 | self._current_dt = dt 111 | for pos in self._positions.values(): 112 | pos.current_dt = dt 113 | 114 | def update(self, asset, amount, fill_price, commission) -> float: 115 | """asset position + amount, also calculation average_price and realized P&L""" 116 | assert self._current_dt is not None 117 | if amount == 0: 118 | return 0 119 | if asset in self._positions: 120 | empty, realized = self._positions[asset].update( 121 | amount, fill_price, commission, self._current_dt) 122 | if empty: 123 | del self._positions[asset] 124 | return realized 125 | else: 126 | self._positions[asset] = Position( 127 | amount, fill_price, commission, self._current_dt, self.stop_model) 128 | return 0 129 | 130 | def update_cash(self, amount, is_funds=False): 131 | """ is_funds: Is this cash update related to funds transfer, (deposits/withdraw) """ 132 | assert amount == amount 133 | self._cash += amount 134 | if is_funds: 135 | if self._current_dt is None: 136 | current_date = None 137 | else: 138 | current_date = self._current_dt.normalize() 139 | self._funds_change.append({'index': current_date, 'amount': amount}) 140 | 141 | def process_split(self, asset, inverse_ratio: float, last_price): 142 | if asset not in self._positions: 143 | return 144 | pos = self._positions[asset] 145 | cash = pos.process_split(inverse_ratio, last_price) 146 | self.update_cash(cash) 147 | 148 | def process_dividend(self, asset, amount, tax): 149 | if asset not in self._positions: 150 | return 151 | pos = self._positions[asset] 152 | cash = pos.process_dividend(amount, tax) 153 | self.update_cash(cash) 154 | 155 | def process_borrow_interest(self, day_passed, money_interest_rate, stock_interest_rate): 156 | interest = 0 157 | for asset, pos in self._positions.items(): 158 | # 有的时候运行到这pos.last_price没更新会导致cash变成nan,还没找到哪里没更新 159 | if pos.shares < 0 and pos.value == pos.value: 160 | interest += pos.value * (stock_interest_rate / 365) * day_passed 161 | if self._cash < 0: 162 | interest += self._cash * (money_interest_rate / 365) * day_passed 163 | self.update_cash(interest) 164 | 165 | def _update_value_func(self, func): 166 | for asset, pos in self._positions.items(): 167 | price = func(asset) 168 | if price and price == price: 169 | pos.last_price = price 170 | 171 | def _update_value_dict(self, prices): 172 | for asset, pos in self._positions.items(): 173 | price = prices.get(asset, np.nan) 174 | if price == price: 175 | pos.last_price = price 176 | 177 | def update_value(self, prices: Union[Callable, dict]): 178 | if callable(prices): 179 | self._update_value_func(prices) 180 | elif isinstance(prices, dict): 181 | self._update_value_dict(prices) 182 | else: 183 | raise ValueError('prices either callable or dict') 184 | 185 | def check_stop_trigger(self, *args): 186 | ret = [] 187 | for asset in list(self._positions.keys()): 188 | pos = self._positions[asset] 189 | ret.append(pos.check_stop_trigger(asset, -pos.shares, *args)) 190 | return ret 191 | -------------------------------------------------------------------------------- /spectre/trading/position.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: Heerozh (Zhang Jianhao) 3 | @copyright: Copyright 2019-2020, Heerozh. All rights reserved. 4 | @license: Apache 2.0 5 | @email: heeroz@gmail.com 6 | """ 7 | import math 8 | from typing import Tuple 9 | 10 | from .stopmodel import StopModel 11 | 12 | 13 | def sign(x): 14 | return math.copysign(1, x) 15 | 16 | 17 | class Position: 18 | def __init__(self, shares: int, fill_price: float, commission: float, 19 | dt, stop_model: StopModel = None): 20 | self._shares = shares 21 | self._average_price = fill_price + commission / shares 22 | self._last_price = fill_price 23 | self._realized = 0 24 | self._open_dt = dt 25 | self.current_dt = dt 26 | self.stop_model = stop_model 27 | self.stop_tracker = None 28 | if stop_model is not None: 29 | self.stop_tracker = stop_model.new_tracker( 30 | fill_price, True if self._shares < 0 else False) 31 | self.stop_tracker.tracking_position = self 32 | 33 | @property 34 | def open_dt(self): 35 | return self._open_dt 36 | 37 | @property 38 | def period(self): 39 | return self.current_dt - self._open_dt 40 | 41 | @property 42 | def value(self): 43 | return self._shares * self._last_price 44 | 45 | @property 46 | def shares(self): 47 | return self._shares 48 | 49 | @property 50 | def average_price(self): 51 | return self._average_price 52 | 53 | @property 54 | def last_price(self): 55 | return self._last_price 56 | 57 | @last_price.setter 58 | def last_price(self, last_price: float): 59 | self._last_price = last_price 60 | if self.stop_tracker: 61 | self.stop_tracker.update_price(last_price) 62 | 63 | @property 64 | def realized(self): 65 | return self._realized 66 | 67 | @property 68 | def unrealized(self): 69 | return (self._last_price - self._average_price) * self._shares 70 | 71 | @property 72 | def unrealized_percent(self): 73 | return (self._last_price / self._average_price - 1) * sign(self._shares) 74 | 75 | def update(self, amount: int, fill_price: float, commission: float, dt) -> Tuple[bool, float]: 76 | """ 77 | position + amount, fill_price and commission is for calculation average_price and P&L 78 | return (True, realized) when position is empty. 79 | """ 80 | before_shares = self._shares 81 | before_avg_px = self._average_price 82 | after_shares = before_shares + amount 83 | 84 | # If the position is reversed, it will be filled in 2 steps 85 | if after_shares != 0 and sign(after_shares) != sign(before_shares): 86 | fill_1 = amount - after_shares 87 | fill_2 = amount - fill_1 88 | per_comm = commission / amount 89 | # close position 90 | _, realized = self.update(fill_1, fill_price, per_comm * fill_1, dt) 91 | # open a new position 92 | self.__init__(fill_2, fill_price, per_comm * fill_2, dt, stop_model=self.stop_model) 93 | return False, realized 94 | else: 95 | cum_cost = self._average_price * before_shares + amount * fill_price + commission 96 | self._shares = after_shares 97 | if after_shares == 0: 98 | self._average_price = 0 99 | realized = -cum_cost - self._realized 100 | self._realized = -cum_cost 101 | self.last_price = fill_price 102 | return True, realized 103 | else: 104 | self._average_price = cum_cost / after_shares 105 | if after_shares < before_shares: 106 | realized = (before_avg_px - self._average_price) * abs(after_shares) 107 | else: 108 | realized = 0 109 | self._realized += realized 110 | self.last_price = fill_price 111 | return False, realized 112 | 113 | def process_split(self, inverse_ratio: float, last_price: float) -> float: 114 | if inverse_ratio != inverse_ratio or inverse_ratio == 1: 115 | return 0 116 | sp = self._shares * inverse_ratio 117 | cash = 0 118 | if inverse_ratio < 1: # reverse split remaining to cash 119 | remaining = int(self._shares - int(sp) / inverse_ratio) # for more precise 120 | if remaining != 0: 121 | cash = remaining * last_price 122 | self._shares = int(round(sp, 5)) 123 | self._average_price = self._average_price / inverse_ratio 124 | self.last_price = last_price / inverse_ratio 125 | 126 | if self.stop_tracker: 127 | self.stop_tracker.process_split(last_price) 128 | return cash 129 | 130 | def process_dividend(self, amount: float, tax: float) -> float: 131 | if amount != amount or amount == 0: 132 | return 0 133 | self._average_price -= amount 134 | self.last_price -= amount + tax 135 | cash = self._shares * amount 136 | return cash 137 | 138 | def check_stop_trigger(self, *args): 139 | if self.stop_tracker: 140 | return self.stop_tracker.check_trigger(*args) 141 | -------------------------------------------------------------------------------- /spectre/trading/stopmodel.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: Heerozh (Zhang Jianhao) 3 | @copyright: Copyright 2019-2020, Heerozh. All rights reserved. 4 | @license: Apache 2.0 5 | @email: heeroz@gmail.com 6 | """ 7 | import math 8 | 9 | 10 | def sign(x): 11 | return math.copysign(1, x) 12 | 13 | 14 | class PriceTracker: 15 | def __init__(self, current_price, recorder=max): 16 | self.last_price = current_price 17 | self.recorder = recorder 18 | self.recorded_price = current_price 19 | self.tracking_position = None 20 | 21 | def update_price(self, last_price): 22 | self.recorded_price = self.recorder(self.recorded_price, last_price) 23 | self.last_price = last_price 24 | 25 | def process_split(self, inverse_ratio: float): 26 | self.recorded_price /= inverse_ratio 27 | 28 | 29 | # ----------------------------------------------------------------------------- 30 | 31 | 32 | class StopTracker(PriceTracker): 33 | def __init__(self, current_price, stop_price, callback): 34 | super().__init__(current_price, lambda _, x: x) 35 | self._stop_price = stop_price 36 | self.stop_loss = stop_price < current_price 37 | self.callback = callback 38 | 39 | @property 40 | def stop_price(self): 41 | return self._stop_price 42 | 43 | def fire(self, *args): 44 | if callable(self.callback): 45 | return self.callback(*args) 46 | else: 47 | return self.callback 48 | 49 | def check_trigger(self, *args): 50 | if self.stop_loss: 51 | if self.last_price <= self.stop_price: 52 | return self.fire(*args) 53 | else: 54 | if self.last_price >= self.stop_price: 55 | return self.fire(*args) 56 | return False 57 | 58 | 59 | class StopModel: 60 | def __init__(self, ratio: float, callback=None): 61 | self.ratio = ratio 62 | self.callback = callback 63 | 64 | def new_tracker(self, current_price, inverse): 65 | if inverse: 66 | stop_price = current_price * (1 - self.ratio) 67 | else: 68 | stop_price = current_price * (1 + self.ratio) 69 | return StopTracker(current_price, stop_price, self.callback) 70 | 71 | 72 | # ----------------------------------------------------------------------------- 73 | 74 | 75 | class TrailingStopTracker(StopTracker): 76 | def __init__(self, current_price, ratio, callback): 77 | self.ratio = ratio 78 | stop_price = current_price * (1 + self.ratio) 79 | StopTracker.__init__(self, current_price, stop_price, callback=callback) 80 | PriceTracker.__init__(self, current_price, recorder=max if ratio < 0 else min) 81 | 82 | @property 83 | def stop_price(self): 84 | return self.recorded_price * (1 + self.ratio) 85 | 86 | 87 | class TrailingStopModel(StopModel): 88 | """ 89 | Unlike trailing stop order, the ratio in this model is relative to the highest / lowest price, 90 | so -0.1 means stop price is 90% of the highest price from now to the future; 0.1 means stop 91 | price is 110% of the lowest price from now to the future. 92 | """ 93 | def new_tracker(self, current_price, inverse): 94 | ratio = -self.ratio if inverse else self.ratio 95 | return TrailingStopTracker(current_price, ratio, self.callback) 96 | 97 | 98 | # ----------------------------------------------------------------------------- 99 | 100 | 101 | class DecayTrailingStopTracker(TrailingStopTracker): 102 | def __init__(self, current_price, ratio, target, decay_rate, max_decay, callback): 103 | self.initial_ratio = ratio 104 | self.max_decay = max_decay 105 | self.decay_rate = decay_rate 106 | self.target = target 107 | super().__init__(current_price, ratio, callback) 108 | 109 | @property 110 | def current(self): 111 | raise NotImplementedError("abstractmethod") 112 | 113 | @property 114 | def stop_price(self): 115 | decay = max(self.decay_rate ** (self.current / self.target), self.max_decay) 116 | self.ratio = self.initial_ratio * decay 117 | return self.recorded_price * (1 + self.ratio) 118 | 119 | 120 | class PnLDecayTrailingStopTracker(DecayTrailingStopTracker): 121 | @property 122 | def current(self): 123 | pos = self.tracking_position 124 | pnl = (self.recorded_price / pos.average_price - 1) * sign(pos.shares) 125 | pnl = max(pnl, 0) if self.target > 0 else min(pnl, 0) 126 | return pnl 127 | 128 | 129 | class PnLDecayTrailingStopModel(StopModel): 130 | """ 131 | Exponential decay to the stop ratio: `ratio * decay_rate ^ (PnL% / PnL_target%)`. 132 | If it's stop gain model, `PnL_target` should be Loss Target (negative). 133 | 134 | So, the lower the `ratio` when PnL% approaches the target, and if PnL% exceeds PnL_target%, 135 | any small opposite changes will trigger stop. 136 | """ 137 | 138 | def __init__(self, ratio: float, pnl_target: float, callback=None, 139 | decay_rate=0.05, max_decay=0): 140 | super().__init__(ratio, callback) 141 | self.decay_rate = decay_rate 142 | self.pnl_target = pnl_target 143 | self.max_decay = max_decay 144 | 145 | def new_tracker(self, current_price, inverse): 146 | ratio = -self.ratio if inverse else self.ratio 147 | return PnLDecayTrailingStopTracker( 148 | current_price, ratio, self.pnl_target, self.decay_rate, self.max_decay, self.callback) 149 | 150 | 151 | class TimeDecayTrailingStopTracker(DecayTrailingStopTracker): 152 | @property 153 | def current(self): 154 | pos = self.tracking_position 155 | return pos.period 156 | 157 | 158 | class TimeDecayTrailingStopModel(StopModel): 159 | def __init__(self, ratio: float, period_target: 'pd.Timedelta', callback=None, 160 | decay_rate=0.05, max_decay=0): 161 | super().__init__(ratio, callback) 162 | self.decay_rate = decay_rate 163 | self.period_target = period_target 164 | self.max_decay = max_decay 165 | 166 | def new_tracker(self, current_price, inverse): 167 | ratio = -self.ratio if inverse else self.ratio 168 | return TimeDecayTrailingStopTracker( 169 | current_price, ratio, self.period_target, self.decay_rate, self.max_decay, 170 | self.callback) 171 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Heerozh/spectre/c0e7fd974227b376a5790cc66d49f65d58929af8/tests/__init__.py -------------------------------------------------------------------------------- /tests/data/5mins/AAPL_2018.csv: -------------------------------------------------------------------------------- 1 | Date,Open,High,Low,Close,Volume,BarCount,Wap 2 | 2018-12-28 09:30:00-05:00,157.48,157.65,155.76,156.03,14530,4350,156.849 3 | 2018-12-28 09:35:00-05:00,156.01,156.05,155.19,155.39,8493,4002,155.539 4 | 2018-12-28 09:40:00-05:00,155.42,156.09,155.19,156.05,6825,3040,155.681 5 | 2018-12-28 09:45:00-05:00,156.08,156.84,156.02,156.39,9933,4339,156.506 6 | 2018-12-28 09:50:00-05:00,156.4,156.59,155.84,156.42,7377,3527,156.208 7 | 2018-12-28 09:55:00-05:00,156.45,156.88,156.11,156.73,5646,2717,156.479 8 | 2018-12-28 10:00:00-05:00,156.75,156.84,155.8,156.0,5533,2701,156.255 9 | 2018-12-28 10:05:00-05:00,156.02,156.75,156.02,156.62,7276,3691,156.534 10 | 2018-12-28 10:10:00-05:00,156.63,156.77,155.81,156.06,6216,2970,156.293 11 | 2018-12-28 10:15:00-05:00,156.08,156.45,155.9,156.33,4541,2229,156.208 12 | 2018-12-28 10:20:00-05:00,156.35,156.5,155.59,155.73,5287,2537,156.113 13 | 2018-12-28 10:25:00-05:00,155.73,155.85,155.16,155.27,5340,2661,155.449 14 | 2018-12-28 10:30:00-05:00,155.26,155.78,155.12,155.29,4231,2074,155.413 15 | 2018-12-28 10:35:00-05:00,155.28,155.31,154.55,155.12,6388,3004,154.854 16 | 2018-12-28 10:40:00-05:00,155.1,155.44,154.9,154.9,4302,2056,155.183 17 | 2018-12-28 10:45:00-05:00,154.9,155.72,154.9,155.56,4212,2067,155.377 18 | 2018-12-28 10:50:00-05:00,155.59,155.99,155.49,155.59,5054,2425,155.797 19 | 2018-12-28 10:55:00-05:00,155.59,155.73,155.22,155.31,2992,1379,155.451 20 | 2018-12-28 11:00:00-05:00,155.32,155.5,155.07,155.16,2550,1394,155.303 21 | 2018-12-28 11:05:00-05:00,155.16,155.74,155.0,155.66,3333,1718,155.311 22 | 2018-12-28 11:10:00-05:00,155.64,155.9,155.59,155.78,4292,1816,155.775 23 | 2018-12-28 11:15:00-05:00,155.78,156.22,155.55,155.68,6413,2925,155.854 24 | 2018-12-28 11:20:00-05:00,155.66,155.84,155.44,155.64,2844,1457,155.617 25 | 2018-12-28 11:25:00-05:00,155.64,156.2,155.59,156.08,3821,1654,155.913 26 | 2018-12-28 11:30:00-05:00,156.07,156.25,155.87,156.22,3717,1776,156.109 27 | 2018-12-28 11:35:00-05:00,156.24,156.93,156.21,156.86,5446,2236,156.597 28 | 2018-12-28 11:40:00-05:00,156.86,157.11,156.73,156.74,4495,1927,156.897 29 | 2018-12-28 11:45:00-05:00,156.75,156.82,156.47,156.82,3335,1365,156.661 30 | 2018-12-28 11:50:00-05:00,156.81,157.2,156.63,157.02,3219,1377,156.944 31 | 2018-12-28 11:55:00-05:00,157.02,157.09,156.8,156.92,2776,1195,156.973 32 | 2018-12-28 12:00:00-05:00,156.93,157.23,156.68,157.22,3919,1594,156.939 33 | 2018-12-28 12:05:00-05:00,157.2,157.21,156.65,156.7,3105,1272,156.987 34 | 2018-12-28 12:10:00-05:00,156.72,156.81,156.4,156.49,3073,1286,156.564 35 | 2018-12-28 12:15:00-05:00,156.49,156.56,156.2,156.48,3369,1459,156.393 36 | 2018-12-28 12:20:00-05:00,156.48,156.96,156.48,156.83,3758,1671,156.719 37 | 2018-12-28 12:25:00-05:00,156.83,156.88,156.51,156.68,2637,1244,156.664 38 | 2018-12-28 12:30:00-05:00,156.67,156.92,156.61,156.77,2327,1135,156.753 39 | 2018-12-28 12:35:00-05:00,156.76,156.76,156.39,156.7,2507,1443,156.583 40 | 2018-12-28 12:40:00-05:00,156.71,156.81,156.54,156.57,2913,1232,156.679 41 | 2018-12-28 12:45:00-05:00,156.57,156.63,156.33,156.43,2341,1174,156.457 42 | 2018-12-28 12:50:00-05:00,156.43,156.64,156.22,156.51,2540,1175,156.426 43 | 2018-12-28 12:55:00-05:00,156.53,156.58,156.03,156.21,2431,1100,156.253 44 | 2018-12-28 13:00:00-05:00,156.22,156.36,156.14,156.17,1835,1017,156.251 45 | 2018-12-28 13:05:00-05:00,156.17,156.37,156.07,156.29,2076,1050,156.216 46 | 2018-12-28 13:10:00-05:00,156.29,156.38,156.07,156.29,2189,1182,156.255 47 | 2018-12-28 13:15:00-05:00,156.29,156.42,156.06,156.07,2135,855,156.229 48 | 2018-12-28 13:20:00-05:00,156.08,156.13,155.53,155.55,3625,1417,155.892 49 | 2018-12-28 13:25:00-05:00,155.54,155.77,155.41,155.71,3012,1368,155.581 50 | 2018-12-28 13:30:00-05:00,155.68,155.76,155.27,155.4,2850,1304,155.468 51 | 2018-12-28 13:35:00-05:00,155.39,155.6,155.29,155.47,1733,704,155.429 52 | 2018-12-28 13:40:00-05:00,155.46,155.67,155.38,155.46,1712,710,155.528 53 | 2018-12-28 13:45:00-05:00,155.47,156.1,155.47,155.99,2716,1133,155.854 54 | 2018-12-28 13:50:00-05:00,155.98,156.24,155.82,155.9,2310,1072,156.029 55 | 2018-12-28 13:55:00-05:00,155.9,156.21,155.9,156.03,1814,770,156.031 56 | 2018-12-28 14:00:00-05:00,156.05,156.66,156.05,156.61,3788,1495,156.444 57 | 2018-12-28 14:05:00-05:00,156.63,156.98,156.62,156.91,3227,1420,156.81 58 | 2018-12-28 14:10:00-05:00,156.89,157.35,156.76,157.22,4298,1810,157.114 59 | 2018-12-28 14:15:00-05:00,157.23,157.58,157.22,157.26,4489,2017,157.43 60 | 2018-12-28 14:20:00-05:00,157.25,157.45,157.04,157.37,3218,1411,157.265 61 | 2018-12-28 14:25:00-05:00,157.37,157.58,157.22,157.29,2641,1237,157.412 62 | 2018-12-28 14:30:00-05:00,157.29,157.45,157.04,157.45,3288,1398,157.228 63 | 2018-12-28 14:35:00-05:00,157.43,157.65,157.32,157.63,3294,1656,157.468 64 | 2018-12-28 14:40:00-05:00,157.65,158.1,157.49,157.99,5939,2373,157.848 65 | 2018-12-28 14:45:00-05:00,158.0,158.05,157.81,157.87,3579,1737,157.946 66 | 2018-12-28 14:50:00-05:00,157.88,158.52,157.82,158.2,5207,2384,158.115 67 | 2018-12-28 14:55:00-05:00,158.19,158.47,158.11,158.41,4225,1955,158.303 68 | 2018-12-28 15:00:00-05:00,158.4,158.51,157.95,158.02,3988,2069,158.24 69 | 2018-12-28 15:05:00-05:00,158.02,158.04,157.12,157.39,5006,2637,157.536 70 | 2018-12-28 15:10:00-05:00,157.38,157.67,156.8,157.58,5464,2690,157.17 71 | 2018-12-28 15:15:00-05:00,157.58,157.59,156.84,156.95,3124,1621,157.235 72 | 2018-12-28 15:20:00-05:00,156.93,157.04,156.67,157.01,4691,2531,156.862 73 | 2018-12-28 15:25:00-05:00,157.01,157.13,156.23,156.54,5140,2827,156.53 74 | 2018-12-28 15:30:00-05:00,156.54,156.63,156.16,156.47,4636,2461,156.407 75 | 2018-12-28 15:35:00-05:00,156.47,156.71,156.14,156.2,4395,2461,156.387 76 | 2018-12-28 15:40:00-05:00,156.19,156.46,155.9,155.94,5763,3510,156.215 77 | 2018-12-28 15:45:00-05:00,155.92,156.24,155.71,156.1,5309,3007,155.956 78 | 2018-12-28 15:50:00-05:00,156.1,156.33,155.67,155.75,5943,3780,156.01 79 | 2018-12-28 15:55:00-05:00,155.75,156.5,155.68,156.23,12132,7388,156.05 80 | 2018-12-31 09:30:00-05:00,158.5,158.85,157.96,158.71,11487,3277,158.395 81 | 2018-12-31 09:35:00-05:00,158.69,159.36,158.68,159.04,8422,3572,159.054 82 | 2018-12-31 09:40:00-05:00,159.04,159.17,158.7,158.77,4954,2159,158.883 83 | 2018-12-31 09:45:00-05:00,158.78,158.88,158.29,158.31,4922,2010,158.581 84 | 2018-12-31 09:50:00-05:00,158.32,158.33,157.67,157.75,5488,2321,157.919 85 | 2018-12-31 09:55:00-05:00,157.74,157.94,157.48,157.93,6145,1986,157.715 86 | 2018-12-31 10:00:00-05:00,157.92,158.23,157.91,158.01,4460,1626,158.084 87 | 2018-12-31 10:05:00-05:00,157.99,158.24,157.45,157.47,5103,2032,157.748 88 | 2018-12-31 10:10:00-05:00,157.46,157.69,157.0,157.03,5865,2401,157.407 89 | 2018-12-31 10:15:00-05:00,157.03,157.17,156.88,157.09,4178,1808,157.021 90 | 2018-12-31 10:20:00-05:00,157.07,157.23,156.93,157.06,3010,1582,157.094 91 | 2018-12-31 10:25:00-05:00,157.05,157.48,157.05,157.46,3333,1669,157.276 92 | 2018-12-31 10:30:00-05:00,157.47,157.6,156.81,156.9,4413,1990,157.136 93 | 2018-12-31 10:35:00-05:00,156.9,156.9,156.51,156.7,3970,1727,156.698 94 | 2018-12-31 10:40:00-05:00,156.68,157.0,156.63,156.88,4122,1707,156.811 95 | 2018-12-31 10:45:00-05:00,156.86,157.04,156.68,156.72,2654,1403,156.896 96 | 2018-12-31 10:50:00-05:00,156.71,156.81,156.54,156.55,2362,1134,156.675 97 | 2018-12-31 10:55:00-05:00,156.54,157.07,156.51,157.06,4184,1529,156.75 98 | 2018-12-31 11:00:00-05:00,157.06,157.39,156.91,157.36,3507,1648,157.141 99 | 2018-12-31 11:05:00-05:00,157.35,157.58,157.24,157.39,3958,1789,157.391 100 | 2018-12-31 11:10:00-05:00,157.4,157.43,157.11,157.33,3269,1602,157.257 101 | 2018-12-31 11:15:00-05:00,157.32,157.5,157.14,157.31,3153,1466,157.349 102 | 2018-12-31 11:20:00-05:00,157.32,157.35,156.94,156.98,3473,1694,157.074 103 | 2018-12-31 11:25:00-05:00,156.96,157.7,156.93,157.56,3962,1998,157.363 104 | 2018-12-31 11:30:00-05:00,157.52,157.62,157.07,157.17,3163,1530,157.399 105 | 2018-12-31 11:35:00-05:00,157.18,157.68,156.99,157.55,3753,1770,157.424 106 | 2018-12-31 11:40:00-05:00,157.54,157.61,157.33,157.6,3133,1718,157.455 107 | 2018-12-31 11:45:00-05:00,157.58,157.92,157.51,157.89,3430,1571,157.713 108 | 2018-12-31 11:50:00-05:00,157.89,158.17,157.69,158.08,4052,1510,157.994 109 | 2018-12-31 11:55:00-05:00,158.07,158.1,157.77,157.96,2303,1086,157.935 110 | 2018-12-31 12:00:00-05:00,157.99,158.02,157.74,157.91,2049,918,157.902 111 | 2018-12-31 12:05:00-05:00,157.94,158.1,157.8,158.05,2311,1026,157.975 112 | 2018-12-31 12:10:00-05:00,158.02,158.09,157.77,158.05,3076,1433,157.929 113 | 2018-12-31 12:15:00-05:00,158.05,158.24,157.83,158.14,3164,1309,158.073 114 | 2018-12-31 12:20:00-05:00,158.13,158.48,158.11,158.31,3742,1573,158.33 115 | 2018-12-31 12:25:00-05:00,158.3,158.58,158.26,158.53,3390,1408,158.442 116 | 2018-12-31 12:30:00-05:00,158.52,158.72,158.51,158.7,2436,1161,158.637 117 | 2018-12-31 12:35:00-05:00,158.69,158.85,158.63,158.83,2101,1090,158.749 118 | 2018-12-31 12:40:00-05:00,158.84,158.94,158.68,158.93,2507,1316,158.81 119 | 2018-12-31 12:45:00-05:00,158.92,158.93,158.58,158.6,1787,777,158.741 120 | 2018-12-31 12:50:00-05:00,158.6,158.61,158.3,158.38,2516,1178,158.407 121 | 2018-12-31 12:55:00-05:00,158.37,158.44,158.17,158.37,2206,1008,158.305 122 | 2018-12-31 13:00:00-05:00,158.38,158.48,158.13,158.24,2175,950,158.343 123 | 2018-12-31 13:05:00-05:00,158.22,158.3,158.09,158.23,2073,1001,158.201 124 | 2018-12-31 13:10:00-05:00,158.25,158.43,158.18,158.36,2329,1056,158.332 125 | 2018-12-31 13:15:00-05:00,158.36,158.48,158.28,158.3,2355,1013,158.406 126 | 2018-12-31 13:20:00-05:00,158.31,158.36,158.25,158.33,1367,589,158.298 127 | 2018-12-31 13:25:00-05:00,158.33,158.43,158.25,158.4,1178,533,158.35 128 | 2018-12-31 13:30:00-05:00,158.4,158.44,158.15,158.19,1647,828,158.28 129 | 2018-12-31 13:35:00-05:00,158.17,158.2,157.85,157.91,2175,999,158.002 130 | 2018-12-31 13:40:00-05:00,157.9,158.02,157.7,157.74,2708,1183,157.902 131 | 2018-12-31 13:45:00-05:00,157.75,157.8,157.63,157.76,1756,803,157.705 132 | 2018-12-31 13:50:00-05:00,157.76,157.85,157.43,157.55,2624,1153,157.616 133 | 2018-12-31 13:55:00-05:00,157.54,157.67,157.5,157.57,1422,762,157.589 134 | 2018-12-31 14:00:00-05:00,157.58,157.6,157.33,157.39,2295,1189,157.472 135 | 2018-12-31 14:05:00-05:00,157.4,157.82,157.36,157.68,2740,1145,157.673 136 | 2018-12-31 14:10:00-05:00,157.68,157.72,157.41,157.49,2317,906,157.563 137 | 2018-12-31 14:15:00-05:00,157.48,157.79,157.29,157.77,2965,1441,157.569 138 | 2018-12-31 14:20:00-05:00,157.77,157.79,157.2,157.35,3435,1518,157.469 139 | 2018-12-31 14:25:00-05:00,157.34,157.38,157.05,157.09,2088,905,157.244 140 | 2018-12-31 14:30:00-05:00,157.07,157.17,156.92,156.97,2976,1275,157.023 141 | 2018-12-31 14:35:00-05:00,156.96,157.37,156.85,157.32,2535,1311,157.081 142 | 2018-12-31 14:40:00-05:00,157.32,157.46,157.22,157.45,2421,876,157.333 143 | 2018-12-31 14:45:00-05:00,157.42,157.61,157.34,157.44,1626,610,157.511 144 | 2018-12-31 14:50:00-05:00,157.45,157.74,157.44,157.56,2233,1126,157.636 145 | 2018-12-31 14:55:00-05:00,157.55,157.67,157.3,157.37,1931,903,157.517 146 | 2018-12-31 15:00:00-05:00,157.36,157.69,157.25,157.29,2686,1399,157.485 147 | 2018-12-31 15:05:00-05:00,157.31,157.47,157.26,157.29,1841,1014,157.384 148 | 2018-12-31 15:10:00-05:00,157.28,157.37,157.07,157.09,1909,989,157.217 149 | 2018-12-31 15:15:00-05:00,157.1,157.26,156.9,157.26,3223,1568,157.066 150 | 2018-12-31 15:20:00-05:00,157.25,157.36,156.99,157.07,2318,1119,157.187 151 | 2018-12-31 15:25:00-05:00,157.06,157.33,156.97,157.0,3231,1581,157.142 152 | 2018-12-31 15:30:00-05:00,157.0,157.44,157.0,157.3,2822,1376,157.304 153 | 2018-12-31 15:35:00-05:00,157.29,157.34,157.08,157.25,3303,1454,157.205 154 | 2018-12-31 15:40:00-05:00,157.26,157.48,157.14,157.34,4036,2078,157.337 155 | 2018-12-31 15:45:00-05:00,157.35,157.49,157.06,157.15,4345,2222,157.262 156 | 2018-12-31 15:50:00-05:00,157.14,157.27,157.02,157.12,5169,2549,157.148 157 | 2018-12-31 15:55:00-05:00,157.14,157.95,156.48,157.94,14342,8274,157.104 158 | -------------------------------------------------------------------------------- /tests/data/dividends/AAPL.csv: -------------------------------------------------------------------------------- 1 | exDate,paymentDate,recordDate,declaredDate,amount,flag,currency,description,frequency,date 2 | 2015-02-05,2015-02-12,2015-02-09,2015-01-27,0.47,Cash,USD,Ordinary Shares,quarterly,2019-12-07 3 | 2015-05-07,2015-05-14,2015-05-11,2015-04-27,0.52,Cash,USD,Ordinary Shares,quarterly,2019-12-07 4 | 2015-08-06,2015-08-13,2015-08-10,2015-07-21,0.52,Cash,USD,Ordinary Shares,quarterly,2019-12-07 5 | 2015-11-05,2015-11-12,2015-11-09,2015-10-27,0.52,Cash,USD,Ordinary Shares,quarterly,2019-12-07 6 | 2016-02-04,2016-02-11,2016-02-08,2016-01-26,0.52,Cash,USD,Ordinary Shares,quarterly,2019-12-07 7 | 2016-05-05,2016-05-12,2016-05-09,2016-04-26,0.57,Cash,USD,Ordinary Shares,quarterly,2019-12-07 8 | 2016-08-04,2016-08-11,2016-08-08,2016-07-26,0.57,Cash,USD,Ordinary Shares,quarterly,2019-12-07 9 | 2016-11-03,2016-11-10,2016-11-07,2016-10-25,0.57,Cash,USD,Ordinary Shares,quarterly,2019-12-07 10 | 2017-02-09,2017-02-16,2017-02-13,2017-01-31,0.57,Cash,USD,Ordinary Shares,quarterly,2019-12-07 11 | 2017-05-11,2017-05-18,2017-05-15,2017-05-02,0.63,Cash,USD,Ordinary Shares,quarterly,2019-12-07 12 | 2017-08-10,2017-08-17,2017-08-14,2017-08-01,0.63,Cash,USD,Ordinary Shares,quarterly,2019-12-07 13 | 2017-11-10,2017-11-16,2017-11-13,2017-11-02,0.63,Cash,USD,Ordinary Shares,quarterly,2019-12-07 14 | 2018-02-09,2018-02-15,2018-02-12,2018-02-01,0.63,Cash,USD,Ordinary Shares,quarterly,2019-12-07 15 | 2018-05-11,2018-05-17,2018-05-14,2018-05-01,0.73,Cash,USD,Ordinary Shares,quarterly,2019-12-07 16 | 2018-08-10,2018-08-16,2018-08-13,2018-07-31,0.73,Cash,USD,Ordinary Shares,quarterly,2019-12-07 17 | 2018-11-08,2018-11-15,2018-11-12,2018-11-01,0.73,Cash,USD,Ordinary Shares,quarterly,2019-12-07 18 | 2019-02-08,2019-02-14,2019-02-11,2019-01-29,0.73,Cash,USD,Ordinary Shares,quarterly,2019-12-07 19 | 2019-05-10,2019-05-16,2019-05-13,2019-04-30,0.77,Cash,USD,Ordinary Shares,quarterly,2019-12-07 20 | 2019-08-09,2019-08-15,2019-08-12,2019-07-30,0.77,Cash,USD,Ordinary Shares,quarterly,2019-12-07 21 | 2019-11-07,2019-11-14,2019-11-11,2019-10-30,0.77,Cash,USD,Ordinary Shares,quarterly,2019-12-07 22 | -------------------------------------------------------------------------------- /tests/data/dividends/MSFT.csv: -------------------------------------------------------------------------------- 1 | exDate,paymentDate,recordDate,declaredDate,amount,flag,currency,description,frequency,date 2 | 2015-02-17,2015-03-12,2015-02-19,2014-12-03,0.31,Cash,USD,Ordinary Shares,quarterly,2019-12-07 3 | 2015-05-19,2015-06-11,2015-05-21,2015-03-10,0.31,Cash,USD,Ordinary Shares,quarterly,2019-12-07 4 | 2015-08-18,2015-09-10,2015-08-20,2015-06-09,0.31,Cash,USD,Ordinary Shares,quarterly,2019-12-07 5 | 2015-11-17,2015-12-10,2015-11-19,2015-09-15,0.36,Cash,USD,Ordinary Shares,quarterly,2019-12-07 6 | 2016-02-16,2016-03-10,2016-02-18,2015-12-02,0.36,Cash,USD,Ordinary Shares,quarterly,2019-12-07 7 | 2016-05-17,2016-06-09,2016-05-19,2016-03-15,0.36,Cash,USD,Ordinary Shares,quarterly,2019-12-07 8 | 2016-08-16,2016-09-08,2016-08-18,2016-06-14,0.36,Cash,USD,Ordinary Shares,quarterly,2019-12-07 9 | 2016-11-15,2016-12-08,2016-11-17,2016-09-20,0.39,Cash,USD,Ordinary Shares,quarterly,2019-12-07 10 | 2017-02-14,2017-03-09,2017-02-16,2016-11-30,0.39,Cash,USD,Ordinary Shares,quarterly,2019-12-07 11 | 2017-05-16,2017-06-08,2017-05-18,2017-03-14,0.39,Cash,USD,Ordinary Shares,quarterly,2019-12-07 12 | 2017-08-15,2017-09-14,2017-08-17,2017-06-13,0.39,Cash,USD,Ordinary Shares,quarterly,2019-12-07 13 | 2017-11-15,2017-12-14,2017-11-16,2017-09-19,0.42,Cash,USD,Ordinary Shares,quarterly,2019-12-07 14 | 2018-02-14,2018-03-08,2018-02-15,2017-11-29,0.42,Cash,USD,Ordinary Shares,quarterly,2019-12-07 15 | 2018-05-16,2018-06-14,2018-05-17,2018-03-12,0.42,Cash,USD,Ordinary Shares,quarterly,2019-12-07 16 | 2018-08-15,2018-09-13,2018-08-16,2018-06-13,0.42,Cash,USD,Ordinary Shares,quarterly,2019-12-07 17 | 2018-11-14,2018-12-13,2018-11-15,2018-09-18,0.46,Cash,USD,Ordinary Shares,quarterly,2019-12-07 18 | 2019-01-11,2019-03-14,2019-02-21,2018-11-28,0.46,Cash,USD,Ordinary Shares,quarterly,2019-12-07 19 | 2019-01-11,2019-03-14,2019-02-21,2018-11-28,nan,Cash,USD,Ordinary Shares,quarterly,2019-12-07 20 | 2019-01-11,2019-03-14,2019-02-21,2018-11-28,0.11,Cash,USD,Ordinary Shares,quarterly,2019-12-07 21 | 2019-05-15,2019-06-13,2019-05-16,2019-03-11,0.46,Cash,USD,Ordinary Shares,quarterly,2019-12-07 22 | 2019-08-14,2019-09-12,2019-08-15,2019-06-12,0.46,Cash,USD,Ordinary Shares,quarterly,2019-12-07 23 | 2019-11-20,2019-12-12,2019-11-21,2019-09-18,0.51,Cash,USD,Ordinary Shares,quarterly,2019-12-07 24 | -------------------------------------------------------------------------------- /tests/data/splits/AAPL.csv: -------------------------------------------------------------------------------- 1 | exDate,date 2 | ,2019-12-07 23:30:17.719069 3 | -------------------------------------------------------------------------------- /tests/data/splits/MSFT.csv: -------------------------------------------------------------------------------- 1 | exDate,declaredDate,ratio,toFactor,fromFactor,description,date 2 | 2019-01-14,2017-07-10,5,1,15,1-for-15 Reverse Split,2019-12-07 3 | 2019-01-14,2017-07-10,15,1,15,1-for-15 Reverse Split,2019-12-07 4 | 2019-01-14,2017-05-12,nan,1,15,1-for-15 Reverse Split,2019-12-07 5 | -------------------------------------------------------------------------------- /tests/test_custom_factor.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import spectre 3 | import numpy as np 4 | from numpy.testing import assert_array_equal 5 | import torch 6 | import warnings 7 | from os.path import dirname 8 | 9 | data_dir = dirname(__file__) + '/data/' 10 | 11 | 12 | class TestCustomFactorLib(unittest.TestCase): 13 | 14 | def test_custom_factor(self): 15 | warnings.filterwarnings("ignore", module='spectre') 16 | # test backward tree 17 | a = spectre.factors.CustomFactor(win=2) 18 | b = spectre.factors.CustomFactor(win=3, inputs=(a,)) 19 | c = spectre.factors.CustomFactor(win=3, inputs=(b,)) 20 | self.assertEqual(5, c.get_total_backwards_()) 21 | m = spectre.factors.CustomFactor(win=10) 22 | c.set_mask(m) 23 | self.assertEqual(9, c.get_total_backwards_()) 24 | 25 | a1 = spectre.factors.CustomFactor(win=10) 26 | a2 = spectre.factors.CustomFactor(win=5) 27 | b1 = spectre.factors.CustomFactor(win=20, inputs=(a1, a2)) 28 | b2 = spectre.factors.CustomFactor(win=100, inputs=(a2,)) 29 | c1 = spectre.factors.CustomFactor(win=100, inputs=(b1,)) 30 | self.assertEqual(9, a1.get_total_backwards_()) 31 | self.assertEqual(4, a2.get_total_backwards_()) 32 | self.assertEqual(28, b1.get_total_backwards_()) 33 | self.assertEqual(103, b2.get_total_backwards_()) 34 | self.assertEqual(127, c1.get_total_backwards_()) 35 | 36 | # test inheritance 37 | loader = spectre.data.CsvDirLoader( 38 | data_dir + '/daily/', ohlcv=('uOpen', 'uHigh', 'uLow', 'uClose', 'uVolume'), 39 | prices_index='date', parse_dates=True, 40 | ) 41 | engine = spectre.factors.FactorEngine(loader) 42 | 43 | class TestFactor(spectre.factors.CustomFactor): 44 | inputs = [spectre.factors.OHLCV.open] 45 | 46 | def compute(self, close): 47 | return torch.tensor(np.arange(close.nelement()).reshape(close.shape)) 48 | 49 | class TestFactor2(spectre.factors.CustomFactor): 50 | inputs = [] 51 | 52 | def compute(self): 53 | return torch.tensor([1]) 54 | 55 | engine.add(TestFactor2(), 'test2') 56 | self.assertRaisesRegex(ValueError, "The return data shape.*test2.*", 57 | engine.run, '2019-01-11', '2019-01-15', False) 58 | engine.remove_all_factors() 59 | test_f1 = TestFactor() 60 | 61 | class TestFactor2(spectre.factors.CustomFactor): 62 | inputs = [test_f1] 63 | 64 | def compute(self, test_input): 65 | return torch.tensor(np.cumsum(test_input.numpy(), axis=1)) 66 | 67 | engine.add(test_f1, 'test1') 68 | self.assertRaisesRegex(KeyError, ".*exists.*", 69 | engine.add, TestFactor(), 'test1') 70 | 71 | engine.add(TestFactor2(), 'test2') 72 | 73 | for f in engine._factors.values(): 74 | f.pre_compute_(engine, '2019-01-11', '2019-01-15') 75 | self.assertEqual(2, test_f1._ref_count) 76 | for f in engine._factors.values(): 77 | f._ref_count = 0 78 | 79 | df = engine.run('2019-01-11', '2019-01-15', delay_factor=False) 80 | self.assertEqual(0, test_f1._ref_count) 81 | assert_array_equal([0, 3, 1, 4, 2, 5], df['test1'].values) 82 | assert_array_equal([0, 3, 1, 7, 3, 12], df['test2'].values) 83 | -------------------------------------------------------------------------------- /tests/test_data_factor.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import spectre 3 | from numpy.testing import assert_array_equal 4 | from os.path import dirname 5 | import warnings 6 | 7 | data_dir = dirname(__file__) + '/data/' 8 | 9 | 10 | class TestDataFactorLib(unittest.TestCase): 11 | def test_datafactor_value(self): 12 | warnings.filterwarnings("ignore", module='spectre') 13 | loader = spectre.data.CsvDirLoader( 14 | data_dir + '/daily/', 15 | ohlcv=('uOpen', 'uHigh', 'uLow', 'uClose', 'uVolume'), 16 | prices_index='date', parse_dates=True, 17 | ) 18 | engine = spectre.factors.FactorEngine(loader) 19 | engine.add(spectre.factors.OHLCV.volume, 'CpVol') 20 | df = engine.run('2019-01-11', '2019-01-15') 21 | assert_array_equal(df.loc[(slice(None), 'AAPL'), 'CpVol'].values, 22 | (28065422, 33834032)) 23 | assert_array_equal(df.loc[(slice(None), 'MSFT'), 'CpVol'].values, 24 | (28627674, 28720936)) 25 | 26 | engine.add(spectre.factors.ColumnDataFactor(inputs=('changePercent',)), 'Chg') 27 | df = engine.run('2019-01-11', '2019-01-15') 28 | assert_array_equal(df.loc[(slice(None), 'AAPL'), 'Chg'].values, 29 | (-0.9835, -1.5724)) 30 | assert_array_equal(df.loc[(slice(None), 'MSFT'), 'Chg'].values, 31 | (-0.8025, -0.7489)) 32 | 33 | engine.remove_all_factors() 34 | engine.add(spectre.factors.OHLCV.open, 'open') 35 | df = engine.run('2019-01-11', '2019-01-15', delay_factor=False) 36 | assert_array_equal(df.loc[(slice(None), 'AAPL'), 'open'].values, 37 | (155.72, 155.19, 150.81)) 38 | assert_array_equal(df.loc[(slice(None), 'MSFT'), 'open'].values, 39 | (104.65, 104.9, 103.19)) 40 | 41 | df = engine.run('2019-01-11', '2019-01-15') 42 | assert_array_equal(df.loc[(slice(None), 'AAPL'), 'open'].values, 43 | (155.72, 155.19, 150.81)) 44 | assert_array_equal(df.loc[(slice(None), 'MSFT'), 'open'].values, 45 | (104.65, 104.9, 103.19)) -------------------------------------------------------------------------------- /tests/test_data_loader.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import spectre 3 | import os 4 | import pandas as pd 5 | import numpy as np 6 | from numpy.testing import assert_almost_equal 7 | from os.path import dirname 8 | import warnings 9 | 10 | data_dir = dirname(__file__) + '/data/' 11 | 12 | 13 | class TestDataLoaderLib(unittest.TestCase): 14 | def _assertDFFirstLastEqual(self, tdf, col, expected_first, expected_last): 15 | self.assertAlmostEqual(tdf.loc[tdf.index[0], col], expected_first) 16 | self.assertAlmostEqual(tdf.loc[tdf.index[-1], col], expected_last) 17 | 18 | def test_required_parameters(self): 19 | loader = spectre.data.CsvDirLoader(data_dir + '/daily/') 20 | self.assertRaisesRegex(ValueError, "df must index by datetime.*", 21 | loader.load, '2019-01-01', '2019-01-15', 0) 22 | loader = spectre.data.CsvDirLoader(data_dir + '/daily/', prices_index='date', ) 23 | self.assertRaisesRegex(ValueError, "df must index by datetime.*", 24 | loader.load, '2019-01-01', '2019-01-15', 0) 25 | 26 | def test_csv_loader_value(self): 27 | loader = spectre.data.CsvDirLoader( 28 | data_dir + '/daily/', calender_asset='AAPL', prices_index='date', parse_dates=True, ) 29 | start, end = pd.Timestamp('2019-01-01', tz='UTC'), pd.Timestamp('2019-01-15', tz='UTC') 30 | 31 | # test backward 32 | df = loader.load(start, end, 11) 33 | self._assertDFFirstLastEqual(df.loc[(slice(None), 'AAPL'), :], 'close', 173.43, 158.09) 34 | self._assertDFFirstLastEqual(df.loc[(slice(None), 'MSFT'), :], 'close', 106.57, 105.36) 35 | 36 | # test value 37 | df = loader.load(start, end, 0) 38 | self._assertDFFirstLastEqual(df.loc[(slice(None), 'AAPL'), :], 'close', 160.35, 158.09) 39 | self._assertDFFirstLastEqual(df.loc[(slice(None), 'MSFT'), :], 'close', 100.1, 105.36) 40 | self._assertDFFirstLastEqual(df.loc[(slice('2019-01-11', '2019-01-12'), 'MSFT'), :], 41 | 'close', 104.5, 104.5) 42 | start, end = pd.Timestamp('2019-01-11', tz='UTC'), pd.Timestamp('2019-01-12', tz='UTC') 43 | df = loader.load(start, end, 0) 44 | self._assertDFFirstLastEqual(df.loc[(slice(None), 'MSFT'), :], 'close', 104.5, 104.5) 45 | 46 | loader.test_load() 47 | 48 | def test_csv_split_loader_value(self): 49 | loader = spectre.data.CsvDirLoader( 50 | data_dir + '/5mins/', prices_by_year=True, prices_index='Date', parse_dates=True, 51 | ohlcv=None) 52 | start = pd.Timestamp('2019-01-02 14:30:00', tz='UTC') 53 | end = pd.Timestamp('2019-01-15', tz='UTC') 54 | loader.load(start, end, 0) 55 | 56 | start = pd.Timestamp('2018-12-31 14:50:00', tz='America/New_York').tz_convert('UTC') 57 | end = pd.Timestamp('2019-01-02 10:00:00', tz='America/New_York').tz_convert('UTC') 58 | df = loader.load(start, end, 0) 59 | self._assertDFFirstLastEqual(df.loc[(slice(None), 'AAPL'), :], 'Open', 157.45, 155.17) 60 | self._assertDFFirstLastEqual(df.loc[(slice(None), 'MSFT'), :], 'Open', 101.44, 99.55) 61 | 62 | loader.test_load() 63 | 64 | def test_csv_div_split(self): 65 | warnings.filterwarnings("ignore", module='spectre') 66 | start, end = pd.Timestamp('2019-01-02', tz='UTC'), pd.Timestamp('2019-01-15', tz='UTC') 67 | loader = spectre.data.CsvDirLoader( 68 | prices_path=data_dir + '/daily/', earliest_date=start.tz_convert(None), 69 | calender_asset='AAPL', 70 | dividends_path=data_dir + '/dividends/', splits_path=data_dir + '/splits/', 71 | ohlcv=('uOpen', 'uHigh', 'uLow', 'uClose', 'uVolume'), adjustments=('amount', 'ratio'), 72 | prices_index='date', dividends_index='exDate', splits_index='exDate', 73 | parse_dates=True, ) 74 | loader.test_load() 75 | 76 | df = loader.load(start, end, 0) 77 | 78 | # test value 79 | self.assertAlmostEqual(df.loc[('2019-01-09', 'MSFT'), 'ex-dividend'], 0.57) 80 | 81 | # test adjustments in engine 82 | engine = spectre.factors.FactorEngine(loader) 83 | engine.add(spectre.factors.AdjustedColumnDataFactor(spectre.factors.OHLCV.volume), 'vol') 84 | engine.add(spectre.factors.AdjustedColumnDataFactor(spectre.factors.OHLCV.open), 'open') 85 | df = engine.run(start, end, delay_factor=False) 86 | 87 | expected_msft_open = [1526.24849, 1548.329113, 1536.244448, 1541.16783, 1563.696033, 88 | 1585.47827, 1569.750105, 104.9, 103.19] 89 | expected_msft_vol = [2947962.0000, 3067160.6000, 2443784.2667, 2176777.6000, 90 | 2190846.8000, 2018093.5333, 1908511.6000, 28720936.0000, 32882983.0000] 91 | expected_aapl_open = [155.9200, 147.6300, 148.8400, 148.9000, 150.0000, 157.4400, 154.1000, 92 | 155.7200, 155.1900, 150.8100] 93 | expected_aapl_vol = [37932561, 92707401, 59457561, 56974905, 42839940, 45105063, 94 | 35793075, 28065422, 33834032, 29426699] 95 | 96 | assert_almost_equal(df.loc[(slice(None), 'MSFT'), 'open'], expected_msft_open, decimal=4) 97 | assert_almost_equal(df.loc[(slice(None), 'AAPL'), 'open'], expected_aapl_open, decimal=4) 98 | assert_almost_equal(df.loc[(slice(None), 'MSFT'), 'vol'], expected_msft_vol, decimal=0) 99 | assert_almost_equal(df.loc[(slice(None), 'AAPL'), 'vol'], expected_aapl_vol, decimal=4) 100 | 101 | # rolling adj test 102 | result = [] 103 | 104 | class RollingAdjTest(spectre.factors.CustomFactor): 105 | win = 10 106 | 107 | def compute(self, data): 108 | result.append(data.agg(lambda x: x[:, -1])) 109 | return data.last() 110 | 111 | engine = spectre.factors.FactorEngine(loader) 112 | engine.add(RollingAdjTest(inputs=[spectre.factors.OHLCV.volume]), 'vol') 113 | engine.add(RollingAdjTest(inputs=[spectre.factors.OHLCV.open]), 'open') 114 | engine.run(end, end, delay_factor=False) 115 | 116 | assert_almost_equal(result[0][0], expected_aapl_vol, decimal=4) 117 | assert_almost_equal(result[0][1], expected_msft_vol+[np.nan], decimal=0) 118 | assert_almost_equal(result[1][0], expected_aapl_open, decimal=4) 119 | assert_almost_equal(result[1][1], expected_msft_open+[np.nan], decimal=4) 120 | 121 | def test_no_ohlcv(self): 122 | warnings.filterwarnings("ignore", module='spectre') 123 | start, end = pd.Timestamp('2019-01-02'), pd.Timestamp('2019-01-15') 124 | loader = spectre.data.CsvDirLoader( 125 | prices_path=data_dir + '/daily/', earliest_date=start, calender_asset='AAPL', 126 | ohlcv=None, adjustments=None, 127 | prices_index='date', 128 | parse_dates=True, ) 129 | engine = spectre.factors.FactorEngine(loader) 130 | engine.add(spectre.factors.ColumnDataFactor(inputs=['uOpen']), 'open') 131 | engine.run(start, end, delay_factor=False) 132 | 133 | @unittest.skipUnless(os.getenv('COVERAGE_RUNNING'), "too slow, run manually") 134 | def test_yahoo(self): 135 | yahoo_path = data_dir + '/yahoo/' 136 | try: 137 | os.remove(yahoo_path + 'yahoo.feather') 138 | os.remove(yahoo_path + 'yahoo.feather.meta') 139 | except FileNotFoundError: 140 | pass 141 | 142 | spectre.data.YahooDownloader.ingest("2011", yahoo_path, ['IBM', 'AAPL'], skip_exists=False) 143 | loader = spectre.data.ArrowLoader(yahoo_path + 'yahoo.feather') 144 | df = loader._load() 145 | self.assertEqual(['AAPL', 'IBM'], list(df.index.levels[1])) 146 | 147 | @unittest.skipUnless(os.getenv('COVERAGE_RUNNING'), "too slow, run manually") 148 | def test_QuandlLoader(self): 149 | quandl_path = data_dir + '../../../historical_data/us/prices/quandl/' 150 | try: 151 | os.remove(quandl_path + 'wiki_prices.feather') 152 | os.remove(quandl_path + 'wiki_prices.feather.meta') 153 | except FileNotFoundError: 154 | pass 155 | 156 | spectre.data.ArrowLoader.ingest( 157 | spectre.data.QuandlLoader(quandl_path + 'WIKI_PRICES.zip'), 158 | quandl_path + 'wiki_prices.feather' 159 | ) 160 | 161 | loader = spectre.data.ArrowLoader(quandl_path + 'wiki_prices.feather') 162 | 163 | spectre.parallel.Rolling._split_multi = 80 164 | engine = spectre.factors.FactorEngine(loader) 165 | engine.add(spectre.factors.MA(100), 'ma') 166 | engine.to_cuda() 167 | df = engine.run("2014-01-02", "2014-01-02", delay_factor=False) 168 | # expected result comes from zipline 169 | assert_almost_equal(df.head().values.T, 170 | [[51.388700, 49.194407, 599.280580, 28.336585, 12.7058]], decimal=4) 171 | assert_almost_equal(df.tail().values.T, 172 | [[86.087988, 3.602880, 7.364000, 31.428209, 27.605950]], decimal=4) 173 | 174 | # test last line bug 175 | engine.run("2016-12-15", "2017-01-02") 176 | df = engine._dataframe.loc[(slice('2016-12-15', '2017-12-15'), 'STJ'), :] 177 | assert df.price_multi.values[-1] == 1 178 | 179 | def test_fast_get(self): 180 | loader = spectre.data.CsvDirLoader( 181 | data_dir + '/daily/', prices_index='date', parse_dates=True, ) 182 | df = loader.load()[list(loader.ohlcv)] 183 | getter = spectre.data.DataLoaderFastGetter(df) 184 | 185 | table = getter.get_as_dict(pd.Timestamp('2018-01-02', tz='UTC'), column_id=3) 186 | self.assertAlmostEqual(df.loc[("2018-01-02", 'MSFT')].close, table['MSFT']) 187 | self.assertAlmostEqual(df.loc[("2018-01-02", 'AAPL')].close, table['AAPL']) 188 | self.assertRaises(KeyError, table.__getitem__, 'A') 189 | table = dict(table.items()) 190 | self.assertAlmostEqual(df.loc[("2018-01-02", 'MSFT')].close, table['MSFT']) 191 | self.assertAlmostEqual(df.loc[("2018-01-02", 'AAPL')].close, table['AAPL']) 192 | 193 | table = getter.get_as_dict(pd.Timestamp('2018-01-02', tz='UTC')) 194 | np.testing.assert_array_almost_equal(df.loc[("2018-01-02", 'MSFT')].values, table['MSFT']) 195 | np.testing.assert_array_almost_equal(df.loc[("2018-01-02", 'AAPL')].values, table['AAPL']) 196 | 197 | result_df = getter.get_as_df(pd.Timestamp('2018-01-02', tz='UTC')) 198 | expected = df.xs("2018-01-02") 199 | pd.testing.assert_frame_equal(expected, result_df) 200 | 201 | table = getter.get_as_dict(pd.Timestamp('2019-01-05', tz='UTC'), column_id=3) 202 | self.assertTrue(np.isnan(table['MSFT'])) 203 | self.assertRaises(KeyError, table.__getitem__, 'AAPL') 204 | 205 | table = getter.get_as_dict(pd.Timestamp('2019-01-10', tz='UTC'), column_id=3) 206 | self.assertRaises(KeyError, table.__getitem__, 'MSFT') 207 | 208 | # test 5mins 209 | loader = spectre.data.CsvDirLoader( 210 | data_dir + '/5mins/', prices_by_year=True, prices_index='Date', parse_dates=True, 211 | ohlcv=None) 212 | df = loader.load() 213 | getter = spectre.data.DataLoaderFastGetter(df) 214 | table = getter.get_as_dict( 215 | pd.Timestamp('2018-12-20 00:00:00+00:00', tz='UTC'), 216 | pd.Timestamp('2018-12-20 23:59:59+00:00', tz='UTC'), 217 | column_id=0) 218 | self.assertTrue(len(table.get_datetime_index().normalize().unique()) == 1) 219 | -------------------------------------------------------------------------------- /tests/test_event.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import spectre 3 | from os.path import dirname 4 | import pandas as pd 5 | 6 | data_dir = dirname(__file__) + '/data/' 7 | 8 | 9 | class TestTradingEvent(unittest.TestCase): 10 | 11 | def test_event_mgr(self): 12 | class TestEventReceiver(spectre.trading.EventReceiver): 13 | fired = 0 14 | 15 | def on_run(self): 16 | self.schedule(spectre.trading.event.Always(self.test_always)) 17 | self.schedule(spectre.trading.event.EveryBarData(self.test_every_bar)) 18 | 19 | def test_always(self, _): 20 | self.fired += 1 21 | 22 | def test_every_bar(self, _): 23 | self.fired += 1 24 | if self.fired == 2: 25 | self.stop_event_manager() 26 | 27 | class StopFirer(spectre.trading.EventReceiver): 28 | def on_run(self): 29 | self.schedule(spectre.trading.event.Always(self.test)) 30 | 31 | def test(self, _): 32 | self.fire_event(spectre.trading.event.EveryBarData) 33 | 34 | rcv = TestEventReceiver() 35 | 36 | evt_mgr = spectre.trading.EventManager() 37 | evt_mgr.subscribe(rcv) 38 | rcv.unsubscribe() 39 | self.assertEqual(0, len(evt_mgr._subscribers)) 40 | self.assertRaisesRegex(ValueError, 'At least one subscriber.', evt_mgr.run) 41 | self.assertRaisesRegex(AssertionError, 'Subscriber not exists', evt_mgr.unsubscribe, rcv) 42 | 43 | evt_mgr.subscribe(rcv) 44 | evt_mgr.subscribe(StopFirer()) 45 | self.assertRaisesRegex(AssertionError, 'Duplicate subscribe', evt_mgr.subscribe, rcv) 46 | 47 | evt_mgr.run() 48 | self.assertEqual(2, rcv.fired) 49 | 50 | def test_calendar(self): 51 | tz = 'America/New_York' 52 | end = pd.Timestamp.now(tz=tz) + pd.DateOffset(days=10) 53 | first = pd.date_range(pd.Timestamp.now(tz=tz).normalize(), end, freq='B')[0] 54 | if pd.Timestamp.now(tz=tz) > (first + pd.Timedelta("9:00:00")): 55 | first = first + pd.offsets.BDay(1) 56 | holiday = first + pd.offsets.BDay(2) 57 | test_now = first + pd.offsets.BDay(1) + pd.Timedelta("10:00:00") 58 | 59 | calendar = spectre.trading.Calendar() 60 | calendar.build(start=str(pd.Timestamp.now(tz=tz).normalize()), end=str(end.date()), 61 | daily_events={'Open': '9:00:00', 'Close': '15:00:00'}, 62 | tz=tz) 63 | calendar.set_as_holiday(holiday) 64 | 65 | self.assertEqual(first + pd.Timedelta("9:00:00"), 66 | calendar.events['Open'][0]) 67 | 68 | calendar.hr_now = lambda: test_now 69 | 70 | calendar.pop_passed('Open') 71 | 72 | self.assertEqual(test_now.normalize() + pd.offsets.BDay(2) + pd.Timedelta("9:00:00"), 73 | calendar.events['Open'][0]) 74 | 75 | # test assert 76 | self.assertRaises(ValueError, calendar.build, 77 | start=str(pd.Timestamp.now(tz=tz).normalize()), end='2019', 78 | daily_events={'Open': '9:00:00', 'Close': '15:00:00'}, 79 | tz=tz) 80 | -------------------------------------------------------------------------------- /tests/test_metric.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import spectre 3 | import pandas as pd 4 | from numpy import nan 5 | from numpy.testing import assert_almost_equal 6 | 7 | 8 | class TestMetric(unittest.TestCase): 9 | 10 | def test_metrics(self): 11 | ret = pd.Series([0.0022, 0.0090, -0.0067, 0.0052, 0.0030, -0.0012, -0.0091, 0.0082, 12 | -0.0071, 0.0093], 13 | index=pd.date_range('2040-01-01', periods=10)) 14 | 15 | self.assertAlmostEqual(2.9101144, spectre.trading.sharpe_ratio(ret, 0.00)) 16 | self.assertAlmostEqual(2.5492371, spectre.trading.sharpe_ratio(ret, 0.04)) 17 | 18 | dd, ddu = spectre.trading.drawdown((ret+1).cumprod()) 19 | vol = spectre.trading.annual_volatility(ret) 20 | 21 | self.assertAlmostEqual(0.0102891, dd.abs().max()) 22 | self.assertAlmostEqual(5, ddu.max()) 23 | self.assertAlmostEqual(0.110841, vol) 24 | 25 | txn = pd.DataFrame([['AAPL', 384, 155.92, 157.09960, 1.92], 26 | ['AAPL', -384, 158.61, 157.41695, 1.92]], 27 | columns=['symbol', 'amount', 'price', 28 | 'fill_price', 'commission'], 29 | index=['2040-01-01', '2040-01-02']) 30 | txn.index = pd.to_datetime(txn.index) 31 | 32 | pos = pd.DataFrame([[384, 384*155.92, 10000], 33 | [nan, nan, 10000+384*158.61]], 34 | columns=pd.MultiIndex.from_tuples( 35 | [('shares', 'AAPL'), ('value', 'AAPL'), ('value', 'cash')]), 36 | index=['2040-01-01', '2040-01-02']) 37 | pos.index = pd.to_datetime(pos.index) 38 | assert_almost_equal([0.8633665, 0.8525076], spectre.trading.turnover(pos, txn)) 39 | -------------------------------------------------------------------------------- /tests/test_parallel_algo.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import torch 3 | import spectre 4 | from numpy.testing import assert_array_equal, assert_almost_equal 5 | import numpy as np 6 | import pandas as pd 7 | 8 | 9 | class TestParallelAlgorithm(unittest.TestCase): 10 | def test_groupby(self): 11 | test_x = torch.tensor([1, 2, 10, 3, 11, 20, 4, 21, 5, 12, 13, 14, 15], dtype=torch.float32) 12 | test_k = torch.tensor([1, 1, 2, 1, 2, 3, 1, 3, 1, 2, 2, 2, 2]) 13 | 14 | groupby = spectre.parallel.ParallelGroupBy(test_k) 15 | groups = groupby.split(test_x) 16 | assert_array_equal([1., 2., 3., 4., 5., np.nan], groups[0].tolist()) 17 | assert_array_equal([10., 11., 12., 13., 14., 15.], groups[1].tolist()) 18 | assert_array_equal([20., 21., np.nan, np.nan, np.nan, np.nan], groups[2].tolist()) 19 | 20 | revert_x = groupby.revert(groups) 21 | assert_array_equal(revert_x.tolist(), test_x.tolist()) 22 | 23 | def test_rolling(self): 24 | x = torch.tensor([[164.0000, 163.7100, 158.6100, 145.230], 25 | [104.6100, 104.4200, 101.3000, 102.280]]) 26 | expected = torch.tensor( 27 | [[np.nan, np.nan, 486.3200, 467.5500], 28 | [np.nan, np.nan, 310.3300, 308.0000]]) 29 | 30 | self.assertRegex(str(spectre.parallel.Rolling(x, 3)), 31 | "spectre.parallel.Rolling object(.|\n)*tensor(.|\n)*") 32 | s = spectre.parallel.Rolling(x, 3).sum() 33 | assert_almost_equal(expected.numpy(), s.numpy(), decimal=4) 34 | 35 | # test adjustment 36 | y = torch.tensor([[0.25, 0.25, 0.5, 1], 37 | [0.6, 0.75, 0.75, 1]]) 38 | s = spectre.parallel.Rolling(x, 3, y).sum() 39 | expected = torch.tensor([ 40 | [ 41 | np.nan, np.nan, 42 | sum([164.0000 / 2, 163.7100 / 2, 158.6100]), 43 | sum([163.7100 / 4, 158.6100 / 2, 145.230]), 44 | ], 45 | [ 46 | np.nan, np.nan, 47 | sum([104.6100 * (0.6 / 0.75), 104.4200, 101.3000]), 48 | sum([104.4200 * 0.75, 101.3000 * 0.75, 102.280]), 49 | ] 50 | ]) 51 | assert_almost_equal(expected.numpy(), s.numpy(), decimal=4) 52 | 53 | x = torch.zeros([1024, 102400], dtype=torch.float64) 54 | spectre.parallel.Rolling(x, 252).sum() 55 | 56 | def test_nan(self): 57 | # dim=1 58 | data = [[1, 2, 1], [4, np.nan, 2], [7, 8, 1]] 59 | result = spectre.parallel.nanmean(torch.tensor(data, dtype=torch.float)) 60 | expected = np.nanmean(data, axis=1) 61 | assert_almost_equal(expected, result, decimal=6) 62 | 63 | result = spectre.parallel.nanstd(torch.tensor(data, dtype=torch.float)) 64 | expected = np.nanstd(data, axis=1) 65 | assert_almost_equal(expected, result, decimal=6) 66 | 67 | result = spectre.parallel.nanstd(torch.tensor(data, dtype=torch.float), ddof=1) 68 | expected = np.nanstd(data, axis=1, ddof=1) 69 | assert_almost_equal(expected, result, decimal=6) 70 | 71 | # dim=2 72 | data = [[[np.nan, 1, 2], [1, 2, 1]], [[np.nan, 4, np.nan], [4, np.nan, 2]], 73 | [[np.nan, 7, 8], [7, 8, 1]]] 74 | result = spectre.parallel.nanmean(torch.tensor(data, dtype=torch.float), dim=2) 75 | expected = np.nanmean(data, axis=2) 76 | assert_almost_equal(expected, result, decimal=6) 77 | 78 | result = spectre.parallel.nanstd(torch.tensor(data, dtype=torch.float), dim=2) 79 | expected = np.nanstd(data, axis=2) 80 | assert_almost_equal(expected, result, decimal=6) 81 | 82 | # last 83 | data = [[1, 2, np.nan], [4, np.nan, 2], [7, 8, 1]] 84 | result = spectre.parallel.nanlast(torch.tensor(data, dtype=torch.float).cuda()) 85 | expected = [2., 2., 1.] 86 | assert_almost_equal(expected, result.cpu(), decimal=6) 87 | 88 | data = [[[1, 2, np.nan], [4, np.nan, 2], [7, 8, 1]]] 89 | result = spectre.parallel.nanlast(torch.tensor(data, dtype=torch.float).cuda(), dim=2) 90 | expected = [[2., 2., 1.]] 91 | assert_almost_equal(expected, result.cpu(), decimal=6) 92 | 93 | data = [1, 2, np.nan, 4, np.nan, 2, 7, 8, 1] 94 | result = spectre.parallel.nanlast(torch.tensor(data, dtype=torch.float), dim=0) 95 | expected = [1.] 96 | assert_almost_equal(expected, result, decimal=6) 97 | 98 | data = [1, 2, np.nan] 99 | mask = [False, True, True] 100 | result = spectre.parallel.masked_first( 101 | torch.tensor(data, dtype=torch.float), torch.tensor(mask, dtype=torch.bool), dim=0) 102 | expected = [2.] 103 | assert_almost_equal(expected, result, decimal=6) 104 | 105 | mask = [False, False, True] 106 | result = spectre.parallel.masked_first( 107 | torch.tensor(data, dtype=torch.float), torch.tensor(mask, dtype=torch.bool), dim=0) 108 | expected = [np.nan] 109 | assert_almost_equal(expected, result, decimal=6) 110 | 111 | # nanmin/max 112 | data = [[1, 2, -14, np.nan, 2], [99999, 8, 1, np.nan, 2]] 113 | result = spectre.parallel.nanmax(torch.tensor(data, dtype=torch.float)) 114 | expected = np.nanmax(data, axis=1) 115 | assert_almost_equal(expected, result, decimal=6) 116 | 117 | result = spectre.parallel.nanmin(torch.tensor(data, dtype=torch.float)) 118 | expected = np.nanmin(data, axis=1) 119 | assert_almost_equal(expected, result, decimal=6) 120 | 121 | def test_stat(self): 122 | x = torch.tensor([[1., 2, 3, 4, 5], [10, 12, 13, 14, 16], [2, 2, 2, 2, 2, ]]) 123 | y = torch.tensor([[-1., 2, 3, 4, -5], [11, 12, -13, 14, 15], [2, 2, 2, 2, 2, ]]) 124 | result = spectre.parallel.covariance(x, y, ddof=1) 125 | expected = np.cov(x, y, ddof=1) 126 | expected = expected[:x.shape[0], x.shape[0]:] 127 | assert_almost_equal(np.diag(expected), result, decimal=6) 128 | 129 | coef, intcp = spectre.parallel.linear_regression_1d(x, y) 130 | from sklearn.linear_model import LinearRegression 131 | for i in range(3): 132 | reg = LinearRegression().fit(x[i, :, None], y[i, :, None]) 133 | assert_almost_equal(reg.coef_, coef[i], decimal=6) 134 | 135 | # test pearsonr 136 | result = spectre.parallel.pearsonr(x, y) 137 | from scipy import stats 138 | for i in range(3): 139 | expected, _ = stats.pearsonr(x[i].tolist(), y[i].tolist()) 140 | assert_almost_equal(expected, result[i], decimal=6) 141 | 142 | # test pearsonr 143 | rank_x = spectre.parallel.rankdata(x) 144 | rank_y = spectre.parallel.rankdata(y) 145 | result = spectre.parallel.spearman(rank_x, rank_y) 146 | print(result) 147 | from scipy import stats 148 | for i in range(3): 149 | expected, _ = stats.spearmanr(x[i].tolist(), y[i].tolist()) 150 | if expected != expected: 151 | expected = 1 152 | assert_almost_equal(expected, result[i], decimal=6) 153 | 154 | # test quantile 155 | x = torch.tensor([[1, 2, np.nan, 3, 4, 5, 6], [3, 4, 5, 1.01, np.nan, 1.02, 1.03]]) 156 | result = spectre.parallel.quantile(x, 5, dim=1) 157 | expected = pd.qcut(x[0].tolist(), 5, labels=False) 158 | assert_array_equal(expected, result[0]) 159 | expected = pd.qcut(x[1].tolist(), 5, labels=False) 160 | assert_array_equal(expected, result[1]) 161 | 162 | x = torch.tensor( 163 | [[[1, 2, np.nan, 3, 4, 5, 6], 164 | [3, 4, 5, 1.01, np.nan, 1.02, 1.03]], 165 | [[1, 2, 2.1, 3, 4, 5, 6], 166 | [3, 4, 5, np.nan, np.nan, 1.02, 1.03]]]) 167 | result = spectre.parallel.quantile(x, 5, dim=2) 168 | expected = pd.qcut(x[0, 0].tolist(), 5, labels=False) 169 | assert_array_equal(expected, result[0, 0]) 170 | expected = pd.qcut(x[0, 1].tolist(), 5, labels=False) 171 | assert_array_equal(expected, result[0, 1]) 172 | expected = pd.qcut(x[1, 0].tolist(), 5, labels=False) 173 | assert_array_equal(expected, result[1, 0]) 174 | expected = pd.qcut(x[1, 1].tolist(), 5, labels=False) 175 | assert_array_equal(expected, result[1, 1]) 176 | 177 | # test squeeze bug 178 | x = torch.tensor([[1., 2, 3, 4, 5]]) 179 | y = torch.tensor([[-1., 2, 3, 4, -5]]) 180 | coef, intcp = spectre.parallel.linear_regression_1d(x, y) 181 | reg = LinearRegression().fit(x[0, :, None], y[0, :, None]) 182 | assert_almost_equal(reg.coef_, coef[0], decimal=6) 183 | 184 | # test median 185 | x = torch.tensor([[1, 2, np.nan, 3, 4, 5, 6], [3, 4, 5, 1.01, np.nan, 1.02, 1.03]]) 186 | u = torch.tensor([[True, False, True, True, True, True, True, ], 187 | [True, True, True, True, True, True, True]]) 188 | np_x = torch.masked_fill(x, ~u, np.nan) 189 | median = np.nanmedian(np_x.cpu(), axis=1) 190 | median = np.expand_dims(median, axis=1) 191 | expected, _ = spectre.parallel.masked_kth_value_1d(x, u, [0.5], dim=1) 192 | assert_almost_equal(median, expected[0], decimal=6) 193 | 194 | def test_pad2d(self): 195 | x = torch.tensor([[np.nan, 1, 1, np.nan, 1, np.nan, np.nan, 0, np.nan, 0, np.nan, np.nan, 0, 196 | np.nan, -1, np.nan, - 1, np.nan, np.nan, np.nan, 1], 197 | [np.nan, 1, 0, np.nan, 1, np.nan, np.nan, 1, np.nan, -1, np.nan, -1, 0, 198 | np.nan, -1, np.nan, - 1, np.nan, np.nan, np.nan, 1]]) 199 | result = spectre.parallel.pad_2d(x) 200 | 201 | expected = [pd.Series(x[0].numpy()).ffill(), pd.Series(x[1].numpy()).ffill()] 202 | assert_almost_equal(expected, result) 203 | -------------------------------------------------------------------------------- /tests/test_trading_algorithm.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import spectre 3 | import pandas as pd 4 | from os.path import dirname 5 | from numpy import nan 6 | from numpy.testing import assert_array_equal, assert_almost_equal 7 | 8 | data_dir = dirname(__file__) + '/data/' 9 | 10 | 11 | class TestTradingAlgorithm(unittest.TestCase): 12 | 13 | def test_simulation_event_manager(self): 14 | loader = spectre.data.CsvDirLoader( 15 | data_dir + '/daily/', ohlcv=('uOpen', 'uHigh', 'uLow', 'uClose', 'uVolume'), 16 | prices_index='date', parse_dates=True, 17 | ) 18 | parent = self 19 | 20 | class MockTestAlg(spectre.trading.CustomAlgorithm): 21 | _data = None 22 | # test the elapsed time is correct 23 | _bar_dates = [] 24 | # test if the sequence of events per day is correct 25 | _seq = 0 26 | 27 | def __init__(self): 28 | self.blotter = spectre.trading.SimulationBlotter(loader) 29 | 30 | def clear(self): 31 | pass 32 | 33 | def run_engine(self, start, end, _=False): 34 | engine = spectre.factors.FactorEngine(loader) 35 | f = spectre.factors.MA(5) 36 | engine.add(f, 'f') 37 | # self._engines = {'main': engine} 38 | df = engine.run(start, end) 39 | return df, df.loc[df.index.get_level_values(0)[-1]] 40 | 41 | def initialize(self): 42 | self.schedule(spectre.trading.event.EveryBarData( 43 | lambda x: self.test_every_bar(self._data) 44 | )) 45 | self.schedule(spectre.trading.event.MarketOpen(self.test_before_open, -1000)) 46 | self.schedule(spectre.trading.event.MarketOpen(self.test_open, 0)) 47 | self.schedule(spectre.trading.event.MarketClose(self.test_before_close, -1000)) 48 | self.schedule(spectre.trading.event.MarketClose(self.test_close, 1000)) 49 | 50 | def _run_engine(self, source): 51 | self._data, _ = self.run_engine(None, None) 52 | 53 | def on_run(self): 54 | self.schedule(spectre.trading.event.EveryBarData( 55 | self._run_engine 56 | )) 57 | self.initialize() 58 | 59 | def on_end_of_run(self): 60 | pass 61 | 62 | def test_every_bar(self, data): 63 | self._seq += 1 64 | parent.assertEqual(1, self._seq) 65 | 66 | today = data.index.get_level_values(0)[-1] 67 | self._bar_dates.append(data.index.get_level_values(0)[-1]) 68 | if today > pd.Timestamp("2019-01-10", tz='UTC'): 69 | self.stop_event_manager() 70 | 71 | def test_before_open(self, source): 72 | self._seq += 1 73 | parent.assertEqual(2, self._seq) 74 | 75 | def test_open(self, source): 76 | self._seq += 1 77 | parent.assertEqual(3, self._seq) 78 | 79 | def test_before_close(self, source): 80 | self._seq += 1 81 | parent.assertEqual(4, self._seq) 82 | 83 | def test_close(self, source): 84 | self._seq += 1 85 | parent.assertEqual(5, self._seq) 86 | self._seq = 0 87 | 88 | rcv = MockTestAlg() 89 | 90 | evt_mgr = spectre.trading.SimulationEventManager() 91 | evt_mgr.subscribe(rcv) 92 | evt_mgr.run("2019-01-01", "2019-01-15") 93 | 94 | self.assertEqual(rcv._bar_dates[0], pd.Timestamp("2019-01-03", tz='UTC')) 95 | self.assertEqual(rcv._bar_dates[1], pd.Timestamp("2019-01-04", tz='UTC')) 96 | # test stop event is correct 97 | self.assertEqual(rcv._bar_dates[-1], pd.Timestamp("2019-01-11", tz='UTC')) 98 | 99 | def test_one_engine_algorithm(self): 100 | self_test = self 101 | 102 | class OneEngineAlg(spectre.trading.CustomAlgorithm): 103 | def initialize(self): 104 | engine = self.get_factor_engine() 105 | ma5 = spectre.factors.MA(5) 106 | engine.add(ma5, 'ma5') 107 | engine.set_filter(ma5.top(5)) 108 | 109 | self.schedule_rebalance(spectre.trading.event.MarketOpen(self.rebalance)) 110 | 111 | self.blotter.long_ony = True 112 | self.blotter.set_commission(0, 0, 0) 113 | self.blotter.set_slippage(0, 0) 114 | 115 | def rebalance(self, data, history): 116 | if 103.98 <= data.loc['MSFT', 'ma5'] <= 103.99: 117 | data = data.drop('MSFT', axis=0) 118 | if 'MSFT' in data.index: 119 | self_test.assertAlmostEqual(1268.4657576628665, data.loc['MSFT', 'ma5']) 120 | weights = data.ma5 / data.ma5.sum() 121 | assets = data.index 122 | self.blotter.batch_order_target_percent(assets, weights) 123 | 124 | def terminate(self, _): 125 | pass 126 | 127 | loader = spectre.data.CsvDirLoader( 128 | data_dir + '/daily/', calender_asset='AAPL', 129 | ohlcv=('uOpen', 'uHigh', 'uLow', 'uClose', 'uVolume'), 130 | dividends_path=data_dir + '/dividends/', splits_path=data_dir + '/splits/', 131 | adjustments=('amount', 'ratio'), 132 | prices_index='date', dividends_index='exDate', splits_index='exDate', parse_dates=True, 133 | ) 134 | results = spectre.trading.run_backtest( 135 | loader, OneEngineAlg, "2019-01-11", "2019-01-15") 136 | 137 | # test factor delay, order correct 138 | # --- day1 --- 139 | aapl_shares1 = int(1e5/155.19) 140 | aapl_cost1 = aapl_shares1*155.19 141 | cash1 = 1e5-aapl_cost1 142 | aapl_value_eod1 = aapl_shares1 * 157 143 | # --- day2 --- 144 | aapl_weight2 = 155.854 / (155.854+1268.466) 145 | msft_weight2 = 1268.466 / (155.854+1268.466) 146 | value_bod2 = aapl_shares1 * 150.81 + cash1 147 | aapl_shares_change = aapl_weight2 * value_bod2 / 150.81 148 | aapl_shares_change = int(round(aapl_shares_change)) - aapl_shares1 149 | aapl_shares2 = aapl_shares1 + aapl_shares_change 150 | aapl_basis = (155.19 * aapl_shares1 + aapl_shares_change * 150.81) / aapl_shares2 151 | aapl_value2 = aapl_shares2 * 156.94 152 | msft_shares2 = int(round(msft_weight2 * value_bod2 / 103.19)) 153 | msft_value2 = msft_shares2 * 108.85 154 | cash2 = 1e5-aapl_cost1 + (aapl_shares1-aapl_shares2) * 150.81 - msft_shares2 * 103.19 155 | expected = pd.DataFrame([[nan, nan, nan, nan, nan, nan, 100000.00], 156 | [155.19, nan, aapl_shares1, nan, aapl_value_eod1, nan, cash1], 157 | [aapl_basis, 103.19, aapl_shares2, msft_shares2, aapl_value2, 158 | msft_value2, cash2]], 159 | columns=pd.MultiIndex.from_tuples( 160 | [('avg_px', 'AAPL'), ('avg_px', 'MSFT'), 161 | ('shares', 'AAPL'), ('shares', 'MSFT'), 162 | ('value', 'AAPL'), ('value', 'MSFT'), 163 | ('value', 'cash')]), 164 | index=[pd.Timestamp("2019-01-13", tz='UTC'), 165 | pd.Timestamp("2019-01-14", tz='UTC'), 166 | pd.Timestamp("2019-01-15", tz='UTC')]) 167 | expected.index.name = 'index' 168 | pd.testing.assert_frame_equal(expected, results.positions) 169 | 170 | def test_two_engine_algorithm(self): 171 | class TwoEngineAlg(spectre.trading.CustomAlgorithm): 172 | def initialize(self): 173 | engine_main = self.get_factor_engine('main') 174 | engine_test = self.get_factor_engine('test') 175 | 176 | ma5 = spectre.factors.MA(5) 177 | ma4 = spectre.factors.MA(4) 178 | engine_main.add(ma5, 'ma5') 179 | engine_test.add(ma4, 'ma4') 180 | 181 | self.schedule_rebalance(spectre.trading.event.MarketClose( 182 | self.rebalance, offset_ns=-10000)) 183 | 184 | self.blotter.set_commission(0, 0.005, 1) 185 | self.blotter.set_slippage(0, 0.4) 186 | 187 | def rebalance(self, data, history): 188 | mask = data['test'].ma4 > data['main'].ma5 189 | masked_test = data['test'][mask] 190 | assets = masked_test.index 191 | weights = masked_test.ma4 / masked_test.ma4.sum() 192 | for asset, weight in zip(assets, weights): 193 | self.blotter.order_target_percent(asset, weight) 194 | 195 | def terminate(self, _): 196 | pass 197 | 198 | loader = spectre.data.CsvDirLoader( 199 | data_dir + '/daily/', calender_asset='AAPL', 200 | ohlcv=('uOpen', 'uHigh', 'uLow', 'uClose', 'uVolume'), 201 | dividends_path=data_dir + '/dividends/', splits_path=data_dir + '/splits/', 202 | adjustments=('amount', 'ratio'), 203 | prices_index='date', dividends_index='exDate', splits_index='exDate', parse_dates=True, 204 | ) 205 | blotter = spectre.trading.SimulationBlotter(loader) 206 | evt_mgr = spectre.trading.SimulationEventManager() 207 | alg = TwoEngineAlg(blotter, main=loader, test=loader) 208 | evt_mgr.subscribe(alg) 209 | evt_mgr.subscribe(blotter) 210 | 211 | evt_mgr.run("2019-01-10", "2019-01-15") 212 | first_run = str(blotter) 213 | evt_mgr.run("2019-01-10", "2019-01-15") 214 | 215 | # test two result should be the same. 216 | self.assertEqual(first_run, str(blotter)) 217 | assert_array_equal(['AAPL', 'MSFT', 'AAPL', 'MSFT', 'AAPL'], 218 | blotter.get_transactions().symbol.values) 219 | 220 | def test_record(self): 221 | recorder = spectre.trading.CustomAlgorithm(None, main=None)._recorder 222 | recorder.record("2019-01-10", dict(a=1, b=2)) 223 | recorder.record("2019-01-11", dict(a=2, b=3, c=4)) 224 | df = recorder.to_df() 225 | expected = pd.DataFrame([[1, 2, nan], 226 | [2, 3, 4]], 227 | columns=['a', 'b', 'c'], 228 | index=["2019-01-10", 229 | "2019-01-11"]) 230 | expected.index.name = 'date' 231 | pd.testing.assert_frame_equal(expected, df) 232 | 233 | def test_intraday_algorithm(self): 234 | class IntradayAlg(spectre.trading.CustomAlgorithm): 235 | order_shares = 0.3 236 | 237 | def initialize(self): 238 | engine_main = self.get_factor_engine() 239 | ma5 = spectre.factors.MA(5) 240 | engine_main.add(ma5, 'ma5') 241 | 242 | self.schedule_rebalance(spectre.trading.event.MarketClose( 243 | self.rebalance, offset_ns=-10000)) 244 | 245 | self.blotter.set_commission(0, 0.005, 1) 246 | self.blotter.set_slippage(0, 0.4) 247 | 248 | def rebalance(self, data, history): 249 | self.blotter.order_target_percent('AAPL', self.order_shares) 250 | self.order_shares = -self.order_shares 251 | 252 | def terminate(self, _): 253 | pass 254 | 255 | loader = spectre.data.CsvDirLoader( 256 | data_dir + '/5mins/', prices_by_year=True, prices_index='Date', 257 | ohlcv=('Open', 'High', 'Low', 'Close', 'Volume'), parse_dates=True, ) 258 | results = spectre.trading.run_backtest(loader, IntradayAlg, "2019-01-01", "2019-01-05") 259 | 260 | self.assertAlmostEqual(157.92, results.transactions.loc['2019-01-02 20:55:00+00:00'].price) 261 | self.assertAlmostEqual(142.09, results.transactions.loc['2019-01-03 20:55:00+00:00'].price) 262 | self.assertAlmostEqual(148.26, results.transactions.loc['2019-01-04 20:55:00+00:00'].price) 263 | 264 | class IntradayAlgOpen(spectre.trading.CustomAlgorithm): 265 | order_shares = 0.3 266 | 267 | def initialize(self): 268 | engine_main = self.get_factor_engine() 269 | ma5 = spectre.factors.MA(5) 270 | engine_main.add(ma5, 'ma5') 271 | 272 | self.schedule_rebalance(spectre.trading.event.MarketOpen( 273 | self.rebalance)) 274 | 275 | self.blotter.set_commission(0, 0.005, 1) 276 | self.blotter.set_slippage(0, 0.4) 277 | 278 | def rebalance(self, data, history): 279 | self.blotter.order_target_percent('AAPL', self.order_shares) 280 | self.order_shares = -self.order_shares 281 | 282 | def terminate(self, _): 283 | pass 284 | 285 | loader = spectre.data.CsvDirLoader( 286 | data_dir + '/5mins/', prices_by_year=True, prices_index='Date', 287 | ohlcv=('Open', 'High', 'Low', 'Close', 'Volume'), parse_dates=True, ) 288 | results = spectre.trading.run_backtest(loader, IntradayAlgOpen, "2019-01-01", "2019-01-05") 289 | assert_almost_equal([143.95, 144.58], results.transactions.price) 290 | --------------------------------------------------------------------------------